arize-phoenix 5.7.0__py3-none-any.whl → 5.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-5.7.0.dist-info → arize_phoenix-5.8.0.dist-info}/METADATA +3 -5
- {arize_phoenix-5.7.0.dist-info → arize_phoenix-5.8.0.dist-info}/RECORD +24 -24
- {arize_phoenix-5.7.0.dist-info → arize_phoenix-5.8.0.dist-info}/WHEEL +1 -1
- phoenix/config.py +19 -3
- phoenix/server/api/helpers/playground_clients.py +123 -36
- phoenix/server/api/helpers/playground_spans.py +173 -76
- phoenix/server/api/input_types/InvocationParameters.py +7 -8
- phoenix/server/api/mutations/chat_mutations.py +46 -65
- phoenix/server/api/subscriptions.py +210 -158
- phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +5 -3
- phoenix/server/app.py +14 -0
- phoenix/server/grpc_server.py +3 -1
- phoenix/server/static/.vite/manifest.json +31 -31
- phoenix/server/static/assets/{components-Csu8UKOs.js → components-MllbfxfJ.js} +168 -150
- phoenix/server/static/assets/{index-Bk5C9EA7.js → index-BVO2YcT1.js} +2 -2
- phoenix/server/static/assets/{pages-UeWaKXNs.js → pages-BHfC6jnL.js} +394 -300
- phoenix/server/static/assets/{vendor-CtqfhlbC.js → vendor-BEuNhfwH.js} +1 -1
- phoenix/server/static/assets/{vendor-arizeai-C_3SBz56.js → vendor-arizeai-Bskhzyjm.js} +1 -1
- phoenix/server/static/assets/{vendor-codemirror-wfdk9cjp.js → vendor-codemirror-DLlXCf0x.js} +1 -1
- phoenix/server/static/assets/{vendor-recharts-BiVnSv90.js → vendor-recharts-CRqhvLYg.js} +1 -1
- phoenix/version.py +1 -1
- {arize_phoenix-5.7.0.dist-info → arize_phoenix-5.8.0.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-5.7.0.dist-info → arize_phoenix-5.8.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-5.7.0.dist-info → arize_phoenix-5.8.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -26,7 +26,6 @@ from openinference.semconv.trace import (
|
|
|
26
26
|
)
|
|
27
27
|
from opentelemetry.sdk.trace.id_generator import RandomIdGenerator as DefaultOTelIDGenerator
|
|
28
28
|
from opentelemetry.trace import StatusCode
|
|
29
|
-
from sqlalchemy.ext.asyncio import AsyncSession
|
|
30
29
|
from strawberry.scalars import JSON as JSONScalarType
|
|
31
30
|
from typing_extensions import Self, TypeAlias, assert_never
|
|
32
31
|
|
|
@@ -41,7 +40,7 @@ from phoenix.server.api.types.ChatCompletionSubscriptionPayload import (
|
|
|
41
40
|
TextChunk,
|
|
42
41
|
ToolCallChunk,
|
|
43
42
|
)
|
|
44
|
-
from phoenix.trace.attributes import unflatten
|
|
43
|
+
from phoenix.trace.attributes import get_attribute_value, unflatten
|
|
45
44
|
from phoenix.trace.schemas import (
|
|
46
45
|
SpanEvent,
|
|
47
46
|
SpanException,
|
|
@@ -56,7 +55,8 @@ ToolCallID: TypeAlias = str
|
|
|
56
55
|
|
|
57
56
|
class streaming_llm_span:
|
|
58
57
|
"""
|
|
59
|
-
|
|
58
|
+
A context manager that records OpenInference attributes for streaming chat
|
|
59
|
+
completion LLM spans.
|
|
60
60
|
"""
|
|
61
61
|
|
|
62
62
|
def __init__(
|
|
@@ -71,24 +71,23 @@ class streaming_llm_span:
|
|
|
71
71
|
self._attributes: dict[str, Any] = attributes if attributes is not None else {}
|
|
72
72
|
self._attributes.update(
|
|
73
73
|
chain(
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
74
|
+
llm_span_kind(),
|
|
75
|
+
llm_model_name(input.model.name),
|
|
76
|
+
llm_tools(input.tools or []),
|
|
77
|
+
llm_input_messages(messages),
|
|
78
|
+
llm_invocation_parameters(invocation_parameters),
|
|
79
|
+
input_value_and_mime_type(input),
|
|
80
80
|
)
|
|
81
81
|
)
|
|
82
82
|
self._events: list[SpanEvent] = []
|
|
83
|
-
self._start_time: datetime
|
|
84
|
-
self._end_time: datetime
|
|
85
|
-
self._response_chunks: list[Union[TextChunk, ToolCallChunk]] = []
|
|
83
|
+
self._start_time: Optional[datetime] = None
|
|
84
|
+
self._end_time: Optional[datetime] = None
|
|
86
85
|
self._text_chunks: list[TextChunk] = []
|
|
87
86
|
self._tool_call_chunks: defaultdict[ToolCallID, list[ToolCallChunk]] = defaultdict(list)
|
|
88
|
-
self._status_code: StatusCode
|
|
89
|
-
self._status_message: str
|
|
90
|
-
self.
|
|
91
|
-
self.
|
|
87
|
+
self._status_code: StatusCode = StatusCode.UNSET
|
|
88
|
+
self._status_message: Optional[str] = None
|
|
89
|
+
self._trace_id = _generate_trace_id()
|
|
90
|
+
self._span_id = _generate_span_id()
|
|
92
91
|
|
|
93
92
|
async def __aenter__(self) -> Self:
|
|
94
93
|
self._start_time = cast(datetime, normalize_datetime(dt=local_now(), tz=timezone.utc))
|
|
@@ -102,7 +101,6 @@ class streaming_llm_span:
|
|
|
102
101
|
) -> bool:
|
|
103
102
|
self._end_time = cast(datetime, normalize_datetime(dt=local_now(), tz=timezone.utc))
|
|
104
103
|
self._status_code = StatusCode.OK
|
|
105
|
-
self._status_message = ""
|
|
106
104
|
if exc_type is not None:
|
|
107
105
|
self._status_code = StatusCode.ERROR
|
|
108
106
|
self._status_message = str(exc_value)
|
|
@@ -115,10 +113,10 @@ class streaming_llm_span:
|
|
|
115
113
|
exception_stacktrace=format_exc(),
|
|
116
114
|
)
|
|
117
115
|
)
|
|
118
|
-
if self.
|
|
116
|
+
if self._text_chunks or self._tool_call_chunks:
|
|
119
117
|
self._attributes.update(
|
|
120
118
|
chain(
|
|
121
|
-
_output_value_and_mime_type(self.
|
|
119
|
+
_output_value_and_mime_type(self._text_chunks, self._tool_call_chunks),
|
|
122
120
|
_llm_output_messages(self._text_chunks, self._tool_call_chunks),
|
|
123
121
|
)
|
|
124
122
|
)
|
|
@@ -127,46 +125,7 @@ class streaming_llm_span:
|
|
|
127
125
|
def set_attributes(self, attributes: Mapping[str, Any]) -> None:
|
|
128
126
|
self._attributes.update(attributes)
|
|
129
127
|
|
|
130
|
-
def add_to_session(
|
|
131
|
-
self,
|
|
132
|
-
session: AsyncSession,
|
|
133
|
-
project_id: int,
|
|
134
|
-
) -> models.Span:
|
|
135
|
-
prompt_tokens = self._attributes.get(LLM_TOKEN_COUNT_PROMPT, 0)
|
|
136
|
-
completion_tokens = self._attributes.get(LLM_TOKEN_COUNT_COMPLETION, 0)
|
|
137
|
-
trace_id = _generate_trace_id()
|
|
138
|
-
span_id = _generate_span_id()
|
|
139
|
-
self._db_trace = models.Trace(
|
|
140
|
-
project_rowid=project_id,
|
|
141
|
-
trace_id=trace_id,
|
|
142
|
-
start_time=self._start_time,
|
|
143
|
-
end_time=self._end_time,
|
|
144
|
-
)
|
|
145
|
-
self._db_span = models.Span(
|
|
146
|
-
trace_rowid=self._db_trace.id,
|
|
147
|
-
span_id=span_id,
|
|
148
|
-
parent_id=None,
|
|
149
|
-
name="ChatCompletion",
|
|
150
|
-
span_kind=LLM,
|
|
151
|
-
start_time=self._start_time,
|
|
152
|
-
end_time=self._end_time,
|
|
153
|
-
attributes=unflatten(self._attributes.items()),
|
|
154
|
-
events=[_serialize_event(event) for event in self._events],
|
|
155
|
-
status_code=self._status_code.name,
|
|
156
|
-
status_message=self._status_message,
|
|
157
|
-
cumulative_error_count=int(self._status_code is StatusCode.ERROR),
|
|
158
|
-
cumulative_llm_token_count_prompt=prompt_tokens,
|
|
159
|
-
cumulative_llm_token_count_completion=completion_tokens,
|
|
160
|
-
llm_token_count_prompt=prompt_tokens,
|
|
161
|
-
llm_token_count_completion=completion_tokens,
|
|
162
|
-
trace=self._db_trace,
|
|
163
|
-
)
|
|
164
|
-
session.add(self._db_trace)
|
|
165
|
-
session.add(self._db_span)
|
|
166
|
-
return self._db_span
|
|
167
|
-
|
|
168
128
|
def add_response_chunk(self, chunk: Union[TextChunk, ToolCallChunk]) -> None:
|
|
169
|
-
self._response_chunks.append(chunk)
|
|
170
129
|
if isinstance(chunk, TextChunk):
|
|
171
130
|
self._text_chunks.append(chunk)
|
|
172
131
|
elif isinstance(chunk, ToolCallChunk):
|
|
@@ -174,48 +133,128 @@ class streaming_llm_span:
|
|
|
174
133
|
else:
|
|
175
134
|
assert_never(chunk)
|
|
176
135
|
|
|
136
|
+
@property
|
|
137
|
+
def span_id(self) -> str:
|
|
138
|
+
return self._span_id
|
|
139
|
+
|
|
140
|
+
@property
|
|
141
|
+
def trace_id(self) -> str:
|
|
142
|
+
return self._trace_id
|
|
143
|
+
|
|
177
144
|
@property
|
|
178
145
|
def start_time(self) -> datetime:
|
|
179
|
-
|
|
146
|
+
if self._start_time is None:
|
|
147
|
+
raise ValueError("Cannot access start time before the context manager is entered")
|
|
148
|
+
return self._start_time
|
|
180
149
|
|
|
181
150
|
@property
|
|
182
151
|
def end_time(self) -> datetime:
|
|
183
|
-
|
|
152
|
+
if self._end_time is None:
|
|
153
|
+
raise ValueError("Cannot access end time before the context manager is exited")
|
|
154
|
+
return self._end_time
|
|
184
155
|
|
|
185
156
|
@property
|
|
186
|
-
def
|
|
187
|
-
return self.
|
|
157
|
+
def status_code(self) -> StatusCode:
|
|
158
|
+
return self._status_code
|
|
188
159
|
|
|
189
160
|
@property
|
|
190
|
-
def
|
|
191
|
-
|
|
161
|
+
def status_message(self) -> Optional[str]:
|
|
162
|
+
if self._status_code is StatusCode.UNSET:
|
|
163
|
+
raise ValueError("Cannot access status message before the context manager is exited")
|
|
164
|
+
return self._status_message
|
|
192
165
|
|
|
193
166
|
@property
|
|
194
|
-
def
|
|
195
|
-
return self.
|
|
196
|
-
|
|
167
|
+
def events(self) -> list[SpanEvent]:
|
|
168
|
+
return self._events
|
|
197
169
|
|
|
198
|
-
|
|
170
|
+
@property
|
|
171
|
+
def attributes(self) -> dict[str, Any]:
|
|
172
|
+
return unflatten(self._attributes.items())
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def get_db_trace(span: streaming_llm_span, project_id: int) -> models.Trace:
|
|
176
|
+
return models.Trace(
|
|
177
|
+
project_rowid=project_id,
|
|
178
|
+
trace_id=span.trace_id,
|
|
179
|
+
start_time=span.start_time,
|
|
180
|
+
end_time=span.end_time,
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def get_db_span(
|
|
185
|
+
span: streaming_llm_span,
|
|
186
|
+
db_trace: models.Trace,
|
|
187
|
+
) -> models.Span:
|
|
188
|
+
prompt_tokens = get_attribute_value(span.attributes, LLM_TOKEN_COUNT_PROMPT) or 0
|
|
189
|
+
completion_tokens = get_attribute_value(span.attributes, LLM_TOKEN_COUNT_COMPLETION) or 0
|
|
190
|
+
return models.Span(
|
|
191
|
+
trace_rowid=db_trace.id,
|
|
192
|
+
span_id=span.span_id,
|
|
193
|
+
parent_id=None,
|
|
194
|
+
name="ChatCompletion",
|
|
195
|
+
span_kind=LLM,
|
|
196
|
+
start_time=span.start_time,
|
|
197
|
+
end_time=span.end_time,
|
|
198
|
+
attributes=span.attributes,
|
|
199
|
+
events=[_serialize_event(event) for event in span.events],
|
|
200
|
+
status_code=span.status_code.name,
|
|
201
|
+
status_message=span.status_message or "",
|
|
202
|
+
cumulative_error_count=int(span.status_code is StatusCode.ERROR),
|
|
203
|
+
cumulative_llm_token_count_prompt=prompt_tokens,
|
|
204
|
+
cumulative_llm_token_count_completion=completion_tokens,
|
|
205
|
+
llm_token_count_prompt=prompt_tokens,
|
|
206
|
+
llm_token_count_completion=completion_tokens,
|
|
207
|
+
trace=db_trace,
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def get_db_experiment_run(
|
|
212
|
+
db_span: models.Span,
|
|
213
|
+
db_trace: models.Trace,
|
|
214
|
+
*,
|
|
215
|
+
experiment_id: int,
|
|
216
|
+
example_id: int,
|
|
217
|
+
) -> models.ExperimentRun:
|
|
218
|
+
return models.ExperimentRun(
|
|
219
|
+
experiment_id=experiment_id,
|
|
220
|
+
dataset_example_id=example_id,
|
|
221
|
+
trace_id=db_trace.trace_id,
|
|
222
|
+
output=models.ExperimentRunOutput(
|
|
223
|
+
task_output=get_attribute_value(db_span.attributes, LLM_OUTPUT_MESSAGES),
|
|
224
|
+
),
|
|
225
|
+
repetition_number=1,
|
|
226
|
+
start_time=db_span.start_time,
|
|
227
|
+
end_time=db_span.end_time,
|
|
228
|
+
error=db_span.status_message or None,
|
|
229
|
+
prompt_token_count=get_attribute_value(db_span.attributes, LLM_TOKEN_COUNT_PROMPT),
|
|
230
|
+
completion_token_count=get_attribute_value(db_span.attributes, LLM_TOKEN_COUNT_COMPLETION),
|
|
231
|
+
trace=db_trace,
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def llm_span_kind() -> Iterator[tuple[str, Any]]:
|
|
199
236
|
yield OPENINFERENCE_SPAN_KIND, LLM
|
|
200
237
|
|
|
201
238
|
|
|
202
|
-
def
|
|
239
|
+
def llm_model_name(model_name: str) -> Iterator[tuple[str, Any]]:
|
|
203
240
|
yield LLM_MODEL_NAME, model_name
|
|
204
241
|
|
|
205
242
|
|
|
206
|
-
def
|
|
243
|
+
def llm_invocation_parameters(
|
|
207
244
|
invocation_parameters: Mapping[str, Any],
|
|
208
245
|
) -> Iterator[tuple[str, Any]]:
|
|
209
246
|
if invocation_parameters:
|
|
210
247
|
yield LLM_INVOCATION_PARAMETERS, safe_json_dumps(invocation_parameters)
|
|
211
248
|
|
|
212
249
|
|
|
213
|
-
def
|
|
250
|
+
def llm_tools(tools: list[JSONScalarType]) -> Iterator[tuple[str, Any]]:
|
|
214
251
|
for tool_index, tool in enumerate(tools):
|
|
215
252
|
yield f"{LLM_TOOLS}.{tool_index}.{TOOL_JSON_SCHEMA}", json.dumps(tool)
|
|
216
253
|
|
|
217
254
|
|
|
218
|
-
def
|
|
255
|
+
def input_value_and_mime_type(
|
|
256
|
+
input: Union[ChatCompletionInput, ChatCompletionOverDatasetInput],
|
|
257
|
+
) -> Iterator[tuple[str, Any]]:
|
|
219
258
|
assert (api_key := "api_key") in (input_data := jsonify(input))
|
|
220
259
|
disallowed_keys = {"api_key", "invocation_parameters"}
|
|
221
260
|
input_data = {k: v for k, v in input_data.items() if k not in disallowed_keys}
|
|
@@ -224,12 +263,69 @@ def _input_value_and_mime_type(input: Any) -> Iterator[tuple[str, Any]]:
|
|
|
224
263
|
yield INPUT_VALUE, safe_json_dumps(input_data)
|
|
225
264
|
|
|
226
265
|
|
|
227
|
-
def
|
|
228
|
-
|
|
229
|
-
|
|
266
|
+
def _merge_tool_call_chunks(
|
|
267
|
+
chunks_by_id: defaultdict[str, list[ToolCallChunk]],
|
|
268
|
+
) -> list[dict[str, Any]]:
|
|
269
|
+
merged_tool_calls = []
|
|
270
|
+
|
|
271
|
+
for tool_id, chunks in chunks_by_id.items():
|
|
272
|
+
if not chunks:
|
|
273
|
+
continue
|
|
274
|
+
first_chunk = chunks[0]
|
|
275
|
+
if not first_chunk:
|
|
276
|
+
continue
|
|
277
|
+
|
|
278
|
+
if not hasattr(first_chunk, "function") or not hasattr(first_chunk.function, "name"):
|
|
279
|
+
continue
|
|
280
|
+
# Combine all argument chunks
|
|
281
|
+
merged_arguments = "".join(
|
|
282
|
+
chunk.function.arguments
|
|
283
|
+
for chunk in chunks
|
|
284
|
+
if chunk and hasattr(chunk, "function") and hasattr(chunk.function, "arguments")
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
merged_tool_calls.append(
|
|
288
|
+
{
|
|
289
|
+
"id": tool_id,
|
|
290
|
+
# Only the first chunk has the tool name
|
|
291
|
+
"function": {
|
|
292
|
+
"name": first_chunk.function.name,
|
|
293
|
+
"arguments": merged_arguments or "{}",
|
|
294
|
+
},
|
|
295
|
+
}
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
return merged_tool_calls
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
def _output_value_and_mime_type(
|
|
302
|
+
text_chunks: list[TextChunk],
|
|
303
|
+
tool_call_chunks: defaultdict[ToolCallID, list[ToolCallChunk]],
|
|
304
|
+
) -> Iterator[tuple[str, Any]]:
|
|
305
|
+
content = "".join(chunk.content for chunk in text_chunks)
|
|
306
|
+
merged_tool_calls = _merge_tool_call_chunks(tool_call_chunks)
|
|
307
|
+
if content and merged_tool_calls:
|
|
308
|
+
yield OUTPUT_MIME_TYPE, JSON
|
|
309
|
+
yield (
|
|
310
|
+
OUTPUT_VALUE,
|
|
311
|
+
safe_json_dumps(
|
|
312
|
+
{
|
|
313
|
+
"content": content,
|
|
314
|
+
"tool_calls": jsonify(
|
|
315
|
+
merged_tool_calls,
|
|
316
|
+
),
|
|
317
|
+
}
|
|
318
|
+
),
|
|
319
|
+
)
|
|
320
|
+
elif merged_tool_calls:
|
|
321
|
+
yield OUTPUT_MIME_TYPE, JSON
|
|
322
|
+
yield OUTPUT_VALUE, safe_json_dumps(jsonify(merged_tool_calls))
|
|
323
|
+
elif content:
|
|
324
|
+
yield OUTPUT_MIME_TYPE, TEXT
|
|
325
|
+
yield OUTPUT_VALUE, content
|
|
230
326
|
|
|
231
327
|
|
|
232
|
-
def
|
|
328
|
+
def llm_input_messages(
|
|
233
329
|
messages: Iterable[
|
|
234
330
|
tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[JSONScalarType]]]
|
|
235
331
|
],
|
|
@@ -299,6 +395,7 @@ def _serialize_event(event: SpanEvent) -> dict[str, Any]:
|
|
|
299
395
|
|
|
300
396
|
|
|
301
397
|
JSON = OpenInferenceMimeTypeValues.JSON.value
|
|
398
|
+
TEXT = OpenInferenceMimeTypeValues.TEXT.value
|
|
302
399
|
|
|
303
400
|
LLM = OpenInferenceSpanKindValues.LLM.value
|
|
304
401
|
|
|
@@ -47,25 +47,24 @@ class InvocationParameterBase:
|
|
|
47
47
|
canonical_name: Optional[CanonicalParameterName] = None
|
|
48
48
|
label: str
|
|
49
49
|
required: bool = False
|
|
50
|
-
hidden: bool = False
|
|
51
50
|
|
|
52
51
|
|
|
53
52
|
@strawberry.type
|
|
54
53
|
class IntInvocationParameter(InvocationParameterBase):
|
|
55
54
|
invocation_input_field: InvocationInputField = InvocationInputField.value_int
|
|
56
|
-
default_value: Optional[int] =
|
|
55
|
+
default_value: Optional[int] = None
|
|
57
56
|
|
|
58
57
|
|
|
59
58
|
@strawberry.type
|
|
60
59
|
class FloatInvocationParameter(InvocationParameterBase):
|
|
61
60
|
invocation_input_field: InvocationInputField = InvocationInputField.value_float
|
|
62
|
-
default_value: Optional[float] =
|
|
61
|
+
default_value: Optional[float] = None
|
|
63
62
|
|
|
64
63
|
|
|
65
64
|
@strawberry.type
|
|
66
65
|
class BoundedFloatInvocationParameter(InvocationParameterBase):
|
|
67
66
|
invocation_input_field: InvocationInputField = InvocationInputField.value_float
|
|
68
|
-
default_value: Optional[float] =
|
|
67
|
+
default_value: Optional[float] = None
|
|
69
68
|
min_value: float
|
|
70
69
|
max_value: float
|
|
71
70
|
|
|
@@ -73,25 +72,25 @@ class BoundedFloatInvocationParameter(InvocationParameterBase):
|
|
|
73
72
|
@strawberry.type
|
|
74
73
|
class StringInvocationParameter(InvocationParameterBase):
|
|
75
74
|
invocation_input_field: InvocationInputField = InvocationInputField.value_string
|
|
76
|
-
default_value: Optional[str] =
|
|
75
|
+
default_value: Optional[str] = None
|
|
77
76
|
|
|
78
77
|
|
|
79
78
|
@strawberry.type
|
|
80
79
|
class JSONInvocationParameter(InvocationParameterBase):
|
|
81
80
|
invocation_input_field: InvocationInputField = InvocationInputField.value_json
|
|
82
|
-
default_value: Optional[JSON] =
|
|
81
|
+
default_value: Optional[JSON] = None
|
|
83
82
|
|
|
84
83
|
|
|
85
84
|
@strawberry.type
|
|
86
85
|
class StringListInvocationParameter(InvocationParameterBase):
|
|
87
86
|
invocation_input_field: InvocationInputField = InvocationInputField.value_string_list
|
|
88
|
-
default_value: Optional[list[str]] =
|
|
87
|
+
default_value: Optional[list[str]] = None
|
|
89
88
|
|
|
90
89
|
|
|
91
90
|
@strawberry.type
|
|
92
91
|
class BooleanInvocationParameter(InvocationParameterBase):
|
|
93
92
|
invocation_input_field: InvocationInputField = InvocationInputField.value_bool
|
|
94
|
-
default_value: Optional[bool] =
|
|
93
|
+
default_value: Optional[bool] = None
|
|
95
94
|
|
|
96
95
|
|
|
97
96
|
def extract_parameter(
|
|
@@ -6,6 +6,7 @@ from traceback import format_exc
|
|
|
6
6
|
from typing import Any, Iterable, Iterator, List, Optional
|
|
7
7
|
|
|
8
8
|
import strawberry
|
|
9
|
+
from openinference.instrumentation import safe_json_dumps
|
|
9
10
|
from openinference.semconv.trace import (
|
|
10
11
|
MessageAttributes,
|
|
11
12
|
OpenInferenceMimeTypeValues,
|
|
@@ -26,6 +27,14 @@ from phoenix.server.api.context import Context
|
|
|
26
27
|
from phoenix.server.api.exceptions import BadRequest
|
|
27
28
|
from phoenix.server.api.helpers.playground_clients import initialize_playground_clients
|
|
28
29
|
from phoenix.server.api.helpers.playground_registry import PLAYGROUND_CLIENT_REGISTRY
|
|
30
|
+
from phoenix.server.api.helpers.playground_spans import (
|
|
31
|
+
input_value_and_mime_type,
|
|
32
|
+
llm_input_messages,
|
|
33
|
+
llm_invocation_parameters,
|
|
34
|
+
llm_model_name,
|
|
35
|
+
llm_span_kind,
|
|
36
|
+
llm_tools,
|
|
37
|
+
)
|
|
29
38
|
from phoenix.server.api.input_types.ChatCompletionInput import ChatCompletionInput
|
|
30
39
|
from phoenix.server.api.input_types.TemplateOptions import TemplateOptions
|
|
31
40
|
from phoenix.server.api.types.ChatCompletionMessageRole import ChatCompletionMessageRole
|
|
@@ -38,6 +47,7 @@ from phoenix.server.api.types.TemplateLanguage import TemplateLanguage
|
|
|
38
47
|
from phoenix.server.dml_event import SpanInsertEvent
|
|
39
48
|
from phoenix.trace.attributes import unflatten
|
|
40
49
|
from phoenix.trace.schemas import SpanException
|
|
50
|
+
from phoenix.utilities.json import jsonify
|
|
41
51
|
from phoenix.utilities.template_formatters import (
|
|
42
52
|
FStringTemplateFormatter,
|
|
43
53
|
MustacheTemplateFormatter,
|
|
@@ -94,7 +104,6 @@ class ChatCompletionMutationMixin:
|
|
|
94
104
|
)
|
|
95
105
|
for message in input.messages
|
|
96
106
|
]
|
|
97
|
-
|
|
98
107
|
if template_options := input.template:
|
|
99
108
|
messages = list(_formatted_messages(messages, template_options))
|
|
100
109
|
|
|
@@ -103,16 +112,16 @@ class ChatCompletionMutationMixin:
|
|
|
103
112
|
)
|
|
104
113
|
|
|
105
114
|
text_content = ""
|
|
106
|
-
tool_calls =
|
|
115
|
+
tool_calls: dict[str, ChatCompletionToolCall] = {}
|
|
107
116
|
events = []
|
|
108
117
|
attributes.update(
|
|
109
118
|
chain(
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
119
|
+
llm_span_kind(),
|
|
120
|
+
llm_model_name(input.model.name),
|
|
121
|
+
llm_tools(input.tools or []),
|
|
122
|
+
llm_input_messages(messages),
|
|
123
|
+
llm_invocation_parameters(invocation_parameters),
|
|
124
|
+
input_value_and_mime_type(input),
|
|
116
125
|
**llm_client.attributes,
|
|
117
126
|
)
|
|
118
127
|
)
|
|
@@ -128,14 +137,16 @@ class ChatCompletionMutationMixin:
|
|
|
128
137
|
if isinstance(chunk, TextChunk):
|
|
129
138
|
text_content += chunk.content
|
|
130
139
|
elif isinstance(chunk, ToolCallChunk):
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
140
|
+
if chunk.id not in tool_calls:
|
|
141
|
+
tool_calls[chunk.id] = ChatCompletionToolCall(
|
|
142
|
+
id=chunk.id,
|
|
143
|
+
function=ChatCompletionFunctionCall(
|
|
144
|
+
name=chunk.function.name,
|
|
145
|
+
arguments=chunk.function.arguments,
|
|
146
|
+
),
|
|
147
|
+
)
|
|
148
|
+
else:
|
|
149
|
+
tool_calls[chunk.id].function.arguments += chunk.function.arguments
|
|
139
150
|
else:
|
|
140
151
|
assert_never(chunk)
|
|
141
152
|
except Exception as e:
|
|
@@ -159,7 +170,7 @@ class ChatCompletionMutationMixin:
|
|
|
159
170
|
if text_content or tool_calls:
|
|
160
171
|
attributes.update(
|
|
161
172
|
chain(
|
|
162
|
-
_output_value_and_mime_type(
|
|
173
|
+
_output_value_and_mime_type(text_content, tool_calls),
|
|
163
174
|
_llm_output_messages(text_content, tool_calls),
|
|
164
175
|
)
|
|
165
176
|
)
|
|
@@ -225,7 +236,7 @@ class ChatCompletionMutationMixin:
|
|
|
225
236
|
else:
|
|
226
237
|
return ChatCompletionMutationPayload(
|
|
227
238
|
content=text_content if text_content else None,
|
|
228
|
-
tool_calls=tool_calls,
|
|
239
|
+
tool_calls=list(tool_calls.values()),
|
|
229
240
|
span=gql_span,
|
|
230
241
|
error_message=None,
|
|
231
242
|
)
|
|
@@ -264,61 +275,30 @@ def _template_formatter(template_language: TemplateLanguage) -> TemplateFormatte
|
|
|
264
275
|
assert_never(template_language)
|
|
265
276
|
|
|
266
277
|
|
|
267
|
-
def
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
def _llm_model_name(model_name: str) -> Iterator[tuple[str, Any]]:
|
|
272
|
-
yield LLM_MODEL_NAME, model_name
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
def _llm_invocation_parameters(invocation_parameters: dict[str, Any]) -> Iterator[tuple[str, Any]]:
|
|
276
|
-
yield LLM_INVOCATION_PARAMETERS, json.dumps(invocation_parameters)
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
def _llm_tools(tools: List[Any]) -> Iterator[tuple[str, Any]]:
|
|
280
|
-
for tool_index, tool in enumerate(tools):
|
|
281
|
-
yield f"{LLM_TOOLS}.{tool_index}.{TOOL_JSON_SCHEMA}", json.dumps(tool)
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
def _input_value_and_mime_type(input: ChatCompletionInput) -> Iterator[tuple[str, Any]]:
|
|
285
|
-
input_data = input.__dict__.copy()
|
|
286
|
-
input_data.pop("api_key", None)
|
|
287
|
-
yield INPUT_MIME_TYPE, JSON
|
|
288
|
-
yield INPUT_VALUE, json.dumps(input_data)
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
def _output_value_and_mime_type(output: Any) -> Iterator[tuple[str, Any]]:
|
|
292
|
-
yield OUTPUT_MIME_TYPE, JSON
|
|
293
|
-
yield OUTPUT_VALUE, json.dumps(output)
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
def _llm_input_messages(
|
|
297
|
-
messages: Iterable[ChatCompletionMessage],
|
|
278
|
+
def _output_value_and_mime_type(
|
|
279
|
+
text: str, tool_calls: dict[str, ChatCompletionToolCall]
|
|
298
280
|
) -> Iterator[tuple[str, Any]]:
|
|
299
|
-
|
|
300
|
-
yield
|
|
301
|
-
yield
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
json.dumps(arguments),
|
|
312
|
-
)
|
|
281
|
+
if text and tool_calls:
|
|
282
|
+
yield OUTPUT_MIME_TYPE, JSON
|
|
283
|
+
yield (
|
|
284
|
+
OUTPUT_VALUE,
|
|
285
|
+
safe_json_dumps({"content": text, "tool_calls": jsonify(list(tool_calls.values()))}),
|
|
286
|
+
)
|
|
287
|
+
elif tool_calls:
|
|
288
|
+
yield OUTPUT_MIME_TYPE, JSON
|
|
289
|
+
yield OUTPUT_VALUE, safe_json_dumps(jsonify(list(tool_calls.values())))
|
|
290
|
+
elif text:
|
|
291
|
+
yield OUTPUT_MIME_TYPE, TEXT
|
|
292
|
+
yield OUTPUT_VALUE, text
|
|
313
293
|
|
|
314
294
|
|
|
315
295
|
def _llm_output_messages(
|
|
316
|
-
text_content: str, tool_calls:
|
|
296
|
+
text_content: str, tool_calls: dict[str, ChatCompletionToolCall]
|
|
317
297
|
) -> Iterator[tuple[str, Any]]:
|
|
318
298
|
yield f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_ROLE}", "assistant"
|
|
319
299
|
if text_content:
|
|
320
300
|
yield f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_CONTENT}", text_content
|
|
321
|
-
for tool_call_index, tool_call in enumerate(tool_calls):
|
|
301
|
+
for tool_call_index, tool_call in enumerate(tool_calls.values()):
|
|
322
302
|
yield (
|
|
323
303
|
f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_NAME}",
|
|
324
304
|
tool_call.function.name,
|
|
@@ -347,6 +327,7 @@ def _serialize_event(event: SpanException) -> dict[str, Any]:
|
|
|
347
327
|
|
|
348
328
|
|
|
349
329
|
JSON = OpenInferenceMimeTypeValues.JSON.value
|
|
330
|
+
TEXT = OpenInferenceMimeTypeValues.TEXT.value
|
|
350
331
|
LLM = OpenInferenceSpanKindValues.LLM.value
|
|
351
332
|
|
|
352
333
|
OPENINFERENCE_SPAN_KIND = SpanAttributes.OPENINFERENCE_SPAN_KIND
|