apache-airflow-providers-google 10.17.0rc1__py3-none-any.whl → 10.18.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. airflow/providers/google/__init__.py +3 -3
  2. airflow/providers/google/cloud/hooks/automl.py +1 -1
  3. airflow/providers/google/cloud/hooks/bigquery.py +64 -33
  4. airflow/providers/google/cloud/hooks/cloud_composer.py +250 -2
  5. airflow/providers/google/cloud/hooks/cloud_sql.py +154 -7
  6. airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +7 -2
  7. airflow/providers/google/cloud/hooks/compute_ssh.py +2 -1
  8. airflow/providers/google/cloud/hooks/dataflow.py +246 -32
  9. airflow/providers/google/cloud/hooks/dataplex.py +6 -2
  10. airflow/providers/google/cloud/hooks/dlp.py +14 -14
  11. airflow/providers/google/cloud/hooks/gcs.py +6 -2
  12. airflow/providers/google/cloud/hooks/gdm.py +2 -2
  13. airflow/providers/google/cloud/hooks/kubernetes_engine.py +2 -2
  14. airflow/providers/google/cloud/hooks/mlengine.py +8 -4
  15. airflow/providers/google/cloud/hooks/pubsub.py +1 -1
  16. airflow/providers/google/cloud/hooks/secret_manager.py +252 -4
  17. airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py +1431 -74
  18. airflow/providers/google/cloud/links/vertex_ai.py +2 -1
  19. airflow/providers/google/cloud/log/gcs_task_handler.py +2 -1
  20. airflow/providers/google/cloud/operators/automl.py +13 -12
  21. airflow/providers/google/cloud/operators/bigquery.py +36 -22
  22. airflow/providers/google/cloud/operators/bigquery_dts.py +4 -3
  23. airflow/providers/google/cloud/operators/bigtable.py +7 -6
  24. airflow/providers/google/cloud/operators/cloud_build.py +12 -11
  25. airflow/providers/google/cloud/operators/cloud_composer.py +147 -2
  26. airflow/providers/google/cloud/operators/cloud_memorystore.py +17 -16
  27. airflow/providers/google/cloud/operators/cloud_sql.py +60 -17
  28. airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +35 -16
  29. airflow/providers/google/cloud/operators/compute.py +12 -11
  30. airflow/providers/google/cloud/operators/datacatalog.py +21 -20
  31. airflow/providers/google/cloud/operators/dataflow.py +59 -42
  32. airflow/providers/google/cloud/operators/datafusion.py +11 -10
  33. airflow/providers/google/cloud/operators/datapipeline.py +3 -2
  34. airflow/providers/google/cloud/operators/dataprep.py +5 -4
  35. airflow/providers/google/cloud/operators/dataproc.py +19 -16
  36. airflow/providers/google/cloud/operators/datastore.py +8 -7
  37. airflow/providers/google/cloud/operators/dlp.py +31 -30
  38. airflow/providers/google/cloud/operators/functions.py +4 -3
  39. airflow/providers/google/cloud/operators/gcs.py +66 -41
  40. airflow/providers/google/cloud/operators/kubernetes_engine.py +232 -12
  41. airflow/providers/google/cloud/operators/life_sciences.py +2 -1
  42. airflow/providers/google/cloud/operators/mlengine.py +11 -10
  43. airflow/providers/google/cloud/operators/pubsub.py +6 -5
  44. airflow/providers/google/cloud/operators/spanner.py +7 -6
  45. airflow/providers/google/cloud/operators/speech_to_text.py +2 -1
  46. airflow/providers/google/cloud/operators/stackdriver.py +11 -10
  47. airflow/providers/google/cloud/operators/tasks.py +14 -13
  48. airflow/providers/google/cloud/operators/text_to_speech.py +2 -1
  49. airflow/providers/google/cloud/operators/translate_speech.py +2 -1
  50. airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +333 -26
  51. airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +20 -12
  52. airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py +0 -1
  53. airflow/providers/google/cloud/operators/vision.py +13 -12
  54. airflow/providers/google/cloud/operators/workflows.py +10 -9
  55. airflow/providers/google/cloud/secrets/secret_manager.py +2 -1
  56. airflow/providers/google/cloud/sensors/bigquery_dts.py +2 -1
  57. airflow/providers/google/cloud/sensors/bigtable.py +2 -1
  58. airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py +2 -1
  59. airflow/providers/google/cloud/sensors/dataflow.py +239 -52
  60. airflow/providers/google/cloud/sensors/datafusion.py +2 -1
  61. airflow/providers/google/cloud/sensors/dataproc.py +3 -2
  62. airflow/providers/google/cloud/sensors/gcs.py +14 -12
  63. airflow/providers/google/cloud/sensors/tasks.py +2 -1
  64. airflow/providers/google/cloud/sensors/workflows.py +2 -1
  65. airflow/providers/google/cloud/transfers/adls_to_gcs.py +8 -2
  66. airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py +7 -1
  67. airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py +7 -1
  68. airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +2 -1
  69. airflow/providers/google/cloud/transfers/bigquery_to_mssql.py +1 -1
  70. airflow/providers/google/cloud/transfers/bigquery_to_sql.py +1 -0
  71. airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +5 -6
  72. airflow/providers/google/cloud/transfers/gcs_to_gcs.py +22 -12
  73. airflow/providers/google/cloud/triggers/bigquery.py +14 -3
  74. airflow/providers/google/cloud/triggers/cloud_composer.py +68 -0
  75. airflow/providers/google/cloud/triggers/cloud_sql.py +2 -1
  76. airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py +2 -1
  77. airflow/providers/google/cloud/triggers/dataflow.py +504 -4
  78. airflow/providers/google/cloud/triggers/dataproc.py +110 -26
  79. airflow/providers/google/cloud/triggers/mlengine.py +2 -1
  80. airflow/providers/google/cloud/triggers/vertex_ai.py +94 -0
  81. airflow/providers/google/common/hooks/base_google.py +45 -7
  82. airflow/providers/google/firebase/hooks/firestore.py +2 -2
  83. airflow/providers/google/firebase/operators/firestore.py +2 -1
  84. airflow/providers/google/get_provider_info.py +3 -2
  85. {apache_airflow_providers_google-10.17.0rc1.dist-info → apache_airflow_providers_google-10.18.0rc1.dist-info}/METADATA +8 -8
  86. {apache_airflow_providers_google-10.17.0rc1.dist-info → apache_airflow_providers_google-10.18.0rc1.dist-info}/RECORD +88 -89
  87. airflow/providers/google/cloud/example_dags/example_cloud_sql_query.py +0 -289
  88. {apache_airflow_providers_google-10.17.0rc1.dist-info → apache_airflow_providers_google-10.18.0rc1.dist-info}/WHEEL +0 -0
  89. {apache_airflow_providers_google-10.17.0rc1.dist-info → apache_airflow_providers_google-10.18.0rc1.dist-info}/entry_points.txt +0 -0
@@ -27,7 +27,7 @@ import packaging.version
27
27
 
28
28
  __all__ = ["__version__"]
29
29
 
30
- __version__ = "10.17.0"
30
+ __version__ = "10.18.0"
31
31
 
32
32
  try:
33
33
  from airflow import __version__ as airflow_version
@@ -35,8 +35,8 @@ except ImportError:
35
35
  from airflow.version import version as airflow_version
36
36
 
37
37
  if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
38
- "2.6.0"
38
+ "2.7.0"
39
39
  ):
40
40
  raise RuntimeError(
41
- f"The package `apache-airflow-providers-google:{__version__}` needs Apache Airflow 2.6.0+"
41
+ f"The package `apache-airflow-providers-google:{__version__}` needs Apache Airflow 2.7.0+"
42
42
  )
@@ -529,7 +529,7 @@ class CloudAutoMLHook(GoogleBaseHook):
529
529
  self,
530
530
  dataset_id: str,
531
531
  location: str,
532
- project_id: str | None = None,
532
+ project_id: str = PROVIDE_PROJECT_ID,
533
533
  filter_: str | None = None,
534
534
  page_size: int | None = None,
535
535
  retry: Retry | _MethodDefault = DEFAULT,
@@ -59,7 +59,12 @@ from airflow.providers.common.sql.hooks.sql import DbApiHook
59
59
  from airflow.providers.google.cloud.utils.bigquery import bq_cast
60
60
  from airflow.providers.google.cloud.utils.credentials_provider import _get_scopes
61
61
  from airflow.providers.google.common.consts import CLIENT_INFO
62
- from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook, get_field
62
+ from airflow.providers.google.common.hooks.base_google import (
63
+ PROVIDE_PROJECT_ID,
64
+ GoogleBaseAsyncHook,
65
+ GoogleBaseHook,
66
+ get_field,
67
+ )
63
68
 
64
69
  try:
65
70
  from airflow.utils.hashlib_wrapper import md5
@@ -198,7 +203,7 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
198
203
  http_authorized = self._authorize()
199
204
  return build("bigquery", "v2", http=http_authorized, cache_discovery=False)
200
205
 
201
- def get_client(self, project_id: str | None = None, location: str | None = None) -> Client:
206
+ def get_client(self, project_id: str = PROVIDE_PROJECT_ID, location: str | None = None) -> Client:
202
207
  """Get an authenticated BigQuery Client.
203
208
 
204
209
  :param project_id: Project ID for the project which the client acts on behalf of.
@@ -250,7 +255,7 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
250
255
  @staticmethod
251
256
  def _resolve_table_reference(
252
257
  table_resource: dict[str, Any],
253
- project_id: str | None = None,
258
+ project_id: str = PROVIDE_PROJECT_ID,
254
259
  dataset_id: str | None = None,
255
260
  table_id: str | None = None,
256
261
  ) -> dict[str, Any]:
@@ -360,7 +365,7 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
360
365
  @GoogleBaseHook.fallback_to_default_project_id
361
366
  def create_empty_table(
362
367
  self,
363
- project_id: str | None = None,
368
+ project_id: str = PROVIDE_PROJECT_ID,
364
369
  dataset_id: str | None = None,
365
370
  table_id: str | None = None,
366
371
  table_resource: dict[str, Any] | None = None,
@@ -474,7 +479,7 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
474
479
  def create_empty_dataset(
475
480
  self,
476
481
  dataset_id: str | None = None,
477
- project_id: str | None = None,
482
+ project_id: str = PROVIDE_PROJECT_ID,
478
483
  location: str | None = None,
479
484
  dataset_reference: dict[str, Any] | None = None,
480
485
  exists_ok: bool = True,
@@ -536,7 +541,7 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
536
541
  def get_dataset_tables(
537
542
  self,
538
543
  dataset_id: str,
539
- project_id: str | None = None,
544
+ project_id: str = PROVIDE_PROJECT_ID,
540
545
  max_results: int | None = None,
541
546
  retry: Retry = DEFAULT_RETRY,
542
547
  ) -> list[dict[str, Any]]:
@@ -565,7 +570,7 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
565
570
  def delete_dataset(
566
571
  self,
567
572
  dataset_id: str,
568
- project_id: str | None = None,
573
+ project_id: str = PROVIDE_PROJECT_ID,
569
574
  delete_contents: bool = False,
570
575
  retry: Retry = DEFAULT_RETRY,
571
576
  ) -> None:
@@ -614,7 +619,7 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
614
619
  description: str | None = None,
615
620
  encryption_configuration: dict | None = None,
616
621
  location: str | None = None,
617
- project_id: str | None = None,
622
+ project_id: str = PROVIDE_PROJECT_ID,
618
623
  ) -> Table:
619
624
  """Create an external table in the dataset with data from Google Cloud Storage.
620
625
 
@@ -750,7 +755,7 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
750
755
  fields: list[str] | None = None,
751
756
  dataset_id: str | None = None,
752
757
  table_id: str | None = None,
753
- project_id: str | None = None,
758
+ project_id: str = PROVIDE_PROJECT_ID,
754
759
  ) -> dict[str, Any]:
755
760
  """Change some fields of a table.
756
761
 
@@ -796,7 +801,7 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
796
801
  self,
797
802
  dataset_id: str,
798
803
  table_id: str,
799
- project_id: str | None = None,
804
+ project_id: str = PROVIDE_PROJECT_ID,
800
805
  description: str | None = None,
801
806
  expiration_time: int | None = None,
802
807
  external_data_configuration: dict | None = None,
@@ -953,7 +958,7 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
953
958
  fields: Sequence[str],
954
959
  dataset_resource: dict[str, Any],
955
960
  dataset_id: str | None = None,
956
- project_id: str | None = None,
961
+ project_id: str = PROVIDE_PROJECT_ID,
957
962
  retry: Retry = DEFAULT_RETRY,
958
963
  ) -> Dataset:
959
964
  """Change some fields of a dataset.
@@ -999,7 +1004,9 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
999
1004
  ),
1000
1005
  category=AirflowProviderDeprecationWarning,
1001
1006
  )
1002
- def patch_dataset(self, dataset_id: str, dataset_resource: dict, project_id: str | None = None) -> dict:
1007
+ def patch_dataset(
1008
+ self, dataset_id: str, dataset_resource: dict, project_id: str = PROVIDE_PROJECT_ID
1009
+ ) -> dict:
1003
1010
  """Patches information in an existing dataset.
1004
1011
 
1005
1012
  It only replaces fields that are provided in the submitted dataset resource.
@@ -1047,7 +1054,7 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
1047
1054
  def get_dataset_tables_list(
1048
1055
  self,
1049
1056
  dataset_id: str,
1050
- project_id: str | None = None,
1057
+ project_id: str = PROVIDE_PROJECT_ID,
1051
1058
  table_prefix: str | None = None,
1052
1059
  max_results: int | None = None,
1053
1060
  ) -> list[dict[str, Any]]:
@@ -1084,7 +1091,7 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
1084
1091
  @GoogleBaseHook.fallback_to_default_project_id
1085
1092
  def get_datasets_list(
1086
1093
  self,
1087
- project_id: str | None = None,
1094
+ project_id: str = PROVIDE_PROJECT_ID,
1088
1095
  include_all: bool = False,
1089
1096
  filter_: str | None = None,
1090
1097
  max_results: int | None = None,
@@ -1134,7 +1141,7 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
1134
1141
  return datasets_list
1135
1142
 
1136
1143
  @GoogleBaseHook.fallback_to_default_project_id
1137
- def get_dataset(self, dataset_id: str, project_id: str | None = None) -> Dataset:
1144
+ def get_dataset(self, dataset_id: str, project_id: str = PROVIDE_PROJECT_ID) -> Dataset:
1138
1145
  """Fetch the dataset referenced by *dataset_id*.
1139
1146
 
1140
1147
  :param dataset_id: The BigQuery Dataset ID
@@ -1158,7 +1165,7 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
1158
1165
  view_dataset: str,
1159
1166
  view_table: str,
1160
1167
  view_project: str | None = None,
1161
- project_id: str | None = None,
1168
+ project_id: str = PROVIDE_PROJECT_ID,
1162
1169
  ) -> dict[str, Any]:
1163
1170
  """Grant authorized view access of a dataset to a view table.
1164
1171
 
@@ -1210,7 +1217,7 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
1210
1217
 
1211
1218
  @GoogleBaseHook.fallback_to_default_project_id
1212
1219
  def run_table_upsert(
1213
- self, dataset_id: str, table_resource: dict[str, Any], project_id: str | None = None
1220
+ self, dataset_id: str, table_resource: dict[str, Any], project_id: str = PROVIDE_PROJECT_ID
1214
1221
  ) -> dict[str, Any]:
1215
1222
  """Update a table if it exists, otherwise create a new one.
1216
1223
 
@@ -1267,7 +1274,7 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
1267
1274
  self,
1268
1275
  table_id: str,
1269
1276
  not_found_ok: bool = True,
1270
- project_id: str | None = None,
1277
+ project_id: str = PROVIDE_PROJECT_ID,
1271
1278
  ) -> None:
1272
1279
  """Delete an existing table from the dataset.
1273
1280
 
@@ -1334,7 +1341,7 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
1334
1341
  selected_fields: list[str] | str | None = None,
1335
1342
  page_token: str | None = None,
1336
1343
  start_index: int | None = None,
1337
- project_id: str | None = None,
1344
+ project_id: str = PROVIDE_PROJECT_ID,
1338
1345
  location: str | None = None,
1339
1346
  retry: Retry = DEFAULT_RETRY,
1340
1347
  return_iterator: bool = False,
@@ -1387,7 +1394,7 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
1387
1394
  return list(iterator)
1388
1395
 
1389
1396
  @GoogleBaseHook.fallback_to_default_project_id
1390
- def get_schema(self, dataset_id: str, table_id: str, project_id: str | None = None) -> dict:
1397
+ def get_schema(self, dataset_id: str, table_id: str, project_id: str = PROVIDE_PROJECT_ID) -> dict:
1391
1398
  """Get the schema for a given dataset and table.
1392
1399
 
1393
1400
  .. seealso:: https://cloud.google.com/bigquery/docs/reference/v2/tables#resource
@@ -1409,7 +1416,7 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
1409
1416
  include_policy_tags: bool,
1410
1417
  dataset_id: str,
1411
1418
  table_id: str,
1412
- project_id: str | None = None,
1419
+ project_id: str = PROVIDE_PROJECT_ID,
1413
1420
  ) -> dict[str, Any]:
1414
1421
  """Update fields within a schema for a given dataset and table.
1415
1422
 
@@ -1502,7 +1509,7 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
1502
1509
  def poll_job_complete(
1503
1510
  self,
1504
1511
  job_id: str,
1505
- project_id: str | None = None,
1512
+ project_id: str = PROVIDE_PROJECT_ID,
1506
1513
  location: str | None = None,
1507
1514
  retry: Retry = DEFAULT_RETRY,
1508
1515
  ) -> bool:
@@ -1532,7 +1539,7 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
1532
1539
  def cancel_job(
1533
1540
  self,
1534
1541
  job_id: str,
1535
- project_id: str | None = None,
1542
+ project_id: str = PROVIDE_PROJECT_ID,
1536
1543
  location: str | None = None,
1537
1544
  ) -> None:
1538
1545
  """Cancel a job and wait for cancellation to complete.
@@ -1573,10 +1580,11 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
1573
1580
  time.sleep(5)
1574
1581
 
1575
1582
  @GoogleBaseHook.fallback_to_default_project_id
1583
+ @GoogleBaseHook.refresh_credentials_retry()
1576
1584
  def get_job(
1577
1585
  self,
1578
1586
  job_id: str,
1579
- project_id: str | None = None,
1587
+ project_id: str = PROVIDE_PROJECT_ID,
1580
1588
  location: str | None = None,
1581
1589
  ) -> CopyJob | QueryJob | LoadJob | ExtractJob | UnknownJob:
1582
1590
  """Retrieve a BigQuery job.
@@ -1607,7 +1615,7 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
1607
1615
  self,
1608
1616
  configuration: dict,
1609
1617
  job_id: str | None = None,
1610
- project_id: str | None = None,
1618
+ project_id: str = PROVIDE_PROJECT_ID,
1611
1619
  location: str | None = None,
1612
1620
  nowait: bool = False,
1613
1621
  retry: Retry = DEFAULT_RETRY,
@@ -2849,11 +2857,10 @@ class BigQueryCursor(BigQueryBaseCursor):
2849
2857
  return None
2850
2858
 
2851
2859
  query_results = self._get_query_result()
2852
- if "rows" in query_results and query_results["rows"]:
2860
+ if rows := query_results.get("rows"):
2853
2861
  self.page_token = query_results.get("pageToken")
2854
2862
  fields = query_results["schema"]["fields"]
2855
2863
  col_types = [field["type"] for field in fields]
2856
- rows = query_results["rows"]
2857
2864
 
2858
2865
  for dict_row in rows:
2859
2866
  typed_row = [bq_cast(vs["v"], col_types[idx]) for idx, vs in enumerate(dict_row["f"])]
@@ -3305,7 +3312,7 @@ class BigQueryAsyncHook(GoogleBaseAsyncHook):
3305
3312
  )
3306
3313
 
3307
3314
  async def _get_job(
3308
- self, job_id: str | None, project_id: str | None = None, location: str | None = None
3315
+ self, job_id: str | None, project_id: str = PROVIDE_PROJECT_ID, location: str | None = None
3309
3316
  ) -> CopyJob | QueryJob | LoadJob | ExtractJob | UnknownJob:
3310
3317
  """
3311
3318
  Get BigQuery job by its ID, project ID and location.
@@ -3348,7 +3355,7 @@ class BigQueryAsyncHook(GoogleBaseAsyncHook):
3348
3355
  return hook.get_job(job_id=job_id, project_id=project_id, location=location)
3349
3356
 
3350
3357
  async def get_job_status(
3351
- self, job_id: str | None, project_id: str | None = None, location: str | None = None
3358
+ self, job_id: str | None, project_id: str = PROVIDE_PROJECT_ID, location: str | None = None
3352
3359
  ) -> dict[str, str]:
3353
3360
  job = await self._get_job(job_id=job_id, project_id=project_id, location=location)
3354
3361
  if job.state == "DONE":
@@ -3360,7 +3367,7 @@ class BigQueryAsyncHook(GoogleBaseAsyncHook):
3360
3367
  async def get_job_output(
3361
3368
  self,
3362
3369
  job_id: str | None,
3363
- project_id: str | None = None,
3370
+ project_id: str = PROVIDE_PROJECT_ID,
3364
3371
  ) -> dict[str, Any]:
3365
3372
  """Get the BigQuery job output for a given job ID asynchronously."""
3366
3373
  async with ClientSession() as session:
@@ -3373,7 +3380,7 @@ class BigQueryAsyncHook(GoogleBaseAsyncHook):
3373
3380
  self,
3374
3381
  dataset_id: str | None,
3375
3382
  table_id: str | None = None,
3376
- project_id: str | None = None,
3383
+ project_id: str = PROVIDE_PROJECT_ID,
3377
3384
  ):
3378
3385
  """Create a new job and get the job_id using gcloud-aio."""
3379
3386
  async with ClientSession() as session:
@@ -3389,6 +3396,31 @@ class BigQueryAsyncHook(GoogleBaseAsyncHook):
3389
3396
  job_query_resp = await job_client.query(query_request, cast(Session, session))
3390
3397
  return job_query_resp["jobReference"]["jobId"]
3391
3398
 
3399
+ async def cancel_job(self, job_id: str, project_id: str | None, location: str | None) -> None:
3400
+ """
3401
+ Cancel a BigQuery job.
3402
+
3403
+ :param job_id: ID of the job to cancel.
3404
+ :param project_id: Google Cloud Project where the job was running.
3405
+ :param location: Location where the job was running.
3406
+ """
3407
+ async with ClientSession() as session:
3408
+ token = await self.get_token(session=session)
3409
+ job = Job(job_id=job_id, project=project_id, location=location, token=token, session=session) # type: ignore[arg-type]
3410
+
3411
+ self.log.info(
3412
+ "Attempting to cancel BigQuery job: %s in project: %s, location: %s",
3413
+ job_id,
3414
+ project_id,
3415
+ location,
3416
+ )
3417
+ try:
3418
+ await job.cancel()
3419
+ self.log.info("Job %s cancellation requested.", job_id)
3420
+ except Exception as e:
3421
+ self.log.error("Failed to cancel BigQuery job %s: %s", job_id, str(e))
3422
+ raise
3423
+
3392
3424
  def get_records(self, query_results: dict[str, Any], as_dict: bool = False) -> list[Any]:
3393
3425
  """Convert a response from BigQuery to records.
3394
3426
 
@@ -3396,8 +3428,7 @@ class BigQueryAsyncHook(GoogleBaseAsyncHook):
3396
3428
  :param as_dict: if True returns the result as a list of dictionaries, otherwise as list of lists.
3397
3429
  """
3398
3430
  buffer: list[Any] = []
3399
- if "rows" in query_results and query_results["rows"]:
3400
- rows = query_results["rows"]
3431
+ if rows := query_results.get("rows"):
3401
3432
  fields = query_results["schema"]["fields"]
3402
3433
  fields_names = [field["name"] for field in fields]
3403
3434
  col_types = [field["type"] for field in fields]
@@ -17,7 +17,9 @@
17
17
  # under the License.
18
18
  from __future__ import annotations
19
19
 
20
- from typing import TYPE_CHECKING, Sequence
20
+ import asyncio
21
+ import time
22
+ from typing import TYPE_CHECKING, MutableSequence, Sequence
21
23
 
22
24
  from google.api_core.client_options import ClientOptions
23
25
  from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
@@ -25,6 +27,7 @@ from google.cloud.orchestration.airflow.service_v1 import (
25
27
  EnvironmentsAsyncClient,
26
28
  EnvironmentsClient,
27
29
  ImageVersionsClient,
30
+ PollAirflowCommandResponse,
28
31
  )
29
32
 
30
33
  from airflow.exceptions import AirflowException
@@ -42,7 +45,10 @@ if TYPE_CHECKING:
42
45
  from google.cloud.orchestration.airflow.service_v1.services.image_versions.pagers import (
43
46
  ListImageVersionsPager,
44
47
  )
45
- from google.cloud.orchestration.airflow.service_v1.types import Environment
48
+ from google.cloud.orchestration.airflow.service_v1.types import (
49
+ Environment,
50
+ ExecuteAirflowCommandResponse,
51
+ )
46
52
  from google.protobuf.field_mask_pb2 import FieldMask
47
53
 
48
54
 
@@ -294,6 +300,127 @@ class CloudComposerHook(GoogleBaseHook):
294
300
  )
295
301
  return result
296
302
 
303
+ @GoogleBaseHook.fallback_to_default_project_id
304
+ def execute_airflow_command(
305
+ self,
306
+ project_id: str,
307
+ region: str,
308
+ environment_id: str,
309
+ command: str,
310
+ subcommand: str,
311
+ parameters: MutableSequence[str],
312
+ retry: Retry | _MethodDefault = DEFAULT,
313
+ timeout: float | None = None,
314
+ metadata: Sequence[tuple[str, str]] = (),
315
+ ) -> ExecuteAirflowCommandResponse:
316
+ """
317
+ Execute Airflow command for provided Composer environment.
318
+
319
+ :param project_id: The ID of the Google Cloud project that the service belongs to.
320
+ :param region: The ID of the Google Cloud region that the service belongs to.
321
+ :param environment_id: The ID of the Google Cloud environment that the service belongs to.
322
+ :param command: Airflow command.
323
+ :param subcommand: Airflow subcommand.
324
+ :param parameters: Parameters for the Airflow command/subcommand as an array of arguments. It may
325
+ contain positional arguments like ``["my-dag-id"]``, key-value parameters like ``["--foo=bar"]``
326
+ or ``["--foo","bar"]``, or other flags like ``["-f"]``.
327
+ :param retry: Designation of what errors, if any, should be retried.
328
+ :param timeout: The timeout for this request.
329
+ :param metadata: Strings which should be sent along with the request as metadata.
330
+ """
331
+ client = self.get_environment_client()
332
+ result = client.execute_airflow_command(
333
+ request={
334
+ "environment": self.get_environment_name(project_id, region, environment_id),
335
+ "command": command,
336
+ "subcommand": subcommand,
337
+ "parameters": parameters,
338
+ },
339
+ retry=retry,
340
+ timeout=timeout,
341
+ metadata=metadata,
342
+ )
343
+ return result
344
+
345
+ @GoogleBaseHook.fallback_to_default_project_id
346
+ def poll_airflow_command(
347
+ self,
348
+ project_id: str,
349
+ region: str,
350
+ environment_id: str,
351
+ execution_id: str,
352
+ pod: str,
353
+ pod_namespace: str,
354
+ next_line_number: int,
355
+ retry: Retry | _MethodDefault = DEFAULT,
356
+ timeout: float | None = None,
357
+ metadata: Sequence[tuple[str, str]] = (),
358
+ ) -> PollAirflowCommandResponse:
359
+ """
360
+ Poll Airflow command execution result for provided Composer environment.
361
+
362
+ :param project_id: The ID of the Google Cloud project that the service belongs to.
363
+ :param region: The ID of the Google Cloud region that the service belongs to.
364
+ :param environment_id: The ID of the Google Cloud environment that the service belongs to.
365
+ :param execution_id: The unique ID of the command execution.
366
+ :param pod: The name of the pod where the command is executed.
367
+ :param pod_namespace: The namespace of the pod where the command is executed.
368
+ :param next_line_number: Line number from which new logs should be fetched.
369
+ :param retry: Designation of what errors, if any, should be retried.
370
+ :param timeout: The timeout for this request.
371
+ :param metadata: Strings which should be sent along with the request as metadata.
372
+ """
373
+ client = self.get_environment_client()
374
+ result = client.poll_airflow_command(
375
+ request={
376
+ "environment": self.get_environment_name(project_id, region, environment_id),
377
+ "execution_id": execution_id,
378
+ "pod": pod,
379
+ "pod_namespace": pod_namespace,
380
+ "next_line_number": next_line_number,
381
+ },
382
+ retry=retry,
383
+ timeout=timeout,
384
+ metadata=metadata,
385
+ )
386
+ return result
387
+
388
+ def wait_command_execution_result(
389
+ self,
390
+ project_id: str,
391
+ region: str,
392
+ environment_id: str,
393
+ execution_cmd_info: dict,
394
+ retry: Retry | _MethodDefault = DEFAULT,
395
+ timeout: float | None = None,
396
+ metadata: Sequence[tuple[str, str]] = (),
397
+ poll_interval: int = 10,
398
+ ) -> dict:
399
+ while True:
400
+ try:
401
+ result = self.poll_airflow_command(
402
+ project_id=project_id,
403
+ region=region,
404
+ environment_id=environment_id,
405
+ execution_id=execution_cmd_info["execution_id"],
406
+ pod=execution_cmd_info["pod"],
407
+ pod_namespace=execution_cmd_info["pod_namespace"],
408
+ next_line_number=1,
409
+ retry=retry,
410
+ timeout=timeout,
411
+ metadata=metadata,
412
+ )
413
+ except Exception as ex:
414
+ self.log.exception("Exception occurred while polling CMD result")
415
+ raise AirflowException(ex)
416
+
417
+ result_dict = PollAirflowCommandResponse.to_dict(result)
418
+ if result_dict["output_end"]:
419
+ return result_dict
420
+
421
+ self.log.info("Waiting for result...")
422
+ time.sleep(poll_interval)
423
+
297
424
 
298
425
  class CloudComposerAsyncHook(GoogleBaseHook):
299
426
  """Hook for Google Cloud Composer async APIs."""
@@ -421,3 +548,124 @@ class CloudComposerAsyncHook(GoogleBaseHook):
421
548
  timeout=timeout,
422
549
  metadata=metadata,
423
550
  )
551
+
552
+ @GoogleBaseHook.fallback_to_default_project_id
553
+ async def execute_airflow_command(
554
+ self,
555
+ project_id: str,
556
+ region: str,
557
+ environment_id: str,
558
+ command: str,
559
+ subcommand: str,
560
+ parameters: MutableSequence[str],
561
+ retry: AsyncRetry | _MethodDefault = DEFAULT,
562
+ timeout: float | None = None,
563
+ metadata: Sequence[tuple[str, str]] = (),
564
+ ) -> AsyncOperation:
565
+ """
566
+ Execute Airflow command for provided Composer environment.
567
+
568
+ :param project_id: The ID of the Google Cloud project that the service belongs to.
569
+ :param region: The ID of the Google Cloud region that the service belongs to.
570
+ :param environment_id: The ID of the Google Cloud environment that the service belongs to.
571
+ :param command: Airflow command.
572
+ :param subcommand: Airflow subcommand.
573
+ :param parameters: Parameters for the Airflow command/subcommand as an array of arguments. It may
574
+ contain positional arguments like ``["my-dag-id"]``, key-value parameters like ``["--foo=bar"]``
575
+ or ``["--foo","bar"]``, or other flags like ``["-f"]``.
576
+ :param retry: Designation of what errors, if any, should be retried.
577
+ :param timeout: The timeout for this request.
578
+ :param metadata: Strings which should be sent along with the request as metadata.
579
+ """
580
+ client = self.get_environment_client()
581
+
582
+ return await client.execute_airflow_command(
583
+ request={
584
+ "environment": self.get_environment_name(project_id, region, environment_id),
585
+ "command": command,
586
+ "subcommand": subcommand,
587
+ "parameters": parameters,
588
+ },
589
+ retry=retry,
590
+ timeout=timeout,
591
+ metadata=metadata,
592
+ )
593
+
594
+ @GoogleBaseHook.fallback_to_default_project_id
595
+ async def poll_airflow_command(
596
+ self,
597
+ project_id: str,
598
+ region: str,
599
+ environment_id: str,
600
+ execution_id: str,
601
+ pod: str,
602
+ pod_namespace: str,
603
+ next_line_number: int,
604
+ retry: AsyncRetry | _MethodDefault = DEFAULT,
605
+ timeout: float | None = None,
606
+ metadata: Sequence[tuple[str, str]] = (),
607
+ ) -> AsyncOperation:
608
+ """
609
+ Poll Airflow command execution result for provided Composer environment.
610
+
611
+ :param project_id: The ID of the Google Cloud project that the service belongs to.
612
+ :param region: The ID of the Google Cloud region that the service belongs to.
613
+ :param environment_id: The ID of the Google Cloud environment that the service belongs to.
614
+ :param execution_id: The unique ID of the command execution.
615
+ :param pod: The name of the pod where the command is executed.
616
+ :param pod_namespace: The namespace of the pod where the command is executed.
617
+ :param next_line_number: Line number from which new logs should be fetched.
618
+ :param retry: Designation of what errors, if any, should be retried.
619
+ :param timeout: The timeout for this request.
620
+ :param metadata: Strings which should be sent along with the request as metadata.
621
+ """
622
+ client = self.get_environment_client()
623
+
624
+ return await client.poll_airflow_command(
625
+ request={
626
+ "environment": self.get_environment_name(project_id, region, environment_id),
627
+ "execution_id": execution_id,
628
+ "pod": pod,
629
+ "pod_namespace": pod_namespace,
630
+ "next_line_number": next_line_number,
631
+ },
632
+ retry=retry,
633
+ timeout=timeout,
634
+ metadata=metadata,
635
+ )
636
+
637
+ async def wait_command_execution_result(
638
+ self,
639
+ project_id: str,
640
+ region: str,
641
+ environment_id: str,
642
+ execution_cmd_info: dict,
643
+ retry: AsyncRetry | _MethodDefault = DEFAULT,
644
+ timeout: float | None = None,
645
+ metadata: Sequence[tuple[str, str]] = (),
646
+ poll_interval: int = 10,
647
+ ) -> dict:
648
+ while True:
649
+ try:
650
+ result = await self.poll_airflow_command(
651
+ project_id=project_id,
652
+ region=region,
653
+ environment_id=environment_id,
654
+ execution_id=execution_cmd_info["execution_id"],
655
+ pod=execution_cmd_info["pod"],
656
+ pod_namespace=execution_cmd_info["pod_namespace"],
657
+ next_line_number=1,
658
+ retry=retry,
659
+ timeout=timeout,
660
+ metadata=metadata,
661
+ )
662
+ except Exception as ex:
663
+ self.log.exception("Exception occurred while polling CMD result")
664
+ raise AirflowException(ex)
665
+
666
+ result_dict = PollAirflowCommandResponse.to_dict(result)
667
+ if result_dict["output_end"]:
668
+ return result_dict
669
+
670
+ self.log.info("Sleeping for %s seconds.", poll_interval)
671
+ await asyncio.sleep(poll_interval)