arize-phoenix 10.0.4__py3-none-any.whl → 12.28.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +124 -72
- arize_phoenix-12.28.1.dist-info/RECORD +499 -0
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
- phoenix/__generated__/__init__.py +0 -0
- phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
- phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
- phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
- phoenix/__init__.py +5 -4
- phoenix/auth.py +39 -2
- phoenix/config.py +1763 -91
- phoenix/datetime_utils.py +120 -2
- phoenix/db/README.md +595 -25
- phoenix/db/bulk_inserter.py +145 -103
- phoenix/db/engines.py +140 -33
- phoenix/db/enums.py +3 -12
- phoenix/db/facilitator.py +302 -35
- phoenix/db/helpers.py +1000 -65
- phoenix/db/iam_auth.py +64 -0
- phoenix/db/insertion/dataset.py +135 -2
- phoenix/db/insertion/document_annotation.py +9 -6
- phoenix/db/insertion/evaluation.py +2 -3
- phoenix/db/insertion/helpers.py +17 -2
- phoenix/db/insertion/session_annotation.py +176 -0
- phoenix/db/insertion/span.py +15 -11
- phoenix/db/insertion/span_annotation.py +3 -4
- phoenix/db/insertion/trace_annotation.py +3 -4
- phoenix/db/insertion/types.py +50 -20
- phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
- phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
- phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
- phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
- phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
- phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
- phoenix/db/migrations/versions/a20694b15f82_cost.py +196 -0
- phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
- phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
- phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
- phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
- phoenix/db/models.py +669 -56
- phoenix/db/pg_config.py +10 -0
- phoenix/db/types/model_provider.py +4 -0
- phoenix/db/types/token_price_customization.py +29 -0
- phoenix/db/types/trace_retention.py +23 -15
- phoenix/experiments/evaluators/utils.py +3 -3
- phoenix/experiments/functions.py +160 -52
- phoenix/experiments/tracing.py +2 -2
- phoenix/experiments/types.py +1 -1
- phoenix/inferences/inferences.py +1 -2
- phoenix/server/api/auth.py +38 -7
- phoenix/server/api/auth_messages.py +46 -0
- phoenix/server/api/context.py +100 -4
- phoenix/server/api/dataloaders/__init__.py +79 -5
- phoenix/server/api/dataloaders/annotation_configs_by_project.py +31 -0
- phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
- phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
- phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
- phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
- phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
- phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
- phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
- phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
- phoenix/server/api/dataloaders/dataset_labels.py +36 -0
- phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
- phoenix/server/api/dataloaders/document_evaluations.py +6 -9
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
- phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
- phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
- phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
- phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
- phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
- phoenix/server/api/dataloaders/last_used_times_by_generative_model_id.py +35 -0
- phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
- phoenix/server/api/dataloaders/record_counts.py +37 -10
- phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
- phoenix/server/api/dataloaders/span_cost_by_span.py +24 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_generative_model.py +56 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_project_session.py +57 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_span.py +43 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_trace.py +56 -0
- phoenix/server/api/dataloaders/span_cost_details_by_span_cost.py +27 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment.py +57 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_run.py +58 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_generative_model.py +55 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project.py +152 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project_session.py +56 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_trace.py +55 -0
- phoenix/server/api/dataloaders/span_costs.py +29 -0
- phoenix/server/api/dataloaders/table_fields.py +2 -2
- phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
- phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
- phoenix/server/api/dataloaders/types.py +29 -0
- phoenix/server/api/exceptions.py +11 -1
- phoenix/server/api/helpers/dataset_helpers.py +5 -1
- phoenix/server/api/helpers/playground_clients.py +1243 -292
- phoenix/server/api/helpers/playground_registry.py +2 -2
- phoenix/server/api/helpers/playground_spans.py +8 -4
- phoenix/server/api/helpers/playground_users.py +26 -0
- phoenix/server/api/helpers/prompts/conversions/aws.py +83 -0
- phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
- phoenix/server/api/helpers/prompts/models.py +205 -22
- phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
- phoenix/server/api/input_types/ChatCompletionInput.py +6 -2
- phoenix/server/api/input_types/CreateProjectInput.py +27 -0
- phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
- phoenix/server/api/input_types/DatasetFilter.py +17 -0
- phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
- phoenix/server/api/input_types/GenerativeCredentialInput.py +9 -0
- phoenix/server/api/input_types/GenerativeModelInput.py +5 -0
- phoenix/server/api/input_types/ProjectSessionSort.py +161 -1
- phoenix/server/api/input_types/PromptFilter.py +14 -0
- phoenix/server/api/input_types/PromptVersionInput.py +52 -1
- phoenix/server/api/input_types/SpanSort.py +44 -7
- phoenix/server/api/input_types/TimeBinConfig.py +23 -0
- phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
- phoenix/server/api/input_types/UserRoleInput.py +1 -0
- phoenix/server/api/mutations/__init__.py +10 -0
- phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
- phoenix/server/api/mutations/api_key_mutations.py +19 -23
- phoenix/server/api/mutations/chat_mutations.py +154 -47
- phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
- phoenix/server/api/mutations/dataset_mutations.py +21 -16
- phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
- phoenix/server/api/mutations/experiment_mutations.py +2 -2
- phoenix/server/api/mutations/export_events_mutations.py +3 -3
- phoenix/server/api/mutations/model_mutations.py +210 -0
- phoenix/server/api/mutations/project_mutations.py +49 -10
- phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
- phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
- phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
- phoenix/server/api/mutations/prompt_mutations.py +65 -129
- phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
- phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
- phoenix/server/api/mutations/trace_annotations_mutations.py +14 -10
- phoenix/server/api/mutations/trace_mutations.py +47 -3
- phoenix/server/api/mutations/user_mutations.py +66 -41
- phoenix/server/api/queries.py +768 -293
- phoenix/server/api/routers/__init__.py +2 -2
- phoenix/server/api/routers/auth.py +154 -88
- phoenix/server/api/routers/ldap.py +229 -0
- phoenix/server/api/routers/oauth2.py +369 -106
- phoenix/server/api/routers/v1/__init__.py +24 -4
- phoenix/server/api/routers/v1/annotation_configs.py +23 -31
- phoenix/server/api/routers/v1/annotations.py +481 -17
- phoenix/server/api/routers/v1/datasets.py +395 -81
- phoenix/server/api/routers/v1/documents.py +142 -0
- phoenix/server/api/routers/v1/evaluations.py +24 -31
- phoenix/server/api/routers/v1/experiment_evaluations.py +19 -8
- phoenix/server/api/routers/v1/experiment_runs.py +337 -59
- phoenix/server/api/routers/v1/experiments.py +479 -48
- phoenix/server/api/routers/v1/models.py +7 -0
- phoenix/server/api/routers/v1/projects.py +18 -49
- phoenix/server/api/routers/v1/prompts.py +54 -40
- phoenix/server/api/routers/v1/sessions.py +108 -0
- phoenix/server/api/routers/v1/spans.py +1091 -81
- phoenix/server/api/routers/v1/traces.py +132 -78
- phoenix/server/api/routers/v1/users.py +389 -0
- phoenix/server/api/routers/v1/utils.py +3 -7
- phoenix/server/api/subscriptions.py +305 -88
- phoenix/server/api/types/Annotation.py +90 -23
- phoenix/server/api/types/ApiKey.py +13 -17
- phoenix/server/api/types/AuthMethod.py +1 -0
- phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
- phoenix/server/api/types/CostBreakdown.py +12 -0
- phoenix/server/api/types/Dataset.py +226 -72
- phoenix/server/api/types/DatasetExample.py +88 -18
- phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
- phoenix/server/api/types/DatasetLabel.py +57 -0
- phoenix/server/api/types/DatasetSplit.py +98 -0
- phoenix/server/api/types/DatasetVersion.py +49 -4
- phoenix/server/api/types/DocumentAnnotation.py +212 -0
- phoenix/server/api/types/Experiment.py +264 -59
- phoenix/server/api/types/ExperimentComparison.py +5 -10
- phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
- phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
- phoenix/server/api/types/ExperimentRun.py +169 -65
- phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
- phoenix/server/api/types/GenerativeModel.py +245 -3
- phoenix/server/api/types/GenerativeProvider.py +70 -11
- phoenix/server/api/types/{Model.py → InferenceModel.py} +1 -1
- phoenix/server/api/types/ModelInterface.py +16 -0
- phoenix/server/api/types/PlaygroundModel.py +20 -0
- phoenix/server/api/types/Project.py +1278 -216
- phoenix/server/api/types/ProjectSession.py +188 -28
- phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
- phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
- phoenix/server/api/types/Prompt.py +119 -39
- phoenix/server/api/types/PromptLabel.py +42 -25
- phoenix/server/api/types/PromptVersion.py +11 -8
- phoenix/server/api/types/PromptVersionTag.py +65 -25
- phoenix/server/api/types/ServerStatus.py +6 -0
- phoenix/server/api/types/Span.py +167 -123
- phoenix/server/api/types/SpanAnnotation.py +189 -42
- phoenix/server/api/types/SpanCostDetailSummaryEntry.py +10 -0
- phoenix/server/api/types/SpanCostSummary.py +10 -0
- phoenix/server/api/types/SystemApiKey.py +65 -1
- phoenix/server/api/types/TokenPrice.py +16 -0
- phoenix/server/api/types/TokenUsage.py +3 -3
- phoenix/server/api/types/Trace.py +223 -51
- phoenix/server/api/types/TraceAnnotation.py +149 -50
- phoenix/server/api/types/User.py +137 -32
- phoenix/server/api/types/UserApiKey.py +73 -26
- phoenix/server/api/types/node.py +10 -0
- phoenix/server/api/types/pagination.py +11 -2
- phoenix/server/app.py +290 -45
- phoenix/server/authorization.py +38 -3
- phoenix/server/bearer_auth.py +34 -24
- phoenix/server/cost_tracking/cost_details_calculator.py +196 -0
- phoenix/server/cost_tracking/cost_model_lookup.py +179 -0
- phoenix/server/cost_tracking/helpers.py +68 -0
- phoenix/server/cost_tracking/model_cost_manifest.json +3657 -830
- phoenix/server/cost_tracking/regex_specificity.py +397 -0
- phoenix/server/cost_tracking/token_cost_calculator.py +57 -0
- phoenix/server/daemons/__init__.py +0 -0
- phoenix/server/daemons/db_disk_usage_monitor.py +214 -0
- phoenix/server/daemons/generative_model_store.py +103 -0
- phoenix/server/daemons/span_cost_calculator.py +99 -0
- phoenix/server/dml_event.py +17 -0
- phoenix/server/dml_event_handler.py +5 -0
- phoenix/server/email/sender.py +56 -3
- phoenix/server/email/templates/db_disk_usage_notification.html +19 -0
- phoenix/server/email/types.py +11 -0
- phoenix/server/experiments/__init__.py +0 -0
- phoenix/server/experiments/utils.py +14 -0
- phoenix/server/grpc_server.py +11 -11
- phoenix/server/jwt_store.py +17 -15
- phoenix/server/ldap.py +1449 -0
- phoenix/server/main.py +26 -10
- phoenix/server/oauth2.py +330 -12
- phoenix/server/prometheus.py +66 -6
- phoenix/server/rate_limiters.py +4 -9
- phoenix/server/retention.py +33 -20
- phoenix/server/session_filters.py +49 -0
- phoenix/server/static/.vite/manifest.json +55 -51
- phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
- phoenix/server/static/assets/{index-E0M82BdE.js → index-CTQoemZv.js} +140 -56
- phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
- phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
- phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
- phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
- phoenix/server/static/assets/vendor-recharts-V9cwpXsm.js +37 -0
- phoenix/server/static/assets/vendor-shiki-Do--csgv.js +5 -0
- phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
- phoenix/server/templates/index.html +40 -6
- phoenix/server/thread_server.py +1 -2
- phoenix/server/types.py +14 -4
- phoenix/server/utils.py +74 -0
- phoenix/session/client.py +56 -3
- phoenix/session/data_extractor.py +5 -0
- phoenix/session/evaluation.py +14 -5
- phoenix/session/session.py +45 -9
- phoenix/settings.py +5 -0
- phoenix/trace/attributes.py +80 -13
- phoenix/trace/dsl/helpers.py +90 -1
- phoenix/trace/dsl/query.py +8 -6
- phoenix/trace/projects.py +5 -0
- phoenix/utilities/template_formatters.py +1 -1
- phoenix/version.py +1 -1
- arize_phoenix-10.0.4.dist-info/RECORD +0 -405
- phoenix/server/api/types/Evaluation.py +0 -39
- phoenix/server/cost_tracking/cost_lookup.py +0 -255
- phoenix/server/static/assets/components-DULKeDfL.js +0 -4365
- phoenix/server/static/assets/pages-Cl0A-0U2.js +0 -7430
- phoenix/server/static/assets/vendor-WIZid84E.css +0 -1
- phoenix/server/static/assets/vendor-arizeai-Dy-0mSNw.js +0 -649
- phoenix/server/static/assets/vendor-codemirror-DBtifKNr.js +0 -33
- phoenix/server/static/assets/vendor-oB4u9zuV.js +0 -905
- phoenix/server/static/assets/vendor-recharts-D-T4KPz2.js +0 -59
- phoenix/server/static/assets/vendor-shiki-BMn4O_9F.js +0 -5
- phoenix/server/static/assets/vendor-three-C5WAXd5r.js +0 -2998
- phoenix/utilities/deprecation.py +0 -31
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -4,24 +4,38 @@ from random import getrandbits
|
|
|
4
4
|
from typing import Any, Optional
|
|
5
5
|
|
|
6
6
|
import pandas as pd
|
|
7
|
-
|
|
7
|
+
import sqlalchemy as sa
|
|
8
|
+
from fastapi import APIRouter, Depends, HTTPException, Path, Query, Response
|
|
8
9
|
from pydantic import Field
|
|
9
|
-
from sqlalchemy import and_, func, select
|
|
10
|
+
from sqlalchemy import and_, case, func, select
|
|
10
11
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
11
12
|
from sqlalchemy.orm import joinedload
|
|
12
13
|
from starlette.requests import Request
|
|
13
14
|
from starlette.responses import PlainTextResponse
|
|
14
|
-
from starlette.status import HTTP_200_OK, HTTP_404_NOT_FOUND, HTTP_422_UNPROCESSABLE_ENTITY
|
|
15
15
|
from strawberry.relay import GlobalID
|
|
16
16
|
|
|
17
17
|
from phoenix.db import models
|
|
18
|
-
from phoenix.db.helpers import
|
|
18
|
+
from phoenix.db.helpers import (
|
|
19
|
+
SupportedSQLDialect,
|
|
20
|
+
get_experiment_incomplete_runs_query,
|
|
21
|
+
insert_experiment_with_examples_snapshot,
|
|
22
|
+
)
|
|
19
23
|
from phoenix.db.insertion.helpers import insert_on_conflict
|
|
24
|
+
from phoenix.server.api.routers.v1.datasets import DatasetExample
|
|
20
25
|
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
26
|
+
from phoenix.server.authorization import is_not_locked
|
|
27
|
+
from phoenix.server.bearer_auth import PhoenixUser
|
|
21
28
|
from phoenix.server.dml_event import ExperimentInsertEvent
|
|
29
|
+
from phoenix.server.experiments.utils import generate_experiment_project_name
|
|
22
30
|
|
|
31
|
+
from .datasets import _resolve_split_identifiers
|
|
23
32
|
from .models import V1RoutesBaseModel
|
|
24
|
-
from .utils import
|
|
33
|
+
from .utils import (
|
|
34
|
+
PaginatedResponseBody,
|
|
35
|
+
ResponseBody,
|
|
36
|
+
add_errors_to_responses,
|
|
37
|
+
add_text_csv_content_to_responses,
|
|
38
|
+
)
|
|
25
39
|
|
|
26
40
|
router = APIRouter(tags=["experiments"], include_in_schema=True)
|
|
27
41
|
|
|
@@ -44,13 +58,19 @@ class Experiment(V1RoutesBaseModel):
|
|
|
44
58
|
dataset_version_id: str = Field(
|
|
45
59
|
description="The ID of the dataset version associated with the experiment"
|
|
46
60
|
)
|
|
47
|
-
repetitions: int = Field(description="Number of times the experiment is repeated")
|
|
61
|
+
repetitions: int = Field(description="Number of times the experiment is repeated", gt=0)
|
|
48
62
|
metadata: dict[str, Any] = Field(description="Metadata of the experiment")
|
|
49
63
|
project_name: Optional[str] = Field(
|
|
50
64
|
description="The name of the project associated with the experiment"
|
|
51
65
|
)
|
|
52
66
|
created_at: datetime = Field(description="The creation timestamp of the experiment")
|
|
53
67
|
updated_at: datetime = Field(description="The last update timestamp of the experiment")
|
|
68
|
+
example_count: int = Field(description="Number of examples in the experiment")
|
|
69
|
+
successful_run_count: int = Field(description="Number of successful runs in the experiment")
|
|
70
|
+
failed_run_count: int = Field(description="Number of failed runs in the experiment")
|
|
71
|
+
missing_run_count: int = Field(
|
|
72
|
+
description="Number of missing (not yet executed) runs in the experiment"
|
|
73
|
+
)
|
|
54
74
|
|
|
55
75
|
|
|
56
76
|
class CreateExperimentRequestBody(V1RoutesBaseModel):
|
|
@@ -75,6 +95,10 @@ class CreateExperimentRequestBody(V1RoutesBaseModel):
|
|
|
75
95
|
"(if omitted, the latest version will be used)"
|
|
76
96
|
),
|
|
77
97
|
)
|
|
98
|
+
splits: Optional[list[str]] = Field(
|
|
99
|
+
default=None,
|
|
100
|
+
description="List of dataset split identifiers (GlobalIDs or names) to filter by",
|
|
101
|
+
)
|
|
78
102
|
repetitions: int = Field(
|
|
79
103
|
default=1, description="Number of times the experiment should be repeated for each example"
|
|
80
104
|
)
|
|
@@ -86,10 +110,11 @@ class CreateExperimentResponseBody(ResponseBody[Experiment]):
|
|
|
86
110
|
|
|
87
111
|
@router.post(
|
|
88
112
|
"/datasets/{dataset_id}/experiments",
|
|
113
|
+
dependencies=[Depends(is_not_locked)],
|
|
89
114
|
operation_id="createExperiment",
|
|
90
115
|
summary="Create experiment on a dataset",
|
|
91
116
|
responses=add_errors_to_responses(
|
|
92
|
-
[{"status_code":
|
|
117
|
+
[{"status_code": 404, "description": "Dataset or DatasetVersion not found"}]
|
|
93
118
|
),
|
|
94
119
|
response_description="Experiment retrieved successfully",
|
|
95
120
|
)
|
|
@@ -98,26 +123,38 @@ async def create_experiment(
|
|
|
98
123
|
request_body: CreateExperimentRequestBody,
|
|
99
124
|
dataset_id: str = Path(..., title="Dataset ID"),
|
|
100
125
|
) -> CreateExperimentResponseBody:
|
|
101
|
-
|
|
126
|
+
try:
|
|
127
|
+
dataset_globalid = GlobalID.from_id(dataset_id)
|
|
128
|
+
except Exception as e:
|
|
129
|
+
raise HTTPException(
|
|
130
|
+
detail=f"Invalid dataset ID format: {dataset_id}",
|
|
131
|
+
status_code=422,
|
|
132
|
+
) from e
|
|
102
133
|
try:
|
|
103
134
|
dataset_rowid = from_global_id_with_expected_type(dataset_globalid, "Dataset")
|
|
104
135
|
except ValueError:
|
|
105
136
|
raise HTTPException(
|
|
106
137
|
detail="Dataset with ID {dataset_globalid} does not exist",
|
|
107
|
-
status_code=
|
|
138
|
+
status_code=404,
|
|
108
139
|
)
|
|
109
140
|
|
|
110
141
|
dataset_version_globalid_str = request_body.version_id
|
|
111
142
|
if dataset_version_globalid_str is not None:
|
|
112
143
|
try:
|
|
113
144
|
dataset_version_globalid = GlobalID.from_id(dataset_version_globalid_str)
|
|
145
|
+
except Exception as e:
|
|
146
|
+
raise HTTPException(
|
|
147
|
+
detail=f"Invalid dataset version ID format: {dataset_version_globalid_str}",
|
|
148
|
+
status_code=422,
|
|
149
|
+
) from e
|
|
150
|
+
try:
|
|
114
151
|
dataset_version_id = from_global_id_with_expected_type(
|
|
115
152
|
dataset_version_globalid, "DatasetVersion"
|
|
116
153
|
)
|
|
117
154
|
except ValueError:
|
|
118
155
|
raise HTTPException(
|
|
119
156
|
detail=f"DatasetVersion with ID {dataset_version_globalid_str} does not exist",
|
|
120
|
-
status_code=
|
|
157
|
+
status_code=404,
|
|
121
158
|
)
|
|
122
159
|
|
|
123
160
|
async with request.app.state.db() as session:
|
|
@@ -127,7 +164,7 @@ async def create_experiment(
|
|
|
127
164
|
if result is None:
|
|
128
165
|
raise HTTPException(
|
|
129
166
|
detail=f"Dataset with ID {dataset_globalid} does not exist",
|
|
130
|
-
status_code=
|
|
167
|
+
status_code=404,
|
|
131
168
|
)
|
|
132
169
|
dataset_name = result.name
|
|
133
170
|
if dataset_version_globalid_str is None:
|
|
@@ -140,7 +177,7 @@ async def create_experiment(
|
|
|
140
177
|
if not dataset_version:
|
|
141
178
|
raise HTTPException(
|
|
142
179
|
detail=f"Dataset {dataset_globalid} does not have any versions",
|
|
143
|
-
status_code=
|
|
180
|
+
status_code=404,
|
|
144
181
|
)
|
|
145
182
|
dataset_version_id = dataset_version.id
|
|
146
183
|
dataset_version_globalid = GlobalID("DatasetVersion", str(dataset_version_id))
|
|
@@ -152,12 +189,15 @@ async def create_experiment(
|
|
|
152
189
|
if not dataset_version:
|
|
153
190
|
raise HTTPException(
|
|
154
191
|
detail=f"DatasetVersion with ID {dataset_version_globalid} does not exist",
|
|
155
|
-
status_code=
|
|
192
|
+
status_code=404,
|
|
156
193
|
)
|
|
194
|
+
user_id: Optional[int] = None
|
|
195
|
+
if request.app.state.authentication_enabled and isinstance(request.user, PhoenixUser):
|
|
196
|
+
user_id = int(request.user.identity)
|
|
157
197
|
|
|
158
198
|
# generate a semi-unique name for the experiment
|
|
159
199
|
experiment_name = request_body.name or _generate_experiment_name(dataset_name)
|
|
160
|
-
project_name =
|
|
200
|
+
project_name = generate_experiment_project_name()
|
|
161
201
|
project_description = (
|
|
162
202
|
f"dataset_id: {dataset_globalid}\ndataset_version_id: {dataset_version_globalid}"
|
|
163
203
|
)
|
|
@@ -169,9 +209,23 @@ async def create_experiment(
|
|
|
169
209
|
repetitions=request_body.repetitions,
|
|
170
210
|
metadata_=request_body.metadata or {},
|
|
171
211
|
project_name=project_name,
|
|
212
|
+
user_id=user_id,
|
|
172
213
|
)
|
|
173
|
-
|
|
174
|
-
|
|
214
|
+
|
|
215
|
+
if request_body.splits is not None:
|
|
216
|
+
# Resolve split identifiers (IDs or names) to IDs and names
|
|
217
|
+
resolved_split_ids, _ = await _resolve_split_identifiers(session, request_body.splits)
|
|
218
|
+
|
|
219
|
+
# Generate experiment dataset splits relation
|
|
220
|
+
# prior to the crosswalk table insert
|
|
221
|
+
# in insert_experiment_with_examples_snapshot
|
|
222
|
+
experiment.experiment_dataset_splits = [
|
|
223
|
+
models.ExperimentDatasetSplit(dataset_split_id=split_id)
|
|
224
|
+
for split_id in resolved_split_ids
|
|
225
|
+
]
|
|
226
|
+
|
|
227
|
+
# crosswalk table assumes the relation is already present
|
|
228
|
+
await insert_experiment_with_examples_snapshot(session, experiment)
|
|
175
229
|
|
|
176
230
|
dialect = SupportedSQLDialect(session.bind.dialect.name)
|
|
177
231
|
project_rowid = await session.scalar(
|
|
@@ -194,6 +248,19 @@ async def create_experiment(
|
|
|
194
248
|
dataset_version_globalid = GlobalID(
|
|
195
249
|
"DatasetVersion", str(experiment.dataset_version_id)
|
|
196
250
|
)
|
|
251
|
+
|
|
252
|
+
# Optimization: We just created this experiment, so we know there are 0 runs.
|
|
253
|
+
# No need to query ExperimentRun table - just count the examples.
|
|
254
|
+
example_count = await session.scalar(
|
|
255
|
+
select(func.count())
|
|
256
|
+
.select_from(models.ExperimentDatasetExample)
|
|
257
|
+
.where(models.ExperimentDatasetExample.experiment_id == experiment.id)
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
# No runs exist yet for a newly created experiment
|
|
261
|
+
successful_run_count = 0
|
|
262
|
+
failed_run_count = 0
|
|
263
|
+
missing_run_count = (example_count or 0) * experiment.repetitions
|
|
197
264
|
request.state.event_queue.put(ExperimentInsertEvent((experiment.id,)))
|
|
198
265
|
return CreateExperimentResponseBody(
|
|
199
266
|
data=Experiment(
|
|
@@ -205,6 +272,10 @@ async def create_experiment(
|
|
|
205
272
|
project_name=experiment.project_name,
|
|
206
273
|
created_at=experiment.created_at,
|
|
207
274
|
updated_at=experiment.updated_at,
|
|
275
|
+
example_count=example_count or 0,
|
|
276
|
+
successful_run_count=successful_run_count or 0,
|
|
277
|
+
failed_run_count=failed_run_count or 0,
|
|
278
|
+
missing_run_count=missing_run_count,
|
|
208
279
|
)
|
|
209
280
|
)
|
|
210
281
|
|
|
@@ -218,18 +289,24 @@ class GetExperimentResponseBody(ResponseBody[Experiment]):
|
|
|
218
289
|
operation_id="getExperiment",
|
|
219
290
|
summary="Get experiment by ID",
|
|
220
291
|
responses=add_errors_to_responses(
|
|
221
|
-
[{"status_code":
|
|
292
|
+
[{"status_code": 404, "description": "Experiment not found"}]
|
|
222
293
|
),
|
|
223
294
|
response_description="Experiment retrieved successfully",
|
|
224
295
|
)
|
|
225
296
|
async def get_experiment(request: Request, experiment_id: str) -> GetExperimentResponseBody:
|
|
226
|
-
|
|
297
|
+
try:
|
|
298
|
+
experiment_globalid = GlobalID.from_id(experiment_id)
|
|
299
|
+
except Exception as e:
|
|
300
|
+
raise HTTPException(
|
|
301
|
+
detail=f"Invalid experiment ID format: {experiment_id}",
|
|
302
|
+
status_code=422,
|
|
303
|
+
) from e
|
|
227
304
|
try:
|
|
228
305
|
experiment_rowid = from_global_id_with_expected_type(experiment_globalid, "Experiment")
|
|
229
306
|
except ValueError:
|
|
230
307
|
raise HTTPException(
|
|
231
308
|
detail="Experiment with ID {experiment_globalid} does not exist",
|
|
232
|
-
status_code=
|
|
309
|
+
status_code=404,
|
|
233
310
|
)
|
|
234
311
|
|
|
235
312
|
async with request.app.state.db() as session:
|
|
@@ -240,11 +317,48 @@ async def get_experiment(request: Request, experiment_id: str) -> GetExperimentR
|
|
|
240
317
|
if not experiment:
|
|
241
318
|
raise HTTPException(
|
|
242
319
|
detail=f"Experiment with ID {experiment_globalid} does not exist",
|
|
243
|
-
status_code=
|
|
320
|
+
status_code=404,
|
|
244
321
|
)
|
|
245
322
|
|
|
246
323
|
dataset_globalid = GlobalID("Dataset", str(experiment.dataset_id))
|
|
247
324
|
dataset_version_globalid = GlobalID("DatasetVersion", str(experiment.dataset_version_id))
|
|
325
|
+
|
|
326
|
+
# Get counts efficiently: use CASE to count successful and failed in single table scan
|
|
327
|
+
run_counts_subq = (
|
|
328
|
+
select(
|
|
329
|
+
func.sum(case((models.ExperimentRun.error.is_(None), 1), else_=0)).label(
|
|
330
|
+
"successful_run_count"
|
|
331
|
+
),
|
|
332
|
+
func.sum(case((models.ExperimentRun.error.is_not(None), 1), else_=0)).label(
|
|
333
|
+
"failed_run_count"
|
|
334
|
+
),
|
|
335
|
+
)
|
|
336
|
+
.select_from(models.ExperimentRun)
|
|
337
|
+
.where(models.ExperimentRun.experiment_id == experiment_rowid)
|
|
338
|
+
.subquery()
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
counts_result = await session.execute(
|
|
342
|
+
select(
|
|
343
|
+
select(func.count())
|
|
344
|
+
.select_from(models.ExperimentDatasetExample)
|
|
345
|
+
.where(models.ExperimentDatasetExample.experiment_id == experiment_rowid)
|
|
346
|
+
.scalar_subquery()
|
|
347
|
+
.label("example_count"),
|
|
348
|
+
run_counts_subq.c.successful_run_count,
|
|
349
|
+
run_counts_subq.c.failed_run_count,
|
|
350
|
+
).select_from(run_counts_subq)
|
|
351
|
+
)
|
|
352
|
+
counts = counts_result.one()
|
|
353
|
+
example_count = counts.example_count
|
|
354
|
+
successful_run_count = counts.successful_run_count
|
|
355
|
+
failed_run_count = counts.failed_run_count
|
|
356
|
+
|
|
357
|
+
# Calculate missing runs (no database query needed)
|
|
358
|
+
total_expected_runs = (example_count or 0) * experiment.repetitions
|
|
359
|
+
missing_run_count = (
|
|
360
|
+
total_expected_runs - (successful_run_count or 0) - (failed_run_count or 0)
|
|
361
|
+
)
|
|
248
362
|
return GetExperimentResponseBody(
|
|
249
363
|
data=Experiment(
|
|
250
364
|
id=str(experiment_globalid),
|
|
@@ -255,31 +369,246 @@ async def get_experiment(request: Request, experiment_id: str) -> GetExperimentR
|
|
|
255
369
|
project_name=experiment.project_name,
|
|
256
370
|
created_at=experiment.created_at,
|
|
257
371
|
updated_at=experiment.updated_at,
|
|
372
|
+
example_count=example_count or 0,
|
|
373
|
+
successful_run_count=successful_run_count or 0,
|
|
374
|
+
failed_run_count=failed_run_count or 0,
|
|
375
|
+
missing_run_count=missing_run_count,
|
|
376
|
+
)
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
@router.delete(
|
|
381
|
+
"/experiments/{experiment_id}",
|
|
382
|
+
operation_id="deleteExperiment",
|
|
383
|
+
summary="Delete experiment by ID",
|
|
384
|
+
responses=add_errors_to_responses(
|
|
385
|
+
[{"status_code": 404, "description": "Experiment not found"}]
|
|
386
|
+
),
|
|
387
|
+
response_description="Experiment deleted successfully",
|
|
388
|
+
status_code=204,
|
|
389
|
+
)
|
|
390
|
+
async def delete_experiment(
|
|
391
|
+
request: Request,
|
|
392
|
+
experiment_id: str,
|
|
393
|
+
) -> None:
|
|
394
|
+
try:
|
|
395
|
+
experiment_globalid = GlobalID.from_id(experiment_id)
|
|
396
|
+
except Exception as e:
|
|
397
|
+
raise HTTPException(
|
|
398
|
+
detail=f"Invalid experiment ID format: {experiment_id}",
|
|
399
|
+
status_code=422,
|
|
400
|
+
) from e
|
|
401
|
+
try:
|
|
402
|
+
experiment_rowid = from_global_id_with_expected_type(experiment_globalid, "Experiment")
|
|
403
|
+
except ValueError:
|
|
404
|
+
raise HTTPException(
|
|
405
|
+
detail=f"Experiment with ID {experiment_globalid} does not exist",
|
|
406
|
+
status_code=404,
|
|
258
407
|
)
|
|
408
|
+
|
|
409
|
+
stmt = (
|
|
410
|
+
sa.delete(models.Experiment)
|
|
411
|
+
.where(models.Experiment.id == experiment_rowid)
|
|
412
|
+
.returning(models.Experiment.id)
|
|
259
413
|
)
|
|
414
|
+
async with request.app.state.db() as session:
|
|
415
|
+
if (await session.scalar(stmt)) is None:
|
|
416
|
+
raise HTTPException(detail="Experiment does not exist", status_code=404)
|
|
260
417
|
|
|
261
418
|
|
|
262
|
-
class ListExperimentsResponseBody(
|
|
419
|
+
class ListExperimentsResponseBody(PaginatedResponseBody[Experiment]):
|
|
263
420
|
pass
|
|
264
421
|
|
|
265
422
|
|
|
423
|
+
class IncompleteExperimentRun(V1RoutesBaseModel):
|
|
424
|
+
"""
|
|
425
|
+
Information about incomplete runs for a dataset example
|
|
426
|
+
"""
|
|
427
|
+
|
|
428
|
+
dataset_example: DatasetExample = Field(description="The dataset example")
|
|
429
|
+
repetition_numbers: list[int] = Field(
|
|
430
|
+
description="List of repetition numbers that need to be run"
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
|
|
434
|
+
class GetIncompleteExperimentRunsResponseBody(PaginatedResponseBody[IncompleteExperimentRun]):
|
|
435
|
+
pass
|
|
436
|
+
|
|
437
|
+
|
|
438
|
+
@router.get(
|
|
439
|
+
"/experiments/{experiment_id}/incomplete-runs",
|
|
440
|
+
operation_id="getIncompleteExperimentRuns",
|
|
441
|
+
summary="Get incomplete runs for an experiment",
|
|
442
|
+
responses=add_errors_to_responses(
|
|
443
|
+
[
|
|
444
|
+
{"status_code": 404, "description": "Experiment not found"},
|
|
445
|
+
{"status_code": 422, "description": "Invalid cursor format"},
|
|
446
|
+
]
|
|
447
|
+
),
|
|
448
|
+
response_description="Incomplete runs retrieved successfully",
|
|
449
|
+
)
|
|
450
|
+
async def get_incomplete_runs(
|
|
451
|
+
request: Request,
|
|
452
|
+
experiment_id: str,
|
|
453
|
+
cursor: Optional[str] = Query(default=None, description="Cursor for pagination"),
|
|
454
|
+
limit: int = Query(
|
|
455
|
+
default=50, description="Maximum number of examples with incomplete runs to return", gt=0
|
|
456
|
+
),
|
|
457
|
+
) -> GetIncompleteExperimentRunsResponseBody:
|
|
458
|
+
"""
|
|
459
|
+
Get runs that need to be completed for this experiment.
|
|
460
|
+
|
|
461
|
+
Returns all incomplete runs, including both missing runs (not yet attempted)
|
|
462
|
+
and failed runs (attempted but have errors).
|
|
463
|
+
|
|
464
|
+
Args:
|
|
465
|
+
experiment_id: The ID of the experiment
|
|
466
|
+
cursor: Cursor for pagination
|
|
467
|
+
limit: Maximum number of results to return
|
|
468
|
+
|
|
469
|
+
Returns:
|
|
470
|
+
Paginated list of incomplete runs grouped by dataset example,
|
|
471
|
+
with repetition numbers that need to be run
|
|
472
|
+
"""
|
|
473
|
+
try:
|
|
474
|
+
experiment_globalid = GlobalID.from_id(experiment_id)
|
|
475
|
+
except Exception as e:
|
|
476
|
+
raise HTTPException(
|
|
477
|
+
detail=f"Invalid experiment ID format: {experiment_id}",
|
|
478
|
+
status_code=422,
|
|
479
|
+
) from e
|
|
480
|
+
try:
|
|
481
|
+
id_ = from_global_id_with_expected_type(experiment_globalid, "Experiment")
|
|
482
|
+
except ValueError:
|
|
483
|
+
raise HTTPException(
|
|
484
|
+
detail=f"Experiment with ID {experiment_globalid} does not exist",
|
|
485
|
+
status_code=404,
|
|
486
|
+
)
|
|
487
|
+
|
|
488
|
+
# Parse cursor if provided
|
|
489
|
+
cursor_example_rowid: Optional[int] = None
|
|
490
|
+
if cursor:
|
|
491
|
+
try:
|
|
492
|
+
cursor_gid = GlobalID.from_id(cursor)
|
|
493
|
+
cursor_example_rowid = from_global_id_with_expected_type(cursor_gid, "DatasetExample")
|
|
494
|
+
except (ValueError, AttributeError):
|
|
495
|
+
raise HTTPException(
|
|
496
|
+
detail=f"Invalid cursor format: {cursor}",
|
|
497
|
+
status_code=422,
|
|
498
|
+
)
|
|
499
|
+
|
|
500
|
+
# Fetch experiment first (we need its repetitions count for the query)
|
|
501
|
+
async with request.app.state.db() as session:
|
|
502
|
+
experiment_result = await session.execute(select(models.Experiment).filter_by(id=id_))
|
|
503
|
+
experiment = experiment_result.scalar()
|
|
504
|
+
if not experiment:
|
|
505
|
+
raise HTTPException(
|
|
506
|
+
detail=f"Experiment with ID {experiment_globalid} does not exist",
|
|
507
|
+
status_code=404,
|
|
508
|
+
)
|
|
509
|
+
|
|
510
|
+
dialect = request.app.state.db.dialect
|
|
511
|
+
|
|
512
|
+
stmt = get_experiment_incomplete_runs_query(
|
|
513
|
+
experiment,
|
|
514
|
+
dialect,
|
|
515
|
+
cursor_example_rowid=cursor_example_rowid,
|
|
516
|
+
limit=limit,
|
|
517
|
+
)
|
|
518
|
+
|
|
519
|
+
result = await session.execute(stmt)
|
|
520
|
+
all_examples = result.all()
|
|
521
|
+
|
|
522
|
+
# Check if there's a next page
|
|
523
|
+
has_next_page = len(all_examples) > limit
|
|
524
|
+
if has_next_page:
|
|
525
|
+
# Remove the extra row
|
|
526
|
+
examples_to_process = all_examples[:limit]
|
|
527
|
+
# The cursor points to the FIRST item of the NEXT page
|
|
528
|
+
next_item_id = all_examples[limit][0].dataset_example_id
|
|
529
|
+
next_cursor = str(GlobalID("DatasetExample", str(next_item_id)))
|
|
530
|
+
else:
|
|
531
|
+
examples_to_process = all_examples
|
|
532
|
+
next_cursor = None
|
|
533
|
+
|
|
534
|
+
# Parse incomplete repetitions and build response
|
|
535
|
+
# Optimization: Precompute the "all repetitions" list for completely missing examples
|
|
536
|
+
# to avoid recomputing it for every missing example
|
|
537
|
+
all_repetitions = list(range(1, experiment.repetitions + 1))
|
|
538
|
+
incomplete_runs_list: list[IncompleteExperimentRun] = []
|
|
539
|
+
|
|
540
|
+
for revision, successful_count, incomplete_reps in examples_to_process:
|
|
541
|
+
example_id = revision.dataset_example_id
|
|
542
|
+
|
|
543
|
+
# Three regimes:
|
|
544
|
+
# 1. Completely missing (successful_count = 0): all repetitions are incomplete
|
|
545
|
+
# 2. Partially completed (0 < successful_count < R): parse from SQL result
|
|
546
|
+
# 3. Totally completed (successful_count = R): filtered out by SQL HAVING clause
|
|
547
|
+
|
|
548
|
+
if successful_count == 0:
|
|
549
|
+
# Regime 1: Completely missing - use precomputed list
|
|
550
|
+
incomplete = all_repetitions
|
|
551
|
+
else:
|
|
552
|
+
# Regime 2: Partially completed - parse incomplete reps from SQL
|
|
553
|
+
if dialect is SupportedSQLDialect.POSTGRESQL:
|
|
554
|
+
# PostgreSQL returns array (list), filter out nulls
|
|
555
|
+
incomplete = [r for r in incomplete_reps if r is not None]
|
|
556
|
+
else:
|
|
557
|
+
# SQLite returns JSON string
|
|
558
|
+
incomplete = [r for r in json.loads(incomplete_reps) if r is not None]
|
|
559
|
+
|
|
560
|
+
# Build response
|
|
561
|
+
example_globalid = GlobalID("DatasetExample", str(example_id))
|
|
562
|
+
incomplete_runs_list.append(
|
|
563
|
+
IncompleteExperimentRun(
|
|
564
|
+
dataset_example=DatasetExample(
|
|
565
|
+
id=str(example_globalid),
|
|
566
|
+
input=revision.input,
|
|
567
|
+
output=revision.output,
|
|
568
|
+
metadata=revision.metadata_,
|
|
569
|
+
updated_at=revision.created_at,
|
|
570
|
+
),
|
|
571
|
+
repetition_numbers=sorted(incomplete),
|
|
572
|
+
)
|
|
573
|
+
)
|
|
574
|
+
|
|
575
|
+
return GetIncompleteExperimentRunsResponseBody(
|
|
576
|
+
data=incomplete_runs_list, next_cursor=next_cursor
|
|
577
|
+
)
|
|
578
|
+
|
|
579
|
+
|
|
266
580
|
@router.get(
|
|
267
581
|
"/datasets/{dataset_id}/experiments",
|
|
268
582
|
operation_id="listExperiments",
|
|
269
583
|
summary="List experiments by dataset",
|
|
270
|
-
|
|
584
|
+
description="Retrieve a paginated list of experiments for the specified dataset.",
|
|
585
|
+
response_description="Paginated list of experiments for the dataset",
|
|
586
|
+
responses=add_errors_to_responses([422]),
|
|
271
587
|
)
|
|
272
588
|
async def list_experiments(
|
|
273
589
|
request: Request,
|
|
274
590
|
dataset_id: str = Path(..., title="Dataset ID"),
|
|
591
|
+
cursor: Optional[str] = Query(
|
|
592
|
+
default=None,
|
|
593
|
+
description="Cursor for pagination (base64-encoded experiment ID)",
|
|
594
|
+
),
|
|
595
|
+
limit: int = Query(
|
|
596
|
+
default=50, description="The max number of experiments to return at a time.", gt=0
|
|
597
|
+
),
|
|
275
598
|
) -> ListExperimentsResponseBody:
|
|
276
|
-
|
|
599
|
+
try:
|
|
600
|
+
dataset_gid = GlobalID.from_id(dataset_id)
|
|
601
|
+
except Exception as e:
|
|
602
|
+
raise HTTPException(
|
|
603
|
+
detail=f"Invalid dataset ID format: {dataset_id}",
|
|
604
|
+
status_code=422,
|
|
605
|
+
) from e
|
|
277
606
|
try:
|
|
278
607
|
dataset_rowid = from_global_id_with_expected_type(dataset_gid, "Dataset")
|
|
279
608
|
except ValueError:
|
|
280
609
|
raise HTTPException(
|
|
281
610
|
detail=f"Dataset with ID {dataset_gid} does not exist",
|
|
282
|
-
status_code=
|
|
611
|
+
status_code=404,
|
|
283
612
|
)
|
|
284
613
|
async with request.app.state.db() as session:
|
|
285
614
|
query = (
|
|
@@ -288,29 +617,119 @@ async def list_experiments(
|
|
|
288
617
|
.order_by(models.Experiment.id.desc())
|
|
289
618
|
)
|
|
290
619
|
|
|
620
|
+
# Handle cursor for pagination
|
|
621
|
+
if cursor:
|
|
622
|
+
try:
|
|
623
|
+
cursor_gid = GlobalID.from_id(cursor)
|
|
624
|
+
cursor_rowid = from_global_id_with_expected_type(cursor_gid, "Experiment")
|
|
625
|
+
query = query.where(models.Experiment.id <= cursor_rowid)
|
|
626
|
+
except (ValueError, Exception):
|
|
627
|
+
raise HTTPException(
|
|
628
|
+
detail=f"Invalid cursor format: {cursor}",
|
|
629
|
+
status_code=422,
|
|
630
|
+
)
|
|
631
|
+
|
|
632
|
+
# Overfetch by 1 to determine if there's a next page
|
|
633
|
+
query = query.limit(limit + 1)
|
|
634
|
+
|
|
291
635
|
result = await session.execute(query)
|
|
292
636
|
experiments = result.scalars().all()
|
|
293
637
|
|
|
294
638
|
if not experiments:
|
|
295
|
-
return ListExperimentsResponseBody(data=[])
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
639
|
+
return ListExperimentsResponseBody(data=[], next_cursor=None)
|
|
640
|
+
|
|
641
|
+
# Get example counts and successful run counts for all experiments in a single query
|
|
642
|
+
experiment_ids = [exp.id for exp in experiments]
|
|
643
|
+
|
|
644
|
+
# Create subqueries for counts
|
|
645
|
+
example_count_subq = (
|
|
646
|
+
select(
|
|
647
|
+
models.ExperimentDatasetExample.experiment_id, func.count().label("example_count")
|
|
648
|
+
)
|
|
649
|
+
.where(models.ExperimentDatasetExample.experiment_id.in_(experiment_ids))
|
|
650
|
+
.group_by(models.ExperimentDatasetExample.experiment_id)
|
|
651
|
+
.subquery()
|
|
652
|
+
)
|
|
653
|
+
|
|
654
|
+
# Optimize: Use CASE to count successful and failed in single table scan
|
|
655
|
+
run_counts_subq = (
|
|
656
|
+
select(
|
|
657
|
+
models.ExperimentRun.experiment_id,
|
|
658
|
+
func.sum(case((models.ExperimentRun.error.is_(None), 1), else_=0)).label(
|
|
659
|
+
"successful_run_count"
|
|
660
|
+
),
|
|
661
|
+
func.sum(case((models.ExperimentRun.error.is_not(None), 1), else_=0)).label(
|
|
662
|
+
"failed_run_count"
|
|
303
663
|
),
|
|
304
|
-
repetitions=experiment.repetitions,
|
|
305
|
-
metadata=experiment.metadata_,
|
|
306
|
-
project_name=None,
|
|
307
|
-
created_at=experiment.created_at,
|
|
308
|
-
updated_at=experiment.updated_at,
|
|
309
664
|
)
|
|
310
|
-
|
|
311
|
-
|
|
665
|
+
.where(models.ExperimentRun.experiment_id.in_(experiment_ids))
|
|
666
|
+
.group_by(models.ExperimentRun.experiment_id)
|
|
667
|
+
.subquery()
|
|
668
|
+
)
|
|
312
669
|
|
|
313
|
-
|
|
670
|
+
# Get all counts in one query using outer join
|
|
671
|
+
counts_result = await session.execute(
|
|
672
|
+
select(
|
|
673
|
+
func.coalesce(
|
|
674
|
+
example_count_subq.c.experiment_id,
|
|
675
|
+
run_counts_subq.c.experiment_id,
|
|
676
|
+
).label("experiment_id"),
|
|
677
|
+
func.coalesce(example_count_subq.c.example_count, 0).label("example_count"),
|
|
678
|
+
func.coalesce(run_counts_subq.c.successful_run_count, 0).label(
|
|
679
|
+
"successful_run_count"
|
|
680
|
+
),
|
|
681
|
+
func.coalesce(run_counts_subq.c.failed_run_count, 0).label("failed_run_count"),
|
|
682
|
+
)
|
|
683
|
+
.select_from(example_count_subq)
|
|
684
|
+
.outerjoin(
|
|
685
|
+
run_counts_subq,
|
|
686
|
+
example_count_subq.c.experiment_id == run_counts_subq.c.experiment_id,
|
|
687
|
+
)
|
|
688
|
+
)
|
|
689
|
+
|
|
690
|
+
counts_by_experiment = {
|
|
691
|
+
row.experiment_id: (row.example_count, row.successful_run_count, row.failed_run_count)
|
|
692
|
+
for row in counts_result
|
|
693
|
+
}
|
|
694
|
+
|
|
695
|
+
# Handle pagination: check if we have a next page
|
|
696
|
+
next_cursor = None
|
|
697
|
+
if len(experiments) == limit + 1:
|
|
698
|
+
last_experiment = experiments[-1]
|
|
699
|
+
next_cursor = str(GlobalID("Experiment", str(last_experiment.id)))
|
|
700
|
+
experiments = experiments[:-1] # Remove the extra overfetched experiment
|
|
701
|
+
|
|
702
|
+
data = []
|
|
703
|
+
for experiment in experiments:
|
|
704
|
+
counts = counts_by_experiment.get(experiment.id, (0, 0, 0))
|
|
705
|
+
example_count = counts[0]
|
|
706
|
+
successful_run_count = counts[1]
|
|
707
|
+
failed_run_count = counts[2]
|
|
708
|
+
|
|
709
|
+
# Calculate missing runs (no database query needed)
|
|
710
|
+
total_expected_runs = example_count * experiment.repetitions
|
|
711
|
+
missing_run_count = total_expected_runs - successful_run_count - failed_run_count
|
|
712
|
+
|
|
713
|
+
data.append(
|
|
714
|
+
Experiment(
|
|
715
|
+
id=str(GlobalID("Experiment", str(experiment.id))),
|
|
716
|
+
dataset_id=str(GlobalID("Dataset", str(experiment.dataset_id))),
|
|
717
|
+
dataset_version_id=str(
|
|
718
|
+
GlobalID("DatasetVersion", str(experiment.dataset_version_id))
|
|
719
|
+
),
|
|
720
|
+
repetitions=experiment.repetitions,
|
|
721
|
+
metadata=experiment.metadata_,
|
|
722
|
+
project_name=experiment.project_name,
|
|
723
|
+
created_at=experiment.created_at,
|
|
724
|
+
updated_at=experiment.updated_at,
|
|
725
|
+
example_count=example_count,
|
|
726
|
+
successful_run_count=successful_run_count,
|
|
727
|
+
failed_run_count=failed_run_count,
|
|
728
|
+
missing_run_count=missing_run_count,
|
|
729
|
+
)
|
|
730
|
+
)
|
|
731
|
+
|
|
732
|
+
return ListExperimentsResponseBody(data=data, next_cursor=next_cursor)
|
|
314
733
|
|
|
315
734
|
|
|
316
735
|
async def _get_experiment_runs_and_revisions(
|
|
@@ -318,7 +737,7 @@ async def _get_experiment_runs_and_revisions(
|
|
|
318
737
|
) -> tuple[models.Experiment, tuple[models.ExperimentRun], tuple[models.DatasetExampleRevision]]:
|
|
319
738
|
experiment = await session.get(models.Experiment, experiment_rowid)
|
|
320
739
|
if not experiment:
|
|
321
|
-
raise HTTPException(detail="Experiment not found", status_code=
|
|
740
|
+
raise HTTPException(detail="Experiment not found", status_code=404)
|
|
322
741
|
revision_ids = (
|
|
323
742
|
select(func.max(models.DatasetExampleRevision.id))
|
|
324
743
|
.join(
|
|
@@ -367,7 +786,7 @@ async def _get_experiment_runs_and_revisions(
|
|
|
367
786
|
if not runs_and_revisions:
|
|
368
787
|
raise HTTPException(
|
|
369
788
|
detail="Experiment has no runs",
|
|
370
|
-
status_code=
|
|
789
|
+
status_code=404,
|
|
371
790
|
)
|
|
372
791
|
runs, revisions = zip(*runs_and_revisions)
|
|
373
792
|
return experiment, runs, revisions
|
|
@@ -380,7 +799,7 @@ async def _get_experiment_runs_and_revisions(
|
|
|
380
799
|
response_class=PlainTextResponse,
|
|
381
800
|
responses=add_errors_to_responses(
|
|
382
801
|
[
|
|
383
|
-
{"status_code":
|
|
802
|
+
{"status_code": 404, "description": "Experiment not found"},
|
|
384
803
|
]
|
|
385
804
|
),
|
|
386
805
|
)
|
|
@@ -388,13 +807,19 @@ async def get_experiment_json(
|
|
|
388
807
|
request: Request,
|
|
389
808
|
experiment_id: str = Path(..., title="Experiment ID"),
|
|
390
809
|
) -> Response:
|
|
391
|
-
|
|
810
|
+
try:
|
|
811
|
+
experiment_globalid = GlobalID.from_id(experiment_id)
|
|
812
|
+
except Exception as e:
|
|
813
|
+
raise HTTPException(
|
|
814
|
+
detail=f"Invalid experiment ID format: {experiment_id}",
|
|
815
|
+
status_code=422,
|
|
816
|
+
) from e
|
|
392
817
|
try:
|
|
393
818
|
experiment_rowid = from_global_id_with_expected_type(experiment_globalid, "Experiment")
|
|
394
819
|
except ValueError:
|
|
395
820
|
raise HTTPException(
|
|
396
821
|
detail=f"Invalid experiment ID: {experiment_globalid}",
|
|
397
|
-
status_code=
|
|
822
|
+
status_code=422,
|
|
398
823
|
)
|
|
399
824
|
|
|
400
825
|
async with request.app.state.db() as session:
|
|
@@ -449,19 +874,25 @@ async def get_experiment_json(
|
|
|
449
874
|
"/experiments/{experiment_id}/csv",
|
|
450
875
|
operation_id="getExperimentCSV",
|
|
451
876
|
summary="Download experiment runs as a CSV file",
|
|
452
|
-
responses={**add_text_csv_content_to_responses(
|
|
877
|
+
responses={**add_text_csv_content_to_responses(200)},
|
|
453
878
|
)
|
|
454
879
|
async def get_experiment_csv(
|
|
455
880
|
request: Request,
|
|
456
881
|
experiment_id: str = Path(..., title="Experiment ID"),
|
|
457
882
|
) -> Response:
|
|
458
|
-
|
|
883
|
+
try:
|
|
884
|
+
experiment_globalid = GlobalID.from_id(experiment_id)
|
|
885
|
+
except Exception as e:
|
|
886
|
+
raise HTTPException(
|
|
887
|
+
detail=f"Invalid experiment ID format: {experiment_id}",
|
|
888
|
+
status_code=422,
|
|
889
|
+
) from e
|
|
459
890
|
try:
|
|
460
891
|
experiment_rowid = from_global_id_with_expected_type(experiment_globalid, "Experiment")
|
|
461
892
|
except ValueError:
|
|
462
893
|
raise HTTPException(
|
|
463
894
|
detail=f"Invalid experiment ID: {experiment_globalid}",
|
|
464
|
-
status_code=
|
|
895
|
+
status_code=422,
|
|
465
896
|
)
|
|
466
897
|
|
|
467
898
|
async with request.app.state.db() as session:
|