arize-phoenix 5.10.0__py3-none-any.whl → 5.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (25) hide show
  1. {arize_phoenix-5.10.0.dist-info → arize_phoenix-5.12.0.dist-info}/METADATA +2 -1
  2. {arize_phoenix-5.10.0.dist-info → arize_phoenix-5.12.0.dist-info}/RECORD +25 -25
  3. phoenix/config.py +13 -0
  4. phoenix/db/facilitator.py +3 -2
  5. phoenix/server/api/helpers/playground_clients.py +64 -77
  6. phoenix/server/api/helpers/playground_spans.py +6 -0
  7. phoenix/server/api/mutations/chat_mutations.py +81 -36
  8. phoenix/server/api/subscriptions.py +156 -58
  9. phoenix/server/api/types/TemplateLanguage.py +1 -0
  10. phoenix/server/static/.vite/manifest.json +31 -31
  11. phoenix/server/static/assets/{components-BXIz9ZO8.js → components-72cQL1d1.js} +95 -95
  12. phoenix/server/static/assets/{index-DTut7g1y.js → index-BowjltW-.js} +1 -1
  13. phoenix/server/static/assets/{pages-B8FpJuXu.js → pages-DFAkBAUh.js} +339 -271
  14. phoenix/server/static/assets/{vendor-BX8_Znqy.js → vendor-DexmGnha.js} +150 -150
  15. phoenix/server/static/assets/{vendor-arizeai-CtHir-Ua.js → vendor-arizeai--Q3ol330.js} +28 -28
  16. phoenix/server/static/assets/{vendor-codemirror-DLlGiguX.js → vendor-codemirror-B4bYvWa6.js} +1 -1
  17. phoenix/server/static/assets/{vendor-recharts-CJRple0d.js → vendor-recharts-B4ZzJhNh.js} +1 -1
  18. phoenix/trace/span_evaluations.py +4 -3
  19. phoenix/utilities/json.py +7 -1
  20. phoenix/utilities/template_formatters.py +18 -0
  21. phoenix/version.py +1 -1
  22. {arize_phoenix-5.10.0.dist-info → arize_phoenix-5.12.0.dist-info}/WHEEL +0 -0
  23. {arize_phoenix-5.10.0.dist-info → arize_phoenix-5.12.0.dist-info}/entry_points.txt +0 -0
  24. {arize_phoenix-5.10.0.dist-info → arize_phoenix-5.12.0.dist-info}/licenses/IP_NOTICE +0 -0
  25. {arize_phoenix-5.10.0.dist-info → arize_phoenix-5.12.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,9 +1,9 @@
1
1
  import asyncio
2
2
  from dataclasses import asdict, field
3
3
  from datetime import datetime, timezone
4
- from itertools import chain
4
+ from itertools import chain, islice
5
5
  from traceback import format_exc
6
- from typing import Any, Iterable, Iterator, List, Optional, Union
6
+ from typing import Any, Iterable, Iterator, List, Optional, TypeVar, Union
7
7
 
8
8
  import strawberry
9
9
  from openinference.instrumentation import safe_json_dumps
@@ -25,8 +25,9 @@ from typing_extensions import assert_never
25
25
  from phoenix.datetime_utils import local_now, normalize_datetime
26
26
  from phoenix.db import models
27
27
  from phoenix.db.helpers import get_dataset_example_revisions
28
+ from phoenix.server.api.auth import IsNotReadOnly
28
29
  from phoenix.server.api.context import Context
29
- from phoenix.server.api.exceptions import BadRequest, NotFound
30
+ from phoenix.server.api.exceptions import BadRequest, CustomGraphQLError, NotFound
30
31
  from phoenix.server.api.helpers.playground_clients import (
31
32
  PlaygroundStreamingClient,
32
33
  initialize_playground_clients,
@@ -67,6 +68,7 @@ from phoenix.utilities.json import jsonify
67
68
  from phoenix.utilities.template_formatters import (
68
69
  FStringTemplateFormatter,
69
70
  MustacheTemplateFormatter,
71
+ NoOpFormatter,
70
72
  TemplateFormatter,
71
73
  )
72
74
 
@@ -117,7 +119,7 @@ class ChatCompletionOverDatasetMutationPayload:
117
119
 
118
120
  @strawberry.type
119
121
  class ChatCompletionMutationMixin:
120
- @strawberry.mutation
122
+ @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
121
123
  @classmethod
122
124
  async def chat_completion_over_dataset(
123
125
  cls,
@@ -127,11 +129,19 @@ class ChatCompletionMutationMixin:
127
129
  provider_key = input.model.provider_key
128
130
  llm_client_class = PLAYGROUND_CLIENT_REGISTRY.get_client(provider_key, input.model.name)
129
131
  if llm_client_class is None:
130
- raise BadRequest(f"No LLM client registered for provider '{provider_key}'")
131
- llm_client = llm_client_class(
132
- model=input.model,
133
- api_key=input.api_key,
134
- )
132
+ raise BadRequest(f"Unknown LLM provider: '{provider_key.value}'")
133
+ try:
134
+ llm_client = llm_client_class(
135
+ model=input.model,
136
+ api_key=input.api_key,
137
+ )
138
+ except CustomGraphQLError:
139
+ raise
140
+ except Exception as error:
141
+ raise BadRequest(
142
+ f"Failed to connect to LLM API for {provider_key.value} {input.model.name}: "
143
+ f"{str(error)}"
144
+ )
135
145
  dataset_id = from_global_id_with_expected_type(input.dataset_id, Dataset.__name__)
136
146
  dataset_version_id = (
137
147
  from_global_id_with_expected_type(
@@ -158,7 +168,9 @@ class ChatCompletionMutationMixin:
158
168
  revisions = [
159
169
  revision
160
170
  async for revision in await session.stream_scalars(
161
- get_dataset_example_revisions(resolved_version_id)
171
+ get_dataset_example_revisions(resolved_version_id).order_by(
172
+ models.DatasetExampleRevision.id
173
+ )
162
174
  )
163
175
  ]
164
176
  if not revisions:
@@ -181,28 +193,32 @@ class ChatCompletionMutationMixin:
181
193
  session.add(experiment)
182
194
  await session.flush()
183
195
 
196
+ results = []
197
+ batch_size = 3
184
198
  start_time = datetime.now(timezone.utc)
185
- results = await asyncio.gather(
186
- *(
187
- cls._chat_completion(
188
- info,
189
- llm_client,
190
- ChatCompletionInput(
191
- model=input.model,
192
- api_key=input.api_key,
193
- messages=input.messages,
194
- tools=input.tools,
195
- invocation_parameters=input.invocation_parameters,
196
- template=TemplateOptions(
197
- language=input.template_language,
198
- variables=revision.input,
199
+ for batch in _get_batches(revisions, batch_size):
200
+ batch_results = await asyncio.gather(
201
+ *(
202
+ cls._chat_completion(
203
+ info,
204
+ llm_client,
205
+ ChatCompletionInput(
206
+ model=input.model,
207
+ api_key=input.api_key,
208
+ messages=input.messages,
209
+ tools=input.tools,
210
+ invocation_parameters=input.invocation_parameters,
211
+ template=TemplateOptions(
212
+ language=input.template_language,
213
+ variables=revision.input,
214
+ ),
199
215
  ),
200
- ),
201
- )
202
- for revision in revisions
203
- ),
204
- return_exceptions=True,
205
- )
216
+ )
217
+ for revision in batch
218
+ ),
219
+ return_exceptions=True,
220
+ )
221
+ results.extend(batch_results)
206
222
 
207
223
  payload = ChatCompletionOverDatasetMutationPayload(
208
224
  dataset_id=GlobalID(models.Dataset.__name__, str(dataset.id)),
@@ -258,7 +274,7 @@ class ChatCompletionMutationMixin:
258
274
  payload.examples.append(example_payload)
259
275
  return payload
260
276
 
261
- @strawberry.mutation
277
+ @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
262
278
  @classmethod
263
279
  async def chat_completion(
264
280
  cls, info: Info[Context, None], input: ChatCompletionInput
@@ -266,11 +282,19 @@ class ChatCompletionMutationMixin:
266
282
  provider_key = input.model.provider_key
267
283
  llm_client_class = PLAYGROUND_CLIENT_REGISTRY.get_client(provider_key, input.model.name)
268
284
  if llm_client_class is None:
269
- raise BadRequest(f"No LLM client registered for provider '{provider_key}'")
270
- llm_client = llm_client_class(
271
- model=input.model,
272
- api_key=input.api_key,
273
- )
285
+ raise BadRequest(f"Unknown LLM provider: '{provider_key.value}'")
286
+ try:
287
+ llm_client = llm_client_class(
288
+ model=input.model,
289
+ api_key=input.api_key,
290
+ )
291
+ except CustomGraphQLError:
292
+ raise
293
+ except Exception as error:
294
+ raise BadRequest(
295
+ f"Failed to connect to LLM API for {provider_key.value} {input.model.name}: "
296
+ f"{str(error)}"
297
+ )
274
298
  return await cls._chat_completion(info, llm_client, input)
275
299
 
276
300
  @classmethod
@@ -459,6 +483,8 @@ def _template_formatter(template_language: TemplateLanguage) -> TemplateFormatte
459
483
  return MustacheTemplateFormatter()
460
484
  if template_language is TemplateLanguage.F_STRING:
461
485
  return FStringTemplateFormatter()
486
+ if template_language is TemplateLanguage.NONE:
487
+ return NoOpFormatter()
462
488
  assert_never(template_language)
463
489
 
464
490
 
@@ -486,6 +512,11 @@ def _llm_output_messages(
486
512
  if text_content:
487
513
  yield f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_CONTENT}", text_content
488
514
  for tool_call_index, tool_call in enumerate(tool_calls.values()):
515
+ if tool_call_id := tool_call.id:
516
+ yield (
517
+ f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_ID}",
518
+ tool_call_id,
519
+ )
489
520
  yield (
490
521
  f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_NAME}",
491
522
  tool_call.function.name,
@@ -513,6 +544,19 @@ def _serialize_event(event: SpanException) -> dict[str, Any]:
513
544
  return {k: (v.isoformat() if isinstance(v, datetime) else v) for k, v in asdict(event).items()}
514
545
 
515
546
 
547
+ _AnyT = TypeVar("_AnyT")
548
+
549
+
550
+ def _get_batches(
551
+ iterable: Iterable[_AnyT],
552
+ batch_size: int,
553
+ ) -> Iterator[list[_AnyT]]:
554
+ """Splits an iterable into batches not exceeding a specified size."""
555
+ iterator = iter(iterable)
556
+ while batch := list(islice(iterator, batch_size)):
557
+ yield batch
558
+
559
+
516
560
  JSON = OpenInferenceMimeTypeValues.JSON.value
517
561
  TEXT = OpenInferenceMimeTypeValues.TEXT.value
518
562
  LLM = OpenInferenceSpanKindValues.LLM.value
@@ -534,6 +578,7 @@ MESSAGE_CONTENT = MessageAttributes.MESSAGE_CONTENT
534
578
  MESSAGE_ROLE = MessageAttributes.MESSAGE_ROLE
535
579
  MESSAGE_TOOL_CALLS = MessageAttributes.MESSAGE_TOOL_CALLS
536
580
 
581
+ TOOL_CALL_ID = ToolCallAttributes.TOOL_CALL_ID
537
582
  TOOL_CALL_FUNCTION_NAME = ToolCallAttributes.TOOL_CALL_FUNCTION_NAME
538
583
  TOOL_CALL_FUNCTION_ARGUMENTS_JSON = ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUMENTS_JSON
539
584
 
@@ -1,10 +1,11 @@
1
1
  import asyncio
2
2
  import logging
3
- from asyncio import FIRST_COMPLETED, Queue, QueueEmpty, Task, create_task, wait, wait_for
4
3
  from collections.abc import AsyncIterator, Iterator
5
- from datetime import datetime, timezone
4
+ from datetime import datetime, timedelta, timezone
6
5
  from typing import (
7
6
  Any,
7
+ AsyncGenerator,
8
+ Coroutine,
8
9
  Iterable,
9
10
  Mapping,
10
11
  Optional,
@@ -23,8 +24,9 @@ from typing_extensions import TypeAlias, assert_never
23
24
 
24
25
  from phoenix.datetime_utils import local_now, normalize_datetime
25
26
  from phoenix.db import models
27
+ from phoenix.server.api.auth import IsNotReadOnly
26
28
  from phoenix.server.api.context import Context
27
- from phoenix.server.api.exceptions import BadRequest, NotFound
29
+ from phoenix.server.api.exceptions import BadRequest, CustomGraphQLError, NotFound
28
30
  from phoenix.server.api.helpers.playground_clients import (
29
31
  PlaygroundStreamingClient,
30
32
  initialize_playground_clients,
@@ -62,6 +64,7 @@ from phoenix.server.types import DbSessionFactory
62
64
  from phoenix.utilities.template_formatters import (
63
65
  FStringTemplateFormatter,
64
66
  MustacheTemplateFormatter,
67
+ NoOpFormatter,
65
68
  TemplateFormatter,
66
69
  TemplateFormatterError,
67
70
  )
@@ -79,23 +82,32 @@ DatasetExampleID: TypeAlias = GlobalID
79
82
  ChatCompletionResult: TypeAlias = tuple[
80
83
  DatasetExampleID, Optional[models.Span], models.ExperimentRun
81
84
  ]
85
+ ChatStream: TypeAlias = AsyncGenerator[ChatCompletionSubscriptionPayload, None]
82
86
  PLAYGROUND_PROJECT_NAME = "playground"
83
87
 
84
88
 
85
89
  @strawberry.type
86
90
  class Subscription:
87
- @strawberry.subscription
91
+ @strawberry.subscription(permission_classes=[IsNotReadOnly]) # type: ignore
88
92
  async def chat_completion(
89
93
  self, info: Info[Context, None], input: ChatCompletionInput
90
94
  ) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
91
95
  provider_key = input.model.provider_key
92
96
  llm_client_class = PLAYGROUND_CLIENT_REGISTRY.get_client(provider_key, input.model.name)
93
97
  if llm_client_class is None:
94
- raise BadRequest(f"No LLM client registered for provider '{provider_key}'")
95
- llm_client = llm_client_class(
96
- model=input.model,
97
- api_key=input.api_key,
98
- )
98
+ raise BadRequest(f"Unknown LLM provider: '{provider_key.value}'")
99
+ try:
100
+ llm_client = llm_client_class(
101
+ model=input.model,
102
+ api_key=input.api_key,
103
+ )
104
+ except CustomGraphQLError:
105
+ raise
106
+ except Exception as error:
107
+ raise BadRequest(
108
+ f"Failed to connect to LLM API for {provider_key.value} {input.model.name}: "
109
+ f"{str(error)}"
110
+ )
99
111
 
100
112
  messages = [
101
113
  (
@@ -151,14 +163,26 @@ class Subscription:
151
163
  info.context.event_queue.put(SpanInsertEvent(ids=(playground_project_id,)))
152
164
  yield ChatCompletionSubscriptionResult(span=to_gql_span(db_span))
153
165
 
154
- @strawberry.subscription
166
+ @strawberry.subscription(permission_classes=[IsNotReadOnly]) # type: ignore
155
167
  async def chat_completion_over_dataset(
156
168
  self, info: Info[Context, None], input: ChatCompletionOverDatasetInput
157
169
  ) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
158
170
  provider_key = input.model.provider_key
159
171
  llm_client_class = PLAYGROUND_CLIENT_REGISTRY.get_client(provider_key, input.model.name)
160
172
  if llm_client_class is None:
161
- raise BadRequest(f"No LLM client registered for provider '{provider_key}'")
173
+ raise BadRequest(f"Unknown LLM provider: '{provider_key.value}'")
174
+ try:
175
+ llm_client = llm_client_class(
176
+ model=input.model,
177
+ api_key=input.api_key,
178
+ )
179
+ except CustomGraphQLError:
180
+ raise
181
+ except Exception as error:
182
+ raise BadRequest(
183
+ f"Failed to connect to LLM API for {provider_key.value} {input.model.name}: "
184
+ f"{str(error)}"
185
+ )
162
186
 
163
187
  dataset_id = from_global_id_with_expected_type(input.dataset_id, Dataset.__name__)
164
188
  version_id = (
@@ -264,45 +288,82 @@ class Subscription:
264
288
  experiment=to_gql_experiment(experiment)
265
289
  ) # eagerly yields experiment so it can be linked by consumers of the subscription
266
290
 
267
- results_queue: Queue[ChatCompletionResult] = Queue()
268
- chat_completion_streams = [
269
- _stream_chat_completion_over_dataset_example(
270
- input=input,
271
- llm_client_class=llm_client_class,
272
- revision=revision,
273
- results_queue=results_queue,
274
- experiment_id=experiment.id,
275
- project_id=playground_project_id,
291
+ results: asyncio.Queue[ChatCompletionResult] = asyncio.Queue()
292
+ not_started: list[tuple[DatasetExampleID, ChatStream]] = [
293
+ (
294
+ GlobalID(DatasetExample.__name__, str(revision.dataset_example_id)),
295
+ _stream_chat_completion_over_dataset_example(
296
+ input=input,
297
+ llm_client=llm_client,
298
+ revision=revision,
299
+ results=results,
300
+ experiment_id=experiment.id,
301
+ project_id=playground_project_id,
302
+ ),
276
303
  )
277
304
  for revision in revisions
278
305
  ]
279
- stream_to_async_tasks: dict[
280
- AsyncIterator[ChatCompletionSubscriptionPayload],
281
- Task[ChatCompletionSubscriptionPayload],
282
- ] = {iterator: _create_task_with_timeout(iterator) for iterator in chat_completion_streams}
283
- batch_size = 10
284
- while stream_to_async_tasks:
285
- async_tasks_to_run = [task for task in stream_to_async_tasks.values()]
286
- completed_tasks, _ = await wait(async_tasks_to_run, return_when=FIRST_COMPLETED)
287
- for task in completed_tasks:
288
- iterator = next(it for it, t in stream_to_async_tasks.items() if t == task)
306
+ in_progress: list[
307
+ tuple[
308
+ Optional[DatasetExampleID],
309
+ ChatStream,
310
+ asyncio.Task[ChatCompletionSubscriptionPayload],
311
+ ]
312
+ ] = []
313
+ max_in_progress = 3
314
+ write_batch_size = 10
315
+ write_interval = timedelta(seconds=10)
316
+ last_write_time = datetime.now()
317
+ while not_started or in_progress:
318
+ while not_started and len(in_progress) < max_in_progress:
319
+ ex_id, stream = not_started.pop()
320
+ task = _create_task_with_timeout(stream)
321
+ in_progress.append((ex_id, stream, task))
322
+ async_tasks_to_run = [task for _, _, task in in_progress]
323
+ completed_tasks, _ = await asyncio.wait(
324
+ async_tasks_to_run, return_when=asyncio.FIRST_COMPLETED
325
+ )
326
+ for completed_task in completed_tasks:
327
+ idx = [task for _, _, task in in_progress].index(completed_task)
328
+ example_id, stream, _ = in_progress[idx]
289
329
  try:
290
- yield task.result()
291
- except (StopAsyncIteration, asyncio.TimeoutError):
292
- del stream_to_async_tasks[iterator] # removes exhausted iterator
330
+ yield completed_task.result()
331
+ except StopAsyncIteration:
332
+ del in_progress[idx] # removes exhausted stream
333
+ except asyncio.TimeoutError:
334
+ del in_progress[idx] # removes timed-out stream
335
+ if example_id is not None:
336
+ yield ChatCompletionSubscriptionError(
337
+ message="Playground task timed out", dataset_example_id=example_id
338
+ )
293
339
  except Exception as error:
294
- del stream_to_async_tasks[iterator] # removes failed iterator
340
+ del in_progress[idx] # removes failed stream
341
+ if example_id is not None:
342
+ yield ChatCompletionSubscriptionError(
343
+ message="An unexpected error occurred", dataset_example_id=example_id
344
+ )
295
345
  logger.exception(error)
296
346
  else:
297
- stream_to_async_tasks[iterator] = _create_task_with_timeout(iterator)
298
- if results_queue.qsize() >= batch_size:
299
- result_iterator = _chat_completion_result_payloads(
300
- db=info.context.db, results=_drain_no_wait(results_queue)
301
- )
302
- stream_to_async_tasks[result_iterator] = _create_task_with_timeout(
303
- result_iterator
347
+ task = _create_task_with_timeout(stream)
348
+ in_progress[idx] = (example_id, stream, task)
349
+
350
+ exceeded_write_batch_size = results.qsize() >= write_batch_size
351
+ exceeded_write_interval = datetime.now() - last_write_time > write_interval
352
+ write_already_in_progress = any(
353
+ _is_result_payloads_stream(stream) for _, stream, _ in in_progress
354
+ )
355
+ if (
356
+ not results.empty()
357
+ and (exceeded_write_batch_size or exceeded_write_interval)
358
+ and not write_already_in_progress
359
+ ):
360
+ result_payloads_stream = _chat_completion_result_payloads(
361
+ db=info.context.db, results=_drain_no_wait(results)
304
362
  )
305
- if remaining_results := await _drain(results_queue):
363
+ task = _create_task_with_timeout(result_payloads_stream)
364
+ in_progress.append((None, result_payloads_stream, task))
365
+ last_write_time = datetime.now()
366
+ if remaining_results := await _drain(results):
306
367
  async for result_payload in _chat_completion_result_payloads(
307
368
  db=info.context.db, results=remaining_results
308
369
  ):
@@ -312,17 +373,13 @@ class Subscription:
312
373
  async def _stream_chat_completion_over_dataset_example(
313
374
  *,
314
375
  input: ChatCompletionOverDatasetInput,
315
- llm_client_class: type["PlaygroundStreamingClient"],
376
+ llm_client: PlaygroundStreamingClient,
316
377
  revision: models.DatasetExampleRevision,
317
- results_queue: Queue[ChatCompletionResult],
378
+ results: asyncio.Queue[ChatCompletionResult],
318
379
  experiment_id: int,
319
380
  project_id: int,
320
- ) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
381
+ ) -> ChatStream:
321
382
  example_id = GlobalID(DatasetExample.__name__, str(revision.dataset_example_id))
322
- llm_client = llm_client_class(
323
- model=input.model,
324
- api_key=input.api_key,
325
- )
326
383
  invocation_parameters = llm_client.construct_invocation_parameters(input.invocation_parameters)
327
384
  messages = [
328
385
  (
@@ -345,7 +402,7 @@ async def _stream_chat_completion_over_dataset_example(
345
402
  except TemplateFormatterError as error:
346
403
  format_end_time = cast(datetime, normalize_datetime(dt=local_now(), tz=timezone.utc))
347
404
  yield ChatCompletionSubscriptionError(message=str(error), dataset_example_id=example_id)
348
- await results_queue.put(
405
+ await results.put(
349
406
  (
350
407
  example_id,
351
408
  None,
@@ -380,7 +437,7 @@ async def _stream_chat_completion_over_dataset_example(
380
437
  db_run = get_db_experiment_run(
381
438
  db_span, db_trace, experiment_id=experiment_id, example_id=revision.dataset_example_id
382
439
  )
383
- await results_queue.put((example_id, db_span, db_run))
440
+ await results.put((example_id, db_span, db_run))
384
441
  if span.status_message is not None:
385
442
  yield ChatCompletionSubscriptionError(
386
443
  message=span.status_message, dataset_example_id=example_id
@@ -391,7 +448,7 @@ async def _chat_completion_result_payloads(
391
448
  *,
392
449
  db: DbSessionFactory,
393
450
  results: Sequence[ChatCompletionResult],
394
- ) -> AsyncIterator[ChatCompletionSubscriptionResult]:
451
+ ) -> ChatStream:
395
452
  if not results:
396
453
  return
397
454
  async with db() as session:
@@ -408,25 +465,64 @@ async def _chat_completion_result_payloads(
408
465
  )
409
466
 
410
467
 
468
+ def _is_result_payloads_stream(
469
+ stream: ChatStream,
470
+ ) -> bool:
471
+ """
472
+ Checks if the given generator was instantiated from
473
+ `_chat_completion_result_payloads`
474
+ """
475
+ return stream.ag_code == _chat_completion_result_payloads.__code__
476
+
477
+
411
478
  def _create_task_with_timeout(
412
- iterable: AsyncIterator[GenericType], timeout_in_seconds: int = 60
413
- ) -> Task[GenericType]:
414
- return create_task(wait_for(_as_coroutine(iterable), timeout=timeout_in_seconds))
479
+ iterable: AsyncIterator[GenericType], timeout_in_seconds: int = 90
480
+ ) -> asyncio.Task[GenericType]:
481
+ return asyncio.create_task(
482
+ _wait_for(
483
+ _as_coroutine(iterable),
484
+ timeout=timeout_in_seconds,
485
+ timeout_message="Playground task timed out",
486
+ )
487
+ )
488
+
489
+
490
+ async def _wait_for(
491
+ coro: Coroutine[None, None, GenericType],
492
+ timeout: float,
493
+ timeout_message: Optional[str] = None,
494
+ ) -> GenericType:
495
+ """
496
+ A function that imitates asyncio.wait_for, but allows the task to be
497
+ cancelled with a custom message.
498
+ """
499
+ task = asyncio.create_task(coro)
500
+ done, pending = await asyncio.wait([task], timeout=timeout)
501
+ assert len(done) + len(pending) == 1
502
+ if done:
503
+ task = done.pop()
504
+ return task.result()
505
+ task = pending.pop()
506
+ task.cancel(msg=timeout_message)
507
+ try:
508
+ return await task
509
+ except asyncio.CancelledError:
510
+ raise asyncio.TimeoutError()
415
511
 
416
512
 
417
- async def _drain(queue: Queue[GenericType]) -> list[GenericType]:
513
+ async def _drain(queue: asyncio.Queue[GenericType]) -> list[GenericType]:
418
514
  values: list[GenericType] = []
419
515
  while not queue.empty():
420
516
  values.append(await queue.get())
421
517
  return values
422
518
 
423
519
 
424
- def _drain_no_wait(queue: Queue[GenericType]) -> list[GenericType]:
520
+ def _drain_no_wait(queue: asyncio.Queue[GenericType]) -> list[GenericType]:
425
521
  values: list[GenericType] = []
426
522
  while True:
427
523
  try:
428
524
  values.append(queue.get_nowait())
429
- except QueueEmpty:
525
+ except asyncio.QueueEmpty:
430
526
  break
431
527
  return values
432
528
 
@@ -467,6 +563,8 @@ def _template_formatter(template_language: TemplateLanguage) -> TemplateFormatte
467
563
  return MustacheTemplateFormatter()
468
564
  if template_language is TemplateLanguage.F_STRING:
469
565
  return FStringTemplateFormatter()
566
+ if template_language is TemplateLanguage.NONE:
567
+ return NoOpFormatter()
470
568
  assert_never(template_language)
471
569
 
472
570
 
@@ -5,5 +5,6 @@ import strawberry
5
5
 
6
6
  @strawberry.enum
7
7
  class TemplateLanguage(Enum):
8
+ NONE = "NONE"
8
9
  MUSTACHE = "MUSTACHE"
9
10
  F_STRING = "F_STRING"
@@ -1,32 +1,32 @@
1
1
  {
2
- "_components-BXIz9ZO8.js": {
3
- "file": "assets/components-BXIz9ZO8.js",
2
+ "_components-72cQL1d1.js": {
3
+ "file": "assets/components-72cQL1d1.js",
4
4
  "name": "components",
5
5
  "imports": [
6
- "_vendor-BX8_Znqy.js",
7
- "_pages-B8FpJuXu.js",
8
- "_vendor-arizeai-CtHir-Ua.js",
9
- "_vendor-codemirror-DLlGiguX.js",
6
+ "_vendor-DexmGnha.js",
7
+ "_pages-DFAkBAUh.js",
8
+ "_vendor-arizeai--Q3ol330.js",
9
+ "_vendor-codemirror-B4bYvWa6.js",
10
10
  "_vendor-three-DwGkEfCM.js"
11
11
  ]
12
12
  },
13
- "_pages-B8FpJuXu.js": {
14
- "file": "assets/pages-B8FpJuXu.js",
13
+ "_pages-DFAkBAUh.js": {
14
+ "file": "assets/pages-DFAkBAUh.js",
15
15
  "name": "pages",
16
16
  "imports": [
17
- "_vendor-BX8_Znqy.js",
18
- "_vendor-arizeai-CtHir-Ua.js",
19
- "_components-BXIz9ZO8.js",
20
- "_vendor-recharts-CJRple0d.js",
21
- "_vendor-codemirror-DLlGiguX.js"
17
+ "_vendor-DexmGnha.js",
18
+ "_vendor-arizeai--Q3ol330.js",
19
+ "_components-72cQL1d1.js",
20
+ "_vendor-recharts-B4ZzJhNh.js",
21
+ "_vendor-codemirror-B4bYvWa6.js"
22
22
  ]
23
23
  },
24
24
  "_vendor-!~{003}~.js": {
25
25
  "file": "assets/vendor-DxkFTwjz.css",
26
26
  "src": "_vendor-!~{003}~.js"
27
27
  },
28
- "_vendor-BX8_Znqy.js": {
29
- "file": "assets/vendor-BX8_Znqy.js",
28
+ "_vendor-DexmGnha.js": {
29
+ "file": "assets/vendor-DexmGnha.js",
30
30
  "name": "vendor",
31
31
  "imports": [
32
32
  "_vendor-three-DwGkEfCM.js"
@@ -35,25 +35,25 @@
35
35
  "assets/vendor-DxkFTwjz.css"
36
36
  ]
37
37
  },
38
- "_vendor-arizeai-CtHir-Ua.js": {
39
- "file": "assets/vendor-arizeai-CtHir-Ua.js",
38
+ "_vendor-arizeai--Q3ol330.js": {
39
+ "file": "assets/vendor-arizeai--Q3ol330.js",
40
40
  "name": "vendor-arizeai",
41
41
  "imports": [
42
- "_vendor-BX8_Znqy.js"
42
+ "_vendor-DexmGnha.js"
43
43
  ]
44
44
  },
45
- "_vendor-codemirror-DLlGiguX.js": {
46
- "file": "assets/vendor-codemirror-DLlGiguX.js",
45
+ "_vendor-codemirror-B4bYvWa6.js": {
46
+ "file": "assets/vendor-codemirror-B4bYvWa6.js",
47
47
  "name": "vendor-codemirror",
48
48
  "imports": [
49
- "_vendor-BX8_Znqy.js"
49
+ "_vendor-DexmGnha.js"
50
50
  ]
51
51
  },
52
- "_vendor-recharts-CJRple0d.js": {
53
- "file": "assets/vendor-recharts-CJRple0d.js",
52
+ "_vendor-recharts-B4ZzJhNh.js": {
53
+ "file": "assets/vendor-recharts-B4ZzJhNh.js",
54
54
  "name": "vendor-recharts",
55
55
  "imports": [
56
- "_vendor-BX8_Znqy.js"
56
+ "_vendor-DexmGnha.js"
57
57
  ]
58
58
  },
59
59
  "_vendor-three-DwGkEfCM.js": {
@@ -61,18 +61,18 @@
61
61
  "name": "vendor-three"
62
62
  },
63
63
  "index.tsx": {
64
- "file": "assets/index-DTut7g1y.js",
64
+ "file": "assets/index-BowjltW-.js",
65
65
  "name": "index",
66
66
  "src": "index.tsx",
67
67
  "isEntry": true,
68
68
  "imports": [
69
- "_vendor-BX8_Znqy.js",
70
- "_vendor-arizeai-CtHir-Ua.js",
71
- "_pages-B8FpJuXu.js",
72
- "_components-BXIz9ZO8.js",
69
+ "_vendor-DexmGnha.js",
70
+ "_vendor-arizeai--Q3ol330.js",
71
+ "_pages-DFAkBAUh.js",
72
+ "_components-72cQL1d1.js",
73
73
  "_vendor-three-DwGkEfCM.js",
74
- "_vendor-recharts-CJRple0d.js",
75
- "_vendor-codemirror-DLlGiguX.js"
74
+ "_vendor-recharts-B4ZzJhNh.js",
75
+ "_vendor-codemirror-B4bYvWa6.js"
76
76
  ]
77
77
  }
78
78
  }