apache-airflow-providers-google 10.20.0rc1__py3-none-any.whl → 10.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. airflow/providers/google/__init__.py +1 -1
  2. airflow/providers/google/ads/hooks/ads.py +16 -8
  3. airflow/providers/google/ads/transfers/ads_to_gcs.py +2 -1
  4. airflow/providers/google/cloud/_internal_client/secret_manager_client.py +6 -3
  5. airflow/providers/google/cloud/hooks/bigquery.py +158 -79
  6. airflow/providers/google/cloud/hooks/cloud_sql.py +12 -6
  7. airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +34 -17
  8. airflow/providers/google/cloud/hooks/dataflow.py +30 -26
  9. airflow/providers/google/cloud/hooks/dataform.py +2 -1
  10. airflow/providers/google/cloud/hooks/datafusion.py +4 -2
  11. airflow/providers/google/cloud/hooks/dataproc.py +102 -51
  12. airflow/providers/google/cloud/hooks/functions.py +20 -10
  13. airflow/providers/google/cloud/hooks/kubernetes_engine.py +22 -11
  14. airflow/providers/google/cloud/hooks/os_login.py +2 -1
  15. airflow/providers/google/cloud/hooks/secret_manager.py +18 -9
  16. airflow/providers/google/cloud/hooks/translate.py +2 -1
  17. airflow/providers/google/cloud/hooks/vertex_ai/batch_prediction_job.py +2 -1
  18. airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +141 -0
  19. airflow/providers/google/cloud/hooks/vertex_ai/pipeline_job.py +2 -1
  20. airflow/providers/google/cloud/links/base.py +2 -1
  21. airflow/providers/google/cloud/links/datafusion.py +2 -1
  22. airflow/providers/google/cloud/log/stackdriver_task_handler.py +4 -2
  23. airflow/providers/google/cloud/openlineage/mixins.py +10 -0
  24. airflow/providers/google/cloud/openlineage/utils.py +4 -2
  25. airflow/providers/google/cloud/operators/bigquery.py +55 -21
  26. airflow/providers/google/cloud/operators/cloud_batch.py +3 -1
  27. airflow/providers/google/cloud/operators/cloud_sql.py +22 -11
  28. airflow/providers/google/cloud/operators/dataform.py +2 -1
  29. airflow/providers/google/cloud/operators/dataproc.py +75 -34
  30. airflow/providers/google/cloud/operators/dataproc_metastore.py +24 -12
  31. airflow/providers/google/cloud/operators/gcs.py +2 -1
  32. airflow/providers/google/cloud/operators/pubsub.py +10 -5
  33. airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py +3 -3
  34. airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +12 -9
  35. airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +243 -0
  36. airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py +2 -1
  37. airflow/providers/google/cloud/operators/vision.py +36 -18
  38. airflow/providers/google/cloud/sensors/gcs.py +11 -2
  39. airflow/providers/google/cloud/sensors/pubsub.py +2 -1
  40. airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +21 -12
  41. airflow/providers/google/cloud/transfers/bigquery_to_postgres.py +1 -1
  42. airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py +2 -1
  43. airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +17 -5
  44. airflow/providers/google/cloud/transfers/gcs_to_gcs.py +12 -6
  45. airflow/providers/google/cloud/transfers/local_to_gcs.py +5 -1
  46. airflow/providers/google/cloud/transfers/mysql_to_gcs.py +2 -1
  47. airflow/providers/google/cloud/transfers/oracle_to_gcs.py +2 -1
  48. airflow/providers/google/cloud/transfers/presto_to_gcs.py +2 -1
  49. airflow/providers/google/cloud/transfers/s3_to_gcs.py +2 -1
  50. airflow/providers/google/cloud/transfers/trino_to_gcs.py +2 -1
  51. airflow/providers/google/cloud/triggers/cloud_batch.py +2 -1
  52. airflow/providers/google/cloud/triggers/cloud_run.py +2 -1
  53. airflow/providers/google/cloud/triggers/dataflow.py +2 -1
  54. airflow/providers/google/cloud/triggers/vertex_ai.py +2 -1
  55. airflow/providers/google/cloud/utils/external_token_supplier.py +4 -2
  56. airflow/providers/google/cloud/utils/field_sanitizer.py +4 -2
  57. airflow/providers/google/cloud/utils/field_validator.py +6 -3
  58. airflow/providers/google/cloud/utils/helpers.py +2 -1
  59. airflow/providers/google/common/hooks/base_google.py +2 -1
  60. airflow/providers/google/common/utils/id_token_credentials.py +2 -1
  61. airflow/providers/google/get_provider_info.py +3 -2
  62. airflow/providers/google/go_module_utils.py +4 -2
  63. airflow/providers/google/marketing_platform/hooks/analytics_admin.py +12 -6
  64. airflow/providers/google/marketing_platform/links/analytics_admin.py +2 -1
  65. airflow/providers/google/suite/transfers/local_to_drive.py +2 -1
  66. {apache_airflow_providers_google-10.20.0rc1.dist-info → apache_airflow_providers_google-10.21.0.dist-info}/METADATA +14 -14
  67. {apache_airflow_providers_google-10.20.0rc1.dist-info → apache_airflow_providers_google-10.21.0.dist-info}/RECORD +69 -69
  68. {apache_airflow_providers_google-10.20.0rc1.dist-info → apache_airflow_providers_google-10.21.0.dist-info}/WHEEL +0 -0
  69. {apache_airflow_providers_google-10.20.0rc1.dist-info → apache_airflow_providers_google-10.21.0.dist-info}/entry_points.txt +0 -0
@@ -21,6 +21,9 @@ from __future__ import annotations
21
21
 
22
22
  from typing import TYPE_CHECKING, Sequence
23
23
 
24
+ from deprecated import deprecated
25
+
26
+ from airflow.exceptions import AirflowProviderDeprecationWarning
24
27
  from airflow.providers.google.cloud.hooks.vertex_ai.generative_model import GenerativeModelHook
25
28
  from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
26
29
 
@@ -28,6 +31,10 @@ if TYPE_CHECKING:
28
31
  from airflow.utils.context import Context
29
32
 
30
33
 
34
+ @deprecated(
35
+ reason="This operator is deprecated and will be removed after 01.01.2025, please use `TextGenerationModelPredictOperator`.",
36
+ category=AirflowProviderDeprecationWarning,
37
+ )
31
38
  class PromptLanguageModelOperator(GoogleCloudBaseOperator):
32
39
  """
33
40
  Uses the Vertex AI PaLM API to generate natural language text.
@@ -113,6 +120,10 @@ class PromptLanguageModelOperator(GoogleCloudBaseOperator):
113
120
  return response
114
121
 
115
122
 
123
+ @deprecated(
124
+ reason="This operator is deprecated and will be removed after 01.01.2025, please use `TextEmbeddingModelGetEmbeddingsOperator`.",
125
+ category=AirflowProviderDeprecationWarning,
126
+ )
116
127
  class GenerateTextEmbeddingsOperator(GoogleCloudBaseOperator):
117
128
  """
118
129
  Uses the Vertex AI PaLM API to generate natural language text.
@@ -177,6 +188,10 @@ class GenerateTextEmbeddingsOperator(GoogleCloudBaseOperator):
177
188
  return response
178
189
 
179
190
 
191
+ @deprecated(
192
+ reason="This operator is deprecated and will be removed after 01.01.2025, please use `GenerativeModelGenerateContentOperator`.",
193
+ category=AirflowProviderDeprecationWarning,
194
+ )
180
195
  class PromptMultimodalModelOperator(GoogleCloudBaseOperator):
181
196
  """
182
197
  Use the Vertex AI Gemini Pro foundation model to generate natural language text.
@@ -249,6 +264,10 @@ class PromptMultimodalModelOperator(GoogleCloudBaseOperator):
249
264
  return response
250
265
 
251
266
 
267
+ @deprecated(
268
+ reason="This operator is deprecated and will be removed after 01.01.2025, please use `GenerativeModelGenerateContentOperator`.",
269
+ category=AirflowProviderDeprecationWarning,
270
+ )
252
271
  class PromptMultimodalModelWithMediaOperator(GoogleCloudBaseOperator):
253
272
  """
254
273
  Use the Vertex AI Gemini Pro foundation model to generate natural language text.
@@ -328,3 +347,227 @@ class PromptMultimodalModelWithMediaOperator(GoogleCloudBaseOperator):
328
347
  self.xcom_push(context, key="prompt_response", value=response)
329
348
 
330
349
  return response
350
+
351
+
352
+ class TextGenerationModelPredictOperator(GoogleCloudBaseOperator):
353
+ """
354
+ Uses the Vertex AI PaLM API to generate natural language text.
355
+
356
+ :param project_id: Required. The ID of the Google Cloud project that the
357
+ service belongs to (templated).
358
+ :param location: Required. The ID of the Google Cloud location that the
359
+ service belongs to (templated).
360
+ :param prompt: Required. Inputs or queries that a user or a program gives
361
+ to the Vertex AI PaLM API, in order to elicit a specific response (templated).
362
+ :param pretrained_model: By default uses the pre-trained model `text-bison`,
363
+ optimized for performing natural language tasks such as classification,
364
+ summarization, extraction, content creation, and ideation.
365
+ :param temperature: Temperature controls the degree of randomness in token
366
+ selection. Defaults to 0.0.
367
+ :param max_output_tokens: Token limit determines the maximum amount of text
368
+ output. Defaults to 256.
369
+ :param top_p: Tokens are selected from most probable to least until the sum
370
+ of their probabilities equals the top_p value. Defaults to 0.8.
371
+ :param top_k: A top_k of 1 means the selected token is the most probable
372
+ among all tokens. Defaults to 0.4.
373
+ :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
374
+ :param impersonation_chain: Optional service account to impersonate using short-term
375
+ credentials, or chained list of accounts required to get the access_token
376
+ of the last account in the list, which will be impersonated in the request.
377
+ If set as a string, the account must grant the originating account
378
+ the Service Account Token Creator IAM role.
379
+ If set as a sequence, the identities from the list must grant
380
+ Service Account Token Creator IAM role to the directly preceding identity, with first
381
+ account from the list granting this role to the originating account (templated).
382
+ """
383
+
384
+ template_fields = ("location", "project_id", "impersonation_chain", "prompt")
385
+
386
+ def __init__(
387
+ self,
388
+ *,
389
+ project_id: str,
390
+ location: str,
391
+ prompt: str,
392
+ pretrained_model: str = "text-bison",
393
+ temperature: float = 0.0,
394
+ max_output_tokens: int = 256,
395
+ top_p: float = 0.8,
396
+ top_k: int = 40,
397
+ gcp_conn_id: str = "google_cloud_default",
398
+ impersonation_chain: str | Sequence[str] | None = None,
399
+ **kwargs,
400
+ ) -> None:
401
+ super().__init__(**kwargs)
402
+ self.project_id = project_id
403
+ self.location = location
404
+ self.prompt = prompt
405
+ self.pretrained_model = pretrained_model
406
+ self.temperature = temperature
407
+ self.max_output_tokens = max_output_tokens
408
+ self.top_p = top_p
409
+ self.top_k = top_k
410
+ self.gcp_conn_id = gcp_conn_id
411
+ self.impersonation_chain = impersonation_chain
412
+
413
+ def execute(self, context: Context):
414
+ self.hook = GenerativeModelHook(
415
+ gcp_conn_id=self.gcp_conn_id,
416
+ impersonation_chain=self.impersonation_chain,
417
+ )
418
+
419
+ self.log.info("Submitting prompt")
420
+ response = self.hook.text_generation_model_predict(
421
+ project_id=self.project_id,
422
+ location=self.location,
423
+ prompt=self.prompt,
424
+ pretrained_model=self.pretrained_model,
425
+ temperature=self.temperature,
426
+ max_output_tokens=self.max_output_tokens,
427
+ top_p=self.top_p,
428
+ top_k=self.top_k,
429
+ )
430
+
431
+ self.log.info("Model response: %s", response)
432
+ self.xcom_push(context, key="model_response", value=response)
433
+
434
+ return response
435
+
436
+
437
+ class TextEmbeddingModelGetEmbeddingsOperator(GoogleCloudBaseOperator):
438
+ """
439
+ Uses the Vertex AI Embeddings API to generate embeddings based on prompt.
440
+
441
+ :param project_id: Required. The ID of the Google Cloud project that the
442
+ service belongs to (templated).
443
+ :param location: Required. The ID of the Google Cloud location that the
444
+ service belongs to (templated).
445
+ :param prompt: Required. Inputs or queries that a user or a program gives
446
+ to the Vertex AI PaLM API, in order to elicit a specific response (templated).
447
+ :param pretrained_model: By default uses the pre-trained model `textembedding-gecko`,
448
+ optimized for performing text embeddings.
449
+ :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
450
+ :param impersonation_chain: Optional service account to impersonate using short-term
451
+ credentials, or chained list of accounts required to get the access_token
452
+ of the last account in the list, which will be impersonated in the request.
453
+ If set as a string, the account must grant the originating account
454
+ the Service Account Token Creator IAM role.
455
+ If set as a sequence, the identities from the list must grant
456
+ Service Account Token Creator IAM role to the directly preceding identity, with first
457
+ account from the list granting this role to the originating account (templated).
458
+ """
459
+
460
+ template_fields = ("location", "project_id", "impersonation_chain", "prompt")
461
+
462
+ def __init__(
463
+ self,
464
+ *,
465
+ project_id: str,
466
+ location: str,
467
+ prompt: str,
468
+ pretrained_model: str = "textembedding-gecko",
469
+ gcp_conn_id: str = "google_cloud_default",
470
+ impersonation_chain: str | Sequence[str] | None = None,
471
+ **kwargs,
472
+ ) -> None:
473
+ super().__init__(**kwargs)
474
+ self.project_id = project_id
475
+ self.location = location
476
+ self.prompt = prompt
477
+ self.pretrained_model = pretrained_model
478
+ self.gcp_conn_id = gcp_conn_id
479
+ self.impersonation_chain = impersonation_chain
480
+
481
+ def execute(self, context: Context):
482
+ self.hook = GenerativeModelHook(
483
+ gcp_conn_id=self.gcp_conn_id,
484
+ impersonation_chain=self.impersonation_chain,
485
+ )
486
+
487
+ self.log.info("Generating text embeddings")
488
+ response = self.hook.text_embedding_model_get_embeddings(
489
+ project_id=self.project_id,
490
+ location=self.location,
491
+ prompt=self.prompt,
492
+ pretrained_model=self.pretrained_model,
493
+ )
494
+
495
+ self.log.info("Model response: %s", response)
496
+ self.xcom_push(context, key="model_response", value=response)
497
+
498
+ return response
499
+
500
+
501
+ class GenerativeModelGenerateContentOperator(GoogleCloudBaseOperator):
502
+ """
503
+ Use the Vertex AI Gemini Pro foundation model to generate content.
504
+
505
+ :param project_id: Required. The ID of the Google Cloud project that the
506
+ service belongs to (templated).
507
+ :param contents: Required. The multi-part content of a message that a user or a program
508
+ gives to the generative model, in order to elicit a specific response.
509
+ :param location: Required. The ID of the Google Cloud location that the
510
+ service belongs to (templated).
511
+ :param generation_config: Optional. Generation configuration settings.
512
+ :param safety_settings: Optional. Per request settings for blocking unsafe content.
513
+ :param pretrained_model: By default uses the pre-trained model `gemini-pro`,
514
+ supporting prompts with text-only input, including natural language
515
+ tasks, multi-turn text and code chat, and code generation. It can
516
+ output text and code.
517
+ :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
518
+ :param impersonation_chain: Optional service account to impersonate using short-term
519
+ credentials, or chained list of accounts required to get the access_token
520
+ of the last account in the list, which will be impersonated in the request.
521
+ If set as a string, the account must grant the originating account
522
+ the Service Account Token Creator IAM role.
523
+ If set as a sequence, the identities from the list must grant
524
+ Service Account Token Creator IAM role to the directly preceding identity, with first
525
+ account from the list granting this role to the originating account (templated).
526
+ """
527
+
528
+ template_fields = ("location", "project_id", "impersonation_chain", "contents")
529
+
530
+ def __init__(
531
+ self,
532
+ *,
533
+ project_id: str,
534
+ contents: list,
535
+ location: str,
536
+ tools: list | None = None,
537
+ generation_config: dict | None = None,
538
+ safety_settings: dict | None = None,
539
+ pretrained_model: str = "gemini-pro",
540
+ gcp_conn_id: str = "google_cloud_default",
541
+ impersonation_chain: str | Sequence[str] | None = None,
542
+ **kwargs,
543
+ ) -> None:
544
+ super().__init__(**kwargs)
545
+ self.project_id = project_id
546
+ self.location = location
547
+ self.contents = contents
548
+ self.tools = tools
549
+ self.generation_config = generation_config
550
+ self.safety_settings = safety_settings
551
+ self.pretrained_model = pretrained_model
552
+ self.gcp_conn_id = gcp_conn_id
553
+ self.impersonation_chain = impersonation_chain
554
+
555
+ def execute(self, context: Context):
556
+ self.hook = GenerativeModelHook(
557
+ gcp_conn_id=self.gcp_conn_id,
558
+ impersonation_chain=self.impersonation_chain,
559
+ )
560
+ response = self.hook.generative_model_generate_content(
561
+ project_id=self.project_id,
562
+ location=self.location,
563
+ contents=self.contents,
564
+ tools=self.tools,
565
+ generation_config=self.generation_config,
566
+ safety_settings=self.safety_settings,
567
+ pretrained_model=self.pretrained_model,
568
+ )
569
+
570
+ self.log.info("Model response: %s", response)
571
+ self.xcom_push(context, key="model_response", value=response)
572
+
573
+ return response
@@ -301,7 +301,8 @@ class GetPipelineJobOperator(GoogleCloudBaseOperator):
301
301
 
302
302
 
303
303
  class ListPipelineJobOperator(GoogleCloudBaseOperator):
304
- """Lists PipelineJob in a Location.
304
+ """
305
+ Lists PipelineJob in a Location.
305
306
 
306
307
  :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
307
308
  :param region: Required. The ID of the Google Cloud region that the service belongs to.
@@ -47,7 +47,8 @@ MetaData = Sequence[Tuple[str, str]]
47
47
 
48
48
 
49
49
  class CloudVisionCreateProductSetOperator(GoogleCloudBaseOperator):
50
- """Create a new ProductSet resource.
50
+ """
51
+ Create a new ProductSet resource.
51
52
 
52
53
  .. seealso::
53
54
  For more information on how to use this operator, take a look at the guide:
@@ -139,7 +140,8 @@ class CloudVisionCreateProductSetOperator(GoogleCloudBaseOperator):
139
140
 
140
141
 
141
142
  class CloudVisionGetProductSetOperator(GoogleCloudBaseOperator):
142
- """Get information associated with a ProductSet.
143
+ """
144
+ Get information associated with a ProductSet.
143
145
 
144
146
  .. seealso::
145
147
  For more information on how to use this operator, take a look at the guide:
@@ -216,7 +218,8 @@ class CloudVisionGetProductSetOperator(GoogleCloudBaseOperator):
216
218
 
217
219
 
218
220
  class CloudVisionUpdateProductSetOperator(GoogleCloudBaseOperator):
219
- """Make changes to a `ProductSet` resource.
221
+ """
222
+ Make changes to a `ProductSet` resource.
220
223
 
221
224
  Only ``display_name`` can be updated currently.
222
225
 
@@ -322,7 +325,8 @@ class CloudVisionUpdateProductSetOperator(GoogleCloudBaseOperator):
322
325
 
323
326
 
324
327
  class CloudVisionDeleteProductSetOperator(GoogleCloudBaseOperator):
325
- """Permanently deletes a ``ProductSet``.
328
+ """
329
+ Permanently deletes a ``ProductSet``.
326
330
 
327
331
  ``Products`` and ``ReferenceImages`` in the ``ProductSet`` are not deleted.
328
332
  The actual image files are not deleted from Google Cloud Storage.
@@ -402,7 +406,8 @@ class CloudVisionDeleteProductSetOperator(GoogleCloudBaseOperator):
402
406
 
403
407
 
404
408
  class CloudVisionCreateProductOperator(GoogleCloudBaseOperator):
405
- """Create and return a new product resource.
409
+ """
410
+ Create and return a new product resource.
406
411
 
407
412
  Possible errors regarding the ``Product`` object provided:
408
413
 
@@ -499,7 +504,8 @@ class CloudVisionCreateProductOperator(GoogleCloudBaseOperator):
499
504
 
500
505
 
501
506
  class CloudVisionGetProductOperator(GoogleCloudBaseOperator):
502
- """Get information associated with a ``Product``.
507
+ """
508
+ Get information associated with a ``Product``.
503
509
 
504
510
  Possible errors:
505
511
 
@@ -580,7 +586,8 @@ class CloudVisionGetProductOperator(GoogleCloudBaseOperator):
580
586
 
581
587
 
582
588
  class CloudVisionUpdateProductOperator(GoogleCloudBaseOperator):
583
- """Make changes to a Product resource.
589
+ """
590
+ Make changes to a Product resource.
584
591
 
585
592
  Only the display_name, description, and labels fields can be updated right now.
586
593
 
@@ -693,7 +700,8 @@ class CloudVisionUpdateProductOperator(GoogleCloudBaseOperator):
693
700
 
694
701
 
695
702
  class CloudVisionDeleteProductOperator(GoogleCloudBaseOperator):
696
- """Permanently delete a product and its reference images.
703
+ """
704
+ Permanently delete a product and its reference images.
697
705
 
698
706
  Metadata of the product and all its images will be deleted right away, but
699
707
  search queries against ProductSets containing the product may still work
@@ -778,7 +786,8 @@ class CloudVisionDeleteProductOperator(GoogleCloudBaseOperator):
778
786
 
779
787
 
780
788
  class CloudVisionImageAnnotateOperator(GoogleCloudBaseOperator):
781
- """Run image detection and annotation for an image or a batch of images.
789
+ """
790
+ Run image detection and annotation for an image or a batch of images.
782
791
 
783
792
  .. seealso::
784
793
  For more information on how to use this operator, take a look at the guide:
@@ -845,7 +854,8 @@ class CloudVisionImageAnnotateOperator(GoogleCloudBaseOperator):
845
854
 
846
855
 
847
856
  class CloudVisionCreateReferenceImageOperator(GoogleCloudBaseOperator):
848
- """Create and return a new ReferenceImage ID resource.
857
+ """
858
+ Create and return a new ReferenceImage ID resource.
849
859
 
850
860
  .. seealso::
851
861
  For more information on how to use this operator, take a look at the guide:
@@ -948,7 +958,8 @@ class CloudVisionCreateReferenceImageOperator(GoogleCloudBaseOperator):
948
958
 
949
959
 
950
960
  class CloudVisionDeleteReferenceImageOperator(GoogleCloudBaseOperator):
951
- """Delete a ReferenceImage ID resource.
961
+ """
962
+ Delete a ReferenceImage ID resource.
952
963
 
953
964
  .. seealso::
954
965
  For more information on how to use this operator, take a look at the guide:
@@ -1033,7 +1044,8 @@ class CloudVisionDeleteReferenceImageOperator(GoogleCloudBaseOperator):
1033
1044
 
1034
1045
 
1035
1046
  class CloudVisionAddProductToProductSetOperator(GoogleCloudBaseOperator):
1036
- """Add a Product to the specified ProductSet.
1047
+ """
1048
+ Add a Product to the specified ProductSet.
1037
1049
 
1038
1050
  If the Product is already present, no change is made. One Product can be
1039
1051
  added to at most 100 ProductSets.
@@ -1122,7 +1134,8 @@ class CloudVisionAddProductToProductSetOperator(GoogleCloudBaseOperator):
1122
1134
 
1123
1135
 
1124
1136
  class CloudVisionRemoveProductFromProductSetOperator(GoogleCloudBaseOperator):
1125
- """Remove a Product from the specified ProductSet.
1137
+ """
1138
+ Remove a Product from the specified ProductSet.
1126
1139
 
1127
1140
  .. seealso::
1128
1141
  For more information on how to use this operator, take a look at the guide:
@@ -1204,7 +1217,8 @@ class CloudVisionRemoveProductFromProductSetOperator(GoogleCloudBaseOperator):
1204
1217
 
1205
1218
 
1206
1219
  class CloudVisionDetectTextOperator(GoogleCloudBaseOperator):
1207
- """Detect Text in the image.
1220
+ """
1221
+ Detect Text in the image.
1208
1222
 
1209
1223
  .. seealso::
1210
1224
  For more information on how to use this operator, take a look at the guide:
@@ -1285,7 +1299,8 @@ class CloudVisionDetectTextOperator(GoogleCloudBaseOperator):
1285
1299
 
1286
1300
 
1287
1301
  class CloudVisionTextDetectOperator(GoogleCloudBaseOperator):
1288
- """Detect Document Text in the image.
1302
+ """
1303
+ Detect Document Text in the image.
1289
1304
 
1290
1305
  .. seealso::
1291
1306
  For more information on how to use this operator, take a look at the guide:
@@ -1365,7 +1380,8 @@ class CloudVisionTextDetectOperator(GoogleCloudBaseOperator):
1365
1380
 
1366
1381
 
1367
1382
  class CloudVisionDetectImageLabelsOperator(GoogleCloudBaseOperator):
1368
- """Detect Document Text in the image.
1383
+ """
1384
+ Detect Document Text in the image.
1369
1385
 
1370
1386
  .. seealso::
1371
1387
  For more information on how to use this operator, take a look at the guide:
@@ -1435,7 +1451,8 @@ class CloudVisionDetectImageLabelsOperator(GoogleCloudBaseOperator):
1435
1451
 
1436
1452
 
1437
1453
  class CloudVisionDetectImageSafeSearchOperator(GoogleCloudBaseOperator):
1438
- """Detect Document Text in the image.
1454
+ """
1455
+ Detect Document Text in the image.
1439
1456
 
1440
1457
  .. seealso::
1441
1458
  For more information on how to use this operator, take a look at the guide:
@@ -1507,7 +1524,8 @@ class CloudVisionDetectImageSafeSearchOperator(GoogleCloudBaseOperator):
1507
1524
  def prepare_additional_parameters(
1508
1525
  additional_properties: dict | None, language_hints: Any, web_detection_params: Any
1509
1526
  ) -> dict | None:
1510
- """Create a value for the ``additional_properties`` parameter.
1527
+ """
1528
+ Create a value for the ``additional_properties`` parameter.
1511
1529
 
1512
1530
  The new value is based on ``language_hints``, ``web_detection_params``, and
1513
1531
  ``additional_properties`` parameters specified by the user.
@@ -188,7 +188,15 @@ def ts_function(context):
188
188
  try:
189
189
  return context["data_interval_end"]
190
190
  except KeyError:
191
- return context["dag"].following_schedule(context["execution_date"])
191
+ from airflow.utils import timezone
192
+
193
+ data_interval = context["dag"].infer_automated_data_interval(
194
+ timezone.coerce_datetime(context["execution_date"])
195
+ )
196
+ next_info = context["dag"].next_dagrun_info(data_interval, restricted=False)
197
+ if next_info is None:
198
+ return None
199
+ return next_info.data_interval.start
192
200
 
193
201
 
194
202
  class GCSObjectUpdateSensor(BaseSensorOperator):
@@ -575,7 +583,8 @@ class GCSUploadSessionCompleteSensor(BaseSensorOperator):
575
583
  )
576
584
 
577
585
  def execute_complete(self, context: dict[str, Any], event: dict[str, str] | None = None) -> str:
578
- """Rely on trigger to throw an exception, otherwise it assumes execution was successful.
586
+ """
587
+ Rely on trigger to throw an exception, otherwise it assumes execution was successful.
579
588
 
580
589
  Callback for when the trigger fires - returns immediately.
581
590
 
@@ -185,7 +185,8 @@ class PubSubPullSensor(BaseSensorOperator):
185
185
  pulled_messages: list[ReceivedMessage],
186
186
  context: Context,
187
187
  ):
188
- """Convert `ReceivedMessage` objects into JSON-serializable dicts.
188
+ """
189
+ Convert `ReceivedMessage` objects into JSON-serializable dicts.
189
190
 
190
191
  This method can be overridden by subclasses or by `messages_callback` constructor argument.
191
192
 
@@ -142,8 +142,6 @@ class BigQueryToGCSOperator(BaseOperator):
142
142
  self.hook: BigQueryHook | None = None
143
143
  self.deferrable = deferrable
144
144
 
145
- self._job_id: str = ""
146
-
147
145
  @staticmethod
148
146
  def _handle_job_error(job: BigQueryJob | UnknownJob) -> None:
149
147
  if job.error_result:
@@ -212,7 +210,7 @@ class BigQueryToGCSOperator(BaseOperator):
212
210
  self.hook = hook
213
211
 
214
212
  configuration = self._prepare_configuration()
215
- job_id = hook.generate_job_id(
213
+ self.job_id = hook.generate_job_id(
216
214
  job_id=self.job_id,
217
215
  dag_id=self.dag_id,
218
216
  task_id=self.task_id,
@@ -224,14 +222,14 @@ class BigQueryToGCSOperator(BaseOperator):
224
222
  try:
225
223
  self.log.info("Executing: %s", configuration)
226
224
  job: BigQueryJob | UnknownJob = self._submit_job(
227
- hook=hook, job_id=job_id, configuration=configuration
225
+ hook=hook, job_id=self.job_id, configuration=configuration
228
226
  )
229
227
  except Conflict:
230
228
  # If the job already exists retrieve it
231
229
  job = hook.get_job(
232
230
  project_id=self.project_id,
233
231
  location=self.location,
234
- job_id=job_id,
232
+ job_id=self.job_id,
235
233
  )
236
234
  if job.state in self.reattach_states:
237
235
  # We are reattaching to a job
@@ -240,12 +238,12 @@ class BigQueryToGCSOperator(BaseOperator):
240
238
  else:
241
239
  # Same job configuration so we need force_rerun
242
240
  raise AirflowException(
243
- f"Job with id: {job_id} already exists and is in {job.state} state. If you "
241
+ f"Job with id: {self.job_id} already exists and is in {job.state} state. If you "
244
242
  f"want to force rerun it consider setting `force_rerun=True`."
245
243
  f"Or, if you want to reattach in this scenario add {job.state} to `reattach_states`"
246
244
  )
247
245
 
248
- self._job_id = job.job_id
246
+ self.job_id = job.job_id
249
247
  conf = job.to_api_repr()["configuration"]["extract"]["sourceTable"]
250
248
  dataset_id, project_id, table_id = conf["datasetId"], conf["projectId"], conf["tableId"]
251
249
  BigQueryTableLink.persist(
@@ -261,7 +259,7 @@ class BigQueryToGCSOperator(BaseOperator):
261
259
  timeout=self.execution_timeout,
262
260
  trigger=BigQueryInsertJobTrigger(
263
261
  conn_id=self.gcp_conn_id,
264
- job_id=self._job_id,
262
+ job_id=self.job_id,
265
263
  project_id=self.project_id or self.hook.project_id,
266
264
  location=self.location or self.hook.location,
267
265
  impersonation_chain=self.impersonation_chain,
@@ -272,7 +270,8 @@ class BigQueryToGCSOperator(BaseOperator):
272
270
  job.result(timeout=self.result_timeout, retry=self.result_retry)
273
271
 
274
272
  def execute_complete(self, context: Context, event: dict[str, Any]):
275
- """Return immediately and relies on trigger to throw a success event. Callback for the trigger.
273
+ """
274
+ Return immediately and relies on trigger to throw a success event. Callback for the trigger.
276
275
 
277
276
  Relies on trigger to throw an exception, otherwise it assumes execution was successful.
278
277
  """
@@ -283,6 +282,8 @@ class BigQueryToGCSOperator(BaseOperator):
283
282
  self.task_id,
284
283
  event["message"],
285
284
  )
285
+ # Save job_id as an attribute to be later used by listeners
286
+ self.job_id = event.get("job_id")
286
287
 
287
288
  def get_openlineage_facets_on_complete(self, task_instance):
288
289
  """Implement on_complete as we will include final BQ job id."""
@@ -302,7 +303,15 @@ class BigQueryToGCSOperator(BaseOperator):
302
303
  )
303
304
  from airflow.providers.openlineage.extractors import OperatorLineage
304
305
 
305
- table_object = self.hook.get_client(self.hook.project_id).get_table(self.source_project_dataset_table)
306
+ if not self.hook:
307
+ self.hook = BigQueryHook(
308
+ gcp_conn_id=self.gcp_conn_id,
309
+ location=self.location,
310
+ impersonation_chain=self.impersonation_chain,
311
+ )
312
+
313
+ project_id = self.project_id or self.hook.project_id
314
+ table_object = self.hook.get_client(project_id).get_table(self.source_project_dataset_table)
306
315
 
307
316
  input_dataset = Dataset(
308
317
  namespace="bigquery",
@@ -346,9 +355,9 @@ class BigQueryToGCSOperator(BaseOperator):
346
355
  output_datasets.append(dataset)
347
356
 
348
357
  run_facets = {}
349
- if self._job_id:
358
+ if self.job_id:
350
359
  run_facets = {
351
- "externalQuery": ExternalQueryRunFacet(externalQueryId=self._job_id, source="bigquery"),
360
+ "externalQuery": ExternalQueryRunFacet(externalQueryId=self.job_id, source="bigquery"),
352
361
  }
353
362
 
354
363
  return OperatorLineage(inputs=[input_dataset], outputs=output_datasets, run_facets=run_facets)
@@ -76,7 +76,7 @@ class BigQueryToPostgresOperator(BigQueryToSqlBaseOperator):
76
76
  self.replace_index = replace_index
77
77
 
78
78
  def get_sql_hook(self) -> PostgresHook:
79
- return PostgresHook(schema=self.database, postgres_conn_id=self.postgres_conn_id)
79
+ return PostgresHook(database=self.database, postgres_conn_id=self.postgres_conn_id)
80
80
 
81
81
  def execute(self, context: Context) -> None:
82
82
  big_query_hook = BigQueryHook(
@@ -43,7 +43,8 @@ class FlushAction(Enum):
43
43
 
44
44
 
45
45
  class FacebookAdsReportToGcsOperator(BaseOperator):
46
- """Fetch from Facebook Ads API.
46
+ """
47
+ Fetch from Facebook Ads API.
47
48
 
48
49
  This converts and saves the data as a temporary JSON file, and uploads the
49
50
  JSON to Google Cloud Storage.
@@ -449,7 +449,8 @@ class GCSToBigQueryOperator(BaseOperator):
449
449
  return self._find_max_value_in_column()
450
450
 
451
451
  def execute_complete(self, context: Context, event: dict[str, Any]):
452
- """Return immediately and relies on trigger to throw a success event. Callback for the trigger.
452
+ """
453
+ Return immediately and relies on trigger to throw a success event. Callback for the trigger.
453
454
 
454
455
  Relies on trigger to throw an exception, otherwise it assumes execution was successful.
455
456
  """
@@ -460,6 +461,8 @@ class GCSToBigQueryOperator(BaseOperator):
460
461
  self.task_id,
461
462
  event["message"],
462
463
  )
464
+ # Save job_id as an attribute to be later used by listeners
465
+ self.job_id = event.get("job_id")
463
466
  return self._find_max_value_in_column()
464
467
 
465
468
  def _find_max_value_in_column(self):
@@ -756,17 +759,26 @@ class GCSToBigQueryOperator(BaseOperator):
756
759
  )
757
760
  from airflow.providers.openlineage.extractors import OperatorLineage
758
761
 
759
- table_object = self.hook.get_client(self.hook.project_id).get_table(
760
- self.destination_project_dataset_table
761
- )
762
+ if not self.hook:
763
+ self.hook = BigQueryHook(
764
+ gcp_conn_id=self.gcp_conn_id,
765
+ location=self.location,
766
+ impersonation_chain=self.impersonation_chain,
767
+ )
768
+
769
+ project_id = self.project_id or self.hook.project_id
770
+ table_object = self.hook.get_client(project_id).get_table(self.destination_project_dataset_table)
762
771
 
763
772
  output_dataset_facets = get_facets_from_bq_table(table_object)
764
773
 
774
+ source_objects = (
775
+ self.source_objects if isinstance(self.source_objects, list) else [self.source_objects]
776
+ )
765
777
  input_dataset_facets = {
766
778
  "schema": output_dataset_facets["schema"],
767
779
  }
768
780
  input_datasets = []
769
- for blob in sorted(self.source_objects):
781
+ for blob in sorted(source_objects):
770
782
  additional_facets = {}
771
783
 
772
784
  if "*" in blob: