arize-phoenix 5.7.0__py3-none-any.whl → 5.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (32) hide show
  1. {arize_phoenix-5.7.0.dist-info → arize_phoenix-5.9.0.dist-info}/METADATA +3 -5
  2. {arize_phoenix-5.7.0.dist-info → arize_phoenix-5.9.0.dist-info}/RECORD +31 -31
  3. {arize_phoenix-5.7.0.dist-info → arize_phoenix-5.9.0.dist-info}/WHEEL +1 -1
  4. phoenix/config.py +19 -3
  5. phoenix/db/helpers.py +55 -1
  6. phoenix/server/api/helpers/playground_clients.py +283 -44
  7. phoenix/server/api/helpers/playground_spans.py +173 -76
  8. phoenix/server/api/input_types/InvocationParameters.py +7 -8
  9. phoenix/server/api/mutations/chat_mutations.py +244 -76
  10. phoenix/server/api/queries.py +5 -1
  11. phoenix/server/api/routers/v1/spans.py +25 -1
  12. phoenix/server/api/subscriptions.py +210 -158
  13. phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +5 -3
  14. phoenix/server/api/types/ExperimentRun.py +38 -1
  15. phoenix/server/api/types/GenerativeProvider.py +2 -1
  16. phoenix/server/app.py +21 -2
  17. phoenix/server/grpc_server.py +3 -1
  18. phoenix/server/static/.vite/manifest.json +32 -32
  19. phoenix/server/static/assets/{components-Csu8UKOs.js → components-DU-8CYbi.js} +370 -329
  20. phoenix/server/static/assets/{index-Bk5C9EA7.js → index-D9E16vvV.js} +2 -2
  21. phoenix/server/static/assets/pages-t09OI1rC.js +3966 -0
  22. phoenix/server/static/assets/{vendor-CtqfhlbC.js → vendor-D04tenE6.js} +181 -181
  23. phoenix/server/static/assets/{vendor-arizeai-C_3SBz56.js → vendor-arizeai-D3NxMQw0.js} +2 -2
  24. phoenix/server/static/assets/{vendor-codemirror-wfdk9cjp.js → vendor-codemirror-XTiZSlqq.js} +5 -5
  25. phoenix/server/static/assets/{vendor-recharts-BiVnSv90.js → vendor-recharts-p0L0neVs.js} +1 -1
  26. phoenix/session/client.py +27 -7
  27. phoenix/utilities/json.py +31 -1
  28. phoenix/version.py +1 -1
  29. phoenix/server/static/assets/pages-UeWaKXNs.js +0 -3737
  30. {arize_phoenix-5.7.0.dist-info → arize_phoenix-5.9.0.dist-info}/entry_points.txt +0 -0
  31. {arize_phoenix-5.7.0.dist-info → arize_phoenix-5.9.0.dist-info}/licenses/IP_NOTICE +0 -0
  32. {arize_phoenix-5.7.0.dist-info → arize_phoenix-5.9.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,13 +1,12 @@
1
+ import asyncio
1
2
  import importlib.util
3
+ import inspect
4
+ import json
5
+ import time
2
6
  from abc import ABC, abstractmethod
3
7
  from collections.abc import AsyncIterator, Callable, Iterator
4
- from typing import (
5
- TYPE_CHECKING,
6
- Any,
7
- Mapping,
8
- Optional,
9
- Union,
10
- )
8
+ from functools import wraps
9
+ from typing import TYPE_CHECKING, Any, Hashable, Mapping, Optional, Union
11
10
 
12
11
  from openinference.instrumentation import safe_json_dumps
13
12
  from openinference.semconv.trace import SpanAttributes
@@ -15,14 +14,19 @@ from strawberry import UNSET
15
14
  from strawberry.scalars import JSON as JSONScalarType
16
15
  from typing_extensions import TypeAlias, assert_never
17
16
 
18
- from phoenix.server.api.helpers.playground_registry import (
19
- PROVIDER_DEFAULT,
20
- register_llm_client,
17
+ from phoenix.evals.models.rate_limiters import (
18
+ AsyncCallable,
19
+ GenericType,
20
+ ParameterSpec,
21
+ RateLimiter,
22
+ RateLimitError,
21
23
  )
24
+ from phoenix.server.api.helpers.playground_registry import PROVIDER_DEFAULT, register_llm_client
22
25
  from phoenix.server.api.input_types.GenerativeModelInput import GenerativeModelInput
23
26
  from phoenix.server.api.input_types.InvocationParameters import (
24
27
  BoundedFloatInvocationParameter,
25
28
  CanonicalParameterName,
29
+ FloatInvocationParameter,
26
30
  IntInvocationParameter,
27
31
  InvocationParameter,
28
32
  InvocationParameterInput,
@@ -41,17 +45,114 @@ from phoenix.server.api.types.GenerativeProvider import GenerativeProviderKey
41
45
 
42
46
  if TYPE_CHECKING:
43
47
  from anthropic.types import MessageParam
48
+ from google.generativeai.types import ContentType
44
49
  from openai.types import CompletionUsage
45
- from openai.types.chat import (
46
- ChatCompletionMessageParam,
47
- ChatCompletionMessageToolCallParam,
48
- )
50
+ from openai.types.chat import ChatCompletionMessageParam, ChatCompletionMessageToolCallParam
49
51
 
50
- DependencyName: TypeAlias = str
51
52
  SetSpanAttributesFn: TypeAlias = Callable[[Mapping[str, Any]], None]
52
53
  ChatCompletionChunk: TypeAlias = Union[TextChunk, ToolCallChunk]
53
54
 
54
55
 
56
+ class Dependency:
57
+ """
58
+ Set the module_name to the import name if it is different from the install name
59
+ """
60
+
61
+ def __init__(self, name: str, module_name: Optional[str] = None):
62
+ self.name = name
63
+ self.module_name = module_name
64
+
65
+ @property
66
+ def import_name(self) -> str:
67
+ return self.module_name or self.name
68
+
69
+
70
+ class KeyedSingleton:
71
+ _instances: dict[Hashable, "KeyedSingleton"] = {}
72
+
73
+ def __new__(cls, *args: Any, **kwargs: Any) -> "KeyedSingleton":
74
+ if "singleton_key" in kwargs:
75
+ singleton_key = kwargs.pop("singleton_key")
76
+ elif args:
77
+ singleton_key = args[0]
78
+ args = args[1:]
79
+ else:
80
+ raise ValueError("singleton_key must be provided")
81
+
82
+ instance_key = (cls, singleton_key)
83
+ if instance_key not in cls._instances:
84
+ instance = super().__new__(cls)
85
+ cls._instances[instance_key] = instance
86
+ return cls._instances[instance_key]
87
+
88
+
89
+ class PlaygroundRateLimiter(RateLimiter, KeyedSingleton):
90
+ """
91
+ A rate rate limiter class that will be instantiated once per `singleton_key`.
92
+ """
93
+
94
+ def __init__(self, singleton_key: Hashable, rate_limit_error: Optional[type[BaseException]]):
95
+ super().__init__(
96
+ rate_limit_error=rate_limit_error,
97
+ max_rate_limit_retries=3,
98
+ initial_per_second_request_rate=2.0,
99
+ maximum_per_second_request_rate=10.0,
100
+ enforcement_window_minutes=1,
101
+ rate_reduction_factor=0.5,
102
+ rate_increase_factor=0.01,
103
+ cooldown_seconds=5,
104
+ verbose=False,
105
+ )
106
+
107
+ # TODO: update the rate limiter class in phoenix.evals to support decorated sync functions
108
+ def _alimit(
109
+ self, fn: Callable[ParameterSpec, GenericType]
110
+ ) -> AsyncCallable[ParameterSpec, GenericType]:
111
+ @wraps(fn)
112
+ async def wrapper(*args: Any, **kwargs: Any) -> GenericType:
113
+ self._initialize_async_primitives()
114
+ assert self._rate_limit_handling_lock is not None and isinstance(
115
+ self._rate_limit_handling_lock, asyncio.Lock
116
+ )
117
+ assert self._rate_limit_handling is not None and isinstance(
118
+ self._rate_limit_handling, asyncio.Event
119
+ )
120
+ try:
121
+ try:
122
+ await asyncio.wait_for(self._rate_limit_handling.wait(), 120)
123
+ except asyncio.TimeoutError:
124
+ self._rate_limit_handling.set() # Set the event as a failsafe
125
+ await self._throttler.async_wait_until_ready()
126
+ request_start_time = time.time()
127
+ if inspect.iscoroutinefunction(fn):
128
+ return await fn(*args, **kwargs) # type: ignore
129
+ else:
130
+ return fn(*args, **kwargs)
131
+ except self._rate_limit_error:
132
+ async with self._rate_limit_handling_lock:
133
+ self._rate_limit_handling.clear() # prevent new requests from starting
134
+ self._throttler.on_rate_limit_error(request_start_time, verbose=self._verbose)
135
+ try:
136
+ for _attempt in range(self._max_rate_limit_retries):
137
+ try:
138
+ request_start_time = time.time()
139
+ await self._throttler.async_wait_until_ready()
140
+ if inspect.iscoroutinefunction(fn):
141
+ return await fn(*args, **kwargs) # type: ignore
142
+ else:
143
+ return fn(*args, **kwargs)
144
+ except self._rate_limit_error:
145
+ self._throttler.on_rate_limit_error(
146
+ request_start_time, verbose=self._verbose
147
+ )
148
+ continue
149
+ finally:
150
+ self._rate_limit_handling.set() # allow new requests to start
151
+ raise RateLimitError(f"Exceeded max ({self._max_rate_limit_retries}) retries")
152
+
153
+ return wrapper
154
+
155
+
55
156
  class PlaygroundStreamingClient(ABC):
56
157
  def __init__(
57
158
  self,
@@ -62,8 +163,8 @@ class PlaygroundStreamingClient(ABC):
62
163
 
63
164
  @classmethod
64
165
  @abstractmethod
65
- def dependencies(cls) -> list[DependencyName]:
66
- # A list of dependency names this client needs to run
166
+ def dependencies(cls) -> list[Dependency]:
167
+ # A list of dependencies this client needs to run
67
168
  ...
68
169
 
69
170
  @classmethod
@@ -108,7 +209,8 @@ class PlaygroundStreamingClient(ABC):
108
209
  def dependencies_are_installed(cls) -> bool:
109
210
  try:
110
211
  for dependency in cls.dependencies():
111
- if importlib.util.find_spec(dependency) is None:
212
+ import_name = dependency.import_name
213
+ if importlib.util.find_spec(import_name) is None:
112
214
  return False
113
215
  return True
114
216
  except ValueError:
@@ -150,14 +252,16 @@ class OpenAIStreamingClient(PlaygroundStreamingClient):
150
252
  api_key: Optional[str] = None,
151
253
  ) -> None:
152
254
  from openai import AsyncOpenAI
255
+ from openai import RateLimitError as OpenAIRateLimitError
153
256
 
154
257
  super().__init__(model=model, api_key=api_key)
155
258
  self.client = AsyncOpenAI(api_key=api_key)
156
259
  self.model_name = model.name
260
+ self.rate_limiter = PlaygroundRateLimiter(model.provider_key, OpenAIRateLimitError)
157
261
 
158
262
  @classmethod
159
- def dependencies(cls) -> list[DependencyName]:
160
- return ["openai"]
263
+ def dependencies(cls) -> list[Dependency]:
264
+ return [Dependency(name="openai")]
161
265
 
162
266
  @classmethod
163
267
  def supported_invocation_parameters(cls) -> list[InvocationParameter]:
@@ -174,19 +278,16 @@ class OpenAIStreamingClient(PlaygroundStreamingClient):
174
278
  invocation_name="max_tokens",
175
279
  canonical_name=CanonicalParameterName.MAX_COMPLETION_TOKENS,
176
280
  label="Max Tokens",
177
- default_value=UNSET,
178
281
  ),
179
282
  BoundedFloatInvocationParameter(
180
283
  invocation_name="frequency_penalty",
181
284
  label="Frequency Penalty",
182
- default_value=UNSET,
183
285
  min_value=-2.0,
184
286
  max_value=2.0,
185
287
  ),
186
288
  BoundedFloatInvocationParameter(
187
289
  invocation_name="presence_penalty",
188
290
  label="Presence Penalty",
189
- default_value=UNSET,
190
291
  min_value=-2.0,
191
292
  max_value=2.0,
192
293
  ),
@@ -194,13 +295,11 @@ class OpenAIStreamingClient(PlaygroundStreamingClient):
194
295
  invocation_name="stop",
195
296
  canonical_name=CanonicalParameterName.STOP_SEQUENCES,
196
297
  label="Stop Sequences",
197
- default_value=UNSET,
198
298
  ),
199
299
  BoundedFloatInvocationParameter(
200
300
  invocation_name="top_p",
201
301
  canonical_name=CanonicalParameterName.TOP_P,
202
302
  label="Top P",
203
- default_value=UNSET,
204
303
  min_value=0.0,
205
304
  max_value=1.0,
206
305
  ),
@@ -208,20 +307,16 @@ class OpenAIStreamingClient(PlaygroundStreamingClient):
208
307
  invocation_name="seed",
209
308
  canonical_name=CanonicalParameterName.RANDOM_SEED,
210
309
  label="Seed",
211
- default_value=UNSET,
212
310
  ),
213
311
  JSONInvocationParameter(
214
312
  invocation_name="tool_choice",
215
313
  label="Tool Choice",
216
314
  canonical_name=CanonicalParameterName.TOOL_CHOICE,
217
- default_value=UNSET,
218
- hidden=True,
219
315
  ),
220
316
  JSONInvocationParameter(
221
317
  invocation_name="response_format",
222
318
  label="Response Format",
223
319
  canonical_name=CanonicalParameterName.RESPONSE_FORMAT,
224
- default_value=UNSET,
225
320
  ),
226
321
  ]
227
322
 
@@ -240,7 +335,8 @@ class OpenAIStreamingClient(PlaygroundStreamingClient):
240
335
  openai_messages = [self.to_openai_chat_completion_param(*message) for message in messages]
241
336
  tool_call_ids: dict[int, str] = {}
242
337
  token_usage: Optional["CompletionUsage"] = None
243
- async for chunk in await self.client.chat.completions.create(
338
+ throttled_create = self.rate_limiter.alimit(self.client.chat.completions.create)
339
+ async for chunk in await throttled_create(
244
340
  messages=openai_messages,
245
341
  model=self.model_name,
246
342
  stream=True,
@@ -251,6 +347,9 @@ class OpenAIStreamingClient(PlaygroundStreamingClient):
251
347
  if (usage := chunk.usage) is not None:
252
348
  token_usage = usage
253
349
  continue
350
+ if not chunk.choices:
351
+ # for Azure, initial chunk contains the content filter
352
+ continue
254
353
  choice = chunk.choices[0]
255
354
  delta = choice.delta
256
355
  if choice.finish_reason is None:
@@ -370,20 +469,16 @@ class OpenAIO1StreamingClient(OpenAIStreamingClient):
370
469
  invocation_name="max_completion_tokens",
371
470
  canonical_name=CanonicalParameterName.MAX_COMPLETION_TOKENS,
372
471
  label="Max Completion Tokens",
373
- default_value=UNSET,
374
472
  ),
375
473
  IntInvocationParameter(
376
474
  invocation_name="seed",
377
475
  canonical_name=CanonicalParameterName.RANDOM_SEED,
378
476
  label="Seed",
379
- default_value=UNSET,
380
477
  ),
381
478
  JSONInvocationParameter(
382
479
  invocation_name="tool_choice",
383
480
  label="Tool Choice",
384
481
  canonical_name=CanonicalParameterName.TOOL_CHOICE,
385
- default_value=UNSET,
386
- hidden=True,
387
482
  ),
388
483
  ]
389
484
 
@@ -409,7 +504,8 @@ class OpenAIO1StreamingClient(OpenAIStreamingClient):
409
504
 
410
505
  tool_call_ids: dict[int, str] = {}
411
506
 
412
- response = await self.client.chat.completions.create(
507
+ throttled_create = self.rate_limiter.alimit(self.client.chat.completions.create)
508
+ response = await throttled_create(
413
509
  messages=openai_messages,
414
510
  model=self.model_name,
415
511
  tools=tools or NOT_GIVEN,
@@ -544,10 +640,11 @@ class AnthropicStreamingClient(PlaygroundStreamingClient):
544
640
  super().__init__(model=model, api_key=api_key)
545
641
  self.client = anthropic.AsyncAnthropic(api_key=api_key)
546
642
  self.model_name = model.name
643
+ self.rate_limiter = PlaygroundRateLimiter(model.provider_key, anthropic.RateLimitError)
547
644
 
548
645
  @classmethod
549
- def dependencies(cls) -> list[DependencyName]:
550
- return ["anthropic"]
646
+ def dependencies(cls) -> list[Dependency]:
647
+ return [Dependency(name="anthropic")]
551
648
 
552
649
  @classmethod
553
650
  def supported_invocation_parameters(cls) -> list[InvocationParameter]:
@@ -556,14 +653,12 @@ class AnthropicStreamingClient(PlaygroundStreamingClient):
556
653
  invocation_name="max_tokens",
557
654
  canonical_name=CanonicalParameterName.MAX_COMPLETION_TOKENS,
558
655
  label="Max Tokens",
559
- default_value=UNSET,
560
656
  required=True,
561
657
  ),
562
658
  BoundedFloatInvocationParameter(
563
659
  invocation_name="temperature",
564
660
  canonical_name=CanonicalParameterName.TEMPERATURE,
565
661
  label="Temperature",
566
- default_value=UNSET,
567
662
  min_value=0.0,
568
663
  max_value=1.0,
569
664
  ),
@@ -571,13 +666,11 @@ class AnthropicStreamingClient(PlaygroundStreamingClient):
571
666
  invocation_name="stop_sequences",
572
667
  canonical_name=CanonicalParameterName.STOP_SEQUENCES,
573
668
  label="Stop Sequences",
574
- default_value=UNSET,
575
669
  ),
576
670
  BoundedFloatInvocationParameter(
577
671
  invocation_name="top_p",
578
672
  canonical_name=CanonicalParameterName.TOP_P,
579
673
  label="Top P",
580
- default_value=UNSET,
581
674
  min_value=0.0,
582
675
  max_value=1.0,
583
676
  ),
@@ -585,8 +678,6 @@ class AnthropicStreamingClient(PlaygroundStreamingClient):
585
678
  invocation_name="tool_choice",
586
679
  label="Tool Choice",
587
680
  canonical_name=CanonicalParameterName.TOOL_CHOICE,
588
- default_value=UNSET,
589
- hidden=True,
590
681
  ),
591
682
  ]
592
683
 
@@ -608,9 +699,11 @@ class AnthropicStreamingClient(PlaygroundStreamingClient):
608
699
  "model": self.model_name,
609
700
  "system": system_prompt,
610
701
  "max_tokens": 1024,
702
+ "tools": tools,
611
703
  **invocation_parameters,
612
704
  }
613
- async with self.client.messages.stream(**anthropic_params) as stream:
705
+ throttled_stream = self.rate_limiter._alimit(self.client.messages.stream)
706
+ async with await throttled_stream(**anthropic_params) as stream:
614
707
  async for event in stream:
615
708
  if isinstance(event, anthropic_types.RawMessageStartEvent):
616
709
  self._attributes.update(
@@ -622,6 +715,18 @@ class AnthropicStreamingClient(PlaygroundStreamingClient):
622
715
  self._attributes.update(
623
716
  {LLM_TOKEN_COUNT_COMPLETION: event.message.usage.output_tokens}
624
717
  )
718
+ elif (
719
+ isinstance(event, anthropic_streaming.ContentBlockStopEvent)
720
+ and event.content_block.type == "tool_use"
721
+ ):
722
+ tool_call_chunk = ToolCallChunk(
723
+ id=event.content_block.id,
724
+ function=FunctionCallChunk(
725
+ name=event.content_block.name,
726
+ arguments=json.dumps(event.content_block.input),
727
+ ),
728
+ )
729
+ yield tool_call_chunk
625
730
  elif isinstance(
626
731
  event,
627
732
  (
@@ -629,6 +734,7 @@ class AnthropicStreamingClient(PlaygroundStreamingClient):
629
734
  anthropic_types.RawContentBlockDeltaEvent,
630
735
  anthropic_types.RawMessageDeltaEvent,
631
736
  anthropic_streaming.ContentBlockStopEvent,
737
+ anthropic_streaming.InputJsonEvent,
632
738
  ),
633
739
  ):
634
740
  # event types emitted by the stream that don't contain useful information
@@ -659,6 +765,139 @@ class AnthropicStreamingClient(PlaygroundStreamingClient):
659
765
  return anthropic_messages, system_prompt
660
766
 
661
767
 
768
+ @register_llm_client(
769
+ provider_key=GenerativeProviderKey.GEMINI,
770
+ model_names=[
771
+ PROVIDER_DEFAULT,
772
+ "gemini-1.5-flash",
773
+ "gemini-1.5-flash-8b",
774
+ "gemini-1.5-pro",
775
+ "gemini-1.0-pro",
776
+ ],
777
+ )
778
+ class GeminiStreamingClient(PlaygroundStreamingClient):
779
+ def __init__(
780
+ self,
781
+ model: GenerativeModelInput,
782
+ api_key: Optional[str] = None,
783
+ ) -> None:
784
+ import google.generativeai as google_genai
785
+
786
+ super().__init__(model=model, api_key=api_key)
787
+ google_genai.configure(api_key=api_key)
788
+ self.model_name = model.name
789
+
790
+ @classmethod
791
+ def dependencies(cls) -> list[Dependency]:
792
+ return [Dependency(name="google-generativeai", module_name="google.generativeai")]
793
+
794
+ @classmethod
795
+ def supported_invocation_parameters(cls) -> list[InvocationParameter]:
796
+ return [
797
+ BoundedFloatInvocationParameter(
798
+ invocation_name="temperature",
799
+ canonical_name=CanonicalParameterName.TEMPERATURE,
800
+ label="Temperature",
801
+ default_value=0.0,
802
+ min_value=0.0,
803
+ max_value=2.0,
804
+ ),
805
+ IntInvocationParameter(
806
+ invocation_name="max_output_tokens",
807
+ canonical_name=CanonicalParameterName.MAX_COMPLETION_TOKENS,
808
+ label="Max Output Tokens",
809
+ ),
810
+ StringListInvocationParameter(
811
+ invocation_name="stop",
812
+ canonical_name=CanonicalParameterName.STOP_SEQUENCES,
813
+ label="Stop Sequences",
814
+ ),
815
+ FloatInvocationParameter(
816
+ invocation_name="presence_penalty",
817
+ label="Presence Penalty",
818
+ ),
819
+ FloatInvocationParameter(
820
+ invocation_name="frequency_penalty",
821
+ label="Frequency Penalty",
822
+ ),
823
+ BoundedFloatInvocationParameter(
824
+ invocation_name="top_p",
825
+ canonical_name=CanonicalParameterName.TOP_P,
826
+ label="Top P",
827
+ min_value=0.0,
828
+ max_value=1.0,
829
+ ),
830
+ BoundedFloatInvocationParameter(
831
+ invocation_name="top_k",
832
+ label="Top K",
833
+ min_value=0.0,
834
+ max_value=1.0,
835
+ ),
836
+ IntInvocationParameter(
837
+ invocation_name="seed",
838
+ canonical_name=CanonicalParameterName.RANDOM_SEED,
839
+ label="Seed",
840
+ ),
841
+ ]
842
+
843
+ async def chat_completion_create(
844
+ self,
845
+ messages: list[
846
+ tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[JSONScalarType]]]
847
+ ],
848
+ tools: list[JSONScalarType],
849
+ **invocation_parameters: Any,
850
+ ) -> AsyncIterator[ChatCompletionChunk]:
851
+ import google.generativeai as google_genai
852
+
853
+ gemini_message_history, current_message, system_prompt = self._build_gemini_messages(
854
+ messages
855
+ )
856
+
857
+ model_args = {"model_name": self.model_name}
858
+ if system_prompt:
859
+ model_args["system_instruction"] = system_prompt
860
+ client = google_genai.GenerativeModel(**model_args)
861
+
862
+ gemini_config = google_genai.GenerationConfig(
863
+ **invocation_parameters,
864
+ )
865
+ gemini_params = {
866
+ "content": current_message,
867
+ "generation_config": gemini_config,
868
+ "stream": True,
869
+ }
870
+
871
+ chat = client.start_chat(history=gemini_message_history)
872
+ stream = await chat.send_message_async(**gemini_params)
873
+ async for event in stream:
874
+ yield TextChunk(content=event.text)
875
+
876
+ def _build_gemini_messages(
877
+ self,
878
+ messages: list[tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[str]]]],
879
+ ) -> tuple[list["ContentType"], str, str]:
880
+ gemini_message_history: list["ContentType"] = []
881
+ system_prompts = []
882
+ for role, content, _tool_call_id, _tool_calls in messages:
883
+ if role == ChatCompletionMessageRole.USER:
884
+ gemini_message_history.append({"role": "user", "parts": content})
885
+ elif role == ChatCompletionMessageRole.AI:
886
+ gemini_message_history.append({"role": "model", "parts": content})
887
+ elif role == ChatCompletionMessageRole.SYSTEM:
888
+ system_prompts.append(content)
889
+ elif role == ChatCompletionMessageRole.TOOL:
890
+ raise NotImplementedError
891
+ else:
892
+ assert_never(role)
893
+ if gemini_message_history:
894
+ prompt = gemini_message_history.pop()["parts"]
895
+ else:
896
+ prompt = ""
897
+
898
+ return gemini_message_history, prompt, "\n".join(system_prompts)
899
+
900
+
662
901
  def initialize_playground_clients() -> None:
663
902
  """
664
903
  Ensure that all playground clients are registered at import time.