arize-phoenix 10.0.4__py3-none-any.whl → 12.28.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +124 -72
- arize_phoenix-12.28.1.dist-info/RECORD +499 -0
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
- phoenix/__generated__/__init__.py +0 -0
- phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
- phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
- phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
- phoenix/__init__.py +5 -4
- phoenix/auth.py +39 -2
- phoenix/config.py +1763 -91
- phoenix/datetime_utils.py +120 -2
- phoenix/db/README.md +595 -25
- phoenix/db/bulk_inserter.py +145 -103
- phoenix/db/engines.py +140 -33
- phoenix/db/enums.py +3 -12
- phoenix/db/facilitator.py +302 -35
- phoenix/db/helpers.py +1000 -65
- phoenix/db/iam_auth.py +64 -0
- phoenix/db/insertion/dataset.py +135 -2
- phoenix/db/insertion/document_annotation.py +9 -6
- phoenix/db/insertion/evaluation.py +2 -3
- phoenix/db/insertion/helpers.py +17 -2
- phoenix/db/insertion/session_annotation.py +176 -0
- phoenix/db/insertion/span.py +15 -11
- phoenix/db/insertion/span_annotation.py +3 -4
- phoenix/db/insertion/trace_annotation.py +3 -4
- phoenix/db/insertion/types.py +50 -20
- phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
- phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
- phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
- phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
- phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
- phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
- phoenix/db/migrations/versions/a20694b15f82_cost.py +196 -0
- phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
- phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
- phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
- phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
- phoenix/db/models.py +669 -56
- phoenix/db/pg_config.py +10 -0
- phoenix/db/types/model_provider.py +4 -0
- phoenix/db/types/token_price_customization.py +29 -0
- phoenix/db/types/trace_retention.py +23 -15
- phoenix/experiments/evaluators/utils.py +3 -3
- phoenix/experiments/functions.py +160 -52
- phoenix/experiments/tracing.py +2 -2
- phoenix/experiments/types.py +1 -1
- phoenix/inferences/inferences.py +1 -2
- phoenix/server/api/auth.py +38 -7
- phoenix/server/api/auth_messages.py +46 -0
- phoenix/server/api/context.py +100 -4
- phoenix/server/api/dataloaders/__init__.py +79 -5
- phoenix/server/api/dataloaders/annotation_configs_by_project.py +31 -0
- phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
- phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
- phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
- phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
- phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
- phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
- phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
- phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
- phoenix/server/api/dataloaders/dataset_labels.py +36 -0
- phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
- phoenix/server/api/dataloaders/document_evaluations.py +6 -9
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
- phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
- phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
- phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
- phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
- phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
- phoenix/server/api/dataloaders/last_used_times_by_generative_model_id.py +35 -0
- phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
- phoenix/server/api/dataloaders/record_counts.py +37 -10
- phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
- phoenix/server/api/dataloaders/span_cost_by_span.py +24 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_generative_model.py +56 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_project_session.py +57 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_span.py +43 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_trace.py +56 -0
- phoenix/server/api/dataloaders/span_cost_details_by_span_cost.py +27 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment.py +57 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_run.py +58 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_generative_model.py +55 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project.py +152 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project_session.py +56 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_trace.py +55 -0
- phoenix/server/api/dataloaders/span_costs.py +29 -0
- phoenix/server/api/dataloaders/table_fields.py +2 -2
- phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
- phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
- phoenix/server/api/dataloaders/types.py +29 -0
- phoenix/server/api/exceptions.py +11 -1
- phoenix/server/api/helpers/dataset_helpers.py +5 -1
- phoenix/server/api/helpers/playground_clients.py +1243 -292
- phoenix/server/api/helpers/playground_registry.py +2 -2
- phoenix/server/api/helpers/playground_spans.py +8 -4
- phoenix/server/api/helpers/playground_users.py +26 -0
- phoenix/server/api/helpers/prompts/conversions/aws.py +83 -0
- phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
- phoenix/server/api/helpers/prompts/models.py +205 -22
- phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
- phoenix/server/api/input_types/ChatCompletionInput.py +6 -2
- phoenix/server/api/input_types/CreateProjectInput.py +27 -0
- phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
- phoenix/server/api/input_types/DatasetFilter.py +17 -0
- phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
- phoenix/server/api/input_types/GenerativeCredentialInput.py +9 -0
- phoenix/server/api/input_types/GenerativeModelInput.py +5 -0
- phoenix/server/api/input_types/ProjectSessionSort.py +161 -1
- phoenix/server/api/input_types/PromptFilter.py +14 -0
- phoenix/server/api/input_types/PromptVersionInput.py +52 -1
- phoenix/server/api/input_types/SpanSort.py +44 -7
- phoenix/server/api/input_types/TimeBinConfig.py +23 -0
- phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
- phoenix/server/api/input_types/UserRoleInput.py +1 -0
- phoenix/server/api/mutations/__init__.py +10 -0
- phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
- phoenix/server/api/mutations/api_key_mutations.py +19 -23
- phoenix/server/api/mutations/chat_mutations.py +154 -47
- phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
- phoenix/server/api/mutations/dataset_mutations.py +21 -16
- phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
- phoenix/server/api/mutations/experiment_mutations.py +2 -2
- phoenix/server/api/mutations/export_events_mutations.py +3 -3
- phoenix/server/api/mutations/model_mutations.py +210 -0
- phoenix/server/api/mutations/project_mutations.py +49 -10
- phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
- phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
- phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
- phoenix/server/api/mutations/prompt_mutations.py +65 -129
- phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
- phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
- phoenix/server/api/mutations/trace_annotations_mutations.py +14 -10
- phoenix/server/api/mutations/trace_mutations.py +47 -3
- phoenix/server/api/mutations/user_mutations.py +66 -41
- phoenix/server/api/queries.py +768 -293
- phoenix/server/api/routers/__init__.py +2 -2
- phoenix/server/api/routers/auth.py +154 -88
- phoenix/server/api/routers/ldap.py +229 -0
- phoenix/server/api/routers/oauth2.py +369 -106
- phoenix/server/api/routers/v1/__init__.py +24 -4
- phoenix/server/api/routers/v1/annotation_configs.py +23 -31
- phoenix/server/api/routers/v1/annotations.py +481 -17
- phoenix/server/api/routers/v1/datasets.py +395 -81
- phoenix/server/api/routers/v1/documents.py +142 -0
- phoenix/server/api/routers/v1/evaluations.py +24 -31
- phoenix/server/api/routers/v1/experiment_evaluations.py +19 -8
- phoenix/server/api/routers/v1/experiment_runs.py +337 -59
- phoenix/server/api/routers/v1/experiments.py +479 -48
- phoenix/server/api/routers/v1/models.py +7 -0
- phoenix/server/api/routers/v1/projects.py +18 -49
- phoenix/server/api/routers/v1/prompts.py +54 -40
- phoenix/server/api/routers/v1/sessions.py +108 -0
- phoenix/server/api/routers/v1/spans.py +1091 -81
- phoenix/server/api/routers/v1/traces.py +132 -78
- phoenix/server/api/routers/v1/users.py +389 -0
- phoenix/server/api/routers/v1/utils.py +3 -7
- phoenix/server/api/subscriptions.py +305 -88
- phoenix/server/api/types/Annotation.py +90 -23
- phoenix/server/api/types/ApiKey.py +13 -17
- phoenix/server/api/types/AuthMethod.py +1 -0
- phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
- phoenix/server/api/types/CostBreakdown.py +12 -0
- phoenix/server/api/types/Dataset.py +226 -72
- phoenix/server/api/types/DatasetExample.py +88 -18
- phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
- phoenix/server/api/types/DatasetLabel.py +57 -0
- phoenix/server/api/types/DatasetSplit.py +98 -0
- phoenix/server/api/types/DatasetVersion.py +49 -4
- phoenix/server/api/types/DocumentAnnotation.py +212 -0
- phoenix/server/api/types/Experiment.py +264 -59
- phoenix/server/api/types/ExperimentComparison.py +5 -10
- phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
- phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
- phoenix/server/api/types/ExperimentRun.py +169 -65
- phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
- phoenix/server/api/types/GenerativeModel.py +245 -3
- phoenix/server/api/types/GenerativeProvider.py +70 -11
- phoenix/server/api/types/{Model.py → InferenceModel.py} +1 -1
- phoenix/server/api/types/ModelInterface.py +16 -0
- phoenix/server/api/types/PlaygroundModel.py +20 -0
- phoenix/server/api/types/Project.py +1278 -216
- phoenix/server/api/types/ProjectSession.py +188 -28
- phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
- phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
- phoenix/server/api/types/Prompt.py +119 -39
- phoenix/server/api/types/PromptLabel.py +42 -25
- phoenix/server/api/types/PromptVersion.py +11 -8
- phoenix/server/api/types/PromptVersionTag.py +65 -25
- phoenix/server/api/types/ServerStatus.py +6 -0
- phoenix/server/api/types/Span.py +167 -123
- phoenix/server/api/types/SpanAnnotation.py +189 -42
- phoenix/server/api/types/SpanCostDetailSummaryEntry.py +10 -0
- phoenix/server/api/types/SpanCostSummary.py +10 -0
- phoenix/server/api/types/SystemApiKey.py +65 -1
- phoenix/server/api/types/TokenPrice.py +16 -0
- phoenix/server/api/types/TokenUsage.py +3 -3
- phoenix/server/api/types/Trace.py +223 -51
- phoenix/server/api/types/TraceAnnotation.py +149 -50
- phoenix/server/api/types/User.py +137 -32
- phoenix/server/api/types/UserApiKey.py +73 -26
- phoenix/server/api/types/node.py +10 -0
- phoenix/server/api/types/pagination.py +11 -2
- phoenix/server/app.py +290 -45
- phoenix/server/authorization.py +38 -3
- phoenix/server/bearer_auth.py +34 -24
- phoenix/server/cost_tracking/cost_details_calculator.py +196 -0
- phoenix/server/cost_tracking/cost_model_lookup.py +179 -0
- phoenix/server/cost_tracking/helpers.py +68 -0
- phoenix/server/cost_tracking/model_cost_manifest.json +3657 -830
- phoenix/server/cost_tracking/regex_specificity.py +397 -0
- phoenix/server/cost_tracking/token_cost_calculator.py +57 -0
- phoenix/server/daemons/__init__.py +0 -0
- phoenix/server/daemons/db_disk_usage_monitor.py +214 -0
- phoenix/server/daemons/generative_model_store.py +103 -0
- phoenix/server/daemons/span_cost_calculator.py +99 -0
- phoenix/server/dml_event.py +17 -0
- phoenix/server/dml_event_handler.py +5 -0
- phoenix/server/email/sender.py +56 -3
- phoenix/server/email/templates/db_disk_usage_notification.html +19 -0
- phoenix/server/email/types.py +11 -0
- phoenix/server/experiments/__init__.py +0 -0
- phoenix/server/experiments/utils.py +14 -0
- phoenix/server/grpc_server.py +11 -11
- phoenix/server/jwt_store.py +17 -15
- phoenix/server/ldap.py +1449 -0
- phoenix/server/main.py +26 -10
- phoenix/server/oauth2.py +330 -12
- phoenix/server/prometheus.py +66 -6
- phoenix/server/rate_limiters.py +4 -9
- phoenix/server/retention.py +33 -20
- phoenix/server/session_filters.py +49 -0
- phoenix/server/static/.vite/manifest.json +55 -51
- phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
- phoenix/server/static/assets/{index-E0M82BdE.js → index-CTQoemZv.js} +140 -56
- phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
- phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
- phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
- phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
- phoenix/server/static/assets/vendor-recharts-V9cwpXsm.js +37 -0
- phoenix/server/static/assets/vendor-shiki-Do--csgv.js +5 -0
- phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
- phoenix/server/templates/index.html +40 -6
- phoenix/server/thread_server.py +1 -2
- phoenix/server/types.py +14 -4
- phoenix/server/utils.py +74 -0
- phoenix/session/client.py +56 -3
- phoenix/session/data_extractor.py +5 -0
- phoenix/session/evaluation.py +14 -5
- phoenix/session/session.py +45 -9
- phoenix/settings.py +5 -0
- phoenix/trace/attributes.py +80 -13
- phoenix/trace/dsl/helpers.py +90 -1
- phoenix/trace/dsl/query.py +8 -6
- phoenix/trace/projects.py +5 -0
- phoenix/utilities/template_formatters.py +1 -1
- phoenix/version.py +1 -1
- arize_phoenix-10.0.4.dist-info/RECORD +0 -405
- phoenix/server/api/types/Evaluation.py +0 -39
- phoenix/server/cost_tracking/cost_lookup.py +0 -255
- phoenix/server/static/assets/components-DULKeDfL.js +0 -4365
- phoenix/server/static/assets/pages-Cl0A-0U2.js +0 -7430
- phoenix/server/static/assets/vendor-WIZid84E.css +0 -1
- phoenix/server/static/assets/vendor-arizeai-Dy-0mSNw.js +0 -649
- phoenix/server/static/assets/vendor-codemirror-DBtifKNr.js +0 -33
- phoenix/server/static/assets/vendor-oB4u9zuV.js +0 -905
- phoenix/server/static/assets/vendor-recharts-D-T4KPz2.js +0 -59
- phoenix/server/static/assets/vendor-shiki-BMn4O_9F.js +0 -5
- phoenix/server/static/assets/vendor-three-C5WAXd5r.js +0 -2998
- phoenix/utilities/deprecation.py +0 -31
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from asyncio import sleep
|
|
5
|
+
from datetime import datetime, timedelta, timezone
|
|
6
|
+
from typing import Any, Mapping, Optional
|
|
7
|
+
|
|
8
|
+
import sqlalchemy as sa
|
|
9
|
+
from sqlalchemy.orm import joinedload
|
|
10
|
+
|
|
11
|
+
from phoenix.db import models
|
|
12
|
+
from phoenix.server.cost_tracking.cost_model_lookup import CostModelLookup
|
|
13
|
+
from phoenix.server.types import DaemonTask, DbSessionFactory
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class GenerativeModelStore(DaemonTask):
|
|
19
|
+
"""A daemon that periodically fetches generative models and maintains an in-memory cache.
|
|
20
|
+
|
|
21
|
+
This daemon periodically fetches generative models and their token prices from the
|
|
22
|
+
database and maintains an in-memory cache for fast lookups. It uses an incremental
|
|
23
|
+
fetch strategy to minimize database egress costs. Instead of fetching all models on
|
|
24
|
+
every refresh, we track the last fetch time and only query for models that have
|
|
25
|
+
changed since then (using updated_at/deleted_at).
|
|
26
|
+
|
|
27
|
+
Rationale: Database egress is expensive in cloud environments (especially managed
|
|
28
|
+
databases), and generative models change infrequently (mostly static reference data).
|
|
29
|
+
The cost calculation daemon queries this store frequently (once per span), so trading
|
|
30
|
+
memory for reduced database egress provides significant cost savings.
|
|
31
|
+
|
|
32
|
+
Note:
|
|
33
|
+
This strategy relies on GenerativeModel.updated_at being properly maintained. Any
|
|
34
|
+
code that modifies GenerativeModel or TokenPrice records MUST ensure updated_at
|
|
35
|
+
is explicitly set (see model_mutations.py). Relying solely on SQLAlchemy's
|
|
36
|
+
onupdate=func.now() is insufficient because SQLAlchemy may skip the UPDATE if it
|
|
37
|
+
detects no "real" changes to scalar fields (even if child records like TokenPrice
|
|
38
|
+
are modified).
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
db: DbSessionFactory,
|
|
44
|
+
refresh_interval_seconds: int = 5,
|
|
45
|
+
) -> None:
|
|
46
|
+
super().__init__()
|
|
47
|
+
self._db = db
|
|
48
|
+
self._lookup = CostModelLookup()
|
|
49
|
+
self._last_fetch_time: Optional[datetime] = None
|
|
50
|
+
self._last_fetch_id: Optional[int] = None
|
|
51
|
+
self._refresh_interval_seconds = refresh_interval_seconds
|
|
52
|
+
|
|
53
|
+
def find_model(
|
|
54
|
+
self,
|
|
55
|
+
start_time: datetime,
|
|
56
|
+
attributes: Mapping[str, Any],
|
|
57
|
+
) -> Optional[models.GenerativeModel]:
|
|
58
|
+
return self._lookup.find_model(start_time, attributes)
|
|
59
|
+
|
|
60
|
+
async def _run(self) -> None:
|
|
61
|
+
while self._running:
|
|
62
|
+
# Capture time before query with 2-second buffer for clock skew tolerance
|
|
63
|
+
fetch_start_time = datetime.now(timezone.utc) - timedelta(seconds=2)
|
|
64
|
+
try:
|
|
65
|
+
await self._fetch_models()
|
|
66
|
+
except Exception:
|
|
67
|
+
logger.exception("Failed to refresh generative models")
|
|
68
|
+
else:
|
|
69
|
+
self._last_fetch_time = fetch_start_time
|
|
70
|
+
await sleep(self._refresh_interval_seconds)
|
|
71
|
+
|
|
72
|
+
async def _fetch_models(self) -> None:
|
|
73
|
+
"""
|
|
74
|
+
Fetch generative models from the database using an incremental strategy.
|
|
75
|
+
|
|
76
|
+
On the first run, fetches all models. On subsequent runs, only fetches models
|
|
77
|
+
where updated_at or deleted_at is at or after the last fetch time (with a 2-second
|
|
78
|
+
buffer). Some models may be refetched, but .merge() handles duplicates idempotently.
|
|
79
|
+
"""
|
|
80
|
+
stmt = sa.select(models.GenerativeModel).options(
|
|
81
|
+
joinedload(models.GenerativeModel.token_prices)
|
|
82
|
+
)
|
|
83
|
+
if self._last_fetch_time:
|
|
84
|
+
# Incremental fetch: get models changed since last fetch.
|
|
85
|
+
# Use >= for updated_at/deleted_at to catch models from the buffer window.
|
|
86
|
+
# Include id check as redundant safety check.
|
|
87
|
+
stmt = stmt.where(
|
|
88
|
+
sa.or_(
|
|
89
|
+
models.GenerativeModel.id > self._last_fetch_id,
|
|
90
|
+
models.GenerativeModel.updated_at >= self._last_fetch_time,
|
|
91
|
+
models.GenerativeModel.deleted_at >= self._last_fetch_time,
|
|
92
|
+
)
|
|
93
|
+
)
|
|
94
|
+
async with self._db() as session:
|
|
95
|
+
generative_models = (await session.scalars(stmt)).unique().all()
|
|
96
|
+
|
|
97
|
+
if not generative_models:
|
|
98
|
+
return
|
|
99
|
+
|
|
100
|
+
self._lookup.merge(generative_models)
|
|
101
|
+
|
|
102
|
+
# Track max id for redundant safety check.
|
|
103
|
+
self._last_fetch_id = max(model.id for model in generative_models)
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from asyncio import sleep
|
|
5
|
+
from collections import deque
|
|
6
|
+
from datetime import datetime
|
|
7
|
+
from typing import Any, Mapping, NamedTuple, Optional
|
|
8
|
+
|
|
9
|
+
from typing_extensions import TypeAlias
|
|
10
|
+
|
|
11
|
+
from phoenix.db import models
|
|
12
|
+
from phoenix.server.cost_tracking.cost_details_calculator import SpanCostDetailsCalculator
|
|
13
|
+
from phoenix.server.daemons.generative_model_store import GenerativeModelStore
|
|
14
|
+
from phoenix.server.types import DaemonTask, DbSessionFactory
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
_GenerativeModelId: TypeAlias = int
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class SpanCostCalculatorQueueItem(NamedTuple):
|
|
22
|
+
span_rowid: int
|
|
23
|
+
trace_rowid: int
|
|
24
|
+
attributes: Mapping[str, Any]
|
|
25
|
+
span_start_time: datetime
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class SpanCostCalculator(DaemonTask):
|
|
29
|
+
_SLEEP_INTERVAL = 5 # seconds
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
db: DbSessionFactory,
|
|
34
|
+
model_store: GenerativeModelStore,
|
|
35
|
+
) -> None:
|
|
36
|
+
super().__init__()
|
|
37
|
+
self._db = db
|
|
38
|
+
self._model_store = model_store
|
|
39
|
+
self._queue: deque[SpanCostCalculatorQueueItem] = deque()
|
|
40
|
+
self._max_items_per_transaction = 1000
|
|
41
|
+
|
|
42
|
+
async def _run(self) -> None:
|
|
43
|
+
while self._running:
|
|
44
|
+
num_items_to_insert = min(self._max_items_per_transaction, len(self._queue))
|
|
45
|
+
try:
|
|
46
|
+
await self._insert_costs(num_items_to_insert)
|
|
47
|
+
except Exception as e:
|
|
48
|
+
logger.exception(f"Failed to insert costs: {e}")
|
|
49
|
+
await sleep(self._SLEEP_INTERVAL)
|
|
50
|
+
|
|
51
|
+
async def _insert_costs(self, num_items_to_insert: int) -> None:
|
|
52
|
+
if not num_items_to_insert or not self._queue:
|
|
53
|
+
return
|
|
54
|
+
costs: list[models.SpanCost] = []
|
|
55
|
+
while num_items_to_insert > 0:
|
|
56
|
+
num_items_to_insert -= 1
|
|
57
|
+
item = self._queue.popleft()
|
|
58
|
+
try:
|
|
59
|
+
cost = self.calculate_cost(item.span_start_time, item.attributes)
|
|
60
|
+
except Exception as e:
|
|
61
|
+
logger.exception(f"Failed to calculate cost for span {item.span_rowid}: {e}")
|
|
62
|
+
continue
|
|
63
|
+
if not cost:
|
|
64
|
+
continue
|
|
65
|
+
cost.span_rowid = item.span_rowid
|
|
66
|
+
cost.trace_rowid = item.trace_rowid
|
|
67
|
+
costs.append(cost)
|
|
68
|
+
try:
|
|
69
|
+
async with self._db() as session:
|
|
70
|
+
session.add_all(costs)
|
|
71
|
+
except Exception as e:
|
|
72
|
+
logger.exception(f"Failed to insert costs: {e}")
|
|
73
|
+
|
|
74
|
+
def put_nowait(self, item: SpanCostCalculatorQueueItem) -> None:
|
|
75
|
+
self._queue.append(item)
|
|
76
|
+
|
|
77
|
+
def calculate_cost(
|
|
78
|
+
self,
|
|
79
|
+
start_time: datetime,
|
|
80
|
+
attributes: Mapping[str, Any],
|
|
81
|
+
) -> Optional[models.SpanCost]:
|
|
82
|
+
if not attributes:
|
|
83
|
+
return None
|
|
84
|
+
cost_model = self._model_store.find_model(
|
|
85
|
+
start_time=start_time,
|
|
86
|
+
attributes=attributes,
|
|
87
|
+
)
|
|
88
|
+
calculator = SpanCostDetailsCalculator(cost_model.token_prices if cost_model else [])
|
|
89
|
+
details = calculator.calculate_details(attributes)
|
|
90
|
+
if not details:
|
|
91
|
+
return None
|
|
92
|
+
|
|
93
|
+
cost = models.SpanCost(
|
|
94
|
+
model_id=cost_model.id if cost_model else None,
|
|
95
|
+
span_start_time=start_time,
|
|
96
|
+
)
|
|
97
|
+
for detail in details:
|
|
98
|
+
cost.append_detail(detail)
|
|
99
|
+
return cost
|
phoenix/server/dml_event.py
CHANGED
|
@@ -33,6 +33,10 @@ class ProjectDmlEvent(DmlEvent):
|
|
|
33
33
|
class ProjectDeleteEvent(ProjectDmlEvent): ...
|
|
34
34
|
|
|
35
35
|
|
|
36
|
+
@dataclass(frozen=True)
|
|
37
|
+
class ProjectInsertEvent(ProjectDmlEvent): ...
|
|
38
|
+
|
|
39
|
+
|
|
36
40
|
@dataclass(frozen=True)
|
|
37
41
|
class SpanDmlEvent(ProjectDmlEvent): ...
|
|
38
42
|
|
|
@@ -123,6 +127,19 @@ class TraceAnnotationInsertEvent(TraceAnnotationDmlEvent): ...
|
|
|
123
127
|
class TraceAnnotationDeleteEvent(TraceAnnotationDmlEvent): ...
|
|
124
128
|
|
|
125
129
|
|
|
130
|
+
@dataclass(frozen=True)
|
|
131
|
+
class ProjectSessionAnnotationDmlEvent(DmlEvent):
|
|
132
|
+
table = models.ProjectSessionAnnotation
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
@dataclass(frozen=True)
|
|
136
|
+
class ProjectSessionAnnotationInsertEvent(ProjectSessionAnnotationDmlEvent): ...
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
@dataclass(frozen=True)
|
|
140
|
+
class ProjectSessionAnnotationDeleteEvent(ProjectSessionAnnotationDmlEvent): ...
|
|
141
|
+
|
|
142
|
+
|
|
126
143
|
@dataclass(frozen=True)
|
|
127
144
|
class DocumentAnnotationDmlEvent(DmlEvent):
|
|
128
145
|
table = models.DocumentAnnotation
|
|
@@ -120,6 +120,7 @@ class _SpanDmlEventHandler(_DmlEventHandler[SpanDmlEvent]):
|
|
|
120
120
|
def _clear(cache: CacheForDataLoaders, project_id: int) -> None:
|
|
121
121
|
cache.latency_ms_quantile.invalidate(project_id)
|
|
122
122
|
cache.token_count.invalidate(project_id)
|
|
123
|
+
cache.token_cost.invalidate(project_id)
|
|
123
124
|
cache.record_count.invalidate(project_id)
|
|
124
125
|
cache.min_start_or_max_end_time.invalidate(project_id)
|
|
125
126
|
|
|
@@ -127,6 +128,10 @@ class _SpanDmlEventHandler(_DmlEventHandler[SpanDmlEvent]):
|
|
|
127
128
|
class _SpanDeleteEventHandler(_SpanDmlEventHandler):
|
|
128
129
|
@staticmethod
|
|
129
130
|
def _clear(cache: CacheForDataLoaders, project_id: int) -> None:
|
|
131
|
+
# Call parent's cache invalidation first (core span caches)
|
|
132
|
+
_SpanDmlEventHandler._clear(cache, project_id)
|
|
133
|
+
|
|
134
|
+
# Then invalidate annotation-specific caches
|
|
130
135
|
cache.annotation_summary.invalidate_project(project_id)
|
|
131
136
|
cache.document_evaluation_summary.invalidate_project(project_id)
|
|
132
137
|
|
phoenix/server/email/sender.py
CHANGED
|
@@ -1,15 +1,18 @@
|
|
|
1
|
+
import logging
|
|
1
2
|
import smtplib
|
|
2
3
|
import ssl
|
|
3
4
|
from email.message import EmailMessage
|
|
4
5
|
from pathlib import Path
|
|
5
6
|
from typing import Literal
|
|
6
|
-
from urllib.parse import urljoin
|
|
7
7
|
|
|
8
8
|
from anyio import to_thread
|
|
9
|
+
from email_validator import EmailNotValidError, validate_email
|
|
9
10
|
from jinja2 import Environment, FileSystemLoader, select_autoescape
|
|
10
11
|
from typing_extensions import TypeAlias
|
|
11
12
|
|
|
12
|
-
from phoenix.config import get_env_root_url
|
|
13
|
+
from phoenix.config import get_env_root_url, get_env_support_email
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
13
16
|
|
|
14
17
|
EMAIL_TEMPLATE_FOLDER = Path(__file__).parent / "templates"
|
|
15
18
|
|
|
@@ -45,13 +48,20 @@ class SimpleEmailSender:
|
|
|
45
48
|
email: str,
|
|
46
49
|
name: str,
|
|
47
50
|
) -> None:
|
|
51
|
+
try:
|
|
52
|
+
email = validate_email(email, check_deliverability=False).normalized
|
|
53
|
+
except EmailNotValidError:
|
|
54
|
+
logger.warning("Skipping welcome email for user with invalid email address")
|
|
55
|
+
return
|
|
56
|
+
|
|
48
57
|
subject = "[Phoenix] Welcome to Arize Phoenix"
|
|
49
58
|
template_name = "welcome.html"
|
|
50
59
|
|
|
51
60
|
template = self.env.get_template(template_name)
|
|
61
|
+
|
|
52
62
|
html_content = template.render(
|
|
53
63
|
name=name,
|
|
54
|
-
welcome_url=
|
|
64
|
+
welcome_url=str(get_env_root_url()),
|
|
55
65
|
)
|
|
56
66
|
|
|
57
67
|
msg = EmailMessage()
|
|
@@ -67,6 +77,12 @@ class SimpleEmailSender:
|
|
|
67
77
|
email: str,
|
|
68
78
|
reset_url: str,
|
|
69
79
|
) -> None:
|
|
80
|
+
try:
|
|
81
|
+
email = validate_email(email, check_deliverability=False).normalized
|
|
82
|
+
except EmailNotValidError:
|
|
83
|
+
logger.warning("Skipping password reset email for user with invalid email address")
|
|
84
|
+
return
|
|
85
|
+
|
|
70
86
|
subject = "[Phoenix] Password Reset Request"
|
|
71
87
|
template_name = "password_reset.html"
|
|
72
88
|
|
|
@@ -81,6 +97,43 @@ class SimpleEmailSender:
|
|
|
81
97
|
|
|
82
98
|
await to_thread.run_sync(self._send_email, msg)
|
|
83
99
|
|
|
100
|
+
async def send_db_usage_warning_email(
|
|
101
|
+
self,
|
|
102
|
+
email: str,
|
|
103
|
+
current_usage_gibibytes: float,
|
|
104
|
+
allocated_storage_gibibytes: float,
|
|
105
|
+
notification_threshold_percentage: float,
|
|
106
|
+
) -> None:
|
|
107
|
+
try:
|
|
108
|
+
email = validate_email(email, check_deliverability=False).normalized
|
|
109
|
+
except EmailNotValidError:
|
|
110
|
+
logger.warning(
|
|
111
|
+
"Skipping database usage warning email for user with invalid email address"
|
|
112
|
+
)
|
|
113
|
+
return
|
|
114
|
+
|
|
115
|
+
subject = "[Phoenix] Database Disk Space Usage Threshold Exceeded"
|
|
116
|
+
template_name = "db_disk_usage_notification.html"
|
|
117
|
+
|
|
118
|
+
support_email = get_env_support_email()
|
|
119
|
+
template = self.env.get_template(template_name)
|
|
120
|
+
html_content = template.render(
|
|
121
|
+
current_usage_gibibytes=current_usage_gibibytes,
|
|
122
|
+
allocated_storage_gibibytes=allocated_storage_gibibytes,
|
|
123
|
+
notification_threshold_percentage=notification_threshold_percentage,
|
|
124
|
+
support_email=support_email,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
msg = EmailMessage()
|
|
128
|
+
msg["Subject"] = subject
|
|
129
|
+
msg["From"] = self.sender_email
|
|
130
|
+
msg["To"] = email
|
|
131
|
+
if support_email:
|
|
132
|
+
msg["Cc"] = support_email
|
|
133
|
+
msg.set_content(html_content, subtype="html")
|
|
134
|
+
|
|
135
|
+
await to_thread.run_sync(self._send_email, msg)
|
|
136
|
+
|
|
84
137
|
def _send_email(self, msg: EmailMessage) -> None:
|
|
85
138
|
context: ssl.SSLContext
|
|
86
139
|
if self.validate_certs:
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
<!DOCTYPE html>
|
|
2
|
+
<html>
|
|
3
|
+
<head>
|
|
4
|
+
<meta charset="UTF-8" />
|
|
5
|
+
<title>Database Usage Notification</title>
|
|
6
|
+
</head>
|
|
7
|
+
<body>
|
|
8
|
+
<h1>Database Usage Notification</h1>
|
|
9
|
+
<p>Your Phoenix database usage has exceeded the notification threshold.</p>
|
|
10
|
+
<p><strong>Current Usage:</strong> {{ current_usage_gibibytes|round(1) }} GiB</p>
|
|
11
|
+
<p><strong>Allocated Storage:</strong> {{ allocated_storage_gibibytes|round(1) }} GiB</p>
|
|
12
|
+
<p><strong>Usage Percentage:</strong> {{ ((current_usage_gibibytes / allocated_storage_gibibytes) * 100)|round(1) }}%</p>
|
|
13
|
+
<p><strong>Notification Threshold:</strong> {{ notification_threshold_percentage }}%</p>
|
|
14
|
+
<p>Please consider removing old data or increasing your storage allocation to prevent interruption.</p>
|
|
15
|
+
{% if support_email %}
|
|
16
|
+
<p>If you need assistance, please contact support at <a id="support-email" href="mailto:{{ support_email }}">{{ support_email }}</a>.</p>
|
|
17
|
+
{% endif %}
|
|
18
|
+
</body>
|
|
19
|
+
</html>
|
phoenix/server/email/types.py
CHANGED
|
@@ -19,8 +19,19 @@ class PasswordResetEmailSender(Protocol):
|
|
|
19
19
|
) -> None: ...
|
|
20
20
|
|
|
21
21
|
|
|
22
|
+
class DbUsageWarningEmailSender(Protocol):
|
|
23
|
+
async def send_db_usage_warning_email(
|
|
24
|
+
self,
|
|
25
|
+
email: str,
|
|
26
|
+
current_usage_gibibytes: float,
|
|
27
|
+
allocated_storage_gibibytes: float,
|
|
28
|
+
notification_threshold_percentage: float,
|
|
29
|
+
) -> None: ...
|
|
30
|
+
|
|
31
|
+
|
|
22
32
|
class EmailSender(
|
|
23
33
|
WelcomeEmailSender,
|
|
24
34
|
PasswordResetEmailSender,
|
|
35
|
+
DbUsageWarningEmailSender,
|
|
25
36
|
Protocol,
|
|
26
37
|
): ...
|
|
File without changes
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
from secrets import token_hex
|
|
5
|
+
|
|
6
|
+
EXPERIMENT_PROJECT_NAME_PATTERN = re.compile(r"^Experiment-[0-9a-f]{24}$")
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def generate_experiment_project_name() -> str:
|
|
10
|
+
return f"Experiment-{token_hex(12)}"
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def is_experiment_project_name(name: str) -> bool:
|
|
14
|
+
return bool(EXPERIMENT_PROJECT_NAME_PATTERN.match(name))
|
phoenix/server/grpc_server.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
|
1
|
-
from
|
|
2
|
-
from typing import TYPE_CHECKING, Any, Optional
|
|
1
|
+
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Iterable, Optional
|
|
3
2
|
|
|
4
3
|
import grpc
|
|
5
4
|
from grpc.aio import RpcContext, Server, ServerInterceptor
|
|
@@ -11,6 +10,7 @@ from opentelemetry.proto.collector.trace.v1.trace_service_pb2_grpc import (
|
|
|
11
10
|
TraceServiceServicer,
|
|
12
11
|
add_TraceServiceServicer_to_server,
|
|
13
12
|
)
|
|
13
|
+
from starlette.concurrency import run_in_threadpool
|
|
14
14
|
from typing_extensions import TypeAlias
|
|
15
15
|
|
|
16
16
|
from phoenix.auth import CanReadToken
|
|
@@ -34,10 +34,10 @@ ProjectName: TypeAlias = str
|
|
|
34
34
|
class Servicer(TraceServiceServicer): # type: ignore[misc,unused-ignore]
|
|
35
35
|
def __init__(
|
|
36
36
|
self,
|
|
37
|
-
|
|
37
|
+
enqueue_span: Callable[[Span, ProjectName], Awaitable[None]],
|
|
38
38
|
) -> None:
|
|
39
39
|
super().__init__()
|
|
40
|
-
self.
|
|
40
|
+
self._enqueue_span = enqueue_span
|
|
41
41
|
|
|
42
42
|
async def Export(
|
|
43
43
|
self,
|
|
@@ -48,28 +48,28 @@ class Servicer(TraceServiceServicer): # type: ignore[misc,unused-ignore]
|
|
|
48
48
|
project_name = get_project_name(resource_spans.resource.attributes)
|
|
49
49
|
for scope_span in resource_spans.scope_spans:
|
|
50
50
|
for otlp_span in scope_span.spans:
|
|
51
|
-
span = decode_otlp_span
|
|
52
|
-
await self.
|
|
51
|
+
span = await run_in_threadpool(decode_otlp_span, otlp_span)
|
|
52
|
+
await self._enqueue_span(span, project_name)
|
|
53
53
|
return ExportTraceServiceResponse()
|
|
54
54
|
|
|
55
55
|
|
|
56
56
|
class GrpcServer:
|
|
57
57
|
def __init__(
|
|
58
58
|
self,
|
|
59
|
-
|
|
59
|
+
enqueue_span: Callable[[Span, ProjectName], Awaitable[None]],
|
|
60
60
|
tracer_provider: Optional["TracerProvider"] = None,
|
|
61
61
|
enable_prometheus: bool = False,
|
|
62
62
|
disabled: bool = False,
|
|
63
63
|
token_store: Optional[CanReadToken] = None,
|
|
64
|
-
interceptors:
|
|
64
|
+
interceptors: Iterable[ServerInterceptor] = (),
|
|
65
65
|
) -> None:
|
|
66
|
-
self.
|
|
66
|
+
self._enqueue_span = enqueue_span
|
|
67
67
|
self._server: Optional[Server] = None
|
|
68
68
|
self._tracer_provider = tracer_provider
|
|
69
69
|
self._enable_prometheus = enable_prometheus
|
|
70
70
|
self._disabled = disabled
|
|
71
71
|
self._token_store = token_store
|
|
72
|
-
self._interceptors = interceptors
|
|
72
|
+
self._interceptors = list(interceptors)
|
|
73
73
|
|
|
74
74
|
async def __aenter__(self) -> None:
|
|
75
75
|
interceptors = self._interceptors
|
|
@@ -106,7 +106,7 @@ class GrpcServer:
|
|
|
106
106
|
server.add_secure_port(f"[::]:{get_env_grpc_port()}", server_credentials)
|
|
107
107
|
else:
|
|
108
108
|
server.add_insecure_port(f"[::]:{get_env_grpc_port()}")
|
|
109
|
-
add_TraceServiceServicer_to_server(Servicer(self.
|
|
109
|
+
add_TraceServiceServicer_to_server(Servicer(self._enqueue_span), server) # type: ignore[no-untyped-call,unused-ignore]
|
|
110
110
|
await server.start()
|
|
111
111
|
self._server = server
|
|
112
112
|
|
phoenix/server/jwt_store.py
CHANGED
|
@@ -20,7 +20,7 @@ from phoenix.auth import (
|
|
|
20
20
|
)
|
|
21
21
|
from phoenix.config import get_env_enable_prometheus
|
|
22
22
|
from phoenix.db import models
|
|
23
|
-
from phoenix.db.
|
|
23
|
+
from phoenix.db.models import UserRoleName
|
|
24
24
|
from phoenix.server.types import (
|
|
25
25
|
AccessToken,
|
|
26
26
|
AccessTokenAttributes,
|
|
@@ -164,7 +164,7 @@ class JwtStore:
|
|
|
164
164
|
for token_id in token_ids:
|
|
165
165
|
if isinstance(token_id, PasswordResetTokenId):
|
|
166
166
|
password_reset_token_ids.append(token_id)
|
|
167
|
-
|
|
167
|
+
elif isinstance(token_id, AccessTokenId):
|
|
168
168
|
access_token_ids.append(token_id)
|
|
169
169
|
elif isinstance(token_id, RefreshTokenId):
|
|
170
170
|
refresh_token_ids.append(token_id)
|
|
@@ -182,10 +182,10 @@ class JwtStore:
|
|
|
182
182
|
await gather(*coroutines)
|
|
183
183
|
|
|
184
184
|
async def log_out(self, user_id: UserId) -> None:
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
185
|
+
async with self._db() as session:
|
|
186
|
+
for cls in (AccessTokenId, RefreshTokenId):
|
|
187
|
+
table = cls.table
|
|
188
|
+
stmt = delete(table).where(table.user_id == int(user_id)).returning(table.id)
|
|
189
189
|
async for id_ in await session.stream_scalars(stmt):
|
|
190
190
|
await self._evict(cls(id_))
|
|
191
191
|
|
|
@@ -260,7 +260,7 @@ class _Store(DaemonTask, Generic[_ClaimSetT, _TokenT, _TokenIdT, _RecordT], ABC)
|
|
|
260
260
|
if not record:
|
|
261
261
|
return None
|
|
262
262
|
token, role = record
|
|
263
|
-
_, claims = self._from_db(token,
|
|
263
|
+
_, claims = self._from_db(token, role)
|
|
264
264
|
self._claims[token_id] = claims
|
|
265
265
|
return claims
|
|
266
266
|
|
|
@@ -277,7 +277,7 @@ class _Store(DaemonTask, Generic[_ClaimSetT, _TokenT, _TokenIdT, _RecordT], ABC)
|
|
|
277
277
|
await session.execute(stmt)
|
|
278
278
|
|
|
279
279
|
@abstractmethod
|
|
280
|
-
def _from_db(self, record: _RecordT, role:
|
|
280
|
+
def _from_db(self, record: _RecordT, role: UserRoleName) -> tuple[_TokenIdT, _ClaimSetT]: ...
|
|
281
281
|
|
|
282
282
|
@abstractmethod
|
|
283
283
|
def _to_db(self, claims: _ClaimSetT) -> _RecordT: ...
|
|
@@ -300,12 +300,12 @@ class _Store(DaemonTask, Generic[_ClaimSetT, _TokenT, _TokenIdT, _RecordT], ABC)
|
|
|
300
300
|
await self._delete_expired_tokens(session)
|
|
301
301
|
async with session.begin_nested():
|
|
302
302
|
async for record, role in await session.stream(self._update_stmt):
|
|
303
|
-
token_id, claim_set = self._from_db(record,
|
|
303
|
+
token_id, claim_set = self._from_db(record, role)
|
|
304
304
|
claims[token_id] = claim_set
|
|
305
305
|
self._claims = claims
|
|
306
306
|
|
|
307
307
|
@cached_property
|
|
308
|
-
def _update_stmt(self) -> Select[tuple[_RecordT,
|
|
308
|
+
def _update_stmt(self) -> Select[tuple[_RecordT, UserRoleName]]:
|
|
309
309
|
return (
|
|
310
310
|
select(self._table, models.UserRole.name)
|
|
311
311
|
.join_from(self._table, models.User)
|
|
@@ -314,7 +314,9 @@ class _Store(DaemonTask, Generic[_ClaimSetT, _TokenT, _TokenIdT, _RecordT], ABC)
|
|
|
314
314
|
|
|
315
315
|
async def _delete_expired_tokens(self, session: Any) -> None:
|
|
316
316
|
now = datetime.now(timezone.utc)
|
|
317
|
-
|
|
317
|
+
# Per JWT RFC 7519 Section 4.1.4, tokens expire "on or after" the expiration time.
|
|
318
|
+
# Use <= to include tokens expiring at exactly this moment.
|
|
319
|
+
await session.execute(delete(self._table).where(self._table.expires_at <= now))
|
|
318
320
|
|
|
319
321
|
async def _run(self) -> None:
|
|
320
322
|
while self._running:
|
|
@@ -341,7 +343,7 @@ class _PasswordResetTokenStore(
|
|
|
341
343
|
def _from_db(
|
|
342
344
|
self,
|
|
343
345
|
record: models.PasswordResetToken,
|
|
344
|
-
user_role:
|
|
346
|
+
user_role: UserRoleName,
|
|
345
347
|
) -> tuple[PasswordResetTokenId, PasswordResetTokenClaims]:
|
|
346
348
|
token_id = PasswordResetTokenId(record.id)
|
|
347
349
|
return token_id, PasswordResetTokenClaims(
|
|
@@ -380,7 +382,7 @@ class _AccessTokenStore(
|
|
|
380
382
|
def _from_db(
|
|
381
383
|
self,
|
|
382
384
|
record: models.AccessToken,
|
|
383
|
-
user_role:
|
|
385
|
+
user_role: UserRoleName,
|
|
384
386
|
) -> tuple[AccessTokenId, AccessTokenClaims]:
|
|
385
387
|
token_id = AccessTokenId(record.id)
|
|
386
388
|
refresh_token_id = RefreshTokenId(record.refresh_token_id)
|
|
@@ -424,7 +426,7 @@ class _RefreshTokenStore(
|
|
|
424
426
|
def _from_db(
|
|
425
427
|
self,
|
|
426
428
|
record: models.RefreshToken,
|
|
427
|
-
user_role:
|
|
429
|
+
user_role: UserRoleName,
|
|
428
430
|
) -> tuple[RefreshTokenId, RefreshTokenClaims]:
|
|
429
431
|
token_id = RefreshTokenId(record.id)
|
|
430
432
|
return token_id, RefreshTokenClaims(
|
|
@@ -470,7 +472,7 @@ class _ApiKeyStore(
|
|
|
470
472
|
def _from_db(
|
|
471
473
|
self,
|
|
472
474
|
record: models.ApiKey,
|
|
473
|
-
user_role:
|
|
475
|
+
user_role: UserRoleName,
|
|
474
476
|
) -> tuple[ApiKeyId, ApiKeyClaims]:
|
|
475
477
|
token_id = ApiKeyId(record.id)
|
|
476
478
|
return token_id, ApiKeyClaims(
|