apache-airflow-providers-google 17.1.0__py3-none-any.whl → 17.2.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/google/__init__.py +1 -1
- airflow/providers/google/cloud/hooks/bigquery.py +0 -8
- airflow/providers/google/cloud/hooks/cloud_composer.py +6 -1
- airflow/providers/google/cloud/hooks/cloud_sql.py +3 -3
- airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +10 -5
- airflow/providers/google/cloud/hooks/dataflow.py +1 -1
- airflow/providers/google/cloud/hooks/spanner.py +26 -6
- airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +3 -0
- airflow/providers/google/cloud/openlineage/utils.py +14 -0
- airflow/providers/google/cloud/operators/bigquery.py +9 -1
- airflow/providers/google/cloud/operators/cloud_composer.py +9 -3
- airflow/providers/google/cloud/operators/dataplex.py +12 -12
- airflow/providers/google/cloud/operators/dataproc.py +15 -8
- airflow/providers/google/cloud/operators/pubsub.py +55 -8
- airflow/providers/google/cloud/operators/spanner.py +3 -2
- airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +9 -2
- airflow/providers/google/cloud/sensors/cloud_composer.py +30 -0
- airflow/providers/google/cloud/transfers/bigquery_to_mssql.py +72 -1
- airflow/providers/google/cloud/transfers/bigquery_to_sql.py +7 -3
- airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +0 -6
- airflow/providers/google/cloud/transfers/gcs_to_sftp.py +8 -2
- airflow/providers/google/cloud/transfers/oracle_to_gcs.py +27 -2
- airflow/providers/google/cloud/transfers/postgres_to_gcs.py +42 -9
- airflow/providers/google/cloud/triggers/bigquery.py +23 -19
- airflow/providers/google/cloud/triggers/cloud_composer.py +48 -10
- airflow/providers/google/cloud/triggers/dataplex.py +14 -2
- airflow/providers/google/cloud/triggers/dataproc.py +63 -46
- airflow/providers/google/common/utils/get_secret.py +31 -0
- airflow/providers/google/suite/hooks/sheets.py +15 -1
- airflow/providers/google/suite/operators/sheets.py +5 -0
- airflow/providers/google/version_compat.py +6 -0
- {apache_airflow_providers_google-17.1.0.dist-info → apache_airflow_providers_google-17.2.0rc1.dist-info}/METADATA +17 -18
- {apache_airflow_providers_google-17.1.0.dist-info → apache_airflow_providers_google-17.2.0rc1.dist-info}/RECORD +35 -34
- {apache_airflow_providers_google-17.1.0.dist-info → apache_airflow_providers_google-17.2.0rc1.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_google-17.1.0.dist-info → apache_airflow_providers_google-17.2.0rc1.dist-info}/entry_points.txt +0 -0
@@ -21,14 +21,17 @@ from __future__ import annotations
|
|
21
21
|
|
22
22
|
import warnings
|
23
23
|
from collections.abc import Sequence
|
24
|
+
from functools import cached_property
|
24
25
|
from typing import TYPE_CHECKING
|
25
26
|
|
26
27
|
from airflow.exceptions import AirflowProviderDeprecationWarning
|
28
|
+
from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook
|
27
29
|
from airflow.providers.google.cloud.links.bigquery import BigQueryTableLink
|
28
30
|
from airflow.providers.google.cloud.transfers.bigquery_to_sql import BigQueryToSqlBaseOperator
|
29
31
|
from airflow.providers.microsoft.mssql.hooks.mssql import MsSqlHook
|
30
32
|
|
31
33
|
if TYPE_CHECKING:
|
34
|
+
from airflow.providers.openlineage.extractors import OperatorLineage
|
32
35
|
from airflow.utils.context import Context
|
33
36
|
|
34
37
|
|
@@ -94,9 +97,13 @@ class BigQueryToMsSqlOperator(BigQueryToSqlBaseOperator):
|
|
94
97
|
self.mssql_conn_id = mssql_conn_id
|
95
98
|
self.source_project_dataset_table = source_project_dataset_table
|
96
99
|
|
97
|
-
|
100
|
+
@cached_property
|
101
|
+
def mssql_hook(self) -> MsSqlHook:
|
98
102
|
return MsSqlHook(schema=self.database, mssql_conn_id=self.mssql_conn_id)
|
99
103
|
|
104
|
+
def get_sql_hook(self) -> MsSqlHook:
|
105
|
+
return self.mssql_hook
|
106
|
+
|
100
107
|
def persist_links(self, context: Context) -> None:
|
101
108
|
project_id, dataset_id, table_id = self.source_project_dataset_table.split(".")
|
102
109
|
BigQueryTableLink.persist(
|
@@ -105,3 +112,67 @@ class BigQueryToMsSqlOperator(BigQueryToSqlBaseOperator):
|
|
105
112
|
project_id=project_id,
|
106
113
|
table_id=table_id,
|
107
114
|
)
|
115
|
+
|
116
|
+
def get_openlineage_facets_on_complete(self, task_instance) -> OperatorLineage | None:
|
117
|
+
from airflow.providers.common.compat.openlineage.facet import Dataset
|
118
|
+
from airflow.providers.google.cloud.openlineage.utils import (
|
119
|
+
BIGQUERY_NAMESPACE,
|
120
|
+
get_facets_from_bq_table_for_given_fields,
|
121
|
+
get_identity_column_lineage_facet,
|
122
|
+
)
|
123
|
+
from airflow.providers.openlineage.extractors import OperatorLineage
|
124
|
+
|
125
|
+
if not self.bigquery_hook:
|
126
|
+
self.bigquery_hook = BigQueryHook(
|
127
|
+
gcp_conn_id=self.gcp_conn_id,
|
128
|
+
location=self.location,
|
129
|
+
impersonation_chain=self.impersonation_chain,
|
130
|
+
)
|
131
|
+
|
132
|
+
try:
|
133
|
+
table_obj = self.bigquery_hook.get_client().get_table(self.source_project_dataset_table)
|
134
|
+
except Exception:
|
135
|
+
self.log.debug(
|
136
|
+
"OpenLineage: could not fetch BigQuery table %s",
|
137
|
+
self.source_project_dataset_table,
|
138
|
+
exc_info=True,
|
139
|
+
)
|
140
|
+
return OperatorLineage()
|
141
|
+
|
142
|
+
if self.selected_fields:
|
143
|
+
if isinstance(self.selected_fields, str):
|
144
|
+
bigquery_field_names = list(self.selected_fields)
|
145
|
+
else:
|
146
|
+
bigquery_field_names = self.selected_fields
|
147
|
+
else:
|
148
|
+
bigquery_field_names = [f.name for f in getattr(table_obj, "schema", [])]
|
149
|
+
|
150
|
+
input_dataset = Dataset(
|
151
|
+
namespace=BIGQUERY_NAMESPACE,
|
152
|
+
name=self.source_project_dataset_table,
|
153
|
+
facets=get_facets_from_bq_table_for_given_fields(table_obj, bigquery_field_names),
|
154
|
+
)
|
155
|
+
|
156
|
+
db_info = self.mssql_hook.get_openlineage_database_info(self.mssql_hook.get_conn())
|
157
|
+
default_schema = self.mssql_hook.get_openlineage_default_schema()
|
158
|
+
namespace = f"{db_info.scheme}://{db_info.authority}"
|
159
|
+
|
160
|
+
if self.target_table_name and "." in self.target_table_name:
|
161
|
+
schema_name, table_name = self.target_table_name.split(".", 1)
|
162
|
+
else:
|
163
|
+
schema_name = default_schema or ""
|
164
|
+
table_name = self.target_table_name or ""
|
165
|
+
|
166
|
+
if self.database:
|
167
|
+
output_name = f"{self.database}.{schema_name}.{table_name}"
|
168
|
+
else:
|
169
|
+
output_name = f"{schema_name}.{table_name}"
|
170
|
+
|
171
|
+
column_lineage_facet = get_identity_column_lineage_facet(
|
172
|
+
bigquery_field_names, input_datasets=[input_dataset]
|
173
|
+
)
|
174
|
+
|
175
|
+
output_facets = column_lineage_facet or {}
|
176
|
+
output_dataset = Dataset(namespace=namespace, name=output_name, facets=output_facets)
|
177
|
+
|
178
|
+
return OperatorLineage(inputs=[input_dataset], outputs=[output_dataset])
|
@@ -21,6 +21,7 @@ from __future__ import annotations
|
|
21
21
|
|
22
22
|
import abc
|
23
23
|
from collections.abc import Sequence
|
24
|
+
from functools import cached_property
|
24
25
|
from typing import TYPE_CHECKING
|
25
26
|
|
26
27
|
from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook
|
@@ -113,19 +114,22 @@ class BigQueryToSqlBaseOperator(BaseOperator):
|
|
113
114
|
def persist_links(self, context: Context) -> None:
|
114
115
|
"""Persist the connection to the SQL provider."""
|
115
116
|
|
116
|
-
|
117
|
-
|
117
|
+
@cached_property
|
118
|
+
def bigquery_hook(self) -> BigQueryHook:
|
119
|
+
return BigQueryHook(
|
118
120
|
gcp_conn_id=self.gcp_conn_id,
|
119
121
|
location=self.location,
|
120
122
|
impersonation_chain=self.impersonation_chain,
|
121
123
|
)
|
124
|
+
|
125
|
+
def execute(self, context: Context) -> None:
|
122
126
|
self.persist_links(context)
|
123
127
|
sql_hook = self.get_sql_hook()
|
124
128
|
for rows in bigquery_get_data(
|
125
129
|
self.log,
|
126
130
|
self.dataset_id,
|
127
131
|
self.table_id,
|
128
|
-
|
132
|
+
self.bigquery_hook,
|
129
133
|
self.batch_size,
|
130
134
|
self.selected_fields,
|
131
135
|
):
|
@@ -635,12 +635,6 @@ class GCSToBigQueryOperator(BaseOperator):
|
|
635
635
|
self.configuration["load"]["schema"] = {"fields": self.schema_fields}
|
636
636
|
|
637
637
|
if self.schema_update_options:
|
638
|
-
if self.write_disposition not in ["WRITE_APPEND", "WRITE_TRUNCATE"]:
|
639
|
-
raise ValueError(
|
640
|
-
"schema_update_options is only "
|
641
|
-
"allowed if write_disposition is "
|
642
|
-
"'WRITE_APPEND' or 'WRITE_TRUNCATE'."
|
643
|
-
)
|
644
638
|
# To provide backward compatibility
|
645
639
|
self.schema_update_options = list(self.schema_update_options or [])
|
646
640
|
self.log.info("Adding experimental 'schemaUpdateOptions': %s", self.schema_update_options)
|
@@ -80,7 +80,9 @@ class GCSToSFTPOperator(BaseOperator):
|
|
80
80
|
:param destination_path: The sftp remote path. This is the specified directory path for
|
81
81
|
uploading to the SFTP server.
|
82
82
|
:param keep_directory_structure: (Optional) When set to False the path of the file
|
83
|
-
|
83
|
+
on the bucket is recreated within path passed in destination_path.
|
84
|
+
:param create_intermediate_dirs: (Optional) When set to True the intermediate directories
|
85
|
+
in the specified file path will be created.
|
84
86
|
:param move_object: When move object is True, the object is moved instead
|
85
87
|
of copied to the new location. This is the equivalent of a mv command
|
86
88
|
as opposed to a cp command.
|
@@ -112,6 +114,7 @@ class GCSToSFTPOperator(BaseOperator):
|
|
112
114
|
source_object: str,
|
113
115
|
destination_path: str,
|
114
116
|
keep_directory_structure: bool = True,
|
117
|
+
create_intermediate_dirs: bool = True,
|
115
118
|
move_object: bool = False,
|
116
119
|
gcp_conn_id: str = "google_cloud_default",
|
117
120
|
sftp_conn_id: str = "ssh_default",
|
@@ -124,6 +127,7 @@ class GCSToSFTPOperator(BaseOperator):
|
|
124
127
|
self.source_object = source_object
|
125
128
|
self.destination_path = destination_path
|
126
129
|
self.keep_directory_structure = keep_directory_structure
|
130
|
+
self.create_intermediate_dirs = create_intermediate_dirs
|
127
131
|
self.move_object = move_object
|
128
132
|
self.gcp_conn_id = gcp_conn_id
|
129
133
|
self.sftp_conn_id = sftp_conn_id
|
@@ -190,7 +194,9 @@ class GCSToSFTPOperator(BaseOperator):
|
|
190
194
|
)
|
191
195
|
|
192
196
|
dir_path = os.path.dirname(destination_path)
|
193
|
-
|
197
|
+
|
198
|
+
if self.create_intermediate_dirs:
|
199
|
+
sftp_hook.create_directory(dir_path)
|
194
200
|
|
195
201
|
with NamedTemporaryFile("w") as tmp:
|
196
202
|
gcs_hook.download(
|
@@ -21,12 +21,17 @@ import base64
|
|
21
21
|
import calendar
|
22
22
|
from datetime import date, datetime, timedelta
|
23
23
|
from decimal import Decimal
|
24
|
+
from functools import cached_property
|
25
|
+
from typing import TYPE_CHECKING
|
24
26
|
|
25
27
|
import oracledb
|
26
28
|
|
27
29
|
from airflow.providers.google.cloud.transfers.sql_to_gcs import BaseSQLToGCSOperator
|
28
30
|
from airflow.providers.oracle.hooks.oracle import OracleHook
|
29
31
|
|
32
|
+
if TYPE_CHECKING:
|
33
|
+
from airflow.providers.openlineage.extractors import OperatorLineage
|
34
|
+
|
30
35
|
|
31
36
|
class OracleToGCSOperator(BaseSQLToGCSOperator):
|
32
37
|
"""
|
@@ -62,10 +67,13 @@ class OracleToGCSOperator(BaseSQLToGCSOperator):
|
|
62
67
|
self.ensure_utc = ensure_utc
|
63
68
|
self.oracle_conn_id = oracle_conn_id
|
64
69
|
|
70
|
+
@cached_property
|
71
|
+
def db_hook(self) -> OracleHook:
|
72
|
+
return OracleHook(oracle_conn_id=self.oracle_conn_id)
|
73
|
+
|
65
74
|
def query(self):
|
66
75
|
"""Query Oracle and returns a cursor to the results."""
|
67
|
-
|
68
|
-
conn = oracle.get_conn()
|
76
|
+
conn = self.db_hook.get_conn()
|
69
77
|
cursor = conn.cursor()
|
70
78
|
if self.ensure_utc:
|
71
79
|
# Ensure TIMESTAMP results are in UTC
|
@@ -121,3 +129,20 @@ class OracleToGCSOperator(BaseSQLToGCSOperator):
|
|
121
129
|
else:
|
122
130
|
value = base64.standard_b64encode(value).decode("ascii")
|
123
131
|
return value
|
132
|
+
|
133
|
+
def get_openlineage_facets_on_start(self) -> OperatorLineage | None:
|
134
|
+
from airflow.providers.common.compat.openlineage.facet import SQLJobFacet
|
135
|
+
from airflow.providers.common.compat.openlineage.utils.sql import get_openlineage_facets_with_sql
|
136
|
+
from airflow.providers.openlineage.extractors import OperatorLineage
|
137
|
+
|
138
|
+
sql_parsing_result = get_openlineage_facets_with_sql(
|
139
|
+
hook=self.db_hook,
|
140
|
+
sql=self.sql,
|
141
|
+
conn_id=self.oracle_conn_id,
|
142
|
+
database=self.db_hook.service_name or self.db_hook.sid,
|
143
|
+
)
|
144
|
+
gcs_output_datasets = self._get_openlineage_output_datasets()
|
145
|
+
if sql_parsing_result:
|
146
|
+
sql_parsing_result.outputs = gcs_output_datasets
|
147
|
+
return sql_parsing_result
|
148
|
+
return OperatorLineage(outputs=gcs_output_datasets, job_facets={"sql": SQLJobFacet(self.sql)})
|
@@ -31,7 +31,7 @@ import pendulum
|
|
31
31
|
from slugify import slugify
|
32
32
|
|
33
33
|
from airflow.providers.google.cloud.transfers.sql_to_gcs import BaseSQLToGCSOperator
|
34
|
-
from airflow.providers.postgres.hooks.postgres import PostgresHook
|
34
|
+
from airflow.providers.postgres.hooks.postgres import USE_PSYCOPG3, PostgresHook
|
35
35
|
|
36
36
|
if TYPE_CHECKING:
|
37
37
|
from airflow.providers.openlineage.extractors import OperatorLineage
|
@@ -52,9 +52,20 @@ class _PostgresServerSideCursorDecorator:
|
|
52
52
|
self.initialized = False
|
53
53
|
|
54
54
|
def __iter__(self):
|
55
|
+
"""Make the cursor iterable."""
|
55
56
|
return self
|
56
57
|
|
57
58
|
def __next__(self):
|
59
|
+
"""Fetch next row from the cursor."""
|
60
|
+
if USE_PSYCOPG3:
|
61
|
+
if self.rows:
|
62
|
+
return self.rows.pop()
|
63
|
+
self.initialized = True
|
64
|
+
row = self.cursor.fetchone()
|
65
|
+
if row is None:
|
66
|
+
raise StopIteration
|
67
|
+
return row
|
68
|
+
# psycopg2
|
58
69
|
if self.rows:
|
59
70
|
return self.rows.pop()
|
60
71
|
self.initialized = True
|
@@ -141,13 +152,29 @@ class PostgresToGCSOperator(BaseSQLToGCSOperator):
|
|
141
152
|
return PostgresHook(postgres_conn_id=self.postgres_conn_id)
|
142
153
|
|
143
154
|
def query(self):
|
144
|
-
"""
|
155
|
+
"""Execute the query and return a cursor."""
|
145
156
|
conn = self.db_hook.get_conn()
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
157
|
+
|
158
|
+
if USE_PSYCOPG3:
|
159
|
+
from psycopg.types.json import register_default_adapters
|
160
|
+
|
161
|
+
# Register JSON handlers for this connection if not already done
|
162
|
+
register_default_adapters(conn)
|
163
|
+
|
164
|
+
if self.use_server_side_cursor:
|
165
|
+
cursor_name = f"airflow_{self.task_id.replace('-', '_')}_{uuid.uuid4().hex}"[:63]
|
166
|
+
cursor = conn.cursor(name=cursor_name)
|
167
|
+
cursor.itersize = self.cursor_itersize
|
168
|
+
cursor.execute(self.sql, self.parameters)
|
169
|
+
return _PostgresServerSideCursorDecorator(cursor)
|
170
|
+
cursor = conn.cursor()
|
171
|
+
cursor.execute(self.sql, self.parameters)
|
172
|
+
else:
|
173
|
+
cursor = conn.cursor(name=self._unique_name())
|
174
|
+
cursor.execute(self.sql, self.parameters)
|
175
|
+
if self.use_server_side_cursor:
|
176
|
+
cursor.itersize = self.cursor_itersize
|
177
|
+
return _PostgresServerSideCursorDecorator(cursor)
|
151
178
|
return cursor
|
152
179
|
|
153
180
|
def field_to_bigquery(self, field) -> dict[str, str]:
|
@@ -182,8 +209,14 @@ class PostgresToGCSOperator(BaseSQLToGCSOperator):
|
|
182
209
|
hours=formatted_time.tm_hour, minutes=formatted_time.tm_min, seconds=formatted_time.tm_sec
|
183
210
|
)
|
184
211
|
return str(time_delta)
|
185
|
-
if stringify_dict
|
186
|
-
|
212
|
+
if stringify_dict:
|
213
|
+
if USE_PSYCOPG3:
|
214
|
+
from psycopg.types.json import Json
|
215
|
+
|
216
|
+
if isinstance(value, (dict, Json)):
|
217
|
+
return json.dumps(value)
|
218
|
+
elif isinstance(value, dict):
|
219
|
+
return json.dumps(value)
|
187
220
|
if isinstance(value, Decimal):
|
188
221
|
return float(value)
|
189
222
|
return value
|
@@ -25,16 +25,18 @@ from aiohttp.client_exceptions import ClientResponseError
|
|
25
25
|
from asgiref.sync import sync_to_async
|
26
26
|
|
27
27
|
from airflow.exceptions import AirflowException
|
28
|
-
from airflow.models.taskinstance import TaskInstance
|
29
28
|
from airflow.providers.google.cloud.hooks.bigquery import BigQueryAsyncHook, BigQueryTableAsyncHook
|
30
29
|
from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS
|
31
30
|
from airflow.triggers.base import BaseTrigger, TriggerEvent
|
32
|
-
from airflow.utils.session import provide_session
|
33
31
|
from airflow.utils.state import TaskInstanceState
|
34
32
|
|
35
33
|
if TYPE_CHECKING:
|
36
34
|
from sqlalchemy.orm.session import Session
|
37
35
|
|
36
|
+
if not AIRFLOW_V_3_0_PLUS:
|
37
|
+
from airflow.models.taskinstance import TaskInstance
|
38
|
+
from airflow.utils.session import provide_session
|
39
|
+
|
38
40
|
|
39
41
|
class BigQueryInsertJobTrigger(BaseTrigger):
|
40
42
|
"""
|
@@ -99,24 +101,26 @@ class BigQueryInsertJobTrigger(BaseTrigger):
|
|
99
101
|
},
|
100
102
|
)
|
101
103
|
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
if task_instance is None:
|
112
|
-
raise AirflowException(
|
113
|
-
"TaskInstance with dag_id: %s, task_id: %s, run_id: %s and map_index: %s is not found",
|
114
|
-
self.task_instance.dag_id,
|
115
|
-
self.task_instance.task_id,
|
116
|
-
self.task_instance.run_id,
|
117
|
-
self.task_instance.map_index,
|
104
|
+
if not AIRFLOW_V_3_0_PLUS:
|
105
|
+
|
106
|
+
@provide_session
|
107
|
+
def get_task_instance(self, session: Session) -> TaskInstance:
|
108
|
+
query = session.query(TaskInstance).filter(
|
109
|
+
TaskInstance.dag_id == self.task_instance.dag_id,
|
110
|
+
TaskInstance.task_id == self.task_instance.task_id,
|
111
|
+
TaskInstance.run_id == self.task_instance.run_id,
|
112
|
+
TaskInstance.map_index == self.task_instance.map_index,
|
118
113
|
)
|
119
|
-
|
114
|
+
task_instance = query.one_or_none()
|
115
|
+
if task_instance is None:
|
116
|
+
raise AirflowException(
|
117
|
+
"TaskInstance with dag_id: %s, task_id: %s, run_id: %s and map_index: %s is not found",
|
118
|
+
self.task_instance.dag_id,
|
119
|
+
self.task_instance.task_id,
|
120
|
+
self.task_instance.run_id,
|
121
|
+
self.task_instance.map_index,
|
122
|
+
)
|
123
|
+
return task_instance
|
120
124
|
|
121
125
|
async def get_task_state(self):
|
122
126
|
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
|
@@ -145,10 +145,23 @@ class CloudComposerAirflowCLICommandTrigger(BaseTrigger):
|
|
145
145
|
)
|
146
146
|
return
|
147
147
|
|
148
|
+
exit_code = result.get("exit_info", {}).get("exit_code")
|
149
|
+
|
150
|
+
if exit_code == 0:
|
151
|
+
yield TriggerEvent(
|
152
|
+
{
|
153
|
+
"status": "success",
|
154
|
+
"result": result,
|
155
|
+
}
|
156
|
+
)
|
157
|
+
return
|
158
|
+
|
159
|
+
error_output = "".join(line["content"] for line in result.get("error", []))
|
160
|
+
message = f"Airflow CLI command failed with exit code {exit_code}.\nError output:\n{error_output}"
|
148
161
|
yield TriggerEvent(
|
149
162
|
{
|
150
|
-
"status": "
|
151
|
-
"
|
163
|
+
"status": "error",
|
164
|
+
"message": message,
|
152
165
|
}
|
153
166
|
)
|
154
167
|
return
|
@@ -166,6 +179,7 @@ class CloudComposerDAGRunTrigger(BaseTrigger):
|
|
166
179
|
start_date: datetime,
|
167
180
|
end_date: datetime,
|
168
181
|
allowed_states: list[str],
|
182
|
+
composer_dag_run_id: str | None = None,
|
169
183
|
gcp_conn_id: str = "google_cloud_default",
|
170
184
|
impersonation_chain: str | Sequence[str] | None = None,
|
171
185
|
poll_interval: int = 10,
|
@@ -179,6 +193,7 @@ class CloudComposerDAGRunTrigger(BaseTrigger):
|
|
179
193
|
self.start_date = start_date
|
180
194
|
self.end_date = end_date
|
181
195
|
self.allowed_states = allowed_states
|
196
|
+
self.composer_dag_run_id = composer_dag_run_id
|
182
197
|
self.gcp_conn_id = gcp_conn_id
|
183
198
|
self.impersonation_chain = impersonation_chain
|
184
199
|
self.poll_interval = poll_interval
|
@@ -200,6 +215,7 @@ class CloudComposerDAGRunTrigger(BaseTrigger):
|
|
200
215
|
"start_date": self.start_date,
|
201
216
|
"end_date": self.end_date,
|
202
217
|
"allowed_states": self.allowed_states,
|
218
|
+
"composer_dag_run_id": self.composer_dag_run_id,
|
203
219
|
"gcp_conn_id": self.gcp_conn_id,
|
204
220
|
"impersonation_chain": self.impersonation_chain,
|
205
221
|
"poll_interval": self.poll_interval,
|
@@ -248,20 +264,42 @@ class CloudComposerDAGRunTrigger(BaseTrigger):
|
|
248
264
|
return False
|
249
265
|
return True
|
250
266
|
|
267
|
+
def _check_composer_dag_run_id_states(self, dag_runs: list[dict]) -> bool:
|
268
|
+
for dag_run in dag_runs:
|
269
|
+
if dag_run["run_id"] == self.composer_dag_run_id and dag_run["state"] in self.allowed_states:
|
270
|
+
return True
|
271
|
+
return False
|
272
|
+
|
251
273
|
async def run(self):
|
252
274
|
try:
|
253
275
|
while True:
|
254
276
|
if datetime.now(self.end_date.tzinfo).timestamp() > self.end_date.timestamp():
|
255
277
|
dag_runs = await self._pull_dag_runs()
|
256
278
|
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
279
|
+
if len(dag_runs) == 0:
|
280
|
+
self.log.info("Dag runs are empty. Sensor waits for dag runs...")
|
281
|
+
self.log.info("Sleeping for %s seconds.", self.poll_interval)
|
282
|
+
await asyncio.sleep(self.poll_interval)
|
283
|
+
continue
|
284
|
+
|
285
|
+
if self.composer_dag_run_id:
|
286
|
+
self.log.info(
|
287
|
+
"Sensor waits for allowed states %s for specified RunID: %s",
|
288
|
+
self.allowed_states,
|
289
|
+
self.composer_dag_run_id,
|
290
|
+
)
|
291
|
+
if self._check_composer_dag_run_id_states(dag_runs=dag_runs):
|
292
|
+
yield TriggerEvent({"status": "success"})
|
293
|
+
return
|
294
|
+
else:
|
295
|
+
self.log.info("Sensor waits for allowed states: %s", self.allowed_states)
|
296
|
+
if self._check_dag_runs_states(
|
297
|
+
dag_runs=dag_runs,
|
298
|
+
start_date=self.start_date,
|
299
|
+
end_date=self.end_date,
|
300
|
+
):
|
301
|
+
yield TriggerEvent({"status": "success"})
|
302
|
+
return
|
265
303
|
self.log.info("Sleeping for %s seconds.", self.poll_interval)
|
266
304
|
await asyncio.sleep(self.poll_interval)
|
267
305
|
except AirflowException as ex:
|
@@ -103,7 +103,13 @@ class DataplexDataQualityJobTrigger(BaseTrigger):
|
|
103
103
|
self.polling_interval_seconds,
|
104
104
|
)
|
105
105
|
await asyncio.sleep(self.polling_interval_seconds)
|
106
|
-
yield TriggerEvent(
|
106
|
+
yield TriggerEvent(
|
107
|
+
{
|
108
|
+
"job_id": self.job_id,
|
109
|
+
"job_state": DataScanJob.State(state).name,
|
110
|
+
"job": self._convert_to_dict(job),
|
111
|
+
}
|
112
|
+
)
|
107
113
|
|
108
114
|
def _convert_to_dict(self, job: DataScanJob) -> dict:
|
109
115
|
"""Return a representation of a DataScanJob instance as a dict."""
|
@@ -185,7 +191,13 @@ class DataplexDataProfileJobTrigger(BaseTrigger):
|
|
185
191
|
self.polling_interval_seconds,
|
186
192
|
)
|
187
193
|
await asyncio.sleep(self.polling_interval_seconds)
|
188
|
-
yield TriggerEvent(
|
194
|
+
yield TriggerEvent(
|
195
|
+
{
|
196
|
+
"job_id": self.job_id,
|
197
|
+
"job_state": DataScanJob.State(state).name,
|
198
|
+
"job": self._convert_to_dict(job),
|
199
|
+
}
|
200
|
+
)
|
189
201
|
|
190
202
|
def _convert_to_dict(self, job: DataScanJob) -> dict:
|
191
203
|
"""Return a representation of a DataScanJob instance as a dict."""
|