arize-phoenix 10.0.4__py3-none-any.whl → 12.28.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +124 -72
- arize_phoenix-12.28.1.dist-info/RECORD +499 -0
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
- phoenix/__generated__/__init__.py +0 -0
- phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
- phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
- phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
- phoenix/__init__.py +5 -4
- phoenix/auth.py +39 -2
- phoenix/config.py +1763 -91
- phoenix/datetime_utils.py +120 -2
- phoenix/db/README.md +595 -25
- phoenix/db/bulk_inserter.py +145 -103
- phoenix/db/engines.py +140 -33
- phoenix/db/enums.py +3 -12
- phoenix/db/facilitator.py +302 -35
- phoenix/db/helpers.py +1000 -65
- phoenix/db/iam_auth.py +64 -0
- phoenix/db/insertion/dataset.py +135 -2
- phoenix/db/insertion/document_annotation.py +9 -6
- phoenix/db/insertion/evaluation.py +2 -3
- phoenix/db/insertion/helpers.py +17 -2
- phoenix/db/insertion/session_annotation.py +176 -0
- phoenix/db/insertion/span.py +15 -11
- phoenix/db/insertion/span_annotation.py +3 -4
- phoenix/db/insertion/trace_annotation.py +3 -4
- phoenix/db/insertion/types.py +50 -20
- phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
- phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
- phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
- phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
- phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
- phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
- phoenix/db/migrations/versions/a20694b15f82_cost.py +196 -0
- phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
- phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
- phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
- phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
- phoenix/db/models.py +669 -56
- phoenix/db/pg_config.py +10 -0
- phoenix/db/types/model_provider.py +4 -0
- phoenix/db/types/token_price_customization.py +29 -0
- phoenix/db/types/trace_retention.py +23 -15
- phoenix/experiments/evaluators/utils.py +3 -3
- phoenix/experiments/functions.py +160 -52
- phoenix/experiments/tracing.py +2 -2
- phoenix/experiments/types.py +1 -1
- phoenix/inferences/inferences.py +1 -2
- phoenix/server/api/auth.py +38 -7
- phoenix/server/api/auth_messages.py +46 -0
- phoenix/server/api/context.py +100 -4
- phoenix/server/api/dataloaders/__init__.py +79 -5
- phoenix/server/api/dataloaders/annotation_configs_by_project.py +31 -0
- phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
- phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
- phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
- phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
- phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
- phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
- phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
- phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
- phoenix/server/api/dataloaders/dataset_labels.py +36 -0
- phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
- phoenix/server/api/dataloaders/document_evaluations.py +6 -9
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
- phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
- phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
- phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
- phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
- phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
- phoenix/server/api/dataloaders/last_used_times_by_generative_model_id.py +35 -0
- phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
- phoenix/server/api/dataloaders/record_counts.py +37 -10
- phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
- phoenix/server/api/dataloaders/span_cost_by_span.py +24 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_generative_model.py +56 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_project_session.py +57 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_span.py +43 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_trace.py +56 -0
- phoenix/server/api/dataloaders/span_cost_details_by_span_cost.py +27 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment.py +57 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_run.py +58 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_generative_model.py +55 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project.py +152 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project_session.py +56 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_trace.py +55 -0
- phoenix/server/api/dataloaders/span_costs.py +29 -0
- phoenix/server/api/dataloaders/table_fields.py +2 -2
- phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
- phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
- phoenix/server/api/dataloaders/types.py +29 -0
- phoenix/server/api/exceptions.py +11 -1
- phoenix/server/api/helpers/dataset_helpers.py +5 -1
- phoenix/server/api/helpers/playground_clients.py +1243 -292
- phoenix/server/api/helpers/playground_registry.py +2 -2
- phoenix/server/api/helpers/playground_spans.py +8 -4
- phoenix/server/api/helpers/playground_users.py +26 -0
- phoenix/server/api/helpers/prompts/conversions/aws.py +83 -0
- phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
- phoenix/server/api/helpers/prompts/models.py +205 -22
- phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
- phoenix/server/api/input_types/ChatCompletionInput.py +6 -2
- phoenix/server/api/input_types/CreateProjectInput.py +27 -0
- phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
- phoenix/server/api/input_types/DatasetFilter.py +17 -0
- phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
- phoenix/server/api/input_types/GenerativeCredentialInput.py +9 -0
- phoenix/server/api/input_types/GenerativeModelInput.py +5 -0
- phoenix/server/api/input_types/ProjectSessionSort.py +161 -1
- phoenix/server/api/input_types/PromptFilter.py +14 -0
- phoenix/server/api/input_types/PromptVersionInput.py +52 -1
- phoenix/server/api/input_types/SpanSort.py +44 -7
- phoenix/server/api/input_types/TimeBinConfig.py +23 -0
- phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
- phoenix/server/api/input_types/UserRoleInput.py +1 -0
- phoenix/server/api/mutations/__init__.py +10 -0
- phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
- phoenix/server/api/mutations/api_key_mutations.py +19 -23
- phoenix/server/api/mutations/chat_mutations.py +154 -47
- phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
- phoenix/server/api/mutations/dataset_mutations.py +21 -16
- phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
- phoenix/server/api/mutations/experiment_mutations.py +2 -2
- phoenix/server/api/mutations/export_events_mutations.py +3 -3
- phoenix/server/api/mutations/model_mutations.py +210 -0
- phoenix/server/api/mutations/project_mutations.py +49 -10
- phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
- phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
- phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
- phoenix/server/api/mutations/prompt_mutations.py +65 -129
- phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
- phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
- phoenix/server/api/mutations/trace_annotations_mutations.py +14 -10
- phoenix/server/api/mutations/trace_mutations.py +47 -3
- phoenix/server/api/mutations/user_mutations.py +66 -41
- phoenix/server/api/queries.py +768 -293
- phoenix/server/api/routers/__init__.py +2 -2
- phoenix/server/api/routers/auth.py +154 -88
- phoenix/server/api/routers/ldap.py +229 -0
- phoenix/server/api/routers/oauth2.py +369 -106
- phoenix/server/api/routers/v1/__init__.py +24 -4
- phoenix/server/api/routers/v1/annotation_configs.py +23 -31
- phoenix/server/api/routers/v1/annotations.py +481 -17
- phoenix/server/api/routers/v1/datasets.py +395 -81
- phoenix/server/api/routers/v1/documents.py +142 -0
- phoenix/server/api/routers/v1/evaluations.py +24 -31
- phoenix/server/api/routers/v1/experiment_evaluations.py +19 -8
- phoenix/server/api/routers/v1/experiment_runs.py +337 -59
- phoenix/server/api/routers/v1/experiments.py +479 -48
- phoenix/server/api/routers/v1/models.py +7 -0
- phoenix/server/api/routers/v1/projects.py +18 -49
- phoenix/server/api/routers/v1/prompts.py +54 -40
- phoenix/server/api/routers/v1/sessions.py +108 -0
- phoenix/server/api/routers/v1/spans.py +1091 -81
- phoenix/server/api/routers/v1/traces.py +132 -78
- phoenix/server/api/routers/v1/users.py +389 -0
- phoenix/server/api/routers/v1/utils.py +3 -7
- phoenix/server/api/subscriptions.py +305 -88
- phoenix/server/api/types/Annotation.py +90 -23
- phoenix/server/api/types/ApiKey.py +13 -17
- phoenix/server/api/types/AuthMethod.py +1 -0
- phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
- phoenix/server/api/types/CostBreakdown.py +12 -0
- phoenix/server/api/types/Dataset.py +226 -72
- phoenix/server/api/types/DatasetExample.py +88 -18
- phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
- phoenix/server/api/types/DatasetLabel.py +57 -0
- phoenix/server/api/types/DatasetSplit.py +98 -0
- phoenix/server/api/types/DatasetVersion.py +49 -4
- phoenix/server/api/types/DocumentAnnotation.py +212 -0
- phoenix/server/api/types/Experiment.py +264 -59
- phoenix/server/api/types/ExperimentComparison.py +5 -10
- phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
- phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
- phoenix/server/api/types/ExperimentRun.py +169 -65
- phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
- phoenix/server/api/types/GenerativeModel.py +245 -3
- phoenix/server/api/types/GenerativeProvider.py +70 -11
- phoenix/server/api/types/{Model.py → InferenceModel.py} +1 -1
- phoenix/server/api/types/ModelInterface.py +16 -0
- phoenix/server/api/types/PlaygroundModel.py +20 -0
- phoenix/server/api/types/Project.py +1278 -216
- phoenix/server/api/types/ProjectSession.py +188 -28
- phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
- phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
- phoenix/server/api/types/Prompt.py +119 -39
- phoenix/server/api/types/PromptLabel.py +42 -25
- phoenix/server/api/types/PromptVersion.py +11 -8
- phoenix/server/api/types/PromptVersionTag.py +65 -25
- phoenix/server/api/types/ServerStatus.py +6 -0
- phoenix/server/api/types/Span.py +167 -123
- phoenix/server/api/types/SpanAnnotation.py +189 -42
- phoenix/server/api/types/SpanCostDetailSummaryEntry.py +10 -0
- phoenix/server/api/types/SpanCostSummary.py +10 -0
- phoenix/server/api/types/SystemApiKey.py +65 -1
- phoenix/server/api/types/TokenPrice.py +16 -0
- phoenix/server/api/types/TokenUsage.py +3 -3
- phoenix/server/api/types/Trace.py +223 -51
- phoenix/server/api/types/TraceAnnotation.py +149 -50
- phoenix/server/api/types/User.py +137 -32
- phoenix/server/api/types/UserApiKey.py +73 -26
- phoenix/server/api/types/node.py +10 -0
- phoenix/server/api/types/pagination.py +11 -2
- phoenix/server/app.py +290 -45
- phoenix/server/authorization.py +38 -3
- phoenix/server/bearer_auth.py +34 -24
- phoenix/server/cost_tracking/cost_details_calculator.py +196 -0
- phoenix/server/cost_tracking/cost_model_lookup.py +179 -0
- phoenix/server/cost_tracking/helpers.py +68 -0
- phoenix/server/cost_tracking/model_cost_manifest.json +3657 -830
- phoenix/server/cost_tracking/regex_specificity.py +397 -0
- phoenix/server/cost_tracking/token_cost_calculator.py +57 -0
- phoenix/server/daemons/__init__.py +0 -0
- phoenix/server/daemons/db_disk_usage_monitor.py +214 -0
- phoenix/server/daemons/generative_model_store.py +103 -0
- phoenix/server/daemons/span_cost_calculator.py +99 -0
- phoenix/server/dml_event.py +17 -0
- phoenix/server/dml_event_handler.py +5 -0
- phoenix/server/email/sender.py +56 -3
- phoenix/server/email/templates/db_disk_usage_notification.html +19 -0
- phoenix/server/email/types.py +11 -0
- phoenix/server/experiments/__init__.py +0 -0
- phoenix/server/experiments/utils.py +14 -0
- phoenix/server/grpc_server.py +11 -11
- phoenix/server/jwt_store.py +17 -15
- phoenix/server/ldap.py +1449 -0
- phoenix/server/main.py +26 -10
- phoenix/server/oauth2.py +330 -12
- phoenix/server/prometheus.py +66 -6
- phoenix/server/rate_limiters.py +4 -9
- phoenix/server/retention.py +33 -20
- phoenix/server/session_filters.py +49 -0
- phoenix/server/static/.vite/manifest.json +55 -51
- phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
- phoenix/server/static/assets/{index-E0M82BdE.js → index-CTQoemZv.js} +140 -56
- phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
- phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
- phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
- phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
- phoenix/server/static/assets/vendor-recharts-V9cwpXsm.js +37 -0
- phoenix/server/static/assets/vendor-shiki-Do--csgv.js +5 -0
- phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
- phoenix/server/templates/index.html +40 -6
- phoenix/server/thread_server.py +1 -2
- phoenix/server/types.py +14 -4
- phoenix/server/utils.py +74 -0
- phoenix/session/client.py +56 -3
- phoenix/session/data_extractor.py +5 -0
- phoenix/session/evaluation.py +14 -5
- phoenix/session/session.py +45 -9
- phoenix/settings.py +5 -0
- phoenix/trace/attributes.py +80 -13
- phoenix/trace/dsl/helpers.py +90 -1
- phoenix/trace/dsl/query.py +8 -6
- phoenix/trace/projects.py +5 -0
- phoenix/utilities/template_formatters.py +1 -1
- phoenix/version.py +1 -1
- arize_phoenix-10.0.4.dist-info/RECORD +0 -405
- phoenix/server/api/types/Evaluation.py +0 -39
- phoenix/server/cost_tracking/cost_lookup.py +0 -255
- phoenix/server/static/assets/components-DULKeDfL.js +0 -4365
- phoenix/server/static/assets/pages-Cl0A-0U2.js +0 -7430
- phoenix/server/static/assets/vendor-WIZid84E.css +0 -1
- phoenix/server/static/assets/vendor-arizeai-Dy-0mSNw.js +0 -649
- phoenix/server/static/assets/vendor-codemirror-DBtifKNr.js +0 -33
- phoenix/server/static/assets/vendor-oB4u9zuV.js +0 -905
- phoenix/server/static/assets/vendor-recharts-D-T4KPz2.js +0 -59
- phoenix/server/static/assets/vendor-shiki-BMn4O_9F.js +0 -5
- phoenix/server/static/assets/vendor-three-C5WAXd5r.js +0 -2998
- phoenix/utilities/deprecation.py +0 -31
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
from sqlalchemy import func, select, tuple_
|
|
4
|
+
from strawberry.dataloader import DataLoader
|
|
5
|
+
from typing_extensions import TypeAlias
|
|
6
|
+
|
|
7
|
+
from phoenix.db import models
|
|
8
|
+
from phoenix.server.types import DbSessionFactory
|
|
9
|
+
|
|
10
|
+
ExperimentID: TypeAlias = int
|
|
11
|
+
DatasetExampleID: TypeAlias = int
|
|
12
|
+
RunLatency: TypeAlias = float
|
|
13
|
+
Key: TypeAlias = tuple[ExperimentID, DatasetExampleID]
|
|
14
|
+
Result: TypeAlias = Optional[RunLatency]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class AverageExperimentRepeatedRunGroupLatencyDataLoader(DataLoader[Key, Result]):
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
db: DbSessionFactory,
|
|
21
|
+
) -> None:
|
|
22
|
+
super().__init__(load_fn=self._load_fn)
|
|
23
|
+
self._db = db
|
|
24
|
+
|
|
25
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
26
|
+
average_latency_query = (
|
|
27
|
+
select(
|
|
28
|
+
models.ExperimentRun.experiment_id.label("experiment_id"),
|
|
29
|
+
models.ExperimentRun.dataset_example_id.label("example_id"),
|
|
30
|
+
func.avg(models.ExperimentRun.latency_ms).label("average_repetition_latency_ms"),
|
|
31
|
+
)
|
|
32
|
+
.select_from(models.ExperimentRun)
|
|
33
|
+
.where(
|
|
34
|
+
tuple_(
|
|
35
|
+
models.ExperimentRun.experiment_id, models.ExperimentRun.dataset_example_id
|
|
36
|
+
).in_(set(keys))
|
|
37
|
+
)
|
|
38
|
+
.group_by(models.ExperimentRun.experiment_id, models.ExperimentRun.dataset_example_id)
|
|
39
|
+
)
|
|
40
|
+
async with self._db() as session:
|
|
41
|
+
average_run_latencies_ms = {
|
|
42
|
+
(experiment_id, example_id): average_run_latency_ms
|
|
43
|
+
async for experiment_id, example_id, average_run_latency_ms in await session.stream(
|
|
44
|
+
average_latency_query
|
|
45
|
+
)
|
|
46
|
+
}
|
|
47
|
+
return [
|
|
48
|
+
average_run_latencies_ms.get((experiment_id, example_id))
|
|
49
|
+
for experiment_id, example_id in keys
|
|
50
|
+
]
|
|
@@ -23,32 +23,25 @@ class AverageExperimentRunLatencyDataLoader(DataLoader[Key, Result]):
|
|
|
23
23
|
|
|
24
24
|
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
25
25
|
experiment_ids = keys
|
|
26
|
-
|
|
27
|
-
select(models.Experiment.id)
|
|
28
|
-
.where(models.Experiment.id.in_(set(experiment_ids)))
|
|
29
|
-
.subquery()
|
|
30
|
-
)
|
|
31
|
-
query = (
|
|
26
|
+
average_repetition_latency_ms = (
|
|
32
27
|
select(
|
|
33
|
-
|
|
34
|
-
func.avg(
|
|
35
|
-
func.extract("epoch", models.ExperimentRun.end_time)
|
|
36
|
-
- func.extract("epoch", models.ExperimentRun.start_time)
|
|
37
|
-
),
|
|
28
|
+
models.ExperimentRun.experiment_id.label("experiment_id"),
|
|
29
|
+
func.avg(models.ExperimentRun.latency_ms).label("average_repetition_latency_ms"),
|
|
38
30
|
)
|
|
39
|
-
.
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
)
|
|
44
|
-
.group_by(resolved_experiment_ids.c.id)
|
|
31
|
+
.select_from(models.ExperimentRun)
|
|
32
|
+
.where(models.ExperimentRun.experiment_id.in_(experiment_ids))
|
|
33
|
+
.group_by(models.ExperimentRun.dataset_example_id, models.ExperimentRun.experiment_id)
|
|
34
|
+
.subquery()
|
|
45
35
|
)
|
|
36
|
+
query = select(
|
|
37
|
+
average_repetition_latency_ms.c.experiment_id,
|
|
38
|
+
func.avg(average_repetition_latency_ms.c.average_repetition_latency_ms).label(
|
|
39
|
+
"average_run_latency_ms"
|
|
40
|
+
),
|
|
41
|
+
).group_by(average_repetition_latency_ms.c.experiment_id)
|
|
46
42
|
async with self._db() as session:
|
|
47
|
-
|
|
48
|
-
experiment_id:
|
|
49
|
-
async for experiment_id,
|
|
43
|
+
average_run_latencies_ms = {
|
|
44
|
+
experiment_id: average_run_latency_ms
|
|
45
|
+
async for experiment_id, average_run_latency_ms in await session.stream(query)
|
|
50
46
|
}
|
|
51
|
-
return [
|
|
52
|
-
avg_latencies.get(experiment_id, ValueError(f"Unknown experiment: {experiment_id}"))
|
|
53
|
-
for experiment_id in keys
|
|
54
|
-
]
|
|
47
|
+
return [average_run_latencies_ms.get(experiment_id) for experiment_id in keys]
|
|
@@ -7,7 +7,7 @@ single-tier system we would need to check all the keys to see if they are in the
|
|
|
7
7
|
subset that we want to invalidate.
|
|
8
8
|
"""
|
|
9
9
|
|
|
10
|
-
from abc import
|
|
10
|
+
from abc import abstractmethod
|
|
11
11
|
from asyncio import Future
|
|
12
12
|
from collections.abc import Callable
|
|
13
13
|
from typing import Any, Generic, Optional, TypeVar
|
|
@@ -25,7 +25,6 @@ _SubKey = TypeVar("_SubKey")
|
|
|
25
25
|
class TwoTierCache(
|
|
26
26
|
AbstractCache[_Key, _Result],
|
|
27
27
|
Generic[_Key, _Result, _Section, _SubKey],
|
|
28
|
-
ABC,
|
|
29
28
|
):
|
|
30
29
|
def __init__(
|
|
31
30
|
self,
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
from sqlalchemy import select
|
|
2
|
+
from strawberry.dataloader import DataLoader
|
|
3
|
+
from typing_extensions import TypeAlias
|
|
4
|
+
|
|
5
|
+
from phoenix.db import models
|
|
6
|
+
from phoenix.server.types import DbSessionFactory
|
|
7
|
+
|
|
8
|
+
DatasetID: TypeAlias = int
|
|
9
|
+
Key: TypeAlias = DatasetID
|
|
10
|
+
Result: TypeAlias = list[models.DatasetSplit]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class DatasetDatasetSplitsDataLoader(DataLoader[Key, Result]):
|
|
14
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
15
|
+
super().__init__(
|
|
16
|
+
load_fn=self._load_fn,
|
|
17
|
+
)
|
|
18
|
+
self._db = db
|
|
19
|
+
|
|
20
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
21
|
+
dataset_ids = keys
|
|
22
|
+
async with self._db() as session:
|
|
23
|
+
splits: dict[DatasetID, dict[int, models.DatasetSplit]] = {
|
|
24
|
+
dataset_id: {} for dataset_id in dataset_ids
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
async for dataset_id, split in await session.stream(
|
|
28
|
+
select(models.DatasetExample.dataset_id, models.DatasetSplit)
|
|
29
|
+
.select_from(models.DatasetSplit)
|
|
30
|
+
.join(
|
|
31
|
+
models.DatasetSplitDatasetExample,
|
|
32
|
+
onclause=(
|
|
33
|
+
models.DatasetSplit.id == models.DatasetSplitDatasetExample.dataset_split_id
|
|
34
|
+
),
|
|
35
|
+
)
|
|
36
|
+
.join(
|
|
37
|
+
models.DatasetExample,
|
|
38
|
+
onclause=(
|
|
39
|
+
models.DatasetSplitDatasetExample.dataset_example_id
|
|
40
|
+
== models.DatasetExample.id
|
|
41
|
+
),
|
|
42
|
+
)
|
|
43
|
+
.where(models.DatasetExample.dataset_id.in_(dataset_ids))
|
|
44
|
+
):
|
|
45
|
+
# Use dict to deduplicate splits by split.id
|
|
46
|
+
if dataset_id in splits:
|
|
47
|
+
splits[dataset_id][split.id] = split
|
|
48
|
+
|
|
49
|
+
return [
|
|
50
|
+
sorted(splits.get(dataset_id, {}).values(), key=lambda x: x.name)
|
|
51
|
+
for dataset_id in keys
|
|
52
|
+
]
|
|
@@ -91,7 +91,6 @@ class DatasetExampleRevisionsDataLoader(DataLoader[Key, Result]):
|
|
|
91
91
|
onclause=revision_ids.c.version_id == models.DatasetVersion.id,
|
|
92
92
|
isouter=True, # keep rows where the version id is null
|
|
93
93
|
)
|
|
94
|
-
.where(models.DatasetExampleRevision.revision_kind != "DELETE")
|
|
95
94
|
)
|
|
96
95
|
async with self._db() as session:
|
|
97
96
|
results = {
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from sqlalchemy import select
|
|
2
|
+
from strawberry.dataloader import DataLoader
|
|
3
|
+
from typing_extensions import TypeAlias
|
|
4
|
+
|
|
5
|
+
from phoenix.db import models
|
|
6
|
+
from phoenix.server.types import DbSessionFactory
|
|
7
|
+
|
|
8
|
+
ExampleID: TypeAlias = int
|
|
9
|
+
Key: TypeAlias = ExampleID
|
|
10
|
+
Result: TypeAlias = list[models.DatasetSplit]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class DatasetExampleSplitsDataLoader(DataLoader[Key, Result]):
|
|
14
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
15
|
+
super().__init__(
|
|
16
|
+
load_fn=self._load_fn,
|
|
17
|
+
)
|
|
18
|
+
self._db = db
|
|
19
|
+
|
|
20
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
21
|
+
example_ids = keys
|
|
22
|
+
async with self._db() as session:
|
|
23
|
+
splits: dict[ExampleID, list[models.DatasetSplit]] = {}
|
|
24
|
+
|
|
25
|
+
async for example_id, split in await session.stream(
|
|
26
|
+
select(models.DatasetSplitDatasetExample.dataset_example_id, models.DatasetSplit)
|
|
27
|
+
.select_from(models.DatasetSplit)
|
|
28
|
+
.join(
|
|
29
|
+
models.DatasetSplitDatasetExample,
|
|
30
|
+
onclause=(
|
|
31
|
+
models.DatasetSplit.id == models.DatasetSplitDatasetExample.dataset_split_id
|
|
32
|
+
),
|
|
33
|
+
)
|
|
34
|
+
.where(models.DatasetSplitDatasetExample.dataset_example_id.in_(example_ids))
|
|
35
|
+
):
|
|
36
|
+
if example_id not in splits:
|
|
37
|
+
splits[example_id] = []
|
|
38
|
+
splits[example_id].append(split)
|
|
39
|
+
|
|
40
|
+
return [sorted(splits.get(example_id, []), key=lambda x: x.name) for example_id in keys]
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
from sqlalchemy import select
|
|
2
|
+
from strawberry.dataloader import DataLoader
|
|
3
|
+
from typing_extensions import TypeAlias
|
|
4
|
+
|
|
5
|
+
from phoenix.db import models
|
|
6
|
+
from phoenix.server.types import DbSessionFactory
|
|
7
|
+
|
|
8
|
+
ExperimentRunID: TypeAlias = int
|
|
9
|
+
DatasetExampleID: TypeAlias = int
|
|
10
|
+
DatasetVersionID: TypeAlias = int
|
|
11
|
+
Key: TypeAlias = ExperimentRunID
|
|
12
|
+
Result: TypeAlias = tuple[models.DatasetExample, DatasetVersionID]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class DatasetExamplesAndVersionsByExperimentRunDataLoader(DataLoader[Key, Result]):
|
|
16
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
17
|
+
super().__init__(load_fn=self._load_fn)
|
|
18
|
+
self._db = db
|
|
19
|
+
|
|
20
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
21
|
+
experiment_run_ids = set(keys)
|
|
22
|
+
examples_and_versions_query = (
|
|
23
|
+
select(
|
|
24
|
+
models.ExperimentRun.id.label("experiment_run_id"),
|
|
25
|
+
models.DatasetExample,
|
|
26
|
+
models.Experiment.dataset_version_id.label("dataset_version_id"),
|
|
27
|
+
)
|
|
28
|
+
.select_from(models.ExperimentRun)
|
|
29
|
+
.join(
|
|
30
|
+
models.DatasetExample,
|
|
31
|
+
models.DatasetExample.id == models.ExperimentRun.dataset_example_id,
|
|
32
|
+
)
|
|
33
|
+
.join(
|
|
34
|
+
models.Experiment,
|
|
35
|
+
models.Experiment.id == models.ExperimentRun.experiment_id,
|
|
36
|
+
)
|
|
37
|
+
.where(models.ExperimentRun.id.in_(experiment_run_ids))
|
|
38
|
+
)
|
|
39
|
+
async with self._db() as session:
|
|
40
|
+
examples_and_versions = {
|
|
41
|
+
experiment_run_id: (example, version_id)
|
|
42
|
+
for experiment_run_id, example, version_id in (
|
|
43
|
+
await session.execute(examples_and_versions_query)
|
|
44
|
+
).all()
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
return [examples_and_versions[key] for key in keys]
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
from sqlalchemy import select
|
|
2
|
+
from strawberry.dataloader import DataLoader
|
|
3
|
+
from typing_extensions import TypeAlias
|
|
4
|
+
|
|
5
|
+
from phoenix.db import models
|
|
6
|
+
from phoenix.server.types import DbSessionFactory
|
|
7
|
+
|
|
8
|
+
DatasetID: TypeAlias = int
|
|
9
|
+
Key: TypeAlias = DatasetID
|
|
10
|
+
Result: TypeAlias = list[models.DatasetLabel]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class DatasetLabelsDataLoader(DataLoader[Key, Result]):
|
|
14
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
15
|
+
super().__init__(load_fn=self._load_fn)
|
|
16
|
+
self._db = db
|
|
17
|
+
|
|
18
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
19
|
+
dataset_ids = keys
|
|
20
|
+
async with self._db() as session:
|
|
21
|
+
labels: dict[Key, Result] = {}
|
|
22
|
+
for dataset_id, label in await session.execute(
|
|
23
|
+
select(models.DatasetsDatasetLabel.dataset_id, models.DatasetLabel)
|
|
24
|
+
.select_from(models.DatasetLabel)
|
|
25
|
+
.join(
|
|
26
|
+
models.DatasetsDatasetLabel,
|
|
27
|
+
models.DatasetLabel.id == models.DatasetsDatasetLabel.dataset_label_id,
|
|
28
|
+
)
|
|
29
|
+
.where(models.DatasetsDatasetLabel.dataset_id.in_(dataset_ids))
|
|
30
|
+
):
|
|
31
|
+
if dataset_id not in labels:
|
|
32
|
+
labels[dataset_id] = []
|
|
33
|
+
labels[dataset_id].append(label)
|
|
34
|
+
return [
|
|
35
|
+
sorted(labels.get(dataset_id, []), key=lambda label: label.name) for dataset_id in keys
|
|
36
|
+
]
|
|
@@ -10,7 +10,7 @@ from strawberry.dataloader import AbstractCache, DataLoader
|
|
|
10
10
|
from typing_extensions import TypeAlias
|
|
11
11
|
|
|
12
12
|
from phoenix.db import models
|
|
13
|
-
from phoenix.db.helpers import SupportedSQLDialect
|
|
13
|
+
from phoenix.db.helpers import SupportedSQLDialect
|
|
14
14
|
from phoenix.metrics.retrieval_metrics import RetrievalMetrics
|
|
15
15
|
from phoenix.server.api.dataloaders.cache import TwoTierCache
|
|
16
16
|
from phoenix.server.api.input_types.TimeRange import TimeRange
|
|
@@ -122,7 +122,7 @@ def _get_stmt(
|
|
|
122
122
|
select(
|
|
123
123
|
mda.name,
|
|
124
124
|
models.Span.id,
|
|
125
|
-
|
|
125
|
+
models.Span.num_documents.label("num_docs"),
|
|
126
126
|
mda.score,
|
|
127
127
|
mda.document_position,
|
|
128
128
|
)
|
|
@@ -5,11 +5,10 @@ from strawberry.dataloader import DataLoader
|
|
|
5
5
|
from typing_extensions import TypeAlias
|
|
6
6
|
|
|
7
7
|
from phoenix.db import models
|
|
8
|
-
from phoenix.server.api.types.Evaluation import DocumentEvaluation
|
|
9
8
|
from phoenix.server.types import DbSessionFactory
|
|
10
9
|
|
|
11
10
|
Key: TypeAlias = int
|
|
12
|
-
Result: TypeAlias = list[
|
|
11
|
+
Result: TypeAlias = list[models.DocumentAnnotation]
|
|
13
12
|
|
|
14
13
|
|
|
15
14
|
class DocumentEvaluationsDataLoader(DataLoader[Key, Result]):
|
|
@@ -18,14 +17,12 @@ class DocumentEvaluationsDataLoader(DataLoader[Key, Result]):
|
|
|
18
17
|
self._db = db
|
|
19
18
|
|
|
20
19
|
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
21
|
-
|
|
20
|
+
document_annotations_by_id: defaultdict[Key, Result] = defaultdict(list)
|
|
22
21
|
mda = models.DocumentAnnotation
|
|
23
22
|
async with self._db() as session:
|
|
24
|
-
data = await session.stream_scalars(
|
|
25
|
-
select(mda).where(mda.span_rowid.in_(keys)).where(mda.annotator_kind == "LLM")
|
|
26
|
-
)
|
|
23
|
+
data = await session.stream_scalars(select(mda).where(mda.span_rowid.in_(keys)))
|
|
27
24
|
async for document_evaluation in data:
|
|
28
|
-
|
|
29
|
-
|
|
25
|
+
document_annotations_by_id[document_evaluation.span_rowid].append(
|
|
26
|
+
document_evaluation
|
|
30
27
|
)
|
|
31
|
-
return [
|
|
28
|
+
return [document_annotations_by_id[key] for key in keys]
|
|
@@ -2,7 +2,7 @@ from collections import defaultdict
|
|
|
2
2
|
from dataclasses import dataclass
|
|
3
3
|
from typing import Optional
|
|
4
4
|
|
|
5
|
-
from sqlalchemy import func, select
|
|
5
|
+
from sqlalchemy import and_, func, select
|
|
6
6
|
from strawberry.dataloader import AbstractCache, DataLoader
|
|
7
7
|
from typing_extensions import TypeAlias
|
|
8
8
|
|
|
@@ -37,43 +37,97 @@ class ExperimentAnnotationSummaryDataLoader(DataLoader[Key, Result]):
|
|
|
37
37
|
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
38
38
|
experiment_ids = keys
|
|
39
39
|
summaries: defaultdict[ExperimentID, Result] = defaultdict(list)
|
|
40
|
+
repetition_mean_scores_by_example_subquery = (
|
|
41
|
+
select(
|
|
42
|
+
models.ExperimentRun.experiment_id.label("experiment_id"),
|
|
43
|
+
models.ExperimentRunAnnotation.name.label("annotation_name"),
|
|
44
|
+
func.avg(models.ExperimentRunAnnotation.score).label("mean_repetition_score"),
|
|
45
|
+
)
|
|
46
|
+
.select_from(models.ExperimentRunAnnotation)
|
|
47
|
+
.join(
|
|
48
|
+
models.ExperimentRun,
|
|
49
|
+
models.ExperimentRunAnnotation.experiment_run_id == models.ExperimentRun.id,
|
|
50
|
+
)
|
|
51
|
+
.where(models.ExperimentRun.experiment_id.in_(experiment_ids))
|
|
52
|
+
.group_by(
|
|
53
|
+
models.ExperimentRun.experiment_id,
|
|
54
|
+
models.ExperimentRun.dataset_example_id,
|
|
55
|
+
models.ExperimentRunAnnotation.name,
|
|
56
|
+
)
|
|
57
|
+
.subquery()
|
|
58
|
+
.alias("repetition_mean_scores_by_example")
|
|
59
|
+
)
|
|
60
|
+
repetition_mean_scores_subquery = (
|
|
61
|
+
select(
|
|
62
|
+
repetition_mean_scores_by_example_subquery.c.experiment_id.label("experiment_id"),
|
|
63
|
+
repetition_mean_scores_by_example_subquery.c.annotation_name.label(
|
|
64
|
+
"annotation_name"
|
|
65
|
+
),
|
|
66
|
+
func.avg(repetition_mean_scores_by_example_subquery.c.mean_repetition_score).label(
|
|
67
|
+
"mean_score"
|
|
68
|
+
),
|
|
69
|
+
)
|
|
70
|
+
.select_from(repetition_mean_scores_by_example_subquery)
|
|
71
|
+
.group_by(
|
|
72
|
+
repetition_mean_scores_by_example_subquery.c.experiment_id,
|
|
73
|
+
repetition_mean_scores_by_example_subquery.c.annotation_name,
|
|
74
|
+
)
|
|
75
|
+
.subquery()
|
|
76
|
+
.alias("repetition_mean_scores")
|
|
77
|
+
)
|
|
78
|
+
repetitions_subquery = (
|
|
79
|
+
select(
|
|
80
|
+
models.ExperimentRun.experiment_id.label("experiment_id"),
|
|
81
|
+
models.ExperimentRunAnnotation.name.label("annotation_name"),
|
|
82
|
+
func.min(models.ExperimentRunAnnotation.score).label("min_score"),
|
|
83
|
+
func.max(models.ExperimentRunAnnotation.score).label("max_score"),
|
|
84
|
+
func.count().label("count"),
|
|
85
|
+
func.count(models.ExperimentRunAnnotation.error).label("error_count"),
|
|
86
|
+
)
|
|
87
|
+
.select_from(models.ExperimentRunAnnotation)
|
|
88
|
+
.join(
|
|
89
|
+
models.ExperimentRun,
|
|
90
|
+
models.ExperimentRunAnnotation.experiment_run_id == models.ExperimentRun.id,
|
|
91
|
+
)
|
|
92
|
+
.where(models.ExperimentRun.experiment_id.in_(experiment_ids))
|
|
93
|
+
.group_by(models.ExperimentRun.experiment_id, models.ExperimentRunAnnotation.name)
|
|
94
|
+
.subquery()
|
|
95
|
+
)
|
|
96
|
+
run_scores_query = (
|
|
97
|
+
select(
|
|
98
|
+
repetition_mean_scores_subquery.c.experiment_id.label("experiment_id"),
|
|
99
|
+
repetition_mean_scores_subquery.c.annotation_name.label("annotation_name"),
|
|
100
|
+
repetition_mean_scores_subquery.c.mean_score.label("mean_score"),
|
|
101
|
+
repetitions_subquery.c.min_score.label("min_score"),
|
|
102
|
+
repetitions_subquery.c.max_score.label("max_score"),
|
|
103
|
+
repetitions_subquery.c.count.label("count_"),
|
|
104
|
+
repetitions_subquery.c.error_count.label("error_count"),
|
|
105
|
+
)
|
|
106
|
+
.select_from(repetition_mean_scores_subquery)
|
|
107
|
+
.join(
|
|
108
|
+
repetitions_subquery,
|
|
109
|
+
and_(
|
|
110
|
+
repetitions_subquery.c.experiment_id
|
|
111
|
+
== repetition_mean_scores_subquery.c.experiment_id,
|
|
112
|
+
repetitions_subquery.c.annotation_name
|
|
113
|
+
== repetition_mean_scores_subquery.c.annotation_name,
|
|
114
|
+
),
|
|
115
|
+
)
|
|
116
|
+
.order_by(repetition_mean_scores_subquery.c.annotation_name)
|
|
117
|
+
)
|
|
40
118
|
async with self._db() as session:
|
|
41
|
-
async for (
|
|
42
|
-
experiment_id
|
|
43
|
-
annotation_name,
|
|
44
|
-
min_score,
|
|
45
|
-
max_score,
|
|
46
|
-
mean_score,
|
|
47
|
-
count,
|
|
48
|
-
error_count,
|
|
49
|
-
) in await session.stream(
|
|
50
|
-
select(
|
|
51
|
-
models.ExperimentRun.experiment_id,
|
|
52
|
-
models.ExperimentRunAnnotation.name,
|
|
53
|
-
func.min(models.ExperimentRunAnnotation.score),
|
|
54
|
-
func.max(models.ExperimentRunAnnotation.score),
|
|
55
|
-
func.avg(models.ExperimentRunAnnotation.score),
|
|
56
|
-
func.count(),
|
|
57
|
-
func.count(models.ExperimentRunAnnotation.error),
|
|
58
|
-
)
|
|
59
|
-
.join(
|
|
60
|
-
models.ExperimentRun,
|
|
61
|
-
models.ExperimentRunAnnotation.experiment_run_id == models.ExperimentRun.id,
|
|
62
|
-
)
|
|
63
|
-
.where(models.ExperimentRun.experiment_id.in_(experiment_ids))
|
|
64
|
-
.group_by(models.ExperimentRun.experiment_id, models.ExperimentRunAnnotation.name)
|
|
65
|
-
):
|
|
66
|
-
summaries[experiment_id].append(
|
|
119
|
+
async for scores_tuple in await session.stream(run_scores_query):
|
|
120
|
+
summaries[scores_tuple.experiment_id].append(
|
|
67
121
|
ExperimentAnnotationSummary(
|
|
68
|
-
annotation_name=annotation_name,
|
|
69
|
-
min_score=min_score,
|
|
70
|
-
max_score=max_score,
|
|
71
|
-
mean_score=mean_score,
|
|
72
|
-
count=
|
|
73
|
-
error_count=error_count,
|
|
122
|
+
annotation_name=scores_tuple.annotation_name,
|
|
123
|
+
min_score=scores_tuple.min_score,
|
|
124
|
+
max_score=scores_tuple.max_score,
|
|
125
|
+
mean_score=scores_tuple.mean_score,
|
|
126
|
+
count=scores_tuple.count_,
|
|
127
|
+
error_count=scores_tuple.error_count,
|
|
74
128
|
)
|
|
75
129
|
)
|
|
76
130
|
return [
|
|
77
131
|
sorted(summaries[experiment_id], key=lambda summary: summary.annotation_name)
|
|
78
|
-
for experiment_id in
|
|
132
|
+
for experiment_id in experiment_ids
|
|
79
133
|
]
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
from sqlalchemy import select
|
|
2
|
+
from strawberry.dataloader import DataLoader
|
|
3
|
+
from typing_extensions import TypeAlias
|
|
4
|
+
|
|
5
|
+
from phoenix.db import models
|
|
6
|
+
from phoenix.server.types import DbSessionFactory
|
|
7
|
+
|
|
8
|
+
ExperimentID: TypeAlias = int
|
|
9
|
+
Key: TypeAlias = ExperimentID
|
|
10
|
+
Result: TypeAlias = list[models.DatasetSplit]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ExperimentDatasetSplitsDataLoader(DataLoader[Key, Result]):
|
|
14
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
15
|
+
super().__init__(
|
|
16
|
+
load_fn=self._load_fn,
|
|
17
|
+
)
|
|
18
|
+
self._db = db
|
|
19
|
+
|
|
20
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
21
|
+
experiment_ids = keys
|
|
22
|
+
async with self._db() as session:
|
|
23
|
+
splits: dict[ExperimentID, list[models.DatasetSplit]] = {}
|
|
24
|
+
|
|
25
|
+
async for experiment_id, split in await session.stream(
|
|
26
|
+
select(models.ExperimentDatasetSplit.experiment_id, models.DatasetSplit)
|
|
27
|
+
.select_from(models.DatasetSplit)
|
|
28
|
+
.join(
|
|
29
|
+
models.ExperimentDatasetSplit,
|
|
30
|
+
onclause=(
|
|
31
|
+
models.DatasetSplit.id == models.ExperimentDatasetSplit.dataset_split_id
|
|
32
|
+
),
|
|
33
|
+
)
|
|
34
|
+
.where(models.ExperimentDatasetSplit.experiment_id.in_(experiment_ids))
|
|
35
|
+
):
|
|
36
|
+
if experiment_id not in splits:
|
|
37
|
+
splits[experiment_id] = []
|
|
38
|
+
splits[experiment_id].append(split)
|
|
39
|
+
|
|
40
|
+
return [
|
|
41
|
+
sorted(splits.get(experiment_id, []), key=lambda x: x.name)
|
|
42
|
+
for experiment_id in keys
|
|
43
|
+
]
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from typing import Optional
|
|
2
2
|
|
|
3
|
-
from sqlalchemy import
|
|
3
|
+
from sqlalchemy import func, select
|
|
4
4
|
from strawberry.dataloader import DataLoader
|
|
5
5
|
from typing_extensions import TypeAlias
|
|
6
6
|
|
|
@@ -23,36 +23,29 @@ class ExperimentErrorRatesDataLoader(DataLoader[Key, Result]):
|
|
|
23
23
|
|
|
24
24
|
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
25
25
|
experiment_ids = keys
|
|
26
|
-
|
|
27
|
-
select(models.Experiment.id)
|
|
28
|
-
.where(models.Experiment.id.in_(set(experiment_ids)))
|
|
29
|
-
.subquery()
|
|
30
|
-
)
|
|
31
|
-
query = (
|
|
26
|
+
average_repetition_error_rates_subquery = (
|
|
32
27
|
select(
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
(
|
|
36
|
-
|
|
37
|
-
func.count(models.ExperimentRun.error)
|
|
38
|
-
/ func.count(models.ExperimentRun.id),
|
|
39
|
-
),
|
|
40
|
-
else_=None,
|
|
41
|
-
),
|
|
28
|
+
models.ExperimentRun.experiment_id.label("experiment_id"),
|
|
29
|
+
(
|
|
30
|
+
func.count(models.ExperimentRun.error) / func.count(models.ExperimentRun.id)
|
|
31
|
+
).label("average_repetition_error_rate"),
|
|
42
32
|
)
|
|
43
|
-
.
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
)
|
|
48
|
-
.group_by(resolved_experiment_ids.c.id)
|
|
33
|
+
.where(models.ExperimentRun.experiment_id.in_(experiment_ids))
|
|
34
|
+
.group_by(models.ExperimentRun.dataset_example_id, models.ExperimentRun.experiment_id)
|
|
35
|
+
.subquery()
|
|
36
|
+
.alias("average_repetition_error_rates")
|
|
49
37
|
)
|
|
38
|
+
average_run_error_rates_query = select(
|
|
39
|
+
average_repetition_error_rates_subquery.c.experiment_id,
|
|
40
|
+
func.avg(average_repetition_error_rates_subquery.c.average_repetition_error_rate).label(
|
|
41
|
+
"average_run_error_rates"
|
|
42
|
+
),
|
|
43
|
+
).group_by(average_repetition_error_rates_subquery.c.experiment_id)
|
|
50
44
|
async with self._db() as session:
|
|
51
|
-
|
|
45
|
+
average_run_error_rates = {
|
|
52
46
|
experiment_id: error_rate
|
|
53
|
-
async for experiment_id, error_rate in await session.stream(
|
|
47
|
+
async for experiment_id, error_rate in await session.stream(
|
|
48
|
+
average_run_error_rates_query
|
|
49
|
+
)
|
|
54
50
|
}
|
|
55
|
-
return [
|
|
56
|
-
error_rates.get(experiment_id, ValueError(f"Unknown experiment ID: {experiment_id}"))
|
|
57
|
-
for experiment_id in keys
|
|
58
|
-
]
|
|
51
|
+
return [average_run_error_rates.get(experiment_id) for experiment_id in experiment_ids]
|