arize-phoenix 4.14.1__py3-none-any.whl → 4.16.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (85) hide show
  1. {arize_phoenix-4.14.1.dist-info → arize_phoenix-4.16.0.dist-info}/METADATA +5 -3
  2. {arize_phoenix-4.14.1.dist-info → arize_phoenix-4.16.0.dist-info}/RECORD +81 -71
  3. phoenix/db/bulk_inserter.py +131 -5
  4. phoenix/db/engines.py +2 -1
  5. phoenix/db/helpers.py +23 -1
  6. phoenix/db/insertion/constants.py +2 -0
  7. phoenix/db/insertion/document_annotation.py +157 -0
  8. phoenix/db/insertion/helpers.py +13 -0
  9. phoenix/db/insertion/span_annotation.py +144 -0
  10. phoenix/db/insertion/trace_annotation.py +144 -0
  11. phoenix/db/insertion/types.py +261 -0
  12. phoenix/experiments/functions.py +3 -2
  13. phoenix/experiments/types.py +3 -3
  14. phoenix/server/api/context.py +7 -9
  15. phoenix/server/api/dataloaders/__init__.py +2 -0
  16. phoenix/server/api/dataloaders/average_experiment_run_latency.py +3 -3
  17. phoenix/server/api/dataloaders/dataset_example_revisions.py +2 -4
  18. phoenix/server/api/dataloaders/dataset_example_spans.py +2 -4
  19. phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -4
  20. phoenix/server/api/dataloaders/document_evaluations.py +2 -4
  21. phoenix/server/api/dataloaders/document_retrieval_metrics.py +2 -4
  22. phoenix/server/api/dataloaders/evaluation_summaries.py +2 -4
  23. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +2 -4
  24. phoenix/server/api/dataloaders/experiment_error_rates.py +2 -4
  25. phoenix/server/api/dataloaders/experiment_run_counts.py +2 -4
  26. phoenix/server/api/dataloaders/experiment_sequence_number.py +2 -4
  27. phoenix/server/api/dataloaders/latency_ms_quantile.py +2 -3
  28. phoenix/server/api/dataloaders/min_start_or_max_end_times.py +2 -4
  29. phoenix/server/api/dataloaders/project_by_name.py +3 -3
  30. phoenix/server/api/dataloaders/record_counts.py +2 -4
  31. phoenix/server/api/dataloaders/span_annotations.py +2 -4
  32. phoenix/server/api/dataloaders/span_dataset_examples.py +36 -0
  33. phoenix/server/api/dataloaders/span_descendants.py +2 -4
  34. phoenix/server/api/dataloaders/span_evaluations.py +2 -4
  35. phoenix/server/api/dataloaders/span_projects.py +3 -3
  36. phoenix/server/api/dataloaders/token_counts.py +2 -4
  37. phoenix/server/api/dataloaders/trace_evaluations.py +2 -4
  38. phoenix/server/api/dataloaders/trace_row_ids.py +2 -4
  39. phoenix/server/api/input_types/SpanAnnotationSort.py +17 -0
  40. phoenix/server/api/input_types/TraceAnnotationSort.py +17 -0
  41. phoenix/server/api/mutations/span_annotations_mutations.py +8 -3
  42. phoenix/server/api/mutations/trace_annotations_mutations.py +8 -3
  43. phoenix/server/api/openapi/main.py +18 -2
  44. phoenix/server/api/openapi/schema.py +12 -12
  45. phoenix/server/api/routers/v1/__init__.py +36 -83
  46. phoenix/server/api/routers/v1/datasets.py +515 -509
  47. phoenix/server/api/routers/v1/evaluations.py +164 -73
  48. phoenix/server/api/routers/v1/experiment_evaluations.py +68 -91
  49. phoenix/server/api/routers/v1/experiment_runs.py +98 -155
  50. phoenix/server/api/routers/v1/experiments.py +132 -181
  51. phoenix/server/api/routers/v1/pydantic_compat.py +78 -0
  52. phoenix/server/api/routers/v1/spans.py +164 -203
  53. phoenix/server/api/routers/v1/traces.py +134 -159
  54. phoenix/server/api/routers/v1/utils.py +95 -0
  55. phoenix/server/api/types/Span.py +27 -3
  56. phoenix/server/api/types/Trace.py +21 -4
  57. phoenix/server/api/utils.py +4 -4
  58. phoenix/server/app.py +172 -192
  59. phoenix/server/grpc_server.py +2 -2
  60. phoenix/server/main.py +5 -9
  61. phoenix/server/static/.vite/manifest.json +31 -31
  62. phoenix/server/static/assets/components-Ci5kMOk5.js +1175 -0
  63. phoenix/server/static/assets/{index-CQgXRwU0.js → index-BQG5WVX7.js} +2 -2
  64. phoenix/server/static/assets/{pages-hdjlFZhO.js → pages-BrevprVW.js} +451 -275
  65. phoenix/server/static/assets/{vendor-DPvSDRn3.js → vendor-CP0b0YG0.js} +2 -2
  66. phoenix/server/static/assets/{vendor-arizeai-CkvPT67c.js → vendor-arizeai-DTbiPGp6.js} +27 -27
  67. phoenix/server/static/assets/vendor-codemirror-DtdPDzrv.js +15 -0
  68. phoenix/server/static/assets/{vendor-recharts-5jlNaZuF.js → vendor-recharts-A0DA1O99.js} +1 -1
  69. phoenix/server/thread_server.py +2 -2
  70. phoenix/server/types.py +18 -0
  71. phoenix/session/client.py +5 -3
  72. phoenix/session/session.py +2 -2
  73. phoenix/trace/dsl/filter.py +2 -6
  74. phoenix/trace/fixtures.py +17 -23
  75. phoenix/trace/utils.py +23 -0
  76. phoenix/utilities/client.py +116 -0
  77. phoenix/utilities/project.py +1 -1
  78. phoenix/version.py +1 -1
  79. phoenix/server/api/routers/v1/dataset_examples.py +0 -178
  80. phoenix/server/openapi/docs.py +0 -221
  81. phoenix/server/static/assets/components-DeS0YEmv.js +0 -1142
  82. phoenix/server/static/assets/vendor-codemirror-Cqwpwlua.js +0 -12
  83. {arize_phoenix-4.14.1.dist-info → arize_phoenix-4.16.0.dist-info}/WHEEL +0 -0
  84. {arize_phoenix-4.14.1.dist-info → arize_phoenix-4.16.0.dist-info}/licenses/IP_NOTICE +0 -0
  85. {arize_phoenix-4.14.1.dist-info → arize_phoenix-4.16.0.dist-info}/licenses/LICENSE +0 -0
phoenix/server/app.py CHANGED
@@ -1,3 +1,4 @@
1
+ import asyncio
1
2
  import contextlib
2
3
  import json
3
4
  import logging
@@ -21,25 +22,24 @@ from typing import (
21
22
  )
22
23
 
23
24
  import strawberry
25
+ from fastapi import APIRouter, FastAPI
26
+ from fastapi.middleware.gzip import GZipMiddleware
27
+ from fastapi.responses import FileResponse
28
+ from fastapi.utils import is_body_allowed_for_status_code
24
29
  from sqlalchemy.ext.asyncio import (
25
30
  AsyncEngine,
26
31
  AsyncSession,
27
32
  async_sessionmaker,
28
33
  )
29
- from starlette.applications import Starlette
30
- from starlette.datastructures import QueryParams
31
- from starlette.endpoints import HTTPEndpoint
32
34
  from starlette.exceptions import HTTPException
33
35
  from starlette.middleware import Middleware
34
36
  from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
35
37
  from starlette.requests import Request
36
- from starlette.responses import FileResponse, PlainTextResponse, Response
37
- from starlette.routing import Mount, Route
38
+ from starlette.responses import PlainTextResponse, Response
38
39
  from starlette.staticfiles import StaticFiles
39
40
  from starlette.templating import Jinja2Templates
40
41
  from starlette.types import Scope, StatefulLifespan
41
- from starlette.websockets import WebSocket
42
- from strawberry.asgi import GraphQL
42
+ from strawberry.fastapi import GraphQLRouter
43
43
  from strawberry.schema import BaseSchema
44
44
  from typing_extensions import TypeAlias
45
45
 
@@ -75,6 +75,7 @@ from phoenix.server.api.dataloaders import (
75
75
  ProjectByNameDataLoader,
76
76
  RecordCountDataLoader,
77
77
  SpanAnnotationsDataLoader,
78
+ SpanDatasetExamplesDataLoader,
78
79
  SpanDescendantsDataLoader,
79
80
  SpanEvaluationsDataLoader,
80
81
  SpanProjectsDataLoader,
@@ -82,19 +83,22 @@ from phoenix.server.api.dataloaders import (
82
83
  TraceEvaluationsDataLoader,
83
84
  TraceRowIdsDataLoader,
84
85
  )
85
- from phoenix.server.api.openapi.schema import OPENAPI_SCHEMA_GENERATOR
86
- from phoenix.server.api.routers.v1 import V1_ROUTES
86
+ from phoenix.server.api.routers.v1 import REST_API_VERSION
87
+ from phoenix.server.api.routers.v1 import router as v1_router
87
88
  from phoenix.server.api.schema import schema
88
89
  from phoenix.server.grpc_server import GrpcServer
89
- from phoenix.server.openapi.docs import get_swagger_ui_html
90
90
  from phoenix.server.telemetry import initialize_opentelemetry_tracer_provider
91
+ from phoenix.server.types import DbSessionFactory
91
92
  from phoenix.trace.schemas import Span
93
+ from phoenix.utilities.client import PHOENIX_SERVER_VERSION_HEADER
92
94
 
93
95
  if TYPE_CHECKING:
94
96
  from opentelemetry.trace import TracerProvider
95
97
 
96
98
  logger = logging.getLogger(__name__)
97
99
 
100
+ router = APIRouter(include_in_schema=False)
101
+
98
102
  templates = Jinja2Templates(directory=SERVER_DIR / "templates")
99
103
 
100
104
 
@@ -167,125 +171,35 @@ class HeadersMiddleware(BaseHTTPMiddleware):
167
171
  request: Request,
168
172
  call_next: RequestResponseEndpoint,
169
173
  ) -> Response:
174
+ from phoenix import __version__ as phoenix_version
175
+
170
176
  response = await call_next(request)
171
177
  response.headers["x-colab-notebook-cache-control"] = "no-cache"
178
+ response.headers[PHOENIX_SERVER_VERSION_HEADER] = phoenix_version
172
179
  return response
173
180
 
174
181
 
175
182
  ProjectRowId: TypeAlias = int
176
183
 
177
184
 
178
- class GraphQLWithContext(GraphQL): # type: ignore
179
- def __init__(
180
- self,
181
- schema: BaseSchema,
182
- db: Callable[[], AsyncContextManager[AsyncSession]],
183
- model: Model,
184
- export_path: Path,
185
- graphiql: bool = False,
186
- corpus: Optional[Model] = None,
187
- streaming_last_updated_at: Callable[[ProjectRowId], Optional[datetime]] = lambda _: None,
188
- cache_for_dataloaders: Optional[CacheForDataLoaders] = None,
189
- read_only: bool = False,
190
- ) -> None:
191
- self.db = db
192
- self.model = model
193
- self.corpus = corpus
194
- self.export_path = export_path
195
- self.streaming_last_updated_at = streaming_last_updated_at
196
- self.cache_for_dataloaders = cache_for_dataloaders
197
- self.read_only = read_only
198
- super().__init__(schema, graphiql=graphiql)
199
-
200
- async def get_context(
201
- self,
202
- request: Union[Request, WebSocket],
203
- response: Optional[Response] = None,
204
- ) -> Context:
205
- return Context(
206
- request=request,
207
- response=response,
208
- db=self.db,
209
- model=self.model,
210
- corpus=self.corpus,
211
- export_path=self.export_path,
212
- streaming_last_updated_at=self.streaming_last_updated_at,
213
- data_loaders=DataLoaders(
214
- average_experiment_run_latency=AverageExperimentRunLatencyDataLoader(self.db),
215
- dataset_example_revisions=DatasetExampleRevisionsDataLoader(self.db),
216
- dataset_example_spans=DatasetExampleSpansDataLoader(self.db),
217
- document_evaluation_summaries=DocumentEvaluationSummaryDataLoader(
218
- self.db,
219
- cache_map=self.cache_for_dataloaders.document_evaluation_summary
220
- if self.cache_for_dataloaders
221
- else None,
222
- ),
223
- document_evaluations=DocumentEvaluationsDataLoader(self.db),
224
- document_retrieval_metrics=DocumentRetrievalMetricsDataLoader(self.db),
225
- evaluation_summaries=EvaluationSummaryDataLoader(
226
- self.db,
227
- cache_map=self.cache_for_dataloaders.evaluation_summary
228
- if self.cache_for_dataloaders
229
- else None,
230
- ),
231
- experiment_annotation_summaries=ExperimentAnnotationSummaryDataLoader(self.db),
232
- experiment_error_rates=ExperimentErrorRatesDataLoader(self.db),
233
- experiment_run_counts=ExperimentRunCountsDataLoader(self.db),
234
- experiment_sequence_number=ExperimentSequenceNumberDataLoader(self.db),
235
- latency_ms_quantile=LatencyMsQuantileDataLoader(
236
- self.db,
237
- cache_map=self.cache_for_dataloaders.latency_ms_quantile
238
- if self.cache_for_dataloaders
239
- else None,
240
- ),
241
- min_start_or_max_end_times=MinStartOrMaxEndTimeDataLoader(
242
- self.db,
243
- cache_map=self.cache_for_dataloaders.min_start_or_max_end_time
244
- if self.cache_for_dataloaders
245
- else None,
246
- ),
247
- record_counts=RecordCountDataLoader(
248
- self.db,
249
- cache_map=self.cache_for_dataloaders.record_count
250
- if self.cache_for_dataloaders
251
- else None,
252
- ),
253
- span_descendants=SpanDescendantsDataLoader(self.db),
254
- span_evaluations=SpanEvaluationsDataLoader(self.db),
255
- span_projects=SpanProjectsDataLoader(self.db),
256
- token_counts=TokenCountDataLoader(
257
- self.db,
258
- cache_map=self.cache_for_dataloaders.token_count
259
- if self.cache_for_dataloaders
260
- else None,
261
- ),
262
- trace_evaluations=TraceEvaluationsDataLoader(self.db),
263
- trace_row_ids=TraceRowIdsDataLoader(self.db),
264
- project_by_name=ProjectByNameDataLoader(self.db),
265
- span_annotations=SpanAnnotationsDataLoader(self.db),
266
- ),
267
- cache_for_dataloaders=self.cache_for_dataloaders,
268
- read_only=self.read_only,
269
- )
185
+ @router.get("/exports")
186
+ async def download_exported_file(request: Request, filename: str) -> FileResponse:
187
+ file = request.app.state.export_path / (filename + ".parquet")
188
+ if not file.is_file():
189
+ raise HTTPException(status_code=404)
190
+ return FileResponse(
191
+ path=file,
192
+ filename=file.name,
193
+ media_type="application/x-octet-stream",
194
+ )
270
195
 
271
196
 
272
- class Download(HTTPEndpoint):
273
- path: Path
197
+ @router.get("/arize_phoenix_version")
198
+ async def version() -> PlainTextResponse:
199
+ return PlainTextResponse(f"{phoenix.__version__}")
274
200
 
275
- async def get(self, request: Request) -> FileResponse:
276
- params = QueryParams(request.query_params)
277
- file = self.path / (params.get("filename", "") + ".parquet")
278
- if not file.is_file():
279
- raise HTTPException(status_code=404)
280
- return FileResponse(
281
- path=file,
282
- filename=file.name,
283
- media_type="application/x-octet-stream",
284
- )
285
201
 
286
-
287
- async def version(_: Request) -> PlainTextResponse:
288
- return PlainTextResponse(f"{phoenix.__version__}")
202
+ DB_MUTEX: Optional[asyncio.Lock] = None
289
203
 
290
204
 
291
205
  def _db(engine: AsyncEngine) -> Callable[[], AsyncContextManager[AsyncSession]]:
@@ -293,23 +207,29 @@ def _db(engine: AsyncEngine) -> Callable[[], AsyncContextManager[AsyncSession]]:
293
207
 
294
208
  @contextlib.asynccontextmanager
295
209
  async def factory() -> AsyncIterator[AsyncSession]:
296
- async with Session.begin() as session:
297
- yield session
210
+ async with contextlib.AsyncExitStack() as stack:
211
+ if DB_MUTEX:
212
+ await stack.enter_async_context(DB_MUTEX)
213
+ yield await stack.enter_async_context(Session.begin())
298
214
 
299
215
  return factory
300
216
 
301
217
 
302
218
  def _lifespan(
303
219
  *,
220
+ dialect: SupportedSQLDialect,
304
221
  bulk_inserter: BulkInserter,
305
222
  tracer_provider: Optional["TracerProvider"] = None,
306
223
  enable_prometheus: bool = False,
307
224
  clean_ups: Iterable[Callable[[], None]] = (),
308
225
  read_only: bool = False,
309
- ) -> StatefulLifespan[Starlette]:
226
+ ) -> StatefulLifespan[FastAPI]:
310
227
  @contextlib.asynccontextmanager
311
- async def lifespan(_: Starlette) -> AsyncIterator[Dict[str, Any]]:
228
+ async def lifespan(_: FastAPI) -> AsyncIterator[Dict[str, Any]]:
229
+ global DB_MUTEX
230
+ DB_MUTEX = asyncio.Lock() if dialect is SupportedSQLDialect.SQLITE else None
312
231
  async with bulk_inserter as (
232
+ enqueue,
313
233
  queue_span,
314
234
  queue_evaluation,
315
235
  enqueue_operation,
@@ -320,6 +240,7 @@ def _lifespan(
320
240
  enable_prometheus=enable_prometheus,
321
241
  ):
322
242
  yield {
243
+ "enqueue": enqueue,
323
244
  "queue_span_for_bulk_insert": queue_span,
324
245
  "queue_evaluation_for_bulk_insert": queue_evaluation,
325
246
  "enqueue_operation": enqueue_operation,
@@ -330,29 +251,91 @@ def _lifespan(
330
251
  return lifespan
331
252
 
332
253
 
254
+ @router.get("/healthz")
333
255
  async def check_healthz(_: Request) -> PlainTextResponse:
334
256
  return PlainTextResponse("OK")
335
257
 
336
258
 
337
- async def openapi_schema(request: Request) -> Response:
338
- return OPENAPI_SCHEMA_GENERATOR.OpenAPIResponse(request=request)
339
-
340
-
341
- async def api_docs(request: Request) -> Response:
342
- return get_swagger_ui_html(openapi_url="/schema", title="arize-phoenix API")
343
-
344
-
345
- class SessionFactory:
346
- def __init__(
347
- self,
348
- session_factory: Callable[[], AsyncContextManager[AsyncSession]],
349
- dialect: str,
350
- ):
351
- self.session_factory = session_factory
352
- self.dialect = SupportedSQLDialect(dialect)
259
+ def create_graphql_router(
260
+ *,
261
+ schema: BaseSchema,
262
+ db: DbSessionFactory,
263
+ model: Model,
264
+ export_path: Path,
265
+ corpus: Optional[Model] = None,
266
+ streaming_last_updated_at: Callable[[ProjectRowId], Optional[datetime]] = lambda _: None,
267
+ cache_for_dataloaders: Optional[CacheForDataLoaders] = None,
268
+ read_only: bool = False,
269
+ ) -> GraphQLRouter: # type: ignore[type-arg]
270
+ def get_context() -> Context:
271
+ return Context(
272
+ db=db,
273
+ model=model,
274
+ corpus=corpus,
275
+ export_path=export_path,
276
+ streaming_last_updated_at=streaming_last_updated_at,
277
+ data_loaders=DataLoaders(
278
+ average_experiment_run_latency=AverageExperimentRunLatencyDataLoader(db),
279
+ dataset_example_revisions=DatasetExampleRevisionsDataLoader(db),
280
+ dataset_example_spans=DatasetExampleSpansDataLoader(db),
281
+ document_evaluation_summaries=DocumentEvaluationSummaryDataLoader(
282
+ db,
283
+ cache_map=cache_for_dataloaders.document_evaluation_summary
284
+ if cache_for_dataloaders
285
+ else None,
286
+ ),
287
+ document_evaluations=DocumentEvaluationsDataLoader(db),
288
+ document_retrieval_metrics=DocumentRetrievalMetricsDataLoader(db),
289
+ evaluation_summaries=EvaluationSummaryDataLoader(
290
+ db,
291
+ cache_map=cache_for_dataloaders.evaluation_summary
292
+ if cache_for_dataloaders
293
+ else None,
294
+ ),
295
+ experiment_annotation_summaries=ExperimentAnnotationSummaryDataLoader(db),
296
+ experiment_error_rates=ExperimentErrorRatesDataLoader(db),
297
+ experiment_run_counts=ExperimentRunCountsDataLoader(db),
298
+ experiment_sequence_number=ExperimentSequenceNumberDataLoader(db),
299
+ latency_ms_quantile=LatencyMsQuantileDataLoader(
300
+ db,
301
+ cache_map=cache_for_dataloaders.latency_ms_quantile
302
+ if cache_for_dataloaders
303
+ else None,
304
+ ),
305
+ min_start_or_max_end_times=MinStartOrMaxEndTimeDataLoader(
306
+ db,
307
+ cache_map=cache_for_dataloaders.min_start_or_max_end_time
308
+ if cache_for_dataloaders
309
+ else None,
310
+ ),
311
+ record_counts=RecordCountDataLoader(
312
+ db,
313
+ cache_map=cache_for_dataloaders.record_count if cache_for_dataloaders else None,
314
+ ),
315
+ span_annotations=SpanAnnotationsDataLoader(db),
316
+ span_dataset_examples=SpanDatasetExamplesDataLoader(db),
317
+ span_descendants=SpanDescendantsDataLoader(db),
318
+ span_evaluations=SpanEvaluationsDataLoader(db),
319
+ span_projects=SpanProjectsDataLoader(db),
320
+ token_counts=TokenCountDataLoader(
321
+ db,
322
+ cache_map=cache_for_dataloaders.token_count if cache_for_dataloaders else None,
323
+ ),
324
+ trace_evaluations=TraceEvaluationsDataLoader(db),
325
+ trace_row_ids=TraceRowIdsDataLoader(db),
326
+ project_by_name=ProjectByNameDataLoader(db),
327
+ ),
328
+ cache_for_dataloaders=cache_for_dataloaders,
329
+ read_only=read_only,
330
+ )
353
331
 
354
- def __call__(self) -> AsyncContextManager[AsyncSession]:
355
- return self.session_factory()
332
+ return GraphQLRouter(
333
+ schema,
334
+ graphiql=True,
335
+ context_getter=get_context,
336
+ include_in_schema=False,
337
+ prefix="/graphql",
338
+ )
356
339
 
357
340
 
358
341
  def create_engine_and_run_migrations(
@@ -390,8 +373,20 @@ def instrument_engine_if_enabled(engine: AsyncEngine) -> List[Callable[[], None]
390
373
  return instrumentation_cleanups
391
374
 
392
375
 
376
+ async def plain_text_http_exception_handler(request: Request, exc: HTTPException) -> Response:
377
+ """
378
+ Overrides the default handler for HTTPExceptions to return a plain text
379
+ response instead of a JSON response. For the original source code, see
380
+ https://github.com/tiangolo/fastapi/blob/d3cdd3bbd14109f3b268df7ca496e24bb64593aa/fastapi/exception_handlers.py#L11
381
+ """
382
+ headers = getattr(exc, "headers", None)
383
+ if not is_body_allowed_for_status_code(exc.status_code):
384
+ return Response(status_code=exc.status_code, headers=headers)
385
+ return PlainTextResponse(str(exc.detail), status_code=exc.status_code, headers=headers)
386
+
387
+
393
388
  def create_app(
394
- db: SessionFactory,
389
+ db: DbSessionFactory,
395
390
  export_path: Path,
396
391
  model: Model,
397
392
  umap_params: UMAPParameters,
@@ -404,7 +399,7 @@ def create_app(
404
399
  initial_evaluations: Optional[Iterable[pb.Evaluation]] = None,
405
400
  serve_ui: bool = True,
406
401
  clean_up_callbacks: List[Callable[[], None]] = [],
407
- ) -> Starlette:
402
+ ) -> FastAPI:
408
403
  clean_ups: List[Callable[[], None]] = clean_up_callbacks # To be called at app shutdown.
409
404
  initial_batch_of_spans: Iterable[Tuple[Span, str]] = (
410
405
  ()
@@ -447,7 +442,7 @@ def create_app(
447
442
 
448
443
  strawberry_extensions.append(_OpenTelemetryExtension)
449
444
 
450
- graphql = GraphQLWithContext(
445
+ graphql_router = create_graphql_router(
451
446
  db=db,
452
447
  schema=strawberry.Schema(
453
448
  query=schema.query,
@@ -458,7 +453,6 @@ def create_app(
458
453
  model=model,
459
454
  corpus=corpus,
460
455
  export_path=export_path,
461
- graphiql=True,
462
456
  streaming_last_updated_at=bulk_inserter.last_updated_at,
463
457
  cache_for_dataloaders=cache_for_dataloaders,
464
458
  read_only=read_only,
@@ -469,8 +463,11 @@ def create_app(
469
463
  prometheus_middlewares = [Middleware(PrometheusMiddleware)]
470
464
  else:
471
465
  prometheus_middlewares = []
472
- app = Starlette(
466
+ app = FastAPI(
467
+ title="Arize-Phoenix REST API",
468
+ version=REST_API_VERSION,
473
469
  lifespan=_lifespan(
470
+ dialect=db.dialect,
474
471
  read_only=read_only,
475
472
  bulk_inserter=bulk_inserter,
476
473
  tracer_provider=tracer_provider,
@@ -481,58 +478,41 @@ def create_app(
481
478
  Middleware(HeadersMiddleware),
482
479
  *prometheus_middlewares,
483
480
  ],
481
+ exception_handlers={HTTPException: plain_text_http_exception_handler},
484
482
  debug=debug,
485
- routes=V1_ROUTES
486
- + [
487
- Route("/schema", endpoint=openapi_schema, include_in_schema=False),
488
- Route("/arize_phoenix_version", version),
489
- Route("/healthz", check_healthz),
490
- Route(
491
- "/exports",
492
- type(
493
- "DownloadExports",
494
- (Download,),
495
- {"path": export_path},
496
- ),
497
- ),
498
- Route(
499
- "/docs",
500
- api_docs,
501
- ),
502
- Route(
503
- "/graphql",
504
- graphql,
505
- ),
506
- ]
507
- + (
508
- [
509
- Mount(
510
- "/",
511
- app=Static(
512
- directory=SERVER_DIR / "static",
513
- app_config=AppConfig(
514
- has_inferences=model.is_empty is not True,
515
- has_corpus=corpus is not None,
516
- min_dist=umap_params.min_dist,
517
- n_neighbors=umap_params.n_neighbors,
518
- n_samples=umap_params.n_samples,
519
- is_development=dev,
520
- web_manifest_path=SERVER_DIR / "static" / ".vite" / "manifest.json",
521
- ),
522
- ),
523
- name="static",
524
- ),
525
- ]
526
- if serve_ui
527
- else []
528
- ),
483
+ swagger_ui_parameters={
484
+ "defaultModelsExpandDepth": -1, # hides the schema section in the Swagger UI
485
+ },
529
486
  )
530
487
  app.state.read_only = read_only
488
+ app.state.export_path = export_path
489
+ app.include_router(v1_router)
490
+ app.include_router(router)
491
+ app.include_router(graphql_router)
492
+ app.add_middleware(GZipMiddleware)
493
+ if serve_ui:
494
+ app.mount(
495
+ "/",
496
+ app=Static(
497
+ directory=SERVER_DIR / "static",
498
+ app_config=AppConfig(
499
+ has_inferences=model.is_empty is not True,
500
+ has_corpus=corpus is not None,
501
+ min_dist=umap_params.min_dist,
502
+ n_neighbors=umap_params.n_neighbors,
503
+ n_samples=umap_params.n_samples,
504
+ is_development=dev,
505
+ web_manifest_path=SERVER_DIR / "static" / ".vite" / "manifest.json",
506
+ ),
507
+ ),
508
+ name="static",
509
+ )
510
+
531
511
  app.state.db = db
532
512
  if tracer_provider:
533
- from opentelemetry.instrumentation.starlette import StarletteInstrumentor
513
+ from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
534
514
 
535
- StarletteInstrumentor().instrument(tracer_provider=tracer_provider)
536
- StarletteInstrumentor.instrument_app(app, tracer_provider=tracer_provider)
537
- clean_ups.append(StarletteInstrumentor().uninstrument)
515
+ FastAPIInstrumentor().instrument(tracer_provider=tracer_provider)
516
+ FastAPIInstrumentor.instrument_app(app, tracer_provider=tracer_provider)
517
+ clean_ups.append(FastAPIInstrumentor().uninstrument)
538
518
  return app
@@ -23,7 +23,7 @@ if TYPE_CHECKING:
23
23
  ProjectName: TypeAlias = str
24
24
 
25
25
 
26
- class Servicer(TraceServiceServicer):
26
+ class Servicer(TraceServiceServicer): # type:ignore
27
27
  def __init__(
28
28
  self,
29
29
  callback: Callable[[Span, ProjectName], Awaitable[None]],
@@ -78,7 +78,7 @@ class GrpcServer:
78
78
  interceptors=interceptors,
79
79
  )
80
80
  server.add_insecure_port(f"[::]:{get_env_grpc_port()}")
81
- add_TraceServiceServicer_to_server(Servicer(self._callback), server) # type: ignore
81
+ add_TraceServiceServicer_to_server(Servicer(self._callback), server)
82
82
  await server.start()
83
83
  self._server = server
84
84
 
phoenix/server/main.py CHANGED
@@ -33,25 +33,23 @@ from phoenix.pointcloud.umap_parameters import (
33
33
  UMAPParameters,
34
34
  )
35
35
  from phoenix.server.app import (
36
- SessionFactory,
37
36
  _db,
38
37
  create_app,
39
38
  create_engine_and_run_migrations,
40
39
  instrument_engine_if_enabled,
41
40
  )
41
+ from phoenix.server.types import DbSessionFactory
42
42
  from phoenix.settings import Settings
43
43
  from phoenix.trace.fixtures import (
44
44
  TRACES_FIXTURES,
45
- download_traces_fixture,
46
45
  get_dataset_fixtures,
47
46
  get_evals_from_fixture,
48
- get_trace_fixture_by_name,
47
+ load_example_traces,
49
48
  reset_fixture_span_ids_and_timestamps,
50
49
  send_dataset_fixtures,
51
50
  )
52
51
  from phoenix.trace.otel import decode_otlp_span, encode_span_to_otlp
53
52
  from phoenix.trace.schemas import Span
54
- from phoenix.trace.span_json_decoder import json_string_to_span
55
53
 
56
54
  logger = logging.getLogger(__name__)
57
55
 
@@ -221,10 +219,8 @@ if __name__ == "__main__":
221
219
  (
222
220
  # Apply `encode` here because legacy jsonl files contains UUIDs as strings.
223
221
  # `encode` removes the hyphens in the UUIDs.
224
- decode_otlp_span(encode_span_to_otlp(json_string_to_span(json_span)))
225
- for json_span in download_traces_fixture(
226
- get_trace_fixture_by_name(trace_dataset_name)
227
- )
222
+ decode_otlp_span(encode_span_to_otlp(span))
223
+ for span in load_example_traces(trace_dataset_name).to_spans()
228
224
  ),
229
225
  get_evals_from_fixture(trace_dataset_name),
230
226
  )
@@ -250,7 +246,7 @@ if __name__ == "__main__":
250
246
  working_dir = get_working_dir().resolve()
251
247
  engine = create_engine_and_run_migrations(db_connection_str)
252
248
  instrumentation_cleanups = instrument_engine_if_enabled(engine)
253
- factory = SessionFactory(session_factory=_db(engine), dialect=engine.dialect.name)
249
+ factory = DbSessionFactory(db=_db(engine), dialect=engine.dialect.name)
254
250
  app = create_app(
255
251
  db=factory,
256
252
  export_path=export_path,