arize-phoenix 0.0.32rc1__py3-none-any.whl → 0.0.33__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (71) hide show
  1. {arize_phoenix-0.0.32rc1.dist-info → arize_phoenix-0.0.33.dist-info}/METADATA +11 -5
  2. {arize_phoenix-0.0.32rc1.dist-info → arize_phoenix-0.0.33.dist-info}/RECORD +69 -40
  3. phoenix/__init__.py +3 -1
  4. phoenix/config.py +23 -1
  5. phoenix/core/model_schema.py +14 -37
  6. phoenix/core/model_schema_adapter.py +0 -1
  7. phoenix/core/traces.py +285 -0
  8. phoenix/datasets/dataset.py +14 -21
  9. phoenix/datasets/errors.py +4 -1
  10. phoenix/datasets/schema.py +1 -1
  11. phoenix/datetime_utils.py +87 -0
  12. phoenix/experimental/callbacks/__init__.py +0 -0
  13. phoenix/experimental/callbacks/langchain_tracer.py +228 -0
  14. phoenix/experimental/callbacks/llama_index_trace_callback_handler.py +364 -0
  15. phoenix/experimental/evals/__init__.py +33 -0
  16. phoenix/experimental/evals/functions/__init__.py +4 -0
  17. phoenix/experimental/evals/functions/binary.py +156 -0
  18. phoenix/experimental/evals/functions/common.py +31 -0
  19. phoenix/experimental/evals/functions/generate.py +50 -0
  20. phoenix/experimental/evals/models/__init__.py +4 -0
  21. phoenix/experimental/evals/models/base.py +130 -0
  22. phoenix/experimental/evals/models/openai.py +128 -0
  23. phoenix/experimental/evals/retrievals.py +2 -2
  24. phoenix/experimental/evals/templates/__init__.py +24 -0
  25. phoenix/experimental/evals/templates/default_templates.py +126 -0
  26. phoenix/experimental/evals/templates/template.py +107 -0
  27. phoenix/experimental/evals/utils/__init__.py +0 -0
  28. phoenix/experimental/evals/utils/downloads.py +33 -0
  29. phoenix/experimental/evals/utils/threads.py +27 -0
  30. phoenix/experimental/evals/utils/types.py +9 -0
  31. phoenix/experimental/evals/utils.py +33 -0
  32. phoenix/metrics/binning.py +0 -1
  33. phoenix/metrics/timeseries.py +2 -3
  34. phoenix/server/api/context.py +2 -0
  35. phoenix/server/api/input_types/SpanSort.py +60 -0
  36. phoenix/server/api/schema.py +85 -4
  37. phoenix/server/api/types/DataQualityMetric.py +10 -1
  38. phoenix/server/api/types/Dataset.py +2 -4
  39. phoenix/server/api/types/DatasetInfo.py +10 -0
  40. phoenix/server/api/types/ExportEventsMutation.py +4 -1
  41. phoenix/server/api/types/Functionality.py +15 -0
  42. phoenix/server/api/types/MimeType.py +16 -0
  43. phoenix/server/api/types/Model.py +3 -5
  44. phoenix/server/api/types/SortDir.py +13 -0
  45. phoenix/server/api/types/Span.py +229 -0
  46. phoenix/server/api/types/TimeSeries.py +9 -2
  47. phoenix/server/api/types/pagination.py +2 -0
  48. phoenix/server/app.py +24 -4
  49. phoenix/server/main.py +60 -24
  50. phoenix/server/span_handler.py +39 -0
  51. phoenix/server/static/index.js +956 -479
  52. phoenix/server/thread_server.py +10 -2
  53. phoenix/services.py +39 -16
  54. phoenix/session/session.py +99 -27
  55. phoenix/trace/exporter.py +71 -0
  56. phoenix/trace/filter.py +181 -0
  57. phoenix/trace/fixtures.py +23 -8
  58. phoenix/trace/schemas.py +59 -6
  59. phoenix/trace/semantic_conventions.py +141 -1
  60. phoenix/trace/span_json_decoder.py +60 -6
  61. phoenix/trace/span_json_encoder.py +1 -9
  62. phoenix/trace/trace_dataset.py +100 -8
  63. phoenix/trace/tracer.py +26 -3
  64. phoenix/trace/v1/__init__.py +522 -0
  65. phoenix/trace/v1/trace_pb2.py +52 -0
  66. phoenix/trace/v1/trace_pb2.pyi +351 -0
  67. phoenix/core/dimension_data_type.py +0 -6
  68. phoenix/core/dimension_type.py +0 -9
  69. {arize_phoenix-0.0.32rc1.dist-info → arize_phoenix-0.0.33.dist-info}/WHEEL +0 -0
  70. {arize_phoenix-0.0.32rc1.dist-info → arize_phoenix-0.0.33.dist-info}/licenses/IP_NOTICE +0 -0
  71. {arize_phoenix-0.0.32rc1.dist-info → arize_phoenix-0.0.33.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,229 @@
1
+ import json
2
+ from collections import defaultdict
3
+ from datetime import datetime
4
+ from enum import Enum
5
+ from typing import Any, DefaultDict, List, Mapping, Optional, cast
6
+
7
+ import strawberry
8
+ from strawberry import ID
9
+ from strawberry.types import Info
10
+
11
+ import phoenix.trace.schemas as trace_schema
12
+ from phoenix.core.traces import (
13
+ CUMULATIVE_LLM_TOKEN_COUNT_COMPLETION,
14
+ CUMULATIVE_LLM_TOKEN_COUNT_PROMPT,
15
+ CUMULATIVE_LLM_TOKEN_COUNT_TOTAL,
16
+ LATENCY_MS,
17
+ )
18
+ from phoenix.server.api.context import Context
19
+ from phoenix.server.api.types.MimeType import MimeType
20
+ from phoenix.trace.schemas import SpanID
21
+ from phoenix.trace.semantic_conventions import (
22
+ EXCEPTION_MESSAGE,
23
+ INPUT_MIME_TYPE,
24
+ INPUT_VALUE,
25
+ LLM_TOKEN_COUNT_COMPLETION,
26
+ LLM_TOKEN_COUNT_PROMPT,
27
+ LLM_TOKEN_COUNT_TOTAL,
28
+ OUTPUT_MIME_TYPE,
29
+ OUTPUT_VALUE,
30
+ )
31
+
32
+
33
+ @strawberry.enum
34
+ class SpanKind(Enum):
35
+ """
36
+ The type of work that a Span encapsulates.
37
+
38
+ NB: this is actively under construction
39
+ """
40
+
41
+ chain = trace_schema.SpanKind.CHAIN
42
+ tool = trace_schema.SpanKind.TOOL
43
+ llm = trace_schema.SpanKind.LLM
44
+ retriever = trace_schema.SpanKind.RETRIEVER
45
+ embedding = trace_schema.SpanKind.EMBEDDING
46
+ agent = trace_schema.SpanKind.AGENT
47
+ unknown = trace_schema.SpanKind.UNKNOWN
48
+
49
+ @classmethod
50
+ def _missing_(cls, v: Any) -> Optional["SpanKind"]:
51
+ return None if v else cls.unknown
52
+
53
+
54
+ @strawberry.type
55
+ class SpanContext:
56
+ trace_id: ID
57
+ span_id: ID
58
+
59
+
60
+ @strawberry.type
61
+ class SpanIOValue:
62
+ mime_type: MimeType
63
+ value: str
64
+
65
+
66
+ @strawberry.enum
67
+ class SpanStatusCode(Enum):
68
+ OK = trace_schema.SpanStatusCode.OK
69
+ ERROR = trace_schema.SpanStatusCode.ERROR
70
+ UNSET = trace_schema.SpanStatusCode.UNSET
71
+
72
+ @classmethod
73
+ def _missing_(cls, v: Any) -> Optional["SpanStatusCode"]:
74
+ return None if v else cls.UNSET
75
+
76
+
77
+ @strawberry.type
78
+ class SpanEvent:
79
+ name: str
80
+ message: str
81
+ timestamp: datetime
82
+
83
+ @staticmethod
84
+ def from_event(
85
+ event: trace_schema.SpanEvent,
86
+ ) -> "SpanEvent":
87
+ return SpanEvent(
88
+ name=event.name,
89
+ message=cast(str, event.attributes.get(EXCEPTION_MESSAGE) or ""),
90
+ timestamp=event.timestamp,
91
+ )
92
+
93
+
94
+ @strawberry.type
95
+ class Span:
96
+ name: str
97
+ status_code: SpanStatusCode
98
+ start_time: datetime
99
+ end_time: Optional[datetime]
100
+ latency_ms: Optional[float]
101
+ parent_id: Optional[ID] = strawberry.field(
102
+ description="the parent span ID. If null, it is a root span"
103
+ )
104
+ span_kind: SpanKind
105
+ context: SpanContext
106
+ attributes: str = strawberry.field(
107
+ description="Span attributes as a JSON string",
108
+ )
109
+ token_count_total: Optional[int]
110
+ token_count_prompt: Optional[int]
111
+ token_count_completion: Optional[int]
112
+ input: Optional[SpanIOValue]
113
+ output: Optional[SpanIOValue]
114
+ events: List[SpanEvent]
115
+ cumulative_token_count_total: Optional[int] = strawberry.field(
116
+ description="Cumulative (prompt plus completion) token count from "
117
+ "self and all descendant spans (children, grandchildren, etc.)",
118
+ )
119
+ cumulative_token_count_prompt: Optional[int] = strawberry.field(
120
+ description="Cumulative (prompt) token count from self and all "
121
+ "descendant spans (children, grandchildren, etc.)",
122
+ )
123
+ cumulative_token_count_completion: Optional[int] = strawberry.field(
124
+ description="Cumulative (completion) token count from self and all "
125
+ "descendant spans (children, grandchildren, etc.)",
126
+ )
127
+
128
+ @strawberry.field(
129
+ description="All descendant spans (children, grandchildren, etc.)",
130
+ ) # type: ignore
131
+ def descendants(
132
+ self,
133
+ info: Info[Context, None],
134
+ ) -> List["Span"]:
135
+ if (traces := info.context.traces) is None:
136
+ return []
137
+ return [
138
+ to_gql_span(cast(trace_schema.Span, traces[span_id]))
139
+ for span_id in traces.get_descendant_span_ids(
140
+ cast(SpanID, self.context.span_id),
141
+ )
142
+ ]
143
+
144
+
145
+ def to_gql_span(span: trace_schema.Span) -> "Span":
146
+ events: List[SpanEvent] = list(map(SpanEvent.from_event, span.events))
147
+ input_value = cast(Optional[str], span.attributes.get(INPUT_VALUE))
148
+ output_value = cast(Optional[str], span.attributes.get(OUTPUT_VALUE))
149
+ return Span(
150
+ name=span.name,
151
+ status_code=SpanStatusCode(span.status_code),
152
+ parent_id=cast(Optional[ID], span.parent_id),
153
+ span_kind=SpanKind(span.span_kind),
154
+ start_time=span.start_time,
155
+ end_time=span.end_time,
156
+ latency_ms=cast(Optional[float], span.attributes.get(LATENCY_MS)),
157
+ context=SpanContext(
158
+ trace_id=cast(ID, span.context.trace_id),
159
+ span_id=cast(ID, span.context.span_id),
160
+ ),
161
+ attributes=json.dumps(
162
+ _nested_attributes(span.attributes),
163
+ default=_json_encode,
164
+ ),
165
+ token_count_total=cast(
166
+ Optional[int],
167
+ span.attributes.get(LLM_TOKEN_COUNT_TOTAL),
168
+ ),
169
+ token_count_prompt=cast(
170
+ Optional[int],
171
+ span.attributes.get(LLM_TOKEN_COUNT_PROMPT),
172
+ ),
173
+ token_count_completion=cast(
174
+ Optional[int],
175
+ span.attributes.get(LLM_TOKEN_COUNT_COMPLETION),
176
+ ),
177
+ cumulative_token_count_total=cast(
178
+ Optional[int],
179
+ span.attributes.get(CUMULATIVE_LLM_TOKEN_COUNT_TOTAL),
180
+ ),
181
+ cumulative_token_count_prompt=cast(
182
+ Optional[int],
183
+ span.attributes.get(CUMULATIVE_LLM_TOKEN_COUNT_PROMPT),
184
+ ),
185
+ cumulative_token_count_completion=cast(
186
+ Optional[int],
187
+ span.attributes.get(CUMULATIVE_LLM_TOKEN_COUNT_COMPLETION),
188
+ ),
189
+ events=events,
190
+ input=(
191
+ SpanIOValue(
192
+ mime_type=MimeType(span.attributes.get(INPUT_MIME_TYPE)),
193
+ value=input_value,
194
+ )
195
+ if input_value is not None
196
+ else None
197
+ ),
198
+ output=(
199
+ SpanIOValue(
200
+ mime_type=MimeType(span.attributes.get(OUTPUT_MIME_TYPE)),
201
+ value=output_value,
202
+ )
203
+ if output_value is not None
204
+ else None
205
+ ),
206
+ )
207
+
208
+
209
+ def _json_encode(v: Any) -> str:
210
+ if isinstance(v, datetime):
211
+ return v.isoformat()
212
+ return str(v)
213
+
214
+
215
+ def _trie() -> DefaultDict[str, Any]:
216
+ return defaultdict(_trie)
217
+
218
+
219
+ def _nested_attributes(
220
+ attributes: Mapping[str, Any],
221
+ ) -> DefaultDict[str, Any]:
222
+ nested_attributes = _trie()
223
+ for attribute_name, attribute_value in attributes.items():
224
+ trie = nested_attributes
225
+ keys = attribute_name.split(".")
226
+ for key in keys[:-1]:
227
+ trie = trie[key]
228
+ trie[keys[-1]] = attribute_value
229
+ return nested_attributes
@@ -7,7 +7,14 @@ import pandas as pd
7
7
  import strawberry
8
8
  from strawberry import UNSET
9
9
 
10
- from phoenix.core.model_schema import CONTINUOUS, PRIMARY, REFERENCE, Column, Dataset, Dimension
10
+ from phoenix.core.model_schema import (
11
+ CONTINUOUS,
12
+ PRIMARY,
13
+ REFERENCE,
14
+ Column,
15
+ Dataset,
16
+ Dimension,
17
+ )
11
18
  from phoenix.metrics import Metric, binning
12
19
  from phoenix.metrics.mixins import UnaryOperator
13
20
  from phoenix.metrics.timeseries import timeseries
@@ -31,7 +38,7 @@ class TimeSeriesDataPoint:
31
38
  """The value of the data point"""
32
39
  value: Optional[float] = strawberry.field(default=GqlValueMediator())
33
40
 
34
- def __lt__(self, other: "TimeSeriesDataPoint") -> bool:
41
+ def __lt__(self, other: "TimeSeriesDataPoint") -> bool: # type: ignore
35
42
  return self.timestamp < other.timestamp
36
43
 
37
44
 
@@ -35,6 +35,7 @@ class PageInfo:
35
35
  has_previous_page: bool
36
36
  start_cursor: Optional[str]
37
37
  end_cursor: Optional[str]
38
+ total_count: int
38
39
 
39
40
 
40
41
  # A type alias for the connection cursor implementation
@@ -168,5 +169,6 @@ def connection_from_list_slice(
168
169
  end_cursor=last_edge.cursor if last_edge else None,
169
170
  has_previous_page=start_offset > lower_bound if isinstance(args.last, int) else False,
170
171
  has_next_page=end_offset < upper_bound if isinstance(args.first, int) else False,
172
+ total_count=list_length,
171
173
  ),
172
174
  )
phoenix/server/app.py CHANGED
@@ -19,9 +19,10 @@ from strawberry.schema import BaseSchema
19
19
 
20
20
  from phoenix.config import SERVER_DIR
21
21
  from phoenix.core.model_schema import Model
22
-
23
- from .api.context import Context
24
- from .api.schema import schema
22
+ from phoenix.core.traces import Traces
23
+ from phoenix.server.api.context import Context
24
+ from phoenix.server.api.schema import schema
25
+ from phoenix.server.span_handler import SpanHandler
25
26
 
26
27
  logger = logging.getLogger(__name__)
27
28
 
@@ -65,9 +66,11 @@ class GraphQLWithContext(GraphQL): # type: ignore
65
66
  export_path: Path,
66
67
  graphiql: bool = False,
67
68
  corpus: Optional[Model] = None,
69
+ traces: Optional[Traces] = None,
68
70
  ) -> None:
69
71
  self.model = model
70
72
  self.corpus = corpus
73
+ self.traces = traces
71
74
  self.export_path = export_path
72
75
  super().__init__(schema, graphiql=graphiql)
73
76
 
@@ -81,6 +84,7 @@ class GraphQLWithContext(GraphQL): # type: ignore
81
84
  response=response,
82
85
  model=self.model,
83
86
  corpus=self.corpus,
87
+ traces=self.traces,
84
88
  export_path=self.export_path,
85
89
  )
86
90
 
@@ -104,12 +108,14 @@ def create_app(
104
108
  export_path: Path,
105
109
  model: Model,
106
110
  corpus: Optional[Model] = None,
111
+ traces: Optional[Traces] = None,
107
112
  debug: bool = False,
108
113
  ) -> Starlette:
109
114
  graphql = GraphQLWithContext(
110
115
  schema=schema,
111
116
  model=model,
112
117
  corpus=corpus,
118
+ traces=traces,
113
119
  export_path=export_path,
114
120
  graphiql=True,
115
121
  )
@@ -118,7 +124,21 @@ def create_app(
118
124
  Middleware(HeadersMiddleware),
119
125
  ],
120
126
  debug=debug,
121
- routes=[
127
+ routes=(
128
+ []
129
+ if traces is None
130
+ else [
131
+ Route(
132
+ "/v1/spans",
133
+ type(
134
+ "SpanEndpoint",
135
+ (SpanHandler,),
136
+ {"queue": traces},
137
+ ),
138
+ ),
139
+ ]
140
+ )
141
+ + [
122
142
  Route(
123
143
  "/exports",
124
144
  type(
phoenix/server/main.py CHANGED
@@ -1,53 +1,67 @@
1
1
  import atexit
2
- import errno
3
2
  import logging
4
3
  import os
5
4
  from argparse import ArgumentParser
6
5
  from pathlib import Path
6
+ from threading import Thread
7
+ from time import sleep, time
7
8
  from typing import Optional
8
9
 
9
- import uvicorn
10
+ from uvicorn import Config, Server
10
11
 
11
- import phoenix.config as config
12
+ from phoenix.config import EXPORT_DIR, get_env_host, get_env_port, get_pids_path
12
13
  from phoenix.core.model_schema_adapter import create_model_from_datasets
13
- from phoenix.datasets.dataset import Dataset
14
+ from phoenix.core.traces import Traces
15
+ from phoenix.datasets.dataset import EMPTY_DATASET, Dataset
14
16
  from phoenix.datasets.fixtures import FIXTURES, get_datasets
15
17
  from phoenix.server.app import create_app
18
+ from phoenix.trace.fixtures import (
19
+ TRACES_FIXTURES,
20
+ _download_traces_fixture,
21
+ _get_trace_fixture_by_name,
22
+ )
23
+ from phoenix.trace.span_json_decoder import json_string_to_span
16
24
 
17
25
  logger = logging.getLogger(__name__)
18
26
 
19
27
 
20
- def _write_pid_file() -> None:
21
- with open(_get_pid_file(), "w"):
22
- pass
28
+ def _write_pid_file_when_ready(
29
+ server: Server,
30
+ wait_up_to_seconds: float = 5,
31
+ ) -> None:
32
+ """Write PID file after server is started (or when time is up)."""
33
+ time_limit = time() + wait_up_to_seconds
34
+ while time() < time_limit and not server.should_exit and not server.started:
35
+ sleep(1e-3)
36
+ if time() >= time_limit and not server.started:
37
+ server.should_exit = True
38
+ _get_pid_file().touch()
23
39
 
24
40
 
25
41
  def _remove_pid_file() -> None:
26
- try:
27
- os.unlink(_get_pid_file())
28
- except OSError as e:
29
- if e.errno == errno.ENOENT:
30
- # If the pid file doesn't exist, ignore and continue on since
31
- # we are already in the desired end state; This should not happen
32
- pass
33
- else:
34
- raise
42
+ _get_pid_file().unlink(missing_ok=True)
35
43
 
36
44
 
37
- def _get_pid_file() -> str:
38
- return os.path.join(config.get_pids_path(), "%d" % os.getpid())
45
+ def _get_pid_file() -> Path:
46
+ return get_pids_path() / str(os.getpid())
39
47
 
40
48
 
41
49
  if __name__ == "__main__":
42
50
  primary_dataset_name: str
43
51
  reference_dataset_name: Optional[str]
52
+ trace_dataset_name: Optional[str] = None
53
+
54
+ primary_dataset: Dataset = EMPTY_DATASET
55
+ reference_dataset: Optional[Dataset] = None
56
+ corpus_dataset: Optional[Dataset] = None
57
+
44
58
  # automatically remove the pid file when the process is being gracefully terminated
45
59
  atexit.register(_remove_pid_file)
46
- _write_pid_file()
47
60
 
48
61
  parser = ArgumentParser()
49
62
  parser.add_argument("--export_path")
50
- parser.add_argument("--port", type=int, default=config.PORT)
63
+ parser.add_argument("--host", type=str, required=False)
64
+ parser.add_argument("--port", type=int, required=False)
51
65
  parser.add_argument("--no-internet", action="store_true")
52
66
  parser.add_argument("--debug", action="store_false") # TODO: Disable before public launch
53
67
  subparsers = parser.add_subparsers(dest="command", required=True)
@@ -55,11 +69,16 @@ if __name__ == "__main__":
55
69
  datasets_parser.add_argument("--primary", type=str, required=True)
56
70
  datasets_parser.add_argument("--reference", type=str, required=False)
57
71
  datasets_parser.add_argument("--corpus", type=str, required=False)
72
+ datasets_parser.add_argument("--trace", type=str, required=False)
58
73
  fixture_parser = subparsers.add_parser("fixture")
59
74
  fixture_parser.add_argument("fixture", type=str, choices=[fixture.name for fixture in FIXTURES])
60
75
  fixture_parser.add_argument("--primary-only", type=bool)
76
+ trace_fixture_parser = subparsers.add_parser("trace-fixture")
77
+ trace_fixture_parser.add_argument(
78
+ "fixture", type=str, choices=[fixture.name for fixture in TRACES_FIXTURES]
79
+ )
61
80
  args = parser.parse_args()
62
- export_path = Path(args.export_path) if args.export_path else config.EXPORT_DIR
81
+ export_path = Path(args.export_path) if args.export_path else EXPORT_DIR
63
82
  if args.command == "datasets":
64
83
  primary_dataset_name = args.primary
65
84
  reference_dataset_name = args.reference
@@ -73,7 +92,7 @@ if __name__ == "__main__":
73
92
  corpus_dataset = (
74
93
  None if corpus_dataset_name is None else Dataset.from_name(corpus_dataset_name)
75
94
  )
76
- else:
95
+ elif args.command == "fixture":
77
96
  fixture_name = args.fixture
78
97
  primary_only = args.primary_only
79
98
  primary_dataset, reference_dataset, corpus_dataset = get_datasets(
@@ -83,16 +102,33 @@ if __name__ == "__main__":
83
102
  if primary_only:
84
103
  reference_dataset_name = None
85
104
  reference_dataset = None
105
+ elif args.command == "trace-fixture":
106
+ trace_dataset_name = args.fixture
86
107
 
87
108
  model = create_model_from_datasets(
88
109
  primary_dataset,
89
110
  reference_dataset,
90
111
  )
112
+ traces = Traces()
113
+ if trace_dataset_name is not None:
114
+ for span in map(
115
+ json_string_to_span,
116
+ _download_traces_fixture(
117
+ _get_trace_fixture_by_name(
118
+ trace_dataset_name,
119
+ ),
120
+ ),
121
+ ):
122
+ traces.put(span)
91
123
  app = create_app(
92
124
  export_path=export_path,
93
125
  model=model,
126
+ traces=traces,
94
127
  corpus=None if corpus_dataset is None else create_model_from_datasets(corpus_dataset),
95
128
  debug=args.debug,
96
129
  )
97
-
98
- uvicorn.run(app, port=args.port)
130
+ host = args.host or get_env_host()
131
+ port = args.port or get_env_port()
132
+ server = Server(config=Config(app, host=host, port=port))
133
+ Thread(target=_write_pid_file_when_ready, args=(server,), daemon=True).start()
134
+ server.run()
@@ -0,0 +1,39 @@
1
+ import gzip
2
+ from typing import Protocol
3
+
4
+ from starlette.endpoints import HTTPEndpoint
5
+ from starlette.requests import Request
6
+ from starlette.responses import Response
7
+
8
+ from phoenix.trace.schemas import Span
9
+ from phoenix.trace.span_json_decoder import json_to_span
10
+ from phoenix.trace.v1 import encode
11
+ from phoenix.trace.v1 import trace_pb2 as pb
12
+
13
+
14
+ class SupportsPutSpan(Protocol):
15
+ def put(self, span: pb.Span) -> None:
16
+ ...
17
+
18
+
19
+ class SpanHandler(HTTPEndpoint):
20
+ queue: SupportsPutSpan
21
+
22
+ async def post(self, request: Request) -> Response:
23
+ try:
24
+ content_type = request.headers.get("content-type")
25
+ if content_type == "application/x-protobuf":
26
+ body = await request.body()
27
+ content_encoding = request.headers.get("content-encoding")
28
+ if content_encoding == "gzip":
29
+ body = gzip.decompress(body)
30
+ pb_span = pb.Span()
31
+ pb_span.ParseFromString(body)
32
+ else:
33
+ span = json_to_span(await request.json())
34
+ assert isinstance(span, Span)
35
+ pb_span = encode(span)
36
+ except Exception:
37
+ return Response(status_code=422)
38
+ self.queue.put(pb_span)
39
+ return Response()