arize-phoenix 11.23.1__py3-none-any.whl → 12.28.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +61 -36
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/RECORD +212 -162
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
- phoenix/__generated__/__init__.py +0 -0
- phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
- phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
- phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
- phoenix/__init__.py +2 -1
- phoenix/auth.py +27 -2
- phoenix/config.py +1594 -81
- phoenix/db/README.md +546 -28
- phoenix/db/bulk_inserter.py +119 -116
- phoenix/db/engines.py +140 -33
- phoenix/db/facilitator.py +22 -1
- phoenix/db/helpers.py +818 -65
- phoenix/db/iam_auth.py +64 -0
- phoenix/db/insertion/dataset.py +133 -1
- phoenix/db/insertion/document_annotation.py +9 -6
- phoenix/db/insertion/evaluation.py +2 -3
- phoenix/db/insertion/helpers.py +2 -2
- phoenix/db/insertion/session_annotation.py +176 -0
- phoenix/db/insertion/span_annotation.py +3 -4
- phoenix/db/insertion/trace_annotation.py +3 -4
- phoenix/db/insertion/types.py +41 -18
- phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
- phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
- phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
- phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
- phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
- phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
- phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
- phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
- phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
- phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
- phoenix/db/models.py +364 -56
- phoenix/db/pg_config.py +10 -0
- phoenix/db/types/trace_retention.py +7 -6
- phoenix/experiments/functions.py +69 -19
- phoenix/inferences/inferences.py +1 -2
- phoenix/server/api/auth.py +9 -0
- phoenix/server/api/auth_messages.py +46 -0
- phoenix/server/api/context.py +60 -0
- phoenix/server/api/dataloaders/__init__.py +36 -0
- phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
- phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
- phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
- phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
- phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
- phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
- phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
- phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
- phoenix/server/api/dataloaders/dataset_labels.py +36 -0
- phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
- phoenix/server/api/dataloaders/document_evaluations.py +6 -9
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
- phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
- phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
- phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
- phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
- phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
- phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
- phoenix/server/api/dataloaders/record_counts.py +37 -10
- phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project.py +28 -14
- phoenix/server/api/dataloaders/span_costs.py +3 -9
- phoenix/server/api/dataloaders/table_fields.py +2 -2
- phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
- phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
- phoenix/server/api/exceptions.py +5 -1
- phoenix/server/api/helpers/playground_clients.py +263 -83
- phoenix/server/api/helpers/playground_spans.py +2 -1
- phoenix/server/api/helpers/playground_users.py +26 -0
- phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
- phoenix/server/api/helpers/prompts/models.py +61 -19
- phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
- phoenix/server/api/input_types/ChatCompletionInput.py +3 -0
- phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
- phoenix/server/api/input_types/DatasetFilter.py +5 -2
- phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
- phoenix/server/api/input_types/GenerativeModelInput.py +3 -0
- phoenix/server/api/input_types/ProjectSessionSort.py +158 -1
- phoenix/server/api/input_types/PromptVersionInput.py +47 -1
- phoenix/server/api/input_types/SpanSort.py +3 -2
- phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
- phoenix/server/api/input_types/UserRoleInput.py +1 -0
- phoenix/server/api/mutations/__init__.py +8 -0
- phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
- phoenix/server/api/mutations/api_key_mutations.py +15 -20
- phoenix/server/api/mutations/chat_mutations.py +106 -37
- phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
- phoenix/server/api/mutations/dataset_mutations.py +21 -16
- phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
- phoenix/server/api/mutations/experiment_mutations.py +2 -2
- phoenix/server/api/mutations/export_events_mutations.py +3 -3
- phoenix/server/api/mutations/model_mutations.py +11 -9
- phoenix/server/api/mutations/project_mutations.py +4 -4
- phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
- phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
- phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
- phoenix/server/api/mutations/prompt_mutations.py +65 -129
- phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
- phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
- phoenix/server/api/mutations/trace_annotations_mutations.py +13 -8
- phoenix/server/api/mutations/trace_mutations.py +3 -3
- phoenix/server/api/mutations/user_mutations.py +55 -26
- phoenix/server/api/queries.py +501 -617
- phoenix/server/api/routers/__init__.py +2 -2
- phoenix/server/api/routers/auth.py +141 -87
- phoenix/server/api/routers/ldap.py +229 -0
- phoenix/server/api/routers/oauth2.py +349 -101
- phoenix/server/api/routers/v1/__init__.py +22 -4
- phoenix/server/api/routers/v1/annotation_configs.py +19 -30
- phoenix/server/api/routers/v1/annotations.py +455 -13
- phoenix/server/api/routers/v1/datasets.py +355 -68
- phoenix/server/api/routers/v1/documents.py +142 -0
- phoenix/server/api/routers/v1/evaluations.py +20 -28
- phoenix/server/api/routers/v1/experiment_evaluations.py +16 -6
- phoenix/server/api/routers/v1/experiment_runs.py +335 -59
- phoenix/server/api/routers/v1/experiments.py +475 -47
- phoenix/server/api/routers/v1/projects.py +16 -50
- phoenix/server/api/routers/v1/prompts.py +50 -39
- phoenix/server/api/routers/v1/sessions.py +108 -0
- phoenix/server/api/routers/v1/spans.py +156 -96
- phoenix/server/api/routers/v1/traces.py +51 -77
- phoenix/server/api/routers/v1/users.py +64 -24
- phoenix/server/api/routers/v1/utils.py +3 -7
- phoenix/server/api/subscriptions.py +257 -93
- phoenix/server/api/types/Annotation.py +90 -23
- phoenix/server/api/types/ApiKey.py +13 -17
- phoenix/server/api/types/AuthMethod.py +1 -0
- phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
- phoenix/server/api/types/Dataset.py +199 -72
- phoenix/server/api/types/DatasetExample.py +88 -18
- phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
- phoenix/server/api/types/DatasetLabel.py +57 -0
- phoenix/server/api/types/DatasetSplit.py +98 -0
- phoenix/server/api/types/DatasetVersion.py +49 -4
- phoenix/server/api/types/DocumentAnnotation.py +212 -0
- phoenix/server/api/types/Experiment.py +215 -68
- phoenix/server/api/types/ExperimentComparison.py +3 -9
- phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
- phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
- phoenix/server/api/types/ExperimentRun.py +120 -70
- phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
- phoenix/server/api/types/GenerativeModel.py +95 -42
- phoenix/server/api/types/GenerativeProvider.py +1 -1
- phoenix/server/api/types/ModelInterface.py +7 -2
- phoenix/server/api/types/PlaygroundModel.py +12 -2
- phoenix/server/api/types/Project.py +218 -185
- phoenix/server/api/types/ProjectSession.py +146 -29
- phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
- phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
- phoenix/server/api/types/Prompt.py +119 -39
- phoenix/server/api/types/PromptLabel.py +42 -25
- phoenix/server/api/types/PromptVersion.py +11 -8
- phoenix/server/api/types/PromptVersionTag.py +65 -25
- phoenix/server/api/types/Span.py +130 -123
- phoenix/server/api/types/SpanAnnotation.py +189 -42
- phoenix/server/api/types/SystemApiKey.py +65 -1
- phoenix/server/api/types/Trace.py +184 -53
- phoenix/server/api/types/TraceAnnotation.py +149 -50
- phoenix/server/api/types/User.py +128 -33
- phoenix/server/api/types/UserApiKey.py +73 -26
- phoenix/server/api/types/node.py +10 -0
- phoenix/server/api/types/pagination.py +11 -2
- phoenix/server/app.py +154 -36
- phoenix/server/authorization.py +5 -4
- phoenix/server/bearer_auth.py +13 -5
- phoenix/server/cost_tracking/cost_model_lookup.py +42 -14
- phoenix/server/cost_tracking/model_cost_manifest.json +1085 -194
- phoenix/server/daemons/generative_model_store.py +61 -9
- phoenix/server/daemons/span_cost_calculator.py +10 -8
- phoenix/server/dml_event.py +13 -0
- phoenix/server/email/sender.py +29 -2
- phoenix/server/grpc_server.py +9 -9
- phoenix/server/jwt_store.py +8 -6
- phoenix/server/ldap.py +1449 -0
- phoenix/server/main.py +9 -3
- phoenix/server/oauth2.py +330 -12
- phoenix/server/prometheus.py +43 -6
- phoenix/server/rate_limiters.py +4 -9
- phoenix/server/retention.py +33 -20
- phoenix/server/session_filters.py +49 -0
- phoenix/server/static/.vite/manifest.json +51 -53
- phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
- phoenix/server/static/assets/{index-BPCwGQr8.js → index-CTQoemZv.js} +42 -35
- phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
- phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
- phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
- phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
- phoenix/server/static/assets/{vendor-recharts-Bw30oz1A.js → vendor-recharts-V9cwpXsm.js} +7 -7
- phoenix/server/static/assets/{vendor-shiki-DZajAPeq.js → vendor-shiki-Do--csgv.js} +1 -1
- phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
- phoenix/server/templates/index.html +7 -1
- phoenix/server/thread_server.py +1 -2
- phoenix/server/utils.py +74 -0
- phoenix/session/client.py +55 -1
- phoenix/session/data_extractor.py +5 -0
- phoenix/session/evaluation.py +8 -4
- phoenix/session/session.py +44 -8
- phoenix/settings.py +2 -0
- phoenix/trace/attributes.py +80 -13
- phoenix/trace/dsl/query.py +2 -0
- phoenix/trace/projects.py +5 -0
- phoenix/utilities/template_formatters.py +1 -1
- phoenix/version.py +1 -1
- phoenix/server/api/types/Evaluation.py +0 -39
- phoenix/server/static/assets/components-D0DWAf0l.js +0 -5650
- phoenix/server/static/assets/pages-Creyamao.js +0 -8612
- phoenix/server/static/assets/vendor-CU36oj8y.js +0 -905
- phoenix/server/static/assets/vendor-CqDb5u4o.css +0 -1
- phoenix/server/static/assets/vendor-arizeai-Ctgw0e1G.js +0 -168
- phoenix/server/static/assets/vendor-codemirror-Cojjzqb9.js +0 -25
- phoenix/server/static/assets/vendor-three-BLWp5bic.js +0 -2998
- phoenix/utilities/deprecation.py +0 -31
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
from collections.abc import AsyncIterable
|
|
2
2
|
from datetime import datetime
|
|
3
|
-
from typing import
|
|
3
|
+
from typing import Optional, cast
|
|
4
4
|
|
|
5
5
|
import strawberry
|
|
6
|
-
from sqlalchemy import and_, func, or_, select
|
|
6
|
+
from sqlalchemy import Text, and_, func, or_, select
|
|
7
7
|
from sqlalchemy.sql.functions import count
|
|
8
8
|
from strawberry import UNSET
|
|
9
9
|
from strawberry.relay import Connection, GlobalID, Node, NodeID
|
|
@@ -15,9 +15,13 @@ from phoenix.server.api.context import Context
|
|
|
15
15
|
from phoenix.server.api.exceptions import BadRequest
|
|
16
16
|
from phoenix.server.api.input_types.DatasetVersionSort import DatasetVersionSort
|
|
17
17
|
from phoenix.server.api.types.DatasetExample import DatasetExample
|
|
18
|
+
from phoenix.server.api.types.DatasetExperimentAnnotationSummary import (
|
|
19
|
+
DatasetExperimentAnnotationSummary,
|
|
20
|
+
)
|
|
21
|
+
from phoenix.server.api.types.DatasetLabel import DatasetLabel
|
|
22
|
+
from phoenix.server.api.types.DatasetSplit import DatasetSplit
|
|
18
23
|
from phoenix.server.api.types.DatasetVersion import DatasetVersion
|
|
19
24
|
from phoenix.server.api.types.Experiment import Experiment, to_gql_experiment
|
|
20
|
-
from phoenix.server.api.types.ExperimentAnnotationSummary import ExperimentAnnotationSummary
|
|
21
25
|
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
22
26
|
from phoenix.server.api.types.pagination import (
|
|
23
27
|
ConnectionArgs,
|
|
@@ -29,13 +33,77 @@ from phoenix.server.api.types.SortDir import SortDir
|
|
|
29
33
|
|
|
30
34
|
@strawberry.type
|
|
31
35
|
class Dataset(Node):
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
36
|
+
id: NodeID[int]
|
|
37
|
+
db_record: strawberry.Private[Optional[models.Dataset]] = None
|
|
38
|
+
|
|
39
|
+
def __post_init__(self) -> None:
|
|
40
|
+
if self.db_record and self.id != self.db_record.id:
|
|
41
|
+
raise ValueError("Dataset ID mismatch")
|
|
42
|
+
|
|
43
|
+
@strawberry.field
|
|
44
|
+
async def name(
|
|
45
|
+
self,
|
|
46
|
+
info: Info[Context, None],
|
|
47
|
+
) -> str:
|
|
48
|
+
if self.db_record:
|
|
49
|
+
val = self.db_record.name
|
|
50
|
+
else:
|
|
51
|
+
val = await info.context.data_loaders.dataset_fields.load(
|
|
52
|
+
(self.id, models.Dataset.name),
|
|
53
|
+
)
|
|
54
|
+
return val
|
|
55
|
+
|
|
56
|
+
@strawberry.field
|
|
57
|
+
async def description(
|
|
58
|
+
self,
|
|
59
|
+
info: Info[Context, None],
|
|
60
|
+
) -> Optional[str]:
|
|
61
|
+
if self.db_record:
|
|
62
|
+
val = self.db_record.description
|
|
63
|
+
else:
|
|
64
|
+
val = await info.context.data_loaders.dataset_fields.load(
|
|
65
|
+
(self.id, models.Dataset.description),
|
|
66
|
+
)
|
|
67
|
+
return val
|
|
68
|
+
|
|
69
|
+
@strawberry.field
|
|
70
|
+
async def metadata(
|
|
71
|
+
self,
|
|
72
|
+
info: Info[Context, None],
|
|
73
|
+
) -> JSON:
|
|
74
|
+
if self.db_record:
|
|
75
|
+
val = self.db_record.metadata_
|
|
76
|
+
else:
|
|
77
|
+
val = await info.context.data_loaders.dataset_fields.load(
|
|
78
|
+
(self.id, models.Dataset.metadata_),
|
|
79
|
+
)
|
|
80
|
+
return val
|
|
81
|
+
|
|
82
|
+
@strawberry.field
|
|
83
|
+
async def created_at(
|
|
84
|
+
self,
|
|
85
|
+
info: Info[Context, None],
|
|
86
|
+
) -> datetime:
|
|
87
|
+
if self.db_record:
|
|
88
|
+
val = self.db_record.created_at
|
|
89
|
+
else:
|
|
90
|
+
val = await info.context.data_loaders.dataset_fields.load(
|
|
91
|
+
(self.id, models.Dataset.created_at),
|
|
92
|
+
)
|
|
93
|
+
return val
|
|
94
|
+
|
|
95
|
+
@strawberry.field
|
|
96
|
+
async def updated_at(
|
|
97
|
+
self,
|
|
98
|
+
info: Info[Context, None],
|
|
99
|
+
) -> datetime:
|
|
100
|
+
if self.db_record:
|
|
101
|
+
val = self.db_record.updated_at
|
|
102
|
+
else:
|
|
103
|
+
val = await info.context.data_loaders.dataset_fields.load(
|
|
104
|
+
(self.id, models.Dataset.updated_at),
|
|
105
|
+
)
|
|
106
|
+
return val
|
|
39
107
|
|
|
40
108
|
@strawberry.field
|
|
41
109
|
async def versions(
|
|
@@ -54,7 +122,7 @@ class Dataset(Node):
|
|
|
54
122
|
before=before if isinstance(before, CursorString) else None,
|
|
55
123
|
)
|
|
56
124
|
async with info.context.db() as session:
|
|
57
|
-
stmt = select(models.DatasetVersion).filter_by(dataset_id=self.
|
|
125
|
+
stmt = select(models.DatasetVersion).filter_by(dataset_id=self.id)
|
|
58
126
|
if sort:
|
|
59
127
|
# For now assume the the column names match 1:1 with the enum values
|
|
60
128
|
sort_col = getattr(models.DatasetVersion, sort.col.value)
|
|
@@ -65,15 +133,7 @@ class Dataset(Node):
|
|
|
65
133
|
else:
|
|
66
134
|
stmt = stmt.order_by(models.DatasetVersion.created_at.desc())
|
|
67
135
|
versions = await session.scalars(stmt)
|
|
68
|
-
data = [
|
|
69
|
-
DatasetVersion(
|
|
70
|
-
id_attr=version.id,
|
|
71
|
-
description=version.description,
|
|
72
|
-
metadata=version.metadata_,
|
|
73
|
-
created_at=version.created_at,
|
|
74
|
-
)
|
|
75
|
-
for version in versions
|
|
76
|
-
]
|
|
136
|
+
data = [DatasetVersion(id=version.id, db_record=version) for version in versions]
|
|
77
137
|
return connection_from_list(data=data, args=args)
|
|
78
138
|
|
|
79
139
|
@strawberry.field(
|
|
@@ -84,8 +144,9 @@ class Dataset(Node):
|
|
|
84
144
|
self,
|
|
85
145
|
info: Info[Context, None],
|
|
86
146
|
dataset_version_id: Optional[GlobalID] = UNSET,
|
|
147
|
+
split_ids: Optional[list[GlobalID]] = UNSET,
|
|
87
148
|
) -> int:
|
|
88
|
-
dataset_id = self.
|
|
149
|
+
dataset_id = self.id
|
|
89
150
|
version_id = (
|
|
90
151
|
from_global_id_with_expected_type(
|
|
91
152
|
global_id=dataset_version_id,
|
|
@@ -94,6 +155,20 @@ class Dataset(Node):
|
|
|
94
155
|
if dataset_version_id
|
|
95
156
|
else None
|
|
96
157
|
)
|
|
158
|
+
|
|
159
|
+
# Parse split IDs if provided
|
|
160
|
+
split_rowids: Optional[list[int]] = None
|
|
161
|
+
if split_ids:
|
|
162
|
+
split_rowids = []
|
|
163
|
+
for split_id in split_ids:
|
|
164
|
+
try:
|
|
165
|
+
split_rowid = from_global_id_with_expected_type(
|
|
166
|
+
global_id=split_id, expected_type_name=models.DatasetSplit.__name__
|
|
167
|
+
)
|
|
168
|
+
split_rowids.append(split_rowid)
|
|
169
|
+
except Exception:
|
|
170
|
+
raise BadRequest(f"Invalid split ID: {split_id}")
|
|
171
|
+
|
|
97
172
|
revision_ids = (
|
|
98
173
|
select(func.max(models.DatasetExampleRevision.id))
|
|
99
174
|
.join(models.DatasetExample)
|
|
@@ -110,11 +185,36 @@ class Dataset(Node):
|
|
|
110
185
|
revision_ids = revision_ids.where(
|
|
111
186
|
models.DatasetExampleRevision.dataset_version_id <= version_id_subquery
|
|
112
187
|
)
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
188
|
+
|
|
189
|
+
# Build the count query
|
|
190
|
+
if split_rowids:
|
|
191
|
+
# When filtering by splits, count distinct examples that belong to those splits
|
|
192
|
+
stmt = (
|
|
193
|
+
select(count(models.DatasetExample.id.distinct()))
|
|
194
|
+
.join(
|
|
195
|
+
models.DatasetExampleRevision,
|
|
196
|
+
onclause=(
|
|
197
|
+
models.DatasetExample.id == models.DatasetExampleRevision.dataset_example_id
|
|
198
|
+
),
|
|
199
|
+
)
|
|
200
|
+
.join(
|
|
201
|
+
models.DatasetSplitDatasetExample,
|
|
202
|
+
onclause=(
|
|
203
|
+
models.DatasetExample.id
|
|
204
|
+
== models.DatasetSplitDatasetExample.dataset_example_id
|
|
205
|
+
),
|
|
206
|
+
)
|
|
207
|
+
.where(models.DatasetExampleRevision.id.in_(revision_ids))
|
|
208
|
+
.where(models.DatasetExampleRevision.revision_kind != "DELETE")
|
|
209
|
+
.where(models.DatasetSplitDatasetExample.dataset_split_id.in_(split_rowids))
|
|
210
|
+
)
|
|
211
|
+
else:
|
|
212
|
+
stmt = (
|
|
213
|
+
select(count(models.DatasetExampleRevision.id))
|
|
214
|
+
.where(models.DatasetExampleRevision.id.in_(revision_ids))
|
|
215
|
+
.where(models.DatasetExampleRevision.revision_kind != "DELETE")
|
|
216
|
+
)
|
|
217
|
+
|
|
118
218
|
async with info.context.db() as session:
|
|
119
219
|
return (await session.scalar(stmt)) or 0
|
|
120
220
|
|
|
@@ -123,10 +223,12 @@ class Dataset(Node):
|
|
|
123
223
|
self,
|
|
124
224
|
info: Info[Context, None],
|
|
125
225
|
dataset_version_id: Optional[GlobalID] = UNSET,
|
|
226
|
+
split_ids: Optional[list[GlobalID]] = UNSET,
|
|
126
227
|
first: Optional[int] = 50,
|
|
127
228
|
last: Optional[int] = UNSET,
|
|
128
229
|
after: Optional[CursorString] = UNSET,
|
|
129
230
|
before: Optional[CursorString] = UNSET,
|
|
231
|
+
filter: Optional[str] = UNSET,
|
|
130
232
|
) -> Connection[DatasetExample]:
|
|
131
233
|
args = ConnectionArgs(
|
|
132
234
|
first=first,
|
|
@@ -134,7 +236,7 @@ class Dataset(Node):
|
|
|
134
236
|
last=last,
|
|
135
237
|
before=before if isinstance(before, CursorString) else None,
|
|
136
238
|
)
|
|
137
|
-
dataset_id = self.
|
|
239
|
+
dataset_id = self.id
|
|
138
240
|
version_id = (
|
|
139
241
|
from_global_id_with_expected_type(
|
|
140
242
|
global_id=dataset_version_id, expected_type_name=DatasetVersion.__name__
|
|
@@ -142,6 +244,20 @@ class Dataset(Node):
|
|
|
142
244
|
if dataset_version_id
|
|
143
245
|
else None
|
|
144
246
|
)
|
|
247
|
+
|
|
248
|
+
# Parse split IDs if provided
|
|
249
|
+
split_rowids: Optional[list[int]] = None
|
|
250
|
+
if split_ids:
|
|
251
|
+
split_rowids = []
|
|
252
|
+
for split_id in split_ids:
|
|
253
|
+
try:
|
|
254
|
+
split_rowid = from_global_id_with_expected_type(
|
|
255
|
+
global_id=split_id, expected_type_name=models.DatasetSplit.__name__
|
|
256
|
+
)
|
|
257
|
+
split_rowids.append(split_rowid)
|
|
258
|
+
except Exception:
|
|
259
|
+
raise BadRequest(f"Invalid split ID: {split_id}")
|
|
260
|
+
|
|
145
261
|
revision_ids = (
|
|
146
262
|
select(func.max(models.DatasetExampleRevision.id))
|
|
147
263
|
.join(models.DatasetExample)
|
|
@@ -171,19 +287,51 @@ class Dataset(Node):
|
|
|
171
287
|
models.DatasetExampleRevision.revision_kind != "DELETE",
|
|
172
288
|
)
|
|
173
289
|
)
|
|
174
|
-
.order_by(models.
|
|
290
|
+
.order_by(models.DatasetExample.id.desc())
|
|
175
291
|
)
|
|
292
|
+
|
|
293
|
+
# Filter by split IDs if provided
|
|
294
|
+
if split_rowids:
|
|
295
|
+
query = (
|
|
296
|
+
query.join(
|
|
297
|
+
models.DatasetSplitDatasetExample,
|
|
298
|
+
onclause=(
|
|
299
|
+
models.DatasetExample.id
|
|
300
|
+
== models.DatasetSplitDatasetExample.dataset_example_id
|
|
301
|
+
),
|
|
302
|
+
)
|
|
303
|
+
.where(models.DatasetSplitDatasetExample.dataset_split_id.in_(split_rowids))
|
|
304
|
+
.distinct()
|
|
305
|
+
)
|
|
306
|
+
# Apply filter if provided - search through JSON fields (input, output, metadata)
|
|
307
|
+
if filter is not UNSET and filter:
|
|
308
|
+
# Create a filter that searches for the filter string in JSON fields
|
|
309
|
+
# Using PostgreSQL's JSON operators for case-insensitive text search
|
|
310
|
+
filter_condition = or_(
|
|
311
|
+
func.cast(models.DatasetExampleRevision.input, Text).ilike(f"%{filter}%"),
|
|
312
|
+
func.cast(models.DatasetExampleRevision.output, Text).ilike(f"%{filter}%"),
|
|
313
|
+
func.cast(models.DatasetExampleRevision.metadata_, Text).ilike(f"%{filter}%"),
|
|
314
|
+
)
|
|
315
|
+
query = query.where(filter_condition)
|
|
316
|
+
|
|
176
317
|
async with info.context.db() as session:
|
|
177
318
|
dataset_examples = [
|
|
178
319
|
DatasetExample(
|
|
179
|
-
|
|
320
|
+
id=example.id,
|
|
321
|
+
db_record=example,
|
|
180
322
|
version_id=version_id,
|
|
181
|
-
created_at=example.created_at,
|
|
182
323
|
)
|
|
183
324
|
async for example in await session.stream_scalars(query)
|
|
184
325
|
]
|
|
185
326
|
return connection_from_list(data=dataset_examples, args=args)
|
|
186
327
|
|
|
328
|
+
@strawberry.field
|
|
329
|
+
async def splits(self, info: Info[Context, None]) -> list[DatasetSplit]:
|
|
330
|
+
return [
|
|
331
|
+
DatasetSplit(id=split.id, db_record=split)
|
|
332
|
+
for split in await info.context.data_loaders.dataset_dataset_splits.load(self.id)
|
|
333
|
+
]
|
|
334
|
+
|
|
187
335
|
@strawberry.field(
|
|
188
336
|
description="Number of experiments for a specific version if version is specified, "
|
|
189
337
|
"or for all versions if version is not specified."
|
|
@@ -193,9 +341,7 @@ class Dataset(Node):
|
|
|
193
341
|
info: Info[Context, None],
|
|
194
342
|
dataset_version_id: Optional[GlobalID] = UNSET,
|
|
195
343
|
) -> int:
|
|
196
|
-
stmt = select(count(models.Experiment.id)).where(
|
|
197
|
-
models.Experiment.dataset_id == self.id_attr
|
|
198
|
-
)
|
|
344
|
+
stmt = select(count(models.Experiment.id)).where(models.Experiment.dataset_id == self.id)
|
|
199
345
|
version_id = (
|
|
200
346
|
from_global_id_with_expected_type(
|
|
201
347
|
global_id=dataset_version_id,
|
|
@@ -228,7 +374,7 @@ class Dataset(Node):
|
|
|
228
374
|
last=last,
|
|
229
375
|
before=before if isinstance(before, CursorString) else None,
|
|
230
376
|
)
|
|
231
|
-
dataset_id = self.
|
|
377
|
+
dataset_id = self.id
|
|
232
378
|
row_number = func.row_number().over(order_by=models.Experiment.id).label("row_number")
|
|
233
379
|
query = (
|
|
234
380
|
select(models.Experiment, row_number)
|
|
@@ -270,17 +416,15 @@ class Dataset(Node):
|
|
|
270
416
|
@strawberry.field
|
|
271
417
|
async def experiment_annotation_summaries(
|
|
272
418
|
self, info: Info[Context, None]
|
|
273
|
-
) -> list[
|
|
274
|
-
dataset_id = self.
|
|
419
|
+
) -> list[DatasetExperimentAnnotationSummary]:
|
|
420
|
+
dataset_id = self.id
|
|
275
421
|
query = (
|
|
276
422
|
select(
|
|
277
|
-
models.ExperimentRunAnnotation.name,
|
|
278
|
-
func.min(models.ExperimentRunAnnotation.score),
|
|
279
|
-
func.max(models.ExperimentRunAnnotation.score),
|
|
280
|
-
func.avg(models.ExperimentRunAnnotation.score),
|
|
281
|
-
func.count(),
|
|
282
|
-
func.count(models.ExperimentRunAnnotation.error),
|
|
423
|
+
models.ExperimentRunAnnotation.name.label("annotation_name"),
|
|
424
|
+
func.min(models.ExperimentRunAnnotation.score).label("min_score"),
|
|
425
|
+
func.max(models.ExperimentRunAnnotation.score).label("max_score"),
|
|
283
426
|
)
|
|
427
|
+
.select_from(models.ExperimentRunAnnotation)
|
|
284
428
|
.join(
|
|
285
429
|
models.ExperimentRun,
|
|
286
430
|
models.ExperimentRunAnnotation.experiment_run_id == models.ExperimentRun.id,
|
|
@@ -295,38 +439,21 @@ class Dataset(Node):
|
|
|
295
439
|
)
|
|
296
440
|
async with info.context.db() as session:
|
|
297
441
|
return [
|
|
298
|
-
|
|
299
|
-
annotation_name=annotation_name,
|
|
300
|
-
min_score=min_score,
|
|
301
|
-
max_score=max_score,
|
|
302
|
-
mean_score=mean_score,
|
|
303
|
-
count=count_,
|
|
304
|
-
error_count=error_count,
|
|
442
|
+
DatasetExperimentAnnotationSummary(
|
|
443
|
+
annotation_name=scores_tuple.annotation_name,
|
|
444
|
+
min_score=scores_tuple.min_score,
|
|
445
|
+
max_score=scores_tuple.max_score,
|
|
305
446
|
)
|
|
306
|
-
async for (
|
|
307
|
-
annotation_name,
|
|
308
|
-
min_score,
|
|
309
|
-
max_score,
|
|
310
|
-
mean_score,
|
|
311
|
-
count_,
|
|
312
|
-
error_count,
|
|
313
|
-
) in await session.stream(query)
|
|
447
|
+
async for scores_tuple in await session.stream(query)
|
|
314
448
|
]
|
|
315
449
|
|
|
316
450
|
@strawberry.field
|
|
317
|
-
def
|
|
318
|
-
return
|
|
319
|
-
|
|
451
|
+
async def labels(self, info: Info[Context, None]) -> list[DatasetLabel]:
|
|
452
|
+
return [
|
|
453
|
+
DatasetLabel(id=label.id, db_record=label)
|
|
454
|
+
for label in await info.context.data_loaders.dataset_labels.load(self.id)
|
|
455
|
+
]
|
|
320
456
|
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
"""
|
|
325
|
-
return Dataset(
|
|
326
|
-
id_attr=dataset.id,
|
|
327
|
-
name=dataset.name,
|
|
328
|
-
description=dataset.description,
|
|
329
|
-
metadata=dataset.metadata_,
|
|
330
|
-
created_at=dataset.created_at,
|
|
331
|
-
updated_at=dataset.updated_at,
|
|
332
|
-
)
|
|
457
|
+
@strawberry.field
|
|
458
|
+
def last_updated_at(self, info: Info[Context, None]) -> Optional[datetime]:
|
|
459
|
+
return info.context.last_updated_at.get(models.Dataset, self.id)
|
|
@@ -1,40 +1,59 @@
|
|
|
1
1
|
from datetime import datetime
|
|
2
|
-
from typing import Optional
|
|
2
|
+
from typing import TYPE_CHECKING, Annotated, Optional
|
|
3
3
|
|
|
4
4
|
import strawberry
|
|
5
5
|
from sqlalchemy import select
|
|
6
|
-
from sqlalchemy.orm import joinedload
|
|
7
6
|
from strawberry import UNSET
|
|
8
7
|
from strawberry.relay.types import Connection, GlobalID, Node, NodeID
|
|
9
8
|
from strawberry.types import Info
|
|
10
9
|
|
|
11
10
|
from phoenix.db import models
|
|
12
11
|
from phoenix.server.api.context import Context
|
|
12
|
+
from phoenix.server.api.exceptions import BadRequest
|
|
13
13
|
from phoenix.server.api.types.DatasetExampleRevision import DatasetExampleRevision
|
|
14
|
+
from phoenix.server.api.types.DatasetSplit import DatasetSplit
|
|
14
15
|
from phoenix.server.api.types.DatasetVersion import DatasetVersion
|
|
15
|
-
from phoenix.server.api.types.
|
|
16
|
+
from phoenix.server.api.types.ExperimentRepeatedRunGroup import (
|
|
17
|
+
ExperimentRepeatedRunGroup,
|
|
18
|
+
)
|
|
19
|
+
from phoenix.server.api.types.ExperimentRun import ExperimentRun
|
|
16
20
|
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
17
21
|
from phoenix.server.api.types.pagination import (
|
|
18
22
|
ConnectionArgs,
|
|
19
23
|
CursorString,
|
|
20
24
|
connection_from_list,
|
|
21
25
|
)
|
|
22
|
-
|
|
26
|
+
|
|
27
|
+
if TYPE_CHECKING:
|
|
28
|
+
from .Span import Span
|
|
23
29
|
|
|
24
30
|
|
|
25
31
|
@strawberry.type
|
|
26
32
|
class DatasetExample(Node):
|
|
27
|
-
|
|
28
|
-
|
|
33
|
+
id: NodeID[int]
|
|
34
|
+
db_record: strawberry.Private[Optional[models.DatasetExample]] = None
|
|
29
35
|
version_id: strawberry.Private[Optional[int]] = None
|
|
30
36
|
|
|
37
|
+
def __post_init__(self) -> None:
|
|
38
|
+
if self.db_record and self.id != self.db_record.id:
|
|
39
|
+
raise ValueError("DatasetExample ID mismatch")
|
|
40
|
+
|
|
41
|
+
@strawberry.field
|
|
42
|
+
async def created_at(self, info: Info[Context, None]) -> datetime:
|
|
43
|
+
if self.db_record:
|
|
44
|
+
val = self.db_record.created_at
|
|
45
|
+
else:
|
|
46
|
+
val = await info.context.data_loaders.dataset_example_fields.load(
|
|
47
|
+
(self.id, models.DatasetExample.created_at),
|
|
48
|
+
)
|
|
49
|
+
return val
|
|
50
|
+
|
|
31
51
|
@strawberry.field
|
|
32
52
|
async def revision(
|
|
33
53
|
self,
|
|
34
54
|
info: Info[Context, None],
|
|
35
55
|
dataset_version_id: Optional[GlobalID] = UNSET,
|
|
36
56
|
) -> DatasetExampleRevision:
|
|
37
|
-
example_id = self.id_attr
|
|
38
57
|
version_id: Optional[int] = None
|
|
39
58
|
if dataset_version_id:
|
|
40
59
|
version_id = from_global_id_with_expected_type(
|
|
@@ -42,18 +61,18 @@ class DatasetExample(Node):
|
|
|
42
61
|
)
|
|
43
62
|
elif self.version_id is not None:
|
|
44
63
|
version_id = self.version_id
|
|
45
|
-
return await info.context.data_loaders.dataset_example_revisions.load(
|
|
46
|
-
(example_id, version_id)
|
|
47
|
-
)
|
|
64
|
+
return await info.context.data_loaders.dataset_example_revisions.load((self.id, version_id))
|
|
48
65
|
|
|
49
66
|
@strawberry.field
|
|
50
67
|
async def span(
|
|
51
68
|
self,
|
|
52
69
|
info: Info[Context, None],
|
|
53
|
-
) -> Optional[Span]:
|
|
70
|
+
) -> Optional[Annotated["Span", strawberry.lazy(".Span")]]:
|
|
71
|
+
from .Span import Span
|
|
72
|
+
|
|
54
73
|
return (
|
|
55
|
-
Span(
|
|
56
|
-
if (span := await info.context.data_loaders.dataset_example_spans.load(self.
|
|
74
|
+
Span(id=span.id, db_record=span)
|
|
75
|
+
if (span := await info.context.data_loaders.dataset_example_spans.load(self.id))
|
|
57
76
|
else None
|
|
58
77
|
)
|
|
59
78
|
|
|
@@ -65,6 +84,7 @@ class DatasetExample(Node):
|
|
|
65
84
|
last: Optional[int] = UNSET,
|
|
66
85
|
after: Optional[CursorString] = UNSET,
|
|
67
86
|
before: Optional[CursorString] = UNSET,
|
|
87
|
+
experiment_ids: Optional[list[GlobalID]] = UNSET,
|
|
68
88
|
) -> Connection[ExperimentRun]:
|
|
69
89
|
args = ConnectionArgs(
|
|
70
90
|
first=first,
|
|
@@ -72,14 +92,64 @@ class DatasetExample(Node):
|
|
|
72
92
|
last=last,
|
|
73
93
|
before=before if isinstance(before, CursorString) else None,
|
|
74
94
|
)
|
|
75
|
-
example_id = self.id_attr
|
|
76
95
|
query = (
|
|
77
96
|
select(models.ExperimentRun)
|
|
78
|
-
.options(joinedload(models.ExperimentRun.trace).load_only(models.Trace.trace_id))
|
|
79
97
|
.join(models.Experiment, models.Experiment.id == models.ExperimentRun.experiment_id)
|
|
80
|
-
.where(models.ExperimentRun.dataset_example_id ==
|
|
81
|
-
.order_by(
|
|
98
|
+
.where(models.ExperimentRun.dataset_example_id == self.id)
|
|
99
|
+
.order_by(
|
|
100
|
+
models.ExperimentRun.experiment_id.asc(),
|
|
101
|
+
models.ExperimentRun.repetition_number.asc(),
|
|
102
|
+
)
|
|
82
103
|
)
|
|
104
|
+
if experiment_ids:
|
|
105
|
+
experiment_db_ids = [
|
|
106
|
+
from_global_id_with_expected_type(
|
|
107
|
+
global_id=experiment_id,
|
|
108
|
+
expected_type_name=models.Experiment.__name__,
|
|
109
|
+
)
|
|
110
|
+
for experiment_id in experiment_ids or []
|
|
111
|
+
]
|
|
112
|
+
query = query.where(models.ExperimentRun.experiment_id.in_(experiment_db_ids))
|
|
83
113
|
async with info.context.db() as session:
|
|
84
114
|
runs = (await session.scalars(query)).all()
|
|
85
|
-
return connection_from_list([
|
|
115
|
+
return connection_from_list([ExperimentRun(id=run.id, db_record=run) for run in runs], args)
|
|
116
|
+
|
|
117
|
+
@strawberry.field
|
|
118
|
+
async def experiment_repeated_run_groups(
|
|
119
|
+
self,
|
|
120
|
+
info: Info[Context, None],
|
|
121
|
+
experiment_ids: list[GlobalID],
|
|
122
|
+
) -> list[ExperimentRepeatedRunGroup]:
|
|
123
|
+
experiment_rowids = []
|
|
124
|
+
for experiment_id in experiment_ids:
|
|
125
|
+
try:
|
|
126
|
+
experiment_rowid = from_global_id_with_expected_type(
|
|
127
|
+
global_id=experiment_id,
|
|
128
|
+
expected_type_name=models.Experiment.__name__,
|
|
129
|
+
)
|
|
130
|
+
except Exception:
|
|
131
|
+
raise BadRequest(f"Invalid experiment ID: {experiment_id}")
|
|
132
|
+
experiment_rowids.append(experiment_rowid)
|
|
133
|
+
repeated_run_groups = (
|
|
134
|
+
await info.context.data_loaders.experiment_repeated_run_groups.load_many(
|
|
135
|
+
[(experiment_rowid, self.id) for experiment_rowid in experiment_rowids]
|
|
136
|
+
)
|
|
137
|
+
)
|
|
138
|
+
return [
|
|
139
|
+
ExperimentRepeatedRunGroup(
|
|
140
|
+
experiment_rowid=group.experiment_rowid,
|
|
141
|
+
dataset_example_rowid=group.dataset_example_rowid,
|
|
142
|
+
cached_runs=[ExperimentRun(id=run.id, db_record=run) for run in group.runs],
|
|
143
|
+
)
|
|
144
|
+
for group in repeated_run_groups
|
|
145
|
+
]
|
|
146
|
+
|
|
147
|
+
@strawberry.field
|
|
148
|
+
async def dataset_splits(
|
|
149
|
+
self,
|
|
150
|
+
info: Info[Context, None],
|
|
151
|
+
) -> list[DatasetSplit]:
|
|
152
|
+
return [
|
|
153
|
+
DatasetSplit(id=split.id, db_record=split)
|
|
154
|
+
for split in await info.context.data_loaders.dataset_example_splits.load(self.id)
|
|
155
|
+
]
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import strawberry
|
|
4
|
+
from strawberry.relay import Node, NodeID
|
|
5
|
+
from strawberry.types import Info
|
|
6
|
+
|
|
7
|
+
from phoenix.db import models
|
|
8
|
+
from phoenix.server.api.context import Context
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@strawberry.type
|
|
12
|
+
class DatasetLabel(Node):
|
|
13
|
+
id: NodeID[int]
|
|
14
|
+
db_record: strawberry.Private[Optional[models.DatasetLabel]] = None
|
|
15
|
+
|
|
16
|
+
def __post_init__(self) -> None:
|
|
17
|
+
if self.db_record and self.id != self.db_record.id:
|
|
18
|
+
raise ValueError("DatasetLabel ID mismatch")
|
|
19
|
+
|
|
20
|
+
@strawberry.field
|
|
21
|
+
async def name(
|
|
22
|
+
self,
|
|
23
|
+
info: Info[Context, None],
|
|
24
|
+
) -> str:
|
|
25
|
+
if self.db_record:
|
|
26
|
+
val = self.db_record.name
|
|
27
|
+
else:
|
|
28
|
+
val = await info.context.data_loaders.dataset_label_fields.load(
|
|
29
|
+
(self.id, models.DatasetLabel.name),
|
|
30
|
+
)
|
|
31
|
+
return val
|
|
32
|
+
|
|
33
|
+
@strawberry.field
|
|
34
|
+
async def description(
|
|
35
|
+
self,
|
|
36
|
+
info: Info[Context, None],
|
|
37
|
+
) -> Optional[str]:
|
|
38
|
+
if self.db_record:
|
|
39
|
+
val = self.db_record.description
|
|
40
|
+
else:
|
|
41
|
+
val = await info.context.data_loaders.dataset_label_fields.load(
|
|
42
|
+
(self.id, models.DatasetLabel.description),
|
|
43
|
+
)
|
|
44
|
+
return val
|
|
45
|
+
|
|
46
|
+
@strawberry.field
|
|
47
|
+
async def color(
|
|
48
|
+
self,
|
|
49
|
+
info: Info[Context, None],
|
|
50
|
+
) -> str:
|
|
51
|
+
if self.db_record:
|
|
52
|
+
val = self.db_record.color
|
|
53
|
+
else:
|
|
54
|
+
val = await info.context.data_loaders.dataset_label_fields.load(
|
|
55
|
+
(self.id, models.DatasetLabel.color),
|
|
56
|
+
)
|
|
57
|
+
return val
|