arize-phoenix 4.4.2__py3-none-any.whl → 4.4.4rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. {arize_phoenix-4.4.2.dist-info → arize_phoenix-4.4.4rc0.dist-info}/METADATA +12 -11
  2. {arize_phoenix-4.4.2.dist-info → arize_phoenix-4.4.4rc0.dist-info}/RECORD +110 -57
  3. phoenix/__init__.py +0 -27
  4. phoenix/config.py +21 -7
  5. phoenix/core/model.py +25 -25
  6. phoenix/core/model_schema.py +66 -64
  7. phoenix/core/model_schema_adapter.py +27 -25
  8. phoenix/datasets/__init__.py +0 -0
  9. phoenix/datasets/evaluators.py +275 -0
  10. phoenix/datasets/experiments.py +469 -0
  11. phoenix/datasets/tracing.py +66 -0
  12. phoenix/datasets/types.py +212 -0
  13. phoenix/db/bulk_inserter.py +54 -14
  14. phoenix/db/insertion/dataset.py +234 -0
  15. phoenix/db/insertion/evaluation.py +6 -6
  16. phoenix/db/insertion/helpers.py +13 -2
  17. phoenix/db/migrations/types.py +29 -0
  18. phoenix/db/migrations/versions/10460e46d750_datasets.py +291 -0
  19. phoenix/db/migrations/versions/cf03bd6bae1d_init.py +2 -28
  20. phoenix/db/models.py +230 -3
  21. phoenix/inferences/fixtures.py +23 -23
  22. phoenix/inferences/inferences.py +7 -7
  23. phoenix/inferences/validation.py +1 -1
  24. phoenix/metrics/binning.py +2 -2
  25. phoenix/server/api/context.py +16 -0
  26. phoenix/server/api/dataloaders/__init__.py +16 -0
  27. phoenix/server/api/dataloaders/dataset_example_revisions.py +100 -0
  28. phoenix/server/api/dataloaders/dataset_example_spans.py +43 -0
  29. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +85 -0
  30. phoenix/server/api/dataloaders/experiment_error_rates.py +43 -0
  31. phoenix/server/api/dataloaders/experiment_sequence_number.py +49 -0
  32. phoenix/server/api/dataloaders/project_by_name.py +31 -0
  33. phoenix/server/api/dataloaders/span_descendants.py +2 -3
  34. phoenix/server/api/dataloaders/span_projects.py +33 -0
  35. phoenix/server/api/dataloaders/trace_row_ids.py +39 -0
  36. phoenix/server/api/helpers/dataset_helpers.py +178 -0
  37. phoenix/server/api/input_types/AddExamplesToDatasetInput.py +16 -0
  38. phoenix/server/api/input_types/AddSpansToDatasetInput.py +14 -0
  39. phoenix/server/api/input_types/CreateDatasetInput.py +12 -0
  40. phoenix/server/api/input_types/DatasetExampleInput.py +14 -0
  41. phoenix/server/api/input_types/DatasetSort.py +17 -0
  42. phoenix/server/api/input_types/DatasetVersionSort.py +16 -0
  43. phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +13 -0
  44. phoenix/server/api/input_types/DeleteDatasetInput.py +7 -0
  45. phoenix/server/api/input_types/DeleteExperimentsInput.py +9 -0
  46. phoenix/server/api/input_types/PatchDatasetExamplesInput.py +35 -0
  47. phoenix/server/api/input_types/PatchDatasetInput.py +14 -0
  48. phoenix/server/api/mutations/__init__.py +13 -0
  49. phoenix/server/api/mutations/auth.py +11 -0
  50. phoenix/server/api/mutations/dataset_mutations.py +520 -0
  51. phoenix/server/api/mutations/experiment_mutations.py +65 -0
  52. phoenix/server/api/{types/ExportEventsMutation.py → mutations/export_events_mutations.py} +17 -14
  53. phoenix/server/api/mutations/project_mutations.py +42 -0
  54. phoenix/server/api/queries.py +503 -0
  55. phoenix/server/api/routers/v1/__init__.py +77 -2
  56. phoenix/server/api/routers/v1/dataset_examples.py +178 -0
  57. phoenix/server/api/routers/v1/datasets.py +861 -0
  58. phoenix/server/api/routers/v1/evaluations.py +4 -2
  59. phoenix/server/api/routers/v1/experiment_evaluations.py +65 -0
  60. phoenix/server/api/routers/v1/experiment_runs.py +108 -0
  61. phoenix/server/api/routers/v1/experiments.py +174 -0
  62. phoenix/server/api/routers/v1/spans.py +3 -1
  63. phoenix/server/api/routers/v1/traces.py +1 -4
  64. phoenix/server/api/schema.py +2 -303
  65. phoenix/server/api/types/AnnotatorKind.py +10 -0
  66. phoenix/server/api/types/Cluster.py +19 -19
  67. phoenix/server/api/types/CreateDatasetPayload.py +8 -0
  68. phoenix/server/api/types/Dataset.py +282 -63
  69. phoenix/server/api/types/DatasetExample.py +85 -0
  70. phoenix/server/api/types/DatasetExampleRevision.py +34 -0
  71. phoenix/server/api/types/DatasetVersion.py +14 -0
  72. phoenix/server/api/types/Dimension.py +30 -29
  73. phoenix/server/api/types/EmbeddingDimension.py +40 -34
  74. phoenix/server/api/types/Event.py +16 -16
  75. phoenix/server/api/types/ExampleRevisionInterface.py +14 -0
  76. phoenix/server/api/types/Experiment.py +135 -0
  77. phoenix/server/api/types/ExperimentAnnotationSummary.py +13 -0
  78. phoenix/server/api/types/ExperimentComparison.py +19 -0
  79. phoenix/server/api/types/ExperimentRun.py +91 -0
  80. phoenix/server/api/types/ExperimentRunAnnotation.py +57 -0
  81. phoenix/server/api/types/Inferences.py +80 -0
  82. phoenix/server/api/types/InferencesRole.py +23 -0
  83. phoenix/server/api/types/Model.py +43 -42
  84. phoenix/server/api/types/Project.py +26 -12
  85. phoenix/server/api/types/Segments.py +1 -1
  86. phoenix/server/api/types/Span.py +78 -2
  87. phoenix/server/api/types/TimeSeries.py +6 -6
  88. phoenix/server/api/types/Trace.py +15 -4
  89. phoenix/server/api/types/UMAPPoints.py +1 -1
  90. phoenix/server/api/types/node.py +5 -111
  91. phoenix/server/api/types/pagination.py +10 -52
  92. phoenix/server/app.py +99 -49
  93. phoenix/server/main.py +49 -27
  94. phoenix/server/openapi/docs.py +3 -0
  95. phoenix/server/static/index.js +2246 -1368
  96. phoenix/server/templates/index.html +1 -0
  97. phoenix/services.py +15 -15
  98. phoenix/session/client.py +316 -21
  99. phoenix/session/session.py +47 -37
  100. phoenix/trace/exporter.py +14 -9
  101. phoenix/trace/fixtures.py +133 -7
  102. phoenix/trace/span_evaluations.py +3 -3
  103. phoenix/trace/trace_dataset.py +6 -6
  104. phoenix/utilities/json.py +61 -0
  105. phoenix/utilities/re.py +50 -0
  106. phoenix/version.py +1 -1
  107. phoenix/server/api/types/DatasetRole.py +0 -23
  108. {arize_phoenix-4.4.2.dist-info → arize_phoenix-4.4.4rc0.dist-info}/WHEEL +0 -0
  109. {arize_phoenix-4.4.2.dist-info → arize_phoenix-4.4.4rc0.dist-info}/licenses/IP_NOTICE +0 -0
  110. {arize_phoenix-4.4.2.dist-info → arize_phoenix-4.4.4rc0.dist-info}/licenses/LICENSE +0 -0
  111. /phoenix/server/api/{helpers.py → helpers/__init__.py} +0 -0
@@ -8,6 +8,7 @@ import numpy.typing as npt
8
8
  import pandas as pd
9
9
  import strawberry
10
10
  from strawberry import UNSET
11
+ from strawberry.relay import GlobalID, Node, NodeID
11
12
  from strawberry.scalars import ID
12
13
  from strawberry.types import Info
13
14
  from typing_extensions import Annotated
@@ -22,7 +23,7 @@ from phoenix.core.model_schema import (
22
23
  PRIMARY,
23
24
  PROMPT,
24
25
  REFERENCE,
25
- Dataset,
26
+ Inferences,
26
27
  )
27
28
  from phoenix.metrics.timeseries import row_interval_from_sorted_time_index
28
29
  from phoenix.pointcloud.clustering import Hdbscan
@@ -31,7 +32,7 @@ from phoenix.pointcloud.projectors import Umap
31
32
  from phoenix.server.api.context import Context
32
33
  from phoenix.server.api.input_types.TimeRange import TimeRange
33
34
  from phoenix.server.api.types.Cluster import to_gql_clusters
34
- from phoenix.server.api.types.DatasetRole import AncillaryDatasetRole, DatasetRole
35
+ from phoenix.server.api.types.InferencesRole import AncillaryInferencesRole, InferencesRole
35
36
  from phoenix.server.api.types.VectorDriftMetricEnum import VectorDriftMetric
36
37
 
37
38
  from ..input_types.Granularity import Granularity
@@ -39,7 +40,6 @@ from .DataQualityMetric import DataQualityMetric
39
40
  from .EmbeddingMetadata import EmbeddingMetadata
40
41
  from .Event import create_event_id, unpack_event_id
41
42
  from .EventMetadata import EventMetadata
42
- from .node import GlobalID, Node
43
43
  from .Retrieval import Retrieval
44
44
  from .TimeSeries import (
45
45
  DataQualityTimeSeries,
@@ -70,6 +70,7 @@ CORPUS = "CORPUS"
70
70
  class EmbeddingDimension(Node):
71
71
  """A embedding dimension of a model. Represents unstructured data"""
72
72
 
73
+ id_attr: NodeID[int]
73
74
  name: str
74
75
  dimension: strawberry.Private[ms.EmbeddingDimension]
75
76
 
@@ -155,16 +156,16 @@ class EmbeddingDimension(Node):
155
156
  metric: DataQualityMetric,
156
157
  time_range: TimeRange,
157
158
  granularity: Granularity,
158
- dataset_role: Annotated[
159
- Optional[DatasetRole],
159
+ inferences_role: Annotated[
160
+ Optional[InferencesRole],
160
161
  strawberry.argument(
161
162
  description="The dataset (primary or reference) to query",
162
163
  ),
163
- ] = DatasetRole.primary,
164
+ ] = InferencesRole.primary,
164
165
  ) -> DataQualityTimeSeries:
165
- if not isinstance(dataset_role, DatasetRole):
166
- dataset_role = DatasetRole.primary
167
- dataset = info.context.model[dataset_role.value]
166
+ if not isinstance(inferences_role, InferencesRole):
167
+ inferences_role = InferencesRole.primary
168
+ dataset = info.context.model[inferences_role.value]
168
169
  time_range, granularity = ensure_timeseries_parameters(
169
170
  dataset,
170
171
  time_range,
@@ -176,7 +177,7 @@ class EmbeddingDimension(Node):
176
177
  metric,
177
178
  time_range,
178
179
  granularity,
179
- dataset_role,
180
+ inferences_role,
180
181
  )
181
182
  )
182
183
 
@@ -314,16 +315,16 @@ class EmbeddingDimension(Node):
314
315
  model = info.context.model
315
316
  data: Dict[ID, npt.NDArray[np.float64]] = {}
316
317
  retrievals: List[Tuple[ID, Any, Any]] = []
317
- for dataset in model[Dataset]:
318
- dataset_id = dataset.role
319
- row_id_start, row_id_stop = 0, len(dataset)
320
- if dataset_id is PRIMARY:
318
+ for inferences in model[Inferences]:
319
+ inferences_id = inferences.role
320
+ row_id_start, row_id_stop = 0, len(inferences)
321
+ if inferences_id is PRIMARY:
321
322
  row_id_start, row_id_stop = row_interval_from_sorted_time_index(
322
- time_index=cast(pd.DatetimeIndex, dataset.index),
323
+ time_index=cast(pd.DatetimeIndex, inferences.index),
323
324
  time_start=time_range.start,
324
325
  time_stop=time_range.end,
325
326
  )
326
- vector_column = self.dimension[dataset_id]
327
+ vector_column = self.dimension[inferences_id]
327
328
  samples_collected = 0
328
329
  for row_id in _row_indices(
329
330
  row_id_start,
@@ -337,7 +338,7 @@ class EmbeddingDimension(Node):
337
338
  # of dunder method __len__.
338
339
  if not hasattr(embedding_vector, "__len__"):
339
340
  continue
340
- event_id = create_event_id(row_id, dataset_id)
341
+ event_id = create_event_id(row_id, inferences_id)
341
342
  data[event_id] = embedding_vector
342
343
  samples_collected += 1
343
344
  if isinstance(
@@ -347,8 +348,8 @@ class EmbeddingDimension(Node):
347
348
  retrievals.append(
348
349
  (
349
350
  event_id,
350
- self.dimension.context_retrieval_ids(dataset).iloc[row_id],
351
- self.dimension.context_retrieval_scores(dataset).iloc[row_id],
351
+ self.dimension.context_retrieval_ids(inferences).iloc[row_id],
352
+ self.dimension.context_retrieval_scores(inferences).iloc[row_id],
352
353
  )
353
354
  )
354
355
 
@@ -357,13 +358,13 @@ class EmbeddingDimension(Node):
357
358
  self.dimension,
358
359
  ms.RetrievalEmbeddingDimension,
359
360
  ) and (corpus := info.context.corpus):
360
- corpus_dataset = corpus[PRIMARY]
361
- for row_id, document_embedding_vector in enumerate(corpus_dataset[PROMPT]):
361
+ corpus_inferences = corpus[PRIMARY]
362
+ for row_id, document_embedding_vector in enumerate(corpus_inferences[PROMPT]):
362
363
  if not hasattr(document_embedding_vector, "__len__"):
363
364
  continue
364
- event_id = create_event_id(row_id, AncillaryDatasetRole.corpus)
365
+ event_id = create_event_id(row_id, AncillaryInferencesRole.corpus)
365
366
  data[event_id] = document_embedding_vector
366
- corpus_primary_key = corpus_dataset.primary_key
367
+ corpus_primary_key = corpus_inferences.primary_key
367
368
  for event_id, retrieval_ids, retrieval_scores in retrievals:
368
369
  if not isinstance(retrieval_ids, Iterable):
369
370
  continue
@@ -385,7 +386,7 @@ class EmbeddingDimension(Node):
385
386
  )
386
387
  except KeyError:
387
388
  continue
388
- document_embedding_vector = corpus_dataset[PROMPT].iloc[document_row_id]
389
+ document_embedding_vector = corpus_inferences[PROMPT].iloc[document_row_id]
389
390
  if not hasattr(document_embedding_vector, "__len__"):
390
391
  continue
391
392
  context_retrievals.append(
@@ -393,7 +394,7 @@ class EmbeddingDimension(Node):
393
394
  query_id=event_id,
394
395
  document_id=create_event_id(
395
396
  document_row_id,
396
- AncillaryDatasetRole.corpus,
397
+ AncillaryInferencesRole.corpus,
397
398
  ),
398
399
  relevance=document_score,
399
400
  )
@@ -413,11 +414,13 @@ class EmbeddingDimension(Node):
413
414
  ),
414
415
  ).generate(data, n_components=n_components)
415
416
 
416
- points: Dict[Union[DatasetRole, AncillaryDatasetRole], List[UMAPPoint]] = defaultdict(list)
417
+ points: Dict[Union[InferencesRole, AncillaryInferencesRole], List[UMAPPoint]] = defaultdict(
418
+ list
419
+ )
417
420
  for event_id, vector in vectors.items():
418
- row_id, dataset_role = unpack_event_id(event_id)
419
- if isinstance(dataset_role, DatasetRole):
420
- dataset = model[dataset_role.value]
421
+ row_id, inferences_role = unpack_event_id(event_id)
422
+ if isinstance(inferences_role, InferencesRole):
423
+ dataset = model[inferences_role.value]
421
424
  embedding_metadata = EmbeddingMetadata(
422
425
  prediction_id=dataset[PREDICTION_ID][row_id],
423
426
  link_to_data=dataset[self.dimension.link_to_data][row_id],
@@ -433,9 +436,12 @@ class EmbeddingDimension(Node):
433
436
  )
434
437
  else:
435
438
  continue
436
- points[dataset_role].append(
439
+ points[inferences_role].append(
437
440
  UMAPPoint(
438
- id=GlobalID(f"{type(self).__name__}:{str(dataset_role)}", row_id),
441
+ id=GlobalID(
442
+ type_name=f"{type(self).__name__}:{str(inferences_role)}",
443
+ node_id=str(row_id),
444
+ ),
439
445
  event_id=event_id,
440
446
  coordinates=to_gql_coordinates(vector),
441
447
  event_metadata=EventMetadata(
@@ -449,12 +455,12 @@ class EmbeddingDimension(Node):
449
455
  )
450
456
 
451
457
  return UMAPPoints(
452
- data=points[DatasetRole.primary],
453
- reference_data=points[DatasetRole.reference],
458
+ data=points[InferencesRole.primary],
459
+ reference_data=points[InferencesRole.reference],
454
460
  clusters=to_gql_clusters(
455
461
  clustered_events=clustered_events,
456
462
  ),
457
- corpus_data=points[AncillaryDatasetRole.corpus],
463
+ corpus_data=points[AncillaryInferencesRole.corpus],
458
464
  context_retrievals=context_retrievals,
459
465
  )
460
466
 
@@ -17,10 +17,10 @@ from phoenix.core.model_schema import (
17
17
  )
18
18
 
19
19
  from ..interceptor import GqlValueMediator
20
- from .DatasetRole import STR_TO_DATASET_ROLE, AncillaryDatasetRole, DatasetRole
21
20
  from .Dimension import Dimension
22
21
  from .DimensionWithValue import DimensionWithValue
23
22
  from .EventMetadata import EventMetadata
23
+ from .InferencesRole import STR_TO_INFEREENCES_ROLE, AncillaryInferencesRole, InferencesRole
24
24
  from .PromptResponse import PromptResponse
25
25
 
26
26
 
@@ -41,35 +41,35 @@ class Event:
41
41
 
42
42
  def create_event_id(
43
43
  row_id: int,
44
- dataset_role: Union[DatasetRole, AncillaryDatasetRole, ms.DatasetRole],
44
+ inferences_role: Union[InferencesRole, AncillaryInferencesRole, ms.InferencesRole],
45
45
  ) -> ID:
46
- dataset_role_str = (
47
- dataset_role.value
48
- if isinstance(dataset_role, (DatasetRole, AncillaryDatasetRole))
49
- else dataset_role
46
+ inferences_role_str = (
47
+ inferences_role.value
48
+ if isinstance(inferences_role, (InferencesRole, AncillaryInferencesRole))
49
+ else inferences_role
50
50
  )
51
- return ID(f"{row_id}:{dataset_role_str}")
51
+ return ID(f"{row_id}:{inferences_role_str}")
52
52
 
53
53
 
54
54
  def unpack_event_id(
55
55
  event_id: ID,
56
- ) -> Tuple[int, Union[DatasetRole, AncillaryDatasetRole]]:
57
- row_id_str, dataset_role_str = str(event_id).split(":")
56
+ ) -> Tuple[int, Union[InferencesRole, AncillaryInferencesRole]]:
57
+ row_id_str, inferences_role_str = str(event_id).split(":")
58
58
  row_id = int(row_id_str)
59
- dataset_role = STR_TO_DATASET_ROLE[dataset_role_str]
60
- return row_id, dataset_role
59
+ inferences_role = STR_TO_INFEREENCES_ROLE[inferences_role_str]
60
+ return row_id, inferences_role
61
61
 
62
62
 
63
- def parse_event_ids_by_dataset_role(
63
+ def parse_event_ids_by_inferences_role(
64
64
  event_ids: List[ID],
65
- ) -> Dict[Union[DatasetRole, AncillaryDatasetRole], List[int]]:
65
+ ) -> Dict[Union[InferencesRole, AncillaryInferencesRole], List[int]]:
66
66
  """
67
67
  Parses event IDs and returns the corresponding row indexes.
68
68
  """
69
- row_indexes: Dict[Union[DatasetRole, AncillaryDatasetRole], List[int]] = defaultdict(list)
69
+ row_indexes: Dict[Union[InferencesRole, AncillaryInferencesRole], List[int]] = defaultdict(list)
70
70
  for event_id in event_ids:
71
- row_id, dataset_role = unpack_event_id(event_id)
72
- row_indexes[dataset_role].append(row_id)
71
+ row_id, inferences_role = unpack_event_id(event_id)
72
+ row_indexes[inferences_role].append(row_id)
73
73
  return row_indexes
74
74
 
75
75
 
@@ -0,0 +1,14 @@
1
+ import strawberry
2
+ from strawberry.scalars import JSON
3
+
4
+
5
+ @strawberry.interface
6
+ class ExampleRevision:
7
+ """
8
+ Represents an example revision for generative tasks.
9
+ For example, you might have text -> text, text -> labels, etc.
10
+ """
11
+
12
+ input: JSON
13
+ output: JSON
14
+ metadata: JSON
@@ -0,0 +1,135 @@
1
+ from datetime import datetime
2
+ from typing import List, Optional
3
+
4
+ import strawberry
5
+ from sqlalchemy import select
6
+ from sqlalchemy.orm import joinedload
7
+ from strawberry import UNSET, Private
8
+ from strawberry.relay import Connection, Node, NodeID
9
+ from strawberry.scalars import JSON
10
+ from strawberry.types import Info
11
+
12
+ from phoenix.db import models
13
+ from phoenix.server.api.context import Context
14
+ from phoenix.server.api.types.ExperimentAnnotationSummary import ExperimentAnnotationSummary
15
+ from phoenix.server.api.types.ExperimentRun import ExperimentRun, to_gql_experiment_run
16
+ from phoenix.server.api.types.pagination import (
17
+ ConnectionArgs,
18
+ CursorString,
19
+ connection_from_list,
20
+ )
21
+ from phoenix.server.api.types.Project import Project
22
+
23
+
24
+ @strawberry.type
25
+ class Experiment(Node):
26
+ cached_sequence_number: Private[Optional[int]] = None
27
+ id_attr: NodeID[int]
28
+ name: str
29
+ project_name: Optional[str]
30
+ description: Optional[str]
31
+ metadata: JSON
32
+ created_at: datetime
33
+ updated_at: datetime
34
+
35
+ @strawberry.field(
36
+ description="Sequence number (1-based) of experiments belonging to the same dataset"
37
+ ) # type: ignore
38
+ async def sequence_number(
39
+ self,
40
+ info: Info[Context, None],
41
+ ) -> int:
42
+ if self.cached_sequence_number is None:
43
+ seq_num = await info.context.data_loaders.experiment_sequence_number.load(self.id_attr)
44
+ if seq_num is None:
45
+ raise ValueError(f"invalid experiment: id={self.id_attr}")
46
+ self.cached_sequence_number = seq_num
47
+ return self.cached_sequence_number
48
+
49
+ @strawberry.field
50
+ async def runs(
51
+ self,
52
+ info: Info[Context, None],
53
+ first: Optional[int] = 50,
54
+ last: Optional[int] = UNSET,
55
+ after: Optional[CursorString] = UNSET,
56
+ before: Optional[CursorString] = UNSET,
57
+ ) -> Connection[ExperimentRun]:
58
+ args = ConnectionArgs(
59
+ first=first,
60
+ after=after if isinstance(after, CursorString) else None,
61
+ last=last,
62
+ before=before if isinstance(before, CursorString) else None,
63
+ )
64
+ experiment_id = self.id_attr
65
+ async with info.context.db() as session:
66
+ runs = (
67
+ await session.scalars(
68
+ select(models.ExperimentRun)
69
+ .where(models.ExperimentRun.experiment_id == experiment_id)
70
+ .order_by(models.ExperimentRun.id.desc())
71
+ .options(
72
+ joinedload(models.ExperimentRun.trace).load_only(models.Trace.trace_id)
73
+ )
74
+ )
75
+ ).all()
76
+ return connection_from_list([to_gql_experiment_run(run) for run in runs], args)
77
+
78
+ @strawberry.field
79
+ async def annotation_summaries(
80
+ self, info: Info[Context, None]
81
+ ) -> List[ExperimentAnnotationSummary]:
82
+ experiment_id = self.id_attr
83
+ return [
84
+ ExperimentAnnotationSummary(
85
+ annotation_name=summary.annotation_name,
86
+ min_score=summary.min_score,
87
+ max_score=summary.max_score,
88
+ mean_score=summary.mean_score,
89
+ count=summary.count,
90
+ error_count=summary.error_count,
91
+ )
92
+ for summary in await info.context.data_loaders.experiment_annotation_summaries.load(
93
+ experiment_id
94
+ )
95
+ ]
96
+
97
+ @strawberry.field
98
+ async def error_rate(self, info: Info[Context, None]) -> Optional[float]:
99
+ return await info.context.data_loaders.experiment_error_rates.load(self.id_attr)
100
+
101
+ @strawberry.field
102
+ async def project(self, info: Info[Context, None]) -> Optional[Project]:
103
+ if self.project_name is None:
104
+ return None
105
+
106
+ db_project = await info.context.data_loaders.project_by_name.load(self.project_name)
107
+
108
+ if db_project is None:
109
+ return None
110
+
111
+ return Project(
112
+ id_attr=db_project.id,
113
+ name=db_project.name,
114
+ gradient_start_color=db_project.gradient_start_color,
115
+ gradient_end_color=db_project.gradient_end_color,
116
+ )
117
+
118
+
119
+ def to_gql_experiment(
120
+ experiment: models.Experiment,
121
+ sequence_number: Optional[int] = None,
122
+ ) -> Experiment:
123
+ """
124
+ Converts an ORM experiment to a GraphQL Experiment.
125
+ """
126
+ return Experiment(
127
+ cached_sequence_number=sequence_number,
128
+ id_attr=experiment.id,
129
+ name=experiment.name,
130
+ project_name=experiment.project_name,
131
+ description=experiment.description,
132
+ metadata=experiment.metadata_,
133
+ created_at=experiment.created_at,
134
+ updated_at=experiment.updated_at,
135
+ )
@@ -0,0 +1,13 @@
1
+ from typing import Optional
2
+
3
+ import strawberry
4
+
5
+
6
+ @strawberry.type
7
+ class ExperimentAnnotationSummary:
8
+ annotation_name: str
9
+ min_score: Optional[float]
10
+ max_score: Optional[float]
11
+ mean_score: Optional[float]
12
+ count: int
13
+ error_count: int
@@ -0,0 +1,19 @@
1
+ from typing import List
2
+
3
+ import strawberry
4
+ from strawberry.relay import GlobalID
5
+
6
+ from phoenix.server.api.types.DatasetExample import DatasetExample
7
+ from phoenix.server.api.types.ExperimentRun import ExperimentRun
8
+
9
+
10
+ @strawberry.type
11
+ class RunComparisonItem:
12
+ experiment_id: GlobalID
13
+ runs: List[ExperimentRun]
14
+
15
+
16
+ @strawberry.type
17
+ class ExperimentComparison:
18
+ example: DatasetExample
19
+ run_comparison_items: List[RunComparisonItem]
@@ -0,0 +1,91 @@
1
+ from datetime import datetime
2
+ from typing import Optional
3
+
4
+ import strawberry
5
+ from sqlalchemy import select
6
+ from strawberry import UNSET
7
+ from strawberry.relay import Connection, GlobalID, Node, NodeID
8
+ from strawberry.scalars import JSON
9
+ from strawberry.types import Info
10
+
11
+ from phoenix.db import models
12
+ from phoenix.server.api.context import Context
13
+ from phoenix.server.api.types.ExperimentRunAnnotation import (
14
+ ExperimentRunAnnotation,
15
+ to_gql_experiment_run_annotation,
16
+ )
17
+ from phoenix.server.api.types.pagination import (
18
+ ConnectionArgs,
19
+ CursorString,
20
+ connection_from_list,
21
+ )
22
+ from phoenix.server.api.types.Trace import Trace
23
+
24
+
25
+ @strawberry.type
26
+ class ExperimentRun(Node):
27
+ id_attr: NodeID[int]
28
+ experiment_id: GlobalID
29
+ trace_id: Optional[str]
30
+ output: Optional[JSON]
31
+ start_time: datetime
32
+ end_time: datetime
33
+ error: Optional[str]
34
+
35
+ @strawberry.field
36
+ async def annotations(
37
+ self,
38
+ info: Info[Context, None],
39
+ first: Optional[int] = 50,
40
+ last: Optional[int] = UNSET,
41
+ after: Optional[CursorString] = UNSET,
42
+ before: Optional[CursorString] = UNSET,
43
+ ) -> Connection[ExperimentRunAnnotation]:
44
+ args = ConnectionArgs(
45
+ first=first,
46
+ after=after if isinstance(after, CursorString) else None,
47
+ last=last,
48
+ before=before if isinstance(before, CursorString) else None,
49
+ )
50
+ run_id = self.id_attr
51
+ async with info.context.db() as session:
52
+ annotations = (
53
+ await session.scalars(
54
+ select(models.ExperimentRunAnnotation)
55
+ .where(models.ExperimentRunAnnotation.experiment_run_id == run_id)
56
+ .order_by(models.ExperimentRunAnnotation.id.desc())
57
+ )
58
+ ).all()
59
+ return connection_from_list(
60
+ [to_gql_experiment_run_annotation(annotation) for annotation in annotations], args
61
+ )
62
+
63
+ @strawberry.field
64
+ async def trace(self, info: Info) -> Optional[Trace]:
65
+ if not self.trace_id:
66
+ return None
67
+ dataloader = info.context.data_loaders.trace_row_ids
68
+ if (trace := await dataloader.load(self.trace_id)) is None:
69
+ return None
70
+ trace_rowid, project_rowid = trace
71
+ return Trace(id_attr=trace_rowid, trace_id=self.trace_id, project_rowid=project_rowid)
72
+
73
+
74
+ def to_gql_experiment_run(run: models.ExperimentRun) -> ExperimentRun:
75
+ """
76
+ Converts an ORM experiment run to a GraphQL ExperimentRun.
77
+ """
78
+
79
+ from phoenix.server.api.types.Experiment import Experiment
80
+
81
+ return ExperimentRun(
82
+ id_attr=run.id,
83
+ experiment_id=GlobalID(Experiment.__name__, str(run.experiment_id)),
84
+ trace_id=trace_id
85
+ if (trace := run.trace) and (trace_id := trace.trace_id) is not None
86
+ else None,
87
+ output=run.output,
88
+ start_time=run.start_time,
89
+ end_time=run.end_time,
90
+ error=run.error,
91
+ )
@@ -0,0 +1,57 @@
1
+ from datetime import datetime
2
+ from typing import Optional
3
+
4
+ import strawberry
5
+ from strawberry import Info
6
+ from strawberry.relay import Node, NodeID
7
+ from strawberry.scalars import JSON
8
+
9
+ from phoenix.db import models
10
+ from phoenix.server.api.types.AnnotatorKind import AnnotatorKind
11
+ from phoenix.server.api.types.Trace import Trace
12
+
13
+
14
+ @strawberry.type
15
+ class ExperimentRunAnnotation(Node):
16
+ id_attr: NodeID[int]
17
+ name: str
18
+ annotator_kind: AnnotatorKind
19
+ label: Optional[str]
20
+ score: Optional[float]
21
+ explanation: Optional[str]
22
+ error: Optional[str]
23
+ metadata: JSON
24
+ start_time: datetime
25
+ end_time: datetime
26
+ trace_id: Optional[str]
27
+
28
+ @strawberry.field
29
+ async def trace(self, info: Info) -> Optional[Trace]:
30
+ if not self.trace_id:
31
+ return None
32
+ dataloader = info.context.data_loaders.trace_row_ids
33
+ if (trace := await dataloader.load(self.trace_id)) is None:
34
+ return None
35
+ trace_row_id, project_row_id = trace
36
+ return Trace(id_attr=trace_row_id, trace_id=trace.trace_id, project_rowid=project_row_id)
37
+
38
+
39
+ def to_gql_experiment_run_annotation(
40
+ annotation: models.ExperimentRunAnnotation,
41
+ ) -> ExperimentRunAnnotation:
42
+ """
43
+ Converts an ORM experiment run annotation to a GraphQL ExperimentRunAnnotation.
44
+ """
45
+ return ExperimentRunAnnotation(
46
+ id_attr=annotation.id,
47
+ name=annotation.name,
48
+ annotator_kind=AnnotatorKind(annotation.annotator_kind),
49
+ label=annotation.label,
50
+ score=annotation.score,
51
+ explanation=annotation.explanation,
52
+ error=annotation.error,
53
+ metadata=annotation.metadata_,
54
+ start_time=annotation.start_time,
55
+ end_time=annotation.end_time,
56
+ trace_id=annotation.trace_id,
57
+ )
@@ -0,0 +1,80 @@
1
+ from datetime import datetime
2
+ from typing import Iterable, List, Optional, Set, Union
3
+
4
+ import strawberry
5
+ from strawberry.scalars import ID
6
+ from strawberry.unset import UNSET
7
+
8
+ import phoenix.core.model_schema as ms
9
+ from phoenix.core.model_schema import FEATURE, TAG, ScalarDimension
10
+
11
+ from ..input_types.DimensionInput import DimensionInput
12
+ from .Dimension import Dimension, to_gql_dimension
13
+ from .Event import Event, create_event, create_event_id, parse_event_ids_by_inferences_role
14
+ from .InferencesRole import AncillaryInferencesRole, InferencesRole
15
+
16
+
17
+ @strawberry.type
18
+ class Inferences:
19
+ start_time: datetime = strawberry.field(description="The start bookend of the data")
20
+ end_time: datetime = strawberry.field(description="The end bookend of the data")
21
+ record_count: int = strawberry.field(description="The record count of the data")
22
+ inferences: strawberry.Private[ms.Inferences]
23
+ inferences_role: strawberry.Private[Union[InferencesRole, AncillaryInferencesRole]]
24
+ model: strawberry.Private[ms.Model]
25
+
26
+ # type ignored here to get around the following: https://github.com/strawberry-graphql/strawberry/issues/1929
27
+ @strawberry.field(description="Returns a human friendly name for the inferences.") # type: ignore
28
+ def name(self) -> str:
29
+ return self.inferences.display_name
30
+
31
+ @strawberry.field
32
+ def events(
33
+ self,
34
+ event_ids: List[ID],
35
+ dimensions: Optional[List[DimensionInput]] = UNSET,
36
+ ) -> List[Event]:
37
+ """
38
+ Returns events for specific event IDs and dimensions. If no input
39
+ dimensions are provided, returns all features and tags.
40
+ """
41
+ if not event_ids:
42
+ return []
43
+ row_ids = parse_event_ids_by_inferences_role(event_ids)
44
+ if len(row_ids) > 1 or self.inferences_role not in row_ids:
45
+ raise ValueError("eventIds contains IDs from incorrect inferences.")
46
+ events = self.inferences[row_ids[self.inferences_role]]
47
+ requested_gql_dimensions = _get_requested_features_and_tags(
48
+ core_dimensions=self.model.scalar_dimensions,
49
+ requested_dimension_names=set(dim.name for dim in dimensions)
50
+ if isinstance(dimensions, list)
51
+ else None,
52
+ )
53
+ return [
54
+ create_event(
55
+ event_id=create_event_id(event.id.row_id, self.inferences_role),
56
+ event=event,
57
+ dimensions=requested_gql_dimensions,
58
+ is_document_record=self.inferences_role is AncillaryInferencesRole.corpus,
59
+ )
60
+ for event in events
61
+ ]
62
+
63
+
64
+ def _get_requested_features_and_tags(
65
+ core_dimensions: Iterable[ScalarDimension],
66
+ requested_dimension_names: Optional[Set[str]] = UNSET,
67
+ ) -> List[Dimension]:
68
+ """
69
+ Returns requested features and tags as a list of strawberry Inferences. If no
70
+ dimensions are explicitly requested, returns all features and tags.
71
+ """
72
+ requested_features_and_tags: List[Dimension] = []
73
+ for id, dim in enumerate(core_dimensions):
74
+ is_requested = (
75
+ not isinstance(requested_dimension_names, Set)
76
+ ) or dim.name in requested_dimension_names
77
+ is_feature_or_tag = dim.role in (FEATURE, TAG)
78
+ if is_requested and is_feature_or_tag:
79
+ requested_features_and_tags.append(to_gql_dimension(id_attr=id, dimension=dim))
80
+ return requested_features_and_tags