apache-airflow-providers-google 15.0.1rc1__py3-none-any.whl → 15.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/google/__init__.py +1 -1
- airflow/providers/google/cloud/hooks/bigquery.py +3 -5
- airflow/providers/google/cloud/hooks/cloud_batch.py +3 -4
- airflow/providers/google/cloud/hooks/cloud_sql.py +34 -41
- airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +1 -1
- airflow/providers/google/cloud/hooks/compute_ssh.py +2 -3
- airflow/providers/google/cloud/hooks/dataflow.py +11 -6
- airflow/providers/google/cloud/hooks/datafusion.py +3 -4
- airflow/providers/google/cloud/hooks/dataprep.py +1 -1
- airflow/providers/google/cloud/hooks/dlp.py +1 -1
- airflow/providers/google/cloud/hooks/gcs.py +5 -6
- airflow/providers/google/cloud/hooks/kubernetes_engine.py +1 -2
- airflow/providers/google/cloud/hooks/managed_kafka.py +1 -1
- airflow/providers/google/cloud/hooks/mlengine.py +4 -6
- airflow/providers/google/cloud/hooks/stackdriver.py +4 -6
- airflow/providers/google/cloud/hooks/vision.py +1 -2
- airflow/providers/google/cloud/openlineage/mixins.py +2 -3
- airflow/providers/google/cloud/operators/alloy_db.py +1 -2
- airflow/providers/google/cloud/operators/automl.py +5 -5
- airflow/providers/google/cloud/operators/bigquery.py +24 -26
- airflow/providers/google/cloud/operators/cloud_batch.py +13 -15
- airflow/providers/google/cloud/operators/cloud_build.py +1 -2
- airflow/providers/google/cloud/operators/cloud_composer.py +24 -28
- airflow/providers/google/cloud/operators/cloud_run.py +12 -13
- airflow/providers/google/cloud/operators/cloud_sql.py +42 -49
- airflow/providers/google/cloud/operators/compute.py +9 -10
- airflow/providers/google/cloud/operators/dataproc.py +23 -26
- airflow/providers/google/cloud/operators/functions.py +5 -7
- airflow/providers/google/cloud/operators/kubernetes_engine.py +1 -2
- airflow/providers/google/cloud/operators/spanner.py +29 -33
- airflow/providers/google/cloud/sensors/cloud_composer.py +3 -5
- airflow/providers/google/cloud/sensors/dataflow.py +1 -1
- airflow/providers/google/cloud/sensors/dataproc.py +5 -5
- airflow/providers/google/cloud/sensors/gcs.py +15 -16
- airflow/providers/google/cloud/sensors/looker.py +3 -3
- airflow/providers/google/cloud/sensors/pubsub.py +13 -14
- airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +7 -8
- airflow/providers/google/cloud/transfers/cassandra_to_gcs.py +16 -20
- airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py +1 -2
- airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +14 -16
- airflow/providers/google/cloud/transfers/postgres_to_gcs.py +2 -3
- airflow/providers/google/cloud/utils/bigquery.py +5 -7
- airflow/providers/google/cloud/utils/dataform.py +1 -1
- airflow/providers/google/common/hooks/base_google.py +1 -1
- airflow/providers/google/common/hooks/operation_helpers.py +1 -2
- airflow/providers/google/common/utils/id_token_credentials.py +1 -1
- airflow/providers/google/leveldb/hooks/leveldb.py +4 -5
- {apache_airflow_providers_google-15.0.1rc1.dist-info → apache_airflow_providers_google-15.1.0.dist-info}/METADATA +9 -9
- {apache_airflow_providers_google-15.0.1rc1.dist-info → apache_airflow_providers_google-15.1.0.dist-info}/RECORD +51 -51
- {apache_airflow_providers_google-15.0.1rc1.dist-info → apache_airflow_providers_google-15.1.0.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_google-15.0.1rc1.dist-info → apache_airflow_providers_google-15.1.0.dist-info}/entry_points.txt +0 -0
@@ -29,7 +29,7 @@ from airflow import __version__ as airflow_version
|
|
29
29
|
|
30
30
|
__all__ = ["__version__"]
|
31
31
|
|
32
|
-
__version__ = "15.0
|
32
|
+
__version__ = "15.1.0"
|
33
33
|
|
34
34
|
if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
|
35
35
|
"2.9.0"
|
@@ -123,7 +123,7 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
|
|
123
123
|
from airflow.providers.google.cloud.utils.validators import ValidJson
|
124
124
|
|
125
125
|
connection_form_widgets = super().get_connection_form_widgets()
|
126
|
-
connection_form_widgets["use_legacy_sql"] = BooleanField(lazy_gettext("Use Legacy SQL")
|
126
|
+
connection_form_widgets["use_legacy_sql"] = BooleanField(lazy_gettext("Use Legacy SQL"))
|
127
127
|
connection_form_widgets["location"] = StringField(
|
128
128
|
lazy_gettext("Location"), widget=BS3TextFieldWidget()
|
129
129
|
)
|
@@ -1376,8 +1376,7 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
|
|
1376
1376
|
def var_print(var_name):
|
1377
1377
|
if var_name is None:
|
1378
1378
|
return ""
|
1379
|
-
|
1380
|
-
return f"Format exception for {var_name}: "
|
1379
|
+
return f"Format exception for {var_name}: "
|
1381
1380
|
|
1382
1381
|
if table_input.count(".") + table_input.count(":") > 3:
|
1383
1382
|
raise ValueError(f"{var_print(var_name)}Use either : or . to specify project got {table_input}")
|
@@ -1955,8 +1954,7 @@ def split_tablename(
|
|
1955
1954
|
def var_print(var_name):
|
1956
1955
|
if var_name is None:
|
1957
1956
|
return ""
|
1958
|
-
|
1959
|
-
return f"Format exception for {var_name}: "
|
1957
|
+
return f"Format exception for {var_name}: "
|
1960
1958
|
|
1961
1959
|
if table_input.count(".") + table_input.count(":") > 3:
|
1962
1960
|
raise ValueError(f"{var_print(var_name)}Use either : or . to specify project got {table_input}")
|
@@ -147,19 +147,18 @@ class CloudBatchHook(GoogleBaseHook):
|
|
147
147
|
status: JobStatus.State = job.status.state
|
148
148
|
if status == JobStatus.State.SUCCEEDED:
|
149
149
|
return job
|
150
|
-
|
150
|
+
if status == JobStatus.State.FAILED:
|
151
151
|
message = (
|
152
152
|
"Unexpected error in the operation: "
|
153
153
|
"Batch job with name {job_name} has failed its execution."
|
154
154
|
)
|
155
155
|
raise AirflowException(message)
|
156
|
-
|
156
|
+
if status == JobStatus.State.DELETION_IN_PROGRESS:
|
157
157
|
message = (
|
158
158
|
"Unexpected error in the operation: Batch job with name {job_name} is being deleted."
|
159
159
|
)
|
160
160
|
raise AirflowException(message)
|
161
|
-
|
162
|
-
time.sleep(polling_period_seconds)
|
161
|
+
time.sleep(polling_period_seconds)
|
163
162
|
except Exception as e:
|
164
163
|
self.log.exception("Exception occurred while checking for job completion.")
|
165
164
|
raise e
|
@@ -634,35 +634,32 @@ class CloudSqlProxyRunner(LoggingMixin):
|
|
634
634
|
self._download_sql_proxy_if_needed()
|
635
635
|
if self.sql_proxy_process:
|
636
636
|
raise AirflowException(f"The sql proxy is already running: {self.sql_proxy_process}")
|
637
|
-
|
638
|
-
|
639
|
-
|
640
|
-
|
641
|
-
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
|
650
|
-
|
651
|
-
|
652
|
-
|
653
|
-
|
654
|
-
|
655
|
-
|
656
|
-
|
657
|
-
|
658
|
-
|
659
|
-
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
raise AirflowException(f"Error when starting the cloud_sql_proxy {line}!")
|
664
|
-
if "Ready for new connections" in line:
|
665
|
-
return
|
637
|
+
command_to_run = [self.sql_proxy_path]
|
638
|
+
command_to_run.extend(self.command_line_parameters)
|
639
|
+
self.log.info("Creating directory %s", self.cloud_sql_proxy_socket_directory)
|
640
|
+
Path(self.cloud_sql_proxy_socket_directory).mkdir(parents=True, exist_ok=True)
|
641
|
+
command_to_run.extend(self._get_credential_parameters())
|
642
|
+
self.log.info("Running the command: `%s`", " ".join(command_to_run))
|
643
|
+
|
644
|
+
self.sql_proxy_process = Popen(command_to_run, stdin=PIPE, stdout=PIPE, stderr=PIPE)
|
645
|
+
self.log.info("The pid of cloud_sql_proxy: %s", self.sql_proxy_process.pid)
|
646
|
+
while True:
|
647
|
+
line = (
|
648
|
+
self.sql_proxy_process.stderr.readline().decode("utf-8")
|
649
|
+
if self.sql_proxy_process.stderr
|
650
|
+
else ""
|
651
|
+
)
|
652
|
+
return_code = self.sql_proxy_process.poll()
|
653
|
+
if line == "" and return_code is not None:
|
654
|
+
self.sql_proxy_process = None
|
655
|
+
raise AirflowException(f"The cloud_sql_proxy finished early with return code {return_code}!")
|
656
|
+
if line != "":
|
657
|
+
self.log.info(line)
|
658
|
+
if "googleapi: Error" in line or "invalid instance name:" in line:
|
659
|
+
self.stop_proxy()
|
660
|
+
raise AirflowException(f"Error when starting the cloud_sql_proxy {line}!")
|
661
|
+
if "Ready for new connections" in line:
|
662
|
+
return
|
666
663
|
|
667
664
|
def stop_proxy(self) -> None:
|
668
665
|
"""
|
@@ -672,10 +669,9 @@ class CloudSqlProxyRunner(LoggingMixin):
|
|
672
669
|
"""
|
673
670
|
if not self.sql_proxy_process:
|
674
671
|
raise AirflowException("The sql proxy is not started yet")
|
675
|
-
|
676
|
-
|
677
|
-
|
678
|
-
self.sql_proxy_process = None
|
672
|
+
self.log.info("Stopping the cloud_sql_proxy pid: %s", self.sql_proxy_process.pid)
|
673
|
+
self.sql_proxy_process.kill()
|
674
|
+
self.sql_proxy_process = None
|
679
675
|
# Cleanup!
|
680
676
|
self.log.info("Removing the socket directory: %s", self.cloud_sql_proxy_socket_directory)
|
681
677
|
shutil.rmtree(self.cloud_sql_proxy_socket_directory, ignore_errors=True)
|
@@ -704,8 +700,7 @@ class CloudSqlProxyRunner(LoggingMixin):
|
|
704
700
|
matched = re.search("[Vv]ersion (.*?);", result)
|
705
701
|
if matched:
|
706
702
|
return matched.group(1)
|
707
|
-
|
708
|
-
return None
|
703
|
+
return None
|
709
704
|
|
710
705
|
def get_socket_path(self) -> str:
|
711
706
|
"""
|
@@ -908,10 +903,9 @@ class CloudSQLDatabaseHook(BaseHook):
|
|
908
903
|
secret_data = json.loads(base64.b64decode(secret.payload.data))
|
909
904
|
if cert_name in secret_data:
|
910
905
|
return secret_data[cert_name]
|
911
|
-
|
912
|
-
|
913
|
-
|
914
|
-
)
|
906
|
+
raise AirflowException(
|
907
|
+
"Invalid secret format. Expected dictionary with keys: `sslcert`, `sslkey`, `sslrootcert`"
|
908
|
+
)
|
915
909
|
|
916
910
|
def _set_temporary_ssl_file(
|
917
911
|
self, cert_name: str, cert_path: str | None = None, cert_value: str | None = None
|
@@ -1205,8 +1199,7 @@ class CloudSQLDatabaseHook(BaseHook):
|
|
1205
1199
|
|
1206
1200
|
if self.database_type == "postgres":
|
1207
1201
|
return self.cloudsql_connection.login.split(".gserviceaccount.com")[0]
|
1208
|
-
|
1209
|
-
return self.cloudsql_connection.login.split("@")[0]
|
1202
|
+
return self.cloudsql_connection.login.split("@")[0]
|
1210
1203
|
|
1211
1204
|
def _generate_login_token(self, service_account) -> str:
|
1212
1205
|
"""Generate an IAM login token for Cloud SQL and return the token."""
|
@@ -218,7 +218,7 @@ class CloudDataTransferServiceHook(GoogleBaseHook):
|
|
218
218
|
return (
|
219
219
|
self.get_conn().transferJobs().create(body=body).execute(num_retries=self.num_retries)
|
220
220
|
)
|
221
|
-
|
221
|
+
if transfer_job.get(STATUS) == GcpTransferJobsStatus.DISABLED:
|
222
222
|
return self.enable_transfer_job(job_name=job_name, project_id=body.get(PROJECT_ID))
|
223
223
|
else:
|
224
224
|
raise e
|
@@ -148,8 +148,7 @@ class ComputeEngineSSHHook(SSHHook):
|
|
148
148
|
return ComputeEngineHook(
|
149
149
|
gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain
|
150
150
|
)
|
151
|
-
|
152
|
-
return ComputeEngineHook(gcp_conn_id=self.gcp_conn_id)
|
151
|
+
return ComputeEngineHook(gcp_conn_id=self.gcp_conn_id)
|
153
152
|
|
154
153
|
def _load_connection_config(self):
|
155
154
|
def _boolify(value):
|
@@ -158,7 +157,7 @@ class ComputeEngineSSHHook(SSHHook):
|
|
158
157
|
if isinstance(value, str):
|
159
158
|
if value.lower() == "false":
|
160
159
|
return False
|
161
|
-
|
160
|
+
if value.lower() == "true":
|
162
161
|
return True
|
163
162
|
return False
|
164
163
|
|
@@ -51,6 +51,7 @@ from googleapiclient.discovery import Resource, build
|
|
51
51
|
|
52
52
|
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
|
53
53
|
from airflow.providers.apache.beam.hooks.beam import BeamHook, BeamRunnerType, beam_options_to_args
|
54
|
+
from airflow.providers.google.common.deprecated import deprecated
|
54
55
|
from airflow.providers.google.common.hooks.base_google import (
|
55
56
|
PROVIDE_PROJECT_ID,
|
56
57
|
GoogleBaseAsyncHook,
|
@@ -261,15 +262,14 @@ class _DataflowJobsController(LoggingMixin):
|
|
261
262
|
"""
|
262
263
|
if not self._multiple_jobs and self._job_id:
|
263
264
|
return [self.fetch_job_by_id(self._job_id)]
|
264
|
-
|
265
|
+
if self._jobs:
|
265
266
|
return [self.fetch_job_by_id(job["id"]) for job in self._jobs]
|
266
|
-
|
267
|
+
if self._job_name:
|
267
268
|
jobs = self._fetch_jobs_by_prefix_name(self._job_name.lower())
|
268
269
|
if len(jobs) == 1:
|
269
270
|
self._job_id = jobs[0]["id"]
|
270
271
|
return jobs
|
271
|
-
|
272
|
-
raise ValueError("Missing both dataflow job ID and name.")
|
272
|
+
raise ValueError("Missing both dataflow job ID and name.")
|
273
273
|
|
274
274
|
def fetch_job_by_id(self, job_id: str) -> dict[str, str]:
|
275
275
|
"""
|
@@ -434,12 +434,12 @@ class _DataflowJobsController(LoggingMixin):
|
|
434
434
|
f"'{current_expected_state}' is invalid."
|
435
435
|
f" The value should be any of the following: {terminal_states}"
|
436
436
|
)
|
437
|
-
|
437
|
+
if is_streaming and current_expected_state == DataflowJobStatus.JOB_STATE_DONE:
|
438
438
|
raise AirflowException(
|
439
439
|
"Google Cloud Dataflow job's expected terminal state cannot be "
|
440
440
|
"JOB_STATE_DONE while it is a streaming job"
|
441
441
|
)
|
442
|
-
|
442
|
+
if not is_streaming and current_expected_state == DataflowJobStatus.JOB_STATE_DRAINED:
|
443
443
|
raise AirflowException(
|
444
444
|
"Google Cloud Dataflow job's expected terminal state cannot be "
|
445
445
|
"JOB_STATE_DRAINED while it is a batch job"
|
@@ -1063,6 +1063,11 @@ class DataflowHook(GoogleBaseHook):
|
|
1063
1063
|
)
|
1064
1064
|
jobs_controller.cancel()
|
1065
1065
|
|
1066
|
+
@deprecated(
|
1067
|
+
planned_removal_date="July 01, 2025",
|
1068
|
+
use_instead="airflow.providers.google.cloud.hooks.dataflow.DataflowHook.launch_beam_yaml_job",
|
1069
|
+
category=AirflowProviderDeprecationWarning,
|
1070
|
+
)
|
1066
1071
|
@GoogleBaseHook.fallback_to_default_project_id
|
1067
1072
|
def start_sql_job(
|
1068
1073
|
self,
|
@@ -170,9 +170,9 @@ class DataFusionHook(GoogleBaseHook):
|
|
170
170
|
def _check_response_status_and_data(response, message: str) -> None:
|
171
171
|
if response.status == 404:
|
172
172
|
raise AirflowNotFoundException(message)
|
173
|
-
|
173
|
+
if response.status == 409:
|
174
174
|
raise ConflictException("Conflict: Resource is still in use.")
|
175
|
-
|
175
|
+
if response.status != 200:
|
176
176
|
raise AirflowException(message)
|
177
177
|
if response.data is None:
|
178
178
|
raise AirflowException(
|
@@ -572,8 +572,7 @@ class DataFusionAsyncHook(GoogleBaseAsyncHook):
|
|
572
572
|
raise
|
573
573
|
if pipeline:
|
574
574
|
return pipeline
|
575
|
-
|
576
|
-
raise AirflowException("Could not retrieve pipeline. Aborting.")
|
575
|
+
raise AirflowException("Could not retrieve pipeline. Aborting.")
|
577
576
|
|
578
577
|
async def get_pipeline(
|
579
578
|
self,
|
@@ -31,7 +31,7 @@ from tenacity import retry, stop_after_attempt, wait_exponential
|
|
31
31
|
from airflow.hooks.base import BaseHook
|
32
32
|
|
33
33
|
|
34
|
-
def _get_field(extras: dict, field_name: str):
|
34
|
+
def _get_field(extras: dict, field_name: str) -> str | None:
|
35
35
|
"""Get field from extra, first checking short name, then for backcompat we check for prefixed name."""
|
36
36
|
backcompat_prefix = "extra__dataprep__"
|
37
37
|
if field_name.startswith("extra__"):
|
@@ -358,11 +358,10 @@ class GCSHook(GoogleBaseHook):
|
|
358
358
|
)
|
359
359
|
self.log.info("File downloaded to %s", filename)
|
360
360
|
return filename
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
return blob.download_as_bytes()
|
361
|
+
get_hook_lineage_collector().add_input_asset(
|
362
|
+
context=self, scheme="gs", asset_kwargs={"bucket": bucket.name, "key": blob.name}
|
363
|
+
)
|
364
|
+
return blob.download_as_bytes()
|
366
365
|
|
367
366
|
except GoogleCloudError:
|
368
367
|
if attempt == num_max_attempts - 1:
|
@@ -556,7 +555,7 @@ class GCSHook(GoogleBaseHook):
|
|
556
555
|
"specify a single parameter, either 'filename' for "
|
557
556
|
"local file uploads or 'data' for file content uploads."
|
558
557
|
)
|
559
|
-
|
558
|
+
if filename:
|
560
559
|
if not mime_type:
|
561
560
|
mime_type = "application/octet-stream"
|
562
561
|
if gzip:
|
@@ -120,7 +120,7 @@ class ManagedKafkaHook(GoogleBaseHook):
|
|
120
120
|
error = operation.exception(timeout=timeout)
|
121
121
|
raise AirflowException(error)
|
122
122
|
|
123
|
-
def get_confluent_token(self):
|
123
|
+
def get_confluent_token(self, config_str: str):
|
124
124
|
"""Get the authentication token for confluent client."""
|
125
125
|
token_provider = ManagedKafkaTokenProvider(credentials=self.get_credentials())
|
126
126
|
token = token_provider.confluent_token()
|
@@ -78,8 +78,7 @@ def _poll_with_exponential_delay(
|
|
78
78
|
if e.resp.status != 429:
|
79
79
|
log.info("Something went wrong. Not retrying: %s", format(e))
|
80
80
|
raise
|
81
|
-
|
82
|
-
time.sleep((2**i) + random.random())
|
81
|
+
time.sleep((2**i) + random.random())
|
83
82
|
|
84
83
|
raise ValueError(f"Connection could not be established after {max_n} retries.")
|
85
84
|
|
@@ -219,12 +218,11 @@ class MLEngineHook(GoogleBaseHook):
|
|
219
218
|
if e.resp.status == 404:
|
220
219
|
self.log.error("Job with job_id %s does not exist. ", job_id)
|
221
220
|
raise
|
222
|
-
|
221
|
+
if e.resp.status == 400:
|
223
222
|
self.log.info("Job with job_id %s is already complete, cancellation aborted.", job_id)
|
224
223
|
return {}
|
225
|
-
|
226
|
-
|
227
|
-
raise
|
224
|
+
self.log.error("Failed to cancel MLEngine job: %s", e)
|
225
|
+
raise
|
228
226
|
|
229
227
|
def get_job(self, project_id: str, job_id: str) -> dict:
|
230
228
|
"""
|
@@ -121,10 +121,9 @@ class StackdriverHook(GoogleBaseHook):
|
|
121
121
|
)
|
122
122
|
if format_ == "dict":
|
123
123
|
return [AlertPolicy.to_dict(policy) for policy in policies_]
|
124
|
-
|
124
|
+
if format_ == "json":
|
125
125
|
return [AlertPolicy.to_jsoon(policy) for policy in policies_]
|
126
|
-
|
127
|
-
return policies_
|
126
|
+
return policies_
|
128
127
|
|
129
128
|
@GoogleBaseHook.fallback_to_default_project_id
|
130
129
|
def _toggle_policy_status(
|
@@ -395,10 +394,9 @@ class StackdriverHook(GoogleBaseHook):
|
|
395
394
|
)
|
396
395
|
if format_ == "dict":
|
397
396
|
return [NotificationChannel.to_dict(channel) for channel in channels]
|
398
|
-
|
397
|
+
if format_ == "json":
|
399
398
|
return [NotificationChannel.to_json(channel) for channel in channels]
|
400
|
-
|
401
|
-
return channels
|
399
|
+
return channels
|
402
400
|
|
403
401
|
@GoogleBaseHook.fallback_to_default_project_id
|
404
402
|
def _toggle_channel_status(
|
@@ -107,8 +107,7 @@ class NameDeterminer:
|
|
107
107
|
# Not enough parameters to construct the name. Trying to use the name from Product / ProductSet.
|
108
108
|
if explicit_name:
|
109
109
|
return entity
|
110
|
-
|
111
|
-
raise AirflowException(ERR_UNABLE_TO_CREATE.format(label=self.label, id_label=self.id_label))
|
110
|
+
raise AirflowException(ERR_UNABLE_TO_CREATE.format(label=self.label, id_label=self.id_label))
|
112
111
|
|
113
112
|
|
114
113
|
class CloudVisionHook(GoogleBaseHook):
|
@@ -207,15 +207,14 @@ class _BigQueryInsertJobOperatorOpenLineageMixin:
|
|
207
207
|
name=dataset_name,
|
208
208
|
facets=dataset_facets,
|
209
209
|
)
|
210
|
-
|
210
|
+
if dataset_type == "output":
|
211
211
|
# Logic specific to creating OutputDataset (if needed)
|
212
212
|
return OutputDataset(
|
213
213
|
namespace=BIGQUERY_NAMESPACE,
|
214
214
|
name=dataset_name,
|
215
215
|
facets=dataset_facets,
|
216
216
|
)
|
217
|
-
|
218
|
-
raise ValueError("Invalid dataset_type. Must be 'input' or 'output'")
|
217
|
+
raise ValueError("Invalid dataset_type. Must be 'input' or 'output'")
|
219
218
|
|
220
219
|
def _get_table_facets_safely(self, table_name: str) -> dict[str, DatasetFacet]:
|
221
220
|
try:
|
@@ -145,8 +145,7 @@ class AlloyDBWriteBaseOperator(AlloyDBBaseOperator):
|
|
145
145
|
if self.validate_request:
|
146
146
|
# Validation requests are only validated and aren't executed, thus no operation result is expected
|
147
147
|
return None
|
148
|
-
|
149
|
-
return self.hook.wait_for_operation(timeout=self.timeout, operation=operation)
|
148
|
+
return self.hook.wait_for_operation(timeout=self.timeout, operation=operation)
|
150
149
|
|
151
150
|
|
152
151
|
class AlloyDBCreateClusterOperator(AlloyDBWriteBaseOperator):
|
@@ -259,11 +259,11 @@ class AutoMLPredictOperator(GoogleCloudBaseOperator):
|
|
259
259
|
gcp_conn_id=self.gcp_conn_id,
|
260
260
|
impersonation_chain=self.impersonation_chain,
|
261
261
|
)
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
262
|
+
# endpoint_id defined
|
263
|
+
return PredictionServiceHook(
|
264
|
+
gcp_conn_id=self.gcp_conn_id,
|
265
|
+
impersonation_chain=self.impersonation_chain,
|
266
|
+
)
|
267
267
|
|
268
268
|
@cached_property
|
269
269
|
def model(self) -> Model | None:
|
@@ -287,7 +287,7 @@ class BigQueryCheckOperator(
|
|
287
287
|
def _validate_records(self, records) -> None:
|
288
288
|
if not records:
|
289
289
|
raise AirflowException(f"The following query returned zero rows: {self.sql}")
|
290
|
-
|
290
|
+
if not all(records):
|
291
291
|
self._raise_exception( # type: ignore[attr-defined]
|
292
292
|
f"Test failed.\nQuery:\n{self.sql}\nResults:\n{records!s}"
|
293
293
|
)
|
@@ -2976,14 +2976,13 @@ class BigQueryInsertJobOperator(GoogleCloudBaseOperator, _BigQueryInsertJobOpera
|
|
2976
2976
|
f"Or, if you want to reattach in this scenario add {job.state} to `reattach_states`"
|
2977
2977
|
)
|
2978
2978
|
|
2979
|
-
|
2980
|
-
|
2981
|
-
|
2982
|
-
raise AirflowException("Job is already in state DONE. Can not reattach to this job.")
|
2979
|
+
# Job already reached state DONE
|
2980
|
+
if job.state == "DONE":
|
2981
|
+
raise AirflowException("Job is already in state DONE. Can not reattach to this job.")
|
2983
2982
|
|
2984
|
-
|
2985
|
-
|
2986
|
-
|
2983
|
+
# We are reattaching to a job
|
2984
|
+
self.log.info("Reattaching to existing Job in state %s", job.state)
|
2985
|
+
self._handle_job_error(job)
|
2987
2986
|
|
2988
2987
|
job_types = {
|
2989
2988
|
LoadJob._JOB_TYPE: ["sourceTable", "destinationTable"],
|
@@ -3036,24 +3035,23 @@ class BigQueryInsertJobOperator(GoogleCloudBaseOperator, _BigQueryInsertJobOpera
|
|
3036
3035
|
self._handle_job_error(job)
|
3037
3036
|
|
3038
3037
|
return self.job_id
|
3039
|
-
|
3040
|
-
|
3041
|
-
self.
|
3042
|
-
|
3043
|
-
|
3044
|
-
|
3045
|
-
|
3046
|
-
|
3047
|
-
|
3048
|
-
|
3049
|
-
|
3050
|
-
|
3051
|
-
|
3052
|
-
|
3053
|
-
|
3054
|
-
|
3055
|
-
|
3056
|
-
return self.job_id
|
3038
|
+
if job.running():
|
3039
|
+
self.defer(
|
3040
|
+
timeout=self.execution_timeout,
|
3041
|
+
trigger=BigQueryInsertJobTrigger(
|
3042
|
+
conn_id=self.gcp_conn_id,
|
3043
|
+
job_id=self.job_id,
|
3044
|
+
project_id=self.project_id,
|
3045
|
+
location=self.location or hook.location,
|
3046
|
+
poll_interval=self.poll_interval,
|
3047
|
+
impersonation_chain=self.impersonation_chain,
|
3048
|
+
cancel_on_kill=self.cancel_on_kill,
|
3049
|
+
),
|
3050
|
+
method_name="execute_complete",
|
3051
|
+
)
|
3052
|
+
self.log.info("Current state of job %s is %s", job.job_id, job.state)
|
3053
|
+
self._handle_job_error(job)
|
3054
|
+
return self.job_id
|
3057
3055
|
|
3058
3056
|
def execute_complete(self, context: Context, event: dict[str, Any]) -> str | None:
|
3059
3057
|
"""
|
@@ -100,19 +100,18 @@ class CloudBatchSubmitJobOperator(GoogleCloudBaseOperator):
|
|
100
100
|
|
101
101
|
return Job.to_dict(completed_job)
|
102
102
|
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
)
|
103
|
+
self.defer(
|
104
|
+
trigger=CloudBatchJobFinishedTrigger(
|
105
|
+
job_name=job.name,
|
106
|
+
project_id=self.project_id,
|
107
|
+
gcp_conn_id=self.gcp_conn_id,
|
108
|
+
impersonation_chain=self.impersonation_chain,
|
109
|
+
location=self.region,
|
110
|
+
polling_period_seconds=self.polling_period_seconds,
|
111
|
+
timeout=self.timeout_seconds,
|
112
|
+
),
|
113
|
+
method_name="execute_complete",
|
114
|
+
)
|
116
115
|
|
117
116
|
def execute_complete(self, context: Context, event: dict):
|
118
117
|
job_status = event["status"]
|
@@ -120,8 +119,7 @@ class CloudBatchSubmitJobOperator(GoogleCloudBaseOperator):
|
|
120
119
|
hook: CloudBatchHook = CloudBatchHook(self.gcp_conn_id, self.impersonation_chain)
|
121
120
|
job = hook.get_job(job_name=event["job_name"])
|
122
121
|
return Job.to_dict(job)
|
123
|
-
|
124
|
-
raise AirflowException(f"Unexpected error in the operation: {event['message']}")
|
122
|
+
raise AirflowException(f"Unexpected error in the operation: {event['message']}")
|
125
123
|
|
126
124
|
|
127
125
|
class CloudBatchDeleteJobOperator(GoogleCloudBaseOperator):
|
@@ -275,8 +275,7 @@ class CloudBuildCreateBuildOperator(GoogleCloudBaseOperator):
|
|
275
275
|
build_id=event["id_"],
|
276
276
|
)
|
277
277
|
return event["instance"]
|
278
|
-
|
279
|
-
raise AirflowException(f"Unexpected error in the operation: {event['message']}")
|
278
|
+
raise AirflowException(f"Unexpected error in the operation: {event['message']}")
|
280
279
|
|
281
280
|
|
282
281
|
class CloudBuildCreateBuildTriggerOperator(GoogleCloudBaseOperator):
|