apache-airflow-providers-google 17.2.0__py3-none-any.whl → 18.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of apache-airflow-providers-google might be problematic. Click here for more details.
- airflow/providers/google/__init__.py +1 -1
- airflow/providers/google/cloud/hooks/bigquery.py +6 -0
- airflow/providers/google/cloud/hooks/cloud_composer.py +79 -13
- airflow/providers/google/cloud/hooks/cloud_run.py +16 -8
- airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py +0 -173
- airflow/providers/google/cloud/log/gcs_task_handler.py +8 -2
- airflow/providers/google/cloud/operators/cloud_composer.py +84 -1
- airflow/providers/google/cloud/sensors/cloud_composer.py +1 -1
- airflow/providers/google/cloud/transfers/bigquery_to_mssql.py +0 -66
- airflow/providers/google/cloud/transfers/bigquery_to_mysql.py +12 -1
- airflow/providers/google/cloud/transfers/bigquery_to_postgres.py +18 -9
- airflow/providers/google/cloud/transfers/bigquery_to_sql.py +95 -0
- airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +11 -0
- airflow/providers/google/cloud/triggers/cloud_composer.py +21 -15
- airflow/providers/google/cloud/utils/bigquery_get_data.py +1 -1
- airflow/providers/google/marketing_platform/hooks/display_video.py +0 -150
- airflow/providers/google/marketing_platform/operators/display_video.py +0 -510
- airflow/providers/google/marketing_platform/sensors/display_video.py +1 -68
- {apache_airflow_providers_google-17.2.0.dist-info → apache_airflow_providers_google-18.0.0.dist-info}/METADATA +35 -8
- {apache_airflow_providers_google-17.2.0.dist-info → apache_airflow_providers_google-18.0.0.dist-info}/RECORD +22 -22
- {apache_airflow_providers_google-17.2.0.dist-info → apache_airflow_providers_google-18.0.0.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_google-17.2.0.dist-info → apache_airflow_providers_google-18.0.0.dist-info}/entry_points.txt +0 -0
|
@@ -29,7 +29,7 @@ from airflow import __version__ as airflow_version
|
|
|
29
29
|
|
|
30
30
|
__all__ = ["__version__"]
|
|
31
31
|
|
|
32
|
-
__version__ = "
|
|
32
|
+
__version__ = "18.0.0"
|
|
33
33
|
|
|
34
34
|
if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
|
|
35
35
|
"2.10.0"
|
|
@@ -1702,6 +1702,7 @@ class BigQueryCursor(BigQueryBaseCursor):
|
|
|
1702
1702
|
schema_update_options: Iterable | None = None,
|
|
1703
1703
|
priority: str | None = None,
|
|
1704
1704
|
time_partitioning: dict | None = None,
|
|
1705
|
+
range_partitioning: dict | None = None,
|
|
1705
1706
|
api_resource_configs: dict | None = None,
|
|
1706
1707
|
cluster_fields: list[str] | None = None,
|
|
1707
1708
|
encryption_configuration: dict | None = None,
|
|
@@ -1714,6 +1715,10 @@ class BigQueryCursor(BigQueryBaseCursor):
|
|
|
1714
1715
|
|
|
1715
1716
|
if time_partitioning is None:
|
|
1716
1717
|
time_partitioning = {}
|
|
1718
|
+
if range_partitioning is None:
|
|
1719
|
+
range_partitioning = {}
|
|
1720
|
+
if time_partitioning and range_partitioning:
|
|
1721
|
+
raise ValueError("Only one of time_partitioning or range_partitioning can be set.")
|
|
1717
1722
|
|
|
1718
1723
|
if not api_resource_configs:
|
|
1719
1724
|
api_resource_configs = self.hook.api_resource_configs
|
|
@@ -1766,6 +1771,7 @@ class BigQueryCursor(BigQueryBaseCursor):
|
|
|
1766
1771
|
(maximum_billing_tier, "maximumBillingTier", None, int),
|
|
1767
1772
|
(maximum_bytes_billed, "maximumBytesBilled", None, float),
|
|
1768
1773
|
(time_partitioning, "timePartitioning", {}, dict),
|
|
1774
|
+
(range_partitioning, "rangePartitioning", {}, dict),
|
|
1769
1775
|
(schema_update_options, "schemaUpdateOptions", None, list),
|
|
1770
1776
|
(destination_dataset_table, "destinationTable", None, dict),
|
|
1771
1777
|
(cluster_fields, "clustering", None, dict),
|
|
@@ -18,12 +18,15 @@
|
|
|
18
18
|
from __future__ import annotations
|
|
19
19
|
|
|
20
20
|
import asyncio
|
|
21
|
+
import json
|
|
21
22
|
import time
|
|
22
23
|
from collections.abc import MutableSequence, Sequence
|
|
23
|
-
from typing import TYPE_CHECKING
|
|
24
|
+
from typing import TYPE_CHECKING, Any
|
|
25
|
+
from urllib.parse import urljoin
|
|
24
26
|
|
|
25
27
|
from google.api_core.client_options import ClientOptions
|
|
26
28
|
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
|
|
29
|
+
from google.auth.transport.requests import AuthorizedSession
|
|
27
30
|
from google.cloud.orchestration.airflow.service_v1 import (
|
|
28
31
|
EnvironmentsAsyncClient,
|
|
29
32
|
EnvironmentsClient,
|
|
@@ -33,7 +36,7 @@ from google.cloud.orchestration.airflow.service_v1 import (
|
|
|
33
36
|
|
|
34
37
|
from airflow.exceptions import AirflowException
|
|
35
38
|
from airflow.providers.google.common.consts import CLIENT_INFO
|
|
36
|
-
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
|
|
39
|
+
from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook
|
|
37
40
|
|
|
38
41
|
if TYPE_CHECKING:
|
|
39
42
|
from google.api_core.operation import Operation
|
|
@@ -76,6 +79,34 @@ class CloudComposerHook(GoogleBaseHook, OperationHelper):
|
|
|
76
79
|
client_options=self.client_options,
|
|
77
80
|
)
|
|
78
81
|
|
|
82
|
+
def make_composer_airflow_api_request(
|
|
83
|
+
self,
|
|
84
|
+
method: str,
|
|
85
|
+
airflow_uri: str,
|
|
86
|
+
path: str,
|
|
87
|
+
data: Any | None = None,
|
|
88
|
+
timeout: float | None = None,
|
|
89
|
+
):
|
|
90
|
+
"""
|
|
91
|
+
Make a request to Cloud Composer environment's web server.
|
|
92
|
+
|
|
93
|
+
:param method: The request method to use ('GET', 'OPTIONS', 'HEAD', 'POST', 'PUT', 'PATCH', 'DELETE').
|
|
94
|
+
:param airflow_uri: The URI of the Apache Airflow Web UI hosted within this environment.
|
|
95
|
+
:param path: The path to send the request.
|
|
96
|
+
:param data: Dictionary, list of tuples, bytes, or file-like object to send in the body of the request.
|
|
97
|
+
:param timeout: The timeout for this request.
|
|
98
|
+
"""
|
|
99
|
+
authed_session = AuthorizedSession(self.get_credentials())
|
|
100
|
+
|
|
101
|
+
resp = authed_session.request(
|
|
102
|
+
method=method,
|
|
103
|
+
url=urljoin(airflow_uri, path),
|
|
104
|
+
data=data,
|
|
105
|
+
headers={"Content-Type": "application/json"},
|
|
106
|
+
timeout=timeout,
|
|
107
|
+
)
|
|
108
|
+
return resp
|
|
109
|
+
|
|
79
110
|
def get_operation(self, operation_name):
|
|
80
111
|
return self.get_environment_client().transport.operations_client.get_operation(name=operation_name)
|
|
81
112
|
|
|
@@ -408,16 +439,52 @@ class CloudComposerHook(GoogleBaseHook, OperationHelper):
|
|
|
408
439
|
self.log.info("Waiting for result...")
|
|
409
440
|
time.sleep(poll_interval)
|
|
410
441
|
|
|
442
|
+
def trigger_dag_run(
|
|
443
|
+
self,
|
|
444
|
+
composer_airflow_uri: str,
|
|
445
|
+
composer_dag_id: str,
|
|
446
|
+
composer_dag_conf: dict | None = None,
|
|
447
|
+
timeout: float | None = None,
|
|
448
|
+
) -> dict:
|
|
449
|
+
"""
|
|
450
|
+
Trigger DAG run for provided Apache Airflow Web UI hosted within Composer environment.
|
|
451
|
+
|
|
452
|
+
:param composer_airflow_uri: The URI of the Apache Airflow Web UI hosted within Composer environment.
|
|
453
|
+
:param composer_dag_id: The ID of DAG which will be triggered.
|
|
454
|
+
:param composer_dag_conf: Configuration parameters for the DAG run.
|
|
455
|
+
:param timeout: The timeout for this request.
|
|
456
|
+
"""
|
|
457
|
+
response = self.make_composer_airflow_api_request(
|
|
458
|
+
method="POST",
|
|
459
|
+
airflow_uri=composer_airflow_uri,
|
|
460
|
+
path=f"/api/v1/dags/{composer_dag_id}/dagRuns",
|
|
461
|
+
data=json.dumps(
|
|
462
|
+
{
|
|
463
|
+
"conf": composer_dag_conf or {},
|
|
464
|
+
}
|
|
465
|
+
),
|
|
466
|
+
timeout=timeout,
|
|
467
|
+
)
|
|
411
468
|
|
|
412
|
-
|
|
469
|
+
if response.status_code != 200:
|
|
470
|
+
self.log.error(response.text)
|
|
471
|
+
response.raise_for_status()
|
|
472
|
+
|
|
473
|
+
return response.json()
|
|
474
|
+
|
|
475
|
+
|
|
476
|
+
class CloudComposerAsyncHook(GoogleBaseAsyncHook):
|
|
413
477
|
"""Hook for Google Cloud Composer async APIs."""
|
|
414
478
|
|
|
479
|
+
sync_hook_class = CloudComposerHook
|
|
480
|
+
|
|
415
481
|
client_options = ClientOptions(api_endpoint="composer.googleapis.com:443")
|
|
416
482
|
|
|
417
|
-
def get_environment_client(self) -> EnvironmentsAsyncClient:
|
|
483
|
+
async def get_environment_client(self) -> EnvironmentsAsyncClient:
|
|
418
484
|
"""Retrieve client library object that allow access Environments service."""
|
|
485
|
+
sync_hook = await self.get_sync_hook()
|
|
419
486
|
return EnvironmentsAsyncClient(
|
|
420
|
-
credentials=
|
|
487
|
+
credentials=sync_hook.get_credentials(),
|
|
421
488
|
client_info=CLIENT_INFO,
|
|
422
489
|
client_options=self.client_options,
|
|
423
490
|
)
|
|
@@ -429,9 +496,8 @@ class CloudComposerAsyncHook(GoogleBaseHook):
|
|
|
429
496
|
return f"projects/{project_id}/locations/{region}"
|
|
430
497
|
|
|
431
498
|
async def get_operation(self, operation_name):
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
)
|
|
499
|
+
client = await self.get_environment_client()
|
|
500
|
+
return await client.transport.operations_client.get_operation(name=operation_name)
|
|
435
501
|
|
|
436
502
|
@GoogleBaseHook.fallback_to_default_project_id
|
|
437
503
|
async def create_environment(
|
|
@@ -454,7 +520,7 @@ class CloudComposerAsyncHook(GoogleBaseHook):
|
|
|
454
520
|
:param timeout: The timeout for this request.
|
|
455
521
|
:param metadata: Strings which should be sent along with the request as metadata.
|
|
456
522
|
"""
|
|
457
|
-
client = self.get_environment_client()
|
|
523
|
+
client = await self.get_environment_client()
|
|
458
524
|
return await client.create_environment(
|
|
459
525
|
request={"parent": self.get_parent(project_id, region), "environment": environment},
|
|
460
526
|
retry=retry,
|
|
@@ -482,7 +548,7 @@ class CloudComposerAsyncHook(GoogleBaseHook):
|
|
|
482
548
|
:param timeout: The timeout for this request.
|
|
483
549
|
:param metadata: Strings which should be sent along with the request as metadata.
|
|
484
550
|
"""
|
|
485
|
-
client = self.get_environment_client()
|
|
551
|
+
client = await self.get_environment_client()
|
|
486
552
|
name = self.get_environment_name(project_id, region, environment_id)
|
|
487
553
|
return await client.delete_environment(
|
|
488
554
|
request={"name": name}, retry=retry, timeout=timeout, metadata=metadata
|
|
@@ -518,7 +584,7 @@ class CloudComposerAsyncHook(GoogleBaseHook):
|
|
|
518
584
|
:param timeout: The timeout for this request.
|
|
519
585
|
:param metadata: Strings which should be sent along with the request as metadata.
|
|
520
586
|
"""
|
|
521
|
-
client = self.get_environment_client()
|
|
587
|
+
client = await self.get_environment_client()
|
|
522
588
|
name = self.get_environment_name(project_id, region, environment_id)
|
|
523
589
|
|
|
524
590
|
return await client.update_environment(
|
|
@@ -556,7 +622,7 @@ class CloudComposerAsyncHook(GoogleBaseHook):
|
|
|
556
622
|
:param timeout: The timeout for this request.
|
|
557
623
|
:param metadata: Strings which should be sent along with the request as metadata.
|
|
558
624
|
"""
|
|
559
|
-
client = self.get_environment_client()
|
|
625
|
+
client = await self.get_environment_client()
|
|
560
626
|
|
|
561
627
|
return await client.execute_airflow_command(
|
|
562
628
|
request={
|
|
@@ -598,7 +664,7 @@ class CloudComposerAsyncHook(GoogleBaseHook):
|
|
|
598
664
|
:param timeout: The timeout for this request.
|
|
599
665
|
:param metadata: Strings which should be sent along with the request as metadata.
|
|
600
666
|
"""
|
|
601
|
-
client = self.get_environment_client()
|
|
667
|
+
client = await self.get_environment_client()
|
|
602
668
|
|
|
603
669
|
return await client.poll_airflow_command(
|
|
604
670
|
request={
|
|
@@ -42,7 +42,11 @@ from google.longrunning import operations_pb2
|
|
|
42
42
|
|
|
43
43
|
from airflow.exceptions import AirflowException
|
|
44
44
|
from airflow.providers.google.common.consts import CLIENT_INFO
|
|
45
|
-
from airflow.providers.google.common.hooks.base_google import
|
|
45
|
+
from airflow.providers.google.common.hooks.base_google import (
|
|
46
|
+
PROVIDE_PROJECT_ID,
|
|
47
|
+
GoogleBaseAsyncHook,
|
|
48
|
+
GoogleBaseHook,
|
|
49
|
+
)
|
|
46
50
|
|
|
47
51
|
if TYPE_CHECKING:
|
|
48
52
|
from google.api_core import operation
|
|
@@ -159,7 +163,7 @@ class CloudRunHook(GoogleBaseHook):
|
|
|
159
163
|
return list(itertools.islice(jobs, limit))
|
|
160
164
|
|
|
161
165
|
|
|
162
|
-
class CloudRunAsyncHook(
|
|
166
|
+
class CloudRunAsyncHook(GoogleBaseAsyncHook):
|
|
163
167
|
"""
|
|
164
168
|
Async hook for the Google Cloud Run service.
|
|
165
169
|
|
|
@@ -174,6 +178,8 @@ class CloudRunAsyncHook(GoogleBaseHook):
|
|
|
174
178
|
account from the list granting this role to the originating account.
|
|
175
179
|
"""
|
|
176
180
|
|
|
181
|
+
sync_hook_class = CloudRunHook
|
|
182
|
+
|
|
177
183
|
def __init__(
|
|
178
184
|
self,
|
|
179
185
|
gcp_conn_id: str = "google_cloud_default",
|
|
@@ -183,16 +189,16 @@ class CloudRunAsyncHook(GoogleBaseHook):
|
|
|
183
189
|
self._client: JobsAsyncClient | None = None
|
|
184
190
|
super().__init__(gcp_conn_id=gcp_conn_id, impersonation_chain=impersonation_chain, **kwargs)
|
|
185
191
|
|
|
186
|
-
def get_conn(self):
|
|
192
|
+
async def get_conn(self):
|
|
187
193
|
if self._client is None:
|
|
188
|
-
|
|
194
|
+
sync_hook = await self.get_sync_hook()
|
|
195
|
+
self._client = JobsAsyncClient(credentials=sync_hook.get_credentials(), client_info=CLIENT_INFO)
|
|
189
196
|
|
|
190
197
|
return self._client
|
|
191
198
|
|
|
192
199
|
async def get_operation(self, operation_name: str) -> operations_pb2.Operation:
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
)
|
|
200
|
+
conn = await self.get_conn()
|
|
201
|
+
return await conn.get_operation(operations_pb2.GetOperationRequest(name=operation_name), timeout=120)
|
|
196
202
|
|
|
197
203
|
|
|
198
204
|
class CloudRunServiceHook(GoogleBaseHook):
|
|
@@ -258,7 +264,7 @@ class CloudRunServiceHook(GoogleBaseHook):
|
|
|
258
264
|
return operation.result()
|
|
259
265
|
|
|
260
266
|
|
|
261
|
-
class CloudRunServiceAsyncHook(
|
|
267
|
+
class CloudRunServiceAsyncHook(GoogleBaseAsyncHook):
|
|
262
268
|
"""
|
|
263
269
|
Async hook for the Google Cloud Run services.
|
|
264
270
|
|
|
@@ -273,6 +279,8 @@ class CloudRunServiceAsyncHook(GoogleBaseHook):
|
|
|
273
279
|
account from the list granting this role to the originating account.
|
|
274
280
|
"""
|
|
275
281
|
|
|
282
|
+
sync_hook_class = CloudRunServiceHook
|
|
283
|
+
|
|
276
284
|
def __init__(
|
|
277
285
|
self,
|
|
278
286
|
gcp_conn_id: str = "google_cloud_default",
|
|
@@ -38,7 +38,6 @@ from google.cloud.aiplatform_v1 import JobServiceClient, PipelineServiceClient
|
|
|
38
38
|
|
|
39
39
|
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
|
|
40
40
|
from airflow.providers.google.common.consts import CLIENT_INFO
|
|
41
|
-
from airflow.providers.google.common.deprecated import deprecated
|
|
42
41
|
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
|
|
43
42
|
from airflow.providers.google.common.hooks.operation_helpers import OperationHelper
|
|
44
43
|
|
|
@@ -951,178 +950,6 @@ class AutoMLHook(GoogleBaseHook, OperationHelper):
|
|
|
951
950
|
)
|
|
952
951
|
return model, training_id
|
|
953
952
|
|
|
954
|
-
@GoogleBaseHook.fallback_to_default_project_id
|
|
955
|
-
@deprecated(
|
|
956
|
-
planned_removal_date="September 15, 2025",
|
|
957
|
-
category=AirflowProviderDeprecationWarning,
|
|
958
|
-
reason="Deprecation of AutoMLText API",
|
|
959
|
-
)
|
|
960
|
-
def create_auto_ml_text_training_job(
|
|
961
|
-
self,
|
|
962
|
-
project_id: str,
|
|
963
|
-
region: str,
|
|
964
|
-
display_name: str,
|
|
965
|
-
dataset: datasets.TextDataset,
|
|
966
|
-
prediction_type: str,
|
|
967
|
-
multi_label: bool = False,
|
|
968
|
-
sentiment_max: int = 10,
|
|
969
|
-
labels: dict[str, str] | None = None,
|
|
970
|
-
training_encryption_spec_key_name: str | None = None,
|
|
971
|
-
model_encryption_spec_key_name: str | None = None,
|
|
972
|
-
training_fraction_split: float | None = None,
|
|
973
|
-
validation_fraction_split: float | None = None,
|
|
974
|
-
test_fraction_split: float | None = None,
|
|
975
|
-
training_filter_split: str | None = None,
|
|
976
|
-
validation_filter_split: str | None = None,
|
|
977
|
-
test_filter_split: str | None = None,
|
|
978
|
-
model_display_name: str | None = None,
|
|
979
|
-
model_labels: dict[str, str] | None = None,
|
|
980
|
-
sync: bool = True,
|
|
981
|
-
parent_model: str | None = None,
|
|
982
|
-
is_default_version: bool | None = None,
|
|
983
|
-
model_version_aliases: list[str] | None = None,
|
|
984
|
-
model_version_description: str | None = None,
|
|
985
|
-
) -> tuple[models.Model | None, str]:
|
|
986
|
-
"""
|
|
987
|
-
Create an AutoML Text Training Job.
|
|
988
|
-
|
|
989
|
-
WARNING: Text creation API is deprecated since September 15, 2024
|
|
990
|
-
(https://cloud.google.com/vertex-ai/docs/tutorials/text-classification-automl/overview).
|
|
991
|
-
|
|
992
|
-
:param project_id: Required. Project to run training in.
|
|
993
|
-
:param region: Required. Location to run training in.
|
|
994
|
-
:param display_name: Required. The user-defined name of this TrainingPipeline.
|
|
995
|
-
:param dataset: Required. The dataset within the same Project from which data will be used to train
|
|
996
|
-
the Model. The Dataset must use schema compatible with Model being trained, and what is
|
|
997
|
-
compatible should be described in the used TrainingPipeline's [training_task_definition]
|
|
998
|
-
[google.cloud.aiplatform.v1beta1.TrainingPipeline.training_task_definition].
|
|
999
|
-
:param prediction_type: The type of prediction the Model is to produce, one of:
|
|
1000
|
-
"classification" - A classification model analyzes text data and returns a list of categories
|
|
1001
|
-
that apply to the text found in the data. Vertex AI offers both single-label and multi-label text
|
|
1002
|
-
classification models.
|
|
1003
|
-
"extraction" - An entity extraction model inspects text data for known entities referenced in the
|
|
1004
|
-
data and labels those entities in the text.
|
|
1005
|
-
"sentiment" - A sentiment analysis model inspects text data and identifies the prevailing
|
|
1006
|
-
emotional opinion within it, especially to determine a writer's attitude as positive, negative,
|
|
1007
|
-
or neutral.
|
|
1008
|
-
:param parent_model: Optional. The resource name or model ID of an existing model.
|
|
1009
|
-
The new model uploaded by this job will be a version of `parent_model`.
|
|
1010
|
-
Only set this field when training a new version of an existing model.
|
|
1011
|
-
:param is_default_version: Optional. When set to True, the newly uploaded model version will
|
|
1012
|
-
automatically have alias "default" included. Subsequent uses of
|
|
1013
|
-
the model produced by this job without a version specified will
|
|
1014
|
-
use this "default" version.
|
|
1015
|
-
When set to False, the "default" alias will not be moved.
|
|
1016
|
-
Actions targeting the model version produced by this job will need
|
|
1017
|
-
to specifically reference this version by ID or alias.
|
|
1018
|
-
New model uploads, i.e. version 1, will always be "default" aliased.
|
|
1019
|
-
:param model_version_aliases: Optional. User provided version aliases so that the model version
|
|
1020
|
-
uploaded by this job can be referenced via alias instead of
|
|
1021
|
-
auto-generated version ID. A default version alias will be created
|
|
1022
|
-
for the first version of the model.
|
|
1023
|
-
The format is [a-z][a-zA-Z0-9-]{0,126}[a-z0-9]
|
|
1024
|
-
:param model_version_description: Optional. The description of the model version
|
|
1025
|
-
being uploaded by this job.
|
|
1026
|
-
:param multi_label: Required and only applicable for text classification task. If false, a
|
|
1027
|
-
single-label (multi-class) Model will be trained (i.e. assuming that for each text snippet just
|
|
1028
|
-
up to one annotation may be applicable). If true, a multi-label Model will be trained (i.e.
|
|
1029
|
-
assuming that for each text snippet multiple annotations may be applicable).
|
|
1030
|
-
:param sentiment_max: Required and only applicable for sentiment task. A sentiment is expressed as an
|
|
1031
|
-
integer ordinal, where higher value means a more positive sentiment. The range of sentiments that
|
|
1032
|
-
will be used is between 0 and sentimentMax (inclusive on both ends), and all the values in the
|
|
1033
|
-
range must be represented in the dataset before a model can be created. Only the Annotations with
|
|
1034
|
-
this sentimentMax will be used for training. sentimentMax value must be between 1 and 10
|
|
1035
|
-
(inclusive).
|
|
1036
|
-
:param labels: Optional. The labels with user-defined metadata to organize TrainingPipelines. Label
|
|
1037
|
-
keys and values can be no longer than 64 characters (Unicode codepoints), can only contain
|
|
1038
|
-
lowercase letters, numeric characters, underscores and dashes. International characters are
|
|
1039
|
-
allowed. See https://goo.gl/xmQnxf for more information and examples of labels.
|
|
1040
|
-
:param training_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer
|
|
1041
|
-
managed encryption key used to protect the training pipeline. Has the form:
|
|
1042
|
-
``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
|
|
1043
|
-
The key needs to be in the same region as where the compute resource is created.
|
|
1044
|
-
If set, this TrainingPipeline will be secured by this key.
|
|
1045
|
-
Note: Model trained by this TrainingPipeline is also secured by this key if ``model_to_upload``
|
|
1046
|
-
is not set separately.
|
|
1047
|
-
:param model_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer
|
|
1048
|
-
managed encryption key used to protect the model. Has the form:
|
|
1049
|
-
``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
|
|
1050
|
-
The key needs to be in the same region as where the compute resource is created.
|
|
1051
|
-
If set, the trained Model will be secured by this key.
|
|
1052
|
-
:param training_fraction_split: Optional. The fraction of the input data that is to be used to train
|
|
1053
|
-
the Model. This is ignored if Dataset is not provided.
|
|
1054
|
-
:param validation_fraction_split: Optional. The fraction of the input data that is to be used to
|
|
1055
|
-
validate the Model. This is ignored if Dataset is not provided.
|
|
1056
|
-
:param test_fraction_split: Optional. The fraction of the input data that is to be used to evaluate
|
|
1057
|
-
the Model. This is ignored if Dataset is not provided.
|
|
1058
|
-
:param training_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match
|
|
1059
|
-
this filter are used to train the Model. A filter with same syntax as the one used in
|
|
1060
|
-
DatasetService.ListDataItems may be used. If a single DataItem is matched by more than one of the
|
|
1061
|
-
FilterSplit filters, then it is assigned to the first set that applies to it in the training,
|
|
1062
|
-
validation, test order. This is ignored if Dataset is not provided.
|
|
1063
|
-
:param validation_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match
|
|
1064
|
-
this filter are used to validate the Model. A filter with same syntax as the one used in
|
|
1065
|
-
DatasetService.ListDataItems may be used. If a single DataItem is matched by more than one of the
|
|
1066
|
-
FilterSplit filters, then it is assigned to the first set that applies to it in the training,
|
|
1067
|
-
validation, test order. This is ignored if Dataset is not provided.
|
|
1068
|
-
:param test_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match this
|
|
1069
|
-
filter are used to test the Model. A filter with same syntax as the one used in
|
|
1070
|
-
DatasetService.ListDataItems may be used. If a single DataItem is matched by more than one of the
|
|
1071
|
-
FilterSplit filters, then it is assigned to the first set that applies to it in the training,
|
|
1072
|
-
validation, test order. This is ignored if Dataset is not provided.
|
|
1073
|
-
:param model_display_name: Optional. The display name of the managed Vertex AI Model. The name can be
|
|
1074
|
-
up to 128 characters long and can consist of any UTF-8 characters.
|
|
1075
|
-
If not provided upon creation, the job's display_name is used.
|
|
1076
|
-
:param model_labels: Optional. The labels with user-defined metadata to organize your Models. Label
|
|
1077
|
-
keys and values can be no longer than 64 characters (Unicode codepoints), can only contain
|
|
1078
|
-
lowercase letters, numeric characters, underscores and dashes. International characters are
|
|
1079
|
-
allowed. See https://goo.gl/xmQnxf for more information and examples of labels.
|
|
1080
|
-
:param sync: Whether to execute this method synchronously. If False, this method will be executed in
|
|
1081
|
-
concurrent Future and any downstream object will be immediately returned and synced when the
|
|
1082
|
-
Future has completed.
|
|
1083
|
-
"""
|
|
1084
|
-
self._job = AutoMLTextTrainingJob(
|
|
1085
|
-
display_name=display_name,
|
|
1086
|
-
prediction_type=prediction_type,
|
|
1087
|
-
multi_label=multi_label,
|
|
1088
|
-
sentiment_max=sentiment_max,
|
|
1089
|
-
project=project_id,
|
|
1090
|
-
location=region,
|
|
1091
|
-
credentials=self.get_credentials(),
|
|
1092
|
-
labels=labels,
|
|
1093
|
-
training_encryption_spec_key_name=training_encryption_spec_key_name,
|
|
1094
|
-
model_encryption_spec_key_name=model_encryption_spec_key_name,
|
|
1095
|
-
)
|
|
1096
|
-
|
|
1097
|
-
if not self._job:
|
|
1098
|
-
raise AirflowException("AutoMLTextTrainingJob was not created")
|
|
1099
|
-
|
|
1100
|
-
model = self._job.run(
|
|
1101
|
-
dataset=dataset,
|
|
1102
|
-
training_fraction_split=training_fraction_split,
|
|
1103
|
-
validation_fraction_split=validation_fraction_split,
|
|
1104
|
-
test_fraction_split=test_fraction_split,
|
|
1105
|
-
training_filter_split=training_filter_split,
|
|
1106
|
-
validation_filter_split=validation_filter_split,
|
|
1107
|
-
test_filter_split=test_filter_split,
|
|
1108
|
-
model_display_name=model_display_name,
|
|
1109
|
-
model_labels=model_labels,
|
|
1110
|
-
sync=sync,
|
|
1111
|
-
parent_model=parent_model,
|
|
1112
|
-
is_default_version=is_default_version,
|
|
1113
|
-
model_version_aliases=model_version_aliases,
|
|
1114
|
-
model_version_description=model_version_description,
|
|
1115
|
-
)
|
|
1116
|
-
training_id = self.extract_training_id(self._job.resource_name)
|
|
1117
|
-
if model:
|
|
1118
|
-
model.wait()
|
|
1119
|
-
else:
|
|
1120
|
-
self.log.warning(
|
|
1121
|
-
"Training did not produce a Managed Model returning None. AutoML Text Training "
|
|
1122
|
-
"Pipeline is not configured to upload a Model."
|
|
1123
|
-
)
|
|
1124
|
-
return model, training_id
|
|
1125
|
-
|
|
1126
953
|
@GoogleBaseHook.fallback_to_default_project_id
|
|
1127
954
|
def create_auto_ml_video_training_job(
|
|
1128
955
|
self,
|
|
@@ -213,9 +213,15 @@ class GCSTaskHandler(FileTaskHandler, LoggingMixin):
|
|
|
213
213
|
gcp_keyfile_dict: dict | None = None,
|
|
214
214
|
gcp_scopes: Collection[str] | None = _DEFAULT_SCOPESS,
|
|
215
215
|
project_id: str = PROVIDE_PROJECT_ID,
|
|
216
|
+
max_bytes: int = 0,
|
|
217
|
+
backup_count: int = 0,
|
|
218
|
+
delay: bool = False,
|
|
216
219
|
**kwargs,
|
|
217
|
-
):
|
|
218
|
-
|
|
220
|
+
) -> None:
|
|
221
|
+
# support log file size handling of FileTaskHandler
|
|
222
|
+
super().__init__(
|
|
223
|
+
base_log_folder=base_log_folder, max_bytes=max_bytes, backup_count=backup_count, delay=delay
|
|
224
|
+
)
|
|
219
225
|
self.handler: logging.FileHandler | None = None
|
|
220
226
|
self.log_relative_path = ""
|
|
221
227
|
self.closed = False
|
|
@@ -21,7 +21,7 @@ import shlex
|
|
|
21
21
|
from collections.abc import Sequence
|
|
22
22
|
from typing import TYPE_CHECKING, Any
|
|
23
23
|
|
|
24
|
-
from google.api_core.exceptions import AlreadyExists
|
|
24
|
+
from google.api_core.exceptions import AlreadyExists, NotFound
|
|
25
25
|
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
|
|
26
26
|
from google.cloud.orchestration.airflow.service_v1 import ImageVersion
|
|
27
27
|
from google.cloud.orchestration.airflow.service_v1.types import Environment, ExecuteAirflowCommandResponse
|
|
@@ -798,3 +798,86 @@ class CloudComposerRunAirflowCLICommandOperator(GoogleCloudBaseOperator):
|
|
|
798
798
|
"""Merge output to one string."""
|
|
799
799
|
result_str = "\n".join(line_dict["content"] for line_dict in result["output"])
|
|
800
800
|
return result_str
|
|
801
|
+
|
|
802
|
+
|
|
803
|
+
class CloudComposerTriggerDAGRunOperator(GoogleCloudBaseOperator):
|
|
804
|
+
"""
|
|
805
|
+
Trigger DAG run for provided Composer environment.
|
|
806
|
+
|
|
807
|
+
:param project_id: The ID of the Google Cloud project that the service belongs to.
|
|
808
|
+
:param region: The ID of the Google Cloud region that the service belongs to.
|
|
809
|
+
:param environment_id: The ID of the Google Cloud environment that the service belongs to.
|
|
810
|
+
:param composer_dag_id: The ID of DAG which will be triggered.
|
|
811
|
+
:param composer_dag_conf: Configuration parameters for the DAG run.
|
|
812
|
+
:param timeout: The timeout for this request.
|
|
813
|
+
:param gcp_conn_id: The connection ID used to connect to Google Cloud Platform.
|
|
814
|
+
:param impersonation_chain: Optional service account to impersonate using short-term
|
|
815
|
+
credentials, or chained list of accounts required to get the access_token
|
|
816
|
+
of the last account in the list, which will be impersonated in the request.
|
|
817
|
+
If set as a string, the account must grant the originating account
|
|
818
|
+
the Service Account Token Creator IAM role.
|
|
819
|
+
If set as a sequence, the identities from the list must grant
|
|
820
|
+
Service Account Token Creator IAM role to the directly preceding identity, with first
|
|
821
|
+
account from the list granting this role to the originating account (templated).
|
|
822
|
+
"""
|
|
823
|
+
|
|
824
|
+
template_fields = (
|
|
825
|
+
"project_id",
|
|
826
|
+
"region",
|
|
827
|
+
"environment_id",
|
|
828
|
+
"composer_dag_id",
|
|
829
|
+
"impersonation_chain",
|
|
830
|
+
)
|
|
831
|
+
|
|
832
|
+
def __init__(
|
|
833
|
+
self,
|
|
834
|
+
*,
|
|
835
|
+
project_id: str,
|
|
836
|
+
region: str,
|
|
837
|
+
environment_id: str,
|
|
838
|
+
composer_dag_id: str,
|
|
839
|
+
composer_dag_conf: dict | None = None,
|
|
840
|
+
timeout: float | None = None,
|
|
841
|
+
gcp_conn_id: str = "google_cloud_default",
|
|
842
|
+
impersonation_chain: str | Sequence[str] | None = None,
|
|
843
|
+
**kwargs,
|
|
844
|
+
) -> None:
|
|
845
|
+
super().__init__(**kwargs)
|
|
846
|
+
self.project_id = project_id
|
|
847
|
+
self.region = region
|
|
848
|
+
self.environment_id = environment_id
|
|
849
|
+
self.composer_dag_id = composer_dag_id
|
|
850
|
+
self.composer_dag_conf = composer_dag_conf or {}
|
|
851
|
+
self.timeout = timeout
|
|
852
|
+
self.gcp_conn_id = gcp_conn_id
|
|
853
|
+
self.impersonation_chain = impersonation_chain
|
|
854
|
+
|
|
855
|
+
def execute(self, context: Context):
|
|
856
|
+
hook = CloudComposerHook(
|
|
857
|
+
gcp_conn_id=self.gcp_conn_id,
|
|
858
|
+
impersonation_chain=self.impersonation_chain,
|
|
859
|
+
)
|
|
860
|
+
try:
|
|
861
|
+
environment = hook.get_environment(
|
|
862
|
+
project_id=self.project_id,
|
|
863
|
+
region=self.region,
|
|
864
|
+
environment_id=self.environment_id,
|
|
865
|
+
timeout=self.timeout,
|
|
866
|
+
)
|
|
867
|
+
except NotFound as not_found_err:
|
|
868
|
+
self.log.info("The Composer environment %s does not exist.", self.environment_id)
|
|
869
|
+
raise AirflowException(not_found_err)
|
|
870
|
+
composer_airflow_uri = environment.config.airflow_uri
|
|
871
|
+
|
|
872
|
+
self.log.info(
|
|
873
|
+
"Triggering the DAG %s on the %s environment...", self.composer_dag_id, self.environment_id
|
|
874
|
+
)
|
|
875
|
+
dag_run = hook.trigger_dag_run(
|
|
876
|
+
composer_airflow_uri=composer_airflow_uri,
|
|
877
|
+
composer_dag_id=self.composer_dag_id,
|
|
878
|
+
composer_dag_conf=self.composer_dag_conf,
|
|
879
|
+
timeout=self.timeout,
|
|
880
|
+
)
|
|
881
|
+
self.log.info("The DAG %s was triggered with Run ID: %s", self.composer_dag_id, dag_run["dag_run_id"])
|
|
882
|
+
|
|
883
|
+
return dag_run
|
|
@@ -222,7 +222,7 @@ class CloudComposerDAGRunSensor(BaseSensorOperator):
|
|
|
222
222
|
if self.deferrable:
|
|
223
223
|
start_date, end_date = self._get_logical_dates(context)
|
|
224
224
|
self.defer(
|
|
225
|
-
timeout=self.timeout,
|
|
225
|
+
timeout=timedelta(seconds=self.timeout) if self.timeout else None,
|
|
226
226
|
trigger=CloudComposerDAGRunTrigger(
|
|
227
227
|
project_id=self.project_id,
|
|
228
228
|
region=self.region,
|
|
@@ -25,13 +25,11 @@ from functools import cached_property
|
|
|
25
25
|
from typing import TYPE_CHECKING
|
|
26
26
|
|
|
27
27
|
from airflow.exceptions import AirflowProviderDeprecationWarning
|
|
28
|
-
from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook
|
|
29
28
|
from airflow.providers.google.cloud.links.bigquery import BigQueryTableLink
|
|
30
29
|
from airflow.providers.google.cloud.transfers.bigquery_to_sql import BigQueryToSqlBaseOperator
|
|
31
30
|
from airflow.providers.microsoft.mssql.hooks.mssql import MsSqlHook
|
|
32
31
|
|
|
33
32
|
if TYPE_CHECKING:
|
|
34
|
-
from airflow.providers.openlineage.extractors import OperatorLineage
|
|
35
33
|
from airflow.utils.context import Context
|
|
36
34
|
|
|
37
35
|
|
|
@@ -112,67 +110,3 @@ class BigQueryToMsSqlOperator(BigQueryToSqlBaseOperator):
|
|
|
112
110
|
project_id=project_id,
|
|
113
111
|
table_id=table_id,
|
|
114
112
|
)
|
|
115
|
-
|
|
116
|
-
def get_openlineage_facets_on_complete(self, task_instance) -> OperatorLineage | None:
|
|
117
|
-
from airflow.providers.common.compat.openlineage.facet import Dataset
|
|
118
|
-
from airflow.providers.google.cloud.openlineage.utils import (
|
|
119
|
-
BIGQUERY_NAMESPACE,
|
|
120
|
-
get_facets_from_bq_table_for_given_fields,
|
|
121
|
-
get_identity_column_lineage_facet,
|
|
122
|
-
)
|
|
123
|
-
from airflow.providers.openlineage.extractors import OperatorLineage
|
|
124
|
-
|
|
125
|
-
if not self.bigquery_hook:
|
|
126
|
-
self.bigquery_hook = BigQueryHook(
|
|
127
|
-
gcp_conn_id=self.gcp_conn_id,
|
|
128
|
-
location=self.location,
|
|
129
|
-
impersonation_chain=self.impersonation_chain,
|
|
130
|
-
)
|
|
131
|
-
|
|
132
|
-
try:
|
|
133
|
-
table_obj = self.bigquery_hook.get_client().get_table(self.source_project_dataset_table)
|
|
134
|
-
except Exception:
|
|
135
|
-
self.log.debug(
|
|
136
|
-
"OpenLineage: could not fetch BigQuery table %s",
|
|
137
|
-
self.source_project_dataset_table,
|
|
138
|
-
exc_info=True,
|
|
139
|
-
)
|
|
140
|
-
return OperatorLineage()
|
|
141
|
-
|
|
142
|
-
if self.selected_fields:
|
|
143
|
-
if isinstance(self.selected_fields, str):
|
|
144
|
-
bigquery_field_names = list(self.selected_fields)
|
|
145
|
-
else:
|
|
146
|
-
bigquery_field_names = self.selected_fields
|
|
147
|
-
else:
|
|
148
|
-
bigquery_field_names = [f.name for f in getattr(table_obj, "schema", [])]
|
|
149
|
-
|
|
150
|
-
input_dataset = Dataset(
|
|
151
|
-
namespace=BIGQUERY_NAMESPACE,
|
|
152
|
-
name=self.source_project_dataset_table,
|
|
153
|
-
facets=get_facets_from_bq_table_for_given_fields(table_obj, bigquery_field_names),
|
|
154
|
-
)
|
|
155
|
-
|
|
156
|
-
db_info = self.mssql_hook.get_openlineage_database_info(self.mssql_hook.get_conn())
|
|
157
|
-
default_schema = self.mssql_hook.get_openlineage_default_schema()
|
|
158
|
-
namespace = f"{db_info.scheme}://{db_info.authority}"
|
|
159
|
-
|
|
160
|
-
if self.target_table_name and "." in self.target_table_name:
|
|
161
|
-
schema_name, table_name = self.target_table_name.split(".", 1)
|
|
162
|
-
else:
|
|
163
|
-
schema_name = default_schema or ""
|
|
164
|
-
table_name = self.target_table_name or ""
|
|
165
|
-
|
|
166
|
-
if self.database:
|
|
167
|
-
output_name = f"{self.database}.{schema_name}.{table_name}"
|
|
168
|
-
else:
|
|
169
|
-
output_name = f"{schema_name}.{table_name}"
|
|
170
|
-
|
|
171
|
-
column_lineage_facet = get_identity_column_lineage_facet(
|
|
172
|
-
bigquery_field_names, input_datasets=[input_dataset]
|
|
173
|
-
)
|
|
174
|
-
|
|
175
|
-
output_facets = column_lineage_facet or {}
|
|
176
|
-
output_dataset = Dataset(namespace=namespace, name=output_name, facets=output_facets)
|
|
177
|
-
|
|
178
|
-
return OperatorLineage(inputs=[input_dataset], outputs=[output_dataset])
|