arize-phoenix 11.23.1__py3-none-any.whl → 12.28.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (221) hide show
  1. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +61 -36
  2. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/RECORD +212 -162
  3. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
  4. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
  5. phoenix/__generated__/__init__.py +0 -0
  6. phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
  7. phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
  8. phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
  9. phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
  10. phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
  11. phoenix/__init__.py +2 -1
  12. phoenix/auth.py +27 -2
  13. phoenix/config.py +1594 -81
  14. phoenix/db/README.md +546 -28
  15. phoenix/db/bulk_inserter.py +119 -116
  16. phoenix/db/engines.py +140 -33
  17. phoenix/db/facilitator.py +22 -1
  18. phoenix/db/helpers.py +818 -65
  19. phoenix/db/iam_auth.py +64 -0
  20. phoenix/db/insertion/dataset.py +133 -1
  21. phoenix/db/insertion/document_annotation.py +9 -6
  22. phoenix/db/insertion/evaluation.py +2 -3
  23. phoenix/db/insertion/helpers.py +2 -2
  24. phoenix/db/insertion/session_annotation.py +176 -0
  25. phoenix/db/insertion/span_annotation.py +3 -4
  26. phoenix/db/insertion/trace_annotation.py +3 -4
  27. phoenix/db/insertion/types.py +41 -18
  28. phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
  29. phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
  30. phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
  31. phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
  32. phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
  33. phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
  34. phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
  35. phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
  36. phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
  37. phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
  38. phoenix/db/models.py +364 -56
  39. phoenix/db/pg_config.py +10 -0
  40. phoenix/db/types/trace_retention.py +7 -6
  41. phoenix/experiments/functions.py +69 -19
  42. phoenix/inferences/inferences.py +1 -2
  43. phoenix/server/api/auth.py +9 -0
  44. phoenix/server/api/auth_messages.py +46 -0
  45. phoenix/server/api/context.py +60 -0
  46. phoenix/server/api/dataloaders/__init__.py +36 -0
  47. phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
  48. phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
  49. phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
  50. phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
  51. phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
  52. phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
  53. phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
  54. phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
  55. phoenix/server/api/dataloaders/dataset_labels.py +36 -0
  56. phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
  57. phoenix/server/api/dataloaders/document_evaluations.py +6 -9
  58. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
  59. phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
  60. phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
  61. phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
  62. phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
  63. phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
  64. phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
  65. phoenix/server/api/dataloaders/record_counts.py +37 -10
  66. phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
  67. phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
  68. phoenix/server/api/dataloaders/span_cost_summary_by_project.py +28 -14
  69. phoenix/server/api/dataloaders/span_costs.py +3 -9
  70. phoenix/server/api/dataloaders/table_fields.py +2 -2
  71. phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
  72. phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
  73. phoenix/server/api/exceptions.py +5 -1
  74. phoenix/server/api/helpers/playground_clients.py +263 -83
  75. phoenix/server/api/helpers/playground_spans.py +2 -1
  76. phoenix/server/api/helpers/playground_users.py +26 -0
  77. phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
  78. phoenix/server/api/helpers/prompts/models.py +61 -19
  79. phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
  80. phoenix/server/api/input_types/ChatCompletionInput.py +3 -0
  81. phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
  82. phoenix/server/api/input_types/DatasetFilter.py +5 -2
  83. phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
  84. phoenix/server/api/input_types/GenerativeModelInput.py +3 -0
  85. phoenix/server/api/input_types/ProjectSessionSort.py +158 -1
  86. phoenix/server/api/input_types/PromptVersionInput.py +47 -1
  87. phoenix/server/api/input_types/SpanSort.py +3 -2
  88. phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
  89. phoenix/server/api/input_types/UserRoleInput.py +1 -0
  90. phoenix/server/api/mutations/__init__.py +8 -0
  91. phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
  92. phoenix/server/api/mutations/api_key_mutations.py +15 -20
  93. phoenix/server/api/mutations/chat_mutations.py +106 -37
  94. phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
  95. phoenix/server/api/mutations/dataset_mutations.py +21 -16
  96. phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
  97. phoenix/server/api/mutations/experiment_mutations.py +2 -2
  98. phoenix/server/api/mutations/export_events_mutations.py +3 -3
  99. phoenix/server/api/mutations/model_mutations.py +11 -9
  100. phoenix/server/api/mutations/project_mutations.py +4 -4
  101. phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
  102. phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
  103. phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
  104. phoenix/server/api/mutations/prompt_mutations.py +65 -129
  105. phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
  106. phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
  107. phoenix/server/api/mutations/trace_annotations_mutations.py +13 -8
  108. phoenix/server/api/mutations/trace_mutations.py +3 -3
  109. phoenix/server/api/mutations/user_mutations.py +55 -26
  110. phoenix/server/api/queries.py +501 -617
  111. phoenix/server/api/routers/__init__.py +2 -2
  112. phoenix/server/api/routers/auth.py +141 -87
  113. phoenix/server/api/routers/ldap.py +229 -0
  114. phoenix/server/api/routers/oauth2.py +349 -101
  115. phoenix/server/api/routers/v1/__init__.py +22 -4
  116. phoenix/server/api/routers/v1/annotation_configs.py +19 -30
  117. phoenix/server/api/routers/v1/annotations.py +455 -13
  118. phoenix/server/api/routers/v1/datasets.py +355 -68
  119. phoenix/server/api/routers/v1/documents.py +142 -0
  120. phoenix/server/api/routers/v1/evaluations.py +20 -28
  121. phoenix/server/api/routers/v1/experiment_evaluations.py +16 -6
  122. phoenix/server/api/routers/v1/experiment_runs.py +335 -59
  123. phoenix/server/api/routers/v1/experiments.py +475 -47
  124. phoenix/server/api/routers/v1/projects.py +16 -50
  125. phoenix/server/api/routers/v1/prompts.py +50 -39
  126. phoenix/server/api/routers/v1/sessions.py +108 -0
  127. phoenix/server/api/routers/v1/spans.py +156 -96
  128. phoenix/server/api/routers/v1/traces.py +51 -77
  129. phoenix/server/api/routers/v1/users.py +64 -24
  130. phoenix/server/api/routers/v1/utils.py +3 -7
  131. phoenix/server/api/subscriptions.py +257 -93
  132. phoenix/server/api/types/Annotation.py +90 -23
  133. phoenix/server/api/types/ApiKey.py +13 -17
  134. phoenix/server/api/types/AuthMethod.py +1 -0
  135. phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
  136. phoenix/server/api/types/Dataset.py +199 -72
  137. phoenix/server/api/types/DatasetExample.py +88 -18
  138. phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
  139. phoenix/server/api/types/DatasetLabel.py +57 -0
  140. phoenix/server/api/types/DatasetSplit.py +98 -0
  141. phoenix/server/api/types/DatasetVersion.py +49 -4
  142. phoenix/server/api/types/DocumentAnnotation.py +212 -0
  143. phoenix/server/api/types/Experiment.py +215 -68
  144. phoenix/server/api/types/ExperimentComparison.py +3 -9
  145. phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
  146. phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
  147. phoenix/server/api/types/ExperimentRun.py +120 -70
  148. phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
  149. phoenix/server/api/types/GenerativeModel.py +95 -42
  150. phoenix/server/api/types/GenerativeProvider.py +1 -1
  151. phoenix/server/api/types/ModelInterface.py +7 -2
  152. phoenix/server/api/types/PlaygroundModel.py +12 -2
  153. phoenix/server/api/types/Project.py +218 -185
  154. phoenix/server/api/types/ProjectSession.py +146 -29
  155. phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
  156. phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
  157. phoenix/server/api/types/Prompt.py +119 -39
  158. phoenix/server/api/types/PromptLabel.py +42 -25
  159. phoenix/server/api/types/PromptVersion.py +11 -8
  160. phoenix/server/api/types/PromptVersionTag.py +65 -25
  161. phoenix/server/api/types/Span.py +130 -123
  162. phoenix/server/api/types/SpanAnnotation.py +189 -42
  163. phoenix/server/api/types/SystemApiKey.py +65 -1
  164. phoenix/server/api/types/Trace.py +184 -53
  165. phoenix/server/api/types/TraceAnnotation.py +149 -50
  166. phoenix/server/api/types/User.py +128 -33
  167. phoenix/server/api/types/UserApiKey.py +73 -26
  168. phoenix/server/api/types/node.py +10 -0
  169. phoenix/server/api/types/pagination.py +11 -2
  170. phoenix/server/app.py +154 -36
  171. phoenix/server/authorization.py +5 -4
  172. phoenix/server/bearer_auth.py +13 -5
  173. phoenix/server/cost_tracking/cost_model_lookup.py +42 -14
  174. phoenix/server/cost_tracking/model_cost_manifest.json +1085 -194
  175. phoenix/server/daemons/generative_model_store.py +61 -9
  176. phoenix/server/daemons/span_cost_calculator.py +10 -8
  177. phoenix/server/dml_event.py +13 -0
  178. phoenix/server/email/sender.py +29 -2
  179. phoenix/server/grpc_server.py +9 -9
  180. phoenix/server/jwt_store.py +8 -6
  181. phoenix/server/ldap.py +1449 -0
  182. phoenix/server/main.py +9 -3
  183. phoenix/server/oauth2.py +330 -12
  184. phoenix/server/prometheus.py +43 -6
  185. phoenix/server/rate_limiters.py +4 -9
  186. phoenix/server/retention.py +33 -20
  187. phoenix/server/session_filters.py +49 -0
  188. phoenix/server/static/.vite/manifest.json +51 -53
  189. phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
  190. phoenix/server/static/assets/{index-BPCwGQr8.js → index-CTQoemZv.js} +42 -35
  191. phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
  192. phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
  193. phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
  194. phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
  195. phoenix/server/static/assets/{vendor-recharts-Bw30oz1A.js → vendor-recharts-V9cwpXsm.js} +7 -7
  196. phoenix/server/static/assets/{vendor-shiki-DZajAPeq.js → vendor-shiki-Do--csgv.js} +1 -1
  197. phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
  198. phoenix/server/templates/index.html +7 -1
  199. phoenix/server/thread_server.py +1 -2
  200. phoenix/server/utils.py +74 -0
  201. phoenix/session/client.py +55 -1
  202. phoenix/session/data_extractor.py +5 -0
  203. phoenix/session/evaluation.py +8 -4
  204. phoenix/session/session.py +44 -8
  205. phoenix/settings.py +2 -0
  206. phoenix/trace/attributes.py +80 -13
  207. phoenix/trace/dsl/query.py +2 -0
  208. phoenix/trace/projects.py +5 -0
  209. phoenix/utilities/template_formatters.py +1 -1
  210. phoenix/version.py +1 -1
  211. phoenix/server/api/types/Evaluation.py +0 -39
  212. phoenix/server/static/assets/components-D0DWAf0l.js +0 -5650
  213. phoenix/server/static/assets/pages-Creyamao.js +0 -8612
  214. phoenix/server/static/assets/vendor-CU36oj8y.js +0 -905
  215. phoenix/server/static/assets/vendor-CqDb5u4o.css +0 -1
  216. phoenix/server/static/assets/vendor-arizeai-Ctgw0e1G.js +0 -168
  217. phoenix/server/static/assets/vendor-codemirror-Cojjzqb9.js +0 -25
  218. phoenix/server/static/assets/vendor-three-BLWp5bic.js +0 -2998
  219. phoenix/utilities/deprecation.py +0 -31
  220. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
  221. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
@@ -18,7 +18,7 @@ _AttrStrIdentifier: TypeAlias = str
18
18
 
19
19
 
20
20
  class TableFieldsDataLoader(DataLoader[Key, Result]):
21
- def __init__(self, db: DbSessionFactory, table: type[models.Base]) -> None:
21
+ def __init__(self, db: DbSessionFactory, table: type[models.HasId]) -> None:
22
22
  super().__init__(load_fn=self._load_fn)
23
23
  self._db = db
24
24
  self._table = table
@@ -37,7 +37,7 @@ class TableFieldsDataLoader(DataLoader[Key, Result]):
37
37
 
38
38
  def _get_stmt(
39
39
  keys: Iterable[tuple[RowId, QueryableAttribute[Any]]],
40
- table: type[models.Base],
40
+ table: type[models.HasId],
41
41
  ) -> tuple[
42
42
  Select[Any],
43
43
  dict[_ResultColumnPosition, _AttrStrIdentifier],
@@ -0,0 +1,30 @@
1
+ from collections import defaultdict
2
+
3
+ from sqlalchemy import select
4
+ from strawberry.dataloader import DataLoader
5
+ from typing_extensions import TypeAlias
6
+
7
+ from phoenix.db import models
8
+ from phoenix.server.types import DbSessionFactory
9
+
10
+ ModelId: TypeAlias = int
11
+ Key: TypeAlias = ModelId
12
+ Result: TypeAlias = list[models.TokenPrice]
13
+
14
+
15
+ class TokenPricesByModelDataLoader(DataLoader[Key, Result]):
16
+ def __init__(self, db: DbSessionFactory) -> None:
17
+ super().__init__(load_fn=self._load_fn)
18
+ self._db = db
19
+
20
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
21
+ model_ids = keys
22
+ token_prices: defaultdict[Key, Result] = defaultdict(list)
23
+
24
+ async with self._db() as session:
25
+ async for token_price in await session.stream_scalars(
26
+ select(models.TokenPrice).where(models.TokenPrice.model_id.in_(model_ids))
27
+ ):
28
+ token_prices[token_price.model_id].append(token_price)
29
+
30
+ return [token_prices[model_id] for model_id in keys]
@@ -0,0 +1,27 @@
1
+ from collections import defaultdict
2
+
3
+ from sqlalchemy import select
4
+ from strawberry.dataloader import DataLoader
5
+ from typing_extensions import TypeAlias
6
+
7
+ from phoenix.db.models import TraceAnnotation
8
+ from phoenix.server.types import DbSessionFactory
9
+
10
+ TraceRowId: TypeAlias = int
11
+ Key: TypeAlias = TraceRowId
12
+ Result: TypeAlias = list[TraceAnnotation]
13
+
14
+
15
+ class TraceAnnotationsByTraceDataLoader(DataLoader[Key, Result]):
16
+ def __init__(self, db: DbSessionFactory) -> None:
17
+ super().__init__(load_fn=self._load_fn)
18
+ self._db = db
19
+
20
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
21
+ annotations_by_id: defaultdict[Key, Result] = defaultdict(list)
22
+ async with self._db() as session:
23
+ async for annotation in await session.stream_scalars(
24
+ select(TraceAnnotation).where(TraceAnnotation.trace_rowid.in_(keys))
25
+ ):
26
+ annotations_by_id[annotation.trace_rowid].append(annotation)
27
+ return [annotations_by_id[key] for key in keys]
@@ -1,6 +1,8 @@
1
1
  from graphql.error import GraphQLError
2
2
  from strawberry.extensions import MaskErrors
3
3
 
4
+ from phoenix.config import get_env_mask_internal_server_errors
5
+
4
6
 
5
7
  class CustomGraphQLError(Exception):
6
8
  """
@@ -51,4 +53,6 @@ def _should_mask_error(error: GraphQLError) -> bool:
51
53
  """
52
54
  Masks unexpected errors raised from GraphQL resolvers.
53
55
  """
54
- return not isinstance(error.original_error, CustomGraphQLError)
56
+ return get_env_mask_internal_server_errors() and not isinstance(
57
+ error.original_error, CustomGraphQLError
58
+ )
@@ -57,6 +57,7 @@ from phoenix.server.api.types.GenerativeProvider import GenerativeProviderKey
57
57
  if TYPE_CHECKING:
58
58
  import httpx
59
59
  from anthropic.types import MessageParam, TextBlockParam, ToolResultBlockParam
60
+ from botocore.awsrequest import AWSPreparedRequest # type: ignore[import-untyped]
60
61
  from google.generativeai.types import ContentType
61
62
  from openai import AsyncAzureOpenAI, AsyncOpenAI
62
63
  from openai.types import CompletionUsage
@@ -308,7 +309,6 @@ class OpenAIBaseStreamingClient(PlaygroundStreamingClient):
308
309
  invocation_name="top_p",
309
310
  canonical_name=CanonicalParameterName.TOP_P,
310
311
  label="Top P",
311
- default_value=1.0,
312
312
  min_value=0.0,
313
313
  max_value=1.0,
314
314
  ),
@@ -327,6 +327,10 @@ class OpenAIBaseStreamingClient(PlaygroundStreamingClient):
327
327
  label="Response Format",
328
328
  canonical_name=CanonicalParameterName.RESPONSE_FORMAT,
329
329
  ),
330
+ JSONInvocationParameter(
331
+ invocation_name="extra_body",
332
+ label="Extra Body",
333
+ ),
330
334
  ]
331
335
 
332
336
  async def chat_completion_create(
@@ -543,7 +547,11 @@ class DeepSeekStreamingClient(OpenAIBaseStreamingClient):
543
547
  raise BadRequest("An API key is required for DeepSeek models")
544
548
  api_key = "sk-fake-api-key"
545
549
 
546
- client = AsyncOpenAI(api_key=api_key, base_url=base_url or "https://api.deepseek.com")
550
+ client = AsyncOpenAI(
551
+ api_key=api_key,
552
+ base_url=base_url or "https://api.deepseek.com",
553
+ default_headers=model.custom_headers or None,
554
+ )
547
555
  super().__init__(client=client, model=model, credentials=credentials)
548
556
  # DeepSeek uses OpenAI-compatible API but we'll track it as a separate provider
549
557
  # Adding a custom "deepseek" provider value to make it distinguishable in traces
@@ -581,7 +589,11 @@ class XAIStreamingClient(OpenAIBaseStreamingClient):
581
589
  raise BadRequest("An API key is required for xAI models")
582
590
  api_key = "sk-fake-api-key"
583
591
 
584
- client = AsyncOpenAI(api_key=api_key, base_url=base_url or "https://api.x.ai/v1")
592
+ client = AsyncOpenAI(
593
+ api_key=api_key,
594
+ base_url=base_url or "https://api.x.ai/v1",
595
+ default_headers=model.custom_headers or None,
596
+ )
585
597
  super().__init__(client=client, model=model, credentials=credentials)
586
598
  # xAI uses OpenAI-compatible API but we'll track it as a separate provider
587
599
  # Adding a custom "xai" provider value to make it distinguishable in traces
@@ -618,7 +630,11 @@ class OllamaStreamingClient(OpenAIBaseStreamingClient):
618
630
  if not base_url:
619
631
  raise BadRequest("An Ollama base URL is required for Ollama models")
620
632
  api_key = "ollama"
621
- client = AsyncOpenAI(api_key=api_key, base_url=base_url)
633
+ client = AsyncOpenAI(
634
+ api_key=api_key,
635
+ base_url=base_url,
636
+ default_headers=model.custom_headers or None,
637
+ )
622
638
  super().__init__(client=client, model=model, credentials=credentials)
623
639
  # Ollama uses OpenAI-compatible API but we'll track it as a separate provider
624
640
  # Adding a custom "ollama" provider value to make it distinguishable in traces
@@ -630,13 +646,17 @@ class OllamaStreamingClient(OpenAIBaseStreamingClient):
630
646
  provider_key=GenerativeProviderKey.AWS,
631
647
  model_names=[
632
648
  PROVIDER_DEFAULT,
633
- "anthropic.claude-3-5-sonnet-20240620-v1:0",
649
+ "anthropic.claude-opus-4-5-20251101-v1:0",
650
+ "anthropic.claude-sonnet-4-5-20250929-v1:0",
651
+ "anthropic.claude-haiku-4-5-20251001-v1:0",
652
+ "anthropic.claude-opus-4-1-20250805-v1:0",
653
+ "anthropic.claude-opus-4-20250514-v1:0",
654
+ "anthropic.claude-sonnet-4-20250514-v1:0",
634
655
  "anthropic.claude-3-7-sonnet-20250219-v1:0",
635
- "anthropic.claude-3-haiku-20240307-v1:0",
636
656
  "anthropic.claude-3-5-sonnet-20241022-v2:0",
657
+ "anthropic.claude-3-5-sonnet-20240620-v1:0",
637
658
  "anthropic.claude-3-5-haiku-20241022-v1:0",
638
- "anthropic.claude-opus-4-20250514-v1:0",
639
- "anthropic.claude-sonnet-4-20250514-v1:0",
659
+ "anthropic.claude-3-haiku-20240307-v1:0",
640
660
  "amazon.titan-embed-text-v2:0",
641
661
  "amazon.nova-pro-v1:0",
642
662
  "amazon.nova-premier-v1:0:8k",
@@ -671,29 +691,45 @@ class BedrockStreamingClient(PlaygroundStreamingClient):
671
691
  import boto3 # type: ignore[import-untyped]
672
692
 
673
693
  super().__init__(model=model, credentials=credentials)
674
- self.region = model.region or "us-east-1"
694
+ region = model.region or "us-east-1"
675
695
  self.api = "converse"
676
- self.aws_access_key_id = _get_credential_value(credentials, "AWS_ACCESS_KEY_ID") or getenv(
696
+ custom_headers = model.custom_headers
697
+ aws_access_key_id = _get_credential_value(credentials, "AWS_ACCESS_KEY_ID") or getenv(
677
698
  "AWS_ACCESS_KEY_ID"
678
699
  )
679
- self.aws_secret_access_key = _get_credential_value(
700
+ aws_secret_access_key = _get_credential_value(
680
701
  credentials, "AWS_SECRET_ACCESS_KEY"
681
702
  ) or getenv("AWS_SECRET_ACCESS_KEY")
682
- self.aws_session_token = _get_credential_value(credentials, "AWS_SESSION_TOKEN") or getenv(
703
+ aws_session_token = _get_credential_value(credentials, "AWS_SESSION_TOKEN") or getenv(
683
704
  "AWS_SESSION_TOKEN"
684
705
  )
685
706
  self.model_name = model.name
686
- self.client = boto3.client(
687
- service_name="bedrock-runtime",
688
- region_name="us-east-1", # match the default region in the UI
689
- aws_access_key_id=self.aws_access_key_id,
690
- aws_secret_access_key=self.aws_secret_access_key,
691
- aws_session_token=self.aws_session_token,
707
+ session = boto3.Session(
708
+ region_name=region,
709
+ aws_access_key_id=aws_access_key_id,
710
+ aws_secret_access_key=aws_secret_access_key,
711
+ aws_session_token=aws_session_token,
692
712
  )
713
+ client = session.client(service_name="bedrock-runtime")
714
+
715
+ # Add custom headers support via boto3 event system
716
+ if custom_headers:
717
+
718
+ def add_custom_headers(request: "AWSPreparedRequest", **kwargs: Any) -> None:
719
+ request.headers.update(custom_headers)
720
+
721
+ client.meta.events.register("before-send.*", add_custom_headers)
693
722
 
723
+ self.client = client
694
724
  self._attributes[LLM_PROVIDER] = "aws"
695
725
  self._attributes[LLM_SYSTEM] = "aws"
696
726
 
727
+ @staticmethod
728
+ def _setup_custom_headers(client: Any, custom_headers: Mapping[str, str]) -> None:
729
+ """Setup custom headers using boto3's event system."""
730
+ if not custom_headers:
731
+ return
732
+
697
733
  @classmethod
698
734
  def dependencies(cls) -> list[Dependency]:
699
735
  return [Dependency(name="boto3")]
@@ -719,7 +755,6 @@ class BedrockStreamingClient(PlaygroundStreamingClient):
719
755
  invocation_name="top_p",
720
756
  canonical_name=CanonicalParameterName.TOP_P,
721
757
  label="Top P",
722
- default_value=1.0,
723
758
  min_value=0.0,
724
759
  max_value=1.0,
725
760
  ),
@@ -738,18 +773,6 @@ class BedrockStreamingClient(PlaygroundStreamingClient):
738
773
  tools: list[JSONScalarType],
739
774
  **invocation_parameters: Any,
740
775
  ) -> AsyncIterator[ChatCompletionChunk]:
741
- import boto3
742
-
743
- if (
744
- self.client.meta.region_name != self.region
745
- ): # override the region if it's different from the default
746
- self.client = boto3.client(
747
- "bedrock-runtime",
748
- region_name=self.region,
749
- aws_access_key_id=self.aws_access_key_id,
750
- aws_secret_access_key=self.aws_secret_access_key,
751
- aws_session_token=self.aws_session_token,
752
- )
753
776
  if self.api == "invoke":
754
777
  async for chunk in self._handle_invoke_api(messages, tools, invocation_parameters):
755
778
  yield chunk
@@ -771,15 +794,25 @@ class BedrockStreamingClient(PlaygroundStreamingClient):
771
794
  # Build messages in Converse API format
772
795
  converse_messages = self._build_converse_messages(messages)
773
796
 
797
+ inference_config = {}
798
+ if (
799
+ "max_tokens" in invocation_parameters
800
+ and invocation_parameters["max_tokens"] is not None
801
+ ):
802
+ inference_config["maxTokens"] = invocation_parameters["max_tokens"]
803
+ if (
804
+ "temperature" in invocation_parameters
805
+ and invocation_parameters["temperature"] is not None
806
+ ):
807
+ inference_config["temperature"] = invocation_parameters["temperature"]
808
+ if "top_p" in invocation_parameters and invocation_parameters["top_p"] is not None:
809
+ inference_config["topP"] = invocation_parameters["top_p"]
810
+
774
811
  # Build the request parameters for Converse API
775
812
  converse_params: dict[str, Any] = {
776
- "modelId": f"us.{self.model_name}",
813
+ "modelId": self.model_name,
777
814
  "messages": converse_messages,
778
- "inferenceConfig": {
779
- "maxTokens": invocation_parameters["max_tokens"],
780
- "temperature": invocation_parameters["temperature"],
781
- "topP": invocation_parameters["top_p"],
782
- },
815
+ "inferenceConfig": inference_config,
783
816
  }
784
817
 
785
818
  # Add system prompt if available
@@ -912,16 +945,26 @@ class BedrockStreamingClient(PlaygroundStreamingClient):
912
945
  bedrock_messages, system_prompt = self._build_bedrock_messages(messages)
913
946
  bedrock_params = {
914
947
  "anthropic_version": "bedrock-2023-05-31",
915
- "max_tokens": invocation_parameters["max_tokens"],
916
948
  "messages": bedrock_messages,
917
949
  "system": system_prompt,
918
- "temperature": invocation_parameters["temperature"],
919
- "top_p": invocation_parameters["top_p"],
920
950
  "tools": tools,
921
951
  }
922
952
 
953
+ if (
954
+ "max_tokens" in invocation_parameters
955
+ and invocation_parameters["max_tokens"] is not None
956
+ ):
957
+ bedrock_params["max_tokens"] = invocation_parameters["max_tokens"]
958
+ if (
959
+ "temperature" in invocation_parameters
960
+ and invocation_parameters["temperature"] is not None
961
+ ):
962
+ bedrock_params["temperature"] = invocation_parameters["temperature"]
963
+ if "top_p" in invocation_parameters and invocation_parameters["top_p"] is not None:
964
+ bedrock_params["top_p"] = invocation_parameters["top_p"]
965
+
923
966
  response = self.client.invoke_model_with_response_stream(
924
- modelId=f"us.{self.model_name}", # or another Claude model
967
+ modelId=self.model_name,
925
968
  contentType="application/json",
926
969
  accept="application/json",
927
970
  body=json.dumps(bedrock_params),
@@ -1134,13 +1177,24 @@ class OpenAIStreamingClient(OpenAIBaseStreamingClient):
1134
1177
  raise BadRequest("An API key is required for OpenAI models")
1135
1178
  api_key = "sk-fake-api-key"
1136
1179
 
1137
- client = AsyncOpenAI(api_key=api_key, base_url=base_url)
1180
+ client = AsyncOpenAI(
1181
+ api_key=api_key,
1182
+ base_url=base_url,
1183
+ default_headers=model.custom_headers or None,
1184
+ timeout=30,
1185
+ )
1138
1186
  super().__init__(client=client, model=model, credentials=credentials)
1139
1187
  self._attributes[LLM_PROVIDER] = OpenInferenceLLMProviderValues.OPENAI.value
1140
1188
  self._attributes[LLM_SYSTEM] = OpenInferenceLLMSystemValues.OPENAI.value
1141
1189
 
1142
1190
 
1143
1191
  _OPENAI_REASONING_MODELS = [
1192
+ "gpt-5.2",
1193
+ "gpt-5.2-2025-12-11",
1194
+ "gpt-5.2-chat-latest",
1195
+ "gpt-5.1",
1196
+ "gpt-5.1-2025-11-13",
1197
+ "gpt-5.1-chat-latest",
1144
1198
  "gpt-5",
1145
1199
  "gpt-5-mini",
1146
1200
  "gpt-5-nano",
@@ -1194,6 +1248,10 @@ class OpenAIReasoningReasoningModelsMixin:
1194
1248
  label="Response Format",
1195
1249
  canonical_name=CanonicalParameterName.RESPONSE_FORMAT,
1196
1250
  ),
1251
+ JSONInvocationParameter(
1252
+ invocation_name="extra_body",
1253
+ label="Extra Body",
1254
+ ),
1197
1255
  ]
1198
1256
 
1199
1257
 
@@ -1289,6 +1347,7 @@ class AzureOpenAIStreamingClient(OpenAIBaseStreamingClient):
1289
1347
  api_key=api_key,
1290
1348
  azure_endpoint=endpoint,
1291
1349
  api_version=api_version,
1350
+ default_headers=model.custom_headers or None,
1292
1351
  )
1293
1352
  else:
1294
1353
  try:
@@ -1306,6 +1365,7 @@ class AzureOpenAIStreamingClient(OpenAIBaseStreamingClient):
1306
1365
  ),
1307
1366
  azure_endpoint=endpoint,
1308
1367
  api_version=api_version,
1368
+ default_headers=model.custom_headers or None,
1309
1369
  )
1310
1370
  super().__init__(client=client, model=model, credentials=credentials)
1311
1371
  self._attributes[LLM_PROVIDER] = OpenInferenceLLMProviderValues.AZURE.value
@@ -1423,13 +1483,8 @@ class AzureOpenAIReasoningNonStreamingClient(
1423
1483
  provider_key=GenerativeProviderKey.ANTHROPIC,
1424
1484
  model_names=[
1425
1485
  PROVIDER_DEFAULT,
1426
- "claude-3-5-sonnet-latest",
1427
1486
  "claude-3-5-haiku-latest",
1428
- "claude-3-5-sonnet-20241022",
1429
1487
  "claude-3-5-haiku-20241022",
1430
- "claude-3-5-sonnet-20240620",
1431
- "claude-3-opus-latest",
1432
- "claude-3-sonnet-20240229",
1433
1488
  "claude-3-haiku-20240307",
1434
1489
  ],
1435
1490
  )
@@ -1453,7 +1508,10 @@ class AnthropicStreamingClient(PlaygroundStreamingClient):
1453
1508
  if not api_key:
1454
1509
  raise BadRequest("An API key is required for Anthropic models")
1455
1510
 
1456
- self.client = anthropic.AsyncAnthropic(api_key=api_key)
1511
+ self.client = anthropic.AsyncAnthropic(
1512
+ api_key=api_key,
1513
+ default_headers=model.custom_headers or None,
1514
+ )
1457
1515
  self.model_name = model.name
1458
1516
  self.rate_limiter = PlaygroundRateLimiter(model.provider_key, anthropic.RateLimitError)
1459
1517
  self.client._client = _HttpxClient(self.client._client, self._attributes)
@@ -1489,7 +1547,6 @@ class AnthropicStreamingClient(PlaygroundStreamingClient):
1489
1547
  invocation_name="top_p",
1490
1548
  canonical_name=CanonicalParameterName.TOP_P,
1491
1549
  label="Top P",
1492
- default_value=1.0,
1493
1550
  min_value=0.0,
1494
1551
  max_value=1.0,
1495
1552
  ),
@@ -1635,10 +1692,16 @@ class AnthropicStreamingClient(PlaygroundStreamingClient):
1635
1692
  @register_llm_client(
1636
1693
  provider_key=GenerativeProviderKey.ANTHROPIC,
1637
1694
  model_names=[
1638
- "claude-sonnet-4-0",
1639
- "claude-sonnet-4-20250514",
1695
+ "claude-opus-4-5",
1696
+ "claude-opus-4-5-20251101",
1697
+ "claude-sonnet-4-5",
1698
+ "claude-sonnet-4-5-20250929",
1699
+ "claude-haiku-4-5",
1700
+ "claude-haiku-4-5-20251001",
1640
1701
  "claude-opus-4-1",
1641
1702
  "claude-opus-4-1-20250805",
1703
+ "claude-sonnet-4-0",
1704
+ "claude-sonnet-4-20250514",
1642
1705
  "claude-opus-4-0",
1643
1706
  "claude-opus-4-20250514",
1644
1707
  "claude-3-7-sonnet-latest",
@@ -1663,7 +1726,6 @@ class AnthropicReasoningStreamingClient(AnthropicStreamingClient):
1663
1726
  provider_key=GenerativeProviderKey.GOOGLE,
1664
1727
  model_names=[
1665
1728
  PROVIDER_DEFAULT,
1666
- "gemini-2.5-pro-preview-03-25",
1667
1729
  "gemini-2.0-flash-lite",
1668
1730
  "gemini-2.0-flash-001",
1669
1731
  "gemini-2.0-flash-thinking-exp-01-21",
@@ -1679,7 +1741,7 @@ class GoogleStreamingClient(PlaygroundStreamingClient):
1679
1741
  model: GenerativeModelInput,
1680
1742
  credentials: Optional[list[PlaygroundClientCredential]] = None,
1681
1743
  ) -> None:
1682
- import google.generativeai as google_genai
1744
+ import google.genai as google_genai
1683
1745
 
1684
1746
  super().__init__(model=model, credentials=credentials)
1685
1747
  self._attributes[LLM_PROVIDER] = OpenInferenceLLMProviderValues.GOOGLE.value
@@ -1696,12 +1758,12 @@ class GoogleStreamingClient(PlaygroundStreamingClient):
1696
1758
  if not api_key:
1697
1759
  raise BadRequest("An API key is required for Gemini models")
1698
1760
 
1699
- google_genai.configure(api_key=api_key)
1761
+ self.client = google_genai.Client(api_key=api_key)
1700
1762
  self.model_name = model.name
1701
1763
 
1702
1764
  @classmethod
1703
1765
  def dependencies(cls) -> list[Dependency]:
1704
- return [Dependency(name="google-generativeai", module_name="google.generativeai")]
1766
+ return [Dependency(name="google-genai", module_name="google.genai")]
1705
1767
 
1706
1768
  @classmethod
1707
1769
  def supported_invocation_parameters(cls) -> list[InvocationParameter]:
@@ -1738,7 +1800,6 @@ class GoogleStreamingClient(PlaygroundStreamingClient):
1738
1800
  invocation_name="top_p",
1739
1801
  canonical_name=CanonicalParameterName.TOP_P,
1740
1802
  label="Top P",
1741
- default_value=1.0,
1742
1803
  min_value=0.0,
1743
1804
  max_value=1.0,
1744
1805
  ),
@@ -1746,6 +1807,11 @@ class GoogleStreamingClient(PlaygroundStreamingClient):
1746
1807
  invocation_name="top_k",
1747
1808
  label="Top K",
1748
1809
  ),
1810
+ JSONInvocationParameter(
1811
+ invocation_name="tool_config",
1812
+ label="Tool Config",
1813
+ canonical_name=CanonicalParameterName.TOOL_CHOICE,
1814
+ ),
1749
1815
  ]
1750
1816
 
1751
1817
  async def chat_completion_create(
@@ -1756,28 +1822,25 @@ class GoogleStreamingClient(PlaygroundStreamingClient):
1756
1822
  tools: list[JSONScalarType],
1757
1823
  **invocation_parameters: Any,
1758
1824
  ) -> AsyncIterator[ChatCompletionChunk]:
1759
- import google.generativeai as google_genai
1825
+ from google.genai import types
1760
1826
 
1761
- google_message_history, current_message, system_prompt = self._build_google_messages(
1762
- messages
1763
- )
1827
+ contents, system_prompt = self._build_google_messages(messages)
1828
+
1829
+ config_dict = invocation_parameters.copy()
1764
1830
 
1765
- model_args = {"model_name": self.model_name}
1766
1831
  if system_prompt:
1767
- model_args["system_instruction"] = system_prompt
1768
- client = google_genai.GenerativeModel(**model_args)
1832
+ config_dict["system_instruction"] = system_prompt
1769
1833
 
1770
- google_config = google_genai.GenerationConfig(
1771
- **invocation_parameters,
1834
+ if tools:
1835
+ function_declarations = [types.FunctionDeclaration(**tool) for tool in tools]
1836
+ config_dict["tools"] = [types.Tool(function_declarations=function_declarations)]
1837
+
1838
+ config = types.GenerateContentConfig.model_validate(config_dict)
1839
+ stream = await self.client.aio.models.generate_content_stream(
1840
+ model=f"models/{self.model_name}",
1841
+ contents=contents,
1842
+ config=config,
1772
1843
  )
1773
- google_params = {
1774
- "content": current_message,
1775
- "generation_config": google_config,
1776
- "stream": True,
1777
- }
1778
-
1779
- chat = client.start_chat(history=google_message_history)
1780
- stream = await chat.send_message_async(**google_params)
1781
1844
  async for event in stream:
1782
1845
  self._attributes.update(
1783
1846
  {
@@ -1786,31 +1849,148 @@ class GoogleStreamingClient(PlaygroundStreamingClient):
1786
1849
  LLM_TOKEN_COUNT_TOTAL: event.usage_metadata.total_token_count,
1787
1850
  }
1788
1851
  )
1789
- yield TextChunk(content=event.text)
1852
+
1853
+ if event.candidates:
1854
+ candidate = event.candidates[0]
1855
+ if candidate.content and candidate.content.parts:
1856
+ for part in candidate.content.parts:
1857
+ if function_call := part.function_call:
1858
+ yield ToolCallChunk(
1859
+ id=function_call.id or "",
1860
+ function=FunctionCallChunk(
1861
+ name=function_call.name or "",
1862
+ arguments=json.dumps(function_call.args or {}),
1863
+ ),
1864
+ )
1865
+ elif text := part.text:
1866
+ yield TextChunk(content=text)
1790
1867
 
1791
1868
  def _build_google_messages(
1792
1869
  self,
1793
1870
  messages: list[tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[str]]]],
1794
- ) -> tuple[list["ContentType"], str, str]:
1795
- google_message_history: list["ContentType"] = []
1871
+ ) -> tuple[list["ContentType"], str]:
1872
+ """Build Google messages following the standard pattern - process ALL messages."""
1873
+ google_messages: list["ContentType"] = []
1796
1874
  system_prompts = []
1797
1875
  for role, content, _tool_call_id, _tool_calls in messages:
1798
1876
  if role == ChatCompletionMessageRole.USER:
1799
- google_message_history.append({"role": "user", "parts": content})
1877
+ google_messages.append({"role": "user", "parts": [{"text": content}]})
1800
1878
  elif role == ChatCompletionMessageRole.AI:
1801
- google_message_history.append({"role": "model", "parts": content})
1879
+ google_messages.append({"role": "model", "parts": [{"text": content}]})
1802
1880
  elif role == ChatCompletionMessageRole.SYSTEM:
1803
1881
  system_prompts.append(content)
1804
1882
  elif role == ChatCompletionMessageRole.TOOL:
1805
1883
  raise NotImplementedError
1806
1884
  else:
1807
1885
  assert_never(role)
1808
- if google_message_history:
1809
- prompt = google_message_history.pop()["parts"]
1810
- else:
1811
- prompt = ""
1812
1886
 
1813
- return google_message_history, prompt, "\n".join(system_prompts)
1887
+ return google_messages, "\n".join(system_prompts)
1888
+
1889
+
1890
+ @register_llm_client(
1891
+ provider_key=GenerativeProviderKey.GOOGLE,
1892
+ model_names=[
1893
+ PROVIDER_DEFAULT,
1894
+ "gemini-2.5-pro",
1895
+ "gemini-2.5-flash",
1896
+ "gemini-2.5-flash-lite",
1897
+ "gemini-2.5-pro-preview-03-25",
1898
+ ],
1899
+ )
1900
+ class Gemini25GoogleStreamingClient(GoogleStreamingClient):
1901
+ @classmethod
1902
+ def supported_invocation_parameters(cls) -> list[InvocationParameter]:
1903
+ return [
1904
+ BoundedFloatInvocationParameter(
1905
+ invocation_name="temperature",
1906
+ canonical_name=CanonicalParameterName.TEMPERATURE,
1907
+ label="Temperature",
1908
+ default_value=1.0,
1909
+ min_value=0.0,
1910
+ max_value=2.0,
1911
+ ),
1912
+ IntInvocationParameter(
1913
+ invocation_name="max_output_tokens",
1914
+ canonical_name=CanonicalParameterName.MAX_COMPLETION_TOKENS,
1915
+ label="Max Output Tokens",
1916
+ ),
1917
+ StringListInvocationParameter(
1918
+ invocation_name="stop_sequences",
1919
+ canonical_name=CanonicalParameterName.STOP_SEQUENCES,
1920
+ label="Stop Sequences",
1921
+ ),
1922
+ BoundedFloatInvocationParameter(
1923
+ invocation_name="top_p",
1924
+ canonical_name=CanonicalParameterName.TOP_P,
1925
+ label="Top P",
1926
+ min_value=0.0,
1927
+ max_value=1.0,
1928
+ ),
1929
+ FloatInvocationParameter(
1930
+ invocation_name="top_k",
1931
+ label="Top K",
1932
+ ),
1933
+ JSONInvocationParameter(
1934
+ invocation_name="tool_config",
1935
+ label="Tool Choice",
1936
+ canonical_name=CanonicalParameterName.TOOL_CHOICE,
1937
+ ),
1938
+ ]
1939
+
1940
+
1941
+ @register_llm_client(
1942
+ provider_key=GenerativeProviderKey.GOOGLE,
1943
+ model_names=[
1944
+ "gemini-3-pro-preview",
1945
+ ],
1946
+ )
1947
+ class Gemini3GoogleStreamingClient(Gemini25GoogleStreamingClient):
1948
+ @classmethod
1949
+ def supported_invocation_parameters(cls) -> list[InvocationParameter]:
1950
+ return [
1951
+ StringInvocationParameter(
1952
+ invocation_name="thinking_level",
1953
+ label="Thinking Level",
1954
+ canonical_name=CanonicalParameterName.REASONING_EFFORT,
1955
+ ),
1956
+ *super().supported_invocation_parameters(),
1957
+ ]
1958
+
1959
+ async def chat_completion_create(
1960
+ self,
1961
+ messages: list[
1962
+ tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[JSONScalarType]]]
1963
+ ],
1964
+ tools: list[JSONScalarType],
1965
+ **invocation_parameters: Any,
1966
+ ) -> AsyncIterator[ChatCompletionChunk]:
1967
+ # Extract thinking_level and construct thinking_config
1968
+ thinking_level = invocation_parameters.pop("thinking_level", None)
1969
+
1970
+ if thinking_level:
1971
+ try:
1972
+ import google.genai
1973
+ from packaging.version import parse as parse_version
1974
+
1975
+ if parse_version(google.genai.__version__) < parse_version("1.50.0"):
1976
+ raise ImportError
1977
+ except (ImportError, AttributeError):
1978
+ raise BadRequest(
1979
+ "Reasoning capabilities for Gemini models require `google-genai>=1.50.0` "
1980
+ "and Python >= 3.10."
1981
+ )
1982
+
1983
+ # NOTE: as of gemini 1.51.0 medium thinking is not supported
1984
+ # but will eventually be added in a future version
1985
+ # we are purposefully allowing users to select medium knowing
1986
+ # it does not work.
1987
+ invocation_parameters["thinking_config"] = {
1988
+ "include_thoughts": True,
1989
+ "thinking_level": thinking_level.upper(),
1990
+ }
1991
+
1992
+ async for chunk in super().chat_completion_create(messages, tools, **invocation_parameters):
1993
+ yield chunk
1814
1994
 
1815
1995
 
1816
1996
  def initialize_playground_clients() -> None: