arize-phoenix 5.6.0__py3-none-any.whl → 5.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (39) hide show
  1. {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.8.0.dist-info}/METADATA +4 -6
  2. {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.8.0.dist-info}/RECORD +39 -30
  3. {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.8.0.dist-info}/WHEEL +1 -1
  4. phoenix/config.py +58 -0
  5. phoenix/server/api/helpers/playground_clients.py +758 -0
  6. phoenix/server/api/helpers/playground_registry.py +70 -0
  7. phoenix/server/api/helpers/playground_spans.py +422 -0
  8. phoenix/server/api/input_types/ChatCompletionInput.py +38 -0
  9. phoenix/server/api/input_types/GenerativeModelInput.py +17 -0
  10. phoenix/server/api/input_types/InvocationParameters.py +155 -13
  11. phoenix/server/api/input_types/TemplateOptions.py +10 -0
  12. phoenix/server/api/mutations/__init__.py +4 -0
  13. phoenix/server/api/mutations/chat_mutations.py +355 -0
  14. phoenix/server/api/queries.py +41 -52
  15. phoenix/server/api/schema.py +42 -10
  16. phoenix/server/api/subscriptions.py +378 -595
  17. phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +46 -0
  18. phoenix/server/api/types/GenerativeProvider.py +27 -3
  19. phoenix/server/api/types/Span.py +37 -0
  20. phoenix/server/api/types/TemplateLanguage.py +9 -0
  21. phoenix/server/app.py +75 -13
  22. phoenix/server/grpc_server.py +3 -1
  23. phoenix/server/main.py +14 -1
  24. phoenix/server/static/.vite/manifest.json +31 -31
  25. phoenix/server/static/assets/{components-C70HJiXz.js → components-MllbfxfJ.js} +168 -150
  26. phoenix/server/static/assets/{index-DLe1Oo3l.js → index-BVO2YcT1.js} +2 -2
  27. phoenix/server/static/assets/{pages-C8-Sl7JI.js → pages-BHfC6jnL.js} +464 -310
  28. phoenix/server/static/assets/{vendor-CtqfhlbC.js → vendor-BEuNhfwH.js} +1 -1
  29. phoenix/server/static/assets/{vendor-arizeai-C_3SBz56.js → vendor-arizeai-Bskhzyjm.js} +1 -1
  30. phoenix/server/static/assets/{vendor-codemirror-wfdk9cjp.js → vendor-codemirror-DLlXCf0x.js} +1 -1
  31. phoenix/server/static/assets/{vendor-recharts-BiVnSv90.js → vendor-recharts-CRqhvLYg.js} +1 -1
  32. phoenix/server/templates/index.html +1 -0
  33. phoenix/services.py +4 -0
  34. phoenix/session/session.py +15 -1
  35. phoenix/utilities/template_formatters.py +11 -1
  36. phoenix/version.py +1 -1
  37. {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.8.0.dist-info}/entry_points.txt +0 -0
  38. {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.8.0.dist-info}/licenses/IP_NOTICE +0 -0
  39. {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.8.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,20 +1,162 @@
1
- from typing import Optional
1
+ from enum import Enum
2
+ from typing import Annotated, Any, Mapping, Optional, Union
2
3
 
3
4
  import strawberry
4
5
  from strawberry import UNSET
5
6
  from strawberry.scalars import JSON
6
7
 
7
8
 
9
+ @strawberry.enum
10
+ class CanonicalParameterName(str, Enum):
11
+ TEMPERATURE = "temperature"
12
+ MAX_COMPLETION_TOKENS = "max_completion_tokens"
13
+ STOP_SEQUENCES = "stop_sequences"
14
+ TOP_P = "top_p"
15
+ RANDOM_SEED = "random_seed"
16
+ TOOL_CHOICE = "tool_choice"
17
+ RESPONSE_FORMAT = "response_format"
18
+
19
+
20
+ @strawberry.enum
21
+ class InvocationInputField(str, Enum):
22
+ value_int = "value_int"
23
+ value_float = "value_float"
24
+ value_bool = "value_bool"
25
+ value_string = "value_string"
26
+ value_json = "value_json"
27
+ value_string_list = "value_string_list"
28
+ value_boolean = "value_boolean"
29
+
30
+
8
31
  @strawberry.input
9
- class InvocationParameters:
10
- """
11
- Invocation parameters interface shared between different providers.
12
- """
13
-
14
- temperature: Optional[float] = UNSET
15
- max_completion_tokens: Optional[int] = UNSET
16
- max_tokens: Optional[int] = UNSET
17
- top_p: Optional[float] = UNSET
18
- stop: Optional[list[str]] = UNSET
19
- seed: Optional[int] = UNSET
20
- tool_choice: Optional[JSON] = UNSET
32
+ class InvocationParameterInput:
33
+ invocation_name: str
34
+ canonical_name: Optional[CanonicalParameterName] = None
35
+ value_int: Optional[int] = UNSET
36
+ value_float: Optional[float] = UNSET
37
+ value_bool: Optional[bool] = UNSET
38
+ value_string: Optional[str] = UNSET
39
+ value_json: Optional[JSON] = UNSET
40
+ value_string_list: Optional[list[str]] = UNSET
41
+ value_boolean: Optional[bool] = UNSET
42
+
43
+
44
+ @strawberry.interface
45
+ class InvocationParameterBase:
46
+ invocation_name: str
47
+ canonical_name: Optional[CanonicalParameterName] = None
48
+ label: str
49
+ required: bool = False
50
+
51
+
52
+ @strawberry.type
53
+ class IntInvocationParameter(InvocationParameterBase):
54
+ invocation_input_field: InvocationInputField = InvocationInputField.value_int
55
+ default_value: Optional[int] = None
56
+
57
+
58
+ @strawberry.type
59
+ class FloatInvocationParameter(InvocationParameterBase):
60
+ invocation_input_field: InvocationInputField = InvocationInputField.value_float
61
+ default_value: Optional[float] = None
62
+
63
+
64
+ @strawberry.type
65
+ class BoundedFloatInvocationParameter(InvocationParameterBase):
66
+ invocation_input_field: InvocationInputField = InvocationInputField.value_float
67
+ default_value: Optional[float] = None
68
+ min_value: float
69
+ max_value: float
70
+
71
+
72
+ @strawberry.type
73
+ class StringInvocationParameter(InvocationParameterBase):
74
+ invocation_input_field: InvocationInputField = InvocationInputField.value_string
75
+ default_value: Optional[str] = None
76
+
77
+
78
+ @strawberry.type
79
+ class JSONInvocationParameter(InvocationParameterBase):
80
+ invocation_input_field: InvocationInputField = InvocationInputField.value_json
81
+ default_value: Optional[JSON] = None
82
+
83
+
84
+ @strawberry.type
85
+ class StringListInvocationParameter(InvocationParameterBase):
86
+ invocation_input_field: InvocationInputField = InvocationInputField.value_string_list
87
+ default_value: Optional[list[str]] = None
88
+
89
+
90
+ @strawberry.type
91
+ class BooleanInvocationParameter(InvocationParameterBase):
92
+ invocation_input_field: InvocationInputField = InvocationInputField.value_bool
93
+ default_value: Optional[bool] = None
94
+
95
+
96
+ def extract_parameter(
97
+ param_def: InvocationParameterBase, param_input: InvocationParameterInput
98
+ ) -> Any:
99
+ if isinstance(param_def, IntInvocationParameter):
100
+ return (
101
+ param_input.value_int if param_input.value_int is not UNSET else param_def.default_value
102
+ )
103
+ elif isinstance(param_def, FloatInvocationParameter):
104
+ return (
105
+ param_input.value_float
106
+ if param_input.value_float is not UNSET
107
+ else param_def.default_value
108
+ )
109
+ elif isinstance(param_def, BoundedFloatInvocationParameter):
110
+ return (
111
+ param_input.value_float
112
+ if param_input.value_float is not UNSET
113
+ else param_def.default_value
114
+ )
115
+ elif isinstance(param_def, StringInvocationParameter):
116
+ return (
117
+ param_input.value_string
118
+ if param_input.value_string is not UNSET
119
+ else param_def.default_value
120
+ )
121
+ elif isinstance(param_def, JSONInvocationParameter):
122
+ return (
123
+ param_input.value_json
124
+ if param_input.value_json is not UNSET
125
+ else param_def.default_value
126
+ )
127
+ elif isinstance(param_def, StringListInvocationParameter):
128
+ return (
129
+ param_input.value_string_list
130
+ if param_input.value_string_list is not UNSET
131
+ else param_def.default_value
132
+ )
133
+ elif isinstance(param_def, BooleanInvocationParameter):
134
+ return (
135
+ param_input.value_bool
136
+ if param_input.value_bool is not UNSET
137
+ else param_def.default_value
138
+ )
139
+
140
+
141
+ def validate_invocation_parameters(
142
+ parameters: list["InvocationParameter"],
143
+ input: Mapping[str, Any],
144
+ ) -> None:
145
+ for param_def in parameters:
146
+ if param_def.required and param_def.invocation_name not in input:
147
+ raise ValueError(f"Required parameter {param_def.invocation_name} not provided")
148
+
149
+
150
+ # Create the union for output types
151
+ InvocationParameter = Annotated[
152
+ Union[
153
+ IntInvocationParameter,
154
+ FloatInvocationParameter,
155
+ BoundedFloatInvocationParameter,
156
+ StringInvocationParameter,
157
+ JSONInvocationParameter,
158
+ StringListInvocationParameter,
159
+ BooleanInvocationParameter,
160
+ ],
161
+ strawberry.union("InvocationParameter"),
162
+ ]
@@ -0,0 +1,10 @@
1
+ import strawberry
2
+ from strawberry.scalars import JSON
3
+
4
+ from phoenix.server.api.types.TemplateLanguage import TemplateLanguage
5
+
6
+
7
+ @strawberry.input
8
+ class TemplateOptions:
9
+ variables: JSON
10
+ language: TemplateLanguage
@@ -1,6 +1,9 @@
1
1
  import strawberry
2
2
 
3
3
  from phoenix.server.api.mutations.api_key_mutations import ApiKeyMutationMixin
4
+ from phoenix.server.api.mutations.chat_mutations import (
5
+ ChatCompletionMutationMixin,
6
+ )
4
7
  from phoenix.server.api.mutations.dataset_mutations import DatasetMutationMixin
5
8
  from phoenix.server.api.mutations.experiment_mutations import ExperimentMutationMixin
6
9
  from phoenix.server.api.mutations.export_events_mutations import ExportEventsMutationMixin
@@ -20,5 +23,6 @@ class Mutation(
20
23
  SpanAnnotationMutationMixin,
21
24
  TraceAnnotationMutationMixin,
22
25
  UserMutationMixin,
26
+ ChatCompletionMutationMixin,
23
27
  ):
24
28
  pass
@@ -0,0 +1,355 @@
1
+ import json
2
+ from dataclasses import asdict
3
+ from datetime import datetime, timezone
4
+ from itertools import chain
5
+ from traceback import format_exc
6
+ from typing import Any, Iterable, Iterator, List, Optional
7
+
8
+ import strawberry
9
+ from openinference.instrumentation import safe_json_dumps
10
+ from openinference.semconv.trace import (
11
+ MessageAttributes,
12
+ OpenInferenceMimeTypeValues,
13
+ OpenInferenceSpanKindValues,
14
+ SpanAttributes,
15
+ ToolAttributes,
16
+ ToolCallAttributes,
17
+ )
18
+ from opentelemetry.sdk.trace.id_generator import RandomIdGenerator as DefaultOTelIDGenerator
19
+ from opentelemetry.trace import StatusCode
20
+ from sqlalchemy import insert, select
21
+ from strawberry.types import Info
22
+ from typing_extensions import assert_never
23
+
24
+ from phoenix.datetime_utils import local_now, normalize_datetime
25
+ from phoenix.db import models
26
+ from phoenix.server.api.context import Context
27
+ from phoenix.server.api.exceptions import BadRequest
28
+ from phoenix.server.api.helpers.playground_clients import initialize_playground_clients
29
+ from phoenix.server.api.helpers.playground_registry import PLAYGROUND_CLIENT_REGISTRY
30
+ from phoenix.server.api.helpers.playground_spans import (
31
+ input_value_and_mime_type,
32
+ llm_input_messages,
33
+ llm_invocation_parameters,
34
+ llm_model_name,
35
+ llm_span_kind,
36
+ llm_tools,
37
+ )
38
+ from phoenix.server.api.input_types.ChatCompletionInput import ChatCompletionInput
39
+ from phoenix.server.api.input_types.TemplateOptions import TemplateOptions
40
+ from phoenix.server.api.types.ChatCompletionMessageRole import ChatCompletionMessageRole
41
+ from phoenix.server.api.types.ChatCompletionSubscriptionPayload import (
42
+ TextChunk,
43
+ ToolCallChunk,
44
+ )
45
+ from phoenix.server.api.types.Span import Span, to_gql_span
46
+ from phoenix.server.api.types.TemplateLanguage import TemplateLanguage
47
+ from phoenix.server.dml_event import SpanInsertEvent
48
+ from phoenix.trace.attributes import unflatten
49
+ from phoenix.trace.schemas import SpanException
50
+ from phoenix.utilities.json import jsonify
51
+ from phoenix.utilities.template_formatters import (
52
+ FStringTemplateFormatter,
53
+ MustacheTemplateFormatter,
54
+ TemplateFormatter,
55
+ )
56
+
57
+ initialize_playground_clients()
58
+
59
+ ChatCompletionMessage = tuple[ChatCompletionMessageRole, str, Optional[str], Optional[List[Any]]]
60
+
61
+
62
+ @strawberry.type
63
+ class ChatCompletionFunctionCall:
64
+ name: str
65
+ arguments: str
66
+
67
+
68
+ @strawberry.type
69
+ class ChatCompletionToolCall:
70
+ id: str
71
+ function: ChatCompletionFunctionCall
72
+
73
+
74
+ @strawberry.type
75
+ class ChatCompletionMutationPayload:
76
+ content: Optional[str]
77
+ tool_calls: List[ChatCompletionToolCall]
78
+ span: Span
79
+ error_message: Optional[str]
80
+
81
+
82
+ @strawberry.type
83
+ class ChatCompletionMutationMixin:
84
+ @strawberry.mutation
85
+ async def chat_completion(
86
+ self, info: Info[Context, None], input: ChatCompletionInput
87
+ ) -> ChatCompletionMutationPayload:
88
+ provider_key = input.model.provider_key
89
+ llm_client_class = PLAYGROUND_CLIENT_REGISTRY.get_client(provider_key, input.model.name)
90
+ if llm_client_class is None:
91
+ raise BadRequest(f"No LLM client registered for provider '{provider_key}'")
92
+ attributes: dict[str, Any] = {}
93
+ llm_client = llm_client_class(
94
+ model=input.model,
95
+ api_key=input.api_key,
96
+ )
97
+
98
+ messages = [
99
+ (
100
+ message.role,
101
+ message.content,
102
+ message.tool_call_id if isinstance(message.tool_call_id, str) else None,
103
+ message.tool_calls if isinstance(message.tool_calls, list) else None,
104
+ )
105
+ for message in input.messages
106
+ ]
107
+ if template_options := input.template:
108
+ messages = list(_formatted_messages(messages, template_options))
109
+
110
+ invocation_parameters = llm_client.construct_invocation_parameters(
111
+ input.invocation_parameters
112
+ )
113
+
114
+ text_content = ""
115
+ tool_calls: dict[str, ChatCompletionToolCall] = {}
116
+ events = []
117
+ attributes.update(
118
+ chain(
119
+ llm_span_kind(),
120
+ llm_model_name(input.model.name),
121
+ llm_tools(input.tools or []),
122
+ llm_input_messages(messages),
123
+ llm_invocation_parameters(invocation_parameters),
124
+ input_value_and_mime_type(input),
125
+ **llm_client.attributes,
126
+ )
127
+ )
128
+
129
+ start_time = normalize_datetime(dt=local_now(), tz=timezone.utc)
130
+ status_code = StatusCode.OK
131
+ status_message = ""
132
+ try:
133
+ async for chunk in llm_client.chat_completion_create(
134
+ messages=messages, tools=input.tools or [], **invocation_parameters
135
+ ):
136
+ # Process the chunk
137
+ if isinstance(chunk, TextChunk):
138
+ text_content += chunk.content
139
+ elif isinstance(chunk, ToolCallChunk):
140
+ if chunk.id not in tool_calls:
141
+ tool_calls[chunk.id] = ChatCompletionToolCall(
142
+ id=chunk.id,
143
+ function=ChatCompletionFunctionCall(
144
+ name=chunk.function.name,
145
+ arguments=chunk.function.arguments,
146
+ ),
147
+ )
148
+ else:
149
+ tool_calls[chunk.id].function.arguments += chunk.function.arguments
150
+ else:
151
+ assert_never(chunk)
152
+ except Exception as e:
153
+ # Handle exceptions and record exception event
154
+ status_code = StatusCode.ERROR
155
+ status_message = str(e)
156
+ end_time = normalize_datetime(dt=local_now(), tz=timezone.utc)
157
+ assert end_time is not None
158
+ events.append(
159
+ SpanException(
160
+ timestamp=end_time,
161
+ message=status_message,
162
+ exception_type=type(e).__name__,
163
+ exception_escaped=False,
164
+ exception_stacktrace=format_exc(),
165
+ )
166
+ )
167
+ else:
168
+ end_time = normalize_datetime(dt=local_now(), tz=timezone.utc)
169
+
170
+ if text_content or tool_calls:
171
+ attributes.update(
172
+ chain(
173
+ _output_value_and_mime_type(text_content, tool_calls),
174
+ _llm_output_messages(text_content, tool_calls),
175
+ )
176
+ )
177
+
178
+ # Now write the span to the database
179
+ trace_id = _generate_trace_id()
180
+ span_id = _generate_span_id()
181
+ async with info.context.db() as session:
182
+ # Get or create the project ID
183
+ if (
184
+ project_id := await session.scalar(
185
+ select(models.Project.id).where(models.Project.name == PLAYGROUND_PROJECT_NAME)
186
+ )
187
+ ) is None:
188
+ project_id = await session.scalar(
189
+ insert(models.Project)
190
+ .returning(models.Project.id)
191
+ .values(
192
+ name=PLAYGROUND_PROJECT_NAME,
193
+ description="Traces from prompt playground",
194
+ )
195
+ )
196
+ trace = models.Trace(
197
+ project_rowid=project_id,
198
+ trace_id=trace_id,
199
+ start_time=start_time,
200
+ end_time=end_time,
201
+ )
202
+ span = models.Span(
203
+ trace_rowid=trace.id,
204
+ span_id=span_id,
205
+ parent_id=None,
206
+ name="ChatCompletion",
207
+ span_kind=LLM,
208
+ start_time=start_time,
209
+ end_time=end_time,
210
+ attributes=unflatten(attributes.items()),
211
+ events=[_serialize_event(event) for event in events],
212
+ status_code=status_code.name,
213
+ status_message=status_message,
214
+ cumulative_error_count=int(status_code is StatusCode.ERROR),
215
+ cumulative_llm_token_count_prompt=attributes.get(LLM_TOKEN_COUNT_PROMPT, 0),
216
+ cumulative_llm_token_count_completion=attributes.get(LLM_TOKEN_COUNT_COMPLETION, 0),
217
+ llm_token_count_prompt=attributes.get(LLM_TOKEN_COUNT_PROMPT, 0),
218
+ llm_token_count_completion=attributes.get(LLM_TOKEN_COUNT_COMPLETION, 0),
219
+ trace=trace,
220
+ )
221
+ session.add(trace)
222
+ session.add(span)
223
+ await session.flush()
224
+
225
+ gql_span = to_gql_span(span)
226
+
227
+ info.context.event_queue.put(SpanInsertEvent(ids=(project_id,)))
228
+
229
+ if status_code is StatusCode.ERROR:
230
+ return ChatCompletionMutationPayload(
231
+ content=None,
232
+ tool_calls=[],
233
+ span=gql_span,
234
+ error_message=status_message,
235
+ )
236
+ else:
237
+ return ChatCompletionMutationPayload(
238
+ content=text_content if text_content else None,
239
+ tool_calls=list(tool_calls.values()),
240
+ span=gql_span,
241
+ error_message=None,
242
+ )
243
+
244
+
245
+ def _formatted_messages(
246
+ messages: Iterable[ChatCompletionMessage],
247
+ template_options: TemplateOptions,
248
+ ) -> Iterator[ChatCompletionMessage]:
249
+ """
250
+ Formats the messages using the given template options.
251
+ """
252
+ template_formatter = _template_formatter(template_language=template_options.language)
253
+ (
254
+ roles,
255
+ templates,
256
+ tool_call_id,
257
+ tool_calls,
258
+ ) = zip(*messages)
259
+ formatted_templates = map(
260
+ lambda template: template_formatter.format(template, **template_options.variables),
261
+ templates,
262
+ )
263
+ formatted_messages = zip(roles, formatted_templates, tool_call_id, tool_calls)
264
+ return formatted_messages
265
+
266
+
267
+ def _template_formatter(template_language: TemplateLanguage) -> TemplateFormatter:
268
+ """
269
+ Instantiates the appropriate template formatter for the template language.
270
+ """
271
+ if template_language is TemplateLanguage.MUSTACHE:
272
+ return MustacheTemplateFormatter()
273
+ if template_language is TemplateLanguage.F_STRING:
274
+ return FStringTemplateFormatter()
275
+ assert_never(template_language)
276
+
277
+
278
+ def _output_value_and_mime_type(
279
+ text: str, tool_calls: dict[str, ChatCompletionToolCall]
280
+ ) -> Iterator[tuple[str, Any]]:
281
+ if text and tool_calls:
282
+ yield OUTPUT_MIME_TYPE, JSON
283
+ yield (
284
+ OUTPUT_VALUE,
285
+ safe_json_dumps({"content": text, "tool_calls": jsonify(list(tool_calls.values()))}),
286
+ )
287
+ elif tool_calls:
288
+ yield OUTPUT_MIME_TYPE, JSON
289
+ yield OUTPUT_VALUE, safe_json_dumps(jsonify(list(tool_calls.values())))
290
+ elif text:
291
+ yield OUTPUT_MIME_TYPE, TEXT
292
+ yield OUTPUT_VALUE, text
293
+
294
+
295
+ def _llm_output_messages(
296
+ text_content: str, tool_calls: dict[str, ChatCompletionToolCall]
297
+ ) -> Iterator[tuple[str, Any]]:
298
+ yield f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_ROLE}", "assistant"
299
+ if text_content:
300
+ yield f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_CONTENT}", text_content
301
+ for tool_call_index, tool_call in enumerate(tool_calls.values()):
302
+ yield (
303
+ f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_NAME}",
304
+ tool_call.function.name,
305
+ )
306
+ if arguments := tool_call.function.arguments:
307
+ yield (
308
+ f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_ARGUMENTS_JSON}",
309
+ json.dumps(arguments),
310
+ )
311
+
312
+
313
+ def _generate_trace_id() -> str:
314
+ return _hex(DefaultOTelIDGenerator().generate_trace_id())
315
+
316
+
317
+ def _generate_span_id() -> str:
318
+ return _hex(DefaultOTelIDGenerator().generate_span_id())
319
+
320
+
321
+ def _hex(number: int) -> str:
322
+ return hex(number)[2:]
323
+
324
+
325
+ def _serialize_event(event: SpanException) -> dict[str, Any]:
326
+ return {k: (v.isoformat() if isinstance(v, datetime) else v) for k, v in asdict(event).items()}
327
+
328
+
329
+ JSON = OpenInferenceMimeTypeValues.JSON.value
330
+ TEXT = OpenInferenceMimeTypeValues.TEXT.value
331
+ LLM = OpenInferenceSpanKindValues.LLM.value
332
+
333
+ OPENINFERENCE_SPAN_KIND = SpanAttributes.OPENINFERENCE_SPAN_KIND
334
+ INPUT_MIME_TYPE = SpanAttributes.INPUT_MIME_TYPE
335
+ INPUT_VALUE = SpanAttributes.INPUT_VALUE
336
+ OUTPUT_MIME_TYPE = SpanAttributes.OUTPUT_MIME_TYPE
337
+ OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE
338
+ LLM_INPUT_MESSAGES = SpanAttributes.LLM_INPUT_MESSAGES
339
+ LLM_OUTPUT_MESSAGES = SpanAttributes.LLM_OUTPUT_MESSAGES
340
+ LLM_MODEL_NAME = SpanAttributes.LLM_MODEL_NAME
341
+ LLM_INVOCATION_PARAMETERS = SpanAttributes.LLM_INVOCATION_PARAMETERS
342
+ LLM_TOOLS = SpanAttributes.LLM_TOOLS
343
+ LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT
344
+ LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION
345
+
346
+ MESSAGE_CONTENT = MessageAttributes.MESSAGE_CONTENT
347
+ MESSAGE_ROLE = MessageAttributes.MESSAGE_ROLE
348
+ MESSAGE_TOOL_CALLS = MessageAttributes.MESSAGE_TOOL_CALLS
349
+
350
+ TOOL_CALL_FUNCTION_NAME = ToolCallAttributes.TOOL_CALL_FUNCTION_NAME
351
+ TOOL_CALL_FUNCTION_ARGUMENTS_JSON = ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUMENTS_JSON
352
+
353
+ TOOL_JSON_SCHEMA = ToolAttributes.TOOL_JSON_SCHEMA
354
+
355
+ PLAYGROUND_PROJECT_NAME = "playground"
@@ -37,12 +37,17 @@ from phoenix.server.api.auth import MSG_ADMIN_ONLY, IsAdmin
37
37
  from phoenix.server.api.context import Context
38
38
  from phoenix.server.api.exceptions import NotFound, Unauthorized
39
39
  from phoenix.server.api.helpers import ensure_list
40
+ from phoenix.server.api.helpers.playground_clients import initialize_playground_clients
41
+ from phoenix.server.api.helpers.playground_registry import PLAYGROUND_CLIENT_REGISTRY
40
42
  from phoenix.server.api.input_types.ClusterInput import ClusterInput
41
43
  from phoenix.server.api.input_types.Coordinates import (
42
44
  InputCoordinate2D,
43
45
  InputCoordinate3D,
44
46
  )
45
47
  from phoenix.server.api.input_types.DatasetSort import DatasetSort
48
+ from phoenix.server.api.input_types.InvocationParameters import (
49
+ InvocationParameter,
50
+ )
46
51
  from phoenix.server.api.types.Cluster import Cluster, to_gql_clusters
47
52
  from phoenix.server.api.types.Dataset import Dataset, to_gql_dataset
48
53
  from phoenix.server.api.types.DatasetExample import DatasetExample
@@ -80,78 +85,62 @@ from phoenix.server.api.types.User import User, to_gql_user
80
85
  from phoenix.server.api.types.UserApiKey import UserApiKey, to_gql_api_key
81
86
  from phoenix.server.api.types.UserRole import UserRole
82
87
 
88
+ initialize_playground_clients()
89
+
83
90
 
84
91
  @strawberry.input
85
92
  class ModelsInput:
86
93
  provider_key: Optional[GenerativeProviderKey]
94
+ model_name: Optional[str] = None
87
95
 
88
96
 
89
97
  @strawberry.type
90
98
  class Query:
91
99
  @strawberry.field
92
100
  async def model_providers(self) -> list[GenerativeProvider]:
101
+ available_providers = PLAYGROUND_CLIENT_REGISTRY.list_all_providers()
93
102
  return [
94
103
  GenerativeProvider(
95
- name="OpenAI",
96
- key=GenerativeProviderKey.OPENAI,
97
- ),
98
- GenerativeProvider(
99
- name="Azure OpenAI",
100
- key=GenerativeProviderKey.AZURE_OPENAI,
101
- ),
102
- GenerativeProvider(
103
- name="Anthropic",
104
- key=GenerativeProviderKey.ANTHROPIC,
105
- ),
104
+ name=provider_key.value,
105
+ key=provider_key,
106
+ )
107
+ for provider_key in available_providers
106
108
  ]
107
109
 
108
110
  @strawberry.field
109
111
  async def models(self, input: Optional[ModelsInput] = None) -> list[GenerativeModel]:
110
- openai_models = [
111
- "o1-preview",
112
- "o1-preview-2024-09-12",
113
- "o1-mini",
114
- "o1-mini-2024-09-12",
115
- "gpt-4o",
116
- "gpt-4o-2024-08-06",
117
- "gpt-4o-2024-05-13",
118
- "chatgpt-4o-latest",
119
- "gpt-4o-mini",
120
- "gpt-4o-mini-2024-07-18",
121
- "gpt-4-turbo",
122
- "gpt-4-turbo-2024-04-09",
123
- "gpt-4-turbo-preview",
124
- "gpt-4-0125-preview",
125
- "gpt-4-1106-preview",
126
- "gpt-4",
127
- "gpt-4-0613",
128
- "gpt-3.5-turbo-0125",
129
- "gpt-3.5-turbo",
130
- "gpt-3.5-turbo-1106",
131
- "gpt-3.5-turbo-instruct",
132
- ]
133
- anthropic_models = [
134
- "claude-3-5-sonnet-20240620",
135
- "claude-3-opus-20240229",
136
- "claude-3-sonnet-20240229",
137
- "claude-3-haiku-20240307",
138
- ]
139
- openai_generative_models = [
140
- GenerativeModel(name=model_name, provider_key=GenerativeProviderKey.OPENAI)
141
- for model_name in openai_models
142
- ]
143
- anthropic_generative_models = [
144
- GenerativeModel(name=model_name, provider_key=GenerativeProviderKey.ANTHROPIC)
145
- for model_name in anthropic_models
146
- ]
147
-
148
- all_models = openai_generative_models + anthropic_generative_models
149
-
150
112
  if input is not None and input.provider_key is not None:
151
- return [model for model in all_models if model.provider_key == input.provider_key]
113
+ supported_model_names = PLAYGROUND_CLIENT_REGISTRY.list_models(input.provider_key)
114
+ supported_models = [
115
+ GenerativeModel(name=model_name, provider_key=input.provider_key)
116
+ for model_name in supported_model_names
117
+ ]
118
+ return supported_models
152
119
 
120
+ registered_models = PLAYGROUND_CLIENT_REGISTRY.list_all_models()
121
+ all_models: list[GenerativeModel] = []
122
+ for provider_key, model_name in registered_models:
123
+ if model_name is not None and provider_key is not None:
124
+ all_models.append(GenerativeModel(name=model_name, provider_key=provider_key))
153
125
  return all_models
154
126
 
127
+ @strawberry.field
128
+ async def model_invocation_parameters(
129
+ self, input: Optional[ModelsInput] = None
130
+ ) -> list[InvocationParameter]:
131
+ if input is None:
132
+ return []
133
+ provider_key = input.provider_key
134
+ model_name = input.model_name
135
+ if provider_key is not None:
136
+ client = PLAYGROUND_CLIENT_REGISTRY.get_client(provider_key, model_name)
137
+ if client is None:
138
+ return []
139
+ invocation_parameters = client.supported_invocation_parameters()
140
+ return invocation_parameters
141
+ else:
142
+ return []
143
+
155
144
  @strawberry.field(permission_classes=[IsAdmin]) # type: ignore
156
145
  async def users(
157
146
  self,