apache-airflow-providers-google 17.1.0rc1__py3-none-any.whl → 17.2.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/google/__init__.py +1 -1
- airflow/providers/google/cloud/hooks/bigquery.py +0 -8
- airflow/providers/google/cloud/hooks/cloud_composer.py +6 -1
- airflow/providers/google/cloud/hooks/cloud_sql.py +3 -3
- airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +10 -5
- airflow/providers/google/cloud/hooks/dataflow.py +1 -1
- airflow/providers/google/cloud/hooks/spanner.py +26 -6
- airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +3 -0
- airflow/providers/google/cloud/openlineage/utils.py +14 -0
- airflow/providers/google/cloud/operators/bigquery.py +9 -1
- airflow/providers/google/cloud/operators/cloud_composer.py +9 -3
- airflow/providers/google/cloud/operators/dataplex.py +12 -12
- airflow/providers/google/cloud/operators/dataproc.py +15 -8
- airflow/providers/google/cloud/operators/pubsub.py +55 -8
- airflow/providers/google/cloud/operators/spanner.py +3 -2
- airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +9 -2
- airflow/providers/google/cloud/sensors/cloud_composer.py +30 -0
- airflow/providers/google/cloud/transfers/bigquery_to_mssql.py +72 -1
- airflow/providers/google/cloud/transfers/bigquery_to_sql.py +7 -3
- airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +0 -6
- airflow/providers/google/cloud/transfers/gcs_to_sftp.py +8 -2
- airflow/providers/google/cloud/transfers/oracle_to_gcs.py +27 -2
- airflow/providers/google/cloud/transfers/postgres_to_gcs.py +42 -9
- airflow/providers/google/cloud/triggers/bigquery.py +23 -19
- airflow/providers/google/cloud/triggers/cloud_composer.py +48 -10
- airflow/providers/google/cloud/triggers/dataplex.py +14 -2
- airflow/providers/google/cloud/triggers/dataproc.py +63 -46
- airflow/providers/google/common/utils/get_secret.py +31 -0
- airflow/providers/google/suite/hooks/sheets.py +15 -1
- airflow/providers/google/suite/operators/sheets.py +5 -0
- airflow/providers/google/version_compat.py +6 -0
- {apache_airflow_providers_google-17.1.0rc1.dist-info → apache_airflow_providers_google-17.2.0rc1.dist-info}/METADATA +9 -10
- {apache_airflow_providers_google-17.1.0rc1.dist-info → apache_airflow_providers_google-17.2.0rc1.dist-info}/RECORD +35 -34
- {apache_airflow_providers_google-17.1.0rc1.dist-info → apache_airflow_providers_google-17.2.0rc1.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_google-17.1.0rc1.dist-info → apache_airflow_providers_google-17.2.0rc1.dist-info}/entry_points.txt +0 -0
@@ -29,7 +29,7 @@ from airflow import __version__ as airflow_version
|
|
29
29
|
|
30
30
|
__all__ = ["__version__"]
|
31
31
|
|
32
|
-
__version__ = "17.
|
32
|
+
__version__ = "17.2.0"
|
33
33
|
|
34
34
|
if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
|
35
35
|
"2.10.0"
|
@@ -1743,14 +1743,6 @@ class BigQueryCursor(BigQueryBaseCursor):
|
|
1743
1743
|
f" Please only use one or more of the following options: {allowed_schema_update_options}"
|
1744
1744
|
)
|
1745
1745
|
|
1746
|
-
if schema_update_options:
|
1747
|
-
if write_disposition not in ["WRITE_APPEND", "WRITE_TRUNCATE"]:
|
1748
|
-
raise ValueError(
|
1749
|
-
"schema_update_options is only "
|
1750
|
-
"allowed if write_disposition is "
|
1751
|
-
"'WRITE_APPEND' or 'WRITE_TRUNCATE'."
|
1752
|
-
)
|
1753
|
-
|
1754
1746
|
if destination_dataset_table:
|
1755
1747
|
destination_project, destination_dataset, destination_table = self.hook.split_tablename(
|
1756
1748
|
table_input=destination_dataset_table, default_project_id=self.project_id
|
@@ -642,7 +642,12 @@ class CloudComposerAsyncHook(GoogleBaseHook):
|
|
642
642
|
self.log.exception("Exception occurred while polling CMD result")
|
643
643
|
raise AirflowException(ex)
|
644
644
|
|
645
|
-
|
645
|
+
try:
|
646
|
+
result_dict = PollAirflowCommandResponse.to_dict(result)
|
647
|
+
except Exception as ex:
|
648
|
+
self.log.exception("Exception occurred while transforming PollAirflowCommandResponse")
|
649
|
+
raise AirflowException(ex)
|
650
|
+
|
646
651
|
if result_dict["output_end"]:
|
647
652
|
return result_dict
|
648
653
|
|
@@ -1175,9 +1175,9 @@ class CloudSQLDatabaseHook(BaseHook):
|
|
1175
1175
|
raise ValueError("The db_hook should be set")
|
1176
1176
|
if not isinstance(self.db_hook, PostgresHook):
|
1177
1177
|
raise ValueError(f"The db_hook should be PostgresHook and is {type(self.db_hook)}")
|
1178
|
-
conn = getattr(self.db_hook, "conn")
|
1179
|
-
if conn and conn.notices:
|
1180
|
-
for output in
|
1178
|
+
conn = getattr(self.db_hook, "conn", None)
|
1179
|
+
if conn and hasattr(conn, "notices") and conn.notices:
|
1180
|
+
for output in conn.notices:
|
1181
1181
|
self.log.info(output)
|
1182
1182
|
|
1183
1183
|
def reserve_free_tcp_port(self) -> None:
|
@@ -36,13 +36,13 @@ from copy import deepcopy
|
|
36
36
|
from datetime import timedelta
|
37
37
|
from typing import TYPE_CHECKING, Any
|
38
38
|
|
39
|
-
from google.api_core import protobuf_helpers
|
40
39
|
from google.cloud.storage_transfer_v1 import (
|
41
40
|
ListTransferJobsRequest,
|
42
41
|
StorageTransferServiceAsyncClient,
|
43
42
|
TransferJob,
|
44
43
|
TransferOperation,
|
45
44
|
)
|
45
|
+
from google.protobuf.json_format import MessageToDict
|
46
46
|
from googleapiclient.discovery import Resource, build
|
47
47
|
from googleapiclient.errors import HttpError
|
48
48
|
|
@@ -603,7 +603,7 @@ class CloudDataTransferServiceAsyncHook(GoogleBaseAsyncHook):
|
|
603
603
|
self,
|
604
604
|
request_filter: dict | None = None,
|
605
605
|
**kwargs,
|
606
|
-
) -> list[
|
606
|
+
) -> list[dict[str, Any]]:
|
607
607
|
"""
|
608
608
|
Get a transfer operation in Google Storage Transfer Service.
|
609
609
|
|
@@ -660,7 +660,12 @@ class CloudDataTransferServiceAsyncHook(GoogleBaseAsyncHook):
|
|
660
660
|
)
|
661
661
|
|
662
662
|
transfer_operations = [
|
663
|
-
|
663
|
+
MessageToDict(
|
664
|
+
getattr(op, "_pb", op),
|
665
|
+
preserving_proto_field_name=True,
|
666
|
+
use_integers_for_enums=True,
|
667
|
+
)
|
668
|
+
for op in operations
|
664
669
|
]
|
665
670
|
|
666
671
|
return transfer_operations
|
@@ -677,7 +682,7 @@ class CloudDataTransferServiceAsyncHook(GoogleBaseAsyncHook):
|
|
677
682
|
|
678
683
|
@staticmethod
|
679
684
|
async def operations_contain_expected_statuses(
|
680
|
-
operations: list[
|
685
|
+
operations: list[dict[str, Any]], expected_statuses: set[str] | str
|
681
686
|
) -> bool:
|
682
687
|
"""
|
683
688
|
Check whether an operation exists with the expected status.
|
@@ -696,7 +701,7 @@ class CloudDataTransferServiceAsyncHook(GoogleBaseAsyncHook):
|
|
696
701
|
if not operations:
|
697
702
|
return False
|
698
703
|
|
699
|
-
current_statuses = {
|
704
|
+
current_statuses = {TransferOperation.Status(op["metadata"]["status"]).name for op in operations}
|
700
705
|
|
701
706
|
if len(current_statuses - expected_statuses_set) != len(current_statuses):
|
702
707
|
return True
|
@@ -56,8 +56,8 @@ from airflow.providers.google.common.hooks.base_google import (
|
|
56
56
|
GoogleBaseAsyncHook,
|
57
57
|
GoogleBaseHook,
|
58
58
|
)
|
59
|
+
from airflow.providers.google.version_compat import timeout
|
59
60
|
from airflow.utils.log.logging_mixin import LoggingMixin
|
60
|
-
from airflow.utils.timeout import timeout
|
61
61
|
|
62
62
|
if TYPE_CHECKING:
|
63
63
|
from google.cloud.dataflow_v1beta3.services.jobs_v1_beta3.pagers import ListJobsAsyncPager
|
@@ -19,6 +19,7 @@
|
|
19
19
|
|
20
20
|
from __future__ import annotations
|
21
21
|
|
22
|
+
from collections import OrderedDict
|
22
23
|
from collections.abc import Callable, Sequence
|
23
24
|
from typing import TYPE_CHECKING, NamedTuple
|
24
25
|
|
@@ -388,7 +389,7 @@ class SpannerHook(GoogleBaseHook, DbApiHook):
|
|
388
389
|
database_id: str,
|
389
390
|
queries: list[str],
|
390
391
|
project_id: str,
|
391
|
-
) ->
|
392
|
+
) -> list[int]:
|
392
393
|
"""
|
393
394
|
Execute an arbitrary DML query (INSERT, UPDATE, DELETE).
|
394
395
|
|
@@ -398,12 +399,31 @@ class SpannerHook(GoogleBaseHook, DbApiHook):
|
|
398
399
|
:param project_id: Optional, the ID of the Google Cloud project that owns the Cloud Spanner
|
399
400
|
database. If set to None or missing, the default project_id from the Google Cloud connection
|
400
401
|
is used.
|
402
|
+
:return: list of numbers of affected rows by DML query
|
401
403
|
"""
|
402
|
-
|
403
|
-
|
404
|
-
|
404
|
+
db = (
|
405
|
+
self._get_client(project_id=project_id)
|
406
|
+
.instance(instance_id=instance_id)
|
407
|
+
.database(database_id=database_id)
|
408
|
+
)
|
409
|
+
|
410
|
+
def _tx_runner(tx: Transaction) -> dict[str, int]:
|
411
|
+
return self._execute_sql_in_transaction(tx, queries)
|
412
|
+
|
413
|
+
result = db.run_in_transaction(_tx_runner)
|
414
|
+
|
415
|
+
result_rows_count_per_query = []
|
416
|
+
for i, (sql, rc) in enumerate(result.items(), start=1):
|
417
|
+
if not sql.startswith("SELECT"):
|
418
|
+
preview = sql if len(sql) <= 300 else sql[:300] + "…"
|
419
|
+
self.log.info("[DML %d/%d] affected rows=%d | %s", i, len(result), rc, preview)
|
420
|
+
result_rows_count_per_query.append(rc)
|
421
|
+
return result_rows_count_per_query
|
405
422
|
|
406
423
|
@staticmethod
|
407
|
-
def _execute_sql_in_transaction(transaction: Transaction, queries: list[str]):
|
424
|
+
def _execute_sql_in_transaction(transaction: Transaction, queries: list[str]) -> dict[str, int]:
|
425
|
+
counts: OrderedDict[str, int] = OrderedDict()
|
408
426
|
for sql in queries:
|
409
|
-
transaction.execute_update(sql)
|
427
|
+
rc = transaction.execute_update(sql)
|
428
|
+
counts[sql] = rc
|
429
|
+
return counts
|
@@ -350,6 +350,9 @@ class GenerativeModelHook(GoogleBaseHook):
|
|
350
350
|
:param generation_config: Optional. Generation configuration settings.
|
351
351
|
:param safety_settings: Optional. Per request settings for blocking unsafe content.
|
352
352
|
"""
|
353
|
+
# During run of the system test it was found out that names from xcom, e.g. 3402922389 can be
|
354
|
+
# treated as int and throw an error TypeError: expected string or bytes-like object, got 'int'
|
355
|
+
cached_content_name = str(cached_content_name)
|
353
356
|
vertexai.init(project=project_id, location=location, credentials=self.get_credentials())
|
354
357
|
|
355
358
|
cached_context_model = self.get_cached_context_model(cached_content_name=cached_content_name)
|
@@ -214,7 +214,20 @@ def extract_ds_name_from_gcs_path(path: str) -> str:
|
|
214
214
|
|
215
215
|
def get_facets_from_bq_table(table: Table) -> dict[str, DatasetFacet]:
|
216
216
|
"""Get facets from BigQuery table object."""
|
217
|
+
return get_facets_from_bq_table_for_given_fields(table, selected_fields=None)
|
218
|
+
|
219
|
+
|
220
|
+
def get_facets_from_bq_table_for_given_fields(
|
221
|
+
table: Table, selected_fields: list[str] | None
|
222
|
+
) -> dict[str, DatasetFacet]:
|
223
|
+
"""
|
224
|
+
Get facets from BigQuery table object for selected fields only.
|
225
|
+
|
226
|
+
If selected_fields is None, include all fields.
|
227
|
+
"""
|
217
228
|
facets: dict[str, DatasetFacet] = {}
|
229
|
+
selected_fields_set = set(selected_fields) if selected_fields else None
|
230
|
+
|
218
231
|
if table.schema:
|
219
232
|
facets["schema"] = SchemaDatasetFacet(
|
220
233
|
fields=[
|
@@ -222,6 +235,7 @@ def get_facets_from_bq_table(table: Table) -> dict[str, DatasetFacet]:
|
|
222
235
|
name=schema_field.name, type=schema_field.field_type, description=schema_field.description
|
223
236
|
)
|
224
237
|
for schema_field in table.schema
|
238
|
+
if selected_fields_set is None or schema_field.name in selected_fields_set
|
225
239
|
]
|
226
240
|
)
|
227
241
|
if table.description:
|
@@ -2370,11 +2370,19 @@ class BigQueryInsertJobOperator(GoogleCloudBaseOperator, _BigQueryInsertJobOpera
|
|
2370
2370
|
if self.project_id is None:
|
2371
2371
|
self.project_id = hook.project_id
|
2372
2372
|
|
2373
|
+
# Handle missing logical_date. Example: asset-triggered DAGs (Airflow 3)
|
2374
|
+
logical_date = context.get("logical_date")
|
2375
|
+
if logical_date is None:
|
2376
|
+
# Use dag_run.run_after as fallback when logical_date is not available
|
2377
|
+
dag_run = context.get("dag_run")
|
2378
|
+
if dag_run and hasattr(dag_run, "run_after"):
|
2379
|
+
logical_date = dag_run.run_after
|
2380
|
+
|
2373
2381
|
self.job_id = hook.generate_job_id(
|
2374
2382
|
job_id=self.job_id,
|
2375
2383
|
dag_id=self.dag_id,
|
2376
2384
|
task_id=self.task_id,
|
2377
|
-
logical_date=
|
2385
|
+
logical_date=logical_date,
|
2378
2386
|
configuration=self.configuration,
|
2379
2387
|
force_rerun=self.force_rerun,
|
2380
2388
|
)
|
@@ -764,9 +764,15 @@ class CloudComposerRunAirflowCLICommandOperator(GoogleCloudBaseOperator):
|
|
764
764
|
metadata=self.metadata,
|
765
765
|
poll_interval=self.poll_interval,
|
766
766
|
)
|
767
|
-
|
768
|
-
|
769
|
-
|
767
|
+
exit_code = result.get("exit_info", {}).get("exit_code")
|
768
|
+
if exit_code == 0:
|
769
|
+
result_str = self._merge_cmd_output_result(result)
|
770
|
+
self.log.info("Command execution result:\n%s", result_str)
|
771
|
+
return result
|
772
|
+
|
773
|
+
error_output = "".join(line["content"] for line in result.get("error", []))
|
774
|
+
message = f"Airflow CLI command failed with exit code {exit_code}.\nError output:\n{error_output}"
|
775
|
+
raise AirflowException(message)
|
770
776
|
|
771
777
|
def execute_complete(self, context: Context, event: dict) -> dict:
|
772
778
|
if event and event["status"] == "error":
|
@@ -1082,11 +1082,11 @@ class DataplexRunDataQualityScanOperator(GoogleCloudBaseOperator):
|
|
1082
1082
|
"""
|
1083
1083
|
job_state = event["job_state"]
|
1084
1084
|
job_id = event["job_id"]
|
1085
|
-
if job_state == DataScanJob.State.FAILED:
|
1085
|
+
if job_state == DataScanJob.State.FAILED.name: # type: ignore
|
1086
1086
|
raise AirflowException(f"Job failed:\n{job_id}")
|
1087
|
-
if job_state == DataScanJob.State.CANCELLED:
|
1087
|
+
if job_state == DataScanJob.State.CANCELLED.name: # type: ignore
|
1088
1088
|
raise AirflowException(f"Job was cancelled:\n{job_id}")
|
1089
|
-
if job_state == DataScanJob.State.SUCCEEDED:
|
1089
|
+
if job_state == DataScanJob.State.SUCCEEDED.name: # type: ignore
|
1090
1090
|
job = event["job"]
|
1091
1091
|
if not job["data_quality_result"]["passed"]:
|
1092
1092
|
if self.fail_on_dq_failure:
|
@@ -1260,11 +1260,11 @@ class DataplexGetDataQualityScanResultOperator(GoogleCloudBaseOperator):
|
|
1260
1260
|
job_state = event["job_state"]
|
1261
1261
|
job_id = event["job_id"]
|
1262
1262
|
job = event["job"]
|
1263
|
-
if job_state == DataScanJob.State.FAILED:
|
1263
|
+
if job_state == DataScanJob.State.FAILED.name: # type: ignore
|
1264
1264
|
raise AirflowException(f"Job failed:\n{job_id}")
|
1265
|
-
if job_state == DataScanJob.State.CANCELLED:
|
1265
|
+
if job_state == DataScanJob.State.CANCELLED.name: # type: ignore
|
1266
1266
|
raise AirflowException(f"Job was cancelled:\n{job_id}")
|
1267
|
-
if job_state == DataScanJob.State.SUCCEEDED:
|
1267
|
+
if job_state == DataScanJob.State.SUCCEEDED.name: # type: ignore
|
1268
1268
|
if not job["data_quality_result"]["passed"]:
|
1269
1269
|
if self.fail_on_dq_failure:
|
1270
1270
|
raise AirflowDataQualityScanException(
|
@@ -1639,12 +1639,12 @@ class DataplexRunDataProfileScanOperator(GoogleCloudBaseOperator):
|
|
1639
1639
|
result_timeout=self.result_timeout,
|
1640
1640
|
)
|
1641
1641
|
|
1642
|
-
if job.state == DataScanJob.State.FAILED:
|
1642
|
+
if job.state == DataScanJob.State.FAILED.name: # type: ignore
|
1643
1643
|
raise AirflowException(f"Data Profile job failed: {job_id}")
|
1644
|
-
if job.state == DataScanJob.State.SUCCEEDED:
|
1644
|
+
if job.state == DataScanJob.State.SUCCEEDED.name: # type: ignore
|
1645
1645
|
self.log.info("Data Profile job executed successfully.")
|
1646
1646
|
else:
|
1647
|
-
self.log.info("Data Profile job execution returned status: %s", job.
|
1647
|
+
self.log.info("Data Profile job execution returned status: %s", job.state)
|
1648
1648
|
|
1649
1649
|
return job_id
|
1650
1650
|
|
@@ -1657,11 +1657,11 @@ class DataplexRunDataProfileScanOperator(GoogleCloudBaseOperator):
|
|
1657
1657
|
"""
|
1658
1658
|
job_state = event["job_state"]
|
1659
1659
|
job_id = event["job_id"]
|
1660
|
-
if job_state == DataScanJob.State.FAILED:
|
1660
|
+
if job_state == DataScanJob.State.FAILED.name: # type: ignore
|
1661
1661
|
raise AirflowException(f"Job failed:\n{job_id}")
|
1662
|
-
if job_state == DataScanJob.State.CANCELLED:
|
1662
|
+
if job_state == DataScanJob.State.CANCELLED.name: # type: ignore
|
1663
1663
|
raise AirflowException(f"Job was cancelled:\n{job_id}")
|
1664
|
-
if job_state == DataScanJob.State.SUCCEEDED:
|
1664
|
+
if job_state == DataScanJob.State.SUCCEEDED.name: # type: ignore
|
1665
1665
|
self.log.info("Data Profile job executed successfully.")
|
1666
1666
|
return job_id
|
1667
1667
|
|
@@ -213,6 +213,7 @@ class ClusterGenerator:
|
|
213
213
|
:param secondary_worker_accelerator_type: Type of the accelerator card (GPU) to attach to the secondary workers,
|
214
214
|
see https://cloud.google.com/dataproc/docs/reference/rest/v1/InstanceGroupConfig#acceleratorconfig
|
215
215
|
:param secondary_worker_accelerator_count: Number of accelerator cards (GPUs) to attach to the secondary workers
|
216
|
+
:param cluster_tier: The tier of the cluster (e.g. "CLUSTER_TIER_STANDARD" / "CLUSTER_TIER_PREMIUM").
|
216
217
|
"""
|
217
218
|
|
218
219
|
def __init__(
|
@@ -261,6 +262,8 @@ class ClusterGenerator:
|
|
261
262
|
secondary_worker_instance_flexibility_policy: InstanceFlexibilityPolicy | None = None,
|
262
263
|
secondary_worker_accelerator_type: str | None = None,
|
263
264
|
secondary_worker_accelerator_count: int | None = None,
|
265
|
+
*,
|
266
|
+
cluster_tier: str | None = None,
|
264
267
|
**kwargs,
|
265
268
|
) -> None:
|
266
269
|
self.project_id = project_id
|
@@ -308,6 +311,7 @@ class ClusterGenerator:
|
|
308
311
|
self.secondary_worker_instance_flexibility_policy = secondary_worker_instance_flexibility_policy
|
309
312
|
self.secondary_worker_accelerator_type = secondary_worker_accelerator_type
|
310
313
|
self.secondary_worker_accelerator_count = secondary_worker_accelerator_count
|
314
|
+
self.cluster_tier = cluster_tier
|
311
315
|
|
312
316
|
if self.custom_image and self.image_version:
|
313
317
|
raise ValueError("The custom_image and image_version can't be both set")
|
@@ -513,6 +517,9 @@ class ClusterGenerator:
|
|
513
517
|
if self.driver_pool_size > 0:
|
514
518
|
cluster_data["auxiliary_node_groups"] = [self._build_driver_pool()]
|
515
519
|
|
520
|
+
if self.cluster_tier:
|
521
|
+
cluster_data["cluster_tier"] = self.cluster_tier
|
522
|
+
|
516
523
|
cluster_data = self._build_gce_cluster_config(cluster_data)
|
517
524
|
|
518
525
|
if self.single_node:
|
@@ -1945,9 +1952,9 @@ class DataprocSubmitJobOperator(GoogleCloudBaseOperator):
|
|
1945
1952
|
job_state = event["job_state"]
|
1946
1953
|
job_id = event["job_id"]
|
1947
1954
|
job = event["job"]
|
1948
|
-
if job_state == JobStatus.State.ERROR:
|
1955
|
+
if job_state == JobStatus.State.ERROR.name: # type: ignore
|
1949
1956
|
raise AirflowException(f"Job {job_id} failed:\n{job}")
|
1950
|
-
if job_state == JobStatus.State.CANCELLED:
|
1957
|
+
if job_state == JobStatus.State.CANCELLED.name: # type: ignore
|
1951
1958
|
raise AirflowException(f"Job {job_id} was cancelled:\n{job}")
|
1952
1959
|
self.log.info("%s completed successfully.", self.task_id)
|
1953
1960
|
return job_id
|
@@ -2455,7 +2462,7 @@ class DataprocCreateBatchOperator(GoogleCloudBaseOperator):
|
|
2455
2462
|
if not self.hook.check_error_for_resource_is_not_ready_msg(batch.state_message):
|
2456
2463
|
break
|
2457
2464
|
|
2458
|
-
self.handle_batch_status(context, batch.state, batch_id, batch.state_message)
|
2465
|
+
self.handle_batch_status(context, batch.state.name, batch_id, batch.state_message)
|
2459
2466
|
return Batch.to_dict(batch)
|
2460
2467
|
|
2461
2468
|
@cached_property
|
@@ -2480,19 +2487,19 @@ class DataprocCreateBatchOperator(GoogleCloudBaseOperator):
|
|
2480
2487
|
self.operation.cancel()
|
2481
2488
|
|
2482
2489
|
def handle_batch_status(
|
2483
|
-
self, context: Context, state:
|
2490
|
+
self, context: Context, state: str, batch_id: str, state_message: str | None = None
|
2484
2491
|
) -> None:
|
2485
2492
|
# The existing batch may be a number of states other than 'SUCCEEDED'\
|
2486
2493
|
# wait_for_operation doesn't fail if the job is cancelled, so we will check for it here which also
|
2487
2494
|
# finds a cancelling|canceled|unspecified job from wait_for_batch or the deferred trigger
|
2488
2495
|
link = DATAPROC_BATCH_LINK.format(region=self.region, project_id=self.project_id, batch_id=batch_id)
|
2489
|
-
if state == Batch.State.FAILED:
|
2496
|
+
if state == Batch.State.FAILED.name: # type: ignore
|
2490
2497
|
raise AirflowException(
|
2491
2498
|
f"Batch job {batch_id} failed with error: {state_message}.\nDriver logs: {link}"
|
2492
2499
|
)
|
2493
|
-
if state in (Batch.State.CANCELLED, Batch.State.CANCELLING):
|
2500
|
+
if state in (Batch.State.CANCELLED.name, Batch.State.CANCELLING.name): # type: ignore
|
2494
2501
|
raise AirflowException(f"Batch job {batch_id} was cancelled.\nDriver logs: {link}")
|
2495
|
-
if state == Batch.State.STATE_UNSPECIFIED:
|
2502
|
+
if state == Batch.State.STATE_UNSPECIFIED.name: # type: ignore
|
2496
2503
|
raise AirflowException(f"Batch job {batch_id} unspecified.\nDriver logs: {link}")
|
2497
2504
|
self.log.info("Batch job %s completed.\nDriver logs: %s", batch_id, link)
|
2498
2505
|
|
@@ -2566,7 +2573,7 @@ class DataprocCreateBatchOperator(GoogleCloudBaseOperator):
|
|
2566
2573
|
dag_id = re.sub(r"[.\s]", "_", self.dag_id.lower())
|
2567
2574
|
task_id = re.sub(r"[.\s]", "_", self.task_id.lower())
|
2568
2575
|
|
2569
|
-
labels_regex = re.compile(r"^[a-z][\w-]{0,
|
2576
|
+
labels_regex = re.compile(r"^[a-z][\w-]{0,62}$")
|
2570
2577
|
if not labels_regex.match(dag_id) or not labels_regex.match(task_id):
|
2571
2578
|
return
|
2572
2579
|
|
@@ -26,6 +26,7 @@ This module contains Google PubSub operators.
|
|
26
26
|
from __future__ import annotations
|
27
27
|
|
28
28
|
from collections.abc import Callable, Sequence
|
29
|
+
from functools import cached_property
|
29
30
|
from typing import TYPE_CHECKING, Any
|
30
31
|
|
31
32
|
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
|
@@ -52,6 +53,7 @@ from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID
|
|
52
53
|
if TYPE_CHECKING:
|
53
54
|
from google.api_core.retry import Retry
|
54
55
|
|
56
|
+
from airflow.providers.openlineage.extractors import OperatorLineage
|
55
57
|
from airflow.utils.context import Context
|
56
58
|
|
57
59
|
|
@@ -359,15 +361,18 @@ class PubSubCreateSubscriptionOperator(GoogleCloudBaseOperator):
|
|
359
361
|
self.timeout = timeout
|
360
362
|
self.metadata = metadata
|
361
363
|
self.impersonation_chain = impersonation_chain
|
364
|
+
self._resolved_subscription_name: str | None = None
|
362
365
|
|
363
|
-
|
364
|
-
|
366
|
+
@cached_property
|
367
|
+
def pubsub_hook(self):
|
368
|
+
return PubSubHook(
|
365
369
|
gcp_conn_id=self.gcp_conn_id,
|
366
370
|
impersonation_chain=self.impersonation_chain,
|
367
371
|
)
|
368
372
|
|
373
|
+
def execute(self, context: Context) -> str:
|
369
374
|
self.log.info("Creating subscription for topic %s", self.topic)
|
370
|
-
result =
|
375
|
+
result = self.pubsub_hook.create_subscription(
|
371
376
|
project_id=self.project_id,
|
372
377
|
topic=self.topic,
|
373
378
|
subscription=self.subscription,
|
@@ -389,13 +394,34 @@ class PubSubCreateSubscriptionOperator(GoogleCloudBaseOperator):
|
|
389
394
|
)
|
390
395
|
|
391
396
|
self.log.info("Created subscription for topic %s", self.topic)
|
397
|
+
|
398
|
+
# Store resolved subscription for Open Lineage
|
399
|
+
self._resolved_subscription_name = self.subscription or result
|
400
|
+
|
392
401
|
PubSubSubscriptionLink.persist(
|
393
402
|
context=context,
|
394
|
-
subscription_id=self.
|
395
|
-
project_id=self.project_id or
|
403
|
+
subscription_id=self._resolved_subscription_name, # result returns subscription name
|
404
|
+
project_id=self.project_id or self.pubsub_hook.project_id,
|
396
405
|
)
|
397
406
|
return result
|
398
407
|
|
408
|
+
def get_openlineage_facets_on_complete(self, _) -> OperatorLineage:
|
409
|
+
from airflow.providers.common.compat.openlineage.facet import Dataset
|
410
|
+
from airflow.providers.openlineage.extractors import OperatorLineage
|
411
|
+
|
412
|
+
topic_project_id = self.project_id or self.pubsub_hook.project_id
|
413
|
+
subscription_project_id = self.subscription_project_id or topic_project_id
|
414
|
+
|
415
|
+
return OperatorLineage(
|
416
|
+
inputs=[Dataset(namespace="pubsub", name=f"topic:{topic_project_id}:{self.topic}")],
|
417
|
+
outputs=[
|
418
|
+
Dataset(
|
419
|
+
namespace="pubsub",
|
420
|
+
name=f"subscription:{subscription_project_id}:{self._resolved_subscription_name}",
|
421
|
+
)
|
422
|
+
],
|
423
|
+
)
|
424
|
+
|
399
425
|
|
400
426
|
class PubSubDeleteTopicOperator(GoogleCloudBaseOperator):
|
401
427
|
"""
|
@@ -692,17 +718,28 @@ class PubSubPublishMessageOperator(GoogleCloudBaseOperator):
|
|
692
718
|
self.enable_message_ordering = enable_message_ordering
|
693
719
|
self.impersonation_chain = impersonation_chain
|
694
720
|
|
695
|
-
|
696
|
-
|
721
|
+
@cached_property
|
722
|
+
def pubsub_hook(self):
|
723
|
+
return PubSubHook(
|
697
724
|
gcp_conn_id=self.gcp_conn_id,
|
698
725
|
impersonation_chain=self.impersonation_chain,
|
699
726
|
enable_message_ordering=self.enable_message_ordering,
|
700
727
|
)
|
701
728
|
|
729
|
+
def execute(self, context: Context) -> None:
|
702
730
|
self.log.info("Publishing to topic %s", self.topic)
|
703
|
-
|
731
|
+
self.pubsub_hook.publish(project_id=self.project_id, topic=self.topic, messages=self.messages)
|
704
732
|
self.log.info("Published to topic %s", self.topic)
|
705
733
|
|
734
|
+
def get_openlineage_facets_on_complete(self, _) -> OperatorLineage:
|
735
|
+
from airflow.providers.common.compat.openlineage.facet import Dataset
|
736
|
+
from airflow.providers.openlineage.extractors import OperatorLineage
|
737
|
+
|
738
|
+
project_id = self.project_id or self.pubsub_hook.project_id
|
739
|
+
output_dataset = [Dataset(namespace="pubsub", name=f"topic:{project_id}:{self.topic}")]
|
740
|
+
|
741
|
+
return OperatorLineage(outputs=output_dataset)
|
742
|
+
|
706
743
|
|
707
744
|
class PubSubPullOperator(GoogleCloudBaseOperator):
|
708
745
|
"""
|
@@ -853,3 +890,13 @@ class PubSubPullOperator(GoogleCloudBaseOperator):
|
|
853
890
|
messages_json = [ReceivedMessage.to_dict(m) for m in pulled_messages]
|
854
891
|
|
855
892
|
return messages_json
|
893
|
+
|
894
|
+
def get_openlineage_facets_on_complete(self, _) -> OperatorLineage:
|
895
|
+
from airflow.providers.common.compat.openlineage.facet import Dataset
|
896
|
+
from airflow.providers.openlineage.extractors import OperatorLineage
|
897
|
+
|
898
|
+
output_dataset = [
|
899
|
+
Dataset(namespace="pubsub", name=f"subscription:{self.project_id}:{self.subscription}")
|
900
|
+
]
|
901
|
+
|
902
|
+
return OperatorLineage(outputs=output_dataset)
|
@@ -280,8 +280,8 @@ class SpannerQueryDatabaseInstanceOperator(GoogleCloudBaseOperator):
|
|
280
280
|
self.instance_id,
|
281
281
|
self.database_id,
|
282
282
|
)
|
283
|
-
self.log.info(queries)
|
284
|
-
hook.execute_dml(
|
283
|
+
self.log.info("Executing queries: %s", queries)
|
284
|
+
result_rows_count_per_query = hook.execute_dml(
|
285
285
|
project_id=self.project_id,
|
286
286
|
instance_id=self.instance_id,
|
287
287
|
database_id=self.database_id,
|
@@ -293,6 +293,7 @@ class SpannerQueryDatabaseInstanceOperator(GoogleCloudBaseOperator):
|
|
293
293
|
database_id=self.database_id,
|
294
294
|
project_id=self.project_id or hook.project_id,
|
295
295
|
)
|
296
|
+
return result_rows_count_per_query
|
296
297
|
|
297
298
|
@staticmethod
|
298
299
|
def sanitize_queries(queries: list[str]) -> None:
|
@@ -58,7 +58,7 @@ class TextEmbeddingModelGetEmbeddingsOperator(GoogleCloudBaseOperator):
|
|
58
58
|
account from the list granting this role to the originating account (templated).
|
59
59
|
"""
|
60
60
|
|
61
|
-
template_fields = ("location", "project_id", "impersonation_chain", "prompt")
|
61
|
+
template_fields = ("location", "project_id", "impersonation_chain", "prompt", "pretrained_model")
|
62
62
|
|
63
63
|
def __init__(
|
64
64
|
self,
|
@@ -211,7 +211,14 @@ class SupervisedFineTuningTrainOperator(GoogleCloudBaseOperator):
|
|
211
211
|
account from the list granting this role to the originating account (templated).
|
212
212
|
"""
|
213
213
|
|
214
|
-
template_fields = (
|
214
|
+
template_fields = (
|
215
|
+
"location",
|
216
|
+
"project_id",
|
217
|
+
"impersonation_chain",
|
218
|
+
"train_dataset",
|
219
|
+
"validation_dataset",
|
220
|
+
"source_model",
|
221
|
+
)
|
215
222
|
|
216
223
|
def __init__(
|
217
224
|
self,
|
@@ -61,6 +61,7 @@ class CloudComposerDAGRunSensor(BaseSensorOperator):
|
|
61
61
|
Or [datetime(2024,3,22,0,0,0)] in this case sensor will check for states from specific time in the
|
62
62
|
past till current time execution.
|
63
63
|
Default value datetime.timedelta(days=1).
|
64
|
+
:param composer_dag_run_id: The Run ID of executable task. The 'execution_range' param is ignored, if both specified.
|
64
65
|
:param gcp_conn_id: The connection ID to use when fetching connection info.
|
65
66
|
:param impersonation_chain: Optional service account to impersonate using short-term
|
66
67
|
credentials, or chained list of accounts required to get the access_token
|
@@ -91,6 +92,7 @@ class CloudComposerDAGRunSensor(BaseSensorOperator):
|
|
91
92
|
composer_dag_id: str,
|
92
93
|
allowed_states: Iterable[str] | None = None,
|
93
94
|
execution_range: timedelta | list[datetime] | None = None,
|
95
|
+
composer_dag_run_id: str | None = None,
|
94
96
|
gcp_conn_id: str = "google_cloud_default",
|
95
97
|
impersonation_chain: str | Sequence[str] | None = None,
|
96
98
|
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
|
@@ -104,11 +106,17 @@ class CloudComposerDAGRunSensor(BaseSensorOperator):
|
|
104
106
|
self.composer_dag_id = composer_dag_id
|
105
107
|
self.allowed_states = list(allowed_states) if allowed_states else [TaskInstanceState.SUCCESS.value]
|
106
108
|
self.execution_range = execution_range
|
109
|
+
self.composer_dag_run_id = composer_dag_run_id
|
107
110
|
self.gcp_conn_id = gcp_conn_id
|
108
111
|
self.impersonation_chain = impersonation_chain
|
109
112
|
self.deferrable = deferrable
|
110
113
|
self.poll_interval = poll_interval
|
111
114
|
|
115
|
+
if self.composer_dag_run_id and self.execution_range:
|
116
|
+
self.log.warning(
|
117
|
+
"The composer_dag_run_id parameter and execution_range parameter do not work together. This run will ignore execution_range parameter and count only specified composer_dag_run_id parameter."
|
118
|
+
)
|
119
|
+
|
112
120
|
def _get_logical_dates(self, context) -> tuple[datetime, datetime]:
|
113
121
|
if isinstance(self.execution_range, timedelta):
|
114
122
|
if self.execution_range < timedelta(0):
|
@@ -128,6 +136,20 @@ class CloudComposerDAGRunSensor(BaseSensorOperator):
|
|
128
136
|
|
129
137
|
dag_runs = self._pull_dag_runs()
|
130
138
|
|
139
|
+
if len(dag_runs) == 0:
|
140
|
+
self.log.info("Dag runs are empty. Sensor waits for dag runs...")
|
141
|
+
return False
|
142
|
+
|
143
|
+
if self.composer_dag_run_id:
|
144
|
+
self.log.info(
|
145
|
+
"Sensor waits for allowed states %s for specified RunID: %s",
|
146
|
+
self.allowed_states,
|
147
|
+
self.composer_dag_run_id,
|
148
|
+
)
|
149
|
+
composer_dag_run_id_status = self._check_composer_dag_run_id_states(
|
150
|
+
dag_runs=dag_runs,
|
151
|
+
)
|
152
|
+
return composer_dag_run_id_status
|
131
153
|
self.log.info("Sensor waits for allowed states: %s", self.allowed_states)
|
132
154
|
allowed_states_status = self._check_dag_runs_states(
|
133
155
|
dag_runs=dag_runs,
|
@@ -189,16 +211,24 @@ class CloudComposerDAGRunSensor(BaseSensorOperator):
|
|
189
211
|
image_version = environment_config["config"]["software_config"]["image_version"]
|
190
212
|
return int(image_version.split("airflow-")[1].split(".")[0])
|
191
213
|
|
214
|
+
def _check_composer_dag_run_id_states(self, dag_runs: list[dict]) -> bool:
|
215
|
+
for dag_run in dag_runs:
|
216
|
+
if dag_run["run_id"] == self.composer_dag_run_id and dag_run["state"] in self.allowed_states:
|
217
|
+
return True
|
218
|
+
return False
|
219
|
+
|
192
220
|
def execute(self, context: Context) -> None:
|
193
221
|
self._composer_airflow_version = self._get_composer_airflow_version()
|
194
222
|
if self.deferrable:
|
195
223
|
start_date, end_date = self._get_logical_dates(context)
|
196
224
|
self.defer(
|
225
|
+
timeout=self.timeout,
|
197
226
|
trigger=CloudComposerDAGRunTrigger(
|
198
227
|
project_id=self.project_id,
|
199
228
|
region=self.region,
|
200
229
|
environment_id=self.environment_id,
|
201
230
|
composer_dag_id=self.composer_dag_id,
|
231
|
+
composer_dag_run_id=self.composer_dag_run_id,
|
202
232
|
start_date=start_date,
|
203
233
|
end_date=end_date,
|
204
234
|
allowed_states=self.allowed_states,
|