arize-phoenix 5.6.0__py3-none-any.whl → 5.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.8.0.dist-info}/METADATA +4 -6
- {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.8.0.dist-info}/RECORD +39 -30
- {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.8.0.dist-info}/WHEEL +1 -1
- phoenix/config.py +58 -0
- phoenix/server/api/helpers/playground_clients.py +758 -0
- phoenix/server/api/helpers/playground_registry.py +70 -0
- phoenix/server/api/helpers/playground_spans.py +422 -0
- phoenix/server/api/input_types/ChatCompletionInput.py +38 -0
- phoenix/server/api/input_types/GenerativeModelInput.py +17 -0
- phoenix/server/api/input_types/InvocationParameters.py +155 -13
- phoenix/server/api/input_types/TemplateOptions.py +10 -0
- phoenix/server/api/mutations/__init__.py +4 -0
- phoenix/server/api/mutations/chat_mutations.py +355 -0
- phoenix/server/api/queries.py +41 -52
- phoenix/server/api/schema.py +42 -10
- phoenix/server/api/subscriptions.py +378 -595
- phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +46 -0
- phoenix/server/api/types/GenerativeProvider.py +27 -3
- phoenix/server/api/types/Span.py +37 -0
- phoenix/server/api/types/TemplateLanguage.py +9 -0
- phoenix/server/app.py +75 -13
- phoenix/server/grpc_server.py +3 -1
- phoenix/server/main.py +14 -1
- phoenix/server/static/.vite/manifest.json +31 -31
- phoenix/server/static/assets/{components-C70HJiXz.js → components-MllbfxfJ.js} +168 -150
- phoenix/server/static/assets/{index-DLe1Oo3l.js → index-BVO2YcT1.js} +2 -2
- phoenix/server/static/assets/{pages-C8-Sl7JI.js → pages-BHfC6jnL.js} +464 -310
- phoenix/server/static/assets/{vendor-CtqfhlbC.js → vendor-BEuNhfwH.js} +1 -1
- phoenix/server/static/assets/{vendor-arizeai-C_3SBz56.js → vendor-arizeai-Bskhzyjm.js} +1 -1
- phoenix/server/static/assets/{vendor-codemirror-wfdk9cjp.js → vendor-codemirror-DLlXCf0x.js} +1 -1
- phoenix/server/static/assets/{vendor-recharts-BiVnSv90.js → vendor-recharts-CRqhvLYg.js} +1 -1
- phoenix/server/templates/index.html +1 -0
- phoenix/services.py +4 -0
- phoenix/session/session.py +15 -1
- phoenix/utilities/template_formatters.py +11 -1
- phoenix/version.py +1 -1
- {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.8.0.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.8.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.8.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,419 +1,85 @@
|
|
|
1
|
-
import
|
|
2
|
-
|
|
3
|
-
from
|
|
4
|
-
from collections.abc import AsyncIterator,
|
|
5
|
-
from dataclasses import asdict
|
|
1
|
+
import asyncio
|
|
2
|
+
import logging
|
|
3
|
+
from asyncio import FIRST_COMPLETED, Queue, QueueEmpty, Task, create_task, wait, wait_for
|
|
4
|
+
from collections.abc import AsyncIterator, Iterator
|
|
6
5
|
from datetime import datetime, timezone
|
|
7
|
-
from
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
6
|
+
from typing import (
|
|
7
|
+
Any,
|
|
8
|
+
Iterable,
|
|
9
|
+
Mapping,
|
|
10
|
+
Optional,
|
|
11
|
+
Sequence,
|
|
12
|
+
TypeVar,
|
|
13
|
+
cast,
|
|
14
|
+
)
|
|
11
15
|
|
|
12
16
|
import strawberry
|
|
13
|
-
from openinference.
|
|
14
|
-
from
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
OpenInferenceSpanKindValues,
|
|
18
|
-
SpanAttributes,
|
|
19
|
-
ToolAttributes,
|
|
20
|
-
ToolCallAttributes,
|
|
21
|
-
)
|
|
22
|
-
from opentelemetry.sdk.trace.id_generator import RandomIdGenerator as DefaultOTelIDGenerator
|
|
23
|
-
from opentelemetry.trace import StatusCode
|
|
24
|
-
from sqlalchemy import insert, select
|
|
25
|
-
from strawberry import UNSET
|
|
26
|
-
from strawberry.scalars import JSON as JSONScalarType
|
|
17
|
+
from openinference.semconv.trace import SpanAttributes
|
|
18
|
+
from sqlalchemy import and_, func, insert, select
|
|
19
|
+
from sqlalchemy.orm import load_only
|
|
20
|
+
from strawberry.relay.types import GlobalID
|
|
27
21
|
from strawberry.types import Info
|
|
28
22
|
from typing_extensions import TypeAlias, assert_never
|
|
29
23
|
|
|
30
24
|
from phoenix.datetime_utils import local_now, normalize_datetime
|
|
31
25
|
from phoenix.db import models
|
|
32
26
|
from phoenix.server.api.context import Context
|
|
33
|
-
from phoenix.server.api.exceptions import BadRequest
|
|
34
|
-
from phoenix.server.api.
|
|
35
|
-
|
|
27
|
+
from phoenix.server.api.exceptions import BadRequest, NotFound
|
|
28
|
+
from phoenix.server.api.helpers.playground_clients import (
|
|
29
|
+
PlaygroundStreamingClient,
|
|
30
|
+
initialize_playground_clients,
|
|
31
|
+
)
|
|
32
|
+
from phoenix.server.api.helpers.playground_registry import (
|
|
33
|
+
PLAYGROUND_CLIENT_REGISTRY,
|
|
34
|
+
)
|
|
35
|
+
from phoenix.server.api.helpers.playground_spans import (
|
|
36
|
+
get_db_experiment_run,
|
|
37
|
+
get_db_span,
|
|
38
|
+
get_db_trace,
|
|
39
|
+
streaming_llm_span,
|
|
40
|
+
)
|
|
41
|
+
from phoenix.server.api.input_types.ChatCompletionInput import (
|
|
42
|
+
ChatCompletionInput,
|
|
43
|
+
ChatCompletionOverDatasetInput,
|
|
44
|
+
)
|
|
36
45
|
from phoenix.server.api.types.ChatCompletionMessageRole import ChatCompletionMessageRole
|
|
37
|
-
from phoenix.server.api.types.
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
SpanEvent,
|
|
43
|
-
SpanException,
|
|
46
|
+
from phoenix.server.api.types.ChatCompletionSubscriptionPayload import (
|
|
47
|
+
ChatCompletionSubscriptionError,
|
|
48
|
+
ChatCompletionSubscriptionExperiment,
|
|
49
|
+
ChatCompletionSubscriptionPayload,
|
|
50
|
+
ChatCompletionSubscriptionResult,
|
|
44
51
|
)
|
|
45
|
-
from phoenix.
|
|
52
|
+
from phoenix.server.api.types.Dataset import Dataset
|
|
53
|
+
from phoenix.server.api.types.DatasetExample import DatasetExample
|
|
54
|
+
from phoenix.server.api.types.DatasetVersion import DatasetVersion
|
|
55
|
+
from phoenix.server.api.types.Experiment import to_gql_experiment
|
|
56
|
+
from phoenix.server.api.types.ExperimentRun import to_gql_experiment_run
|
|
57
|
+
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
58
|
+
from phoenix.server.api.types.Span import to_gql_span
|
|
59
|
+
from phoenix.server.api.types.TemplateLanguage import TemplateLanguage
|
|
60
|
+
from phoenix.server.dml_event import SpanInsertEvent
|
|
61
|
+
from phoenix.server.types import DbSessionFactory
|
|
46
62
|
from phoenix.utilities.template_formatters import (
|
|
47
63
|
FStringTemplateFormatter,
|
|
48
64
|
MustacheTemplateFormatter,
|
|
49
65
|
TemplateFormatter,
|
|
66
|
+
TemplateFormatterError,
|
|
50
67
|
)
|
|
51
68
|
|
|
52
|
-
|
|
53
|
-
from anthropic.types import MessageParam
|
|
54
|
-
from openai.types import CompletionUsage
|
|
55
|
-
from openai.types.chat import (
|
|
56
|
-
ChatCompletionMessageParam,
|
|
57
|
-
ChatCompletionMessageToolCallParam,
|
|
58
|
-
)
|
|
59
|
-
|
|
60
|
-
PLAYGROUND_PROJECT_NAME = "playground"
|
|
61
|
-
|
|
62
|
-
ToolCallID: TypeAlias = str
|
|
63
|
-
SetSpanAttributesFn: TypeAlias = Callable[[Mapping[str, Any]], None]
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
@strawberry.enum
|
|
67
|
-
class TemplateLanguage(Enum):
|
|
68
|
-
MUSTACHE = "MUSTACHE"
|
|
69
|
-
F_STRING = "F_STRING"
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
@strawberry.input
|
|
73
|
-
class TemplateOptions:
|
|
74
|
-
variables: JSONScalarType
|
|
75
|
-
language: TemplateLanguage
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
@strawberry.type
|
|
79
|
-
class TextChunk:
|
|
80
|
-
content: str
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
@strawberry.type
|
|
84
|
-
class FunctionCallChunk:
|
|
85
|
-
name: str
|
|
86
|
-
arguments: str
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
@strawberry.type
|
|
90
|
-
class ToolCallChunk:
|
|
91
|
-
id: str
|
|
92
|
-
function: FunctionCallChunk
|
|
93
|
-
|
|
69
|
+
GenericType = TypeVar("GenericType")
|
|
94
70
|
|
|
95
|
-
|
|
96
|
-
class ChatCompletionSubscriptionError:
|
|
97
|
-
message: str
|
|
71
|
+
logger = logging.getLogger(__name__)
|
|
98
72
|
|
|
73
|
+
initialize_playground_clients()
|
|
99
74
|
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
span: Span
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
ChatCompletionChunk: TypeAlias = Union[TextChunk, ToolCallChunk]
|
|
106
|
-
|
|
107
|
-
ChatCompletionSubscriptionPayload: TypeAlias = Annotated[
|
|
108
|
-
Union[TextChunk, ToolCallChunk, FinishedChatCompletion, ChatCompletionSubscriptionError],
|
|
109
|
-
strawberry.union("ChatCompletionSubscriptionPayload"),
|
|
75
|
+
ChatCompletionMessage: TypeAlias = tuple[
|
|
76
|
+
ChatCompletionMessageRole, str, Optional[str], Optional[list[str]]
|
|
110
77
|
]
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
name: str
|
|
117
|
-
""" The name of the model. Or the Deployment Name for Azure OpenAI models. """
|
|
118
|
-
endpoint: Optional[str] = UNSET
|
|
119
|
-
""" The endpoint to use for the model. Only required for Azure OpenAI models. """
|
|
120
|
-
api_version: Optional[str] = UNSET
|
|
121
|
-
""" The API version to use for the model. """
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
@strawberry.input
|
|
125
|
-
class ChatCompletionInput:
|
|
126
|
-
messages: list[ChatCompletionMessageInput]
|
|
127
|
-
model: GenerativeModelInput
|
|
128
|
-
invocation_parameters: InvocationParameters = strawberry.field(default_factory=dict)
|
|
129
|
-
tools: Optional[list[JSONScalarType]] = UNSET
|
|
130
|
-
template: Optional[TemplateOptions] = UNSET
|
|
131
|
-
api_key: Optional[str] = strawberry.field(default=None)
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
PLAYGROUND_STREAMING_CLIENT_REGISTRY: dict[
|
|
135
|
-
GenerativeProviderKey, type["PlaygroundStreamingClient"]
|
|
136
|
-
] = {}
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
def register_llm_client(
|
|
140
|
-
provider_key: GenerativeProviderKey,
|
|
141
|
-
) -> Callable[[type["PlaygroundStreamingClient"]], type["PlaygroundStreamingClient"]]:
|
|
142
|
-
def decorator(cls: type["PlaygroundStreamingClient"]) -> type["PlaygroundStreamingClient"]:
|
|
143
|
-
PLAYGROUND_STREAMING_CLIENT_REGISTRY[provider_key] = cls
|
|
144
|
-
return cls
|
|
145
|
-
|
|
146
|
-
return decorator
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
class PlaygroundStreamingClient(ABC):
|
|
150
|
-
def __init__(
|
|
151
|
-
self,
|
|
152
|
-
model: GenerativeModelInput,
|
|
153
|
-
api_key: Optional[str] = None,
|
|
154
|
-
set_span_attributes: Optional[SetSpanAttributesFn] = None,
|
|
155
|
-
) -> None:
|
|
156
|
-
self._set_span_attributes = set_span_attributes
|
|
157
|
-
|
|
158
|
-
@abstractmethod
|
|
159
|
-
async def chat_completion_create(
|
|
160
|
-
self,
|
|
161
|
-
messages: list[
|
|
162
|
-
tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[JSONScalarType]]]
|
|
163
|
-
],
|
|
164
|
-
tools: list[JSONScalarType],
|
|
165
|
-
**invocation_parameters: Any,
|
|
166
|
-
) -> AsyncIterator[ChatCompletionChunk]:
|
|
167
|
-
# a yield statement is needed to satisfy the type-checker
|
|
168
|
-
# https://mypy.readthedocs.io/en/stable/more_types.html#asynchronous-iterators
|
|
169
|
-
yield TextChunk(content="")
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
@register_llm_client(GenerativeProviderKey.OPENAI)
|
|
173
|
-
class OpenAIStreamingClient(PlaygroundStreamingClient):
|
|
174
|
-
def __init__(
|
|
175
|
-
self,
|
|
176
|
-
model: GenerativeModelInput,
|
|
177
|
-
api_key: Optional[str] = None,
|
|
178
|
-
set_span_attributes: Optional[SetSpanAttributesFn] = None,
|
|
179
|
-
) -> None:
|
|
180
|
-
from openai import AsyncOpenAI
|
|
181
|
-
|
|
182
|
-
super().__init__(model=model, api_key=api_key, set_span_attributes=set_span_attributes)
|
|
183
|
-
self.client = AsyncOpenAI(api_key=api_key)
|
|
184
|
-
self.model_name = model.name
|
|
185
|
-
|
|
186
|
-
async def chat_completion_create(
|
|
187
|
-
self,
|
|
188
|
-
messages: list[
|
|
189
|
-
tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[JSONScalarType]]]
|
|
190
|
-
],
|
|
191
|
-
tools: list[JSONScalarType],
|
|
192
|
-
**invocation_parameters: Any,
|
|
193
|
-
) -> AsyncIterator[ChatCompletionChunk]:
|
|
194
|
-
from openai import NOT_GIVEN
|
|
195
|
-
from openai.types.chat import ChatCompletionStreamOptionsParam
|
|
196
|
-
|
|
197
|
-
# Convert standard messages to OpenAI messages
|
|
198
|
-
openai_messages = [self.to_openai_chat_completion_param(*message) for message in messages]
|
|
199
|
-
tool_call_ids: dict[int, str] = {}
|
|
200
|
-
token_usage: Optional["CompletionUsage"] = None
|
|
201
|
-
async for chunk in await self.client.chat.completions.create(
|
|
202
|
-
messages=openai_messages,
|
|
203
|
-
model=self.model_name,
|
|
204
|
-
stream=True,
|
|
205
|
-
stream_options=ChatCompletionStreamOptionsParam(include_usage=True),
|
|
206
|
-
tools=tools or NOT_GIVEN,
|
|
207
|
-
**invocation_parameters,
|
|
208
|
-
):
|
|
209
|
-
if (usage := chunk.usage) is not None:
|
|
210
|
-
token_usage = usage
|
|
211
|
-
continue
|
|
212
|
-
choice = chunk.choices[0]
|
|
213
|
-
delta = choice.delta
|
|
214
|
-
if choice.finish_reason is None:
|
|
215
|
-
if isinstance(chunk_content := delta.content, str):
|
|
216
|
-
text_chunk = TextChunk(content=chunk_content)
|
|
217
|
-
yield text_chunk
|
|
218
|
-
if (tool_calls := delta.tool_calls) is not None:
|
|
219
|
-
for tool_call_index, tool_call in enumerate(tool_calls):
|
|
220
|
-
tool_call_id = (
|
|
221
|
-
tool_call.id
|
|
222
|
-
if tool_call.id is not None
|
|
223
|
-
else tool_call_ids[tool_call_index]
|
|
224
|
-
)
|
|
225
|
-
tool_call_ids[tool_call_index] = tool_call_id
|
|
226
|
-
if (function := tool_call.function) is not None:
|
|
227
|
-
tool_call_chunk = ToolCallChunk(
|
|
228
|
-
id=tool_call_id,
|
|
229
|
-
function=FunctionCallChunk(
|
|
230
|
-
name=function.name or "",
|
|
231
|
-
arguments=function.arguments or "",
|
|
232
|
-
),
|
|
233
|
-
)
|
|
234
|
-
yield tool_call_chunk
|
|
235
|
-
if token_usage is not None and self._set_span_attributes:
|
|
236
|
-
self._set_span_attributes(dict(self._llm_token_counts(token_usage)))
|
|
237
|
-
|
|
238
|
-
def to_openai_chat_completion_param(
|
|
239
|
-
self,
|
|
240
|
-
role: ChatCompletionMessageRole,
|
|
241
|
-
content: JSONScalarType,
|
|
242
|
-
tool_call_id: Optional[str] = None,
|
|
243
|
-
tool_calls: Optional[list[JSONScalarType]] = None,
|
|
244
|
-
) -> "ChatCompletionMessageParam":
|
|
245
|
-
from openai.types.chat import (
|
|
246
|
-
ChatCompletionAssistantMessageParam,
|
|
247
|
-
ChatCompletionSystemMessageParam,
|
|
248
|
-
ChatCompletionToolMessageParam,
|
|
249
|
-
ChatCompletionUserMessageParam,
|
|
250
|
-
)
|
|
251
|
-
|
|
252
|
-
if role is ChatCompletionMessageRole.USER:
|
|
253
|
-
return ChatCompletionUserMessageParam(
|
|
254
|
-
{
|
|
255
|
-
"content": content,
|
|
256
|
-
"role": "user",
|
|
257
|
-
}
|
|
258
|
-
)
|
|
259
|
-
if role is ChatCompletionMessageRole.SYSTEM:
|
|
260
|
-
return ChatCompletionSystemMessageParam(
|
|
261
|
-
{
|
|
262
|
-
"content": content,
|
|
263
|
-
"role": "system",
|
|
264
|
-
}
|
|
265
|
-
)
|
|
266
|
-
if role is ChatCompletionMessageRole.AI:
|
|
267
|
-
if tool_calls is None:
|
|
268
|
-
return ChatCompletionAssistantMessageParam(
|
|
269
|
-
{
|
|
270
|
-
"content": content,
|
|
271
|
-
"role": "assistant",
|
|
272
|
-
}
|
|
273
|
-
)
|
|
274
|
-
else:
|
|
275
|
-
return ChatCompletionAssistantMessageParam(
|
|
276
|
-
{
|
|
277
|
-
"content": content,
|
|
278
|
-
"role": "assistant",
|
|
279
|
-
"tool_calls": [
|
|
280
|
-
self.to_openai_tool_call_param(tool_call) for tool_call in tool_calls
|
|
281
|
-
],
|
|
282
|
-
}
|
|
283
|
-
)
|
|
284
|
-
if role is ChatCompletionMessageRole.TOOL:
|
|
285
|
-
if tool_call_id is None:
|
|
286
|
-
raise ValueError("tool_call_id is required for tool messages")
|
|
287
|
-
return ChatCompletionToolMessageParam(
|
|
288
|
-
{"content": content, "role": "tool", "tool_call_id": tool_call_id}
|
|
289
|
-
)
|
|
290
|
-
assert_never(role)
|
|
291
|
-
|
|
292
|
-
def to_openai_tool_call_param(
|
|
293
|
-
self,
|
|
294
|
-
tool_call: JSONScalarType,
|
|
295
|
-
) -> "ChatCompletionMessageToolCallParam":
|
|
296
|
-
from openai.types.chat import ChatCompletionMessageToolCallParam
|
|
297
|
-
|
|
298
|
-
return ChatCompletionMessageToolCallParam(
|
|
299
|
-
id=tool_call.get("id", ""),
|
|
300
|
-
function={
|
|
301
|
-
"name": tool_call.get("function", {}).get("name", ""),
|
|
302
|
-
"arguments": safe_json_dumps(tool_call.get("function", {}).get("arguments", "")),
|
|
303
|
-
},
|
|
304
|
-
type="function",
|
|
305
|
-
)
|
|
306
|
-
|
|
307
|
-
@staticmethod
|
|
308
|
-
def _llm_token_counts(usage: "CompletionUsage") -> Iterator[tuple[str, Any]]:
|
|
309
|
-
yield LLM_TOKEN_COUNT_PROMPT, usage.prompt_tokens
|
|
310
|
-
yield LLM_TOKEN_COUNT_COMPLETION, usage.completion_tokens
|
|
311
|
-
yield LLM_TOKEN_COUNT_TOTAL, usage.total_tokens
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
@register_llm_client(GenerativeProviderKey.AZURE_OPENAI)
|
|
315
|
-
class AzureOpenAIStreamingClient(OpenAIStreamingClient):
|
|
316
|
-
def __init__(
|
|
317
|
-
self,
|
|
318
|
-
model: GenerativeModelInput,
|
|
319
|
-
api_key: Optional[str] = None,
|
|
320
|
-
set_span_attributes: Optional[SetSpanAttributesFn] = None,
|
|
321
|
-
):
|
|
322
|
-
from openai import AsyncAzureOpenAI
|
|
323
|
-
|
|
324
|
-
super().__init__(model=model, api_key=api_key, set_span_attributes=set_span_attributes)
|
|
325
|
-
if model.endpoint is None or model.api_version is None:
|
|
326
|
-
raise ValueError("endpoint and api_version are required for Azure OpenAI models")
|
|
327
|
-
self.client = AsyncAzureOpenAI(
|
|
328
|
-
api_key=api_key,
|
|
329
|
-
azure_endpoint=model.endpoint,
|
|
330
|
-
api_version=model.api_version,
|
|
331
|
-
)
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
@register_llm_client(GenerativeProviderKey.ANTHROPIC)
|
|
335
|
-
class AnthropicStreamingClient(PlaygroundStreamingClient):
|
|
336
|
-
def __init__(
|
|
337
|
-
self,
|
|
338
|
-
model: GenerativeModelInput,
|
|
339
|
-
api_key: Optional[str] = None,
|
|
340
|
-
set_span_attributes: Optional[SetSpanAttributesFn] = None,
|
|
341
|
-
) -> None:
|
|
342
|
-
import anthropic
|
|
343
|
-
|
|
344
|
-
super().__init__(model=model, api_key=api_key, set_span_attributes=set_span_attributes)
|
|
345
|
-
self.client = anthropic.AsyncAnthropic(api_key=api_key)
|
|
346
|
-
self.model_name = model.name
|
|
347
|
-
|
|
348
|
-
async def chat_completion_create(
|
|
349
|
-
self,
|
|
350
|
-
messages: list[
|
|
351
|
-
tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[JSONScalarType]]]
|
|
352
|
-
],
|
|
353
|
-
tools: list[JSONScalarType],
|
|
354
|
-
**invocation_parameters: Any,
|
|
355
|
-
) -> AsyncIterator[ChatCompletionChunk]:
|
|
356
|
-
import anthropic.lib.streaming as anthropic_streaming
|
|
357
|
-
import anthropic.types as anthropic_types
|
|
358
|
-
|
|
359
|
-
anthropic_messages, system_prompt = self._build_anthropic_messages(messages)
|
|
360
|
-
|
|
361
|
-
anthropic_params = {
|
|
362
|
-
"messages": anthropic_messages,
|
|
363
|
-
"model": self.model_name,
|
|
364
|
-
"system": system_prompt,
|
|
365
|
-
"max_tokens": 1024,
|
|
366
|
-
**invocation_parameters,
|
|
367
|
-
}
|
|
368
|
-
async with self.client.messages.stream(**anthropic_params) as stream:
|
|
369
|
-
async for event in stream:
|
|
370
|
-
if isinstance(event, anthropic_types.RawMessageStartEvent):
|
|
371
|
-
if self._set_span_attributes:
|
|
372
|
-
self._set_span_attributes(
|
|
373
|
-
{LLM_TOKEN_COUNT_PROMPT: event.message.usage.input_tokens}
|
|
374
|
-
)
|
|
375
|
-
elif isinstance(event, anthropic_streaming.TextEvent):
|
|
376
|
-
yield TextChunk(content=event.text)
|
|
377
|
-
elif isinstance(event, anthropic_streaming.MessageStopEvent):
|
|
378
|
-
if self._set_span_attributes:
|
|
379
|
-
self._set_span_attributes(
|
|
380
|
-
{LLM_TOKEN_COUNT_COMPLETION: event.message.usage.output_tokens}
|
|
381
|
-
)
|
|
382
|
-
elif isinstance(
|
|
383
|
-
event,
|
|
384
|
-
(
|
|
385
|
-
anthropic_types.RawContentBlockStartEvent,
|
|
386
|
-
anthropic_types.RawContentBlockDeltaEvent,
|
|
387
|
-
anthropic_types.RawMessageDeltaEvent,
|
|
388
|
-
anthropic_streaming.ContentBlockStopEvent,
|
|
389
|
-
),
|
|
390
|
-
):
|
|
391
|
-
# event types emitted by the stream that don't contain useful information
|
|
392
|
-
pass
|
|
393
|
-
elif isinstance(event, anthropic_streaming.InputJsonEvent):
|
|
394
|
-
raise NotImplementedError
|
|
395
|
-
else:
|
|
396
|
-
assert_never(event)
|
|
397
|
-
|
|
398
|
-
def _build_anthropic_messages(
|
|
399
|
-
self,
|
|
400
|
-
messages: list[tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[str]]]],
|
|
401
|
-
) -> tuple[list["MessageParam"], str]:
|
|
402
|
-
anthropic_messages: list["MessageParam"] = []
|
|
403
|
-
system_prompt = ""
|
|
404
|
-
for role, content, _tool_call_id, _tool_calls in messages:
|
|
405
|
-
if role == ChatCompletionMessageRole.USER:
|
|
406
|
-
anthropic_messages.append({"role": "user", "content": content})
|
|
407
|
-
elif role == ChatCompletionMessageRole.AI:
|
|
408
|
-
anthropic_messages.append({"role": "assistant", "content": content})
|
|
409
|
-
elif role == ChatCompletionMessageRole.SYSTEM:
|
|
410
|
-
system_prompt += content + "\n"
|
|
411
|
-
elif role == ChatCompletionMessageRole.TOOL:
|
|
412
|
-
raise NotImplementedError
|
|
413
|
-
else:
|
|
414
|
-
assert_never(role)
|
|
415
|
-
|
|
416
|
-
return anthropic_messages, system_prompt
|
|
78
|
+
DatasetExampleID: TypeAlias = GlobalID
|
|
79
|
+
ChatCompletionResult: TypeAlias = tuple[
|
|
80
|
+
DatasetExampleID, Optional[models.Span], models.ExperimentRun
|
|
81
|
+
]
|
|
82
|
+
PLAYGROUND_PROJECT_NAME = "playground"
|
|
417
83
|
|
|
418
84
|
|
|
419
85
|
@strawberry.type
|
|
@@ -422,15 +88,15 @@ class Subscription:
|
|
|
422
88
|
async def chat_completion(
|
|
423
89
|
self, info: Info[Context, None], input: ChatCompletionInput
|
|
424
90
|
) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
|
|
425
|
-
# Determine which LLM client to use based on provider_key
|
|
426
91
|
provider_key = input.model.provider_key
|
|
427
|
-
|
|
92
|
+
llm_client_class = PLAYGROUND_CLIENT_REGISTRY.get_client(provider_key, input.model.name)
|
|
93
|
+
if llm_client_class is None:
|
|
428
94
|
raise BadRequest(f"No LLM client registered for provider '{provider_key}'")
|
|
429
95
|
llm_client = llm_client_class(
|
|
430
96
|
model=input.model,
|
|
431
97
|
api_key=input.api_key,
|
|
432
|
-
set_span_attributes=lambda attrs: attributes.update(attrs),
|
|
433
98
|
)
|
|
99
|
+
|
|
434
100
|
messages = [
|
|
435
101
|
(
|
|
436
102
|
message.role,
|
|
@@ -441,70 +107,129 @@ class Subscription:
|
|
|
441
107
|
for message in input.messages
|
|
442
108
|
]
|
|
443
109
|
if template_options := input.template:
|
|
444
|
-
messages = list(
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
_llm_tools(input.tools or []),
|
|
451
|
-
_llm_input_messages(messages),
|
|
452
|
-
_llm_invocation_parameters(invocation_parameters),
|
|
453
|
-
_input_value_and_mime_type(input),
|
|
110
|
+
messages = list(
|
|
111
|
+
_formatted_messages(
|
|
112
|
+
messages=messages,
|
|
113
|
+
template_language=template_options.language,
|
|
114
|
+
template_variables=template_options.variables,
|
|
115
|
+
)
|
|
454
116
|
)
|
|
117
|
+
invocation_parameters = llm_client.construct_invocation_parameters(
|
|
118
|
+
input.invocation_parameters
|
|
455
119
|
)
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
response_chunks = []
|
|
462
|
-
text_chunks: list[TextChunk] = []
|
|
463
|
-
tool_call_chunks: defaultdict[ToolCallID, list[ToolCallChunk]] = defaultdict(list)
|
|
464
|
-
try:
|
|
465
|
-
start_time = cast(datetime, normalize_datetime(dt=local_now(), tz=timezone.utc))
|
|
120
|
+
async with streaming_llm_span(
|
|
121
|
+
input=input,
|
|
122
|
+
messages=messages,
|
|
123
|
+
invocation_parameters=invocation_parameters,
|
|
124
|
+
) as span:
|
|
466
125
|
async for chunk in llm_client.chat_completion_create(
|
|
467
|
-
messages=messages,
|
|
468
|
-
tools=input.tools or [],
|
|
469
|
-
**invocation_parameters,
|
|
126
|
+
messages=messages, tools=input.tools or [], **invocation_parameters
|
|
470
127
|
):
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
status_code = StatusCode.OK
|
|
481
|
-
except Exception as error:
|
|
482
|
-
end_time = cast(datetime, normalize_datetime(dt=local_now(), tz=timezone.utc))
|
|
483
|
-
status_code = StatusCode.ERROR
|
|
484
|
-
status_message = str(error)
|
|
485
|
-
events.append(
|
|
486
|
-
SpanException(
|
|
487
|
-
timestamp=end_time,
|
|
488
|
-
message=status_message,
|
|
489
|
-
exception_type=type(error).__name__,
|
|
490
|
-
exception_escaped=False,
|
|
491
|
-
exception_stacktrace=format_exc(),
|
|
128
|
+
span.add_response_chunk(chunk)
|
|
129
|
+
yield chunk
|
|
130
|
+
span.set_attributes(llm_client.attributes)
|
|
131
|
+
if span.status_message is not None:
|
|
132
|
+
yield ChatCompletionSubscriptionError(message=span.status_message)
|
|
133
|
+
async with info.context.db() as session:
|
|
134
|
+
if (
|
|
135
|
+
playground_project_id := await session.scalar(
|
|
136
|
+
select(models.Project.id).where(models.Project.name == PLAYGROUND_PROJECT_NAME)
|
|
492
137
|
)
|
|
493
|
-
)
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
138
|
+
) is None:
|
|
139
|
+
playground_project_id = await session.scalar(
|
|
140
|
+
insert(models.Project)
|
|
141
|
+
.returning(models.Project.id)
|
|
142
|
+
.values(
|
|
143
|
+
name=PLAYGROUND_PROJECT_NAME,
|
|
144
|
+
description="Traces from prompt playground",
|
|
145
|
+
)
|
|
501
146
|
)
|
|
147
|
+
db_trace = get_db_trace(span, playground_project_id)
|
|
148
|
+
db_span = get_db_span(span, db_trace)
|
|
149
|
+
session.add(db_span)
|
|
150
|
+
await session.flush()
|
|
151
|
+
info.context.event_queue.put(SpanInsertEvent(ids=(playground_project_id,)))
|
|
152
|
+
yield ChatCompletionSubscriptionResult(span=to_gql_span(db_span))
|
|
153
|
+
|
|
154
|
+
@strawberry.subscription
|
|
155
|
+
async def chat_completion_over_dataset(
|
|
156
|
+
self, info: Info[Context, None], input: ChatCompletionOverDatasetInput
|
|
157
|
+
) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
|
|
158
|
+
provider_key = input.model.provider_key
|
|
159
|
+
llm_client_class = PLAYGROUND_CLIENT_REGISTRY.get_client(provider_key, input.model.name)
|
|
160
|
+
if llm_client_class is None:
|
|
161
|
+
raise BadRequest(f"No LLM client registered for provider '{provider_key}'")
|
|
162
|
+
|
|
163
|
+
dataset_id = from_global_id_with_expected_type(input.dataset_id, Dataset.__name__)
|
|
164
|
+
version_id = (
|
|
165
|
+
from_global_id_with_expected_type(
|
|
166
|
+
global_id=input.dataset_version_id, expected_type_name=DatasetVersion.__name__
|
|
502
167
|
)
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
span_id = _generate_span_id()
|
|
168
|
+
if input.dataset_version_id
|
|
169
|
+
else None
|
|
170
|
+
)
|
|
507
171
|
async with info.context.db() as session:
|
|
172
|
+
if (
|
|
173
|
+
dataset := await session.scalar(
|
|
174
|
+
select(models.Dataset).where(models.Dataset.id == dataset_id)
|
|
175
|
+
)
|
|
176
|
+
) is None:
|
|
177
|
+
raise NotFound(f"Could not find dataset with ID {dataset_id}")
|
|
178
|
+
if version_id is None:
|
|
179
|
+
if (
|
|
180
|
+
resolved_version_id := await session.scalar(
|
|
181
|
+
select(models.DatasetVersion.id)
|
|
182
|
+
.where(models.DatasetVersion.dataset_id == dataset_id)
|
|
183
|
+
.order_by(models.DatasetVersion.id.desc())
|
|
184
|
+
.limit(1)
|
|
185
|
+
)
|
|
186
|
+
) is None:
|
|
187
|
+
raise NotFound(f"No versions found for dataset with ID {dataset_id}")
|
|
188
|
+
else:
|
|
189
|
+
if (
|
|
190
|
+
resolved_version_id := await session.scalar(
|
|
191
|
+
select(models.DatasetVersion.id).where(
|
|
192
|
+
and_(
|
|
193
|
+
models.DatasetVersion.dataset_id == dataset_id,
|
|
194
|
+
models.DatasetVersion.id == version_id,
|
|
195
|
+
)
|
|
196
|
+
)
|
|
197
|
+
)
|
|
198
|
+
) is None:
|
|
199
|
+
raise NotFound(f"Could not find dataset version with ID {version_id}")
|
|
200
|
+
revision_ids = (
|
|
201
|
+
select(func.max(models.DatasetExampleRevision.id))
|
|
202
|
+
.join(models.DatasetExample)
|
|
203
|
+
.where(
|
|
204
|
+
and_(
|
|
205
|
+
models.DatasetExample.dataset_id == dataset_id,
|
|
206
|
+
models.DatasetExampleRevision.dataset_version_id <= resolved_version_id,
|
|
207
|
+
)
|
|
208
|
+
)
|
|
209
|
+
.group_by(models.DatasetExampleRevision.dataset_example_id)
|
|
210
|
+
)
|
|
211
|
+
if not (
|
|
212
|
+
revisions := [
|
|
213
|
+
rev
|
|
214
|
+
async for rev in await session.stream_scalars(
|
|
215
|
+
select(models.DatasetExampleRevision)
|
|
216
|
+
.where(
|
|
217
|
+
and_(
|
|
218
|
+
models.DatasetExampleRevision.id.in_(revision_ids),
|
|
219
|
+
models.DatasetExampleRevision.revision_kind != "DELETE",
|
|
220
|
+
)
|
|
221
|
+
)
|
|
222
|
+
.order_by(models.DatasetExampleRevision.dataset_example_id.asc())
|
|
223
|
+
.options(
|
|
224
|
+
load_only(
|
|
225
|
+
models.DatasetExampleRevision.dataset_example_id,
|
|
226
|
+
models.DatasetExampleRevision.input,
|
|
227
|
+
)
|
|
228
|
+
)
|
|
229
|
+
)
|
|
230
|
+
]
|
|
231
|
+
):
|
|
232
|
+
raise NotFound("No examples found for the given dataset and version")
|
|
508
233
|
if (
|
|
509
234
|
playground_project_id := await session.scalar(
|
|
510
235
|
select(models.Project.id).where(models.Project.name == PLAYGROUND_PROJECT_NAME)
|
|
@@ -518,138 +243,208 @@ class Subscription:
|
|
|
518
243
|
description="Traces from prompt playground",
|
|
519
244
|
)
|
|
520
245
|
)
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
246
|
+
experiment = models.Experiment(
|
|
247
|
+
dataset_id=from_global_id_with_expected_type(input.dataset_id, Dataset.__name__),
|
|
248
|
+
dataset_version_id=resolved_version_id,
|
|
249
|
+
name=input.experiment_name or _default_playground_experiment_name(),
|
|
250
|
+
description=input.experiment_description
|
|
251
|
+
or _default_playground_experiment_description(dataset_name=dataset.name),
|
|
252
|
+
repetitions=1,
|
|
253
|
+
metadata_=input.experiment_metadata
|
|
254
|
+
or _default_playground_experiment_metadata(
|
|
255
|
+
dataset_name=dataset.name,
|
|
256
|
+
dataset_id=input.dataset_id,
|
|
257
|
+
version_id=GlobalID(DatasetVersion.__name__, str(resolved_version_id)),
|
|
258
|
+
),
|
|
259
|
+
project_name=PLAYGROUND_PROJECT_NAME,
|
|
526
260
|
)
|
|
527
|
-
|
|
528
|
-
trace_rowid=playground_trace.id,
|
|
529
|
-
span_id=span_id,
|
|
530
|
-
parent_id=None,
|
|
531
|
-
name="ChatCompletion",
|
|
532
|
-
span_kind=LLM,
|
|
533
|
-
start_time=start_time,
|
|
534
|
-
end_time=end_time,
|
|
535
|
-
attributes=unflatten(attributes.items()),
|
|
536
|
-
events=[_serialize_event(event) for event in events],
|
|
537
|
-
status_code=status_code.name,
|
|
538
|
-
status_message=status_message,
|
|
539
|
-
cumulative_error_count=int(status_code is StatusCode.ERROR),
|
|
540
|
-
cumulative_llm_token_count_prompt=prompt_tokens,
|
|
541
|
-
cumulative_llm_token_count_completion=completion_tokens,
|
|
542
|
-
llm_token_count_prompt=prompt_tokens,
|
|
543
|
-
llm_token_count_completion=completion_tokens,
|
|
544
|
-
trace=playground_trace,
|
|
545
|
-
)
|
|
546
|
-
session.add(playground_trace)
|
|
547
|
-
session.add(playground_span)
|
|
261
|
+
session.add(experiment)
|
|
548
262
|
await session.flush()
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
263
|
+
yield ChatCompletionSubscriptionExperiment(
|
|
264
|
+
experiment=to_gql_experiment(experiment)
|
|
265
|
+
) # eagerly yields experiment so it can be linked by consumers of the subscription
|
|
266
|
+
|
|
267
|
+
results_queue: Queue[ChatCompletionResult] = Queue()
|
|
268
|
+
chat_completion_streams = [
|
|
269
|
+
_stream_chat_completion_over_dataset_example(
|
|
270
|
+
input=input,
|
|
271
|
+
llm_client_class=llm_client_class,
|
|
272
|
+
revision=revision,
|
|
273
|
+
results_queue=results_queue,
|
|
274
|
+
experiment_id=experiment.id,
|
|
275
|
+
project_id=playground_project_id,
|
|
276
|
+
)
|
|
277
|
+
for revision in revisions
|
|
278
|
+
]
|
|
279
|
+
stream_to_async_tasks: dict[
|
|
280
|
+
AsyncIterator[ChatCompletionSubscriptionPayload],
|
|
281
|
+
Task[ChatCompletionSubscriptionPayload],
|
|
282
|
+
] = {iterator: _create_task_with_timeout(iterator) for iterator in chat_completion_streams}
|
|
283
|
+
batch_size = 10
|
|
284
|
+
while stream_to_async_tasks:
|
|
285
|
+
async_tasks_to_run = [task for task in stream_to_async_tasks.values()]
|
|
286
|
+
completed_tasks, _ = await wait(async_tasks_to_run, return_when=FIRST_COMPLETED)
|
|
287
|
+
for task in completed_tasks:
|
|
288
|
+
iterator = next(it for it, t in stream_to_async_tasks.items() if t == task)
|
|
289
|
+
try:
|
|
290
|
+
yield task.result()
|
|
291
|
+
except (StopAsyncIteration, asyncio.TimeoutError):
|
|
292
|
+
del stream_to_async_tasks[iterator] # removes exhausted iterator
|
|
293
|
+
except Exception as error:
|
|
294
|
+
del stream_to_async_tasks[iterator] # removes failed iterator
|
|
295
|
+
logger.exception(error)
|
|
296
|
+
else:
|
|
297
|
+
stream_to_async_tasks[iterator] = _create_task_with_timeout(iterator)
|
|
298
|
+
if results_queue.qsize() >= batch_size:
|
|
299
|
+
result_iterator = _chat_completion_result_payloads(
|
|
300
|
+
db=info.context.db, results=_drain_no_wait(results_queue)
|
|
301
|
+
)
|
|
302
|
+
stream_to_async_tasks[result_iterator] = _create_task_with_timeout(
|
|
303
|
+
result_iterator
|
|
304
|
+
)
|
|
305
|
+
if remaining_results := await _drain(results_queue):
|
|
306
|
+
async for result_payload in _chat_completion_result_payloads(
|
|
307
|
+
db=info.context.db, results=remaining_results
|
|
308
|
+
):
|
|
309
|
+
yield result_payload
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
async def _stream_chat_completion_over_dataset_example(
|
|
313
|
+
*,
|
|
314
|
+
input: ChatCompletionOverDatasetInput,
|
|
315
|
+
llm_client_class: type["PlaygroundStreamingClient"],
|
|
316
|
+
revision: models.DatasetExampleRevision,
|
|
317
|
+
results_queue: Queue[ChatCompletionResult],
|
|
318
|
+
experiment_id: int,
|
|
319
|
+
project_id: int,
|
|
320
|
+
) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
|
|
321
|
+
example_id = GlobalID(DatasetExample.__name__, str(revision.dataset_example_id))
|
|
322
|
+
llm_client = llm_client_class(
|
|
323
|
+
model=input.model,
|
|
324
|
+
api_key=input.api_key,
|
|
325
|
+
)
|
|
326
|
+
invocation_parameters = llm_client.construct_invocation_parameters(input.invocation_parameters)
|
|
327
|
+
messages = [
|
|
328
|
+
(
|
|
329
|
+
message.role,
|
|
330
|
+
message.content,
|
|
331
|
+
message.tool_call_id if isinstance(message.tool_call_id, str) else None,
|
|
332
|
+
message.tool_calls if isinstance(message.tool_calls, list) else None,
|
|
333
|
+
)
|
|
334
|
+
for message in input.messages
|
|
335
|
+
]
|
|
336
|
+
try:
|
|
337
|
+
format_start_time = cast(datetime, normalize_datetime(dt=local_now(), tz=timezone.utc))
|
|
338
|
+
messages = list(
|
|
339
|
+
_formatted_messages(
|
|
340
|
+
messages=messages,
|
|
341
|
+
template_language=input.template_language,
|
|
342
|
+
template_variables=revision.input,
|
|
343
|
+
)
|
|
344
|
+
)
|
|
345
|
+
except TemplateFormatterError as error:
|
|
346
|
+
format_end_time = cast(datetime, normalize_datetime(dt=local_now(), tz=timezone.utc))
|
|
347
|
+
yield ChatCompletionSubscriptionError(message=str(error), dataset_example_id=example_id)
|
|
348
|
+
await results_queue.put(
|
|
349
|
+
(
|
|
350
|
+
example_id,
|
|
351
|
+
None,
|
|
352
|
+
models.ExperimentRun(
|
|
353
|
+
experiment_id=experiment_id,
|
|
354
|
+
dataset_example_id=revision.dataset_example_id,
|
|
355
|
+
trace_id=None,
|
|
356
|
+
output={},
|
|
357
|
+
repetition_number=1,
|
|
358
|
+
start_time=format_start_time,
|
|
359
|
+
end_time=format_end_time,
|
|
360
|
+
error=str(error),
|
|
361
|
+
trace=None,
|
|
362
|
+
),
|
|
363
|
+
)
|
|
364
|
+
)
|
|
365
|
+
return
|
|
366
|
+
async with streaming_llm_span(
|
|
367
|
+
input=input,
|
|
368
|
+
messages=messages,
|
|
369
|
+
invocation_parameters=invocation_parameters,
|
|
370
|
+
) as span:
|
|
371
|
+
async for chunk in llm_client.chat_completion_create(
|
|
372
|
+
messages=messages, tools=input.tools or [], **invocation_parameters
|
|
373
|
+
):
|
|
374
|
+
span.add_response_chunk(chunk)
|
|
375
|
+
chunk.dataset_example_id = example_id
|
|
376
|
+
yield chunk
|
|
377
|
+
span.set_attributes(llm_client.attributes)
|
|
378
|
+
db_trace = get_db_trace(span, project_id)
|
|
379
|
+
db_span = get_db_span(span, db_trace)
|
|
380
|
+
db_run = get_db_experiment_run(
|
|
381
|
+
db_span, db_trace, experiment_id=experiment_id, example_id=revision.dataset_example_id
|
|
382
|
+
)
|
|
383
|
+
await results_queue.put((example_id, db_span, db_run))
|
|
384
|
+
if span.status_message is not None:
|
|
385
|
+
yield ChatCompletionSubscriptionError(
|
|
386
|
+
message=span.status_message, dataset_example_id=example_id
|
|
387
|
+
)
|
|
581
388
|
|
|
582
389
|
|
|
583
|
-
def
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
],
|
|
587
|
-
) ->
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
390
|
+
async def _chat_completion_result_payloads(
|
|
391
|
+
*,
|
|
392
|
+
db: DbSessionFactory,
|
|
393
|
+
results: Sequence[ChatCompletionResult],
|
|
394
|
+
) -> AsyncIterator[ChatCompletionSubscriptionResult]:
|
|
395
|
+
if not results:
|
|
396
|
+
return
|
|
397
|
+
async with db() as session:
|
|
398
|
+
for _, span, run in results:
|
|
399
|
+
if span:
|
|
400
|
+
session.add(span)
|
|
401
|
+
session.add(run)
|
|
402
|
+
await session.flush()
|
|
403
|
+
for example_id, span, run in results:
|
|
404
|
+
yield ChatCompletionSubscriptionResult(
|
|
405
|
+
span=to_gql_span(span) if span else None,
|
|
406
|
+
experiment_run=to_gql_experiment_run(run),
|
|
407
|
+
dataset_example_id=example_id,
|
|
408
|
+
)
|
|
602
409
|
|
|
603
410
|
|
|
604
|
-
def
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
)
|
|
608
|
-
yield f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_ROLE}", "assistant"
|
|
609
|
-
if content := "".join(chunk.content for chunk in text_chunks):
|
|
610
|
-
yield f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_CONTENT}", content
|
|
611
|
-
for tool_call_index, (_tool_call_id, tool_call_chunks_) in enumerate(tool_call_chunks.items()):
|
|
612
|
-
if tool_call_chunks_ and (name := tool_call_chunks_[0].function.name):
|
|
613
|
-
yield (
|
|
614
|
-
f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_NAME}",
|
|
615
|
-
name,
|
|
616
|
-
)
|
|
617
|
-
if arguments := "".join(chunk.function.arguments for chunk in tool_call_chunks_):
|
|
618
|
-
yield (
|
|
619
|
-
f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_ARGUMENTS_JSON}",
|
|
620
|
-
arguments,
|
|
621
|
-
)
|
|
411
|
+
def _create_task_with_timeout(
|
|
412
|
+
iterable: AsyncIterator[GenericType], timeout_in_seconds: int = 60
|
|
413
|
+
) -> Task[GenericType]:
|
|
414
|
+
return create_task(wait_for(_as_coroutine(iterable), timeout=timeout_in_seconds))
|
|
622
415
|
|
|
623
416
|
|
|
624
|
-
def
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
return
|
|
417
|
+
async def _drain(queue: Queue[GenericType]) -> list[GenericType]:
|
|
418
|
+
values: list[GenericType] = []
|
|
419
|
+
while not queue.empty():
|
|
420
|
+
values.append(await queue.get())
|
|
421
|
+
return values
|
|
629
422
|
|
|
630
423
|
|
|
631
|
-
def
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
424
|
+
def _drain_no_wait(queue: Queue[GenericType]) -> list[GenericType]:
|
|
425
|
+
values: list[GenericType] = []
|
|
426
|
+
while True:
|
|
427
|
+
try:
|
|
428
|
+
values.append(queue.get_nowait())
|
|
429
|
+
except QueueEmpty:
|
|
430
|
+
break
|
|
431
|
+
return values
|
|
636
432
|
|
|
637
433
|
|
|
638
|
-
def
|
|
639
|
-
|
|
640
|
-
Converts an integer to a hexadecimal string.
|
|
641
|
-
"""
|
|
642
|
-
return hex(number)[2:]
|
|
434
|
+
async def _as_coroutine(iterable: AsyncIterator[GenericType]) -> GenericType:
|
|
435
|
+
return await iterable.__anext__()
|
|
643
436
|
|
|
644
437
|
|
|
645
438
|
def _formatted_messages(
|
|
646
|
-
|
|
647
|
-
|
|
439
|
+
*,
|
|
440
|
+
messages: Iterable[ChatCompletionMessage],
|
|
441
|
+
template_language: TemplateLanguage,
|
|
442
|
+
template_variables: Mapping[str, Any],
|
|
648
443
|
) -> Iterator[tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[str]]]]:
|
|
649
444
|
"""
|
|
650
445
|
Formats the messages using the given template options.
|
|
651
446
|
"""
|
|
652
|
-
template_formatter = _template_formatter(template_language=
|
|
447
|
+
template_formatter = _template_formatter(template_language=template_language)
|
|
653
448
|
(
|
|
654
449
|
roles,
|
|
655
450
|
templates,
|
|
@@ -657,7 +452,7 @@ def _formatted_messages(
|
|
|
657
452
|
tool_calls,
|
|
658
453
|
) = zip(*messages)
|
|
659
454
|
formatted_templates = map(
|
|
660
|
-
lambda template: template_formatter.format(template, **
|
|
455
|
+
lambda template: template_formatter.format(template, **template_variables),
|
|
661
456
|
templates,
|
|
662
457
|
)
|
|
663
458
|
formatted_messages = zip(roles, formatted_templates, tool_call_id, tool_calls)
|
|
@@ -675,36 +470,24 @@ def _template_formatter(template_language: TemplateLanguage) -> TemplateFormatte
|
|
|
675
470
|
assert_never(template_language)
|
|
676
471
|
|
|
677
472
|
|
|
678
|
-
def
|
|
679
|
-
""
|
|
680
|
-
Serializes a SpanEvent to a dictionary.
|
|
681
|
-
"""
|
|
682
|
-
return {k: (v.isoformat() if isinstance(v, datetime) else v) for k, v in asdict(event).items()}
|
|
473
|
+
def _default_playground_experiment_name() -> str:
|
|
474
|
+
return "playground-experiment"
|
|
683
475
|
|
|
684
476
|
|
|
685
|
-
|
|
477
|
+
def _default_playground_experiment_description(dataset_name: str) -> str:
|
|
478
|
+
return f'Playground experiment for dataset "{dataset_name}"'
|
|
686
479
|
|
|
687
|
-
LLM = OpenInferenceSpanKindValues.LLM.value
|
|
688
480
|
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
LLM_INVOCATION_PARAMETERS = SpanAttributes.LLM_INVOCATION_PARAMETERS
|
|
698
|
-
LLM_TOOLS = SpanAttributes.LLM_TOOLS
|
|
699
|
-
LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT
|
|
700
|
-
LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION
|
|
701
|
-
LLM_TOKEN_COUNT_TOTAL = SpanAttributes.LLM_TOKEN_COUNT_TOTAL
|
|
481
|
+
def _default_playground_experiment_metadata(
|
|
482
|
+
dataset_name: str, dataset_id: GlobalID, version_id: GlobalID
|
|
483
|
+
) -> dict[str, Any]:
|
|
484
|
+
return {
|
|
485
|
+
"dataset_name": dataset_name,
|
|
486
|
+
"dataset_id": str(dataset_id),
|
|
487
|
+
"dataset_version_id": str(version_id),
|
|
488
|
+
}
|
|
702
489
|
|
|
703
|
-
MESSAGE_CONTENT = MessageAttributes.MESSAGE_CONTENT
|
|
704
|
-
MESSAGE_ROLE = MessageAttributes.MESSAGE_ROLE
|
|
705
|
-
MESSAGE_TOOL_CALLS = MessageAttributes.MESSAGE_TOOL_CALLS
|
|
706
490
|
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
TOOL_JSON_SCHEMA = ToolAttributes.TOOL_JSON_SCHEMA
|
|
491
|
+
LLM_OUTPUT_MESSAGES = SpanAttributes.LLM_OUTPUT_MESSAGES
|
|
492
|
+
LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION
|
|
493
|
+
LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT
|