arize-phoenix 11.23.1__py3-none-any.whl → 12.28.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +61 -36
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/RECORD +212 -162
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
- phoenix/__generated__/__init__.py +0 -0
- phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
- phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
- phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
- phoenix/__init__.py +2 -1
- phoenix/auth.py +27 -2
- phoenix/config.py +1594 -81
- phoenix/db/README.md +546 -28
- phoenix/db/bulk_inserter.py +119 -116
- phoenix/db/engines.py +140 -33
- phoenix/db/facilitator.py +22 -1
- phoenix/db/helpers.py +818 -65
- phoenix/db/iam_auth.py +64 -0
- phoenix/db/insertion/dataset.py +133 -1
- phoenix/db/insertion/document_annotation.py +9 -6
- phoenix/db/insertion/evaluation.py +2 -3
- phoenix/db/insertion/helpers.py +2 -2
- phoenix/db/insertion/session_annotation.py +176 -0
- phoenix/db/insertion/span_annotation.py +3 -4
- phoenix/db/insertion/trace_annotation.py +3 -4
- phoenix/db/insertion/types.py +41 -18
- phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
- phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
- phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
- phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
- phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
- phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
- phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
- phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
- phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
- phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
- phoenix/db/models.py +364 -56
- phoenix/db/pg_config.py +10 -0
- phoenix/db/types/trace_retention.py +7 -6
- phoenix/experiments/functions.py +69 -19
- phoenix/inferences/inferences.py +1 -2
- phoenix/server/api/auth.py +9 -0
- phoenix/server/api/auth_messages.py +46 -0
- phoenix/server/api/context.py +60 -0
- phoenix/server/api/dataloaders/__init__.py +36 -0
- phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
- phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
- phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
- phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
- phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
- phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
- phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
- phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
- phoenix/server/api/dataloaders/dataset_labels.py +36 -0
- phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
- phoenix/server/api/dataloaders/document_evaluations.py +6 -9
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
- phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
- phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
- phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
- phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
- phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
- phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
- phoenix/server/api/dataloaders/record_counts.py +37 -10
- phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project.py +28 -14
- phoenix/server/api/dataloaders/span_costs.py +3 -9
- phoenix/server/api/dataloaders/table_fields.py +2 -2
- phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
- phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
- phoenix/server/api/exceptions.py +5 -1
- phoenix/server/api/helpers/playground_clients.py +263 -83
- phoenix/server/api/helpers/playground_spans.py +2 -1
- phoenix/server/api/helpers/playground_users.py +26 -0
- phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
- phoenix/server/api/helpers/prompts/models.py +61 -19
- phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
- phoenix/server/api/input_types/ChatCompletionInput.py +3 -0
- phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
- phoenix/server/api/input_types/DatasetFilter.py +5 -2
- phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
- phoenix/server/api/input_types/GenerativeModelInput.py +3 -0
- phoenix/server/api/input_types/ProjectSessionSort.py +158 -1
- phoenix/server/api/input_types/PromptVersionInput.py +47 -1
- phoenix/server/api/input_types/SpanSort.py +3 -2
- phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
- phoenix/server/api/input_types/UserRoleInput.py +1 -0
- phoenix/server/api/mutations/__init__.py +8 -0
- phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
- phoenix/server/api/mutations/api_key_mutations.py +15 -20
- phoenix/server/api/mutations/chat_mutations.py +106 -37
- phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
- phoenix/server/api/mutations/dataset_mutations.py +21 -16
- phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
- phoenix/server/api/mutations/experiment_mutations.py +2 -2
- phoenix/server/api/mutations/export_events_mutations.py +3 -3
- phoenix/server/api/mutations/model_mutations.py +11 -9
- phoenix/server/api/mutations/project_mutations.py +4 -4
- phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
- phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
- phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
- phoenix/server/api/mutations/prompt_mutations.py +65 -129
- phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
- phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
- phoenix/server/api/mutations/trace_annotations_mutations.py +13 -8
- phoenix/server/api/mutations/trace_mutations.py +3 -3
- phoenix/server/api/mutations/user_mutations.py +55 -26
- phoenix/server/api/queries.py +501 -617
- phoenix/server/api/routers/__init__.py +2 -2
- phoenix/server/api/routers/auth.py +141 -87
- phoenix/server/api/routers/ldap.py +229 -0
- phoenix/server/api/routers/oauth2.py +349 -101
- phoenix/server/api/routers/v1/__init__.py +22 -4
- phoenix/server/api/routers/v1/annotation_configs.py +19 -30
- phoenix/server/api/routers/v1/annotations.py +455 -13
- phoenix/server/api/routers/v1/datasets.py +355 -68
- phoenix/server/api/routers/v1/documents.py +142 -0
- phoenix/server/api/routers/v1/evaluations.py +20 -28
- phoenix/server/api/routers/v1/experiment_evaluations.py +16 -6
- phoenix/server/api/routers/v1/experiment_runs.py +335 -59
- phoenix/server/api/routers/v1/experiments.py +475 -47
- phoenix/server/api/routers/v1/projects.py +16 -50
- phoenix/server/api/routers/v1/prompts.py +50 -39
- phoenix/server/api/routers/v1/sessions.py +108 -0
- phoenix/server/api/routers/v1/spans.py +156 -96
- phoenix/server/api/routers/v1/traces.py +51 -77
- phoenix/server/api/routers/v1/users.py +64 -24
- phoenix/server/api/routers/v1/utils.py +3 -7
- phoenix/server/api/subscriptions.py +257 -93
- phoenix/server/api/types/Annotation.py +90 -23
- phoenix/server/api/types/ApiKey.py +13 -17
- phoenix/server/api/types/AuthMethod.py +1 -0
- phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
- phoenix/server/api/types/Dataset.py +199 -72
- phoenix/server/api/types/DatasetExample.py +88 -18
- phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
- phoenix/server/api/types/DatasetLabel.py +57 -0
- phoenix/server/api/types/DatasetSplit.py +98 -0
- phoenix/server/api/types/DatasetVersion.py +49 -4
- phoenix/server/api/types/DocumentAnnotation.py +212 -0
- phoenix/server/api/types/Experiment.py +215 -68
- phoenix/server/api/types/ExperimentComparison.py +3 -9
- phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
- phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
- phoenix/server/api/types/ExperimentRun.py +120 -70
- phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
- phoenix/server/api/types/GenerativeModel.py +95 -42
- phoenix/server/api/types/GenerativeProvider.py +1 -1
- phoenix/server/api/types/ModelInterface.py +7 -2
- phoenix/server/api/types/PlaygroundModel.py +12 -2
- phoenix/server/api/types/Project.py +218 -185
- phoenix/server/api/types/ProjectSession.py +146 -29
- phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
- phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
- phoenix/server/api/types/Prompt.py +119 -39
- phoenix/server/api/types/PromptLabel.py +42 -25
- phoenix/server/api/types/PromptVersion.py +11 -8
- phoenix/server/api/types/PromptVersionTag.py +65 -25
- phoenix/server/api/types/Span.py +130 -123
- phoenix/server/api/types/SpanAnnotation.py +189 -42
- phoenix/server/api/types/SystemApiKey.py +65 -1
- phoenix/server/api/types/Trace.py +184 -53
- phoenix/server/api/types/TraceAnnotation.py +149 -50
- phoenix/server/api/types/User.py +128 -33
- phoenix/server/api/types/UserApiKey.py +73 -26
- phoenix/server/api/types/node.py +10 -0
- phoenix/server/api/types/pagination.py +11 -2
- phoenix/server/app.py +154 -36
- phoenix/server/authorization.py +5 -4
- phoenix/server/bearer_auth.py +13 -5
- phoenix/server/cost_tracking/cost_model_lookup.py +42 -14
- phoenix/server/cost_tracking/model_cost_manifest.json +1085 -194
- phoenix/server/daemons/generative_model_store.py +61 -9
- phoenix/server/daemons/span_cost_calculator.py +10 -8
- phoenix/server/dml_event.py +13 -0
- phoenix/server/email/sender.py +29 -2
- phoenix/server/grpc_server.py +9 -9
- phoenix/server/jwt_store.py +8 -6
- phoenix/server/ldap.py +1449 -0
- phoenix/server/main.py +9 -3
- phoenix/server/oauth2.py +330 -12
- phoenix/server/prometheus.py +43 -6
- phoenix/server/rate_limiters.py +4 -9
- phoenix/server/retention.py +33 -20
- phoenix/server/session_filters.py +49 -0
- phoenix/server/static/.vite/manifest.json +51 -53
- phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
- phoenix/server/static/assets/{index-BPCwGQr8.js → index-CTQoemZv.js} +42 -35
- phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
- phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
- phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
- phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
- phoenix/server/static/assets/{vendor-recharts-Bw30oz1A.js → vendor-recharts-V9cwpXsm.js} +7 -7
- phoenix/server/static/assets/{vendor-shiki-DZajAPeq.js → vendor-shiki-Do--csgv.js} +1 -1
- phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
- phoenix/server/templates/index.html +7 -1
- phoenix/server/thread_server.py +1 -2
- phoenix/server/utils.py +74 -0
- phoenix/session/client.py +55 -1
- phoenix/session/data_extractor.py +5 -0
- phoenix/session/evaluation.py +8 -4
- phoenix/session/session.py +44 -8
- phoenix/settings.py +2 -0
- phoenix/trace/attributes.py +80 -13
- phoenix/trace/dsl/query.py +2 -0
- phoenix/trace/projects.py +5 -0
- phoenix/utilities/template_formatters.py +1 -1
- phoenix/version.py +1 -1
- phoenix/server/api/types/Evaluation.py +0 -39
- phoenix/server/static/assets/components-D0DWAf0l.js +0 -5650
- phoenix/server/static/assets/pages-Creyamao.js +0 -8612
- phoenix/server/static/assets/vendor-CU36oj8y.js +0 -905
- phoenix/server/static/assets/vendor-CqDb5u4o.css +0 -1
- phoenix/server/static/assets/vendor-arizeai-Ctgw0e1G.js +0 -168
- phoenix/server/static/assets/vendor-codemirror-Cojjzqb9.js +0 -25
- phoenix/server/static/assets/vendor-three-BLWp5bic.js +0 -2998
- phoenix/utilities/deprecation.py +0 -31
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -4,26 +4,38 @@ from random import getrandbits
|
|
|
4
4
|
from typing import Any, Optional
|
|
5
5
|
|
|
6
6
|
import pandas as pd
|
|
7
|
-
|
|
7
|
+
import sqlalchemy as sa
|
|
8
|
+
from fastapi import APIRouter, Depends, HTTPException, Path, Query, Response
|
|
8
9
|
from pydantic import Field
|
|
9
|
-
from sqlalchemy import and_, func, select
|
|
10
|
+
from sqlalchemy import and_, case, func, select
|
|
10
11
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
11
12
|
from sqlalchemy.orm import joinedload
|
|
12
13
|
from starlette.requests import Request
|
|
13
14
|
from starlette.responses import PlainTextResponse
|
|
14
|
-
from starlette.status import HTTP_200_OK, HTTP_404_NOT_FOUND, HTTP_422_UNPROCESSABLE_ENTITY
|
|
15
15
|
from strawberry.relay import GlobalID
|
|
16
16
|
|
|
17
17
|
from phoenix.db import models
|
|
18
|
-
from phoenix.db.helpers import
|
|
18
|
+
from phoenix.db.helpers import (
|
|
19
|
+
SupportedSQLDialect,
|
|
20
|
+
get_experiment_incomplete_runs_query,
|
|
21
|
+
insert_experiment_with_examples_snapshot,
|
|
22
|
+
)
|
|
19
23
|
from phoenix.db.insertion.helpers import insert_on_conflict
|
|
24
|
+
from phoenix.server.api.routers.v1.datasets import DatasetExample
|
|
20
25
|
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
21
26
|
from phoenix.server.authorization import is_not_locked
|
|
27
|
+
from phoenix.server.bearer_auth import PhoenixUser
|
|
22
28
|
from phoenix.server.dml_event import ExperimentInsertEvent
|
|
23
29
|
from phoenix.server.experiments.utils import generate_experiment_project_name
|
|
24
30
|
|
|
31
|
+
from .datasets import _resolve_split_identifiers
|
|
25
32
|
from .models import V1RoutesBaseModel
|
|
26
|
-
from .utils import
|
|
33
|
+
from .utils import (
|
|
34
|
+
PaginatedResponseBody,
|
|
35
|
+
ResponseBody,
|
|
36
|
+
add_errors_to_responses,
|
|
37
|
+
add_text_csv_content_to_responses,
|
|
38
|
+
)
|
|
27
39
|
|
|
28
40
|
router = APIRouter(tags=["experiments"], include_in_schema=True)
|
|
29
41
|
|
|
@@ -46,13 +58,19 @@ class Experiment(V1RoutesBaseModel):
|
|
|
46
58
|
dataset_version_id: str = Field(
|
|
47
59
|
description="The ID of the dataset version associated with the experiment"
|
|
48
60
|
)
|
|
49
|
-
repetitions: int = Field(description="Number of times the experiment is repeated")
|
|
61
|
+
repetitions: int = Field(description="Number of times the experiment is repeated", gt=0)
|
|
50
62
|
metadata: dict[str, Any] = Field(description="Metadata of the experiment")
|
|
51
63
|
project_name: Optional[str] = Field(
|
|
52
64
|
description="The name of the project associated with the experiment"
|
|
53
65
|
)
|
|
54
66
|
created_at: datetime = Field(description="The creation timestamp of the experiment")
|
|
55
67
|
updated_at: datetime = Field(description="The last update timestamp of the experiment")
|
|
68
|
+
example_count: int = Field(description="Number of examples in the experiment")
|
|
69
|
+
successful_run_count: int = Field(description="Number of successful runs in the experiment")
|
|
70
|
+
failed_run_count: int = Field(description="Number of failed runs in the experiment")
|
|
71
|
+
missing_run_count: int = Field(
|
|
72
|
+
description="Number of missing (not yet executed) runs in the experiment"
|
|
73
|
+
)
|
|
56
74
|
|
|
57
75
|
|
|
58
76
|
class CreateExperimentRequestBody(V1RoutesBaseModel):
|
|
@@ -77,6 +95,10 @@ class CreateExperimentRequestBody(V1RoutesBaseModel):
|
|
|
77
95
|
"(if omitted, the latest version will be used)"
|
|
78
96
|
),
|
|
79
97
|
)
|
|
98
|
+
splits: Optional[list[str]] = Field(
|
|
99
|
+
default=None,
|
|
100
|
+
description="List of dataset split identifiers (GlobalIDs or names) to filter by",
|
|
101
|
+
)
|
|
80
102
|
repetitions: int = Field(
|
|
81
103
|
default=1, description="Number of times the experiment should be repeated for each example"
|
|
82
104
|
)
|
|
@@ -92,7 +114,7 @@ class CreateExperimentResponseBody(ResponseBody[Experiment]):
|
|
|
92
114
|
operation_id="createExperiment",
|
|
93
115
|
summary="Create experiment on a dataset",
|
|
94
116
|
responses=add_errors_to_responses(
|
|
95
|
-
[{"status_code":
|
|
117
|
+
[{"status_code": 404, "description": "Dataset or DatasetVersion not found"}]
|
|
96
118
|
),
|
|
97
119
|
response_description="Experiment retrieved successfully",
|
|
98
120
|
)
|
|
@@ -101,26 +123,38 @@ async def create_experiment(
|
|
|
101
123
|
request_body: CreateExperimentRequestBody,
|
|
102
124
|
dataset_id: str = Path(..., title="Dataset ID"),
|
|
103
125
|
) -> CreateExperimentResponseBody:
|
|
104
|
-
|
|
126
|
+
try:
|
|
127
|
+
dataset_globalid = GlobalID.from_id(dataset_id)
|
|
128
|
+
except Exception as e:
|
|
129
|
+
raise HTTPException(
|
|
130
|
+
detail=f"Invalid dataset ID format: {dataset_id}",
|
|
131
|
+
status_code=422,
|
|
132
|
+
) from e
|
|
105
133
|
try:
|
|
106
134
|
dataset_rowid = from_global_id_with_expected_type(dataset_globalid, "Dataset")
|
|
107
135
|
except ValueError:
|
|
108
136
|
raise HTTPException(
|
|
109
137
|
detail="Dataset with ID {dataset_globalid} does not exist",
|
|
110
|
-
status_code=
|
|
138
|
+
status_code=404,
|
|
111
139
|
)
|
|
112
140
|
|
|
113
141
|
dataset_version_globalid_str = request_body.version_id
|
|
114
142
|
if dataset_version_globalid_str is not None:
|
|
115
143
|
try:
|
|
116
144
|
dataset_version_globalid = GlobalID.from_id(dataset_version_globalid_str)
|
|
145
|
+
except Exception as e:
|
|
146
|
+
raise HTTPException(
|
|
147
|
+
detail=f"Invalid dataset version ID format: {dataset_version_globalid_str}",
|
|
148
|
+
status_code=422,
|
|
149
|
+
) from e
|
|
150
|
+
try:
|
|
117
151
|
dataset_version_id = from_global_id_with_expected_type(
|
|
118
152
|
dataset_version_globalid, "DatasetVersion"
|
|
119
153
|
)
|
|
120
154
|
except ValueError:
|
|
121
155
|
raise HTTPException(
|
|
122
156
|
detail=f"DatasetVersion with ID {dataset_version_globalid_str} does not exist",
|
|
123
|
-
status_code=
|
|
157
|
+
status_code=404,
|
|
124
158
|
)
|
|
125
159
|
|
|
126
160
|
async with request.app.state.db() as session:
|
|
@@ -130,7 +164,7 @@ async def create_experiment(
|
|
|
130
164
|
if result is None:
|
|
131
165
|
raise HTTPException(
|
|
132
166
|
detail=f"Dataset with ID {dataset_globalid} does not exist",
|
|
133
|
-
status_code=
|
|
167
|
+
status_code=404,
|
|
134
168
|
)
|
|
135
169
|
dataset_name = result.name
|
|
136
170
|
if dataset_version_globalid_str is None:
|
|
@@ -143,7 +177,7 @@ async def create_experiment(
|
|
|
143
177
|
if not dataset_version:
|
|
144
178
|
raise HTTPException(
|
|
145
179
|
detail=f"Dataset {dataset_globalid} does not have any versions",
|
|
146
|
-
status_code=
|
|
180
|
+
status_code=404,
|
|
147
181
|
)
|
|
148
182
|
dataset_version_id = dataset_version.id
|
|
149
183
|
dataset_version_globalid = GlobalID("DatasetVersion", str(dataset_version_id))
|
|
@@ -155,8 +189,11 @@ async def create_experiment(
|
|
|
155
189
|
if not dataset_version:
|
|
156
190
|
raise HTTPException(
|
|
157
191
|
detail=f"DatasetVersion with ID {dataset_version_globalid} does not exist",
|
|
158
|
-
status_code=
|
|
192
|
+
status_code=404,
|
|
159
193
|
)
|
|
194
|
+
user_id: Optional[int] = None
|
|
195
|
+
if request.app.state.authentication_enabled and isinstance(request.user, PhoenixUser):
|
|
196
|
+
user_id = int(request.user.identity)
|
|
160
197
|
|
|
161
198
|
# generate a semi-unique name for the experiment
|
|
162
199
|
experiment_name = request_body.name or _generate_experiment_name(dataset_name)
|
|
@@ -172,9 +209,23 @@ async def create_experiment(
|
|
|
172
209
|
repetitions=request_body.repetitions,
|
|
173
210
|
metadata_=request_body.metadata or {},
|
|
174
211
|
project_name=project_name,
|
|
212
|
+
user_id=user_id,
|
|
175
213
|
)
|
|
176
|
-
|
|
177
|
-
|
|
214
|
+
|
|
215
|
+
if request_body.splits is not None:
|
|
216
|
+
# Resolve split identifiers (IDs or names) to IDs and names
|
|
217
|
+
resolved_split_ids, _ = await _resolve_split_identifiers(session, request_body.splits)
|
|
218
|
+
|
|
219
|
+
# Generate experiment dataset splits relation
|
|
220
|
+
# prior to the crosswalk table insert
|
|
221
|
+
# in insert_experiment_with_examples_snapshot
|
|
222
|
+
experiment.experiment_dataset_splits = [
|
|
223
|
+
models.ExperimentDatasetSplit(dataset_split_id=split_id)
|
|
224
|
+
for split_id in resolved_split_ids
|
|
225
|
+
]
|
|
226
|
+
|
|
227
|
+
# crosswalk table assumes the relation is already present
|
|
228
|
+
await insert_experiment_with_examples_snapshot(session, experiment)
|
|
178
229
|
|
|
179
230
|
dialect = SupportedSQLDialect(session.bind.dialect.name)
|
|
180
231
|
project_rowid = await session.scalar(
|
|
@@ -197,6 +248,19 @@ async def create_experiment(
|
|
|
197
248
|
dataset_version_globalid = GlobalID(
|
|
198
249
|
"DatasetVersion", str(experiment.dataset_version_id)
|
|
199
250
|
)
|
|
251
|
+
|
|
252
|
+
# Optimization: We just created this experiment, so we know there are 0 runs.
|
|
253
|
+
# No need to query ExperimentRun table - just count the examples.
|
|
254
|
+
example_count = await session.scalar(
|
|
255
|
+
select(func.count())
|
|
256
|
+
.select_from(models.ExperimentDatasetExample)
|
|
257
|
+
.where(models.ExperimentDatasetExample.experiment_id == experiment.id)
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
# No runs exist yet for a newly created experiment
|
|
261
|
+
successful_run_count = 0
|
|
262
|
+
failed_run_count = 0
|
|
263
|
+
missing_run_count = (example_count or 0) * experiment.repetitions
|
|
200
264
|
request.state.event_queue.put(ExperimentInsertEvent((experiment.id,)))
|
|
201
265
|
return CreateExperimentResponseBody(
|
|
202
266
|
data=Experiment(
|
|
@@ -208,6 +272,10 @@ async def create_experiment(
|
|
|
208
272
|
project_name=experiment.project_name,
|
|
209
273
|
created_at=experiment.created_at,
|
|
210
274
|
updated_at=experiment.updated_at,
|
|
275
|
+
example_count=example_count or 0,
|
|
276
|
+
successful_run_count=successful_run_count or 0,
|
|
277
|
+
failed_run_count=failed_run_count or 0,
|
|
278
|
+
missing_run_count=missing_run_count,
|
|
211
279
|
)
|
|
212
280
|
)
|
|
213
281
|
|
|
@@ -221,18 +289,24 @@ class GetExperimentResponseBody(ResponseBody[Experiment]):
|
|
|
221
289
|
operation_id="getExperiment",
|
|
222
290
|
summary="Get experiment by ID",
|
|
223
291
|
responses=add_errors_to_responses(
|
|
224
|
-
[{"status_code":
|
|
292
|
+
[{"status_code": 404, "description": "Experiment not found"}]
|
|
225
293
|
),
|
|
226
294
|
response_description="Experiment retrieved successfully",
|
|
227
295
|
)
|
|
228
296
|
async def get_experiment(request: Request, experiment_id: str) -> GetExperimentResponseBody:
|
|
229
|
-
|
|
297
|
+
try:
|
|
298
|
+
experiment_globalid = GlobalID.from_id(experiment_id)
|
|
299
|
+
except Exception as e:
|
|
300
|
+
raise HTTPException(
|
|
301
|
+
detail=f"Invalid experiment ID format: {experiment_id}",
|
|
302
|
+
status_code=422,
|
|
303
|
+
) from e
|
|
230
304
|
try:
|
|
231
305
|
experiment_rowid = from_global_id_with_expected_type(experiment_globalid, "Experiment")
|
|
232
306
|
except ValueError:
|
|
233
307
|
raise HTTPException(
|
|
234
308
|
detail="Experiment with ID {experiment_globalid} does not exist",
|
|
235
|
-
status_code=
|
|
309
|
+
status_code=404,
|
|
236
310
|
)
|
|
237
311
|
|
|
238
312
|
async with request.app.state.db() as session:
|
|
@@ -243,11 +317,48 @@ async def get_experiment(request: Request, experiment_id: str) -> GetExperimentR
|
|
|
243
317
|
if not experiment:
|
|
244
318
|
raise HTTPException(
|
|
245
319
|
detail=f"Experiment with ID {experiment_globalid} does not exist",
|
|
246
|
-
status_code=
|
|
320
|
+
status_code=404,
|
|
247
321
|
)
|
|
248
322
|
|
|
249
323
|
dataset_globalid = GlobalID("Dataset", str(experiment.dataset_id))
|
|
250
324
|
dataset_version_globalid = GlobalID("DatasetVersion", str(experiment.dataset_version_id))
|
|
325
|
+
|
|
326
|
+
# Get counts efficiently: use CASE to count successful and failed in single table scan
|
|
327
|
+
run_counts_subq = (
|
|
328
|
+
select(
|
|
329
|
+
func.sum(case((models.ExperimentRun.error.is_(None), 1), else_=0)).label(
|
|
330
|
+
"successful_run_count"
|
|
331
|
+
),
|
|
332
|
+
func.sum(case((models.ExperimentRun.error.is_not(None), 1), else_=0)).label(
|
|
333
|
+
"failed_run_count"
|
|
334
|
+
),
|
|
335
|
+
)
|
|
336
|
+
.select_from(models.ExperimentRun)
|
|
337
|
+
.where(models.ExperimentRun.experiment_id == experiment_rowid)
|
|
338
|
+
.subquery()
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
counts_result = await session.execute(
|
|
342
|
+
select(
|
|
343
|
+
select(func.count())
|
|
344
|
+
.select_from(models.ExperimentDatasetExample)
|
|
345
|
+
.where(models.ExperimentDatasetExample.experiment_id == experiment_rowid)
|
|
346
|
+
.scalar_subquery()
|
|
347
|
+
.label("example_count"),
|
|
348
|
+
run_counts_subq.c.successful_run_count,
|
|
349
|
+
run_counts_subq.c.failed_run_count,
|
|
350
|
+
).select_from(run_counts_subq)
|
|
351
|
+
)
|
|
352
|
+
counts = counts_result.one()
|
|
353
|
+
example_count = counts.example_count
|
|
354
|
+
successful_run_count = counts.successful_run_count
|
|
355
|
+
failed_run_count = counts.failed_run_count
|
|
356
|
+
|
|
357
|
+
# Calculate missing runs (no database query needed)
|
|
358
|
+
total_expected_runs = (example_count or 0) * experiment.repetitions
|
|
359
|
+
missing_run_count = (
|
|
360
|
+
total_expected_runs - (successful_run_count or 0) - (failed_run_count or 0)
|
|
361
|
+
)
|
|
251
362
|
return GetExperimentResponseBody(
|
|
252
363
|
data=Experiment(
|
|
253
364
|
id=str(experiment_globalid),
|
|
@@ -258,31 +369,246 @@ async def get_experiment(request: Request, experiment_id: str) -> GetExperimentR
|
|
|
258
369
|
project_name=experiment.project_name,
|
|
259
370
|
created_at=experiment.created_at,
|
|
260
371
|
updated_at=experiment.updated_at,
|
|
372
|
+
example_count=example_count or 0,
|
|
373
|
+
successful_run_count=successful_run_count or 0,
|
|
374
|
+
failed_run_count=failed_run_count or 0,
|
|
375
|
+
missing_run_count=missing_run_count,
|
|
376
|
+
)
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
@router.delete(
|
|
381
|
+
"/experiments/{experiment_id}",
|
|
382
|
+
operation_id="deleteExperiment",
|
|
383
|
+
summary="Delete experiment by ID",
|
|
384
|
+
responses=add_errors_to_responses(
|
|
385
|
+
[{"status_code": 404, "description": "Experiment not found"}]
|
|
386
|
+
),
|
|
387
|
+
response_description="Experiment deleted successfully",
|
|
388
|
+
status_code=204,
|
|
389
|
+
)
|
|
390
|
+
async def delete_experiment(
|
|
391
|
+
request: Request,
|
|
392
|
+
experiment_id: str,
|
|
393
|
+
) -> None:
|
|
394
|
+
try:
|
|
395
|
+
experiment_globalid = GlobalID.from_id(experiment_id)
|
|
396
|
+
except Exception as e:
|
|
397
|
+
raise HTTPException(
|
|
398
|
+
detail=f"Invalid experiment ID format: {experiment_id}",
|
|
399
|
+
status_code=422,
|
|
400
|
+
) from e
|
|
401
|
+
try:
|
|
402
|
+
experiment_rowid = from_global_id_with_expected_type(experiment_globalid, "Experiment")
|
|
403
|
+
except ValueError:
|
|
404
|
+
raise HTTPException(
|
|
405
|
+
detail=f"Experiment with ID {experiment_globalid} does not exist",
|
|
406
|
+
status_code=404,
|
|
261
407
|
)
|
|
408
|
+
|
|
409
|
+
stmt = (
|
|
410
|
+
sa.delete(models.Experiment)
|
|
411
|
+
.where(models.Experiment.id == experiment_rowid)
|
|
412
|
+
.returning(models.Experiment.id)
|
|
262
413
|
)
|
|
414
|
+
async with request.app.state.db() as session:
|
|
415
|
+
if (await session.scalar(stmt)) is None:
|
|
416
|
+
raise HTTPException(detail="Experiment does not exist", status_code=404)
|
|
263
417
|
|
|
264
418
|
|
|
265
|
-
class ListExperimentsResponseBody(
|
|
419
|
+
class ListExperimentsResponseBody(PaginatedResponseBody[Experiment]):
|
|
266
420
|
pass
|
|
267
421
|
|
|
268
422
|
|
|
423
|
+
class IncompleteExperimentRun(V1RoutesBaseModel):
|
|
424
|
+
"""
|
|
425
|
+
Information about incomplete runs for a dataset example
|
|
426
|
+
"""
|
|
427
|
+
|
|
428
|
+
dataset_example: DatasetExample = Field(description="The dataset example")
|
|
429
|
+
repetition_numbers: list[int] = Field(
|
|
430
|
+
description="List of repetition numbers that need to be run"
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
|
|
434
|
+
class GetIncompleteExperimentRunsResponseBody(PaginatedResponseBody[IncompleteExperimentRun]):
|
|
435
|
+
pass
|
|
436
|
+
|
|
437
|
+
|
|
438
|
+
@router.get(
|
|
439
|
+
"/experiments/{experiment_id}/incomplete-runs",
|
|
440
|
+
operation_id="getIncompleteExperimentRuns",
|
|
441
|
+
summary="Get incomplete runs for an experiment",
|
|
442
|
+
responses=add_errors_to_responses(
|
|
443
|
+
[
|
|
444
|
+
{"status_code": 404, "description": "Experiment not found"},
|
|
445
|
+
{"status_code": 422, "description": "Invalid cursor format"},
|
|
446
|
+
]
|
|
447
|
+
),
|
|
448
|
+
response_description="Incomplete runs retrieved successfully",
|
|
449
|
+
)
|
|
450
|
+
async def get_incomplete_runs(
|
|
451
|
+
request: Request,
|
|
452
|
+
experiment_id: str,
|
|
453
|
+
cursor: Optional[str] = Query(default=None, description="Cursor for pagination"),
|
|
454
|
+
limit: int = Query(
|
|
455
|
+
default=50, description="Maximum number of examples with incomplete runs to return", gt=0
|
|
456
|
+
),
|
|
457
|
+
) -> GetIncompleteExperimentRunsResponseBody:
|
|
458
|
+
"""
|
|
459
|
+
Get runs that need to be completed for this experiment.
|
|
460
|
+
|
|
461
|
+
Returns all incomplete runs, including both missing runs (not yet attempted)
|
|
462
|
+
and failed runs (attempted but have errors).
|
|
463
|
+
|
|
464
|
+
Args:
|
|
465
|
+
experiment_id: The ID of the experiment
|
|
466
|
+
cursor: Cursor for pagination
|
|
467
|
+
limit: Maximum number of results to return
|
|
468
|
+
|
|
469
|
+
Returns:
|
|
470
|
+
Paginated list of incomplete runs grouped by dataset example,
|
|
471
|
+
with repetition numbers that need to be run
|
|
472
|
+
"""
|
|
473
|
+
try:
|
|
474
|
+
experiment_globalid = GlobalID.from_id(experiment_id)
|
|
475
|
+
except Exception as e:
|
|
476
|
+
raise HTTPException(
|
|
477
|
+
detail=f"Invalid experiment ID format: {experiment_id}",
|
|
478
|
+
status_code=422,
|
|
479
|
+
) from e
|
|
480
|
+
try:
|
|
481
|
+
id_ = from_global_id_with_expected_type(experiment_globalid, "Experiment")
|
|
482
|
+
except ValueError:
|
|
483
|
+
raise HTTPException(
|
|
484
|
+
detail=f"Experiment with ID {experiment_globalid} does not exist",
|
|
485
|
+
status_code=404,
|
|
486
|
+
)
|
|
487
|
+
|
|
488
|
+
# Parse cursor if provided
|
|
489
|
+
cursor_example_rowid: Optional[int] = None
|
|
490
|
+
if cursor:
|
|
491
|
+
try:
|
|
492
|
+
cursor_gid = GlobalID.from_id(cursor)
|
|
493
|
+
cursor_example_rowid = from_global_id_with_expected_type(cursor_gid, "DatasetExample")
|
|
494
|
+
except (ValueError, AttributeError):
|
|
495
|
+
raise HTTPException(
|
|
496
|
+
detail=f"Invalid cursor format: {cursor}",
|
|
497
|
+
status_code=422,
|
|
498
|
+
)
|
|
499
|
+
|
|
500
|
+
# Fetch experiment first (we need its repetitions count for the query)
|
|
501
|
+
async with request.app.state.db() as session:
|
|
502
|
+
experiment_result = await session.execute(select(models.Experiment).filter_by(id=id_))
|
|
503
|
+
experiment = experiment_result.scalar()
|
|
504
|
+
if not experiment:
|
|
505
|
+
raise HTTPException(
|
|
506
|
+
detail=f"Experiment with ID {experiment_globalid} does not exist",
|
|
507
|
+
status_code=404,
|
|
508
|
+
)
|
|
509
|
+
|
|
510
|
+
dialect = request.app.state.db.dialect
|
|
511
|
+
|
|
512
|
+
stmt = get_experiment_incomplete_runs_query(
|
|
513
|
+
experiment,
|
|
514
|
+
dialect,
|
|
515
|
+
cursor_example_rowid=cursor_example_rowid,
|
|
516
|
+
limit=limit,
|
|
517
|
+
)
|
|
518
|
+
|
|
519
|
+
result = await session.execute(stmt)
|
|
520
|
+
all_examples = result.all()
|
|
521
|
+
|
|
522
|
+
# Check if there's a next page
|
|
523
|
+
has_next_page = len(all_examples) > limit
|
|
524
|
+
if has_next_page:
|
|
525
|
+
# Remove the extra row
|
|
526
|
+
examples_to_process = all_examples[:limit]
|
|
527
|
+
# The cursor points to the FIRST item of the NEXT page
|
|
528
|
+
next_item_id = all_examples[limit][0].dataset_example_id
|
|
529
|
+
next_cursor = str(GlobalID("DatasetExample", str(next_item_id)))
|
|
530
|
+
else:
|
|
531
|
+
examples_to_process = all_examples
|
|
532
|
+
next_cursor = None
|
|
533
|
+
|
|
534
|
+
# Parse incomplete repetitions and build response
|
|
535
|
+
# Optimization: Precompute the "all repetitions" list for completely missing examples
|
|
536
|
+
# to avoid recomputing it for every missing example
|
|
537
|
+
all_repetitions = list(range(1, experiment.repetitions + 1))
|
|
538
|
+
incomplete_runs_list: list[IncompleteExperimentRun] = []
|
|
539
|
+
|
|
540
|
+
for revision, successful_count, incomplete_reps in examples_to_process:
|
|
541
|
+
example_id = revision.dataset_example_id
|
|
542
|
+
|
|
543
|
+
# Three regimes:
|
|
544
|
+
# 1. Completely missing (successful_count = 0): all repetitions are incomplete
|
|
545
|
+
# 2. Partially completed (0 < successful_count < R): parse from SQL result
|
|
546
|
+
# 3. Totally completed (successful_count = R): filtered out by SQL HAVING clause
|
|
547
|
+
|
|
548
|
+
if successful_count == 0:
|
|
549
|
+
# Regime 1: Completely missing - use precomputed list
|
|
550
|
+
incomplete = all_repetitions
|
|
551
|
+
else:
|
|
552
|
+
# Regime 2: Partially completed - parse incomplete reps from SQL
|
|
553
|
+
if dialect is SupportedSQLDialect.POSTGRESQL:
|
|
554
|
+
# PostgreSQL returns array (list), filter out nulls
|
|
555
|
+
incomplete = [r for r in incomplete_reps if r is not None]
|
|
556
|
+
else:
|
|
557
|
+
# SQLite returns JSON string
|
|
558
|
+
incomplete = [r for r in json.loads(incomplete_reps) if r is not None]
|
|
559
|
+
|
|
560
|
+
# Build response
|
|
561
|
+
example_globalid = GlobalID("DatasetExample", str(example_id))
|
|
562
|
+
incomplete_runs_list.append(
|
|
563
|
+
IncompleteExperimentRun(
|
|
564
|
+
dataset_example=DatasetExample(
|
|
565
|
+
id=str(example_globalid),
|
|
566
|
+
input=revision.input,
|
|
567
|
+
output=revision.output,
|
|
568
|
+
metadata=revision.metadata_,
|
|
569
|
+
updated_at=revision.created_at,
|
|
570
|
+
),
|
|
571
|
+
repetition_numbers=sorted(incomplete),
|
|
572
|
+
)
|
|
573
|
+
)
|
|
574
|
+
|
|
575
|
+
return GetIncompleteExperimentRunsResponseBody(
|
|
576
|
+
data=incomplete_runs_list, next_cursor=next_cursor
|
|
577
|
+
)
|
|
578
|
+
|
|
579
|
+
|
|
269
580
|
@router.get(
|
|
270
581
|
"/datasets/{dataset_id}/experiments",
|
|
271
582
|
operation_id="listExperiments",
|
|
272
583
|
summary="List experiments by dataset",
|
|
273
|
-
|
|
584
|
+
description="Retrieve a paginated list of experiments for the specified dataset.",
|
|
585
|
+
response_description="Paginated list of experiments for the dataset",
|
|
586
|
+
responses=add_errors_to_responses([422]),
|
|
274
587
|
)
|
|
275
588
|
async def list_experiments(
|
|
276
589
|
request: Request,
|
|
277
590
|
dataset_id: str = Path(..., title="Dataset ID"),
|
|
591
|
+
cursor: Optional[str] = Query(
|
|
592
|
+
default=None,
|
|
593
|
+
description="Cursor for pagination (base64-encoded experiment ID)",
|
|
594
|
+
),
|
|
595
|
+
limit: int = Query(
|
|
596
|
+
default=50, description="The max number of experiments to return at a time.", gt=0
|
|
597
|
+
),
|
|
278
598
|
) -> ListExperimentsResponseBody:
|
|
279
|
-
|
|
599
|
+
try:
|
|
600
|
+
dataset_gid = GlobalID.from_id(dataset_id)
|
|
601
|
+
except Exception as e:
|
|
602
|
+
raise HTTPException(
|
|
603
|
+
detail=f"Invalid dataset ID format: {dataset_id}",
|
|
604
|
+
status_code=422,
|
|
605
|
+
) from e
|
|
280
606
|
try:
|
|
281
607
|
dataset_rowid = from_global_id_with_expected_type(dataset_gid, "Dataset")
|
|
282
608
|
except ValueError:
|
|
283
609
|
raise HTTPException(
|
|
284
610
|
detail=f"Dataset with ID {dataset_gid} does not exist",
|
|
285
|
-
status_code=
|
|
611
|
+
status_code=404,
|
|
286
612
|
)
|
|
287
613
|
async with request.app.state.db() as session:
|
|
288
614
|
query = (
|
|
@@ -291,29 +617,119 @@ async def list_experiments(
|
|
|
291
617
|
.order_by(models.Experiment.id.desc())
|
|
292
618
|
)
|
|
293
619
|
|
|
620
|
+
# Handle cursor for pagination
|
|
621
|
+
if cursor:
|
|
622
|
+
try:
|
|
623
|
+
cursor_gid = GlobalID.from_id(cursor)
|
|
624
|
+
cursor_rowid = from_global_id_with_expected_type(cursor_gid, "Experiment")
|
|
625
|
+
query = query.where(models.Experiment.id <= cursor_rowid)
|
|
626
|
+
except (ValueError, Exception):
|
|
627
|
+
raise HTTPException(
|
|
628
|
+
detail=f"Invalid cursor format: {cursor}",
|
|
629
|
+
status_code=422,
|
|
630
|
+
)
|
|
631
|
+
|
|
632
|
+
# Overfetch by 1 to determine if there's a next page
|
|
633
|
+
query = query.limit(limit + 1)
|
|
634
|
+
|
|
294
635
|
result = await session.execute(query)
|
|
295
636
|
experiments = result.scalars().all()
|
|
296
637
|
|
|
297
638
|
if not experiments:
|
|
298
|
-
return ListExperimentsResponseBody(data=[])
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
639
|
+
return ListExperimentsResponseBody(data=[], next_cursor=None)
|
|
640
|
+
|
|
641
|
+
# Get example counts and successful run counts for all experiments in a single query
|
|
642
|
+
experiment_ids = [exp.id for exp in experiments]
|
|
643
|
+
|
|
644
|
+
# Create subqueries for counts
|
|
645
|
+
example_count_subq = (
|
|
646
|
+
select(
|
|
647
|
+
models.ExperimentDatasetExample.experiment_id, func.count().label("example_count")
|
|
648
|
+
)
|
|
649
|
+
.where(models.ExperimentDatasetExample.experiment_id.in_(experiment_ids))
|
|
650
|
+
.group_by(models.ExperimentDatasetExample.experiment_id)
|
|
651
|
+
.subquery()
|
|
652
|
+
)
|
|
653
|
+
|
|
654
|
+
# Optimize: Use CASE to count successful and failed in single table scan
|
|
655
|
+
run_counts_subq = (
|
|
656
|
+
select(
|
|
657
|
+
models.ExperimentRun.experiment_id,
|
|
658
|
+
func.sum(case((models.ExperimentRun.error.is_(None), 1), else_=0)).label(
|
|
659
|
+
"successful_run_count"
|
|
660
|
+
),
|
|
661
|
+
func.sum(case((models.ExperimentRun.error.is_not(None), 1), else_=0)).label(
|
|
662
|
+
"failed_run_count"
|
|
306
663
|
),
|
|
307
|
-
repetitions=experiment.repetitions,
|
|
308
|
-
metadata=experiment.metadata_,
|
|
309
|
-
project_name=experiment.project_name,
|
|
310
|
-
created_at=experiment.created_at,
|
|
311
|
-
updated_at=experiment.updated_at,
|
|
312
664
|
)
|
|
313
|
-
|
|
314
|
-
|
|
665
|
+
.where(models.ExperimentRun.experiment_id.in_(experiment_ids))
|
|
666
|
+
.group_by(models.ExperimentRun.experiment_id)
|
|
667
|
+
.subquery()
|
|
668
|
+
)
|
|
315
669
|
|
|
316
|
-
|
|
670
|
+
# Get all counts in one query using outer join
|
|
671
|
+
counts_result = await session.execute(
|
|
672
|
+
select(
|
|
673
|
+
func.coalesce(
|
|
674
|
+
example_count_subq.c.experiment_id,
|
|
675
|
+
run_counts_subq.c.experiment_id,
|
|
676
|
+
).label("experiment_id"),
|
|
677
|
+
func.coalesce(example_count_subq.c.example_count, 0).label("example_count"),
|
|
678
|
+
func.coalesce(run_counts_subq.c.successful_run_count, 0).label(
|
|
679
|
+
"successful_run_count"
|
|
680
|
+
),
|
|
681
|
+
func.coalesce(run_counts_subq.c.failed_run_count, 0).label("failed_run_count"),
|
|
682
|
+
)
|
|
683
|
+
.select_from(example_count_subq)
|
|
684
|
+
.outerjoin(
|
|
685
|
+
run_counts_subq,
|
|
686
|
+
example_count_subq.c.experiment_id == run_counts_subq.c.experiment_id,
|
|
687
|
+
)
|
|
688
|
+
)
|
|
689
|
+
|
|
690
|
+
counts_by_experiment = {
|
|
691
|
+
row.experiment_id: (row.example_count, row.successful_run_count, row.failed_run_count)
|
|
692
|
+
for row in counts_result
|
|
693
|
+
}
|
|
694
|
+
|
|
695
|
+
# Handle pagination: check if we have a next page
|
|
696
|
+
next_cursor = None
|
|
697
|
+
if len(experiments) == limit + 1:
|
|
698
|
+
last_experiment = experiments[-1]
|
|
699
|
+
next_cursor = str(GlobalID("Experiment", str(last_experiment.id)))
|
|
700
|
+
experiments = experiments[:-1] # Remove the extra overfetched experiment
|
|
701
|
+
|
|
702
|
+
data = []
|
|
703
|
+
for experiment in experiments:
|
|
704
|
+
counts = counts_by_experiment.get(experiment.id, (0, 0, 0))
|
|
705
|
+
example_count = counts[0]
|
|
706
|
+
successful_run_count = counts[1]
|
|
707
|
+
failed_run_count = counts[2]
|
|
708
|
+
|
|
709
|
+
# Calculate missing runs (no database query needed)
|
|
710
|
+
total_expected_runs = example_count * experiment.repetitions
|
|
711
|
+
missing_run_count = total_expected_runs - successful_run_count - failed_run_count
|
|
712
|
+
|
|
713
|
+
data.append(
|
|
714
|
+
Experiment(
|
|
715
|
+
id=str(GlobalID("Experiment", str(experiment.id))),
|
|
716
|
+
dataset_id=str(GlobalID("Dataset", str(experiment.dataset_id))),
|
|
717
|
+
dataset_version_id=str(
|
|
718
|
+
GlobalID("DatasetVersion", str(experiment.dataset_version_id))
|
|
719
|
+
),
|
|
720
|
+
repetitions=experiment.repetitions,
|
|
721
|
+
metadata=experiment.metadata_,
|
|
722
|
+
project_name=experiment.project_name,
|
|
723
|
+
created_at=experiment.created_at,
|
|
724
|
+
updated_at=experiment.updated_at,
|
|
725
|
+
example_count=example_count,
|
|
726
|
+
successful_run_count=successful_run_count,
|
|
727
|
+
failed_run_count=failed_run_count,
|
|
728
|
+
missing_run_count=missing_run_count,
|
|
729
|
+
)
|
|
730
|
+
)
|
|
731
|
+
|
|
732
|
+
return ListExperimentsResponseBody(data=data, next_cursor=next_cursor)
|
|
317
733
|
|
|
318
734
|
|
|
319
735
|
async def _get_experiment_runs_and_revisions(
|
|
@@ -321,7 +737,7 @@ async def _get_experiment_runs_and_revisions(
|
|
|
321
737
|
) -> tuple[models.Experiment, tuple[models.ExperimentRun], tuple[models.DatasetExampleRevision]]:
|
|
322
738
|
experiment = await session.get(models.Experiment, experiment_rowid)
|
|
323
739
|
if not experiment:
|
|
324
|
-
raise HTTPException(detail="Experiment not found", status_code=
|
|
740
|
+
raise HTTPException(detail="Experiment not found", status_code=404)
|
|
325
741
|
revision_ids = (
|
|
326
742
|
select(func.max(models.DatasetExampleRevision.id))
|
|
327
743
|
.join(
|
|
@@ -370,7 +786,7 @@ async def _get_experiment_runs_and_revisions(
|
|
|
370
786
|
if not runs_and_revisions:
|
|
371
787
|
raise HTTPException(
|
|
372
788
|
detail="Experiment has no runs",
|
|
373
|
-
status_code=
|
|
789
|
+
status_code=404,
|
|
374
790
|
)
|
|
375
791
|
runs, revisions = zip(*runs_and_revisions)
|
|
376
792
|
return experiment, runs, revisions
|
|
@@ -383,7 +799,7 @@ async def _get_experiment_runs_and_revisions(
|
|
|
383
799
|
response_class=PlainTextResponse,
|
|
384
800
|
responses=add_errors_to_responses(
|
|
385
801
|
[
|
|
386
|
-
{"status_code":
|
|
802
|
+
{"status_code": 404, "description": "Experiment not found"},
|
|
387
803
|
]
|
|
388
804
|
),
|
|
389
805
|
)
|
|
@@ -391,13 +807,19 @@ async def get_experiment_json(
|
|
|
391
807
|
request: Request,
|
|
392
808
|
experiment_id: str = Path(..., title="Experiment ID"),
|
|
393
809
|
) -> Response:
|
|
394
|
-
|
|
810
|
+
try:
|
|
811
|
+
experiment_globalid = GlobalID.from_id(experiment_id)
|
|
812
|
+
except Exception as e:
|
|
813
|
+
raise HTTPException(
|
|
814
|
+
detail=f"Invalid experiment ID format: {experiment_id}",
|
|
815
|
+
status_code=422,
|
|
816
|
+
) from e
|
|
395
817
|
try:
|
|
396
818
|
experiment_rowid = from_global_id_with_expected_type(experiment_globalid, "Experiment")
|
|
397
819
|
except ValueError:
|
|
398
820
|
raise HTTPException(
|
|
399
821
|
detail=f"Invalid experiment ID: {experiment_globalid}",
|
|
400
|
-
status_code=
|
|
822
|
+
status_code=422,
|
|
401
823
|
)
|
|
402
824
|
|
|
403
825
|
async with request.app.state.db() as session:
|
|
@@ -452,19 +874,25 @@ async def get_experiment_json(
|
|
|
452
874
|
"/experiments/{experiment_id}/csv",
|
|
453
875
|
operation_id="getExperimentCSV",
|
|
454
876
|
summary="Download experiment runs as a CSV file",
|
|
455
|
-
responses={**add_text_csv_content_to_responses(
|
|
877
|
+
responses={**add_text_csv_content_to_responses(200)},
|
|
456
878
|
)
|
|
457
879
|
async def get_experiment_csv(
|
|
458
880
|
request: Request,
|
|
459
881
|
experiment_id: str = Path(..., title="Experiment ID"),
|
|
460
882
|
) -> Response:
|
|
461
|
-
|
|
883
|
+
try:
|
|
884
|
+
experiment_globalid = GlobalID.from_id(experiment_id)
|
|
885
|
+
except Exception as e:
|
|
886
|
+
raise HTTPException(
|
|
887
|
+
detail=f"Invalid experiment ID format: {experiment_id}",
|
|
888
|
+
status_code=422,
|
|
889
|
+
) from e
|
|
462
890
|
try:
|
|
463
891
|
experiment_rowid = from_global_id_with_expected_type(experiment_globalid, "Experiment")
|
|
464
892
|
except ValueError:
|
|
465
893
|
raise HTTPException(
|
|
466
894
|
detail=f"Invalid experiment ID: {experiment_globalid}",
|
|
467
|
-
status_code=
|
|
895
|
+
status_code=422,
|
|
468
896
|
)
|
|
469
897
|
|
|
470
898
|
async with request.app.state.db() as session:
|