arize-phoenix 5.6.0__py3-none-any.whl → 5.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (39) hide show
  1. {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.8.0.dist-info}/METADATA +4 -6
  2. {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.8.0.dist-info}/RECORD +39 -30
  3. {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.8.0.dist-info}/WHEEL +1 -1
  4. phoenix/config.py +58 -0
  5. phoenix/server/api/helpers/playground_clients.py +758 -0
  6. phoenix/server/api/helpers/playground_registry.py +70 -0
  7. phoenix/server/api/helpers/playground_spans.py +422 -0
  8. phoenix/server/api/input_types/ChatCompletionInput.py +38 -0
  9. phoenix/server/api/input_types/GenerativeModelInput.py +17 -0
  10. phoenix/server/api/input_types/InvocationParameters.py +155 -13
  11. phoenix/server/api/input_types/TemplateOptions.py +10 -0
  12. phoenix/server/api/mutations/__init__.py +4 -0
  13. phoenix/server/api/mutations/chat_mutations.py +355 -0
  14. phoenix/server/api/queries.py +41 -52
  15. phoenix/server/api/schema.py +42 -10
  16. phoenix/server/api/subscriptions.py +378 -595
  17. phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +46 -0
  18. phoenix/server/api/types/GenerativeProvider.py +27 -3
  19. phoenix/server/api/types/Span.py +37 -0
  20. phoenix/server/api/types/TemplateLanguage.py +9 -0
  21. phoenix/server/app.py +75 -13
  22. phoenix/server/grpc_server.py +3 -1
  23. phoenix/server/main.py +14 -1
  24. phoenix/server/static/.vite/manifest.json +31 -31
  25. phoenix/server/static/assets/{components-C70HJiXz.js → components-MllbfxfJ.js} +168 -150
  26. phoenix/server/static/assets/{index-DLe1Oo3l.js → index-BVO2YcT1.js} +2 -2
  27. phoenix/server/static/assets/{pages-C8-Sl7JI.js → pages-BHfC6jnL.js} +464 -310
  28. phoenix/server/static/assets/{vendor-CtqfhlbC.js → vendor-BEuNhfwH.js} +1 -1
  29. phoenix/server/static/assets/{vendor-arizeai-C_3SBz56.js → vendor-arizeai-Bskhzyjm.js} +1 -1
  30. phoenix/server/static/assets/{vendor-codemirror-wfdk9cjp.js → vendor-codemirror-DLlXCf0x.js} +1 -1
  31. phoenix/server/static/assets/{vendor-recharts-BiVnSv90.js → vendor-recharts-CRqhvLYg.js} +1 -1
  32. phoenix/server/templates/index.html +1 -0
  33. phoenix/services.py +4 -0
  34. phoenix/session/session.py +15 -1
  35. phoenix/utilities/template_formatters.py +11 -1
  36. phoenix/version.py +1 -1
  37. {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.8.0.dist-info}/entry_points.txt +0 -0
  38. {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.8.0.dist-info}/licenses/IP_NOTICE +0 -0
  39. {arize_phoenix-5.6.0.dist-info → arize_phoenix-5.8.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,758 @@
1
+ import asyncio
2
+ import importlib.util
3
+ import inspect
4
+ import json
5
+ import time
6
+ from abc import ABC, abstractmethod
7
+ from collections.abc import AsyncIterator, Callable, Iterator
8
+ from functools import wraps
9
+ from typing import TYPE_CHECKING, Any, Hashable, Mapping, Optional, Union
10
+
11
+ from openinference.instrumentation import safe_json_dumps
12
+ from openinference.semconv.trace import SpanAttributes
13
+ from strawberry import UNSET
14
+ from strawberry.scalars import JSON as JSONScalarType
15
+ from typing_extensions import TypeAlias, assert_never
16
+
17
+ from phoenix.evals.models.rate_limiters import (
18
+ AsyncCallable,
19
+ GenericType,
20
+ ParameterSpec,
21
+ RateLimiter,
22
+ RateLimitError,
23
+ )
24
+ from phoenix.server.api.helpers.playground_registry import PROVIDER_DEFAULT, register_llm_client
25
+ from phoenix.server.api.input_types.GenerativeModelInput import GenerativeModelInput
26
+ from phoenix.server.api.input_types.InvocationParameters import (
27
+ BoundedFloatInvocationParameter,
28
+ CanonicalParameterName,
29
+ IntInvocationParameter,
30
+ InvocationParameter,
31
+ InvocationParameterInput,
32
+ JSONInvocationParameter,
33
+ StringListInvocationParameter,
34
+ extract_parameter,
35
+ validate_invocation_parameters,
36
+ )
37
+ from phoenix.server.api.types.ChatCompletionMessageRole import ChatCompletionMessageRole
38
+ from phoenix.server.api.types.ChatCompletionSubscriptionPayload import (
39
+ FunctionCallChunk,
40
+ TextChunk,
41
+ ToolCallChunk,
42
+ )
43
+ from phoenix.server.api.types.GenerativeProvider import GenerativeProviderKey
44
+
45
+ if TYPE_CHECKING:
46
+ from anthropic.types import MessageParam
47
+ from openai.types import CompletionUsage
48
+ from openai.types.chat import ChatCompletionMessageParam, ChatCompletionMessageToolCallParam
49
+
50
+ DependencyName: TypeAlias = str
51
+ SetSpanAttributesFn: TypeAlias = Callable[[Mapping[str, Any]], None]
52
+ ChatCompletionChunk: TypeAlias = Union[TextChunk, ToolCallChunk]
53
+
54
+
55
+ class KeyedSingleton:
56
+ _instances: dict[Hashable, "KeyedSingleton"] = {}
57
+
58
+ def __new__(cls, *args: Any, **kwargs: Any) -> "KeyedSingleton":
59
+ if "singleton_key" in kwargs:
60
+ singleton_key = kwargs.pop("singleton_key")
61
+ elif args:
62
+ singleton_key = args[0]
63
+ args = args[1:]
64
+ else:
65
+ raise ValueError("singleton_key must be provided")
66
+
67
+ instance_key = (cls, singleton_key)
68
+ if instance_key not in cls._instances:
69
+ instance = super().__new__(cls)
70
+ cls._instances[instance_key] = instance
71
+ return cls._instances[instance_key]
72
+
73
+
74
+ class PlaygroundRateLimiter(RateLimiter, KeyedSingleton):
75
+ """
76
+ A rate rate limiter class that will be instantiated once per `singleton_key`.
77
+ """
78
+
79
+ def __init__(self, singleton_key: Hashable, rate_limit_error: Optional[type[BaseException]]):
80
+ super().__init__(
81
+ rate_limit_error=rate_limit_error,
82
+ max_rate_limit_retries=3,
83
+ initial_per_second_request_rate=2.0,
84
+ maximum_per_second_request_rate=10.0,
85
+ enforcement_window_minutes=1,
86
+ rate_reduction_factor=0.5,
87
+ rate_increase_factor=0.01,
88
+ cooldown_seconds=5,
89
+ verbose=False,
90
+ )
91
+
92
+ # TODO: update the rate limiter class in phoenix.evals to support decorated sync functions
93
+ def _alimit(
94
+ self, fn: Callable[ParameterSpec, GenericType]
95
+ ) -> AsyncCallable[ParameterSpec, GenericType]:
96
+ @wraps(fn)
97
+ async def wrapper(*args: Any, **kwargs: Any) -> GenericType:
98
+ self._initialize_async_primitives()
99
+ assert self._rate_limit_handling_lock is not None and isinstance(
100
+ self._rate_limit_handling_lock, asyncio.Lock
101
+ )
102
+ assert self._rate_limit_handling is not None and isinstance(
103
+ self._rate_limit_handling, asyncio.Event
104
+ )
105
+ try:
106
+ try:
107
+ await asyncio.wait_for(self._rate_limit_handling.wait(), 120)
108
+ except asyncio.TimeoutError:
109
+ self._rate_limit_handling.set() # Set the event as a failsafe
110
+ await self._throttler.async_wait_until_ready()
111
+ request_start_time = time.time()
112
+ if inspect.iscoroutinefunction(fn):
113
+ return await fn(*args, **kwargs) # type: ignore
114
+ else:
115
+ return fn(*args, **kwargs)
116
+ except self._rate_limit_error:
117
+ async with self._rate_limit_handling_lock:
118
+ self._rate_limit_handling.clear() # prevent new requests from starting
119
+ self._throttler.on_rate_limit_error(request_start_time, verbose=self._verbose)
120
+ try:
121
+ for _attempt in range(self._max_rate_limit_retries):
122
+ try:
123
+ request_start_time = time.time()
124
+ await self._throttler.async_wait_until_ready()
125
+ if inspect.iscoroutinefunction(fn):
126
+ return await fn(*args, **kwargs) # type: ignore
127
+ else:
128
+ return fn(*args, **kwargs)
129
+ except self._rate_limit_error:
130
+ self._throttler.on_rate_limit_error(
131
+ request_start_time, verbose=self._verbose
132
+ )
133
+ continue
134
+ finally:
135
+ self._rate_limit_handling.set() # allow new requests to start
136
+ raise RateLimitError(f"Exceeded max ({self._max_rate_limit_retries}) retries")
137
+
138
+ return wrapper
139
+
140
+
141
+ class PlaygroundStreamingClient(ABC):
142
+ def __init__(
143
+ self,
144
+ model: GenerativeModelInput,
145
+ api_key: Optional[str] = None,
146
+ ) -> None:
147
+ self._attributes: dict[str, Any] = dict()
148
+
149
+ @classmethod
150
+ @abstractmethod
151
+ def dependencies(cls) -> list[DependencyName]:
152
+ # A list of dependency names this client needs to run
153
+ ...
154
+
155
+ @classmethod
156
+ @abstractmethod
157
+ def supported_invocation_parameters(cls) -> list[InvocationParameter]: ...
158
+
159
+ @abstractmethod
160
+ async def chat_completion_create(
161
+ self,
162
+ messages: list[
163
+ tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[JSONScalarType]]]
164
+ ],
165
+ tools: list[JSONScalarType],
166
+ **invocation_parameters: Any,
167
+ ) -> AsyncIterator[ChatCompletionChunk]:
168
+ # a yield statement is needed to satisfy the type-checker
169
+ # https://mypy.readthedocs.io/en/stable/more_types.html#asynchronous-iterators
170
+ yield TextChunk(content="")
171
+
172
+ @classmethod
173
+ def construct_invocation_parameters(
174
+ cls, invocation_parameters: list[InvocationParameterInput]
175
+ ) -> dict[str, Any]:
176
+ supported_params = cls.supported_invocation_parameters()
177
+ params = {param.invocation_name: param for param in supported_params}
178
+
179
+ formatted_invocation_parameters = dict()
180
+
181
+ for param_input in invocation_parameters:
182
+ invocation_name = param_input.invocation_name
183
+ if invocation_name not in params:
184
+ raise ValueError(f"Unsupported invocation parameter: {invocation_name}")
185
+
186
+ param_def = params[invocation_name]
187
+ value = extract_parameter(param_def, param_input)
188
+ if value is not UNSET:
189
+ formatted_invocation_parameters[invocation_name] = value
190
+ validate_invocation_parameters(supported_params, formatted_invocation_parameters)
191
+ return formatted_invocation_parameters
192
+
193
+ @classmethod
194
+ def dependencies_are_installed(cls) -> bool:
195
+ try:
196
+ for dependency in cls.dependencies():
197
+ if importlib.util.find_spec(dependency) is None:
198
+ return False
199
+ return True
200
+ except ValueError:
201
+ # happens in some cases if the spec is None
202
+ return False
203
+
204
+ @property
205
+ def attributes(self) -> dict[str, Any]:
206
+ return self._attributes
207
+
208
+
209
+ @register_llm_client(
210
+ provider_key=GenerativeProviderKey.OPENAI,
211
+ model_names=[
212
+ PROVIDER_DEFAULT,
213
+ "gpt-4o",
214
+ "gpt-4o-2024-08-06",
215
+ "gpt-4o-2024-05-13",
216
+ "chatgpt-4o-latest",
217
+ "gpt-4o-mini",
218
+ "gpt-4o-mini-2024-07-18",
219
+ "gpt-4-turbo",
220
+ "gpt-4-turbo-2024-04-09",
221
+ "gpt-4-turbo-preview",
222
+ "gpt-4-0125-preview",
223
+ "gpt-4-1106-preview",
224
+ "gpt-4",
225
+ "gpt-4-0613",
226
+ "gpt-3.5-turbo-0125",
227
+ "gpt-3.5-turbo",
228
+ "gpt-3.5-turbo-1106",
229
+ "gpt-3.5-turbo-instruct",
230
+ ],
231
+ )
232
+ class OpenAIStreamingClient(PlaygroundStreamingClient):
233
+ def __init__(
234
+ self,
235
+ model: GenerativeModelInput,
236
+ api_key: Optional[str] = None,
237
+ ) -> None:
238
+ from openai import AsyncOpenAI
239
+ from openai import RateLimitError as OpenAIRateLimitError
240
+
241
+ super().__init__(model=model, api_key=api_key)
242
+ self.client = AsyncOpenAI(api_key=api_key)
243
+ self.model_name = model.name
244
+ self.rate_limiter = PlaygroundRateLimiter(model.provider_key, OpenAIRateLimitError)
245
+
246
+ @classmethod
247
+ def dependencies(cls) -> list[DependencyName]:
248
+ return ["openai"]
249
+
250
+ @classmethod
251
+ def supported_invocation_parameters(cls) -> list[InvocationParameter]:
252
+ return [
253
+ BoundedFloatInvocationParameter(
254
+ invocation_name="temperature",
255
+ canonical_name=CanonicalParameterName.TEMPERATURE,
256
+ label="Temperature",
257
+ default_value=0.0,
258
+ min_value=0.0,
259
+ max_value=2.0,
260
+ ),
261
+ IntInvocationParameter(
262
+ invocation_name="max_tokens",
263
+ canonical_name=CanonicalParameterName.MAX_COMPLETION_TOKENS,
264
+ label="Max Tokens",
265
+ ),
266
+ BoundedFloatInvocationParameter(
267
+ invocation_name="frequency_penalty",
268
+ label="Frequency Penalty",
269
+ min_value=-2.0,
270
+ max_value=2.0,
271
+ ),
272
+ BoundedFloatInvocationParameter(
273
+ invocation_name="presence_penalty",
274
+ label="Presence Penalty",
275
+ min_value=-2.0,
276
+ max_value=2.0,
277
+ ),
278
+ StringListInvocationParameter(
279
+ invocation_name="stop",
280
+ canonical_name=CanonicalParameterName.STOP_SEQUENCES,
281
+ label="Stop Sequences",
282
+ ),
283
+ BoundedFloatInvocationParameter(
284
+ invocation_name="top_p",
285
+ canonical_name=CanonicalParameterName.TOP_P,
286
+ label="Top P",
287
+ min_value=0.0,
288
+ max_value=1.0,
289
+ ),
290
+ IntInvocationParameter(
291
+ invocation_name="seed",
292
+ canonical_name=CanonicalParameterName.RANDOM_SEED,
293
+ label="Seed",
294
+ ),
295
+ JSONInvocationParameter(
296
+ invocation_name="tool_choice",
297
+ label="Tool Choice",
298
+ canonical_name=CanonicalParameterName.TOOL_CHOICE,
299
+ ),
300
+ JSONInvocationParameter(
301
+ invocation_name="response_format",
302
+ label="Response Format",
303
+ canonical_name=CanonicalParameterName.RESPONSE_FORMAT,
304
+ ),
305
+ ]
306
+
307
+ async def chat_completion_create(
308
+ self,
309
+ messages: list[
310
+ tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[JSONScalarType]]]
311
+ ],
312
+ tools: list[JSONScalarType],
313
+ **invocation_parameters: Any,
314
+ ) -> AsyncIterator[ChatCompletionChunk]:
315
+ from openai import NOT_GIVEN
316
+ from openai.types.chat import ChatCompletionStreamOptionsParam
317
+
318
+ # Convert standard messages to OpenAI messages
319
+ openai_messages = [self.to_openai_chat_completion_param(*message) for message in messages]
320
+ tool_call_ids: dict[int, str] = {}
321
+ token_usage: Optional["CompletionUsage"] = None
322
+ throttled_create = self.rate_limiter.alimit(self.client.chat.completions.create)
323
+ async for chunk in await throttled_create(
324
+ messages=openai_messages,
325
+ model=self.model_name,
326
+ stream=True,
327
+ stream_options=ChatCompletionStreamOptionsParam(include_usage=True),
328
+ tools=tools or NOT_GIVEN,
329
+ **invocation_parameters,
330
+ ):
331
+ if (usage := chunk.usage) is not None:
332
+ token_usage = usage
333
+ continue
334
+ choice = chunk.choices[0]
335
+ delta = choice.delta
336
+ if choice.finish_reason is None:
337
+ if isinstance(chunk_content := delta.content, str):
338
+ text_chunk = TextChunk(content=chunk_content)
339
+ yield text_chunk
340
+ if (tool_calls := delta.tool_calls) is not None:
341
+ for tool_call_index, tool_call in enumerate(tool_calls):
342
+ tool_call_id = (
343
+ tool_call.id
344
+ if tool_call.id is not None
345
+ else tool_call_ids[tool_call_index]
346
+ )
347
+ tool_call_ids[tool_call_index] = tool_call_id
348
+ if (function := tool_call.function) is not None:
349
+ tool_call_chunk = ToolCallChunk(
350
+ id=tool_call_id,
351
+ function=FunctionCallChunk(
352
+ name=function.name or "",
353
+ arguments=function.arguments or "",
354
+ ),
355
+ )
356
+ yield tool_call_chunk
357
+ if token_usage is not None:
358
+ self._attributes.update(dict(self._llm_token_counts(token_usage)))
359
+
360
+ def to_openai_chat_completion_param(
361
+ self,
362
+ role: ChatCompletionMessageRole,
363
+ content: JSONScalarType,
364
+ tool_call_id: Optional[str] = None,
365
+ tool_calls: Optional[list[JSONScalarType]] = None,
366
+ ) -> "ChatCompletionMessageParam":
367
+ from openai.types.chat import (
368
+ ChatCompletionAssistantMessageParam,
369
+ ChatCompletionSystemMessageParam,
370
+ ChatCompletionToolMessageParam,
371
+ ChatCompletionUserMessageParam,
372
+ )
373
+
374
+ if role is ChatCompletionMessageRole.USER:
375
+ return ChatCompletionUserMessageParam(
376
+ {
377
+ "content": content,
378
+ "role": "user",
379
+ }
380
+ )
381
+ if role is ChatCompletionMessageRole.SYSTEM:
382
+ return ChatCompletionSystemMessageParam(
383
+ {
384
+ "content": content,
385
+ "role": "system",
386
+ }
387
+ )
388
+ if role is ChatCompletionMessageRole.AI:
389
+ if tool_calls is None:
390
+ return ChatCompletionAssistantMessageParam(
391
+ {
392
+ "content": content,
393
+ "role": "assistant",
394
+ }
395
+ )
396
+ else:
397
+ return ChatCompletionAssistantMessageParam(
398
+ {
399
+ "content": content,
400
+ "role": "assistant",
401
+ "tool_calls": [
402
+ self.to_openai_tool_call_param(tool_call) for tool_call in tool_calls
403
+ ],
404
+ }
405
+ )
406
+ if role is ChatCompletionMessageRole.TOOL:
407
+ if tool_call_id is None:
408
+ raise ValueError("tool_call_id is required for tool messages")
409
+ return ChatCompletionToolMessageParam(
410
+ {"content": content, "role": "tool", "tool_call_id": tool_call_id}
411
+ )
412
+ assert_never(role)
413
+
414
+ def to_openai_tool_call_param(
415
+ self,
416
+ tool_call: JSONScalarType,
417
+ ) -> "ChatCompletionMessageToolCallParam":
418
+ from openai.types.chat import ChatCompletionMessageToolCallParam
419
+
420
+ return ChatCompletionMessageToolCallParam(
421
+ id=tool_call.get("id", ""),
422
+ function={
423
+ "name": tool_call.get("function", {}).get("name", ""),
424
+ "arguments": safe_json_dumps(tool_call.get("function", {}).get("arguments", "")),
425
+ },
426
+ type="function",
427
+ )
428
+
429
+ @staticmethod
430
+ def _llm_token_counts(usage: "CompletionUsage") -> Iterator[tuple[str, Any]]:
431
+ yield LLM_TOKEN_COUNT_PROMPT, usage.prompt_tokens
432
+ yield LLM_TOKEN_COUNT_COMPLETION, usage.completion_tokens
433
+ yield LLM_TOKEN_COUNT_TOTAL, usage.total_tokens
434
+
435
+
436
+ @register_llm_client(
437
+ provider_key=GenerativeProviderKey.OPENAI,
438
+ model_names=[
439
+ "o1-preview",
440
+ "o1-preview-2024-09-12",
441
+ "o1-mini",
442
+ "o1-mini-2024-09-12",
443
+ ],
444
+ )
445
+ class OpenAIO1StreamingClient(OpenAIStreamingClient):
446
+ @classmethod
447
+ def supported_invocation_parameters(cls) -> list[InvocationParameter]:
448
+ return [
449
+ IntInvocationParameter(
450
+ invocation_name="max_completion_tokens",
451
+ canonical_name=CanonicalParameterName.MAX_COMPLETION_TOKENS,
452
+ label="Max Completion Tokens",
453
+ ),
454
+ IntInvocationParameter(
455
+ invocation_name="seed",
456
+ canonical_name=CanonicalParameterName.RANDOM_SEED,
457
+ label="Seed",
458
+ ),
459
+ JSONInvocationParameter(
460
+ invocation_name="tool_choice",
461
+ label="Tool Choice",
462
+ canonical_name=CanonicalParameterName.TOOL_CHOICE,
463
+ ),
464
+ ]
465
+
466
+ async def chat_completion_create(
467
+ self,
468
+ messages: list[
469
+ tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[JSONScalarType]]]
470
+ ],
471
+ tools: list[JSONScalarType],
472
+ **invocation_parameters: Any,
473
+ ) -> AsyncIterator[ChatCompletionChunk]:
474
+ from openai import NOT_GIVEN
475
+
476
+ # Convert standard messages to OpenAI messages
477
+ unfiltered_openai_messages = [
478
+ self.to_openai_o1_chat_completion_param(*message) for message in messages
479
+ ]
480
+
481
+ # filter out unsupported messages
482
+ openai_messages: list[ChatCompletionMessageParam] = [
483
+ message for message in unfiltered_openai_messages if message is not None
484
+ ]
485
+
486
+ tool_call_ids: dict[int, str] = {}
487
+
488
+ throttled_create = self.rate_limiter.alimit(self.client.chat.completions.create)
489
+ response = await throttled_create(
490
+ messages=openai_messages,
491
+ model=self.model_name,
492
+ tools=tools or NOT_GIVEN,
493
+ **invocation_parameters,
494
+ )
495
+
496
+ choice = response.choices[0]
497
+ message = choice.message
498
+ content = message.content
499
+
500
+ text_chunk = TextChunk(content=content)
501
+ yield text_chunk
502
+
503
+ if (tool_calls := message.tool_calls) is not None:
504
+ for tool_call_index, tool_call in enumerate(tool_calls):
505
+ tool_call_id = (
506
+ tool_call.id
507
+ if tool_call.id is not None
508
+ else tool_call_ids.get(tool_call_index, f"tool_call_{tool_call_index}")
509
+ )
510
+ tool_call_ids[tool_call_index] = tool_call_id
511
+ if (function := tool_call.function) is not None:
512
+ tool_call_chunk = ToolCallChunk(
513
+ id=tool_call_id,
514
+ function=FunctionCallChunk(
515
+ name=function.name or "",
516
+ arguments=function.arguments or "",
517
+ ),
518
+ )
519
+ yield tool_call_chunk
520
+
521
+ if (usage := response.usage) is not None:
522
+ self._attributes.update(dict(self._llm_token_counts(usage)))
523
+
524
+ def to_openai_o1_chat_completion_param(
525
+ self,
526
+ role: ChatCompletionMessageRole,
527
+ content: JSONScalarType,
528
+ tool_call_id: Optional[str] = None,
529
+ tool_calls: Optional[list[JSONScalarType]] = None,
530
+ ) -> Optional["ChatCompletionMessageParam"]:
531
+ from openai.types.chat import (
532
+ ChatCompletionAssistantMessageParam,
533
+ ChatCompletionToolMessageParam,
534
+ ChatCompletionUserMessageParam,
535
+ )
536
+
537
+ if role is ChatCompletionMessageRole.USER:
538
+ return ChatCompletionUserMessageParam(
539
+ {
540
+ "content": content,
541
+ "role": "user",
542
+ }
543
+ )
544
+ if role is ChatCompletionMessageRole.SYSTEM:
545
+ return None # System messages are not supported for o1 models
546
+ if role is ChatCompletionMessageRole.AI:
547
+ if tool_calls is None:
548
+ return ChatCompletionAssistantMessageParam(
549
+ {
550
+ "content": content,
551
+ "role": "assistant",
552
+ }
553
+ )
554
+ else:
555
+ return ChatCompletionAssistantMessageParam(
556
+ {
557
+ "content": content,
558
+ "role": "assistant",
559
+ "tool_calls": [
560
+ self.to_openai_tool_call_param(tool_call) for tool_call in tool_calls
561
+ ],
562
+ }
563
+ )
564
+ if role is ChatCompletionMessageRole.TOOL:
565
+ if tool_call_id is None:
566
+ raise ValueError("tool_call_id is required for tool messages")
567
+ return ChatCompletionToolMessageParam(
568
+ {"content": content, "role": "tool", "tool_call_id": tool_call_id}
569
+ )
570
+ assert_never(role)
571
+
572
+ @staticmethod
573
+ def _llm_token_counts(usage: "CompletionUsage") -> Iterator[tuple[str, Any]]:
574
+ yield LLM_TOKEN_COUNT_PROMPT, usage.prompt_tokens
575
+ yield LLM_TOKEN_COUNT_COMPLETION, usage.completion_tokens
576
+ yield LLM_TOKEN_COUNT_TOTAL, usage.total_tokens
577
+
578
+
579
+ @register_llm_client(
580
+ provider_key=GenerativeProviderKey.AZURE_OPENAI,
581
+ model_names=[
582
+ PROVIDER_DEFAULT,
583
+ ],
584
+ )
585
+ class AzureOpenAIStreamingClient(OpenAIStreamingClient):
586
+ def __init__(
587
+ self,
588
+ model: GenerativeModelInput,
589
+ api_key: Optional[str] = None,
590
+ ):
591
+ from openai import AsyncAzureOpenAI
592
+
593
+ super().__init__(model=model, api_key=api_key)
594
+ if model.endpoint is None or model.api_version is None:
595
+ raise ValueError("endpoint and api_version are required for Azure OpenAI models")
596
+ self.client = AsyncAzureOpenAI(
597
+ api_key=api_key,
598
+ azure_endpoint=model.endpoint,
599
+ api_version=model.api_version,
600
+ )
601
+
602
+
603
+ @register_llm_client(
604
+ provider_key=GenerativeProviderKey.ANTHROPIC,
605
+ model_names=[
606
+ PROVIDER_DEFAULT,
607
+ "claude-3-5-sonnet-20240620",
608
+ "claude-3-opus-20240229",
609
+ "claude-3-sonnet-20240229",
610
+ "claude-3-haiku-20240307",
611
+ ],
612
+ )
613
+ class AnthropicStreamingClient(PlaygroundStreamingClient):
614
+ def __init__(
615
+ self,
616
+ model: GenerativeModelInput,
617
+ api_key: Optional[str] = None,
618
+ ) -> None:
619
+ import anthropic
620
+
621
+ super().__init__(model=model, api_key=api_key)
622
+ self.client = anthropic.AsyncAnthropic(api_key=api_key)
623
+ self.model_name = model.name
624
+ self.rate_limiter = PlaygroundRateLimiter(model.provider_key, anthropic.RateLimitError)
625
+
626
+ @classmethod
627
+ def dependencies(cls) -> list[DependencyName]:
628
+ return ["anthropic"]
629
+
630
+ @classmethod
631
+ def supported_invocation_parameters(cls) -> list[InvocationParameter]:
632
+ return [
633
+ IntInvocationParameter(
634
+ invocation_name="max_tokens",
635
+ canonical_name=CanonicalParameterName.MAX_COMPLETION_TOKENS,
636
+ label="Max Tokens",
637
+ required=True,
638
+ ),
639
+ BoundedFloatInvocationParameter(
640
+ invocation_name="temperature",
641
+ canonical_name=CanonicalParameterName.TEMPERATURE,
642
+ label="Temperature",
643
+ min_value=0.0,
644
+ max_value=1.0,
645
+ ),
646
+ StringListInvocationParameter(
647
+ invocation_name="stop_sequences",
648
+ canonical_name=CanonicalParameterName.STOP_SEQUENCES,
649
+ label="Stop Sequences",
650
+ ),
651
+ BoundedFloatInvocationParameter(
652
+ invocation_name="top_p",
653
+ canonical_name=CanonicalParameterName.TOP_P,
654
+ label="Top P",
655
+ min_value=0.0,
656
+ max_value=1.0,
657
+ ),
658
+ JSONInvocationParameter(
659
+ invocation_name="tool_choice",
660
+ label="Tool Choice",
661
+ canonical_name=CanonicalParameterName.TOOL_CHOICE,
662
+ ),
663
+ ]
664
+
665
+ async def chat_completion_create(
666
+ self,
667
+ messages: list[
668
+ tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[JSONScalarType]]]
669
+ ],
670
+ tools: list[JSONScalarType],
671
+ **invocation_parameters: Any,
672
+ ) -> AsyncIterator[ChatCompletionChunk]:
673
+ import anthropic.lib.streaming as anthropic_streaming
674
+ import anthropic.types as anthropic_types
675
+
676
+ anthropic_messages, system_prompt = self._build_anthropic_messages(messages)
677
+
678
+ anthropic_params = {
679
+ "messages": anthropic_messages,
680
+ "model": self.model_name,
681
+ "system": system_prompt,
682
+ "max_tokens": 1024,
683
+ "tools": tools,
684
+ **invocation_parameters,
685
+ }
686
+ throttled_stream = self.rate_limiter._alimit(self.client.messages.stream)
687
+ async with await throttled_stream(**anthropic_params) as stream:
688
+ async for event in stream:
689
+ if isinstance(event, anthropic_types.RawMessageStartEvent):
690
+ self._attributes.update(
691
+ {LLM_TOKEN_COUNT_PROMPT: event.message.usage.input_tokens}
692
+ )
693
+ elif isinstance(event, anthropic_streaming.TextEvent):
694
+ yield TextChunk(content=event.text)
695
+ elif isinstance(event, anthropic_streaming.MessageStopEvent):
696
+ self._attributes.update(
697
+ {LLM_TOKEN_COUNT_COMPLETION: event.message.usage.output_tokens}
698
+ )
699
+ elif (
700
+ isinstance(event, anthropic_streaming.ContentBlockStopEvent)
701
+ and event.content_block.type == "tool_use"
702
+ ):
703
+ tool_call_chunk = ToolCallChunk(
704
+ id=event.content_block.id,
705
+ function=FunctionCallChunk(
706
+ name=event.content_block.name,
707
+ arguments=json.dumps(event.content_block.input),
708
+ ),
709
+ )
710
+ yield tool_call_chunk
711
+ elif isinstance(
712
+ event,
713
+ (
714
+ anthropic_types.RawContentBlockStartEvent,
715
+ anthropic_types.RawContentBlockDeltaEvent,
716
+ anthropic_types.RawMessageDeltaEvent,
717
+ anthropic_streaming.ContentBlockStopEvent,
718
+ anthropic_streaming.InputJsonEvent,
719
+ ),
720
+ ):
721
+ # event types emitted by the stream that don't contain useful information
722
+ pass
723
+ elif isinstance(event, anthropic_streaming.InputJsonEvent):
724
+ raise NotImplementedError
725
+ else:
726
+ assert_never(event)
727
+
728
+ def _build_anthropic_messages(
729
+ self,
730
+ messages: list[tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[str]]]],
731
+ ) -> tuple[list["MessageParam"], str]:
732
+ anthropic_messages: list["MessageParam"] = []
733
+ system_prompt = ""
734
+ for role, content, _tool_call_id, _tool_calls in messages:
735
+ if role == ChatCompletionMessageRole.USER:
736
+ anthropic_messages.append({"role": "user", "content": content})
737
+ elif role == ChatCompletionMessageRole.AI:
738
+ anthropic_messages.append({"role": "assistant", "content": content})
739
+ elif role == ChatCompletionMessageRole.SYSTEM:
740
+ system_prompt += content + "\n"
741
+ elif role == ChatCompletionMessageRole.TOOL:
742
+ raise NotImplementedError
743
+ else:
744
+ assert_never(role)
745
+
746
+ return anthropic_messages, system_prompt
747
+
748
+
749
+ def initialize_playground_clients() -> None:
750
+ """
751
+ Ensure that all playground clients are registered at import time.
752
+ """
753
+ pass
754
+
755
+
756
+ LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT
757
+ LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION
758
+ LLM_TOKEN_COUNT_TOTAL = SpanAttributes.LLM_TOKEN_COUNT_TOTAL