arize-phoenix 5.6.0__py3-none-any.whl → 5.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.7.0.dist-info}/METADATA +2 -2
- {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.7.0.dist-info}/RECORD +34 -25
- phoenix/config.py +42 -0
- phoenix/server/api/helpers/playground_clients.py +671 -0
- phoenix/server/api/helpers/playground_registry.py +70 -0
- phoenix/server/api/helpers/playground_spans.py +325 -0
- phoenix/server/api/input_types/ChatCompletionInput.py +38 -0
- phoenix/server/api/input_types/GenerativeModelInput.py +17 -0
- phoenix/server/api/input_types/InvocationParameters.py +156 -13
- phoenix/server/api/input_types/TemplateOptions.py +10 -0
- phoenix/server/api/mutations/__init__.py +4 -0
- phoenix/server/api/mutations/chat_mutations.py +374 -0
- phoenix/server/api/queries.py +41 -52
- phoenix/server/api/schema.py +42 -10
- phoenix/server/api/subscriptions.py +326 -595
- phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +44 -0
- phoenix/server/api/types/GenerativeProvider.py +27 -3
- phoenix/server/api/types/Span.py +37 -0
- phoenix/server/api/types/TemplateLanguage.py +9 -0
- phoenix/server/app.py +61 -13
- phoenix/server/main.py +14 -1
- phoenix/server/static/.vite/manifest.json +9 -9
- phoenix/server/static/assets/{components-C70HJiXz.js → components-Csu8UKOs.js} +114 -114
- phoenix/server/static/assets/{index-DLe1Oo3l.js → index-Bk5C9EA7.js} +1 -1
- phoenix/server/static/assets/{pages-C8-Sl7JI.js → pages-UeWaKXNs.js} +328 -268
- phoenix/server/templates/index.html +1 -0
- phoenix/services.py +4 -0
- phoenix/session/session.py +15 -1
- phoenix/utilities/template_formatters.py +11 -1
- phoenix/version.py +1 -1
- {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.7.0.dist-info}/WHEEL +0 -0
- {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.7.0.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.7.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.7.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,419 +1,73 @@
|
|
|
1
|
-
import
|
|
2
|
-
from
|
|
3
|
-
from collections import
|
|
4
|
-
from
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
1
|
+
import logging
|
|
2
|
+
from asyncio import FIRST_COMPLETED, Task, create_task, wait
|
|
3
|
+
from collections.abc import Iterator
|
|
4
|
+
from typing import (
|
|
5
|
+
Any,
|
|
6
|
+
AsyncIterator,
|
|
7
|
+
Collection,
|
|
8
|
+
Iterable,
|
|
9
|
+
Mapping,
|
|
10
|
+
Optional,
|
|
11
|
+
TypeVar,
|
|
12
|
+
)
|
|
11
13
|
|
|
12
14
|
import strawberry
|
|
13
|
-
from openinference.
|
|
14
|
-
from
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
OpenInferenceSpanKindValues,
|
|
18
|
-
SpanAttributes,
|
|
19
|
-
ToolAttributes,
|
|
20
|
-
ToolCallAttributes,
|
|
21
|
-
)
|
|
22
|
-
from opentelemetry.sdk.trace.id_generator import RandomIdGenerator as DefaultOTelIDGenerator
|
|
23
|
-
from opentelemetry.trace import StatusCode
|
|
24
|
-
from sqlalchemy import insert, select
|
|
25
|
-
from strawberry import UNSET
|
|
26
|
-
from strawberry.scalars import JSON as JSONScalarType
|
|
15
|
+
from openinference.semconv.trace import SpanAttributes
|
|
16
|
+
from sqlalchemy import and_, func, insert, select
|
|
17
|
+
from sqlalchemy.orm import load_only
|
|
18
|
+
from strawberry.relay.types import GlobalID
|
|
27
19
|
from strawberry.types import Info
|
|
28
20
|
from typing_extensions import TypeAlias, assert_never
|
|
29
21
|
|
|
30
|
-
from phoenix.datetime_utils import local_now, normalize_datetime
|
|
31
22
|
from phoenix.db import models
|
|
32
23
|
from phoenix.server.api.context import Context
|
|
33
24
|
from phoenix.server.api.exceptions import BadRequest
|
|
34
|
-
from phoenix.server.api.
|
|
35
|
-
|
|
25
|
+
from phoenix.server.api.helpers.playground_clients import (
|
|
26
|
+
PlaygroundStreamingClient,
|
|
27
|
+
initialize_playground_clients,
|
|
28
|
+
)
|
|
29
|
+
from phoenix.server.api.helpers.playground_registry import (
|
|
30
|
+
PLAYGROUND_CLIENT_REGISTRY,
|
|
31
|
+
)
|
|
32
|
+
from phoenix.server.api.helpers.playground_spans import streaming_llm_span
|
|
33
|
+
from phoenix.server.api.input_types.ChatCompletionInput import (
|
|
34
|
+
ChatCompletionInput,
|
|
35
|
+
ChatCompletionOverDatasetInput,
|
|
36
|
+
)
|
|
36
37
|
from phoenix.server.api.types.ChatCompletionMessageRole import ChatCompletionMessageRole
|
|
37
|
-
from phoenix.server.api.types.
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
SpanEvent,
|
|
43
|
-
SpanException,
|
|
38
|
+
from phoenix.server.api.types.ChatCompletionSubscriptionPayload import (
|
|
39
|
+
ChatCompletionOverDatasetSubscriptionResult,
|
|
40
|
+
ChatCompletionSubscriptionError,
|
|
41
|
+
ChatCompletionSubscriptionPayload,
|
|
42
|
+
FinishedChatCompletion,
|
|
44
43
|
)
|
|
45
|
-
from phoenix.
|
|
44
|
+
from phoenix.server.api.types.Dataset import Dataset
|
|
45
|
+
from phoenix.server.api.types.DatasetExample import DatasetExample
|
|
46
|
+
from phoenix.server.api.types.DatasetVersion import DatasetVersion
|
|
47
|
+
from phoenix.server.api.types.Experiment import to_gql_experiment
|
|
48
|
+
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
49
|
+
from phoenix.server.api.types.Span import to_gql_span
|
|
50
|
+
from phoenix.server.api.types.TemplateLanguage import TemplateLanguage
|
|
51
|
+
from phoenix.server.dml_event import SpanInsertEvent
|
|
52
|
+
from phoenix.trace.attributes import get_attribute_value
|
|
46
53
|
from phoenix.utilities.template_formatters import (
|
|
47
54
|
FStringTemplateFormatter,
|
|
48
55
|
MustacheTemplateFormatter,
|
|
49
56
|
TemplateFormatter,
|
|
57
|
+
TemplateFormatterError,
|
|
50
58
|
)
|
|
51
59
|
|
|
52
|
-
|
|
53
|
-
from anthropic.types import MessageParam
|
|
54
|
-
from openai.types import CompletionUsage
|
|
55
|
-
from openai.types.chat import (
|
|
56
|
-
ChatCompletionMessageParam,
|
|
57
|
-
ChatCompletionMessageToolCallParam,
|
|
58
|
-
)
|
|
59
|
-
|
|
60
|
-
PLAYGROUND_PROJECT_NAME = "playground"
|
|
61
|
-
|
|
62
|
-
ToolCallID: TypeAlias = str
|
|
63
|
-
SetSpanAttributesFn: TypeAlias = Callable[[Mapping[str, Any]], None]
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
@strawberry.enum
|
|
67
|
-
class TemplateLanguage(Enum):
|
|
68
|
-
MUSTACHE = "MUSTACHE"
|
|
69
|
-
F_STRING = "F_STRING"
|
|
70
|
-
|
|
60
|
+
GenericType = TypeVar("GenericType")
|
|
71
61
|
|
|
72
|
-
|
|
73
|
-
class TemplateOptions:
|
|
74
|
-
variables: JSONScalarType
|
|
75
|
-
language: TemplateLanguage
|
|
62
|
+
logger = logging.getLogger(__name__)
|
|
76
63
|
|
|
64
|
+
initialize_playground_clients()
|
|
77
65
|
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
content: str
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
@strawberry.type
|
|
84
|
-
class FunctionCallChunk:
|
|
85
|
-
name: str
|
|
86
|
-
arguments: str
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
@strawberry.type
|
|
90
|
-
class ToolCallChunk:
|
|
91
|
-
id: str
|
|
92
|
-
function: FunctionCallChunk
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
@strawberry.type
|
|
96
|
-
class ChatCompletionSubscriptionError:
|
|
97
|
-
message: str
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
@strawberry.type
|
|
101
|
-
class FinishedChatCompletion:
|
|
102
|
-
span: Span
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
ChatCompletionChunk: TypeAlias = Union[TextChunk, ToolCallChunk]
|
|
106
|
-
|
|
107
|
-
ChatCompletionSubscriptionPayload: TypeAlias = Annotated[
|
|
108
|
-
Union[TextChunk, ToolCallChunk, FinishedChatCompletion, ChatCompletionSubscriptionError],
|
|
109
|
-
strawberry.union("ChatCompletionSubscriptionPayload"),
|
|
66
|
+
ChatCompletionMessage: TypeAlias = tuple[
|
|
67
|
+
ChatCompletionMessageRole, str, Optional[str], Optional[list[str]]
|
|
110
68
|
]
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
@strawberry.input
|
|
114
|
-
class GenerativeModelInput:
|
|
115
|
-
provider_key: GenerativeProviderKey
|
|
116
|
-
name: str
|
|
117
|
-
""" The name of the model. Or the Deployment Name for Azure OpenAI models. """
|
|
118
|
-
endpoint: Optional[str] = UNSET
|
|
119
|
-
""" The endpoint to use for the model. Only required for Azure OpenAI models. """
|
|
120
|
-
api_version: Optional[str] = UNSET
|
|
121
|
-
""" The API version to use for the model. """
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
@strawberry.input
|
|
125
|
-
class ChatCompletionInput:
|
|
126
|
-
messages: list[ChatCompletionMessageInput]
|
|
127
|
-
model: GenerativeModelInput
|
|
128
|
-
invocation_parameters: InvocationParameters = strawberry.field(default_factory=dict)
|
|
129
|
-
tools: Optional[list[JSONScalarType]] = UNSET
|
|
130
|
-
template: Optional[TemplateOptions] = UNSET
|
|
131
|
-
api_key: Optional[str] = strawberry.field(default=None)
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
PLAYGROUND_STREAMING_CLIENT_REGISTRY: dict[
|
|
135
|
-
GenerativeProviderKey, type["PlaygroundStreamingClient"]
|
|
136
|
-
] = {}
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
def register_llm_client(
|
|
140
|
-
provider_key: GenerativeProviderKey,
|
|
141
|
-
) -> Callable[[type["PlaygroundStreamingClient"]], type["PlaygroundStreamingClient"]]:
|
|
142
|
-
def decorator(cls: type["PlaygroundStreamingClient"]) -> type["PlaygroundStreamingClient"]:
|
|
143
|
-
PLAYGROUND_STREAMING_CLIENT_REGISTRY[provider_key] = cls
|
|
144
|
-
return cls
|
|
145
|
-
|
|
146
|
-
return decorator
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
class PlaygroundStreamingClient(ABC):
|
|
150
|
-
def __init__(
|
|
151
|
-
self,
|
|
152
|
-
model: GenerativeModelInput,
|
|
153
|
-
api_key: Optional[str] = None,
|
|
154
|
-
set_span_attributes: Optional[SetSpanAttributesFn] = None,
|
|
155
|
-
) -> None:
|
|
156
|
-
self._set_span_attributes = set_span_attributes
|
|
157
|
-
|
|
158
|
-
@abstractmethod
|
|
159
|
-
async def chat_completion_create(
|
|
160
|
-
self,
|
|
161
|
-
messages: list[
|
|
162
|
-
tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[JSONScalarType]]]
|
|
163
|
-
],
|
|
164
|
-
tools: list[JSONScalarType],
|
|
165
|
-
**invocation_parameters: Any,
|
|
166
|
-
) -> AsyncIterator[ChatCompletionChunk]:
|
|
167
|
-
# a yield statement is needed to satisfy the type-checker
|
|
168
|
-
# https://mypy.readthedocs.io/en/stable/more_types.html#asynchronous-iterators
|
|
169
|
-
yield TextChunk(content="")
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
@register_llm_client(GenerativeProviderKey.OPENAI)
|
|
173
|
-
class OpenAIStreamingClient(PlaygroundStreamingClient):
|
|
174
|
-
def __init__(
|
|
175
|
-
self,
|
|
176
|
-
model: GenerativeModelInput,
|
|
177
|
-
api_key: Optional[str] = None,
|
|
178
|
-
set_span_attributes: Optional[SetSpanAttributesFn] = None,
|
|
179
|
-
) -> None:
|
|
180
|
-
from openai import AsyncOpenAI
|
|
181
|
-
|
|
182
|
-
super().__init__(model=model, api_key=api_key, set_span_attributes=set_span_attributes)
|
|
183
|
-
self.client = AsyncOpenAI(api_key=api_key)
|
|
184
|
-
self.model_name = model.name
|
|
185
|
-
|
|
186
|
-
async def chat_completion_create(
|
|
187
|
-
self,
|
|
188
|
-
messages: list[
|
|
189
|
-
tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[JSONScalarType]]]
|
|
190
|
-
],
|
|
191
|
-
tools: list[JSONScalarType],
|
|
192
|
-
**invocation_parameters: Any,
|
|
193
|
-
) -> AsyncIterator[ChatCompletionChunk]:
|
|
194
|
-
from openai import NOT_GIVEN
|
|
195
|
-
from openai.types.chat import ChatCompletionStreamOptionsParam
|
|
196
|
-
|
|
197
|
-
# Convert standard messages to OpenAI messages
|
|
198
|
-
openai_messages = [self.to_openai_chat_completion_param(*message) for message in messages]
|
|
199
|
-
tool_call_ids: dict[int, str] = {}
|
|
200
|
-
token_usage: Optional["CompletionUsage"] = None
|
|
201
|
-
async for chunk in await self.client.chat.completions.create(
|
|
202
|
-
messages=openai_messages,
|
|
203
|
-
model=self.model_name,
|
|
204
|
-
stream=True,
|
|
205
|
-
stream_options=ChatCompletionStreamOptionsParam(include_usage=True),
|
|
206
|
-
tools=tools or NOT_GIVEN,
|
|
207
|
-
**invocation_parameters,
|
|
208
|
-
):
|
|
209
|
-
if (usage := chunk.usage) is not None:
|
|
210
|
-
token_usage = usage
|
|
211
|
-
continue
|
|
212
|
-
choice = chunk.choices[0]
|
|
213
|
-
delta = choice.delta
|
|
214
|
-
if choice.finish_reason is None:
|
|
215
|
-
if isinstance(chunk_content := delta.content, str):
|
|
216
|
-
text_chunk = TextChunk(content=chunk_content)
|
|
217
|
-
yield text_chunk
|
|
218
|
-
if (tool_calls := delta.tool_calls) is not None:
|
|
219
|
-
for tool_call_index, tool_call in enumerate(tool_calls):
|
|
220
|
-
tool_call_id = (
|
|
221
|
-
tool_call.id
|
|
222
|
-
if tool_call.id is not None
|
|
223
|
-
else tool_call_ids[tool_call_index]
|
|
224
|
-
)
|
|
225
|
-
tool_call_ids[tool_call_index] = tool_call_id
|
|
226
|
-
if (function := tool_call.function) is not None:
|
|
227
|
-
tool_call_chunk = ToolCallChunk(
|
|
228
|
-
id=tool_call_id,
|
|
229
|
-
function=FunctionCallChunk(
|
|
230
|
-
name=function.name or "",
|
|
231
|
-
arguments=function.arguments or "",
|
|
232
|
-
),
|
|
233
|
-
)
|
|
234
|
-
yield tool_call_chunk
|
|
235
|
-
if token_usage is not None and self._set_span_attributes:
|
|
236
|
-
self._set_span_attributes(dict(self._llm_token_counts(token_usage)))
|
|
237
|
-
|
|
238
|
-
def to_openai_chat_completion_param(
|
|
239
|
-
self,
|
|
240
|
-
role: ChatCompletionMessageRole,
|
|
241
|
-
content: JSONScalarType,
|
|
242
|
-
tool_call_id: Optional[str] = None,
|
|
243
|
-
tool_calls: Optional[list[JSONScalarType]] = None,
|
|
244
|
-
) -> "ChatCompletionMessageParam":
|
|
245
|
-
from openai.types.chat import (
|
|
246
|
-
ChatCompletionAssistantMessageParam,
|
|
247
|
-
ChatCompletionSystemMessageParam,
|
|
248
|
-
ChatCompletionToolMessageParam,
|
|
249
|
-
ChatCompletionUserMessageParam,
|
|
250
|
-
)
|
|
251
|
-
|
|
252
|
-
if role is ChatCompletionMessageRole.USER:
|
|
253
|
-
return ChatCompletionUserMessageParam(
|
|
254
|
-
{
|
|
255
|
-
"content": content,
|
|
256
|
-
"role": "user",
|
|
257
|
-
}
|
|
258
|
-
)
|
|
259
|
-
if role is ChatCompletionMessageRole.SYSTEM:
|
|
260
|
-
return ChatCompletionSystemMessageParam(
|
|
261
|
-
{
|
|
262
|
-
"content": content,
|
|
263
|
-
"role": "system",
|
|
264
|
-
}
|
|
265
|
-
)
|
|
266
|
-
if role is ChatCompletionMessageRole.AI:
|
|
267
|
-
if tool_calls is None:
|
|
268
|
-
return ChatCompletionAssistantMessageParam(
|
|
269
|
-
{
|
|
270
|
-
"content": content,
|
|
271
|
-
"role": "assistant",
|
|
272
|
-
}
|
|
273
|
-
)
|
|
274
|
-
else:
|
|
275
|
-
return ChatCompletionAssistantMessageParam(
|
|
276
|
-
{
|
|
277
|
-
"content": content,
|
|
278
|
-
"role": "assistant",
|
|
279
|
-
"tool_calls": [
|
|
280
|
-
self.to_openai_tool_call_param(tool_call) for tool_call in tool_calls
|
|
281
|
-
],
|
|
282
|
-
}
|
|
283
|
-
)
|
|
284
|
-
if role is ChatCompletionMessageRole.TOOL:
|
|
285
|
-
if tool_call_id is None:
|
|
286
|
-
raise ValueError("tool_call_id is required for tool messages")
|
|
287
|
-
return ChatCompletionToolMessageParam(
|
|
288
|
-
{"content": content, "role": "tool", "tool_call_id": tool_call_id}
|
|
289
|
-
)
|
|
290
|
-
assert_never(role)
|
|
291
|
-
|
|
292
|
-
def to_openai_tool_call_param(
|
|
293
|
-
self,
|
|
294
|
-
tool_call: JSONScalarType,
|
|
295
|
-
) -> "ChatCompletionMessageToolCallParam":
|
|
296
|
-
from openai.types.chat import ChatCompletionMessageToolCallParam
|
|
297
|
-
|
|
298
|
-
return ChatCompletionMessageToolCallParam(
|
|
299
|
-
id=tool_call.get("id", ""),
|
|
300
|
-
function={
|
|
301
|
-
"name": tool_call.get("function", {}).get("name", ""),
|
|
302
|
-
"arguments": safe_json_dumps(tool_call.get("function", {}).get("arguments", "")),
|
|
303
|
-
},
|
|
304
|
-
type="function",
|
|
305
|
-
)
|
|
306
|
-
|
|
307
|
-
@staticmethod
|
|
308
|
-
def _llm_token_counts(usage: "CompletionUsage") -> Iterator[tuple[str, Any]]:
|
|
309
|
-
yield LLM_TOKEN_COUNT_PROMPT, usage.prompt_tokens
|
|
310
|
-
yield LLM_TOKEN_COUNT_COMPLETION, usage.completion_tokens
|
|
311
|
-
yield LLM_TOKEN_COUNT_TOTAL, usage.total_tokens
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
@register_llm_client(GenerativeProviderKey.AZURE_OPENAI)
|
|
315
|
-
class AzureOpenAIStreamingClient(OpenAIStreamingClient):
|
|
316
|
-
def __init__(
|
|
317
|
-
self,
|
|
318
|
-
model: GenerativeModelInput,
|
|
319
|
-
api_key: Optional[str] = None,
|
|
320
|
-
set_span_attributes: Optional[SetSpanAttributesFn] = None,
|
|
321
|
-
):
|
|
322
|
-
from openai import AsyncAzureOpenAI
|
|
323
|
-
|
|
324
|
-
super().__init__(model=model, api_key=api_key, set_span_attributes=set_span_attributes)
|
|
325
|
-
if model.endpoint is None or model.api_version is None:
|
|
326
|
-
raise ValueError("endpoint and api_version are required for Azure OpenAI models")
|
|
327
|
-
self.client = AsyncAzureOpenAI(
|
|
328
|
-
api_key=api_key,
|
|
329
|
-
azure_endpoint=model.endpoint,
|
|
330
|
-
api_version=model.api_version,
|
|
331
|
-
)
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
@register_llm_client(GenerativeProviderKey.ANTHROPIC)
|
|
335
|
-
class AnthropicStreamingClient(PlaygroundStreamingClient):
|
|
336
|
-
def __init__(
|
|
337
|
-
self,
|
|
338
|
-
model: GenerativeModelInput,
|
|
339
|
-
api_key: Optional[str] = None,
|
|
340
|
-
set_span_attributes: Optional[SetSpanAttributesFn] = None,
|
|
341
|
-
) -> None:
|
|
342
|
-
import anthropic
|
|
343
|
-
|
|
344
|
-
super().__init__(model=model, api_key=api_key, set_span_attributes=set_span_attributes)
|
|
345
|
-
self.client = anthropic.AsyncAnthropic(api_key=api_key)
|
|
346
|
-
self.model_name = model.name
|
|
347
|
-
|
|
348
|
-
async def chat_completion_create(
|
|
349
|
-
self,
|
|
350
|
-
messages: list[
|
|
351
|
-
tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[JSONScalarType]]]
|
|
352
|
-
],
|
|
353
|
-
tools: list[JSONScalarType],
|
|
354
|
-
**invocation_parameters: Any,
|
|
355
|
-
) -> AsyncIterator[ChatCompletionChunk]:
|
|
356
|
-
import anthropic.lib.streaming as anthropic_streaming
|
|
357
|
-
import anthropic.types as anthropic_types
|
|
358
|
-
|
|
359
|
-
anthropic_messages, system_prompt = self._build_anthropic_messages(messages)
|
|
360
|
-
|
|
361
|
-
anthropic_params = {
|
|
362
|
-
"messages": anthropic_messages,
|
|
363
|
-
"model": self.model_name,
|
|
364
|
-
"system": system_prompt,
|
|
365
|
-
"max_tokens": 1024,
|
|
366
|
-
**invocation_parameters,
|
|
367
|
-
}
|
|
368
|
-
async with self.client.messages.stream(**anthropic_params) as stream:
|
|
369
|
-
async for event in stream:
|
|
370
|
-
if isinstance(event, anthropic_types.RawMessageStartEvent):
|
|
371
|
-
if self._set_span_attributes:
|
|
372
|
-
self._set_span_attributes(
|
|
373
|
-
{LLM_TOKEN_COUNT_PROMPT: event.message.usage.input_tokens}
|
|
374
|
-
)
|
|
375
|
-
elif isinstance(event, anthropic_streaming.TextEvent):
|
|
376
|
-
yield TextChunk(content=event.text)
|
|
377
|
-
elif isinstance(event, anthropic_streaming.MessageStopEvent):
|
|
378
|
-
if self._set_span_attributes:
|
|
379
|
-
self._set_span_attributes(
|
|
380
|
-
{LLM_TOKEN_COUNT_COMPLETION: event.message.usage.output_tokens}
|
|
381
|
-
)
|
|
382
|
-
elif isinstance(
|
|
383
|
-
event,
|
|
384
|
-
(
|
|
385
|
-
anthropic_types.RawContentBlockStartEvent,
|
|
386
|
-
anthropic_types.RawContentBlockDeltaEvent,
|
|
387
|
-
anthropic_types.RawMessageDeltaEvent,
|
|
388
|
-
anthropic_streaming.ContentBlockStopEvent,
|
|
389
|
-
),
|
|
390
|
-
):
|
|
391
|
-
# event types emitted by the stream that don't contain useful information
|
|
392
|
-
pass
|
|
393
|
-
elif isinstance(event, anthropic_streaming.InputJsonEvent):
|
|
394
|
-
raise NotImplementedError
|
|
395
|
-
else:
|
|
396
|
-
assert_never(event)
|
|
397
|
-
|
|
398
|
-
def _build_anthropic_messages(
|
|
399
|
-
self,
|
|
400
|
-
messages: list[tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[str]]]],
|
|
401
|
-
) -> tuple[list["MessageParam"], str]:
|
|
402
|
-
anthropic_messages: list["MessageParam"] = []
|
|
403
|
-
system_prompt = ""
|
|
404
|
-
for role, content, _tool_call_id, _tool_calls in messages:
|
|
405
|
-
if role == ChatCompletionMessageRole.USER:
|
|
406
|
-
anthropic_messages.append({"role": "user", "content": content})
|
|
407
|
-
elif role == ChatCompletionMessageRole.AI:
|
|
408
|
-
anthropic_messages.append({"role": "assistant", "content": content})
|
|
409
|
-
elif role == ChatCompletionMessageRole.SYSTEM:
|
|
410
|
-
system_prompt += content + "\n"
|
|
411
|
-
elif role == ChatCompletionMessageRole.TOOL:
|
|
412
|
-
raise NotImplementedError
|
|
413
|
-
else:
|
|
414
|
-
assert_never(role)
|
|
415
|
-
|
|
416
|
-
return anthropic_messages, system_prompt
|
|
69
|
+
DatasetExampleID: TypeAlias = GlobalID
|
|
70
|
+
PLAYGROUND_PROJECT_NAME = "playground"
|
|
417
71
|
|
|
418
72
|
|
|
419
73
|
@strawberry.type
|
|
@@ -422,15 +76,15 @@ class Subscription:
|
|
|
422
76
|
async def chat_completion(
|
|
423
77
|
self, info: Info[Context, None], input: ChatCompletionInput
|
|
424
78
|
) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
|
|
425
|
-
# Determine which LLM client to use based on provider_key
|
|
426
79
|
provider_key = input.model.provider_key
|
|
427
|
-
|
|
80
|
+
llm_client_class = PLAYGROUND_CLIENT_REGISTRY.get_client(provider_key, input.model.name)
|
|
81
|
+
if llm_client_class is None:
|
|
428
82
|
raise BadRequest(f"No LLM client registered for provider '{provider_key}'")
|
|
429
83
|
llm_client = llm_client_class(
|
|
430
84
|
model=input.model,
|
|
431
85
|
api_key=input.api_key,
|
|
432
|
-
set_span_attributes=lambda attrs: attributes.update(attrs),
|
|
433
86
|
)
|
|
87
|
+
|
|
434
88
|
messages = [
|
|
435
89
|
(
|
|
436
90
|
message.role,
|
|
@@ -441,69 +95,29 @@ class Subscription:
|
|
|
441
95
|
for message in input.messages
|
|
442
96
|
]
|
|
443
97
|
if template_options := input.template:
|
|
444
|
-
messages = list(
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
_llm_tools(input.tools or []),
|
|
451
|
-
_llm_input_messages(messages),
|
|
452
|
-
_llm_invocation_parameters(invocation_parameters),
|
|
453
|
-
_input_value_and_mime_type(input),
|
|
98
|
+
messages = list(
|
|
99
|
+
_formatted_messages(
|
|
100
|
+
messages=messages,
|
|
101
|
+
template_language=template_options.language,
|
|
102
|
+
template_variables=template_options.variables,
|
|
103
|
+
)
|
|
454
104
|
)
|
|
105
|
+
invocation_parameters = llm_client.construct_invocation_parameters(
|
|
106
|
+
input.invocation_parameters
|
|
455
107
|
)
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
response_chunks = []
|
|
462
|
-
text_chunks: list[TextChunk] = []
|
|
463
|
-
tool_call_chunks: defaultdict[ToolCallID, list[ToolCallChunk]] = defaultdict(list)
|
|
464
|
-
try:
|
|
465
|
-
start_time = cast(datetime, normalize_datetime(dt=local_now(), tz=timezone.utc))
|
|
108
|
+
async with streaming_llm_span(
|
|
109
|
+
input=input,
|
|
110
|
+
messages=messages,
|
|
111
|
+
invocation_parameters=invocation_parameters,
|
|
112
|
+
) as span:
|
|
466
113
|
async for chunk in llm_client.chat_completion_create(
|
|
467
|
-
messages=messages,
|
|
468
|
-
tools=input.tools or [],
|
|
469
|
-
**invocation_parameters,
|
|
114
|
+
messages=messages, tools=input.tools or [], **invocation_parameters
|
|
470
115
|
):
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
yield chunk
|
|
477
|
-
tool_call_chunks[chunk.id].append(chunk)
|
|
478
|
-
else:
|
|
479
|
-
assert_never(chunk)
|
|
480
|
-
status_code = StatusCode.OK
|
|
481
|
-
except Exception as error:
|
|
482
|
-
end_time = cast(datetime, normalize_datetime(dt=local_now(), tz=timezone.utc))
|
|
483
|
-
status_code = StatusCode.ERROR
|
|
484
|
-
status_message = str(error)
|
|
485
|
-
events.append(
|
|
486
|
-
SpanException(
|
|
487
|
-
timestamp=end_time,
|
|
488
|
-
message=status_message,
|
|
489
|
-
exception_type=type(error).__name__,
|
|
490
|
-
exception_escaped=False,
|
|
491
|
-
exception_stacktrace=format_exc(),
|
|
492
|
-
)
|
|
493
|
-
)
|
|
494
|
-
yield ChatCompletionSubscriptionError(message=status_message)
|
|
495
|
-
else:
|
|
496
|
-
end_time = cast(datetime, normalize_datetime(dt=local_now(), tz=timezone.utc))
|
|
497
|
-
attributes.update(
|
|
498
|
-
chain(
|
|
499
|
-
_output_value_and_mime_type(response_chunks),
|
|
500
|
-
_llm_output_messages(text_chunks, tool_call_chunks),
|
|
501
|
-
)
|
|
502
|
-
)
|
|
503
|
-
prompt_tokens = attributes.get(LLM_TOKEN_COUNT_PROMPT, 0)
|
|
504
|
-
completion_tokens = attributes.get(LLM_TOKEN_COUNT_COMPLETION, 0)
|
|
505
|
-
trace_id = _generate_trace_id()
|
|
506
|
-
span_id = _generate_span_id()
|
|
116
|
+
span.add_response_chunk(chunk)
|
|
117
|
+
yield chunk
|
|
118
|
+
span.set_attributes(llm_client.attributes)
|
|
119
|
+
if span.error_message is not None:
|
|
120
|
+
yield ChatCompletionSubscriptionError(message=span.error_message)
|
|
507
121
|
async with info.context.db() as session:
|
|
508
122
|
if (
|
|
509
123
|
playground_project_id := await session.scalar(
|
|
@@ -518,138 +132,262 @@ class Subscription:
|
|
|
518
132
|
description="Traces from prompt playground",
|
|
519
133
|
)
|
|
520
134
|
)
|
|
521
|
-
|
|
522
|
-
project_rowid=playground_project_id,
|
|
523
|
-
trace_id=trace_id,
|
|
524
|
-
start_time=start_time,
|
|
525
|
-
end_time=end_time,
|
|
526
|
-
)
|
|
527
|
-
playground_span = models.Span(
|
|
528
|
-
trace_rowid=playground_trace.id,
|
|
529
|
-
span_id=span_id,
|
|
530
|
-
parent_id=None,
|
|
531
|
-
name="ChatCompletion",
|
|
532
|
-
span_kind=LLM,
|
|
533
|
-
start_time=start_time,
|
|
534
|
-
end_time=end_time,
|
|
535
|
-
attributes=unflatten(attributes.items()),
|
|
536
|
-
events=[_serialize_event(event) for event in events],
|
|
537
|
-
status_code=status_code.name,
|
|
538
|
-
status_message=status_message,
|
|
539
|
-
cumulative_error_count=int(status_code is StatusCode.ERROR),
|
|
540
|
-
cumulative_llm_token_count_prompt=prompt_tokens,
|
|
541
|
-
cumulative_llm_token_count_completion=completion_tokens,
|
|
542
|
-
llm_token_count_prompt=prompt_tokens,
|
|
543
|
-
llm_token_count_completion=completion_tokens,
|
|
544
|
-
trace=playground_trace,
|
|
545
|
-
)
|
|
546
|
-
session.add(playground_trace)
|
|
547
|
-
session.add(playground_span)
|
|
135
|
+
db_span = span.add_to_session(session, playground_project_id)
|
|
548
136
|
await session.flush()
|
|
549
|
-
yield FinishedChatCompletion(span=to_gql_span(
|
|
137
|
+
yield FinishedChatCompletion(span=to_gql_span(db_span))
|
|
550
138
|
info.context.event_queue.put(SpanInsertEvent(ids=(playground_project_id,)))
|
|
551
139
|
|
|
140
|
+
@strawberry.subscription
|
|
141
|
+
async def chat_completion_over_dataset(
|
|
142
|
+
self, info: Info[Context, None], input: ChatCompletionOverDatasetInput
|
|
143
|
+
) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
|
|
144
|
+
provider_key = input.model.provider_key
|
|
145
|
+
llm_client_class = PLAYGROUND_CLIENT_REGISTRY.get_client(provider_key, input.model.name)
|
|
146
|
+
if llm_client_class is None:
|
|
147
|
+
raise BadRequest(f"No LLM client registered for provider '{provider_key}'")
|
|
552
148
|
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
149
|
+
dataset_id = from_global_id_with_expected_type(input.dataset_id, Dataset.__name__)
|
|
150
|
+
version_id = (
|
|
151
|
+
from_global_id_with_expected_type(
|
|
152
|
+
global_id=input.dataset_version_id, expected_type_name=DatasetVersion.__name__
|
|
153
|
+
)
|
|
154
|
+
if input.dataset_version_id
|
|
155
|
+
else None
|
|
156
|
+
)
|
|
157
|
+
revision_ids = (
|
|
158
|
+
select(func.max(models.DatasetExampleRevision.id))
|
|
159
|
+
.join(models.DatasetExample)
|
|
160
|
+
.where(models.DatasetExample.dataset_id == dataset_id)
|
|
161
|
+
.group_by(models.DatasetExampleRevision.dataset_example_id)
|
|
162
|
+
)
|
|
163
|
+
if version_id:
|
|
164
|
+
version_id_subquery = (
|
|
165
|
+
select(models.DatasetVersion.id)
|
|
166
|
+
.where(models.DatasetVersion.dataset_id == dataset_id)
|
|
167
|
+
.where(models.DatasetVersion.id == version_id)
|
|
168
|
+
.scalar_subquery()
|
|
169
|
+
)
|
|
170
|
+
revision_ids = revision_ids.where(
|
|
171
|
+
models.DatasetExampleRevision.dataset_version_id <= version_id_subquery
|
|
172
|
+
)
|
|
173
|
+
query = (
|
|
174
|
+
select(models.DatasetExampleRevision)
|
|
175
|
+
.where(
|
|
176
|
+
and_(
|
|
177
|
+
models.DatasetExampleRevision.id.in_(revision_ids),
|
|
178
|
+
models.DatasetExampleRevision.revision_kind != "DELETE",
|
|
179
|
+
)
|
|
180
|
+
)
|
|
181
|
+
.order_by(models.DatasetExampleRevision.dataset_example_id.asc())
|
|
182
|
+
.options(
|
|
183
|
+
load_only(
|
|
184
|
+
models.DatasetExampleRevision.dataset_example_id,
|
|
185
|
+
models.DatasetExampleRevision.input,
|
|
186
|
+
)
|
|
187
|
+
)
|
|
188
|
+
)
|
|
189
|
+
async with info.context.db() as session:
|
|
190
|
+
revisions = [revision async for revision in await session.stream_scalars(query)]
|
|
191
|
+
if not revisions:
|
|
192
|
+
raise BadRequest("No examples found for the given dataset and version")
|
|
193
|
+
|
|
194
|
+
spans: dict[DatasetExampleID, streaming_llm_span] = {}
|
|
195
|
+
async for payload in _merge_iterators(
|
|
196
|
+
[
|
|
197
|
+
_stream_chat_completion_over_dataset_example(
|
|
198
|
+
input=input,
|
|
199
|
+
llm_client_class=llm_client_class,
|
|
200
|
+
revision=revision,
|
|
201
|
+
spans=spans,
|
|
202
|
+
)
|
|
203
|
+
for revision in revisions
|
|
204
|
+
]
|
|
205
|
+
):
|
|
206
|
+
yield payload
|
|
582
207
|
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
) -> Iterator[tuple[str, Any]]:
|
|
588
|
-
for i, (role, content, _tool_call_id, tool_calls) in enumerate(messages):
|
|
589
|
-
yield f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_ROLE}", role.value.lower()
|
|
590
|
-
yield f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_CONTENT}", content
|
|
591
|
-
if tool_calls is not None:
|
|
592
|
-
for tool_call_index, tool_call in enumerate(tool_calls):
|
|
593
|
-
yield (
|
|
594
|
-
f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_NAME}",
|
|
595
|
-
tool_call["function"]["name"],
|
|
208
|
+
async with info.context.db() as session:
|
|
209
|
+
if (
|
|
210
|
+
playground_project_id := await session.scalar(
|
|
211
|
+
select(models.Project.id).where(models.Project.name == PLAYGROUND_PROJECT_NAME)
|
|
596
212
|
)
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
213
|
+
) is None:
|
|
214
|
+
playground_project_id = await session.scalar(
|
|
215
|
+
insert(models.Project)
|
|
216
|
+
.returning(models.Project.id)
|
|
217
|
+
.values(
|
|
218
|
+
name=PLAYGROUND_PROJECT_NAME,
|
|
219
|
+
description="Traces from prompt playground",
|
|
601
220
|
)
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
221
|
+
)
|
|
222
|
+
db_spans = {
|
|
223
|
+
example_id: span.add_to_session(session, playground_project_id)
|
|
224
|
+
for example_id, span in spans.items()
|
|
225
|
+
}
|
|
226
|
+
assert (
|
|
227
|
+
dataset_name := await session.scalar(
|
|
228
|
+
select(models.Dataset.name).where(models.Dataset.id == dataset_id)
|
|
229
|
+
)
|
|
230
|
+
) is not None
|
|
231
|
+
if version_id is None:
|
|
232
|
+
resolved_version_id = await session.scalar(
|
|
233
|
+
select(models.DatasetVersion.id)
|
|
234
|
+
.where(models.DatasetVersion.dataset_id == dataset_id)
|
|
235
|
+
.order_by(models.DatasetVersion.id.desc())
|
|
236
|
+
.limit(1)
|
|
237
|
+
)
|
|
238
|
+
else:
|
|
239
|
+
resolved_version_id = await session.scalar(
|
|
240
|
+
select(models.DatasetVersion.id).where(
|
|
241
|
+
and_(
|
|
242
|
+
models.DatasetVersion.dataset_id == dataset_id,
|
|
243
|
+
models.DatasetVersion.id == version_id,
|
|
244
|
+
)
|
|
245
|
+
)
|
|
246
|
+
)
|
|
247
|
+
assert resolved_version_id is not None
|
|
248
|
+
resolved_version_node_id = GlobalID(DatasetVersion.__name__, str(resolved_version_id))
|
|
249
|
+
experiment = models.Experiment(
|
|
250
|
+
dataset_id=from_global_id_with_expected_type(input.dataset_id, Dataset.__name__),
|
|
251
|
+
dataset_version_id=resolved_version_id,
|
|
252
|
+
name=input.experiment_name or _DEFAULT_PLAYGROUND_EXPERIMENT_NAME,
|
|
253
|
+
description=input.experiment_description
|
|
254
|
+
or _default_playground_experiment_description(dataset_name=dataset_name),
|
|
255
|
+
repetitions=1,
|
|
256
|
+
metadata_=input.experiment_metadata
|
|
257
|
+
or _default_playground_experiment_metadata(
|
|
258
|
+
dataset_name=dataset_name,
|
|
259
|
+
dataset_id=input.dataset_id,
|
|
260
|
+
version_id=resolved_version_node_id,
|
|
261
|
+
),
|
|
262
|
+
project_name=PLAYGROUND_PROJECT_NAME,
|
|
263
|
+
)
|
|
264
|
+
session.add(experiment)
|
|
265
|
+
await session.flush()
|
|
266
|
+
runs = [
|
|
267
|
+
models.ExperimentRun(
|
|
268
|
+
experiment_id=experiment.id,
|
|
269
|
+
dataset_example_id=from_global_id_with_expected_type(
|
|
270
|
+
example_id, DatasetExample.__name__
|
|
271
|
+
),
|
|
272
|
+
trace_id=span.trace_id,
|
|
273
|
+
output=models.ExperimentRunOutput(
|
|
274
|
+
task_output=_get_playground_experiment_task_output(span)
|
|
275
|
+
),
|
|
276
|
+
repetition_number=1,
|
|
277
|
+
start_time=span.start_time,
|
|
278
|
+
end_time=span.end_time,
|
|
279
|
+
error=error_message
|
|
280
|
+
if (error_message := span.error_message) is not None
|
|
281
|
+
else None,
|
|
282
|
+
prompt_token_count=get_attribute_value(span.attributes, LLM_TOKEN_COUNT_PROMPT),
|
|
283
|
+
completion_token_count=get_attribute_value(
|
|
284
|
+
span.attributes, LLM_TOKEN_COUNT_COMPLETION
|
|
285
|
+
),
|
|
286
|
+
)
|
|
287
|
+
for example_id, span in spans.items()
|
|
288
|
+
]
|
|
289
|
+
session.add_all(runs)
|
|
290
|
+
await session.flush()
|
|
291
|
+
for example_id in spans:
|
|
292
|
+
yield FinishedChatCompletion(
|
|
293
|
+
span=to_gql_span(db_spans[example_id]),
|
|
294
|
+
dataset_example_id=example_id,
|
|
616
295
|
)
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
296
|
+
yield ChatCompletionOverDatasetSubscriptionResult(experiment=to_gql_experiment(experiment))
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
async def _stream_chat_completion_over_dataset_example(
|
|
300
|
+
*,
|
|
301
|
+
input: ChatCompletionOverDatasetInput,
|
|
302
|
+
llm_client_class: type["PlaygroundStreamingClient"],
|
|
303
|
+
revision: models.DatasetExampleRevision,
|
|
304
|
+
spans: dict[DatasetExampleID, streaming_llm_span],
|
|
305
|
+
) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
|
|
306
|
+
example_id = GlobalID(DatasetExample.__name__, str(revision.dataset_example_id))
|
|
307
|
+
llm_client = llm_client_class(
|
|
308
|
+
model=input.model,
|
|
309
|
+
api_key=input.api_key,
|
|
310
|
+
)
|
|
311
|
+
invocation_parameters = llm_client.construct_invocation_parameters(input.invocation_parameters)
|
|
312
|
+
messages = [
|
|
313
|
+
(
|
|
314
|
+
message.role,
|
|
315
|
+
message.content,
|
|
316
|
+
message.tool_call_id if isinstance(message.tool_call_id, str) else None,
|
|
317
|
+
message.tool_calls if isinstance(message.tool_calls, list) else None,
|
|
318
|
+
)
|
|
319
|
+
for message in input.messages
|
|
320
|
+
]
|
|
321
|
+
try:
|
|
322
|
+
messages = list(
|
|
323
|
+
_formatted_messages(
|
|
324
|
+
messages=messages,
|
|
325
|
+
template_language=input.template_language,
|
|
326
|
+
template_variables=revision.input,
|
|
621
327
|
)
|
|
328
|
+
)
|
|
329
|
+
except TemplateFormatterError as error:
|
|
330
|
+
yield ChatCompletionSubscriptionError(message=str(error), dataset_example_id=example_id)
|
|
331
|
+
return
|
|
332
|
+
span = streaming_llm_span(
|
|
333
|
+
input=input,
|
|
334
|
+
messages=messages,
|
|
335
|
+
invocation_parameters=invocation_parameters,
|
|
336
|
+
)
|
|
337
|
+
spans[example_id] = span
|
|
338
|
+
async with span:
|
|
339
|
+
async for chunk in llm_client.chat_completion_create(
|
|
340
|
+
messages=messages, tools=input.tools or [], **invocation_parameters
|
|
341
|
+
):
|
|
342
|
+
span.add_response_chunk(chunk)
|
|
343
|
+
chunk.dataset_example_id = example_id
|
|
344
|
+
yield chunk
|
|
345
|
+
span.set_attributes(llm_client.attributes)
|
|
346
|
+
if span.error_message is not None:
|
|
347
|
+
yield ChatCompletionSubscriptionError(
|
|
348
|
+
message=span.error_message, dataset_example_id=example_id
|
|
349
|
+
)
|
|
622
350
|
|
|
623
351
|
|
|
624
|
-
def
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
352
|
+
async def _merge_iterators(
|
|
353
|
+
iterators: Collection[AsyncIterator[GenericType]],
|
|
354
|
+
) -> AsyncIterator[GenericType]:
|
|
355
|
+
tasks: dict[AsyncIterator[GenericType], Task[GenericType]] = {
|
|
356
|
+
iterable: _as_task(iterable) for iterable in iterators
|
|
357
|
+
}
|
|
358
|
+
while tasks:
|
|
359
|
+
completed_tasks, _ = await wait(tasks.values(), return_when=FIRST_COMPLETED)
|
|
360
|
+
for task in completed_tasks:
|
|
361
|
+
iterator = next(it for it, t in tasks.items() if t == task)
|
|
362
|
+
try:
|
|
363
|
+
yield task.result()
|
|
364
|
+
except StopAsyncIteration:
|
|
365
|
+
del tasks[iterator]
|
|
366
|
+
except Exception as error:
|
|
367
|
+
del tasks[iterator]
|
|
368
|
+
logger.exception(error)
|
|
369
|
+
else:
|
|
370
|
+
tasks[iterator] = _as_task(iterator)
|
|
629
371
|
|
|
630
372
|
|
|
631
|
-
def
|
|
632
|
-
|
|
633
|
-
Generates a random span ID in hexadecimal format.
|
|
634
|
-
"""
|
|
635
|
-
return _hex(DefaultOTelIDGenerator().generate_span_id())
|
|
373
|
+
def _as_task(iterable: AsyncIterator[GenericType]) -> Task[GenericType]:
|
|
374
|
+
return create_task(_as_coroutine(iterable))
|
|
636
375
|
|
|
637
376
|
|
|
638
|
-
def
|
|
639
|
-
|
|
640
|
-
Converts an integer to a hexadecimal string.
|
|
641
|
-
"""
|
|
642
|
-
return hex(number)[2:]
|
|
377
|
+
async def _as_coroutine(iterable: AsyncIterator[GenericType]) -> GenericType:
|
|
378
|
+
return await iterable.__anext__()
|
|
643
379
|
|
|
644
380
|
|
|
645
381
|
def _formatted_messages(
|
|
646
|
-
|
|
647
|
-
|
|
382
|
+
*,
|
|
383
|
+
messages: Iterable[ChatCompletionMessage],
|
|
384
|
+
template_language: TemplateLanguage,
|
|
385
|
+
template_variables: Mapping[str, Any],
|
|
648
386
|
) -> Iterator[tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[str]]]]:
|
|
649
387
|
"""
|
|
650
388
|
Formats the messages using the given template options.
|
|
651
389
|
"""
|
|
652
|
-
template_formatter = _template_formatter(template_language=
|
|
390
|
+
template_formatter = _template_formatter(template_language=template_language)
|
|
653
391
|
(
|
|
654
392
|
roles,
|
|
655
393
|
templates,
|
|
@@ -657,7 +395,7 @@ def _formatted_messages(
|
|
|
657
395
|
tool_calls,
|
|
658
396
|
) = zip(*messages)
|
|
659
397
|
formatted_templates = map(
|
|
660
|
-
lambda template: template_formatter.format(template, **
|
|
398
|
+
lambda template: template_formatter.format(template, **template_variables),
|
|
661
399
|
templates,
|
|
662
400
|
)
|
|
663
401
|
formatted_messages = zip(roles, formatted_templates, tool_call_id, tool_calls)
|
|
@@ -675,36 +413,29 @@ def _template_formatter(template_language: TemplateLanguage) -> TemplateFormatte
|
|
|
675
413
|
assert_never(template_language)
|
|
676
414
|
|
|
677
415
|
|
|
678
|
-
def
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
|
|
682
|
-
return {k: (v.isoformat() if isinstance(v, datetime) else v) for k, v in asdict(event).items()}
|
|
416
|
+
def _get_playground_experiment_task_output(
|
|
417
|
+
span: streaming_llm_span,
|
|
418
|
+
) -> Any:
|
|
419
|
+
return get_attribute_value(span.attributes, LLM_OUTPUT_MESSAGES)
|
|
683
420
|
|
|
684
421
|
|
|
685
|
-
|
|
422
|
+
_DEFAULT_PLAYGROUND_EXPERIMENT_NAME = "playground-experiment"
|
|
686
423
|
|
|
687
|
-
LLM = OpenInferenceSpanKindValues.LLM.value
|
|
688
424
|
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
INPUT_VALUE = SpanAttributes.INPUT_VALUE
|
|
692
|
-
OUTPUT_MIME_TYPE = SpanAttributes.OUTPUT_MIME_TYPE
|
|
693
|
-
OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE
|
|
694
|
-
LLM_INPUT_MESSAGES = SpanAttributes.LLM_INPUT_MESSAGES
|
|
695
|
-
LLM_OUTPUT_MESSAGES = SpanAttributes.LLM_OUTPUT_MESSAGES
|
|
696
|
-
LLM_MODEL_NAME = SpanAttributes.LLM_MODEL_NAME
|
|
697
|
-
LLM_INVOCATION_PARAMETERS = SpanAttributes.LLM_INVOCATION_PARAMETERS
|
|
698
|
-
LLM_TOOLS = SpanAttributes.LLM_TOOLS
|
|
699
|
-
LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT
|
|
700
|
-
LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION
|
|
701
|
-
LLM_TOKEN_COUNT_TOTAL = SpanAttributes.LLM_TOKEN_COUNT_TOTAL
|
|
425
|
+
def _default_playground_experiment_description(dataset_name: str) -> str:
|
|
426
|
+
return f'Playground experiment for dataset "{dataset_name}"'
|
|
702
427
|
|
|
703
|
-
MESSAGE_CONTENT = MessageAttributes.MESSAGE_CONTENT
|
|
704
|
-
MESSAGE_ROLE = MessageAttributes.MESSAGE_ROLE
|
|
705
|
-
MESSAGE_TOOL_CALLS = MessageAttributes.MESSAGE_TOOL_CALLS
|
|
706
428
|
|
|
707
|
-
|
|
708
|
-
|
|
429
|
+
def _default_playground_experiment_metadata(
|
|
430
|
+
dataset_name: str, dataset_id: GlobalID, version_id: GlobalID
|
|
431
|
+
) -> dict[str, Any]:
|
|
432
|
+
return {
|
|
433
|
+
"dataset_name": dataset_name,
|
|
434
|
+
"dataset_id": str(dataset_id),
|
|
435
|
+
"dataset_version_id": str(version_id),
|
|
436
|
+
}
|
|
709
437
|
|
|
710
|
-
|
|
438
|
+
|
|
439
|
+
LLM_OUTPUT_MESSAGES = SpanAttributes.LLM_OUTPUT_MESSAGES
|
|
440
|
+
LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION
|
|
441
|
+
LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT
|