arize-phoenix 11.23.1__py3-none-any.whl → 12.28.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +61 -36
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/RECORD +212 -162
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
- phoenix/__generated__/__init__.py +0 -0
- phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
- phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
- phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
- phoenix/__init__.py +2 -1
- phoenix/auth.py +27 -2
- phoenix/config.py +1594 -81
- phoenix/db/README.md +546 -28
- phoenix/db/bulk_inserter.py +119 -116
- phoenix/db/engines.py +140 -33
- phoenix/db/facilitator.py +22 -1
- phoenix/db/helpers.py +818 -65
- phoenix/db/iam_auth.py +64 -0
- phoenix/db/insertion/dataset.py +133 -1
- phoenix/db/insertion/document_annotation.py +9 -6
- phoenix/db/insertion/evaluation.py +2 -3
- phoenix/db/insertion/helpers.py +2 -2
- phoenix/db/insertion/session_annotation.py +176 -0
- phoenix/db/insertion/span_annotation.py +3 -4
- phoenix/db/insertion/trace_annotation.py +3 -4
- phoenix/db/insertion/types.py +41 -18
- phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
- phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
- phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
- phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
- phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
- phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
- phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
- phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
- phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
- phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
- phoenix/db/models.py +364 -56
- phoenix/db/pg_config.py +10 -0
- phoenix/db/types/trace_retention.py +7 -6
- phoenix/experiments/functions.py +69 -19
- phoenix/inferences/inferences.py +1 -2
- phoenix/server/api/auth.py +9 -0
- phoenix/server/api/auth_messages.py +46 -0
- phoenix/server/api/context.py +60 -0
- phoenix/server/api/dataloaders/__init__.py +36 -0
- phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
- phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
- phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
- phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
- phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
- phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
- phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
- phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
- phoenix/server/api/dataloaders/dataset_labels.py +36 -0
- phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
- phoenix/server/api/dataloaders/document_evaluations.py +6 -9
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
- phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
- phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
- phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
- phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
- phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
- phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
- phoenix/server/api/dataloaders/record_counts.py +37 -10
- phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project.py +28 -14
- phoenix/server/api/dataloaders/span_costs.py +3 -9
- phoenix/server/api/dataloaders/table_fields.py +2 -2
- phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
- phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
- phoenix/server/api/exceptions.py +5 -1
- phoenix/server/api/helpers/playground_clients.py +263 -83
- phoenix/server/api/helpers/playground_spans.py +2 -1
- phoenix/server/api/helpers/playground_users.py +26 -0
- phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
- phoenix/server/api/helpers/prompts/models.py +61 -19
- phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
- phoenix/server/api/input_types/ChatCompletionInput.py +3 -0
- phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
- phoenix/server/api/input_types/DatasetFilter.py +5 -2
- phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
- phoenix/server/api/input_types/GenerativeModelInput.py +3 -0
- phoenix/server/api/input_types/ProjectSessionSort.py +158 -1
- phoenix/server/api/input_types/PromptVersionInput.py +47 -1
- phoenix/server/api/input_types/SpanSort.py +3 -2
- phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
- phoenix/server/api/input_types/UserRoleInput.py +1 -0
- phoenix/server/api/mutations/__init__.py +8 -0
- phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
- phoenix/server/api/mutations/api_key_mutations.py +15 -20
- phoenix/server/api/mutations/chat_mutations.py +106 -37
- phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
- phoenix/server/api/mutations/dataset_mutations.py +21 -16
- phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
- phoenix/server/api/mutations/experiment_mutations.py +2 -2
- phoenix/server/api/mutations/export_events_mutations.py +3 -3
- phoenix/server/api/mutations/model_mutations.py +11 -9
- phoenix/server/api/mutations/project_mutations.py +4 -4
- phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
- phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
- phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
- phoenix/server/api/mutations/prompt_mutations.py +65 -129
- phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
- phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
- phoenix/server/api/mutations/trace_annotations_mutations.py +13 -8
- phoenix/server/api/mutations/trace_mutations.py +3 -3
- phoenix/server/api/mutations/user_mutations.py +55 -26
- phoenix/server/api/queries.py +501 -617
- phoenix/server/api/routers/__init__.py +2 -2
- phoenix/server/api/routers/auth.py +141 -87
- phoenix/server/api/routers/ldap.py +229 -0
- phoenix/server/api/routers/oauth2.py +349 -101
- phoenix/server/api/routers/v1/__init__.py +22 -4
- phoenix/server/api/routers/v1/annotation_configs.py +19 -30
- phoenix/server/api/routers/v1/annotations.py +455 -13
- phoenix/server/api/routers/v1/datasets.py +355 -68
- phoenix/server/api/routers/v1/documents.py +142 -0
- phoenix/server/api/routers/v1/evaluations.py +20 -28
- phoenix/server/api/routers/v1/experiment_evaluations.py +16 -6
- phoenix/server/api/routers/v1/experiment_runs.py +335 -59
- phoenix/server/api/routers/v1/experiments.py +475 -47
- phoenix/server/api/routers/v1/projects.py +16 -50
- phoenix/server/api/routers/v1/prompts.py +50 -39
- phoenix/server/api/routers/v1/sessions.py +108 -0
- phoenix/server/api/routers/v1/spans.py +156 -96
- phoenix/server/api/routers/v1/traces.py +51 -77
- phoenix/server/api/routers/v1/users.py +64 -24
- phoenix/server/api/routers/v1/utils.py +3 -7
- phoenix/server/api/subscriptions.py +257 -93
- phoenix/server/api/types/Annotation.py +90 -23
- phoenix/server/api/types/ApiKey.py +13 -17
- phoenix/server/api/types/AuthMethod.py +1 -0
- phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
- phoenix/server/api/types/Dataset.py +199 -72
- phoenix/server/api/types/DatasetExample.py +88 -18
- phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
- phoenix/server/api/types/DatasetLabel.py +57 -0
- phoenix/server/api/types/DatasetSplit.py +98 -0
- phoenix/server/api/types/DatasetVersion.py +49 -4
- phoenix/server/api/types/DocumentAnnotation.py +212 -0
- phoenix/server/api/types/Experiment.py +215 -68
- phoenix/server/api/types/ExperimentComparison.py +3 -9
- phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
- phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
- phoenix/server/api/types/ExperimentRun.py +120 -70
- phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
- phoenix/server/api/types/GenerativeModel.py +95 -42
- phoenix/server/api/types/GenerativeProvider.py +1 -1
- phoenix/server/api/types/ModelInterface.py +7 -2
- phoenix/server/api/types/PlaygroundModel.py +12 -2
- phoenix/server/api/types/Project.py +218 -185
- phoenix/server/api/types/ProjectSession.py +146 -29
- phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
- phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
- phoenix/server/api/types/Prompt.py +119 -39
- phoenix/server/api/types/PromptLabel.py +42 -25
- phoenix/server/api/types/PromptVersion.py +11 -8
- phoenix/server/api/types/PromptVersionTag.py +65 -25
- phoenix/server/api/types/Span.py +130 -123
- phoenix/server/api/types/SpanAnnotation.py +189 -42
- phoenix/server/api/types/SystemApiKey.py +65 -1
- phoenix/server/api/types/Trace.py +184 -53
- phoenix/server/api/types/TraceAnnotation.py +149 -50
- phoenix/server/api/types/User.py +128 -33
- phoenix/server/api/types/UserApiKey.py +73 -26
- phoenix/server/api/types/node.py +10 -0
- phoenix/server/api/types/pagination.py +11 -2
- phoenix/server/app.py +154 -36
- phoenix/server/authorization.py +5 -4
- phoenix/server/bearer_auth.py +13 -5
- phoenix/server/cost_tracking/cost_model_lookup.py +42 -14
- phoenix/server/cost_tracking/model_cost_manifest.json +1085 -194
- phoenix/server/daemons/generative_model_store.py +61 -9
- phoenix/server/daemons/span_cost_calculator.py +10 -8
- phoenix/server/dml_event.py +13 -0
- phoenix/server/email/sender.py +29 -2
- phoenix/server/grpc_server.py +9 -9
- phoenix/server/jwt_store.py +8 -6
- phoenix/server/ldap.py +1449 -0
- phoenix/server/main.py +9 -3
- phoenix/server/oauth2.py +330 -12
- phoenix/server/prometheus.py +43 -6
- phoenix/server/rate_limiters.py +4 -9
- phoenix/server/retention.py +33 -20
- phoenix/server/session_filters.py +49 -0
- phoenix/server/static/.vite/manifest.json +51 -53
- phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
- phoenix/server/static/assets/{index-BPCwGQr8.js → index-CTQoemZv.js} +42 -35
- phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
- phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
- phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
- phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
- phoenix/server/static/assets/{vendor-recharts-Bw30oz1A.js → vendor-recharts-V9cwpXsm.js} +7 -7
- phoenix/server/static/assets/{vendor-shiki-DZajAPeq.js → vendor-shiki-Do--csgv.js} +1 -1
- phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
- phoenix/server/templates/index.html +7 -1
- phoenix/server/thread_server.py +1 -2
- phoenix/server/utils.py +74 -0
- phoenix/session/client.py +55 -1
- phoenix/session/data_extractor.py +5 -0
- phoenix/session/evaluation.py +8 -4
- phoenix/session/session.py +44 -8
- phoenix/settings.py +2 -0
- phoenix/trace/attributes.py +80 -13
- phoenix/trace/dsl/query.py +2 -0
- phoenix/trace/projects.py +5 -0
- phoenix/utilities/template_formatters.py +1 -1
- phoenix/version.py +1 -1
- phoenix/server/api/types/Evaluation.py +0 -39
- phoenix/server/static/assets/components-D0DWAf0l.js +0 -5650
- phoenix/server/static/assets/pages-Creyamao.js +0 -8612
- phoenix/server/static/assets/vendor-CU36oj8y.js +0 -905
- phoenix/server/static/assets/vendor-CqDb5u4o.css +0 -1
- phoenix/server/static/assets/vendor-arizeai-Ctgw0e1G.js +0 -168
- phoenix/server/static/assets/vendor-codemirror-Cojjzqb9.js +0 -25
- phoenix/server/static/assets/vendor-three-BLWp5bic.js +0 -2998
- phoenix/utilities/deprecation.py +0 -31
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -18,7 +18,7 @@ _AttrStrIdentifier: TypeAlias = str
|
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
class TableFieldsDataLoader(DataLoader[Key, Result]):
|
|
21
|
-
def __init__(self, db: DbSessionFactory, table: type[models.
|
|
21
|
+
def __init__(self, db: DbSessionFactory, table: type[models.HasId]) -> None:
|
|
22
22
|
super().__init__(load_fn=self._load_fn)
|
|
23
23
|
self._db = db
|
|
24
24
|
self._table = table
|
|
@@ -37,7 +37,7 @@ class TableFieldsDataLoader(DataLoader[Key, Result]):
|
|
|
37
37
|
|
|
38
38
|
def _get_stmt(
|
|
39
39
|
keys: Iterable[tuple[RowId, QueryableAttribute[Any]]],
|
|
40
|
-
table: type[models.
|
|
40
|
+
table: type[models.HasId],
|
|
41
41
|
) -> tuple[
|
|
42
42
|
Select[Any],
|
|
43
43
|
dict[_ResultColumnPosition, _AttrStrIdentifier],
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
|
|
3
|
+
from sqlalchemy import select
|
|
4
|
+
from strawberry.dataloader import DataLoader
|
|
5
|
+
from typing_extensions import TypeAlias
|
|
6
|
+
|
|
7
|
+
from phoenix.db import models
|
|
8
|
+
from phoenix.server.types import DbSessionFactory
|
|
9
|
+
|
|
10
|
+
ModelId: TypeAlias = int
|
|
11
|
+
Key: TypeAlias = ModelId
|
|
12
|
+
Result: TypeAlias = list[models.TokenPrice]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class TokenPricesByModelDataLoader(DataLoader[Key, Result]):
|
|
16
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
17
|
+
super().__init__(load_fn=self._load_fn)
|
|
18
|
+
self._db = db
|
|
19
|
+
|
|
20
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
21
|
+
model_ids = keys
|
|
22
|
+
token_prices: defaultdict[Key, Result] = defaultdict(list)
|
|
23
|
+
|
|
24
|
+
async with self._db() as session:
|
|
25
|
+
async for token_price in await session.stream_scalars(
|
|
26
|
+
select(models.TokenPrice).where(models.TokenPrice.model_id.in_(model_ids))
|
|
27
|
+
):
|
|
28
|
+
token_prices[token_price.model_id].append(token_price)
|
|
29
|
+
|
|
30
|
+
return [token_prices[model_id] for model_id in keys]
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
|
|
3
|
+
from sqlalchemy import select
|
|
4
|
+
from strawberry.dataloader import DataLoader
|
|
5
|
+
from typing_extensions import TypeAlias
|
|
6
|
+
|
|
7
|
+
from phoenix.db.models import TraceAnnotation
|
|
8
|
+
from phoenix.server.types import DbSessionFactory
|
|
9
|
+
|
|
10
|
+
TraceRowId: TypeAlias = int
|
|
11
|
+
Key: TypeAlias = TraceRowId
|
|
12
|
+
Result: TypeAlias = list[TraceAnnotation]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class TraceAnnotationsByTraceDataLoader(DataLoader[Key, Result]):
|
|
16
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
17
|
+
super().__init__(load_fn=self._load_fn)
|
|
18
|
+
self._db = db
|
|
19
|
+
|
|
20
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
21
|
+
annotations_by_id: defaultdict[Key, Result] = defaultdict(list)
|
|
22
|
+
async with self._db() as session:
|
|
23
|
+
async for annotation in await session.stream_scalars(
|
|
24
|
+
select(TraceAnnotation).where(TraceAnnotation.trace_rowid.in_(keys))
|
|
25
|
+
):
|
|
26
|
+
annotations_by_id[annotation.trace_rowid].append(annotation)
|
|
27
|
+
return [annotations_by_id[key] for key in keys]
|
phoenix/server/api/exceptions.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
from graphql.error import GraphQLError
|
|
2
2
|
from strawberry.extensions import MaskErrors
|
|
3
3
|
|
|
4
|
+
from phoenix.config import get_env_mask_internal_server_errors
|
|
5
|
+
|
|
4
6
|
|
|
5
7
|
class CustomGraphQLError(Exception):
|
|
6
8
|
"""
|
|
@@ -51,4 +53,6 @@ def _should_mask_error(error: GraphQLError) -> bool:
|
|
|
51
53
|
"""
|
|
52
54
|
Masks unexpected errors raised from GraphQL resolvers.
|
|
53
55
|
"""
|
|
54
|
-
return not isinstance(
|
|
56
|
+
return get_env_mask_internal_server_errors() and not isinstance(
|
|
57
|
+
error.original_error, CustomGraphQLError
|
|
58
|
+
)
|
|
@@ -57,6 +57,7 @@ from phoenix.server.api.types.GenerativeProvider import GenerativeProviderKey
|
|
|
57
57
|
if TYPE_CHECKING:
|
|
58
58
|
import httpx
|
|
59
59
|
from anthropic.types import MessageParam, TextBlockParam, ToolResultBlockParam
|
|
60
|
+
from botocore.awsrequest import AWSPreparedRequest # type: ignore[import-untyped]
|
|
60
61
|
from google.generativeai.types import ContentType
|
|
61
62
|
from openai import AsyncAzureOpenAI, AsyncOpenAI
|
|
62
63
|
from openai.types import CompletionUsage
|
|
@@ -308,7 +309,6 @@ class OpenAIBaseStreamingClient(PlaygroundStreamingClient):
|
|
|
308
309
|
invocation_name="top_p",
|
|
309
310
|
canonical_name=CanonicalParameterName.TOP_P,
|
|
310
311
|
label="Top P",
|
|
311
|
-
default_value=1.0,
|
|
312
312
|
min_value=0.0,
|
|
313
313
|
max_value=1.0,
|
|
314
314
|
),
|
|
@@ -327,6 +327,10 @@ class OpenAIBaseStreamingClient(PlaygroundStreamingClient):
|
|
|
327
327
|
label="Response Format",
|
|
328
328
|
canonical_name=CanonicalParameterName.RESPONSE_FORMAT,
|
|
329
329
|
),
|
|
330
|
+
JSONInvocationParameter(
|
|
331
|
+
invocation_name="extra_body",
|
|
332
|
+
label="Extra Body",
|
|
333
|
+
),
|
|
330
334
|
]
|
|
331
335
|
|
|
332
336
|
async def chat_completion_create(
|
|
@@ -543,7 +547,11 @@ class DeepSeekStreamingClient(OpenAIBaseStreamingClient):
|
|
|
543
547
|
raise BadRequest("An API key is required for DeepSeek models")
|
|
544
548
|
api_key = "sk-fake-api-key"
|
|
545
549
|
|
|
546
|
-
client = AsyncOpenAI(
|
|
550
|
+
client = AsyncOpenAI(
|
|
551
|
+
api_key=api_key,
|
|
552
|
+
base_url=base_url or "https://api.deepseek.com",
|
|
553
|
+
default_headers=model.custom_headers or None,
|
|
554
|
+
)
|
|
547
555
|
super().__init__(client=client, model=model, credentials=credentials)
|
|
548
556
|
# DeepSeek uses OpenAI-compatible API but we'll track it as a separate provider
|
|
549
557
|
# Adding a custom "deepseek" provider value to make it distinguishable in traces
|
|
@@ -581,7 +589,11 @@ class XAIStreamingClient(OpenAIBaseStreamingClient):
|
|
|
581
589
|
raise BadRequest("An API key is required for xAI models")
|
|
582
590
|
api_key = "sk-fake-api-key"
|
|
583
591
|
|
|
584
|
-
client = AsyncOpenAI(
|
|
592
|
+
client = AsyncOpenAI(
|
|
593
|
+
api_key=api_key,
|
|
594
|
+
base_url=base_url or "https://api.x.ai/v1",
|
|
595
|
+
default_headers=model.custom_headers or None,
|
|
596
|
+
)
|
|
585
597
|
super().__init__(client=client, model=model, credentials=credentials)
|
|
586
598
|
# xAI uses OpenAI-compatible API but we'll track it as a separate provider
|
|
587
599
|
# Adding a custom "xai" provider value to make it distinguishable in traces
|
|
@@ -618,7 +630,11 @@ class OllamaStreamingClient(OpenAIBaseStreamingClient):
|
|
|
618
630
|
if not base_url:
|
|
619
631
|
raise BadRequest("An Ollama base URL is required for Ollama models")
|
|
620
632
|
api_key = "ollama"
|
|
621
|
-
client = AsyncOpenAI(
|
|
633
|
+
client = AsyncOpenAI(
|
|
634
|
+
api_key=api_key,
|
|
635
|
+
base_url=base_url,
|
|
636
|
+
default_headers=model.custom_headers or None,
|
|
637
|
+
)
|
|
622
638
|
super().__init__(client=client, model=model, credentials=credentials)
|
|
623
639
|
# Ollama uses OpenAI-compatible API but we'll track it as a separate provider
|
|
624
640
|
# Adding a custom "ollama" provider value to make it distinguishable in traces
|
|
@@ -630,13 +646,17 @@ class OllamaStreamingClient(OpenAIBaseStreamingClient):
|
|
|
630
646
|
provider_key=GenerativeProviderKey.AWS,
|
|
631
647
|
model_names=[
|
|
632
648
|
PROVIDER_DEFAULT,
|
|
633
|
-
"anthropic.claude-
|
|
649
|
+
"anthropic.claude-opus-4-5-20251101-v1:0",
|
|
650
|
+
"anthropic.claude-sonnet-4-5-20250929-v1:0",
|
|
651
|
+
"anthropic.claude-haiku-4-5-20251001-v1:0",
|
|
652
|
+
"anthropic.claude-opus-4-1-20250805-v1:0",
|
|
653
|
+
"anthropic.claude-opus-4-20250514-v1:0",
|
|
654
|
+
"anthropic.claude-sonnet-4-20250514-v1:0",
|
|
634
655
|
"anthropic.claude-3-7-sonnet-20250219-v1:0",
|
|
635
|
-
"anthropic.claude-3-haiku-20240307-v1:0",
|
|
636
656
|
"anthropic.claude-3-5-sonnet-20241022-v2:0",
|
|
657
|
+
"anthropic.claude-3-5-sonnet-20240620-v1:0",
|
|
637
658
|
"anthropic.claude-3-5-haiku-20241022-v1:0",
|
|
638
|
-
"anthropic.claude-
|
|
639
|
-
"anthropic.claude-sonnet-4-20250514-v1:0",
|
|
659
|
+
"anthropic.claude-3-haiku-20240307-v1:0",
|
|
640
660
|
"amazon.titan-embed-text-v2:0",
|
|
641
661
|
"amazon.nova-pro-v1:0",
|
|
642
662
|
"amazon.nova-premier-v1:0:8k",
|
|
@@ -671,29 +691,45 @@ class BedrockStreamingClient(PlaygroundStreamingClient):
|
|
|
671
691
|
import boto3 # type: ignore[import-untyped]
|
|
672
692
|
|
|
673
693
|
super().__init__(model=model, credentials=credentials)
|
|
674
|
-
|
|
694
|
+
region = model.region or "us-east-1"
|
|
675
695
|
self.api = "converse"
|
|
676
|
-
|
|
696
|
+
custom_headers = model.custom_headers
|
|
697
|
+
aws_access_key_id = _get_credential_value(credentials, "AWS_ACCESS_KEY_ID") or getenv(
|
|
677
698
|
"AWS_ACCESS_KEY_ID"
|
|
678
699
|
)
|
|
679
|
-
|
|
700
|
+
aws_secret_access_key = _get_credential_value(
|
|
680
701
|
credentials, "AWS_SECRET_ACCESS_KEY"
|
|
681
702
|
) or getenv("AWS_SECRET_ACCESS_KEY")
|
|
682
|
-
|
|
703
|
+
aws_session_token = _get_credential_value(credentials, "AWS_SESSION_TOKEN") or getenv(
|
|
683
704
|
"AWS_SESSION_TOKEN"
|
|
684
705
|
)
|
|
685
706
|
self.model_name = model.name
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
aws_session_token=self.aws_session_token,
|
|
707
|
+
session = boto3.Session(
|
|
708
|
+
region_name=region,
|
|
709
|
+
aws_access_key_id=aws_access_key_id,
|
|
710
|
+
aws_secret_access_key=aws_secret_access_key,
|
|
711
|
+
aws_session_token=aws_session_token,
|
|
692
712
|
)
|
|
713
|
+
client = session.client(service_name="bedrock-runtime")
|
|
714
|
+
|
|
715
|
+
# Add custom headers support via boto3 event system
|
|
716
|
+
if custom_headers:
|
|
717
|
+
|
|
718
|
+
def add_custom_headers(request: "AWSPreparedRequest", **kwargs: Any) -> None:
|
|
719
|
+
request.headers.update(custom_headers)
|
|
720
|
+
|
|
721
|
+
client.meta.events.register("before-send.*", add_custom_headers)
|
|
693
722
|
|
|
723
|
+
self.client = client
|
|
694
724
|
self._attributes[LLM_PROVIDER] = "aws"
|
|
695
725
|
self._attributes[LLM_SYSTEM] = "aws"
|
|
696
726
|
|
|
727
|
+
@staticmethod
|
|
728
|
+
def _setup_custom_headers(client: Any, custom_headers: Mapping[str, str]) -> None:
|
|
729
|
+
"""Setup custom headers using boto3's event system."""
|
|
730
|
+
if not custom_headers:
|
|
731
|
+
return
|
|
732
|
+
|
|
697
733
|
@classmethod
|
|
698
734
|
def dependencies(cls) -> list[Dependency]:
|
|
699
735
|
return [Dependency(name="boto3")]
|
|
@@ -719,7 +755,6 @@ class BedrockStreamingClient(PlaygroundStreamingClient):
|
|
|
719
755
|
invocation_name="top_p",
|
|
720
756
|
canonical_name=CanonicalParameterName.TOP_P,
|
|
721
757
|
label="Top P",
|
|
722
|
-
default_value=1.0,
|
|
723
758
|
min_value=0.0,
|
|
724
759
|
max_value=1.0,
|
|
725
760
|
),
|
|
@@ -738,18 +773,6 @@ class BedrockStreamingClient(PlaygroundStreamingClient):
|
|
|
738
773
|
tools: list[JSONScalarType],
|
|
739
774
|
**invocation_parameters: Any,
|
|
740
775
|
) -> AsyncIterator[ChatCompletionChunk]:
|
|
741
|
-
import boto3
|
|
742
|
-
|
|
743
|
-
if (
|
|
744
|
-
self.client.meta.region_name != self.region
|
|
745
|
-
): # override the region if it's different from the default
|
|
746
|
-
self.client = boto3.client(
|
|
747
|
-
"bedrock-runtime",
|
|
748
|
-
region_name=self.region,
|
|
749
|
-
aws_access_key_id=self.aws_access_key_id,
|
|
750
|
-
aws_secret_access_key=self.aws_secret_access_key,
|
|
751
|
-
aws_session_token=self.aws_session_token,
|
|
752
|
-
)
|
|
753
776
|
if self.api == "invoke":
|
|
754
777
|
async for chunk in self._handle_invoke_api(messages, tools, invocation_parameters):
|
|
755
778
|
yield chunk
|
|
@@ -771,15 +794,25 @@ class BedrockStreamingClient(PlaygroundStreamingClient):
|
|
|
771
794
|
# Build messages in Converse API format
|
|
772
795
|
converse_messages = self._build_converse_messages(messages)
|
|
773
796
|
|
|
797
|
+
inference_config = {}
|
|
798
|
+
if (
|
|
799
|
+
"max_tokens" in invocation_parameters
|
|
800
|
+
and invocation_parameters["max_tokens"] is not None
|
|
801
|
+
):
|
|
802
|
+
inference_config["maxTokens"] = invocation_parameters["max_tokens"]
|
|
803
|
+
if (
|
|
804
|
+
"temperature" in invocation_parameters
|
|
805
|
+
and invocation_parameters["temperature"] is not None
|
|
806
|
+
):
|
|
807
|
+
inference_config["temperature"] = invocation_parameters["temperature"]
|
|
808
|
+
if "top_p" in invocation_parameters and invocation_parameters["top_p"] is not None:
|
|
809
|
+
inference_config["topP"] = invocation_parameters["top_p"]
|
|
810
|
+
|
|
774
811
|
# Build the request parameters for Converse API
|
|
775
812
|
converse_params: dict[str, Any] = {
|
|
776
|
-
"modelId":
|
|
813
|
+
"modelId": self.model_name,
|
|
777
814
|
"messages": converse_messages,
|
|
778
|
-
"inferenceConfig":
|
|
779
|
-
"maxTokens": invocation_parameters["max_tokens"],
|
|
780
|
-
"temperature": invocation_parameters["temperature"],
|
|
781
|
-
"topP": invocation_parameters["top_p"],
|
|
782
|
-
},
|
|
815
|
+
"inferenceConfig": inference_config,
|
|
783
816
|
}
|
|
784
817
|
|
|
785
818
|
# Add system prompt if available
|
|
@@ -912,16 +945,26 @@ class BedrockStreamingClient(PlaygroundStreamingClient):
|
|
|
912
945
|
bedrock_messages, system_prompt = self._build_bedrock_messages(messages)
|
|
913
946
|
bedrock_params = {
|
|
914
947
|
"anthropic_version": "bedrock-2023-05-31",
|
|
915
|
-
"max_tokens": invocation_parameters["max_tokens"],
|
|
916
948
|
"messages": bedrock_messages,
|
|
917
949
|
"system": system_prompt,
|
|
918
|
-
"temperature": invocation_parameters["temperature"],
|
|
919
|
-
"top_p": invocation_parameters["top_p"],
|
|
920
950
|
"tools": tools,
|
|
921
951
|
}
|
|
922
952
|
|
|
953
|
+
if (
|
|
954
|
+
"max_tokens" in invocation_parameters
|
|
955
|
+
and invocation_parameters["max_tokens"] is not None
|
|
956
|
+
):
|
|
957
|
+
bedrock_params["max_tokens"] = invocation_parameters["max_tokens"]
|
|
958
|
+
if (
|
|
959
|
+
"temperature" in invocation_parameters
|
|
960
|
+
and invocation_parameters["temperature"] is not None
|
|
961
|
+
):
|
|
962
|
+
bedrock_params["temperature"] = invocation_parameters["temperature"]
|
|
963
|
+
if "top_p" in invocation_parameters and invocation_parameters["top_p"] is not None:
|
|
964
|
+
bedrock_params["top_p"] = invocation_parameters["top_p"]
|
|
965
|
+
|
|
923
966
|
response = self.client.invoke_model_with_response_stream(
|
|
924
|
-
modelId=
|
|
967
|
+
modelId=self.model_name,
|
|
925
968
|
contentType="application/json",
|
|
926
969
|
accept="application/json",
|
|
927
970
|
body=json.dumps(bedrock_params),
|
|
@@ -1134,13 +1177,24 @@ class OpenAIStreamingClient(OpenAIBaseStreamingClient):
|
|
|
1134
1177
|
raise BadRequest("An API key is required for OpenAI models")
|
|
1135
1178
|
api_key = "sk-fake-api-key"
|
|
1136
1179
|
|
|
1137
|
-
client = AsyncOpenAI(
|
|
1180
|
+
client = AsyncOpenAI(
|
|
1181
|
+
api_key=api_key,
|
|
1182
|
+
base_url=base_url,
|
|
1183
|
+
default_headers=model.custom_headers or None,
|
|
1184
|
+
timeout=30,
|
|
1185
|
+
)
|
|
1138
1186
|
super().__init__(client=client, model=model, credentials=credentials)
|
|
1139
1187
|
self._attributes[LLM_PROVIDER] = OpenInferenceLLMProviderValues.OPENAI.value
|
|
1140
1188
|
self._attributes[LLM_SYSTEM] = OpenInferenceLLMSystemValues.OPENAI.value
|
|
1141
1189
|
|
|
1142
1190
|
|
|
1143
1191
|
_OPENAI_REASONING_MODELS = [
|
|
1192
|
+
"gpt-5.2",
|
|
1193
|
+
"gpt-5.2-2025-12-11",
|
|
1194
|
+
"gpt-5.2-chat-latest",
|
|
1195
|
+
"gpt-5.1",
|
|
1196
|
+
"gpt-5.1-2025-11-13",
|
|
1197
|
+
"gpt-5.1-chat-latest",
|
|
1144
1198
|
"gpt-5",
|
|
1145
1199
|
"gpt-5-mini",
|
|
1146
1200
|
"gpt-5-nano",
|
|
@@ -1194,6 +1248,10 @@ class OpenAIReasoningReasoningModelsMixin:
|
|
|
1194
1248
|
label="Response Format",
|
|
1195
1249
|
canonical_name=CanonicalParameterName.RESPONSE_FORMAT,
|
|
1196
1250
|
),
|
|
1251
|
+
JSONInvocationParameter(
|
|
1252
|
+
invocation_name="extra_body",
|
|
1253
|
+
label="Extra Body",
|
|
1254
|
+
),
|
|
1197
1255
|
]
|
|
1198
1256
|
|
|
1199
1257
|
|
|
@@ -1289,6 +1347,7 @@ class AzureOpenAIStreamingClient(OpenAIBaseStreamingClient):
|
|
|
1289
1347
|
api_key=api_key,
|
|
1290
1348
|
azure_endpoint=endpoint,
|
|
1291
1349
|
api_version=api_version,
|
|
1350
|
+
default_headers=model.custom_headers or None,
|
|
1292
1351
|
)
|
|
1293
1352
|
else:
|
|
1294
1353
|
try:
|
|
@@ -1306,6 +1365,7 @@ class AzureOpenAIStreamingClient(OpenAIBaseStreamingClient):
|
|
|
1306
1365
|
),
|
|
1307
1366
|
azure_endpoint=endpoint,
|
|
1308
1367
|
api_version=api_version,
|
|
1368
|
+
default_headers=model.custom_headers or None,
|
|
1309
1369
|
)
|
|
1310
1370
|
super().__init__(client=client, model=model, credentials=credentials)
|
|
1311
1371
|
self._attributes[LLM_PROVIDER] = OpenInferenceLLMProviderValues.AZURE.value
|
|
@@ -1423,13 +1483,8 @@ class AzureOpenAIReasoningNonStreamingClient(
|
|
|
1423
1483
|
provider_key=GenerativeProviderKey.ANTHROPIC,
|
|
1424
1484
|
model_names=[
|
|
1425
1485
|
PROVIDER_DEFAULT,
|
|
1426
|
-
"claude-3-5-sonnet-latest",
|
|
1427
1486
|
"claude-3-5-haiku-latest",
|
|
1428
|
-
"claude-3-5-sonnet-20241022",
|
|
1429
1487
|
"claude-3-5-haiku-20241022",
|
|
1430
|
-
"claude-3-5-sonnet-20240620",
|
|
1431
|
-
"claude-3-opus-latest",
|
|
1432
|
-
"claude-3-sonnet-20240229",
|
|
1433
1488
|
"claude-3-haiku-20240307",
|
|
1434
1489
|
],
|
|
1435
1490
|
)
|
|
@@ -1453,7 +1508,10 @@ class AnthropicStreamingClient(PlaygroundStreamingClient):
|
|
|
1453
1508
|
if not api_key:
|
|
1454
1509
|
raise BadRequest("An API key is required for Anthropic models")
|
|
1455
1510
|
|
|
1456
|
-
self.client = anthropic.AsyncAnthropic(
|
|
1511
|
+
self.client = anthropic.AsyncAnthropic(
|
|
1512
|
+
api_key=api_key,
|
|
1513
|
+
default_headers=model.custom_headers or None,
|
|
1514
|
+
)
|
|
1457
1515
|
self.model_name = model.name
|
|
1458
1516
|
self.rate_limiter = PlaygroundRateLimiter(model.provider_key, anthropic.RateLimitError)
|
|
1459
1517
|
self.client._client = _HttpxClient(self.client._client, self._attributes)
|
|
@@ -1489,7 +1547,6 @@ class AnthropicStreamingClient(PlaygroundStreamingClient):
|
|
|
1489
1547
|
invocation_name="top_p",
|
|
1490
1548
|
canonical_name=CanonicalParameterName.TOP_P,
|
|
1491
1549
|
label="Top P",
|
|
1492
|
-
default_value=1.0,
|
|
1493
1550
|
min_value=0.0,
|
|
1494
1551
|
max_value=1.0,
|
|
1495
1552
|
),
|
|
@@ -1635,10 +1692,16 @@ class AnthropicStreamingClient(PlaygroundStreamingClient):
|
|
|
1635
1692
|
@register_llm_client(
|
|
1636
1693
|
provider_key=GenerativeProviderKey.ANTHROPIC,
|
|
1637
1694
|
model_names=[
|
|
1638
|
-
"claude-
|
|
1639
|
-
"claude-
|
|
1695
|
+
"claude-opus-4-5",
|
|
1696
|
+
"claude-opus-4-5-20251101",
|
|
1697
|
+
"claude-sonnet-4-5",
|
|
1698
|
+
"claude-sonnet-4-5-20250929",
|
|
1699
|
+
"claude-haiku-4-5",
|
|
1700
|
+
"claude-haiku-4-5-20251001",
|
|
1640
1701
|
"claude-opus-4-1",
|
|
1641
1702
|
"claude-opus-4-1-20250805",
|
|
1703
|
+
"claude-sonnet-4-0",
|
|
1704
|
+
"claude-sonnet-4-20250514",
|
|
1642
1705
|
"claude-opus-4-0",
|
|
1643
1706
|
"claude-opus-4-20250514",
|
|
1644
1707
|
"claude-3-7-sonnet-latest",
|
|
@@ -1663,7 +1726,6 @@ class AnthropicReasoningStreamingClient(AnthropicStreamingClient):
|
|
|
1663
1726
|
provider_key=GenerativeProviderKey.GOOGLE,
|
|
1664
1727
|
model_names=[
|
|
1665
1728
|
PROVIDER_DEFAULT,
|
|
1666
|
-
"gemini-2.5-pro-preview-03-25",
|
|
1667
1729
|
"gemini-2.0-flash-lite",
|
|
1668
1730
|
"gemini-2.0-flash-001",
|
|
1669
1731
|
"gemini-2.0-flash-thinking-exp-01-21",
|
|
@@ -1679,7 +1741,7 @@ class GoogleStreamingClient(PlaygroundStreamingClient):
|
|
|
1679
1741
|
model: GenerativeModelInput,
|
|
1680
1742
|
credentials: Optional[list[PlaygroundClientCredential]] = None,
|
|
1681
1743
|
) -> None:
|
|
1682
|
-
import google.
|
|
1744
|
+
import google.genai as google_genai
|
|
1683
1745
|
|
|
1684
1746
|
super().__init__(model=model, credentials=credentials)
|
|
1685
1747
|
self._attributes[LLM_PROVIDER] = OpenInferenceLLMProviderValues.GOOGLE.value
|
|
@@ -1696,12 +1758,12 @@ class GoogleStreamingClient(PlaygroundStreamingClient):
|
|
|
1696
1758
|
if not api_key:
|
|
1697
1759
|
raise BadRequest("An API key is required for Gemini models")
|
|
1698
1760
|
|
|
1699
|
-
google_genai.
|
|
1761
|
+
self.client = google_genai.Client(api_key=api_key)
|
|
1700
1762
|
self.model_name = model.name
|
|
1701
1763
|
|
|
1702
1764
|
@classmethod
|
|
1703
1765
|
def dependencies(cls) -> list[Dependency]:
|
|
1704
|
-
return [Dependency(name="google-
|
|
1766
|
+
return [Dependency(name="google-genai", module_name="google.genai")]
|
|
1705
1767
|
|
|
1706
1768
|
@classmethod
|
|
1707
1769
|
def supported_invocation_parameters(cls) -> list[InvocationParameter]:
|
|
@@ -1738,7 +1800,6 @@ class GoogleStreamingClient(PlaygroundStreamingClient):
|
|
|
1738
1800
|
invocation_name="top_p",
|
|
1739
1801
|
canonical_name=CanonicalParameterName.TOP_P,
|
|
1740
1802
|
label="Top P",
|
|
1741
|
-
default_value=1.0,
|
|
1742
1803
|
min_value=0.0,
|
|
1743
1804
|
max_value=1.0,
|
|
1744
1805
|
),
|
|
@@ -1746,6 +1807,11 @@ class GoogleStreamingClient(PlaygroundStreamingClient):
|
|
|
1746
1807
|
invocation_name="top_k",
|
|
1747
1808
|
label="Top K",
|
|
1748
1809
|
),
|
|
1810
|
+
JSONInvocationParameter(
|
|
1811
|
+
invocation_name="tool_config",
|
|
1812
|
+
label="Tool Config",
|
|
1813
|
+
canonical_name=CanonicalParameterName.TOOL_CHOICE,
|
|
1814
|
+
),
|
|
1749
1815
|
]
|
|
1750
1816
|
|
|
1751
1817
|
async def chat_completion_create(
|
|
@@ -1756,28 +1822,25 @@ class GoogleStreamingClient(PlaygroundStreamingClient):
|
|
|
1756
1822
|
tools: list[JSONScalarType],
|
|
1757
1823
|
**invocation_parameters: Any,
|
|
1758
1824
|
) -> AsyncIterator[ChatCompletionChunk]:
|
|
1759
|
-
|
|
1825
|
+
from google.genai import types
|
|
1760
1826
|
|
|
1761
|
-
|
|
1762
|
-
|
|
1763
|
-
)
|
|
1827
|
+
contents, system_prompt = self._build_google_messages(messages)
|
|
1828
|
+
|
|
1829
|
+
config_dict = invocation_parameters.copy()
|
|
1764
1830
|
|
|
1765
|
-
model_args = {"model_name": self.model_name}
|
|
1766
1831
|
if system_prompt:
|
|
1767
|
-
|
|
1768
|
-
client = google_genai.GenerativeModel(**model_args)
|
|
1832
|
+
config_dict["system_instruction"] = system_prompt
|
|
1769
1833
|
|
|
1770
|
-
|
|
1771
|
-
**
|
|
1834
|
+
if tools:
|
|
1835
|
+
function_declarations = [types.FunctionDeclaration(**tool) for tool in tools]
|
|
1836
|
+
config_dict["tools"] = [types.Tool(function_declarations=function_declarations)]
|
|
1837
|
+
|
|
1838
|
+
config = types.GenerateContentConfig.model_validate(config_dict)
|
|
1839
|
+
stream = await self.client.aio.models.generate_content_stream(
|
|
1840
|
+
model=f"models/{self.model_name}",
|
|
1841
|
+
contents=contents,
|
|
1842
|
+
config=config,
|
|
1772
1843
|
)
|
|
1773
|
-
google_params = {
|
|
1774
|
-
"content": current_message,
|
|
1775
|
-
"generation_config": google_config,
|
|
1776
|
-
"stream": True,
|
|
1777
|
-
}
|
|
1778
|
-
|
|
1779
|
-
chat = client.start_chat(history=google_message_history)
|
|
1780
|
-
stream = await chat.send_message_async(**google_params)
|
|
1781
1844
|
async for event in stream:
|
|
1782
1845
|
self._attributes.update(
|
|
1783
1846
|
{
|
|
@@ -1786,31 +1849,148 @@ class GoogleStreamingClient(PlaygroundStreamingClient):
|
|
|
1786
1849
|
LLM_TOKEN_COUNT_TOTAL: event.usage_metadata.total_token_count,
|
|
1787
1850
|
}
|
|
1788
1851
|
)
|
|
1789
|
-
|
|
1852
|
+
|
|
1853
|
+
if event.candidates:
|
|
1854
|
+
candidate = event.candidates[0]
|
|
1855
|
+
if candidate.content and candidate.content.parts:
|
|
1856
|
+
for part in candidate.content.parts:
|
|
1857
|
+
if function_call := part.function_call:
|
|
1858
|
+
yield ToolCallChunk(
|
|
1859
|
+
id=function_call.id or "",
|
|
1860
|
+
function=FunctionCallChunk(
|
|
1861
|
+
name=function_call.name or "",
|
|
1862
|
+
arguments=json.dumps(function_call.args or {}),
|
|
1863
|
+
),
|
|
1864
|
+
)
|
|
1865
|
+
elif text := part.text:
|
|
1866
|
+
yield TextChunk(content=text)
|
|
1790
1867
|
|
|
1791
1868
|
def _build_google_messages(
|
|
1792
1869
|
self,
|
|
1793
1870
|
messages: list[tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[str]]]],
|
|
1794
|
-
) -> tuple[list["ContentType"], str
|
|
1795
|
-
|
|
1871
|
+
) -> tuple[list["ContentType"], str]:
|
|
1872
|
+
"""Build Google messages following the standard pattern - process ALL messages."""
|
|
1873
|
+
google_messages: list["ContentType"] = []
|
|
1796
1874
|
system_prompts = []
|
|
1797
1875
|
for role, content, _tool_call_id, _tool_calls in messages:
|
|
1798
1876
|
if role == ChatCompletionMessageRole.USER:
|
|
1799
|
-
|
|
1877
|
+
google_messages.append({"role": "user", "parts": [{"text": content}]})
|
|
1800
1878
|
elif role == ChatCompletionMessageRole.AI:
|
|
1801
|
-
|
|
1879
|
+
google_messages.append({"role": "model", "parts": [{"text": content}]})
|
|
1802
1880
|
elif role == ChatCompletionMessageRole.SYSTEM:
|
|
1803
1881
|
system_prompts.append(content)
|
|
1804
1882
|
elif role == ChatCompletionMessageRole.TOOL:
|
|
1805
1883
|
raise NotImplementedError
|
|
1806
1884
|
else:
|
|
1807
1885
|
assert_never(role)
|
|
1808
|
-
if google_message_history:
|
|
1809
|
-
prompt = google_message_history.pop()["parts"]
|
|
1810
|
-
else:
|
|
1811
|
-
prompt = ""
|
|
1812
1886
|
|
|
1813
|
-
return
|
|
1887
|
+
return google_messages, "\n".join(system_prompts)
|
|
1888
|
+
|
|
1889
|
+
|
|
1890
|
+
@register_llm_client(
|
|
1891
|
+
provider_key=GenerativeProviderKey.GOOGLE,
|
|
1892
|
+
model_names=[
|
|
1893
|
+
PROVIDER_DEFAULT,
|
|
1894
|
+
"gemini-2.5-pro",
|
|
1895
|
+
"gemini-2.5-flash",
|
|
1896
|
+
"gemini-2.5-flash-lite",
|
|
1897
|
+
"gemini-2.5-pro-preview-03-25",
|
|
1898
|
+
],
|
|
1899
|
+
)
|
|
1900
|
+
class Gemini25GoogleStreamingClient(GoogleStreamingClient):
|
|
1901
|
+
@classmethod
|
|
1902
|
+
def supported_invocation_parameters(cls) -> list[InvocationParameter]:
|
|
1903
|
+
return [
|
|
1904
|
+
BoundedFloatInvocationParameter(
|
|
1905
|
+
invocation_name="temperature",
|
|
1906
|
+
canonical_name=CanonicalParameterName.TEMPERATURE,
|
|
1907
|
+
label="Temperature",
|
|
1908
|
+
default_value=1.0,
|
|
1909
|
+
min_value=0.0,
|
|
1910
|
+
max_value=2.0,
|
|
1911
|
+
),
|
|
1912
|
+
IntInvocationParameter(
|
|
1913
|
+
invocation_name="max_output_tokens",
|
|
1914
|
+
canonical_name=CanonicalParameterName.MAX_COMPLETION_TOKENS,
|
|
1915
|
+
label="Max Output Tokens",
|
|
1916
|
+
),
|
|
1917
|
+
StringListInvocationParameter(
|
|
1918
|
+
invocation_name="stop_sequences",
|
|
1919
|
+
canonical_name=CanonicalParameterName.STOP_SEQUENCES,
|
|
1920
|
+
label="Stop Sequences",
|
|
1921
|
+
),
|
|
1922
|
+
BoundedFloatInvocationParameter(
|
|
1923
|
+
invocation_name="top_p",
|
|
1924
|
+
canonical_name=CanonicalParameterName.TOP_P,
|
|
1925
|
+
label="Top P",
|
|
1926
|
+
min_value=0.0,
|
|
1927
|
+
max_value=1.0,
|
|
1928
|
+
),
|
|
1929
|
+
FloatInvocationParameter(
|
|
1930
|
+
invocation_name="top_k",
|
|
1931
|
+
label="Top K",
|
|
1932
|
+
),
|
|
1933
|
+
JSONInvocationParameter(
|
|
1934
|
+
invocation_name="tool_config",
|
|
1935
|
+
label="Tool Choice",
|
|
1936
|
+
canonical_name=CanonicalParameterName.TOOL_CHOICE,
|
|
1937
|
+
),
|
|
1938
|
+
]
|
|
1939
|
+
|
|
1940
|
+
|
|
1941
|
+
@register_llm_client(
|
|
1942
|
+
provider_key=GenerativeProviderKey.GOOGLE,
|
|
1943
|
+
model_names=[
|
|
1944
|
+
"gemini-3-pro-preview",
|
|
1945
|
+
],
|
|
1946
|
+
)
|
|
1947
|
+
class Gemini3GoogleStreamingClient(Gemini25GoogleStreamingClient):
|
|
1948
|
+
@classmethod
|
|
1949
|
+
def supported_invocation_parameters(cls) -> list[InvocationParameter]:
|
|
1950
|
+
return [
|
|
1951
|
+
StringInvocationParameter(
|
|
1952
|
+
invocation_name="thinking_level",
|
|
1953
|
+
label="Thinking Level",
|
|
1954
|
+
canonical_name=CanonicalParameterName.REASONING_EFFORT,
|
|
1955
|
+
),
|
|
1956
|
+
*super().supported_invocation_parameters(),
|
|
1957
|
+
]
|
|
1958
|
+
|
|
1959
|
+
async def chat_completion_create(
|
|
1960
|
+
self,
|
|
1961
|
+
messages: list[
|
|
1962
|
+
tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[JSONScalarType]]]
|
|
1963
|
+
],
|
|
1964
|
+
tools: list[JSONScalarType],
|
|
1965
|
+
**invocation_parameters: Any,
|
|
1966
|
+
) -> AsyncIterator[ChatCompletionChunk]:
|
|
1967
|
+
# Extract thinking_level and construct thinking_config
|
|
1968
|
+
thinking_level = invocation_parameters.pop("thinking_level", None)
|
|
1969
|
+
|
|
1970
|
+
if thinking_level:
|
|
1971
|
+
try:
|
|
1972
|
+
import google.genai
|
|
1973
|
+
from packaging.version import parse as parse_version
|
|
1974
|
+
|
|
1975
|
+
if parse_version(google.genai.__version__) < parse_version("1.50.0"):
|
|
1976
|
+
raise ImportError
|
|
1977
|
+
except (ImportError, AttributeError):
|
|
1978
|
+
raise BadRequest(
|
|
1979
|
+
"Reasoning capabilities for Gemini models require `google-genai>=1.50.0` "
|
|
1980
|
+
"and Python >= 3.10."
|
|
1981
|
+
)
|
|
1982
|
+
|
|
1983
|
+
# NOTE: as of gemini 1.51.0 medium thinking is not supported
|
|
1984
|
+
# but will eventually be added in a future version
|
|
1985
|
+
# we are purposefully allowing users to select medium knowing
|
|
1986
|
+
# it does not work.
|
|
1987
|
+
invocation_parameters["thinking_config"] = {
|
|
1988
|
+
"include_thoughts": True,
|
|
1989
|
+
"thinking_level": thinking_level.upper(),
|
|
1990
|
+
}
|
|
1991
|
+
|
|
1992
|
+
async for chunk in super().chat_completion_create(messages, tools, **invocation_parameters):
|
|
1993
|
+
yield chunk
|
|
1814
1994
|
|
|
1815
1995
|
|
|
1816
1996
|
def initialize_playground_clients() -> None:
|