arize-phoenix 5.7.0__py3-none-any.whl → 5.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (24) hide show
  1. {arize_phoenix-5.7.0.dist-info → arize_phoenix-5.8.0.dist-info}/METADATA +3 -5
  2. {arize_phoenix-5.7.0.dist-info → arize_phoenix-5.8.0.dist-info}/RECORD +24 -24
  3. {arize_phoenix-5.7.0.dist-info → arize_phoenix-5.8.0.dist-info}/WHEEL +1 -1
  4. phoenix/config.py +19 -3
  5. phoenix/server/api/helpers/playground_clients.py +123 -36
  6. phoenix/server/api/helpers/playground_spans.py +173 -76
  7. phoenix/server/api/input_types/InvocationParameters.py +7 -8
  8. phoenix/server/api/mutations/chat_mutations.py +46 -65
  9. phoenix/server/api/subscriptions.py +210 -158
  10. phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +5 -3
  11. phoenix/server/app.py +14 -0
  12. phoenix/server/grpc_server.py +3 -1
  13. phoenix/server/static/.vite/manifest.json +31 -31
  14. phoenix/server/static/assets/{components-Csu8UKOs.js → components-MllbfxfJ.js} +168 -150
  15. phoenix/server/static/assets/{index-Bk5C9EA7.js → index-BVO2YcT1.js} +2 -2
  16. phoenix/server/static/assets/{pages-UeWaKXNs.js → pages-BHfC6jnL.js} +394 -300
  17. phoenix/server/static/assets/{vendor-CtqfhlbC.js → vendor-BEuNhfwH.js} +1 -1
  18. phoenix/server/static/assets/{vendor-arizeai-C_3SBz56.js → vendor-arizeai-Bskhzyjm.js} +1 -1
  19. phoenix/server/static/assets/{vendor-codemirror-wfdk9cjp.js → vendor-codemirror-DLlXCf0x.js} +1 -1
  20. phoenix/server/static/assets/{vendor-recharts-BiVnSv90.js → vendor-recharts-CRqhvLYg.js} +1 -1
  21. phoenix/version.py +1 -1
  22. {arize_phoenix-5.7.0.dist-info → arize_phoenix-5.8.0.dist-info}/entry_points.txt +0 -0
  23. {arize_phoenix-5.7.0.dist-info → arize_phoenix-5.8.0.dist-info}/licenses/IP_NOTICE +0 -0
  24. {arize_phoenix-5.7.0.dist-info → arize_phoenix-5.8.0.dist-info}/licenses/LICENSE +0 -0
@@ -26,7 +26,6 @@ from openinference.semconv.trace import (
26
26
  )
27
27
  from opentelemetry.sdk.trace.id_generator import RandomIdGenerator as DefaultOTelIDGenerator
28
28
  from opentelemetry.trace import StatusCode
29
- from sqlalchemy.ext.asyncio import AsyncSession
30
29
  from strawberry.scalars import JSON as JSONScalarType
31
30
  from typing_extensions import Self, TypeAlias, assert_never
32
31
 
@@ -41,7 +40,7 @@ from phoenix.server.api.types.ChatCompletionSubscriptionPayload import (
41
40
  TextChunk,
42
41
  ToolCallChunk,
43
42
  )
44
- from phoenix.trace.attributes import unflatten
43
+ from phoenix.trace.attributes import get_attribute_value, unflatten
45
44
  from phoenix.trace.schemas import (
46
45
  SpanEvent,
47
46
  SpanException,
@@ -56,7 +55,8 @@ ToolCallID: TypeAlias = str
56
55
 
57
56
  class streaming_llm_span:
58
57
  """
59
- Creates an LLM span for a streaming chat completion.
58
+ A context manager that records OpenInference attributes for streaming chat
59
+ completion LLM spans.
60
60
  """
61
61
 
62
62
  def __init__(
@@ -71,24 +71,23 @@ class streaming_llm_span:
71
71
  self._attributes: dict[str, Any] = attributes if attributes is not None else {}
72
72
  self._attributes.update(
73
73
  chain(
74
- _llm_span_kind(),
75
- _llm_model_name(input.model.name),
76
- _llm_tools(input.tools or []),
77
- _llm_input_messages(messages),
78
- _llm_invocation_parameters(invocation_parameters),
79
- _input_value_and_mime_type(input),
74
+ llm_span_kind(),
75
+ llm_model_name(input.model.name),
76
+ llm_tools(input.tools or []),
77
+ llm_input_messages(messages),
78
+ llm_invocation_parameters(invocation_parameters),
79
+ input_value_and_mime_type(input),
80
80
  )
81
81
  )
82
82
  self._events: list[SpanEvent] = []
83
- self._start_time: datetime
84
- self._end_time: datetime
85
- self._response_chunks: list[Union[TextChunk, ToolCallChunk]] = []
83
+ self._start_time: Optional[datetime] = None
84
+ self._end_time: Optional[datetime] = None
86
85
  self._text_chunks: list[TextChunk] = []
87
86
  self._tool_call_chunks: defaultdict[ToolCallID, list[ToolCallChunk]] = defaultdict(list)
88
- self._status_code: StatusCode
89
- self._status_message: str
90
- self._db_span: models.Span
91
- self._db_trace: models.Trace
87
+ self._status_code: StatusCode = StatusCode.UNSET
88
+ self._status_message: Optional[str] = None
89
+ self._trace_id = _generate_trace_id()
90
+ self._span_id = _generate_span_id()
92
91
 
93
92
  async def __aenter__(self) -> Self:
94
93
  self._start_time = cast(datetime, normalize_datetime(dt=local_now(), tz=timezone.utc))
@@ -102,7 +101,6 @@ class streaming_llm_span:
102
101
  ) -> bool:
103
102
  self._end_time = cast(datetime, normalize_datetime(dt=local_now(), tz=timezone.utc))
104
103
  self._status_code = StatusCode.OK
105
- self._status_message = ""
106
104
  if exc_type is not None:
107
105
  self._status_code = StatusCode.ERROR
108
106
  self._status_message = str(exc_value)
@@ -115,10 +113,10 @@ class streaming_llm_span:
115
113
  exception_stacktrace=format_exc(),
116
114
  )
117
115
  )
118
- if self._response_chunks:
116
+ if self._text_chunks or self._tool_call_chunks:
119
117
  self._attributes.update(
120
118
  chain(
121
- _output_value_and_mime_type(self._response_chunks),
119
+ _output_value_and_mime_type(self._text_chunks, self._tool_call_chunks),
122
120
  _llm_output_messages(self._text_chunks, self._tool_call_chunks),
123
121
  )
124
122
  )
@@ -127,46 +125,7 @@ class streaming_llm_span:
127
125
  def set_attributes(self, attributes: Mapping[str, Any]) -> None:
128
126
  self._attributes.update(attributes)
129
127
 
130
- def add_to_session(
131
- self,
132
- session: AsyncSession,
133
- project_id: int,
134
- ) -> models.Span:
135
- prompt_tokens = self._attributes.get(LLM_TOKEN_COUNT_PROMPT, 0)
136
- completion_tokens = self._attributes.get(LLM_TOKEN_COUNT_COMPLETION, 0)
137
- trace_id = _generate_trace_id()
138
- span_id = _generate_span_id()
139
- self._db_trace = models.Trace(
140
- project_rowid=project_id,
141
- trace_id=trace_id,
142
- start_time=self._start_time,
143
- end_time=self._end_time,
144
- )
145
- self._db_span = models.Span(
146
- trace_rowid=self._db_trace.id,
147
- span_id=span_id,
148
- parent_id=None,
149
- name="ChatCompletion",
150
- span_kind=LLM,
151
- start_time=self._start_time,
152
- end_time=self._end_time,
153
- attributes=unflatten(self._attributes.items()),
154
- events=[_serialize_event(event) for event in self._events],
155
- status_code=self._status_code.name,
156
- status_message=self._status_message,
157
- cumulative_error_count=int(self._status_code is StatusCode.ERROR),
158
- cumulative_llm_token_count_prompt=prompt_tokens,
159
- cumulative_llm_token_count_completion=completion_tokens,
160
- llm_token_count_prompt=prompt_tokens,
161
- llm_token_count_completion=completion_tokens,
162
- trace=self._db_trace,
163
- )
164
- session.add(self._db_trace)
165
- session.add(self._db_span)
166
- return self._db_span
167
-
168
128
  def add_response_chunk(self, chunk: Union[TextChunk, ToolCallChunk]) -> None:
169
- self._response_chunks.append(chunk)
170
129
  if isinstance(chunk, TextChunk):
171
130
  self._text_chunks.append(chunk)
172
131
  elif isinstance(chunk, ToolCallChunk):
@@ -174,48 +133,128 @@ class streaming_llm_span:
174
133
  else:
175
134
  assert_never(chunk)
176
135
 
136
+ @property
137
+ def span_id(self) -> str:
138
+ return self._span_id
139
+
140
+ @property
141
+ def trace_id(self) -> str:
142
+ return self._trace_id
143
+
177
144
  @property
178
145
  def start_time(self) -> datetime:
179
- return self._db_span.start_time
146
+ if self._start_time is None:
147
+ raise ValueError("Cannot access start time before the context manager is entered")
148
+ return self._start_time
180
149
 
181
150
  @property
182
151
  def end_time(self) -> datetime:
183
- return self._db_span.end_time
152
+ if self._end_time is None:
153
+ raise ValueError("Cannot access end time before the context manager is exited")
154
+ return self._end_time
184
155
 
185
156
  @property
186
- def error_message(self) -> Optional[str]:
187
- return self._status_message if self._status_code is StatusCode.ERROR else None
157
+ def status_code(self) -> StatusCode:
158
+ return self._status_code
188
159
 
189
160
  @property
190
- def trace_id(self) -> str:
191
- return self._db_trace.trace_id
161
+ def status_message(self) -> Optional[str]:
162
+ if self._status_code is StatusCode.UNSET:
163
+ raise ValueError("Cannot access status message before the context manager is exited")
164
+ return self._status_message
192
165
 
193
166
  @property
194
- def attributes(self) -> dict[str, Any]:
195
- return self._db_span.attributes
196
-
167
+ def events(self) -> list[SpanEvent]:
168
+ return self._events
197
169
 
198
- def _llm_span_kind() -> Iterator[tuple[str, Any]]:
170
+ @property
171
+ def attributes(self) -> dict[str, Any]:
172
+ return unflatten(self._attributes.items())
173
+
174
+
175
+ def get_db_trace(span: streaming_llm_span, project_id: int) -> models.Trace:
176
+ return models.Trace(
177
+ project_rowid=project_id,
178
+ trace_id=span.trace_id,
179
+ start_time=span.start_time,
180
+ end_time=span.end_time,
181
+ )
182
+
183
+
184
+ def get_db_span(
185
+ span: streaming_llm_span,
186
+ db_trace: models.Trace,
187
+ ) -> models.Span:
188
+ prompt_tokens = get_attribute_value(span.attributes, LLM_TOKEN_COUNT_PROMPT) or 0
189
+ completion_tokens = get_attribute_value(span.attributes, LLM_TOKEN_COUNT_COMPLETION) or 0
190
+ return models.Span(
191
+ trace_rowid=db_trace.id,
192
+ span_id=span.span_id,
193
+ parent_id=None,
194
+ name="ChatCompletion",
195
+ span_kind=LLM,
196
+ start_time=span.start_time,
197
+ end_time=span.end_time,
198
+ attributes=span.attributes,
199
+ events=[_serialize_event(event) for event in span.events],
200
+ status_code=span.status_code.name,
201
+ status_message=span.status_message or "",
202
+ cumulative_error_count=int(span.status_code is StatusCode.ERROR),
203
+ cumulative_llm_token_count_prompt=prompt_tokens,
204
+ cumulative_llm_token_count_completion=completion_tokens,
205
+ llm_token_count_prompt=prompt_tokens,
206
+ llm_token_count_completion=completion_tokens,
207
+ trace=db_trace,
208
+ )
209
+
210
+
211
+ def get_db_experiment_run(
212
+ db_span: models.Span,
213
+ db_trace: models.Trace,
214
+ *,
215
+ experiment_id: int,
216
+ example_id: int,
217
+ ) -> models.ExperimentRun:
218
+ return models.ExperimentRun(
219
+ experiment_id=experiment_id,
220
+ dataset_example_id=example_id,
221
+ trace_id=db_trace.trace_id,
222
+ output=models.ExperimentRunOutput(
223
+ task_output=get_attribute_value(db_span.attributes, LLM_OUTPUT_MESSAGES),
224
+ ),
225
+ repetition_number=1,
226
+ start_time=db_span.start_time,
227
+ end_time=db_span.end_time,
228
+ error=db_span.status_message or None,
229
+ prompt_token_count=get_attribute_value(db_span.attributes, LLM_TOKEN_COUNT_PROMPT),
230
+ completion_token_count=get_attribute_value(db_span.attributes, LLM_TOKEN_COUNT_COMPLETION),
231
+ trace=db_trace,
232
+ )
233
+
234
+
235
+ def llm_span_kind() -> Iterator[tuple[str, Any]]:
199
236
  yield OPENINFERENCE_SPAN_KIND, LLM
200
237
 
201
238
 
202
- def _llm_model_name(model_name: str) -> Iterator[tuple[str, Any]]:
239
+ def llm_model_name(model_name: str) -> Iterator[tuple[str, Any]]:
203
240
  yield LLM_MODEL_NAME, model_name
204
241
 
205
242
 
206
- def _llm_invocation_parameters(
243
+ def llm_invocation_parameters(
207
244
  invocation_parameters: Mapping[str, Any],
208
245
  ) -> Iterator[tuple[str, Any]]:
209
246
  if invocation_parameters:
210
247
  yield LLM_INVOCATION_PARAMETERS, safe_json_dumps(invocation_parameters)
211
248
 
212
249
 
213
- def _llm_tools(tools: list[JSONScalarType]) -> Iterator[tuple[str, Any]]:
250
+ def llm_tools(tools: list[JSONScalarType]) -> Iterator[tuple[str, Any]]:
214
251
  for tool_index, tool in enumerate(tools):
215
252
  yield f"{LLM_TOOLS}.{tool_index}.{TOOL_JSON_SCHEMA}", json.dumps(tool)
216
253
 
217
254
 
218
- def _input_value_and_mime_type(input: Any) -> Iterator[tuple[str, Any]]:
255
+ def input_value_and_mime_type(
256
+ input: Union[ChatCompletionInput, ChatCompletionOverDatasetInput],
257
+ ) -> Iterator[tuple[str, Any]]:
219
258
  assert (api_key := "api_key") in (input_data := jsonify(input))
220
259
  disallowed_keys = {"api_key", "invocation_parameters"}
221
260
  input_data = {k: v for k, v in input_data.items() if k not in disallowed_keys}
@@ -224,12 +263,69 @@ def _input_value_and_mime_type(input: Any) -> Iterator[tuple[str, Any]]:
224
263
  yield INPUT_VALUE, safe_json_dumps(input_data)
225
264
 
226
265
 
227
- def _output_value_and_mime_type(output: Any) -> Iterator[tuple[str, Any]]:
228
- yield OUTPUT_MIME_TYPE, JSON
229
- yield OUTPUT_VALUE, safe_json_dumps(jsonify(output))
266
+ def _merge_tool_call_chunks(
267
+ chunks_by_id: defaultdict[str, list[ToolCallChunk]],
268
+ ) -> list[dict[str, Any]]:
269
+ merged_tool_calls = []
270
+
271
+ for tool_id, chunks in chunks_by_id.items():
272
+ if not chunks:
273
+ continue
274
+ first_chunk = chunks[0]
275
+ if not first_chunk:
276
+ continue
277
+
278
+ if not hasattr(first_chunk, "function") or not hasattr(first_chunk.function, "name"):
279
+ continue
280
+ # Combine all argument chunks
281
+ merged_arguments = "".join(
282
+ chunk.function.arguments
283
+ for chunk in chunks
284
+ if chunk and hasattr(chunk, "function") and hasattr(chunk.function, "arguments")
285
+ )
286
+
287
+ merged_tool_calls.append(
288
+ {
289
+ "id": tool_id,
290
+ # Only the first chunk has the tool name
291
+ "function": {
292
+ "name": first_chunk.function.name,
293
+ "arguments": merged_arguments or "{}",
294
+ },
295
+ }
296
+ )
297
+
298
+ return merged_tool_calls
299
+
300
+
301
+ def _output_value_and_mime_type(
302
+ text_chunks: list[TextChunk],
303
+ tool_call_chunks: defaultdict[ToolCallID, list[ToolCallChunk]],
304
+ ) -> Iterator[tuple[str, Any]]:
305
+ content = "".join(chunk.content for chunk in text_chunks)
306
+ merged_tool_calls = _merge_tool_call_chunks(tool_call_chunks)
307
+ if content and merged_tool_calls:
308
+ yield OUTPUT_MIME_TYPE, JSON
309
+ yield (
310
+ OUTPUT_VALUE,
311
+ safe_json_dumps(
312
+ {
313
+ "content": content,
314
+ "tool_calls": jsonify(
315
+ merged_tool_calls,
316
+ ),
317
+ }
318
+ ),
319
+ )
320
+ elif merged_tool_calls:
321
+ yield OUTPUT_MIME_TYPE, JSON
322
+ yield OUTPUT_VALUE, safe_json_dumps(jsonify(merged_tool_calls))
323
+ elif content:
324
+ yield OUTPUT_MIME_TYPE, TEXT
325
+ yield OUTPUT_VALUE, content
230
326
 
231
327
 
232
- def _llm_input_messages(
328
+ def llm_input_messages(
233
329
  messages: Iterable[
234
330
  tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[JSONScalarType]]]
235
331
  ],
@@ -299,6 +395,7 @@ def _serialize_event(event: SpanEvent) -> dict[str, Any]:
299
395
 
300
396
 
301
397
  JSON = OpenInferenceMimeTypeValues.JSON.value
398
+ TEXT = OpenInferenceMimeTypeValues.TEXT.value
302
399
 
303
400
  LLM = OpenInferenceSpanKindValues.LLM.value
304
401
 
@@ -47,25 +47,24 @@ class InvocationParameterBase:
47
47
  canonical_name: Optional[CanonicalParameterName] = None
48
48
  label: str
49
49
  required: bool = False
50
- hidden: bool = False
51
50
 
52
51
 
53
52
  @strawberry.type
54
53
  class IntInvocationParameter(InvocationParameterBase):
55
54
  invocation_input_field: InvocationInputField = InvocationInputField.value_int
56
- default_value: Optional[int] = UNSET
55
+ default_value: Optional[int] = None
57
56
 
58
57
 
59
58
  @strawberry.type
60
59
  class FloatInvocationParameter(InvocationParameterBase):
61
60
  invocation_input_field: InvocationInputField = InvocationInputField.value_float
62
- default_value: Optional[float] = UNSET
61
+ default_value: Optional[float] = None
63
62
 
64
63
 
65
64
  @strawberry.type
66
65
  class BoundedFloatInvocationParameter(InvocationParameterBase):
67
66
  invocation_input_field: InvocationInputField = InvocationInputField.value_float
68
- default_value: Optional[float] = UNSET
67
+ default_value: Optional[float] = None
69
68
  min_value: float
70
69
  max_value: float
71
70
 
@@ -73,25 +72,25 @@ class BoundedFloatInvocationParameter(InvocationParameterBase):
73
72
  @strawberry.type
74
73
  class StringInvocationParameter(InvocationParameterBase):
75
74
  invocation_input_field: InvocationInputField = InvocationInputField.value_string
76
- default_value: Optional[str] = UNSET
75
+ default_value: Optional[str] = None
77
76
 
78
77
 
79
78
  @strawberry.type
80
79
  class JSONInvocationParameter(InvocationParameterBase):
81
80
  invocation_input_field: InvocationInputField = InvocationInputField.value_json
82
- default_value: Optional[JSON] = UNSET
81
+ default_value: Optional[JSON] = None
83
82
 
84
83
 
85
84
  @strawberry.type
86
85
  class StringListInvocationParameter(InvocationParameterBase):
87
86
  invocation_input_field: InvocationInputField = InvocationInputField.value_string_list
88
- default_value: Optional[list[str]] = UNSET
87
+ default_value: Optional[list[str]] = None
89
88
 
90
89
 
91
90
  @strawberry.type
92
91
  class BooleanInvocationParameter(InvocationParameterBase):
93
92
  invocation_input_field: InvocationInputField = InvocationInputField.value_bool
94
- default_value: Optional[bool] = UNSET
93
+ default_value: Optional[bool] = None
95
94
 
96
95
 
97
96
  def extract_parameter(
@@ -6,6 +6,7 @@ from traceback import format_exc
6
6
  from typing import Any, Iterable, Iterator, List, Optional
7
7
 
8
8
  import strawberry
9
+ from openinference.instrumentation import safe_json_dumps
9
10
  from openinference.semconv.trace import (
10
11
  MessageAttributes,
11
12
  OpenInferenceMimeTypeValues,
@@ -26,6 +27,14 @@ from phoenix.server.api.context import Context
26
27
  from phoenix.server.api.exceptions import BadRequest
27
28
  from phoenix.server.api.helpers.playground_clients import initialize_playground_clients
28
29
  from phoenix.server.api.helpers.playground_registry import PLAYGROUND_CLIENT_REGISTRY
30
+ from phoenix.server.api.helpers.playground_spans import (
31
+ input_value_and_mime_type,
32
+ llm_input_messages,
33
+ llm_invocation_parameters,
34
+ llm_model_name,
35
+ llm_span_kind,
36
+ llm_tools,
37
+ )
29
38
  from phoenix.server.api.input_types.ChatCompletionInput import ChatCompletionInput
30
39
  from phoenix.server.api.input_types.TemplateOptions import TemplateOptions
31
40
  from phoenix.server.api.types.ChatCompletionMessageRole import ChatCompletionMessageRole
@@ -38,6 +47,7 @@ from phoenix.server.api.types.TemplateLanguage import TemplateLanguage
38
47
  from phoenix.server.dml_event import SpanInsertEvent
39
48
  from phoenix.trace.attributes import unflatten
40
49
  from phoenix.trace.schemas import SpanException
50
+ from phoenix.utilities.json import jsonify
41
51
  from phoenix.utilities.template_formatters import (
42
52
  FStringTemplateFormatter,
43
53
  MustacheTemplateFormatter,
@@ -94,7 +104,6 @@ class ChatCompletionMutationMixin:
94
104
  )
95
105
  for message in input.messages
96
106
  ]
97
-
98
107
  if template_options := input.template:
99
108
  messages = list(_formatted_messages(messages, template_options))
100
109
 
@@ -103,16 +112,16 @@ class ChatCompletionMutationMixin:
103
112
  )
104
113
 
105
114
  text_content = ""
106
- tool_calls = []
115
+ tool_calls: dict[str, ChatCompletionToolCall] = {}
107
116
  events = []
108
117
  attributes.update(
109
118
  chain(
110
- _llm_span_kind(),
111
- _llm_model_name(input.model.name),
112
- _llm_tools(input.tools or []),
113
- _llm_input_messages(messages),
114
- _llm_invocation_parameters(invocation_parameters),
115
- _input_value_and_mime_type(input),
119
+ llm_span_kind(),
120
+ llm_model_name(input.model.name),
121
+ llm_tools(input.tools or []),
122
+ llm_input_messages(messages),
123
+ llm_invocation_parameters(invocation_parameters),
124
+ input_value_and_mime_type(input),
116
125
  **llm_client.attributes,
117
126
  )
118
127
  )
@@ -128,14 +137,16 @@ class ChatCompletionMutationMixin:
128
137
  if isinstance(chunk, TextChunk):
129
138
  text_content += chunk.content
130
139
  elif isinstance(chunk, ToolCallChunk):
131
- tool_call = ChatCompletionToolCall(
132
- id=chunk.id,
133
- function=ChatCompletionFunctionCall(
134
- name=chunk.function.name,
135
- arguments=chunk.function.arguments,
136
- ),
137
- )
138
- tool_calls.append(tool_call)
140
+ if chunk.id not in tool_calls:
141
+ tool_calls[chunk.id] = ChatCompletionToolCall(
142
+ id=chunk.id,
143
+ function=ChatCompletionFunctionCall(
144
+ name=chunk.function.name,
145
+ arguments=chunk.function.arguments,
146
+ ),
147
+ )
148
+ else:
149
+ tool_calls[chunk.id].function.arguments += chunk.function.arguments
139
150
  else:
140
151
  assert_never(chunk)
141
152
  except Exception as e:
@@ -159,7 +170,7 @@ class ChatCompletionMutationMixin:
159
170
  if text_content or tool_calls:
160
171
  attributes.update(
161
172
  chain(
162
- _output_value_and_mime_type({"text": text_content, "tool_calls": tool_calls}),
173
+ _output_value_and_mime_type(text_content, tool_calls),
163
174
  _llm_output_messages(text_content, tool_calls),
164
175
  )
165
176
  )
@@ -225,7 +236,7 @@ class ChatCompletionMutationMixin:
225
236
  else:
226
237
  return ChatCompletionMutationPayload(
227
238
  content=text_content if text_content else None,
228
- tool_calls=tool_calls,
239
+ tool_calls=list(tool_calls.values()),
229
240
  span=gql_span,
230
241
  error_message=None,
231
242
  )
@@ -264,61 +275,30 @@ def _template_formatter(template_language: TemplateLanguage) -> TemplateFormatte
264
275
  assert_never(template_language)
265
276
 
266
277
 
267
- def _llm_span_kind() -> Iterator[tuple[str, Any]]:
268
- yield OPENINFERENCE_SPAN_KIND, LLM
269
-
270
-
271
- def _llm_model_name(model_name: str) -> Iterator[tuple[str, Any]]:
272
- yield LLM_MODEL_NAME, model_name
273
-
274
-
275
- def _llm_invocation_parameters(invocation_parameters: dict[str, Any]) -> Iterator[tuple[str, Any]]:
276
- yield LLM_INVOCATION_PARAMETERS, json.dumps(invocation_parameters)
277
-
278
-
279
- def _llm_tools(tools: List[Any]) -> Iterator[tuple[str, Any]]:
280
- for tool_index, tool in enumerate(tools):
281
- yield f"{LLM_TOOLS}.{tool_index}.{TOOL_JSON_SCHEMA}", json.dumps(tool)
282
-
283
-
284
- def _input_value_and_mime_type(input: ChatCompletionInput) -> Iterator[tuple[str, Any]]:
285
- input_data = input.__dict__.copy()
286
- input_data.pop("api_key", None)
287
- yield INPUT_MIME_TYPE, JSON
288
- yield INPUT_VALUE, json.dumps(input_data)
289
-
290
-
291
- def _output_value_and_mime_type(output: Any) -> Iterator[tuple[str, Any]]:
292
- yield OUTPUT_MIME_TYPE, JSON
293
- yield OUTPUT_VALUE, json.dumps(output)
294
-
295
-
296
- def _llm_input_messages(
297
- messages: Iterable[ChatCompletionMessage],
278
+ def _output_value_and_mime_type(
279
+ text: str, tool_calls: dict[str, ChatCompletionToolCall]
298
280
  ) -> Iterator[tuple[str, Any]]:
299
- for i, (role, content, _tool_call_id, tool_calls) in enumerate(messages):
300
- yield f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_ROLE}", role.value.lower()
301
- yield f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_CONTENT}", content
302
- if tool_calls:
303
- for tool_call_index, tool_call in enumerate(tool_calls):
304
- yield (
305
- f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_NAME}",
306
- tool_call["function"]["name"],
307
- )
308
- if arguments := tool_call["function"]["arguments"]:
309
- yield (
310
- f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_ARGUMENTS_JSON}",
311
- json.dumps(arguments),
312
- )
281
+ if text and tool_calls:
282
+ yield OUTPUT_MIME_TYPE, JSON
283
+ yield (
284
+ OUTPUT_VALUE,
285
+ safe_json_dumps({"content": text, "tool_calls": jsonify(list(tool_calls.values()))}),
286
+ )
287
+ elif tool_calls:
288
+ yield OUTPUT_MIME_TYPE, JSON
289
+ yield OUTPUT_VALUE, safe_json_dumps(jsonify(list(tool_calls.values())))
290
+ elif text:
291
+ yield OUTPUT_MIME_TYPE, TEXT
292
+ yield OUTPUT_VALUE, text
313
293
 
314
294
 
315
295
  def _llm_output_messages(
316
- text_content: str, tool_calls: List[ChatCompletionToolCall]
296
+ text_content: str, tool_calls: dict[str, ChatCompletionToolCall]
317
297
  ) -> Iterator[tuple[str, Any]]:
318
298
  yield f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_ROLE}", "assistant"
319
299
  if text_content:
320
300
  yield f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_CONTENT}", text_content
321
- for tool_call_index, tool_call in enumerate(tool_calls):
301
+ for tool_call_index, tool_call in enumerate(tool_calls.values()):
322
302
  yield (
323
303
  f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_NAME}",
324
304
  tool_call.function.name,
@@ -347,6 +327,7 @@ def _serialize_event(event: SpanException) -> dict[str, Any]:
347
327
 
348
328
 
349
329
  JSON = OpenInferenceMimeTypeValues.JSON.value
330
+ TEXT = OpenInferenceMimeTypeValues.TEXT.value
350
331
  LLM = OpenInferenceSpanKindValues.LLM.value
351
332
 
352
333
  OPENINFERENCE_SPAN_KIND = SpanAttributes.OPENINFERENCE_SPAN_KIND