apache-airflow-providers-google 14.0.0__py3-none-any.whl → 19.1.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/google/3rd-party-licenses/LICENSES.txt +14 -0
- airflow/providers/google/3rd-party-licenses/NOTICE +5 -0
- airflow/providers/google/__init__.py +3 -3
- airflow/providers/google/_vendor/__init__.py +0 -0
- airflow/providers/google/_vendor/json_merge_patch.py +91 -0
- airflow/providers/google/ads/hooks/ads.py +52 -43
- airflow/providers/google/ads/operators/ads.py +2 -2
- airflow/providers/google/ads/transfers/ads_to_gcs.py +3 -19
- airflow/providers/google/assets/gcs.py +1 -11
- airflow/providers/google/cloud/_internal_client/secret_manager_client.py +3 -2
- airflow/providers/google/cloud/bundles/gcs.py +161 -0
- airflow/providers/google/cloud/hooks/alloy_db.py +2 -3
- airflow/providers/google/cloud/hooks/bigquery.py +195 -318
- airflow/providers/google/cloud/hooks/bigquery_dts.py +8 -8
- airflow/providers/google/cloud/hooks/bigtable.py +3 -2
- airflow/providers/google/cloud/hooks/cloud_batch.py +8 -9
- airflow/providers/google/cloud/hooks/cloud_build.py +6 -65
- airflow/providers/google/cloud/hooks/cloud_composer.py +292 -24
- airflow/providers/google/cloud/hooks/cloud_logging.py +109 -0
- airflow/providers/google/cloud/hooks/cloud_memorystore.py +4 -3
- airflow/providers/google/cloud/hooks/cloud_run.py +20 -11
- airflow/providers/google/cloud/hooks/cloud_sql.py +136 -64
- airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +35 -15
- airflow/providers/google/cloud/hooks/compute.py +7 -6
- airflow/providers/google/cloud/hooks/compute_ssh.py +7 -4
- airflow/providers/google/cloud/hooks/datacatalog.py +12 -3
- airflow/providers/google/cloud/hooks/dataflow.py +87 -242
- airflow/providers/google/cloud/hooks/dataform.py +9 -14
- airflow/providers/google/cloud/hooks/datafusion.py +7 -9
- airflow/providers/google/cloud/hooks/dataplex.py +13 -12
- airflow/providers/google/cloud/hooks/dataprep.py +2 -2
- airflow/providers/google/cloud/hooks/dataproc.py +76 -74
- airflow/providers/google/cloud/hooks/dataproc_metastore.py +4 -3
- airflow/providers/google/cloud/hooks/dlp.py +5 -4
- airflow/providers/google/cloud/hooks/gcs.py +144 -33
- airflow/providers/google/cloud/hooks/gen_ai.py +196 -0
- airflow/providers/google/cloud/hooks/kms.py +3 -2
- airflow/providers/google/cloud/hooks/kubernetes_engine.py +22 -17
- airflow/providers/google/cloud/hooks/looker.py +6 -1
- airflow/providers/google/cloud/hooks/managed_kafka.py +227 -3
- airflow/providers/google/cloud/hooks/mlengine.py +7 -8
- airflow/providers/google/cloud/hooks/natural_language.py +3 -2
- airflow/providers/google/cloud/hooks/os_login.py +3 -2
- airflow/providers/google/cloud/hooks/pubsub.py +6 -6
- airflow/providers/google/cloud/hooks/secret_manager.py +105 -12
- airflow/providers/google/cloud/hooks/spanner.py +75 -10
- airflow/providers/google/cloud/hooks/speech_to_text.py +3 -2
- airflow/providers/google/cloud/hooks/stackdriver.py +18 -18
- airflow/providers/google/cloud/hooks/tasks.py +4 -3
- airflow/providers/google/cloud/hooks/text_to_speech.py +3 -2
- airflow/providers/google/cloud/hooks/translate.py +8 -17
- airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py +8 -222
- airflow/providers/google/cloud/hooks/vertex_ai/batch_prediction_job.py +9 -15
- airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py +33 -283
- airflow/providers/google/cloud/hooks/vertex_ai/dataset.py +5 -12
- airflow/providers/google/cloud/hooks/vertex_ai/endpoint_service.py +6 -12
- airflow/providers/google/cloud/hooks/vertex_ai/experiment_service.py +202 -0
- airflow/providers/google/cloud/hooks/vertex_ai/feature_store.py +311 -10
- airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +79 -75
- airflow/providers/google/cloud/hooks/vertex_ai/hyperparameter_tuning_job.py +7 -13
- airflow/providers/google/cloud/hooks/vertex_ai/model_service.py +8 -12
- airflow/providers/google/cloud/hooks/vertex_ai/pipeline_job.py +6 -12
- airflow/providers/google/cloud/hooks/vertex_ai/prediction_service.py +3 -2
- airflow/providers/google/cloud/hooks/vertex_ai/ray.py +223 -0
- airflow/providers/google/cloud/hooks/video_intelligence.py +3 -2
- airflow/providers/google/cloud/hooks/vision.py +7 -7
- airflow/providers/google/cloud/hooks/workflows.py +4 -3
- airflow/providers/google/cloud/links/alloy_db.py +0 -46
- airflow/providers/google/cloud/links/base.py +77 -7
- airflow/providers/google/cloud/links/bigquery.py +0 -47
- airflow/providers/google/cloud/links/bigquery_dts.py +0 -20
- airflow/providers/google/cloud/links/bigtable.py +0 -48
- airflow/providers/google/cloud/links/cloud_build.py +0 -73
- airflow/providers/google/cloud/links/cloud_functions.py +0 -33
- airflow/providers/google/cloud/links/cloud_memorystore.py +0 -58
- airflow/providers/google/cloud/links/{life_sciences.py → cloud_run.py} +5 -27
- airflow/providers/google/cloud/links/cloud_sql.py +0 -33
- airflow/providers/google/cloud/links/cloud_storage_transfer.py +17 -46
- airflow/providers/google/cloud/links/cloud_tasks.py +7 -26
- airflow/providers/google/cloud/links/compute.py +0 -58
- airflow/providers/google/cloud/links/data_loss_prevention.py +0 -169
- airflow/providers/google/cloud/links/datacatalog.py +23 -54
- airflow/providers/google/cloud/links/dataflow.py +0 -34
- airflow/providers/google/cloud/links/dataform.py +0 -64
- airflow/providers/google/cloud/links/datafusion.py +1 -90
- airflow/providers/google/cloud/links/dataplex.py +0 -154
- airflow/providers/google/cloud/links/dataprep.py +0 -24
- airflow/providers/google/cloud/links/dataproc.py +11 -89
- airflow/providers/google/cloud/links/datastore.py +0 -31
- airflow/providers/google/cloud/links/kubernetes_engine.py +11 -61
- airflow/providers/google/cloud/links/managed_kafka.py +11 -51
- airflow/providers/google/cloud/links/mlengine.py +0 -70
- airflow/providers/google/cloud/links/pubsub.py +0 -32
- airflow/providers/google/cloud/links/spanner.py +0 -33
- airflow/providers/google/cloud/links/stackdriver.py +0 -30
- airflow/providers/google/cloud/links/translate.py +17 -187
- airflow/providers/google/cloud/links/vertex_ai.py +28 -195
- airflow/providers/google/cloud/links/workflows.py +0 -52
- airflow/providers/google/cloud/log/gcs_task_handler.py +166 -118
- airflow/providers/google/cloud/log/stackdriver_task_handler.py +14 -9
- airflow/providers/google/cloud/openlineage/CloudStorageTransferJobFacet.json +68 -0
- airflow/providers/google/cloud/openlineage/CloudStorageTransferRunFacet.json +60 -0
- airflow/providers/google/cloud/openlineage/DataFusionRunFacet.json +32 -0
- airflow/providers/google/cloud/openlineage/facets.py +141 -40
- airflow/providers/google/cloud/openlineage/mixins.py +14 -13
- airflow/providers/google/cloud/openlineage/utils.py +19 -3
- airflow/providers/google/cloud/operators/alloy_db.py +76 -61
- airflow/providers/google/cloud/operators/bigquery.py +104 -667
- airflow/providers/google/cloud/operators/bigquery_dts.py +12 -12
- airflow/providers/google/cloud/operators/bigtable.py +38 -7
- airflow/providers/google/cloud/operators/cloud_base.py +22 -1
- airflow/providers/google/cloud/operators/cloud_batch.py +18 -18
- airflow/providers/google/cloud/operators/cloud_build.py +80 -36
- airflow/providers/google/cloud/operators/cloud_composer.py +157 -71
- airflow/providers/google/cloud/operators/cloud_logging_sink.py +341 -0
- airflow/providers/google/cloud/operators/cloud_memorystore.py +74 -46
- airflow/providers/google/cloud/operators/cloud_run.py +39 -20
- airflow/providers/google/cloud/operators/cloud_sql.py +46 -61
- airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +92 -14
- airflow/providers/google/cloud/operators/compute.py +18 -50
- airflow/providers/google/cloud/operators/datacatalog.py +167 -29
- airflow/providers/google/cloud/operators/dataflow.py +38 -15
- airflow/providers/google/cloud/operators/dataform.py +19 -7
- airflow/providers/google/cloud/operators/datafusion.py +43 -43
- airflow/providers/google/cloud/operators/dataplex.py +212 -126
- airflow/providers/google/cloud/operators/dataprep.py +1 -5
- airflow/providers/google/cloud/operators/dataproc.py +134 -207
- airflow/providers/google/cloud/operators/dataproc_metastore.py +102 -84
- airflow/providers/google/cloud/operators/datastore.py +22 -6
- airflow/providers/google/cloud/operators/dlp.py +24 -45
- airflow/providers/google/cloud/operators/functions.py +21 -14
- airflow/providers/google/cloud/operators/gcs.py +15 -12
- airflow/providers/google/cloud/operators/gen_ai.py +389 -0
- airflow/providers/google/cloud/operators/kubernetes_engine.py +115 -106
- airflow/providers/google/cloud/operators/looker.py +1 -1
- airflow/providers/google/cloud/operators/managed_kafka.py +362 -40
- airflow/providers/google/cloud/operators/natural_language.py +5 -3
- airflow/providers/google/cloud/operators/pubsub.py +69 -21
- airflow/providers/google/cloud/operators/spanner.py +53 -45
- airflow/providers/google/cloud/operators/speech_to_text.py +5 -4
- airflow/providers/google/cloud/operators/stackdriver.py +5 -11
- airflow/providers/google/cloud/operators/tasks.py +6 -15
- airflow/providers/google/cloud/operators/text_to_speech.py +4 -3
- airflow/providers/google/cloud/operators/translate.py +46 -20
- airflow/providers/google/cloud/operators/translate_speech.py +4 -3
- airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py +44 -34
- airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py +34 -12
- airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +62 -53
- airflow/providers/google/cloud/operators/vertex_ai/dataset.py +75 -11
- airflow/providers/google/cloud/operators/vertex_ai/endpoint_service.py +48 -12
- airflow/providers/google/cloud/operators/vertex_ai/experiment_service.py +435 -0
- airflow/providers/google/cloud/operators/vertex_ai/feature_store.py +532 -1
- airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +135 -116
- airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py +16 -12
- airflow/providers/google/cloud/operators/vertex_ai/model_service.py +62 -14
- airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py +35 -10
- airflow/providers/google/cloud/operators/vertex_ai/ray.py +393 -0
- airflow/providers/google/cloud/operators/video_intelligence.py +5 -3
- airflow/providers/google/cloud/operators/vision.py +7 -5
- airflow/providers/google/cloud/operators/workflows.py +24 -19
- airflow/providers/google/cloud/secrets/secret_manager.py +2 -1
- airflow/providers/google/cloud/sensors/bigquery.py +2 -2
- airflow/providers/google/cloud/sensors/bigquery_dts.py +6 -4
- airflow/providers/google/cloud/sensors/bigtable.py +14 -6
- airflow/providers/google/cloud/sensors/cloud_composer.py +535 -33
- airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py +6 -5
- airflow/providers/google/cloud/sensors/dataflow.py +27 -10
- airflow/providers/google/cloud/sensors/dataform.py +2 -2
- airflow/providers/google/cloud/sensors/datafusion.py +4 -4
- airflow/providers/google/cloud/sensors/dataplex.py +7 -5
- airflow/providers/google/cloud/sensors/dataprep.py +2 -2
- airflow/providers/google/cloud/sensors/dataproc.py +10 -9
- airflow/providers/google/cloud/sensors/dataproc_metastore.py +4 -3
- airflow/providers/google/cloud/sensors/gcs.py +22 -21
- airflow/providers/google/cloud/sensors/looker.py +5 -5
- airflow/providers/google/cloud/sensors/pubsub.py +20 -20
- airflow/providers/google/cloud/sensors/tasks.py +2 -2
- airflow/providers/google/cloud/sensors/vertex_ai/feature_store.py +2 -2
- airflow/providers/google/cloud/sensors/workflows.py +6 -4
- airflow/providers/google/cloud/transfers/adls_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py +11 -8
- airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +14 -13
- airflow/providers/google/cloud/transfers/bigquery_to_mssql.py +7 -3
- airflow/providers/google/cloud/transfers/bigquery_to_mysql.py +12 -1
- airflow/providers/google/cloud/transfers/bigquery_to_postgres.py +24 -10
- airflow/providers/google/cloud/transfers/bigquery_to_sql.py +104 -5
- airflow/providers/google/cloud/transfers/calendar_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/cassandra_to_gcs.py +18 -22
- airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py +4 -5
- airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +45 -38
- airflow/providers/google/cloud/transfers/gcs_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/gcs_to_local.py +5 -3
- airflow/providers/google/cloud/transfers/gcs_to_sftp.py +10 -4
- airflow/providers/google/cloud/transfers/gdrive_to_gcs.py +6 -2
- airflow/providers/google/cloud/transfers/gdrive_to_local.py +2 -2
- airflow/providers/google/cloud/transfers/http_to_gcs.py +193 -0
- airflow/providers/google/cloud/transfers/local_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/mssql_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/oracle_to_gcs.py +36 -11
- airflow/providers/google/cloud/transfers/postgres_to_gcs.py +44 -12
- airflow/providers/google/cloud/transfers/s3_to_gcs.py +12 -6
- airflow/providers/google/cloud/transfers/salesforce_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/sftp_to_gcs.py +36 -14
- airflow/providers/google/cloud/transfers/sheets_to_gcs.py +3 -3
- airflow/providers/google/cloud/transfers/sql_to_gcs.py +10 -10
- airflow/providers/google/cloud/triggers/bigquery.py +75 -34
- airflow/providers/google/cloud/triggers/bigquery_dts.py +2 -1
- airflow/providers/google/cloud/triggers/cloud_batch.py +2 -1
- airflow/providers/google/cloud/triggers/cloud_build.py +3 -2
- airflow/providers/google/cloud/triggers/cloud_composer.py +303 -47
- airflow/providers/google/cloud/triggers/cloud_run.py +2 -2
- airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py +96 -5
- airflow/providers/google/cloud/triggers/dataflow.py +125 -2
- airflow/providers/google/cloud/triggers/datafusion.py +1 -1
- airflow/providers/google/cloud/triggers/dataplex.py +16 -3
- airflow/providers/google/cloud/triggers/dataproc.py +124 -53
- airflow/providers/google/cloud/triggers/kubernetes_engine.py +46 -28
- airflow/providers/google/cloud/triggers/mlengine.py +1 -1
- airflow/providers/google/cloud/triggers/pubsub.py +17 -20
- airflow/providers/google/cloud/triggers/vertex_ai.py +8 -7
- airflow/providers/google/cloud/utils/bigquery.py +5 -7
- airflow/providers/google/cloud/utils/bigquery_get_data.py +1 -1
- airflow/providers/google/cloud/utils/credentials_provider.py +4 -3
- airflow/providers/google/cloud/utils/dataform.py +1 -1
- airflow/providers/google/cloud/utils/external_token_supplier.py +0 -1
- airflow/providers/google/cloud/utils/field_validator.py +1 -2
- airflow/providers/google/cloud/utils/validators.py +43 -0
- airflow/providers/google/common/auth_backend/google_openid.py +26 -9
- airflow/providers/google/common/consts.py +2 -1
- airflow/providers/google/common/deprecated.py +2 -1
- airflow/providers/google/common/hooks/base_google.py +40 -43
- airflow/providers/google/common/hooks/operation_helpers.py +78 -0
- airflow/providers/google/common/links/storage.py +0 -22
- airflow/providers/google/common/utils/get_secret.py +31 -0
- airflow/providers/google/common/utils/id_token_credentials.py +4 -5
- airflow/providers/google/firebase/operators/firestore.py +2 -2
- airflow/providers/google/get_provider_info.py +61 -216
- airflow/providers/google/go_module_utils.py +35 -3
- airflow/providers/google/leveldb/hooks/leveldb.py +30 -6
- airflow/providers/google/leveldb/operators/leveldb.py +2 -2
- airflow/providers/google/marketing_platform/hooks/analytics_admin.py +3 -2
- airflow/providers/google/marketing_platform/hooks/display_video.py +3 -109
- airflow/providers/google/marketing_platform/hooks/search_ads.py +1 -1
- airflow/providers/google/marketing_platform/links/analytics_admin.py +4 -5
- airflow/providers/google/marketing_platform/operators/analytics_admin.py +7 -6
- airflow/providers/google/marketing_platform/operators/campaign_manager.py +5 -5
- airflow/providers/google/marketing_platform/operators/display_video.py +28 -489
- airflow/providers/google/marketing_platform/operators/search_ads.py +2 -2
- airflow/providers/google/marketing_platform/sensors/campaign_manager.py +2 -2
- airflow/providers/google/marketing_platform/sensors/display_video.py +4 -64
- airflow/providers/google/suite/hooks/calendar.py +1 -1
- airflow/providers/google/suite/hooks/drive.py +2 -2
- airflow/providers/google/suite/hooks/sheets.py +15 -1
- airflow/providers/google/suite/operators/sheets.py +8 -3
- airflow/providers/google/suite/sensors/drive.py +2 -2
- airflow/providers/google/suite/transfers/gcs_to_gdrive.py +2 -2
- airflow/providers/google/suite/transfers/gcs_to_sheets.py +1 -1
- airflow/providers/google/suite/transfers/local_to_drive.py +3 -3
- airflow/providers/google/suite/transfers/sql_to_sheets.py +5 -4
- airflow/providers/google/version_compat.py +15 -1
- {apache_airflow_providers_google-14.0.0.dist-info → apache_airflow_providers_google-19.1.0rc1.dist-info}/METADATA +117 -72
- apache_airflow_providers_google-19.1.0rc1.dist-info/RECORD +331 -0
- {apache_airflow_providers_google-14.0.0.dist-info → apache_airflow_providers_google-19.1.0rc1.dist-info}/WHEEL +1 -1
- apache_airflow_providers_google-19.1.0rc1.dist-info/licenses/NOTICE +5 -0
- airflow/providers/google/cloud/example_dags/example_cloud_task.py +0 -54
- airflow/providers/google/cloud/hooks/automl.py +0 -679
- airflow/providers/google/cloud/hooks/life_sciences.py +0 -159
- airflow/providers/google/cloud/links/automl.py +0 -193
- airflow/providers/google/cloud/operators/automl.py +0 -1360
- airflow/providers/google/cloud/operators/life_sciences.py +0 -119
- airflow/providers/google/cloud/operators/mlengine.py +0 -1515
- airflow/providers/google/cloud/utils/mlengine_operator_utils.py +0 -273
- apache_airflow_providers_google-14.0.0.dist-info/RECORD +0 -318
- /airflow/providers/google/cloud/{example_dags → bundles}/__init__.py +0 -0
- {apache_airflow_providers_google-14.0.0.dist-info → apache_airflow_providers_google-19.1.0rc1.dist-info}/entry_points.txt +0 -0
- {airflow/providers/google → apache_airflow_providers_google-19.1.0rc1.dist-info/licenses}/LICENSE +0 -0
|
@@ -19,17 +19,19 @@
|
|
|
19
19
|
|
|
20
20
|
from __future__ import annotations
|
|
21
21
|
|
|
22
|
-
from collections
|
|
23
|
-
from
|
|
22
|
+
from collections import OrderedDict
|
|
23
|
+
from collections.abc import Callable, Sequence
|
|
24
|
+
from typing import TYPE_CHECKING, NamedTuple
|
|
24
25
|
|
|
26
|
+
from google.api_core.exceptions import AlreadyExists, GoogleAPICallError
|
|
27
|
+
from google.cloud.spanner_v1.client import Client
|
|
25
28
|
from sqlalchemy import create_engine
|
|
26
29
|
|
|
27
30
|
from airflow.exceptions import AirflowException
|
|
28
31
|
from airflow.providers.common.sql.hooks.sql import DbApiHook
|
|
29
32
|
from airflow.providers.google.common.consts import CLIENT_INFO
|
|
30
33
|
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook, get_field
|
|
31
|
-
from
|
|
32
|
-
from google.cloud.spanner_v1.client import Client
|
|
34
|
+
from airflow.providers.openlineage.sqlparser import DatabaseInfo
|
|
33
35
|
|
|
34
36
|
if TYPE_CHECKING:
|
|
35
37
|
from google.cloud.spanner_v1.database import Database
|
|
@@ -37,6 +39,8 @@ if TYPE_CHECKING:
|
|
|
37
39
|
from google.cloud.spanner_v1.transaction import Transaction
|
|
38
40
|
from google.longrunning.operations_grpc_pb2 import Operation
|
|
39
41
|
|
|
42
|
+
from airflow.models.connection import Connection
|
|
43
|
+
|
|
40
44
|
|
|
41
45
|
class SpannerConnectionParams(NamedTuple):
|
|
42
46
|
"""Information about Google Spanner connection parameters."""
|
|
@@ -388,7 +392,7 @@ class SpannerHook(GoogleBaseHook, DbApiHook):
|
|
|
388
392
|
database_id: str,
|
|
389
393
|
queries: list[str],
|
|
390
394
|
project_id: str,
|
|
391
|
-
) ->
|
|
395
|
+
) -> list[int]:
|
|
392
396
|
"""
|
|
393
397
|
Execute an arbitrary DML query (INSERT, UPDATE, DELETE).
|
|
394
398
|
|
|
@@ -398,12 +402,73 @@ class SpannerHook(GoogleBaseHook, DbApiHook):
|
|
|
398
402
|
:param project_id: Optional, the ID of the Google Cloud project that owns the Cloud Spanner
|
|
399
403
|
database. If set to None or missing, the default project_id from the Google Cloud connection
|
|
400
404
|
is used.
|
|
405
|
+
:return: list of numbers of affected rows by DML query
|
|
401
406
|
"""
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
407
|
+
db = (
|
|
408
|
+
self._get_client(project_id=project_id)
|
|
409
|
+
.instance(instance_id=instance_id)
|
|
410
|
+
.database(database_id=database_id)
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
def _tx_runner(tx: Transaction) -> dict[str, int]:
|
|
414
|
+
return self._execute_sql_in_transaction(tx, queries)
|
|
415
|
+
|
|
416
|
+
result = db.run_in_transaction(_tx_runner)
|
|
417
|
+
|
|
418
|
+
result_rows_count_per_query = []
|
|
419
|
+
for i, (sql, rc) in enumerate(result.items(), start=1):
|
|
420
|
+
if not sql.startswith("SELECT"):
|
|
421
|
+
preview = sql if len(sql) <= 300 else sql[:300] + "…"
|
|
422
|
+
self.log.info("[DML %d/%d] affected rows=%d | %s", i, len(result), rc, preview)
|
|
423
|
+
result_rows_count_per_query.append(rc)
|
|
424
|
+
return result_rows_count_per_query
|
|
405
425
|
|
|
406
426
|
@staticmethod
|
|
407
|
-
def _execute_sql_in_transaction(transaction: Transaction, queries: list[str]):
|
|
427
|
+
def _execute_sql_in_transaction(transaction: Transaction, queries: list[str]) -> dict[str, int]:
|
|
428
|
+
counts: OrderedDict[str, int] = OrderedDict()
|
|
408
429
|
for sql in queries:
|
|
409
|
-
transaction.execute_update(sql)
|
|
430
|
+
rc = transaction.execute_update(sql)
|
|
431
|
+
counts[sql] = rc
|
|
432
|
+
return counts
|
|
433
|
+
|
|
434
|
+
def _get_openlineage_authority_part(self, connection: Connection) -> str | None:
|
|
435
|
+
"""Build Spanner-specific authority part for OpenLineage. Returns {project}/{instance}."""
|
|
436
|
+
extras = connection.extra_dejson
|
|
437
|
+
project_id = extras.get("project_id")
|
|
438
|
+
instance_id = extras.get("instance_id")
|
|
439
|
+
|
|
440
|
+
if not project_id or not instance_id:
|
|
441
|
+
return None
|
|
442
|
+
|
|
443
|
+
return f"{project_id}/{instance_id}"
|
|
444
|
+
|
|
445
|
+
def get_openlineage_database_dialect(self, connection: Connection) -> str:
|
|
446
|
+
"""Return database dialect for OpenLineage."""
|
|
447
|
+
return "spanner"
|
|
448
|
+
|
|
449
|
+
def get_openlineage_database_info(self, connection: Connection) -> DatabaseInfo:
|
|
450
|
+
"""Return Spanner specific information for OpenLineage."""
|
|
451
|
+
extras = connection.extra_dejson
|
|
452
|
+
database_id = extras.get("database_id")
|
|
453
|
+
|
|
454
|
+
return DatabaseInfo(
|
|
455
|
+
scheme=self.get_openlineage_database_dialect(connection),
|
|
456
|
+
authority=self._get_openlineage_authority_part(connection),
|
|
457
|
+
database=database_id,
|
|
458
|
+
information_schema_columns=[
|
|
459
|
+
"table_schema",
|
|
460
|
+
"table_name",
|
|
461
|
+
"column_name",
|
|
462
|
+
"ordinal_position",
|
|
463
|
+
"spanner_type",
|
|
464
|
+
],
|
|
465
|
+
)
|
|
466
|
+
|
|
467
|
+
def get_openlineage_default_schema(self) -> str | None:
|
|
468
|
+
"""
|
|
469
|
+
Spanner expose 'public' or '' schema depending on dialect(Postgres vs GoogleSQL).
|
|
470
|
+
|
|
471
|
+
SQLAlchemy dialect for Spanner does not expose default schema, so we return None
|
|
472
|
+
to follow the same approach.
|
|
473
|
+
"""
|
|
474
|
+
return None
|
|
@@ -22,12 +22,13 @@ from __future__ import annotations
|
|
|
22
22
|
from collections.abc import Sequence
|
|
23
23
|
from typing import TYPE_CHECKING
|
|
24
24
|
|
|
25
|
-
from airflow.providers.google.common.consts import CLIENT_INFO
|
|
26
|
-
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
|
|
27
25
|
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
|
|
28
26
|
from google.cloud.speech_v1 import SpeechClient
|
|
29
27
|
from google.cloud.speech_v1.types import RecognitionAudio, RecognitionConfig
|
|
30
28
|
|
|
29
|
+
from airflow.providers.google.common.consts import CLIENT_INFO
|
|
30
|
+
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
|
|
31
|
+
|
|
31
32
|
if TYPE_CHECKING:
|
|
32
33
|
from google.api_core.retry import Retry
|
|
33
34
|
|
|
@@ -24,15 +24,15 @@ import json
|
|
|
24
24
|
from collections.abc import Sequence
|
|
25
25
|
from typing import TYPE_CHECKING, Any
|
|
26
26
|
|
|
27
|
-
from googleapiclient.errors import HttpError
|
|
28
|
-
|
|
29
|
-
from airflow.exceptions import AirflowException
|
|
30
|
-
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook
|
|
31
27
|
from google.api_core.exceptions import InvalidArgument
|
|
32
28
|
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
|
|
33
29
|
from google.cloud import monitoring_v3
|
|
34
30
|
from google.cloud.monitoring_v3 import AlertPolicy, NotificationChannel
|
|
35
31
|
from google.protobuf.field_mask_pb2 import FieldMask
|
|
32
|
+
from googleapiclient.errors import HttpError
|
|
33
|
+
|
|
34
|
+
from airflow.exceptions import AirflowException
|
|
35
|
+
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook
|
|
36
36
|
|
|
37
37
|
if TYPE_CHECKING:
|
|
38
38
|
from google.api_core.retry import Retry
|
|
@@ -121,10 +121,9 @@ class StackdriverHook(GoogleBaseHook):
|
|
|
121
121
|
)
|
|
122
122
|
if format_ == "dict":
|
|
123
123
|
return [AlertPolicy.to_dict(policy) for policy in policies_]
|
|
124
|
-
|
|
124
|
+
if format_ == "json":
|
|
125
125
|
return [AlertPolicy.to_jsoon(policy) for policy in policies_]
|
|
126
|
-
|
|
127
|
-
return policies_
|
|
126
|
+
return policies_
|
|
128
127
|
|
|
129
128
|
@GoogleBaseHook.fallback_to_default_project_id
|
|
130
129
|
def _toggle_policy_status(
|
|
@@ -262,8 +261,9 @@ class StackdriverHook(GoogleBaseHook):
|
|
|
262
261
|
channel_name_map = {}
|
|
263
262
|
|
|
264
263
|
for channel in channels:
|
|
264
|
+
# This field is immutable, illegal to specifying non-default UNVERIFIED or VERIFIED, so setting default
|
|
265
265
|
channel.verification_status = (
|
|
266
|
-
monitoring_v3.NotificationChannel.VerificationStatus.VERIFICATION_STATUS_UNSPECIFIED
|
|
266
|
+
monitoring_v3.NotificationChannel.VerificationStatus.VERIFICATION_STATUS_UNSPECIFIED # type: ignore[assignment]
|
|
267
267
|
)
|
|
268
268
|
|
|
269
269
|
if channel.name in existing_channels:
|
|
@@ -275,7 +275,7 @@ class StackdriverHook(GoogleBaseHook):
|
|
|
275
275
|
)
|
|
276
276
|
else:
|
|
277
277
|
old_name = channel.name
|
|
278
|
-
channel.name
|
|
278
|
+
del channel.name
|
|
279
279
|
new_channel = channel_client.create_notification_channel(
|
|
280
280
|
request={"name": f"projects/{project_id}", "notification_channel": channel},
|
|
281
281
|
retry=retry,
|
|
@@ -285,8 +285,8 @@ class StackdriverHook(GoogleBaseHook):
|
|
|
285
285
|
channel_name_map[old_name] = new_channel.name
|
|
286
286
|
|
|
287
287
|
for policy in policies_:
|
|
288
|
-
policy.creation_record
|
|
289
|
-
policy.mutation_record
|
|
288
|
+
del policy.creation_record
|
|
289
|
+
del policy.mutation_record
|
|
290
290
|
|
|
291
291
|
for i, channel in enumerate(policy.notification_channels):
|
|
292
292
|
new_channel = channel_name_map.get(channel)
|
|
@@ -302,9 +302,9 @@ class StackdriverHook(GoogleBaseHook):
|
|
|
302
302
|
metadata=metadata,
|
|
303
303
|
)
|
|
304
304
|
else:
|
|
305
|
-
policy.name
|
|
305
|
+
del policy.name
|
|
306
306
|
for condition in policy.conditions:
|
|
307
|
-
condition.name
|
|
307
|
+
del condition.name
|
|
308
308
|
policy_client.create_alert_policy(
|
|
309
309
|
request={"name": f"projects/{project_id}", "alert_policy": policy},
|
|
310
310
|
retry=retry,
|
|
@@ -395,10 +395,9 @@ class StackdriverHook(GoogleBaseHook):
|
|
|
395
395
|
)
|
|
396
396
|
if format_ == "dict":
|
|
397
397
|
return [NotificationChannel.to_dict(channel) for channel in channels]
|
|
398
|
-
|
|
398
|
+
if format_ == "json":
|
|
399
399
|
return [NotificationChannel.to_json(channel) for channel in channels]
|
|
400
|
-
|
|
401
|
-
return channels
|
|
400
|
+
return channels
|
|
402
401
|
|
|
403
402
|
@GoogleBaseHook.fallback_to_default_project_id
|
|
404
403
|
def _toggle_channel_status(
|
|
@@ -533,8 +532,9 @@ class StackdriverHook(GoogleBaseHook):
|
|
|
533
532
|
channels_list.append(NotificationChannel(**channel))
|
|
534
533
|
|
|
535
534
|
for channel in channels_list:
|
|
535
|
+
# This field is immutable, illegal to specifying non-default UNVERIFIED or VERIFIED, so setting default
|
|
536
536
|
channel.verification_status = (
|
|
537
|
-
monitoring_v3.NotificationChannel.VerificationStatus.VERIFICATION_STATUS_UNSPECIFIED
|
|
537
|
+
monitoring_v3.NotificationChannel.VerificationStatus.VERIFICATION_STATUS_UNSPECIFIED # type: ignore[assignment]
|
|
538
538
|
)
|
|
539
539
|
|
|
540
540
|
if channel.name in existing_channels:
|
|
@@ -546,7 +546,7 @@ class StackdriverHook(GoogleBaseHook):
|
|
|
546
546
|
)
|
|
547
547
|
else:
|
|
548
548
|
old_name = channel.name
|
|
549
|
-
channel.name
|
|
549
|
+
del channel.name
|
|
550
550
|
new_channel = channel_client.create_notification_channel(
|
|
551
551
|
request={"name": f"projects/{project_id}", "notification_channel": channel},
|
|
552
552
|
retry=retry,
|
|
@@ -22,13 +22,14 @@ from __future__ import annotations
|
|
|
22
22
|
from collections.abc import Sequence
|
|
23
23
|
from typing import TYPE_CHECKING
|
|
24
24
|
|
|
25
|
-
from airflow.exceptions import AirflowException
|
|
26
|
-
from airflow.providers.google.common.consts import CLIENT_INFO
|
|
27
|
-
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook
|
|
28
25
|
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
|
|
29
26
|
from google.cloud.tasks_v2 import CloudTasksClient
|
|
30
27
|
from google.cloud.tasks_v2.types import Queue, Task
|
|
31
28
|
|
|
29
|
+
from airflow.exceptions import AirflowException
|
|
30
|
+
from airflow.providers.google.common.consts import CLIENT_INFO
|
|
31
|
+
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook
|
|
32
|
+
|
|
32
33
|
if TYPE_CHECKING:
|
|
33
34
|
from google.api_core.retry import Retry
|
|
34
35
|
from google.protobuf.field_mask_pb2 import FieldMask
|
|
@@ -22,8 +22,6 @@ from __future__ import annotations
|
|
|
22
22
|
from collections.abc import Sequence
|
|
23
23
|
from typing import TYPE_CHECKING
|
|
24
24
|
|
|
25
|
-
from airflow.providers.google.common.consts import CLIENT_INFO
|
|
26
|
-
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
|
|
27
25
|
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
|
|
28
26
|
from google.cloud.texttospeech_v1 import TextToSpeechClient
|
|
29
27
|
from google.cloud.texttospeech_v1.types import (
|
|
@@ -33,6 +31,9 @@ from google.cloud.texttospeech_v1.types import (
|
|
|
33
31
|
VoiceSelectionParams,
|
|
34
32
|
)
|
|
35
33
|
|
|
34
|
+
from airflow.providers.google.common.consts import CLIENT_INFO
|
|
35
|
+
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
|
|
36
|
+
|
|
36
37
|
if TYPE_CHECKING:
|
|
37
38
|
from google.api_core.retry import Retry
|
|
38
39
|
|
|
@@ -25,9 +25,6 @@ from typing import (
|
|
|
25
25
|
cast,
|
|
26
26
|
)
|
|
27
27
|
|
|
28
|
-
from airflow.exceptions import AirflowException
|
|
29
|
-
from airflow.providers.google.common.consts import CLIENT_INFO
|
|
30
|
-
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook
|
|
31
28
|
from google.api_core.exceptions import GoogleAPICallError
|
|
32
29
|
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
|
|
33
30
|
from google.api_core.retry import Retry
|
|
@@ -35,9 +32,12 @@ from google.cloud.translate_v2 import Client
|
|
|
35
32
|
from google.cloud.translate_v3 import TranslationServiceClient
|
|
36
33
|
from google.cloud.translate_v3.types.translation_service import GlossaryInputConfig
|
|
37
34
|
|
|
38
|
-
|
|
39
|
-
|
|
35
|
+
from airflow.exceptions import AirflowException
|
|
36
|
+
from airflow.providers.google.common.consts import CLIENT_INFO
|
|
37
|
+
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook
|
|
38
|
+
from airflow.providers.google.common.hooks.operation_helpers import OperationHelper
|
|
40
39
|
|
|
40
|
+
if TYPE_CHECKING:
|
|
41
41
|
from google.api_core.operation import Operation
|
|
42
42
|
from google.cloud.translate_v3.services.translation_service import pagers
|
|
43
43
|
from google.cloud.translate_v3.types import (
|
|
@@ -155,7 +155,7 @@ class CloudTranslateHook(GoogleBaseHook):
|
|
|
155
155
|
)
|
|
156
156
|
|
|
157
157
|
|
|
158
|
-
class TranslateHook(GoogleBaseHook):
|
|
158
|
+
class TranslateHook(GoogleBaseHook, OperationHelper):
|
|
159
159
|
"""
|
|
160
160
|
Hook for Google Cloud translation (Advanced) using client version V3.
|
|
161
161
|
|
|
@@ -221,15 +221,6 @@ class TranslateHook(GoogleBaseHook):
|
|
|
221
221
|
error = operation.exception(timeout=timeout)
|
|
222
222
|
raise AirflowException(error)
|
|
223
223
|
|
|
224
|
-
@staticmethod
|
|
225
|
-
def wait_for_operation_result(operation: Operation, timeout: int | None = None) -> Message:
|
|
226
|
-
"""Wait for long-lasting operation to complete."""
|
|
227
|
-
try:
|
|
228
|
-
return operation.result(timeout=timeout)
|
|
229
|
-
except GoogleAPICallError:
|
|
230
|
-
error = operation.exception(timeout=timeout)
|
|
231
|
-
raise AirflowException(error)
|
|
232
|
-
|
|
233
224
|
@staticmethod
|
|
234
225
|
def extract_object_id(obj: dict) -> str:
|
|
235
226
|
"""Return unique id of the object."""
|
|
@@ -320,7 +311,7 @@ class TranslateHook(GoogleBaseHook):
|
|
|
320
311
|
retry=retry,
|
|
321
312
|
metadata=metadata,
|
|
322
313
|
)
|
|
323
|
-
return cast(dict, type(result).to_dict(result))
|
|
314
|
+
return cast("dict", type(result).to_dict(result))
|
|
324
315
|
|
|
325
316
|
def batch_translate_text(
|
|
326
317
|
self,
|
|
@@ -438,7 +429,7 @@ class TranslateHook(GoogleBaseHook):
|
|
|
438
429
|
project_id: str,
|
|
439
430
|
location: str,
|
|
440
431
|
retry: Retry | _MethodDefault = DEFAULT,
|
|
441
|
-
timeout: float | _MethodDefault = DEFAULT,
|
|
432
|
+
timeout: float | None | _MethodDefault = DEFAULT,
|
|
442
433
|
metadata: Sequence[tuple[str, str]] = (),
|
|
443
434
|
) -> automl_translation.Dataset:
|
|
444
435
|
"""
|
|
@@ -23,9 +23,6 @@ import warnings
|
|
|
23
23
|
from collections.abc import Sequence
|
|
24
24
|
from typing import TYPE_CHECKING
|
|
25
25
|
|
|
26
|
-
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
|
|
27
|
-
from airflow.providers.google.common.deprecated import deprecated
|
|
28
|
-
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
|
|
29
26
|
from google.api_core.client_options import ClientOptions
|
|
30
27
|
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
|
|
31
28
|
from google.cloud.aiplatform import (
|
|
@@ -39,6 +36,11 @@ from google.cloud.aiplatform import (
|
|
|
39
36
|
)
|
|
40
37
|
from google.cloud.aiplatform_v1 import JobServiceClient, PipelineServiceClient
|
|
41
38
|
|
|
39
|
+
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
|
|
40
|
+
from airflow.providers.google.common.consts import CLIENT_INFO
|
|
41
|
+
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
|
|
42
|
+
from airflow.providers.google.common.hooks.operation_helpers import OperationHelper
|
|
43
|
+
|
|
42
44
|
if TYPE_CHECKING:
|
|
43
45
|
from google.api_core.operation import Operation
|
|
44
46
|
from google.api_core.retry import Retry
|
|
@@ -46,7 +48,7 @@ if TYPE_CHECKING:
|
|
|
46
48
|
from google.cloud.aiplatform_v1.types import TrainingPipeline
|
|
47
49
|
|
|
48
50
|
|
|
49
|
-
class AutoMLHook(GoogleBaseHook):
|
|
51
|
+
class AutoMLHook(GoogleBaseHook, OperationHelper):
|
|
50
52
|
"""Hook for Google Cloud Vertex AI Auto ML APIs."""
|
|
51
53
|
|
|
52
54
|
def __init__(
|
|
@@ -79,7 +81,7 @@ class AutoMLHook(GoogleBaseHook):
|
|
|
79
81
|
client_options = ClientOptions()
|
|
80
82
|
|
|
81
83
|
return PipelineServiceClient(
|
|
82
|
-
credentials=self.get_credentials(), client_info=
|
|
84
|
+
credentials=self.get_credentials(), client_info=CLIENT_INFO, client_options=client_options
|
|
83
85
|
)
|
|
84
86
|
|
|
85
87
|
def get_job_service_client(
|
|
@@ -93,7 +95,7 @@ class AutoMLHook(GoogleBaseHook):
|
|
|
93
95
|
client_options = ClientOptions()
|
|
94
96
|
|
|
95
97
|
return JobServiceClient(
|
|
96
|
-
credentials=self.get_credentials(), client_info=
|
|
98
|
+
credentials=self.get_credentials(), client_info=CLIENT_INFO, client_options=client_options
|
|
97
99
|
)
|
|
98
100
|
|
|
99
101
|
def get_auto_ml_tabular_training_job(
|
|
@@ -182,42 +184,6 @@ class AutoMLHook(GoogleBaseHook):
|
|
|
182
184
|
model_encryption_spec_key_name=model_encryption_spec_key_name,
|
|
183
185
|
)
|
|
184
186
|
|
|
185
|
-
@deprecated(
|
|
186
|
-
planned_removal_date="June 15, 2025",
|
|
187
|
-
category=AirflowProviderDeprecationWarning,
|
|
188
|
-
reason="Deprecation of AutoMLText API",
|
|
189
|
-
)
|
|
190
|
-
def get_auto_ml_text_training_job(
|
|
191
|
-
self,
|
|
192
|
-
display_name: str,
|
|
193
|
-
prediction_type: str,
|
|
194
|
-
multi_label: bool = False,
|
|
195
|
-
sentiment_max: int = 10,
|
|
196
|
-
project: str | None = None,
|
|
197
|
-
location: str | None = None,
|
|
198
|
-
labels: dict[str, str] | None = None,
|
|
199
|
-
training_encryption_spec_key_name: str | None = None,
|
|
200
|
-
model_encryption_spec_key_name: str | None = None,
|
|
201
|
-
) -> AutoMLTextTrainingJob:
|
|
202
|
-
"""
|
|
203
|
-
Return AutoMLTextTrainingJob object.
|
|
204
|
-
|
|
205
|
-
WARNING: Text creation API is deprecated since September 15, 2024
|
|
206
|
-
(https://cloud.google.com/vertex-ai/docs/tutorials/text-classification-automl/overview).
|
|
207
|
-
"""
|
|
208
|
-
return AutoMLTextTrainingJob(
|
|
209
|
-
display_name=display_name,
|
|
210
|
-
prediction_type=prediction_type,
|
|
211
|
-
multi_label=multi_label,
|
|
212
|
-
sentiment_max=sentiment_max,
|
|
213
|
-
project=project,
|
|
214
|
-
location=location,
|
|
215
|
-
credentials=self.get_credentials(),
|
|
216
|
-
labels=labels,
|
|
217
|
-
training_encryption_spec_key_name=training_encryption_spec_key_name,
|
|
218
|
-
model_encryption_spec_key_name=model_encryption_spec_key_name,
|
|
219
|
-
)
|
|
220
|
-
|
|
221
187
|
def get_auto_ml_video_training_job(
|
|
222
188
|
self,
|
|
223
189
|
display_name: str,
|
|
@@ -252,14 +218,6 @@ class AutoMLHook(GoogleBaseHook):
|
|
|
252
218
|
"""Return unique id of the Training pipeline."""
|
|
253
219
|
return resource_name.rpartition("/")[-1]
|
|
254
220
|
|
|
255
|
-
def wait_for_operation(self, operation: Operation, timeout: float | None = None):
|
|
256
|
-
"""Wait for long-lasting operation to complete."""
|
|
257
|
-
try:
|
|
258
|
-
return operation.result(timeout=timeout)
|
|
259
|
-
except Exception:
|
|
260
|
-
error = operation.exception(timeout=timeout)
|
|
261
|
-
raise AirflowException(error)
|
|
262
|
-
|
|
263
221
|
def cancel_auto_ml_job(self) -> None:
|
|
264
222
|
"""Cancel Auto ML Job for training pipeline."""
|
|
265
223
|
if self._job:
|
|
@@ -992,178 +950,6 @@ class AutoMLHook(GoogleBaseHook):
|
|
|
992
950
|
)
|
|
993
951
|
return model, training_id
|
|
994
952
|
|
|
995
|
-
@GoogleBaseHook.fallback_to_default_project_id
|
|
996
|
-
@deprecated(
|
|
997
|
-
planned_removal_date="September 15, 2025",
|
|
998
|
-
category=AirflowProviderDeprecationWarning,
|
|
999
|
-
reason="Deprecation of AutoMLText API",
|
|
1000
|
-
)
|
|
1001
|
-
def create_auto_ml_text_training_job(
|
|
1002
|
-
self,
|
|
1003
|
-
project_id: str,
|
|
1004
|
-
region: str,
|
|
1005
|
-
display_name: str,
|
|
1006
|
-
dataset: datasets.TextDataset,
|
|
1007
|
-
prediction_type: str,
|
|
1008
|
-
multi_label: bool = False,
|
|
1009
|
-
sentiment_max: int = 10,
|
|
1010
|
-
labels: dict[str, str] | None = None,
|
|
1011
|
-
training_encryption_spec_key_name: str | None = None,
|
|
1012
|
-
model_encryption_spec_key_name: str | None = None,
|
|
1013
|
-
training_fraction_split: float | None = None,
|
|
1014
|
-
validation_fraction_split: float | None = None,
|
|
1015
|
-
test_fraction_split: float | None = None,
|
|
1016
|
-
training_filter_split: str | None = None,
|
|
1017
|
-
validation_filter_split: str | None = None,
|
|
1018
|
-
test_filter_split: str | None = None,
|
|
1019
|
-
model_display_name: str | None = None,
|
|
1020
|
-
model_labels: dict[str, str] | None = None,
|
|
1021
|
-
sync: bool = True,
|
|
1022
|
-
parent_model: str | None = None,
|
|
1023
|
-
is_default_version: bool | None = None,
|
|
1024
|
-
model_version_aliases: list[str] | None = None,
|
|
1025
|
-
model_version_description: str | None = None,
|
|
1026
|
-
) -> tuple[models.Model | None, str]:
|
|
1027
|
-
"""
|
|
1028
|
-
Create an AutoML Text Training Job.
|
|
1029
|
-
|
|
1030
|
-
WARNING: Text creation API is deprecated since September 15, 2024
|
|
1031
|
-
(https://cloud.google.com/vertex-ai/docs/tutorials/text-classification-automl/overview).
|
|
1032
|
-
|
|
1033
|
-
:param project_id: Required. Project to run training in.
|
|
1034
|
-
:param region: Required. Location to run training in.
|
|
1035
|
-
:param display_name: Required. The user-defined name of this TrainingPipeline.
|
|
1036
|
-
:param dataset: Required. The dataset within the same Project from which data will be used to train
|
|
1037
|
-
the Model. The Dataset must use schema compatible with Model being trained, and what is
|
|
1038
|
-
compatible should be described in the used TrainingPipeline's [training_task_definition]
|
|
1039
|
-
[google.cloud.aiplatform.v1beta1.TrainingPipeline.training_task_definition].
|
|
1040
|
-
:param prediction_type: The type of prediction the Model is to produce, one of:
|
|
1041
|
-
"classification" - A classification model analyzes text data and returns a list of categories
|
|
1042
|
-
that apply to the text found in the data. Vertex AI offers both single-label and multi-label text
|
|
1043
|
-
classification models.
|
|
1044
|
-
"extraction" - An entity extraction model inspects text data for known entities referenced in the
|
|
1045
|
-
data and labels those entities in the text.
|
|
1046
|
-
"sentiment" - A sentiment analysis model inspects text data and identifies the prevailing
|
|
1047
|
-
emotional opinion within it, especially to determine a writer's attitude as positive, negative,
|
|
1048
|
-
or neutral.
|
|
1049
|
-
:param parent_model: Optional. The resource name or model ID of an existing model.
|
|
1050
|
-
The new model uploaded by this job will be a version of `parent_model`.
|
|
1051
|
-
Only set this field when training a new version of an existing model.
|
|
1052
|
-
:param is_default_version: Optional. When set to True, the newly uploaded model version will
|
|
1053
|
-
automatically have alias "default" included. Subsequent uses of
|
|
1054
|
-
the model produced by this job without a version specified will
|
|
1055
|
-
use this "default" version.
|
|
1056
|
-
When set to False, the "default" alias will not be moved.
|
|
1057
|
-
Actions targeting the model version produced by this job will need
|
|
1058
|
-
to specifically reference this version by ID or alias.
|
|
1059
|
-
New model uploads, i.e. version 1, will always be "default" aliased.
|
|
1060
|
-
:param model_version_aliases: Optional. User provided version aliases so that the model version
|
|
1061
|
-
uploaded by this job can be referenced via alias instead of
|
|
1062
|
-
auto-generated version ID. A default version alias will be created
|
|
1063
|
-
for the first version of the model.
|
|
1064
|
-
The format is [a-z][a-zA-Z0-9-]{0,126}[a-z0-9]
|
|
1065
|
-
:param model_version_description: Optional. The description of the model version
|
|
1066
|
-
being uploaded by this job.
|
|
1067
|
-
:param multi_label: Required and only applicable for text classification task. If false, a
|
|
1068
|
-
single-label (multi-class) Model will be trained (i.e. assuming that for each text snippet just
|
|
1069
|
-
up to one annotation may be applicable). If true, a multi-label Model will be trained (i.e.
|
|
1070
|
-
assuming that for each text snippet multiple annotations may be applicable).
|
|
1071
|
-
:param sentiment_max: Required and only applicable for sentiment task. A sentiment is expressed as an
|
|
1072
|
-
integer ordinal, where higher value means a more positive sentiment. The range of sentiments that
|
|
1073
|
-
will be used is between 0 and sentimentMax (inclusive on both ends), and all the values in the
|
|
1074
|
-
range must be represented in the dataset before a model can be created. Only the Annotations with
|
|
1075
|
-
this sentimentMax will be used for training. sentimentMax value must be between 1 and 10
|
|
1076
|
-
(inclusive).
|
|
1077
|
-
:param labels: Optional. The labels with user-defined metadata to organize TrainingPipelines. Label
|
|
1078
|
-
keys and values can be no longer than 64 characters (Unicode codepoints), can only contain
|
|
1079
|
-
lowercase letters, numeric characters, underscores and dashes. International characters are
|
|
1080
|
-
allowed. See https://goo.gl/xmQnxf for more information and examples of labels.
|
|
1081
|
-
:param training_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer
|
|
1082
|
-
managed encryption key used to protect the training pipeline. Has the form:
|
|
1083
|
-
``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
|
|
1084
|
-
The key needs to be in the same region as where the compute resource is created.
|
|
1085
|
-
If set, this TrainingPipeline will be secured by this key.
|
|
1086
|
-
Note: Model trained by this TrainingPipeline is also secured by this key if ``model_to_upload``
|
|
1087
|
-
is not set separately.
|
|
1088
|
-
:param model_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer
|
|
1089
|
-
managed encryption key used to protect the model. Has the form:
|
|
1090
|
-
``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
|
|
1091
|
-
The key needs to be in the same region as where the compute resource is created.
|
|
1092
|
-
If set, the trained Model will be secured by this key.
|
|
1093
|
-
:param training_fraction_split: Optional. The fraction of the input data that is to be used to train
|
|
1094
|
-
the Model. This is ignored if Dataset is not provided.
|
|
1095
|
-
:param validation_fraction_split: Optional. The fraction of the input data that is to be used to
|
|
1096
|
-
validate the Model. This is ignored if Dataset is not provided.
|
|
1097
|
-
:param test_fraction_split: Optional. The fraction of the input data that is to be used to evaluate
|
|
1098
|
-
the Model. This is ignored if Dataset is not provided.
|
|
1099
|
-
:param training_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match
|
|
1100
|
-
this filter are used to train the Model. A filter with same syntax as the one used in
|
|
1101
|
-
DatasetService.ListDataItems may be used. If a single DataItem is matched by more than one of the
|
|
1102
|
-
FilterSplit filters, then it is assigned to the first set that applies to it in the training,
|
|
1103
|
-
validation, test order. This is ignored if Dataset is not provided.
|
|
1104
|
-
:param validation_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match
|
|
1105
|
-
this filter are used to validate the Model. A filter with same syntax as the one used in
|
|
1106
|
-
DatasetService.ListDataItems may be used. If a single DataItem is matched by more than one of the
|
|
1107
|
-
FilterSplit filters, then it is assigned to the first set that applies to it in the training,
|
|
1108
|
-
validation, test order. This is ignored if Dataset is not provided.
|
|
1109
|
-
:param test_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match this
|
|
1110
|
-
filter are used to test the Model. A filter with same syntax as the one used in
|
|
1111
|
-
DatasetService.ListDataItems may be used. If a single DataItem is matched by more than one of the
|
|
1112
|
-
FilterSplit filters, then it is assigned to the first set that applies to it in the training,
|
|
1113
|
-
validation, test order. This is ignored if Dataset is not provided.
|
|
1114
|
-
:param model_display_name: Optional. The display name of the managed Vertex AI Model. The name can be
|
|
1115
|
-
up to 128 characters long and can consist of any UTF-8 characters.
|
|
1116
|
-
If not provided upon creation, the job's display_name is used.
|
|
1117
|
-
:param model_labels: Optional. The labels with user-defined metadata to organize your Models. Label
|
|
1118
|
-
keys and values can be no longer than 64 characters (Unicode codepoints), can only contain
|
|
1119
|
-
lowercase letters, numeric characters, underscores and dashes. International characters are
|
|
1120
|
-
allowed. See https://goo.gl/xmQnxf for more information and examples of labels.
|
|
1121
|
-
:param sync: Whether to execute this method synchronously. If False, this method will be executed in
|
|
1122
|
-
concurrent Future and any downstream object will be immediately returned and synced when the
|
|
1123
|
-
Future has completed.
|
|
1124
|
-
"""
|
|
1125
|
-
self._job = AutoMLTextTrainingJob(
|
|
1126
|
-
display_name=display_name,
|
|
1127
|
-
prediction_type=prediction_type,
|
|
1128
|
-
multi_label=multi_label,
|
|
1129
|
-
sentiment_max=sentiment_max,
|
|
1130
|
-
project=project_id,
|
|
1131
|
-
location=region,
|
|
1132
|
-
credentials=self.get_credentials(),
|
|
1133
|
-
labels=labels,
|
|
1134
|
-
training_encryption_spec_key_name=training_encryption_spec_key_name,
|
|
1135
|
-
model_encryption_spec_key_name=model_encryption_spec_key_name,
|
|
1136
|
-
)
|
|
1137
|
-
|
|
1138
|
-
if not self._job:
|
|
1139
|
-
raise AirflowException("AutoMLTextTrainingJob was not created")
|
|
1140
|
-
|
|
1141
|
-
model = self._job.run(
|
|
1142
|
-
dataset=dataset, # type: ignore[arg-type]
|
|
1143
|
-
training_fraction_split=training_fraction_split, # type: ignore[call-arg]
|
|
1144
|
-
validation_fraction_split=validation_fraction_split, # type: ignore[call-arg]
|
|
1145
|
-
test_fraction_split=test_fraction_split,
|
|
1146
|
-
training_filter_split=training_filter_split,
|
|
1147
|
-
validation_filter_split=validation_filter_split,
|
|
1148
|
-
test_filter_split=test_filter_split, # type: ignore[call-arg]
|
|
1149
|
-
model_display_name=model_display_name,
|
|
1150
|
-
model_labels=model_labels,
|
|
1151
|
-
sync=sync,
|
|
1152
|
-
parent_model=parent_model,
|
|
1153
|
-
is_default_version=is_default_version,
|
|
1154
|
-
model_version_aliases=model_version_aliases,
|
|
1155
|
-
model_version_description=model_version_description,
|
|
1156
|
-
)
|
|
1157
|
-
training_id = self.extract_training_id(self._job.resource_name)
|
|
1158
|
-
if model:
|
|
1159
|
-
model.wait()
|
|
1160
|
-
else:
|
|
1161
|
-
self.log.warning(
|
|
1162
|
-
"Training did not produce a Managed Model returning None. AutoML Text Training "
|
|
1163
|
-
"Pipeline is not configured to upload a Model."
|
|
1164
|
-
)
|
|
1165
|
-
return model, training_id
|
|
1166
|
-
|
|
1167
953
|
@GoogleBaseHook.fallback_to_default_project_id
|
|
1168
954
|
def create_auto_ml_video_training_job(
|
|
1169
955
|
self,
|