arize-phoenix 11.4.0__py3-none-any.whl → 11.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-11.4.0.dist-info → arize_phoenix-11.6.0.dist-info}/METADATA +2 -2
- {arize_phoenix-11.4.0.dist-info → arize_phoenix-11.6.0.dist-info}/RECORD +42 -40
- phoenix/config.py +51 -2
- phoenix/server/api/auth.py +1 -1
- phoenix/server/api/queries.py +52 -44
- phoenix/server/api/routers/v1/annotation_configs.py +4 -1
- phoenix/server/api/routers/v1/datasets.py +3 -1
- phoenix/server/api/routers/v1/evaluations.py +3 -1
- phoenix/server/api/routers/v1/experiment_runs.py +3 -1
- phoenix/server/api/routers/v1/experiments.py +3 -1
- phoenix/server/api/routers/v1/projects.py +4 -1
- phoenix/server/api/routers/v1/prompts.py +4 -1
- phoenix/server/api/routers/v1/spans.py +6 -3
- phoenix/server/api/routers/v1/traces.py +4 -1
- phoenix/server/api/routers/v1/users.py +2 -2
- phoenix/server/api/types/Span.py +0 -99
- phoenix/server/app.py +47 -12
- phoenix/server/authorization.py +9 -0
- phoenix/server/bearer_auth.py +18 -15
- phoenix/server/cost_tracking/cost_model_lookup.py +1 -1
- phoenix/server/cost_tracking/model_cost_manifest.json +107 -107
- phoenix/server/daemons/db_disk_usage_monitor.py +215 -0
- phoenix/server/email/sender.py +25 -0
- phoenix/server/email/templates/db_disk_usage_notification.html +16 -0
- phoenix/server/email/types.py +11 -0
- phoenix/server/grpc_server.py +3 -3
- phoenix/server/prometheus.py +22 -0
- phoenix/server/static/.vite/manifest.json +44 -44
- phoenix/server/static/assets/{components-CVcMbu2U.js → components-mOUBHJ12.js} +297 -298
- phoenix/server/static/assets/{index-Dz7I-Hpn.js → index-CQ_A6K_M.js} +2 -2
- phoenix/server/static/assets/{pages-QK2o2V7x.js → pages-CCsLkNZY.js} +517 -498
- phoenix/server/static/assets/{vendor-pg5m6BWE.js → vendor-DRWIRkSJ.js} +1 -1
- phoenix/server/static/assets/{vendor-arizeai-BwMsgSAG.js → vendor-arizeai-DUhQaeau.js} +1 -1
- phoenix/server/static/assets/{vendor-codemirror-BwSDEu2g.js → vendor-codemirror-D_6Q6Auv.js} +1 -1
- phoenix/server/static/assets/{vendor-recharts-SW3HwAtG.js → vendor-recharts-BNBwj7vz.js} +1 -1
- phoenix/server/static/assets/{vendor-shiki-BsdYoDvs.js → vendor-shiki-k1qj_XjP.js} +1 -1
- phoenix/server/types.py +11 -2
- phoenix/version.py +1 -1
- {arize_phoenix-11.4.0.dist-info → arize_phoenix-11.6.0.dist-info}/WHEEL +0 -0
- {arize_phoenix-11.4.0.dist-info → arize_phoenix-11.6.0.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-11.4.0.dist-info → arize_phoenix-11.6.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-11.4.0.dist-info → arize_phoenix-11.6.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
from typing import Any, Optional, Union
|
|
3
3
|
|
|
4
|
-
from fastapi import APIRouter, HTTPException, Path, Query
|
|
4
|
+
from fastapi import APIRouter, Depends, HTTPException, Path, Query
|
|
5
5
|
from pydantic import ValidationError, model_validator
|
|
6
6
|
from sqlalchemy import select
|
|
7
7
|
from sqlalchemy.sql import Select
|
|
@@ -33,6 +33,7 @@ from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
|
33
33
|
from phoenix.server.api.types.Prompt import Prompt as PromptNodeType
|
|
34
34
|
from phoenix.server.api.types.PromptVersion import PromptVersion as PromptVersionNodeType
|
|
35
35
|
from phoenix.server.api.types.PromptVersionTag import PromptVersionTag as PromptVersionTagNodeType
|
|
36
|
+
from phoenix.server.authorization import is_not_locked
|
|
36
37
|
from phoenix.server.bearer_auth import PhoenixUser
|
|
37
38
|
|
|
38
39
|
logger = logging.getLogger(__name__)
|
|
@@ -393,6 +394,7 @@ async def get_prompt_version_by_latest(
|
|
|
393
394
|
|
|
394
395
|
@router.post(
|
|
395
396
|
"/prompts",
|
|
397
|
+
dependencies=[Depends(is_not_locked)],
|
|
396
398
|
operation_id="postPromptVersion",
|
|
397
399
|
summary="Create a new prompt",
|
|
398
400
|
description="Create a new prompt and its initial version. A prompt can have multiple versions.",
|
|
@@ -602,6 +604,7 @@ async def list_prompt_version_tags(
|
|
|
602
604
|
|
|
603
605
|
@router.post(
|
|
604
606
|
"/prompt_versions/{prompt_version_id}/tags",
|
|
607
|
+
dependencies=[Depends(is_not_locked)],
|
|
605
608
|
operation_id="createPromptVersionTag",
|
|
606
609
|
summary="Add tag to prompt version",
|
|
607
610
|
description="Add a new tag to a specific prompt version. Tags help identify and categorize "
|
|
@@ -8,7 +8,7 @@ from secrets import token_urlsafe
|
|
|
8
8
|
from typing import Annotated, Any, Literal, Optional, Union
|
|
9
9
|
|
|
10
10
|
import pandas as pd
|
|
11
|
-
from fastapi import APIRouter, Header, HTTPException, Path, Query
|
|
11
|
+
from fastapi import APIRouter, Depends, Header, HTTPException, Path, Query
|
|
12
12
|
from pydantic import BaseModel, Field
|
|
13
13
|
from sqlalchemy import select
|
|
14
14
|
from starlette.requests import Request
|
|
@@ -28,6 +28,7 @@ from phoenix.db.helpers import SupportedSQLDialect
|
|
|
28
28
|
from phoenix.db.insertion.helpers import as_kv, insert_on_conflict
|
|
29
29
|
from phoenix.db.insertion.types import Precursors
|
|
30
30
|
from phoenix.server.api.routers.utils import df_to_bytes
|
|
31
|
+
from phoenix.server.authorization import is_not_locked
|
|
31
32
|
from phoenix.server.bearer_auth import PhoenixUser
|
|
32
33
|
from phoenix.server.dml_event import SpanAnnotationInsertEvent
|
|
33
34
|
from phoenix.trace.attributes import flatten
|
|
@@ -601,7 +602,7 @@ async def span_search_otlpv1(
|
|
|
601
602
|
models.Trace.trace_id,
|
|
602
603
|
)
|
|
603
604
|
.join(models.Trace, onclause=models.Trace.id == models.Span.trace_rowid)
|
|
604
|
-
.
|
|
605
|
+
.where(models.Trace.project_rowid == project_id)
|
|
605
606
|
.order_by(*order_by)
|
|
606
607
|
)
|
|
607
608
|
|
|
@@ -736,7 +737,7 @@ async def span_search(
|
|
|
736
737
|
models.Trace.trace_id,
|
|
737
738
|
)
|
|
738
739
|
.join(models.Trace, onclause=models.Trace.id == models.Span.trace_rowid)
|
|
739
|
-
.
|
|
740
|
+
.where(models.Trace.project_rowid == project_id)
|
|
740
741
|
.order_by(*order_by)
|
|
741
742
|
)
|
|
742
743
|
|
|
@@ -907,6 +908,7 @@ class AnnotateSpansResponseBody(ResponseBody[list[InsertedSpanAnnotation]]):
|
|
|
907
908
|
|
|
908
909
|
@router.post(
|
|
909
910
|
"/span_annotations",
|
|
911
|
+
dependencies=[Depends(is_not_locked)],
|
|
910
912
|
operation_id="annotateSpans",
|
|
911
913
|
summary="Create span annotations",
|
|
912
914
|
responses=add_errors_to_responses(
|
|
@@ -990,6 +992,7 @@ class CreateSpansResponseBody(V1RoutesBaseModel):
|
|
|
990
992
|
|
|
991
993
|
@router.post(
|
|
992
994
|
"/projects/{project_identifier}/spans",
|
|
995
|
+
dependencies=[Depends(is_not_locked)],
|
|
993
996
|
operation_id="createSpans",
|
|
994
997
|
summary="Create spans",
|
|
995
998
|
description=(
|
|
@@ -2,7 +2,7 @@ import gzip
|
|
|
2
2
|
import zlib
|
|
3
3
|
from typing import Any, Literal, Optional
|
|
4
4
|
|
|
5
|
-
from fastapi import APIRouter, BackgroundTasks, Header, HTTPException, Query
|
|
5
|
+
from fastapi import APIRouter, BackgroundTasks, Depends, Header, HTTPException, Query
|
|
6
6
|
from google.protobuf.message import DecodeError
|
|
7
7
|
from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import (
|
|
8
8
|
ExportTraceServiceRequest,
|
|
@@ -24,6 +24,7 @@ from strawberry.relay import GlobalID
|
|
|
24
24
|
from phoenix.db import models
|
|
25
25
|
from phoenix.db.insertion.helpers import as_kv
|
|
26
26
|
from phoenix.db.insertion.types import Precursors
|
|
27
|
+
from phoenix.server.authorization import is_not_locked
|
|
27
28
|
from phoenix.server.bearer_auth import PhoenixUser
|
|
28
29
|
from phoenix.server.dml_event import TraceAnnotationInsertEvent
|
|
29
30
|
from phoenix.trace.otel import decode_otlp_span
|
|
@@ -37,6 +38,7 @@ router = APIRouter(tags=["traces"])
|
|
|
37
38
|
|
|
38
39
|
@router.post(
|
|
39
40
|
"/traces",
|
|
41
|
+
dependencies=[Depends(is_not_locked)],
|
|
40
42
|
operation_id="addTraces",
|
|
41
43
|
summary="Send traces",
|
|
42
44
|
responses=add_errors_to_responses(
|
|
@@ -160,6 +162,7 @@ class AnnotateTracesResponseBody(ResponseBody[list[InsertedTraceAnnotation]]):
|
|
|
160
162
|
|
|
161
163
|
@router.post(
|
|
162
164
|
"/trace_annotations",
|
|
165
|
+
dependencies=[Depends(is_not_locked)],
|
|
163
166
|
operation_id="annotateTraces",
|
|
164
167
|
summary="Create trace annotations",
|
|
165
168
|
responses=add_errors_to_responses(
|
|
@@ -43,7 +43,7 @@ from phoenix.server.api.routers.v1.utils import (
|
|
|
43
43
|
add_errors_to_responses,
|
|
44
44
|
)
|
|
45
45
|
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
46
|
-
from phoenix.server.authorization import require_admin
|
|
46
|
+
from phoenix.server.authorization import is_not_locked, require_admin
|
|
47
47
|
|
|
48
48
|
logger = logging.getLogger(__name__)
|
|
49
49
|
|
|
@@ -194,7 +194,7 @@ async def list_users(
|
|
|
194
194
|
HTTP_422_UNPROCESSABLE_ENTITY,
|
|
195
195
|
]
|
|
196
196
|
),
|
|
197
|
-
dependencies=[Depends(require_admin)],
|
|
197
|
+
dependencies=[Depends(require_admin), Depends(is_not_locked)],
|
|
198
198
|
response_model_by_alias=True,
|
|
199
199
|
response_model_exclude_unset=True,
|
|
200
200
|
response_model_exclude_defaults=True,
|
phoenix/server/api/types/Span.py
CHANGED
|
@@ -19,7 +19,6 @@ from typing_extensions import Annotated, TypeAlias
|
|
|
19
19
|
import phoenix.trace.schemas as trace_schema
|
|
20
20
|
from phoenix.db import models
|
|
21
21
|
from phoenix.server.api.context import Context
|
|
22
|
-
from phoenix.server.api.dataloaders import types as dataloader_types
|
|
23
22
|
from phoenix.server.api.helpers.dataset_helpers import (
|
|
24
23
|
get_dataset_example_input,
|
|
25
24
|
get_dataset_example_output,
|
|
@@ -829,104 +828,6 @@ class Span(Node):
|
|
|
829
828
|
for entry in entries
|
|
830
829
|
]
|
|
831
830
|
|
|
832
|
-
@strawberry.field
|
|
833
|
-
async def cumulative_cost_summary(self, info: Info[Context, None]) -> Optional[SpanCostSummary]:
|
|
834
|
-
max_depth = 0
|
|
835
|
-
descendant_rowids = await info.context.data_loaders.span_descendants.load(
|
|
836
|
-
(self.span_rowid, max_depth)
|
|
837
|
-
)
|
|
838
|
-
span_costs = await info.context.data_loaders.span_cost_by_span.load_many(
|
|
839
|
-
(self.span_rowid, *descendant_rowids)
|
|
840
|
-
)
|
|
841
|
-
total_cost: Optional[float] = None
|
|
842
|
-
total_tokens: Optional[float] = None
|
|
843
|
-
prompt_cost: Optional[float] = None
|
|
844
|
-
prompt_tokens: Optional[float] = None
|
|
845
|
-
completion_cost: Optional[float] = None
|
|
846
|
-
completion_tokens: Optional[float] = None
|
|
847
|
-
for span_cost in span_costs:
|
|
848
|
-
if span_cost is None:
|
|
849
|
-
continue
|
|
850
|
-
if span_cost.total_cost is not None:
|
|
851
|
-
total_cost = (total_cost or 0) + span_cost.total_cost
|
|
852
|
-
if span_cost.total_tokens is not None:
|
|
853
|
-
total_tokens = (total_tokens or 0) + span_cost.total_tokens
|
|
854
|
-
if span_cost.prompt_cost is not None:
|
|
855
|
-
prompt_cost = (prompt_cost or 0) + span_cost.prompt_cost
|
|
856
|
-
if span_cost.prompt_tokens is not None:
|
|
857
|
-
prompt_tokens = (prompt_tokens or 0) + span_cost.prompt_tokens
|
|
858
|
-
if span_cost.completion_cost is not None:
|
|
859
|
-
completion_cost = (completion_cost or 0) + span_cost.completion_cost
|
|
860
|
-
if span_cost.completion_tokens is not None:
|
|
861
|
-
completion_tokens = (completion_tokens or 0) + span_cost.completion_tokens
|
|
862
|
-
return SpanCostSummary(
|
|
863
|
-
prompt=CostBreakdown(
|
|
864
|
-
tokens=prompt_tokens,
|
|
865
|
-
cost=prompt_cost,
|
|
866
|
-
),
|
|
867
|
-
completion=CostBreakdown(
|
|
868
|
-
tokens=completion_tokens,
|
|
869
|
-
cost=completion_cost,
|
|
870
|
-
),
|
|
871
|
-
total=CostBreakdown(
|
|
872
|
-
tokens=total_tokens,
|
|
873
|
-
cost=total_cost,
|
|
874
|
-
),
|
|
875
|
-
)
|
|
876
|
-
|
|
877
|
-
@strawberry.field
|
|
878
|
-
async def cumulative_cost_detail_summary_entries(
|
|
879
|
-
self, info: Info[Context, None]
|
|
880
|
-
) -> list[SpanCostDetailSummaryEntry]:
|
|
881
|
-
max_depth = 0
|
|
882
|
-
descendant_rowids = await info.context.data_loaders.span_descendants.load(
|
|
883
|
-
(self.span_rowid, max_depth)
|
|
884
|
-
)
|
|
885
|
-
entry_lists = (
|
|
886
|
-
await info.context.data_loaders.span_cost_detail_summary_entries_by_span.load_many(
|
|
887
|
-
(self.span_rowid, *descendant_rowids)
|
|
888
|
-
)
|
|
889
|
-
)
|
|
890
|
-
|
|
891
|
-
TokenType: TypeAlias = str
|
|
892
|
-
IsPrompt: TypeAlias = bool
|
|
893
|
-
grouped_entries: dict[
|
|
894
|
-
IsPrompt, dict[TokenType, list[dataloader_types.SpanCostDetailSummaryEntry]]
|
|
895
|
-
] = {}
|
|
896
|
-
|
|
897
|
-
for entries in entry_lists:
|
|
898
|
-
for entry in entries:
|
|
899
|
-
is_prompt = entry.is_prompt
|
|
900
|
-
token_type = entry.token_type
|
|
901
|
-
if is_prompt not in grouped_entries:
|
|
902
|
-
grouped_entries[is_prompt] = {}
|
|
903
|
-
if token_type not in grouped_entries[is_prompt]:
|
|
904
|
-
grouped_entries[is_prompt][token_type] = []
|
|
905
|
-
grouped_entries[is_prompt][token_type].append(entry)
|
|
906
|
-
|
|
907
|
-
result: list[SpanCostDetailSummaryEntry] = []
|
|
908
|
-
for is_prompt in (True, False):
|
|
909
|
-
entries_by_token_type = grouped_entries[is_prompt]
|
|
910
|
-
for token_type, entries in sorted(entries_by_token_type.items()):
|
|
911
|
-
cost: Optional[float] = None
|
|
912
|
-
tokens: Optional[float] = None
|
|
913
|
-
for entry in entries:
|
|
914
|
-
if entry.value.cost is not None:
|
|
915
|
-
cost = (cost or 0) + entry.value.cost
|
|
916
|
-
if entry.value.tokens is not None:
|
|
917
|
-
tokens = (tokens or 0) + entry.value.tokens
|
|
918
|
-
result.append(
|
|
919
|
-
SpanCostDetailSummaryEntry(
|
|
920
|
-
token_type=token_type,
|
|
921
|
-
is_prompt=is_prompt,
|
|
922
|
-
value=CostBreakdown(
|
|
923
|
-
tokens=tokens,
|
|
924
|
-
cost=cost,
|
|
925
|
-
),
|
|
926
|
-
)
|
|
927
|
-
)
|
|
928
|
-
return result
|
|
929
|
-
|
|
930
831
|
|
|
931
832
|
def _hide_embedding_vectors(attributes: Mapping[str, Any]) -> Mapping[str, Any]:
|
|
932
833
|
if not (
|
phoenix/server/app.py
CHANGED
|
@@ -16,17 +16,20 @@ from typing import (
|
|
|
16
16
|
Any,
|
|
17
17
|
NamedTuple,
|
|
18
18
|
Optional,
|
|
19
|
+
Protocol,
|
|
19
20
|
TypedDict,
|
|
20
21
|
Union,
|
|
21
22
|
cast,
|
|
22
23
|
)
|
|
23
24
|
from urllib.parse import urlparse
|
|
24
25
|
|
|
26
|
+
import grpc
|
|
25
27
|
import strawberry
|
|
26
28
|
from fastapi import APIRouter, Depends, FastAPI
|
|
27
29
|
from fastapi.middleware.cors import CORSMiddleware
|
|
28
30
|
from fastapi.utils import is_body_allowed_for_status_code
|
|
29
31
|
from grpc.aio import ServerInterceptor
|
|
32
|
+
from grpc_interceptor import AsyncServerInterceptor
|
|
30
33
|
from sqlalchemy import select
|
|
31
34
|
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
|
|
32
35
|
from starlette.datastructures import URL, Secret
|
|
@@ -44,7 +47,7 @@ from starlette.types import Scope, StatefulLifespan
|
|
|
44
47
|
from strawberry.extensions import SchemaExtension
|
|
45
48
|
from strawberry.fastapi import GraphQLRouter
|
|
46
49
|
from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL
|
|
47
|
-
from typing_extensions import TypeAlias
|
|
50
|
+
from typing_extensions import TypeAlias, override
|
|
48
51
|
|
|
49
52
|
import phoenix.trace.v1 as pb
|
|
50
53
|
from phoenix.config import (
|
|
@@ -134,6 +137,7 @@ from phoenix.server.api.routers import (
|
|
|
134
137
|
from phoenix.server.api.routers.v1 import REST_API_VERSION
|
|
135
138
|
from phoenix.server.api.schema import build_graphql_schema
|
|
136
139
|
from phoenix.server.bearer_auth import BearerTokenAuthBackend, is_authenticated
|
|
140
|
+
from phoenix.server.daemons.db_disk_usage_monitor import DbDiskUsageMonitor
|
|
137
141
|
from phoenix.server.daemons.generative_model_store import GenerativeModelStore
|
|
138
142
|
from phoenix.server.daemons.span_cost_calculator import SpanCostCalculator
|
|
139
143
|
from phoenix.server.dml_event import DmlEvent
|
|
@@ -382,19 +386,16 @@ async def version() -> PlainTextResponse:
|
|
|
382
386
|
return PlainTextResponse(f"{phoenix_version}")
|
|
383
387
|
|
|
384
388
|
|
|
385
|
-
DB_MUTEX: Optional[asyncio.Lock] = None
|
|
386
|
-
|
|
387
|
-
|
|
388
389
|
def _db(
|
|
389
|
-
engine: AsyncEngine,
|
|
390
|
-
) -> Callable[[], AbstractAsyncContextManager[AsyncSession]]:
|
|
390
|
+
engine: AsyncEngine,
|
|
391
|
+
) -> Callable[[Optional[asyncio.Lock]], AbstractAsyncContextManager[AsyncSession]]:
|
|
391
392
|
Session = async_sessionmaker(engine, expire_on_commit=False)
|
|
392
393
|
|
|
393
394
|
@contextlib.asynccontextmanager
|
|
394
|
-
async def factory() -> AsyncIterator[AsyncSession]:
|
|
395
|
+
async def factory(lock: Optional[asyncio.Lock] = None) -> AsyncIterator[AsyncSession]:
|
|
395
396
|
async with contextlib.AsyncExitStack() as stack:
|
|
396
|
-
if
|
|
397
|
-
await stack.enter_async_context(
|
|
397
|
+
if lock:
|
|
398
|
+
await stack.enter_async_context(lock)
|
|
398
399
|
yield await stack.enter_async_context(Session.begin())
|
|
399
400
|
|
|
400
401
|
return factory
|
|
@@ -523,6 +524,7 @@ def _lifespan(
|
|
|
523
524
|
trace_data_sweeper: Optional[TraceDataSweeper],
|
|
524
525
|
span_cost_calculator: SpanCostCalculator,
|
|
525
526
|
generative_model_store: GenerativeModelStore,
|
|
527
|
+
db_disk_usage_monitor: DbDiskUsageMonitor,
|
|
526
528
|
token_store: Optional[TokenStore] = None,
|
|
527
529
|
tracer_provider: Optional["TracerProvider"] = None,
|
|
528
530
|
enable_prometheus: bool = False,
|
|
@@ -530,14 +532,14 @@ def _lifespan(
|
|
|
530
532
|
shutdown_callbacks: Iterable[_Callback] = (),
|
|
531
533
|
read_only: bool = False,
|
|
532
534
|
scaffolder_config: Optional[ScaffolderConfig] = None,
|
|
535
|
+
grpc_interceptors: Iterable[AsyncServerInterceptor] = (),
|
|
533
536
|
) -> StatefulLifespan[FastAPI]:
|
|
534
537
|
@contextlib.asynccontextmanager
|
|
535
538
|
async def lifespan(_: FastAPI) -> AsyncIterator[dict[str, Any]]:
|
|
536
539
|
for callback in startup_callbacks:
|
|
537
540
|
if isinstance((res := callback()), Awaitable):
|
|
538
541
|
await res
|
|
539
|
-
|
|
540
|
-
DB_MUTEX = asyncio.Lock() if db.dialect is SupportedSQLDialect.SQLITE else None
|
|
542
|
+
db.lock = asyncio.Lock() if db.dialect is SupportedSQLDialect.SQLITE else None
|
|
541
543
|
async with AsyncExitStack() as stack:
|
|
542
544
|
(
|
|
543
545
|
enqueue,
|
|
@@ -551,7 +553,7 @@ def _lifespan(
|
|
|
551
553
|
tracer_provider=tracer_provider,
|
|
552
554
|
enable_prometheus=enable_prometheus,
|
|
553
555
|
token_store=token_store,
|
|
554
|
-
interceptors=user_grpc_interceptors(),
|
|
556
|
+
interceptors=user_grpc_interceptors() + list(grpc_interceptors),
|
|
555
557
|
)
|
|
556
558
|
await stack.enter_async_context(grpc_server)
|
|
557
559
|
await stack.enter_async_context(dml_event_handler)
|
|
@@ -559,6 +561,7 @@ def _lifespan(
|
|
|
559
561
|
await stack.enter_async_context(trace_data_sweeper)
|
|
560
562
|
await stack.enter_async_context(span_cost_calculator)
|
|
561
563
|
await stack.enter_async_context(generative_model_store)
|
|
564
|
+
await stack.enter_async_context(db_disk_usage_monitor)
|
|
562
565
|
if scaffolder_config:
|
|
563
566
|
scaffolder = Scaffolder(
|
|
564
567
|
config=scaffolder_config,
|
|
@@ -826,6 +829,34 @@ async def plain_text_http_exception_handler(request: Request, exc: HTTPException
|
|
|
826
829
|
return PlainTextResponse(str(exc.detail), status_code=exc.status_code, headers=headers)
|
|
827
830
|
|
|
828
831
|
|
|
832
|
+
class _HasDbStatus(Protocol):
|
|
833
|
+
@property
|
|
834
|
+
def should_not_insert_or_update(self) -> bool: ...
|
|
835
|
+
|
|
836
|
+
|
|
837
|
+
class DbDiskUsageInterceptor(AsyncServerInterceptor):
|
|
838
|
+
def __init__(self, db: _HasDbStatus) -> None:
|
|
839
|
+
self._db = db
|
|
840
|
+
|
|
841
|
+
@override
|
|
842
|
+
async def intercept(
|
|
843
|
+
self,
|
|
844
|
+
method: Callable[[Any, grpc.aio.ServicerContext], Awaitable[Any]],
|
|
845
|
+
request_or_iterator: Any,
|
|
846
|
+
context: grpc.aio.ServicerContext,
|
|
847
|
+
method_name: str,
|
|
848
|
+
) -> Any:
|
|
849
|
+
if (
|
|
850
|
+
method_name.endswith("trace.v1.TraceService/Export")
|
|
851
|
+
and self._db.should_not_insert_or_update
|
|
852
|
+
):
|
|
853
|
+
await context.abort(
|
|
854
|
+
grpc.StatusCode.RESOURCE_EXHAUSTED,
|
|
855
|
+
"Database disk usage threshold exceeded",
|
|
856
|
+
)
|
|
857
|
+
return await method(request_or_iterator, context)
|
|
858
|
+
|
|
859
|
+
|
|
829
860
|
def create_app(
|
|
830
861
|
db: DbSessionFactory,
|
|
831
862
|
export_path: Path,
|
|
@@ -971,6 +1002,8 @@ def create_app(
|
|
|
971
1002
|
from phoenix.server.prometheus import PrometheusMiddleware
|
|
972
1003
|
|
|
973
1004
|
middlewares.append(Middleware(PrometheusMiddleware))
|
|
1005
|
+
grpc_interceptors: list[AsyncServerInterceptor] = []
|
|
1006
|
+
grpc_interceptors.append(DbDiskUsageInterceptor(db))
|
|
974
1007
|
app = FastAPI(
|
|
975
1008
|
title="Arize-Phoenix REST API",
|
|
976
1009
|
version=REST_API_VERSION,
|
|
@@ -982,6 +1015,8 @@ def create_app(
|
|
|
982
1015
|
trace_data_sweeper=trace_data_sweeper,
|
|
983
1016
|
span_cost_calculator=span_cost_calculator,
|
|
984
1017
|
generative_model_store=generative_model_store,
|
|
1018
|
+
db_disk_usage_monitor=DbDiskUsageMonitor(db, email_sender),
|
|
1019
|
+
grpc_interceptors=grpc_interceptors,
|
|
985
1020
|
token_store=token_store,
|
|
986
1021
|
tracer_provider=tracer_provider,
|
|
987
1022
|
enable_prometheus=enable_prometheus,
|
phoenix/server/authorization.py
CHANGED
|
@@ -51,3 +51,12 @@ def require_admin(request: Request) -> None:
|
|
|
51
51
|
status_code=fastapi_status.HTTP_403_FORBIDDEN,
|
|
52
52
|
detail="Only admin or system users can perform this action.",
|
|
53
53
|
)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def is_not_locked(request: Request) -> None:
|
|
57
|
+
if request.app.state.db.should_not_insert_or_update:
|
|
58
|
+
raise HTTPException(
|
|
59
|
+
status_code=fastapi_status.HTTP_507_INSUFFICIENT_STORAGE,
|
|
60
|
+
detail="Operations that insert or update database "
|
|
61
|
+
"records are currently not allowed.",
|
|
62
|
+
)
|
phoenix/server/bearer_auth.py
CHANGED
|
@@ -7,10 +7,10 @@ from typing import Any, Optional, cast
|
|
|
7
7
|
import grpc
|
|
8
8
|
from fastapi import HTTPException, Request, WebSocket, WebSocketException
|
|
9
9
|
from grpc_interceptor import AsyncServerInterceptor
|
|
10
|
-
from grpc_interceptor.exceptions import Unauthenticated
|
|
11
10
|
from starlette.authentication import AuthCredentials, AuthenticationBackend, BaseUser
|
|
12
11
|
from starlette.requests import HTTPConnection
|
|
13
12
|
from starlette.status import HTTP_401_UNAUTHORIZED
|
|
13
|
+
from typing_extensions import override
|
|
14
14
|
|
|
15
15
|
from phoenix import config
|
|
16
16
|
from phoenix.auth import (
|
|
@@ -100,16 +100,19 @@ class PhoenixSystemUser(PhoenixUser):
|
|
|
100
100
|
|
|
101
101
|
|
|
102
102
|
class ApiKeyInterceptor(HasTokenStore, AsyncServerInterceptor):
|
|
103
|
+
@override
|
|
103
104
|
async def intercept(
|
|
104
105
|
self,
|
|
105
|
-
method: Callable[[Any, grpc.ServicerContext], Awaitable[Any]],
|
|
106
|
+
method: Callable[[Any, grpc.aio.ServicerContext], Awaitable[Any]],
|
|
106
107
|
request_or_iterator: Any,
|
|
107
|
-
context: grpc.ServicerContext,
|
|
108
|
+
context: grpc.aio.ServicerContext,
|
|
108
109
|
method_name: str,
|
|
109
110
|
) -> Any:
|
|
110
|
-
for
|
|
111
|
-
if
|
|
112
|
-
|
|
111
|
+
for key, value in context.invocation_metadata() or ():
|
|
112
|
+
if key.lower() == "authorization":
|
|
113
|
+
if isinstance(value, bytes):
|
|
114
|
+
value = value.decode("utf-8")
|
|
115
|
+
scheme, _, token = value.partition(" ")
|
|
113
116
|
if scheme.lower() != "bearer" or not token:
|
|
114
117
|
break
|
|
115
118
|
if (
|
|
@@ -119,16 +122,16 @@ class ApiKeyInterceptor(HasTokenStore, AsyncServerInterceptor):
|
|
|
119
122
|
):
|
|
120
123
|
return await method(request_or_iterator, context)
|
|
121
124
|
claims = await self._token_store.read(Token(token))
|
|
122
|
-
if
|
|
125
|
+
if (
|
|
126
|
+
not (
|
|
127
|
+
isinstance(claims, (ApiKeyClaims, AccessTokenClaims))
|
|
128
|
+
and isinstance(claims.subject, UserId)
|
|
129
|
+
)
|
|
130
|
+
or claims.status is not ClaimSetStatus.VALID
|
|
131
|
+
):
|
|
123
132
|
break
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
if claims.status is ClaimSetStatus.EXPIRED:
|
|
127
|
-
raise Unauthenticated(details="Expired token")
|
|
128
|
-
if claims.status is ClaimSetStatus.VALID:
|
|
129
|
-
return await method(request_or_iterator, context)
|
|
130
|
-
raise Unauthenticated()
|
|
131
|
-
raise Unauthenticated()
|
|
133
|
+
return await method(request_or_iterator, context)
|
|
134
|
+
await context.abort(grpc.StatusCode.UNAUTHENTICATED)
|
|
132
135
|
|
|
133
136
|
|
|
134
137
|
async def is_authenticated(
|
|
@@ -109,7 +109,7 @@ class CostModelLookup:
|
|
|
109
109
|
model
|
|
110
110
|
for model in self._models
|
|
111
111
|
if (not model.start_time or model.start_time <= start_time)
|
|
112
|
-
and model.name_pattern.
|
|
112
|
+
and model.name_pattern.search(model_name)
|
|
113
113
|
]
|
|
114
114
|
if not candidates:
|
|
115
115
|
return None
|