arize-phoenix 4.4.4rc6__py3-none-any.whl → 4.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (123) hide show
  1. {arize_phoenix-4.4.4rc6.dist-info → arize_phoenix-4.5.0.dist-info}/METADATA +8 -14
  2. {arize_phoenix-4.4.4rc6.dist-info → arize_phoenix-4.5.0.dist-info}/RECORD +58 -122
  3. {arize_phoenix-4.4.4rc6.dist-info → arize_phoenix-4.5.0.dist-info}/WHEEL +1 -1
  4. phoenix/__init__.py +27 -0
  5. phoenix/config.py +7 -42
  6. phoenix/core/model.py +25 -25
  7. phoenix/core/model_schema.py +62 -64
  8. phoenix/core/model_schema_adapter.py +25 -27
  9. phoenix/datetime_utils.py +0 -4
  10. phoenix/db/bulk_inserter.py +14 -54
  11. phoenix/db/insertion/evaluation.py +10 -10
  12. phoenix/db/insertion/helpers.py +14 -17
  13. phoenix/db/insertion/span.py +3 -3
  14. phoenix/db/migrations/versions/cf03bd6bae1d_init.py +28 -2
  15. phoenix/db/models.py +4 -236
  16. phoenix/inferences/fixtures.py +23 -23
  17. phoenix/inferences/inferences.py +7 -7
  18. phoenix/inferences/validation.py +1 -1
  19. phoenix/server/api/context.py +0 -20
  20. phoenix/server/api/dataloaders/__init__.py +0 -20
  21. phoenix/server/api/dataloaders/span_descendants.py +3 -2
  22. phoenix/server/api/routers/v1/__init__.py +2 -77
  23. phoenix/server/api/routers/v1/evaluations.py +13 -8
  24. phoenix/server/api/routers/v1/spans.py +5 -9
  25. phoenix/server/api/routers/v1/traces.py +4 -1
  26. phoenix/server/api/schema.py +303 -2
  27. phoenix/server/api/types/Cluster.py +19 -19
  28. phoenix/server/api/types/Dataset.py +63 -282
  29. phoenix/server/api/types/DatasetRole.py +23 -0
  30. phoenix/server/api/types/Dimension.py +29 -30
  31. phoenix/server/api/types/EmbeddingDimension.py +34 -40
  32. phoenix/server/api/types/Event.py +16 -16
  33. phoenix/server/api/{mutations/export_events_mutations.py → types/ExportEventsMutation.py} +14 -17
  34. phoenix/server/api/types/Model.py +42 -43
  35. phoenix/server/api/types/Project.py +12 -26
  36. phoenix/server/api/types/Span.py +2 -79
  37. phoenix/server/api/types/TimeSeries.py +6 -6
  38. phoenix/server/api/types/Trace.py +4 -15
  39. phoenix/server/api/types/UMAPPoints.py +1 -1
  40. phoenix/server/api/types/node.py +111 -5
  41. phoenix/server/api/types/pagination.py +52 -10
  42. phoenix/server/app.py +49 -103
  43. phoenix/server/main.py +27 -49
  44. phoenix/server/openapi/docs.py +0 -3
  45. phoenix/server/static/index.js +1384 -2390
  46. phoenix/server/templates/index.html +0 -1
  47. phoenix/services.py +15 -15
  48. phoenix/session/client.py +23 -611
  49. phoenix/session/session.py +37 -47
  50. phoenix/trace/exporter.py +9 -14
  51. phoenix/trace/fixtures.py +7 -133
  52. phoenix/trace/schemas.py +2 -1
  53. phoenix/trace/span_evaluations.py +3 -3
  54. phoenix/trace/trace_dataset.py +6 -6
  55. phoenix/version.py +1 -1
  56. phoenix/db/insertion/dataset.py +0 -237
  57. phoenix/db/migrations/types.py +0 -29
  58. phoenix/db/migrations/versions/10460e46d750_datasets.py +0 -291
  59. phoenix/experiments/__init__.py +0 -6
  60. phoenix/experiments/evaluators/__init__.py +0 -29
  61. phoenix/experiments/evaluators/base.py +0 -153
  62. phoenix/experiments/evaluators/code_evaluators.py +0 -99
  63. phoenix/experiments/evaluators/llm_evaluators.py +0 -244
  64. phoenix/experiments/evaluators/utils.py +0 -189
  65. phoenix/experiments/functions.py +0 -616
  66. phoenix/experiments/tracing.py +0 -85
  67. phoenix/experiments/types.py +0 -722
  68. phoenix/experiments/utils.py +0 -9
  69. phoenix/server/api/dataloaders/average_experiment_run_latency.py +0 -54
  70. phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -100
  71. phoenix/server/api/dataloaders/dataset_example_spans.py +0 -43
  72. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +0 -85
  73. phoenix/server/api/dataloaders/experiment_error_rates.py +0 -43
  74. phoenix/server/api/dataloaders/experiment_run_counts.py +0 -42
  75. phoenix/server/api/dataloaders/experiment_sequence_number.py +0 -49
  76. phoenix/server/api/dataloaders/project_by_name.py +0 -31
  77. phoenix/server/api/dataloaders/span_projects.py +0 -33
  78. phoenix/server/api/dataloaders/trace_row_ids.py +0 -39
  79. phoenix/server/api/helpers/dataset_helpers.py +0 -179
  80. phoenix/server/api/input_types/AddExamplesToDatasetInput.py +0 -16
  81. phoenix/server/api/input_types/AddSpansToDatasetInput.py +0 -14
  82. phoenix/server/api/input_types/ClearProjectInput.py +0 -15
  83. phoenix/server/api/input_types/CreateDatasetInput.py +0 -12
  84. phoenix/server/api/input_types/DatasetExampleInput.py +0 -14
  85. phoenix/server/api/input_types/DatasetSort.py +0 -17
  86. phoenix/server/api/input_types/DatasetVersionSort.py +0 -16
  87. phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +0 -13
  88. phoenix/server/api/input_types/DeleteDatasetInput.py +0 -7
  89. phoenix/server/api/input_types/DeleteExperimentsInput.py +0 -9
  90. phoenix/server/api/input_types/PatchDatasetExamplesInput.py +0 -35
  91. phoenix/server/api/input_types/PatchDatasetInput.py +0 -14
  92. phoenix/server/api/mutations/__init__.py +0 -13
  93. phoenix/server/api/mutations/auth.py +0 -11
  94. phoenix/server/api/mutations/dataset_mutations.py +0 -520
  95. phoenix/server/api/mutations/experiment_mutations.py +0 -65
  96. phoenix/server/api/mutations/project_mutations.py +0 -47
  97. phoenix/server/api/openapi/__init__.py +0 -0
  98. phoenix/server/api/openapi/main.py +0 -6
  99. phoenix/server/api/openapi/schema.py +0 -16
  100. phoenix/server/api/queries.py +0 -503
  101. phoenix/server/api/routers/v1/dataset_examples.py +0 -178
  102. phoenix/server/api/routers/v1/datasets.py +0 -965
  103. phoenix/server/api/routers/v1/experiment_evaluations.py +0 -65
  104. phoenix/server/api/routers/v1/experiment_runs.py +0 -96
  105. phoenix/server/api/routers/v1/experiments.py +0 -174
  106. phoenix/server/api/types/AnnotatorKind.py +0 -10
  107. phoenix/server/api/types/CreateDatasetPayload.py +0 -8
  108. phoenix/server/api/types/DatasetExample.py +0 -85
  109. phoenix/server/api/types/DatasetExampleRevision.py +0 -34
  110. phoenix/server/api/types/DatasetVersion.py +0 -14
  111. phoenix/server/api/types/ExampleRevisionInterface.py +0 -14
  112. phoenix/server/api/types/Experiment.py +0 -147
  113. phoenix/server/api/types/ExperimentAnnotationSummary.py +0 -13
  114. phoenix/server/api/types/ExperimentComparison.py +0 -19
  115. phoenix/server/api/types/ExperimentRun.py +0 -91
  116. phoenix/server/api/types/ExperimentRunAnnotation.py +0 -57
  117. phoenix/server/api/types/Inferences.py +0 -80
  118. phoenix/server/api/types/InferencesRole.py +0 -23
  119. phoenix/utilities/json.py +0 -61
  120. phoenix/utilities/re.py +0 -50
  121. {arize_phoenix-4.4.4rc6.dist-info → arize_phoenix-4.5.0.dist-info}/licenses/IP_NOTICE +0 -0
  122. {arize_phoenix-4.4.4rc6.dist-info → arize_phoenix-4.5.0.dist-info}/licenses/LICENSE +0 -0
  123. /phoenix/server/api/{helpers/__init__.py → helpers.py} +0 -0
@@ -9,9 +9,9 @@ from phoenix.core.model_schema import PRIMARY, REFERENCE
9
9
  from phoenix.server.api.context import Context
10
10
  from phoenix.server.api.input_types.DataQualityMetricInput import DataQualityMetricInput
11
11
  from phoenix.server.api.input_types.PerformanceMetricInput import PerformanceMetricInput
12
+ from phoenix.server.api.types.DatasetRole import AncillaryDatasetRole, DatasetRole
12
13
  from phoenix.server.api.types.DatasetValues import DatasetValues
13
14
  from phoenix.server.api.types.Event import unpack_event_id
14
- from phoenix.server.api.types.InferencesRole import AncillaryInferencesRole, InferencesRole
15
15
 
16
16
 
17
17
  @strawberry.type
@@ -36,8 +36,8 @@ class Cluster:
36
36
  """
37
37
  Calculates the drift score of the cluster. The score will be a value
38
38
  representing the balance of points between the primary and the reference
39
- inferences, and will be on a scale between 1 (all primary) and -1 (all
40
- reference), with 0 being an even balance between the two inference sets.
39
+ datasets, and will be on a scale between 1 (all primary) and -1 (all
40
+ reference), with 0 being an even balance between the two datasets.
41
41
 
42
42
  Returns
43
43
  -------
@@ -47,8 +47,8 @@ class Cluster:
47
47
  if model[REFERENCE].empty:
48
48
  return None
49
49
  count_by_role = Counter(unpack_event_id(event_id)[1] for event_id in self.event_ids)
50
- primary_count = count_by_role[InferencesRole.primary]
51
- reference_count = count_by_role[InferencesRole.reference]
50
+ primary_count = count_by_role[DatasetRole.primary]
51
+ reference_count = count_by_role[DatasetRole.reference]
52
52
  return (
53
53
  None
54
54
  if not (denominator := (primary_count + reference_count))
@@ -76,8 +76,8 @@ class Cluster:
76
76
  if corpus is None or corpus[PRIMARY].empty:
77
77
  return None
78
78
  count_by_role = Counter(unpack_event_id(event_id)[1] for event_id in self.event_ids)
79
- primary_count = count_by_role[InferencesRole.primary]
80
- corpus_count = count_by_role[AncillaryInferencesRole.corpus]
79
+ primary_count = count_by_role[DatasetRole.primary]
80
+ corpus_count = count_by_role[AncillaryDatasetRole.corpus]
81
81
  return (
82
82
  None
83
83
  if not (denominator := (primary_count + corpus_count))
@@ -94,19 +94,19 @@ class Cluster:
94
94
  metric: DataQualityMetricInput,
95
95
  ) -> DatasetValues:
96
96
  model = info.context.model
97
- row_ids: Dict[InferencesRole, List[int]] = defaultdict(list)
98
- for row_id, inferences_role in map(unpack_event_id, self.event_ids):
99
- if not isinstance(inferences_role, InferencesRole):
97
+ row_ids: Dict[DatasetRole, List[int]] = defaultdict(list)
98
+ for row_id, dataset_role in map(unpack_event_id, self.event_ids):
99
+ if not isinstance(dataset_role, DatasetRole):
100
100
  continue
101
- row_ids[inferences_role].append(row_id)
101
+ row_ids[dataset_role].append(row_id)
102
102
  return DatasetValues(
103
103
  primary_value=metric.metric_instance(
104
104
  model[PRIMARY],
105
- subset_rows=row_ids[InferencesRole.primary],
105
+ subset_rows=row_ids[DatasetRole.primary],
106
106
  ),
107
107
  reference_value=metric.metric_instance(
108
108
  model[REFERENCE],
109
- subset_rows=row_ids[InferencesRole.reference],
109
+ subset_rows=row_ids[DatasetRole.reference],
110
110
  ),
111
111
  )
112
112
 
@@ -120,20 +120,20 @@ class Cluster:
120
120
  metric: PerformanceMetricInput,
121
121
  ) -> DatasetValues:
122
122
  model = info.context.model
123
- row_ids: Dict[InferencesRole, List[int]] = defaultdict(list)
124
- for row_id, inferences_role in map(unpack_event_id, self.event_ids):
125
- if not isinstance(inferences_role, InferencesRole):
123
+ row_ids: Dict[DatasetRole, List[int]] = defaultdict(list)
124
+ for row_id, dataset_role in map(unpack_event_id, self.event_ids):
125
+ if not isinstance(dataset_role, DatasetRole):
126
126
  continue
127
- row_ids[inferences_role].append(row_id)
127
+ row_ids[dataset_role].append(row_id)
128
128
  metric_instance = metric.metric_instance(model)
129
129
  return DatasetValues(
130
130
  primary_value=metric_instance(
131
131
  model[PRIMARY],
132
- subset_rows=row_ids[InferencesRole.primary],
132
+ subset_rows=row_ids[DatasetRole.primary],
133
133
  ),
134
134
  reference_value=metric_instance(
135
135
  model[REFERENCE],
136
- subset_rows=row_ids[InferencesRole.reference],
136
+ subset_rows=row_ids[DatasetRole.reference],
137
137
  ),
138
138
  )
139
139
 
@@ -1,299 +1,80 @@
1
1
  from datetime import datetime
2
- from typing import AsyncIterable, List, Optional, Tuple, cast
2
+ from typing import Iterable, List, Optional, Set, Union
3
3
 
4
4
  import strawberry
5
- from sqlalchemy import and_, func, select
6
- from sqlalchemy.sql.functions import count
7
- from strawberry import UNSET
8
- from strawberry.relay import Connection, GlobalID, Node, NodeID
9
- from strawberry.scalars import JSON
10
- from strawberry.types import Info
5
+ from strawberry.scalars import ID
6
+ from strawberry.unset import UNSET
11
7
 
12
- from phoenix.db import models
13
- from phoenix.server.api.context import Context
14
- from phoenix.server.api.input_types.DatasetVersionSort import DatasetVersionSort
15
- from phoenix.server.api.types.DatasetExample import DatasetExample
16
- from phoenix.server.api.types.DatasetVersion import DatasetVersion
17
- from phoenix.server.api.types.Experiment import Experiment, to_gql_experiment
18
- from phoenix.server.api.types.ExperimentAnnotationSummary import ExperimentAnnotationSummary
19
- from phoenix.server.api.types.node import from_global_id_with_expected_type
20
- from phoenix.server.api.types.pagination import (
21
- ConnectionArgs,
22
- CursorString,
23
- connection_from_list,
24
- )
25
- from phoenix.server.api.types.SortDir import SortDir
8
+ import phoenix.core.model_schema as ms
9
+ from phoenix.core.model_schema import FEATURE, TAG, ScalarDimension
26
10
 
11
+ from ..input_types.DimensionInput import DimensionInput
12
+ from .DatasetRole import AncillaryDatasetRole, DatasetRole
13
+ from .Dimension import Dimension, to_gql_dimension
14
+ from .Event import Event, create_event, create_event_id, parse_event_ids_by_dataset_role
27
15
 
28
- @strawberry.type
29
- class Dataset(Node):
30
- id_attr: NodeID[int]
31
- name: str
32
- description: Optional[str]
33
- metadata: JSON
34
- created_at: datetime
35
- updated_at: datetime
36
-
37
- @strawberry.field
38
- async def versions(
39
- self,
40
- info: Info[Context, None],
41
- first: Optional[int] = 50,
42
- last: Optional[int] = UNSET,
43
- after: Optional[CursorString] = UNSET,
44
- before: Optional[CursorString] = UNSET,
45
- sort: Optional[DatasetVersionSort] = UNSET,
46
- ) -> Connection[DatasetVersion]:
47
- args = ConnectionArgs(
48
- first=first,
49
- after=after if isinstance(after, CursorString) else None,
50
- last=last,
51
- before=before if isinstance(before, CursorString) else None,
52
- )
53
- async with info.context.db() as session:
54
- stmt = select(models.DatasetVersion).filter_by(dataset_id=self.id_attr)
55
- if sort:
56
- # For now assume the the column names match 1:1 with the enum values
57
- sort_col = getattr(models.DatasetVersion, sort.col.value)
58
- if sort.dir is SortDir.desc:
59
- stmt = stmt.order_by(sort_col.desc(), models.DatasetVersion.id.desc())
60
- else:
61
- stmt = stmt.order_by(sort_col.asc(), models.DatasetVersion.id.asc())
62
- else:
63
- stmt = stmt.order_by(models.DatasetVersion.created_at.desc())
64
- versions = await session.scalars(stmt)
65
- data = [
66
- DatasetVersion(
67
- id_attr=version.id,
68
- description=version.description,
69
- metadata=version.metadata_,
70
- created_at=version.created_at,
71
- )
72
- for version in versions
73
- ]
74
- return connection_from_list(data=data, args=args)
75
16
 
76
- @strawberry.field(
77
- description="Number of examples in a specific version if version is specified, or in the "
78
- "latest version if version is not specified."
79
- ) # type: ignore
80
- async def example_count(
81
- self,
82
- info: Info[Context, None],
83
- dataset_version_id: Optional[GlobalID] = UNSET,
84
- ) -> int:
85
- dataset_id = self.id_attr
86
- version_id = (
87
- from_global_id_with_expected_type(
88
- global_id=dataset_version_id,
89
- expected_type_name=DatasetVersion.__name__,
90
- )
91
- if dataset_version_id
92
- else None
93
- )
94
- revision_ids = (
95
- select(func.max(models.DatasetExampleRevision.id))
96
- .join(models.DatasetExample)
97
- .where(models.DatasetExample.dataset_id == dataset_id)
98
- .group_by(models.DatasetExampleRevision.dataset_example_id)
99
- )
100
- if version_id:
101
- version_id_subquery = (
102
- select(models.DatasetVersion.id)
103
- .where(models.DatasetVersion.dataset_id == dataset_id)
104
- .where(models.DatasetVersion.id == version_id)
105
- .scalar_subquery()
106
- )
107
- revision_ids = revision_ids.where(
108
- models.DatasetExampleRevision.dataset_version_id <= version_id_subquery
109
- )
110
- stmt = (
111
- select(count(models.DatasetExampleRevision.id))
112
- .where(models.DatasetExampleRevision.id.in_(revision_ids))
113
- .where(models.DatasetExampleRevision.revision_kind != "DELETE")
114
- )
115
- async with info.context.db() as session:
116
- return (await session.scalar(stmt)) or 0
117
-
118
- @strawberry.field
119
- async def examples(
120
- self,
121
- info: Info[Context, None],
122
- dataset_version_id: Optional[GlobalID] = UNSET,
123
- first: Optional[int] = 50,
124
- last: Optional[int] = UNSET,
125
- after: Optional[CursorString] = UNSET,
126
- before: Optional[CursorString] = UNSET,
127
- ) -> Connection[DatasetExample]:
128
- args = ConnectionArgs(
129
- first=first,
130
- after=after if isinstance(after, CursorString) else None,
131
- last=last,
132
- before=before if isinstance(before, CursorString) else None,
133
- )
134
- dataset_id = self.id_attr
135
- version_id = (
136
- from_global_id_with_expected_type(
137
- global_id=dataset_version_id, expected_type_name=DatasetVersion.__name__
138
- )
139
- if dataset_version_id
140
- else None
141
- )
142
- revision_ids = (
143
- select(func.max(models.DatasetExampleRevision.id))
144
- .join(models.DatasetExample)
145
- .where(models.DatasetExample.dataset_id == dataset_id)
146
- .group_by(models.DatasetExampleRevision.dataset_example_id)
147
- )
148
- if version_id:
149
- version_id_subquery = (
150
- select(models.DatasetVersion.id)
151
- .where(models.DatasetVersion.dataset_id == dataset_id)
152
- .where(models.DatasetVersion.id == version_id)
153
- .scalar_subquery()
154
- )
155
- revision_ids = revision_ids.where(
156
- models.DatasetExampleRevision.dataset_version_id <= version_id_subquery
157
- )
158
- query = (
159
- select(models.DatasetExample)
160
- .join(
161
- models.DatasetExampleRevision,
162
- onclause=models.DatasetExample.id
163
- == models.DatasetExampleRevision.dataset_example_id,
164
- )
165
- .where(
166
- and_(
167
- models.DatasetExampleRevision.id.in_(revision_ids),
168
- models.DatasetExampleRevision.revision_kind != "DELETE",
169
- )
170
- )
171
- .order_by(models.DatasetExampleRevision.dataset_example_id.desc())
172
- )
173
- async with info.context.db() as session:
174
- dataset_examples = [
175
- DatasetExample(
176
- id_attr=example.id,
177
- version_id=version_id,
178
- created_at=example.created_at,
179
- )
180
- async for example in await session.stream_scalars(query)
181
- ]
182
- return connection_from_list(data=dataset_examples, args=args)
17
+ @strawberry.type
18
+ class Dataset:
19
+ start_time: datetime = strawberry.field(description="The start bookend of the data")
20
+ end_time: datetime = strawberry.field(description="The end bookend of the data")
21
+ record_count: int = strawberry.field(description="The record count of the data")
22
+ dataset: strawberry.Private[ms.Dataset]
23
+ dataset_role: strawberry.Private[Union[DatasetRole, AncillaryDatasetRole]]
24
+ model: strawberry.Private[ms.Model]
183
25
 
184
- @strawberry.field(
185
- description="Number of experiments for a specific version if version is specified, "
186
- "or for all versions if version is not specified."
187
- ) # type: ignore
188
- async def experiment_count(
189
- self,
190
- info: Info[Context, None],
191
- dataset_version_id: Optional[GlobalID] = UNSET,
192
- ) -> int:
193
- stmt = select(count(models.Experiment.id)).where(
194
- models.Experiment.dataset_id == self.id_attr
195
- )
196
- version_id = (
197
- from_global_id_with_expected_type(
198
- global_id=dataset_version_id,
199
- expected_type_name=DatasetVersion.__name__,
200
- )
201
- if dataset_version_id
202
- else None
203
- )
204
- if version_id is not None:
205
- stmt = stmt.where(models.Experiment.dataset_version_id == version_id)
206
- async with info.context.db() as session:
207
- return (await session.scalar(stmt)) or 0
26
+ # type ignored here to get around the following: https://github.com/strawberry-graphql/strawberry/issues/1929
27
+ @strawberry.field(description="Returns a human friendly name for the dataset.") # type: ignore
28
+ def name(self) -> str:
29
+ return self.dataset.display_name
208
30
 
209
31
  @strawberry.field
210
- async def experiments(
32
+ def events(
211
33
  self,
212
- info: Info[Context, None],
213
- first: Optional[int] = 50,
214
- last: Optional[int] = UNSET,
215
- after: Optional[CursorString] = UNSET,
216
- before: Optional[CursorString] = UNSET,
217
- ) -> Connection[Experiment]:
218
- args = ConnectionArgs(
219
- first=first,
220
- after=after if isinstance(after, CursorString) else None,
221
- last=last,
222
- before=before if isinstance(before, CursorString) else None,
223
- )
224
- dataset_id = self.id_attr
225
- row_number = func.row_number().over(order_by=models.Experiment.id).label("row_number")
226
- query = (
227
- select(models.Experiment, row_number)
228
- .where(models.Experiment.dataset_id == dataset_id)
229
- .order_by(models.Experiment.id.desc())
34
+ event_ids: List[ID],
35
+ dimensions: Optional[List[DimensionInput]] = UNSET,
36
+ ) -> List[Event]:
37
+ """
38
+ Returns events for specific event IDs and dimensions. If no input
39
+ dimensions are provided, returns all features and tags.
40
+ """
41
+ if not event_ids:
42
+ return []
43
+ row_ids = parse_event_ids_by_dataset_role(event_ids)
44
+ if len(row_ids) > 1 or self.dataset_role not in row_ids:
45
+ raise ValueError("eventIds contains IDs from incorrect dataset.")
46
+ events = self.dataset[row_ids[self.dataset_role]]
47
+ requested_gql_dimensions = _get_requested_features_and_tags(
48
+ core_dimensions=self.model.scalar_dimensions,
49
+ requested_dimension_names=set(dim.name for dim in dimensions)
50
+ if isinstance(dimensions, list)
51
+ else None,
230
52
  )
231
- async with info.context.db() as session:
232
- experiments = [
233
- to_gql_experiment(experiment, sequence_number)
234
- async for experiment, sequence_number in cast(
235
- AsyncIterable[Tuple[models.Experiment, int]],
236
- await session.stream(query),
237
- )
238
- ]
239
- return connection_from_list(data=experiments, args=args)
240
-
241
- @strawberry.field
242
- async def experiment_annotation_summaries(
243
- self, info: Info[Context, None]
244
- ) -> List[ExperimentAnnotationSummary]:
245
- dataset_id = self.id_attr
246
- query = (
247
- select(
248
- models.ExperimentRunAnnotation.name,
249
- func.min(models.ExperimentRunAnnotation.score),
250
- func.max(models.ExperimentRunAnnotation.score),
251
- func.avg(models.ExperimentRunAnnotation.score),
252
- func.count(),
253
- func.count(models.ExperimentRunAnnotation.error),
53
+ return [
54
+ create_event(
55
+ event_id=create_event_id(event.id.row_id, self.dataset_role),
56
+ event=event,
57
+ dimensions=requested_gql_dimensions,
58
+ is_document_record=self.dataset_role is AncillaryDatasetRole.corpus,
254
59
  )
255
- .join(
256
- models.ExperimentRun,
257
- models.ExperimentRunAnnotation.experiment_run_id == models.ExperimentRun.id,
258
- )
259
- .join(
260
- models.Experiment,
261
- models.ExperimentRun.experiment_id == models.Experiment.id,
262
- )
263
- .where(models.Experiment.dataset_id == dataset_id)
264
- .group_by(models.ExperimentRunAnnotation.name)
265
- .order_by(models.ExperimentRunAnnotation.name)
266
- )
267
- async with info.context.db() as session:
268
- return [
269
- ExperimentAnnotationSummary(
270
- annotation_name=annotation_name,
271
- min_score=min_score,
272
- max_score=max_score,
273
- mean_score=mean_score,
274
- count=count_,
275
- error_count=error_count,
276
- )
277
- async for (
278
- annotation_name,
279
- min_score,
280
- max_score,
281
- mean_score,
282
- count_,
283
- error_count,
284
- ) in await session.stream(query)
285
- ]
60
+ for event in events
61
+ ]
286
62
 
287
63
 
288
- def to_gql_dataset(dataset: models.Dataset) -> Dataset:
64
+ def _get_requested_features_and_tags(
65
+ core_dimensions: Iterable[ScalarDimension],
66
+ requested_dimension_names: Optional[Set[str]] = UNSET,
67
+ ) -> List[Dimension]:
289
68
  """
290
- Converts an ORM dataset to a GraphQL dataset.
69
+ Returns requested features and tags as a list of strawberry Datasets. If no
70
+ dimensions are explicitly requested, returns all features and tags.
291
71
  """
292
- return Dataset(
293
- id_attr=dataset.id,
294
- name=dataset.name,
295
- description=dataset.description,
296
- metadata=dataset.metadata_,
297
- created_at=dataset.created_at,
298
- updated_at=dataset.updated_at,
299
- )
72
+ requested_features_and_tags: List[Dimension] = []
73
+ for id, dim in enumerate(core_dimensions):
74
+ is_requested = (
75
+ not isinstance(requested_dimension_names, Set)
76
+ ) or dim.name in requested_dimension_names
77
+ is_feature_or_tag = dim.role in (FEATURE, TAG)
78
+ if is_requested and is_feature_or_tag:
79
+ requested_features_and_tags.append(to_gql_dimension(id_attr=id, dimension=dim))
80
+ return requested_features_and_tags
@@ -0,0 +1,23 @@
1
+ from enum import Enum
2
+ from typing import Dict, Union
3
+
4
+ import strawberry
5
+
6
+ from phoenix.core.model_schema import PRIMARY, REFERENCE
7
+
8
+
9
+ @strawberry.enum
10
+ class DatasetRole(Enum):
11
+ primary = PRIMARY
12
+ reference = REFERENCE
13
+
14
+
15
+ class AncillaryDatasetRole(Enum):
16
+ corpus = "DatasetRole.CORPUS"
17
+
18
+
19
+ STR_TO_DATASET_ROLE: Dict[str, Union[DatasetRole, AncillaryDatasetRole]] = {
20
+ str(DatasetRole.primary.value): DatasetRole.primary,
21
+ str(DatasetRole.reference.value): DatasetRole.reference,
22
+ str(AncillaryDatasetRole.corpus.value): AncillaryDatasetRole.corpus,
23
+ }
@@ -4,7 +4,6 @@ from typing import Any, Dict, List, Optional
4
4
  import pandas as pd
5
5
  import strawberry
6
6
  from strawberry import UNSET
7
- from strawberry.relay import Node, NodeID
8
7
  from strawberry.types import Info
9
8
  from typing_extensions import Annotated
10
9
 
@@ -18,11 +17,12 @@ from ..context import Context
18
17
  from ..input_types.Granularity import Granularity
19
18
  from ..input_types.TimeRange import TimeRange
20
19
  from .DataQualityMetric import DataQualityMetric
20
+ from .DatasetRole import DatasetRole
21
21
  from .DatasetValues import DatasetValues
22
22
  from .DimensionDataType import DimensionDataType
23
23
  from .DimensionShape import DimensionShape
24
24
  from .DimensionType import DimensionType
25
- from .InferencesRole import InferencesRole
25
+ from .node import Node
26
26
  from .ScalarDriftMetricEnum import ScalarDriftMetric
27
27
  from .Segments import (
28
28
  GqlBinFactory,
@@ -40,7 +40,6 @@ from .TimeSeries import (
40
40
 
41
41
  @strawberry.type
42
42
  class Dimension(Node):
43
- id_attr: NodeID[int]
44
43
  name: str = strawberry.field(description="The name of the dimension (a.k.a. the column name)")
45
44
  type: DimensionType = strawberry.field(
46
45
  description="Whether the dimension represents a feature, tag, prediction, or actual."
@@ -63,16 +62,16 @@ class Dimension(Node):
63
62
  """
64
63
  Computes a drift metric between all reference data and the primary data
65
64
  belonging to the input time range (inclusive of the time range start and
66
- exclusive of the time range end). Returns None if no reference inferences
67
- exist, if no primary data exists in the input time range, or if the
65
+ exclusive of the time range end). Returns None if no reference dataset
66
+ exists, if no primary data exists in the input time range, or if the
68
67
  input time range is invalid.
69
68
  """
70
69
  model = info.context.model
71
70
  if model[REFERENCE].empty:
72
71
  return None
73
- inferences = model[PRIMARY]
72
+ dataset = model[PRIMARY]
74
73
  time_range, granularity = ensure_timeseries_parameters(
75
- inferences,
74
+ dataset,
76
75
  time_range,
77
76
  )
78
77
  data = get_drift_timeseries_data(
@@ -93,18 +92,18 @@ class Dimension(Node):
93
92
  info: Info[Context, None],
94
93
  metric: DataQualityMetric,
95
94
  time_range: Optional[TimeRange] = UNSET,
96
- inferences_role: Annotated[
97
- Optional[InferencesRole],
95
+ dataset_role: Annotated[
96
+ Optional[DatasetRole],
98
97
  strawberry.argument(
99
- description="The inferences (primary or reference) to query",
98
+ description="The dataset (primary or reference) to query",
100
99
  ),
101
- ] = InferencesRole.primary,
100
+ ] = DatasetRole.primary,
102
101
  ) -> Optional[float]:
103
- if not isinstance(inferences_role, InferencesRole):
104
- inferences_role = InferencesRole.primary
105
- inferences = info.context.model[inferences_role.value]
102
+ if not isinstance(dataset_role, DatasetRole):
103
+ dataset_role = DatasetRole.primary
104
+ dataset = info.context.model[dataset_role.value]
106
105
  time_range, granularity = ensure_timeseries_parameters(
107
- inferences,
106
+ dataset,
108
107
  time_range,
109
108
  )
110
109
  data = get_data_quality_timeseries_data(
@@ -112,7 +111,7 @@ class Dimension(Node):
112
111
  metric,
113
112
  time_range,
114
113
  granularity,
115
- inferences_role,
114
+ dataset_role,
116
115
  )
117
116
  return data[0].value if len(data) else None
118
117
 
@@ -140,18 +139,18 @@ class Dimension(Node):
140
139
  metric: DataQualityMetric,
141
140
  time_range: TimeRange,
142
141
  granularity: Granularity,
143
- inferences_role: Annotated[
144
- Optional[InferencesRole],
142
+ dataset_role: Annotated[
143
+ Optional[DatasetRole],
145
144
  strawberry.argument(
146
- description="The inferences (primary or reference) to query",
145
+ description="The dataset (primary or reference) to query",
147
146
  ),
148
- ] = InferencesRole.primary,
147
+ ] = DatasetRole.primary,
149
148
  ) -> DataQualityTimeSeries:
150
- if not isinstance(inferences_role, InferencesRole):
151
- inferences_role = InferencesRole.primary
152
- inferences = info.context.model[inferences_role.value]
149
+ if not isinstance(dataset_role, DatasetRole):
150
+ dataset_role = DatasetRole.primary
151
+ dataset = info.context.model[dataset_role.value]
153
152
  time_range, granularity = ensure_timeseries_parameters(
154
- inferences,
153
+ dataset,
155
154
  time_range,
156
155
  granularity,
157
156
  )
@@ -161,7 +160,7 @@ class Dimension(Node):
161
160
  metric,
162
161
  time_range,
163
162
  granularity,
164
- inferences_role,
163
+ dataset_role,
165
164
  )
166
165
  )
167
166
 
@@ -183,9 +182,9 @@ class Dimension(Node):
183
182
  model = info.context.model
184
183
  if model[REFERENCE].empty:
185
184
  return DriftTimeSeries(data=[])
186
- inferences = model[PRIMARY]
185
+ dataset = model[PRIMARY]
187
186
  time_range, granularity = ensure_timeseries_parameters(
188
- inferences,
187
+ dataset,
189
188
  time_range,
190
189
  granularity,
191
190
  )
@@ -203,7 +202,7 @@ class Dimension(Node):
203
202
  )
204
203
 
205
204
  @strawberry.field(
206
- description="The segments across both inference sets and returns the counts per segment",
205
+ description="Returns the segments across both datasets and returns the counts per segment",
207
206
  ) # type: ignore
208
207
  def segments_comparison(
209
208
  self,
@@ -250,8 +249,8 @@ class Dimension(Node):
250
249
  if isinstance(binning_method, binning.IntervalBinning) and binning_method.bins is not None:
251
250
  all_bins = all_bins.union(binning_method.bins)
252
251
  for bin in all_bins:
253
- values: Dict[ms.InferencesRole, Any] = defaultdict(lambda: None)
254
- for role in ms.InferencesRole:
252
+ values: Dict[ms.DatasetRole, Any] = defaultdict(lambda: None)
253
+ for role in ms.DatasetRole:
255
254
  if model[role].empty:
256
255
  continue
257
256
  try: