arize-phoenix 4.12.1rc1__py3-none-any.whl → 4.14.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (51) hide show
  1. {arize_phoenix-4.12.1rc1.dist-info → arize_phoenix-4.14.1.dist-info}/METADATA +12 -9
  2. {arize_phoenix-4.12.1rc1.dist-info → arize_phoenix-4.14.1.dist-info}/RECORD +48 -49
  3. phoenix/db/bulk_inserter.py +3 -1
  4. phoenix/experiments/evaluators/base.py +4 -0
  5. phoenix/experiments/evaluators/code_evaluators.py +80 -0
  6. phoenix/experiments/evaluators/llm_evaluators.py +77 -1
  7. phoenix/experiments/evaluators/utils.py +70 -21
  8. phoenix/experiments/functions.py +14 -14
  9. phoenix/server/api/context.py +7 -3
  10. phoenix/server/api/dataloaders/average_experiment_run_latency.py +23 -23
  11. phoenix/server/api/dataloaders/experiment_error_rates.py +30 -10
  12. phoenix/server/api/dataloaders/experiment_run_counts.py +18 -5
  13. phoenix/server/api/input_types/{CreateSpanAnnotationsInput.py → CreateSpanAnnotationInput.py} +4 -2
  14. phoenix/server/api/input_types/{CreateTraceAnnotationsInput.py → CreateTraceAnnotationInput.py} +4 -2
  15. phoenix/server/api/input_types/{PatchAnnotationsInput.py → PatchAnnotationInput.py} +4 -2
  16. phoenix/server/api/mutations/span_annotations_mutations.py +12 -6
  17. phoenix/server/api/mutations/trace_annotations_mutations.py +12 -6
  18. phoenix/server/api/openapi/main.py +2 -18
  19. phoenix/server/api/openapi/schema.py +12 -12
  20. phoenix/server/api/routers/v1/__init__.py +83 -36
  21. phoenix/server/api/routers/v1/dataset_examples.py +123 -102
  22. phoenix/server/api/routers/v1/datasets.py +506 -390
  23. phoenix/server/api/routers/v1/evaluations.py +66 -73
  24. phoenix/server/api/routers/v1/experiment_evaluations.py +91 -68
  25. phoenix/server/api/routers/v1/experiment_runs.py +155 -98
  26. phoenix/server/api/routers/v1/experiments.py +181 -132
  27. phoenix/server/api/routers/v1/spans.py +173 -144
  28. phoenix/server/api/routers/v1/traces.py +128 -115
  29. phoenix/server/api/types/Experiment.py +2 -2
  30. phoenix/server/api/types/Inferences.py +1 -2
  31. phoenix/server/api/types/Model.py +1 -2
  32. phoenix/server/app.py +177 -152
  33. phoenix/server/openapi/docs.py +221 -0
  34. phoenix/server/static/.vite/manifest.json +31 -31
  35. phoenix/server/static/assets/{components-C8sm_r1F.js → components-DeS0YEmv.js} +2 -2
  36. phoenix/server/static/assets/index-CQgXRwU0.js +100 -0
  37. phoenix/server/static/assets/{pages-bN7juCjh.js → pages-hdjlFZhO.js} +275 -198
  38. phoenix/server/static/assets/{vendor-CUDAPm8e.js → vendor-DPvSDRn3.js} +1 -1
  39. phoenix/server/static/assets/{vendor-arizeai-Do2HOmcL.js → vendor-arizeai-CkvPT67c.js} +2 -2
  40. phoenix/server/static/assets/{vendor-codemirror-CrdxOlMs.js → vendor-codemirror-Cqwpwlua.js} +1 -1
  41. phoenix/server/static/assets/{vendor-recharts-PKRvByVe.js → vendor-recharts-5jlNaZuF.js} +1 -1
  42. phoenix/server/thread_server.py +2 -2
  43. phoenix/session/client.py +9 -8
  44. phoenix/trace/dsl/filter.py +40 -25
  45. phoenix/version.py +1 -1
  46. phoenix/server/api/routers/v1/pydantic_compat.py +0 -78
  47. phoenix/server/api/routers/v1/utils.py +0 -95
  48. phoenix/server/static/assets/index-BEKPzgQs.js +0 -100
  49. {arize_phoenix-4.12.1rc1.dist-info → arize_phoenix-4.14.1.dist-info}/WHEEL +0 -0
  50. {arize_phoenix-4.12.1rc1.dist-info → arize_phoenix-4.14.1.dist-info}/licenses/IP_NOTICE +0 -0
  51. {arize_phoenix-4.12.1rc1.dist-info → arize_phoenix-4.14.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,6 +1,5 @@
1
1
  import functools
2
2
  import inspect
3
- from itertools import chain, islice, repeat
4
3
  from typing import TYPE_CHECKING, Any, Callable, Optional, Union
5
4
 
6
5
  from phoenix.experiments.types import (
@@ -75,6 +74,72 @@ def create_evaluator(
75
74
  name: Optional[str] = None,
76
75
  scorer: Optional[Callable[[Any], EvaluationResult]] = None,
77
76
  ) -> Callable[[Callable[..., Any]], "Evaluator"]:
77
+ """
78
+ A decorator that configures a sync or async function to be used as an experiment evaluator.
79
+
80
+ If the `evaluator` is a function of one argument then that argument will be
81
+ bound to the `output` of an experiment task. Alternatively, the `evaluator` can be a function
82
+ of any combination of specific argument names that will be bound to special values:
83
+ `input`: The input field of the dataset example
84
+ `output`: The output of an experiment task
85
+ `expected`: The expected or reference output of the dataset example
86
+ `reference`: An alias for `expected`
87
+ `metadata`: Metadata associated with the dataset example
88
+
89
+ Args:
90
+ kind (str | AnnotatorKind): Broadly indicates how the evaluator scores an experiment run.
91
+ Valid kinds are: "CODE", "LLM". Defaults to "CODE".
92
+ name (str, optional): The name of the evaluator. If not provided, the name of the function
93
+ will be used.
94
+ scorer (callable, optional): An optional function that converts the output of the wrapped
95
+ function into an `EvaluationResult`. This allows configuring the evaluation
96
+ payload by setting a label, score and explanation. By default, numeric outputs will
97
+ be recorded as scores, boolean outputs will be recorded as scores and labels, and
98
+ string outputs will be recorded as labels. If the output is a 2-tuple, the first item
99
+ will be recorded as the score and the second item will recorded as the explanation.
100
+
101
+ Examples:
102
+ Configuring an evaluator that returns a boolean
103
+
104
+ .. code-block:: python
105
+ @create_evaluator(kind="CODE", name="exact-match)
106
+ def match(output: str, expected: str) -> bool:
107
+ return output == expected
108
+
109
+ Configuring an evaluator that returns a label
110
+
111
+ .. code-block:: python
112
+ client = openai.Client()
113
+
114
+ @create_evaluator(kind="LLM")
115
+ def label(output: str) -> str:
116
+ res = client.chat.completions.create(
117
+ model = "gpt-4",
118
+ messages = [
119
+ {
120
+ "role": "user",
121
+ "content": (
122
+ "in one word, characterize the sentiment of the following customer "
123
+ f"request: {output}"
124
+ )
125
+ },
126
+ ],
127
+ )
128
+ label = res.choices[0].message.content
129
+ return label
130
+
131
+ Configuring an evaluator that returns a score and explanation
132
+
133
+ .. code-block:: python
134
+ from textdistance import levenshtein
135
+
136
+ @create_evaluator(kind="CODE", name="levenshtein-distance")
137
+ def ld(output: str, expected: str) -> Tuple[float, str]:
138
+ return (
139
+ levenshtein(output, expected),
140
+ f"Levenshtein distance between {output} and {expected}"
141
+ )
142
+ """
78
143
  if scorer is None:
79
144
  scorer = _default_eval_scorer
80
145
 
@@ -163,24 +228,8 @@ def _default_eval_scorer(result: Any) -> EvaluationResult:
163
228
  return EvaluationResult(score=float(result))
164
229
  if isinstance(result, str):
165
230
  return EvaluationResult(label=result)
166
- if isinstance(result, (tuple, list)) and 0 < len(result) <= 3:
167
- # Possible interpretations are:
168
- # - 3-tuple: (Score, Label, Explanation)
169
- # - 2-tuple: (Score, Explanation) or (Label, Explanation)
170
- # - 1-tuple: (Score, ) or (Label, )
171
- # Note that (Score, Label) conflicts with (Score, Explanation) and we
172
- # pick the latter because it's probably more prevalent. To get
173
- # (Score, Label), use a 3-tuple instead, i.e. (Score, Label, None).
174
- a, b, c = islice(chain(result, repeat(None)), 3)
175
- score, label, explanation = None, a, b
176
- if hasattr(a, "__float__"):
177
- try:
178
- score = float(a)
179
- except ValueError:
180
- pass
181
- else:
182
- label, explanation = (None, b) if len(result) < 3 else (b, c)
183
- return EvaluationResult(score=score, label=label, explanation=explanation)
184
- if result is None:
185
- return EvaluationResult(score=0)
231
+ if isinstance(result, (tuple, list)) and len(result) == 2:
232
+ # If the result is a 2-tuple, the first item will be recorded as the score
233
+ # and the second item will recorded as the explanation.
234
+ return EvaluationResult(score=float(result[0]), explanation=str(result[1]))
186
235
  raise ValueError(f"Unsupported evaluation result type: {type(result)}")
@@ -120,21 +120,23 @@ def run_experiment(
120
120
  output. If the `task` is a function of one argument then that argument will be bound to the
121
121
  `input` field of the dataset example. Alternatively, the `task` can be a function of any
122
122
  combination of specific argument names that will be bound to special values:
123
- `input`: The input field of the dataset example
124
- `expected`: The expected or reference output of the dataset example
125
- `reference`: An alias for `expected`
126
- `metadata`: Metadata associated with the dataset example
127
- `example`: The dataset `Example` object with all associated fields
123
+
124
+ - `input`: The input field of the dataset example
125
+ - `expected`: The expected or reference output of the dataset example
126
+ - `reference`: An alias for `expected`
127
+ - `metadata`: Metadata associated with the dataset example
128
+ - `example`: The dataset `Example` object with all associated fields
128
129
 
129
130
  An `evaluator` is either a synchronous or asynchronous function that returns either a boolean
130
131
  or numeric "score". If the `evaluator` is a function of one argument then that argument will be
131
132
  bound to the `output` of the task. Alternatively, the `evaluator` can be a function of any
132
133
  combination of specific argument names that will be bound to special values:
133
- `input`: The input field of the dataset example
134
- `output`: The output of the task
135
- `expected`: The expected or reference output of the dataset example
136
- `reference`: An alias for `expected`
137
- `metadata`: Metadata associated with the dataset example
134
+
135
+ - `input`: The input field of the dataset example
136
+ - `output`: The output of the task
137
+ - `expected`: The expected or reference output of the dataset example
138
+ - `reference`: An alias for `expected`
139
+ - `metadata`: Metadata associated with the dataset example
138
140
 
139
141
  Phoenix also provides pre-built evaluators in the `phoenix.experiments.evaluators` module.
140
142
 
@@ -366,10 +368,9 @@ def run_experiment(
366
368
  return exp_run
367
369
 
368
370
  _errors: Tuple[Type[BaseException], ...]
369
- if not hasattr(rate_limit_errors, "__iter__"):
371
+ if not isinstance(rate_limit_errors, Sequence):
370
372
  _errors = (rate_limit_errors,) if rate_limit_errors is not None else ()
371
373
  else:
372
- rate_limit_errors = cast(Sequence[Type[BaseException]], rate_limit_errors)
373
374
  _errors = tuple(filter(None, rate_limit_errors))
374
375
  rate_limiters = [RateLimiter(rate_limit_error=rate_limit_error) for rate_limit_error in _errors]
375
376
 
@@ -606,10 +607,9 @@ def evaluate_experiment(
606
607
  return eval_run
607
608
 
608
609
  _errors: Tuple[Type[BaseException], ...]
609
- if not hasattr(rate_limit_errors, "__iter__"):
610
+ if not isinstance(rate_limit_errors, Sequence):
610
611
  _errors = (rate_limit_errors,) if rate_limit_errors is not None else ()
611
612
  else:
612
- rate_limit_errors = cast(Sequence[Type[BaseException]], rate_limit_errors)
613
613
  _errors = tuple(filter(None, rate_limit_errors))
614
614
  rate_limiters = [RateLimiter(rate_limit_error=rate_limit_error) for rate_limit_error in _errors]
615
615
 
@@ -1,10 +1,12 @@
1
1
  from dataclasses import dataclass
2
2
  from datetime import datetime
3
3
  from pathlib import Path
4
- from typing import AsyncContextManager, Callable, Optional
4
+ from typing import AsyncContextManager, Callable, Optional, Union
5
5
 
6
6
  from sqlalchemy.ext.asyncio import AsyncSession
7
- from strawberry.fastapi import BaseContext
7
+ from starlette.requests import Request
8
+ from starlette.responses import Response
9
+ from starlette.websockets import WebSocket
8
10
  from typing_extensions import TypeAlias
9
11
 
10
12
  from phoenix.core.model_schema import Model
@@ -65,7 +67,9 @@ ProjectRowId: TypeAlias = int
65
67
 
66
68
 
67
69
  @dataclass
68
- class Context(BaseContext):
70
+ class Context:
71
+ request: Union[Request, WebSocket]
72
+ response: Optional[Response]
69
73
  db: Callable[[], AsyncContextManager[AsyncSession]]
70
74
  data_loaders: DataLoaders
71
75
  cache_for_dataloaders: Optional[CacheForDataLoaders]
@@ -1,8 +1,4 @@
1
- from typing import (
2
- AsyncContextManager,
3
- Callable,
4
- List,
5
- )
1
+ from typing import AsyncContextManager, Callable, List, Optional
6
2
 
7
3
  from sqlalchemy import func, select
8
4
  from sqlalchemy.ext.asyncio import AsyncSession
@@ -12,7 +8,7 @@ from typing_extensions import TypeAlias
12
8
  from phoenix.db import models
13
9
 
14
10
  ExperimentID: TypeAlias = int
15
- RunLatency: TypeAlias = float
11
+ RunLatency: TypeAlias = Optional[float]
16
12
  Key: TypeAlias = ExperimentID
17
13
  Result: TypeAlias = RunLatency
18
14
 
@@ -27,26 +23,30 @@ class AverageExperimentRunLatencyDataLoader(DataLoader[Key, Result]):
27
23
 
28
24
  async def _load_fn(self, keys: List[Key]) -> List[Result]:
29
25
  experiment_ids = keys
26
+ resolved_experiment_ids = (
27
+ select(models.Experiment.id)
28
+ .where(models.Experiment.id.in_(set(experiment_ids)))
29
+ .subquery()
30
+ )
31
+ query = (
32
+ select(
33
+ resolved_experiment_ids.c.id,
34
+ func.avg(
35
+ func.extract("epoch", models.ExperimentRun.end_time)
36
+ - func.extract("epoch", models.ExperimentRun.start_time)
37
+ ),
38
+ )
39
+ .outerjoin_from(
40
+ from_=resolved_experiment_ids,
41
+ target=models.ExperimentRun,
42
+ onclause=resolved_experiment_ids.c.id == models.ExperimentRun.experiment_id,
43
+ )
44
+ .group_by(resolved_experiment_ids.c.id)
45
+ )
30
46
  async with self._db() as session:
31
47
  avg_latencies = {
32
48
  experiment_id: avg_latency
33
- async for experiment_id, avg_latency in await session.stream(
34
- select(
35
- models.ExperimentRun.experiment_id,
36
- func.avg(
37
- func.extract(
38
- "epoch",
39
- models.ExperimentRun.end_time,
40
- )
41
- - func.extract(
42
- "epoch",
43
- models.ExperimentRun.start_time,
44
- )
45
- ),
46
- )
47
- .where(models.ExperimentRun.experiment_id.in_(set(experiment_ids)))
48
- .group_by(models.ExperimentRun.experiment_id)
49
- )
49
+ async for experiment_id, avg_latency in await session.stream(query)
50
50
  }
51
51
  return [
52
52
  avg_latencies.get(experiment_id, ValueError(f"Unknown experiment: {experiment_id}"))
@@ -5,7 +5,7 @@ from typing import (
5
5
  Optional,
6
6
  )
7
7
 
8
- from sqlalchemy import func, select
8
+ from sqlalchemy import case, func, select
9
9
  from sqlalchemy.ext.asyncio import AsyncSession
10
10
  from strawberry.dataloader import DataLoader
11
11
  from typing_extensions import TypeAlias
@@ -28,16 +28,36 @@ class ExperimentErrorRatesDataLoader(DataLoader[Key, Result]):
28
28
 
29
29
  async def _load_fn(self, keys: List[Key]) -> List[Result]:
30
30
  experiment_ids = keys
31
+ resolved_experiment_ids = (
32
+ select(models.Experiment.id)
33
+ .where(models.Experiment.id.in_(set(experiment_ids)))
34
+ .subquery()
35
+ )
36
+ query = (
37
+ select(
38
+ resolved_experiment_ids.c.id,
39
+ case(
40
+ (
41
+ func.count(models.ExperimentRun.id) != 0,
42
+ func.count(models.ExperimentRun.error)
43
+ / func.count(models.ExperimentRun.id),
44
+ ),
45
+ else_=None,
46
+ ),
47
+ )
48
+ .outerjoin_from(
49
+ from_=resolved_experiment_ids,
50
+ target=models.ExperimentRun,
51
+ onclause=resolved_experiment_ids.c.id == models.ExperimentRun.experiment_id,
52
+ )
53
+ .group_by(resolved_experiment_ids.c.id)
54
+ )
31
55
  async with self._db() as session:
32
56
  error_rates = {
33
57
  experiment_id: error_rate
34
- async for experiment_id, error_rate in await session.stream(
35
- select(
36
- models.ExperimentRun.experiment_id,
37
- func.count(models.ExperimentRun.error) / func.count(),
38
- )
39
- .group_by(models.ExperimentRun.experiment_id)
40
- .where(models.ExperimentRun.experiment_id.in_(experiment_ids))
41
- )
58
+ async for experiment_id, error_rate in await session.stream(query)
42
59
  }
43
- return [error_rates.get(experiment_id) for experiment_id in experiment_ids]
60
+ return [
61
+ error_rates.get(experiment_id, ValueError(f"Unknown experiment ID: {experiment_id}"))
62
+ for experiment_id in experiment_ids
63
+ ]
@@ -27,14 +27,27 @@ class ExperimentRunCountsDataLoader(DataLoader[Key, Result]):
27
27
 
28
28
  async def _load_fn(self, keys: List[Key]) -> List[Result]:
29
29
  experiment_ids = keys
30
+ resolved_experiment_ids = (
31
+ select(models.Experiment.id)
32
+ .where(models.Experiment.id.in_(set(experiment_ids)))
33
+ .subquery()
34
+ )
35
+ query = (
36
+ select(
37
+ resolved_experiment_ids.c.id,
38
+ func.count(models.ExperimentRun.experiment_id),
39
+ )
40
+ .outerjoin_from(
41
+ from_=resolved_experiment_ids,
42
+ target=models.ExperimentRun,
43
+ onclause=resolved_experiment_ids.c.id == models.ExperimentRun.experiment_id,
44
+ )
45
+ .group_by(resolved_experiment_ids.c.id)
46
+ )
30
47
  async with self._db() as session:
31
48
  run_counts = {
32
49
  experiment_id: run_count
33
- async for experiment_id, run_count in await session.stream(
34
- select(models.ExperimentRun.experiment_id, func.count())
35
- .where(models.ExperimentRun.experiment_id.in_(set(experiment_ids)))
36
- .group_by(models.ExperimentRun.experiment_id)
37
- )
50
+ async for experiment_id, run_count in await session.stream(query)
38
51
  }
39
52
  return [
40
53
  run_counts.get(experiment_id, ValueError(f"Unknown experiment: {experiment_id}"))
@@ -4,12 +4,14 @@ import strawberry
4
4
  from strawberry.relay import GlobalID
5
5
  from strawberry.scalars import JSON
6
6
 
7
+ from phoenix.server.api.types.AnnotatorKind import AnnotatorKind
8
+
7
9
 
8
10
  @strawberry.input
9
- class CreateSpanAnnotationsInput:
11
+ class CreateSpanAnnotationInput:
10
12
  span_id: GlobalID
11
13
  name: str
12
- annotator_kind: str
14
+ annotator_kind: AnnotatorKind
13
15
  label: Optional[str] = None
14
16
  score: Optional[float] = None
15
17
  explanation: Optional[str] = None
@@ -4,12 +4,14 @@ import strawberry
4
4
  from strawberry.relay import GlobalID
5
5
  from strawberry.scalars import JSON
6
6
 
7
+ from phoenix.server.api.types.AnnotatorKind import AnnotatorKind
8
+
7
9
 
8
10
  @strawberry.input
9
- class CreateTraceAnnotationsInput:
11
+ class CreateTraceAnnotationInput:
10
12
  trace_id: GlobalID
11
13
  name: str
12
- annotator_kind: str
14
+ annotator_kind: AnnotatorKind
13
15
  label: Optional[str] = None
14
16
  score: Optional[float] = None
15
17
  explanation: Optional[str] = None
@@ -5,12 +5,14 @@ from strawberry import UNSET
5
5
  from strawberry.relay import GlobalID
6
6
  from strawberry.scalars import JSON
7
7
 
8
+ from phoenix.server.api.types.AnnotatorKind import AnnotatorKind
9
+
8
10
 
9
11
  @strawberry.input
10
- class PatchAnnotationsInput:
12
+ class PatchAnnotationInput:
11
13
  annotation_id: GlobalID
12
14
  name: Optional[str] = UNSET
13
- annotator_kind: Optional[str] = UNSET
15
+ annotator_kind: Optional[AnnotatorKind] = UNSET
14
16
  label: Optional[str] = UNSET
15
17
  score: Optional[float] = UNSET
16
18
  explanation: Optional[str] = UNSET
@@ -7,9 +7,9 @@ from strawberry.types import Info
7
7
 
8
8
  from phoenix.db import models
9
9
  from phoenix.server.api.context import Context
10
- from phoenix.server.api.input_types.CreateSpanAnnotationsInput import CreateSpanAnnotationsInput
10
+ from phoenix.server.api.input_types.CreateSpanAnnotationInput import CreateSpanAnnotationInput
11
11
  from phoenix.server.api.input_types.DeleteAnnotationsInput import DeleteAnnotationsInput
12
- from phoenix.server.api.input_types.PatchAnnotationsInput import PatchAnnotationsInput
12
+ from phoenix.server.api.input_types.PatchAnnotationInput import PatchAnnotationInput
13
13
  from phoenix.server.api.mutations.auth import IsAuthenticated
14
14
  from phoenix.server.api.types.node import from_global_id_with_expected_type
15
15
  from phoenix.server.api.types.SpanAnnotation import SpanAnnotation, to_gql_span_annotation
@@ -24,7 +24,7 @@ class SpanAnnotationMutationPayload:
24
24
  class SpanAnnotationMutationMixin:
25
25
  @strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
26
26
  async def create_span_annotations(
27
- self, info: Info[Context, None], input: List[CreateSpanAnnotationsInput]
27
+ self, info: Info[Context, None], input: List[CreateSpanAnnotationInput]
28
28
  ) -> SpanAnnotationMutationPayload:
29
29
  inserted_annotations: Sequence[models.SpanAnnotation] = []
30
30
  async with info.context.db() as session:
@@ -35,7 +35,7 @@ class SpanAnnotationMutationMixin:
35
35
  label=annotation.label,
36
36
  score=annotation.score,
37
37
  explanation=annotation.explanation,
38
- annotator_kind=annotation.annotator_kind,
38
+ annotator_kind=annotation.annotator_kind.value,
39
39
  metadata_=annotation.metadata,
40
40
  )
41
41
  for annotation in input
@@ -54,7 +54,7 @@ class SpanAnnotationMutationMixin:
54
54
 
55
55
  @strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
56
56
  async def patch_span_annotations(
57
- self, info: Info[Context, None], input: List[PatchAnnotationsInput]
57
+ self, info: Info[Context, None], input: List[PatchAnnotationInput]
58
58
  ) -> SpanAnnotationMutationPayload:
59
59
  patched_annotations = []
60
60
  async with info.context.db() as session:
@@ -66,7 +66,13 @@ class SpanAnnotationMutationMixin:
66
66
  column.key: patch_value
67
67
  for column, patch_value, column_is_nullable in (
68
68
  (models.SpanAnnotation.name, annotation.name, False),
69
- (models.SpanAnnotation.annotator_kind, annotation.annotator_kind, False),
69
+ (
70
+ models.SpanAnnotation.annotator_kind,
71
+ annotation.annotator_kind.value
72
+ if annotation.annotator_kind is not None
73
+ else None,
74
+ False,
75
+ ),
70
76
  (models.SpanAnnotation.label, annotation.label, True),
71
77
  (models.SpanAnnotation.score, annotation.score, True),
72
78
  (models.SpanAnnotation.explanation, annotation.explanation, True),
@@ -7,9 +7,9 @@ from strawberry.types import Info
7
7
 
8
8
  from phoenix.db import models
9
9
  from phoenix.server.api.context import Context
10
- from phoenix.server.api.input_types.CreateTraceAnnotationsInput import CreateTraceAnnotationsInput
10
+ from phoenix.server.api.input_types.CreateTraceAnnotationInput import CreateTraceAnnotationInput
11
11
  from phoenix.server.api.input_types.DeleteAnnotationsInput import DeleteAnnotationsInput
12
- from phoenix.server.api.input_types.PatchAnnotationsInput import PatchAnnotationsInput
12
+ from phoenix.server.api.input_types.PatchAnnotationInput import PatchAnnotationInput
13
13
  from phoenix.server.api.mutations.auth import IsAuthenticated
14
14
  from phoenix.server.api.types.node import from_global_id_with_expected_type
15
15
  from phoenix.server.api.types.TraceAnnotation import TraceAnnotation, to_gql_trace_annotation
@@ -24,7 +24,7 @@ class TraceAnnotationMutationPayload:
24
24
  class TraceAnnotationMutationMixin:
25
25
  @strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
26
26
  async def create_trace_annotations(
27
- self, info: Info[Context, None], input: List[CreateTraceAnnotationsInput]
27
+ self, info: Info[Context, None], input: List[CreateTraceAnnotationInput]
28
28
  ) -> TraceAnnotationMutationPayload:
29
29
  inserted_annotations: Sequence[models.TraceAnnotation] = []
30
30
  async with info.context.db() as session:
@@ -35,7 +35,7 @@ class TraceAnnotationMutationMixin:
35
35
  label=annotation.label,
36
36
  score=annotation.score,
37
37
  explanation=annotation.explanation,
38
- annotator_kind=annotation.annotator_kind,
38
+ annotator_kind=annotation.annotator_kind.value,
39
39
  metadata_=annotation.metadata,
40
40
  )
41
41
  for annotation in input
@@ -54,7 +54,7 @@ class TraceAnnotationMutationMixin:
54
54
 
55
55
  @strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
56
56
  async def patch_trace_annotations(
57
- self, info: Info[Context, None], input: List[PatchAnnotationsInput]
57
+ self, info: Info[Context, None], input: List[PatchAnnotationInput]
58
58
  ) -> TraceAnnotationMutationPayload:
59
59
  patched_annotations = []
60
60
  async with info.context.db() as session:
@@ -66,7 +66,13 @@ class TraceAnnotationMutationMixin:
66
66
  column.key: patch_value
67
67
  for column, patch_value, column_is_nullable in (
68
68
  (models.TraceAnnotation.name, annotation.name, False),
69
- (models.TraceAnnotation.annotator_kind, annotation.annotator_kind, False),
69
+ (
70
+ models.TraceAnnotation.annotator_kind,
71
+ annotation.annotator_kind.value
72
+ if annotation.annotator_kind is not None
73
+ else None,
74
+ False,
75
+ ),
70
76
  (models.TraceAnnotation.label, annotation.label, True),
71
77
  (models.TraceAnnotation.score, annotation.score, True),
72
78
  (models.TraceAnnotation.explanation, annotation.explanation, True),
@@ -1,22 +1,6 @@
1
- import json
2
- from argparse import ArgumentParser
3
- from typing import Optional, Tuple
4
-
5
1
  from .schema import get_openapi_schema
6
2
 
7
3
  if __name__ == "__main__":
8
- parser = ArgumentParser()
9
- parser.add_argument(
10
- "--compress",
11
- action="store_true",
12
- help="Whether to output a compressed version of the OpenAPI schema",
13
- )
14
- args = parser.parse_args()
4
+ import yaml # type: ignore
15
5
 
16
- indent: Optional[int] = None
17
- separator: Optional[Tuple[str, str]] = None
18
- if args.compress:
19
- separator = (",", ":")
20
- else:
21
- indent = 2
22
- print(json.dumps(get_openapi_schema(), indent=indent, separators=separator))
6
+ print(yaml.dump(get_openapi_schema(), indent=2))
@@ -1,16 +1,16 @@
1
- from typing import Any, Dict
1
+ from typing import Any
2
2
 
3
- from fastapi.openapi.utils import get_openapi
3
+ from starlette.schemas import SchemaGenerator
4
4
 
5
- from phoenix.server.api.routers.v1 import REST_API_VERSION
6
- from phoenix.server.api.routers.v1 import router as v1_router
5
+ from phoenix.server.api.routers.v1 import V1_ROUTES
7
6
 
7
+ OPENAPI_SCHEMA_GENERATOR = SchemaGenerator(
8
+ {"openapi": "3.0.0", "info": {"title": "Arize-Phoenix API", "version": "1.0"}}
9
+ )
8
10
 
9
- def get_openapi_schema() -> Dict[str, Any]:
10
- return get_openapi(
11
- title="Arize-Phoenix REST API",
12
- version=REST_API_VERSION,
13
- openapi_version="3.1.0",
14
- description="Schema for Arize-Phoenix REST API",
15
- routes=v1_router.routes,
16
- )
11
+
12
+ def get_openapi_schema() -> Any:
13
+ """
14
+ Exports an OpenAPI schema for the Phoenix REST API as a JSON object.
15
+ """
16
+ return OPENAPI_SCHEMA_GENERATOR.get_schema(V1_ROUTES) # type: ignore