arize-phoenix 10.0.4__py3-none-any.whl → 12.28.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +124 -72
- arize_phoenix-12.28.1.dist-info/RECORD +499 -0
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
- phoenix/__generated__/__init__.py +0 -0
- phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
- phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
- phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
- phoenix/__init__.py +5 -4
- phoenix/auth.py +39 -2
- phoenix/config.py +1763 -91
- phoenix/datetime_utils.py +120 -2
- phoenix/db/README.md +595 -25
- phoenix/db/bulk_inserter.py +145 -103
- phoenix/db/engines.py +140 -33
- phoenix/db/enums.py +3 -12
- phoenix/db/facilitator.py +302 -35
- phoenix/db/helpers.py +1000 -65
- phoenix/db/iam_auth.py +64 -0
- phoenix/db/insertion/dataset.py +135 -2
- phoenix/db/insertion/document_annotation.py +9 -6
- phoenix/db/insertion/evaluation.py +2 -3
- phoenix/db/insertion/helpers.py +17 -2
- phoenix/db/insertion/session_annotation.py +176 -0
- phoenix/db/insertion/span.py +15 -11
- phoenix/db/insertion/span_annotation.py +3 -4
- phoenix/db/insertion/trace_annotation.py +3 -4
- phoenix/db/insertion/types.py +50 -20
- phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
- phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
- phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
- phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
- phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
- phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
- phoenix/db/migrations/versions/a20694b15f82_cost.py +196 -0
- phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
- phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
- phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
- phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
- phoenix/db/models.py +669 -56
- phoenix/db/pg_config.py +10 -0
- phoenix/db/types/model_provider.py +4 -0
- phoenix/db/types/token_price_customization.py +29 -0
- phoenix/db/types/trace_retention.py +23 -15
- phoenix/experiments/evaluators/utils.py +3 -3
- phoenix/experiments/functions.py +160 -52
- phoenix/experiments/tracing.py +2 -2
- phoenix/experiments/types.py +1 -1
- phoenix/inferences/inferences.py +1 -2
- phoenix/server/api/auth.py +38 -7
- phoenix/server/api/auth_messages.py +46 -0
- phoenix/server/api/context.py +100 -4
- phoenix/server/api/dataloaders/__init__.py +79 -5
- phoenix/server/api/dataloaders/annotation_configs_by_project.py +31 -0
- phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
- phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
- phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
- phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
- phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
- phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
- phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
- phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
- phoenix/server/api/dataloaders/dataset_labels.py +36 -0
- phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
- phoenix/server/api/dataloaders/document_evaluations.py +6 -9
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
- phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
- phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
- phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
- phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
- phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
- phoenix/server/api/dataloaders/last_used_times_by_generative_model_id.py +35 -0
- phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
- phoenix/server/api/dataloaders/record_counts.py +37 -10
- phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
- phoenix/server/api/dataloaders/span_cost_by_span.py +24 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_generative_model.py +56 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_project_session.py +57 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_span.py +43 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_trace.py +56 -0
- phoenix/server/api/dataloaders/span_cost_details_by_span_cost.py +27 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment.py +57 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_run.py +58 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_generative_model.py +55 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project.py +152 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project_session.py +56 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_trace.py +55 -0
- phoenix/server/api/dataloaders/span_costs.py +29 -0
- phoenix/server/api/dataloaders/table_fields.py +2 -2
- phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
- phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
- phoenix/server/api/dataloaders/types.py +29 -0
- phoenix/server/api/exceptions.py +11 -1
- phoenix/server/api/helpers/dataset_helpers.py +5 -1
- phoenix/server/api/helpers/playground_clients.py +1243 -292
- phoenix/server/api/helpers/playground_registry.py +2 -2
- phoenix/server/api/helpers/playground_spans.py +8 -4
- phoenix/server/api/helpers/playground_users.py +26 -0
- phoenix/server/api/helpers/prompts/conversions/aws.py +83 -0
- phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
- phoenix/server/api/helpers/prompts/models.py +205 -22
- phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
- phoenix/server/api/input_types/ChatCompletionInput.py +6 -2
- phoenix/server/api/input_types/CreateProjectInput.py +27 -0
- phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
- phoenix/server/api/input_types/DatasetFilter.py +17 -0
- phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
- phoenix/server/api/input_types/GenerativeCredentialInput.py +9 -0
- phoenix/server/api/input_types/GenerativeModelInput.py +5 -0
- phoenix/server/api/input_types/ProjectSessionSort.py +161 -1
- phoenix/server/api/input_types/PromptFilter.py +14 -0
- phoenix/server/api/input_types/PromptVersionInput.py +52 -1
- phoenix/server/api/input_types/SpanSort.py +44 -7
- phoenix/server/api/input_types/TimeBinConfig.py +23 -0
- phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
- phoenix/server/api/input_types/UserRoleInput.py +1 -0
- phoenix/server/api/mutations/__init__.py +10 -0
- phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
- phoenix/server/api/mutations/api_key_mutations.py +19 -23
- phoenix/server/api/mutations/chat_mutations.py +154 -47
- phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
- phoenix/server/api/mutations/dataset_mutations.py +21 -16
- phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
- phoenix/server/api/mutations/experiment_mutations.py +2 -2
- phoenix/server/api/mutations/export_events_mutations.py +3 -3
- phoenix/server/api/mutations/model_mutations.py +210 -0
- phoenix/server/api/mutations/project_mutations.py +49 -10
- phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
- phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
- phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
- phoenix/server/api/mutations/prompt_mutations.py +65 -129
- phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
- phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
- phoenix/server/api/mutations/trace_annotations_mutations.py +14 -10
- phoenix/server/api/mutations/trace_mutations.py +47 -3
- phoenix/server/api/mutations/user_mutations.py +66 -41
- phoenix/server/api/queries.py +768 -293
- phoenix/server/api/routers/__init__.py +2 -2
- phoenix/server/api/routers/auth.py +154 -88
- phoenix/server/api/routers/ldap.py +229 -0
- phoenix/server/api/routers/oauth2.py +369 -106
- phoenix/server/api/routers/v1/__init__.py +24 -4
- phoenix/server/api/routers/v1/annotation_configs.py +23 -31
- phoenix/server/api/routers/v1/annotations.py +481 -17
- phoenix/server/api/routers/v1/datasets.py +395 -81
- phoenix/server/api/routers/v1/documents.py +142 -0
- phoenix/server/api/routers/v1/evaluations.py +24 -31
- phoenix/server/api/routers/v1/experiment_evaluations.py +19 -8
- phoenix/server/api/routers/v1/experiment_runs.py +337 -59
- phoenix/server/api/routers/v1/experiments.py +479 -48
- phoenix/server/api/routers/v1/models.py +7 -0
- phoenix/server/api/routers/v1/projects.py +18 -49
- phoenix/server/api/routers/v1/prompts.py +54 -40
- phoenix/server/api/routers/v1/sessions.py +108 -0
- phoenix/server/api/routers/v1/spans.py +1091 -81
- phoenix/server/api/routers/v1/traces.py +132 -78
- phoenix/server/api/routers/v1/users.py +389 -0
- phoenix/server/api/routers/v1/utils.py +3 -7
- phoenix/server/api/subscriptions.py +305 -88
- phoenix/server/api/types/Annotation.py +90 -23
- phoenix/server/api/types/ApiKey.py +13 -17
- phoenix/server/api/types/AuthMethod.py +1 -0
- phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
- phoenix/server/api/types/CostBreakdown.py +12 -0
- phoenix/server/api/types/Dataset.py +226 -72
- phoenix/server/api/types/DatasetExample.py +88 -18
- phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
- phoenix/server/api/types/DatasetLabel.py +57 -0
- phoenix/server/api/types/DatasetSplit.py +98 -0
- phoenix/server/api/types/DatasetVersion.py +49 -4
- phoenix/server/api/types/DocumentAnnotation.py +212 -0
- phoenix/server/api/types/Experiment.py +264 -59
- phoenix/server/api/types/ExperimentComparison.py +5 -10
- phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
- phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
- phoenix/server/api/types/ExperimentRun.py +169 -65
- phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
- phoenix/server/api/types/GenerativeModel.py +245 -3
- phoenix/server/api/types/GenerativeProvider.py +70 -11
- phoenix/server/api/types/{Model.py → InferenceModel.py} +1 -1
- phoenix/server/api/types/ModelInterface.py +16 -0
- phoenix/server/api/types/PlaygroundModel.py +20 -0
- phoenix/server/api/types/Project.py +1278 -216
- phoenix/server/api/types/ProjectSession.py +188 -28
- phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
- phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
- phoenix/server/api/types/Prompt.py +119 -39
- phoenix/server/api/types/PromptLabel.py +42 -25
- phoenix/server/api/types/PromptVersion.py +11 -8
- phoenix/server/api/types/PromptVersionTag.py +65 -25
- phoenix/server/api/types/ServerStatus.py +6 -0
- phoenix/server/api/types/Span.py +167 -123
- phoenix/server/api/types/SpanAnnotation.py +189 -42
- phoenix/server/api/types/SpanCostDetailSummaryEntry.py +10 -0
- phoenix/server/api/types/SpanCostSummary.py +10 -0
- phoenix/server/api/types/SystemApiKey.py +65 -1
- phoenix/server/api/types/TokenPrice.py +16 -0
- phoenix/server/api/types/TokenUsage.py +3 -3
- phoenix/server/api/types/Trace.py +223 -51
- phoenix/server/api/types/TraceAnnotation.py +149 -50
- phoenix/server/api/types/User.py +137 -32
- phoenix/server/api/types/UserApiKey.py +73 -26
- phoenix/server/api/types/node.py +10 -0
- phoenix/server/api/types/pagination.py +11 -2
- phoenix/server/app.py +290 -45
- phoenix/server/authorization.py +38 -3
- phoenix/server/bearer_auth.py +34 -24
- phoenix/server/cost_tracking/cost_details_calculator.py +196 -0
- phoenix/server/cost_tracking/cost_model_lookup.py +179 -0
- phoenix/server/cost_tracking/helpers.py +68 -0
- phoenix/server/cost_tracking/model_cost_manifest.json +3657 -830
- phoenix/server/cost_tracking/regex_specificity.py +397 -0
- phoenix/server/cost_tracking/token_cost_calculator.py +57 -0
- phoenix/server/daemons/__init__.py +0 -0
- phoenix/server/daemons/db_disk_usage_monitor.py +214 -0
- phoenix/server/daemons/generative_model_store.py +103 -0
- phoenix/server/daemons/span_cost_calculator.py +99 -0
- phoenix/server/dml_event.py +17 -0
- phoenix/server/dml_event_handler.py +5 -0
- phoenix/server/email/sender.py +56 -3
- phoenix/server/email/templates/db_disk_usage_notification.html +19 -0
- phoenix/server/email/types.py +11 -0
- phoenix/server/experiments/__init__.py +0 -0
- phoenix/server/experiments/utils.py +14 -0
- phoenix/server/grpc_server.py +11 -11
- phoenix/server/jwt_store.py +17 -15
- phoenix/server/ldap.py +1449 -0
- phoenix/server/main.py +26 -10
- phoenix/server/oauth2.py +330 -12
- phoenix/server/prometheus.py +66 -6
- phoenix/server/rate_limiters.py +4 -9
- phoenix/server/retention.py +33 -20
- phoenix/server/session_filters.py +49 -0
- phoenix/server/static/.vite/manifest.json +55 -51
- phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
- phoenix/server/static/assets/{index-E0M82BdE.js → index-CTQoemZv.js} +140 -56
- phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
- phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
- phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
- phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
- phoenix/server/static/assets/vendor-recharts-V9cwpXsm.js +37 -0
- phoenix/server/static/assets/vendor-shiki-Do--csgv.js +5 -0
- phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
- phoenix/server/templates/index.html +40 -6
- phoenix/server/thread_server.py +1 -2
- phoenix/server/types.py +14 -4
- phoenix/server/utils.py +74 -0
- phoenix/session/client.py +56 -3
- phoenix/session/data_extractor.py +5 -0
- phoenix/session/evaluation.py +14 -5
- phoenix/session/session.py +45 -9
- phoenix/settings.py +5 -0
- phoenix/trace/attributes.py +80 -13
- phoenix/trace/dsl/helpers.py +90 -1
- phoenix/trace/dsl/query.py +8 -6
- phoenix/trace/projects.py +5 -0
- phoenix/utilities/template_formatters.py +1 -1
- phoenix/version.py +1 -1
- arize_phoenix-10.0.4.dist-info/RECORD +0 -405
- phoenix/server/api/types/Evaluation.py +0 -39
- phoenix/server/cost_tracking/cost_lookup.py +0 -255
- phoenix/server/static/assets/components-DULKeDfL.js +0 -4365
- phoenix/server/static/assets/pages-Cl0A-0U2.js +0 -7430
- phoenix/server/static/assets/vendor-WIZid84E.css +0 -1
- phoenix/server/static/assets/vendor-arizeai-Dy-0mSNw.js +0 -649
- phoenix/server/static/assets/vendor-codemirror-DBtifKNr.js +0 -33
- phoenix/server/static/assets/vendor-oB4u9zuV.js +0 -905
- phoenix/server/static/assets/vendor-recharts-D-T4KPz2.js +0 -59
- phoenix/server/static/assets/vendor-shiki-BMn4O_9F.js +0 -5
- phoenix/server/static/assets/vendor-three-C5WAXd5r.js +0 -2998
- phoenix/utilities/deprecation.py +0 -31
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -7,6 +7,7 @@ import json
|
|
|
7
7
|
import time
|
|
8
8
|
from abc import ABC, abstractmethod
|
|
9
9
|
from collections.abc import AsyncIterator, Callable, Iterator
|
|
10
|
+
from dataclasses import dataclass
|
|
10
11
|
from functools import wraps
|
|
11
12
|
from typing import TYPE_CHECKING, Any, Hashable, Mapping, MutableMapping, Optional, Union
|
|
12
13
|
|
|
@@ -19,7 +20,7 @@ from openinference.semconv.trace import (
|
|
|
19
20
|
)
|
|
20
21
|
from strawberry import UNSET
|
|
21
22
|
from strawberry.scalars import JSON as JSONScalarType
|
|
22
|
-
from typing_extensions import TypeAlias, assert_never
|
|
23
|
+
from typing_extensions import TypeAlias, assert_never, override
|
|
23
24
|
|
|
24
25
|
from phoenix.config import getenv
|
|
25
26
|
from phoenix.evals.models.rate_limiters import (
|
|
@@ -56,6 +57,7 @@ from phoenix.server.api.types.GenerativeProvider import GenerativeProviderKey
|
|
|
56
57
|
if TYPE_CHECKING:
|
|
57
58
|
import httpx
|
|
58
59
|
from anthropic.types import MessageParam, TextBlockParam, ToolResultBlockParam
|
|
60
|
+
from botocore.awsrequest import AWSPreparedRequest # type: ignore[import-untyped]
|
|
59
61
|
from google.generativeai.types import ContentType
|
|
60
62
|
from openai import AsyncAzureOpenAI, AsyncOpenAI
|
|
61
63
|
from openai.types import CompletionUsage
|
|
@@ -66,6 +68,16 @@ SetSpanAttributesFn: TypeAlias = Callable[[Mapping[str, Any]], None]
|
|
|
66
68
|
ChatCompletionChunk: TypeAlias = Union[TextChunk, ToolCallChunk]
|
|
67
69
|
|
|
68
70
|
|
|
71
|
+
@dataclass
|
|
72
|
+
class PlaygroundClientCredential:
|
|
73
|
+
"""
|
|
74
|
+
Represents a credential for LLM providers.
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
env_var_name: str
|
|
78
|
+
value: str
|
|
79
|
+
|
|
80
|
+
|
|
69
81
|
class Dependency:
|
|
70
82
|
"""
|
|
71
83
|
Set the module_name to the import name if it is different from the install name
|
|
@@ -172,9 +184,10 @@ class PlaygroundStreamingClient(ABC):
|
|
|
172
184
|
def __init__(
|
|
173
185
|
self,
|
|
174
186
|
model: GenerativeModelInput,
|
|
175
|
-
|
|
187
|
+
credentials: Optional[list[PlaygroundClientCredential]] = None,
|
|
176
188
|
) -> None:
|
|
177
189
|
self._attributes: dict[str, AttributeValue] = dict()
|
|
190
|
+
self._credentials = credentials or []
|
|
178
191
|
|
|
179
192
|
@classmethod
|
|
180
193
|
@abstractmethod
|
|
@@ -243,11 +256,11 @@ class OpenAIBaseStreamingClient(PlaygroundStreamingClient):
|
|
|
243
256
|
*,
|
|
244
257
|
client: Union["AsyncOpenAI", "AsyncAzureOpenAI"],
|
|
245
258
|
model: GenerativeModelInput,
|
|
246
|
-
|
|
259
|
+
credentials: Optional[list[PlaygroundClientCredential]] = None,
|
|
247
260
|
) -> None:
|
|
248
261
|
from openai import RateLimitError as OpenAIRateLimitError
|
|
249
262
|
|
|
250
|
-
super().__init__(model=model,
|
|
263
|
+
super().__init__(model=model, credentials=credentials)
|
|
251
264
|
self.client = client
|
|
252
265
|
self.model_name = model.name
|
|
253
266
|
self.rate_limiter = PlaygroundRateLimiter(model.provider_key, OpenAIRateLimitError)
|
|
@@ -296,7 +309,6 @@ class OpenAIBaseStreamingClient(PlaygroundStreamingClient):
|
|
|
296
309
|
invocation_name="top_p",
|
|
297
310
|
canonical_name=CanonicalParameterName.TOP_P,
|
|
298
311
|
label="Top P",
|
|
299
|
-
default_value=1.0,
|
|
300
312
|
min_value=0.0,
|
|
301
313
|
max_value=1.0,
|
|
302
314
|
),
|
|
@@ -315,6 +327,10 @@ class OpenAIBaseStreamingClient(PlaygroundStreamingClient):
|
|
|
315
327
|
label="Response Format",
|
|
316
328
|
canonical_name=CanonicalParameterName.RESPONSE_FORMAT,
|
|
317
329
|
),
|
|
330
|
+
JSONInvocationParameter(
|
|
331
|
+
invocation_name="extra_body",
|
|
332
|
+
label="Extra Body",
|
|
333
|
+
),
|
|
318
334
|
]
|
|
319
335
|
|
|
320
336
|
async def chat_completion_create(
|
|
@@ -347,7 +363,6 @@ class OpenAIBaseStreamingClient(PlaygroundStreamingClient):
|
|
|
347
363
|
):
|
|
348
364
|
if (usage := chunk.usage) is not None:
|
|
349
365
|
token_usage = usage
|
|
350
|
-
continue
|
|
351
366
|
if not chunk.choices:
|
|
352
367
|
# for Azure, initial chunk contains the content filter
|
|
353
368
|
continue
|
|
@@ -426,9 +441,9 @@ class OpenAIBaseStreamingClient(PlaygroundStreamingClient):
|
|
|
426
441
|
if role is ChatCompletionMessageRole.TOOL:
|
|
427
442
|
if tool_call_id is None:
|
|
428
443
|
raise ValueError("tool_call_id is required for tool messages")
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
444
|
+
return ChatCompletionToolMessageParam(
|
|
445
|
+
{"content": content, "role": "tool", "tool_call_id": tool_call_id}
|
|
446
|
+
)
|
|
432
447
|
assert_never(role)
|
|
433
448
|
|
|
434
449
|
def to_openai_tool_call_param(
|
|
@@ -452,288 +467,272 @@ class OpenAIBaseStreamingClient(PlaygroundStreamingClient):
|
|
|
452
467
|
yield LLM_TOKEN_COUNT_COMPLETION, usage.completion_tokens
|
|
453
468
|
yield LLM_TOKEN_COUNT_TOTAL, usage.total_tokens
|
|
454
469
|
|
|
470
|
+
if hasattr(usage, "prompt_tokens_details") and usage.prompt_tokens_details is not None:
|
|
471
|
+
prompt_details = usage.prompt_tokens_details
|
|
472
|
+
if (
|
|
473
|
+
hasattr(prompt_details, "cached_tokens")
|
|
474
|
+
and prompt_details.cached_tokens is not None
|
|
475
|
+
):
|
|
476
|
+
yield LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_READ, prompt_details.cached_tokens
|
|
477
|
+
if hasattr(prompt_details, "audio_tokens") and prompt_details.audio_tokens is not None:
|
|
478
|
+
yield LLM_TOKEN_COUNT_PROMPT_DETAILS_AUDIO, prompt_details.audio_tokens
|
|
479
|
+
|
|
480
|
+
if (
|
|
481
|
+
hasattr(usage, "completion_tokens_details")
|
|
482
|
+
and usage.completion_tokens_details is not None
|
|
483
|
+
):
|
|
484
|
+
completion_details = usage.completion_tokens_details
|
|
485
|
+
if (
|
|
486
|
+
hasattr(completion_details, "reasoning_tokens")
|
|
487
|
+
and completion_details.reasoning_tokens is not None
|
|
488
|
+
):
|
|
489
|
+
yield (
|
|
490
|
+
LLM_TOKEN_COUNT_COMPLETION_DETAILS_REASONING,
|
|
491
|
+
completion_details.reasoning_tokens,
|
|
492
|
+
)
|
|
493
|
+
if (
|
|
494
|
+
hasattr(completion_details, "audio_tokens")
|
|
495
|
+
and completion_details.audio_tokens is not None
|
|
496
|
+
):
|
|
497
|
+
yield LLM_TOKEN_COUNT_COMPLETION_DETAILS_AUDIO, completion_details.audio_tokens
|
|
498
|
+
|
|
499
|
+
|
|
500
|
+
def _get_credential_value(
|
|
501
|
+
credentials: Optional[list[PlaygroundClientCredential]], env_var_name: str
|
|
502
|
+
) -> Optional[str]:
|
|
503
|
+
"""Helper function to extract credential value from credentials list."""
|
|
504
|
+
if not credentials:
|
|
505
|
+
return None
|
|
506
|
+
return next(
|
|
507
|
+
(credential.value for credential in credentials if credential.env_var_name == env_var_name),
|
|
508
|
+
None,
|
|
509
|
+
)
|
|
510
|
+
|
|
511
|
+
|
|
512
|
+
def _require_credential(
|
|
513
|
+
credentials: Optional[list[PlaygroundClientCredential]], env_var_name: str, provider_name: str
|
|
514
|
+
) -> str:
|
|
515
|
+
"""Helper function to require a credential value, raising an exception if not found."""
|
|
516
|
+
value = _get_credential_value(credentials, env_var_name)
|
|
517
|
+
if value is None:
|
|
518
|
+
raise BadRequest(f"Missing required credential '{env_var_name}' for {provider_name}")
|
|
519
|
+
return value
|
|
520
|
+
|
|
455
521
|
|
|
456
522
|
@register_llm_client(
|
|
457
|
-
provider_key=GenerativeProviderKey.
|
|
523
|
+
provider_key=GenerativeProviderKey.DEEPSEEK,
|
|
458
524
|
model_names=[
|
|
459
525
|
PROVIDER_DEFAULT,
|
|
460
|
-
"
|
|
461
|
-
"
|
|
462
|
-
"gpt-4.1-nano",
|
|
463
|
-
"gpt-4.1-2025-04-14",
|
|
464
|
-
"gpt-4.1-mini-2025-04-14",
|
|
465
|
-
"gpt-4.1-nano-2025-04-14",
|
|
466
|
-
"gpt-4o",
|
|
467
|
-
"gpt-4o-2024-11-20",
|
|
468
|
-
"gpt-4o-2024-08-06",
|
|
469
|
-
"gpt-4o-2024-05-13",
|
|
470
|
-
"chatgpt-4o-latest",
|
|
471
|
-
"gpt-4o-mini",
|
|
472
|
-
"gpt-4o-mini-2024-07-18",
|
|
473
|
-
"gpt-4-turbo",
|
|
474
|
-
"gpt-4-turbo-2024-04-09",
|
|
475
|
-
"gpt-4-turbo-preview",
|
|
476
|
-
"gpt-4-0125-preview",
|
|
477
|
-
"gpt-4-1106-preview",
|
|
478
|
-
"gpt-4",
|
|
479
|
-
"gpt-4-0613",
|
|
480
|
-
"gpt-3.5-turbo-0125",
|
|
481
|
-
"gpt-3.5-turbo",
|
|
482
|
-
"gpt-3.5-turbo-1106",
|
|
483
|
-
# preview models
|
|
484
|
-
"gpt-4.5-preview",
|
|
526
|
+
"deepseek-chat",
|
|
527
|
+
"deepseek-reasoner",
|
|
485
528
|
],
|
|
486
529
|
)
|
|
487
|
-
class
|
|
530
|
+
class DeepSeekStreamingClient(OpenAIBaseStreamingClient):
|
|
488
531
|
def __init__(
|
|
489
532
|
self,
|
|
490
533
|
model: GenerativeModelInput,
|
|
491
|
-
|
|
534
|
+
credentials: Optional[list[PlaygroundClientCredential]] = None,
|
|
492
535
|
) -> None:
|
|
493
536
|
from openai import AsyncOpenAI
|
|
494
537
|
|
|
495
|
-
base_url = model.base_url or getenv("
|
|
496
|
-
|
|
538
|
+
base_url = model.base_url or getenv("DEEPSEEK_BASE_URL")
|
|
539
|
+
|
|
540
|
+
# Try to get API key from credentials first, then fallback to env
|
|
541
|
+
api_key = _get_credential_value(credentials, "DEEPSEEK_API_KEY") or getenv(
|
|
542
|
+
"DEEPSEEK_API_KEY"
|
|
543
|
+
)
|
|
544
|
+
|
|
545
|
+
if not api_key:
|
|
497
546
|
if not base_url:
|
|
498
|
-
raise BadRequest("An API key is required for
|
|
547
|
+
raise BadRequest("An API key is required for DeepSeek models")
|
|
499
548
|
api_key = "sk-fake-api-key"
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
549
|
+
|
|
550
|
+
client = AsyncOpenAI(
|
|
551
|
+
api_key=api_key,
|
|
552
|
+
base_url=base_url or "https://api.deepseek.com",
|
|
553
|
+
default_headers=model.custom_headers or None,
|
|
554
|
+
)
|
|
555
|
+
super().__init__(client=client, model=model, credentials=credentials)
|
|
556
|
+
# DeepSeek uses OpenAI-compatible API but we'll track it as a separate provider
|
|
557
|
+
# Adding a custom "deepseek" provider value to make it distinguishable in traces
|
|
558
|
+
self._attributes[LLM_PROVIDER] = "deepseek"
|
|
503
559
|
self._attributes[LLM_SYSTEM] = OpenInferenceLLMSystemValues.OPENAI.value
|
|
504
560
|
|
|
505
561
|
|
|
506
562
|
@register_llm_client(
|
|
507
|
-
provider_key=GenerativeProviderKey.
|
|
563
|
+
provider_key=GenerativeProviderKey.XAI,
|
|
508
564
|
model_names=[
|
|
509
|
-
|
|
510
|
-
"
|
|
511
|
-
"
|
|
512
|
-
"
|
|
513
|
-
"
|
|
514
|
-
"
|
|
515
|
-
"
|
|
516
|
-
"o3-mini-2025-01-31",
|
|
565
|
+
PROVIDER_DEFAULT,
|
|
566
|
+
"grok-3",
|
|
567
|
+
"grok-3-fast",
|
|
568
|
+
"grok-3-mini",
|
|
569
|
+
"grok-3-mini-fast",
|
|
570
|
+
"grok-2-1212",
|
|
571
|
+
"grok-2-vision-1212",
|
|
517
572
|
],
|
|
518
573
|
)
|
|
519
|
-
class
|
|
520
|
-
|
|
521
|
-
def supported_invocation_parameters(cls) -> list[InvocationParameter]:
|
|
522
|
-
return [
|
|
523
|
-
StringInvocationParameter(
|
|
524
|
-
invocation_name="reasoning_effort",
|
|
525
|
-
label="Reasoning Effort",
|
|
526
|
-
canonical_name=CanonicalParameterName.REASONING_EFFORT,
|
|
527
|
-
),
|
|
528
|
-
IntInvocationParameter(
|
|
529
|
-
invocation_name="max_completion_tokens",
|
|
530
|
-
canonical_name=CanonicalParameterName.MAX_COMPLETION_TOKENS,
|
|
531
|
-
label="Max Completion Tokens",
|
|
532
|
-
),
|
|
533
|
-
IntInvocationParameter(
|
|
534
|
-
invocation_name="seed",
|
|
535
|
-
canonical_name=CanonicalParameterName.RANDOM_SEED,
|
|
536
|
-
label="Seed",
|
|
537
|
-
),
|
|
538
|
-
JSONInvocationParameter(
|
|
539
|
-
invocation_name="tool_choice",
|
|
540
|
-
label="Tool Choice",
|
|
541
|
-
canonical_name=CanonicalParameterName.TOOL_CHOICE,
|
|
542
|
-
),
|
|
543
|
-
JSONInvocationParameter(
|
|
544
|
-
invocation_name="response_format",
|
|
545
|
-
label="Response Format",
|
|
546
|
-
canonical_name=CanonicalParameterName.RESPONSE_FORMAT,
|
|
547
|
-
),
|
|
548
|
-
]
|
|
549
|
-
|
|
550
|
-
async def chat_completion_create(
|
|
574
|
+
class XAIStreamingClient(OpenAIBaseStreamingClient):
|
|
575
|
+
def __init__(
|
|
551
576
|
self,
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
**invocation_parameters: Any,
|
|
557
|
-
) -> AsyncIterator[ChatCompletionChunk]:
|
|
558
|
-
from openai import NOT_GIVEN
|
|
559
|
-
|
|
560
|
-
# Convert standard messages to OpenAI messages
|
|
561
|
-
openai_messages = []
|
|
562
|
-
for message in messages:
|
|
563
|
-
openai_message = self.to_openai_chat_completion_param(*message)
|
|
564
|
-
if openai_message is not None:
|
|
565
|
-
openai_messages.append(openai_message)
|
|
566
|
-
|
|
567
|
-
throttled_create = self.rate_limiter._alimit(self.client.chat.completions.create)
|
|
568
|
-
response = await throttled_create(
|
|
569
|
-
messages=openai_messages,
|
|
570
|
-
model=self.model_name,
|
|
571
|
-
stream=False,
|
|
572
|
-
tools=tools or NOT_GIVEN,
|
|
573
|
-
**invocation_parameters,
|
|
574
|
-
)
|
|
575
|
-
|
|
576
|
-
if response.usage is not None:
|
|
577
|
-
self._attributes.update(dict(self._llm_token_counts(response.usage)))
|
|
577
|
+
model: GenerativeModelInput,
|
|
578
|
+
credentials: Optional[list[PlaygroundClientCredential]] = None,
|
|
579
|
+
) -> None:
|
|
580
|
+
from openai import AsyncOpenAI
|
|
578
581
|
|
|
579
|
-
|
|
580
|
-
if choice.message.content:
|
|
581
|
-
yield TextChunk(content=choice.message.content)
|
|
582
|
+
base_url = model.base_url or getenv("XAI_BASE_URL")
|
|
582
583
|
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
yield ToolCallChunk(
|
|
586
|
-
id=tool_call.id,
|
|
587
|
-
function=FunctionCallChunk(
|
|
588
|
-
name=tool_call.function.name,
|
|
589
|
-
arguments=tool_call.function.arguments,
|
|
590
|
-
),
|
|
591
|
-
)
|
|
584
|
+
# Try to get API key from credentials first, then fallback to env
|
|
585
|
+
api_key = _get_credential_value(credentials, "XAI_API_KEY") or getenv("XAI_API_KEY")
|
|
592
586
|
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
tool_call_id: Optional[str] = None,
|
|
598
|
-
tool_calls: Optional[list[JSONScalarType]] = None,
|
|
599
|
-
) -> Optional["ChatCompletionMessageParam"]:
|
|
600
|
-
from openai.types.chat import (
|
|
601
|
-
ChatCompletionAssistantMessageParam,
|
|
602
|
-
ChatCompletionDeveloperMessageParam,
|
|
603
|
-
ChatCompletionToolMessageParam,
|
|
604
|
-
ChatCompletionUserMessageParam,
|
|
605
|
-
)
|
|
587
|
+
if not api_key:
|
|
588
|
+
if not base_url:
|
|
589
|
+
raise BadRequest("An API key is required for xAI models")
|
|
590
|
+
api_key = "sk-fake-api-key"
|
|
606
591
|
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
"role": "user",
|
|
612
|
-
}
|
|
613
|
-
)
|
|
614
|
-
if role is ChatCompletionMessageRole.SYSTEM:
|
|
615
|
-
return ChatCompletionDeveloperMessageParam(
|
|
616
|
-
{
|
|
617
|
-
"content": content,
|
|
618
|
-
"role": "developer",
|
|
619
|
-
}
|
|
620
|
-
)
|
|
621
|
-
if role is ChatCompletionMessageRole.AI:
|
|
622
|
-
if tool_calls is None:
|
|
623
|
-
return ChatCompletionAssistantMessageParam(
|
|
624
|
-
{
|
|
625
|
-
"content": content,
|
|
626
|
-
"role": "assistant",
|
|
627
|
-
}
|
|
628
|
-
)
|
|
629
|
-
else:
|
|
630
|
-
return ChatCompletionAssistantMessageParam(
|
|
631
|
-
{
|
|
632
|
-
"content": content,
|
|
633
|
-
"role": "assistant",
|
|
634
|
-
"tool_calls": [
|
|
635
|
-
self.to_openai_tool_call_param(tool_call) for tool_call in tool_calls
|
|
636
|
-
],
|
|
637
|
-
}
|
|
638
|
-
)
|
|
639
|
-
if role is ChatCompletionMessageRole.TOOL:
|
|
640
|
-
if tool_call_id is None:
|
|
641
|
-
raise ValueError("tool_call_id is required for tool messages")
|
|
642
|
-
return ChatCompletionToolMessageParam(
|
|
643
|
-
{"content": content, "role": "tool", "tool_call_id": tool_call_id}
|
|
592
|
+
client = AsyncOpenAI(
|
|
593
|
+
api_key=api_key,
|
|
594
|
+
base_url=base_url or "https://api.x.ai/v1",
|
|
595
|
+
default_headers=model.custom_headers or None,
|
|
644
596
|
)
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
yield LLM_TOKEN_COUNT_COMPLETION, usage.completion_tokens
|
|
651
|
-
yield LLM_TOKEN_COUNT_TOTAL, usage.total_tokens
|
|
597
|
+
super().__init__(client=client, model=model, credentials=credentials)
|
|
598
|
+
# xAI uses OpenAI-compatible API but we'll track it as a separate provider
|
|
599
|
+
# Adding a custom "xai" provider value to make it distinguishable in traces
|
|
600
|
+
self._attributes[LLM_PROVIDER] = "xai"
|
|
601
|
+
self._attributes[LLM_SYSTEM] = OpenInferenceLLMSystemValues.OPENAI.value
|
|
652
602
|
|
|
653
603
|
|
|
654
604
|
@register_llm_client(
|
|
655
|
-
provider_key=GenerativeProviderKey.
|
|
605
|
+
provider_key=GenerativeProviderKey.OLLAMA,
|
|
656
606
|
model_names=[
|
|
657
607
|
PROVIDER_DEFAULT,
|
|
608
|
+
"llama3.3",
|
|
609
|
+
"llama3.2",
|
|
610
|
+
"llama3.1",
|
|
611
|
+
"llama3",
|
|
612
|
+
"llama2",
|
|
613
|
+
"mistral",
|
|
614
|
+
"mixtral",
|
|
615
|
+
"codellama",
|
|
616
|
+
"phi3",
|
|
617
|
+
"qwen2.5",
|
|
618
|
+
"gemma2",
|
|
658
619
|
],
|
|
659
620
|
)
|
|
660
|
-
class
|
|
621
|
+
class OllamaStreamingClient(OpenAIBaseStreamingClient):
|
|
661
622
|
def __init__(
|
|
662
623
|
self,
|
|
663
624
|
model: GenerativeModelInput,
|
|
664
|
-
|
|
665
|
-
):
|
|
666
|
-
from openai import
|
|
667
|
-
|
|
668
|
-
if not (endpoint := model.endpoint or getenv("AZURE_OPENAI_ENDPOINT")):
|
|
669
|
-
raise BadRequest("An Azure endpoint is required for Azure OpenAI models")
|
|
670
|
-
if not (api_version := model.api_version or getenv("OPENAI_API_VERSION")):
|
|
671
|
-
raise BadRequest("An OpenAI API version is required for Azure OpenAI models")
|
|
672
|
-
if api_key := api_key or getenv("AZURE_OPENAI_API_KEY"):
|
|
673
|
-
client = AsyncAzureOpenAI(
|
|
674
|
-
api_key=api_key,
|
|
675
|
-
azure_endpoint=endpoint,
|
|
676
|
-
api_version=api_version,
|
|
677
|
-
)
|
|
678
|
-
else:
|
|
679
|
-
try:
|
|
680
|
-
from azure.identity.aio import DefaultAzureCredential, get_bearer_token_provider
|
|
681
|
-
except ImportError:
|
|
682
|
-
raise BadRequest(
|
|
683
|
-
"Provide an API key for Azure OpenAI models or use azure-identity, see. e.g. "
|
|
684
|
-
"https://learn.microsoft.com/en-us/python/api/azure-identity/azure.identity.environmentcredential?view=azure-python" # noqa: E501
|
|
685
|
-
)
|
|
625
|
+
credentials: Optional[list[PlaygroundClientCredential]] = None,
|
|
626
|
+
) -> None:
|
|
627
|
+
from openai import AsyncOpenAI
|
|
686
628
|
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
629
|
+
base_url = model.base_url or getenv("OLLAMA_BASE_URL")
|
|
630
|
+
if not base_url:
|
|
631
|
+
raise BadRequest("An Ollama base URL is required for Ollama models")
|
|
632
|
+
api_key = "ollama"
|
|
633
|
+
client = AsyncOpenAI(
|
|
634
|
+
api_key=api_key,
|
|
635
|
+
base_url=base_url,
|
|
636
|
+
default_headers=model.custom_headers or None,
|
|
637
|
+
)
|
|
638
|
+
super().__init__(client=client, model=model, credentials=credentials)
|
|
639
|
+
# Ollama uses OpenAI-compatible API but we'll track it as a separate provider
|
|
640
|
+
# Adding a custom "ollama" provider value to make it distinguishable in traces
|
|
641
|
+
self._attributes[LLM_PROVIDER] = "ollama"
|
|
697
642
|
self._attributes[LLM_SYSTEM] = OpenInferenceLLMSystemValues.OPENAI.value
|
|
698
643
|
|
|
699
644
|
|
|
700
645
|
@register_llm_client(
|
|
701
|
-
provider_key=GenerativeProviderKey.
|
|
646
|
+
provider_key=GenerativeProviderKey.AWS,
|
|
702
647
|
model_names=[
|
|
703
648
|
PROVIDER_DEFAULT,
|
|
704
|
-
"claude-
|
|
705
|
-
"claude-
|
|
706
|
-
"claude-
|
|
707
|
-
"claude-
|
|
708
|
-
"claude-
|
|
709
|
-
"claude-
|
|
710
|
-
"claude-3-
|
|
711
|
-
"claude-3-
|
|
712
|
-
"claude-3-sonnet-
|
|
713
|
-
"claude-3-haiku-
|
|
649
|
+
"anthropic.claude-opus-4-5-20251101-v1:0",
|
|
650
|
+
"anthropic.claude-sonnet-4-5-20250929-v1:0",
|
|
651
|
+
"anthropic.claude-haiku-4-5-20251001-v1:0",
|
|
652
|
+
"anthropic.claude-opus-4-1-20250805-v1:0",
|
|
653
|
+
"anthropic.claude-opus-4-20250514-v1:0",
|
|
654
|
+
"anthropic.claude-sonnet-4-20250514-v1:0",
|
|
655
|
+
"anthropic.claude-3-7-sonnet-20250219-v1:0",
|
|
656
|
+
"anthropic.claude-3-5-sonnet-20241022-v2:0",
|
|
657
|
+
"anthropic.claude-3-5-sonnet-20240620-v1:0",
|
|
658
|
+
"anthropic.claude-3-5-haiku-20241022-v1:0",
|
|
659
|
+
"anthropic.claude-3-haiku-20240307-v1:0",
|
|
660
|
+
"amazon.titan-embed-text-v2:0",
|
|
661
|
+
"amazon.nova-pro-v1:0",
|
|
662
|
+
"amazon.nova-premier-v1:0:8k",
|
|
663
|
+
"amazon.nova-premier-v1:0:20k",
|
|
664
|
+
"amazon.nova-premier-v1:0:1000k",
|
|
665
|
+
"amazon.nova-premier-v1:0:mm",
|
|
666
|
+
"amazon.nova-premier-v1:0",
|
|
667
|
+
"amazon.nova-lite-v1:0",
|
|
668
|
+
"amazon.nova-micro-v1:0",
|
|
669
|
+
"deepseek.r1-v1:0",
|
|
670
|
+
"mistral.pixtral-large-2502-v1:0",
|
|
671
|
+
"meta.llama3-1-8b-instruct-v1:0:128k",
|
|
672
|
+
"meta.llama3-1-8b-instruct-v1:0",
|
|
673
|
+
"meta.llama3-1-70b-instruct-v1:0:128k",
|
|
674
|
+
"meta.llama3-1-70b-instruct-v1:0",
|
|
675
|
+
"meta.llama3-1-405b-instruct-v1:0",
|
|
676
|
+
"meta.llama3-2-11b-instruct-v1:0",
|
|
677
|
+
"meta.llama3-2-90b-instruct-v1:0",
|
|
678
|
+
"meta.llama3-2-1b-instruct-v1:0",
|
|
679
|
+
"meta.llama3-2-3b-instruct-v1:0",
|
|
680
|
+
"meta.llama3-3-70b-instruct-v1:0",
|
|
681
|
+
"meta.llama4-scout-17b-instruct-v1:0",
|
|
682
|
+
"meta.llama4-maverick-17b-instruct-v1:0",
|
|
714
683
|
],
|
|
715
684
|
)
|
|
716
|
-
class
|
|
685
|
+
class BedrockStreamingClient(PlaygroundStreamingClient):
|
|
717
686
|
def __init__(
|
|
718
687
|
self,
|
|
719
688
|
model: GenerativeModelInput,
|
|
720
|
-
|
|
689
|
+
credentials: Optional[list[PlaygroundClientCredential]] = None,
|
|
721
690
|
) -> None:
|
|
722
|
-
import
|
|
723
|
-
|
|
724
|
-
super().__init__(model=model,
|
|
725
|
-
|
|
726
|
-
self.
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
|
|
691
|
+
import boto3 # type: ignore[import-untyped]
|
|
692
|
+
|
|
693
|
+
super().__init__(model=model, credentials=credentials)
|
|
694
|
+
region = model.region or "us-east-1"
|
|
695
|
+
self.api = "converse"
|
|
696
|
+
custom_headers = model.custom_headers
|
|
697
|
+
aws_access_key_id = _get_credential_value(credentials, "AWS_ACCESS_KEY_ID") or getenv(
|
|
698
|
+
"AWS_ACCESS_KEY_ID"
|
|
699
|
+
)
|
|
700
|
+
aws_secret_access_key = _get_credential_value(
|
|
701
|
+
credentials, "AWS_SECRET_ACCESS_KEY"
|
|
702
|
+
) or getenv("AWS_SECRET_ACCESS_KEY")
|
|
703
|
+
aws_session_token = _get_credential_value(credentials, "AWS_SESSION_TOKEN") or getenv(
|
|
704
|
+
"AWS_SESSION_TOKEN"
|
|
705
|
+
)
|
|
730
706
|
self.model_name = model.name
|
|
731
|
-
|
|
732
|
-
|
|
707
|
+
session = boto3.Session(
|
|
708
|
+
region_name=region,
|
|
709
|
+
aws_access_key_id=aws_access_key_id,
|
|
710
|
+
aws_secret_access_key=aws_secret_access_key,
|
|
711
|
+
aws_session_token=aws_session_token,
|
|
712
|
+
)
|
|
713
|
+
client = session.client(service_name="bedrock-runtime")
|
|
714
|
+
|
|
715
|
+
# Add custom headers support via boto3 event system
|
|
716
|
+
if custom_headers:
|
|
717
|
+
|
|
718
|
+
def add_custom_headers(request: "AWSPreparedRequest", **kwargs: Any) -> None:
|
|
719
|
+
request.headers.update(custom_headers)
|
|
720
|
+
|
|
721
|
+
client.meta.events.register("before-send.*", add_custom_headers)
|
|
722
|
+
|
|
723
|
+
self.client = client
|
|
724
|
+
self._attributes[LLM_PROVIDER] = "aws"
|
|
725
|
+
self._attributes[LLM_SYSTEM] = "aws"
|
|
726
|
+
|
|
727
|
+
@staticmethod
|
|
728
|
+
def _setup_custom_headers(client: Any, custom_headers: Mapping[str, str]) -> None:
|
|
729
|
+
"""Setup custom headers using boto3's event system."""
|
|
730
|
+
if not custom_headers:
|
|
731
|
+
return
|
|
733
732
|
|
|
734
733
|
@classmethod
|
|
735
734
|
def dependencies(cls) -> list[Dependency]:
|
|
736
|
-
return [Dependency(name="
|
|
735
|
+
return [Dependency(name="boto3")]
|
|
737
736
|
|
|
738
737
|
@classmethod
|
|
739
738
|
def supported_invocation_parameters(cls) -> list[InvocationParameter]:
|
|
@@ -743,7 +742,6 @@ class AnthropicStreamingClient(PlaygroundStreamingClient):
|
|
|
743
742
|
canonical_name=CanonicalParameterName.MAX_COMPLETION_TOKENS,
|
|
744
743
|
label="Max Tokens",
|
|
745
744
|
default_value=1024,
|
|
746
|
-
required=True,
|
|
747
745
|
),
|
|
748
746
|
BoundedFloatInvocationParameter(
|
|
749
747
|
invocation_name="temperature",
|
|
@@ -753,16 +751,10 @@ class AnthropicStreamingClient(PlaygroundStreamingClient):
|
|
|
753
751
|
min_value=0.0,
|
|
754
752
|
max_value=1.0,
|
|
755
753
|
),
|
|
756
|
-
StringListInvocationParameter(
|
|
757
|
-
invocation_name="stop_sequences",
|
|
758
|
-
canonical_name=CanonicalParameterName.STOP_SEQUENCES,
|
|
759
|
-
label="Stop Sequences",
|
|
760
|
-
),
|
|
761
754
|
BoundedFloatInvocationParameter(
|
|
762
755
|
invocation_name="top_p",
|
|
763
756
|
canonical_name=CanonicalParameterName.TOP_P,
|
|
764
757
|
label="Top P",
|
|
765
|
-
default_value=1.0,
|
|
766
758
|
min_value=0.0,
|
|
767
759
|
max_value=1.0,
|
|
768
760
|
),
|
|
@@ -781,14 +773,806 @@ class AnthropicStreamingClient(PlaygroundStreamingClient):
|
|
|
781
773
|
tools: list[JSONScalarType],
|
|
782
774
|
**invocation_parameters: Any,
|
|
783
775
|
) -> AsyncIterator[ChatCompletionChunk]:
|
|
784
|
-
|
|
785
|
-
|
|
776
|
+
if self.api == "invoke":
|
|
777
|
+
async for chunk in self._handle_invoke_api(messages, tools, invocation_parameters):
|
|
778
|
+
yield chunk
|
|
779
|
+
else:
|
|
780
|
+
async for chunk in self._handle_converse_api(messages, tools, invocation_parameters):
|
|
781
|
+
yield chunk
|
|
786
782
|
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
|
|
783
|
+
async def _handle_converse_api(
|
|
784
|
+
self,
|
|
785
|
+
messages: list[
|
|
786
|
+
tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[JSONScalarType]]]
|
|
787
|
+
],
|
|
788
|
+
tools: list[JSONScalarType],
|
|
789
|
+
invocation_parameters: dict[str, Any],
|
|
790
|
+
) -> AsyncIterator[ChatCompletionChunk]:
|
|
791
|
+
"""
|
|
792
|
+
Handle the converse API.
|
|
793
|
+
"""
|
|
794
|
+
# Build messages in Converse API format
|
|
795
|
+
converse_messages = self._build_converse_messages(messages)
|
|
796
|
+
|
|
797
|
+
inference_config = {}
|
|
798
|
+
if (
|
|
799
|
+
"max_tokens" in invocation_parameters
|
|
800
|
+
and invocation_parameters["max_tokens"] is not None
|
|
801
|
+
):
|
|
802
|
+
inference_config["maxTokens"] = invocation_parameters["max_tokens"]
|
|
803
|
+
if (
|
|
804
|
+
"temperature" in invocation_parameters
|
|
805
|
+
and invocation_parameters["temperature"] is not None
|
|
806
|
+
):
|
|
807
|
+
inference_config["temperature"] = invocation_parameters["temperature"]
|
|
808
|
+
if "top_p" in invocation_parameters and invocation_parameters["top_p"] is not None:
|
|
809
|
+
inference_config["topP"] = invocation_parameters["top_p"]
|
|
810
|
+
|
|
811
|
+
# Build the request parameters for Converse API
|
|
812
|
+
converse_params: dict[str, Any] = {
|
|
813
|
+
"modelId": self.model_name,
|
|
814
|
+
"messages": converse_messages,
|
|
815
|
+
"inferenceConfig": inference_config,
|
|
816
|
+
}
|
|
817
|
+
|
|
818
|
+
# Add system prompt if available
|
|
819
|
+
system_prompt = self._extract_system_prompt(messages)
|
|
820
|
+
if system_prompt:
|
|
821
|
+
converse_params["system"] = [{"text": system_prompt}]
|
|
822
|
+
|
|
823
|
+
# Add tools if provided
|
|
824
|
+
if tools:
|
|
825
|
+
converse_params["toolConfig"] = {"tools": tools}
|
|
826
|
+
if (
|
|
827
|
+
"tool_choice" in invocation_parameters
|
|
828
|
+
and invocation_parameters["tool_choice"]["type"] != "none"
|
|
829
|
+
):
|
|
830
|
+
converse_params["toolConfig"]["toolChoice"] = {}
|
|
831
|
+
|
|
832
|
+
if invocation_parameters["tool_choice"]["type"] == "auto":
|
|
833
|
+
converse_params["toolConfig"]["toolChoice"]["auto"] = {}
|
|
834
|
+
elif invocation_parameters["tool_choice"]["type"] == "any":
|
|
835
|
+
converse_params["toolConfig"]["toolChoice"]["any"] = {}
|
|
836
|
+
else:
|
|
837
|
+
converse_params["toolConfig"]["toolChoice"]["tool"] = {
|
|
838
|
+
"name": invocation_parameters["tool_choice"]["name"],
|
|
839
|
+
}
|
|
840
|
+
|
|
841
|
+
# Make the streaming API call
|
|
842
|
+
response = self.client.converse_stream(**converse_params)
|
|
843
|
+
|
|
844
|
+
# Track active tool calls
|
|
845
|
+
active_tool_calls = {} # contentBlockIndex -> {id, name, arguments_buffer}
|
|
846
|
+
|
|
847
|
+
# Process the event stream
|
|
848
|
+
event_stream = response.get("stream")
|
|
849
|
+
|
|
850
|
+
for event in event_stream:
|
|
851
|
+
# Handle content block start events
|
|
852
|
+
if "contentBlockStart" in event:
|
|
853
|
+
content_block_start = event["contentBlockStart"]
|
|
854
|
+
start_event = content_block_start.get("start", {})
|
|
855
|
+
block_index = content_block_start.get(
|
|
856
|
+
"contentBlockIndex", 0
|
|
857
|
+
) # Get the actual index
|
|
858
|
+
|
|
859
|
+
if "toolUse" in start_event:
|
|
860
|
+
tool_use = start_event["toolUse"]
|
|
861
|
+
active_tool_calls[block_index] = { # Use the actual block index
|
|
862
|
+
"id": tool_use.get("toolUseId"),
|
|
863
|
+
"name": tool_use.get("name"),
|
|
864
|
+
"arguments_buffer": "",
|
|
865
|
+
}
|
|
866
|
+
|
|
867
|
+
# Yield initial tool call chunk
|
|
868
|
+
yield ToolCallChunk(
|
|
869
|
+
id=tool_use.get("toolUseId"),
|
|
870
|
+
function=FunctionCallChunk(
|
|
871
|
+
name=tool_use.get("name"),
|
|
872
|
+
arguments="",
|
|
873
|
+
),
|
|
874
|
+
)
|
|
875
|
+
|
|
876
|
+
# Handle content block delta events
|
|
877
|
+
elif "contentBlockDelta" in event:
|
|
878
|
+
content_delta = event["contentBlockDelta"]
|
|
879
|
+
delta = content_delta.get("delta", {})
|
|
880
|
+
delta_index = content_delta.get("contentBlockIndex", 0)
|
|
881
|
+
|
|
882
|
+
# Handle text delta
|
|
883
|
+
if "text" in delta:
|
|
884
|
+
yield TextChunk(content=delta["text"])
|
|
885
|
+
|
|
886
|
+
# Handle tool use delta
|
|
887
|
+
elif "toolUse" in delta:
|
|
888
|
+
tool_delta = delta["toolUse"]
|
|
889
|
+
if "input" in tool_delta and delta_index in active_tool_calls:
|
|
890
|
+
# Accumulate tool arguments
|
|
891
|
+
json_chunk = tool_delta["input"]
|
|
892
|
+
active_tool_calls[delta_index]["arguments_buffer"] += json_chunk
|
|
893
|
+
|
|
894
|
+
# Yield incremental argument update
|
|
895
|
+
yield ToolCallChunk(
|
|
896
|
+
id=active_tool_calls[delta_index]["id"],
|
|
897
|
+
function=FunctionCallChunk(
|
|
898
|
+
name=active_tool_calls[delta_index]["name"],
|
|
899
|
+
arguments=json_chunk,
|
|
900
|
+
),
|
|
901
|
+
)
|
|
902
|
+
|
|
903
|
+
# Handle content block stop events
|
|
904
|
+
elif "contentBlockStop" in event:
|
|
905
|
+
stop_index = event["contentBlockStop"].get("contentBlockIndex", 0)
|
|
906
|
+
if stop_index in active_tool_calls:
|
|
907
|
+
del active_tool_calls[stop_index]
|
|
908
|
+
|
|
909
|
+
elif "metadata" in event:
|
|
910
|
+
self._attributes.update(
|
|
911
|
+
{
|
|
912
|
+
LLM_TOKEN_COUNT_PROMPT: event.get("metadata")
|
|
913
|
+
.get("usage", {})
|
|
914
|
+
.get("inputTokens", 0)
|
|
915
|
+
}
|
|
916
|
+
)
|
|
917
|
+
|
|
918
|
+
self._attributes.update(
|
|
919
|
+
{
|
|
920
|
+
LLM_TOKEN_COUNT_COMPLETION: event.get("metadata")
|
|
921
|
+
.get("usage", {})
|
|
922
|
+
.get("outputTokens", 0)
|
|
923
|
+
}
|
|
924
|
+
)
|
|
925
|
+
|
|
926
|
+
self._attributes.update(
|
|
927
|
+
{
|
|
928
|
+
LLM_TOKEN_COUNT_TOTAL: event.get("metadata")
|
|
929
|
+
.get("usage", {})
|
|
930
|
+
.get("totalTokens", 0)
|
|
931
|
+
}
|
|
932
|
+
)
|
|
933
|
+
|
|
934
|
+
async def _handle_invoke_api(
|
|
935
|
+
self,
|
|
936
|
+
messages: list[
|
|
937
|
+
tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[JSONScalarType]]]
|
|
938
|
+
],
|
|
939
|
+
tools: list[JSONScalarType],
|
|
940
|
+
invocation_parameters: dict[str, Any],
|
|
941
|
+
) -> AsyncIterator[ChatCompletionChunk]:
|
|
942
|
+
if "anthropic" not in self.model_name:
|
|
943
|
+
raise ValueError("Invoke API is only supported for Anthropic models")
|
|
944
|
+
|
|
945
|
+
bedrock_messages, system_prompt = self._build_bedrock_messages(messages)
|
|
946
|
+
bedrock_params = {
|
|
947
|
+
"anthropic_version": "bedrock-2023-05-31",
|
|
948
|
+
"messages": bedrock_messages,
|
|
949
|
+
"system": system_prompt,
|
|
950
|
+
"tools": tools,
|
|
951
|
+
}
|
|
952
|
+
|
|
953
|
+
if (
|
|
954
|
+
"max_tokens" in invocation_parameters
|
|
955
|
+
and invocation_parameters["max_tokens"] is not None
|
|
956
|
+
):
|
|
957
|
+
bedrock_params["max_tokens"] = invocation_parameters["max_tokens"]
|
|
958
|
+
if (
|
|
959
|
+
"temperature" in invocation_parameters
|
|
960
|
+
and invocation_parameters["temperature"] is not None
|
|
961
|
+
):
|
|
962
|
+
bedrock_params["temperature"] = invocation_parameters["temperature"]
|
|
963
|
+
if "top_p" in invocation_parameters and invocation_parameters["top_p"] is not None:
|
|
964
|
+
bedrock_params["top_p"] = invocation_parameters["top_p"]
|
|
965
|
+
|
|
966
|
+
response = self.client.invoke_model_with_response_stream(
|
|
967
|
+
modelId=self.model_name,
|
|
968
|
+
contentType="application/json",
|
|
969
|
+
accept="application/json",
|
|
970
|
+
body=json.dumps(bedrock_params),
|
|
971
|
+
trace="ENABLED_FULL",
|
|
972
|
+
)
|
|
973
|
+
|
|
974
|
+
# The response['body'] is an EventStream object
|
|
975
|
+
event_stream = response["body"]
|
|
976
|
+
|
|
977
|
+
# Track active tool calls and their accumulating arguments
|
|
978
|
+
active_tool_calls: dict[int, dict[str, Any]] = {} # index -> {id, name, arguments_buffer}
|
|
979
|
+
|
|
980
|
+
for event in event_stream:
|
|
981
|
+
if "chunk" in event:
|
|
982
|
+
chunk_data = json.loads(event["chunk"]["bytes"].decode("utf-8"))
|
|
983
|
+
|
|
984
|
+
# Handle text content
|
|
985
|
+
if chunk_data.get("type") == "content_block_delta":
|
|
986
|
+
delta = chunk_data.get("delta", {})
|
|
987
|
+
index = chunk_data.get("index", 0)
|
|
988
|
+
|
|
989
|
+
if delta.get("type") == "text_delta" and "text" in delta:
|
|
990
|
+
yield TextChunk(content=delta["text"])
|
|
991
|
+
|
|
992
|
+
elif delta.get("type") == "input_json_delta":
|
|
993
|
+
# Accumulate tool arguments
|
|
994
|
+
if index in active_tool_calls:
|
|
995
|
+
active_tool_calls[index]["arguments_buffer"] += delta.get(
|
|
996
|
+
"partial_json", ""
|
|
997
|
+
)
|
|
998
|
+
# Yield incremental argument update
|
|
999
|
+
yield ToolCallChunk(
|
|
1000
|
+
id=active_tool_calls[index]["id"],
|
|
1001
|
+
function=FunctionCallChunk(
|
|
1002
|
+
name=active_tool_calls[index]["name"],
|
|
1003
|
+
arguments=delta.get("partial_json", ""),
|
|
1004
|
+
),
|
|
1005
|
+
)
|
|
1006
|
+
|
|
1007
|
+
# Handle tool call start
|
|
1008
|
+
elif chunk_data.get("type") == "content_block_start":
|
|
1009
|
+
content_block = chunk_data.get("content_block", {})
|
|
1010
|
+
index = chunk_data.get("index", 0)
|
|
1011
|
+
|
|
1012
|
+
if content_block.get("type") == "tool_use":
|
|
1013
|
+
# Initialize tool call tracking
|
|
1014
|
+
active_tool_calls[index] = {
|
|
1015
|
+
"id": content_block.get("id"),
|
|
1016
|
+
"name": content_block.get("name"),
|
|
1017
|
+
"arguments_buffer": "",
|
|
1018
|
+
}
|
|
1019
|
+
|
|
1020
|
+
# Yield initial tool call chunk
|
|
1021
|
+
yield ToolCallChunk(
|
|
1022
|
+
id=content_block.get("id"),
|
|
1023
|
+
function=FunctionCallChunk(
|
|
1024
|
+
name=content_block.get("name"),
|
|
1025
|
+
arguments="", # Start with empty, will be filled by deltas
|
|
1026
|
+
),
|
|
1027
|
+
)
|
|
1028
|
+
|
|
1029
|
+
# Handle content block stop (tool call complete)
|
|
1030
|
+
elif chunk_data.get("type") == "content_block_stop":
|
|
1031
|
+
index = chunk_data.get("index", 0)
|
|
1032
|
+
if index in active_tool_calls:
|
|
1033
|
+
# Tool call is complete, clean up
|
|
1034
|
+
del active_tool_calls[index]
|
|
1035
|
+
|
|
1036
|
+
elif chunk_data.get("type") == "message_stop":
|
|
1037
|
+
self._attributes.update(
|
|
1038
|
+
{
|
|
1039
|
+
LLM_TOKEN_COUNT_COMPLETION: chunk_data.get(
|
|
1040
|
+
"amazon-bedrock-invocationMetrics", {}
|
|
1041
|
+
).get("outputTokenCount", 0)
|
|
1042
|
+
}
|
|
1043
|
+
)
|
|
1044
|
+
|
|
1045
|
+
self._attributes.update(
|
|
1046
|
+
{
|
|
1047
|
+
LLM_TOKEN_COUNT_PROMPT: chunk_data.get(
|
|
1048
|
+
"amazon-bedrock-invocationMetrics", {}
|
|
1049
|
+
).get("inputTokenCount", 0)
|
|
1050
|
+
}
|
|
1051
|
+
)
|
|
1052
|
+
|
|
1053
|
+
def _build_bedrock_messages(
|
|
1054
|
+
self,
|
|
1055
|
+
messages: list[
|
|
1056
|
+
tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[JSONScalarType]]]
|
|
1057
|
+
],
|
|
1058
|
+
) -> tuple[list[dict[str, Any]], str]:
|
|
1059
|
+
bedrock_messages = []
|
|
1060
|
+
system_prompt = ""
|
|
1061
|
+
for role, content, _, _ in messages:
|
|
1062
|
+
if role == ChatCompletionMessageRole.USER:
|
|
1063
|
+
bedrock_messages.append(
|
|
1064
|
+
{
|
|
1065
|
+
"role": "user",
|
|
1066
|
+
"content": content,
|
|
1067
|
+
}
|
|
1068
|
+
)
|
|
1069
|
+
elif role == ChatCompletionMessageRole.AI:
|
|
1070
|
+
bedrock_messages.append(
|
|
1071
|
+
{
|
|
1072
|
+
"role": "assistant",
|
|
1073
|
+
"content": content,
|
|
1074
|
+
}
|
|
1075
|
+
)
|
|
1076
|
+
elif role == ChatCompletionMessageRole.SYSTEM:
|
|
1077
|
+
system_prompt += content + "\n"
|
|
1078
|
+
return bedrock_messages, system_prompt
|
|
1079
|
+
|
|
1080
|
+
def _extract_system_prompt(
|
|
1081
|
+
self,
|
|
1082
|
+
messages: list[
|
|
1083
|
+
tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[JSONScalarType]]]
|
|
1084
|
+
],
|
|
1085
|
+
) -> str:
|
|
1086
|
+
"""Extract system prompt from messages."""
|
|
1087
|
+
system_prompts = []
|
|
1088
|
+
for role, content, _, _ in messages:
|
|
1089
|
+
if role == ChatCompletionMessageRole.SYSTEM:
|
|
1090
|
+
system_prompts.append(content)
|
|
1091
|
+
return "\n".join(system_prompts)
|
|
1092
|
+
|
|
1093
|
+
def _build_converse_messages(
|
|
1094
|
+
self,
|
|
1095
|
+
messages: list[
|
|
1096
|
+
tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[JSONScalarType]]]
|
|
1097
|
+
],
|
|
1098
|
+
) -> list[dict[str, Any]]:
|
|
1099
|
+
"""Convert messages to Converse API format."""
|
|
1100
|
+
converse_messages: list[dict[str, Any]] = []
|
|
1101
|
+
for role, content, _id, tool_calls in messages:
|
|
1102
|
+
if role == ChatCompletionMessageRole.USER:
|
|
1103
|
+
converse_messages.append({"role": "user", "content": [{"text": content}]})
|
|
1104
|
+
elif role == ChatCompletionMessageRole.TOOL:
|
|
1105
|
+
converse_messages.append(
|
|
1106
|
+
{
|
|
1107
|
+
"role": "user",
|
|
1108
|
+
"content": [
|
|
1109
|
+
{
|
|
1110
|
+
"toolResult": {
|
|
1111
|
+
"toolUseId": _id,
|
|
1112
|
+
"content": [{"json": json.loads(content)}],
|
|
1113
|
+
}
|
|
1114
|
+
}
|
|
1115
|
+
],
|
|
1116
|
+
}
|
|
1117
|
+
)
|
|
1118
|
+
|
|
1119
|
+
elif role == ChatCompletionMessageRole.AI:
|
|
1120
|
+
# Handle assistant messages with potential tool calls
|
|
1121
|
+
message: dict[str, Any] = {"role": "assistant", "content": []}
|
|
1122
|
+
if content:
|
|
1123
|
+
message["content"].append({"text": content})
|
|
1124
|
+
if tool_calls:
|
|
1125
|
+
for tool_call in tool_calls:
|
|
1126
|
+
message["content"].append(tool_call)
|
|
1127
|
+
converse_messages.append(message)
|
|
1128
|
+
return converse_messages
|
|
1129
|
+
|
|
1130
|
+
|
|
1131
|
+
@register_llm_client(
|
|
1132
|
+
provider_key=GenerativeProviderKey.OPENAI,
|
|
1133
|
+
model_names=[
|
|
1134
|
+
PROVIDER_DEFAULT,
|
|
1135
|
+
"gpt-4.1",
|
|
1136
|
+
"gpt-4.1-mini",
|
|
1137
|
+
"gpt-4.1-nano",
|
|
1138
|
+
"gpt-4.1-2025-04-14",
|
|
1139
|
+
"gpt-4.1-mini-2025-04-14",
|
|
1140
|
+
"gpt-4.1-nano-2025-04-14",
|
|
1141
|
+
"gpt-4o",
|
|
1142
|
+
"gpt-4o-2024-11-20",
|
|
1143
|
+
"gpt-4o-2024-08-06",
|
|
1144
|
+
"gpt-4o-2024-05-13",
|
|
1145
|
+
"chatgpt-4o-latest",
|
|
1146
|
+
"gpt-4o-mini",
|
|
1147
|
+
"gpt-4o-mini-2024-07-18",
|
|
1148
|
+
"gpt-4-turbo",
|
|
1149
|
+
"gpt-4-turbo-2024-04-09",
|
|
1150
|
+
"gpt-4-turbo-preview",
|
|
1151
|
+
"gpt-4-0125-preview",
|
|
1152
|
+
"gpt-4-1106-preview",
|
|
1153
|
+
"gpt-4",
|
|
1154
|
+
"gpt-4-0613",
|
|
1155
|
+
"gpt-3.5-turbo-0125",
|
|
1156
|
+
"gpt-3.5-turbo",
|
|
1157
|
+
"gpt-3.5-turbo-1106",
|
|
1158
|
+
# preview models
|
|
1159
|
+
"gpt-4.5-preview",
|
|
1160
|
+
],
|
|
1161
|
+
)
|
|
1162
|
+
class OpenAIStreamingClient(OpenAIBaseStreamingClient):
|
|
1163
|
+
def __init__(
|
|
1164
|
+
self,
|
|
1165
|
+
model: GenerativeModelInput,
|
|
1166
|
+
credentials: Optional[list[PlaygroundClientCredential]] = None,
|
|
1167
|
+
) -> None:
|
|
1168
|
+
from openai import AsyncOpenAI
|
|
1169
|
+
|
|
1170
|
+
base_url = model.base_url or getenv("OPENAI_BASE_URL")
|
|
1171
|
+
|
|
1172
|
+
# Try to get API key from credentials first, then fallback to env
|
|
1173
|
+
api_key = _get_credential_value(credentials, "OPENAI_API_KEY") or getenv("OPENAI_API_KEY")
|
|
1174
|
+
|
|
1175
|
+
if not api_key:
|
|
1176
|
+
if not base_url:
|
|
1177
|
+
raise BadRequest("An API key is required for OpenAI models")
|
|
1178
|
+
api_key = "sk-fake-api-key"
|
|
1179
|
+
|
|
1180
|
+
client = AsyncOpenAI(
|
|
1181
|
+
api_key=api_key,
|
|
1182
|
+
base_url=base_url,
|
|
1183
|
+
default_headers=model.custom_headers or None,
|
|
1184
|
+
timeout=30,
|
|
1185
|
+
)
|
|
1186
|
+
super().__init__(client=client, model=model, credentials=credentials)
|
|
1187
|
+
self._attributes[LLM_PROVIDER] = OpenInferenceLLMProviderValues.OPENAI.value
|
|
1188
|
+
self._attributes[LLM_SYSTEM] = OpenInferenceLLMSystemValues.OPENAI.value
|
|
1189
|
+
|
|
1190
|
+
|
|
1191
|
+
_OPENAI_REASONING_MODELS = [
|
|
1192
|
+
"gpt-5.2",
|
|
1193
|
+
"gpt-5.2-2025-12-11",
|
|
1194
|
+
"gpt-5.2-chat-latest",
|
|
1195
|
+
"gpt-5.1",
|
|
1196
|
+
"gpt-5.1-2025-11-13",
|
|
1197
|
+
"gpt-5.1-chat-latest",
|
|
1198
|
+
"gpt-5",
|
|
1199
|
+
"gpt-5-mini",
|
|
1200
|
+
"gpt-5-nano",
|
|
1201
|
+
"gpt-5-chat-latest",
|
|
1202
|
+
"o1",
|
|
1203
|
+
"o1-pro",
|
|
1204
|
+
"o1-2024-12-17",
|
|
1205
|
+
"o1-pro-2025-03-19",
|
|
1206
|
+
"o1-mini",
|
|
1207
|
+
"o1-mini-2024-09-12",
|
|
1208
|
+
"o1-preview",
|
|
1209
|
+
"o1-preview-2024-09-12",
|
|
1210
|
+
"o3",
|
|
1211
|
+
"o3-pro",
|
|
1212
|
+
"o3-2025-04-16",
|
|
1213
|
+
"o3-mini",
|
|
1214
|
+
"o3-mini-2025-01-31",
|
|
1215
|
+
"o4-mini",
|
|
1216
|
+
"o4-mini-2025-04-16",
|
|
1217
|
+
]
|
|
1218
|
+
|
|
1219
|
+
|
|
1220
|
+
class OpenAIReasoningReasoningModelsMixin:
|
|
1221
|
+
"""Mixin class for OpenAI-style reasoning model clients (o1, o3 series)."""
|
|
1222
|
+
|
|
1223
|
+
@classmethod
|
|
1224
|
+
def supported_invocation_parameters(cls) -> list[InvocationParameter]:
|
|
1225
|
+
return [
|
|
1226
|
+
StringInvocationParameter(
|
|
1227
|
+
invocation_name="reasoning_effort",
|
|
1228
|
+
label="Reasoning Effort",
|
|
1229
|
+
canonical_name=CanonicalParameterName.REASONING_EFFORT,
|
|
1230
|
+
),
|
|
1231
|
+
IntInvocationParameter(
|
|
1232
|
+
invocation_name="max_completion_tokens",
|
|
1233
|
+
canonical_name=CanonicalParameterName.MAX_COMPLETION_TOKENS,
|
|
1234
|
+
label="Max Completion Tokens",
|
|
1235
|
+
),
|
|
1236
|
+
IntInvocationParameter(
|
|
1237
|
+
invocation_name="seed",
|
|
1238
|
+
canonical_name=CanonicalParameterName.RANDOM_SEED,
|
|
1239
|
+
label="Seed",
|
|
1240
|
+
),
|
|
1241
|
+
JSONInvocationParameter(
|
|
1242
|
+
invocation_name="tool_choice",
|
|
1243
|
+
label="Tool Choice",
|
|
1244
|
+
canonical_name=CanonicalParameterName.TOOL_CHOICE,
|
|
1245
|
+
),
|
|
1246
|
+
JSONInvocationParameter(
|
|
1247
|
+
invocation_name="response_format",
|
|
1248
|
+
label="Response Format",
|
|
1249
|
+
canonical_name=CanonicalParameterName.RESPONSE_FORMAT,
|
|
1250
|
+
),
|
|
1251
|
+
JSONInvocationParameter(
|
|
1252
|
+
invocation_name="extra_body",
|
|
1253
|
+
label="Extra Body",
|
|
1254
|
+
),
|
|
1255
|
+
]
|
|
1256
|
+
|
|
1257
|
+
|
|
1258
|
+
@register_llm_client(
|
|
1259
|
+
provider_key=GenerativeProviderKey.OPENAI,
|
|
1260
|
+
model_names=_OPENAI_REASONING_MODELS,
|
|
1261
|
+
)
|
|
1262
|
+
class OpenAIReasoningNonStreamingClient(
|
|
1263
|
+
OpenAIReasoningReasoningModelsMixin,
|
|
1264
|
+
OpenAIStreamingClient,
|
|
1265
|
+
):
|
|
1266
|
+
def to_openai_chat_completion_param(
|
|
1267
|
+
self,
|
|
1268
|
+
role: ChatCompletionMessageRole,
|
|
1269
|
+
content: JSONScalarType,
|
|
1270
|
+
tool_call_id: Optional[str] = None,
|
|
1271
|
+
tool_calls: Optional[list[JSONScalarType]] = None,
|
|
1272
|
+
) -> Optional["ChatCompletionMessageParam"]:
|
|
1273
|
+
from openai.types.chat import (
|
|
1274
|
+
ChatCompletionAssistantMessageParam,
|
|
1275
|
+
ChatCompletionDeveloperMessageParam,
|
|
1276
|
+
ChatCompletionToolMessageParam,
|
|
1277
|
+
ChatCompletionUserMessageParam,
|
|
1278
|
+
)
|
|
1279
|
+
|
|
1280
|
+
if role is ChatCompletionMessageRole.USER:
|
|
1281
|
+
return ChatCompletionUserMessageParam(
|
|
1282
|
+
{
|
|
1283
|
+
"content": content,
|
|
1284
|
+
"role": "user",
|
|
1285
|
+
}
|
|
1286
|
+
)
|
|
1287
|
+
if role is ChatCompletionMessageRole.SYSTEM:
|
|
1288
|
+
return ChatCompletionDeveloperMessageParam(
|
|
1289
|
+
{
|
|
1290
|
+
"content": content,
|
|
1291
|
+
"role": "developer",
|
|
1292
|
+
}
|
|
1293
|
+
)
|
|
1294
|
+
if role is ChatCompletionMessageRole.AI:
|
|
1295
|
+
if tool_calls is None:
|
|
1296
|
+
return ChatCompletionAssistantMessageParam(
|
|
1297
|
+
{
|
|
1298
|
+
"content": content,
|
|
1299
|
+
"role": "assistant",
|
|
1300
|
+
}
|
|
1301
|
+
)
|
|
1302
|
+
else:
|
|
1303
|
+
return ChatCompletionAssistantMessageParam(
|
|
1304
|
+
{
|
|
1305
|
+
"content": content,
|
|
1306
|
+
"role": "assistant",
|
|
1307
|
+
"tool_calls": [
|
|
1308
|
+
self.to_openai_tool_call_param(tool_call) for tool_call in tool_calls
|
|
1309
|
+
],
|
|
1310
|
+
}
|
|
1311
|
+
)
|
|
1312
|
+
if role is ChatCompletionMessageRole.TOOL:
|
|
1313
|
+
if tool_call_id is None:
|
|
1314
|
+
raise ValueError("tool_call_id is required for tool messages")
|
|
1315
|
+
return ChatCompletionToolMessageParam(
|
|
1316
|
+
{"content": content, "role": "tool", "tool_call_id": tool_call_id}
|
|
1317
|
+
)
|
|
1318
|
+
assert_never(role)
|
|
1319
|
+
|
|
1320
|
+
|
|
1321
|
+
@register_llm_client(
|
|
1322
|
+
provider_key=GenerativeProviderKey.AZURE_OPENAI,
|
|
1323
|
+
model_names=[
|
|
1324
|
+
PROVIDER_DEFAULT,
|
|
1325
|
+
],
|
|
1326
|
+
)
|
|
1327
|
+
class AzureOpenAIStreamingClient(OpenAIBaseStreamingClient):
|
|
1328
|
+
def __init__(
|
|
1329
|
+
self,
|
|
1330
|
+
model: GenerativeModelInput,
|
|
1331
|
+
credentials: Optional[list[PlaygroundClientCredential]] = None,
|
|
1332
|
+
):
|
|
1333
|
+
from openai import AsyncAzureOpenAI
|
|
1334
|
+
|
|
1335
|
+
if not (endpoint := model.endpoint or getenv("AZURE_OPENAI_ENDPOINT")):
|
|
1336
|
+
raise BadRequest("An Azure endpoint is required for Azure OpenAI models")
|
|
1337
|
+
if not (api_version := model.api_version or getenv("OPENAI_API_VERSION")):
|
|
1338
|
+
raise BadRequest("An OpenAI API version is required for Azure OpenAI models")
|
|
1339
|
+
|
|
1340
|
+
# Try to get API key from credentials first, then fallback to env
|
|
1341
|
+
api_key = _get_credential_value(credentials, "AZURE_OPENAI_API_KEY") or getenv(
|
|
1342
|
+
"AZURE_OPENAI_API_KEY"
|
|
1343
|
+
)
|
|
1344
|
+
|
|
1345
|
+
if api_key:
|
|
1346
|
+
client = AsyncAzureOpenAI(
|
|
1347
|
+
api_key=api_key,
|
|
1348
|
+
azure_endpoint=endpoint,
|
|
1349
|
+
api_version=api_version,
|
|
1350
|
+
default_headers=model.custom_headers or None,
|
|
1351
|
+
)
|
|
1352
|
+
else:
|
|
1353
|
+
try:
|
|
1354
|
+
from azure.identity.aio import DefaultAzureCredential, get_bearer_token_provider
|
|
1355
|
+
except ImportError:
|
|
1356
|
+
raise BadRequest(
|
|
1357
|
+
"Provide an API key for Azure OpenAI models or use azure-identity, see. e.g. "
|
|
1358
|
+
"https://learn.microsoft.com/en-us/python/api/azure-identity/azure.identity.environmentcredential?view=azure-python" # noqa: E501
|
|
1359
|
+
)
|
|
1360
|
+
|
|
1361
|
+
client = AsyncAzureOpenAI(
|
|
1362
|
+
azure_ad_token_provider=get_bearer_token_provider(
|
|
1363
|
+
DefaultAzureCredential(),
|
|
1364
|
+
"https://cognitiveservices.azure.com/.default",
|
|
1365
|
+
),
|
|
1366
|
+
azure_endpoint=endpoint,
|
|
1367
|
+
api_version=api_version,
|
|
1368
|
+
default_headers=model.custom_headers or None,
|
|
1369
|
+
)
|
|
1370
|
+
super().__init__(client=client, model=model, credentials=credentials)
|
|
1371
|
+
self._attributes[LLM_PROVIDER] = OpenInferenceLLMProviderValues.AZURE.value
|
|
1372
|
+
self._attributes[LLM_SYSTEM] = OpenInferenceLLMSystemValues.OPENAI.value
|
|
1373
|
+
|
|
1374
|
+
|
|
1375
|
+
@register_llm_client(
|
|
1376
|
+
provider_key=GenerativeProviderKey.AZURE_OPENAI,
|
|
1377
|
+
model_names=_OPENAI_REASONING_MODELS,
|
|
1378
|
+
)
|
|
1379
|
+
class AzureOpenAIReasoningNonStreamingClient(
|
|
1380
|
+
OpenAIReasoningReasoningModelsMixin,
|
|
1381
|
+
AzureOpenAIStreamingClient,
|
|
1382
|
+
):
|
|
1383
|
+
@override
|
|
1384
|
+
async def chat_completion_create(
|
|
1385
|
+
self,
|
|
1386
|
+
messages: list[
|
|
1387
|
+
tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[JSONScalarType]]]
|
|
1388
|
+
],
|
|
1389
|
+
tools: list[JSONScalarType],
|
|
1390
|
+
**invocation_parameters: Any,
|
|
1391
|
+
) -> AsyncIterator[ChatCompletionChunk]:
|
|
1392
|
+
from openai import NOT_GIVEN
|
|
1393
|
+
|
|
1394
|
+
# Convert standard messages to OpenAI messages
|
|
1395
|
+
openai_messages = []
|
|
1396
|
+
for message in messages:
|
|
1397
|
+
openai_message = self.to_openai_chat_completion_param(*message)
|
|
1398
|
+
if openai_message is not None:
|
|
1399
|
+
openai_messages.append(openai_message)
|
|
1400
|
+
|
|
1401
|
+
throttled_create = self.rate_limiter._alimit(self.client.chat.completions.create)
|
|
1402
|
+
response = await throttled_create(
|
|
1403
|
+
messages=openai_messages,
|
|
1404
|
+
model=self.model_name,
|
|
1405
|
+
stream=False,
|
|
1406
|
+
tools=tools or NOT_GIVEN,
|
|
1407
|
+
**invocation_parameters,
|
|
1408
|
+
)
|
|
1409
|
+
|
|
1410
|
+
if response.usage is not None:
|
|
1411
|
+
self._attributes.update(dict(self._llm_token_counts(response.usage)))
|
|
1412
|
+
|
|
1413
|
+
choice = response.choices[0]
|
|
1414
|
+
if choice.message.content:
|
|
1415
|
+
yield TextChunk(content=choice.message.content)
|
|
1416
|
+
|
|
1417
|
+
if choice.message.tool_calls:
|
|
1418
|
+
for tool_call in choice.message.tool_calls:
|
|
1419
|
+
yield ToolCallChunk(
|
|
1420
|
+
id=tool_call.id,
|
|
1421
|
+
function=FunctionCallChunk(
|
|
1422
|
+
name=tool_call.function.name,
|
|
1423
|
+
arguments=tool_call.function.arguments,
|
|
1424
|
+
),
|
|
1425
|
+
)
|
|
1426
|
+
|
|
1427
|
+
def to_openai_chat_completion_param(
|
|
1428
|
+
self,
|
|
1429
|
+
role: ChatCompletionMessageRole,
|
|
1430
|
+
content: JSONScalarType,
|
|
1431
|
+
tool_call_id: Optional[str] = None,
|
|
1432
|
+
tool_calls: Optional[list[JSONScalarType]] = None,
|
|
1433
|
+
) -> Optional["ChatCompletionMessageParam"]:
|
|
1434
|
+
from openai.types.chat import (
|
|
1435
|
+
ChatCompletionAssistantMessageParam,
|
|
1436
|
+
ChatCompletionDeveloperMessageParam,
|
|
1437
|
+
ChatCompletionToolMessageParam,
|
|
1438
|
+
ChatCompletionUserMessageParam,
|
|
1439
|
+
)
|
|
1440
|
+
|
|
1441
|
+
if role is ChatCompletionMessageRole.USER:
|
|
1442
|
+
return ChatCompletionUserMessageParam(
|
|
1443
|
+
{
|
|
1444
|
+
"content": content,
|
|
1445
|
+
"role": "user",
|
|
1446
|
+
}
|
|
1447
|
+
)
|
|
1448
|
+
if role is ChatCompletionMessageRole.SYSTEM:
|
|
1449
|
+
return ChatCompletionDeveloperMessageParam(
|
|
1450
|
+
{
|
|
1451
|
+
"content": content,
|
|
1452
|
+
"role": "developer",
|
|
1453
|
+
}
|
|
1454
|
+
)
|
|
1455
|
+
if role is ChatCompletionMessageRole.AI:
|
|
1456
|
+
if tool_calls is None:
|
|
1457
|
+
return ChatCompletionAssistantMessageParam(
|
|
1458
|
+
{
|
|
1459
|
+
"content": content,
|
|
1460
|
+
"role": "assistant",
|
|
1461
|
+
}
|
|
1462
|
+
)
|
|
1463
|
+
else:
|
|
1464
|
+
return ChatCompletionAssistantMessageParam(
|
|
1465
|
+
{
|
|
1466
|
+
"content": content,
|
|
1467
|
+
"role": "assistant",
|
|
1468
|
+
"tool_calls": [
|
|
1469
|
+
self.to_openai_tool_call_param(tool_call) for tool_call in tool_calls
|
|
1470
|
+
],
|
|
1471
|
+
}
|
|
1472
|
+
)
|
|
1473
|
+
if role is ChatCompletionMessageRole.TOOL:
|
|
1474
|
+
if tool_call_id is None:
|
|
1475
|
+
raise ValueError("tool_call_id is required for tool messages")
|
|
1476
|
+
return ChatCompletionToolMessageParam(
|
|
1477
|
+
{"content": content, "role": "tool", "tool_call_id": tool_call_id}
|
|
1478
|
+
)
|
|
1479
|
+
assert_never(role)
|
|
1480
|
+
|
|
1481
|
+
|
|
1482
|
+
@register_llm_client(
|
|
1483
|
+
provider_key=GenerativeProviderKey.ANTHROPIC,
|
|
1484
|
+
model_names=[
|
|
1485
|
+
PROVIDER_DEFAULT,
|
|
1486
|
+
"claude-3-5-haiku-latest",
|
|
1487
|
+
"claude-3-5-haiku-20241022",
|
|
1488
|
+
"claude-3-haiku-20240307",
|
|
1489
|
+
],
|
|
1490
|
+
)
|
|
1491
|
+
class AnthropicStreamingClient(PlaygroundStreamingClient):
|
|
1492
|
+
def __init__(
|
|
1493
|
+
self,
|
|
1494
|
+
model: GenerativeModelInput,
|
|
1495
|
+
credentials: Optional[list[PlaygroundClientCredential]] = None,
|
|
1496
|
+
) -> None:
|
|
1497
|
+
import anthropic
|
|
1498
|
+
|
|
1499
|
+
super().__init__(model=model, credentials=credentials)
|
|
1500
|
+
self._attributes[LLM_PROVIDER] = OpenInferenceLLMProviderValues.ANTHROPIC.value
|
|
1501
|
+
self._attributes[LLM_SYSTEM] = OpenInferenceLLMSystemValues.ANTHROPIC.value
|
|
1502
|
+
|
|
1503
|
+
# Try to get API key from credentials first, then fallback to env
|
|
1504
|
+
api_key = _get_credential_value(credentials, "ANTHROPIC_API_KEY") or getenv(
|
|
1505
|
+
"ANTHROPIC_API_KEY"
|
|
1506
|
+
)
|
|
1507
|
+
|
|
1508
|
+
if not api_key:
|
|
1509
|
+
raise BadRequest("An API key is required for Anthropic models")
|
|
1510
|
+
|
|
1511
|
+
self.client = anthropic.AsyncAnthropic(
|
|
1512
|
+
api_key=api_key,
|
|
1513
|
+
default_headers=model.custom_headers or None,
|
|
1514
|
+
)
|
|
1515
|
+
self.model_name = model.name
|
|
1516
|
+
self.rate_limiter = PlaygroundRateLimiter(model.provider_key, anthropic.RateLimitError)
|
|
1517
|
+
self.client._client = _HttpxClient(self.client._client, self._attributes)
|
|
1518
|
+
|
|
1519
|
+
@classmethod
|
|
1520
|
+
def dependencies(cls) -> list[Dependency]:
|
|
1521
|
+
return [Dependency(name="anthropic")]
|
|
1522
|
+
|
|
1523
|
+
@classmethod
|
|
1524
|
+
def supported_invocation_parameters(cls) -> list[InvocationParameter]:
|
|
1525
|
+
return [
|
|
1526
|
+
IntInvocationParameter(
|
|
1527
|
+
invocation_name="max_tokens",
|
|
1528
|
+
canonical_name=CanonicalParameterName.MAX_COMPLETION_TOKENS,
|
|
1529
|
+
label="Max Tokens",
|
|
1530
|
+
default_value=1024,
|
|
1531
|
+
required=True,
|
|
1532
|
+
),
|
|
1533
|
+
BoundedFloatInvocationParameter(
|
|
1534
|
+
invocation_name="temperature",
|
|
1535
|
+
canonical_name=CanonicalParameterName.TEMPERATURE,
|
|
1536
|
+
label="Temperature",
|
|
1537
|
+
default_value=1.0,
|
|
1538
|
+
min_value=0.0,
|
|
1539
|
+
max_value=1.0,
|
|
1540
|
+
),
|
|
1541
|
+
StringListInvocationParameter(
|
|
1542
|
+
invocation_name="stop_sequences",
|
|
1543
|
+
canonical_name=CanonicalParameterName.STOP_SEQUENCES,
|
|
1544
|
+
label="Stop Sequences",
|
|
1545
|
+
),
|
|
1546
|
+
BoundedFloatInvocationParameter(
|
|
1547
|
+
invocation_name="top_p",
|
|
1548
|
+
canonical_name=CanonicalParameterName.TOP_P,
|
|
1549
|
+
label="Top P",
|
|
1550
|
+
min_value=0.0,
|
|
1551
|
+
max_value=1.0,
|
|
1552
|
+
),
|
|
1553
|
+
JSONInvocationParameter(
|
|
1554
|
+
invocation_name="tool_choice",
|
|
1555
|
+
label="Tool Choice",
|
|
1556
|
+
canonical_name=CanonicalParameterName.TOOL_CHOICE,
|
|
1557
|
+
),
|
|
1558
|
+
]
|
|
1559
|
+
|
|
1560
|
+
async def chat_completion_create(
|
|
1561
|
+
self,
|
|
1562
|
+
messages: list[
|
|
1563
|
+
tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[JSONScalarType]]]
|
|
1564
|
+
],
|
|
1565
|
+
tools: list[JSONScalarType],
|
|
1566
|
+
**invocation_parameters: Any,
|
|
1567
|
+
) -> AsyncIterator[ChatCompletionChunk]:
|
|
1568
|
+
import anthropic.lib.streaming as anthropic_streaming
|
|
1569
|
+
import anthropic.types as anthropic_types
|
|
1570
|
+
|
|
1571
|
+
anthropic_messages, system_prompt = self._build_anthropic_messages(messages)
|
|
1572
|
+
anthropic_params = {
|
|
1573
|
+
"messages": anthropic_messages,
|
|
1574
|
+
"model": self.model_name,
|
|
1575
|
+
"system": system_prompt,
|
|
792
1576
|
"tools": tools,
|
|
793
1577
|
**invocation_parameters,
|
|
794
1578
|
}
|
|
@@ -796,15 +1580,34 @@ class AnthropicStreamingClient(PlaygroundStreamingClient):
|
|
|
796
1580
|
async with await throttled_stream(**anthropic_params) as stream:
|
|
797
1581
|
async for event in stream:
|
|
798
1582
|
if isinstance(event, anthropic_types.RawMessageStartEvent):
|
|
799
|
-
|
|
800
|
-
|
|
801
|
-
|
|
1583
|
+
usage = event.message.usage
|
|
1584
|
+
|
|
1585
|
+
token_counts: dict[str, Any] = {}
|
|
1586
|
+
if prompt_tokens := (
|
|
1587
|
+
(usage.input_tokens or 0)
|
|
1588
|
+
+ (getattr(usage, "cache_creation_input_tokens", 0) or 0)
|
|
1589
|
+
+ (getattr(usage, "cache_read_input_tokens", 0) or 0)
|
|
1590
|
+
):
|
|
1591
|
+
token_counts[LLM_TOKEN_COUNT_PROMPT] = prompt_tokens
|
|
1592
|
+
if cache_creation_tokens := getattr(usage, "cache_creation_input_tokens", None):
|
|
1593
|
+
if cache_creation_tokens is not None:
|
|
1594
|
+
token_counts[LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_WRITE] = (
|
|
1595
|
+
cache_creation_tokens
|
|
1596
|
+
)
|
|
1597
|
+
self._attributes.update(token_counts)
|
|
802
1598
|
elif isinstance(event, anthropic_streaming.TextEvent):
|
|
803
1599
|
yield TextChunk(content=event.text)
|
|
804
1600
|
elif isinstance(event, anthropic_streaming.MessageStopEvent):
|
|
805
|
-
|
|
806
|
-
|
|
807
|
-
|
|
1601
|
+
usage = event.message.usage
|
|
1602
|
+
output_token_counts: dict[str, Any] = {}
|
|
1603
|
+
if usage.output_tokens:
|
|
1604
|
+
output_token_counts[LLM_TOKEN_COUNT_COMPLETION] = usage.output_tokens
|
|
1605
|
+
if cache_read_tokens := getattr(usage, "cache_read_input_tokens", None):
|
|
1606
|
+
if cache_read_tokens is not None:
|
|
1607
|
+
output_token_counts[LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_READ] = (
|
|
1608
|
+
cache_read_tokens
|
|
1609
|
+
)
|
|
1610
|
+
self._attributes.update(output_token_counts)
|
|
808
1611
|
elif (
|
|
809
1612
|
isinstance(event, anthropic_streaming.ContentBlockStopEvent)
|
|
810
1613
|
and event.content_block.type == "tool_use"
|
|
@@ -889,6 +1692,18 @@ class AnthropicStreamingClient(PlaygroundStreamingClient):
|
|
|
889
1692
|
@register_llm_client(
|
|
890
1693
|
provider_key=GenerativeProviderKey.ANTHROPIC,
|
|
891
1694
|
model_names=[
|
|
1695
|
+
"claude-opus-4-5",
|
|
1696
|
+
"claude-opus-4-5-20251101",
|
|
1697
|
+
"claude-sonnet-4-5",
|
|
1698
|
+
"claude-sonnet-4-5-20250929",
|
|
1699
|
+
"claude-haiku-4-5",
|
|
1700
|
+
"claude-haiku-4-5-20251001",
|
|
1701
|
+
"claude-opus-4-1",
|
|
1702
|
+
"claude-opus-4-1-20250805",
|
|
1703
|
+
"claude-sonnet-4-0",
|
|
1704
|
+
"claude-sonnet-4-20250514",
|
|
1705
|
+
"claude-opus-4-0",
|
|
1706
|
+
"claude-opus-4-20250514",
|
|
892
1707
|
"claude-3-7-sonnet-latest",
|
|
893
1708
|
"claude-3-7-sonnet-20250219",
|
|
894
1709
|
],
|
|
@@ -911,7 +1726,6 @@ class AnthropicReasoningStreamingClient(AnthropicStreamingClient):
|
|
|
911
1726
|
provider_key=GenerativeProviderKey.GOOGLE,
|
|
912
1727
|
model_names=[
|
|
913
1728
|
PROVIDER_DEFAULT,
|
|
914
|
-
"gemini-2.5-pro-preview-03-25",
|
|
915
1729
|
"gemini-2.0-flash-lite",
|
|
916
1730
|
"gemini-2.0-flash-001",
|
|
917
1731
|
"gemini-2.0-flash-thinking-exp-01-21",
|
|
@@ -925,21 +1739,31 @@ class GoogleStreamingClient(PlaygroundStreamingClient):
|
|
|
925
1739
|
def __init__(
|
|
926
1740
|
self,
|
|
927
1741
|
model: GenerativeModelInput,
|
|
928
|
-
|
|
1742
|
+
credentials: Optional[list[PlaygroundClientCredential]] = None,
|
|
929
1743
|
) -> None:
|
|
930
|
-
import google.
|
|
1744
|
+
import google.genai as google_genai
|
|
931
1745
|
|
|
932
|
-
super().__init__(model=model,
|
|
1746
|
+
super().__init__(model=model, credentials=credentials)
|
|
933
1747
|
self._attributes[LLM_PROVIDER] = OpenInferenceLLMProviderValues.GOOGLE.value
|
|
934
1748
|
self._attributes[LLM_SYSTEM] = OpenInferenceLLMSystemValues.VERTEXAI.value
|
|
935
|
-
|
|
1749
|
+
|
|
1750
|
+
# Try to get API key from credentials first, then fallback to env
|
|
1751
|
+
api_key = (
|
|
1752
|
+
_get_credential_value(credentials, "GEMINI_API_KEY")
|
|
1753
|
+
or _get_credential_value(credentials, "GOOGLE_API_KEY")
|
|
1754
|
+
or getenv("GEMINI_API_KEY")
|
|
1755
|
+
or getenv("GOOGLE_API_KEY")
|
|
1756
|
+
)
|
|
1757
|
+
|
|
1758
|
+
if not api_key:
|
|
936
1759
|
raise BadRequest("An API key is required for Gemini models")
|
|
937
|
-
|
|
1760
|
+
|
|
1761
|
+
self.client = google_genai.Client(api_key=api_key)
|
|
938
1762
|
self.model_name = model.name
|
|
939
1763
|
|
|
940
1764
|
@classmethod
|
|
941
1765
|
def dependencies(cls) -> list[Dependency]:
|
|
942
|
-
return [Dependency(name="google-
|
|
1766
|
+
return [Dependency(name="google-genai", module_name="google.genai")]
|
|
943
1767
|
|
|
944
1768
|
@classmethod
|
|
945
1769
|
def supported_invocation_parameters(cls) -> list[InvocationParameter]:
|
|
@@ -976,7 +1800,6 @@ class GoogleStreamingClient(PlaygroundStreamingClient):
|
|
|
976
1800
|
invocation_name="top_p",
|
|
977
1801
|
canonical_name=CanonicalParameterName.TOP_P,
|
|
978
1802
|
label="Top P",
|
|
979
|
-
default_value=1.0,
|
|
980
1803
|
min_value=0.0,
|
|
981
1804
|
max_value=1.0,
|
|
982
1805
|
),
|
|
@@ -984,6 +1807,11 @@ class GoogleStreamingClient(PlaygroundStreamingClient):
|
|
|
984
1807
|
invocation_name="top_k",
|
|
985
1808
|
label="Top K",
|
|
986
1809
|
),
|
|
1810
|
+
JSONInvocationParameter(
|
|
1811
|
+
invocation_name="tool_config",
|
|
1812
|
+
label="Tool Config",
|
|
1813
|
+
canonical_name=CanonicalParameterName.TOOL_CHOICE,
|
|
1814
|
+
),
|
|
987
1815
|
]
|
|
988
1816
|
|
|
989
1817
|
async def chat_completion_create(
|
|
@@ -994,28 +1822,25 @@ class GoogleStreamingClient(PlaygroundStreamingClient):
|
|
|
994
1822
|
tools: list[JSONScalarType],
|
|
995
1823
|
**invocation_parameters: Any,
|
|
996
1824
|
) -> AsyncIterator[ChatCompletionChunk]:
|
|
997
|
-
|
|
1825
|
+
from google.genai import types
|
|
998
1826
|
|
|
999
|
-
|
|
1000
|
-
|
|
1001
|
-
)
|
|
1827
|
+
contents, system_prompt = self._build_google_messages(messages)
|
|
1828
|
+
|
|
1829
|
+
config_dict = invocation_parameters.copy()
|
|
1002
1830
|
|
|
1003
|
-
model_args = {"model_name": self.model_name}
|
|
1004
1831
|
if system_prompt:
|
|
1005
|
-
|
|
1006
|
-
client = google_genai.GenerativeModel(**model_args)
|
|
1832
|
+
config_dict["system_instruction"] = system_prompt
|
|
1007
1833
|
|
|
1008
|
-
|
|
1009
|
-
**
|
|
1010
|
-
|
|
1011
|
-
google_params = {
|
|
1012
|
-
"content": current_message,
|
|
1013
|
-
"generation_config": google_config,
|
|
1014
|
-
"stream": True,
|
|
1015
|
-
}
|
|
1834
|
+
if tools:
|
|
1835
|
+
function_declarations = [types.FunctionDeclaration(**tool) for tool in tools]
|
|
1836
|
+
config_dict["tools"] = [types.Tool(function_declarations=function_declarations)]
|
|
1016
1837
|
|
|
1017
|
-
|
|
1018
|
-
stream = await
|
|
1838
|
+
config = types.GenerateContentConfig.model_validate(config_dict)
|
|
1839
|
+
stream = await self.client.aio.models.generate_content_stream(
|
|
1840
|
+
model=f"models/{self.model_name}",
|
|
1841
|
+
contents=contents,
|
|
1842
|
+
config=config,
|
|
1843
|
+
)
|
|
1019
1844
|
async for event in stream:
|
|
1020
1845
|
self._attributes.update(
|
|
1021
1846
|
{
|
|
@@ -1024,31 +1849,148 @@ class GoogleStreamingClient(PlaygroundStreamingClient):
|
|
|
1024
1849
|
LLM_TOKEN_COUNT_TOTAL: event.usage_metadata.total_token_count,
|
|
1025
1850
|
}
|
|
1026
1851
|
)
|
|
1027
|
-
|
|
1852
|
+
|
|
1853
|
+
if event.candidates:
|
|
1854
|
+
candidate = event.candidates[0]
|
|
1855
|
+
if candidate.content and candidate.content.parts:
|
|
1856
|
+
for part in candidate.content.parts:
|
|
1857
|
+
if function_call := part.function_call:
|
|
1858
|
+
yield ToolCallChunk(
|
|
1859
|
+
id=function_call.id or "",
|
|
1860
|
+
function=FunctionCallChunk(
|
|
1861
|
+
name=function_call.name or "",
|
|
1862
|
+
arguments=json.dumps(function_call.args or {}),
|
|
1863
|
+
),
|
|
1864
|
+
)
|
|
1865
|
+
elif text := part.text:
|
|
1866
|
+
yield TextChunk(content=text)
|
|
1028
1867
|
|
|
1029
1868
|
def _build_google_messages(
|
|
1030
1869
|
self,
|
|
1031
1870
|
messages: list[tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[str]]]],
|
|
1032
|
-
) -> tuple[list["ContentType"], str
|
|
1033
|
-
|
|
1871
|
+
) -> tuple[list["ContentType"], str]:
|
|
1872
|
+
"""Build Google messages following the standard pattern - process ALL messages."""
|
|
1873
|
+
google_messages: list["ContentType"] = []
|
|
1034
1874
|
system_prompts = []
|
|
1035
1875
|
for role, content, _tool_call_id, _tool_calls in messages:
|
|
1036
1876
|
if role == ChatCompletionMessageRole.USER:
|
|
1037
|
-
|
|
1877
|
+
google_messages.append({"role": "user", "parts": [{"text": content}]})
|
|
1038
1878
|
elif role == ChatCompletionMessageRole.AI:
|
|
1039
|
-
|
|
1879
|
+
google_messages.append({"role": "model", "parts": [{"text": content}]})
|
|
1040
1880
|
elif role == ChatCompletionMessageRole.SYSTEM:
|
|
1041
1881
|
system_prompts.append(content)
|
|
1042
1882
|
elif role == ChatCompletionMessageRole.TOOL:
|
|
1043
1883
|
raise NotImplementedError
|
|
1044
1884
|
else:
|
|
1045
1885
|
assert_never(role)
|
|
1046
|
-
if google_message_history:
|
|
1047
|
-
prompt = google_message_history.pop()["parts"]
|
|
1048
|
-
else:
|
|
1049
|
-
prompt = ""
|
|
1050
1886
|
|
|
1051
|
-
return
|
|
1887
|
+
return google_messages, "\n".join(system_prompts)
|
|
1888
|
+
|
|
1889
|
+
|
|
1890
|
+
@register_llm_client(
|
|
1891
|
+
provider_key=GenerativeProviderKey.GOOGLE,
|
|
1892
|
+
model_names=[
|
|
1893
|
+
PROVIDER_DEFAULT,
|
|
1894
|
+
"gemini-2.5-pro",
|
|
1895
|
+
"gemini-2.5-flash",
|
|
1896
|
+
"gemini-2.5-flash-lite",
|
|
1897
|
+
"gemini-2.5-pro-preview-03-25",
|
|
1898
|
+
],
|
|
1899
|
+
)
|
|
1900
|
+
class Gemini25GoogleStreamingClient(GoogleStreamingClient):
|
|
1901
|
+
@classmethod
|
|
1902
|
+
def supported_invocation_parameters(cls) -> list[InvocationParameter]:
|
|
1903
|
+
return [
|
|
1904
|
+
BoundedFloatInvocationParameter(
|
|
1905
|
+
invocation_name="temperature",
|
|
1906
|
+
canonical_name=CanonicalParameterName.TEMPERATURE,
|
|
1907
|
+
label="Temperature",
|
|
1908
|
+
default_value=1.0,
|
|
1909
|
+
min_value=0.0,
|
|
1910
|
+
max_value=2.0,
|
|
1911
|
+
),
|
|
1912
|
+
IntInvocationParameter(
|
|
1913
|
+
invocation_name="max_output_tokens",
|
|
1914
|
+
canonical_name=CanonicalParameterName.MAX_COMPLETION_TOKENS,
|
|
1915
|
+
label="Max Output Tokens",
|
|
1916
|
+
),
|
|
1917
|
+
StringListInvocationParameter(
|
|
1918
|
+
invocation_name="stop_sequences",
|
|
1919
|
+
canonical_name=CanonicalParameterName.STOP_SEQUENCES,
|
|
1920
|
+
label="Stop Sequences",
|
|
1921
|
+
),
|
|
1922
|
+
BoundedFloatInvocationParameter(
|
|
1923
|
+
invocation_name="top_p",
|
|
1924
|
+
canonical_name=CanonicalParameterName.TOP_P,
|
|
1925
|
+
label="Top P",
|
|
1926
|
+
min_value=0.0,
|
|
1927
|
+
max_value=1.0,
|
|
1928
|
+
),
|
|
1929
|
+
FloatInvocationParameter(
|
|
1930
|
+
invocation_name="top_k",
|
|
1931
|
+
label="Top K",
|
|
1932
|
+
),
|
|
1933
|
+
JSONInvocationParameter(
|
|
1934
|
+
invocation_name="tool_config",
|
|
1935
|
+
label="Tool Choice",
|
|
1936
|
+
canonical_name=CanonicalParameterName.TOOL_CHOICE,
|
|
1937
|
+
),
|
|
1938
|
+
]
|
|
1939
|
+
|
|
1940
|
+
|
|
1941
|
+
@register_llm_client(
|
|
1942
|
+
provider_key=GenerativeProviderKey.GOOGLE,
|
|
1943
|
+
model_names=[
|
|
1944
|
+
"gemini-3-pro-preview",
|
|
1945
|
+
],
|
|
1946
|
+
)
|
|
1947
|
+
class Gemini3GoogleStreamingClient(Gemini25GoogleStreamingClient):
|
|
1948
|
+
@classmethod
|
|
1949
|
+
def supported_invocation_parameters(cls) -> list[InvocationParameter]:
|
|
1950
|
+
return [
|
|
1951
|
+
StringInvocationParameter(
|
|
1952
|
+
invocation_name="thinking_level",
|
|
1953
|
+
label="Thinking Level",
|
|
1954
|
+
canonical_name=CanonicalParameterName.REASONING_EFFORT,
|
|
1955
|
+
),
|
|
1956
|
+
*super().supported_invocation_parameters(),
|
|
1957
|
+
]
|
|
1958
|
+
|
|
1959
|
+
async def chat_completion_create(
|
|
1960
|
+
self,
|
|
1961
|
+
messages: list[
|
|
1962
|
+
tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[JSONScalarType]]]
|
|
1963
|
+
],
|
|
1964
|
+
tools: list[JSONScalarType],
|
|
1965
|
+
**invocation_parameters: Any,
|
|
1966
|
+
) -> AsyncIterator[ChatCompletionChunk]:
|
|
1967
|
+
# Extract thinking_level and construct thinking_config
|
|
1968
|
+
thinking_level = invocation_parameters.pop("thinking_level", None)
|
|
1969
|
+
|
|
1970
|
+
if thinking_level:
|
|
1971
|
+
try:
|
|
1972
|
+
import google.genai
|
|
1973
|
+
from packaging.version import parse as parse_version
|
|
1974
|
+
|
|
1975
|
+
if parse_version(google.genai.__version__) < parse_version("1.50.0"):
|
|
1976
|
+
raise ImportError
|
|
1977
|
+
except (ImportError, AttributeError):
|
|
1978
|
+
raise BadRequest(
|
|
1979
|
+
"Reasoning capabilities for Gemini models require `google-genai>=1.50.0` "
|
|
1980
|
+
"and Python >= 3.10."
|
|
1981
|
+
)
|
|
1982
|
+
|
|
1983
|
+
# NOTE: as of gemini 1.51.0 medium thinking is not supported
|
|
1984
|
+
# but will eventually be added in a future version
|
|
1985
|
+
# we are purposefully allowing users to select medium knowing
|
|
1986
|
+
# it does not work.
|
|
1987
|
+
invocation_parameters["thinking_config"] = {
|
|
1988
|
+
"include_thoughts": True,
|
|
1989
|
+
"thinking_level": thinking_level.upper(),
|
|
1990
|
+
}
|
|
1991
|
+
|
|
1992
|
+
async for chunk in super().chat_completion_create(messages, tools, **invocation_parameters):
|
|
1993
|
+
yield chunk
|
|
1052
1994
|
|
|
1053
1995
|
|
|
1054
1996
|
def initialize_playground_clients() -> None:
|
|
@@ -1063,6 +2005,15 @@ LLM_SYSTEM = SpanAttributes.LLM_SYSTEM
|
|
|
1063
2005
|
LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT
|
|
1064
2006
|
LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION
|
|
1065
2007
|
LLM_TOKEN_COUNT_TOTAL = SpanAttributes.LLM_TOKEN_COUNT_TOTAL
|
|
2008
|
+
LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_READ = SpanAttributes.LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_READ
|
|
2009
|
+
LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_WRITE = (
|
|
2010
|
+
SpanAttributes.LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_WRITE
|
|
2011
|
+
)
|
|
2012
|
+
LLM_TOKEN_COUNT_PROMPT_DETAILS_AUDIO = SpanAttributes.LLM_TOKEN_COUNT_PROMPT_DETAILS_AUDIO
|
|
2013
|
+
LLM_TOKEN_COUNT_COMPLETION_DETAILS_REASONING = (
|
|
2014
|
+
SpanAttributes.LLM_TOKEN_COUNT_COMPLETION_DETAILS_REASONING
|
|
2015
|
+
)
|
|
2016
|
+
LLM_TOKEN_COUNT_COMPLETION_DETAILS_AUDIO = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION_DETAILS_AUDIO
|
|
1066
2017
|
|
|
1067
2018
|
|
|
1068
2019
|
class _HttpxClient(wrapt.ObjectProxy): # type: ignore
|