arize-phoenix 10.0.4__py3-none-any.whl → 12.28.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +124 -72
- arize_phoenix-12.28.1.dist-info/RECORD +499 -0
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
- phoenix/__generated__/__init__.py +0 -0
- phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
- phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
- phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
- phoenix/__init__.py +5 -4
- phoenix/auth.py +39 -2
- phoenix/config.py +1763 -91
- phoenix/datetime_utils.py +120 -2
- phoenix/db/README.md +595 -25
- phoenix/db/bulk_inserter.py +145 -103
- phoenix/db/engines.py +140 -33
- phoenix/db/enums.py +3 -12
- phoenix/db/facilitator.py +302 -35
- phoenix/db/helpers.py +1000 -65
- phoenix/db/iam_auth.py +64 -0
- phoenix/db/insertion/dataset.py +135 -2
- phoenix/db/insertion/document_annotation.py +9 -6
- phoenix/db/insertion/evaluation.py +2 -3
- phoenix/db/insertion/helpers.py +17 -2
- phoenix/db/insertion/session_annotation.py +176 -0
- phoenix/db/insertion/span.py +15 -11
- phoenix/db/insertion/span_annotation.py +3 -4
- phoenix/db/insertion/trace_annotation.py +3 -4
- phoenix/db/insertion/types.py +50 -20
- phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
- phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
- phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
- phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
- phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
- phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
- phoenix/db/migrations/versions/a20694b15f82_cost.py +196 -0
- phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
- phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
- phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
- phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
- phoenix/db/models.py +669 -56
- phoenix/db/pg_config.py +10 -0
- phoenix/db/types/model_provider.py +4 -0
- phoenix/db/types/token_price_customization.py +29 -0
- phoenix/db/types/trace_retention.py +23 -15
- phoenix/experiments/evaluators/utils.py +3 -3
- phoenix/experiments/functions.py +160 -52
- phoenix/experiments/tracing.py +2 -2
- phoenix/experiments/types.py +1 -1
- phoenix/inferences/inferences.py +1 -2
- phoenix/server/api/auth.py +38 -7
- phoenix/server/api/auth_messages.py +46 -0
- phoenix/server/api/context.py +100 -4
- phoenix/server/api/dataloaders/__init__.py +79 -5
- phoenix/server/api/dataloaders/annotation_configs_by_project.py +31 -0
- phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
- phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
- phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
- phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
- phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
- phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
- phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
- phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
- phoenix/server/api/dataloaders/dataset_labels.py +36 -0
- phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
- phoenix/server/api/dataloaders/document_evaluations.py +6 -9
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
- phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
- phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
- phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
- phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
- phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
- phoenix/server/api/dataloaders/last_used_times_by_generative_model_id.py +35 -0
- phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
- phoenix/server/api/dataloaders/record_counts.py +37 -10
- phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
- phoenix/server/api/dataloaders/span_cost_by_span.py +24 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_generative_model.py +56 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_project_session.py +57 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_span.py +43 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_trace.py +56 -0
- phoenix/server/api/dataloaders/span_cost_details_by_span_cost.py +27 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment.py +57 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_run.py +58 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_generative_model.py +55 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project.py +152 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project_session.py +56 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_trace.py +55 -0
- phoenix/server/api/dataloaders/span_costs.py +29 -0
- phoenix/server/api/dataloaders/table_fields.py +2 -2
- phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
- phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
- phoenix/server/api/dataloaders/types.py +29 -0
- phoenix/server/api/exceptions.py +11 -1
- phoenix/server/api/helpers/dataset_helpers.py +5 -1
- phoenix/server/api/helpers/playground_clients.py +1243 -292
- phoenix/server/api/helpers/playground_registry.py +2 -2
- phoenix/server/api/helpers/playground_spans.py +8 -4
- phoenix/server/api/helpers/playground_users.py +26 -0
- phoenix/server/api/helpers/prompts/conversions/aws.py +83 -0
- phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
- phoenix/server/api/helpers/prompts/models.py +205 -22
- phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
- phoenix/server/api/input_types/ChatCompletionInput.py +6 -2
- phoenix/server/api/input_types/CreateProjectInput.py +27 -0
- phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
- phoenix/server/api/input_types/DatasetFilter.py +17 -0
- phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
- phoenix/server/api/input_types/GenerativeCredentialInput.py +9 -0
- phoenix/server/api/input_types/GenerativeModelInput.py +5 -0
- phoenix/server/api/input_types/ProjectSessionSort.py +161 -1
- phoenix/server/api/input_types/PromptFilter.py +14 -0
- phoenix/server/api/input_types/PromptVersionInput.py +52 -1
- phoenix/server/api/input_types/SpanSort.py +44 -7
- phoenix/server/api/input_types/TimeBinConfig.py +23 -0
- phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
- phoenix/server/api/input_types/UserRoleInput.py +1 -0
- phoenix/server/api/mutations/__init__.py +10 -0
- phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
- phoenix/server/api/mutations/api_key_mutations.py +19 -23
- phoenix/server/api/mutations/chat_mutations.py +154 -47
- phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
- phoenix/server/api/mutations/dataset_mutations.py +21 -16
- phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
- phoenix/server/api/mutations/experiment_mutations.py +2 -2
- phoenix/server/api/mutations/export_events_mutations.py +3 -3
- phoenix/server/api/mutations/model_mutations.py +210 -0
- phoenix/server/api/mutations/project_mutations.py +49 -10
- phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
- phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
- phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
- phoenix/server/api/mutations/prompt_mutations.py +65 -129
- phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
- phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
- phoenix/server/api/mutations/trace_annotations_mutations.py +14 -10
- phoenix/server/api/mutations/trace_mutations.py +47 -3
- phoenix/server/api/mutations/user_mutations.py +66 -41
- phoenix/server/api/queries.py +768 -293
- phoenix/server/api/routers/__init__.py +2 -2
- phoenix/server/api/routers/auth.py +154 -88
- phoenix/server/api/routers/ldap.py +229 -0
- phoenix/server/api/routers/oauth2.py +369 -106
- phoenix/server/api/routers/v1/__init__.py +24 -4
- phoenix/server/api/routers/v1/annotation_configs.py +23 -31
- phoenix/server/api/routers/v1/annotations.py +481 -17
- phoenix/server/api/routers/v1/datasets.py +395 -81
- phoenix/server/api/routers/v1/documents.py +142 -0
- phoenix/server/api/routers/v1/evaluations.py +24 -31
- phoenix/server/api/routers/v1/experiment_evaluations.py +19 -8
- phoenix/server/api/routers/v1/experiment_runs.py +337 -59
- phoenix/server/api/routers/v1/experiments.py +479 -48
- phoenix/server/api/routers/v1/models.py +7 -0
- phoenix/server/api/routers/v1/projects.py +18 -49
- phoenix/server/api/routers/v1/prompts.py +54 -40
- phoenix/server/api/routers/v1/sessions.py +108 -0
- phoenix/server/api/routers/v1/spans.py +1091 -81
- phoenix/server/api/routers/v1/traces.py +132 -78
- phoenix/server/api/routers/v1/users.py +389 -0
- phoenix/server/api/routers/v1/utils.py +3 -7
- phoenix/server/api/subscriptions.py +305 -88
- phoenix/server/api/types/Annotation.py +90 -23
- phoenix/server/api/types/ApiKey.py +13 -17
- phoenix/server/api/types/AuthMethod.py +1 -0
- phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
- phoenix/server/api/types/CostBreakdown.py +12 -0
- phoenix/server/api/types/Dataset.py +226 -72
- phoenix/server/api/types/DatasetExample.py +88 -18
- phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
- phoenix/server/api/types/DatasetLabel.py +57 -0
- phoenix/server/api/types/DatasetSplit.py +98 -0
- phoenix/server/api/types/DatasetVersion.py +49 -4
- phoenix/server/api/types/DocumentAnnotation.py +212 -0
- phoenix/server/api/types/Experiment.py +264 -59
- phoenix/server/api/types/ExperimentComparison.py +5 -10
- phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
- phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
- phoenix/server/api/types/ExperimentRun.py +169 -65
- phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
- phoenix/server/api/types/GenerativeModel.py +245 -3
- phoenix/server/api/types/GenerativeProvider.py +70 -11
- phoenix/server/api/types/{Model.py → InferenceModel.py} +1 -1
- phoenix/server/api/types/ModelInterface.py +16 -0
- phoenix/server/api/types/PlaygroundModel.py +20 -0
- phoenix/server/api/types/Project.py +1278 -216
- phoenix/server/api/types/ProjectSession.py +188 -28
- phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
- phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
- phoenix/server/api/types/Prompt.py +119 -39
- phoenix/server/api/types/PromptLabel.py +42 -25
- phoenix/server/api/types/PromptVersion.py +11 -8
- phoenix/server/api/types/PromptVersionTag.py +65 -25
- phoenix/server/api/types/ServerStatus.py +6 -0
- phoenix/server/api/types/Span.py +167 -123
- phoenix/server/api/types/SpanAnnotation.py +189 -42
- phoenix/server/api/types/SpanCostDetailSummaryEntry.py +10 -0
- phoenix/server/api/types/SpanCostSummary.py +10 -0
- phoenix/server/api/types/SystemApiKey.py +65 -1
- phoenix/server/api/types/TokenPrice.py +16 -0
- phoenix/server/api/types/TokenUsage.py +3 -3
- phoenix/server/api/types/Trace.py +223 -51
- phoenix/server/api/types/TraceAnnotation.py +149 -50
- phoenix/server/api/types/User.py +137 -32
- phoenix/server/api/types/UserApiKey.py +73 -26
- phoenix/server/api/types/node.py +10 -0
- phoenix/server/api/types/pagination.py +11 -2
- phoenix/server/app.py +290 -45
- phoenix/server/authorization.py +38 -3
- phoenix/server/bearer_auth.py +34 -24
- phoenix/server/cost_tracking/cost_details_calculator.py +196 -0
- phoenix/server/cost_tracking/cost_model_lookup.py +179 -0
- phoenix/server/cost_tracking/helpers.py +68 -0
- phoenix/server/cost_tracking/model_cost_manifest.json +3657 -830
- phoenix/server/cost_tracking/regex_specificity.py +397 -0
- phoenix/server/cost_tracking/token_cost_calculator.py +57 -0
- phoenix/server/daemons/__init__.py +0 -0
- phoenix/server/daemons/db_disk_usage_monitor.py +214 -0
- phoenix/server/daemons/generative_model_store.py +103 -0
- phoenix/server/daemons/span_cost_calculator.py +99 -0
- phoenix/server/dml_event.py +17 -0
- phoenix/server/dml_event_handler.py +5 -0
- phoenix/server/email/sender.py +56 -3
- phoenix/server/email/templates/db_disk_usage_notification.html +19 -0
- phoenix/server/email/types.py +11 -0
- phoenix/server/experiments/__init__.py +0 -0
- phoenix/server/experiments/utils.py +14 -0
- phoenix/server/grpc_server.py +11 -11
- phoenix/server/jwt_store.py +17 -15
- phoenix/server/ldap.py +1449 -0
- phoenix/server/main.py +26 -10
- phoenix/server/oauth2.py +330 -12
- phoenix/server/prometheus.py +66 -6
- phoenix/server/rate_limiters.py +4 -9
- phoenix/server/retention.py +33 -20
- phoenix/server/session_filters.py +49 -0
- phoenix/server/static/.vite/manifest.json +55 -51
- phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
- phoenix/server/static/assets/{index-E0M82BdE.js → index-CTQoemZv.js} +140 -56
- phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
- phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
- phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
- phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
- phoenix/server/static/assets/vendor-recharts-V9cwpXsm.js +37 -0
- phoenix/server/static/assets/vendor-shiki-Do--csgv.js +5 -0
- phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
- phoenix/server/templates/index.html +40 -6
- phoenix/server/thread_server.py +1 -2
- phoenix/server/types.py +14 -4
- phoenix/server/utils.py +74 -0
- phoenix/session/client.py +56 -3
- phoenix/session/data_extractor.py +5 -0
- phoenix/session/evaluation.py +14 -5
- phoenix/session/session.py +45 -9
- phoenix/settings.py +5 -0
- phoenix/trace/attributes.py +80 -13
- phoenix/trace/dsl/helpers.py +90 -1
- phoenix/trace/dsl/query.py +8 -6
- phoenix/trace/projects.py +5 -0
- phoenix/utilities/template_formatters.py +1 -1
- phoenix/version.py +1 -1
- arize_phoenix-10.0.4.dist-info/RECORD +0 -405
- phoenix/server/api/types/Evaluation.py +0 -39
- phoenix/server/cost_tracking/cost_lookup.py +0 -255
- phoenix/server/static/assets/components-DULKeDfL.js +0 -4365
- phoenix/server/static/assets/pages-Cl0A-0U2.js +0 -7430
- phoenix/server/static/assets/vendor-WIZid84E.css +0 -1
- phoenix/server/static/assets/vendor-arizeai-Dy-0mSNw.js +0 -649
- phoenix/server/static/assets/vendor-codemirror-DBtifKNr.js +0 -33
- phoenix/server/static/assets/vendor-oB4u9zuV.js +0 -905
- phoenix/server/static/assets/vendor-recharts-D-T4KPz2.js +0 -59
- phoenix/server/static/assets/vendor-shiki-BMn4O_9F.js +0 -5
- phoenix/server/static/assets/vendor-three-C5WAXd5r.js +0 -2998
- phoenix/utilities/deprecation.py +0 -31
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,7 +1,10 @@
|
|
|
1
1
|
from datetime import datetime
|
|
2
|
+
from typing import Any
|
|
2
3
|
|
|
3
4
|
from pydantic import BaseModel, ConfigDict
|
|
4
5
|
|
|
6
|
+
from phoenix.db.types.db_models import UNDEFINED
|
|
7
|
+
|
|
5
8
|
|
|
6
9
|
def datetime_encoder(dt: datetime) -> str:
|
|
7
10
|
"""
|
|
@@ -43,3 +46,7 @@ class V1RoutesBaseModel(BaseModel):
|
|
|
43
46
|
[]
|
|
44
47
|
), # suppress warnings about protected namespaces starting with `model_` on pydantic 2.9
|
|
45
48
|
)
|
|
49
|
+
|
|
50
|
+
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
51
|
+
kwargs = {k: v for k, v in kwargs.items() if v is not UNDEFINED}
|
|
52
|
+
super().__init__(*args, **kwargs)
|
|
@@ -1,20 +1,13 @@
|
|
|
1
1
|
from typing import Optional
|
|
2
2
|
|
|
3
|
-
from fastapi import APIRouter, HTTPException, Path, Query
|
|
3
|
+
from fastapi import APIRouter, Depends, HTTPException, Path, Query
|
|
4
4
|
from pydantic import Field
|
|
5
5
|
from sqlalchemy import select
|
|
6
6
|
from starlette.requests import Request
|
|
7
|
-
from starlette.status import (
|
|
8
|
-
HTTP_204_NO_CONTENT,
|
|
9
|
-
HTTP_403_FORBIDDEN,
|
|
10
|
-
HTTP_404_NOT_FOUND,
|
|
11
|
-
HTTP_422_UNPROCESSABLE_ENTITY,
|
|
12
|
-
)
|
|
13
7
|
from strawberry.relay import GlobalID
|
|
14
8
|
|
|
15
9
|
from phoenix.config import DEFAULT_PROJECT_NAME
|
|
16
10
|
from phoenix.db import models
|
|
17
|
-
from phoenix.db.enums import UserRole
|
|
18
11
|
from phoenix.db.helpers import exclude_experiment_projects
|
|
19
12
|
from phoenix.server.api.routers.v1.models import V1RoutesBaseModel
|
|
20
13
|
from phoenix.server.api.routers.v1.utils import (
|
|
@@ -24,6 +17,7 @@ from phoenix.server.api.routers.v1.utils import (
|
|
|
24
17
|
add_errors_to_responses,
|
|
25
18
|
)
|
|
26
19
|
from phoenix.server.api.types.Project import Project as ProjectNodeType
|
|
20
|
+
from phoenix.server.authorization import is_not_locked, require_admin
|
|
27
21
|
|
|
28
22
|
router = APIRouter(tags=["projects"])
|
|
29
23
|
|
|
@@ -69,7 +63,7 @@ class UpdateProjectResponseBody(ResponseBody[Project]):
|
|
|
69
63
|
response_description="A list of projects with pagination information", # noqa: E501
|
|
70
64
|
responses=add_errors_to_responses(
|
|
71
65
|
[
|
|
72
|
-
|
|
66
|
+
422,
|
|
73
67
|
]
|
|
74
68
|
),
|
|
75
69
|
)
|
|
@@ -114,7 +108,7 @@ async def get_projects(
|
|
|
114
108
|
except ValueError:
|
|
115
109
|
raise HTTPException(
|
|
116
110
|
detail=f"Invalid cursor format: {cursor}",
|
|
117
|
-
status_code=
|
|
111
|
+
status_code=422,
|
|
118
112
|
)
|
|
119
113
|
|
|
120
114
|
stmt = stmt.limit(limit + 1)
|
|
@@ -141,8 +135,8 @@ async def get_projects(
|
|
|
141
135
|
response_description="The requested project", # noqa: E501
|
|
142
136
|
responses=add_errors_to_responses(
|
|
143
137
|
[
|
|
144
|
-
|
|
145
|
-
|
|
138
|
+
404,
|
|
139
|
+
422,
|
|
146
140
|
]
|
|
147
141
|
),
|
|
148
142
|
)
|
|
@@ -174,13 +168,14 @@ async def get_project(
|
|
|
174
168
|
|
|
175
169
|
@router.post(
|
|
176
170
|
"/projects",
|
|
171
|
+
dependencies=[Depends(is_not_locked)],
|
|
177
172
|
operation_id="createProject",
|
|
178
173
|
summary="Create a new project", # noqa: E501
|
|
179
174
|
description="Create a new project with the specified configuration.", # noqa: E501
|
|
180
175
|
response_description="The newly created project", # noqa: E501
|
|
181
176
|
responses=add_errors_to_responses(
|
|
182
177
|
[
|
|
183
|
-
|
|
178
|
+
422,
|
|
184
179
|
]
|
|
185
180
|
),
|
|
186
181
|
)
|
|
@@ -214,15 +209,16 @@ async def create_project(
|
|
|
214
209
|
|
|
215
210
|
@router.put(
|
|
216
211
|
"/projects/{project_identifier}",
|
|
212
|
+
dependencies=[Depends(require_admin), Depends(is_not_locked)],
|
|
217
213
|
operation_id="updateProject",
|
|
218
214
|
summary="Update a project by ID or name", # noqa: E501
|
|
219
215
|
description="Update an existing project with new configuration. Project names cannot be changed. The project identifier is either project ID or project name. Note: When using a project name as the identifier, it cannot contain slash (/), question mark (?), or pound sign (#) characters.", # noqa: E501
|
|
220
216
|
response_description="The updated project", # noqa: E501
|
|
221
217
|
responses=add_errors_to_responses(
|
|
222
218
|
[
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
219
|
+
403,
|
|
220
|
+
404,
|
|
221
|
+
422,
|
|
226
222
|
]
|
|
227
223
|
),
|
|
228
224
|
)
|
|
@@ -248,20 +244,6 @@ async def update_project(
|
|
|
248
244
|
Raises:
|
|
249
245
|
HTTPException: If the project identifier format is invalid or the project is not found.
|
|
250
246
|
""" # noqa: E501
|
|
251
|
-
if request.app.state.authentication_enabled:
|
|
252
|
-
async with request.app.state.db() as session:
|
|
253
|
-
# Check if the user is an admin
|
|
254
|
-
stmt = (
|
|
255
|
-
select(models.UserRole.name)
|
|
256
|
-
.join(models.User)
|
|
257
|
-
.where(models.User.id == int(request.user.identity))
|
|
258
|
-
)
|
|
259
|
-
role_name = await session.scalar(stmt)
|
|
260
|
-
if role_name != UserRole.ADMIN.value:
|
|
261
|
-
raise HTTPException(
|
|
262
|
-
status_code=HTTP_403_FORBIDDEN,
|
|
263
|
-
detail="Only admins can update projects",
|
|
264
|
-
)
|
|
265
247
|
async with request.app.state.db() as session:
|
|
266
248
|
project = await _get_project_by_identifier(session, project_identifier)
|
|
267
249
|
|
|
@@ -275,16 +257,17 @@ async def update_project(
|
|
|
275
257
|
|
|
276
258
|
@router.delete(
|
|
277
259
|
"/projects/{project_identifier}",
|
|
260
|
+
dependencies=[Depends(require_admin)],
|
|
278
261
|
operation_id="deleteProject",
|
|
279
262
|
summary="Delete a project by ID or name", # noqa: E501
|
|
280
263
|
description="Delete an existing project and all its associated data. The project identifier is either project ID or project name. The default project cannot be deleted. Note: When using a project name as the identifier, it cannot contain slash (/), question mark (?), or pound sign (#) characters.", # noqa: E501
|
|
281
264
|
response_description="No content returned on successful deletion", # noqa: E501
|
|
282
|
-
status_code=
|
|
265
|
+
status_code=204,
|
|
283
266
|
responses=add_errors_to_responses(
|
|
284
267
|
[
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
268
|
+
403,
|
|
269
|
+
404,
|
|
270
|
+
422,
|
|
288
271
|
]
|
|
289
272
|
),
|
|
290
273
|
)
|
|
@@ -308,27 +291,13 @@ async def delete_project(
|
|
|
308
291
|
Raises:
|
|
309
292
|
HTTPException: If the project identifier format is invalid, the project is not found, or it's the default project.
|
|
310
293
|
""" # noqa: E501
|
|
311
|
-
if request.app.state.authentication_enabled:
|
|
312
|
-
async with request.app.state.db() as session:
|
|
313
|
-
# Check if the user is an admin
|
|
314
|
-
stmt = (
|
|
315
|
-
select(models.UserRole.name)
|
|
316
|
-
.join(models.User)
|
|
317
|
-
.where(models.User.id == int(request.user.identity))
|
|
318
|
-
)
|
|
319
|
-
role_name = await session.scalar(stmt)
|
|
320
|
-
if role_name != UserRole.ADMIN.value:
|
|
321
|
-
raise HTTPException(
|
|
322
|
-
status_code=HTTP_403_FORBIDDEN,
|
|
323
|
-
detail="Only admins can delete projects",
|
|
324
|
-
)
|
|
325
294
|
async with request.app.state.db() as session:
|
|
326
295
|
project = await _get_project_by_identifier(session, project_identifier)
|
|
327
296
|
|
|
328
297
|
# The default project must not be deleted - it's forbidden
|
|
329
298
|
if project.name == DEFAULT_PROJECT_NAME:
|
|
330
299
|
raise HTTPException(
|
|
331
|
-
status_code=
|
|
300
|
+
status_code=403,
|
|
332
301
|
detail="The default project cannot be deleted",
|
|
333
302
|
)
|
|
334
303
|
|
|
@@ -1,12 +1,12 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
from typing import Any, Optional, Union
|
|
3
3
|
|
|
4
|
-
from fastapi import APIRouter, HTTPException, Path, Query
|
|
4
|
+
from fastapi import APIRouter, Depends, HTTPException, Path, Query
|
|
5
5
|
from pydantic import ValidationError, model_validator
|
|
6
6
|
from sqlalchemy import select
|
|
7
|
+
from sqlalchemy.orm import joinedload
|
|
7
8
|
from sqlalchemy.sql import Select
|
|
8
9
|
from starlette.requests import Request
|
|
9
|
-
from starlette.status import HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND, HTTP_422_UNPROCESSABLE_ENTITY
|
|
10
10
|
from strawberry.relay import GlobalID
|
|
11
11
|
from typing_extensions import Self, TypeAlias, assert_never
|
|
12
12
|
|
|
@@ -33,6 +33,7 @@ from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
|
33
33
|
from phoenix.server.api.types.Prompt import Prompt as PromptNodeType
|
|
34
34
|
from phoenix.server.api.types.PromptVersion import PromptVersion as PromptVersionNodeType
|
|
35
35
|
from phoenix.server.api.types.PromptVersionTag import PromptVersionTag as PromptVersionTagNodeType
|
|
36
|
+
from phoenix.server.authorization import is_not_locked
|
|
36
37
|
from phoenix.server.bearer_auth import PhoenixUser
|
|
37
38
|
|
|
38
39
|
logger = logging.getLogger(__name__)
|
|
@@ -42,6 +43,7 @@ class PromptData(V1RoutesBaseModel):
|
|
|
42
43
|
name: Identifier
|
|
43
44
|
description: Optional[str] = None
|
|
44
45
|
source_prompt_id: Optional[str] = None
|
|
46
|
+
metadata: Optional[dict[str, Any]] = None
|
|
45
47
|
|
|
46
48
|
|
|
47
49
|
class Prompt(PromptData):
|
|
@@ -109,7 +111,7 @@ router = APIRouter(tags=["prompts"])
|
|
|
109
111
|
response_description="A list of prompts with pagination information",
|
|
110
112
|
responses=add_errors_to_responses(
|
|
111
113
|
[
|
|
112
|
-
|
|
114
|
+
422,
|
|
113
115
|
]
|
|
114
116
|
),
|
|
115
117
|
)
|
|
@@ -153,7 +155,7 @@ async def get_prompts(
|
|
|
153
155
|
except ValueError:
|
|
154
156
|
raise HTTPException(
|
|
155
157
|
detail=f"Invalid cursor format: {cursor}",
|
|
156
|
-
status_code=
|
|
158
|
+
status_code=422,
|
|
157
159
|
)
|
|
158
160
|
|
|
159
161
|
query = query.limit(limit + 1)
|
|
@@ -180,7 +182,7 @@ async def get_prompts(
|
|
|
180
182
|
description="Retrieve all versions of a specific prompt with pagination support. Each prompt "
|
|
181
183
|
"can have multiple versions with different configurations.",
|
|
182
184
|
response_description="A list of prompt versions with pagination information",
|
|
183
|
-
responses=add_errors_to_responses([
|
|
185
|
+
responses=add_errors_to_responses([422, 404]),
|
|
184
186
|
response_model_by_alias=True,
|
|
185
187
|
response_model_exclude_defaults=True,
|
|
186
188
|
response_model_exclude_unset=True,
|
|
@@ -213,7 +215,7 @@ async def list_prompt_versions(
|
|
|
213
215
|
HTTPException: If the cursor format is invalid, the prompt identifier is invalid,
|
|
214
216
|
or the prompt is not found.
|
|
215
217
|
"""
|
|
216
|
-
query = select(models.PromptVersion)
|
|
218
|
+
query = select(models.PromptVersion).options(joinedload(models.PromptVersion.prompt))
|
|
217
219
|
query = _filter_by_prompt_identifier(query.join(models.Prompt), prompt_identifier)
|
|
218
220
|
query = query.order_by(models.PromptVersion.id.desc())
|
|
219
221
|
|
|
@@ -225,7 +227,7 @@ async def list_prompt_versions(
|
|
|
225
227
|
except ValueError:
|
|
226
228
|
raise HTTPException(
|
|
227
229
|
detail=f"Invalid cursor format: {cursor}",
|
|
228
|
-
status_code=
|
|
230
|
+
status_code=422,
|
|
229
231
|
)
|
|
230
232
|
|
|
231
233
|
query = query.limit(limit + 1)
|
|
@@ -254,8 +256,8 @@ async def list_prompt_versions(
|
|
|
254
256
|
response_description="The requested prompt version",
|
|
255
257
|
responses=add_errors_to_responses(
|
|
256
258
|
[
|
|
257
|
-
|
|
258
|
-
|
|
259
|
+
404,
|
|
260
|
+
422,
|
|
259
261
|
]
|
|
260
262
|
),
|
|
261
263
|
response_model_by_alias=True,
|
|
@@ -285,11 +287,16 @@ async def get_prompt_version_by_prompt_version_id(
|
|
|
285
287
|
PromptVersionNodeType.__name__,
|
|
286
288
|
)
|
|
287
289
|
except ValueError:
|
|
288
|
-
raise HTTPException(
|
|
290
|
+
raise HTTPException(422, "Invalid prompt version ID")
|
|
289
291
|
async with request.app.state.db() as session:
|
|
290
|
-
|
|
292
|
+
stmt = (
|
|
293
|
+
select(models.PromptVersion)
|
|
294
|
+
.options(joinedload(models.PromptVersion.prompt))
|
|
295
|
+
.where(models.PromptVersion.id == id_)
|
|
296
|
+
)
|
|
297
|
+
prompt_version = await session.scalar(stmt)
|
|
291
298
|
if prompt_version is None:
|
|
292
|
-
raise HTTPException(
|
|
299
|
+
raise HTTPException(404)
|
|
293
300
|
data = _prompt_version_from_orm_version(prompt_version)
|
|
294
301
|
return GetPromptResponseBody(data=data)
|
|
295
302
|
|
|
@@ -303,8 +310,8 @@ async def get_prompt_version_by_prompt_version_id(
|
|
|
303
310
|
response_description="The prompt version with the specified tag",
|
|
304
311
|
responses=add_errors_to_responses(
|
|
305
312
|
[
|
|
306
|
-
|
|
307
|
-
|
|
313
|
+
404,
|
|
314
|
+
422,
|
|
308
315
|
]
|
|
309
316
|
),
|
|
310
317
|
response_model_by_alias=True,
|
|
@@ -333,9 +340,10 @@ async def get_prompt_version_by_tag_name(
|
|
|
333
340
|
try:
|
|
334
341
|
name = Identifier.model_validate(tag_name)
|
|
335
342
|
except ValidationError:
|
|
336
|
-
raise HTTPException(
|
|
343
|
+
raise HTTPException(422, "Invalid tag name")
|
|
337
344
|
stmt = (
|
|
338
345
|
select(models.PromptVersion)
|
|
346
|
+
.options(joinedload(models.PromptVersion.prompt))
|
|
339
347
|
.join_from(models.PromptVersion, models.PromptVersionTag)
|
|
340
348
|
.where(models.PromptVersionTag.name == name)
|
|
341
349
|
)
|
|
@@ -343,7 +351,7 @@ async def get_prompt_version_by_tag_name(
|
|
|
343
351
|
async with request.app.state.db() as session:
|
|
344
352
|
prompt_version: models.PromptVersion = await session.scalar(stmt)
|
|
345
353
|
if prompt_version is None:
|
|
346
|
-
raise HTTPException(
|
|
354
|
+
raise HTTPException(404)
|
|
347
355
|
data = _prompt_version_from_orm_version(prompt_version)
|
|
348
356
|
return GetPromptResponseBody(data=data)
|
|
349
357
|
|
|
@@ -356,8 +364,8 @@ async def get_prompt_version_by_tag_name(
|
|
|
356
364
|
response_description="The latest version of the specified prompt",
|
|
357
365
|
responses=add_errors_to_responses(
|
|
358
366
|
[
|
|
359
|
-
|
|
360
|
-
|
|
367
|
+
404,
|
|
368
|
+
422,
|
|
361
369
|
]
|
|
362
370
|
),
|
|
363
371
|
response_model_by_alias=True,
|
|
@@ -381,25 +389,31 @@ async def get_prompt_version_by_latest(
|
|
|
381
389
|
Raises:
|
|
382
390
|
HTTPException: If the prompt identifier is invalid or no prompt version is found.
|
|
383
391
|
"""
|
|
384
|
-
stmt =
|
|
392
|
+
stmt = (
|
|
393
|
+
select(models.PromptVersion)
|
|
394
|
+
.options(joinedload(models.PromptVersion.prompt))
|
|
395
|
+
.order_by(models.PromptVersion.id.desc())
|
|
396
|
+
.limit(1)
|
|
397
|
+
)
|
|
385
398
|
stmt = _filter_by_prompt_identifier(stmt.join(models.Prompt), prompt_identifier)
|
|
386
399
|
async with request.app.state.db() as session:
|
|
387
400
|
prompt_version: models.PromptVersion = await session.scalar(stmt)
|
|
388
401
|
if prompt_version is None:
|
|
389
|
-
raise HTTPException(
|
|
402
|
+
raise HTTPException(404)
|
|
390
403
|
data = _prompt_version_from_orm_version(prompt_version)
|
|
391
404
|
return GetPromptResponseBody(data=data)
|
|
392
405
|
|
|
393
406
|
|
|
394
407
|
@router.post(
|
|
395
408
|
"/prompts",
|
|
409
|
+
dependencies=[Depends(is_not_locked)],
|
|
396
410
|
operation_id="postPromptVersion",
|
|
397
411
|
summary="Create a new prompt",
|
|
398
412
|
description="Create a new prompt and its initial version. A prompt can have multiple versions.",
|
|
399
413
|
response_description="The newly created prompt version",
|
|
400
414
|
responses=add_errors_to_responses(
|
|
401
415
|
[
|
|
402
|
-
|
|
416
|
+
422,
|
|
403
417
|
]
|
|
404
418
|
),
|
|
405
419
|
response_model_by_alias=True,
|
|
@@ -429,7 +443,7 @@ async def create_prompt(
|
|
|
429
443
|
or request_body.version.template_type != PromptTemplateType.CHAT
|
|
430
444
|
):
|
|
431
445
|
raise HTTPException(
|
|
432
|
-
|
|
446
|
+
422,
|
|
433
447
|
"Only CHAT template type is supported for prompts",
|
|
434
448
|
)
|
|
435
449
|
prompt = request_body.prompt
|
|
@@ -437,7 +451,7 @@ async def create_prompt(
|
|
|
437
451
|
name = Identifier.model_validate(prompt.name)
|
|
438
452
|
except ValidationError as e:
|
|
439
453
|
raise HTTPException(
|
|
440
|
-
|
|
454
|
+
422,
|
|
441
455
|
"Invalid name identifier for prompt: " + e.errors()[0]["msg"],
|
|
442
456
|
)
|
|
443
457
|
version = request_body.version
|
|
@@ -446,17 +460,15 @@ async def create_prompt(
|
|
|
446
460
|
assert isinstance(user := request.user, PhoenixUser)
|
|
447
461
|
user_id = int(user.identity)
|
|
448
462
|
async with request.app.state.db() as session:
|
|
449
|
-
if not (
|
|
463
|
+
if not (prompt_orm := await session.scalar(select(models.Prompt).filter_by(name=name))):
|
|
450
464
|
prompt_orm = models.Prompt(
|
|
451
465
|
name=name,
|
|
452
466
|
description=prompt.description,
|
|
467
|
+
metadata_=prompt.metadata or {},
|
|
453
468
|
)
|
|
454
|
-
session.add(prompt_orm)
|
|
455
|
-
await session.flush()
|
|
456
|
-
prompt_id = prompt_orm.id
|
|
457
469
|
version_orm = models.PromptVersion(
|
|
458
470
|
user_id=user_id,
|
|
459
|
-
|
|
471
|
+
prompt=prompt_orm,
|
|
460
472
|
description=version.description,
|
|
461
473
|
model_provider=version.model_provider,
|
|
462
474
|
model_name=version.model_name,
|
|
@@ -494,8 +506,8 @@ class GetPromptVersionTagsResponseBody(PaginatedResponseBody[PromptVersionTag]):
|
|
|
494
506
|
response_description="A list of tags associated with the prompt version",
|
|
495
507
|
responses=add_errors_to_responses(
|
|
496
508
|
[
|
|
497
|
-
|
|
498
|
-
|
|
509
|
+
404,
|
|
510
|
+
422,
|
|
499
511
|
]
|
|
500
512
|
),
|
|
501
513
|
response_model_by_alias=True,
|
|
@@ -535,7 +547,7 @@ async def list_prompt_version_tags(
|
|
|
535
547
|
PromptVersionNodeType.__name__,
|
|
536
548
|
)
|
|
537
549
|
except ValueError:
|
|
538
|
-
raise HTTPException(
|
|
550
|
+
raise HTTPException(422, "Invalid prompt version ID")
|
|
539
551
|
|
|
540
552
|
# Build the query for tags
|
|
541
553
|
stmt = (
|
|
@@ -558,7 +570,7 @@ async def list_prompt_version_tags(
|
|
|
558
570
|
except ValueError:
|
|
559
571
|
raise HTTPException(
|
|
560
572
|
detail=f"Invalid cursor format: {cursor}",
|
|
561
|
-
status_code=
|
|
573
|
+
status_code=422,
|
|
562
574
|
)
|
|
563
575
|
|
|
564
576
|
# Apply limit
|
|
@@ -569,7 +581,7 @@ async def list_prompt_version_tags(
|
|
|
569
581
|
|
|
570
582
|
# Check if prompt version exists
|
|
571
583
|
if not result:
|
|
572
|
-
raise HTTPException(
|
|
584
|
+
raise HTTPException(404, "Prompt version not found")
|
|
573
585
|
|
|
574
586
|
# Check if there are any tags
|
|
575
587
|
has_tags = any(id_ is not None for _, id_, _, _ in result)
|
|
@@ -602,16 +614,17 @@ async def list_prompt_version_tags(
|
|
|
602
614
|
|
|
603
615
|
@router.post(
|
|
604
616
|
"/prompt_versions/{prompt_version_id}/tags",
|
|
617
|
+
dependencies=[Depends(is_not_locked)],
|
|
605
618
|
operation_id="createPromptVersionTag",
|
|
606
619
|
summary="Add tag to prompt version",
|
|
607
620
|
description="Add a new tag to a specific prompt version. Tags help identify and categorize "
|
|
608
621
|
"different versions of a prompt.",
|
|
609
622
|
response_description="No content returned on successful tag creation",
|
|
610
|
-
status_code=
|
|
623
|
+
status_code=204,
|
|
611
624
|
responses=add_errors_to_responses(
|
|
612
625
|
[
|
|
613
|
-
|
|
614
|
-
|
|
626
|
+
404,
|
|
627
|
+
422,
|
|
615
628
|
]
|
|
616
629
|
),
|
|
617
630
|
response_model_by_alias=True,
|
|
@@ -644,7 +657,7 @@ async def create_prompt_version_tag(
|
|
|
644
657
|
PromptVersionNodeType.__name__,
|
|
645
658
|
)
|
|
646
659
|
except ValueError:
|
|
647
|
-
raise HTTPException(
|
|
660
|
+
raise HTTPException(422, "Invalid prompt version ID")
|
|
648
661
|
user_id: Optional[int] = None
|
|
649
662
|
if request.app.state.authentication_enabled:
|
|
650
663
|
assert isinstance(user := request.user, PhoenixUser)
|
|
@@ -652,7 +665,7 @@ async def create_prompt_version_tag(
|
|
|
652
665
|
async with request.app.state.db() as session:
|
|
653
666
|
prompt_id = await session.scalar(select(models.PromptVersion.prompt_id).filter_by(id=id_))
|
|
654
667
|
if prompt_id is None:
|
|
655
|
-
raise HTTPException(
|
|
668
|
+
raise HTTPException(404)
|
|
656
669
|
dialect = SupportedSQLDialect(session.bind.dialect.name)
|
|
657
670
|
values = dict(
|
|
658
671
|
name=request_body.name,
|
|
@@ -683,7 +696,7 @@ def _parse_prompt_identifier(
|
|
|
683
696
|
prompt_identifier: str,
|
|
684
697
|
) -> _PromptIdentifier:
|
|
685
698
|
if not prompt_identifier:
|
|
686
|
-
raise HTTPException(
|
|
699
|
+
raise HTTPException(422, "Invalid prompt identifier")
|
|
687
700
|
try:
|
|
688
701
|
prompt_id = from_global_id_with_expected_type(
|
|
689
702
|
GlobalID.from_id(prompt_identifier),
|
|
@@ -693,7 +706,7 @@ def _parse_prompt_identifier(
|
|
|
693
706
|
try:
|
|
694
707
|
return Identifier.model_validate(prompt_identifier)
|
|
695
708
|
except ValidationError:
|
|
696
|
-
raise HTTPException(
|
|
709
|
+
raise HTTPException(422, "Invalid prompt name")
|
|
697
710
|
return _PromptId(prompt_id)
|
|
698
711
|
|
|
699
712
|
|
|
@@ -739,4 +752,5 @@ def _prompt_from_orm_prompt(orm_prompt: models.Prompt) -> Prompt:
|
|
|
739
752
|
source_prompt_id=source_prompt_id,
|
|
740
753
|
name=orm_prompt.name,
|
|
741
754
|
description=orm_prompt.description,
|
|
755
|
+
metadata=orm_prompt.metadata_,
|
|
742
756
|
)
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import warnings
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
from fastapi import APIRouter, Depends, HTTPException, Query
|
|
7
|
+
from pydantic import Field
|
|
8
|
+
from sqlalchemy import select
|
|
9
|
+
from starlette.requests import Request
|
|
10
|
+
|
|
11
|
+
from phoenix.db import models
|
|
12
|
+
from phoenix.db.helpers import SupportedSQLDialect
|
|
13
|
+
from phoenix.db.insertion.helpers import as_kv, insert_on_conflict
|
|
14
|
+
from phoenix.server.api.routers.v1.models import V1RoutesBaseModel
|
|
15
|
+
from phoenix.server.authorization import is_not_locked
|
|
16
|
+
from phoenix.server.bearer_auth import PhoenixUser
|
|
17
|
+
|
|
18
|
+
from .annotations import SessionAnnotationData
|
|
19
|
+
from .utils import RequestBody, ResponseBody, add_errors_to_responses
|
|
20
|
+
|
|
21
|
+
router = APIRouter(tags=["sessions"])
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class InsertedSessionAnnotation(V1RoutesBaseModel):
|
|
25
|
+
id: str = Field(description="The ID of the inserted session annotation")
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class AnnotateSessionsRequestBody(RequestBody[list[SessionAnnotationData]]):
|
|
29
|
+
pass
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class AnnotateSessionsResponseBody(ResponseBody[list[InsertedSessionAnnotation]]):
|
|
33
|
+
pass
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@router.post(
|
|
37
|
+
"/session_annotations",
|
|
38
|
+
dependencies=[Depends(is_not_locked)],
|
|
39
|
+
operation_id="annotateSessions",
|
|
40
|
+
summary="Create session annotations",
|
|
41
|
+
responses=add_errors_to_responses([{"status_code": 404, "description": "Session not found"}]),
|
|
42
|
+
response_description="Session annotations inserted successfully",
|
|
43
|
+
include_in_schema=True,
|
|
44
|
+
)
|
|
45
|
+
async def annotate_sessions(
|
|
46
|
+
request: Request,
|
|
47
|
+
request_body: AnnotateSessionsRequestBody,
|
|
48
|
+
sync: bool = Query(default=False, description="If true, fulfill request synchronously."),
|
|
49
|
+
) -> AnnotateSessionsResponseBody:
|
|
50
|
+
if not request_body.data:
|
|
51
|
+
return AnnotateSessionsResponseBody(data=[])
|
|
52
|
+
|
|
53
|
+
user_id: Optional[int] = None
|
|
54
|
+
if request.app.state.authentication_enabled and isinstance(request.user, PhoenixUser):
|
|
55
|
+
user_id = int(request.user.identity)
|
|
56
|
+
|
|
57
|
+
session_annotations = request_body.data
|
|
58
|
+
filtered_session_annotations = list(filter(lambda d: d.name != "note", session_annotations))
|
|
59
|
+
if len(filtered_session_annotations) != len(session_annotations):
|
|
60
|
+
warnings.warn(
|
|
61
|
+
(
|
|
62
|
+
"Session annotations with the name 'note' are not supported in this endpoint. "
|
|
63
|
+
"They will be ignored."
|
|
64
|
+
),
|
|
65
|
+
UserWarning,
|
|
66
|
+
)
|
|
67
|
+
precursors = [d.as_precursor(user_id=user_id) for d in filtered_session_annotations]
|
|
68
|
+
if not sync:
|
|
69
|
+
await request.state.enqueue_annotations(*precursors)
|
|
70
|
+
return AnnotateSessionsResponseBody(data=[])
|
|
71
|
+
|
|
72
|
+
session_ids = {p.session_id for p in precursors}
|
|
73
|
+
async with request.app.state.db() as session:
|
|
74
|
+
existing_sessions = {
|
|
75
|
+
session_id: rowid
|
|
76
|
+
async for session_id, rowid in await session.stream(
|
|
77
|
+
select(models.ProjectSession.session_id, models.ProjectSession.id).filter(
|
|
78
|
+
models.ProjectSession.session_id.in_(session_ids)
|
|
79
|
+
)
|
|
80
|
+
)
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
missing_session_ids = session_ids - set(existing_sessions.keys())
|
|
84
|
+
# We prefer to fail the entire operation if there are missing sessions in sync mode
|
|
85
|
+
if missing_session_ids:
|
|
86
|
+
raise HTTPException(
|
|
87
|
+
detail=f"Sessions with IDs {', '.join(missing_session_ids)} do not exist.",
|
|
88
|
+
status_code=404,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
async with request.app.state.db() as session:
|
|
92
|
+
inserted_ids = []
|
|
93
|
+
dialect = SupportedSQLDialect(session.bind.dialect.name)
|
|
94
|
+
for p in precursors:
|
|
95
|
+
values = dict(as_kv(p.as_insertable(existing_sessions[p.session_id]).row))
|
|
96
|
+
session_annotation_id = await session.scalar(
|
|
97
|
+
insert_on_conflict(
|
|
98
|
+
values,
|
|
99
|
+
dialect=dialect,
|
|
100
|
+
table=models.ProjectSessionAnnotation,
|
|
101
|
+
unique_by=("name", "project_session_id", "identifier"),
|
|
102
|
+
).returning(models.ProjectSessionAnnotation.id)
|
|
103
|
+
)
|
|
104
|
+
inserted_ids.append(session_annotation_id)
|
|
105
|
+
|
|
106
|
+
return AnnotateSessionsResponseBody(
|
|
107
|
+
data=[InsertedSessionAnnotation(id=str(inserted_id)) for inserted_id in inserted_ids]
|
|
108
|
+
)
|