apache-airflow-providers-google 10.10.0rc1__py3-none-any.whl → 10.10.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/google/__init__.py +1 -1
- airflow/providers/google/cloud/hooks/cloud_run.py +4 -2
- airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py +131 -27
- airflow/providers/google/cloud/hooks/vertex_ai/batch_prediction_job.py +1 -9
- airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py +121 -4
- airflow/providers/google/cloud/hooks/vertex_ai/endpoint_service.py +1 -11
- airflow/providers/google/cloud/hooks/vertex_ai/hyperparameter_tuning_job.py +1 -10
- airflow/providers/google/cloud/hooks/vertex_ai/model_service.py +220 -6
- airflow/providers/google/cloud/hooks/vertex_ai/pipeline_job.py +409 -0
- airflow/providers/google/cloud/links/vertex_ai.py +49 -0
- airflow/providers/google/cloud/operators/dataproc.py +32 -10
- airflow/providers/google/cloud/operators/gcs.py +1 -1
- airflow/providers/google/cloud/operators/mlengine.py +116 -0
- airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py +45 -0
- airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py +2 -8
- airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +287 -201
- airflow/providers/google/cloud/operators/vertex_ai/endpoint_service.py +1 -9
- airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py +2 -9
- airflow/providers/google/cloud/operators/vertex_ai/model_service.py +451 -12
- airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py +464 -0
- airflow/providers/google/cloud/utils/mlengine_operator_utils.py +7 -1
- airflow/providers/google/get_provider_info.py +5 -0
- {apache_airflow_providers_google-10.10.0rc1.dist-info → apache_airflow_providers_google-10.10.1.dist-info}/METADATA +8 -8
- {apache_airflow_providers_google-10.10.0rc1.dist-info → apache_airflow_providers_google-10.10.1.dist-info}/RECORD +29 -27
- {apache_airflow_providers_google-10.10.0rc1.dist-info → apache_airflow_providers_google-10.10.1.dist-info}/LICENSE +0 -0
- {apache_airflow_providers_google-10.10.0rc1.dist-info → apache_airflow_providers_google-10.10.1.dist-info}/NOTICE +0 -0
- {apache_airflow_providers_google-10.10.0rc1.dist-info → apache_airflow_providers_google-10.10.1.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_google-10.10.0rc1.dist-info → apache_airflow_providers_google-10.10.1.dist-info}/entry_points.txt +0 -0
- {apache_airflow_providers_google-10.10.0rc1.dist-info → apache_airflow_providers_google-10.10.1.dist-info}/top_level.txt +0 -0
@@ -15,17 +15,9 @@
|
|
15
15
|
# KIND, either express or implied. See the License for the
|
16
16
|
# specific language governing permissions and limitations
|
17
17
|
# under the License.
|
18
|
-
"""This module contains Google Vertex AI operators.
|
19
18
|
|
20
|
-
|
19
|
+
"""This module contains Google Vertex AI operators."""
|
21
20
|
|
22
|
-
undeployed
|
23
|
-
undeploy
|
24
|
-
Undeploys
|
25
|
-
aiplatform
|
26
|
-
FieldMask
|
27
|
-
unassigns
|
28
|
-
"""
|
29
21
|
from __future__ import annotations
|
30
22
|
|
31
23
|
from typing import TYPE_CHECKING, Sequence
|
@@ -15,16 +15,9 @@
|
|
15
15
|
# KIND, either express or implied. See the License for the
|
16
16
|
# specific language governing permissions and limitations
|
17
17
|
# under the License.
|
18
|
-
"""This module contains Google Vertex AI operators.
|
19
18
|
|
20
|
-
|
19
|
+
"""This module contains Google Vertex AI operators."""
|
21
20
|
|
22
|
-
irreproducible
|
23
|
-
codepoints
|
24
|
-
Tensorboard
|
25
|
-
aiplatform
|
26
|
-
myVPC
|
27
|
-
"""
|
28
21
|
from __future__ import annotations
|
29
22
|
|
30
23
|
from typing import TYPE_CHECKING, Sequence
|
@@ -67,7 +60,7 @@ class CreateHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
|
|
67
60
|
:param max_trial_count: Required. The desired total number of Trials.
|
68
61
|
:param parallel_trial_count: Required. The desired number of Trials to run in parallel.
|
69
62
|
:param worker_pool_specs: Required. The spec of the worker pools including machine type and Docker
|
70
|
-
image. Can provided as a list of dictionaries or list of WorkerPoolSpec proto messages.
|
63
|
+
image. Can be provided as a list of dictionaries or list of WorkerPoolSpec proto messages.
|
71
64
|
:param base_output_dir: Optional. GCS output directory of job. If not provided a timestamped
|
72
65
|
directory in the staging directory will be used.
|
73
66
|
:param custom_job_labels: Optional. The labels with user-defined metadata to organize CustomJobs.
|
@@ -15,13 +15,8 @@
|
|
15
15
|
# KIND, either express or implied. See the License for the
|
16
16
|
# specific language governing permissions and limitations
|
17
17
|
# under the License.
|
18
|
-
"""This module contains Google Vertex AI operators.
|
18
|
+
"""This module contains Google Vertex AI operators."""
|
19
19
|
|
20
|
-
.. spelling:word-list::
|
21
|
-
|
22
|
-
aiplatform
|
23
|
-
camelCase
|
24
|
-
"""
|
25
20
|
from __future__ import annotations
|
26
21
|
|
27
22
|
from typing import TYPE_CHECKING, Sequence
|
@@ -50,7 +45,10 @@ class DeleteModelOperator(GoogleCloudBaseOperator):
|
|
50
45
|
|
51
46
|
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
52
47
|
:param region: Required. The ID of the Google Cloud region that the service belongs to.
|
53
|
-
:param model_id: Required. The
|
48
|
+
:param model_id: Required. The ID of the Model resource to be deleted.
|
49
|
+
Could be in format `projects/{project}/locations/{location}/models/{model_id}@{version_id}` or
|
50
|
+
`projects/{project}/locations/{location}/models/{model_id}@{version_alias}` if model
|
51
|
+
has several versions.
|
54
52
|
:param retry: Designation of what errors, if any, should be retried.
|
55
53
|
:param timeout: The timeout for this request.
|
56
54
|
:param metadata: Strings which should be sent along with the request as metadata.
|
@@ -95,7 +93,7 @@ class DeleteModelOperator(GoogleCloudBaseOperator):
|
|
95
93
|
gcp_conn_id=self.gcp_conn_id,
|
96
94
|
impersonation_chain=self.impersonation_chain,
|
97
95
|
)
|
98
|
-
|
96
|
+
self.model_id = self.model_id.rpartition("@")[0] if "@" in self.model_id else self.model_id
|
99
97
|
try:
|
100
98
|
self.log.info("Deleting model: %s", self.model_id)
|
101
99
|
operation = hook.delete_model(
|
@@ -112,13 +110,91 @@ class DeleteModelOperator(GoogleCloudBaseOperator):
|
|
112
110
|
self.log.info("The Model ID %s does not exist.", self.model_id)
|
113
111
|
|
114
112
|
|
113
|
+
class GetModelOperator(GoogleCloudBaseOperator):
|
114
|
+
"""
|
115
|
+
Retrieves a Model.
|
116
|
+
|
117
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
118
|
+
:param region: Required. The ID of the Google Cloud region that the service belongs to.
|
119
|
+
:param model_id: Required. The ID of the Model resource to be retrieved.
|
120
|
+
Could be in format `projects/{project}/locations/{location}/models/{model_id}@{version_id}` or
|
121
|
+
`projects/{project}/locations/{location}/models/{model_id}@{version_alias}` if model has
|
122
|
+
several versions.
|
123
|
+
:param retry: Designation of what errors, if any, should be retried.
|
124
|
+
:param timeout: The timeout for this request.
|
125
|
+
:param metadata: Strings which should be sent along with the request as metadata.
|
126
|
+
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
|
127
|
+
:param impersonation_chain: Optional service account to impersonate using short-term
|
128
|
+
credentials, or chained list of accounts required to get the access_token
|
129
|
+
of the last account in the list, which will be impersonated in the request.
|
130
|
+
If set as a string, the account must grant the originating account
|
131
|
+
the Service Account Token Creator IAM role.
|
132
|
+
If set as a sequence, the identities from the list must grant
|
133
|
+
Service Account Token Creator IAM role to the directly preceding identity, with first
|
134
|
+
account from the list granting this role to the originating account (templated).
|
135
|
+
"""
|
136
|
+
|
137
|
+
template_fields = ("region", "model_id", "project_id", "impersonation_chain")
|
138
|
+
operator_extra_links = (VertexAIModelLink(),)
|
139
|
+
|
140
|
+
def __init__(
|
141
|
+
self,
|
142
|
+
*,
|
143
|
+
region: str,
|
144
|
+
project_id: str,
|
145
|
+
model_id: str,
|
146
|
+
retry: Retry | _MethodDefault = DEFAULT,
|
147
|
+
timeout: float | None = None,
|
148
|
+
metadata: Sequence[tuple[str, str]] = (),
|
149
|
+
gcp_conn_id: str = "google_cloud_default",
|
150
|
+
impersonation_chain: str | Sequence[str] | None = None,
|
151
|
+
**kwargs,
|
152
|
+
) -> None:
|
153
|
+
super().__init__(**kwargs)
|
154
|
+
self.region = region
|
155
|
+
self.project_id = project_id
|
156
|
+
self.model_id = model_id
|
157
|
+
self.retry = retry
|
158
|
+
self.timeout = timeout
|
159
|
+
self.metadata = metadata
|
160
|
+
self.gcp_conn_id = gcp_conn_id
|
161
|
+
self.impersonation_chain = impersonation_chain
|
162
|
+
|
163
|
+
def execute(self, context: Context):
|
164
|
+
hook = ModelServiceHook(
|
165
|
+
gcp_conn_id=self.gcp_conn_id,
|
166
|
+
impersonation_chain=self.impersonation_chain,
|
167
|
+
)
|
168
|
+
self.model_id = self.model_id.rpartition("@")[0] if "@" in self.model_id else self.model_id
|
169
|
+
try:
|
170
|
+
self.log.info("Retrieving model: %s", self.model_id)
|
171
|
+
model = hook.get_model(
|
172
|
+
project_id=self.project_id,
|
173
|
+
region=self.region,
|
174
|
+
model_id=self.model_id,
|
175
|
+
retry=self.retry,
|
176
|
+
timeout=self.timeout,
|
177
|
+
metadata=self.metadata,
|
178
|
+
)
|
179
|
+
self.log.info("Model found. Model ID: %s", self.model_id)
|
180
|
+
|
181
|
+
self.xcom_push(context, key="model_id", value=self.model_id)
|
182
|
+
VertexAIModelLink.persist(context=context, task_instance=self, model_id=self.model_id)
|
183
|
+
return Model.to_dict(model)
|
184
|
+
except NotFound:
|
185
|
+
self.log.info("The Model ID %s does not exist.", self.model_id)
|
186
|
+
|
187
|
+
|
115
188
|
class ExportModelOperator(GoogleCloudBaseOperator):
|
116
189
|
"""
|
117
190
|
Exports a trained, exportable Model to a location specified by the user.
|
118
191
|
|
119
192
|
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
120
193
|
:param region: Required. The ID of the Google Cloud region that the service belongs to.
|
121
|
-
:param model_id: Required. The
|
194
|
+
:param model_id: Required. The ID of the Model to export.
|
195
|
+
Could be in format `projects/{project}/locations/{location}/models/{model_id}@{version_id}` or
|
196
|
+
`projects/{project}/locations/{location}/models/{model_id}@{version_alias}` if model has
|
197
|
+
several versions.
|
122
198
|
:param output_config: Required. The desired output location and configuration.
|
123
199
|
:param retry: Designation of what errors, if any, should be retried.
|
124
200
|
:param timeout: The timeout for this request.
|
@@ -195,8 +271,8 @@ class ListModelsOperator(GoogleCloudBaseOperator):
|
|
195
271
|
:param retry: Designation of what errors, if any, should be retried.
|
196
272
|
:param filter: An expression for filtering the results of the request. For field names both
|
197
273
|
snake_case and camelCase are supported.
|
198
|
-
- ``model`` supports = and !=. ``model`` represents the Model ID,
|
199
|
-
Model's [resource name][google.cloud.aiplatform.v1.Model.name].
|
274
|
+
- ``model`` supports = and !=. ``model`` represents the Model ID, Could be in format the
|
275
|
+
last segment of the Model's [resource name][google.cloud.aiplatform.v1.Model.name].
|
200
276
|
- ``display_name`` supports = and !=
|
201
277
|
- ``labels`` supports general map functions that is:
|
202
278
|
-- ``labels.key=value`` - key:value equality
|
@@ -285,7 +361,7 @@ class UploadModelOperator(GoogleCloudBaseOperator):
|
|
285
361
|
|
286
362
|
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
287
363
|
:param region: Required. The ID of the Google Cloud region that the service belongs to.
|
288
|
-
:param model:
|
364
|
+
:param model: Required. The Model to create.
|
289
365
|
:param retry: Designation of what errors, if any, should be retried.
|
290
366
|
:param timeout: The timeout for this request.
|
291
367
|
:param metadata: Strings which should be sent along with the request as metadata.
|
@@ -349,3 +425,366 @@ class UploadModelOperator(GoogleCloudBaseOperator):
|
|
349
425
|
self.xcom_push(context, key="model_id", value=model_id)
|
350
426
|
VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
|
351
427
|
return model_resp
|
428
|
+
|
429
|
+
|
430
|
+
class ListModelVersionsOperator(GoogleCloudBaseOperator):
|
431
|
+
"""
|
432
|
+
Lists Model versions in a Location.
|
433
|
+
|
434
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
435
|
+
:param region: Required. The ID of the Google Cloud region that the service belongs to.
|
436
|
+
:param model_id: Required. The ID of the model to list versions for.
|
437
|
+
Could be in format `projects/{project}/locations/{location}/models/{model_id}@{version_id}` or
|
438
|
+
`projects/{project}/locations/{location}/models/{model_id}@{version_alias}` if model has
|
439
|
+
several versions.
|
440
|
+
:param retry: Designation of what errors, if any, should be retried.
|
441
|
+
:param timeout: The timeout for this request.
|
442
|
+
:param metadata: Strings which should be sent along with the request as metadata.
|
443
|
+
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
|
444
|
+
:param impersonation_chain: Optional service account to impersonate using short-term
|
445
|
+
credentials, or chained list of accounts required to get the access_token
|
446
|
+
of the last account in the list, which will be impersonated in the request.
|
447
|
+
If set as a string, the account must grant the originating account
|
448
|
+
the Service Account Token Creator IAM role.
|
449
|
+
If set as a sequence, the identities from the list must grant
|
450
|
+
Service Account Token Creator IAM role to the directly preceding identity, with first
|
451
|
+
account from the list granting this role to the originating account (templated).
|
452
|
+
"""
|
453
|
+
|
454
|
+
template_fields = ("model_id", "region", "project_id", "impersonation_chain")
|
455
|
+
|
456
|
+
def __init__(
|
457
|
+
self,
|
458
|
+
*,
|
459
|
+
region: str,
|
460
|
+
project_id: str,
|
461
|
+
model_id: str,
|
462
|
+
retry: Retry | _MethodDefault = DEFAULT,
|
463
|
+
timeout: float | None = None,
|
464
|
+
metadata: Sequence[tuple[str, str]] = (),
|
465
|
+
gcp_conn_id: str = "google_cloud_default",
|
466
|
+
impersonation_chain: str | Sequence[str] | None = None,
|
467
|
+
**kwargs,
|
468
|
+
) -> None:
|
469
|
+
super().__init__(**kwargs)
|
470
|
+
self.region = region
|
471
|
+
self.project_id = project_id
|
472
|
+
self.model_id = model_id
|
473
|
+
self.retry = retry
|
474
|
+
self.timeout = timeout
|
475
|
+
self.metadata = metadata
|
476
|
+
self.gcp_conn_id = gcp_conn_id
|
477
|
+
self.impersonation_chain = impersonation_chain
|
478
|
+
|
479
|
+
def execute(self, context: Context):
|
480
|
+
hook = ModelServiceHook(
|
481
|
+
gcp_conn_id=self.gcp_conn_id,
|
482
|
+
impersonation_chain=self.impersonation_chain,
|
483
|
+
)
|
484
|
+
self.log.info("Retrieving versions list from model: %s", self.model_id)
|
485
|
+
results = hook.list_model_versions(
|
486
|
+
project_id=self.project_id,
|
487
|
+
region=self.region,
|
488
|
+
model_id=self.model_id,
|
489
|
+
retry=self.retry,
|
490
|
+
timeout=self.timeout,
|
491
|
+
metadata=self.metadata,
|
492
|
+
)
|
493
|
+
for result in results:
|
494
|
+
model = Model.to_dict(result)
|
495
|
+
self.log.info("Model name: %s;", model["name"])
|
496
|
+
self.log.info("Model version: %s, model alias %s;", model["version_id"], model["version_aliases"])
|
497
|
+
return [Model.to_dict(result) for result in results]
|
498
|
+
|
499
|
+
|
500
|
+
class SetDefaultVersionOnModelOperator(GoogleCloudBaseOperator):
|
501
|
+
"""
|
502
|
+
Sets the desired Model version as Default.
|
503
|
+
|
504
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
505
|
+
:param region: Required. The ID of the Google Cloud region that the service belongs to.
|
506
|
+
:param model_id: Required. The ID of the model to set as default.
|
507
|
+
Should be in format `projects/{project}/locations/{location}/models/{model_id}@{version_id}` or
|
508
|
+
`projects/{project}/locations/{location}/models/{model_id}@{version_alias}` if model
|
509
|
+
has several versions.
|
510
|
+
:param retry: Designation of what errors, if any, should be retried.
|
511
|
+
:param timeout: The timeout for this request.
|
512
|
+
:param metadata: Strings which should be sent along with the request as metadata.
|
513
|
+
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
|
514
|
+
:param impersonation_chain: Optional service account to impersonate using short-term
|
515
|
+
credentials, or chained list of accounts required to get the access_token
|
516
|
+
of the last account in the list, which will be impersonated in the request.
|
517
|
+
If set as a string, the account must grant the originating account
|
518
|
+
the Service Account Token Creator IAM role.
|
519
|
+
If set as a sequence, the identities from the list must grant
|
520
|
+
Service Account Token Creator IAM role to the directly preceding identity, with first
|
521
|
+
account from the list granting this role to the originating account (templated).
|
522
|
+
"""
|
523
|
+
|
524
|
+
template_fields = ("model_id", "project_id", "impersonation_chain")
|
525
|
+
operator_extra_links = (VertexAIModelLink(),)
|
526
|
+
|
527
|
+
def __init__(
|
528
|
+
self,
|
529
|
+
*,
|
530
|
+
region: str,
|
531
|
+
project_id: str,
|
532
|
+
model_id: str,
|
533
|
+
retry: Retry | _MethodDefault = DEFAULT,
|
534
|
+
timeout: float | None = None,
|
535
|
+
metadata: Sequence[tuple[str, str]] = (),
|
536
|
+
gcp_conn_id: str = "google_cloud_default",
|
537
|
+
impersonation_chain: str | Sequence[str] | None = None,
|
538
|
+
**kwargs,
|
539
|
+
) -> None:
|
540
|
+
super().__init__(**kwargs)
|
541
|
+
self.region = region
|
542
|
+
self.project_id = project_id
|
543
|
+
self.model_id = model_id
|
544
|
+
self.retry = retry
|
545
|
+
self.timeout = timeout
|
546
|
+
self.metadata = metadata
|
547
|
+
self.gcp_conn_id = gcp_conn_id
|
548
|
+
self.impersonation_chain = impersonation_chain
|
549
|
+
|
550
|
+
def execute(self, context: Context):
|
551
|
+
hook = ModelServiceHook(
|
552
|
+
gcp_conn_id=self.gcp_conn_id,
|
553
|
+
impersonation_chain=self.impersonation_chain,
|
554
|
+
)
|
555
|
+
|
556
|
+
self.log.info(
|
557
|
+
"Setting version %s as default on model %s", self.model_id.rpartition("@")[0], self.model_id
|
558
|
+
)
|
559
|
+
|
560
|
+
updated_model = hook.set_version_as_default(
|
561
|
+
region=self.region,
|
562
|
+
model_id=self.model_id,
|
563
|
+
project_id=self.project_id,
|
564
|
+
retry=self.retry,
|
565
|
+
timeout=self.timeout,
|
566
|
+
metadata=self.metadata,
|
567
|
+
)
|
568
|
+
VertexAIModelLink.persist(context=context, task_instance=self, model_id=self.model_id)
|
569
|
+
return Model.to_dict(updated_model)
|
570
|
+
|
571
|
+
|
572
|
+
class AddVersionAliasesOnModelOperator(GoogleCloudBaseOperator):
|
573
|
+
"""
|
574
|
+
Adds version aliases for the Model.
|
575
|
+
|
576
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
577
|
+
:param region: Required. The ID of the Google Cloud region that the service belongs to.
|
578
|
+
:param model_id: Required. The ID of the model to add version aliases for.
|
579
|
+
Should be in format `projects/{project}/locations/{location}/models/{model_id}@{version_id}` or
|
580
|
+
`projects/{project}/locations/{location}/models/{model_id}@{version_alias}`.
|
581
|
+
:param version_aliases: List of version aliases to be added to model version.
|
582
|
+
:param retry: Designation of what errors, if any, should be retried.
|
583
|
+
:param timeout: The timeout for this request.
|
584
|
+
:param metadata: Strings which should be sent along with the request as metadata.
|
585
|
+
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
|
586
|
+
:param impersonation_chain: Optional service account to impersonate using short-term
|
587
|
+
credentials, or chained list of accounts required to get the access_token
|
588
|
+
of the last account in the list, which will be impersonated in the request.
|
589
|
+
If set as a string, the account must grant the originating account
|
590
|
+
the Service Account Token Creator IAM role.
|
591
|
+
If set as a sequence, the identities from the list must grant
|
592
|
+
Service Account Token Creator IAM role to the directly preceding identity, with first
|
593
|
+
account from the list granting this role to the originating account (templated).
|
594
|
+
"""
|
595
|
+
|
596
|
+
template_fields = ("model_id", "project_id", "impersonation_chain")
|
597
|
+
operator_extra_links = (VertexAIModelLink(),)
|
598
|
+
|
599
|
+
def __init__(
|
600
|
+
self,
|
601
|
+
*,
|
602
|
+
region: str,
|
603
|
+
project_id: str,
|
604
|
+
model_id: str,
|
605
|
+
version_aliases: Sequence[str],
|
606
|
+
retry: Retry | _MethodDefault = DEFAULT,
|
607
|
+
timeout: float | None = None,
|
608
|
+
metadata: Sequence[tuple[str, str]] = (),
|
609
|
+
gcp_conn_id: str = "google_cloud_default",
|
610
|
+
impersonation_chain: str | Sequence[str] | None = None,
|
611
|
+
**kwargs,
|
612
|
+
) -> None:
|
613
|
+
super().__init__(**kwargs)
|
614
|
+
self.region = region
|
615
|
+
self.project_id = project_id
|
616
|
+
self.model_id = model_id
|
617
|
+
self.version_aliases = version_aliases
|
618
|
+
self.retry = retry
|
619
|
+
self.timeout = timeout
|
620
|
+
self.metadata = metadata
|
621
|
+
self.gcp_conn_id = gcp_conn_id
|
622
|
+
self.impersonation_chain = impersonation_chain
|
623
|
+
|
624
|
+
def execute(self, context: Context):
|
625
|
+
hook = ModelServiceHook(
|
626
|
+
gcp_conn_id=self.gcp_conn_id,
|
627
|
+
impersonation_chain=self.impersonation_chain,
|
628
|
+
)
|
629
|
+
self.log.info(
|
630
|
+
"Adding aliases %s to model version %s", self.version_aliases, self.model_id.rpartition("@")[0]
|
631
|
+
)
|
632
|
+
|
633
|
+
updated_model = hook.add_version_aliases(
|
634
|
+
region=self.region,
|
635
|
+
model_id=self.model_id,
|
636
|
+
version_aliases=self.version_aliases,
|
637
|
+
project_id=self.project_id,
|
638
|
+
retry=self.retry,
|
639
|
+
timeout=self.timeout,
|
640
|
+
metadata=self.metadata,
|
641
|
+
)
|
642
|
+
VertexAIModelLink.persist(context=context, task_instance=self, model_id=self.model_id)
|
643
|
+
return Model.to_dict(updated_model)
|
644
|
+
|
645
|
+
|
646
|
+
class DeleteVersionAliasesOnModelOperator(GoogleCloudBaseOperator):
|
647
|
+
"""
|
648
|
+
Deletes version aliases for the Model.
|
649
|
+
|
650
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
651
|
+
:param region: Required. The ID of the Google Cloud region that the service belongs to.
|
652
|
+
:param model_id: Required. The ID of the model to delete version aliases from.
|
653
|
+
Should be in format `projects/{project}/locations/{location}/models/{model_id}@{version_id}` or
|
654
|
+
`projects/{project}/locations/{location}/models/{model_id}@{version_alias}`.
|
655
|
+
:param version_aliases: List of version aliases to be deleted from model version.
|
656
|
+
:param retry: Designation of what errors, if any, should be retried.
|
657
|
+
:param timeout: The timeout for this request.
|
658
|
+
:param metadata: Strings which should be sent along with the request as metadata.
|
659
|
+
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
|
660
|
+
:param impersonation_chain: Optional service account to impersonate using short-term
|
661
|
+
credentials, or chained list of accounts required to get the access_token
|
662
|
+
of the last account in the list, which will be impersonated in the request.
|
663
|
+
If set as a string, the account must grant the originating account
|
664
|
+
the Service Account Token Creator IAM role.
|
665
|
+
If set as a sequence, the identities from the list must grant
|
666
|
+
Service Account Token Creator IAM role to the directly preceding identity, with first
|
667
|
+
account from the list granting this role to the originating account (templated).
|
668
|
+
"""
|
669
|
+
|
670
|
+
template_fields = ("model_id", "project_id", "impersonation_chain")
|
671
|
+
operator_extra_links = (VertexAIModelLink(),)
|
672
|
+
|
673
|
+
def __init__(
|
674
|
+
self,
|
675
|
+
*,
|
676
|
+
region: str,
|
677
|
+
project_id: str,
|
678
|
+
model_id: str,
|
679
|
+
version_aliases: Sequence[str],
|
680
|
+
retry: Retry | _MethodDefault = DEFAULT,
|
681
|
+
timeout: float | None = None,
|
682
|
+
metadata: Sequence[tuple[str, str]] = (),
|
683
|
+
gcp_conn_id: str = "google_cloud_default",
|
684
|
+
impersonation_chain: str | Sequence[str] | None = None,
|
685
|
+
**kwargs,
|
686
|
+
) -> None:
|
687
|
+
super().__init__(**kwargs)
|
688
|
+
self.region = region
|
689
|
+
self.project_id = project_id
|
690
|
+
self.model_id = model_id
|
691
|
+
self.version_aliases = version_aliases
|
692
|
+
self.retry = retry
|
693
|
+
self.timeout = timeout
|
694
|
+
self.metadata = metadata
|
695
|
+
self.gcp_conn_id = gcp_conn_id
|
696
|
+
self.impersonation_chain = impersonation_chain
|
697
|
+
|
698
|
+
def execute(self, context: Context):
|
699
|
+
hook = ModelServiceHook(
|
700
|
+
gcp_conn_id=self.gcp_conn_id,
|
701
|
+
impersonation_chain=self.impersonation_chain,
|
702
|
+
)
|
703
|
+
self.log.info(
|
704
|
+
"Deleting aliases %s from model version %s",
|
705
|
+
self.version_aliases,
|
706
|
+
self.model_id.rpartition("@")[0],
|
707
|
+
)
|
708
|
+
|
709
|
+
updated_model = hook.delete_version_aliases(
|
710
|
+
region=self.region,
|
711
|
+
model_id=self.model_id,
|
712
|
+
version_aliases=self.version_aliases,
|
713
|
+
project_id=self.project_id,
|
714
|
+
retry=self.retry,
|
715
|
+
timeout=self.timeout,
|
716
|
+
metadata=self.metadata,
|
717
|
+
)
|
718
|
+
VertexAIModelLink.persist(context=context, task_instance=self, model_id=self.model_id)
|
719
|
+
return Model.to_dict(updated_model)
|
720
|
+
|
721
|
+
|
722
|
+
class DeleteModelVersionOperator(GoogleCloudBaseOperator):
|
723
|
+
"""
|
724
|
+
Delete Model version in a Location.
|
725
|
+
|
726
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
727
|
+
:param region: Required. The ID of the Google Cloud region that the service belongs to.
|
728
|
+
:param model_id: Required. The ID of the Model in which to delete version.
|
729
|
+
Should be in format `projects/{project}/locations/{location}/models/{model_id}@{version_id}` or
|
730
|
+
`projects/{project}/locations/{location}/models/{model_id}@{version_alias}`
|
731
|
+
several versions.
|
732
|
+
:param retry: Designation of what errors, if any, should be retried.
|
733
|
+
:param timeout: The timeout for this request.
|
734
|
+
:param metadata: Strings which should be sent along with the request as metadata.
|
735
|
+
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
|
736
|
+
:param impersonation_chain: Optional service account to impersonate using short-term
|
737
|
+
credentials, or chained list of accounts required to get the access_token
|
738
|
+
of the last account in the list, which will be impersonated in the request.
|
739
|
+
If set as a string, the account must grant the originating account
|
740
|
+
the Service Account Token Creator IAM role.
|
741
|
+
If set as a sequence, the identities from the list must grant
|
742
|
+
Service Account Token Creator IAM role to the directly preceding identity, with first
|
743
|
+
account from the list granting this role to the originating account (templated).
|
744
|
+
"""
|
745
|
+
|
746
|
+
template_fields = ("model_id", "project_id", "impersonation_chain")
|
747
|
+
|
748
|
+
def __init__(
|
749
|
+
self,
|
750
|
+
*,
|
751
|
+
region: str,
|
752
|
+
project_id: str,
|
753
|
+
model_id: str,
|
754
|
+
retry: Retry | _MethodDefault = DEFAULT,
|
755
|
+
timeout: float | None = None,
|
756
|
+
metadata: Sequence[tuple[str, str]] = (),
|
757
|
+
gcp_conn_id: str = "google_cloud_default",
|
758
|
+
impersonation_chain: str | Sequence[str] | None = None,
|
759
|
+
**kwargs,
|
760
|
+
) -> None:
|
761
|
+
super().__init__(**kwargs)
|
762
|
+
self.region = region
|
763
|
+
self.project_id = project_id
|
764
|
+
self.model_id = model_id
|
765
|
+
self.retry = retry
|
766
|
+
self.timeout = timeout
|
767
|
+
self.metadata = metadata
|
768
|
+
self.gcp_conn_id = gcp_conn_id
|
769
|
+
self.impersonation_chain = impersonation_chain
|
770
|
+
|
771
|
+
def execute(self, context: Context):
|
772
|
+
hook = ModelServiceHook(
|
773
|
+
gcp_conn_id=self.gcp_conn_id,
|
774
|
+
impersonation_chain=self.impersonation_chain,
|
775
|
+
)
|
776
|
+
|
777
|
+
try:
|
778
|
+
self.log.info("Deleting model version: %s", self.model_id)
|
779
|
+
operation = hook.delete_model_version(
|
780
|
+
project_id=self.project_id,
|
781
|
+
region=self.region,
|
782
|
+
model_id=self.model_id,
|
783
|
+
retry=self.retry,
|
784
|
+
timeout=self.timeout,
|
785
|
+
metadata=self.metadata,
|
786
|
+
)
|
787
|
+
hook.wait_for_operation(timeout=self.timeout, operation=operation)
|
788
|
+
self.log.info("Model version was deleted.")
|
789
|
+
except NotFound:
|
790
|
+
self.log.info("The Model ID %s does not exist.", self.model_id)
|