apache-airflow-providers-google 10.19.0__py3-none-any.whl → 10.20.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. airflow/providers/google/LICENSE +4 -4
  2. airflow/providers/google/__init__.py +1 -1
  3. airflow/providers/google/ads/hooks/ads.py +4 -4
  4. airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +26 -0
  5. airflow/providers/google/cloud/hooks/dataflow.py +132 -1
  6. airflow/providers/google/cloud/hooks/datapipeline.py +22 -73
  7. airflow/providers/google/cloud/hooks/gcs.py +21 -0
  8. airflow/providers/google/cloud/hooks/pubsub.py +10 -1
  9. airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py +8 -0
  10. airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +15 -3
  11. airflow/providers/google/cloud/hooks/vertex_ai/hyperparameter_tuning_job.py +1 -1
  12. airflow/providers/google/cloud/links/dataflow.py +25 -0
  13. airflow/providers/google/cloud/openlineage/mixins.py +271 -0
  14. airflow/providers/google/cloud/openlineage/utils.py +5 -218
  15. airflow/providers/google/cloud/operators/bigquery.py +74 -20
  16. airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +76 -0
  17. airflow/providers/google/cloud/operators/dataflow.py +235 -1
  18. airflow/providers/google/cloud/operators/datapipeline.py +29 -121
  19. airflow/providers/google/cloud/operators/dataplex.py +1 -1
  20. airflow/providers/google/cloud/operators/dataproc_metastore.py +17 -6
  21. airflow/providers/google/cloud/operators/kubernetes_engine.py +9 -6
  22. airflow/providers/google/cloud/operators/pubsub.py +18 -0
  23. airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py +6 -0
  24. airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +16 -0
  25. airflow/providers/google/cloud/sensors/cloud_composer.py +171 -2
  26. airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py +13 -0
  27. airflow/providers/google/cloud/transfers/bigquery_to_postgres.py +56 -1
  28. airflow/providers/google/cloud/transfers/gcs_to_gcs.py +6 -12
  29. airflow/providers/google/cloud/triggers/cloud_composer.py +115 -0
  30. airflow/providers/google/cloud/triggers/kubernetes_engine.py +2 -0
  31. airflow/providers/google/cloud/utils/credentials_provider.py +81 -6
  32. airflow/providers/google/cloud/utils/external_token_supplier.py +175 -0
  33. airflow/providers/google/common/hooks/base_google.py +35 -1
  34. airflow/providers/google/common/utils/id_token_credentials.py +1 -1
  35. airflow/providers/google/get_provider_info.py +19 -14
  36. {apache_airflow_providers_google-10.19.0.dist-info → apache_airflow_providers_google-10.20.0.dist-info}/METADATA +41 -35
  37. {apache_airflow_providers_google-10.19.0.dist-info → apache_airflow_providers_google-10.20.0.dist-info}/RECORD +39 -37
  38. {apache_airflow_providers_google-10.19.0.dist-info → apache_airflow_providers_google-10.20.0.dist-info}/WHEEL +0 -0
  39. {apache_airflow_providers_google-10.19.0.dist-info → apache_airflow_providers_google-10.20.0.dist-info}/entry_points.txt +0 -0
@@ -215,7 +215,7 @@ Third party Apache 2.0 licenses
215
215
 
216
216
  The following components are provided under the Apache 2.0 License.
217
217
  See project link for details. The text of each license is also included
218
- at licenses/LICENSE-[project].txt.
218
+ at 3rd-party-licenses/LICENSE-[project].txt.
219
219
 
220
220
  (ALv2 License) hue v4.3.0 (https://github.com/cloudera/hue/)
221
221
  (ALv2 License) jqclock v2.3.0 (https://github.com/JohnRDOrazio/jQuery-Clock-Plugin)
@@ -227,7 +227,7 @@ MIT licenses
227
227
  ========================================================================
228
228
 
229
229
  The following components are provided under the MIT License. See project link for details.
230
- The text of each license is also included at licenses/LICENSE-[project].txt.
230
+ The text of each license is also included at 3rd-party-licenses/LICENSE-[project].txt.
231
231
 
232
232
  (MIT License) jquery v3.5.1 (https://jquery.org/license/)
233
233
  (MIT License) dagre-d3 v0.6.4 (https://github.com/cpettitt/dagre-d3)
@@ -243,11 +243,11 @@ The text of each license is also included at licenses/LICENSE-[project].txt.
243
243
  BSD 3-Clause licenses
244
244
  ========================================================================
245
245
  The following components are provided under the BSD 3-Clause license. See project links for details.
246
- The text of each license is also included at licenses/LICENSE-[project].txt.
246
+ The text of each license is also included at 3rd-party-licenses/LICENSE-[project].txt.
247
247
 
248
248
  (BSD 3 License) d3 v5.16.0 (https://d3js.org)
249
249
  (BSD 3 License) d3-shape v2.1.0 (https://github.com/d3/d3-shape)
250
250
  (BSD 3 License) cgroupspy 0.2.1 (https://github.com/cloudsigma/cgroupspy)
251
251
 
252
252
  ========================================================================
253
- See licenses/LICENSES-ui.txt for packages used in `/airflow/www`
253
+ See 3rd-party-licenses/LICENSES-ui.txt for packages used in `/airflow/www`
@@ -29,7 +29,7 @@ from airflow import __version__ as airflow_version
29
29
 
30
30
  __all__ = ["__version__"]
31
31
 
32
- __version__ = "10.19.0"
32
+ __version__ = "10.20.0"
33
33
 
34
34
  if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
35
35
  "2.7.0"
@@ -32,9 +32,9 @@ from airflow.hooks.base import BaseHook
32
32
  from airflow.providers.google.common.hooks.base_google import get_field
33
33
 
34
34
  if TYPE_CHECKING:
35
- from google.ads.googleads.v16.services.services.customer_service import CustomerServiceClient
36
- from google.ads.googleads.v16.services.services.google_ads_service import GoogleAdsServiceClient
37
- from google.ads.googleads.v16.services.types.google_ads_service import GoogleAdsRow
35
+ from google.ads.googleads.v17.services.services.customer_service import CustomerServiceClient
36
+ from google.ads.googleads.v17.services.services.google_ads_service import GoogleAdsServiceClient
37
+ from google.ads.googleads.v17.services.types.google_ads_service import GoogleAdsRow
38
38
  from google.api_core.page_iterator import GRPCIterator
39
39
 
40
40
 
@@ -100,7 +100,7 @@ class GoogleAdsHook(BaseHook):
100
100
  :param api_version: The Google Ads API version to use.
101
101
  """
102
102
 
103
- default_api_version = "v16"
103
+ default_api_version = "v17"
104
104
 
105
105
  def __init__(
106
106
  self,
@@ -344,6 +344,32 @@ class CloudDataTransferServiceHook(GoogleBaseHook):
344
344
  .execute(num_retries=self.num_retries)
345
345
  )
346
346
 
347
+ @GoogleBaseHook.fallback_to_default_project_id
348
+ def run_transfer_job(self, job_name: str, project_id: str) -> dict:
349
+ """Run Google Storage Transfer Service job.
350
+
351
+ :param job_name: (Required) Name of the job to be fetched
352
+ :param project_id: (Optional) the ID of the project that owns the Transfer
353
+ Job. If set to None or missing, the default project_id from the Google Cloud
354
+ connection is used.
355
+ :return: If successful, Operation. See:
356
+ https://cloud.google.com/storage-transfer/docs/reference/rest/v1/Operation
357
+
358
+ .. seealso:: https://cloud.google.com/storage-transfer/docs/reference/rest/v1/transferJobs/run
359
+
360
+ """
361
+ return (
362
+ self.get_conn()
363
+ .transferJobs()
364
+ .run(
365
+ jobName=job_name,
366
+ body={
367
+ PROJECT_ID: project_id,
368
+ },
369
+ )
370
+ .execute(num_retries=self.num_retries)
371
+ )
372
+
347
373
  def cancel_transfer_operation(self, operation_name: str) -> None:
348
374
  """Cancel a transfer operation in Google Storage Transfer Service.
349
375
 
@@ -71,7 +71,7 @@ DEFAULT_DATAFLOW_LOCATION = "us-central1"
71
71
 
72
72
 
73
73
  JOB_ID_PATTERN = re.compile(
74
- r"Submitted job: (?P<job_id_java>.*)|Created job with id: \[(?P<job_id_python>.*)\]"
74
+ r"Submitted job: (?P<job_id_java>[^\"\n]*)|Created job with id: \[(?P<job_id_python>[^\"\n]*)\]"
75
75
  )
76
76
 
77
77
  T = TypeVar("T", bound=Callable)
@@ -582,6 +582,11 @@ class DataflowHook(GoogleBaseHook):
582
582
  http_authorized = self._authorize()
583
583
  return build("dataflow", "v1b3", http=http_authorized, cache_discovery=False)
584
584
 
585
+ def get_pipelines_conn(self) -> build:
586
+ """Return a Google Cloud Data Pipelines service object."""
587
+ http_authorized = self._authorize()
588
+ return build("datapipelines", "v1", http=http_authorized, cache_discovery=False)
589
+
585
590
  @_fallback_to_location_from_variables
586
591
  @_fallback_to_project_id_from_variables
587
592
  @GoogleBaseHook.fallback_to_default_project_id
@@ -1351,6 +1356,132 @@ class DataflowHook(GoogleBaseHook):
1351
1356
 
1352
1357
  return job_controller._check_dataflow_job_state(job)
1353
1358
 
1359
+ @GoogleBaseHook.fallback_to_default_project_id
1360
+ def create_data_pipeline(
1361
+ self,
1362
+ body: dict,
1363
+ project_id: str,
1364
+ location: str = DEFAULT_DATAFLOW_LOCATION,
1365
+ ):
1366
+ """
1367
+ Create a new Dataflow Data Pipelines instance.
1368
+
1369
+ :param body: The request body (contains instance of Pipeline). See:
1370
+ https://cloud.google.com/dataflow/docs/reference/data-pipelines/rest/v1/projects.locations.pipelines/create#request-body
1371
+ :param project_id: The ID of the GCP project that owns the job.
1372
+ :param location: The location to direct the Data Pipelines instance to (for example us-central1).
1373
+
1374
+ Returns the created Data Pipelines instance in JSON representation.
1375
+ """
1376
+ parent = self.build_parent_name(project_id, location)
1377
+ service = self.get_pipelines_conn()
1378
+ request = (
1379
+ service.projects()
1380
+ .locations()
1381
+ .pipelines()
1382
+ .create(
1383
+ parent=parent,
1384
+ body=body,
1385
+ )
1386
+ )
1387
+ response = request.execute(num_retries=self.num_retries)
1388
+ return response
1389
+
1390
+ @GoogleBaseHook.fallback_to_default_project_id
1391
+ def get_data_pipeline(
1392
+ self,
1393
+ pipeline_name: str,
1394
+ project_id: str,
1395
+ location: str = DEFAULT_DATAFLOW_LOCATION,
1396
+ ) -> dict:
1397
+ """
1398
+ Retrieve a new Dataflow Data Pipelines instance.
1399
+
1400
+ :param pipeline_name: The display name of the pipeline. In example
1401
+ projects/PROJECT_ID/locations/LOCATION_ID/pipelines/PIPELINE_ID it would be the PIPELINE_ID.
1402
+ :param project_id: The ID of the GCP project that owns the job.
1403
+ :param location: The location to direct the Data Pipelines instance to (for example us-central1).
1404
+
1405
+ Returns the created Data Pipelines instance in JSON representation.
1406
+ """
1407
+ parent = self.build_parent_name(project_id, location)
1408
+ service = self.get_pipelines_conn()
1409
+ request = (
1410
+ service.projects()
1411
+ .locations()
1412
+ .pipelines()
1413
+ .get(
1414
+ name=f"{parent}/pipelines/{pipeline_name}",
1415
+ )
1416
+ )
1417
+ response = request.execute(num_retries=self.num_retries)
1418
+ return response
1419
+
1420
+ @GoogleBaseHook.fallback_to_default_project_id
1421
+ def run_data_pipeline(
1422
+ self,
1423
+ pipeline_name: str,
1424
+ project_id: str,
1425
+ location: str = DEFAULT_DATAFLOW_LOCATION,
1426
+ ) -> dict:
1427
+ """
1428
+ Run a Dataflow Data Pipeline Instance.
1429
+
1430
+ :param pipeline_name: The display name of the pipeline. In example
1431
+ projects/PROJECT_ID/locations/LOCATION_ID/pipelines/PIPELINE_ID it would be the PIPELINE_ID.
1432
+ :param project_id: The ID of the GCP project that owns the job.
1433
+ :param location: The location to direct the Data Pipelines instance to (for example us-central1).
1434
+
1435
+ Returns the created Job in JSON representation.
1436
+ """
1437
+ parent = self.build_parent_name(project_id, location)
1438
+ service = self.get_pipelines_conn()
1439
+ request = (
1440
+ service.projects()
1441
+ .locations()
1442
+ .pipelines()
1443
+ .run(
1444
+ name=f"{parent}/pipelines/{pipeline_name}",
1445
+ body={},
1446
+ )
1447
+ )
1448
+ response = request.execute(num_retries=self.num_retries)
1449
+ return response
1450
+
1451
+ @GoogleBaseHook.fallback_to_default_project_id
1452
+ def delete_data_pipeline(
1453
+ self,
1454
+ pipeline_name: str,
1455
+ project_id: str,
1456
+ location: str = DEFAULT_DATAFLOW_LOCATION,
1457
+ ) -> dict | None:
1458
+ """
1459
+ Delete a Dataflow Data Pipelines Instance.
1460
+
1461
+ :param pipeline_name: The display name of the pipeline. In example
1462
+ projects/PROJECT_ID/locations/LOCATION_ID/pipelines/PIPELINE_ID it would be the PIPELINE_ID.
1463
+ :param project_id: The ID of the GCP project that owns the job.
1464
+ :param location: The location to direct the Data Pipelines instance to (for example us-central1).
1465
+
1466
+ Returns the created Job in JSON representation.
1467
+ """
1468
+ parent = self.build_parent_name(project_id, location)
1469
+ service = self.get_pipelines_conn()
1470
+ request = (
1471
+ service.projects()
1472
+ .locations()
1473
+ .pipelines()
1474
+ .delete(
1475
+ name=f"{parent}/pipelines/{pipeline_name}",
1476
+ )
1477
+ )
1478
+ response = request.execute(num_retries=self.num_retries)
1479
+ return response
1480
+
1481
+ @staticmethod
1482
+ def build_parent_name(project_id: str, location: str):
1483
+ return f"projects/{project_id}/locations/{location}"
1484
+
1354
1485
 
1355
1486
  class AsyncDataflowHook(GoogleBaseAsyncHook):
1356
1487
  """Async hook class for dataflow service."""
@@ -19,40 +19,30 @@
19
19
 
20
20
  from __future__ import annotations
21
21
 
22
- from typing import Sequence
22
+ from typing import TYPE_CHECKING
23
23
 
24
- from googleapiclient.discovery import build
24
+ from deprecated import deprecated
25
25
 
26
- from airflow.providers.google.common.hooks.base_google import (
27
- GoogleBaseHook,
28
- )
29
-
30
- DEFAULT_DATAPIPELINE_LOCATION = "us-central1"
26
+ from airflow.exceptions import AirflowProviderDeprecationWarning
27
+ from airflow.providers.google.cloud.hooks.dataflow import DataflowHook
28
+ from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
31
29
 
30
+ if TYPE_CHECKING:
31
+ from googleapiclient.discovery import build
32
32
 
33
- class DataPipelineHook(GoogleBaseHook):
34
- """
35
- Hook for Google Data Pipelines.
33
+ DEFAULT_DATAPIPELINE_LOCATION = "us-central1"
36
34
 
37
- All the methods in the hook where project_id is used must be called with
38
- keyword arguments rather than positional.
39
- """
40
35
 
41
- def __init__(
42
- self,
43
- gcp_conn_id: str = "google_cloud_default",
44
- impersonation_chain: str | Sequence[str] | None = None,
45
- **kwargs,
46
- ) -> None:
47
- super().__init__(
48
- gcp_conn_id=gcp_conn_id,
49
- impersonation_chain=impersonation_chain,
50
- )
36
+ @deprecated(
37
+ reason="This hook is deprecated and will be removed after 01.12.2024. Please use `DataflowHook`.",
38
+ category=AirflowProviderDeprecationWarning,
39
+ )
40
+ class DataPipelineHook(DataflowHook):
41
+ """Hook for Google Data Pipelines."""
51
42
 
52
43
  def get_conn(self) -> build:
53
44
  """Return a Google Cloud Data Pipelines service object."""
54
- http_authorized = self._authorize()
55
- return build("datapipelines", "v1", http=http_authorized, cache_discovery=False)
45
+ return super().get_pipelines_conn()
56
46
 
57
47
  @GoogleBaseHook.fallback_to_default_project_id
58
48
  def create_data_pipeline(
@@ -60,31 +50,9 @@ class DataPipelineHook(GoogleBaseHook):
60
50
  body: dict,
61
51
  project_id: str,
62
52
  location: str = DEFAULT_DATAPIPELINE_LOCATION,
63
- ) -> None:
64
- """
65
- Create a new Data Pipelines instance from the Data Pipelines API.
66
-
67
- :param body: The request body (contains instance of Pipeline). See:
68
- https://cloud.google.com/dataflow/docs/reference/data-pipelines/rest/v1/projects.locations.pipelines/create#request-body
69
- :param project_id: The ID of the GCP project that owns the job.
70
- :param location: The location to direct the Data Pipelines instance to (for example us-central1).
71
-
72
- Returns the created Data Pipelines instance in JSON representation.
73
- """
74
- parent = self.build_parent_name(project_id, location)
75
- service = self.get_conn()
76
- self.log.info(dir(service.projects().locations()))
77
- request = (
78
- service.projects()
79
- .locations()
80
- .pipelines()
81
- .create(
82
- parent=parent,
83
- body=body,
84
- )
85
- )
86
- response = request.execute(num_retries=self.num_retries)
87
- return response
53
+ ) -> dict:
54
+ """Create a new Data Pipelines instance from the Data Pipelines API."""
55
+ return super().create_data_pipeline(body=body, project_id=project_id, location=location)
88
56
 
89
57
  @GoogleBaseHook.fallback_to_default_project_id
90
58
  def run_data_pipeline(
@@ -92,30 +60,11 @@ class DataPipelineHook(GoogleBaseHook):
92
60
  data_pipeline_name: str,
93
61
  project_id: str,
94
62
  location: str = DEFAULT_DATAPIPELINE_LOCATION,
95
- ) -> None:
96
- """
97
- Run a Data Pipelines Instance using the Data Pipelines API.
98
-
99
- :param data_pipeline_name: The display name of the pipeline. In example
100
- projects/PROJECT_ID/locations/LOCATION_ID/pipelines/PIPELINE_ID it would be the PIPELINE_ID.
101
- :param project_id: The ID of the GCP project that owns the job.
102
- :param location: The location to direct the Data Pipelines instance to (for example us-central1).
103
-
104
- Returns the created Job in JSON representation.
105
- """
106
- parent = self.build_parent_name(project_id, location)
107
- service = self.get_conn()
108
- request = (
109
- service.projects()
110
- .locations()
111
- .pipelines()
112
- .run(
113
- name=f"{parent}/pipelines/{data_pipeline_name}",
114
- body={},
115
- )
63
+ ) -> dict:
64
+ """Run a Data Pipelines Instance using the Data Pipelines API."""
65
+ return super().run_data_pipeline(
66
+ pipeline_name=data_pipeline_name, project_id=project_id, location=location
116
67
  )
117
- response = request.execute(num_retries=self.num_retries)
118
- return response
119
68
 
120
69
  @staticmethod
121
70
  def build_parent_name(project_id: str, location: str):
@@ -1010,6 +1010,27 @@ class GCSHook(GoogleBaseHook):
1010
1010
  self.log.info("The md5Hash of %s is %s", object_name, blob_md5hash)
1011
1011
  return blob_md5hash
1012
1012
 
1013
+ def get_metadata(self, bucket_name: str, object_name: str) -> dict | None:
1014
+ """
1015
+ Get the metadata of an object in Google Cloud Storage.
1016
+
1017
+ :param bucket_name: Name of the Google Cloud Storage bucket where the object is.
1018
+ :param object_name: The name of the object containing the desired metadata
1019
+ :return: The metadata associated with the object
1020
+ """
1021
+ self.log.info("Retrieving the metadata dict of object (%s) in bucket (%s)", object_name, bucket_name)
1022
+ client = self.get_conn()
1023
+ bucket = client.bucket(bucket_name)
1024
+ blob = bucket.get_blob(blob_name=object_name)
1025
+ if blob is None:
1026
+ raise ValueError("Object (%s) not found in bucket (%s)", object_name, bucket_name)
1027
+ blob_metadata = blob.metadata
1028
+ if blob_metadata:
1029
+ self.log.info("Retrieved metadata of object (%s) with %s fields", object_name, len(blob_metadata))
1030
+ else:
1031
+ self.log.info("Metadata of object (%s) is empty or it does not exist", object_name)
1032
+ return blob_metadata
1033
+
1013
1034
  @GoogleBaseHook.fallback_to_default_project_id
1014
1035
  def create_bucket(
1015
1036
  self,
@@ -36,6 +36,7 @@ from google.api_core.exceptions import AlreadyExists, GoogleAPICallError
36
36
  from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
37
37
  from google.cloud.exceptions import NotFound
38
38
  from google.cloud.pubsub_v1 import PublisherClient, SubscriberClient
39
+ from google.cloud.pubsub_v1.types import PublisherOptions
39
40
  from google.pubsub_v1.services.subscriber.async_client import SubscriberAsyncClient
40
41
  from googleapiclient.errors import HttpError
41
42
 
@@ -79,6 +80,7 @@ class PubSubHook(GoogleBaseHook):
79
80
  self,
80
81
  gcp_conn_id: str = "google_cloud_default",
81
82
  impersonation_chain: str | Sequence[str] | None = None,
83
+ enable_message_ordering: bool = False,
82
84
  **kwargs,
83
85
  ) -> None:
84
86
  if kwargs.get("delegate_to") is not None:
@@ -90,6 +92,7 @@ class PubSubHook(GoogleBaseHook):
90
92
  gcp_conn_id=gcp_conn_id,
91
93
  impersonation_chain=impersonation_chain,
92
94
  )
95
+ self.enable_message_ordering = enable_message_ordering
93
96
  self._client = None
94
97
 
95
98
  def get_conn(self) -> PublisherClient:
@@ -99,7 +102,13 @@ class PubSubHook(GoogleBaseHook):
99
102
  :return: Google Cloud Pub/Sub client object.
100
103
  """
101
104
  if not self._client:
102
- self._client = PublisherClient(credentials=self.get_credentials(), client_info=CLIENT_INFO)
105
+ self._client = PublisherClient(
106
+ credentials=self.get_credentials(),
107
+ client_info=CLIENT_INFO,
108
+ publisher_options=PublisherOptions(
109
+ enable_message_ordering=self.enable_message_ordering,
110
+ ),
111
+ )
103
112
  return self._client
104
113
 
105
114
  @cached_property
@@ -551,6 +551,8 @@ class AutoMLHook(GoogleBaseHook):
551
551
  is_default_version: bool | None = None,
552
552
  model_version_aliases: list[str] | None = None,
553
553
  model_version_description: str | None = None,
554
+ window_stride_length: int | None = None,
555
+ window_max_count: int | None = None,
554
556
  ) -> tuple[models.Model | None, str]:
555
557
  """
556
558
  Create an AutoML Forecasting Training Job.
@@ -703,6 +705,10 @@ class AutoMLHook(GoogleBaseHook):
703
705
  :param sync: Whether to execute this method synchronously. If False, this method will be executed in
704
706
  concurrent Future and any downstream object will be immediately returned and synced when the
705
707
  Future has completed.
708
+ :param window_stride_length: Optional. Step length used to generate input examples. Every
709
+ ``window_stride_length`` rows will be used to generate a sliding window.
710
+ :param window_max_count: Optional. Number of rows that should be used to generate input examples. If the
711
+ total row count is larger than this number, the input data will be randomly sampled to hit the count.
706
712
  """
707
713
  if column_transformations:
708
714
  warnings.warn(
@@ -758,6 +764,8 @@ class AutoMLHook(GoogleBaseHook):
758
764
  is_default_version=is_default_version,
759
765
  model_version_aliases=model_version_aliases,
760
766
  model_version_description=model_version_description,
767
+ window_stride_length=window_stride_length,
768
+ window_max_count=window_max_count,
761
769
  )
762
770
  training_id = self.extract_training_id(self._job.resource_name)
763
771
  if model:
@@ -141,6 +141,8 @@ class GenerativeModelHook(GoogleBaseHook):
141
141
  self,
142
142
  prompt: str,
143
143
  location: str,
144
+ generation_config: dict | None = None,
145
+ safety_settings: dict | None = None,
144
146
  pretrained_model: str = "gemini-pro",
145
147
  project_id: str = PROVIDE_PROJECT_ID,
146
148
  ) -> str:
@@ -149,17 +151,21 @@ class GenerativeModelHook(GoogleBaseHook):
149
151
 
150
152
  :param prompt: Required. Inputs or queries that a user or a program gives
151
153
  to the Multi-modal model, in order to elicit a specific response.
154
+ :param location: Required. The ID of the Google Cloud location that the service belongs to.
155
+ :param generation_config: Optional. Generation configuration settings.
156
+ :param safety_settings: Optional. Per request settings for blocking unsafe content.
152
157
  :param pretrained_model: By default uses the pre-trained model `gemini-pro`,
153
158
  supporting prompts with text-only input, including natural language
154
159
  tasks, multi-turn text and code chat, and code generation. It can
155
160
  output text and code.
156
- :param location: Required. The ID of the Google Cloud location that the service belongs to.
157
161
  :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
158
162
  """
159
163
  vertexai.init(project=project_id, location=location, credentials=self.get_credentials())
160
164
 
161
165
  model = self.get_generative_model(pretrained_model)
162
- response = model.generate_content(prompt)
166
+ response = model.generate_content(
167
+ contents=[prompt], generation_config=generation_config, safety_settings=safety_settings
168
+ )
163
169
 
164
170
  return response.text
165
171
 
@@ -170,6 +176,8 @@ class GenerativeModelHook(GoogleBaseHook):
170
176
  location: str,
171
177
  media_gcs_path: str,
172
178
  mime_type: str,
179
+ generation_config: dict | None = None,
180
+ safety_settings: dict | None = None,
173
181
  pretrained_model: str = "gemini-pro-vision",
174
182
  project_id: str = PROVIDE_PROJECT_ID,
175
183
  ) -> str:
@@ -178,6 +186,8 @@ class GenerativeModelHook(GoogleBaseHook):
178
186
 
179
187
  :param prompt: Required. Inputs or queries that a user or a program gives
180
188
  to the Multi-modal model, in order to elicit a specific response.
189
+ :param generation_config: Optional. Generation configuration settings.
190
+ :param safety_settings: Optional. Per request settings for blocking unsafe content.
181
191
  :param pretrained_model: By default uses the pre-trained model `gemini-pro-vision`,
182
192
  supporting prompts with text-only input, including natural language
183
193
  tasks, multi-turn text and code chat, and code generation. It can
@@ -192,6 +202,8 @@ class GenerativeModelHook(GoogleBaseHook):
192
202
 
193
203
  model = self.get_generative_model(pretrained_model)
194
204
  part = self.get_generative_model_part(media_gcs_path, mime_type)
195
- response = model.generate_content([prompt, part])
205
+ response = model.generate_content(
206
+ contents=[prompt, part], generation_config=generation_config, safety_settings=safety_settings
207
+ )
196
208
 
197
209
  return response.text
@@ -126,7 +126,7 @@ class HyperparameterTuningJobHook(GoogleBaseHook):
126
126
  base_output_dir=base_output_dir,
127
127
  project=project,
128
128
  location=location,
129
- credentials=self.get_credentials,
129
+ credentials=self.get_credentials(),
130
130
  labels=labels,
131
131
  encryption_spec_key_name=encryption_spec_key_name,
132
132
  staging_bucket=staging_bucket,
@@ -30,6 +30,9 @@ if TYPE_CHECKING:
30
30
  DATAFLOW_BASE_LINK = "/dataflow/jobs"
31
31
  DATAFLOW_JOB_LINK = DATAFLOW_BASE_LINK + "/{region}/{job_id}?project={project_id}"
32
32
 
33
+ DATAFLOW_PIPELINE_BASE_LINK = "/dataflow/pipelines"
34
+ DATAFLOW_PIPELINE_LINK = DATAFLOW_PIPELINE_BASE_LINK + "/{location}/{pipeline_name}?project={project_id}"
35
+
33
36
 
34
37
  class DataflowJobLink(BaseGoogleLink):
35
38
  """Helper class for constructing Dataflow Job Link."""
@@ -51,3 +54,25 @@ class DataflowJobLink(BaseGoogleLink):
51
54
  key=DataflowJobLink.key,
52
55
  value={"project_id": project_id, "region": region, "job_id": job_id},
53
56
  )
57
+
58
+
59
+ class DataflowPipelineLink(BaseGoogleLink):
60
+ """Helper class for constructing Dataflow Pipeline Link."""
61
+
62
+ name = "Dataflow Pipeline"
63
+ key = "dataflow_pipeline_config"
64
+ format_str = DATAFLOW_PIPELINE_LINK
65
+
66
+ @staticmethod
67
+ def persist(
68
+ operator_instance: BaseOperator,
69
+ context: Context,
70
+ project_id: str | None,
71
+ location: str | None,
72
+ pipeline_name: str | None,
73
+ ):
74
+ operator_instance.xcom_push(
75
+ context,
76
+ key=DataflowPipelineLink.key,
77
+ value={"project_id": project_id, "location": location, "pipeline_name": pipeline_name},
78
+ )