arize-phoenix 5.10.0__py3-none-any.whl → 5.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-5.10.0.dist-info → arize_phoenix-5.12.0.dist-info}/METADATA +2 -1
- {arize_phoenix-5.10.0.dist-info → arize_phoenix-5.12.0.dist-info}/RECORD +25 -25
- phoenix/config.py +13 -0
- phoenix/db/facilitator.py +3 -2
- phoenix/server/api/helpers/playground_clients.py +64 -77
- phoenix/server/api/helpers/playground_spans.py +6 -0
- phoenix/server/api/mutations/chat_mutations.py +81 -36
- phoenix/server/api/subscriptions.py +156 -58
- phoenix/server/api/types/TemplateLanguage.py +1 -0
- phoenix/server/static/.vite/manifest.json +31 -31
- phoenix/server/static/assets/{components-BXIz9ZO8.js → components-72cQL1d1.js} +95 -95
- phoenix/server/static/assets/{index-DTut7g1y.js → index-BowjltW-.js} +1 -1
- phoenix/server/static/assets/{pages-B8FpJuXu.js → pages-DFAkBAUh.js} +339 -271
- phoenix/server/static/assets/{vendor-BX8_Znqy.js → vendor-DexmGnha.js} +150 -150
- phoenix/server/static/assets/{vendor-arizeai-CtHir-Ua.js → vendor-arizeai--Q3ol330.js} +28 -28
- phoenix/server/static/assets/{vendor-codemirror-DLlGiguX.js → vendor-codemirror-B4bYvWa6.js} +1 -1
- phoenix/server/static/assets/{vendor-recharts-CJRple0d.js → vendor-recharts-B4ZzJhNh.js} +1 -1
- phoenix/trace/span_evaluations.py +4 -3
- phoenix/utilities/json.py +7 -1
- phoenix/utilities/template_formatters.py +18 -0
- phoenix/version.py +1 -1
- {arize_phoenix-5.10.0.dist-info → arize_phoenix-5.12.0.dist-info}/WHEEL +0 -0
- {arize_phoenix-5.10.0.dist-info → arize_phoenix-5.12.0.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-5.10.0.dist-info → arize_phoenix-5.12.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-5.10.0.dist-info → arize_phoenix-5.12.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
from dataclasses import asdict, field
|
|
3
3
|
from datetime import datetime, timezone
|
|
4
|
-
from itertools import chain
|
|
4
|
+
from itertools import chain, islice
|
|
5
5
|
from traceback import format_exc
|
|
6
|
-
from typing import Any, Iterable, Iterator, List, Optional, Union
|
|
6
|
+
from typing import Any, Iterable, Iterator, List, Optional, TypeVar, Union
|
|
7
7
|
|
|
8
8
|
import strawberry
|
|
9
9
|
from openinference.instrumentation import safe_json_dumps
|
|
@@ -25,8 +25,9 @@ from typing_extensions import assert_never
|
|
|
25
25
|
from phoenix.datetime_utils import local_now, normalize_datetime
|
|
26
26
|
from phoenix.db import models
|
|
27
27
|
from phoenix.db.helpers import get_dataset_example_revisions
|
|
28
|
+
from phoenix.server.api.auth import IsNotReadOnly
|
|
28
29
|
from phoenix.server.api.context import Context
|
|
29
|
-
from phoenix.server.api.exceptions import BadRequest, NotFound
|
|
30
|
+
from phoenix.server.api.exceptions import BadRequest, CustomGraphQLError, NotFound
|
|
30
31
|
from phoenix.server.api.helpers.playground_clients import (
|
|
31
32
|
PlaygroundStreamingClient,
|
|
32
33
|
initialize_playground_clients,
|
|
@@ -67,6 +68,7 @@ from phoenix.utilities.json import jsonify
|
|
|
67
68
|
from phoenix.utilities.template_formatters import (
|
|
68
69
|
FStringTemplateFormatter,
|
|
69
70
|
MustacheTemplateFormatter,
|
|
71
|
+
NoOpFormatter,
|
|
70
72
|
TemplateFormatter,
|
|
71
73
|
)
|
|
72
74
|
|
|
@@ -117,7 +119,7 @@ class ChatCompletionOverDatasetMutationPayload:
|
|
|
117
119
|
|
|
118
120
|
@strawberry.type
|
|
119
121
|
class ChatCompletionMutationMixin:
|
|
120
|
-
@strawberry.mutation
|
|
122
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
|
|
121
123
|
@classmethod
|
|
122
124
|
async def chat_completion_over_dataset(
|
|
123
125
|
cls,
|
|
@@ -127,11 +129,19 @@ class ChatCompletionMutationMixin:
|
|
|
127
129
|
provider_key = input.model.provider_key
|
|
128
130
|
llm_client_class = PLAYGROUND_CLIENT_REGISTRY.get_client(provider_key, input.model.name)
|
|
129
131
|
if llm_client_class is None:
|
|
130
|
-
raise BadRequest(f"
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
132
|
+
raise BadRequest(f"Unknown LLM provider: '{provider_key.value}'")
|
|
133
|
+
try:
|
|
134
|
+
llm_client = llm_client_class(
|
|
135
|
+
model=input.model,
|
|
136
|
+
api_key=input.api_key,
|
|
137
|
+
)
|
|
138
|
+
except CustomGraphQLError:
|
|
139
|
+
raise
|
|
140
|
+
except Exception as error:
|
|
141
|
+
raise BadRequest(
|
|
142
|
+
f"Failed to connect to LLM API for {provider_key.value} {input.model.name}: "
|
|
143
|
+
f"{str(error)}"
|
|
144
|
+
)
|
|
135
145
|
dataset_id = from_global_id_with_expected_type(input.dataset_id, Dataset.__name__)
|
|
136
146
|
dataset_version_id = (
|
|
137
147
|
from_global_id_with_expected_type(
|
|
@@ -158,7 +168,9 @@ class ChatCompletionMutationMixin:
|
|
|
158
168
|
revisions = [
|
|
159
169
|
revision
|
|
160
170
|
async for revision in await session.stream_scalars(
|
|
161
|
-
get_dataset_example_revisions(resolved_version_id)
|
|
171
|
+
get_dataset_example_revisions(resolved_version_id).order_by(
|
|
172
|
+
models.DatasetExampleRevision.id
|
|
173
|
+
)
|
|
162
174
|
)
|
|
163
175
|
]
|
|
164
176
|
if not revisions:
|
|
@@ -181,28 +193,32 @@ class ChatCompletionMutationMixin:
|
|
|
181
193
|
session.add(experiment)
|
|
182
194
|
await session.flush()
|
|
183
195
|
|
|
196
|
+
results = []
|
|
197
|
+
batch_size = 3
|
|
184
198
|
start_time = datetime.now(timezone.utc)
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
+
for batch in _get_batches(revisions, batch_size):
|
|
200
|
+
batch_results = await asyncio.gather(
|
|
201
|
+
*(
|
|
202
|
+
cls._chat_completion(
|
|
203
|
+
info,
|
|
204
|
+
llm_client,
|
|
205
|
+
ChatCompletionInput(
|
|
206
|
+
model=input.model,
|
|
207
|
+
api_key=input.api_key,
|
|
208
|
+
messages=input.messages,
|
|
209
|
+
tools=input.tools,
|
|
210
|
+
invocation_parameters=input.invocation_parameters,
|
|
211
|
+
template=TemplateOptions(
|
|
212
|
+
language=input.template_language,
|
|
213
|
+
variables=revision.input,
|
|
214
|
+
),
|
|
199
215
|
),
|
|
200
|
-
)
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
216
|
+
)
|
|
217
|
+
for revision in batch
|
|
218
|
+
),
|
|
219
|
+
return_exceptions=True,
|
|
220
|
+
)
|
|
221
|
+
results.extend(batch_results)
|
|
206
222
|
|
|
207
223
|
payload = ChatCompletionOverDatasetMutationPayload(
|
|
208
224
|
dataset_id=GlobalID(models.Dataset.__name__, str(dataset.id)),
|
|
@@ -258,7 +274,7 @@ class ChatCompletionMutationMixin:
|
|
|
258
274
|
payload.examples.append(example_payload)
|
|
259
275
|
return payload
|
|
260
276
|
|
|
261
|
-
@strawberry.mutation
|
|
277
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
|
|
262
278
|
@classmethod
|
|
263
279
|
async def chat_completion(
|
|
264
280
|
cls, info: Info[Context, None], input: ChatCompletionInput
|
|
@@ -266,11 +282,19 @@ class ChatCompletionMutationMixin:
|
|
|
266
282
|
provider_key = input.model.provider_key
|
|
267
283
|
llm_client_class = PLAYGROUND_CLIENT_REGISTRY.get_client(provider_key, input.model.name)
|
|
268
284
|
if llm_client_class is None:
|
|
269
|
-
raise BadRequest(f"
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
285
|
+
raise BadRequest(f"Unknown LLM provider: '{provider_key.value}'")
|
|
286
|
+
try:
|
|
287
|
+
llm_client = llm_client_class(
|
|
288
|
+
model=input.model,
|
|
289
|
+
api_key=input.api_key,
|
|
290
|
+
)
|
|
291
|
+
except CustomGraphQLError:
|
|
292
|
+
raise
|
|
293
|
+
except Exception as error:
|
|
294
|
+
raise BadRequest(
|
|
295
|
+
f"Failed to connect to LLM API for {provider_key.value} {input.model.name}: "
|
|
296
|
+
f"{str(error)}"
|
|
297
|
+
)
|
|
274
298
|
return await cls._chat_completion(info, llm_client, input)
|
|
275
299
|
|
|
276
300
|
@classmethod
|
|
@@ -459,6 +483,8 @@ def _template_formatter(template_language: TemplateLanguage) -> TemplateFormatte
|
|
|
459
483
|
return MustacheTemplateFormatter()
|
|
460
484
|
if template_language is TemplateLanguage.F_STRING:
|
|
461
485
|
return FStringTemplateFormatter()
|
|
486
|
+
if template_language is TemplateLanguage.NONE:
|
|
487
|
+
return NoOpFormatter()
|
|
462
488
|
assert_never(template_language)
|
|
463
489
|
|
|
464
490
|
|
|
@@ -486,6 +512,11 @@ def _llm_output_messages(
|
|
|
486
512
|
if text_content:
|
|
487
513
|
yield f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_CONTENT}", text_content
|
|
488
514
|
for tool_call_index, tool_call in enumerate(tool_calls.values()):
|
|
515
|
+
if tool_call_id := tool_call.id:
|
|
516
|
+
yield (
|
|
517
|
+
f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_ID}",
|
|
518
|
+
tool_call_id,
|
|
519
|
+
)
|
|
489
520
|
yield (
|
|
490
521
|
f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_NAME}",
|
|
491
522
|
tool_call.function.name,
|
|
@@ -513,6 +544,19 @@ def _serialize_event(event: SpanException) -> dict[str, Any]:
|
|
|
513
544
|
return {k: (v.isoformat() if isinstance(v, datetime) else v) for k, v in asdict(event).items()}
|
|
514
545
|
|
|
515
546
|
|
|
547
|
+
_AnyT = TypeVar("_AnyT")
|
|
548
|
+
|
|
549
|
+
|
|
550
|
+
def _get_batches(
|
|
551
|
+
iterable: Iterable[_AnyT],
|
|
552
|
+
batch_size: int,
|
|
553
|
+
) -> Iterator[list[_AnyT]]:
|
|
554
|
+
"""Splits an iterable into batches not exceeding a specified size."""
|
|
555
|
+
iterator = iter(iterable)
|
|
556
|
+
while batch := list(islice(iterator, batch_size)):
|
|
557
|
+
yield batch
|
|
558
|
+
|
|
559
|
+
|
|
516
560
|
JSON = OpenInferenceMimeTypeValues.JSON.value
|
|
517
561
|
TEXT = OpenInferenceMimeTypeValues.TEXT.value
|
|
518
562
|
LLM = OpenInferenceSpanKindValues.LLM.value
|
|
@@ -534,6 +578,7 @@ MESSAGE_CONTENT = MessageAttributes.MESSAGE_CONTENT
|
|
|
534
578
|
MESSAGE_ROLE = MessageAttributes.MESSAGE_ROLE
|
|
535
579
|
MESSAGE_TOOL_CALLS = MessageAttributes.MESSAGE_TOOL_CALLS
|
|
536
580
|
|
|
581
|
+
TOOL_CALL_ID = ToolCallAttributes.TOOL_CALL_ID
|
|
537
582
|
TOOL_CALL_FUNCTION_NAME = ToolCallAttributes.TOOL_CALL_FUNCTION_NAME
|
|
538
583
|
TOOL_CALL_FUNCTION_ARGUMENTS_JSON = ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUMENTS_JSON
|
|
539
584
|
|
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import logging
|
|
3
|
-
from asyncio import FIRST_COMPLETED, Queue, QueueEmpty, Task, create_task, wait, wait_for
|
|
4
3
|
from collections.abc import AsyncIterator, Iterator
|
|
5
|
-
from datetime import datetime, timezone
|
|
4
|
+
from datetime import datetime, timedelta, timezone
|
|
6
5
|
from typing import (
|
|
7
6
|
Any,
|
|
7
|
+
AsyncGenerator,
|
|
8
|
+
Coroutine,
|
|
8
9
|
Iterable,
|
|
9
10
|
Mapping,
|
|
10
11
|
Optional,
|
|
@@ -23,8 +24,9 @@ from typing_extensions import TypeAlias, assert_never
|
|
|
23
24
|
|
|
24
25
|
from phoenix.datetime_utils import local_now, normalize_datetime
|
|
25
26
|
from phoenix.db import models
|
|
27
|
+
from phoenix.server.api.auth import IsNotReadOnly
|
|
26
28
|
from phoenix.server.api.context import Context
|
|
27
|
-
from phoenix.server.api.exceptions import BadRequest, NotFound
|
|
29
|
+
from phoenix.server.api.exceptions import BadRequest, CustomGraphQLError, NotFound
|
|
28
30
|
from phoenix.server.api.helpers.playground_clients import (
|
|
29
31
|
PlaygroundStreamingClient,
|
|
30
32
|
initialize_playground_clients,
|
|
@@ -62,6 +64,7 @@ from phoenix.server.types import DbSessionFactory
|
|
|
62
64
|
from phoenix.utilities.template_formatters import (
|
|
63
65
|
FStringTemplateFormatter,
|
|
64
66
|
MustacheTemplateFormatter,
|
|
67
|
+
NoOpFormatter,
|
|
65
68
|
TemplateFormatter,
|
|
66
69
|
TemplateFormatterError,
|
|
67
70
|
)
|
|
@@ -79,23 +82,32 @@ DatasetExampleID: TypeAlias = GlobalID
|
|
|
79
82
|
ChatCompletionResult: TypeAlias = tuple[
|
|
80
83
|
DatasetExampleID, Optional[models.Span], models.ExperimentRun
|
|
81
84
|
]
|
|
85
|
+
ChatStream: TypeAlias = AsyncGenerator[ChatCompletionSubscriptionPayload, None]
|
|
82
86
|
PLAYGROUND_PROJECT_NAME = "playground"
|
|
83
87
|
|
|
84
88
|
|
|
85
89
|
@strawberry.type
|
|
86
90
|
class Subscription:
|
|
87
|
-
@strawberry.subscription
|
|
91
|
+
@strawberry.subscription(permission_classes=[IsNotReadOnly]) # type: ignore
|
|
88
92
|
async def chat_completion(
|
|
89
93
|
self, info: Info[Context, None], input: ChatCompletionInput
|
|
90
94
|
) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
|
|
91
95
|
provider_key = input.model.provider_key
|
|
92
96
|
llm_client_class = PLAYGROUND_CLIENT_REGISTRY.get_client(provider_key, input.model.name)
|
|
93
97
|
if llm_client_class is None:
|
|
94
|
-
raise BadRequest(f"
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
98
|
+
raise BadRequest(f"Unknown LLM provider: '{provider_key.value}'")
|
|
99
|
+
try:
|
|
100
|
+
llm_client = llm_client_class(
|
|
101
|
+
model=input.model,
|
|
102
|
+
api_key=input.api_key,
|
|
103
|
+
)
|
|
104
|
+
except CustomGraphQLError:
|
|
105
|
+
raise
|
|
106
|
+
except Exception as error:
|
|
107
|
+
raise BadRequest(
|
|
108
|
+
f"Failed to connect to LLM API for {provider_key.value} {input.model.name}: "
|
|
109
|
+
f"{str(error)}"
|
|
110
|
+
)
|
|
99
111
|
|
|
100
112
|
messages = [
|
|
101
113
|
(
|
|
@@ -151,14 +163,26 @@ class Subscription:
|
|
|
151
163
|
info.context.event_queue.put(SpanInsertEvent(ids=(playground_project_id,)))
|
|
152
164
|
yield ChatCompletionSubscriptionResult(span=to_gql_span(db_span))
|
|
153
165
|
|
|
154
|
-
@strawberry.subscription
|
|
166
|
+
@strawberry.subscription(permission_classes=[IsNotReadOnly]) # type: ignore
|
|
155
167
|
async def chat_completion_over_dataset(
|
|
156
168
|
self, info: Info[Context, None], input: ChatCompletionOverDatasetInput
|
|
157
169
|
) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
|
|
158
170
|
provider_key = input.model.provider_key
|
|
159
171
|
llm_client_class = PLAYGROUND_CLIENT_REGISTRY.get_client(provider_key, input.model.name)
|
|
160
172
|
if llm_client_class is None:
|
|
161
|
-
raise BadRequest(f"
|
|
173
|
+
raise BadRequest(f"Unknown LLM provider: '{provider_key.value}'")
|
|
174
|
+
try:
|
|
175
|
+
llm_client = llm_client_class(
|
|
176
|
+
model=input.model,
|
|
177
|
+
api_key=input.api_key,
|
|
178
|
+
)
|
|
179
|
+
except CustomGraphQLError:
|
|
180
|
+
raise
|
|
181
|
+
except Exception as error:
|
|
182
|
+
raise BadRequest(
|
|
183
|
+
f"Failed to connect to LLM API for {provider_key.value} {input.model.name}: "
|
|
184
|
+
f"{str(error)}"
|
|
185
|
+
)
|
|
162
186
|
|
|
163
187
|
dataset_id = from_global_id_with_expected_type(input.dataset_id, Dataset.__name__)
|
|
164
188
|
version_id = (
|
|
@@ -264,45 +288,82 @@ class Subscription:
|
|
|
264
288
|
experiment=to_gql_experiment(experiment)
|
|
265
289
|
) # eagerly yields experiment so it can be linked by consumers of the subscription
|
|
266
290
|
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
291
|
+
results: asyncio.Queue[ChatCompletionResult] = asyncio.Queue()
|
|
292
|
+
not_started: list[tuple[DatasetExampleID, ChatStream]] = [
|
|
293
|
+
(
|
|
294
|
+
GlobalID(DatasetExample.__name__, str(revision.dataset_example_id)),
|
|
295
|
+
_stream_chat_completion_over_dataset_example(
|
|
296
|
+
input=input,
|
|
297
|
+
llm_client=llm_client,
|
|
298
|
+
revision=revision,
|
|
299
|
+
results=results,
|
|
300
|
+
experiment_id=experiment.id,
|
|
301
|
+
project_id=playground_project_id,
|
|
302
|
+
),
|
|
276
303
|
)
|
|
277
304
|
for revision in revisions
|
|
278
305
|
]
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
306
|
+
in_progress: list[
|
|
307
|
+
tuple[
|
|
308
|
+
Optional[DatasetExampleID],
|
|
309
|
+
ChatStream,
|
|
310
|
+
asyncio.Task[ChatCompletionSubscriptionPayload],
|
|
311
|
+
]
|
|
312
|
+
] = []
|
|
313
|
+
max_in_progress = 3
|
|
314
|
+
write_batch_size = 10
|
|
315
|
+
write_interval = timedelta(seconds=10)
|
|
316
|
+
last_write_time = datetime.now()
|
|
317
|
+
while not_started or in_progress:
|
|
318
|
+
while not_started and len(in_progress) < max_in_progress:
|
|
319
|
+
ex_id, stream = not_started.pop()
|
|
320
|
+
task = _create_task_with_timeout(stream)
|
|
321
|
+
in_progress.append((ex_id, stream, task))
|
|
322
|
+
async_tasks_to_run = [task for _, _, task in in_progress]
|
|
323
|
+
completed_tasks, _ = await asyncio.wait(
|
|
324
|
+
async_tasks_to_run, return_when=asyncio.FIRST_COMPLETED
|
|
325
|
+
)
|
|
326
|
+
for completed_task in completed_tasks:
|
|
327
|
+
idx = [task for _, _, task in in_progress].index(completed_task)
|
|
328
|
+
example_id, stream, _ = in_progress[idx]
|
|
289
329
|
try:
|
|
290
|
-
yield
|
|
291
|
-
except
|
|
292
|
-
del
|
|
330
|
+
yield completed_task.result()
|
|
331
|
+
except StopAsyncIteration:
|
|
332
|
+
del in_progress[idx] # removes exhausted stream
|
|
333
|
+
except asyncio.TimeoutError:
|
|
334
|
+
del in_progress[idx] # removes timed-out stream
|
|
335
|
+
if example_id is not None:
|
|
336
|
+
yield ChatCompletionSubscriptionError(
|
|
337
|
+
message="Playground task timed out", dataset_example_id=example_id
|
|
338
|
+
)
|
|
293
339
|
except Exception as error:
|
|
294
|
-
del
|
|
340
|
+
del in_progress[idx] # removes failed stream
|
|
341
|
+
if example_id is not None:
|
|
342
|
+
yield ChatCompletionSubscriptionError(
|
|
343
|
+
message="An unexpected error occurred", dataset_example_id=example_id
|
|
344
|
+
)
|
|
295
345
|
logger.exception(error)
|
|
296
346
|
else:
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
347
|
+
task = _create_task_with_timeout(stream)
|
|
348
|
+
in_progress[idx] = (example_id, stream, task)
|
|
349
|
+
|
|
350
|
+
exceeded_write_batch_size = results.qsize() >= write_batch_size
|
|
351
|
+
exceeded_write_interval = datetime.now() - last_write_time > write_interval
|
|
352
|
+
write_already_in_progress = any(
|
|
353
|
+
_is_result_payloads_stream(stream) for _, stream, _ in in_progress
|
|
354
|
+
)
|
|
355
|
+
if (
|
|
356
|
+
not results.empty()
|
|
357
|
+
and (exceeded_write_batch_size or exceeded_write_interval)
|
|
358
|
+
and not write_already_in_progress
|
|
359
|
+
):
|
|
360
|
+
result_payloads_stream = _chat_completion_result_payloads(
|
|
361
|
+
db=info.context.db, results=_drain_no_wait(results)
|
|
304
362
|
)
|
|
305
|
-
|
|
363
|
+
task = _create_task_with_timeout(result_payloads_stream)
|
|
364
|
+
in_progress.append((None, result_payloads_stream, task))
|
|
365
|
+
last_write_time = datetime.now()
|
|
366
|
+
if remaining_results := await _drain(results):
|
|
306
367
|
async for result_payload in _chat_completion_result_payloads(
|
|
307
368
|
db=info.context.db, results=remaining_results
|
|
308
369
|
):
|
|
@@ -312,17 +373,13 @@ class Subscription:
|
|
|
312
373
|
async def _stream_chat_completion_over_dataset_example(
|
|
313
374
|
*,
|
|
314
375
|
input: ChatCompletionOverDatasetInput,
|
|
315
|
-
|
|
376
|
+
llm_client: PlaygroundStreamingClient,
|
|
316
377
|
revision: models.DatasetExampleRevision,
|
|
317
|
-
|
|
378
|
+
results: asyncio.Queue[ChatCompletionResult],
|
|
318
379
|
experiment_id: int,
|
|
319
380
|
project_id: int,
|
|
320
|
-
) ->
|
|
381
|
+
) -> ChatStream:
|
|
321
382
|
example_id = GlobalID(DatasetExample.__name__, str(revision.dataset_example_id))
|
|
322
|
-
llm_client = llm_client_class(
|
|
323
|
-
model=input.model,
|
|
324
|
-
api_key=input.api_key,
|
|
325
|
-
)
|
|
326
383
|
invocation_parameters = llm_client.construct_invocation_parameters(input.invocation_parameters)
|
|
327
384
|
messages = [
|
|
328
385
|
(
|
|
@@ -345,7 +402,7 @@ async def _stream_chat_completion_over_dataset_example(
|
|
|
345
402
|
except TemplateFormatterError as error:
|
|
346
403
|
format_end_time = cast(datetime, normalize_datetime(dt=local_now(), tz=timezone.utc))
|
|
347
404
|
yield ChatCompletionSubscriptionError(message=str(error), dataset_example_id=example_id)
|
|
348
|
-
await
|
|
405
|
+
await results.put(
|
|
349
406
|
(
|
|
350
407
|
example_id,
|
|
351
408
|
None,
|
|
@@ -380,7 +437,7 @@ async def _stream_chat_completion_over_dataset_example(
|
|
|
380
437
|
db_run = get_db_experiment_run(
|
|
381
438
|
db_span, db_trace, experiment_id=experiment_id, example_id=revision.dataset_example_id
|
|
382
439
|
)
|
|
383
|
-
await
|
|
440
|
+
await results.put((example_id, db_span, db_run))
|
|
384
441
|
if span.status_message is not None:
|
|
385
442
|
yield ChatCompletionSubscriptionError(
|
|
386
443
|
message=span.status_message, dataset_example_id=example_id
|
|
@@ -391,7 +448,7 @@ async def _chat_completion_result_payloads(
|
|
|
391
448
|
*,
|
|
392
449
|
db: DbSessionFactory,
|
|
393
450
|
results: Sequence[ChatCompletionResult],
|
|
394
|
-
) ->
|
|
451
|
+
) -> ChatStream:
|
|
395
452
|
if not results:
|
|
396
453
|
return
|
|
397
454
|
async with db() as session:
|
|
@@ -408,25 +465,64 @@ async def _chat_completion_result_payloads(
|
|
|
408
465
|
)
|
|
409
466
|
|
|
410
467
|
|
|
468
|
+
def _is_result_payloads_stream(
|
|
469
|
+
stream: ChatStream,
|
|
470
|
+
) -> bool:
|
|
471
|
+
"""
|
|
472
|
+
Checks if the given generator was instantiated from
|
|
473
|
+
`_chat_completion_result_payloads`
|
|
474
|
+
"""
|
|
475
|
+
return stream.ag_code == _chat_completion_result_payloads.__code__
|
|
476
|
+
|
|
477
|
+
|
|
411
478
|
def _create_task_with_timeout(
|
|
412
|
-
iterable: AsyncIterator[GenericType], timeout_in_seconds: int =
|
|
413
|
-
) -> Task[GenericType]:
|
|
414
|
-
return create_task(
|
|
479
|
+
iterable: AsyncIterator[GenericType], timeout_in_seconds: int = 90
|
|
480
|
+
) -> asyncio.Task[GenericType]:
|
|
481
|
+
return asyncio.create_task(
|
|
482
|
+
_wait_for(
|
|
483
|
+
_as_coroutine(iterable),
|
|
484
|
+
timeout=timeout_in_seconds,
|
|
485
|
+
timeout_message="Playground task timed out",
|
|
486
|
+
)
|
|
487
|
+
)
|
|
488
|
+
|
|
489
|
+
|
|
490
|
+
async def _wait_for(
|
|
491
|
+
coro: Coroutine[None, None, GenericType],
|
|
492
|
+
timeout: float,
|
|
493
|
+
timeout_message: Optional[str] = None,
|
|
494
|
+
) -> GenericType:
|
|
495
|
+
"""
|
|
496
|
+
A function that imitates asyncio.wait_for, but allows the task to be
|
|
497
|
+
cancelled with a custom message.
|
|
498
|
+
"""
|
|
499
|
+
task = asyncio.create_task(coro)
|
|
500
|
+
done, pending = await asyncio.wait([task], timeout=timeout)
|
|
501
|
+
assert len(done) + len(pending) == 1
|
|
502
|
+
if done:
|
|
503
|
+
task = done.pop()
|
|
504
|
+
return task.result()
|
|
505
|
+
task = pending.pop()
|
|
506
|
+
task.cancel(msg=timeout_message)
|
|
507
|
+
try:
|
|
508
|
+
return await task
|
|
509
|
+
except asyncio.CancelledError:
|
|
510
|
+
raise asyncio.TimeoutError()
|
|
415
511
|
|
|
416
512
|
|
|
417
|
-
async def _drain(queue: Queue[GenericType]) -> list[GenericType]:
|
|
513
|
+
async def _drain(queue: asyncio.Queue[GenericType]) -> list[GenericType]:
|
|
418
514
|
values: list[GenericType] = []
|
|
419
515
|
while not queue.empty():
|
|
420
516
|
values.append(await queue.get())
|
|
421
517
|
return values
|
|
422
518
|
|
|
423
519
|
|
|
424
|
-
def _drain_no_wait(queue: Queue[GenericType]) -> list[GenericType]:
|
|
520
|
+
def _drain_no_wait(queue: asyncio.Queue[GenericType]) -> list[GenericType]:
|
|
425
521
|
values: list[GenericType] = []
|
|
426
522
|
while True:
|
|
427
523
|
try:
|
|
428
524
|
values.append(queue.get_nowait())
|
|
429
|
-
except QueueEmpty:
|
|
525
|
+
except asyncio.QueueEmpty:
|
|
430
526
|
break
|
|
431
527
|
return values
|
|
432
528
|
|
|
@@ -467,6 +563,8 @@ def _template_formatter(template_language: TemplateLanguage) -> TemplateFormatte
|
|
|
467
563
|
return MustacheTemplateFormatter()
|
|
468
564
|
if template_language is TemplateLanguage.F_STRING:
|
|
469
565
|
return FStringTemplateFormatter()
|
|
566
|
+
if template_language is TemplateLanguage.NONE:
|
|
567
|
+
return NoOpFormatter()
|
|
470
568
|
assert_never(template_language)
|
|
471
569
|
|
|
472
570
|
|
|
@@ -1,32 +1,32 @@
|
|
|
1
1
|
{
|
|
2
|
-
"_components-
|
|
3
|
-
"file": "assets/components-
|
|
2
|
+
"_components-72cQL1d1.js": {
|
|
3
|
+
"file": "assets/components-72cQL1d1.js",
|
|
4
4
|
"name": "components",
|
|
5
5
|
"imports": [
|
|
6
|
-
"_vendor-
|
|
7
|
-
"_pages-
|
|
8
|
-
"_vendor-arizeai
|
|
9
|
-
"_vendor-codemirror-
|
|
6
|
+
"_vendor-DexmGnha.js",
|
|
7
|
+
"_pages-DFAkBAUh.js",
|
|
8
|
+
"_vendor-arizeai--Q3ol330.js",
|
|
9
|
+
"_vendor-codemirror-B4bYvWa6.js",
|
|
10
10
|
"_vendor-three-DwGkEfCM.js"
|
|
11
11
|
]
|
|
12
12
|
},
|
|
13
|
-
"_pages-
|
|
14
|
-
"file": "assets/pages-
|
|
13
|
+
"_pages-DFAkBAUh.js": {
|
|
14
|
+
"file": "assets/pages-DFAkBAUh.js",
|
|
15
15
|
"name": "pages",
|
|
16
16
|
"imports": [
|
|
17
|
-
"_vendor-
|
|
18
|
-
"_vendor-arizeai
|
|
19
|
-
"_components-
|
|
20
|
-
"_vendor-recharts-
|
|
21
|
-
"_vendor-codemirror-
|
|
17
|
+
"_vendor-DexmGnha.js",
|
|
18
|
+
"_vendor-arizeai--Q3ol330.js",
|
|
19
|
+
"_components-72cQL1d1.js",
|
|
20
|
+
"_vendor-recharts-B4ZzJhNh.js",
|
|
21
|
+
"_vendor-codemirror-B4bYvWa6.js"
|
|
22
22
|
]
|
|
23
23
|
},
|
|
24
24
|
"_vendor-!~{003}~.js": {
|
|
25
25
|
"file": "assets/vendor-DxkFTwjz.css",
|
|
26
26
|
"src": "_vendor-!~{003}~.js"
|
|
27
27
|
},
|
|
28
|
-
"_vendor-
|
|
29
|
-
"file": "assets/vendor-
|
|
28
|
+
"_vendor-DexmGnha.js": {
|
|
29
|
+
"file": "assets/vendor-DexmGnha.js",
|
|
30
30
|
"name": "vendor",
|
|
31
31
|
"imports": [
|
|
32
32
|
"_vendor-three-DwGkEfCM.js"
|
|
@@ -35,25 +35,25 @@
|
|
|
35
35
|
"assets/vendor-DxkFTwjz.css"
|
|
36
36
|
]
|
|
37
37
|
},
|
|
38
|
-
"_vendor-arizeai
|
|
39
|
-
"file": "assets/vendor-arizeai
|
|
38
|
+
"_vendor-arizeai--Q3ol330.js": {
|
|
39
|
+
"file": "assets/vendor-arizeai--Q3ol330.js",
|
|
40
40
|
"name": "vendor-arizeai",
|
|
41
41
|
"imports": [
|
|
42
|
-
"_vendor-
|
|
42
|
+
"_vendor-DexmGnha.js"
|
|
43
43
|
]
|
|
44
44
|
},
|
|
45
|
-
"_vendor-codemirror-
|
|
46
|
-
"file": "assets/vendor-codemirror-
|
|
45
|
+
"_vendor-codemirror-B4bYvWa6.js": {
|
|
46
|
+
"file": "assets/vendor-codemirror-B4bYvWa6.js",
|
|
47
47
|
"name": "vendor-codemirror",
|
|
48
48
|
"imports": [
|
|
49
|
-
"_vendor-
|
|
49
|
+
"_vendor-DexmGnha.js"
|
|
50
50
|
]
|
|
51
51
|
},
|
|
52
|
-
"_vendor-recharts-
|
|
53
|
-
"file": "assets/vendor-recharts-
|
|
52
|
+
"_vendor-recharts-B4ZzJhNh.js": {
|
|
53
|
+
"file": "assets/vendor-recharts-B4ZzJhNh.js",
|
|
54
54
|
"name": "vendor-recharts",
|
|
55
55
|
"imports": [
|
|
56
|
-
"_vendor-
|
|
56
|
+
"_vendor-DexmGnha.js"
|
|
57
57
|
]
|
|
58
58
|
},
|
|
59
59
|
"_vendor-three-DwGkEfCM.js": {
|
|
@@ -61,18 +61,18 @@
|
|
|
61
61
|
"name": "vendor-three"
|
|
62
62
|
},
|
|
63
63
|
"index.tsx": {
|
|
64
|
-
"file": "assets/index-
|
|
64
|
+
"file": "assets/index-BowjltW-.js",
|
|
65
65
|
"name": "index",
|
|
66
66
|
"src": "index.tsx",
|
|
67
67
|
"isEntry": true,
|
|
68
68
|
"imports": [
|
|
69
|
-
"_vendor-
|
|
70
|
-
"_vendor-arizeai
|
|
71
|
-
"_pages-
|
|
72
|
-
"_components-
|
|
69
|
+
"_vendor-DexmGnha.js",
|
|
70
|
+
"_vendor-arizeai--Q3ol330.js",
|
|
71
|
+
"_pages-DFAkBAUh.js",
|
|
72
|
+
"_components-72cQL1d1.js",
|
|
73
73
|
"_vendor-three-DwGkEfCM.js",
|
|
74
|
-
"_vendor-recharts-
|
|
75
|
-
"_vendor-codemirror-
|
|
74
|
+
"_vendor-recharts-B4ZzJhNh.js",
|
|
75
|
+
"_vendor-codemirror-B4bYvWa6.js"
|
|
76
76
|
]
|
|
77
77
|
}
|
|
78
78
|
}
|