apache-airflow-providers-google 17.2.0__py3-none-any.whl → 18.0.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of apache-airflow-providers-google might be problematic. Click here for more details.

Files changed (22) hide show
  1. airflow/providers/google/__init__.py +1 -1
  2. airflow/providers/google/cloud/hooks/bigquery.py +6 -0
  3. airflow/providers/google/cloud/hooks/cloud_composer.py +79 -13
  4. airflow/providers/google/cloud/hooks/cloud_run.py +16 -8
  5. airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py +0 -173
  6. airflow/providers/google/cloud/log/gcs_task_handler.py +8 -2
  7. airflow/providers/google/cloud/operators/cloud_composer.py +84 -1
  8. airflow/providers/google/cloud/sensors/cloud_composer.py +1 -1
  9. airflow/providers/google/cloud/transfers/bigquery_to_mssql.py +0 -66
  10. airflow/providers/google/cloud/transfers/bigquery_to_mysql.py +12 -1
  11. airflow/providers/google/cloud/transfers/bigquery_to_postgres.py +18 -9
  12. airflow/providers/google/cloud/transfers/bigquery_to_sql.py +95 -0
  13. airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +11 -0
  14. airflow/providers/google/cloud/triggers/cloud_composer.py +21 -15
  15. airflow/providers/google/cloud/utils/bigquery_get_data.py +1 -1
  16. airflow/providers/google/marketing_platform/hooks/display_video.py +0 -150
  17. airflow/providers/google/marketing_platform/operators/display_video.py +0 -510
  18. airflow/providers/google/marketing_platform/sensors/display_video.py +1 -68
  19. {apache_airflow_providers_google-17.2.0.dist-info → apache_airflow_providers_google-18.0.0rc1.dist-info}/METADATA +43 -16
  20. {apache_airflow_providers_google-17.2.0.dist-info → apache_airflow_providers_google-18.0.0rc1.dist-info}/RECORD +22 -22
  21. {apache_airflow_providers_google-17.2.0.dist-info → apache_airflow_providers_google-18.0.0rc1.dist-info}/WHEEL +0 -0
  22. {apache_airflow_providers_google-17.2.0.dist-info → apache_airflow_providers_google-18.0.0rc1.dist-info}/entry_points.txt +0 -0
@@ -29,7 +29,7 @@ from airflow import __version__ as airflow_version
29
29
 
30
30
  __all__ = ["__version__"]
31
31
 
32
- __version__ = "17.2.0"
32
+ __version__ = "18.0.0"
33
33
 
34
34
  if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
35
35
  "2.10.0"
@@ -1702,6 +1702,7 @@ class BigQueryCursor(BigQueryBaseCursor):
1702
1702
  schema_update_options: Iterable | None = None,
1703
1703
  priority: str | None = None,
1704
1704
  time_partitioning: dict | None = None,
1705
+ range_partitioning: dict | None = None,
1705
1706
  api_resource_configs: dict | None = None,
1706
1707
  cluster_fields: list[str] | None = None,
1707
1708
  encryption_configuration: dict | None = None,
@@ -1714,6 +1715,10 @@ class BigQueryCursor(BigQueryBaseCursor):
1714
1715
 
1715
1716
  if time_partitioning is None:
1716
1717
  time_partitioning = {}
1718
+ if range_partitioning is None:
1719
+ range_partitioning = {}
1720
+ if time_partitioning and range_partitioning:
1721
+ raise ValueError("Only one of time_partitioning or range_partitioning can be set.")
1717
1722
 
1718
1723
  if not api_resource_configs:
1719
1724
  api_resource_configs = self.hook.api_resource_configs
@@ -1766,6 +1771,7 @@ class BigQueryCursor(BigQueryBaseCursor):
1766
1771
  (maximum_billing_tier, "maximumBillingTier", None, int),
1767
1772
  (maximum_bytes_billed, "maximumBytesBilled", None, float),
1768
1773
  (time_partitioning, "timePartitioning", {}, dict),
1774
+ (range_partitioning, "rangePartitioning", {}, dict),
1769
1775
  (schema_update_options, "schemaUpdateOptions", None, list),
1770
1776
  (destination_dataset_table, "destinationTable", None, dict),
1771
1777
  (cluster_fields, "clustering", None, dict),
@@ -18,12 +18,15 @@
18
18
  from __future__ import annotations
19
19
 
20
20
  import asyncio
21
+ import json
21
22
  import time
22
23
  from collections.abc import MutableSequence, Sequence
23
- from typing import TYPE_CHECKING
24
+ from typing import TYPE_CHECKING, Any
25
+ from urllib.parse import urljoin
24
26
 
25
27
  from google.api_core.client_options import ClientOptions
26
28
  from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
29
+ from google.auth.transport.requests import AuthorizedSession
27
30
  from google.cloud.orchestration.airflow.service_v1 import (
28
31
  EnvironmentsAsyncClient,
29
32
  EnvironmentsClient,
@@ -33,7 +36,7 @@ from google.cloud.orchestration.airflow.service_v1 import (
33
36
 
34
37
  from airflow.exceptions import AirflowException
35
38
  from airflow.providers.google.common.consts import CLIENT_INFO
36
- from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
39
+ from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook
37
40
 
38
41
  if TYPE_CHECKING:
39
42
  from google.api_core.operation import Operation
@@ -76,6 +79,34 @@ class CloudComposerHook(GoogleBaseHook, OperationHelper):
76
79
  client_options=self.client_options,
77
80
  )
78
81
 
82
+ def make_composer_airflow_api_request(
83
+ self,
84
+ method: str,
85
+ airflow_uri: str,
86
+ path: str,
87
+ data: Any | None = None,
88
+ timeout: float | None = None,
89
+ ):
90
+ """
91
+ Make a request to Cloud Composer environment's web server.
92
+
93
+ :param method: The request method to use ('GET', 'OPTIONS', 'HEAD', 'POST', 'PUT', 'PATCH', 'DELETE').
94
+ :param airflow_uri: The URI of the Apache Airflow Web UI hosted within this environment.
95
+ :param path: The path to send the request.
96
+ :param data: Dictionary, list of tuples, bytes, or file-like object to send in the body of the request.
97
+ :param timeout: The timeout for this request.
98
+ """
99
+ authed_session = AuthorizedSession(self.get_credentials())
100
+
101
+ resp = authed_session.request(
102
+ method=method,
103
+ url=urljoin(airflow_uri, path),
104
+ data=data,
105
+ headers={"Content-Type": "application/json"},
106
+ timeout=timeout,
107
+ )
108
+ return resp
109
+
79
110
  def get_operation(self, operation_name):
80
111
  return self.get_environment_client().transport.operations_client.get_operation(name=operation_name)
81
112
 
@@ -408,16 +439,52 @@ class CloudComposerHook(GoogleBaseHook, OperationHelper):
408
439
  self.log.info("Waiting for result...")
409
440
  time.sleep(poll_interval)
410
441
 
442
+ def trigger_dag_run(
443
+ self,
444
+ composer_airflow_uri: str,
445
+ composer_dag_id: str,
446
+ composer_dag_conf: dict | None = None,
447
+ timeout: float | None = None,
448
+ ) -> dict:
449
+ """
450
+ Trigger DAG run for provided Apache Airflow Web UI hosted within Composer environment.
451
+
452
+ :param composer_airflow_uri: The URI of the Apache Airflow Web UI hosted within Composer environment.
453
+ :param composer_dag_id: The ID of DAG which will be triggered.
454
+ :param composer_dag_conf: Configuration parameters for the DAG run.
455
+ :param timeout: The timeout for this request.
456
+ """
457
+ response = self.make_composer_airflow_api_request(
458
+ method="POST",
459
+ airflow_uri=composer_airflow_uri,
460
+ path=f"/api/v1/dags/{composer_dag_id}/dagRuns",
461
+ data=json.dumps(
462
+ {
463
+ "conf": composer_dag_conf or {},
464
+ }
465
+ ),
466
+ timeout=timeout,
467
+ )
411
468
 
412
- class CloudComposerAsyncHook(GoogleBaseHook):
469
+ if response.status_code != 200:
470
+ self.log.error(response.text)
471
+ response.raise_for_status()
472
+
473
+ return response.json()
474
+
475
+
476
+ class CloudComposerAsyncHook(GoogleBaseAsyncHook):
413
477
  """Hook for Google Cloud Composer async APIs."""
414
478
 
479
+ sync_hook_class = CloudComposerHook
480
+
415
481
  client_options = ClientOptions(api_endpoint="composer.googleapis.com:443")
416
482
 
417
- def get_environment_client(self) -> EnvironmentsAsyncClient:
483
+ async def get_environment_client(self) -> EnvironmentsAsyncClient:
418
484
  """Retrieve client library object that allow access Environments service."""
485
+ sync_hook = await self.get_sync_hook()
419
486
  return EnvironmentsAsyncClient(
420
- credentials=self.get_credentials(),
487
+ credentials=sync_hook.get_credentials(),
421
488
  client_info=CLIENT_INFO,
422
489
  client_options=self.client_options,
423
490
  )
@@ -429,9 +496,8 @@ class CloudComposerAsyncHook(GoogleBaseHook):
429
496
  return f"projects/{project_id}/locations/{region}"
430
497
 
431
498
  async def get_operation(self, operation_name):
432
- return await self.get_environment_client().transport.operations_client.get_operation(
433
- name=operation_name
434
- )
499
+ client = await self.get_environment_client()
500
+ return await client.transport.operations_client.get_operation(name=operation_name)
435
501
 
436
502
  @GoogleBaseHook.fallback_to_default_project_id
437
503
  async def create_environment(
@@ -454,7 +520,7 @@ class CloudComposerAsyncHook(GoogleBaseHook):
454
520
  :param timeout: The timeout for this request.
455
521
  :param metadata: Strings which should be sent along with the request as metadata.
456
522
  """
457
- client = self.get_environment_client()
523
+ client = await self.get_environment_client()
458
524
  return await client.create_environment(
459
525
  request={"parent": self.get_parent(project_id, region), "environment": environment},
460
526
  retry=retry,
@@ -482,7 +548,7 @@ class CloudComposerAsyncHook(GoogleBaseHook):
482
548
  :param timeout: The timeout for this request.
483
549
  :param metadata: Strings which should be sent along with the request as metadata.
484
550
  """
485
- client = self.get_environment_client()
551
+ client = await self.get_environment_client()
486
552
  name = self.get_environment_name(project_id, region, environment_id)
487
553
  return await client.delete_environment(
488
554
  request={"name": name}, retry=retry, timeout=timeout, metadata=metadata
@@ -518,7 +584,7 @@ class CloudComposerAsyncHook(GoogleBaseHook):
518
584
  :param timeout: The timeout for this request.
519
585
  :param metadata: Strings which should be sent along with the request as metadata.
520
586
  """
521
- client = self.get_environment_client()
587
+ client = await self.get_environment_client()
522
588
  name = self.get_environment_name(project_id, region, environment_id)
523
589
 
524
590
  return await client.update_environment(
@@ -556,7 +622,7 @@ class CloudComposerAsyncHook(GoogleBaseHook):
556
622
  :param timeout: The timeout for this request.
557
623
  :param metadata: Strings which should be sent along with the request as metadata.
558
624
  """
559
- client = self.get_environment_client()
625
+ client = await self.get_environment_client()
560
626
 
561
627
  return await client.execute_airflow_command(
562
628
  request={
@@ -598,7 +664,7 @@ class CloudComposerAsyncHook(GoogleBaseHook):
598
664
  :param timeout: The timeout for this request.
599
665
  :param metadata: Strings which should be sent along with the request as metadata.
600
666
  """
601
- client = self.get_environment_client()
667
+ client = await self.get_environment_client()
602
668
 
603
669
  return await client.poll_airflow_command(
604
670
  request={
@@ -42,7 +42,11 @@ from google.longrunning import operations_pb2
42
42
 
43
43
  from airflow.exceptions import AirflowException
44
44
  from airflow.providers.google.common.consts import CLIENT_INFO
45
- from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook
45
+ from airflow.providers.google.common.hooks.base_google import (
46
+ PROVIDE_PROJECT_ID,
47
+ GoogleBaseAsyncHook,
48
+ GoogleBaseHook,
49
+ )
46
50
 
47
51
  if TYPE_CHECKING:
48
52
  from google.api_core import operation
@@ -159,7 +163,7 @@ class CloudRunHook(GoogleBaseHook):
159
163
  return list(itertools.islice(jobs, limit))
160
164
 
161
165
 
162
- class CloudRunAsyncHook(GoogleBaseHook):
166
+ class CloudRunAsyncHook(GoogleBaseAsyncHook):
163
167
  """
164
168
  Async hook for the Google Cloud Run service.
165
169
 
@@ -174,6 +178,8 @@ class CloudRunAsyncHook(GoogleBaseHook):
174
178
  account from the list granting this role to the originating account.
175
179
  """
176
180
 
181
+ sync_hook_class = CloudRunHook
182
+
177
183
  def __init__(
178
184
  self,
179
185
  gcp_conn_id: str = "google_cloud_default",
@@ -183,16 +189,16 @@ class CloudRunAsyncHook(GoogleBaseHook):
183
189
  self._client: JobsAsyncClient | None = None
184
190
  super().__init__(gcp_conn_id=gcp_conn_id, impersonation_chain=impersonation_chain, **kwargs)
185
191
 
186
- def get_conn(self):
192
+ async def get_conn(self):
187
193
  if self._client is None:
188
- self._client = JobsAsyncClient(credentials=self.get_credentials(), client_info=CLIENT_INFO)
194
+ sync_hook = await self.get_sync_hook()
195
+ self._client = JobsAsyncClient(credentials=sync_hook.get_credentials(), client_info=CLIENT_INFO)
189
196
 
190
197
  return self._client
191
198
 
192
199
  async def get_operation(self, operation_name: str) -> operations_pb2.Operation:
193
- return await self.get_conn().get_operation(
194
- operations_pb2.GetOperationRequest(name=operation_name), timeout=120
195
- )
200
+ conn = await self.get_conn()
201
+ return await conn.get_operation(operations_pb2.GetOperationRequest(name=operation_name), timeout=120)
196
202
 
197
203
 
198
204
  class CloudRunServiceHook(GoogleBaseHook):
@@ -258,7 +264,7 @@ class CloudRunServiceHook(GoogleBaseHook):
258
264
  return operation.result()
259
265
 
260
266
 
261
- class CloudRunServiceAsyncHook(GoogleBaseHook):
267
+ class CloudRunServiceAsyncHook(GoogleBaseAsyncHook):
262
268
  """
263
269
  Async hook for the Google Cloud Run services.
264
270
 
@@ -273,6 +279,8 @@ class CloudRunServiceAsyncHook(GoogleBaseHook):
273
279
  account from the list granting this role to the originating account.
274
280
  """
275
281
 
282
+ sync_hook_class = CloudRunServiceHook
283
+
276
284
  def __init__(
277
285
  self,
278
286
  gcp_conn_id: str = "google_cloud_default",
@@ -38,7 +38,6 @@ from google.cloud.aiplatform_v1 import JobServiceClient, PipelineServiceClient
38
38
 
39
39
  from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
40
40
  from airflow.providers.google.common.consts import CLIENT_INFO
41
- from airflow.providers.google.common.deprecated import deprecated
42
41
  from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
43
42
  from airflow.providers.google.common.hooks.operation_helpers import OperationHelper
44
43
 
@@ -951,178 +950,6 @@ class AutoMLHook(GoogleBaseHook, OperationHelper):
951
950
  )
952
951
  return model, training_id
953
952
 
954
- @GoogleBaseHook.fallback_to_default_project_id
955
- @deprecated(
956
- planned_removal_date="September 15, 2025",
957
- category=AirflowProviderDeprecationWarning,
958
- reason="Deprecation of AutoMLText API",
959
- )
960
- def create_auto_ml_text_training_job(
961
- self,
962
- project_id: str,
963
- region: str,
964
- display_name: str,
965
- dataset: datasets.TextDataset,
966
- prediction_type: str,
967
- multi_label: bool = False,
968
- sentiment_max: int = 10,
969
- labels: dict[str, str] | None = None,
970
- training_encryption_spec_key_name: str | None = None,
971
- model_encryption_spec_key_name: str | None = None,
972
- training_fraction_split: float | None = None,
973
- validation_fraction_split: float | None = None,
974
- test_fraction_split: float | None = None,
975
- training_filter_split: str | None = None,
976
- validation_filter_split: str | None = None,
977
- test_filter_split: str | None = None,
978
- model_display_name: str | None = None,
979
- model_labels: dict[str, str] | None = None,
980
- sync: bool = True,
981
- parent_model: str | None = None,
982
- is_default_version: bool | None = None,
983
- model_version_aliases: list[str] | None = None,
984
- model_version_description: str | None = None,
985
- ) -> tuple[models.Model | None, str]:
986
- """
987
- Create an AutoML Text Training Job.
988
-
989
- WARNING: Text creation API is deprecated since September 15, 2024
990
- (https://cloud.google.com/vertex-ai/docs/tutorials/text-classification-automl/overview).
991
-
992
- :param project_id: Required. Project to run training in.
993
- :param region: Required. Location to run training in.
994
- :param display_name: Required. The user-defined name of this TrainingPipeline.
995
- :param dataset: Required. The dataset within the same Project from which data will be used to train
996
- the Model. The Dataset must use schema compatible with Model being trained, and what is
997
- compatible should be described in the used TrainingPipeline's [training_task_definition]
998
- [google.cloud.aiplatform.v1beta1.TrainingPipeline.training_task_definition].
999
- :param prediction_type: The type of prediction the Model is to produce, one of:
1000
- "classification" - A classification model analyzes text data and returns a list of categories
1001
- that apply to the text found in the data. Vertex AI offers both single-label and multi-label text
1002
- classification models.
1003
- "extraction" - An entity extraction model inspects text data for known entities referenced in the
1004
- data and labels those entities in the text.
1005
- "sentiment" - A sentiment analysis model inspects text data and identifies the prevailing
1006
- emotional opinion within it, especially to determine a writer's attitude as positive, negative,
1007
- or neutral.
1008
- :param parent_model: Optional. The resource name or model ID of an existing model.
1009
- The new model uploaded by this job will be a version of `parent_model`.
1010
- Only set this field when training a new version of an existing model.
1011
- :param is_default_version: Optional. When set to True, the newly uploaded model version will
1012
- automatically have alias "default" included. Subsequent uses of
1013
- the model produced by this job without a version specified will
1014
- use this "default" version.
1015
- When set to False, the "default" alias will not be moved.
1016
- Actions targeting the model version produced by this job will need
1017
- to specifically reference this version by ID or alias.
1018
- New model uploads, i.e. version 1, will always be "default" aliased.
1019
- :param model_version_aliases: Optional. User provided version aliases so that the model version
1020
- uploaded by this job can be referenced via alias instead of
1021
- auto-generated version ID. A default version alias will be created
1022
- for the first version of the model.
1023
- The format is [a-z][a-zA-Z0-9-]{0,126}[a-z0-9]
1024
- :param model_version_description: Optional. The description of the model version
1025
- being uploaded by this job.
1026
- :param multi_label: Required and only applicable for text classification task. If false, a
1027
- single-label (multi-class) Model will be trained (i.e. assuming that for each text snippet just
1028
- up to one annotation may be applicable). If true, a multi-label Model will be trained (i.e.
1029
- assuming that for each text snippet multiple annotations may be applicable).
1030
- :param sentiment_max: Required and only applicable for sentiment task. A sentiment is expressed as an
1031
- integer ordinal, where higher value means a more positive sentiment. The range of sentiments that
1032
- will be used is between 0 and sentimentMax (inclusive on both ends), and all the values in the
1033
- range must be represented in the dataset before a model can be created. Only the Annotations with
1034
- this sentimentMax will be used for training. sentimentMax value must be between 1 and 10
1035
- (inclusive).
1036
- :param labels: Optional. The labels with user-defined metadata to organize TrainingPipelines. Label
1037
- keys and values can be no longer than 64 characters (Unicode codepoints), can only contain
1038
- lowercase letters, numeric characters, underscores and dashes. International characters are
1039
- allowed. See https://goo.gl/xmQnxf for more information and examples of labels.
1040
- :param training_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer
1041
- managed encryption key used to protect the training pipeline. Has the form:
1042
- ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
1043
- The key needs to be in the same region as where the compute resource is created.
1044
- If set, this TrainingPipeline will be secured by this key.
1045
- Note: Model trained by this TrainingPipeline is also secured by this key if ``model_to_upload``
1046
- is not set separately.
1047
- :param model_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer
1048
- managed encryption key used to protect the model. Has the form:
1049
- ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
1050
- The key needs to be in the same region as where the compute resource is created.
1051
- If set, the trained Model will be secured by this key.
1052
- :param training_fraction_split: Optional. The fraction of the input data that is to be used to train
1053
- the Model. This is ignored if Dataset is not provided.
1054
- :param validation_fraction_split: Optional. The fraction of the input data that is to be used to
1055
- validate the Model. This is ignored if Dataset is not provided.
1056
- :param test_fraction_split: Optional. The fraction of the input data that is to be used to evaluate
1057
- the Model. This is ignored if Dataset is not provided.
1058
- :param training_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match
1059
- this filter are used to train the Model. A filter with same syntax as the one used in
1060
- DatasetService.ListDataItems may be used. If a single DataItem is matched by more than one of the
1061
- FilterSplit filters, then it is assigned to the first set that applies to it in the training,
1062
- validation, test order. This is ignored if Dataset is not provided.
1063
- :param validation_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match
1064
- this filter are used to validate the Model. A filter with same syntax as the one used in
1065
- DatasetService.ListDataItems may be used. If a single DataItem is matched by more than one of the
1066
- FilterSplit filters, then it is assigned to the first set that applies to it in the training,
1067
- validation, test order. This is ignored if Dataset is not provided.
1068
- :param test_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match this
1069
- filter are used to test the Model. A filter with same syntax as the one used in
1070
- DatasetService.ListDataItems may be used. If a single DataItem is matched by more than one of the
1071
- FilterSplit filters, then it is assigned to the first set that applies to it in the training,
1072
- validation, test order. This is ignored if Dataset is not provided.
1073
- :param model_display_name: Optional. The display name of the managed Vertex AI Model. The name can be
1074
- up to 128 characters long and can consist of any UTF-8 characters.
1075
- If not provided upon creation, the job's display_name is used.
1076
- :param model_labels: Optional. The labels with user-defined metadata to organize your Models. Label
1077
- keys and values can be no longer than 64 characters (Unicode codepoints), can only contain
1078
- lowercase letters, numeric characters, underscores and dashes. International characters are
1079
- allowed. See https://goo.gl/xmQnxf for more information and examples of labels.
1080
- :param sync: Whether to execute this method synchronously. If False, this method will be executed in
1081
- concurrent Future and any downstream object will be immediately returned and synced when the
1082
- Future has completed.
1083
- """
1084
- self._job = AutoMLTextTrainingJob(
1085
- display_name=display_name,
1086
- prediction_type=prediction_type,
1087
- multi_label=multi_label,
1088
- sentiment_max=sentiment_max,
1089
- project=project_id,
1090
- location=region,
1091
- credentials=self.get_credentials(),
1092
- labels=labels,
1093
- training_encryption_spec_key_name=training_encryption_spec_key_name,
1094
- model_encryption_spec_key_name=model_encryption_spec_key_name,
1095
- )
1096
-
1097
- if not self._job:
1098
- raise AirflowException("AutoMLTextTrainingJob was not created")
1099
-
1100
- model = self._job.run(
1101
- dataset=dataset,
1102
- training_fraction_split=training_fraction_split,
1103
- validation_fraction_split=validation_fraction_split,
1104
- test_fraction_split=test_fraction_split,
1105
- training_filter_split=training_filter_split,
1106
- validation_filter_split=validation_filter_split,
1107
- test_filter_split=test_filter_split,
1108
- model_display_name=model_display_name,
1109
- model_labels=model_labels,
1110
- sync=sync,
1111
- parent_model=parent_model,
1112
- is_default_version=is_default_version,
1113
- model_version_aliases=model_version_aliases,
1114
- model_version_description=model_version_description,
1115
- )
1116
- training_id = self.extract_training_id(self._job.resource_name)
1117
- if model:
1118
- model.wait()
1119
- else:
1120
- self.log.warning(
1121
- "Training did not produce a Managed Model returning None. AutoML Text Training "
1122
- "Pipeline is not configured to upload a Model."
1123
- )
1124
- return model, training_id
1125
-
1126
953
  @GoogleBaseHook.fallback_to_default_project_id
1127
954
  def create_auto_ml_video_training_job(
1128
955
  self,
@@ -213,9 +213,15 @@ class GCSTaskHandler(FileTaskHandler, LoggingMixin):
213
213
  gcp_keyfile_dict: dict | None = None,
214
214
  gcp_scopes: Collection[str] | None = _DEFAULT_SCOPESS,
215
215
  project_id: str = PROVIDE_PROJECT_ID,
216
+ max_bytes: int = 0,
217
+ backup_count: int = 0,
218
+ delay: bool = False,
216
219
  **kwargs,
217
- ):
218
- super().__init__(base_log_folder)
220
+ ) -> None:
221
+ # support log file size handling of FileTaskHandler
222
+ super().__init__(
223
+ base_log_folder=base_log_folder, max_bytes=max_bytes, backup_count=backup_count, delay=delay
224
+ )
219
225
  self.handler: logging.FileHandler | None = None
220
226
  self.log_relative_path = ""
221
227
  self.closed = False
@@ -21,7 +21,7 @@ import shlex
21
21
  from collections.abc import Sequence
22
22
  from typing import TYPE_CHECKING, Any
23
23
 
24
- from google.api_core.exceptions import AlreadyExists
24
+ from google.api_core.exceptions import AlreadyExists, NotFound
25
25
  from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
26
26
  from google.cloud.orchestration.airflow.service_v1 import ImageVersion
27
27
  from google.cloud.orchestration.airflow.service_v1.types import Environment, ExecuteAirflowCommandResponse
@@ -798,3 +798,86 @@ class CloudComposerRunAirflowCLICommandOperator(GoogleCloudBaseOperator):
798
798
  """Merge output to one string."""
799
799
  result_str = "\n".join(line_dict["content"] for line_dict in result["output"])
800
800
  return result_str
801
+
802
+
803
+ class CloudComposerTriggerDAGRunOperator(GoogleCloudBaseOperator):
804
+ """
805
+ Trigger DAG run for provided Composer environment.
806
+
807
+ :param project_id: The ID of the Google Cloud project that the service belongs to.
808
+ :param region: The ID of the Google Cloud region that the service belongs to.
809
+ :param environment_id: The ID of the Google Cloud environment that the service belongs to.
810
+ :param composer_dag_id: The ID of DAG which will be triggered.
811
+ :param composer_dag_conf: Configuration parameters for the DAG run.
812
+ :param timeout: The timeout for this request.
813
+ :param gcp_conn_id: The connection ID used to connect to Google Cloud Platform.
814
+ :param impersonation_chain: Optional service account to impersonate using short-term
815
+ credentials, or chained list of accounts required to get the access_token
816
+ of the last account in the list, which will be impersonated in the request.
817
+ If set as a string, the account must grant the originating account
818
+ the Service Account Token Creator IAM role.
819
+ If set as a sequence, the identities from the list must grant
820
+ Service Account Token Creator IAM role to the directly preceding identity, with first
821
+ account from the list granting this role to the originating account (templated).
822
+ """
823
+
824
+ template_fields = (
825
+ "project_id",
826
+ "region",
827
+ "environment_id",
828
+ "composer_dag_id",
829
+ "impersonation_chain",
830
+ )
831
+
832
+ def __init__(
833
+ self,
834
+ *,
835
+ project_id: str,
836
+ region: str,
837
+ environment_id: str,
838
+ composer_dag_id: str,
839
+ composer_dag_conf: dict | None = None,
840
+ timeout: float | None = None,
841
+ gcp_conn_id: str = "google_cloud_default",
842
+ impersonation_chain: str | Sequence[str] | None = None,
843
+ **kwargs,
844
+ ) -> None:
845
+ super().__init__(**kwargs)
846
+ self.project_id = project_id
847
+ self.region = region
848
+ self.environment_id = environment_id
849
+ self.composer_dag_id = composer_dag_id
850
+ self.composer_dag_conf = composer_dag_conf or {}
851
+ self.timeout = timeout
852
+ self.gcp_conn_id = gcp_conn_id
853
+ self.impersonation_chain = impersonation_chain
854
+
855
+ def execute(self, context: Context):
856
+ hook = CloudComposerHook(
857
+ gcp_conn_id=self.gcp_conn_id,
858
+ impersonation_chain=self.impersonation_chain,
859
+ )
860
+ try:
861
+ environment = hook.get_environment(
862
+ project_id=self.project_id,
863
+ region=self.region,
864
+ environment_id=self.environment_id,
865
+ timeout=self.timeout,
866
+ )
867
+ except NotFound as not_found_err:
868
+ self.log.info("The Composer environment %s does not exist.", self.environment_id)
869
+ raise AirflowException(not_found_err)
870
+ composer_airflow_uri = environment.config.airflow_uri
871
+
872
+ self.log.info(
873
+ "Triggering the DAG %s on the %s environment...", self.composer_dag_id, self.environment_id
874
+ )
875
+ dag_run = hook.trigger_dag_run(
876
+ composer_airflow_uri=composer_airflow_uri,
877
+ composer_dag_id=self.composer_dag_id,
878
+ composer_dag_conf=self.composer_dag_conf,
879
+ timeout=self.timeout,
880
+ )
881
+ self.log.info("The DAG %s was triggered with Run ID: %s", self.composer_dag_id, dag_run["dag_run_id"])
882
+
883
+ return dag_run
@@ -222,7 +222,7 @@ class CloudComposerDAGRunSensor(BaseSensorOperator):
222
222
  if self.deferrable:
223
223
  start_date, end_date = self._get_logical_dates(context)
224
224
  self.defer(
225
- timeout=self.timeout,
225
+ timeout=timedelta(seconds=self.timeout) if self.timeout else None,
226
226
  trigger=CloudComposerDAGRunTrigger(
227
227
  project_id=self.project_id,
228
228
  region=self.region,
@@ -25,13 +25,11 @@ from functools import cached_property
25
25
  from typing import TYPE_CHECKING
26
26
 
27
27
  from airflow.exceptions import AirflowProviderDeprecationWarning
28
- from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook
29
28
  from airflow.providers.google.cloud.links.bigquery import BigQueryTableLink
30
29
  from airflow.providers.google.cloud.transfers.bigquery_to_sql import BigQueryToSqlBaseOperator
31
30
  from airflow.providers.microsoft.mssql.hooks.mssql import MsSqlHook
32
31
 
33
32
  if TYPE_CHECKING:
34
- from airflow.providers.openlineage.extractors import OperatorLineage
35
33
  from airflow.utils.context import Context
36
34
 
37
35
 
@@ -112,67 +110,3 @@ class BigQueryToMsSqlOperator(BigQueryToSqlBaseOperator):
112
110
  project_id=project_id,
113
111
  table_id=table_id,
114
112
  )
115
-
116
- def get_openlineage_facets_on_complete(self, task_instance) -> OperatorLineage | None:
117
- from airflow.providers.common.compat.openlineage.facet import Dataset
118
- from airflow.providers.google.cloud.openlineage.utils import (
119
- BIGQUERY_NAMESPACE,
120
- get_facets_from_bq_table_for_given_fields,
121
- get_identity_column_lineage_facet,
122
- )
123
- from airflow.providers.openlineage.extractors import OperatorLineage
124
-
125
- if not self.bigquery_hook:
126
- self.bigquery_hook = BigQueryHook(
127
- gcp_conn_id=self.gcp_conn_id,
128
- location=self.location,
129
- impersonation_chain=self.impersonation_chain,
130
- )
131
-
132
- try:
133
- table_obj = self.bigquery_hook.get_client().get_table(self.source_project_dataset_table)
134
- except Exception:
135
- self.log.debug(
136
- "OpenLineage: could not fetch BigQuery table %s",
137
- self.source_project_dataset_table,
138
- exc_info=True,
139
- )
140
- return OperatorLineage()
141
-
142
- if self.selected_fields:
143
- if isinstance(self.selected_fields, str):
144
- bigquery_field_names = list(self.selected_fields)
145
- else:
146
- bigquery_field_names = self.selected_fields
147
- else:
148
- bigquery_field_names = [f.name for f in getattr(table_obj, "schema", [])]
149
-
150
- input_dataset = Dataset(
151
- namespace=BIGQUERY_NAMESPACE,
152
- name=self.source_project_dataset_table,
153
- facets=get_facets_from_bq_table_for_given_fields(table_obj, bigquery_field_names),
154
- )
155
-
156
- db_info = self.mssql_hook.get_openlineage_database_info(self.mssql_hook.get_conn())
157
- default_schema = self.mssql_hook.get_openlineage_default_schema()
158
- namespace = f"{db_info.scheme}://{db_info.authority}"
159
-
160
- if self.target_table_name and "." in self.target_table_name:
161
- schema_name, table_name = self.target_table_name.split(".", 1)
162
- else:
163
- schema_name = default_schema or ""
164
- table_name = self.target_table_name or ""
165
-
166
- if self.database:
167
- output_name = f"{self.database}.{schema_name}.{table_name}"
168
- else:
169
- output_name = f"{schema_name}.{table_name}"
170
-
171
- column_lineage_facet = get_identity_column_lineage_facet(
172
- bigquery_field_names, input_datasets=[input_dataset]
173
- )
174
-
175
- output_facets = column_lineage_facet or {}
176
- output_dataset = Dataset(namespace=namespace, name=output_name, facets=output_facets)
177
-
178
- return OperatorLineage(inputs=[input_dataset], outputs=[output_dataset])