arize-phoenix 4.5.0__py3-none-any.whl → 4.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (123) hide show
  1. {arize_phoenix-4.5.0.dist-info → arize_phoenix-4.6.2.dist-info}/METADATA +16 -8
  2. {arize_phoenix-4.5.0.dist-info → arize_phoenix-4.6.2.dist-info}/RECORD +122 -58
  3. {arize_phoenix-4.5.0.dist-info → arize_phoenix-4.6.2.dist-info}/WHEEL +1 -1
  4. phoenix/__init__.py +0 -27
  5. phoenix/config.py +42 -7
  6. phoenix/core/model.py +25 -25
  7. phoenix/core/model_schema.py +64 -62
  8. phoenix/core/model_schema_adapter.py +27 -25
  9. phoenix/datetime_utils.py +4 -0
  10. phoenix/db/bulk_inserter.py +54 -14
  11. phoenix/db/insertion/dataset.py +237 -0
  12. phoenix/db/insertion/evaluation.py +10 -10
  13. phoenix/db/insertion/helpers.py +17 -14
  14. phoenix/db/insertion/span.py +3 -3
  15. phoenix/db/migrations/types.py +29 -0
  16. phoenix/db/migrations/versions/10460e46d750_datasets.py +291 -0
  17. phoenix/db/migrations/versions/cf03bd6bae1d_init.py +2 -28
  18. phoenix/db/models.py +236 -4
  19. phoenix/experiments/__init__.py +6 -0
  20. phoenix/experiments/evaluators/__init__.py +29 -0
  21. phoenix/experiments/evaluators/base.py +153 -0
  22. phoenix/experiments/evaluators/code_evaluators.py +99 -0
  23. phoenix/experiments/evaluators/llm_evaluators.py +244 -0
  24. phoenix/experiments/evaluators/utils.py +186 -0
  25. phoenix/experiments/functions.py +757 -0
  26. phoenix/experiments/tracing.py +85 -0
  27. phoenix/experiments/types.py +753 -0
  28. phoenix/experiments/utils.py +24 -0
  29. phoenix/inferences/fixtures.py +23 -23
  30. phoenix/inferences/inferences.py +7 -7
  31. phoenix/inferences/validation.py +1 -1
  32. phoenix/server/api/context.py +20 -0
  33. phoenix/server/api/dataloaders/__init__.py +20 -0
  34. phoenix/server/api/dataloaders/average_experiment_run_latency.py +54 -0
  35. phoenix/server/api/dataloaders/dataset_example_revisions.py +100 -0
  36. phoenix/server/api/dataloaders/dataset_example_spans.py +43 -0
  37. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +85 -0
  38. phoenix/server/api/dataloaders/experiment_error_rates.py +43 -0
  39. phoenix/server/api/dataloaders/experiment_run_counts.py +42 -0
  40. phoenix/server/api/dataloaders/experiment_sequence_number.py +49 -0
  41. phoenix/server/api/dataloaders/project_by_name.py +31 -0
  42. phoenix/server/api/dataloaders/span_descendants.py +2 -3
  43. phoenix/server/api/dataloaders/span_projects.py +33 -0
  44. phoenix/server/api/dataloaders/trace_row_ids.py +39 -0
  45. phoenix/server/api/helpers/dataset_helpers.py +179 -0
  46. phoenix/server/api/input_types/AddExamplesToDatasetInput.py +16 -0
  47. phoenix/server/api/input_types/AddSpansToDatasetInput.py +14 -0
  48. phoenix/server/api/input_types/ClearProjectInput.py +15 -0
  49. phoenix/server/api/input_types/CreateDatasetInput.py +12 -0
  50. phoenix/server/api/input_types/DatasetExampleInput.py +14 -0
  51. phoenix/server/api/input_types/DatasetSort.py +17 -0
  52. phoenix/server/api/input_types/DatasetVersionSort.py +16 -0
  53. phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +13 -0
  54. phoenix/server/api/input_types/DeleteDatasetInput.py +7 -0
  55. phoenix/server/api/input_types/DeleteExperimentsInput.py +9 -0
  56. phoenix/server/api/input_types/PatchDatasetExamplesInput.py +35 -0
  57. phoenix/server/api/input_types/PatchDatasetInput.py +14 -0
  58. phoenix/server/api/mutations/__init__.py +13 -0
  59. phoenix/server/api/mutations/auth.py +11 -0
  60. phoenix/server/api/mutations/dataset_mutations.py +520 -0
  61. phoenix/server/api/mutations/experiment_mutations.py +65 -0
  62. phoenix/server/api/{types/ExportEventsMutation.py → mutations/export_events_mutations.py} +17 -14
  63. phoenix/server/api/mutations/project_mutations.py +47 -0
  64. phoenix/server/api/openapi/__init__.py +0 -0
  65. phoenix/server/api/openapi/main.py +6 -0
  66. phoenix/server/api/openapi/schema.py +16 -0
  67. phoenix/server/api/queries.py +503 -0
  68. phoenix/server/api/routers/v1/__init__.py +77 -2
  69. phoenix/server/api/routers/v1/dataset_examples.py +178 -0
  70. phoenix/server/api/routers/v1/datasets.py +965 -0
  71. phoenix/server/api/routers/v1/evaluations.py +8 -13
  72. phoenix/server/api/routers/v1/experiment_evaluations.py +143 -0
  73. phoenix/server/api/routers/v1/experiment_runs.py +220 -0
  74. phoenix/server/api/routers/v1/experiments.py +302 -0
  75. phoenix/server/api/routers/v1/spans.py +9 -5
  76. phoenix/server/api/routers/v1/traces.py +1 -4
  77. phoenix/server/api/schema.py +2 -303
  78. phoenix/server/api/types/AnnotatorKind.py +10 -0
  79. phoenix/server/api/types/Cluster.py +19 -19
  80. phoenix/server/api/types/CreateDatasetPayload.py +8 -0
  81. phoenix/server/api/types/Dataset.py +282 -63
  82. phoenix/server/api/types/DatasetExample.py +85 -0
  83. phoenix/server/api/types/DatasetExampleRevision.py +34 -0
  84. phoenix/server/api/types/DatasetVersion.py +14 -0
  85. phoenix/server/api/types/Dimension.py +30 -29
  86. phoenix/server/api/types/EmbeddingDimension.py +40 -34
  87. phoenix/server/api/types/Event.py +16 -16
  88. phoenix/server/api/types/ExampleRevisionInterface.py +14 -0
  89. phoenix/server/api/types/Experiment.py +147 -0
  90. phoenix/server/api/types/ExperimentAnnotationSummary.py +13 -0
  91. phoenix/server/api/types/ExperimentComparison.py +19 -0
  92. phoenix/server/api/types/ExperimentRun.py +91 -0
  93. phoenix/server/api/types/ExperimentRunAnnotation.py +57 -0
  94. phoenix/server/api/types/Inferences.py +80 -0
  95. phoenix/server/api/types/InferencesRole.py +23 -0
  96. phoenix/server/api/types/Model.py +43 -42
  97. phoenix/server/api/types/Project.py +26 -12
  98. phoenix/server/api/types/Span.py +79 -2
  99. phoenix/server/api/types/TimeSeries.py +6 -6
  100. phoenix/server/api/types/Trace.py +15 -4
  101. phoenix/server/api/types/UMAPPoints.py +1 -1
  102. phoenix/server/api/types/node.py +5 -111
  103. phoenix/server/api/types/pagination.py +10 -52
  104. phoenix/server/app.py +103 -49
  105. phoenix/server/main.py +49 -27
  106. phoenix/server/openapi/docs.py +3 -0
  107. phoenix/server/static/index.js +2300 -1294
  108. phoenix/server/templates/index.html +1 -0
  109. phoenix/services.py +15 -15
  110. phoenix/session/client.py +581 -22
  111. phoenix/session/session.py +47 -37
  112. phoenix/trace/exporter.py +14 -9
  113. phoenix/trace/fixtures.py +133 -7
  114. phoenix/trace/schemas.py +1 -2
  115. phoenix/trace/span_evaluations.py +3 -3
  116. phoenix/trace/trace_dataset.py +6 -6
  117. phoenix/utilities/json.py +61 -0
  118. phoenix/utilities/re.py +50 -0
  119. phoenix/version.py +1 -1
  120. phoenix/server/api/types/DatasetRole.py +0 -23
  121. {arize_phoenix-4.5.0.dist-info → arize_phoenix-4.6.2.dist-info}/licenses/IP_NOTICE +0 -0
  122. {arize_phoenix-4.5.0.dist-info → arize_phoenix-4.6.2.dist-info}/licenses/LICENSE +0 -0
  123. /phoenix/server/api/{helpers.py → helpers/__init__.py} +0 -0
@@ -7,7 +7,7 @@ import pandas as pd
7
7
  import strawberry
8
8
  from strawberry import UNSET
9
9
 
10
- from phoenix.core.model_schema import CONTINUOUS, PRIMARY, REFERENCE, Column, Dataset, Dimension
10
+ from phoenix.core.model_schema import CONTINUOUS, PRIMARY, REFERENCE, Column, Dimension, Inferences
11
11
  from phoenix.metrics import Metric, binning
12
12
  from phoenix.metrics.mixins import UnaryOperator
13
13
  from phoenix.metrics.timeseries import timeseries
@@ -15,7 +15,7 @@ from phoenix.server.api.input_types.Granularity import Granularity, to_timestamp
15
15
  from phoenix.server.api.input_types.TimeRange import TimeRange
16
16
  from phoenix.server.api.interceptor import GqlValueMediator
17
17
  from phoenix.server.api.types.DataQualityMetric import DataQualityMetric
18
- from phoenix.server.api.types.DatasetRole import DatasetRole
18
+ from phoenix.server.api.types.InferencesRole import InferencesRole
19
19
  from phoenix.server.api.types.ScalarDriftMetricEnum import ScalarDriftMetric
20
20
  from phoenix.server.api.types.VectorDriftMetricEnum import VectorDriftMetric
21
21
 
@@ -97,7 +97,7 @@ def get_data_quality_timeseries_data(
97
97
  metric: DataQualityMetric,
98
98
  time_range: TimeRange,
99
99
  granularity: Granularity,
100
- dataset_role: DatasetRole,
100
+ inferences_role: InferencesRole,
101
101
  ) -> List[TimeSeriesDataPoint]:
102
102
  metric_instance = metric.value()
103
103
  if isinstance(metric_instance, UnaryOperator):
@@ -106,7 +106,7 @@ def get_data_quality_timeseries_data(
106
106
  operand=Column(dimension.name),
107
107
  )
108
108
  df = pd.DataFrame(
109
- {dimension.name: dimension[dataset_role.value]},
109
+ {dimension.name: dimension[inferences_role.value]},
110
110
  copy=False,
111
111
  )
112
112
  return get_timeseries_data(
@@ -160,12 +160,12 @@ class PerformanceTimeSeries(TimeSeries):
160
160
 
161
161
 
162
162
  def ensure_timeseries_parameters(
163
- dataset: Dataset,
163
+ inferences: Inferences,
164
164
  time_range: Optional[TimeRange] = UNSET,
165
165
  granularity: Optional[Granularity] = UNSET,
166
166
  ) -> Tuple[TimeRange, Granularity]:
167
167
  if not isinstance(time_range, TimeRange):
168
- start, stop = dataset.time_range
168
+ start, stop = inferences.time_range
169
169
  time_range = TimeRange(start=start, end=stop)
170
170
  if not isinstance(granularity, Granularity):
171
171
  total_minutes = int((time_range.end - time_range.start).total_seconds()) // 60
@@ -1,17 +1,18 @@
1
+ from __future__ import annotations
2
+
1
3
  from typing import List, Optional
2
4
 
3
5
  import strawberry
4
6
  from sqlalchemy import desc, select
5
7
  from sqlalchemy.orm import contains_eager
6
- from strawberry import UNSET
8
+ from strawberry import UNSET, Private
9
+ from strawberry.relay import Connection, GlobalID, Node, NodeID
7
10
  from strawberry.types import Info
8
11
 
9
12
  from phoenix.db import models
10
13
  from phoenix.server.api.context import Context
11
14
  from phoenix.server.api.types.Evaluation import TraceEvaluation
12
- from phoenix.server.api.types.node import Node
13
15
  from phoenix.server.api.types.pagination import (
14
- Connection,
15
16
  ConnectionArgs,
16
17
  CursorString,
17
18
  connection_from_list,
@@ -21,6 +22,16 @@ from phoenix.server.api.types.Span import Span, to_gql_span
21
22
 
22
23
  @strawberry.type
23
24
  class Trace(Node):
25
+ id_attr: NodeID[int]
26
+ project_rowid: Private[int]
27
+ trace_id: str
28
+
29
+ @strawberry.field
30
+ async def project_id(self) -> GlobalID:
31
+ from phoenix.server.api.types.Project import Project
32
+
33
+ return GlobalID(type_name=Project.__name__, node_id=str(self.project_rowid))
34
+
24
35
  @strawberry.field
25
36
  async def spans(
26
37
  self,
@@ -40,7 +51,7 @@ class Trace(Node):
40
51
  select(models.Span)
41
52
  .join(models.Trace)
42
53
  .where(models.Trace.id == self.id_attr)
43
- .options(contains_eager(models.Span.trace))
54
+ .options(contains_eager(models.Span.trace).load_only(models.Trace.trace_id))
44
55
  # Sort descending because the root span tends to show up later
45
56
  # in the ingestion process.
46
57
  .order_by(desc(models.Span.id))
@@ -3,13 +3,13 @@ from typing import List, Union
3
3
  import numpy as np
4
4
  import numpy.typing as npt
5
5
  import strawberry
6
+ from strawberry.relay.types import GlobalID
6
7
  from strawberry.scalars import ID
7
8
 
8
9
  from phoenix.server.api.types.Cluster import Cluster
9
10
 
10
11
  from .EmbeddingMetadata import EmbeddingMetadata
11
12
  from .EventMetadata import EventMetadata
12
- from .node import GlobalID
13
13
  from .Retrieval import Retrieval
14
14
 
15
15
 
@@ -1,36 +1,19 @@
1
- import base64
2
- import dataclasses
3
- from typing import Tuple, Union
1
+ from typing import Tuple
4
2
 
5
- import strawberry
6
- from graphql import GraphQLID
7
- from strawberry.custom_scalar import ScalarDefinition
8
- from strawberry.schema.types.scalar import DEFAULT_SCALAR_REGISTRY
3
+ from strawberry.relay import GlobalID
9
4
 
10
5
 
11
- def to_global_id(type_name: str, node_id: int) -> str:
12
- """
13
- Encode the given id into a global id.
14
-
15
- :param type_name: The type of the node.
16
- :param node_id: The id of the node.
17
- :return: A global id.
18
- """
19
- return base64.b64encode(f"{type_name}:{node_id}".encode("utf-8")).decode()
20
-
21
-
22
- def from_global_id(global_id: str) -> Tuple[str, int]:
6
+ def from_global_id(global_id: GlobalID) -> Tuple[str, int]:
23
7
  """
24
8
  Decode the given global id into a type and id.
25
9
 
26
10
  :param global_id: The global id to decode.
27
11
  :return: A tuple of type and id.
28
12
  """
29
- type_name, node_id = base64.b64decode(global_id).decode().split(":")
30
- return type_name, int(node_id)
13
+ return global_id.type_name, int(global_id.node_id)
31
14
 
32
15
 
33
- def from_global_id_with_expected_type(global_id: str, expected_type_name: str) -> int:
16
+ def from_global_id_with_expected_type(global_id: GlobalID, expected_type_name: str) -> int:
34
17
  """
35
18
  Decodes the given global id and return the id, checking that the type
36
19
  matches the expected type.
@@ -42,92 +25,3 @@ def from_global_id_with_expected_type(global_id: str, expected_type_name: str) -
42
25
  f"but instead corresponds to a node of type: {type_name}"
43
26
  )
44
27
  return node_id
45
-
46
-
47
- class GlobalIDValueError(ValueError):
48
- """GlobalID value error, usually related to parsing or serialization."""
49
-
50
-
51
- @dataclasses.dataclass(frozen=True)
52
- class GlobalID:
53
- """Global ID for relay types.
54
- Different from `strawberry.ID`, this ID wraps the original object ID in a string
55
- that contains both its GraphQL type name and the ID itself, and encodes it
56
- to a base64_ string.
57
- This object contains helpers to work with that, including method to retrieve
58
- the python object type or even the encoded node itself.
59
- Attributes:
60
- type_name:
61
- The type name part of the id
62
- node_id:
63
- The node id part of the id
64
- .. _base64:
65
- https://en.wikipedia.org/wiki/Base64
66
- """
67
-
68
- type_name: str
69
- node_id: int
70
-
71
- def __post_init__(self) -> None:
72
- if not isinstance(self.type_name, str):
73
- raise GlobalIDValueError(
74
- f"type_name is expected to be a string, found {self.type_name}"
75
- )
76
- try:
77
- # node_id could be numpy.int64, hence the need for coercion
78
- object.__setattr__(self, "node_id", int(self.node_id))
79
- except ValueError:
80
- raise GlobalIDValueError(f"node_id is expected to be an int, found {self.node_id}")
81
-
82
- def __str__(self) -> str:
83
- return to_global_id(self.type_name, self.node_id)
84
-
85
- @classmethod
86
- def from_id(cls, value: Union[str, strawberry.ID]) -> "GlobalID":
87
- """Create a new GlobalID from parsing the given value.
88
- Args:
89
- value:
90
- The value to be parsed, as a base64 string in the "TypeName:NodeID" format
91
- Returns:
92
- An instance of GLobalID
93
- Raises:
94
- GlobalIDValueError:
95
- If the value is not in a GLobalID format
96
- """
97
- try:
98
- type_name, node_id = from_global_id(value)
99
- except ValueError as e:
100
- raise GlobalIDValueError(str(e)) from e
101
-
102
- return cls(type_name=type_name, node_id=node_id)
103
-
104
-
105
- @strawberry.interface(description="A node in the graph with a globally unique ID")
106
- class Node:
107
- """
108
- All types that are relay ready should inherit from this interface and
109
- implement the following methods.
110
-
111
- Attributes:
112
- id_attr:
113
- The raw id field of node. Typically a database id or index
114
- """
115
-
116
- id_attr: strawberry.Private[int]
117
-
118
- @strawberry.field
119
- def id(self) -> GlobalID:
120
- return GlobalID(type(self).__name__, self.id_attr)
121
-
122
-
123
- # Register our GlobalID scalar
124
- DEFAULT_SCALAR_REGISTRY[GlobalID] = ScalarDefinition(
125
- # Use the same name/description/parse_literal from GraphQLID
126
- # specs expect this type to be "ID".
127
- name="GlobalID",
128
- description=GraphQLID.description,
129
- parse_literal=lambda v, vars=None: GlobalID.from_id(GraphQLID.parse_literal(v, vars)),
130
- parse_value=GlobalID.from_id,
131
- serialize=str,
132
- specified_by_url="https://relay.dev/graphql/objectidentification.htm",
133
- )
@@ -2,60 +2,18 @@ import base64
2
2
  from dataclasses import dataclass
3
3
  from datetime import datetime
4
4
  from enum import Enum, auto
5
- from typing import ClassVar, Generic, List, Optional, Tuple, TypeVar, Union
5
+ from typing import Any, ClassVar, List, Optional, Tuple, Union
6
6
 
7
- import strawberry
8
7
  from strawberry import UNSET
8
+ from strawberry.relay.types import Connection, Edge, NodeType, PageInfo
9
9
  from typing_extensions import TypeAlias, assert_never
10
10
 
11
11
  ID: TypeAlias = int
12
- GenericType = TypeVar("GenericType")
13
12
  CursorSortColumnValue: TypeAlias = Union[str, int, float, datetime]
14
13
 
15
-
16
- @strawberry.type
17
- class Connection(Generic[GenericType]):
18
- """Represents a paginated relationship between two entities
19
-
20
- This pattern is used when the relationship itself has attributes.
21
- """
22
-
23
- page_info: "PageInfo"
24
- edges: List["Edge[GenericType]"]
25
-
26
-
27
- @strawberry.type
28
- class PageInfo:
29
- """Pagination context to navigate objects with cursor-based pagination
30
-
31
- Instead of classic offset pagination via `page` and `limit` parameters,
32
- here we have a cursor of the last object and we fetch items starting from that one
33
-
34
- Read more at:
35
- - https://graphql.org/learn/pagination/#pagination-and-edges
36
- - https://relay.dev/graphql/connections.htm
37
- """
38
-
39
- has_next_page: bool
40
- has_previous_page: bool
41
- start_cursor: Optional[str]
42
- end_cursor: Optional[str]
43
-
44
-
45
14
  # A type alias for the connection cursor implementation
46
15
  CursorString = str
47
16
 
48
-
49
- @strawberry.type
50
- class Edge(Generic[GenericType]):
51
- """
52
- An edge may contain additional information of the relationship. This is the trivial case
53
- """
54
-
55
- node: GenericType
56
- cursor: str
57
-
58
-
59
17
  # The hashing prefix for a connection cursor
60
18
  CURSOR_PREFIX = "connection:"
61
19
 
@@ -218,9 +176,9 @@ class ConnectionArgs:
218
176
 
219
177
 
220
178
  def connection_from_list(
221
- data: List[GenericType],
179
+ data: List[NodeType],
222
180
  args: ConnectionArgs,
223
- ) -> Connection[GenericType]:
181
+ ) -> Connection[NodeType]:
224
182
  """
225
183
  A simple function that accepts a list and connection arguments, and returns
226
184
  a connection object for use in GraphQL. It uses list offsets as pagination,
@@ -230,11 +188,11 @@ def connection_from_list(
230
188
 
231
189
 
232
190
  def connection_from_list_slice(
233
- list_slice: List[GenericType],
191
+ list_slice: List[NodeType],
234
192
  args: ConnectionArgs,
235
193
  slice_start: int,
236
194
  list_length: int,
237
- ) -> Connection[GenericType]:
195
+ ) -> Connection[NodeType]:
238
196
  """
239
197
  Given a slice (subset) of a list, returns a connection object for use in
240
198
  GraphQL.
@@ -295,12 +253,12 @@ def connection_from_list_slice(
295
253
  )
296
254
 
297
255
 
298
- def connections(
299
- data: List[Tuple[Cursor, GenericType]],
256
+ def connection_from_cursors_and_nodes(
257
+ cursors_and_nodes: List[Tuple[Any, NodeType]],
300
258
  has_previous_page: bool,
301
259
  has_next_page: bool,
302
- ) -> Connection[GenericType]:
303
- edges = [Edge(node=node, cursor=str(cursor)) for cursor, node in data]
260
+ ) -> Connection[NodeType]:
261
+ edges = [Edge(node=node, cursor=str(cursor)) for cursor, node in cursors_and_nodes]
304
262
  has_edges = len(edges) > 0
305
263
  first_edge = edges[0] if has_edges else None
306
264
  last_edge = edges[-1] if has_edges else None
phoenix/server/app.py CHANGED
@@ -33,7 +33,6 @@ from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoin
33
33
  from starlette.requests import Request
34
34
  from starlette.responses import FileResponse, PlainTextResponse, Response
35
35
  from starlette.routing import Mount, Route
36
- from starlette.schemas import SchemaGenerator
37
36
  from starlette.staticfiles import StaticFiles
38
37
  from starlette.templating import Jinja2Templates
39
38
  from starlette.types import Scope, StatefulLifespan
@@ -57,19 +56,30 @@ from phoenix.exceptions import PhoenixMigrationError
57
56
  from phoenix.pointcloud.umap_parameters import UMAPParameters
58
57
  from phoenix.server.api.context import Context, DataLoaders
59
58
  from phoenix.server.api.dataloaders import (
59
+ AverageExperimentRunLatencyDataLoader,
60
60
  CacheForDataLoaders,
61
+ DatasetExampleRevisionsDataLoader,
62
+ DatasetExampleSpansDataLoader,
61
63
  DocumentEvaluationsDataLoader,
62
64
  DocumentEvaluationSummaryDataLoader,
63
65
  DocumentRetrievalMetricsDataLoader,
64
66
  EvaluationSummaryDataLoader,
67
+ ExperimentAnnotationSummaryDataLoader,
68
+ ExperimentErrorRatesDataLoader,
69
+ ExperimentRunCountsDataLoader,
70
+ ExperimentSequenceNumberDataLoader,
65
71
  LatencyMsQuantileDataLoader,
66
72
  MinStartOrMaxEndTimeDataLoader,
73
+ ProjectByNameDataLoader,
67
74
  RecordCountDataLoader,
68
75
  SpanDescendantsDataLoader,
69
76
  SpanEvaluationsDataLoader,
77
+ SpanProjectsDataLoader,
70
78
  TokenCountDataLoader,
71
79
  TraceEvaluationsDataLoader,
80
+ TraceRowIdsDataLoader,
72
81
  )
82
+ from phoenix.server.api.openapi.schema import OPENAPI_SCHEMA_GENERATOR
73
83
  from phoenix.server.api.routers.v1 import V1_ROUTES
74
84
  from phoenix.server.api.schema import schema
75
85
  from phoenix.server.grpc_server import GrpcServer
@@ -84,10 +94,6 @@ logger = logging.getLogger(__name__)
84
94
 
85
95
  templates = Jinja2Templates(directory=SERVER_DIR / "templates")
86
96
 
87
- schemas = SchemaGenerator(
88
- {"openapi": "3.0.0", "info": {"title": "ArizePhoenix API", "version": "1.0"}}
89
- )
90
-
91
97
 
92
98
  class AppConfig(NamedTuple):
93
99
  has_inferences: bool
@@ -126,6 +132,7 @@ class Static(StaticFiles):
126
132
  "n_neighbors": self._app_config.n_neighbors,
127
133
  "n_samples": self._app_config.n_samples,
128
134
  "basename": request.scope.get("root_path", ""),
135
+ "platform_version": phoenix.__version__,
129
136
  "request": request,
130
137
  },
131
138
  )
@@ -185,6 +192,9 @@ class GraphQLWithContext(GraphQL): # type: ignore
185
192
  export_path=self.export_path,
186
193
  streaming_last_updated_at=self.streaming_last_updated_at,
187
194
  data_loaders=DataLoaders(
195
+ average_experiment_run_latency=AverageExperimentRunLatencyDataLoader(self.db),
196
+ dataset_example_revisions=DatasetExampleRevisionsDataLoader(self.db),
197
+ dataset_example_spans=DatasetExampleSpansDataLoader(self.db),
188
198
  document_evaluation_summaries=DocumentEvaluationSummaryDataLoader(
189
199
  self.db,
190
200
  cache_map=self.cache_for_dataloaders.document_evaluation_summary
@@ -199,6 +209,10 @@ class GraphQLWithContext(GraphQL): # type: ignore
199
209
  if self.cache_for_dataloaders
200
210
  else None,
201
211
  ),
212
+ experiment_annotation_summaries=ExperimentAnnotationSummaryDataLoader(self.db),
213
+ experiment_error_rates=ExperimentErrorRatesDataLoader(self.db),
214
+ experiment_run_counts=ExperimentRunCountsDataLoader(self.db),
215
+ experiment_sequence_number=ExperimentSequenceNumberDataLoader(self.db),
202
216
  latency_ms_quantile=LatencyMsQuantileDataLoader(
203
217
  self.db,
204
218
  cache_map=self.cache_for_dataloaders.latency_ms_quantile
@@ -219,6 +233,7 @@ class GraphQLWithContext(GraphQL): # type: ignore
219
233
  ),
220
234
  span_descendants=SpanDescendantsDataLoader(self.db),
221
235
  span_evaluations=SpanEvaluationsDataLoader(self.db),
236
+ span_projects=SpanProjectsDataLoader(self.db),
222
237
  token_counts=TokenCountDataLoader(
223
238
  self.db,
224
239
  cache_map=self.cache_for_dataloaders.token_count
@@ -226,6 +241,8 @@ class GraphQLWithContext(GraphQL): # type: ignore
226
241
  else None,
227
242
  ),
228
243
  trace_evaluations=TraceEvaluationsDataLoader(self.db),
244
+ trace_row_ids=TraceRowIdsDataLoader(self.db),
245
+ project_by_name=ProjectByNameDataLoader(self.db),
229
246
  ),
230
247
  cache_for_dataloaders=self.cache_for_dataloaders,
231
248
  read_only=self.read_only,
@@ -272,7 +289,11 @@ def _lifespan(
272
289
  ) -> StatefulLifespan[Starlette]:
273
290
  @contextlib.asynccontextmanager
274
291
  async def lifespan(_: Starlette) -> AsyncIterator[Dict[str, Any]]:
275
- async with bulk_inserter as (queue_span, queue_evaluation), GrpcServer(
292
+ async with bulk_inserter as (
293
+ queue_span,
294
+ queue_evaluation,
295
+ enqueue_operation,
296
+ ), GrpcServer(
276
297
  queue_span,
277
298
  disabled=read_only,
278
299
  tracer_provider=tracer_provider,
@@ -281,6 +302,7 @@ def _lifespan(
281
302
  yield {
282
303
  "queue_span_for_bulk_insert": queue_span,
283
304
  "queue_evaluation_for_bulk_insert": queue_evaluation,
305
+ "enqueue_operation": enqueue_operation,
284
306
  }
285
307
  for clean_up in clean_ups:
286
308
  clean_up()
@@ -293,15 +315,63 @@ async def check_healthz(_: Request) -> PlainTextResponse:
293
315
 
294
316
 
295
317
  async def openapi_schema(request: Request) -> Response:
296
- return schemas.OpenAPIResponse(request=request)
318
+ return OPENAPI_SCHEMA_GENERATOR.OpenAPIResponse(request=request)
297
319
 
298
320
 
299
321
  async def api_docs(request: Request) -> Response:
300
322
  return get_swagger_ui_html(openapi_url="/schema", title="arize-phoenix API")
301
323
 
302
324
 
303
- def create_app(
325
+ class SessionFactory:
326
+ def __init__(
327
+ self,
328
+ session_factory: Callable[[], AsyncContextManager[AsyncSession]],
329
+ dialect: str,
330
+ ):
331
+ self.session_factory = session_factory
332
+ self.dialect = SupportedSQLDialect(dialect)
333
+
334
+ def __call__(self) -> AsyncContextManager[AsyncSession]:
335
+ return self.session_factory()
336
+
337
+
338
+ def create_engine_and_run_migrations(
304
339
  database_url: str,
340
+ ) -> AsyncEngine:
341
+ try:
342
+ return create_engine(database_url)
343
+ except PhoenixMigrationError as e:
344
+ msg = (
345
+ "\n\n⚠️⚠️ Phoenix failed to migrate the database to the latest version. ⚠️⚠️\n\n"
346
+ "The database may be in a dirty state. To resolve this, the Alembic CLI can be used\n"
347
+ "from the `src/phoenix/db` directory inside the Phoenix project root. From here,\n"
348
+ "revert any partial migrations and run `alembic stamp` to reset the migration state,\n"
349
+ "then try starting Phoenix again.\n\n"
350
+ "If issues persist, please reach out for support in the Arize community Slack:\n"
351
+ "https://arize-ai.slack.com\n\n"
352
+ "You can also refer to the Alembic documentation for more information:\n"
353
+ "https://alembic.sqlalchemy.org/en/latest/tutorial.html\n\n"
354
+ ""
355
+ )
356
+ raise PhoenixMigrationError(msg) from e
357
+
358
+
359
+ def instrument_engine_if_enabled(engine: AsyncEngine) -> List[Callable[[], None]]:
360
+ instrumentation_cleanups = []
361
+ if server_instrumentation_is_enabled():
362
+ from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor
363
+
364
+ tracer_provider = initialize_opentelemetry_tracer_provider()
365
+ SQLAlchemyInstrumentor().instrument(
366
+ engine=engine.sync_engine,
367
+ tracer_provider=tracer_provider,
368
+ )
369
+ instrumentation_cleanups.append(SQLAlchemyInstrumentor().uninstrument)
370
+ return instrumentation_cleanups
371
+
372
+
373
+ def create_app(
374
+ db: SessionFactory,
305
375
  export_path: Path,
306
376
  model: Model,
307
377
  umap_params: UMAPParameters,
@@ -311,8 +381,10 @@ def create_app(
311
381
  enable_prometheus: bool = False,
312
382
  initial_spans: Optional[Iterable[Union[Span, Tuple[Span, str]]]] = None,
313
383
  initial_evaluations: Optional[Iterable[pb.Evaluation]] = None,
384
+ serve_ui: bool = True,
385
+ clean_up_callbacks: List[Callable[[], None]] = [],
314
386
  ) -> Starlette:
315
- clean_ups: List[Callable[[], None]] = [] # To be called at app shutdown.
387
+ clean_ups: List[Callable[[], None]] = clean_up_callbacks # To be called at app shutdown.
316
388
  initial_batch_of_spans: Iterable[Tuple[Span, str]] = (
317
389
  ()
318
390
  if initial_spans is None
@@ -322,28 +394,10 @@ def create_app(
322
394
  )
323
395
  )
324
396
  initial_batch_of_evaluations = () if initial_evaluations is None else initial_evaluations
325
- try:
326
- engine = create_engine(database_url)
327
- except PhoenixMigrationError as e:
328
- msg = (
329
- "\n\n⚠️⚠️ Phoenix failed to migrate the database to the latest version. ⚠️⚠️\n\n"
330
- "The database may be in a dirty state. To resolve this, the Alembic CLI can be used\n"
331
- "from the `src/phoenix/db` directory inside the Phoenix project root. From here,\n"
332
- "revert any partial migrations and run `alembic stamp` to reset the migration state,\n"
333
- "then try starting Phoenix again.\n\n"
334
- "If issues persist, please reach out for support in the Arize community Slack:\n"
335
- "https://arize-ai.slack.com\n\n"
336
- "You can also refer to the Alembic documentation for more information:\n"
337
- "https://alembic.sqlalchemy.org/en/latest/tutorial.html\n\n"
338
- ""
339
- )
340
- raise PhoenixMigrationError(msg) from e
341
397
  cache_for_dataloaders = (
342
- CacheForDataLoaders()
343
- if SupportedSQLDialect(engine.dialect.name) is SupportedSQLDialect.SQLITE
344
- else None
398
+ CacheForDataLoaders() if db.dialect is SupportedSQLDialect.SQLITE else None
345
399
  )
346
- db = _db(engine)
400
+
347
401
  bulk_inserter = BulkInserter(
348
402
  db,
349
403
  enable_prometheus=enable_prometheus,
@@ -354,16 +408,9 @@ def create_app(
354
408
  tracer_provider = None
355
409
  strawberry_extensions = schema.get_extensions()
356
410
  if server_instrumentation_is_enabled():
357
- from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor
358
411
  from opentelemetry.trace import TracerProvider
359
412
  from strawberry.extensions.tracing import OpenTelemetryExtension
360
413
 
361
- tracer_provider = initialize_opentelemetry_tracer_provider()
362
- SQLAlchemyInstrumentor().instrument(
363
- engine=engine.sync_engine,
364
- tracer_provider=tracer_provider,
365
- )
366
- clean_ups.append(SQLAlchemyInstrumentor().uninstrument)
367
414
  if TYPE_CHECKING:
368
415
  # Type-check the class before monkey-patching its private attribute.
369
416
  assert OpenTelemetryExtension._tracer
@@ -377,6 +424,7 @@ def create_app(
377
424
  self._tracer = cast(TracerProvider, tracer_provider).get_tracer("strawberry")
378
425
 
379
426
  strawberry_extensions.append(_OpenTelemetryExtension)
427
+
380
428
  graphql = GraphQLWithContext(
381
429
  db=db,
382
430
  schema=strawberry.Schema(
@@ -433,21 +481,27 @@ def create_app(
433
481
  "/graphql",
434
482
  graphql,
435
483
  ),
436
- Mount(
437
- "/",
438
- app=Static(
439
- directory=SERVER_DIR / "static",
440
- app_config=AppConfig(
441
- has_inferences=model.is_empty is not True,
442
- has_corpus=corpus is not None,
443
- min_dist=umap_params.min_dist,
444
- n_neighbors=umap_params.n_neighbors,
445
- n_samples=umap_params.n_samples,
484
+ ]
485
+ + (
486
+ [
487
+ Mount(
488
+ "/",
489
+ app=Static(
490
+ directory=SERVER_DIR / "static",
491
+ app_config=AppConfig(
492
+ has_inferences=model.is_empty is not True,
493
+ has_corpus=corpus is not None,
494
+ min_dist=umap_params.min_dist,
495
+ n_neighbors=umap_params.n_neighbors,
496
+ n_samples=umap_params.n_samples,
497
+ ),
446
498
  ),
499
+ name="static",
447
500
  ),
448
- name="static",
449
- ),
450
- ],
501
+ ]
502
+ if serve_ui
503
+ else []
504
+ ),
451
505
  )
452
506
  app.state.read_only = read_only
453
507
  app.state.db = db