arize-phoenix 5.7.0__py3-none-any.whl → 5.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-5.7.0.dist-info → arize_phoenix-5.9.0.dist-info}/METADATA +3 -5
- {arize_phoenix-5.7.0.dist-info → arize_phoenix-5.9.0.dist-info}/RECORD +31 -31
- {arize_phoenix-5.7.0.dist-info → arize_phoenix-5.9.0.dist-info}/WHEEL +1 -1
- phoenix/config.py +19 -3
- phoenix/db/helpers.py +55 -1
- phoenix/server/api/helpers/playground_clients.py +283 -44
- phoenix/server/api/helpers/playground_spans.py +173 -76
- phoenix/server/api/input_types/InvocationParameters.py +7 -8
- phoenix/server/api/mutations/chat_mutations.py +244 -76
- phoenix/server/api/queries.py +5 -1
- phoenix/server/api/routers/v1/spans.py +25 -1
- phoenix/server/api/subscriptions.py +210 -158
- phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +5 -3
- phoenix/server/api/types/ExperimentRun.py +38 -1
- phoenix/server/api/types/GenerativeProvider.py +2 -1
- phoenix/server/app.py +21 -2
- phoenix/server/grpc_server.py +3 -1
- phoenix/server/static/.vite/manifest.json +32 -32
- phoenix/server/static/assets/{components-Csu8UKOs.js → components-DU-8CYbi.js} +370 -329
- phoenix/server/static/assets/{index-Bk5C9EA7.js → index-D9E16vvV.js} +2 -2
- phoenix/server/static/assets/pages-t09OI1rC.js +3966 -0
- phoenix/server/static/assets/{vendor-CtqfhlbC.js → vendor-D04tenE6.js} +181 -181
- phoenix/server/static/assets/{vendor-arizeai-C_3SBz56.js → vendor-arizeai-D3NxMQw0.js} +2 -2
- phoenix/server/static/assets/{vendor-codemirror-wfdk9cjp.js → vendor-codemirror-XTiZSlqq.js} +5 -5
- phoenix/server/static/assets/{vendor-recharts-BiVnSv90.js → vendor-recharts-p0L0neVs.js} +1 -1
- phoenix/session/client.py +27 -7
- phoenix/utilities/json.py +31 -1
- phoenix/version.py +1 -1
- phoenix/server/static/assets/pages-UeWaKXNs.js +0 -3737
- {arize_phoenix-5.7.0.dist-info → arize_phoenix-5.9.0.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-5.7.0.dist-info → arize_phoenix-5.9.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-5.7.0.dist-info → arize_phoenix-5.9.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,13 +1,12 @@
|
|
|
1
|
+
import asyncio
|
|
1
2
|
import importlib.util
|
|
3
|
+
import inspect
|
|
4
|
+
import json
|
|
5
|
+
import time
|
|
2
6
|
from abc import ABC, abstractmethod
|
|
3
7
|
from collections.abc import AsyncIterator, Callable, Iterator
|
|
4
|
-
from
|
|
5
|
-
|
|
6
|
-
Any,
|
|
7
|
-
Mapping,
|
|
8
|
-
Optional,
|
|
9
|
-
Union,
|
|
10
|
-
)
|
|
8
|
+
from functools import wraps
|
|
9
|
+
from typing import TYPE_CHECKING, Any, Hashable, Mapping, Optional, Union
|
|
11
10
|
|
|
12
11
|
from openinference.instrumentation import safe_json_dumps
|
|
13
12
|
from openinference.semconv.trace import SpanAttributes
|
|
@@ -15,14 +14,19 @@ from strawberry import UNSET
|
|
|
15
14
|
from strawberry.scalars import JSON as JSONScalarType
|
|
16
15
|
from typing_extensions import TypeAlias, assert_never
|
|
17
16
|
|
|
18
|
-
from phoenix.
|
|
19
|
-
|
|
20
|
-
|
|
17
|
+
from phoenix.evals.models.rate_limiters import (
|
|
18
|
+
AsyncCallable,
|
|
19
|
+
GenericType,
|
|
20
|
+
ParameterSpec,
|
|
21
|
+
RateLimiter,
|
|
22
|
+
RateLimitError,
|
|
21
23
|
)
|
|
24
|
+
from phoenix.server.api.helpers.playground_registry import PROVIDER_DEFAULT, register_llm_client
|
|
22
25
|
from phoenix.server.api.input_types.GenerativeModelInput import GenerativeModelInput
|
|
23
26
|
from phoenix.server.api.input_types.InvocationParameters import (
|
|
24
27
|
BoundedFloatInvocationParameter,
|
|
25
28
|
CanonicalParameterName,
|
|
29
|
+
FloatInvocationParameter,
|
|
26
30
|
IntInvocationParameter,
|
|
27
31
|
InvocationParameter,
|
|
28
32
|
InvocationParameterInput,
|
|
@@ -41,17 +45,114 @@ from phoenix.server.api.types.GenerativeProvider import GenerativeProviderKey
|
|
|
41
45
|
|
|
42
46
|
if TYPE_CHECKING:
|
|
43
47
|
from anthropic.types import MessageParam
|
|
48
|
+
from google.generativeai.types import ContentType
|
|
44
49
|
from openai.types import CompletionUsage
|
|
45
|
-
from openai.types.chat import
|
|
46
|
-
ChatCompletionMessageParam,
|
|
47
|
-
ChatCompletionMessageToolCallParam,
|
|
48
|
-
)
|
|
50
|
+
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionMessageToolCallParam
|
|
49
51
|
|
|
50
|
-
DependencyName: TypeAlias = str
|
|
51
52
|
SetSpanAttributesFn: TypeAlias = Callable[[Mapping[str, Any]], None]
|
|
52
53
|
ChatCompletionChunk: TypeAlias = Union[TextChunk, ToolCallChunk]
|
|
53
54
|
|
|
54
55
|
|
|
56
|
+
class Dependency:
|
|
57
|
+
"""
|
|
58
|
+
Set the module_name to the import name if it is different from the install name
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
def __init__(self, name: str, module_name: Optional[str] = None):
|
|
62
|
+
self.name = name
|
|
63
|
+
self.module_name = module_name
|
|
64
|
+
|
|
65
|
+
@property
|
|
66
|
+
def import_name(self) -> str:
|
|
67
|
+
return self.module_name or self.name
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class KeyedSingleton:
|
|
71
|
+
_instances: dict[Hashable, "KeyedSingleton"] = {}
|
|
72
|
+
|
|
73
|
+
def __new__(cls, *args: Any, **kwargs: Any) -> "KeyedSingleton":
|
|
74
|
+
if "singleton_key" in kwargs:
|
|
75
|
+
singleton_key = kwargs.pop("singleton_key")
|
|
76
|
+
elif args:
|
|
77
|
+
singleton_key = args[0]
|
|
78
|
+
args = args[1:]
|
|
79
|
+
else:
|
|
80
|
+
raise ValueError("singleton_key must be provided")
|
|
81
|
+
|
|
82
|
+
instance_key = (cls, singleton_key)
|
|
83
|
+
if instance_key not in cls._instances:
|
|
84
|
+
instance = super().__new__(cls)
|
|
85
|
+
cls._instances[instance_key] = instance
|
|
86
|
+
return cls._instances[instance_key]
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class PlaygroundRateLimiter(RateLimiter, KeyedSingleton):
|
|
90
|
+
"""
|
|
91
|
+
A rate rate limiter class that will be instantiated once per `singleton_key`.
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
def __init__(self, singleton_key: Hashable, rate_limit_error: Optional[type[BaseException]]):
|
|
95
|
+
super().__init__(
|
|
96
|
+
rate_limit_error=rate_limit_error,
|
|
97
|
+
max_rate_limit_retries=3,
|
|
98
|
+
initial_per_second_request_rate=2.0,
|
|
99
|
+
maximum_per_second_request_rate=10.0,
|
|
100
|
+
enforcement_window_minutes=1,
|
|
101
|
+
rate_reduction_factor=0.5,
|
|
102
|
+
rate_increase_factor=0.01,
|
|
103
|
+
cooldown_seconds=5,
|
|
104
|
+
verbose=False,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
# TODO: update the rate limiter class in phoenix.evals to support decorated sync functions
|
|
108
|
+
def _alimit(
|
|
109
|
+
self, fn: Callable[ParameterSpec, GenericType]
|
|
110
|
+
) -> AsyncCallable[ParameterSpec, GenericType]:
|
|
111
|
+
@wraps(fn)
|
|
112
|
+
async def wrapper(*args: Any, **kwargs: Any) -> GenericType:
|
|
113
|
+
self._initialize_async_primitives()
|
|
114
|
+
assert self._rate_limit_handling_lock is not None and isinstance(
|
|
115
|
+
self._rate_limit_handling_lock, asyncio.Lock
|
|
116
|
+
)
|
|
117
|
+
assert self._rate_limit_handling is not None and isinstance(
|
|
118
|
+
self._rate_limit_handling, asyncio.Event
|
|
119
|
+
)
|
|
120
|
+
try:
|
|
121
|
+
try:
|
|
122
|
+
await asyncio.wait_for(self._rate_limit_handling.wait(), 120)
|
|
123
|
+
except asyncio.TimeoutError:
|
|
124
|
+
self._rate_limit_handling.set() # Set the event as a failsafe
|
|
125
|
+
await self._throttler.async_wait_until_ready()
|
|
126
|
+
request_start_time = time.time()
|
|
127
|
+
if inspect.iscoroutinefunction(fn):
|
|
128
|
+
return await fn(*args, **kwargs) # type: ignore
|
|
129
|
+
else:
|
|
130
|
+
return fn(*args, **kwargs)
|
|
131
|
+
except self._rate_limit_error:
|
|
132
|
+
async with self._rate_limit_handling_lock:
|
|
133
|
+
self._rate_limit_handling.clear() # prevent new requests from starting
|
|
134
|
+
self._throttler.on_rate_limit_error(request_start_time, verbose=self._verbose)
|
|
135
|
+
try:
|
|
136
|
+
for _attempt in range(self._max_rate_limit_retries):
|
|
137
|
+
try:
|
|
138
|
+
request_start_time = time.time()
|
|
139
|
+
await self._throttler.async_wait_until_ready()
|
|
140
|
+
if inspect.iscoroutinefunction(fn):
|
|
141
|
+
return await fn(*args, **kwargs) # type: ignore
|
|
142
|
+
else:
|
|
143
|
+
return fn(*args, **kwargs)
|
|
144
|
+
except self._rate_limit_error:
|
|
145
|
+
self._throttler.on_rate_limit_error(
|
|
146
|
+
request_start_time, verbose=self._verbose
|
|
147
|
+
)
|
|
148
|
+
continue
|
|
149
|
+
finally:
|
|
150
|
+
self._rate_limit_handling.set() # allow new requests to start
|
|
151
|
+
raise RateLimitError(f"Exceeded max ({self._max_rate_limit_retries}) retries")
|
|
152
|
+
|
|
153
|
+
return wrapper
|
|
154
|
+
|
|
155
|
+
|
|
55
156
|
class PlaygroundStreamingClient(ABC):
|
|
56
157
|
def __init__(
|
|
57
158
|
self,
|
|
@@ -62,8 +163,8 @@ class PlaygroundStreamingClient(ABC):
|
|
|
62
163
|
|
|
63
164
|
@classmethod
|
|
64
165
|
@abstractmethod
|
|
65
|
-
def dependencies(cls) -> list[
|
|
66
|
-
# A list of
|
|
166
|
+
def dependencies(cls) -> list[Dependency]:
|
|
167
|
+
# A list of dependencies this client needs to run
|
|
67
168
|
...
|
|
68
169
|
|
|
69
170
|
@classmethod
|
|
@@ -108,7 +209,8 @@ class PlaygroundStreamingClient(ABC):
|
|
|
108
209
|
def dependencies_are_installed(cls) -> bool:
|
|
109
210
|
try:
|
|
110
211
|
for dependency in cls.dependencies():
|
|
111
|
-
|
|
212
|
+
import_name = dependency.import_name
|
|
213
|
+
if importlib.util.find_spec(import_name) is None:
|
|
112
214
|
return False
|
|
113
215
|
return True
|
|
114
216
|
except ValueError:
|
|
@@ -150,14 +252,16 @@ class OpenAIStreamingClient(PlaygroundStreamingClient):
|
|
|
150
252
|
api_key: Optional[str] = None,
|
|
151
253
|
) -> None:
|
|
152
254
|
from openai import AsyncOpenAI
|
|
255
|
+
from openai import RateLimitError as OpenAIRateLimitError
|
|
153
256
|
|
|
154
257
|
super().__init__(model=model, api_key=api_key)
|
|
155
258
|
self.client = AsyncOpenAI(api_key=api_key)
|
|
156
259
|
self.model_name = model.name
|
|
260
|
+
self.rate_limiter = PlaygroundRateLimiter(model.provider_key, OpenAIRateLimitError)
|
|
157
261
|
|
|
158
262
|
@classmethod
|
|
159
|
-
def dependencies(cls) -> list[
|
|
160
|
-
return ["openai"]
|
|
263
|
+
def dependencies(cls) -> list[Dependency]:
|
|
264
|
+
return [Dependency(name="openai")]
|
|
161
265
|
|
|
162
266
|
@classmethod
|
|
163
267
|
def supported_invocation_parameters(cls) -> list[InvocationParameter]:
|
|
@@ -174,19 +278,16 @@ class OpenAIStreamingClient(PlaygroundStreamingClient):
|
|
|
174
278
|
invocation_name="max_tokens",
|
|
175
279
|
canonical_name=CanonicalParameterName.MAX_COMPLETION_TOKENS,
|
|
176
280
|
label="Max Tokens",
|
|
177
|
-
default_value=UNSET,
|
|
178
281
|
),
|
|
179
282
|
BoundedFloatInvocationParameter(
|
|
180
283
|
invocation_name="frequency_penalty",
|
|
181
284
|
label="Frequency Penalty",
|
|
182
|
-
default_value=UNSET,
|
|
183
285
|
min_value=-2.0,
|
|
184
286
|
max_value=2.0,
|
|
185
287
|
),
|
|
186
288
|
BoundedFloatInvocationParameter(
|
|
187
289
|
invocation_name="presence_penalty",
|
|
188
290
|
label="Presence Penalty",
|
|
189
|
-
default_value=UNSET,
|
|
190
291
|
min_value=-2.0,
|
|
191
292
|
max_value=2.0,
|
|
192
293
|
),
|
|
@@ -194,13 +295,11 @@ class OpenAIStreamingClient(PlaygroundStreamingClient):
|
|
|
194
295
|
invocation_name="stop",
|
|
195
296
|
canonical_name=CanonicalParameterName.STOP_SEQUENCES,
|
|
196
297
|
label="Stop Sequences",
|
|
197
|
-
default_value=UNSET,
|
|
198
298
|
),
|
|
199
299
|
BoundedFloatInvocationParameter(
|
|
200
300
|
invocation_name="top_p",
|
|
201
301
|
canonical_name=CanonicalParameterName.TOP_P,
|
|
202
302
|
label="Top P",
|
|
203
|
-
default_value=UNSET,
|
|
204
303
|
min_value=0.0,
|
|
205
304
|
max_value=1.0,
|
|
206
305
|
),
|
|
@@ -208,20 +307,16 @@ class OpenAIStreamingClient(PlaygroundStreamingClient):
|
|
|
208
307
|
invocation_name="seed",
|
|
209
308
|
canonical_name=CanonicalParameterName.RANDOM_SEED,
|
|
210
309
|
label="Seed",
|
|
211
|
-
default_value=UNSET,
|
|
212
310
|
),
|
|
213
311
|
JSONInvocationParameter(
|
|
214
312
|
invocation_name="tool_choice",
|
|
215
313
|
label="Tool Choice",
|
|
216
314
|
canonical_name=CanonicalParameterName.TOOL_CHOICE,
|
|
217
|
-
default_value=UNSET,
|
|
218
|
-
hidden=True,
|
|
219
315
|
),
|
|
220
316
|
JSONInvocationParameter(
|
|
221
317
|
invocation_name="response_format",
|
|
222
318
|
label="Response Format",
|
|
223
319
|
canonical_name=CanonicalParameterName.RESPONSE_FORMAT,
|
|
224
|
-
default_value=UNSET,
|
|
225
320
|
),
|
|
226
321
|
]
|
|
227
322
|
|
|
@@ -240,7 +335,8 @@ class OpenAIStreamingClient(PlaygroundStreamingClient):
|
|
|
240
335
|
openai_messages = [self.to_openai_chat_completion_param(*message) for message in messages]
|
|
241
336
|
tool_call_ids: dict[int, str] = {}
|
|
242
337
|
token_usage: Optional["CompletionUsage"] = None
|
|
243
|
-
|
|
338
|
+
throttled_create = self.rate_limiter.alimit(self.client.chat.completions.create)
|
|
339
|
+
async for chunk in await throttled_create(
|
|
244
340
|
messages=openai_messages,
|
|
245
341
|
model=self.model_name,
|
|
246
342
|
stream=True,
|
|
@@ -251,6 +347,9 @@ class OpenAIStreamingClient(PlaygroundStreamingClient):
|
|
|
251
347
|
if (usage := chunk.usage) is not None:
|
|
252
348
|
token_usage = usage
|
|
253
349
|
continue
|
|
350
|
+
if not chunk.choices:
|
|
351
|
+
# for Azure, initial chunk contains the content filter
|
|
352
|
+
continue
|
|
254
353
|
choice = chunk.choices[0]
|
|
255
354
|
delta = choice.delta
|
|
256
355
|
if choice.finish_reason is None:
|
|
@@ -370,20 +469,16 @@ class OpenAIO1StreamingClient(OpenAIStreamingClient):
|
|
|
370
469
|
invocation_name="max_completion_tokens",
|
|
371
470
|
canonical_name=CanonicalParameterName.MAX_COMPLETION_TOKENS,
|
|
372
471
|
label="Max Completion Tokens",
|
|
373
|
-
default_value=UNSET,
|
|
374
472
|
),
|
|
375
473
|
IntInvocationParameter(
|
|
376
474
|
invocation_name="seed",
|
|
377
475
|
canonical_name=CanonicalParameterName.RANDOM_SEED,
|
|
378
476
|
label="Seed",
|
|
379
|
-
default_value=UNSET,
|
|
380
477
|
),
|
|
381
478
|
JSONInvocationParameter(
|
|
382
479
|
invocation_name="tool_choice",
|
|
383
480
|
label="Tool Choice",
|
|
384
481
|
canonical_name=CanonicalParameterName.TOOL_CHOICE,
|
|
385
|
-
default_value=UNSET,
|
|
386
|
-
hidden=True,
|
|
387
482
|
),
|
|
388
483
|
]
|
|
389
484
|
|
|
@@ -409,7 +504,8 @@ class OpenAIO1StreamingClient(OpenAIStreamingClient):
|
|
|
409
504
|
|
|
410
505
|
tool_call_ids: dict[int, str] = {}
|
|
411
506
|
|
|
412
|
-
|
|
507
|
+
throttled_create = self.rate_limiter.alimit(self.client.chat.completions.create)
|
|
508
|
+
response = await throttled_create(
|
|
413
509
|
messages=openai_messages,
|
|
414
510
|
model=self.model_name,
|
|
415
511
|
tools=tools or NOT_GIVEN,
|
|
@@ -544,10 +640,11 @@ class AnthropicStreamingClient(PlaygroundStreamingClient):
|
|
|
544
640
|
super().__init__(model=model, api_key=api_key)
|
|
545
641
|
self.client = anthropic.AsyncAnthropic(api_key=api_key)
|
|
546
642
|
self.model_name = model.name
|
|
643
|
+
self.rate_limiter = PlaygroundRateLimiter(model.provider_key, anthropic.RateLimitError)
|
|
547
644
|
|
|
548
645
|
@classmethod
|
|
549
|
-
def dependencies(cls) -> list[
|
|
550
|
-
return ["anthropic"]
|
|
646
|
+
def dependencies(cls) -> list[Dependency]:
|
|
647
|
+
return [Dependency(name="anthropic")]
|
|
551
648
|
|
|
552
649
|
@classmethod
|
|
553
650
|
def supported_invocation_parameters(cls) -> list[InvocationParameter]:
|
|
@@ -556,14 +653,12 @@ class AnthropicStreamingClient(PlaygroundStreamingClient):
|
|
|
556
653
|
invocation_name="max_tokens",
|
|
557
654
|
canonical_name=CanonicalParameterName.MAX_COMPLETION_TOKENS,
|
|
558
655
|
label="Max Tokens",
|
|
559
|
-
default_value=UNSET,
|
|
560
656
|
required=True,
|
|
561
657
|
),
|
|
562
658
|
BoundedFloatInvocationParameter(
|
|
563
659
|
invocation_name="temperature",
|
|
564
660
|
canonical_name=CanonicalParameterName.TEMPERATURE,
|
|
565
661
|
label="Temperature",
|
|
566
|
-
default_value=UNSET,
|
|
567
662
|
min_value=0.0,
|
|
568
663
|
max_value=1.0,
|
|
569
664
|
),
|
|
@@ -571,13 +666,11 @@ class AnthropicStreamingClient(PlaygroundStreamingClient):
|
|
|
571
666
|
invocation_name="stop_sequences",
|
|
572
667
|
canonical_name=CanonicalParameterName.STOP_SEQUENCES,
|
|
573
668
|
label="Stop Sequences",
|
|
574
|
-
default_value=UNSET,
|
|
575
669
|
),
|
|
576
670
|
BoundedFloatInvocationParameter(
|
|
577
671
|
invocation_name="top_p",
|
|
578
672
|
canonical_name=CanonicalParameterName.TOP_P,
|
|
579
673
|
label="Top P",
|
|
580
|
-
default_value=UNSET,
|
|
581
674
|
min_value=0.0,
|
|
582
675
|
max_value=1.0,
|
|
583
676
|
),
|
|
@@ -585,8 +678,6 @@ class AnthropicStreamingClient(PlaygroundStreamingClient):
|
|
|
585
678
|
invocation_name="tool_choice",
|
|
586
679
|
label="Tool Choice",
|
|
587
680
|
canonical_name=CanonicalParameterName.TOOL_CHOICE,
|
|
588
|
-
default_value=UNSET,
|
|
589
|
-
hidden=True,
|
|
590
681
|
),
|
|
591
682
|
]
|
|
592
683
|
|
|
@@ -608,9 +699,11 @@ class AnthropicStreamingClient(PlaygroundStreamingClient):
|
|
|
608
699
|
"model": self.model_name,
|
|
609
700
|
"system": system_prompt,
|
|
610
701
|
"max_tokens": 1024,
|
|
702
|
+
"tools": tools,
|
|
611
703
|
**invocation_parameters,
|
|
612
704
|
}
|
|
613
|
-
|
|
705
|
+
throttled_stream = self.rate_limiter._alimit(self.client.messages.stream)
|
|
706
|
+
async with await throttled_stream(**anthropic_params) as stream:
|
|
614
707
|
async for event in stream:
|
|
615
708
|
if isinstance(event, anthropic_types.RawMessageStartEvent):
|
|
616
709
|
self._attributes.update(
|
|
@@ -622,6 +715,18 @@ class AnthropicStreamingClient(PlaygroundStreamingClient):
|
|
|
622
715
|
self._attributes.update(
|
|
623
716
|
{LLM_TOKEN_COUNT_COMPLETION: event.message.usage.output_tokens}
|
|
624
717
|
)
|
|
718
|
+
elif (
|
|
719
|
+
isinstance(event, anthropic_streaming.ContentBlockStopEvent)
|
|
720
|
+
and event.content_block.type == "tool_use"
|
|
721
|
+
):
|
|
722
|
+
tool_call_chunk = ToolCallChunk(
|
|
723
|
+
id=event.content_block.id,
|
|
724
|
+
function=FunctionCallChunk(
|
|
725
|
+
name=event.content_block.name,
|
|
726
|
+
arguments=json.dumps(event.content_block.input),
|
|
727
|
+
),
|
|
728
|
+
)
|
|
729
|
+
yield tool_call_chunk
|
|
625
730
|
elif isinstance(
|
|
626
731
|
event,
|
|
627
732
|
(
|
|
@@ -629,6 +734,7 @@ class AnthropicStreamingClient(PlaygroundStreamingClient):
|
|
|
629
734
|
anthropic_types.RawContentBlockDeltaEvent,
|
|
630
735
|
anthropic_types.RawMessageDeltaEvent,
|
|
631
736
|
anthropic_streaming.ContentBlockStopEvent,
|
|
737
|
+
anthropic_streaming.InputJsonEvent,
|
|
632
738
|
),
|
|
633
739
|
):
|
|
634
740
|
# event types emitted by the stream that don't contain useful information
|
|
@@ -659,6 +765,139 @@ class AnthropicStreamingClient(PlaygroundStreamingClient):
|
|
|
659
765
|
return anthropic_messages, system_prompt
|
|
660
766
|
|
|
661
767
|
|
|
768
|
+
@register_llm_client(
|
|
769
|
+
provider_key=GenerativeProviderKey.GEMINI,
|
|
770
|
+
model_names=[
|
|
771
|
+
PROVIDER_DEFAULT,
|
|
772
|
+
"gemini-1.5-flash",
|
|
773
|
+
"gemini-1.5-flash-8b",
|
|
774
|
+
"gemini-1.5-pro",
|
|
775
|
+
"gemini-1.0-pro",
|
|
776
|
+
],
|
|
777
|
+
)
|
|
778
|
+
class GeminiStreamingClient(PlaygroundStreamingClient):
|
|
779
|
+
def __init__(
|
|
780
|
+
self,
|
|
781
|
+
model: GenerativeModelInput,
|
|
782
|
+
api_key: Optional[str] = None,
|
|
783
|
+
) -> None:
|
|
784
|
+
import google.generativeai as google_genai
|
|
785
|
+
|
|
786
|
+
super().__init__(model=model, api_key=api_key)
|
|
787
|
+
google_genai.configure(api_key=api_key)
|
|
788
|
+
self.model_name = model.name
|
|
789
|
+
|
|
790
|
+
@classmethod
|
|
791
|
+
def dependencies(cls) -> list[Dependency]:
|
|
792
|
+
return [Dependency(name="google-generativeai", module_name="google.generativeai")]
|
|
793
|
+
|
|
794
|
+
@classmethod
|
|
795
|
+
def supported_invocation_parameters(cls) -> list[InvocationParameter]:
|
|
796
|
+
return [
|
|
797
|
+
BoundedFloatInvocationParameter(
|
|
798
|
+
invocation_name="temperature",
|
|
799
|
+
canonical_name=CanonicalParameterName.TEMPERATURE,
|
|
800
|
+
label="Temperature",
|
|
801
|
+
default_value=0.0,
|
|
802
|
+
min_value=0.0,
|
|
803
|
+
max_value=2.0,
|
|
804
|
+
),
|
|
805
|
+
IntInvocationParameter(
|
|
806
|
+
invocation_name="max_output_tokens",
|
|
807
|
+
canonical_name=CanonicalParameterName.MAX_COMPLETION_TOKENS,
|
|
808
|
+
label="Max Output Tokens",
|
|
809
|
+
),
|
|
810
|
+
StringListInvocationParameter(
|
|
811
|
+
invocation_name="stop",
|
|
812
|
+
canonical_name=CanonicalParameterName.STOP_SEQUENCES,
|
|
813
|
+
label="Stop Sequences",
|
|
814
|
+
),
|
|
815
|
+
FloatInvocationParameter(
|
|
816
|
+
invocation_name="presence_penalty",
|
|
817
|
+
label="Presence Penalty",
|
|
818
|
+
),
|
|
819
|
+
FloatInvocationParameter(
|
|
820
|
+
invocation_name="frequency_penalty",
|
|
821
|
+
label="Frequency Penalty",
|
|
822
|
+
),
|
|
823
|
+
BoundedFloatInvocationParameter(
|
|
824
|
+
invocation_name="top_p",
|
|
825
|
+
canonical_name=CanonicalParameterName.TOP_P,
|
|
826
|
+
label="Top P",
|
|
827
|
+
min_value=0.0,
|
|
828
|
+
max_value=1.0,
|
|
829
|
+
),
|
|
830
|
+
BoundedFloatInvocationParameter(
|
|
831
|
+
invocation_name="top_k",
|
|
832
|
+
label="Top K",
|
|
833
|
+
min_value=0.0,
|
|
834
|
+
max_value=1.0,
|
|
835
|
+
),
|
|
836
|
+
IntInvocationParameter(
|
|
837
|
+
invocation_name="seed",
|
|
838
|
+
canonical_name=CanonicalParameterName.RANDOM_SEED,
|
|
839
|
+
label="Seed",
|
|
840
|
+
),
|
|
841
|
+
]
|
|
842
|
+
|
|
843
|
+
async def chat_completion_create(
|
|
844
|
+
self,
|
|
845
|
+
messages: list[
|
|
846
|
+
tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[JSONScalarType]]]
|
|
847
|
+
],
|
|
848
|
+
tools: list[JSONScalarType],
|
|
849
|
+
**invocation_parameters: Any,
|
|
850
|
+
) -> AsyncIterator[ChatCompletionChunk]:
|
|
851
|
+
import google.generativeai as google_genai
|
|
852
|
+
|
|
853
|
+
gemini_message_history, current_message, system_prompt = self._build_gemini_messages(
|
|
854
|
+
messages
|
|
855
|
+
)
|
|
856
|
+
|
|
857
|
+
model_args = {"model_name": self.model_name}
|
|
858
|
+
if system_prompt:
|
|
859
|
+
model_args["system_instruction"] = system_prompt
|
|
860
|
+
client = google_genai.GenerativeModel(**model_args)
|
|
861
|
+
|
|
862
|
+
gemini_config = google_genai.GenerationConfig(
|
|
863
|
+
**invocation_parameters,
|
|
864
|
+
)
|
|
865
|
+
gemini_params = {
|
|
866
|
+
"content": current_message,
|
|
867
|
+
"generation_config": gemini_config,
|
|
868
|
+
"stream": True,
|
|
869
|
+
}
|
|
870
|
+
|
|
871
|
+
chat = client.start_chat(history=gemini_message_history)
|
|
872
|
+
stream = await chat.send_message_async(**gemini_params)
|
|
873
|
+
async for event in stream:
|
|
874
|
+
yield TextChunk(content=event.text)
|
|
875
|
+
|
|
876
|
+
def _build_gemini_messages(
|
|
877
|
+
self,
|
|
878
|
+
messages: list[tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[str]]]],
|
|
879
|
+
) -> tuple[list["ContentType"], str, str]:
|
|
880
|
+
gemini_message_history: list["ContentType"] = []
|
|
881
|
+
system_prompts = []
|
|
882
|
+
for role, content, _tool_call_id, _tool_calls in messages:
|
|
883
|
+
if role == ChatCompletionMessageRole.USER:
|
|
884
|
+
gemini_message_history.append({"role": "user", "parts": content})
|
|
885
|
+
elif role == ChatCompletionMessageRole.AI:
|
|
886
|
+
gemini_message_history.append({"role": "model", "parts": content})
|
|
887
|
+
elif role == ChatCompletionMessageRole.SYSTEM:
|
|
888
|
+
system_prompts.append(content)
|
|
889
|
+
elif role == ChatCompletionMessageRole.TOOL:
|
|
890
|
+
raise NotImplementedError
|
|
891
|
+
else:
|
|
892
|
+
assert_never(role)
|
|
893
|
+
if gemini_message_history:
|
|
894
|
+
prompt = gemini_message_history.pop()["parts"]
|
|
895
|
+
else:
|
|
896
|
+
prompt = ""
|
|
897
|
+
|
|
898
|
+
return gemini_message_history, prompt, "\n".join(system_prompts)
|
|
899
|
+
|
|
900
|
+
|
|
662
901
|
def initialize_playground_clients() -> None:
|
|
663
902
|
"""
|
|
664
903
|
Ensure that all playground clients are registered at import time.
|