arize-phoenix 4.4.2__py3-none-any.whl → 4.4.4rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. {arize_phoenix-4.4.2.dist-info → arize_phoenix-4.4.4rc0.dist-info}/METADATA +12 -11
  2. {arize_phoenix-4.4.2.dist-info → arize_phoenix-4.4.4rc0.dist-info}/RECORD +110 -57
  3. phoenix/__init__.py +0 -27
  4. phoenix/config.py +21 -7
  5. phoenix/core/model.py +25 -25
  6. phoenix/core/model_schema.py +66 -64
  7. phoenix/core/model_schema_adapter.py +27 -25
  8. phoenix/datasets/__init__.py +0 -0
  9. phoenix/datasets/evaluators.py +275 -0
  10. phoenix/datasets/experiments.py +469 -0
  11. phoenix/datasets/tracing.py +66 -0
  12. phoenix/datasets/types.py +212 -0
  13. phoenix/db/bulk_inserter.py +54 -14
  14. phoenix/db/insertion/dataset.py +234 -0
  15. phoenix/db/insertion/evaluation.py +6 -6
  16. phoenix/db/insertion/helpers.py +13 -2
  17. phoenix/db/migrations/types.py +29 -0
  18. phoenix/db/migrations/versions/10460e46d750_datasets.py +291 -0
  19. phoenix/db/migrations/versions/cf03bd6bae1d_init.py +2 -28
  20. phoenix/db/models.py +230 -3
  21. phoenix/inferences/fixtures.py +23 -23
  22. phoenix/inferences/inferences.py +7 -7
  23. phoenix/inferences/validation.py +1 -1
  24. phoenix/metrics/binning.py +2 -2
  25. phoenix/server/api/context.py +16 -0
  26. phoenix/server/api/dataloaders/__init__.py +16 -0
  27. phoenix/server/api/dataloaders/dataset_example_revisions.py +100 -0
  28. phoenix/server/api/dataloaders/dataset_example_spans.py +43 -0
  29. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +85 -0
  30. phoenix/server/api/dataloaders/experiment_error_rates.py +43 -0
  31. phoenix/server/api/dataloaders/experiment_sequence_number.py +49 -0
  32. phoenix/server/api/dataloaders/project_by_name.py +31 -0
  33. phoenix/server/api/dataloaders/span_descendants.py +2 -3
  34. phoenix/server/api/dataloaders/span_projects.py +33 -0
  35. phoenix/server/api/dataloaders/trace_row_ids.py +39 -0
  36. phoenix/server/api/helpers/dataset_helpers.py +178 -0
  37. phoenix/server/api/input_types/AddExamplesToDatasetInput.py +16 -0
  38. phoenix/server/api/input_types/AddSpansToDatasetInput.py +14 -0
  39. phoenix/server/api/input_types/CreateDatasetInput.py +12 -0
  40. phoenix/server/api/input_types/DatasetExampleInput.py +14 -0
  41. phoenix/server/api/input_types/DatasetSort.py +17 -0
  42. phoenix/server/api/input_types/DatasetVersionSort.py +16 -0
  43. phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +13 -0
  44. phoenix/server/api/input_types/DeleteDatasetInput.py +7 -0
  45. phoenix/server/api/input_types/DeleteExperimentsInput.py +9 -0
  46. phoenix/server/api/input_types/PatchDatasetExamplesInput.py +35 -0
  47. phoenix/server/api/input_types/PatchDatasetInput.py +14 -0
  48. phoenix/server/api/mutations/__init__.py +13 -0
  49. phoenix/server/api/mutations/auth.py +11 -0
  50. phoenix/server/api/mutations/dataset_mutations.py +520 -0
  51. phoenix/server/api/mutations/experiment_mutations.py +65 -0
  52. phoenix/server/api/{types/ExportEventsMutation.py → mutations/export_events_mutations.py} +17 -14
  53. phoenix/server/api/mutations/project_mutations.py +42 -0
  54. phoenix/server/api/queries.py +503 -0
  55. phoenix/server/api/routers/v1/__init__.py +77 -2
  56. phoenix/server/api/routers/v1/dataset_examples.py +178 -0
  57. phoenix/server/api/routers/v1/datasets.py +861 -0
  58. phoenix/server/api/routers/v1/evaluations.py +4 -2
  59. phoenix/server/api/routers/v1/experiment_evaluations.py +65 -0
  60. phoenix/server/api/routers/v1/experiment_runs.py +108 -0
  61. phoenix/server/api/routers/v1/experiments.py +174 -0
  62. phoenix/server/api/routers/v1/spans.py +3 -1
  63. phoenix/server/api/routers/v1/traces.py +1 -4
  64. phoenix/server/api/schema.py +2 -303
  65. phoenix/server/api/types/AnnotatorKind.py +10 -0
  66. phoenix/server/api/types/Cluster.py +19 -19
  67. phoenix/server/api/types/CreateDatasetPayload.py +8 -0
  68. phoenix/server/api/types/Dataset.py +282 -63
  69. phoenix/server/api/types/DatasetExample.py +85 -0
  70. phoenix/server/api/types/DatasetExampleRevision.py +34 -0
  71. phoenix/server/api/types/DatasetVersion.py +14 -0
  72. phoenix/server/api/types/Dimension.py +30 -29
  73. phoenix/server/api/types/EmbeddingDimension.py +40 -34
  74. phoenix/server/api/types/Event.py +16 -16
  75. phoenix/server/api/types/ExampleRevisionInterface.py +14 -0
  76. phoenix/server/api/types/Experiment.py +135 -0
  77. phoenix/server/api/types/ExperimentAnnotationSummary.py +13 -0
  78. phoenix/server/api/types/ExperimentComparison.py +19 -0
  79. phoenix/server/api/types/ExperimentRun.py +91 -0
  80. phoenix/server/api/types/ExperimentRunAnnotation.py +57 -0
  81. phoenix/server/api/types/Inferences.py +80 -0
  82. phoenix/server/api/types/InferencesRole.py +23 -0
  83. phoenix/server/api/types/Model.py +43 -42
  84. phoenix/server/api/types/Project.py +26 -12
  85. phoenix/server/api/types/Segments.py +1 -1
  86. phoenix/server/api/types/Span.py +78 -2
  87. phoenix/server/api/types/TimeSeries.py +6 -6
  88. phoenix/server/api/types/Trace.py +15 -4
  89. phoenix/server/api/types/UMAPPoints.py +1 -1
  90. phoenix/server/api/types/node.py +5 -111
  91. phoenix/server/api/types/pagination.py +10 -52
  92. phoenix/server/app.py +99 -49
  93. phoenix/server/main.py +49 -27
  94. phoenix/server/openapi/docs.py +3 -0
  95. phoenix/server/static/index.js +2246 -1368
  96. phoenix/server/templates/index.html +1 -0
  97. phoenix/services.py +15 -15
  98. phoenix/session/client.py +316 -21
  99. phoenix/session/session.py +47 -37
  100. phoenix/trace/exporter.py +14 -9
  101. phoenix/trace/fixtures.py +133 -7
  102. phoenix/trace/span_evaluations.py +3 -3
  103. phoenix/trace/trace_dataset.py +6 -6
  104. phoenix/utilities/json.py +61 -0
  105. phoenix/utilities/re.py +50 -0
  106. phoenix/version.py +1 -1
  107. phoenix/server/api/types/DatasetRole.py +0 -23
  108. {arize_phoenix-4.4.2.dist-info → arize_phoenix-4.4.4rc0.dist-info}/WHEEL +0 -0
  109. {arize_phoenix-4.4.2.dist-info → arize_phoenix-4.4.4rc0.dist-info}/licenses/IP_NOTICE +0 -0
  110. {arize_phoenix-4.4.2.dist-info → arize_phoenix-4.4.4rc0.dist-info}/licenses/LICENSE +0 -0
  111. /phoenix/server/api/{helpers.py → helpers/__init__.py} +0 -0
@@ -1,80 +1,299 @@
1
1
  from datetime import datetime
2
- from typing import Iterable, List, Optional, Set, Union
2
+ from typing import AsyncIterable, List, Optional, Tuple, cast
3
3
 
4
4
  import strawberry
5
- from strawberry.scalars import ID
6
- from strawberry.unset import UNSET
5
+ from sqlalchemy import and_, func, select
6
+ from sqlalchemy.sql.functions import count
7
+ from strawberry import UNSET
8
+ from strawberry.relay import Connection, GlobalID, Node, NodeID
9
+ from strawberry.scalars import JSON
10
+ from strawberry.types import Info
7
11
 
8
- import phoenix.core.model_schema as ms
9
- from phoenix.core.model_schema import FEATURE, TAG, ScalarDimension
10
-
11
- from ..input_types.DimensionInput import DimensionInput
12
- from .DatasetRole import AncillaryDatasetRole, DatasetRole
13
- from .Dimension import Dimension, to_gql_dimension
14
- from .Event import Event, create_event, create_event_id, parse_event_ids_by_dataset_role
12
+ from phoenix.db import models
13
+ from phoenix.server.api.context import Context
14
+ from phoenix.server.api.input_types.DatasetVersionSort import DatasetVersionSort
15
+ from phoenix.server.api.types.DatasetExample import DatasetExample
16
+ from phoenix.server.api.types.DatasetVersion import DatasetVersion
17
+ from phoenix.server.api.types.Experiment import Experiment, to_gql_experiment
18
+ from phoenix.server.api.types.ExperimentAnnotationSummary import ExperimentAnnotationSummary
19
+ from phoenix.server.api.types.node import from_global_id_with_expected_type
20
+ from phoenix.server.api.types.pagination import (
21
+ ConnectionArgs,
22
+ CursorString,
23
+ connection_from_list,
24
+ )
25
+ from phoenix.server.api.types.SortDir import SortDir
15
26
 
16
27
 
17
28
  @strawberry.type
18
- class Dataset:
19
- start_time: datetime = strawberry.field(description="The start bookend of the data")
20
- end_time: datetime = strawberry.field(description="The end bookend of the data")
21
- record_count: int = strawberry.field(description="The record count of the data")
22
- dataset: strawberry.Private[ms.Dataset]
23
- dataset_role: strawberry.Private[Union[DatasetRole, AncillaryDatasetRole]]
24
- model: strawberry.Private[ms.Model]
25
-
26
- # type ignored here to get around the following: https://github.com/strawberry-graphql/strawberry/issues/1929
27
- @strawberry.field(description="Returns a human friendly name for the dataset.") # type: ignore
28
- def name(self) -> str:
29
- return self.dataset.display_name
29
+ class Dataset(Node):
30
+ id_attr: NodeID[int]
31
+ name: str
32
+ description: Optional[str]
33
+ metadata: JSON
34
+ created_at: datetime
35
+ updated_at: datetime
30
36
 
31
37
  @strawberry.field
32
- def events(
38
+ async def versions(
33
39
  self,
34
- event_ids: List[ID],
35
- dimensions: Optional[List[DimensionInput]] = UNSET,
36
- ) -> List[Event]:
37
- """
38
- Returns events for specific event IDs and dimensions. If no input
39
- dimensions are provided, returns all features and tags.
40
- """
41
- if not event_ids:
42
- return []
43
- row_ids = parse_event_ids_by_dataset_role(event_ids)
44
- if len(row_ids) > 1 or self.dataset_role not in row_ids:
45
- raise ValueError("eventIds contains IDs from incorrect dataset.")
46
- events = self.dataset[row_ids[self.dataset_role]]
47
- requested_gql_dimensions = _get_requested_features_and_tags(
48
- core_dimensions=self.model.scalar_dimensions,
49
- requested_dimension_names=set(dim.name for dim in dimensions)
50
- if isinstance(dimensions, list)
51
- else None,
40
+ info: Info[Context, None],
41
+ first: Optional[int] = 50,
42
+ last: Optional[int] = UNSET,
43
+ after: Optional[CursorString] = UNSET,
44
+ before: Optional[CursorString] = UNSET,
45
+ sort: Optional[DatasetVersionSort] = UNSET,
46
+ ) -> Connection[DatasetVersion]:
47
+ args = ConnectionArgs(
48
+ first=first,
49
+ after=after if isinstance(after, CursorString) else None,
50
+ last=last,
51
+ before=before if isinstance(before, CursorString) else None,
52
52
  )
53
- return [
54
- create_event(
55
- event_id=create_event_id(event.id.row_id, self.dataset_role),
56
- event=event,
57
- dimensions=requested_gql_dimensions,
58
- is_document_record=self.dataset_role is AncillaryDatasetRole.corpus,
53
+ async with info.context.db() as session:
54
+ stmt = select(models.DatasetVersion).filter_by(dataset_id=self.id_attr)
55
+ if sort:
56
+ # For now assume the the column names match 1:1 with the enum values
57
+ sort_col = getattr(models.DatasetVersion, sort.col.value)
58
+ if sort.dir is SortDir.desc:
59
+ stmt = stmt.order_by(sort_col.desc(), models.DatasetVersion.id.desc())
60
+ else:
61
+ stmt = stmt.order_by(sort_col.asc(), models.DatasetVersion.id.asc())
62
+ else:
63
+ stmt = stmt.order_by(models.DatasetVersion.created_at.desc())
64
+ versions = await session.scalars(stmt)
65
+ data = [
66
+ DatasetVersion(
67
+ id_attr=version.id,
68
+ description=version.description,
69
+ metadata=version.metadata_,
70
+ created_at=version.created_at,
59
71
  )
60
- for event in events
72
+ for version in versions
61
73
  ]
74
+ return connection_from_list(data=data, args=args)
75
+
76
+ @strawberry.field(
77
+ description="Number of examples in a specific version if version is specified, or in the "
78
+ "latest version if version is not specified."
79
+ ) # type: ignore
80
+ async def example_count(
81
+ self,
82
+ info: Info[Context, None],
83
+ dataset_version_id: Optional[GlobalID] = UNSET,
84
+ ) -> int:
85
+ dataset_id = self.id_attr
86
+ version_id = (
87
+ from_global_id_with_expected_type(
88
+ global_id=dataset_version_id,
89
+ expected_type_name=DatasetVersion.__name__,
90
+ )
91
+ if dataset_version_id
92
+ else None
93
+ )
94
+ revision_ids = (
95
+ select(func.max(models.DatasetExampleRevision.id))
96
+ .join(models.DatasetExample)
97
+ .where(models.DatasetExample.dataset_id == dataset_id)
98
+ .group_by(models.DatasetExampleRevision.dataset_example_id)
99
+ )
100
+ if version_id:
101
+ version_id_subquery = (
102
+ select(models.DatasetVersion.id)
103
+ .where(models.DatasetVersion.dataset_id == dataset_id)
104
+ .where(models.DatasetVersion.id == version_id)
105
+ .scalar_subquery()
106
+ )
107
+ revision_ids = revision_ids.where(
108
+ models.DatasetExampleRevision.dataset_version_id <= version_id_subquery
109
+ )
110
+ stmt = (
111
+ select(count(models.DatasetExampleRevision.id))
112
+ .where(models.DatasetExampleRevision.id.in_(revision_ids))
113
+ .where(models.DatasetExampleRevision.revision_kind != "DELETE")
114
+ )
115
+ async with info.context.db() as session:
116
+ return (await session.scalar(stmt)) or 0
117
+
118
+ @strawberry.field
119
+ async def examples(
120
+ self,
121
+ info: Info[Context, None],
122
+ dataset_version_id: Optional[GlobalID] = UNSET,
123
+ first: Optional[int] = 50,
124
+ last: Optional[int] = UNSET,
125
+ after: Optional[CursorString] = UNSET,
126
+ before: Optional[CursorString] = UNSET,
127
+ ) -> Connection[DatasetExample]:
128
+ args = ConnectionArgs(
129
+ first=first,
130
+ after=after if isinstance(after, CursorString) else None,
131
+ last=last,
132
+ before=before if isinstance(before, CursorString) else None,
133
+ )
134
+ dataset_id = self.id_attr
135
+ version_id = (
136
+ from_global_id_with_expected_type(
137
+ global_id=dataset_version_id, expected_type_name=DatasetVersion.__name__
138
+ )
139
+ if dataset_version_id
140
+ else None
141
+ )
142
+ revision_ids = (
143
+ select(func.max(models.DatasetExampleRevision.id))
144
+ .join(models.DatasetExample)
145
+ .where(models.DatasetExample.dataset_id == dataset_id)
146
+ .group_by(models.DatasetExampleRevision.dataset_example_id)
147
+ )
148
+ if version_id:
149
+ version_id_subquery = (
150
+ select(models.DatasetVersion.id)
151
+ .where(models.DatasetVersion.dataset_id == dataset_id)
152
+ .where(models.DatasetVersion.id == version_id)
153
+ .scalar_subquery()
154
+ )
155
+ revision_ids = revision_ids.where(
156
+ models.DatasetExampleRevision.dataset_version_id <= version_id_subquery
157
+ )
158
+ query = (
159
+ select(models.DatasetExample)
160
+ .join(
161
+ models.DatasetExampleRevision,
162
+ onclause=models.DatasetExample.id
163
+ == models.DatasetExampleRevision.dataset_example_id,
164
+ )
165
+ .where(
166
+ and_(
167
+ models.DatasetExampleRevision.id.in_(revision_ids),
168
+ models.DatasetExampleRevision.revision_kind != "DELETE",
169
+ )
170
+ )
171
+ .order_by(models.DatasetExampleRevision.dataset_example_id.desc())
172
+ )
173
+ async with info.context.db() as session:
174
+ dataset_examples = [
175
+ DatasetExample(
176
+ id_attr=example.id,
177
+ version_id=version_id,
178
+ created_at=example.created_at,
179
+ )
180
+ async for example in await session.stream_scalars(query)
181
+ ]
182
+ return connection_from_list(data=dataset_examples, args=args)
183
+
184
+ @strawberry.field(
185
+ description="Number of experiments for a specific version if version is specified, "
186
+ "or for all versions if version is not specified."
187
+ ) # type: ignore
188
+ async def experiment_count(
189
+ self,
190
+ info: Info[Context, None],
191
+ dataset_version_id: Optional[GlobalID] = UNSET,
192
+ ) -> int:
193
+ stmt = select(count(models.Experiment.id)).where(
194
+ models.Experiment.dataset_id == self.id_attr
195
+ )
196
+ version_id = (
197
+ from_global_id_with_expected_type(
198
+ global_id=dataset_version_id,
199
+ expected_type_name=DatasetVersion.__name__,
200
+ )
201
+ if dataset_version_id
202
+ else None
203
+ )
204
+ if version_id is not None:
205
+ stmt = stmt.where(models.Experiment.dataset_version_id == version_id)
206
+ async with info.context.db() as session:
207
+ return (await session.scalar(stmt)) or 0
208
+
209
+ @strawberry.field
210
+ async def experiments(
211
+ self,
212
+ info: Info[Context, None],
213
+ first: Optional[int] = 50,
214
+ last: Optional[int] = UNSET,
215
+ after: Optional[CursorString] = UNSET,
216
+ before: Optional[CursorString] = UNSET,
217
+ ) -> Connection[Experiment]:
218
+ args = ConnectionArgs(
219
+ first=first,
220
+ after=after if isinstance(after, CursorString) else None,
221
+ last=last,
222
+ before=before if isinstance(before, CursorString) else None,
223
+ )
224
+ dataset_id = self.id_attr
225
+ row_number = func.row_number().over(order_by=models.Experiment.id).label("row_number")
226
+ query = (
227
+ select(models.Experiment, row_number)
228
+ .where(models.Experiment.dataset_id == dataset_id)
229
+ .order_by(models.Experiment.id.desc())
230
+ )
231
+ async with info.context.db() as session:
232
+ experiments = [
233
+ to_gql_experiment(experiment, sequence_number)
234
+ async for experiment, sequence_number in cast(
235
+ AsyncIterable[Tuple[models.Experiment, int]],
236
+ await session.stream(query),
237
+ )
238
+ ]
239
+ return connection_from_list(data=experiments, args=args)
240
+
241
+ @strawberry.field
242
+ async def experiment_annotation_summaries(
243
+ self, info: Info[Context, None]
244
+ ) -> List[ExperimentAnnotationSummary]:
245
+ dataset_id = self.id_attr
246
+ query = (
247
+ select(
248
+ models.ExperimentRunAnnotation.name,
249
+ func.min(models.ExperimentRunAnnotation.score),
250
+ func.max(models.ExperimentRunAnnotation.score),
251
+ func.avg(models.ExperimentRunAnnotation.score),
252
+ func.count(),
253
+ func.count(models.ExperimentRunAnnotation.error),
254
+ )
255
+ .join(
256
+ models.ExperimentRun,
257
+ models.ExperimentRunAnnotation.experiment_run_id == models.ExperimentRun.id,
258
+ )
259
+ .join(
260
+ models.Experiment,
261
+ models.ExperimentRun.experiment_id == models.Experiment.id,
262
+ )
263
+ .where(models.Experiment.dataset_id == dataset_id)
264
+ .group_by(models.ExperimentRunAnnotation.name)
265
+ .order_by(models.ExperimentRunAnnotation.name)
266
+ )
267
+ async with info.context.db() as session:
268
+ return [
269
+ ExperimentAnnotationSummary(
270
+ annotation_name=annotation_name,
271
+ min_score=min_score,
272
+ max_score=max_score,
273
+ mean_score=mean_score,
274
+ count=count_,
275
+ error_count=error_count,
276
+ )
277
+ async for (
278
+ annotation_name,
279
+ min_score,
280
+ max_score,
281
+ mean_score,
282
+ count_,
283
+ error_count,
284
+ ) in await session.stream(query)
285
+ ]
62
286
 
63
287
 
64
- def _get_requested_features_and_tags(
65
- core_dimensions: Iterable[ScalarDimension],
66
- requested_dimension_names: Optional[Set[str]] = UNSET,
67
- ) -> List[Dimension]:
288
+ def to_gql_dataset(dataset: models.Dataset) -> Dataset:
68
289
  """
69
- Returns requested features and tags as a list of strawberry Datasets. If no
70
- dimensions are explicitly requested, returns all features and tags.
290
+ Converts an ORM dataset to a GraphQL dataset.
71
291
  """
72
- requested_features_and_tags: List[Dimension] = []
73
- for id, dim in enumerate(core_dimensions):
74
- is_requested = (
75
- not isinstance(requested_dimension_names, Set)
76
- ) or dim.name in requested_dimension_names
77
- is_feature_or_tag = dim.role in (FEATURE, TAG)
78
- if is_requested and is_feature_or_tag:
79
- requested_features_and_tags.append(to_gql_dimension(id_attr=id, dimension=dim))
80
- return requested_features_and_tags
292
+ return Dataset(
293
+ id_attr=dataset.id,
294
+ name=dataset.name,
295
+ description=dataset.description,
296
+ metadata=dataset.metadata_,
297
+ created_at=dataset.created_at,
298
+ updated_at=dataset.updated_at,
299
+ )
@@ -0,0 +1,85 @@
1
+ from datetime import datetime
2
+ from typing import Optional
3
+
4
+ import strawberry
5
+ from sqlalchemy import select
6
+ from sqlalchemy.orm import joinedload
7
+ from strawberry import UNSET
8
+ from strawberry.relay.types import Connection, GlobalID, Node, NodeID
9
+ from strawberry.types import Info
10
+
11
+ from phoenix.db import models
12
+ from phoenix.server.api.context import Context
13
+ from phoenix.server.api.types.DatasetExampleRevision import DatasetExampleRevision
14
+ from phoenix.server.api.types.DatasetVersion import DatasetVersion
15
+ from phoenix.server.api.types.ExperimentRun import ExperimentRun, to_gql_experiment_run
16
+ from phoenix.server.api.types.node import from_global_id_with_expected_type
17
+ from phoenix.server.api.types.pagination import (
18
+ ConnectionArgs,
19
+ CursorString,
20
+ connection_from_list,
21
+ )
22
+ from phoenix.server.api.types.Span import Span, to_gql_span
23
+
24
+
25
+ @strawberry.type
26
+ class DatasetExample(Node):
27
+ id_attr: NodeID[int]
28
+ created_at: datetime
29
+ version_id: strawberry.Private[Optional[int]] = None
30
+
31
+ @strawberry.field
32
+ async def revision(
33
+ self,
34
+ info: Info[Context, None],
35
+ dataset_version_id: Optional[GlobalID] = UNSET,
36
+ ) -> DatasetExampleRevision:
37
+ example_id = self.id_attr
38
+ version_id: Optional[int] = None
39
+ if dataset_version_id:
40
+ version_id = from_global_id_with_expected_type(
41
+ global_id=dataset_version_id, expected_type_name=DatasetVersion.__name__
42
+ )
43
+ elif self.version_id is not None:
44
+ version_id = self.version_id
45
+ return await info.context.data_loaders.dataset_example_revisions.load(
46
+ (example_id, version_id)
47
+ )
48
+
49
+ @strawberry.field
50
+ async def span(
51
+ self,
52
+ info: Info[Context, None],
53
+ ) -> Optional[Span]:
54
+ return (
55
+ to_gql_span(span)
56
+ if (span := await info.context.data_loaders.dataset_example_spans.load(self.id_attr))
57
+ else None
58
+ )
59
+
60
+ @strawberry.field
61
+ async def experiment_runs(
62
+ self,
63
+ info: Info[Context, None],
64
+ first: Optional[int] = 50,
65
+ last: Optional[int] = UNSET,
66
+ after: Optional[CursorString] = UNSET,
67
+ before: Optional[CursorString] = UNSET,
68
+ ) -> Connection[ExperimentRun]:
69
+ args = ConnectionArgs(
70
+ first=first,
71
+ after=after if isinstance(after, CursorString) else None,
72
+ last=last,
73
+ before=before if isinstance(before, CursorString) else None,
74
+ )
75
+ example_id = self.id_attr
76
+ query = (
77
+ select(models.ExperimentRun)
78
+ .options(joinedload(models.ExperimentRun.trace).load_only(models.Trace.trace_id))
79
+ .join(models.Experiment, models.Experiment.id == models.ExperimentRun.experiment_id)
80
+ .where(models.ExperimentRun.dataset_example_id == example_id)
81
+ .order_by(models.Experiment.id.desc())
82
+ )
83
+ async with info.context.db() as session:
84
+ runs = (await session.scalars(query)).all()
85
+ return connection_from_list([to_gql_experiment_run(run) for run in runs], args)
@@ -0,0 +1,34 @@
1
+ from datetime import datetime
2
+ from enum import Enum
3
+
4
+ import strawberry
5
+
6
+ from phoenix.db import models
7
+ from phoenix.server.api.types.ExampleRevisionInterface import ExampleRevision
8
+
9
+
10
+ @strawberry.enum
11
+ class RevisionKind(Enum):
12
+ CREATE = "CREATE"
13
+ PATCH = "PATCH"
14
+ DELETE = "DELETE"
15
+
16
+
17
+ @strawberry.type
18
+ class DatasetExampleRevision(ExampleRevision):
19
+ """
20
+ Represents a revision (i.e., update or alteration) of a dataset example.
21
+ """
22
+
23
+ revision_kind: RevisionKind
24
+ created_at: datetime
25
+
26
+ @classmethod
27
+ def from_orm_revision(cls, revision: models.DatasetExampleRevision) -> "DatasetExampleRevision":
28
+ return cls(
29
+ input=revision.input,
30
+ output=revision.output,
31
+ metadata=revision.metadata_,
32
+ revision_kind=RevisionKind(revision.revision_kind),
33
+ created_at=revision.created_at,
34
+ )
@@ -0,0 +1,14 @@
1
+ from datetime import datetime
2
+ from typing import Optional
3
+
4
+ import strawberry
5
+ from strawberry.relay import Node, NodeID
6
+ from strawberry.scalars import JSON
7
+
8
+
9
+ @strawberry.type
10
+ class DatasetVersion(Node):
11
+ id_attr: NodeID[int]
12
+ description: Optional[str]
13
+ metadata: JSON
14
+ created_at: datetime
@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional
4
4
  import pandas as pd
5
5
  import strawberry
6
6
  from strawberry import UNSET
7
+ from strawberry.relay import Node, NodeID
7
8
  from strawberry.types import Info
8
9
  from typing_extensions import Annotated
9
10
 
@@ -17,12 +18,11 @@ from ..context import Context
17
18
  from ..input_types.Granularity import Granularity
18
19
  from ..input_types.TimeRange import TimeRange
19
20
  from .DataQualityMetric import DataQualityMetric
20
- from .DatasetRole import DatasetRole
21
21
  from .DatasetValues import DatasetValues
22
22
  from .DimensionDataType import DimensionDataType
23
23
  from .DimensionShape import DimensionShape
24
24
  from .DimensionType import DimensionType
25
- from .node import Node
25
+ from .InferencesRole import InferencesRole
26
26
  from .ScalarDriftMetricEnum import ScalarDriftMetric
27
27
  from .Segments import (
28
28
  GqlBinFactory,
@@ -40,6 +40,7 @@ from .TimeSeries import (
40
40
 
41
41
  @strawberry.type
42
42
  class Dimension(Node):
43
+ id_attr: NodeID[int]
43
44
  name: str = strawberry.field(description="The name of the dimension (a.k.a. the column name)")
44
45
  type: DimensionType = strawberry.field(
45
46
  description="Whether the dimension represents a feature, tag, prediction, or actual."
@@ -62,16 +63,16 @@ class Dimension(Node):
62
63
  """
63
64
  Computes a drift metric between all reference data and the primary data
64
65
  belonging to the input time range (inclusive of the time range start and
65
- exclusive of the time range end). Returns None if no reference dataset
66
- exists, if no primary data exists in the input time range, or if the
66
+ exclusive of the time range end). Returns None if no reference inferences
67
+ exist, if no primary data exists in the input time range, or if the
67
68
  input time range is invalid.
68
69
  """
69
70
  model = info.context.model
70
71
  if model[REFERENCE].empty:
71
72
  return None
72
- dataset = model[PRIMARY]
73
+ inferences = model[PRIMARY]
73
74
  time_range, granularity = ensure_timeseries_parameters(
74
- dataset,
75
+ inferences,
75
76
  time_range,
76
77
  )
77
78
  data = get_drift_timeseries_data(
@@ -92,18 +93,18 @@ class Dimension(Node):
92
93
  info: Info[Context, None],
93
94
  metric: DataQualityMetric,
94
95
  time_range: Optional[TimeRange] = UNSET,
95
- dataset_role: Annotated[
96
- Optional[DatasetRole],
96
+ inferences_role: Annotated[
97
+ Optional[InferencesRole],
97
98
  strawberry.argument(
98
- description="The dataset (primary or reference) to query",
99
+ description="The inferences (primary or reference) to query",
99
100
  ),
100
- ] = DatasetRole.primary,
101
+ ] = InferencesRole.primary,
101
102
  ) -> Optional[float]:
102
- if not isinstance(dataset_role, DatasetRole):
103
- dataset_role = DatasetRole.primary
104
- dataset = info.context.model[dataset_role.value]
103
+ if not isinstance(inferences_role, InferencesRole):
104
+ inferences_role = InferencesRole.primary
105
+ inferences = info.context.model[inferences_role.value]
105
106
  time_range, granularity = ensure_timeseries_parameters(
106
- dataset,
107
+ inferences,
107
108
  time_range,
108
109
  )
109
110
  data = get_data_quality_timeseries_data(
@@ -111,7 +112,7 @@ class Dimension(Node):
111
112
  metric,
112
113
  time_range,
113
114
  granularity,
114
- dataset_role,
115
+ inferences_role,
115
116
  )
116
117
  return data[0].value if len(data) else None
117
118
 
@@ -139,18 +140,18 @@ class Dimension(Node):
139
140
  metric: DataQualityMetric,
140
141
  time_range: TimeRange,
141
142
  granularity: Granularity,
142
- dataset_role: Annotated[
143
- Optional[DatasetRole],
143
+ inferences_role: Annotated[
144
+ Optional[InferencesRole],
144
145
  strawberry.argument(
145
- description="The dataset (primary or reference) to query",
146
+ description="The inferences (primary or reference) to query",
146
147
  ),
147
- ] = DatasetRole.primary,
148
+ ] = InferencesRole.primary,
148
149
  ) -> DataQualityTimeSeries:
149
- if not isinstance(dataset_role, DatasetRole):
150
- dataset_role = DatasetRole.primary
151
- dataset = info.context.model[dataset_role.value]
150
+ if not isinstance(inferences_role, InferencesRole):
151
+ inferences_role = InferencesRole.primary
152
+ inferences = info.context.model[inferences_role.value]
152
153
  time_range, granularity = ensure_timeseries_parameters(
153
- dataset,
154
+ inferences,
154
155
  time_range,
155
156
  granularity,
156
157
  )
@@ -160,7 +161,7 @@ class Dimension(Node):
160
161
  metric,
161
162
  time_range,
162
163
  granularity,
163
- dataset_role,
164
+ inferences_role,
164
165
  )
165
166
  )
166
167
 
@@ -182,9 +183,9 @@ class Dimension(Node):
182
183
  model = info.context.model
183
184
  if model[REFERENCE].empty:
184
185
  return DriftTimeSeries(data=[])
185
- dataset = model[PRIMARY]
186
+ inferences = model[PRIMARY]
186
187
  time_range, granularity = ensure_timeseries_parameters(
187
- dataset,
188
+ inferences,
188
189
  time_range,
189
190
  granularity,
190
191
  )
@@ -202,7 +203,7 @@ class Dimension(Node):
202
203
  )
203
204
 
204
205
  @strawberry.field(
205
- description="Returns the segments across both datasets and returns the counts per segment",
206
+ description="The segments across both inference sets and returns the counts per segment",
206
207
  ) # type: ignore
207
208
  def segments_comparison(
208
209
  self,
@@ -249,8 +250,8 @@ class Dimension(Node):
249
250
  if isinstance(binning_method, binning.IntervalBinning) and binning_method.bins is not None:
250
251
  all_bins = all_bins.union(binning_method.bins)
251
252
  for bin in all_bins:
252
- values: Dict[ms.DatasetRole, Any] = defaultdict(lambda: None)
253
- for role in ms.DatasetRole:
253
+ values: Dict[ms.InferencesRole, Any] = defaultdict(lambda: None)
254
+ for role in ms.InferencesRole:
254
255
  if model[role].empty:
255
256
  continue
256
257
  try: