apache-airflow-providers-google 17.1.0__py3-none-any.whl → 17.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. airflow/providers/google/__init__.py +1 -1
  2. airflow/providers/google/cloud/hooks/bigquery.py +0 -8
  3. airflow/providers/google/cloud/hooks/cloud_composer.py +6 -1
  4. airflow/providers/google/cloud/hooks/cloud_sql.py +3 -3
  5. airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +10 -5
  6. airflow/providers/google/cloud/hooks/dataflow.py +1 -1
  7. airflow/providers/google/cloud/hooks/spanner.py +26 -6
  8. airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +3 -0
  9. airflow/providers/google/cloud/openlineage/utils.py +14 -0
  10. airflow/providers/google/cloud/operators/bigquery.py +9 -1
  11. airflow/providers/google/cloud/operators/cloud_composer.py +9 -3
  12. airflow/providers/google/cloud/operators/dataplex.py +12 -12
  13. airflow/providers/google/cloud/operators/dataproc.py +15 -8
  14. airflow/providers/google/cloud/operators/pubsub.py +55 -8
  15. airflow/providers/google/cloud/operators/spanner.py +3 -2
  16. airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +9 -2
  17. airflow/providers/google/cloud/sensors/cloud_composer.py +30 -0
  18. airflow/providers/google/cloud/transfers/bigquery_to_mssql.py +72 -1
  19. airflow/providers/google/cloud/transfers/bigquery_to_sql.py +7 -3
  20. airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +0 -6
  21. airflow/providers/google/cloud/transfers/gcs_to_sftp.py +8 -2
  22. airflow/providers/google/cloud/transfers/oracle_to_gcs.py +27 -2
  23. airflow/providers/google/cloud/transfers/postgres_to_gcs.py +42 -9
  24. airflow/providers/google/cloud/triggers/bigquery.py +23 -19
  25. airflow/providers/google/cloud/triggers/cloud_composer.py +48 -10
  26. airflow/providers/google/cloud/triggers/dataplex.py +14 -2
  27. airflow/providers/google/cloud/triggers/dataproc.py +63 -46
  28. airflow/providers/google/common/utils/get_secret.py +31 -0
  29. airflow/providers/google/suite/hooks/sheets.py +15 -1
  30. airflow/providers/google/suite/operators/sheets.py +5 -0
  31. airflow/providers/google/version_compat.py +6 -0
  32. {apache_airflow_providers_google-17.1.0.dist-info → apache_airflow_providers_google-17.2.0.dist-info}/METADATA +9 -10
  33. {apache_airflow_providers_google-17.1.0.dist-info → apache_airflow_providers_google-17.2.0.dist-info}/RECORD +35 -34
  34. {apache_airflow_providers_google-17.1.0.dist-info → apache_airflow_providers_google-17.2.0.dist-info}/WHEEL +0 -0
  35. {apache_airflow_providers_google-17.1.0.dist-info → apache_airflow_providers_google-17.2.0.dist-info}/entry_points.txt +0 -0
@@ -21,14 +21,17 @@ from __future__ import annotations
21
21
 
22
22
  import warnings
23
23
  from collections.abc import Sequence
24
+ from functools import cached_property
24
25
  from typing import TYPE_CHECKING
25
26
 
26
27
  from airflow.exceptions import AirflowProviderDeprecationWarning
28
+ from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook
27
29
  from airflow.providers.google.cloud.links.bigquery import BigQueryTableLink
28
30
  from airflow.providers.google.cloud.transfers.bigquery_to_sql import BigQueryToSqlBaseOperator
29
31
  from airflow.providers.microsoft.mssql.hooks.mssql import MsSqlHook
30
32
 
31
33
  if TYPE_CHECKING:
34
+ from airflow.providers.openlineage.extractors import OperatorLineage
32
35
  from airflow.utils.context import Context
33
36
 
34
37
 
@@ -94,9 +97,13 @@ class BigQueryToMsSqlOperator(BigQueryToSqlBaseOperator):
94
97
  self.mssql_conn_id = mssql_conn_id
95
98
  self.source_project_dataset_table = source_project_dataset_table
96
99
 
97
- def get_sql_hook(self) -> MsSqlHook:
100
+ @cached_property
101
+ def mssql_hook(self) -> MsSqlHook:
98
102
  return MsSqlHook(schema=self.database, mssql_conn_id=self.mssql_conn_id)
99
103
 
104
+ def get_sql_hook(self) -> MsSqlHook:
105
+ return self.mssql_hook
106
+
100
107
  def persist_links(self, context: Context) -> None:
101
108
  project_id, dataset_id, table_id = self.source_project_dataset_table.split(".")
102
109
  BigQueryTableLink.persist(
@@ -105,3 +112,67 @@ class BigQueryToMsSqlOperator(BigQueryToSqlBaseOperator):
105
112
  project_id=project_id,
106
113
  table_id=table_id,
107
114
  )
115
+
116
+ def get_openlineage_facets_on_complete(self, task_instance) -> OperatorLineage | None:
117
+ from airflow.providers.common.compat.openlineage.facet import Dataset
118
+ from airflow.providers.google.cloud.openlineage.utils import (
119
+ BIGQUERY_NAMESPACE,
120
+ get_facets_from_bq_table_for_given_fields,
121
+ get_identity_column_lineage_facet,
122
+ )
123
+ from airflow.providers.openlineage.extractors import OperatorLineage
124
+
125
+ if not self.bigquery_hook:
126
+ self.bigquery_hook = BigQueryHook(
127
+ gcp_conn_id=self.gcp_conn_id,
128
+ location=self.location,
129
+ impersonation_chain=self.impersonation_chain,
130
+ )
131
+
132
+ try:
133
+ table_obj = self.bigquery_hook.get_client().get_table(self.source_project_dataset_table)
134
+ except Exception:
135
+ self.log.debug(
136
+ "OpenLineage: could not fetch BigQuery table %s",
137
+ self.source_project_dataset_table,
138
+ exc_info=True,
139
+ )
140
+ return OperatorLineage()
141
+
142
+ if self.selected_fields:
143
+ if isinstance(self.selected_fields, str):
144
+ bigquery_field_names = list(self.selected_fields)
145
+ else:
146
+ bigquery_field_names = self.selected_fields
147
+ else:
148
+ bigquery_field_names = [f.name for f in getattr(table_obj, "schema", [])]
149
+
150
+ input_dataset = Dataset(
151
+ namespace=BIGQUERY_NAMESPACE,
152
+ name=self.source_project_dataset_table,
153
+ facets=get_facets_from_bq_table_for_given_fields(table_obj, bigquery_field_names),
154
+ )
155
+
156
+ db_info = self.mssql_hook.get_openlineage_database_info(self.mssql_hook.get_conn())
157
+ default_schema = self.mssql_hook.get_openlineage_default_schema()
158
+ namespace = f"{db_info.scheme}://{db_info.authority}"
159
+
160
+ if self.target_table_name and "." in self.target_table_name:
161
+ schema_name, table_name = self.target_table_name.split(".", 1)
162
+ else:
163
+ schema_name = default_schema or ""
164
+ table_name = self.target_table_name or ""
165
+
166
+ if self.database:
167
+ output_name = f"{self.database}.{schema_name}.{table_name}"
168
+ else:
169
+ output_name = f"{schema_name}.{table_name}"
170
+
171
+ column_lineage_facet = get_identity_column_lineage_facet(
172
+ bigquery_field_names, input_datasets=[input_dataset]
173
+ )
174
+
175
+ output_facets = column_lineage_facet or {}
176
+ output_dataset = Dataset(namespace=namespace, name=output_name, facets=output_facets)
177
+
178
+ return OperatorLineage(inputs=[input_dataset], outputs=[output_dataset])
@@ -21,6 +21,7 @@ from __future__ import annotations
21
21
 
22
22
  import abc
23
23
  from collections.abc import Sequence
24
+ from functools import cached_property
24
25
  from typing import TYPE_CHECKING
25
26
 
26
27
  from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook
@@ -113,19 +114,22 @@ class BigQueryToSqlBaseOperator(BaseOperator):
113
114
  def persist_links(self, context: Context) -> None:
114
115
  """Persist the connection to the SQL provider."""
115
116
 
116
- def execute(self, context: Context) -> None:
117
- big_query_hook = BigQueryHook(
117
+ @cached_property
118
+ def bigquery_hook(self) -> BigQueryHook:
119
+ return BigQueryHook(
118
120
  gcp_conn_id=self.gcp_conn_id,
119
121
  location=self.location,
120
122
  impersonation_chain=self.impersonation_chain,
121
123
  )
124
+
125
+ def execute(self, context: Context) -> None:
122
126
  self.persist_links(context)
123
127
  sql_hook = self.get_sql_hook()
124
128
  for rows in bigquery_get_data(
125
129
  self.log,
126
130
  self.dataset_id,
127
131
  self.table_id,
128
- big_query_hook,
132
+ self.bigquery_hook,
129
133
  self.batch_size,
130
134
  self.selected_fields,
131
135
  ):
@@ -635,12 +635,6 @@ class GCSToBigQueryOperator(BaseOperator):
635
635
  self.configuration["load"]["schema"] = {"fields": self.schema_fields}
636
636
 
637
637
  if self.schema_update_options:
638
- if self.write_disposition not in ["WRITE_APPEND", "WRITE_TRUNCATE"]:
639
- raise ValueError(
640
- "schema_update_options is only "
641
- "allowed if write_disposition is "
642
- "'WRITE_APPEND' or 'WRITE_TRUNCATE'."
643
- )
644
638
  # To provide backward compatibility
645
639
  self.schema_update_options = list(self.schema_update_options or [])
646
640
  self.log.info("Adding experimental 'schemaUpdateOptions': %s", self.schema_update_options)
@@ -80,7 +80,9 @@ class GCSToSFTPOperator(BaseOperator):
80
80
  :param destination_path: The sftp remote path. This is the specified directory path for
81
81
  uploading to the SFTP server.
82
82
  :param keep_directory_structure: (Optional) When set to False the path of the file
83
- on the bucket is recreated within path passed in destination_path.
83
+ on the bucket is recreated within path passed in destination_path.
84
+ :param create_intermediate_dirs: (Optional) When set to True the intermediate directories
85
+ in the specified file path will be created.
84
86
  :param move_object: When move object is True, the object is moved instead
85
87
  of copied to the new location. This is the equivalent of a mv command
86
88
  as opposed to a cp command.
@@ -112,6 +114,7 @@ class GCSToSFTPOperator(BaseOperator):
112
114
  source_object: str,
113
115
  destination_path: str,
114
116
  keep_directory_structure: bool = True,
117
+ create_intermediate_dirs: bool = True,
115
118
  move_object: bool = False,
116
119
  gcp_conn_id: str = "google_cloud_default",
117
120
  sftp_conn_id: str = "ssh_default",
@@ -124,6 +127,7 @@ class GCSToSFTPOperator(BaseOperator):
124
127
  self.source_object = source_object
125
128
  self.destination_path = destination_path
126
129
  self.keep_directory_structure = keep_directory_structure
130
+ self.create_intermediate_dirs = create_intermediate_dirs
127
131
  self.move_object = move_object
128
132
  self.gcp_conn_id = gcp_conn_id
129
133
  self.sftp_conn_id = sftp_conn_id
@@ -190,7 +194,9 @@ class GCSToSFTPOperator(BaseOperator):
190
194
  )
191
195
 
192
196
  dir_path = os.path.dirname(destination_path)
193
- sftp_hook.create_directory(dir_path)
197
+
198
+ if self.create_intermediate_dirs:
199
+ sftp_hook.create_directory(dir_path)
194
200
 
195
201
  with NamedTemporaryFile("w") as tmp:
196
202
  gcs_hook.download(
@@ -21,12 +21,17 @@ import base64
21
21
  import calendar
22
22
  from datetime import date, datetime, timedelta
23
23
  from decimal import Decimal
24
+ from functools import cached_property
25
+ from typing import TYPE_CHECKING
24
26
 
25
27
  import oracledb
26
28
 
27
29
  from airflow.providers.google.cloud.transfers.sql_to_gcs import BaseSQLToGCSOperator
28
30
  from airflow.providers.oracle.hooks.oracle import OracleHook
29
31
 
32
+ if TYPE_CHECKING:
33
+ from airflow.providers.openlineage.extractors import OperatorLineage
34
+
30
35
 
31
36
  class OracleToGCSOperator(BaseSQLToGCSOperator):
32
37
  """
@@ -62,10 +67,13 @@ class OracleToGCSOperator(BaseSQLToGCSOperator):
62
67
  self.ensure_utc = ensure_utc
63
68
  self.oracle_conn_id = oracle_conn_id
64
69
 
70
+ @cached_property
71
+ def db_hook(self) -> OracleHook:
72
+ return OracleHook(oracle_conn_id=self.oracle_conn_id)
73
+
65
74
  def query(self):
66
75
  """Query Oracle and returns a cursor to the results."""
67
- oracle = OracleHook(oracle_conn_id=self.oracle_conn_id)
68
- conn = oracle.get_conn()
76
+ conn = self.db_hook.get_conn()
69
77
  cursor = conn.cursor()
70
78
  if self.ensure_utc:
71
79
  # Ensure TIMESTAMP results are in UTC
@@ -121,3 +129,20 @@ class OracleToGCSOperator(BaseSQLToGCSOperator):
121
129
  else:
122
130
  value = base64.standard_b64encode(value).decode("ascii")
123
131
  return value
132
+
133
+ def get_openlineage_facets_on_start(self) -> OperatorLineage | None:
134
+ from airflow.providers.common.compat.openlineage.facet import SQLJobFacet
135
+ from airflow.providers.common.compat.openlineage.utils.sql import get_openlineage_facets_with_sql
136
+ from airflow.providers.openlineage.extractors import OperatorLineage
137
+
138
+ sql_parsing_result = get_openlineage_facets_with_sql(
139
+ hook=self.db_hook,
140
+ sql=self.sql,
141
+ conn_id=self.oracle_conn_id,
142
+ database=self.db_hook.service_name or self.db_hook.sid,
143
+ )
144
+ gcs_output_datasets = self._get_openlineage_output_datasets()
145
+ if sql_parsing_result:
146
+ sql_parsing_result.outputs = gcs_output_datasets
147
+ return sql_parsing_result
148
+ return OperatorLineage(outputs=gcs_output_datasets, job_facets={"sql": SQLJobFacet(self.sql)})
@@ -31,7 +31,7 @@ import pendulum
31
31
  from slugify import slugify
32
32
 
33
33
  from airflow.providers.google.cloud.transfers.sql_to_gcs import BaseSQLToGCSOperator
34
- from airflow.providers.postgres.hooks.postgres import PostgresHook
34
+ from airflow.providers.postgres.hooks.postgres import USE_PSYCOPG3, PostgresHook
35
35
 
36
36
  if TYPE_CHECKING:
37
37
  from airflow.providers.openlineage.extractors import OperatorLineage
@@ -52,9 +52,20 @@ class _PostgresServerSideCursorDecorator:
52
52
  self.initialized = False
53
53
 
54
54
  def __iter__(self):
55
+ """Make the cursor iterable."""
55
56
  return self
56
57
 
57
58
  def __next__(self):
59
+ """Fetch next row from the cursor."""
60
+ if USE_PSYCOPG3:
61
+ if self.rows:
62
+ return self.rows.pop()
63
+ self.initialized = True
64
+ row = self.cursor.fetchone()
65
+ if row is None:
66
+ raise StopIteration
67
+ return row
68
+ # psycopg2
58
69
  if self.rows:
59
70
  return self.rows.pop()
60
71
  self.initialized = True
@@ -141,13 +152,29 @@ class PostgresToGCSOperator(BaseSQLToGCSOperator):
141
152
  return PostgresHook(postgres_conn_id=self.postgres_conn_id)
142
153
 
143
154
  def query(self):
144
- """Query Postgres and returns a cursor to the results."""
155
+ """Execute the query and return a cursor."""
145
156
  conn = self.db_hook.get_conn()
146
- cursor = conn.cursor(name=self._unique_name())
147
- cursor.execute(self.sql, self.parameters)
148
- if self.use_server_side_cursor:
149
- cursor.itersize = self.cursor_itersize
150
- return _PostgresServerSideCursorDecorator(cursor)
157
+
158
+ if USE_PSYCOPG3:
159
+ from psycopg.types.json import register_default_adapters
160
+
161
+ # Register JSON handlers for this connection if not already done
162
+ register_default_adapters(conn)
163
+
164
+ if self.use_server_side_cursor:
165
+ cursor_name = f"airflow_{self.task_id.replace('-', '_')}_{uuid.uuid4().hex}"[:63]
166
+ cursor = conn.cursor(name=cursor_name)
167
+ cursor.itersize = self.cursor_itersize
168
+ cursor.execute(self.sql, self.parameters)
169
+ return _PostgresServerSideCursorDecorator(cursor)
170
+ cursor = conn.cursor()
171
+ cursor.execute(self.sql, self.parameters)
172
+ else:
173
+ cursor = conn.cursor(name=self._unique_name())
174
+ cursor.execute(self.sql, self.parameters)
175
+ if self.use_server_side_cursor:
176
+ cursor.itersize = self.cursor_itersize
177
+ return _PostgresServerSideCursorDecorator(cursor)
151
178
  return cursor
152
179
 
153
180
  def field_to_bigquery(self, field) -> dict[str, str]:
@@ -182,8 +209,14 @@ class PostgresToGCSOperator(BaseSQLToGCSOperator):
182
209
  hours=formatted_time.tm_hour, minutes=formatted_time.tm_min, seconds=formatted_time.tm_sec
183
210
  )
184
211
  return str(time_delta)
185
- if stringify_dict and isinstance(value, dict):
186
- return json.dumps(value)
212
+ if stringify_dict:
213
+ if USE_PSYCOPG3:
214
+ from psycopg.types.json import Json
215
+
216
+ if isinstance(value, (dict, Json)):
217
+ return json.dumps(value)
218
+ elif isinstance(value, dict):
219
+ return json.dumps(value)
187
220
  if isinstance(value, Decimal):
188
221
  return float(value)
189
222
  return value
@@ -25,16 +25,18 @@ from aiohttp.client_exceptions import ClientResponseError
25
25
  from asgiref.sync import sync_to_async
26
26
 
27
27
  from airflow.exceptions import AirflowException
28
- from airflow.models.taskinstance import TaskInstance
29
28
  from airflow.providers.google.cloud.hooks.bigquery import BigQueryAsyncHook, BigQueryTableAsyncHook
30
29
  from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS
31
30
  from airflow.triggers.base import BaseTrigger, TriggerEvent
32
- from airflow.utils.session import provide_session
33
31
  from airflow.utils.state import TaskInstanceState
34
32
 
35
33
  if TYPE_CHECKING:
36
34
  from sqlalchemy.orm.session import Session
37
35
 
36
+ if not AIRFLOW_V_3_0_PLUS:
37
+ from airflow.models.taskinstance import TaskInstance
38
+ from airflow.utils.session import provide_session
39
+
38
40
 
39
41
  class BigQueryInsertJobTrigger(BaseTrigger):
40
42
  """
@@ -99,24 +101,26 @@ class BigQueryInsertJobTrigger(BaseTrigger):
99
101
  },
100
102
  )
101
103
 
102
- @provide_session
103
- def get_task_instance(self, session: Session) -> TaskInstance:
104
- query = session.query(TaskInstance).filter(
105
- TaskInstance.dag_id == self.task_instance.dag_id,
106
- TaskInstance.task_id == self.task_instance.task_id,
107
- TaskInstance.run_id == self.task_instance.run_id,
108
- TaskInstance.map_index == self.task_instance.map_index,
109
- )
110
- task_instance = query.one_or_none()
111
- if task_instance is None:
112
- raise AirflowException(
113
- "TaskInstance with dag_id: %s, task_id: %s, run_id: %s and map_index: %s is not found",
114
- self.task_instance.dag_id,
115
- self.task_instance.task_id,
116
- self.task_instance.run_id,
117
- self.task_instance.map_index,
104
+ if not AIRFLOW_V_3_0_PLUS:
105
+
106
+ @provide_session
107
+ def get_task_instance(self, session: Session) -> TaskInstance:
108
+ query = session.query(TaskInstance).filter(
109
+ TaskInstance.dag_id == self.task_instance.dag_id,
110
+ TaskInstance.task_id == self.task_instance.task_id,
111
+ TaskInstance.run_id == self.task_instance.run_id,
112
+ TaskInstance.map_index == self.task_instance.map_index,
118
113
  )
119
- return task_instance
114
+ task_instance = query.one_or_none()
115
+ if task_instance is None:
116
+ raise AirflowException(
117
+ "TaskInstance with dag_id: %s, task_id: %s, run_id: %s and map_index: %s is not found",
118
+ self.task_instance.dag_id,
119
+ self.task_instance.task_id,
120
+ self.task_instance.run_id,
121
+ self.task_instance.map_index,
122
+ )
123
+ return task_instance
120
124
 
121
125
  async def get_task_state(self):
122
126
  from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
@@ -145,10 +145,23 @@ class CloudComposerAirflowCLICommandTrigger(BaseTrigger):
145
145
  )
146
146
  return
147
147
 
148
+ exit_code = result.get("exit_info", {}).get("exit_code")
149
+
150
+ if exit_code == 0:
151
+ yield TriggerEvent(
152
+ {
153
+ "status": "success",
154
+ "result": result,
155
+ }
156
+ )
157
+ return
158
+
159
+ error_output = "".join(line["content"] for line in result.get("error", []))
160
+ message = f"Airflow CLI command failed with exit code {exit_code}.\nError output:\n{error_output}"
148
161
  yield TriggerEvent(
149
162
  {
150
- "status": "success",
151
- "result": result,
163
+ "status": "error",
164
+ "message": message,
152
165
  }
153
166
  )
154
167
  return
@@ -166,6 +179,7 @@ class CloudComposerDAGRunTrigger(BaseTrigger):
166
179
  start_date: datetime,
167
180
  end_date: datetime,
168
181
  allowed_states: list[str],
182
+ composer_dag_run_id: str | None = None,
169
183
  gcp_conn_id: str = "google_cloud_default",
170
184
  impersonation_chain: str | Sequence[str] | None = None,
171
185
  poll_interval: int = 10,
@@ -179,6 +193,7 @@ class CloudComposerDAGRunTrigger(BaseTrigger):
179
193
  self.start_date = start_date
180
194
  self.end_date = end_date
181
195
  self.allowed_states = allowed_states
196
+ self.composer_dag_run_id = composer_dag_run_id
182
197
  self.gcp_conn_id = gcp_conn_id
183
198
  self.impersonation_chain = impersonation_chain
184
199
  self.poll_interval = poll_interval
@@ -200,6 +215,7 @@ class CloudComposerDAGRunTrigger(BaseTrigger):
200
215
  "start_date": self.start_date,
201
216
  "end_date": self.end_date,
202
217
  "allowed_states": self.allowed_states,
218
+ "composer_dag_run_id": self.composer_dag_run_id,
203
219
  "gcp_conn_id": self.gcp_conn_id,
204
220
  "impersonation_chain": self.impersonation_chain,
205
221
  "poll_interval": self.poll_interval,
@@ -248,20 +264,42 @@ class CloudComposerDAGRunTrigger(BaseTrigger):
248
264
  return False
249
265
  return True
250
266
 
267
+ def _check_composer_dag_run_id_states(self, dag_runs: list[dict]) -> bool:
268
+ for dag_run in dag_runs:
269
+ if dag_run["run_id"] == self.composer_dag_run_id and dag_run["state"] in self.allowed_states:
270
+ return True
271
+ return False
272
+
251
273
  async def run(self):
252
274
  try:
253
275
  while True:
254
276
  if datetime.now(self.end_date.tzinfo).timestamp() > self.end_date.timestamp():
255
277
  dag_runs = await self._pull_dag_runs()
256
278
 
257
- self.log.info("Sensor waits for allowed states: %s", self.allowed_states)
258
- if self._check_dag_runs_states(
259
- dag_runs=dag_runs,
260
- start_date=self.start_date,
261
- end_date=self.end_date,
262
- ):
263
- yield TriggerEvent({"status": "success"})
264
- return
279
+ if len(dag_runs) == 0:
280
+ self.log.info("Dag runs are empty. Sensor waits for dag runs...")
281
+ self.log.info("Sleeping for %s seconds.", self.poll_interval)
282
+ await asyncio.sleep(self.poll_interval)
283
+ continue
284
+
285
+ if self.composer_dag_run_id:
286
+ self.log.info(
287
+ "Sensor waits for allowed states %s for specified RunID: %s",
288
+ self.allowed_states,
289
+ self.composer_dag_run_id,
290
+ )
291
+ if self._check_composer_dag_run_id_states(dag_runs=dag_runs):
292
+ yield TriggerEvent({"status": "success"})
293
+ return
294
+ else:
295
+ self.log.info("Sensor waits for allowed states: %s", self.allowed_states)
296
+ if self._check_dag_runs_states(
297
+ dag_runs=dag_runs,
298
+ start_date=self.start_date,
299
+ end_date=self.end_date,
300
+ ):
301
+ yield TriggerEvent({"status": "success"})
302
+ return
265
303
  self.log.info("Sleeping for %s seconds.", self.poll_interval)
266
304
  await asyncio.sleep(self.poll_interval)
267
305
  except AirflowException as ex:
@@ -103,7 +103,13 @@ class DataplexDataQualityJobTrigger(BaseTrigger):
103
103
  self.polling_interval_seconds,
104
104
  )
105
105
  await asyncio.sleep(self.polling_interval_seconds)
106
- yield TriggerEvent({"job_id": self.job_id, "job_state": state, "job": self._convert_to_dict(job)})
106
+ yield TriggerEvent(
107
+ {
108
+ "job_id": self.job_id,
109
+ "job_state": DataScanJob.State(state).name,
110
+ "job": self._convert_to_dict(job),
111
+ }
112
+ )
107
113
 
108
114
  def _convert_to_dict(self, job: DataScanJob) -> dict:
109
115
  """Return a representation of a DataScanJob instance as a dict."""
@@ -185,7 +191,13 @@ class DataplexDataProfileJobTrigger(BaseTrigger):
185
191
  self.polling_interval_seconds,
186
192
  )
187
193
  await asyncio.sleep(self.polling_interval_seconds)
188
- yield TriggerEvent({"job_id": self.job_id, "job_state": state, "job": self._convert_to_dict(job)})
194
+ yield TriggerEvent(
195
+ {
196
+ "job_id": self.job_id,
197
+ "job_state": DataScanJob.State(state).name,
198
+ "job": self._convert_to_dict(job),
199
+ }
200
+ )
189
201
 
190
202
  def _convert_to_dict(self, job: DataScanJob) -> dict:
191
203
  """Return a representation of a DataScanJob instance as a dict."""