arize-phoenix 10.0.4__py3-none-any.whl → 12.28.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (276) hide show
  1. {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +124 -72
  2. arize_phoenix-12.28.1.dist-info/RECORD +499 -0
  3. {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
  4. {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
  5. phoenix/__generated__/__init__.py +0 -0
  6. phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
  7. phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
  8. phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
  9. phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
  10. phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
  11. phoenix/__init__.py +5 -4
  12. phoenix/auth.py +39 -2
  13. phoenix/config.py +1763 -91
  14. phoenix/datetime_utils.py +120 -2
  15. phoenix/db/README.md +595 -25
  16. phoenix/db/bulk_inserter.py +145 -103
  17. phoenix/db/engines.py +140 -33
  18. phoenix/db/enums.py +3 -12
  19. phoenix/db/facilitator.py +302 -35
  20. phoenix/db/helpers.py +1000 -65
  21. phoenix/db/iam_auth.py +64 -0
  22. phoenix/db/insertion/dataset.py +135 -2
  23. phoenix/db/insertion/document_annotation.py +9 -6
  24. phoenix/db/insertion/evaluation.py +2 -3
  25. phoenix/db/insertion/helpers.py +17 -2
  26. phoenix/db/insertion/session_annotation.py +176 -0
  27. phoenix/db/insertion/span.py +15 -11
  28. phoenix/db/insertion/span_annotation.py +3 -4
  29. phoenix/db/insertion/trace_annotation.py +3 -4
  30. phoenix/db/insertion/types.py +50 -20
  31. phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
  32. phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
  33. phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
  34. phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
  35. phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
  36. phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
  37. phoenix/db/migrations/versions/a20694b15f82_cost.py +196 -0
  38. phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
  39. phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
  40. phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
  41. phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
  42. phoenix/db/models.py +669 -56
  43. phoenix/db/pg_config.py +10 -0
  44. phoenix/db/types/model_provider.py +4 -0
  45. phoenix/db/types/token_price_customization.py +29 -0
  46. phoenix/db/types/trace_retention.py +23 -15
  47. phoenix/experiments/evaluators/utils.py +3 -3
  48. phoenix/experiments/functions.py +160 -52
  49. phoenix/experiments/tracing.py +2 -2
  50. phoenix/experiments/types.py +1 -1
  51. phoenix/inferences/inferences.py +1 -2
  52. phoenix/server/api/auth.py +38 -7
  53. phoenix/server/api/auth_messages.py +46 -0
  54. phoenix/server/api/context.py +100 -4
  55. phoenix/server/api/dataloaders/__init__.py +79 -5
  56. phoenix/server/api/dataloaders/annotation_configs_by_project.py +31 -0
  57. phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
  58. phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
  59. phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
  60. phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
  61. phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
  62. phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
  63. phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
  64. phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
  65. phoenix/server/api/dataloaders/dataset_labels.py +36 -0
  66. phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
  67. phoenix/server/api/dataloaders/document_evaluations.py +6 -9
  68. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
  69. phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
  70. phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
  71. phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
  72. phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
  73. phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
  74. phoenix/server/api/dataloaders/last_used_times_by_generative_model_id.py +35 -0
  75. phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
  76. phoenix/server/api/dataloaders/record_counts.py +37 -10
  77. phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
  78. phoenix/server/api/dataloaders/span_cost_by_span.py +24 -0
  79. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_generative_model.py +56 -0
  80. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_project_session.py +57 -0
  81. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_span.py +43 -0
  82. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_trace.py +56 -0
  83. phoenix/server/api/dataloaders/span_cost_details_by_span_cost.py +27 -0
  84. phoenix/server/api/dataloaders/span_cost_summary_by_experiment.py +57 -0
  85. phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
  86. phoenix/server/api/dataloaders/span_cost_summary_by_experiment_run.py +58 -0
  87. phoenix/server/api/dataloaders/span_cost_summary_by_generative_model.py +55 -0
  88. phoenix/server/api/dataloaders/span_cost_summary_by_project.py +152 -0
  89. phoenix/server/api/dataloaders/span_cost_summary_by_project_session.py +56 -0
  90. phoenix/server/api/dataloaders/span_cost_summary_by_trace.py +55 -0
  91. phoenix/server/api/dataloaders/span_costs.py +29 -0
  92. phoenix/server/api/dataloaders/table_fields.py +2 -2
  93. phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
  94. phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
  95. phoenix/server/api/dataloaders/types.py +29 -0
  96. phoenix/server/api/exceptions.py +11 -1
  97. phoenix/server/api/helpers/dataset_helpers.py +5 -1
  98. phoenix/server/api/helpers/playground_clients.py +1243 -292
  99. phoenix/server/api/helpers/playground_registry.py +2 -2
  100. phoenix/server/api/helpers/playground_spans.py +8 -4
  101. phoenix/server/api/helpers/playground_users.py +26 -0
  102. phoenix/server/api/helpers/prompts/conversions/aws.py +83 -0
  103. phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
  104. phoenix/server/api/helpers/prompts/models.py +205 -22
  105. phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
  106. phoenix/server/api/input_types/ChatCompletionInput.py +6 -2
  107. phoenix/server/api/input_types/CreateProjectInput.py +27 -0
  108. phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
  109. phoenix/server/api/input_types/DatasetFilter.py +17 -0
  110. phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
  111. phoenix/server/api/input_types/GenerativeCredentialInput.py +9 -0
  112. phoenix/server/api/input_types/GenerativeModelInput.py +5 -0
  113. phoenix/server/api/input_types/ProjectSessionSort.py +161 -1
  114. phoenix/server/api/input_types/PromptFilter.py +14 -0
  115. phoenix/server/api/input_types/PromptVersionInput.py +52 -1
  116. phoenix/server/api/input_types/SpanSort.py +44 -7
  117. phoenix/server/api/input_types/TimeBinConfig.py +23 -0
  118. phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
  119. phoenix/server/api/input_types/UserRoleInput.py +1 -0
  120. phoenix/server/api/mutations/__init__.py +10 -0
  121. phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
  122. phoenix/server/api/mutations/api_key_mutations.py +19 -23
  123. phoenix/server/api/mutations/chat_mutations.py +154 -47
  124. phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
  125. phoenix/server/api/mutations/dataset_mutations.py +21 -16
  126. phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
  127. phoenix/server/api/mutations/experiment_mutations.py +2 -2
  128. phoenix/server/api/mutations/export_events_mutations.py +3 -3
  129. phoenix/server/api/mutations/model_mutations.py +210 -0
  130. phoenix/server/api/mutations/project_mutations.py +49 -10
  131. phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
  132. phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
  133. phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
  134. phoenix/server/api/mutations/prompt_mutations.py +65 -129
  135. phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
  136. phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
  137. phoenix/server/api/mutations/trace_annotations_mutations.py +14 -10
  138. phoenix/server/api/mutations/trace_mutations.py +47 -3
  139. phoenix/server/api/mutations/user_mutations.py +66 -41
  140. phoenix/server/api/queries.py +768 -293
  141. phoenix/server/api/routers/__init__.py +2 -2
  142. phoenix/server/api/routers/auth.py +154 -88
  143. phoenix/server/api/routers/ldap.py +229 -0
  144. phoenix/server/api/routers/oauth2.py +369 -106
  145. phoenix/server/api/routers/v1/__init__.py +24 -4
  146. phoenix/server/api/routers/v1/annotation_configs.py +23 -31
  147. phoenix/server/api/routers/v1/annotations.py +481 -17
  148. phoenix/server/api/routers/v1/datasets.py +395 -81
  149. phoenix/server/api/routers/v1/documents.py +142 -0
  150. phoenix/server/api/routers/v1/evaluations.py +24 -31
  151. phoenix/server/api/routers/v1/experiment_evaluations.py +19 -8
  152. phoenix/server/api/routers/v1/experiment_runs.py +337 -59
  153. phoenix/server/api/routers/v1/experiments.py +479 -48
  154. phoenix/server/api/routers/v1/models.py +7 -0
  155. phoenix/server/api/routers/v1/projects.py +18 -49
  156. phoenix/server/api/routers/v1/prompts.py +54 -40
  157. phoenix/server/api/routers/v1/sessions.py +108 -0
  158. phoenix/server/api/routers/v1/spans.py +1091 -81
  159. phoenix/server/api/routers/v1/traces.py +132 -78
  160. phoenix/server/api/routers/v1/users.py +389 -0
  161. phoenix/server/api/routers/v1/utils.py +3 -7
  162. phoenix/server/api/subscriptions.py +305 -88
  163. phoenix/server/api/types/Annotation.py +90 -23
  164. phoenix/server/api/types/ApiKey.py +13 -17
  165. phoenix/server/api/types/AuthMethod.py +1 -0
  166. phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
  167. phoenix/server/api/types/CostBreakdown.py +12 -0
  168. phoenix/server/api/types/Dataset.py +226 -72
  169. phoenix/server/api/types/DatasetExample.py +88 -18
  170. phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
  171. phoenix/server/api/types/DatasetLabel.py +57 -0
  172. phoenix/server/api/types/DatasetSplit.py +98 -0
  173. phoenix/server/api/types/DatasetVersion.py +49 -4
  174. phoenix/server/api/types/DocumentAnnotation.py +212 -0
  175. phoenix/server/api/types/Experiment.py +264 -59
  176. phoenix/server/api/types/ExperimentComparison.py +5 -10
  177. phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
  178. phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
  179. phoenix/server/api/types/ExperimentRun.py +169 -65
  180. phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
  181. phoenix/server/api/types/GenerativeModel.py +245 -3
  182. phoenix/server/api/types/GenerativeProvider.py +70 -11
  183. phoenix/server/api/types/{Model.py → InferenceModel.py} +1 -1
  184. phoenix/server/api/types/ModelInterface.py +16 -0
  185. phoenix/server/api/types/PlaygroundModel.py +20 -0
  186. phoenix/server/api/types/Project.py +1278 -216
  187. phoenix/server/api/types/ProjectSession.py +188 -28
  188. phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
  189. phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
  190. phoenix/server/api/types/Prompt.py +119 -39
  191. phoenix/server/api/types/PromptLabel.py +42 -25
  192. phoenix/server/api/types/PromptVersion.py +11 -8
  193. phoenix/server/api/types/PromptVersionTag.py +65 -25
  194. phoenix/server/api/types/ServerStatus.py +6 -0
  195. phoenix/server/api/types/Span.py +167 -123
  196. phoenix/server/api/types/SpanAnnotation.py +189 -42
  197. phoenix/server/api/types/SpanCostDetailSummaryEntry.py +10 -0
  198. phoenix/server/api/types/SpanCostSummary.py +10 -0
  199. phoenix/server/api/types/SystemApiKey.py +65 -1
  200. phoenix/server/api/types/TokenPrice.py +16 -0
  201. phoenix/server/api/types/TokenUsage.py +3 -3
  202. phoenix/server/api/types/Trace.py +223 -51
  203. phoenix/server/api/types/TraceAnnotation.py +149 -50
  204. phoenix/server/api/types/User.py +137 -32
  205. phoenix/server/api/types/UserApiKey.py +73 -26
  206. phoenix/server/api/types/node.py +10 -0
  207. phoenix/server/api/types/pagination.py +11 -2
  208. phoenix/server/app.py +290 -45
  209. phoenix/server/authorization.py +38 -3
  210. phoenix/server/bearer_auth.py +34 -24
  211. phoenix/server/cost_tracking/cost_details_calculator.py +196 -0
  212. phoenix/server/cost_tracking/cost_model_lookup.py +179 -0
  213. phoenix/server/cost_tracking/helpers.py +68 -0
  214. phoenix/server/cost_tracking/model_cost_manifest.json +3657 -830
  215. phoenix/server/cost_tracking/regex_specificity.py +397 -0
  216. phoenix/server/cost_tracking/token_cost_calculator.py +57 -0
  217. phoenix/server/daemons/__init__.py +0 -0
  218. phoenix/server/daemons/db_disk_usage_monitor.py +214 -0
  219. phoenix/server/daemons/generative_model_store.py +103 -0
  220. phoenix/server/daemons/span_cost_calculator.py +99 -0
  221. phoenix/server/dml_event.py +17 -0
  222. phoenix/server/dml_event_handler.py +5 -0
  223. phoenix/server/email/sender.py +56 -3
  224. phoenix/server/email/templates/db_disk_usage_notification.html +19 -0
  225. phoenix/server/email/types.py +11 -0
  226. phoenix/server/experiments/__init__.py +0 -0
  227. phoenix/server/experiments/utils.py +14 -0
  228. phoenix/server/grpc_server.py +11 -11
  229. phoenix/server/jwt_store.py +17 -15
  230. phoenix/server/ldap.py +1449 -0
  231. phoenix/server/main.py +26 -10
  232. phoenix/server/oauth2.py +330 -12
  233. phoenix/server/prometheus.py +66 -6
  234. phoenix/server/rate_limiters.py +4 -9
  235. phoenix/server/retention.py +33 -20
  236. phoenix/server/session_filters.py +49 -0
  237. phoenix/server/static/.vite/manifest.json +55 -51
  238. phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
  239. phoenix/server/static/assets/{index-E0M82BdE.js → index-CTQoemZv.js} +140 -56
  240. phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
  241. phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
  242. phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
  243. phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
  244. phoenix/server/static/assets/vendor-recharts-V9cwpXsm.js +37 -0
  245. phoenix/server/static/assets/vendor-shiki-Do--csgv.js +5 -0
  246. phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
  247. phoenix/server/templates/index.html +40 -6
  248. phoenix/server/thread_server.py +1 -2
  249. phoenix/server/types.py +14 -4
  250. phoenix/server/utils.py +74 -0
  251. phoenix/session/client.py +56 -3
  252. phoenix/session/data_extractor.py +5 -0
  253. phoenix/session/evaluation.py +14 -5
  254. phoenix/session/session.py +45 -9
  255. phoenix/settings.py +5 -0
  256. phoenix/trace/attributes.py +80 -13
  257. phoenix/trace/dsl/helpers.py +90 -1
  258. phoenix/trace/dsl/query.py +8 -6
  259. phoenix/trace/projects.py +5 -0
  260. phoenix/utilities/template_formatters.py +1 -1
  261. phoenix/version.py +1 -1
  262. arize_phoenix-10.0.4.dist-info/RECORD +0 -405
  263. phoenix/server/api/types/Evaluation.py +0 -39
  264. phoenix/server/cost_tracking/cost_lookup.py +0 -255
  265. phoenix/server/static/assets/components-DULKeDfL.js +0 -4365
  266. phoenix/server/static/assets/pages-Cl0A-0U2.js +0 -7430
  267. phoenix/server/static/assets/vendor-WIZid84E.css +0 -1
  268. phoenix/server/static/assets/vendor-arizeai-Dy-0mSNw.js +0 -649
  269. phoenix/server/static/assets/vendor-codemirror-DBtifKNr.js +0 -33
  270. phoenix/server/static/assets/vendor-oB4u9zuV.js +0 -905
  271. phoenix/server/static/assets/vendor-recharts-D-T4KPz2.js +0 -59
  272. phoenix/server/static/assets/vendor-shiki-BMn4O_9F.js +0 -5
  273. phoenix/server/static/assets/vendor-three-C5WAXd5r.js +0 -2998
  274. phoenix/utilities/deprecation.py +0 -31
  275. {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
  276. {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,4 +1,4 @@
1
- from typing import TYPE_CHECKING, Any, Callable, Optional, Union
1
+ from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence, Union
2
2
 
3
3
  from phoenix.server.api.types.GenerativeProvider import GenerativeProviderKey
4
4
 
@@ -59,7 +59,7 @@ PLAYGROUND_CLIENT_REGISTRY: PlaygroundClientRegistry = PlaygroundClientRegistry(
59
59
 
60
60
  def register_llm_client(
61
61
  provider_key: GenerativeProviderKey,
62
- model_names: list[ModelName],
62
+ model_names: Sequence[ModelName],
63
63
  ) -> Callable[[type["PlaygroundStreamingClient"]], type["PlaygroundStreamingClient"]]:
64
64
  def decorator(cls: type["PlaygroundStreamingClient"]) -> type["PlaygroundStreamingClient"]:
65
65
  provider_registry = PLAYGROUND_CLIENT_REGISTRY._registry.setdefault(provider_key, {})
@@ -222,6 +222,7 @@ def get_db_experiment_run(
222
222
  *,
223
223
  experiment_id: int,
224
224
  example_id: int,
225
+ repetition_number: int,
225
226
  ) -> models.ExperimentRun:
226
227
  return models.ExperimentRun(
227
228
  experiment_id=experiment_id,
@@ -230,7 +231,7 @@ def get_db_experiment_run(
230
231
  output=models.ExperimentRunOutput(
231
232
  task_output=get_dataset_example_output(db_span),
232
233
  ),
233
- repetition_number=1,
234
+ repetition_number=repetition_number,
234
235
  start_time=db_span.start_time,
235
236
  end_time=db_span.end_time,
236
237
  error=db_span.status_message or None,
@@ -263,10 +264,13 @@ def llm_tools(tools: list[JSONScalarType]) -> Iterator[tuple[str, Any]]:
263
264
  def input_value_and_mime_type(
264
265
  input: Union[ChatCompletionInput, ChatCompletionOverDatasetInput],
265
266
  ) -> Iterator[tuple[str, Any]]:
266
- assert (api_key := "api_key") in (input_data := jsonify(input))
267
- disallowed_keys = {"api_key", "invocation_parameters"}
267
+ input_data = jsonify(input)
268
+ # Filter out sensitive credential information and invocation parameters
269
+ disallowed_keys = {"api_key", "credentials", "invocation_parameters"}
268
270
  input_data = {k: v for k, v in input_data.items() if k not in disallowed_keys}
269
- assert api_key not in input_data
271
+ # Ensure sensitive data is not included in trace data
272
+ assert "api_key" not in input_data
273
+ assert "credentials" not in input_data
270
274
  yield INPUT_MIME_TYPE, JSON
271
275
  yield INPUT_VALUE, safe_json_dumps(input_data)
272
276
 
@@ -0,0 +1,26 @@
1
+ from typing import (
2
+ Optional,
3
+ )
4
+
5
+ from starlette.requests import Request
6
+ from strawberry import Info
7
+
8
+ from phoenix.server.api.context import Context
9
+ from phoenix.server.bearer_auth import PhoenixUser
10
+
11
+
12
+ def get_user(info: Info[Context, None]) -> Optional[int]:
13
+ user_id: Optional[int] = None
14
+ try:
15
+ assert isinstance(request := info.context.request, Request)
16
+
17
+ if "user" in request.scope and isinstance((user := info.context.user), PhoenixUser):
18
+ user_id = int(user.identity)
19
+ except AssertionError:
20
+ # Request is not available, try to obtain user identify
21
+ # this will also throw an assertion error if auth is not available
22
+ # the finally block will continue execution returning None
23
+ if info.context.user.is_authenticated:
24
+ user_id = int(info.context.user.identity)
25
+ finally:
26
+ return user_id
@@ -0,0 +1,83 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Optional, Union
4
+
5
+ from typing_extensions import assert_never
6
+
7
+ if TYPE_CHECKING:
8
+ from anthropic.types import (
9
+ ToolChoiceAnyParam,
10
+ ToolChoiceAutoParam,
11
+ ToolChoiceParam,
12
+ ToolChoiceToolParam,
13
+ )
14
+
15
+ from phoenix.server.api.helpers.prompts.models import (
16
+ PromptToolChoiceNone,
17
+ PromptToolChoiceOneOrMore,
18
+ PromptToolChoiceSpecificFunctionTool,
19
+ PromptToolChoiceZeroOrMore,
20
+ )
21
+
22
+
23
+ class AwsToolChoiceConversion:
24
+ @staticmethod
25
+ def to_aws(
26
+ obj: Union[
27
+ PromptToolChoiceNone,
28
+ PromptToolChoiceZeroOrMore,
29
+ PromptToolChoiceOneOrMore,
30
+ PromptToolChoiceSpecificFunctionTool,
31
+ ],
32
+ disable_parallel_tool_use: Optional[bool] = None,
33
+ ) -> ToolChoiceParam:
34
+ if obj.type == "zero_or_more":
35
+ choice_auto: ToolChoiceAutoParam = {"type": "auto"}
36
+ if disable_parallel_tool_use is not None:
37
+ choice_auto["disable_parallel_tool_use"] = disable_parallel_tool_use
38
+ return choice_auto
39
+ if obj.type == "one_or_more":
40
+ choice_any: ToolChoiceAnyParam = {"type": "any"}
41
+ if disable_parallel_tool_use is not None:
42
+ choice_any["disable_parallel_tool_use"] = disable_parallel_tool_use
43
+ return choice_any
44
+ if obj.type == "specific_function":
45
+ choice_tool: ToolChoiceToolParam = {"type": "tool", "name": obj.function_name}
46
+ if disable_parallel_tool_use is not None:
47
+ choice_tool["disable_parallel_tool_use"] = disable_parallel_tool_use
48
+ return choice_tool
49
+ if obj.type == "none":
50
+ return {"type": "none"}
51
+ assert_never(obj.type)
52
+
53
+ @staticmethod
54
+ def from_aws(
55
+ obj: ToolChoiceParam,
56
+ ) -> Union[
57
+ PromptToolChoiceNone,
58
+ PromptToolChoiceZeroOrMore,
59
+ PromptToolChoiceOneOrMore,
60
+ PromptToolChoiceSpecificFunctionTool,
61
+ ]:
62
+ from phoenix.server.api.helpers.prompts.models import (
63
+ PromptToolChoiceNone,
64
+ PromptToolChoiceOneOrMore,
65
+ PromptToolChoiceSpecificFunctionTool,
66
+ PromptToolChoiceZeroOrMore,
67
+ )
68
+
69
+ if obj["type"] == "auto":
70
+ choice_zero_or_more = PromptToolChoiceZeroOrMore(type="zero_or_more")
71
+ return choice_zero_or_more
72
+ if obj["type"] == "any":
73
+ choice_one_or_more = PromptToolChoiceOneOrMore(type="one_or_more")
74
+ return choice_one_or_more
75
+ if obj["type"] == "tool":
76
+ choice_function_tool = PromptToolChoiceSpecificFunctionTool(
77
+ type="specific_function",
78
+ function_name=obj["name"],
79
+ )
80
+ return choice_function_tool
81
+ if obj["type"] == "none":
82
+ return PromptToolChoiceNone(type="none")
83
+ assert_never(obj)
@@ -0,0 +1,103 @@
1
+ from typing import TYPE_CHECKING, Any, Literal, Union
2
+
3
+ from typing_extensions import NotRequired, TypedDict, assert_never
4
+
5
+ if TYPE_CHECKING:
6
+ from phoenix.server.api.helpers.prompts.models import (
7
+ PromptToolChoiceNone,
8
+ PromptToolChoiceOneOrMore,
9
+ PromptToolChoiceSpecificFunctionTool,
10
+ PromptToolChoiceZeroOrMore,
11
+ )
12
+
13
+
14
+ class GoogleFunctionCallingConfig(TypedDict, total=False):
15
+ """
16
+ Based on https://github.com/googleapis/python-genai/blob/97cc7e4eafbee4fa4035e7420170ab6a2c9da7fb/google/genai/types.py#L4245
17
+ """
18
+
19
+ mode: NotRequired[Literal["auto", "any", "none"]]
20
+ allowed_function_names: NotRequired[list[str]]
21
+
22
+
23
+ class GoogleToolChoice(TypedDict):
24
+ """
25
+ Based on https://github.com/googleapis/python-genai/blob/97cc7e4eafbee4fa4035e7420170ab6a2c9da7fb/google/genai/types.py#L4341
26
+ """
27
+
28
+ function_calling_config: GoogleFunctionCallingConfig
29
+
30
+
31
+ class GoogleToolChoiceConversion:
32
+ @staticmethod
33
+ def to_google(
34
+ obj: Union[
35
+ "PromptToolChoiceNone",
36
+ "PromptToolChoiceZeroOrMore",
37
+ "PromptToolChoiceOneOrMore",
38
+ "PromptToolChoiceSpecificFunctionTool",
39
+ ],
40
+ ) -> GoogleToolChoice:
41
+ if obj.type == "none":
42
+ return {"function_calling_config": {"mode": "none"}}
43
+ if obj.type == "zero_or_more":
44
+ return {"function_calling_config": {"mode": "auto"}}
45
+ if obj.type == "one_or_more":
46
+ return {"function_calling_config": {"mode": "any"}}
47
+ if obj.type == "specific_function":
48
+ return {
49
+ "function_calling_config": {
50
+ "mode": "any",
51
+ "allowed_function_names": [obj.function_name],
52
+ }
53
+ }
54
+ assert_never(obj)
55
+
56
+ @staticmethod
57
+ def from_google(
58
+ obj: Any,
59
+ ) -> Union[
60
+ "PromptToolChoiceNone",
61
+ "PromptToolChoiceZeroOrMore",
62
+ "PromptToolChoiceOneOrMore",
63
+ "PromptToolChoiceSpecificFunctionTool",
64
+ ]:
65
+ from google.genai.types import ToolConfig
66
+
67
+ from phoenix.server.api.helpers.prompts.models import (
68
+ PromptToolChoiceNone,
69
+ PromptToolChoiceOneOrMore,
70
+ PromptToolChoiceSpecificFunctionTool,
71
+ PromptToolChoiceZeroOrMore,
72
+ )
73
+
74
+ tool_config = ToolConfig.model_validate(obj)
75
+ if (function_calling_config := tool_config.function_calling_config) is None:
76
+ raise ValueError("function_calling_config is required")
77
+ # normalize mode to lowercase since Google's API is case-insensitive
78
+ # https://github.com/googleapis/python-genai/blob/97cc7e4eafbee4fa4035e7420170ab6a2c9da7fb/google/genai/types.py#L645
79
+ normalized_mode = (
80
+ function_calling_config.mode.value.lower()
81
+ if function_calling_config.mode is not None
82
+ else None
83
+ )
84
+ allowed_function_names = function_calling_config.allowed_function_names
85
+
86
+ if allowed_function_names:
87
+ if len(allowed_function_names) != 1:
88
+ raise ValueError("Only one allowed function name is currently supported")
89
+ if normalized_mode != "any":
90
+ raise ValueError("allowed function names only supported in 'any' mode")
91
+ return PromptToolChoiceSpecificFunctionTool(
92
+ type="specific_function",
93
+ function_name=allowed_function_names[0],
94
+ )
95
+
96
+ if normalized_mode == "none":
97
+ return PromptToolChoiceNone(type="none")
98
+ if normalized_mode == "auto" or normalized_mode is None:
99
+ return PromptToolChoiceZeroOrMore(type="zero_or_more")
100
+ if normalized_mode == "any":
101
+ return PromptToolChoiceOneOrMore(type="one_or_more")
102
+
103
+ raise ValueError(f"Unsupported Google tool choice mode: {normalized_mode}")
@@ -1,5 +1,3 @@
1
- from __future__ import annotations
2
-
3
1
  from enum import Enum
4
2
  from typing import Any, Literal, Mapping, Optional, Union
5
3
 
@@ -9,6 +7,8 @@ from typing_extensions import Annotated, Self, TypeAlias, TypeGuard, assert_neve
9
7
  from phoenix.db.types.db_models import UNDEFINED, DBBaseModel
10
8
  from phoenix.db.types.model_provider import ModelProvider
11
9
  from phoenix.server.api.helpers.prompts.conversions.anthropic import AnthropicToolChoiceConversion
10
+ from phoenix.server.api.helpers.prompts.conversions.aws import AwsToolChoiceConversion
11
+ from phoenix.server.api.helpers.prompts.conversions.google import GoogleToolChoiceConversion
12
12
  from phoenix.server.api.helpers.prompts.conversions.openai import OpenAIToolChoiceConversion
13
13
 
14
14
  JSONSerializable = Union[None, bool, int, float, str, dict[str, Any], list[Any]]
@@ -126,11 +126,6 @@ class PromptTemplateRootModel(RootModel[PromptTemplate]):
126
126
  root: PromptTemplate
127
127
 
128
128
 
129
- class PromptToolFunction(DBBaseModel):
130
- type: Literal["function"]
131
- function: PromptToolFunctionDefinition
132
-
133
-
134
129
  class PromptToolFunctionDefinition(DBBaseModel):
135
130
  name: str
136
131
  description: str = UNDEFINED
@@ -138,14 +133,12 @@ class PromptToolFunctionDefinition(DBBaseModel):
138
133
  strict: bool = UNDEFINED
139
134
 
140
135
 
141
- PromptTool: TypeAlias = Annotated[Union[PromptToolFunction], Field(..., discriminator="type")]
136
+ class PromptToolFunction(DBBaseModel):
137
+ type: Literal["function"]
138
+ function: PromptToolFunctionDefinition
142
139
 
143
140
 
144
- class PromptTools(DBBaseModel):
145
- type: Literal["tools"]
146
- tools: Annotated[list[PromptTool], Field(..., min_length=1)]
147
- tool_choice: PromptToolChoice = UNDEFINED
148
- disable_parallel_tool_calls: bool = UNDEFINED
141
+ PromptTool: TypeAlias = Annotated[Union[PromptToolFunction], Field(..., discriminator="type")]
149
142
 
150
143
 
151
144
  class PromptToolChoiceNone(DBBaseModel):
@@ -176,6 +169,13 @@ PromptToolChoice: TypeAlias = Annotated[
176
169
  ]
177
170
 
178
171
 
172
+ class PromptTools(DBBaseModel):
173
+ type: Literal["tools"]
174
+ tools: Annotated[list[PromptTool], Field(..., min_length=1)]
175
+ tool_choice: PromptToolChoice = UNDEFINED
176
+ disable_parallel_tool_calls: bool = UNDEFINED
177
+
178
+
179
179
  class PromptOpenAIJSONSchema(DBBaseModel):
180
180
  """
181
181
  Based on https://github.com/openai/openai-python/blob/d16e6edde5a155626910b5758a0b939bfedb9ced/src/openai/types/shared/response_format_json_schema.py#L13
@@ -199,11 +199,6 @@ class PromptOpenAIResponseFormatJSONSchema(DBBaseModel):
199
199
  type: Literal["json_schema"]
200
200
 
201
201
 
202
- class PromptResponseFormatJSONSchema(DBBaseModel):
203
- type: Literal["json_schema"]
204
- json_schema: PromptResponseFormatJSONSchemaDefinition
205
-
206
-
207
202
  class PromptResponseFormatJSONSchemaDefinition(DBBaseModel):
208
203
  name: str
209
204
  description: str = UNDEFINED
@@ -211,6 +206,11 @@ class PromptResponseFormatJSONSchemaDefinition(DBBaseModel):
211
206
  strict: bool = UNDEFINED
212
207
 
213
208
 
209
+ class PromptResponseFormatJSONSchema(DBBaseModel):
210
+ type: Literal["json_schema"]
211
+ json_schema: PromptResponseFormatJSONSchemaDefinition
212
+
213
+
214
214
  PromptResponseFormat: TypeAlias = Annotated[
215
215
  Union[PromptResponseFormatJSONSchema], Field(..., discriminator="type")
216
216
  ]
@@ -312,6 +312,24 @@ class AnthropicToolDefinition(DBBaseModel):
312
312
  description: str = UNDEFINED
313
313
 
314
314
 
315
+ class BedrockToolDefinition(DBBaseModel):
316
+ """
317
+ Based on https://github.com/aws/amazon-bedrock-sdk-python/blob/main/src/bedrock/types/tool_param.py#L12
318
+ """
319
+
320
+ toolSpec: dict[str, Any]
321
+
322
+
323
+ class GeminiToolDefinition(DBBaseModel):
324
+ """
325
+ Based on https://github.com/googleapis/python-genai/blob/c0b175a0ca20286db419390031a2239938d0c0b7/google/genai/types.py#L2792
326
+ """
327
+
328
+ name: str
329
+ description: str = UNDEFINED
330
+ parameters: dict[str, Any]
331
+
332
+
315
333
  class PromptOpenAIInvocationParametersContent(DBBaseModel):
316
334
  temperature: float = UNDEFINED
317
335
  max_tokens: int = UNDEFINED
@@ -320,7 +338,7 @@ class PromptOpenAIInvocationParametersContent(DBBaseModel):
320
338
  presence_penalty: float = UNDEFINED
321
339
  top_p: float = UNDEFINED
322
340
  seed: int = UNDEFINED
323
- reasoning_effort: Literal["low", "medium", "high"] = UNDEFINED
341
+ reasoning_effort: Literal["none", "minimal", "low", "medium", "high", "xhigh"] = UNDEFINED
324
342
 
325
343
 
326
344
  class PromptOpenAIInvocationParameters(DBBaseModel):
@@ -332,11 +350,38 @@ class PromptAzureOpenAIInvocationParametersContent(PromptOpenAIInvocationParamet
332
350
  pass
333
351
 
334
352
 
353
+ class PromptDeepSeekInvocationParametersContent(PromptOpenAIInvocationParametersContent):
354
+ pass
355
+
356
+
357
+ class PromptXAIInvocationParametersContent(PromptOpenAIInvocationParametersContent):
358
+ pass
359
+
360
+
361
+ class PromptOllamaInvocationParametersContent(PromptOpenAIInvocationParametersContent):
362
+ pass
363
+
364
+
335
365
  class PromptAzureOpenAIInvocationParameters(DBBaseModel):
336
366
  type: Literal["azure_openai"]
337
367
  azure_openai: PromptAzureOpenAIInvocationParametersContent
338
368
 
339
369
 
370
+ class PromptDeepSeekInvocationParameters(DBBaseModel):
371
+ type: Literal["deepseek"]
372
+ deepseek: PromptDeepSeekInvocationParametersContent
373
+
374
+
375
+ class PromptXAIInvocationParameters(DBBaseModel):
376
+ type: Literal["xai"]
377
+ xai: PromptXAIInvocationParametersContent
378
+
379
+
380
+ class PromptOllamaInvocationParameters(DBBaseModel):
381
+ type: Literal["ollama"]
382
+ ollama: PromptOllamaInvocationParametersContent
383
+
384
+
340
385
  class PromptAnthropicThinkingConfigDisabled(DBBaseModel):
341
386
  type: Literal["disabled"]
342
387
 
@@ -370,6 +415,17 @@ class PromptAnthropicInvocationParameters(DBBaseModel):
370
415
  anthropic: PromptAnthropicInvocationParametersContent
371
416
 
372
417
 
418
+ class PromptAwsInvocationParametersContent(DBBaseModel):
419
+ max_tokens: int = UNDEFINED
420
+ temperature: float = UNDEFINED
421
+ top_p: float = UNDEFINED
422
+
423
+
424
+ class PromptAwsInvocationParameters(DBBaseModel):
425
+ type: Literal["aws"]
426
+ aws: PromptAwsInvocationParametersContent
427
+
428
+
373
429
  class PromptGoogleInvocationParametersContent(DBBaseModel):
374
430
  temperature: float = UNDEFINED
375
431
  max_output_tokens: int = UNDEFINED
@@ -391,6 +447,10 @@ PromptInvocationParameters: TypeAlias = Annotated[
391
447
  PromptAzureOpenAIInvocationParameters,
392
448
  PromptAnthropicInvocationParameters,
393
449
  PromptGoogleInvocationParameters,
450
+ PromptDeepSeekInvocationParameters,
451
+ PromptXAIInvocationParameters,
452
+ PromptOllamaInvocationParameters,
453
+ PromptAwsInvocationParameters,
394
454
  ],
395
455
  Field(..., discriminator="type"),
396
456
  ]
@@ -407,6 +467,14 @@ def get_raw_invocation_parameters(
407
467
  return invocation_parameters.anthropic.model_dump()
408
468
  if isinstance(invocation_parameters, PromptGoogleInvocationParameters):
409
469
  return invocation_parameters.google.model_dump()
470
+ if isinstance(invocation_parameters, PromptDeepSeekInvocationParameters):
471
+ return invocation_parameters.deepseek.model_dump()
472
+ if isinstance(invocation_parameters, PromptXAIInvocationParameters):
473
+ return invocation_parameters.xai.model_dump()
474
+ if isinstance(invocation_parameters, PromptOllamaInvocationParameters):
475
+ return invocation_parameters.ollama.model_dump()
476
+ if isinstance(invocation_parameters, PromptAwsInvocationParameters):
477
+ return invocation_parameters.aws.model_dump()
410
478
  assert_never(invocation_parameters)
411
479
 
412
480
 
@@ -420,6 +488,10 @@ def is_prompt_invocation_parameters(
420
488
  PromptAzureOpenAIInvocationParameters,
421
489
  PromptAnthropicInvocationParameters,
422
490
  PromptGoogleInvocationParameters,
491
+ PromptDeepSeekInvocationParameters,
492
+ PromptXAIInvocationParameters,
493
+ PromptOllamaInvocationParameters,
494
+ PromptAwsInvocationParameters,
423
495
  ),
424
496
  )
425
497
 
@@ -444,6 +516,13 @@ def validate_invocation_parameters(
444
516
  invocation_parameters
445
517
  ),
446
518
  )
519
+ elif model_provider is ModelProvider.DEEPSEEK:
520
+ return PromptDeepSeekInvocationParameters(
521
+ type="deepseek",
522
+ deepseek=PromptDeepSeekInvocationParametersContent.model_validate(
523
+ invocation_parameters
524
+ ),
525
+ )
447
526
  elif model_provider is ModelProvider.ANTHROPIC:
448
527
  return PromptAnthropicInvocationParameters(
449
528
  type="anthropic",
@@ -456,6 +535,21 @@ def validate_invocation_parameters(
456
535
  type="google",
457
536
  google=PromptGoogleInvocationParametersContent.model_validate(invocation_parameters),
458
537
  )
538
+ elif model_provider is ModelProvider.XAI:
539
+ return PromptXAIInvocationParameters(
540
+ type="xai",
541
+ xai=PromptXAIInvocationParametersContent.model_validate(invocation_parameters),
542
+ )
543
+ elif model_provider is ModelProvider.OLLAMA:
544
+ return PromptOllamaInvocationParameters(
545
+ type="ollama",
546
+ ollama=PromptOllamaInvocationParametersContent.model_validate(invocation_parameters),
547
+ )
548
+ elif model_provider is ModelProvider.AWS:
549
+ return PromptAwsInvocationParameters(
550
+ type="aws",
551
+ aws=PromptAwsInvocationParametersContent.model_validate(invocation_parameters),
552
+ )
459
553
  assert_never(model_provider)
460
554
 
461
555
 
@@ -465,18 +559,39 @@ def normalize_tools(
465
559
  tool_choice: Optional[Union[str, Mapping[str, Any]]] = None,
466
560
  ) -> PromptTools:
467
561
  tools: list[PromptToolFunction]
468
- if model_provider is ModelProvider.OPENAI or model_provider is ModelProvider.AZURE_OPENAI:
562
+ if (
563
+ model_provider is ModelProvider.OPENAI
564
+ or model_provider is ModelProvider.AZURE_OPENAI
565
+ or model_provider is ModelProvider.DEEPSEEK
566
+ or model_provider is ModelProvider.XAI
567
+ or model_provider is ModelProvider.OLLAMA
568
+ ):
469
569
  openai_tools = [OpenAIToolDefinition.model_validate(schema) for schema in schemas]
470
570
  tools = [_openai_to_prompt_tool(openai_tool) for openai_tool in openai_tools]
571
+ elif model_provider is ModelProvider.AWS:
572
+ bedrock_tools = [BedrockToolDefinition.model_validate(schema) for schema in schemas]
573
+ tools = [_bedrock_to_prompt_tool(bedrock_tool) for bedrock_tool in bedrock_tools]
471
574
  elif model_provider is ModelProvider.ANTHROPIC:
472
575
  anthropic_tools = [AnthropicToolDefinition.model_validate(schema) for schema in schemas]
473
576
  tools = [_anthropic_to_prompt_tool(anthropic_tool) for anthropic_tool in anthropic_tools]
577
+ elif model_provider is ModelProvider.GOOGLE:
578
+ gemini_tools = [GeminiToolDefinition.model_validate(schema) for schema in schemas]
579
+ tools = [_gemini_to_prompt_tool(gemini_tool) for gemini_tool in gemini_tools]
474
580
  else:
475
581
  raise ValueError(f"Unsupported model provider: {model_provider}")
476
582
  ans = PromptTools(type="tools", tools=tools)
583
+
477
584
  if tool_choice is not None:
478
- if model_provider is ModelProvider.OPENAI or model_provider is ModelProvider.AZURE_OPENAI:
585
+ if (
586
+ model_provider is ModelProvider.OPENAI
587
+ or model_provider is ModelProvider.AZURE_OPENAI
588
+ or model_provider is ModelProvider.DEEPSEEK
589
+ or model_provider is ModelProvider.XAI
590
+ or model_provider is ModelProvider.OLLAMA
591
+ ):
479
592
  ans.tool_choice = OpenAIToolChoiceConversion.from_openai(tool_choice) # type: ignore[arg-type]
593
+ elif model_provider is ModelProvider.AWS:
594
+ ans.tool_choice = AwsToolChoiceConversion.from_aws(tool_choice) # type: ignore[arg-type]
480
595
  elif model_provider is ModelProvider.ANTHROPIC:
481
596
  choice, disable_parallel_tool_calls = AnthropicToolChoiceConversion.from_anthropic(
482
597
  tool_choice # type: ignore[arg-type]
@@ -484,6 +599,8 @@ def normalize_tools(
484
599
  ans.tool_choice = choice
485
600
  if disable_parallel_tool_calls is not None:
486
601
  ans.disable_parallel_tool_calls = disable_parallel_tool_calls
602
+ elif model_provider is ModelProvider.GOOGLE:
603
+ ans.tool_choice = GoogleToolChoiceConversion.from_google(tool_choice)
487
604
  return ans
488
605
 
489
606
 
@@ -493,14 +610,28 @@ def denormalize_tools(
493
610
  assert tools.type == "tools"
494
611
  denormalized_tools: list[DBBaseModel]
495
612
  tool_choice: Optional[Any] = None
496
- if model_provider is ModelProvider.OPENAI or model_provider is ModelProvider.AZURE_OPENAI:
613
+ if (
614
+ model_provider is ModelProvider.OPENAI
615
+ or model_provider is ModelProvider.AZURE_OPENAI
616
+ or model_provider is ModelProvider.DEEPSEEK
617
+ or model_provider is ModelProvider.XAI
618
+ or model_provider is ModelProvider.OLLAMA
619
+ ):
497
620
  denormalized_tools = [_prompt_to_openai_tool(tool) for tool in tools.tools]
498
621
  if tools.tool_choice:
499
622
  tool_choice = OpenAIToolChoiceConversion.to_openai(tools.tool_choice)
623
+ elif model_provider is ModelProvider.AWS:
624
+ denormalized_tools = [_prompt_to_bedrock_tool(tool) for tool in tools.tools]
625
+ if tools.tool_choice:
626
+ tool_choice = OpenAIToolChoiceConversion.to_openai(tools.tool_choice)
500
627
  elif model_provider is ModelProvider.ANTHROPIC:
501
628
  denormalized_tools = [_prompt_to_anthropic_tool(tool) for tool in tools.tools]
502
629
  if tools.tool_choice and tools.tool_choice.type != "none":
503
630
  tool_choice = AnthropicToolChoiceConversion.to_anthropic(tools.tool_choice)
631
+ elif model_provider is ModelProvider.GOOGLE:
632
+ denormalized_tools = [_prompt_to_gemini_tool(tool) for tool in tools.tools]
633
+ if tools.tool_choice:
634
+ tool_choice = GoogleToolChoiceConversion.to_google(tools.tool_choice)
504
635
  else:
505
636
  raise ValueError(f"Unsupported model provider: {model_provider}")
506
637
  return [tool.model_dump() for tool in denormalized_tools], tool_choice
@@ -540,6 +671,19 @@ def _prompt_to_openai_tool(
540
671
  )
541
672
 
542
673
 
674
+ def _bedrock_to_prompt_tool(
675
+ tool: BedrockToolDefinition,
676
+ ) -> PromptToolFunction:
677
+ return PromptToolFunction(
678
+ type="function",
679
+ function=PromptToolFunctionDefinition(
680
+ name=tool.toolSpec["name"],
681
+ description=tool.toolSpec["description"],
682
+ parameters=tool.toolSpec["inputSchema"]["json"],
683
+ ),
684
+ )
685
+
686
+
543
687
  def _anthropic_to_prompt_tool(
544
688
  tool: AnthropicToolDefinition,
545
689
  ) -> PromptToolFunction:
@@ -562,3 +706,42 @@ def _prompt_to_anthropic_tool(
562
706
  name=function.name,
563
707
  description=function.description,
564
708
  )
709
+
710
+
711
+ def _prompt_to_bedrock_tool(
712
+ tool: PromptToolFunction,
713
+ ) -> BedrockToolDefinition:
714
+ function = tool.function
715
+ return BedrockToolDefinition(
716
+ toolSpec={
717
+ "name": function.name,
718
+ "description": function.description,
719
+ "inputSchema": {
720
+ "json": function.parameters,
721
+ },
722
+ }
723
+ )
724
+
725
+
726
+ def _gemini_to_prompt_tool(
727
+ tool: GeminiToolDefinition,
728
+ ) -> PromptToolFunction:
729
+ return PromptToolFunction(
730
+ type="function",
731
+ function=PromptToolFunctionDefinition(
732
+ name=tool.name,
733
+ description=tool.description,
734
+ parameters=tool.parameters,
735
+ ),
736
+ )
737
+
738
+
739
+ def _prompt_to_gemini_tool(
740
+ tool: PromptToolFunction,
741
+ ) -> GeminiToolDefinition:
742
+ function = tool.function
743
+ return GeminiToolDefinition(
744
+ name=function.name,
745
+ description=function.description,
746
+ parameters=function.parameters,
747
+ )