arize-phoenix 4.4.3__py3-none-any.whl → 4.4.4rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. {arize_phoenix-4.4.3.dist-info → arize_phoenix-4.4.4rc0.dist-info}/METADATA +4 -4
  2. {arize_phoenix-4.4.3.dist-info → arize_phoenix-4.4.4rc0.dist-info}/RECORD +108 -55
  3. phoenix/__init__.py +0 -27
  4. phoenix/config.py +21 -7
  5. phoenix/core/model.py +25 -25
  6. phoenix/core/model_schema.py +64 -62
  7. phoenix/core/model_schema_adapter.py +27 -25
  8. phoenix/datasets/__init__.py +0 -0
  9. phoenix/datasets/evaluators.py +275 -0
  10. phoenix/datasets/experiments.py +469 -0
  11. phoenix/datasets/tracing.py +66 -0
  12. phoenix/datasets/types.py +212 -0
  13. phoenix/db/bulk_inserter.py +54 -14
  14. phoenix/db/insertion/dataset.py +234 -0
  15. phoenix/db/insertion/evaluation.py +6 -6
  16. phoenix/db/insertion/helpers.py +13 -2
  17. phoenix/db/migrations/types.py +29 -0
  18. phoenix/db/migrations/versions/10460e46d750_datasets.py +291 -0
  19. phoenix/db/migrations/versions/cf03bd6bae1d_init.py +2 -28
  20. phoenix/db/models.py +230 -3
  21. phoenix/inferences/fixtures.py +23 -23
  22. phoenix/inferences/inferences.py +7 -7
  23. phoenix/inferences/validation.py +1 -1
  24. phoenix/server/api/context.py +16 -0
  25. phoenix/server/api/dataloaders/__init__.py +16 -0
  26. phoenix/server/api/dataloaders/dataset_example_revisions.py +100 -0
  27. phoenix/server/api/dataloaders/dataset_example_spans.py +43 -0
  28. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +85 -0
  29. phoenix/server/api/dataloaders/experiment_error_rates.py +43 -0
  30. phoenix/server/api/dataloaders/experiment_sequence_number.py +49 -0
  31. phoenix/server/api/dataloaders/project_by_name.py +31 -0
  32. phoenix/server/api/dataloaders/span_descendants.py +2 -3
  33. phoenix/server/api/dataloaders/span_projects.py +33 -0
  34. phoenix/server/api/dataloaders/trace_row_ids.py +39 -0
  35. phoenix/server/api/helpers/dataset_helpers.py +178 -0
  36. phoenix/server/api/input_types/AddExamplesToDatasetInput.py +16 -0
  37. phoenix/server/api/input_types/AddSpansToDatasetInput.py +14 -0
  38. phoenix/server/api/input_types/CreateDatasetInput.py +12 -0
  39. phoenix/server/api/input_types/DatasetExampleInput.py +14 -0
  40. phoenix/server/api/input_types/DatasetSort.py +17 -0
  41. phoenix/server/api/input_types/DatasetVersionSort.py +16 -0
  42. phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +13 -0
  43. phoenix/server/api/input_types/DeleteDatasetInput.py +7 -0
  44. phoenix/server/api/input_types/DeleteExperimentsInput.py +9 -0
  45. phoenix/server/api/input_types/PatchDatasetExamplesInput.py +35 -0
  46. phoenix/server/api/input_types/PatchDatasetInput.py +14 -0
  47. phoenix/server/api/mutations/__init__.py +13 -0
  48. phoenix/server/api/mutations/auth.py +11 -0
  49. phoenix/server/api/mutations/dataset_mutations.py +520 -0
  50. phoenix/server/api/mutations/experiment_mutations.py +65 -0
  51. phoenix/server/api/{types/ExportEventsMutation.py → mutations/export_events_mutations.py} +17 -14
  52. phoenix/server/api/mutations/project_mutations.py +42 -0
  53. phoenix/server/api/queries.py +503 -0
  54. phoenix/server/api/routers/v1/__init__.py +77 -2
  55. phoenix/server/api/routers/v1/dataset_examples.py +178 -0
  56. phoenix/server/api/routers/v1/datasets.py +861 -0
  57. phoenix/server/api/routers/v1/evaluations.py +4 -2
  58. phoenix/server/api/routers/v1/experiment_evaluations.py +65 -0
  59. phoenix/server/api/routers/v1/experiment_runs.py +108 -0
  60. phoenix/server/api/routers/v1/experiments.py +174 -0
  61. phoenix/server/api/routers/v1/spans.py +3 -1
  62. phoenix/server/api/routers/v1/traces.py +1 -4
  63. phoenix/server/api/schema.py +2 -303
  64. phoenix/server/api/types/AnnotatorKind.py +10 -0
  65. phoenix/server/api/types/Cluster.py +19 -19
  66. phoenix/server/api/types/CreateDatasetPayload.py +8 -0
  67. phoenix/server/api/types/Dataset.py +282 -63
  68. phoenix/server/api/types/DatasetExample.py +85 -0
  69. phoenix/server/api/types/DatasetExampleRevision.py +34 -0
  70. phoenix/server/api/types/DatasetVersion.py +14 -0
  71. phoenix/server/api/types/Dimension.py +30 -29
  72. phoenix/server/api/types/EmbeddingDimension.py +40 -34
  73. phoenix/server/api/types/Event.py +16 -16
  74. phoenix/server/api/types/ExampleRevisionInterface.py +14 -0
  75. phoenix/server/api/types/Experiment.py +135 -0
  76. phoenix/server/api/types/ExperimentAnnotationSummary.py +13 -0
  77. phoenix/server/api/types/ExperimentComparison.py +19 -0
  78. phoenix/server/api/types/ExperimentRun.py +91 -0
  79. phoenix/server/api/types/ExperimentRunAnnotation.py +57 -0
  80. phoenix/server/api/types/Inferences.py +80 -0
  81. phoenix/server/api/types/InferencesRole.py +23 -0
  82. phoenix/server/api/types/Model.py +43 -42
  83. phoenix/server/api/types/Project.py +26 -12
  84. phoenix/server/api/types/Span.py +78 -2
  85. phoenix/server/api/types/TimeSeries.py +6 -6
  86. phoenix/server/api/types/Trace.py +15 -4
  87. phoenix/server/api/types/UMAPPoints.py +1 -1
  88. phoenix/server/api/types/node.py +5 -111
  89. phoenix/server/api/types/pagination.py +10 -52
  90. phoenix/server/app.py +99 -49
  91. phoenix/server/main.py +49 -27
  92. phoenix/server/openapi/docs.py +3 -0
  93. phoenix/server/static/index.js +2246 -1368
  94. phoenix/server/templates/index.html +1 -0
  95. phoenix/services.py +15 -15
  96. phoenix/session/client.py +316 -21
  97. phoenix/session/session.py +47 -37
  98. phoenix/trace/exporter.py +14 -9
  99. phoenix/trace/fixtures.py +133 -7
  100. phoenix/trace/span_evaluations.py +3 -3
  101. phoenix/trace/trace_dataset.py +6 -6
  102. phoenix/utilities/json.py +61 -0
  103. phoenix/utilities/re.py +50 -0
  104. phoenix/version.py +1 -1
  105. phoenix/server/api/types/DatasetRole.py +0 -23
  106. {arize_phoenix-4.4.3.dist-info → arize_phoenix-4.4.4rc0.dist-info}/WHEEL +0 -0
  107. {arize_phoenix-4.4.3.dist-info → arize_phoenix-4.4.4rc0.dist-info}/licenses/IP_NOTICE +0 -0
  108. {arize_phoenix-4.4.3.dist-info → arize_phoenix-4.4.4rc0.dist-info}/licenses/LICENSE +0 -0
  109. /phoenix/server/api/{helpers.py → helpers/__init__.py} +0 -0
@@ -0,0 +1,23 @@
1
+ from enum import Enum
2
+ from typing import Dict, Union
3
+
4
+ import strawberry
5
+
6
+ from phoenix.core.model_schema import PRIMARY, REFERENCE
7
+
8
+
9
+ @strawberry.enum
10
+ class InferencesRole(Enum):
11
+ primary = PRIMARY
12
+ reference = REFERENCE
13
+
14
+
15
+ class AncillaryInferencesRole(Enum):
16
+ corpus = "InferencesRole.CORPUS"
17
+
18
+
19
+ STR_TO_INFEREENCES_ROLE: Dict[str, Union[InferencesRole, AncillaryInferencesRole]] = {
20
+ str(InferencesRole.primary.value): InferencesRole.primary,
21
+ str(InferencesRole.reference.value): InferencesRole.reference,
22
+ str(AncillaryInferencesRole.corpus.value): AncillaryInferencesRole.corpus,
23
+ }
@@ -2,6 +2,7 @@ import asyncio
2
2
  from typing import List, Optional
3
3
 
4
4
  import strawberry
5
+ from strawberry.relay import Connection
5
6
  from strawberry.types import Info
6
7
  from strawberry.unset import UNSET
7
8
  from typing_extensions import Annotated
@@ -14,12 +15,12 @@ from ..input_types.DimensionFilter import DimensionFilter
14
15
  from ..input_types.Granularity import Granularity
15
16
  from ..input_types.PerformanceMetricInput import PerformanceMetricInput
16
17
  from ..input_types.TimeRange import TimeRange
17
- from .Dataset import Dataset
18
- from .DatasetRole import AncillaryDatasetRole, DatasetRole
19
18
  from .Dimension import Dimension, to_gql_dimension
20
19
  from .EmbeddingDimension import EmbeddingDimension, to_gql_embedding_dimension
21
20
  from .ExportedFile import ExportedFile
22
- from .pagination import Connection, ConnectionArgs, CursorString, connection_from_list
21
+ from .Inferences import Inferences
22
+ from .InferencesRole import AncillaryInferencesRole, InferencesRole
23
+ from .pagination import ConnectionArgs, CursorString, connection_from_list
23
24
  from .TimeSeries import (
24
25
  PerformanceTimeSeries,
25
26
  ensure_timeseries_parameters,
@@ -57,45 +58,45 @@ class Model:
57
58
  )
58
59
 
59
60
  @strawberry.field
60
- def primary_dataset(self, info: Info[Context, None]) -> Dataset:
61
- dataset = info.context.model[PRIMARY]
62
- start, stop = dataset.time_range
63
- return Dataset(
61
+ def primary_inferences(self, info: Info[Context, None]) -> Inferences:
62
+ inferences = info.context.model[PRIMARY]
63
+ start, stop = inferences.time_range
64
+ return Inferences(
64
65
  start_time=start,
65
66
  end_time=stop,
66
- record_count=len(dataset),
67
- dataset=dataset,
68
- dataset_role=DatasetRole.primary,
67
+ record_count=len(inferences),
68
+ inferences=inferences,
69
+ inferences_role=InferencesRole.primary,
69
70
  model=info.context.model,
70
71
  )
71
72
 
72
73
  @strawberry.field
73
- def reference_dataset(self, info: Info[Context, None]) -> Optional[Dataset]:
74
- if (dataset := info.context.model[REFERENCE]).empty:
74
+ def reference_inferences(self, info: Info[Context, None]) -> Optional[Inferences]:
75
+ if (inferences := info.context.model[REFERENCE]).empty:
75
76
  return None
76
- start, stop = dataset.time_range
77
- return Dataset(
77
+ start, stop = inferences.time_range
78
+ return Inferences(
78
79
  start_time=start,
79
80
  end_time=stop,
80
- record_count=len(dataset),
81
- dataset=dataset,
82
- dataset_role=DatasetRole.reference,
81
+ record_count=len(inferences),
82
+ inferences=inferences,
83
+ inferences_role=InferencesRole.reference,
83
84
  model=info.context.model,
84
85
  )
85
86
 
86
87
  @strawberry.field
87
- def corpus_dataset(self, info: Info[Context, None]) -> Optional[Dataset]:
88
+ def corpus_inferences(self, info: Info[Context, None]) -> Optional[Inferences]:
88
89
  if info.context.corpus is None:
89
90
  return None
90
- if (dataset := info.context.corpus[PRIMARY]).empty:
91
+ if (inferences := info.context.corpus[PRIMARY]).empty:
91
92
  return None
92
- start, stop = dataset.time_range
93
- return Dataset(
93
+ start, stop = inferences.time_range
94
+ return Inferences(
94
95
  start_time=start,
95
96
  end_time=stop,
96
- record_count=len(dataset),
97
- dataset=dataset,
98
- dataset_role=AncillaryDatasetRole.corpus,
97
+ record_count=len(inferences),
98
+ inferences=inferences,
99
+ inferences_role=AncillaryInferencesRole.corpus,
99
100
  model=info.context.corpus,
100
101
  )
101
102
 
@@ -156,24 +157,24 @@ class Model:
156
157
  info: Info[Context, None],
157
158
  metric: PerformanceMetricInput,
158
159
  time_range: Optional[TimeRange] = UNSET,
159
- dataset_role: Annotated[
160
- Optional[DatasetRole],
160
+ inferences_role: Annotated[
161
+ Optional[InferencesRole],
161
162
  strawberry.argument(
162
- description="The dataset (primary or reference) to query",
163
+ description="The inferences (primary or reference) to query",
163
164
  ),
164
- ] = DatasetRole.primary,
165
+ ] = InferencesRole.primary,
165
166
  ) -> Optional[float]:
166
- if not isinstance(dataset_role, DatasetRole):
167
- dataset_role = DatasetRole.primary
167
+ if not isinstance(inferences_role, InferencesRole):
168
+ inferences_role = InferencesRole.primary
168
169
  model = info.context.model
169
- dataset = model[dataset_role.value]
170
+ inferences = model[inferences_role.value]
170
171
  time_range, granularity = ensure_timeseries_parameters(
171
- dataset,
172
+ inferences,
172
173
  time_range,
173
174
  )
174
175
  metric_instance = metric.metric_instance(model)
175
176
  data = get_timeseries_data(
176
- dataset,
177
+ inferences,
177
178
  metric_instance,
178
179
  time_range,
179
180
  granularity,
@@ -194,26 +195,26 @@ class Model:
194
195
  metric: PerformanceMetricInput,
195
196
  time_range: TimeRange,
196
197
  granularity: Granularity,
197
- dataset_role: Annotated[
198
- Optional[DatasetRole],
198
+ inferences_role: Annotated[
199
+ Optional[InferencesRole],
199
200
  strawberry.argument(
200
- description="The dataset (primary or reference) to query",
201
+ description="The inferences (primary or reference) to query",
201
202
  ),
202
- ] = DatasetRole.primary,
203
+ ] = InferencesRole.primary,
203
204
  ) -> PerformanceTimeSeries:
204
- if not isinstance(dataset_role, DatasetRole):
205
- dataset_role = DatasetRole.primary
205
+ if not isinstance(inferences_role, InferencesRole):
206
+ inferences_role = InferencesRole.primary
206
207
  model = info.context.model
207
- dataset = model[dataset_role.value]
208
+ inferences = model[inferences_role.value]
208
209
  time_range, granularity = ensure_timeseries_parameters(
209
- dataset,
210
+ inferences,
210
211
  time_range,
211
212
  granularity,
212
213
  )
213
214
  metric_instance = metric.metric_instance(model)
214
215
  return PerformanceTimeSeries(
215
216
  data=get_timeseries_data(
216
- dataset,
217
+ inferences,
217
218
  metric_instance,
218
219
  time_range,
219
220
  granularity,
@@ -1,6 +1,10 @@
1
1
  import operator
2
2
  from datetime import datetime
3
- from typing import Any, List, Optional
3
+ from typing import (
4
+ Any,
5
+ List,
6
+ Optional,
7
+ )
4
8
 
5
9
  import strawberry
6
10
  from aioitertools.itertools import islice
@@ -8,6 +12,7 @@ from sqlalchemy import and_, desc, distinct, select
8
12
  from sqlalchemy.orm import contains_eager
9
13
  from sqlalchemy.sql.expression import tuple_
10
14
  from strawberry import ID, UNSET
15
+ from strawberry.relay import Connection, Node, NodeID
11
16
  from strawberry.types import Info
12
17
 
13
18
  from phoenix.datetime_utils import right_open_time_range
@@ -17,13 +22,11 @@ from phoenix.server.api.input_types.SpanSort import SpanSort, SpanSortConfig
17
22
  from phoenix.server.api.input_types.TimeRange import TimeRange
18
23
  from phoenix.server.api.types.DocumentEvaluationSummary import DocumentEvaluationSummary
19
24
  from phoenix.server.api.types.EvaluationSummary import EvaluationSummary
20
- from phoenix.server.api.types.node import Node
21
25
  from phoenix.server.api.types.pagination import (
22
- Connection,
23
26
  Cursor,
24
27
  CursorSortColumn,
25
28
  CursorString,
26
- connections,
29
+ connection_from_cursors_and_nodes,
27
30
  )
28
31
  from phoenix.server.api.types.SortDir import SortDir
29
32
  from phoenix.server.api.types.Span import Span, to_gql_span
@@ -31,11 +34,10 @@ from phoenix.server.api.types.Trace import Trace
31
34
  from phoenix.server.api.types.ValidationResult import ValidationResult
32
35
  from phoenix.trace.dsl import SpanFilter
33
36
 
34
- SPANS_LIMIT = 1000
35
-
36
37
 
37
38
  @strawberry.type
38
39
  class Project(Node):
40
+ id_attr: NodeID[int]
39
41
  name: str
40
42
  gradient_start_color: str
41
43
  gradient_end_color: str
@@ -149,7 +151,7 @@ class Project(Node):
149
151
  async with info.context.db() as session:
150
152
  if (id_attr := await session.scalar(stmt)) is None:
151
153
  return None
152
- return Trace(id_attr=id_attr)
154
+ return Trace(id_attr=id_attr, trace_id=trace_id, project_rowid=self.id_attr)
153
155
 
154
156
  @strawberry.field
155
157
  async def spans(
@@ -168,7 +170,7 @@ class Project(Node):
168
170
  select(models.Span)
169
171
  .join(models.Trace)
170
172
  .where(models.Trace.project_rowid == self.id_attr)
171
- .options(contains_eager(models.Span.trace))
173
+ .options(contains_eager(models.Span.trace).load_only(models.Trace.trace_id))
172
174
  )
173
175
  if time_range:
174
176
  stmt = stmt.where(
@@ -213,7 +215,7 @@ class Project(Node):
213
215
  first + 1 # overfetch by one to determine whether there's a next page
214
216
  )
215
217
  stmt = stmt.order_by(cursor_rowid_column)
216
- data = []
218
+ cursors_and_nodes = []
217
219
  async with info.context.db() as session:
218
220
  span_records = await session.execute(stmt)
219
221
  async for span_record in islice(span_records, first):
@@ -230,15 +232,15 @@ class Project(Node):
230
232
  else None
231
233
  ),
232
234
  )
233
- data.append((cursor, to_gql_span(span)))
235
+ cursors_and_nodes.append((cursor, to_gql_span(span)))
234
236
  has_next_page = True
235
237
  try:
236
238
  next(span_records)
237
239
  except StopIteration:
238
240
  has_next_page = False
239
241
 
240
- return connections(
241
- data,
242
+ return connection_from_cursors_and_nodes(
243
+ cursors_and_nodes,
242
244
  has_previous_page=False,
243
245
  has_next_page=has_next_page,
244
246
  )
@@ -355,3 +357,15 @@ class Project(Node):
355
357
  is_valid=False,
356
358
  error_message=e.msg,
357
359
  )
360
+
361
+
362
+ def to_gql_project(project: models.Project) -> Project:
363
+ """
364
+ Converts an ORM project to a GraphQL Project.
365
+ """
366
+ return Project(
367
+ id_attr=project.id,
368
+ name=project.name,
369
+ gradient_start_color=project.gradient_start_color,
370
+ gradient_end_color=project.gradient_end_color,
371
+ )
@@ -1,23 +1,33 @@
1
1
  import json
2
+ from dataclasses import dataclass
2
3
  from datetime import datetime
3
4
  from enum import Enum
4
- from typing import Any, List, Mapping, Optional, Sized, cast
5
+ from typing import TYPE_CHECKING, Any, List, Mapping, Optional, Sized, cast
5
6
 
6
7
  import numpy as np
7
8
  import strawberry
8
9
  from openinference.semconv.trace import EmbeddingAttributes, SpanAttributes
9
10
  from strawberry import ID, UNSET
11
+ from strawberry.relay import Node, NodeID
10
12
  from strawberry.types import Info
13
+ from typing_extensions import Annotated
11
14
 
12
15
  import phoenix.trace.schemas as trace_schema
13
16
  from phoenix.db import models
14
17
  from phoenix.server.api.context import Context
18
+ from phoenix.server.api.helpers.dataset_helpers import (
19
+ get_dataset_example_input,
20
+ get_dataset_example_output,
21
+ )
15
22
  from phoenix.server.api.types.DocumentRetrievalMetrics import DocumentRetrievalMetrics
16
23
  from phoenix.server.api.types.Evaluation import DocumentEvaluation, SpanEvaluation
24
+ from phoenix.server.api.types.ExampleRevisionInterface import ExampleRevision
17
25
  from phoenix.server.api.types.MimeType import MimeType
18
- from phoenix.server.api.types.node import Node
19
26
  from phoenix.trace.attributes import get_attribute_value
20
27
 
28
+ if TYPE_CHECKING:
29
+ from phoenix.server.api.types.Project import Project
30
+
21
31
  EMBEDDING_EMBEDDINGS = SpanAttributes.EMBEDDING_EMBEDDINGS
22
32
  EMBEDDING_VECTOR = EmbeddingAttributes.EMBEDDING_VECTOR
23
33
  INPUT_MIME_TYPE = SpanAttributes.INPUT_MIME_TYPE
@@ -25,6 +35,9 @@ INPUT_VALUE = SpanAttributes.INPUT_VALUE
25
35
  LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION
26
36
  LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT
27
37
  LLM_TOKEN_COUNT_TOTAL = SpanAttributes.LLM_TOKEN_COUNT_TOTAL
38
+ LLM_PROMPT_TEMPLATE_VARIABLES = SpanAttributes.LLM_PROMPT_TEMPLATE_VARIABLES
39
+ LLM_INPUT_MESSAGES = SpanAttributes.LLM_INPUT_MESSAGES
40
+ LLM_OUTPUT_MESSAGES = SpanAttributes.LLM_OUTPUT_MESSAGES
28
41
  METADATA = SpanAttributes.METADATA
29
42
  OUTPUT_MIME_TYPE = SpanAttributes.OUTPUT_MIME_TYPE
30
43
  OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE
@@ -101,8 +114,14 @@ class SpanEvent:
101
114
  )
102
115
 
103
116
 
117
+ @strawberry.type
118
+ class SpanAsExampleRevision(ExampleRevision): ...
119
+
120
+
104
121
  @strawberry.type
105
122
  class Span(Node):
123
+ id_attr: NodeID[int]
124
+ db_span: strawberry.Private[models.Span]
106
125
  name: str
107
126
  status_code: SpanStatusCode
108
127
  status_message: str
@@ -188,6 +207,44 @@ class Span(Node):
188
207
  spans = await info.context.data_loaders.span_descendants.load(span_id)
189
208
  return [to_gql_span(span) for span in spans]
190
209
 
210
+ @strawberry.field(
211
+ description="The span's attributes translated into an example revision for a dataset",
212
+ ) # type: ignore
213
+ def as_example_revision(self) -> SpanAsExampleRevision:
214
+ db_span = self.db_span
215
+ attributes = db_span.attributes
216
+ span_io = _SpanIO(
217
+ span_kind=db_span.span_kind,
218
+ input_value=get_attribute_value(attributes, INPUT_VALUE),
219
+ input_mime_type=get_attribute_value(attributes, INPUT_MIME_TYPE),
220
+ output_value=get_attribute_value(attributes, OUTPUT_VALUE),
221
+ output_mime_type=get_attribute_value(attributes, OUTPUT_MIME_TYPE),
222
+ llm_prompt_template_variables=get_attribute_value(
223
+ attributes, LLM_PROMPT_TEMPLATE_VARIABLES
224
+ ),
225
+ llm_input_messages=get_attribute_value(attributes, LLM_INPUT_MESSAGES),
226
+ llm_output_messages=get_attribute_value(attributes, LLM_OUTPUT_MESSAGES),
227
+ retrieval_documents=get_attribute_value(attributes, RETRIEVAL_DOCUMENTS),
228
+ )
229
+ return SpanAsExampleRevision(
230
+ input=get_dataset_example_input(span_io),
231
+ output=get_dataset_example_output(span_io),
232
+ metadata=attributes,
233
+ )
234
+
235
+ @strawberry.field(description="The project that this span belongs to.") # type: ignore
236
+ async def project(
237
+ self,
238
+ info: Info[Context, None],
239
+ ) -> Annotated[
240
+ "Project", strawberry.lazy("phoenix.server.api.types.Project")
241
+ ]: # use lazy types to avoid circular import: https://strawberry.rocks/docs/types/lazy
242
+ from phoenix.server.api.types.Project import to_gql_project
243
+
244
+ span_id = self.id_attr
245
+ project = await info.context.data_loaders.span_projects.load(span_id)
246
+ return to_gql_project(project)
247
+
191
248
 
192
249
  def to_gql_span(span: models.Span) -> Span:
193
250
  events: List[SpanEvent] = list(map(SpanEvent.from_dict, span.events))
@@ -197,6 +254,7 @@ def to_gql_span(span: models.Span) -> Span:
197
254
  num_documents = len(retrieval_documents) if isinstance(retrieval_documents, Sized) else None
198
255
  return Span(
199
256
  id_attr=span.id,
257
+ db_span=span,
200
258
  name=span.name,
201
259
  status_code=SpanStatusCode(span.status_code),
202
260
  status_message=span.status_message,
@@ -302,3 +360,21 @@ def _convert_metadata_to_string(metadata: Any) -> Optional[str]:
302
360
  return json.dumps(metadata)
303
361
  except Exception:
304
362
  return str(metadata)
363
+
364
+
365
+ @dataclass
366
+ class _SpanIO:
367
+ """
368
+ An class that contains the information needed to extract dataset example
369
+ input and output values from a span.
370
+ """
371
+
372
+ span_kind: Optional[str]
373
+ input_value: Any
374
+ input_mime_type: Optional[str]
375
+ output_value: Any
376
+ output_mime_type: Optional[str]
377
+ llm_prompt_template_variables: Any
378
+ llm_input_messages: Any
379
+ llm_output_messages: Any
380
+ retrieval_documents: Any
@@ -7,7 +7,7 @@ import pandas as pd
7
7
  import strawberry
8
8
  from strawberry import UNSET
9
9
 
10
- from phoenix.core.model_schema import CONTINUOUS, PRIMARY, REFERENCE, Column, Dataset, Dimension
10
+ from phoenix.core.model_schema import CONTINUOUS, PRIMARY, REFERENCE, Column, Dimension, Inferences
11
11
  from phoenix.metrics import Metric, binning
12
12
  from phoenix.metrics.mixins import UnaryOperator
13
13
  from phoenix.metrics.timeseries import timeseries
@@ -15,7 +15,7 @@ from phoenix.server.api.input_types.Granularity import Granularity, to_timestamp
15
15
  from phoenix.server.api.input_types.TimeRange import TimeRange
16
16
  from phoenix.server.api.interceptor import GqlValueMediator
17
17
  from phoenix.server.api.types.DataQualityMetric import DataQualityMetric
18
- from phoenix.server.api.types.DatasetRole import DatasetRole
18
+ from phoenix.server.api.types.InferencesRole import InferencesRole
19
19
  from phoenix.server.api.types.ScalarDriftMetricEnum import ScalarDriftMetric
20
20
  from phoenix.server.api.types.VectorDriftMetricEnum import VectorDriftMetric
21
21
 
@@ -97,7 +97,7 @@ def get_data_quality_timeseries_data(
97
97
  metric: DataQualityMetric,
98
98
  time_range: TimeRange,
99
99
  granularity: Granularity,
100
- dataset_role: DatasetRole,
100
+ inferences_role: InferencesRole,
101
101
  ) -> List[TimeSeriesDataPoint]:
102
102
  metric_instance = metric.value()
103
103
  if isinstance(metric_instance, UnaryOperator):
@@ -106,7 +106,7 @@ def get_data_quality_timeseries_data(
106
106
  operand=Column(dimension.name),
107
107
  )
108
108
  df = pd.DataFrame(
109
- {dimension.name: dimension[dataset_role.value]},
109
+ {dimension.name: dimension[inferences_role.value]},
110
110
  copy=False,
111
111
  )
112
112
  return get_timeseries_data(
@@ -160,12 +160,12 @@ class PerformanceTimeSeries(TimeSeries):
160
160
 
161
161
 
162
162
  def ensure_timeseries_parameters(
163
- dataset: Dataset,
163
+ inferences: Inferences,
164
164
  time_range: Optional[TimeRange] = UNSET,
165
165
  granularity: Optional[Granularity] = UNSET,
166
166
  ) -> Tuple[TimeRange, Granularity]:
167
167
  if not isinstance(time_range, TimeRange):
168
- start, stop = dataset.time_range
168
+ start, stop = inferences.time_range
169
169
  time_range = TimeRange(start=start, end=stop)
170
170
  if not isinstance(granularity, Granularity):
171
171
  total_minutes = int((time_range.end - time_range.start).total_seconds()) // 60
@@ -1,17 +1,18 @@
1
+ from __future__ import annotations
2
+
1
3
  from typing import List, Optional
2
4
 
3
5
  import strawberry
4
6
  from sqlalchemy import desc, select
5
7
  from sqlalchemy.orm import contains_eager
6
- from strawberry import UNSET
8
+ from strawberry import UNSET, Private
9
+ from strawberry.relay import Connection, GlobalID, Node, NodeID
7
10
  from strawberry.types import Info
8
11
 
9
12
  from phoenix.db import models
10
13
  from phoenix.server.api.context import Context
11
14
  from phoenix.server.api.types.Evaluation import TraceEvaluation
12
- from phoenix.server.api.types.node import Node
13
15
  from phoenix.server.api.types.pagination import (
14
- Connection,
15
16
  ConnectionArgs,
16
17
  CursorString,
17
18
  connection_from_list,
@@ -21,6 +22,16 @@ from phoenix.server.api.types.Span import Span, to_gql_span
21
22
 
22
23
  @strawberry.type
23
24
  class Trace(Node):
25
+ id_attr: NodeID[int]
26
+ project_rowid: Private[int]
27
+ trace_id: str
28
+
29
+ @strawberry.field
30
+ async def project_id(self) -> GlobalID:
31
+ from phoenix.server.api.types.Project import Project
32
+
33
+ return GlobalID(type_name=Project.__name__, node_id=str(self.project_rowid))
34
+
24
35
  @strawberry.field
25
36
  async def spans(
26
37
  self,
@@ -40,7 +51,7 @@ class Trace(Node):
40
51
  select(models.Span)
41
52
  .join(models.Trace)
42
53
  .where(models.Trace.id == self.id_attr)
43
- .options(contains_eager(models.Span.trace))
54
+ .options(contains_eager(models.Span.trace).load_only(models.Trace.trace_id))
44
55
  # Sort descending because the root span tends to show up later
45
56
  # in the ingestion process.
46
57
  .order_by(desc(models.Span.id))
@@ -3,13 +3,13 @@ from typing import List, Union
3
3
  import numpy as np
4
4
  import numpy.typing as npt
5
5
  import strawberry
6
+ from strawberry.relay.types import GlobalID
6
7
  from strawberry.scalars import ID
7
8
 
8
9
  from phoenix.server.api.types.Cluster import Cluster
9
10
 
10
11
  from .EmbeddingMetadata import EmbeddingMetadata
11
12
  from .EventMetadata import EventMetadata
12
- from .node import GlobalID
13
13
  from .Retrieval import Retrieval
14
14
 
15
15
 
@@ -1,36 +1,19 @@
1
- import base64
2
- import dataclasses
3
- from typing import Tuple, Union
1
+ from typing import Tuple
4
2
 
5
- import strawberry
6
- from graphql import GraphQLID
7
- from strawberry.custom_scalar import ScalarDefinition
8
- from strawberry.schema.types.scalar import DEFAULT_SCALAR_REGISTRY
3
+ from strawberry.relay import GlobalID
9
4
 
10
5
 
11
- def to_global_id(type_name: str, node_id: int) -> str:
12
- """
13
- Encode the given id into a global id.
14
-
15
- :param type_name: The type of the node.
16
- :param node_id: The id of the node.
17
- :return: A global id.
18
- """
19
- return base64.b64encode(f"{type_name}:{node_id}".encode("utf-8")).decode()
20
-
21
-
22
- def from_global_id(global_id: str) -> Tuple[str, int]:
6
+ def from_global_id(global_id: GlobalID) -> Tuple[str, int]:
23
7
  """
24
8
  Decode the given global id into a type and id.
25
9
 
26
10
  :param global_id: The global id to decode.
27
11
  :return: A tuple of type and id.
28
12
  """
29
- type_name, node_id = base64.b64decode(global_id).decode().split(":")
30
- return type_name, int(node_id)
13
+ return global_id.type_name, int(global_id.node_id)
31
14
 
32
15
 
33
- def from_global_id_with_expected_type(global_id: str, expected_type_name: str) -> int:
16
+ def from_global_id_with_expected_type(global_id: GlobalID, expected_type_name: str) -> int:
34
17
  """
35
18
  Decodes the given global id and return the id, checking that the type
36
19
  matches the expected type.
@@ -42,92 +25,3 @@ def from_global_id_with_expected_type(global_id: str, expected_type_name: str) -
42
25
  f"but instead corresponds to a node of type: {type_name}"
43
26
  )
44
27
  return node_id
45
-
46
-
47
- class GlobalIDValueError(ValueError):
48
- """GlobalID value error, usually related to parsing or serialization."""
49
-
50
-
51
- @dataclasses.dataclass(frozen=True)
52
- class GlobalID:
53
- """Global ID for relay types.
54
- Different from `strawberry.ID`, this ID wraps the original object ID in a string
55
- that contains both its GraphQL type name and the ID itself, and encodes it
56
- to a base64_ string.
57
- This object contains helpers to work with that, including method to retrieve
58
- the python object type or even the encoded node itself.
59
- Attributes:
60
- type_name:
61
- The type name part of the id
62
- node_id:
63
- The node id part of the id
64
- .. _base64:
65
- https://en.wikipedia.org/wiki/Base64
66
- """
67
-
68
- type_name: str
69
- node_id: int
70
-
71
- def __post_init__(self) -> None:
72
- if not isinstance(self.type_name, str):
73
- raise GlobalIDValueError(
74
- f"type_name is expected to be a string, found {self.type_name}"
75
- )
76
- try:
77
- # node_id could be numpy.int64, hence the need for coercion
78
- object.__setattr__(self, "node_id", int(self.node_id))
79
- except ValueError:
80
- raise GlobalIDValueError(f"node_id is expected to be an int, found {self.node_id}")
81
-
82
- def __str__(self) -> str:
83
- return to_global_id(self.type_name, self.node_id)
84
-
85
- @classmethod
86
- def from_id(cls, value: Union[str, strawberry.ID]) -> "GlobalID":
87
- """Create a new GlobalID from parsing the given value.
88
- Args:
89
- value:
90
- The value to be parsed, as a base64 string in the "TypeName:NodeID" format
91
- Returns:
92
- An instance of GLobalID
93
- Raises:
94
- GlobalIDValueError:
95
- If the value is not in a GLobalID format
96
- """
97
- try:
98
- type_name, node_id = from_global_id(value)
99
- except ValueError as e:
100
- raise GlobalIDValueError(str(e)) from e
101
-
102
- return cls(type_name=type_name, node_id=node_id)
103
-
104
-
105
- @strawberry.interface(description="A node in the graph with a globally unique ID")
106
- class Node:
107
- """
108
- All types that are relay ready should inherit from this interface and
109
- implement the following methods.
110
-
111
- Attributes:
112
- id_attr:
113
- The raw id field of node. Typically a database id or index
114
- """
115
-
116
- id_attr: strawberry.Private[int]
117
-
118
- @strawberry.field
119
- def id(self) -> GlobalID:
120
- return GlobalID(type(self).__name__, self.id_attr)
121
-
122
-
123
- # Register our GlobalID scalar
124
- DEFAULT_SCALAR_REGISTRY[GlobalID] = ScalarDefinition(
125
- # Use the same name/description/parse_literal from GraphQLID
126
- # specs expect this type to be "ID".
127
- name="GlobalID",
128
- description=GraphQLID.description,
129
- parse_literal=lambda v, vars=None: GlobalID.from_id(GraphQLID.parse_literal(v, vars)),
130
- parse_value=GlobalID.from_id,
131
- serialize=str,
132
- specified_by_url="https://relay.dev/graphql/objectidentification.htm",
133
- )