apache-airflow-providers-google 17.1.0rc1__py3-none-any.whl → 17.2.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. airflow/providers/google/__init__.py +1 -1
  2. airflow/providers/google/cloud/hooks/bigquery.py +0 -8
  3. airflow/providers/google/cloud/hooks/cloud_composer.py +6 -1
  4. airflow/providers/google/cloud/hooks/cloud_sql.py +3 -3
  5. airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +10 -5
  6. airflow/providers/google/cloud/hooks/dataflow.py +1 -1
  7. airflow/providers/google/cloud/hooks/spanner.py +26 -6
  8. airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +3 -0
  9. airflow/providers/google/cloud/openlineage/utils.py +14 -0
  10. airflow/providers/google/cloud/operators/bigquery.py +9 -1
  11. airflow/providers/google/cloud/operators/cloud_composer.py +9 -3
  12. airflow/providers/google/cloud/operators/dataplex.py +12 -12
  13. airflow/providers/google/cloud/operators/dataproc.py +15 -8
  14. airflow/providers/google/cloud/operators/pubsub.py +55 -8
  15. airflow/providers/google/cloud/operators/spanner.py +3 -2
  16. airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +9 -2
  17. airflow/providers/google/cloud/sensors/cloud_composer.py +30 -0
  18. airflow/providers/google/cloud/transfers/bigquery_to_mssql.py +72 -1
  19. airflow/providers/google/cloud/transfers/bigquery_to_sql.py +7 -3
  20. airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +0 -6
  21. airflow/providers/google/cloud/transfers/gcs_to_sftp.py +8 -2
  22. airflow/providers/google/cloud/transfers/oracle_to_gcs.py +27 -2
  23. airflow/providers/google/cloud/transfers/postgres_to_gcs.py +42 -9
  24. airflow/providers/google/cloud/triggers/bigquery.py +23 -19
  25. airflow/providers/google/cloud/triggers/cloud_composer.py +48 -10
  26. airflow/providers/google/cloud/triggers/dataplex.py +14 -2
  27. airflow/providers/google/cloud/triggers/dataproc.py +63 -46
  28. airflow/providers/google/common/utils/get_secret.py +31 -0
  29. airflow/providers/google/suite/hooks/sheets.py +15 -1
  30. airflow/providers/google/suite/operators/sheets.py +5 -0
  31. airflow/providers/google/version_compat.py +6 -0
  32. {apache_airflow_providers_google-17.1.0rc1.dist-info → apache_airflow_providers_google-17.2.0rc1.dist-info}/METADATA +9 -10
  33. {apache_airflow_providers_google-17.1.0rc1.dist-info → apache_airflow_providers_google-17.2.0rc1.dist-info}/RECORD +35 -34
  34. {apache_airflow_providers_google-17.1.0rc1.dist-info → apache_airflow_providers_google-17.2.0rc1.dist-info}/WHEEL +0 -0
  35. {apache_airflow_providers_google-17.1.0rc1.dist-info → apache_airflow_providers_google-17.2.0rc1.dist-info}/entry_points.txt +0 -0
@@ -29,7 +29,7 @@ from airflow import __version__ as airflow_version
29
29
 
30
30
  __all__ = ["__version__"]
31
31
 
32
- __version__ = "17.1.0"
32
+ __version__ = "17.2.0"
33
33
 
34
34
  if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
35
35
  "2.10.0"
@@ -1743,14 +1743,6 @@ class BigQueryCursor(BigQueryBaseCursor):
1743
1743
  f" Please only use one or more of the following options: {allowed_schema_update_options}"
1744
1744
  )
1745
1745
 
1746
- if schema_update_options:
1747
- if write_disposition not in ["WRITE_APPEND", "WRITE_TRUNCATE"]:
1748
- raise ValueError(
1749
- "schema_update_options is only "
1750
- "allowed if write_disposition is "
1751
- "'WRITE_APPEND' or 'WRITE_TRUNCATE'."
1752
- )
1753
-
1754
1746
  if destination_dataset_table:
1755
1747
  destination_project, destination_dataset, destination_table = self.hook.split_tablename(
1756
1748
  table_input=destination_dataset_table, default_project_id=self.project_id
@@ -642,7 +642,12 @@ class CloudComposerAsyncHook(GoogleBaseHook):
642
642
  self.log.exception("Exception occurred while polling CMD result")
643
643
  raise AirflowException(ex)
644
644
 
645
- result_dict = PollAirflowCommandResponse.to_dict(result)
645
+ try:
646
+ result_dict = PollAirflowCommandResponse.to_dict(result)
647
+ except Exception as ex:
648
+ self.log.exception("Exception occurred while transforming PollAirflowCommandResponse")
649
+ raise AirflowException(ex)
650
+
646
651
  if result_dict["output_end"]:
647
652
  return result_dict
648
653
 
@@ -1175,9 +1175,9 @@ class CloudSQLDatabaseHook(BaseHook):
1175
1175
  raise ValueError("The db_hook should be set")
1176
1176
  if not isinstance(self.db_hook, PostgresHook):
1177
1177
  raise ValueError(f"The db_hook should be PostgresHook and is {type(self.db_hook)}")
1178
- conn = getattr(self.db_hook, "conn")
1179
- if conn and conn.notices:
1180
- for output in self.db_hook.conn.notices:
1178
+ conn = getattr(self.db_hook, "conn", None)
1179
+ if conn and hasattr(conn, "notices") and conn.notices:
1180
+ for output in conn.notices:
1181
1181
  self.log.info(output)
1182
1182
 
1183
1183
  def reserve_free_tcp_port(self) -> None:
@@ -36,13 +36,13 @@ from copy import deepcopy
36
36
  from datetime import timedelta
37
37
  from typing import TYPE_CHECKING, Any
38
38
 
39
- from google.api_core import protobuf_helpers
40
39
  from google.cloud.storage_transfer_v1 import (
41
40
  ListTransferJobsRequest,
42
41
  StorageTransferServiceAsyncClient,
43
42
  TransferJob,
44
43
  TransferOperation,
45
44
  )
45
+ from google.protobuf.json_format import MessageToDict
46
46
  from googleapiclient.discovery import Resource, build
47
47
  from googleapiclient.errors import HttpError
48
48
 
@@ -603,7 +603,7 @@ class CloudDataTransferServiceAsyncHook(GoogleBaseAsyncHook):
603
603
  self,
604
604
  request_filter: dict | None = None,
605
605
  **kwargs,
606
- ) -> list[TransferOperation]:
606
+ ) -> list[dict[str, Any]]:
607
607
  """
608
608
  Get a transfer operation in Google Storage Transfer Service.
609
609
 
@@ -660,7 +660,12 @@ class CloudDataTransferServiceAsyncHook(GoogleBaseAsyncHook):
660
660
  )
661
661
 
662
662
  transfer_operations = [
663
- protobuf_helpers.from_any_pb(TransferOperation, op.metadata) for op in operations
663
+ MessageToDict(
664
+ getattr(op, "_pb", op),
665
+ preserving_proto_field_name=True,
666
+ use_integers_for_enums=True,
667
+ )
668
+ for op in operations
664
669
  ]
665
670
 
666
671
  return transfer_operations
@@ -677,7 +682,7 @@ class CloudDataTransferServiceAsyncHook(GoogleBaseAsyncHook):
677
682
 
678
683
  @staticmethod
679
684
  async def operations_contain_expected_statuses(
680
- operations: list[TransferOperation], expected_statuses: set[str] | str
685
+ operations: list[dict[str, Any]], expected_statuses: set[str] | str
681
686
  ) -> bool:
682
687
  """
683
688
  Check whether an operation exists with the expected status.
@@ -696,7 +701,7 @@ class CloudDataTransferServiceAsyncHook(GoogleBaseAsyncHook):
696
701
  if not operations:
697
702
  return False
698
703
 
699
- current_statuses = {operation.status.name for operation in operations}
704
+ current_statuses = {TransferOperation.Status(op["metadata"]["status"]).name for op in operations}
700
705
 
701
706
  if len(current_statuses - expected_statuses_set) != len(current_statuses):
702
707
  return True
@@ -56,8 +56,8 @@ from airflow.providers.google.common.hooks.base_google import (
56
56
  GoogleBaseAsyncHook,
57
57
  GoogleBaseHook,
58
58
  )
59
+ from airflow.providers.google.version_compat import timeout
59
60
  from airflow.utils.log.logging_mixin import LoggingMixin
60
- from airflow.utils.timeout import timeout
61
61
 
62
62
  if TYPE_CHECKING:
63
63
  from google.cloud.dataflow_v1beta3.services.jobs_v1_beta3.pagers import ListJobsAsyncPager
@@ -19,6 +19,7 @@
19
19
 
20
20
  from __future__ import annotations
21
21
 
22
+ from collections import OrderedDict
22
23
  from collections.abc import Callable, Sequence
23
24
  from typing import TYPE_CHECKING, NamedTuple
24
25
 
@@ -388,7 +389,7 @@ class SpannerHook(GoogleBaseHook, DbApiHook):
388
389
  database_id: str,
389
390
  queries: list[str],
390
391
  project_id: str,
391
- ) -> None:
392
+ ) -> list[int]:
392
393
  """
393
394
  Execute an arbitrary DML query (INSERT, UPDATE, DELETE).
394
395
 
@@ -398,12 +399,31 @@ class SpannerHook(GoogleBaseHook, DbApiHook):
398
399
  :param project_id: Optional, the ID of the Google Cloud project that owns the Cloud Spanner
399
400
  database. If set to None or missing, the default project_id from the Google Cloud connection
400
401
  is used.
402
+ :return: list of numbers of affected rows by DML query
401
403
  """
402
- self._get_client(project_id=project_id).instance(instance_id=instance_id).database(
403
- database_id=database_id
404
- ).run_in_transaction(lambda transaction: self._execute_sql_in_transaction(transaction, queries))
404
+ db = (
405
+ self._get_client(project_id=project_id)
406
+ .instance(instance_id=instance_id)
407
+ .database(database_id=database_id)
408
+ )
409
+
410
+ def _tx_runner(tx: Transaction) -> dict[str, int]:
411
+ return self._execute_sql_in_transaction(tx, queries)
412
+
413
+ result = db.run_in_transaction(_tx_runner)
414
+
415
+ result_rows_count_per_query = []
416
+ for i, (sql, rc) in enumerate(result.items(), start=1):
417
+ if not sql.startswith("SELECT"):
418
+ preview = sql if len(sql) <= 300 else sql[:300] + "…"
419
+ self.log.info("[DML %d/%d] affected rows=%d | %s", i, len(result), rc, preview)
420
+ result_rows_count_per_query.append(rc)
421
+ return result_rows_count_per_query
405
422
 
406
423
  @staticmethod
407
- def _execute_sql_in_transaction(transaction: Transaction, queries: list[str]):
424
+ def _execute_sql_in_transaction(transaction: Transaction, queries: list[str]) -> dict[str, int]:
425
+ counts: OrderedDict[str, int] = OrderedDict()
408
426
  for sql in queries:
409
- transaction.execute_update(sql)
427
+ rc = transaction.execute_update(sql)
428
+ counts[sql] = rc
429
+ return counts
@@ -350,6 +350,9 @@ class GenerativeModelHook(GoogleBaseHook):
350
350
  :param generation_config: Optional. Generation configuration settings.
351
351
  :param safety_settings: Optional. Per request settings for blocking unsafe content.
352
352
  """
353
+ # During run of the system test it was found out that names from xcom, e.g. 3402922389 can be
354
+ # treated as int and throw an error TypeError: expected string or bytes-like object, got 'int'
355
+ cached_content_name = str(cached_content_name)
353
356
  vertexai.init(project=project_id, location=location, credentials=self.get_credentials())
354
357
 
355
358
  cached_context_model = self.get_cached_context_model(cached_content_name=cached_content_name)
@@ -214,7 +214,20 @@ def extract_ds_name_from_gcs_path(path: str) -> str:
214
214
 
215
215
  def get_facets_from_bq_table(table: Table) -> dict[str, DatasetFacet]:
216
216
  """Get facets from BigQuery table object."""
217
+ return get_facets_from_bq_table_for_given_fields(table, selected_fields=None)
218
+
219
+
220
+ def get_facets_from_bq_table_for_given_fields(
221
+ table: Table, selected_fields: list[str] | None
222
+ ) -> dict[str, DatasetFacet]:
223
+ """
224
+ Get facets from BigQuery table object for selected fields only.
225
+
226
+ If selected_fields is None, include all fields.
227
+ """
217
228
  facets: dict[str, DatasetFacet] = {}
229
+ selected_fields_set = set(selected_fields) if selected_fields else None
230
+
218
231
  if table.schema:
219
232
  facets["schema"] = SchemaDatasetFacet(
220
233
  fields=[
@@ -222,6 +235,7 @@ def get_facets_from_bq_table(table: Table) -> dict[str, DatasetFacet]:
222
235
  name=schema_field.name, type=schema_field.field_type, description=schema_field.description
223
236
  )
224
237
  for schema_field in table.schema
238
+ if selected_fields_set is None or schema_field.name in selected_fields_set
225
239
  ]
226
240
  )
227
241
  if table.description:
@@ -2370,11 +2370,19 @@ class BigQueryInsertJobOperator(GoogleCloudBaseOperator, _BigQueryInsertJobOpera
2370
2370
  if self.project_id is None:
2371
2371
  self.project_id = hook.project_id
2372
2372
 
2373
+ # Handle missing logical_date. Example: asset-triggered DAGs (Airflow 3)
2374
+ logical_date = context.get("logical_date")
2375
+ if logical_date is None:
2376
+ # Use dag_run.run_after as fallback when logical_date is not available
2377
+ dag_run = context.get("dag_run")
2378
+ if dag_run and hasattr(dag_run, "run_after"):
2379
+ logical_date = dag_run.run_after
2380
+
2373
2381
  self.job_id = hook.generate_job_id(
2374
2382
  job_id=self.job_id,
2375
2383
  dag_id=self.dag_id,
2376
2384
  task_id=self.task_id,
2377
- logical_date=context["logical_date"],
2385
+ logical_date=logical_date,
2378
2386
  configuration=self.configuration,
2379
2387
  force_rerun=self.force_rerun,
2380
2388
  )
@@ -764,9 +764,15 @@ class CloudComposerRunAirflowCLICommandOperator(GoogleCloudBaseOperator):
764
764
  metadata=self.metadata,
765
765
  poll_interval=self.poll_interval,
766
766
  )
767
- result_str = self._merge_cmd_output_result(result)
768
- self.log.info("Command execution result:\n%s", result_str)
769
- return result
767
+ exit_code = result.get("exit_info", {}).get("exit_code")
768
+ if exit_code == 0:
769
+ result_str = self._merge_cmd_output_result(result)
770
+ self.log.info("Command execution result:\n%s", result_str)
771
+ return result
772
+
773
+ error_output = "".join(line["content"] for line in result.get("error", []))
774
+ message = f"Airflow CLI command failed with exit code {exit_code}.\nError output:\n{error_output}"
775
+ raise AirflowException(message)
770
776
 
771
777
  def execute_complete(self, context: Context, event: dict) -> dict:
772
778
  if event and event["status"] == "error":
@@ -1082,11 +1082,11 @@ class DataplexRunDataQualityScanOperator(GoogleCloudBaseOperator):
1082
1082
  """
1083
1083
  job_state = event["job_state"]
1084
1084
  job_id = event["job_id"]
1085
- if job_state == DataScanJob.State.FAILED:
1085
+ if job_state == DataScanJob.State.FAILED.name: # type: ignore
1086
1086
  raise AirflowException(f"Job failed:\n{job_id}")
1087
- if job_state == DataScanJob.State.CANCELLED:
1087
+ if job_state == DataScanJob.State.CANCELLED.name: # type: ignore
1088
1088
  raise AirflowException(f"Job was cancelled:\n{job_id}")
1089
- if job_state == DataScanJob.State.SUCCEEDED:
1089
+ if job_state == DataScanJob.State.SUCCEEDED.name: # type: ignore
1090
1090
  job = event["job"]
1091
1091
  if not job["data_quality_result"]["passed"]:
1092
1092
  if self.fail_on_dq_failure:
@@ -1260,11 +1260,11 @@ class DataplexGetDataQualityScanResultOperator(GoogleCloudBaseOperator):
1260
1260
  job_state = event["job_state"]
1261
1261
  job_id = event["job_id"]
1262
1262
  job = event["job"]
1263
- if job_state == DataScanJob.State.FAILED:
1263
+ if job_state == DataScanJob.State.FAILED.name: # type: ignore
1264
1264
  raise AirflowException(f"Job failed:\n{job_id}")
1265
- if job_state == DataScanJob.State.CANCELLED:
1265
+ if job_state == DataScanJob.State.CANCELLED.name: # type: ignore
1266
1266
  raise AirflowException(f"Job was cancelled:\n{job_id}")
1267
- if job_state == DataScanJob.State.SUCCEEDED:
1267
+ if job_state == DataScanJob.State.SUCCEEDED.name: # type: ignore
1268
1268
  if not job["data_quality_result"]["passed"]:
1269
1269
  if self.fail_on_dq_failure:
1270
1270
  raise AirflowDataQualityScanException(
@@ -1639,12 +1639,12 @@ class DataplexRunDataProfileScanOperator(GoogleCloudBaseOperator):
1639
1639
  result_timeout=self.result_timeout,
1640
1640
  )
1641
1641
 
1642
- if job.state == DataScanJob.State.FAILED:
1642
+ if job.state == DataScanJob.State.FAILED.name: # type: ignore
1643
1643
  raise AirflowException(f"Data Profile job failed: {job_id}")
1644
- if job.state == DataScanJob.State.SUCCEEDED:
1644
+ if job.state == DataScanJob.State.SUCCEEDED.name: # type: ignore
1645
1645
  self.log.info("Data Profile job executed successfully.")
1646
1646
  else:
1647
- self.log.info("Data Profile job execution returned status: %s", job.status)
1647
+ self.log.info("Data Profile job execution returned status: %s", job.state)
1648
1648
 
1649
1649
  return job_id
1650
1650
 
@@ -1657,11 +1657,11 @@ class DataplexRunDataProfileScanOperator(GoogleCloudBaseOperator):
1657
1657
  """
1658
1658
  job_state = event["job_state"]
1659
1659
  job_id = event["job_id"]
1660
- if job_state == DataScanJob.State.FAILED:
1660
+ if job_state == DataScanJob.State.FAILED.name: # type: ignore
1661
1661
  raise AirflowException(f"Job failed:\n{job_id}")
1662
- if job_state == DataScanJob.State.CANCELLED:
1662
+ if job_state == DataScanJob.State.CANCELLED.name: # type: ignore
1663
1663
  raise AirflowException(f"Job was cancelled:\n{job_id}")
1664
- if job_state == DataScanJob.State.SUCCEEDED:
1664
+ if job_state == DataScanJob.State.SUCCEEDED.name: # type: ignore
1665
1665
  self.log.info("Data Profile job executed successfully.")
1666
1666
  return job_id
1667
1667
 
@@ -213,6 +213,7 @@ class ClusterGenerator:
213
213
  :param secondary_worker_accelerator_type: Type of the accelerator card (GPU) to attach to the secondary workers,
214
214
  see https://cloud.google.com/dataproc/docs/reference/rest/v1/InstanceGroupConfig#acceleratorconfig
215
215
  :param secondary_worker_accelerator_count: Number of accelerator cards (GPUs) to attach to the secondary workers
216
+ :param cluster_tier: The tier of the cluster (e.g. "CLUSTER_TIER_STANDARD" / "CLUSTER_TIER_PREMIUM").
216
217
  """
217
218
 
218
219
  def __init__(
@@ -261,6 +262,8 @@ class ClusterGenerator:
261
262
  secondary_worker_instance_flexibility_policy: InstanceFlexibilityPolicy | None = None,
262
263
  secondary_worker_accelerator_type: str | None = None,
263
264
  secondary_worker_accelerator_count: int | None = None,
265
+ *,
266
+ cluster_tier: str | None = None,
264
267
  **kwargs,
265
268
  ) -> None:
266
269
  self.project_id = project_id
@@ -308,6 +311,7 @@ class ClusterGenerator:
308
311
  self.secondary_worker_instance_flexibility_policy = secondary_worker_instance_flexibility_policy
309
312
  self.secondary_worker_accelerator_type = secondary_worker_accelerator_type
310
313
  self.secondary_worker_accelerator_count = secondary_worker_accelerator_count
314
+ self.cluster_tier = cluster_tier
311
315
 
312
316
  if self.custom_image and self.image_version:
313
317
  raise ValueError("The custom_image and image_version can't be both set")
@@ -513,6 +517,9 @@ class ClusterGenerator:
513
517
  if self.driver_pool_size > 0:
514
518
  cluster_data["auxiliary_node_groups"] = [self._build_driver_pool()]
515
519
 
520
+ if self.cluster_tier:
521
+ cluster_data["cluster_tier"] = self.cluster_tier
522
+
516
523
  cluster_data = self._build_gce_cluster_config(cluster_data)
517
524
 
518
525
  if self.single_node:
@@ -1945,9 +1952,9 @@ class DataprocSubmitJobOperator(GoogleCloudBaseOperator):
1945
1952
  job_state = event["job_state"]
1946
1953
  job_id = event["job_id"]
1947
1954
  job = event["job"]
1948
- if job_state == JobStatus.State.ERROR:
1955
+ if job_state == JobStatus.State.ERROR.name: # type: ignore
1949
1956
  raise AirflowException(f"Job {job_id} failed:\n{job}")
1950
- if job_state == JobStatus.State.CANCELLED:
1957
+ if job_state == JobStatus.State.CANCELLED.name: # type: ignore
1951
1958
  raise AirflowException(f"Job {job_id} was cancelled:\n{job}")
1952
1959
  self.log.info("%s completed successfully.", self.task_id)
1953
1960
  return job_id
@@ -2455,7 +2462,7 @@ class DataprocCreateBatchOperator(GoogleCloudBaseOperator):
2455
2462
  if not self.hook.check_error_for_resource_is_not_ready_msg(batch.state_message):
2456
2463
  break
2457
2464
 
2458
- self.handle_batch_status(context, batch.state, batch_id, batch.state_message)
2465
+ self.handle_batch_status(context, batch.state.name, batch_id, batch.state_message)
2459
2466
  return Batch.to_dict(batch)
2460
2467
 
2461
2468
  @cached_property
@@ -2480,19 +2487,19 @@ class DataprocCreateBatchOperator(GoogleCloudBaseOperator):
2480
2487
  self.operation.cancel()
2481
2488
 
2482
2489
  def handle_batch_status(
2483
- self, context: Context, state: Batch.State, batch_id: str, state_message: str | None = None
2490
+ self, context: Context, state: str, batch_id: str, state_message: str | None = None
2484
2491
  ) -> None:
2485
2492
  # The existing batch may be a number of states other than 'SUCCEEDED'\
2486
2493
  # wait_for_operation doesn't fail if the job is cancelled, so we will check for it here which also
2487
2494
  # finds a cancelling|canceled|unspecified job from wait_for_batch or the deferred trigger
2488
2495
  link = DATAPROC_BATCH_LINK.format(region=self.region, project_id=self.project_id, batch_id=batch_id)
2489
- if state == Batch.State.FAILED:
2496
+ if state == Batch.State.FAILED.name: # type: ignore
2490
2497
  raise AirflowException(
2491
2498
  f"Batch job {batch_id} failed with error: {state_message}.\nDriver logs: {link}"
2492
2499
  )
2493
- if state in (Batch.State.CANCELLED, Batch.State.CANCELLING):
2500
+ if state in (Batch.State.CANCELLED.name, Batch.State.CANCELLING.name): # type: ignore
2494
2501
  raise AirflowException(f"Batch job {batch_id} was cancelled.\nDriver logs: {link}")
2495
- if state == Batch.State.STATE_UNSPECIFIED:
2502
+ if state == Batch.State.STATE_UNSPECIFIED.name: # type: ignore
2496
2503
  raise AirflowException(f"Batch job {batch_id} unspecified.\nDriver logs: {link}")
2497
2504
  self.log.info("Batch job %s completed.\nDriver logs: %s", batch_id, link)
2498
2505
 
@@ -2566,7 +2573,7 @@ class DataprocCreateBatchOperator(GoogleCloudBaseOperator):
2566
2573
  dag_id = re.sub(r"[.\s]", "_", self.dag_id.lower())
2567
2574
  task_id = re.sub(r"[.\s]", "_", self.task_id.lower())
2568
2575
 
2569
- labels_regex = re.compile(r"^[a-z][\w-]{0,63}$")
2576
+ labels_regex = re.compile(r"^[a-z][\w-]{0,62}$")
2570
2577
  if not labels_regex.match(dag_id) or not labels_regex.match(task_id):
2571
2578
  return
2572
2579
 
@@ -26,6 +26,7 @@ This module contains Google PubSub operators.
26
26
  from __future__ import annotations
27
27
 
28
28
  from collections.abc import Callable, Sequence
29
+ from functools import cached_property
29
30
  from typing import TYPE_CHECKING, Any
30
31
 
31
32
  from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
@@ -52,6 +53,7 @@ from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID
52
53
  if TYPE_CHECKING:
53
54
  from google.api_core.retry import Retry
54
55
 
56
+ from airflow.providers.openlineage.extractors import OperatorLineage
55
57
  from airflow.utils.context import Context
56
58
 
57
59
 
@@ -359,15 +361,18 @@ class PubSubCreateSubscriptionOperator(GoogleCloudBaseOperator):
359
361
  self.timeout = timeout
360
362
  self.metadata = metadata
361
363
  self.impersonation_chain = impersonation_chain
364
+ self._resolved_subscription_name: str | None = None
362
365
 
363
- def execute(self, context: Context) -> str:
364
- hook = PubSubHook(
366
+ @cached_property
367
+ def pubsub_hook(self):
368
+ return PubSubHook(
365
369
  gcp_conn_id=self.gcp_conn_id,
366
370
  impersonation_chain=self.impersonation_chain,
367
371
  )
368
372
 
373
+ def execute(self, context: Context) -> str:
369
374
  self.log.info("Creating subscription for topic %s", self.topic)
370
- result = hook.create_subscription(
375
+ result = self.pubsub_hook.create_subscription(
371
376
  project_id=self.project_id,
372
377
  topic=self.topic,
373
378
  subscription=self.subscription,
@@ -389,13 +394,34 @@ class PubSubCreateSubscriptionOperator(GoogleCloudBaseOperator):
389
394
  )
390
395
 
391
396
  self.log.info("Created subscription for topic %s", self.topic)
397
+
398
+ # Store resolved subscription for Open Lineage
399
+ self._resolved_subscription_name = self.subscription or result
400
+
392
401
  PubSubSubscriptionLink.persist(
393
402
  context=context,
394
- subscription_id=self.subscription or result, # result returns subscription name
395
- project_id=self.project_id or hook.project_id,
403
+ subscription_id=self._resolved_subscription_name, # result returns subscription name
404
+ project_id=self.project_id or self.pubsub_hook.project_id,
396
405
  )
397
406
  return result
398
407
 
408
+ def get_openlineage_facets_on_complete(self, _) -> OperatorLineage:
409
+ from airflow.providers.common.compat.openlineage.facet import Dataset
410
+ from airflow.providers.openlineage.extractors import OperatorLineage
411
+
412
+ topic_project_id = self.project_id or self.pubsub_hook.project_id
413
+ subscription_project_id = self.subscription_project_id or topic_project_id
414
+
415
+ return OperatorLineage(
416
+ inputs=[Dataset(namespace="pubsub", name=f"topic:{topic_project_id}:{self.topic}")],
417
+ outputs=[
418
+ Dataset(
419
+ namespace="pubsub",
420
+ name=f"subscription:{subscription_project_id}:{self._resolved_subscription_name}",
421
+ )
422
+ ],
423
+ )
424
+
399
425
 
400
426
  class PubSubDeleteTopicOperator(GoogleCloudBaseOperator):
401
427
  """
@@ -692,17 +718,28 @@ class PubSubPublishMessageOperator(GoogleCloudBaseOperator):
692
718
  self.enable_message_ordering = enable_message_ordering
693
719
  self.impersonation_chain = impersonation_chain
694
720
 
695
- def execute(self, context: Context) -> None:
696
- hook = PubSubHook(
721
+ @cached_property
722
+ def pubsub_hook(self):
723
+ return PubSubHook(
697
724
  gcp_conn_id=self.gcp_conn_id,
698
725
  impersonation_chain=self.impersonation_chain,
699
726
  enable_message_ordering=self.enable_message_ordering,
700
727
  )
701
728
 
729
+ def execute(self, context: Context) -> None:
702
730
  self.log.info("Publishing to topic %s", self.topic)
703
- hook.publish(project_id=self.project_id, topic=self.topic, messages=self.messages)
731
+ self.pubsub_hook.publish(project_id=self.project_id, topic=self.topic, messages=self.messages)
704
732
  self.log.info("Published to topic %s", self.topic)
705
733
 
734
+ def get_openlineage_facets_on_complete(self, _) -> OperatorLineage:
735
+ from airflow.providers.common.compat.openlineage.facet import Dataset
736
+ from airflow.providers.openlineage.extractors import OperatorLineage
737
+
738
+ project_id = self.project_id or self.pubsub_hook.project_id
739
+ output_dataset = [Dataset(namespace="pubsub", name=f"topic:{project_id}:{self.topic}")]
740
+
741
+ return OperatorLineage(outputs=output_dataset)
742
+
706
743
 
707
744
  class PubSubPullOperator(GoogleCloudBaseOperator):
708
745
  """
@@ -853,3 +890,13 @@ class PubSubPullOperator(GoogleCloudBaseOperator):
853
890
  messages_json = [ReceivedMessage.to_dict(m) for m in pulled_messages]
854
891
 
855
892
  return messages_json
893
+
894
+ def get_openlineage_facets_on_complete(self, _) -> OperatorLineage:
895
+ from airflow.providers.common.compat.openlineage.facet import Dataset
896
+ from airflow.providers.openlineage.extractors import OperatorLineage
897
+
898
+ output_dataset = [
899
+ Dataset(namespace="pubsub", name=f"subscription:{self.project_id}:{self.subscription}")
900
+ ]
901
+
902
+ return OperatorLineage(outputs=output_dataset)
@@ -280,8 +280,8 @@ class SpannerQueryDatabaseInstanceOperator(GoogleCloudBaseOperator):
280
280
  self.instance_id,
281
281
  self.database_id,
282
282
  )
283
- self.log.info(queries)
284
- hook.execute_dml(
283
+ self.log.info("Executing queries: %s", queries)
284
+ result_rows_count_per_query = hook.execute_dml(
285
285
  project_id=self.project_id,
286
286
  instance_id=self.instance_id,
287
287
  database_id=self.database_id,
@@ -293,6 +293,7 @@ class SpannerQueryDatabaseInstanceOperator(GoogleCloudBaseOperator):
293
293
  database_id=self.database_id,
294
294
  project_id=self.project_id or hook.project_id,
295
295
  )
296
+ return result_rows_count_per_query
296
297
 
297
298
  @staticmethod
298
299
  def sanitize_queries(queries: list[str]) -> None:
@@ -58,7 +58,7 @@ class TextEmbeddingModelGetEmbeddingsOperator(GoogleCloudBaseOperator):
58
58
  account from the list granting this role to the originating account (templated).
59
59
  """
60
60
 
61
- template_fields = ("location", "project_id", "impersonation_chain", "prompt")
61
+ template_fields = ("location", "project_id", "impersonation_chain", "prompt", "pretrained_model")
62
62
 
63
63
  def __init__(
64
64
  self,
@@ -211,7 +211,14 @@ class SupervisedFineTuningTrainOperator(GoogleCloudBaseOperator):
211
211
  account from the list granting this role to the originating account (templated).
212
212
  """
213
213
 
214
- template_fields = ("location", "project_id", "impersonation_chain", "train_dataset", "validation_dataset")
214
+ template_fields = (
215
+ "location",
216
+ "project_id",
217
+ "impersonation_chain",
218
+ "train_dataset",
219
+ "validation_dataset",
220
+ "source_model",
221
+ )
215
222
 
216
223
  def __init__(
217
224
  self,
@@ -61,6 +61,7 @@ class CloudComposerDAGRunSensor(BaseSensorOperator):
61
61
  Or [datetime(2024,3,22,0,0,0)] in this case sensor will check for states from specific time in the
62
62
  past till current time execution.
63
63
  Default value datetime.timedelta(days=1).
64
+ :param composer_dag_run_id: The Run ID of executable task. The 'execution_range' param is ignored, if both specified.
64
65
  :param gcp_conn_id: The connection ID to use when fetching connection info.
65
66
  :param impersonation_chain: Optional service account to impersonate using short-term
66
67
  credentials, or chained list of accounts required to get the access_token
@@ -91,6 +92,7 @@ class CloudComposerDAGRunSensor(BaseSensorOperator):
91
92
  composer_dag_id: str,
92
93
  allowed_states: Iterable[str] | None = None,
93
94
  execution_range: timedelta | list[datetime] | None = None,
95
+ composer_dag_run_id: str | None = None,
94
96
  gcp_conn_id: str = "google_cloud_default",
95
97
  impersonation_chain: str | Sequence[str] | None = None,
96
98
  deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
@@ -104,11 +106,17 @@ class CloudComposerDAGRunSensor(BaseSensorOperator):
104
106
  self.composer_dag_id = composer_dag_id
105
107
  self.allowed_states = list(allowed_states) if allowed_states else [TaskInstanceState.SUCCESS.value]
106
108
  self.execution_range = execution_range
109
+ self.composer_dag_run_id = composer_dag_run_id
107
110
  self.gcp_conn_id = gcp_conn_id
108
111
  self.impersonation_chain = impersonation_chain
109
112
  self.deferrable = deferrable
110
113
  self.poll_interval = poll_interval
111
114
 
115
+ if self.composer_dag_run_id and self.execution_range:
116
+ self.log.warning(
117
+ "The composer_dag_run_id parameter and execution_range parameter do not work together. This run will ignore execution_range parameter and count only specified composer_dag_run_id parameter."
118
+ )
119
+
112
120
  def _get_logical_dates(self, context) -> tuple[datetime, datetime]:
113
121
  if isinstance(self.execution_range, timedelta):
114
122
  if self.execution_range < timedelta(0):
@@ -128,6 +136,20 @@ class CloudComposerDAGRunSensor(BaseSensorOperator):
128
136
 
129
137
  dag_runs = self._pull_dag_runs()
130
138
 
139
+ if len(dag_runs) == 0:
140
+ self.log.info("Dag runs are empty. Sensor waits for dag runs...")
141
+ return False
142
+
143
+ if self.composer_dag_run_id:
144
+ self.log.info(
145
+ "Sensor waits for allowed states %s for specified RunID: %s",
146
+ self.allowed_states,
147
+ self.composer_dag_run_id,
148
+ )
149
+ composer_dag_run_id_status = self._check_composer_dag_run_id_states(
150
+ dag_runs=dag_runs,
151
+ )
152
+ return composer_dag_run_id_status
131
153
  self.log.info("Sensor waits for allowed states: %s", self.allowed_states)
132
154
  allowed_states_status = self._check_dag_runs_states(
133
155
  dag_runs=dag_runs,
@@ -189,16 +211,24 @@ class CloudComposerDAGRunSensor(BaseSensorOperator):
189
211
  image_version = environment_config["config"]["software_config"]["image_version"]
190
212
  return int(image_version.split("airflow-")[1].split(".")[0])
191
213
 
214
+ def _check_composer_dag_run_id_states(self, dag_runs: list[dict]) -> bool:
215
+ for dag_run in dag_runs:
216
+ if dag_run["run_id"] == self.composer_dag_run_id and dag_run["state"] in self.allowed_states:
217
+ return True
218
+ return False
219
+
192
220
  def execute(self, context: Context) -> None:
193
221
  self._composer_airflow_version = self._get_composer_airflow_version()
194
222
  if self.deferrable:
195
223
  start_date, end_date = self._get_logical_dates(context)
196
224
  self.defer(
225
+ timeout=self.timeout,
197
226
  trigger=CloudComposerDAGRunTrigger(
198
227
  project_id=self.project_id,
199
228
  region=self.region,
200
229
  environment_id=self.environment_id,
201
230
  composer_dag_id=self.composer_dag_id,
231
+ composer_dag_run_id=self.composer_dag_run_id,
202
232
  start_date=start_date,
203
233
  end_date=end_date,
204
234
  allowed_states=self.allowed_states,