arize-phoenix 4.5.0__py3-none-any.whl → 4.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (123) hide show
  1. {arize_phoenix-4.5.0.dist-info → arize_phoenix-4.6.1.dist-info}/METADATA +16 -8
  2. {arize_phoenix-4.5.0.dist-info → arize_phoenix-4.6.1.dist-info}/RECORD +122 -58
  3. {arize_phoenix-4.5.0.dist-info → arize_phoenix-4.6.1.dist-info}/WHEEL +1 -1
  4. phoenix/__init__.py +0 -27
  5. phoenix/config.py +42 -7
  6. phoenix/core/model.py +25 -25
  7. phoenix/core/model_schema.py +64 -62
  8. phoenix/core/model_schema_adapter.py +27 -25
  9. phoenix/datetime_utils.py +4 -0
  10. phoenix/db/bulk_inserter.py +54 -14
  11. phoenix/db/insertion/dataset.py +237 -0
  12. phoenix/db/insertion/evaluation.py +10 -10
  13. phoenix/db/insertion/helpers.py +17 -14
  14. phoenix/db/insertion/span.py +3 -3
  15. phoenix/db/migrations/types.py +29 -0
  16. phoenix/db/migrations/versions/10460e46d750_datasets.py +291 -0
  17. phoenix/db/migrations/versions/cf03bd6bae1d_init.py +2 -28
  18. phoenix/db/models.py +236 -4
  19. phoenix/experiments/__init__.py +6 -0
  20. phoenix/experiments/evaluators/__init__.py +29 -0
  21. phoenix/experiments/evaluators/base.py +153 -0
  22. phoenix/experiments/evaluators/code_evaluators.py +99 -0
  23. phoenix/experiments/evaluators/llm_evaluators.py +244 -0
  24. phoenix/experiments/evaluators/utils.py +186 -0
  25. phoenix/experiments/functions.py +757 -0
  26. phoenix/experiments/tracing.py +85 -0
  27. phoenix/experiments/types.py +753 -0
  28. phoenix/experiments/utils.py +24 -0
  29. phoenix/inferences/fixtures.py +23 -23
  30. phoenix/inferences/inferences.py +7 -7
  31. phoenix/inferences/validation.py +1 -1
  32. phoenix/server/api/context.py +20 -0
  33. phoenix/server/api/dataloaders/__init__.py +20 -0
  34. phoenix/server/api/dataloaders/average_experiment_run_latency.py +54 -0
  35. phoenix/server/api/dataloaders/dataset_example_revisions.py +100 -0
  36. phoenix/server/api/dataloaders/dataset_example_spans.py +43 -0
  37. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +85 -0
  38. phoenix/server/api/dataloaders/experiment_error_rates.py +43 -0
  39. phoenix/server/api/dataloaders/experiment_run_counts.py +42 -0
  40. phoenix/server/api/dataloaders/experiment_sequence_number.py +49 -0
  41. phoenix/server/api/dataloaders/project_by_name.py +31 -0
  42. phoenix/server/api/dataloaders/span_descendants.py +2 -3
  43. phoenix/server/api/dataloaders/span_projects.py +33 -0
  44. phoenix/server/api/dataloaders/trace_row_ids.py +39 -0
  45. phoenix/server/api/helpers/dataset_helpers.py +179 -0
  46. phoenix/server/api/input_types/AddExamplesToDatasetInput.py +16 -0
  47. phoenix/server/api/input_types/AddSpansToDatasetInput.py +14 -0
  48. phoenix/server/api/input_types/ClearProjectInput.py +15 -0
  49. phoenix/server/api/input_types/CreateDatasetInput.py +12 -0
  50. phoenix/server/api/input_types/DatasetExampleInput.py +14 -0
  51. phoenix/server/api/input_types/DatasetSort.py +17 -0
  52. phoenix/server/api/input_types/DatasetVersionSort.py +16 -0
  53. phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +13 -0
  54. phoenix/server/api/input_types/DeleteDatasetInput.py +7 -0
  55. phoenix/server/api/input_types/DeleteExperimentsInput.py +9 -0
  56. phoenix/server/api/input_types/PatchDatasetExamplesInput.py +35 -0
  57. phoenix/server/api/input_types/PatchDatasetInput.py +14 -0
  58. phoenix/server/api/mutations/__init__.py +13 -0
  59. phoenix/server/api/mutations/auth.py +11 -0
  60. phoenix/server/api/mutations/dataset_mutations.py +520 -0
  61. phoenix/server/api/mutations/experiment_mutations.py +65 -0
  62. phoenix/server/api/{types/ExportEventsMutation.py → mutations/export_events_mutations.py} +17 -14
  63. phoenix/server/api/mutations/project_mutations.py +47 -0
  64. phoenix/server/api/openapi/__init__.py +0 -0
  65. phoenix/server/api/openapi/main.py +6 -0
  66. phoenix/server/api/openapi/schema.py +16 -0
  67. phoenix/server/api/queries.py +503 -0
  68. phoenix/server/api/routers/v1/__init__.py +77 -2
  69. phoenix/server/api/routers/v1/dataset_examples.py +178 -0
  70. phoenix/server/api/routers/v1/datasets.py +965 -0
  71. phoenix/server/api/routers/v1/evaluations.py +8 -13
  72. phoenix/server/api/routers/v1/experiment_evaluations.py +143 -0
  73. phoenix/server/api/routers/v1/experiment_runs.py +220 -0
  74. phoenix/server/api/routers/v1/experiments.py +302 -0
  75. phoenix/server/api/routers/v1/spans.py +9 -5
  76. phoenix/server/api/routers/v1/traces.py +1 -4
  77. phoenix/server/api/schema.py +2 -303
  78. phoenix/server/api/types/AnnotatorKind.py +10 -0
  79. phoenix/server/api/types/Cluster.py +19 -19
  80. phoenix/server/api/types/CreateDatasetPayload.py +8 -0
  81. phoenix/server/api/types/Dataset.py +282 -63
  82. phoenix/server/api/types/DatasetExample.py +85 -0
  83. phoenix/server/api/types/DatasetExampleRevision.py +34 -0
  84. phoenix/server/api/types/DatasetVersion.py +14 -0
  85. phoenix/server/api/types/Dimension.py +30 -29
  86. phoenix/server/api/types/EmbeddingDimension.py +40 -34
  87. phoenix/server/api/types/Event.py +16 -16
  88. phoenix/server/api/types/ExampleRevisionInterface.py +14 -0
  89. phoenix/server/api/types/Experiment.py +147 -0
  90. phoenix/server/api/types/ExperimentAnnotationSummary.py +13 -0
  91. phoenix/server/api/types/ExperimentComparison.py +19 -0
  92. phoenix/server/api/types/ExperimentRun.py +91 -0
  93. phoenix/server/api/types/ExperimentRunAnnotation.py +57 -0
  94. phoenix/server/api/types/Inferences.py +80 -0
  95. phoenix/server/api/types/InferencesRole.py +23 -0
  96. phoenix/server/api/types/Model.py +43 -42
  97. phoenix/server/api/types/Project.py +26 -12
  98. phoenix/server/api/types/Span.py +79 -2
  99. phoenix/server/api/types/TimeSeries.py +6 -6
  100. phoenix/server/api/types/Trace.py +15 -4
  101. phoenix/server/api/types/UMAPPoints.py +1 -1
  102. phoenix/server/api/types/node.py +5 -111
  103. phoenix/server/api/types/pagination.py +10 -52
  104. phoenix/server/app.py +103 -49
  105. phoenix/server/main.py +49 -27
  106. phoenix/server/openapi/docs.py +3 -0
  107. phoenix/server/static/index.js +2300 -1294
  108. phoenix/server/templates/index.html +1 -0
  109. phoenix/services.py +15 -15
  110. phoenix/session/client.py +581 -22
  111. phoenix/session/session.py +47 -37
  112. phoenix/trace/exporter.py +14 -9
  113. phoenix/trace/fixtures.py +133 -7
  114. phoenix/trace/schemas.py +1 -2
  115. phoenix/trace/span_evaluations.py +3 -3
  116. phoenix/trace/trace_dataset.py +6 -6
  117. phoenix/utilities/json.py +61 -0
  118. phoenix/utilities/re.py +50 -0
  119. phoenix/version.py +1 -1
  120. phoenix/server/api/types/DatasetRole.py +0 -23
  121. {arize_phoenix-4.5.0.dist-info → arize_phoenix-4.6.1.dist-info}/licenses/IP_NOTICE +0 -0
  122. {arize_phoenix-4.5.0.dist-info → arize_phoenix-4.6.1.dist-info}/licenses/LICENSE +0 -0
  123. /phoenix/server/api/{helpers.py → helpers/__init__.py} +0 -0
@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional
4
4
  import pandas as pd
5
5
  import strawberry
6
6
  from strawberry import UNSET
7
+ from strawberry.relay import Node, NodeID
7
8
  from strawberry.types import Info
8
9
  from typing_extensions import Annotated
9
10
 
@@ -17,12 +18,11 @@ from ..context import Context
17
18
  from ..input_types.Granularity import Granularity
18
19
  from ..input_types.TimeRange import TimeRange
19
20
  from .DataQualityMetric import DataQualityMetric
20
- from .DatasetRole import DatasetRole
21
21
  from .DatasetValues import DatasetValues
22
22
  from .DimensionDataType import DimensionDataType
23
23
  from .DimensionShape import DimensionShape
24
24
  from .DimensionType import DimensionType
25
- from .node import Node
25
+ from .InferencesRole import InferencesRole
26
26
  from .ScalarDriftMetricEnum import ScalarDriftMetric
27
27
  from .Segments import (
28
28
  GqlBinFactory,
@@ -40,6 +40,7 @@ from .TimeSeries import (
40
40
 
41
41
  @strawberry.type
42
42
  class Dimension(Node):
43
+ id_attr: NodeID[int]
43
44
  name: str = strawberry.field(description="The name of the dimension (a.k.a. the column name)")
44
45
  type: DimensionType = strawberry.field(
45
46
  description="Whether the dimension represents a feature, tag, prediction, or actual."
@@ -62,16 +63,16 @@ class Dimension(Node):
62
63
  """
63
64
  Computes a drift metric between all reference data and the primary data
64
65
  belonging to the input time range (inclusive of the time range start and
65
- exclusive of the time range end). Returns None if no reference dataset
66
- exists, if no primary data exists in the input time range, or if the
66
+ exclusive of the time range end). Returns None if no reference inferences
67
+ exist, if no primary data exists in the input time range, or if the
67
68
  input time range is invalid.
68
69
  """
69
70
  model = info.context.model
70
71
  if model[REFERENCE].empty:
71
72
  return None
72
- dataset = model[PRIMARY]
73
+ inferences = model[PRIMARY]
73
74
  time_range, granularity = ensure_timeseries_parameters(
74
- dataset,
75
+ inferences,
75
76
  time_range,
76
77
  )
77
78
  data = get_drift_timeseries_data(
@@ -92,18 +93,18 @@ class Dimension(Node):
92
93
  info: Info[Context, None],
93
94
  metric: DataQualityMetric,
94
95
  time_range: Optional[TimeRange] = UNSET,
95
- dataset_role: Annotated[
96
- Optional[DatasetRole],
96
+ inferences_role: Annotated[
97
+ Optional[InferencesRole],
97
98
  strawberry.argument(
98
- description="The dataset (primary or reference) to query",
99
+ description="The inferences (primary or reference) to query",
99
100
  ),
100
- ] = DatasetRole.primary,
101
+ ] = InferencesRole.primary,
101
102
  ) -> Optional[float]:
102
- if not isinstance(dataset_role, DatasetRole):
103
- dataset_role = DatasetRole.primary
104
- dataset = info.context.model[dataset_role.value]
103
+ if not isinstance(inferences_role, InferencesRole):
104
+ inferences_role = InferencesRole.primary
105
+ inferences = info.context.model[inferences_role.value]
105
106
  time_range, granularity = ensure_timeseries_parameters(
106
- dataset,
107
+ inferences,
107
108
  time_range,
108
109
  )
109
110
  data = get_data_quality_timeseries_data(
@@ -111,7 +112,7 @@ class Dimension(Node):
111
112
  metric,
112
113
  time_range,
113
114
  granularity,
114
- dataset_role,
115
+ inferences_role,
115
116
  )
116
117
  return data[0].value if len(data) else None
117
118
 
@@ -139,18 +140,18 @@ class Dimension(Node):
139
140
  metric: DataQualityMetric,
140
141
  time_range: TimeRange,
141
142
  granularity: Granularity,
142
- dataset_role: Annotated[
143
- Optional[DatasetRole],
143
+ inferences_role: Annotated[
144
+ Optional[InferencesRole],
144
145
  strawberry.argument(
145
- description="The dataset (primary or reference) to query",
146
+ description="The inferences (primary or reference) to query",
146
147
  ),
147
- ] = DatasetRole.primary,
148
+ ] = InferencesRole.primary,
148
149
  ) -> DataQualityTimeSeries:
149
- if not isinstance(dataset_role, DatasetRole):
150
- dataset_role = DatasetRole.primary
151
- dataset = info.context.model[dataset_role.value]
150
+ if not isinstance(inferences_role, InferencesRole):
151
+ inferences_role = InferencesRole.primary
152
+ inferences = info.context.model[inferences_role.value]
152
153
  time_range, granularity = ensure_timeseries_parameters(
153
- dataset,
154
+ inferences,
154
155
  time_range,
155
156
  granularity,
156
157
  )
@@ -160,7 +161,7 @@ class Dimension(Node):
160
161
  metric,
161
162
  time_range,
162
163
  granularity,
163
- dataset_role,
164
+ inferences_role,
164
165
  )
165
166
  )
166
167
 
@@ -182,9 +183,9 @@ class Dimension(Node):
182
183
  model = info.context.model
183
184
  if model[REFERENCE].empty:
184
185
  return DriftTimeSeries(data=[])
185
- dataset = model[PRIMARY]
186
+ inferences = model[PRIMARY]
186
187
  time_range, granularity = ensure_timeseries_parameters(
187
- dataset,
188
+ inferences,
188
189
  time_range,
189
190
  granularity,
190
191
  )
@@ -202,7 +203,7 @@ class Dimension(Node):
202
203
  )
203
204
 
204
205
  @strawberry.field(
205
- description="Returns the segments across both datasets and returns the counts per segment",
206
+ description="The segments across both inference sets and returns the counts per segment",
206
207
  ) # type: ignore
207
208
  def segments_comparison(
208
209
  self,
@@ -249,8 +250,8 @@ class Dimension(Node):
249
250
  if isinstance(binning_method, binning.IntervalBinning) and binning_method.bins is not None:
250
251
  all_bins = all_bins.union(binning_method.bins)
251
252
  for bin in all_bins:
252
- values: Dict[ms.DatasetRole, Any] = defaultdict(lambda: None)
253
- for role in ms.DatasetRole:
253
+ values: Dict[ms.InferencesRole, Any] = defaultdict(lambda: None)
254
+ for role in ms.InferencesRole:
254
255
  if model[role].empty:
255
256
  continue
256
257
  try:
@@ -8,6 +8,7 @@ import numpy.typing as npt
8
8
  import pandas as pd
9
9
  import strawberry
10
10
  from strawberry import UNSET
11
+ from strawberry.relay import GlobalID, Node, NodeID
11
12
  from strawberry.scalars import ID
12
13
  from strawberry.types import Info
13
14
  from typing_extensions import Annotated
@@ -22,7 +23,7 @@ from phoenix.core.model_schema import (
22
23
  PRIMARY,
23
24
  PROMPT,
24
25
  REFERENCE,
25
- Dataset,
26
+ Inferences,
26
27
  )
27
28
  from phoenix.metrics.timeseries import row_interval_from_sorted_time_index
28
29
  from phoenix.pointcloud.clustering import Hdbscan
@@ -31,7 +32,7 @@ from phoenix.pointcloud.projectors import Umap
31
32
  from phoenix.server.api.context import Context
32
33
  from phoenix.server.api.input_types.TimeRange import TimeRange
33
34
  from phoenix.server.api.types.Cluster import to_gql_clusters
34
- from phoenix.server.api.types.DatasetRole import AncillaryDatasetRole, DatasetRole
35
+ from phoenix.server.api.types.InferencesRole import AncillaryInferencesRole, InferencesRole
35
36
  from phoenix.server.api.types.VectorDriftMetricEnum import VectorDriftMetric
36
37
 
37
38
  from ..input_types.Granularity import Granularity
@@ -39,7 +40,6 @@ from .DataQualityMetric import DataQualityMetric
39
40
  from .EmbeddingMetadata import EmbeddingMetadata
40
41
  from .Event import create_event_id, unpack_event_id
41
42
  from .EventMetadata import EventMetadata
42
- from .node import GlobalID, Node
43
43
  from .Retrieval import Retrieval
44
44
  from .TimeSeries import (
45
45
  DataQualityTimeSeries,
@@ -70,6 +70,7 @@ CORPUS = "CORPUS"
70
70
  class EmbeddingDimension(Node):
71
71
  """A embedding dimension of a model. Represents unstructured data"""
72
72
 
73
+ id_attr: NodeID[int]
73
74
  name: str
74
75
  dimension: strawberry.Private[ms.EmbeddingDimension]
75
76
 
@@ -155,16 +156,16 @@ class EmbeddingDimension(Node):
155
156
  metric: DataQualityMetric,
156
157
  time_range: TimeRange,
157
158
  granularity: Granularity,
158
- dataset_role: Annotated[
159
- Optional[DatasetRole],
159
+ inferences_role: Annotated[
160
+ Optional[InferencesRole],
160
161
  strawberry.argument(
161
162
  description="The dataset (primary or reference) to query",
162
163
  ),
163
- ] = DatasetRole.primary,
164
+ ] = InferencesRole.primary,
164
165
  ) -> DataQualityTimeSeries:
165
- if not isinstance(dataset_role, DatasetRole):
166
- dataset_role = DatasetRole.primary
167
- dataset = info.context.model[dataset_role.value]
166
+ if not isinstance(inferences_role, InferencesRole):
167
+ inferences_role = InferencesRole.primary
168
+ dataset = info.context.model[inferences_role.value]
168
169
  time_range, granularity = ensure_timeseries_parameters(
169
170
  dataset,
170
171
  time_range,
@@ -176,7 +177,7 @@ class EmbeddingDimension(Node):
176
177
  metric,
177
178
  time_range,
178
179
  granularity,
179
- dataset_role,
180
+ inferences_role,
180
181
  )
181
182
  )
182
183
 
@@ -314,16 +315,16 @@ class EmbeddingDimension(Node):
314
315
  model = info.context.model
315
316
  data: Dict[ID, npt.NDArray[np.float64]] = {}
316
317
  retrievals: List[Tuple[ID, Any, Any]] = []
317
- for dataset in model[Dataset]:
318
- dataset_id = dataset.role
319
- row_id_start, row_id_stop = 0, len(dataset)
320
- if dataset_id is PRIMARY:
318
+ for inferences in model[Inferences]:
319
+ inferences_id = inferences.role
320
+ row_id_start, row_id_stop = 0, len(inferences)
321
+ if inferences_id is PRIMARY:
321
322
  row_id_start, row_id_stop = row_interval_from_sorted_time_index(
322
- time_index=cast(pd.DatetimeIndex, dataset.index),
323
+ time_index=cast(pd.DatetimeIndex, inferences.index),
323
324
  time_start=time_range.start,
324
325
  time_stop=time_range.end,
325
326
  )
326
- vector_column = self.dimension[dataset_id]
327
+ vector_column = self.dimension[inferences_id]
327
328
  samples_collected = 0
328
329
  for row_id in _row_indices(
329
330
  row_id_start,
@@ -337,7 +338,7 @@ class EmbeddingDimension(Node):
337
338
  # of dunder method __len__.
338
339
  if not hasattr(embedding_vector, "__len__"):
339
340
  continue
340
- event_id = create_event_id(row_id, dataset_id)
341
+ event_id = create_event_id(row_id, inferences_id)
341
342
  data[event_id] = embedding_vector
342
343
  samples_collected += 1
343
344
  if isinstance(
@@ -347,8 +348,8 @@ class EmbeddingDimension(Node):
347
348
  retrievals.append(
348
349
  (
349
350
  event_id,
350
- self.dimension.context_retrieval_ids(dataset).iloc[row_id],
351
- self.dimension.context_retrieval_scores(dataset).iloc[row_id],
351
+ self.dimension.context_retrieval_ids(inferences).iloc[row_id],
352
+ self.dimension.context_retrieval_scores(inferences).iloc[row_id],
352
353
  )
353
354
  )
354
355
 
@@ -357,13 +358,13 @@ class EmbeddingDimension(Node):
357
358
  self.dimension,
358
359
  ms.RetrievalEmbeddingDimension,
359
360
  ) and (corpus := info.context.corpus):
360
- corpus_dataset = corpus[PRIMARY]
361
- for row_id, document_embedding_vector in enumerate(corpus_dataset[PROMPT]):
361
+ corpus_inferences = corpus[PRIMARY]
362
+ for row_id, document_embedding_vector in enumerate(corpus_inferences[PROMPT]):
362
363
  if not hasattr(document_embedding_vector, "__len__"):
363
364
  continue
364
- event_id = create_event_id(row_id, AncillaryDatasetRole.corpus)
365
+ event_id = create_event_id(row_id, AncillaryInferencesRole.corpus)
365
366
  data[event_id] = document_embedding_vector
366
- corpus_primary_key = corpus_dataset.primary_key
367
+ corpus_primary_key = corpus_inferences.primary_key
367
368
  for event_id, retrieval_ids, retrieval_scores in retrievals:
368
369
  if not isinstance(retrieval_ids, Iterable):
369
370
  continue
@@ -385,7 +386,7 @@ class EmbeddingDimension(Node):
385
386
  )
386
387
  except KeyError:
387
388
  continue
388
- document_embedding_vector = corpus_dataset[PROMPT].iloc[document_row_id]
389
+ document_embedding_vector = corpus_inferences[PROMPT].iloc[document_row_id]
389
390
  if not hasattr(document_embedding_vector, "__len__"):
390
391
  continue
391
392
  context_retrievals.append(
@@ -393,7 +394,7 @@ class EmbeddingDimension(Node):
393
394
  query_id=event_id,
394
395
  document_id=create_event_id(
395
396
  document_row_id,
396
- AncillaryDatasetRole.corpus,
397
+ AncillaryInferencesRole.corpus,
397
398
  ),
398
399
  relevance=document_score,
399
400
  )
@@ -413,11 +414,13 @@ class EmbeddingDimension(Node):
413
414
  ),
414
415
  ).generate(data, n_components=n_components)
415
416
 
416
- points: Dict[Union[DatasetRole, AncillaryDatasetRole], List[UMAPPoint]] = defaultdict(list)
417
+ points: Dict[Union[InferencesRole, AncillaryInferencesRole], List[UMAPPoint]] = defaultdict(
418
+ list
419
+ )
417
420
  for event_id, vector in vectors.items():
418
- row_id, dataset_role = unpack_event_id(event_id)
419
- if isinstance(dataset_role, DatasetRole):
420
- dataset = model[dataset_role.value]
421
+ row_id, inferences_role = unpack_event_id(event_id)
422
+ if isinstance(inferences_role, InferencesRole):
423
+ dataset = model[inferences_role.value]
421
424
  embedding_metadata = EmbeddingMetadata(
422
425
  prediction_id=dataset[PREDICTION_ID][row_id],
423
426
  link_to_data=dataset[self.dimension.link_to_data][row_id],
@@ -433,9 +436,12 @@ class EmbeddingDimension(Node):
433
436
  )
434
437
  else:
435
438
  continue
436
- points[dataset_role].append(
439
+ points[inferences_role].append(
437
440
  UMAPPoint(
438
- id=GlobalID(f"{type(self).__name__}:{str(dataset_role)}", row_id),
441
+ id=GlobalID(
442
+ type_name=f"{type(self).__name__}:{str(inferences_role)}",
443
+ node_id=str(row_id),
444
+ ),
439
445
  event_id=event_id,
440
446
  coordinates=to_gql_coordinates(vector),
441
447
  event_metadata=EventMetadata(
@@ -449,12 +455,12 @@ class EmbeddingDimension(Node):
449
455
  )
450
456
 
451
457
  return UMAPPoints(
452
- data=points[DatasetRole.primary],
453
- reference_data=points[DatasetRole.reference],
458
+ data=points[InferencesRole.primary],
459
+ reference_data=points[InferencesRole.reference],
454
460
  clusters=to_gql_clusters(
455
461
  clustered_events=clustered_events,
456
462
  ),
457
- corpus_data=points[AncillaryDatasetRole.corpus],
463
+ corpus_data=points[AncillaryInferencesRole.corpus],
458
464
  context_retrievals=context_retrievals,
459
465
  )
460
466
 
@@ -17,10 +17,10 @@ from phoenix.core.model_schema import (
17
17
  )
18
18
 
19
19
  from ..interceptor import GqlValueMediator
20
- from .DatasetRole import STR_TO_DATASET_ROLE, AncillaryDatasetRole, DatasetRole
21
20
  from .Dimension import Dimension
22
21
  from .DimensionWithValue import DimensionWithValue
23
22
  from .EventMetadata import EventMetadata
23
+ from .InferencesRole import STR_TO_INFEREENCES_ROLE, AncillaryInferencesRole, InferencesRole
24
24
  from .PromptResponse import PromptResponse
25
25
 
26
26
 
@@ -41,35 +41,35 @@ class Event:
41
41
 
42
42
  def create_event_id(
43
43
  row_id: int,
44
- dataset_role: Union[DatasetRole, AncillaryDatasetRole, ms.DatasetRole],
44
+ inferences_role: Union[InferencesRole, AncillaryInferencesRole, ms.InferencesRole],
45
45
  ) -> ID:
46
- dataset_role_str = (
47
- dataset_role.value
48
- if isinstance(dataset_role, (DatasetRole, AncillaryDatasetRole))
49
- else dataset_role
46
+ inferences_role_str = (
47
+ inferences_role.value
48
+ if isinstance(inferences_role, (InferencesRole, AncillaryInferencesRole))
49
+ else inferences_role
50
50
  )
51
- return ID(f"{row_id}:{dataset_role_str}")
51
+ return ID(f"{row_id}:{inferences_role_str}")
52
52
 
53
53
 
54
54
  def unpack_event_id(
55
55
  event_id: ID,
56
- ) -> Tuple[int, Union[DatasetRole, AncillaryDatasetRole]]:
57
- row_id_str, dataset_role_str = str(event_id).split(":")
56
+ ) -> Tuple[int, Union[InferencesRole, AncillaryInferencesRole]]:
57
+ row_id_str, inferences_role_str = str(event_id).split(":")
58
58
  row_id = int(row_id_str)
59
- dataset_role = STR_TO_DATASET_ROLE[dataset_role_str]
60
- return row_id, dataset_role
59
+ inferences_role = STR_TO_INFEREENCES_ROLE[inferences_role_str]
60
+ return row_id, inferences_role
61
61
 
62
62
 
63
- def parse_event_ids_by_dataset_role(
63
+ def parse_event_ids_by_inferences_role(
64
64
  event_ids: List[ID],
65
- ) -> Dict[Union[DatasetRole, AncillaryDatasetRole], List[int]]:
65
+ ) -> Dict[Union[InferencesRole, AncillaryInferencesRole], List[int]]:
66
66
  """
67
67
  Parses event IDs and returns the corresponding row indexes.
68
68
  """
69
- row_indexes: Dict[Union[DatasetRole, AncillaryDatasetRole], List[int]] = defaultdict(list)
69
+ row_indexes: Dict[Union[InferencesRole, AncillaryInferencesRole], List[int]] = defaultdict(list)
70
70
  for event_id in event_ids:
71
- row_id, dataset_role = unpack_event_id(event_id)
72
- row_indexes[dataset_role].append(row_id)
71
+ row_id, inferences_role = unpack_event_id(event_id)
72
+ row_indexes[inferences_role].append(row_id)
73
73
  return row_indexes
74
74
 
75
75
 
@@ -0,0 +1,14 @@
1
+ import strawberry
2
+ from strawberry.scalars import JSON
3
+
4
+
5
+ @strawberry.interface
6
+ class ExampleRevision:
7
+ """
8
+ Represents an example revision for generative tasks.
9
+ For example, you might have text -> text, text -> labels, etc.
10
+ """
11
+
12
+ input: JSON
13
+ output: JSON
14
+ metadata: JSON
@@ -0,0 +1,147 @@
1
+ from datetime import datetime
2
+ from typing import List, Optional
3
+
4
+ import strawberry
5
+ from sqlalchemy import select
6
+ from sqlalchemy.orm import joinedload
7
+ from strawberry import UNSET, Private
8
+ from strawberry.relay import Connection, Node, NodeID
9
+ from strawberry.scalars import JSON
10
+ from strawberry.types import Info
11
+
12
+ from phoenix.db import models
13
+ from phoenix.server.api.context import Context
14
+ from phoenix.server.api.types.ExperimentAnnotationSummary import ExperimentAnnotationSummary
15
+ from phoenix.server.api.types.ExperimentRun import ExperimentRun, to_gql_experiment_run
16
+ from phoenix.server.api.types.pagination import (
17
+ ConnectionArgs,
18
+ CursorString,
19
+ connection_from_list,
20
+ )
21
+ from phoenix.server.api.types.Project import Project
22
+
23
+
24
+ @strawberry.type
25
+ class Experiment(Node):
26
+ cached_sequence_number: Private[Optional[int]] = None
27
+ id_attr: NodeID[int]
28
+ name: str
29
+ project_name: Optional[str]
30
+ description: Optional[str]
31
+ metadata: JSON
32
+ created_at: datetime
33
+ updated_at: datetime
34
+
35
+ @strawberry.field(
36
+ description="Sequence number (1-based) of experiments belonging to the same dataset"
37
+ ) # type: ignore
38
+ async def sequence_number(
39
+ self,
40
+ info: Info[Context, None],
41
+ ) -> int:
42
+ if self.cached_sequence_number is None:
43
+ seq_num = await info.context.data_loaders.experiment_sequence_number.load(self.id_attr)
44
+ if seq_num is None:
45
+ raise ValueError(f"invalid experiment: id={self.id_attr}")
46
+ self.cached_sequence_number = seq_num
47
+ return self.cached_sequence_number
48
+
49
+ @strawberry.field
50
+ async def runs(
51
+ self,
52
+ info: Info[Context, None],
53
+ first: Optional[int] = 50,
54
+ last: Optional[int] = UNSET,
55
+ after: Optional[CursorString] = UNSET,
56
+ before: Optional[CursorString] = UNSET,
57
+ ) -> Connection[ExperimentRun]:
58
+ args = ConnectionArgs(
59
+ first=first,
60
+ after=after if isinstance(after, CursorString) else None,
61
+ last=last,
62
+ before=before if isinstance(before, CursorString) else None,
63
+ )
64
+ experiment_id = self.id_attr
65
+ async with info.context.db() as session:
66
+ runs = (
67
+ await session.scalars(
68
+ select(models.ExperimentRun)
69
+ .where(models.ExperimentRun.experiment_id == experiment_id)
70
+ .order_by(models.ExperimentRun.id.desc())
71
+ .options(
72
+ joinedload(models.ExperimentRun.trace).load_only(models.Trace.trace_id)
73
+ )
74
+ )
75
+ ).all()
76
+ return connection_from_list([to_gql_experiment_run(run) for run in runs], args)
77
+
78
+ @strawberry.field
79
+ async def run_count(self, info: Info[Context, None]) -> int:
80
+ experiment_id = self.id_attr
81
+ return await info.context.data_loaders.experiment_run_counts.load(experiment_id)
82
+
83
+ @strawberry.field
84
+ async def annotation_summaries(
85
+ self, info: Info[Context, None]
86
+ ) -> List[ExperimentAnnotationSummary]:
87
+ experiment_id = self.id_attr
88
+ return [
89
+ ExperimentAnnotationSummary(
90
+ annotation_name=summary.annotation_name,
91
+ min_score=summary.min_score,
92
+ max_score=summary.max_score,
93
+ mean_score=summary.mean_score,
94
+ count=summary.count,
95
+ error_count=summary.error_count,
96
+ )
97
+ for summary in await info.context.data_loaders.experiment_annotation_summaries.load(
98
+ experiment_id
99
+ )
100
+ ]
101
+
102
+ @strawberry.field
103
+ async def error_rate(self, info: Info[Context, None]) -> Optional[float]:
104
+ return await info.context.data_loaders.experiment_error_rates.load(self.id_attr)
105
+
106
+ @strawberry.field
107
+ async def average_run_latency_ms(self, info: Info[Context, None]) -> float:
108
+ latency_seconds = await info.context.data_loaders.average_experiment_run_latency.load(
109
+ self.id_attr
110
+ )
111
+ return latency_seconds * 1000
112
+
113
+ @strawberry.field
114
+ async def project(self, info: Info[Context, None]) -> Optional[Project]:
115
+ if self.project_name is None:
116
+ return None
117
+
118
+ db_project = await info.context.data_loaders.project_by_name.load(self.project_name)
119
+
120
+ if db_project is None:
121
+ return None
122
+
123
+ return Project(
124
+ id_attr=db_project.id,
125
+ name=db_project.name,
126
+ gradient_start_color=db_project.gradient_start_color,
127
+ gradient_end_color=db_project.gradient_end_color,
128
+ )
129
+
130
+
131
+ def to_gql_experiment(
132
+ experiment: models.Experiment,
133
+ sequence_number: Optional[int] = None,
134
+ ) -> Experiment:
135
+ """
136
+ Converts an ORM experiment to a GraphQL Experiment.
137
+ """
138
+ return Experiment(
139
+ cached_sequence_number=sequence_number,
140
+ id_attr=experiment.id,
141
+ name=experiment.name,
142
+ project_name=experiment.project_name,
143
+ description=experiment.description,
144
+ metadata=experiment.metadata_,
145
+ created_at=experiment.created_at,
146
+ updated_at=experiment.updated_at,
147
+ )
@@ -0,0 +1,13 @@
1
+ from typing import Optional
2
+
3
+ import strawberry
4
+
5
+
6
+ @strawberry.type
7
+ class ExperimentAnnotationSummary:
8
+ annotation_name: str
9
+ min_score: Optional[float]
10
+ max_score: Optional[float]
11
+ mean_score: Optional[float]
12
+ count: int
13
+ error_count: int
@@ -0,0 +1,19 @@
1
+ from typing import List
2
+
3
+ import strawberry
4
+ from strawberry.relay import GlobalID
5
+
6
+ from phoenix.server.api.types.DatasetExample import DatasetExample
7
+ from phoenix.server.api.types.ExperimentRun import ExperimentRun
8
+
9
+
10
+ @strawberry.type
11
+ class RunComparisonItem:
12
+ experiment_id: GlobalID
13
+ runs: List[ExperimentRun]
14
+
15
+
16
+ @strawberry.type
17
+ class ExperimentComparison:
18
+ example: DatasetExample
19
+ run_comparison_items: List[RunComparisonItem]