apache-airflow-providers-google 14.1.0__py3-none-any.whl → 15.0.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/google/__init__.py +1 -1
- airflow/providers/google/ads/hooks/ads.py +7 -33
- airflow/providers/google/ads/transfers/ads_to_gcs.py +1 -17
- airflow/providers/google/cloud/hooks/bigquery.py +6 -11
- airflow/providers/google/cloud/hooks/cloud_batch.py +1 -2
- airflow/providers/google/cloud/hooks/cloud_build.py +1 -54
- airflow/providers/google/cloud/hooks/compute.py +4 -3
- airflow/providers/google/cloud/hooks/dataflow.py +2 -139
- airflow/providers/google/cloud/hooks/dataform.py +6 -12
- airflow/providers/google/cloud/hooks/datafusion.py +1 -2
- airflow/providers/google/cloud/hooks/dataplex.py +1 -1
- airflow/providers/google/cloud/hooks/gcs.py +13 -5
- airflow/providers/google/cloud/hooks/life_sciences.py +1 -1
- airflow/providers/google/cloud/hooks/translate.py +1 -1
- airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py +3 -2
- airflow/providers/google/cloud/hooks/vertex_ai/batch_prediction_job.py +1 -1
- airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py +2 -272
- airflow/providers/google/cloud/hooks/vertex_ai/endpoint_service.py +2 -1
- airflow/providers/google/cloud/hooks/vertex_ai/hyperparameter_tuning_job.py +1 -1
- airflow/providers/google/cloud/hooks/vertex_ai/model_service.py +2 -1
- airflow/providers/google/cloud/links/cloud_storage_transfer.py +1 -3
- airflow/providers/google/cloud/links/dataproc.py +0 -1
- airflow/providers/google/cloud/log/gcs_task_handler.py +147 -115
- airflow/providers/google/cloud/openlineage/facets.py +32 -32
- airflow/providers/google/cloud/openlineage/mixins.py +2 -2
- airflow/providers/google/cloud/operators/automl.py +1 -1
- airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +0 -3
- airflow/providers/google/cloud/operators/datafusion.py +1 -22
- airflow/providers/google/cloud/operators/dataproc.py +1 -143
- airflow/providers/google/cloud/operators/dataproc_metastore.py +0 -1
- airflow/providers/google/cloud/operators/mlengine.py +3 -1406
- airflow/providers/google/cloud/operators/spanner.py +1 -2
- airflow/providers/google/cloud/operators/translate.py +2 -2
- airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py +0 -12
- airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +1 -22
- airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py +4 -3
- airflow/providers/google/cloud/sensors/dataproc_metastore.py +1 -1
- airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +1 -2
- airflow/providers/google/cloud/transfers/sftp_to_gcs.py +23 -10
- airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py +2 -2
- airflow/providers/google/common/auth_backend/google_openid.py +1 -1
- airflow/providers/google/common/hooks/base_google.py +7 -28
- airflow/providers/google/get_provider_info.py +3 -1
- airflow/providers/google/marketing_platform/sensors/display_video.py +1 -1
- airflow/providers/google/suite/hooks/drive.py +2 -2
- {apache_airflow_providers_google-14.1.0.dist-info → apache_airflow_providers_google-15.0.0rc1.dist-info}/METADATA +11 -9
- {apache_airflow_providers_google-14.1.0.dist-info → apache_airflow_providers_google-15.0.0rc1.dist-info}/RECORD +49 -50
- airflow/providers/google/cloud/utils/mlengine_operator_utils.py +0 -273
- {apache_airflow_providers_google-14.1.0.dist-info → apache_airflow_providers_google-15.0.0rc1.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_google-14.1.0.dist-info → apache_airflow_providers_google-15.0.0rc1.dist-info}/entry_points.txt +0 -0
@@ -37,6 +37,7 @@ from google.cloud.aiplatform import (
|
|
37
37
|
from google.cloud.aiplatform_v1 import JobServiceClient, PipelineServiceClient
|
38
38
|
|
39
39
|
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
|
40
|
+
from airflow.providers.google.common.consts import CLIENT_INFO
|
40
41
|
from airflow.providers.google.common.deprecated import deprecated
|
41
42
|
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
|
42
43
|
from airflow.providers.google.common.hooks.operation_helpers import OperationHelper
|
@@ -81,7 +82,7 @@ class AutoMLHook(GoogleBaseHook, OperationHelper):
|
|
81
82
|
client_options = ClientOptions()
|
82
83
|
|
83
84
|
return PipelineServiceClient(
|
84
|
-
credentials=self.get_credentials(), client_info=
|
85
|
+
credentials=self.get_credentials(), client_info=CLIENT_INFO, client_options=client_options
|
85
86
|
)
|
86
87
|
|
87
88
|
def get_job_service_client(
|
@@ -95,7 +96,7 @@ class AutoMLHook(GoogleBaseHook, OperationHelper):
|
|
95
96
|
client_options = ClientOptions()
|
96
97
|
|
97
98
|
return JobServiceClient(
|
98
|
-
credentials=self.get_credentials(), client_info=
|
99
|
+
credentials=self.get_credentials(), client_info=CLIENT_INFO, client_options=client_options
|
99
100
|
)
|
100
101
|
|
101
102
|
def get_auto_ml_tabular_training_job(
|
@@ -63,7 +63,7 @@ class BatchPredictionJobHook(GoogleBaseHook, OperationHelper):
|
|
63
63
|
client_options = ClientOptions()
|
64
64
|
|
65
65
|
return JobServiceClient(
|
66
|
-
credentials=self.get_credentials(), client_info=
|
66
|
+
credentials=self.get_credentials(), client_info=CLIENT_INFO, client_options=client_options
|
67
67
|
)
|
68
68
|
|
69
69
|
@staticmethod
|
@@ -42,9 +42,8 @@ from google.cloud.aiplatform_v1 import (
|
|
42
42
|
types,
|
43
43
|
)
|
44
44
|
|
45
|
-
from airflow.exceptions import AirflowException
|
45
|
+
from airflow.exceptions import AirflowException
|
46
46
|
from airflow.providers.google.common.consts import CLIENT_INFO
|
47
|
-
from airflow.providers.google.common.deprecated import deprecated
|
48
47
|
from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook
|
49
48
|
from airflow.providers.google.common.hooks.operation_helpers import OperationHelper
|
50
49
|
|
@@ -54,10 +53,9 @@ if TYPE_CHECKING:
|
|
54
53
|
from google.auth.credentials import Credentials
|
55
54
|
from google.cloud.aiplatform_v1.services.job_service.pagers import ListCustomJobsPager
|
56
55
|
from google.cloud.aiplatform_v1.services.pipeline_service.pagers import (
|
57
|
-
ListPipelineJobsPager,
|
58
56
|
ListTrainingPipelinesPager,
|
59
57
|
)
|
60
|
-
from google.cloud.aiplatform_v1.types import CustomJob,
|
58
|
+
from google.cloud.aiplatform_v1.types import CustomJob, TrainingPipeline
|
61
59
|
|
62
60
|
|
63
61
|
class CustomJobHook(GoogleBaseHook, OperationHelper):
|
@@ -368,54 +366,6 @@ class CustomJobHook(GoogleBaseHook, OperationHelper):
|
|
368
366
|
)
|
369
367
|
return model, training_id, custom_job_id
|
370
368
|
|
371
|
-
@GoogleBaseHook.fallback_to_default_project_id
|
372
|
-
@deprecated(
|
373
|
-
planned_removal_date="March 01, 2025",
|
374
|
-
use_instead="PipelineJobHook.cancel_pipeline_job",
|
375
|
-
category=AirflowProviderDeprecationWarning,
|
376
|
-
)
|
377
|
-
def cancel_pipeline_job(
|
378
|
-
self,
|
379
|
-
project_id: str,
|
380
|
-
region: str,
|
381
|
-
pipeline_job: str,
|
382
|
-
retry: Retry | _MethodDefault = DEFAULT,
|
383
|
-
timeout: float | None = None,
|
384
|
-
metadata: Sequence[tuple[str, str]] = (),
|
385
|
-
) -> None:
|
386
|
-
"""
|
387
|
-
Cancel a PipelineJob.
|
388
|
-
|
389
|
-
Starts asynchronous cancellation on the PipelineJob. The server makes the best
|
390
|
-
effort to cancel the pipeline, but success is not guaranteed. Clients can use
|
391
|
-
[PipelineService.GetPipelineJob][google.cloud.aiplatform.v1.PipelineService.GetPipelineJob] or other
|
392
|
-
methods to check whether the cancellation succeeded or whether the pipeline completed despite
|
393
|
-
cancellation. On successful cancellation, the PipelineJob is not deleted; instead it becomes a
|
394
|
-
pipeline with a [PipelineJob.error][google.cloud.aiplatform.v1.PipelineJob.error] value with a
|
395
|
-
[google.rpc.Status.code][google.rpc.Status.code] of 1, corresponding to ``Code.CANCELLED``, and
|
396
|
-
[PipelineJob.state][google.cloud.aiplatform.v1.PipelineJob.state] is set to ``CANCELLED``.
|
397
|
-
|
398
|
-
This method is deprecated, please use `PipelineJobHook.cancel_pipeline_job` method.
|
399
|
-
|
400
|
-
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
401
|
-
:param region: Required. The ID of the Google Cloud region that the service belongs to.
|
402
|
-
:param pipeline_job: The name of the PipelineJob to cancel.
|
403
|
-
:param retry: Designation of what errors, if any, should be retried.
|
404
|
-
:param timeout: The timeout for this request.
|
405
|
-
:param metadata: Strings which should be sent along with the request as metadata.
|
406
|
-
"""
|
407
|
-
client = self.get_pipeline_service_client(region)
|
408
|
-
name = client.pipeline_job_path(project_id, region, pipeline_job)
|
409
|
-
|
410
|
-
client.cancel_pipeline_job(
|
411
|
-
request={
|
412
|
-
"name": name,
|
413
|
-
},
|
414
|
-
retry=retry,
|
415
|
-
timeout=timeout,
|
416
|
-
metadata=metadata,
|
417
|
-
)
|
418
|
-
|
419
369
|
@GoogleBaseHook.fallback_to_default_project_id
|
420
370
|
def cancel_training_pipeline(
|
421
371
|
self,
|
@@ -498,53 +448,6 @@ class CustomJobHook(GoogleBaseHook, OperationHelper):
|
|
498
448
|
metadata=metadata,
|
499
449
|
)
|
500
450
|
|
501
|
-
@GoogleBaseHook.fallback_to_default_project_id
|
502
|
-
@deprecated(
|
503
|
-
planned_removal_date="March 01, 2025",
|
504
|
-
use_instead="PipelineJobHook.create_pipeline_job",
|
505
|
-
category=AirflowProviderDeprecationWarning,
|
506
|
-
)
|
507
|
-
def create_pipeline_job(
|
508
|
-
self,
|
509
|
-
project_id: str,
|
510
|
-
region: str,
|
511
|
-
pipeline_job: PipelineJob,
|
512
|
-
pipeline_job_id: str,
|
513
|
-
retry: Retry | _MethodDefault = DEFAULT,
|
514
|
-
timeout: float | None = None,
|
515
|
-
metadata: Sequence[tuple[str, str]] = (),
|
516
|
-
) -> PipelineJob:
|
517
|
-
"""
|
518
|
-
Create a PipelineJob. A PipelineJob will run immediately when created.
|
519
|
-
|
520
|
-
This method is deprecated, please use `PipelineJobHook.create_pipeline_job` method.
|
521
|
-
|
522
|
-
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
523
|
-
:param region: Required. The ID of the Google Cloud region that the service belongs to.
|
524
|
-
:param pipeline_job: Required. The PipelineJob to create.
|
525
|
-
:param pipeline_job_id: The ID to use for the PipelineJob, which will become the final component of
|
526
|
-
the PipelineJob name. If not provided, an ID will be automatically generated.
|
527
|
-
|
528
|
-
This value should be less than 128 characters, and valid characters are /[a-z][0-9]-/.
|
529
|
-
:param retry: Designation of what errors, if any, should be retried.
|
530
|
-
:param timeout: The timeout for this request.
|
531
|
-
:param metadata: Strings which should be sent along with the request as metadata.
|
532
|
-
"""
|
533
|
-
client = self.get_pipeline_service_client(region)
|
534
|
-
parent = client.common_location_path(project_id, region)
|
535
|
-
|
536
|
-
result = client.create_pipeline_job(
|
537
|
-
request={
|
538
|
-
"parent": parent,
|
539
|
-
"pipeline_job": pipeline_job,
|
540
|
-
"pipeline_job_id": pipeline_job_id,
|
541
|
-
},
|
542
|
-
retry=retry,
|
543
|
-
timeout=timeout,
|
544
|
-
metadata=metadata,
|
545
|
-
)
|
546
|
-
return result
|
547
|
-
|
548
451
|
@GoogleBaseHook.fallback_to_default_project_id
|
549
452
|
def create_training_pipeline(
|
550
453
|
self,
|
@@ -2970,46 +2873,6 @@ class CustomJobHook(GoogleBaseHook, OperationHelper):
|
|
2970
2873
|
)
|
2971
2874
|
return result
|
2972
2875
|
|
2973
|
-
@GoogleBaseHook.fallback_to_default_project_id
|
2974
|
-
@deprecated(
|
2975
|
-
planned_removal_date="March 01, 2025",
|
2976
|
-
use_instead="PipelineJobHook.get_pipeline_job",
|
2977
|
-
category=AirflowProviderDeprecationWarning,
|
2978
|
-
)
|
2979
|
-
def get_pipeline_job(
|
2980
|
-
self,
|
2981
|
-
project_id: str,
|
2982
|
-
region: str,
|
2983
|
-
pipeline_job: str,
|
2984
|
-
retry: Retry | _MethodDefault = DEFAULT,
|
2985
|
-
timeout: float | None = None,
|
2986
|
-
metadata: Sequence[tuple[str, str]] = (),
|
2987
|
-
) -> PipelineJob:
|
2988
|
-
"""
|
2989
|
-
Get a PipelineJob.
|
2990
|
-
|
2991
|
-
This method is deprecated, please use `PipelineJobHook.get_pipeline_job` method.
|
2992
|
-
|
2993
|
-
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
2994
|
-
:param region: Required. The ID of the Google Cloud region that the service belongs to.
|
2995
|
-
:param pipeline_job: Required. The name of the PipelineJob resource.
|
2996
|
-
:param retry: Designation of what errors, if any, should be retried.
|
2997
|
-
:param timeout: The timeout for this request.
|
2998
|
-
:param metadata: Strings which should be sent along with the request as metadata.
|
2999
|
-
"""
|
3000
|
-
client = self.get_pipeline_service_client(region)
|
3001
|
-
name = client.pipeline_job_path(project_id, region, pipeline_job)
|
3002
|
-
|
3003
|
-
result = client.get_pipeline_job(
|
3004
|
-
request={
|
3005
|
-
"name": name,
|
3006
|
-
},
|
3007
|
-
retry=retry,
|
3008
|
-
timeout=timeout,
|
3009
|
-
metadata=metadata,
|
3010
|
-
)
|
3011
|
-
return result
|
3012
|
-
|
3013
2876
|
@GoogleBaseHook.fallback_to_default_project_id
|
3014
2877
|
def get_training_pipeline(
|
3015
2878
|
self,
|
@@ -3076,101 +2939,6 @@ class CustomJobHook(GoogleBaseHook, OperationHelper):
|
|
3076
2939
|
)
|
3077
2940
|
return result
|
3078
2941
|
|
3079
|
-
@GoogleBaseHook.fallback_to_default_project_id
|
3080
|
-
@deprecated(
|
3081
|
-
planned_removal_date="March 01, 2025",
|
3082
|
-
use_instead="PipelineJobHook.list_pipeline_jobs",
|
3083
|
-
category=AirflowProviderDeprecationWarning,
|
3084
|
-
)
|
3085
|
-
def list_pipeline_jobs(
|
3086
|
-
self,
|
3087
|
-
project_id: str,
|
3088
|
-
region: str,
|
3089
|
-
page_size: int | None = None,
|
3090
|
-
page_token: str | None = None,
|
3091
|
-
filter: str | None = None,
|
3092
|
-
order_by: str | None = None,
|
3093
|
-
retry: Retry | _MethodDefault = DEFAULT,
|
3094
|
-
timeout: float | None = None,
|
3095
|
-
metadata: Sequence[tuple[str, str]] = (),
|
3096
|
-
) -> ListPipelineJobsPager:
|
3097
|
-
"""
|
3098
|
-
List PipelineJobs in a Location.
|
3099
|
-
|
3100
|
-
This method is deprecated, please use `PipelineJobHook.list_pipeline_jobs` method.
|
3101
|
-
|
3102
|
-
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
3103
|
-
:param region: Required. The ID of the Google Cloud region that the service belongs to.
|
3104
|
-
:param filter: Optional. Lists the PipelineJobs that match the filter expression. The
|
3105
|
-
following fields are supported:
|
3106
|
-
|
3107
|
-
- ``pipeline_name``: Supports ``=`` and ``!=`` comparisons.
|
3108
|
-
- ``display_name``: Supports ``=``, ``!=`` comparisons, and
|
3109
|
-
``:`` wildcard.
|
3110
|
-
- ``pipeline_job_user_id``: Supports ``=``, ``!=``
|
3111
|
-
comparisons, and ``:`` wildcard. for example, can check
|
3112
|
-
if pipeline's display_name contains *step* by doing
|
3113
|
-
display_name:"*step*"
|
3114
|
-
- ``create_time``: Supports ``=``, ``!=``, ``<``, ``>``,
|
3115
|
-
``<=``, and ``>=`` comparisons. Values must be in RFC
|
3116
|
-
3339 format.
|
3117
|
-
- ``update_time``: Supports ``=``, ``!=``, ``<``, ``>``,
|
3118
|
-
``<=``, and ``>=`` comparisons. Values must be in RFC
|
3119
|
-
3339 format.
|
3120
|
-
- ``end_time``: Supports ``=``, ``!=``, ``<``, ``>``,
|
3121
|
-
``<=``, and ``>=`` comparisons. Values must be in RFC
|
3122
|
-
3339 format.
|
3123
|
-
- ``labels``: Supports key-value equality and key presence.
|
3124
|
-
|
3125
|
-
Filter expressions can be combined together using logical
|
3126
|
-
operators (``AND`` & ``OR``). For example:
|
3127
|
-
``pipeline_name="test" AND create_time>"2020-05-18T13:30:00Z"``.
|
3128
|
-
|
3129
|
-
The syntax to define filter expression is based on
|
3130
|
-
https://google.aip.dev/160.
|
3131
|
-
:param page_size: Optional. The standard list page size.
|
3132
|
-
:param page_token: Optional. The standard list page token. Typically obtained via
|
3133
|
-
[ListPipelineJobsResponse.next_page_token][google.cloud.aiplatform.v1.ListPipelineJobsResponse.next_page_token]
|
3134
|
-
of the previous
|
3135
|
-
[PipelineService.ListPipelineJobs][google.cloud.aiplatform.v1.PipelineService.ListPipelineJobs]
|
3136
|
-
call.
|
3137
|
-
:param order_by: Optional. A comma-separated list of fields to order by. The default
|
3138
|
-
sort order is in ascending order. Use "desc" after a field
|
3139
|
-
name for descending. You can have multiple order_by fields
|
3140
|
-
provided e.g. "create_time desc, end_time", "end_time,
|
3141
|
-
start_time, update_time" For example, using "create_time
|
3142
|
-
desc, end_time" will order results by create time in
|
3143
|
-
descending order, and if there are multiple jobs having the
|
3144
|
-
same create time, order them by the end time in ascending
|
3145
|
-
order. if order_by is not specified, it will order by
|
3146
|
-
default order is create time in descending order. Supported
|
3147
|
-
fields:
|
3148
|
-
|
3149
|
-
- ``create_time``
|
3150
|
-
- ``update_time``
|
3151
|
-
- ``end_time``
|
3152
|
-
- ``start_time``
|
3153
|
-
:param retry: Designation of what errors, if any, should be retried.
|
3154
|
-
:param timeout: The timeout for this request.
|
3155
|
-
:param metadata: Strings which should be sent along with the request as metadata.
|
3156
|
-
"""
|
3157
|
-
client = self.get_pipeline_service_client(region)
|
3158
|
-
parent = client.common_location_path(project_id, region)
|
3159
|
-
|
3160
|
-
result = client.list_pipeline_jobs(
|
3161
|
-
request={
|
3162
|
-
"parent": parent,
|
3163
|
-
"page_size": page_size,
|
3164
|
-
"page_token": page_token,
|
3165
|
-
"filter": filter,
|
3166
|
-
"order_by": order_by,
|
3167
|
-
},
|
3168
|
-
retry=retry,
|
3169
|
-
timeout=timeout,
|
3170
|
-
metadata=metadata,
|
3171
|
-
)
|
3172
|
-
return result
|
3173
|
-
|
3174
2942
|
@GoogleBaseHook.fallback_to_default_project_id
|
3175
2943
|
def list_training_pipelines(
|
3176
2944
|
self,
|
@@ -3293,44 +3061,6 @@ class CustomJobHook(GoogleBaseHook, OperationHelper):
|
|
3293
3061
|
)
|
3294
3062
|
return result
|
3295
3063
|
|
3296
|
-
@GoogleBaseHook.fallback_to_default_project_id
|
3297
|
-
@deprecated(
|
3298
|
-
planned_removal_date="March 01, 2025",
|
3299
|
-
use_instead="PipelineJobHook.delete_pipeline_job",
|
3300
|
-
category=AirflowProviderDeprecationWarning,
|
3301
|
-
)
|
3302
|
-
def delete_pipeline_job(
|
3303
|
-
self,
|
3304
|
-
project_id: str,
|
3305
|
-
region: str,
|
3306
|
-
pipeline_job: str,
|
3307
|
-
retry: Retry | _MethodDefault = DEFAULT,
|
3308
|
-
timeout: float | None = None,
|
3309
|
-
metadata: Sequence[tuple[str, str]] = (),
|
3310
|
-
) -> Operation:
|
3311
|
-
"""
|
3312
|
-
Delete a PipelineJob.
|
3313
|
-
|
3314
|
-
This method is deprecated, please use `PipelineJobHook.delete_pipeline_job` method.
|
3315
|
-
|
3316
|
-
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
3317
|
-
:param region: Required. The ID of the Google Cloud region that the service belongs to.
|
3318
|
-
:param pipeline_job: Required. The name of the PipelineJob resource to be deleted.
|
3319
|
-
:param retry: Designation of what errors, if any, should be retried.
|
3320
|
-
:param timeout: The timeout for this request.
|
3321
|
-
:param metadata: Strings which should be sent along with the request as metadata.
|
3322
|
-
"""
|
3323
|
-
client = self.get_pipeline_service_client(region)
|
3324
|
-
name = client.pipeline_job_path(project_id, region, pipeline_job)
|
3325
|
-
|
3326
|
-
result = client.delete_pipeline_job(
|
3327
|
-
request={"name": name},
|
3328
|
-
retry=retry,
|
3329
|
-
timeout=timeout,
|
3330
|
-
metadata=metadata,
|
3331
|
-
)
|
3332
|
-
return result
|
3333
|
-
|
3334
3064
|
|
3335
3065
|
class CustomJobAsyncHook(GoogleBaseAsyncHook):
|
3336
3066
|
"""Async hook for Custom Job Service Client."""
|
@@ -26,6 +26,7 @@ from google.api_core.client_options import ClientOptions
|
|
26
26
|
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
|
27
27
|
from google.cloud.aiplatform_v1 import EndpointServiceClient
|
28
28
|
|
29
|
+
from airflow.providers.google.common.consts import CLIENT_INFO
|
29
30
|
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
|
30
31
|
from airflow.providers.google.common.hooks.operation_helpers import OperationHelper
|
31
32
|
|
@@ -48,7 +49,7 @@ class EndpointServiceHook(GoogleBaseHook, OperationHelper):
|
|
48
49
|
client_options = ClientOptions()
|
49
50
|
|
50
51
|
return EndpointServiceClient(
|
51
|
-
credentials=self.get_credentials(), client_info=
|
52
|
+
credentials=self.get_credentials(), client_info=CLIENT_INFO, client_options=client_options
|
52
53
|
)
|
53
54
|
|
54
55
|
@staticmethod
|
@@ -69,7 +69,7 @@ class HyperparameterTuningJobHook(GoogleBaseHook, OperationHelper):
|
|
69
69
|
client_options = ClientOptions()
|
70
70
|
|
71
71
|
return JobServiceClient(
|
72
|
-
credentials=self.get_credentials(), client_info=
|
72
|
+
credentials=self.get_credentials(), client_info=CLIENT_INFO, client_options=client_options
|
73
73
|
)
|
74
74
|
|
75
75
|
def get_hyperparameter_tuning_job_object(
|
@@ -28,6 +28,7 @@ from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
|
|
28
28
|
from google.cloud.aiplatform_v1 import ModelServiceClient
|
29
29
|
|
30
30
|
from airflow.exceptions import AirflowException
|
31
|
+
from airflow.providers.google.common.consts import CLIENT_INFO
|
31
32
|
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
|
32
33
|
|
33
34
|
if TYPE_CHECKING:
|
@@ -53,7 +54,7 @@ class ModelServiceHook(GoogleBaseHook, OperationHelper):
|
|
53
54
|
client_options = ClientOptions()
|
54
55
|
|
55
56
|
return ModelServiceClient(
|
56
|
-
credentials=self.get_credentials(), client_info=
|
57
|
+
credentials=self.get_credentials(), client_info=CLIENT_INFO, client_options=client_options
|
57
58
|
)
|
58
59
|
|
59
60
|
@staticmethod
|
@@ -82,15 +82,13 @@ class CloudStorageTransferJobLink(BaseGoogleLink):
|
|
82
82
|
|
83
83
|
@staticmethod
|
84
84
|
def persist(
|
85
|
-
task_instance,
|
86
85
|
context: Context,
|
87
86
|
project_id: str,
|
88
87
|
job_name: str,
|
89
88
|
):
|
90
89
|
job_name = job_name.split("/")[1] if job_name else ""
|
91
90
|
|
92
|
-
|
93
|
-
context,
|
91
|
+
context["ti"].xcom_push(
|
94
92
|
key=CloudStorageTransferJobLink.key,
|
95
93
|
value={
|
96
94
|
"project_id": project_id,
|
@@ -126,7 +126,6 @@ class DataprocLink(BaseOperatorLink):
|
|
126
126
|
|
127
127
|
def __attrs_post_init__(self):
|
128
128
|
# This link is still used into the selected operators
|
129
|
-
# - airflow.providers.google.cloud.operators.dataproc.DataprocScaleClusterOperator
|
130
129
|
# - airflow.providers.google.cloud.operators.dataproc.DataprocJobBaseOperator
|
131
130
|
# As soon as we remove reference to this link we might deprecate it by add warning message
|
132
131
|
# with `stacklevel=3` below in this method.
|