arize-phoenix 3.16.1__py3-none-any.whl → 7.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- arize_phoenix-7.7.1.dist-info/METADATA +261 -0
- arize_phoenix-7.7.1.dist-info/RECORD +345 -0
- {arize_phoenix-3.16.1.dist-info → arize_phoenix-7.7.1.dist-info}/WHEEL +1 -1
- arize_phoenix-7.7.1.dist-info/entry_points.txt +3 -0
- phoenix/__init__.py +86 -14
- phoenix/auth.py +309 -0
- phoenix/config.py +675 -45
- phoenix/core/model.py +32 -30
- phoenix/core/model_schema.py +102 -109
- phoenix/core/model_schema_adapter.py +48 -45
- phoenix/datetime_utils.py +24 -3
- phoenix/db/README.md +54 -0
- phoenix/db/__init__.py +4 -0
- phoenix/db/alembic.ini +85 -0
- phoenix/db/bulk_inserter.py +294 -0
- phoenix/db/engines.py +208 -0
- phoenix/db/enums.py +20 -0
- phoenix/db/facilitator.py +113 -0
- phoenix/db/helpers.py +159 -0
- phoenix/db/insertion/constants.py +2 -0
- phoenix/db/insertion/dataset.py +227 -0
- phoenix/db/insertion/document_annotation.py +171 -0
- phoenix/db/insertion/evaluation.py +191 -0
- phoenix/db/insertion/helpers.py +98 -0
- phoenix/db/insertion/span.py +193 -0
- phoenix/db/insertion/span_annotation.py +158 -0
- phoenix/db/insertion/trace_annotation.py +158 -0
- phoenix/db/insertion/types.py +256 -0
- phoenix/db/migrate.py +86 -0
- phoenix/db/migrations/data_migration_scripts/populate_project_sessions.py +199 -0
- phoenix/db/migrations/env.py +114 -0
- phoenix/db/migrations/script.py.mako +26 -0
- phoenix/db/migrations/versions/10460e46d750_datasets.py +317 -0
- phoenix/db/migrations/versions/3be8647b87d8_add_token_columns_to_spans_table.py +126 -0
- phoenix/db/migrations/versions/4ded9e43755f_create_project_sessions_table.py +66 -0
- phoenix/db/migrations/versions/cd164e83824f_users_and_tokens.py +157 -0
- phoenix/db/migrations/versions/cf03bd6bae1d_init.py +280 -0
- phoenix/db/models.py +807 -0
- phoenix/exceptions.py +5 -1
- phoenix/experiments/__init__.py +6 -0
- phoenix/experiments/evaluators/__init__.py +29 -0
- phoenix/experiments/evaluators/base.py +158 -0
- phoenix/experiments/evaluators/code_evaluators.py +184 -0
- phoenix/experiments/evaluators/llm_evaluators.py +473 -0
- phoenix/experiments/evaluators/utils.py +236 -0
- phoenix/experiments/functions.py +772 -0
- phoenix/experiments/tracing.py +86 -0
- phoenix/experiments/types.py +726 -0
- phoenix/experiments/utils.py +25 -0
- phoenix/inferences/__init__.py +0 -0
- phoenix/{datasets → inferences}/errors.py +6 -5
- phoenix/{datasets → inferences}/fixtures.py +49 -42
- phoenix/{datasets/dataset.py → inferences/inferences.py} +121 -105
- phoenix/{datasets → inferences}/schema.py +11 -11
- phoenix/{datasets → inferences}/validation.py +13 -14
- phoenix/logging/__init__.py +3 -0
- phoenix/logging/_config.py +90 -0
- phoenix/logging/_filter.py +6 -0
- phoenix/logging/_formatter.py +69 -0
- phoenix/metrics/__init__.py +5 -4
- phoenix/metrics/binning.py +4 -3
- phoenix/metrics/metrics.py +2 -1
- phoenix/metrics/mixins.py +7 -6
- phoenix/metrics/retrieval_metrics.py +2 -1
- phoenix/metrics/timeseries.py +5 -4
- phoenix/metrics/wrappers.py +9 -3
- phoenix/pointcloud/clustering.py +5 -5
- phoenix/pointcloud/pointcloud.py +7 -5
- phoenix/pointcloud/projectors.py +5 -6
- phoenix/pointcloud/umap_parameters.py +53 -52
- phoenix/server/api/README.md +28 -0
- phoenix/server/api/auth.py +44 -0
- phoenix/server/api/context.py +152 -9
- phoenix/server/api/dataloaders/__init__.py +91 -0
- phoenix/server/api/dataloaders/annotation_summaries.py +139 -0
- phoenix/server/api/dataloaders/average_experiment_run_latency.py +54 -0
- phoenix/server/api/dataloaders/cache/__init__.py +3 -0
- phoenix/server/api/dataloaders/cache/two_tier_cache.py +68 -0
- phoenix/server/api/dataloaders/dataset_example_revisions.py +131 -0
- phoenix/server/api/dataloaders/dataset_example_spans.py +38 -0
- phoenix/server/api/dataloaders/document_evaluation_summaries.py +144 -0
- phoenix/server/api/dataloaders/document_evaluations.py +31 -0
- phoenix/server/api/dataloaders/document_retrieval_metrics.py +89 -0
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +79 -0
- phoenix/server/api/dataloaders/experiment_error_rates.py +58 -0
- phoenix/server/api/dataloaders/experiment_run_annotations.py +36 -0
- phoenix/server/api/dataloaders/experiment_run_counts.py +49 -0
- phoenix/server/api/dataloaders/experiment_sequence_number.py +44 -0
- phoenix/server/api/dataloaders/latency_ms_quantile.py +188 -0
- phoenix/server/api/dataloaders/min_start_or_max_end_times.py +85 -0
- phoenix/server/api/dataloaders/project_by_name.py +31 -0
- phoenix/server/api/dataloaders/record_counts.py +116 -0
- phoenix/server/api/dataloaders/session_io.py +79 -0
- phoenix/server/api/dataloaders/session_num_traces.py +30 -0
- phoenix/server/api/dataloaders/session_num_traces_with_error.py +32 -0
- phoenix/server/api/dataloaders/session_token_usages.py +41 -0
- phoenix/server/api/dataloaders/session_trace_latency_ms_quantile.py +55 -0
- phoenix/server/api/dataloaders/span_annotations.py +26 -0
- phoenix/server/api/dataloaders/span_dataset_examples.py +31 -0
- phoenix/server/api/dataloaders/span_descendants.py +57 -0
- phoenix/server/api/dataloaders/span_projects.py +33 -0
- phoenix/server/api/dataloaders/token_counts.py +124 -0
- phoenix/server/api/dataloaders/trace_by_trace_ids.py +25 -0
- phoenix/server/api/dataloaders/trace_root_spans.py +32 -0
- phoenix/server/api/dataloaders/user_roles.py +30 -0
- phoenix/server/api/dataloaders/users.py +33 -0
- phoenix/server/api/exceptions.py +48 -0
- phoenix/server/api/helpers/__init__.py +12 -0
- phoenix/server/api/helpers/dataset_helpers.py +217 -0
- phoenix/server/api/helpers/experiment_run_filters.py +763 -0
- phoenix/server/api/helpers/playground_clients.py +948 -0
- phoenix/server/api/helpers/playground_registry.py +70 -0
- phoenix/server/api/helpers/playground_spans.py +455 -0
- phoenix/server/api/input_types/AddExamplesToDatasetInput.py +16 -0
- phoenix/server/api/input_types/AddSpansToDatasetInput.py +14 -0
- phoenix/server/api/input_types/ChatCompletionInput.py +38 -0
- phoenix/server/api/input_types/ChatCompletionMessageInput.py +24 -0
- phoenix/server/api/input_types/ClearProjectInput.py +15 -0
- phoenix/server/api/input_types/ClusterInput.py +2 -2
- phoenix/server/api/input_types/CreateDatasetInput.py +12 -0
- phoenix/server/api/input_types/CreateSpanAnnotationInput.py +18 -0
- phoenix/server/api/input_types/CreateTraceAnnotationInput.py +18 -0
- phoenix/server/api/input_types/DataQualityMetricInput.py +5 -2
- phoenix/server/api/input_types/DatasetExampleInput.py +14 -0
- phoenix/server/api/input_types/DatasetSort.py +17 -0
- phoenix/server/api/input_types/DatasetVersionSort.py +16 -0
- phoenix/server/api/input_types/DeleteAnnotationsInput.py +7 -0
- phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +13 -0
- phoenix/server/api/input_types/DeleteDatasetInput.py +7 -0
- phoenix/server/api/input_types/DeleteExperimentsInput.py +7 -0
- phoenix/server/api/input_types/DimensionFilter.py +4 -4
- phoenix/server/api/input_types/GenerativeModelInput.py +17 -0
- phoenix/server/api/input_types/Granularity.py +1 -1
- phoenix/server/api/input_types/InvocationParameters.py +162 -0
- phoenix/server/api/input_types/PatchAnnotationInput.py +19 -0
- phoenix/server/api/input_types/PatchDatasetExamplesInput.py +35 -0
- phoenix/server/api/input_types/PatchDatasetInput.py +14 -0
- phoenix/server/api/input_types/PerformanceMetricInput.py +5 -2
- phoenix/server/api/input_types/ProjectSessionSort.py +29 -0
- phoenix/server/api/input_types/SpanAnnotationSort.py +17 -0
- phoenix/server/api/input_types/SpanSort.py +134 -69
- phoenix/server/api/input_types/TemplateOptions.py +10 -0
- phoenix/server/api/input_types/TraceAnnotationSort.py +17 -0
- phoenix/server/api/input_types/UserRoleInput.py +9 -0
- phoenix/server/api/mutations/__init__.py +28 -0
- phoenix/server/api/mutations/api_key_mutations.py +167 -0
- phoenix/server/api/mutations/chat_mutations.py +593 -0
- phoenix/server/api/mutations/dataset_mutations.py +591 -0
- phoenix/server/api/mutations/experiment_mutations.py +75 -0
- phoenix/server/api/{types/ExportEventsMutation.py → mutations/export_events_mutations.py} +21 -18
- phoenix/server/api/mutations/project_mutations.py +57 -0
- phoenix/server/api/mutations/span_annotations_mutations.py +128 -0
- phoenix/server/api/mutations/trace_annotations_mutations.py +127 -0
- phoenix/server/api/mutations/user_mutations.py +329 -0
- phoenix/server/api/openapi/__init__.py +0 -0
- phoenix/server/api/openapi/main.py +17 -0
- phoenix/server/api/openapi/schema.py +16 -0
- phoenix/server/api/queries.py +738 -0
- phoenix/server/api/routers/__init__.py +11 -0
- phoenix/server/api/routers/auth.py +284 -0
- phoenix/server/api/routers/embeddings.py +26 -0
- phoenix/server/api/routers/oauth2.py +488 -0
- phoenix/server/api/routers/v1/__init__.py +64 -0
- phoenix/server/api/routers/v1/datasets.py +1017 -0
- phoenix/server/api/routers/v1/evaluations.py +362 -0
- phoenix/server/api/routers/v1/experiment_evaluations.py +115 -0
- phoenix/server/api/routers/v1/experiment_runs.py +167 -0
- phoenix/server/api/routers/v1/experiments.py +308 -0
- phoenix/server/api/routers/v1/pydantic_compat.py +78 -0
- phoenix/server/api/routers/v1/spans.py +267 -0
- phoenix/server/api/routers/v1/traces.py +208 -0
- phoenix/server/api/routers/v1/utils.py +95 -0
- phoenix/server/api/schema.py +44 -241
- phoenix/server/api/subscriptions.py +597 -0
- phoenix/server/api/types/Annotation.py +21 -0
- phoenix/server/api/types/AnnotationSummary.py +55 -0
- phoenix/server/api/types/AnnotatorKind.py +16 -0
- phoenix/server/api/types/ApiKey.py +27 -0
- phoenix/server/api/types/AuthMethod.py +9 -0
- phoenix/server/api/types/ChatCompletionMessageRole.py +11 -0
- phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +46 -0
- phoenix/server/api/types/Cluster.py +25 -24
- phoenix/server/api/types/CreateDatasetPayload.py +8 -0
- phoenix/server/api/types/DataQualityMetric.py +31 -13
- phoenix/server/api/types/Dataset.py +288 -63
- phoenix/server/api/types/DatasetExample.py +85 -0
- phoenix/server/api/types/DatasetExampleRevision.py +34 -0
- phoenix/server/api/types/DatasetVersion.py +14 -0
- phoenix/server/api/types/Dimension.py +32 -31
- phoenix/server/api/types/DocumentEvaluationSummary.py +9 -8
- phoenix/server/api/types/EmbeddingDimension.py +56 -49
- phoenix/server/api/types/Evaluation.py +25 -31
- phoenix/server/api/types/EvaluationSummary.py +30 -50
- phoenix/server/api/types/Event.py +20 -20
- phoenix/server/api/types/ExampleRevisionInterface.py +14 -0
- phoenix/server/api/types/Experiment.py +152 -0
- phoenix/server/api/types/ExperimentAnnotationSummary.py +13 -0
- phoenix/server/api/types/ExperimentComparison.py +17 -0
- phoenix/server/api/types/ExperimentRun.py +119 -0
- phoenix/server/api/types/ExperimentRunAnnotation.py +56 -0
- phoenix/server/api/types/GenerativeModel.py +9 -0
- phoenix/server/api/types/GenerativeProvider.py +85 -0
- phoenix/server/api/types/Inferences.py +80 -0
- phoenix/server/api/types/InferencesRole.py +23 -0
- phoenix/server/api/types/LabelFraction.py +7 -0
- phoenix/server/api/types/MimeType.py +2 -2
- phoenix/server/api/types/Model.py +54 -54
- phoenix/server/api/types/PerformanceMetric.py +8 -5
- phoenix/server/api/types/Project.py +407 -142
- phoenix/server/api/types/ProjectSession.py +139 -0
- phoenix/server/api/types/Segments.py +4 -4
- phoenix/server/api/types/Span.py +221 -176
- phoenix/server/api/types/SpanAnnotation.py +43 -0
- phoenix/server/api/types/SpanIOValue.py +15 -0
- phoenix/server/api/types/SystemApiKey.py +9 -0
- phoenix/server/api/types/TemplateLanguage.py +10 -0
- phoenix/server/api/types/TimeSeries.py +19 -15
- phoenix/server/api/types/TokenUsage.py +11 -0
- phoenix/server/api/types/Trace.py +154 -0
- phoenix/server/api/types/TraceAnnotation.py +45 -0
- phoenix/server/api/types/UMAPPoints.py +7 -7
- phoenix/server/api/types/User.py +60 -0
- phoenix/server/api/types/UserApiKey.py +45 -0
- phoenix/server/api/types/UserRole.py +15 -0
- phoenix/server/api/types/node.py +4 -112
- phoenix/server/api/types/pagination.py +156 -57
- phoenix/server/api/utils.py +34 -0
- phoenix/server/app.py +864 -115
- phoenix/server/bearer_auth.py +163 -0
- phoenix/server/dml_event.py +136 -0
- phoenix/server/dml_event_handler.py +256 -0
- phoenix/server/email/__init__.py +0 -0
- phoenix/server/email/sender.py +97 -0
- phoenix/server/email/templates/__init__.py +0 -0
- phoenix/server/email/templates/password_reset.html +19 -0
- phoenix/server/email/types.py +11 -0
- phoenix/server/grpc_server.py +102 -0
- phoenix/server/jwt_store.py +505 -0
- phoenix/server/main.py +305 -116
- phoenix/server/oauth2.py +52 -0
- phoenix/server/openapi/__init__.py +0 -0
- phoenix/server/prometheus.py +111 -0
- phoenix/server/rate_limiters.py +188 -0
- phoenix/server/static/.vite/manifest.json +87 -0
- phoenix/server/static/assets/components-Cy9nwIvF.js +2125 -0
- phoenix/server/static/assets/index-BKvHIxkk.js +113 -0
- phoenix/server/static/assets/pages-CUi2xCVQ.js +4449 -0
- phoenix/server/static/assets/vendor-DvC8cT4X.js +894 -0
- phoenix/server/static/assets/vendor-DxkFTwjz.css +1 -0
- phoenix/server/static/assets/vendor-arizeai-Do1793cv.js +662 -0
- phoenix/server/static/assets/vendor-codemirror-BzwZPyJM.js +24 -0
- phoenix/server/static/assets/vendor-recharts-_Jb7JjhG.js +59 -0
- phoenix/server/static/assets/vendor-shiki-Cl9QBraO.js +5 -0
- phoenix/server/static/assets/vendor-three-DwGkEfCM.js +2998 -0
- phoenix/server/telemetry.py +68 -0
- phoenix/server/templates/index.html +82 -23
- phoenix/server/thread_server.py +3 -3
- phoenix/server/types.py +275 -0
- phoenix/services.py +27 -18
- phoenix/session/client.py +743 -68
- phoenix/session/data_extractor.py +31 -7
- phoenix/session/evaluation.py +3 -9
- phoenix/session/session.py +263 -219
- phoenix/settings.py +22 -0
- phoenix/trace/__init__.py +2 -22
- phoenix/trace/attributes.py +338 -0
- phoenix/trace/dsl/README.md +116 -0
- phoenix/trace/dsl/filter.py +663 -213
- phoenix/trace/dsl/helpers.py +73 -21
- phoenix/trace/dsl/query.py +574 -201
- phoenix/trace/exporter.py +24 -19
- phoenix/trace/fixtures.py +368 -32
- phoenix/trace/otel.py +71 -219
- phoenix/trace/projects.py +3 -2
- phoenix/trace/schemas.py +33 -11
- phoenix/trace/span_evaluations.py +21 -16
- phoenix/trace/span_json_decoder.py +6 -4
- phoenix/trace/span_json_encoder.py +2 -2
- phoenix/trace/trace_dataset.py +47 -32
- phoenix/trace/utils.py +21 -4
- phoenix/utilities/__init__.py +0 -26
- phoenix/utilities/client.py +132 -0
- phoenix/utilities/deprecation.py +31 -0
- phoenix/utilities/error_handling.py +3 -2
- phoenix/utilities/json.py +109 -0
- phoenix/utilities/logging.py +8 -0
- phoenix/utilities/project.py +2 -2
- phoenix/utilities/re.py +49 -0
- phoenix/utilities/span_store.py +0 -23
- phoenix/utilities/template_formatters.py +99 -0
- phoenix/version.py +1 -1
- arize_phoenix-3.16.1.dist-info/METADATA +0 -495
- arize_phoenix-3.16.1.dist-info/RECORD +0 -178
- phoenix/core/project.py +0 -619
- phoenix/core/traces.py +0 -96
- phoenix/experimental/evals/__init__.py +0 -73
- phoenix/experimental/evals/evaluators.py +0 -413
- phoenix/experimental/evals/functions/__init__.py +0 -4
- phoenix/experimental/evals/functions/classify.py +0 -453
- phoenix/experimental/evals/functions/executor.py +0 -353
- phoenix/experimental/evals/functions/generate.py +0 -138
- phoenix/experimental/evals/functions/processing.py +0 -76
- phoenix/experimental/evals/models/__init__.py +0 -14
- phoenix/experimental/evals/models/anthropic.py +0 -175
- phoenix/experimental/evals/models/base.py +0 -170
- phoenix/experimental/evals/models/bedrock.py +0 -221
- phoenix/experimental/evals/models/litellm.py +0 -134
- phoenix/experimental/evals/models/openai.py +0 -448
- phoenix/experimental/evals/models/rate_limiters.py +0 -246
- phoenix/experimental/evals/models/vertex.py +0 -173
- phoenix/experimental/evals/models/vertexai.py +0 -186
- phoenix/experimental/evals/retrievals.py +0 -96
- phoenix/experimental/evals/templates/__init__.py +0 -50
- phoenix/experimental/evals/templates/default_templates.py +0 -472
- phoenix/experimental/evals/templates/template.py +0 -195
- phoenix/experimental/evals/utils/__init__.py +0 -172
- phoenix/experimental/evals/utils/threads.py +0 -27
- phoenix/server/api/helpers.py +0 -11
- phoenix/server/api/routers/evaluation_handler.py +0 -109
- phoenix/server/api/routers/span_handler.py +0 -70
- phoenix/server/api/routers/trace_handler.py +0 -60
- phoenix/server/api/types/DatasetRole.py +0 -23
- phoenix/server/static/index.css +0 -6
- phoenix/server/static/index.js +0 -7447
- phoenix/storage/span_store/__init__.py +0 -23
- phoenix/storage/span_store/text_file.py +0 -85
- phoenix/trace/dsl/missing.py +0 -60
- phoenix/trace/langchain/__init__.py +0 -3
- phoenix/trace/langchain/instrumentor.py +0 -35
- phoenix/trace/llama_index/__init__.py +0 -3
- phoenix/trace/llama_index/callback.py +0 -102
- phoenix/trace/openai/__init__.py +0 -3
- phoenix/trace/openai/instrumentor.py +0 -30
- {arize_phoenix-3.16.1.dist-info → arize_phoenix-7.7.1.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-3.16.1.dist-info → arize_phoenix-7.7.1.dist-info}/licenses/LICENSE +0 -0
- /phoenix/{datasets → db/insertion}/__init__.py +0 -0
- /phoenix/{experimental → db/migrations}/__init__.py +0 -0
- /phoenix/{storage → db/migrations/data_migration_scripts}/__init__.py +0 -0
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
from collections.abc import Awaitable, Callable
|
|
2
|
+
from typing import TYPE_CHECKING, Any, Optional
|
|
3
|
+
|
|
4
|
+
import grpc
|
|
5
|
+
from grpc.aio import RpcContext, Server, ServerInterceptor
|
|
6
|
+
from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import (
|
|
7
|
+
ExportTraceServiceRequest,
|
|
8
|
+
ExportTraceServiceResponse,
|
|
9
|
+
)
|
|
10
|
+
from opentelemetry.proto.collector.trace.v1.trace_service_pb2_grpc import (
|
|
11
|
+
TraceServiceServicer,
|
|
12
|
+
add_TraceServiceServicer_to_server,
|
|
13
|
+
)
|
|
14
|
+
from typing_extensions import TypeAlias
|
|
15
|
+
|
|
16
|
+
from phoenix.auth import CanReadToken
|
|
17
|
+
from phoenix.config import get_env_grpc_port
|
|
18
|
+
from phoenix.server.bearer_auth import ApiKeyInterceptor
|
|
19
|
+
from phoenix.trace.otel import decode_otlp_span
|
|
20
|
+
from phoenix.trace.schemas import Span
|
|
21
|
+
from phoenix.utilities.project import get_project_name
|
|
22
|
+
|
|
23
|
+
if TYPE_CHECKING:
|
|
24
|
+
from opentelemetry.trace import TracerProvider
|
|
25
|
+
|
|
26
|
+
ProjectName: TypeAlias = str
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class Servicer(TraceServiceServicer): # type: ignore[misc,unused-ignore]
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
callback: Callable[[Span, ProjectName], Awaitable[None]],
|
|
33
|
+
) -> None:
|
|
34
|
+
super().__init__()
|
|
35
|
+
self._callback = callback
|
|
36
|
+
|
|
37
|
+
async def Export(
|
|
38
|
+
self,
|
|
39
|
+
request: ExportTraceServiceRequest,
|
|
40
|
+
context: RpcContext,
|
|
41
|
+
) -> ExportTraceServiceResponse:
|
|
42
|
+
for resource_spans in request.resource_spans:
|
|
43
|
+
project_name = get_project_name(resource_spans.resource.attributes)
|
|
44
|
+
for scope_span in resource_spans.scope_spans:
|
|
45
|
+
for otlp_span in scope_span.spans:
|
|
46
|
+
span = decode_otlp_span(otlp_span)
|
|
47
|
+
await self._callback(span, project_name)
|
|
48
|
+
return ExportTraceServiceResponse()
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class GrpcServer:
|
|
52
|
+
def __init__(
|
|
53
|
+
self,
|
|
54
|
+
callback: Callable[[Span, ProjectName], Awaitable[None]],
|
|
55
|
+
tracer_provider: Optional["TracerProvider"] = None,
|
|
56
|
+
enable_prometheus: bool = False,
|
|
57
|
+
disabled: bool = False,
|
|
58
|
+
token_store: Optional[CanReadToken] = None,
|
|
59
|
+
interceptors: list[ServerInterceptor] = [],
|
|
60
|
+
) -> None:
|
|
61
|
+
self._callback = callback
|
|
62
|
+
self._server: Optional[Server] = None
|
|
63
|
+
self._tracer_provider = tracer_provider
|
|
64
|
+
self._enable_prometheus = enable_prometheus
|
|
65
|
+
self._disabled = disabled
|
|
66
|
+
self._token_store = token_store
|
|
67
|
+
self._interceptors = interceptors
|
|
68
|
+
|
|
69
|
+
async def __aenter__(self) -> None:
|
|
70
|
+
interceptors = self._interceptors
|
|
71
|
+
if self._disabled:
|
|
72
|
+
return
|
|
73
|
+
if self._token_store:
|
|
74
|
+
interceptors.append(ApiKeyInterceptor(self._token_store))
|
|
75
|
+
if self._enable_prometheus:
|
|
76
|
+
...
|
|
77
|
+
# TODO: convert to async interceptor
|
|
78
|
+
# from py_grpc_prometheus.prometheus_server_interceptor import PromServerInterceptor
|
|
79
|
+
#
|
|
80
|
+
# interceptors.append(PromServerInterceptor())
|
|
81
|
+
if self._tracer_provider is not None:
|
|
82
|
+
from opentelemetry.instrumentation.grpc import GrpcAioInstrumentorServer
|
|
83
|
+
|
|
84
|
+
GrpcAioInstrumentorServer().instrument(tracer_provider=self._tracer_provider) # type: ignore
|
|
85
|
+
server = grpc.aio.server(
|
|
86
|
+
options=(("grpc.so_reuseport", 0),),
|
|
87
|
+
interceptors=interceptors,
|
|
88
|
+
)
|
|
89
|
+
server.add_insecure_port(f"[::]:{get_env_grpc_port()}")
|
|
90
|
+
add_TraceServiceServicer_to_server(Servicer(self._callback), server) # type: ignore[no-untyped-call,unused-ignore]
|
|
91
|
+
await server.start()
|
|
92
|
+
self._server = server
|
|
93
|
+
|
|
94
|
+
async def __aexit__(self, *args: Any, **kwargs: Any) -> None:
|
|
95
|
+
if self._server is None:
|
|
96
|
+
return
|
|
97
|
+
await self._server.stop(5)
|
|
98
|
+
self._server = None
|
|
99
|
+
if self._tracer_provider is not None:
|
|
100
|
+
from opentelemetry.instrumentation.grpc import GrpcAioInstrumentorServer
|
|
101
|
+
|
|
102
|
+
GrpcAioInstrumentorServer().uninstrument() # type: ignore
|
|
@@ -0,0 +1,505 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from asyncio import create_task, gather, sleep
|
|
4
|
+
from collections.abc import Callable, Coroutine
|
|
5
|
+
from copy import deepcopy
|
|
6
|
+
from dataclasses import replace
|
|
7
|
+
from datetime import datetime, timezone
|
|
8
|
+
from functools import cached_property, singledispatchmethod
|
|
9
|
+
from typing import Any, Generic, Optional, TypeVar
|
|
10
|
+
|
|
11
|
+
from authlib.jose import jwt
|
|
12
|
+
from authlib.jose.errors import JoseError
|
|
13
|
+
from sqlalchemy import Select, delete, select
|
|
14
|
+
|
|
15
|
+
from phoenix.auth import (
|
|
16
|
+
JWT_ALGORITHM,
|
|
17
|
+
ClaimSet,
|
|
18
|
+
Token,
|
|
19
|
+
)
|
|
20
|
+
from phoenix.config import get_env_enable_prometheus
|
|
21
|
+
from phoenix.db import models
|
|
22
|
+
from phoenix.db.enums import UserRole
|
|
23
|
+
from phoenix.server.types import (
|
|
24
|
+
AccessToken,
|
|
25
|
+
AccessTokenAttributes,
|
|
26
|
+
AccessTokenClaims,
|
|
27
|
+
AccessTokenId,
|
|
28
|
+
ApiKey,
|
|
29
|
+
ApiKeyAttributes,
|
|
30
|
+
ApiKeyClaims,
|
|
31
|
+
ApiKeyId,
|
|
32
|
+
DaemonTask,
|
|
33
|
+
DbSessionFactory,
|
|
34
|
+
PasswordResetToken,
|
|
35
|
+
PasswordResetTokenAttributes,
|
|
36
|
+
PasswordResetTokenClaims,
|
|
37
|
+
PasswordResetTokenId,
|
|
38
|
+
RefreshToken,
|
|
39
|
+
RefreshTokenAttributes,
|
|
40
|
+
RefreshTokenClaims,
|
|
41
|
+
RefreshTokenId,
|
|
42
|
+
TokenId,
|
|
43
|
+
UserId,
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
logger = logging.getLogger(__name__)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class JwtStore:
|
|
50
|
+
def __init__(
|
|
51
|
+
self,
|
|
52
|
+
db: DbSessionFactory,
|
|
53
|
+
secret: str,
|
|
54
|
+
algorithm: str = JWT_ALGORITHM,
|
|
55
|
+
sleep_seconds: int = 10,
|
|
56
|
+
**kwargs: Any,
|
|
57
|
+
) -> None:
|
|
58
|
+
assert secret
|
|
59
|
+
super().__init__(**kwargs)
|
|
60
|
+
self._db = db
|
|
61
|
+
self._secret = secret
|
|
62
|
+
args = (db, secret, algorithm, sleep_seconds)
|
|
63
|
+
self._password_reset_token_store = _PasswordResetTokenStore(*args, **kwargs)
|
|
64
|
+
self._access_token_store = _AccessTokenStore(*args, **kwargs)
|
|
65
|
+
self._refresh_token_store = _RefreshTokenStore(*args, **kwargs)
|
|
66
|
+
self._api_key_store = _ApiKeyStore(*args, **kwargs)
|
|
67
|
+
|
|
68
|
+
@cached_property
|
|
69
|
+
def _stores(self) -> tuple[DaemonTask, ...]:
|
|
70
|
+
return tuple(dt for dt in self.__dict__.values() if isinstance(dt, _Store))
|
|
71
|
+
|
|
72
|
+
async def __aenter__(self) -> None:
|
|
73
|
+
await gather(*(s.__aenter__() for s in self._stores))
|
|
74
|
+
|
|
75
|
+
async def __aexit__(self, *args: Any, **kwargs: Any) -> None:
|
|
76
|
+
await gather(*(s.__aexit__(*args, **kwargs) for s in self._stores))
|
|
77
|
+
|
|
78
|
+
async def read(self, token: Token) -> Optional[ClaimSet]:
|
|
79
|
+
try:
|
|
80
|
+
payload = jwt.decode(
|
|
81
|
+
s=token,
|
|
82
|
+
key=self._secret,
|
|
83
|
+
)
|
|
84
|
+
except JoseError:
|
|
85
|
+
return None
|
|
86
|
+
if (jti := payload.get("jti")) is None:
|
|
87
|
+
return None
|
|
88
|
+
if (token_id := TokenId.parse(jti)) is None:
|
|
89
|
+
return None
|
|
90
|
+
return await self._get(token_id)
|
|
91
|
+
|
|
92
|
+
@singledispatchmethod
|
|
93
|
+
async def _get(self, _: TokenId) -> Optional[ClaimSet]:
|
|
94
|
+
return None
|
|
95
|
+
|
|
96
|
+
@_get.register
|
|
97
|
+
async def _(self, token_id: PasswordResetTokenId) -> Optional[ClaimSet]:
|
|
98
|
+
return await self._password_reset_token_store.get(token_id)
|
|
99
|
+
|
|
100
|
+
@_get.register
|
|
101
|
+
async def _(self, token_id: AccessTokenId) -> Optional[ClaimSet]:
|
|
102
|
+
return await self._access_token_store.get(token_id)
|
|
103
|
+
|
|
104
|
+
@_get.register
|
|
105
|
+
async def _(self, token_id: RefreshTokenId) -> Optional[ClaimSet]:
|
|
106
|
+
return await self._refresh_token_store.get(token_id)
|
|
107
|
+
|
|
108
|
+
@_get.register
|
|
109
|
+
async def _(self, token_id: ApiKeyId) -> Optional[ClaimSet]:
|
|
110
|
+
return await self._api_key_store.get(token_id)
|
|
111
|
+
|
|
112
|
+
@singledispatchmethod
|
|
113
|
+
async def _evict(self, _: TokenId) -> Optional[ClaimSet]:
|
|
114
|
+
return None
|
|
115
|
+
|
|
116
|
+
@_evict.register
|
|
117
|
+
async def _(self, token_id: PasswordResetTokenId) -> Optional[ClaimSet]:
|
|
118
|
+
return await self._password_reset_token_store.evict(token_id)
|
|
119
|
+
|
|
120
|
+
@_evict.register
|
|
121
|
+
async def _(self, token_id: AccessTokenId) -> Optional[ClaimSet]:
|
|
122
|
+
return await self._access_token_store.evict(token_id)
|
|
123
|
+
|
|
124
|
+
@_evict.register
|
|
125
|
+
async def _(self, token_id: RefreshTokenId) -> Optional[ClaimSet]:
|
|
126
|
+
return await self._refresh_token_store.evict(token_id)
|
|
127
|
+
|
|
128
|
+
@_evict.register
|
|
129
|
+
async def _(self, token_id: ApiKeyId) -> Optional[ClaimSet]:
|
|
130
|
+
return await self._api_key_store.evict(token_id)
|
|
131
|
+
|
|
132
|
+
async def create_password_reset_token(
|
|
133
|
+
self,
|
|
134
|
+
claim: PasswordResetTokenClaims,
|
|
135
|
+
) -> tuple[PasswordResetToken, PasswordResetTokenId]:
|
|
136
|
+
return await self._password_reset_token_store.create(claim)
|
|
137
|
+
|
|
138
|
+
async def create_access_token(
|
|
139
|
+
self,
|
|
140
|
+
claim: AccessTokenClaims,
|
|
141
|
+
) -> tuple[AccessToken, AccessTokenId]:
|
|
142
|
+
return await self._access_token_store.create(claim)
|
|
143
|
+
|
|
144
|
+
async def create_refresh_token(
|
|
145
|
+
self,
|
|
146
|
+
claim: RefreshTokenClaims,
|
|
147
|
+
) -> tuple[RefreshToken, RefreshTokenId]:
|
|
148
|
+
return await self._refresh_token_store.create(claim)
|
|
149
|
+
|
|
150
|
+
async def create_api_key(
|
|
151
|
+
self,
|
|
152
|
+
claim: ApiKeyClaims,
|
|
153
|
+
) -> tuple[ApiKey, ApiKeyId]:
|
|
154
|
+
return await self._api_key_store.create(claim)
|
|
155
|
+
|
|
156
|
+
async def revoke(self, *token_ids: TokenId) -> None:
|
|
157
|
+
if not token_ids:
|
|
158
|
+
return
|
|
159
|
+
password_reset_token_ids: list[PasswordResetTokenId] = []
|
|
160
|
+
access_token_ids: list[AccessTokenId] = []
|
|
161
|
+
refresh_token_ids: list[RefreshTokenId] = []
|
|
162
|
+
api_key_ids: list[ApiKeyId] = []
|
|
163
|
+
for token_id in token_ids:
|
|
164
|
+
if isinstance(token_id, PasswordResetTokenId):
|
|
165
|
+
password_reset_token_ids.append(token_id)
|
|
166
|
+
if isinstance(token_id, AccessTokenId):
|
|
167
|
+
access_token_ids.append(token_id)
|
|
168
|
+
elif isinstance(token_id, RefreshTokenId):
|
|
169
|
+
refresh_token_ids.append(token_id)
|
|
170
|
+
elif isinstance(token_id, ApiKeyId):
|
|
171
|
+
api_key_ids.append(token_id)
|
|
172
|
+
coroutines: list[Coroutine[None, None, None]] = []
|
|
173
|
+
if password_reset_token_ids:
|
|
174
|
+
coroutines.append(self._password_reset_token_store.revoke(*password_reset_token_ids))
|
|
175
|
+
if access_token_ids:
|
|
176
|
+
coroutines.append(self._access_token_store.revoke(*access_token_ids))
|
|
177
|
+
if refresh_token_ids:
|
|
178
|
+
coroutines.append(self._refresh_token_store.revoke(*refresh_token_ids))
|
|
179
|
+
if api_key_ids:
|
|
180
|
+
coroutines.append(self._api_key_store.revoke(*api_key_ids))
|
|
181
|
+
await gather(*coroutines)
|
|
182
|
+
|
|
183
|
+
async def log_out(self, user_id: UserId) -> None:
|
|
184
|
+
for cls in (AccessTokenId, RefreshTokenId):
|
|
185
|
+
table = cls.table
|
|
186
|
+
stmt = delete(table).where(table.user_id == int(user_id)).returning(table.id)
|
|
187
|
+
async with self._db() as session:
|
|
188
|
+
async for id_ in await session.stream_scalars(stmt):
|
|
189
|
+
await self._evict(cls(id_))
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
_TokenT = TypeVar("_TokenT", bound=Token)
|
|
193
|
+
_TokenIdT = TypeVar("_TokenIdT", bound=TokenId)
|
|
194
|
+
_ClaimSetT = TypeVar("_ClaimSetT", bound=ClaimSet)
|
|
195
|
+
_RecordT = TypeVar(
|
|
196
|
+
"_RecordT",
|
|
197
|
+
models.PasswordResetToken,
|
|
198
|
+
models.AccessToken,
|
|
199
|
+
models.RefreshToken,
|
|
200
|
+
models.ApiKey,
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
class _Claims(Generic[_TokenIdT, _ClaimSetT]):
|
|
205
|
+
def __init__(self) -> None:
|
|
206
|
+
self._cache: dict[_TokenIdT, _ClaimSetT] = {}
|
|
207
|
+
|
|
208
|
+
def __getitem__(self, token_id: _TokenIdT) -> Optional[_ClaimSetT]:
|
|
209
|
+
claim = self._cache.get(token_id)
|
|
210
|
+
return deepcopy(claim) if claim else None
|
|
211
|
+
|
|
212
|
+
def __setitem__(self, token_id: _TokenIdT, claim: _ClaimSetT) -> None:
|
|
213
|
+
self._cache[token_id] = deepcopy(claim)
|
|
214
|
+
|
|
215
|
+
def get(self, token_id: _TokenIdT) -> Optional[_ClaimSetT]:
|
|
216
|
+
claim = self._cache.get(token_id)
|
|
217
|
+
return deepcopy(claim) if claim else None
|
|
218
|
+
|
|
219
|
+
def pop(
|
|
220
|
+
self, token_id: _TokenIdT, default: Optional[_ClaimSetT] = None
|
|
221
|
+
) -> Optional[_ClaimSetT]:
|
|
222
|
+
claim = self._cache.pop(token_id, default)
|
|
223
|
+
return deepcopy(claim) if claim else None
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
class _Store(DaemonTask, Generic[_ClaimSetT, _TokenT, _TokenIdT, _RecordT], ABC):
|
|
227
|
+
_table: type[_RecordT]
|
|
228
|
+
_token_id: Callable[[int], _TokenIdT]
|
|
229
|
+
_token: Callable[[str], _TokenT]
|
|
230
|
+
|
|
231
|
+
def __init__(
|
|
232
|
+
self,
|
|
233
|
+
db: DbSessionFactory,
|
|
234
|
+
secret: str,
|
|
235
|
+
algorithm: str = JWT_ALGORITHM,
|
|
236
|
+
sleep_seconds: int = 10,
|
|
237
|
+
**kwargs: Any,
|
|
238
|
+
) -> None:
|
|
239
|
+
assert secret
|
|
240
|
+
super().__init__(**kwargs)
|
|
241
|
+
self._db = db
|
|
242
|
+
self._seconds = sleep_seconds
|
|
243
|
+
self._claims: _Claims[_TokenIdT, _ClaimSetT] = _Claims()
|
|
244
|
+
self._secret = secret
|
|
245
|
+
self._algorithm = algorithm
|
|
246
|
+
|
|
247
|
+
def _encode(self, claim: ClaimSet) -> str:
|
|
248
|
+
payload: dict[str, Any] = dict(jti=claim.token_id)
|
|
249
|
+
header = {"alg": self._algorithm}
|
|
250
|
+
jwt_bytes: bytes = jwt.encode(header=header, payload=payload, key=self._secret)
|
|
251
|
+
return jwt_bytes.decode()
|
|
252
|
+
|
|
253
|
+
async def get(self, token_id: _TokenIdT) -> Optional[_ClaimSetT]:
|
|
254
|
+
if claims := self._claims.get(token_id):
|
|
255
|
+
return claims
|
|
256
|
+
stmt = self._update_stmt.where(self._table.id == int(token_id))
|
|
257
|
+
async with self._db() as session:
|
|
258
|
+
record = (await session.execute(stmt)).first()
|
|
259
|
+
if not record:
|
|
260
|
+
return None
|
|
261
|
+
token, role = record
|
|
262
|
+
_, claims = self._from_db(token, UserRole(role))
|
|
263
|
+
self._claims[token_id] = claims
|
|
264
|
+
return claims
|
|
265
|
+
|
|
266
|
+
async def evict(self, token_id: _TokenIdT) -> Optional[_ClaimSetT]:
|
|
267
|
+
return self._claims.pop(token_id, None)
|
|
268
|
+
|
|
269
|
+
async def revoke(self, *token_ids: _TokenIdT) -> None:
|
|
270
|
+
if not token_ids:
|
|
271
|
+
return
|
|
272
|
+
for token_id in token_ids:
|
|
273
|
+
await self.evict(token_id)
|
|
274
|
+
stmt = delete(self._table).where(self._table.id.in_(map(int, token_ids)))
|
|
275
|
+
async with self._db() as session:
|
|
276
|
+
await session.execute(stmt)
|
|
277
|
+
|
|
278
|
+
@abstractmethod
|
|
279
|
+
def _from_db(self, record: _RecordT, role: UserRole) -> tuple[_TokenIdT, _ClaimSetT]: ...
|
|
280
|
+
|
|
281
|
+
@abstractmethod
|
|
282
|
+
def _to_db(self, claims: _ClaimSetT) -> _RecordT: ...
|
|
283
|
+
|
|
284
|
+
async def create(self, claim: _ClaimSetT) -> tuple[_TokenT, _TokenIdT]:
|
|
285
|
+
record = self._to_db(claim)
|
|
286
|
+
async with self._db() as session:
|
|
287
|
+
session.add(record)
|
|
288
|
+
await session.flush()
|
|
289
|
+
token_id = self._token_id(record.id)
|
|
290
|
+
claim = replace(claim, token_id=token_id)
|
|
291
|
+
self._claims[token_id] = claim
|
|
292
|
+
token = self._token(self._encode(claim))
|
|
293
|
+
return token, token_id
|
|
294
|
+
|
|
295
|
+
async def _update(self) -> None:
|
|
296
|
+
claims: _Claims[_TokenIdT, _ClaimSetT] = _Claims()
|
|
297
|
+
async with self._db() as session:
|
|
298
|
+
async with session.begin_nested():
|
|
299
|
+
await self._delete_expired_tokens(session)
|
|
300
|
+
async with session.begin_nested():
|
|
301
|
+
async for record, role in await session.stream(self._update_stmt):
|
|
302
|
+
token_id, claim_set = self._from_db(record, UserRole(role))
|
|
303
|
+
claims[token_id] = claim_set
|
|
304
|
+
self._claims = claims
|
|
305
|
+
|
|
306
|
+
@cached_property
|
|
307
|
+
def _update_stmt(self) -> Select[tuple[_RecordT, str]]:
|
|
308
|
+
return (
|
|
309
|
+
select(self._table, models.UserRole.name)
|
|
310
|
+
.join_from(self._table, models.User)
|
|
311
|
+
.join_from(models.User, models.UserRole)
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
async def _delete_expired_tokens(self, session: Any) -> None:
|
|
315
|
+
now = datetime.now(timezone.utc)
|
|
316
|
+
await session.execute(delete(self._table).where(self._table.expires_at < now))
|
|
317
|
+
|
|
318
|
+
async def _run(self) -> None:
|
|
319
|
+
while self._running:
|
|
320
|
+
self._tasks.append(create_task(self._update()))
|
|
321
|
+
await self._tasks[-1]
|
|
322
|
+
self._tasks.pop()
|
|
323
|
+
self._tasks.append(create_task(sleep(self._seconds)))
|
|
324
|
+
await self._tasks[-1]
|
|
325
|
+
self._tasks.pop()
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
class _PasswordResetTokenStore(
|
|
329
|
+
_Store[
|
|
330
|
+
PasswordResetTokenClaims,
|
|
331
|
+
PasswordResetToken,
|
|
332
|
+
PasswordResetTokenId,
|
|
333
|
+
models.PasswordResetToken,
|
|
334
|
+
]
|
|
335
|
+
):
|
|
336
|
+
_table = models.PasswordResetToken
|
|
337
|
+
_token_id = PasswordResetTokenId
|
|
338
|
+
_token = PasswordResetToken
|
|
339
|
+
|
|
340
|
+
def _from_db(
|
|
341
|
+
self,
|
|
342
|
+
record: models.PasswordResetToken,
|
|
343
|
+
user_role: UserRole,
|
|
344
|
+
) -> tuple[PasswordResetTokenId, PasswordResetTokenClaims]:
|
|
345
|
+
token_id = PasswordResetTokenId(record.id)
|
|
346
|
+
return token_id, PasswordResetTokenClaims(
|
|
347
|
+
token_id=token_id,
|
|
348
|
+
subject=UserId(record.user_id),
|
|
349
|
+
issued_at=record.created_at,
|
|
350
|
+
expiration_time=record.expires_at,
|
|
351
|
+
attributes=PasswordResetTokenAttributes(
|
|
352
|
+
user_role=user_role,
|
|
353
|
+
),
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
def _to_db(self, claim: PasswordResetTokenClaims) -> models.PasswordResetToken:
|
|
357
|
+
assert claim.expiration_time
|
|
358
|
+
assert claim.subject
|
|
359
|
+
user_id = int(claim.subject)
|
|
360
|
+
return models.PasswordResetToken(
|
|
361
|
+
user_id=user_id,
|
|
362
|
+
created_at=claim.issued_at,
|
|
363
|
+
expires_at=claim.expiration_time,
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
class _AccessTokenStore(
|
|
368
|
+
_Store[
|
|
369
|
+
AccessTokenClaims,
|
|
370
|
+
AccessToken,
|
|
371
|
+
AccessTokenId,
|
|
372
|
+
models.AccessToken,
|
|
373
|
+
]
|
|
374
|
+
):
|
|
375
|
+
_table = models.AccessToken
|
|
376
|
+
_token_id = AccessTokenId
|
|
377
|
+
_token = AccessToken
|
|
378
|
+
|
|
379
|
+
def _from_db(
|
|
380
|
+
self,
|
|
381
|
+
record: models.AccessToken,
|
|
382
|
+
user_role: UserRole,
|
|
383
|
+
) -> tuple[AccessTokenId, AccessTokenClaims]:
|
|
384
|
+
token_id = AccessTokenId(record.id)
|
|
385
|
+
refresh_token_id = RefreshTokenId(record.refresh_token_id)
|
|
386
|
+
return token_id, AccessTokenClaims(
|
|
387
|
+
token_id=token_id,
|
|
388
|
+
subject=UserId(record.user_id),
|
|
389
|
+
issued_at=record.created_at,
|
|
390
|
+
expiration_time=record.expires_at,
|
|
391
|
+
attributes=AccessTokenAttributes(
|
|
392
|
+
user_role=user_role,
|
|
393
|
+
refresh_token_id=refresh_token_id,
|
|
394
|
+
),
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
def _to_db(self, claim: AccessTokenClaims) -> models.AccessToken:
|
|
398
|
+
assert claim.expiration_time
|
|
399
|
+
assert claim.subject
|
|
400
|
+
user_id = int(claim.subject)
|
|
401
|
+
assert claim.attributes
|
|
402
|
+
refresh_token_id = int(claim.attributes.refresh_token_id)
|
|
403
|
+
return models.AccessToken(
|
|
404
|
+
user_id=user_id,
|
|
405
|
+
created_at=claim.issued_at,
|
|
406
|
+
expires_at=claim.expiration_time,
|
|
407
|
+
refresh_token_id=refresh_token_id,
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
class _RefreshTokenStore(
|
|
412
|
+
_Store[
|
|
413
|
+
RefreshTokenClaims,
|
|
414
|
+
RefreshToken,
|
|
415
|
+
RefreshTokenId,
|
|
416
|
+
models.RefreshToken,
|
|
417
|
+
]
|
|
418
|
+
):
|
|
419
|
+
_table = models.RefreshToken
|
|
420
|
+
_token_id = RefreshTokenId
|
|
421
|
+
_token = RefreshToken
|
|
422
|
+
|
|
423
|
+
def _from_db(
|
|
424
|
+
self,
|
|
425
|
+
record: models.RefreshToken,
|
|
426
|
+
user_role: UserRole,
|
|
427
|
+
) -> tuple[RefreshTokenId, RefreshTokenClaims]:
|
|
428
|
+
token_id = RefreshTokenId(record.id)
|
|
429
|
+
return token_id, RefreshTokenClaims(
|
|
430
|
+
token_id=token_id,
|
|
431
|
+
subject=UserId(record.user_id),
|
|
432
|
+
issued_at=record.created_at,
|
|
433
|
+
expiration_time=record.expires_at,
|
|
434
|
+
attributes=RefreshTokenAttributes(
|
|
435
|
+
user_role=user_role,
|
|
436
|
+
),
|
|
437
|
+
)
|
|
438
|
+
|
|
439
|
+
def _to_db(self, claims: RefreshTokenClaims) -> models.RefreshToken:
|
|
440
|
+
assert claims.expiration_time
|
|
441
|
+
assert claims.subject
|
|
442
|
+
user_id = int(claims.subject)
|
|
443
|
+
return models.RefreshToken(
|
|
444
|
+
user_id=user_id,
|
|
445
|
+
created_at=claims.issued_at,
|
|
446
|
+
expires_at=claims.expiration_time,
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
async def _update(self) -> None:
|
|
450
|
+
await super()._update()
|
|
451
|
+
if get_env_enable_prometheus():
|
|
452
|
+
from phoenix.server.prometheus import JWT_STORE_TOKENS_ACTIVE
|
|
453
|
+
|
|
454
|
+
JWT_STORE_TOKENS_ACTIVE.set(len(self._claims._cache))
|
|
455
|
+
|
|
456
|
+
|
|
457
|
+
class _ApiKeyStore(
|
|
458
|
+
_Store[
|
|
459
|
+
ApiKeyClaims,
|
|
460
|
+
ApiKey,
|
|
461
|
+
ApiKeyId,
|
|
462
|
+
models.ApiKey,
|
|
463
|
+
]
|
|
464
|
+
):
|
|
465
|
+
_table = models.ApiKey
|
|
466
|
+
_token_id = ApiKeyId
|
|
467
|
+
_token = ApiKey
|
|
468
|
+
|
|
469
|
+
def _from_db(
|
|
470
|
+
self,
|
|
471
|
+
record: models.ApiKey,
|
|
472
|
+
user_role: UserRole,
|
|
473
|
+
) -> tuple[ApiKeyId, ApiKeyClaims]:
|
|
474
|
+
token_id = ApiKeyId(record.id)
|
|
475
|
+
return token_id, ApiKeyClaims(
|
|
476
|
+
token_id=token_id,
|
|
477
|
+
subject=UserId(record.user_id),
|
|
478
|
+
issued_at=record.created_at,
|
|
479
|
+
expiration_time=record.expires_at,
|
|
480
|
+
attributes=ApiKeyAttributes(
|
|
481
|
+
user_role=user_role,
|
|
482
|
+
name=record.name,
|
|
483
|
+
description=record.description,
|
|
484
|
+
),
|
|
485
|
+
)
|
|
486
|
+
|
|
487
|
+
def _to_db(self, claims: ApiKeyClaims) -> models.ApiKey:
|
|
488
|
+
assert claims.attributes
|
|
489
|
+
assert claims.attributes.name
|
|
490
|
+
assert claims.subject
|
|
491
|
+
user_id = int(claims.subject)
|
|
492
|
+
return models.ApiKey(
|
|
493
|
+
user_id=user_id,
|
|
494
|
+
name=claims.attributes.name,
|
|
495
|
+
description=claims.attributes.description or None,
|
|
496
|
+
created_at=claims.issued_at,
|
|
497
|
+
expires_at=claims.expiration_time or None,
|
|
498
|
+
)
|
|
499
|
+
|
|
500
|
+
async def _update(self) -> None:
|
|
501
|
+
await super()._update()
|
|
502
|
+
if get_env_enable_prometheus():
|
|
503
|
+
from phoenix.server.prometheus import JWT_STORE_API_KEYS_ACTIVE
|
|
504
|
+
|
|
505
|
+
JWT_STORE_API_KEYS_ACTIVE.set(len(self._claims._cache))
|