arize-phoenix 5.5.2__py3-none-any.whl → 5.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-5.5.2.dist-info → arize_phoenix-5.6.0.dist-info}/METADATA +3 -6
- {arize_phoenix-5.5.2.dist-info → arize_phoenix-5.6.0.dist-info}/RECORD +171 -171
- phoenix/config.py +8 -8
- phoenix/core/model.py +3 -3
- phoenix/core/model_schema.py +41 -50
- phoenix/core/model_schema_adapter.py +17 -16
- phoenix/datetime_utils.py +2 -2
- phoenix/db/bulk_inserter.py +10 -20
- phoenix/db/engines.py +2 -1
- phoenix/db/enums.py +2 -2
- phoenix/db/helpers.py +8 -7
- phoenix/db/insertion/dataset.py +9 -19
- phoenix/db/insertion/document_annotation.py +14 -13
- phoenix/db/insertion/helpers.py +6 -16
- phoenix/db/insertion/span_annotation.py +14 -13
- phoenix/db/insertion/trace_annotation.py +14 -13
- phoenix/db/insertion/types.py +19 -30
- phoenix/db/migrations/versions/3be8647b87d8_add_token_columns_to_spans_table.py +8 -8
- phoenix/db/models.py +28 -28
- phoenix/experiments/evaluators/base.py +2 -1
- phoenix/experiments/evaluators/code_evaluators.py +4 -5
- phoenix/experiments/evaluators/llm_evaluators.py +157 -4
- phoenix/experiments/evaluators/utils.py +3 -2
- phoenix/experiments/functions.py +10 -21
- phoenix/experiments/tracing.py +2 -1
- phoenix/experiments/types.py +20 -29
- phoenix/experiments/utils.py +2 -1
- phoenix/inferences/errors.py +6 -5
- phoenix/inferences/fixtures.py +6 -5
- phoenix/inferences/inferences.py +37 -37
- phoenix/inferences/schema.py +11 -10
- phoenix/inferences/validation.py +13 -14
- phoenix/logging/_formatter.py +3 -3
- phoenix/metrics/__init__.py +5 -4
- phoenix/metrics/binning.py +2 -1
- phoenix/metrics/metrics.py +2 -1
- phoenix/metrics/mixins.py +7 -6
- phoenix/metrics/retrieval_metrics.py +2 -1
- phoenix/metrics/timeseries.py +5 -4
- phoenix/metrics/wrappers.py +2 -2
- phoenix/pointcloud/clustering.py +3 -4
- phoenix/pointcloud/pointcloud.py +7 -5
- phoenix/pointcloud/umap_parameters.py +2 -1
- phoenix/server/api/dataloaders/annotation_summaries.py +12 -19
- phoenix/server/api/dataloaders/average_experiment_run_latency.py +2 -2
- phoenix/server/api/dataloaders/cache/two_tier_cache.py +3 -2
- phoenix/server/api/dataloaders/dataset_example_revisions.py +3 -8
- phoenix/server/api/dataloaders/dataset_example_spans.py +2 -5
- phoenix/server/api/dataloaders/document_evaluation_summaries.py +12 -18
- phoenix/server/api/dataloaders/document_evaluations.py +3 -7
- phoenix/server/api/dataloaders/document_retrieval_metrics.py +6 -13
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +4 -8
- phoenix/server/api/dataloaders/experiment_error_rates.py +2 -5
- phoenix/server/api/dataloaders/experiment_run_annotations.py +3 -7
- phoenix/server/api/dataloaders/experiment_run_counts.py +1 -5
- phoenix/server/api/dataloaders/experiment_sequence_number.py +2 -5
- phoenix/server/api/dataloaders/latency_ms_quantile.py +21 -30
- phoenix/server/api/dataloaders/min_start_or_max_end_times.py +7 -13
- phoenix/server/api/dataloaders/project_by_name.py +3 -3
- phoenix/server/api/dataloaders/record_counts.py +11 -18
- phoenix/server/api/dataloaders/span_annotations.py +3 -7
- phoenix/server/api/dataloaders/span_dataset_examples.py +3 -8
- phoenix/server/api/dataloaders/span_descendants.py +3 -7
- phoenix/server/api/dataloaders/span_projects.py +2 -2
- phoenix/server/api/dataloaders/token_counts.py +12 -19
- phoenix/server/api/dataloaders/trace_row_ids.py +3 -7
- phoenix/server/api/dataloaders/user_roles.py +3 -3
- phoenix/server/api/dataloaders/users.py +3 -3
- phoenix/server/api/helpers/__init__.py +4 -3
- phoenix/server/api/helpers/dataset_helpers.py +10 -9
- phoenix/server/api/input_types/AddExamplesToDatasetInput.py +2 -2
- phoenix/server/api/input_types/AddSpansToDatasetInput.py +2 -2
- phoenix/server/api/input_types/ChatCompletionMessageInput.py +13 -1
- phoenix/server/api/input_types/ClusterInput.py +2 -2
- phoenix/server/api/input_types/DeleteAnnotationsInput.py +1 -3
- phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +2 -2
- phoenix/server/api/input_types/DeleteExperimentsInput.py +1 -3
- phoenix/server/api/input_types/DimensionFilter.py +4 -4
- phoenix/server/api/input_types/Granularity.py +1 -1
- phoenix/server/api/input_types/InvocationParameters.py +2 -2
- phoenix/server/api/input_types/PatchDatasetExamplesInput.py +2 -2
- phoenix/server/api/mutations/dataset_mutations.py +4 -4
- phoenix/server/api/mutations/experiment_mutations.py +1 -2
- phoenix/server/api/mutations/export_events_mutations.py +7 -7
- phoenix/server/api/mutations/span_annotations_mutations.py +4 -4
- phoenix/server/api/mutations/trace_annotations_mutations.py +4 -4
- phoenix/server/api/mutations/user_mutations.py +4 -4
- phoenix/server/api/openapi/schema.py +2 -2
- phoenix/server/api/queries.py +20 -20
- phoenix/server/api/routers/oauth2.py +4 -4
- phoenix/server/api/routers/v1/datasets.py +22 -36
- phoenix/server/api/routers/v1/evaluations.py +6 -5
- phoenix/server/api/routers/v1/experiment_evaluations.py +2 -2
- phoenix/server/api/routers/v1/experiment_runs.py +2 -2
- phoenix/server/api/routers/v1/experiments.py +4 -4
- phoenix/server/api/routers/v1/spans.py +13 -12
- phoenix/server/api/routers/v1/traces.py +5 -5
- phoenix/server/api/routers/v1/utils.py +5 -5
- phoenix/server/api/subscriptions.py +284 -162
- phoenix/server/api/types/AnnotationSummary.py +3 -3
- phoenix/server/api/types/Cluster.py +8 -7
- phoenix/server/api/types/Dataset.py +5 -4
- phoenix/server/api/types/Dimension.py +3 -3
- phoenix/server/api/types/DocumentEvaluationSummary.py +8 -7
- phoenix/server/api/types/EmbeddingDimension.py +6 -5
- phoenix/server/api/types/EvaluationSummary.py +3 -3
- phoenix/server/api/types/Event.py +7 -7
- phoenix/server/api/types/Experiment.py +3 -3
- phoenix/server/api/types/ExperimentComparison.py +2 -4
- phoenix/server/api/types/Inferences.py +9 -8
- phoenix/server/api/types/InferencesRole.py +2 -2
- phoenix/server/api/types/Model.py +2 -2
- phoenix/server/api/types/Project.py +11 -18
- phoenix/server/api/types/Segments.py +3 -3
- phoenix/server/api/types/Span.py +8 -7
- phoenix/server/api/types/TimeSeries.py +8 -7
- phoenix/server/api/types/Trace.py +2 -2
- phoenix/server/api/types/UMAPPoints.py +6 -6
- phoenix/server/api/types/User.py +3 -3
- phoenix/server/api/types/node.py +1 -3
- phoenix/server/api/types/pagination.py +4 -4
- phoenix/server/api/utils.py +2 -4
- phoenix/server/app.py +16 -25
- phoenix/server/bearer_auth.py +4 -10
- phoenix/server/dml_event.py +3 -3
- phoenix/server/dml_event_handler.py +10 -24
- phoenix/server/grpc_server.py +3 -2
- phoenix/server/jwt_store.py +22 -21
- phoenix/server/main.py +3 -3
- phoenix/server/oauth2.py +3 -2
- phoenix/server/rate_limiters.py +5 -8
- phoenix/server/static/.vite/manifest.json +31 -31
- phoenix/server/static/assets/components-C70HJiXz.js +1612 -0
- phoenix/server/static/assets/{index-DCzakdJq.js → index-DLe1Oo3l.js} +2 -2
- phoenix/server/static/assets/{pages-CAL1FDMt.js → pages-C8-Sl7JI.js} +269 -434
- phoenix/server/static/assets/{vendor-6IcPAw_j.js → vendor-CtqfhlbC.js} +6 -6
- phoenix/server/static/assets/{vendor-arizeai-DRZuoyuF.js → vendor-arizeai-C_3SBz56.js} +2 -2
- phoenix/server/static/assets/{vendor-codemirror-DVE2_WBr.js → vendor-codemirror-wfdk9cjp.js} +1 -1
- phoenix/server/static/assets/{vendor-recharts-DwrexFA4.js → vendor-recharts-BiVnSv90.js} +1 -1
- phoenix/server/thread_server.py +1 -1
- phoenix/server/types.py +17 -29
- phoenix/services.py +4 -3
- phoenix/session/client.py +12 -24
- phoenix/session/data_extractor.py +3 -3
- phoenix/session/evaluation.py +1 -2
- phoenix/session/session.py +11 -20
- phoenix/trace/attributes.py +16 -28
- phoenix/trace/dsl/filter.py +17 -21
- phoenix/trace/dsl/helpers.py +3 -3
- phoenix/trace/dsl/query.py +13 -22
- phoenix/trace/fixtures.py +11 -17
- phoenix/trace/otel.py +5 -15
- phoenix/trace/projects.py +3 -2
- phoenix/trace/schemas.py +2 -2
- phoenix/trace/span_evaluations.py +9 -8
- phoenix/trace/span_json_decoder.py +3 -3
- phoenix/trace/span_json_encoder.py +2 -2
- phoenix/trace/trace_dataset.py +6 -5
- phoenix/trace/utils.py +6 -6
- phoenix/utilities/deprecation.py +3 -2
- phoenix/utilities/error_handling.py +3 -2
- phoenix/utilities/json.py +2 -1
- phoenix/utilities/logging.py +2 -2
- phoenix/utilities/project.py +1 -1
- phoenix/utilities/re.py +3 -4
- phoenix/utilities/template_formatters.py +5 -4
- phoenix/version.py +1 -1
- phoenix/server/static/assets/components-hX0LgYz3.js +0 -1428
- {arize_phoenix-5.5.2.dist-info → arize_phoenix-5.6.0.dist-info}/WHEEL +0 -0
- {arize_phoenix-5.5.2.dist-info → arize_phoenix-5.6.0.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-5.5.2.dist-info → arize_phoenix-5.6.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-5.5.2.dist-info → arize_phoenix-5.6.0.dist-info}/licenses/LICENSE +0 -0
phoenix/server/api/types/User.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from datetime import datetime
|
|
2
|
-
from typing import
|
|
2
|
+
from typing import Optional
|
|
3
3
|
|
|
4
4
|
import strawberry
|
|
5
5
|
from sqlalchemy import select
|
|
@@ -35,7 +35,7 @@ class User(Node):
|
|
|
35
35
|
return to_gql_user_role(role)
|
|
36
36
|
|
|
37
37
|
@strawberry.field
|
|
38
|
-
async def api_keys(self, info: Info[Context, None]) ->
|
|
38
|
+
async def api_keys(self, info: Info[Context, None]) -> list[UserApiKey]:
|
|
39
39
|
async with info.context.db() as session:
|
|
40
40
|
api_keys = await session.scalars(
|
|
41
41
|
select(models.ApiKey).where(models.ApiKey.user_id == self.id_attr)
|
|
@@ -43,7 +43,7 @@ class User(Node):
|
|
|
43
43
|
return [to_gql_api_key(api_key) for api_key in api_keys]
|
|
44
44
|
|
|
45
45
|
|
|
46
|
-
def to_gql_user(user: models.User, api_keys: Optional[
|
|
46
|
+
def to_gql_user(user: models.User, api_keys: Optional[list[models.ApiKey]] = None) -> User:
|
|
47
47
|
"""
|
|
48
48
|
Converts an ORM user to a GraphQL user.
|
|
49
49
|
"""
|
phoenix/server/api/types/node.py
CHANGED
|
@@ -2,7 +2,7 @@ import base64
|
|
|
2
2
|
from dataclasses import dataclass
|
|
3
3
|
from datetime import datetime
|
|
4
4
|
from enum import Enum, auto
|
|
5
|
-
from typing import Any, ClassVar,
|
|
5
|
+
from typing import Any, ClassVar, Optional, Union
|
|
6
6
|
|
|
7
7
|
from strawberry import UNSET
|
|
8
8
|
from strawberry.relay.types import Connection, Edge, NodeType, PageInfo
|
|
@@ -176,7 +176,7 @@ class ConnectionArgs:
|
|
|
176
176
|
|
|
177
177
|
|
|
178
178
|
def connection_from_list(
|
|
179
|
-
data:
|
|
179
|
+
data: list[NodeType],
|
|
180
180
|
args: ConnectionArgs,
|
|
181
181
|
) -> Connection[NodeType]:
|
|
182
182
|
"""
|
|
@@ -188,7 +188,7 @@ def connection_from_list(
|
|
|
188
188
|
|
|
189
189
|
|
|
190
190
|
def connection_from_list_slice(
|
|
191
|
-
list_slice:
|
|
191
|
+
list_slice: list[NodeType],
|
|
192
192
|
args: ConnectionArgs,
|
|
193
193
|
slice_start: int,
|
|
194
194
|
list_length: int,
|
|
@@ -254,7 +254,7 @@ def connection_from_list_slice(
|
|
|
254
254
|
|
|
255
255
|
|
|
256
256
|
def connection_from_cursors_and_nodes(
|
|
257
|
-
cursors_and_nodes:
|
|
257
|
+
cursors_and_nodes: list[tuple[Any, NodeType]],
|
|
258
258
|
has_previous_page: bool,
|
|
259
259
|
has_next_page: bool,
|
|
260
260
|
) -> Connection[NodeType]:
|
phoenix/server/api/utils.py
CHANGED
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
from typing import List
|
|
2
|
-
|
|
3
1
|
from sqlalchemy import delete
|
|
4
2
|
|
|
5
3
|
from phoenix.db import models
|
|
@@ -9,7 +7,7 @@ from phoenix.server.types import DbSessionFactory
|
|
|
9
7
|
async def delete_projects(
|
|
10
8
|
db: DbSessionFactory,
|
|
11
9
|
*project_names: str,
|
|
12
|
-
) ->
|
|
10
|
+
) -> list[int]:
|
|
13
11
|
if not project_names:
|
|
14
12
|
return []
|
|
15
13
|
stmt = (
|
|
@@ -24,7 +22,7 @@ async def delete_projects(
|
|
|
24
22
|
async def delete_traces(
|
|
25
23
|
db: DbSessionFactory,
|
|
26
24
|
*trace_ids: str,
|
|
27
|
-
) ->
|
|
25
|
+
) -> list[int]:
|
|
28
26
|
if not trace_ids:
|
|
29
27
|
return []
|
|
30
28
|
stmt = (
|
phoenix/server/app.py
CHANGED
|
@@ -2,7 +2,8 @@ import asyncio
|
|
|
2
2
|
import contextlib
|
|
3
3
|
import json
|
|
4
4
|
import logging
|
|
5
|
-
from
|
|
5
|
+
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Sequence
|
|
6
|
+
from contextlib import AbstractAsyncContextManager, AsyncExitStack
|
|
6
7
|
from dataclasses import dataclass, field
|
|
7
8
|
from datetime import datetime, timedelta, timezone
|
|
8
9
|
from functools import cached_property
|
|
@@ -11,18 +12,8 @@ from types import MethodType
|
|
|
11
12
|
from typing import (
|
|
12
13
|
TYPE_CHECKING,
|
|
13
14
|
Any,
|
|
14
|
-
AsyncContextManager,
|
|
15
|
-
AsyncIterator,
|
|
16
|
-
Awaitable,
|
|
17
|
-
Callable,
|
|
18
|
-
Dict,
|
|
19
|
-
Iterable,
|
|
20
|
-
List,
|
|
21
15
|
NamedTuple,
|
|
22
16
|
Optional,
|
|
23
|
-
Sequence,
|
|
24
|
-
Tuple,
|
|
25
|
-
Type,
|
|
26
17
|
TypedDict,
|
|
27
18
|
Union,
|
|
28
19
|
cast,
|
|
@@ -188,10 +179,10 @@ class Static(StaticFiles):
|
|
|
188
179
|
super().__init__(**kwargs)
|
|
189
180
|
|
|
190
181
|
@cached_property
|
|
191
|
-
def _web_manifest(self) ->
|
|
182
|
+
def _web_manifest(self) -> dict[str, Any]:
|
|
192
183
|
try:
|
|
193
184
|
with open(self._app_config.web_manifest_path, "r") as f:
|
|
194
|
-
return cast(
|
|
185
|
+
return cast(dict[str, Any], json.load(f))
|
|
195
186
|
except FileNotFoundError as e:
|
|
196
187
|
if self._app_config.is_development:
|
|
197
188
|
return {}
|
|
@@ -233,7 +224,7 @@ class Static(StaticFiles):
|
|
|
233
224
|
|
|
234
225
|
|
|
235
226
|
class RequestOriginHostnameValidator(BaseHTTPMiddleware):
|
|
236
|
-
def __init__(self, trusted_hostnames:
|
|
227
|
+
def __init__(self, trusted_hostnames: list[str], *args: Any, **kwargs: Any) -> None:
|
|
237
228
|
super().__init__(*args, **kwargs)
|
|
238
229
|
self._trusted_hostnames = trusted_hostnames
|
|
239
230
|
|
|
@@ -278,7 +269,7 @@ DB_MUTEX: Optional[asyncio.Lock] = None
|
|
|
278
269
|
|
|
279
270
|
def _db(
|
|
280
271
|
engine: AsyncEngine, bypass_lock: bool = False
|
|
281
|
-
) -> Callable[[],
|
|
272
|
+
) -> Callable[[], AbstractAsyncContextManager[AsyncSession]]:
|
|
282
273
|
Session = async_sessionmaker(engine, expire_on_commit=False)
|
|
283
274
|
|
|
284
275
|
@contextlib.asynccontextmanager
|
|
@@ -420,7 +411,7 @@ def _lifespan(
|
|
|
420
411
|
scaffolder_config: Optional[ScaffolderConfig] = None,
|
|
421
412
|
) -> StatefulLifespan[FastAPI]:
|
|
422
413
|
@contextlib.asynccontextmanager
|
|
423
|
-
async def lifespan(_: FastAPI) -> AsyncIterator[
|
|
414
|
+
async def lifespan(_: FastAPI) -> AsyncIterator[dict[str, Any]]:
|
|
424
415
|
for callback in startup_callbacks:
|
|
425
416
|
if isinstance((res := callback()), Awaitable):
|
|
426
417
|
await res
|
|
@@ -449,7 +440,7 @@ def _lifespan(
|
|
|
449
440
|
queue_evaluation=queue_evaluation,
|
|
450
441
|
)
|
|
451
442
|
await stack.enter_async_context(scaffolder)
|
|
452
|
-
if isinstance(token_store,
|
|
443
|
+
if isinstance(token_store, AbstractAsyncContextManager):
|
|
453
444
|
await stack.enter_async_context(token_store)
|
|
454
445
|
yield {
|
|
455
446
|
"event_queue": dml_event_handler,
|
|
@@ -607,7 +598,7 @@ def create_engine_and_run_migrations(
|
|
|
607
598
|
raise PhoenixMigrationError(msg) from e
|
|
608
599
|
|
|
609
600
|
|
|
610
|
-
def instrument_engine_if_enabled(engine: AsyncEngine) ->
|
|
601
|
+
def instrument_engine_if_enabled(engine: AsyncEngine) -> list[Callable[[], None]]:
|
|
611
602
|
instrumentation_cleanups = []
|
|
612
603
|
if server_instrumentation_is_enabled():
|
|
613
604
|
from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor
|
|
@@ -667,7 +658,7 @@ def create_app(
|
|
|
667
658
|
dev: bool = False,
|
|
668
659
|
read_only: bool = False,
|
|
669
660
|
enable_prometheus: bool = False,
|
|
670
|
-
initial_spans: Optional[Iterable[Union[Span,
|
|
661
|
+
initial_spans: Optional[Iterable[Union[Span, tuple[Span, str]]]] = None,
|
|
671
662
|
initial_evaluations: Optional[Iterable[pb.Evaluation]] = None,
|
|
672
663
|
serve_ui: bool = True,
|
|
673
664
|
startup_callbacks: Iterable[_Callback] = (),
|
|
@@ -678,7 +669,7 @@ def create_app(
|
|
|
678
669
|
refresh_token_expiry: Optional[timedelta] = None,
|
|
679
670
|
scaffolder_config: Optional[ScaffolderConfig] = None,
|
|
680
671
|
email_sender: Optional[EmailSender] = None,
|
|
681
|
-
oauth2_client_configs: Optional[
|
|
672
|
+
oauth2_client_configs: Optional[list[OAuth2ClientConfig]] = None,
|
|
682
673
|
bulk_inserter_factory: Optional[Callable[..., BulkInserter]] = None,
|
|
683
674
|
) -> FastAPI:
|
|
684
675
|
if model.embedding_dimensions:
|
|
@@ -692,10 +683,10 @@ def create_app(
|
|
|
692
683
|
) from exc
|
|
693
684
|
logger.info(f"Server umap params: {umap_params}")
|
|
694
685
|
bulk_inserter_factory = bulk_inserter_factory or BulkInserter
|
|
695
|
-
startup_callbacks_list:
|
|
696
|
-
shutdown_callbacks_list:
|
|
686
|
+
startup_callbacks_list: list[_Callback] = list(startup_callbacks)
|
|
687
|
+
shutdown_callbacks_list: list[_Callback] = list(shutdown_callbacks)
|
|
697
688
|
startup_callbacks_list.append(Facilitator(db=db))
|
|
698
|
-
initial_batch_of_spans: Iterable[
|
|
689
|
+
initial_batch_of_spans: Iterable[tuple[Span, str]] = (
|
|
699
690
|
()
|
|
700
691
|
if initial_spans is None
|
|
701
692
|
else (
|
|
@@ -708,7 +699,7 @@ def create_app(
|
|
|
708
699
|
CacheForDataLoaders() if db.dialect is SupportedSQLDialect.SQLITE else None
|
|
709
700
|
)
|
|
710
701
|
last_updated_at = LastUpdatedAt()
|
|
711
|
-
middlewares:
|
|
702
|
+
middlewares: list[Middleware] = [Middleware(HeadersMiddleware)]
|
|
712
703
|
if origins := get_env_csrf_trusted_origins():
|
|
713
704
|
trusted_hostnames = [h for o in origins if o and (h := urlparse(o).hostname)]
|
|
714
705
|
middlewares.append(Middleware(RequestOriginHostnameValidator, trusted_hostnames))
|
|
@@ -742,7 +733,7 @@ def create_app(
|
|
|
742
733
|
initial_batch_of_evaluations=initial_batch_of_evaluations,
|
|
743
734
|
)
|
|
744
735
|
tracer_provider = None
|
|
745
|
-
strawberry_extensions:
|
|
736
|
+
strawberry_extensions: list[Union[type[SchemaExtension], SchemaExtension]] = []
|
|
746
737
|
strawberry_extensions.extend(schema.get_extensions())
|
|
747
738
|
if server_instrumentation_is_enabled():
|
|
748
739
|
tracer_provider = initialize_opentelemetry_tracer_provider()
|
phoenix/server/bearer_auth.py
CHANGED
|
@@ -1,14 +1,8 @@
|
|
|
1
1
|
from abc import ABC
|
|
2
|
+
from collections.abc import Awaitable, Callable
|
|
2
3
|
from datetime import datetime, timedelta, timezone
|
|
3
4
|
from functools import cached_property
|
|
4
|
-
from typing import
|
|
5
|
-
Any,
|
|
6
|
-
Awaitable,
|
|
7
|
-
Callable,
|
|
8
|
-
Optional,
|
|
9
|
-
Tuple,
|
|
10
|
-
cast,
|
|
11
|
-
)
|
|
5
|
+
from typing import Any, Optional, cast
|
|
12
6
|
|
|
13
7
|
import grpc
|
|
14
8
|
from fastapi import HTTPException, Request, WebSocket, WebSocketException
|
|
@@ -51,7 +45,7 @@ class BearerTokenAuthBackend(HasTokenStore, AuthenticationBackend):
|
|
|
51
45
|
async def authenticate(
|
|
52
46
|
self,
|
|
53
47
|
conn: HTTPConnection,
|
|
54
|
-
) -> Optional[
|
|
48
|
+
) -> Optional[tuple[AuthCredentials, BaseUser]]:
|
|
55
49
|
if header := conn.headers.get("Authorization"):
|
|
56
50
|
scheme, _, token = header.partition(" ")
|
|
57
51
|
if scheme.lower() != "bearer" or not token:
|
|
@@ -143,7 +137,7 @@ async def create_access_and_refresh_tokens(
|
|
|
143
137
|
user: OrmUser,
|
|
144
138
|
access_token_expiry: timedelta,
|
|
145
139
|
refresh_token_expiry: timedelta,
|
|
146
|
-
) ->
|
|
140
|
+
) -> tuple[AccessToken, RefreshToken]:
|
|
147
141
|
issued_at = datetime.now(timezone.utc)
|
|
148
142
|
user_id = UserId(user.id)
|
|
149
143
|
user_role = UserRole(user.role.name)
|
phoenix/server/dml_event.py
CHANGED
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from abc import ABC
|
|
4
4
|
from dataclasses import dataclass, field
|
|
5
|
-
from typing import ClassVar
|
|
5
|
+
from typing import ClassVar
|
|
6
6
|
|
|
7
7
|
from phoenix.db import models
|
|
8
8
|
|
|
@@ -14,8 +14,8 @@ class DmlEvent(ABC):
|
|
|
14
14
|
operation, e.g. insertion, update, or deletion.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
-
table: ClassVar[
|
|
18
|
-
ids:
|
|
17
|
+
table: ClassVar[type[models.Base]]
|
|
18
|
+
ids: tuple[int, ...] = field(default_factory=tuple)
|
|
19
19
|
|
|
20
20
|
def __bool__(self) -> bool:
|
|
21
21
|
return bool(self.ids)
|
|
@@ -2,24 +2,10 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from abc import ABC, abstractmethod
|
|
4
4
|
from asyncio import gather
|
|
5
|
+
from collections.abc import Callable, Iterable, Iterator, Mapping
|
|
5
6
|
from inspect import getmro
|
|
6
7
|
from itertools import chain
|
|
7
|
-
from typing import
|
|
8
|
-
Any,
|
|
9
|
-
Callable,
|
|
10
|
-
Generic,
|
|
11
|
-
Iterable,
|
|
12
|
-
Iterator,
|
|
13
|
-
Mapping,
|
|
14
|
-
Optional,
|
|
15
|
-
Set,
|
|
16
|
-
Tuple,
|
|
17
|
-
Type,
|
|
18
|
-
TypedDict,
|
|
19
|
-
TypeVar,
|
|
20
|
-
Union,
|
|
21
|
-
cast,
|
|
22
|
-
)
|
|
8
|
+
from typing import Any, Generic, Optional, TypedDict, TypeVar, Union, cast
|
|
23
9
|
|
|
24
10
|
from sqlalchemy import Select, select
|
|
25
11
|
from typing_extensions import TypeAlias, Unpack
|
|
@@ -54,7 +40,7 @@ _DmlEventT = TypeVar("_DmlEventT", bound=DmlEvent)
|
|
|
54
40
|
class _DmlEventQueue(Generic[_DmlEventT]):
|
|
55
41
|
def __init__(self, **kwargs: Any) -> None:
|
|
56
42
|
super().__init__(**kwargs)
|
|
57
|
-
self._events:
|
|
43
|
+
self._events: set[_DmlEventT] = set()
|
|
58
44
|
|
|
59
45
|
@property
|
|
60
46
|
def empty(self) -> bool:
|
|
@@ -120,7 +106,7 @@ class _GenericDmlEventHandler(_DmlEventHandler[DmlEvent]):
|
|
|
120
106
|
for id_ in e.ids:
|
|
121
107
|
self._update(e.table, id_)
|
|
122
108
|
|
|
123
|
-
def _update(self, table:
|
|
109
|
+
def _update(self, table: type[Base], id_: int) -> None:
|
|
124
110
|
self._last_updated_at.set(table, id_)
|
|
125
111
|
|
|
126
112
|
|
|
@@ -146,9 +132,9 @@ class _SpanDeleteEventHandler(_SpanDmlEventHandler):
|
|
|
146
132
|
|
|
147
133
|
|
|
148
134
|
_AnnotationTable: TypeAlias = Union[
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
135
|
+
type[SpanAnnotation],
|
|
136
|
+
type[TraceAnnotation],
|
|
137
|
+
type[DocumentAnnotation],
|
|
152
138
|
]
|
|
153
139
|
|
|
154
140
|
_AnnotationDmlEventT = TypeVar(
|
|
@@ -165,7 +151,7 @@ class _AnnotationDmlEventHandler(
|
|
|
165
151
|
ABC,
|
|
166
152
|
):
|
|
167
153
|
_table: _AnnotationTable
|
|
168
|
-
_base_stmt: Union[Select[
|
|
154
|
+
_base_stmt: Union[Select[tuple[int, str]], Select[tuple[int]]] = (
|
|
169
155
|
select(Project.id).join_from(Project, Trace).distinct()
|
|
170
156
|
)
|
|
171
157
|
|
|
@@ -175,7 +161,7 @@ class _AnnotationDmlEventHandler(
|
|
|
175
161
|
if self._cache_for_dataloaders:
|
|
176
162
|
self._stmt = self._stmt.add_columns(self._table.name)
|
|
177
163
|
|
|
178
|
-
def _get_stmt(self) -> Union[Select[
|
|
164
|
+
def _get_stmt(self) -> Union[Select[tuple[int, str]], Select[tuple[int]]]:
|
|
179
165
|
ids = set(chain.from_iterable(e.ids for e in self._batch))
|
|
180
166
|
return self._stmt.where(self._table.id.in_(ids))
|
|
181
167
|
|
|
@@ -242,7 +228,7 @@ class DmlEventHandler:
|
|
|
242
228
|
cache_for_dataloaders=cache_for_dataloaders,
|
|
243
229
|
sleep_seconds=sleep_seconds,
|
|
244
230
|
)
|
|
245
|
-
self._handlers: Mapping[
|
|
231
|
+
self._handlers: Mapping[type[DmlEvent], Iterable[_DmlEventHandler[Any]]] = {
|
|
246
232
|
DmlEvent: [_GenericDmlEventHandler(**kwargs)],
|
|
247
233
|
SpanDmlEvent: [_SpanDmlEventHandler(**kwargs)],
|
|
248
234
|
SpanDeleteEvent: [_SpanDeleteEventHandler(**kwargs)],
|
phoenix/server/grpc_server.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
from
|
|
1
|
+
from collections.abc import Awaitable, Callable
|
|
2
|
+
from typing import TYPE_CHECKING, Any, Optional
|
|
2
3
|
|
|
3
4
|
import grpc
|
|
4
5
|
from grpc.aio import RpcContext, Server, ServerInterceptor
|
|
@@ -66,7 +67,7 @@ class GrpcServer:
|
|
|
66
67
|
async def __aenter__(self) -> None:
|
|
67
68
|
if self._disabled:
|
|
68
69
|
return
|
|
69
|
-
interceptors:
|
|
70
|
+
interceptors: list[ServerInterceptor] = []
|
|
70
71
|
if self._token_store:
|
|
71
72
|
interceptors.append(ApiKeyInterceptor(self._token_store))
|
|
72
73
|
if self._enable_prometheus:
|
phoenix/server/jwt_store.py
CHANGED
|
@@ -1,11 +1,12 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
from abc import ABC, abstractmethod
|
|
3
3
|
from asyncio import create_task, gather, sleep
|
|
4
|
+
from collections.abc import Callable, Coroutine
|
|
4
5
|
from copy import deepcopy
|
|
5
6
|
from dataclasses import replace
|
|
6
7
|
from datetime import datetime, timezone
|
|
7
8
|
from functools import cached_property, singledispatchmethod
|
|
8
|
-
from typing import Any,
|
|
9
|
+
from typing import Any, Generic, Optional, TypeVar
|
|
9
10
|
|
|
10
11
|
from authlib.jose import jwt
|
|
11
12
|
from authlib.jose.errors import JoseError
|
|
@@ -65,7 +66,7 @@ class JwtStore:
|
|
|
65
66
|
self._api_key_store = _ApiKeyStore(*args, **kwargs)
|
|
66
67
|
|
|
67
68
|
@cached_property
|
|
68
|
-
def _stores(self) ->
|
|
69
|
+
def _stores(self) -> tuple[DaemonTask, ...]:
|
|
69
70
|
return tuple(dt for dt in self.__dict__.values() if isinstance(dt, _Store))
|
|
70
71
|
|
|
71
72
|
async def __aenter__(self) -> None:
|
|
@@ -131,34 +132,34 @@ class JwtStore:
|
|
|
131
132
|
async def create_password_reset_token(
|
|
132
133
|
self,
|
|
133
134
|
claim: PasswordResetTokenClaims,
|
|
134
|
-
) ->
|
|
135
|
+
) -> tuple[PasswordResetToken, PasswordResetTokenId]:
|
|
135
136
|
return await self._password_reset_token_store.create(claim)
|
|
136
137
|
|
|
137
138
|
async def create_access_token(
|
|
138
139
|
self,
|
|
139
140
|
claim: AccessTokenClaims,
|
|
140
|
-
) ->
|
|
141
|
+
) -> tuple[AccessToken, AccessTokenId]:
|
|
141
142
|
return await self._access_token_store.create(claim)
|
|
142
143
|
|
|
143
144
|
async def create_refresh_token(
|
|
144
145
|
self,
|
|
145
146
|
claim: RefreshTokenClaims,
|
|
146
|
-
) ->
|
|
147
|
+
) -> tuple[RefreshToken, RefreshTokenId]:
|
|
147
148
|
return await self._refresh_token_store.create(claim)
|
|
148
149
|
|
|
149
150
|
async def create_api_key(
|
|
150
151
|
self,
|
|
151
152
|
claim: ApiKeyClaims,
|
|
152
|
-
) ->
|
|
153
|
+
) -> tuple[ApiKey, ApiKeyId]:
|
|
153
154
|
return await self._api_key_store.create(claim)
|
|
154
155
|
|
|
155
156
|
async def revoke(self, *token_ids: TokenId) -> None:
|
|
156
157
|
if not token_ids:
|
|
157
158
|
return
|
|
158
|
-
password_reset_token_ids:
|
|
159
|
-
access_token_ids:
|
|
160
|
-
refresh_token_ids:
|
|
161
|
-
api_key_ids:
|
|
159
|
+
password_reset_token_ids: list[PasswordResetTokenId] = []
|
|
160
|
+
access_token_ids: list[AccessTokenId] = []
|
|
161
|
+
refresh_token_ids: list[RefreshTokenId] = []
|
|
162
|
+
api_key_ids: list[ApiKeyId] = []
|
|
162
163
|
for token_id in token_ids:
|
|
163
164
|
if isinstance(token_id, PasswordResetTokenId):
|
|
164
165
|
password_reset_token_ids.append(token_id)
|
|
@@ -168,7 +169,7 @@ class JwtStore:
|
|
|
168
169
|
refresh_token_ids.append(token_id)
|
|
169
170
|
elif isinstance(token_id, ApiKeyId):
|
|
170
171
|
api_key_ids.append(token_id)
|
|
171
|
-
coroutines:
|
|
172
|
+
coroutines: list[Coroutine[None, None, None]] = []
|
|
172
173
|
if password_reset_token_ids:
|
|
173
174
|
coroutines.append(self._password_reset_token_store.revoke(*password_reset_token_ids))
|
|
174
175
|
if access_token_ids:
|
|
@@ -202,7 +203,7 @@ _RecordT = TypeVar(
|
|
|
202
203
|
|
|
203
204
|
class _Claims(Generic[_TokenIdT, _ClaimSetT]):
|
|
204
205
|
def __init__(self) -> None:
|
|
205
|
-
self._cache:
|
|
206
|
+
self._cache: dict[_TokenIdT, _ClaimSetT] = {}
|
|
206
207
|
|
|
207
208
|
def __getitem__(self, token_id: _TokenIdT) -> Optional[_ClaimSetT]:
|
|
208
209
|
claim = self._cache.get(token_id)
|
|
@@ -223,7 +224,7 @@ class _Claims(Generic[_TokenIdT, _ClaimSetT]):
|
|
|
223
224
|
|
|
224
225
|
|
|
225
226
|
class _Store(DaemonTask, Generic[_ClaimSetT, _TokenT, _TokenIdT, _RecordT], ABC):
|
|
226
|
-
_table:
|
|
227
|
+
_table: type[_RecordT]
|
|
227
228
|
_token_id: Callable[[int], _TokenIdT]
|
|
228
229
|
_token: Callable[[str], _TokenT]
|
|
229
230
|
|
|
@@ -244,7 +245,7 @@ class _Store(DaemonTask, Generic[_ClaimSetT, _TokenT, _TokenIdT, _RecordT], ABC)
|
|
|
244
245
|
self._algorithm = algorithm
|
|
245
246
|
|
|
246
247
|
def _encode(self, claim: ClaimSet) -> str:
|
|
247
|
-
payload:
|
|
248
|
+
payload: dict[str, Any] = dict(jti=claim.token_id)
|
|
248
249
|
header = {"alg": self._algorithm}
|
|
249
250
|
jwt_bytes: bytes = jwt.encode(header=header, payload=payload, key=self._secret)
|
|
250
251
|
return jwt_bytes.decode()
|
|
@@ -275,12 +276,12 @@ class _Store(DaemonTask, Generic[_ClaimSetT, _TokenT, _TokenIdT, _RecordT], ABC)
|
|
|
275
276
|
await session.execute(stmt)
|
|
276
277
|
|
|
277
278
|
@abstractmethod
|
|
278
|
-
def _from_db(self, record: _RecordT, role: UserRole) ->
|
|
279
|
+
def _from_db(self, record: _RecordT, role: UserRole) -> tuple[_TokenIdT, _ClaimSetT]: ...
|
|
279
280
|
|
|
280
281
|
@abstractmethod
|
|
281
282
|
def _to_db(self, claims: _ClaimSetT) -> _RecordT: ...
|
|
282
283
|
|
|
283
|
-
async def create(self, claim: _ClaimSetT) ->
|
|
284
|
+
async def create(self, claim: _ClaimSetT) -> tuple[_TokenT, _TokenIdT]:
|
|
284
285
|
record = self._to_db(claim)
|
|
285
286
|
async with self._db() as session:
|
|
286
287
|
session.add(record)
|
|
@@ -303,7 +304,7 @@ class _Store(DaemonTask, Generic[_ClaimSetT, _TokenT, _TokenIdT, _RecordT], ABC)
|
|
|
303
304
|
self._claims = claims
|
|
304
305
|
|
|
305
306
|
@cached_property
|
|
306
|
-
def _update_stmt(self) -> Select[
|
|
307
|
+
def _update_stmt(self) -> Select[tuple[_RecordT, str]]:
|
|
307
308
|
return (
|
|
308
309
|
select(self._table, models.UserRole.name)
|
|
309
310
|
.join_from(self._table, models.User)
|
|
@@ -340,7 +341,7 @@ class _PasswordResetTokenStore(
|
|
|
340
341
|
self,
|
|
341
342
|
record: models.PasswordResetToken,
|
|
342
343
|
user_role: UserRole,
|
|
343
|
-
) ->
|
|
344
|
+
) -> tuple[PasswordResetTokenId, PasswordResetTokenClaims]:
|
|
344
345
|
token_id = PasswordResetTokenId(record.id)
|
|
345
346
|
return token_id, PasswordResetTokenClaims(
|
|
346
347
|
token_id=token_id,
|
|
@@ -379,7 +380,7 @@ class _AccessTokenStore(
|
|
|
379
380
|
self,
|
|
380
381
|
record: models.AccessToken,
|
|
381
382
|
user_role: UserRole,
|
|
382
|
-
) ->
|
|
383
|
+
) -> tuple[AccessTokenId, AccessTokenClaims]:
|
|
383
384
|
token_id = AccessTokenId(record.id)
|
|
384
385
|
refresh_token_id = RefreshTokenId(record.refresh_token_id)
|
|
385
386
|
return token_id, AccessTokenClaims(
|
|
@@ -423,7 +424,7 @@ class _RefreshTokenStore(
|
|
|
423
424
|
self,
|
|
424
425
|
record: models.RefreshToken,
|
|
425
426
|
user_role: UserRole,
|
|
426
|
-
) ->
|
|
427
|
+
) -> tuple[RefreshTokenId, RefreshTokenClaims]:
|
|
427
428
|
token_id = RefreshTokenId(record.id)
|
|
428
429
|
return token_id, RefreshTokenClaims(
|
|
429
430
|
token_id=token_id,
|
|
@@ -469,7 +470,7 @@ class _ApiKeyStore(
|
|
|
469
470
|
self,
|
|
470
471
|
record: models.ApiKey,
|
|
471
472
|
user_role: UserRole,
|
|
472
|
-
) ->
|
|
473
|
+
) -> tuple[ApiKeyId, ApiKeyClaims]:
|
|
473
474
|
token_id = ApiKeyId(record.id)
|
|
474
475
|
return token_id, ApiKeyClaims(
|
|
475
476
|
token_id=token_id,
|
phoenix/server/main.py
CHANGED
|
@@ -6,7 +6,7 @@ from argparse import SUPPRESS, ArgumentParser
|
|
|
6
6
|
from pathlib import Path
|
|
7
7
|
from threading import Thread
|
|
8
8
|
from time import sleep, time
|
|
9
|
-
from typing import
|
|
9
|
+
from typing import Optional
|
|
10
10
|
from urllib.parse import urljoin
|
|
11
11
|
|
|
12
12
|
from jinja2 import BaseLoader, Environment
|
|
@@ -312,8 +312,8 @@ def main() -> None:
|
|
|
312
312
|
|
|
313
313
|
authentication_enabled, secret = get_env_auth_settings()
|
|
314
314
|
|
|
315
|
-
fixture_spans:
|
|
316
|
-
fixture_evals:
|
|
315
|
+
fixture_spans: list[Span] = []
|
|
316
|
+
fixture_evals: list[pb.Evaluation] = []
|
|
317
317
|
if trace_dataset_name is not None:
|
|
318
318
|
fixture_spans, fixture_evals = reset_fixture_span_ids_and_timestamps(
|
|
319
319
|
(
|
phoenix/server/oauth2.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
from
|
|
1
|
+
from collections.abc import Iterable
|
|
2
|
+
from typing import Any
|
|
2
3
|
|
|
3
4
|
from authlib.integrations.base_client import BaseApp
|
|
4
5
|
from authlib.integrations.base_client.async_app import AsyncOAuth2Mixin
|
|
@@ -24,7 +25,7 @@ class OAuth2Client(AsyncOAuth2Mixin, AsyncOpenIDMixin, BaseApp): # type:ignore[
|
|
|
24
25
|
|
|
25
26
|
class OAuth2Clients:
|
|
26
27
|
def __init__(self) -> None:
|
|
27
|
-
self._clients:
|
|
28
|
+
self._clients: dict[str, OAuth2Client] = {}
|
|
28
29
|
|
|
29
30
|
def add_client(self, config: OAuth2ClientConfig) -> None:
|
|
30
31
|
if (idp_name := config.idp_name) in self._clients:
|
phoenix/server/rate_limiters.py
CHANGED
|
@@ -1,13 +1,10 @@
|
|
|
1
1
|
import re
|
|
2
2
|
import time
|
|
3
3
|
from collections import defaultdict
|
|
4
|
+
from collections.abc import Callable, Coroutine
|
|
4
5
|
from functools import partial
|
|
5
6
|
from typing import (
|
|
6
7
|
Any,
|
|
7
|
-
Callable,
|
|
8
|
-
Coroutine,
|
|
9
|
-
DefaultDict,
|
|
10
|
-
List,
|
|
11
8
|
Optional,
|
|
12
9
|
Pattern, # import from re module when we drop support for 3.8
|
|
13
10
|
Union,
|
|
@@ -98,7 +95,7 @@ class ServerRateLimiter:
|
|
|
98
95
|
self._last_cleanup_time = time.time()
|
|
99
96
|
|
|
100
97
|
def _reset_rate_limiters(self) -> None:
|
|
101
|
-
self.cache_partitions:
|
|
98
|
+
self.cache_partitions: list[defaultdict[Any, TokenBucket]] = [
|
|
102
99
|
defaultdict(self.bucket_factory) for _ in range(self.num_partitions)
|
|
103
100
|
]
|
|
104
101
|
|
|
@@ -107,10 +104,10 @@ class ServerRateLimiter:
|
|
|
107
104
|
int(timestamp // self.partition_seconds) % self.num_partitions
|
|
108
105
|
) # a cyclic bucket index
|
|
109
106
|
|
|
110
|
-
def _active_partition_indices(self, current_index: int) ->
|
|
107
|
+
def _active_partition_indices(self, current_index: int) -> list[int]:
|
|
111
108
|
return [(current_index - ii) % self.num_partitions for ii in range(self.active_partitions)]
|
|
112
109
|
|
|
113
|
-
def _inactive_partition_indices(self, current_index: int) ->
|
|
110
|
+
def _inactive_partition_indices(self, current_index: int) -> list[int]:
|
|
114
111
|
active_indices = set(self._active_partition_indices(current_index))
|
|
115
112
|
all_indices = set(range(self.num_partitions))
|
|
116
113
|
return list(all_indices - active_indices)
|
|
@@ -156,7 +153,7 @@ class ServerRateLimiter:
|
|
|
156
153
|
|
|
157
154
|
|
|
158
155
|
def fastapi_ip_rate_limiter(
|
|
159
|
-
rate_limiter: ServerRateLimiter, paths: Optional[
|
|
156
|
+
rate_limiter: ServerRateLimiter, paths: Optional[list[Union[str, Pattern[str]]]] = None
|
|
160
157
|
) -> Callable[[Request], Coroutine[Any, Any, Request]]:
|
|
161
158
|
async def dependency(request: Request) -> Request:
|
|
162
159
|
if paths is None or any(path_match(request.url.path, path) for path in paths):
|