arize-phoenix 0.0.32rc1__py3-none-any.whl → 0.0.33__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-0.0.32rc1.dist-info → arize_phoenix-0.0.33.dist-info}/METADATA +11 -5
- {arize_phoenix-0.0.32rc1.dist-info → arize_phoenix-0.0.33.dist-info}/RECORD +69 -40
- phoenix/__init__.py +3 -1
- phoenix/config.py +23 -1
- phoenix/core/model_schema.py +14 -37
- phoenix/core/model_schema_adapter.py +0 -1
- phoenix/core/traces.py +285 -0
- phoenix/datasets/dataset.py +14 -21
- phoenix/datasets/errors.py +4 -1
- phoenix/datasets/schema.py +1 -1
- phoenix/datetime_utils.py +87 -0
- phoenix/experimental/callbacks/__init__.py +0 -0
- phoenix/experimental/callbacks/langchain_tracer.py +228 -0
- phoenix/experimental/callbacks/llama_index_trace_callback_handler.py +364 -0
- phoenix/experimental/evals/__init__.py +33 -0
- phoenix/experimental/evals/functions/__init__.py +4 -0
- phoenix/experimental/evals/functions/binary.py +156 -0
- phoenix/experimental/evals/functions/common.py +31 -0
- phoenix/experimental/evals/functions/generate.py +50 -0
- phoenix/experimental/evals/models/__init__.py +4 -0
- phoenix/experimental/evals/models/base.py +130 -0
- phoenix/experimental/evals/models/openai.py +128 -0
- phoenix/experimental/evals/retrievals.py +2 -2
- phoenix/experimental/evals/templates/__init__.py +24 -0
- phoenix/experimental/evals/templates/default_templates.py +126 -0
- phoenix/experimental/evals/templates/template.py +107 -0
- phoenix/experimental/evals/utils/__init__.py +0 -0
- phoenix/experimental/evals/utils/downloads.py +33 -0
- phoenix/experimental/evals/utils/threads.py +27 -0
- phoenix/experimental/evals/utils/types.py +9 -0
- phoenix/experimental/evals/utils.py +33 -0
- phoenix/metrics/binning.py +0 -1
- phoenix/metrics/timeseries.py +2 -3
- phoenix/server/api/context.py +2 -0
- phoenix/server/api/input_types/SpanSort.py +60 -0
- phoenix/server/api/schema.py +85 -4
- phoenix/server/api/types/DataQualityMetric.py +10 -1
- phoenix/server/api/types/Dataset.py +2 -4
- phoenix/server/api/types/DatasetInfo.py +10 -0
- phoenix/server/api/types/ExportEventsMutation.py +4 -1
- phoenix/server/api/types/Functionality.py +15 -0
- phoenix/server/api/types/MimeType.py +16 -0
- phoenix/server/api/types/Model.py +3 -5
- phoenix/server/api/types/SortDir.py +13 -0
- phoenix/server/api/types/Span.py +229 -0
- phoenix/server/api/types/TimeSeries.py +9 -2
- phoenix/server/api/types/pagination.py +2 -0
- phoenix/server/app.py +24 -4
- phoenix/server/main.py +60 -24
- phoenix/server/span_handler.py +39 -0
- phoenix/server/static/index.js +956 -479
- phoenix/server/thread_server.py +10 -2
- phoenix/services.py +39 -16
- phoenix/session/session.py +99 -27
- phoenix/trace/exporter.py +71 -0
- phoenix/trace/filter.py +181 -0
- phoenix/trace/fixtures.py +23 -8
- phoenix/trace/schemas.py +59 -6
- phoenix/trace/semantic_conventions.py +141 -1
- phoenix/trace/span_json_decoder.py +60 -6
- phoenix/trace/span_json_encoder.py +1 -9
- phoenix/trace/trace_dataset.py +100 -8
- phoenix/trace/tracer.py +26 -3
- phoenix/trace/v1/__init__.py +522 -0
- phoenix/trace/v1/trace_pb2.py +52 -0
- phoenix/trace/v1/trace_pb2.pyi +351 -0
- phoenix/core/dimension_data_type.py +0 -6
- phoenix/core/dimension_type.py +0 -9
- {arize_phoenix-0.0.32rc1.dist-info → arize_phoenix-0.0.33.dist-info}/WHEEL +0 -0
- {arize_phoenix-0.0.32rc1.dist-info → arize_phoenix-0.0.33.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-0.0.32rc1.dist-info → arize_phoenix-0.0.33.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,229 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from collections import defaultdict
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
from enum import Enum
|
|
5
|
+
from typing import Any, DefaultDict, List, Mapping, Optional, cast
|
|
6
|
+
|
|
7
|
+
import strawberry
|
|
8
|
+
from strawberry import ID
|
|
9
|
+
from strawberry.types import Info
|
|
10
|
+
|
|
11
|
+
import phoenix.trace.schemas as trace_schema
|
|
12
|
+
from phoenix.core.traces import (
|
|
13
|
+
CUMULATIVE_LLM_TOKEN_COUNT_COMPLETION,
|
|
14
|
+
CUMULATIVE_LLM_TOKEN_COUNT_PROMPT,
|
|
15
|
+
CUMULATIVE_LLM_TOKEN_COUNT_TOTAL,
|
|
16
|
+
LATENCY_MS,
|
|
17
|
+
)
|
|
18
|
+
from phoenix.server.api.context import Context
|
|
19
|
+
from phoenix.server.api.types.MimeType import MimeType
|
|
20
|
+
from phoenix.trace.schemas import SpanID
|
|
21
|
+
from phoenix.trace.semantic_conventions import (
|
|
22
|
+
EXCEPTION_MESSAGE,
|
|
23
|
+
INPUT_MIME_TYPE,
|
|
24
|
+
INPUT_VALUE,
|
|
25
|
+
LLM_TOKEN_COUNT_COMPLETION,
|
|
26
|
+
LLM_TOKEN_COUNT_PROMPT,
|
|
27
|
+
LLM_TOKEN_COUNT_TOTAL,
|
|
28
|
+
OUTPUT_MIME_TYPE,
|
|
29
|
+
OUTPUT_VALUE,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@strawberry.enum
|
|
34
|
+
class SpanKind(Enum):
|
|
35
|
+
"""
|
|
36
|
+
The type of work that a Span encapsulates.
|
|
37
|
+
|
|
38
|
+
NB: this is actively under construction
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
chain = trace_schema.SpanKind.CHAIN
|
|
42
|
+
tool = trace_schema.SpanKind.TOOL
|
|
43
|
+
llm = trace_schema.SpanKind.LLM
|
|
44
|
+
retriever = trace_schema.SpanKind.RETRIEVER
|
|
45
|
+
embedding = trace_schema.SpanKind.EMBEDDING
|
|
46
|
+
agent = trace_schema.SpanKind.AGENT
|
|
47
|
+
unknown = trace_schema.SpanKind.UNKNOWN
|
|
48
|
+
|
|
49
|
+
@classmethod
|
|
50
|
+
def _missing_(cls, v: Any) -> Optional["SpanKind"]:
|
|
51
|
+
return None if v else cls.unknown
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@strawberry.type
|
|
55
|
+
class SpanContext:
|
|
56
|
+
trace_id: ID
|
|
57
|
+
span_id: ID
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@strawberry.type
|
|
61
|
+
class SpanIOValue:
|
|
62
|
+
mime_type: MimeType
|
|
63
|
+
value: str
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@strawberry.enum
|
|
67
|
+
class SpanStatusCode(Enum):
|
|
68
|
+
OK = trace_schema.SpanStatusCode.OK
|
|
69
|
+
ERROR = trace_schema.SpanStatusCode.ERROR
|
|
70
|
+
UNSET = trace_schema.SpanStatusCode.UNSET
|
|
71
|
+
|
|
72
|
+
@classmethod
|
|
73
|
+
def _missing_(cls, v: Any) -> Optional["SpanStatusCode"]:
|
|
74
|
+
return None if v else cls.UNSET
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
@strawberry.type
|
|
78
|
+
class SpanEvent:
|
|
79
|
+
name: str
|
|
80
|
+
message: str
|
|
81
|
+
timestamp: datetime
|
|
82
|
+
|
|
83
|
+
@staticmethod
|
|
84
|
+
def from_event(
|
|
85
|
+
event: trace_schema.SpanEvent,
|
|
86
|
+
) -> "SpanEvent":
|
|
87
|
+
return SpanEvent(
|
|
88
|
+
name=event.name,
|
|
89
|
+
message=cast(str, event.attributes.get(EXCEPTION_MESSAGE) or ""),
|
|
90
|
+
timestamp=event.timestamp,
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
@strawberry.type
|
|
95
|
+
class Span:
|
|
96
|
+
name: str
|
|
97
|
+
status_code: SpanStatusCode
|
|
98
|
+
start_time: datetime
|
|
99
|
+
end_time: Optional[datetime]
|
|
100
|
+
latency_ms: Optional[float]
|
|
101
|
+
parent_id: Optional[ID] = strawberry.field(
|
|
102
|
+
description="the parent span ID. If null, it is a root span"
|
|
103
|
+
)
|
|
104
|
+
span_kind: SpanKind
|
|
105
|
+
context: SpanContext
|
|
106
|
+
attributes: str = strawberry.field(
|
|
107
|
+
description="Span attributes as a JSON string",
|
|
108
|
+
)
|
|
109
|
+
token_count_total: Optional[int]
|
|
110
|
+
token_count_prompt: Optional[int]
|
|
111
|
+
token_count_completion: Optional[int]
|
|
112
|
+
input: Optional[SpanIOValue]
|
|
113
|
+
output: Optional[SpanIOValue]
|
|
114
|
+
events: List[SpanEvent]
|
|
115
|
+
cumulative_token_count_total: Optional[int] = strawberry.field(
|
|
116
|
+
description="Cumulative (prompt plus completion) token count from "
|
|
117
|
+
"self and all descendant spans (children, grandchildren, etc.)",
|
|
118
|
+
)
|
|
119
|
+
cumulative_token_count_prompt: Optional[int] = strawberry.field(
|
|
120
|
+
description="Cumulative (prompt) token count from self and all "
|
|
121
|
+
"descendant spans (children, grandchildren, etc.)",
|
|
122
|
+
)
|
|
123
|
+
cumulative_token_count_completion: Optional[int] = strawberry.field(
|
|
124
|
+
description="Cumulative (completion) token count from self and all "
|
|
125
|
+
"descendant spans (children, grandchildren, etc.)",
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
@strawberry.field(
|
|
129
|
+
description="All descendant spans (children, grandchildren, etc.)",
|
|
130
|
+
) # type: ignore
|
|
131
|
+
def descendants(
|
|
132
|
+
self,
|
|
133
|
+
info: Info[Context, None],
|
|
134
|
+
) -> List["Span"]:
|
|
135
|
+
if (traces := info.context.traces) is None:
|
|
136
|
+
return []
|
|
137
|
+
return [
|
|
138
|
+
to_gql_span(cast(trace_schema.Span, traces[span_id]))
|
|
139
|
+
for span_id in traces.get_descendant_span_ids(
|
|
140
|
+
cast(SpanID, self.context.span_id),
|
|
141
|
+
)
|
|
142
|
+
]
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def to_gql_span(span: trace_schema.Span) -> "Span":
|
|
146
|
+
events: List[SpanEvent] = list(map(SpanEvent.from_event, span.events))
|
|
147
|
+
input_value = cast(Optional[str], span.attributes.get(INPUT_VALUE))
|
|
148
|
+
output_value = cast(Optional[str], span.attributes.get(OUTPUT_VALUE))
|
|
149
|
+
return Span(
|
|
150
|
+
name=span.name,
|
|
151
|
+
status_code=SpanStatusCode(span.status_code),
|
|
152
|
+
parent_id=cast(Optional[ID], span.parent_id),
|
|
153
|
+
span_kind=SpanKind(span.span_kind),
|
|
154
|
+
start_time=span.start_time,
|
|
155
|
+
end_time=span.end_time,
|
|
156
|
+
latency_ms=cast(Optional[float], span.attributes.get(LATENCY_MS)),
|
|
157
|
+
context=SpanContext(
|
|
158
|
+
trace_id=cast(ID, span.context.trace_id),
|
|
159
|
+
span_id=cast(ID, span.context.span_id),
|
|
160
|
+
),
|
|
161
|
+
attributes=json.dumps(
|
|
162
|
+
_nested_attributes(span.attributes),
|
|
163
|
+
default=_json_encode,
|
|
164
|
+
),
|
|
165
|
+
token_count_total=cast(
|
|
166
|
+
Optional[int],
|
|
167
|
+
span.attributes.get(LLM_TOKEN_COUNT_TOTAL),
|
|
168
|
+
),
|
|
169
|
+
token_count_prompt=cast(
|
|
170
|
+
Optional[int],
|
|
171
|
+
span.attributes.get(LLM_TOKEN_COUNT_PROMPT),
|
|
172
|
+
),
|
|
173
|
+
token_count_completion=cast(
|
|
174
|
+
Optional[int],
|
|
175
|
+
span.attributes.get(LLM_TOKEN_COUNT_COMPLETION),
|
|
176
|
+
),
|
|
177
|
+
cumulative_token_count_total=cast(
|
|
178
|
+
Optional[int],
|
|
179
|
+
span.attributes.get(CUMULATIVE_LLM_TOKEN_COUNT_TOTAL),
|
|
180
|
+
),
|
|
181
|
+
cumulative_token_count_prompt=cast(
|
|
182
|
+
Optional[int],
|
|
183
|
+
span.attributes.get(CUMULATIVE_LLM_TOKEN_COUNT_PROMPT),
|
|
184
|
+
),
|
|
185
|
+
cumulative_token_count_completion=cast(
|
|
186
|
+
Optional[int],
|
|
187
|
+
span.attributes.get(CUMULATIVE_LLM_TOKEN_COUNT_COMPLETION),
|
|
188
|
+
),
|
|
189
|
+
events=events,
|
|
190
|
+
input=(
|
|
191
|
+
SpanIOValue(
|
|
192
|
+
mime_type=MimeType(span.attributes.get(INPUT_MIME_TYPE)),
|
|
193
|
+
value=input_value,
|
|
194
|
+
)
|
|
195
|
+
if input_value is not None
|
|
196
|
+
else None
|
|
197
|
+
),
|
|
198
|
+
output=(
|
|
199
|
+
SpanIOValue(
|
|
200
|
+
mime_type=MimeType(span.attributes.get(OUTPUT_MIME_TYPE)),
|
|
201
|
+
value=output_value,
|
|
202
|
+
)
|
|
203
|
+
if output_value is not None
|
|
204
|
+
else None
|
|
205
|
+
),
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def _json_encode(v: Any) -> str:
|
|
210
|
+
if isinstance(v, datetime):
|
|
211
|
+
return v.isoformat()
|
|
212
|
+
return str(v)
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def _trie() -> DefaultDict[str, Any]:
|
|
216
|
+
return defaultdict(_trie)
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def _nested_attributes(
|
|
220
|
+
attributes: Mapping[str, Any],
|
|
221
|
+
) -> DefaultDict[str, Any]:
|
|
222
|
+
nested_attributes = _trie()
|
|
223
|
+
for attribute_name, attribute_value in attributes.items():
|
|
224
|
+
trie = nested_attributes
|
|
225
|
+
keys = attribute_name.split(".")
|
|
226
|
+
for key in keys[:-1]:
|
|
227
|
+
trie = trie[key]
|
|
228
|
+
trie[keys[-1]] = attribute_value
|
|
229
|
+
return nested_attributes
|
|
@@ -7,7 +7,14 @@ import pandas as pd
|
|
|
7
7
|
import strawberry
|
|
8
8
|
from strawberry import UNSET
|
|
9
9
|
|
|
10
|
-
from phoenix.core.model_schema import
|
|
10
|
+
from phoenix.core.model_schema import (
|
|
11
|
+
CONTINUOUS,
|
|
12
|
+
PRIMARY,
|
|
13
|
+
REFERENCE,
|
|
14
|
+
Column,
|
|
15
|
+
Dataset,
|
|
16
|
+
Dimension,
|
|
17
|
+
)
|
|
11
18
|
from phoenix.metrics import Metric, binning
|
|
12
19
|
from phoenix.metrics.mixins import UnaryOperator
|
|
13
20
|
from phoenix.metrics.timeseries import timeseries
|
|
@@ -31,7 +38,7 @@ class TimeSeriesDataPoint:
|
|
|
31
38
|
"""The value of the data point"""
|
|
32
39
|
value: Optional[float] = strawberry.field(default=GqlValueMediator())
|
|
33
40
|
|
|
34
|
-
def __lt__(self, other: "TimeSeriesDataPoint") -> bool:
|
|
41
|
+
def __lt__(self, other: "TimeSeriesDataPoint") -> bool: # type: ignore
|
|
35
42
|
return self.timestamp < other.timestamp
|
|
36
43
|
|
|
37
44
|
|
|
@@ -35,6 +35,7 @@ class PageInfo:
|
|
|
35
35
|
has_previous_page: bool
|
|
36
36
|
start_cursor: Optional[str]
|
|
37
37
|
end_cursor: Optional[str]
|
|
38
|
+
total_count: int
|
|
38
39
|
|
|
39
40
|
|
|
40
41
|
# A type alias for the connection cursor implementation
|
|
@@ -168,5 +169,6 @@ def connection_from_list_slice(
|
|
|
168
169
|
end_cursor=last_edge.cursor if last_edge else None,
|
|
169
170
|
has_previous_page=start_offset > lower_bound if isinstance(args.last, int) else False,
|
|
170
171
|
has_next_page=end_offset < upper_bound if isinstance(args.first, int) else False,
|
|
172
|
+
total_count=list_length,
|
|
171
173
|
),
|
|
172
174
|
)
|
phoenix/server/app.py
CHANGED
|
@@ -19,9 +19,10 @@ from strawberry.schema import BaseSchema
|
|
|
19
19
|
|
|
20
20
|
from phoenix.config import SERVER_DIR
|
|
21
21
|
from phoenix.core.model_schema import Model
|
|
22
|
-
|
|
23
|
-
from .api.context import Context
|
|
24
|
-
from .api.schema import schema
|
|
22
|
+
from phoenix.core.traces import Traces
|
|
23
|
+
from phoenix.server.api.context import Context
|
|
24
|
+
from phoenix.server.api.schema import schema
|
|
25
|
+
from phoenix.server.span_handler import SpanHandler
|
|
25
26
|
|
|
26
27
|
logger = logging.getLogger(__name__)
|
|
27
28
|
|
|
@@ -65,9 +66,11 @@ class GraphQLWithContext(GraphQL): # type: ignore
|
|
|
65
66
|
export_path: Path,
|
|
66
67
|
graphiql: bool = False,
|
|
67
68
|
corpus: Optional[Model] = None,
|
|
69
|
+
traces: Optional[Traces] = None,
|
|
68
70
|
) -> None:
|
|
69
71
|
self.model = model
|
|
70
72
|
self.corpus = corpus
|
|
73
|
+
self.traces = traces
|
|
71
74
|
self.export_path = export_path
|
|
72
75
|
super().__init__(schema, graphiql=graphiql)
|
|
73
76
|
|
|
@@ -81,6 +84,7 @@ class GraphQLWithContext(GraphQL): # type: ignore
|
|
|
81
84
|
response=response,
|
|
82
85
|
model=self.model,
|
|
83
86
|
corpus=self.corpus,
|
|
87
|
+
traces=self.traces,
|
|
84
88
|
export_path=self.export_path,
|
|
85
89
|
)
|
|
86
90
|
|
|
@@ -104,12 +108,14 @@ def create_app(
|
|
|
104
108
|
export_path: Path,
|
|
105
109
|
model: Model,
|
|
106
110
|
corpus: Optional[Model] = None,
|
|
111
|
+
traces: Optional[Traces] = None,
|
|
107
112
|
debug: bool = False,
|
|
108
113
|
) -> Starlette:
|
|
109
114
|
graphql = GraphQLWithContext(
|
|
110
115
|
schema=schema,
|
|
111
116
|
model=model,
|
|
112
117
|
corpus=corpus,
|
|
118
|
+
traces=traces,
|
|
113
119
|
export_path=export_path,
|
|
114
120
|
graphiql=True,
|
|
115
121
|
)
|
|
@@ -118,7 +124,21 @@ def create_app(
|
|
|
118
124
|
Middleware(HeadersMiddleware),
|
|
119
125
|
],
|
|
120
126
|
debug=debug,
|
|
121
|
-
routes=
|
|
127
|
+
routes=(
|
|
128
|
+
[]
|
|
129
|
+
if traces is None
|
|
130
|
+
else [
|
|
131
|
+
Route(
|
|
132
|
+
"/v1/spans",
|
|
133
|
+
type(
|
|
134
|
+
"SpanEndpoint",
|
|
135
|
+
(SpanHandler,),
|
|
136
|
+
{"queue": traces},
|
|
137
|
+
),
|
|
138
|
+
),
|
|
139
|
+
]
|
|
140
|
+
)
|
|
141
|
+
+ [
|
|
122
142
|
Route(
|
|
123
143
|
"/exports",
|
|
124
144
|
type(
|
phoenix/server/main.py
CHANGED
|
@@ -1,53 +1,67 @@
|
|
|
1
1
|
import atexit
|
|
2
|
-
import errno
|
|
3
2
|
import logging
|
|
4
3
|
import os
|
|
5
4
|
from argparse import ArgumentParser
|
|
6
5
|
from pathlib import Path
|
|
6
|
+
from threading import Thread
|
|
7
|
+
from time import sleep, time
|
|
7
8
|
from typing import Optional
|
|
8
9
|
|
|
9
|
-
import
|
|
10
|
+
from uvicorn import Config, Server
|
|
10
11
|
|
|
11
|
-
|
|
12
|
+
from phoenix.config import EXPORT_DIR, get_env_host, get_env_port, get_pids_path
|
|
12
13
|
from phoenix.core.model_schema_adapter import create_model_from_datasets
|
|
13
|
-
from phoenix.
|
|
14
|
+
from phoenix.core.traces import Traces
|
|
15
|
+
from phoenix.datasets.dataset import EMPTY_DATASET, Dataset
|
|
14
16
|
from phoenix.datasets.fixtures import FIXTURES, get_datasets
|
|
15
17
|
from phoenix.server.app import create_app
|
|
18
|
+
from phoenix.trace.fixtures import (
|
|
19
|
+
TRACES_FIXTURES,
|
|
20
|
+
_download_traces_fixture,
|
|
21
|
+
_get_trace_fixture_by_name,
|
|
22
|
+
)
|
|
23
|
+
from phoenix.trace.span_json_decoder import json_string_to_span
|
|
16
24
|
|
|
17
25
|
logger = logging.getLogger(__name__)
|
|
18
26
|
|
|
19
27
|
|
|
20
|
-
def
|
|
21
|
-
|
|
22
|
-
|
|
28
|
+
def _write_pid_file_when_ready(
|
|
29
|
+
server: Server,
|
|
30
|
+
wait_up_to_seconds: float = 5,
|
|
31
|
+
) -> None:
|
|
32
|
+
"""Write PID file after server is started (or when time is up)."""
|
|
33
|
+
time_limit = time() + wait_up_to_seconds
|
|
34
|
+
while time() < time_limit and not server.should_exit and not server.started:
|
|
35
|
+
sleep(1e-3)
|
|
36
|
+
if time() >= time_limit and not server.started:
|
|
37
|
+
server.should_exit = True
|
|
38
|
+
_get_pid_file().touch()
|
|
23
39
|
|
|
24
40
|
|
|
25
41
|
def _remove_pid_file() -> None:
|
|
26
|
-
|
|
27
|
-
os.unlink(_get_pid_file())
|
|
28
|
-
except OSError as e:
|
|
29
|
-
if e.errno == errno.ENOENT:
|
|
30
|
-
# If the pid file doesn't exist, ignore and continue on since
|
|
31
|
-
# we are already in the desired end state; This should not happen
|
|
32
|
-
pass
|
|
33
|
-
else:
|
|
34
|
-
raise
|
|
42
|
+
_get_pid_file().unlink(missing_ok=True)
|
|
35
43
|
|
|
36
44
|
|
|
37
|
-
def _get_pid_file() ->
|
|
38
|
-
return
|
|
45
|
+
def _get_pid_file() -> Path:
|
|
46
|
+
return get_pids_path() / str(os.getpid())
|
|
39
47
|
|
|
40
48
|
|
|
41
49
|
if __name__ == "__main__":
|
|
42
50
|
primary_dataset_name: str
|
|
43
51
|
reference_dataset_name: Optional[str]
|
|
52
|
+
trace_dataset_name: Optional[str] = None
|
|
53
|
+
|
|
54
|
+
primary_dataset: Dataset = EMPTY_DATASET
|
|
55
|
+
reference_dataset: Optional[Dataset] = None
|
|
56
|
+
corpus_dataset: Optional[Dataset] = None
|
|
57
|
+
|
|
44
58
|
# automatically remove the pid file when the process is being gracefully terminated
|
|
45
59
|
atexit.register(_remove_pid_file)
|
|
46
|
-
_write_pid_file()
|
|
47
60
|
|
|
48
61
|
parser = ArgumentParser()
|
|
49
62
|
parser.add_argument("--export_path")
|
|
50
|
-
parser.add_argument("--
|
|
63
|
+
parser.add_argument("--host", type=str, required=False)
|
|
64
|
+
parser.add_argument("--port", type=int, required=False)
|
|
51
65
|
parser.add_argument("--no-internet", action="store_true")
|
|
52
66
|
parser.add_argument("--debug", action="store_false") # TODO: Disable before public launch
|
|
53
67
|
subparsers = parser.add_subparsers(dest="command", required=True)
|
|
@@ -55,11 +69,16 @@ if __name__ == "__main__":
|
|
|
55
69
|
datasets_parser.add_argument("--primary", type=str, required=True)
|
|
56
70
|
datasets_parser.add_argument("--reference", type=str, required=False)
|
|
57
71
|
datasets_parser.add_argument("--corpus", type=str, required=False)
|
|
72
|
+
datasets_parser.add_argument("--trace", type=str, required=False)
|
|
58
73
|
fixture_parser = subparsers.add_parser("fixture")
|
|
59
74
|
fixture_parser.add_argument("fixture", type=str, choices=[fixture.name for fixture in FIXTURES])
|
|
60
75
|
fixture_parser.add_argument("--primary-only", type=bool)
|
|
76
|
+
trace_fixture_parser = subparsers.add_parser("trace-fixture")
|
|
77
|
+
trace_fixture_parser.add_argument(
|
|
78
|
+
"fixture", type=str, choices=[fixture.name for fixture in TRACES_FIXTURES]
|
|
79
|
+
)
|
|
61
80
|
args = parser.parse_args()
|
|
62
|
-
export_path = Path(args.export_path) if args.export_path else
|
|
81
|
+
export_path = Path(args.export_path) if args.export_path else EXPORT_DIR
|
|
63
82
|
if args.command == "datasets":
|
|
64
83
|
primary_dataset_name = args.primary
|
|
65
84
|
reference_dataset_name = args.reference
|
|
@@ -73,7 +92,7 @@ if __name__ == "__main__":
|
|
|
73
92
|
corpus_dataset = (
|
|
74
93
|
None if corpus_dataset_name is None else Dataset.from_name(corpus_dataset_name)
|
|
75
94
|
)
|
|
76
|
-
|
|
95
|
+
elif args.command == "fixture":
|
|
77
96
|
fixture_name = args.fixture
|
|
78
97
|
primary_only = args.primary_only
|
|
79
98
|
primary_dataset, reference_dataset, corpus_dataset = get_datasets(
|
|
@@ -83,16 +102,33 @@ if __name__ == "__main__":
|
|
|
83
102
|
if primary_only:
|
|
84
103
|
reference_dataset_name = None
|
|
85
104
|
reference_dataset = None
|
|
105
|
+
elif args.command == "trace-fixture":
|
|
106
|
+
trace_dataset_name = args.fixture
|
|
86
107
|
|
|
87
108
|
model = create_model_from_datasets(
|
|
88
109
|
primary_dataset,
|
|
89
110
|
reference_dataset,
|
|
90
111
|
)
|
|
112
|
+
traces = Traces()
|
|
113
|
+
if trace_dataset_name is not None:
|
|
114
|
+
for span in map(
|
|
115
|
+
json_string_to_span,
|
|
116
|
+
_download_traces_fixture(
|
|
117
|
+
_get_trace_fixture_by_name(
|
|
118
|
+
trace_dataset_name,
|
|
119
|
+
),
|
|
120
|
+
),
|
|
121
|
+
):
|
|
122
|
+
traces.put(span)
|
|
91
123
|
app = create_app(
|
|
92
124
|
export_path=export_path,
|
|
93
125
|
model=model,
|
|
126
|
+
traces=traces,
|
|
94
127
|
corpus=None if corpus_dataset is None else create_model_from_datasets(corpus_dataset),
|
|
95
128
|
debug=args.debug,
|
|
96
129
|
)
|
|
97
|
-
|
|
98
|
-
|
|
130
|
+
host = args.host or get_env_host()
|
|
131
|
+
port = args.port or get_env_port()
|
|
132
|
+
server = Server(config=Config(app, host=host, port=port))
|
|
133
|
+
Thread(target=_write_pid_file_when_ready, args=(server,), daemon=True).start()
|
|
134
|
+
server.run()
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
import gzip
|
|
2
|
+
from typing import Protocol
|
|
3
|
+
|
|
4
|
+
from starlette.endpoints import HTTPEndpoint
|
|
5
|
+
from starlette.requests import Request
|
|
6
|
+
from starlette.responses import Response
|
|
7
|
+
|
|
8
|
+
from phoenix.trace.schemas import Span
|
|
9
|
+
from phoenix.trace.span_json_decoder import json_to_span
|
|
10
|
+
from phoenix.trace.v1 import encode
|
|
11
|
+
from phoenix.trace.v1 import trace_pb2 as pb
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class SupportsPutSpan(Protocol):
|
|
15
|
+
def put(self, span: pb.Span) -> None:
|
|
16
|
+
...
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class SpanHandler(HTTPEndpoint):
|
|
20
|
+
queue: SupportsPutSpan
|
|
21
|
+
|
|
22
|
+
async def post(self, request: Request) -> Response:
|
|
23
|
+
try:
|
|
24
|
+
content_type = request.headers.get("content-type")
|
|
25
|
+
if content_type == "application/x-protobuf":
|
|
26
|
+
body = await request.body()
|
|
27
|
+
content_encoding = request.headers.get("content-encoding")
|
|
28
|
+
if content_encoding == "gzip":
|
|
29
|
+
body = gzip.decompress(body)
|
|
30
|
+
pb_span = pb.Span()
|
|
31
|
+
pb_span.ParseFromString(body)
|
|
32
|
+
else:
|
|
33
|
+
span = json_to_span(await request.json())
|
|
34
|
+
assert isinstance(span, Span)
|
|
35
|
+
pb_span = encode(span)
|
|
36
|
+
except Exception:
|
|
37
|
+
return Response(status_code=422)
|
|
38
|
+
self.queue.put(pb_span)
|
|
39
|
+
return Response()
|