arize-phoenix 4.14.1__py3-none-any.whl → 4.16.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (85) hide show
  1. {arize_phoenix-4.14.1.dist-info → arize_phoenix-4.16.0.dist-info}/METADATA +5 -3
  2. {arize_phoenix-4.14.1.dist-info → arize_phoenix-4.16.0.dist-info}/RECORD +81 -71
  3. phoenix/db/bulk_inserter.py +131 -5
  4. phoenix/db/engines.py +2 -1
  5. phoenix/db/helpers.py +23 -1
  6. phoenix/db/insertion/constants.py +2 -0
  7. phoenix/db/insertion/document_annotation.py +157 -0
  8. phoenix/db/insertion/helpers.py +13 -0
  9. phoenix/db/insertion/span_annotation.py +144 -0
  10. phoenix/db/insertion/trace_annotation.py +144 -0
  11. phoenix/db/insertion/types.py +261 -0
  12. phoenix/experiments/functions.py +3 -2
  13. phoenix/experiments/types.py +3 -3
  14. phoenix/server/api/context.py +7 -9
  15. phoenix/server/api/dataloaders/__init__.py +2 -0
  16. phoenix/server/api/dataloaders/average_experiment_run_latency.py +3 -3
  17. phoenix/server/api/dataloaders/dataset_example_revisions.py +2 -4
  18. phoenix/server/api/dataloaders/dataset_example_spans.py +2 -4
  19. phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -4
  20. phoenix/server/api/dataloaders/document_evaluations.py +2 -4
  21. phoenix/server/api/dataloaders/document_retrieval_metrics.py +2 -4
  22. phoenix/server/api/dataloaders/evaluation_summaries.py +2 -4
  23. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +2 -4
  24. phoenix/server/api/dataloaders/experiment_error_rates.py +2 -4
  25. phoenix/server/api/dataloaders/experiment_run_counts.py +2 -4
  26. phoenix/server/api/dataloaders/experiment_sequence_number.py +2 -4
  27. phoenix/server/api/dataloaders/latency_ms_quantile.py +2 -3
  28. phoenix/server/api/dataloaders/min_start_or_max_end_times.py +2 -4
  29. phoenix/server/api/dataloaders/project_by_name.py +3 -3
  30. phoenix/server/api/dataloaders/record_counts.py +2 -4
  31. phoenix/server/api/dataloaders/span_annotations.py +2 -4
  32. phoenix/server/api/dataloaders/span_dataset_examples.py +36 -0
  33. phoenix/server/api/dataloaders/span_descendants.py +2 -4
  34. phoenix/server/api/dataloaders/span_evaluations.py +2 -4
  35. phoenix/server/api/dataloaders/span_projects.py +3 -3
  36. phoenix/server/api/dataloaders/token_counts.py +2 -4
  37. phoenix/server/api/dataloaders/trace_evaluations.py +2 -4
  38. phoenix/server/api/dataloaders/trace_row_ids.py +2 -4
  39. phoenix/server/api/input_types/SpanAnnotationSort.py +17 -0
  40. phoenix/server/api/input_types/TraceAnnotationSort.py +17 -0
  41. phoenix/server/api/mutations/span_annotations_mutations.py +8 -3
  42. phoenix/server/api/mutations/trace_annotations_mutations.py +8 -3
  43. phoenix/server/api/openapi/main.py +18 -2
  44. phoenix/server/api/openapi/schema.py +12 -12
  45. phoenix/server/api/routers/v1/__init__.py +36 -83
  46. phoenix/server/api/routers/v1/datasets.py +515 -509
  47. phoenix/server/api/routers/v1/evaluations.py +164 -73
  48. phoenix/server/api/routers/v1/experiment_evaluations.py +68 -91
  49. phoenix/server/api/routers/v1/experiment_runs.py +98 -155
  50. phoenix/server/api/routers/v1/experiments.py +132 -181
  51. phoenix/server/api/routers/v1/pydantic_compat.py +78 -0
  52. phoenix/server/api/routers/v1/spans.py +164 -203
  53. phoenix/server/api/routers/v1/traces.py +134 -159
  54. phoenix/server/api/routers/v1/utils.py +95 -0
  55. phoenix/server/api/types/Span.py +27 -3
  56. phoenix/server/api/types/Trace.py +21 -4
  57. phoenix/server/api/utils.py +4 -4
  58. phoenix/server/app.py +172 -192
  59. phoenix/server/grpc_server.py +2 -2
  60. phoenix/server/main.py +5 -9
  61. phoenix/server/static/.vite/manifest.json +31 -31
  62. phoenix/server/static/assets/components-Ci5kMOk5.js +1175 -0
  63. phoenix/server/static/assets/{index-CQgXRwU0.js → index-BQG5WVX7.js} +2 -2
  64. phoenix/server/static/assets/{pages-hdjlFZhO.js → pages-BrevprVW.js} +451 -275
  65. phoenix/server/static/assets/{vendor-DPvSDRn3.js → vendor-CP0b0YG0.js} +2 -2
  66. phoenix/server/static/assets/{vendor-arizeai-CkvPT67c.js → vendor-arizeai-DTbiPGp6.js} +27 -27
  67. phoenix/server/static/assets/vendor-codemirror-DtdPDzrv.js +15 -0
  68. phoenix/server/static/assets/{vendor-recharts-5jlNaZuF.js → vendor-recharts-A0DA1O99.js} +1 -1
  69. phoenix/server/thread_server.py +2 -2
  70. phoenix/server/types.py +18 -0
  71. phoenix/session/client.py +5 -3
  72. phoenix/session/session.py +2 -2
  73. phoenix/trace/dsl/filter.py +2 -6
  74. phoenix/trace/fixtures.py +17 -23
  75. phoenix/trace/utils.py +23 -0
  76. phoenix/utilities/client.py +116 -0
  77. phoenix/utilities/project.py +1 -1
  78. phoenix/version.py +1 -1
  79. phoenix/server/api/routers/v1/dataset_examples.py +0 -178
  80. phoenix/server/openapi/docs.py +0 -221
  81. phoenix/server/static/assets/components-DeS0YEmv.js +0 -1142
  82. phoenix/server/static/assets/vendor-codemirror-Cqwpwlua.js +0 -12
  83. {arize_phoenix-4.14.1.dist-info → arize_phoenix-4.16.0.dist-info}/WHEEL +0 -0
  84. {arize_phoenix-4.14.1.dist-info → arize_phoenix-4.16.0.dist-info}/licenses/IP_NOTICE +0 -0
  85. {arize_phoenix-4.14.1.dist-info → arize_phoenix-4.16.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,16 +1,14 @@
1
1
  from typing import (
2
- AsyncContextManager,
3
- Callable,
4
2
  List,
5
3
  Optional,
6
4
  )
7
5
 
8
6
  from sqlalchemy import distinct, func, select
9
- from sqlalchemy.ext.asyncio import AsyncSession
10
7
  from strawberry.dataloader import DataLoader
11
8
  from typing_extensions import TypeAlias
12
9
 
13
10
  from phoenix.db import models
11
+ from phoenix.server.types import DbSessionFactory
14
12
 
15
13
  ExperimentId: TypeAlias = int
16
14
  Key: TypeAlias = ExperimentId
@@ -18,7 +16,7 @@ Result: TypeAlias = Optional[int]
18
16
 
19
17
 
20
18
  class ExperimentSequenceNumberDataLoader(DataLoader[Key, Result]):
21
- def __init__(self, db: Callable[[], AsyncContextManager[AsyncSession]]) -> None:
19
+ def __init__(self, db: DbSessionFactory) -> None:
22
20
  super().__init__(load_fn=self._load_fn)
23
21
  self._db = db
24
22
 
@@ -2,9 +2,7 @@ from collections import defaultdict
2
2
  from datetime import datetime
3
3
  from typing import (
4
4
  Any,
5
- AsyncContextManager,
6
5
  AsyncIterator,
7
- Callable,
8
6
  DefaultDict,
9
7
  List,
10
8
  Literal,
@@ -36,6 +34,7 @@ from phoenix.db import models
36
34
  from phoenix.db.helpers import SupportedSQLDialect
37
35
  from phoenix.server.api.dataloaders.cache import TwoTierCache
38
36
  from phoenix.server.api.input_types.TimeRange import TimeRange
37
+ from phoenix.server.types import DbSessionFactory
39
38
  from phoenix.trace.dsl import SpanFilter
40
39
 
41
40
  Kind: TypeAlias = Literal["span", "trace"]
@@ -88,7 +87,7 @@ class LatencyMsQuantileCache(
88
87
  class LatencyMsQuantileDataLoader(DataLoader[Key, Result]):
89
88
  def __init__(
90
89
  self,
91
- db: Callable[[], AsyncContextManager[AsyncSession]],
90
+ db: DbSessionFactory,
92
91
  cache_map: Optional[AbstractCache[Key, Result]] = None,
93
92
  ) -> None:
94
93
  super().__init__(
@@ -1,8 +1,6 @@
1
1
  from collections import defaultdict
2
2
  from datetime import datetime
3
3
  from typing import (
4
- AsyncContextManager,
5
- Callable,
6
4
  DefaultDict,
7
5
  List,
8
6
  Literal,
@@ -12,12 +10,12 @@ from typing import (
12
10
 
13
11
  from cachetools import LFUCache
14
12
  from sqlalchemy import func, select
15
- from sqlalchemy.ext.asyncio import AsyncSession
16
13
  from strawberry.dataloader import AbstractCache, DataLoader
17
14
  from typing_extensions import TypeAlias, assert_never
18
15
 
19
16
  from phoenix.db import models
20
17
  from phoenix.server.api.dataloaders.cache import TwoTierCache
18
+ from phoenix.server.types import DbSessionFactory
21
19
 
22
20
  Kind: TypeAlias = Literal["start", "end"]
23
21
  ProjectRowId: TypeAlias = int
@@ -50,7 +48,7 @@ class MinStartOrMaxEndTimeCache(
50
48
  class MinStartOrMaxEndTimeDataLoader(DataLoader[Key, Result]):
51
49
  def __init__(
52
50
  self,
53
- db: Callable[[], AsyncContextManager[AsyncSession]],
51
+ db: DbSessionFactory,
54
52
  cache_map: Optional[AbstractCache[Key, Result]] = None,
55
53
  ) -> None:
56
54
  super().__init__(
@@ -1,12 +1,12 @@
1
1
  from collections import defaultdict
2
- from typing import AsyncContextManager, Callable, DefaultDict, List, Optional
2
+ from typing import DefaultDict, List, Optional
3
3
 
4
4
  from sqlalchemy import select
5
- from sqlalchemy.ext.asyncio import AsyncSession
6
5
  from strawberry.dataloader import DataLoader
7
6
  from typing_extensions import TypeAlias
8
7
 
9
8
  from phoenix.db import models
9
+ from phoenix.server.types import DbSessionFactory
10
10
 
11
11
  ProjectName: TypeAlias = str
12
12
  Key: TypeAlias = ProjectName
@@ -14,7 +14,7 @@ Result: TypeAlias = Optional[models.Project]
14
14
 
15
15
 
16
16
  class ProjectByNameDataLoader(DataLoader[Key, Result]):
17
- def __init__(self, db: Callable[[], AsyncContextManager[AsyncSession]]) -> None:
17
+ def __init__(self, db: DbSessionFactory) -> None:
18
18
  super().__init__(load_fn=self._load_fn)
19
19
  self._db = db
20
20
 
@@ -2,8 +2,6 @@ from collections import defaultdict
2
2
  from datetime import datetime
3
3
  from typing import (
4
4
  Any,
5
- AsyncContextManager,
6
- Callable,
7
5
  DefaultDict,
8
6
  List,
9
7
  Literal,
@@ -13,13 +11,13 @@ from typing import (
13
11
 
14
12
  from cachetools import LFUCache, TTLCache
15
13
  from sqlalchemy import Select, func, select
16
- from sqlalchemy.ext.asyncio import AsyncSession
17
14
  from strawberry.dataloader import AbstractCache, DataLoader
18
15
  from typing_extensions import TypeAlias, assert_never
19
16
 
20
17
  from phoenix.db import models
21
18
  from phoenix.server.api.dataloaders.cache import TwoTierCache
22
19
  from phoenix.server.api.input_types.TimeRange import TimeRange
20
+ from phoenix.server.types import DbSessionFactory
23
21
  from phoenix.trace.dsl import SpanFilter
24
22
 
25
23
  Kind: TypeAlias = Literal["span", "trace"]
@@ -69,7 +67,7 @@ class RecordCountCache(
69
67
  class RecordCountDataLoader(DataLoader[Key, Result]):
70
68
  def __init__(
71
69
  self,
72
- db: Callable[[], AsyncContextManager[AsyncSession]],
70
+ db: DbSessionFactory,
73
71
  cache_map: Optional[AbstractCache[Key, Result]] = None,
74
72
  ) -> None:
75
73
  super().__init__(
@@ -1,25 +1,23 @@
1
1
  from collections import defaultdict
2
2
  from typing import (
3
- AsyncContextManager,
4
- Callable,
5
3
  DefaultDict,
6
4
  List,
7
5
  )
8
6
 
9
7
  from sqlalchemy import select
10
- from sqlalchemy.ext.asyncio import AsyncSession
11
8
  from strawberry.dataloader import DataLoader
12
9
  from typing_extensions import TypeAlias
13
10
 
14
11
  from phoenix.db import models
15
12
  from phoenix.server.api.types.SpanAnnotation import SpanAnnotation, to_gql_span_annotation
13
+ from phoenix.server.types import DbSessionFactory
16
14
 
17
15
  Key: TypeAlias = int
18
16
  Result: TypeAlias = List[SpanAnnotation]
19
17
 
20
18
 
21
19
  class SpanAnnotationsDataLoader(DataLoader[Key, Result]):
22
- def __init__(self, db: Callable[[], AsyncContextManager[AsyncSession]]) -> None:
20
+ def __init__(self, db: DbSessionFactory) -> None:
23
21
  super().__init__(load_fn=self._load_fn)
24
22
  self._db = db
25
23
 
@@ -0,0 +1,36 @@
1
+ from typing import (
2
+ Dict,
3
+ List,
4
+ )
5
+
6
+ from sqlalchemy import select
7
+ from strawberry.dataloader import DataLoader
8
+ from typing_extensions import TypeAlias
9
+
10
+ from phoenix.db import models
11
+ from phoenix.server.types import DbSessionFactory
12
+
13
+ SpanID: TypeAlias = int
14
+ Key: TypeAlias = SpanID
15
+ Result: TypeAlias = List[models.DatasetExample]
16
+
17
+
18
+ class SpanDatasetExamplesDataLoader(DataLoader[Key, Result]):
19
+ def __init__(self, db: DbSessionFactory) -> None:
20
+ super().__init__(load_fn=self._load_fn)
21
+ self._db = db
22
+
23
+ async def _load_fn(self, keys: List[Key]) -> List[Result]:
24
+ span_rowids = keys
25
+ async with self._db() as session:
26
+ dataset_examples: Dict[Key, List[models.DatasetExample]] = {
27
+ span_rowid: [] for span_rowid in span_rowids
28
+ }
29
+ async for span_rowid, dataset_example in await session.stream(
30
+ select(models.Span.id, models.DatasetExample)
31
+ .select_from(models.Span)
32
+ .join(models.DatasetExample, models.DatasetExample.span_rowid == models.Span.id)
33
+ .where(models.Span.id.in_(span_rowids))
34
+ ):
35
+ dataset_examples[span_rowid].append(dataset_example)
36
+ return [dataset_examples.get(span_rowid, []) for span_rowid in span_rowids]
@@ -1,19 +1,17 @@
1
1
  from random import randint
2
2
  from typing import (
3
- AsyncContextManager,
4
- Callable,
5
3
  Dict,
6
4
  List,
7
5
  )
8
6
 
9
7
  from aioitertools.itertools import groupby
10
8
  from sqlalchemy import select
11
- from sqlalchemy.ext.asyncio import AsyncSession
12
9
  from sqlalchemy.orm import joinedload
13
10
  from strawberry.dataloader import DataLoader
14
11
  from typing_extensions import TypeAlias
15
12
 
16
13
  from phoenix.db import models
14
+ from phoenix.server.types import DbSessionFactory
17
15
 
18
16
  SpanId: TypeAlias = str
19
17
 
@@ -22,7 +20,7 @@ Result: TypeAlias = List[models.Span]
22
20
 
23
21
 
24
22
  class SpanDescendantsDataLoader(DataLoader[Key, Result]):
25
- def __init__(self, db: Callable[[], AsyncContextManager[AsyncSession]]) -> None:
23
+ def __init__(self, db: DbSessionFactory) -> None:
26
24
  super().__init__(load_fn=self._load_fn)
27
25
  self._db = db
28
26
 
@@ -1,25 +1,23 @@
1
1
  from collections import defaultdict
2
2
  from typing import (
3
- AsyncContextManager,
4
- Callable,
5
3
  DefaultDict,
6
4
  List,
7
5
  )
8
6
 
9
7
  from sqlalchemy import select
10
- from sqlalchemy.ext.asyncio import AsyncSession
11
8
  from strawberry.dataloader import DataLoader
12
9
  from typing_extensions import TypeAlias
13
10
 
14
11
  from phoenix.db import models
15
12
  from phoenix.server.api.types.Evaluation import SpanEvaluation
13
+ from phoenix.server.types import DbSessionFactory
16
14
 
17
15
  Key: TypeAlias = int
18
16
  Result: TypeAlias = List[SpanEvaluation]
19
17
 
20
18
 
21
19
  class SpanEvaluationsDataLoader(DataLoader[Key, Result]):
22
- def __init__(self, db: Callable[[], AsyncContextManager[AsyncSession]]) -> None:
20
+ def __init__(self, db: DbSessionFactory) -> None:
23
21
  super().__init__(load_fn=self._load_fn)
24
22
  self._db = db
25
23
 
@@ -1,11 +1,11 @@
1
- from typing import AsyncContextManager, Callable, List, Union
1
+ from typing import List, Union
2
2
 
3
3
  from sqlalchemy import select
4
- from sqlalchemy.ext.asyncio import AsyncSession
5
4
  from strawberry.dataloader import DataLoader
6
5
  from typing_extensions import TypeAlias
7
6
 
8
7
  from phoenix.db import models
8
+ from phoenix.server.types import DbSessionFactory
9
9
 
10
10
  SpanID: TypeAlias = int
11
11
  Key: TypeAlias = SpanID
@@ -13,7 +13,7 @@ Result: TypeAlias = models.Project
13
13
 
14
14
 
15
15
  class SpanProjectsDataLoader(DataLoader[Key, Result]):
16
- def __init__(self, db: Callable[[], AsyncContextManager[AsyncSession]]) -> None:
16
+ def __init__(self, db: DbSessionFactory) -> None:
17
17
  super().__init__(load_fn=self._load_fn)
18
18
  self._db = db
19
19
 
@@ -2,8 +2,6 @@ from collections import defaultdict
2
2
  from datetime import datetime
3
3
  from typing import (
4
4
  Any,
5
- AsyncContextManager,
6
- Callable,
7
5
  DefaultDict,
8
6
  List,
9
7
  Literal,
@@ -14,7 +12,6 @@ from typing import (
14
12
  from cachetools import LFUCache, TTLCache
15
13
  from openinference.semconv.trace import SpanAttributes
16
14
  from sqlalchemy import Select, func, select
17
- from sqlalchemy.ext.asyncio import AsyncSession
18
15
  from sqlalchemy.sql.functions import coalesce
19
16
  from strawberry.dataloader import AbstractCache, DataLoader
20
17
  from typing_extensions import TypeAlias
@@ -22,6 +19,7 @@ from typing_extensions import TypeAlias
22
19
  from phoenix.db import models
23
20
  from phoenix.server.api.dataloaders.cache import TwoTierCache
24
21
  from phoenix.server.api.input_types.TimeRange import TimeRange
22
+ from phoenix.server.types import DbSessionFactory
25
23
  from phoenix.trace.dsl import SpanFilter
26
24
 
27
25
  Kind: TypeAlias = Literal["prompt", "completion", "total"]
@@ -71,7 +69,7 @@ class TokenCountCache(
71
69
  class TokenCountDataLoader(DataLoader[Key, Result]):
72
70
  def __init__(
73
71
  self,
74
- db: Callable[[], AsyncContextManager[AsyncSession]],
72
+ db: DbSessionFactory,
75
73
  cache_map: Optional[AbstractCache[Key, Result]] = None,
76
74
  ) -> None:
77
75
  super().__init__(
@@ -1,25 +1,23 @@
1
1
  from collections import defaultdict
2
2
  from typing import (
3
- AsyncContextManager,
4
- Callable,
5
3
  DefaultDict,
6
4
  List,
7
5
  )
8
6
 
9
7
  from sqlalchemy import select
10
- from sqlalchemy.ext.asyncio import AsyncSession
11
8
  from strawberry.dataloader import DataLoader
12
9
  from typing_extensions import TypeAlias
13
10
 
14
11
  from phoenix.db import models
15
12
  from phoenix.server.api.types.Evaluation import TraceEvaluation
13
+ from phoenix.server.types import DbSessionFactory
16
14
 
17
15
  Key: TypeAlias = int
18
16
  Result: TypeAlias = List[TraceEvaluation]
19
17
 
20
18
 
21
19
  class TraceEvaluationsDataLoader(DataLoader[Key, Result]):
22
- def __init__(self, db: Callable[[], AsyncContextManager[AsyncSession]]) -> None:
20
+ def __init__(self, db: DbSessionFactory) -> None:
23
21
  super().__init__(load_fn=self._load_fn)
24
22
  self._db = db
25
23
 
@@ -1,17 +1,15 @@
1
1
  from typing import (
2
- AsyncContextManager,
3
- Callable,
4
2
  List,
5
3
  Optional,
6
4
  Tuple,
7
5
  )
8
6
 
9
7
  from sqlalchemy import select
10
- from sqlalchemy.ext.asyncio import AsyncSession
11
8
  from strawberry.dataloader import DataLoader
12
9
  from typing_extensions import TypeAlias
13
10
 
14
11
  from phoenix.db import models
12
+ from phoenix.server.types import DbSessionFactory
15
13
 
16
14
  TraceId: TypeAlias = str
17
15
  Key: TypeAlias = TraceId
@@ -21,7 +19,7 @@ Result: TypeAlias = Optional[Tuple[TraceRowId, ProjectRowId]]
21
19
 
22
20
 
23
21
  class TraceRowIdsDataLoader(DataLoader[Key, Result]):
24
- def __init__(self, db: Callable[[], AsyncContextManager[AsyncSession]]) -> None:
22
+ def __init__(self, db: DbSessionFactory) -> None:
25
23
  super().__init__(load_fn=self._load_fn)
26
24
  self._db = db
27
25
 
@@ -0,0 +1,17 @@
1
+ from enum import Enum
2
+
3
+ import strawberry
4
+
5
+ from phoenix.server.api.types.SortDir import SortDir
6
+
7
+
8
+ @strawberry.enum
9
+ class SpanAnnotationColumn(Enum):
10
+ createdAt = "created_at"
11
+ name = "name"
12
+
13
+
14
+ @strawberry.input(description="The sort key and direction for SpanAnnotation connections")
15
+ class SpanAnnotationSort:
16
+ col: SpanAnnotationColumn
17
+ dir: SortDir
@@ -0,0 +1,17 @@
1
+ from enum import Enum
2
+
3
+ import strawberry
4
+
5
+ from phoenix.server.api.types.SortDir import SortDir
6
+
7
+
8
+ @strawberry.enum
9
+ class TraceAnnotationColumn(Enum):
10
+ createdAt = "created_at"
11
+ name = "name"
12
+
13
+
14
+ @strawberry.input(description="The sort key and direction for TraceAnnotation connections")
15
+ class TraceAnnotationSort:
16
+ col: TraceAnnotationColumn
17
+ dir: SortDir
@@ -11,6 +11,7 @@ from phoenix.server.api.input_types.CreateSpanAnnotationInput import CreateSpanA
11
11
  from phoenix.server.api.input_types.DeleteAnnotationsInput import DeleteAnnotationsInput
12
12
  from phoenix.server.api.input_types.PatchAnnotationInput import PatchAnnotationInput
13
13
  from phoenix.server.api.mutations.auth import IsAuthenticated
14
+ from phoenix.server.api.queries import Query
14
15
  from phoenix.server.api.types.node import from_global_id_with_expected_type
15
16
  from phoenix.server.api.types.SpanAnnotation import SpanAnnotation, to_gql_span_annotation
16
17
 
@@ -18,6 +19,7 @@ from phoenix.server.api.types.SpanAnnotation import SpanAnnotation, to_gql_span_
18
19
  @strawberry.type
19
20
  class SpanAnnotationMutationPayload:
20
21
  span_annotations: List[SpanAnnotation]
22
+ query: Query
21
23
 
22
24
 
23
25
  @strawberry.type
@@ -49,7 +51,8 @@ class SpanAnnotationMutationMixin:
49
51
  return SpanAnnotationMutationPayload(
50
52
  span_annotations=[
51
53
  to_gql_span_annotation(annotation) for annotation in inserted_annotations
52
- ]
54
+ ],
55
+ query=Query(),
53
56
  )
54
57
 
55
58
  @strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
@@ -89,7 +92,7 @@ class SpanAnnotationMutationMixin:
89
92
  if span_annotation is not None:
90
93
  patched_annotations.append(to_gql_span_annotation(span_annotation))
91
94
 
92
- return SpanAnnotationMutationPayload(span_annotations=patched_annotations)
95
+ return SpanAnnotationMutationPayload(span_annotations=patched_annotations, query=Query())
93
96
 
94
97
  @strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
95
98
  async def delete_span_annotations(
@@ -111,4 +114,6 @@ class SpanAnnotationMutationMixin:
111
114
  deleted_annotations_gql = [
112
115
  to_gql_span_annotation(annotation) for annotation in deleted_annotations
113
116
  ]
114
- return SpanAnnotationMutationPayload(span_annotations=deleted_annotations_gql)
117
+ return SpanAnnotationMutationPayload(
118
+ span_annotations=deleted_annotations_gql, query=Query()
119
+ )
@@ -11,6 +11,7 @@ from phoenix.server.api.input_types.CreateTraceAnnotationInput import CreateTrac
11
11
  from phoenix.server.api.input_types.DeleteAnnotationsInput import DeleteAnnotationsInput
12
12
  from phoenix.server.api.input_types.PatchAnnotationInput import PatchAnnotationInput
13
13
  from phoenix.server.api.mutations.auth import IsAuthenticated
14
+ from phoenix.server.api.queries import Query
14
15
  from phoenix.server.api.types.node import from_global_id_with_expected_type
15
16
  from phoenix.server.api.types.TraceAnnotation import TraceAnnotation, to_gql_trace_annotation
16
17
 
@@ -18,6 +19,7 @@ from phoenix.server.api.types.TraceAnnotation import TraceAnnotation, to_gql_tra
18
19
  @strawberry.type
19
20
  class TraceAnnotationMutationPayload:
20
21
  trace_annotations: List[TraceAnnotation]
22
+ query: Query
21
23
 
22
24
 
23
25
  @strawberry.type
@@ -49,7 +51,8 @@ class TraceAnnotationMutationMixin:
49
51
  return TraceAnnotationMutationPayload(
50
52
  trace_annotations=[
51
53
  to_gql_trace_annotation(annotation) for annotation in inserted_annotations
52
- ]
54
+ ],
55
+ query=Query(),
53
56
  )
54
57
 
55
58
  @strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
@@ -89,7 +92,7 @@ class TraceAnnotationMutationMixin:
89
92
  if trace_annotation:
90
93
  patched_annotations.append(to_gql_trace_annotation(trace_annotation))
91
94
 
92
- return TraceAnnotationMutationPayload(trace_annotations=patched_annotations)
95
+ return TraceAnnotationMutationPayload(trace_annotations=patched_annotations, query=Query())
93
96
 
94
97
  @strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
95
98
  async def delete_trace_annotations(
@@ -111,4 +114,6 @@ class TraceAnnotationMutationMixin:
111
114
  deleted_annotations_gql = [
112
115
  to_gql_trace_annotation(annotation) for annotation in deleted_annotations
113
116
  ]
114
- return TraceAnnotationMutationPayload(trace_annotations=deleted_annotations_gql)
117
+ return TraceAnnotationMutationPayload(
118
+ trace_annotations=deleted_annotations_gql, query=Query()
119
+ )
@@ -1,6 +1,22 @@
1
+ import json
2
+ from argparse import ArgumentParser
3
+ from typing import Optional, Tuple
4
+
1
5
  from .schema import get_openapi_schema
2
6
 
3
7
  if __name__ == "__main__":
4
- import yaml # type: ignore
8
+ parser = ArgumentParser()
9
+ parser.add_argument(
10
+ "--compress",
11
+ action="store_true",
12
+ help="Whether to output a compressed version of the OpenAPI schema",
13
+ )
14
+ args = parser.parse_args()
5
15
 
6
- print(yaml.dump(get_openapi_schema(), indent=2))
16
+ indent: Optional[int] = None
17
+ separator: Optional[Tuple[str, str]] = None
18
+ if args.compress:
19
+ separator = (",", ":")
20
+ else:
21
+ indent = 2
22
+ print(json.dumps(get_openapi_schema(), indent=indent, separators=separator))
@@ -1,16 +1,16 @@
1
- from typing import Any
1
+ from typing import Any, Dict
2
2
 
3
- from starlette.schemas import SchemaGenerator
3
+ from fastapi.openapi.utils import get_openapi
4
4
 
5
- from phoenix.server.api.routers.v1 import V1_ROUTES
5
+ from phoenix.server.api.routers.v1 import REST_API_VERSION
6
+ from phoenix.server.api.routers.v1 import router as v1_router
6
7
 
7
- OPENAPI_SCHEMA_GENERATOR = SchemaGenerator(
8
- {"openapi": "3.0.0", "info": {"title": "Arize-Phoenix API", "version": "1.0"}}
9
- )
10
8
 
11
-
12
- def get_openapi_schema() -> Any:
13
- """
14
- Exports an OpenAPI schema for the Phoenix REST API as a JSON object.
15
- """
16
- return OPENAPI_SCHEMA_GENERATOR.get_schema(V1_ROUTES) # type: ignore
9
+ def get_openapi_schema() -> Dict[str, Any]:
10
+ return get_openapi(
11
+ title="Arize-Phoenix REST API",
12
+ version=REST_API_VERSION,
13
+ openapi_version="3.1.0",
14
+ description="Schema for Arize-Phoenix REST API",
15
+ routes=v1_router.routes,
16
+ )