arize-phoenix 5.5.2__py3-none-any.whl → 5.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (186) hide show
  1. {arize_phoenix-5.5.2.dist-info → arize_phoenix-5.7.0.dist-info}/METADATA +4 -7
  2. arize_phoenix-5.7.0.dist-info/RECORD +330 -0
  3. phoenix/config.py +50 -8
  4. phoenix/core/model.py +3 -3
  5. phoenix/core/model_schema.py +41 -50
  6. phoenix/core/model_schema_adapter.py +17 -16
  7. phoenix/datetime_utils.py +2 -2
  8. phoenix/db/bulk_inserter.py +10 -20
  9. phoenix/db/engines.py +2 -1
  10. phoenix/db/enums.py +2 -2
  11. phoenix/db/helpers.py +8 -7
  12. phoenix/db/insertion/dataset.py +9 -19
  13. phoenix/db/insertion/document_annotation.py +14 -13
  14. phoenix/db/insertion/helpers.py +6 -16
  15. phoenix/db/insertion/span_annotation.py +14 -13
  16. phoenix/db/insertion/trace_annotation.py +14 -13
  17. phoenix/db/insertion/types.py +19 -30
  18. phoenix/db/migrations/versions/3be8647b87d8_add_token_columns_to_spans_table.py +8 -8
  19. phoenix/db/models.py +28 -28
  20. phoenix/experiments/evaluators/base.py +2 -1
  21. phoenix/experiments/evaluators/code_evaluators.py +4 -5
  22. phoenix/experiments/evaluators/llm_evaluators.py +157 -4
  23. phoenix/experiments/evaluators/utils.py +3 -2
  24. phoenix/experiments/functions.py +10 -21
  25. phoenix/experiments/tracing.py +2 -1
  26. phoenix/experiments/types.py +20 -29
  27. phoenix/experiments/utils.py +2 -1
  28. phoenix/inferences/errors.py +6 -5
  29. phoenix/inferences/fixtures.py +6 -5
  30. phoenix/inferences/inferences.py +37 -37
  31. phoenix/inferences/schema.py +11 -10
  32. phoenix/inferences/validation.py +13 -14
  33. phoenix/logging/_formatter.py +3 -3
  34. phoenix/metrics/__init__.py +5 -4
  35. phoenix/metrics/binning.py +2 -1
  36. phoenix/metrics/metrics.py +2 -1
  37. phoenix/metrics/mixins.py +7 -6
  38. phoenix/metrics/retrieval_metrics.py +2 -1
  39. phoenix/metrics/timeseries.py +5 -4
  40. phoenix/metrics/wrappers.py +2 -2
  41. phoenix/pointcloud/clustering.py +3 -4
  42. phoenix/pointcloud/pointcloud.py +7 -5
  43. phoenix/pointcloud/umap_parameters.py +2 -1
  44. phoenix/server/api/dataloaders/annotation_summaries.py +12 -19
  45. phoenix/server/api/dataloaders/average_experiment_run_latency.py +2 -2
  46. phoenix/server/api/dataloaders/cache/two_tier_cache.py +3 -2
  47. phoenix/server/api/dataloaders/dataset_example_revisions.py +3 -8
  48. phoenix/server/api/dataloaders/dataset_example_spans.py +2 -5
  49. phoenix/server/api/dataloaders/document_evaluation_summaries.py +12 -18
  50. phoenix/server/api/dataloaders/document_evaluations.py +3 -7
  51. phoenix/server/api/dataloaders/document_retrieval_metrics.py +6 -13
  52. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +4 -8
  53. phoenix/server/api/dataloaders/experiment_error_rates.py +2 -5
  54. phoenix/server/api/dataloaders/experiment_run_annotations.py +3 -7
  55. phoenix/server/api/dataloaders/experiment_run_counts.py +1 -5
  56. phoenix/server/api/dataloaders/experiment_sequence_number.py +2 -5
  57. phoenix/server/api/dataloaders/latency_ms_quantile.py +21 -30
  58. phoenix/server/api/dataloaders/min_start_or_max_end_times.py +7 -13
  59. phoenix/server/api/dataloaders/project_by_name.py +3 -3
  60. phoenix/server/api/dataloaders/record_counts.py +11 -18
  61. phoenix/server/api/dataloaders/span_annotations.py +3 -7
  62. phoenix/server/api/dataloaders/span_dataset_examples.py +3 -8
  63. phoenix/server/api/dataloaders/span_descendants.py +3 -7
  64. phoenix/server/api/dataloaders/span_projects.py +2 -2
  65. phoenix/server/api/dataloaders/token_counts.py +12 -19
  66. phoenix/server/api/dataloaders/trace_row_ids.py +3 -7
  67. phoenix/server/api/dataloaders/user_roles.py +3 -3
  68. phoenix/server/api/dataloaders/users.py +3 -3
  69. phoenix/server/api/helpers/__init__.py +4 -3
  70. phoenix/server/api/helpers/dataset_helpers.py +10 -9
  71. phoenix/server/api/helpers/playground_clients.py +671 -0
  72. phoenix/server/api/helpers/playground_registry.py +70 -0
  73. phoenix/server/api/helpers/playground_spans.py +325 -0
  74. phoenix/server/api/input_types/AddExamplesToDatasetInput.py +2 -2
  75. phoenix/server/api/input_types/AddSpansToDatasetInput.py +2 -2
  76. phoenix/server/api/input_types/ChatCompletionInput.py +38 -0
  77. phoenix/server/api/input_types/ChatCompletionMessageInput.py +13 -1
  78. phoenix/server/api/input_types/ClusterInput.py +2 -2
  79. phoenix/server/api/input_types/DeleteAnnotationsInput.py +1 -3
  80. phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +2 -2
  81. phoenix/server/api/input_types/DeleteExperimentsInput.py +1 -3
  82. phoenix/server/api/input_types/DimensionFilter.py +4 -4
  83. phoenix/server/api/input_types/GenerativeModelInput.py +17 -0
  84. phoenix/server/api/input_types/Granularity.py +1 -1
  85. phoenix/server/api/input_types/InvocationParameters.py +156 -13
  86. phoenix/server/api/input_types/PatchDatasetExamplesInput.py +2 -2
  87. phoenix/server/api/input_types/TemplateOptions.py +10 -0
  88. phoenix/server/api/mutations/__init__.py +4 -0
  89. phoenix/server/api/mutations/chat_mutations.py +374 -0
  90. phoenix/server/api/mutations/dataset_mutations.py +4 -4
  91. phoenix/server/api/mutations/experiment_mutations.py +1 -2
  92. phoenix/server/api/mutations/export_events_mutations.py +7 -7
  93. phoenix/server/api/mutations/span_annotations_mutations.py +4 -4
  94. phoenix/server/api/mutations/trace_annotations_mutations.py +4 -4
  95. phoenix/server/api/mutations/user_mutations.py +4 -4
  96. phoenix/server/api/openapi/schema.py +2 -2
  97. phoenix/server/api/queries.py +61 -72
  98. phoenix/server/api/routers/oauth2.py +4 -4
  99. phoenix/server/api/routers/v1/datasets.py +22 -36
  100. phoenix/server/api/routers/v1/evaluations.py +6 -5
  101. phoenix/server/api/routers/v1/experiment_evaluations.py +2 -2
  102. phoenix/server/api/routers/v1/experiment_runs.py +2 -2
  103. phoenix/server/api/routers/v1/experiments.py +4 -4
  104. phoenix/server/api/routers/v1/spans.py +13 -12
  105. phoenix/server/api/routers/v1/traces.py +5 -5
  106. phoenix/server/api/routers/v1/utils.py +5 -5
  107. phoenix/server/api/schema.py +42 -10
  108. phoenix/server/api/subscriptions.py +347 -494
  109. phoenix/server/api/types/AnnotationSummary.py +3 -3
  110. phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +44 -0
  111. phoenix/server/api/types/Cluster.py +8 -7
  112. phoenix/server/api/types/Dataset.py +5 -4
  113. phoenix/server/api/types/Dimension.py +3 -3
  114. phoenix/server/api/types/DocumentEvaluationSummary.py +8 -7
  115. phoenix/server/api/types/EmbeddingDimension.py +6 -5
  116. phoenix/server/api/types/EvaluationSummary.py +3 -3
  117. phoenix/server/api/types/Event.py +7 -7
  118. phoenix/server/api/types/Experiment.py +3 -3
  119. phoenix/server/api/types/ExperimentComparison.py +2 -4
  120. phoenix/server/api/types/GenerativeProvider.py +27 -3
  121. phoenix/server/api/types/Inferences.py +9 -8
  122. phoenix/server/api/types/InferencesRole.py +2 -2
  123. phoenix/server/api/types/Model.py +2 -2
  124. phoenix/server/api/types/Project.py +11 -18
  125. phoenix/server/api/types/Segments.py +3 -3
  126. phoenix/server/api/types/Span.py +45 -7
  127. phoenix/server/api/types/TemplateLanguage.py +9 -0
  128. phoenix/server/api/types/TimeSeries.py +8 -7
  129. phoenix/server/api/types/Trace.py +2 -2
  130. phoenix/server/api/types/UMAPPoints.py +6 -6
  131. phoenix/server/api/types/User.py +3 -3
  132. phoenix/server/api/types/node.py +1 -3
  133. phoenix/server/api/types/pagination.py +4 -4
  134. phoenix/server/api/utils.py +2 -4
  135. phoenix/server/app.py +76 -37
  136. phoenix/server/bearer_auth.py +4 -10
  137. phoenix/server/dml_event.py +3 -3
  138. phoenix/server/dml_event_handler.py +10 -24
  139. phoenix/server/grpc_server.py +3 -2
  140. phoenix/server/jwt_store.py +22 -21
  141. phoenix/server/main.py +17 -4
  142. phoenix/server/oauth2.py +3 -2
  143. phoenix/server/rate_limiters.py +5 -8
  144. phoenix/server/static/.vite/manifest.json +31 -31
  145. phoenix/server/static/assets/components-Csu8UKOs.js +1612 -0
  146. phoenix/server/static/assets/{index-DCzakdJq.js → index-Bk5C9EA7.js} +2 -2
  147. phoenix/server/static/assets/{pages-CAL1FDMt.js → pages-UeWaKXNs.js} +337 -442
  148. phoenix/server/static/assets/{vendor-6IcPAw_j.js → vendor-CtqfhlbC.js} +6 -6
  149. phoenix/server/static/assets/{vendor-arizeai-DRZuoyuF.js → vendor-arizeai-C_3SBz56.js} +2 -2
  150. phoenix/server/static/assets/{vendor-codemirror-DVE2_WBr.js → vendor-codemirror-wfdk9cjp.js} +1 -1
  151. phoenix/server/static/assets/{vendor-recharts-DwrexFA4.js → vendor-recharts-BiVnSv90.js} +1 -1
  152. phoenix/server/templates/index.html +1 -0
  153. phoenix/server/thread_server.py +1 -1
  154. phoenix/server/types.py +17 -29
  155. phoenix/services.py +8 -3
  156. phoenix/session/client.py +12 -24
  157. phoenix/session/data_extractor.py +3 -3
  158. phoenix/session/evaluation.py +1 -2
  159. phoenix/session/session.py +26 -21
  160. phoenix/trace/attributes.py +16 -28
  161. phoenix/trace/dsl/filter.py +17 -21
  162. phoenix/trace/dsl/helpers.py +3 -3
  163. phoenix/trace/dsl/query.py +13 -22
  164. phoenix/trace/fixtures.py +11 -17
  165. phoenix/trace/otel.py +5 -15
  166. phoenix/trace/projects.py +3 -2
  167. phoenix/trace/schemas.py +2 -2
  168. phoenix/trace/span_evaluations.py +9 -8
  169. phoenix/trace/span_json_decoder.py +3 -3
  170. phoenix/trace/span_json_encoder.py +2 -2
  171. phoenix/trace/trace_dataset.py +6 -5
  172. phoenix/trace/utils.py +6 -6
  173. phoenix/utilities/deprecation.py +3 -2
  174. phoenix/utilities/error_handling.py +3 -2
  175. phoenix/utilities/json.py +2 -1
  176. phoenix/utilities/logging.py +2 -2
  177. phoenix/utilities/project.py +1 -1
  178. phoenix/utilities/re.py +3 -4
  179. phoenix/utilities/template_formatters.py +16 -5
  180. phoenix/version.py +1 -1
  181. arize_phoenix-5.5.2.dist-info/RECORD +0 -321
  182. phoenix/server/static/assets/components-hX0LgYz3.js +0 -1428
  183. {arize_phoenix-5.5.2.dist-info → arize_phoenix-5.7.0.dist-info}/WHEEL +0 -0
  184. {arize_phoenix-5.5.2.dist-info → arize_phoenix-5.7.0.dist-info}/entry_points.txt +0 -0
  185. {arize_phoenix-5.5.2.dist-info → arize_phoenix-5.7.0.dist-info}/licenses/IP_NOTICE +0 -0
  186. {arize_phoenix-5.5.2.dist-info → arize_phoenix-5.7.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,8 +1,9 @@
1
1
  import json
2
+ from collections.abc import Mapping, Sized
2
3
  from dataclasses import dataclass
3
4
  from datetime import datetime
4
5
  from enum import Enum
5
- from typing import TYPE_CHECKING, Any, List, Mapping, Optional, Sized, cast
6
+ from typing import TYPE_CHECKING, Any, Optional, cast
6
7
 
7
8
  import numpy as np
8
9
  import strawberry
@@ -19,10 +20,12 @@ from phoenix.server.api.helpers.dataset_helpers import (
19
20
  get_dataset_example_input,
20
21
  get_dataset_example_output,
21
22
  )
23
+ from phoenix.server.api.input_types.InvocationParameters import InvocationParameter
22
24
  from phoenix.server.api.input_types.SpanAnnotationSort import (
23
25
  SpanAnnotationColumn,
24
26
  SpanAnnotationSort,
25
27
  )
28
+ from phoenix.server.api.types.GenerativeProvider import GenerativeProviderKey
26
29
  from phoenix.server.api.types.SortDir import SortDir
27
30
  from phoenix.server.api.types.SpanAnnotation import to_gql_span_annotation
28
31
  from phoenix.trace.attributes import get_attribute_value
@@ -152,7 +155,7 @@ class Span(Node):
152
155
  token_count_completion: Optional[int]
153
156
  input: Optional[SpanIOValue]
154
157
  output: Optional[SpanIOValue]
155
- events: List[SpanEvent]
158
+ events: list[SpanEvent]
156
159
  cumulative_token_count_total: Optional[int] = strawberry.field(
157
160
  description="Cumulative (prompt plus completion) token count from "
158
161
  "self and all descendant spans (children, grandchildren, etc.)",
@@ -180,7 +183,7 @@ class Span(Node):
180
183
  self,
181
184
  info: Info[Context, None],
182
185
  sort: Optional[SpanAnnotationSort] = UNSET,
183
- ) -> List[SpanAnnotation]:
186
+ ) -> list[SpanAnnotation]:
184
187
  span_id = self.id_attr
185
188
  annotations = await info.context.data_loaders.span_annotations.load(span_id)
186
189
  sort_key = SpanAnnotationColumn.name.value
@@ -201,7 +204,7 @@ class Span(Node):
201
204
  "a list, and each evaluation is identified by its document's (zero-based) "
202
205
  "index in that list."
203
206
  ) # type: ignore
204
- async def document_evaluations(self, info: Info[Context, None]) -> List[DocumentEvaluation]:
207
+ async def document_evaluations(self, info: Info[Context, None]) -> list[DocumentEvaluation]:
205
208
  return await info.context.data_loaders.document_evaluations.load(self.id_attr)
206
209
 
207
210
  @strawberry.field(
@@ -211,7 +214,7 @@ class Span(Node):
211
214
  self,
212
215
  info: Info[Context, None],
213
216
  evaluation_name: Optional[str] = UNSET,
214
- ) -> List[DocumentRetrievalMetrics]:
217
+ ) -> list[DocumentRetrievalMetrics]:
215
218
  if not self.num_documents:
216
219
  return []
217
220
  return await info.context.data_loaders.document_retrieval_metrics.load(
@@ -224,7 +227,7 @@ class Span(Node):
224
227
  async def descendants(
225
228
  self,
226
229
  info: Info[Context, None],
227
- ) -> List["Span"]:
230
+ ) -> list["Span"]:
228
231
  span_id = str(self.context.span_id)
229
232
  spans = await info.context.data_loaders.span_descendants.load(span_id)
230
233
  return [to_gql_span(span) for span in spans]
@@ -290,9 +293,44 @@ class Span(Node):
290
293
  examples = await info.context.data_loaders.span_dataset_examples.load(self.id_attr)
291
294
  return bool(examples)
292
295
 
296
+ @strawberry.field(description="Invocation parameters for the span") # type: ignore
297
+ async def invocation_parameters(self, info: Info[Context, None]) -> list[InvocationParameter]:
298
+ from phoenix.server.api.helpers.playground_clients import OpenAIStreamingClient
299
+ from phoenix.server.api.helpers.playground_registry import PLAYGROUND_CLIENT_REGISTRY
300
+
301
+ db_span = self.db_span
302
+ attributes = db_span.attributes
303
+ llm_provider: GenerativeProviderKey = (
304
+ get_attribute_value(attributes, SpanAttributes.LLM_PROVIDER)
305
+ or GenerativeProviderKey.OPENAI
306
+ )
307
+ llm_model = get_attribute_value(attributes, SpanAttributes.LLM_MODEL_NAME)
308
+ invocation_parameters = get_attribute_value(
309
+ attributes, SpanAttributes.LLM_INVOCATION_PARAMETERS
310
+ )
311
+ if invocation_parameters is None:
312
+ return []
313
+ invocation_parameters = json.loads(invocation_parameters)
314
+ # find the client class for the provider, if there is no client class or provider,
315
+ # return openai as default
316
+ client_class = PLAYGROUND_CLIENT_REGISTRY.get_client(llm_provider, llm_model)
317
+ if not client_class:
318
+ client_class = OpenAIStreamingClient
319
+ supported_invocation_parameters = client_class.supported_invocation_parameters()
320
+ # filter supported invocation parameters down to those whose canonical_name is in the
321
+ # invocation_parameters keys
322
+ return [
323
+ ip
324
+ for ip in supported_invocation_parameters
325
+ if (
326
+ ip.canonical_name in invocation_parameters
327
+ or ip.invocation_name in invocation_parameters
328
+ )
329
+ ]
330
+
293
331
 
294
332
  def to_gql_span(span: models.Span) -> Span:
295
- events: List[SpanEvent] = list(map(SpanEvent.from_dict, span.events))
333
+ events: list[SpanEvent] = list(map(SpanEvent.from_dict, span.events))
296
334
  input_value = cast(Optional[str], get_attribute_value(span.attributes, INPUT_VALUE))
297
335
  output_value = cast(Optional[str], get_attribute_value(span.attributes, OUTPUT_VALUE))
298
336
  retrieval_documents = get_attribute_value(span.attributes, RETRIEVAL_DOCUMENTS)
@@ -0,0 +1,9 @@
1
+ from enum import Enum
2
+
3
+ import strawberry
4
+
5
+
6
+ @strawberry.enum
7
+ class TemplateLanguage(Enum):
8
+ MUSTACHE = "MUSTACHE"
9
+ F_STRING = "F_STRING"
@@ -1,7 +1,8 @@
1
+ from collections.abc import Iterable
1
2
  from dataclasses import replace
2
3
  from datetime import datetime, timedelta
3
4
  from functools import total_ordering
4
- from typing import Iterable, List, Optional, Tuple, Union, cast
5
+ from typing import Optional, Union, cast
5
6
 
6
7
  import pandas as pd
7
8
  import strawberry
@@ -39,7 +40,7 @@ def to_gql_datapoints(
39
40
  df: pd.DataFrame,
40
41
  metric: Metric,
41
42
  timestamps: Iterable[datetime],
42
- ) -> List[TimeSeriesDataPoint]:
43
+ ) -> list[TimeSeriesDataPoint]:
43
44
  data = []
44
45
  for timestamp in timestamps:
45
46
  try:
@@ -59,7 +60,7 @@ def to_gql_datapoints(
59
60
  class TimeSeries:
60
61
  """A collection of data points over time"""
61
62
 
62
- data: List[TimeSeriesDataPoint]
63
+ data: list[TimeSeriesDataPoint]
63
64
 
64
65
 
65
66
  def get_timeseries_data(
@@ -67,7 +68,7 @@ def get_timeseries_data(
67
68
  metric: Metric,
68
69
  time_range: TimeRange,
69
70
  granularity: Granularity,
70
- ) -> List[TimeSeriesDataPoint]:
71
+ ) -> list[TimeSeriesDataPoint]:
71
72
  return df.pipe(
72
73
  timeseries(
73
74
  start_time=time_range.start,
@@ -98,7 +99,7 @@ def get_data_quality_timeseries_data(
98
99
  time_range: TimeRange,
99
100
  granularity: Granularity,
100
101
  inferences_role: InferencesRole,
101
- ) -> List[TimeSeriesDataPoint]:
102
+ ) -> list[TimeSeriesDataPoint]:
102
103
  metric_instance = metric.value()
103
104
  if isinstance(metric_instance, UnaryOperator):
104
105
  metric_instance = replace(
@@ -128,7 +129,7 @@ def get_drift_timeseries_data(
128
129
  time_range: TimeRange,
129
130
  granularity: Granularity,
130
131
  reference_data: pd.DataFrame,
131
- ) -> List[TimeSeriesDataPoint]:
132
+ ) -> list[TimeSeriesDataPoint]:
132
133
  metric_instance = metric.value()
133
134
  metric_instance = replace(
134
135
  metric_instance,
@@ -163,7 +164,7 @@ def ensure_timeseries_parameters(
163
164
  inferences: Inferences,
164
165
  time_range: Optional[TimeRange] = UNSET,
165
166
  granularity: Optional[Granularity] = UNSET,
166
- ) -> Tuple[TimeRange, Granularity]:
167
+ ) -> tuple[TimeRange, Granularity]:
167
168
  if not isinstance(time_range, TimeRange):
168
169
  start, stop = inferences.time_range
169
170
  time_range = TimeRange(start=start, end=stop)
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import List, Optional
3
+ from typing import Optional
4
4
 
5
5
  import strawberry
6
6
  from sqlalchemy import desc, select
@@ -69,7 +69,7 @@ class Trace(Node):
69
69
  self,
70
70
  info: Info[Context, None],
71
71
  sort: Optional[TraceAnnotationSort] = None,
72
- ) -> List[TraceAnnotation]:
72
+ ) -> list[TraceAnnotation]:
73
73
  async with info.context.db() as session:
74
74
  stmt = select(models.TraceAnnotation).filter_by(span_rowid=self.id_attr)
75
75
  if sort:
@@ -1,4 +1,4 @@
1
- from typing import List, Union
1
+ from typing import Union
2
2
 
3
3
  import numpy as np
4
4
  import numpy.typing as npt
@@ -57,8 +57,8 @@ class UMAPPoint:
57
57
 
58
58
  @strawberry.type
59
59
  class UMAPPoints:
60
- data: List[UMAPPoint]
61
- reference_data: List[UMAPPoint]
62
- clusters: List[Cluster]
63
- corpus_data: List[UMAPPoint] = strawberry.field(default_factory=list)
64
- context_retrievals: List[Retrieval] = strawberry.field(default_factory=list)
60
+ data: list[UMAPPoint]
61
+ reference_data: list[UMAPPoint]
62
+ clusters: list[Cluster]
63
+ corpus_data: list[UMAPPoint] = strawberry.field(default_factory=list)
64
+ context_retrievals: list[Retrieval] = strawberry.field(default_factory=list)
@@ -1,5 +1,5 @@
1
1
  from datetime import datetime
2
- from typing import List, Optional
2
+ from typing import Optional
3
3
 
4
4
  import strawberry
5
5
  from sqlalchemy import select
@@ -35,7 +35,7 @@ class User(Node):
35
35
  return to_gql_user_role(role)
36
36
 
37
37
  @strawberry.field
38
- async def api_keys(self, info: Info[Context, None]) -> List[UserApiKey]:
38
+ async def api_keys(self, info: Info[Context, None]) -> list[UserApiKey]:
39
39
  async with info.context.db() as session:
40
40
  api_keys = await session.scalars(
41
41
  select(models.ApiKey).where(models.ApiKey.user_id == self.id_attr)
@@ -43,7 +43,7 @@ class User(Node):
43
43
  return [to_gql_api_key(api_key) for api_key in api_keys]
44
44
 
45
45
 
46
- def to_gql_user(user: models.User, api_keys: Optional[List[models.ApiKey]] = None) -> User:
46
+ def to_gql_user(user: models.User, api_keys: Optional[list[models.ApiKey]] = None) -> User:
47
47
  """
48
48
  Converts an ORM user to a GraphQL user.
49
49
  """
@@ -1,9 +1,7 @@
1
- from typing import Tuple
2
-
3
1
  from strawberry.relay import GlobalID
4
2
 
5
3
 
6
- def from_global_id(global_id: GlobalID) -> Tuple[str, int]:
4
+ def from_global_id(global_id: GlobalID) -> tuple[str, int]:
7
5
  """
8
6
  Decode the given global id into a type and id.
9
7
 
@@ -2,7 +2,7 @@ import base64
2
2
  from dataclasses import dataclass
3
3
  from datetime import datetime
4
4
  from enum import Enum, auto
5
- from typing import Any, ClassVar, List, Optional, Tuple, Union
5
+ from typing import Any, ClassVar, Optional, Union
6
6
 
7
7
  from strawberry import UNSET
8
8
  from strawberry.relay.types import Connection, Edge, NodeType, PageInfo
@@ -176,7 +176,7 @@ class ConnectionArgs:
176
176
 
177
177
 
178
178
  def connection_from_list(
179
- data: List[NodeType],
179
+ data: list[NodeType],
180
180
  args: ConnectionArgs,
181
181
  ) -> Connection[NodeType]:
182
182
  """
@@ -188,7 +188,7 @@ def connection_from_list(
188
188
 
189
189
 
190
190
  def connection_from_list_slice(
191
- list_slice: List[NodeType],
191
+ list_slice: list[NodeType],
192
192
  args: ConnectionArgs,
193
193
  slice_start: int,
194
194
  list_length: int,
@@ -254,7 +254,7 @@ def connection_from_list_slice(
254
254
 
255
255
 
256
256
  def connection_from_cursors_and_nodes(
257
- cursors_and_nodes: List[Tuple[Any, NodeType]],
257
+ cursors_and_nodes: list[tuple[Any, NodeType]],
258
258
  has_previous_page: bool,
259
259
  has_next_page: bool,
260
260
  ) -> Connection[NodeType]:
@@ -1,5 +1,3 @@
1
- from typing import List
2
-
3
1
  from sqlalchemy import delete
4
2
 
5
3
  from phoenix.db import models
@@ -9,7 +7,7 @@ from phoenix.server.types import DbSessionFactory
9
7
  async def delete_projects(
10
8
  db: DbSessionFactory,
11
9
  *project_names: str,
12
- ) -> List[int]:
10
+ ) -> list[int]:
13
11
  if not project_names:
14
12
  return []
15
13
  stmt = (
@@ -24,7 +22,7 @@ async def delete_projects(
24
22
  async def delete_traces(
25
23
  db: DbSessionFactory,
26
24
  *trace_ids: str,
27
- ) -> List[int]:
25
+ ) -> list[int]:
28
26
  if not trace_ids:
29
27
  return []
30
28
  stmt = (
phoenix/server/app.py CHANGED
@@ -1,8 +1,11 @@
1
1
  import asyncio
2
2
  import contextlib
3
+ import importlib
3
4
  import json
4
5
  import logging
5
- from contextlib import AsyncExitStack
6
+ import os
7
+ from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Sequence
8
+ from contextlib import AbstractAsyncContextManager, AsyncExitStack
6
9
  from dataclasses import dataclass, field
7
10
  from datetime import datetime, timedelta, timezone
8
11
  from functools import cached_property
@@ -11,18 +14,8 @@ from types import MethodType
11
14
  from typing import (
12
15
  TYPE_CHECKING,
13
16
  Any,
14
- AsyncContextManager,
15
- AsyncIterator,
16
- Awaitable,
17
- Callable,
18
- Dict,
19
- Iterable,
20
- List,
21
17
  NamedTuple,
22
18
  Optional,
23
- Sequence,
24
- Tuple,
25
- Type,
26
19
  TypedDict,
27
20
  Union,
28
21
  cast,
@@ -49,7 +42,6 @@ from starlette.types import Scope, StatefulLifespan
49
42
  from starlette.websockets import WebSocket
50
43
  from strawberry.extensions import SchemaExtension
51
44
  from strawberry.fastapi import GraphQLRouter
52
- from strawberry.schema import BaseSchema
53
45
  from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL
54
46
  from typing_extensions import TypeAlias
55
47
 
@@ -60,6 +52,8 @@ from phoenix.config import (
60
52
  SERVER_DIR,
61
53
  OAuth2ClientConfig,
62
54
  get_env_csrf_trusted_origins,
55
+ get_env_fastapi_middleware_paths,
56
+ get_env_gql_extension_paths,
63
57
  get_env_host,
64
58
  get_env_port,
65
59
  server_instrumentation_is_enabled,
@@ -107,7 +101,7 @@ from phoenix.server.api.routers import (
107
101
  oauth2_router,
108
102
  )
109
103
  from phoenix.server.api.routers.v1 import REST_API_VERSION
110
- from phoenix.server.api.schema import schema
104
+ from phoenix.server.api.schema import build_graphql_schema
111
105
  from phoenix.server.bearer_auth import BearerTokenAuthBackend, is_authenticated
112
106
  from phoenix.server.dml_event import DmlEvent
113
107
  from phoenix.server.dml_event_handler import DmlEventHandler
@@ -159,6 +153,28 @@ ProjectName: TypeAlias = str
159
153
  _Callback: TypeAlias = Callable[[], Union[None, Awaitable[None]]]
160
154
 
161
155
 
156
+ def import_object_from_file(file_path: str, object_name: str) -> Any:
157
+ """Import an object (class or function) from a Python file."""
158
+ try:
159
+ if not os.path.isfile(file_path):
160
+ raise FileNotFoundError(f"File '{file_path}' does not exist.")
161
+ module_name = f"custom_module_{hash(file_path)}"
162
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
163
+ if spec is None:
164
+ raise ImportError(f"Could not load spec for '{file_path}'")
165
+ module = importlib.util.module_from_spec(spec)
166
+ loader = spec.loader
167
+ if loader is None:
168
+ raise ImportError(f"No loader found for '{file_path}'")
169
+ loader.exec_module(module)
170
+ try:
171
+ return getattr(module, object_name)
172
+ except AttributeError:
173
+ raise ImportError(f"Module '{file_path}' does not have an object '{object_name}'.")
174
+ except Exception as e:
175
+ raise ImportError(f"Could not import '{object_name}' from '{file_path}': {e}")
176
+
177
+
162
178
  class OAuth2Idp(TypedDict):
163
179
  name: str
164
180
  displayName: str
@@ -175,6 +191,7 @@ class AppConfig(NamedTuple):
175
191
  web_manifest_path: Path
176
192
  authentication_enabled: bool
177
193
  """ Whether authentication is enabled """
194
+ websockets_enabled: bool
178
195
  oauth2_idps: Sequence[OAuth2Idp]
179
196
 
180
197
 
@@ -188,10 +205,10 @@ class Static(StaticFiles):
188
205
  super().__init__(**kwargs)
189
206
 
190
207
  @cached_property
191
- def _web_manifest(self) -> Dict[str, Any]:
208
+ def _web_manifest(self) -> dict[str, Any]:
192
209
  try:
193
210
  with open(self._app_config.web_manifest_path, "r") as f:
194
- return cast(Dict[str, Any], json.load(f))
211
+ return cast(dict[str, Any], json.load(f))
195
212
  except FileNotFoundError as e:
196
213
  if self._app_config.is_development:
197
214
  return {}
@@ -225,6 +242,7 @@ class Static(StaticFiles):
225
242
  "manifest": self._web_manifest,
226
243
  "authentication_enabled": self._app_config.authentication_enabled,
227
244
  "oauth2_idps": self._app_config.oauth2_idps,
245
+ "websockets_enabled": self._app_config.websockets_enabled,
228
246
  },
229
247
  )
230
248
  except Exception as e:
@@ -233,7 +251,7 @@ class Static(StaticFiles):
233
251
 
234
252
 
235
253
  class RequestOriginHostnameValidator(BaseHTTPMiddleware):
236
- def __init__(self, trusted_hostnames: List[str], *args: Any, **kwargs: Any) -> None:
254
+ def __init__(self, trusted_hostnames: list[str], *args: Any, **kwargs: Any) -> None:
237
255
  super().__init__(*args, **kwargs)
238
256
  self._trusted_hostnames = trusted_hostnames
239
257
 
@@ -265,6 +283,28 @@ class HeadersMiddleware(BaseHTTPMiddleware):
265
283
  return response
266
284
 
267
285
 
286
+ def user_fastapi_middlewares() -> list[Middleware]:
287
+ paths = get_env_fastapi_middleware_paths()
288
+ middlewares = []
289
+ for file_path, object_name in paths:
290
+ middleware_class = import_object_from_file(file_path, object_name)
291
+ if not issubclass(middleware_class, BaseHTTPMiddleware):
292
+ raise TypeError(f"{middleware_class} is not a subclass of BaseHTTPMiddleware")
293
+ middlewares.append(Middleware(middleware_class))
294
+ return middlewares
295
+
296
+
297
+ def user_gql_extensions() -> list[Union[type[SchemaExtension], SchemaExtension]]:
298
+ paths = get_env_gql_extension_paths()
299
+ extensions = []
300
+ for file_path, object_name in paths:
301
+ extension_class = import_object_from_file(file_path, object_name)
302
+ if not issubclass(extension_class, SchemaExtension):
303
+ raise TypeError(f"{extension_class} is not a subclass of SchemaExtension")
304
+ extensions.append(extension_class)
305
+ return extensions
306
+
307
+
268
308
  ProjectRowId: TypeAlias = int
269
309
 
270
310
 
@@ -278,7 +318,7 @@ DB_MUTEX: Optional[asyncio.Lock] = None
278
318
 
279
319
  def _db(
280
320
  engine: AsyncEngine, bypass_lock: bool = False
281
- ) -> Callable[[], AsyncContextManager[AsyncSession]]:
321
+ ) -> Callable[[], AbstractAsyncContextManager[AsyncSession]]:
282
322
  Session = async_sessionmaker(engine, expire_on_commit=False)
283
323
 
284
324
  @contextlib.asynccontextmanager
@@ -420,7 +460,7 @@ def _lifespan(
420
460
  scaffolder_config: Optional[ScaffolderConfig] = None,
421
461
  ) -> StatefulLifespan[FastAPI]:
422
462
  @contextlib.asynccontextmanager
423
- async def lifespan(_: FastAPI) -> AsyncIterator[Dict[str, Any]]:
463
+ async def lifespan(_: FastAPI) -> AsyncIterator[dict[str, Any]]:
424
464
  for callback in startup_callbacks:
425
465
  if isinstance((res := callback()), Awaitable):
426
466
  await res
@@ -449,7 +489,7 @@ def _lifespan(
449
489
  queue_evaluation=queue_evaluation,
450
490
  )
451
491
  await stack.enter_async_context(scaffolder)
452
- if isinstance(token_store, AsyncContextManager):
492
+ if isinstance(token_store, AbstractAsyncContextManager):
453
493
  await stack.enter_async_context(token_store)
454
494
  yield {
455
495
  "event_queue": dml_event_handler,
@@ -472,7 +512,7 @@ async def check_healthz(_: Request) -> PlainTextResponse:
472
512
 
473
513
  def create_graphql_router(
474
514
  *,
475
- schema: BaseSchema,
515
+ graphql_schema: strawberry.Schema,
476
516
  db: DbSessionFactory,
477
517
  model: Model,
478
518
  export_path: Path,
@@ -576,7 +616,7 @@ def create_graphql_router(
576
616
  )
577
617
 
578
618
  return GraphQLRouter(
579
- schema,
619
+ graphql_schema,
580
620
  graphql_ide="graphiql",
581
621
  context_getter=get_context,
582
622
  include_in_schema=False,
@@ -607,7 +647,7 @@ def create_engine_and_run_migrations(
607
647
  raise PhoenixMigrationError(msg) from e
608
648
 
609
649
 
610
- def instrument_engine_if_enabled(engine: AsyncEngine) -> List[Callable[[], None]]:
650
+ def instrument_engine_if_enabled(engine: AsyncEngine) -> list[Callable[[], None]]:
611
651
  instrumentation_cleanups = []
612
652
  if server_instrumentation_is_enabled():
613
653
  from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor
@@ -662,12 +702,13 @@ def create_app(
662
702
  model: Model,
663
703
  authentication_enabled: bool,
664
704
  umap_params: UMAPParameters,
705
+ enable_websockets: bool,
665
706
  corpus: Optional[Model] = None,
666
707
  debug: bool = False,
667
708
  dev: bool = False,
668
709
  read_only: bool = False,
669
710
  enable_prometheus: bool = False,
670
- initial_spans: Optional[Iterable[Union[Span, Tuple[Span, str]]]] = None,
711
+ initial_spans: Optional[Iterable[Union[Span, tuple[Span, str]]]] = None,
671
712
  initial_evaluations: Optional[Iterable[pb.Evaluation]] = None,
672
713
  serve_ui: bool = True,
673
714
  startup_callbacks: Iterable[_Callback] = (),
@@ -678,7 +719,7 @@ def create_app(
678
719
  refresh_token_expiry: Optional[timedelta] = None,
679
720
  scaffolder_config: Optional[ScaffolderConfig] = None,
680
721
  email_sender: Optional[EmailSender] = None,
681
- oauth2_client_configs: Optional[List[OAuth2ClientConfig]] = None,
722
+ oauth2_client_configs: Optional[list[OAuth2ClientConfig]] = None,
682
723
  bulk_inserter_factory: Optional[Callable[..., BulkInserter]] = None,
683
724
  ) -> FastAPI:
684
725
  if model.embedding_dimensions:
@@ -692,10 +733,10 @@ def create_app(
692
733
  ) from exc
693
734
  logger.info(f"Server umap params: {umap_params}")
694
735
  bulk_inserter_factory = bulk_inserter_factory or BulkInserter
695
- startup_callbacks_list: List[_Callback] = list(startup_callbacks)
696
- shutdown_callbacks_list: List[_Callback] = list(shutdown_callbacks)
736
+ startup_callbacks_list: list[_Callback] = list(startup_callbacks)
737
+ shutdown_callbacks_list: list[_Callback] = list(shutdown_callbacks)
697
738
  startup_callbacks_list.append(Facilitator(db=db))
698
- initial_batch_of_spans: Iterable[Tuple[Span, str]] = (
739
+ initial_batch_of_spans: Iterable[tuple[Span, str]] = (
699
740
  ()
700
741
  if initial_spans is None
701
742
  else (
@@ -708,7 +749,8 @@ def create_app(
708
749
  CacheForDataLoaders() if db.dialect is SupportedSQLDialect.SQLITE else None
709
750
  )
710
751
  last_updated_at = LastUpdatedAt()
711
- middlewares: List[Middleware] = [Middleware(HeadersMiddleware)]
752
+ middlewares: list[Middleware] = [Middleware(HeadersMiddleware)]
753
+ middlewares.extend(user_fastapi_middlewares())
712
754
  if origins := get_env_csrf_trusted_origins():
713
755
  trusted_hostnames = [h for o in origins if o and (h := urlparse(o).hostname)]
714
756
  middlewares.append(Middleware(RequestOriginHostnameValidator, trusted_hostnames))
@@ -742,8 +784,9 @@ def create_app(
742
784
  initial_batch_of_evaluations=initial_batch_of_evaluations,
743
785
  )
744
786
  tracer_provider = None
745
- strawberry_extensions: List[Union[Type[SchemaExtension], SchemaExtension]] = []
746
- strawberry_extensions.extend(schema.get_extensions())
787
+ graphql_schema_extensions: list[Union[type[SchemaExtension], SchemaExtension]] = []
788
+ graphql_schema_extensions.extend(user_gql_extensions())
789
+
747
790
  if server_instrumentation_is_enabled():
748
791
  tracer_provider = initialize_opentelemetry_tracer_provider()
749
792
  from opentelemetry.trace import TracerProvider
@@ -761,16 +804,11 @@ def create_app(
761
804
  # used by OpenInference.
762
805
  self._tracer = cast(TracerProvider, tracer_provider).get_tracer("strawberry")
763
806
 
764
- strawberry_extensions.append(_OpenTelemetryExtension)
807
+ graphql_schema_extensions.append(_OpenTelemetryExtension)
765
808
 
766
809
  graphql_router = create_graphql_router(
767
810
  db=db,
768
- schema=strawberry.Schema(
769
- query=schema.query,
770
- mutation=schema.mutation,
771
- subscription=schema.subscription,
772
- extensions=strawberry_extensions,
773
- ),
811
+ graphql_schema=build_graphql_schema(graphql_schema_extensions),
774
812
  model=model,
775
813
  corpus=corpus,
776
814
  authentication_enabled=authentication_enabled,
@@ -839,6 +877,7 @@ def create_app(
839
877
  authentication_enabled=authentication_enabled,
840
878
  web_manifest_path=web_manifest_path,
841
879
  oauth2_idps=oauth2_idps,
880
+ websockets_enabled=enable_websockets,
842
881
  ),
843
882
  ),
844
883
  name="static",
@@ -1,14 +1,8 @@
1
1
  from abc import ABC
2
+ from collections.abc import Awaitable, Callable
2
3
  from datetime import datetime, timedelta, timezone
3
4
  from functools import cached_property
4
- from typing import (
5
- Any,
6
- Awaitable,
7
- Callable,
8
- Optional,
9
- Tuple,
10
- cast,
11
- )
5
+ from typing import Any, Optional, cast
12
6
 
13
7
  import grpc
14
8
  from fastapi import HTTPException, Request, WebSocket, WebSocketException
@@ -51,7 +45,7 @@ class BearerTokenAuthBackend(HasTokenStore, AuthenticationBackend):
51
45
  async def authenticate(
52
46
  self,
53
47
  conn: HTTPConnection,
54
- ) -> Optional[Tuple[AuthCredentials, BaseUser]]:
48
+ ) -> Optional[tuple[AuthCredentials, BaseUser]]:
55
49
  if header := conn.headers.get("Authorization"):
56
50
  scheme, _, token = header.partition(" ")
57
51
  if scheme.lower() != "bearer" or not token:
@@ -143,7 +137,7 @@ async def create_access_and_refresh_tokens(
143
137
  user: OrmUser,
144
138
  access_token_expiry: timedelta,
145
139
  refresh_token_expiry: timedelta,
146
- ) -> Tuple[AccessToken, RefreshToken]:
140
+ ) -> tuple[AccessToken, RefreshToken]:
147
141
  issued_at = datetime.now(timezone.utc)
148
142
  user_id = UserId(user.id)
149
143
  user_role = UserRole(user.role.name)
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from abc import ABC
4
4
  from dataclasses import dataclass, field
5
- from typing import ClassVar, Tuple, Type
5
+ from typing import ClassVar
6
6
 
7
7
  from phoenix.db import models
8
8
 
@@ -14,8 +14,8 @@ class DmlEvent(ABC):
14
14
  operation, e.g. insertion, update, or deletion.
15
15
  """
16
16
 
17
- table: ClassVar[Type[models.Base]]
18
- ids: Tuple[int, ...] = field(default_factory=tuple)
17
+ table: ClassVar[type[models.Base]]
18
+ ids: tuple[int, ...] = field(default_factory=tuple)
19
19
 
20
20
  def __bool__(self) -> bool:
21
21
  return bool(self.ids)