apache-airflow-providers-google 10.16.0rc1__py3-none-any.whl → 10.17.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/google/__init__.py +1 -1
- airflow/providers/google/ads/hooks/ads.py +5 -4
- airflow/providers/google/ads/operators/ads.py +1 -0
- airflow/providers/google/cloud/example_dags/example_cloud_sql_query.py +1 -0
- airflow/providers/google/cloud/example_dags/example_cloud_task.py +1 -0
- airflow/providers/google/cloud/example_dags/example_facebook_ads_to_gcs.py +1 -0
- airflow/providers/google/cloud/example_dags/example_looker.py +1 -0
- airflow/providers/google/cloud/example_dags/example_presto_to_gcs.py +1 -0
- airflow/providers/google/cloud/example_dags/example_salesforce_to_gcs.py +1 -0
- airflow/providers/google/cloud/fs/gcs.py +1 -2
- airflow/providers/google/cloud/hooks/automl.py +1 -0
- airflow/providers/google/cloud/hooks/bigquery.py +87 -24
- airflow/providers/google/cloud/hooks/bigquery_dts.py +1 -0
- airflow/providers/google/cloud/hooks/bigtable.py +1 -0
- airflow/providers/google/cloud/hooks/cloud_build.py +1 -0
- airflow/providers/google/cloud/hooks/cloud_memorystore.py +1 -0
- airflow/providers/google/cloud/hooks/cloud_sql.py +1 -0
- airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +9 -4
- airflow/providers/google/cloud/hooks/compute.py +1 -0
- airflow/providers/google/cloud/hooks/compute_ssh.py +2 -2
- airflow/providers/google/cloud/hooks/dataflow.py +6 -5
- airflow/providers/google/cloud/hooks/datafusion.py +1 -0
- airflow/providers/google/cloud/hooks/datapipeline.py +1 -0
- airflow/providers/google/cloud/hooks/dataplex.py +1 -0
- airflow/providers/google/cloud/hooks/dataprep.py +1 -0
- airflow/providers/google/cloud/hooks/dataproc.py +3 -2
- airflow/providers/google/cloud/hooks/dataproc_metastore.py +1 -0
- airflow/providers/google/cloud/hooks/datastore.py +1 -0
- airflow/providers/google/cloud/hooks/dlp.py +1 -0
- airflow/providers/google/cloud/hooks/functions.py +1 -0
- airflow/providers/google/cloud/hooks/gcs.py +12 -5
- airflow/providers/google/cloud/hooks/kms.py +1 -0
- airflow/providers/google/cloud/hooks/kubernetes_engine.py +178 -300
- airflow/providers/google/cloud/hooks/life_sciences.py +1 -0
- airflow/providers/google/cloud/hooks/looker.py +1 -0
- airflow/providers/google/cloud/hooks/mlengine.py +1 -0
- airflow/providers/google/cloud/hooks/natural_language.py +1 -0
- airflow/providers/google/cloud/hooks/os_login.py +1 -0
- airflow/providers/google/cloud/hooks/pubsub.py +1 -0
- airflow/providers/google/cloud/hooks/secret_manager.py +1 -0
- airflow/providers/google/cloud/hooks/spanner.py +1 -0
- airflow/providers/google/cloud/hooks/speech_to_text.py +1 -0
- airflow/providers/google/cloud/hooks/stackdriver.py +1 -0
- airflow/providers/google/cloud/hooks/text_to_speech.py +1 -0
- airflow/providers/google/cloud/hooks/translate.py +1 -0
- airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py +1 -0
- airflow/providers/google/cloud/hooks/vertex_ai/batch_prediction_job.py +255 -3
- airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py +1 -0
- airflow/providers/google/cloud/hooks/vertex_ai/dataset.py +1 -0
- airflow/providers/google/cloud/hooks/vertex_ai/endpoint_service.py +1 -0
- airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +197 -0
- airflow/providers/google/cloud/hooks/vertex_ai/hyperparameter_tuning_job.py +9 -9
- airflow/providers/google/cloud/hooks/vertex_ai/pipeline_job.py +231 -12
- airflow/providers/google/cloud/hooks/video_intelligence.py +1 -0
- airflow/providers/google/cloud/hooks/vision.py +1 -0
- airflow/providers/google/cloud/links/automl.py +1 -0
- airflow/providers/google/cloud/links/bigquery.py +1 -0
- airflow/providers/google/cloud/links/bigquery_dts.py +1 -0
- airflow/providers/google/cloud/links/cloud_memorystore.py +1 -0
- airflow/providers/google/cloud/links/cloud_sql.py +1 -0
- airflow/providers/google/cloud/links/cloud_tasks.py +1 -0
- airflow/providers/google/cloud/links/compute.py +1 -0
- airflow/providers/google/cloud/links/datacatalog.py +1 -0
- airflow/providers/google/cloud/links/dataflow.py +1 -0
- airflow/providers/google/cloud/links/dataform.py +1 -0
- airflow/providers/google/cloud/links/datafusion.py +1 -0
- airflow/providers/google/cloud/links/dataplex.py +1 -0
- airflow/providers/google/cloud/links/dataproc.py +1 -0
- airflow/providers/google/cloud/links/kubernetes_engine.py +28 -0
- airflow/providers/google/cloud/links/mlengine.py +1 -0
- airflow/providers/google/cloud/links/pubsub.py +1 -0
- airflow/providers/google/cloud/links/spanner.py +1 -0
- airflow/providers/google/cloud/links/stackdriver.py +1 -0
- airflow/providers/google/cloud/links/workflows.py +1 -0
- airflow/providers/google/cloud/log/stackdriver_task_handler.py +18 -4
- airflow/providers/google/cloud/operators/automl.py +1 -0
- airflow/providers/google/cloud/operators/bigquery.py +21 -0
- airflow/providers/google/cloud/operators/bigquery_dts.py +1 -0
- airflow/providers/google/cloud/operators/bigtable.py +1 -0
- airflow/providers/google/cloud/operators/cloud_base.py +1 -0
- airflow/providers/google/cloud/operators/cloud_build.py +1 -0
- airflow/providers/google/cloud/operators/cloud_memorystore.py +1 -0
- airflow/providers/google/cloud/operators/cloud_sql.py +1 -0
- airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +11 -5
- airflow/providers/google/cloud/operators/compute.py +1 -0
- airflow/providers/google/cloud/operators/dataflow.py +1 -0
- airflow/providers/google/cloud/operators/datafusion.py +1 -0
- airflow/providers/google/cloud/operators/datapipeline.py +1 -0
- airflow/providers/google/cloud/operators/dataprep.py +1 -0
- airflow/providers/google/cloud/operators/dataproc.py +3 -2
- airflow/providers/google/cloud/operators/dataproc_metastore.py +1 -0
- airflow/providers/google/cloud/operators/datastore.py +1 -0
- airflow/providers/google/cloud/operators/functions.py +1 -0
- airflow/providers/google/cloud/operators/gcs.py +1 -0
- airflow/providers/google/cloud/operators/kubernetes_engine.py +600 -4
- airflow/providers/google/cloud/operators/life_sciences.py +1 -0
- airflow/providers/google/cloud/operators/looker.py +1 -0
- airflow/providers/google/cloud/operators/mlengine.py +283 -259
- airflow/providers/google/cloud/operators/natural_language.py +1 -0
- airflow/providers/google/cloud/operators/pubsub.py +1 -0
- airflow/providers/google/cloud/operators/spanner.py +1 -0
- airflow/providers/google/cloud/operators/speech_to_text.py +1 -0
- airflow/providers/google/cloud/operators/text_to_speech.py +1 -0
- airflow/providers/google/cloud/operators/translate.py +1 -0
- airflow/providers/google/cloud/operators/translate_speech.py +1 -0
- airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py +14 -7
- airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py +67 -13
- airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +26 -8
- airflow/providers/google/cloud/operators/vertex_ai/dataset.py +1 -0
- airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +306 -0
- airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py +29 -48
- airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py +52 -17
- airflow/providers/google/cloud/operators/video_intelligence.py +1 -0
- airflow/providers/google/cloud/operators/vision.py +1 -0
- airflow/providers/google/cloud/secrets/secret_manager.py +1 -0
- airflow/providers/google/cloud/sensors/bigquery.py +1 -0
- airflow/providers/google/cloud/sensors/bigquery_dts.py +1 -0
- airflow/providers/google/cloud/sensors/bigtable.py +1 -0
- airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py +1 -0
- airflow/providers/google/cloud/sensors/dataflow.py +1 -0
- airflow/providers/google/cloud/sensors/dataform.py +1 -0
- airflow/providers/google/cloud/sensors/datafusion.py +1 -0
- airflow/providers/google/cloud/sensors/dataplex.py +1 -0
- airflow/providers/google/cloud/sensors/dataprep.py +1 -0
- airflow/providers/google/cloud/sensors/dataproc.py +1 -0
- airflow/providers/google/cloud/sensors/gcs.py +1 -0
- airflow/providers/google/cloud/sensors/looker.py +1 -0
- airflow/providers/google/cloud/sensors/pubsub.py +1 -0
- airflow/providers/google/cloud/sensors/tasks.py +1 -0
- airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py +1 -0
- airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +1 -0
- airflow/providers/google/cloud/transfers/bigquery_to_mssql.py +1 -0
- airflow/providers/google/cloud/transfers/bigquery_to_mysql.py +1 -0
- airflow/providers/google/cloud/transfers/bigquery_to_postgres.py +1 -0
- airflow/providers/google/cloud/transfers/bigquery_to_sql.py +1 -0
- airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py +1 -0
- airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +3 -2
- airflow/providers/google/cloud/transfers/gcs_to_gcs.py +1 -0
- airflow/providers/google/cloud/transfers/gcs_to_sftp.py +1 -0
- airflow/providers/google/cloud/transfers/local_to_gcs.py +1 -0
- airflow/providers/google/cloud/transfers/mssql_to_gcs.py +1 -0
- airflow/providers/google/cloud/transfers/mysql_to_gcs.py +1 -0
- airflow/providers/google/cloud/transfers/postgres_to_gcs.py +19 -1
- airflow/providers/google/cloud/transfers/s3_to_gcs.py +3 -5
- airflow/providers/google/cloud/transfers/sftp_to_gcs.py +1 -0
- airflow/providers/google/cloud/transfers/sql_to_gcs.py +4 -2
- airflow/providers/google/cloud/triggers/bigquery.py +4 -3
- airflow/providers/google/cloud/triggers/cloud_batch.py +1 -1
- airflow/providers/google/cloud/triggers/cloud_run.py +1 -0
- airflow/providers/google/cloud/triggers/cloud_sql.py +2 -0
- airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py +14 -2
- airflow/providers/google/cloud/triggers/dataplex.py +1 -0
- airflow/providers/google/cloud/triggers/dataproc.py +1 -0
- airflow/providers/google/cloud/triggers/kubernetes_engine.py +72 -2
- airflow/providers/google/cloud/triggers/mlengine.py +2 -0
- airflow/providers/google/cloud/triggers/pubsub.py +3 -3
- airflow/providers/google/cloud/triggers/vertex_ai.py +107 -15
- airflow/providers/google/cloud/utils/field_sanitizer.py +2 -1
- airflow/providers/google/cloud/utils/field_validator.py +1 -0
- airflow/providers/google/cloud/utils/helpers.py +1 -0
- airflow/providers/google/cloud/utils/mlengine_operator_utils.py +1 -0
- airflow/providers/google/cloud/utils/mlengine_prediction_summary.py +1 -0
- airflow/providers/google/cloud/utils/openlineage.py +1 -0
- airflow/providers/google/common/auth_backend/google_openid.py +1 -0
- airflow/providers/google/common/hooks/base_google.py +2 -1
- airflow/providers/google/common/hooks/discovery_api.py +1 -0
- airflow/providers/google/common/links/storage.py +1 -0
- airflow/providers/google/common/utils/id_token_credentials.py +1 -0
- airflow/providers/google/firebase/hooks/firestore.py +1 -0
- airflow/providers/google/get_provider_info.py +9 -3
- airflow/providers/google/go_module_utils.py +1 -0
- airflow/providers/google/leveldb/hooks/leveldb.py +8 -7
- airflow/providers/google/marketing_platform/example_dags/example_display_video.py +1 -0
- airflow/providers/google/marketing_platform/hooks/analytics_admin.py +1 -0
- airflow/providers/google/marketing_platform/hooks/campaign_manager.py +1 -0
- airflow/providers/google/marketing_platform/hooks/display_video.py +1 -0
- airflow/providers/google/marketing_platform/hooks/search_ads.py +1 -0
- airflow/providers/google/marketing_platform/operators/analytics.py +1 -0
- airflow/providers/google/marketing_platform/operators/analytics_admin.py +4 -2
- airflow/providers/google/marketing_platform/operators/campaign_manager.py +1 -0
- airflow/providers/google/marketing_platform/operators/display_video.py +1 -0
- airflow/providers/google/marketing_platform/operators/search_ads.py +1 -0
- airflow/providers/google/marketing_platform/sensors/campaign_manager.py +1 -0
- airflow/providers/google/marketing_platform/sensors/display_video.py +2 -1
- airflow/providers/google/marketing_platform/sensors/search_ads.py +1 -0
- airflow/providers/google/suite/hooks/calendar.py +1 -0
- airflow/providers/google/suite/hooks/drive.py +1 -0
- airflow/providers/google/suite/hooks/sheets.py +1 -0
- airflow/providers/google/suite/sensors/drive.py +1 -0
- airflow/providers/google/suite/transfers/gcs_to_gdrive.py +7 -0
- airflow/providers/google/suite/transfers/gcs_to_sheets.py +4 -1
- airflow/providers/google/suite/transfers/local_to_drive.py +1 -0
- {apache_airflow_providers_google-10.16.0rc1.dist-info → apache_airflow_providers_google-10.17.0rc1.dist-info}/METADATA +22 -17
- {apache_airflow_providers_google-10.16.0rc1.dist-info → apache_airflow_providers_google-10.17.0rc1.dist-info}/RECORD +196 -194
- {apache_airflow_providers_google-10.16.0rc1.dist-info → apache_airflow_providers_google-10.17.0rc1.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_google-10.16.0rc1.dist-info → apache_airflow_providers_google-10.17.0rc1.dist-info}/entry_points.txt +0 -0
@@ -111,7 +111,7 @@ class MLEngineStartBatchPredictionJobOperator(GoogleCloudBaseOperator):
|
|
111
111
|
In options 2 and 3, both model and version name should contain the
|
112
112
|
minimal identifier. For instance, call::
|
113
113
|
|
114
|
-
|
114
|
+
MLEngineStartBatchPredictionJobOperator(
|
115
115
|
...,
|
116
116
|
model_name='my_model',
|
117
117
|
version_name='my_version',
|
@@ -173,15 +173,15 @@ class MLEngineStartBatchPredictionJobOperator(GoogleCloudBaseOperator):
|
|
173
173
|
"""
|
174
174
|
|
175
175
|
template_fields: Sequence[str] = (
|
176
|
-
"
|
177
|
-
"
|
178
|
-
"
|
179
|
-
"
|
180
|
-
"
|
181
|
-
"
|
182
|
-
"
|
183
|
-
"
|
184
|
-
"
|
176
|
+
"project_id",
|
177
|
+
"job_id",
|
178
|
+
"region",
|
179
|
+
"input_paths",
|
180
|
+
"output_path",
|
181
|
+
"model_name",
|
182
|
+
"version_name",
|
183
|
+
"uri",
|
184
|
+
"impersonation_chain",
|
185
185
|
)
|
186
186
|
|
187
187
|
def __init__(
|
@@ -206,67 +206,66 @@ class MLEngineStartBatchPredictionJobOperator(GoogleCloudBaseOperator):
|
|
206
206
|
) -> None:
|
207
207
|
super().__init__(**kwargs)
|
208
208
|
|
209
|
-
self.
|
210
|
-
self.
|
211
|
-
self.
|
209
|
+
self.project_id = project_id
|
210
|
+
self.job_id = job_id
|
211
|
+
self.region = region
|
212
212
|
self._data_format = data_format
|
213
|
-
self.
|
214
|
-
self.
|
215
|
-
self.
|
216
|
-
self.
|
217
|
-
self.
|
213
|
+
self.input_paths = input_paths
|
214
|
+
self.output_path = output_path
|
215
|
+
self.model_name = model_name
|
216
|
+
self.version_name = version_name
|
217
|
+
self.uri = uri
|
218
218
|
self._max_worker_count = max_worker_count
|
219
219
|
self._runtime_version = runtime_version
|
220
220
|
self._signature_name = signature_name
|
221
221
|
self._gcp_conn_id = gcp_conn_id
|
222
222
|
self._labels = labels
|
223
|
-
self.
|
223
|
+
self.impersonation_chain = impersonation_chain
|
224
224
|
|
225
|
-
|
225
|
+
def execute(self, context: Context):
|
226
|
+
if not self.project_id:
|
226
227
|
raise AirflowException("Google Cloud project id is required.")
|
227
|
-
if not self.
|
228
|
+
if not self.job_id:
|
228
229
|
raise AirflowException("An unique job id is required for Google MLEngine prediction job.")
|
229
230
|
|
230
|
-
if self.
|
231
|
-
if self.
|
231
|
+
if self.uri:
|
232
|
+
if self.model_name or self.version_name:
|
232
233
|
raise AirflowException(
|
233
234
|
"Ambiguous model origin: Both uri and model/version name are provided."
|
234
235
|
)
|
235
236
|
|
236
|
-
if self.
|
237
|
+
if self.version_name and not self.model_name:
|
237
238
|
raise AirflowException(
|
238
239
|
"Missing model: Batch prediction expects a model name when a version name is provided."
|
239
240
|
)
|
240
241
|
|
241
|
-
if not (self.
|
242
|
+
if not (self.uri or self.model_name):
|
242
243
|
raise AirflowException(
|
243
244
|
"Missing model origin: Batch prediction expects a model, "
|
244
245
|
"a model & version combination, or a URI to a savedModel."
|
245
246
|
)
|
246
|
-
|
247
|
-
def execute(self, context: Context):
|
248
|
-
job_id = _normalize_mlengine_job_id(self._job_id)
|
247
|
+
job_id = _normalize_mlengine_job_id(self.job_id)
|
249
248
|
prediction_request: dict[str, Any] = {
|
250
249
|
"jobId": job_id,
|
251
250
|
"predictionInput": {
|
252
251
|
"dataFormat": self._data_format,
|
253
|
-
"inputPaths": self.
|
254
|
-
"outputPath": self.
|
255
|
-
"region": self.
|
252
|
+
"inputPaths": self.input_paths,
|
253
|
+
"outputPath": self.output_path,
|
254
|
+
"region": self.region,
|
256
255
|
},
|
257
256
|
}
|
258
257
|
if self._labels:
|
259
258
|
prediction_request["labels"] = self._labels
|
260
259
|
|
261
|
-
if self.
|
262
|
-
prediction_request["predictionInput"]["uri"] = self.
|
263
|
-
elif self.
|
264
|
-
origin_name = f"projects/{self.
|
265
|
-
if not self.
|
260
|
+
if self.uri:
|
261
|
+
prediction_request["predictionInput"]["uri"] = self.uri
|
262
|
+
elif self.model_name:
|
263
|
+
origin_name = f"projects/{self.project_id}/models/{self.model_name}"
|
264
|
+
if not self.version_name:
|
266
265
|
prediction_request["predictionInput"]["modelName"] = origin_name
|
267
266
|
else:
|
268
267
|
prediction_request["predictionInput"]["versionName"] = (
|
269
|
-
origin_name + f"/versions/{self.
|
268
|
+
origin_name + f"/versions/{self.version_name}"
|
270
269
|
)
|
271
270
|
|
272
271
|
if self._max_worker_count:
|
@@ -278,7 +277,7 @@ class MLEngineStartBatchPredictionJobOperator(GoogleCloudBaseOperator):
|
|
278
277
|
if self._signature_name:
|
279
278
|
prediction_request["predictionInput"]["signatureName"] = self._signature_name
|
280
279
|
|
281
|
-
hook = MLEngineHook(gcp_conn_id=self._gcp_conn_id, impersonation_chain=self.
|
280
|
+
hook = MLEngineHook(gcp_conn_id=self._gcp_conn_id, impersonation_chain=self.impersonation_chain)
|
282
281
|
|
283
282
|
# Helper method to check if the existing job's prediction input is the
|
284
283
|
# same as the request we get here.
|
@@ -286,7 +285,7 @@ class MLEngineStartBatchPredictionJobOperator(GoogleCloudBaseOperator):
|
|
286
285
|
return existing_job.get("predictionInput") == prediction_request["predictionInput"]
|
287
286
|
|
288
287
|
finished_prediction_job = hook.create_job(
|
289
|
-
project_id=self.
|
288
|
+
project_id=self.project_id, job=prediction_request, use_existing_job_fn=check_existing_job
|
290
289
|
)
|
291
290
|
|
292
291
|
if finished_prediction_job["state"] != "SUCCEEDED":
|
@@ -336,9 +335,9 @@ class MLEngineManageModelOperator(GoogleCloudBaseOperator):
|
|
336
335
|
"""
|
337
336
|
|
338
337
|
template_fields: Sequence[str] = (
|
339
|
-
"
|
340
|
-
"
|
341
|
-
"
|
338
|
+
"project_id",
|
339
|
+
"model",
|
340
|
+
"impersonation_chain",
|
342
341
|
)
|
343
342
|
|
344
343
|
def __init__(
|
@@ -352,21 +351,21 @@ class MLEngineManageModelOperator(GoogleCloudBaseOperator):
|
|
352
351
|
**kwargs,
|
353
352
|
) -> None:
|
354
353
|
super().__init__(**kwargs)
|
355
|
-
self.
|
356
|
-
self.
|
354
|
+
self.project_id = project_id
|
355
|
+
self.model = model
|
357
356
|
self._operation = operation
|
358
357
|
self._gcp_conn_id = gcp_conn_id
|
359
|
-
self.
|
358
|
+
self.impersonation_chain = impersonation_chain
|
360
359
|
|
361
360
|
def execute(self, context: Context):
|
362
361
|
hook = MLEngineHook(
|
363
362
|
gcp_conn_id=self._gcp_conn_id,
|
364
|
-
impersonation_chain=self.
|
363
|
+
impersonation_chain=self.impersonation_chain,
|
365
364
|
)
|
366
365
|
if self._operation == "create":
|
367
|
-
return hook.create_model(project_id=self.
|
366
|
+
return hook.create_model(project_id=self.project_id, model=self.model)
|
368
367
|
elif self._operation == "get":
|
369
|
-
return hook.get_model(project_id=self.
|
368
|
+
return hook.get_model(project_id=self.project_id, model_name=self.model["name"])
|
370
369
|
else:
|
371
370
|
raise ValueError(f"Unknown operation: {self._operation}")
|
372
371
|
|
@@ -408,9 +407,9 @@ class MLEngineCreateModelOperator(GoogleCloudBaseOperator):
|
|
408
407
|
"""
|
409
408
|
|
410
409
|
template_fields: Sequence[str] = (
|
411
|
-
"
|
412
|
-
"
|
413
|
-
"
|
410
|
+
"project_id",
|
411
|
+
"model",
|
412
|
+
"impersonation_chain",
|
414
413
|
)
|
415
414
|
operator_extra_links = (MLEngineModelLink(),)
|
416
415
|
|
@@ -424,27 +423,27 @@ class MLEngineCreateModelOperator(GoogleCloudBaseOperator):
|
|
424
423
|
**kwargs,
|
425
424
|
) -> None:
|
426
425
|
super().__init__(**kwargs)
|
427
|
-
self.
|
428
|
-
self.
|
426
|
+
self.project_id = project_id
|
427
|
+
self.model = model
|
429
428
|
self._gcp_conn_id = gcp_conn_id
|
430
|
-
self.
|
429
|
+
self.impersonation_chain = impersonation_chain
|
431
430
|
|
432
431
|
def execute(self, context: Context):
|
433
432
|
hook = MLEngineHook(
|
434
433
|
gcp_conn_id=self._gcp_conn_id,
|
435
|
-
impersonation_chain=self.
|
434
|
+
impersonation_chain=self.impersonation_chain,
|
436
435
|
)
|
437
436
|
|
438
|
-
project_id = self.
|
437
|
+
project_id = self.project_id or hook.project_id
|
439
438
|
if project_id:
|
440
439
|
MLEngineModelLink.persist(
|
441
440
|
context=context,
|
442
441
|
task_instance=self,
|
443
442
|
project_id=project_id,
|
444
|
-
model_id=self.
|
443
|
+
model_id=self.model["name"],
|
445
444
|
)
|
446
445
|
|
447
|
-
return hook.create_model(project_id=self.
|
446
|
+
return hook.create_model(project_id=self.project_id, model=self.model)
|
448
447
|
|
449
448
|
|
450
449
|
@deprecated(
|
@@ -484,9 +483,9 @@ class MLEngineGetModelOperator(GoogleCloudBaseOperator):
|
|
484
483
|
"""
|
485
484
|
|
486
485
|
template_fields: Sequence[str] = (
|
487
|
-
"
|
488
|
-
"
|
489
|
-
"
|
486
|
+
"project_id",
|
487
|
+
"model_name",
|
488
|
+
"impersonation_chain",
|
490
489
|
)
|
491
490
|
operator_extra_links = (MLEngineModelLink(),)
|
492
491
|
|
@@ -500,26 +499,26 @@ class MLEngineGetModelOperator(GoogleCloudBaseOperator):
|
|
500
499
|
**kwargs,
|
501
500
|
) -> None:
|
502
501
|
super().__init__(**kwargs)
|
503
|
-
self.
|
504
|
-
self.
|
502
|
+
self.project_id = project_id
|
503
|
+
self.model_name = model_name
|
505
504
|
self._gcp_conn_id = gcp_conn_id
|
506
|
-
self.
|
505
|
+
self.impersonation_chain = impersonation_chain
|
507
506
|
|
508
507
|
def execute(self, context: Context):
|
509
508
|
hook = MLEngineHook(
|
510
509
|
gcp_conn_id=self._gcp_conn_id,
|
511
|
-
impersonation_chain=self.
|
510
|
+
impersonation_chain=self.impersonation_chain,
|
512
511
|
)
|
513
|
-
project_id = self.
|
512
|
+
project_id = self.project_id or hook.project_id
|
514
513
|
if project_id:
|
515
514
|
MLEngineModelLink.persist(
|
516
515
|
context=context,
|
517
516
|
task_instance=self,
|
518
517
|
project_id=project_id,
|
519
|
-
model_id=self.
|
518
|
+
model_id=self.model_name,
|
520
519
|
)
|
521
520
|
|
522
|
-
return hook.get_model(project_id=self.
|
521
|
+
return hook.get_model(project_id=self.project_id, model_name=self.model_name)
|
523
522
|
|
524
523
|
|
525
524
|
@deprecated(
|
@@ -563,9 +562,9 @@ class MLEngineDeleteModelOperator(GoogleCloudBaseOperator):
|
|
563
562
|
"""
|
564
563
|
|
565
564
|
template_fields: Sequence[str] = (
|
566
|
-
"
|
567
|
-
"
|
568
|
-
"
|
565
|
+
"project_id",
|
566
|
+
"model_name",
|
567
|
+
"impersonation_chain",
|
569
568
|
)
|
570
569
|
operator_extra_links = (MLEngineModelsListLink(),)
|
571
570
|
|
@@ -580,19 +579,19 @@ class MLEngineDeleteModelOperator(GoogleCloudBaseOperator):
|
|
580
579
|
**kwargs,
|
581
580
|
) -> None:
|
582
581
|
super().__init__(**kwargs)
|
583
|
-
self.
|
584
|
-
self.
|
582
|
+
self.project_id = project_id
|
583
|
+
self.model_name = model_name
|
585
584
|
self._delete_contents = delete_contents
|
586
585
|
self._gcp_conn_id = gcp_conn_id
|
587
|
-
self.
|
586
|
+
self.impersonation_chain = impersonation_chain
|
588
587
|
|
589
588
|
def execute(self, context: Context):
|
590
589
|
hook = MLEngineHook(
|
591
590
|
gcp_conn_id=self._gcp_conn_id,
|
592
|
-
impersonation_chain=self.
|
591
|
+
impersonation_chain=self.impersonation_chain,
|
593
592
|
)
|
594
593
|
|
595
|
-
project_id = self.
|
594
|
+
project_id = self.project_id or hook.project_id
|
596
595
|
if project_id:
|
597
596
|
MLEngineModelsListLink.persist(
|
598
597
|
context=context,
|
@@ -601,7 +600,7 @@ class MLEngineDeleteModelOperator(GoogleCloudBaseOperator):
|
|
601
600
|
)
|
602
601
|
|
603
602
|
return hook.delete_model(
|
604
|
-
project_id=self.
|
603
|
+
project_id=self.project_id, model_name=self.model_name, delete_contents=self._delete_contents
|
605
604
|
)
|
606
605
|
|
607
606
|
|
@@ -667,11 +666,11 @@ class MLEngineManageVersionOperator(GoogleCloudBaseOperator):
|
|
667
666
|
"""
|
668
667
|
|
669
668
|
template_fields: Sequence[str] = (
|
670
|
-
"
|
671
|
-
"
|
672
|
-
"
|
673
|
-
"
|
674
|
-
"
|
669
|
+
"project_id",
|
670
|
+
"model_name",
|
671
|
+
"version_name",
|
672
|
+
"version",
|
673
|
+
"impersonation_chain",
|
675
674
|
)
|
676
675
|
|
677
676
|
def __init__(
|
@@ -687,38 +686,38 @@ class MLEngineManageVersionOperator(GoogleCloudBaseOperator):
|
|
687
686
|
**kwargs,
|
688
687
|
) -> None:
|
689
688
|
super().__init__(**kwargs)
|
690
|
-
self.
|
691
|
-
self.
|
692
|
-
self.
|
693
|
-
self.
|
689
|
+
self.project_id = project_id
|
690
|
+
self.model_name = model_name
|
691
|
+
self.version_name = version_name
|
692
|
+
self.version = version or {}
|
694
693
|
self._operation = operation
|
695
694
|
self._gcp_conn_id = gcp_conn_id
|
696
|
-
self.
|
695
|
+
self.impersonation_chain = impersonation_chain
|
697
696
|
|
698
697
|
def execute(self, context: Context):
|
699
|
-
if "name" not in self.
|
700
|
-
self.
|
698
|
+
if "name" not in self.version:
|
699
|
+
self.version["name"] = self.version_name
|
701
700
|
|
702
701
|
hook = MLEngineHook(
|
703
702
|
gcp_conn_id=self._gcp_conn_id,
|
704
|
-
impersonation_chain=self.
|
703
|
+
impersonation_chain=self.impersonation_chain,
|
705
704
|
)
|
706
705
|
|
707
706
|
if self._operation == "create":
|
708
|
-
if not self.
|
707
|
+
if not self.version:
|
709
708
|
raise ValueError(f"version attribute of {self.__class__.__name__} could not be empty")
|
710
709
|
return hook.create_version(
|
711
|
-
project_id=self.
|
710
|
+
project_id=self.project_id, model_name=self.model_name, version_spec=self.version
|
712
711
|
)
|
713
712
|
elif self._operation == "set_default":
|
714
713
|
return hook.set_default_version(
|
715
|
-
project_id=self.
|
714
|
+
project_id=self.project_id, model_name=self.model_name, version_name=self.version["name"]
|
716
715
|
)
|
717
716
|
elif self._operation == "list":
|
718
|
-
return hook.list_versions(project_id=self.
|
717
|
+
return hook.list_versions(project_id=self.project_id, model_name=self.model_name)
|
719
718
|
elif self._operation == "delete":
|
720
719
|
return hook.delete_version(
|
721
|
-
project_id=self.
|
720
|
+
project_id=self.project_id, model_name=self.model_name, version_name=self.version["name"]
|
722
721
|
)
|
723
722
|
else:
|
724
723
|
raise ValueError(f"Unknown operation: {self._operation}")
|
@@ -762,10 +761,10 @@ class MLEngineCreateVersionOperator(GoogleCloudBaseOperator):
|
|
762
761
|
"""
|
763
762
|
|
764
763
|
template_fields: Sequence[str] = (
|
765
|
-
"
|
766
|
-
"
|
767
|
-
"
|
768
|
-
"
|
764
|
+
"project_id",
|
765
|
+
"model_name",
|
766
|
+
"version",
|
767
|
+
"impersonation_chain",
|
769
768
|
)
|
770
769
|
operator_extra_links = (MLEngineModelVersionDetailsLink(),)
|
771
770
|
|
@@ -780,38 +779,38 @@ class MLEngineCreateVersionOperator(GoogleCloudBaseOperator):
|
|
780
779
|
**kwargs,
|
781
780
|
) -> None:
|
782
781
|
super().__init__(**kwargs)
|
783
|
-
self.
|
784
|
-
self.
|
785
|
-
self.
|
782
|
+
self.project_id = project_id
|
783
|
+
self.model_name = model_name
|
784
|
+
self.version = version
|
786
785
|
self._gcp_conn_id = gcp_conn_id
|
787
|
-
self.
|
788
|
-
self._validate_inputs()
|
786
|
+
self.impersonation_chain = impersonation_chain
|
789
787
|
|
790
788
|
def _validate_inputs(self):
|
791
|
-
if not self.
|
789
|
+
if not self.model_name:
|
792
790
|
raise AirflowException("The model_name parameter could not be empty.")
|
793
791
|
|
794
|
-
if not self.
|
792
|
+
if not self.version:
|
795
793
|
raise AirflowException("The version parameter could not be empty.")
|
796
794
|
|
797
795
|
def execute(self, context: Context):
|
796
|
+
self._validate_inputs()
|
798
797
|
hook = MLEngineHook(
|
799
798
|
gcp_conn_id=self._gcp_conn_id,
|
800
|
-
impersonation_chain=self.
|
799
|
+
impersonation_chain=self.impersonation_chain,
|
801
800
|
)
|
802
801
|
|
803
|
-
project_id = self.
|
802
|
+
project_id = self.project_id or hook.project_id
|
804
803
|
if project_id:
|
805
804
|
MLEngineModelVersionDetailsLink.persist(
|
806
805
|
context=context,
|
807
806
|
task_instance=self,
|
808
807
|
project_id=project_id,
|
809
|
-
model_id=self.
|
810
|
-
version_id=self.
|
808
|
+
model_id=self.model_name,
|
809
|
+
version_id=self.version["name"],
|
811
810
|
)
|
812
811
|
|
813
812
|
return hook.create_version(
|
814
|
-
project_id=self.
|
813
|
+
project_id=self.project_id, model_name=self.model_name, version_spec=self.version
|
815
814
|
)
|
816
815
|
|
817
816
|
|
@@ -855,10 +854,10 @@ class MLEngineSetDefaultVersionOperator(GoogleCloudBaseOperator):
|
|
855
854
|
"""
|
856
855
|
|
857
856
|
template_fields: Sequence[str] = (
|
858
|
-
"
|
859
|
-
"
|
860
|
-
"
|
861
|
-
"
|
857
|
+
"project_id",
|
858
|
+
"model_name",
|
859
|
+
"version_name",
|
860
|
+
"impersonation_chain",
|
862
861
|
)
|
863
862
|
operator_extra_links = (MLEngineModelVersionDetailsLink(),)
|
864
863
|
|
@@ -873,38 +872,38 @@ class MLEngineSetDefaultVersionOperator(GoogleCloudBaseOperator):
|
|
873
872
|
**kwargs,
|
874
873
|
) -> None:
|
875
874
|
super().__init__(**kwargs)
|
876
|
-
self.
|
877
|
-
self.
|
878
|
-
self.
|
875
|
+
self.project_id = project_id
|
876
|
+
self.model_name = model_name
|
877
|
+
self.version_name = version_name
|
879
878
|
self._gcp_conn_id = gcp_conn_id
|
880
|
-
self.
|
881
|
-
self._validate_inputs()
|
879
|
+
self.impersonation_chain = impersonation_chain
|
882
880
|
|
883
881
|
def _validate_inputs(self):
|
884
|
-
if not self.
|
882
|
+
if not self.model_name:
|
885
883
|
raise AirflowException("The model_name parameter could not be empty.")
|
886
884
|
|
887
|
-
if not self.
|
885
|
+
if not self.version_name:
|
888
886
|
raise AirflowException("The version_name parameter could not be empty.")
|
889
887
|
|
890
888
|
def execute(self, context: Context):
|
889
|
+
self._validate_inputs()
|
891
890
|
hook = MLEngineHook(
|
892
891
|
gcp_conn_id=self._gcp_conn_id,
|
893
|
-
impersonation_chain=self.
|
892
|
+
impersonation_chain=self.impersonation_chain,
|
894
893
|
)
|
895
894
|
|
896
|
-
project_id = self.
|
895
|
+
project_id = self.project_id or hook.project_id
|
897
896
|
if project_id:
|
898
897
|
MLEngineModelVersionDetailsLink.persist(
|
899
898
|
context=context,
|
900
899
|
task_instance=self,
|
901
900
|
project_id=project_id,
|
902
|
-
model_id=self.
|
903
|
-
version_id=self.
|
901
|
+
model_id=self.model_name,
|
902
|
+
version_id=self.version_name,
|
904
903
|
)
|
905
904
|
|
906
905
|
return hook.set_default_version(
|
907
|
-
project_id=self.
|
906
|
+
project_id=self.project_id, model_name=self.model_name, version_name=self.version_name
|
908
907
|
)
|
909
908
|
|
910
909
|
|
@@ -947,9 +946,9 @@ class MLEngineListVersionsOperator(GoogleCloudBaseOperator):
|
|
947
946
|
"""
|
948
947
|
|
949
948
|
template_fields: Sequence[str] = (
|
950
|
-
"
|
951
|
-
"
|
952
|
-
"
|
949
|
+
"project_id",
|
950
|
+
"model_name",
|
951
|
+
"impersonation_chain",
|
953
952
|
)
|
954
953
|
operator_extra_links = (MLEngineModelLink(),)
|
955
954
|
|
@@ -963,34 +962,34 @@ class MLEngineListVersionsOperator(GoogleCloudBaseOperator):
|
|
963
962
|
**kwargs,
|
964
963
|
) -> None:
|
965
964
|
super().__init__(**kwargs)
|
966
|
-
self.
|
967
|
-
self.
|
965
|
+
self.project_id = project_id
|
966
|
+
self.model_name = model_name
|
968
967
|
self._gcp_conn_id = gcp_conn_id
|
969
|
-
self.
|
970
|
-
self._validate_inputs()
|
968
|
+
self.impersonation_chain = impersonation_chain
|
971
969
|
|
972
970
|
def _validate_inputs(self):
|
973
|
-
if not self.
|
971
|
+
if not self.model_name:
|
974
972
|
raise AirflowException("The model_name parameter could not be empty.")
|
975
973
|
|
976
974
|
def execute(self, context: Context):
|
975
|
+
self._validate_inputs()
|
977
976
|
hook = MLEngineHook(
|
978
977
|
gcp_conn_id=self._gcp_conn_id,
|
979
|
-
impersonation_chain=self.
|
978
|
+
impersonation_chain=self.impersonation_chain,
|
980
979
|
)
|
981
980
|
|
982
|
-
project_id = self.
|
981
|
+
project_id = self.project_id or hook.project_id
|
983
982
|
if project_id:
|
984
983
|
MLEngineModelLink.persist(
|
985
984
|
context=context,
|
986
985
|
task_instance=self,
|
987
986
|
project_id=project_id,
|
988
|
-
model_id=self.
|
987
|
+
model_id=self.model_name,
|
989
988
|
)
|
990
989
|
|
991
990
|
return hook.list_versions(
|
992
|
-
project_id=self.
|
993
|
-
model_name=self.
|
991
|
+
project_id=self.project_id,
|
992
|
+
model_name=self.model_name,
|
994
993
|
)
|
995
994
|
|
996
995
|
|
@@ -1034,10 +1033,10 @@ class MLEngineDeleteVersionOperator(GoogleCloudBaseOperator):
|
|
1034
1033
|
"""
|
1035
1034
|
|
1036
1035
|
template_fields: Sequence[str] = (
|
1037
|
-
"
|
1038
|
-
"
|
1039
|
-
"
|
1040
|
-
"
|
1036
|
+
"project_id",
|
1037
|
+
"model_name",
|
1038
|
+
"version_name",
|
1039
|
+
"impersonation_chain",
|
1041
1040
|
)
|
1042
1041
|
operator_extra_links = (MLEngineModelLink(),)
|
1043
1042
|
|
@@ -1052,37 +1051,37 @@ class MLEngineDeleteVersionOperator(GoogleCloudBaseOperator):
|
|
1052
1051
|
**kwargs,
|
1053
1052
|
) -> None:
|
1054
1053
|
super().__init__(**kwargs)
|
1055
|
-
self.
|
1056
|
-
self.
|
1057
|
-
self.
|
1054
|
+
self.project_id = project_id
|
1055
|
+
self.model_name = model_name
|
1056
|
+
self.version_name = version_name
|
1058
1057
|
self._gcp_conn_id = gcp_conn_id
|
1059
|
-
self.
|
1060
|
-
self._validate_inputs()
|
1058
|
+
self.impersonation_chain = impersonation_chain
|
1061
1059
|
|
1062
1060
|
def _validate_inputs(self):
|
1063
|
-
if not self.
|
1061
|
+
if not self.model_name:
|
1064
1062
|
raise AirflowException("The model_name parameter could not be empty.")
|
1065
1063
|
|
1066
|
-
if not self.
|
1064
|
+
if not self.version_name:
|
1067
1065
|
raise AirflowException("The version_name parameter could not be empty.")
|
1068
1066
|
|
1069
1067
|
def execute(self, context: Context):
|
1068
|
+
self._validate_inputs()
|
1070
1069
|
hook = MLEngineHook(
|
1071
1070
|
gcp_conn_id=self._gcp_conn_id,
|
1072
|
-
impersonation_chain=self.
|
1071
|
+
impersonation_chain=self.impersonation_chain,
|
1073
1072
|
)
|
1074
1073
|
|
1075
|
-
project_id = self.
|
1074
|
+
project_id = self.project_id or hook.project_id
|
1076
1075
|
if project_id:
|
1077
1076
|
MLEngineModelLink.persist(
|
1078
1077
|
context=context,
|
1079
1078
|
task_instance=self,
|
1080
1079
|
project_id=project_id,
|
1081
|
-
model_id=self.
|
1080
|
+
model_id=self.model_name,
|
1082
1081
|
)
|
1083
1082
|
|
1084
1083
|
return hook.delete_version(
|
1085
|
-
project_id=self.
|
1084
|
+
project_id=self.project_id, model_name=self.model_name, version_name=self.version_name
|
1086
1085
|
)
|
1087
1086
|
|
1088
1087
|
|
@@ -1163,21 +1162,21 @@ class MLEngineStartTrainingJobOperator(GoogleCloudBaseOperator):
|
|
1163
1162
|
"""
|
1164
1163
|
|
1165
1164
|
template_fields: Sequence[str] = (
|
1166
|
-
"
|
1167
|
-
"
|
1168
|
-
"
|
1169
|
-
"
|
1170
|
-
"
|
1171
|
-
"
|
1172
|
-
"
|
1173
|
-
"
|
1174
|
-
"
|
1175
|
-
"
|
1176
|
-
"
|
1177
|
-
"
|
1178
|
-
"
|
1179
|
-
"
|
1180
|
-
"
|
1165
|
+
"project_id",
|
1166
|
+
"job_id",
|
1167
|
+
"region",
|
1168
|
+
"package_uris",
|
1169
|
+
"training_python_module",
|
1170
|
+
"training_args",
|
1171
|
+
"scale_tier",
|
1172
|
+
"master_type",
|
1173
|
+
"master_config",
|
1174
|
+
"runtime_version",
|
1175
|
+
"python_version",
|
1176
|
+
"job_dir",
|
1177
|
+
"service_account",
|
1178
|
+
"hyperparameters",
|
1179
|
+
"impersonation_chain",
|
1181
1180
|
)
|
1182
1181
|
operator_extra_links = (MLEngineJobDetailsLink(),)
|
1183
1182
|
|
@@ -1207,98 +1206,95 @@ class MLEngineStartTrainingJobOperator(GoogleCloudBaseOperator):
|
|
1207
1206
|
**kwargs,
|
1208
1207
|
) -> None:
|
1209
1208
|
super().__init__(**kwargs)
|
1210
|
-
self.
|
1211
|
-
self.
|
1212
|
-
self.
|
1213
|
-
self.
|
1214
|
-
self.
|
1215
|
-
self.
|
1216
|
-
self.
|
1217
|
-
self.
|
1218
|
-
self.
|
1219
|
-
self.
|
1220
|
-
self.
|
1221
|
-
self.
|
1222
|
-
self.
|
1209
|
+
self.project_id = project_id
|
1210
|
+
self.job_id = job_id
|
1211
|
+
self.region = region
|
1212
|
+
self.package_uris = package_uris
|
1213
|
+
self.training_python_module = training_python_module
|
1214
|
+
self.training_args = training_args
|
1215
|
+
self.scale_tier = scale_tier
|
1216
|
+
self.master_type = master_type
|
1217
|
+
self.master_config = master_config
|
1218
|
+
self.runtime_version = runtime_version
|
1219
|
+
self.python_version = python_version
|
1220
|
+
self.job_dir = job_dir
|
1221
|
+
self.service_account = service_account
|
1223
1222
|
self._gcp_conn_id = gcp_conn_id
|
1224
1223
|
self._mode = mode
|
1225
1224
|
self._labels = labels
|
1226
|
-
self.
|
1227
|
-
self.
|
1225
|
+
self.hyperparameters = hyperparameters
|
1226
|
+
self.impersonation_chain = impersonation_chain
|
1228
1227
|
self.deferrable = deferrable
|
1229
1228
|
self.cancel_on_kill = cancel_on_kill
|
1230
1229
|
|
1231
|
-
|
1230
|
+
def _handle_job_error(self, finished_training_job) -> None:
|
1231
|
+
if finished_training_job["state"] != "SUCCEEDED":
|
1232
|
+
self.log.error("MLEngine training job failed: %s", finished_training_job)
|
1233
|
+
raise RuntimeError(finished_training_job["errorMessage"])
|
1234
|
+
|
1235
|
+
def execute(self, context: Context):
|
1236
|
+
custom = self.scale_tier is not None and self.scale_tier.upper() == "CUSTOM"
|
1232
1237
|
custom_image = (
|
1233
|
-
custom
|
1234
|
-
and self._master_config is not None
|
1235
|
-
and self._master_config.get("imageUri", None) is not None
|
1238
|
+
custom and self.master_config is not None and self.master_config.get("imageUri", None) is not None
|
1236
1239
|
)
|
1237
1240
|
|
1238
|
-
if not self.
|
1241
|
+
if not self.project_id:
|
1239
1242
|
raise AirflowException("Google Cloud project id is required.")
|
1240
|
-
if not self.
|
1243
|
+
if not self.job_id:
|
1241
1244
|
raise AirflowException("An unique job id is required for Google MLEngine training job.")
|
1242
|
-
if not self.
|
1245
|
+
if not self.region:
|
1243
1246
|
raise AirflowException("Google Compute Engine region is required.")
|
1244
|
-
if custom and not self.
|
1247
|
+
if custom and not self.master_type:
|
1245
1248
|
raise AirflowException("master_type must be set when scale_tier is CUSTOM")
|
1246
|
-
if self.
|
1249
|
+
if self.master_config and not self.master_type:
|
1247
1250
|
raise AirflowException("master_type must be set when master_config is provided")
|
1248
|
-
if not (package_uris and training_python_module) and not custom_image:
|
1251
|
+
if not (self.package_uris and self.training_python_module) and not custom_image:
|
1249
1252
|
raise AirflowException(
|
1250
1253
|
"Either a Python package with a Python module or a custom Docker image should be provided."
|
1251
1254
|
)
|
1252
|
-
if (package_uris or training_python_module) and custom_image:
|
1255
|
+
if (self.package_uris or self.training_python_module) and custom_image:
|
1253
1256
|
raise AirflowException(
|
1254
1257
|
"Either a Python package with a Python module or "
|
1255
1258
|
"a custom Docker image should be provided but not both."
|
1256
1259
|
)
|
1257
|
-
|
1258
|
-
def _handle_job_error(self, finished_training_job) -> None:
|
1259
|
-
if finished_training_job["state"] != "SUCCEEDED":
|
1260
|
-
self.log.error("MLEngine training job failed: %s", finished_training_job)
|
1261
|
-
raise RuntimeError(finished_training_job["errorMessage"])
|
1262
|
-
|
1263
|
-
def execute(self, context: Context):
|
1264
|
-
job_id = _normalize_mlengine_job_id(self._job_id)
|
1260
|
+
job_id = _normalize_mlengine_job_id(self.job_id)
|
1265
1261
|
self.job_id = job_id
|
1266
1262
|
training_request: dict[str, Any] = {
|
1267
1263
|
"jobId": self.job_id,
|
1268
1264
|
"trainingInput": {
|
1269
|
-
"scaleTier": self.
|
1270
|
-
"region": self.
|
1265
|
+
"scaleTier": self.scale_tier,
|
1266
|
+
"region": self.region,
|
1271
1267
|
},
|
1272
1268
|
}
|
1273
|
-
if self.
|
1274
|
-
training_request["trainingInput"]["packageUris"] = self.
|
1269
|
+
if self.package_uris:
|
1270
|
+
training_request["trainingInput"]["packageUris"] = self.package_uris
|
1275
1271
|
|
1276
|
-
if self.
|
1277
|
-
training_request["trainingInput"]["pythonModule"] = self.
|
1272
|
+
if self.training_python_module:
|
1273
|
+
training_request["trainingInput"]["pythonModule"] = self.training_python_module
|
1278
1274
|
|
1279
|
-
if self.
|
1280
|
-
training_request["trainingInput"]["args"] = self.
|
1275
|
+
if self.training_args:
|
1276
|
+
training_request["trainingInput"]["args"] = self.training_args
|
1281
1277
|
|
1282
|
-
if self.
|
1283
|
-
training_request["trainingInput"]["masterType"] = self.
|
1278
|
+
if self.master_type:
|
1279
|
+
training_request["trainingInput"]["masterType"] = self.master_type
|
1284
1280
|
|
1285
|
-
if self.
|
1286
|
-
training_request["trainingInput"]["masterConfig"] = self.
|
1281
|
+
if self.master_config:
|
1282
|
+
training_request["trainingInput"]["masterConfig"] = self.master_config
|
1287
1283
|
|
1288
|
-
if self.
|
1289
|
-
training_request["trainingInput"]["runtimeVersion"] = self.
|
1284
|
+
if self.runtime_version:
|
1285
|
+
training_request["trainingInput"]["runtimeVersion"] = self.runtime_version
|
1290
1286
|
|
1291
|
-
if self.
|
1292
|
-
training_request["trainingInput"]["pythonVersion"] = self.
|
1287
|
+
if self.python_version:
|
1288
|
+
training_request["trainingInput"]["pythonVersion"] = self.python_version
|
1293
1289
|
|
1294
|
-
if self.
|
1295
|
-
training_request["trainingInput"]["jobDir"] = self.
|
1290
|
+
if self.job_dir:
|
1291
|
+
training_request["trainingInput"]["jobDir"] = self.job_dir
|
1296
1292
|
|
1297
|
-
if self.
|
1298
|
-
training_request["trainingInput"]["serviceAccount"] = self.
|
1293
|
+
if self.service_account:
|
1294
|
+
training_request["trainingInput"]["serviceAccount"] = self.service_account
|
1299
1295
|
|
1300
|
-
if self.
|
1301
|
-
training_request["trainingInput"]["hyperparameters"] = self.
|
1296
|
+
if self.hyperparameters:
|
1297
|
+
training_request["trainingInput"]["hyperparameters"] = self.hyperparameters
|
1302
1298
|
|
1303
1299
|
if self._labels:
|
1304
1300
|
training_request["labels"] = self._labels
|
@@ -1310,25 +1306,25 @@ class MLEngineStartTrainingJobOperator(GoogleCloudBaseOperator):
|
|
1310
1306
|
|
1311
1307
|
hook = MLEngineHook(
|
1312
1308
|
gcp_conn_id=self._gcp_conn_id,
|
1313
|
-
impersonation_chain=self.
|
1309
|
+
impersonation_chain=self.impersonation_chain,
|
1314
1310
|
)
|
1315
1311
|
self.hook = hook
|
1316
1312
|
|
1317
1313
|
try:
|
1318
1314
|
self.log.info("Executing: %s'", training_request)
|
1319
1315
|
self.job_id = self.hook.create_job_without_waiting_result(
|
1320
|
-
project_id=self.
|
1316
|
+
project_id=self.project_id,
|
1321
1317
|
body=training_request,
|
1322
1318
|
)
|
1323
1319
|
except HttpError as e:
|
1324
1320
|
if e.resp.status == 409:
|
1325
1321
|
# If the job already exists retrieve it
|
1326
|
-
self.hook.get_job(project_id=self.
|
1327
|
-
if self.
|
1322
|
+
self.hook.get_job(project_id=self.project_id, job_id=self.job_id)
|
1323
|
+
if self.project_id:
|
1328
1324
|
MLEngineJobDetailsLink.persist(
|
1329
1325
|
context=context,
|
1330
1326
|
task_instance=self,
|
1331
|
-
project_id=self.
|
1327
|
+
project_id=self.project_id,
|
1332
1328
|
job_id=self.job_id,
|
1333
1329
|
)
|
1334
1330
|
self.log.error(
|
@@ -1345,30 +1341,30 @@ class MLEngineStartTrainingJobOperator(GoogleCloudBaseOperator):
|
|
1345
1341
|
trigger=MLEngineStartTrainingJobTrigger(
|
1346
1342
|
conn_id=self._gcp_conn_id,
|
1347
1343
|
job_id=self.job_id,
|
1348
|
-
project_id=self.
|
1349
|
-
region=self.
|
1350
|
-
runtime_version=self.
|
1351
|
-
python_version=self.
|
1352
|
-
job_dir=self.
|
1353
|
-
package_uris=self.
|
1354
|
-
training_python_module=self.
|
1355
|
-
training_args=self.
|
1344
|
+
project_id=self.project_id,
|
1345
|
+
region=self.region,
|
1346
|
+
runtime_version=self.runtime_version,
|
1347
|
+
python_version=self.python_version,
|
1348
|
+
job_dir=self.job_dir,
|
1349
|
+
package_uris=self.package_uris,
|
1350
|
+
training_python_module=self.training_python_module,
|
1351
|
+
training_args=self.training_args,
|
1356
1352
|
labels=self._labels,
|
1357
1353
|
gcp_conn_id=self._gcp_conn_id,
|
1358
|
-
impersonation_chain=self.
|
1354
|
+
impersonation_chain=self.impersonation_chain,
|
1359
1355
|
),
|
1360
1356
|
method_name="execute_complete",
|
1361
1357
|
)
|
1362
1358
|
else:
|
1363
|
-
finished_training_job = self._wait_for_job_done(self.
|
1359
|
+
finished_training_job = self._wait_for_job_done(self.project_id, self.job_id)
|
1364
1360
|
self._handle_job_error(finished_training_job)
|
1365
1361
|
gcp_metadata = {
|
1366
1362
|
"job_id": self.job_id,
|
1367
|
-
"project_id": self.
|
1363
|
+
"project_id": self.project_id,
|
1368
1364
|
}
|
1369
1365
|
context["task_instance"].xcom_push("gcp_metadata", gcp_metadata)
|
1370
1366
|
|
1371
|
-
project_id = self.
|
1367
|
+
project_id = self.project_id or hook.project_id
|
1372
1368
|
if project_id:
|
1373
1369
|
MLEngineJobDetailsLink.persist(
|
1374
1370
|
context=context,
|
@@ -1413,19 +1409,19 @@ class MLEngineStartTrainingJobOperator(GoogleCloudBaseOperator):
|
|
1413
1409
|
self.task_id,
|
1414
1410
|
event["message"],
|
1415
1411
|
)
|
1416
|
-
if self.
|
1412
|
+
if self.project_id:
|
1417
1413
|
MLEngineJobDetailsLink.persist(
|
1418
1414
|
context=context,
|
1419
1415
|
task_instance=self,
|
1420
|
-
project_id=self.
|
1421
|
-
job_id=self.
|
1416
|
+
project_id=self.project_id,
|
1417
|
+
job_id=self.job_id,
|
1422
1418
|
)
|
1423
1419
|
|
1424
1420
|
def on_kill(self) -> None:
|
1425
1421
|
if self.job_id and self.cancel_on_kill:
|
1426
|
-
self.hook.cancel_job(job_id=self.job_id, project_id=self.
|
1422
|
+
self.hook.cancel_job(job_id=self.job_id, project_id=self.project_id) # type: ignore[union-attr]
|
1427
1423
|
else:
|
1428
|
-
self.log.info("Skipping to cancel job: %s:%s.%s", self.
|
1424
|
+
self.log.info("Skipping to cancel job: %s:%s.%s", self.project_id, self.job_id)
|
1429
1425
|
|
1430
1426
|
|
1431
1427
|
@deprecated(
|
@@ -1461,9 +1457,9 @@ class MLEngineTrainingCancelJobOperator(GoogleCloudBaseOperator):
|
|
1461
1457
|
"""
|
1462
1458
|
|
1463
1459
|
template_fields: Sequence[str] = (
|
1464
|
-
"
|
1465
|
-
"
|
1466
|
-
"
|
1460
|
+
"project_id",
|
1461
|
+
"job_id",
|
1462
|
+
"impersonation_chain",
|
1467
1463
|
)
|
1468
1464
|
operator_extra_links = (MLEngineJobSListLink(),)
|
1469
1465
|
|
@@ -1477,21 +1473,49 @@ class MLEngineTrainingCancelJobOperator(GoogleCloudBaseOperator):
|
|
1477
1473
|
**kwargs,
|
1478
1474
|
) -> None:
|
1479
1475
|
super().__init__(**kwargs)
|
1480
|
-
self.
|
1481
|
-
self.
|
1476
|
+
self.project_id = project_id
|
1477
|
+
self.job_id = job_id
|
1482
1478
|
self._gcp_conn_id = gcp_conn_id
|
1483
|
-
self.
|
1479
|
+
self.impersonation_chain = impersonation_chain
|
1484
1480
|
|
1485
|
-
|
1486
|
-
|
1481
|
+
@property
|
1482
|
+
@deprecated(
|
1483
|
+
reason="`_project_id` is deprecated and will be removed in the future. Please use `project_id`"
|
1484
|
+
" instead.",
|
1485
|
+
category=AirflowProviderDeprecationWarning,
|
1486
|
+
)
|
1487
|
+
def _project_id(self):
|
1488
|
+
"""Alias for ``project_id``, used for compatibility (deprecated)."""
|
1489
|
+
return self.project_id
|
1490
|
+
|
1491
|
+
@property
|
1492
|
+
@deprecated(
|
1493
|
+
reason="`_job_id` is deprecated and will be removed in the future. Please use `job_id` instead.",
|
1494
|
+
category=AirflowProviderDeprecationWarning,
|
1495
|
+
)
|
1496
|
+
def _job_id(self):
|
1497
|
+
"""Alias for ``job_id``, used for compatibility (deprecated)."""
|
1498
|
+
return self.job_id
|
1499
|
+
|
1500
|
+
@property
|
1501
|
+
@deprecated(
|
1502
|
+
reason="`_impersonation_chain` is deprecated and will be removed in the future."
|
1503
|
+
" Please use `impersonation_chain` instead.",
|
1504
|
+
category=AirflowProviderDeprecationWarning,
|
1505
|
+
)
|
1506
|
+
def _impersonation_chain(self):
|
1507
|
+
"""Alias for ``impersonation_chain``, used for compatibility (deprecated)."""
|
1508
|
+
return self.impersonation_chain
|
1487
1509
|
|
1488
1510
|
def execute(self, context: Context):
|
1511
|
+
if not self.project_id:
|
1512
|
+
raise AirflowException("Google Cloud project id is required.")
|
1489
1513
|
hook = MLEngineHook(
|
1490
1514
|
gcp_conn_id=self._gcp_conn_id,
|
1491
|
-
impersonation_chain=self.
|
1515
|
+
impersonation_chain=self.impersonation_chain,
|
1492
1516
|
)
|
1493
1517
|
|
1494
|
-
project_id = self.
|
1518
|
+
project_id = self.project_id or hook.project_id
|
1495
1519
|
if project_id:
|
1496
1520
|
MLEngineJobSListLink.persist(
|
1497
1521
|
context=context,
|
@@ -1499,4 +1523,4 @@ class MLEngineTrainingCancelJobOperator(GoogleCloudBaseOperator):
|
|
1499
1523
|
project_id=project_id,
|
1500
1524
|
)
|
1501
1525
|
|
1502
|
-
hook.cancel_job(project_id=self.
|
1526
|
+
hook.cancel_job(project_id=self.project_id, job_id=_normalize_mlengine_job_id(self.job_id))
|