arize-phoenix 5.6.0__py3-none-any.whl → 5.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (34) hide show
  1. {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.7.0.dist-info}/METADATA +2 -2
  2. {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.7.0.dist-info}/RECORD +34 -25
  3. phoenix/config.py +42 -0
  4. phoenix/server/api/helpers/playground_clients.py +671 -0
  5. phoenix/server/api/helpers/playground_registry.py +70 -0
  6. phoenix/server/api/helpers/playground_spans.py +325 -0
  7. phoenix/server/api/input_types/ChatCompletionInput.py +38 -0
  8. phoenix/server/api/input_types/GenerativeModelInput.py +17 -0
  9. phoenix/server/api/input_types/InvocationParameters.py +156 -13
  10. phoenix/server/api/input_types/TemplateOptions.py +10 -0
  11. phoenix/server/api/mutations/__init__.py +4 -0
  12. phoenix/server/api/mutations/chat_mutations.py +374 -0
  13. phoenix/server/api/queries.py +41 -52
  14. phoenix/server/api/schema.py +42 -10
  15. phoenix/server/api/subscriptions.py +326 -595
  16. phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +44 -0
  17. phoenix/server/api/types/GenerativeProvider.py +27 -3
  18. phoenix/server/api/types/Span.py +37 -0
  19. phoenix/server/api/types/TemplateLanguage.py +9 -0
  20. phoenix/server/app.py +61 -13
  21. phoenix/server/main.py +14 -1
  22. phoenix/server/static/.vite/manifest.json +9 -9
  23. phoenix/server/static/assets/{components-C70HJiXz.js → components-Csu8UKOs.js} +114 -114
  24. phoenix/server/static/assets/{index-DLe1Oo3l.js → index-Bk5C9EA7.js} +1 -1
  25. phoenix/server/static/assets/{pages-C8-Sl7JI.js → pages-UeWaKXNs.js} +328 -268
  26. phoenix/server/templates/index.html +1 -0
  27. phoenix/services.py +4 -0
  28. phoenix/session/session.py +15 -1
  29. phoenix/utilities/template_formatters.py +11 -1
  30. phoenix/version.py +1 -1
  31. {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.7.0.dist-info}/WHEEL +0 -0
  32. {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.7.0.dist-info}/entry_points.txt +0 -0
  33. {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.7.0.dist-info}/licenses/IP_NOTICE +0 -0
  34. {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.7.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,70 @@
1
+ from typing import TYPE_CHECKING, Any, Callable, Optional, Union
2
+
3
+ from phoenix.server.api.types.GenerativeProvider import GenerativeProviderKey
4
+
5
+ if TYPE_CHECKING:
6
+ from phoenix.server.api.helpers.playground_clients import PlaygroundStreamingClient
7
+
8
+ ModelName = Union[str, None]
9
+ ModelKey = tuple[GenerativeProviderKey, ModelName]
10
+
11
+ PROVIDER_DEFAULT = None
12
+
13
+
14
+ class SingletonMeta(type):
15
+ _instances: dict[Any, Any] = dict()
16
+
17
+ def __call__(cls, *args: Any, **kwargs: Any) -> Any:
18
+ if cls not in cls._instances:
19
+ cls._instances[cls] = super(SingletonMeta, cls).__call__(*args, **kwargs)
20
+ return cls._instances[cls]
21
+
22
+
23
+ class PlaygroundClientRegistry(metaclass=SingletonMeta):
24
+ def __init__(self) -> None:
25
+ self._registry: dict[
26
+ GenerativeProviderKey, dict[ModelName, Optional[type["PlaygroundStreamingClient"]]]
27
+ ] = {}
28
+
29
+ def get_client(
30
+ self,
31
+ provider_key: GenerativeProviderKey,
32
+ model_name: ModelName,
33
+ ) -> Optional[type["PlaygroundStreamingClient"]]:
34
+ provider_registry = self._registry.get(provider_key, {})
35
+ client_class = provider_registry.get(model_name)
36
+ if client_class is None and None in provider_registry:
37
+ client_class = provider_registry[PROVIDER_DEFAULT] # Fallback to provider default
38
+ return client_class
39
+
40
+ def list_all_providers(
41
+ self,
42
+ ) -> list[GenerativeProviderKey]:
43
+ return [provider_key for provider_key in self._registry]
44
+
45
+ def list_models(self, provider_key: GenerativeProviderKey) -> list[str]:
46
+ provider_registry = self._registry.get(provider_key, {})
47
+ return [model_name for model_name in provider_registry.keys() if model_name is not None]
48
+
49
+ def list_all_models(self) -> list[ModelKey]:
50
+ return [
51
+ (provider_key, model_name)
52
+ for provider_key, provider_registry in self._registry.items()
53
+ for model_name in provider_registry.keys()
54
+ ]
55
+
56
+
57
+ PLAYGROUND_CLIENT_REGISTRY: PlaygroundClientRegistry = PlaygroundClientRegistry()
58
+
59
+
60
+ def register_llm_client(
61
+ provider_key: GenerativeProviderKey,
62
+ model_names: list[ModelName],
63
+ ) -> Callable[[type["PlaygroundStreamingClient"]], type["PlaygroundStreamingClient"]]:
64
+ def decorator(cls: type["PlaygroundStreamingClient"]) -> type["PlaygroundStreamingClient"]:
65
+ provider_registry = PLAYGROUND_CLIENT_REGISTRY._registry.setdefault(provider_key, {})
66
+ for model_name in model_names:
67
+ provider_registry[model_name] = cls
68
+ return cls
69
+
70
+ return decorator
@@ -0,0 +1,325 @@
1
+ import json
2
+ from collections import defaultdict
3
+ from collections.abc import Mapping
4
+ from dataclasses import asdict
5
+ from datetime import datetime, timezone
6
+ from itertools import chain
7
+ from traceback import format_exc
8
+ from types import TracebackType
9
+ from typing import (
10
+ Any,
11
+ Iterable,
12
+ Iterator,
13
+ Optional,
14
+ Union,
15
+ cast,
16
+ )
17
+
18
+ from openinference.instrumentation import safe_json_dumps
19
+ from openinference.semconv.trace import (
20
+ MessageAttributes,
21
+ OpenInferenceMimeTypeValues,
22
+ OpenInferenceSpanKindValues,
23
+ SpanAttributes,
24
+ ToolAttributes,
25
+ ToolCallAttributes,
26
+ )
27
+ from opentelemetry.sdk.trace.id_generator import RandomIdGenerator as DefaultOTelIDGenerator
28
+ from opentelemetry.trace import StatusCode
29
+ from sqlalchemy.ext.asyncio import AsyncSession
30
+ from strawberry.scalars import JSON as JSONScalarType
31
+ from typing_extensions import Self, TypeAlias, assert_never
32
+
33
+ from phoenix.datetime_utils import local_now, normalize_datetime
34
+ from phoenix.db import models
35
+ from phoenix.server.api.input_types.ChatCompletionInput import (
36
+ ChatCompletionInput,
37
+ ChatCompletionOverDatasetInput,
38
+ )
39
+ from phoenix.server.api.types.ChatCompletionMessageRole import ChatCompletionMessageRole
40
+ from phoenix.server.api.types.ChatCompletionSubscriptionPayload import (
41
+ TextChunk,
42
+ ToolCallChunk,
43
+ )
44
+ from phoenix.trace.attributes import unflatten
45
+ from phoenix.trace.schemas import (
46
+ SpanEvent,
47
+ SpanException,
48
+ )
49
+ from phoenix.utilities.json import jsonify
50
+
51
+ ChatCompletionMessage: TypeAlias = tuple[
52
+ ChatCompletionMessageRole, str, Optional[str], Optional[list[str]]
53
+ ]
54
+ ToolCallID: TypeAlias = str
55
+
56
+
57
+ class streaming_llm_span:
58
+ """
59
+ Creates an LLM span for a streaming chat completion.
60
+ """
61
+
62
+ def __init__(
63
+ self,
64
+ *,
65
+ input: Union[ChatCompletionInput, ChatCompletionOverDatasetInput],
66
+ messages: list[ChatCompletionMessage],
67
+ invocation_parameters: Mapping[str, Any],
68
+ attributes: Optional[dict[str, Any]] = None,
69
+ ) -> None:
70
+ self._input = input
71
+ self._attributes: dict[str, Any] = attributes if attributes is not None else {}
72
+ self._attributes.update(
73
+ chain(
74
+ _llm_span_kind(),
75
+ _llm_model_name(input.model.name),
76
+ _llm_tools(input.tools or []),
77
+ _llm_input_messages(messages),
78
+ _llm_invocation_parameters(invocation_parameters),
79
+ _input_value_and_mime_type(input),
80
+ )
81
+ )
82
+ self._events: list[SpanEvent] = []
83
+ self._start_time: datetime
84
+ self._end_time: datetime
85
+ self._response_chunks: list[Union[TextChunk, ToolCallChunk]] = []
86
+ self._text_chunks: list[TextChunk] = []
87
+ self._tool_call_chunks: defaultdict[ToolCallID, list[ToolCallChunk]] = defaultdict(list)
88
+ self._status_code: StatusCode
89
+ self._status_message: str
90
+ self._db_span: models.Span
91
+ self._db_trace: models.Trace
92
+
93
+ async def __aenter__(self) -> Self:
94
+ self._start_time = cast(datetime, normalize_datetime(dt=local_now(), tz=timezone.utc))
95
+ return self
96
+
97
+ async def __aexit__(
98
+ self,
99
+ exc_type: Optional[type[BaseException]],
100
+ exc_value: Optional[BaseException],
101
+ traceback: Optional[TracebackType],
102
+ ) -> bool:
103
+ self._end_time = cast(datetime, normalize_datetime(dt=local_now(), tz=timezone.utc))
104
+ self._status_code = StatusCode.OK
105
+ self._status_message = ""
106
+ if exc_type is not None:
107
+ self._status_code = StatusCode.ERROR
108
+ self._status_message = str(exc_value)
109
+ self._events.append(
110
+ SpanException(
111
+ timestamp=self._end_time,
112
+ message=self._status_message,
113
+ exception_type=type(exc_value).__name__,
114
+ exception_escaped=False,
115
+ exception_stacktrace=format_exc(),
116
+ )
117
+ )
118
+ if self._response_chunks:
119
+ self._attributes.update(
120
+ chain(
121
+ _output_value_and_mime_type(self._response_chunks),
122
+ _llm_output_messages(self._text_chunks, self._tool_call_chunks),
123
+ )
124
+ )
125
+ return True
126
+
127
+ def set_attributes(self, attributes: Mapping[str, Any]) -> None:
128
+ self._attributes.update(attributes)
129
+
130
+ def add_to_session(
131
+ self,
132
+ session: AsyncSession,
133
+ project_id: int,
134
+ ) -> models.Span:
135
+ prompt_tokens = self._attributes.get(LLM_TOKEN_COUNT_PROMPT, 0)
136
+ completion_tokens = self._attributes.get(LLM_TOKEN_COUNT_COMPLETION, 0)
137
+ trace_id = _generate_trace_id()
138
+ span_id = _generate_span_id()
139
+ self._db_trace = models.Trace(
140
+ project_rowid=project_id,
141
+ trace_id=trace_id,
142
+ start_time=self._start_time,
143
+ end_time=self._end_time,
144
+ )
145
+ self._db_span = models.Span(
146
+ trace_rowid=self._db_trace.id,
147
+ span_id=span_id,
148
+ parent_id=None,
149
+ name="ChatCompletion",
150
+ span_kind=LLM,
151
+ start_time=self._start_time,
152
+ end_time=self._end_time,
153
+ attributes=unflatten(self._attributes.items()),
154
+ events=[_serialize_event(event) for event in self._events],
155
+ status_code=self._status_code.name,
156
+ status_message=self._status_message,
157
+ cumulative_error_count=int(self._status_code is StatusCode.ERROR),
158
+ cumulative_llm_token_count_prompt=prompt_tokens,
159
+ cumulative_llm_token_count_completion=completion_tokens,
160
+ llm_token_count_prompt=prompt_tokens,
161
+ llm_token_count_completion=completion_tokens,
162
+ trace=self._db_trace,
163
+ )
164
+ session.add(self._db_trace)
165
+ session.add(self._db_span)
166
+ return self._db_span
167
+
168
+ def add_response_chunk(self, chunk: Union[TextChunk, ToolCallChunk]) -> None:
169
+ self._response_chunks.append(chunk)
170
+ if isinstance(chunk, TextChunk):
171
+ self._text_chunks.append(chunk)
172
+ elif isinstance(chunk, ToolCallChunk):
173
+ self._tool_call_chunks[chunk.id].append(chunk)
174
+ else:
175
+ assert_never(chunk)
176
+
177
+ @property
178
+ def start_time(self) -> datetime:
179
+ return self._db_span.start_time
180
+
181
+ @property
182
+ def end_time(self) -> datetime:
183
+ return self._db_span.end_time
184
+
185
+ @property
186
+ def error_message(self) -> Optional[str]:
187
+ return self._status_message if self._status_code is StatusCode.ERROR else None
188
+
189
+ @property
190
+ def trace_id(self) -> str:
191
+ return self._db_trace.trace_id
192
+
193
+ @property
194
+ def attributes(self) -> dict[str, Any]:
195
+ return self._db_span.attributes
196
+
197
+
198
+ def _llm_span_kind() -> Iterator[tuple[str, Any]]:
199
+ yield OPENINFERENCE_SPAN_KIND, LLM
200
+
201
+
202
+ def _llm_model_name(model_name: str) -> Iterator[tuple[str, Any]]:
203
+ yield LLM_MODEL_NAME, model_name
204
+
205
+
206
+ def _llm_invocation_parameters(
207
+ invocation_parameters: Mapping[str, Any],
208
+ ) -> Iterator[tuple[str, Any]]:
209
+ if invocation_parameters:
210
+ yield LLM_INVOCATION_PARAMETERS, safe_json_dumps(invocation_parameters)
211
+
212
+
213
+ def _llm_tools(tools: list[JSONScalarType]) -> Iterator[tuple[str, Any]]:
214
+ for tool_index, tool in enumerate(tools):
215
+ yield f"{LLM_TOOLS}.{tool_index}.{TOOL_JSON_SCHEMA}", json.dumps(tool)
216
+
217
+
218
+ def _input_value_and_mime_type(input: Any) -> Iterator[tuple[str, Any]]:
219
+ assert (api_key := "api_key") in (input_data := jsonify(input))
220
+ disallowed_keys = {"api_key", "invocation_parameters"}
221
+ input_data = {k: v for k, v in input_data.items() if k not in disallowed_keys}
222
+ assert api_key not in input_data
223
+ yield INPUT_MIME_TYPE, JSON
224
+ yield INPUT_VALUE, safe_json_dumps(input_data)
225
+
226
+
227
+ def _output_value_and_mime_type(output: Any) -> Iterator[tuple[str, Any]]:
228
+ yield OUTPUT_MIME_TYPE, JSON
229
+ yield OUTPUT_VALUE, safe_json_dumps(jsonify(output))
230
+
231
+
232
+ def _llm_input_messages(
233
+ messages: Iterable[
234
+ tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[JSONScalarType]]]
235
+ ],
236
+ ) -> Iterator[tuple[str, Any]]:
237
+ for i, (role, content, _tool_call_id, tool_calls) in enumerate(messages):
238
+ yield f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_ROLE}", role.value.lower()
239
+ yield f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_CONTENT}", content
240
+ if tool_calls is not None:
241
+ for tool_call_index, tool_call in enumerate(tool_calls):
242
+ yield (
243
+ f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_NAME}",
244
+ tool_call["function"]["name"],
245
+ )
246
+ if arguments := tool_call["function"]["arguments"]:
247
+ yield (
248
+ f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_ARGUMENTS_JSON}",
249
+ safe_json_dumps(jsonify(arguments)),
250
+ )
251
+
252
+
253
+ def _llm_output_messages(
254
+ text_chunks: list[TextChunk],
255
+ tool_call_chunks: defaultdict[ToolCallID, list[ToolCallChunk]],
256
+ ) -> Iterator[tuple[str, Any]]:
257
+ yield f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_ROLE}", "assistant"
258
+ if content := "".join(chunk.content for chunk in text_chunks):
259
+ yield f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_CONTENT}", content
260
+ for tool_call_index, (_tool_call_id, tool_call_chunks_) in enumerate(tool_call_chunks.items()):
261
+ if tool_call_chunks_ and (name := tool_call_chunks_[0].function.name):
262
+ yield (
263
+ f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_NAME}",
264
+ name,
265
+ )
266
+ if arguments := "".join(chunk.function.arguments for chunk in tool_call_chunks_):
267
+ yield (
268
+ f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_ARGUMENTS_JSON}",
269
+ arguments,
270
+ )
271
+
272
+
273
+ def _generate_trace_id() -> str:
274
+ """
275
+ Generates a random trace ID in hexadecimal format.
276
+ """
277
+ return _hex(DefaultOTelIDGenerator().generate_trace_id())
278
+
279
+
280
+ def _generate_span_id() -> str:
281
+ """
282
+ Generates a random span ID in hexadecimal format.
283
+ """
284
+ return _hex(DefaultOTelIDGenerator().generate_span_id())
285
+
286
+
287
+ def _hex(number: int) -> str:
288
+ """
289
+ Converts an integer to a hexadecimal string.
290
+ """
291
+ return hex(number)[2:]
292
+
293
+
294
+ def _serialize_event(event: SpanEvent) -> dict[str, Any]:
295
+ """
296
+ Serializes a SpanEvent to a dictionary.
297
+ """
298
+ return {k: (v.isoformat() if isinstance(v, datetime) else v) for k, v in asdict(event).items()}
299
+
300
+
301
+ JSON = OpenInferenceMimeTypeValues.JSON.value
302
+
303
+ LLM = OpenInferenceSpanKindValues.LLM.value
304
+
305
+ OPENINFERENCE_SPAN_KIND = SpanAttributes.OPENINFERENCE_SPAN_KIND
306
+ INPUT_MIME_TYPE = SpanAttributes.INPUT_MIME_TYPE
307
+ INPUT_VALUE = SpanAttributes.INPUT_VALUE
308
+ OUTPUT_MIME_TYPE = SpanAttributes.OUTPUT_MIME_TYPE
309
+ OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE
310
+ LLM_INPUT_MESSAGES = SpanAttributes.LLM_INPUT_MESSAGES
311
+ LLM_OUTPUT_MESSAGES = SpanAttributes.LLM_OUTPUT_MESSAGES
312
+ LLM_MODEL_NAME = SpanAttributes.LLM_MODEL_NAME
313
+ LLM_INVOCATION_PARAMETERS = SpanAttributes.LLM_INVOCATION_PARAMETERS
314
+ LLM_TOOLS = SpanAttributes.LLM_TOOLS
315
+ LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT
316
+ LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION
317
+
318
+ MESSAGE_CONTENT = MessageAttributes.MESSAGE_CONTENT
319
+ MESSAGE_ROLE = MessageAttributes.MESSAGE_ROLE
320
+ MESSAGE_TOOL_CALLS = MessageAttributes.MESSAGE_TOOL_CALLS
321
+
322
+ TOOL_CALL_FUNCTION_NAME = ToolCallAttributes.TOOL_CALL_FUNCTION_NAME
323
+ TOOL_CALL_FUNCTION_ARGUMENTS_JSON = ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUMENTS_JSON
324
+
325
+ TOOL_JSON_SCHEMA = ToolAttributes.TOOL_JSON_SCHEMA
@@ -0,0 +1,38 @@
1
+ from typing import Optional
2
+
3
+ import strawberry
4
+ from strawberry import UNSET
5
+ from strawberry.relay.types import GlobalID
6
+ from strawberry.scalars import JSON
7
+
8
+ from phoenix.server.api.types.TemplateLanguage import TemplateLanguage
9
+
10
+ from .ChatCompletionMessageInput import ChatCompletionMessageInput
11
+ from .GenerativeModelInput import GenerativeModelInput
12
+ from .InvocationParameters import InvocationParameterInput
13
+ from .TemplateOptions import TemplateOptions
14
+
15
+
16
+ @strawberry.input
17
+ class ChatCompletionInput:
18
+ messages: list[ChatCompletionMessageInput]
19
+ model: GenerativeModelInput
20
+ invocation_parameters: list[InvocationParameterInput] = strawberry.field(default_factory=list)
21
+ tools: Optional[list[JSON]] = UNSET
22
+ api_key: Optional[str] = strawberry.field(default=None)
23
+ template: Optional[TemplateOptions] = UNSET
24
+
25
+
26
+ @strawberry.input
27
+ class ChatCompletionOverDatasetInput:
28
+ messages: list[ChatCompletionMessageInput]
29
+ model: GenerativeModelInput
30
+ invocation_parameters: list[InvocationParameterInput] = strawberry.field(default_factory=list)
31
+ tools: Optional[list[JSON]] = UNSET
32
+ api_key: Optional[str] = strawberry.field(default=None)
33
+ template_language: TemplateLanguage
34
+ dataset_id: GlobalID
35
+ dataset_version_id: Optional[GlobalID] = None
36
+ experiment_name: Optional[str] = None
37
+ experiment_description: Optional[str] = None
38
+ experiment_metadata: Optional[JSON] = strawberry.field(default_factory=dict)
@@ -0,0 +1,17 @@
1
+ from typing import Optional
2
+
3
+ import strawberry
4
+ from strawberry import UNSET
5
+
6
+ from phoenix.server.api.types.GenerativeProvider import GenerativeProviderKey
7
+
8
+
9
+ @strawberry.input
10
+ class GenerativeModelInput:
11
+ provider_key: GenerativeProviderKey
12
+ name: str
13
+ """ The name of the model. Or the Deployment Name for Azure OpenAI models. """
14
+ endpoint: Optional[str] = UNSET
15
+ """ The endpoint to use for the model. Only required for Azure OpenAI models. """
16
+ api_version: Optional[str] = UNSET
17
+ """ The API version to use for the model. """
@@ -1,20 +1,163 @@
1
- from typing import Optional
1
+ from enum import Enum
2
+ from typing import Annotated, Any, Mapping, Optional, Union
2
3
 
3
4
  import strawberry
4
5
  from strawberry import UNSET
5
6
  from strawberry.scalars import JSON
6
7
 
7
8
 
9
+ @strawberry.enum
10
+ class CanonicalParameterName(str, Enum):
11
+ TEMPERATURE = "temperature"
12
+ MAX_COMPLETION_TOKENS = "max_completion_tokens"
13
+ STOP_SEQUENCES = "stop_sequences"
14
+ TOP_P = "top_p"
15
+ RANDOM_SEED = "random_seed"
16
+ TOOL_CHOICE = "tool_choice"
17
+ RESPONSE_FORMAT = "response_format"
18
+
19
+
20
+ @strawberry.enum
21
+ class InvocationInputField(str, Enum):
22
+ value_int = "value_int"
23
+ value_float = "value_float"
24
+ value_bool = "value_bool"
25
+ value_string = "value_string"
26
+ value_json = "value_json"
27
+ value_string_list = "value_string_list"
28
+ value_boolean = "value_boolean"
29
+
30
+
8
31
  @strawberry.input
9
- class InvocationParameters:
10
- """
11
- Invocation parameters interface shared between different providers.
12
- """
13
-
14
- temperature: Optional[float] = UNSET
15
- max_completion_tokens: Optional[int] = UNSET
16
- max_tokens: Optional[int] = UNSET
17
- top_p: Optional[float] = UNSET
18
- stop: Optional[list[str]] = UNSET
19
- seed: Optional[int] = UNSET
20
- tool_choice: Optional[JSON] = UNSET
32
+ class InvocationParameterInput:
33
+ invocation_name: str
34
+ canonical_name: Optional[CanonicalParameterName] = None
35
+ value_int: Optional[int] = UNSET
36
+ value_float: Optional[float] = UNSET
37
+ value_bool: Optional[bool] = UNSET
38
+ value_string: Optional[str] = UNSET
39
+ value_json: Optional[JSON] = UNSET
40
+ value_string_list: Optional[list[str]] = UNSET
41
+ value_boolean: Optional[bool] = UNSET
42
+
43
+
44
+ @strawberry.interface
45
+ class InvocationParameterBase:
46
+ invocation_name: str
47
+ canonical_name: Optional[CanonicalParameterName] = None
48
+ label: str
49
+ required: bool = False
50
+ hidden: bool = False
51
+
52
+
53
+ @strawberry.type
54
+ class IntInvocationParameter(InvocationParameterBase):
55
+ invocation_input_field: InvocationInputField = InvocationInputField.value_int
56
+ default_value: Optional[int] = UNSET
57
+
58
+
59
+ @strawberry.type
60
+ class FloatInvocationParameter(InvocationParameterBase):
61
+ invocation_input_field: InvocationInputField = InvocationInputField.value_float
62
+ default_value: Optional[float] = UNSET
63
+
64
+
65
+ @strawberry.type
66
+ class BoundedFloatInvocationParameter(InvocationParameterBase):
67
+ invocation_input_field: InvocationInputField = InvocationInputField.value_float
68
+ default_value: Optional[float] = UNSET
69
+ min_value: float
70
+ max_value: float
71
+
72
+
73
+ @strawberry.type
74
+ class StringInvocationParameter(InvocationParameterBase):
75
+ invocation_input_field: InvocationInputField = InvocationInputField.value_string
76
+ default_value: Optional[str] = UNSET
77
+
78
+
79
+ @strawberry.type
80
+ class JSONInvocationParameter(InvocationParameterBase):
81
+ invocation_input_field: InvocationInputField = InvocationInputField.value_json
82
+ default_value: Optional[JSON] = UNSET
83
+
84
+
85
+ @strawberry.type
86
+ class StringListInvocationParameter(InvocationParameterBase):
87
+ invocation_input_field: InvocationInputField = InvocationInputField.value_string_list
88
+ default_value: Optional[list[str]] = UNSET
89
+
90
+
91
+ @strawberry.type
92
+ class BooleanInvocationParameter(InvocationParameterBase):
93
+ invocation_input_field: InvocationInputField = InvocationInputField.value_bool
94
+ default_value: Optional[bool] = UNSET
95
+
96
+
97
+ def extract_parameter(
98
+ param_def: InvocationParameterBase, param_input: InvocationParameterInput
99
+ ) -> Any:
100
+ if isinstance(param_def, IntInvocationParameter):
101
+ return (
102
+ param_input.value_int if param_input.value_int is not UNSET else param_def.default_value
103
+ )
104
+ elif isinstance(param_def, FloatInvocationParameter):
105
+ return (
106
+ param_input.value_float
107
+ if param_input.value_float is not UNSET
108
+ else param_def.default_value
109
+ )
110
+ elif isinstance(param_def, BoundedFloatInvocationParameter):
111
+ return (
112
+ param_input.value_float
113
+ if param_input.value_float is not UNSET
114
+ else param_def.default_value
115
+ )
116
+ elif isinstance(param_def, StringInvocationParameter):
117
+ return (
118
+ param_input.value_string
119
+ if param_input.value_string is not UNSET
120
+ else param_def.default_value
121
+ )
122
+ elif isinstance(param_def, JSONInvocationParameter):
123
+ return (
124
+ param_input.value_json
125
+ if param_input.value_json is not UNSET
126
+ else param_def.default_value
127
+ )
128
+ elif isinstance(param_def, StringListInvocationParameter):
129
+ return (
130
+ param_input.value_string_list
131
+ if param_input.value_string_list is not UNSET
132
+ else param_def.default_value
133
+ )
134
+ elif isinstance(param_def, BooleanInvocationParameter):
135
+ return (
136
+ param_input.value_bool
137
+ if param_input.value_bool is not UNSET
138
+ else param_def.default_value
139
+ )
140
+
141
+
142
+ def validate_invocation_parameters(
143
+ parameters: list["InvocationParameter"],
144
+ input: Mapping[str, Any],
145
+ ) -> None:
146
+ for param_def in parameters:
147
+ if param_def.required and param_def.invocation_name not in input:
148
+ raise ValueError(f"Required parameter {param_def.invocation_name} not provided")
149
+
150
+
151
+ # Create the union for output types
152
+ InvocationParameter = Annotated[
153
+ Union[
154
+ IntInvocationParameter,
155
+ FloatInvocationParameter,
156
+ BoundedFloatInvocationParameter,
157
+ StringInvocationParameter,
158
+ JSONInvocationParameter,
159
+ StringListInvocationParameter,
160
+ BooleanInvocationParameter,
161
+ ],
162
+ strawberry.union("InvocationParameter"),
163
+ ]
@@ -0,0 +1,10 @@
1
+ import strawberry
2
+ from strawberry.scalars import JSON
3
+
4
+ from phoenix.server.api.types.TemplateLanguage import TemplateLanguage
5
+
6
+
7
+ @strawberry.input
8
+ class TemplateOptions:
9
+ variables: JSON
10
+ language: TemplateLanguage
@@ -1,6 +1,9 @@
1
1
  import strawberry
2
2
 
3
3
  from phoenix.server.api.mutations.api_key_mutations import ApiKeyMutationMixin
4
+ from phoenix.server.api.mutations.chat_mutations import (
5
+ ChatCompletionMutationMixin,
6
+ )
4
7
  from phoenix.server.api.mutations.dataset_mutations import DatasetMutationMixin
5
8
  from phoenix.server.api.mutations.experiment_mutations import ExperimentMutationMixin
6
9
  from phoenix.server.api.mutations.export_events_mutations import ExportEventsMutationMixin
@@ -20,5 +23,6 @@ class Mutation(
20
23
  SpanAnnotationMutationMixin,
21
24
  TraceAnnotationMutationMixin,
22
25
  UserMutationMixin,
26
+ ChatCompletionMutationMixin,
23
27
  ):
24
28
  pass