arize-phoenix 5.6.0__py3-none-any.whl → 5.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (34) hide show
  1. {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.7.0.dist-info}/METADATA +2 -2
  2. {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.7.0.dist-info}/RECORD +34 -25
  3. phoenix/config.py +42 -0
  4. phoenix/server/api/helpers/playground_clients.py +671 -0
  5. phoenix/server/api/helpers/playground_registry.py +70 -0
  6. phoenix/server/api/helpers/playground_spans.py +325 -0
  7. phoenix/server/api/input_types/ChatCompletionInput.py +38 -0
  8. phoenix/server/api/input_types/GenerativeModelInput.py +17 -0
  9. phoenix/server/api/input_types/InvocationParameters.py +156 -13
  10. phoenix/server/api/input_types/TemplateOptions.py +10 -0
  11. phoenix/server/api/mutations/__init__.py +4 -0
  12. phoenix/server/api/mutations/chat_mutations.py +374 -0
  13. phoenix/server/api/queries.py +41 -52
  14. phoenix/server/api/schema.py +42 -10
  15. phoenix/server/api/subscriptions.py +326 -595
  16. phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +44 -0
  17. phoenix/server/api/types/GenerativeProvider.py +27 -3
  18. phoenix/server/api/types/Span.py +37 -0
  19. phoenix/server/api/types/TemplateLanguage.py +9 -0
  20. phoenix/server/app.py +61 -13
  21. phoenix/server/main.py +14 -1
  22. phoenix/server/static/.vite/manifest.json +9 -9
  23. phoenix/server/static/assets/{components-C70HJiXz.js → components-Csu8UKOs.js} +114 -114
  24. phoenix/server/static/assets/{index-DLe1Oo3l.js → index-Bk5C9EA7.js} +1 -1
  25. phoenix/server/static/assets/{pages-C8-Sl7JI.js → pages-UeWaKXNs.js} +328 -268
  26. phoenix/server/templates/index.html +1 -0
  27. phoenix/services.py +4 -0
  28. phoenix/session/session.py +15 -1
  29. phoenix/utilities/template_formatters.py +11 -1
  30. phoenix/version.py +1 -1
  31. {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.7.0.dist-info}/WHEEL +0 -0
  32. {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.7.0.dist-info}/entry_points.txt +0 -0
  33. {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.7.0.dist-info}/licenses/IP_NOTICE +0 -0
  34. {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.7.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,374 @@
1
+ import json
2
+ from dataclasses import asdict
3
+ from datetime import datetime, timezone
4
+ from itertools import chain
5
+ from traceback import format_exc
6
+ from typing import Any, Iterable, Iterator, List, Optional
7
+
8
+ import strawberry
9
+ from openinference.semconv.trace import (
10
+ MessageAttributes,
11
+ OpenInferenceMimeTypeValues,
12
+ OpenInferenceSpanKindValues,
13
+ SpanAttributes,
14
+ ToolAttributes,
15
+ ToolCallAttributes,
16
+ )
17
+ from opentelemetry.sdk.trace.id_generator import RandomIdGenerator as DefaultOTelIDGenerator
18
+ from opentelemetry.trace import StatusCode
19
+ from sqlalchemy import insert, select
20
+ from strawberry.types import Info
21
+ from typing_extensions import assert_never
22
+
23
+ from phoenix.datetime_utils import local_now, normalize_datetime
24
+ from phoenix.db import models
25
+ from phoenix.server.api.context import Context
26
+ from phoenix.server.api.exceptions import BadRequest
27
+ from phoenix.server.api.helpers.playground_clients import initialize_playground_clients
28
+ from phoenix.server.api.helpers.playground_registry import PLAYGROUND_CLIENT_REGISTRY
29
+ from phoenix.server.api.input_types.ChatCompletionInput import ChatCompletionInput
30
+ from phoenix.server.api.input_types.TemplateOptions import TemplateOptions
31
+ from phoenix.server.api.types.ChatCompletionMessageRole import ChatCompletionMessageRole
32
+ from phoenix.server.api.types.ChatCompletionSubscriptionPayload import (
33
+ TextChunk,
34
+ ToolCallChunk,
35
+ )
36
+ from phoenix.server.api.types.Span import Span, to_gql_span
37
+ from phoenix.server.api.types.TemplateLanguage import TemplateLanguage
38
+ from phoenix.server.dml_event import SpanInsertEvent
39
+ from phoenix.trace.attributes import unflatten
40
+ from phoenix.trace.schemas import SpanException
41
+ from phoenix.utilities.template_formatters import (
42
+ FStringTemplateFormatter,
43
+ MustacheTemplateFormatter,
44
+ TemplateFormatter,
45
+ )
46
+
47
+ initialize_playground_clients()
48
+
49
+ ChatCompletionMessage = tuple[ChatCompletionMessageRole, str, Optional[str], Optional[List[Any]]]
50
+
51
+
52
+ @strawberry.type
53
+ class ChatCompletionFunctionCall:
54
+ name: str
55
+ arguments: str
56
+
57
+
58
+ @strawberry.type
59
+ class ChatCompletionToolCall:
60
+ id: str
61
+ function: ChatCompletionFunctionCall
62
+
63
+
64
+ @strawberry.type
65
+ class ChatCompletionMutationPayload:
66
+ content: Optional[str]
67
+ tool_calls: List[ChatCompletionToolCall]
68
+ span: Span
69
+ error_message: Optional[str]
70
+
71
+
72
+ @strawberry.type
73
+ class ChatCompletionMutationMixin:
74
+ @strawberry.mutation
75
+ async def chat_completion(
76
+ self, info: Info[Context, None], input: ChatCompletionInput
77
+ ) -> ChatCompletionMutationPayload:
78
+ provider_key = input.model.provider_key
79
+ llm_client_class = PLAYGROUND_CLIENT_REGISTRY.get_client(provider_key, input.model.name)
80
+ if llm_client_class is None:
81
+ raise BadRequest(f"No LLM client registered for provider '{provider_key}'")
82
+ attributes: dict[str, Any] = {}
83
+ llm_client = llm_client_class(
84
+ model=input.model,
85
+ api_key=input.api_key,
86
+ )
87
+
88
+ messages = [
89
+ (
90
+ message.role,
91
+ message.content,
92
+ message.tool_call_id if isinstance(message.tool_call_id, str) else None,
93
+ message.tool_calls if isinstance(message.tool_calls, list) else None,
94
+ )
95
+ for message in input.messages
96
+ ]
97
+
98
+ if template_options := input.template:
99
+ messages = list(_formatted_messages(messages, template_options))
100
+
101
+ invocation_parameters = llm_client.construct_invocation_parameters(
102
+ input.invocation_parameters
103
+ )
104
+
105
+ text_content = ""
106
+ tool_calls = []
107
+ events = []
108
+ attributes.update(
109
+ chain(
110
+ _llm_span_kind(),
111
+ _llm_model_name(input.model.name),
112
+ _llm_tools(input.tools or []),
113
+ _llm_input_messages(messages),
114
+ _llm_invocation_parameters(invocation_parameters),
115
+ _input_value_and_mime_type(input),
116
+ **llm_client.attributes,
117
+ )
118
+ )
119
+
120
+ start_time = normalize_datetime(dt=local_now(), tz=timezone.utc)
121
+ status_code = StatusCode.OK
122
+ status_message = ""
123
+ try:
124
+ async for chunk in llm_client.chat_completion_create(
125
+ messages=messages, tools=input.tools or [], **invocation_parameters
126
+ ):
127
+ # Process the chunk
128
+ if isinstance(chunk, TextChunk):
129
+ text_content += chunk.content
130
+ elif isinstance(chunk, ToolCallChunk):
131
+ tool_call = ChatCompletionToolCall(
132
+ id=chunk.id,
133
+ function=ChatCompletionFunctionCall(
134
+ name=chunk.function.name,
135
+ arguments=chunk.function.arguments,
136
+ ),
137
+ )
138
+ tool_calls.append(tool_call)
139
+ else:
140
+ assert_never(chunk)
141
+ except Exception as e:
142
+ # Handle exceptions and record exception event
143
+ status_code = StatusCode.ERROR
144
+ status_message = str(e)
145
+ end_time = normalize_datetime(dt=local_now(), tz=timezone.utc)
146
+ assert end_time is not None
147
+ events.append(
148
+ SpanException(
149
+ timestamp=end_time,
150
+ message=status_message,
151
+ exception_type=type(e).__name__,
152
+ exception_escaped=False,
153
+ exception_stacktrace=format_exc(),
154
+ )
155
+ )
156
+ else:
157
+ end_time = normalize_datetime(dt=local_now(), tz=timezone.utc)
158
+
159
+ if text_content or tool_calls:
160
+ attributes.update(
161
+ chain(
162
+ _output_value_and_mime_type({"text": text_content, "tool_calls": tool_calls}),
163
+ _llm_output_messages(text_content, tool_calls),
164
+ )
165
+ )
166
+
167
+ # Now write the span to the database
168
+ trace_id = _generate_trace_id()
169
+ span_id = _generate_span_id()
170
+ async with info.context.db() as session:
171
+ # Get or create the project ID
172
+ if (
173
+ project_id := await session.scalar(
174
+ select(models.Project.id).where(models.Project.name == PLAYGROUND_PROJECT_NAME)
175
+ )
176
+ ) is None:
177
+ project_id = await session.scalar(
178
+ insert(models.Project)
179
+ .returning(models.Project.id)
180
+ .values(
181
+ name=PLAYGROUND_PROJECT_NAME,
182
+ description="Traces from prompt playground",
183
+ )
184
+ )
185
+ trace = models.Trace(
186
+ project_rowid=project_id,
187
+ trace_id=trace_id,
188
+ start_time=start_time,
189
+ end_time=end_time,
190
+ )
191
+ span = models.Span(
192
+ trace_rowid=trace.id,
193
+ span_id=span_id,
194
+ parent_id=None,
195
+ name="ChatCompletion",
196
+ span_kind=LLM,
197
+ start_time=start_time,
198
+ end_time=end_time,
199
+ attributes=unflatten(attributes.items()),
200
+ events=[_serialize_event(event) for event in events],
201
+ status_code=status_code.name,
202
+ status_message=status_message,
203
+ cumulative_error_count=int(status_code is StatusCode.ERROR),
204
+ cumulative_llm_token_count_prompt=attributes.get(LLM_TOKEN_COUNT_PROMPT, 0),
205
+ cumulative_llm_token_count_completion=attributes.get(LLM_TOKEN_COUNT_COMPLETION, 0),
206
+ llm_token_count_prompt=attributes.get(LLM_TOKEN_COUNT_PROMPT, 0),
207
+ llm_token_count_completion=attributes.get(LLM_TOKEN_COUNT_COMPLETION, 0),
208
+ trace=trace,
209
+ )
210
+ session.add(trace)
211
+ session.add(span)
212
+ await session.flush()
213
+
214
+ gql_span = to_gql_span(span)
215
+
216
+ info.context.event_queue.put(SpanInsertEvent(ids=(project_id,)))
217
+
218
+ if status_code is StatusCode.ERROR:
219
+ return ChatCompletionMutationPayload(
220
+ content=None,
221
+ tool_calls=[],
222
+ span=gql_span,
223
+ error_message=status_message,
224
+ )
225
+ else:
226
+ return ChatCompletionMutationPayload(
227
+ content=text_content if text_content else None,
228
+ tool_calls=tool_calls,
229
+ span=gql_span,
230
+ error_message=None,
231
+ )
232
+
233
+
234
+ def _formatted_messages(
235
+ messages: Iterable[ChatCompletionMessage],
236
+ template_options: TemplateOptions,
237
+ ) -> Iterator[ChatCompletionMessage]:
238
+ """
239
+ Formats the messages using the given template options.
240
+ """
241
+ template_formatter = _template_formatter(template_language=template_options.language)
242
+ (
243
+ roles,
244
+ templates,
245
+ tool_call_id,
246
+ tool_calls,
247
+ ) = zip(*messages)
248
+ formatted_templates = map(
249
+ lambda template: template_formatter.format(template, **template_options.variables),
250
+ templates,
251
+ )
252
+ formatted_messages = zip(roles, formatted_templates, tool_call_id, tool_calls)
253
+ return formatted_messages
254
+
255
+
256
+ def _template_formatter(template_language: TemplateLanguage) -> TemplateFormatter:
257
+ """
258
+ Instantiates the appropriate template formatter for the template language.
259
+ """
260
+ if template_language is TemplateLanguage.MUSTACHE:
261
+ return MustacheTemplateFormatter()
262
+ if template_language is TemplateLanguage.F_STRING:
263
+ return FStringTemplateFormatter()
264
+ assert_never(template_language)
265
+
266
+
267
+ def _llm_span_kind() -> Iterator[tuple[str, Any]]:
268
+ yield OPENINFERENCE_SPAN_KIND, LLM
269
+
270
+
271
+ def _llm_model_name(model_name: str) -> Iterator[tuple[str, Any]]:
272
+ yield LLM_MODEL_NAME, model_name
273
+
274
+
275
+ def _llm_invocation_parameters(invocation_parameters: dict[str, Any]) -> Iterator[tuple[str, Any]]:
276
+ yield LLM_INVOCATION_PARAMETERS, json.dumps(invocation_parameters)
277
+
278
+
279
+ def _llm_tools(tools: List[Any]) -> Iterator[tuple[str, Any]]:
280
+ for tool_index, tool in enumerate(tools):
281
+ yield f"{LLM_TOOLS}.{tool_index}.{TOOL_JSON_SCHEMA}", json.dumps(tool)
282
+
283
+
284
+ def _input_value_and_mime_type(input: ChatCompletionInput) -> Iterator[tuple[str, Any]]:
285
+ input_data = input.__dict__.copy()
286
+ input_data.pop("api_key", None)
287
+ yield INPUT_MIME_TYPE, JSON
288
+ yield INPUT_VALUE, json.dumps(input_data)
289
+
290
+
291
+ def _output_value_and_mime_type(output: Any) -> Iterator[tuple[str, Any]]:
292
+ yield OUTPUT_MIME_TYPE, JSON
293
+ yield OUTPUT_VALUE, json.dumps(output)
294
+
295
+
296
+ def _llm_input_messages(
297
+ messages: Iterable[ChatCompletionMessage],
298
+ ) -> Iterator[tuple[str, Any]]:
299
+ for i, (role, content, _tool_call_id, tool_calls) in enumerate(messages):
300
+ yield f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_ROLE}", role.value.lower()
301
+ yield f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_CONTENT}", content
302
+ if tool_calls:
303
+ for tool_call_index, tool_call in enumerate(tool_calls):
304
+ yield (
305
+ f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_NAME}",
306
+ tool_call["function"]["name"],
307
+ )
308
+ if arguments := tool_call["function"]["arguments"]:
309
+ yield (
310
+ f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_ARGUMENTS_JSON}",
311
+ json.dumps(arguments),
312
+ )
313
+
314
+
315
+ def _llm_output_messages(
316
+ text_content: str, tool_calls: List[ChatCompletionToolCall]
317
+ ) -> Iterator[tuple[str, Any]]:
318
+ yield f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_ROLE}", "assistant"
319
+ if text_content:
320
+ yield f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_CONTENT}", text_content
321
+ for tool_call_index, tool_call in enumerate(tool_calls):
322
+ yield (
323
+ f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_NAME}",
324
+ tool_call.function.name,
325
+ )
326
+ if arguments := tool_call.function.arguments:
327
+ yield (
328
+ f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_ARGUMENTS_JSON}",
329
+ json.dumps(arguments),
330
+ )
331
+
332
+
333
+ def _generate_trace_id() -> str:
334
+ return _hex(DefaultOTelIDGenerator().generate_trace_id())
335
+
336
+
337
+ def _generate_span_id() -> str:
338
+ return _hex(DefaultOTelIDGenerator().generate_span_id())
339
+
340
+
341
+ def _hex(number: int) -> str:
342
+ return hex(number)[2:]
343
+
344
+
345
+ def _serialize_event(event: SpanException) -> dict[str, Any]:
346
+ return {k: (v.isoformat() if isinstance(v, datetime) else v) for k, v in asdict(event).items()}
347
+
348
+
349
+ JSON = OpenInferenceMimeTypeValues.JSON.value
350
+ LLM = OpenInferenceSpanKindValues.LLM.value
351
+
352
+ OPENINFERENCE_SPAN_KIND = SpanAttributes.OPENINFERENCE_SPAN_KIND
353
+ INPUT_MIME_TYPE = SpanAttributes.INPUT_MIME_TYPE
354
+ INPUT_VALUE = SpanAttributes.INPUT_VALUE
355
+ OUTPUT_MIME_TYPE = SpanAttributes.OUTPUT_MIME_TYPE
356
+ OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE
357
+ LLM_INPUT_MESSAGES = SpanAttributes.LLM_INPUT_MESSAGES
358
+ LLM_OUTPUT_MESSAGES = SpanAttributes.LLM_OUTPUT_MESSAGES
359
+ LLM_MODEL_NAME = SpanAttributes.LLM_MODEL_NAME
360
+ LLM_INVOCATION_PARAMETERS = SpanAttributes.LLM_INVOCATION_PARAMETERS
361
+ LLM_TOOLS = SpanAttributes.LLM_TOOLS
362
+ LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT
363
+ LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION
364
+
365
+ MESSAGE_CONTENT = MessageAttributes.MESSAGE_CONTENT
366
+ MESSAGE_ROLE = MessageAttributes.MESSAGE_ROLE
367
+ MESSAGE_TOOL_CALLS = MessageAttributes.MESSAGE_TOOL_CALLS
368
+
369
+ TOOL_CALL_FUNCTION_NAME = ToolCallAttributes.TOOL_CALL_FUNCTION_NAME
370
+ TOOL_CALL_FUNCTION_ARGUMENTS_JSON = ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUMENTS_JSON
371
+
372
+ TOOL_JSON_SCHEMA = ToolAttributes.TOOL_JSON_SCHEMA
373
+
374
+ PLAYGROUND_PROJECT_NAME = "playground"
@@ -37,12 +37,17 @@ from phoenix.server.api.auth import MSG_ADMIN_ONLY, IsAdmin
37
37
  from phoenix.server.api.context import Context
38
38
  from phoenix.server.api.exceptions import NotFound, Unauthorized
39
39
  from phoenix.server.api.helpers import ensure_list
40
+ from phoenix.server.api.helpers.playground_clients import initialize_playground_clients
41
+ from phoenix.server.api.helpers.playground_registry import PLAYGROUND_CLIENT_REGISTRY
40
42
  from phoenix.server.api.input_types.ClusterInput import ClusterInput
41
43
  from phoenix.server.api.input_types.Coordinates import (
42
44
  InputCoordinate2D,
43
45
  InputCoordinate3D,
44
46
  )
45
47
  from phoenix.server.api.input_types.DatasetSort import DatasetSort
48
+ from phoenix.server.api.input_types.InvocationParameters import (
49
+ InvocationParameter,
50
+ )
46
51
  from phoenix.server.api.types.Cluster import Cluster, to_gql_clusters
47
52
  from phoenix.server.api.types.Dataset import Dataset, to_gql_dataset
48
53
  from phoenix.server.api.types.DatasetExample import DatasetExample
@@ -80,78 +85,62 @@ from phoenix.server.api.types.User import User, to_gql_user
80
85
  from phoenix.server.api.types.UserApiKey import UserApiKey, to_gql_api_key
81
86
  from phoenix.server.api.types.UserRole import UserRole
82
87
 
88
+ initialize_playground_clients()
89
+
83
90
 
84
91
  @strawberry.input
85
92
  class ModelsInput:
86
93
  provider_key: Optional[GenerativeProviderKey]
94
+ model_name: Optional[str] = None
87
95
 
88
96
 
89
97
  @strawberry.type
90
98
  class Query:
91
99
  @strawberry.field
92
100
  async def model_providers(self) -> list[GenerativeProvider]:
101
+ available_providers = PLAYGROUND_CLIENT_REGISTRY.list_all_providers()
93
102
  return [
94
103
  GenerativeProvider(
95
- name="OpenAI",
96
- key=GenerativeProviderKey.OPENAI,
97
- ),
98
- GenerativeProvider(
99
- name="Azure OpenAI",
100
- key=GenerativeProviderKey.AZURE_OPENAI,
101
- ),
102
- GenerativeProvider(
103
- name="Anthropic",
104
- key=GenerativeProviderKey.ANTHROPIC,
105
- ),
104
+ name=provider_key.value,
105
+ key=provider_key,
106
+ )
107
+ for provider_key in available_providers
106
108
  ]
107
109
 
108
110
  @strawberry.field
109
111
  async def models(self, input: Optional[ModelsInput] = None) -> list[GenerativeModel]:
110
- openai_models = [
111
- "o1-preview",
112
- "o1-preview-2024-09-12",
113
- "o1-mini",
114
- "o1-mini-2024-09-12",
115
- "gpt-4o",
116
- "gpt-4o-2024-08-06",
117
- "gpt-4o-2024-05-13",
118
- "chatgpt-4o-latest",
119
- "gpt-4o-mini",
120
- "gpt-4o-mini-2024-07-18",
121
- "gpt-4-turbo",
122
- "gpt-4-turbo-2024-04-09",
123
- "gpt-4-turbo-preview",
124
- "gpt-4-0125-preview",
125
- "gpt-4-1106-preview",
126
- "gpt-4",
127
- "gpt-4-0613",
128
- "gpt-3.5-turbo-0125",
129
- "gpt-3.5-turbo",
130
- "gpt-3.5-turbo-1106",
131
- "gpt-3.5-turbo-instruct",
132
- ]
133
- anthropic_models = [
134
- "claude-3-5-sonnet-20240620",
135
- "claude-3-opus-20240229",
136
- "claude-3-sonnet-20240229",
137
- "claude-3-haiku-20240307",
138
- ]
139
- openai_generative_models = [
140
- GenerativeModel(name=model_name, provider_key=GenerativeProviderKey.OPENAI)
141
- for model_name in openai_models
142
- ]
143
- anthropic_generative_models = [
144
- GenerativeModel(name=model_name, provider_key=GenerativeProviderKey.ANTHROPIC)
145
- for model_name in anthropic_models
146
- ]
147
-
148
- all_models = openai_generative_models + anthropic_generative_models
149
-
150
112
  if input is not None and input.provider_key is not None:
151
- return [model for model in all_models if model.provider_key == input.provider_key]
113
+ supported_model_names = PLAYGROUND_CLIENT_REGISTRY.list_models(input.provider_key)
114
+ supported_models = [
115
+ GenerativeModel(name=model_name, provider_key=input.provider_key)
116
+ for model_name in supported_model_names
117
+ ]
118
+ return supported_models
152
119
 
120
+ registered_models = PLAYGROUND_CLIENT_REGISTRY.list_all_models()
121
+ all_models: list[GenerativeModel] = []
122
+ for provider_key, model_name in registered_models:
123
+ if model_name is not None and provider_key is not None:
124
+ all_models.append(GenerativeModel(name=model_name, provider_key=provider_key))
153
125
  return all_models
154
126
 
127
+ @strawberry.field
128
+ async def model_invocation_parameters(
129
+ self, input: Optional[ModelsInput] = None
130
+ ) -> list[InvocationParameter]:
131
+ if input is None:
132
+ return []
133
+ provider_key = input.provider_key
134
+ model_name = input.model_name
135
+ if provider_key is not None:
136
+ client = PLAYGROUND_CLIENT_REGISTRY.get_client(provider_key, model_name)
137
+ if client is None:
138
+ return []
139
+ invocation_parameters = client.supported_invocation_parameters()
140
+ return invocation_parameters
141
+ else:
142
+ return []
143
+
155
144
  @strawberry.field(permission_classes=[IsAdmin]) # type: ignore
156
145
  async def users(
157
146
  self,
@@ -1,17 +1,49 @@
1
+ from itertools import chain
2
+ from typing import Any, Iterable, Iterator, Optional, Union
3
+
1
4
  import strawberry
5
+ from strawberry.extensions import SchemaExtension
6
+ from strawberry.types.base import StrawberryObjectDefinition, StrawberryType
2
7
 
3
8
  from phoenix.server.api.exceptions import get_mask_errors_extension
4
9
  from phoenix.server.api.mutations import Mutation
5
10
  from phoenix.server.api.queries import Query
6
11
  from phoenix.server.api.subscriptions import Subscription
7
-
8
- # This is the schema for generating `schema.graphql`.
9
- # See https://strawberry.rocks/docs/guides/schema-export
10
- # It should be kept in sync with the server's runtime-initialized
11
- # instance. To do so, search for the usage of `strawberry.Schema(...)`.
12
- schema = strawberry.Schema(
13
- query=Query,
14
- mutation=Mutation,
15
- extensions=[get_mask_errors_extension()],
16
- subscription=Subscription,
12
+ from phoenix.server.api.types.ChatCompletionSubscriptionPayload import (
13
+ ChatCompletionSubscriptionPayload,
17
14
  )
15
+
16
+
17
+ def build_graphql_schema(
18
+ extensions: Optional[Iterable[Union[type[SchemaExtension], SchemaExtension]]] = None,
19
+ ) -> strawberry.Schema:
20
+ """
21
+ Builds a strawberry schema.
22
+ """
23
+ return strawberry.Schema(
24
+ query=Query,
25
+ mutation=Mutation,
26
+ extensions=chain(extensions or [], [get_mask_errors_extension()]),
27
+ subscription=Subscription,
28
+ types=_implementing_types(ChatCompletionSubscriptionPayload),
29
+ )
30
+
31
+
32
+ def _implementing_types(interface: Any) -> Iterator[StrawberryType]:
33
+ """
34
+ Iterates over strawberry types implementing the given strawberry interface.
35
+ """
36
+ assert isinstance(
37
+ strawberry_definition := getattr(interface, "__strawberry_definition__", None),
38
+ StrawberryObjectDefinition,
39
+ )
40
+ assert strawberry_definition.is_interface
41
+ for subcls in interface.__subclasses__():
42
+ if isinstance(
43
+ getattr(subcls, "__strawberry_definition__", None),
44
+ StrawberryObjectDefinition,
45
+ ):
46
+ yield subcls
47
+
48
+
49
+ _EXPORTED_GRAPHQL_SCHEMA = build_graphql_schema() # used to export the GraphQL schema to file