arize-phoenix 5.6.0__py3-none-any.whl → 5.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (39) hide show
  1. {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.8.0.dist-info}/METADATA +4 -6
  2. {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.8.0.dist-info}/RECORD +39 -30
  3. {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.8.0.dist-info}/WHEEL +1 -1
  4. phoenix/config.py +58 -0
  5. phoenix/server/api/helpers/playground_clients.py +758 -0
  6. phoenix/server/api/helpers/playground_registry.py +70 -0
  7. phoenix/server/api/helpers/playground_spans.py +422 -0
  8. phoenix/server/api/input_types/ChatCompletionInput.py +38 -0
  9. phoenix/server/api/input_types/GenerativeModelInput.py +17 -0
  10. phoenix/server/api/input_types/InvocationParameters.py +155 -13
  11. phoenix/server/api/input_types/TemplateOptions.py +10 -0
  12. phoenix/server/api/mutations/__init__.py +4 -0
  13. phoenix/server/api/mutations/chat_mutations.py +355 -0
  14. phoenix/server/api/queries.py +41 -52
  15. phoenix/server/api/schema.py +42 -10
  16. phoenix/server/api/subscriptions.py +378 -595
  17. phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +46 -0
  18. phoenix/server/api/types/GenerativeProvider.py +27 -3
  19. phoenix/server/api/types/Span.py +37 -0
  20. phoenix/server/api/types/TemplateLanguage.py +9 -0
  21. phoenix/server/app.py +75 -13
  22. phoenix/server/grpc_server.py +3 -1
  23. phoenix/server/main.py +14 -1
  24. phoenix/server/static/.vite/manifest.json +31 -31
  25. phoenix/server/static/assets/{components-C70HJiXz.js → components-MllbfxfJ.js} +168 -150
  26. phoenix/server/static/assets/{index-DLe1Oo3l.js → index-BVO2YcT1.js} +2 -2
  27. phoenix/server/static/assets/{pages-C8-Sl7JI.js → pages-BHfC6jnL.js} +464 -310
  28. phoenix/server/static/assets/{vendor-CtqfhlbC.js → vendor-BEuNhfwH.js} +1 -1
  29. phoenix/server/static/assets/{vendor-arizeai-C_3SBz56.js → vendor-arizeai-Bskhzyjm.js} +1 -1
  30. phoenix/server/static/assets/{vendor-codemirror-wfdk9cjp.js → vendor-codemirror-DLlXCf0x.js} +1 -1
  31. phoenix/server/static/assets/{vendor-recharts-BiVnSv90.js → vendor-recharts-CRqhvLYg.js} +1 -1
  32. phoenix/server/templates/index.html +1 -0
  33. phoenix/services.py +4 -0
  34. phoenix/session/session.py +15 -1
  35. phoenix/utilities/template_formatters.py +11 -1
  36. phoenix/version.py +1 -1
  37. {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.8.0.dist-info}/entry_points.txt +0 -0
  38. {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.8.0.dist-info}/licenses/IP_NOTICE +0 -0
  39. {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.8.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,70 @@
1
+ from typing import TYPE_CHECKING, Any, Callable, Optional, Union
2
+
3
+ from phoenix.server.api.types.GenerativeProvider import GenerativeProviderKey
4
+
5
+ if TYPE_CHECKING:
6
+ from phoenix.server.api.helpers.playground_clients import PlaygroundStreamingClient
7
+
8
+ ModelName = Union[str, None]
9
+ ModelKey = tuple[GenerativeProviderKey, ModelName]
10
+
11
+ PROVIDER_DEFAULT = None
12
+
13
+
14
+ class SingletonMeta(type):
15
+ _instances: dict[Any, Any] = dict()
16
+
17
+ def __call__(cls, *args: Any, **kwargs: Any) -> Any:
18
+ if cls not in cls._instances:
19
+ cls._instances[cls] = super(SingletonMeta, cls).__call__(*args, **kwargs)
20
+ return cls._instances[cls]
21
+
22
+
23
+ class PlaygroundClientRegistry(metaclass=SingletonMeta):
24
+ def __init__(self) -> None:
25
+ self._registry: dict[
26
+ GenerativeProviderKey, dict[ModelName, Optional[type["PlaygroundStreamingClient"]]]
27
+ ] = {}
28
+
29
+ def get_client(
30
+ self,
31
+ provider_key: GenerativeProviderKey,
32
+ model_name: ModelName,
33
+ ) -> Optional[type["PlaygroundStreamingClient"]]:
34
+ provider_registry = self._registry.get(provider_key, {})
35
+ client_class = provider_registry.get(model_name)
36
+ if client_class is None and None in provider_registry:
37
+ client_class = provider_registry[PROVIDER_DEFAULT] # Fallback to provider default
38
+ return client_class
39
+
40
+ def list_all_providers(
41
+ self,
42
+ ) -> list[GenerativeProviderKey]:
43
+ return [provider_key for provider_key in self._registry]
44
+
45
+ def list_models(self, provider_key: GenerativeProviderKey) -> list[str]:
46
+ provider_registry = self._registry.get(provider_key, {})
47
+ return [model_name for model_name in provider_registry.keys() if model_name is not None]
48
+
49
+ def list_all_models(self) -> list[ModelKey]:
50
+ return [
51
+ (provider_key, model_name)
52
+ for provider_key, provider_registry in self._registry.items()
53
+ for model_name in provider_registry.keys()
54
+ ]
55
+
56
+
57
+ PLAYGROUND_CLIENT_REGISTRY: PlaygroundClientRegistry = PlaygroundClientRegistry()
58
+
59
+
60
+ def register_llm_client(
61
+ provider_key: GenerativeProviderKey,
62
+ model_names: list[ModelName],
63
+ ) -> Callable[[type["PlaygroundStreamingClient"]], type["PlaygroundStreamingClient"]]:
64
+ def decorator(cls: type["PlaygroundStreamingClient"]) -> type["PlaygroundStreamingClient"]:
65
+ provider_registry = PLAYGROUND_CLIENT_REGISTRY._registry.setdefault(provider_key, {})
66
+ for model_name in model_names:
67
+ provider_registry[model_name] = cls
68
+ return cls
69
+
70
+ return decorator
@@ -0,0 +1,422 @@
1
+ import json
2
+ from collections import defaultdict
3
+ from collections.abc import Mapping
4
+ from dataclasses import asdict
5
+ from datetime import datetime, timezone
6
+ from itertools import chain
7
+ from traceback import format_exc
8
+ from types import TracebackType
9
+ from typing import (
10
+ Any,
11
+ Iterable,
12
+ Iterator,
13
+ Optional,
14
+ Union,
15
+ cast,
16
+ )
17
+
18
+ from openinference.instrumentation import safe_json_dumps
19
+ from openinference.semconv.trace import (
20
+ MessageAttributes,
21
+ OpenInferenceMimeTypeValues,
22
+ OpenInferenceSpanKindValues,
23
+ SpanAttributes,
24
+ ToolAttributes,
25
+ ToolCallAttributes,
26
+ )
27
+ from opentelemetry.sdk.trace.id_generator import RandomIdGenerator as DefaultOTelIDGenerator
28
+ from opentelemetry.trace import StatusCode
29
+ from strawberry.scalars import JSON as JSONScalarType
30
+ from typing_extensions import Self, TypeAlias, assert_never
31
+
32
+ from phoenix.datetime_utils import local_now, normalize_datetime
33
+ from phoenix.db import models
34
+ from phoenix.server.api.input_types.ChatCompletionInput import (
35
+ ChatCompletionInput,
36
+ ChatCompletionOverDatasetInput,
37
+ )
38
+ from phoenix.server.api.types.ChatCompletionMessageRole import ChatCompletionMessageRole
39
+ from phoenix.server.api.types.ChatCompletionSubscriptionPayload import (
40
+ TextChunk,
41
+ ToolCallChunk,
42
+ )
43
+ from phoenix.trace.attributes import get_attribute_value, unflatten
44
+ from phoenix.trace.schemas import (
45
+ SpanEvent,
46
+ SpanException,
47
+ )
48
+ from phoenix.utilities.json import jsonify
49
+
50
+ ChatCompletionMessage: TypeAlias = tuple[
51
+ ChatCompletionMessageRole, str, Optional[str], Optional[list[str]]
52
+ ]
53
+ ToolCallID: TypeAlias = str
54
+
55
+
56
+ class streaming_llm_span:
57
+ """
58
+ A context manager that records OpenInference attributes for streaming chat
59
+ completion LLM spans.
60
+ """
61
+
62
+ def __init__(
63
+ self,
64
+ *,
65
+ input: Union[ChatCompletionInput, ChatCompletionOverDatasetInput],
66
+ messages: list[ChatCompletionMessage],
67
+ invocation_parameters: Mapping[str, Any],
68
+ attributes: Optional[dict[str, Any]] = None,
69
+ ) -> None:
70
+ self._input = input
71
+ self._attributes: dict[str, Any] = attributes if attributes is not None else {}
72
+ self._attributes.update(
73
+ chain(
74
+ llm_span_kind(),
75
+ llm_model_name(input.model.name),
76
+ llm_tools(input.tools or []),
77
+ llm_input_messages(messages),
78
+ llm_invocation_parameters(invocation_parameters),
79
+ input_value_and_mime_type(input),
80
+ )
81
+ )
82
+ self._events: list[SpanEvent] = []
83
+ self._start_time: Optional[datetime] = None
84
+ self._end_time: Optional[datetime] = None
85
+ self._text_chunks: list[TextChunk] = []
86
+ self._tool_call_chunks: defaultdict[ToolCallID, list[ToolCallChunk]] = defaultdict(list)
87
+ self._status_code: StatusCode = StatusCode.UNSET
88
+ self._status_message: Optional[str] = None
89
+ self._trace_id = _generate_trace_id()
90
+ self._span_id = _generate_span_id()
91
+
92
+ async def __aenter__(self) -> Self:
93
+ self._start_time = cast(datetime, normalize_datetime(dt=local_now(), tz=timezone.utc))
94
+ return self
95
+
96
+ async def __aexit__(
97
+ self,
98
+ exc_type: Optional[type[BaseException]],
99
+ exc_value: Optional[BaseException],
100
+ traceback: Optional[TracebackType],
101
+ ) -> bool:
102
+ self._end_time = cast(datetime, normalize_datetime(dt=local_now(), tz=timezone.utc))
103
+ self._status_code = StatusCode.OK
104
+ if exc_type is not None:
105
+ self._status_code = StatusCode.ERROR
106
+ self._status_message = str(exc_value)
107
+ self._events.append(
108
+ SpanException(
109
+ timestamp=self._end_time,
110
+ message=self._status_message,
111
+ exception_type=type(exc_value).__name__,
112
+ exception_escaped=False,
113
+ exception_stacktrace=format_exc(),
114
+ )
115
+ )
116
+ if self._text_chunks or self._tool_call_chunks:
117
+ self._attributes.update(
118
+ chain(
119
+ _output_value_and_mime_type(self._text_chunks, self._tool_call_chunks),
120
+ _llm_output_messages(self._text_chunks, self._tool_call_chunks),
121
+ )
122
+ )
123
+ return True
124
+
125
+ def set_attributes(self, attributes: Mapping[str, Any]) -> None:
126
+ self._attributes.update(attributes)
127
+
128
+ def add_response_chunk(self, chunk: Union[TextChunk, ToolCallChunk]) -> None:
129
+ if isinstance(chunk, TextChunk):
130
+ self._text_chunks.append(chunk)
131
+ elif isinstance(chunk, ToolCallChunk):
132
+ self._tool_call_chunks[chunk.id].append(chunk)
133
+ else:
134
+ assert_never(chunk)
135
+
136
+ @property
137
+ def span_id(self) -> str:
138
+ return self._span_id
139
+
140
+ @property
141
+ def trace_id(self) -> str:
142
+ return self._trace_id
143
+
144
+ @property
145
+ def start_time(self) -> datetime:
146
+ if self._start_time is None:
147
+ raise ValueError("Cannot access start time before the context manager is entered")
148
+ return self._start_time
149
+
150
+ @property
151
+ def end_time(self) -> datetime:
152
+ if self._end_time is None:
153
+ raise ValueError("Cannot access end time before the context manager is exited")
154
+ return self._end_time
155
+
156
+ @property
157
+ def status_code(self) -> StatusCode:
158
+ return self._status_code
159
+
160
+ @property
161
+ def status_message(self) -> Optional[str]:
162
+ if self._status_code is StatusCode.UNSET:
163
+ raise ValueError("Cannot access status message before the context manager is exited")
164
+ return self._status_message
165
+
166
+ @property
167
+ def events(self) -> list[SpanEvent]:
168
+ return self._events
169
+
170
+ @property
171
+ def attributes(self) -> dict[str, Any]:
172
+ return unflatten(self._attributes.items())
173
+
174
+
175
+ def get_db_trace(span: streaming_llm_span, project_id: int) -> models.Trace:
176
+ return models.Trace(
177
+ project_rowid=project_id,
178
+ trace_id=span.trace_id,
179
+ start_time=span.start_time,
180
+ end_time=span.end_time,
181
+ )
182
+
183
+
184
+ def get_db_span(
185
+ span: streaming_llm_span,
186
+ db_trace: models.Trace,
187
+ ) -> models.Span:
188
+ prompt_tokens = get_attribute_value(span.attributes, LLM_TOKEN_COUNT_PROMPT) or 0
189
+ completion_tokens = get_attribute_value(span.attributes, LLM_TOKEN_COUNT_COMPLETION) or 0
190
+ return models.Span(
191
+ trace_rowid=db_trace.id,
192
+ span_id=span.span_id,
193
+ parent_id=None,
194
+ name="ChatCompletion",
195
+ span_kind=LLM,
196
+ start_time=span.start_time,
197
+ end_time=span.end_time,
198
+ attributes=span.attributes,
199
+ events=[_serialize_event(event) for event in span.events],
200
+ status_code=span.status_code.name,
201
+ status_message=span.status_message or "",
202
+ cumulative_error_count=int(span.status_code is StatusCode.ERROR),
203
+ cumulative_llm_token_count_prompt=prompt_tokens,
204
+ cumulative_llm_token_count_completion=completion_tokens,
205
+ llm_token_count_prompt=prompt_tokens,
206
+ llm_token_count_completion=completion_tokens,
207
+ trace=db_trace,
208
+ )
209
+
210
+
211
+ def get_db_experiment_run(
212
+ db_span: models.Span,
213
+ db_trace: models.Trace,
214
+ *,
215
+ experiment_id: int,
216
+ example_id: int,
217
+ ) -> models.ExperimentRun:
218
+ return models.ExperimentRun(
219
+ experiment_id=experiment_id,
220
+ dataset_example_id=example_id,
221
+ trace_id=db_trace.trace_id,
222
+ output=models.ExperimentRunOutput(
223
+ task_output=get_attribute_value(db_span.attributes, LLM_OUTPUT_MESSAGES),
224
+ ),
225
+ repetition_number=1,
226
+ start_time=db_span.start_time,
227
+ end_time=db_span.end_time,
228
+ error=db_span.status_message or None,
229
+ prompt_token_count=get_attribute_value(db_span.attributes, LLM_TOKEN_COUNT_PROMPT),
230
+ completion_token_count=get_attribute_value(db_span.attributes, LLM_TOKEN_COUNT_COMPLETION),
231
+ trace=db_trace,
232
+ )
233
+
234
+
235
+ def llm_span_kind() -> Iterator[tuple[str, Any]]:
236
+ yield OPENINFERENCE_SPAN_KIND, LLM
237
+
238
+
239
+ def llm_model_name(model_name: str) -> Iterator[tuple[str, Any]]:
240
+ yield LLM_MODEL_NAME, model_name
241
+
242
+
243
+ def llm_invocation_parameters(
244
+ invocation_parameters: Mapping[str, Any],
245
+ ) -> Iterator[tuple[str, Any]]:
246
+ if invocation_parameters:
247
+ yield LLM_INVOCATION_PARAMETERS, safe_json_dumps(invocation_parameters)
248
+
249
+
250
+ def llm_tools(tools: list[JSONScalarType]) -> Iterator[tuple[str, Any]]:
251
+ for tool_index, tool in enumerate(tools):
252
+ yield f"{LLM_TOOLS}.{tool_index}.{TOOL_JSON_SCHEMA}", json.dumps(tool)
253
+
254
+
255
+ def input_value_and_mime_type(
256
+ input: Union[ChatCompletionInput, ChatCompletionOverDatasetInput],
257
+ ) -> Iterator[tuple[str, Any]]:
258
+ assert (api_key := "api_key") in (input_data := jsonify(input))
259
+ disallowed_keys = {"api_key", "invocation_parameters"}
260
+ input_data = {k: v for k, v in input_data.items() if k not in disallowed_keys}
261
+ assert api_key not in input_data
262
+ yield INPUT_MIME_TYPE, JSON
263
+ yield INPUT_VALUE, safe_json_dumps(input_data)
264
+
265
+
266
+ def _merge_tool_call_chunks(
267
+ chunks_by_id: defaultdict[str, list[ToolCallChunk]],
268
+ ) -> list[dict[str, Any]]:
269
+ merged_tool_calls = []
270
+
271
+ for tool_id, chunks in chunks_by_id.items():
272
+ if not chunks:
273
+ continue
274
+ first_chunk = chunks[0]
275
+ if not first_chunk:
276
+ continue
277
+
278
+ if not hasattr(first_chunk, "function") or not hasattr(first_chunk.function, "name"):
279
+ continue
280
+ # Combine all argument chunks
281
+ merged_arguments = "".join(
282
+ chunk.function.arguments
283
+ for chunk in chunks
284
+ if chunk and hasattr(chunk, "function") and hasattr(chunk.function, "arguments")
285
+ )
286
+
287
+ merged_tool_calls.append(
288
+ {
289
+ "id": tool_id,
290
+ # Only the first chunk has the tool name
291
+ "function": {
292
+ "name": first_chunk.function.name,
293
+ "arguments": merged_arguments or "{}",
294
+ },
295
+ }
296
+ )
297
+
298
+ return merged_tool_calls
299
+
300
+
301
+ def _output_value_and_mime_type(
302
+ text_chunks: list[TextChunk],
303
+ tool_call_chunks: defaultdict[ToolCallID, list[ToolCallChunk]],
304
+ ) -> Iterator[tuple[str, Any]]:
305
+ content = "".join(chunk.content for chunk in text_chunks)
306
+ merged_tool_calls = _merge_tool_call_chunks(tool_call_chunks)
307
+ if content and merged_tool_calls:
308
+ yield OUTPUT_MIME_TYPE, JSON
309
+ yield (
310
+ OUTPUT_VALUE,
311
+ safe_json_dumps(
312
+ {
313
+ "content": content,
314
+ "tool_calls": jsonify(
315
+ merged_tool_calls,
316
+ ),
317
+ }
318
+ ),
319
+ )
320
+ elif merged_tool_calls:
321
+ yield OUTPUT_MIME_TYPE, JSON
322
+ yield OUTPUT_VALUE, safe_json_dumps(jsonify(merged_tool_calls))
323
+ elif content:
324
+ yield OUTPUT_MIME_TYPE, TEXT
325
+ yield OUTPUT_VALUE, content
326
+
327
+
328
+ def llm_input_messages(
329
+ messages: Iterable[
330
+ tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[JSONScalarType]]]
331
+ ],
332
+ ) -> Iterator[tuple[str, Any]]:
333
+ for i, (role, content, _tool_call_id, tool_calls) in enumerate(messages):
334
+ yield f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_ROLE}", role.value.lower()
335
+ yield f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_CONTENT}", content
336
+ if tool_calls is not None:
337
+ for tool_call_index, tool_call in enumerate(tool_calls):
338
+ yield (
339
+ f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_NAME}",
340
+ tool_call["function"]["name"],
341
+ )
342
+ if arguments := tool_call["function"]["arguments"]:
343
+ yield (
344
+ f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_ARGUMENTS_JSON}",
345
+ safe_json_dumps(jsonify(arguments)),
346
+ )
347
+
348
+
349
+ def _llm_output_messages(
350
+ text_chunks: list[TextChunk],
351
+ tool_call_chunks: defaultdict[ToolCallID, list[ToolCallChunk]],
352
+ ) -> Iterator[tuple[str, Any]]:
353
+ yield f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_ROLE}", "assistant"
354
+ if content := "".join(chunk.content for chunk in text_chunks):
355
+ yield f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_CONTENT}", content
356
+ for tool_call_index, (_tool_call_id, tool_call_chunks_) in enumerate(tool_call_chunks.items()):
357
+ if tool_call_chunks_ and (name := tool_call_chunks_[0].function.name):
358
+ yield (
359
+ f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_NAME}",
360
+ name,
361
+ )
362
+ if arguments := "".join(chunk.function.arguments for chunk in tool_call_chunks_):
363
+ yield (
364
+ f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_ARGUMENTS_JSON}",
365
+ arguments,
366
+ )
367
+
368
+
369
+ def _generate_trace_id() -> str:
370
+ """
371
+ Generates a random trace ID in hexadecimal format.
372
+ """
373
+ return _hex(DefaultOTelIDGenerator().generate_trace_id())
374
+
375
+
376
+ def _generate_span_id() -> str:
377
+ """
378
+ Generates a random span ID in hexadecimal format.
379
+ """
380
+ return _hex(DefaultOTelIDGenerator().generate_span_id())
381
+
382
+
383
+ def _hex(number: int) -> str:
384
+ """
385
+ Converts an integer to a hexadecimal string.
386
+ """
387
+ return hex(number)[2:]
388
+
389
+
390
+ def _serialize_event(event: SpanEvent) -> dict[str, Any]:
391
+ """
392
+ Serializes a SpanEvent to a dictionary.
393
+ """
394
+ return {k: (v.isoformat() if isinstance(v, datetime) else v) for k, v in asdict(event).items()}
395
+
396
+
397
+ JSON = OpenInferenceMimeTypeValues.JSON.value
398
+ TEXT = OpenInferenceMimeTypeValues.TEXT.value
399
+
400
+ LLM = OpenInferenceSpanKindValues.LLM.value
401
+
402
+ OPENINFERENCE_SPAN_KIND = SpanAttributes.OPENINFERENCE_SPAN_KIND
403
+ INPUT_MIME_TYPE = SpanAttributes.INPUT_MIME_TYPE
404
+ INPUT_VALUE = SpanAttributes.INPUT_VALUE
405
+ OUTPUT_MIME_TYPE = SpanAttributes.OUTPUT_MIME_TYPE
406
+ OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE
407
+ LLM_INPUT_MESSAGES = SpanAttributes.LLM_INPUT_MESSAGES
408
+ LLM_OUTPUT_MESSAGES = SpanAttributes.LLM_OUTPUT_MESSAGES
409
+ LLM_MODEL_NAME = SpanAttributes.LLM_MODEL_NAME
410
+ LLM_INVOCATION_PARAMETERS = SpanAttributes.LLM_INVOCATION_PARAMETERS
411
+ LLM_TOOLS = SpanAttributes.LLM_TOOLS
412
+ LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT
413
+ LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION
414
+
415
+ MESSAGE_CONTENT = MessageAttributes.MESSAGE_CONTENT
416
+ MESSAGE_ROLE = MessageAttributes.MESSAGE_ROLE
417
+ MESSAGE_TOOL_CALLS = MessageAttributes.MESSAGE_TOOL_CALLS
418
+
419
+ TOOL_CALL_FUNCTION_NAME = ToolCallAttributes.TOOL_CALL_FUNCTION_NAME
420
+ TOOL_CALL_FUNCTION_ARGUMENTS_JSON = ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUMENTS_JSON
421
+
422
+ TOOL_JSON_SCHEMA = ToolAttributes.TOOL_JSON_SCHEMA
@@ -0,0 +1,38 @@
1
+ from typing import Optional
2
+
3
+ import strawberry
4
+ from strawberry import UNSET
5
+ from strawberry.relay.types import GlobalID
6
+ from strawberry.scalars import JSON
7
+
8
+ from phoenix.server.api.types.TemplateLanguage import TemplateLanguage
9
+
10
+ from .ChatCompletionMessageInput import ChatCompletionMessageInput
11
+ from .GenerativeModelInput import GenerativeModelInput
12
+ from .InvocationParameters import InvocationParameterInput
13
+ from .TemplateOptions import TemplateOptions
14
+
15
+
16
+ @strawberry.input
17
+ class ChatCompletionInput:
18
+ messages: list[ChatCompletionMessageInput]
19
+ model: GenerativeModelInput
20
+ invocation_parameters: list[InvocationParameterInput] = strawberry.field(default_factory=list)
21
+ tools: Optional[list[JSON]] = UNSET
22
+ api_key: Optional[str] = strawberry.field(default=None)
23
+ template: Optional[TemplateOptions] = UNSET
24
+
25
+
26
+ @strawberry.input
27
+ class ChatCompletionOverDatasetInput:
28
+ messages: list[ChatCompletionMessageInput]
29
+ model: GenerativeModelInput
30
+ invocation_parameters: list[InvocationParameterInput] = strawberry.field(default_factory=list)
31
+ tools: Optional[list[JSON]] = UNSET
32
+ api_key: Optional[str] = strawberry.field(default=None)
33
+ template_language: TemplateLanguage
34
+ dataset_id: GlobalID
35
+ dataset_version_id: Optional[GlobalID] = None
36
+ experiment_name: Optional[str] = None
37
+ experiment_description: Optional[str] = None
38
+ experiment_metadata: Optional[JSON] = strawberry.field(default_factory=dict)
@@ -0,0 +1,17 @@
1
+ from typing import Optional
2
+
3
+ import strawberry
4
+ from strawberry import UNSET
5
+
6
+ from phoenix.server.api.types.GenerativeProvider import GenerativeProviderKey
7
+
8
+
9
+ @strawberry.input
10
+ class GenerativeModelInput:
11
+ provider_key: GenerativeProviderKey
12
+ name: str
13
+ """ The name of the model. Or the Deployment Name for Azure OpenAI models. """
14
+ endpoint: Optional[str] = UNSET
15
+ """ The endpoint to use for the model. Only required for Azure OpenAI models. """
16
+ api_version: Optional[str] = UNSET
17
+ """ The API version to use for the model. """