arize-phoenix 5.2.2__py3-none-any.whl → 5.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-5.2.2.dist-info → arize_phoenix-5.3.0.dist-info}/METADATA +2 -1
- {arize_phoenix-5.2.2.dist-info → arize_phoenix-5.3.0.dist-info}/RECORD +36 -29
- phoenix/core/model_schema_adapter.py +2 -2
- phoenix/db/migrations/versions/10460e46d750_datasets.py +1 -1
- phoenix/db/migrations/versions/3be8647b87d8_add_token_columns_to_spans_table.py +1 -1
- phoenix/db/migrations/versions/cf03bd6bae1d_init.py +1 -1
- phoenix/db/models.py +6 -6
- phoenix/server/api/input_types/ChatCompletionMessageInput.py +12 -0
- phoenix/server/api/input_types/InvocationParameters.py +20 -0
- phoenix/server/api/openapi/main.py +8 -13
- phoenix/server/api/queries.py +74 -0
- phoenix/server/api/schema.py +2 -0
- phoenix/server/api/subscriptions.py +430 -0
- phoenix/server/api/types/ChatCompletionMessageRole.py +11 -0
- phoenix/server/api/types/GenerativeModel.py +9 -0
- phoenix/server/api/types/GenerativeProvider.py +16 -0
- phoenix/server/app.py +36 -7
- phoenix/server/bearer_auth.py +12 -4
- phoenix/server/main.py +2 -2
- phoenix/server/static/.vite/manifest.json +35 -35
- phoenix/server/static/assets/{components-CPkaHQZs.js → components-DwERj42u.js} +153 -130
- phoenix/server/static/assets/{index-D4G-kYL9.js → index-CyOjvLOr.js} +2 -2
- phoenix/server/static/assets/{pages-B3HCyYqg.js → pages-BNYMd7SU.js} +360 -255
- phoenix/server/static/assets/vendor-D-NIjePD.js +872 -0
- phoenix/server/static/assets/{vendor-arizeai-D0FocbYu.js → vendor-arizeai-DoY5jUTO.js} +30 -28
- phoenix/server/static/assets/vendor-codemirror-CIhY_nEU.js +24 -0
- phoenix/server/static/assets/{vendor-recharts-JMOLUxWG.js → vendor-recharts-Dgcm35Jq.js} +1 -1
- phoenix/trace/fixtures.py +7 -9
- phoenix/utilities/client.py +2 -2
- phoenix/utilities/json.py +11 -0
- phoenix/utilities/template_formatters.py +70 -0
- phoenix/version.py +1 -1
- phoenix/server/static/assets/vendor-c0HJYYGN.js +0 -641
- phoenix/server/static/assets/vendor-codemirror-BbCMI-_D.js +0 -30
- {arize_phoenix-5.2.2.dist-info → arize_phoenix-5.3.0.dist-info}/WHEEL +0 -0
- {arize_phoenix-5.2.2.dist-info → arize_phoenix-5.3.0.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-5.2.2.dist-info → arize_phoenix-5.3.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-5.2.2.dist-info → arize_phoenix-5.3.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,430 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from collections import defaultdict
|
|
3
|
+
from dataclasses import fields
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from enum import Enum
|
|
6
|
+
from itertools import chain
|
|
7
|
+
from typing import (
|
|
8
|
+
TYPE_CHECKING,
|
|
9
|
+
Annotated,
|
|
10
|
+
Any,
|
|
11
|
+
AsyncIterator,
|
|
12
|
+
DefaultDict,
|
|
13
|
+
Dict,
|
|
14
|
+
Iterable,
|
|
15
|
+
Iterator,
|
|
16
|
+
List,
|
|
17
|
+
Optional,
|
|
18
|
+
Tuple,
|
|
19
|
+
Union,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
import strawberry
|
|
23
|
+
from openinference.instrumentation import safe_json_dumps
|
|
24
|
+
from openinference.semconv.trace import (
|
|
25
|
+
MessageAttributes,
|
|
26
|
+
OpenInferenceMimeTypeValues,
|
|
27
|
+
OpenInferenceSpanKindValues,
|
|
28
|
+
SpanAttributes,
|
|
29
|
+
ToolAttributes,
|
|
30
|
+
ToolCallAttributes,
|
|
31
|
+
)
|
|
32
|
+
from opentelemetry.sdk.trace import TracerProvider
|
|
33
|
+
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
|
|
34
|
+
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
|
|
35
|
+
from opentelemetry.trace import StatusCode
|
|
36
|
+
from sqlalchemy import insert, select
|
|
37
|
+
from strawberry import UNSET
|
|
38
|
+
from strawberry.scalars import JSON as JSONScalarType
|
|
39
|
+
from strawberry.types import Info
|
|
40
|
+
from typing_extensions import TypeAlias, assert_never
|
|
41
|
+
|
|
42
|
+
from phoenix.db import models
|
|
43
|
+
from phoenix.server.api.context import Context
|
|
44
|
+
from phoenix.server.api.input_types.ChatCompletionMessageInput import ChatCompletionMessageInput
|
|
45
|
+
from phoenix.server.api.input_types.InvocationParameters import InvocationParameters
|
|
46
|
+
from phoenix.server.api.types.ChatCompletionMessageRole import ChatCompletionMessageRole
|
|
47
|
+
from phoenix.server.api.types.GenerativeProvider import GenerativeProviderKey
|
|
48
|
+
from phoenix.server.dml_event import SpanInsertEvent
|
|
49
|
+
from phoenix.trace.attributes import unflatten
|
|
50
|
+
from phoenix.utilities.json import jsonify
|
|
51
|
+
from phoenix.utilities.template_formatters import (
|
|
52
|
+
FStringTemplateFormatter,
|
|
53
|
+
MustacheTemplateFormatter,
|
|
54
|
+
TemplateFormatter,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
if TYPE_CHECKING:
|
|
58
|
+
from openai.types.chat import (
|
|
59
|
+
ChatCompletionMessageParam,
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
PLAYGROUND_PROJECT_NAME = "playground"
|
|
63
|
+
|
|
64
|
+
ToolCallIndex: TypeAlias = int
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
@strawberry.enum
|
|
68
|
+
class TemplateLanguage(Enum):
|
|
69
|
+
MUSTACHE = "MUSTACHE"
|
|
70
|
+
F_STRING = "F_STRING"
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@strawberry.input
|
|
74
|
+
class TemplateOptions:
|
|
75
|
+
variables: JSONScalarType
|
|
76
|
+
language: TemplateLanguage
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
@strawberry.type
|
|
80
|
+
class TextChunk:
|
|
81
|
+
content: str
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@strawberry.type
|
|
85
|
+
class FunctionCallChunk:
|
|
86
|
+
name: str
|
|
87
|
+
arguments: str
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@strawberry.type
|
|
91
|
+
class ToolCallChunk:
|
|
92
|
+
id: str
|
|
93
|
+
function: FunctionCallChunk
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
ChatCompletionChunk: TypeAlias = Annotated[
|
|
97
|
+
Union[TextChunk, ToolCallChunk], strawberry.union("ChatCompletionChunk")
|
|
98
|
+
]
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
@strawberry.input
|
|
102
|
+
class GenerativeModelInput:
|
|
103
|
+
provider_key: GenerativeProviderKey
|
|
104
|
+
name: str
|
|
105
|
+
""" The name of the model. Or the Deployment Name for Azure OpenAI models. """
|
|
106
|
+
endpoint: Optional[str] = UNSET
|
|
107
|
+
""" The endpoint to use for the model. Only required for Azure OpenAI models. """
|
|
108
|
+
api_version: Optional[str] = UNSET
|
|
109
|
+
""" The API version to use for the model. """
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
@strawberry.input
|
|
113
|
+
class ChatCompletionInput:
|
|
114
|
+
messages: List[ChatCompletionMessageInput]
|
|
115
|
+
model: GenerativeModelInput
|
|
116
|
+
invocation_parameters: InvocationParameters
|
|
117
|
+
tools: Optional[List[JSONScalarType]] = UNSET
|
|
118
|
+
template: Optional[TemplateOptions] = UNSET
|
|
119
|
+
api_key: Optional[str] = strawberry.field(default=None)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def to_openai_chat_completion_param(
|
|
123
|
+
role: ChatCompletionMessageRole, content: JSONScalarType
|
|
124
|
+
) -> "ChatCompletionMessageParam":
|
|
125
|
+
from openai.types.chat import (
|
|
126
|
+
ChatCompletionAssistantMessageParam,
|
|
127
|
+
ChatCompletionSystemMessageParam,
|
|
128
|
+
ChatCompletionUserMessageParam,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
if role is ChatCompletionMessageRole.USER:
|
|
132
|
+
return ChatCompletionUserMessageParam(
|
|
133
|
+
{
|
|
134
|
+
"content": content,
|
|
135
|
+
"role": "user",
|
|
136
|
+
}
|
|
137
|
+
)
|
|
138
|
+
if role is ChatCompletionMessageRole.SYSTEM:
|
|
139
|
+
return ChatCompletionSystemMessageParam(
|
|
140
|
+
{
|
|
141
|
+
"content": content,
|
|
142
|
+
"role": "system",
|
|
143
|
+
}
|
|
144
|
+
)
|
|
145
|
+
if role is ChatCompletionMessageRole.AI:
|
|
146
|
+
return ChatCompletionAssistantMessageParam(
|
|
147
|
+
{
|
|
148
|
+
"content": content,
|
|
149
|
+
"role": "assistant",
|
|
150
|
+
}
|
|
151
|
+
)
|
|
152
|
+
if role is ChatCompletionMessageRole.TOOL:
|
|
153
|
+
raise NotImplementedError
|
|
154
|
+
assert_never(role)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
@strawberry.type
|
|
158
|
+
class Subscription:
|
|
159
|
+
@strawberry.subscription
|
|
160
|
+
async def chat_completion(
|
|
161
|
+
self, info: Info[Context, None], input: ChatCompletionInput
|
|
162
|
+
) -> AsyncIterator[ChatCompletionChunk]:
|
|
163
|
+
from openai import NOT_GIVEN, AsyncAzureOpenAI, AsyncOpenAI
|
|
164
|
+
|
|
165
|
+
client: Union[AsyncAzureOpenAI, AsyncOpenAI]
|
|
166
|
+
|
|
167
|
+
if input.model.provider_key == GenerativeProviderKey.AZURE_OPENAI:
|
|
168
|
+
if input.model.endpoint is None or input.model.api_version is None:
|
|
169
|
+
raise ValueError("endpoint and api_version are required for Azure OpenAI models")
|
|
170
|
+
client = AsyncAzureOpenAI(
|
|
171
|
+
api_key=input.api_key,
|
|
172
|
+
azure_endpoint=input.model.endpoint,
|
|
173
|
+
api_version=input.model.api_version,
|
|
174
|
+
)
|
|
175
|
+
else:
|
|
176
|
+
client = AsyncOpenAI(api_key=input.api_key)
|
|
177
|
+
|
|
178
|
+
invocation_parameters = jsonify(input.invocation_parameters)
|
|
179
|
+
|
|
180
|
+
messages: List[Tuple[ChatCompletionMessageRole, str]] = [
|
|
181
|
+
(message.role, message.content) for message in input.messages
|
|
182
|
+
]
|
|
183
|
+
if template_options := input.template:
|
|
184
|
+
messages = list(_formatted_messages(messages, template_options))
|
|
185
|
+
openai_messages = [to_openai_chat_completion_param(*message) for message in messages]
|
|
186
|
+
|
|
187
|
+
in_memory_span_exporter = InMemorySpanExporter()
|
|
188
|
+
tracer_provider = TracerProvider()
|
|
189
|
+
tracer_provider.add_span_processor(
|
|
190
|
+
span_processor=SimpleSpanProcessor(span_exporter=in_memory_span_exporter)
|
|
191
|
+
)
|
|
192
|
+
tracer = tracer_provider.get_tracer(__name__)
|
|
193
|
+
span_name = "ChatCompletion"
|
|
194
|
+
with tracer.start_span(
|
|
195
|
+
span_name,
|
|
196
|
+
attributes=dict(
|
|
197
|
+
chain(
|
|
198
|
+
_llm_span_kind(),
|
|
199
|
+
_llm_model_name(input.model.name),
|
|
200
|
+
_llm_tools(input.tools or []),
|
|
201
|
+
_llm_input_messages(messages),
|
|
202
|
+
_llm_invocation_parameters(invocation_parameters),
|
|
203
|
+
_input_value_and_mime_type(input),
|
|
204
|
+
)
|
|
205
|
+
),
|
|
206
|
+
) as span:
|
|
207
|
+
response_chunks = []
|
|
208
|
+
text_chunks: List[TextChunk] = []
|
|
209
|
+
tool_call_chunks: DefaultDict[ToolCallIndex, List[ToolCallChunk]] = defaultdict(list)
|
|
210
|
+
role: Optional[str] = None
|
|
211
|
+
async for chunk in await client.chat.completions.create(
|
|
212
|
+
messages=openai_messages,
|
|
213
|
+
model=input.model.name,
|
|
214
|
+
stream=True,
|
|
215
|
+
tools=input.tools or NOT_GIVEN,
|
|
216
|
+
**invocation_parameters,
|
|
217
|
+
):
|
|
218
|
+
response_chunks.append(chunk)
|
|
219
|
+
choice = chunk.choices[0]
|
|
220
|
+
delta = choice.delta
|
|
221
|
+
if role is None:
|
|
222
|
+
role = delta.role
|
|
223
|
+
if choice.finish_reason is None:
|
|
224
|
+
if isinstance(chunk_content := delta.content, str):
|
|
225
|
+
text_chunk = TextChunk(content=chunk_content)
|
|
226
|
+
yield text_chunk
|
|
227
|
+
text_chunks.append(text_chunk)
|
|
228
|
+
if (tool_calls := delta.tool_calls) is not None:
|
|
229
|
+
for tool_call_index, tool_call in enumerate(tool_calls):
|
|
230
|
+
if (function := tool_call.function) is not None:
|
|
231
|
+
if (tool_call_id := tool_call.id) is None:
|
|
232
|
+
first_tool_call_chunk = tool_call_chunks[tool_call_index][0]
|
|
233
|
+
tool_call_id = first_tool_call_chunk.id
|
|
234
|
+
tool_call_chunk = ToolCallChunk(
|
|
235
|
+
id=tool_call_id,
|
|
236
|
+
function=FunctionCallChunk(
|
|
237
|
+
name=function.name or "",
|
|
238
|
+
arguments=function.arguments or "",
|
|
239
|
+
),
|
|
240
|
+
)
|
|
241
|
+
yield tool_call_chunk
|
|
242
|
+
tool_call_chunks[tool_call_index].append(tool_call_chunk)
|
|
243
|
+
span.set_status(StatusCode.OK)
|
|
244
|
+
assert role is not None
|
|
245
|
+
span.set_attributes(
|
|
246
|
+
dict(
|
|
247
|
+
chain(
|
|
248
|
+
_output_value_and_mime_type(response_chunks),
|
|
249
|
+
_llm_output_messages(text_chunks, tool_call_chunks),
|
|
250
|
+
)
|
|
251
|
+
)
|
|
252
|
+
)
|
|
253
|
+
assert len(spans := in_memory_span_exporter.get_finished_spans()) == 1
|
|
254
|
+
finished_span = spans[0]
|
|
255
|
+
assert finished_span.start_time is not None
|
|
256
|
+
assert finished_span.end_time is not None
|
|
257
|
+
assert (attributes := finished_span.attributes) is not None
|
|
258
|
+
start_time = _datetime(epoch_nanoseconds=finished_span.start_time)
|
|
259
|
+
end_time = _datetime(epoch_nanoseconds=finished_span.end_time)
|
|
260
|
+
trace_id = _hex(finished_span.context.trace_id)
|
|
261
|
+
span_id = _hex(finished_span.context.span_id)
|
|
262
|
+
status = finished_span.status
|
|
263
|
+
async with info.context.db() as session:
|
|
264
|
+
if (
|
|
265
|
+
playground_project_id := await session.scalar(
|
|
266
|
+
select(models.Project.id).where(models.Project.name == PLAYGROUND_PROJECT_NAME)
|
|
267
|
+
)
|
|
268
|
+
) is None:
|
|
269
|
+
playground_project_id = await session.scalar(
|
|
270
|
+
insert(models.Project)
|
|
271
|
+
.returning(models.Project.id)
|
|
272
|
+
.values(
|
|
273
|
+
name=PLAYGROUND_PROJECT_NAME,
|
|
274
|
+
description="Traces from prompt playground",
|
|
275
|
+
)
|
|
276
|
+
)
|
|
277
|
+
trace_rowid = await session.scalar(
|
|
278
|
+
insert(models.Trace)
|
|
279
|
+
.returning(models.Trace.id)
|
|
280
|
+
.values(
|
|
281
|
+
project_rowid=playground_project_id,
|
|
282
|
+
trace_id=trace_id,
|
|
283
|
+
start_time=start_time,
|
|
284
|
+
end_time=end_time,
|
|
285
|
+
)
|
|
286
|
+
)
|
|
287
|
+
await session.execute(
|
|
288
|
+
insert(models.Span).values(
|
|
289
|
+
trace_rowid=trace_rowid,
|
|
290
|
+
span_id=span_id,
|
|
291
|
+
parent_id=None,
|
|
292
|
+
name=span_name,
|
|
293
|
+
span_kind=LLM,
|
|
294
|
+
start_time=start_time,
|
|
295
|
+
end_time=end_time,
|
|
296
|
+
attributes=unflatten(attributes.items()),
|
|
297
|
+
events=finished_span.events,
|
|
298
|
+
status_code=status.status_code.name,
|
|
299
|
+
status_message=status.description or "",
|
|
300
|
+
cumulative_error_count=int(not status.is_ok),
|
|
301
|
+
cumulative_llm_token_count_prompt=0,
|
|
302
|
+
cumulative_llm_token_count_completion=0,
|
|
303
|
+
llm_token_count_prompt=0,
|
|
304
|
+
llm_token_count_completion=0,
|
|
305
|
+
)
|
|
306
|
+
)
|
|
307
|
+
info.context.event_queue.put(SpanInsertEvent(ids=(playground_project_id,)))
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
def _llm_span_kind() -> Iterator[Tuple[str, Any]]:
|
|
311
|
+
yield OPENINFERENCE_SPAN_KIND, LLM
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
def _llm_model_name(model_name: str) -> Iterator[Tuple[str, Any]]:
|
|
315
|
+
yield LLM_MODEL_NAME, model_name
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
def _llm_invocation_parameters(invocation_parameters: Dict[str, Any]) -> Iterator[Tuple[str, Any]]:
|
|
319
|
+
yield LLM_INVOCATION_PARAMETERS, safe_json_dumps(invocation_parameters)
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
def _llm_tools(tools: List[JSONScalarType]) -> Iterator[Tuple[str, Any]]:
|
|
323
|
+
for tool_index, tool in enumerate(tools):
|
|
324
|
+
yield f"{LLM_TOOLS}.{tool_index}.{TOOL_JSON_SCHEMA}", json.dumps(tool)
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
def _input_value_and_mime_type(input: ChatCompletionInput) -> Iterator[Tuple[str, Any]]:
|
|
328
|
+
assert any(field.name == (api_key := "api_key") for field in fields(ChatCompletionInput))
|
|
329
|
+
yield INPUT_MIME_TYPE, JSON
|
|
330
|
+
yield INPUT_VALUE, safe_json_dumps({k: v for k, v in jsonify(input).items() if k != api_key})
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
def _output_value_and_mime_type(output: Any) -> Iterator[Tuple[str, Any]]:
|
|
334
|
+
yield OUTPUT_MIME_TYPE, JSON
|
|
335
|
+
yield OUTPUT_VALUE, safe_json_dumps(jsonify(output))
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
def _llm_input_messages(
|
|
339
|
+
messages: Iterable[Tuple[ChatCompletionMessageRole, str]],
|
|
340
|
+
) -> Iterator[Tuple[str, Any]]:
|
|
341
|
+
for i, (role, content) in enumerate(messages):
|
|
342
|
+
yield f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_ROLE}", role.value.lower()
|
|
343
|
+
yield f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_CONTENT}", content
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
def _llm_output_messages(
|
|
347
|
+
text_chunks: List[TextChunk],
|
|
348
|
+
tool_call_chunks: DefaultDict[ToolCallIndex, List[ToolCallChunk]],
|
|
349
|
+
) -> Iterator[Tuple[str, Any]]:
|
|
350
|
+
yield f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_ROLE}", "assistant"
|
|
351
|
+
if content := "".join(chunk.content for chunk in text_chunks):
|
|
352
|
+
yield f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_CONTENT}", content
|
|
353
|
+
for tool_call_index, tool_call_chunks_ in tool_call_chunks.items():
|
|
354
|
+
if tool_call_chunks_ and (name := tool_call_chunks_[0].function.name):
|
|
355
|
+
yield (
|
|
356
|
+
f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_NAME}",
|
|
357
|
+
name,
|
|
358
|
+
)
|
|
359
|
+
if arguments := "".join(chunk.function.arguments for chunk in tool_call_chunks_):
|
|
360
|
+
yield (
|
|
361
|
+
f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_ARGUMENTS_JSON}",
|
|
362
|
+
arguments,
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
def _hex(number: int) -> str:
|
|
367
|
+
"""
|
|
368
|
+
Converts an integer to a hexadecimal string.
|
|
369
|
+
"""
|
|
370
|
+
return hex(number)[2:]
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
def _datetime(*, epoch_nanoseconds: float) -> datetime:
|
|
374
|
+
"""
|
|
375
|
+
Converts a Unix epoch timestamp in nanoseconds to a datetime.
|
|
376
|
+
"""
|
|
377
|
+
epoch_seconds = epoch_nanoseconds / 1e9
|
|
378
|
+
return datetime.fromtimestamp(epoch_seconds)
|
|
379
|
+
|
|
380
|
+
|
|
381
|
+
def _formatted_messages(
|
|
382
|
+
messages: Iterable[Tuple[ChatCompletionMessageRole, str]], template_options: TemplateOptions
|
|
383
|
+
) -> Iterator[Tuple[ChatCompletionMessageRole, str]]:
|
|
384
|
+
"""
|
|
385
|
+
Formats the messages using the given template options.
|
|
386
|
+
"""
|
|
387
|
+
template_formatter = _template_formatter(template_language=template_options.language)
|
|
388
|
+
roles, templates = zip(*messages)
|
|
389
|
+
formatted_templates = map(
|
|
390
|
+
lambda template: template_formatter.format(template, **template_options.variables),
|
|
391
|
+
templates,
|
|
392
|
+
)
|
|
393
|
+
formatted_messages = zip(roles, formatted_templates)
|
|
394
|
+
return formatted_messages
|
|
395
|
+
|
|
396
|
+
|
|
397
|
+
def _template_formatter(template_language: TemplateLanguage) -> TemplateFormatter:
|
|
398
|
+
"""
|
|
399
|
+
Instantiates the appropriate template formatter for the template language.
|
|
400
|
+
"""
|
|
401
|
+
if template_language is TemplateLanguage.MUSTACHE:
|
|
402
|
+
return MustacheTemplateFormatter()
|
|
403
|
+
if template_language is TemplateLanguage.F_STRING:
|
|
404
|
+
return FStringTemplateFormatter()
|
|
405
|
+
assert_never(template_language)
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
JSON = OpenInferenceMimeTypeValues.JSON.value
|
|
409
|
+
|
|
410
|
+
LLM = OpenInferenceSpanKindValues.LLM.value
|
|
411
|
+
|
|
412
|
+
OPENINFERENCE_SPAN_KIND = SpanAttributes.OPENINFERENCE_SPAN_KIND
|
|
413
|
+
INPUT_MIME_TYPE = SpanAttributes.INPUT_MIME_TYPE
|
|
414
|
+
INPUT_VALUE = SpanAttributes.INPUT_VALUE
|
|
415
|
+
OUTPUT_MIME_TYPE = SpanAttributes.OUTPUT_MIME_TYPE
|
|
416
|
+
OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE
|
|
417
|
+
LLM_INPUT_MESSAGES = SpanAttributes.LLM_INPUT_MESSAGES
|
|
418
|
+
LLM_OUTPUT_MESSAGES = SpanAttributes.LLM_OUTPUT_MESSAGES
|
|
419
|
+
LLM_MODEL_NAME = SpanAttributes.LLM_MODEL_NAME
|
|
420
|
+
LLM_INVOCATION_PARAMETERS = SpanAttributes.LLM_INVOCATION_PARAMETERS
|
|
421
|
+
LLM_TOOLS = SpanAttributes.LLM_TOOLS
|
|
422
|
+
|
|
423
|
+
MESSAGE_CONTENT = MessageAttributes.MESSAGE_CONTENT
|
|
424
|
+
MESSAGE_ROLE = MessageAttributes.MESSAGE_ROLE
|
|
425
|
+
MESSAGE_TOOL_CALLS = MessageAttributes.MESSAGE_TOOL_CALLS
|
|
426
|
+
|
|
427
|
+
TOOL_CALL_FUNCTION_NAME = ToolCallAttributes.TOOL_CALL_FUNCTION_NAME
|
|
428
|
+
TOOL_CALL_FUNCTION_ARGUMENTS_JSON = ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUMENTS_JSON
|
|
429
|
+
|
|
430
|
+
TOOL_JSON_SCHEMA = ToolAttributes.TOOL_JSON_SCHEMA
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
|
|
3
|
+
import strawberry
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@strawberry.enum
|
|
7
|
+
class GenerativeProviderKey(Enum):
|
|
8
|
+
OPENAI = "OPENAI"
|
|
9
|
+
ANTHROPIC = "ANTHROPIC"
|
|
10
|
+
AZURE_OPENAI = "AZURE_OPENAI"
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@strawberry.type
|
|
14
|
+
class GenerativeProvider:
|
|
15
|
+
name: str
|
|
16
|
+
key: GenerativeProviderKey
|
phoenix/server/app.py
CHANGED
|
@@ -36,22 +36,23 @@ from fastapi.utils import is_body_allowed_for_status_code
|
|
|
36
36
|
from sqlalchemy import select
|
|
37
37
|
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
|
|
38
38
|
from starlette.datastructures import State as StarletteState
|
|
39
|
-
from starlette.exceptions import HTTPException
|
|
39
|
+
from starlette.exceptions import HTTPException, WebSocketException
|
|
40
40
|
from starlette.middleware import Middleware
|
|
41
41
|
from starlette.middleware.authentication import AuthenticationMiddleware
|
|
42
42
|
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
|
43
43
|
from starlette.requests import Request
|
|
44
|
-
from starlette.responses import PlainTextResponse, Response
|
|
44
|
+
from starlette.responses import JSONResponse, PlainTextResponse, Response
|
|
45
45
|
from starlette.staticfiles import StaticFiles
|
|
46
46
|
from starlette.status import HTTP_401_UNAUTHORIZED
|
|
47
47
|
from starlette.templating import Jinja2Templates
|
|
48
48
|
from starlette.types import Scope, StatefulLifespan
|
|
49
|
+
from starlette.websockets import WebSocket
|
|
49
50
|
from strawberry.extensions import SchemaExtension
|
|
50
51
|
from strawberry.fastapi import GraphQLRouter
|
|
51
52
|
from strawberry.schema import BaseSchema
|
|
53
|
+
from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL
|
|
52
54
|
from typing_extensions import TypeAlias
|
|
53
55
|
|
|
54
|
-
import phoenix
|
|
55
56
|
import phoenix.trace.v1 as pb
|
|
56
57
|
from phoenix.config import (
|
|
57
58
|
DEFAULT_PROJECT_NAME,
|
|
@@ -135,6 +136,7 @@ from phoenix.trace.fixtures import (
|
|
|
135
136
|
from phoenix.trace.otel import decode_otlp_span, encode_span_to_otlp
|
|
136
137
|
from phoenix.trace.schemas import Span
|
|
137
138
|
from phoenix.utilities.client import PHOENIX_SERVER_VERSION_HEADER
|
|
139
|
+
from phoenix.version import __version__ as phoenix_version
|
|
138
140
|
|
|
139
141
|
if TYPE_CHECKING:
|
|
140
142
|
from opentelemetry.trace import TracerProvider
|
|
@@ -217,7 +219,7 @@ class Static(StaticFiles):
|
|
|
217
219
|
"n_neighbors": self._app_config.n_neighbors,
|
|
218
220
|
"n_samples": self._app_config.n_samples,
|
|
219
221
|
"basename": self._sanitize_basename(request.scope.get("root_path", "")),
|
|
220
|
-
"platform_version":
|
|
222
|
+
"platform_version": phoenix_version,
|
|
221
223
|
"request": request,
|
|
222
224
|
"is_development": self._app_config.is_development,
|
|
223
225
|
"manifest": self._web_manifest,
|
|
@@ -255,7 +257,7 @@ class HeadersMiddleware(BaseHTTPMiddleware):
|
|
|
255
257
|
request: Request,
|
|
256
258
|
call_next: RequestResponseEndpoint,
|
|
257
259
|
) -> Response:
|
|
258
|
-
from phoenix import __version__ as phoenix_version
|
|
260
|
+
from phoenix.version import __version__ as phoenix_version
|
|
259
261
|
|
|
260
262
|
response = await call_next(request)
|
|
261
263
|
response.headers["x-colab-notebook-cache-control"] = "no-cache"
|
|
@@ -268,7 +270,7 @@ ProjectRowId: TypeAlias = int
|
|
|
268
270
|
|
|
269
271
|
@router.get("/arize_phoenix_version")
|
|
270
272
|
async def version() -> PlainTextResponse:
|
|
271
|
-
return PlainTextResponse(f"{
|
|
273
|
+
return PlainTextResponse(f"{phoenix_version}")
|
|
272
274
|
|
|
273
275
|
|
|
274
276
|
DB_MUTEX: Optional[asyncio.Lock] = None
|
|
@@ -580,6 +582,7 @@ def create_graphql_router(
|
|
|
580
582
|
include_in_schema=False,
|
|
581
583
|
prefix="/graphql",
|
|
582
584
|
dependencies=(Depends(is_authenticated),) if authentication_enabled else (),
|
|
585
|
+
subscription_protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL],
|
|
583
586
|
)
|
|
584
587
|
|
|
585
588
|
|
|
@@ -630,6 +633,29 @@ async def plain_text_http_exception_handler(request: Request, exc: HTTPException
|
|
|
630
633
|
return PlainTextResponse(str(exc.detail), status_code=exc.status_code, headers=headers)
|
|
631
634
|
|
|
632
635
|
|
|
636
|
+
async def websocket_denial_response_handler(websocket: WebSocket, exc: WebSocketException) -> None:
|
|
637
|
+
"""
|
|
638
|
+
Overrides the default exception handler for WebSocketException to ensure
|
|
639
|
+
that the HTTP response returned when a WebSocket connection is denied has
|
|
640
|
+
the same status code as the raised exception. This is in keeping with the
|
|
641
|
+
WebSocket Denial Response Extension of the ASGI specificiation described
|
|
642
|
+
below.
|
|
643
|
+
|
|
644
|
+
"Websocket connections start with the client sending a HTTP request
|
|
645
|
+
containing the appropriate upgrade headers. On receipt of this request a
|
|
646
|
+
server can choose to either upgrade the connection or respond with an HTTP
|
|
647
|
+
response (denying the upgrade). The core ASGI specification does not allow
|
|
648
|
+
for any control over the denial response, instead specifying that the HTTP
|
|
649
|
+
status code 403 should be returned, whereas this extension allows an ASGI
|
|
650
|
+
framework to control the denial response."
|
|
651
|
+
|
|
652
|
+
For details, see:
|
|
653
|
+
- https://asgi.readthedocs.io/en/latest/extensions.html#websocket-denial-response
|
|
654
|
+
"""
|
|
655
|
+
assert isinstance(exc, WebSocketException)
|
|
656
|
+
await websocket.send_denial_response(JSONResponse(status_code=exc.code, content=exc.reason))
|
|
657
|
+
|
|
658
|
+
|
|
633
659
|
def create_app(
|
|
634
660
|
db: DbSessionFactory,
|
|
635
661
|
export_path: Path,
|
|
@@ -776,7 +802,10 @@ def create_app(
|
|
|
776
802
|
scaffolder_config=scaffolder_config,
|
|
777
803
|
),
|
|
778
804
|
middleware=middlewares,
|
|
779
|
-
exception_handlers={
|
|
805
|
+
exception_handlers={
|
|
806
|
+
HTTPException: plain_text_http_exception_handler,
|
|
807
|
+
WebSocketException: websocket_denial_response_handler, # type: ignore[dict-item]
|
|
808
|
+
},
|
|
780
809
|
debug=debug,
|
|
781
810
|
swagger_ui_parameters={
|
|
782
811
|
"defaultModelsExpandDepth": -1, # hides the schema section in the Swagger UI
|
phoenix/server/bearer_auth.py
CHANGED
|
@@ -7,10 +7,11 @@ from typing import (
|
|
|
7
7
|
Callable,
|
|
8
8
|
Optional,
|
|
9
9
|
Tuple,
|
|
10
|
+
cast,
|
|
10
11
|
)
|
|
11
12
|
|
|
12
13
|
import grpc
|
|
13
|
-
from fastapi import HTTPException, Request
|
|
14
|
+
from fastapi import HTTPException, Request, WebSocket, WebSocketException
|
|
14
15
|
from grpc_interceptor import AsyncServerInterceptor
|
|
15
16
|
from grpc_interceptor.exceptions import Unauthenticated
|
|
16
17
|
from starlette.authentication import AuthCredentials, AuthenticationBackend, BaseUser
|
|
@@ -116,12 +117,19 @@ class ApiKeyInterceptor(HasTokenStore, AsyncServerInterceptor):
|
|
|
116
117
|
raise Unauthenticated()
|
|
117
118
|
|
|
118
119
|
|
|
119
|
-
async def is_authenticated(
|
|
120
|
+
async def is_authenticated(
|
|
121
|
+
# fastapi dependencies require non-optional types
|
|
122
|
+
request: Request = cast(Request, None),
|
|
123
|
+
websocket: WebSocket = cast(WebSocket, None),
|
|
124
|
+
) -> None:
|
|
120
125
|
"""
|
|
121
|
-
Raises a 401 if the request is not authenticated.
|
|
126
|
+
Raises a 401 if the request or websocket connection is not authenticated.
|
|
122
127
|
"""
|
|
123
|
-
|
|
128
|
+
assert request or websocket
|
|
129
|
+
if request and not isinstance((user := request.user), PhoenixUser):
|
|
124
130
|
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Invalid token")
|
|
131
|
+
if websocket and not isinstance((user := websocket.user), PhoenixUser):
|
|
132
|
+
raise WebSocketException(code=HTTP_401_UNAUTHORIZED, reason="Invalid token")
|
|
125
133
|
claims = user.claims
|
|
126
134
|
if claims.status is ClaimSetStatus.EXPIRED:
|
|
127
135
|
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Expired token")
|
phoenix/server/main.py
CHANGED
|
@@ -3,7 +3,6 @@ import codecs
|
|
|
3
3
|
import os
|
|
4
4
|
import sys
|
|
5
5
|
from argparse import SUPPRESS, ArgumentParser
|
|
6
|
-
from importlib.metadata import version
|
|
7
6
|
from pathlib import Path
|
|
8
7
|
from threading import Thread
|
|
9
8
|
from time import sleep, time
|
|
@@ -72,6 +71,7 @@ from phoenix.trace.fixtures import (
|
|
|
72
71
|
)
|
|
73
72
|
from phoenix.trace.otel import decode_otlp_span, encode_span_to_otlp
|
|
74
73
|
from phoenix.trace.schemas import Span
|
|
74
|
+
from phoenix.version import __version__ as phoenix_version
|
|
75
75
|
|
|
76
76
|
_WELCOME_MESSAGE = Environment(loader=BaseLoader()).from_string("""
|
|
77
77
|
|
|
@@ -351,7 +351,7 @@ def main() -> None:
|
|
|
351
351
|
# Print information about the server
|
|
352
352
|
root_path = urljoin(f"http://{host}:{port}", host_root_path)
|
|
353
353
|
msg = _WELCOME_MESSAGE.render(
|
|
354
|
-
version=
|
|
354
|
+
version=phoenix_version,
|
|
355
355
|
ui_path=root_path,
|
|
356
356
|
grpc_path=f"http://{host}:{get_env_grpc_port()}",
|
|
357
357
|
http_path=urljoin(root_path, "v1/traces"),
|