arize-phoenix 5.2.1__py3-none-any.whl → 5.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-5.2.1.dist-info → arize_phoenix-5.3.0.dist-info}/METADATA +4 -3
- {arize_phoenix-5.2.1.dist-info → arize_phoenix-5.3.0.dist-info}/RECORD +37 -30
- phoenix/config.py +24 -0
- phoenix/core/model_schema_adapter.py +2 -2
- phoenix/db/migrations/versions/10460e46d750_datasets.py +1 -1
- phoenix/db/migrations/versions/3be8647b87d8_add_token_columns_to_spans_table.py +1 -1
- phoenix/db/migrations/versions/cf03bd6bae1d_init.py +1 -1
- phoenix/db/models.py +6 -6
- phoenix/server/api/input_types/ChatCompletionMessageInput.py +12 -0
- phoenix/server/api/input_types/InvocationParameters.py +20 -0
- phoenix/server/api/openapi/main.py +8 -13
- phoenix/server/api/queries.py +74 -0
- phoenix/server/api/schema.py +2 -0
- phoenix/server/api/subscriptions.py +430 -0
- phoenix/server/api/types/ChatCompletionMessageRole.py +11 -0
- phoenix/server/api/types/GenerativeModel.py +9 -0
- phoenix/server/api/types/GenerativeProvider.py +16 -0
- phoenix/server/app.py +69 -7
- phoenix/server/bearer_auth.py +12 -4
- phoenix/server/main.py +2 -2
- phoenix/server/static/.vite/manifest.json +31 -31
- phoenix/server/static/assets/{components-BqfGSfjl.js → components-DwERj42u.js} +153 -130
- phoenix/server/static/assets/{index-CCZJh73q.js → index-CyOjvLOr.js} +2 -2
- phoenix/server/static/assets/{pages-CCfzLmwR.js → pages-BNYMd7SU.js} +360 -255
- phoenix/server/static/assets/vendor-D-NIjePD.js +872 -0
- phoenix/server/static/assets/{vendor-arizeai-BCDjSYK3.js → vendor-arizeai-DoY5jUTO.js} +30 -28
- phoenix/server/static/assets/vendor-codemirror-CIhY_nEU.js +24 -0
- phoenix/server/static/assets/{vendor-recharts-CJm3CJf0.js → vendor-recharts-Dgcm35Jq.js} +1 -1
- phoenix/trace/fixtures.py +7 -9
- phoenix/utilities/client.py +2 -2
- phoenix/utilities/json.py +11 -0
- phoenix/utilities/template_formatters.py +70 -0
- phoenix/version.py +1 -1
- phoenix/server/static/assets/vendor-WKqfwbiB.js +0 -641
- phoenix/server/static/assets/vendor-codemirror-B_z6EOTv.js +0 -27
- {arize_phoenix-5.2.1.dist-info → arize_phoenix-5.3.0.dist-info}/WHEEL +0 -0
- {arize_phoenix-5.2.1.dist-info → arize_phoenix-5.3.0.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-5.2.1.dist-info → arize_phoenix-5.3.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-5.2.1.dist-info → arize_phoenix-5.3.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,430 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from collections import defaultdict
|
|
3
|
+
from dataclasses import fields
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from enum import Enum
|
|
6
|
+
from itertools import chain
|
|
7
|
+
from typing import (
|
|
8
|
+
TYPE_CHECKING,
|
|
9
|
+
Annotated,
|
|
10
|
+
Any,
|
|
11
|
+
AsyncIterator,
|
|
12
|
+
DefaultDict,
|
|
13
|
+
Dict,
|
|
14
|
+
Iterable,
|
|
15
|
+
Iterator,
|
|
16
|
+
List,
|
|
17
|
+
Optional,
|
|
18
|
+
Tuple,
|
|
19
|
+
Union,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
import strawberry
|
|
23
|
+
from openinference.instrumentation import safe_json_dumps
|
|
24
|
+
from openinference.semconv.trace import (
|
|
25
|
+
MessageAttributes,
|
|
26
|
+
OpenInferenceMimeTypeValues,
|
|
27
|
+
OpenInferenceSpanKindValues,
|
|
28
|
+
SpanAttributes,
|
|
29
|
+
ToolAttributes,
|
|
30
|
+
ToolCallAttributes,
|
|
31
|
+
)
|
|
32
|
+
from opentelemetry.sdk.trace import TracerProvider
|
|
33
|
+
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
|
|
34
|
+
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
|
|
35
|
+
from opentelemetry.trace import StatusCode
|
|
36
|
+
from sqlalchemy import insert, select
|
|
37
|
+
from strawberry import UNSET
|
|
38
|
+
from strawberry.scalars import JSON as JSONScalarType
|
|
39
|
+
from strawberry.types import Info
|
|
40
|
+
from typing_extensions import TypeAlias, assert_never
|
|
41
|
+
|
|
42
|
+
from phoenix.db import models
|
|
43
|
+
from phoenix.server.api.context import Context
|
|
44
|
+
from phoenix.server.api.input_types.ChatCompletionMessageInput import ChatCompletionMessageInput
|
|
45
|
+
from phoenix.server.api.input_types.InvocationParameters import InvocationParameters
|
|
46
|
+
from phoenix.server.api.types.ChatCompletionMessageRole import ChatCompletionMessageRole
|
|
47
|
+
from phoenix.server.api.types.GenerativeProvider import GenerativeProviderKey
|
|
48
|
+
from phoenix.server.dml_event import SpanInsertEvent
|
|
49
|
+
from phoenix.trace.attributes import unflatten
|
|
50
|
+
from phoenix.utilities.json import jsonify
|
|
51
|
+
from phoenix.utilities.template_formatters import (
|
|
52
|
+
FStringTemplateFormatter,
|
|
53
|
+
MustacheTemplateFormatter,
|
|
54
|
+
TemplateFormatter,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
if TYPE_CHECKING:
|
|
58
|
+
from openai.types.chat import (
|
|
59
|
+
ChatCompletionMessageParam,
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
PLAYGROUND_PROJECT_NAME = "playground"
|
|
63
|
+
|
|
64
|
+
ToolCallIndex: TypeAlias = int
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
@strawberry.enum
|
|
68
|
+
class TemplateLanguage(Enum):
|
|
69
|
+
MUSTACHE = "MUSTACHE"
|
|
70
|
+
F_STRING = "F_STRING"
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@strawberry.input
|
|
74
|
+
class TemplateOptions:
|
|
75
|
+
variables: JSONScalarType
|
|
76
|
+
language: TemplateLanguage
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
@strawberry.type
|
|
80
|
+
class TextChunk:
|
|
81
|
+
content: str
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@strawberry.type
|
|
85
|
+
class FunctionCallChunk:
|
|
86
|
+
name: str
|
|
87
|
+
arguments: str
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@strawberry.type
|
|
91
|
+
class ToolCallChunk:
|
|
92
|
+
id: str
|
|
93
|
+
function: FunctionCallChunk
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
ChatCompletionChunk: TypeAlias = Annotated[
|
|
97
|
+
Union[TextChunk, ToolCallChunk], strawberry.union("ChatCompletionChunk")
|
|
98
|
+
]
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
@strawberry.input
|
|
102
|
+
class GenerativeModelInput:
|
|
103
|
+
provider_key: GenerativeProviderKey
|
|
104
|
+
name: str
|
|
105
|
+
""" The name of the model. Or the Deployment Name for Azure OpenAI models. """
|
|
106
|
+
endpoint: Optional[str] = UNSET
|
|
107
|
+
""" The endpoint to use for the model. Only required for Azure OpenAI models. """
|
|
108
|
+
api_version: Optional[str] = UNSET
|
|
109
|
+
""" The API version to use for the model. """
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
@strawberry.input
|
|
113
|
+
class ChatCompletionInput:
|
|
114
|
+
messages: List[ChatCompletionMessageInput]
|
|
115
|
+
model: GenerativeModelInput
|
|
116
|
+
invocation_parameters: InvocationParameters
|
|
117
|
+
tools: Optional[List[JSONScalarType]] = UNSET
|
|
118
|
+
template: Optional[TemplateOptions] = UNSET
|
|
119
|
+
api_key: Optional[str] = strawberry.field(default=None)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def to_openai_chat_completion_param(
|
|
123
|
+
role: ChatCompletionMessageRole, content: JSONScalarType
|
|
124
|
+
) -> "ChatCompletionMessageParam":
|
|
125
|
+
from openai.types.chat import (
|
|
126
|
+
ChatCompletionAssistantMessageParam,
|
|
127
|
+
ChatCompletionSystemMessageParam,
|
|
128
|
+
ChatCompletionUserMessageParam,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
if role is ChatCompletionMessageRole.USER:
|
|
132
|
+
return ChatCompletionUserMessageParam(
|
|
133
|
+
{
|
|
134
|
+
"content": content,
|
|
135
|
+
"role": "user",
|
|
136
|
+
}
|
|
137
|
+
)
|
|
138
|
+
if role is ChatCompletionMessageRole.SYSTEM:
|
|
139
|
+
return ChatCompletionSystemMessageParam(
|
|
140
|
+
{
|
|
141
|
+
"content": content,
|
|
142
|
+
"role": "system",
|
|
143
|
+
}
|
|
144
|
+
)
|
|
145
|
+
if role is ChatCompletionMessageRole.AI:
|
|
146
|
+
return ChatCompletionAssistantMessageParam(
|
|
147
|
+
{
|
|
148
|
+
"content": content,
|
|
149
|
+
"role": "assistant",
|
|
150
|
+
}
|
|
151
|
+
)
|
|
152
|
+
if role is ChatCompletionMessageRole.TOOL:
|
|
153
|
+
raise NotImplementedError
|
|
154
|
+
assert_never(role)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
@strawberry.type
|
|
158
|
+
class Subscription:
|
|
159
|
+
@strawberry.subscription
|
|
160
|
+
async def chat_completion(
|
|
161
|
+
self, info: Info[Context, None], input: ChatCompletionInput
|
|
162
|
+
) -> AsyncIterator[ChatCompletionChunk]:
|
|
163
|
+
from openai import NOT_GIVEN, AsyncAzureOpenAI, AsyncOpenAI
|
|
164
|
+
|
|
165
|
+
client: Union[AsyncAzureOpenAI, AsyncOpenAI]
|
|
166
|
+
|
|
167
|
+
if input.model.provider_key == GenerativeProviderKey.AZURE_OPENAI:
|
|
168
|
+
if input.model.endpoint is None or input.model.api_version is None:
|
|
169
|
+
raise ValueError("endpoint and api_version are required for Azure OpenAI models")
|
|
170
|
+
client = AsyncAzureOpenAI(
|
|
171
|
+
api_key=input.api_key,
|
|
172
|
+
azure_endpoint=input.model.endpoint,
|
|
173
|
+
api_version=input.model.api_version,
|
|
174
|
+
)
|
|
175
|
+
else:
|
|
176
|
+
client = AsyncOpenAI(api_key=input.api_key)
|
|
177
|
+
|
|
178
|
+
invocation_parameters = jsonify(input.invocation_parameters)
|
|
179
|
+
|
|
180
|
+
messages: List[Tuple[ChatCompletionMessageRole, str]] = [
|
|
181
|
+
(message.role, message.content) for message in input.messages
|
|
182
|
+
]
|
|
183
|
+
if template_options := input.template:
|
|
184
|
+
messages = list(_formatted_messages(messages, template_options))
|
|
185
|
+
openai_messages = [to_openai_chat_completion_param(*message) for message in messages]
|
|
186
|
+
|
|
187
|
+
in_memory_span_exporter = InMemorySpanExporter()
|
|
188
|
+
tracer_provider = TracerProvider()
|
|
189
|
+
tracer_provider.add_span_processor(
|
|
190
|
+
span_processor=SimpleSpanProcessor(span_exporter=in_memory_span_exporter)
|
|
191
|
+
)
|
|
192
|
+
tracer = tracer_provider.get_tracer(__name__)
|
|
193
|
+
span_name = "ChatCompletion"
|
|
194
|
+
with tracer.start_span(
|
|
195
|
+
span_name,
|
|
196
|
+
attributes=dict(
|
|
197
|
+
chain(
|
|
198
|
+
_llm_span_kind(),
|
|
199
|
+
_llm_model_name(input.model.name),
|
|
200
|
+
_llm_tools(input.tools or []),
|
|
201
|
+
_llm_input_messages(messages),
|
|
202
|
+
_llm_invocation_parameters(invocation_parameters),
|
|
203
|
+
_input_value_and_mime_type(input),
|
|
204
|
+
)
|
|
205
|
+
),
|
|
206
|
+
) as span:
|
|
207
|
+
response_chunks = []
|
|
208
|
+
text_chunks: List[TextChunk] = []
|
|
209
|
+
tool_call_chunks: DefaultDict[ToolCallIndex, List[ToolCallChunk]] = defaultdict(list)
|
|
210
|
+
role: Optional[str] = None
|
|
211
|
+
async for chunk in await client.chat.completions.create(
|
|
212
|
+
messages=openai_messages,
|
|
213
|
+
model=input.model.name,
|
|
214
|
+
stream=True,
|
|
215
|
+
tools=input.tools or NOT_GIVEN,
|
|
216
|
+
**invocation_parameters,
|
|
217
|
+
):
|
|
218
|
+
response_chunks.append(chunk)
|
|
219
|
+
choice = chunk.choices[0]
|
|
220
|
+
delta = choice.delta
|
|
221
|
+
if role is None:
|
|
222
|
+
role = delta.role
|
|
223
|
+
if choice.finish_reason is None:
|
|
224
|
+
if isinstance(chunk_content := delta.content, str):
|
|
225
|
+
text_chunk = TextChunk(content=chunk_content)
|
|
226
|
+
yield text_chunk
|
|
227
|
+
text_chunks.append(text_chunk)
|
|
228
|
+
if (tool_calls := delta.tool_calls) is not None:
|
|
229
|
+
for tool_call_index, tool_call in enumerate(tool_calls):
|
|
230
|
+
if (function := tool_call.function) is not None:
|
|
231
|
+
if (tool_call_id := tool_call.id) is None:
|
|
232
|
+
first_tool_call_chunk = tool_call_chunks[tool_call_index][0]
|
|
233
|
+
tool_call_id = first_tool_call_chunk.id
|
|
234
|
+
tool_call_chunk = ToolCallChunk(
|
|
235
|
+
id=tool_call_id,
|
|
236
|
+
function=FunctionCallChunk(
|
|
237
|
+
name=function.name or "",
|
|
238
|
+
arguments=function.arguments or "",
|
|
239
|
+
),
|
|
240
|
+
)
|
|
241
|
+
yield tool_call_chunk
|
|
242
|
+
tool_call_chunks[tool_call_index].append(tool_call_chunk)
|
|
243
|
+
span.set_status(StatusCode.OK)
|
|
244
|
+
assert role is not None
|
|
245
|
+
span.set_attributes(
|
|
246
|
+
dict(
|
|
247
|
+
chain(
|
|
248
|
+
_output_value_and_mime_type(response_chunks),
|
|
249
|
+
_llm_output_messages(text_chunks, tool_call_chunks),
|
|
250
|
+
)
|
|
251
|
+
)
|
|
252
|
+
)
|
|
253
|
+
assert len(spans := in_memory_span_exporter.get_finished_spans()) == 1
|
|
254
|
+
finished_span = spans[0]
|
|
255
|
+
assert finished_span.start_time is not None
|
|
256
|
+
assert finished_span.end_time is not None
|
|
257
|
+
assert (attributes := finished_span.attributes) is not None
|
|
258
|
+
start_time = _datetime(epoch_nanoseconds=finished_span.start_time)
|
|
259
|
+
end_time = _datetime(epoch_nanoseconds=finished_span.end_time)
|
|
260
|
+
trace_id = _hex(finished_span.context.trace_id)
|
|
261
|
+
span_id = _hex(finished_span.context.span_id)
|
|
262
|
+
status = finished_span.status
|
|
263
|
+
async with info.context.db() as session:
|
|
264
|
+
if (
|
|
265
|
+
playground_project_id := await session.scalar(
|
|
266
|
+
select(models.Project.id).where(models.Project.name == PLAYGROUND_PROJECT_NAME)
|
|
267
|
+
)
|
|
268
|
+
) is None:
|
|
269
|
+
playground_project_id = await session.scalar(
|
|
270
|
+
insert(models.Project)
|
|
271
|
+
.returning(models.Project.id)
|
|
272
|
+
.values(
|
|
273
|
+
name=PLAYGROUND_PROJECT_NAME,
|
|
274
|
+
description="Traces from prompt playground",
|
|
275
|
+
)
|
|
276
|
+
)
|
|
277
|
+
trace_rowid = await session.scalar(
|
|
278
|
+
insert(models.Trace)
|
|
279
|
+
.returning(models.Trace.id)
|
|
280
|
+
.values(
|
|
281
|
+
project_rowid=playground_project_id,
|
|
282
|
+
trace_id=trace_id,
|
|
283
|
+
start_time=start_time,
|
|
284
|
+
end_time=end_time,
|
|
285
|
+
)
|
|
286
|
+
)
|
|
287
|
+
await session.execute(
|
|
288
|
+
insert(models.Span).values(
|
|
289
|
+
trace_rowid=trace_rowid,
|
|
290
|
+
span_id=span_id,
|
|
291
|
+
parent_id=None,
|
|
292
|
+
name=span_name,
|
|
293
|
+
span_kind=LLM,
|
|
294
|
+
start_time=start_time,
|
|
295
|
+
end_time=end_time,
|
|
296
|
+
attributes=unflatten(attributes.items()),
|
|
297
|
+
events=finished_span.events,
|
|
298
|
+
status_code=status.status_code.name,
|
|
299
|
+
status_message=status.description or "",
|
|
300
|
+
cumulative_error_count=int(not status.is_ok),
|
|
301
|
+
cumulative_llm_token_count_prompt=0,
|
|
302
|
+
cumulative_llm_token_count_completion=0,
|
|
303
|
+
llm_token_count_prompt=0,
|
|
304
|
+
llm_token_count_completion=0,
|
|
305
|
+
)
|
|
306
|
+
)
|
|
307
|
+
info.context.event_queue.put(SpanInsertEvent(ids=(playground_project_id,)))
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
def _llm_span_kind() -> Iterator[Tuple[str, Any]]:
|
|
311
|
+
yield OPENINFERENCE_SPAN_KIND, LLM
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
def _llm_model_name(model_name: str) -> Iterator[Tuple[str, Any]]:
|
|
315
|
+
yield LLM_MODEL_NAME, model_name
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
def _llm_invocation_parameters(invocation_parameters: Dict[str, Any]) -> Iterator[Tuple[str, Any]]:
|
|
319
|
+
yield LLM_INVOCATION_PARAMETERS, safe_json_dumps(invocation_parameters)
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
def _llm_tools(tools: List[JSONScalarType]) -> Iterator[Tuple[str, Any]]:
|
|
323
|
+
for tool_index, tool in enumerate(tools):
|
|
324
|
+
yield f"{LLM_TOOLS}.{tool_index}.{TOOL_JSON_SCHEMA}", json.dumps(tool)
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
def _input_value_and_mime_type(input: ChatCompletionInput) -> Iterator[Tuple[str, Any]]:
|
|
328
|
+
assert any(field.name == (api_key := "api_key") for field in fields(ChatCompletionInput))
|
|
329
|
+
yield INPUT_MIME_TYPE, JSON
|
|
330
|
+
yield INPUT_VALUE, safe_json_dumps({k: v for k, v in jsonify(input).items() if k != api_key})
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
def _output_value_and_mime_type(output: Any) -> Iterator[Tuple[str, Any]]:
|
|
334
|
+
yield OUTPUT_MIME_TYPE, JSON
|
|
335
|
+
yield OUTPUT_VALUE, safe_json_dumps(jsonify(output))
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
def _llm_input_messages(
|
|
339
|
+
messages: Iterable[Tuple[ChatCompletionMessageRole, str]],
|
|
340
|
+
) -> Iterator[Tuple[str, Any]]:
|
|
341
|
+
for i, (role, content) in enumerate(messages):
|
|
342
|
+
yield f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_ROLE}", role.value.lower()
|
|
343
|
+
yield f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_CONTENT}", content
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
def _llm_output_messages(
|
|
347
|
+
text_chunks: List[TextChunk],
|
|
348
|
+
tool_call_chunks: DefaultDict[ToolCallIndex, List[ToolCallChunk]],
|
|
349
|
+
) -> Iterator[Tuple[str, Any]]:
|
|
350
|
+
yield f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_ROLE}", "assistant"
|
|
351
|
+
if content := "".join(chunk.content for chunk in text_chunks):
|
|
352
|
+
yield f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_CONTENT}", content
|
|
353
|
+
for tool_call_index, tool_call_chunks_ in tool_call_chunks.items():
|
|
354
|
+
if tool_call_chunks_ and (name := tool_call_chunks_[0].function.name):
|
|
355
|
+
yield (
|
|
356
|
+
f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_NAME}",
|
|
357
|
+
name,
|
|
358
|
+
)
|
|
359
|
+
if arguments := "".join(chunk.function.arguments for chunk in tool_call_chunks_):
|
|
360
|
+
yield (
|
|
361
|
+
f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_ARGUMENTS_JSON}",
|
|
362
|
+
arguments,
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
def _hex(number: int) -> str:
|
|
367
|
+
"""
|
|
368
|
+
Converts an integer to a hexadecimal string.
|
|
369
|
+
"""
|
|
370
|
+
return hex(number)[2:]
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
def _datetime(*, epoch_nanoseconds: float) -> datetime:
|
|
374
|
+
"""
|
|
375
|
+
Converts a Unix epoch timestamp in nanoseconds to a datetime.
|
|
376
|
+
"""
|
|
377
|
+
epoch_seconds = epoch_nanoseconds / 1e9
|
|
378
|
+
return datetime.fromtimestamp(epoch_seconds)
|
|
379
|
+
|
|
380
|
+
|
|
381
|
+
def _formatted_messages(
|
|
382
|
+
messages: Iterable[Tuple[ChatCompletionMessageRole, str]], template_options: TemplateOptions
|
|
383
|
+
) -> Iterator[Tuple[ChatCompletionMessageRole, str]]:
|
|
384
|
+
"""
|
|
385
|
+
Formats the messages using the given template options.
|
|
386
|
+
"""
|
|
387
|
+
template_formatter = _template_formatter(template_language=template_options.language)
|
|
388
|
+
roles, templates = zip(*messages)
|
|
389
|
+
formatted_templates = map(
|
|
390
|
+
lambda template: template_formatter.format(template, **template_options.variables),
|
|
391
|
+
templates,
|
|
392
|
+
)
|
|
393
|
+
formatted_messages = zip(roles, formatted_templates)
|
|
394
|
+
return formatted_messages
|
|
395
|
+
|
|
396
|
+
|
|
397
|
+
def _template_formatter(template_language: TemplateLanguage) -> TemplateFormatter:
|
|
398
|
+
"""
|
|
399
|
+
Instantiates the appropriate template formatter for the template language.
|
|
400
|
+
"""
|
|
401
|
+
if template_language is TemplateLanguage.MUSTACHE:
|
|
402
|
+
return MustacheTemplateFormatter()
|
|
403
|
+
if template_language is TemplateLanguage.F_STRING:
|
|
404
|
+
return FStringTemplateFormatter()
|
|
405
|
+
assert_never(template_language)
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
JSON = OpenInferenceMimeTypeValues.JSON.value
|
|
409
|
+
|
|
410
|
+
LLM = OpenInferenceSpanKindValues.LLM.value
|
|
411
|
+
|
|
412
|
+
OPENINFERENCE_SPAN_KIND = SpanAttributes.OPENINFERENCE_SPAN_KIND
|
|
413
|
+
INPUT_MIME_TYPE = SpanAttributes.INPUT_MIME_TYPE
|
|
414
|
+
INPUT_VALUE = SpanAttributes.INPUT_VALUE
|
|
415
|
+
OUTPUT_MIME_TYPE = SpanAttributes.OUTPUT_MIME_TYPE
|
|
416
|
+
OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE
|
|
417
|
+
LLM_INPUT_MESSAGES = SpanAttributes.LLM_INPUT_MESSAGES
|
|
418
|
+
LLM_OUTPUT_MESSAGES = SpanAttributes.LLM_OUTPUT_MESSAGES
|
|
419
|
+
LLM_MODEL_NAME = SpanAttributes.LLM_MODEL_NAME
|
|
420
|
+
LLM_INVOCATION_PARAMETERS = SpanAttributes.LLM_INVOCATION_PARAMETERS
|
|
421
|
+
LLM_TOOLS = SpanAttributes.LLM_TOOLS
|
|
422
|
+
|
|
423
|
+
MESSAGE_CONTENT = MessageAttributes.MESSAGE_CONTENT
|
|
424
|
+
MESSAGE_ROLE = MessageAttributes.MESSAGE_ROLE
|
|
425
|
+
MESSAGE_TOOL_CALLS = MessageAttributes.MESSAGE_TOOL_CALLS
|
|
426
|
+
|
|
427
|
+
TOOL_CALL_FUNCTION_NAME = ToolCallAttributes.TOOL_CALL_FUNCTION_NAME
|
|
428
|
+
TOOL_CALL_FUNCTION_ARGUMENTS_JSON = ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUMENTS_JSON
|
|
429
|
+
|
|
430
|
+
TOOL_JSON_SCHEMA = ToolAttributes.TOOL_JSON_SCHEMA
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
|
|
3
|
+
import strawberry
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@strawberry.enum
|
|
7
|
+
class GenerativeProviderKey(Enum):
|
|
8
|
+
OPENAI = "OPENAI"
|
|
9
|
+
ANTHROPIC = "ANTHROPIC"
|
|
10
|
+
AZURE_OPENAI = "AZURE_OPENAI"
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@strawberry.type
|
|
14
|
+
class GenerativeProvider:
|
|
15
|
+
name: str
|
|
16
|
+
key: GenerativeProviderKey
|
phoenix/server/app.py
CHANGED
|
@@ -27,6 +27,7 @@ from typing import (
|
|
|
27
27
|
Union,
|
|
28
28
|
cast,
|
|
29
29
|
)
|
|
30
|
+
from urllib.parse import urlparse
|
|
30
31
|
|
|
31
32
|
import strawberry
|
|
32
33
|
from fastapi import APIRouter, Depends, FastAPI
|
|
@@ -35,26 +36,30 @@ from fastapi.utils import is_body_allowed_for_status_code
|
|
|
35
36
|
from sqlalchemy import select
|
|
36
37
|
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
|
|
37
38
|
from starlette.datastructures import State as StarletteState
|
|
38
|
-
from starlette.exceptions import HTTPException
|
|
39
|
+
from starlette.exceptions import HTTPException, WebSocketException
|
|
39
40
|
from starlette.middleware import Middleware
|
|
40
41
|
from starlette.middleware.authentication import AuthenticationMiddleware
|
|
41
42
|
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
|
42
43
|
from starlette.requests import Request
|
|
43
|
-
from starlette.responses import PlainTextResponse, Response
|
|
44
|
+
from starlette.responses import JSONResponse, PlainTextResponse, Response
|
|
44
45
|
from starlette.staticfiles import StaticFiles
|
|
46
|
+
from starlette.status import HTTP_401_UNAUTHORIZED
|
|
45
47
|
from starlette.templating import Jinja2Templates
|
|
46
48
|
from starlette.types import Scope, StatefulLifespan
|
|
49
|
+
from starlette.websockets import WebSocket
|
|
47
50
|
from strawberry.extensions import SchemaExtension
|
|
48
51
|
from strawberry.fastapi import GraphQLRouter
|
|
49
52
|
from strawberry.schema import BaseSchema
|
|
53
|
+
from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL
|
|
50
54
|
from typing_extensions import TypeAlias
|
|
51
55
|
|
|
52
|
-
import phoenix
|
|
53
56
|
import phoenix.trace.v1 as pb
|
|
54
57
|
from phoenix.config import (
|
|
55
58
|
DEFAULT_PROJECT_NAME,
|
|
59
|
+
ENV_PHOENIX_CSRF_TRUSTED_ORIGINS,
|
|
56
60
|
SERVER_DIR,
|
|
57
61
|
OAuth2ClientConfig,
|
|
62
|
+
get_env_csrf_trusted_origins,
|
|
58
63
|
get_env_host,
|
|
59
64
|
get_env_port,
|
|
60
65
|
server_instrumentation_is_enabled,
|
|
@@ -131,6 +136,7 @@ from phoenix.trace.fixtures import (
|
|
|
131
136
|
from phoenix.trace.otel import decode_otlp_span, encode_span_to_otlp
|
|
132
137
|
from phoenix.trace.schemas import Span
|
|
133
138
|
from phoenix.utilities.client import PHOENIX_SERVER_VERSION_HEADER
|
|
139
|
+
from phoenix.version import __version__ as phoenix_version
|
|
134
140
|
|
|
135
141
|
if TYPE_CHECKING:
|
|
136
142
|
from opentelemetry.trace import TracerProvider
|
|
@@ -213,7 +219,7 @@ class Static(StaticFiles):
|
|
|
213
219
|
"n_neighbors": self._app_config.n_neighbors,
|
|
214
220
|
"n_samples": self._app_config.n_samples,
|
|
215
221
|
"basename": self._sanitize_basename(request.scope.get("root_path", "")),
|
|
216
|
-
"platform_version":
|
|
222
|
+
"platform_version": phoenix_version,
|
|
217
223
|
"request": request,
|
|
218
224
|
"is_development": self._app_config.is_development,
|
|
219
225
|
"manifest": self._web_manifest,
|
|
@@ -226,13 +232,32 @@ class Static(StaticFiles):
|
|
|
226
232
|
return response
|
|
227
233
|
|
|
228
234
|
|
|
235
|
+
class RequestOriginHostnameValidator(BaseHTTPMiddleware):
|
|
236
|
+
def __init__(self, trusted_hostnames: List[str], *args: Any, **kwargs: Any) -> None:
|
|
237
|
+
super().__init__(*args, **kwargs)
|
|
238
|
+
self._trusted_hostnames = trusted_hostnames
|
|
239
|
+
|
|
240
|
+
async def dispatch(
|
|
241
|
+
self,
|
|
242
|
+
request: Request,
|
|
243
|
+
call_next: RequestResponseEndpoint,
|
|
244
|
+
) -> Response:
|
|
245
|
+
headers = request.headers
|
|
246
|
+
for key in "origin", "referer":
|
|
247
|
+
if not (url := headers.get(key)):
|
|
248
|
+
continue
|
|
249
|
+
if urlparse(url).hostname not in self._trusted_hostnames:
|
|
250
|
+
return Response(f"untrusted {key}", status_code=HTTP_401_UNAUTHORIZED)
|
|
251
|
+
return await call_next(request)
|
|
252
|
+
|
|
253
|
+
|
|
229
254
|
class HeadersMiddleware(BaseHTTPMiddleware):
|
|
230
255
|
async def dispatch(
|
|
231
256
|
self,
|
|
232
257
|
request: Request,
|
|
233
258
|
call_next: RequestResponseEndpoint,
|
|
234
259
|
) -> Response:
|
|
235
|
-
from phoenix import __version__ as phoenix_version
|
|
260
|
+
from phoenix.version import __version__ as phoenix_version
|
|
236
261
|
|
|
237
262
|
response = await call_next(request)
|
|
238
263
|
response.headers["x-colab-notebook-cache-control"] = "no-cache"
|
|
@@ -245,7 +270,7 @@ ProjectRowId: TypeAlias = int
|
|
|
245
270
|
|
|
246
271
|
@router.get("/arize_phoenix_version")
|
|
247
272
|
async def version() -> PlainTextResponse:
|
|
248
|
-
return PlainTextResponse(f"{
|
|
273
|
+
return PlainTextResponse(f"{phoenix_version}")
|
|
249
274
|
|
|
250
275
|
|
|
251
276
|
DB_MUTEX: Optional[asyncio.Lock] = None
|
|
@@ -557,6 +582,7 @@ def create_graphql_router(
|
|
|
557
582
|
include_in_schema=False,
|
|
558
583
|
prefix="/graphql",
|
|
559
584
|
dependencies=(Depends(is_authenticated),) if authentication_enabled else (),
|
|
585
|
+
subscription_protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL],
|
|
560
586
|
)
|
|
561
587
|
|
|
562
588
|
|
|
@@ -607,6 +633,29 @@ async def plain_text_http_exception_handler(request: Request, exc: HTTPException
|
|
|
607
633
|
return PlainTextResponse(str(exc.detail), status_code=exc.status_code, headers=headers)
|
|
608
634
|
|
|
609
635
|
|
|
636
|
+
async def websocket_denial_response_handler(websocket: WebSocket, exc: WebSocketException) -> None:
|
|
637
|
+
"""
|
|
638
|
+
Overrides the default exception handler for WebSocketException to ensure
|
|
639
|
+
that the HTTP response returned when a WebSocket connection is denied has
|
|
640
|
+
the same status code as the raised exception. This is in keeping with the
|
|
641
|
+
WebSocket Denial Response Extension of the ASGI specificiation described
|
|
642
|
+
below.
|
|
643
|
+
|
|
644
|
+
"Websocket connections start with the client sending a HTTP request
|
|
645
|
+
containing the appropriate upgrade headers. On receipt of this request a
|
|
646
|
+
server can choose to either upgrade the connection or respond with an HTTP
|
|
647
|
+
response (denying the upgrade). The core ASGI specification does not allow
|
|
648
|
+
for any control over the denial response, instead specifying that the HTTP
|
|
649
|
+
status code 403 should be returned, whereas this extension allows an ASGI
|
|
650
|
+
framework to control the denial response."
|
|
651
|
+
|
|
652
|
+
For details, see:
|
|
653
|
+
- https://asgi.readthedocs.io/en/latest/extensions.html#websocket-denial-response
|
|
654
|
+
"""
|
|
655
|
+
assert isinstance(exc, WebSocketException)
|
|
656
|
+
await websocket.send_denial_response(JSONResponse(status_code=exc.code, content=exc.reason))
|
|
657
|
+
|
|
658
|
+
|
|
610
659
|
def create_app(
|
|
611
660
|
db: DbSessionFactory,
|
|
612
661
|
export_path: Path,
|
|
@@ -660,6 +709,16 @@ def create_app(
|
|
|
660
709
|
)
|
|
661
710
|
last_updated_at = LastUpdatedAt()
|
|
662
711
|
middlewares: List[Middleware] = [Middleware(HeadersMiddleware)]
|
|
712
|
+
if origins := get_env_csrf_trusted_origins():
|
|
713
|
+
trusted_hostnames = [h for o in origins if o and (h := urlparse(o).hostname)]
|
|
714
|
+
middlewares.append(Middleware(RequestOriginHostnameValidator, trusted_hostnames))
|
|
715
|
+
elif email_sender or oauth2_client_configs:
|
|
716
|
+
logger.warning(
|
|
717
|
+
"CSRF protection can be enabled by listing trusted origins via "
|
|
718
|
+
f"the `{ENV_PHOENIX_CSRF_TRUSTED_ORIGINS}` environment variable. "
|
|
719
|
+
"This is recommended when setting up OAuth2 clients or sending "
|
|
720
|
+
"password reset emails."
|
|
721
|
+
)
|
|
663
722
|
if authentication_enabled and secret:
|
|
664
723
|
token_store = JwtStore(db, secret)
|
|
665
724
|
middlewares.append(
|
|
@@ -743,7 +802,10 @@ def create_app(
|
|
|
743
802
|
scaffolder_config=scaffolder_config,
|
|
744
803
|
),
|
|
745
804
|
middleware=middlewares,
|
|
746
|
-
exception_handlers={
|
|
805
|
+
exception_handlers={
|
|
806
|
+
HTTPException: plain_text_http_exception_handler,
|
|
807
|
+
WebSocketException: websocket_denial_response_handler, # type: ignore[dict-item]
|
|
808
|
+
},
|
|
747
809
|
debug=debug,
|
|
748
810
|
swagger_ui_parameters={
|
|
749
811
|
"defaultModelsExpandDepth": -1, # hides the schema section in the Swagger UI
|
phoenix/server/bearer_auth.py
CHANGED
|
@@ -7,10 +7,11 @@ from typing import (
|
|
|
7
7
|
Callable,
|
|
8
8
|
Optional,
|
|
9
9
|
Tuple,
|
|
10
|
+
cast,
|
|
10
11
|
)
|
|
11
12
|
|
|
12
13
|
import grpc
|
|
13
|
-
from fastapi import HTTPException, Request
|
|
14
|
+
from fastapi import HTTPException, Request, WebSocket, WebSocketException
|
|
14
15
|
from grpc_interceptor import AsyncServerInterceptor
|
|
15
16
|
from grpc_interceptor.exceptions import Unauthenticated
|
|
16
17
|
from starlette.authentication import AuthCredentials, AuthenticationBackend, BaseUser
|
|
@@ -116,12 +117,19 @@ class ApiKeyInterceptor(HasTokenStore, AsyncServerInterceptor):
|
|
|
116
117
|
raise Unauthenticated()
|
|
117
118
|
|
|
118
119
|
|
|
119
|
-
async def is_authenticated(
|
|
120
|
+
async def is_authenticated(
|
|
121
|
+
# fastapi dependencies require non-optional types
|
|
122
|
+
request: Request = cast(Request, None),
|
|
123
|
+
websocket: WebSocket = cast(WebSocket, None),
|
|
124
|
+
) -> None:
|
|
120
125
|
"""
|
|
121
|
-
Raises a 401 if the request is not authenticated.
|
|
126
|
+
Raises a 401 if the request or websocket connection is not authenticated.
|
|
122
127
|
"""
|
|
123
|
-
|
|
128
|
+
assert request or websocket
|
|
129
|
+
if request and not isinstance((user := request.user), PhoenixUser):
|
|
124
130
|
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Invalid token")
|
|
131
|
+
if websocket and not isinstance((user := websocket.user), PhoenixUser):
|
|
132
|
+
raise WebSocketException(code=HTTP_401_UNAUTHORIZED, reason="Invalid token")
|
|
125
133
|
claims = user.claims
|
|
126
134
|
if claims.status is ClaimSetStatus.EXPIRED:
|
|
127
135
|
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Expired token")
|
phoenix/server/main.py
CHANGED
|
@@ -3,7 +3,6 @@ import codecs
|
|
|
3
3
|
import os
|
|
4
4
|
import sys
|
|
5
5
|
from argparse import SUPPRESS, ArgumentParser
|
|
6
|
-
from importlib.metadata import version
|
|
7
6
|
from pathlib import Path
|
|
8
7
|
from threading import Thread
|
|
9
8
|
from time import sleep, time
|
|
@@ -72,6 +71,7 @@ from phoenix.trace.fixtures import (
|
|
|
72
71
|
)
|
|
73
72
|
from phoenix.trace.otel import decode_otlp_span, encode_span_to_otlp
|
|
74
73
|
from phoenix.trace.schemas import Span
|
|
74
|
+
from phoenix.version import __version__ as phoenix_version
|
|
75
75
|
|
|
76
76
|
_WELCOME_MESSAGE = Environment(loader=BaseLoader()).from_string("""
|
|
77
77
|
|
|
@@ -351,7 +351,7 @@ def main() -> None:
|
|
|
351
351
|
# Print information about the server
|
|
352
352
|
root_path = urljoin(f"http://{host}:{port}", host_root_path)
|
|
353
353
|
msg = _WELCOME_MESSAGE.render(
|
|
354
|
-
version=
|
|
354
|
+
version=phoenix_version,
|
|
355
355
|
ui_path=root_path,
|
|
356
356
|
grpc_path=f"http://{host}:{get_env_grpc_port()}",
|
|
357
357
|
http_path=urljoin(root_path, "v1/traces"),
|