arize-phoenix 10.0.4__py3-none-any.whl → 12.28.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +124 -72
- arize_phoenix-12.28.1.dist-info/RECORD +499 -0
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
- phoenix/__generated__/__init__.py +0 -0
- phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
- phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
- phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
- phoenix/__init__.py +5 -4
- phoenix/auth.py +39 -2
- phoenix/config.py +1763 -91
- phoenix/datetime_utils.py +120 -2
- phoenix/db/README.md +595 -25
- phoenix/db/bulk_inserter.py +145 -103
- phoenix/db/engines.py +140 -33
- phoenix/db/enums.py +3 -12
- phoenix/db/facilitator.py +302 -35
- phoenix/db/helpers.py +1000 -65
- phoenix/db/iam_auth.py +64 -0
- phoenix/db/insertion/dataset.py +135 -2
- phoenix/db/insertion/document_annotation.py +9 -6
- phoenix/db/insertion/evaluation.py +2 -3
- phoenix/db/insertion/helpers.py +17 -2
- phoenix/db/insertion/session_annotation.py +176 -0
- phoenix/db/insertion/span.py +15 -11
- phoenix/db/insertion/span_annotation.py +3 -4
- phoenix/db/insertion/trace_annotation.py +3 -4
- phoenix/db/insertion/types.py +50 -20
- phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
- phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
- phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
- phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
- phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
- phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
- phoenix/db/migrations/versions/a20694b15f82_cost.py +196 -0
- phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
- phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
- phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
- phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
- phoenix/db/models.py +669 -56
- phoenix/db/pg_config.py +10 -0
- phoenix/db/types/model_provider.py +4 -0
- phoenix/db/types/token_price_customization.py +29 -0
- phoenix/db/types/trace_retention.py +23 -15
- phoenix/experiments/evaluators/utils.py +3 -3
- phoenix/experiments/functions.py +160 -52
- phoenix/experiments/tracing.py +2 -2
- phoenix/experiments/types.py +1 -1
- phoenix/inferences/inferences.py +1 -2
- phoenix/server/api/auth.py +38 -7
- phoenix/server/api/auth_messages.py +46 -0
- phoenix/server/api/context.py +100 -4
- phoenix/server/api/dataloaders/__init__.py +79 -5
- phoenix/server/api/dataloaders/annotation_configs_by_project.py +31 -0
- phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
- phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
- phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
- phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
- phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
- phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
- phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
- phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
- phoenix/server/api/dataloaders/dataset_labels.py +36 -0
- phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
- phoenix/server/api/dataloaders/document_evaluations.py +6 -9
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
- phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
- phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
- phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
- phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
- phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
- phoenix/server/api/dataloaders/last_used_times_by_generative_model_id.py +35 -0
- phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
- phoenix/server/api/dataloaders/record_counts.py +37 -10
- phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
- phoenix/server/api/dataloaders/span_cost_by_span.py +24 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_generative_model.py +56 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_project_session.py +57 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_span.py +43 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_trace.py +56 -0
- phoenix/server/api/dataloaders/span_cost_details_by_span_cost.py +27 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment.py +57 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_run.py +58 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_generative_model.py +55 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project.py +152 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project_session.py +56 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_trace.py +55 -0
- phoenix/server/api/dataloaders/span_costs.py +29 -0
- phoenix/server/api/dataloaders/table_fields.py +2 -2
- phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
- phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
- phoenix/server/api/dataloaders/types.py +29 -0
- phoenix/server/api/exceptions.py +11 -1
- phoenix/server/api/helpers/dataset_helpers.py +5 -1
- phoenix/server/api/helpers/playground_clients.py +1243 -292
- phoenix/server/api/helpers/playground_registry.py +2 -2
- phoenix/server/api/helpers/playground_spans.py +8 -4
- phoenix/server/api/helpers/playground_users.py +26 -0
- phoenix/server/api/helpers/prompts/conversions/aws.py +83 -0
- phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
- phoenix/server/api/helpers/prompts/models.py +205 -22
- phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
- phoenix/server/api/input_types/ChatCompletionInput.py +6 -2
- phoenix/server/api/input_types/CreateProjectInput.py +27 -0
- phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
- phoenix/server/api/input_types/DatasetFilter.py +17 -0
- phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
- phoenix/server/api/input_types/GenerativeCredentialInput.py +9 -0
- phoenix/server/api/input_types/GenerativeModelInput.py +5 -0
- phoenix/server/api/input_types/ProjectSessionSort.py +161 -1
- phoenix/server/api/input_types/PromptFilter.py +14 -0
- phoenix/server/api/input_types/PromptVersionInput.py +52 -1
- phoenix/server/api/input_types/SpanSort.py +44 -7
- phoenix/server/api/input_types/TimeBinConfig.py +23 -0
- phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
- phoenix/server/api/input_types/UserRoleInput.py +1 -0
- phoenix/server/api/mutations/__init__.py +10 -0
- phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
- phoenix/server/api/mutations/api_key_mutations.py +19 -23
- phoenix/server/api/mutations/chat_mutations.py +154 -47
- phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
- phoenix/server/api/mutations/dataset_mutations.py +21 -16
- phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
- phoenix/server/api/mutations/experiment_mutations.py +2 -2
- phoenix/server/api/mutations/export_events_mutations.py +3 -3
- phoenix/server/api/mutations/model_mutations.py +210 -0
- phoenix/server/api/mutations/project_mutations.py +49 -10
- phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
- phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
- phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
- phoenix/server/api/mutations/prompt_mutations.py +65 -129
- phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
- phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
- phoenix/server/api/mutations/trace_annotations_mutations.py +14 -10
- phoenix/server/api/mutations/trace_mutations.py +47 -3
- phoenix/server/api/mutations/user_mutations.py +66 -41
- phoenix/server/api/queries.py +768 -293
- phoenix/server/api/routers/__init__.py +2 -2
- phoenix/server/api/routers/auth.py +154 -88
- phoenix/server/api/routers/ldap.py +229 -0
- phoenix/server/api/routers/oauth2.py +369 -106
- phoenix/server/api/routers/v1/__init__.py +24 -4
- phoenix/server/api/routers/v1/annotation_configs.py +23 -31
- phoenix/server/api/routers/v1/annotations.py +481 -17
- phoenix/server/api/routers/v1/datasets.py +395 -81
- phoenix/server/api/routers/v1/documents.py +142 -0
- phoenix/server/api/routers/v1/evaluations.py +24 -31
- phoenix/server/api/routers/v1/experiment_evaluations.py +19 -8
- phoenix/server/api/routers/v1/experiment_runs.py +337 -59
- phoenix/server/api/routers/v1/experiments.py +479 -48
- phoenix/server/api/routers/v1/models.py +7 -0
- phoenix/server/api/routers/v1/projects.py +18 -49
- phoenix/server/api/routers/v1/prompts.py +54 -40
- phoenix/server/api/routers/v1/sessions.py +108 -0
- phoenix/server/api/routers/v1/spans.py +1091 -81
- phoenix/server/api/routers/v1/traces.py +132 -78
- phoenix/server/api/routers/v1/users.py +389 -0
- phoenix/server/api/routers/v1/utils.py +3 -7
- phoenix/server/api/subscriptions.py +305 -88
- phoenix/server/api/types/Annotation.py +90 -23
- phoenix/server/api/types/ApiKey.py +13 -17
- phoenix/server/api/types/AuthMethod.py +1 -0
- phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
- phoenix/server/api/types/CostBreakdown.py +12 -0
- phoenix/server/api/types/Dataset.py +226 -72
- phoenix/server/api/types/DatasetExample.py +88 -18
- phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
- phoenix/server/api/types/DatasetLabel.py +57 -0
- phoenix/server/api/types/DatasetSplit.py +98 -0
- phoenix/server/api/types/DatasetVersion.py +49 -4
- phoenix/server/api/types/DocumentAnnotation.py +212 -0
- phoenix/server/api/types/Experiment.py +264 -59
- phoenix/server/api/types/ExperimentComparison.py +5 -10
- phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
- phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
- phoenix/server/api/types/ExperimentRun.py +169 -65
- phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
- phoenix/server/api/types/GenerativeModel.py +245 -3
- phoenix/server/api/types/GenerativeProvider.py +70 -11
- phoenix/server/api/types/{Model.py → InferenceModel.py} +1 -1
- phoenix/server/api/types/ModelInterface.py +16 -0
- phoenix/server/api/types/PlaygroundModel.py +20 -0
- phoenix/server/api/types/Project.py +1278 -216
- phoenix/server/api/types/ProjectSession.py +188 -28
- phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
- phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
- phoenix/server/api/types/Prompt.py +119 -39
- phoenix/server/api/types/PromptLabel.py +42 -25
- phoenix/server/api/types/PromptVersion.py +11 -8
- phoenix/server/api/types/PromptVersionTag.py +65 -25
- phoenix/server/api/types/ServerStatus.py +6 -0
- phoenix/server/api/types/Span.py +167 -123
- phoenix/server/api/types/SpanAnnotation.py +189 -42
- phoenix/server/api/types/SpanCostDetailSummaryEntry.py +10 -0
- phoenix/server/api/types/SpanCostSummary.py +10 -0
- phoenix/server/api/types/SystemApiKey.py +65 -1
- phoenix/server/api/types/TokenPrice.py +16 -0
- phoenix/server/api/types/TokenUsage.py +3 -3
- phoenix/server/api/types/Trace.py +223 -51
- phoenix/server/api/types/TraceAnnotation.py +149 -50
- phoenix/server/api/types/User.py +137 -32
- phoenix/server/api/types/UserApiKey.py +73 -26
- phoenix/server/api/types/node.py +10 -0
- phoenix/server/api/types/pagination.py +11 -2
- phoenix/server/app.py +290 -45
- phoenix/server/authorization.py +38 -3
- phoenix/server/bearer_auth.py +34 -24
- phoenix/server/cost_tracking/cost_details_calculator.py +196 -0
- phoenix/server/cost_tracking/cost_model_lookup.py +179 -0
- phoenix/server/cost_tracking/helpers.py +68 -0
- phoenix/server/cost_tracking/model_cost_manifest.json +3657 -830
- phoenix/server/cost_tracking/regex_specificity.py +397 -0
- phoenix/server/cost_tracking/token_cost_calculator.py +57 -0
- phoenix/server/daemons/__init__.py +0 -0
- phoenix/server/daemons/db_disk_usage_monitor.py +214 -0
- phoenix/server/daemons/generative_model_store.py +103 -0
- phoenix/server/daemons/span_cost_calculator.py +99 -0
- phoenix/server/dml_event.py +17 -0
- phoenix/server/dml_event_handler.py +5 -0
- phoenix/server/email/sender.py +56 -3
- phoenix/server/email/templates/db_disk_usage_notification.html +19 -0
- phoenix/server/email/types.py +11 -0
- phoenix/server/experiments/__init__.py +0 -0
- phoenix/server/experiments/utils.py +14 -0
- phoenix/server/grpc_server.py +11 -11
- phoenix/server/jwt_store.py +17 -15
- phoenix/server/ldap.py +1449 -0
- phoenix/server/main.py +26 -10
- phoenix/server/oauth2.py +330 -12
- phoenix/server/prometheus.py +66 -6
- phoenix/server/rate_limiters.py +4 -9
- phoenix/server/retention.py +33 -20
- phoenix/server/session_filters.py +49 -0
- phoenix/server/static/.vite/manifest.json +55 -51
- phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
- phoenix/server/static/assets/{index-E0M82BdE.js → index-CTQoemZv.js} +140 -56
- phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
- phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
- phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
- phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
- phoenix/server/static/assets/vendor-recharts-V9cwpXsm.js +37 -0
- phoenix/server/static/assets/vendor-shiki-Do--csgv.js +5 -0
- phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
- phoenix/server/templates/index.html +40 -6
- phoenix/server/thread_server.py +1 -2
- phoenix/server/types.py +14 -4
- phoenix/server/utils.py +74 -0
- phoenix/session/client.py +56 -3
- phoenix/session/data_extractor.py +5 -0
- phoenix/session/evaluation.py +14 -5
- phoenix/session/session.py +45 -9
- phoenix/settings.py +5 -0
- phoenix/trace/attributes.py +80 -13
- phoenix/trace/dsl/helpers.py +90 -1
- phoenix/trace/dsl/query.py +8 -6
- phoenix/trace/projects.py +5 -0
- phoenix/utilities/template_formatters.py +1 -1
- phoenix/version.py +1 -1
- arize_phoenix-10.0.4.dist-info/RECORD +0 -405
- phoenix/server/api/types/Evaluation.py +0 -39
- phoenix/server/cost_tracking/cost_lookup.py +0 -255
- phoenix/server/static/assets/components-DULKeDfL.js +0 -4365
- phoenix/server/static/assets/pages-Cl0A-0U2.js +0 -7430
- phoenix/server/static/assets/vendor-WIZid84E.css +0 -1
- phoenix/server/static/assets/vendor-arizeai-Dy-0mSNw.js +0 -649
- phoenix/server/static/assets/vendor-codemirror-DBtifKNr.js +0 -33
- phoenix/server/static/assets/vendor-oB4u9zuV.js +0 -905
- phoenix/server/static/assets/vendor-recharts-D-T4KPz2.js +0 -59
- phoenix/server/static/assets/vendor-shiki-BMn4O_9F.js +0 -5
- phoenix/server/static/assets/vendor-three-C5WAXd5r.js +0 -2998
- phoenix/utilities/deprecation.py +0 -31
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
|
1
|
+
from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence, Union
|
|
2
2
|
|
|
3
3
|
from phoenix.server.api.types.GenerativeProvider import GenerativeProviderKey
|
|
4
4
|
|
|
@@ -59,7 +59,7 @@ PLAYGROUND_CLIENT_REGISTRY: PlaygroundClientRegistry = PlaygroundClientRegistry(
|
|
|
59
59
|
|
|
60
60
|
def register_llm_client(
|
|
61
61
|
provider_key: GenerativeProviderKey,
|
|
62
|
-
model_names:
|
|
62
|
+
model_names: Sequence[ModelName],
|
|
63
63
|
) -> Callable[[type["PlaygroundStreamingClient"]], type["PlaygroundStreamingClient"]]:
|
|
64
64
|
def decorator(cls: type["PlaygroundStreamingClient"]) -> type["PlaygroundStreamingClient"]:
|
|
65
65
|
provider_registry = PLAYGROUND_CLIENT_REGISTRY._registry.setdefault(provider_key, {})
|
|
@@ -222,6 +222,7 @@ def get_db_experiment_run(
|
|
|
222
222
|
*,
|
|
223
223
|
experiment_id: int,
|
|
224
224
|
example_id: int,
|
|
225
|
+
repetition_number: int,
|
|
225
226
|
) -> models.ExperimentRun:
|
|
226
227
|
return models.ExperimentRun(
|
|
227
228
|
experiment_id=experiment_id,
|
|
@@ -230,7 +231,7 @@ def get_db_experiment_run(
|
|
|
230
231
|
output=models.ExperimentRunOutput(
|
|
231
232
|
task_output=get_dataset_example_output(db_span),
|
|
232
233
|
),
|
|
233
|
-
repetition_number=
|
|
234
|
+
repetition_number=repetition_number,
|
|
234
235
|
start_time=db_span.start_time,
|
|
235
236
|
end_time=db_span.end_time,
|
|
236
237
|
error=db_span.status_message or None,
|
|
@@ -263,10 +264,13 @@ def llm_tools(tools: list[JSONScalarType]) -> Iterator[tuple[str, Any]]:
|
|
|
263
264
|
def input_value_and_mime_type(
|
|
264
265
|
input: Union[ChatCompletionInput, ChatCompletionOverDatasetInput],
|
|
265
266
|
) -> Iterator[tuple[str, Any]]:
|
|
266
|
-
|
|
267
|
-
|
|
267
|
+
input_data = jsonify(input)
|
|
268
|
+
# Filter out sensitive credential information and invocation parameters
|
|
269
|
+
disallowed_keys = {"api_key", "credentials", "invocation_parameters"}
|
|
268
270
|
input_data = {k: v for k, v in input_data.items() if k not in disallowed_keys}
|
|
269
|
-
|
|
271
|
+
# Ensure sensitive data is not included in trace data
|
|
272
|
+
assert "api_key" not in input_data
|
|
273
|
+
assert "credentials" not in input_data
|
|
270
274
|
yield INPUT_MIME_TYPE, JSON
|
|
271
275
|
yield INPUT_VALUE, safe_json_dumps(input_data)
|
|
272
276
|
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from typing import (
|
|
2
|
+
Optional,
|
|
3
|
+
)
|
|
4
|
+
|
|
5
|
+
from starlette.requests import Request
|
|
6
|
+
from strawberry import Info
|
|
7
|
+
|
|
8
|
+
from phoenix.server.api.context import Context
|
|
9
|
+
from phoenix.server.bearer_auth import PhoenixUser
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def get_user(info: Info[Context, None]) -> Optional[int]:
|
|
13
|
+
user_id: Optional[int] = None
|
|
14
|
+
try:
|
|
15
|
+
assert isinstance(request := info.context.request, Request)
|
|
16
|
+
|
|
17
|
+
if "user" in request.scope and isinstance((user := info.context.user), PhoenixUser):
|
|
18
|
+
user_id = int(user.identity)
|
|
19
|
+
except AssertionError:
|
|
20
|
+
# Request is not available, try to obtain user identify
|
|
21
|
+
# this will also throw an assertion error if auth is not available
|
|
22
|
+
# the finally block will continue execution returning None
|
|
23
|
+
if info.context.user.is_authenticated:
|
|
24
|
+
user_id = int(info.context.user.identity)
|
|
25
|
+
finally:
|
|
26
|
+
return user_id
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING, Optional, Union
|
|
4
|
+
|
|
5
|
+
from typing_extensions import assert_never
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from anthropic.types import (
|
|
9
|
+
ToolChoiceAnyParam,
|
|
10
|
+
ToolChoiceAutoParam,
|
|
11
|
+
ToolChoiceParam,
|
|
12
|
+
ToolChoiceToolParam,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
from phoenix.server.api.helpers.prompts.models import (
|
|
16
|
+
PromptToolChoiceNone,
|
|
17
|
+
PromptToolChoiceOneOrMore,
|
|
18
|
+
PromptToolChoiceSpecificFunctionTool,
|
|
19
|
+
PromptToolChoiceZeroOrMore,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class AwsToolChoiceConversion:
|
|
24
|
+
@staticmethod
|
|
25
|
+
def to_aws(
|
|
26
|
+
obj: Union[
|
|
27
|
+
PromptToolChoiceNone,
|
|
28
|
+
PromptToolChoiceZeroOrMore,
|
|
29
|
+
PromptToolChoiceOneOrMore,
|
|
30
|
+
PromptToolChoiceSpecificFunctionTool,
|
|
31
|
+
],
|
|
32
|
+
disable_parallel_tool_use: Optional[bool] = None,
|
|
33
|
+
) -> ToolChoiceParam:
|
|
34
|
+
if obj.type == "zero_or_more":
|
|
35
|
+
choice_auto: ToolChoiceAutoParam = {"type": "auto"}
|
|
36
|
+
if disable_parallel_tool_use is not None:
|
|
37
|
+
choice_auto["disable_parallel_tool_use"] = disable_parallel_tool_use
|
|
38
|
+
return choice_auto
|
|
39
|
+
if obj.type == "one_or_more":
|
|
40
|
+
choice_any: ToolChoiceAnyParam = {"type": "any"}
|
|
41
|
+
if disable_parallel_tool_use is not None:
|
|
42
|
+
choice_any["disable_parallel_tool_use"] = disable_parallel_tool_use
|
|
43
|
+
return choice_any
|
|
44
|
+
if obj.type == "specific_function":
|
|
45
|
+
choice_tool: ToolChoiceToolParam = {"type": "tool", "name": obj.function_name}
|
|
46
|
+
if disable_parallel_tool_use is not None:
|
|
47
|
+
choice_tool["disable_parallel_tool_use"] = disable_parallel_tool_use
|
|
48
|
+
return choice_tool
|
|
49
|
+
if obj.type == "none":
|
|
50
|
+
return {"type": "none"}
|
|
51
|
+
assert_never(obj.type)
|
|
52
|
+
|
|
53
|
+
@staticmethod
|
|
54
|
+
def from_aws(
|
|
55
|
+
obj: ToolChoiceParam,
|
|
56
|
+
) -> Union[
|
|
57
|
+
PromptToolChoiceNone,
|
|
58
|
+
PromptToolChoiceZeroOrMore,
|
|
59
|
+
PromptToolChoiceOneOrMore,
|
|
60
|
+
PromptToolChoiceSpecificFunctionTool,
|
|
61
|
+
]:
|
|
62
|
+
from phoenix.server.api.helpers.prompts.models import (
|
|
63
|
+
PromptToolChoiceNone,
|
|
64
|
+
PromptToolChoiceOneOrMore,
|
|
65
|
+
PromptToolChoiceSpecificFunctionTool,
|
|
66
|
+
PromptToolChoiceZeroOrMore,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
if obj["type"] == "auto":
|
|
70
|
+
choice_zero_or_more = PromptToolChoiceZeroOrMore(type="zero_or_more")
|
|
71
|
+
return choice_zero_or_more
|
|
72
|
+
if obj["type"] == "any":
|
|
73
|
+
choice_one_or_more = PromptToolChoiceOneOrMore(type="one_or_more")
|
|
74
|
+
return choice_one_or_more
|
|
75
|
+
if obj["type"] == "tool":
|
|
76
|
+
choice_function_tool = PromptToolChoiceSpecificFunctionTool(
|
|
77
|
+
type="specific_function",
|
|
78
|
+
function_name=obj["name"],
|
|
79
|
+
)
|
|
80
|
+
return choice_function_tool
|
|
81
|
+
if obj["type"] == "none":
|
|
82
|
+
return PromptToolChoiceNone(type="none")
|
|
83
|
+
assert_never(obj)
|
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING, Any, Literal, Union
|
|
2
|
+
|
|
3
|
+
from typing_extensions import NotRequired, TypedDict, assert_never
|
|
4
|
+
|
|
5
|
+
if TYPE_CHECKING:
|
|
6
|
+
from phoenix.server.api.helpers.prompts.models import (
|
|
7
|
+
PromptToolChoiceNone,
|
|
8
|
+
PromptToolChoiceOneOrMore,
|
|
9
|
+
PromptToolChoiceSpecificFunctionTool,
|
|
10
|
+
PromptToolChoiceZeroOrMore,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class GoogleFunctionCallingConfig(TypedDict, total=False):
|
|
15
|
+
"""
|
|
16
|
+
Based on https://github.com/googleapis/python-genai/blob/97cc7e4eafbee4fa4035e7420170ab6a2c9da7fb/google/genai/types.py#L4245
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
mode: NotRequired[Literal["auto", "any", "none"]]
|
|
20
|
+
allowed_function_names: NotRequired[list[str]]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class GoogleToolChoice(TypedDict):
|
|
24
|
+
"""
|
|
25
|
+
Based on https://github.com/googleapis/python-genai/blob/97cc7e4eafbee4fa4035e7420170ab6a2c9da7fb/google/genai/types.py#L4341
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
function_calling_config: GoogleFunctionCallingConfig
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class GoogleToolChoiceConversion:
|
|
32
|
+
@staticmethod
|
|
33
|
+
def to_google(
|
|
34
|
+
obj: Union[
|
|
35
|
+
"PromptToolChoiceNone",
|
|
36
|
+
"PromptToolChoiceZeroOrMore",
|
|
37
|
+
"PromptToolChoiceOneOrMore",
|
|
38
|
+
"PromptToolChoiceSpecificFunctionTool",
|
|
39
|
+
],
|
|
40
|
+
) -> GoogleToolChoice:
|
|
41
|
+
if obj.type == "none":
|
|
42
|
+
return {"function_calling_config": {"mode": "none"}}
|
|
43
|
+
if obj.type == "zero_or_more":
|
|
44
|
+
return {"function_calling_config": {"mode": "auto"}}
|
|
45
|
+
if obj.type == "one_or_more":
|
|
46
|
+
return {"function_calling_config": {"mode": "any"}}
|
|
47
|
+
if obj.type == "specific_function":
|
|
48
|
+
return {
|
|
49
|
+
"function_calling_config": {
|
|
50
|
+
"mode": "any",
|
|
51
|
+
"allowed_function_names": [obj.function_name],
|
|
52
|
+
}
|
|
53
|
+
}
|
|
54
|
+
assert_never(obj)
|
|
55
|
+
|
|
56
|
+
@staticmethod
|
|
57
|
+
def from_google(
|
|
58
|
+
obj: Any,
|
|
59
|
+
) -> Union[
|
|
60
|
+
"PromptToolChoiceNone",
|
|
61
|
+
"PromptToolChoiceZeroOrMore",
|
|
62
|
+
"PromptToolChoiceOneOrMore",
|
|
63
|
+
"PromptToolChoiceSpecificFunctionTool",
|
|
64
|
+
]:
|
|
65
|
+
from google.genai.types import ToolConfig
|
|
66
|
+
|
|
67
|
+
from phoenix.server.api.helpers.prompts.models import (
|
|
68
|
+
PromptToolChoiceNone,
|
|
69
|
+
PromptToolChoiceOneOrMore,
|
|
70
|
+
PromptToolChoiceSpecificFunctionTool,
|
|
71
|
+
PromptToolChoiceZeroOrMore,
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
tool_config = ToolConfig.model_validate(obj)
|
|
75
|
+
if (function_calling_config := tool_config.function_calling_config) is None:
|
|
76
|
+
raise ValueError("function_calling_config is required")
|
|
77
|
+
# normalize mode to lowercase since Google's API is case-insensitive
|
|
78
|
+
# https://github.com/googleapis/python-genai/blob/97cc7e4eafbee4fa4035e7420170ab6a2c9da7fb/google/genai/types.py#L645
|
|
79
|
+
normalized_mode = (
|
|
80
|
+
function_calling_config.mode.value.lower()
|
|
81
|
+
if function_calling_config.mode is not None
|
|
82
|
+
else None
|
|
83
|
+
)
|
|
84
|
+
allowed_function_names = function_calling_config.allowed_function_names
|
|
85
|
+
|
|
86
|
+
if allowed_function_names:
|
|
87
|
+
if len(allowed_function_names) != 1:
|
|
88
|
+
raise ValueError("Only one allowed function name is currently supported")
|
|
89
|
+
if normalized_mode != "any":
|
|
90
|
+
raise ValueError("allowed function names only supported in 'any' mode")
|
|
91
|
+
return PromptToolChoiceSpecificFunctionTool(
|
|
92
|
+
type="specific_function",
|
|
93
|
+
function_name=allowed_function_names[0],
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
if normalized_mode == "none":
|
|
97
|
+
return PromptToolChoiceNone(type="none")
|
|
98
|
+
if normalized_mode == "auto" or normalized_mode is None:
|
|
99
|
+
return PromptToolChoiceZeroOrMore(type="zero_or_more")
|
|
100
|
+
if normalized_mode == "any":
|
|
101
|
+
return PromptToolChoiceOneOrMore(type="one_or_more")
|
|
102
|
+
|
|
103
|
+
raise ValueError(f"Unsupported Google tool choice mode: {normalized_mode}")
|
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
3
1
|
from enum import Enum
|
|
4
2
|
from typing import Any, Literal, Mapping, Optional, Union
|
|
5
3
|
|
|
@@ -9,6 +7,8 @@ from typing_extensions import Annotated, Self, TypeAlias, TypeGuard, assert_neve
|
|
|
9
7
|
from phoenix.db.types.db_models import UNDEFINED, DBBaseModel
|
|
10
8
|
from phoenix.db.types.model_provider import ModelProvider
|
|
11
9
|
from phoenix.server.api.helpers.prompts.conversions.anthropic import AnthropicToolChoiceConversion
|
|
10
|
+
from phoenix.server.api.helpers.prompts.conversions.aws import AwsToolChoiceConversion
|
|
11
|
+
from phoenix.server.api.helpers.prompts.conversions.google import GoogleToolChoiceConversion
|
|
12
12
|
from phoenix.server.api.helpers.prompts.conversions.openai import OpenAIToolChoiceConversion
|
|
13
13
|
|
|
14
14
|
JSONSerializable = Union[None, bool, int, float, str, dict[str, Any], list[Any]]
|
|
@@ -126,11 +126,6 @@ class PromptTemplateRootModel(RootModel[PromptTemplate]):
|
|
|
126
126
|
root: PromptTemplate
|
|
127
127
|
|
|
128
128
|
|
|
129
|
-
class PromptToolFunction(DBBaseModel):
|
|
130
|
-
type: Literal["function"]
|
|
131
|
-
function: PromptToolFunctionDefinition
|
|
132
|
-
|
|
133
|
-
|
|
134
129
|
class PromptToolFunctionDefinition(DBBaseModel):
|
|
135
130
|
name: str
|
|
136
131
|
description: str = UNDEFINED
|
|
@@ -138,14 +133,12 @@ class PromptToolFunctionDefinition(DBBaseModel):
|
|
|
138
133
|
strict: bool = UNDEFINED
|
|
139
134
|
|
|
140
135
|
|
|
141
|
-
|
|
136
|
+
class PromptToolFunction(DBBaseModel):
|
|
137
|
+
type: Literal["function"]
|
|
138
|
+
function: PromptToolFunctionDefinition
|
|
142
139
|
|
|
143
140
|
|
|
144
|
-
|
|
145
|
-
type: Literal["tools"]
|
|
146
|
-
tools: Annotated[list[PromptTool], Field(..., min_length=1)]
|
|
147
|
-
tool_choice: PromptToolChoice = UNDEFINED
|
|
148
|
-
disable_parallel_tool_calls: bool = UNDEFINED
|
|
141
|
+
PromptTool: TypeAlias = Annotated[Union[PromptToolFunction], Field(..., discriminator="type")]
|
|
149
142
|
|
|
150
143
|
|
|
151
144
|
class PromptToolChoiceNone(DBBaseModel):
|
|
@@ -176,6 +169,13 @@ PromptToolChoice: TypeAlias = Annotated[
|
|
|
176
169
|
]
|
|
177
170
|
|
|
178
171
|
|
|
172
|
+
class PromptTools(DBBaseModel):
|
|
173
|
+
type: Literal["tools"]
|
|
174
|
+
tools: Annotated[list[PromptTool], Field(..., min_length=1)]
|
|
175
|
+
tool_choice: PromptToolChoice = UNDEFINED
|
|
176
|
+
disable_parallel_tool_calls: bool = UNDEFINED
|
|
177
|
+
|
|
178
|
+
|
|
179
179
|
class PromptOpenAIJSONSchema(DBBaseModel):
|
|
180
180
|
"""
|
|
181
181
|
Based on https://github.com/openai/openai-python/blob/d16e6edde5a155626910b5758a0b939bfedb9ced/src/openai/types/shared/response_format_json_schema.py#L13
|
|
@@ -199,11 +199,6 @@ class PromptOpenAIResponseFormatJSONSchema(DBBaseModel):
|
|
|
199
199
|
type: Literal["json_schema"]
|
|
200
200
|
|
|
201
201
|
|
|
202
|
-
class PromptResponseFormatJSONSchema(DBBaseModel):
|
|
203
|
-
type: Literal["json_schema"]
|
|
204
|
-
json_schema: PromptResponseFormatJSONSchemaDefinition
|
|
205
|
-
|
|
206
|
-
|
|
207
202
|
class PromptResponseFormatJSONSchemaDefinition(DBBaseModel):
|
|
208
203
|
name: str
|
|
209
204
|
description: str = UNDEFINED
|
|
@@ -211,6 +206,11 @@ class PromptResponseFormatJSONSchemaDefinition(DBBaseModel):
|
|
|
211
206
|
strict: bool = UNDEFINED
|
|
212
207
|
|
|
213
208
|
|
|
209
|
+
class PromptResponseFormatJSONSchema(DBBaseModel):
|
|
210
|
+
type: Literal["json_schema"]
|
|
211
|
+
json_schema: PromptResponseFormatJSONSchemaDefinition
|
|
212
|
+
|
|
213
|
+
|
|
214
214
|
PromptResponseFormat: TypeAlias = Annotated[
|
|
215
215
|
Union[PromptResponseFormatJSONSchema], Field(..., discriminator="type")
|
|
216
216
|
]
|
|
@@ -312,6 +312,24 @@ class AnthropicToolDefinition(DBBaseModel):
|
|
|
312
312
|
description: str = UNDEFINED
|
|
313
313
|
|
|
314
314
|
|
|
315
|
+
class BedrockToolDefinition(DBBaseModel):
|
|
316
|
+
"""
|
|
317
|
+
Based on https://github.com/aws/amazon-bedrock-sdk-python/blob/main/src/bedrock/types/tool_param.py#L12
|
|
318
|
+
"""
|
|
319
|
+
|
|
320
|
+
toolSpec: dict[str, Any]
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
class GeminiToolDefinition(DBBaseModel):
|
|
324
|
+
"""
|
|
325
|
+
Based on https://github.com/googleapis/python-genai/blob/c0b175a0ca20286db419390031a2239938d0c0b7/google/genai/types.py#L2792
|
|
326
|
+
"""
|
|
327
|
+
|
|
328
|
+
name: str
|
|
329
|
+
description: str = UNDEFINED
|
|
330
|
+
parameters: dict[str, Any]
|
|
331
|
+
|
|
332
|
+
|
|
315
333
|
class PromptOpenAIInvocationParametersContent(DBBaseModel):
|
|
316
334
|
temperature: float = UNDEFINED
|
|
317
335
|
max_tokens: int = UNDEFINED
|
|
@@ -320,7 +338,7 @@ class PromptOpenAIInvocationParametersContent(DBBaseModel):
|
|
|
320
338
|
presence_penalty: float = UNDEFINED
|
|
321
339
|
top_p: float = UNDEFINED
|
|
322
340
|
seed: int = UNDEFINED
|
|
323
|
-
reasoning_effort: Literal["low", "medium", "high"] = UNDEFINED
|
|
341
|
+
reasoning_effort: Literal["none", "minimal", "low", "medium", "high", "xhigh"] = UNDEFINED
|
|
324
342
|
|
|
325
343
|
|
|
326
344
|
class PromptOpenAIInvocationParameters(DBBaseModel):
|
|
@@ -332,11 +350,38 @@ class PromptAzureOpenAIInvocationParametersContent(PromptOpenAIInvocationParamet
|
|
|
332
350
|
pass
|
|
333
351
|
|
|
334
352
|
|
|
353
|
+
class PromptDeepSeekInvocationParametersContent(PromptOpenAIInvocationParametersContent):
|
|
354
|
+
pass
|
|
355
|
+
|
|
356
|
+
|
|
357
|
+
class PromptXAIInvocationParametersContent(PromptOpenAIInvocationParametersContent):
|
|
358
|
+
pass
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
class PromptOllamaInvocationParametersContent(PromptOpenAIInvocationParametersContent):
|
|
362
|
+
pass
|
|
363
|
+
|
|
364
|
+
|
|
335
365
|
class PromptAzureOpenAIInvocationParameters(DBBaseModel):
|
|
336
366
|
type: Literal["azure_openai"]
|
|
337
367
|
azure_openai: PromptAzureOpenAIInvocationParametersContent
|
|
338
368
|
|
|
339
369
|
|
|
370
|
+
class PromptDeepSeekInvocationParameters(DBBaseModel):
|
|
371
|
+
type: Literal["deepseek"]
|
|
372
|
+
deepseek: PromptDeepSeekInvocationParametersContent
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
class PromptXAIInvocationParameters(DBBaseModel):
|
|
376
|
+
type: Literal["xai"]
|
|
377
|
+
xai: PromptXAIInvocationParametersContent
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
class PromptOllamaInvocationParameters(DBBaseModel):
|
|
381
|
+
type: Literal["ollama"]
|
|
382
|
+
ollama: PromptOllamaInvocationParametersContent
|
|
383
|
+
|
|
384
|
+
|
|
340
385
|
class PromptAnthropicThinkingConfigDisabled(DBBaseModel):
|
|
341
386
|
type: Literal["disabled"]
|
|
342
387
|
|
|
@@ -370,6 +415,17 @@ class PromptAnthropicInvocationParameters(DBBaseModel):
|
|
|
370
415
|
anthropic: PromptAnthropicInvocationParametersContent
|
|
371
416
|
|
|
372
417
|
|
|
418
|
+
class PromptAwsInvocationParametersContent(DBBaseModel):
|
|
419
|
+
max_tokens: int = UNDEFINED
|
|
420
|
+
temperature: float = UNDEFINED
|
|
421
|
+
top_p: float = UNDEFINED
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
class PromptAwsInvocationParameters(DBBaseModel):
|
|
425
|
+
type: Literal["aws"]
|
|
426
|
+
aws: PromptAwsInvocationParametersContent
|
|
427
|
+
|
|
428
|
+
|
|
373
429
|
class PromptGoogleInvocationParametersContent(DBBaseModel):
|
|
374
430
|
temperature: float = UNDEFINED
|
|
375
431
|
max_output_tokens: int = UNDEFINED
|
|
@@ -391,6 +447,10 @@ PromptInvocationParameters: TypeAlias = Annotated[
|
|
|
391
447
|
PromptAzureOpenAIInvocationParameters,
|
|
392
448
|
PromptAnthropicInvocationParameters,
|
|
393
449
|
PromptGoogleInvocationParameters,
|
|
450
|
+
PromptDeepSeekInvocationParameters,
|
|
451
|
+
PromptXAIInvocationParameters,
|
|
452
|
+
PromptOllamaInvocationParameters,
|
|
453
|
+
PromptAwsInvocationParameters,
|
|
394
454
|
],
|
|
395
455
|
Field(..., discriminator="type"),
|
|
396
456
|
]
|
|
@@ -407,6 +467,14 @@ def get_raw_invocation_parameters(
|
|
|
407
467
|
return invocation_parameters.anthropic.model_dump()
|
|
408
468
|
if isinstance(invocation_parameters, PromptGoogleInvocationParameters):
|
|
409
469
|
return invocation_parameters.google.model_dump()
|
|
470
|
+
if isinstance(invocation_parameters, PromptDeepSeekInvocationParameters):
|
|
471
|
+
return invocation_parameters.deepseek.model_dump()
|
|
472
|
+
if isinstance(invocation_parameters, PromptXAIInvocationParameters):
|
|
473
|
+
return invocation_parameters.xai.model_dump()
|
|
474
|
+
if isinstance(invocation_parameters, PromptOllamaInvocationParameters):
|
|
475
|
+
return invocation_parameters.ollama.model_dump()
|
|
476
|
+
if isinstance(invocation_parameters, PromptAwsInvocationParameters):
|
|
477
|
+
return invocation_parameters.aws.model_dump()
|
|
410
478
|
assert_never(invocation_parameters)
|
|
411
479
|
|
|
412
480
|
|
|
@@ -420,6 +488,10 @@ def is_prompt_invocation_parameters(
|
|
|
420
488
|
PromptAzureOpenAIInvocationParameters,
|
|
421
489
|
PromptAnthropicInvocationParameters,
|
|
422
490
|
PromptGoogleInvocationParameters,
|
|
491
|
+
PromptDeepSeekInvocationParameters,
|
|
492
|
+
PromptXAIInvocationParameters,
|
|
493
|
+
PromptOllamaInvocationParameters,
|
|
494
|
+
PromptAwsInvocationParameters,
|
|
423
495
|
),
|
|
424
496
|
)
|
|
425
497
|
|
|
@@ -444,6 +516,13 @@ def validate_invocation_parameters(
|
|
|
444
516
|
invocation_parameters
|
|
445
517
|
),
|
|
446
518
|
)
|
|
519
|
+
elif model_provider is ModelProvider.DEEPSEEK:
|
|
520
|
+
return PromptDeepSeekInvocationParameters(
|
|
521
|
+
type="deepseek",
|
|
522
|
+
deepseek=PromptDeepSeekInvocationParametersContent.model_validate(
|
|
523
|
+
invocation_parameters
|
|
524
|
+
),
|
|
525
|
+
)
|
|
447
526
|
elif model_provider is ModelProvider.ANTHROPIC:
|
|
448
527
|
return PromptAnthropicInvocationParameters(
|
|
449
528
|
type="anthropic",
|
|
@@ -456,6 +535,21 @@ def validate_invocation_parameters(
|
|
|
456
535
|
type="google",
|
|
457
536
|
google=PromptGoogleInvocationParametersContent.model_validate(invocation_parameters),
|
|
458
537
|
)
|
|
538
|
+
elif model_provider is ModelProvider.XAI:
|
|
539
|
+
return PromptXAIInvocationParameters(
|
|
540
|
+
type="xai",
|
|
541
|
+
xai=PromptXAIInvocationParametersContent.model_validate(invocation_parameters),
|
|
542
|
+
)
|
|
543
|
+
elif model_provider is ModelProvider.OLLAMA:
|
|
544
|
+
return PromptOllamaInvocationParameters(
|
|
545
|
+
type="ollama",
|
|
546
|
+
ollama=PromptOllamaInvocationParametersContent.model_validate(invocation_parameters),
|
|
547
|
+
)
|
|
548
|
+
elif model_provider is ModelProvider.AWS:
|
|
549
|
+
return PromptAwsInvocationParameters(
|
|
550
|
+
type="aws",
|
|
551
|
+
aws=PromptAwsInvocationParametersContent.model_validate(invocation_parameters),
|
|
552
|
+
)
|
|
459
553
|
assert_never(model_provider)
|
|
460
554
|
|
|
461
555
|
|
|
@@ -465,18 +559,39 @@ def normalize_tools(
|
|
|
465
559
|
tool_choice: Optional[Union[str, Mapping[str, Any]]] = None,
|
|
466
560
|
) -> PromptTools:
|
|
467
561
|
tools: list[PromptToolFunction]
|
|
468
|
-
if
|
|
562
|
+
if (
|
|
563
|
+
model_provider is ModelProvider.OPENAI
|
|
564
|
+
or model_provider is ModelProvider.AZURE_OPENAI
|
|
565
|
+
or model_provider is ModelProvider.DEEPSEEK
|
|
566
|
+
or model_provider is ModelProvider.XAI
|
|
567
|
+
or model_provider is ModelProvider.OLLAMA
|
|
568
|
+
):
|
|
469
569
|
openai_tools = [OpenAIToolDefinition.model_validate(schema) for schema in schemas]
|
|
470
570
|
tools = [_openai_to_prompt_tool(openai_tool) for openai_tool in openai_tools]
|
|
571
|
+
elif model_provider is ModelProvider.AWS:
|
|
572
|
+
bedrock_tools = [BedrockToolDefinition.model_validate(schema) for schema in schemas]
|
|
573
|
+
tools = [_bedrock_to_prompt_tool(bedrock_tool) for bedrock_tool in bedrock_tools]
|
|
471
574
|
elif model_provider is ModelProvider.ANTHROPIC:
|
|
472
575
|
anthropic_tools = [AnthropicToolDefinition.model_validate(schema) for schema in schemas]
|
|
473
576
|
tools = [_anthropic_to_prompt_tool(anthropic_tool) for anthropic_tool in anthropic_tools]
|
|
577
|
+
elif model_provider is ModelProvider.GOOGLE:
|
|
578
|
+
gemini_tools = [GeminiToolDefinition.model_validate(schema) for schema in schemas]
|
|
579
|
+
tools = [_gemini_to_prompt_tool(gemini_tool) for gemini_tool in gemini_tools]
|
|
474
580
|
else:
|
|
475
581
|
raise ValueError(f"Unsupported model provider: {model_provider}")
|
|
476
582
|
ans = PromptTools(type="tools", tools=tools)
|
|
583
|
+
|
|
477
584
|
if tool_choice is not None:
|
|
478
|
-
if
|
|
585
|
+
if (
|
|
586
|
+
model_provider is ModelProvider.OPENAI
|
|
587
|
+
or model_provider is ModelProvider.AZURE_OPENAI
|
|
588
|
+
or model_provider is ModelProvider.DEEPSEEK
|
|
589
|
+
or model_provider is ModelProvider.XAI
|
|
590
|
+
or model_provider is ModelProvider.OLLAMA
|
|
591
|
+
):
|
|
479
592
|
ans.tool_choice = OpenAIToolChoiceConversion.from_openai(tool_choice) # type: ignore[arg-type]
|
|
593
|
+
elif model_provider is ModelProvider.AWS:
|
|
594
|
+
ans.tool_choice = AwsToolChoiceConversion.from_aws(tool_choice) # type: ignore[arg-type]
|
|
480
595
|
elif model_provider is ModelProvider.ANTHROPIC:
|
|
481
596
|
choice, disable_parallel_tool_calls = AnthropicToolChoiceConversion.from_anthropic(
|
|
482
597
|
tool_choice # type: ignore[arg-type]
|
|
@@ -484,6 +599,8 @@ def normalize_tools(
|
|
|
484
599
|
ans.tool_choice = choice
|
|
485
600
|
if disable_parallel_tool_calls is not None:
|
|
486
601
|
ans.disable_parallel_tool_calls = disable_parallel_tool_calls
|
|
602
|
+
elif model_provider is ModelProvider.GOOGLE:
|
|
603
|
+
ans.tool_choice = GoogleToolChoiceConversion.from_google(tool_choice)
|
|
487
604
|
return ans
|
|
488
605
|
|
|
489
606
|
|
|
@@ -493,14 +610,28 @@ def denormalize_tools(
|
|
|
493
610
|
assert tools.type == "tools"
|
|
494
611
|
denormalized_tools: list[DBBaseModel]
|
|
495
612
|
tool_choice: Optional[Any] = None
|
|
496
|
-
if
|
|
613
|
+
if (
|
|
614
|
+
model_provider is ModelProvider.OPENAI
|
|
615
|
+
or model_provider is ModelProvider.AZURE_OPENAI
|
|
616
|
+
or model_provider is ModelProvider.DEEPSEEK
|
|
617
|
+
or model_provider is ModelProvider.XAI
|
|
618
|
+
or model_provider is ModelProvider.OLLAMA
|
|
619
|
+
):
|
|
497
620
|
denormalized_tools = [_prompt_to_openai_tool(tool) for tool in tools.tools]
|
|
498
621
|
if tools.tool_choice:
|
|
499
622
|
tool_choice = OpenAIToolChoiceConversion.to_openai(tools.tool_choice)
|
|
623
|
+
elif model_provider is ModelProvider.AWS:
|
|
624
|
+
denormalized_tools = [_prompt_to_bedrock_tool(tool) for tool in tools.tools]
|
|
625
|
+
if tools.tool_choice:
|
|
626
|
+
tool_choice = OpenAIToolChoiceConversion.to_openai(tools.tool_choice)
|
|
500
627
|
elif model_provider is ModelProvider.ANTHROPIC:
|
|
501
628
|
denormalized_tools = [_prompt_to_anthropic_tool(tool) for tool in tools.tools]
|
|
502
629
|
if tools.tool_choice and tools.tool_choice.type != "none":
|
|
503
630
|
tool_choice = AnthropicToolChoiceConversion.to_anthropic(tools.tool_choice)
|
|
631
|
+
elif model_provider is ModelProvider.GOOGLE:
|
|
632
|
+
denormalized_tools = [_prompt_to_gemini_tool(tool) for tool in tools.tools]
|
|
633
|
+
if tools.tool_choice:
|
|
634
|
+
tool_choice = GoogleToolChoiceConversion.to_google(tools.tool_choice)
|
|
504
635
|
else:
|
|
505
636
|
raise ValueError(f"Unsupported model provider: {model_provider}")
|
|
506
637
|
return [tool.model_dump() for tool in denormalized_tools], tool_choice
|
|
@@ -540,6 +671,19 @@ def _prompt_to_openai_tool(
|
|
|
540
671
|
)
|
|
541
672
|
|
|
542
673
|
|
|
674
|
+
def _bedrock_to_prompt_tool(
|
|
675
|
+
tool: BedrockToolDefinition,
|
|
676
|
+
) -> PromptToolFunction:
|
|
677
|
+
return PromptToolFunction(
|
|
678
|
+
type="function",
|
|
679
|
+
function=PromptToolFunctionDefinition(
|
|
680
|
+
name=tool.toolSpec["name"],
|
|
681
|
+
description=tool.toolSpec["description"],
|
|
682
|
+
parameters=tool.toolSpec["inputSchema"]["json"],
|
|
683
|
+
),
|
|
684
|
+
)
|
|
685
|
+
|
|
686
|
+
|
|
543
687
|
def _anthropic_to_prompt_tool(
|
|
544
688
|
tool: AnthropicToolDefinition,
|
|
545
689
|
) -> PromptToolFunction:
|
|
@@ -562,3 +706,42 @@ def _prompt_to_anthropic_tool(
|
|
|
562
706
|
name=function.name,
|
|
563
707
|
description=function.description,
|
|
564
708
|
)
|
|
709
|
+
|
|
710
|
+
|
|
711
|
+
def _prompt_to_bedrock_tool(
|
|
712
|
+
tool: PromptToolFunction,
|
|
713
|
+
) -> BedrockToolDefinition:
|
|
714
|
+
function = tool.function
|
|
715
|
+
return BedrockToolDefinition(
|
|
716
|
+
toolSpec={
|
|
717
|
+
"name": function.name,
|
|
718
|
+
"description": function.description,
|
|
719
|
+
"inputSchema": {
|
|
720
|
+
"json": function.parameters,
|
|
721
|
+
},
|
|
722
|
+
}
|
|
723
|
+
)
|
|
724
|
+
|
|
725
|
+
|
|
726
|
+
def _gemini_to_prompt_tool(
|
|
727
|
+
tool: GeminiToolDefinition,
|
|
728
|
+
) -> PromptToolFunction:
|
|
729
|
+
return PromptToolFunction(
|
|
730
|
+
type="function",
|
|
731
|
+
function=PromptToolFunctionDefinition(
|
|
732
|
+
name=tool.name,
|
|
733
|
+
description=tool.description,
|
|
734
|
+
parameters=tool.parameters,
|
|
735
|
+
),
|
|
736
|
+
)
|
|
737
|
+
|
|
738
|
+
|
|
739
|
+
def _prompt_to_gemini_tool(
|
|
740
|
+
tool: PromptToolFunction,
|
|
741
|
+
) -> GeminiToolDefinition:
|
|
742
|
+
function = tool.function
|
|
743
|
+
return GeminiToolDefinition(
|
|
744
|
+
name=function.name,
|
|
745
|
+
description=function.description,
|
|
746
|
+
parameters=function.parameters,
|
|
747
|
+
)
|