arize-phoenix 10.0.4__py3-none-any.whl → 12.28.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +124 -72
- arize_phoenix-12.28.1.dist-info/RECORD +499 -0
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
- phoenix/__generated__/__init__.py +0 -0
- phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
- phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
- phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
- phoenix/__init__.py +5 -4
- phoenix/auth.py +39 -2
- phoenix/config.py +1763 -91
- phoenix/datetime_utils.py +120 -2
- phoenix/db/README.md +595 -25
- phoenix/db/bulk_inserter.py +145 -103
- phoenix/db/engines.py +140 -33
- phoenix/db/enums.py +3 -12
- phoenix/db/facilitator.py +302 -35
- phoenix/db/helpers.py +1000 -65
- phoenix/db/iam_auth.py +64 -0
- phoenix/db/insertion/dataset.py +135 -2
- phoenix/db/insertion/document_annotation.py +9 -6
- phoenix/db/insertion/evaluation.py +2 -3
- phoenix/db/insertion/helpers.py +17 -2
- phoenix/db/insertion/session_annotation.py +176 -0
- phoenix/db/insertion/span.py +15 -11
- phoenix/db/insertion/span_annotation.py +3 -4
- phoenix/db/insertion/trace_annotation.py +3 -4
- phoenix/db/insertion/types.py +50 -20
- phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
- phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
- phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
- phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
- phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
- phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
- phoenix/db/migrations/versions/a20694b15f82_cost.py +196 -0
- phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
- phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
- phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
- phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
- phoenix/db/models.py +669 -56
- phoenix/db/pg_config.py +10 -0
- phoenix/db/types/model_provider.py +4 -0
- phoenix/db/types/token_price_customization.py +29 -0
- phoenix/db/types/trace_retention.py +23 -15
- phoenix/experiments/evaluators/utils.py +3 -3
- phoenix/experiments/functions.py +160 -52
- phoenix/experiments/tracing.py +2 -2
- phoenix/experiments/types.py +1 -1
- phoenix/inferences/inferences.py +1 -2
- phoenix/server/api/auth.py +38 -7
- phoenix/server/api/auth_messages.py +46 -0
- phoenix/server/api/context.py +100 -4
- phoenix/server/api/dataloaders/__init__.py +79 -5
- phoenix/server/api/dataloaders/annotation_configs_by_project.py +31 -0
- phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
- phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
- phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
- phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
- phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
- phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
- phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
- phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
- phoenix/server/api/dataloaders/dataset_labels.py +36 -0
- phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
- phoenix/server/api/dataloaders/document_evaluations.py +6 -9
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
- phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
- phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
- phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
- phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
- phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
- phoenix/server/api/dataloaders/last_used_times_by_generative_model_id.py +35 -0
- phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
- phoenix/server/api/dataloaders/record_counts.py +37 -10
- phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
- phoenix/server/api/dataloaders/span_cost_by_span.py +24 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_generative_model.py +56 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_project_session.py +57 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_span.py +43 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_trace.py +56 -0
- phoenix/server/api/dataloaders/span_cost_details_by_span_cost.py +27 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment.py +57 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_run.py +58 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_generative_model.py +55 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project.py +152 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project_session.py +56 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_trace.py +55 -0
- phoenix/server/api/dataloaders/span_costs.py +29 -0
- phoenix/server/api/dataloaders/table_fields.py +2 -2
- phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
- phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
- phoenix/server/api/dataloaders/types.py +29 -0
- phoenix/server/api/exceptions.py +11 -1
- phoenix/server/api/helpers/dataset_helpers.py +5 -1
- phoenix/server/api/helpers/playground_clients.py +1243 -292
- phoenix/server/api/helpers/playground_registry.py +2 -2
- phoenix/server/api/helpers/playground_spans.py +8 -4
- phoenix/server/api/helpers/playground_users.py +26 -0
- phoenix/server/api/helpers/prompts/conversions/aws.py +83 -0
- phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
- phoenix/server/api/helpers/prompts/models.py +205 -22
- phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
- phoenix/server/api/input_types/ChatCompletionInput.py +6 -2
- phoenix/server/api/input_types/CreateProjectInput.py +27 -0
- phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
- phoenix/server/api/input_types/DatasetFilter.py +17 -0
- phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
- phoenix/server/api/input_types/GenerativeCredentialInput.py +9 -0
- phoenix/server/api/input_types/GenerativeModelInput.py +5 -0
- phoenix/server/api/input_types/ProjectSessionSort.py +161 -1
- phoenix/server/api/input_types/PromptFilter.py +14 -0
- phoenix/server/api/input_types/PromptVersionInput.py +52 -1
- phoenix/server/api/input_types/SpanSort.py +44 -7
- phoenix/server/api/input_types/TimeBinConfig.py +23 -0
- phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
- phoenix/server/api/input_types/UserRoleInput.py +1 -0
- phoenix/server/api/mutations/__init__.py +10 -0
- phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
- phoenix/server/api/mutations/api_key_mutations.py +19 -23
- phoenix/server/api/mutations/chat_mutations.py +154 -47
- phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
- phoenix/server/api/mutations/dataset_mutations.py +21 -16
- phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
- phoenix/server/api/mutations/experiment_mutations.py +2 -2
- phoenix/server/api/mutations/export_events_mutations.py +3 -3
- phoenix/server/api/mutations/model_mutations.py +210 -0
- phoenix/server/api/mutations/project_mutations.py +49 -10
- phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
- phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
- phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
- phoenix/server/api/mutations/prompt_mutations.py +65 -129
- phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
- phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
- phoenix/server/api/mutations/trace_annotations_mutations.py +14 -10
- phoenix/server/api/mutations/trace_mutations.py +47 -3
- phoenix/server/api/mutations/user_mutations.py +66 -41
- phoenix/server/api/queries.py +768 -293
- phoenix/server/api/routers/__init__.py +2 -2
- phoenix/server/api/routers/auth.py +154 -88
- phoenix/server/api/routers/ldap.py +229 -0
- phoenix/server/api/routers/oauth2.py +369 -106
- phoenix/server/api/routers/v1/__init__.py +24 -4
- phoenix/server/api/routers/v1/annotation_configs.py +23 -31
- phoenix/server/api/routers/v1/annotations.py +481 -17
- phoenix/server/api/routers/v1/datasets.py +395 -81
- phoenix/server/api/routers/v1/documents.py +142 -0
- phoenix/server/api/routers/v1/evaluations.py +24 -31
- phoenix/server/api/routers/v1/experiment_evaluations.py +19 -8
- phoenix/server/api/routers/v1/experiment_runs.py +337 -59
- phoenix/server/api/routers/v1/experiments.py +479 -48
- phoenix/server/api/routers/v1/models.py +7 -0
- phoenix/server/api/routers/v1/projects.py +18 -49
- phoenix/server/api/routers/v1/prompts.py +54 -40
- phoenix/server/api/routers/v1/sessions.py +108 -0
- phoenix/server/api/routers/v1/spans.py +1091 -81
- phoenix/server/api/routers/v1/traces.py +132 -78
- phoenix/server/api/routers/v1/users.py +389 -0
- phoenix/server/api/routers/v1/utils.py +3 -7
- phoenix/server/api/subscriptions.py +305 -88
- phoenix/server/api/types/Annotation.py +90 -23
- phoenix/server/api/types/ApiKey.py +13 -17
- phoenix/server/api/types/AuthMethod.py +1 -0
- phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
- phoenix/server/api/types/CostBreakdown.py +12 -0
- phoenix/server/api/types/Dataset.py +226 -72
- phoenix/server/api/types/DatasetExample.py +88 -18
- phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
- phoenix/server/api/types/DatasetLabel.py +57 -0
- phoenix/server/api/types/DatasetSplit.py +98 -0
- phoenix/server/api/types/DatasetVersion.py +49 -4
- phoenix/server/api/types/DocumentAnnotation.py +212 -0
- phoenix/server/api/types/Experiment.py +264 -59
- phoenix/server/api/types/ExperimentComparison.py +5 -10
- phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
- phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
- phoenix/server/api/types/ExperimentRun.py +169 -65
- phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
- phoenix/server/api/types/GenerativeModel.py +245 -3
- phoenix/server/api/types/GenerativeProvider.py +70 -11
- phoenix/server/api/types/{Model.py → InferenceModel.py} +1 -1
- phoenix/server/api/types/ModelInterface.py +16 -0
- phoenix/server/api/types/PlaygroundModel.py +20 -0
- phoenix/server/api/types/Project.py +1278 -216
- phoenix/server/api/types/ProjectSession.py +188 -28
- phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
- phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
- phoenix/server/api/types/Prompt.py +119 -39
- phoenix/server/api/types/PromptLabel.py +42 -25
- phoenix/server/api/types/PromptVersion.py +11 -8
- phoenix/server/api/types/PromptVersionTag.py +65 -25
- phoenix/server/api/types/ServerStatus.py +6 -0
- phoenix/server/api/types/Span.py +167 -123
- phoenix/server/api/types/SpanAnnotation.py +189 -42
- phoenix/server/api/types/SpanCostDetailSummaryEntry.py +10 -0
- phoenix/server/api/types/SpanCostSummary.py +10 -0
- phoenix/server/api/types/SystemApiKey.py +65 -1
- phoenix/server/api/types/TokenPrice.py +16 -0
- phoenix/server/api/types/TokenUsage.py +3 -3
- phoenix/server/api/types/Trace.py +223 -51
- phoenix/server/api/types/TraceAnnotation.py +149 -50
- phoenix/server/api/types/User.py +137 -32
- phoenix/server/api/types/UserApiKey.py +73 -26
- phoenix/server/api/types/node.py +10 -0
- phoenix/server/api/types/pagination.py +11 -2
- phoenix/server/app.py +290 -45
- phoenix/server/authorization.py +38 -3
- phoenix/server/bearer_auth.py +34 -24
- phoenix/server/cost_tracking/cost_details_calculator.py +196 -0
- phoenix/server/cost_tracking/cost_model_lookup.py +179 -0
- phoenix/server/cost_tracking/helpers.py +68 -0
- phoenix/server/cost_tracking/model_cost_manifest.json +3657 -830
- phoenix/server/cost_tracking/regex_specificity.py +397 -0
- phoenix/server/cost_tracking/token_cost_calculator.py +57 -0
- phoenix/server/daemons/__init__.py +0 -0
- phoenix/server/daemons/db_disk_usage_monitor.py +214 -0
- phoenix/server/daemons/generative_model_store.py +103 -0
- phoenix/server/daemons/span_cost_calculator.py +99 -0
- phoenix/server/dml_event.py +17 -0
- phoenix/server/dml_event_handler.py +5 -0
- phoenix/server/email/sender.py +56 -3
- phoenix/server/email/templates/db_disk_usage_notification.html +19 -0
- phoenix/server/email/types.py +11 -0
- phoenix/server/experiments/__init__.py +0 -0
- phoenix/server/experiments/utils.py +14 -0
- phoenix/server/grpc_server.py +11 -11
- phoenix/server/jwt_store.py +17 -15
- phoenix/server/ldap.py +1449 -0
- phoenix/server/main.py +26 -10
- phoenix/server/oauth2.py +330 -12
- phoenix/server/prometheus.py +66 -6
- phoenix/server/rate_limiters.py +4 -9
- phoenix/server/retention.py +33 -20
- phoenix/server/session_filters.py +49 -0
- phoenix/server/static/.vite/manifest.json +55 -51
- phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
- phoenix/server/static/assets/{index-E0M82BdE.js → index-CTQoemZv.js} +140 -56
- phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
- phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
- phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
- phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
- phoenix/server/static/assets/vendor-recharts-V9cwpXsm.js +37 -0
- phoenix/server/static/assets/vendor-shiki-Do--csgv.js +5 -0
- phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
- phoenix/server/templates/index.html +40 -6
- phoenix/server/thread_server.py +1 -2
- phoenix/server/types.py +14 -4
- phoenix/server/utils.py +74 -0
- phoenix/session/client.py +56 -3
- phoenix/session/data_extractor.py +5 -0
- phoenix/session/evaluation.py +14 -5
- phoenix/session/session.py +45 -9
- phoenix/settings.py +5 -0
- phoenix/trace/attributes.py +80 -13
- phoenix/trace/dsl/helpers.py +90 -1
- phoenix/trace/dsl/query.py +8 -6
- phoenix/trace/projects.py +5 -0
- phoenix/utilities/template_formatters.py +1 -1
- phoenix/version.py +1 -1
- arize_phoenix-10.0.4.dist-info/RECORD +0 -405
- phoenix/server/api/types/Evaluation.py +0 -39
- phoenix/server/cost_tracking/cost_lookup.py +0 -255
- phoenix/server/static/assets/components-DULKeDfL.js +0 -4365
- phoenix/server/static/assets/pages-Cl0A-0U2.js +0 -7430
- phoenix/server/static/assets/vendor-WIZid84E.css +0 -1
- phoenix/server/static/assets/vendor-arizeai-Dy-0mSNw.js +0 -649
- phoenix/server/static/assets/vendor-codemirror-DBtifKNr.js +0 -33
- phoenix/server/static/assets/vendor-oB4u9zuV.js +0 -905
- phoenix/server/static/assets/vendor-recharts-D-T4KPz2.js +0 -59
- phoenix/server/static/assets/vendor-shiki-BMn4O_9F.js +0 -5
- phoenix/server/static/assets/vendor-three-C5WAXd5r.js +0 -2998
- phoenix/utilities/deprecation.py +0 -31
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
phoenix/db/helpers.py
CHANGED
|
@@ -1,22 +1,28 @@
|
|
|
1
1
|
from collections.abc import Callable, Hashable, Iterable
|
|
2
|
+
from datetime import datetime
|
|
2
3
|
from enum import Enum
|
|
3
|
-
from typing import Any, Optional, TypeVar
|
|
4
|
+
from typing import Any, Literal, Optional, Sequence, TypeVar, Union
|
|
4
5
|
|
|
5
|
-
|
|
6
|
-
OpenInferenceSpanKindValues,
|
|
7
|
-
RerankerAttributes,
|
|
8
|
-
SpanAttributes,
|
|
9
|
-
)
|
|
6
|
+
import sqlalchemy as sa
|
|
10
7
|
from sqlalchemy import (
|
|
11
|
-
|
|
8
|
+
Insert,
|
|
12
9
|
Select,
|
|
13
10
|
SQLColumnExpression,
|
|
14
11
|
and_,
|
|
15
12
|
case,
|
|
16
13
|
distinct,
|
|
14
|
+
exists,
|
|
17
15
|
func,
|
|
16
|
+
insert,
|
|
17
|
+
literal,
|
|
18
|
+
literal_column,
|
|
19
|
+
or_,
|
|
18
20
|
select,
|
|
21
|
+
util,
|
|
19
22
|
)
|
|
23
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
24
|
+
from sqlalchemy.orm import QueryableAttribute
|
|
25
|
+
from sqlalchemy.sql.roles import InElementRole
|
|
20
26
|
from typing_extensions import assert_never
|
|
21
27
|
|
|
22
28
|
from phoenix.config import PLAYGROUND_PROJECT_NAME
|
|
@@ -34,30 +40,6 @@ class SupportedSQLDialect(Enum):
|
|
|
34
40
|
raise ValueError(f"`{v}` is not a supported SQL backend/dialect.")
|
|
35
41
|
|
|
36
42
|
|
|
37
|
-
def num_docs_col(dialect: SupportedSQLDialect) -> SQLColumnExpression[Integer]:
|
|
38
|
-
if dialect is SupportedSQLDialect.POSTGRESQL:
|
|
39
|
-
array_length = func.jsonb_array_length
|
|
40
|
-
elif dialect is SupportedSQLDialect.SQLITE:
|
|
41
|
-
array_length = func.json_array_length
|
|
42
|
-
else:
|
|
43
|
-
assert_never(dialect)
|
|
44
|
-
retrieval_docs = models.Span.attributes[_RETRIEVAL_DOCUMENTS]
|
|
45
|
-
num_retrieval_docs = array_length(retrieval_docs)
|
|
46
|
-
reranker_docs = models.Span.attributes[_RERANKER_OUTPUT_DOCUMENTS]
|
|
47
|
-
num_reranker_docs = array_length(reranker_docs)
|
|
48
|
-
return case(
|
|
49
|
-
(
|
|
50
|
-
func.upper(models.Span.span_kind) == OpenInferenceSpanKindValues.RERANKER.value.upper(),
|
|
51
|
-
num_reranker_docs,
|
|
52
|
-
),
|
|
53
|
-
else_=num_retrieval_docs,
|
|
54
|
-
).label("num_docs")
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
_RETRIEVAL_DOCUMENTS = SpanAttributes.RETRIEVAL_DOCUMENTS.split(".")
|
|
58
|
-
_RERANKER_OUTPUT_DOCUMENTS = RerankerAttributes.RERANKER_OUTPUT_DOCUMENTS.split(".")
|
|
59
|
-
|
|
60
|
-
|
|
61
43
|
def get_eval_trace_ids_for_datasets(*dataset_ids: int) -> Select[tuple[Optional[str]]]:
|
|
62
44
|
return (
|
|
63
45
|
select(distinct(models.ExperimentRunAnnotation.trace_id))
|
|
@@ -115,51 +97,205 @@ def dedup(
|
|
|
115
97
|
return ans
|
|
116
98
|
|
|
117
99
|
|
|
118
|
-
def
|
|
100
|
+
def _build_ranked_revisions_query(
|
|
119
101
|
dataset_version_id: int,
|
|
120
|
-
|
|
121
|
-
|
|
102
|
+
/,
|
|
103
|
+
*,
|
|
104
|
+
dataset_id: Optional[int] = None,
|
|
105
|
+
example_ids: Optional[Union[Sequence[int], InElementRole]] = None,
|
|
106
|
+
) -> Select[tuple[int]]:
|
|
107
|
+
"""
|
|
108
|
+
Build a query that ranks revisions per example within a dataset version.
|
|
109
|
+
|
|
110
|
+
This performs the core ranking logic using ROW_NUMBER() to find the latest
|
|
111
|
+
revision for each example within the specified dataset version.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
dataset_version_id: Maximum dataset version to consider
|
|
115
|
+
dataset_id: Optional dataset ID - if provided, avoids subquery lookup
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
SQLAlchemy SELECT query with revision ranking and basic dataset filtering
|
|
119
|
+
"""
|
|
120
|
+
stmt = (
|
|
122
121
|
select(
|
|
123
|
-
|
|
124
|
-
|
|
122
|
+
func.row_number()
|
|
123
|
+
.over(
|
|
124
|
+
partition_by=models.DatasetExampleRevision.dataset_example_id,
|
|
125
|
+
order_by=models.DatasetExampleRevision.dataset_version_id.desc(),
|
|
126
|
+
)
|
|
127
|
+
.label("rn"),
|
|
125
128
|
)
|
|
126
|
-
.
|
|
127
|
-
.
|
|
129
|
+
.join(models.DatasetExample)
|
|
130
|
+
.where(models.DatasetExampleRevision.dataset_version_id <= dataset_version_id)
|
|
128
131
|
)
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
132
|
+
|
|
133
|
+
if dataset_id is None:
|
|
134
|
+
version_subquery = (
|
|
135
|
+
select(models.DatasetVersion.dataset_id)
|
|
136
|
+
.filter_by(id=dataset_version_id)
|
|
137
|
+
.scalar_subquery()
|
|
138
|
+
)
|
|
139
|
+
stmt = stmt.where(models.DatasetExample.dataset_id == version_subquery)
|
|
140
|
+
else:
|
|
141
|
+
stmt = stmt.where(models.DatasetExample.dataset_id == dataset_id)
|
|
142
|
+
|
|
143
|
+
if example_ids is not None:
|
|
144
|
+
stmt = stmt.where(models.DatasetExampleRevision.dataset_example_id.in_(example_ids))
|
|
145
|
+
|
|
146
|
+
return stmt
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def get_dataset_example_revisions(
|
|
150
|
+
dataset_version_id: int,
|
|
151
|
+
/,
|
|
152
|
+
*,
|
|
153
|
+
dataset_id: Optional[int] = None,
|
|
154
|
+
example_ids: Optional[Union[Sequence[int], InElementRole]] = None,
|
|
155
|
+
split_ids: Optional[Union[Sequence[int], InElementRole]] = None,
|
|
156
|
+
split_names: Optional[Union[Sequence[str], InElementRole]] = None,
|
|
157
|
+
) -> Select[tuple[models.DatasetExampleRevision]]:
|
|
158
|
+
"""
|
|
159
|
+
Get the latest revisions for all dataset examples within a specific dataset version.
|
|
160
|
+
|
|
161
|
+
Excludes examples where the latest revision is a DELETE.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
dataset_version_id: The dataset version to get revisions for
|
|
165
|
+
dataset_id: Optional dataset ID - if provided, avoids extra subquery lookup
|
|
166
|
+
example_ids: Optional filter by specific example IDs (subquery or list of IDs).
|
|
167
|
+
- None = no filtering
|
|
168
|
+
- Empty sequences/subqueries = no matches (strict filtering)
|
|
169
|
+
split_ids: Optional filter by split IDs (subquery or list of split IDs).
|
|
170
|
+
- None = no filtering
|
|
171
|
+
- Empty sequences/subqueries = no matches (strict filtering)
|
|
172
|
+
split_names: Optional filter by split names (subquery or list of split names).
|
|
173
|
+
- None = no filtering
|
|
174
|
+
- Empty sequences/subqueries = no matches (strict filtering)
|
|
175
|
+
|
|
176
|
+
Note:
|
|
177
|
+
- split_ids and split_names are mutually exclusive
|
|
178
|
+
- Use split_ids for better performance when IDs are available (avoids JOIN)
|
|
179
|
+
- Empty filters use strict behavior: empty inputs return zero results
|
|
180
|
+
"""
|
|
181
|
+
if split_ids is not None and split_names is not None:
|
|
182
|
+
raise ValueError(
|
|
183
|
+
"Cannot specify both split_ids and split_names - they are mutually exclusive"
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
stmt = _build_ranked_revisions_query(
|
|
187
|
+
dataset_version_id,
|
|
188
|
+
dataset_id=dataset_id,
|
|
189
|
+
example_ids=example_ids,
|
|
190
|
+
).add_columns(
|
|
191
|
+
models.DatasetExampleRevision.id,
|
|
192
|
+
models.DatasetExampleRevision.revision_kind,
|
|
149
193
|
)
|
|
194
|
+
|
|
195
|
+
if split_ids is not None or split_names is not None:
|
|
196
|
+
if split_names is not None:
|
|
197
|
+
split_example_ids_subquery = (
|
|
198
|
+
select(models.DatasetSplitDatasetExample.dataset_example_id)
|
|
199
|
+
.join(
|
|
200
|
+
models.DatasetSplit,
|
|
201
|
+
models.DatasetSplit.id == models.DatasetSplitDatasetExample.dataset_split_id,
|
|
202
|
+
)
|
|
203
|
+
.where(models.DatasetSplit.name.in_(split_names))
|
|
204
|
+
)
|
|
205
|
+
stmt = stmt.where(models.DatasetExample.id.in_(split_example_ids_subquery))
|
|
206
|
+
else:
|
|
207
|
+
assert split_ids is not None
|
|
208
|
+
split_example_ids_subquery = select(
|
|
209
|
+
models.DatasetSplitDatasetExample.dataset_example_id
|
|
210
|
+
).where(models.DatasetSplitDatasetExample.dataset_split_id.in_(split_ids))
|
|
211
|
+
stmt = stmt.where(models.DatasetExample.id.in_(split_example_ids_subquery))
|
|
212
|
+
|
|
213
|
+
ranked_subquery = stmt.subquery()
|
|
150
214
|
return (
|
|
151
|
-
select(
|
|
152
|
-
.where(table.revision_kind != "DELETE")
|
|
215
|
+
select(models.DatasetExampleRevision)
|
|
153
216
|
.join(
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
217
|
+
ranked_subquery,
|
|
218
|
+
models.DatasetExampleRevision.id == ranked_subquery.c.id,
|
|
219
|
+
)
|
|
220
|
+
.where(
|
|
221
|
+
ranked_subquery.c.rn == 1,
|
|
222
|
+
ranked_subquery.c.revision_kind != "DELETE",
|
|
159
223
|
)
|
|
160
224
|
)
|
|
161
225
|
|
|
162
226
|
|
|
227
|
+
def create_experiment_examples_snapshot_insert(
|
|
228
|
+
experiment: models.Experiment,
|
|
229
|
+
) -> Insert:
|
|
230
|
+
"""
|
|
231
|
+
Create an INSERT statement to snapshot dataset examples for an experiment.
|
|
232
|
+
|
|
233
|
+
This captures which examples belong to the experiment at the time of creation,
|
|
234
|
+
respecting any dataset splits assigned to the experiment.
|
|
235
|
+
|
|
236
|
+
Args:
|
|
237
|
+
experiment: The experiment to create the snapshot for
|
|
238
|
+
|
|
239
|
+
Returns:
|
|
240
|
+
SQLAlchemy INSERT statement ready for execution
|
|
241
|
+
"""
|
|
242
|
+
stmt = _build_ranked_revisions_query(
|
|
243
|
+
experiment.dataset_version_id,
|
|
244
|
+
dataset_id=experiment.dataset_id,
|
|
245
|
+
).add_columns(
|
|
246
|
+
models.DatasetExampleRevision.id,
|
|
247
|
+
models.DatasetExampleRevision.dataset_example_id,
|
|
248
|
+
models.DatasetExampleRevision.revision_kind,
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
experiment_splits_subquery = select(models.ExperimentDatasetSplit.dataset_split_id).where(
|
|
252
|
+
models.ExperimentDatasetSplit.experiment_id == experiment.id
|
|
253
|
+
)
|
|
254
|
+
has_splits_condition = exists(experiment_splits_subquery)
|
|
255
|
+
split_filtered_example_ids = select(models.DatasetSplitDatasetExample.dataset_example_id).where(
|
|
256
|
+
models.DatasetSplitDatasetExample.dataset_split_id.in_(experiment_splits_subquery)
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
stmt = stmt.where(
|
|
260
|
+
or_(
|
|
261
|
+
~has_splits_condition, # No splits = include all examples
|
|
262
|
+
models.DatasetExampleRevision.dataset_example_id.in_(
|
|
263
|
+
split_filtered_example_ids
|
|
264
|
+
), # Has splits = filter by splits
|
|
265
|
+
)
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
ranked_subquery = stmt.subquery()
|
|
269
|
+
return insert(models.ExperimentDatasetExample).from_select(
|
|
270
|
+
[
|
|
271
|
+
models.ExperimentDatasetExample.experiment_id,
|
|
272
|
+
models.ExperimentDatasetExample.dataset_example_id,
|
|
273
|
+
models.ExperimentDatasetExample.dataset_example_revision_id,
|
|
274
|
+
],
|
|
275
|
+
select(
|
|
276
|
+
literal(experiment.id),
|
|
277
|
+
ranked_subquery.c.dataset_example_id,
|
|
278
|
+
ranked_subquery.c.id,
|
|
279
|
+
).where(
|
|
280
|
+
ranked_subquery.c.rn == 1,
|
|
281
|
+
ranked_subquery.c.revision_kind != "DELETE",
|
|
282
|
+
),
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
async def insert_experiment_with_examples_snapshot(
|
|
287
|
+
session: AsyncSession,
|
|
288
|
+
experiment: models.Experiment,
|
|
289
|
+
) -> None:
|
|
290
|
+
"""
|
|
291
|
+
Insert an experiment with its snapshot of dataset examples.
|
|
292
|
+
"""
|
|
293
|
+
session.add(experiment)
|
|
294
|
+
await session.flush()
|
|
295
|
+
insert_stmt = create_experiment_examples_snapshot_insert(experiment)
|
|
296
|
+
await session.execute(insert_stmt)
|
|
297
|
+
|
|
298
|
+
|
|
163
299
|
_AnyTuple = TypeVar("_AnyTuple", bound=tuple[Any, ...])
|
|
164
300
|
|
|
165
301
|
|
|
@@ -173,3 +309,802 @@ def exclude_experiment_projects(
|
|
|
173
309
|
models.Experiment.project_name != PLAYGROUND_PROJECT_NAME,
|
|
174
310
|
),
|
|
175
311
|
).where(models.Experiment.project_name.is_(None))
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
def date_trunc(
|
|
315
|
+
dialect: SupportedSQLDialect,
|
|
316
|
+
field: Literal["minute", "hour", "day", "week", "month", "year"],
|
|
317
|
+
source: Union[QueryableAttribute[datetime], sa.TextClause],
|
|
318
|
+
utc_offset_minutes: int = 0,
|
|
319
|
+
) -> SQLColumnExpression[datetime]:
|
|
320
|
+
"""
|
|
321
|
+
Truncate a datetime to the specified field with optional UTC offset adjustment.
|
|
322
|
+
|
|
323
|
+
This function provides a cross-dialect way to truncate datetime values to a specific
|
|
324
|
+
time unit (minute, hour, day, week, month, or year). It handles UTC offset conversion
|
|
325
|
+
by applying the offset before truncation and then converting back to UTC.
|
|
326
|
+
|
|
327
|
+
Args:
|
|
328
|
+
dialect: The SQL dialect to use (PostgreSQL or SQLite).
|
|
329
|
+
field: The time unit to truncate to. Valid values are:
|
|
330
|
+
- "minute": Truncate to the start of the minute (seconds set to 0)
|
|
331
|
+
- "hour": Truncate to the start of the hour (minutes and seconds set to 0)
|
|
332
|
+
- "day": Truncate to the start of the day (time set to 00:00:00)
|
|
333
|
+
- "week": Truncate to the start of the week (Monday at 00:00:00)
|
|
334
|
+
- "month": Truncate to the first day of the month (day set to 1, time to 00:00:00)
|
|
335
|
+
- "year": Truncate to the first day of the year (date set to Jan 1, time to 00:00:00)
|
|
336
|
+
source: The datetime column or expression to truncate.
|
|
337
|
+
utc_offset_minutes: UTC offset in minutes to apply before truncation.
|
|
338
|
+
Positive values represent time zones ahead of UTC (e.g., +60 for UTC+1).
|
|
339
|
+
Negative values represent time zones behind UTC (e.g., -300 for UTC-5).
|
|
340
|
+
Defaults to 0 (no offset).
|
|
341
|
+
|
|
342
|
+
Returns:
|
|
343
|
+
A SQL column expression representing the truncated datetime in UTC.
|
|
344
|
+
|
|
345
|
+
Note:
|
|
346
|
+
- For PostgreSQL, uses the native `date_trunc` function with timezone support.
|
|
347
|
+
- For SQLite, implements custom truncation logic using datetime functions.
|
|
348
|
+
- Week truncation starts on Monday (ISO 8601 standard).
|
|
349
|
+
- The result is always returned in UTC, regardless of the input offset.
|
|
350
|
+
|
|
351
|
+
Examples:
|
|
352
|
+
>>> # Truncate to hour with no offset
|
|
353
|
+
>>> date_trunc(SupportedSQLDialect.POSTGRESQL, "hour", Span.start_time)
|
|
354
|
+
|
|
355
|
+
>>> # Truncate to day with UTC-5 offset (Eastern Time)
|
|
356
|
+
>>> date_trunc(SupportedSQLDialect.SQLITE, "day", Span.start_time, -300)
|
|
357
|
+
"""
|
|
358
|
+
if dialect is SupportedSQLDialect.POSTGRESQL:
|
|
359
|
+
# Note: the usage of the timezone parameter in the form of e.g. "+05:00"
|
|
360
|
+
# appears to be an undocumented feature of PostgreSQL's date_trunc function.
|
|
361
|
+
# Below is an example query and its output executed on PostgreSQL v12 and v17.
|
|
362
|
+
# SELECT date_trunc('day', TIMESTAMP WITH TIME ZONE '2001-02-16 15:38:40-05'),
|
|
363
|
+
# date_trunc('day', TIMESTAMP WITH TIME ZONE '2001-02-16 20:38:40+00', '+05:00'),
|
|
364
|
+
# date_trunc('day', TIMESTAMP WITH TIME ZONE '2001-02-16 20:38:40+00', '-05:00');
|
|
365
|
+
# ┌────────────────────────┬────────────────────────┬────────────────────────┐
|
|
366
|
+
# │ date_trunc │ date_trunc │ date_trunc │
|
|
367
|
+
# ├────────────────────────┼────────────────────────┼────────────────────────┤
|
|
368
|
+
# │ 2001-02-16 00:00:00+00 │ 2001-02-16 05:00:00+00 │ 2001-02-16 19:00:00+00 │
|
|
369
|
+
# └────────────────────────┴────────────────────────┴────────────────────────┘
|
|
370
|
+
# (1 row)
|
|
371
|
+
sign = "-" if utc_offset_minutes >= 0 else "+"
|
|
372
|
+
timezone = f"{sign}{abs(utc_offset_minutes) // 60}:{abs(utc_offset_minutes) % 60:02d}"
|
|
373
|
+
return sa.func.date_trunc(field, source, timezone)
|
|
374
|
+
elif dialect is SupportedSQLDialect.SQLITE:
|
|
375
|
+
return _date_trunc_for_sqlite(field, source, utc_offset_minutes)
|
|
376
|
+
else:
|
|
377
|
+
assert_never(dialect)
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
def _date_trunc_for_sqlite(
|
|
381
|
+
field: Literal["minute", "hour", "day", "week", "month", "year"],
|
|
382
|
+
source: Union[QueryableAttribute[datetime], sa.TextClause],
|
|
383
|
+
utc_offset_minutes: int = 0,
|
|
384
|
+
) -> SQLColumnExpression[datetime]:
|
|
385
|
+
"""
|
|
386
|
+
SQLite-specific implementation of datetime truncation with UTC offset handling.
|
|
387
|
+
|
|
388
|
+
This private helper function implements date truncation for SQLite databases, which
|
|
389
|
+
lack a native date_trunc function. It uses SQLite's datetime and strftime functions
|
|
390
|
+
to achieve the same result as PostgreSQL's date_trunc function.
|
|
391
|
+
|
|
392
|
+
Args:
|
|
393
|
+
field: The time unit to truncate to. Valid values are:
|
|
394
|
+
- "minute": Truncate to the start of the minute (seconds set to 0)
|
|
395
|
+
- "hour": Truncate to the start of the hour (minutes and seconds set to 0)
|
|
396
|
+
- "day": Truncate to the start of the day (time set to 00:00:00)
|
|
397
|
+
- "week": Truncate to the start of the week (Monday at 00:00:00)
|
|
398
|
+
- "month": Truncate to the first day of the month (day set to 1, time to 00:00:00)
|
|
399
|
+
- "year": Truncate to the first day of the year (date set to Jan 1, time to 00:00:00)
|
|
400
|
+
source: The datetime column or expression to truncate.
|
|
401
|
+
utc_offset_minutes: UTC offset in minutes to apply before truncation.
|
|
402
|
+
Positive values represent time zones ahead of UTC (e.g., +60 for UTC+1).
|
|
403
|
+
Negative values represent time zones behind UTC (e.g., -300 for UTC-5).
|
|
404
|
+
|
|
405
|
+
Returns:
|
|
406
|
+
A SQL column expression representing the truncated datetime in UTC.
|
|
407
|
+
|
|
408
|
+
Implementation Details:
|
|
409
|
+
- Uses SQLite's strftime() function to format and extract date components
|
|
410
|
+
- Applies UTC offset before truncation using datetime(source, "N minutes")
|
|
411
|
+
- Converts result back to UTC by subtracting the offset
|
|
412
|
+
- Week truncation uses day-of-week calculations where:
|
|
413
|
+
* strftime('%w') returns 0=Sunday, 1=Monday, ..., 6=Saturday
|
|
414
|
+
* Truncates to Monday (start of week) using case-based day adjustments
|
|
415
|
+
- Month/year truncation reconstructs dates using extracted components
|
|
416
|
+
|
|
417
|
+
Raises:
|
|
418
|
+
ValueError: If the field parameter is not one of the supported values.
|
|
419
|
+
|
|
420
|
+
Note:
|
|
421
|
+
This is a private helper function intended only for use by the date_trunc function
|
|
422
|
+
when the dialect is SupportedSQLDialect.SQLITE.
|
|
423
|
+
"""
|
|
424
|
+
# SQLite does not have a built-in date truncation function, so we use datetime functions
|
|
425
|
+
# First apply UTC offset, then truncate
|
|
426
|
+
offset_source = func.datetime(source, f"{utc_offset_minutes} minutes")
|
|
427
|
+
|
|
428
|
+
if field == "minute":
|
|
429
|
+
t = func.datetime(func.strftime("%Y-%m-%d %H:%M:00", offset_source))
|
|
430
|
+
elif field == "hour":
|
|
431
|
+
t = func.datetime(func.strftime("%Y-%m-%d %H:00:00", offset_source))
|
|
432
|
+
elif field == "day":
|
|
433
|
+
t = func.datetime(func.strftime("%Y-%m-%d 00:00:00", offset_source))
|
|
434
|
+
elif field == "week":
|
|
435
|
+
# Truncate to Monday (start of week)
|
|
436
|
+
# SQLite strftime('%w') returns: 0=Sunday, 1=Monday, ..., 6=Saturday
|
|
437
|
+
dow = func.strftime("%w", offset_source)
|
|
438
|
+
t = func.datetime(
|
|
439
|
+
case(
|
|
440
|
+
(dow == "0", func.date(offset_source, "-6 days")), # Sunday -> go back 6 days
|
|
441
|
+
(dow == "1", func.date(offset_source, "+0 days")), # Monday -> stay
|
|
442
|
+
(dow == "2", func.date(offset_source, "-1 days")), # Tuesday -> go back 1 day
|
|
443
|
+
(dow == "3", func.date(offset_source, "-2 days")), # Wednesday -> go back 2 days
|
|
444
|
+
(dow == "4", func.date(offset_source, "-3 days")), # Thursday -> go back 3 days
|
|
445
|
+
(dow == "5", func.date(offset_source, "-4 days")), # Friday -> go back 4 days
|
|
446
|
+
(dow == "6", func.date(offset_source, "-5 days")), # Saturday -> go back 5 days
|
|
447
|
+
),
|
|
448
|
+
"00:00:00",
|
|
449
|
+
)
|
|
450
|
+
elif field == "month":
|
|
451
|
+
# Extract year and month, then construct first day of month
|
|
452
|
+
year = func.strftime("%Y", offset_source)
|
|
453
|
+
month = func.strftime("%m", offset_source)
|
|
454
|
+
t = func.datetime(year + "-" + month + "-01 00:00:00")
|
|
455
|
+
elif field == "year":
|
|
456
|
+
# Extract year, then construct first day of year
|
|
457
|
+
year = func.strftime("%Y", offset_source)
|
|
458
|
+
t = func.datetime(year + "-01-01 00:00:00")
|
|
459
|
+
else:
|
|
460
|
+
raise ValueError(f"Unsupported field for date truncation: {field}")
|
|
461
|
+
|
|
462
|
+
# Convert back to UTC by subtracting the offset
|
|
463
|
+
return func.datetime(t, f"{-utc_offset_minutes} minutes")
|
|
464
|
+
|
|
465
|
+
|
|
466
|
+
def get_ancestor_span_rowids(parent_id: str) -> Select[tuple[int]]:
|
|
467
|
+
"""
|
|
468
|
+
Get all ancestor span IDs for a given parent_id using recursive CTE.
|
|
469
|
+
|
|
470
|
+
This function returns a query that finds all ancestors of a span with the given parent_id.
|
|
471
|
+
It uses a recursive Common Table Expression (CTE) to traverse up the span hierarchy.
|
|
472
|
+
|
|
473
|
+
Args:
|
|
474
|
+
parent_id: The span_id of the parent span to start the ancestor search from.
|
|
475
|
+
|
|
476
|
+
Returns:
|
|
477
|
+
A Select query that returns tuples of (span_id,) for all ancestor spans.
|
|
478
|
+
"""
|
|
479
|
+
ancestors = (
|
|
480
|
+
select(models.Span.id, models.Span.parent_id)
|
|
481
|
+
.where(models.Span.span_id == parent_id)
|
|
482
|
+
.cte(recursive=True)
|
|
483
|
+
)
|
|
484
|
+
child = ancestors.alias()
|
|
485
|
+
ancestors = ancestors.union_all(
|
|
486
|
+
select(models.Span.id, models.Span.parent_id).join(
|
|
487
|
+
child, models.Span.span_id == child.c.parent_id
|
|
488
|
+
)
|
|
489
|
+
)
|
|
490
|
+
return select(ancestors.c.id)
|
|
491
|
+
|
|
492
|
+
|
|
493
|
+
def truncate_name(name: str, max_len: int = 63) -> str:
|
|
494
|
+
# https://github.com/sqlalchemy/sqlalchemy/blob/e263825e3c5060bf4f47eed0e833c6660a31658e/lib/sqlalchemy/sql/compiler.py#L7844-L7845
|
|
495
|
+
if len(name) > max_len:
|
|
496
|
+
return name[0 : max_len - 8] + "_" + util.md5_hex(name)[-4:]
|
|
497
|
+
return name
|
|
498
|
+
|
|
499
|
+
|
|
500
|
+
def get_successful_run_counts_subquery(
|
|
501
|
+
experiment_id: int,
|
|
502
|
+
repetitions: int,
|
|
503
|
+
) -> Any:
|
|
504
|
+
"""
|
|
505
|
+
Build a subquery that counts successful runs per dataset example for an experiment.
|
|
506
|
+
|
|
507
|
+
This subquery outer joins experiment dataset examples with their runs, counting only
|
|
508
|
+
successful runs (runs that exist and have no error). The HAVING clause filters to only
|
|
509
|
+
include examples with fewer successful runs than the total repetitions required.
|
|
510
|
+
|
|
511
|
+
Args:
|
|
512
|
+
experiment_id: The experiment ID to query runs for
|
|
513
|
+
repetitions: The number of repetitions required per example
|
|
514
|
+
|
|
515
|
+
Returns:
|
|
516
|
+
SQLAlchemy subquery with columns:
|
|
517
|
+
- dataset_example_revision_id: ID of the example revision
|
|
518
|
+
- dataset_example_id: ID of the dataset example
|
|
519
|
+
- successful_count: Count of successful runs for this example
|
|
520
|
+
"""
|
|
521
|
+
# Use CASE to count only successful runs (run exists AND error IS NULL)
|
|
522
|
+
# Important: Must check that run exists (id IS NOT NULL) to distinguish
|
|
523
|
+
# "no run" from "successful run" in the outer join
|
|
524
|
+
successful_run_case = case(
|
|
525
|
+
(
|
|
526
|
+
and_(
|
|
527
|
+
models.ExperimentRun.id.is_not(None), # Run exists
|
|
528
|
+
models.ExperimentRun.error.is_(None), # No error (successful)
|
|
529
|
+
),
|
|
530
|
+
1,
|
|
531
|
+
),
|
|
532
|
+
else_=0,
|
|
533
|
+
)
|
|
534
|
+
|
|
535
|
+
return (
|
|
536
|
+
select(
|
|
537
|
+
models.ExperimentDatasetExample.dataset_example_revision_id,
|
|
538
|
+
models.ExperimentDatasetExample.dataset_example_id,
|
|
539
|
+
func.sum(successful_run_case).label("successful_count"),
|
|
540
|
+
)
|
|
541
|
+
.select_from(models.ExperimentDatasetExample)
|
|
542
|
+
.outerjoin(
|
|
543
|
+
models.ExperimentRun,
|
|
544
|
+
and_(
|
|
545
|
+
models.ExperimentRun.experiment_id == experiment_id,
|
|
546
|
+
models.ExperimentRun.dataset_example_id
|
|
547
|
+
== models.ExperimentDatasetExample.dataset_example_id,
|
|
548
|
+
),
|
|
549
|
+
)
|
|
550
|
+
.where(models.ExperimentDatasetExample.experiment_id == experiment_id)
|
|
551
|
+
.group_by(
|
|
552
|
+
models.ExperimentDatasetExample.dataset_example_revision_id,
|
|
553
|
+
models.ExperimentDatasetExample.dataset_example_id,
|
|
554
|
+
)
|
|
555
|
+
.having(
|
|
556
|
+
# Only include incomplete examples (successful_count < repetitions)
|
|
557
|
+
func.coalesce(func.sum(successful_run_case), 0) < repetitions
|
|
558
|
+
)
|
|
559
|
+
.subquery()
|
|
560
|
+
)
|
|
561
|
+
|
|
562
|
+
|
|
563
|
+
def generate_expected_repetitions_cte(
|
|
564
|
+
dialect: SupportedSQLDialect,
|
|
565
|
+
run_counts_subquery: Any,
|
|
566
|
+
repetitions: int,
|
|
567
|
+
) -> Any:
|
|
568
|
+
"""
|
|
569
|
+
Generate a CTE that produces all expected repetition numbers for partially complete examples.
|
|
570
|
+
|
|
571
|
+
This generates a sequence of repetition numbers [1..repetitions] for each example that has
|
|
572
|
+
at least one successful run (0 < successful_count < repetitions). The implementation varies
|
|
573
|
+
by SQL dialect.
|
|
574
|
+
|
|
575
|
+
Args:
|
|
576
|
+
dialect: The SQL dialect to use (PostgreSQL or SQLite)
|
|
577
|
+
run_counts_subquery: Subquery from get_successful_run_counts_subquery containing
|
|
578
|
+
dataset_example_revision_id, dataset_example_id, and successful_count columns
|
|
579
|
+
repetitions: The total number of repetitions required
|
|
580
|
+
|
|
581
|
+
Returns:
|
|
582
|
+
SQLAlchemy CTE with columns:
|
|
583
|
+
- dataset_example_revision_id: ID of the example revision
|
|
584
|
+
- dataset_example_id: ID of the dataset example
|
|
585
|
+
- successful_count: Count of successful runs for this example
|
|
586
|
+
- repetition_number: Expected repetition number (1..repetitions)
|
|
587
|
+
|
|
588
|
+
Note:
|
|
589
|
+
- For PostgreSQL: Uses generate_series function
|
|
590
|
+
- For SQLite: Uses recursive CTE to generate the sequence
|
|
591
|
+
"""
|
|
592
|
+
if dialect is SupportedSQLDialect.POSTGRESQL:
|
|
593
|
+
# Generate expected repetition numbers only for partially complete examples
|
|
594
|
+
# Use func.generate_series with direct parameter - SQLAlchemy handles this safely
|
|
595
|
+
return (
|
|
596
|
+
select(
|
|
597
|
+
run_counts_subquery.c.dataset_example_revision_id,
|
|
598
|
+
run_counts_subquery.c.dataset_example_id,
|
|
599
|
+
run_counts_subquery.c.successful_count,
|
|
600
|
+
func.generate_series(1, repetitions).label("repetition_number"),
|
|
601
|
+
)
|
|
602
|
+
.select_from(run_counts_subquery)
|
|
603
|
+
.where(run_counts_subquery.c.successful_count > 0) # Only partially complete!
|
|
604
|
+
.cte("expected_runs")
|
|
605
|
+
)
|
|
606
|
+
elif dialect is SupportedSQLDialect.SQLITE:
|
|
607
|
+
# Recursive CTE only for partially complete examples
|
|
608
|
+
expected_runs_cte = (
|
|
609
|
+
select(
|
|
610
|
+
run_counts_subquery.c.dataset_example_revision_id,
|
|
611
|
+
run_counts_subquery.c.dataset_example_id,
|
|
612
|
+
run_counts_subquery.c.successful_count,
|
|
613
|
+
literal_column("1").label("repetition_number"),
|
|
614
|
+
)
|
|
615
|
+
.select_from(run_counts_subquery)
|
|
616
|
+
.where(run_counts_subquery.c.successful_count > 0) # Only partially complete!
|
|
617
|
+
.cte("expected_runs", recursive=True)
|
|
618
|
+
)
|
|
619
|
+
|
|
620
|
+
# Recursive part: increment repetition_number up to repetitions
|
|
621
|
+
expected_runs_recursive = expected_runs_cte.union_all(
|
|
622
|
+
select(
|
|
623
|
+
expected_runs_cte.c.dataset_example_revision_id,
|
|
624
|
+
expected_runs_cte.c.dataset_example_id,
|
|
625
|
+
expected_runs_cte.c.successful_count,
|
|
626
|
+
(expected_runs_cte.c.repetition_number + 1).label("repetition_number"),
|
|
627
|
+
).where(expected_runs_cte.c.repetition_number < repetitions)
|
|
628
|
+
)
|
|
629
|
+
|
|
630
|
+
return expected_runs_recursive
|
|
631
|
+
else:
|
|
632
|
+
assert_never(dialect)
|
|
633
|
+
|
|
634
|
+
|
|
635
|
+
def get_incomplete_repetitions_query(
|
|
636
|
+
dialect: SupportedSQLDialect,
|
|
637
|
+
expected_runs_cte: Any,
|
|
638
|
+
experiment_id: int,
|
|
639
|
+
) -> Select[tuple[Any, Any, Any]]:
|
|
640
|
+
"""
|
|
641
|
+
Build a query that finds incomplete repetitions for partially complete examples.
|
|
642
|
+
|
|
643
|
+
This query outer joins the expected repetition numbers with actual successful runs to find
|
|
644
|
+
which repetitions are missing or failed. It aggregates the incomplete repetitions into an
|
|
645
|
+
array or JSON array depending on the dialect.
|
|
646
|
+
|
|
647
|
+
Args:
|
|
648
|
+
dialect: The SQL dialect to use (PostgreSQL or SQLite)
|
|
649
|
+
expected_runs_cte: CTE from generate_expected_repetitions_cte containing expected
|
|
650
|
+
repetition numbers for partially complete examples
|
|
651
|
+
experiment_id: The experiment ID to query runs for
|
|
652
|
+
|
|
653
|
+
Returns:
|
|
654
|
+
SQLAlchemy SELECT query with columns:
|
|
655
|
+
- dataset_example_revision_id: ID of the example revision
|
|
656
|
+
- successful_count: Count of successful runs for this example
|
|
657
|
+
- incomplete_reps: Array/JSON array of incomplete repetition numbers
|
|
658
|
+
|
|
659
|
+
Note:
|
|
660
|
+
- For PostgreSQL: Returns an array using array_agg
|
|
661
|
+
- For SQLite: Returns a JSON string using json_group_array
|
|
662
|
+
"""
|
|
663
|
+
if dialect is SupportedSQLDialect.POSTGRESQL:
|
|
664
|
+
agg_func = func.coalesce(
|
|
665
|
+
func.array_agg(expected_runs_cte.c.repetition_number),
|
|
666
|
+
literal_column("ARRAY[]::int[]"),
|
|
667
|
+
)
|
|
668
|
+
elif dialect is SupportedSQLDialect.SQLITE:
|
|
669
|
+
agg_func = func.coalesce(
|
|
670
|
+
func.json_group_array(expected_runs_cte.c.repetition_number),
|
|
671
|
+
literal_column("'[]'"),
|
|
672
|
+
)
|
|
673
|
+
else:
|
|
674
|
+
assert_never(dialect)
|
|
675
|
+
|
|
676
|
+
# Find incomplete runs for partially complete examples
|
|
677
|
+
return (
|
|
678
|
+
select(
|
|
679
|
+
expected_runs_cte.c.dataset_example_revision_id,
|
|
680
|
+
expected_runs_cte.c.successful_count,
|
|
681
|
+
agg_func.label("incomplete_reps"),
|
|
682
|
+
)
|
|
683
|
+
.select_from(expected_runs_cte)
|
|
684
|
+
.outerjoin(
|
|
685
|
+
models.ExperimentRun,
|
|
686
|
+
and_(
|
|
687
|
+
models.ExperimentRun.experiment_id == experiment_id,
|
|
688
|
+
models.ExperimentRun.dataset_example_id == expected_runs_cte.c.dataset_example_id,
|
|
689
|
+
models.ExperimentRun.repetition_number == expected_runs_cte.c.repetition_number,
|
|
690
|
+
# Only join successful runs
|
|
691
|
+
models.ExperimentRun.error.is_(None),
|
|
692
|
+
),
|
|
693
|
+
)
|
|
694
|
+
.where(
|
|
695
|
+
# Incomplete = no matching run (NULL)
|
|
696
|
+
models.ExperimentRun.id.is_(None)
|
|
697
|
+
)
|
|
698
|
+
.group_by(
|
|
699
|
+
expected_runs_cte.c.dataset_example_revision_id,
|
|
700
|
+
expected_runs_cte.c.successful_count,
|
|
701
|
+
)
|
|
702
|
+
)
|
|
703
|
+
|
|
704
|
+
|
|
705
|
+
def get_incomplete_runs_with_revisions_query(
|
|
706
|
+
incomplete_runs_subquery: Any,
|
|
707
|
+
*,
|
|
708
|
+
cursor_example_rowid: Optional[int] = None,
|
|
709
|
+
limit: Optional[int] = None,
|
|
710
|
+
) -> Select[tuple[models.DatasetExampleRevision, Any, Any]]:
|
|
711
|
+
"""
|
|
712
|
+
Build the main query that joins incomplete runs with dataset example revisions.
|
|
713
|
+
|
|
714
|
+
This query takes a subquery containing incomplete run information and joins it with
|
|
715
|
+
the DatasetExampleRevision table to get the full example data. It also applies
|
|
716
|
+
cursor-based pagination for efficient retrieval of large result sets.
|
|
717
|
+
|
|
718
|
+
Args:
|
|
719
|
+
incomplete_runs_subquery: Subquery with columns:
|
|
720
|
+
- dataset_example_revision_id: ID of the example revision
|
|
721
|
+
- successful_count: Count of successful runs for this example
|
|
722
|
+
- incomplete_reps: Array/JSON array of incomplete repetition numbers
|
|
723
|
+
cursor_example_rowid: Optional cursor position (dataset_example_id) for pagination.
|
|
724
|
+
When provided, only returns examples with ID >= cursor_example_rowid
|
|
725
|
+
limit: Optional maximum number of results to return. If provided, the query
|
|
726
|
+
will fetch limit+1 rows to enable next-page detection
|
|
727
|
+
|
|
728
|
+
Returns:
|
|
729
|
+
SQLAlchemy SELECT query with columns:
|
|
730
|
+
- DatasetExampleRevision: The full revision object
|
|
731
|
+
- successful_count: Count of successful runs
|
|
732
|
+
- incomplete_reps: Array/JSON array of incomplete repetition numbers
|
|
733
|
+
|
|
734
|
+
Note:
|
|
735
|
+
Results are ordered by dataset_example_id ascending for consistent pagination.
|
|
736
|
+
When using limit, fetch one extra row to check if there's a next page.
|
|
737
|
+
"""
|
|
738
|
+
stmt = (
|
|
739
|
+
select(
|
|
740
|
+
models.DatasetExampleRevision,
|
|
741
|
+
incomplete_runs_subquery.c.successful_count,
|
|
742
|
+
incomplete_runs_subquery.c.incomplete_reps,
|
|
743
|
+
)
|
|
744
|
+
.select_from(incomplete_runs_subquery)
|
|
745
|
+
.join(
|
|
746
|
+
models.DatasetExampleRevision,
|
|
747
|
+
models.DatasetExampleRevision.id
|
|
748
|
+
== incomplete_runs_subquery.c.dataset_example_revision_id,
|
|
749
|
+
)
|
|
750
|
+
.order_by(models.DatasetExampleRevision.dataset_example_id.asc())
|
|
751
|
+
)
|
|
752
|
+
|
|
753
|
+
# Apply cursor filter in SQL for efficiency with large datasets
|
|
754
|
+
if cursor_example_rowid is not None:
|
|
755
|
+
stmt = stmt.where(models.DatasetExampleRevision.dataset_example_id >= cursor_example_rowid)
|
|
756
|
+
|
|
757
|
+
# Fetch limit+1 to check if there's a next page
|
|
758
|
+
if limit is not None:
|
|
759
|
+
stmt = stmt.limit(limit + 1)
|
|
760
|
+
|
|
761
|
+
return stmt
|
|
762
|
+
|
|
763
|
+
|
|
764
|
+
def get_successful_experiment_runs_query(
|
|
765
|
+
experiment_id: int,
|
|
766
|
+
*,
|
|
767
|
+
cursor_run_rowid: Optional[int] = None,
|
|
768
|
+
limit: Optional[int] = None,
|
|
769
|
+
) -> Select[tuple[models.ExperimentRun, int]]:
|
|
770
|
+
"""
|
|
771
|
+
Build a query for successful experiment runs with their dataset example revision IDs.
|
|
772
|
+
|
|
773
|
+
This query retrieves all experiment runs that completed successfully (error IS NULL)
|
|
774
|
+
and joins them with the ExperimentDatasetExample table to get the revision IDs.
|
|
775
|
+
Results are ordered by run ID ascending for consistent pagination.
|
|
776
|
+
|
|
777
|
+
Args:
|
|
778
|
+
experiment_id: The experiment ID to query runs for
|
|
779
|
+
cursor_run_rowid: Optional cursor position (experiment_run_id) for pagination.
|
|
780
|
+
When provided, only returns runs with ID >= cursor_run_rowid
|
|
781
|
+
limit: Optional maximum number of results to return. If provided, the query
|
|
782
|
+
will fetch limit+1 rows to enable next-page detection
|
|
783
|
+
|
|
784
|
+
Returns:
|
|
785
|
+
SQLAlchemy SELECT query with columns:
|
|
786
|
+
- ExperimentRun: The full experiment run object
|
|
787
|
+
- dataset_example_revision_id: ID of the dataset example revision (int)
|
|
788
|
+
|
|
789
|
+
Note:
|
|
790
|
+
- Only includes successful runs (error IS NULL)
|
|
791
|
+
- Results ordered by run ID ascending for consistent pagination
|
|
792
|
+
- When using limit, fetch one extra row to check if there's a next page
|
|
793
|
+
"""
|
|
794
|
+
stmt = (
|
|
795
|
+
select(
|
|
796
|
+
models.ExperimentRun,
|
|
797
|
+
models.ExperimentDatasetExample.dataset_example_revision_id,
|
|
798
|
+
)
|
|
799
|
+
.join(
|
|
800
|
+
models.ExperimentDatasetExample,
|
|
801
|
+
and_(
|
|
802
|
+
models.ExperimentDatasetExample.experiment_id == experiment_id,
|
|
803
|
+
models.ExperimentDatasetExample.dataset_example_id
|
|
804
|
+
== models.ExperimentRun.dataset_example_id,
|
|
805
|
+
),
|
|
806
|
+
)
|
|
807
|
+
.where(
|
|
808
|
+
and_(
|
|
809
|
+
models.ExperimentRun.experiment_id == experiment_id,
|
|
810
|
+
models.ExperimentRun.error.is_(None), # Only successful task runs
|
|
811
|
+
)
|
|
812
|
+
)
|
|
813
|
+
.order_by(models.ExperimentRun.id.asc())
|
|
814
|
+
)
|
|
815
|
+
|
|
816
|
+
if cursor_run_rowid is not None:
|
|
817
|
+
stmt = stmt.where(models.ExperimentRun.id >= cursor_run_rowid)
|
|
818
|
+
|
|
819
|
+
if limit is not None:
|
|
820
|
+
stmt = stmt.limit(limit + 1)
|
|
821
|
+
|
|
822
|
+
return stmt
|
|
823
|
+
|
|
824
|
+
|
|
825
|
+
def get_experiment_run_annotations_query(
|
|
826
|
+
run_ids: Sequence[int],
|
|
827
|
+
evaluation_names: Sequence[str],
|
|
828
|
+
) -> Select[tuple[int, str, Optional[str]]]:
|
|
829
|
+
"""
|
|
830
|
+
Build a query to get annotations for specific runs and evaluation names.
|
|
831
|
+
|
|
832
|
+
This query retrieves annotations (evaluations) for a set of experiment runs,
|
|
833
|
+
filtered by specific evaluation names. It returns only the essential fields
|
|
834
|
+
needed to determine if an evaluation is complete or has errors.
|
|
835
|
+
|
|
836
|
+
Args:
|
|
837
|
+
run_ids: List of experiment run IDs to query annotations for
|
|
838
|
+
evaluation_names: List of evaluation names to filter by
|
|
839
|
+
|
|
840
|
+
Returns:
|
|
841
|
+
SQLAlchemy SELECT query with columns:
|
|
842
|
+
- experiment_run_id: ID of the experiment run (int)
|
|
843
|
+
- name: Name of the evaluation (str)
|
|
844
|
+
- error: Error message if evaluation failed, None if successful (Optional[str])
|
|
845
|
+
|
|
846
|
+
Example:
|
|
847
|
+
>>> run_ids = [1, 2, 3]
|
|
848
|
+
>>> eval_names = ["relevance", "coherence"]
|
|
849
|
+
>>> query = get_experiment_run_annotations_query(run_ids, eval_names)
|
|
850
|
+
>>> results = await session.execute(query)
|
|
851
|
+
>>> for run_id, name, error in results:
|
|
852
|
+
... # Process annotations...
|
|
853
|
+
"""
|
|
854
|
+
return (
|
|
855
|
+
select(
|
|
856
|
+
models.ExperimentRunAnnotation.experiment_run_id,
|
|
857
|
+
models.ExperimentRunAnnotation.name,
|
|
858
|
+
models.ExperimentRunAnnotation.error,
|
|
859
|
+
)
|
|
860
|
+
.where(models.ExperimentRunAnnotation.experiment_run_id.in_(run_ids))
|
|
861
|
+
.where(models.ExperimentRunAnnotation.name.in_(evaluation_names))
|
|
862
|
+
)
|
|
863
|
+
|
|
864
|
+
|
|
865
|
+
def get_runs_with_incomplete_evaluations_query(
|
|
866
|
+
experiment_id: int,
|
|
867
|
+
evaluation_names: Sequence[str],
|
|
868
|
+
dialect: SupportedSQLDialect,
|
|
869
|
+
*,
|
|
870
|
+
cursor_run_rowid: Optional[int] = None,
|
|
871
|
+
limit: Optional[int] = None,
|
|
872
|
+
include_annotations_and_revisions: bool = False,
|
|
873
|
+
) -> Select[Any]:
|
|
874
|
+
"""
|
|
875
|
+
Get experiment runs that have incomplete evaluations.
|
|
876
|
+
|
|
877
|
+
A run has incomplete evaluations if it's missing successful annotations for any of
|
|
878
|
+
the requested evaluation names. Both missing (no annotation) and failed (error != NULL)
|
|
879
|
+
evaluations are considered incomplete.
|
|
880
|
+
|
|
881
|
+
Args:
|
|
882
|
+
experiment_id: The experiment ID to query
|
|
883
|
+
evaluation_names: Evaluation names to check for completeness
|
|
884
|
+
dialect: SQL dialect (PostgreSQL or SQLite)
|
|
885
|
+
cursor_run_rowid: Optional run ID for cursor-based pagination
|
|
886
|
+
limit: Optional limit (fetches limit+1 for next-page detection)
|
|
887
|
+
include_annotations_and_revisions: If True, also fetch revision and successful
|
|
888
|
+
annotation names as JSON array
|
|
889
|
+
|
|
890
|
+
Returns:
|
|
891
|
+
Query returning (ExperimentRun, revision_id, [revision, annotations_json])
|
|
892
|
+
Results ordered by run ID ascending
|
|
893
|
+
"""
|
|
894
|
+
# Subquery: Count successful annotations per run
|
|
895
|
+
successful_annotations_count = (
|
|
896
|
+
select(
|
|
897
|
+
models.ExperimentRunAnnotation.experiment_run_id,
|
|
898
|
+
func.count().label("successful_count"),
|
|
899
|
+
)
|
|
900
|
+
.where(
|
|
901
|
+
models.ExperimentRunAnnotation.name.in_(evaluation_names),
|
|
902
|
+
models.ExperimentRunAnnotation.error.is_(None),
|
|
903
|
+
)
|
|
904
|
+
.group_by(models.ExperimentRunAnnotation.experiment_run_id)
|
|
905
|
+
.subquery()
|
|
906
|
+
)
|
|
907
|
+
|
|
908
|
+
# Base query: Find runs where successful_count < required evaluations
|
|
909
|
+
stmt = (
|
|
910
|
+
select(
|
|
911
|
+
models.ExperimentRun,
|
|
912
|
+
models.ExperimentDatasetExample.dataset_example_revision_id,
|
|
913
|
+
)
|
|
914
|
+
.join(
|
|
915
|
+
models.ExperimentDatasetExample,
|
|
916
|
+
and_(
|
|
917
|
+
models.ExperimentDatasetExample.experiment_id == experiment_id,
|
|
918
|
+
models.ExperimentDatasetExample.dataset_example_id
|
|
919
|
+
== models.ExperimentRun.dataset_example_id,
|
|
920
|
+
),
|
|
921
|
+
)
|
|
922
|
+
.outerjoin(
|
|
923
|
+
successful_annotations_count,
|
|
924
|
+
successful_annotations_count.c.experiment_run_id == models.ExperimentRun.id,
|
|
925
|
+
)
|
|
926
|
+
.where(
|
|
927
|
+
models.ExperimentRun.experiment_id == experiment_id,
|
|
928
|
+
models.ExperimentRun.error.is_(None), # Only successful task runs
|
|
929
|
+
func.coalesce(successful_annotations_count.c.successful_count, 0)
|
|
930
|
+
< len(evaluation_names),
|
|
931
|
+
)
|
|
932
|
+
)
|
|
933
|
+
|
|
934
|
+
# Optionally include revisions and successful annotation names
|
|
935
|
+
if include_annotations_and_revisions:
|
|
936
|
+
# Subquery: Aggregate successful annotation names as JSON array
|
|
937
|
+
if dialect is SupportedSQLDialect.POSTGRESQL:
|
|
938
|
+
json_agg_expr = func.cast(
|
|
939
|
+
func.coalesce(
|
|
940
|
+
func.json_agg(models.ExperimentRunAnnotation.name),
|
|
941
|
+
literal_column("'[]'::json"),
|
|
942
|
+
),
|
|
943
|
+
sa.String,
|
|
944
|
+
)
|
|
945
|
+
else: # SQLite
|
|
946
|
+
json_agg_expr = func.cast(
|
|
947
|
+
func.coalesce(
|
|
948
|
+
func.json_group_array(models.ExperimentRunAnnotation.name),
|
|
949
|
+
literal_column("'[]'"),
|
|
950
|
+
),
|
|
951
|
+
sa.String,
|
|
952
|
+
)
|
|
953
|
+
|
|
954
|
+
successful_annotations_json = (
|
|
955
|
+
select(
|
|
956
|
+
models.ExperimentRunAnnotation.experiment_run_id,
|
|
957
|
+
json_agg_expr.label("annotations_json"),
|
|
958
|
+
)
|
|
959
|
+
.where(
|
|
960
|
+
models.ExperimentRunAnnotation.name.in_(evaluation_names),
|
|
961
|
+
models.ExperimentRunAnnotation.error.is_(None),
|
|
962
|
+
)
|
|
963
|
+
.group_by(models.ExperimentRunAnnotation.experiment_run_id)
|
|
964
|
+
.subquery()
|
|
965
|
+
)
|
|
966
|
+
|
|
967
|
+
stmt = (
|
|
968
|
+
stmt.add_columns(
|
|
969
|
+
models.DatasetExampleRevision,
|
|
970
|
+
successful_annotations_json.c.annotations_json,
|
|
971
|
+
)
|
|
972
|
+
.join(
|
|
973
|
+
models.DatasetExampleRevision,
|
|
974
|
+
models.DatasetExampleRevision.id
|
|
975
|
+
== models.ExperimentDatasetExample.dataset_example_revision_id,
|
|
976
|
+
)
|
|
977
|
+
.outerjoin(
|
|
978
|
+
successful_annotations_json,
|
|
979
|
+
successful_annotations_json.c.experiment_run_id == models.ExperimentRun.id,
|
|
980
|
+
)
|
|
981
|
+
)
|
|
982
|
+
|
|
983
|
+
# Apply ordering, cursor, and limit
|
|
984
|
+
stmt = stmt.order_by(models.ExperimentRun.id.asc())
|
|
985
|
+
|
|
986
|
+
if cursor_run_rowid is not None:
|
|
987
|
+
stmt = stmt.where(models.ExperimentRun.id >= cursor_run_rowid)
|
|
988
|
+
|
|
989
|
+
if limit is not None:
|
|
990
|
+
stmt = stmt.limit(limit + 1)
|
|
991
|
+
|
|
992
|
+
return stmt
|
|
993
|
+
|
|
994
|
+
|
|
995
|
+
def get_experiment_incomplete_runs_query(
|
|
996
|
+
experiment: models.Experiment,
|
|
997
|
+
dialect: SupportedSQLDialect,
|
|
998
|
+
*,
|
|
999
|
+
cursor_example_rowid: Optional[int] = None,
|
|
1000
|
+
limit: Optional[int] = None,
|
|
1001
|
+
) -> Select[tuple[models.DatasetExampleRevision, Any, Any]]:
|
|
1002
|
+
"""
|
|
1003
|
+
High-level helper to build a complete query for incomplete runs in an experiment.
|
|
1004
|
+
|
|
1005
|
+
This is the main entry point for querying incomplete runs. It encapsulates all the
|
|
1006
|
+
logic for finding runs that need to be completed, including both missing runs
|
|
1007
|
+
(not yet attempted) and failed runs (attempted but have errors).
|
|
1008
|
+
|
|
1009
|
+
The function automatically chooses the optimal query strategy:
|
|
1010
|
+
- For repetitions=1: Simple fast path (no CTE needed)
|
|
1011
|
+
- For repetitions>1: Two-path optimization separating completely missing examples
|
|
1012
|
+
from partially complete examples
|
|
1013
|
+
|
|
1014
|
+
Args:
|
|
1015
|
+
experiment: The Experiment model instance to query incomplete runs for
|
|
1016
|
+
dialect: The SQL dialect to use (PostgreSQL or SQLite)
|
|
1017
|
+
cursor_example_rowid: Optional cursor position (dataset_example_id) for pagination.
|
|
1018
|
+
When provided, only returns examples with ID >= cursor_example_rowid
|
|
1019
|
+
limit: Optional maximum number of results to return. If provided, the query
|
|
1020
|
+
will fetch limit+1 rows to enable next-page detection
|
|
1021
|
+
|
|
1022
|
+
Returns:
|
|
1023
|
+
SQLAlchemy SELECT query with columns:
|
|
1024
|
+
- DatasetExampleRevision: The full revision object with example data
|
|
1025
|
+
- successful_count: Count of successful runs for this example (int)
|
|
1026
|
+
- incomplete_reps: Incomplete repetition numbers as:
|
|
1027
|
+
* PostgreSQL: Array of ints (or empty array for completely missing)
|
|
1028
|
+
* SQLite: JSON string array (or '[]' for completely missing)
|
|
1029
|
+
|
|
1030
|
+
Note:
|
|
1031
|
+
For completely missing examples (successful_count=0), the incomplete_reps
|
|
1032
|
+
column will be an empty array/JSON. Callers should generate the full
|
|
1033
|
+
[1..repetitions] list when successful_count=0.
|
|
1034
|
+
|
|
1035
|
+
Example:
|
|
1036
|
+
>>> experiment = session.get(models.Experiment, experiment_id)
|
|
1037
|
+
>>> dialect = SupportedSQLDialect(session.bind.dialect.name)
|
|
1038
|
+
>>> query = get_experiment_incomplete_runs_query(
|
|
1039
|
+
... experiment, dialect, cursor_example_rowid=100, limit=50
|
|
1040
|
+
... )
|
|
1041
|
+
>>> results = await session.execute(query)
|
|
1042
|
+
>>> for revision, success_count, incomplete_reps in results:
|
|
1043
|
+
... # Process incomplete runs...
|
|
1044
|
+
"""
|
|
1045
|
+
# Step 1: Get successful run counts for incomplete examples
|
|
1046
|
+
run_counts_subquery = get_successful_run_counts_subquery(experiment.id, experiment.repetitions)
|
|
1047
|
+
|
|
1048
|
+
# Step 2: Build the combined incomplete runs subquery
|
|
1049
|
+
# The strategy depends on whether repetitions=1 or >1
|
|
1050
|
+
if experiment.repetitions == 1:
|
|
1051
|
+
# Fast path optimization for repetitions=1:
|
|
1052
|
+
# All incomplete examples have successful_count=0, so we can skip the expensive CTE
|
|
1053
|
+
empty_array: Any
|
|
1054
|
+
if dialect is SupportedSQLDialect.POSTGRESQL:
|
|
1055
|
+
empty_array = literal_column("ARRAY[]::int[]")
|
|
1056
|
+
elif dialect is SupportedSQLDialect.SQLITE:
|
|
1057
|
+
empty_array = literal_column("'[]'")
|
|
1058
|
+
else:
|
|
1059
|
+
assert_never(dialect)
|
|
1060
|
+
|
|
1061
|
+
combined_incomplete = (
|
|
1062
|
+
select(
|
|
1063
|
+
run_counts_subquery.c.dataset_example_revision_id,
|
|
1064
|
+
run_counts_subquery.c.successful_count,
|
|
1065
|
+
empty_array.label("incomplete_reps"),
|
|
1066
|
+
).select_from(run_counts_subquery)
|
|
1067
|
+
).subquery()
|
|
1068
|
+
else:
|
|
1069
|
+
# Two-path optimization for repetitions > 1:
|
|
1070
|
+
# Path 1: Completely missing examples (successful_count = 0) - no CTE needed
|
|
1071
|
+
# Path 2: Partially complete examples (0 < successful_count < R) - use CTE
|
|
1072
|
+
|
|
1073
|
+
# Path 1: Completely missing examples
|
|
1074
|
+
empty_array_inner: Any
|
|
1075
|
+
if dialect is SupportedSQLDialect.POSTGRESQL:
|
|
1076
|
+
empty_array_inner = literal_column("ARRAY[]::int[]")
|
|
1077
|
+
elif dialect is SupportedSQLDialect.SQLITE:
|
|
1078
|
+
empty_array_inner = literal_column("'[]'")
|
|
1079
|
+
else:
|
|
1080
|
+
assert_never(dialect)
|
|
1081
|
+
|
|
1082
|
+
completely_missing_stmt = (
|
|
1083
|
+
select(
|
|
1084
|
+
run_counts_subquery.c.dataset_example_revision_id,
|
|
1085
|
+
run_counts_subquery.c.successful_count,
|
|
1086
|
+
empty_array_inner.label("incomplete_reps"),
|
|
1087
|
+
)
|
|
1088
|
+
.select_from(run_counts_subquery)
|
|
1089
|
+
.where(run_counts_subquery.c.successful_count == 0)
|
|
1090
|
+
)
|
|
1091
|
+
|
|
1092
|
+
# Path 2: Partially complete examples
|
|
1093
|
+
expected_runs_cte = generate_expected_repetitions_cte(
|
|
1094
|
+
dialect, run_counts_subquery, experiment.repetitions
|
|
1095
|
+
)
|
|
1096
|
+
partially_complete_stmt = get_incomplete_repetitions_query(
|
|
1097
|
+
dialect, expected_runs_cte, experiment.id
|
|
1098
|
+
)
|
|
1099
|
+
|
|
1100
|
+
# Combine both paths
|
|
1101
|
+
from sqlalchemy import union_all
|
|
1102
|
+
|
|
1103
|
+
combined_incomplete = union_all(completely_missing_stmt, partially_complete_stmt).subquery()
|
|
1104
|
+
|
|
1105
|
+
# Step 3: Join with revisions and apply pagination
|
|
1106
|
+
return get_incomplete_runs_with_revisions_query(
|
|
1107
|
+
combined_incomplete,
|
|
1108
|
+
cursor_example_rowid=cursor_example_rowid,
|
|
1109
|
+
limit=limit,
|
|
1110
|
+
)
|