arize-phoenix 10.0.4__py3-none-any.whl → 12.28.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +124 -72
- arize_phoenix-12.28.1.dist-info/RECORD +499 -0
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
- phoenix/__generated__/__init__.py +0 -0
- phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
- phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
- phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
- phoenix/__init__.py +5 -4
- phoenix/auth.py +39 -2
- phoenix/config.py +1763 -91
- phoenix/datetime_utils.py +120 -2
- phoenix/db/README.md +595 -25
- phoenix/db/bulk_inserter.py +145 -103
- phoenix/db/engines.py +140 -33
- phoenix/db/enums.py +3 -12
- phoenix/db/facilitator.py +302 -35
- phoenix/db/helpers.py +1000 -65
- phoenix/db/iam_auth.py +64 -0
- phoenix/db/insertion/dataset.py +135 -2
- phoenix/db/insertion/document_annotation.py +9 -6
- phoenix/db/insertion/evaluation.py +2 -3
- phoenix/db/insertion/helpers.py +17 -2
- phoenix/db/insertion/session_annotation.py +176 -0
- phoenix/db/insertion/span.py +15 -11
- phoenix/db/insertion/span_annotation.py +3 -4
- phoenix/db/insertion/trace_annotation.py +3 -4
- phoenix/db/insertion/types.py +50 -20
- phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
- phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
- phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
- phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
- phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
- phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
- phoenix/db/migrations/versions/a20694b15f82_cost.py +196 -0
- phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
- phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
- phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
- phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
- phoenix/db/models.py +669 -56
- phoenix/db/pg_config.py +10 -0
- phoenix/db/types/model_provider.py +4 -0
- phoenix/db/types/token_price_customization.py +29 -0
- phoenix/db/types/trace_retention.py +23 -15
- phoenix/experiments/evaluators/utils.py +3 -3
- phoenix/experiments/functions.py +160 -52
- phoenix/experiments/tracing.py +2 -2
- phoenix/experiments/types.py +1 -1
- phoenix/inferences/inferences.py +1 -2
- phoenix/server/api/auth.py +38 -7
- phoenix/server/api/auth_messages.py +46 -0
- phoenix/server/api/context.py +100 -4
- phoenix/server/api/dataloaders/__init__.py +79 -5
- phoenix/server/api/dataloaders/annotation_configs_by_project.py +31 -0
- phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
- phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
- phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
- phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
- phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
- phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
- phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
- phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
- phoenix/server/api/dataloaders/dataset_labels.py +36 -0
- phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
- phoenix/server/api/dataloaders/document_evaluations.py +6 -9
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
- phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
- phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
- phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
- phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
- phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
- phoenix/server/api/dataloaders/last_used_times_by_generative_model_id.py +35 -0
- phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
- phoenix/server/api/dataloaders/record_counts.py +37 -10
- phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
- phoenix/server/api/dataloaders/span_cost_by_span.py +24 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_generative_model.py +56 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_project_session.py +57 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_span.py +43 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_trace.py +56 -0
- phoenix/server/api/dataloaders/span_cost_details_by_span_cost.py +27 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment.py +57 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_run.py +58 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_generative_model.py +55 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project.py +152 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project_session.py +56 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_trace.py +55 -0
- phoenix/server/api/dataloaders/span_costs.py +29 -0
- phoenix/server/api/dataloaders/table_fields.py +2 -2
- phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
- phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
- phoenix/server/api/dataloaders/types.py +29 -0
- phoenix/server/api/exceptions.py +11 -1
- phoenix/server/api/helpers/dataset_helpers.py +5 -1
- phoenix/server/api/helpers/playground_clients.py +1243 -292
- phoenix/server/api/helpers/playground_registry.py +2 -2
- phoenix/server/api/helpers/playground_spans.py +8 -4
- phoenix/server/api/helpers/playground_users.py +26 -0
- phoenix/server/api/helpers/prompts/conversions/aws.py +83 -0
- phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
- phoenix/server/api/helpers/prompts/models.py +205 -22
- phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
- phoenix/server/api/input_types/ChatCompletionInput.py +6 -2
- phoenix/server/api/input_types/CreateProjectInput.py +27 -0
- phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
- phoenix/server/api/input_types/DatasetFilter.py +17 -0
- phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
- phoenix/server/api/input_types/GenerativeCredentialInput.py +9 -0
- phoenix/server/api/input_types/GenerativeModelInput.py +5 -0
- phoenix/server/api/input_types/ProjectSessionSort.py +161 -1
- phoenix/server/api/input_types/PromptFilter.py +14 -0
- phoenix/server/api/input_types/PromptVersionInput.py +52 -1
- phoenix/server/api/input_types/SpanSort.py +44 -7
- phoenix/server/api/input_types/TimeBinConfig.py +23 -0
- phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
- phoenix/server/api/input_types/UserRoleInput.py +1 -0
- phoenix/server/api/mutations/__init__.py +10 -0
- phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
- phoenix/server/api/mutations/api_key_mutations.py +19 -23
- phoenix/server/api/mutations/chat_mutations.py +154 -47
- phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
- phoenix/server/api/mutations/dataset_mutations.py +21 -16
- phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
- phoenix/server/api/mutations/experiment_mutations.py +2 -2
- phoenix/server/api/mutations/export_events_mutations.py +3 -3
- phoenix/server/api/mutations/model_mutations.py +210 -0
- phoenix/server/api/mutations/project_mutations.py +49 -10
- phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
- phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
- phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
- phoenix/server/api/mutations/prompt_mutations.py +65 -129
- phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
- phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
- phoenix/server/api/mutations/trace_annotations_mutations.py +14 -10
- phoenix/server/api/mutations/trace_mutations.py +47 -3
- phoenix/server/api/mutations/user_mutations.py +66 -41
- phoenix/server/api/queries.py +768 -293
- phoenix/server/api/routers/__init__.py +2 -2
- phoenix/server/api/routers/auth.py +154 -88
- phoenix/server/api/routers/ldap.py +229 -0
- phoenix/server/api/routers/oauth2.py +369 -106
- phoenix/server/api/routers/v1/__init__.py +24 -4
- phoenix/server/api/routers/v1/annotation_configs.py +23 -31
- phoenix/server/api/routers/v1/annotations.py +481 -17
- phoenix/server/api/routers/v1/datasets.py +395 -81
- phoenix/server/api/routers/v1/documents.py +142 -0
- phoenix/server/api/routers/v1/evaluations.py +24 -31
- phoenix/server/api/routers/v1/experiment_evaluations.py +19 -8
- phoenix/server/api/routers/v1/experiment_runs.py +337 -59
- phoenix/server/api/routers/v1/experiments.py +479 -48
- phoenix/server/api/routers/v1/models.py +7 -0
- phoenix/server/api/routers/v1/projects.py +18 -49
- phoenix/server/api/routers/v1/prompts.py +54 -40
- phoenix/server/api/routers/v1/sessions.py +108 -0
- phoenix/server/api/routers/v1/spans.py +1091 -81
- phoenix/server/api/routers/v1/traces.py +132 -78
- phoenix/server/api/routers/v1/users.py +389 -0
- phoenix/server/api/routers/v1/utils.py +3 -7
- phoenix/server/api/subscriptions.py +305 -88
- phoenix/server/api/types/Annotation.py +90 -23
- phoenix/server/api/types/ApiKey.py +13 -17
- phoenix/server/api/types/AuthMethod.py +1 -0
- phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
- phoenix/server/api/types/CostBreakdown.py +12 -0
- phoenix/server/api/types/Dataset.py +226 -72
- phoenix/server/api/types/DatasetExample.py +88 -18
- phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
- phoenix/server/api/types/DatasetLabel.py +57 -0
- phoenix/server/api/types/DatasetSplit.py +98 -0
- phoenix/server/api/types/DatasetVersion.py +49 -4
- phoenix/server/api/types/DocumentAnnotation.py +212 -0
- phoenix/server/api/types/Experiment.py +264 -59
- phoenix/server/api/types/ExperimentComparison.py +5 -10
- phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
- phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
- phoenix/server/api/types/ExperimentRun.py +169 -65
- phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
- phoenix/server/api/types/GenerativeModel.py +245 -3
- phoenix/server/api/types/GenerativeProvider.py +70 -11
- phoenix/server/api/types/{Model.py → InferenceModel.py} +1 -1
- phoenix/server/api/types/ModelInterface.py +16 -0
- phoenix/server/api/types/PlaygroundModel.py +20 -0
- phoenix/server/api/types/Project.py +1278 -216
- phoenix/server/api/types/ProjectSession.py +188 -28
- phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
- phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
- phoenix/server/api/types/Prompt.py +119 -39
- phoenix/server/api/types/PromptLabel.py +42 -25
- phoenix/server/api/types/PromptVersion.py +11 -8
- phoenix/server/api/types/PromptVersionTag.py +65 -25
- phoenix/server/api/types/ServerStatus.py +6 -0
- phoenix/server/api/types/Span.py +167 -123
- phoenix/server/api/types/SpanAnnotation.py +189 -42
- phoenix/server/api/types/SpanCostDetailSummaryEntry.py +10 -0
- phoenix/server/api/types/SpanCostSummary.py +10 -0
- phoenix/server/api/types/SystemApiKey.py +65 -1
- phoenix/server/api/types/TokenPrice.py +16 -0
- phoenix/server/api/types/TokenUsage.py +3 -3
- phoenix/server/api/types/Trace.py +223 -51
- phoenix/server/api/types/TraceAnnotation.py +149 -50
- phoenix/server/api/types/User.py +137 -32
- phoenix/server/api/types/UserApiKey.py +73 -26
- phoenix/server/api/types/node.py +10 -0
- phoenix/server/api/types/pagination.py +11 -2
- phoenix/server/app.py +290 -45
- phoenix/server/authorization.py +38 -3
- phoenix/server/bearer_auth.py +34 -24
- phoenix/server/cost_tracking/cost_details_calculator.py +196 -0
- phoenix/server/cost_tracking/cost_model_lookup.py +179 -0
- phoenix/server/cost_tracking/helpers.py +68 -0
- phoenix/server/cost_tracking/model_cost_manifest.json +3657 -830
- phoenix/server/cost_tracking/regex_specificity.py +397 -0
- phoenix/server/cost_tracking/token_cost_calculator.py +57 -0
- phoenix/server/daemons/__init__.py +0 -0
- phoenix/server/daemons/db_disk_usage_monitor.py +214 -0
- phoenix/server/daemons/generative_model_store.py +103 -0
- phoenix/server/daemons/span_cost_calculator.py +99 -0
- phoenix/server/dml_event.py +17 -0
- phoenix/server/dml_event_handler.py +5 -0
- phoenix/server/email/sender.py +56 -3
- phoenix/server/email/templates/db_disk_usage_notification.html +19 -0
- phoenix/server/email/types.py +11 -0
- phoenix/server/experiments/__init__.py +0 -0
- phoenix/server/experiments/utils.py +14 -0
- phoenix/server/grpc_server.py +11 -11
- phoenix/server/jwt_store.py +17 -15
- phoenix/server/ldap.py +1449 -0
- phoenix/server/main.py +26 -10
- phoenix/server/oauth2.py +330 -12
- phoenix/server/prometheus.py +66 -6
- phoenix/server/rate_limiters.py +4 -9
- phoenix/server/retention.py +33 -20
- phoenix/server/session_filters.py +49 -0
- phoenix/server/static/.vite/manifest.json +55 -51
- phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
- phoenix/server/static/assets/{index-E0M82BdE.js → index-CTQoemZv.js} +140 -56
- phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
- phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
- phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
- phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
- phoenix/server/static/assets/vendor-recharts-V9cwpXsm.js +37 -0
- phoenix/server/static/assets/vendor-shiki-Do--csgv.js +5 -0
- phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
- phoenix/server/templates/index.html +40 -6
- phoenix/server/thread_server.py +1 -2
- phoenix/server/types.py +14 -4
- phoenix/server/utils.py +74 -0
- phoenix/session/client.py +56 -3
- phoenix/session/data_extractor.py +5 -0
- phoenix/session/evaluation.py +14 -5
- phoenix/session/session.py +45 -9
- phoenix/settings.py +5 -0
- phoenix/trace/attributes.py +80 -13
- phoenix/trace/dsl/helpers.py +90 -1
- phoenix/trace/dsl/query.py +8 -6
- phoenix/trace/projects.py +5 -0
- phoenix/utilities/template_formatters.py +1 -1
- phoenix/version.py +1 -1
- arize_phoenix-10.0.4.dist-info/RECORD +0 -405
- phoenix/server/api/types/Evaluation.py +0 -39
- phoenix/server/cost_tracking/cost_lookup.py +0 -255
- phoenix/server/static/assets/components-DULKeDfL.js +0 -4365
- phoenix/server/static/assets/pages-Cl0A-0U2.js +0 -7430
- phoenix/server/static/assets/vendor-WIZid84E.css +0 -1
- phoenix/server/static/assets/vendor-arizeai-Dy-0mSNw.js +0 -649
- phoenix/server/static/assets/vendor-codemirror-DBtifKNr.js +0 -33
- phoenix/server/static/assets/vendor-oB4u9zuV.js +0 -905
- phoenix/server/static/assets/vendor-recharts-D-T4KPz2.js +0 -59
- phoenix/server/static/assets/vendor-shiki-BMn4O_9F.js +0 -5
- phoenix/server/static/assets/vendor-three-C5WAXd5r.js +0 -2998
- phoenix/utilities/deprecation.py +0 -31
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,37 +1,155 @@
|
|
|
1
1
|
from datetime import datetime
|
|
2
|
-
from typing import
|
|
2
|
+
from typing import TYPE_CHECKING, Annotated, Optional
|
|
3
3
|
|
|
4
4
|
import strawberry
|
|
5
|
-
from sqlalchemy import select
|
|
6
|
-
from sqlalchemy.orm import joinedload
|
|
5
|
+
from sqlalchemy import func, select
|
|
7
6
|
from strawberry import UNSET, Private
|
|
8
|
-
from strawberry.relay import Connection, Node, NodeID
|
|
7
|
+
from strawberry.relay import Connection, GlobalID, Node, NodeID
|
|
9
8
|
from strawberry.scalars import JSON
|
|
10
9
|
from strawberry.types import Info
|
|
11
10
|
|
|
12
11
|
from phoenix.db import models
|
|
13
12
|
from phoenix.server.api.context import Context
|
|
13
|
+
from phoenix.server.api.exceptions import BadRequest
|
|
14
|
+
from phoenix.server.api.input_types.ExperimentRunSort import (
|
|
15
|
+
ExperimentRunSort,
|
|
16
|
+
add_order_by_and_page_start_to_query,
|
|
17
|
+
get_experiment_run_cursor,
|
|
18
|
+
)
|
|
19
|
+
from phoenix.server.api.types.CostBreakdown import CostBreakdown
|
|
20
|
+
from phoenix.server.api.types.DatasetSplit import DatasetSplit
|
|
21
|
+
from phoenix.server.api.types.DatasetVersion import DatasetVersion
|
|
14
22
|
from phoenix.server.api.types.ExperimentAnnotationSummary import ExperimentAnnotationSummary
|
|
15
|
-
from phoenix.server.api.types.ExperimentRun import ExperimentRun
|
|
23
|
+
from phoenix.server.api.types.ExperimentRun import ExperimentRun
|
|
16
24
|
from phoenix.server.api.types.pagination import (
|
|
17
25
|
ConnectionArgs,
|
|
26
|
+
Cursor,
|
|
18
27
|
CursorString,
|
|
28
|
+
connection_from_cursors_and_nodes,
|
|
19
29
|
connection_from_list,
|
|
20
30
|
)
|
|
21
|
-
from phoenix.server.api.types.
|
|
31
|
+
from phoenix.server.api.types.SpanCostDetailSummaryEntry import SpanCostDetailSummaryEntry
|
|
32
|
+
from phoenix.server.api.types.SpanCostSummary import SpanCostSummary
|
|
33
|
+
|
|
34
|
+
_DEFAULT_EXPERIMENT_RUNS_PAGE_SIZE = 50
|
|
35
|
+
|
|
36
|
+
if TYPE_CHECKING:
|
|
37
|
+
from .Project import Project
|
|
22
38
|
|
|
23
39
|
|
|
24
40
|
@strawberry.type
|
|
25
41
|
class Experiment(Node):
|
|
26
|
-
|
|
42
|
+
id: NodeID[int]
|
|
43
|
+
db_record: strawberry.Private[Optional[models.Experiment]] = None
|
|
27
44
|
cached_sequence_number: Private[Optional[int]] = None
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
45
|
+
|
|
46
|
+
def __post_init__(self) -> None:
|
|
47
|
+
if self.db_record and self.id != self.db_record.id:
|
|
48
|
+
raise ValueError("Experiment ID mismatch")
|
|
49
|
+
|
|
50
|
+
@strawberry.field
|
|
51
|
+
async def name(
|
|
52
|
+
self,
|
|
53
|
+
info: Info[Context, None],
|
|
54
|
+
) -> str:
|
|
55
|
+
if self.db_record:
|
|
56
|
+
val = self.db_record.name
|
|
57
|
+
else:
|
|
58
|
+
val = await info.context.data_loaders.experiment_fields.load(
|
|
59
|
+
(self.id, models.Experiment.name),
|
|
60
|
+
)
|
|
61
|
+
return val
|
|
62
|
+
|
|
63
|
+
@strawberry.field
|
|
64
|
+
async def project_name(
|
|
65
|
+
self,
|
|
66
|
+
info: Info[Context, None],
|
|
67
|
+
) -> Optional[str]:
|
|
68
|
+
if self.db_record:
|
|
69
|
+
val = self.db_record.project_name
|
|
70
|
+
else:
|
|
71
|
+
val = await info.context.data_loaders.experiment_fields.load(
|
|
72
|
+
(self.id, models.Experiment.project_name),
|
|
73
|
+
)
|
|
74
|
+
return val
|
|
75
|
+
|
|
76
|
+
@strawberry.field
|
|
77
|
+
async def description(
|
|
78
|
+
self,
|
|
79
|
+
info: Info[Context, None],
|
|
80
|
+
) -> Optional[str]:
|
|
81
|
+
if self.db_record:
|
|
82
|
+
val = self.db_record.description
|
|
83
|
+
else:
|
|
84
|
+
val = await info.context.data_loaders.experiment_fields.load(
|
|
85
|
+
(self.id, models.Experiment.description),
|
|
86
|
+
)
|
|
87
|
+
return val
|
|
88
|
+
|
|
89
|
+
@strawberry.field
|
|
90
|
+
async def repetitions(
|
|
91
|
+
self,
|
|
92
|
+
info: Info[Context, None],
|
|
93
|
+
) -> int:
|
|
94
|
+
if self.db_record:
|
|
95
|
+
val = self.db_record.repetitions
|
|
96
|
+
else:
|
|
97
|
+
val = await info.context.data_loaders.experiment_fields.load(
|
|
98
|
+
(self.id, models.Experiment.repetitions),
|
|
99
|
+
)
|
|
100
|
+
return val
|
|
101
|
+
|
|
102
|
+
@strawberry.field
|
|
103
|
+
async def dataset_version_id(
|
|
104
|
+
self,
|
|
105
|
+
info: Info[Context, None],
|
|
106
|
+
) -> GlobalID:
|
|
107
|
+
if self.db_record:
|
|
108
|
+
version_id = self.db_record.dataset_version_id
|
|
109
|
+
else:
|
|
110
|
+
version_id = await info.context.data_loaders.experiment_fields.load(
|
|
111
|
+
(self.id, models.Experiment.dataset_version_id),
|
|
112
|
+
)
|
|
113
|
+
return GlobalID(DatasetVersion.__name__, str(version_id))
|
|
114
|
+
|
|
115
|
+
@strawberry.field
|
|
116
|
+
async def metadata(
|
|
117
|
+
self,
|
|
118
|
+
info: Info[Context, None],
|
|
119
|
+
) -> JSON:
|
|
120
|
+
if self.db_record:
|
|
121
|
+
val = self.db_record.metadata_
|
|
122
|
+
else:
|
|
123
|
+
val = await info.context.data_loaders.experiment_fields.load(
|
|
124
|
+
(self.id, models.Experiment.metadata_),
|
|
125
|
+
)
|
|
126
|
+
return val
|
|
127
|
+
|
|
128
|
+
@strawberry.field
|
|
129
|
+
async def created_at(
|
|
130
|
+
self,
|
|
131
|
+
info: Info[Context, None],
|
|
132
|
+
) -> datetime:
|
|
133
|
+
if self.db_record:
|
|
134
|
+
val = self.db_record.created_at
|
|
135
|
+
else:
|
|
136
|
+
val = await info.context.data_loaders.experiment_fields.load(
|
|
137
|
+
(self.id, models.Experiment.created_at),
|
|
138
|
+
)
|
|
139
|
+
return val
|
|
140
|
+
|
|
141
|
+
@strawberry.field
|
|
142
|
+
async def updated_at(
|
|
143
|
+
self,
|
|
144
|
+
info: Info[Context, None],
|
|
145
|
+
) -> datetime:
|
|
146
|
+
if self.db_record:
|
|
147
|
+
val = self.db_record.updated_at
|
|
148
|
+
else:
|
|
149
|
+
val = await info.context.data_loaders.experiment_fields.load(
|
|
150
|
+
(self.id, models.Experiment.updated_at),
|
|
151
|
+
)
|
|
152
|
+
return val
|
|
35
153
|
|
|
36
154
|
@strawberry.field(
|
|
37
155
|
description="Sequence number (1-based) of experiments belonging to the same dataset"
|
|
@@ -41,9 +159,9 @@ class Experiment(Node):
|
|
|
41
159
|
info: Info[Context, None],
|
|
42
160
|
) -> int:
|
|
43
161
|
if self.cached_sequence_number is None:
|
|
44
|
-
seq_num = await info.context.data_loaders.experiment_sequence_number.load(self.
|
|
162
|
+
seq_num = await info.context.data_loaders.experiment_sequence_number.load(self.id)
|
|
45
163
|
if seq_num is None:
|
|
46
|
-
raise ValueError(f"invalid experiment: id={self.
|
|
164
|
+
raise ValueError(f"invalid experiment: id={self.id}")
|
|
47
165
|
self.cached_sequence_number = seq_num
|
|
48
166
|
return self.cached_sequence_number
|
|
49
167
|
|
|
@@ -51,41 +169,68 @@ class Experiment(Node):
|
|
|
51
169
|
async def runs(
|
|
52
170
|
self,
|
|
53
171
|
info: Info[Context, None],
|
|
54
|
-
first: Optional[int] =
|
|
55
|
-
last: Optional[int] = UNSET,
|
|
172
|
+
first: Optional[int] = _DEFAULT_EXPERIMENT_RUNS_PAGE_SIZE,
|
|
56
173
|
after: Optional[CursorString] = UNSET,
|
|
57
|
-
|
|
174
|
+
sort: Optional[ExperimentRunSort] = UNSET,
|
|
58
175
|
) -> Connection[ExperimentRun]:
|
|
59
|
-
|
|
60
|
-
first
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
176
|
+
if first is not None and first <= 0:
|
|
177
|
+
raise BadRequest("first must be a positive integer if set")
|
|
178
|
+
page_size = first or _DEFAULT_EXPERIMENT_RUNS_PAGE_SIZE
|
|
179
|
+
experiment_runs_query = (
|
|
180
|
+
select(models.ExperimentRun)
|
|
181
|
+
.where(models.ExperimentRun.experiment_id == self.id)
|
|
182
|
+
.limit(page_size + 1)
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
after_experiment_run_rowid = None
|
|
186
|
+
after_sort_column_value = None
|
|
187
|
+
if after:
|
|
188
|
+
cursor = Cursor.from_string(after)
|
|
189
|
+
after_experiment_run_rowid = cursor.rowid
|
|
190
|
+
if cursor.sort_column is not None:
|
|
191
|
+
after_sort_column_value = cursor.sort_column.value
|
|
192
|
+
|
|
193
|
+
experiment_runs_query = add_order_by_and_page_start_to_query(
|
|
194
|
+
query=experiment_runs_query,
|
|
195
|
+
sort=sort,
|
|
196
|
+
experiment_rowid=self.id,
|
|
197
|
+
after_experiment_run_rowid=after_experiment_run_rowid,
|
|
198
|
+
after_sort_column_value=after_sort_column_value,
|
|
64
199
|
)
|
|
65
|
-
|
|
200
|
+
|
|
66
201
|
async with info.context.db() as session:
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
202
|
+
results = (await session.execute(experiment_runs_query)).all()
|
|
203
|
+
|
|
204
|
+
has_next_page = False
|
|
205
|
+
if len(results) > page_size:
|
|
206
|
+
results = results[:page_size]
|
|
207
|
+
has_next_page = True
|
|
208
|
+
|
|
209
|
+
cursors_and_nodes = []
|
|
210
|
+
for result in results:
|
|
211
|
+
run = result[0]
|
|
212
|
+
annotation_score = result[1] if len(result) > 1 else None
|
|
213
|
+
gql_run = ExperimentRun(id=run.id, db_record=run)
|
|
214
|
+
cursor = get_experiment_run_cursor(
|
|
215
|
+
run=run, annotation_score=annotation_score, sort=sort
|
|
216
|
+
)
|
|
217
|
+
cursors_and_nodes.append((cursor, gql_run))
|
|
218
|
+
|
|
219
|
+
return connection_from_cursors_and_nodes(
|
|
220
|
+
cursors_and_nodes=cursors_and_nodes,
|
|
221
|
+
has_previous_page=False, # set to false since we are only doing forward pagination (https://relay.dev/graphql/connections.htm#sec-undefined.PageInfo.Fields) # noqa: E501
|
|
222
|
+
has_next_page=has_next_page,
|
|
223
|
+
)
|
|
78
224
|
|
|
79
225
|
@strawberry.field
|
|
80
226
|
async def run_count(self, info: Info[Context, None]) -> int:
|
|
81
|
-
|
|
82
|
-
return await info.context.data_loaders.experiment_run_counts.load(experiment_id)
|
|
227
|
+
return await info.context.data_loaders.experiment_run_counts.load(self.id)
|
|
83
228
|
|
|
84
229
|
@strawberry.field
|
|
85
230
|
async def annotation_summaries(
|
|
86
231
|
self, info: Info[Context, None]
|
|
87
232
|
) -> list[ExperimentAnnotationSummary]:
|
|
88
|
-
experiment_id = self.
|
|
233
|
+
experiment_id = self.id
|
|
89
234
|
return [
|
|
90
235
|
ExperimentAnnotationSummary(
|
|
91
236
|
annotation_name=summary.annotation_name,
|
|
@@ -102,33 +247,98 @@ class Experiment(Node):
|
|
|
102
247
|
|
|
103
248
|
@strawberry.field
|
|
104
249
|
async def error_rate(self, info: Info[Context, None]) -> Optional[float]:
|
|
105
|
-
return await info.context.data_loaders.experiment_error_rates.load(self.
|
|
250
|
+
return await info.context.data_loaders.experiment_error_rates.load(self.id)
|
|
106
251
|
|
|
107
252
|
@strawberry.field
|
|
108
253
|
async def average_run_latency_ms(self, info: Info[Context, None]) -> Optional[float]:
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
)
|
|
112
|
-
return latency_seconds * 1000 if latency_seconds is not None else None
|
|
254
|
+
latency_ms = await info.context.data_loaders.average_experiment_run_latency.load(self.id)
|
|
255
|
+
return latency_ms
|
|
113
256
|
|
|
114
257
|
@strawberry.field
|
|
115
|
-
async def project(
|
|
116
|
-
|
|
258
|
+
async def project(
|
|
259
|
+
self, info: Info[Context, None]
|
|
260
|
+
) -> Optional[Annotated["Project", strawberry.lazy(".Project")]]:
|
|
261
|
+
if self.db_record:
|
|
262
|
+
project_name = self.db_record.project_name
|
|
263
|
+
else:
|
|
264
|
+
project_name = await info.context.data_loaders.experiment_fields.load(
|
|
265
|
+
(self.id, models.Experiment.project_name),
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
if project_name is None:
|
|
117
269
|
return None
|
|
118
270
|
|
|
119
|
-
db_project = await info.context.data_loaders.project_by_name.load(
|
|
271
|
+
db_project = await info.context.data_loaders.project_by_name.load(project_name)
|
|
120
272
|
|
|
121
273
|
if db_project is None:
|
|
122
274
|
return None
|
|
275
|
+
from .Project import Project
|
|
123
276
|
|
|
124
|
-
return Project(
|
|
125
|
-
project_rowid=db_project.id,
|
|
126
|
-
db_project=db_project,
|
|
127
|
-
)
|
|
277
|
+
return Project(id=db_project.id, db_record=db_project)
|
|
128
278
|
|
|
129
279
|
@strawberry.field
|
|
130
280
|
def last_updated_at(self, info: Info[Context, None]) -> Optional[datetime]:
|
|
131
|
-
return info.context.last_updated_at.get(
|
|
281
|
+
return info.context.last_updated_at.get(models.Experiment, self.id)
|
|
282
|
+
|
|
283
|
+
@strawberry.field
|
|
284
|
+
async def cost_summary(self, info: Info[Context, None]) -> SpanCostSummary:
|
|
285
|
+
summary = await info.context.data_loaders.span_cost_summary_by_experiment.load(self.id)
|
|
286
|
+
return SpanCostSummary(
|
|
287
|
+
prompt=CostBreakdown(
|
|
288
|
+
tokens=summary.prompt.tokens,
|
|
289
|
+
cost=summary.prompt.cost,
|
|
290
|
+
),
|
|
291
|
+
completion=CostBreakdown(
|
|
292
|
+
tokens=summary.completion.tokens,
|
|
293
|
+
cost=summary.completion.cost,
|
|
294
|
+
),
|
|
295
|
+
total=CostBreakdown(
|
|
296
|
+
tokens=summary.total.tokens,
|
|
297
|
+
cost=summary.total.cost,
|
|
298
|
+
),
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
@strawberry.field
|
|
302
|
+
async def cost_detail_summary_entries(
|
|
303
|
+
self, info: Info[Context, None]
|
|
304
|
+
) -> list[SpanCostDetailSummaryEntry]:
|
|
305
|
+
stmt = (
|
|
306
|
+
select(
|
|
307
|
+
models.SpanCostDetail.token_type,
|
|
308
|
+
models.SpanCostDetail.is_prompt,
|
|
309
|
+
func.sum(models.SpanCostDetail.cost).label("cost"),
|
|
310
|
+
func.sum(models.SpanCostDetail.tokens).label("tokens"),
|
|
311
|
+
)
|
|
312
|
+
.select_from(models.SpanCostDetail)
|
|
313
|
+
.join(models.SpanCost, models.SpanCostDetail.span_cost_id == models.SpanCost.id)
|
|
314
|
+
.join(models.Span, models.SpanCost.span_rowid == models.Span.id)
|
|
315
|
+
.join(models.Trace, models.Span.trace_rowid == models.Trace.id)
|
|
316
|
+
.join(models.ExperimentRun, models.ExperimentRun.trace_id == models.Trace.trace_id)
|
|
317
|
+
.where(models.ExperimentRun.experiment_id == self.id)
|
|
318
|
+
.group_by(models.SpanCostDetail.token_type, models.SpanCostDetail.is_prompt)
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
async with info.context.db() as session:
|
|
322
|
+
data = await session.stream(stmt)
|
|
323
|
+
return [
|
|
324
|
+
SpanCostDetailSummaryEntry(
|
|
325
|
+
token_type=token_type,
|
|
326
|
+
is_prompt=is_prompt,
|
|
327
|
+
value=CostBreakdown(tokens=tokens, cost=cost),
|
|
328
|
+
)
|
|
329
|
+
async for token_type, is_prompt, cost, tokens in data
|
|
330
|
+
]
|
|
331
|
+
|
|
332
|
+
@strawberry.field
|
|
333
|
+
async def dataset_splits(
|
|
334
|
+
self,
|
|
335
|
+
info: Info[Context, None],
|
|
336
|
+
) -> Connection[DatasetSplit]:
|
|
337
|
+
"""Returns the dataset splits associated with this experiment."""
|
|
338
|
+
splits = await info.context.data_loaders.experiment_dataset_splits.load(self.id)
|
|
339
|
+
return connection_from_list(
|
|
340
|
+
[DatasetSplit(id=split.id, db_record=split) for split in splits], ConnectionArgs()
|
|
341
|
+
)
|
|
132
342
|
|
|
133
343
|
|
|
134
344
|
def to_gql_experiment(
|
|
@@ -139,12 +349,7 @@ def to_gql_experiment(
|
|
|
139
349
|
Converts an ORM experiment to a GraphQL Experiment.
|
|
140
350
|
"""
|
|
141
351
|
return Experiment(
|
|
352
|
+
id=experiment.id,
|
|
353
|
+
db_record=experiment,
|
|
142
354
|
cached_sequence_number=sequence_number,
|
|
143
|
-
id_attr=experiment.id,
|
|
144
|
-
name=experiment.name,
|
|
145
|
-
project_name=experiment.project_name,
|
|
146
|
-
description=experiment.description,
|
|
147
|
-
metadata=experiment.metadata_,
|
|
148
|
-
created_at=experiment.created_at,
|
|
149
|
-
updated_at=experiment.updated_at,
|
|
150
355
|
)
|
|
@@ -1,17 +1,12 @@
|
|
|
1
1
|
import strawberry
|
|
2
|
-
from strawberry.relay import
|
|
2
|
+
from strawberry.relay import Node, NodeID
|
|
3
3
|
|
|
4
4
|
from phoenix.server.api.types.DatasetExample import DatasetExample
|
|
5
|
-
from phoenix.server.api.types.
|
|
5
|
+
from phoenix.server.api.types.ExperimentRepeatedRunGroup import ExperimentRepeatedRunGroup
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
@strawberry.type
|
|
9
|
-
class
|
|
10
|
-
|
|
11
|
-
runs: list[ExperimentRun]
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
@strawberry.type
|
|
15
|
-
class ExperimentComparison:
|
|
9
|
+
class ExperimentComparison(Node):
|
|
10
|
+
id_attr: NodeID[int]
|
|
16
11
|
example: DatasetExample
|
|
17
|
-
|
|
12
|
+
repeated_run_groups: list[ExperimentRepeatedRunGroup]
|
|
@@ -0,0 +1,155 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from base64 import b64decode
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import strawberry
|
|
6
|
+
from sqlalchemy import func, select
|
|
7
|
+
from strawberry.relay import GlobalID, Node
|
|
8
|
+
from strawberry.types import Info
|
|
9
|
+
from typing_extensions import Self, TypeAlias
|
|
10
|
+
|
|
11
|
+
from phoenix.db import models
|
|
12
|
+
from phoenix.server.api.context import Context
|
|
13
|
+
from phoenix.server.api.types.CostBreakdown import CostBreakdown
|
|
14
|
+
from phoenix.server.api.types.ExperimentRepeatedRunGroupAnnotationSummary import (
|
|
15
|
+
ExperimentRepeatedRunGroupAnnotationSummary,
|
|
16
|
+
)
|
|
17
|
+
from phoenix.server.api.types.ExperimentRun import ExperimentRun
|
|
18
|
+
from phoenix.server.api.types.SpanCostDetailSummaryEntry import SpanCostDetailSummaryEntry
|
|
19
|
+
from phoenix.server.api.types.SpanCostSummary import SpanCostSummary
|
|
20
|
+
|
|
21
|
+
ExperimentRowId: TypeAlias = int
|
|
22
|
+
DatasetExampleRowId: TypeAlias = int
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@strawberry.type
|
|
26
|
+
class ExperimentRepeatedRunGroup(Node):
|
|
27
|
+
experiment_rowid: strawberry.Private[ExperimentRowId]
|
|
28
|
+
dataset_example_rowid: strawberry.Private[DatasetExampleRowId]
|
|
29
|
+
cached_runs: strawberry.Private[Optional[list[ExperimentRun]]] = None
|
|
30
|
+
|
|
31
|
+
@strawberry.field
|
|
32
|
+
async def runs(self, info: Info[Context, None]) -> list[ExperimentRun]:
|
|
33
|
+
if self.cached_runs is not None:
|
|
34
|
+
return self.cached_runs
|
|
35
|
+
runs = await info.context.data_loaders.experiment_runs_by_experiment_and_example.load(
|
|
36
|
+
(self.experiment_rowid, self.dataset_example_rowid)
|
|
37
|
+
)
|
|
38
|
+
return [ExperimentRun(id=run.id, db_record=run) for run in runs]
|
|
39
|
+
|
|
40
|
+
@classmethod
|
|
41
|
+
def resolve_id(
|
|
42
|
+
cls,
|
|
43
|
+
root: Self,
|
|
44
|
+
*,
|
|
45
|
+
info: Info,
|
|
46
|
+
) -> str:
|
|
47
|
+
return (
|
|
48
|
+
f"experiment_id={root.experiment_rowid}:dataset_example_id={root.dataset_example_rowid}"
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
@strawberry.field
|
|
52
|
+
def experiment_id(self) -> strawberry.ID:
|
|
53
|
+
from phoenix.server.api.types.Experiment import Experiment
|
|
54
|
+
|
|
55
|
+
return strawberry.ID(str(GlobalID(Experiment.__name__, str(self.experiment_rowid))))
|
|
56
|
+
|
|
57
|
+
@strawberry.field
|
|
58
|
+
async def average_latency_ms(self, info: Info[Context, None]) -> Optional[float]:
|
|
59
|
+
return await info.context.data_loaders.average_experiment_repeated_run_group_latency.load(
|
|
60
|
+
(self.experiment_rowid, self.dataset_example_rowid)
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
@strawberry.field
|
|
64
|
+
async def cost_summary(self, info: Info[Context, None]) -> SpanCostSummary:
|
|
65
|
+
experiment_id = self.experiment_rowid
|
|
66
|
+
example_id = self.dataset_example_rowid
|
|
67
|
+
summary = (
|
|
68
|
+
await info.context.data_loaders.span_cost_summary_by_experiment_repeated_run_group.load(
|
|
69
|
+
(experiment_id, example_id)
|
|
70
|
+
)
|
|
71
|
+
)
|
|
72
|
+
return SpanCostSummary(
|
|
73
|
+
prompt=CostBreakdown(
|
|
74
|
+
tokens=summary.prompt.tokens,
|
|
75
|
+
cost=summary.prompt.cost,
|
|
76
|
+
),
|
|
77
|
+
completion=CostBreakdown(
|
|
78
|
+
tokens=summary.completion.tokens,
|
|
79
|
+
cost=summary.completion.cost,
|
|
80
|
+
),
|
|
81
|
+
total=CostBreakdown(
|
|
82
|
+
tokens=summary.total.tokens,
|
|
83
|
+
cost=summary.total.cost,
|
|
84
|
+
),
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
@strawberry.field
|
|
88
|
+
async def cost_detail_summary_entries(
|
|
89
|
+
self, info: Info[Context, None]
|
|
90
|
+
) -> list[SpanCostDetailSummaryEntry]:
|
|
91
|
+
experiment_id = self.experiment_rowid
|
|
92
|
+
example_id = self.dataset_example_rowid
|
|
93
|
+
stmt = (
|
|
94
|
+
select(
|
|
95
|
+
models.SpanCostDetail.token_type,
|
|
96
|
+
models.SpanCostDetail.is_prompt,
|
|
97
|
+
func.sum(models.SpanCostDetail.cost).label("cost"),
|
|
98
|
+
func.sum(models.SpanCostDetail.tokens).label("tokens"),
|
|
99
|
+
)
|
|
100
|
+
.select_from(models.SpanCostDetail)
|
|
101
|
+
.join(models.SpanCost, models.SpanCostDetail.span_cost_id == models.SpanCost.id)
|
|
102
|
+
.join(models.Trace, models.SpanCost.trace_rowid == models.Trace.id)
|
|
103
|
+
.join(models.ExperimentRun, models.ExperimentRun.trace_id == models.Trace.trace_id)
|
|
104
|
+
.where(models.ExperimentRun.experiment_id == experiment_id)
|
|
105
|
+
.where(models.ExperimentRun.dataset_example_id == example_id)
|
|
106
|
+
.group_by(models.SpanCostDetail.token_type, models.SpanCostDetail.is_prompt)
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
async with info.context.db() as session:
|
|
110
|
+
data = await session.stream(stmt)
|
|
111
|
+
return [
|
|
112
|
+
SpanCostDetailSummaryEntry(
|
|
113
|
+
token_type=token_type,
|
|
114
|
+
is_prompt=is_prompt,
|
|
115
|
+
value=CostBreakdown(tokens=tokens, cost=cost),
|
|
116
|
+
)
|
|
117
|
+
async for token_type, is_prompt, cost, tokens in data
|
|
118
|
+
]
|
|
119
|
+
|
|
120
|
+
@strawberry.field
|
|
121
|
+
async def annotation_summaries(
|
|
122
|
+
self,
|
|
123
|
+
info: Info[Context, None],
|
|
124
|
+
) -> list[ExperimentRepeatedRunGroupAnnotationSummary]:
|
|
125
|
+
loader = info.context.data_loaders.experiment_repeated_run_group_annotation_summaries
|
|
126
|
+
summaries = await loader.load((self.experiment_rowid, self.dataset_example_rowid))
|
|
127
|
+
return [
|
|
128
|
+
ExperimentRepeatedRunGroupAnnotationSummary(
|
|
129
|
+
annotation_name=summary.annotation_name,
|
|
130
|
+
mean_score=summary.mean_score,
|
|
131
|
+
)
|
|
132
|
+
for summary in summaries
|
|
133
|
+
]
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
_EXPERIMENT_REPEATED_RUN_GROUP_NODE_ID_PATTERN = re.compile(
|
|
137
|
+
r"ExperimentRepeatedRunGroup:experiment_id=(\d+):dataset_example_id=(\d+)"
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def parse_experiment_repeated_run_group_node_id(
|
|
142
|
+
node_id: str,
|
|
143
|
+
) -> tuple[ExperimentRowId, DatasetExampleRowId]:
|
|
144
|
+
decoded_node_id = _base64_decode(node_id)
|
|
145
|
+
match = re.match(_EXPERIMENT_REPEATED_RUN_GROUP_NODE_ID_PATTERN, decoded_node_id)
|
|
146
|
+
if not match:
|
|
147
|
+
raise ValueError(f"Invalid node ID format: {node_id}")
|
|
148
|
+
|
|
149
|
+
experiment_id = int(match.group(1))
|
|
150
|
+
dataset_example_id = int(match.group(2))
|
|
151
|
+
return experiment_id, dataset_example_id
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def _base64_decode(string: str) -> str:
|
|
155
|
+
return b64decode(string.encode()).decode()
|