arize-phoenix 10.0.4__py3-none-any.whl → 12.28.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +124 -72
- arize_phoenix-12.28.1.dist-info/RECORD +499 -0
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
- phoenix/__generated__/__init__.py +0 -0
- phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
- phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
- phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
- phoenix/__init__.py +5 -4
- phoenix/auth.py +39 -2
- phoenix/config.py +1763 -91
- phoenix/datetime_utils.py +120 -2
- phoenix/db/README.md +595 -25
- phoenix/db/bulk_inserter.py +145 -103
- phoenix/db/engines.py +140 -33
- phoenix/db/enums.py +3 -12
- phoenix/db/facilitator.py +302 -35
- phoenix/db/helpers.py +1000 -65
- phoenix/db/iam_auth.py +64 -0
- phoenix/db/insertion/dataset.py +135 -2
- phoenix/db/insertion/document_annotation.py +9 -6
- phoenix/db/insertion/evaluation.py +2 -3
- phoenix/db/insertion/helpers.py +17 -2
- phoenix/db/insertion/session_annotation.py +176 -0
- phoenix/db/insertion/span.py +15 -11
- phoenix/db/insertion/span_annotation.py +3 -4
- phoenix/db/insertion/trace_annotation.py +3 -4
- phoenix/db/insertion/types.py +50 -20
- phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
- phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
- phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
- phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
- phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
- phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
- phoenix/db/migrations/versions/a20694b15f82_cost.py +196 -0
- phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
- phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
- phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
- phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
- phoenix/db/models.py +669 -56
- phoenix/db/pg_config.py +10 -0
- phoenix/db/types/model_provider.py +4 -0
- phoenix/db/types/token_price_customization.py +29 -0
- phoenix/db/types/trace_retention.py +23 -15
- phoenix/experiments/evaluators/utils.py +3 -3
- phoenix/experiments/functions.py +160 -52
- phoenix/experiments/tracing.py +2 -2
- phoenix/experiments/types.py +1 -1
- phoenix/inferences/inferences.py +1 -2
- phoenix/server/api/auth.py +38 -7
- phoenix/server/api/auth_messages.py +46 -0
- phoenix/server/api/context.py +100 -4
- phoenix/server/api/dataloaders/__init__.py +79 -5
- phoenix/server/api/dataloaders/annotation_configs_by_project.py +31 -0
- phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
- phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
- phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
- phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
- phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
- phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
- phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
- phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
- phoenix/server/api/dataloaders/dataset_labels.py +36 -0
- phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
- phoenix/server/api/dataloaders/document_evaluations.py +6 -9
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
- phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
- phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
- phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
- phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
- phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
- phoenix/server/api/dataloaders/last_used_times_by_generative_model_id.py +35 -0
- phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
- phoenix/server/api/dataloaders/record_counts.py +37 -10
- phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
- phoenix/server/api/dataloaders/span_cost_by_span.py +24 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_generative_model.py +56 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_project_session.py +57 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_span.py +43 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_trace.py +56 -0
- phoenix/server/api/dataloaders/span_cost_details_by_span_cost.py +27 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment.py +57 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_run.py +58 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_generative_model.py +55 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project.py +152 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project_session.py +56 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_trace.py +55 -0
- phoenix/server/api/dataloaders/span_costs.py +29 -0
- phoenix/server/api/dataloaders/table_fields.py +2 -2
- phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
- phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
- phoenix/server/api/dataloaders/types.py +29 -0
- phoenix/server/api/exceptions.py +11 -1
- phoenix/server/api/helpers/dataset_helpers.py +5 -1
- phoenix/server/api/helpers/playground_clients.py +1243 -292
- phoenix/server/api/helpers/playground_registry.py +2 -2
- phoenix/server/api/helpers/playground_spans.py +8 -4
- phoenix/server/api/helpers/playground_users.py +26 -0
- phoenix/server/api/helpers/prompts/conversions/aws.py +83 -0
- phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
- phoenix/server/api/helpers/prompts/models.py +205 -22
- phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
- phoenix/server/api/input_types/ChatCompletionInput.py +6 -2
- phoenix/server/api/input_types/CreateProjectInput.py +27 -0
- phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
- phoenix/server/api/input_types/DatasetFilter.py +17 -0
- phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
- phoenix/server/api/input_types/GenerativeCredentialInput.py +9 -0
- phoenix/server/api/input_types/GenerativeModelInput.py +5 -0
- phoenix/server/api/input_types/ProjectSessionSort.py +161 -1
- phoenix/server/api/input_types/PromptFilter.py +14 -0
- phoenix/server/api/input_types/PromptVersionInput.py +52 -1
- phoenix/server/api/input_types/SpanSort.py +44 -7
- phoenix/server/api/input_types/TimeBinConfig.py +23 -0
- phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
- phoenix/server/api/input_types/UserRoleInput.py +1 -0
- phoenix/server/api/mutations/__init__.py +10 -0
- phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
- phoenix/server/api/mutations/api_key_mutations.py +19 -23
- phoenix/server/api/mutations/chat_mutations.py +154 -47
- phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
- phoenix/server/api/mutations/dataset_mutations.py +21 -16
- phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
- phoenix/server/api/mutations/experiment_mutations.py +2 -2
- phoenix/server/api/mutations/export_events_mutations.py +3 -3
- phoenix/server/api/mutations/model_mutations.py +210 -0
- phoenix/server/api/mutations/project_mutations.py +49 -10
- phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
- phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
- phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
- phoenix/server/api/mutations/prompt_mutations.py +65 -129
- phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
- phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
- phoenix/server/api/mutations/trace_annotations_mutations.py +14 -10
- phoenix/server/api/mutations/trace_mutations.py +47 -3
- phoenix/server/api/mutations/user_mutations.py +66 -41
- phoenix/server/api/queries.py +768 -293
- phoenix/server/api/routers/__init__.py +2 -2
- phoenix/server/api/routers/auth.py +154 -88
- phoenix/server/api/routers/ldap.py +229 -0
- phoenix/server/api/routers/oauth2.py +369 -106
- phoenix/server/api/routers/v1/__init__.py +24 -4
- phoenix/server/api/routers/v1/annotation_configs.py +23 -31
- phoenix/server/api/routers/v1/annotations.py +481 -17
- phoenix/server/api/routers/v1/datasets.py +395 -81
- phoenix/server/api/routers/v1/documents.py +142 -0
- phoenix/server/api/routers/v1/evaluations.py +24 -31
- phoenix/server/api/routers/v1/experiment_evaluations.py +19 -8
- phoenix/server/api/routers/v1/experiment_runs.py +337 -59
- phoenix/server/api/routers/v1/experiments.py +479 -48
- phoenix/server/api/routers/v1/models.py +7 -0
- phoenix/server/api/routers/v1/projects.py +18 -49
- phoenix/server/api/routers/v1/prompts.py +54 -40
- phoenix/server/api/routers/v1/sessions.py +108 -0
- phoenix/server/api/routers/v1/spans.py +1091 -81
- phoenix/server/api/routers/v1/traces.py +132 -78
- phoenix/server/api/routers/v1/users.py +389 -0
- phoenix/server/api/routers/v1/utils.py +3 -7
- phoenix/server/api/subscriptions.py +305 -88
- phoenix/server/api/types/Annotation.py +90 -23
- phoenix/server/api/types/ApiKey.py +13 -17
- phoenix/server/api/types/AuthMethod.py +1 -0
- phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
- phoenix/server/api/types/CostBreakdown.py +12 -0
- phoenix/server/api/types/Dataset.py +226 -72
- phoenix/server/api/types/DatasetExample.py +88 -18
- phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
- phoenix/server/api/types/DatasetLabel.py +57 -0
- phoenix/server/api/types/DatasetSplit.py +98 -0
- phoenix/server/api/types/DatasetVersion.py +49 -4
- phoenix/server/api/types/DocumentAnnotation.py +212 -0
- phoenix/server/api/types/Experiment.py +264 -59
- phoenix/server/api/types/ExperimentComparison.py +5 -10
- phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
- phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
- phoenix/server/api/types/ExperimentRun.py +169 -65
- phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
- phoenix/server/api/types/GenerativeModel.py +245 -3
- phoenix/server/api/types/GenerativeProvider.py +70 -11
- phoenix/server/api/types/{Model.py → InferenceModel.py} +1 -1
- phoenix/server/api/types/ModelInterface.py +16 -0
- phoenix/server/api/types/PlaygroundModel.py +20 -0
- phoenix/server/api/types/Project.py +1278 -216
- phoenix/server/api/types/ProjectSession.py +188 -28
- phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
- phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
- phoenix/server/api/types/Prompt.py +119 -39
- phoenix/server/api/types/PromptLabel.py +42 -25
- phoenix/server/api/types/PromptVersion.py +11 -8
- phoenix/server/api/types/PromptVersionTag.py +65 -25
- phoenix/server/api/types/ServerStatus.py +6 -0
- phoenix/server/api/types/Span.py +167 -123
- phoenix/server/api/types/SpanAnnotation.py +189 -42
- phoenix/server/api/types/SpanCostDetailSummaryEntry.py +10 -0
- phoenix/server/api/types/SpanCostSummary.py +10 -0
- phoenix/server/api/types/SystemApiKey.py +65 -1
- phoenix/server/api/types/TokenPrice.py +16 -0
- phoenix/server/api/types/TokenUsage.py +3 -3
- phoenix/server/api/types/Trace.py +223 -51
- phoenix/server/api/types/TraceAnnotation.py +149 -50
- phoenix/server/api/types/User.py +137 -32
- phoenix/server/api/types/UserApiKey.py +73 -26
- phoenix/server/api/types/node.py +10 -0
- phoenix/server/api/types/pagination.py +11 -2
- phoenix/server/app.py +290 -45
- phoenix/server/authorization.py +38 -3
- phoenix/server/bearer_auth.py +34 -24
- phoenix/server/cost_tracking/cost_details_calculator.py +196 -0
- phoenix/server/cost_tracking/cost_model_lookup.py +179 -0
- phoenix/server/cost_tracking/helpers.py +68 -0
- phoenix/server/cost_tracking/model_cost_manifest.json +3657 -830
- phoenix/server/cost_tracking/regex_specificity.py +397 -0
- phoenix/server/cost_tracking/token_cost_calculator.py +57 -0
- phoenix/server/daemons/__init__.py +0 -0
- phoenix/server/daemons/db_disk_usage_monitor.py +214 -0
- phoenix/server/daemons/generative_model_store.py +103 -0
- phoenix/server/daemons/span_cost_calculator.py +99 -0
- phoenix/server/dml_event.py +17 -0
- phoenix/server/dml_event_handler.py +5 -0
- phoenix/server/email/sender.py +56 -3
- phoenix/server/email/templates/db_disk_usage_notification.html +19 -0
- phoenix/server/email/types.py +11 -0
- phoenix/server/experiments/__init__.py +0 -0
- phoenix/server/experiments/utils.py +14 -0
- phoenix/server/grpc_server.py +11 -11
- phoenix/server/jwt_store.py +17 -15
- phoenix/server/ldap.py +1449 -0
- phoenix/server/main.py +26 -10
- phoenix/server/oauth2.py +330 -12
- phoenix/server/prometheus.py +66 -6
- phoenix/server/rate_limiters.py +4 -9
- phoenix/server/retention.py +33 -20
- phoenix/server/session_filters.py +49 -0
- phoenix/server/static/.vite/manifest.json +55 -51
- phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
- phoenix/server/static/assets/{index-E0M82BdE.js → index-CTQoemZv.js} +140 -56
- phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
- phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
- phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
- phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
- phoenix/server/static/assets/vendor-recharts-V9cwpXsm.js +37 -0
- phoenix/server/static/assets/vendor-shiki-Do--csgv.js +5 -0
- phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
- phoenix/server/templates/index.html +40 -6
- phoenix/server/thread_server.py +1 -2
- phoenix/server/types.py +14 -4
- phoenix/server/utils.py +74 -0
- phoenix/session/client.py +56 -3
- phoenix/session/data_extractor.py +5 -0
- phoenix/session/evaluation.py +14 -5
- phoenix/session/session.py +45 -9
- phoenix/settings.py +5 -0
- phoenix/trace/attributes.py +80 -13
- phoenix/trace/dsl/helpers.py +90 -1
- phoenix/trace/dsl/query.py +8 -6
- phoenix/trace/projects.py +5 -0
- phoenix/utilities/template_formatters.py +1 -1
- phoenix/version.py +1 -1
- arize_phoenix-10.0.4.dist-info/RECORD +0 -405
- phoenix/server/api/types/Evaluation.py +0 -39
- phoenix/server/cost_tracking/cost_lookup.py +0 -255
- phoenix/server/static/assets/components-DULKeDfL.js +0 -4365
- phoenix/server/static/assets/pages-Cl0A-0U2.js +0 -7430
- phoenix/server/static/assets/vendor-WIZid84E.css +0 -1
- phoenix/server/static/assets/vendor-arizeai-Dy-0mSNw.js +0 -649
- phoenix/server/static/assets/vendor-codemirror-DBtifKNr.js +0 -33
- phoenix/server/static/assets/vendor-oB4u9zuV.js +0 -905
- phoenix/server/static/assets/vendor-recharts-D-T4KPz2.js +0 -59
- phoenix/server/static/assets/vendor-shiki-BMn4O_9F.js +0 -5
- phoenix/server/static/assets/vendor-three-C5WAXd5r.js +0 -2998
- phoenix/utilities/deprecation.py +0 -31
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
phoenix/db/facilitator.py
CHANGED
|
@@ -1,18 +1,20 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
|
+
import json
|
|
4
5
|
import logging
|
|
6
|
+
import re
|
|
5
7
|
import secrets
|
|
6
8
|
from asyncio import gather
|
|
9
|
+
from datetime import datetime, timedelta, timezone
|
|
7
10
|
from functools import partial
|
|
8
|
-
from
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import NamedTuple, Optional, Union
|
|
9
13
|
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
select,
|
|
15
|
-
)
|
|
14
|
+
import sqlalchemy as sa
|
|
15
|
+
from sqlalchemy import select
|
|
16
|
+
from sqlalchemy.orm import InstrumentedAttribute, joinedload
|
|
17
|
+
from sqlalchemy.sql.dml import ReturningDelete
|
|
16
18
|
|
|
17
19
|
from phoenix import config
|
|
18
20
|
from phoenix.auth import (
|
|
@@ -24,13 +26,16 @@ from phoenix.auth import (
|
|
|
24
26
|
compute_password_hash,
|
|
25
27
|
)
|
|
26
28
|
from phoenix.config import (
|
|
29
|
+
LDAPConfig,
|
|
27
30
|
get_env_admins,
|
|
28
31
|
get_env_default_admin_initial_password,
|
|
32
|
+
get_env_default_retention_policy_days,
|
|
29
33
|
get_env_disable_basic_auth,
|
|
34
|
+
get_env_oauth2_settings,
|
|
30
35
|
)
|
|
31
36
|
from phoenix.db import models
|
|
32
37
|
from phoenix.db.constants import DEFAULT_PROJECT_TRACE_RETENTION_POLICY_ID
|
|
33
|
-
from phoenix.db.enums import
|
|
38
|
+
from phoenix.db.enums import ENUM_COLUMNS
|
|
34
39
|
from phoenix.db.types.trace_retention import (
|
|
35
40
|
MaxDaysRule,
|
|
36
41
|
TraceRetentionCronExpression,
|
|
@@ -66,6 +71,8 @@ class Facilitator:
|
|
|
66
71
|
_get_system_user_id,
|
|
67
72
|
partial(_ensure_admins, email_sender=self._email_sender),
|
|
68
73
|
_ensure_default_project_trace_retention_policy,
|
|
74
|
+
_ensure_model_costs,
|
|
75
|
+
_delete_expired_childless_records,
|
|
69
76
|
):
|
|
70
77
|
await fn(self._db)
|
|
71
78
|
|
|
@@ -76,18 +83,17 @@ async def _ensure_enums(db: DbSessionFactory) -> None:
|
|
|
76
83
|
they will be added. If any values are present in the database but not in the enum, an error will
|
|
77
84
|
be raised. This function is idempotent: it will not add duplicate values to the database.
|
|
78
85
|
"""
|
|
79
|
-
for column
|
|
86
|
+
for column in ENUM_COLUMNS:
|
|
80
87
|
table = column.class_
|
|
88
|
+
assert isinstance(column.type, sa.Enum)
|
|
81
89
|
async with db() as session:
|
|
82
|
-
existing = set(
|
|
83
|
-
|
|
84
|
-
)
|
|
85
|
-
expected = set(e.value for e in enum)
|
|
90
|
+
existing = set(await session.scalars(sa.select(column)))
|
|
91
|
+
expected = set(column.type.enums)
|
|
86
92
|
if unexpected := existing - expected:
|
|
87
93
|
raise ValueError(f"Unexpected values in {table.name}.{column.key}: {unexpected}")
|
|
88
94
|
if not (missing := expected - existing):
|
|
89
95
|
continue
|
|
90
|
-
await session.execute(insert(table), [{column.key: v} for v in missing])
|
|
96
|
+
await session.execute(sa.insert(table), [{column.key: v} for v in missing])
|
|
91
97
|
|
|
92
98
|
|
|
93
99
|
async def _ensure_user_roles(db: DbSessionFactory) -> None:
|
|
@@ -97,21 +103,22 @@ async def _ensure_user_roles(db: DbSessionFactory) -> None:
|
|
|
97
103
|
the email "admin@localhost".
|
|
98
104
|
"""
|
|
99
105
|
async with db() as session:
|
|
100
|
-
role_ids = {
|
|
106
|
+
role_ids: dict[models.UserRoleName, int] = {
|
|
101
107
|
name: id_
|
|
102
108
|
async for name, id_ in await session.stream(
|
|
103
|
-
select(models.UserRole.name, models.UserRole.id)
|
|
109
|
+
sa.select(models.UserRole.name, models.UserRole.id)
|
|
104
110
|
)
|
|
105
111
|
}
|
|
106
|
-
existing_roles = [
|
|
112
|
+
existing_roles: list[models.UserRoleName] = [
|
|
107
113
|
name
|
|
108
114
|
async for name in await session.stream_scalars(
|
|
109
|
-
select(distinct(models.UserRole.name)).join_from(models.User, models.UserRole)
|
|
115
|
+
sa.select(sa.distinct(models.UserRole.name)).join_from(models.User, models.UserRole)
|
|
110
116
|
)
|
|
111
117
|
]
|
|
112
|
-
if (
|
|
113
|
-
|
|
114
|
-
|
|
118
|
+
if (
|
|
119
|
+
"SYSTEM" not in existing_roles
|
|
120
|
+
and (system_role_id := role_ids.get("SYSTEM")) is not None
|
|
121
|
+
):
|
|
115
122
|
system_user = models.LocalUser(
|
|
116
123
|
user_role_id=system_role_id,
|
|
117
124
|
username=DEFAULT_SYSTEM_USERNAME,
|
|
@@ -121,9 +128,7 @@ async def _ensure_user_roles(db: DbSessionFactory) -> None:
|
|
|
121
128
|
password_hash=secrets.token_bytes(DEFAULT_SECRET_LENGTH),
|
|
122
129
|
)
|
|
123
130
|
session.add(system_user)
|
|
124
|
-
if
|
|
125
|
-
admin_role_id := role_ids.get(admin_role)
|
|
126
|
-
) is not None:
|
|
131
|
+
if "ADMIN" not in existing_roles and (admin_role_id := role_ids.get("ADMIN")) is not None:
|
|
127
132
|
salt = secrets.token_bytes(DEFAULT_SECRET_LENGTH)
|
|
128
133
|
password = get_env_default_admin_initial_password()
|
|
129
134
|
compute = partial(compute_password_hash, password=password, salt=salt)
|
|
@@ -147,9 +152,9 @@ async def _get_system_user_id(db: DbSessionFactory) -> None:
|
|
|
147
152
|
"""
|
|
148
153
|
async with db() as session:
|
|
149
154
|
system_user_id = await session.scalar(
|
|
150
|
-
select(models.User.id)
|
|
155
|
+
sa.select(models.User.id)
|
|
151
156
|
.join(models.UserRole)
|
|
152
|
-
.where(models.UserRole.name ==
|
|
157
|
+
.where(models.UserRole.name == "SYSTEM")
|
|
153
158
|
.order_by(models.User.id)
|
|
154
159
|
.limit(1)
|
|
155
160
|
)
|
|
@@ -173,7 +178,7 @@ async def _ensure_admins(
|
|
|
173
178
|
async with db() as session:
|
|
174
179
|
existing_emails = set(
|
|
175
180
|
await session.scalars(
|
|
176
|
-
select(models.User.email).where(models.User.email.in_(admins.keys()))
|
|
181
|
+
sa.select(models.User.email).where(models.User.email.in_(admins.keys()))
|
|
177
182
|
)
|
|
178
183
|
)
|
|
179
184
|
admins = {
|
|
@@ -183,7 +188,7 @@ async def _ensure_admins(
|
|
|
183
188
|
return
|
|
184
189
|
existing_usernames = set(
|
|
185
190
|
await session.scalars(
|
|
186
|
-
select(models.User.username).where(models.User.username.in_(admins.values()))
|
|
191
|
+
sa.select(models.User.username).where(models.User.username.in_(admins.values()))
|
|
187
192
|
)
|
|
188
193
|
)
|
|
189
194
|
admins = {
|
|
@@ -193,10 +198,22 @@ async def _ensure_admins(
|
|
|
193
198
|
}
|
|
194
199
|
if not admins:
|
|
195
200
|
return
|
|
196
|
-
admin_role_id = await session.scalar(
|
|
197
|
-
select(models.UserRole.id).filter_by(name=UserRole.ADMIN.value)
|
|
198
|
-
)
|
|
201
|
+
admin_role_id = await session.scalar(sa.select(models.UserRole.id).filter_by(name="ADMIN"))
|
|
199
202
|
assert admin_role_id is not None, "Admin role not found in database"
|
|
203
|
+
|
|
204
|
+
# Determine which auth method to use for admin users
|
|
205
|
+
# Priority: LOCAL (if enabled) > LDAP (if configured and no OAuth2) > OAuth2
|
|
206
|
+
# Use try/except to handle invalid configurations gracefully
|
|
207
|
+
try:
|
|
208
|
+
ldap_config = LDAPConfig.from_env()
|
|
209
|
+
except Exception:
|
|
210
|
+
ldap_config = None
|
|
211
|
+
try:
|
|
212
|
+
oauth2_configs = get_env_oauth2_settings()
|
|
213
|
+
except Exception:
|
|
214
|
+
oauth2_configs = []
|
|
215
|
+
use_ldap = disable_basic_auth and ldap_config is not None and not oauth2_configs
|
|
216
|
+
|
|
200
217
|
user: models.User
|
|
201
218
|
for email, username in admins.items():
|
|
202
219
|
if not disable_basic_auth:
|
|
@@ -206,6 +223,11 @@ async def _ensure_admins(
|
|
|
206
223
|
password_salt=secrets.token_bytes(DEFAULT_SECRET_LENGTH),
|
|
207
224
|
password_hash=secrets.token_bytes(DEFAULT_SECRET_LENGTH),
|
|
208
225
|
)
|
|
226
|
+
elif use_ldap:
|
|
227
|
+
user = models.LDAPUser(
|
|
228
|
+
email=email,
|
|
229
|
+
username=username,
|
|
230
|
+
)
|
|
209
231
|
else:
|
|
210
232
|
user = models.OAuth2User(
|
|
211
233
|
email=email,
|
|
@@ -224,6 +246,83 @@ async def _ensure_admins(
|
|
|
224
246
|
logger.error(f"Failed to send welcome email: {exc}")
|
|
225
247
|
|
|
226
248
|
|
|
249
|
+
_CHILDLESS_RECORD_DELETION_GRACE_PERIOD_DAYS = 1
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
def _stmt_to_delete_expired_childless_records(
|
|
253
|
+
table: type[models.HasId],
|
|
254
|
+
foreign_key: Union[InstrumentedAttribute[int], InstrumentedAttribute[Optional[int]]],
|
|
255
|
+
) -> ReturningDelete[tuple[int]]:
|
|
256
|
+
"""
|
|
257
|
+
Creates a SQLAlchemy DELETE statement to permanently remove childless records.
|
|
258
|
+
|
|
259
|
+
Args:
|
|
260
|
+
table: The table model class that has a deleted_at column
|
|
261
|
+
foreign_key: The foreign key attribute to check for child relationships
|
|
262
|
+
|
|
263
|
+
Returns:
|
|
264
|
+
A DELETE statement that removes childless records marked for deletion more than
|
|
265
|
+
_CHILDLESS_RECORD_DELETION_GRACE_PERIOD_DAYS days ago
|
|
266
|
+
""" # noqa: E501
|
|
267
|
+
if not hasattr(table, "deleted_at"):
|
|
268
|
+
raise TypeError("Table must have a 'deleted_at' column")
|
|
269
|
+
cutoff_time = datetime.now(timezone.utc) - timedelta(
|
|
270
|
+
days=_CHILDLESS_RECORD_DELETION_GRACE_PERIOD_DAYS
|
|
271
|
+
)
|
|
272
|
+
return (
|
|
273
|
+
sa.delete(table)
|
|
274
|
+
.where(table.deleted_at.isnot(None))
|
|
275
|
+
.where(table.deleted_at < cutoff_time)
|
|
276
|
+
.where(~sa.exists().where(table.id == foreign_key))
|
|
277
|
+
.returning(table.id)
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
async def _delete_expired_childless_records_on_generative_models(
|
|
282
|
+
db: DbSessionFactory,
|
|
283
|
+
) -> None:
|
|
284
|
+
"""
|
|
285
|
+
Permanently deletes childless GenerativeModel records that have been marked for deletion.
|
|
286
|
+
|
|
287
|
+
This function removes GenerativeModel records that:
|
|
288
|
+
- Have been marked for deletion (deleted_at is not NULL)
|
|
289
|
+
- Were marked more than 1 day ago (grace period expired)
|
|
290
|
+
- Have no associated SpanCost records (childless)
|
|
291
|
+
|
|
292
|
+
This cleanup is necessary to remove orphaned records that may have been left behind
|
|
293
|
+
due to previous migrations or deletions.
|
|
294
|
+
""" # noqa: E501
|
|
295
|
+
stmt = _stmt_to_delete_expired_childless_records(
|
|
296
|
+
models.GenerativeModel,
|
|
297
|
+
models.SpanCost.model_id,
|
|
298
|
+
)
|
|
299
|
+
async with db() as session:
|
|
300
|
+
result = (await session.scalars(stmt)).all()
|
|
301
|
+
if result:
|
|
302
|
+
logger.info(f"Permanently deleted {len(result)} expired childless GenerativeModel records")
|
|
303
|
+
else:
|
|
304
|
+
logger.debug("No expired childless GenerativeModel records found for permanent deletion")
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
async def _delete_expired_childless_records(
|
|
308
|
+
db: DbSessionFactory,
|
|
309
|
+
) -> None:
|
|
310
|
+
"""
|
|
311
|
+
Permanently deletes childless records across all relevant tables.
|
|
312
|
+
|
|
313
|
+
This function runs the deletion process for all table types that support soft deletion,
|
|
314
|
+
handling any exceptions that occur during the process. Only records that have been
|
|
315
|
+
marked for deletion for more than the grace period (1 day) are permanently removed.
|
|
316
|
+
""" # noqa: E501
|
|
317
|
+
exceptions = await gather(
|
|
318
|
+
_delete_expired_childless_records_on_generative_models(db),
|
|
319
|
+
return_exceptions=True,
|
|
320
|
+
)
|
|
321
|
+
for exc in exceptions:
|
|
322
|
+
if isinstance(exc, Exception):
|
|
323
|
+
logger.error(f"Failed to delete childless records: {exc}")
|
|
324
|
+
|
|
325
|
+
|
|
227
326
|
async def _ensure_default_project_trace_retention_policy(db: DbSessionFactory) -> None:
|
|
228
327
|
"""
|
|
229
328
|
Ensures the default trace retention policy (id=1) exists in the database. Default policy
|
|
@@ -248,8 +347,8 @@ async def _ensure_default_project_trace_retention_policy(db: DbSessionFactory) -
|
|
|
248
347
|
assert DEFAULT_PROJECT_TRACE_RETENTION_POLICY_ID == 0
|
|
249
348
|
async with db() as session:
|
|
250
349
|
if await session.scalar(
|
|
251
|
-
select(
|
|
252
|
-
exists().where(
|
|
350
|
+
sa.select(
|
|
351
|
+
sa.exists().where(
|
|
253
352
|
models.ProjectTraceRetentionPolicy.id
|
|
254
353
|
== DEFAULT_PROJECT_TRACE_RETENTION_POLICY_ID
|
|
255
354
|
)
|
|
@@ -257,9 +356,11 @@ async def _ensure_default_project_trace_retention_policy(db: DbSessionFactory) -
|
|
|
257
356
|
):
|
|
258
357
|
return
|
|
259
358
|
cron_expression = TraceRetentionCronExpression(root="0 0 * * 0")
|
|
260
|
-
rule = TraceRetentionRule(
|
|
359
|
+
rule = TraceRetentionRule(
|
|
360
|
+
root=MaxDaysRule(max_days=get_env_default_retention_policy_days())
|
|
361
|
+
)
|
|
261
362
|
await session.execute(
|
|
262
|
-
insert(models.ProjectTraceRetentionPolicy),
|
|
363
|
+
sa.insert(models.ProjectTraceRetentionPolicy),
|
|
263
364
|
[
|
|
264
365
|
{
|
|
265
366
|
"id": DEFAULT_PROJECT_TRACE_RETENTION_POLICY_ID,
|
|
@@ -269,3 +370,169 @@ async def _ensure_default_project_trace_retention_policy(db: DbSessionFactory) -
|
|
|
269
370
|
}
|
|
270
371
|
],
|
|
271
372
|
)
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
_COST_MODEL_MANIFEST: Path = (
|
|
376
|
+
Path(__file__).parent.parent / "server" / "cost_tracking" / "model_cost_manifest.json"
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
class _TokenTypeKey(NamedTuple):
|
|
381
|
+
"""
|
|
382
|
+
Composite key for uniquely identifying token price configurations.
|
|
383
|
+
|
|
384
|
+
Token prices are differentiated by both their type (e.g., "input", "output", "audio")
|
|
385
|
+
and whether they represent prompt tokens (input to the model) or completion tokens
|
|
386
|
+
(output from the model). Some token types like "audio" can exist in both categories.
|
|
387
|
+
|
|
388
|
+
Attributes:
|
|
389
|
+
token_type: The category of token (e.g., "input", "output", "audio", "cache_write")
|
|
390
|
+
is_prompt: True if these are prompt/input tokens, False if completion/output tokens
|
|
391
|
+
"""
|
|
392
|
+
|
|
393
|
+
token_type: str
|
|
394
|
+
is_prompt: bool
|
|
395
|
+
|
|
396
|
+
|
|
397
|
+
async def _ensure_model_costs(db: DbSessionFactory) -> None:
|
|
398
|
+
"""
|
|
399
|
+
Ensures that built-in generative models and their token pricing information are up-to-date
|
|
400
|
+
in the database based on the model cost manifest file.
|
|
401
|
+
|
|
402
|
+
This function performs a comprehensive synchronization between the database and the manifest:
|
|
403
|
+
|
|
404
|
+
1. **Model Management**: Creates new built-in models from the manifest or updates existing ones
|
|
405
|
+
2. **Token Price Synchronization**: Ensures all token prices match the manifest rates
|
|
406
|
+
3. **Cleanup**: Soft-deletes built-in models no longer present in the manifest
|
|
407
|
+
|
|
408
|
+
The function handles different token types including:
|
|
409
|
+
- Input tokens (prompt): Standard input tokens for generation
|
|
410
|
+
- Cache write tokens (prompt): Tokens written to cache systems
|
|
411
|
+
- Cache read tokens (prompt): Tokens read from cache systems
|
|
412
|
+
- Output tokens (non-prompt): Generated response tokens
|
|
413
|
+
- Audio tokens (both prompt and non-prompt): Audio processing tokens
|
|
414
|
+
|
|
415
|
+
Token prices are uniquely identified by (token_type, is_prompt) pairs to handle
|
|
416
|
+
cases like audio tokens that can be both prompt and non-prompt.
|
|
417
|
+
|
|
418
|
+
Args:
|
|
419
|
+
db (DbSessionFactory): Database session factory for database operations
|
|
420
|
+
|
|
421
|
+
Returns:
|
|
422
|
+
None
|
|
423
|
+
|
|
424
|
+
Raises:
|
|
425
|
+
FileNotFoundError: If the model cost manifest file is not found
|
|
426
|
+
json.JSONDecodeError: If the manifest file contains invalid JSON
|
|
427
|
+
ValueError: If manifest data is malformed or missing required fields
|
|
428
|
+
"""
|
|
429
|
+
# Load the authoritative model cost data from the manifest file
|
|
430
|
+
with open(_COST_MODEL_MANIFEST) as f:
|
|
431
|
+
manifest = json.load(f)
|
|
432
|
+
|
|
433
|
+
async with db() as session:
|
|
434
|
+
# Fetch all existing built-in models with their token prices eagerly loaded
|
|
435
|
+
# Using .unique() to deduplicate models when multiple token prices are joined
|
|
436
|
+
built_in_models = {
|
|
437
|
+
omodel.name: omodel
|
|
438
|
+
for omodel in (
|
|
439
|
+
await session.scalars(
|
|
440
|
+
select(models.GenerativeModel)
|
|
441
|
+
.where(models.GenerativeModel.deleted_at.is_(None))
|
|
442
|
+
.where(models.GenerativeModel.is_built_in.is_(True))
|
|
443
|
+
.options(joinedload(models.GenerativeModel.token_prices))
|
|
444
|
+
)
|
|
445
|
+
).unique()
|
|
446
|
+
}
|
|
447
|
+
|
|
448
|
+
seen_names: set[str] = set()
|
|
449
|
+
seen_patterns: set[tuple[re.Pattern[str], str]] = set()
|
|
450
|
+
|
|
451
|
+
for model_data in manifest["models"]:
|
|
452
|
+
name = str(model_data.get("name") or "").strip()
|
|
453
|
+
if not name:
|
|
454
|
+
logger.warning("Skipping model with empty name in manifest")
|
|
455
|
+
continue
|
|
456
|
+
if name in seen_names:
|
|
457
|
+
logger.warning(f"Skipping model '{name}' with duplicate name in manifest")
|
|
458
|
+
continue
|
|
459
|
+
seen_names.add(name)
|
|
460
|
+
regex = str(model_data.get("name_pattern") or "").strip()
|
|
461
|
+
try:
|
|
462
|
+
pattern = re.compile(regex)
|
|
463
|
+
except re.error as e:
|
|
464
|
+
logger.warning(f"Skipping model '{name}' with invalid regex: {e}")
|
|
465
|
+
continue
|
|
466
|
+
provider = str(model_data.get("provider") or "").strip()
|
|
467
|
+
if (pattern, provider) in seen_patterns:
|
|
468
|
+
logger.warning(
|
|
469
|
+
f"Skipping model '{name}' with duplicate name_pattern/provider combination"
|
|
470
|
+
)
|
|
471
|
+
continue
|
|
472
|
+
seen_patterns.add((pattern, provider))
|
|
473
|
+
# Remove model from built_in_models dict (for cleanup tracking)
|
|
474
|
+
# or create new model if not found
|
|
475
|
+
model = built_in_models.pop(model_data["name"], None)
|
|
476
|
+
if model is None:
|
|
477
|
+
# Create new built-in model from manifest data
|
|
478
|
+
model = models.GenerativeModel(
|
|
479
|
+
name=name,
|
|
480
|
+
provider=provider,
|
|
481
|
+
name_pattern=pattern,
|
|
482
|
+
is_built_in=True,
|
|
483
|
+
)
|
|
484
|
+
session.add(model)
|
|
485
|
+
else:
|
|
486
|
+
# Update existing model's metadata from manifest
|
|
487
|
+
model.provider = provider
|
|
488
|
+
model.name_pattern = pattern
|
|
489
|
+
|
|
490
|
+
# Create lookup table for existing token prices by (token_type, is_prompt)
|
|
491
|
+
# Using pop() during iteration allows us to track which prices are no longer needed
|
|
492
|
+
existing_token_prices = {
|
|
493
|
+
_TokenTypeKey(token_price.token_type, token_price.is_prompt): token_price
|
|
494
|
+
for token_price in model.token_prices
|
|
495
|
+
}
|
|
496
|
+
|
|
497
|
+
# Synchronize token prices for all supported token types
|
|
498
|
+
for manifest_token_price in model_data["token_prices"]:
|
|
499
|
+
# Skip if this token type has no rate in the manifest
|
|
500
|
+
if not (base_rate := manifest_token_price.get("base_rate")):
|
|
501
|
+
continue
|
|
502
|
+
|
|
503
|
+
key = _TokenTypeKey(
|
|
504
|
+
manifest_token_price["token_type"],
|
|
505
|
+
manifest_token_price["is_prompt"],
|
|
506
|
+
)
|
|
507
|
+
# Remove from tracking dict and get existing price (if any)
|
|
508
|
+
if not (token_price := existing_token_prices.pop(key, None)):
|
|
509
|
+
# Create new token price if it doesn't exist
|
|
510
|
+
token_price = models.TokenPrice(
|
|
511
|
+
token_type=manifest_token_price["token_type"],
|
|
512
|
+
is_prompt=manifest_token_price["is_prompt"],
|
|
513
|
+
base_rate=base_rate,
|
|
514
|
+
)
|
|
515
|
+
model.token_prices.append(token_price)
|
|
516
|
+
elif token_price.base_rate != base_rate:
|
|
517
|
+
# Update existing price if rate has changed
|
|
518
|
+
token_price.base_rate = base_rate
|
|
519
|
+
|
|
520
|
+
# Remove any token prices that are no longer in the manifest
|
|
521
|
+
# These are prices that weren't popped from the token_prices dict above
|
|
522
|
+
for token_price in existing_token_prices.values():
|
|
523
|
+
model.token_prices.remove(token_price)
|
|
524
|
+
|
|
525
|
+
# Clean up built-in models that are no longer in the manifest
|
|
526
|
+
# These are models that weren't popped from built_in_models dict above
|
|
527
|
+
remaining_models = list(built_in_models.values())
|
|
528
|
+
if not remaining_models:
|
|
529
|
+
return
|
|
530
|
+
|
|
531
|
+
# Soft delete obsolete built-in models
|
|
532
|
+
async with db() as session:
|
|
533
|
+
await session.execute(
|
|
534
|
+
sa.update(models.GenerativeModel)
|
|
535
|
+
.values(deleted_at=sa.func.now())
|
|
536
|
+
.where(models.GenerativeModel.id.in_([m.id for m in remaining_models]))
|
|
537
|
+
.where(~sa.exists().where(models.GenerativeModel.id == models.SpanCost.model_id))
|
|
538
|
+
)
|