arize-phoenix 5.7.0__py3-none-any.whl → 5.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (32) hide show
  1. {arize_phoenix-5.7.0.dist-info → arize_phoenix-5.9.0.dist-info}/METADATA +3 -5
  2. {arize_phoenix-5.7.0.dist-info → arize_phoenix-5.9.0.dist-info}/RECORD +31 -31
  3. {arize_phoenix-5.7.0.dist-info → arize_phoenix-5.9.0.dist-info}/WHEEL +1 -1
  4. phoenix/config.py +19 -3
  5. phoenix/db/helpers.py +55 -1
  6. phoenix/server/api/helpers/playground_clients.py +283 -44
  7. phoenix/server/api/helpers/playground_spans.py +173 -76
  8. phoenix/server/api/input_types/InvocationParameters.py +7 -8
  9. phoenix/server/api/mutations/chat_mutations.py +244 -76
  10. phoenix/server/api/queries.py +5 -1
  11. phoenix/server/api/routers/v1/spans.py +25 -1
  12. phoenix/server/api/subscriptions.py +210 -158
  13. phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +5 -3
  14. phoenix/server/api/types/ExperimentRun.py +38 -1
  15. phoenix/server/api/types/GenerativeProvider.py +2 -1
  16. phoenix/server/app.py +21 -2
  17. phoenix/server/grpc_server.py +3 -1
  18. phoenix/server/static/.vite/manifest.json +32 -32
  19. phoenix/server/static/assets/{components-Csu8UKOs.js → components-DU-8CYbi.js} +370 -329
  20. phoenix/server/static/assets/{index-Bk5C9EA7.js → index-D9E16vvV.js} +2 -2
  21. phoenix/server/static/assets/pages-t09OI1rC.js +3966 -0
  22. phoenix/server/static/assets/{vendor-CtqfhlbC.js → vendor-D04tenE6.js} +181 -181
  23. phoenix/server/static/assets/{vendor-arizeai-C_3SBz56.js → vendor-arizeai-D3NxMQw0.js} +2 -2
  24. phoenix/server/static/assets/{vendor-codemirror-wfdk9cjp.js → vendor-codemirror-XTiZSlqq.js} +5 -5
  25. phoenix/server/static/assets/{vendor-recharts-BiVnSv90.js → vendor-recharts-p0L0neVs.js} +1 -1
  26. phoenix/session/client.py +27 -7
  27. phoenix/utilities/json.py +31 -1
  28. phoenix/version.py +1 -1
  29. phoenix/server/static/assets/pages-UeWaKXNs.js +0 -3737
  30. {arize_phoenix-5.7.0.dist-info → arize_phoenix-5.9.0.dist-info}/entry_points.txt +0 -0
  31. {arize_phoenix-5.7.0.dist-info → arize_phoenix-5.9.0.dist-info}/licenses/IP_NOTICE +0 -0
  32. {arize_phoenix-5.7.0.dist-info → arize_phoenix-5.9.0.dist-info}/licenses/LICENSE +0 -0
@@ -26,7 +26,6 @@ from openinference.semconv.trace import (
26
26
  )
27
27
  from opentelemetry.sdk.trace.id_generator import RandomIdGenerator as DefaultOTelIDGenerator
28
28
  from opentelemetry.trace import StatusCode
29
- from sqlalchemy.ext.asyncio import AsyncSession
30
29
  from strawberry.scalars import JSON as JSONScalarType
31
30
  from typing_extensions import Self, TypeAlias, assert_never
32
31
 
@@ -41,7 +40,7 @@ from phoenix.server.api.types.ChatCompletionSubscriptionPayload import (
41
40
  TextChunk,
42
41
  ToolCallChunk,
43
42
  )
44
- from phoenix.trace.attributes import unflatten
43
+ from phoenix.trace.attributes import get_attribute_value, unflatten
45
44
  from phoenix.trace.schemas import (
46
45
  SpanEvent,
47
46
  SpanException,
@@ -56,7 +55,8 @@ ToolCallID: TypeAlias = str
56
55
 
57
56
  class streaming_llm_span:
58
57
  """
59
- Creates an LLM span for a streaming chat completion.
58
+ A context manager that records OpenInference attributes for streaming chat
59
+ completion LLM spans.
60
60
  """
61
61
 
62
62
  def __init__(
@@ -71,24 +71,23 @@ class streaming_llm_span:
71
71
  self._attributes: dict[str, Any] = attributes if attributes is not None else {}
72
72
  self._attributes.update(
73
73
  chain(
74
- _llm_span_kind(),
75
- _llm_model_name(input.model.name),
76
- _llm_tools(input.tools or []),
77
- _llm_input_messages(messages),
78
- _llm_invocation_parameters(invocation_parameters),
79
- _input_value_and_mime_type(input),
74
+ llm_span_kind(),
75
+ llm_model_name(input.model.name),
76
+ llm_tools(input.tools or []),
77
+ llm_input_messages(messages),
78
+ llm_invocation_parameters(invocation_parameters),
79
+ input_value_and_mime_type(input),
80
80
  )
81
81
  )
82
82
  self._events: list[SpanEvent] = []
83
- self._start_time: datetime
84
- self._end_time: datetime
85
- self._response_chunks: list[Union[TextChunk, ToolCallChunk]] = []
83
+ self._start_time: Optional[datetime] = None
84
+ self._end_time: Optional[datetime] = None
86
85
  self._text_chunks: list[TextChunk] = []
87
86
  self._tool_call_chunks: defaultdict[ToolCallID, list[ToolCallChunk]] = defaultdict(list)
88
- self._status_code: StatusCode
89
- self._status_message: str
90
- self._db_span: models.Span
91
- self._db_trace: models.Trace
87
+ self._status_code: StatusCode = StatusCode.UNSET
88
+ self._status_message: Optional[str] = None
89
+ self._trace_id = _generate_trace_id()
90
+ self._span_id = _generate_span_id()
92
91
 
93
92
  async def __aenter__(self) -> Self:
94
93
  self._start_time = cast(datetime, normalize_datetime(dt=local_now(), tz=timezone.utc))
@@ -102,7 +101,6 @@ class streaming_llm_span:
102
101
  ) -> bool:
103
102
  self._end_time = cast(datetime, normalize_datetime(dt=local_now(), tz=timezone.utc))
104
103
  self._status_code = StatusCode.OK
105
- self._status_message = ""
106
104
  if exc_type is not None:
107
105
  self._status_code = StatusCode.ERROR
108
106
  self._status_message = str(exc_value)
@@ -115,10 +113,10 @@ class streaming_llm_span:
115
113
  exception_stacktrace=format_exc(),
116
114
  )
117
115
  )
118
- if self._response_chunks:
116
+ if self._text_chunks or self._tool_call_chunks:
119
117
  self._attributes.update(
120
118
  chain(
121
- _output_value_and_mime_type(self._response_chunks),
119
+ _output_value_and_mime_type(self._text_chunks, self._tool_call_chunks),
122
120
  _llm_output_messages(self._text_chunks, self._tool_call_chunks),
123
121
  )
124
122
  )
@@ -127,46 +125,7 @@ class streaming_llm_span:
127
125
  def set_attributes(self, attributes: Mapping[str, Any]) -> None:
128
126
  self._attributes.update(attributes)
129
127
 
130
- def add_to_session(
131
- self,
132
- session: AsyncSession,
133
- project_id: int,
134
- ) -> models.Span:
135
- prompt_tokens = self._attributes.get(LLM_TOKEN_COUNT_PROMPT, 0)
136
- completion_tokens = self._attributes.get(LLM_TOKEN_COUNT_COMPLETION, 0)
137
- trace_id = _generate_trace_id()
138
- span_id = _generate_span_id()
139
- self._db_trace = models.Trace(
140
- project_rowid=project_id,
141
- trace_id=trace_id,
142
- start_time=self._start_time,
143
- end_time=self._end_time,
144
- )
145
- self._db_span = models.Span(
146
- trace_rowid=self._db_trace.id,
147
- span_id=span_id,
148
- parent_id=None,
149
- name="ChatCompletion",
150
- span_kind=LLM,
151
- start_time=self._start_time,
152
- end_time=self._end_time,
153
- attributes=unflatten(self._attributes.items()),
154
- events=[_serialize_event(event) for event in self._events],
155
- status_code=self._status_code.name,
156
- status_message=self._status_message,
157
- cumulative_error_count=int(self._status_code is StatusCode.ERROR),
158
- cumulative_llm_token_count_prompt=prompt_tokens,
159
- cumulative_llm_token_count_completion=completion_tokens,
160
- llm_token_count_prompt=prompt_tokens,
161
- llm_token_count_completion=completion_tokens,
162
- trace=self._db_trace,
163
- )
164
- session.add(self._db_trace)
165
- session.add(self._db_span)
166
- return self._db_span
167
-
168
128
  def add_response_chunk(self, chunk: Union[TextChunk, ToolCallChunk]) -> None:
169
- self._response_chunks.append(chunk)
170
129
  if isinstance(chunk, TextChunk):
171
130
  self._text_chunks.append(chunk)
172
131
  elif isinstance(chunk, ToolCallChunk):
@@ -174,48 +133,128 @@ class streaming_llm_span:
174
133
  else:
175
134
  assert_never(chunk)
176
135
 
136
+ @property
137
+ def span_id(self) -> str:
138
+ return self._span_id
139
+
140
+ @property
141
+ def trace_id(self) -> str:
142
+ return self._trace_id
143
+
177
144
  @property
178
145
  def start_time(self) -> datetime:
179
- return self._db_span.start_time
146
+ if self._start_time is None:
147
+ raise ValueError("Cannot access start time before the context manager is entered")
148
+ return self._start_time
180
149
 
181
150
  @property
182
151
  def end_time(self) -> datetime:
183
- return self._db_span.end_time
152
+ if self._end_time is None:
153
+ raise ValueError("Cannot access end time before the context manager is exited")
154
+ return self._end_time
184
155
 
185
156
  @property
186
- def error_message(self) -> Optional[str]:
187
- return self._status_message if self._status_code is StatusCode.ERROR else None
157
+ def status_code(self) -> StatusCode:
158
+ return self._status_code
188
159
 
189
160
  @property
190
- def trace_id(self) -> str:
191
- return self._db_trace.trace_id
161
+ def status_message(self) -> Optional[str]:
162
+ if self._status_code is StatusCode.UNSET:
163
+ raise ValueError("Cannot access status message before the context manager is exited")
164
+ return self._status_message
192
165
 
193
166
  @property
194
- def attributes(self) -> dict[str, Any]:
195
- return self._db_span.attributes
196
-
167
+ def events(self) -> list[SpanEvent]:
168
+ return self._events
197
169
 
198
- def _llm_span_kind() -> Iterator[tuple[str, Any]]:
170
+ @property
171
+ def attributes(self) -> dict[str, Any]:
172
+ return unflatten(self._attributes.items())
173
+
174
+
175
+ def get_db_trace(span: streaming_llm_span, project_id: int) -> models.Trace:
176
+ return models.Trace(
177
+ project_rowid=project_id,
178
+ trace_id=span.trace_id,
179
+ start_time=span.start_time,
180
+ end_time=span.end_time,
181
+ )
182
+
183
+
184
+ def get_db_span(
185
+ span: streaming_llm_span,
186
+ db_trace: models.Trace,
187
+ ) -> models.Span:
188
+ prompt_tokens = get_attribute_value(span.attributes, LLM_TOKEN_COUNT_PROMPT) or 0
189
+ completion_tokens = get_attribute_value(span.attributes, LLM_TOKEN_COUNT_COMPLETION) or 0
190
+ return models.Span(
191
+ trace_rowid=db_trace.id,
192
+ span_id=span.span_id,
193
+ parent_id=None,
194
+ name="ChatCompletion",
195
+ span_kind=LLM,
196
+ start_time=span.start_time,
197
+ end_time=span.end_time,
198
+ attributes=span.attributes,
199
+ events=[_serialize_event(event) for event in span.events],
200
+ status_code=span.status_code.name,
201
+ status_message=span.status_message or "",
202
+ cumulative_error_count=int(span.status_code is StatusCode.ERROR),
203
+ cumulative_llm_token_count_prompt=prompt_tokens,
204
+ cumulative_llm_token_count_completion=completion_tokens,
205
+ llm_token_count_prompt=prompt_tokens,
206
+ llm_token_count_completion=completion_tokens,
207
+ trace=db_trace,
208
+ )
209
+
210
+
211
+ def get_db_experiment_run(
212
+ db_span: models.Span,
213
+ db_trace: models.Trace,
214
+ *,
215
+ experiment_id: int,
216
+ example_id: int,
217
+ ) -> models.ExperimentRun:
218
+ return models.ExperimentRun(
219
+ experiment_id=experiment_id,
220
+ dataset_example_id=example_id,
221
+ trace_id=db_trace.trace_id,
222
+ output=models.ExperimentRunOutput(
223
+ task_output=get_attribute_value(db_span.attributes, LLM_OUTPUT_MESSAGES),
224
+ ),
225
+ repetition_number=1,
226
+ start_time=db_span.start_time,
227
+ end_time=db_span.end_time,
228
+ error=db_span.status_message or None,
229
+ prompt_token_count=get_attribute_value(db_span.attributes, LLM_TOKEN_COUNT_PROMPT),
230
+ completion_token_count=get_attribute_value(db_span.attributes, LLM_TOKEN_COUNT_COMPLETION),
231
+ trace=db_trace,
232
+ )
233
+
234
+
235
+ def llm_span_kind() -> Iterator[tuple[str, Any]]:
199
236
  yield OPENINFERENCE_SPAN_KIND, LLM
200
237
 
201
238
 
202
- def _llm_model_name(model_name: str) -> Iterator[tuple[str, Any]]:
239
+ def llm_model_name(model_name: str) -> Iterator[tuple[str, Any]]:
203
240
  yield LLM_MODEL_NAME, model_name
204
241
 
205
242
 
206
- def _llm_invocation_parameters(
243
+ def llm_invocation_parameters(
207
244
  invocation_parameters: Mapping[str, Any],
208
245
  ) -> Iterator[tuple[str, Any]]:
209
246
  if invocation_parameters:
210
247
  yield LLM_INVOCATION_PARAMETERS, safe_json_dumps(invocation_parameters)
211
248
 
212
249
 
213
- def _llm_tools(tools: list[JSONScalarType]) -> Iterator[tuple[str, Any]]:
250
+ def llm_tools(tools: list[JSONScalarType]) -> Iterator[tuple[str, Any]]:
214
251
  for tool_index, tool in enumerate(tools):
215
252
  yield f"{LLM_TOOLS}.{tool_index}.{TOOL_JSON_SCHEMA}", json.dumps(tool)
216
253
 
217
254
 
218
- def _input_value_and_mime_type(input: Any) -> Iterator[tuple[str, Any]]:
255
+ def input_value_and_mime_type(
256
+ input: Union[ChatCompletionInput, ChatCompletionOverDatasetInput],
257
+ ) -> Iterator[tuple[str, Any]]:
219
258
  assert (api_key := "api_key") in (input_data := jsonify(input))
220
259
  disallowed_keys = {"api_key", "invocation_parameters"}
221
260
  input_data = {k: v for k, v in input_data.items() if k not in disallowed_keys}
@@ -224,12 +263,69 @@ def _input_value_and_mime_type(input: Any) -> Iterator[tuple[str, Any]]:
224
263
  yield INPUT_VALUE, safe_json_dumps(input_data)
225
264
 
226
265
 
227
- def _output_value_and_mime_type(output: Any) -> Iterator[tuple[str, Any]]:
228
- yield OUTPUT_MIME_TYPE, JSON
229
- yield OUTPUT_VALUE, safe_json_dumps(jsonify(output))
266
+ def _merge_tool_call_chunks(
267
+ chunks_by_id: defaultdict[str, list[ToolCallChunk]],
268
+ ) -> list[dict[str, Any]]:
269
+ merged_tool_calls = []
270
+
271
+ for tool_id, chunks in chunks_by_id.items():
272
+ if not chunks:
273
+ continue
274
+ first_chunk = chunks[0]
275
+ if not first_chunk:
276
+ continue
277
+
278
+ if not hasattr(first_chunk, "function") or not hasattr(first_chunk.function, "name"):
279
+ continue
280
+ # Combine all argument chunks
281
+ merged_arguments = "".join(
282
+ chunk.function.arguments
283
+ for chunk in chunks
284
+ if chunk and hasattr(chunk, "function") and hasattr(chunk.function, "arguments")
285
+ )
286
+
287
+ merged_tool_calls.append(
288
+ {
289
+ "id": tool_id,
290
+ # Only the first chunk has the tool name
291
+ "function": {
292
+ "name": first_chunk.function.name,
293
+ "arguments": merged_arguments or "{}",
294
+ },
295
+ }
296
+ )
297
+
298
+ return merged_tool_calls
299
+
300
+
301
+ def _output_value_and_mime_type(
302
+ text_chunks: list[TextChunk],
303
+ tool_call_chunks: defaultdict[ToolCallID, list[ToolCallChunk]],
304
+ ) -> Iterator[tuple[str, Any]]:
305
+ content = "".join(chunk.content for chunk in text_chunks)
306
+ merged_tool_calls = _merge_tool_call_chunks(tool_call_chunks)
307
+ if content and merged_tool_calls:
308
+ yield OUTPUT_MIME_TYPE, JSON
309
+ yield (
310
+ OUTPUT_VALUE,
311
+ safe_json_dumps(
312
+ {
313
+ "content": content,
314
+ "tool_calls": jsonify(
315
+ merged_tool_calls,
316
+ ),
317
+ }
318
+ ),
319
+ )
320
+ elif merged_tool_calls:
321
+ yield OUTPUT_MIME_TYPE, JSON
322
+ yield OUTPUT_VALUE, safe_json_dumps(jsonify(merged_tool_calls))
323
+ elif content:
324
+ yield OUTPUT_MIME_TYPE, TEXT
325
+ yield OUTPUT_VALUE, content
230
326
 
231
327
 
232
- def _llm_input_messages(
328
+ def llm_input_messages(
233
329
  messages: Iterable[
234
330
  tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[JSONScalarType]]]
235
331
  ],
@@ -299,6 +395,7 @@ def _serialize_event(event: SpanEvent) -> dict[str, Any]:
299
395
 
300
396
 
301
397
  JSON = OpenInferenceMimeTypeValues.JSON.value
398
+ TEXT = OpenInferenceMimeTypeValues.TEXT.value
302
399
 
303
400
  LLM = OpenInferenceSpanKindValues.LLM.value
304
401
 
@@ -47,25 +47,24 @@ class InvocationParameterBase:
47
47
  canonical_name: Optional[CanonicalParameterName] = None
48
48
  label: str
49
49
  required: bool = False
50
- hidden: bool = False
51
50
 
52
51
 
53
52
  @strawberry.type
54
53
  class IntInvocationParameter(InvocationParameterBase):
55
54
  invocation_input_field: InvocationInputField = InvocationInputField.value_int
56
- default_value: Optional[int] = UNSET
55
+ default_value: Optional[int] = None
57
56
 
58
57
 
59
58
  @strawberry.type
60
59
  class FloatInvocationParameter(InvocationParameterBase):
61
60
  invocation_input_field: InvocationInputField = InvocationInputField.value_float
62
- default_value: Optional[float] = UNSET
61
+ default_value: Optional[float] = None
63
62
 
64
63
 
65
64
  @strawberry.type
66
65
  class BoundedFloatInvocationParameter(InvocationParameterBase):
67
66
  invocation_input_field: InvocationInputField = InvocationInputField.value_float
68
- default_value: Optional[float] = UNSET
67
+ default_value: Optional[float] = None
69
68
  min_value: float
70
69
  max_value: float
71
70
 
@@ -73,25 +72,25 @@ class BoundedFloatInvocationParameter(InvocationParameterBase):
73
72
  @strawberry.type
74
73
  class StringInvocationParameter(InvocationParameterBase):
75
74
  invocation_input_field: InvocationInputField = InvocationInputField.value_string
76
- default_value: Optional[str] = UNSET
75
+ default_value: Optional[str] = None
77
76
 
78
77
 
79
78
  @strawberry.type
80
79
  class JSONInvocationParameter(InvocationParameterBase):
81
80
  invocation_input_field: InvocationInputField = InvocationInputField.value_json
82
- default_value: Optional[JSON] = UNSET
81
+ default_value: Optional[JSON] = None
83
82
 
84
83
 
85
84
  @strawberry.type
86
85
  class StringListInvocationParameter(InvocationParameterBase):
87
86
  invocation_input_field: InvocationInputField = InvocationInputField.value_string_list
88
- default_value: Optional[list[str]] = UNSET
87
+ default_value: Optional[list[str]] = None
89
88
 
90
89
 
91
90
  @strawberry.type
92
91
  class BooleanInvocationParameter(InvocationParameterBase):
93
92
  invocation_input_field: InvocationInputField = InvocationInputField.value_bool
94
- default_value: Optional[bool] = UNSET
93
+ default_value: Optional[bool] = None
95
94
 
96
95
 
97
96
  def extract_parameter(