arize-phoenix 4.4.3__py3-none-any.whl → 4.4.4rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. {arize_phoenix-4.4.3.dist-info → arize_phoenix-4.4.4rc0.dist-info}/METADATA +4 -4
  2. {arize_phoenix-4.4.3.dist-info → arize_phoenix-4.4.4rc0.dist-info}/RECORD +108 -55
  3. phoenix/__init__.py +0 -27
  4. phoenix/config.py +21 -7
  5. phoenix/core/model.py +25 -25
  6. phoenix/core/model_schema.py +64 -62
  7. phoenix/core/model_schema_adapter.py +27 -25
  8. phoenix/datasets/__init__.py +0 -0
  9. phoenix/datasets/evaluators.py +275 -0
  10. phoenix/datasets/experiments.py +469 -0
  11. phoenix/datasets/tracing.py +66 -0
  12. phoenix/datasets/types.py +212 -0
  13. phoenix/db/bulk_inserter.py +54 -14
  14. phoenix/db/insertion/dataset.py +234 -0
  15. phoenix/db/insertion/evaluation.py +6 -6
  16. phoenix/db/insertion/helpers.py +13 -2
  17. phoenix/db/migrations/types.py +29 -0
  18. phoenix/db/migrations/versions/10460e46d750_datasets.py +291 -0
  19. phoenix/db/migrations/versions/cf03bd6bae1d_init.py +2 -28
  20. phoenix/db/models.py +230 -3
  21. phoenix/inferences/fixtures.py +23 -23
  22. phoenix/inferences/inferences.py +7 -7
  23. phoenix/inferences/validation.py +1 -1
  24. phoenix/server/api/context.py +16 -0
  25. phoenix/server/api/dataloaders/__init__.py +16 -0
  26. phoenix/server/api/dataloaders/dataset_example_revisions.py +100 -0
  27. phoenix/server/api/dataloaders/dataset_example_spans.py +43 -0
  28. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +85 -0
  29. phoenix/server/api/dataloaders/experiment_error_rates.py +43 -0
  30. phoenix/server/api/dataloaders/experiment_sequence_number.py +49 -0
  31. phoenix/server/api/dataloaders/project_by_name.py +31 -0
  32. phoenix/server/api/dataloaders/span_descendants.py +2 -3
  33. phoenix/server/api/dataloaders/span_projects.py +33 -0
  34. phoenix/server/api/dataloaders/trace_row_ids.py +39 -0
  35. phoenix/server/api/helpers/dataset_helpers.py +178 -0
  36. phoenix/server/api/input_types/AddExamplesToDatasetInput.py +16 -0
  37. phoenix/server/api/input_types/AddSpansToDatasetInput.py +14 -0
  38. phoenix/server/api/input_types/CreateDatasetInput.py +12 -0
  39. phoenix/server/api/input_types/DatasetExampleInput.py +14 -0
  40. phoenix/server/api/input_types/DatasetSort.py +17 -0
  41. phoenix/server/api/input_types/DatasetVersionSort.py +16 -0
  42. phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +13 -0
  43. phoenix/server/api/input_types/DeleteDatasetInput.py +7 -0
  44. phoenix/server/api/input_types/DeleteExperimentsInput.py +9 -0
  45. phoenix/server/api/input_types/PatchDatasetExamplesInput.py +35 -0
  46. phoenix/server/api/input_types/PatchDatasetInput.py +14 -0
  47. phoenix/server/api/mutations/__init__.py +13 -0
  48. phoenix/server/api/mutations/auth.py +11 -0
  49. phoenix/server/api/mutations/dataset_mutations.py +520 -0
  50. phoenix/server/api/mutations/experiment_mutations.py +65 -0
  51. phoenix/server/api/{types/ExportEventsMutation.py → mutations/export_events_mutations.py} +17 -14
  52. phoenix/server/api/mutations/project_mutations.py +42 -0
  53. phoenix/server/api/queries.py +503 -0
  54. phoenix/server/api/routers/v1/__init__.py +77 -2
  55. phoenix/server/api/routers/v1/dataset_examples.py +178 -0
  56. phoenix/server/api/routers/v1/datasets.py +861 -0
  57. phoenix/server/api/routers/v1/evaluations.py +4 -2
  58. phoenix/server/api/routers/v1/experiment_evaluations.py +65 -0
  59. phoenix/server/api/routers/v1/experiment_runs.py +108 -0
  60. phoenix/server/api/routers/v1/experiments.py +174 -0
  61. phoenix/server/api/routers/v1/spans.py +3 -1
  62. phoenix/server/api/routers/v1/traces.py +1 -4
  63. phoenix/server/api/schema.py +2 -303
  64. phoenix/server/api/types/AnnotatorKind.py +10 -0
  65. phoenix/server/api/types/Cluster.py +19 -19
  66. phoenix/server/api/types/CreateDatasetPayload.py +8 -0
  67. phoenix/server/api/types/Dataset.py +282 -63
  68. phoenix/server/api/types/DatasetExample.py +85 -0
  69. phoenix/server/api/types/DatasetExampleRevision.py +34 -0
  70. phoenix/server/api/types/DatasetVersion.py +14 -0
  71. phoenix/server/api/types/Dimension.py +30 -29
  72. phoenix/server/api/types/EmbeddingDimension.py +40 -34
  73. phoenix/server/api/types/Event.py +16 -16
  74. phoenix/server/api/types/ExampleRevisionInterface.py +14 -0
  75. phoenix/server/api/types/Experiment.py +135 -0
  76. phoenix/server/api/types/ExperimentAnnotationSummary.py +13 -0
  77. phoenix/server/api/types/ExperimentComparison.py +19 -0
  78. phoenix/server/api/types/ExperimentRun.py +91 -0
  79. phoenix/server/api/types/ExperimentRunAnnotation.py +57 -0
  80. phoenix/server/api/types/Inferences.py +80 -0
  81. phoenix/server/api/types/InferencesRole.py +23 -0
  82. phoenix/server/api/types/Model.py +43 -42
  83. phoenix/server/api/types/Project.py +26 -12
  84. phoenix/server/api/types/Span.py +78 -2
  85. phoenix/server/api/types/TimeSeries.py +6 -6
  86. phoenix/server/api/types/Trace.py +15 -4
  87. phoenix/server/api/types/UMAPPoints.py +1 -1
  88. phoenix/server/api/types/node.py +5 -111
  89. phoenix/server/api/types/pagination.py +10 -52
  90. phoenix/server/app.py +99 -49
  91. phoenix/server/main.py +49 -27
  92. phoenix/server/openapi/docs.py +3 -0
  93. phoenix/server/static/index.js +2246 -1368
  94. phoenix/server/templates/index.html +1 -0
  95. phoenix/services.py +15 -15
  96. phoenix/session/client.py +316 -21
  97. phoenix/session/session.py +47 -37
  98. phoenix/trace/exporter.py +14 -9
  99. phoenix/trace/fixtures.py +133 -7
  100. phoenix/trace/span_evaluations.py +3 -3
  101. phoenix/trace/trace_dataset.py +6 -6
  102. phoenix/utilities/json.py +61 -0
  103. phoenix/utilities/re.py +50 -0
  104. phoenix/version.py +1 -1
  105. phoenix/server/api/types/DatasetRole.py +0 -23
  106. {arize_phoenix-4.4.3.dist-info → arize_phoenix-4.4.4rc0.dist-info}/WHEEL +0 -0
  107. {arize_phoenix-4.4.3.dist-info → arize_phoenix-4.4.4rc0.dist-info}/licenses/IP_NOTICE +0 -0
  108. {arize_phoenix-4.4.3.dist-info → arize_phoenix-4.4.4rc0.dist-info}/licenses/LICENSE +0 -0
  109. /phoenix/server/api/{helpers.py → helpers/__init__.py} +0 -0
@@ -1,308 +1,7 @@
1
- from collections import defaultdict
2
- from typing import Dict, List, Optional, Set, Union
3
-
4
- import numpy as np
5
- import numpy.typing as npt
6
1
  import strawberry
7
- from sqlalchemy import delete, select
8
- from sqlalchemy.orm import contains_eager, load_only
9
- from strawberry import ID, UNSET
10
- from strawberry.types import Info
11
- from typing_extensions import Annotated
12
-
13
- from phoenix.config import DEFAULT_PROJECT_NAME
14
- from phoenix.db import models
15
- from phoenix.db.insertion.span import ClearProjectSpansEvent
16
- from phoenix.pointcloud.clustering import Hdbscan
17
- from phoenix.server.api.context import Context
18
- from phoenix.server.api.helpers import ensure_list
19
- from phoenix.server.api.input_types.ClusterInput import ClusterInput
20
- from phoenix.server.api.input_types.Coordinates import (
21
- InputCoordinate2D,
22
- InputCoordinate3D,
23
- )
24
- from phoenix.server.api.types.Cluster import Cluster, to_gql_clusters
25
- from phoenix.server.api.types.DatasetRole import AncillaryDatasetRole, DatasetRole
26
- from phoenix.server.api.types.Dimension import to_gql_dimension
27
- from phoenix.server.api.types.EmbeddingDimension import (
28
- DEFAULT_CLUSTER_SELECTION_EPSILON,
29
- DEFAULT_MIN_CLUSTER_SIZE,
30
- DEFAULT_MIN_SAMPLES,
31
- to_gql_embedding_dimension,
32
- )
33
- from phoenix.server.api.types.Event import create_event_id, unpack_event_id
34
- from phoenix.server.api.types.ExportEventsMutation import ExportEventsMutation
35
- from phoenix.server.api.types.Functionality import Functionality
36
- from phoenix.server.api.types.Model import Model
37
- from phoenix.server.api.types.node import (
38
- GlobalID,
39
- Node,
40
- from_global_id,
41
- from_global_id_with_expected_type,
42
- )
43
- from phoenix.server.api.types.pagination import (
44
- Connection,
45
- ConnectionArgs,
46
- CursorString,
47
- connection_from_list,
48
- )
49
- from phoenix.server.api.types.Project import Project
50
- from phoenix.server.api.types.Span import to_gql_span
51
- from phoenix.server.api.types.Trace import Trace
52
-
53
-
54
- @strawberry.type
55
- class Query:
56
- @strawberry.field
57
- async def projects(
58
- self,
59
- info: Info[Context, None],
60
- first: Optional[int] = 50,
61
- last: Optional[int] = UNSET,
62
- after: Optional[CursorString] = UNSET,
63
- before: Optional[CursorString] = UNSET,
64
- ) -> Connection[Project]:
65
- args = ConnectionArgs(
66
- first=first,
67
- after=after if isinstance(after, CursorString) else None,
68
- last=last,
69
- before=before if isinstance(before, CursorString) else None,
70
- )
71
- async with info.context.db() as session:
72
- projects = await session.scalars(select(models.Project))
73
- data = [
74
- Project(
75
- id_attr=project.id,
76
- name=project.name,
77
- gradient_start_color=project.gradient_start_color,
78
- gradient_end_color=project.gradient_end_color,
79
- )
80
- for project in projects
81
- ]
82
- return connection_from_list(data=data, args=args)
83
-
84
- @strawberry.field
85
- async def functionality(self, info: Info[Context, None]) -> "Functionality":
86
- has_model_inferences = not info.context.model.is_empty
87
- async with info.context.db() as session:
88
- has_traces = (await session.scalar(select(models.Trace).limit(1))) is not None
89
- return Functionality(
90
- model_inferences=has_model_inferences,
91
- tracing=has_traces,
92
- )
93
-
94
- @strawberry.field
95
- def model(self) -> Model:
96
- return Model()
97
-
98
- @strawberry.field
99
- async def node(self, id: GlobalID, info: Info[Context, None]) -> Node:
100
- type_name, node_id = from_global_id(str(id))
101
- if type_name == "Dimension":
102
- dimension = info.context.model.scalar_dimensions[node_id]
103
- return to_gql_dimension(node_id, dimension)
104
- elif type_name == "EmbeddingDimension":
105
- embedding_dimension = info.context.model.embedding_dimensions[node_id]
106
- return to_gql_embedding_dimension(node_id, embedding_dimension)
107
- elif type_name == "Project":
108
- project_stmt = select(
109
- models.Project.id,
110
- models.Project.name,
111
- models.Project.gradient_start_color,
112
- models.Project.gradient_end_color,
113
- ).where(models.Project.id == node_id)
114
- async with info.context.db() as session:
115
- project = (await session.execute(project_stmt)).first()
116
- if project is None:
117
- raise ValueError(f"Unknown project: {id}")
118
- return Project(
119
- id_attr=project.id,
120
- name=project.name,
121
- gradient_start_color=project.gradient_start_color,
122
- gradient_end_color=project.gradient_end_color,
123
- )
124
- elif type_name == "Trace":
125
- trace_stmt = select(models.Trace.id).where(models.Trace.id == node_id)
126
- async with info.context.db() as session:
127
- id_attr = await session.scalar(trace_stmt)
128
- if id_attr is None:
129
- raise ValueError(f"Unknown trace: {id}")
130
- return Trace(id_attr=id_attr)
131
- elif type_name == "Span":
132
- span_stmt = (
133
- select(models.Span)
134
- .join(models.Trace)
135
- .options(contains_eager(models.Span.trace))
136
- .where(models.Span.id == node_id)
137
- )
138
- async with info.context.db() as session:
139
- span = await session.scalar(span_stmt)
140
- if span is None:
141
- raise ValueError(f"Unknown span: {id}")
142
- return to_gql_span(span)
143
- raise Exception(f"Unknown node type: {type_name}")
144
-
145
- @strawberry.field
146
- def clusters(
147
- self,
148
- clusters: List[ClusterInput],
149
- ) -> List[Cluster]:
150
- clustered_events: Dict[str, Set[ID]] = defaultdict(set)
151
- for i, cluster in enumerate(clusters):
152
- clustered_events[cluster.id or str(i)].update(cluster.event_ids)
153
- return to_gql_clusters(
154
- clustered_events=clustered_events,
155
- )
156
-
157
- @strawberry.field
158
- def hdbscan_clustering(
159
- self,
160
- info: Info[Context, None],
161
- event_ids: Annotated[
162
- List[ID],
163
- strawberry.argument(
164
- description="Event ID of the coordinates",
165
- ),
166
- ],
167
- coordinates_2d: Annotated[
168
- Optional[List[InputCoordinate2D]],
169
- strawberry.argument(
170
- description="Point coordinates. Must be either 2D or 3D.",
171
- ),
172
- ] = UNSET,
173
- coordinates_3d: Annotated[
174
- Optional[List[InputCoordinate3D]],
175
- strawberry.argument(
176
- description="Point coordinates. Must be either 2D or 3D.",
177
- ),
178
- ] = UNSET,
179
- min_cluster_size: Annotated[
180
- int,
181
- strawberry.argument(
182
- description="HDBSCAN minimum cluster size",
183
- ),
184
- ] = DEFAULT_MIN_CLUSTER_SIZE,
185
- cluster_min_samples: Annotated[
186
- int,
187
- strawberry.argument(
188
- description="HDBSCAN minimum samples",
189
- ),
190
- ] = DEFAULT_MIN_SAMPLES,
191
- cluster_selection_epsilon: Annotated[
192
- float,
193
- strawberry.argument(
194
- description="HDBSCAN cluster selection epsilon",
195
- ),
196
- ] = DEFAULT_CLUSTER_SELECTION_EPSILON,
197
- ) -> List[Cluster]:
198
- coordinates_3d = ensure_list(coordinates_3d)
199
- coordinates_2d = ensure_list(coordinates_2d)
200
-
201
- if len(coordinates_3d) > 0 and len(coordinates_2d) > 0:
202
- raise ValueError("must specify only one of 2D or 3D coordinates")
203
-
204
- if len(coordinates_3d) > 0:
205
- coordinates = list(
206
- map(
207
- lambda coord: np.array(
208
- [coord.x, coord.y, coord.z],
209
- ),
210
- coordinates_3d,
211
- )
212
- )
213
- else:
214
- coordinates = list(
215
- map(
216
- lambda coord: np.array(
217
- [coord.x, coord.y],
218
- ),
219
- coordinates_2d,
220
- )
221
- )
222
-
223
- if len(event_ids) != len(coordinates):
224
- raise ValueError(
225
- f"length mismatch between "
226
- f"event_ids ({len(event_ids)}) "
227
- f"and coordinates ({len(coordinates)})"
228
- )
229
-
230
- if len(event_ids) == 0:
231
- return []
232
-
233
- grouped_event_ids: Dict[
234
- Union[DatasetRole, AncillaryDatasetRole],
235
- List[ID],
236
- ] = defaultdict(list)
237
- grouped_coordinates: Dict[
238
- Union[DatasetRole, AncillaryDatasetRole],
239
- List[npt.NDArray[np.float64]],
240
- ] = defaultdict(list)
241
-
242
- for event_id, coordinate in zip(event_ids, coordinates):
243
- row_id, dataset_role = unpack_event_id(event_id)
244
- grouped_coordinates[dataset_role].append(coordinate)
245
- grouped_event_ids[dataset_role].append(create_event_id(row_id, dataset_role))
246
-
247
- stacked_event_ids = (
248
- grouped_event_ids[DatasetRole.primary]
249
- + grouped_event_ids[DatasetRole.reference]
250
- + grouped_event_ids[AncillaryDatasetRole.corpus]
251
- )
252
- stacked_coordinates = np.stack(
253
- grouped_coordinates[DatasetRole.primary]
254
- + grouped_coordinates[DatasetRole.reference]
255
- + grouped_coordinates[AncillaryDatasetRole.corpus]
256
- )
257
-
258
- clusters = Hdbscan(
259
- min_cluster_size=min_cluster_size,
260
- min_samples=cluster_min_samples,
261
- cluster_selection_epsilon=cluster_selection_epsilon,
262
- ).find_clusters(stacked_coordinates)
263
-
264
- clustered_events = {
265
- str(i): {stacked_event_ids[row_idx] for row_idx in cluster}
266
- for i, cluster in enumerate(clusters)
267
- }
268
-
269
- return to_gql_clusters(
270
- clustered_events=clustered_events,
271
- )
272
-
273
-
274
- @strawberry.type
275
- class Mutation(ExportEventsMutation):
276
- @strawberry.mutation
277
- async def delete_project(self, info: Info[Context, None], id: GlobalID) -> Query:
278
- if info.context.read_only:
279
- return Query()
280
- node_id = from_global_id_with_expected_type(str(id), "Project")
281
- async with info.context.db() as session:
282
- project = await session.scalar(
283
- select(models.Project)
284
- .where(models.Project.id == node_id)
285
- .options(load_only(models.Project.name))
286
- )
287
- if project is None:
288
- raise ValueError(f"Unknown project: {id}")
289
- if project.name == DEFAULT_PROJECT_NAME:
290
- raise ValueError(f"Cannot delete the {DEFAULT_PROJECT_NAME} project")
291
- await session.delete(project)
292
- return Query()
293
-
294
- @strawberry.mutation
295
- async def clear_project(self, info: Info[Context, None], id: GlobalID) -> Query:
296
- if info.context.read_only:
297
- return Query()
298
- project_id = from_global_id_with_expected_type(str(id), "Project")
299
- delete_statement = delete(models.Trace).where(models.Trace.project_rowid == project_id)
300
- async with info.context.db() as session:
301
- await session.execute(delete_statement)
302
- if cache := info.context.cache_for_dataloaders:
303
- cache.invalidate(ClearProjectSpansEvent(project_rowid=project_id))
304
- return Query()
305
2
 
3
+ from phoenix.server.api.mutations import Mutation
4
+ from phoenix.server.api.queries import Query
306
5
 
307
6
  # This is the schema for generating `schema.graphql`.
308
7
  # See https://strawberry.rocks/docs/guides/schema-export
@@ -0,0 +1,10 @@
1
+ from enum import Enum
2
+
3
+ import strawberry
4
+
5
+
6
+ @strawberry.enum
7
+ class AnnotatorKind(Enum):
8
+ LLM = "LLM"
9
+ HUMAN = "HUMAN"
10
+ CODE = "CODE"
@@ -9,9 +9,9 @@ from phoenix.core.model_schema import PRIMARY, REFERENCE
9
9
  from phoenix.server.api.context import Context
10
10
  from phoenix.server.api.input_types.DataQualityMetricInput import DataQualityMetricInput
11
11
  from phoenix.server.api.input_types.PerformanceMetricInput import PerformanceMetricInput
12
- from phoenix.server.api.types.DatasetRole import AncillaryDatasetRole, DatasetRole
13
12
  from phoenix.server.api.types.DatasetValues import DatasetValues
14
13
  from phoenix.server.api.types.Event import unpack_event_id
14
+ from phoenix.server.api.types.InferencesRole import AncillaryInferencesRole, InferencesRole
15
15
 
16
16
 
17
17
  @strawberry.type
@@ -36,8 +36,8 @@ class Cluster:
36
36
  """
37
37
  Calculates the drift score of the cluster. The score will be a value
38
38
  representing the balance of points between the primary and the reference
39
- datasets, and will be on a scale between 1 (all primary) and -1 (all
40
- reference), with 0 being an even balance between the two datasets.
39
+ inferences, and will be on a scale between 1 (all primary) and -1 (all
40
+ reference), with 0 being an even balance between the two inference sets.
41
41
 
42
42
  Returns
43
43
  -------
@@ -47,8 +47,8 @@ class Cluster:
47
47
  if model[REFERENCE].empty:
48
48
  return None
49
49
  count_by_role = Counter(unpack_event_id(event_id)[1] for event_id in self.event_ids)
50
- primary_count = count_by_role[DatasetRole.primary]
51
- reference_count = count_by_role[DatasetRole.reference]
50
+ primary_count = count_by_role[InferencesRole.primary]
51
+ reference_count = count_by_role[InferencesRole.reference]
52
52
  return (
53
53
  None
54
54
  if not (denominator := (primary_count + reference_count))
@@ -76,8 +76,8 @@ class Cluster:
76
76
  if corpus is None or corpus[PRIMARY].empty:
77
77
  return None
78
78
  count_by_role = Counter(unpack_event_id(event_id)[1] for event_id in self.event_ids)
79
- primary_count = count_by_role[DatasetRole.primary]
80
- corpus_count = count_by_role[AncillaryDatasetRole.corpus]
79
+ primary_count = count_by_role[InferencesRole.primary]
80
+ corpus_count = count_by_role[AncillaryInferencesRole.corpus]
81
81
  return (
82
82
  None
83
83
  if not (denominator := (primary_count + corpus_count))
@@ -94,19 +94,19 @@ class Cluster:
94
94
  metric: DataQualityMetricInput,
95
95
  ) -> DatasetValues:
96
96
  model = info.context.model
97
- row_ids: Dict[DatasetRole, List[int]] = defaultdict(list)
98
- for row_id, dataset_role in map(unpack_event_id, self.event_ids):
99
- if not isinstance(dataset_role, DatasetRole):
97
+ row_ids: Dict[InferencesRole, List[int]] = defaultdict(list)
98
+ for row_id, inferences_role in map(unpack_event_id, self.event_ids):
99
+ if not isinstance(inferences_role, InferencesRole):
100
100
  continue
101
- row_ids[dataset_role].append(row_id)
101
+ row_ids[inferences_role].append(row_id)
102
102
  return DatasetValues(
103
103
  primary_value=metric.metric_instance(
104
104
  model[PRIMARY],
105
- subset_rows=row_ids[DatasetRole.primary],
105
+ subset_rows=row_ids[InferencesRole.primary],
106
106
  ),
107
107
  reference_value=metric.metric_instance(
108
108
  model[REFERENCE],
109
- subset_rows=row_ids[DatasetRole.reference],
109
+ subset_rows=row_ids[InferencesRole.reference],
110
110
  ),
111
111
  )
112
112
 
@@ -120,20 +120,20 @@ class Cluster:
120
120
  metric: PerformanceMetricInput,
121
121
  ) -> DatasetValues:
122
122
  model = info.context.model
123
- row_ids: Dict[DatasetRole, List[int]] = defaultdict(list)
124
- for row_id, dataset_role in map(unpack_event_id, self.event_ids):
125
- if not isinstance(dataset_role, DatasetRole):
123
+ row_ids: Dict[InferencesRole, List[int]] = defaultdict(list)
124
+ for row_id, inferences_role in map(unpack_event_id, self.event_ids):
125
+ if not isinstance(inferences_role, InferencesRole):
126
126
  continue
127
- row_ids[dataset_role].append(row_id)
127
+ row_ids[inferences_role].append(row_id)
128
128
  metric_instance = metric.metric_instance(model)
129
129
  return DatasetValues(
130
130
  primary_value=metric_instance(
131
131
  model[PRIMARY],
132
- subset_rows=row_ids[DatasetRole.primary],
132
+ subset_rows=row_ids[InferencesRole.primary],
133
133
  ),
134
134
  reference_value=metric_instance(
135
135
  model[REFERENCE],
136
- subset_rows=row_ids[DatasetRole.reference],
136
+ subset_rows=row_ids[InferencesRole.reference],
137
137
  ),
138
138
  )
139
139
 
@@ -0,0 +1,8 @@
1
+ import strawberry
2
+
3
+ from phoenix.server.api.types.Dataset import Dataset
4
+
5
+
6
+ @strawberry.type
7
+ class CreateDatasetPayload:
8
+ dataset: Dataset