arize-phoenix 5.6.0__py3-none-any.whl → 5.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.7.0.dist-info}/METADATA +2 -2
- {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.7.0.dist-info}/RECORD +34 -25
- phoenix/config.py +42 -0
- phoenix/server/api/helpers/playground_clients.py +671 -0
- phoenix/server/api/helpers/playground_registry.py +70 -0
- phoenix/server/api/helpers/playground_spans.py +325 -0
- phoenix/server/api/input_types/ChatCompletionInput.py +38 -0
- phoenix/server/api/input_types/GenerativeModelInput.py +17 -0
- phoenix/server/api/input_types/InvocationParameters.py +156 -13
- phoenix/server/api/input_types/TemplateOptions.py +10 -0
- phoenix/server/api/mutations/__init__.py +4 -0
- phoenix/server/api/mutations/chat_mutations.py +374 -0
- phoenix/server/api/queries.py +41 -52
- phoenix/server/api/schema.py +42 -10
- phoenix/server/api/subscriptions.py +326 -595
- phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +44 -0
- phoenix/server/api/types/GenerativeProvider.py +27 -3
- phoenix/server/api/types/Span.py +37 -0
- phoenix/server/api/types/TemplateLanguage.py +9 -0
- phoenix/server/app.py +61 -13
- phoenix/server/main.py +14 -1
- phoenix/server/static/.vite/manifest.json +9 -9
- phoenix/server/static/assets/{components-C70HJiXz.js → components-Csu8UKOs.js} +114 -114
- phoenix/server/static/assets/{index-DLe1Oo3l.js → index-Bk5C9EA7.js} +1 -1
- phoenix/server/static/assets/{pages-C8-Sl7JI.js → pages-UeWaKXNs.js} +328 -268
- phoenix/server/templates/index.html +1 -0
- phoenix/services.py +4 -0
- phoenix/session/session.py +15 -1
- phoenix/utilities/template_formatters.py +11 -1
- phoenix/version.py +1 -1
- {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.7.0.dist-info}/WHEEL +0 -0
- {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.7.0.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.7.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.7.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
|
2
|
+
|
|
3
|
+
from phoenix.server.api.types.GenerativeProvider import GenerativeProviderKey
|
|
4
|
+
|
|
5
|
+
if TYPE_CHECKING:
|
|
6
|
+
from phoenix.server.api.helpers.playground_clients import PlaygroundStreamingClient
|
|
7
|
+
|
|
8
|
+
ModelName = Union[str, None]
|
|
9
|
+
ModelKey = tuple[GenerativeProviderKey, ModelName]
|
|
10
|
+
|
|
11
|
+
PROVIDER_DEFAULT = None
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class SingletonMeta(type):
|
|
15
|
+
_instances: dict[Any, Any] = dict()
|
|
16
|
+
|
|
17
|
+
def __call__(cls, *args: Any, **kwargs: Any) -> Any:
|
|
18
|
+
if cls not in cls._instances:
|
|
19
|
+
cls._instances[cls] = super(SingletonMeta, cls).__call__(*args, **kwargs)
|
|
20
|
+
return cls._instances[cls]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class PlaygroundClientRegistry(metaclass=SingletonMeta):
|
|
24
|
+
def __init__(self) -> None:
|
|
25
|
+
self._registry: dict[
|
|
26
|
+
GenerativeProviderKey, dict[ModelName, Optional[type["PlaygroundStreamingClient"]]]
|
|
27
|
+
] = {}
|
|
28
|
+
|
|
29
|
+
def get_client(
|
|
30
|
+
self,
|
|
31
|
+
provider_key: GenerativeProviderKey,
|
|
32
|
+
model_name: ModelName,
|
|
33
|
+
) -> Optional[type["PlaygroundStreamingClient"]]:
|
|
34
|
+
provider_registry = self._registry.get(provider_key, {})
|
|
35
|
+
client_class = provider_registry.get(model_name)
|
|
36
|
+
if client_class is None and None in provider_registry:
|
|
37
|
+
client_class = provider_registry[PROVIDER_DEFAULT] # Fallback to provider default
|
|
38
|
+
return client_class
|
|
39
|
+
|
|
40
|
+
def list_all_providers(
|
|
41
|
+
self,
|
|
42
|
+
) -> list[GenerativeProviderKey]:
|
|
43
|
+
return [provider_key for provider_key in self._registry]
|
|
44
|
+
|
|
45
|
+
def list_models(self, provider_key: GenerativeProviderKey) -> list[str]:
|
|
46
|
+
provider_registry = self._registry.get(provider_key, {})
|
|
47
|
+
return [model_name for model_name in provider_registry.keys() if model_name is not None]
|
|
48
|
+
|
|
49
|
+
def list_all_models(self) -> list[ModelKey]:
|
|
50
|
+
return [
|
|
51
|
+
(provider_key, model_name)
|
|
52
|
+
for provider_key, provider_registry in self._registry.items()
|
|
53
|
+
for model_name in provider_registry.keys()
|
|
54
|
+
]
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
PLAYGROUND_CLIENT_REGISTRY: PlaygroundClientRegistry = PlaygroundClientRegistry()
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def register_llm_client(
|
|
61
|
+
provider_key: GenerativeProviderKey,
|
|
62
|
+
model_names: list[ModelName],
|
|
63
|
+
) -> Callable[[type["PlaygroundStreamingClient"]], type["PlaygroundStreamingClient"]]:
|
|
64
|
+
def decorator(cls: type["PlaygroundStreamingClient"]) -> type["PlaygroundStreamingClient"]:
|
|
65
|
+
provider_registry = PLAYGROUND_CLIENT_REGISTRY._registry.setdefault(provider_key, {})
|
|
66
|
+
for model_name in model_names:
|
|
67
|
+
provider_registry[model_name] = cls
|
|
68
|
+
return cls
|
|
69
|
+
|
|
70
|
+
return decorator
|
|
@@ -0,0 +1,325 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from collections import defaultdict
|
|
3
|
+
from collections.abc import Mapping
|
|
4
|
+
from dataclasses import asdict
|
|
5
|
+
from datetime import datetime, timezone
|
|
6
|
+
from itertools import chain
|
|
7
|
+
from traceback import format_exc
|
|
8
|
+
from types import TracebackType
|
|
9
|
+
from typing import (
|
|
10
|
+
Any,
|
|
11
|
+
Iterable,
|
|
12
|
+
Iterator,
|
|
13
|
+
Optional,
|
|
14
|
+
Union,
|
|
15
|
+
cast,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
from openinference.instrumentation import safe_json_dumps
|
|
19
|
+
from openinference.semconv.trace import (
|
|
20
|
+
MessageAttributes,
|
|
21
|
+
OpenInferenceMimeTypeValues,
|
|
22
|
+
OpenInferenceSpanKindValues,
|
|
23
|
+
SpanAttributes,
|
|
24
|
+
ToolAttributes,
|
|
25
|
+
ToolCallAttributes,
|
|
26
|
+
)
|
|
27
|
+
from opentelemetry.sdk.trace.id_generator import RandomIdGenerator as DefaultOTelIDGenerator
|
|
28
|
+
from opentelemetry.trace import StatusCode
|
|
29
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
30
|
+
from strawberry.scalars import JSON as JSONScalarType
|
|
31
|
+
from typing_extensions import Self, TypeAlias, assert_never
|
|
32
|
+
|
|
33
|
+
from phoenix.datetime_utils import local_now, normalize_datetime
|
|
34
|
+
from phoenix.db import models
|
|
35
|
+
from phoenix.server.api.input_types.ChatCompletionInput import (
|
|
36
|
+
ChatCompletionInput,
|
|
37
|
+
ChatCompletionOverDatasetInput,
|
|
38
|
+
)
|
|
39
|
+
from phoenix.server.api.types.ChatCompletionMessageRole import ChatCompletionMessageRole
|
|
40
|
+
from phoenix.server.api.types.ChatCompletionSubscriptionPayload import (
|
|
41
|
+
TextChunk,
|
|
42
|
+
ToolCallChunk,
|
|
43
|
+
)
|
|
44
|
+
from phoenix.trace.attributes import unflatten
|
|
45
|
+
from phoenix.trace.schemas import (
|
|
46
|
+
SpanEvent,
|
|
47
|
+
SpanException,
|
|
48
|
+
)
|
|
49
|
+
from phoenix.utilities.json import jsonify
|
|
50
|
+
|
|
51
|
+
ChatCompletionMessage: TypeAlias = tuple[
|
|
52
|
+
ChatCompletionMessageRole, str, Optional[str], Optional[list[str]]
|
|
53
|
+
]
|
|
54
|
+
ToolCallID: TypeAlias = str
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class streaming_llm_span:
|
|
58
|
+
"""
|
|
59
|
+
Creates an LLM span for a streaming chat completion.
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
def __init__(
|
|
63
|
+
self,
|
|
64
|
+
*,
|
|
65
|
+
input: Union[ChatCompletionInput, ChatCompletionOverDatasetInput],
|
|
66
|
+
messages: list[ChatCompletionMessage],
|
|
67
|
+
invocation_parameters: Mapping[str, Any],
|
|
68
|
+
attributes: Optional[dict[str, Any]] = None,
|
|
69
|
+
) -> None:
|
|
70
|
+
self._input = input
|
|
71
|
+
self._attributes: dict[str, Any] = attributes if attributes is not None else {}
|
|
72
|
+
self._attributes.update(
|
|
73
|
+
chain(
|
|
74
|
+
_llm_span_kind(),
|
|
75
|
+
_llm_model_name(input.model.name),
|
|
76
|
+
_llm_tools(input.tools or []),
|
|
77
|
+
_llm_input_messages(messages),
|
|
78
|
+
_llm_invocation_parameters(invocation_parameters),
|
|
79
|
+
_input_value_and_mime_type(input),
|
|
80
|
+
)
|
|
81
|
+
)
|
|
82
|
+
self._events: list[SpanEvent] = []
|
|
83
|
+
self._start_time: datetime
|
|
84
|
+
self._end_time: datetime
|
|
85
|
+
self._response_chunks: list[Union[TextChunk, ToolCallChunk]] = []
|
|
86
|
+
self._text_chunks: list[TextChunk] = []
|
|
87
|
+
self._tool_call_chunks: defaultdict[ToolCallID, list[ToolCallChunk]] = defaultdict(list)
|
|
88
|
+
self._status_code: StatusCode
|
|
89
|
+
self._status_message: str
|
|
90
|
+
self._db_span: models.Span
|
|
91
|
+
self._db_trace: models.Trace
|
|
92
|
+
|
|
93
|
+
async def __aenter__(self) -> Self:
|
|
94
|
+
self._start_time = cast(datetime, normalize_datetime(dt=local_now(), tz=timezone.utc))
|
|
95
|
+
return self
|
|
96
|
+
|
|
97
|
+
async def __aexit__(
|
|
98
|
+
self,
|
|
99
|
+
exc_type: Optional[type[BaseException]],
|
|
100
|
+
exc_value: Optional[BaseException],
|
|
101
|
+
traceback: Optional[TracebackType],
|
|
102
|
+
) -> bool:
|
|
103
|
+
self._end_time = cast(datetime, normalize_datetime(dt=local_now(), tz=timezone.utc))
|
|
104
|
+
self._status_code = StatusCode.OK
|
|
105
|
+
self._status_message = ""
|
|
106
|
+
if exc_type is not None:
|
|
107
|
+
self._status_code = StatusCode.ERROR
|
|
108
|
+
self._status_message = str(exc_value)
|
|
109
|
+
self._events.append(
|
|
110
|
+
SpanException(
|
|
111
|
+
timestamp=self._end_time,
|
|
112
|
+
message=self._status_message,
|
|
113
|
+
exception_type=type(exc_value).__name__,
|
|
114
|
+
exception_escaped=False,
|
|
115
|
+
exception_stacktrace=format_exc(),
|
|
116
|
+
)
|
|
117
|
+
)
|
|
118
|
+
if self._response_chunks:
|
|
119
|
+
self._attributes.update(
|
|
120
|
+
chain(
|
|
121
|
+
_output_value_and_mime_type(self._response_chunks),
|
|
122
|
+
_llm_output_messages(self._text_chunks, self._tool_call_chunks),
|
|
123
|
+
)
|
|
124
|
+
)
|
|
125
|
+
return True
|
|
126
|
+
|
|
127
|
+
def set_attributes(self, attributes: Mapping[str, Any]) -> None:
|
|
128
|
+
self._attributes.update(attributes)
|
|
129
|
+
|
|
130
|
+
def add_to_session(
|
|
131
|
+
self,
|
|
132
|
+
session: AsyncSession,
|
|
133
|
+
project_id: int,
|
|
134
|
+
) -> models.Span:
|
|
135
|
+
prompt_tokens = self._attributes.get(LLM_TOKEN_COUNT_PROMPT, 0)
|
|
136
|
+
completion_tokens = self._attributes.get(LLM_TOKEN_COUNT_COMPLETION, 0)
|
|
137
|
+
trace_id = _generate_trace_id()
|
|
138
|
+
span_id = _generate_span_id()
|
|
139
|
+
self._db_trace = models.Trace(
|
|
140
|
+
project_rowid=project_id,
|
|
141
|
+
trace_id=trace_id,
|
|
142
|
+
start_time=self._start_time,
|
|
143
|
+
end_time=self._end_time,
|
|
144
|
+
)
|
|
145
|
+
self._db_span = models.Span(
|
|
146
|
+
trace_rowid=self._db_trace.id,
|
|
147
|
+
span_id=span_id,
|
|
148
|
+
parent_id=None,
|
|
149
|
+
name="ChatCompletion",
|
|
150
|
+
span_kind=LLM,
|
|
151
|
+
start_time=self._start_time,
|
|
152
|
+
end_time=self._end_time,
|
|
153
|
+
attributes=unflatten(self._attributes.items()),
|
|
154
|
+
events=[_serialize_event(event) for event in self._events],
|
|
155
|
+
status_code=self._status_code.name,
|
|
156
|
+
status_message=self._status_message,
|
|
157
|
+
cumulative_error_count=int(self._status_code is StatusCode.ERROR),
|
|
158
|
+
cumulative_llm_token_count_prompt=prompt_tokens,
|
|
159
|
+
cumulative_llm_token_count_completion=completion_tokens,
|
|
160
|
+
llm_token_count_prompt=prompt_tokens,
|
|
161
|
+
llm_token_count_completion=completion_tokens,
|
|
162
|
+
trace=self._db_trace,
|
|
163
|
+
)
|
|
164
|
+
session.add(self._db_trace)
|
|
165
|
+
session.add(self._db_span)
|
|
166
|
+
return self._db_span
|
|
167
|
+
|
|
168
|
+
def add_response_chunk(self, chunk: Union[TextChunk, ToolCallChunk]) -> None:
|
|
169
|
+
self._response_chunks.append(chunk)
|
|
170
|
+
if isinstance(chunk, TextChunk):
|
|
171
|
+
self._text_chunks.append(chunk)
|
|
172
|
+
elif isinstance(chunk, ToolCallChunk):
|
|
173
|
+
self._tool_call_chunks[chunk.id].append(chunk)
|
|
174
|
+
else:
|
|
175
|
+
assert_never(chunk)
|
|
176
|
+
|
|
177
|
+
@property
|
|
178
|
+
def start_time(self) -> datetime:
|
|
179
|
+
return self._db_span.start_time
|
|
180
|
+
|
|
181
|
+
@property
|
|
182
|
+
def end_time(self) -> datetime:
|
|
183
|
+
return self._db_span.end_time
|
|
184
|
+
|
|
185
|
+
@property
|
|
186
|
+
def error_message(self) -> Optional[str]:
|
|
187
|
+
return self._status_message if self._status_code is StatusCode.ERROR else None
|
|
188
|
+
|
|
189
|
+
@property
|
|
190
|
+
def trace_id(self) -> str:
|
|
191
|
+
return self._db_trace.trace_id
|
|
192
|
+
|
|
193
|
+
@property
|
|
194
|
+
def attributes(self) -> dict[str, Any]:
|
|
195
|
+
return self._db_span.attributes
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def _llm_span_kind() -> Iterator[tuple[str, Any]]:
|
|
199
|
+
yield OPENINFERENCE_SPAN_KIND, LLM
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def _llm_model_name(model_name: str) -> Iterator[tuple[str, Any]]:
|
|
203
|
+
yield LLM_MODEL_NAME, model_name
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def _llm_invocation_parameters(
|
|
207
|
+
invocation_parameters: Mapping[str, Any],
|
|
208
|
+
) -> Iterator[tuple[str, Any]]:
|
|
209
|
+
if invocation_parameters:
|
|
210
|
+
yield LLM_INVOCATION_PARAMETERS, safe_json_dumps(invocation_parameters)
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def _llm_tools(tools: list[JSONScalarType]) -> Iterator[tuple[str, Any]]:
|
|
214
|
+
for tool_index, tool in enumerate(tools):
|
|
215
|
+
yield f"{LLM_TOOLS}.{tool_index}.{TOOL_JSON_SCHEMA}", json.dumps(tool)
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def _input_value_and_mime_type(input: Any) -> Iterator[tuple[str, Any]]:
|
|
219
|
+
assert (api_key := "api_key") in (input_data := jsonify(input))
|
|
220
|
+
disallowed_keys = {"api_key", "invocation_parameters"}
|
|
221
|
+
input_data = {k: v for k, v in input_data.items() if k not in disallowed_keys}
|
|
222
|
+
assert api_key not in input_data
|
|
223
|
+
yield INPUT_MIME_TYPE, JSON
|
|
224
|
+
yield INPUT_VALUE, safe_json_dumps(input_data)
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def _output_value_and_mime_type(output: Any) -> Iterator[tuple[str, Any]]:
|
|
228
|
+
yield OUTPUT_MIME_TYPE, JSON
|
|
229
|
+
yield OUTPUT_VALUE, safe_json_dumps(jsonify(output))
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
def _llm_input_messages(
|
|
233
|
+
messages: Iterable[
|
|
234
|
+
tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[JSONScalarType]]]
|
|
235
|
+
],
|
|
236
|
+
) -> Iterator[tuple[str, Any]]:
|
|
237
|
+
for i, (role, content, _tool_call_id, tool_calls) in enumerate(messages):
|
|
238
|
+
yield f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_ROLE}", role.value.lower()
|
|
239
|
+
yield f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_CONTENT}", content
|
|
240
|
+
if tool_calls is not None:
|
|
241
|
+
for tool_call_index, tool_call in enumerate(tool_calls):
|
|
242
|
+
yield (
|
|
243
|
+
f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_NAME}",
|
|
244
|
+
tool_call["function"]["name"],
|
|
245
|
+
)
|
|
246
|
+
if arguments := tool_call["function"]["arguments"]:
|
|
247
|
+
yield (
|
|
248
|
+
f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_ARGUMENTS_JSON}",
|
|
249
|
+
safe_json_dumps(jsonify(arguments)),
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def _llm_output_messages(
|
|
254
|
+
text_chunks: list[TextChunk],
|
|
255
|
+
tool_call_chunks: defaultdict[ToolCallID, list[ToolCallChunk]],
|
|
256
|
+
) -> Iterator[tuple[str, Any]]:
|
|
257
|
+
yield f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_ROLE}", "assistant"
|
|
258
|
+
if content := "".join(chunk.content for chunk in text_chunks):
|
|
259
|
+
yield f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_CONTENT}", content
|
|
260
|
+
for tool_call_index, (_tool_call_id, tool_call_chunks_) in enumerate(tool_call_chunks.items()):
|
|
261
|
+
if tool_call_chunks_ and (name := tool_call_chunks_[0].function.name):
|
|
262
|
+
yield (
|
|
263
|
+
f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_NAME}",
|
|
264
|
+
name,
|
|
265
|
+
)
|
|
266
|
+
if arguments := "".join(chunk.function.arguments for chunk in tool_call_chunks_):
|
|
267
|
+
yield (
|
|
268
|
+
f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_ARGUMENTS_JSON}",
|
|
269
|
+
arguments,
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
def _generate_trace_id() -> str:
|
|
274
|
+
"""
|
|
275
|
+
Generates a random trace ID in hexadecimal format.
|
|
276
|
+
"""
|
|
277
|
+
return _hex(DefaultOTelIDGenerator().generate_trace_id())
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
def _generate_span_id() -> str:
|
|
281
|
+
"""
|
|
282
|
+
Generates a random span ID in hexadecimal format.
|
|
283
|
+
"""
|
|
284
|
+
return _hex(DefaultOTelIDGenerator().generate_span_id())
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def _hex(number: int) -> str:
|
|
288
|
+
"""
|
|
289
|
+
Converts an integer to a hexadecimal string.
|
|
290
|
+
"""
|
|
291
|
+
return hex(number)[2:]
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
def _serialize_event(event: SpanEvent) -> dict[str, Any]:
|
|
295
|
+
"""
|
|
296
|
+
Serializes a SpanEvent to a dictionary.
|
|
297
|
+
"""
|
|
298
|
+
return {k: (v.isoformat() if isinstance(v, datetime) else v) for k, v in asdict(event).items()}
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
JSON = OpenInferenceMimeTypeValues.JSON.value
|
|
302
|
+
|
|
303
|
+
LLM = OpenInferenceSpanKindValues.LLM.value
|
|
304
|
+
|
|
305
|
+
OPENINFERENCE_SPAN_KIND = SpanAttributes.OPENINFERENCE_SPAN_KIND
|
|
306
|
+
INPUT_MIME_TYPE = SpanAttributes.INPUT_MIME_TYPE
|
|
307
|
+
INPUT_VALUE = SpanAttributes.INPUT_VALUE
|
|
308
|
+
OUTPUT_MIME_TYPE = SpanAttributes.OUTPUT_MIME_TYPE
|
|
309
|
+
OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE
|
|
310
|
+
LLM_INPUT_MESSAGES = SpanAttributes.LLM_INPUT_MESSAGES
|
|
311
|
+
LLM_OUTPUT_MESSAGES = SpanAttributes.LLM_OUTPUT_MESSAGES
|
|
312
|
+
LLM_MODEL_NAME = SpanAttributes.LLM_MODEL_NAME
|
|
313
|
+
LLM_INVOCATION_PARAMETERS = SpanAttributes.LLM_INVOCATION_PARAMETERS
|
|
314
|
+
LLM_TOOLS = SpanAttributes.LLM_TOOLS
|
|
315
|
+
LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT
|
|
316
|
+
LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION
|
|
317
|
+
|
|
318
|
+
MESSAGE_CONTENT = MessageAttributes.MESSAGE_CONTENT
|
|
319
|
+
MESSAGE_ROLE = MessageAttributes.MESSAGE_ROLE
|
|
320
|
+
MESSAGE_TOOL_CALLS = MessageAttributes.MESSAGE_TOOL_CALLS
|
|
321
|
+
|
|
322
|
+
TOOL_CALL_FUNCTION_NAME = ToolCallAttributes.TOOL_CALL_FUNCTION_NAME
|
|
323
|
+
TOOL_CALL_FUNCTION_ARGUMENTS_JSON = ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUMENTS_JSON
|
|
324
|
+
|
|
325
|
+
TOOL_JSON_SCHEMA = ToolAttributes.TOOL_JSON_SCHEMA
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import strawberry
|
|
4
|
+
from strawberry import UNSET
|
|
5
|
+
from strawberry.relay.types import GlobalID
|
|
6
|
+
from strawberry.scalars import JSON
|
|
7
|
+
|
|
8
|
+
from phoenix.server.api.types.TemplateLanguage import TemplateLanguage
|
|
9
|
+
|
|
10
|
+
from .ChatCompletionMessageInput import ChatCompletionMessageInput
|
|
11
|
+
from .GenerativeModelInput import GenerativeModelInput
|
|
12
|
+
from .InvocationParameters import InvocationParameterInput
|
|
13
|
+
from .TemplateOptions import TemplateOptions
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@strawberry.input
|
|
17
|
+
class ChatCompletionInput:
|
|
18
|
+
messages: list[ChatCompletionMessageInput]
|
|
19
|
+
model: GenerativeModelInput
|
|
20
|
+
invocation_parameters: list[InvocationParameterInput] = strawberry.field(default_factory=list)
|
|
21
|
+
tools: Optional[list[JSON]] = UNSET
|
|
22
|
+
api_key: Optional[str] = strawberry.field(default=None)
|
|
23
|
+
template: Optional[TemplateOptions] = UNSET
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@strawberry.input
|
|
27
|
+
class ChatCompletionOverDatasetInput:
|
|
28
|
+
messages: list[ChatCompletionMessageInput]
|
|
29
|
+
model: GenerativeModelInput
|
|
30
|
+
invocation_parameters: list[InvocationParameterInput] = strawberry.field(default_factory=list)
|
|
31
|
+
tools: Optional[list[JSON]] = UNSET
|
|
32
|
+
api_key: Optional[str] = strawberry.field(default=None)
|
|
33
|
+
template_language: TemplateLanguage
|
|
34
|
+
dataset_id: GlobalID
|
|
35
|
+
dataset_version_id: Optional[GlobalID] = None
|
|
36
|
+
experiment_name: Optional[str] = None
|
|
37
|
+
experiment_description: Optional[str] = None
|
|
38
|
+
experiment_metadata: Optional[JSON] = strawberry.field(default_factory=dict)
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import strawberry
|
|
4
|
+
from strawberry import UNSET
|
|
5
|
+
|
|
6
|
+
from phoenix.server.api.types.GenerativeProvider import GenerativeProviderKey
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@strawberry.input
|
|
10
|
+
class GenerativeModelInput:
|
|
11
|
+
provider_key: GenerativeProviderKey
|
|
12
|
+
name: str
|
|
13
|
+
""" The name of the model. Or the Deployment Name for Azure OpenAI models. """
|
|
14
|
+
endpoint: Optional[str] = UNSET
|
|
15
|
+
""" The endpoint to use for the model. Only required for Azure OpenAI models. """
|
|
16
|
+
api_version: Optional[str] = UNSET
|
|
17
|
+
""" The API version to use for the model. """
|
|
@@ -1,20 +1,163 @@
|
|
|
1
|
-
from
|
|
1
|
+
from enum import Enum
|
|
2
|
+
from typing import Annotated, Any, Mapping, Optional, Union
|
|
2
3
|
|
|
3
4
|
import strawberry
|
|
4
5
|
from strawberry import UNSET
|
|
5
6
|
from strawberry.scalars import JSON
|
|
6
7
|
|
|
7
8
|
|
|
9
|
+
@strawberry.enum
|
|
10
|
+
class CanonicalParameterName(str, Enum):
|
|
11
|
+
TEMPERATURE = "temperature"
|
|
12
|
+
MAX_COMPLETION_TOKENS = "max_completion_tokens"
|
|
13
|
+
STOP_SEQUENCES = "stop_sequences"
|
|
14
|
+
TOP_P = "top_p"
|
|
15
|
+
RANDOM_SEED = "random_seed"
|
|
16
|
+
TOOL_CHOICE = "tool_choice"
|
|
17
|
+
RESPONSE_FORMAT = "response_format"
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@strawberry.enum
|
|
21
|
+
class InvocationInputField(str, Enum):
|
|
22
|
+
value_int = "value_int"
|
|
23
|
+
value_float = "value_float"
|
|
24
|
+
value_bool = "value_bool"
|
|
25
|
+
value_string = "value_string"
|
|
26
|
+
value_json = "value_json"
|
|
27
|
+
value_string_list = "value_string_list"
|
|
28
|
+
value_boolean = "value_boolean"
|
|
29
|
+
|
|
30
|
+
|
|
8
31
|
@strawberry.input
|
|
9
|
-
class
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
32
|
+
class InvocationParameterInput:
|
|
33
|
+
invocation_name: str
|
|
34
|
+
canonical_name: Optional[CanonicalParameterName] = None
|
|
35
|
+
value_int: Optional[int] = UNSET
|
|
36
|
+
value_float: Optional[float] = UNSET
|
|
37
|
+
value_bool: Optional[bool] = UNSET
|
|
38
|
+
value_string: Optional[str] = UNSET
|
|
39
|
+
value_json: Optional[JSON] = UNSET
|
|
40
|
+
value_string_list: Optional[list[str]] = UNSET
|
|
41
|
+
value_boolean: Optional[bool] = UNSET
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@strawberry.interface
|
|
45
|
+
class InvocationParameterBase:
|
|
46
|
+
invocation_name: str
|
|
47
|
+
canonical_name: Optional[CanonicalParameterName] = None
|
|
48
|
+
label: str
|
|
49
|
+
required: bool = False
|
|
50
|
+
hidden: bool = False
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@strawberry.type
|
|
54
|
+
class IntInvocationParameter(InvocationParameterBase):
|
|
55
|
+
invocation_input_field: InvocationInputField = InvocationInputField.value_int
|
|
56
|
+
default_value: Optional[int] = UNSET
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@strawberry.type
|
|
60
|
+
class FloatInvocationParameter(InvocationParameterBase):
|
|
61
|
+
invocation_input_field: InvocationInputField = InvocationInputField.value_float
|
|
62
|
+
default_value: Optional[float] = UNSET
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@strawberry.type
|
|
66
|
+
class BoundedFloatInvocationParameter(InvocationParameterBase):
|
|
67
|
+
invocation_input_field: InvocationInputField = InvocationInputField.value_float
|
|
68
|
+
default_value: Optional[float] = UNSET
|
|
69
|
+
min_value: float
|
|
70
|
+
max_value: float
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@strawberry.type
|
|
74
|
+
class StringInvocationParameter(InvocationParameterBase):
|
|
75
|
+
invocation_input_field: InvocationInputField = InvocationInputField.value_string
|
|
76
|
+
default_value: Optional[str] = UNSET
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
@strawberry.type
|
|
80
|
+
class JSONInvocationParameter(InvocationParameterBase):
|
|
81
|
+
invocation_input_field: InvocationInputField = InvocationInputField.value_json
|
|
82
|
+
default_value: Optional[JSON] = UNSET
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
@strawberry.type
|
|
86
|
+
class StringListInvocationParameter(InvocationParameterBase):
|
|
87
|
+
invocation_input_field: InvocationInputField = InvocationInputField.value_string_list
|
|
88
|
+
default_value: Optional[list[str]] = UNSET
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
@strawberry.type
|
|
92
|
+
class BooleanInvocationParameter(InvocationParameterBase):
|
|
93
|
+
invocation_input_field: InvocationInputField = InvocationInputField.value_bool
|
|
94
|
+
default_value: Optional[bool] = UNSET
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def extract_parameter(
|
|
98
|
+
param_def: InvocationParameterBase, param_input: InvocationParameterInput
|
|
99
|
+
) -> Any:
|
|
100
|
+
if isinstance(param_def, IntInvocationParameter):
|
|
101
|
+
return (
|
|
102
|
+
param_input.value_int if param_input.value_int is not UNSET else param_def.default_value
|
|
103
|
+
)
|
|
104
|
+
elif isinstance(param_def, FloatInvocationParameter):
|
|
105
|
+
return (
|
|
106
|
+
param_input.value_float
|
|
107
|
+
if param_input.value_float is not UNSET
|
|
108
|
+
else param_def.default_value
|
|
109
|
+
)
|
|
110
|
+
elif isinstance(param_def, BoundedFloatInvocationParameter):
|
|
111
|
+
return (
|
|
112
|
+
param_input.value_float
|
|
113
|
+
if param_input.value_float is not UNSET
|
|
114
|
+
else param_def.default_value
|
|
115
|
+
)
|
|
116
|
+
elif isinstance(param_def, StringInvocationParameter):
|
|
117
|
+
return (
|
|
118
|
+
param_input.value_string
|
|
119
|
+
if param_input.value_string is not UNSET
|
|
120
|
+
else param_def.default_value
|
|
121
|
+
)
|
|
122
|
+
elif isinstance(param_def, JSONInvocationParameter):
|
|
123
|
+
return (
|
|
124
|
+
param_input.value_json
|
|
125
|
+
if param_input.value_json is not UNSET
|
|
126
|
+
else param_def.default_value
|
|
127
|
+
)
|
|
128
|
+
elif isinstance(param_def, StringListInvocationParameter):
|
|
129
|
+
return (
|
|
130
|
+
param_input.value_string_list
|
|
131
|
+
if param_input.value_string_list is not UNSET
|
|
132
|
+
else param_def.default_value
|
|
133
|
+
)
|
|
134
|
+
elif isinstance(param_def, BooleanInvocationParameter):
|
|
135
|
+
return (
|
|
136
|
+
param_input.value_bool
|
|
137
|
+
if param_input.value_bool is not UNSET
|
|
138
|
+
else param_def.default_value
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def validate_invocation_parameters(
|
|
143
|
+
parameters: list["InvocationParameter"],
|
|
144
|
+
input: Mapping[str, Any],
|
|
145
|
+
) -> None:
|
|
146
|
+
for param_def in parameters:
|
|
147
|
+
if param_def.required and param_def.invocation_name not in input:
|
|
148
|
+
raise ValueError(f"Required parameter {param_def.invocation_name} not provided")
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
# Create the union for output types
|
|
152
|
+
InvocationParameter = Annotated[
|
|
153
|
+
Union[
|
|
154
|
+
IntInvocationParameter,
|
|
155
|
+
FloatInvocationParameter,
|
|
156
|
+
BoundedFloatInvocationParameter,
|
|
157
|
+
StringInvocationParameter,
|
|
158
|
+
JSONInvocationParameter,
|
|
159
|
+
StringListInvocationParameter,
|
|
160
|
+
BooleanInvocationParameter,
|
|
161
|
+
],
|
|
162
|
+
strawberry.union("InvocationParameter"),
|
|
163
|
+
]
|
|
@@ -1,6 +1,9 @@
|
|
|
1
1
|
import strawberry
|
|
2
2
|
|
|
3
3
|
from phoenix.server.api.mutations.api_key_mutations import ApiKeyMutationMixin
|
|
4
|
+
from phoenix.server.api.mutations.chat_mutations import (
|
|
5
|
+
ChatCompletionMutationMixin,
|
|
6
|
+
)
|
|
4
7
|
from phoenix.server.api.mutations.dataset_mutations import DatasetMutationMixin
|
|
5
8
|
from phoenix.server.api.mutations.experiment_mutations import ExperimentMutationMixin
|
|
6
9
|
from phoenix.server.api.mutations.export_events_mutations import ExportEventsMutationMixin
|
|
@@ -20,5 +23,6 @@ class Mutation(
|
|
|
20
23
|
SpanAnnotationMutationMixin,
|
|
21
24
|
TraceAnnotationMutationMixin,
|
|
22
25
|
UserMutationMixin,
|
|
26
|
+
ChatCompletionMutationMixin,
|
|
23
27
|
):
|
|
24
28
|
pass
|