google-genai 1.53.0__py3-none-any.whl → 1.55.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (324) hide show
  1. google/genai/__init__.py +1 -0
  2. google/genai/_api_client.py +6 -6
  3. google/genai/_interactions/__init__.py +117 -0
  4. google/genai/_interactions/_base_client.py +2019 -0
  5. google/genai/_interactions/_client.py +511 -0
  6. google/genai/_interactions/_compat.py +234 -0
  7. google/genai/_interactions/_constants.py +29 -0
  8. google/genai/_interactions/_exceptions.py +122 -0
  9. google/genai/_interactions/_files.py +139 -0
  10. google/genai/_interactions/_models.py +873 -0
  11. google/genai/_interactions/_qs.py +165 -0
  12. google/genai/_interactions/_resource.py +58 -0
  13. google/genai/_interactions/_response.py +847 -0
  14. google/genai/_interactions/_streaming.py +354 -0
  15. google/genai/_interactions/_types.py +276 -0
  16. google/genai/_interactions/_utils/__init__.py +79 -0
  17. google/genai/_interactions/_utils/_compat.py +61 -0
  18. google/genai/_interactions/_utils/_datetime_parse.py +151 -0
  19. google/genai/_interactions/_utils/_logs.py +40 -0
  20. google/genai/_interactions/_utils/_proxy.py +80 -0
  21. google/genai/_interactions/_utils/_reflection.py +57 -0
  22. google/genai/_interactions/_utils/_resources_proxy.py +39 -0
  23. google/genai/_interactions/_utils/_streams.py +27 -0
  24. google/genai/_interactions/_utils/_sync.py +73 -0
  25. google/genai/_interactions/_utils/_transform.py +472 -0
  26. google/genai/_interactions/_utils/_typing.py +172 -0
  27. google/genai/_interactions/_utils/_utils.py +437 -0
  28. google/genai/_interactions/_version.py +18 -0
  29. google/genai/_interactions/resources/__init__.py +34 -0
  30. google/genai/_interactions/resources/interactions.py +1350 -0
  31. google/genai/_interactions/types/__init__.py +107 -0
  32. google/genai/_interactions/types/allowed_tools.py +33 -0
  33. google/genai/_interactions/types/allowed_tools_param.py +35 -0
  34. google/genai/_interactions/types/annotation.py +42 -0
  35. google/genai/_interactions/types/annotation_param.py +42 -0
  36. google/genai/_interactions/types/audio_content.py +38 -0
  37. google/genai/_interactions/types/audio_content_param.py +45 -0
  38. google/genai/_interactions/types/audio_mime_type.py +25 -0
  39. google/genai/_interactions/types/audio_mime_type_param.py +27 -0
  40. google/genai/_interactions/types/code_execution_call_arguments.py +33 -0
  41. google/genai/_interactions/types/code_execution_call_arguments_param.py +32 -0
  42. google/genai/_interactions/types/code_execution_call_content.py +37 -0
  43. google/genai/_interactions/types/code_execution_call_content_param.py +37 -0
  44. google/genai/_interactions/types/code_execution_result_content.py +42 -0
  45. google/genai/_interactions/types/code_execution_result_content_param.py +41 -0
  46. google/genai/_interactions/types/content_delta.py +358 -0
  47. google/genai/_interactions/types/content_start.py +79 -0
  48. google/genai/_interactions/types/content_stop.py +35 -0
  49. google/genai/_interactions/types/deep_research_agent_config.py +33 -0
  50. google/genai/_interactions/types/deep_research_agent_config_param.py +32 -0
  51. google/genai/_interactions/types/document_content.py +36 -0
  52. google/genai/_interactions/types/document_content_param.py +43 -0
  53. google/genai/_interactions/types/dynamic_agent_config.py +44 -0
  54. google/genai/_interactions/types/dynamic_agent_config_param.py +33 -0
  55. google/genai/_interactions/types/error_event.py +46 -0
  56. google/genai/_interactions/types/file_search_result_content.py +46 -0
  57. google/genai/_interactions/types/file_search_result_content_param.py +46 -0
  58. google/genai/_interactions/types/function.py +38 -0
  59. google/genai/_interactions/types/function_call_content.py +39 -0
  60. google/genai/_interactions/types/function_call_content_param.py +39 -0
  61. google/genai/_interactions/types/function_param.py +37 -0
  62. google/genai/_interactions/types/function_result_content.py +52 -0
  63. google/genai/_interactions/types/function_result_content_param.py +54 -0
  64. google/genai/_interactions/types/generation_config.py +57 -0
  65. google/genai/_interactions/types/generation_config_param.py +59 -0
  66. google/genai/_interactions/types/google_search_call_arguments.py +29 -0
  67. google/genai/_interactions/types/google_search_call_arguments_param.py +31 -0
  68. google/genai/_interactions/types/google_search_call_content.py +37 -0
  69. google/genai/_interactions/types/google_search_call_content_param.py +37 -0
  70. google/genai/_interactions/types/google_search_result.py +35 -0
  71. google/genai/_interactions/types/google_search_result_content.py +43 -0
  72. google/genai/_interactions/types/google_search_result_content_param.py +44 -0
  73. google/genai/_interactions/types/google_search_result_param.py +35 -0
  74. google/genai/_interactions/types/image_content.py +41 -0
  75. google/genai/_interactions/types/image_content_param.py +48 -0
  76. google/genai/_interactions/types/image_mime_type.py +23 -0
  77. google/genai/_interactions/types/image_mime_type_param.py +25 -0
  78. google/genai/_interactions/types/interaction.py +165 -0
  79. google/genai/_interactions/types/interaction_create_params.py +212 -0
  80. google/genai/_interactions/types/interaction_event.py +37 -0
  81. google/genai/_interactions/types/interaction_get_params.py +46 -0
  82. google/genai/_interactions/types/interaction_sse_event.py +32 -0
  83. google/genai/_interactions/types/interaction_status_update.py +37 -0
  84. google/genai/_interactions/types/mcp_server_tool_call_content.py +42 -0
  85. google/genai/_interactions/types/mcp_server_tool_call_content_param.py +42 -0
  86. google/genai/_interactions/types/mcp_server_tool_result_content.py +52 -0
  87. google/genai/_interactions/types/mcp_server_tool_result_content_param.py +54 -0
  88. google/genai/_interactions/types/model.py +36 -0
  89. google/genai/_interactions/types/model_param.py +38 -0
  90. google/genai/_interactions/types/speech_config.py +35 -0
  91. google/genai/_interactions/types/speech_config_param.py +35 -0
  92. google/genai/_interactions/types/text_content.py +37 -0
  93. google/genai/_interactions/types/text_content_param.py +38 -0
  94. google/genai/_interactions/types/thinking_level.py +22 -0
  95. google/genai/_interactions/types/thought_content.py +41 -0
  96. google/genai/_interactions/types/thought_content_param.py +47 -0
  97. google/genai/_interactions/types/tool.py +100 -0
  98. google/genai/_interactions/types/tool_choice.py +26 -0
  99. google/genai/_interactions/types/tool_choice_config.py +28 -0
  100. google/genai/_interactions/types/tool_choice_config_param.py +29 -0
  101. google/genai/_interactions/types/tool_choice_param.py +28 -0
  102. google/genai/_interactions/types/tool_choice_type.py +22 -0
  103. google/genai/_interactions/types/tool_param.py +97 -0
  104. google/genai/_interactions/types/turn.py +76 -0
  105. google/genai/_interactions/types/turn_param.py +73 -0
  106. google/genai/_interactions/types/url_context_call_arguments.py +29 -0
  107. google/genai/_interactions/types/url_context_call_arguments_param.py +31 -0
  108. google/genai/_interactions/types/url_context_call_content.py +37 -0
  109. google/genai/_interactions/types/url_context_call_content_param.py +37 -0
  110. google/genai/_interactions/types/url_context_result.py +33 -0
  111. google/genai/_interactions/types/url_context_result_content.py +43 -0
  112. google/genai/_interactions/types/url_context_result_content_param.py +44 -0
  113. google/genai/_interactions/types/url_context_result_param.py +32 -0
  114. google/genai/_interactions/types/usage.py +106 -0
  115. google/genai/_interactions/types/usage_param.py +106 -0
  116. google/genai/_interactions/types/video_content.py +41 -0
  117. google/genai/_interactions/types/video_content_param.py +48 -0
  118. google/genai/_interactions/types/video_mime_type.py +36 -0
  119. google/genai/_interactions/types/video_mime_type_param.py +38 -0
  120. google/genai/_live_converters.py +34 -3
  121. google/genai/_tokens_converters.py +5 -0
  122. google/genai/batches.py +62 -55
  123. google/genai/client.py +223 -0
  124. google/genai/errors.py +16 -1
  125. google/genai/file_search_stores.py +60 -60
  126. google/genai/files.py +56 -56
  127. google/genai/interactions.py +17 -0
  128. google/genai/live.py +4 -3
  129. google/genai/models.py +15 -3
  130. google/genai/tests/__init__.py +21 -0
  131. google/genai/tests/afc/__init__.py +21 -0
  132. google/genai/tests/afc/test_convert_if_exist_pydantic_model.py +309 -0
  133. google/genai/tests/afc/test_convert_number_values_for_function_call_args.py +63 -0
  134. google/genai/tests/afc/test_find_afc_incompatible_tool_indexes.py +240 -0
  135. google/genai/tests/afc/test_generate_content_stream_afc.py +530 -0
  136. google/genai/tests/afc/test_generate_content_stream_afc_thoughts.py +77 -0
  137. google/genai/tests/afc/test_get_function_map.py +176 -0
  138. google/genai/tests/afc/test_get_function_response_parts.py +277 -0
  139. google/genai/tests/afc/test_get_max_remote_calls_for_afc.py +130 -0
  140. google/genai/tests/afc/test_invoke_function_from_dict_args.py +241 -0
  141. google/genai/tests/afc/test_raise_error_for_afc_incompatible_config.py +159 -0
  142. google/genai/tests/afc/test_should_append_afc_history.py +53 -0
  143. google/genai/tests/afc/test_should_disable_afc.py +214 -0
  144. google/genai/tests/batches/__init__.py +17 -0
  145. google/genai/tests/batches/test_cancel.py +77 -0
  146. google/genai/tests/batches/test_create.py +78 -0
  147. google/genai/tests/batches/test_create_with_bigquery.py +113 -0
  148. google/genai/tests/batches/test_create_with_file.py +82 -0
  149. google/genai/tests/batches/test_create_with_gcs.py +125 -0
  150. google/genai/tests/batches/test_create_with_inlined_requests.py +255 -0
  151. google/genai/tests/batches/test_delete.py +86 -0
  152. google/genai/tests/batches/test_embedding.py +157 -0
  153. google/genai/tests/batches/test_get.py +78 -0
  154. google/genai/tests/batches/test_list.py +79 -0
  155. google/genai/tests/caches/__init__.py +17 -0
  156. google/genai/tests/caches/constants.py +29 -0
  157. google/genai/tests/caches/test_create.py +210 -0
  158. google/genai/tests/caches/test_create_custom_url.py +105 -0
  159. google/genai/tests/caches/test_delete.py +54 -0
  160. google/genai/tests/caches/test_delete_custom_url.py +52 -0
  161. google/genai/tests/caches/test_get.py +94 -0
  162. google/genai/tests/caches/test_get_custom_url.py +52 -0
  163. google/genai/tests/caches/test_list.py +68 -0
  164. google/genai/tests/caches/test_update.py +70 -0
  165. google/genai/tests/caches/test_update_custom_url.py +58 -0
  166. google/genai/tests/chats/__init__.py +1 -0
  167. google/genai/tests/chats/test_get_history.py +597 -0
  168. google/genai/tests/chats/test_send_message.py +844 -0
  169. google/genai/tests/chats/test_validate_response.py +90 -0
  170. google/genai/tests/client/__init__.py +17 -0
  171. google/genai/tests/client/test_async_stream.py +427 -0
  172. google/genai/tests/client/test_client_close.py +197 -0
  173. google/genai/tests/client/test_client_initialization.py +1687 -0
  174. google/genai/tests/client/test_client_requests.py +355 -0
  175. google/genai/tests/client/test_custom_client.py +77 -0
  176. google/genai/tests/client/test_http_options.py +178 -0
  177. google/genai/tests/client/test_replay_client_equality.py +168 -0
  178. google/genai/tests/client/test_retries.py +846 -0
  179. google/genai/tests/client/test_upload_errors.py +136 -0
  180. google/genai/tests/common/__init__.py +17 -0
  181. google/genai/tests/common/test_common.py +954 -0
  182. google/genai/tests/conftest.py +162 -0
  183. google/genai/tests/documents/__init__.py +17 -0
  184. google/genai/tests/documents/test_delete.py +51 -0
  185. google/genai/tests/documents/test_get.py +85 -0
  186. google/genai/tests/documents/test_list.py +72 -0
  187. google/genai/tests/errors/__init__.py +1 -0
  188. google/genai/tests/errors/test_api_error.py +417 -0
  189. google/genai/tests/file_search_stores/__init__.py +17 -0
  190. google/genai/tests/file_search_stores/test_create.py +66 -0
  191. google/genai/tests/file_search_stores/test_delete.py +64 -0
  192. google/genai/tests/file_search_stores/test_get.py +94 -0
  193. google/genai/tests/file_search_stores/test_import_file.py +112 -0
  194. google/genai/tests/file_search_stores/test_list.py +57 -0
  195. google/genai/tests/file_search_stores/test_upload_to_file_search_store.py +141 -0
  196. google/genai/tests/files/__init__.py +17 -0
  197. google/genai/tests/files/test_delete.py +46 -0
  198. google/genai/tests/files/test_download.py +85 -0
  199. google/genai/tests/files/test_get.py +46 -0
  200. google/genai/tests/files/test_list.py +72 -0
  201. google/genai/tests/files/test_upload.py +255 -0
  202. google/genai/tests/imports/test_no_optional_imports.py +28 -0
  203. google/genai/tests/interactions/__init__.py +0 -0
  204. google/genai/tests/interactions/test_integration.py +80 -0
  205. google/genai/tests/live/__init__.py +16 -0
  206. google/genai/tests/live/test_live.py +2177 -0
  207. google/genai/tests/live/test_live_music.py +362 -0
  208. google/genai/tests/live/test_live_response.py +163 -0
  209. google/genai/tests/live/test_send_client_content.py +147 -0
  210. google/genai/tests/live/test_send_realtime_input.py +268 -0
  211. google/genai/tests/live/test_send_tool_response.py +222 -0
  212. google/genai/tests/local_tokenizer/__init__.py +17 -0
  213. google/genai/tests/local_tokenizer/test_local_tokenizer.py +343 -0
  214. google/genai/tests/local_tokenizer/test_local_tokenizer_loader.py +235 -0
  215. google/genai/tests/mcp/__init__.py +17 -0
  216. google/genai/tests/mcp/test_has_mcp_tool_usage.py +89 -0
  217. google/genai/tests/mcp/test_mcp_to_gemini_tools.py +191 -0
  218. google/genai/tests/mcp/test_parse_config_for_mcp_sessions.py +201 -0
  219. google/genai/tests/mcp/test_parse_config_for_mcp_usage.py +130 -0
  220. google/genai/tests/mcp/test_set_mcp_usage_header.py +72 -0
  221. google/genai/tests/models/__init__.py +17 -0
  222. google/genai/tests/models/constants.py +8 -0
  223. google/genai/tests/models/test_compute_tokens.py +120 -0
  224. google/genai/tests/models/test_count_tokens.py +159 -0
  225. google/genai/tests/models/test_delete.py +107 -0
  226. google/genai/tests/models/test_edit_image.py +264 -0
  227. google/genai/tests/models/test_embed_content.py +94 -0
  228. google/genai/tests/models/test_function_call_streaming.py +442 -0
  229. google/genai/tests/models/test_generate_content.py +2502 -0
  230. google/genai/tests/models/test_generate_content_cached_content.py +132 -0
  231. google/genai/tests/models/test_generate_content_config_zero_value.py +103 -0
  232. google/genai/tests/models/test_generate_content_from_apikey.py +44 -0
  233. google/genai/tests/models/test_generate_content_http_options.py +40 -0
  234. google/genai/tests/models/test_generate_content_image_generation.py +143 -0
  235. google/genai/tests/models/test_generate_content_mcp.py +343 -0
  236. google/genai/tests/models/test_generate_content_media_resolution.py +97 -0
  237. google/genai/tests/models/test_generate_content_model.py +139 -0
  238. google/genai/tests/models/test_generate_content_part.py +821 -0
  239. google/genai/tests/models/test_generate_content_thought.py +76 -0
  240. google/genai/tests/models/test_generate_content_tools.py +1761 -0
  241. google/genai/tests/models/test_generate_images.py +191 -0
  242. google/genai/tests/models/test_generate_videos.py +759 -0
  243. google/genai/tests/models/test_get.py +104 -0
  244. google/genai/tests/models/test_list.py +233 -0
  245. google/genai/tests/models/test_recontext_image.py +189 -0
  246. google/genai/tests/models/test_segment_image.py +148 -0
  247. google/genai/tests/models/test_update.py +95 -0
  248. google/genai/tests/models/test_upscale_image.py +157 -0
  249. google/genai/tests/operations/__init__.py +17 -0
  250. google/genai/tests/operations/test_get.py +38 -0
  251. google/genai/tests/public_samples/__init__.py +17 -0
  252. google/genai/tests/public_samples/test_gemini_text_only.py +34 -0
  253. google/genai/tests/pytest_helper.py +229 -0
  254. google/genai/tests/shared/__init__.py +16 -0
  255. google/genai/tests/shared/batches/__init__.py +14 -0
  256. google/genai/tests/shared/batches/test_create_delete.py +57 -0
  257. google/genai/tests/shared/batches/test_create_get_cancel.py +56 -0
  258. google/genai/tests/shared/batches/test_list.py +40 -0
  259. google/genai/tests/shared/caches/__init__.py +14 -0
  260. google/genai/tests/shared/caches/test_create_get_delete.py +67 -0
  261. google/genai/tests/shared/caches/test_create_update_get.py +71 -0
  262. google/genai/tests/shared/caches/test_list.py +40 -0
  263. google/genai/tests/shared/chats/__init__.py +14 -0
  264. google/genai/tests/shared/chats/test_send_message.py +48 -0
  265. google/genai/tests/shared/chats/test_send_message_stream.py +50 -0
  266. google/genai/tests/shared/files/__init__.py +14 -0
  267. google/genai/tests/shared/files/test_list.py +41 -0
  268. google/genai/tests/shared/files/test_upload_get_delete.py +54 -0
  269. google/genai/tests/shared/models/__init__.py +14 -0
  270. google/genai/tests/shared/models/test_compute_tokens.py +41 -0
  271. google/genai/tests/shared/models/test_count_tokens.py +40 -0
  272. google/genai/tests/shared/models/test_edit_image.py +67 -0
  273. google/genai/tests/shared/models/test_embed.py +40 -0
  274. google/genai/tests/shared/models/test_generate_content.py +39 -0
  275. google/genai/tests/shared/models/test_generate_content_stream.py +54 -0
  276. google/genai/tests/shared/models/test_generate_images.py +40 -0
  277. google/genai/tests/shared/models/test_generate_videos.py +38 -0
  278. google/genai/tests/shared/models/test_list.py +37 -0
  279. google/genai/tests/shared/models/test_recontext_image.py +55 -0
  280. google/genai/tests/shared/models/test_segment_image.py +52 -0
  281. google/genai/tests/shared/models/test_upscale_image.py +52 -0
  282. google/genai/tests/shared/tunings/__init__.py +16 -0
  283. google/genai/tests/shared/tunings/test_create.py +46 -0
  284. google/genai/tests/shared/tunings/test_create_get_cancel.py +56 -0
  285. google/genai/tests/shared/tunings/test_list.py +39 -0
  286. google/genai/tests/tokens/__init__.py +16 -0
  287. google/genai/tests/tokens/test_create.py +154 -0
  288. google/genai/tests/transformers/__init__.py +17 -0
  289. google/genai/tests/transformers/test_blobs.py +71 -0
  290. google/genai/tests/transformers/test_bytes.py +15 -0
  291. google/genai/tests/transformers/test_duck_type.py +96 -0
  292. google/genai/tests/transformers/test_function_responses.py +72 -0
  293. google/genai/tests/transformers/test_schema.py +653 -0
  294. google/genai/tests/transformers/test_t_batch.py +286 -0
  295. google/genai/tests/transformers/test_t_content.py +160 -0
  296. google/genai/tests/transformers/test_t_contents.py +398 -0
  297. google/genai/tests/transformers/test_t_part.py +85 -0
  298. google/genai/tests/transformers/test_t_parts.py +87 -0
  299. google/genai/tests/transformers/test_t_tool.py +157 -0
  300. google/genai/tests/transformers/test_t_tools.py +195 -0
  301. google/genai/tests/tunings/__init__.py +16 -0
  302. google/genai/tests/tunings/test_cancel.py +39 -0
  303. google/genai/tests/tunings/test_end_to_end.py +106 -0
  304. google/genai/tests/tunings/test_get.py +67 -0
  305. google/genai/tests/tunings/test_list.py +75 -0
  306. google/genai/tests/tunings/test_tune.py +268 -0
  307. google/genai/tests/types/__init__.py +16 -0
  308. google/genai/tests/types/test_bytes_internal.py +271 -0
  309. google/genai/tests/types/test_bytes_type.py +152 -0
  310. google/genai/tests/types/test_future.py +101 -0
  311. google/genai/tests/types/test_optional_types.py +36 -0
  312. google/genai/tests/types/test_part_type.py +616 -0
  313. google/genai/tests/types/test_schema_from_json_schema.py +417 -0
  314. google/genai/tests/types/test_schema_json_schema.py +468 -0
  315. google/genai/tests/types/test_types.py +2903 -0
  316. google/genai/tunings.py +57 -57
  317. google/genai/types.py +229 -121
  318. google/genai/version.py +1 -1
  319. {google_genai-1.53.0.dist-info → google_genai-1.55.0.dist-info}/METADATA +4 -2
  320. google_genai-1.55.0.dist-info/RECORD +345 -0
  321. google_genai-1.53.0.dist-info/RECORD +0 -41
  322. {google_genai-1.53.0.dist-info → google_genai-1.55.0.dist-info}/WHEEL +0 -0
  323. {google_genai-1.53.0.dist-info → google_genai-1.55.0.dist-info}/licenses/LICENSE +0 -0
  324. {google_genai-1.53.0.dist-info → google_genai-1.55.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,2177 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+
16
+
17
+ """Tests for live.py."""
18
+
19
+ import contextlib
20
+ import json
21
+ import os
22
+ import ssl
23
+ import typing
24
+ from typing import Any, AsyncIterator
25
+ from unittest import mock
26
+ from unittest.mock import AsyncMock
27
+ from unittest.mock import Mock
28
+ from unittest.mock import patch
29
+ import warnings
30
+
31
+ import certifi
32
+ from google.oauth2.credentials import Credentials
33
+ import pytest
34
+ from websockets import client
35
+
36
+ from .. import pytest_helper
37
+ from ... import _api_client as api_client
38
+ from ... import _common
39
+ from ... import Client
40
+ from ... import client as gl_client
41
+ from ... import live
42
+ from ... import types
43
+ try:
44
+ import aiohttp
45
+ AIOHTTP_NOT_INSTALLED = False
46
+ except ImportError:
47
+ AIOHTTP_NOT_INSTALLED = True
48
+ aiohttp = mock.MagicMock()
49
+
50
+
51
+ if typing.TYPE_CHECKING:
52
+ from mcp import types as mcp_types
53
+ from mcp import ClientSession as McpClientSession
54
+ else:
55
+ mcp_types: typing.Type = Any
56
+ McpClientSession: typing.Type = Any
57
+ try:
58
+ from mcp import types as mcp_types
59
+ from mcp import ClientSession as McpClientSession
60
+ except ImportError:
61
+ mcp_types = None
62
+ McpClientSession = None
63
+
64
+
65
+ requires_aiohttp = pytest.mark.skipif(
66
+ AIOHTTP_NOT_INSTALLED, reason="aiohttp is not installed, skipping test."
67
+ )
68
+
69
+ function_declarations = [{
70
+ 'name': 'get_current_weather',
71
+ 'description': 'Get the current weather in a city',
72
+ 'parameters': {
73
+ 'type': 'OBJECT',
74
+ 'properties': {
75
+ 'location': {
76
+ 'type': 'STRING',
77
+ 'description': 'The location to get the weather for',
78
+ },
79
+ 'unit': {
80
+ 'type': 'STRING',
81
+ 'enum': ['C', 'F'],
82
+ },
83
+ },
84
+ },
85
+ }]
86
+
87
+
88
+ def get_current_weather(location: str, unit: str):
89
+ """Get the current weather in a city."""
90
+ return 15 if unit == 'C' else 59
91
+
92
+
93
+ def mock_api_client(vertexai=False, credentials=None, http_options=None):
94
+ api_client = mock.MagicMock(spec=gl_client.BaseApiClient)
95
+ if not vertexai:
96
+ api_client.api_key = 'TEST_API_KEY'
97
+ api_client.location = None
98
+ api_client.project = None
99
+ api_client.custom_base_url = None
100
+ else:
101
+ api_client.api_key = None
102
+ if http_options:
103
+ http_options = (
104
+ types.HttpOptions(**http_options)
105
+ if isinstance(http_options, dict)
106
+ else http_options
107
+ )
108
+ api_client.custom_base_url = http_options.base_url
109
+ api_client.location = None
110
+ api_client.project = None
111
+ else:
112
+ api_client.location = 'us-central1'
113
+ api_client.project = 'test_project'
114
+ api_client.custom_base_url = None
115
+
116
+ api_client._host = lambda: 'test_host'
117
+ api_client._credentials = credentials
118
+ api_client._http_options = types.HttpOptions.model_validate(
119
+ {'headers': {}}
120
+ ) # Ensure headers exist
121
+ api_client.vertexai = vertexai
122
+ api_client._api_client = api_client
123
+ ctx = ssl.create_default_context(
124
+ cafile=os.environ.get("SSL_CERT_FILE", certifi.where()),
125
+ capath=os.environ.get("SSL_CERT_DIR"),
126
+ )
127
+ api_client._websocket_ssl_ctx = {'ssl': ctx}
128
+ return api_client
129
+
130
+
131
+ @pytest.fixture
132
+ def mock_websocket():
133
+ websocket = AsyncMock(spec=client.ClientConnection)
134
+ websocket.send = AsyncMock()
135
+ websocket.recv = AsyncMock(
136
+ return_value='{"serverContent": {"turnComplete": true}}'
137
+ ) # Default response
138
+ websocket.close = AsyncMock()
139
+ return websocket
140
+
141
+
142
+ async def get_connect_message(api_client, model, config=None):
143
+ if config is None:
144
+ config = {}
145
+ mock_ws = AsyncMock()
146
+ mock_ws.send = AsyncMock()
147
+ mock_ws.recv = AsyncMock(
148
+ return_value=(
149
+ b'{\n "setupComplete": {"sessionId": "test_session_id"}\n}\n'
150
+ )
151
+ )
152
+
153
+ mock_google_auth_default = Mock(return_value=(None, None))
154
+ mock_creds = Mock(token='test_token')
155
+ mock_google_auth_default.return_value = (mock_creds, None)
156
+
157
+ @contextlib.asynccontextmanager
158
+ async def mock_connect(uri, additional_headers=None, **kwargs):
159
+ yield mock_ws
160
+
161
+ @patch('google.auth.default', new=mock_google_auth_default)
162
+ @patch.object(live, 'ws_connect', new=mock_connect)
163
+ async def _test_connect():
164
+ live_module = live.AsyncLive(api_client)
165
+ async with live_module.connect(
166
+ model=model,
167
+ config=config,
168
+ ):
169
+ pass
170
+
171
+ mock_ws.send.assert_called_once()
172
+ return json.loads(mock_ws.send.call_args[0][0])
173
+
174
+ return await _test_connect()
175
+
176
+
177
+ async def _async_iterator_to_list(async_iter):
178
+ return [value async for value in async_iter]
179
+
180
+
181
+ def test_mldev_from_env(monkeypatch):
182
+ api_key = 'google_api_key'
183
+ monkeypatch.setenv('GOOGLE_API_KEY', api_key)
184
+
185
+ client = Client()
186
+
187
+ assert not client.aio.live._api_client.vertexai
188
+ assert client.aio.live._api_client.api_key == api_key
189
+ assert isinstance(client.aio.live._api_client, api_client.BaseApiClient)
190
+ assert client.aio.live._api_client._http_options.headers['x-goog-api-key'] == api_key
191
+
192
+
193
+ @requires_aiohttp
194
+ def test_vertex_from_env(monkeypatch):
195
+ project_id = 'fake_project_id'
196
+ location = 'fake-location'
197
+ monkeypatch.setenv('GOOGLE_GENAI_USE_VERTEXAI', 'true')
198
+ monkeypatch.setenv('GOOGLE_CLOUD_PROJECT', project_id)
199
+ monkeypatch.setenv('GOOGLE_CLOUD_LOCATION', location)
200
+
201
+ client = Client()
202
+
203
+ assert client.aio.live._api_client.custom_base_url is None
204
+ assert client.aio.live._api_client.vertexai
205
+ assert client.aio.live._api_client.project == project_id
206
+ assert isinstance(client.aio.live._api_client, api_client.BaseApiClient)
207
+ assert 'x-goog-api-key' not in client.aio.live._api_client._http_options.headers
208
+
209
+
210
+ def test_vertex_api_key_from_env(monkeypatch):
211
+ api_key = 'google_api_key'
212
+ monkeypatch.setenv('GOOGLE_GENAI_USE_VERTEXAI', 'true')
213
+ monkeypatch.setenv('GOOGLE_API_KEY', api_key)
214
+
215
+ # Due to proj/location taking precedence, need to clear proj/location env
216
+ # variables. Tests in client/test_client_initialization.py provide
217
+ # comprehensive coverage for proj/location and api key precedence.
218
+ monkeypatch.setenv("GOOGLE_CLOUD_LOCATION", "")
219
+ monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "")
220
+
221
+ client = Client()
222
+
223
+ assert client.aio.live._api_client.vertexai
224
+ assert client.aio.live._api_client.api_key == api_key
225
+ assert isinstance(client.aio.live._api_client, api_client.BaseApiClient)
226
+ assert client.aio.live._api_client._http_options.headers['x-goog-api-key'] == api_key
227
+
228
+
229
+ def test_websocket_base_url():
230
+ base_url = 'https://test.com'
231
+ api_client = gl_client.BaseApiClient(
232
+ api_key='google_api_key',
233
+ http_options={'base_url': base_url},
234
+ )
235
+ assert api_client._websocket_base_url() == 'wss://test.com'
236
+
237
+
238
+ def test_websocket_base_url_no_auth_with_custom_base_url():
239
+ base_url = 'https://test-api-gateway-proxy.com'
240
+ api_client = gl_client.BaseApiClient(
241
+ vertexai=True,
242
+ http_options={
243
+ 'base_url': base_url,
244
+ 'headers': {'Authorization': 'Bearer test_token'},
245
+ },
246
+ )
247
+ # Note that our test environment does have project/location set. So we
248
+ # need to explicitly set them to None here.
249
+ api_client.project = None
250
+ api_client.location = None
251
+
252
+ # Fully pass the custom base url if no API key or project/location.
253
+ assert api_client._websocket_base_url() == base_url
254
+
255
+
256
+ @pytest.mark.parametrize('vertexai', [True, False])
257
+ @pytest.mark.asyncio
258
+ async def test_async_session_send_text(mock_websocket, vertexai):
259
+ session = live.AsyncSession(
260
+ api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
261
+ )
262
+ await session.send(input='test')
263
+ mock_websocket.send.assert_called_once()
264
+ sent_data = json.loads(mock_websocket.send.call_args[0][0])
265
+ assert 'client_content' in sent_data
266
+
267
+
268
+ @pytest.mark.parametrize('vertexai', [True, False])
269
+ @pytest.mark.asyncio
270
+ async def test_async_session_send_content_dict(
271
+ mock_websocket, vertexai
272
+ ):
273
+ session = live.AsyncSession(
274
+ api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
275
+ )
276
+ client_content = {
277
+ 'content': [{'parts': [{'text': 'test'}]}],
278
+ 'turn_complete': True,
279
+ }
280
+ await session.send(input=client_content)
281
+ mock_websocket.send.assert_called_once()
282
+ sent_data = json.loads(mock_websocket.send.call_args[0][0])
283
+ assert 'client_content' in sent_data
284
+
285
+
286
+ @pytest.mark.parametrize('vertexai', [True, False])
287
+ @pytest.mark.asyncio
288
+ async def test_async_session_send_content(
289
+ mock_websocket, vertexai
290
+ ):
291
+ session = live.AsyncSession(
292
+ api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
293
+ )
294
+ client_content = types.LiveClientContent(
295
+ turns=[types.Content(parts=[types.Part(text='test')])], turn_complete=True
296
+ )
297
+ await session.send(input=client_content)
298
+ mock_websocket.send.assert_called_once()
299
+ sent_data = json.loads(mock_websocket.send.call_args[0][0])
300
+ assert 'client_content' in sent_data
301
+
302
+
303
+ @pytest.mark.parametrize('vertexai', [True, False])
304
+ @pytest.mark.asyncio
305
+ async def test_async_session_send_bytes(
306
+ mock_websocket, vertexai
307
+ ):
308
+ session = live.AsyncSession(
309
+ api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
310
+ )
311
+ realtime_input = {'data': b'000000', 'mime_type': 'audio/pcm'}
312
+
313
+ await session.send(input=realtime_input)
314
+ mock_websocket.send.assert_called_once()
315
+ sent_data = json.loads(mock_websocket.send.call_args[0][0])
316
+ assert 'realtime_input' in sent_data
317
+
318
+
319
+ @pytest.mark.parametrize('vertexai', [True, False])
320
+ @pytest.mark.asyncio
321
+ async def test_async_session_send_blob(
322
+ mock_websocket, vertexai
323
+ ):
324
+ session = live.AsyncSession(
325
+ api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
326
+ )
327
+ realtime_input = types.Blob(data=b'000000', mime_type='audio/pcm')
328
+
329
+ await session.send(input=realtime_input)
330
+ mock_websocket.send.assert_called_once()
331
+ sent_data = json.loads(mock_websocket.send.call_args[0][0])
332
+ assert 'realtime_input' in sent_data
333
+
334
+
335
+ @pytest.mark.parametrize('vertexai', [True, False])
336
+ @pytest.mark.asyncio
337
+ async def test_async_session_send_realtime_input(
338
+ mock_websocket, vertexai
339
+ ):
340
+ session = live.AsyncSession(
341
+ api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
342
+ )
343
+ realtime_input = types.LiveClientRealtimeInput(
344
+ media_chunks=[types.Blob(data='MDAwMDAw', mime_type='audio/pcm')]
345
+ )
346
+ await session.send(input=realtime_input)
347
+ mock_websocket.send.assert_called_once()
348
+ sent_data = json.loads(mock_websocket.send.call_args[0][0])
349
+ assert 'realtime_input' in sent_data
350
+
351
+
352
+ @pytest.mark.parametrize('vertexai', [True, False])
353
+ @pytest.mark.asyncio
354
+ async def test_async_session_send_tool_response(
355
+ mock_websocket, vertexai
356
+ ):
357
+ session = live.AsyncSession(
358
+ api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
359
+ )
360
+
361
+ if vertexai:
362
+ tool_response = types.LiveClientToolResponse(
363
+ function_responses=[
364
+ types.FunctionResponse(
365
+ name='get_current_weather',
366
+ response={'temperature': 14.5, 'unit': 'C'},
367
+ )
368
+ ]
369
+ )
370
+ else:
371
+ tool_response = types.LiveClientToolResponse(
372
+ function_responses=[
373
+ types.FunctionResponse(
374
+ name='get_current_weather',
375
+ response={'temperature': 14.5, 'unit': 'C'},
376
+ id='some-id',
377
+ )
378
+ ]
379
+ )
380
+ await session.send(input=tool_response)
381
+ mock_websocket.send.assert_called_once()
382
+ sent_data = json.loads(mock_websocket.send.call_args[0][0])
383
+ assert 'tool_response' in sent_data
384
+
385
+
386
+ @pytest.mark.parametrize('vertexai', [True, False])
387
+ @pytest.mark.asyncio
388
+ async def test_async_session_send_input_none(
389
+ mock_websocket, vertexai
390
+ ):
391
+ session = live.AsyncSession(
392
+ api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
393
+ )
394
+ await session.send(input=None)
395
+ mock_websocket.send.assert_called_once()
396
+ sent_data = json.loads(mock_websocket.send.call_args[0][0])
397
+ assert 'client_content' in sent_data
398
+ assert sent_data['client_content']['turn_complete']
399
+
400
+
401
+ @pytest.mark.parametrize('vertexai', [True, False])
402
+ @pytest.mark.asyncio
403
+ async def test_async_session_send_error(
404
+ mock_websocket, vertexai
405
+ ):
406
+ session = live.AsyncSession(
407
+ api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
408
+ )
409
+ with pytest.raises(ValueError):
410
+ await session.send(input=[{'invalid_key': 'invalid_value'}])
411
+
412
+ with pytest.raises(ValueError):
413
+ await session.send(input={'invalid_key': 'invalid_value'})
414
+
415
+
416
+ @pytest.mark.parametrize('vertexai', [True, False])
417
+ @pytest.mark.asyncio
418
+ async def test_async_session_receive( mock_websocket, vertexai):
419
+ session = live.AsyncSession(
420
+ api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
421
+ )
422
+ responses = session.receive()
423
+ responses = await _async_iterator_to_list(responses)
424
+ assert isinstance(responses[0], types.LiveServerMessage)
425
+
426
+
427
+ @pytest.mark.parametrize('vertexai', [True, False])
428
+ @pytest.mark.asyncio
429
+ async def test_async_session_receive_error(
430
+ mock_websocket, vertexai
431
+ ):
432
+ mock_websocket.recv = AsyncMock(return_value='invalid json')
433
+ session = live.AsyncSession(
434
+ api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
435
+ )
436
+ with pytest.raises(ValueError):
437
+ await session.receive().__anext__()
438
+
439
+
440
+ @pytest.mark.parametrize('vertexai', [True, False])
441
+ @pytest.mark.asyncio
442
+ async def test_async_session_receive_text(
443
+ mock_websocket, vertexai
444
+ ):
445
+ mock_websocket.recv = AsyncMock(
446
+ side_effect=[
447
+ '{"serverContent": {"modelTurn": {"parts":[{"text": "test"}]}}}',
448
+ '{"serverContent": {"turnComplete": true}}',
449
+ ]
450
+ )
451
+ session = live.AsyncSession(
452
+ api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
453
+ )
454
+ messages = session.receive()
455
+ messages = await _async_iterator_to_list(messages)
456
+ assert isinstance(messages[0], types.LiveServerMessage)
457
+ assert messages[0].server_content.model_turn.parts[0].text == 'test'
458
+ assert messages[1].server_content.turn_complete == True
459
+
460
+
461
+ @pytest.mark.parametrize('vertexai', [True, False])
462
+ @pytest.mark.asyncio
463
+ async def test_async_session_receive_audio(
464
+ mock_websocket, vertexai
465
+ ):
466
+ mock_websocket.recv = AsyncMock(
467
+ side_effect=[
468
+ (
469
+ '{"serverContent": {"modelTurn": {"parts":[{"inlineData":'
470
+ ' {"data": "MDAwMDAw", "mimeType": "audio/pcm" }}]}}}'
471
+ ),
472
+ '{"serverContent": {"turnComplete": true}}',
473
+ ]
474
+ )
475
+ session = live.AsyncSession(
476
+ api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
477
+ )
478
+ messages = session.receive()
479
+ messages = await _async_iterator_to_list(messages)
480
+ assert isinstance(messages[0], types.LiveServerMessage)
481
+ assert (
482
+ messages[0].server_content.model_turn.parts[0].inline_data.mime_type
483
+ == 'audio/pcm'
484
+ )
485
+ assert (
486
+ messages[0].server_content.model_turn.parts[0].inline_data.data
487
+ == b'000000'
488
+ )
489
+
490
+ with pytest.raises(RuntimeError):
491
+ await _async_iterator_to_list(session.receive())
492
+
493
+
494
+ @pytest.mark.parametrize('vertexai', [True, False])
495
+ @pytest.mark.asyncio
496
+ async def test_async_session_receive_tool_call(
497
+ mock_websocket, vertexai
498
+ ):
499
+ mock_websocket.recv = AsyncMock(
500
+ side_effect=[
501
+ (
502
+ '{"toolCall": {"functionCalls": [{"name":'
503
+ ' "get_current_weather", "args": {"location": "San Francisco",'
504
+ ' "unit": "C"}}]}}'
505
+ ),
506
+ '{"serverContent": {"turnComplete": true}}',
507
+ ]
508
+ )
509
+ session = live.AsyncSession(
510
+ api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
511
+ )
512
+ messages = session.receive()
513
+ messages = await _async_iterator_to_list(messages)
514
+ assert isinstance(messages[0], types.LiveServerMessage)
515
+ assert messages[0].tool_call.function_calls[0].name == 'get_current_weather'
516
+ assert (
517
+ messages[0].tool_call.function_calls[0].args['location']
518
+ == 'San Francisco'
519
+ )
520
+ assert messages[0].tool_call.function_calls[0].args['unit'] == 'C'
521
+
522
+ with pytest.raises(RuntimeError):
523
+ await _async_iterator_to_list(session.receive())
524
+
525
+
526
+ @pytest.mark.parametrize('vertexai', [True, False])
527
+ @pytest.mark.asyncio
528
+ async def test_async_session_receive_transcription(
529
+ mock_websocket, vertexai
530
+ ):
531
+ mock_websocket.recv = AsyncMock(
532
+ side_effect=[
533
+ '{"serverContent": {"inputTranscription": {"text": "test_input", "finished": true}}}',
534
+ '{"serverContent": {"outputTranscription": {"text": "test_output", "finished": false}}}',
535
+ '{"serverContent": {"turnComplete": true}}',
536
+ ]
537
+ )
538
+ session = live.AsyncSession(
539
+ api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
540
+ )
541
+ messages = session.receive()
542
+ messages = await _async_iterator_to_list(messages)
543
+ assert isinstance(messages[0], types.LiveServerMessage)
544
+ assert messages[0].server_content.input_transcription.text == 'test_input'
545
+ assert messages[0].server_content.input_transcription.finished == True
546
+
547
+ assert isinstance(messages[1], types.LiveServerMessage)
548
+ assert messages[1].server_content.output_transcription.text == 'test_output'
549
+ assert messages[1].server_content.output_transcription.finished == False
550
+
551
+
552
+ @pytest.mark.parametrize('vertexai', [True, False])
553
+ @pytest.mark.asyncio
554
+ async def test_async_go_away(
555
+ mock_websocket, vertexai
556
+ ):
557
+ mock_websocket.recv = AsyncMock(
558
+ side_effect=[
559
+ '{"goAway": {"timeLeft": "10s"}}',
560
+ '{"serverContent": {"turnComplete": true}}',
561
+ ]
562
+ )
563
+ expected_result = types.LiveServerMessage(
564
+ go_away=types.LiveServerGoAway(time_left='10s'),
565
+ )
566
+ session = live.AsyncSession(
567
+ api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
568
+ )
569
+ messages = session.receive()
570
+ messages = await _async_iterator_to_list(messages)
571
+ message = messages[0]
572
+
573
+ assert isinstance(message, types.LiveServerMessage)
574
+ assert message == expected_result
575
+
576
+
577
+ @pytest.mark.parametrize('vertexai', [True, False])
578
+ @pytest.mark.asyncio
579
+ async def test_async_session_resumption_update(
580
+ mock_websocket, vertexai
581
+ ):
582
+ mock_websocket.recv = AsyncMock(
583
+ side_effect=[
584
+ """{
585
+ "sessionResumptionUpdate": {
586
+ "newHandle": "test_handle",
587
+ "resumable": "true",
588
+ "lastConsumedClientMessageIndex": "123456789"
589
+ }
590
+ }""",
591
+ '{"serverContent": {"turnComplete": true}}',
592
+ ]
593
+ )
594
+
595
+ expected_result = types.LiveServerMessage(
596
+ session_resumption_update=types.LiveServerSessionResumptionUpdate(
597
+ new_handle='test_handle',
598
+ resumable=True,
599
+ last_consumed_client_message_index=123456789
600
+ ),
601
+ )
602
+
603
+ session = live.AsyncSession(
604
+ api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
605
+ )
606
+ messages = session.receive()
607
+ messages = await _async_iterator_to_list(messages)
608
+ message = messages[0]
609
+
610
+ assert isinstance(message, types.LiveServerMessage)
611
+ assert message == expected_result
612
+
613
+
614
+ @pytest.mark.parametrize('vertexai', [True, False])
615
+ @pytest.mark.asyncio
616
+ async def test_async_session_start_stream(
617
+ mock_websocket, vertexai
618
+ ):
619
+
620
+ session = live.AsyncSession(
621
+ mock_api_client(vertexai=vertexai), mock_websocket
622
+ )
623
+
624
+ async def mock_stream():
625
+ yield b'data1'
626
+ yield b'data2'
627
+
628
+ async for message in session.start_stream(
629
+ stream=mock_stream(), mime_type='audio/pcm'
630
+ ):
631
+ assert isinstance(message, types.LiveServerMessage)
632
+
633
+
634
+ @pytest.mark.parametrize('vertexai', [True, False])
635
+ @pytest.mark.asyncio
636
+ async def test_async_session_receive_vad_signal(mock_websocket, vertexai):
637
+ # Simulate the server sending a VAD signal message
638
+ mock_websocket.recv = mock.AsyncMock(
639
+ side_effect=[
640
+ '{"voiceActivityDetectionSignal": {"vadSignalType": "VAD_SIGNAL_TYPE_SOS"}}',
641
+ '{"serverContent": {"turnComplete": true}}', # To close the receiver loop
642
+ ]
643
+ )
644
+
645
+ session = live.AsyncSession(
646
+ api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
647
+ )
648
+
649
+ messages = await _async_iterator_to_list(session.receive())
650
+
651
+ # Check if the first message contains the VAD signal
652
+ assert len(messages) > 0
653
+ vad_message = messages[0]
654
+ assert isinstance(vad_message, types.LiveServerMessage)
655
+ assert vad_message.voice_activity_detection_signal is not None
656
+ assert (
657
+ vad_message.voice_activity_detection_signal.vad_signal_type
658
+ == types.VadSignalType.VAD_SIGNAL_TYPE_SOS
659
+ )
660
+
661
+ # Check that the session can close cleanly
662
+ assert messages[-1].server_content.turn_complete is True
663
+
664
+
665
+ @pytest.mark.parametrize('vertexai', [True, False])
666
+ @pytest.mark.asyncio
667
+ async def test_async_session_close( mock_websocket, vertexai):
668
+ session = live.AsyncSession(
669
+ mock_api_client(vertexai=vertexai), mock_websocket
670
+ )
671
+ await session.close()
672
+ mock_websocket.close.assert_called_once()
673
+
674
+
675
+ @pytest.mark.parametrize('vertexai', [True, False])
676
+ @pytest.mark.asyncio
677
+ async def test_bidi_setup_to_api_no_config(vertexai):
678
+ with warnings.catch_warnings():
679
+ # Make sure there are no warnings cause by default values.
680
+ warnings.simplefilter('error')
681
+ result = await get_connect_message(
682
+ mock_api_client(vertexai=vertexai),
683
+ model='test_model'
684
+ )
685
+ expected_result = {'setup': {}}
686
+ if vertexai:
687
+ expected_result['setup']['model'] = 'projects/test_project/locations/us-central1/publishers/google/models/test_model'
688
+ expected_result['setup']['generationConfig'] = {}
689
+ expected_result['setup']['generationConfig']['responseModalities'] = ["AUDIO"]
690
+ else:
691
+ expected_result['setup']['model'] = 'models/test_model'
692
+ assert result == expected_result
693
+
694
+
695
+ @pytest.mark.parametrize('vertexai', [True, False])
696
+ @pytest.mark.asyncio
697
+ async def test_bidi_setup_to_api_speech_config(vertexai):
698
+
699
+ expected_result = {
700
+ 'setup': {
701
+ 'model': 'models/test_model',
702
+ 'generationConfig': {
703
+ 'speechConfig': {
704
+ 'voice_config': {
705
+ 'prebuilt_voice_config': {'voice_name': 'en-default'}
706
+ },
707
+ 'language_code': 'en-US',
708
+ },
709
+ 'enableAffectiveDialog': True,
710
+ 'temperature': 0.7,
711
+ 'topP': 0.8,
712
+ 'topK': 9.0,
713
+ 'maxOutputTokens': 10,
714
+ 'mediaResolution': 'MEDIA_RESOLUTION_MEDIUM',
715
+ 'seed': 13,
716
+ },
717
+ 'proactivity': {'proactive_audio': True},
718
+ 'systemInstruction': {
719
+ 'parts': [
720
+ {
721
+ 'text': 'test instruction',
722
+ },
723
+ ],
724
+ 'role': 'user',
725
+ },
726
+ }
727
+ }
728
+ if vertexai:
729
+ expected_result['setup']['model'] = (
730
+ 'projects/test_project/locations/us-central1/'
731
+ 'publishers/google/models/test_model'
732
+ )
733
+ expected_result['setup']['generationConfig']['responseModalities'] = [
734
+ 'AUDIO'
735
+ ]
736
+ expected_result['setup']['generationConfig']['speechConfig'] = {
737
+ 'voiceConfig': {
738
+ 'prebuilt_voice_config': {'voice_name': 'en-default'}
739
+ },
740
+ 'languageCode': 'en-US',
741
+ }
742
+ else:
743
+ expected_result['setup']['model'] = 'models/test_model'
744
+
745
+ # Config is a dict
746
+ config_dict = {
747
+ 'speech_config': {
748
+ 'voice_config': {
749
+ 'prebuilt_voice_config': {'voice_name': 'en-default'}
750
+ },
751
+ 'language_code': 'en-US',
752
+ },
753
+ 'enable_affective_dialog': True,
754
+ 'proactivity': {'proactive_audio': True},
755
+ 'temperature': 0.7,
756
+ 'top_p': 0.8,
757
+ 'top_k': 9,
758
+ 'max_output_tokens': 10,
759
+ 'seed': 13,
760
+ 'system_instruction': 'test instruction',
761
+ 'media_resolution': 'MEDIA_RESOLUTION_MEDIUM',
762
+ }
763
+ result = await get_connect_message(
764
+ mock_api_client(vertexai=vertexai), model='test_model', config=config_dict
765
+ )
766
+ assert types.LiveClientMessage._from_response(
767
+ response=result, kwargs=None
768
+ ) == types.LiveClientMessage._from_response(
769
+ response=expected_result, kwargs=None
770
+ )
771
+ # Config is a LiveConnectConfig
772
+ config = types.LiveConnectConfig(
773
+ speech_config=types.SpeechConfig(
774
+ voice_config=types.VoiceConfig(
775
+ prebuilt_voice_config=types.PrebuiltVoiceConfig(
776
+ voice_name='en-default'
777
+ )
778
+ ),
779
+ language_code='en-US',
780
+ ),
781
+ enable_affective_dialog=True,
782
+ proactivity=types.ProactivityConfig(proactive_audio=True),
783
+ temperature=0.7,
784
+ top_p=0.8,
785
+ top_k=9,
786
+ max_output_tokens=10,
787
+ media_resolution=types.MediaResolution.MEDIA_RESOLUTION_MEDIUM,
788
+ seed=13,
789
+ system_instruction='test instruction',
790
+ )
791
+ result = await get_connect_message(
792
+ mock_api_client(vertexai=vertexai),
793
+ model='test_model', config=config
794
+ )
795
+ assert types.LiveClientMessage._from_response(
796
+ response=result, kwargs=None
797
+ ) == types.LiveClientMessage._from_response(
798
+ response=expected_result, kwargs=None
799
+ )
800
+
801
+
802
+ @pytest.mark.parametrize('vertexai', [True, False])
803
+ @pytest.mark.asyncio
804
+ async def test_bidi_setup_error_if_multispeaker_voice_config(vertexai):
805
+
806
+ # Config is a dict
807
+ config_dict = {
808
+ 'speech_config': {
809
+ 'multi_speaker_voice_config': {
810
+ 'speaker_voice_configs': [
811
+ {
812
+ 'speaker': 'Alice',
813
+ 'voice_config': {
814
+ 'prebuilt_voice_config': {'voice_name': 'leda'}
815
+ },
816
+ },
817
+ {
818
+ 'speaker': 'Bob',
819
+ 'voice_config': {
820
+ 'prebuilt_voice_config': {'voice_name': 'kore'}
821
+ },
822
+ },
823
+ ],
824
+ },
825
+ },
826
+ 'temperature': 0.7,
827
+ 'top_p': 0.8,
828
+ 'top_k': 9,
829
+ 'max_output_tokens': 10,
830
+ 'seed': 13,
831
+ 'system_instruction': 'test instruction',
832
+ 'media_resolution': 'MEDIA_RESOLUTION_MEDIUM',
833
+ }
834
+ with pytest.raises(ValueError, match='.*multi_speaker_voice_config.*'):
835
+ result = await get_connect_message(
836
+ mock_api_client(vertexai=vertexai),
837
+ model='test_model',
838
+ config=config_dict,
839
+ )
840
+
841
+
842
+ @pytest.mark.parametrize('vertexai', [True, False])
843
+ @pytest.mark.asyncio
844
+ async def test_replicated_voice_config(vertexai):
845
+ # Config is a dict
846
+ config_dict = {
847
+ 'speech_config': {
848
+ 'voice_config': {
849
+ 'replicated_voice_config': {
850
+ 'mime_type': 'audio/pcm',
851
+ 'voice_sample_audio': bytes([0, 0, 0]),
852
+ },
853
+ },
854
+ },
855
+ }
856
+ result = await get_connect_message(
857
+ mock_api_client(vertexai=vertexai),
858
+ model='test_model',
859
+ config=config_dict,
860
+ )
861
+ if vertexai:
862
+ try:
863
+ replicated_voice_config = result['setup']['generationConfig'][
864
+ 'speechConfig'
865
+ ]['voiceConfig']['replicatedVoiceConfig']
866
+ except KeyError:
867
+ replicated_voice_config = result['setup']['generationConfig'][
868
+ 'speechConfig'
869
+ ]['voiceConfig']['replicated_voice_config']
870
+ assert replicated_voice_config == {
871
+ 'mime_type': 'audio/pcm',
872
+ 'voice_sample_audio': 'AAAA',
873
+ }
874
+ else:
875
+ return
876
+
877
+
878
+ @pytest.mark.parametrize('vertexai', [True, False])
879
+ @pytest.mark.asyncio
880
+ async def test_explicit_vad(vertexai):
881
+ # Config is a dict
882
+ config_dict = {'explicit_vad_signal': True}
883
+ with pytest_helper.exception_if_mldev(
884
+ mock_api_client(vertexai=vertexai), ValueError
885
+ ):
886
+ result = await get_connect_message(
887
+ mock_api_client(vertexai=vertexai), model='test_model', config=config_dict
888
+ )
889
+ if not vertexai:
890
+ return
891
+ assert result['setup']['explicitVadSignal'] == True
892
+
893
+
894
+ @pytest.mark.parametrize('vertexai', [True, False])
895
+ @pytest.mark.asyncio
896
+ async def test_explicit_vad_config(vertexai):
897
+ api_client = mock_api_client(vertexai=vertexai)
898
+
899
+ # Config is a dict
900
+ config_dict = {'explicit_vad_signal': True}
901
+ with pytest_helper.exception_if_mldev(api_client, ValueError):
902
+ result = await get_connect_message(
903
+ mock_api_client(vertexai=vertexai),
904
+ model='test_model',
905
+ config=config_dict,
906
+ )
907
+ if not vertexai:
908
+ return
909
+ assert result['setup']['explicitVadSignal'] == True
910
+
911
+
912
+ @pytest.mark.parametrize('vertexai', [True, False])
913
+ @pytest.mark.asyncio
914
+ async def test_bidi_setup_to_api_with_system_instruction_as_content_type(
915
+ vertexai,
916
+ ):
917
+ config_dict = {
918
+ 'system_instruction': {
919
+ 'parts': [{'text': 'test instruction'}],
920
+ 'role': 'user',
921
+ },
922
+ }
923
+ config = types.LiveConnectConfig(**config_dict)
924
+ expected_result = {
925
+ 'setup': {
926
+ 'model': 'test_model',
927
+ 'systemInstruction': {
928
+ 'parts': [{'text': 'test instruction'}],
929
+ 'role': 'user',
930
+ },
931
+ }
932
+ }
933
+ if vertexai:
934
+ expected_result['setup'][
935
+ 'model'
936
+ ] = 'projects/test_project/locations/us-central1/publishers/google/models/test_model'
937
+ expected_result['setup']['generationConfig'] = {}
938
+ expected_result['setup']['generationConfig']['responseModalities'] = [
939
+ 'AUDIO'
940
+ ]
941
+ else:
942
+ expected_result['setup']['model'] = 'models/test_model'
943
+
944
+ result = await get_connect_message(
945
+ mock_api_client(vertexai=vertexai),
946
+ model='test_model', config=config
947
+ )
948
+ assert result == expected_result
949
+
950
+
951
+ @pytest.mark.parametrize('vertexai', [True, False])
952
+ @pytest.mark.asyncio
953
+ async def test_bidi_setup_to_api_with_config_tools_google_search(vertexai):
954
+ config_dict = {
955
+ 'response_modalities': ['TEXT'],
956
+ 'system_instruction': 'test instruction',
957
+ 'generation_config': {'temperature': 0.7},
958
+ 'tools': [{'google_search': {}}],
959
+ }
960
+
961
+ config = types.LiveConnectConfig(**config_dict)
962
+ expected_result = {
963
+ 'setup': {
964
+ 'generationConfig': {
965
+ 'temperature': 0.7,
966
+ 'responseModalities': ['TEXT'],
967
+ },
968
+ 'systemInstruction': {
969
+ 'parts': [{'text': 'test instruction'}],
970
+ 'role': 'user',
971
+ },
972
+ 'tools': [{'googleSearch': {}}],
973
+ }
974
+ }
975
+ if vertexai:
976
+ expected_result['setup']['model'] = 'projects/test_project/locations/us-central1/publishers/google/models/test_model'
977
+ else:
978
+ expected_result['setup']['model'] = 'models/test_model'
979
+
980
+ result = await get_connect_message(
981
+ mock_api_client(vertexai=vertexai),
982
+ model='test_model', config=config_dict
983
+ )
984
+
985
+ assert result == expected_result
986
+
987
+ # Test config is a LiveConnectConfig
988
+ result = await get_connect_message(
989
+ mock_api_client(vertexai=vertexai),
990
+ model='test_model', config=config
991
+ )
992
+
993
+ assert result == expected_result
994
+
995
+
996
+ @pytest.mark.parametrize('vertexai', [True, False])
997
+ @pytest.mark.asyncio
998
+ async def test_bidi_setup_to_api_with_config_tools_with_no_mcp(vertexai):
999
+ config_dict = {
1000
+ 'response_modalities': ['TEXT'],
1001
+ 'system_instruction': 'test instruction',
1002
+ 'generation_config': {'temperature': 0.7},
1003
+ 'tools': [{'google_search': {}}],
1004
+ }
1005
+
1006
+ config = types.LiveConnectConfig(**config_dict)
1007
+ expected_result = {
1008
+ 'setup': {
1009
+ 'generationConfig': {
1010
+ 'temperature': 0.7,
1011
+ 'responseModalities': ['TEXT'],
1012
+ },
1013
+ 'systemInstruction': {
1014
+ 'parts': [{'text': 'test instruction'}],
1015
+ 'role': 'user',
1016
+ },
1017
+ 'tools': [{'googleSearch': {}}],
1018
+ }
1019
+ }
1020
+ if vertexai:
1021
+ expected_result['setup']['model'] = 'projects/test_project/locations/us-central1/publishers/google/models/test_model'
1022
+ else:
1023
+ expected_result['setup']['model'] = 'models/test_model'
1024
+
1025
+ @patch.object(live, "McpClientSession", new=None)
1026
+ @patch.object(live, "McpTool", new=None)
1027
+ async def get_connect_message_no_mcp(config):
1028
+ return await get_connect_message(
1029
+ mock_api_client(vertexai=vertexai),
1030
+ model='test_model', config=config
1031
+ )
1032
+
1033
+ result = await get_connect_message_no_mcp(config_dict)
1034
+ assert result == expected_result
1035
+
1036
+ result = await get_connect_message_no_mcp(config_dict)
1037
+ assert result == expected_result
1038
+
1039
+
1040
+ @pytest.mark.parametrize('vertexai', [True, False])
1041
+ @pytest.mark.asyncio
1042
+ async def test_bidi_setup_to_api_with_context_window_compression(
1043
+ vertexai
1044
+ ):
1045
+ config = types.LiveConnectConfig(
1046
+ generation_config=types.GenerationConfig(temperature=0.7),
1047
+ response_modalities=['TEXT'],
1048
+ system_instruction=types.Content(
1049
+ parts=[types.Part(text='test instruction')], role='user'
1050
+ ),
1051
+ context_window_compression=types.ContextWindowCompressionConfig(
1052
+ trigger_tokens=1000,
1053
+ sliding_window=types.SlidingWindow(target_tokens=10),
1054
+ ),
1055
+ )
1056
+ expected_result = {
1057
+ 'setup': {
1058
+ 'generationConfig': {
1059
+ 'temperature': 0.7,
1060
+ 'responseModalities': ['TEXT'],
1061
+ },
1062
+ 'systemInstruction': {
1063
+ 'parts': [{'text': 'test instruction'}],
1064
+ 'role': 'user',
1065
+ },
1066
+ 'contextWindowCompression': {
1067
+ 'trigger_tokens': 1000,
1068
+ 'sliding_window': {'target_tokens': 10},
1069
+ }
1070
+ }
1071
+ }
1072
+ if vertexai:
1073
+ expected_result['setup']['model'] = 'projects/test_project/locations/us-central1/publishers/google/models/test_model'
1074
+ else:
1075
+ expected_result['setup']['model'] = 'models/test_model'
1076
+
1077
+ result = await get_connect_message(
1078
+ mock_api_client(vertexai=vertexai),
1079
+ model='test_model', config=config
1080
+ )
1081
+ assert result == expected_result
1082
+
1083
+
1084
+ @pytest.mark.parametrize('vertexai', [True, False])
1085
+ @pytest.mark.asyncio
1086
+ async def test_bidi_setup_to_api_with_config_tools_function_declaration(
1087
+ vertexai
1088
+ ):
1089
+ config_dict = {
1090
+ 'generation_config': {'temperature': 0.7},
1091
+ 'tools': [{'function_declarations': function_declarations}],
1092
+ }
1093
+ config = types.LiveConnectConfig(**config_dict)
1094
+ expected_result = {
1095
+ 'setup': {
1096
+ 'model': 'test_model',
1097
+ 'tools': [{
1098
+ 'functionDeclarations': [{
1099
+ 'parameters': {
1100
+ 'type': 'OBJECT',
1101
+ 'properties': {
1102
+ 'location': {
1103
+ 'type': 'STRING',
1104
+ 'description': (
1105
+ 'The location to get the weather for'
1106
+ ),
1107
+ },
1108
+ 'unit': {'type': 'STRING', 'enum': ['C', 'F']},
1109
+ },
1110
+ },
1111
+ 'name': 'get_current_weather',
1112
+ 'description': 'Get the current weather in a city',
1113
+ }],
1114
+ }],
1115
+ }
1116
+ }
1117
+ result = await get_connect_message(
1118
+ mock_api_client(vertexai=vertexai),
1119
+ model='test_model', config=config
1120
+ )
1121
+
1122
+ assert result['setup']['tools'][0]['functionDeclarations'][0][
1123
+ 'description'
1124
+ ] == (
1125
+ expected_result['setup']['tools'][0]['functionDeclarations'][0][
1126
+ 'description'
1127
+ ]
1128
+ )
1129
+
1130
+ result = await get_connect_message(
1131
+ mock_api_client(vertexai=vertexai),
1132
+ model='test_model', config=config
1133
+ )
1134
+
1135
+ assert result['setup']['tools'][0]['functionDeclarations'][0][
1136
+ 'description'
1137
+ ] == (
1138
+ expected_result['setup']['tools'][0]['functionDeclarations'][0][
1139
+ 'description'
1140
+ ]
1141
+ )
1142
+
1143
+
1144
+ @pytest.mark.parametrize('vertexai', [True, False])
1145
+ @pytest.mark.asyncio
1146
+ async def test_bidi_setup_to_api_with_config_tools_function_directly(
1147
+ vertexai
1148
+ ):
1149
+ config_dict = {
1150
+ 'generation_config': {'temperature': 0.7},
1151
+ 'tools': [get_current_weather],
1152
+ }
1153
+ config = types.LiveConnectConfig(**config_dict)
1154
+ expected_result = {
1155
+ 'setup': {
1156
+ 'model': 'test_model',
1157
+ 'tools': [{
1158
+ 'functionDeclarations': [{
1159
+ 'parameters': {
1160
+ 'type': 'OBJECT',
1161
+ 'properties': {
1162
+ 'location': {
1163
+ 'type': 'STRING',
1164
+ 'description': (
1165
+ 'The location to get the weather for'
1166
+ ),
1167
+ },
1168
+ 'unit': {'type': 'STRING', 'enum': ['C', 'F']},
1169
+ },
1170
+ },
1171
+ 'name': 'get_current_weather',
1172
+ 'description': 'Get the current weather in a city.',
1173
+ }],
1174
+ }],
1175
+ }
1176
+ }
1177
+ result = await get_connect_message(
1178
+ mock_api_client(vertexai=vertexai),
1179
+ model='test_model', config=config
1180
+ )
1181
+
1182
+ assert result['setup']['tools'][0]['functionDeclarations'][0][
1183
+ 'description'
1184
+ ] == (
1185
+ expected_result['setup']['tools'][0]['functionDeclarations'][0][
1186
+ 'description'
1187
+ ]
1188
+ )
1189
+
1190
+ result = await get_connect_message(
1191
+ mock_api_client(vertexai=vertexai),
1192
+ model='test_model', config=config
1193
+ )
1194
+
1195
+ assert result['setup']['tools'][0]['functionDeclarations'][0][
1196
+ 'description'
1197
+ ] == (
1198
+ expected_result['setup']['tools'][0]['functionDeclarations'][0][
1199
+ 'description'
1200
+ ]
1201
+ )
1202
+
1203
+
1204
+ @pytest.mark.parametrize('vertexai', [True, False])
1205
+ @pytest.mark.asyncio
1206
+ async def test_bidi_setup_to_api_with_tools_function_behavior(vertexai):
1207
+ api_client = mock_api_client(vertexai=vertexai)
1208
+
1209
+ declaration = types.FunctionDeclaration.from_callable(
1210
+ client=api_client, callable=get_current_weather
1211
+ )
1212
+ declaration.behavior = types.Behavior.NON_BLOCKING
1213
+ config_dict = {
1214
+ 'generation_config': {'temperature': 0.7},
1215
+ 'tools': [{'function_declarations': [declaration]}],
1216
+ }
1217
+ config = types.LiveConnectConfig(**config_dict)
1218
+
1219
+ with pytest_helper.exception_if_vertex(api_client, ValueError):
1220
+ result = await get_connect_message(
1221
+ mock_api_client(vertexai=vertexai), model='test_model', config=config
1222
+ )
1223
+ if vertexai:
1224
+ return
1225
+
1226
+ assert (
1227
+ result['setup']['tools'][0]['functionDeclarations'][0]['behavior']
1228
+ == 'NON_BLOCKING'
1229
+ )
1230
+
1231
+
1232
+ @pytest.mark.parametrize('vertexai', [True, False])
1233
+ @pytest.mark.asyncio
1234
+ async def test_bidi_setup_to_api_with_config_mcp_tools(
1235
+ vertexai,
1236
+ ):
1237
+ if mcp_types is None:
1238
+ return
1239
+
1240
+ expected_result_googleai = {
1241
+ 'setup': {
1242
+ 'model': 'models/test_model',
1243
+ 'tools': [{
1244
+ 'functionDeclarations': [{
1245
+ 'parameters': {
1246
+ 'type': 'OBJECT',
1247
+ 'properties': {
1248
+ 'location': {
1249
+ 'type': 'STRING',
1250
+ },
1251
+ },
1252
+ },
1253
+ 'name': 'get_weather',
1254
+ 'description': 'Get the weather in a city.',
1255
+ }],
1256
+ }],
1257
+ }
1258
+ }
1259
+ expected_result_vertexai = {
1260
+ 'setup': {
1261
+ 'generationConfig': {
1262
+ 'responseModalities': [
1263
+ 'AUDIO',
1264
+ ],
1265
+ },
1266
+ 'model': (
1267
+ 'projects/test_project/locations/us-central1/publishers/google/models/test_model'
1268
+ ),
1269
+ 'tools': [{
1270
+ 'functionDeclarations': [{
1271
+ 'parameters': {
1272
+ 'type': 'OBJECT',
1273
+ 'properties': {
1274
+ 'location': {
1275
+ 'type': 'STRING',
1276
+ },
1277
+ },
1278
+ },
1279
+ 'name': 'get_weather',
1280
+ 'description': 'Get the weather in a city.',
1281
+ }],
1282
+ }],
1283
+ }
1284
+ }
1285
+ result = await get_connect_message(
1286
+ mock_api_client(vertexai=vertexai),
1287
+ model='test_model',
1288
+ config={
1289
+ 'tools': [
1290
+ mcp_types.Tool(
1291
+ name='get_weather',
1292
+ description='Get the weather in a city.',
1293
+ inputSchema={
1294
+ 'type': 'object',
1295
+ 'properties': {'location': {'type': 'string'}},
1296
+ },
1297
+ )
1298
+ ],
1299
+ },
1300
+ )
1301
+
1302
+ assert (
1303
+ result == expected_result_vertexai
1304
+ if vertexai
1305
+ else expected_result_googleai
1306
+ )
1307
+
1308
+
1309
+ @pytest.mark.parametrize('vertexai', [True, False])
1310
+ @pytest.mark.asyncio
1311
+ async def test_bidi_setup_to_api_with_config_mcp_session(
1312
+ vertexai,
1313
+ ):
1314
+ if mcp_types is None:
1315
+ return
1316
+
1317
+ class MockMcpClientSession(McpClientSession):
1318
+
1319
+ def __init__(self):
1320
+ self._read_stream = None
1321
+ self._write_stream = None
1322
+
1323
+ async def list_tools(self):
1324
+ return mcp_types.ListToolsResult(
1325
+ tools=[
1326
+ mcp_types.Tool(
1327
+ name='get_weather',
1328
+ description='Get the weather in a city.',
1329
+ inputSchema={
1330
+ 'type': 'object',
1331
+ 'properties': {'location': {'type': 'string'}},
1332
+ },
1333
+ ),
1334
+ ]
1335
+ )
1336
+
1337
+ expected_result_googleai = {
1338
+ 'setup': {
1339
+ 'model': 'models/test_model',
1340
+ 'tools': [{
1341
+ 'functionDeclarations': [{
1342
+ 'parameters': {
1343
+ 'type': 'OBJECT',
1344
+ 'properties': {
1345
+ 'location': {
1346
+ 'type': 'STRING',
1347
+ },
1348
+ },
1349
+ },
1350
+ 'name': 'get_weather',
1351
+ 'description': 'Get the weather in a city.',
1352
+ }],
1353
+ }],
1354
+ }
1355
+ }
1356
+ expected_result_vertexai = {
1357
+ 'setup': {
1358
+ 'generationConfig': {
1359
+ 'responseModalities': [
1360
+ 'AUDIO',
1361
+ ],
1362
+ },
1363
+ 'model': (
1364
+ 'projects/test_project/locations/us-central1/publishers/google/models/test_model'
1365
+ ),
1366
+ 'tools': [{
1367
+ 'functionDeclarations': [{
1368
+ 'parameters': {
1369
+ 'type': 'OBJECT',
1370
+ 'properties': {
1371
+ 'location': {
1372
+ 'type': 'STRING',
1373
+ },
1374
+ },
1375
+ },
1376
+ 'name': 'get_weather',
1377
+ 'description': 'Get the weather in a city.',
1378
+ }],
1379
+ }],
1380
+ }
1381
+ }
1382
+ result = await get_connect_message(
1383
+ mock_api_client(vertexai=vertexai),
1384
+ model='test_model',
1385
+ config={
1386
+ 'tools': [MockMcpClientSession()],
1387
+ },
1388
+ )
1389
+
1390
+ assert (
1391
+ result == expected_result_vertexai
1392
+ if vertexai
1393
+ else expected_result_googleai
1394
+ )
1395
+
1396
+
1397
+ @pytest.mark.parametrize('vertexai', [True, False])
1398
+ @pytest.mark.asyncio
1399
+ async def test_bidi_setup_to_api_with_config_tools_code_execution(
1400
+ vertexai
1401
+ ):
1402
+ config_dict = {
1403
+ 'tools': [{'code_execution': {}}],
1404
+ }
1405
+ config = types.LiveConnectConfig(**config_dict)
1406
+ expected_result = {
1407
+ 'setup': {
1408
+ 'model': 'test_model',
1409
+ 'tools': [{
1410
+ 'codeExecution': {},
1411
+ }],
1412
+ }
1413
+ }
1414
+ result = await get_connect_message(
1415
+ mock_api_client(vertexai=vertexai),
1416
+ model='test_model', config=config
1417
+ )
1418
+
1419
+ assert result['setup']['tools'][0] == expected_result['setup']['tools'][0]
1420
+
1421
+
1422
+ @pytest.mark.parametrize('vertexai', [True, False])
1423
+ @pytest.mark.asyncio
1424
+ async def test_bidi_setup_to_api_with_realtime_input_config(vertexai):
1425
+ config_dict = {
1426
+ 'realtime_input_config': {
1427
+ 'automatic_activity_detection': {
1428
+ 'disabled': True,
1429
+ 'start_of_speech_sensitivity': 'START_SENSITIVITY_HIGH',
1430
+ 'end_of_speech_sensitivity': 'END_SENSITIVITY_HIGH',
1431
+ 'prefix_padding_ms': 20,
1432
+ 'silence_duration_ms': 100,
1433
+ },
1434
+ 'activity_handling': 'NO_INTERRUPTION',
1435
+ 'turn_coverage': 'TURN_INCLUDES_ALL_INPUT',
1436
+ }
1437
+ }
1438
+
1439
+ config = types.LiveConnectConfig(**config_dict)
1440
+ expected_result = {
1441
+ 'setup': {
1442
+ 'model': 'test_model',
1443
+ 'realtimeInputConfig': config_dict['realtime_input_config'],
1444
+ }
1445
+ }
1446
+
1447
+ result = await get_connect_message(
1448
+ mock_api_client(vertexai=vertexai),
1449
+ model='test_model', config=config
1450
+ )
1451
+
1452
+ assert (
1453
+ result['setup']['realtimeInputConfig']
1454
+ == expected_result['setup']['realtimeInputConfig']
1455
+ )
1456
+
1457
+
1458
+ @pytest.mark.parametrize('vertexai', [True, False])
1459
+ @pytest.mark.asyncio
1460
+ async def test_bidi_setup_to_api_with_input_transcription(vertexai):
1461
+ config_dict = {
1462
+ 'input_audio_transcription': {},
1463
+ }
1464
+ config = types.LiveConnectConfig(**config_dict)
1465
+ expected_result = {
1466
+ 'setup': {
1467
+ 'model': 'test_model',
1468
+ 'inputAudioTranscription': {},
1469
+ }
1470
+ }
1471
+
1472
+ result = await get_connect_message(
1473
+ mock_api_client(vertexai=vertexai), model='test_model', config=config
1474
+ )
1475
+
1476
+ assert (
1477
+ result['setup']['inputAudioTranscription']
1478
+ == expected_result['setup']['inputAudioTranscription']
1479
+ )
1480
+
1481
+
1482
+ @pytest.mark.parametrize('vertexai', [True, False])
1483
+ @pytest.mark.asyncio
1484
+ async def test_bidi_setup_to_api_with_output_transcription(vertexai):
1485
+ config_dict = {
1486
+ 'output_audio_transcription': {},
1487
+ }
1488
+ config = types.LiveConnectConfig(**config_dict)
1489
+ expected_result = {
1490
+ 'setup': {
1491
+ 'model': 'test_model',
1492
+ 'outputAudioTranscription': {},
1493
+ }
1494
+ }
1495
+
1496
+ result = await get_connect_message(
1497
+ mock_api_client(vertexai=vertexai),
1498
+ model='test_model', config=config
1499
+ )
1500
+
1501
+ assert (
1502
+ result['setup']['outputAudioTranscription']
1503
+ == expected_result['setup']['outputAudioTranscription']
1504
+ )
1505
+
1506
+ @pytest.mark.parametrize('vertexai', [True, False])
1507
+ @pytest.mark.asyncio
1508
+ async def test_bidi_setup_to_api_with_media_resolution(vertexai):
1509
+ config_dict = {
1510
+ 'media_resolution': 'MEDIA_RESOLUTION_LOW',
1511
+ }
1512
+ config = types.LiveConnectConfig(**config_dict)
1513
+ expected_result = {
1514
+ 'setup': {
1515
+ 'model': 'test_model',
1516
+ 'generationConfig': {'mediaResolution':'MEDIA_RESOLUTION_LOW'},
1517
+ }
1518
+ }
1519
+
1520
+ result = await get_connect_message(
1521
+ mock_api_client(vertexai=vertexai),
1522
+ model='test_model', config=config
1523
+ )
1524
+
1525
+ assert (
1526
+ result['setup']['generationConfig']['mediaResolution']
1527
+ == expected_result['setup']['generationConfig']['mediaResolution']
1528
+ )
1529
+
1530
+
1531
+ @pytest.mark.parametrize('vertexai', [True])
1532
+ @pytest.mark.asyncio
1533
+ async def test_bidi_setup_publishers(
1534
+ vertexai
1535
+ ):
1536
+ expected_result = {
1537
+ 'setup': {
1538
+ 'generationConfig': {
1539
+ 'responseModalities': [
1540
+ 'AUDIO',
1541
+ ],
1542
+ },
1543
+ 'model': 'projects/test_project/locations/us-central1/publishers/google/models/test_model',
1544
+ }
1545
+ }
1546
+ result = await get_connect_message(
1547
+ mock_api_client(vertexai=vertexai),
1548
+ model='publishers/google/models/test_model')
1549
+
1550
+ assert result == expected_result
1551
+
1552
+
1553
+ @pytest.mark.parametrize('vertexai', [True, False])
1554
+ @pytest.mark.asyncio
1555
+ async def test_bidi_setup_generation_config_warning(
1556
+ vertexai
1557
+ ):
1558
+ with pytest.warns(
1559
+ DeprecationWarning,
1560
+ match='Setting `LiveConnectConfig.generation_config` is deprecated'
1561
+ ):
1562
+ result = await get_connect_message(
1563
+ mock_api_client(vertexai=vertexai),
1564
+ model='models/test_model',
1565
+ config={'generation_config': {'temperature': 0.7}})
1566
+
1567
+ assert result['setup']['generationConfig']['temperature'] == 0.7
1568
+
1569
+
1570
+ @pytest.mark.parametrize('vertexai', [True, False])
1571
+ @pytest.mark.asyncio
1572
+ async def test_bidi_setup_to_api_with_session_resumption(vertexai):
1573
+ config_dict = {
1574
+ 'session_resumption': {'handle': 'test_handle'},
1575
+ }
1576
+ config = types.LiveConnectConfig(**config_dict)
1577
+
1578
+ result = await get_connect_message(
1579
+ mock_api_client(vertexai=vertexai),
1580
+ model='test_model',
1581
+ config=config
1582
+ )
1583
+ expected_result = {
1584
+ 'setup': {
1585
+ 'sessionResumption': {
1586
+ 'handle': 'test_handle',
1587
+ },
1588
+ }
1589
+ }
1590
+ if vertexai:
1591
+ expected_result['setup']['generationConfig'] = {
1592
+ 'responseModalities': [
1593
+ 'AUDIO',
1594
+ ],
1595
+ }
1596
+ expected_result['setup']['model'] = 'projects/test_project/locations/us-central1/publishers/google/models/test_model'
1597
+ else:
1598
+ expected_result['setup']['model'] = 'models/test_model'
1599
+ assert result == expected_result
1600
+
1601
+
1602
+ @pytest.mark.parametrize('vertexai', [True, False])
1603
+ @pytest.mark.asyncio
1604
+ async def test_bidi_setup_to_api_with_transparent_session_resumption(vertexai):
1605
+ api_client = mock_api_client(vertexai=vertexai)
1606
+ config_dict = {
1607
+ 'session_resumption': {'handle': 'test_handle', 'transparent': True},
1608
+ }
1609
+ config = types.LiveConnectConfig(**config_dict)
1610
+
1611
+ with pytest_helper.exception_if_mldev(api_client, ValueError):
1612
+ result = await get_connect_message(
1613
+ mock_api_client(vertexai=vertexai),
1614
+ model='test_model',
1615
+ config=config
1616
+ )
1617
+
1618
+ expected_result = {
1619
+ 'setup': {
1620
+ 'sessionResumption': {
1621
+ 'handle': 'test_handle',
1622
+ 'transparent': True,
1623
+ },
1624
+ }
1625
+ }
1626
+ if vertexai:
1627
+ expected_result['setup']['generationConfig'] = {
1628
+ 'responseModalities': [
1629
+ 'AUDIO',
1630
+ ],
1631
+ }
1632
+ expected_result['setup']['model'] = 'projects/test_project/locations/us-central1/publishers/google/models/test_model'
1633
+ else:
1634
+ return
1635
+
1636
+ assert result == expected_result
1637
+
1638
+
1639
+ @pytest.mark.parametrize('vertexai', [True, False])
1640
+ def test_parse_client_message_str( mock_websocket, vertexai):
1641
+ session = live.AsyncSession(
1642
+ api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
1643
+ )
1644
+ result = session._parse_client_message('test')
1645
+ assert 'client_content' in result
1646
+ assert result == {
1647
+ 'client_content': {
1648
+ 'turn_complete': False,
1649
+ 'turns': [{'role': 'user', 'parts': [{'text': 'test'}]}],
1650
+ }
1651
+ }
1652
+ # _parse_client_message returns a TypedDict, so we should be able to
1653
+ # construct a LiveClientMessage from it
1654
+ assert types.LiveClientMessage(**result)
1655
+
1656
+
1657
+ @pytest.mark.parametrize('vertexai', [True, False])
1658
+ @pytest.mark.asyncio
1659
+ async def test_bidi_setup_to_api_with_thinking_config(vertexai):
1660
+ config_dict = {
1661
+ 'thinking_config': {
1662
+ 'include_thoughts': True,
1663
+ 'thinking_budget': 1024,
1664
+ }
1665
+ }
1666
+
1667
+ expected_gen_config = {
1668
+ 'thinkingConfig': {
1669
+ 'include_thoughts': True,
1670
+ 'thinking_budget': 1024,
1671
+ }
1672
+ }
1673
+
1674
+ if vertexai:
1675
+ expected_gen_config['responseModalities'] = ['AUDIO']
1676
+
1677
+ expected_result = {
1678
+ 'setup': {
1679
+ 'generationConfig': expected_gen_config,
1680
+ }
1681
+ }
1682
+
1683
+ if vertexai:
1684
+ expected_result['setup'][
1685
+ 'model'
1686
+ ] = 'projects/test_project/locations/us-central1/publishers/google/models/test_model'
1687
+ else:
1688
+ expected_result['setup']['model'] = 'models/test_model'
1689
+
1690
+ result = await get_connect_message(
1691
+ mock_api_client(vertexai=vertexai), model='test_model', config=config_dict
1692
+ )
1693
+ assert result == expected_result
1694
+
1695
+
1696
+ @pytest.mark.parametrize('vertexai', [True, False])
1697
+ def test_parse_client_message_blob( mock_websocket, vertexai):
1698
+ session = live.AsyncSession(
1699
+ api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
1700
+ )
1701
+ result = session._parse_client_message(
1702
+ types.Blob(data=bytes([0, 0, 0]), mime_type='text/plain')
1703
+ )
1704
+ assert 'realtime_input' in result
1705
+ assert result == {
1706
+ 'realtime_input': {
1707
+ 'media_chunks': [{'mime_type': 'text/plain', 'data': 'AAAA'}],
1708
+ }
1709
+ }
1710
+
1711
+
1712
+ @pytest.mark.parametrize('vertexai', [True, False])
1713
+ def test_parse_client_message_blob_dict(
1714
+ mock_websocket, vertexai
1715
+ ):
1716
+ session = live.AsyncSession(
1717
+ api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
1718
+ )
1719
+
1720
+ blob = types.Blob(data=bytes([0, 0, 0]), mime_type='text/plain')
1721
+ blob_dict = blob.model_dump()
1722
+ result = session._parse_client_message(blob_dict)
1723
+ assert 'realtime_input' in result
1724
+ assert result == {
1725
+ 'realtime_input': {
1726
+ 'media_chunks': [{'mime_type': 'text/plain', 'data': 'AAAA'}],
1727
+ }
1728
+ }
1729
+
1730
+
1731
+ @pytest.mark.parametrize('vertexai', [True, False])
1732
+ def test_parse_client_message_client_content(
1733
+ mock_websocket, vertexai
1734
+ ):
1735
+ session = live.AsyncSession(
1736
+ api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
1737
+ )
1738
+ result = session._parse_client_message(
1739
+ types.LiveClientContent(
1740
+ turn_complete=False,
1741
+ turns=[types.Content(parts=[types.Part(text='test')], role='user')],
1742
+ )
1743
+ )
1744
+ assert 'client_content' in result
1745
+ assert result == {
1746
+ 'client_content': {
1747
+ 'turn_complete': False,
1748
+ 'turns': [{'role': 'user', 'parts': [{'text': 'test'}]}],
1749
+ }
1750
+ }
1751
+
1752
+
1753
+ @pytest.mark.parametrize('vertexai', [True, False])
1754
+ def test_parse_client_message_client_content_blob(
1755
+ mock_websocket, vertexai
1756
+ ):
1757
+ session = live.AsyncSession(
1758
+ api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
1759
+ )
1760
+ client_content = types.LiveClientContent(
1761
+ turn_complete=False,
1762
+ turns=[
1763
+ types.Content(
1764
+ parts=[
1765
+ types.Part(
1766
+ inline_data=types.Blob(
1767
+ data=bytes([0, 0, 0]), mime_type='text/plain'
1768
+ )
1769
+ )
1770
+ ],
1771
+ role='user',
1772
+ )
1773
+ ],
1774
+ )
1775
+ result = session._parse_client_message(client_content)
1776
+ assert 'client_content' in result
1777
+ assert (
1778
+ type(
1779
+ result['client_content']['turns'][0]['parts'][0]['inline_data'][
1780
+ 'data'
1781
+ ]
1782
+ )
1783
+ == str
1784
+ )
1785
+ assert result == {
1786
+ 'client_content': {
1787
+ 'turn_complete': False,
1788
+ 'turns': [{
1789
+ 'role': 'user',
1790
+ 'parts': [
1791
+ {'inline_data': {'mime_type': 'text/plain', 'data': 'AAAA'}}
1792
+ ],
1793
+ }],
1794
+ }
1795
+ }
1796
+
1797
+
1798
+ @pytest.mark.parametrize('vertexai', [True, False])
1799
+ def test_parse_client_message_client_content_dict(
1800
+ mock_websocket, vertexai
1801
+ ):
1802
+ session = live.AsyncSession(
1803
+ api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
1804
+ )
1805
+ client_content = types.LiveClientContent(
1806
+ turn_complete=False,
1807
+ turns=[
1808
+ types.Content(
1809
+ parts=[
1810
+ types.Part(
1811
+ inline_data=types.Blob(
1812
+ data=bytes([0, 0, 0]), mime_type='text/plain'
1813
+ )
1814
+ )
1815
+ ],
1816
+ role='user',
1817
+ )
1818
+ ],
1819
+ )
1820
+ result = session._parse_client_message(
1821
+ client_content.model_dump(mode='json', exclude_none=True)
1822
+ )
1823
+ assert 'client_content' in result
1824
+ assert (
1825
+ type(
1826
+ result['client_content']['turns'][0]['parts'][0]['inline_data'][
1827
+ 'data'
1828
+ ]
1829
+ )
1830
+ == str
1831
+ )
1832
+ assert result == {
1833
+ 'client_content': {
1834
+ 'turn_complete': False,
1835
+ 'turns': [{
1836
+ 'role': 'user',
1837
+ 'parts': [
1838
+ {'inline_data': {'mime_type': 'text/plain', 'data': 'AAAA'}}
1839
+ ],
1840
+ }],
1841
+ }
1842
+ }
1843
+
1844
+
1845
+ @pytest.mark.parametrize('vertexai', [True, False])
1846
+ def test_parse_client_message_realtime_input(
1847
+ mock_websocket, vertexai
1848
+ ):
1849
+ session = live.AsyncSession(
1850
+ api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
1851
+ )
1852
+ input = types.LiveClientRealtimeInput(
1853
+ media_chunks=[types.Blob(data=bytes([0, 0, 0]), mime_type='text/plain')]
1854
+ )
1855
+ result = session._parse_client_message(input)
1856
+ assert 'realtime_input' in result
1857
+ assert result == {
1858
+ 'realtime_input': {
1859
+ 'media_chunks': [{'mime_type': 'text/plain', 'data': 'AAAA'}],
1860
+ }
1861
+ }
1862
+
1863
+
1864
+ @pytest.mark.parametrize('vertexai', [True, False])
1865
+ def test_parse_client_message_realtime_input_dict(
1866
+ mock_websocket, vertexai
1867
+ ):
1868
+ session = live.AsyncSession(
1869
+ api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
1870
+ )
1871
+ input = types.LiveClientRealtimeInput(
1872
+ media_chunks=[types.Blob(data=bytes([0, 0, 0]), mime_type='text/plain')]
1873
+ )
1874
+ result = session._parse_client_message(
1875
+ input.model_dump(mode='json', exclude_none=True)
1876
+ )
1877
+ assert 'realtime_input' in result
1878
+ assert result == {
1879
+ 'realtime_input': {
1880
+ 'media_chunks': [{'mime_type': 'text/plain', 'data': 'AAAA'}],
1881
+ }
1882
+ }
1883
+
1884
+
1885
+ @pytest.mark.parametrize('vertexai', [True, False])
1886
+ def test_parse_client_message_tool_response(
1887
+ mock_websocket, vertexai
1888
+ ):
1889
+ session = live.AsyncSession(
1890
+ api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
1891
+ )
1892
+ input = types.LiveClientToolResponse(
1893
+ function_responses=[
1894
+ types.FunctionResponse(
1895
+ id='test_id',
1896
+ name='test_name',
1897
+ response={'result': 'test_response'},
1898
+ )
1899
+ ]
1900
+ )
1901
+ result = session._parse_client_message(input)
1902
+ assert 'tool_response' in result
1903
+ assert result == {
1904
+ 'tool_response': {
1905
+ 'function_responses': [
1906
+ {
1907
+ 'id': 'test_id',
1908
+ 'name': 'test_name',
1909
+ 'response': {
1910
+ 'result': 'test_response',
1911
+ },
1912
+ },
1913
+ ],
1914
+ }
1915
+ }
1916
+
1917
+
1918
+ @pytest.mark.parametrize('vertexai', [True, False])
1919
+ def test_parse_client_message_function_response(
1920
+ mock_websocket, vertexai
1921
+ ):
1922
+ session = live.AsyncSession(
1923
+ api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
1924
+ )
1925
+ input = types.FunctionResponse(
1926
+ id='test_id',
1927
+ name='test_name',
1928
+ response={
1929
+ 'result': 'test_response',
1930
+ 'user_name': 'test_user_name',
1931
+ 'userEmail': 'test_user_email',
1932
+ },
1933
+ )
1934
+ result = session._parse_client_message(input)
1935
+ assert 'tool_response' in result
1936
+ assert result == {
1937
+ 'tool_response': {
1938
+ 'function_responses': [
1939
+ {
1940
+ 'id': 'test_id',
1941
+ 'name': 'test_name',
1942
+ 'response': {
1943
+ 'result': 'test_response',
1944
+ 'user_name': 'test_user_name',
1945
+ 'userEmail': 'test_user_email',
1946
+ },
1947
+ },
1948
+ ],
1949
+ }
1950
+ }
1951
+
1952
+
1953
+ @pytest.mark.parametrize('vertexai', [True, False])
1954
+ def test_parse_client_message_tool_response_dict_with_only_response(
1955
+ mock_websocket, vertexai
1956
+ ):
1957
+ session = live.AsyncSession(
1958
+ api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
1959
+ )
1960
+ input = {
1961
+ 'id': 'test_id',
1962
+ 'name': 'test_name',
1963
+ 'response': {
1964
+ 'result': 'test_response',
1965
+ }
1966
+ }
1967
+ result = session._parse_client_message(input)
1968
+ assert 'tool_response' in result
1969
+ assert result == {
1970
+ 'tool_response': {
1971
+ 'function_responses': [
1972
+ {
1973
+ 'id': 'test_id',
1974
+ 'name': 'test_name',
1975
+ 'response': {
1976
+ 'result': 'test_response',
1977
+ },
1978
+ },
1979
+ ],
1980
+ }
1981
+ }
1982
+
1983
+
1984
+ @pytest.mark.parametrize('vertexai', [True, False])
1985
+ def test_parse_client_message_realtime_tool_response(
1986
+ mock_websocket, vertexai
1987
+ ):
1988
+ session = live.AsyncSession(
1989
+ api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket
1990
+ )
1991
+ input = types.LiveClientToolResponse(
1992
+ function_responses=[
1993
+ types.FunctionResponse(
1994
+ id='test_id',
1995
+ name='test_name',
1996
+ response={'result': 'test_response'},
1997
+ )
1998
+ ]
1999
+ )
2000
+
2001
+ result = session._parse_client_message(
2002
+ input.model_dump(mode='json', exclude_none=True)
2003
+ )
2004
+ assert 'tool_response' in result
2005
+ assert result == {
2006
+ 'tool_response': {
2007
+ 'function_responses': [
2008
+ {
2009
+ 'id': 'test_id',
2010
+ 'name': 'test_name',
2011
+ 'response': {
2012
+ 'result': 'test_response',
2013
+ },
2014
+ },
2015
+ ],
2016
+ }
2017
+ }
2018
+
2019
+
2020
+ @pytest.mark.asyncio
2021
+ async def test_connect_with_provided_credentials(mock_websocket):
2022
+ # custom oauth2 credentials
2023
+ credentials = Credentials(token='provided_fake_token')
2024
+ # mock api client
2025
+ client = mock_api_client(vertexai=True, credentials=credentials)
2026
+ capture = {}
2027
+
2028
+ @contextlib.asynccontextmanager
2029
+ async def mock_connect(uri, additional_headers=None, **kwargs):
2030
+ capture['headers'] = additional_headers
2031
+ yield mock_websocket
2032
+
2033
+ @patch.object(live, 'ws_connect', new=mock_connect)
2034
+ async def _test_connect():
2035
+ live_module = live.AsyncLive(client)
2036
+ async with live_module.connect(model='test-model'):
2037
+ pass
2038
+
2039
+ assert 'Authorization' in capture['headers']
2040
+ assert capture['headers']['Authorization'] == 'Bearer provided_fake_token'
2041
+
2042
+ await _test_connect()
2043
+
2044
+
2045
+ @pytest.mark.asyncio
2046
+ async def test_connect_with_default_credentials(mock_websocket):
2047
+ # mock api client
2048
+ client = mock_api_client(vertexai=True, credentials=None)
2049
+ # mock google auth cred
2050
+ mock_google_auth_default = Mock(return_value=(None, None))
2051
+ mock_creds = Mock(token='default_test_token')
2052
+ mock_google_auth_default.return_value = (mock_creds, None)
2053
+ capture = {}
2054
+
2055
+ @contextlib.asynccontextmanager
2056
+ async def mock_connect(uri, additional_headers=None, **kwargs):
2057
+ capture['headers'] = additional_headers
2058
+ yield mock_websocket
2059
+
2060
+ @patch('google.auth.default', new=mock_google_auth_default)
2061
+ @patch.object(live, 'ws_connect', new=mock_connect)
2062
+ async def _test_connect():
2063
+ live_module = live.AsyncLive(client)
2064
+ async with live_module.connect(model='test-model'):
2065
+ pass
2066
+
2067
+ assert 'Authorization' in capture['headers']
2068
+ assert capture['headers']['Authorization'] == 'Bearer default_test_token'
2069
+
2070
+ await _test_connect()
2071
+
2072
+
2073
+ @pytest.mark.asyncio
2074
+ async def test_connect_with_custom_base_url(mock_websocket):
2075
+ # mock api client
2076
+ client = gl_client.BaseApiClient(
2077
+ vertexai=True,
2078
+ http_options={
2079
+ 'base_url': 'https://custom-base-url.com',
2080
+ 'headers': {'Authorization': 'Bearer custom_test_token'},
2081
+ }
2082
+ )
2083
+ # No ADC credentials.
2084
+ capture = {}
2085
+
2086
+ @contextlib.asynccontextmanager
2087
+ async def mock_connect(uri, additional_headers=None, **kwargs):
2088
+ capture['uri'] = uri
2089
+ capture['headers'] = additional_headers
2090
+ yield mock_websocket
2091
+
2092
+ @patch.object(live, 'ws_connect', new=mock_connect)
2093
+ async def _test_connect():
2094
+ live_module = live.AsyncLive(client)
2095
+ async with live_module.connect(model='test-model'):
2096
+ pass
2097
+
2098
+ assert 'Authorization' in capture['headers']
2099
+ assert capture['headers']['Authorization'] == 'Bearer custom_test_token'
2100
+ assert capture['uri'] == 'https://custom-base-url.com'
2101
+
2102
+ await _test_connect()
2103
+
2104
+
2105
+ @pytest.mark.parametrize('vertexai', [False])
2106
+ @pytest.mark.asyncio
2107
+ async def test_bidi_setup_to_api_with_auth_tokens(mock_websocket, vertexai):
2108
+ api_client_mock = mock_api_client(vertexai=vertexai)
2109
+ api_client_mock.api_key = 'auth_tokens/TEST_AUTH_TOKEN'
2110
+ result = await get_connect_message(api_client_mock, model='test_model')
2111
+
2112
+ mock_ws = AsyncMock()
2113
+ mock_ws.send = AsyncMock()
2114
+ mock_ws.recv = AsyncMock(
2115
+ return_value=(
2116
+ b'{\n "setupComplete": {"sessionId": "test_session_id"}\n}\n'
2117
+ )
2118
+ )
2119
+ capture = {}
2120
+
2121
+ @contextlib.asynccontextmanager
2122
+ async def mock_connect(uri, additional_headers=None, **kwargs):
2123
+ capture['uri'] = uri
2124
+ capture['headers'] = additional_headers
2125
+ yield mock_ws
2126
+
2127
+ with patch.object(live, 'ws_connect', new=mock_connect):
2128
+ live_module = live.AsyncLive(api_client_mock)
2129
+ async with live_module.connect(
2130
+ model='test_model',
2131
+ ):
2132
+ pass
2133
+
2134
+ assert (
2135
+ 'Authorization' in capture['headers']
2136
+ ), 'Authorization key is missing from headers'
2137
+ assert (
2138
+ capture['headers']['Authorization'] == 'Token auth_tokens/TEST_AUTH_TOKEN'
2139
+ )
2140
+ assert 'BidiGenerateContentConstrained' in capture['uri']
2141
+
2142
+
2143
+ @pytest.mark.parametrize('vertexai', [False])
2144
+ @pytest.mark.asyncio
2145
+ async def test_bidi_setup_to_api_with_api_key(mock_websocket, vertexai):
2146
+ api_client_mock = mock_api_client(vertexai=vertexai)
2147
+ api_client_mock._http_options = types.HttpOptions.model_validate(
2148
+ {'headers': {'x-goog-api-key': 'TEST_API_KEY'}}
2149
+ )
2150
+ result = await get_connect_message(api_client_mock, model='test_model')
2151
+
2152
+ mock_ws = AsyncMock()
2153
+ mock_ws.send = AsyncMock()
2154
+ mock_ws.recv = AsyncMock(
2155
+ return_value=(
2156
+ b'{\n "setupComplete": {"sessionId": "test_session_id"}\n}\n'
2157
+ )
2158
+ )
2159
+ capture = {}
2160
+
2161
+ @contextlib.asynccontextmanager
2162
+ async def mock_connect(uri, additional_headers=None, **kwargs):
2163
+ capture['uri'] = uri
2164
+ capture['headers'] = additional_headers
2165
+ yield mock_ws
2166
+
2167
+ with patch.object(live, 'ws_connect', new=mock_connect):
2168
+ live_module = live.AsyncLive(api_client_mock)
2169
+ async with live_module.connect(
2170
+ model='test_model',
2171
+ ):
2172
+ pass
2173
+
2174
+ assert 'x-goog-api-key' in capture['headers'], "x-goog-api-key is missing from headers"
2175
+ assert capture['headers']['x-goog-api-key'] == 'TEST_API_KEY'
2176
+ assert 'BidiGenerateContent' in capture['uri']
2177
+