apache-airflow-providers-google 14.1.0__py3-none-any.whl → 15.0.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. airflow/providers/google/__init__.py +1 -1
  2. airflow/providers/google/ads/hooks/ads.py +7 -33
  3. airflow/providers/google/ads/transfers/ads_to_gcs.py +1 -17
  4. airflow/providers/google/cloud/hooks/bigquery.py +6 -11
  5. airflow/providers/google/cloud/hooks/cloud_batch.py +1 -2
  6. airflow/providers/google/cloud/hooks/cloud_build.py +1 -54
  7. airflow/providers/google/cloud/hooks/compute.py +4 -3
  8. airflow/providers/google/cloud/hooks/dataflow.py +2 -139
  9. airflow/providers/google/cloud/hooks/dataform.py +6 -12
  10. airflow/providers/google/cloud/hooks/datafusion.py +1 -2
  11. airflow/providers/google/cloud/hooks/dataplex.py +1 -1
  12. airflow/providers/google/cloud/hooks/gcs.py +13 -5
  13. airflow/providers/google/cloud/hooks/life_sciences.py +1 -1
  14. airflow/providers/google/cloud/hooks/translate.py +1 -1
  15. airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py +3 -2
  16. airflow/providers/google/cloud/hooks/vertex_ai/batch_prediction_job.py +1 -1
  17. airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py +2 -272
  18. airflow/providers/google/cloud/hooks/vertex_ai/endpoint_service.py +2 -1
  19. airflow/providers/google/cloud/hooks/vertex_ai/hyperparameter_tuning_job.py +1 -1
  20. airflow/providers/google/cloud/hooks/vertex_ai/model_service.py +2 -1
  21. airflow/providers/google/cloud/links/cloud_storage_transfer.py +1 -3
  22. airflow/providers/google/cloud/links/dataproc.py +0 -1
  23. airflow/providers/google/cloud/log/gcs_task_handler.py +147 -115
  24. airflow/providers/google/cloud/openlineage/facets.py +32 -32
  25. airflow/providers/google/cloud/openlineage/mixins.py +2 -2
  26. airflow/providers/google/cloud/operators/automl.py +1 -1
  27. airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +0 -3
  28. airflow/providers/google/cloud/operators/datafusion.py +1 -22
  29. airflow/providers/google/cloud/operators/dataproc.py +1 -143
  30. airflow/providers/google/cloud/operators/dataproc_metastore.py +0 -1
  31. airflow/providers/google/cloud/operators/mlengine.py +3 -1406
  32. airflow/providers/google/cloud/operators/spanner.py +1 -2
  33. airflow/providers/google/cloud/operators/translate.py +2 -2
  34. airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py +0 -12
  35. airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +1 -22
  36. airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py +4 -3
  37. airflow/providers/google/cloud/sensors/dataproc_metastore.py +1 -1
  38. airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +1 -2
  39. airflow/providers/google/cloud/transfers/sftp_to_gcs.py +23 -10
  40. airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py +2 -2
  41. airflow/providers/google/common/auth_backend/google_openid.py +1 -1
  42. airflow/providers/google/common/hooks/base_google.py +7 -28
  43. airflow/providers/google/get_provider_info.py +3 -1
  44. airflow/providers/google/marketing_platform/sensors/display_video.py +1 -1
  45. airflow/providers/google/suite/hooks/drive.py +2 -2
  46. {apache_airflow_providers_google-14.1.0.dist-info → apache_airflow_providers_google-15.0.0rc1.dist-info}/METADATA +11 -9
  47. {apache_airflow_providers_google-14.1.0.dist-info → apache_airflow_providers_google-15.0.0rc1.dist-info}/RECORD +49 -50
  48. airflow/providers/google/cloud/utils/mlengine_operator_utils.py +0 -273
  49. {apache_airflow_providers_google-14.1.0.dist-info → apache_airflow_providers_google-15.0.0rc1.dist-info}/WHEEL +0 -0
  50. {apache_airflow_providers_google-14.1.0.dist-info → apache_airflow_providers_google-15.0.0rc1.dist-info}/entry_points.txt +0 -0
@@ -20,25 +20,13 @@
20
20
  from __future__ import annotations
21
21
 
22
22
  import logging
23
- import re
24
- import time
25
23
  from collections.abc import Sequence
26
- from typing import TYPE_CHECKING, Any
24
+ from typing import TYPE_CHECKING
27
25
 
28
- from googleapiclient.errors import HttpError
29
-
30
- from airflow.configuration import conf
31
- from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
26
+ from airflow.exceptions import AirflowProviderDeprecationWarning
32
27
  from airflow.providers.google.cloud.hooks.mlengine import MLEngineHook
33
- from airflow.providers.google.cloud.links.mlengine import (
34
- MLEngineJobDetailsLink,
35
- MLEngineJobSListLink,
36
- MLEngineModelLink,
37
- MLEngineModelsListLink,
38
- MLEngineModelVersionDetailsLink,
39
- )
28
+ from airflow.providers.google.cloud.links.mlengine import MLEngineModelLink
40
29
  from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
41
- from airflow.providers.google.cloud.triggers.mlengine import MLEngineStartTrainingJobTrigger
42
30
  from airflow.providers.google.common.deprecated import deprecated
43
31
  from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID
44
32
 
@@ -49,326 +37,6 @@ if TYPE_CHECKING:
49
37
  log = logging.getLogger(__name__)
50
38
 
51
39
 
52
- def _normalize_mlengine_job_id(job_id: str) -> str:
53
- """
54
- Replace invalid MLEngine job_id characters with '_'.
55
-
56
- This also adds a leading 'z' in case job_id starts with an invalid
57
- character.
58
-
59
- :param job_id: A job_id str that may have invalid characters.
60
- :return: A valid job_id representation.
61
- """
62
- # Add a prefix when a job_id starts with a digit or a template
63
- match = re.search(r"\d|\{{2}", job_id)
64
- if match and match.start() == 0:
65
- job = f"z_{job_id}"
66
- else:
67
- job = job_id
68
-
69
- # Clean up 'bad' characters except templates
70
- tracker = 0
71
- cleansed_job_id = ""
72
- for match in re.finditer(r"\{{2}.+?\}{2}", job):
73
- cleansed_job_id += re.sub(r"[^0-9a-zA-Z]+", "_", job[tracker : match.start()])
74
- cleansed_job_id += job[match.start() : match.end()]
75
- tracker = match.end()
76
-
77
- # Clean up last substring or the full string if no templates
78
- cleansed_job_id += re.sub(r"[^0-9a-zA-Z]+", "_", job[tracker:])
79
-
80
- return cleansed_job_id
81
-
82
-
83
- @deprecated(
84
- planned_removal_date="March 01, 2025",
85
- use_instead="CreateBatchPredictionJobOperator",
86
- reason="All the functionality of legacy MLEngine and new features are available on the Vertex AI.",
87
- category=AirflowProviderDeprecationWarning,
88
- )
89
- class MLEngineStartBatchPredictionJobOperator(GoogleCloudBaseOperator):
90
- """
91
- Start a Google Cloud ML Engine prediction job.
92
-
93
- .. warning::
94
- This operator is deprecated. Please use
95
- :class:`airflow.providers.google.cloud.operators.vertex_ai.batch_prediction.CreateBatchPredictionJobOperator`
96
- instead.
97
-
98
- .. seealso::
99
- For more information on how to use this operator, take a look at the guide:
100
- :ref:`howto/operator:MLEngineStartBatchPredictionJobOperator`
101
-
102
- NOTE: For model origin, users should consider exactly one from the
103
- three options below:
104
-
105
- 1. Populate ``uri`` field only, which should be a GCS location that
106
- points to a tensorflow savedModel directory.
107
- 2. Populate ``model_name`` field only, which refers to an existing
108
- model, and the default version of the model will be used.
109
- 3. Populate both ``model_name`` and ``version_name`` fields, which
110
- refers to a specific version of a specific model.
111
-
112
- In options 2 and 3, both model and version name should contain the
113
- minimal identifier. For instance, call::
114
-
115
- MLEngineStartBatchPredictionJobOperator(
116
- ...,
117
- model_name='my_model',
118
- version_name='my_version',
119
- ...)
120
-
121
- if the desired model version is
122
- ``projects/my_project/models/my_model/versions/my_version``.
123
-
124
- See https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs
125
- for further documentation on the parameters.
126
-
127
- :param job_id: A unique id for the prediction job on Google Cloud
128
- ML Engine. (templated)
129
- :param data_format: The format of the input data.
130
- It will default to 'DATA_FORMAT_UNSPECIFIED' if is not provided
131
- or is not one of ["TEXT", "TF_RECORD", "TF_RECORD_GZIP"].
132
- :param input_paths: A list of GCS paths of input data for batch
133
- prediction. Accepting wildcard operator ``*``, but only at the end. (templated)
134
- :param output_path: The GCS path where the prediction results are
135
- written to. (templated)
136
- :param region: The Google Compute Engine region to run the
137
- prediction job in. (templated)
138
- :param model_name: The Google Cloud ML Engine model to use for prediction.
139
- If version_name is not provided, the default version of this
140
- model will be used.
141
- Should not be None if version_name is provided.
142
- Should be None if uri is provided. (templated)
143
- :param version_name: The Google Cloud ML Engine model version to use for
144
- prediction.
145
- Should be None if uri is provided. (templated)
146
- :param uri: The GCS path of the saved model to use for prediction.
147
- Should be None if model_name is provided.
148
- It should be a GCS path pointing to a tensorflow SavedModel. (templated)
149
- :param max_worker_count: The maximum number of workers to be used
150
- for parallel processing. Defaults to 10 if not specified. Should be a
151
- string representing the worker count ("10" instead of 10, "50" instead
152
- of 50, etc.)
153
- :param runtime_version: The Google Cloud ML Engine runtime version to use
154
- for batch prediction.
155
- :param signature_name: The name of the signature defined in the SavedModel
156
- to use for this job.
157
- :param project_id: The Google Cloud project name where the prediction job is submitted.
158
- If set to None or missing, the default project_id from the Google Cloud connection is used.
159
- (templated)
160
- :param gcp_conn_id: The connection ID used for connection to Google
161
- Cloud Platform.
162
- :param labels: a dictionary containing labels for the job; passed to BigQuery
163
- :param impersonation_chain: Optional service account to impersonate using short-term
164
- credentials, or chained list of accounts required to get the access_token
165
- of the last account in the list, which will be impersonated in the request.
166
- If set as a string, the account must grant the originating account
167
- the Service Account Token Creator IAM role.
168
- If set as a sequence, the identities from the list must grant
169
- Service Account Token Creator IAM role to the directly preceding identity, with first
170
- account from the list granting this role to the originating account (templated).
171
-
172
- :raises: ``ValueError``: if a unique model/version origin cannot be
173
- determined.
174
- """
175
-
176
- template_fields: Sequence[str] = (
177
- "project_id",
178
- "job_id",
179
- "region",
180
- "input_paths",
181
- "output_path",
182
- "model_name",
183
- "version_name",
184
- "uri",
185
- "impersonation_chain",
186
- )
187
-
188
- def __init__(
189
- self,
190
- *,
191
- job_id: str,
192
- region: str,
193
- data_format: str,
194
- input_paths: list[str],
195
- output_path: str,
196
- model_name: str | None = None,
197
- version_name: str | None = None,
198
- uri: str | None = None,
199
- max_worker_count: int | None = None,
200
- runtime_version: str | None = None,
201
- signature_name: str | None = None,
202
- project_id: str | None = None,
203
- gcp_conn_id: str = "google_cloud_default",
204
- labels: dict[str, str] | None = None,
205
- impersonation_chain: str | Sequence[str] | None = None,
206
- **kwargs,
207
- ) -> None:
208
- super().__init__(**kwargs)
209
-
210
- self.project_id = project_id
211
- self.job_id = job_id
212
- self.region = region
213
- self._data_format = data_format
214
- self.input_paths = input_paths
215
- self.output_path = output_path
216
- self.model_name = model_name
217
- self.version_name = version_name
218
- self.uri = uri
219
- self._max_worker_count = max_worker_count
220
- self._runtime_version = runtime_version
221
- self._signature_name = signature_name
222
- self._gcp_conn_id = gcp_conn_id
223
- self._labels = labels
224
- self.impersonation_chain = impersonation_chain
225
-
226
- def execute(self, context: Context):
227
- if not self.project_id:
228
- raise AirflowException("Google Cloud project id is required.")
229
- if not self.job_id:
230
- raise AirflowException("An unique job id is required for Google MLEngine prediction job.")
231
-
232
- if self.uri:
233
- if self.model_name or self.version_name:
234
- raise AirflowException(
235
- "Ambiguous model origin: Both uri and model/version name are provided."
236
- )
237
-
238
- if self.version_name and not self.model_name:
239
- raise AirflowException(
240
- "Missing model: Batch prediction expects a model name when a version name is provided."
241
- )
242
-
243
- if not (self.uri or self.model_name):
244
- raise AirflowException(
245
- "Missing model origin: Batch prediction expects a model, "
246
- "a model & version combination, or a URI to a savedModel."
247
- )
248
- job_id = _normalize_mlengine_job_id(self.job_id)
249
- prediction_request: dict[str, Any] = {
250
- "jobId": job_id,
251
- "predictionInput": {
252
- "dataFormat": self._data_format,
253
- "inputPaths": self.input_paths,
254
- "outputPath": self.output_path,
255
- "region": self.region,
256
- },
257
- }
258
- if self._labels:
259
- prediction_request["labels"] = self._labels
260
-
261
- if self.uri:
262
- prediction_request["predictionInput"]["uri"] = self.uri
263
- elif self.model_name:
264
- origin_name = f"projects/{self.project_id}/models/{self.model_name}"
265
- if not self.version_name:
266
- prediction_request["predictionInput"]["modelName"] = origin_name
267
- else:
268
- prediction_request["predictionInput"]["versionName"] = (
269
- origin_name + f"/versions/{self.version_name}"
270
- )
271
-
272
- if self._max_worker_count:
273
- prediction_request["predictionInput"]["maxWorkerCount"] = self._max_worker_count
274
-
275
- if self._runtime_version:
276
- prediction_request["predictionInput"]["runtimeVersion"] = self._runtime_version
277
-
278
- if self._signature_name:
279
- prediction_request["predictionInput"]["signatureName"] = self._signature_name
280
-
281
- hook = MLEngineHook(gcp_conn_id=self._gcp_conn_id, impersonation_chain=self.impersonation_chain)
282
-
283
- # Helper method to check if the existing job's prediction input is the
284
- # same as the request we get here.
285
- def check_existing_job(existing_job):
286
- return existing_job.get("predictionInput") == prediction_request["predictionInput"]
287
-
288
- finished_prediction_job = hook.create_job(
289
- project_id=self.project_id, job=prediction_request, use_existing_job_fn=check_existing_job
290
- )
291
-
292
- if finished_prediction_job["state"] != "SUCCEEDED":
293
- self.log.error("MLEngine batch prediction job failed: %s", finished_prediction_job)
294
- raise RuntimeError(finished_prediction_job["errorMessage"])
295
-
296
- return finished_prediction_job["predictionOutput"]
297
-
298
-
299
- @deprecated(
300
- planned_removal_date="March 01, 2025",
301
- use_instead="MLEngineCreateModelOperator, MLEngineGetModelOperator",
302
- category=AirflowProviderDeprecationWarning,
303
- )
304
- class MLEngineManageModelOperator(GoogleCloudBaseOperator):
305
- """
306
- Operator for managing a Google Cloud ML Engine model.
307
-
308
- .. warning::
309
- This operator is deprecated. Consider using operators for specific operations:
310
- MLEngineCreateModelOperator, MLEngineGetModelOperator.
311
-
312
- :param model: A dictionary containing the information about the model.
313
- If the `operation` is `create`, then the `model` parameter should
314
- contain all the information about this model such as `name`.
315
-
316
- If the `operation` is `get`, the `model` parameter
317
- should contain the `name` of the model.
318
- :param operation: The operation to perform. Available operations are:
319
-
320
- * ``create``: Creates a new model as provided by the `model` parameter.
321
- * ``get``: Gets a particular model where the name is specified in `model`.
322
- :param project_id: The Google Cloud project name to which MLEngine model belongs.
323
- If set to None or missing, the default project_id from the Google Cloud connection is used.
324
- (templated)
325
- :param gcp_conn_id: The connection ID to use when fetching connection info.
326
- :param impersonation_chain: Optional service account to impersonate using short-term
327
- credentials, or chained list of accounts required to get the access_token
328
- of the last account in the list, which will be impersonated in the request.
329
- If set as a string, the account must grant the originating account
330
- the Service Account Token Creator IAM role.
331
- If set as a sequence, the identities from the list must grant
332
- Service Account Token Creator IAM role to the directly preceding identity, with first
333
- account from the list granting this role to the originating account (templated).
334
- """
335
-
336
- template_fields: Sequence[str] = (
337
- "project_id",
338
- "model",
339
- "impersonation_chain",
340
- )
341
-
342
- def __init__(
343
- self,
344
- *,
345
- model: dict,
346
- operation: str = "create",
347
- project_id: str = PROVIDE_PROJECT_ID,
348
- gcp_conn_id: str = "google_cloud_default",
349
- impersonation_chain: str | Sequence[str] | None = None,
350
- **kwargs,
351
- ) -> None:
352
- super().__init__(**kwargs)
353
- self.project_id = project_id
354
- self.model = model
355
- self._operation = operation
356
- self._gcp_conn_id = gcp_conn_id
357
- self.impersonation_chain = impersonation_chain
358
-
359
- def execute(self, context: Context):
360
- hook = MLEngineHook(
361
- gcp_conn_id=self._gcp_conn_id,
362
- impersonation_chain=self.impersonation_chain,
363
- )
364
- if self._operation == "create":
365
- return hook.create_model(project_id=self.project_id, model=self.model)
366
- elif self._operation == "get":
367
- return hook.get_model(project_id=self.project_id, model_name=self.model["name"])
368
- else:
369
- raise ValueError(f"Unknown operation: {self._operation}")
370
-
371
-
372
40
  @deprecated(
373
41
  planned_removal_date="November 01, 2025",
374
42
  use_instead="appropriate VertexAI operator",
@@ -442,1074 +110,3 @@ class MLEngineCreateModelOperator(GoogleCloudBaseOperator):
442
110
  )
443
111
 
444
112
  return hook.create_model(project_id=self.project_id, model=self.model)
445
-
446
-
447
- @deprecated(
448
- planned_removal_date="March 01, 2025",
449
- use_instead="GetModelOperator",
450
- reason="All the functionality of legacy MLEngine and new features are available on the Vertex AI.",
451
- category=AirflowProviderDeprecationWarning,
452
- )
453
- class MLEngineGetModelOperator(GoogleCloudBaseOperator):
454
- """
455
- Gets a particular model.
456
-
457
- .. warning::
458
- This operator is deprecated. Please use
459
- :class:`airflow.providers.google.cloud.operators.vertex_ai.model_service.GetModelOperator` instead.
460
-
461
- .. seealso::
462
- For more information on how to use this operator, take a look at the guide:
463
- :ref:`howto/operator:MLEngineGetModelOperator`
464
-
465
- The name of model should be specified in `model_name`.
466
-
467
- :param model_name: The name of the model.
468
- :param project_id: The Google Cloud project name to which MLEngine model belongs.
469
- If set to None or missing, the default project_id from the Google Cloud connection is used.
470
- (templated)
471
- :param gcp_conn_id: The connection ID to use when fetching connection info.
472
- :param impersonation_chain: Optional service account to impersonate using short-term
473
- credentials, or chained list of accounts required to get the access_token
474
- of the last account in the list, which will be impersonated in the request.
475
- If set as a string, the account must grant the originating account
476
- the Service Account Token Creator IAM role.
477
- If set as a sequence, the identities from the list must grant
478
- Service Account Token Creator IAM role to the directly preceding identity, with first
479
- account from the list granting this role to the originating account (templated).
480
- """
481
-
482
- template_fields: Sequence[str] = (
483
- "project_id",
484
- "model_name",
485
- "impersonation_chain",
486
- )
487
- operator_extra_links = (MLEngineModelLink(),)
488
-
489
- def __init__(
490
- self,
491
- *,
492
- model_name: str,
493
- project_id: str = PROVIDE_PROJECT_ID,
494
- gcp_conn_id: str = "google_cloud_default",
495
- impersonation_chain: str | Sequence[str] | None = None,
496
- **kwargs,
497
- ) -> None:
498
- super().__init__(**kwargs)
499
- self.project_id = project_id
500
- self.model_name = model_name
501
- self._gcp_conn_id = gcp_conn_id
502
- self.impersonation_chain = impersonation_chain
503
-
504
- def execute(self, context: Context):
505
- hook = MLEngineHook(
506
- gcp_conn_id=self._gcp_conn_id,
507
- impersonation_chain=self.impersonation_chain,
508
- )
509
- project_id = self.project_id or hook.project_id
510
- if project_id:
511
- MLEngineModelLink.persist(
512
- context=context,
513
- task_instance=self,
514
- project_id=project_id,
515
- model_id=self.model_name,
516
- )
517
-
518
- return hook.get_model(project_id=self.project_id, model_name=self.model_name)
519
-
520
-
521
- @deprecated(
522
- planned_removal_date="March 01, 2025",
523
- use_instead="DeleteModelOperator",
524
- reason="All the functionality of legacy MLEngine and new features are available on the Vertex AI.",
525
- category=AirflowProviderDeprecationWarning,
526
- )
527
- class MLEngineDeleteModelOperator(GoogleCloudBaseOperator):
528
- """
529
- Deletes a model.
530
-
531
- .. warning::
532
- This operator is deprecated. Please use
533
- :class:`airflow.providers.google.cloud.operators.vertex_ai.model_service.DeleteModelOperator` instead.
534
-
535
- .. seealso::
536
- For more information on how to use this operator, take a look at the guide:
537
- :ref:`howto/operator:MLEngineDeleteModelOperator`
538
-
539
- The model should be provided by the `model_name` parameter.
540
-
541
- :param model_name: The name of the model.
542
- :param delete_contents: (Optional) Whether to force the deletion even if the models is not empty.
543
- Will delete all version (if any) in the dataset if set to True.
544
- The default value is False.
545
- :param project_id: The Google Cloud project name to which MLEngine model belongs.
546
- If set to None or missing, the default project_id from the Google Cloud connection is used.
547
- (templated)
548
- :param gcp_conn_id: The connection ID to use when fetching connection info.
549
- :param impersonation_chain: Optional service account to impersonate using short-term
550
- credentials, or chained list of accounts required to get the access_token
551
- of the last account in the list, which will be impersonated in the request.
552
- If set as a string, the account must grant the originating account
553
- the Service Account Token Creator IAM role.
554
- If set as a sequence, the identities from the list must grant
555
- Service Account Token Creator IAM role to the directly preceding identity, with first
556
- account from the list granting this role to the originating account (templated).
557
- """
558
-
559
- template_fields: Sequence[str] = (
560
- "project_id",
561
- "model_name",
562
- "impersonation_chain",
563
- )
564
- operator_extra_links = (MLEngineModelsListLink(),)
565
-
566
- def __init__(
567
- self,
568
- *,
569
- model_name: str,
570
- delete_contents: bool = False,
571
- project_id: str = PROVIDE_PROJECT_ID,
572
- gcp_conn_id: str = "google_cloud_default",
573
- impersonation_chain: str | Sequence[str] | None = None,
574
- **kwargs,
575
- ) -> None:
576
- super().__init__(**kwargs)
577
- self.project_id = project_id
578
- self.model_name = model_name
579
- self._delete_contents = delete_contents
580
- self._gcp_conn_id = gcp_conn_id
581
- self.impersonation_chain = impersonation_chain
582
-
583
- def execute(self, context: Context):
584
- hook = MLEngineHook(
585
- gcp_conn_id=self._gcp_conn_id,
586
- impersonation_chain=self.impersonation_chain,
587
- )
588
-
589
- project_id = self.project_id or hook.project_id
590
- if project_id:
591
- MLEngineModelsListLink.persist(
592
- context=context,
593
- task_instance=self,
594
- project_id=project_id,
595
- )
596
-
597
- return hook.delete_model(
598
- project_id=self.project_id, model_name=self.model_name, delete_contents=self._delete_contents
599
- )
600
-
601
-
602
- @deprecated(
603
- planned_removal_date="March 01, 2025",
604
- use_instead="MLEngineCreateVersion, MLEngineSetDefaultVersion, MLEngineListVersions, "
605
- "MLEngineDeleteVersion",
606
- category=AirflowProviderDeprecationWarning,
607
- )
608
- class MLEngineManageVersionOperator(GoogleCloudBaseOperator):
609
- """
610
- Operator for managing a Google Cloud ML Engine version.
611
-
612
- .. warning::
613
- This operator is deprecated. Consider using operators for specific operations:
614
- MLEngineCreateVersionOperator, MLEngineSetDefaultVersionOperator,
615
- MLEngineListVersionsOperator, MLEngineDeleteVersionOperator.
616
-
617
- :param model_name: The name of the Google Cloud ML Engine model that the version
618
- belongs to. (templated)
619
- :param version_name: A name to use for the version being operated upon.
620
- If not None and the `version` argument is None or does not have a value for
621
- the `name` key, then this will be populated in the payload for the
622
- `name` key. (templated)
623
- :param version: A dictionary containing the information about the version.
624
- If the `operation` is `create`, `version` should contain all the
625
- information about this version such as name, and deploymentUrl.
626
- If the `operation` is `get` or `delete`, the `version` parameter
627
- should contain the `name` of the version.
628
- If it is None, the only `operation` possible would be `list`. (templated)
629
- :param operation: The operation to perform. Available operations are:
630
-
631
- * ``create``: Creates a new version in the model specified by `model_name`,
632
- in which case the `version` parameter should contain all the
633
- information to create that version
634
- (e.g. `name`, `deploymentUrl`).
635
-
636
- * ``set_defaults``: Sets a version in the model specified by `model_name` to be the default.
637
- The name of the version should be specified in the `version`
638
- parameter.
639
-
640
- * ``list``: Lists all available versions of the model specified
641
- by `model_name`.
642
-
643
- * ``delete``: Deletes the version specified in `version` parameter from the
644
- model specified by `model_name`).
645
- The name of the version should be specified in the `version`
646
- parameter.
647
- :param project_id: The Google Cloud project name to which MLEngine model belongs.
648
- If set to None or missing, the default project_id from the Google Cloud connection is used.
649
- (templated)
650
- :param gcp_conn_id: The connection ID to use when fetching connection info.
651
- :param impersonation_chain: Optional service account to impersonate using short-term
652
- credentials, or chained list of accounts required to get the access_token
653
- of the last account in the list, which will be impersonated in the request.
654
- If set as a string, the account must grant the originating account
655
- the Service Account Token Creator IAM role.
656
- If set as a sequence, the identities from the list must grant
657
- Service Account Token Creator IAM role to the directly preceding identity, with first
658
- account from the list granting this role to the originating account (templated).
659
- """
660
-
661
- template_fields: Sequence[str] = (
662
- "project_id",
663
- "model_name",
664
- "version_name",
665
- "version",
666
- "impersonation_chain",
667
- )
668
-
669
- def __init__(
670
- self,
671
- *,
672
- model_name: str,
673
- version_name: str | None = None,
674
- version: dict | None = None,
675
- operation: str = "create",
676
- project_id: str = PROVIDE_PROJECT_ID,
677
- gcp_conn_id: str = "google_cloud_default",
678
- impersonation_chain: str | Sequence[str] | None = None,
679
- **kwargs,
680
- ) -> None:
681
- super().__init__(**kwargs)
682
- self.project_id = project_id
683
- self.model_name = model_name
684
- self.version_name = version_name
685
- self.version = version or {}
686
- self._operation = operation
687
- self._gcp_conn_id = gcp_conn_id
688
- self.impersonation_chain = impersonation_chain
689
-
690
- def execute(self, context: Context):
691
- if "name" not in self.version:
692
- self.version["name"] = self.version_name
693
-
694
- hook = MLEngineHook(
695
- gcp_conn_id=self._gcp_conn_id,
696
- impersonation_chain=self.impersonation_chain,
697
- )
698
-
699
- if self._operation == "create":
700
- if not self.version:
701
- raise ValueError(f"version attribute of {self.__class__.__name__} could not be empty")
702
- return hook.create_version(
703
- project_id=self.project_id, model_name=self.model_name, version_spec=self.version
704
- )
705
- elif self._operation == "set_default":
706
- return hook.set_default_version(
707
- project_id=self.project_id, model_name=self.model_name, version_name=self.version["name"]
708
- )
709
- elif self._operation == "list":
710
- return hook.list_versions(project_id=self.project_id, model_name=self.model_name)
711
- elif self._operation == "delete":
712
- return hook.delete_version(
713
- project_id=self.project_id, model_name=self.model_name, version_name=self.version["name"]
714
- )
715
- else:
716
- raise ValueError(f"Unknown operation: {self._operation}")
717
-
718
-
719
- @deprecated(
720
- planned_removal_date="March 01, 2025",
721
- use_instead="parent_model parameter for VertexAI operators",
722
- reason="All the functionality of legacy MLEngine and new features are available on the Vertex AI.",
723
- category=AirflowProviderDeprecationWarning,
724
- )
725
- class MLEngineCreateVersionOperator(GoogleCloudBaseOperator):
726
- """
727
- Creates a new version in the model.
728
-
729
- .. warning::
730
- This operator is deprecated. Please use parent_model parameter of VertexAI
731
- operators instead.
732
-
733
- .. seealso::
734
- For more information on how to use this operator, take a look at the guide:
735
- :ref:`howto/operator:MLEngineCreateVersionOperator`
736
-
737
- Model should be specified by `model_name`, in which case the `version` parameter should contain all the
738
- information to create that version
739
-
740
- :param model_name: The name of the Google Cloud ML Engine model that the version belongs to. (templated)
741
- :param version: A dictionary containing the information about the version. (templated)
742
- :param project_id: The Google Cloud project name to which MLEngine model belongs.
743
- If set to None or missing, the default project_id from the Google Cloud connection is used.
744
- (templated)
745
- :param gcp_conn_id: The connection ID to use when fetching connection info.
746
- :param impersonation_chain: Optional service account to impersonate using short-term
747
- credentials, or chained list of accounts required to get the access_token
748
- of the last account in the list, which will be impersonated in the request.
749
- If set as a string, the account must grant the originating account
750
- the Service Account Token Creator IAM role.
751
- If set as a sequence, the identities from the list must grant
752
- Service Account Token Creator IAM role to the directly preceding identity, with first
753
- account from the list granting this role to the originating account (templated).
754
- """
755
-
756
- template_fields: Sequence[str] = (
757
- "project_id",
758
- "model_name",
759
- "version",
760
- "impersonation_chain",
761
- )
762
- operator_extra_links = (MLEngineModelVersionDetailsLink(),)
763
-
764
- def __init__(
765
- self,
766
- *,
767
- model_name: str,
768
- version: dict,
769
- project_id: str = PROVIDE_PROJECT_ID,
770
- gcp_conn_id: str = "google_cloud_default",
771
- impersonation_chain: str | Sequence[str] | None = None,
772
- **kwargs,
773
- ) -> None:
774
- super().__init__(**kwargs)
775
- self.project_id = project_id
776
- self.model_name = model_name
777
- self.version = version
778
- self._gcp_conn_id = gcp_conn_id
779
- self.impersonation_chain = impersonation_chain
780
-
781
- def _validate_inputs(self):
782
- if not self.model_name:
783
- raise AirflowException("The model_name parameter could not be empty.")
784
-
785
- if not self.version:
786
- raise AirflowException("The version parameter could not be empty.")
787
-
788
- def execute(self, context: Context):
789
- self._validate_inputs()
790
- hook = MLEngineHook(
791
- gcp_conn_id=self._gcp_conn_id,
792
- impersonation_chain=self.impersonation_chain,
793
- )
794
-
795
- project_id = self.project_id or hook.project_id
796
- if project_id:
797
- MLEngineModelVersionDetailsLink.persist(
798
- context=context,
799
- task_instance=self,
800
- project_id=project_id,
801
- model_id=self.model_name,
802
- version_id=self.version["name"],
803
- )
804
-
805
- return hook.create_version(
806
- project_id=self.project_id, model_name=self.model_name, version_spec=self.version
807
- )
808
-
809
-
810
- @deprecated(
811
- planned_removal_date="March 01, 2025",
812
- use_instead="SetDefaultVersionOnModelOperator",
813
- reason="All the functionality of legacy MLEngine and new features are available on the Vertex AI.",
814
- category=AirflowProviderDeprecationWarning,
815
- )
816
- class MLEngineSetDefaultVersionOperator(GoogleCloudBaseOperator):
817
- """
818
- Sets a version in the model.
819
-
820
- .. warning::
821
- This operator is deprecated. Please use
822
- :class:`airflow.providers.google.cloud.operators.vertex_ai.model_service.SetDefaultVersionOnModelOperator`
823
- instead.
824
-
825
- .. seealso::
826
- For more information on how to use this operator, take a look at the guide:
827
- :ref:`howto/operator:MLEngineSetDefaultVersionOperator`
828
-
829
- The model should be specified by `model_name` to be the default. The name of the version should be
830
- specified in the `version_name` parameter.
831
-
832
- :param model_name: The name of the Google Cloud ML Engine model that the version belongs to. (templated)
833
- :param version_name: A name to use for the version being operated upon. (templated)
834
- :param project_id: The Google Cloud project name to which MLEngine model belongs.
835
- If set to None or missing, the default project_id from the Google Cloud connection is used.
836
- (templated)
837
- :param gcp_conn_id: The connection ID to use when fetching connection info.
838
- :param impersonation_chain: Optional service account to impersonate using short-term
839
- credentials, or chained list of accounts required to get the access_token
840
- of the last account in the list, which will be impersonated in the request.
841
- If set as a string, the account must grant the originating account
842
- the Service Account Token Creator IAM role.
843
- If set as a sequence, the identities from the list must grant
844
- Service Account Token Creator IAM role to the directly preceding identity, with first
845
- account from the list granting this role to the originating account (templated).
846
- """
847
-
848
- template_fields: Sequence[str] = (
849
- "project_id",
850
- "model_name",
851
- "version_name",
852
- "impersonation_chain",
853
- )
854
- operator_extra_links = (MLEngineModelVersionDetailsLink(),)
855
-
856
- def __init__(
857
- self,
858
- *,
859
- model_name: str,
860
- version_name: str,
861
- project_id: str = PROVIDE_PROJECT_ID,
862
- gcp_conn_id: str = "google_cloud_default",
863
- impersonation_chain: str | Sequence[str] | None = None,
864
- **kwargs,
865
- ) -> None:
866
- super().__init__(**kwargs)
867
- self.project_id = project_id
868
- self.model_name = model_name
869
- self.version_name = version_name
870
- self._gcp_conn_id = gcp_conn_id
871
- self.impersonation_chain = impersonation_chain
872
-
873
- def _validate_inputs(self):
874
- if not self.model_name:
875
- raise AirflowException("The model_name parameter could not be empty.")
876
-
877
- if not self.version_name:
878
- raise AirflowException("The version_name parameter could not be empty.")
879
-
880
- def execute(self, context: Context):
881
- self._validate_inputs()
882
- hook = MLEngineHook(
883
- gcp_conn_id=self._gcp_conn_id,
884
- impersonation_chain=self.impersonation_chain,
885
- )
886
-
887
- project_id = self.project_id or hook.project_id
888
- if project_id:
889
- MLEngineModelVersionDetailsLink.persist(
890
- context=context,
891
- task_instance=self,
892
- project_id=project_id,
893
- model_id=self.model_name,
894
- version_id=self.version_name,
895
- )
896
-
897
- return hook.set_default_version(
898
- project_id=self.project_id, model_name=self.model_name, version_name=self.version_name
899
- )
900
-
901
-
902
- @deprecated(
903
- planned_removal_date="March 01, 2025",
904
- use_instead="istModelVersionsOperator",
905
- reason="All the functionality of legacy MLEngine and new features are available on the Vertex AI.",
906
- category=AirflowProviderDeprecationWarning,
907
- )
908
- class MLEngineListVersionsOperator(GoogleCloudBaseOperator):
909
- """
910
- Lists all available versions of the model.
911
-
912
- .. warning::
913
- This operator is deprecated. Please use
914
- :class:`airflow.providers.google.cloud.operators.vertex_ai.model_service.ListModelVersionsOperator`
915
- instead.
916
-
917
- .. seealso::
918
- For more information on how to use this operator, take a look at the guide:
919
- :ref:`howto/operator:MLEngineListVersionsOperator`
920
-
921
- The model should be specified by `model_name`.
922
-
923
- :param model_name: The name of the Google Cloud ML Engine model that the version
924
- belongs to. (templated)
925
- :param gcp_conn_id: The connection ID to use when fetching connection info.
926
- :param project_id: The Google Cloud project name to which MLEngine model belongs.
927
- If set to None or missing, the default project_id from the Google Cloud connection is used.
928
- (templated)
929
- :param impersonation_chain: Optional service account to impersonate using short-term
930
- credentials, or chained list of accounts required to get the access_token
931
- of the last account in the list, which will be impersonated in the request.
932
- If set as a string, the account must grant the originating account
933
- the Service Account Token Creator IAM role.
934
- If set as a sequence, the identities from the list must grant
935
- Service Account Token Creator IAM role to the directly preceding identity, with first
936
- account from the list granting this role to the originating account (templated).
937
- """
938
-
939
- template_fields: Sequence[str] = (
940
- "project_id",
941
- "model_name",
942
- "impersonation_chain",
943
- )
944
- operator_extra_links = (MLEngineModelLink(),)
945
-
946
- def __init__(
947
- self,
948
- *,
949
- model_name: str,
950
- project_id: str = PROVIDE_PROJECT_ID,
951
- gcp_conn_id: str = "google_cloud_default",
952
- impersonation_chain: str | Sequence[str] | None = None,
953
- **kwargs,
954
- ) -> None:
955
- super().__init__(**kwargs)
956
- self.project_id = project_id
957
- self.model_name = model_name
958
- self._gcp_conn_id = gcp_conn_id
959
- self.impersonation_chain = impersonation_chain
960
-
961
- def _validate_inputs(self):
962
- if not self.model_name:
963
- raise AirflowException("The model_name parameter could not be empty.")
964
-
965
- def execute(self, context: Context):
966
- self._validate_inputs()
967
- hook = MLEngineHook(
968
- gcp_conn_id=self._gcp_conn_id,
969
- impersonation_chain=self.impersonation_chain,
970
- )
971
-
972
- project_id = self.project_id or hook.project_id
973
- if project_id:
974
- MLEngineModelLink.persist(
975
- context=context,
976
- task_instance=self,
977
- project_id=project_id,
978
- model_id=self.model_name,
979
- )
980
-
981
- return hook.list_versions(
982
- project_id=self.project_id,
983
- model_name=self.model_name,
984
- )
985
-
986
-
987
- @deprecated(
988
- planned_removal_date="March 01, 2025",
989
- use_instead="DeleteModelVersionOperator",
990
- reason="All the functionality of legacy MLEngine and new features are available on the Vertex AI.",
991
- category=AirflowProviderDeprecationWarning,
992
- )
993
- class MLEngineDeleteVersionOperator(GoogleCloudBaseOperator):
994
- """
995
- Deletes the version from the model.
996
-
997
- .. warning::
998
- This operator is deprecated. Please use
999
- :class:`airflow.providers.google.cloud.operators.vertex_ai.model_service.DeleteModelVersionOperator`
1000
- instead.
1001
-
1002
- .. seealso::
1003
- For more information on how to use this operator, take a look at the guide:
1004
- :ref:`howto/operator:MLEngineDeleteVersionOperator`
1005
-
1006
- The name of the version should be specified in `version_name` parameter from the model specified
1007
- by `model_name`.
1008
-
1009
- :param model_name: The name of the Google Cloud ML Engine model that the version
1010
- belongs to. (templated)
1011
- :param version_name: A name to use for the version being operated upon. (templated)
1012
- :param project_id: The Google Cloud project name to which MLEngine
1013
- model belongs.
1014
- :param gcp_conn_id: The connection ID to use when fetching connection info.
1015
- :param impersonation_chain: Optional service account to impersonate using short-term
1016
- credentials, or chained list of accounts required to get the access_token
1017
- of the last account in the list, which will be impersonated in the request.
1018
- If set as a string, the account must grant the originating account
1019
- the Service Account Token Creator IAM role.
1020
- If set as a sequence, the identities from the list must grant
1021
- Service Account Token Creator IAM role to the directly preceding identity, with first
1022
- account from the list granting this role to the originating account (templated).
1023
- """
1024
-
1025
- template_fields: Sequence[str] = (
1026
- "project_id",
1027
- "model_name",
1028
- "version_name",
1029
- "impersonation_chain",
1030
- )
1031
- operator_extra_links = (MLEngineModelLink(),)
1032
-
1033
- def __init__(
1034
- self,
1035
- *,
1036
- model_name: str,
1037
- version_name: str,
1038
- project_id: str = PROVIDE_PROJECT_ID,
1039
- gcp_conn_id: str = "google_cloud_default",
1040
- impersonation_chain: str | Sequence[str] | None = None,
1041
- **kwargs,
1042
- ) -> None:
1043
- super().__init__(**kwargs)
1044
- self.project_id = project_id
1045
- self.model_name = model_name
1046
- self.version_name = version_name
1047
- self._gcp_conn_id = gcp_conn_id
1048
- self.impersonation_chain = impersonation_chain
1049
-
1050
- def _validate_inputs(self):
1051
- if not self.model_name:
1052
- raise AirflowException("The model_name parameter could not be empty.")
1053
-
1054
- if not self.version_name:
1055
- raise AirflowException("The version_name parameter could not be empty.")
1056
-
1057
- def execute(self, context: Context):
1058
- self._validate_inputs()
1059
- hook = MLEngineHook(
1060
- gcp_conn_id=self._gcp_conn_id,
1061
- impersonation_chain=self.impersonation_chain,
1062
- )
1063
-
1064
- project_id = self.project_id or hook.project_id
1065
- if project_id:
1066
- MLEngineModelLink.persist(
1067
- context=context,
1068
- task_instance=self,
1069
- project_id=project_id,
1070
- model_id=self.model_name,
1071
- )
1072
-
1073
- return hook.delete_version(
1074
- project_id=self.project_id, model_name=self.model_name, version_name=self.version_name
1075
- )
1076
-
1077
-
1078
- @deprecated(
1079
- planned_removal_date="March 01, 2025",
1080
- use_instead="CreateCustomPythonPackageTrainingJobOperator",
1081
- reason="All the functionality of legacy MLEngine and new features are available on the Vertex AI.",
1082
- category=AirflowProviderDeprecationWarning,
1083
- )
1084
- class MLEngineStartTrainingJobOperator(GoogleCloudBaseOperator):
1085
- """
1086
- Operator for launching a MLEngine training job.
1087
-
1088
- .. warning::
1089
- This operator is deprecated. Please use
1090
- :class:`airflow.providers.google.cloud.operators.vertex_ai.custom_job.CreateCustomPythonPackageTrainingJobOperator`
1091
- instead.
1092
-
1093
- .. seealso::
1094
- For more information on how to use this operator, take a look at the guide:
1095
- :ref:`howto/operator:MLEngineStartTrainingJobOperator`
1096
-
1097
- For more information about used parameters, check:
1098
- https://cloud.google.com/sdk/gcloud/reference/ml-engine/jobs/submit/training
1099
-
1100
- :param job_id: A unique templated id for the submitted Google MLEngine
1101
- training job. (templated)
1102
- :param region: The Google Compute Engine region to run the MLEngine training
1103
- job in (templated).
1104
- :param package_uris: A list of Python package locations for the training
1105
- job, which should include the main training program and any additional
1106
- dependencies. This is mutually exclusive with a custom image specified
1107
- via master_config. (templated)
1108
- :param training_python_module: The name of the Python module to run within
1109
- the training job after installing the packages. This is mutually
1110
- exclusive with a custom image specified via master_config. (templated)
1111
- :param training_args: A list of command-line arguments to pass to the
1112
- training program. (templated)
1113
- :param scale_tier: Resource tier for MLEngine training job. (templated)
1114
- :param master_type: The type of virtual machine to use for the master
1115
- worker. It must be set whenever scale_tier is CUSTOM. (templated)
1116
- :param master_config: The configuration for the master worker. If this is
1117
- provided, master_type must be set as well. If a custom image is
1118
- specified, this is mutually exclusive with package_uris and
1119
- training_python_module. (templated)
1120
- :param runtime_version: The Google Cloud ML runtime version to use for
1121
- training. (templated)
1122
- :param python_version: The version of Python used in training. (templated)
1123
- :param job_dir: A Google Cloud Storage path in which to store training
1124
- outputs and other data needed for training. (templated)
1125
- :param service_account: Optional service account to use when running the training application.
1126
- (templated)
1127
- The specified service account must have the `iam.serviceAccounts.actAs` role. The
1128
- Google-managed Cloud ML Engine service account must have the `iam.serviceAccountAdmin` role
1129
- for the specified service account.
1130
- If set to None or missing, the Google-managed Cloud ML Engine service account will be used.
1131
- :param project_id: The Google Cloud project name within which MLEngine training job should run.
1132
- :param gcp_conn_id: The connection ID to use when fetching connection info.
1133
- :param mode: Can be one of 'DRY_RUN'/'CLOUD'. In 'DRY_RUN' mode, no real
1134
- training job will be launched, but the MLEngine training job request
1135
- will be printed out. In 'CLOUD' mode, a real MLEngine training job
1136
- creation request will be issued.
1137
- :param labels: a dictionary containing labels for the job; passed to BigQuery
1138
- :param hyperparameters: Optional HyperparameterSpec dictionary for hyperparameter tuning.
1139
- For further reference, check:
1140
- https://cloud.google.com/ai-platform/training/docs/reference/rest/v1/projects.jobs#HyperparameterSpec
1141
- :param impersonation_chain: Optional service account to impersonate using short-term
1142
- credentials, or chained list of accounts required to get the access_token
1143
- of the last account in the list, which will be impersonated in the request.
1144
- If set as a string, the account must grant the originating account
1145
- the Service Account Token Creator IAM role.
1146
- If set as a sequence, the identities from the list must grant
1147
- Service Account Token Creator IAM role to the directly preceding identity, with first
1148
- account from the list granting this role to the originating account (templated).
1149
- :param cancel_on_kill: Flag which indicates whether cancel the hook's job or not, when on_kill is called
1150
- :param deferrable: Run operator in the deferrable mode
1151
- """
1152
-
1153
- template_fields: Sequence[str] = (
1154
- "project_id",
1155
- "job_id",
1156
- "region",
1157
- "package_uris",
1158
- "training_python_module",
1159
- "training_args",
1160
- "scale_tier",
1161
- "master_type",
1162
- "master_config",
1163
- "runtime_version",
1164
- "python_version",
1165
- "job_dir",
1166
- "service_account",
1167
- "hyperparameters",
1168
- "impersonation_chain",
1169
- )
1170
- operator_extra_links = (MLEngineJobDetailsLink(),)
1171
-
1172
- def __init__(
1173
- self,
1174
- *,
1175
- job_id: str,
1176
- region: str,
1177
- project_id: str,
1178
- package_uris: list[str] | None = None,
1179
- training_python_module: str | None = None,
1180
- training_args: list[str] | None = None,
1181
- scale_tier: str | None = None,
1182
- master_type: str | None = None,
1183
- master_config: dict | None = None,
1184
- runtime_version: str | None = None,
1185
- python_version: str | None = None,
1186
- job_dir: str | None = None,
1187
- service_account: str | None = None,
1188
- gcp_conn_id: str = "google_cloud_default",
1189
- mode: str = "PRODUCTION",
1190
- labels: dict[str, str] | None = None,
1191
- impersonation_chain: str | Sequence[str] | None = None,
1192
- hyperparameters: dict | None = None,
1193
- deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
1194
- cancel_on_kill: bool = True,
1195
- **kwargs,
1196
- ) -> None:
1197
- super().__init__(**kwargs)
1198
- self.project_id = project_id
1199
- self.job_id = job_id
1200
- self.region = region
1201
- self.package_uris = package_uris
1202
- self.training_python_module = training_python_module
1203
- self.training_args = training_args
1204
- self.scale_tier = scale_tier
1205
- self.master_type = master_type
1206
- self.master_config = master_config
1207
- self.runtime_version = runtime_version
1208
- self.python_version = python_version
1209
- self.job_dir = job_dir
1210
- self.service_account = service_account
1211
- self._gcp_conn_id = gcp_conn_id
1212
- self._mode = mode
1213
- self._labels = labels
1214
- self.hyperparameters = hyperparameters
1215
- self.impersonation_chain = impersonation_chain
1216
- self.deferrable = deferrable
1217
- self.cancel_on_kill = cancel_on_kill
1218
-
1219
- def _handle_job_error(self, finished_training_job) -> None:
1220
- if finished_training_job["state"] != "SUCCEEDED":
1221
- self.log.error("MLEngine training job failed: %s", finished_training_job)
1222
- raise RuntimeError(finished_training_job["errorMessage"])
1223
-
1224
- def execute(self, context: Context):
1225
- custom = self.scale_tier is not None and self.scale_tier.upper() == "CUSTOM"
1226
- custom_image = (
1227
- custom and self.master_config is not None and self.master_config.get("imageUri", None) is not None
1228
- )
1229
-
1230
- if not self.project_id:
1231
- raise AirflowException("Google Cloud project id is required.")
1232
- if not self.job_id:
1233
- raise AirflowException("An unique job id is required for Google MLEngine training job.")
1234
- if not self.region:
1235
- raise AirflowException("Google Compute Engine region is required.")
1236
- if custom and not self.master_type:
1237
- raise AirflowException("master_type must be set when scale_tier is CUSTOM")
1238
- if self.master_config and not self.master_type:
1239
- raise AirflowException("master_type must be set when master_config is provided")
1240
- if not (self.package_uris and self.training_python_module) and not custom_image:
1241
- raise AirflowException(
1242
- "Either a Python package with a Python module or a custom Docker image should be provided."
1243
- )
1244
- if (self.package_uris or self.training_python_module) and custom_image:
1245
- raise AirflowException(
1246
- "Either a Python package with a Python module or "
1247
- "a custom Docker image should be provided but not both."
1248
- )
1249
- job_id = _normalize_mlengine_job_id(self.job_id)
1250
- self.job_id = job_id
1251
- training_request: dict[str, Any] = {
1252
- "jobId": self.job_id,
1253
- "trainingInput": {
1254
- "scaleTier": self.scale_tier,
1255
- "region": self.region,
1256
- },
1257
- }
1258
- if self.package_uris:
1259
- training_request["trainingInput"]["packageUris"] = self.package_uris
1260
-
1261
- if self.training_python_module:
1262
- training_request["trainingInput"]["pythonModule"] = self.training_python_module
1263
-
1264
- if self.training_args:
1265
- training_request["trainingInput"]["args"] = self.training_args
1266
-
1267
- if self.master_type:
1268
- training_request["trainingInput"]["masterType"] = self.master_type
1269
-
1270
- if self.master_config:
1271
- training_request["trainingInput"]["masterConfig"] = self.master_config
1272
-
1273
- if self.runtime_version:
1274
- training_request["trainingInput"]["runtimeVersion"] = self.runtime_version
1275
-
1276
- if self.python_version:
1277
- training_request["trainingInput"]["pythonVersion"] = self.python_version
1278
-
1279
- if self.job_dir:
1280
- training_request["trainingInput"]["jobDir"] = self.job_dir
1281
-
1282
- if self.service_account:
1283
- training_request["trainingInput"]["serviceAccount"] = self.service_account
1284
-
1285
- if self.hyperparameters:
1286
- training_request["trainingInput"]["hyperparameters"] = self.hyperparameters
1287
-
1288
- if self._labels:
1289
- training_request["labels"] = self._labels
1290
-
1291
- if self._mode == "DRY_RUN":
1292
- self.log.info("In dry_run mode.")
1293
- self.log.info("MLEngine Training job request is: %s", training_request)
1294
- return
1295
-
1296
- hook = MLEngineHook(
1297
- gcp_conn_id=self._gcp_conn_id,
1298
- impersonation_chain=self.impersonation_chain,
1299
- )
1300
- self.hook = hook
1301
-
1302
- try:
1303
- self.log.info("Executing: %s'", training_request)
1304
- self.job_id = self.hook.create_job_without_waiting_result(
1305
- project_id=self.project_id,
1306
- body=training_request,
1307
- )
1308
- except HttpError as e:
1309
- if e.resp.status == 409:
1310
- # If the job already exists retrieve it
1311
- self.hook.get_job(project_id=self.project_id, job_id=self.job_id)
1312
- if self.project_id:
1313
- MLEngineJobDetailsLink.persist(
1314
- context=context,
1315
- task_instance=self,
1316
- project_id=self.project_id,
1317
- job_id=self.job_id,
1318
- )
1319
- self.log.error(
1320
- "Failed to create new job with given name since it already exists. "
1321
- "The existing one will be used."
1322
- )
1323
- else:
1324
- raise e
1325
-
1326
- context["ti"].xcom_push(key="job_id", value=self.job_id)
1327
- if self.deferrable:
1328
- self.defer(
1329
- timeout=self.execution_timeout,
1330
- trigger=MLEngineStartTrainingJobTrigger(
1331
- conn_id=self._gcp_conn_id,
1332
- job_id=self.job_id,
1333
- project_id=self.project_id,
1334
- region=self.region,
1335
- runtime_version=self.runtime_version,
1336
- python_version=self.python_version,
1337
- job_dir=self.job_dir,
1338
- package_uris=self.package_uris,
1339
- training_python_module=self.training_python_module,
1340
- training_args=self.training_args,
1341
- labels=self._labels,
1342
- gcp_conn_id=self._gcp_conn_id,
1343
- impersonation_chain=self.impersonation_chain,
1344
- ),
1345
- method_name="execute_complete",
1346
- )
1347
- else:
1348
- finished_training_job = self._wait_for_job_done(self.project_id, self.job_id)
1349
- self._handle_job_error(finished_training_job)
1350
- gcp_metadata = {
1351
- "job_id": self.job_id,
1352
- "project_id": self.project_id,
1353
- }
1354
- context["task_instance"].xcom_push("gcp_metadata", gcp_metadata)
1355
-
1356
- project_id = self.project_id or hook.project_id
1357
- if project_id:
1358
- MLEngineJobDetailsLink.persist(
1359
- context=context,
1360
- task_instance=self,
1361
- project_id=project_id,
1362
- job_id=job_id,
1363
- )
1364
-
1365
- def _wait_for_job_done(self, project_id: str, job_id: str, interval: int = 30):
1366
- """
1367
- Wait for the Job to reach a terminal state.
1368
-
1369
- This method will periodically check the job state until the job reach
1370
- a terminal state.
1371
-
1372
- :param project_id: The project in which the Job is located. If set to None or missing, the default
1373
- project_id from the Google Cloud connection is used. (templated)
1374
- :param job_id: A unique id for the Google MLEngine job. (templated)
1375
- :param interval: Time expressed in seconds after which the job status is checked again. (templated)
1376
- :raises: googleapiclient.errors.HttpError
1377
- """
1378
- self.log.info("Waiting for job. job_id=%s", job_id)
1379
-
1380
- if interval <= 0:
1381
- raise ValueError("Interval must be > 0")
1382
- while True:
1383
- job = self.hook.get_job(project_id, job_id)
1384
- if job["state"] in ["SUCCEEDED", "FAILED", "CANCELLED"]:
1385
- return job
1386
- time.sleep(interval)
1387
-
1388
- def execute_complete(self, context: Context, event: dict[str, Any]):
1389
- """
1390
- Act as a callback for when the trigger fires - returns immediately.
1391
-
1392
- Relies on trigger to throw an exception, otherwise it assumes execution was successful.
1393
- """
1394
- if event["status"] == "error":
1395
- raise AirflowException(event["message"])
1396
- self.log.info(
1397
- "%s completed with response %s ",
1398
- self.task_id,
1399
- event["message"],
1400
- )
1401
- if self.project_id:
1402
- MLEngineJobDetailsLink.persist(
1403
- context=context,
1404
- task_instance=self,
1405
- project_id=self.project_id,
1406
- job_id=self.job_id,
1407
- )
1408
-
1409
- def on_kill(self) -> None:
1410
- if self.job_id and self.cancel_on_kill:
1411
- self.hook.cancel_job(job_id=self.job_id, project_id=self.project_id) # type: ignore[union-attr]
1412
- else:
1413
- self.log.info("Skipping to cancel job: %s:%s.%s", self.project_id, self.job_id)
1414
-
1415
-
1416
- @deprecated(
1417
- planned_removal_date="March 01, 2025",
1418
- use_instead="CancelCustomTrainingJobOperator",
1419
- reason="All the functionality of legacy MLEngine and new features are available on the Vertex AI.",
1420
- category=AirflowProviderDeprecationWarning,
1421
- )
1422
- class MLEngineTrainingCancelJobOperator(GoogleCloudBaseOperator):
1423
- """
1424
- Operator for cleaning up failed MLEngine training job.
1425
-
1426
- .. warning::
1427
- This operator is deprecated. Please use
1428
- :class:`airflow.providers.google.cloud.operators.vertex_ai.custom_job.CancelCustomTrainingJobOperator`
1429
- instead.
1430
-
1431
- :param job_id: A unique templated id for the submitted Google MLEngine
1432
- training job. (templated)
1433
- :param project_id: The Google Cloud project name within which MLEngine training job should run.
1434
- If set to None or missing, the default project_id from the Google Cloud connection is used.
1435
- (templated)
1436
- :param gcp_conn_id: The connection ID to use when fetching connection info.
1437
- :param impersonation_chain: Optional service account to impersonate using short-term
1438
- credentials, or chained list of accounts required to get the access_token
1439
- of the last account in the list, which will be impersonated in the request.
1440
- If set as a string, the account must grant the originating account
1441
- the Service Account Token Creator IAM role.
1442
- If set as a sequence, the identities from the list must grant
1443
- Service Account Token Creator IAM role to the directly preceding identity, with first
1444
- account from the list granting this role to the originating account (templated).
1445
- """
1446
-
1447
- template_fields: Sequence[str] = (
1448
- "project_id",
1449
- "job_id",
1450
- "impersonation_chain",
1451
- )
1452
- operator_extra_links = (MLEngineJobSListLink(),)
1453
-
1454
- def __init__(
1455
- self,
1456
- *,
1457
- job_id: str,
1458
- project_id: str = PROVIDE_PROJECT_ID,
1459
- gcp_conn_id: str = "google_cloud_default",
1460
- impersonation_chain: str | Sequence[str] | None = None,
1461
- **kwargs,
1462
- ) -> None:
1463
- super().__init__(**kwargs)
1464
- self.project_id = project_id
1465
- self.job_id = job_id
1466
- self._gcp_conn_id = gcp_conn_id
1467
- self.impersonation_chain = impersonation_chain
1468
-
1469
- @property
1470
- @deprecated(
1471
- planned_removal_date="March 01, 2025",
1472
- use_instead="project_id",
1473
- category=AirflowProviderDeprecationWarning,
1474
- )
1475
- def _project_id(self):
1476
- """Alias for ``project_id``, used for compatibility (deprecated)."""
1477
- return self.project_id
1478
-
1479
- @property
1480
- @deprecated(
1481
- planned_removal_date="March 01, 2025",
1482
- use_instead="job_id",
1483
- category=AirflowProviderDeprecationWarning,
1484
- )
1485
- def _job_id(self):
1486
- """Alias for ``job_id``, used for compatibility (deprecated)."""
1487
- return self.job_id
1488
-
1489
- @property
1490
- @deprecated(
1491
- planned_removal_date="March 01, 2025",
1492
- use_instead="impersonation_chain",
1493
- category=AirflowProviderDeprecationWarning,
1494
- )
1495
- def _impersonation_chain(self):
1496
- """Alias for ``impersonation_chain``, used for compatibility (deprecated)."""
1497
- return self.impersonation_chain
1498
-
1499
- def execute(self, context: Context):
1500
- if not self.project_id:
1501
- raise AirflowException("Google Cloud project id is required.")
1502
- hook = MLEngineHook(
1503
- gcp_conn_id=self._gcp_conn_id,
1504
- impersonation_chain=self.impersonation_chain,
1505
- )
1506
-
1507
- project_id = self.project_id or hook.project_id
1508
- if project_id:
1509
- MLEngineJobSListLink.persist(
1510
- context=context,
1511
- task_instance=self,
1512
- project_id=project_id,
1513
- )
1514
-
1515
- hook.cancel_job(project_id=self.project_id, job_id=_normalize_mlengine_job_id(self.job_id))