arize-phoenix 11.23.1__py3-none-any.whl → 12.28.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +61 -36
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/RECORD +212 -162
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
- phoenix/__generated__/__init__.py +0 -0
- phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
- phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
- phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
- phoenix/__init__.py +2 -1
- phoenix/auth.py +27 -2
- phoenix/config.py +1594 -81
- phoenix/db/README.md +546 -28
- phoenix/db/bulk_inserter.py +119 -116
- phoenix/db/engines.py +140 -33
- phoenix/db/facilitator.py +22 -1
- phoenix/db/helpers.py +818 -65
- phoenix/db/iam_auth.py +64 -0
- phoenix/db/insertion/dataset.py +133 -1
- phoenix/db/insertion/document_annotation.py +9 -6
- phoenix/db/insertion/evaluation.py +2 -3
- phoenix/db/insertion/helpers.py +2 -2
- phoenix/db/insertion/session_annotation.py +176 -0
- phoenix/db/insertion/span_annotation.py +3 -4
- phoenix/db/insertion/trace_annotation.py +3 -4
- phoenix/db/insertion/types.py +41 -18
- phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
- phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
- phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
- phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
- phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
- phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
- phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
- phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
- phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
- phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
- phoenix/db/models.py +364 -56
- phoenix/db/pg_config.py +10 -0
- phoenix/db/types/trace_retention.py +7 -6
- phoenix/experiments/functions.py +69 -19
- phoenix/inferences/inferences.py +1 -2
- phoenix/server/api/auth.py +9 -0
- phoenix/server/api/auth_messages.py +46 -0
- phoenix/server/api/context.py +60 -0
- phoenix/server/api/dataloaders/__init__.py +36 -0
- phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
- phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
- phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
- phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
- phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
- phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
- phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
- phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
- phoenix/server/api/dataloaders/dataset_labels.py +36 -0
- phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
- phoenix/server/api/dataloaders/document_evaluations.py +6 -9
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
- phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
- phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
- phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
- phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
- phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
- phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
- phoenix/server/api/dataloaders/record_counts.py +37 -10
- phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project.py +28 -14
- phoenix/server/api/dataloaders/span_costs.py +3 -9
- phoenix/server/api/dataloaders/table_fields.py +2 -2
- phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
- phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
- phoenix/server/api/exceptions.py +5 -1
- phoenix/server/api/helpers/playground_clients.py +263 -83
- phoenix/server/api/helpers/playground_spans.py +2 -1
- phoenix/server/api/helpers/playground_users.py +26 -0
- phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
- phoenix/server/api/helpers/prompts/models.py +61 -19
- phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
- phoenix/server/api/input_types/ChatCompletionInput.py +3 -0
- phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
- phoenix/server/api/input_types/DatasetFilter.py +5 -2
- phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
- phoenix/server/api/input_types/GenerativeModelInput.py +3 -0
- phoenix/server/api/input_types/ProjectSessionSort.py +158 -1
- phoenix/server/api/input_types/PromptVersionInput.py +47 -1
- phoenix/server/api/input_types/SpanSort.py +3 -2
- phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
- phoenix/server/api/input_types/UserRoleInput.py +1 -0
- phoenix/server/api/mutations/__init__.py +8 -0
- phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
- phoenix/server/api/mutations/api_key_mutations.py +15 -20
- phoenix/server/api/mutations/chat_mutations.py +106 -37
- phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
- phoenix/server/api/mutations/dataset_mutations.py +21 -16
- phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
- phoenix/server/api/mutations/experiment_mutations.py +2 -2
- phoenix/server/api/mutations/export_events_mutations.py +3 -3
- phoenix/server/api/mutations/model_mutations.py +11 -9
- phoenix/server/api/mutations/project_mutations.py +4 -4
- phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
- phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
- phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
- phoenix/server/api/mutations/prompt_mutations.py +65 -129
- phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
- phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
- phoenix/server/api/mutations/trace_annotations_mutations.py +13 -8
- phoenix/server/api/mutations/trace_mutations.py +3 -3
- phoenix/server/api/mutations/user_mutations.py +55 -26
- phoenix/server/api/queries.py +501 -617
- phoenix/server/api/routers/__init__.py +2 -2
- phoenix/server/api/routers/auth.py +141 -87
- phoenix/server/api/routers/ldap.py +229 -0
- phoenix/server/api/routers/oauth2.py +349 -101
- phoenix/server/api/routers/v1/__init__.py +22 -4
- phoenix/server/api/routers/v1/annotation_configs.py +19 -30
- phoenix/server/api/routers/v1/annotations.py +455 -13
- phoenix/server/api/routers/v1/datasets.py +355 -68
- phoenix/server/api/routers/v1/documents.py +142 -0
- phoenix/server/api/routers/v1/evaluations.py +20 -28
- phoenix/server/api/routers/v1/experiment_evaluations.py +16 -6
- phoenix/server/api/routers/v1/experiment_runs.py +335 -59
- phoenix/server/api/routers/v1/experiments.py +475 -47
- phoenix/server/api/routers/v1/projects.py +16 -50
- phoenix/server/api/routers/v1/prompts.py +50 -39
- phoenix/server/api/routers/v1/sessions.py +108 -0
- phoenix/server/api/routers/v1/spans.py +156 -96
- phoenix/server/api/routers/v1/traces.py +51 -77
- phoenix/server/api/routers/v1/users.py +64 -24
- phoenix/server/api/routers/v1/utils.py +3 -7
- phoenix/server/api/subscriptions.py +257 -93
- phoenix/server/api/types/Annotation.py +90 -23
- phoenix/server/api/types/ApiKey.py +13 -17
- phoenix/server/api/types/AuthMethod.py +1 -0
- phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
- phoenix/server/api/types/Dataset.py +199 -72
- phoenix/server/api/types/DatasetExample.py +88 -18
- phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
- phoenix/server/api/types/DatasetLabel.py +57 -0
- phoenix/server/api/types/DatasetSplit.py +98 -0
- phoenix/server/api/types/DatasetVersion.py +49 -4
- phoenix/server/api/types/DocumentAnnotation.py +212 -0
- phoenix/server/api/types/Experiment.py +215 -68
- phoenix/server/api/types/ExperimentComparison.py +3 -9
- phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
- phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
- phoenix/server/api/types/ExperimentRun.py +120 -70
- phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
- phoenix/server/api/types/GenerativeModel.py +95 -42
- phoenix/server/api/types/GenerativeProvider.py +1 -1
- phoenix/server/api/types/ModelInterface.py +7 -2
- phoenix/server/api/types/PlaygroundModel.py +12 -2
- phoenix/server/api/types/Project.py +218 -185
- phoenix/server/api/types/ProjectSession.py +146 -29
- phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
- phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
- phoenix/server/api/types/Prompt.py +119 -39
- phoenix/server/api/types/PromptLabel.py +42 -25
- phoenix/server/api/types/PromptVersion.py +11 -8
- phoenix/server/api/types/PromptVersionTag.py +65 -25
- phoenix/server/api/types/Span.py +130 -123
- phoenix/server/api/types/SpanAnnotation.py +189 -42
- phoenix/server/api/types/SystemApiKey.py +65 -1
- phoenix/server/api/types/Trace.py +184 -53
- phoenix/server/api/types/TraceAnnotation.py +149 -50
- phoenix/server/api/types/User.py +128 -33
- phoenix/server/api/types/UserApiKey.py +73 -26
- phoenix/server/api/types/node.py +10 -0
- phoenix/server/api/types/pagination.py +11 -2
- phoenix/server/app.py +154 -36
- phoenix/server/authorization.py +5 -4
- phoenix/server/bearer_auth.py +13 -5
- phoenix/server/cost_tracking/cost_model_lookup.py +42 -14
- phoenix/server/cost_tracking/model_cost_manifest.json +1085 -194
- phoenix/server/daemons/generative_model_store.py +61 -9
- phoenix/server/daemons/span_cost_calculator.py +10 -8
- phoenix/server/dml_event.py +13 -0
- phoenix/server/email/sender.py +29 -2
- phoenix/server/grpc_server.py +9 -9
- phoenix/server/jwt_store.py +8 -6
- phoenix/server/ldap.py +1449 -0
- phoenix/server/main.py +9 -3
- phoenix/server/oauth2.py +330 -12
- phoenix/server/prometheus.py +43 -6
- phoenix/server/rate_limiters.py +4 -9
- phoenix/server/retention.py +33 -20
- phoenix/server/session_filters.py +49 -0
- phoenix/server/static/.vite/manifest.json +51 -53
- phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
- phoenix/server/static/assets/{index-BPCwGQr8.js → index-CTQoemZv.js} +42 -35
- phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
- phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
- phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
- phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
- phoenix/server/static/assets/{vendor-recharts-Bw30oz1A.js → vendor-recharts-V9cwpXsm.js} +7 -7
- phoenix/server/static/assets/{vendor-shiki-DZajAPeq.js → vendor-shiki-Do--csgv.js} +1 -1
- phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
- phoenix/server/templates/index.html +7 -1
- phoenix/server/thread_server.py +1 -2
- phoenix/server/utils.py +74 -0
- phoenix/session/client.py +55 -1
- phoenix/session/data_extractor.py +5 -0
- phoenix/session/evaluation.py +8 -4
- phoenix/session/session.py +44 -8
- phoenix/settings.py +2 -0
- phoenix/trace/attributes.py +80 -13
- phoenix/trace/dsl/query.py +2 -0
- phoenix/trace/projects.py +5 -0
- phoenix/utilities/template_formatters.py +1 -1
- phoenix/version.py +1 -1
- phoenix/server/api/types/Evaluation.py +0 -39
- phoenix/server/static/assets/components-D0DWAf0l.js +0 -5650
- phoenix/server/static/assets/pages-Creyamao.js +0 -8612
- phoenix/server/static/assets/vendor-CU36oj8y.js +0 -905
- phoenix/server/static/assets/vendor-CqDb5u4o.css +0 -1
- phoenix/server/static/assets/vendor-arizeai-Ctgw0e1G.js +0 -168
- phoenix/server/static/assets/vendor-codemirror-Cojjzqb9.js +0 -25
- phoenix/server/static/assets/vendor-three-BLWp5bic.js +0 -2998
- phoenix/utilities/deprecation.py +0 -31
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,41 +1,155 @@
|
|
|
1
1
|
from datetime import datetime
|
|
2
|
-
from typing import
|
|
2
|
+
from typing import TYPE_CHECKING, Annotated, Optional
|
|
3
3
|
|
|
4
4
|
import strawberry
|
|
5
5
|
from sqlalchemy import func, select
|
|
6
|
-
from sqlalchemy.orm import joinedload
|
|
7
|
-
from sqlalchemy.sql.functions import coalesce
|
|
8
6
|
from strawberry import UNSET, Private
|
|
9
|
-
from strawberry.relay import Connection, Node, NodeID
|
|
7
|
+
from strawberry.relay import Connection, GlobalID, Node, NodeID
|
|
10
8
|
from strawberry.scalars import JSON
|
|
11
9
|
from strawberry.types import Info
|
|
12
10
|
|
|
13
11
|
from phoenix.db import models
|
|
14
12
|
from phoenix.server.api.context import Context
|
|
13
|
+
from phoenix.server.api.exceptions import BadRequest
|
|
14
|
+
from phoenix.server.api.input_types.ExperimentRunSort import (
|
|
15
|
+
ExperimentRunSort,
|
|
16
|
+
add_order_by_and_page_start_to_query,
|
|
17
|
+
get_experiment_run_cursor,
|
|
18
|
+
)
|
|
15
19
|
from phoenix.server.api.types.CostBreakdown import CostBreakdown
|
|
20
|
+
from phoenix.server.api.types.DatasetSplit import DatasetSplit
|
|
21
|
+
from phoenix.server.api.types.DatasetVersion import DatasetVersion
|
|
16
22
|
from phoenix.server.api.types.ExperimentAnnotationSummary import ExperimentAnnotationSummary
|
|
17
|
-
from phoenix.server.api.types.ExperimentRun import ExperimentRun
|
|
23
|
+
from phoenix.server.api.types.ExperimentRun import ExperimentRun
|
|
18
24
|
from phoenix.server.api.types.pagination import (
|
|
19
25
|
ConnectionArgs,
|
|
26
|
+
Cursor,
|
|
20
27
|
CursorString,
|
|
28
|
+
connection_from_cursors_and_nodes,
|
|
21
29
|
connection_from_list,
|
|
22
30
|
)
|
|
23
|
-
from phoenix.server.api.types.Project import Project
|
|
24
31
|
from phoenix.server.api.types.SpanCostDetailSummaryEntry import SpanCostDetailSummaryEntry
|
|
25
32
|
from phoenix.server.api.types.SpanCostSummary import SpanCostSummary
|
|
26
33
|
|
|
34
|
+
_DEFAULT_EXPERIMENT_RUNS_PAGE_SIZE = 50
|
|
35
|
+
|
|
36
|
+
if TYPE_CHECKING:
|
|
37
|
+
from .Project import Project
|
|
38
|
+
|
|
27
39
|
|
|
28
40
|
@strawberry.type
|
|
29
41
|
class Experiment(Node):
|
|
30
|
-
|
|
42
|
+
id: NodeID[int]
|
|
43
|
+
db_record: strawberry.Private[Optional[models.Experiment]] = None
|
|
31
44
|
cached_sequence_number: Private[Optional[int]] = None
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
45
|
+
|
|
46
|
+
def __post_init__(self) -> None:
|
|
47
|
+
if self.db_record and self.id != self.db_record.id:
|
|
48
|
+
raise ValueError("Experiment ID mismatch")
|
|
49
|
+
|
|
50
|
+
@strawberry.field
|
|
51
|
+
async def name(
|
|
52
|
+
self,
|
|
53
|
+
info: Info[Context, None],
|
|
54
|
+
) -> str:
|
|
55
|
+
if self.db_record:
|
|
56
|
+
val = self.db_record.name
|
|
57
|
+
else:
|
|
58
|
+
val = await info.context.data_loaders.experiment_fields.load(
|
|
59
|
+
(self.id, models.Experiment.name),
|
|
60
|
+
)
|
|
61
|
+
return val
|
|
62
|
+
|
|
63
|
+
@strawberry.field
|
|
64
|
+
async def project_name(
|
|
65
|
+
self,
|
|
66
|
+
info: Info[Context, None],
|
|
67
|
+
) -> Optional[str]:
|
|
68
|
+
if self.db_record:
|
|
69
|
+
val = self.db_record.project_name
|
|
70
|
+
else:
|
|
71
|
+
val = await info.context.data_loaders.experiment_fields.load(
|
|
72
|
+
(self.id, models.Experiment.project_name),
|
|
73
|
+
)
|
|
74
|
+
return val
|
|
75
|
+
|
|
76
|
+
@strawberry.field
|
|
77
|
+
async def description(
|
|
78
|
+
self,
|
|
79
|
+
info: Info[Context, None],
|
|
80
|
+
) -> Optional[str]:
|
|
81
|
+
if self.db_record:
|
|
82
|
+
val = self.db_record.description
|
|
83
|
+
else:
|
|
84
|
+
val = await info.context.data_loaders.experiment_fields.load(
|
|
85
|
+
(self.id, models.Experiment.description),
|
|
86
|
+
)
|
|
87
|
+
return val
|
|
88
|
+
|
|
89
|
+
@strawberry.field
|
|
90
|
+
async def repetitions(
|
|
91
|
+
self,
|
|
92
|
+
info: Info[Context, None],
|
|
93
|
+
) -> int:
|
|
94
|
+
if self.db_record:
|
|
95
|
+
val = self.db_record.repetitions
|
|
96
|
+
else:
|
|
97
|
+
val = await info.context.data_loaders.experiment_fields.load(
|
|
98
|
+
(self.id, models.Experiment.repetitions),
|
|
99
|
+
)
|
|
100
|
+
return val
|
|
101
|
+
|
|
102
|
+
@strawberry.field
|
|
103
|
+
async def dataset_version_id(
|
|
104
|
+
self,
|
|
105
|
+
info: Info[Context, None],
|
|
106
|
+
) -> GlobalID:
|
|
107
|
+
if self.db_record:
|
|
108
|
+
version_id = self.db_record.dataset_version_id
|
|
109
|
+
else:
|
|
110
|
+
version_id = await info.context.data_loaders.experiment_fields.load(
|
|
111
|
+
(self.id, models.Experiment.dataset_version_id),
|
|
112
|
+
)
|
|
113
|
+
return GlobalID(DatasetVersion.__name__, str(version_id))
|
|
114
|
+
|
|
115
|
+
@strawberry.field
|
|
116
|
+
async def metadata(
|
|
117
|
+
self,
|
|
118
|
+
info: Info[Context, None],
|
|
119
|
+
) -> JSON:
|
|
120
|
+
if self.db_record:
|
|
121
|
+
val = self.db_record.metadata_
|
|
122
|
+
else:
|
|
123
|
+
val = await info.context.data_loaders.experiment_fields.load(
|
|
124
|
+
(self.id, models.Experiment.metadata_),
|
|
125
|
+
)
|
|
126
|
+
return val
|
|
127
|
+
|
|
128
|
+
@strawberry.field
|
|
129
|
+
async def created_at(
|
|
130
|
+
self,
|
|
131
|
+
info: Info[Context, None],
|
|
132
|
+
) -> datetime:
|
|
133
|
+
if self.db_record:
|
|
134
|
+
val = self.db_record.created_at
|
|
135
|
+
else:
|
|
136
|
+
val = await info.context.data_loaders.experiment_fields.load(
|
|
137
|
+
(self.id, models.Experiment.created_at),
|
|
138
|
+
)
|
|
139
|
+
return val
|
|
140
|
+
|
|
141
|
+
@strawberry.field
|
|
142
|
+
async def updated_at(
|
|
143
|
+
self,
|
|
144
|
+
info: Info[Context, None],
|
|
145
|
+
) -> datetime:
|
|
146
|
+
if self.db_record:
|
|
147
|
+
val = self.db_record.updated_at
|
|
148
|
+
else:
|
|
149
|
+
val = await info.context.data_loaders.experiment_fields.load(
|
|
150
|
+
(self.id, models.Experiment.updated_at),
|
|
151
|
+
)
|
|
152
|
+
return val
|
|
39
153
|
|
|
40
154
|
@strawberry.field(
|
|
41
155
|
description="Sequence number (1-based) of experiments belonging to the same dataset"
|
|
@@ -45,9 +159,9 @@ class Experiment(Node):
|
|
|
45
159
|
info: Info[Context, None],
|
|
46
160
|
) -> int:
|
|
47
161
|
if self.cached_sequence_number is None:
|
|
48
|
-
seq_num = await info.context.data_loaders.experiment_sequence_number.load(self.
|
|
162
|
+
seq_num = await info.context.data_loaders.experiment_sequence_number.load(self.id)
|
|
49
163
|
if seq_num is None:
|
|
50
|
-
raise ValueError(f"invalid experiment: id={self.
|
|
164
|
+
raise ValueError(f"invalid experiment: id={self.id}")
|
|
51
165
|
self.cached_sequence_number = seq_num
|
|
52
166
|
return self.cached_sequence_number
|
|
53
167
|
|
|
@@ -55,41 +169,68 @@ class Experiment(Node):
|
|
|
55
169
|
async def runs(
|
|
56
170
|
self,
|
|
57
171
|
info: Info[Context, None],
|
|
58
|
-
first: Optional[int] =
|
|
59
|
-
last: Optional[int] = UNSET,
|
|
172
|
+
first: Optional[int] = _DEFAULT_EXPERIMENT_RUNS_PAGE_SIZE,
|
|
60
173
|
after: Optional[CursorString] = UNSET,
|
|
61
|
-
|
|
174
|
+
sort: Optional[ExperimentRunSort] = UNSET,
|
|
62
175
|
) -> Connection[ExperimentRun]:
|
|
63
|
-
|
|
64
|
-
first
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
176
|
+
if first is not None and first <= 0:
|
|
177
|
+
raise BadRequest("first must be a positive integer if set")
|
|
178
|
+
page_size = first or _DEFAULT_EXPERIMENT_RUNS_PAGE_SIZE
|
|
179
|
+
experiment_runs_query = (
|
|
180
|
+
select(models.ExperimentRun)
|
|
181
|
+
.where(models.ExperimentRun.experiment_id == self.id)
|
|
182
|
+
.limit(page_size + 1)
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
after_experiment_run_rowid = None
|
|
186
|
+
after_sort_column_value = None
|
|
187
|
+
if after:
|
|
188
|
+
cursor = Cursor.from_string(after)
|
|
189
|
+
after_experiment_run_rowid = cursor.rowid
|
|
190
|
+
if cursor.sort_column is not None:
|
|
191
|
+
after_sort_column_value = cursor.sort_column.value
|
|
192
|
+
|
|
193
|
+
experiment_runs_query = add_order_by_and_page_start_to_query(
|
|
194
|
+
query=experiment_runs_query,
|
|
195
|
+
sort=sort,
|
|
196
|
+
experiment_rowid=self.id,
|
|
197
|
+
after_experiment_run_rowid=after_experiment_run_rowid,
|
|
198
|
+
after_sort_column_value=after_sort_column_value,
|
|
68
199
|
)
|
|
69
|
-
|
|
200
|
+
|
|
70
201
|
async with info.context.db() as session:
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
202
|
+
results = (await session.execute(experiment_runs_query)).all()
|
|
203
|
+
|
|
204
|
+
has_next_page = False
|
|
205
|
+
if len(results) > page_size:
|
|
206
|
+
results = results[:page_size]
|
|
207
|
+
has_next_page = True
|
|
208
|
+
|
|
209
|
+
cursors_and_nodes = []
|
|
210
|
+
for result in results:
|
|
211
|
+
run = result[0]
|
|
212
|
+
annotation_score = result[1] if len(result) > 1 else None
|
|
213
|
+
gql_run = ExperimentRun(id=run.id, db_record=run)
|
|
214
|
+
cursor = get_experiment_run_cursor(
|
|
215
|
+
run=run, annotation_score=annotation_score, sort=sort
|
|
216
|
+
)
|
|
217
|
+
cursors_and_nodes.append((cursor, gql_run))
|
|
218
|
+
|
|
219
|
+
return connection_from_cursors_and_nodes(
|
|
220
|
+
cursors_and_nodes=cursors_and_nodes,
|
|
221
|
+
has_previous_page=False, # set to false since we are only doing forward pagination (https://relay.dev/graphql/connections.htm#sec-undefined.PageInfo.Fields) # noqa: E501
|
|
222
|
+
has_next_page=has_next_page,
|
|
223
|
+
)
|
|
82
224
|
|
|
83
225
|
@strawberry.field
|
|
84
226
|
async def run_count(self, info: Info[Context, None]) -> int:
|
|
85
|
-
|
|
86
|
-
return await info.context.data_loaders.experiment_run_counts.load(experiment_id)
|
|
227
|
+
return await info.context.data_loaders.experiment_run_counts.load(self.id)
|
|
87
228
|
|
|
88
229
|
@strawberry.field
|
|
89
230
|
async def annotation_summaries(
|
|
90
231
|
self, info: Info[Context, None]
|
|
91
232
|
) -> list[ExperimentAnnotationSummary]:
|
|
92
|
-
experiment_id = self.
|
|
233
|
+
experiment_id = self.id
|
|
93
234
|
return [
|
|
94
235
|
ExperimentAnnotationSummary(
|
|
95
236
|
annotation_name=summary.annotation_name,
|
|
@@ -106,40 +247,42 @@ class Experiment(Node):
|
|
|
106
247
|
|
|
107
248
|
@strawberry.field
|
|
108
249
|
async def error_rate(self, info: Info[Context, None]) -> Optional[float]:
|
|
109
|
-
return await info.context.data_loaders.experiment_error_rates.load(self.
|
|
250
|
+
return await info.context.data_loaders.experiment_error_rates.load(self.id)
|
|
110
251
|
|
|
111
252
|
@strawberry.field
|
|
112
253
|
async def average_run_latency_ms(self, info: Info[Context, None]) -> Optional[float]:
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
)
|
|
116
|
-
return latency_seconds * 1000 if latency_seconds is not None else None
|
|
254
|
+
latency_ms = await info.context.data_loaders.average_experiment_run_latency.load(self.id)
|
|
255
|
+
return latency_ms
|
|
117
256
|
|
|
118
257
|
@strawberry.field
|
|
119
|
-
async def project(
|
|
120
|
-
|
|
258
|
+
async def project(
|
|
259
|
+
self, info: Info[Context, None]
|
|
260
|
+
) -> Optional[Annotated["Project", strawberry.lazy(".Project")]]:
|
|
261
|
+
if self.db_record:
|
|
262
|
+
project_name = self.db_record.project_name
|
|
263
|
+
else:
|
|
264
|
+
project_name = await info.context.data_loaders.experiment_fields.load(
|
|
265
|
+
(self.id, models.Experiment.project_name),
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
if project_name is None:
|
|
121
269
|
return None
|
|
122
270
|
|
|
123
|
-
db_project = await info.context.data_loaders.project_by_name.load(
|
|
271
|
+
db_project = await info.context.data_loaders.project_by_name.load(project_name)
|
|
124
272
|
|
|
125
273
|
if db_project is None:
|
|
126
274
|
return None
|
|
275
|
+
from .Project import Project
|
|
127
276
|
|
|
128
|
-
return Project(
|
|
129
|
-
project_rowid=db_project.id,
|
|
130
|
-
db_project=db_project,
|
|
131
|
-
)
|
|
277
|
+
return Project(id=db_project.id, db_record=db_project)
|
|
132
278
|
|
|
133
279
|
@strawberry.field
|
|
134
280
|
def last_updated_at(self, info: Info[Context, None]) -> Optional[datetime]:
|
|
135
|
-
return info.context.last_updated_at.get(
|
|
281
|
+
return info.context.last_updated_at.get(models.Experiment, self.id)
|
|
136
282
|
|
|
137
283
|
@strawberry.field
|
|
138
284
|
async def cost_summary(self, info: Info[Context, None]) -> SpanCostSummary:
|
|
139
|
-
|
|
140
|
-
summary = await info.context.data_loaders.span_cost_summary_by_experiment.load(
|
|
141
|
-
experiment_id
|
|
142
|
-
)
|
|
285
|
+
summary = await info.context.data_loaders.span_cost_summary_by_experiment.load(self.id)
|
|
143
286
|
return SpanCostSummary(
|
|
144
287
|
prompt=CostBreakdown(
|
|
145
288
|
tokens=summary.prompt.tokens,
|
|
@@ -159,21 +302,19 @@ class Experiment(Node):
|
|
|
159
302
|
async def cost_detail_summary_entries(
|
|
160
303
|
self, info: Info[Context, None]
|
|
161
304
|
) -> list[SpanCostDetailSummaryEntry]:
|
|
162
|
-
experiment_id = self.id_attr
|
|
163
|
-
|
|
164
305
|
stmt = (
|
|
165
306
|
select(
|
|
166
307
|
models.SpanCostDetail.token_type,
|
|
167
308
|
models.SpanCostDetail.is_prompt,
|
|
168
|
-
|
|
169
|
-
|
|
309
|
+
func.sum(models.SpanCostDetail.cost).label("cost"),
|
|
310
|
+
func.sum(models.SpanCostDetail.tokens).label("tokens"),
|
|
170
311
|
)
|
|
171
312
|
.select_from(models.SpanCostDetail)
|
|
172
313
|
.join(models.SpanCost, models.SpanCostDetail.span_cost_id == models.SpanCost.id)
|
|
173
314
|
.join(models.Span, models.SpanCost.span_rowid == models.Span.id)
|
|
174
315
|
.join(models.Trace, models.Span.trace_rowid == models.Trace.id)
|
|
175
316
|
.join(models.ExperimentRun, models.ExperimentRun.trace_id == models.Trace.trace_id)
|
|
176
|
-
.where(models.ExperimentRun.experiment_id ==
|
|
317
|
+
.where(models.ExperimentRun.experiment_id == self.id)
|
|
177
318
|
.group_by(models.SpanCostDetail.token_type, models.SpanCostDetail.is_prompt)
|
|
178
319
|
)
|
|
179
320
|
|
|
@@ -188,6 +329,17 @@ class Experiment(Node):
|
|
|
188
329
|
async for token_type, is_prompt, cost, tokens in data
|
|
189
330
|
]
|
|
190
331
|
|
|
332
|
+
@strawberry.field
|
|
333
|
+
async def dataset_splits(
|
|
334
|
+
self,
|
|
335
|
+
info: Info[Context, None],
|
|
336
|
+
) -> Connection[DatasetSplit]:
|
|
337
|
+
"""Returns the dataset splits associated with this experiment."""
|
|
338
|
+
splits = await info.context.data_loaders.experiment_dataset_splits.load(self.id)
|
|
339
|
+
return connection_from_list(
|
|
340
|
+
[DatasetSplit(id=split.id, db_record=split) for split in splits], ConnectionArgs()
|
|
341
|
+
)
|
|
342
|
+
|
|
191
343
|
|
|
192
344
|
def to_gql_experiment(
|
|
193
345
|
experiment: models.Experiment,
|
|
@@ -197,12 +349,7 @@ def to_gql_experiment(
|
|
|
197
349
|
Converts an ORM experiment to a GraphQL Experiment.
|
|
198
350
|
"""
|
|
199
351
|
return Experiment(
|
|
352
|
+
id=experiment.id,
|
|
353
|
+
db_record=experiment,
|
|
200
354
|
cached_sequence_number=sequence_number,
|
|
201
|
-
id_attr=experiment.id,
|
|
202
|
-
name=experiment.name,
|
|
203
|
-
project_name=experiment.project_name,
|
|
204
|
-
description=experiment.description,
|
|
205
|
-
metadata=experiment.metadata_,
|
|
206
|
-
created_at=experiment.created_at,
|
|
207
|
-
updated_at=experiment.updated_at,
|
|
208
355
|
)
|
|
@@ -1,18 +1,12 @@
|
|
|
1
1
|
import strawberry
|
|
2
|
-
from strawberry.relay import
|
|
2
|
+
from strawberry.relay import Node, NodeID
|
|
3
3
|
|
|
4
4
|
from phoenix.server.api.types.DatasetExample import DatasetExample
|
|
5
|
-
from phoenix.server.api.types.
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
@strawberry.type
|
|
9
|
-
class RunComparisonItem:
|
|
10
|
-
experiment_id: GlobalID
|
|
11
|
-
runs: list[ExperimentRun]
|
|
5
|
+
from phoenix.server.api.types.ExperimentRepeatedRunGroup import ExperimentRepeatedRunGroup
|
|
12
6
|
|
|
13
7
|
|
|
14
8
|
@strawberry.type
|
|
15
9
|
class ExperimentComparison(Node):
|
|
16
10
|
id_attr: NodeID[int]
|
|
17
11
|
example: DatasetExample
|
|
18
|
-
|
|
12
|
+
repeated_run_groups: list[ExperimentRepeatedRunGroup]
|
|
@@ -0,0 +1,155 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from base64 import b64decode
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import strawberry
|
|
6
|
+
from sqlalchemy import func, select
|
|
7
|
+
from strawberry.relay import GlobalID, Node
|
|
8
|
+
from strawberry.types import Info
|
|
9
|
+
from typing_extensions import Self, TypeAlias
|
|
10
|
+
|
|
11
|
+
from phoenix.db import models
|
|
12
|
+
from phoenix.server.api.context import Context
|
|
13
|
+
from phoenix.server.api.types.CostBreakdown import CostBreakdown
|
|
14
|
+
from phoenix.server.api.types.ExperimentRepeatedRunGroupAnnotationSummary import (
|
|
15
|
+
ExperimentRepeatedRunGroupAnnotationSummary,
|
|
16
|
+
)
|
|
17
|
+
from phoenix.server.api.types.ExperimentRun import ExperimentRun
|
|
18
|
+
from phoenix.server.api.types.SpanCostDetailSummaryEntry import SpanCostDetailSummaryEntry
|
|
19
|
+
from phoenix.server.api.types.SpanCostSummary import SpanCostSummary
|
|
20
|
+
|
|
21
|
+
ExperimentRowId: TypeAlias = int
|
|
22
|
+
DatasetExampleRowId: TypeAlias = int
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@strawberry.type
|
|
26
|
+
class ExperimentRepeatedRunGroup(Node):
|
|
27
|
+
experiment_rowid: strawberry.Private[ExperimentRowId]
|
|
28
|
+
dataset_example_rowid: strawberry.Private[DatasetExampleRowId]
|
|
29
|
+
cached_runs: strawberry.Private[Optional[list[ExperimentRun]]] = None
|
|
30
|
+
|
|
31
|
+
@strawberry.field
|
|
32
|
+
async def runs(self, info: Info[Context, None]) -> list[ExperimentRun]:
|
|
33
|
+
if self.cached_runs is not None:
|
|
34
|
+
return self.cached_runs
|
|
35
|
+
runs = await info.context.data_loaders.experiment_runs_by_experiment_and_example.load(
|
|
36
|
+
(self.experiment_rowid, self.dataset_example_rowid)
|
|
37
|
+
)
|
|
38
|
+
return [ExperimentRun(id=run.id, db_record=run) for run in runs]
|
|
39
|
+
|
|
40
|
+
@classmethod
|
|
41
|
+
def resolve_id(
|
|
42
|
+
cls,
|
|
43
|
+
root: Self,
|
|
44
|
+
*,
|
|
45
|
+
info: Info,
|
|
46
|
+
) -> str:
|
|
47
|
+
return (
|
|
48
|
+
f"experiment_id={root.experiment_rowid}:dataset_example_id={root.dataset_example_rowid}"
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
@strawberry.field
|
|
52
|
+
def experiment_id(self) -> strawberry.ID:
|
|
53
|
+
from phoenix.server.api.types.Experiment import Experiment
|
|
54
|
+
|
|
55
|
+
return strawberry.ID(str(GlobalID(Experiment.__name__, str(self.experiment_rowid))))
|
|
56
|
+
|
|
57
|
+
@strawberry.field
|
|
58
|
+
async def average_latency_ms(self, info: Info[Context, None]) -> Optional[float]:
|
|
59
|
+
return await info.context.data_loaders.average_experiment_repeated_run_group_latency.load(
|
|
60
|
+
(self.experiment_rowid, self.dataset_example_rowid)
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
@strawberry.field
|
|
64
|
+
async def cost_summary(self, info: Info[Context, None]) -> SpanCostSummary:
|
|
65
|
+
experiment_id = self.experiment_rowid
|
|
66
|
+
example_id = self.dataset_example_rowid
|
|
67
|
+
summary = (
|
|
68
|
+
await info.context.data_loaders.span_cost_summary_by_experiment_repeated_run_group.load(
|
|
69
|
+
(experiment_id, example_id)
|
|
70
|
+
)
|
|
71
|
+
)
|
|
72
|
+
return SpanCostSummary(
|
|
73
|
+
prompt=CostBreakdown(
|
|
74
|
+
tokens=summary.prompt.tokens,
|
|
75
|
+
cost=summary.prompt.cost,
|
|
76
|
+
),
|
|
77
|
+
completion=CostBreakdown(
|
|
78
|
+
tokens=summary.completion.tokens,
|
|
79
|
+
cost=summary.completion.cost,
|
|
80
|
+
),
|
|
81
|
+
total=CostBreakdown(
|
|
82
|
+
tokens=summary.total.tokens,
|
|
83
|
+
cost=summary.total.cost,
|
|
84
|
+
),
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
@strawberry.field
|
|
88
|
+
async def cost_detail_summary_entries(
|
|
89
|
+
self, info: Info[Context, None]
|
|
90
|
+
) -> list[SpanCostDetailSummaryEntry]:
|
|
91
|
+
experiment_id = self.experiment_rowid
|
|
92
|
+
example_id = self.dataset_example_rowid
|
|
93
|
+
stmt = (
|
|
94
|
+
select(
|
|
95
|
+
models.SpanCostDetail.token_type,
|
|
96
|
+
models.SpanCostDetail.is_prompt,
|
|
97
|
+
func.sum(models.SpanCostDetail.cost).label("cost"),
|
|
98
|
+
func.sum(models.SpanCostDetail.tokens).label("tokens"),
|
|
99
|
+
)
|
|
100
|
+
.select_from(models.SpanCostDetail)
|
|
101
|
+
.join(models.SpanCost, models.SpanCostDetail.span_cost_id == models.SpanCost.id)
|
|
102
|
+
.join(models.Trace, models.SpanCost.trace_rowid == models.Trace.id)
|
|
103
|
+
.join(models.ExperimentRun, models.ExperimentRun.trace_id == models.Trace.trace_id)
|
|
104
|
+
.where(models.ExperimentRun.experiment_id == experiment_id)
|
|
105
|
+
.where(models.ExperimentRun.dataset_example_id == example_id)
|
|
106
|
+
.group_by(models.SpanCostDetail.token_type, models.SpanCostDetail.is_prompt)
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
async with info.context.db() as session:
|
|
110
|
+
data = await session.stream(stmt)
|
|
111
|
+
return [
|
|
112
|
+
SpanCostDetailSummaryEntry(
|
|
113
|
+
token_type=token_type,
|
|
114
|
+
is_prompt=is_prompt,
|
|
115
|
+
value=CostBreakdown(tokens=tokens, cost=cost),
|
|
116
|
+
)
|
|
117
|
+
async for token_type, is_prompt, cost, tokens in data
|
|
118
|
+
]
|
|
119
|
+
|
|
120
|
+
@strawberry.field
|
|
121
|
+
async def annotation_summaries(
|
|
122
|
+
self,
|
|
123
|
+
info: Info[Context, None],
|
|
124
|
+
) -> list[ExperimentRepeatedRunGroupAnnotationSummary]:
|
|
125
|
+
loader = info.context.data_loaders.experiment_repeated_run_group_annotation_summaries
|
|
126
|
+
summaries = await loader.load((self.experiment_rowid, self.dataset_example_rowid))
|
|
127
|
+
return [
|
|
128
|
+
ExperimentRepeatedRunGroupAnnotationSummary(
|
|
129
|
+
annotation_name=summary.annotation_name,
|
|
130
|
+
mean_score=summary.mean_score,
|
|
131
|
+
)
|
|
132
|
+
for summary in summaries
|
|
133
|
+
]
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
_EXPERIMENT_REPEATED_RUN_GROUP_NODE_ID_PATTERN = re.compile(
|
|
137
|
+
r"ExperimentRepeatedRunGroup:experiment_id=(\d+):dataset_example_id=(\d+)"
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def parse_experiment_repeated_run_group_node_id(
|
|
142
|
+
node_id: str,
|
|
143
|
+
) -> tuple[ExperimentRowId, DatasetExampleRowId]:
|
|
144
|
+
decoded_node_id = _base64_decode(node_id)
|
|
145
|
+
match = re.match(_EXPERIMENT_REPEATED_RUN_GROUP_NODE_ID_PATTERN, decoded_node_id)
|
|
146
|
+
if not match:
|
|
147
|
+
raise ValueError(f"Invalid node ID format: {node_id}")
|
|
148
|
+
|
|
149
|
+
experiment_id = int(match.group(1))
|
|
150
|
+
dataset_example_id = int(match.group(2))
|
|
151
|
+
return experiment_id, dataset_example_id
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def _base64_decode(string: str) -> str:
|
|
155
|
+
return b64decode(string.encode()).decode()
|