arize-phoenix 4.5.0__py3-none-any.whl → 4.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (123) hide show
  1. {arize_phoenix-4.5.0.dist-info → arize_phoenix-4.6.1.dist-info}/METADATA +16 -8
  2. {arize_phoenix-4.5.0.dist-info → arize_phoenix-4.6.1.dist-info}/RECORD +122 -58
  3. {arize_phoenix-4.5.0.dist-info → arize_phoenix-4.6.1.dist-info}/WHEEL +1 -1
  4. phoenix/__init__.py +0 -27
  5. phoenix/config.py +42 -7
  6. phoenix/core/model.py +25 -25
  7. phoenix/core/model_schema.py +64 -62
  8. phoenix/core/model_schema_adapter.py +27 -25
  9. phoenix/datetime_utils.py +4 -0
  10. phoenix/db/bulk_inserter.py +54 -14
  11. phoenix/db/insertion/dataset.py +237 -0
  12. phoenix/db/insertion/evaluation.py +10 -10
  13. phoenix/db/insertion/helpers.py +17 -14
  14. phoenix/db/insertion/span.py +3 -3
  15. phoenix/db/migrations/types.py +29 -0
  16. phoenix/db/migrations/versions/10460e46d750_datasets.py +291 -0
  17. phoenix/db/migrations/versions/cf03bd6bae1d_init.py +2 -28
  18. phoenix/db/models.py +236 -4
  19. phoenix/experiments/__init__.py +6 -0
  20. phoenix/experiments/evaluators/__init__.py +29 -0
  21. phoenix/experiments/evaluators/base.py +153 -0
  22. phoenix/experiments/evaluators/code_evaluators.py +99 -0
  23. phoenix/experiments/evaluators/llm_evaluators.py +244 -0
  24. phoenix/experiments/evaluators/utils.py +186 -0
  25. phoenix/experiments/functions.py +757 -0
  26. phoenix/experiments/tracing.py +85 -0
  27. phoenix/experiments/types.py +753 -0
  28. phoenix/experiments/utils.py +24 -0
  29. phoenix/inferences/fixtures.py +23 -23
  30. phoenix/inferences/inferences.py +7 -7
  31. phoenix/inferences/validation.py +1 -1
  32. phoenix/server/api/context.py +20 -0
  33. phoenix/server/api/dataloaders/__init__.py +20 -0
  34. phoenix/server/api/dataloaders/average_experiment_run_latency.py +54 -0
  35. phoenix/server/api/dataloaders/dataset_example_revisions.py +100 -0
  36. phoenix/server/api/dataloaders/dataset_example_spans.py +43 -0
  37. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +85 -0
  38. phoenix/server/api/dataloaders/experiment_error_rates.py +43 -0
  39. phoenix/server/api/dataloaders/experiment_run_counts.py +42 -0
  40. phoenix/server/api/dataloaders/experiment_sequence_number.py +49 -0
  41. phoenix/server/api/dataloaders/project_by_name.py +31 -0
  42. phoenix/server/api/dataloaders/span_descendants.py +2 -3
  43. phoenix/server/api/dataloaders/span_projects.py +33 -0
  44. phoenix/server/api/dataloaders/trace_row_ids.py +39 -0
  45. phoenix/server/api/helpers/dataset_helpers.py +179 -0
  46. phoenix/server/api/input_types/AddExamplesToDatasetInput.py +16 -0
  47. phoenix/server/api/input_types/AddSpansToDatasetInput.py +14 -0
  48. phoenix/server/api/input_types/ClearProjectInput.py +15 -0
  49. phoenix/server/api/input_types/CreateDatasetInput.py +12 -0
  50. phoenix/server/api/input_types/DatasetExampleInput.py +14 -0
  51. phoenix/server/api/input_types/DatasetSort.py +17 -0
  52. phoenix/server/api/input_types/DatasetVersionSort.py +16 -0
  53. phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +13 -0
  54. phoenix/server/api/input_types/DeleteDatasetInput.py +7 -0
  55. phoenix/server/api/input_types/DeleteExperimentsInput.py +9 -0
  56. phoenix/server/api/input_types/PatchDatasetExamplesInput.py +35 -0
  57. phoenix/server/api/input_types/PatchDatasetInput.py +14 -0
  58. phoenix/server/api/mutations/__init__.py +13 -0
  59. phoenix/server/api/mutations/auth.py +11 -0
  60. phoenix/server/api/mutations/dataset_mutations.py +520 -0
  61. phoenix/server/api/mutations/experiment_mutations.py +65 -0
  62. phoenix/server/api/{types/ExportEventsMutation.py → mutations/export_events_mutations.py} +17 -14
  63. phoenix/server/api/mutations/project_mutations.py +47 -0
  64. phoenix/server/api/openapi/__init__.py +0 -0
  65. phoenix/server/api/openapi/main.py +6 -0
  66. phoenix/server/api/openapi/schema.py +16 -0
  67. phoenix/server/api/queries.py +503 -0
  68. phoenix/server/api/routers/v1/__init__.py +77 -2
  69. phoenix/server/api/routers/v1/dataset_examples.py +178 -0
  70. phoenix/server/api/routers/v1/datasets.py +965 -0
  71. phoenix/server/api/routers/v1/evaluations.py +8 -13
  72. phoenix/server/api/routers/v1/experiment_evaluations.py +143 -0
  73. phoenix/server/api/routers/v1/experiment_runs.py +220 -0
  74. phoenix/server/api/routers/v1/experiments.py +302 -0
  75. phoenix/server/api/routers/v1/spans.py +9 -5
  76. phoenix/server/api/routers/v1/traces.py +1 -4
  77. phoenix/server/api/schema.py +2 -303
  78. phoenix/server/api/types/AnnotatorKind.py +10 -0
  79. phoenix/server/api/types/Cluster.py +19 -19
  80. phoenix/server/api/types/CreateDatasetPayload.py +8 -0
  81. phoenix/server/api/types/Dataset.py +282 -63
  82. phoenix/server/api/types/DatasetExample.py +85 -0
  83. phoenix/server/api/types/DatasetExampleRevision.py +34 -0
  84. phoenix/server/api/types/DatasetVersion.py +14 -0
  85. phoenix/server/api/types/Dimension.py +30 -29
  86. phoenix/server/api/types/EmbeddingDimension.py +40 -34
  87. phoenix/server/api/types/Event.py +16 -16
  88. phoenix/server/api/types/ExampleRevisionInterface.py +14 -0
  89. phoenix/server/api/types/Experiment.py +147 -0
  90. phoenix/server/api/types/ExperimentAnnotationSummary.py +13 -0
  91. phoenix/server/api/types/ExperimentComparison.py +19 -0
  92. phoenix/server/api/types/ExperimentRun.py +91 -0
  93. phoenix/server/api/types/ExperimentRunAnnotation.py +57 -0
  94. phoenix/server/api/types/Inferences.py +80 -0
  95. phoenix/server/api/types/InferencesRole.py +23 -0
  96. phoenix/server/api/types/Model.py +43 -42
  97. phoenix/server/api/types/Project.py +26 -12
  98. phoenix/server/api/types/Span.py +79 -2
  99. phoenix/server/api/types/TimeSeries.py +6 -6
  100. phoenix/server/api/types/Trace.py +15 -4
  101. phoenix/server/api/types/UMAPPoints.py +1 -1
  102. phoenix/server/api/types/node.py +5 -111
  103. phoenix/server/api/types/pagination.py +10 -52
  104. phoenix/server/app.py +103 -49
  105. phoenix/server/main.py +49 -27
  106. phoenix/server/openapi/docs.py +3 -0
  107. phoenix/server/static/index.js +2300 -1294
  108. phoenix/server/templates/index.html +1 -0
  109. phoenix/services.py +15 -15
  110. phoenix/session/client.py +581 -22
  111. phoenix/session/session.py +47 -37
  112. phoenix/trace/exporter.py +14 -9
  113. phoenix/trace/fixtures.py +133 -7
  114. phoenix/trace/schemas.py +1 -2
  115. phoenix/trace/span_evaluations.py +3 -3
  116. phoenix/trace/trace_dataset.py +6 -6
  117. phoenix/utilities/json.py +61 -0
  118. phoenix/utilities/re.py +50 -0
  119. phoenix/version.py +1 -1
  120. phoenix/server/api/types/DatasetRole.py +0 -23
  121. {arize_phoenix-4.5.0.dist-info → arize_phoenix-4.6.1.dist-info}/licenses/IP_NOTICE +0 -0
  122. {arize_phoenix-4.5.0.dist-info → arize_phoenix-4.6.1.dist-info}/licenses/LICENSE +0 -0
  123. /phoenix/server/api/{helpers.py → helpers/__init__.py} +0 -0
@@ -9,9 +9,9 @@ from phoenix.core.model_schema import PRIMARY, REFERENCE
9
9
  from phoenix.server.api.context import Context
10
10
  from phoenix.server.api.input_types.DataQualityMetricInput import DataQualityMetricInput
11
11
  from phoenix.server.api.input_types.PerformanceMetricInput import PerformanceMetricInput
12
- from phoenix.server.api.types.DatasetRole import AncillaryDatasetRole, DatasetRole
13
12
  from phoenix.server.api.types.DatasetValues import DatasetValues
14
13
  from phoenix.server.api.types.Event import unpack_event_id
14
+ from phoenix.server.api.types.InferencesRole import AncillaryInferencesRole, InferencesRole
15
15
 
16
16
 
17
17
  @strawberry.type
@@ -36,8 +36,8 @@ class Cluster:
36
36
  """
37
37
  Calculates the drift score of the cluster. The score will be a value
38
38
  representing the balance of points between the primary and the reference
39
- datasets, and will be on a scale between 1 (all primary) and -1 (all
40
- reference), with 0 being an even balance between the two datasets.
39
+ inferences, and will be on a scale between 1 (all primary) and -1 (all
40
+ reference), with 0 being an even balance between the two inference sets.
41
41
 
42
42
  Returns
43
43
  -------
@@ -47,8 +47,8 @@ class Cluster:
47
47
  if model[REFERENCE].empty:
48
48
  return None
49
49
  count_by_role = Counter(unpack_event_id(event_id)[1] for event_id in self.event_ids)
50
- primary_count = count_by_role[DatasetRole.primary]
51
- reference_count = count_by_role[DatasetRole.reference]
50
+ primary_count = count_by_role[InferencesRole.primary]
51
+ reference_count = count_by_role[InferencesRole.reference]
52
52
  return (
53
53
  None
54
54
  if not (denominator := (primary_count + reference_count))
@@ -76,8 +76,8 @@ class Cluster:
76
76
  if corpus is None or corpus[PRIMARY].empty:
77
77
  return None
78
78
  count_by_role = Counter(unpack_event_id(event_id)[1] for event_id in self.event_ids)
79
- primary_count = count_by_role[DatasetRole.primary]
80
- corpus_count = count_by_role[AncillaryDatasetRole.corpus]
79
+ primary_count = count_by_role[InferencesRole.primary]
80
+ corpus_count = count_by_role[AncillaryInferencesRole.corpus]
81
81
  return (
82
82
  None
83
83
  if not (denominator := (primary_count + corpus_count))
@@ -94,19 +94,19 @@ class Cluster:
94
94
  metric: DataQualityMetricInput,
95
95
  ) -> DatasetValues:
96
96
  model = info.context.model
97
- row_ids: Dict[DatasetRole, List[int]] = defaultdict(list)
98
- for row_id, dataset_role in map(unpack_event_id, self.event_ids):
99
- if not isinstance(dataset_role, DatasetRole):
97
+ row_ids: Dict[InferencesRole, List[int]] = defaultdict(list)
98
+ for row_id, inferences_role in map(unpack_event_id, self.event_ids):
99
+ if not isinstance(inferences_role, InferencesRole):
100
100
  continue
101
- row_ids[dataset_role].append(row_id)
101
+ row_ids[inferences_role].append(row_id)
102
102
  return DatasetValues(
103
103
  primary_value=metric.metric_instance(
104
104
  model[PRIMARY],
105
- subset_rows=row_ids[DatasetRole.primary],
105
+ subset_rows=row_ids[InferencesRole.primary],
106
106
  ),
107
107
  reference_value=metric.metric_instance(
108
108
  model[REFERENCE],
109
- subset_rows=row_ids[DatasetRole.reference],
109
+ subset_rows=row_ids[InferencesRole.reference],
110
110
  ),
111
111
  )
112
112
 
@@ -120,20 +120,20 @@ class Cluster:
120
120
  metric: PerformanceMetricInput,
121
121
  ) -> DatasetValues:
122
122
  model = info.context.model
123
- row_ids: Dict[DatasetRole, List[int]] = defaultdict(list)
124
- for row_id, dataset_role in map(unpack_event_id, self.event_ids):
125
- if not isinstance(dataset_role, DatasetRole):
123
+ row_ids: Dict[InferencesRole, List[int]] = defaultdict(list)
124
+ for row_id, inferences_role in map(unpack_event_id, self.event_ids):
125
+ if not isinstance(inferences_role, InferencesRole):
126
126
  continue
127
- row_ids[dataset_role].append(row_id)
127
+ row_ids[inferences_role].append(row_id)
128
128
  metric_instance = metric.metric_instance(model)
129
129
  return DatasetValues(
130
130
  primary_value=metric_instance(
131
131
  model[PRIMARY],
132
- subset_rows=row_ids[DatasetRole.primary],
132
+ subset_rows=row_ids[InferencesRole.primary],
133
133
  ),
134
134
  reference_value=metric_instance(
135
135
  model[REFERENCE],
136
- subset_rows=row_ids[DatasetRole.reference],
136
+ subset_rows=row_ids[InferencesRole.reference],
137
137
  ),
138
138
  )
139
139
 
@@ -0,0 +1,8 @@
1
+ import strawberry
2
+
3
+ from phoenix.server.api.types.Dataset import Dataset
4
+
5
+
6
+ @strawberry.type
7
+ class CreateDatasetPayload:
8
+ dataset: Dataset
@@ -1,80 +1,299 @@
1
1
  from datetime import datetime
2
- from typing import Iterable, List, Optional, Set, Union
2
+ from typing import AsyncIterable, List, Optional, Tuple, cast
3
3
 
4
4
  import strawberry
5
- from strawberry.scalars import ID
6
- from strawberry.unset import UNSET
5
+ from sqlalchemy import and_, func, select
6
+ from sqlalchemy.sql.functions import count
7
+ from strawberry import UNSET
8
+ from strawberry.relay import Connection, GlobalID, Node, NodeID
9
+ from strawberry.scalars import JSON
10
+ from strawberry.types import Info
7
11
 
8
- import phoenix.core.model_schema as ms
9
- from phoenix.core.model_schema import FEATURE, TAG, ScalarDimension
10
-
11
- from ..input_types.DimensionInput import DimensionInput
12
- from .DatasetRole import AncillaryDatasetRole, DatasetRole
13
- from .Dimension import Dimension, to_gql_dimension
14
- from .Event import Event, create_event, create_event_id, parse_event_ids_by_dataset_role
12
+ from phoenix.db import models
13
+ from phoenix.server.api.context import Context
14
+ from phoenix.server.api.input_types.DatasetVersionSort import DatasetVersionSort
15
+ from phoenix.server.api.types.DatasetExample import DatasetExample
16
+ from phoenix.server.api.types.DatasetVersion import DatasetVersion
17
+ from phoenix.server.api.types.Experiment import Experiment, to_gql_experiment
18
+ from phoenix.server.api.types.ExperimentAnnotationSummary import ExperimentAnnotationSummary
19
+ from phoenix.server.api.types.node import from_global_id_with_expected_type
20
+ from phoenix.server.api.types.pagination import (
21
+ ConnectionArgs,
22
+ CursorString,
23
+ connection_from_list,
24
+ )
25
+ from phoenix.server.api.types.SortDir import SortDir
15
26
 
16
27
 
17
28
  @strawberry.type
18
- class Dataset:
19
- start_time: datetime = strawberry.field(description="The start bookend of the data")
20
- end_time: datetime = strawberry.field(description="The end bookend of the data")
21
- record_count: int = strawberry.field(description="The record count of the data")
22
- dataset: strawberry.Private[ms.Dataset]
23
- dataset_role: strawberry.Private[Union[DatasetRole, AncillaryDatasetRole]]
24
- model: strawberry.Private[ms.Model]
25
-
26
- # type ignored here to get around the following: https://github.com/strawberry-graphql/strawberry/issues/1929
27
- @strawberry.field(description="Returns a human friendly name for the dataset.") # type: ignore
28
- def name(self) -> str:
29
- return self.dataset.display_name
29
+ class Dataset(Node):
30
+ id_attr: NodeID[int]
31
+ name: str
32
+ description: Optional[str]
33
+ metadata: JSON
34
+ created_at: datetime
35
+ updated_at: datetime
30
36
 
31
37
  @strawberry.field
32
- def events(
38
+ async def versions(
33
39
  self,
34
- event_ids: List[ID],
35
- dimensions: Optional[List[DimensionInput]] = UNSET,
36
- ) -> List[Event]:
37
- """
38
- Returns events for specific event IDs and dimensions. If no input
39
- dimensions are provided, returns all features and tags.
40
- """
41
- if not event_ids:
42
- return []
43
- row_ids = parse_event_ids_by_dataset_role(event_ids)
44
- if len(row_ids) > 1 or self.dataset_role not in row_ids:
45
- raise ValueError("eventIds contains IDs from incorrect dataset.")
46
- events = self.dataset[row_ids[self.dataset_role]]
47
- requested_gql_dimensions = _get_requested_features_and_tags(
48
- core_dimensions=self.model.scalar_dimensions,
49
- requested_dimension_names=set(dim.name for dim in dimensions)
50
- if isinstance(dimensions, list)
51
- else None,
40
+ info: Info[Context, None],
41
+ first: Optional[int] = 50,
42
+ last: Optional[int] = UNSET,
43
+ after: Optional[CursorString] = UNSET,
44
+ before: Optional[CursorString] = UNSET,
45
+ sort: Optional[DatasetVersionSort] = UNSET,
46
+ ) -> Connection[DatasetVersion]:
47
+ args = ConnectionArgs(
48
+ first=first,
49
+ after=after if isinstance(after, CursorString) else None,
50
+ last=last,
51
+ before=before if isinstance(before, CursorString) else None,
52
52
  )
53
- return [
54
- create_event(
55
- event_id=create_event_id(event.id.row_id, self.dataset_role),
56
- event=event,
57
- dimensions=requested_gql_dimensions,
58
- is_document_record=self.dataset_role is AncillaryDatasetRole.corpus,
53
+ async with info.context.db() as session:
54
+ stmt = select(models.DatasetVersion).filter_by(dataset_id=self.id_attr)
55
+ if sort:
56
+ # For now assume the the column names match 1:1 with the enum values
57
+ sort_col = getattr(models.DatasetVersion, sort.col.value)
58
+ if sort.dir is SortDir.desc:
59
+ stmt = stmt.order_by(sort_col.desc(), models.DatasetVersion.id.desc())
60
+ else:
61
+ stmt = stmt.order_by(sort_col.asc(), models.DatasetVersion.id.asc())
62
+ else:
63
+ stmt = stmt.order_by(models.DatasetVersion.created_at.desc())
64
+ versions = await session.scalars(stmt)
65
+ data = [
66
+ DatasetVersion(
67
+ id_attr=version.id,
68
+ description=version.description,
69
+ metadata=version.metadata_,
70
+ created_at=version.created_at,
59
71
  )
60
- for event in events
72
+ for version in versions
61
73
  ]
74
+ return connection_from_list(data=data, args=args)
75
+
76
+ @strawberry.field(
77
+ description="Number of examples in a specific version if version is specified, or in the "
78
+ "latest version if version is not specified."
79
+ ) # type: ignore
80
+ async def example_count(
81
+ self,
82
+ info: Info[Context, None],
83
+ dataset_version_id: Optional[GlobalID] = UNSET,
84
+ ) -> int:
85
+ dataset_id = self.id_attr
86
+ version_id = (
87
+ from_global_id_with_expected_type(
88
+ global_id=dataset_version_id,
89
+ expected_type_name=DatasetVersion.__name__,
90
+ )
91
+ if dataset_version_id
92
+ else None
93
+ )
94
+ revision_ids = (
95
+ select(func.max(models.DatasetExampleRevision.id))
96
+ .join(models.DatasetExample)
97
+ .where(models.DatasetExample.dataset_id == dataset_id)
98
+ .group_by(models.DatasetExampleRevision.dataset_example_id)
99
+ )
100
+ if version_id:
101
+ version_id_subquery = (
102
+ select(models.DatasetVersion.id)
103
+ .where(models.DatasetVersion.dataset_id == dataset_id)
104
+ .where(models.DatasetVersion.id == version_id)
105
+ .scalar_subquery()
106
+ )
107
+ revision_ids = revision_ids.where(
108
+ models.DatasetExampleRevision.dataset_version_id <= version_id_subquery
109
+ )
110
+ stmt = (
111
+ select(count(models.DatasetExampleRevision.id))
112
+ .where(models.DatasetExampleRevision.id.in_(revision_ids))
113
+ .where(models.DatasetExampleRevision.revision_kind != "DELETE")
114
+ )
115
+ async with info.context.db() as session:
116
+ return (await session.scalar(stmt)) or 0
117
+
118
+ @strawberry.field
119
+ async def examples(
120
+ self,
121
+ info: Info[Context, None],
122
+ dataset_version_id: Optional[GlobalID] = UNSET,
123
+ first: Optional[int] = 50,
124
+ last: Optional[int] = UNSET,
125
+ after: Optional[CursorString] = UNSET,
126
+ before: Optional[CursorString] = UNSET,
127
+ ) -> Connection[DatasetExample]:
128
+ args = ConnectionArgs(
129
+ first=first,
130
+ after=after if isinstance(after, CursorString) else None,
131
+ last=last,
132
+ before=before if isinstance(before, CursorString) else None,
133
+ )
134
+ dataset_id = self.id_attr
135
+ version_id = (
136
+ from_global_id_with_expected_type(
137
+ global_id=dataset_version_id, expected_type_name=DatasetVersion.__name__
138
+ )
139
+ if dataset_version_id
140
+ else None
141
+ )
142
+ revision_ids = (
143
+ select(func.max(models.DatasetExampleRevision.id))
144
+ .join(models.DatasetExample)
145
+ .where(models.DatasetExample.dataset_id == dataset_id)
146
+ .group_by(models.DatasetExampleRevision.dataset_example_id)
147
+ )
148
+ if version_id:
149
+ version_id_subquery = (
150
+ select(models.DatasetVersion.id)
151
+ .where(models.DatasetVersion.dataset_id == dataset_id)
152
+ .where(models.DatasetVersion.id == version_id)
153
+ .scalar_subquery()
154
+ )
155
+ revision_ids = revision_ids.where(
156
+ models.DatasetExampleRevision.dataset_version_id <= version_id_subquery
157
+ )
158
+ query = (
159
+ select(models.DatasetExample)
160
+ .join(
161
+ models.DatasetExampleRevision,
162
+ onclause=models.DatasetExample.id
163
+ == models.DatasetExampleRevision.dataset_example_id,
164
+ )
165
+ .where(
166
+ and_(
167
+ models.DatasetExampleRevision.id.in_(revision_ids),
168
+ models.DatasetExampleRevision.revision_kind != "DELETE",
169
+ )
170
+ )
171
+ .order_by(models.DatasetExampleRevision.dataset_example_id.desc())
172
+ )
173
+ async with info.context.db() as session:
174
+ dataset_examples = [
175
+ DatasetExample(
176
+ id_attr=example.id,
177
+ version_id=version_id,
178
+ created_at=example.created_at,
179
+ )
180
+ async for example in await session.stream_scalars(query)
181
+ ]
182
+ return connection_from_list(data=dataset_examples, args=args)
183
+
184
+ @strawberry.field(
185
+ description="Number of experiments for a specific version if version is specified, "
186
+ "or for all versions if version is not specified."
187
+ ) # type: ignore
188
+ async def experiment_count(
189
+ self,
190
+ info: Info[Context, None],
191
+ dataset_version_id: Optional[GlobalID] = UNSET,
192
+ ) -> int:
193
+ stmt = select(count(models.Experiment.id)).where(
194
+ models.Experiment.dataset_id == self.id_attr
195
+ )
196
+ version_id = (
197
+ from_global_id_with_expected_type(
198
+ global_id=dataset_version_id,
199
+ expected_type_name=DatasetVersion.__name__,
200
+ )
201
+ if dataset_version_id
202
+ else None
203
+ )
204
+ if version_id is not None:
205
+ stmt = stmt.where(models.Experiment.dataset_version_id == version_id)
206
+ async with info.context.db() as session:
207
+ return (await session.scalar(stmt)) or 0
208
+
209
+ @strawberry.field
210
+ async def experiments(
211
+ self,
212
+ info: Info[Context, None],
213
+ first: Optional[int] = 50,
214
+ last: Optional[int] = UNSET,
215
+ after: Optional[CursorString] = UNSET,
216
+ before: Optional[CursorString] = UNSET,
217
+ ) -> Connection[Experiment]:
218
+ args = ConnectionArgs(
219
+ first=first,
220
+ after=after if isinstance(after, CursorString) else None,
221
+ last=last,
222
+ before=before if isinstance(before, CursorString) else None,
223
+ )
224
+ dataset_id = self.id_attr
225
+ row_number = func.row_number().over(order_by=models.Experiment.id).label("row_number")
226
+ query = (
227
+ select(models.Experiment, row_number)
228
+ .where(models.Experiment.dataset_id == dataset_id)
229
+ .order_by(models.Experiment.id.desc())
230
+ )
231
+ async with info.context.db() as session:
232
+ experiments = [
233
+ to_gql_experiment(experiment, sequence_number)
234
+ async for experiment, sequence_number in cast(
235
+ AsyncIterable[Tuple[models.Experiment, int]],
236
+ await session.stream(query),
237
+ )
238
+ ]
239
+ return connection_from_list(data=experiments, args=args)
240
+
241
+ @strawberry.field
242
+ async def experiment_annotation_summaries(
243
+ self, info: Info[Context, None]
244
+ ) -> List[ExperimentAnnotationSummary]:
245
+ dataset_id = self.id_attr
246
+ query = (
247
+ select(
248
+ models.ExperimentRunAnnotation.name,
249
+ func.min(models.ExperimentRunAnnotation.score),
250
+ func.max(models.ExperimentRunAnnotation.score),
251
+ func.avg(models.ExperimentRunAnnotation.score),
252
+ func.count(),
253
+ func.count(models.ExperimentRunAnnotation.error),
254
+ )
255
+ .join(
256
+ models.ExperimentRun,
257
+ models.ExperimentRunAnnotation.experiment_run_id == models.ExperimentRun.id,
258
+ )
259
+ .join(
260
+ models.Experiment,
261
+ models.ExperimentRun.experiment_id == models.Experiment.id,
262
+ )
263
+ .where(models.Experiment.dataset_id == dataset_id)
264
+ .group_by(models.ExperimentRunAnnotation.name)
265
+ .order_by(models.ExperimentRunAnnotation.name)
266
+ )
267
+ async with info.context.db() as session:
268
+ return [
269
+ ExperimentAnnotationSummary(
270
+ annotation_name=annotation_name,
271
+ min_score=min_score,
272
+ max_score=max_score,
273
+ mean_score=mean_score,
274
+ count=count_,
275
+ error_count=error_count,
276
+ )
277
+ async for (
278
+ annotation_name,
279
+ min_score,
280
+ max_score,
281
+ mean_score,
282
+ count_,
283
+ error_count,
284
+ ) in await session.stream(query)
285
+ ]
62
286
 
63
287
 
64
- def _get_requested_features_and_tags(
65
- core_dimensions: Iterable[ScalarDimension],
66
- requested_dimension_names: Optional[Set[str]] = UNSET,
67
- ) -> List[Dimension]:
288
+ def to_gql_dataset(dataset: models.Dataset) -> Dataset:
68
289
  """
69
- Returns requested features and tags as a list of strawberry Datasets. If no
70
- dimensions are explicitly requested, returns all features and tags.
290
+ Converts an ORM dataset to a GraphQL dataset.
71
291
  """
72
- requested_features_and_tags: List[Dimension] = []
73
- for id, dim in enumerate(core_dimensions):
74
- is_requested = (
75
- not isinstance(requested_dimension_names, Set)
76
- ) or dim.name in requested_dimension_names
77
- is_feature_or_tag = dim.role in (FEATURE, TAG)
78
- if is_requested and is_feature_or_tag:
79
- requested_features_and_tags.append(to_gql_dimension(id_attr=id, dimension=dim))
80
- return requested_features_and_tags
292
+ return Dataset(
293
+ id_attr=dataset.id,
294
+ name=dataset.name,
295
+ description=dataset.description,
296
+ metadata=dataset.metadata_,
297
+ created_at=dataset.created_at,
298
+ updated_at=dataset.updated_at,
299
+ )
@@ -0,0 +1,85 @@
1
+ from datetime import datetime
2
+ from typing import Optional
3
+
4
+ import strawberry
5
+ from sqlalchemy import select
6
+ from sqlalchemy.orm import joinedload
7
+ from strawberry import UNSET
8
+ from strawberry.relay.types import Connection, GlobalID, Node, NodeID
9
+ from strawberry.types import Info
10
+
11
+ from phoenix.db import models
12
+ from phoenix.server.api.context import Context
13
+ from phoenix.server.api.types.DatasetExampleRevision import DatasetExampleRevision
14
+ from phoenix.server.api.types.DatasetVersion import DatasetVersion
15
+ from phoenix.server.api.types.ExperimentRun import ExperimentRun, to_gql_experiment_run
16
+ from phoenix.server.api.types.node import from_global_id_with_expected_type
17
+ from phoenix.server.api.types.pagination import (
18
+ ConnectionArgs,
19
+ CursorString,
20
+ connection_from_list,
21
+ )
22
+ from phoenix.server.api.types.Span import Span, to_gql_span
23
+
24
+
25
+ @strawberry.type
26
+ class DatasetExample(Node):
27
+ id_attr: NodeID[int]
28
+ created_at: datetime
29
+ version_id: strawberry.Private[Optional[int]] = None
30
+
31
+ @strawberry.field
32
+ async def revision(
33
+ self,
34
+ info: Info[Context, None],
35
+ dataset_version_id: Optional[GlobalID] = UNSET,
36
+ ) -> DatasetExampleRevision:
37
+ example_id = self.id_attr
38
+ version_id: Optional[int] = None
39
+ if dataset_version_id:
40
+ version_id = from_global_id_with_expected_type(
41
+ global_id=dataset_version_id, expected_type_name=DatasetVersion.__name__
42
+ )
43
+ elif self.version_id is not None:
44
+ version_id = self.version_id
45
+ return await info.context.data_loaders.dataset_example_revisions.load(
46
+ (example_id, version_id)
47
+ )
48
+
49
+ @strawberry.field
50
+ async def span(
51
+ self,
52
+ info: Info[Context, None],
53
+ ) -> Optional[Span]:
54
+ return (
55
+ to_gql_span(span)
56
+ if (span := await info.context.data_loaders.dataset_example_spans.load(self.id_attr))
57
+ else None
58
+ )
59
+
60
+ @strawberry.field
61
+ async def experiment_runs(
62
+ self,
63
+ info: Info[Context, None],
64
+ first: Optional[int] = 50,
65
+ last: Optional[int] = UNSET,
66
+ after: Optional[CursorString] = UNSET,
67
+ before: Optional[CursorString] = UNSET,
68
+ ) -> Connection[ExperimentRun]:
69
+ args = ConnectionArgs(
70
+ first=first,
71
+ after=after if isinstance(after, CursorString) else None,
72
+ last=last,
73
+ before=before if isinstance(before, CursorString) else None,
74
+ )
75
+ example_id = self.id_attr
76
+ query = (
77
+ select(models.ExperimentRun)
78
+ .options(joinedload(models.ExperimentRun.trace).load_only(models.Trace.trace_id))
79
+ .join(models.Experiment, models.Experiment.id == models.ExperimentRun.experiment_id)
80
+ .where(models.ExperimentRun.dataset_example_id == example_id)
81
+ .order_by(models.Experiment.id.desc())
82
+ )
83
+ async with info.context.db() as session:
84
+ runs = (await session.scalars(query)).all()
85
+ return connection_from_list([to_gql_experiment_run(run) for run in runs], args)
@@ -0,0 +1,34 @@
1
+ from datetime import datetime
2
+ from enum import Enum
3
+
4
+ import strawberry
5
+
6
+ from phoenix.db import models
7
+ from phoenix.server.api.types.ExampleRevisionInterface import ExampleRevision
8
+
9
+
10
+ @strawberry.enum
11
+ class RevisionKind(Enum):
12
+ CREATE = "CREATE"
13
+ PATCH = "PATCH"
14
+ DELETE = "DELETE"
15
+
16
+
17
+ @strawberry.type
18
+ class DatasetExampleRevision(ExampleRevision):
19
+ """
20
+ Represents a revision (i.e., update or alteration) of a dataset example.
21
+ """
22
+
23
+ revision_kind: RevisionKind
24
+ created_at: datetime
25
+
26
+ @classmethod
27
+ def from_orm_revision(cls, revision: models.DatasetExampleRevision) -> "DatasetExampleRevision":
28
+ return cls(
29
+ input=revision.input,
30
+ output=revision.output,
31
+ metadata=revision.metadata_,
32
+ revision_kind=RevisionKind(revision.revision_kind),
33
+ created_at=revision.created_at,
34
+ )
@@ -0,0 +1,14 @@
1
+ from datetime import datetime
2
+ from typing import Optional
3
+
4
+ import strawberry
5
+ from strawberry.relay import Node, NodeID
6
+ from strawberry.scalars import JSON
7
+
8
+
9
+ @strawberry.type
10
+ class DatasetVersion(Node):
11
+ id_attr: NodeID[int]
12
+ description: Optional[str]
13
+ metadata: JSON
14
+ created_at: datetime