apache-airflow-providers-google 10.22.0rc1__py3-none-any.whl → 10.23.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/google/__init__.py +1 -1
- airflow/providers/google/cloud/hooks/bigquery.py +91 -54
- airflow/providers/google/cloud/hooks/cloud_build.py +3 -2
- airflow/providers/google/cloud/hooks/dataflow.py +112 -47
- airflow/providers/google/cloud/hooks/datapipeline.py +3 -3
- airflow/providers/google/cloud/hooks/kubernetes_engine.py +15 -26
- airflow/providers/google/cloud/hooks/life_sciences.py +5 -7
- airflow/providers/google/cloud/hooks/secret_manager.py +3 -3
- airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py +28 -8
- airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py +11 -6
- airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +214 -34
- airflow/providers/google/cloud/hooks/vertex_ai/model_service.py +11 -4
- airflow/providers/google/cloud/links/automl.py +13 -22
- airflow/providers/google/cloud/log/gcs_task_handler.py +1 -2
- airflow/providers/google/cloud/operators/bigquery.py +6 -4
- airflow/providers/google/cloud/operators/dataflow.py +186 -4
- airflow/providers/google/cloud/operators/datafusion.py +3 -2
- airflow/providers/google/cloud/operators/datapipeline.py +5 -6
- airflow/providers/google/cloud/operators/dataproc.py +30 -33
- airflow/providers/google/cloud/operators/gcs.py +4 -4
- airflow/providers/google/cloud/operators/kubernetes_engine.py +16 -2
- airflow/providers/google/cloud/operators/life_sciences.py +5 -7
- airflow/providers/google/cloud/operators/mlengine.py +42 -65
- airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py +18 -4
- airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +5 -5
- airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +280 -9
- airflow/providers/google/cloud/operators/vertex_ai/model_service.py +4 -0
- airflow/providers/google/cloud/secrets/secret_manager.py +3 -5
- airflow/providers/google/cloud/sensors/bigquery.py +8 -27
- airflow/providers/google/cloud/sensors/bigquery_dts.py +1 -4
- airflow/providers/google/cloud/sensors/cloud_composer.py +9 -14
- airflow/providers/google/cloud/sensors/dataflow.py +1 -25
- airflow/providers/google/cloud/sensors/dataform.py +1 -4
- airflow/providers/google/cloud/sensors/datafusion.py +1 -7
- airflow/providers/google/cloud/sensors/dataplex.py +1 -31
- airflow/providers/google/cloud/sensors/dataproc.py +1 -16
- airflow/providers/google/cloud/sensors/dataproc_metastore.py +1 -7
- airflow/providers/google/cloud/sensors/gcs.py +5 -27
- airflow/providers/google/cloud/sensors/looker.py +1 -13
- airflow/providers/google/cloud/sensors/pubsub.py +11 -5
- airflow/providers/google/cloud/sensors/workflows.py +1 -4
- airflow/providers/google/cloud/transfers/sftp_to_gcs.py +6 -0
- airflow/providers/google/cloud/triggers/dataflow.py +145 -1
- airflow/providers/google/cloud/triggers/kubernetes_engine.py +66 -3
- airflow/providers/google/common/deprecated.py +176 -0
- airflow/providers/google/common/hooks/base_google.py +3 -2
- airflow/providers/google/get_provider_info.py +8 -10
- airflow/providers/google/marketing_platform/hooks/analytics.py +4 -2
- airflow/providers/google/marketing_platform/hooks/search_ads.py +169 -30
- airflow/providers/google/marketing_platform/operators/analytics.py +16 -33
- airflow/providers/google/marketing_platform/operators/search_ads.py +217 -156
- airflow/providers/google/marketing_platform/sensors/display_video.py +1 -4
- {apache_airflow_providers_google-10.22.0rc1.dist-info → apache_airflow_providers_google-10.23.0.dist-info}/METADATA +25 -23
- {apache_airflow_providers_google-10.22.0rc1.dist-info → apache_airflow_providers_google-10.23.0.dist-info}/RECORD +56 -56
- airflow/providers/google/marketing_platform/sensors/search_ads.py +0 -92
- {apache_airflow_providers_google-10.22.0rc1.dist-info → apache_airflow_providers_google-10.23.0.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_google-10.22.0rc1.dist-info → apache_airflow_providers_google-10.23.0.dist-info}/entry_points.txt +0 -0
@@ -21,18 +21,20 @@ from __future__ import annotations
|
|
21
21
|
|
22
22
|
from typing import TYPE_CHECKING, Sequence
|
23
23
|
|
24
|
-
from
|
24
|
+
from google.cloud.aiplatform_v1beta1 import types as types_v1beta1
|
25
25
|
|
26
26
|
from airflow.exceptions import AirflowProviderDeprecationWarning
|
27
27
|
from airflow.providers.google.cloud.hooks.vertex_ai.generative_model import GenerativeModelHook
|
28
28
|
from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
|
29
|
+
from airflow.providers.google.common.deprecated import deprecated
|
29
30
|
|
30
31
|
if TYPE_CHECKING:
|
31
32
|
from airflow.utils.context import Context
|
32
33
|
|
33
34
|
|
34
35
|
@deprecated(
|
35
|
-
|
36
|
+
planned_removal_date="January 01, 2025",
|
37
|
+
use_instead="TextGenerationModelPredictOperator",
|
36
38
|
category=AirflowProviderDeprecationWarning,
|
37
39
|
)
|
38
40
|
class PromptLanguageModelOperator(GoogleCloudBaseOperator):
|
@@ -121,7 +123,8 @@ class PromptLanguageModelOperator(GoogleCloudBaseOperator):
|
|
121
123
|
|
122
124
|
|
123
125
|
@deprecated(
|
124
|
-
|
126
|
+
planned_removal_date="January 01, 2025",
|
127
|
+
use_instead="TextEmbeddingModelGetEmbeddingsOperator",
|
125
128
|
category=AirflowProviderDeprecationWarning,
|
126
129
|
)
|
127
130
|
class GenerateTextEmbeddingsOperator(GoogleCloudBaseOperator):
|
@@ -189,7 +192,8 @@ class GenerateTextEmbeddingsOperator(GoogleCloudBaseOperator):
|
|
189
192
|
|
190
193
|
|
191
194
|
@deprecated(
|
192
|
-
|
195
|
+
planned_removal_date="January 01, 2025",
|
196
|
+
use_instead="GenerativeModelGenerateContentOperator",
|
193
197
|
category=AirflowProviderDeprecationWarning,
|
194
198
|
)
|
195
199
|
class PromptMultimodalModelOperator(GoogleCloudBaseOperator):
|
@@ -265,7 +269,8 @@ class PromptMultimodalModelOperator(GoogleCloudBaseOperator):
|
|
265
269
|
|
266
270
|
|
267
271
|
@deprecated(
|
268
|
-
|
272
|
+
planned_removal_date="January 01, 2025",
|
273
|
+
use_instead="GenerativeModelGenerateContentOperator",
|
269
274
|
category=AirflowProviderDeprecationWarning,
|
270
275
|
)
|
271
276
|
class PromptMultimodalModelWithMediaOperator(GoogleCloudBaseOperator):
|
@@ -504,12 +509,14 @@ class GenerativeModelGenerateContentOperator(GoogleCloudBaseOperator):
|
|
504
509
|
|
505
510
|
:param project_id: Required. The ID of the Google Cloud project that the
|
506
511
|
service belongs to (templated).
|
507
|
-
:param contents: Required. The multi-part content of a message that a user or a program
|
508
|
-
gives to the generative model, in order to elicit a specific response.
|
509
512
|
:param location: Required. The ID of the Google Cloud location that the
|
510
513
|
service belongs to (templated).
|
514
|
+
:param contents: Required. The multi-part content of a message that a user or a program
|
515
|
+
gives to the generative model, in order to elicit a specific response.
|
511
516
|
:param generation_config: Optional. Generation configuration settings.
|
512
517
|
:param safety_settings: Optional. Per request settings for blocking unsafe content.
|
518
|
+
:param tools: Optional. A list of tools available to the model during evaluation, such as a data store.
|
519
|
+
:param system_instruction: Optional. An instruction given to the model to guide its behavior.
|
513
520
|
:param pretrained_model: By default uses the pre-trained model `gemini-pro`,
|
514
521
|
supporting prompts with text-only input, including natural language
|
515
522
|
tasks, multi-turn text and code chat, and code generation. It can
|
@@ -525,17 +532,18 @@ class GenerativeModelGenerateContentOperator(GoogleCloudBaseOperator):
|
|
525
532
|
account from the list granting this role to the originating account (templated).
|
526
533
|
"""
|
527
534
|
|
528
|
-
template_fields = ("location", "project_id", "impersonation_chain", "contents")
|
535
|
+
template_fields = ("location", "project_id", "impersonation_chain", "contents", "pretrained_model")
|
529
536
|
|
530
537
|
def __init__(
|
531
538
|
self,
|
532
539
|
*,
|
533
540
|
project_id: str,
|
534
|
-
contents: list,
|
535
541
|
location: str,
|
542
|
+
contents: list,
|
536
543
|
tools: list | None = None,
|
537
544
|
generation_config: dict | None = None,
|
538
545
|
safety_settings: dict | None = None,
|
546
|
+
system_instruction: str | None = None,
|
539
547
|
pretrained_model: str = "gemini-pro",
|
540
548
|
gcp_conn_id: str = "google_cloud_default",
|
541
549
|
impersonation_chain: str | Sequence[str] | None = None,
|
@@ -548,6 +556,7 @@ class GenerativeModelGenerateContentOperator(GoogleCloudBaseOperator):
|
|
548
556
|
self.tools = tools
|
549
557
|
self.generation_config = generation_config
|
550
558
|
self.safety_settings = safety_settings
|
559
|
+
self.system_instruction = system_instruction
|
551
560
|
self.pretrained_model = pretrained_model
|
552
561
|
self.gcp_conn_id = gcp_conn_id
|
553
562
|
self.impersonation_chain = impersonation_chain
|
@@ -564,6 +573,7 @@ class GenerativeModelGenerateContentOperator(GoogleCloudBaseOperator):
|
|
564
573
|
tools=self.tools,
|
565
574
|
generation_config=self.generation_config,
|
566
575
|
safety_settings=self.safety_settings,
|
576
|
+
system_instruction=self.system_instruction,
|
567
577
|
pretrained_model=self.pretrained_model,
|
568
578
|
)
|
569
579
|
|
@@ -571,3 +581,264 @@ class GenerativeModelGenerateContentOperator(GoogleCloudBaseOperator):
|
|
571
581
|
self.xcom_push(context, key="model_response", value=response)
|
572
582
|
|
573
583
|
return response
|
584
|
+
|
585
|
+
|
586
|
+
class SupervisedFineTuningTrainOperator(GoogleCloudBaseOperator):
|
587
|
+
"""
|
588
|
+
Use the Supervised Fine Tuning API to create a tuning job.
|
589
|
+
|
590
|
+
:param project_id: Required. The ID of the Google Cloud project that the
|
591
|
+
service belongs to.
|
592
|
+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
593
|
+
:param source_model: Required. A pre-trained model optimized for performing natural
|
594
|
+
language tasks such as classification, summarization, extraction, content
|
595
|
+
creation, and ideation.
|
596
|
+
:param train_dataset: Required. Cloud Storage URI of your training dataset. The dataset
|
597
|
+
must be formatted as a JSONL file. For best results, provide at least 100 to 500 examples.
|
598
|
+
:param tuned_model_display_name: Optional. Display name of the TunedModel. The name can be up
|
599
|
+
to 128 characters long and can consist of any UTF-8 characters.
|
600
|
+
:param validation_dataset: Optional. Cloud Storage URI of your training dataset. The dataset must be
|
601
|
+
formatted as a JSONL file. For best results, provide at least 100 to 500 examples.
|
602
|
+
:param epochs: Optional. To optimize performance on a specific dataset, try using a higher
|
603
|
+
epoch value. Increasing the number of epochs might improve results. However, be cautious
|
604
|
+
about over-fitting, especially when dealing with small datasets. If over-fitting occurs,
|
605
|
+
consider lowering the epoch number.
|
606
|
+
:param adapter_size: Optional. Adapter size for tuning.
|
607
|
+
:param learning_multiplier_rate: Optional. Multiplier for adjusting the default learning rate.
|
608
|
+
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
|
609
|
+
:param impersonation_chain: Optional service account to impersonate using short-term
|
610
|
+
credentials, or chained list of accounts required to get the access_token
|
611
|
+
of the last account in the list, which will be impersonated in the request.
|
612
|
+
If set as a string, the account must grant the originating account
|
613
|
+
the Service Account Token Creator IAM role.
|
614
|
+
If set as a sequence, the identities from the list must grant
|
615
|
+
Service Account Token Creator IAM role to the directly preceding identity, with first
|
616
|
+
account from the list granting this role to the originating account (templated).
|
617
|
+
"""
|
618
|
+
|
619
|
+
template_fields = ("location", "project_id", "impersonation_chain", "train_dataset", "validation_dataset")
|
620
|
+
|
621
|
+
def __init__(
|
622
|
+
self,
|
623
|
+
*,
|
624
|
+
project_id: str,
|
625
|
+
location: str,
|
626
|
+
source_model: str,
|
627
|
+
train_dataset: str,
|
628
|
+
tuned_model_display_name: str | None = None,
|
629
|
+
validation_dataset: str | None = None,
|
630
|
+
epochs: int | None = None,
|
631
|
+
adapter_size: int | None = None,
|
632
|
+
learning_rate_multiplier: float | None = None,
|
633
|
+
gcp_conn_id: str = "google_cloud_default",
|
634
|
+
impersonation_chain: str | Sequence[str] | None = None,
|
635
|
+
**kwargs,
|
636
|
+
) -> None:
|
637
|
+
super().__init__(**kwargs)
|
638
|
+
self.project_id = project_id
|
639
|
+
self.location = location
|
640
|
+
self.source_model = source_model
|
641
|
+
self.train_dataset = train_dataset
|
642
|
+
self.tuned_model_display_name = tuned_model_display_name
|
643
|
+
self.validation_dataset = validation_dataset
|
644
|
+
self.epochs = epochs
|
645
|
+
self.adapter_size = adapter_size
|
646
|
+
self.learning_rate_multiplier = learning_rate_multiplier
|
647
|
+
self.gcp_conn_id = gcp_conn_id
|
648
|
+
self.impersonation_chain = impersonation_chain
|
649
|
+
|
650
|
+
def execute(self, context: Context):
|
651
|
+
self.hook = GenerativeModelHook(
|
652
|
+
gcp_conn_id=self.gcp_conn_id,
|
653
|
+
impersonation_chain=self.impersonation_chain,
|
654
|
+
)
|
655
|
+
response = self.hook.supervised_fine_tuning_train(
|
656
|
+
project_id=self.project_id,
|
657
|
+
location=self.location,
|
658
|
+
source_model=self.source_model,
|
659
|
+
train_dataset=self.train_dataset,
|
660
|
+
validation_dataset=self.validation_dataset,
|
661
|
+
epochs=self.epochs,
|
662
|
+
adapter_size=self.adapter_size,
|
663
|
+
learning_rate_multiplier=self.learning_rate_multiplier,
|
664
|
+
tuned_model_display_name=self.tuned_model_display_name,
|
665
|
+
)
|
666
|
+
|
667
|
+
self.log.info("Tuned Model Name: %s", response.tuned_model_name)
|
668
|
+
self.log.info("Tuned Model Endpoint Name: %s", response.tuned_model_endpoint_name)
|
669
|
+
|
670
|
+
self.xcom_push(context, key="tuned_model_name", value=response.tuned_model_name)
|
671
|
+
self.xcom_push(context, key="tuned_model_endpoint_name", value=response.tuned_model_endpoint_name)
|
672
|
+
|
673
|
+
result = {
|
674
|
+
"tuned_model_name": response.tuned_model_name,
|
675
|
+
"tuned_model_endpoint_name": response.tuned_model_endpoint_name,
|
676
|
+
}
|
677
|
+
|
678
|
+
return result
|
679
|
+
|
680
|
+
|
681
|
+
class CountTokensOperator(GoogleCloudBaseOperator):
|
682
|
+
"""
|
683
|
+
Use the Vertex AI Count Tokens API to calculate the number of input tokens before sending a request to the Gemini API.
|
684
|
+
|
685
|
+
:param project_id: Required. The ID of the Google Cloud project that the
|
686
|
+
service belongs to (templated).
|
687
|
+
:param location: Required. The ID of the Google Cloud location that the
|
688
|
+
service belongs to (templated).
|
689
|
+
:param contents: Required. The multi-part content of a message that a user or a program
|
690
|
+
gives to the generative model, in order to elicit a specific response.
|
691
|
+
:param pretrained_model: By default uses the pre-trained model `gemini-pro`,
|
692
|
+
supporting prompts with text-only input, including natural language
|
693
|
+
tasks, multi-turn text and code chat, and code generation. It can
|
694
|
+
output text and code.
|
695
|
+
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
|
696
|
+
:param impersonation_chain: Optional service account to impersonate using short-term
|
697
|
+
credentials, or chained list of accounts required to get the access_token
|
698
|
+
of the last account in the list, which will be impersonated in the request.
|
699
|
+
If set as a string, the account must grant the originating account
|
700
|
+
the Service Account Token Creator IAM role.
|
701
|
+
If set as a sequence, the identities from the list must grant
|
702
|
+
Service Account Token Creator IAM role to the directly preceding identity, with first
|
703
|
+
account from the list granting this role to the originating account (templated).
|
704
|
+
"""
|
705
|
+
|
706
|
+
template_fields = ("location", "project_id", "impersonation_chain", "contents", "pretrained_model")
|
707
|
+
|
708
|
+
def __init__(
|
709
|
+
self,
|
710
|
+
*,
|
711
|
+
project_id: str,
|
712
|
+
location: str,
|
713
|
+
contents: list,
|
714
|
+
pretrained_model: str = "gemini-pro",
|
715
|
+
gcp_conn_id: str = "google_cloud_default",
|
716
|
+
impersonation_chain: str | Sequence[str] | None = None,
|
717
|
+
**kwargs,
|
718
|
+
) -> None:
|
719
|
+
super().__init__(**kwargs)
|
720
|
+
self.project_id = project_id
|
721
|
+
self.location = location
|
722
|
+
self.contents = contents
|
723
|
+
self.pretrained_model = pretrained_model
|
724
|
+
self.gcp_conn_id = gcp_conn_id
|
725
|
+
self.impersonation_chain = impersonation_chain
|
726
|
+
|
727
|
+
def execute(self, context: Context):
|
728
|
+
self.hook = GenerativeModelHook(
|
729
|
+
gcp_conn_id=self.gcp_conn_id,
|
730
|
+
impersonation_chain=self.impersonation_chain,
|
731
|
+
)
|
732
|
+
response = self.hook.count_tokens(
|
733
|
+
project_id=self.project_id,
|
734
|
+
location=self.location,
|
735
|
+
contents=self.contents,
|
736
|
+
pretrained_model=self.pretrained_model,
|
737
|
+
)
|
738
|
+
|
739
|
+
self.log.info("Total tokens: %s", response.total_tokens)
|
740
|
+
self.log.info("Total billable characters: %s", response.total_billable_characters)
|
741
|
+
|
742
|
+
self.xcom_push(context, key="total_tokens", value=response.total_tokens)
|
743
|
+
self.xcom_push(context, key="total_billable_characters", value=response.total_billable_characters)
|
744
|
+
|
745
|
+
return types_v1beta1.CountTokensResponse.to_dict(response)
|
746
|
+
|
747
|
+
|
748
|
+
class RunEvaluationOperator(GoogleCloudBaseOperator):
|
749
|
+
"""
|
750
|
+
Use the Rapid Evaluation API to evaluate a model.
|
751
|
+
|
752
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
753
|
+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
754
|
+
:param pretrained_model: Required. A pre-trained model optimized for performing natural
|
755
|
+
language tasks such as classification, summarization, extraction, content
|
756
|
+
creation, and ideation.
|
757
|
+
:param eval_dataset: Required. A fixed dataset for evaluating a model against. Adheres to Rapid Evaluation API.
|
758
|
+
:param metrics: Required. A list of evaluation metrics to be used in the experiment. Adheres to Rapid Evaluation API.
|
759
|
+
:param experiment_name: Required. The name of the evaluation experiment.
|
760
|
+
:param experiment_run_name: Required. The specific run name or ID for this experiment.
|
761
|
+
:param prompt_template: Required. The template used to format the model's prompts during evaluation. Adheres to Rapid Evaluation API.
|
762
|
+
:param generation_config: Optional. A dictionary containing generation parameters for the model.
|
763
|
+
:param safety_settings: Optional. A dictionary specifying harm category thresholds for blocking model outputs.
|
764
|
+
:param system_instruction: Optional. An instruction given to the model to guide its behavior.
|
765
|
+
:param tools: Optional. A list of tools available to the model during evaluation, such as a data store.
|
766
|
+
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
|
767
|
+
:param impersonation_chain: Optional service account to impersonate using short-term
|
768
|
+
credentials, or chained list of accounts required to get the access_token
|
769
|
+
of the last account in the list, which will be impersonated in the request.
|
770
|
+
If set as a string, the account must grant the originating account
|
771
|
+
the Service Account Token Creator IAM role.
|
772
|
+
If set as a sequence, the identities from the list must grant
|
773
|
+
Service Account Token Creator IAM role to the directly preceding identity, with first
|
774
|
+
account from the list granting this role to the originating account (templated).
|
775
|
+
"""
|
776
|
+
|
777
|
+
template_fields = (
|
778
|
+
"location",
|
779
|
+
"project_id",
|
780
|
+
"impersonation_chain",
|
781
|
+
"pretrained_model",
|
782
|
+
"eval_dataset",
|
783
|
+
"prompt_template",
|
784
|
+
"experiment_name",
|
785
|
+
"experiment_run_name",
|
786
|
+
)
|
787
|
+
|
788
|
+
def __init__(
|
789
|
+
self,
|
790
|
+
*,
|
791
|
+
project_id: str,
|
792
|
+
location: str,
|
793
|
+
pretrained_model: str,
|
794
|
+
eval_dataset: dict,
|
795
|
+
metrics: list,
|
796
|
+
experiment_name: str,
|
797
|
+
experiment_run_name: str,
|
798
|
+
prompt_template: str,
|
799
|
+
generation_config: dict | None = None,
|
800
|
+
safety_settings: dict | None = None,
|
801
|
+
system_instruction: str | None = None,
|
802
|
+
tools: list | None = None,
|
803
|
+
gcp_conn_id: str = "google_cloud_default",
|
804
|
+
impersonation_chain: str | Sequence[str] | None = None,
|
805
|
+
**kwargs,
|
806
|
+
) -> None:
|
807
|
+
super().__init__(**kwargs)
|
808
|
+
|
809
|
+
self.project_id = project_id
|
810
|
+
self.location = location
|
811
|
+
self.pretrained_model = pretrained_model
|
812
|
+
self.eval_dataset = eval_dataset
|
813
|
+
self.metrics = metrics
|
814
|
+
self.experiment_name = experiment_name
|
815
|
+
self.experiment_run_name = experiment_run_name
|
816
|
+
self.prompt_template = prompt_template
|
817
|
+
self.generation_config = generation_config
|
818
|
+
self.safety_settings = safety_settings
|
819
|
+
self.system_instruction = system_instruction
|
820
|
+
self.tools = tools
|
821
|
+
self.gcp_conn_id = gcp_conn_id
|
822
|
+
self.impersonation_chain = impersonation_chain
|
823
|
+
|
824
|
+
def execute(self, context: Context):
|
825
|
+
self.hook = GenerativeModelHook(
|
826
|
+
gcp_conn_id=self.gcp_conn_id,
|
827
|
+
impersonation_chain=self.impersonation_chain,
|
828
|
+
)
|
829
|
+
response = self.hook.run_evaluation(
|
830
|
+
project_id=self.project_id,
|
831
|
+
location=self.location,
|
832
|
+
pretrained_model=self.pretrained_model,
|
833
|
+
eval_dataset=self.eval_dataset,
|
834
|
+
metrics=self.metrics,
|
835
|
+
experiment_name=self.experiment_name,
|
836
|
+
experiment_run_name=self.experiment_run_name,
|
837
|
+
prompt_template=self.prompt_template,
|
838
|
+
generation_config=self.generation_config,
|
839
|
+
safety_settings=self.safety_settings,
|
840
|
+
system_instruction=self.system_instruction,
|
841
|
+
tools=self.tools,
|
842
|
+
)
|
843
|
+
|
844
|
+
return response.summary_metrics
|
@@ -362,6 +362,7 @@ class UploadModelOperator(GoogleCloudBaseOperator):
|
|
362
362
|
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
363
363
|
:param region: Required. The ID of the Google Cloud region that the service belongs to.
|
364
364
|
:param model: Required. The Model to create.
|
365
|
+
:param parent_model: The name of the parent model to create a new version under.
|
365
366
|
:param retry: Designation of what errors, if any, should be retried.
|
366
367
|
:param timeout: The timeout for this request.
|
367
368
|
:param metadata: Strings which should be sent along with the request as metadata.
|
@@ -385,6 +386,7 @@ class UploadModelOperator(GoogleCloudBaseOperator):
|
|
385
386
|
project_id: str,
|
386
387
|
region: str,
|
387
388
|
model: Model | dict,
|
389
|
+
parent_model: str | None = None,
|
388
390
|
retry: Retry | _MethodDefault = DEFAULT,
|
389
391
|
timeout: float | None = None,
|
390
392
|
metadata: Sequence[tuple[str, str]] = (),
|
@@ -396,6 +398,7 @@ class UploadModelOperator(GoogleCloudBaseOperator):
|
|
396
398
|
self.project_id = project_id
|
397
399
|
self.region = region
|
398
400
|
self.model = model
|
401
|
+
self.parent_model = parent_model
|
399
402
|
self.retry = retry
|
400
403
|
self.timeout = timeout
|
401
404
|
self.metadata = metadata
|
@@ -412,6 +415,7 @@ class UploadModelOperator(GoogleCloudBaseOperator):
|
|
412
415
|
project_id=self.project_id,
|
413
416
|
region=self.region,
|
414
417
|
model=self.model,
|
418
|
+
parent_model=self.parent_model,
|
415
419
|
retry=self.retry,
|
416
420
|
timeout=self.timeout,
|
417
421
|
metadata=self.metadata,
|
@@ -21,7 +21,6 @@ from __future__ import annotations
|
|
21
21
|
import logging
|
22
22
|
from typing import Sequence
|
23
23
|
|
24
|
-
from deprecated import deprecated
|
25
24
|
from google.auth.exceptions import DefaultCredentialsError
|
26
25
|
|
27
26
|
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
|
@@ -30,6 +29,7 @@ from airflow.providers.google.cloud.utils.credentials_provider import (
|
|
30
29
|
_get_target_principal_and_delegates,
|
31
30
|
get_credentials_and_project_id,
|
32
31
|
)
|
32
|
+
from airflow.providers.google.common.deprecated import deprecated
|
33
33
|
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID
|
34
34
|
from airflow.secrets import BaseSecretsBackend
|
35
35
|
from airflow.utils.log.logging_mixin import LoggingMixin
|
@@ -162,10 +162,8 @@ class CloudSecretManagerBackend(BaseSecretsBackend, LoggingMixin):
|
|
162
162
|
return self._get_secret(self.connections_prefix, conn_id)
|
163
163
|
|
164
164
|
@deprecated(
|
165
|
-
|
166
|
-
|
167
|
-
"in a future release. Please use method `get_conn_value` instead."
|
168
|
-
),
|
165
|
+
planned_removal_date="November 01, 2024",
|
166
|
+
use_instead="get_conn_value",
|
169
167
|
category=AirflowProviderDeprecationWarning,
|
170
168
|
)
|
171
169
|
def get_conn_uri(self, conn_id: str) -> str | None:
|
@@ -23,15 +23,14 @@ import warnings
|
|
23
23
|
from datetime import timedelta
|
24
24
|
from typing import TYPE_CHECKING, Any, Sequence
|
25
25
|
|
26
|
-
from deprecated import deprecated
|
27
|
-
|
28
26
|
from airflow.configuration import conf
|
29
|
-
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
|
27
|
+
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
|
30
28
|
from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook
|
31
29
|
from airflow.providers.google.cloud.triggers.bigquery import (
|
32
30
|
BigQueryTableExistenceTrigger,
|
33
31
|
BigQueryTablePartitionExistenceTrigger,
|
34
32
|
)
|
33
|
+
from airflow.providers.google.common.deprecated import deprecated
|
35
34
|
from airflow.sensors.base import BaseSensorOperator
|
36
35
|
|
37
36
|
if TYPE_CHECKING:
|
@@ -144,15 +143,9 @@ class BigQueryTableExistenceSensor(BaseSensorOperator):
|
|
144
143
|
if event:
|
145
144
|
if event["status"] == "success":
|
146
145
|
return event["message"]
|
147
|
-
# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
|
148
|
-
if self.soft_fail:
|
149
|
-
raise AirflowSkipException(event["message"])
|
150
146
|
raise AirflowException(event["message"])
|
151
147
|
|
152
|
-
# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
|
153
148
|
message = "No event received in trigger callback"
|
154
|
-
if self.soft_fail:
|
155
|
-
raise AirflowSkipException(message)
|
156
149
|
raise AirflowException(message)
|
157
150
|
|
158
151
|
|
@@ -260,25 +253,16 @@ class BigQueryTablePartitionExistenceSensor(BaseSensorOperator):
|
|
260
253
|
if event["status"] == "success":
|
261
254
|
return event["message"]
|
262
255
|
|
263
|
-
# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
|
264
|
-
if self.soft_fail:
|
265
|
-
raise AirflowSkipException(event["message"])
|
266
256
|
raise AirflowException(event["message"])
|
267
257
|
|
268
|
-
# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
|
269
258
|
message = "No event received in trigger callback"
|
270
|
-
if self.soft_fail:
|
271
|
-
raise AirflowSkipException(message)
|
272
259
|
raise AirflowException(message)
|
273
260
|
|
274
261
|
|
275
262
|
@deprecated(
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
"Please use `BigQueryTableExistenceSensor` and "
|
280
|
-
"set `deferrable` attribute to `True` instead"
|
281
|
-
),
|
263
|
+
planned_removal_date="November 01, 2024",
|
264
|
+
use_instead="BigQueryTableExistenceSensor",
|
265
|
+
instructions="Please use BigQueryTableExistenceSensor and set deferrable attribute to True.",
|
282
266
|
category=AirflowProviderDeprecationWarning,
|
283
267
|
)
|
284
268
|
class BigQueryTableExistenceAsyncSensor(BigQueryTableExistenceSensor):
|
@@ -315,12 +299,9 @@ class BigQueryTableExistenceAsyncSensor(BigQueryTableExistenceSensor):
|
|
315
299
|
|
316
300
|
|
317
301
|
@deprecated(
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
"Please use `BigQueryTablePartitionExistenceSensor` and "
|
322
|
-
"set `deferrable` attribute to `True` instead"
|
323
|
-
),
|
302
|
+
planned_removal_date="November 01, 2024",
|
303
|
+
use_instead="BigQueryTablePartitionExistenceSensor",
|
304
|
+
instructions="Please use BigQueryTablePartitionExistenceSensor and set deferrable attribute to True.",
|
324
305
|
category=AirflowProviderDeprecationWarning,
|
325
306
|
)
|
326
307
|
class BigQueryTableExistencePartitionAsyncSensor(BigQueryTablePartitionExistenceSensor):
|
@@ -24,7 +24,7 @@ from typing import TYPE_CHECKING, Sequence
|
|
24
24
|
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
|
25
25
|
from google.cloud.bigquery_datatransfer_v1 import TransferState
|
26
26
|
|
27
|
-
from airflow.exceptions import AirflowException
|
27
|
+
from airflow.exceptions import AirflowException
|
28
28
|
from airflow.providers.google.cloud.hooks.bigquery_dts import BiqQueryDataTransferServiceHook
|
29
29
|
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID
|
30
30
|
from airflow.sensors.base import BaseSensorOperator
|
@@ -142,9 +142,6 @@ class BigQueryDataTransferServiceTransferRunSensor(BaseSensorOperator):
|
|
142
142
|
self.log.info("Status of %s run: %s", self.run_id, run.state)
|
143
143
|
|
144
144
|
if run.state in (TransferState.FAILED, TransferState.CANCELLED):
|
145
|
-
# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
|
146
145
|
message = f"Transfer {self.run_id} did not succeed"
|
147
|
-
if self.soft_fail:
|
148
|
-
raise AirflowSkipException(message)
|
149
146
|
raise AirflowException(message)
|
150
147
|
return run.state in self.expected_statuses
|
@@ -24,17 +24,17 @@ from datetime import datetime, timedelta
|
|
24
24
|
from typing import TYPE_CHECKING, Any, Iterable, Sequence
|
25
25
|
|
26
26
|
from dateutil import parser
|
27
|
-
from deprecated import deprecated
|
28
27
|
from google.cloud.orchestration.airflow.service_v1.types import ExecuteAirflowCommandResponse
|
29
28
|
|
30
29
|
from airflow.configuration import conf
|
31
|
-
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
|
30
|
+
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
|
32
31
|
from airflow.providers.google.cloud.hooks.cloud_composer import CloudComposerHook
|
33
32
|
from airflow.providers.google.cloud.triggers.cloud_composer import (
|
34
33
|
CloudComposerDAGRunTrigger,
|
35
34
|
CloudComposerExecutionTrigger,
|
36
35
|
)
|
37
36
|
from airflow.providers.google.common.consts import GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
|
37
|
+
from airflow.providers.google.common.deprecated import deprecated
|
38
38
|
from airflow.sensors.base import BaseSensorOperator
|
39
39
|
from airflow.utils.state import TaskInstanceState
|
40
40
|
|
@@ -43,12 +43,13 @@ if TYPE_CHECKING:
|
|
43
43
|
|
44
44
|
|
45
45
|
@deprecated(
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
)
|
46
|
+
planned_removal_date="November 01, 2024",
|
47
|
+
use_instead="CloudComposerCreateEnvironmentOperator, CloudComposerDeleteEnvironmentOperator, "
|
48
|
+
"CloudComposerUpdateEnvironmentOperator",
|
49
|
+
instructions="Please use CloudComposerCreateEnvironmentOperator, CloudComposerDeleteEnvironmentOperator "
|
50
|
+
"or CloudComposerUpdateEnvironmentOperator in deferrable or non-deferrable mode, "
|
51
|
+
"since since every operator gives user a possibility to wait (asynchronously or synchronously) "
|
52
|
+
"until the Operation is finished.",
|
52
53
|
category=AirflowProviderDeprecationWarning,
|
53
54
|
)
|
54
55
|
class CloudComposerEnvironmentSensor(BaseSensorOperator):
|
@@ -118,15 +119,9 @@ class CloudComposerEnvironmentSensor(BaseSensorOperator):
|
|
118
119
|
if event.get("operation_done"):
|
119
120
|
return event["operation_done"]
|
120
121
|
|
121
|
-
# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
|
122
|
-
if self.soft_fail:
|
123
|
-
raise AirflowSkipException(event["message"])
|
124
122
|
raise AirflowException(event["message"])
|
125
123
|
|
126
|
-
# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
|
127
124
|
message = "No event received in trigger callback"
|
128
|
-
if self.soft_fail:
|
129
|
-
raise AirflowSkipException(message)
|
130
125
|
raise AirflowException(message)
|
131
126
|
|
132
127
|
|