apache-airflow-providers-google 14.0.0__py3-none-any.whl → 19.1.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. airflow/providers/google/3rd-party-licenses/LICENSES.txt +14 -0
  2. airflow/providers/google/3rd-party-licenses/NOTICE +5 -0
  3. airflow/providers/google/__init__.py +3 -3
  4. airflow/providers/google/_vendor/__init__.py +0 -0
  5. airflow/providers/google/_vendor/json_merge_patch.py +91 -0
  6. airflow/providers/google/ads/hooks/ads.py +52 -43
  7. airflow/providers/google/ads/operators/ads.py +2 -2
  8. airflow/providers/google/ads/transfers/ads_to_gcs.py +3 -19
  9. airflow/providers/google/assets/gcs.py +1 -11
  10. airflow/providers/google/cloud/_internal_client/secret_manager_client.py +3 -2
  11. airflow/providers/google/cloud/bundles/gcs.py +161 -0
  12. airflow/providers/google/cloud/hooks/alloy_db.py +2 -3
  13. airflow/providers/google/cloud/hooks/bigquery.py +195 -318
  14. airflow/providers/google/cloud/hooks/bigquery_dts.py +8 -8
  15. airflow/providers/google/cloud/hooks/bigtable.py +3 -2
  16. airflow/providers/google/cloud/hooks/cloud_batch.py +8 -9
  17. airflow/providers/google/cloud/hooks/cloud_build.py +6 -65
  18. airflow/providers/google/cloud/hooks/cloud_composer.py +292 -24
  19. airflow/providers/google/cloud/hooks/cloud_logging.py +109 -0
  20. airflow/providers/google/cloud/hooks/cloud_memorystore.py +4 -3
  21. airflow/providers/google/cloud/hooks/cloud_run.py +20 -11
  22. airflow/providers/google/cloud/hooks/cloud_sql.py +136 -64
  23. airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +35 -15
  24. airflow/providers/google/cloud/hooks/compute.py +7 -6
  25. airflow/providers/google/cloud/hooks/compute_ssh.py +7 -4
  26. airflow/providers/google/cloud/hooks/datacatalog.py +12 -3
  27. airflow/providers/google/cloud/hooks/dataflow.py +87 -242
  28. airflow/providers/google/cloud/hooks/dataform.py +9 -14
  29. airflow/providers/google/cloud/hooks/datafusion.py +7 -9
  30. airflow/providers/google/cloud/hooks/dataplex.py +13 -12
  31. airflow/providers/google/cloud/hooks/dataprep.py +2 -2
  32. airflow/providers/google/cloud/hooks/dataproc.py +76 -74
  33. airflow/providers/google/cloud/hooks/dataproc_metastore.py +4 -3
  34. airflow/providers/google/cloud/hooks/dlp.py +5 -4
  35. airflow/providers/google/cloud/hooks/gcs.py +144 -33
  36. airflow/providers/google/cloud/hooks/gen_ai.py +196 -0
  37. airflow/providers/google/cloud/hooks/kms.py +3 -2
  38. airflow/providers/google/cloud/hooks/kubernetes_engine.py +22 -17
  39. airflow/providers/google/cloud/hooks/looker.py +6 -1
  40. airflow/providers/google/cloud/hooks/managed_kafka.py +227 -3
  41. airflow/providers/google/cloud/hooks/mlengine.py +7 -8
  42. airflow/providers/google/cloud/hooks/natural_language.py +3 -2
  43. airflow/providers/google/cloud/hooks/os_login.py +3 -2
  44. airflow/providers/google/cloud/hooks/pubsub.py +6 -6
  45. airflow/providers/google/cloud/hooks/secret_manager.py +105 -12
  46. airflow/providers/google/cloud/hooks/spanner.py +75 -10
  47. airflow/providers/google/cloud/hooks/speech_to_text.py +3 -2
  48. airflow/providers/google/cloud/hooks/stackdriver.py +18 -18
  49. airflow/providers/google/cloud/hooks/tasks.py +4 -3
  50. airflow/providers/google/cloud/hooks/text_to_speech.py +3 -2
  51. airflow/providers/google/cloud/hooks/translate.py +8 -17
  52. airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py +8 -222
  53. airflow/providers/google/cloud/hooks/vertex_ai/batch_prediction_job.py +9 -15
  54. airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py +33 -283
  55. airflow/providers/google/cloud/hooks/vertex_ai/dataset.py +5 -12
  56. airflow/providers/google/cloud/hooks/vertex_ai/endpoint_service.py +6 -12
  57. airflow/providers/google/cloud/hooks/vertex_ai/experiment_service.py +202 -0
  58. airflow/providers/google/cloud/hooks/vertex_ai/feature_store.py +311 -10
  59. airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +79 -75
  60. airflow/providers/google/cloud/hooks/vertex_ai/hyperparameter_tuning_job.py +7 -13
  61. airflow/providers/google/cloud/hooks/vertex_ai/model_service.py +8 -12
  62. airflow/providers/google/cloud/hooks/vertex_ai/pipeline_job.py +6 -12
  63. airflow/providers/google/cloud/hooks/vertex_ai/prediction_service.py +3 -2
  64. airflow/providers/google/cloud/hooks/vertex_ai/ray.py +223 -0
  65. airflow/providers/google/cloud/hooks/video_intelligence.py +3 -2
  66. airflow/providers/google/cloud/hooks/vision.py +7 -7
  67. airflow/providers/google/cloud/hooks/workflows.py +4 -3
  68. airflow/providers/google/cloud/links/alloy_db.py +0 -46
  69. airflow/providers/google/cloud/links/base.py +77 -7
  70. airflow/providers/google/cloud/links/bigquery.py +0 -47
  71. airflow/providers/google/cloud/links/bigquery_dts.py +0 -20
  72. airflow/providers/google/cloud/links/bigtable.py +0 -48
  73. airflow/providers/google/cloud/links/cloud_build.py +0 -73
  74. airflow/providers/google/cloud/links/cloud_functions.py +0 -33
  75. airflow/providers/google/cloud/links/cloud_memorystore.py +0 -58
  76. airflow/providers/google/cloud/links/{life_sciences.py → cloud_run.py} +5 -27
  77. airflow/providers/google/cloud/links/cloud_sql.py +0 -33
  78. airflow/providers/google/cloud/links/cloud_storage_transfer.py +17 -46
  79. airflow/providers/google/cloud/links/cloud_tasks.py +7 -26
  80. airflow/providers/google/cloud/links/compute.py +0 -58
  81. airflow/providers/google/cloud/links/data_loss_prevention.py +0 -169
  82. airflow/providers/google/cloud/links/datacatalog.py +23 -54
  83. airflow/providers/google/cloud/links/dataflow.py +0 -34
  84. airflow/providers/google/cloud/links/dataform.py +0 -64
  85. airflow/providers/google/cloud/links/datafusion.py +1 -90
  86. airflow/providers/google/cloud/links/dataplex.py +0 -154
  87. airflow/providers/google/cloud/links/dataprep.py +0 -24
  88. airflow/providers/google/cloud/links/dataproc.py +11 -89
  89. airflow/providers/google/cloud/links/datastore.py +0 -31
  90. airflow/providers/google/cloud/links/kubernetes_engine.py +11 -61
  91. airflow/providers/google/cloud/links/managed_kafka.py +11 -51
  92. airflow/providers/google/cloud/links/mlengine.py +0 -70
  93. airflow/providers/google/cloud/links/pubsub.py +0 -32
  94. airflow/providers/google/cloud/links/spanner.py +0 -33
  95. airflow/providers/google/cloud/links/stackdriver.py +0 -30
  96. airflow/providers/google/cloud/links/translate.py +17 -187
  97. airflow/providers/google/cloud/links/vertex_ai.py +28 -195
  98. airflow/providers/google/cloud/links/workflows.py +0 -52
  99. airflow/providers/google/cloud/log/gcs_task_handler.py +166 -118
  100. airflow/providers/google/cloud/log/stackdriver_task_handler.py +14 -9
  101. airflow/providers/google/cloud/openlineage/CloudStorageTransferJobFacet.json +68 -0
  102. airflow/providers/google/cloud/openlineage/CloudStorageTransferRunFacet.json +60 -0
  103. airflow/providers/google/cloud/openlineage/DataFusionRunFacet.json +32 -0
  104. airflow/providers/google/cloud/openlineage/facets.py +141 -40
  105. airflow/providers/google/cloud/openlineage/mixins.py +14 -13
  106. airflow/providers/google/cloud/openlineage/utils.py +19 -3
  107. airflow/providers/google/cloud/operators/alloy_db.py +76 -61
  108. airflow/providers/google/cloud/operators/bigquery.py +104 -667
  109. airflow/providers/google/cloud/operators/bigquery_dts.py +12 -12
  110. airflow/providers/google/cloud/operators/bigtable.py +38 -7
  111. airflow/providers/google/cloud/operators/cloud_base.py +22 -1
  112. airflow/providers/google/cloud/operators/cloud_batch.py +18 -18
  113. airflow/providers/google/cloud/operators/cloud_build.py +80 -36
  114. airflow/providers/google/cloud/operators/cloud_composer.py +157 -71
  115. airflow/providers/google/cloud/operators/cloud_logging_sink.py +341 -0
  116. airflow/providers/google/cloud/operators/cloud_memorystore.py +74 -46
  117. airflow/providers/google/cloud/operators/cloud_run.py +39 -20
  118. airflow/providers/google/cloud/operators/cloud_sql.py +46 -61
  119. airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +92 -14
  120. airflow/providers/google/cloud/operators/compute.py +18 -50
  121. airflow/providers/google/cloud/operators/datacatalog.py +167 -29
  122. airflow/providers/google/cloud/operators/dataflow.py +38 -15
  123. airflow/providers/google/cloud/operators/dataform.py +19 -7
  124. airflow/providers/google/cloud/operators/datafusion.py +43 -43
  125. airflow/providers/google/cloud/operators/dataplex.py +212 -126
  126. airflow/providers/google/cloud/operators/dataprep.py +1 -5
  127. airflow/providers/google/cloud/operators/dataproc.py +134 -207
  128. airflow/providers/google/cloud/operators/dataproc_metastore.py +102 -84
  129. airflow/providers/google/cloud/operators/datastore.py +22 -6
  130. airflow/providers/google/cloud/operators/dlp.py +24 -45
  131. airflow/providers/google/cloud/operators/functions.py +21 -14
  132. airflow/providers/google/cloud/operators/gcs.py +15 -12
  133. airflow/providers/google/cloud/operators/gen_ai.py +389 -0
  134. airflow/providers/google/cloud/operators/kubernetes_engine.py +115 -106
  135. airflow/providers/google/cloud/operators/looker.py +1 -1
  136. airflow/providers/google/cloud/operators/managed_kafka.py +362 -40
  137. airflow/providers/google/cloud/operators/natural_language.py +5 -3
  138. airflow/providers/google/cloud/operators/pubsub.py +69 -21
  139. airflow/providers/google/cloud/operators/spanner.py +53 -45
  140. airflow/providers/google/cloud/operators/speech_to_text.py +5 -4
  141. airflow/providers/google/cloud/operators/stackdriver.py +5 -11
  142. airflow/providers/google/cloud/operators/tasks.py +6 -15
  143. airflow/providers/google/cloud/operators/text_to_speech.py +4 -3
  144. airflow/providers/google/cloud/operators/translate.py +46 -20
  145. airflow/providers/google/cloud/operators/translate_speech.py +4 -3
  146. airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py +44 -34
  147. airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py +34 -12
  148. airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +62 -53
  149. airflow/providers/google/cloud/operators/vertex_ai/dataset.py +75 -11
  150. airflow/providers/google/cloud/operators/vertex_ai/endpoint_service.py +48 -12
  151. airflow/providers/google/cloud/operators/vertex_ai/experiment_service.py +435 -0
  152. airflow/providers/google/cloud/operators/vertex_ai/feature_store.py +532 -1
  153. airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +135 -116
  154. airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py +16 -12
  155. airflow/providers/google/cloud/operators/vertex_ai/model_service.py +62 -14
  156. airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py +35 -10
  157. airflow/providers/google/cloud/operators/vertex_ai/ray.py +393 -0
  158. airflow/providers/google/cloud/operators/video_intelligence.py +5 -3
  159. airflow/providers/google/cloud/operators/vision.py +7 -5
  160. airflow/providers/google/cloud/operators/workflows.py +24 -19
  161. airflow/providers/google/cloud/secrets/secret_manager.py +2 -1
  162. airflow/providers/google/cloud/sensors/bigquery.py +2 -2
  163. airflow/providers/google/cloud/sensors/bigquery_dts.py +6 -4
  164. airflow/providers/google/cloud/sensors/bigtable.py +14 -6
  165. airflow/providers/google/cloud/sensors/cloud_composer.py +535 -33
  166. airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py +6 -5
  167. airflow/providers/google/cloud/sensors/dataflow.py +27 -10
  168. airflow/providers/google/cloud/sensors/dataform.py +2 -2
  169. airflow/providers/google/cloud/sensors/datafusion.py +4 -4
  170. airflow/providers/google/cloud/sensors/dataplex.py +7 -5
  171. airflow/providers/google/cloud/sensors/dataprep.py +2 -2
  172. airflow/providers/google/cloud/sensors/dataproc.py +10 -9
  173. airflow/providers/google/cloud/sensors/dataproc_metastore.py +4 -3
  174. airflow/providers/google/cloud/sensors/gcs.py +22 -21
  175. airflow/providers/google/cloud/sensors/looker.py +5 -5
  176. airflow/providers/google/cloud/sensors/pubsub.py +20 -20
  177. airflow/providers/google/cloud/sensors/tasks.py +2 -2
  178. airflow/providers/google/cloud/sensors/vertex_ai/feature_store.py +2 -2
  179. airflow/providers/google/cloud/sensors/workflows.py +6 -4
  180. airflow/providers/google/cloud/transfers/adls_to_gcs.py +1 -1
  181. airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py +2 -2
  182. airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py +2 -2
  183. airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py +11 -8
  184. airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +14 -13
  185. airflow/providers/google/cloud/transfers/bigquery_to_mssql.py +7 -3
  186. airflow/providers/google/cloud/transfers/bigquery_to_mysql.py +12 -1
  187. airflow/providers/google/cloud/transfers/bigquery_to_postgres.py +24 -10
  188. airflow/providers/google/cloud/transfers/bigquery_to_sql.py +104 -5
  189. airflow/providers/google/cloud/transfers/calendar_to_gcs.py +1 -1
  190. airflow/providers/google/cloud/transfers/cassandra_to_gcs.py +18 -22
  191. airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py +4 -5
  192. airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +45 -38
  193. airflow/providers/google/cloud/transfers/gcs_to_gcs.py +2 -2
  194. airflow/providers/google/cloud/transfers/gcs_to_local.py +5 -3
  195. airflow/providers/google/cloud/transfers/gcs_to_sftp.py +10 -4
  196. airflow/providers/google/cloud/transfers/gdrive_to_gcs.py +6 -2
  197. airflow/providers/google/cloud/transfers/gdrive_to_local.py +2 -2
  198. airflow/providers/google/cloud/transfers/http_to_gcs.py +193 -0
  199. airflow/providers/google/cloud/transfers/local_to_gcs.py +2 -2
  200. airflow/providers/google/cloud/transfers/mssql_to_gcs.py +1 -1
  201. airflow/providers/google/cloud/transfers/oracle_to_gcs.py +36 -11
  202. airflow/providers/google/cloud/transfers/postgres_to_gcs.py +44 -12
  203. airflow/providers/google/cloud/transfers/s3_to_gcs.py +12 -6
  204. airflow/providers/google/cloud/transfers/salesforce_to_gcs.py +2 -2
  205. airflow/providers/google/cloud/transfers/sftp_to_gcs.py +36 -14
  206. airflow/providers/google/cloud/transfers/sheets_to_gcs.py +3 -3
  207. airflow/providers/google/cloud/transfers/sql_to_gcs.py +10 -10
  208. airflow/providers/google/cloud/triggers/bigquery.py +75 -34
  209. airflow/providers/google/cloud/triggers/bigquery_dts.py +2 -1
  210. airflow/providers/google/cloud/triggers/cloud_batch.py +2 -1
  211. airflow/providers/google/cloud/triggers/cloud_build.py +3 -2
  212. airflow/providers/google/cloud/triggers/cloud_composer.py +303 -47
  213. airflow/providers/google/cloud/triggers/cloud_run.py +2 -2
  214. airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py +96 -5
  215. airflow/providers/google/cloud/triggers/dataflow.py +125 -2
  216. airflow/providers/google/cloud/triggers/datafusion.py +1 -1
  217. airflow/providers/google/cloud/triggers/dataplex.py +16 -3
  218. airflow/providers/google/cloud/triggers/dataproc.py +124 -53
  219. airflow/providers/google/cloud/triggers/kubernetes_engine.py +46 -28
  220. airflow/providers/google/cloud/triggers/mlengine.py +1 -1
  221. airflow/providers/google/cloud/triggers/pubsub.py +17 -20
  222. airflow/providers/google/cloud/triggers/vertex_ai.py +8 -7
  223. airflow/providers/google/cloud/utils/bigquery.py +5 -7
  224. airflow/providers/google/cloud/utils/bigquery_get_data.py +1 -1
  225. airflow/providers/google/cloud/utils/credentials_provider.py +4 -3
  226. airflow/providers/google/cloud/utils/dataform.py +1 -1
  227. airflow/providers/google/cloud/utils/external_token_supplier.py +0 -1
  228. airflow/providers/google/cloud/utils/field_validator.py +1 -2
  229. airflow/providers/google/cloud/utils/validators.py +43 -0
  230. airflow/providers/google/common/auth_backend/google_openid.py +26 -9
  231. airflow/providers/google/common/consts.py +2 -1
  232. airflow/providers/google/common/deprecated.py +2 -1
  233. airflow/providers/google/common/hooks/base_google.py +40 -43
  234. airflow/providers/google/common/hooks/operation_helpers.py +78 -0
  235. airflow/providers/google/common/links/storage.py +0 -22
  236. airflow/providers/google/common/utils/get_secret.py +31 -0
  237. airflow/providers/google/common/utils/id_token_credentials.py +4 -5
  238. airflow/providers/google/firebase/operators/firestore.py +2 -2
  239. airflow/providers/google/get_provider_info.py +61 -216
  240. airflow/providers/google/go_module_utils.py +35 -3
  241. airflow/providers/google/leveldb/hooks/leveldb.py +30 -6
  242. airflow/providers/google/leveldb/operators/leveldb.py +2 -2
  243. airflow/providers/google/marketing_platform/hooks/analytics_admin.py +3 -2
  244. airflow/providers/google/marketing_platform/hooks/display_video.py +3 -109
  245. airflow/providers/google/marketing_platform/hooks/search_ads.py +1 -1
  246. airflow/providers/google/marketing_platform/links/analytics_admin.py +4 -5
  247. airflow/providers/google/marketing_platform/operators/analytics_admin.py +7 -6
  248. airflow/providers/google/marketing_platform/operators/campaign_manager.py +5 -5
  249. airflow/providers/google/marketing_platform/operators/display_video.py +28 -489
  250. airflow/providers/google/marketing_platform/operators/search_ads.py +2 -2
  251. airflow/providers/google/marketing_platform/sensors/campaign_manager.py +2 -2
  252. airflow/providers/google/marketing_platform/sensors/display_video.py +4 -64
  253. airflow/providers/google/suite/hooks/calendar.py +1 -1
  254. airflow/providers/google/suite/hooks/drive.py +2 -2
  255. airflow/providers/google/suite/hooks/sheets.py +15 -1
  256. airflow/providers/google/suite/operators/sheets.py +8 -3
  257. airflow/providers/google/suite/sensors/drive.py +2 -2
  258. airflow/providers/google/suite/transfers/gcs_to_gdrive.py +2 -2
  259. airflow/providers/google/suite/transfers/gcs_to_sheets.py +1 -1
  260. airflow/providers/google/suite/transfers/local_to_drive.py +3 -3
  261. airflow/providers/google/suite/transfers/sql_to_sheets.py +5 -4
  262. airflow/providers/google/version_compat.py +15 -1
  263. {apache_airflow_providers_google-14.0.0.dist-info → apache_airflow_providers_google-19.1.0rc1.dist-info}/METADATA +117 -72
  264. apache_airflow_providers_google-19.1.0rc1.dist-info/RECORD +331 -0
  265. {apache_airflow_providers_google-14.0.0.dist-info → apache_airflow_providers_google-19.1.0rc1.dist-info}/WHEEL +1 -1
  266. apache_airflow_providers_google-19.1.0rc1.dist-info/licenses/NOTICE +5 -0
  267. airflow/providers/google/cloud/example_dags/example_cloud_task.py +0 -54
  268. airflow/providers/google/cloud/hooks/automl.py +0 -679
  269. airflow/providers/google/cloud/hooks/life_sciences.py +0 -159
  270. airflow/providers/google/cloud/links/automl.py +0 -193
  271. airflow/providers/google/cloud/operators/automl.py +0 -1360
  272. airflow/providers/google/cloud/operators/life_sciences.py +0 -119
  273. airflow/providers/google/cloud/operators/mlengine.py +0 -1515
  274. airflow/providers/google/cloud/utils/mlengine_operator_utils.py +0 -273
  275. apache_airflow_providers_google-14.0.0.dist-info/RECORD +0 -318
  276. /airflow/providers/google/cloud/{example_dags → bundles}/__init__.py +0 -0
  277. {apache_airflow_providers_google-14.0.0.dist-info → apache_airflow_providers_google-19.1.0rc1.dist-info}/entry_points.txt +0 -0
  278. {airflow/providers/google → apache_airflow_providers_google-19.1.0rc1.dist-info/licenses}/LICENSE +0 -0
@@ -19,17 +19,19 @@
19
19
 
20
20
  from __future__ import annotations
21
21
 
22
- from collections.abc import Sequence
23
- from typing import TYPE_CHECKING, Callable, NamedTuple
22
+ from collections import OrderedDict
23
+ from collections.abc import Callable, Sequence
24
+ from typing import TYPE_CHECKING, NamedTuple
24
25
 
26
+ from google.api_core.exceptions import AlreadyExists, GoogleAPICallError
27
+ from google.cloud.spanner_v1.client import Client
25
28
  from sqlalchemy import create_engine
26
29
 
27
30
  from airflow.exceptions import AirflowException
28
31
  from airflow.providers.common.sql.hooks.sql import DbApiHook
29
32
  from airflow.providers.google.common.consts import CLIENT_INFO
30
33
  from airflow.providers.google.common.hooks.base_google import GoogleBaseHook, get_field
31
- from google.api_core.exceptions import AlreadyExists, GoogleAPICallError
32
- from google.cloud.spanner_v1.client import Client
34
+ from airflow.providers.openlineage.sqlparser import DatabaseInfo
33
35
 
34
36
  if TYPE_CHECKING:
35
37
  from google.cloud.spanner_v1.database import Database
@@ -37,6 +39,8 @@ if TYPE_CHECKING:
37
39
  from google.cloud.spanner_v1.transaction import Transaction
38
40
  from google.longrunning.operations_grpc_pb2 import Operation
39
41
 
42
+ from airflow.models.connection import Connection
43
+
40
44
 
41
45
  class SpannerConnectionParams(NamedTuple):
42
46
  """Information about Google Spanner connection parameters."""
@@ -388,7 +392,7 @@ class SpannerHook(GoogleBaseHook, DbApiHook):
388
392
  database_id: str,
389
393
  queries: list[str],
390
394
  project_id: str,
391
- ) -> None:
395
+ ) -> list[int]:
392
396
  """
393
397
  Execute an arbitrary DML query (INSERT, UPDATE, DELETE).
394
398
 
@@ -398,12 +402,73 @@ class SpannerHook(GoogleBaseHook, DbApiHook):
398
402
  :param project_id: Optional, the ID of the Google Cloud project that owns the Cloud Spanner
399
403
  database. If set to None or missing, the default project_id from the Google Cloud connection
400
404
  is used.
405
+ :return: list of numbers of affected rows by DML query
401
406
  """
402
- self._get_client(project_id=project_id).instance(instance_id=instance_id).database(
403
- database_id=database_id
404
- ).run_in_transaction(lambda transaction: self._execute_sql_in_transaction(transaction, queries))
407
+ db = (
408
+ self._get_client(project_id=project_id)
409
+ .instance(instance_id=instance_id)
410
+ .database(database_id=database_id)
411
+ )
412
+
413
+ def _tx_runner(tx: Transaction) -> dict[str, int]:
414
+ return self._execute_sql_in_transaction(tx, queries)
415
+
416
+ result = db.run_in_transaction(_tx_runner)
417
+
418
+ result_rows_count_per_query = []
419
+ for i, (sql, rc) in enumerate(result.items(), start=1):
420
+ if not sql.startswith("SELECT"):
421
+ preview = sql if len(sql) <= 300 else sql[:300] + "…"
422
+ self.log.info("[DML %d/%d] affected rows=%d | %s", i, len(result), rc, preview)
423
+ result_rows_count_per_query.append(rc)
424
+ return result_rows_count_per_query
405
425
 
406
426
  @staticmethod
407
- def _execute_sql_in_transaction(transaction: Transaction, queries: list[str]):
427
+ def _execute_sql_in_transaction(transaction: Transaction, queries: list[str]) -> dict[str, int]:
428
+ counts: OrderedDict[str, int] = OrderedDict()
408
429
  for sql in queries:
409
- transaction.execute_update(sql)
430
+ rc = transaction.execute_update(sql)
431
+ counts[sql] = rc
432
+ return counts
433
+
434
+ def _get_openlineage_authority_part(self, connection: Connection) -> str | None:
435
+ """Build Spanner-specific authority part for OpenLineage. Returns {project}/{instance}."""
436
+ extras = connection.extra_dejson
437
+ project_id = extras.get("project_id")
438
+ instance_id = extras.get("instance_id")
439
+
440
+ if not project_id or not instance_id:
441
+ return None
442
+
443
+ return f"{project_id}/{instance_id}"
444
+
445
+ def get_openlineage_database_dialect(self, connection: Connection) -> str:
446
+ """Return database dialect for OpenLineage."""
447
+ return "spanner"
448
+
449
+ def get_openlineage_database_info(self, connection: Connection) -> DatabaseInfo:
450
+ """Return Spanner specific information for OpenLineage."""
451
+ extras = connection.extra_dejson
452
+ database_id = extras.get("database_id")
453
+
454
+ return DatabaseInfo(
455
+ scheme=self.get_openlineage_database_dialect(connection),
456
+ authority=self._get_openlineage_authority_part(connection),
457
+ database=database_id,
458
+ information_schema_columns=[
459
+ "table_schema",
460
+ "table_name",
461
+ "column_name",
462
+ "ordinal_position",
463
+ "spanner_type",
464
+ ],
465
+ )
466
+
467
+ def get_openlineage_default_schema(self) -> str | None:
468
+ """
469
+ Spanner expose 'public' or '' schema depending on dialect(Postgres vs GoogleSQL).
470
+
471
+ SQLAlchemy dialect for Spanner does not expose default schema, so we return None
472
+ to follow the same approach.
473
+ """
474
+ return None
@@ -22,12 +22,13 @@ from __future__ import annotations
22
22
  from collections.abc import Sequence
23
23
  from typing import TYPE_CHECKING
24
24
 
25
- from airflow.providers.google.common.consts import CLIENT_INFO
26
- from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
27
25
  from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
28
26
  from google.cloud.speech_v1 import SpeechClient
29
27
  from google.cloud.speech_v1.types import RecognitionAudio, RecognitionConfig
30
28
 
29
+ from airflow.providers.google.common.consts import CLIENT_INFO
30
+ from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
31
+
31
32
  if TYPE_CHECKING:
32
33
  from google.api_core.retry import Retry
33
34
 
@@ -24,15 +24,15 @@ import json
24
24
  from collections.abc import Sequence
25
25
  from typing import TYPE_CHECKING, Any
26
26
 
27
- from googleapiclient.errors import HttpError
28
-
29
- from airflow.exceptions import AirflowException
30
- from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook
31
27
  from google.api_core.exceptions import InvalidArgument
32
28
  from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
33
29
  from google.cloud import monitoring_v3
34
30
  from google.cloud.monitoring_v3 import AlertPolicy, NotificationChannel
35
31
  from google.protobuf.field_mask_pb2 import FieldMask
32
+ from googleapiclient.errors import HttpError
33
+
34
+ from airflow.exceptions import AirflowException
35
+ from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook
36
36
 
37
37
  if TYPE_CHECKING:
38
38
  from google.api_core.retry import Retry
@@ -121,10 +121,9 @@ class StackdriverHook(GoogleBaseHook):
121
121
  )
122
122
  if format_ == "dict":
123
123
  return [AlertPolicy.to_dict(policy) for policy in policies_]
124
- elif format_ == "json":
124
+ if format_ == "json":
125
125
  return [AlertPolicy.to_jsoon(policy) for policy in policies_]
126
- else:
127
- return policies_
126
+ return policies_
128
127
 
129
128
  @GoogleBaseHook.fallback_to_default_project_id
130
129
  def _toggle_policy_status(
@@ -262,8 +261,9 @@ class StackdriverHook(GoogleBaseHook):
262
261
  channel_name_map = {}
263
262
 
264
263
  for channel in channels:
264
+ # This field is immutable, illegal to specifying non-default UNVERIFIED or VERIFIED, so setting default
265
265
  channel.verification_status = (
266
- monitoring_v3.NotificationChannel.VerificationStatus.VERIFICATION_STATUS_UNSPECIFIED
266
+ monitoring_v3.NotificationChannel.VerificationStatus.VERIFICATION_STATUS_UNSPECIFIED # type: ignore[assignment]
267
267
  )
268
268
 
269
269
  if channel.name in existing_channels:
@@ -275,7 +275,7 @@ class StackdriverHook(GoogleBaseHook):
275
275
  )
276
276
  else:
277
277
  old_name = channel.name
278
- channel.name = None
278
+ del channel.name
279
279
  new_channel = channel_client.create_notification_channel(
280
280
  request={"name": f"projects/{project_id}", "notification_channel": channel},
281
281
  retry=retry,
@@ -285,8 +285,8 @@ class StackdriverHook(GoogleBaseHook):
285
285
  channel_name_map[old_name] = new_channel.name
286
286
 
287
287
  for policy in policies_:
288
- policy.creation_record = None
289
- policy.mutation_record = None
288
+ del policy.creation_record
289
+ del policy.mutation_record
290
290
 
291
291
  for i, channel in enumerate(policy.notification_channels):
292
292
  new_channel = channel_name_map.get(channel)
@@ -302,9 +302,9 @@ class StackdriverHook(GoogleBaseHook):
302
302
  metadata=metadata,
303
303
  )
304
304
  else:
305
- policy.name = None
305
+ del policy.name
306
306
  for condition in policy.conditions:
307
- condition.name = None
307
+ del condition.name
308
308
  policy_client.create_alert_policy(
309
309
  request={"name": f"projects/{project_id}", "alert_policy": policy},
310
310
  retry=retry,
@@ -395,10 +395,9 @@ class StackdriverHook(GoogleBaseHook):
395
395
  )
396
396
  if format_ == "dict":
397
397
  return [NotificationChannel.to_dict(channel) for channel in channels]
398
- elif format_ == "json":
398
+ if format_ == "json":
399
399
  return [NotificationChannel.to_json(channel) for channel in channels]
400
- else:
401
- return channels
400
+ return channels
402
401
 
403
402
  @GoogleBaseHook.fallback_to_default_project_id
404
403
  def _toggle_channel_status(
@@ -533,8 +532,9 @@ class StackdriverHook(GoogleBaseHook):
533
532
  channels_list.append(NotificationChannel(**channel))
534
533
 
535
534
  for channel in channels_list:
535
+ # This field is immutable, illegal to specifying non-default UNVERIFIED or VERIFIED, so setting default
536
536
  channel.verification_status = (
537
- monitoring_v3.NotificationChannel.VerificationStatus.VERIFICATION_STATUS_UNSPECIFIED
537
+ monitoring_v3.NotificationChannel.VerificationStatus.VERIFICATION_STATUS_UNSPECIFIED # type: ignore[assignment]
538
538
  )
539
539
 
540
540
  if channel.name in existing_channels:
@@ -546,7 +546,7 @@ class StackdriverHook(GoogleBaseHook):
546
546
  )
547
547
  else:
548
548
  old_name = channel.name
549
- channel.name = None
549
+ del channel.name
550
550
  new_channel = channel_client.create_notification_channel(
551
551
  request={"name": f"projects/{project_id}", "notification_channel": channel},
552
552
  retry=retry,
@@ -22,13 +22,14 @@ from __future__ import annotations
22
22
  from collections.abc import Sequence
23
23
  from typing import TYPE_CHECKING
24
24
 
25
- from airflow.exceptions import AirflowException
26
- from airflow.providers.google.common.consts import CLIENT_INFO
27
- from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook
28
25
  from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
29
26
  from google.cloud.tasks_v2 import CloudTasksClient
30
27
  from google.cloud.tasks_v2.types import Queue, Task
31
28
 
29
+ from airflow.exceptions import AirflowException
30
+ from airflow.providers.google.common.consts import CLIENT_INFO
31
+ from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook
32
+
32
33
  if TYPE_CHECKING:
33
34
  from google.api_core.retry import Retry
34
35
  from google.protobuf.field_mask_pb2 import FieldMask
@@ -22,8 +22,6 @@ from __future__ import annotations
22
22
  from collections.abc import Sequence
23
23
  from typing import TYPE_CHECKING
24
24
 
25
- from airflow.providers.google.common.consts import CLIENT_INFO
26
- from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
27
25
  from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
28
26
  from google.cloud.texttospeech_v1 import TextToSpeechClient
29
27
  from google.cloud.texttospeech_v1.types import (
@@ -33,6 +31,9 @@ from google.cloud.texttospeech_v1.types import (
33
31
  VoiceSelectionParams,
34
32
  )
35
33
 
34
+ from airflow.providers.google.common.consts import CLIENT_INFO
35
+ from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
36
+
36
37
  if TYPE_CHECKING:
37
38
  from google.api_core.retry import Retry
38
39
 
@@ -25,9 +25,6 @@ from typing import (
25
25
  cast,
26
26
  )
27
27
 
28
- from airflow.exceptions import AirflowException
29
- from airflow.providers.google.common.consts import CLIENT_INFO
30
- from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook
31
28
  from google.api_core.exceptions import GoogleAPICallError
32
29
  from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
33
30
  from google.api_core.retry import Retry
@@ -35,9 +32,12 @@ from google.cloud.translate_v2 import Client
35
32
  from google.cloud.translate_v3 import TranslationServiceClient
36
33
  from google.cloud.translate_v3.types.translation_service import GlossaryInputConfig
37
34
 
38
- if TYPE_CHECKING:
39
- from proto import Message
35
+ from airflow.exceptions import AirflowException
36
+ from airflow.providers.google.common.consts import CLIENT_INFO
37
+ from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook
38
+ from airflow.providers.google.common.hooks.operation_helpers import OperationHelper
40
39
 
40
+ if TYPE_CHECKING:
41
41
  from google.api_core.operation import Operation
42
42
  from google.cloud.translate_v3.services.translation_service import pagers
43
43
  from google.cloud.translate_v3.types import (
@@ -155,7 +155,7 @@ class CloudTranslateHook(GoogleBaseHook):
155
155
  )
156
156
 
157
157
 
158
- class TranslateHook(GoogleBaseHook):
158
+ class TranslateHook(GoogleBaseHook, OperationHelper):
159
159
  """
160
160
  Hook for Google Cloud translation (Advanced) using client version V3.
161
161
 
@@ -221,15 +221,6 @@ class TranslateHook(GoogleBaseHook):
221
221
  error = operation.exception(timeout=timeout)
222
222
  raise AirflowException(error)
223
223
 
224
- @staticmethod
225
- def wait_for_operation_result(operation: Operation, timeout: int | None = None) -> Message:
226
- """Wait for long-lasting operation to complete."""
227
- try:
228
- return operation.result(timeout=timeout)
229
- except GoogleAPICallError:
230
- error = operation.exception(timeout=timeout)
231
- raise AirflowException(error)
232
-
233
224
  @staticmethod
234
225
  def extract_object_id(obj: dict) -> str:
235
226
  """Return unique id of the object."""
@@ -320,7 +311,7 @@ class TranslateHook(GoogleBaseHook):
320
311
  retry=retry,
321
312
  metadata=metadata,
322
313
  )
323
- return cast(dict, type(result).to_dict(result))
314
+ return cast("dict", type(result).to_dict(result))
324
315
 
325
316
  def batch_translate_text(
326
317
  self,
@@ -438,7 +429,7 @@ class TranslateHook(GoogleBaseHook):
438
429
  project_id: str,
439
430
  location: str,
440
431
  retry: Retry | _MethodDefault = DEFAULT,
441
- timeout: float | _MethodDefault = DEFAULT,
432
+ timeout: float | None | _MethodDefault = DEFAULT,
442
433
  metadata: Sequence[tuple[str, str]] = (),
443
434
  ) -> automl_translation.Dataset:
444
435
  """
@@ -23,9 +23,6 @@ import warnings
23
23
  from collections.abc import Sequence
24
24
  from typing import TYPE_CHECKING
25
25
 
26
- from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
27
- from airflow.providers.google.common.deprecated import deprecated
28
- from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
29
26
  from google.api_core.client_options import ClientOptions
30
27
  from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
31
28
  from google.cloud.aiplatform import (
@@ -39,6 +36,11 @@ from google.cloud.aiplatform import (
39
36
  )
40
37
  from google.cloud.aiplatform_v1 import JobServiceClient, PipelineServiceClient
41
38
 
39
+ from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
40
+ from airflow.providers.google.common.consts import CLIENT_INFO
41
+ from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
42
+ from airflow.providers.google.common.hooks.operation_helpers import OperationHelper
43
+
42
44
  if TYPE_CHECKING:
43
45
  from google.api_core.operation import Operation
44
46
  from google.api_core.retry import Retry
@@ -46,7 +48,7 @@ if TYPE_CHECKING:
46
48
  from google.cloud.aiplatform_v1.types import TrainingPipeline
47
49
 
48
50
 
49
- class AutoMLHook(GoogleBaseHook):
51
+ class AutoMLHook(GoogleBaseHook, OperationHelper):
50
52
  """Hook for Google Cloud Vertex AI Auto ML APIs."""
51
53
 
52
54
  def __init__(
@@ -79,7 +81,7 @@ class AutoMLHook(GoogleBaseHook):
79
81
  client_options = ClientOptions()
80
82
 
81
83
  return PipelineServiceClient(
82
- credentials=self.get_credentials(), client_info=self.client_info, client_options=client_options
84
+ credentials=self.get_credentials(), client_info=CLIENT_INFO, client_options=client_options
83
85
  )
84
86
 
85
87
  def get_job_service_client(
@@ -93,7 +95,7 @@ class AutoMLHook(GoogleBaseHook):
93
95
  client_options = ClientOptions()
94
96
 
95
97
  return JobServiceClient(
96
- credentials=self.get_credentials(), client_info=self.client_info, client_options=client_options
98
+ credentials=self.get_credentials(), client_info=CLIENT_INFO, client_options=client_options
97
99
  )
98
100
 
99
101
  def get_auto_ml_tabular_training_job(
@@ -182,42 +184,6 @@ class AutoMLHook(GoogleBaseHook):
182
184
  model_encryption_spec_key_name=model_encryption_spec_key_name,
183
185
  )
184
186
 
185
- @deprecated(
186
- planned_removal_date="June 15, 2025",
187
- category=AirflowProviderDeprecationWarning,
188
- reason="Deprecation of AutoMLText API",
189
- )
190
- def get_auto_ml_text_training_job(
191
- self,
192
- display_name: str,
193
- prediction_type: str,
194
- multi_label: bool = False,
195
- sentiment_max: int = 10,
196
- project: str | None = None,
197
- location: str | None = None,
198
- labels: dict[str, str] | None = None,
199
- training_encryption_spec_key_name: str | None = None,
200
- model_encryption_spec_key_name: str | None = None,
201
- ) -> AutoMLTextTrainingJob:
202
- """
203
- Return AutoMLTextTrainingJob object.
204
-
205
- WARNING: Text creation API is deprecated since September 15, 2024
206
- (https://cloud.google.com/vertex-ai/docs/tutorials/text-classification-automl/overview).
207
- """
208
- return AutoMLTextTrainingJob(
209
- display_name=display_name,
210
- prediction_type=prediction_type,
211
- multi_label=multi_label,
212
- sentiment_max=sentiment_max,
213
- project=project,
214
- location=location,
215
- credentials=self.get_credentials(),
216
- labels=labels,
217
- training_encryption_spec_key_name=training_encryption_spec_key_name,
218
- model_encryption_spec_key_name=model_encryption_spec_key_name,
219
- )
220
-
221
187
  def get_auto_ml_video_training_job(
222
188
  self,
223
189
  display_name: str,
@@ -252,14 +218,6 @@ class AutoMLHook(GoogleBaseHook):
252
218
  """Return unique id of the Training pipeline."""
253
219
  return resource_name.rpartition("/")[-1]
254
220
 
255
- def wait_for_operation(self, operation: Operation, timeout: float | None = None):
256
- """Wait for long-lasting operation to complete."""
257
- try:
258
- return operation.result(timeout=timeout)
259
- except Exception:
260
- error = operation.exception(timeout=timeout)
261
- raise AirflowException(error)
262
-
263
221
  def cancel_auto_ml_job(self) -> None:
264
222
  """Cancel Auto ML Job for training pipeline."""
265
223
  if self._job:
@@ -992,178 +950,6 @@ class AutoMLHook(GoogleBaseHook):
992
950
  )
993
951
  return model, training_id
994
952
 
995
- @GoogleBaseHook.fallback_to_default_project_id
996
- @deprecated(
997
- planned_removal_date="September 15, 2025",
998
- category=AirflowProviderDeprecationWarning,
999
- reason="Deprecation of AutoMLText API",
1000
- )
1001
- def create_auto_ml_text_training_job(
1002
- self,
1003
- project_id: str,
1004
- region: str,
1005
- display_name: str,
1006
- dataset: datasets.TextDataset,
1007
- prediction_type: str,
1008
- multi_label: bool = False,
1009
- sentiment_max: int = 10,
1010
- labels: dict[str, str] | None = None,
1011
- training_encryption_spec_key_name: str | None = None,
1012
- model_encryption_spec_key_name: str | None = None,
1013
- training_fraction_split: float | None = None,
1014
- validation_fraction_split: float | None = None,
1015
- test_fraction_split: float | None = None,
1016
- training_filter_split: str | None = None,
1017
- validation_filter_split: str | None = None,
1018
- test_filter_split: str | None = None,
1019
- model_display_name: str | None = None,
1020
- model_labels: dict[str, str] | None = None,
1021
- sync: bool = True,
1022
- parent_model: str | None = None,
1023
- is_default_version: bool | None = None,
1024
- model_version_aliases: list[str] | None = None,
1025
- model_version_description: str | None = None,
1026
- ) -> tuple[models.Model | None, str]:
1027
- """
1028
- Create an AutoML Text Training Job.
1029
-
1030
- WARNING: Text creation API is deprecated since September 15, 2024
1031
- (https://cloud.google.com/vertex-ai/docs/tutorials/text-classification-automl/overview).
1032
-
1033
- :param project_id: Required. Project to run training in.
1034
- :param region: Required. Location to run training in.
1035
- :param display_name: Required. The user-defined name of this TrainingPipeline.
1036
- :param dataset: Required. The dataset within the same Project from which data will be used to train
1037
- the Model. The Dataset must use schema compatible with Model being trained, and what is
1038
- compatible should be described in the used TrainingPipeline's [training_task_definition]
1039
- [google.cloud.aiplatform.v1beta1.TrainingPipeline.training_task_definition].
1040
- :param prediction_type: The type of prediction the Model is to produce, one of:
1041
- "classification" - A classification model analyzes text data and returns a list of categories
1042
- that apply to the text found in the data. Vertex AI offers both single-label and multi-label text
1043
- classification models.
1044
- "extraction" - An entity extraction model inspects text data for known entities referenced in the
1045
- data and labels those entities in the text.
1046
- "sentiment" - A sentiment analysis model inspects text data and identifies the prevailing
1047
- emotional opinion within it, especially to determine a writer's attitude as positive, negative,
1048
- or neutral.
1049
- :param parent_model: Optional. The resource name or model ID of an existing model.
1050
- The new model uploaded by this job will be a version of `parent_model`.
1051
- Only set this field when training a new version of an existing model.
1052
- :param is_default_version: Optional. When set to True, the newly uploaded model version will
1053
- automatically have alias "default" included. Subsequent uses of
1054
- the model produced by this job without a version specified will
1055
- use this "default" version.
1056
- When set to False, the "default" alias will not be moved.
1057
- Actions targeting the model version produced by this job will need
1058
- to specifically reference this version by ID or alias.
1059
- New model uploads, i.e. version 1, will always be "default" aliased.
1060
- :param model_version_aliases: Optional. User provided version aliases so that the model version
1061
- uploaded by this job can be referenced via alias instead of
1062
- auto-generated version ID. A default version alias will be created
1063
- for the first version of the model.
1064
- The format is [a-z][a-zA-Z0-9-]{0,126}[a-z0-9]
1065
- :param model_version_description: Optional. The description of the model version
1066
- being uploaded by this job.
1067
- :param multi_label: Required and only applicable for text classification task. If false, a
1068
- single-label (multi-class) Model will be trained (i.e. assuming that for each text snippet just
1069
- up to one annotation may be applicable). If true, a multi-label Model will be trained (i.e.
1070
- assuming that for each text snippet multiple annotations may be applicable).
1071
- :param sentiment_max: Required and only applicable for sentiment task. A sentiment is expressed as an
1072
- integer ordinal, where higher value means a more positive sentiment. The range of sentiments that
1073
- will be used is between 0 and sentimentMax (inclusive on both ends), and all the values in the
1074
- range must be represented in the dataset before a model can be created. Only the Annotations with
1075
- this sentimentMax will be used for training. sentimentMax value must be between 1 and 10
1076
- (inclusive).
1077
- :param labels: Optional. The labels with user-defined metadata to organize TrainingPipelines. Label
1078
- keys and values can be no longer than 64 characters (Unicode codepoints), can only contain
1079
- lowercase letters, numeric characters, underscores and dashes. International characters are
1080
- allowed. See https://goo.gl/xmQnxf for more information and examples of labels.
1081
- :param training_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer
1082
- managed encryption key used to protect the training pipeline. Has the form:
1083
- ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
1084
- The key needs to be in the same region as where the compute resource is created.
1085
- If set, this TrainingPipeline will be secured by this key.
1086
- Note: Model trained by this TrainingPipeline is also secured by this key if ``model_to_upload``
1087
- is not set separately.
1088
- :param model_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer
1089
- managed encryption key used to protect the model. Has the form:
1090
- ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
1091
- The key needs to be in the same region as where the compute resource is created.
1092
- If set, the trained Model will be secured by this key.
1093
- :param training_fraction_split: Optional. The fraction of the input data that is to be used to train
1094
- the Model. This is ignored if Dataset is not provided.
1095
- :param validation_fraction_split: Optional. The fraction of the input data that is to be used to
1096
- validate the Model. This is ignored if Dataset is not provided.
1097
- :param test_fraction_split: Optional. The fraction of the input data that is to be used to evaluate
1098
- the Model. This is ignored if Dataset is not provided.
1099
- :param training_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match
1100
- this filter are used to train the Model. A filter with same syntax as the one used in
1101
- DatasetService.ListDataItems may be used. If a single DataItem is matched by more than one of the
1102
- FilterSplit filters, then it is assigned to the first set that applies to it in the training,
1103
- validation, test order. This is ignored if Dataset is not provided.
1104
- :param validation_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match
1105
- this filter are used to validate the Model. A filter with same syntax as the one used in
1106
- DatasetService.ListDataItems may be used. If a single DataItem is matched by more than one of the
1107
- FilterSplit filters, then it is assigned to the first set that applies to it in the training,
1108
- validation, test order. This is ignored if Dataset is not provided.
1109
- :param test_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match this
1110
- filter are used to test the Model. A filter with same syntax as the one used in
1111
- DatasetService.ListDataItems may be used. If a single DataItem is matched by more than one of the
1112
- FilterSplit filters, then it is assigned to the first set that applies to it in the training,
1113
- validation, test order. This is ignored if Dataset is not provided.
1114
- :param model_display_name: Optional. The display name of the managed Vertex AI Model. The name can be
1115
- up to 128 characters long and can consist of any UTF-8 characters.
1116
- If not provided upon creation, the job's display_name is used.
1117
- :param model_labels: Optional. The labels with user-defined metadata to organize your Models. Label
1118
- keys and values can be no longer than 64 characters (Unicode codepoints), can only contain
1119
- lowercase letters, numeric characters, underscores and dashes. International characters are
1120
- allowed. See https://goo.gl/xmQnxf for more information and examples of labels.
1121
- :param sync: Whether to execute this method synchronously. If False, this method will be executed in
1122
- concurrent Future and any downstream object will be immediately returned and synced when the
1123
- Future has completed.
1124
- """
1125
- self._job = AutoMLTextTrainingJob(
1126
- display_name=display_name,
1127
- prediction_type=prediction_type,
1128
- multi_label=multi_label,
1129
- sentiment_max=sentiment_max,
1130
- project=project_id,
1131
- location=region,
1132
- credentials=self.get_credentials(),
1133
- labels=labels,
1134
- training_encryption_spec_key_name=training_encryption_spec_key_name,
1135
- model_encryption_spec_key_name=model_encryption_spec_key_name,
1136
- )
1137
-
1138
- if not self._job:
1139
- raise AirflowException("AutoMLTextTrainingJob was not created")
1140
-
1141
- model = self._job.run(
1142
- dataset=dataset, # type: ignore[arg-type]
1143
- training_fraction_split=training_fraction_split, # type: ignore[call-arg]
1144
- validation_fraction_split=validation_fraction_split, # type: ignore[call-arg]
1145
- test_fraction_split=test_fraction_split,
1146
- training_filter_split=training_filter_split,
1147
- validation_filter_split=validation_filter_split,
1148
- test_filter_split=test_filter_split, # type: ignore[call-arg]
1149
- model_display_name=model_display_name,
1150
- model_labels=model_labels,
1151
- sync=sync,
1152
- parent_model=parent_model,
1153
- is_default_version=is_default_version,
1154
- model_version_aliases=model_version_aliases,
1155
- model_version_description=model_version_description,
1156
- )
1157
- training_id = self.extract_training_id(self._job.resource_name)
1158
- if model:
1159
- model.wait()
1160
- else:
1161
- self.log.warning(
1162
- "Training did not produce a Managed Model returning None. AutoML Text Training "
1163
- "Pipeline is not configured to upload a Model."
1164
- )
1165
- return model, training_id
1166
-
1167
953
  @GoogleBaseHook.fallback_to_default_project_id
1168
954
  def create_auto_ml_video_training_job(
1169
955
  self,