arize-phoenix 5.2.2__py3-none-any.whl → 5.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (38) hide show
  1. {arize_phoenix-5.2.2.dist-info → arize_phoenix-5.3.0.dist-info}/METADATA +2 -1
  2. {arize_phoenix-5.2.2.dist-info → arize_phoenix-5.3.0.dist-info}/RECORD +36 -29
  3. phoenix/core/model_schema_adapter.py +2 -2
  4. phoenix/db/migrations/versions/10460e46d750_datasets.py +1 -1
  5. phoenix/db/migrations/versions/3be8647b87d8_add_token_columns_to_spans_table.py +1 -1
  6. phoenix/db/migrations/versions/cf03bd6bae1d_init.py +1 -1
  7. phoenix/db/models.py +6 -6
  8. phoenix/server/api/input_types/ChatCompletionMessageInput.py +12 -0
  9. phoenix/server/api/input_types/InvocationParameters.py +20 -0
  10. phoenix/server/api/openapi/main.py +8 -13
  11. phoenix/server/api/queries.py +74 -0
  12. phoenix/server/api/schema.py +2 -0
  13. phoenix/server/api/subscriptions.py +430 -0
  14. phoenix/server/api/types/ChatCompletionMessageRole.py +11 -0
  15. phoenix/server/api/types/GenerativeModel.py +9 -0
  16. phoenix/server/api/types/GenerativeProvider.py +16 -0
  17. phoenix/server/app.py +36 -7
  18. phoenix/server/bearer_auth.py +12 -4
  19. phoenix/server/main.py +2 -2
  20. phoenix/server/static/.vite/manifest.json +35 -35
  21. phoenix/server/static/assets/{components-CPkaHQZs.js → components-DwERj42u.js} +153 -130
  22. phoenix/server/static/assets/{index-D4G-kYL9.js → index-CyOjvLOr.js} +2 -2
  23. phoenix/server/static/assets/{pages-B3HCyYqg.js → pages-BNYMd7SU.js} +360 -255
  24. phoenix/server/static/assets/vendor-D-NIjePD.js +872 -0
  25. phoenix/server/static/assets/{vendor-arizeai-D0FocbYu.js → vendor-arizeai-DoY5jUTO.js} +30 -28
  26. phoenix/server/static/assets/vendor-codemirror-CIhY_nEU.js +24 -0
  27. phoenix/server/static/assets/{vendor-recharts-JMOLUxWG.js → vendor-recharts-Dgcm35Jq.js} +1 -1
  28. phoenix/trace/fixtures.py +7 -9
  29. phoenix/utilities/client.py +2 -2
  30. phoenix/utilities/json.py +11 -0
  31. phoenix/utilities/template_formatters.py +70 -0
  32. phoenix/version.py +1 -1
  33. phoenix/server/static/assets/vendor-c0HJYYGN.js +0 -641
  34. phoenix/server/static/assets/vendor-codemirror-BbCMI-_D.js +0 -30
  35. {arize_phoenix-5.2.2.dist-info → arize_phoenix-5.3.0.dist-info}/WHEEL +0 -0
  36. {arize_phoenix-5.2.2.dist-info → arize_phoenix-5.3.0.dist-info}/entry_points.txt +0 -0
  37. {arize_phoenix-5.2.2.dist-info → arize_phoenix-5.3.0.dist-info}/licenses/IP_NOTICE +0 -0
  38. {arize_phoenix-5.2.2.dist-info → arize_phoenix-5.3.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,430 @@
1
+ import json
2
+ from collections import defaultdict
3
+ from dataclasses import fields
4
+ from datetime import datetime
5
+ from enum import Enum
6
+ from itertools import chain
7
+ from typing import (
8
+ TYPE_CHECKING,
9
+ Annotated,
10
+ Any,
11
+ AsyncIterator,
12
+ DefaultDict,
13
+ Dict,
14
+ Iterable,
15
+ Iterator,
16
+ List,
17
+ Optional,
18
+ Tuple,
19
+ Union,
20
+ )
21
+
22
+ import strawberry
23
+ from openinference.instrumentation import safe_json_dumps
24
+ from openinference.semconv.trace import (
25
+ MessageAttributes,
26
+ OpenInferenceMimeTypeValues,
27
+ OpenInferenceSpanKindValues,
28
+ SpanAttributes,
29
+ ToolAttributes,
30
+ ToolCallAttributes,
31
+ )
32
+ from opentelemetry.sdk.trace import TracerProvider
33
+ from opentelemetry.sdk.trace.export import SimpleSpanProcessor
34
+ from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
35
+ from opentelemetry.trace import StatusCode
36
+ from sqlalchemy import insert, select
37
+ from strawberry import UNSET
38
+ from strawberry.scalars import JSON as JSONScalarType
39
+ from strawberry.types import Info
40
+ from typing_extensions import TypeAlias, assert_never
41
+
42
+ from phoenix.db import models
43
+ from phoenix.server.api.context import Context
44
+ from phoenix.server.api.input_types.ChatCompletionMessageInput import ChatCompletionMessageInput
45
+ from phoenix.server.api.input_types.InvocationParameters import InvocationParameters
46
+ from phoenix.server.api.types.ChatCompletionMessageRole import ChatCompletionMessageRole
47
+ from phoenix.server.api.types.GenerativeProvider import GenerativeProviderKey
48
+ from phoenix.server.dml_event import SpanInsertEvent
49
+ from phoenix.trace.attributes import unflatten
50
+ from phoenix.utilities.json import jsonify
51
+ from phoenix.utilities.template_formatters import (
52
+ FStringTemplateFormatter,
53
+ MustacheTemplateFormatter,
54
+ TemplateFormatter,
55
+ )
56
+
57
+ if TYPE_CHECKING:
58
+ from openai.types.chat import (
59
+ ChatCompletionMessageParam,
60
+ )
61
+
62
+ PLAYGROUND_PROJECT_NAME = "playground"
63
+
64
+ ToolCallIndex: TypeAlias = int
65
+
66
+
67
+ @strawberry.enum
68
+ class TemplateLanguage(Enum):
69
+ MUSTACHE = "MUSTACHE"
70
+ F_STRING = "F_STRING"
71
+
72
+
73
+ @strawberry.input
74
+ class TemplateOptions:
75
+ variables: JSONScalarType
76
+ language: TemplateLanguage
77
+
78
+
79
+ @strawberry.type
80
+ class TextChunk:
81
+ content: str
82
+
83
+
84
+ @strawberry.type
85
+ class FunctionCallChunk:
86
+ name: str
87
+ arguments: str
88
+
89
+
90
+ @strawberry.type
91
+ class ToolCallChunk:
92
+ id: str
93
+ function: FunctionCallChunk
94
+
95
+
96
+ ChatCompletionChunk: TypeAlias = Annotated[
97
+ Union[TextChunk, ToolCallChunk], strawberry.union("ChatCompletionChunk")
98
+ ]
99
+
100
+
101
+ @strawberry.input
102
+ class GenerativeModelInput:
103
+ provider_key: GenerativeProviderKey
104
+ name: str
105
+ """ The name of the model. Or the Deployment Name for Azure OpenAI models. """
106
+ endpoint: Optional[str] = UNSET
107
+ """ The endpoint to use for the model. Only required for Azure OpenAI models. """
108
+ api_version: Optional[str] = UNSET
109
+ """ The API version to use for the model. """
110
+
111
+
112
+ @strawberry.input
113
+ class ChatCompletionInput:
114
+ messages: List[ChatCompletionMessageInput]
115
+ model: GenerativeModelInput
116
+ invocation_parameters: InvocationParameters
117
+ tools: Optional[List[JSONScalarType]] = UNSET
118
+ template: Optional[TemplateOptions] = UNSET
119
+ api_key: Optional[str] = strawberry.field(default=None)
120
+
121
+
122
+ def to_openai_chat_completion_param(
123
+ role: ChatCompletionMessageRole, content: JSONScalarType
124
+ ) -> "ChatCompletionMessageParam":
125
+ from openai.types.chat import (
126
+ ChatCompletionAssistantMessageParam,
127
+ ChatCompletionSystemMessageParam,
128
+ ChatCompletionUserMessageParam,
129
+ )
130
+
131
+ if role is ChatCompletionMessageRole.USER:
132
+ return ChatCompletionUserMessageParam(
133
+ {
134
+ "content": content,
135
+ "role": "user",
136
+ }
137
+ )
138
+ if role is ChatCompletionMessageRole.SYSTEM:
139
+ return ChatCompletionSystemMessageParam(
140
+ {
141
+ "content": content,
142
+ "role": "system",
143
+ }
144
+ )
145
+ if role is ChatCompletionMessageRole.AI:
146
+ return ChatCompletionAssistantMessageParam(
147
+ {
148
+ "content": content,
149
+ "role": "assistant",
150
+ }
151
+ )
152
+ if role is ChatCompletionMessageRole.TOOL:
153
+ raise NotImplementedError
154
+ assert_never(role)
155
+
156
+
157
+ @strawberry.type
158
+ class Subscription:
159
+ @strawberry.subscription
160
+ async def chat_completion(
161
+ self, info: Info[Context, None], input: ChatCompletionInput
162
+ ) -> AsyncIterator[ChatCompletionChunk]:
163
+ from openai import NOT_GIVEN, AsyncAzureOpenAI, AsyncOpenAI
164
+
165
+ client: Union[AsyncAzureOpenAI, AsyncOpenAI]
166
+
167
+ if input.model.provider_key == GenerativeProviderKey.AZURE_OPENAI:
168
+ if input.model.endpoint is None or input.model.api_version is None:
169
+ raise ValueError("endpoint and api_version are required for Azure OpenAI models")
170
+ client = AsyncAzureOpenAI(
171
+ api_key=input.api_key,
172
+ azure_endpoint=input.model.endpoint,
173
+ api_version=input.model.api_version,
174
+ )
175
+ else:
176
+ client = AsyncOpenAI(api_key=input.api_key)
177
+
178
+ invocation_parameters = jsonify(input.invocation_parameters)
179
+
180
+ messages: List[Tuple[ChatCompletionMessageRole, str]] = [
181
+ (message.role, message.content) for message in input.messages
182
+ ]
183
+ if template_options := input.template:
184
+ messages = list(_formatted_messages(messages, template_options))
185
+ openai_messages = [to_openai_chat_completion_param(*message) for message in messages]
186
+
187
+ in_memory_span_exporter = InMemorySpanExporter()
188
+ tracer_provider = TracerProvider()
189
+ tracer_provider.add_span_processor(
190
+ span_processor=SimpleSpanProcessor(span_exporter=in_memory_span_exporter)
191
+ )
192
+ tracer = tracer_provider.get_tracer(__name__)
193
+ span_name = "ChatCompletion"
194
+ with tracer.start_span(
195
+ span_name,
196
+ attributes=dict(
197
+ chain(
198
+ _llm_span_kind(),
199
+ _llm_model_name(input.model.name),
200
+ _llm_tools(input.tools or []),
201
+ _llm_input_messages(messages),
202
+ _llm_invocation_parameters(invocation_parameters),
203
+ _input_value_and_mime_type(input),
204
+ )
205
+ ),
206
+ ) as span:
207
+ response_chunks = []
208
+ text_chunks: List[TextChunk] = []
209
+ tool_call_chunks: DefaultDict[ToolCallIndex, List[ToolCallChunk]] = defaultdict(list)
210
+ role: Optional[str] = None
211
+ async for chunk in await client.chat.completions.create(
212
+ messages=openai_messages,
213
+ model=input.model.name,
214
+ stream=True,
215
+ tools=input.tools or NOT_GIVEN,
216
+ **invocation_parameters,
217
+ ):
218
+ response_chunks.append(chunk)
219
+ choice = chunk.choices[0]
220
+ delta = choice.delta
221
+ if role is None:
222
+ role = delta.role
223
+ if choice.finish_reason is None:
224
+ if isinstance(chunk_content := delta.content, str):
225
+ text_chunk = TextChunk(content=chunk_content)
226
+ yield text_chunk
227
+ text_chunks.append(text_chunk)
228
+ if (tool_calls := delta.tool_calls) is not None:
229
+ for tool_call_index, tool_call in enumerate(tool_calls):
230
+ if (function := tool_call.function) is not None:
231
+ if (tool_call_id := tool_call.id) is None:
232
+ first_tool_call_chunk = tool_call_chunks[tool_call_index][0]
233
+ tool_call_id = first_tool_call_chunk.id
234
+ tool_call_chunk = ToolCallChunk(
235
+ id=tool_call_id,
236
+ function=FunctionCallChunk(
237
+ name=function.name or "",
238
+ arguments=function.arguments or "",
239
+ ),
240
+ )
241
+ yield tool_call_chunk
242
+ tool_call_chunks[tool_call_index].append(tool_call_chunk)
243
+ span.set_status(StatusCode.OK)
244
+ assert role is not None
245
+ span.set_attributes(
246
+ dict(
247
+ chain(
248
+ _output_value_and_mime_type(response_chunks),
249
+ _llm_output_messages(text_chunks, tool_call_chunks),
250
+ )
251
+ )
252
+ )
253
+ assert len(spans := in_memory_span_exporter.get_finished_spans()) == 1
254
+ finished_span = spans[0]
255
+ assert finished_span.start_time is not None
256
+ assert finished_span.end_time is not None
257
+ assert (attributes := finished_span.attributes) is not None
258
+ start_time = _datetime(epoch_nanoseconds=finished_span.start_time)
259
+ end_time = _datetime(epoch_nanoseconds=finished_span.end_time)
260
+ trace_id = _hex(finished_span.context.trace_id)
261
+ span_id = _hex(finished_span.context.span_id)
262
+ status = finished_span.status
263
+ async with info.context.db() as session:
264
+ if (
265
+ playground_project_id := await session.scalar(
266
+ select(models.Project.id).where(models.Project.name == PLAYGROUND_PROJECT_NAME)
267
+ )
268
+ ) is None:
269
+ playground_project_id = await session.scalar(
270
+ insert(models.Project)
271
+ .returning(models.Project.id)
272
+ .values(
273
+ name=PLAYGROUND_PROJECT_NAME,
274
+ description="Traces from prompt playground",
275
+ )
276
+ )
277
+ trace_rowid = await session.scalar(
278
+ insert(models.Trace)
279
+ .returning(models.Trace.id)
280
+ .values(
281
+ project_rowid=playground_project_id,
282
+ trace_id=trace_id,
283
+ start_time=start_time,
284
+ end_time=end_time,
285
+ )
286
+ )
287
+ await session.execute(
288
+ insert(models.Span).values(
289
+ trace_rowid=trace_rowid,
290
+ span_id=span_id,
291
+ parent_id=None,
292
+ name=span_name,
293
+ span_kind=LLM,
294
+ start_time=start_time,
295
+ end_time=end_time,
296
+ attributes=unflatten(attributes.items()),
297
+ events=finished_span.events,
298
+ status_code=status.status_code.name,
299
+ status_message=status.description or "",
300
+ cumulative_error_count=int(not status.is_ok),
301
+ cumulative_llm_token_count_prompt=0,
302
+ cumulative_llm_token_count_completion=0,
303
+ llm_token_count_prompt=0,
304
+ llm_token_count_completion=0,
305
+ )
306
+ )
307
+ info.context.event_queue.put(SpanInsertEvent(ids=(playground_project_id,)))
308
+
309
+
310
+ def _llm_span_kind() -> Iterator[Tuple[str, Any]]:
311
+ yield OPENINFERENCE_SPAN_KIND, LLM
312
+
313
+
314
+ def _llm_model_name(model_name: str) -> Iterator[Tuple[str, Any]]:
315
+ yield LLM_MODEL_NAME, model_name
316
+
317
+
318
+ def _llm_invocation_parameters(invocation_parameters: Dict[str, Any]) -> Iterator[Tuple[str, Any]]:
319
+ yield LLM_INVOCATION_PARAMETERS, safe_json_dumps(invocation_parameters)
320
+
321
+
322
+ def _llm_tools(tools: List[JSONScalarType]) -> Iterator[Tuple[str, Any]]:
323
+ for tool_index, tool in enumerate(tools):
324
+ yield f"{LLM_TOOLS}.{tool_index}.{TOOL_JSON_SCHEMA}", json.dumps(tool)
325
+
326
+
327
+ def _input_value_and_mime_type(input: ChatCompletionInput) -> Iterator[Tuple[str, Any]]:
328
+ assert any(field.name == (api_key := "api_key") for field in fields(ChatCompletionInput))
329
+ yield INPUT_MIME_TYPE, JSON
330
+ yield INPUT_VALUE, safe_json_dumps({k: v for k, v in jsonify(input).items() if k != api_key})
331
+
332
+
333
+ def _output_value_and_mime_type(output: Any) -> Iterator[Tuple[str, Any]]:
334
+ yield OUTPUT_MIME_TYPE, JSON
335
+ yield OUTPUT_VALUE, safe_json_dumps(jsonify(output))
336
+
337
+
338
+ def _llm_input_messages(
339
+ messages: Iterable[Tuple[ChatCompletionMessageRole, str]],
340
+ ) -> Iterator[Tuple[str, Any]]:
341
+ for i, (role, content) in enumerate(messages):
342
+ yield f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_ROLE}", role.value.lower()
343
+ yield f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_CONTENT}", content
344
+
345
+
346
+ def _llm_output_messages(
347
+ text_chunks: List[TextChunk],
348
+ tool_call_chunks: DefaultDict[ToolCallIndex, List[ToolCallChunk]],
349
+ ) -> Iterator[Tuple[str, Any]]:
350
+ yield f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_ROLE}", "assistant"
351
+ if content := "".join(chunk.content for chunk in text_chunks):
352
+ yield f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_CONTENT}", content
353
+ for tool_call_index, tool_call_chunks_ in tool_call_chunks.items():
354
+ if tool_call_chunks_ and (name := tool_call_chunks_[0].function.name):
355
+ yield (
356
+ f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_NAME}",
357
+ name,
358
+ )
359
+ if arguments := "".join(chunk.function.arguments for chunk in tool_call_chunks_):
360
+ yield (
361
+ f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_ARGUMENTS_JSON}",
362
+ arguments,
363
+ )
364
+
365
+
366
+ def _hex(number: int) -> str:
367
+ """
368
+ Converts an integer to a hexadecimal string.
369
+ """
370
+ return hex(number)[2:]
371
+
372
+
373
+ def _datetime(*, epoch_nanoseconds: float) -> datetime:
374
+ """
375
+ Converts a Unix epoch timestamp in nanoseconds to a datetime.
376
+ """
377
+ epoch_seconds = epoch_nanoseconds / 1e9
378
+ return datetime.fromtimestamp(epoch_seconds)
379
+
380
+
381
+ def _formatted_messages(
382
+ messages: Iterable[Tuple[ChatCompletionMessageRole, str]], template_options: TemplateOptions
383
+ ) -> Iterator[Tuple[ChatCompletionMessageRole, str]]:
384
+ """
385
+ Formats the messages using the given template options.
386
+ """
387
+ template_formatter = _template_formatter(template_language=template_options.language)
388
+ roles, templates = zip(*messages)
389
+ formatted_templates = map(
390
+ lambda template: template_formatter.format(template, **template_options.variables),
391
+ templates,
392
+ )
393
+ formatted_messages = zip(roles, formatted_templates)
394
+ return formatted_messages
395
+
396
+
397
+ def _template_formatter(template_language: TemplateLanguage) -> TemplateFormatter:
398
+ """
399
+ Instantiates the appropriate template formatter for the template language.
400
+ """
401
+ if template_language is TemplateLanguage.MUSTACHE:
402
+ return MustacheTemplateFormatter()
403
+ if template_language is TemplateLanguage.F_STRING:
404
+ return FStringTemplateFormatter()
405
+ assert_never(template_language)
406
+
407
+
408
+ JSON = OpenInferenceMimeTypeValues.JSON.value
409
+
410
+ LLM = OpenInferenceSpanKindValues.LLM.value
411
+
412
+ OPENINFERENCE_SPAN_KIND = SpanAttributes.OPENINFERENCE_SPAN_KIND
413
+ INPUT_MIME_TYPE = SpanAttributes.INPUT_MIME_TYPE
414
+ INPUT_VALUE = SpanAttributes.INPUT_VALUE
415
+ OUTPUT_MIME_TYPE = SpanAttributes.OUTPUT_MIME_TYPE
416
+ OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE
417
+ LLM_INPUT_MESSAGES = SpanAttributes.LLM_INPUT_MESSAGES
418
+ LLM_OUTPUT_MESSAGES = SpanAttributes.LLM_OUTPUT_MESSAGES
419
+ LLM_MODEL_NAME = SpanAttributes.LLM_MODEL_NAME
420
+ LLM_INVOCATION_PARAMETERS = SpanAttributes.LLM_INVOCATION_PARAMETERS
421
+ LLM_TOOLS = SpanAttributes.LLM_TOOLS
422
+
423
+ MESSAGE_CONTENT = MessageAttributes.MESSAGE_CONTENT
424
+ MESSAGE_ROLE = MessageAttributes.MESSAGE_ROLE
425
+ MESSAGE_TOOL_CALLS = MessageAttributes.MESSAGE_TOOL_CALLS
426
+
427
+ TOOL_CALL_FUNCTION_NAME = ToolCallAttributes.TOOL_CALL_FUNCTION_NAME
428
+ TOOL_CALL_FUNCTION_ARGUMENTS_JSON = ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUMENTS_JSON
429
+
430
+ TOOL_JSON_SCHEMA = ToolAttributes.TOOL_JSON_SCHEMA
@@ -0,0 +1,11 @@
1
+ from enum import Enum
2
+
3
+ import strawberry
4
+
5
+
6
+ @strawberry.enum
7
+ class ChatCompletionMessageRole(Enum):
8
+ USER = "USER"
9
+ SYSTEM = "SYSTEM"
10
+ TOOL = "TOOL"
11
+ AI = "AI" # E.g. the assistant. Normalize to AI for consistency.
@@ -0,0 +1,9 @@
1
+ import strawberry
2
+
3
+ from phoenix.server.api.types.GenerativeProvider import GenerativeProviderKey
4
+
5
+
6
+ @strawberry.type
7
+ class GenerativeModel:
8
+ name: str
9
+ provider_key: GenerativeProviderKey
@@ -0,0 +1,16 @@
1
+ from enum import Enum
2
+
3
+ import strawberry
4
+
5
+
6
+ @strawberry.enum
7
+ class GenerativeProviderKey(Enum):
8
+ OPENAI = "OPENAI"
9
+ ANTHROPIC = "ANTHROPIC"
10
+ AZURE_OPENAI = "AZURE_OPENAI"
11
+
12
+
13
+ @strawberry.type
14
+ class GenerativeProvider:
15
+ name: str
16
+ key: GenerativeProviderKey
phoenix/server/app.py CHANGED
@@ -36,22 +36,23 @@ from fastapi.utils import is_body_allowed_for_status_code
36
36
  from sqlalchemy import select
37
37
  from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
38
38
  from starlette.datastructures import State as StarletteState
39
- from starlette.exceptions import HTTPException
39
+ from starlette.exceptions import HTTPException, WebSocketException
40
40
  from starlette.middleware import Middleware
41
41
  from starlette.middleware.authentication import AuthenticationMiddleware
42
42
  from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
43
43
  from starlette.requests import Request
44
- from starlette.responses import PlainTextResponse, Response
44
+ from starlette.responses import JSONResponse, PlainTextResponse, Response
45
45
  from starlette.staticfiles import StaticFiles
46
46
  from starlette.status import HTTP_401_UNAUTHORIZED
47
47
  from starlette.templating import Jinja2Templates
48
48
  from starlette.types import Scope, StatefulLifespan
49
+ from starlette.websockets import WebSocket
49
50
  from strawberry.extensions import SchemaExtension
50
51
  from strawberry.fastapi import GraphQLRouter
51
52
  from strawberry.schema import BaseSchema
53
+ from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL
52
54
  from typing_extensions import TypeAlias
53
55
 
54
- import phoenix
55
56
  import phoenix.trace.v1 as pb
56
57
  from phoenix.config import (
57
58
  DEFAULT_PROJECT_NAME,
@@ -135,6 +136,7 @@ from phoenix.trace.fixtures import (
135
136
  from phoenix.trace.otel import decode_otlp_span, encode_span_to_otlp
136
137
  from phoenix.trace.schemas import Span
137
138
  from phoenix.utilities.client import PHOENIX_SERVER_VERSION_HEADER
139
+ from phoenix.version import __version__ as phoenix_version
138
140
 
139
141
  if TYPE_CHECKING:
140
142
  from opentelemetry.trace import TracerProvider
@@ -217,7 +219,7 @@ class Static(StaticFiles):
217
219
  "n_neighbors": self._app_config.n_neighbors,
218
220
  "n_samples": self._app_config.n_samples,
219
221
  "basename": self._sanitize_basename(request.scope.get("root_path", "")),
220
- "platform_version": phoenix.__version__,
222
+ "platform_version": phoenix_version,
221
223
  "request": request,
222
224
  "is_development": self._app_config.is_development,
223
225
  "manifest": self._web_manifest,
@@ -255,7 +257,7 @@ class HeadersMiddleware(BaseHTTPMiddleware):
255
257
  request: Request,
256
258
  call_next: RequestResponseEndpoint,
257
259
  ) -> Response:
258
- from phoenix import __version__ as phoenix_version
260
+ from phoenix.version import __version__ as phoenix_version
259
261
 
260
262
  response = await call_next(request)
261
263
  response.headers["x-colab-notebook-cache-control"] = "no-cache"
@@ -268,7 +270,7 @@ ProjectRowId: TypeAlias = int
268
270
 
269
271
  @router.get("/arize_phoenix_version")
270
272
  async def version() -> PlainTextResponse:
271
- return PlainTextResponse(f"{phoenix.__version__}")
273
+ return PlainTextResponse(f"{phoenix_version}")
272
274
 
273
275
 
274
276
  DB_MUTEX: Optional[asyncio.Lock] = None
@@ -580,6 +582,7 @@ def create_graphql_router(
580
582
  include_in_schema=False,
581
583
  prefix="/graphql",
582
584
  dependencies=(Depends(is_authenticated),) if authentication_enabled else (),
585
+ subscription_protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL],
583
586
  )
584
587
 
585
588
 
@@ -630,6 +633,29 @@ async def plain_text_http_exception_handler(request: Request, exc: HTTPException
630
633
  return PlainTextResponse(str(exc.detail), status_code=exc.status_code, headers=headers)
631
634
 
632
635
 
636
+ async def websocket_denial_response_handler(websocket: WebSocket, exc: WebSocketException) -> None:
637
+ """
638
+ Overrides the default exception handler for WebSocketException to ensure
639
+ that the HTTP response returned when a WebSocket connection is denied has
640
+ the same status code as the raised exception. This is in keeping with the
641
+ WebSocket Denial Response Extension of the ASGI specificiation described
642
+ below.
643
+
644
+ "Websocket connections start with the client sending a HTTP request
645
+ containing the appropriate upgrade headers. On receipt of this request a
646
+ server can choose to either upgrade the connection or respond with an HTTP
647
+ response (denying the upgrade). The core ASGI specification does not allow
648
+ for any control over the denial response, instead specifying that the HTTP
649
+ status code 403 should be returned, whereas this extension allows an ASGI
650
+ framework to control the denial response."
651
+
652
+ For details, see:
653
+ - https://asgi.readthedocs.io/en/latest/extensions.html#websocket-denial-response
654
+ """
655
+ assert isinstance(exc, WebSocketException)
656
+ await websocket.send_denial_response(JSONResponse(status_code=exc.code, content=exc.reason))
657
+
658
+
633
659
  def create_app(
634
660
  db: DbSessionFactory,
635
661
  export_path: Path,
@@ -776,7 +802,10 @@ def create_app(
776
802
  scaffolder_config=scaffolder_config,
777
803
  ),
778
804
  middleware=middlewares,
779
- exception_handlers={HTTPException: plain_text_http_exception_handler},
805
+ exception_handlers={
806
+ HTTPException: plain_text_http_exception_handler,
807
+ WebSocketException: websocket_denial_response_handler, # type: ignore[dict-item]
808
+ },
780
809
  debug=debug,
781
810
  swagger_ui_parameters={
782
811
  "defaultModelsExpandDepth": -1, # hides the schema section in the Swagger UI
@@ -7,10 +7,11 @@ from typing import (
7
7
  Callable,
8
8
  Optional,
9
9
  Tuple,
10
+ cast,
10
11
  )
11
12
 
12
13
  import grpc
13
- from fastapi import HTTPException, Request
14
+ from fastapi import HTTPException, Request, WebSocket, WebSocketException
14
15
  from grpc_interceptor import AsyncServerInterceptor
15
16
  from grpc_interceptor.exceptions import Unauthenticated
16
17
  from starlette.authentication import AuthCredentials, AuthenticationBackend, BaseUser
@@ -116,12 +117,19 @@ class ApiKeyInterceptor(HasTokenStore, AsyncServerInterceptor):
116
117
  raise Unauthenticated()
117
118
 
118
119
 
119
- async def is_authenticated(request: Request) -> None:
120
+ async def is_authenticated(
121
+ # fastapi dependencies require non-optional types
122
+ request: Request = cast(Request, None),
123
+ websocket: WebSocket = cast(WebSocket, None),
124
+ ) -> None:
120
125
  """
121
- Raises a 401 if the request is not authenticated.
126
+ Raises a 401 if the request or websocket connection is not authenticated.
122
127
  """
123
- if not isinstance((user := request.user), PhoenixUser):
128
+ assert request or websocket
129
+ if request and not isinstance((user := request.user), PhoenixUser):
124
130
  raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Invalid token")
131
+ if websocket and not isinstance((user := websocket.user), PhoenixUser):
132
+ raise WebSocketException(code=HTTP_401_UNAUTHORIZED, reason="Invalid token")
125
133
  claims = user.claims
126
134
  if claims.status is ClaimSetStatus.EXPIRED:
127
135
  raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Expired token")
phoenix/server/main.py CHANGED
@@ -3,7 +3,6 @@ import codecs
3
3
  import os
4
4
  import sys
5
5
  from argparse import SUPPRESS, ArgumentParser
6
- from importlib.metadata import version
7
6
  from pathlib import Path
8
7
  from threading import Thread
9
8
  from time import sleep, time
@@ -72,6 +71,7 @@ from phoenix.trace.fixtures import (
72
71
  )
73
72
  from phoenix.trace.otel import decode_otlp_span, encode_span_to_otlp
74
73
  from phoenix.trace.schemas import Span
74
+ from phoenix.version import __version__ as phoenix_version
75
75
 
76
76
  _WELCOME_MESSAGE = Environment(loader=BaseLoader()).from_string("""
77
77
 
@@ -351,7 +351,7 @@ def main() -> None:
351
351
  # Print information about the server
352
352
  root_path = urljoin(f"http://{host}:{port}", host_root_path)
353
353
  msg = _WELCOME_MESSAGE.render(
354
- version=version("arize-phoenix"),
354
+ version=phoenix_version,
355
355
  ui_path=root_path,
356
356
  grpc_path=f"http://{host}:{get_env_grpc_port()}",
357
357
  http_path=urljoin(root_path, "v1/traces"),