google-genai 1.53.0__py3-none-any.whl → 1.55.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (324) hide show
  1. google/genai/__init__.py +1 -0
  2. google/genai/_api_client.py +6 -6
  3. google/genai/_interactions/__init__.py +117 -0
  4. google/genai/_interactions/_base_client.py +2019 -0
  5. google/genai/_interactions/_client.py +511 -0
  6. google/genai/_interactions/_compat.py +234 -0
  7. google/genai/_interactions/_constants.py +29 -0
  8. google/genai/_interactions/_exceptions.py +122 -0
  9. google/genai/_interactions/_files.py +139 -0
  10. google/genai/_interactions/_models.py +873 -0
  11. google/genai/_interactions/_qs.py +165 -0
  12. google/genai/_interactions/_resource.py +58 -0
  13. google/genai/_interactions/_response.py +847 -0
  14. google/genai/_interactions/_streaming.py +354 -0
  15. google/genai/_interactions/_types.py +276 -0
  16. google/genai/_interactions/_utils/__init__.py +79 -0
  17. google/genai/_interactions/_utils/_compat.py +61 -0
  18. google/genai/_interactions/_utils/_datetime_parse.py +151 -0
  19. google/genai/_interactions/_utils/_logs.py +40 -0
  20. google/genai/_interactions/_utils/_proxy.py +80 -0
  21. google/genai/_interactions/_utils/_reflection.py +57 -0
  22. google/genai/_interactions/_utils/_resources_proxy.py +39 -0
  23. google/genai/_interactions/_utils/_streams.py +27 -0
  24. google/genai/_interactions/_utils/_sync.py +73 -0
  25. google/genai/_interactions/_utils/_transform.py +472 -0
  26. google/genai/_interactions/_utils/_typing.py +172 -0
  27. google/genai/_interactions/_utils/_utils.py +437 -0
  28. google/genai/_interactions/_version.py +18 -0
  29. google/genai/_interactions/resources/__init__.py +34 -0
  30. google/genai/_interactions/resources/interactions.py +1350 -0
  31. google/genai/_interactions/types/__init__.py +107 -0
  32. google/genai/_interactions/types/allowed_tools.py +33 -0
  33. google/genai/_interactions/types/allowed_tools_param.py +35 -0
  34. google/genai/_interactions/types/annotation.py +42 -0
  35. google/genai/_interactions/types/annotation_param.py +42 -0
  36. google/genai/_interactions/types/audio_content.py +38 -0
  37. google/genai/_interactions/types/audio_content_param.py +45 -0
  38. google/genai/_interactions/types/audio_mime_type.py +25 -0
  39. google/genai/_interactions/types/audio_mime_type_param.py +27 -0
  40. google/genai/_interactions/types/code_execution_call_arguments.py +33 -0
  41. google/genai/_interactions/types/code_execution_call_arguments_param.py +32 -0
  42. google/genai/_interactions/types/code_execution_call_content.py +37 -0
  43. google/genai/_interactions/types/code_execution_call_content_param.py +37 -0
  44. google/genai/_interactions/types/code_execution_result_content.py +42 -0
  45. google/genai/_interactions/types/code_execution_result_content_param.py +41 -0
  46. google/genai/_interactions/types/content_delta.py +358 -0
  47. google/genai/_interactions/types/content_start.py +79 -0
  48. google/genai/_interactions/types/content_stop.py +35 -0
  49. google/genai/_interactions/types/deep_research_agent_config.py +33 -0
  50. google/genai/_interactions/types/deep_research_agent_config_param.py +32 -0
  51. google/genai/_interactions/types/document_content.py +36 -0
  52. google/genai/_interactions/types/document_content_param.py +43 -0
  53. google/genai/_interactions/types/dynamic_agent_config.py +44 -0
  54. google/genai/_interactions/types/dynamic_agent_config_param.py +33 -0
  55. google/genai/_interactions/types/error_event.py +46 -0
  56. google/genai/_interactions/types/file_search_result_content.py +46 -0
  57. google/genai/_interactions/types/file_search_result_content_param.py +46 -0
  58. google/genai/_interactions/types/function.py +38 -0
  59. google/genai/_interactions/types/function_call_content.py +39 -0
  60. google/genai/_interactions/types/function_call_content_param.py +39 -0
  61. google/genai/_interactions/types/function_param.py +37 -0
  62. google/genai/_interactions/types/function_result_content.py +52 -0
  63. google/genai/_interactions/types/function_result_content_param.py +54 -0
  64. google/genai/_interactions/types/generation_config.py +57 -0
  65. google/genai/_interactions/types/generation_config_param.py +59 -0
  66. google/genai/_interactions/types/google_search_call_arguments.py +29 -0
  67. google/genai/_interactions/types/google_search_call_arguments_param.py +31 -0
  68. google/genai/_interactions/types/google_search_call_content.py +37 -0
  69. google/genai/_interactions/types/google_search_call_content_param.py +37 -0
  70. google/genai/_interactions/types/google_search_result.py +35 -0
  71. google/genai/_interactions/types/google_search_result_content.py +43 -0
  72. google/genai/_interactions/types/google_search_result_content_param.py +44 -0
  73. google/genai/_interactions/types/google_search_result_param.py +35 -0
  74. google/genai/_interactions/types/image_content.py +41 -0
  75. google/genai/_interactions/types/image_content_param.py +48 -0
  76. google/genai/_interactions/types/image_mime_type.py +23 -0
  77. google/genai/_interactions/types/image_mime_type_param.py +25 -0
  78. google/genai/_interactions/types/interaction.py +165 -0
  79. google/genai/_interactions/types/interaction_create_params.py +212 -0
  80. google/genai/_interactions/types/interaction_event.py +37 -0
  81. google/genai/_interactions/types/interaction_get_params.py +46 -0
  82. google/genai/_interactions/types/interaction_sse_event.py +32 -0
  83. google/genai/_interactions/types/interaction_status_update.py +37 -0
  84. google/genai/_interactions/types/mcp_server_tool_call_content.py +42 -0
  85. google/genai/_interactions/types/mcp_server_tool_call_content_param.py +42 -0
  86. google/genai/_interactions/types/mcp_server_tool_result_content.py +52 -0
  87. google/genai/_interactions/types/mcp_server_tool_result_content_param.py +54 -0
  88. google/genai/_interactions/types/model.py +36 -0
  89. google/genai/_interactions/types/model_param.py +38 -0
  90. google/genai/_interactions/types/speech_config.py +35 -0
  91. google/genai/_interactions/types/speech_config_param.py +35 -0
  92. google/genai/_interactions/types/text_content.py +37 -0
  93. google/genai/_interactions/types/text_content_param.py +38 -0
  94. google/genai/_interactions/types/thinking_level.py +22 -0
  95. google/genai/_interactions/types/thought_content.py +41 -0
  96. google/genai/_interactions/types/thought_content_param.py +47 -0
  97. google/genai/_interactions/types/tool.py +100 -0
  98. google/genai/_interactions/types/tool_choice.py +26 -0
  99. google/genai/_interactions/types/tool_choice_config.py +28 -0
  100. google/genai/_interactions/types/tool_choice_config_param.py +29 -0
  101. google/genai/_interactions/types/tool_choice_param.py +28 -0
  102. google/genai/_interactions/types/tool_choice_type.py +22 -0
  103. google/genai/_interactions/types/tool_param.py +97 -0
  104. google/genai/_interactions/types/turn.py +76 -0
  105. google/genai/_interactions/types/turn_param.py +73 -0
  106. google/genai/_interactions/types/url_context_call_arguments.py +29 -0
  107. google/genai/_interactions/types/url_context_call_arguments_param.py +31 -0
  108. google/genai/_interactions/types/url_context_call_content.py +37 -0
  109. google/genai/_interactions/types/url_context_call_content_param.py +37 -0
  110. google/genai/_interactions/types/url_context_result.py +33 -0
  111. google/genai/_interactions/types/url_context_result_content.py +43 -0
  112. google/genai/_interactions/types/url_context_result_content_param.py +44 -0
  113. google/genai/_interactions/types/url_context_result_param.py +32 -0
  114. google/genai/_interactions/types/usage.py +106 -0
  115. google/genai/_interactions/types/usage_param.py +106 -0
  116. google/genai/_interactions/types/video_content.py +41 -0
  117. google/genai/_interactions/types/video_content_param.py +48 -0
  118. google/genai/_interactions/types/video_mime_type.py +36 -0
  119. google/genai/_interactions/types/video_mime_type_param.py +38 -0
  120. google/genai/_live_converters.py +34 -3
  121. google/genai/_tokens_converters.py +5 -0
  122. google/genai/batches.py +62 -55
  123. google/genai/client.py +223 -0
  124. google/genai/errors.py +16 -1
  125. google/genai/file_search_stores.py +60 -60
  126. google/genai/files.py +56 -56
  127. google/genai/interactions.py +17 -0
  128. google/genai/live.py +4 -3
  129. google/genai/models.py +15 -3
  130. google/genai/tests/__init__.py +21 -0
  131. google/genai/tests/afc/__init__.py +21 -0
  132. google/genai/tests/afc/test_convert_if_exist_pydantic_model.py +309 -0
  133. google/genai/tests/afc/test_convert_number_values_for_function_call_args.py +63 -0
  134. google/genai/tests/afc/test_find_afc_incompatible_tool_indexes.py +240 -0
  135. google/genai/tests/afc/test_generate_content_stream_afc.py +530 -0
  136. google/genai/tests/afc/test_generate_content_stream_afc_thoughts.py +77 -0
  137. google/genai/tests/afc/test_get_function_map.py +176 -0
  138. google/genai/tests/afc/test_get_function_response_parts.py +277 -0
  139. google/genai/tests/afc/test_get_max_remote_calls_for_afc.py +130 -0
  140. google/genai/tests/afc/test_invoke_function_from_dict_args.py +241 -0
  141. google/genai/tests/afc/test_raise_error_for_afc_incompatible_config.py +159 -0
  142. google/genai/tests/afc/test_should_append_afc_history.py +53 -0
  143. google/genai/tests/afc/test_should_disable_afc.py +214 -0
  144. google/genai/tests/batches/__init__.py +17 -0
  145. google/genai/tests/batches/test_cancel.py +77 -0
  146. google/genai/tests/batches/test_create.py +78 -0
  147. google/genai/tests/batches/test_create_with_bigquery.py +113 -0
  148. google/genai/tests/batches/test_create_with_file.py +82 -0
  149. google/genai/tests/batches/test_create_with_gcs.py +125 -0
  150. google/genai/tests/batches/test_create_with_inlined_requests.py +255 -0
  151. google/genai/tests/batches/test_delete.py +86 -0
  152. google/genai/tests/batches/test_embedding.py +157 -0
  153. google/genai/tests/batches/test_get.py +78 -0
  154. google/genai/tests/batches/test_list.py +79 -0
  155. google/genai/tests/caches/__init__.py +17 -0
  156. google/genai/tests/caches/constants.py +29 -0
  157. google/genai/tests/caches/test_create.py +210 -0
  158. google/genai/tests/caches/test_create_custom_url.py +105 -0
  159. google/genai/tests/caches/test_delete.py +54 -0
  160. google/genai/tests/caches/test_delete_custom_url.py +52 -0
  161. google/genai/tests/caches/test_get.py +94 -0
  162. google/genai/tests/caches/test_get_custom_url.py +52 -0
  163. google/genai/tests/caches/test_list.py +68 -0
  164. google/genai/tests/caches/test_update.py +70 -0
  165. google/genai/tests/caches/test_update_custom_url.py +58 -0
  166. google/genai/tests/chats/__init__.py +1 -0
  167. google/genai/tests/chats/test_get_history.py +597 -0
  168. google/genai/tests/chats/test_send_message.py +844 -0
  169. google/genai/tests/chats/test_validate_response.py +90 -0
  170. google/genai/tests/client/__init__.py +17 -0
  171. google/genai/tests/client/test_async_stream.py +427 -0
  172. google/genai/tests/client/test_client_close.py +197 -0
  173. google/genai/tests/client/test_client_initialization.py +1687 -0
  174. google/genai/tests/client/test_client_requests.py +355 -0
  175. google/genai/tests/client/test_custom_client.py +77 -0
  176. google/genai/tests/client/test_http_options.py +178 -0
  177. google/genai/tests/client/test_replay_client_equality.py +168 -0
  178. google/genai/tests/client/test_retries.py +846 -0
  179. google/genai/tests/client/test_upload_errors.py +136 -0
  180. google/genai/tests/common/__init__.py +17 -0
  181. google/genai/tests/common/test_common.py +954 -0
  182. google/genai/tests/conftest.py +162 -0
  183. google/genai/tests/documents/__init__.py +17 -0
  184. google/genai/tests/documents/test_delete.py +51 -0
  185. google/genai/tests/documents/test_get.py +85 -0
  186. google/genai/tests/documents/test_list.py +72 -0
  187. google/genai/tests/errors/__init__.py +1 -0
  188. google/genai/tests/errors/test_api_error.py +417 -0
  189. google/genai/tests/file_search_stores/__init__.py +17 -0
  190. google/genai/tests/file_search_stores/test_create.py +66 -0
  191. google/genai/tests/file_search_stores/test_delete.py +64 -0
  192. google/genai/tests/file_search_stores/test_get.py +94 -0
  193. google/genai/tests/file_search_stores/test_import_file.py +112 -0
  194. google/genai/tests/file_search_stores/test_list.py +57 -0
  195. google/genai/tests/file_search_stores/test_upload_to_file_search_store.py +141 -0
  196. google/genai/tests/files/__init__.py +17 -0
  197. google/genai/tests/files/test_delete.py +46 -0
  198. google/genai/tests/files/test_download.py +85 -0
  199. google/genai/tests/files/test_get.py +46 -0
  200. google/genai/tests/files/test_list.py +72 -0
  201. google/genai/tests/files/test_upload.py +255 -0
  202. google/genai/tests/imports/test_no_optional_imports.py +28 -0
  203. google/genai/tests/interactions/__init__.py +0 -0
  204. google/genai/tests/interactions/test_integration.py +80 -0
  205. google/genai/tests/live/__init__.py +16 -0
  206. google/genai/tests/live/test_live.py +2177 -0
  207. google/genai/tests/live/test_live_music.py +362 -0
  208. google/genai/tests/live/test_live_response.py +163 -0
  209. google/genai/tests/live/test_send_client_content.py +147 -0
  210. google/genai/tests/live/test_send_realtime_input.py +268 -0
  211. google/genai/tests/live/test_send_tool_response.py +222 -0
  212. google/genai/tests/local_tokenizer/__init__.py +17 -0
  213. google/genai/tests/local_tokenizer/test_local_tokenizer.py +343 -0
  214. google/genai/tests/local_tokenizer/test_local_tokenizer_loader.py +235 -0
  215. google/genai/tests/mcp/__init__.py +17 -0
  216. google/genai/tests/mcp/test_has_mcp_tool_usage.py +89 -0
  217. google/genai/tests/mcp/test_mcp_to_gemini_tools.py +191 -0
  218. google/genai/tests/mcp/test_parse_config_for_mcp_sessions.py +201 -0
  219. google/genai/tests/mcp/test_parse_config_for_mcp_usage.py +130 -0
  220. google/genai/tests/mcp/test_set_mcp_usage_header.py +72 -0
  221. google/genai/tests/models/__init__.py +17 -0
  222. google/genai/tests/models/constants.py +8 -0
  223. google/genai/tests/models/test_compute_tokens.py +120 -0
  224. google/genai/tests/models/test_count_tokens.py +159 -0
  225. google/genai/tests/models/test_delete.py +107 -0
  226. google/genai/tests/models/test_edit_image.py +264 -0
  227. google/genai/tests/models/test_embed_content.py +94 -0
  228. google/genai/tests/models/test_function_call_streaming.py +442 -0
  229. google/genai/tests/models/test_generate_content.py +2502 -0
  230. google/genai/tests/models/test_generate_content_cached_content.py +132 -0
  231. google/genai/tests/models/test_generate_content_config_zero_value.py +103 -0
  232. google/genai/tests/models/test_generate_content_from_apikey.py +44 -0
  233. google/genai/tests/models/test_generate_content_http_options.py +40 -0
  234. google/genai/tests/models/test_generate_content_image_generation.py +143 -0
  235. google/genai/tests/models/test_generate_content_mcp.py +343 -0
  236. google/genai/tests/models/test_generate_content_media_resolution.py +97 -0
  237. google/genai/tests/models/test_generate_content_model.py +139 -0
  238. google/genai/tests/models/test_generate_content_part.py +821 -0
  239. google/genai/tests/models/test_generate_content_thought.py +76 -0
  240. google/genai/tests/models/test_generate_content_tools.py +1761 -0
  241. google/genai/tests/models/test_generate_images.py +191 -0
  242. google/genai/tests/models/test_generate_videos.py +759 -0
  243. google/genai/tests/models/test_get.py +104 -0
  244. google/genai/tests/models/test_list.py +233 -0
  245. google/genai/tests/models/test_recontext_image.py +189 -0
  246. google/genai/tests/models/test_segment_image.py +148 -0
  247. google/genai/tests/models/test_update.py +95 -0
  248. google/genai/tests/models/test_upscale_image.py +157 -0
  249. google/genai/tests/operations/__init__.py +17 -0
  250. google/genai/tests/operations/test_get.py +38 -0
  251. google/genai/tests/public_samples/__init__.py +17 -0
  252. google/genai/tests/public_samples/test_gemini_text_only.py +34 -0
  253. google/genai/tests/pytest_helper.py +229 -0
  254. google/genai/tests/shared/__init__.py +16 -0
  255. google/genai/tests/shared/batches/__init__.py +14 -0
  256. google/genai/tests/shared/batches/test_create_delete.py +57 -0
  257. google/genai/tests/shared/batches/test_create_get_cancel.py +56 -0
  258. google/genai/tests/shared/batches/test_list.py +40 -0
  259. google/genai/tests/shared/caches/__init__.py +14 -0
  260. google/genai/tests/shared/caches/test_create_get_delete.py +67 -0
  261. google/genai/tests/shared/caches/test_create_update_get.py +71 -0
  262. google/genai/tests/shared/caches/test_list.py +40 -0
  263. google/genai/tests/shared/chats/__init__.py +14 -0
  264. google/genai/tests/shared/chats/test_send_message.py +48 -0
  265. google/genai/tests/shared/chats/test_send_message_stream.py +50 -0
  266. google/genai/tests/shared/files/__init__.py +14 -0
  267. google/genai/tests/shared/files/test_list.py +41 -0
  268. google/genai/tests/shared/files/test_upload_get_delete.py +54 -0
  269. google/genai/tests/shared/models/__init__.py +14 -0
  270. google/genai/tests/shared/models/test_compute_tokens.py +41 -0
  271. google/genai/tests/shared/models/test_count_tokens.py +40 -0
  272. google/genai/tests/shared/models/test_edit_image.py +67 -0
  273. google/genai/tests/shared/models/test_embed.py +40 -0
  274. google/genai/tests/shared/models/test_generate_content.py +39 -0
  275. google/genai/tests/shared/models/test_generate_content_stream.py +54 -0
  276. google/genai/tests/shared/models/test_generate_images.py +40 -0
  277. google/genai/tests/shared/models/test_generate_videos.py +38 -0
  278. google/genai/tests/shared/models/test_list.py +37 -0
  279. google/genai/tests/shared/models/test_recontext_image.py +55 -0
  280. google/genai/tests/shared/models/test_segment_image.py +52 -0
  281. google/genai/tests/shared/models/test_upscale_image.py +52 -0
  282. google/genai/tests/shared/tunings/__init__.py +16 -0
  283. google/genai/tests/shared/tunings/test_create.py +46 -0
  284. google/genai/tests/shared/tunings/test_create_get_cancel.py +56 -0
  285. google/genai/tests/shared/tunings/test_list.py +39 -0
  286. google/genai/tests/tokens/__init__.py +16 -0
  287. google/genai/tests/tokens/test_create.py +154 -0
  288. google/genai/tests/transformers/__init__.py +17 -0
  289. google/genai/tests/transformers/test_blobs.py +71 -0
  290. google/genai/tests/transformers/test_bytes.py +15 -0
  291. google/genai/tests/transformers/test_duck_type.py +96 -0
  292. google/genai/tests/transformers/test_function_responses.py +72 -0
  293. google/genai/tests/transformers/test_schema.py +653 -0
  294. google/genai/tests/transformers/test_t_batch.py +286 -0
  295. google/genai/tests/transformers/test_t_content.py +160 -0
  296. google/genai/tests/transformers/test_t_contents.py +398 -0
  297. google/genai/tests/transformers/test_t_part.py +85 -0
  298. google/genai/tests/transformers/test_t_parts.py +87 -0
  299. google/genai/tests/transformers/test_t_tool.py +157 -0
  300. google/genai/tests/transformers/test_t_tools.py +195 -0
  301. google/genai/tests/tunings/__init__.py +16 -0
  302. google/genai/tests/tunings/test_cancel.py +39 -0
  303. google/genai/tests/tunings/test_end_to_end.py +106 -0
  304. google/genai/tests/tunings/test_get.py +67 -0
  305. google/genai/tests/tunings/test_list.py +75 -0
  306. google/genai/tests/tunings/test_tune.py +268 -0
  307. google/genai/tests/types/__init__.py +16 -0
  308. google/genai/tests/types/test_bytes_internal.py +271 -0
  309. google/genai/tests/types/test_bytes_type.py +152 -0
  310. google/genai/tests/types/test_future.py +101 -0
  311. google/genai/tests/types/test_optional_types.py +36 -0
  312. google/genai/tests/types/test_part_type.py +616 -0
  313. google/genai/tests/types/test_schema_from_json_schema.py +417 -0
  314. google/genai/tests/types/test_schema_json_schema.py +468 -0
  315. google/genai/tests/types/test_types.py +2903 -0
  316. google/genai/tunings.py +57 -57
  317. google/genai/types.py +229 -121
  318. google/genai/version.py +1 -1
  319. {google_genai-1.53.0.dist-info → google_genai-1.55.0.dist-info}/METADATA +4 -2
  320. google_genai-1.55.0.dist-info/RECORD +345 -0
  321. google_genai-1.53.0.dist-info/RECORD +0 -41
  322. {google_genai-1.53.0.dist-info → google_genai-1.55.0.dist-info}/WHEEL +0 -0
  323. {google_genai-1.53.0.dist-info → google_genai-1.55.0.dist-info}/licenses/LICENSE +0 -0
  324. {google_genai-1.53.0.dist-info → google_genai-1.55.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,2903 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+
16
+
17
+ import copy
18
+ import json
19
+ import sys
20
+ import typing
21
+ from typing import Optional, assert_never
22
+ import PIL.Image
23
+ import pydantic
24
+ import pytest
25
+ from ... import types
26
+
27
+ _is_mcp_imported = False
28
+ if typing.TYPE_CHECKING:
29
+ from mcp import types as mcp_types
30
+
31
+ _is_mcp_imported = True
32
+ else:
33
+ try:
34
+ from mcp import types as mcp_types
35
+
36
+ _is_mcp_imported = True
37
+ except ImportError:
38
+ mcp_types = None
39
+
40
+
41
+ class SubPart(types.Part):
42
+ pass
43
+
44
+
45
+ class SubFunctionResponsePart(types.FunctionResponsePart):
46
+ pass
47
+
48
+
49
+ def test_factory_method_from_uri_part():
50
+
51
+ my_part = SubPart.from_uri(
52
+ file_uri='gs://generativeai-downloads/images/scones.jpg',
53
+ mime_type='image/jpeg',
54
+ )
55
+ assert (
56
+ my_part.file_data.file_uri
57
+ == 'gs://generativeai-downloads/images/scones.jpg'
58
+ )
59
+ assert my_part.file_data.mime_type == 'image/jpeg'
60
+ assert isinstance(my_part, SubPart)
61
+
62
+
63
+ def test_factory_method_from_uri_inferred_mime_type_part():
64
+
65
+ my_part = SubPart.from_uri(
66
+ file_uri='gs://generativeai-downloads/images/scones.jpg',
67
+ )
68
+ assert (
69
+ my_part.file_data.file_uri
70
+ == 'gs://generativeai-downloads/images/scones.jpg'
71
+ )
72
+ assert my_part.file_data.mime_type == 'image/jpeg'
73
+ assert isinstance(my_part, SubPart)
74
+
75
+
76
+ def test_factory_method_from_text_part():
77
+ my_part = SubPart.from_text(text='What is your name?')
78
+ assert my_part.text == 'What is your name?'
79
+ assert isinstance(my_part, SubPart)
80
+
81
+
82
+ def test_factory_method_from_bytes_part():
83
+ my_part = SubPart.from_bytes(data=b'123', mime_type='text/plain')
84
+ assert my_part.inline_data.data == b'123'
85
+ assert my_part.inline_data.mime_type == 'text/plain'
86
+ assert isinstance(my_part, SubPart)
87
+
88
+
89
+ def test_factory_method_from_function_call_part():
90
+ my_part = SubPart.from_function_call(name='func', args={'arg': 'value'})
91
+ assert my_part.function_call.name == 'func'
92
+ assert my_part.function_call.args == {'arg': 'value'}
93
+ assert isinstance(my_part, SubPart)
94
+
95
+
96
+ def test_factory_method_from_function_response_part():
97
+ my_part = SubPart.from_function_response(
98
+ name='func', response={'response': 'value'}
99
+ )
100
+ assert my_part.function_response.name == 'func'
101
+ assert my_part.function_response.response == {'response': 'value'}
102
+ assert isinstance(my_part, SubPart)
103
+
104
+
105
+ def test_factory_method_part_from_function_response_with_multi_modal_parts():
106
+ my_part = SubPart.from_function_response(
107
+ name='func',
108
+ response={'response': 'value'},
109
+ parts=[{'inline_data': {'data': b'123', 'mime_type': 'image/png'}}],
110
+ )
111
+ assert my_part.function_response.name == 'func'
112
+ assert my_part.function_response.response == {'response': 'value'}
113
+ assert my_part.function_response.parts[0].inline_data.data == b'123'
114
+ assert my_part.function_response.parts[0].inline_data.mime_type == 'image/png'
115
+ assert isinstance(my_part, SubPart)
116
+
117
+
118
+ def test_factory_method_function_response_part_from_bytes():
119
+ my_part = SubFunctionResponsePart.from_bytes(
120
+ data=b'123', mime_type='image/png'
121
+ )
122
+ assert my_part.inline_data.data == b'123'
123
+ assert my_part.inline_data.mime_type == 'image/png'
124
+ assert isinstance(my_part, SubFunctionResponsePart)
125
+
126
+
127
+ def test_factory_method_function_response_part_from_uri():
128
+ my_part = SubFunctionResponsePart.from_uri(
129
+ file_uri='gs://generativeai-downloads/images/scones.jpg',
130
+ mime_type='image/jpeg',
131
+ )
132
+ assert (
133
+ my_part.file_data.file_uri
134
+ == 'gs://generativeai-downloads/images/scones.jpg'
135
+ )
136
+ assert my_part.file_data.mime_type == 'image/jpeg'
137
+ assert isinstance(my_part, SubFunctionResponsePart)
138
+
139
+
140
+ def test_factory_method_from_executable_code_part():
141
+ my_part = SubPart.from_executable_code(
142
+ code='print("hello")', language='PYTHON'
143
+ )
144
+ assert my_part.executable_code.code == 'print("hello")'
145
+ assert my_part.executable_code.language == 'PYTHON'
146
+ assert isinstance(my_part, SubPart)
147
+
148
+
149
+ def test_factory_method_from_code_execution_result_part():
150
+ my_part = SubPart.from_code_execution_result(
151
+ outcome='OUTCOME_OK', output='print("hello")'
152
+ )
153
+ assert my_part.code_execution_result.outcome == 'OUTCOME_OK'
154
+ assert my_part.code_execution_result.output == 'print("hello")'
155
+ assert isinstance(my_part, SubPart)
156
+
157
+
158
+ def test_factory_method_from_mcp_call_tool_function_response_on_error():
159
+ if not _is_mcp_imported:
160
+ return
161
+
162
+ call_tool_result = mcp_types.CallToolResult(
163
+ content=[],
164
+ isError=True,
165
+ )
166
+ my_function_response = types.FunctionResponse.from_mcp_response(
167
+ name='func_name', response=call_tool_result
168
+ )
169
+ assert my_function_response.name == 'func_name'
170
+ assert my_function_response.response == {'error': 'MCP response is error.'}
171
+ assert isinstance(my_function_response, types.FunctionResponse)
172
+
173
+
174
+ def test_factory_method_from_mcp_call_tool_function_response_text():
175
+ if not _is_mcp_imported:
176
+ return
177
+
178
+ call_tool_result = mcp_types.CallToolResult(
179
+ content=[
180
+ mcp_types.TextContent(type='text', text='hello'),
181
+ mcp_types.TextContent(type='text', text=' world'),
182
+ ],
183
+ )
184
+ my_function_response = types.FunctionResponse.from_mcp_response(
185
+ name='func_name', response=call_tool_result
186
+ )
187
+ assert my_function_response.name == 'func_name'
188
+ assert my_function_response.response == {
189
+ 'result': [
190
+ mcp_types.TextContent(type='text', text='hello'),
191
+ mcp_types.TextContent(type='text', text=' world'),
192
+ ]
193
+ }
194
+ assert isinstance(my_function_response, types.FunctionResponse)
195
+
196
+
197
+ def test_factory_method_from_mcp_call_tool_function_response_inline_data():
198
+ if not _is_mcp_imported:
199
+ return
200
+
201
+ call_tool_result = mcp_types.CallToolResult(
202
+ content=[
203
+ mcp_types.ImageContent(
204
+ type='image',
205
+ data='MTIz',
206
+ mimeType='text/plain',
207
+ ),
208
+ mcp_types.ImageContent(
209
+ type='image',
210
+ data='NDU2',
211
+ mimeType='text/plain',
212
+ ),
213
+ ],
214
+ )
215
+ my_function_response = types.FunctionResponse.from_mcp_response(
216
+ name='func_name', response=call_tool_result
217
+ )
218
+ assert my_function_response.name == 'func_name'
219
+ assert my_function_response.response == {
220
+ 'result': [
221
+ mcp_types.ImageContent(
222
+ type='image',
223
+ data='MTIz',
224
+ mimeType='text/plain',
225
+ ),
226
+ mcp_types.ImageContent(
227
+ type='image',
228
+ data='NDU2',
229
+ mimeType='text/plain',
230
+ ),
231
+ ]
232
+ }
233
+ assert isinstance(my_function_response, types.FunctionResponse)
234
+
235
+
236
+ def test_factory_method_from_mcp_call_tool_function_response_combined_content():
237
+ if not _is_mcp_imported:
238
+ return
239
+
240
+ call_tool_result = mcp_types.CallToolResult(
241
+ content=[
242
+ mcp_types.TextContent(
243
+ type='text',
244
+ text='Hello',
245
+ ),
246
+ mcp_types.ImageContent(
247
+ type='image',
248
+ data='NDU2',
249
+ mimeType='text/plain',
250
+ ),
251
+ ],
252
+ )
253
+ my_function_response = types.FunctionResponse.from_mcp_response(
254
+ name='func_name', response=call_tool_result
255
+ )
256
+ assert my_function_response.name == 'func_name'
257
+ assert my_function_response.response == {
258
+ 'result': [
259
+ mcp_types.TextContent(
260
+ type='text',
261
+ text='Hello',
262
+ ),
263
+ mcp_types.ImageContent(
264
+ type='image',
265
+ data='NDU2',
266
+ mimeType='text/plain',
267
+ ),
268
+ ]
269
+ }
270
+ assert isinstance(my_function_response, types.FunctionResponse)
271
+
272
+
273
+ def test_factory_method_from_mcp_call_tool_function_response_embedded_resource():
274
+ if not _is_mcp_imported:
275
+ return
276
+
277
+ call_tool_result = mcp_types.CallToolResult(
278
+ content=[
279
+ mcp_types.EmbeddedResource(
280
+ type='resource',
281
+ resource=mcp_types.TextResourceContents(
282
+ uri='https://generativelanguage.googleapis.com/v1beta/files/ansa0kyotrsw',
283
+ text='hello',
284
+ ),
285
+ ),
286
+ ],
287
+ )
288
+ my_function_response = types.FunctionResponse.from_mcp_response(
289
+ name='func_name', response=call_tool_result
290
+ )
291
+ assert my_function_response.name == 'func_name'
292
+ assert my_function_response.response == {
293
+ 'result': [
294
+ mcp_types.EmbeddedResource(
295
+ type='resource',
296
+ resource=mcp_types.TextResourceContents(
297
+ uri='https://generativelanguage.googleapis.com/v1beta/files/ansa0kyotrsw',
298
+ text='hello',
299
+ ),
300
+ ),
301
+ ]
302
+ }
303
+ assert isinstance(my_function_response, types.FunctionResponse)
304
+
305
+
306
+ def test_part_constructor_with_string_value():
307
+ part = types.Part('hello')
308
+ assert part.text == 'hello'
309
+ assert part.file_data is None
310
+ assert part.inline_data is None
311
+
312
+
313
+ def test_part_constructor_with_part_value():
314
+ other_part = types.Part(text='hello from other part')
315
+ part = types.Part(other_part)
316
+ assert part.text == 'hello from other part'
317
+
318
+
319
+ def test_part_constructor_with_part_dict_value():
320
+ part = types.Part({'text': 'hello from dict'})
321
+ assert part.text == 'hello from dict'
322
+
323
+
324
+ def test_part_constructor_with_file_data_dict_value():
325
+ part = types.Part(
326
+ {'file_uri': 'gs://my-bucket/file-data', 'mime_type': 'text/plain'}
327
+ )
328
+ assert part.file_data.file_uri == 'gs://my-bucket/file-data'
329
+ assert part.file_data.mime_type == 'text/plain'
330
+
331
+
332
+ def test_part_constructor_with_kwargs_and_value_fails():
333
+ with pytest.raises(
334
+ ValueError, match='Positional and keyword arguments can not be combined'
335
+ ):
336
+ types.Part('hello', text='world')
337
+
338
+
339
+ def test_part_constructor_with_file_value():
340
+ f = types.File(
341
+ uri='gs://my-bucket/my-file',
342
+ mime_type='text/plain',
343
+ display_name='test file',
344
+ )
345
+ part = types.Part(f)
346
+ assert part.file_data.file_uri == 'gs://my-bucket/my-file'
347
+ assert part.file_data.mime_type == 'text/plain'
348
+ assert part.file_data.display_name == 'test file'
349
+
350
+
351
+ def test_part_constructor_with_pil_image():
352
+ img = PIL.Image.new('RGB', (1, 1), color='red')
353
+ part = types.Part(img)
354
+ assert part.inline_data.mime_type == 'image/jpeg'
355
+ assert isinstance(part.inline_data.data, bytes)
356
+
357
+
358
+ class FakeClient:
359
+
360
+ def __init__(self, vertexai=False) -> None:
361
+ self.vertexai = vertexai
362
+
363
+
364
+ mldev_client = FakeClient()
365
+ vertex_client = FakeClient(vertexai=True)
366
+
367
+
368
+ def test_empty_function():
369
+ def func_under_test():
370
+ """test empty function."""
371
+ pass
372
+
373
+ expected_schema_mldev = types.FunctionDeclaration(
374
+ name='func_under_test',
375
+ description='test empty function.',
376
+ )
377
+ expected_schema_vertex = copy.deepcopy(expected_schema_mldev)
378
+
379
+ actual_schema_mldev = types.FunctionDeclaration.from_callable(
380
+ client=mldev_client, callable=func_under_test
381
+ )
382
+ actual_schema_vertex = types.FunctionDeclaration.from_callable(
383
+ client=vertex_client, callable=func_under_test
384
+ )
385
+
386
+ assert actual_schema_mldev == expected_schema_mldev
387
+ assert actual_schema_vertex == expected_schema_vertex
388
+
389
+
390
+ def test_built_in_primitives_and_compounds():
391
+
392
+ def func_under_test(
393
+ a: int,
394
+ b: float,
395
+ c: bool,
396
+ d: str,
397
+ e: list,
398
+ f: dict,
399
+ ):
400
+ """test built in primitives and compounds."""
401
+ pass
402
+
403
+ expected_schema = types.FunctionDeclaration(
404
+ name='func_under_test',
405
+ parameters=types.Schema(
406
+ type='OBJECT',
407
+ properties={
408
+ 'a': types.Schema(type='INTEGER'),
409
+ 'b': types.Schema(type='NUMBER'),
410
+ 'c': types.Schema(type='BOOLEAN'),
411
+ 'd': types.Schema(type='STRING'),
412
+ 'e': types.Schema(type='ARRAY'),
413
+ 'f': types.Schema(type='OBJECT'),
414
+ },
415
+ required=['a', 'b', 'c', 'd', 'e', 'f'],
416
+ ),
417
+ description='test built in primitives and compounds.',
418
+ )
419
+
420
+ actual_schema_mldev = types.FunctionDeclaration.from_callable(
421
+ client=mldev_client, callable=func_under_test
422
+ )
423
+ actual_schema_vertex = types.FunctionDeclaration.from_callable(
424
+ client=vertex_client, callable=func_under_test
425
+ )
426
+
427
+ assert actual_schema_mldev == expected_schema
428
+ assert actual_schema_vertex == expected_schema
429
+
430
+
431
+ def test_default_value_built_in_type():
432
+ def func_under_test(a: str, b: int = '1', c: list = []):
433
+ """test default value not compatible built in type."""
434
+ pass
435
+
436
+ types.FunctionDeclaration.from_callable(
437
+ client=mldev_client, callable=func_under_test
438
+ )
439
+ types.FunctionDeclaration.from_callable(
440
+ client=vertex_client, callable=func_under_test
441
+ )
442
+
443
+
444
+ def test_default_value_built_in_type():
445
+ def func_under_test(a: str, b: int = 1, c: list = []):
446
+ """test default value."""
447
+ pass
448
+
449
+ expected_schema = types.FunctionDeclaration(
450
+ name='func_under_test',
451
+ parameters=types.Schema(
452
+ type='OBJECT',
453
+ properties={
454
+ 'a': types.Schema(type='STRING'),
455
+ 'b': types.Schema(type='INTEGER', default=1),
456
+ 'c': types.Schema(type='ARRAY', default=[]),
457
+ },
458
+ required=['a'],
459
+ ),
460
+ description='test default value.',
461
+ )
462
+
463
+ actual_schema_vertex = types.FunctionDeclaration.from_callable(
464
+ client=vertex_client, callable=func_under_test
465
+ )
466
+ actual_schema_mldev = types.FunctionDeclaration.from_callable(
467
+ client=mldev_client, callable=func_under_test
468
+ )
469
+
470
+ assert actual_schema_vertex == expected_schema
471
+ assert actual_schema_mldev == expected_schema
472
+
473
+
474
+ @pytest.mark.skipif(
475
+ sys.version_info < (3, 10),
476
+ reason='| is only supported in Python 3.10 and above.',
477
+ )
478
+ def test_built_in_primitives_compounds():
479
+ def func_under_test1(a: bytes):
480
+ pass
481
+
482
+ def func_under_test2(a: set):
483
+ pass
484
+
485
+ def func_under_test3(a: frozenset):
486
+ pass
487
+
488
+ def func_under_test4(a: type(None)):
489
+ pass
490
+
491
+ def func_under_test5(a: int | bytes):
492
+ pass
493
+
494
+ def func_under_test6(a: int | set):
495
+ pass
496
+
497
+ def func_under_test7(a: int | frozenset):
498
+ pass
499
+
500
+ def func_under_test8(a: typing.Union[int, bytes]):
501
+ pass
502
+
503
+ def func_under_test9(a: typing.Union[int, set]):
504
+ pass
505
+
506
+ def func_under_test10(a: typing.Union[int, frozenset]):
507
+ pass
508
+
509
+ all_func_under_test = [
510
+ func_under_test1,
511
+ func_under_test2,
512
+ func_under_test3,
513
+ func_under_test4,
514
+ func_under_test5,
515
+ func_under_test6,
516
+ func_under_test7,
517
+ func_under_test8,
518
+ func_under_test9,
519
+ func_under_test10,
520
+ ]
521
+ for func_under_test in all_func_under_test:
522
+ types.FunctionDeclaration.from_callable(
523
+ client=mldev_client, callable=func_under_test
524
+ )
525
+ types.FunctionDeclaration.from_callable(
526
+ client=vertex_client, callable=func_under_test
527
+ )
528
+
529
+
530
+ @pytest.mark.skipif(
531
+ sys.version_info < (3, 10),
532
+ reason='| is only supported in Python 3.10 and above.',
533
+ )
534
+ def test_built_in_union_type():
535
+
536
+ def func_under_test(
537
+ a: int | str | float | bool,
538
+ b: list | dict,
539
+ ):
540
+ """test built in union type."""
541
+ pass
542
+
543
+ expected_schema = types.FunctionDeclaration(
544
+ name='func_under_test',
545
+ parameters=types.Schema(
546
+ type='OBJECT',
547
+ properties={
548
+ 'a': types.Schema(
549
+ type='OBJECT',
550
+ any_of=[
551
+ types.Schema(type='INTEGER'),
552
+ types.Schema(type='STRING'),
553
+ types.Schema(type='NUMBER'),
554
+ types.Schema(type='BOOLEAN'),
555
+ ],
556
+ ),
557
+ 'b': types.Schema(
558
+ type='OBJECT',
559
+ any_of=[
560
+ types.Schema(type='ARRAY'),
561
+ types.Schema(type='OBJECT'),
562
+ ],
563
+ ),
564
+ },
565
+ required=['a', 'b'],
566
+ ),
567
+ description='test built in union type.',
568
+ )
569
+
570
+ actual_schema_mldev = types.FunctionDeclaration.from_callable(
571
+ client=mldev_client, callable=func_under_test
572
+ )
573
+ actual_schema_vertex = types.FunctionDeclaration.from_callable(
574
+ client=vertex_client, callable=func_under_test
575
+ )
576
+
577
+ assert actual_schema_vertex == expected_schema
578
+ assert actual_schema_mldev == expected_schema
579
+
580
+
581
+ def test_built_in_union_type_all_py_versions():
582
+
583
+ def func_under_test(
584
+ a: typing.Union[int, str, float, bool],
585
+ b: typing.Union[list, dict],
586
+ ):
587
+ """test built in union type."""
588
+ pass
589
+
590
+ expected_schema = types.FunctionDeclaration(
591
+ name='func_under_test',
592
+ parameters=types.Schema(
593
+ type='OBJECT',
594
+ properties={
595
+ 'a': types.Schema(
596
+ type='OBJECT',
597
+ any_of=[
598
+ types.Schema(type='INTEGER'),
599
+ types.Schema(type='STRING'),
600
+ types.Schema(type='NUMBER'),
601
+ types.Schema(type='BOOLEAN'),
602
+ ],
603
+ ),
604
+ 'b': types.Schema(
605
+ type='OBJECT',
606
+ any_of=[
607
+ types.Schema(type='ARRAY'),
608
+ types.Schema(type='OBJECT'),
609
+ ],
610
+ ),
611
+ },
612
+ required=['a', 'b'],
613
+ ),
614
+ description='test built in union type.',
615
+ )
616
+
617
+ actual_schema_mldev = types.FunctionDeclaration.from_callable(
618
+ client=mldev_client, callable=func_under_test
619
+ )
620
+ actual_schema_vertex = types.FunctionDeclaration.from_callable(
621
+ client=vertex_client, callable=func_under_test
622
+ )
623
+
624
+ assert actual_schema_vertex == expected_schema
625
+ assert actual_schema_mldev == expected_schema
626
+
627
+
628
+ @pytest.mark.skipif(
629
+ sys.version_info < (3, 10),
630
+ reason='| is only supported in Python 3.10 and above.',
631
+ )
632
+ def test_default_value_built_in_union_type():
633
+ def func_under_test(
634
+ a: int | str = 1.1,
635
+ ):
636
+ """test default value not compatible built in union type."""
637
+ pass
638
+
639
+ types.FunctionDeclaration.from_callable(
640
+ client=mldev_client, callable=func_under_test
641
+ )
642
+ types.FunctionDeclaration.from_callable(
643
+ client=vertex_client, callable=func_under_test
644
+ )
645
+
646
+
647
+ def test_default_value_built_in_union_type_all_py_versions():
648
+ def func_under_test(
649
+ a: typing.Union[int, str] = 1.1,
650
+ ):
651
+ """test default value not compatible built in union type."""
652
+ pass
653
+
654
+ types.FunctionDeclaration.from_callable(
655
+ client=mldev_client, callable=func_under_test
656
+ )
657
+ types.FunctionDeclaration.from_callable(
658
+ client=vertex_client, callable=func_under_test
659
+ )
660
+
661
+
662
+ @pytest.mark.skipif(
663
+ sys.version_info < (3, 10),
664
+ reason='| is only supported in Python 3.10 and above.',
665
+ )
666
+ def test_default_value_built_in_union_type():
667
+
668
+ def func_under_test(
669
+ a: int | str = '1',
670
+ b: list | dict = [],
671
+ c: list | dict = {},
672
+ ):
673
+ """test default value built in union type."""
674
+ pass
675
+
676
+ expected_schema = types.FunctionDeclaration(
677
+ name='func_under_test',
678
+ parameters=types.Schema(
679
+ type='OBJECT',
680
+ properties={
681
+ 'a': types.Schema(
682
+ type='OBJECT',
683
+ any_of=[
684
+ types.Schema(type='INTEGER'),
685
+ types.Schema(type='STRING'),
686
+ ],
687
+ default='1',
688
+ ),
689
+ 'b': types.Schema(
690
+ type='OBJECT',
691
+ any_of=[
692
+ types.Schema(type='ARRAY'),
693
+ types.Schema(type='OBJECT'),
694
+ ],
695
+ default=[],
696
+ ),
697
+ 'c': types.Schema(
698
+ type='OBJECT',
699
+ any_of=[
700
+ types.Schema(type='ARRAY'),
701
+ types.Schema(type='OBJECT'),
702
+ ],
703
+ default={},
704
+ ),
705
+ },
706
+ required=[],
707
+ ),
708
+ description='test default value built in union type.',
709
+ )
710
+
711
+ actual_schema_vertex = types.FunctionDeclaration.from_callable(
712
+ client=vertex_client, callable=func_under_test
713
+ )
714
+ actual_schema_mldev = types.FunctionDeclaration.from_callable(
715
+ client=mldev_client, callable=func_under_test
716
+ )
717
+
718
+ assert actual_schema_vertex == expected_schema
719
+ assert actual_schema_mldev == expected_schema
720
+
721
+
722
+ def test_default_value_built_in_union_type_all_py_versions():
723
+
724
+ def func_under_test(
725
+ a: typing.Union[int, str] = '1',
726
+ b: typing.Union[list, dict] = [],
727
+ c: typing.Union[list, dict] = {},
728
+ ):
729
+ """test default value built in union type."""
730
+ pass
731
+
732
+ expected_schema = types.FunctionDeclaration(
733
+ name='func_under_test',
734
+ parameters=types.Schema(
735
+ type='OBJECT',
736
+ properties={
737
+ 'a': types.Schema(
738
+ type='OBJECT',
739
+ any_of=[
740
+ types.Schema(type='INTEGER'),
741
+ types.Schema(type='STRING'),
742
+ ],
743
+ default='1',
744
+ ),
745
+ 'b': types.Schema(
746
+ type='OBJECT',
747
+ any_of=[
748
+ types.Schema(type='ARRAY'),
749
+ types.Schema(type='OBJECT'),
750
+ ],
751
+ default=[],
752
+ ),
753
+ 'c': types.Schema(
754
+ type='OBJECT',
755
+ any_of=[
756
+ types.Schema(type='ARRAY'),
757
+ types.Schema(type='OBJECT'),
758
+ ],
759
+ default={},
760
+ ),
761
+ },
762
+ required=[],
763
+ ),
764
+ description='test default value built in union type.',
765
+ )
766
+
767
+ actual_schema_vertex = types.FunctionDeclaration.from_callable(
768
+ client=vertex_client, callable=func_under_test
769
+ )
770
+ actual_schema_mldev = types.FunctionDeclaration.from_callable(
771
+ client=mldev_client, callable=func_under_test
772
+ )
773
+
774
+ assert actual_schema_vertex == expected_schema
775
+ assert actual_schema_mldev == expected_schema
776
+
777
+
778
+ def test_generic_alias_literal():
779
+
780
+ def func_under_test(a: typing.Literal['a', 'b', 'c']):
781
+ """test generic alias literal."""
782
+ pass
783
+
784
+ expected_schema = types.FunctionDeclaration(
785
+ name='func_under_test',
786
+ parameters=types.Schema(
787
+ type='OBJECT',
788
+ properties={
789
+ 'a': types.Schema(
790
+ type='STRING',
791
+ enum=['a', 'b', 'c'],
792
+ ),
793
+ },
794
+ required=['a'],
795
+ ),
796
+ description='test generic alias literal.',
797
+ )
798
+
799
+ actual_schema_mldev = types.FunctionDeclaration.from_callable(
800
+ client=mldev_client, callable=func_under_test
801
+ )
802
+ actual_schema_vertex = types.FunctionDeclaration.from_callable(
803
+ client=vertex_client, callable=func_under_test
804
+ )
805
+
806
+ assert actual_schema_mldev == expected_schema
807
+ assert actual_schema_vertex == expected_schema
808
+
809
+
810
+ def test_default_value_generic_alias_literal():
811
+
812
+ def func_under_test(a: typing.Literal['1', '2', '3'] = '1'):
813
+ """test default value generic alias literal."""
814
+ pass
815
+
816
+ expected_schema = types.FunctionDeclaration(
817
+ name='func_under_test',
818
+ parameters=types.Schema(
819
+ type='OBJECT',
820
+ properties={
821
+ 'a': types.Schema(
822
+ type='STRING',
823
+ enum=['1', '2', '3'],
824
+ default='1',
825
+ ),
826
+ },
827
+ required=[],
828
+ ),
829
+ description='test default value generic alias literal.',
830
+ )
831
+
832
+ actual_schema_vertex = types.FunctionDeclaration.from_callable(
833
+ client=vertex_client, callable=func_under_test
834
+ )
835
+ actual_schema_mldev = types.FunctionDeclaration.from_callable(
836
+ client=mldev_client, callable=func_under_test
837
+ )
838
+
839
+ assert actual_schema_vertex == expected_schema
840
+ assert actual_schema_mldev == expected_schema
841
+
842
+
843
+ def test_default_value_generic_alias_literal():
844
+ def func_under_test(a: typing.Literal['1', '2', 3]):
845
+ """test default value generic alias literal not compatible."""
846
+ pass
847
+
848
+ types.FunctionDeclaration.from_callable(
849
+ client=mldev_client, callable=func_under_test
850
+ )
851
+ types.FunctionDeclaration.from_callable(
852
+ client=vertex_client, callable=func_under_test
853
+ )
854
+
855
+
856
+ def test_default_value_generic_alias_literal_with_str_default():
857
+ def func_under_test(a: typing.Literal['a', 'b', 'c'] = 'd'):
858
+ """test default value not compatible generic alias literal."""
859
+ pass
860
+
861
+ types.FunctionDeclaration.from_callable(
862
+ client=mldev_client, callable=func_under_test
863
+ )
864
+ types.FunctionDeclaration.from_callable(
865
+ client=vertex_client, callable=func_under_test
866
+ )
867
+
868
+
869
+ def test_generic_alias_array():
870
+
871
+ def func_under_test(
872
+ a: typing.List[int],
873
+ ):
874
+ """test generic alias array."""
875
+ pass
876
+
877
+ expected_schema = types.FunctionDeclaration(
878
+ name='func_under_test',
879
+ parameters=types.Schema(
880
+ type='OBJECT',
881
+ properties={
882
+ 'a': types.Schema(
883
+ type='ARRAY', items=types.Schema(type='INTEGER')
884
+ ),
885
+ },
886
+ required=['a'],
887
+ ),
888
+ description='test generic alias array.',
889
+ )
890
+
891
+ actual_schema_mldev = types.FunctionDeclaration.from_callable(
892
+ client=mldev_client, callable=func_under_test
893
+ )
894
+ actual_schema_vertex = types.FunctionDeclaration.from_callable(
895
+ client=vertex_client, callable=func_under_test
896
+ )
897
+
898
+ assert actual_schema_mldev == expected_schema
899
+ assert actual_schema_vertex == expected_schema
900
+
901
+
902
+ @pytest.mark.skipif(
903
+ sys.version_info < (3, 10),
904
+ reason='| is only supported in Python 3.10 and above.',
905
+ )
906
+ def test_generic_alias_complex_array():
907
+
908
+ def func_under_test(
909
+ a: typing.List[int | str | float | bool],
910
+ b: typing.List[list | dict],
911
+ ):
912
+ """test generic alias complex array."""
913
+ pass
914
+
915
+ expected_schema = types.FunctionDeclaration(
916
+ name='func_under_test',
917
+ parameters=types.Schema(
918
+ type='OBJECT',
919
+ properties={
920
+ 'a': types.Schema(
921
+ type='ARRAY',
922
+ items=types.Schema(
923
+ type='OBJECT',
924
+ any_of=[
925
+ types.Schema(type='INTEGER'),
926
+ types.Schema(type='STRING'),
927
+ types.Schema(type='NUMBER'),
928
+ types.Schema(type='BOOLEAN'),
929
+ ],
930
+ ),
931
+ ),
932
+ 'b': types.Schema(
933
+ type='ARRAY',
934
+ items=types.Schema(
935
+ type='OBJECT',
936
+ any_of=[
937
+ types.Schema(type='ARRAY'),
938
+ types.Schema(type='OBJECT'),
939
+ ],
940
+ ),
941
+ ),
942
+ },
943
+ required=['a', 'b'],
944
+ ),
945
+ description='test generic alias complex array.',
946
+ )
947
+
948
+ actual_schema_mldev = types.FunctionDeclaration.from_callable(
949
+ client=mldev_client, callable=func_under_test
950
+ )
951
+ actual_schema_vertex = types.FunctionDeclaration.from_callable(
952
+ client=vertex_client, callable=func_under_test
953
+ )
954
+ assert actual_schema_vertex == expected_schema
955
+ assert actual_schema_mldev == expected_schema
956
+
957
+
958
+ def test_generic_alias_complex_array_all_py_versions():
959
+
960
+ def func_under_test(
961
+ a: typing.List[typing.Union[int, str, float, bool]],
962
+ b: typing.List[typing.Union[list, dict]],
963
+ ):
964
+ """test generic alias complex array."""
965
+ pass
966
+
967
+ expected_schema = types.FunctionDeclaration(
968
+ name='func_under_test',
969
+ parameters=types.Schema(
970
+ type='OBJECT',
971
+ properties={
972
+ 'a': types.Schema(
973
+ type='ARRAY',
974
+ items=types.Schema(
975
+ type='OBJECT',
976
+ any_of=[
977
+ types.Schema(type='INTEGER'),
978
+ types.Schema(type='STRING'),
979
+ types.Schema(type='NUMBER'),
980
+ types.Schema(type='BOOLEAN'),
981
+ ],
982
+ ),
983
+ ),
984
+ 'b': types.Schema(
985
+ type='ARRAY',
986
+ items=types.Schema(
987
+ type='OBJECT',
988
+ any_of=[
989
+ types.Schema(type='ARRAY'),
990
+ types.Schema(type='OBJECT'),
991
+ ],
992
+ ),
993
+ ),
994
+ },
995
+ required=['a', 'b'],
996
+ ),
997
+ description='test generic alias complex array.',
998
+ )
999
+
1000
+ actual_schema_mldev = types.FunctionDeclaration.from_callable(
1001
+ client=mldev_client, callable=func_under_test
1002
+ )
1003
+ actual_schema_vertex = types.FunctionDeclaration.from_callable(
1004
+ client=vertex_client, callable=func_under_test
1005
+ )
1006
+ assert actual_schema_vertex == expected_schema
1007
+ assert actual_schema_mldev == expected_schema
1008
+
1009
+
1010
+ @pytest.mark.skipif(
1011
+ sys.version_info < (3, 10),
1012
+ reason='| is only supported in Python 3.10 and above.',
1013
+ )
1014
+ def test_generic_alias_complex_array_with_default_value():
1015
+
1016
+ def func_under_test(
1017
+ a: typing.List[int | str | float | bool] = [
1018
+ 1,
1019
+ 'a',
1020
+ 1.1,
1021
+ True,
1022
+ ],
1023
+ b: list[int | str | float | bool] = [
1024
+ 11,
1025
+ 'aa',
1026
+ 1.11,
1027
+ False,
1028
+ ],
1029
+ c: typing.List[typing.List[int] | int] = [[1], 2],
1030
+ ):
1031
+ """test generic alias complex array with default value."""
1032
+ pass
1033
+
1034
+ expected_schema = types.FunctionDeclaration(
1035
+ name='func_under_test',
1036
+ parameters=types.Schema(
1037
+ type='OBJECT',
1038
+ properties={
1039
+ 'a': types.Schema(
1040
+ type='ARRAY',
1041
+ items=types.Schema(
1042
+ type='OBJECT',
1043
+ any_of=[
1044
+ types.Schema(type='INTEGER'),
1045
+ types.Schema(type='STRING'),
1046
+ types.Schema(type='NUMBER'),
1047
+ types.Schema(type='BOOLEAN'),
1048
+ ],
1049
+ ),
1050
+ default=[1, 'a', 1.1, True],
1051
+ ),
1052
+ 'b': types.Schema(
1053
+ type='ARRAY',
1054
+ items=types.Schema(
1055
+ type='OBJECT',
1056
+ any_of=[
1057
+ types.Schema(type='INTEGER'),
1058
+ types.Schema(type='STRING'),
1059
+ types.Schema(type='NUMBER'),
1060
+ types.Schema(type='BOOLEAN'),
1061
+ ],
1062
+ ),
1063
+ default=[11, 'aa', 1.11, False],
1064
+ ),
1065
+ 'c': types.Schema(
1066
+ type='ARRAY',
1067
+ items=types.Schema(
1068
+ type='OBJECT',
1069
+ any_of=[
1070
+ types.Schema(
1071
+ type='ARRAY',
1072
+ items=types.Schema(type='INTEGER'),
1073
+ ),
1074
+ types.Schema(type='INTEGER'),
1075
+ ],
1076
+ ),
1077
+ default=[[1], 2],
1078
+ ),
1079
+ },
1080
+ required=[],
1081
+ ),
1082
+ description='test generic alias complex array with default value.',
1083
+ )
1084
+
1085
+ actual_schema_vertex = types.FunctionDeclaration.from_callable(
1086
+ client=vertex_client, callable=func_under_test
1087
+ )
1088
+ actual_schema_mldev = types.FunctionDeclaration.from_callable(
1089
+ client=mldev_client, callable=func_under_test
1090
+ )
1091
+
1092
+ assert actual_schema_vertex == expected_schema
1093
+ assert actual_schema_mldev == expected_schema
1094
+
1095
+
1096
+ def test_generic_alias_complex_array_with_default_value_all_py_versions():
1097
+
1098
+ def func_under_test(
1099
+ a: typing.List[typing.Union[int, str, float, bool]] = [
1100
+ 1,
1101
+ 'a',
1102
+ 1.1,
1103
+ True,
1104
+ ],
1105
+ b: list[typing.Union[int, str, float, bool]] = [
1106
+ 11,
1107
+ 'aa',
1108
+ 1.11,
1109
+ False,
1110
+ ],
1111
+ c: typing.List[typing.Union[typing.List[int], int]] = [[1], 2],
1112
+ ):
1113
+ """test generic alias complex array with default value."""
1114
+ pass
1115
+
1116
+ expected_schema = types.FunctionDeclaration(
1117
+ name='func_under_test',
1118
+ parameters=types.Schema(
1119
+ type='OBJECT',
1120
+ properties={
1121
+ 'a': types.Schema(
1122
+ type='ARRAY',
1123
+ items=types.Schema(
1124
+ type='OBJECT',
1125
+ any_of=[
1126
+ types.Schema(type='INTEGER'),
1127
+ types.Schema(type='STRING'),
1128
+ types.Schema(type='NUMBER'),
1129
+ types.Schema(type='BOOLEAN'),
1130
+ ],
1131
+ ),
1132
+ default=[1, 'a', 1.1, True],
1133
+ ),
1134
+ 'b': types.Schema(
1135
+ type='ARRAY',
1136
+ items=types.Schema(
1137
+ type='OBJECT',
1138
+ any_of=[
1139
+ types.Schema(type='INTEGER'),
1140
+ types.Schema(type='STRING'),
1141
+ types.Schema(type='NUMBER'),
1142
+ types.Schema(type='BOOLEAN'),
1143
+ ],
1144
+ ),
1145
+ default=[11, 'aa', 1.11, False],
1146
+ ),
1147
+ 'c': types.Schema(
1148
+ type='ARRAY',
1149
+ items=types.Schema(
1150
+ type='OBJECT',
1151
+ any_of=[
1152
+ types.Schema(
1153
+ type='ARRAY',
1154
+ items=types.Schema(type='INTEGER'),
1155
+ ),
1156
+ types.Schema(type='INTEGER'),
1157
+ ],
1158
+ ),
1159
+ default=[[1], 2],
1160
+ ),
1161
+ },
1162
+ required=[],
1163
+ ),
1164
+ description='test generic alias complex array with default value.',
1165
+ )
1166
+
1167
+ actual_schema_vertex = types.FunctionDeclaration.from_callable(
1168
+ client=vertex_client, callable=func_under_test
1169
+ )
1170
+ actual_schema_mldev = types.FunctionDeclaration.from_callable(
1171
+ client=mldev_client, callable=func_under_test
1172
+ )
1173
+
1174
+ assert actual_schema_vertex == expected_schema
1175
+ assert actual_schema_mldev == expected_schema
1176
+
1177
+
1178
+ @pytest.mark.skipif(
1179
+ sys.version_info < (3, 10),
1180
+ reason='| is only supported in Python 3.10 and above.',
1181
+ )
1182
+ def test_generic_alias_complex_array_with_default_value_not_compatible():
1183
+
1184
+ def func_under_test1(
1185
+ a: typing.List[int | str | float | bool] = [1, 'a', 1.1, True, []],
1186
+ ):
1187
+ """test generic alias complex array with default value not compatible."""
1188
+ pass
1189
+
1190
+ def func_under_test2(
1191
+ a: list[int | str | float | bool] = [1, 'a', 1.1, True, []],
1192
+ ):
1193
+ """test generic alias complex array with default value not compatible."""
1194
+ pass
1195
+
1196
+ for func_under_test in [func_under_test1, func_under_test2]:
1197
+ types.FunctionDeclaration.from_callable(
1198
+ client=mldev_client, callable=func_under_test
1199
+ )
1200
+ types.FunctionDeclaration.from_callable(
1201
+ client=vertex_client, callable=func_under_test
1202
+ )
1203
+
1204
+
1205
+ def test_generic_alias_complex_array_with_default_value_not_compatible_all_py_versions():
1206
+
1207
+ def func_under_test1(
1208
+ a: typing.List[typing.Union[int, str, float, bool]] = [
1209
+ 1,
1210
+ 'a',
1211
+ 1.1,
1212
+ True,
1213
+ [],
1214
+ ],
1215
+ ):
1216
+ """test generic alias complex array with default value not compatible."""
1217
+ pass
1218
+
1219
+ def func_under_test2(
1220
+ a: list[typing.Union[int, str, float, bool]] = [1, 'a', 1.1, True, []],
1221
+ ):
1222
+ """test generic alias complex array with default value not compatible."""
1223
+ pass
1224
+
1225
+ for func_under_test in [func_under_test1, func_under_test2]:
1226
+ types.FunctionDeclaration.from_callable(
1227
+ client=mldev_client, callable=func_under_test
1228
+ )
1229
+ types.FunctionDeclaration.from_callable(
1230
+ client=vertex_client, callable=func_under_test
1231
+ )
1232
+
1233
+
1234
+ def test_generic_alias_object():
1235
+
1236
+ def func_under_test(
1237
+ a: typing.Dict[str, int],
1238
+ ):
1239
+ """test generic alias object."""
1240
+ pass
1241
+
1242
+ expected_schema = types.FunctionDeclaration(
1243
+ name='func_under_test',
1244
+ parameters=types.Schema(
1245
+ type='OBJECT',
1246
+ properties={
1247
+ 'a': types.Schema(type='OBJECT'),
1248
+ },
1249
+ required=['a'],
1250
+ ),
1251
+ description='test generic alias object.',
1252
+ )
1253
+
1254
+ actual_schema_mldev = types.FunctionDeclaration.from_callable(
1255
+ client=mldev_client, callable=func_under_test
1256
+ )
1257
+ actual_schema_vertex = types.FunctionDeclaration.from_callable(
1258
+ client=vertex_client, callable=func_under_test
1259
+ )
1260
+
1261
+ assert actual_schema_mldev == expected_schema
1262
+ assert actual_schema_vertex == expected_schema
1263
+
1264
+
1265
+ def test_supported_uncommon_generic_alias_object():
1266
+ def func_under_test1(a: typing.OrderedDict[str, int]):
1267
+ """test uncommon generic alias object."""
1268
+ pass
1269
+
1270
+ def func_under_test2(a: typing.MutableMapping[str, int]):
1271
+ """test uncommon generic alias object."""
1272
+ pass
1273
+
1274
+ def func_under_test3(a: typing.MutableSequence[int]):
1275
+ """test uncommon generic alias object."""
1276
+ pass
1277
+
1278
+ def func_under_test4(a: typing.MutableSet[int]):
1279
+ """test uncommon generic alias object."""
1280
+ pass
1281
+
1282
+ def func_under_test5(a: typing.Counter[int]):
1283
+ """test uncommon generic alias object."""
1284
+ pass
1285
+
1286
+ def func_under_test6(a: typing.Iterable[int]):
1287
+ """test uncommon generic alias object."""
1288
+ pass
1289
+
1290
+ def func_under_test7(a: typing.DefaultDict[int, int]):
1291
+ """test uncommon generic alias object."""
1292
+ pass
1293
+
1294
+ all_func_under_test = [
1295
+ func_under_test1,
1296
+ func_under_test2,
1297
+ func_under_test3,
1298
+ func_under_test4,
1299
+ func_under_test5,
1300
+ func_under_test6,
1301
+ func_under_test7,
1302
+ ]
1303
+
1304
+ for func_under_test in all_func_under_test:
1305
+ types.FunctionDeclaration.from_callable(
1306
+ client=mldev_client, callable=func_under_test
1307
+ )
1308
+ types.FunctionDeclaration.from_callable(
1309
+ client=vertex_client, callable=func_under_test
1310
+ )
1311
+
1312
+
1313
+ def test_unsupported_uncommon_generic_alias_object():
1314
+
1315
+ def func_under_test1(a: typing.Collection[int]):
1316
+ """test uncommon generic alias object."""
1317
+ pass
1318
+
1319
+ def func_under_test2(a: typing.Iterator[int]):
1320
+ """test uncommon generic alias object."""
1321
+ pass
1322
+
1323
+ def func_under_test3(a: typing.Container[int]):
1324
+ """test uncommon generic alias object."""
1325
+ pass
1326
+
1327
+ def func_under_test4(a: typing.ChainMap[int, int]):
1328
+ """test uncommon generic alias object."""
1329
+ pass
1330
+
1331
+ all_func_under_test = [
1332
+ func_under_test1,
1333
+ func_under_test2,
1334
+ func_under_test3,
1335
+ func_under_test4,
1336
+ ]
1337
+
1338
+ for func_under_test in all_func_under_test:
1339
+ with pytest.raises(ValueError):
1340
+ types.FunctionDeclaration.from_callable(
1341
+ client=mldev_client, callable=func_under_test
1342
+ )
1343
+ with pytest.raises(ValueError):
1344
+ types.FunctionDeclaration.from_callable(
1345
+ client=vertex_client, callable=func_under_test
1346
+ )
1347
+
1348
+
1349
+ def test_generic_alias_object_with_default_value():
1350
+ def func_under_test(a: typing.Dict[str, int] = {'a': 1}):
1351
+ """test generic alias object with default value."""
1352
+ pass
1353
+
1354
+ expected_schema = types.FunctionDeclaration(
1355
+ name='func_under_test',
1356
+ parameters=types.Schema(
1357
+ type='OBJECT',
1358
+ properties={
1359
+ 'a': types.Schema(
1360
+ type='OBJECT',
1361
+ default={'a': 1},
1362
+ ),
1363
+ },
1364
+ required=[],
1365
+ ),
1366
+ description='test generic alias object with default value.',
1367
+ )
1368
+
1369
+ actual_schema_vertex = types.FunctionDeclaration.from_callable(
1370
+ client=vertex_client, callable=func_under_test
1371
+ )
1372
+ actual_schema_mldev = types.FunctionDeclaration.from_callable(
1373
+ client=mldev_client, callable=func_under_test
1374
+ )
1375
+
1376
+ assert actual_schema_vertex == expected_schema
1377
+ assert actual_schema_mldev == expected_schema
1378
+
1379
+
1380
+ def test_generic_alias_object_with_default_value_not_compatible():
1381
+ def func_under_test(a: typing.Dict[str, int] = 'a'):
1382
+ """test generic alias object with default value not compatible."""
1383
+ pass
1384
+
1385
+ types.FunctionDeclaration.from_callable(
1386
+ client=mldev_client, callable=func_under_test
1387
+ )
1388
+ types.FunctionDeclaration.from_callable(
1389
+ client=vertex_client, callable=func_under_test
1390
+ )
1391
+
1392
+
1393
+ def test_pydantic_model():
1394
+ class MySimplePydanticModel(pydantic.BaseModel):
1395
+ a_simple: int
1396
+ b_simple: str
1397
+
1398
+ class MyComplexPydanticModel(pydantic.BaseModel):
1399
+ a_complex: MySimplePydanticModel
1400
+ b_complex: list[MySimplePydanticModel]
1401
+
1402
+ def func_under_test(
1403
+ a: MySimplePydanticModel,
1404
+ b: MyComplexPydanticModel,
1405
+ ):
1406
+ """test pydantic model."""
1407
+ pass
1408
+
1409
+ expected_schema = types.FunctionDeclaration(
1410
+ name='func_under_test',
1411
+ parameters=types.Schema(
1412
+ type='OBJECT',
1413
+ properties={
1414
+ 'a': types.Schema(
1415
+ type='OBJECT',
1416
+ properties={
1417
+ 'a_simple': types.Schema(type='INTEGER'),
1418
+ 'b_simple': types.Schema(type='STRING'),
1419
+ },
1420
+ required=['a_simple', 'b_simple'],
1421
+ ),
1422
+ 'b': types.Schema(
1423
+ type='OBJECT',
1424
+ properties={
1425
+ 'a_complex': types.Schema(
1426
+ type='OBJECT',
1427
+ properties={
1428
+ 'a_simple': types.Schema(type='INTEGER'),
1429
+ 'b_simple': types.Schema(type='STRING'),
1430
+ },
1431
+ required=['a_simple', 'b_simple'],
1432
+ ),
1433
+ 'b_complex': types.Schema(
1434
+ type='ARRAY',
1435
+ items=types.Schema(
1436
+ type='OBJECT',
1437
+ properties={
1438
+ 'a_simple': types.Schema(type='INTEGER'),
1439
+ 'b_simple': types.Schema(type='STRING'),
1440
+ },
1441
+ required=['a_simple', 'b_simple'],
1442
+ ),
1443
+ ),
1444
+ },
1445
+ required=['a_complex', 'b_complex'],
1446
+ ),
1447
+ },
1448
+ required=['a', 'b'],
1449
+ ),
1450
+ description='test pydantic model.',
1451
+ )
1452
+
1453
+ actual_schema_mldev = types.FunctionDeclaration.from_callable(
1454
+ client=mldev_client, callable=func_under_test
1455
+ )
1456
+ actual_schema_vertex = types.FunctionDeclaration.from_callable(
1457
+ client=vertex_client, callable=func_under_test
1458
+ )
1459
+
1460
+ assert actual_schema_mldev == expected_schema
1461
+ assert actual_schema_vertex == expected_schema
1462
+
1463
+
1464
+ def test_pydantic_model_in_list_type():
1465
+ class MySimplePydanticModel(pydantic.BaseModel):
1466
+ a_simple: int
1467
+ b_simple: str
1468
+
1469
+ def func_under_test(
1470
+ a: list[MySimplePydanticModel],
1471
+ ):
1472
+ """test pydantic model in list type."""
1473
+ pass
1474
+
1475
+ expected_schema = types.FunctionDeclaration(
1476
+ name='func_under_test',
1477
+ parameters=types.Schema(
1478
+ type='OBJECT',
1479
+ properties={
1480
+ 'a': types.Schema(
1481
+ type='ARRAY',
1482
+ items=types.Schema(
1483
+ type='OBJECT',
1484
+ properties={
1485
+ 'a_simple': types.Schema(type='INTEGER'),
1486
+ 'b_simple': types.Schema(type='STRING'),
1487
+ },
1488
+ required=['a_simple', 'b_simple'],
1489
+ ),
1490
+ ),
1491
+ },
1492
+ required=['a'],
1493
+ ),
1494
+ description='test pydantic model in list type.',
1495
+ )
1496
+
1497
+ actual_schema_mldev = types.FunctionDeclaration.from_callable(
1498
+ client=mldev_client, callable=func_under_test
1499
+ )
1500
+ actual_schema_vertex = types.FunctionDeclaration.from_callable(
1501
+ client=vertex_client, callable=func_under_test
1502
+ )
1503
+
1504
+ assert actual_schema_mldev == expected_schema
1505
+ assert actual_schema_vertex == expected_schema
1506
+
1507
+
1508
+ def test_pydantic_model_in_union_type():
1509
+ class CatInformationObject(pydantic.BaseModel):
1510
+ name: str
1511
+ age: int
1512
+ like_purring: bool
1513
+
1514
+ class DogInformationObject(pydantic.BaseModel):
1515
+ name: str
1516
+ age: int
1517
+ like_barking: bool
1518
+
1519
+ def func_under_test(
1520
+ animal: typing.Union[CatInformationObject, DogInformationObject],
1521
+ ):
1522
+ """test pydantic model in union type."""
1523
+ pass
1524
+
1525
+ expected_schema = types.FunctionDeclaration(
1526
+ name='func_under_test',
1527
+ parameters=types.Schema(
1528
+ type='OBJECT',
1529
+ properties={
1530
+ 'animal': types.Schema(
1531
+ type='OBJECT',
1532
+ any_of=[
1533
+ types.Schema(
1534
+ type='OBJECT',
1535
+ properties={
1536
+ 'name': types.Schema(type='STRING'),
1537
+ 'age': types.Schema(type='INTEGER'),
1538
+ 'like_purring': types.Schema(type='BOOLEAN'),
1539
+ },
1540
+ ),
1541
+ types.Schema(
1542
+ type='OBJECT',
1543
+ properties={
1544
+ 'name': types.Schema(type='STRING'),
1545
+ 'age': types.Schema(type='INTEGER'),
1546
+ 'like_barking': types.Schema(type='BOOLEAN'),
1547
+ },
1548
+ ),
1549
+ ],
1550
+ ),
1551
+ },
1552
+ required=['animal'],
1553
+ ),
1554
+ description='test pydantic model in union type.',
1555
+ )
1556
+ expected_schema.parameters.properties['animal'].any_of[0].required = [
1557
+ 'name',
1558
+ 'age',
1559
+ 'like_purring',
1560
+ ]
1561
+ expected_schema.parameters.properties['animal'].any_of[1].required = [
1562
+ 'name',
1563
+ 'age',
1564
+ 'like_barking',
1565
+ ]
1566
+
1567
+ actual_schema_mldev = types.FunctionDeclaration.from_callable(
1568
+ client=mldev_client, callable=func_under_test
1569
+ )
1570
+ actual_schema_vertex = types.FunctionDeclaration.from_callable(
1571
+ client=vertex_client, callable=func_under_test
1572
+ )
1573
+
1574
+ assert actual_schema_vertex == expected_schema
1575
+ assert actual_schema_mldev == expected_schema
1576
+
1577
+
1578
+ def test_pydantic_model_with_default_value():
1579
+ class MySimplePydanticModel(pydantic.BaseModel):
1580
+ a_simple: Optional[int]
1581
+ b_simple: Optional[str]
1582
+
1583
+ mySimplePydanticModel = MySimplePydanticModel(a_simple=1, b_simple='a')
1584
+
1585
+ def func_under_test(a: MySimplePydanticModel = mySimplePydanticModel):
1586
+ """test pydantic model with default value."""
1587
+ pass
1588
+
1589
+ expected_schema = types.FunctionDeclaration(
1590
+ description='test pydantic model with default value.',
1591
+ name='func_under_test',
1592
+ parameters=types.Schema(
1593
+ type='OBJECT',
1594
+ properties={
1595
+ 'a': types.Schema(
1596
+ default=MySimplePydanticModel(a_simple=1, b_simple='a'),
1597
+ type='OBJECT',
1598
+ properties={
1599
+ 'a_simple': types.Schema(
1600
+ nullable=True,
1601
+ type='INTEGER',
1602
+ ),
1603
+ 'b_simple': types.Schema(
1604
+ nullable=True,
1605
+ type='STRING',
1606
+ ),
1607
+ },
1608
+ required=[],
1609
+ )
1610
+ },
1611
+ required=[],
1612
+ ),
1613
+ )
1614
+
1615
+ actual_schema_vertex = types.FunctionDeclaration.from_callable(
1616
+ client=vertex_client, callable=func_under_test
1617
+ )
1618
+ actual_schema_mldev = types.FunctionDeclaration.from_callable(
1619
+ client=mldev_client, callable=func_under_test
1620
+ )
1621
+
1622
+ assert actual_schema_vertex == expected_schema
1623
+ assert actual_schema_mldev == expected_schema
1624
+
1625
+
1626
+ def test_custom_class():
1627
+
1628
+ class MyClass:
1629
+ a: int
1630
+ b: str
1631
+
1632
+ def __init__(self, a: int):
1633
+ self.a = a
1634
+ self.b = str(a)
1635
+
1636
+ def func_under_test(a: MyClass):
1637
+ """test custom class."""
1638
+ pass
1639
+
1640
+ with pytest.raises(ValueError):
1641
+ types.FunctionDeclaration.from_callable(
1642
+ client=mldev_client, callable=func_under_test
1643
+ )
1644
+ with pytest.raises(ValueError):
1645
+ types.FunctionDeclaration.from_callable(
1646
+ client=vertex_client, callable=func_under_test
1647
+ )
1648
+
1649
+
1650
+ @pytest.mark.skipif(
1651
+ sys.version_info < (3, 10),
1652
+ reason='| is only supported in Python 3.10 and above.',
1653
+ )
1654
+ def test_type_union():
1655
+
1656
+ def func_under_test(
1657
+ a: typing.Union[int, str],
1658
+ b: typing.Union[list, dict],
1659
+ c: typing.Union[typing.List[typing.Union[int, float]], dict],
1660
+ d: list | dict,
1661
+ ):
1662
+ """test type union."""
1663
+ pass
1664
+
1665
+ expected_schema = types.FunctionDeclaration(
1666
+ name='func_under_test',
1667
+ parameters=types.Schema(
1668
+ type='OBJECT',
1669
+ properties={
1670
+ 'a': types.Schema(
1671
+ type='OBJECT',
1672
+ any_of=[
1673
+ types.Schema(type='INTEGER'),
1674
+ types.Schema(type='STRING'),
1675
+ ],
1676
+ ),
1677
+ 'b': types.Schema(
1678
+ type='OBJECT',
1679
+ any_of=[
1680
+ types.Schema(type='ARRAY'),
1681
+ types.Schema(type='OBJECT'),
1682
+ ],
1683
+ ),
1684
+ 'c': types.Schema(
1685
+ type='OBJECT',
1686
+ any_of=[
1687
+ types.Schema(
1688
+ type='ARRAY',
1689
+ items=types.Schema(
1690
+ type='OBJECT',
1691
+ any_of=[
1692
+ types.Schema(type='INTEGER'),
1693
+ types.Schema(type='NUMBER'),
1694
+ ],
1695
+ ),
1696
+ ),
1697
+ types.Schema(
1698
+ type='OBJECT',
1699
+ ),
1700
+ ],
1701
+ ),
1702
+ 'd': types.Schema(
1703
+ type='OBJECT',
1704
+ any_of=[
1705
+ types.Schema(type='ARRAY'),
1706
+ types.Schema(type='OBJECT'),
1707
+ ],
1708
+ ),
1709
+ },
1710
+ required=['a', 'b', 'c', 'd'],
1711
+ ),
1712
+ description='test type union.',
1713
+ )
1714
+
1715
+ actual_schema_mldev = types.FunctionDeclaration.from_callable(
1716
+ client=mldev_client, callable=func_under_test
1717
+ )
1718
+ actual_schema_vertex = types.FunctionDeclaration.from_callable(
1719
+ client=vertex_client, callable=func_under_test
1720
+ )
1721
+
1722
+ assert actual_schema_vertex == expected_schema
1723
+ assert actual_schema_mldev == expected_schema
1724
+
1725
+
1726
+ def test_type_union_all_py_versions():
1727
+
1728
+ def func_under_test(
1729
+ a: typing.Union[int, str],
1730
+ b: typing.Union[list, dict],
1731
+ c: typing.Union[typing.List[typing.Union[int, float]], dict],
1732
+ ):
1733
+ """test type union."""
1734
+ pass
1735
+
1736
+ expected_schema = types.FunctionDeclaration(
1737
+ name='func_under_test',
1738
+ parameters=types.Schema(
1739
+ type='OBJECT',
1740
+ properties={
1741
+ 'a': types.Schema(
1742
+ type='OBJECT',
1743
+ any_of=[
1744
+ types.Schema(type='INTEGER'),
1745
+ types.Schema(type='STRING'),
1746
+ ],
1747
+ ),
1748
+ 'b': types.Schema(
1749
+ type='OBJECT',
1750
+ any_of=[
1751
+ types.Schema(type='ARRAY'),
1752
+ types.Schema(type='OBJECT'),
1753
+ ],
1754
+ ),
1755
+ 'c': types.Schema(
1756
+ type='OBJECT',
1757
+ any_of=[
1758
+ types.Schema(
1759
+ type='ARRAY',
1760
+ items=types.Schema(
1761
+ type='OBJECT',
1762
+ any_of=[
1763
+ types.Schema(type='INTEGER'),
1764
+ types.Schema(type='NUMBER'),
1765
+ ],
1766
+ ),
1767
+ ),
1768
+ types.Schema(
1769
+ type='OBJECT',
1770
+ ),
1771
+ ],
1772
+ ),
1773
+ },
1774
+ required=['a', 'b', 'c'],
1775
+ ),
1776
+ description='test type union.',
1777
+ )
1778
+
1779
+ actual_schema_mldev = types.FunctionDeclaration.from_callable(
1780
+ client=mldev_client, callable=func_under_test
1781
+ )
1782
+ actual_schema_vertex = types.FunctionDeclaration.from_callable(
1783
+ client=vertex_client, callable=func_under_test
1784
+ )
1785
+
1786
+ assert actual_schema_vertex == expected_schema
1787
+ assert actual_schema_mldev == expected_schema
1788
+
1789
+
1790
+ def test_type_optional_with_list():
1791
+
1792
+ def func_under_test(
1793
+ a: str,
1794
+ b: typing.Optional[list[str]] = None,
1795
+ ):
1796
+ """test type optional with list."""
1797
+ pass
1798
+
1799
+ expected_schema = types.FunctionDeclaration(
1800
+ name='func_under_test',
1801
+ parameters=types.Schema(
1802
+ type='OBJECT',
1803
+ properties={
1804
+ 'a': types.Schema(type='STRING'),
1805
+ 'b': types.Schema(
1806
+ nullable=True, type='ARRAY', items=types.Schema(type='STRING')
1807
+ ),
1808
+ },
1809
+ required=['a'],
1810
+ ),
1811
+ description='test type optional with list.',
1812
+ )
1813
+
1814
+ actual_schema_mldev = types.FunctionDeclaration.from_callable(
1815
+ client=mldev_client, callable=func_under_test
1816
+ )
1817
+ actual_schema_vertex = types.FunctionDeclaration.from_callable(
1818
+ client=vertex_client, callable=func_under_test
1819
+ )
1820
+
1821
+ assert actual_schema_vertex == expected_schema
1822
+ assert actual_schema_mldev == expected_schema
1823
+
1824
+
1825
+ @pytest.mark.skipif(
1826
+ sys.version_info < (3, 10),
1827
+ reason='| is only supported in Python 3.10 and above.',
1828
+ )
1829
+ def test_type_union_with_default_value():
1830
+
1831
+ def func_under_test(
1832
+ a: typing.Union[int, str] = 1,
1833
+ b: typing.Union[list, dict] = [1],
1834
+ c: typing.Union[typing.List[typing.Union[int, float]], dict] = {},
1835
+ d: list | dict = [1, 2, 3],
1836
+ ):
1837
+ """test type union with default value."""
1838
+ pass
1839
+
1840
+ expected_schema = types.FunctionDeclaration(
1841
+ name='func_under_test',
1842
+ parameters=types.Schema(
1843
+ type='OBJECT',
1844
+ properties={
1845
+ 'a': types.Schema(
1846
+ type='OBJECT',
1847
+ any_of=[
1848
+ types.Schema(type='INTEGER'),
1849
+ types.Schema(type='STRING'),
1850
+ ],
1851
+ default=1,
1852
+ ),
1853
+ 'b': types.Schema(
1854
+ type='OBJECT',
1855
+ any_of=[
1856
+ types.Schema(type='ARRAY'),
1857
+ types.Schema(type='OBJECT'),
1858
+ ],
1859
+ default=[1],
1860
+ ),
1861
+ 'c': types.Schema(
1862
+ type='OBJECT',
1863
+ any_of=[
1864
+ types.Schema(
1865
+ type='ARRAY',
1866
+ items=types.Schema(
1867
+ type='OBJECT',
1868
+ any_of=[
1869
+ types.Schema(type='INTEGER'),
1870
+ types.Schema(type='NUMBER'),
1871
+ ],
1872
+ ),
1873
+ ),
1874
+ types.Schema(
1875
+ type='OBJECT',
1876
+ ),
1877
+ ],
1878
+ default={},
1879
+ ),
1880
+ 'd': types.Schema(
1881
+ type='OBJECT',
1882
+ any_of=[
1883
+ types.Schema(type='ARRAY'),
1884
+ types.Schema(type='OBJECT'),
1885
+ ],
1886
+ default=[1, 2, 3],
1887
+ ),
1888
+ },
1889
+ required=[],
1890
+ ),
1891
+ description='test type union with default value.',
1892
+ )
1893
+
1894
+ actual_schema_vertex = types.FunctionDeclaration.from_callable(
1895
+ client=vertex_client, callable=func_under_test
1896
+ )
1897
+ actual_schema_mldev = types.FunctionDeclaration.from_callable(
1898
+ client=mldev_client, callable=func_under_test
1899
+ )
1900
+
1901
+ assert actual_schema_vertex == expected_schema
1902
+ assert actual_schema_mldev == expected_schema
1903
+
1904
+
1905
+ def test_type_union_with_default_value_all_py_versions():
1906
+
1907
+ def func_under_test(
1908
+ a: typing.Union[int, str] = 1,
1909
+ b: typing.Union[list, dict] = [1],
1910
+ c: typing.Union[typing.List[typing.Union[int, float]], dict] = {},
1911
+ ):
1912
+ """test type union with default value."""
1913
+ pass
1914
+
1915
+ expected_schema = types.FunctionDeclaration(
1916
+ name='func_under_test',
1917
+ parameters=types.Schema(
1918
+ type='OBJECT',
1919
+ properties={
1920
+ 'a': types.Schema(
1921
+ type='OBJECT',
1922
+ any_of=[
1923
+ types.Schema(type='INTEGER'),
1924
+ types.Schema(type='STRING'),
1925
+ ],
1926
+ default=1,
1927
+ ),
1928
+ 'b': types.Schema(
1929
+ type='OBJECT',
1930
+ any_of=[
1931
+ types.Schema(type='ARRAY'),
1932
+ types.Schema(type='OBJECT'),
1933
+ ],
1934
+ default=[1],
1935
+ ),
1936
+ 'c': types.Schema(
1937
+ type='OBJECT',
1938
+ any_of=[
1939
+ types.Schema(
1940
+ type='ARRAY',
1941
+ items=types.Schema(
1942
+ type='OBJECT',
1943
+ any_of=[
1944
+ types.Schema(type='INTEGER'),
1945
+ types.Schema(type='NUMBER'),
1946
+ ],
1947
+ ),
1948
+ ),
1949
+ types.Schema(
1950
+ type='OBJECT',
1951
+ ),
1952
+ ],
1953
+ default={},
1954
+ ),
1955
+ },
1956
+ required=[],
1957
+ ),
1958
+ description='test type union with default value.',
1959
+ )
1960
+
1961
+ actual_schema_vertex = types.FunctionDeclaration.from_callable(
1962
+ client=vertex_client, callable=func_under_test
1963
+ )
1964
+ actual_schema_mldev = types.FunctionDeclaration.from_callable(
1965
+ client=mldev_client, callable=func_under_test
1966
+ )
1967
+
1968
+ assert actual_schema_vertex == expected_schema
1969
+ assert actual_schema_mldev == expected_schema
1970
+
1971
+
1972
+ @pytest.mark.skipif(
1973
+ sys.version_info < (3, 10),
1974
+ reason='| is only supported in Python 3.10 and above.',
1975
+ )
1976
+ def test_type_union_with_default_value():
1977
+
1978
+ def func_under_test1(
1979
+ a: typing.Union[typing.List[typing.Union[int, float]], dict] = 1,
1980
+ ):
1981
+ """test type union with default value not compatible."""
1982
+ pass
1983
+
1984
+ def func_under_test2(
1985
+ a: list | dict = 1,
1986
+ ):
1987
+ """test type union with default value not compatible."""
1988
+ pass
1989
+
1990
+ all_func_under_test = [func_under_test1, func_under_test2]
1991
+
1992
+ for func_under_test in all_func_under_test:
1993
+ types.FunctionDeclaration.from_callable(
1994
+ client=mldev_client, callable=func_under_test
1995
+ )
1996
+ types.FunctionDeclaration.from_callable(
1997
+ client=vertex_client, callable=func_under_test
1998
+ )
1999
+
2000
+
2001
+ def test_type_union_with_default_value_not_compatible_all_py_versions():
2002
+
2003
+ def func_under_test1(
2004
+ a: typing.Union[typing.List[typing.Union[int, float]], dict] = 1,
2005
+ ):
2006
+ """test type union with default value not compatible."""
2007
+ pass
2008
+
2009
+ def func_under_test2(
2010
+ a: typing.Union[list, dict] = 1,
2011
+ ):
2012
+ """test type union with default value not compatible."""
2013
+ pass
2014
+
2015
+ all_func_under_test = [func_under_test1, func_under_test2]
2016
+
2017
+ for func_under_test in all_func_under_test:
2018
+ types.FunctionDeclaration.from_callable(
2019
+ client=mldev_client, callable=func_under_test
2020
+ )
2021
+ types.FunctionDeclaration.from_callable(
2022
+ client=vertex_client, callable=func_under_test
2023
+ )
2024
+
2025
+
2026
+ @pytest.mark.skipif(
2027
+ sys.version_info < (3, 10),
2028
+ reason='| is not supported in Python 3.9',
2029
+ )
2030
+ def test_type_nullable():
2031
+
2032
+ def func_under_test(
2033
+ a: int | float | None,
2034
+ b: typing.Union[list, None],
2035
+ c: typing.Union[list, dict, None],
2036
+ d: typing.Optional[int] = None,
2037
+ ):
2038
+ """test type nullable."""
2039
+ pass
2040
+
2041
+ expected_schema = types.FunctionDeclaration(
2042
+ name='func_under_test',
2043
+ parameters=types.Schema(
2044
+ type='OBJECT',
2045
+ properties={
2046
+ 'a': types.Schema(
2047
+ type='OBJECT',
2048
+ any_of=[
2049
+ types.Schema(type='INTEGER'),
2050
+ types.Schema(type='NUMBER'),
2051
+ ],
2052
+ nullable=True,
2053
+ ),
2054
+ 'b': types.Schema(
2055
+ type='ARRAY',
2056
+ nullable=True,
2057
+ ),
2058
+ 'c': types.Schema(
2059
+ type='OBJECT',
2060
+ any_of=[
2061
+ types.Schema(type='ARRAY'),
2062
+ types.Schema(type='OBJECT'),
2063
+ ],
2064
+ nullable=True,
2065
+ ),
2066
+ 'd': types.Schema(
2067
+ type='INTEGER',
2068
+ nullable=True,
2069
+ default=None,
2070
+ ),
2071
+ },
2072
+ required=[],
2073
+ ),
2074
+ description='test type nullable.',
2075
+ )
2076
+
2077
+ actual_schema_mldev = types.FunctionDeclaration.from_callable(
2078
+ client=mldev_client, callable=func_under_test
2079
+ )
2080
+ actual_schema_vertex = types.FunctionDeclaration.from_callable(
2081
+ client=vertex_client, callable=func_under_test
2082
+ )
2083
+
2084
+ assert actual_schema_vertex == expected_schema
2085
+ assert actual_schema_mldev == expected_schema
2086
+
2087
+
2088
+ def test_type_nullable_all_py_versions():
2089
+
2090
+ def func_under_test(
2091
+ b: typing.Union[list, None],
2092
+ c: typing.Union[list, dict, None],
2093
+ d: typing.Optional[int] = None,
2094
+ ):
2095
+ """test type nullable."""
2096
+ pass
2097
+
2098
+ expected_schema = types.FunctionDeclaration(
2099
+ name='func_under_test',
2100
+ parameters=types.Schema(
2101
+ type='OBJECT',
2102
+ properties={
2103
+ 'b': types.Schema(
2104
+ type='ARRAY',
2105
+ nullable=True,
2106
+ ),
2107
+ 'c': types.Schema(
2108
+ type='OBJECT',
2109
+ any_of=[
2110
+ types.Schema(type='ARRAY'),
2111
+ types.Schema(type='OBJECT'),
2112
+ ],
2113
+ nullable=True,
2114
+ ),
2115
+ 'd': types.Schema(
2116
+ type='INTEGER',
2117
+ nullable=True,
2118
+ default=None,
2119
+ ),
2120
+ },
2121
+ required=[],
2122
+ ),
2123
+ description='test type nullable.',
2124
+ )
2125
+
2126
+ actual_schema_mldev = types.FunctionDeclaration.from_callable(
2127
+ client=mldev_client, callable=func_under_test
2128
+ )
2129
+ actual_schema_vertex = types.FunctionDeclaration.from_callable(
2130
+ client=vertex_client, callable=func_under_test
2131
+ )
2132
+
2133
+ assert actual_schema_vertex == expected_schema
2134
+ assert actual_schema_mldev == expected_schema
2135
+
2136
+
2137
+ def test_empty_function_with_return_type():
2138
+ def func_under_test() -> int:
2139
+ """test empty function with return type."""
2140
+ return 1
2141
+
2142
+ expected_schema_mldev = types.FunctionDeclaration(
2143
+ name='func_under_test',
2144
+ description='test empty function with return type.',
2145
+ )
2146
+ expected_schema_vertex = copy.deepcopy(expected_schema_mldev)
2147
+ expected_schema_vertex.response = types.Schema(type='INTEGER')
2148
+
2149
+ actual_schema_mldev = types.FunctionDeclaration.from_callable(
2150
+ client=mldev_client, callable=func_under_test
2151
+ )
2152
+ actual_schema_vertex = types.FunctionDeclaration.from_callable(
2153
+ client=vertex_client, callable=func_under_test
2154
+ )
2155
+
2156
+ assert actual_schema_mldev == expected_schema_mldev
2157
+ assert actual_schema_vertex == expected_schema_vertex
2158
+
2159
+
2160
+ def test_simple_function_with_return_type():
2161
+ def func_under_test(a: int) -> str:
2162
+ """test return type."""
2163
+ return ''
2164
+
2165
+ expected_schema_mldev = types.FunctionDeclaration(
2166
+ name='func_under_test',
2167
+ parameters=types.Schema(
2168
+ type='OBJECT',
2169
+ properties={
2170
+ 'a': types.Schema(type='INTEGER'),
2171
+ },
2172
+ required=['a'],
2173
+ ),
2174
+ description='test return type.',
2175
+ )
2176
+ expected_schema_vertex = copy.deepcopy(expected_schema_mldev)
2177
+ expected_schema_vertex.response = types.Schema(type='STRING')
2178
+
2179
+ actual_schema_mldev = types.FunctionDeclaration.from_callable(
2180
+ client=mldev_client, callable=func_under_test
2181
+ )
2182
+ actual_schema_vertex = types.FunctionDeclaration.from_callable(
2183
+ client=vertex_client, callable=func_under_test
2184
+ )
2185
+
2186
+ assert actual_schema_mldev == expected_schema_mldev
2187
+ assert actual_schema_vertex == expected_schema_vertex
2188
+
2189
+
2190
+ @pytest.mark.skipif(
2191
+ sys.version_info < (3, 10),
2192
+ reason='| is not supported in Python 3.9',
2193
+ )
2194
+ def test_builtin_union_return_type():
2195
+
2196
+ def func_under_test() -> int | str | float | bool | list | dict | None:
2197
+ """test builtin union return type."""
2198
+ pass
2199
+
2200
+ expected_schema_mldev = types.FunctionDeclaration(
2201
+ name='func_under_test',
2202
+ description='test builtin union return type.',
2203
+ )
2204
+ expected_schema_vertex = copy.deepcopy(expected_schema_mldev)
2205
+ expected_schema_vertex.response_json_schema = types.Schema(
2206
+ type='OBJECT',
2207
+ any_of=[
2208
+ types.Schema(type='INTEGER'),
2209
+ types.Schema(type='STRING'),
2210
+ types.Schema(type='NUMBER'),
2211
+ types.Schema(type='BOOLEAN'),
2212
+ types.Schema(type='ARRAY'),
2213
+ types.Schema(type='OBJECT'),
2214
+ ],
2215
+ nullable=True,
2216
+ )
2217
+
2218
+ actual_schema_mldev = types.FunctionDeclaration.from_callable(
2219
+ client=mldev_client, callable=func_under_test
2220
+ )
2221
+ actual_schema_vertex = types.FunctionDeclaration.from_callable(
2222
+ client=vertex_client, callable=func_under_test
2223
+ )
2224
+
2225
+ assert actual_schema_mldev == expected_schema_mldev
2226
+ assert actual_schema_vertex == expected_schema_vertex
2227
+
2228
+
2229
+ def test_builtin_union_return_type_all_py_versions():
2230
+
2231
+ def func_under_test() -> (
2232
+ typing.Union[int, str, float, bool, list, dict, None]
2233
+ ):
2234
+ """test builtin union return type."""
2235
+ pass
2236
+
2237
+ expected_schema_mldev = types.FunctionDeclaration(
2238
+ name='func_under_test',
2239
+ description='test builtin union return type.',
2240
+ )
2241
+ expected_schema_vertex = copy.deepcopy(expected_schema_mldev)
2242
+ expected_schema_vertex.response_json_schema = types.Schema(
2243
+ type='OBJECT',
2244
+ any_of=[
2245
+ types.Schema(type='INTEGER'),
2246
+ types.Schema(type='STRING'),
2247
+ types.Schema(type='NUMBER'),
2248
+ types.Schema(type='BOOLEAN'),
2249
+ types.Schema(type='ARRAY'),
2250
+ types.Schema(type='OBJECT'),
2251
+ ],
2252
+ nullable=True,
2253
+ )
2254
+
2255
+ actual_schema_mldev = types.FunctionDeclaration.from_callable(
2256
+ client=mldev_client, callable=func_under_test
2257
+ )
2258
+ actual_schema_vertex = types.FunctionDeclaration.from_callable(
2259
+ client=vertex_client, callable=func_under_test
2260
+ )
2261
+
2262
+ assert actual_schema_mldev == expected_schema_mldev
2263
+ assert actual_schema_vertex == expected_schema_vertex
2264
+
2265
+
2266
+ def test_typing_union_return_type():
2267
+
2268
+ def func_under_test() -> (
2269
+ typing.Union[int, str, float, bool, list, dict, None]
2270
+ ):
2271
+ """test typing union return type."""
2272
+ pass
2273
+
2274
+ expected_schema_mldev = types.FunctionDeclaration(
2275
+ name='func_under_test',
2276
+ description='test typing union return type.',
2277
+ )
2278
+ expected_schema_vertex = copy.deepcopy(expected_schema_mldev)
2279
+ expected_schema_vertex.response_json_schema = types.Schema(
2280
+ type='OBJECT',
2281
+ any_of=[
2282
+ types.Schema(type='INTEGER'),
2283
+ types.Schema(type='STRING'),
2284
+ types.Schema(type='NUMBER'),
2285
+ types.Schema(type='BOOLEAN'),
2286
+ types.Schema(type='ARRAY'),
2287
+ types.Schema(type='OBJECT'),
2288
+ ],
2289
+ nullable=True,
2290
+ )
2291
+
2292
+ actual_schema_mldev = types.FunctionDeclaration.from_callable(
2293
+ client=mldev_client, callable=func_under_test
2294
+ )
2295
+ actual_schema_vertex = types.FunctionDeclaration.from_callable(
2296
+ client=vertex_client, callable=func_under_test
2297
+ )
2298
+
2299
+ assert actual_schema_mldev == expected_schema_mldev
2300
+ assert actual_schema_vertex == expected_schema_vertex
2301
+
2302
+
2303
+ def test_return_type_optional():
2304
+ def func_under_test() -> typing.Optional[int]:
2305
+ """test return type optional."""
2306
+ pass
2307
+
2308
+ expected_schema_mldev = types.FunctionDeclaration(
2309
+ name='func_under_test',
2310
+ description='test return type optional.',
2311
+ )
2312
+ expected_schema_vertex = copy.deepcopy(expected_schema_mldev)
2313
+ expected_schema_vertex.response = types.Schema(
2314
+ type='INTEGER',
2315
+ nullable=True,
2316
+ )
2317
+
2318
+ actual_schema_mldev = types.FunctionDeclaration.from_callable(
2319
+ client=mldev_client, callable=func_under_test
2320
+ )
2321
+ actual_schema_vertex = types.FunctionDeclaration.from_callable(
2322
+ client=vertex_client, callable=func_under_test
2323
+ )
2324
+
2325
+ assert actual_schema_mldev == expected_schema_mldev
2326
+ assert actual_schema_vertex == expected_schema_vertex
2327
+
2328
+
2329
+ def test_return_type_pydantic_model():
2330
+ class MySimplePydanticModel(pydantic.BaseModel):
2331
+ a_simple: int
2332
+ b_simple: str
2333
+
2334
+ class MyComplexPydanticModel(pydantic.BaseModel):
2335
+ a_complex: MySimplePydanticModel
2336
+ b_complex: list[MySimplePydanticModel]
2337
+
2338
+ def func_under_test() -> MyComplexPydanticModel:
2339
+ """test return type pydantic model."""
2340
+ pass
2341
+
2342
+ expected_schema_mldev = types.FunctionDeclaration(
2343
+ name='func_under_test',
2344
+ description='test return type pydantic model.',
2345
+ )
2346
+ expected_schema_vertex = copy.deepcopy(expected_schema_mldev)
2347
+ expected_schema_vertex.response = types.Schema(
2348
+ type='OBJECT',
2349
+ properties={
2350
+ 'a_complex': types.Schema(
2351
+ type='OBJECT',
2352
+ properties={
2353
+ 'a_simple': types.Schema(type='INTEGER'),
2354
+ 'b_simple': types.Schema(type='STRING'),
2355
+ },
2356
+ required=['a_simple', 'b_simple'],
2357
+ ),
2358
+ 'b_complex': types.Schema(
2359
+ type='ARRAY',
2360
+ items=types.Schema(
2361
+ type='OBJECT',
2362
+ properties={
2363
+ 'a_simple': types.Schema(type='INTEGER'),
2364
+ 'b_simple': types.Schema(type='STRING'),
2365
+ },
2366
+ required=['a_simple', 'b_simple'],
2367
+ ),
2368
+ ),
2369
+ },
2370
+ required=['a_complex', 'b_complex'],
2371
+ )
2372
+
2373
+ actual_schema_mldev = types.FunctionDeclaration.from_callable(
2374
+ client=mldev_client, callable=func_under_test
2375
+ )
2376
+ actual_schema_vertex = types.FunctionDeclaration.from_callable(
2377
+ client=vertex_client, callable=func_under_test
2378
+ )
2379
+
2380
+ assert actual_schema_mldev == expected_schema_mldev
2381
+ assert actual_schema_vertex == expected_schema_vertex
2382
+
2383
+
2384
+ def test_function_with_return_type():
2385
+ def func_under_test1() -> set:
2386
+ pass
2387
+
2388
+ def func_under_test2() -> frozenset[int]:
2389
+ pass
2390
+
2391
+ def func_under_test3() -> typing.Set[int]:
2392
+ pass
2393
+
2394
+ def func_under_test4() -> typing.FrozenSet[int]:
2395
+ pass
2396
+
2397
+ def func_under_test5() -> typing.Iterable[int]:
2398
+ pass
2399
+
2400
+ def func_under_test6() -> bytes:
2401
+ pass
2402
+
2403
+ def func_under_test7() -> typing.OrderedDict[str, int]:
2404
+ pass
2405
+
2406
+ def func_under_test8() -> typing.MutableMapping[str, int]:
2407
+ pass
2408
+
2409
+ def func_under_test9() -> typing.MutableSequence[int]:
2410
+ pass
2411
+
2412
+ def func_under_test10() -> typing.MutableSet[int]:
2413
+ pass
2414
+
2415
+ def func_under_test11() -> typing.Counter[int]:
2416
+ pass
2417
+
2418
+ all_func_under_test = [
2419
+ func_under_test1,
2420
+ func_under_test2,
2421
+ func_under_test3,
2422
+ func_under_test4,
2423
+ func_under_test5,
2424
+ func_under_test6,
2425
+ func_under_test7,
2426
+ func_under_test8,
2427
+ func_under_test9,
2428
+ func_under_test10,
2429
+ func_under_test11,
2430
+ ]
2431
+ for i, func_under_test in enumerate(all_func_under_test):
2432
+
2433
+ expected_schema_mldev = types.FunctionDeclaration(
2434
+ name=f'func_under_test{i+1}',
2435
+ description=None,
2436
+ )
2437
+ actual_schema_mldev = types.FunctionDeclaration.from_callable(
2438
+ client=mldev_client, callable=func_under_test
2439
+ )
2440
+ assert actual_schema_mldev == expected_schema_mldev
2441
+
2442
+ types.FunctionDeclaration.from_callable(
2443
+ client=vertex_client, callable=func_under_test
2444
+ )
2445
+
2446
+
2447
+ def test_function_with_tuple_return_type():
2448
+ def func_under_test() -> tuple[int, str, str]:
2449
+ pass
2450
+
2451
+ expected_schema_mldev = types.FunctionDeclaration(
2452
+ name=f'func_under_test',
2453
+ description=None,
2454
+ )
2455
+ actual_schema_mldev = types.FunctionDeclaration.from_callable(
2456
+ client=mldev_client, callable=func_under_test
2457
+ )
2458
+
2459
+ expected_schema_vertex = types.FunctionDeclaration(
2460
+ name=f'func_under_test',
2461
+ description=None,
2462
+ response_json_schema={
2463
+ 'maxItems': 3,
2464
+ 'minItems': 3,
2465
+ 'prefixItems': [
2466
+ {'type': 'integer'},
2467
+ {'type': 'string'},
2468
+ {'type': 'string'},
2469
+ ],
2470
+ 'type': 'array',
2471
+ 'unevaluatedItems': False,
2472
+ },
2473
+ )
2474
+ actual_schema_vertex = types.FunctionDeclaration.from_callable(
2475
+ client=vertex_client, callable=func_under_test
2476
+ )
2477
+ assert actual_schema_mldev == expected_schema_mldev
2478
+ assert actual_schema_vertex == expected_schema_vertex
2479
+
2480
+
2481
+ def test_function_with_return_type_not_supported():
2482
+ def func_under_test1() -> typing.Collection[int]:
2483
+ pass
2484
+
2485
+ def func_under_test2() -> typing.Iterator[int]:
2486
+ pass
2487
+
2488
+ def func_under_test3() -> typing.Container[int]:
2489
+ pass
2490
+
2491
+ class MyClass:
2492
+ a: int
2493
+ b: str
2494
+
2495
+ def func_under_test4() -> MyClass:
2496
+ pass
2497
+
2498
+ all_func_under_test = [
2499
+ func_under_test1,
2500
+ func_under_test2,
2501
+ func_under_test3,
2502
+ func_under_test4,
2503
+ ]
2504
+ for i, func_under_test in enumerate(all_func_under_test):
2505
+
2506
+ expected_schema_mldev = types.FunctionDeclaration(
2507
+ name=f'func_under_test{i+1}',
2508
+ description=None,
2509
+ )
2510
+ actual_schema_mldev = types.FunctionDeclaration.from_callable(
2511
+ client=mldev_client, callable=func_under_test
2512
+ )
2513
+ assert actual_schema_mldev == expected_schema_mldev
2514
+ with pytest.raises(ValueError):
2515
+ types.FunctionDeclaration.from_callable(
2516
+ client=vertex_client, callable=func_under_test
2517
+ )
2518
+
2519
+ def test_function_with_tuple_contains_unevaluated_items():
2520
+ def func_under_test(a: tuple[int, int]) -> str:
2521
+ """test return type."""
2522
+ return ''
2523
+
2524
+ expected_parameters_json_schema = {
2525
+ 'a': {
2526
+ 'maxItems': 2,
2527
+ 'minItems': 2,
2528
+ 'prefixItems': [{'type': 'integer'}, {'type': 'integer'}],
2529
+ 'type': 'array',
2530
+ 'unevaluatedItems': False,
2531
+ }
2532
+ }
2533
+
2534
+ actual_schema_mldev = types.FunctionDeclaration.from_callable(
2535
+ client=mldev_client, callable=func_under_test
2536
+ )
2537
+ actual_schema_vertex = types.FunctionDeclaration.from_callable(
2538
+ client=vertex_client, callable=func_under_test
2539
+ )
2540
+
2541
+ assert actual_schema_mldev.parameters_json_schema == expected_parameters_json_schema
2542
+ assert actual_schema_vertex.parameters_json_schema == expected_parameters_json_schema
2543
+
2544
+
2545
+ def test_function_gemini_api(monkeypatch):
2546
+ api_key = 'google_api_key'
2547
+ monkeypatch.setenv('GOOGLE_API_KEY', api_key)
2548
+
2549
+ def func_under_test(a: int) -> str:
2550
+ """test return type."""
2551
+ return ''
2552
+
2553
+ expected_schema_mldev = types.FunctionDeclaration(
2554
+ name='func_under_test',
2555
+ parameters=types.Schema(
2556
+ type='OBJECT',
2557
+ properties={
2558
+ 'a': types.Schema(type='INTEGER'),
2559
+ },
2560
+ required=['a'],
2561
+ ),
2562
+ description='test return type.',
2563
+ )
2564
+
2565
+ actual_schema_mldev = types.FunctionDeclaration.from_callable(
2566
+ client=mldev_client, callable=func_under_test
2567
+ )
2568
+
2569
+ assert actual_schema_mldev == expected_schema_mldev
2570
+
2571
+
2572
+ def test_function_with_option_gemini_api(monkeypatch):
2573
+
2574
+ def func_under_test(a: int) -> str:
2575
+ """test return type."""
2576
+ return ''
2577
+
2578
+ expected_schema_mldev = types.FunctionDeclaration(
2579
+ name='func_under_test',
2580
+ parameters=types.Schema(
2581
+ type='OBJECT',
2582
+ properties={
2583
+ 'a': types.Schema(type='INTEGER'),
2584
+ },
2585
+ required=['a'],
2586
+ ),
2587
+ description='test return type.',
2588
+ )
2589
+
2590
+ actual_schema_mldev = types.FunctionDeclaration.from_callable_with_api_option(
2591
+ callable=func_under_test, api_option='GEMINI_API'
2592
+ )
2593
+
2594
+ assert actual_schema_mldev == expected_schema_mldev
2595
+
2596
+
2597
+ def test_function_with_option_unset(monkeypatch):
2598
+
2599
+ def func_under_test(a: int) -> str:
2600
+ """test return type."""
2601
+ return ''
2602
+
2603
+ expected_schema_mldev = types.FunctionDeclaration(
2604
+ name='func_under_test',
2605
+ parameters=types.Schema(
2606
+ type='OBJECT',
2607
+ properties={
2608
+ 'a': types.Schema(type='INTEGER'),
2609
+ },
2610
+ required=['a'],
2611
+ ),
2612
+ description='test return type.',
2613
+ )
2614
+
2615
+ actual_schema_mldev = types.FunctionDeclaration.from_callable_with_api_option(
2616
+ callable=func_under_test
2617
+ )
2618
+
2619
+ assert actual_schema_mldev == expected_schema_mldev
2620
+
2621
+
2622
+ def test_function_with_option_unsupported_api_option():
2623
+
2624
+ def func_under_test(a: int) -> str:
2625
+ """test return type."""
2626
+ return ''
2627
+
2628
+ with pytest.raises(ValueError):
2629
+ types.FunctionDeclaration.from_callable_with_api_option(
2630
+ callable=func_under_test, api_option='UNSUPPORTED_API_OPTION'
2631
+ )
2632
+
2633
+
2634
+ def test_function_vertex():
2635
+
2636
+ def func_under_test(a: int) -> str:
2637
+ """test return type."""
2638
+ return ''
2639
+
2640
+ expected_schema = types.FunctionDeclaration(
2641
+ name='func_under_test',
2642
+ parameters=types.Schema(
2643
+ type='OBJECT',
2644
+ properties={
2645
+ 'a': types.Schema(type='INTEGER'),
2646
+ },
2647
+ ),
2648
+ description='test return type.',
2649
+ )
2650
+ expected_schema_vertex = copy.deepcopy(expected_schema)
2651
+ expected_schema_vertex.response = types.Schema(type='STRING')
2652
+ expected_schema_vertex.parameters.required = ['a']
2653
+
2654
+ actual_schema_vertex = types.FunctionDeclaration.from_callable(
2655
+ client=vertex_client, callable=func_under_test
2656
+ )
2657
+
2658
+ assert actual_schema_vertex == expected_schema_vertex
2659
+
2660
+
2661
+ def test_function_with_option_vertex(monkeypatch):
2662
+
2663
+ def func_under_test(a: int) -> str:
2664
+ """test return type."""
2665
+ return ''
2666
+
2667
+ expected_schema = types.FunctionDeclaration(
2668
+ name='func_under_test',
2669
+ parameters=types.Schema(
2670
+ type='OBJECT',
2671
+ properties={
2672
+ 'a': types.Schema(type='INTEGER'),
2673
+ },
2674
+ ),
2675
+ description='test return type.',
2676
+ )
2677
+ expected_schema_vertex = copy.deepcopy(expected_schema)
2678
+ expected_schema_vertex.response = types.Schema(type='STRING')
2679
+ expected_schema_vertex.parameters.required = ['a']
2680
+
2681
+ actual_schema_vertex = (
2682
+ types.FunctionDeclaration.from_callable_with_api_option(
2683
+ callable=func_under_test, api_option='VERTEX_AI'
2684
+ )
2685
+ )
2686
+
2687
+ assert actual_schema_vertex == expected_schema_vertex
2688
+
2689
+
2690
+ def test_case_insensitive_enum():
2691
+ assert types.Type('STRING') == types.Type.STRING
2692
+ assert types.Type('string') == types.Type.STRING
2693
+
2694
+
2695
+ def test_case_insensitive_enum_with_pydantic_model():
2696
+ class TestModel(pydantic.BaseModel):
2697
+ test_enum: types.Type
2698
+
2699
+ assert TestModel(test_enum='STRING').test_enum == types.Type.STRING
2700
+ assert TestModel(test_enum='string').test_enum == types.Type.STRING
2701
+
2702
+
2703
+ def test_unknown_enum_value():
2704
+ with pytest.warns(Warning, match='is not a valid'):
2705
+ enum_instance = types.Type('float')
2706
+ assert enum_instance.name == 'float'
2707
+ assert enum_instance.value == 'float'
2708
+
2709
+
2710
+ def test_unknown_enum_value_in_nested_dict():
2711
+ schema = types.SafetyRating._from_response(
2712
+ response={'category': 'NEW_CATEGORY'}, kwargs=None
2713
+ )
2714
+ assert schema.category.name == 'NEW_CATEGORY'
2715
+ assert schema.category.value == 'NEW_CATEGORY'
2716
+
2717
+
2718
+ # Tests that TypedDict types from types.py are compatible with pydantic
2719
+ # pydantic requires TypedDict from typing_extensions for Python <3.12
2720
+ def test_typed_dict_pydantic_field():
2721
+ from pydantic import BaseModel
2722
+
2723
+ class MyConfig(BaseModel):
2724
+ config: types.GenerationConfigDict
2725
+
2726
+
2727
+ def test_model_content_list_part_from_uri():
2728
+ expected_model_content = types.Content(
2729
+ role='model',
2730
+ parts=[
2731
+ types.Part(text='what is this image about?'),
2732
+ types.Part(
2733
+ file_data=types.FileData(
2734
+ file_uri='gs://generativeai-downloads/images/scones.jpg',
2735
+ mime_type='image/jpeg',
2736
+ )
2737
+ ),
2738
+ ],
2739
+ )
2740
+
2741
+ actual_model_content = types.ModelContent(
2742
+ parts=[
2743
+ 'what is this image about?',
2744
+ types.Part.from_uri(
2745
+ file_uri='gs://generativeai-downloads/images/scones.jpg',
2746
+ mime_type='image/jpeg',
2747
+ ),
2748
+ ]
2749
+ )
2750
+
2751
+ assert expected_model_content.model_dump_json(
2752
+ exclude_none=True
2753
+ ) == actual_model_content.model_dump_json(exclude_none=True)
2754
+
2755
+
2756
+ def test_model_content_part_from_uri():
2757
+ expected_model_content = types.Content(
2758
+ role='model',
2759
+ parts=[
2760
+ types.Part(
2761
+ file_data=types.FileData(
2762
+ file_uri='gs://generativeai-downloads/images/scones.jpg',
2763
+ mime_type='image/jpeg',
2764
+ )
2765
+ )
2766
+ ],
2767
+ )
2768
+
2769
+ actual_model_content = types.ModelContent(
2770
+ parts=types.Part.from_uri(
2771
+ file_uri='gs://generativeai-downloads/images/scones.jpg',
2772
+ mime_type='image/jpeg',
2773
+ )
2774
+ )
2775
+
2776
+ assert expected_model_content.model_dump_json(
2777
+ exclude_none=True
2778
+ ) == actual_model_content.model_dump_json(exclude_none=True)
2779
+
2780
+
2781
+ def test_model_content_from_string():
2782
+ expected_model_content = types.Content(
2783
+ role='model',
2784
+ parts=[types.Part(text='why is the sky blue?')],
2785
+ )
2786
+
2787
+ actual_model_content = types.ModelContent('why is the sky blue?')
2788
+
2789
+ assert expected_model_content.model_dump_json(
2790
+ exclude_none=True
2791
+ ) == actual_model_content.model_dump_json(exclude_none=True)
2792
+
2793
+
2794
+ def test_model_content_unsupported_type():
2795
+ with pytest.raises(ValueError):
2796
+ types.ModelContent(123)
2797
+
2798
+
2799
+ def test_model_content_empty_list():
2800
+ with pytest.raises(ValueError):
2801
+ types.ModelContent([])
2802
+
2803
+
2804
+ def test_model_content_unsupported_type_in_list():
2805
+ with pytest.raises(ValueError):
2806
+ types.ModelContent(['hi', 123])
2807
+
2808
+
2809
+ def test_model_content_unsupported_role():
2810
+ with pytest.raises(TypeError):
2811
+ types.ModelContent(role='user', parts=['hi'])
2812
+
2813
+
2814
+ def test_model_content_modify_role():
2815
+ model_content = types.ModelContent(['hi'])
2816
+ with pytest.raises(pydantic.ValidationError):
2817
+ model_content.role = 'user'
2818
+
2819
+
2820
+ def test_model_content_modify_parts():
2821
+ expected_model_content = types.Content(
2822
+ role='model',
2823
+ parts=[types.Part(text='hello')],
2824
+ )
2825
+ model_content = types.ModelContent(['hi'])
2826
+ model_content.parts = [types.Part(text='hello')]
2827
+
2828
+ assert expected_model_content.model_dump_json(
2829
+ exclude_none=True
2830
+ ) == model_content.model_dump_json(exclude_none=True)
2831
+
2832
+
2833
+ def test_user_content_unsupported_type():
2834
+ with pytest.raises(ValueError):
2835
+ types.UserContent(123)
2836
+
2837
+
2838
+ def test_user_content_modify_role():
2839
+ user_content = types.UserContent(['hi'])
2840
+ with pytest.raises(pydantic.ValidationError):
2841
+ user_content.role = 'model'
2842
+
2843
+
2844
+ def test_user_content_modify_parts():
2845
+ expected_user_content = types.Content(
2846
+ role='user',
2847
+ parts=[types.Part(text='hello')],
2848
+ )
2849
+ user_content = types.UserContent(['hi'])
2850
+ user_content.parts = [types.Part(text='hello')]
2851
+
2852
+ assert expected_user_content.model_dump_json(
2853
+ exclude_none=True
2854
+ ) == user_content.model_dump_json(exclude_none=True)
2855
+
2856
+
2857
+ def test_user_content_empty_list():
2858
+ with pytest.raises(ValueError):
2859
+ types.UserContent([])
2860
+
2861
+
2862
+ def test_user_content_unsupported_type_in_list():
2863
+ with pytest.raises(ValueError):
2864
+ types.UserContent(['hi', 123])
2865
+
2866
+
2867
+ def test_user_content_unsupported_role():
2868
+ with pytest.raises(TypeError):
2869
+ types.UserContent(role='model', parts=['hi'])
2870
+
2871
+
2872
+ def test_instantiate_response_from_batch_json():
2873
+ test_batch_json = json.dumps({
2874
+ 'candidates': [{
2875
+ 'citationMetadata': {
2876
+ 'citationSources': [{
2877
+ 'endIndex': 2009,
2878
+ 'startIndex': 1880,
2879
+ 'uri': 'http://someurl.com',
2880
+ }]
2881
+ },
2882
+ 'content': {
2883
+ 'parts': [{
2884
+ 'text': (
2885
+ 'This recipe makes a moist and delicious banana bread!'
2886
+ )
2887
+ }],
2888
+ 'role': 'model',
2889
+ },
2890
+ 'finishReason': 'STOP',
2891
+ }],
2892
+ 'modelVersion': 'gemini-1.5-flash-002@default',
2893
+ })
2894
+ parsed = types.GenerateContentResponse.model_validate_json(test_batch_json)
2895
+ assert isinstance(parsed, types.GenerateContentResponse)
2896
+ assert isinstance(parsed.candidates[0].citation_metadata, types.CitationMetadata)
2897
+ assert isinstance(
2898
+ parsed.candidates[0].citation_metadata.citations[0], types.Citation
2899
+ )
2900
+ assert(
2901
+ parsed.candidates[0].citation_metadata.citations[0].uri
2902
+ == 'http://someurl.com'
2903
+ )