arize-phoenix 4.12.1rc1__py3-none-any.whl → 4.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (73) hide show
  1. {arize_phoenix-4.12.1rc1.dist-info → arize_phoenix-4.15.0.dist-info}/METADATA +10 -6
  2. {arize_phoenix-4.12.1rc1.dist-info → arize_phoenix-4.15.0.dist-info}/RECORD +70 -68
  3. phoenix/db/bulk_inserter.py +5 -4
  4. phoenix/db/engines.py +2 -1
  5. phoenix/experiments/evaluators/base.py +4 -0
  6. phoenix/experiments/evaluators/code_evaluators.py +80 -0
  7. phoenix/experiments/evaluators/llm_evaluators.py +77 -1
  8. phoenix/experiments/evaluators/utils.py +70 -21
  9. phoenix/experiments/functions.py +17 -16
  10. phoenix/server/api/context.py +5 -3
  11. phoenix/server/api/dataloaders/__init__.py +2 -0
  12. phoenix/server/api/dataloaders/average_experiment_run_latency.py +25 -25
  13. phoenix/server/api/dataloaders/dataset_example_revisions.py +2 -4
  14. phoenix/server/api/dataloaders/dataset_example_spans.py +2 -4
  15. phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -4
  16. phoenix/server/api/dataloaders/document_evaluations.py +2 -4
  17. phoenix/server/api/dataloaders/document_retrieval_metrics.py +2 -4
  18. phoenix/server/api/dataloaders/evaluation_summaries.py +2 -4
  19. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +2 -4
  20. phoenix/server/api/dataloaders/experiment_error_rates.py +32 -14
  21. phoenix/server/api/dataloaders/experiment_run_counts.py +20 -9
  22. phoenix/server/api/dataloaders/experiment_sequence_number.py +2 -4
  23. phoenix/server/api/dataloaders/latency_ms_quantile.py +2 -3
  24. phoenix/server/api/dataloaders/min_start_or_max_end_times.py +2 -4
  25. phoenix/server/api/dataloaders/project_by_name.py +3 -3
  26. phoenix/server/api/dataloaders/record_counts.py +2 -4
  27. phoenix/server/api/dataloaders/span_annotations.py +2 -4
  28. phoenix/server/api/dataloaders/span_dataset_examples.py +36 -0
  29. phoenix/server/api/dataloaders/span_descendants.py +2 -4
  30. phoenix/server/api/dataloaders/span_evaluations.py +2 -4
  31. phoenix/server/api/dataloaders/span_projects.py +3 -3
  32. phoenix/server/api/dataloaders/token_counts.py +2 -4
  33. phoenix/server/api/dataloaders/trace_evaluations.py +2 -4
  34. phoenix/server/api/dataloaders/trace_row_ids.py +2 -4
  35. phoenix/server/api/input_types/{CreateSpanAnnotationsInput.py → CreateSpanAnnotationInput.py} +4 -2
  36. phoenix/server/api/input_types/{CreateTraceAnnotationsInput.py → CreateTraceAnnotationInput.py} +4 -2
  37. phoenix/server/api/input_types/{PatchAnnotationsInput.py → PatchAnnotationInput.py} +4 -2
  38. phoenix/server/api/mutations/span_annotations_mutations.py +20 -9
  39. phoenix/server/api/mutations/trace_annotations_mutations.py +20 -9
  40. phoenix/server/api/routers/v1/datasets.py +132 -10
  41. phoenix/server/api/routers/v1/evaluations.py +3 -5
  42. phoenix/server/api/routers/v1/experiments.py +1 -1
  43. phoenix/server/api/types/Experiment.py +2 -2
  44. phoenix/server/api/types/Inferences.py +1 -2
  45. phoenix/server/api/types/Model.py +1 -2
  46. phoenix/server/api/types/Span.py +5 -0
  47. phoenix/server/api/utils.py +4 -4
  48. phoenix/server/app.py +21 -18
  49. phoenix/server/grpc_server.py +2 -2
  50. phoenix/server/main.py +5 -9
  51. phoenix/server/static/.vite/manifest.json +31 -31
  52. phoenix/server/static/assets/{components-C8sm_r1F.js → components-kGgeFkHp.js} +150 -110
  53. phoenix/server/static/assets/index-BctFO6S7.js +100 -0
  54. phoenix/server/static/assets/{pages-bN7juCjh.js → pages-DabDCmVd.js} +432 -255
  55. phoenix/server/static/assets/{vendor-CUDAPm8e.js → vendor-CP0b0YG0.js} +2 -2
  56. phoenix/server/static/assets/{vendor-arizeai-Do2HOmcL.js → vendor-arizeai-B5Hti8OB.js} +27 -27
  57. phoenix/server/static/assets/vendor-codemirror-DtdPDzrv.js +15 -0
  58. phoenix/server/static/assets/{vendor-recharts-PKRvByVe.js → vendor-recharts-A0DA1O99.js} +1 -1
  59. phoenix/server/types.py +18 -0
  60. phoenix/session/client.py +9 -6
  61. phoenix/session/session.py +2 -2
  62. phoenix/trace/dsl/filter.py +40 -25
  63. phoenix/trace/fixtures.py +17 -23
  64. phoenix/trace/utils.py +23 -0
  65. phoenix/utilities/client.py +116 -0
  66. phoenix/utilities/project.py +1 -1
  67. phoenix/version.py +1 -1
  68. phoenix/server/api/routers/v1/dataset_examples.py +0 -157
  69. phoenix/server/static/assets/index-BEKPzgQs.js +0 -100
  70. phoenix/server/static/assets/vendor-codemirror-CrdxOlMs.js +0 -12
  71. {arize_phoenix-4.12.1rc1.dist-info → arize_phoenix-4.15.0.dist-info}/WHEEL +0 -0
  72. {arize_phoenix-4.12.1rc1.dist-info → arize_phoenix-4.15.0.dist-info}/licenses/IP_NOTICE +0 -0
  73. {arize_phoenix-4.12.1rc1.dist-info → arize_phoenix-4.15.0.dist-info}/licenses/LICENSE +0 -0
@@ -7,10 +7,11 @@ from strawberry.types import Info
7
7
 
8
8
  from phoenix.db import models
9
9
  from phoenix.server.api.context import Context
10
- from phoenix.server.api.input_types.CreateTraceAnnotationsInput import CreateTraceAnnotationsInput
10
+ from phoenix.server.api.input_types.CreateTraceAnnotationInput import CreateTraceAnnotationInput
11
11
  from phoenix.server.api.input_types.DeleteAnnotationsInput import DeleteAnnotationsInput
12
- from phoenix.server.api.input_types.PatchAnnotationsInput import PatchAnnotationsInput
12
+ from phoenix.server.api.input_types.PatchAnnotationInput import PatchAnnotationInput
13
13
  from phoenix.server.api.mutations.auth import IsAuthenticated
14
+ from phoenix.server.api.queries import Query
14
15
  from phoenix.server.api.types.node import from_global_id_with_expected_type
15
16
  from phoenix.server.api.types.TraceAnnotation import TraceAnnotation, to_gql_trace_annotation
16
17
 
@@ -18,13 +19,14 @@ from phoenix.server.api.types.TraceAnnotation import TraceAnnotation, to_gql_tra
18
19
  @strawberry.type
19
20
  class TraceAnnotationMutationPayload:
20
21
  trace_annotations: List[TraceAnnotation]
22
+ query: Query
21
23
 
22
24
 
23
25
  @strawberry.type
24
26
  class TraceAnnotationMutationMixin:
25
27
  @strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
26
28
  async def create_trace_annotations(
27
- self, info: Info[Context, None], input: List[CreateTraceAnnotationsInput]
29
+ self, info: Info[Context, None], input: List[CreateTraceAnnotationInput]
28
30
  ) -> TraceAnnotationMutationPayload:
29
31
  inserted_annotations: Sequence[models.TraceAnnotation] = []
30
32
  async with info.context.db() as session:
@@ -35,7 +37,7 @@ class TraceAnnotationMutationMixin:
35
37
  label=annotation.label,
36
38
  score=annotation.score,
37
39
  explanation=annotation.explanation,
38
- annotator_kind=annotation.annotator_kind,
40
+ annotator_kind=annotation.annotator_kind.value,
39
41
  metadata_=annotation.metadata,
40
42
  )
41
43
  for annotation in input
@@ -49,12 +51,13 @@ class TraceAnnotationMutationMixin:
49
51
  return TraceAnnotationMutationPayload(
50
52
  trace_annotations=[
51
53
  to_gql_trace_annotation(annotation) for annotation in inserted_annotations
52
- ]
54
+ ],
55
+ query=Query(),
53
56
  )
54
57
 
55
58
  @strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
56
59
  async def patch_trace_annotations(
57
- self, info: Info[Context, None], input: List[PatchAnnotationsInput]
60
+ self, info: Info[Context, None], input: List[PatchAnnotationInput]
58
61
  ) -> TraceAnnotationMutationPayload:
59
62
  patched_annotations = []
60
63
  async with info.context.db() as session:
@@ -66,7 +69,13 @@ class TraceAnnotationMutationMixin:
66
69
  column.key: patch_value
67
70
  for column, patch_value, column_is_nullable in (
68
71
  (models.TraceAnnotation.name, annotation.name, False),
69
- (models.TraceAnnotation.annotator_kind, annotation.annotator_kind, False),
72
+ (
73
+ models.TraceAnnotation.annotator_kind,
74
+ annotation.annotator_kind.value
75
+ if annotation.annotator_kind is not None
76
+ else None,
77
+ False,
78
+ ),
70
79
  (models.TraceAnnotation.label, annotation.label, True),
71
80
  (models.TraceAnnotation.score, annotation.score, True),
72
81
  (models.TraceAnnotation.explanation, annotation.explanation, True),
@@ -83,7 +92,7 @@ class TraceAnnotationMutationMixin:
83
92
  if trace_annotation:
84
93
  patched_annotations.append(to_gql_trace_annotation(trace_annotation))
85
94
 
86
- return TraceAnnotationMutationPayload(trace_annotations=patched_annotations)
95
+ return TraceAnnotationMutationPayload(trace_annotations=patched_annotations, query=Query())
87
96
 
88
97
  @strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
89
98
  async def delete_trace_annotations(
@@ -105,4 +114,6 @@ class TraceAnnotationMutationMixin:
105
114
  deleted_annotations_gql = [
106
115
  to_gql_trace_annotation(annotation) for annotation in deleted_annotations
107
116
  ]
108
- return TraceAnnotationMutationPayload(trace_annotations=deleted_annotations_gql)
117
+ return TraceAnnotationMutationPayload(
118
+ trace_annotations=deleted_annotations_gql, query=Query()
119
+ )
@@ -56,12 +56,11 @@ from phoenix.db.insertion.dataset import (
56
56
  add_dataset_examples,
57
57
  )
58
58
  from phoenix.server.api.types.Dataset import Dataset as DatasetNodeType
59
- from phoenix.server.api.types.DatasetExample import DatasetExample
59
+ from phoenix.server.api.types.DatasetExample import DatasetExample as DatasetExampleNodeType
60
60
  from phoenix.server.api.types.DatasetVersion import DatasetVersion as DatasetVersionNodeType
61
61
  from phoenix.server.api.types.node import from_global_id_with_expected_type
62
62
  from phoenix.server.api.utils import delete_projects, delete_traces
63
63
 
64
- from .dataset_examples import router as dataset_examples_router
65
64
  from .pydantic_compat import V1RoutesBaseModel
66
65
  from .utils import (
67
66
  PaginatedResponseBody,
@@ -122,7 +121,7 @@ async def list_datasets(
122
121
  status_code=HTTP_422_UNPROCESSABLE_ENTITY,
123
122
  )
124
123
  if name:
125
- query = query.filter(models.Dataset.name.is_(name))
124
+ query = query.filter(models.Dataset.name == name)
126
125
 
127
126
  query = query.limit(limit + 1)
128
127
  result = await session.execute(query)
@@ -669,12 +668,135 @@ async def _parse_form_data(
669
668
  )
670
669
 
671
670
 
672
- # including the dataset examples router here ensures the dataset example routes
673
- # are included in a natural order in the openapi schema and the swagger ui
674
- #
675
- # todo: move the dataset examples routes here and remove the dataset_examples
676
- # sub-module
677
- router.include_router(dataset_examples_router)
671
+ class DatasetExample(V1RoutesBaseModel):
672
+ id: str
673
+ input: Dict[str, Any]
674
+ output: Dict[str, Any]
675
+ metadata: Dict[str, Any]
676
+ updated_at: datetime
677
+
678
+
679
+ class ListDatasetExamplesData(V1RoutesBaseModel):
680
+ dataset_id: str
681
+ version_id: str
682
+ examples: List[DatasetExample]
683
+
684
+
685
+ class ListDatasetExamplesResponseBody(ResponseBody[ListDatasetExamplesData]):
686
+ pass
687
+
688
+
689
+ @router.get(
690
+ "/datasets/{id}/examples",
691
+ operation_id="getDatasetExamples",
692
+ summary="Get examples from a dataset",
693
+ responses=add_errors_to_responses([HTTP_404_NOT_FOUND]),
694
+ )
695
+ async def get_dataset_examples(
696
+ request: Request,
697
+ id: str = Path(description="The ID of the dataset"),
698
+ version_id: Optional[str] = Query(
699
+ default=None,
700
+ description=(
701
+ "The ID of the dataset version " "(if omitted, returns data from the latest version)"
702
+ ),
703
+ ),
704
+ ) -> ListDatasetExamplesResponseBody:
705
+ dataset_gid = GlobalID.from_id(id)
706
+ version_gid = GlobalID.from_id(version_id) if version_id else None
707
+
708
+ if (dataset_type := dataset_gid.type_name) != "Dataset":
709
+ raise HTTPException(
710
+ detail=f"ID {dataset_gid} refers to a {dataset_type}", status_code=HTTP_404_NOT_FOUND
711
+ )
712
+
713
+ if version_gid and (version_type := version_gid.type_name) != "DatasetVersion":
714
+ raise HTTPException(
715
+ detail=f"ID {version_gid} refers to a {version_type}", status_code=HTTP_404_NOT_FOUND
716
+ )
717
+
718
+ async with request.app.state.db() as session:
719
+ if (
720
+ resolved_dataset_id := await session.scalar(
721
+ select(models.Dataset.id).where(models.Dataset.id == int(dataset_gid.node_id))
722
+ )
723
+ ) is None:
724
+ raise HTTPException(
725
+ detail=f"No dataset with id {dataset_gid} can be found.",
726
+ status_code=HTTP_404_NOT_FOUND,
727
+ )
728
+
729
+ # Subquery to find the maximum created_at for each dataset_example_id
730
+ # timestamp tiebreaks are resolved by the largest id
731
+ partial_subquery = select(
732
+ func.max(models.DatasetExampleRevision.id).label("max_id"),
733
+ ).group_by(models.DatasetExampleRevision.dataset_example_id)
734
+
735
+ if version_gid:
736
+ if (
737
+ resolved_version_id := await session.scalar(
738
+ select(models.DatasetVersion.id).where(
739
+ and_(
740
+ models.DatasetVersion.dataset_id == resolved_dataset_id,
741
+ models.DatasetVersion.id == int(version_gid.node_id),
742
+ )
743
+ )
744
+ )
745
+ ) is None:
746
+ raise HTTPException(
747
+ detail=f"No dataset version with id {version_id} can be found.",
748
+ status_code=HTTP_404_NOT_FOUND,
749
+ )
750
+ # if a version_id is provided, filter the subquery to only include revisions from that
751
+ partial_subquery = partial_subquery.filter(
752
+ models.DatasetExampleRevision.dataset_version_id <= resolved_version_id
753
+ )
754
+ else:
755
+ if (
756
+ resolved_version_id := await session.scalar(
757
+ select(func.max(models.DatasetVersion.id)).where(
758
+ models.DatasetVersion.dataset_id == resolved_dataset_id
759
+ )
760
+ )
761
+ ) is None:
762
+ raise HTTPException(
763
+ detail="Dataset has no versions.",
764
+ status_code=HTTP_404_NOT_FOUND,
765
+ )
766
+
767
+ subquery = partial_subquery.subquery()
768
+ # Query for the most recent example revisions that are not deleted
769
+ query = (
770
+ select(models.DatasetExample, models.DatasetExampleRevision)
771
+ .join(
772
+ models.DatasetExampleRevision,
773
+ models.DatasetExample.id == models.DatasetExampleRevision.dataset_example_id,
774
+ )
775
+ .join(
776
+ subquery,
777
+ (subquery.c.max_id == models.DatasetExampleRevision.id),
778
+ )
779
+ .filter(models.DatasetExample.dataset_id == resolved_dataset_id)
780
+ .filter(models.DatasetExampleRevision.revision_kind != "DELETE")
781
+ .order_by(models.DatasetExample.id.asc())
782
+ )
783
+ examples = [
784
+ DatasetExample(
785
+ id=str(GlobalID("DatasetExample", str(example.id))),
786
+ input=revision.input,
787
+ output=revision.output,
788
+ metadata=revision.metadata_,
789
+ updated_at=revision.created_at,
790
+ )
791
+ async for example, revision in await session.stream(query)
792
+ ]
793
+ return ListDatasetExamplesResponseBody(
794
+ data=ListDatasetExamplesData(
795
+ dataset_id=str(GlobalID("Dataset", str(resolved_dataset_id))),
796
+ version_id=str(GlobalID("DatasetVersion", str(resolved_version_id))),
797
+ examples=examples,
798
+ )
799
+ )
678
800
 
679
801
 
680
802
  @router.get(
@@ -794,7 +916,7 @@ def _get_content_csv(examples: List[models.DatasetExampleRevision]) -> bytes:
794
916
  records = [
795
917
  {
796
918
  "example_id": GlobalID(
797
- type_name=DatasetExample.__name__,
919
+ type_name=DatasetExampleNodeType.__name__,
798
920
  node_id=str(ex.dataset_example_id),
799
921
  ),
800
922
  **{f"input_{k}": v for k, v in ex.input.items()},
@@ -1,6 +1,6 @@
1
1
  import gzip
2
2
  from itertools import chain
3
- from typing import AsyncContextManager, Callable, Iterator, Optional, Tuple
3
+ from typing import Iterator, Optional, Tuple
4
4
 
5
5
  import pandas as pd
6
6
  import pyarrow as pa
@@ -9,9 +9,6 @@ from google.protobuf.message import DecodeError
9
9
  from pandas import DataFrame
10
10
  from sqlalchemy import select
11
11
  from sqlalchemy.engine import Connectable
12
- from sqlalchemy.ext.asyncio import (
13
- AsyncSession,
14
- )
15
12
  from starlette.background import BackgroundTask
16
13
  from starlette.datastructures import State
17
14
  from starlette.requests import Request
@@ -29,6 +26,7 @@ from phoenix.config import DEFAULT_PROJECT_NAME
29
26
  from phoenix.db import models
30
27
  from phoenix.exceptions import PhoenixEvaluationNameIsMissing
31
28
  from phoenix.server.api.routers.utils import table_to_bytes
29
+ from phoenix.server.types import DbSessionFactory
32
30
  from phoenix.session.evaluation import encode_evaluations
33
31
  from phoenix.trace.span_evaluations import (
34
32
  DocumentEvaluations,
@@ -128,7 +126,7 @@ async def get_evaluations(
128
126
  or DEFAULT_PROJECT_NAME
129
127
  )
130
128
 
131
- db: Callable[[], AsyncContextManager[AsyncSession]] = request.app.state.db
129
+ db: DbSessionFactory = request.app.state.db
132
130
  async with db() as session:
133
131
  connection = await session.connection()
134
132
  trace_evals_dataframe = await connection.run_sync(
@@ -110,7 +110,7 @@ async def create_experiment(
110
110
  )
111
111
  except ValueError:
112
112
  raise HTTPException(
113
- detail="DatasetVersion with ID {dataset_version_globalid} does not exist",
113
+ detail=f"DatasetVersion with ID {dataset_version_globalid_str} does not exist",
114
114
  status_code=HTTP_404_NOT_FOUND,
115
115
  )
116
116
 
@@ -104,11 +104,11 @@ class Experiment(Node):
104
104
  return await info.context.data_loaders.experiment_error_rates.load(self.id_attr)
105
105
 
106
106
  @strawberry.field
107
- async def average_run_latency_ms(self, info: Info[Context, None]) -> float:
107
+ async def average_run_latency_ms(self, info: Info[Context, None]) -> Optional[float]:
108
108
  latency_seconds = await info.context.data_loaders.average_experiment_run_latency.load(
109
109
  self.id_attr
110
110
  )
111
- return latency_seconds * 1000
111
+ return latency_seconds * 1000 if latency_seconds is not None else None
112
112
 
113
113
  @strawberry.field
114
114
  async def project(self, info: Info[Context, None]) -> Optional[Project]:
@@ -2,8 +2,7 @@ from datetime import datetime
2
2
  from typing import Iterable, List, Optional, Set, Union
3
3
 
4
4
  import strawberry
5
- from strawberry.scalars import ID
6
- from strawberry.unset import UNSET
5
+ from strawberry import ID, UNSET
7
6
 
8
7
  import phoenix.core.model_schema as ms
9
8
  from phoenix.core.model_schema import FEATURE, TAG, ScalarDimension
@@ -2,9 +2,8 @@ import asyncio
2
2
  from typing import List, Optional
3
3
 
4
4
  import strawberry
5
+ from strawberry import UNSET, Info
5
6
  from strawberry.relay import Connection
6
- from strawberry.types import Info
7
- from strawberry.unset import UNSET
8
7
  from typing_extensions import Annotated
9
8
 
10
9
  from phoenix.config import get_exported_files
@@ -258,6 +258,11 @@ class Span(Node):
258
258
  project = await info.context.data_loaders.span_projects.load(span_id)
259
259
  return to_gql_project(project)
260
260
 
261
+ @strawberry.field(description="Indicates if the span is contained in any dataset") # type: ignore
262
+ async def contained_in_dataset(self, info: Info[Context, None]) -> bool:
263
+ examples = await info.context.data_loaders.span_dataset_examples.load(self.id_attr)
264
+ return bool(examples)
265
+
261
266
 
262
267
  def to_gql_span(span: models.Span) -> Span:
263
268
  events: List[SpanEvent] = list(map(SpanEvent.from_dict, span.events))
@@ -1,13 +1,13 @@
1
- from typing import AsyncContextManager, Callable, List
1
+ from typing import List
2
2
 
3
3
  from sqlalchemy import delete
4
- from sqlalchemy.ext.asyncio import AsyncSession
5
4
 
6
5
  from phoenix.db import models
6
+ from phoenix.server.types import DbSessionFactory
7
7
 
8
8
 
9
9
  async def delete_projects(
10
- db: Callable[[], AsyncContextManager[AsyncSession]],
10
+ db: DbSessionFactory,
11
11
  *project_names: str,
12
12
  ) -> List[int]:
13
13
  if not project_names:
@@ -22,7 +22,7 @@ async def delete_projects(
22
22
 
23
23
 
24
24
  async def delete_traces(
25
- db: Callable[[], AsyncContextManager[AsyncSession]],
25
+ db: DbSessionFactory,
26
26
  *trace_ids: str,
27
27
  ) -> List[int]:
28
28
  if not trace_ids:
phoenix/server/app.py CHANGED
@@ -1,3 +1,4 @@
1
+ import asyncio
1
2
  import contextlib
2
3
  import json
3
4
  import logging
@@ -74,6 +75,7 @@ from phoenix.server.api.dataloaders import (
74
75
  ProjectByNameDataLoader,
75
76
  RecordCountDataLoader,
76
77
  SpanAnnotationsDataLoader,
78
+ SpanDatasetExamplesDataLoader,
77
79
  SpanDescendantsDataLoader,
78
80
  SpanEvaluationsDataLoader,
79
81
  SpanProjectsDataLoader,
@@ -86,7 +88,9 @@ from phoenix.server.api.routers.v1 import router as v1_router
86
88
  from phoenix.server.api.schema import schema
87
89
  from phoenix.server.grpc_server import GrpcServer
88
90
  from phoenix.server.telemetry import initialize_opentelemetry_tracer_provider
91
+ from phoenix.server.types import DbSessionFactory
89
92
  from phoenix.trace.schemas import Span
93
+ from phoenix.utilities.client import PHOENIX_SERVER_VERSION_HEADER
90
94
 
91
95
  if TYPE_CHECKING:
92
96
  from opentelemetry.trace import TracerProvider
@@ -167,9 +171,11 @@ class HeadersMiddleware(BaseHTTPMiddleware):
167
171
  request: Request,
168
172
  call_next: RequestResponseEndpoint,
169
173
  ) -> Response:
174
+ from phoenix import __version__ as phoenix_version
175
+
170
176
  response = await call_next(request)
171
177
  response.headers["x-colab-notebook-cache-control"] = "no-cache"
172
- response.headers["Cache-Control"] = "no-store"
178
+ response.headers[PHOENIX_SERVER_VERSION_HEADER] = phoenix_version
173
179
  return response
174
180
 
175
181
 
@@ -193,19 +199,25 @@ async def version() -> PlainTextResponse:
193
199
  return PlainTextResponse(f"{phoenix.__version__}")
194
200
 
195
201
 
202
+ DB_MUTEX: Optional[asyncio.Lock] = None
203
+
204
+
196
205
  def _db(engine: AsyncEngine) -> Callable[[], AsyncContextManager[AsyncSession]]:
197
206
  Session = async_sessionmaker(engine, expire_on_commit=False)
198
207
 
199
208
  @contextlib.asynccontextmanager
200
209
  async def factory() -> AsyncIterator[AsyncSession]:
201
- async with Session.begin() as session:
202
- yield session
210
+ async with contextlib.AsyncExitStack() as stack:
211
+ if DB_MUTEX:
212
+ await stack.enter_async_context(DB_MUTEX)
213
+ yield await stack.enter_async_context(Session.begin())
203
214
 
204
215
  return factory
205
216
 
206
217
 
207
218
  def _lifespan(
208
219
  *,
220
+ dialect: SupportedSQLDialect,
209
221
  bulk_inserter: BulkInserter,
210
222
  tracer_provider: Optional["TracerProvider"] = None,
211
223
  enable_prometheus: bool = False,
@@ -214,6 +226,8 @@ def _lifespan(
214
226
  ) -> StatefulLifespan[FastAPI]:
215
227
  @contextlib.asynccontextmanager
216
228
  async def lifespan(_: FastAPI) -> AsyncIterator[Dict[str, Any]]:
229
+ global DB_MUTEX
230
+ DB_MUTEX = asyncio.Lock() if dialect is SupportedSQLDialect.SQLITE else None
217
231
  async with bulk_inserter as (
218
232
  queue_span,
219
233
  queue_evaluation,
@@ -243,7 +257,7 @@ async def check_healthz(_: Request) -> PlainTextResponse:
243
257
  def create_graphql_router(
244
258
  *,
245
259
  schema: BaseSchema,
246
- db: Callable[[], AsyncContextManager[AsyncSession]],
260
+ db: DbSessionFactory,
247
261
  model: Model,
248
262
  export_path: Path,
249
263
  corpus: Optional[Model] = None,
@@ -297,6 +311,7 @@ def create_graphql_router(
297
311
  cache_map=cache_for_dataloaders.record_count if cache_for_dataloaders else None,
298
312
  ),
299
313
  span_annotations=SpanAnnotationsDataLoader(db),
314
+ span_dataset_examples=SpanDatasetExamplesDataLoader(db),
300
315
  span_descendants=SpanDescendantsDataLoader(db),
301
316
  span_evaluations=SpanEvaluationsDataLoader(db),
302
317
  span_projects=SpanProjectsDataLoader(db),
@@ -321,19 +336,6 @@ def create_graphql_router(
321
336
  )
322
337
 
323
338
 
324
- class SessionFactory:
325
- def __init__(
326
- self,
327
- session_factory: Callable[[], AsyncContextManager[AsyncSession]],
328
- dialect: str,
329
- ):
330
- self.session_factory = session_factory
331
- self.dialect = SupportedSQLDialect(dialect)
332
-
333
- def __call__(self) -> AsyncContextManager[AsyncSession]:
334
- return self.session_factory()
335
-
336
-
337
339
  def create_engine_and_run_migrations(
338
340
  database_url: str,
339
341
  ) -> AsyncEngine:
@@ -382,7 +384,7 @@ async def plain_text_http_exception_handler(request: Request, exc: HTTPException
382
384
 
383
385
 
384
386
  def create_app(
385
- db: SessionFactory,
387
+ db: DbSessionFactory,
386
388
  export_path: Path,
387
389
  model: Model,
388
390
  umap_params: UMAPParameters,
@@ -463,6 +465,7 @@ def create_app(
463
465
  title="Arize-Phoenix REST API",
464
466
  version=REST_API_VERSION,
465
467
  lifespan=_lifespan(
468
+ dialect=db.dialect,
466
469
  read_only=read_only,
467
470
  bulk_inserter=bulk_inserter,
468
471
  tracer_provider=tracer_provider,
@@ -23,7 +23,7 @@ if TYPE_CHECKING:
23
23
  ProjectName: TypeAlias = str
24
24
 
25
25
 
26
- class Servicer(TraceServiceServicer):
26
+ class Servicer(TraceServiceServicer): # type:ignore
27
27
  def __init__(
28
28
  self,
29
29
  callback: Callable[[Span, ProjectName], Awaitable[None]],
@@ -78,7 +78,7 @@ class GrpcServer:
78
78
  interceptors=interceptors,
79
79
  )
80
80
  server.add_insecure_port(f"[::]:{get_env_grpc_port()}")
81
- add_TraceServiceServicer_to_server(Servicer(self._callback), server) # type: ignore
81
+ add_TraceServiceServicer_to_server(Servicer(self._callback), server)
82
82
  await server.start()
83
83
  self._server = server
84
84
 
phoenix/server/main.py CHANGED
@@ -33,25 +33,23 @@ from phoenix.pointcloud.umap_parameters import (
33
33
  UMAPParameters,
34
34
  )
35
35
  from phoenix.server.app import (
36
- SessionFactory,
37
36
  _db,
38
37
  create_app,
39
38
  create_engine_and_run_migrations,
40
39
  instrument_engine_if_enabled,
41
40
  )
41
+ from phoenix.server.types import DbSessionFactory
42
42
  from phoenix.settings import Settings
43
43
  from phoenix.trace.fixtures import (
44
44
  TRACES_FIXTURES,
45
- download_traces_fixture,
46
45
  get_dataset_fixtures,
47
46
  get_evals_from_fixture,
48
- get_trace_fixture_by_name,
47
+ load_example_traces,
49
48
  reset_fixture_span_ids_and_timestamps,
50
49
  send_dataset_fixtures,
51
50
  )
52
51
  from phoenix.trace.otel import decode_otlp_span, encode_span_to_otlp
53
52
  from phoenix.trace.schemas import Span
54
- from phoenix.trace.span_json_decoder import json_string_to_span
55
53
 
56
54
  logger = logging.getLogger(__name__)
57
55
 
@@ -221,10 +219,8 @@ if __name__ == "__main__":
221
219
  (
222
220
  # Apply `encode` here because legacy jsonl files contains UUIDs as strings.
223
221
  # `encode` removes the hyphens in the UUIDs.
224
- decode_otlp_span(encode_span_to_otlp(json_string_to_span(json_span)))
225
- for json_span in download_traces_fixture(
226
- get_trace_fixture_by_name(trace_dataset_name)
227
- )
222
+ decode_otlp_span(encode_span_to_otlp(span))
223
+ for span in load_example_traces(trace_dataset_name).to_spans()
228
224
  ),
229
225
  get_evals_from_fixture(trace_dataset_name),
230
226
  )
@@ -250,7 +246,7 @@ if __name__ == "__main__":
250
246
  working_dir = get_working_dir().resolve()
251
247
  engine = create_engine_and_run_migrations(db_connection_str)
252
248
  instrumentation_cleanups = instrument_engine_if_enabled(engine)
253
- factory = SessionFactory(session_factory=_db(engine), dialect=engine.dialect.name)
249
+ factory = DbSessionFactory(db=_db(engine), dialect=engine.dialect.name)
254
250
  app = create_app(
255
251
  db=factory,
256
252
  export_path=export_path,