apache-airflow-providers-standard 0.1.1rc1__tar.gz → 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of apache-airflow-providers-standard might be problematic. Click here for more details.

Files changed (40) hide show
  1. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0}/PKG-INFO +8 -8
  2. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0}/README.rst +3 -3
  3. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0}/pyproject.toml +6 -6
  4. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0}/src/airflow/providers/standard/__init__.py +1 -1
  5. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0}/src/airflow/providers/standard/get_provider_info.py +3 -2
  6. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0}/src/airflow/providers/standard/operators/bash.py +11 -48
  7. apache_airflow_providers_standard-0.2.0/src/airflow/providers/standard/operators/branch.py +105 -0
  8. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0}/src/airflow/providers/standard/operators/datetime.py +1 -1
  9. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0}/src/airflow/providers/standard/operators/latest_only.py +1 -1
  10. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0}/src/airflow/providers/standard/operators/python.py +46 -28
  11. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0}/src/airflow/providers/standard/operators/trigger_dagrun.py +52 -22
  12. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0}/src/airflow/providers/standard/operators/weekday.py +3 -3
  13. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0}/src/airflow/providers/standard/sensors/time_delta.py +30 -15
  14. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0}/src/airflow/providers/standard/sensors/weekday.py +17 -3
  15. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0}/src/airflow/providers/standard/utils/python_virtualenv_script.jinja2 +5 -0
  16. apache_airflow_providers_standard-0.2.0/src/airflow/providers/standard/utils/skipmixin.py +192 -0
  17. apache_airflow_providers_standard-0.2.0/src/airflow/providers/standard/utils/weekday.py +77 -0
  18. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0}/src/airflow/providers/standard/LICENSE +0 -0
  19. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0}/src/airflow/providers/standard/hooks/__init__.py +0 -0
  20. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0}/src/airflow/providers/standard/hooks/filesystem.py +0 -0
  21. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0}/src/airflow/providers/standard/hooks/package_index.py +0 -0
  22. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0}/src/airflow/providers/standard/hooks/subprocess.py +0 -0
  23. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0}/src/airflow/providers/standard/operators/__init__.py +0 -0
  24. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0}/src/airflow/providers/standard/operators/empty.py +0 -0
  25. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0}/src/airflow/providers/standard/operators/smooth.py +0 -0
  26. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0}/src/airflow/providers/standard/sensors/__init__.py +0 -0
  27. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0}/src/airflow/providers/standard/sensors/bash.py +0 -0
  28. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0}/src/airflow/providers/standard/sensors/date_time.py +0 -0
  29. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0}/src/airflow/providers/standard/sensors/external_task.py +0 -0
  30. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0}/src/airflow/providers/standard/sensors/filesystem.py +0 -0
  31. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0}/src/airflow/providers/standard/sensors/python.py +0 -0
  32. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0}/src/airflow/providers/standard/sensors/time.py +0 -0
  33. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0}/src/airflow/providers/standard/triggers/__init__.py +0 -0
  34. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0}/src/airflow/providers/standard/triggers/external_task.py +0 -0
  35. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0}/src/airflow/providers/standard/triggers/file.py +0 -0
  36. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0}/src/airflow/providers/standard/triggers/temporal.py +0 -0
  37. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0}/src/airflow/providers/standard/utils/__init__.py +0 -0
  38. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0}/src/airflow/providers/standard/utils/python_virtualenv.py +0 -0
  39. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0}/src/airflow/providers/standard/utils/sensor_helper.py +0 -0
  40. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0}/src/airflow/providers/standard/version_compat.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: apache-airflow-providers-standard
3
- Version: 0.1.1rc1
3
+ Version: 0.2.0
4
4
  Summary: Provider package apache-airflow-providers-standard for Apache Airflow
5
5
  Keywords: airflow-provider,standard,airflow,integration
6
6
  Author-email: Apache Software Foundation <dev@airflow.apache.org>
@@ -20,13 +20,13 @@ Classifier: Programming Language :: Python :: 3.10
20
20
  Classifier: Programming Language :: Python :: 3.11
21
21
  Classifier: Programming Language :: Python :: 3.12
22
22
  Classifier: Topic :: System :: Monitoring
23
- Requires-Dist: apache-airflow>=2.9.0rc0
23
+ Requires-Dist: apache-airflow>=2.9.0
24
24
  Project-URL: Bug Tracker, https://github.com/apache/airflow/issues
25
- Project-URL: Changelog, https://airflow.apache.org/docs/apache-airflow-providers-standard/0.1.1/changelog.html
26
- Project-URL: Documentation, https://airflow.apache.org/docs/apache-airflow-providers-standard/0.1.1
25
+ Project-URL: Changelog, https://airflow.apache.org/docs/apache-airflow-providers-standard/0.2.0/changelog.html
26
+ Project-URL: Documentation, https://airflow.apache.org/docs/apache-airflow-providers-standard/0.2.0
27
+ Project-URL: Mastodon, https://fosstodon.org/@airflow
27
28
  Project-URL: Slack Chat, https://s.apache.org/airflow-slack
28
29
  Project-URL: Source Code, https://github.com/apache/airflow
29
- Project-URL: Twitter, https://x.com/ApacheAirflow
30
30
  Project-URL: YouTube, https://www.youtube.com/channel/UCSXwxpWZQ7XZ1WL3wqevChA/
31
31
 
32
32
 
@@ -54,7 +54,7 @@ Project-URL: YouTube, https://www.youtube.com/channel/UCSXwxpWZQ7XZ1WL3wqevChA/
54
54
 
55
55
  Package ``apache-airflow-providers-standard``
56
56
 
57
- Release: ``0.1.1``
57
+ Release: ``0.2.0``
58
58
 
59
59
 
60
60
  Airflow Standard Provider
@@ -67,7 +67,7 @@ This is a provider package for ``standard`` provider. All classes for this provi
67
67
  are in ``airflow.providers.standard`` python package.
68
68
 
69
69
  You can find package information and changelog for the provider
70
- in the `documentation <https://airflow.apache.org/docs/apache-airflow-providers-standard/0.1.1/>`_.
70
+ in the `documentation <https://airflow.apache.org/docs/apache-airflow-providers-standard/0.2.0/>`_.
71
71
 
72
72
  Installation
73
73
  ------------
@@ -88,5 +88,5 @@ PIP package Version required
88
88
  ================== ==================
89
89
 
90
90
  The changelog for the provider package can be found in the
91
- `changelog <https://airflow.apache.org/docs/apache-airflow-providers-standard/0.1.1/changelog.html>`_.
91
+ `changelog <https://airflow.apache.org/docs/apache-airflow-providers-standard/0.2.0/changelog.html>`_.
92
92
 
@@ -23,7 +23,7 @@
23
23
 
24
24
  Package ``apache-airflow-providers-standard``
25
25
 
26
- Release: ``0.1.1``
26
+ Release: ``0.2.0``
27
27
 
28
28
 
29
29
  Airflow Standard Provider
@@ -36,7 +36,7 @@ This is a provider package for ``standard`` provider. All classes for this provi
36
36
  are in ``airflow.providers.standard`` python package.
37
37
 
38
38
  You can find package information and changelog for the provider
39
- in the `documentation <https://airflow.apache.org/docs/apache-airflow-providers-standard/0.1.1/>`_.
39
+ in the `documentation <https://airflow.apache.org/docs/apache-airflow-providers-standard/0.2.0/>`_.
40
40
 
41
41
  Installation
42
42
  ------------
@@ -57,4 +57,4 @@ PIP package Version required
57
57
  ================== ==================
58
58
 
59
59
  The changelog for the provider package can be found in the
60
- `changelog <https://airflow.apache.org/docs/apache-airflow-providers-standard/0.1.1/changelog.html>`_.
60
+ `changelog <https://airflow.apache.org/docs/apache-airflow-providers-standard/0.2.0/changelog.html>`_.
@@ -20,12 +20,12 @@
20
20
  # IF YOU WANT TO MODIFY THIS FILE EXCEPT DEPENDENCIES, YOU SHOULD MODIFY THE TEMPLATE
21
21
  # `pyproject_TEMPLATE.toml.jinja2` IN the `dev/breeze/src/airflow_breeze/templates` DIRECTORY
22
22
  [build-system]
23
- requires = ["flit_core==3.11.0"]
23
+ requires = ["flit_core==3.12.0"]
24
24
  build-backend = "flit_core.buildapi"
25
25
 
26
26
  [project]
27
27
  name = "apache-airflow-providers-standard"
28
- version = "0.1.1.rc1"
28
+ version = "0.2.0"
29
29
  description = "Provider package apache-airflow-providers-standard for Apache Airflow"
30
30
  readme = "README.rst"
31
31
  authors = [
@@ -57,7 +57,7 @@ requires-python = "~=3.9"
57
57
  # Make sure to run ``breeze static-checks --type update-providers-dependencies --all-files``
58
58
  # After you modify the dependencies, and rebuild your Breeze CI image with ``breeze ci-image build``
59
59
  dependencies = [
60
- "apache-airflow>=2.9.0rc0",
60
+ "apache-airflow>=2.9.0",
61
61
  ]
62
62
 
63
63
  [dependency-groups]
@@ -79,12 +79,12 @@ apache-airflow-providers-fab = {workspace = true}
79
79
  apache-airflow-providers-standard = {workspace = true}
80
80
 
81
81
  [project.urls]
82
- "Documentation" = "https://airflow.apache.org/docs/apache-airflow-providers-standard/0.1.1"
83
- "Changelog" = "https://airflow.apache.org/docs/apache-airflow-providers-standard/0.1.1/changelog.html"
82
+ "Documentation" = "https://airflow.apache.org/docs/apache-airflow-providers-standard/0.2.0"
83
+ "Changelog" = "https://airflow.apache.org/docs/apache-airflow-providers-standard/0.2.0/changelog.html"
84
84
  "Bug Tracker" = "https://github.com/apache/airflow/issues"
85
85
  "Source Code" = "https://github.com/apache/airflow"
86
86
  "Slack Chat" = "https://s.apache.org/airflow-slack"
87
- "Twitter" = "https://x.com/ApacheAirflow"
87
+ "Mastodon" = "https://fosstodon.org/@airflow"
88
88
  "YouTube" = "https://www.youtube.com/channel/UCSXwxpWZQ7XZ1WL3wqevChA/"
89
89
 
90
90
  [project.entry-points."apache_airflow_provider"]
@@ -29,7 +29,7 @@ from airflow import __version__ as airflow_version
29
29
 
30
30
  __all__ = ["__version__"]
31
31
 
32
- __version__ = "0.1.1"
32
+ __version__ = "0.2.0"
33
33
 
34
34
  if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
35
35
  "2.9.0"
@@ -27,8 +27,8 @@ def get_provider_info():
27
27
  "name": "Standard",
28
28
  "description": "Airflow Standard Provider\n",
29
29
  "state": "ready",
30
- "source-date-epoch": 1741509906,
31
- "versions": ["0.1.1", "0.1.0", "0.0.3", "0.0.2", "0.0.1"],
30
+ "source-date-epoch": 1742480519,
31
+ "versions": ["0.2.0", "0.1.1", "0.1.0", "0.0.3", "0.0.2", "0.0.1"],
32
32
  "integrations": [
33
33
  {
34
34
  "integration-name": "Standard",
@@ -53,6 +53,7 @@ def get_provider_info():
53
53
  "airflow.providers.standard.operators.trigger_dagrun",
54
54
  "airflow.providers.standard.operators.latest_only",
55
55
  "airflow.providers.standard.operators.smooth",
56
+ "airflow.providers.standard.operators.branch",
56
57
  ],
57
58
  }
58
59
  ],
@@ -27,12 +27,15 @@ from typing import TYPE_CHECKING, Any, Callable, cast
27
27
  from airflow.exceptions import AirflowException, AirflowSkipException
28
28
  from airflow.models.baseoperator import BaseOperator
29
29
  from airflow.providers.standard.hooks.subprocess import SubprocessHook, SubprocessResult, working_directory
30
- from airflow.utils.operator_helpers import context_to_airflow_vars
31
- from airflow.utils.session import NEW_SESSION, provide_session
32
- from airflow.utils.types import ArgNotSet
30
+ from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS
31
+
32
+ if AIRFLOW_V_3_0_PLUS:
33
+ from airflow.sdk.execution_time.context import context_to_airflow_vars
34
+ else:
35
+ from airflow.utils.operator_helpers import context_to_airflow_vars # type: ignore[no-redef, attr-defined]
33
36
 
34
37
  if TYPE_CHECKING:
35
- from sqlalchemy.orm import Session as SASession
38
+ from airflow.utils.types import ArgNotSet
36
39
 
37
40
  try:
38
41
  from airflow.sdk.definitions.context import Context
@@ -182,43 +185,15 @@ class BashOperator(BaseOperator):
182
185
  self.cwd = cwd
183
186
  self.append_env = append_env
184
187
  self.output_processor = output_processor
185
-
186
- # When using the @task.bash decorator, the Bash command is not known until the underlying Python
187
- # callable is executed and therefore set to NOTSET initially. This flag is useful during execution to
188
- # determine whether the bash_command value needs to re-rendered.
189
- self._init_bash_command_not_set = isinstance(self.bash_command, ArgNotSet)
190
-
191
- # Keep a copy of the original bash_command, without the Jinja template rendered.
192
- # This is later used to determine if the bash_command is a script or an inline string command.
193
- # We do this later, because the bash_command is not available in __init__ when using @task.bash.
194
- self._unrendered_bash_command: str | ArgNotSet = bash_command
188
+ self._is_inline_cmd = None
189
+ if isinstance(bash_command, str):
190
+ self._is_inline_cmd = self._is_inline_command(bash_command=bash_command)
195
191
 
196
192
  @cached_property
197
193
  def subprocess_hook(self):
198
194
  """Returns hook for running the bash command."""
199
195
  return SubprocessHook()
200
196
 
201
- # TODO: This should be replaced with Task SDK API call
202
- @staticmethod
203
- @provide_session
204
- def refresh_bash_command(ti, session: SASession = NEW_SESSION) -> None:
205
- """
206
- Rewrite the underlying rendered bash_command value for a task instance in the metadatabase.
207
-
208
- TaskInstance.get_rendered_template_fields() cannot be used because this will retrieve the
209
- RenderedTaskInstanceFields from the metadatabase which doesn't have the runtime-evaluated bash_command
210
- value.
211
-
212
- :meta private:
213
- """
214
- from airflow.models.renderedtifields import RenderedTaskInstanceFields
215
-
216
- """Update rendered task instance fields for cases where runtime evaluated, not templated."""
217
-
218
- rtif = RenderedTaskInstanceFields(ti)
219
- RenderedTaskInstanceFields.write(rtif, session=session)
220
- RenderedTaskInstanceFields.delete_old_records(ti.task_id, ti.dag_id, session=session)
221
-
222
197
  def get_env(self, context) -> dict:
223
198
  """Build the set of environment variables to be exposed for the bash command."""
224
199
  system_env = os.environ.copy()
@@ -247,19 +222,7 @@ class BashOperator(BaseOperator):
247
222
  raise AirflowException(f"The cwd {self.cwd} must be a directory")
248
223
  env = self.get_env(context)
249
224
 
250
- # Because the bash_command value is evaluated at runtime using the @task.bash decorator, the
251
- # RenderedTaskInstanceField data needs to be rewritten and the bash_command value re-rendered -- the
252
- # latter because the returned command from the decorated callable could contain a Jinja expression.
253
- # Both will ensure the correct Bash command is executed and that the Rendered Template view in the UI
254
- # displays the executed command (otherwise it will display as an ArgNotSet type).
255
- if self._init_bash_command_not_set:
256
- is_inline_command = self._is_inline_command(bash_command=cast(str, self.bash_command))
257
- ti = context["ti"]
258
- self.refresh_bash_command(ti)
259
- else:
260
- is_inline_command = self._is_inline_command(bash_command=cast(str, self._unrendered_bash_command))
261
-
262
- if is_inline_command:
225
+ if self._is_inline_cmd:
263
226
  result = self._run_inline_command(bash_path=bash_path, env=env)
264
227
  else:
265
228
  result = self._run_rendered_script_file(bash_path=bash_path, env=env)
@@ -0,0 +1,105 @@
1
+ #
2
+ # Licensed to the Apache Software Foundation (ASF) under one
3
+ # or more contributor license agreements. See the NOTICE file
4
+ # distributed with this work for additional information
5
+ # regarding copyright ownership. The ASF licenses this file
6
+ # to you under the Apache License, Version 2.0 (the
7
+ # "License"); you may not use this file except in compliance
8
+ # with the License. You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing,
13
+ # software distributed under the License is distributed on an
14
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15
+ # KIND, either express or implied. See the License for the
16
+ # specific language governing permissions and limitations
17
+ # under the License.
18
+ """Branching operators."""
19
+
20
+ from __future__ import annotations
21
+
22
+ from collections.abc import Iterable
23
+ from typing import TYPE_CHECKING
24
+
25
+ from airflow.models.baseoperator import BaseOperator
26
+ from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS
27
+
28
+ if AIRFLOW_V_3_0_PLUS:
29
+ from airflow.providers.standard.utils.skipmixin import SkipMixin
30
+ else:
31
+ from airflow.models.skipmixin import SkipMixin
32
+
33
+ if TYPE_CHECKING:
34
+ from airflow.sdk.definitions.context import Context
35
+ from airflow.sdk.types import RuntimeTaskInstanceProtocol
36
+
37
+
38
+ class BranchMixIn(SkipMixin):
39
+ """Utility helper which handles the branching as one-liner."""
40
+
41
+ def do_branch(self, context: Context, branches_to_execute: str | Iterable[str]) -> str | Iterable[str]:
42
+ """Implement the handling of branching including logging."""
43
+ self.log.info("Branch into %s", branches_to_execute)
44
+ branch_task_ids = self._expand_task_group_roots(context["ti"], branches_to_execute)
45
+ self.skip_all_except(context["ti"], branch_task_ids)
46
+ return branches_to_execute
47
+
48
+ def _expand_task_group_roots(
49
+ self, ti: RuntimeTaskInstanceProtocol, branches_to_execute: str | Iterable[str]
50
+ ) -> Iterable[str]:
51
+ """Expand any task group into its root task ids."""
52
+ if TYPE_CHECKING:
53
+ assert ti.task
54
+
55
+ task = ti.task
56
+ dag = task.dag
57
+ if TYPE_CHECKING:
58
+ assert dag
59
+
60
+ if branches_to_execute is None:
61
+ return
62
+ elif isinstance(branches_to_execute, str) or not isinstance(branches_to_execute, Iterable):
63
+ branches_to_execute = [branches_to_execute]
64
+
65
+ for branch in branches_to_execute:
66
+ if branch in dag.task_group_dict:
67
+ tg = dag.task_group_dict[branch]
68
+ root_ids = [root.task_id for root in tg.roots]
69
+ self.log.info("Expanding task group %s into %s", tg.group_id, root_ids)
70
+ yield from root_ids
71
+ else:
72
+ yield branch
73
+
74
+
75
+ class BaseBranchOperator(BaseOperator, BranchMixIn):
76
+ """
77
+ A base class for creating operators with branching functionality, like to BranchPythonOperator.
78
+
79
+ Users should create a subclass from this operator and implement the function
80
+ `choose_branch(self, context)`. This should run whatever business logic
81
+ is needed to determine the branch, and return one of the following:
82
+ - A single task_id (as a str)
83
+ - A single task_group_id (as a str)
84
+ - A list containing a combination of task_ids and task_group_ids
85
+
86
+ The operator will continue with the returned task_id(s) and/or task_group_id(s), and all other
87
+ tasks directly downstream of this operator will be skipped.
88
+ """
89
+
90
+ inherits_from_skipmixin = True
91
+
92
+ def choose_branch(self, context: Context) -> str | Iterable[str]:
93
+ """
94
+ Abstract method to choose which branch to run.
95
+
96
+ Subclasses should implement this, running whatever logic is
97
+ necessary to choose a branch and returning a task_id or list of
98
+ task_ids.
99
+
100
+ :param context: Context dictionary as passed to execute()
101
+ """
102
+ raise NotImplementedError
103
+
104
+ def execute(self, context: Context):
105
+ return self.do_branch(context, self.choose_branch(context))
@@ -21,7 +21,7 @@ from collections.abc import Iterable
21
21
  from typing import TYPE_CHECKING
22
22
 
23
23
  from airflow.exceptions import AirflowException
24
- from airflow.operators.branch import BaseBranchOperator
24
+ from airflow.providers.standard.operators.branch import BaseBranchOperator
25
25
  from airflow.utils import timezone
26
26
 
27
27
  if TYPE_CHECKING:
@@ -24,7 +24,7 @@ from typing import TYPE_CHECKING
24
24
 
25
25
  import pendulum
26
26
 
27
- from airflow.operators.branch import BaseBranchOperator
27
+ from airflow.providers.standard.operators.branch import BaseBranchOperator
28
28
  from airflow.utils.types import DagRunType
29
29
 
30
30
  if TYPE_CHECKING:
@@ -42,9 +42,7 @@ from airflow.exceptions import (
42
42
  DeserializingResultError,
43
43
  )
44
44
  from airflow.models.baseoperator import BaseOperator
45
- from airflow.models.skipmixin import SkipMixin
46
45
  from airflow.models.variable import Variable
47
- from airflow.operators.branch import BranchMixIn
48
46
  from airflow.providers.standard.utils.python_virtualenv import prepare_virtualenv, write_python_script
49
47
  from airflow.providers.standard.version_compat import AIRFLOW_V_2_10_PLUS, AIRFLOW_V_3_0_PLUS
50
48
  from airflow.utils import hashlib_wrapper
@@ -53,6 +51,14 @@ from airflow.utils.file import get_unique_dag_module_name
53
51
  from airflow.utils.operator_helpers import KeywordParameters
54
52
  from airflow.utils.process_utils import execute_in_subprocess, execute_in_subprocess_with_kwargs
55
53
 
54
+ if AIRFLOW_V_3_0_PLUS:
55
+ from airflow.providers.standard.operators.branch import BranchMixIn
56
+ from airflow.providers.standard.utils.skipmixin import SkipMixin
57
+ else:
58
+ from airflow.models.skipmixin import SkipMixin
59
+ from airflow.operators.branch import BranchMixIn
60
+
61
+
56
62
  log = logging.getLogger(__name__)
57
63
 
58
64
  if TYPE_CHECKING:
@@ -60,10 +66,12 @@ if TYPE_CHECKING:
60
66
 
61
67
  from pendulum.datetime import DateTime
62
68
 
69
+ from airflow.sdk.execution_time.callback_runner import ExecutionCallableRunner
70
+ from airflow.sdk.execution_time.context import OutletEventAccessorsProtocol
71
+
63
72
  try:
64
73
  from airflow.sdk.definitions.context import Context
65
- except ImportError:
66
- # TODO: Remove once provider drops support for Airflow 2
74
+ except ImportError: # TODO: Remove once provider drops support for Airflow 2
67
75
  from airflow.utils.context import Context
68
76
 
69
77
  _SerializerTypeDef = Literal["pickle", "cloudpickle", "dill"]
@@ -184,14 +192,22 @@ class PythonOperator(BaseOperator):
184
192
  context_merge(context, self.op_kwargs, templates_dict=self.templates_dict)
185
193
  self.op_kwargs = self.determine_kwargs(context)
186
194
 
187
- if AIRFLOW_V_3_0_PLUS:
188
- from airflow.utils.context import context_get_outlet_events
195
+ # This needs to be lazy because subclasses may implement execute_callable
196
+ # by running a separate process that can't use the eager result.
197
+ def __prepare_execution() -> tuple[ExecutionCallableRunner, OutletEventAccessorsProtocol] | None:
198
+ if AIRFLOW_V_3_0_PLUS:
199
+ from airflow.sdk.execution_time.callback_runner import create_executable_runner
200
+ from airflow.sdk.execution_time.context import context_get_outlet_events
189
201
 
190
- self._asset_events = context_get_outlet_events(context)
191
- elif AIRFLOW_V_2_10_PLUS:
192
- from airflow.utils.context import context_get_outlet_events
202
+ return create_executable_runner, context_get_outlet_events(context)
203
+ if AIRFLOW_V_2_10_PLUS:
204
+ from airflow.utils.context import context_get_outlet_events # type: ignore
205
+ from airflow.utils.operator_helpers import ExecutionCallableRunner # type: ignore
193
206
 
194
- self._dataset_events = context_get_outlet_events(context)
207
+ return ExecutionCallableRunner, context_get_outlet_events(context)
208
+ return None
209
+
210
+ self.__prepare_execution = __prepare_execution
195
211
 
196
212
  return_value = self.execute_callable()
197
213
  if self.show_return_value_in_logs:
@@ -204,21 +220,18 @@ class PythonOperator(BaseOperator):
204
220
  def determine_kwargs(self, context: Mapping[str, Any]) -> Mapping[str, Any]:
205
221
  return KeywordParameters.determine(self.python_callable, self.op_args, context).unpacking()
206
222
 
223
+ __prepare_execution: Callable[[], tuple[ExecutionCallableRunner, OutletEventAccessorsProtocol] | None]
224
+
207
225
  def execute_callable(self) -> Any:
208
226
  """
209
227
  Call the python callable with the given arguments.
210
228
 
211
229
  :return: the return value of the call.
212
230
  """
213
- try:
214
- from airflow.utils.operator_helpers import ExecutionCallableRunner
215
-
216
- asset_events = self._asset_events if AIRFLOW_V_3_0_PLUS else self._dataset_events
217
-
218
- runner = ExecutionCallableRunner(self.python_callable, asset_events, logger=self.log)
219
- except ImportError:
220
- # Handle Pre Airflow 3.10 case where ExecutionCallableRunner was not available
231
+ if (execution_preparation := self.__prepare_execution()) is None:
221
232
  return self.python_callable(*self.op_args, **self.op_kwargs)
233
+ create_execution_runner, asset_events = execution_preparation
234
+ runner = create_execution_runner(self.python_callable, asset_events, logger=self.log)
222
235
  return runner.run(*self.op_args, **self.op_kwargs)
223
236
 
224
237
 
@@ -236,6 +249,8 @@ class BranchPythonOperator(PythonOperator, BranchMixIn):
236
249
  the DAG run's state to be inferred.
237
250
  """
238
251
 
252
+ inherits_from_skipmixin = True
253
+
239
254
  def execute(self, context: Context) -> Any:
240
255
  return self.do_branch(context, super().execute(context))
241
256
 
@@ -266,6 +281,8 @@ class ShortCircuitOperator(PythonOperator, SkipMixin):
266
281
  skipped but the ``trigger_rule`` defined for all other downstream tasks will be respected.
267
282
  """
268
283
 
284
+ inherits_from_skipmixin = True
285
+
269
286
  def __init__(self, *, ignore_downstream_trigger_rules: bool = True, **kwargs) -> None:
270
287
  super().__init__(**kwargs)
271
288
  self.ignore_downstream_trigger_rules = ignore_downstream_trigger_rules
@@ -295,25 +312,24 @@ class ShortCircuitOperator(PythonOperator, SkipMixin):
295
312
 
296
313
  to_skip = get_tasks_to_skip()
297
314
 
298
- # this let's us avoid an intermediate list unless debug logging
315
+ # this lets us avoid an intermediate list unless debug logging
299
316
  if self.log.getEffectiveLevel() <= logging.DEBUG:
300
317
  self.log.debug("Downstream task IDs %s", to_skip := list(get_tasks_to_skip()))
301
318
 
302
319
  self.log.info("Skipping downstream tasks")
303
320
  if AIRFLOW_V_3_0_PLUS:
304
321
  self.skip(
305
- dag_id=dag_run.dag_id,
306
- run_id=dag_run.run_id,
322
+ ti=context["ti"],
307
323
  tasks=to_skip,
308
- map_index=context["ti"].map_index,
309
324
  )
310
325
  else:
311
- self.skip(
312
- dag_run=dag_run,
313
- tasks=to_skip,
314
- execution_date=cast("DateTime", dag_run.logical_date), # type: ignore[call-arg, union-attr]
315
- map_index=context["ti"].map_index,
316
- )
326
+ if to_skip:
327
+ self.skip(
328
+ dag_run=context["dag_run"],
329
+ tasks=to_skip,
330
+ execution_date=cast("DateTime", dag_run.logical_date), # type: ignore[call-arg, union-attr]
331
+ map_index=context["ti"].map_index,
332
+ )
317
333
 
318
334
  self.log.info("Done.")
319
335
  # returns the result of the super execute method as it is instead of returning None
@@ -868,6 +884,8 @@ class BranchPythonVirtualenvOperator(PythonVirtualenvOperator, BranchMixIn):
868
884
  :ref:`howto/operator:BranchPythonVirtualenvOperator`
869
885
  """
870
886
 
887
+ inherits_from_skipmixin = True
888
+
871
889
  def execute(self, context: Context) -> Any:
872
890
  return self.do_branch(context, super().execute(context))
873
891
 
@@ -38,13 +38,12 @@ from airflow.models import BaseOperator
38
38
  from airflow.models.dag import DagModel
39
39
  from airflow.models.dagbag import DagBag
40
40
  from airflow.models.dagrun import DagRun
41
- from airflow.models.xcom import XCom
42
41
  from airflow.providers.standard.triggers.external_task import DagStateTrigger
43
42
  from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS
44
43
  from airflow.utils import timezone
45
44
  from airflow.utils.session import provide_session
46
45
  from airflow.utils.state import DagRunState
47
- from airflow.utils.types import DagRunTriggeredByType, DagRunType
46
+ from airflow.utils.types import DagRunType
48
47
 
49
48
  XCOM_LOGICAL_DATE_ISO = "trigger_logical_date_iso"
50
49
  XCOM_RUN_ID = "trigger_run_id"
@@ -63,7 +62,9 @@ if TYPE_CHECKING:
63
62
 
64
63
  if AIRFLOW_V_3_0_PLUS:
65
64
  from airflow.sdk import BaseOperatorLink
65
+ from airflow.sdk.execution_time.xcom import XCom
66
66
  else:
67
+ from airflow.models import XCom # type: ignore[no-redef]
67
68
  from airflow.models.baseoperatorlink import BaseOperatorLink # type: ignore[no-redef]
68
69
 
69
70
 
@@ -77,15 +78,15 @@ class TriggerDagRunLink(BaseOperatorLink):
77
78
  name = "Triggered DAG"
78
79
 
79
80
  def get_link(self, operator: BaseOperator, *, ti_key: TaskInstanceKey) -> str:
80
- from airflow.models.renderedtifields import RenderedTaskInstanceFields
81
-
82
81
  if TYPE_CHECKING:
83
82
  assert isinstance(operator, TriggerDagRunOperator)
84
83
 
85
- if template_fields := RenderedTaskInstanceFields.get_templated_fields(ti_key):
86
- trigger_dag_id: str = template_fields.get("trigger_dag_id", operator.trigger_dag_id)
87
- else:
88
- trigger_dag_id = operator.trigger_dag_id
84
+ trigger_dag_id = operator.trigger_dag_id
85
+ if not AIRFLOW_V_3_0_PLUS:
86
+ from airflow.models.renderedtifields import RenderedTaskInstanceFields
87
+
88
+ if template_fields := RenderedTaskInstanceFields.get_templated_fields(ti_key):
89
+ trigger_dag_id: str = template_fields.get("trigger_dag_id", operator.trigger_dag_id) # type: ignore[no-redef]
89
90
 
90
91
  # Fetch the correct dag_run_id for the triggerED dag which is
91
92
  # stored in xcom during execution of the triggerING task.
@@ -173,7 +174,7 @@ class TriggerDagRunOperator(BaseOperator):
173
174
  self.allowed_states = [DagRunState(s) for s in allowed_states]
174
175
  else:
175
176
  self.allowed_states = [DagRunState.SUCCESS]
176
- if failed_states or failed_states == []:
177
+ if failed_states is not None:
177
178
  self.failed_states = [DagRunState(s) for s in failed_states]
178
179
  else:
179
180
  self.failed_states = [DagRunState.FAILED]
@@ -197,25 +198,51 @@ class TriggerDagRunOperator(BaseOperator):
197
198
  try:
198
199
  json.dumps(self.conf)
199
200
  except TypeError:
200
- raise AirflowException("conf parameter should be JSON Serializable")
201
+ raise ValueError("conf parameter should be JSON Serializable")
201
202
 
202
203
  if self.trigger_run_id:
203
204
  run_id = str(self.trigger_run_id)
204
205
  else:
205
- run_id = DagRun.generate_run_id(
206
- run_type=DagRunType.MANUAL,
207
- logical_date=parsed_logical_date,
208
- run_after=parsed_logical_date or timezone.utcnow(),
209
- )
206
+ if AIRFLOW_V_3_0_PLUS:
207
+ run_id = DagRun.generate_run_id(
208
+ run_type=DagRunType.MANUAL,
209
+ logical_date=parsed_logical_date,
210
+ run_after=parsed_logical_date or timezone.utcnow(),
211
+ )
212
+ else:
213
+ run_id = DagRun.generate_run_id(DagRunType.MANUAL, parsed_logical_date or timezone.utcnow()) # type: ignore[misc,call-arg]
210
214
 
215
+ if AIRFLOW_V_3_0_PLUS:
216
+ self._trigger_dag_af_3(context=context, run_id=run_id, parsed_logical_date=parsed_logical_date)
217
+ else:
218
+ self._trigger_dag_af_2(context=context, run_id=run_id, parsed_logical_date=parsed_logical_date)
219
+
220
+ def _trigger_dag_af_3(self, context, run_id, parsed_logical_date):
221
+ from airflow.exceptions import DagRunTriggerException
222
+
223
+ raise DagRunTriggerException(
224
+ trigger_dag_id=self.trigger_dag_id,
225
+ dag_run_id=run_id,
226
+ conf=self.conf,
227
+ logical_date=parsed_logical_date,
228
+ reset_dag_run=self.reset_dag_run,
229
+ skip_when_already_exists=self.skip_when_already_exists,
230
+ wait_for_completion=self.wait_for_completion,
231
+ allowed_states=self.allowed_states,
232
+ failed_states=self.failed_states,
233
+ poke_interval=self.poke_interval,
234
+ )
235
+
236
+ # TODO: Support deferral
237
+
238
+ def _trigger_dag_af_2(self, context, run_id, parsed_logical_date):
211
239
  try:
212
240
  dag_run = trigger_dag(
213
241
  dag_id=self.trigger_dag_id,
214
242
  run_id=run_id,
215
243
  conf=self.conf,
216
- logical_date=parsed_logical_date,
244
+ execution_date=parsed_logical_date,
217
245
  replace_microseconds=False,
218
- triggered_by=DagRunTriggeredByType.OPERATOR,
219
246
  )
220
247
 
221
248
  except DagRunAlreadyExists as e:
@@ -231,7 +258,7 @@ class TriggerDagRunOperator(BaseOperator):
231
258
  # Note: here execution fails on database isolation mode. Needs structural changes for AIP-72
232
259
  dag_bag = DagBag(dag_folder=dag_model.fileloc, read_dags_from_db=True)
233
260
  dag = dag_bag.get_dag(self.trigger_dag_id)
234
- dag.clear(run_id=dag_run.run_id)
261
+ dag.clear(start_date=dag_run.logical_date, end_date=dag_run.logical_date)
235
262
  else:
236
263
  if self.skip_when_already_exists:
237
264
  raise AirflowSkipException(
@@ -252,6 +279,7 @@ class TriggerDagRunOperator(BaseOperator):
252
279
  trigger=DagStateTrigger(
253
280
  dag_id=self.trigger_dag_id,
254
281
  states=self.allowed_states + self.failed_states,
282
+ execution_dates=[dag_run.logical_date],
255
283
  run_ids=[run_id],
256
284
  poll_interval=self.poke_interval,
257
285
  ),
@@ -278,16 +306,18 @@ class TriggerDagRunOperator(BaseOperator):
278
306
 
279
307
  @provide_session
280
308
  def execute_complete(self, context: Context, session: Session, event: tuple[str, dict[str, Any]]):
281
- # This run_ids is parsed from the return trigger event
282
- provided_run_id = event[1]["run_ids"][0]
309
+ # This logical_date is parsed from the return trigger event
310
+ provided_logical_date = event[1]["execution_dates"][0]
283
311
  try:
284
312
  # Note: here execution fails on database isolation mode. Needs structural changes for AIP-72
285
313
  dag_run = session.execute(
286
- select(DagRun).where(DagRun.dag_id == self.trigger_dag_id, DagRun.run_id == provided_run_id)
314
+ select(DagRun).where(
315
+ DagRun.dag_id == self.trigger_dag_id, DagRun.execution_date == provided_logical_date
316
+ )
287
317
  ).scalar_one()
288
318
  except NoResultFound:
289
319
  raise AirflowException(
290
- f"No DAG run found for DAG {self.trigger_dag_id} and run ID {provided_run_id}"
320
+ f"No DAG run found for DAG {self.trigger_dag_id} and logical date {self.logical_date}"
291
321
  )
292
322
 
293
323
  state = dag_run.state
@@ -20,9 +20,9 @@ from __future__ import annotations
20
20
  from collections.abc import Iterable
21
21
  from typing import TYPE_CHECKING
22
22
 
23
- from airflow.operators.branch import BaseBranchOperator
23
+ from airflow.providers.standard.operators.branch import BaseBranchOperator
24
+ from airflow.providers.standard.utils.weekday import WeekDay
24
25
  from airflow.utils import timezone
25
- from airflow.utils.weekday import WeekDay
26
26
 
27
27
  if TYPE_CHECKING:
28
28
  try:
@@ -63,7 +63,7 @@ class BranchDayOfWeekOperator(BaseBranchOperator):
63
63
  .. code-block:: python
64
64
 
65
65
  # import WeekDay Enum
66
- from airflow.utils.weekday import WeekDay
66
+ from airflow.providers.standard.utils.weekday import WeekDay
67
67
  from airflow.providers.standard.operators.empty import EmptyOperator
68
68
  from airflow.operators.weekday import BranchDayOfWeekOperator
69
69
 
@@ -46,29 +46,48 @@ def _get_airflow_version():
46
46
 
47
47
  class TimeDeltaSensor(BaseSensorOperator):
48
48
  """
49
- Waits for a timedelta after the run's data interval.
49
+ Waits for a timedelta.
50
50
 
51
- :param delta: time length to wait after the data interval before succeeding.
51
+ The delta will be evaluated against data_interval_end if present for the dag run,
52
+ otherwise run_after will be used.
53
+
54
+ :param delta: time to wait before succeeding.
52
55
 
53
56
  .. seealso::
54
57
  For more information on how to use this sensor, take a look at the guide:
55
58
  :ref:`howto/operator:TimeDeltaSensor`
56
59
 
57
-
58
60
  """
59
61
 
60
62
  def __init__(self, *, delta, **kwargs):
61
63
  super().__init__(**kwargs)
62
64
  self.delta = delta
63
65
 
64
- def poke(self, context: Context):
65
- data_interval_end = context["data_interval_end"]
66
+ def _derive_base_time(self, context: Context) -> datetime:
67
+ """
68
+ Get the "base time" against which the delta should be calculated.
69
+
70
+ If data_interval_end is populated, use it; else use run_after.
71
+ """
72
+ data_interval_end = context.get("data_interval_end")
73
+ if data_interval_end:
74
+ if not isinstance(data_interval_end, datetime):
75
+ raise ValueError("`data_interval_end` returned non-datetime object")
66
76
 
67
- if not isinstance(data_interval_end, datetime):
68
- raise ValueError("`data_interval_end` returned non-datetime object")
77
+ return data_interval_end
69
78
 
70
- target_dttm: datetime = data_interval_end + self.delta
71
- self.log.info("Checking if the time (%s) has come", target_dttm)
79
+ if not data_interval_end and not AIRFLOW_V_3_0_PLUS:
80
+ raise ValueError("`data_interval_end` not found in task context.")
81
+
82
+ dag_run = context.get("dag_run")
83
+ if not dag_run:
84
+ raise ValueError("`dag_run` not found in task context")
85
+ return dag_run.run_after
86
+
87
+ def poke(self, context: Context) -> bool:
88
+ base_time = self._derive_base_time(context=context)
89
+ target_dttm = base_time + self.delta
90
+ self.log.info("Checking if the delta has elapsed base_time=%s, delta=%s", base_time, self.delta)
72
91
  return timezone.utcnow() > target_dttm
73
92
 
74
93
 
@@ -92,12 +111,8 @@ class TimeDeltaSensorAsync(TimeDeltaSensor):
92
111
  self.end_from_trigger = end_from_trigger
93
112
 
94
113
  def execute(self, context: Context) -> bool | NoReturn:
95
- data_interval_end = context["data_interval_end"]
96
-
97
- if not isinstance(data_interval_end, datetime):
98
- raise ValueError("`data_interval_end` returned non-datetime object")
99
-
100
- target_dttm: datetime = data_interval_end + self.delta
114
+ base_time = self._derive_base_time(context=context)
115
+ target_dttm: datetime = base_time + self.delta
101
116
 
102
117
  if timezone.utcnow() > target_dttm:
103
118
  # If the target datetime is in the past, return immediately
@@ -20,9 +20,9 @@ from __future__ import annotations
20
20
  from collections.abc import Iterable
21
21
  from typing import TYPE_CHECKING
22
22
 
23
+ from airflow.providers.standard.utils.weekday import WeekDay
23
24
  from airflow.sensors.base import BaseSensorOperator
24
25
  from airflow.utils import timezone
25
- from airflow.utils.weekday import WeekDay
26
26
 
27
27
  if TYPE_CHECKING:
28
28
  try:
@@ -54,7 +54,7 @@ class DayOfWeekSensor(BaseSensorOperator):
54
54
  **Example** (with :class:`~airflow.utils.weekday.WeekDay` enum): ::
55
55
 
56
56
  # import WeekDay Enum
57
- from airflow.utils.weekday import WeekDay
57
+ from airflow.providers.standard.utils.weekday import WeekDay
58
58
 
59
59
  weekend_check = DayOfWeekSensor(
60
60
  task_id="weekend_check",
@@ -103,7 +103,21 @@ class DayOfWeekSensor(BaseSensorOperator):
103
103
  self.week_day,
104
104
  WeekDay(timezone.utcnow().isoweekday()).name,
105
105
  )
106
+
106
107
  if self.use_task_logical_date:
107
- return context["logical_date"].isoweekday() in self._week_day_num
108
+ logical_date = context.get("logical_date")
109
+ dag_run = context.get("dag_run")
110
+
111
+ if not (logical_date or (dag_run and dag_run.run_after)):
112
+ raise ValueError(
113
+ "Either `logical_date` or `run_after` should be provided in the task context when "
114
+ "`use_task_logical_date` is True"
115
+ )
116
+
117
+ determined_weekday_num = (
118
+ logical_date.isoweekday() if logical_date else dag_run.run_after.isoweekday() # type: ignore[union-attr]
119
+ )
120
+
121
+ return determined_weekday_num in self._week_day_num
108
122
  else:
109
123
  return timezone.utcnow().isoweekday() in self._week_day_num
@@ -20,7 +20,12 @@ from __future__ import annotations
20
20
 
21
21
  import {{ pickling_library }}
22
22
  import sys
23
+ import os
24
+ # Setting the PYTHON_OPERATORS_VIRTUAL_ENV_MODE environment variable to 1,
25
+ # helps to avoid the issue of re creating the orm session in the settings file, otherwise
26
+ # it fails with airflow-db-not-allowed
23
27
 
28
+ os.environ["PYTHON_OPERATORS_VIRTUAL_ENV_MODE"] = "1"
24
29
  {% if expect_airflow %}
25
30
  {# Check whether Airflow is available in the environment.
26
31
  # If it is, we'll want to ensure that we integrate any macros that are being provided
@@ -0,0 +1,192 @@
1
+ #
2
+ # Licensed to the Apache Software Foundation (ASF) under one
3
+ # or more contributor license agreements. See the NOTICE file
4
+ # distributed with this work for additional information
5
+ # regarding copyright ownership. The ASF licenses this file
6
+ # to you under the Apache License, Version 2.0 (the
7
+ # "License"); you may not use this file except in compliance
8
+ # with the License. You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing,
13
+ # software distributed under the License is distributed on an
14
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15
+ # KIND, either express or implied. See the License for the
16
+ # specific language governing permissions and limitations
17
+ # under the License.
18
+ from __future__ import annotations
19
+
20
+ from collections.abc import Iterable, Sequence
21
+ from types import GeneratorType
22
+ from typing import TYPE_CHECKING
23
+
24
+ from airflow.exceptions import AirflowException
25
+ from airflow.utils.log.logging_mixin import LoggingMixin
26
+
27
+ if TYPE_CHECKING:
28
+ from airflow.models.operator import Operator
29
+ from airflow.sdk.definitions._internal.node import DAGNode
30
+ from airflow.sdk.types import RuntimeTaskInstanceProtocol
31
+
32
+ # The key used by SkipMixin to store XCom data.
33
+ XCOM_SKIPMIXIN_KEY = "skipmixin_key"
34
+
35
+ # The dictionary key used to denote task IDs that are skipped
36
+ XCOM_SKIPMIXIN_SKIPPED = "skipped"
37
+
38
+ # The dictionary key used to denote task IDs that are followed
39
+ XCOM_SKIPMIXIN_FOLLOWED = "followed"
40
+
41
+
42
+ def _ensure_tasks(nodes: Iterable[DAGNode]) -> Sequence[Operator]:
43
+ from airflow.models.baseoperator import BaseOperator
44
+ from airflow.models.mappedoperator import MappedOperator
45
+
46
+ return [n for n in nodes if isinstance(n, (BaseOperator, MappedOperator))]
47
+
48
+
49
+ # This class should only be used in Airflow 3.0 and later.
50
+ class SkipMixin(LoggingMixin):
51
+ """A Mixin to skip Tasks Instances."""
52
+
53
+ @staticmethod
54
+ def _set_state_to_skipped(
55
+ tasks: Sequence[str | tuple[str, int]],
56
+ map_index: int | None,
57
+ ) -> None:
58
+ """
59
+ Set state of task instances to skipped from the same dag run.
60
+
61
+ Raises
62
+ ------
63
+ SkipDownstreamTaskInstances
64
+ If the task instances are not in the same dag run.
65
+ """
66
+ # Import is internal for backward compatibility when importing PythonOperator
67
+ # from airflow.providers.common.compat.standard.operators
68
+ from airflow.exceptions import DownstreamTasksSkipped
69
+
70
+ # The following could be applied only for non-mapped tasks,
71
+ # as future mapped tasks have not been expanded yet. Such tasks
72
+ # have to be handled by NotPreviouslySkippedDep.
73
+ if tasks and map_index == -1:
74
+ raise DownstreamTasksSkipped(tasks=tasks)
75
+
76
+ def skip(
77
+ self,
78
+ ti: RuntimeTaskInstanceProtocol,
79
+ tasks: Iterable[DAGNode],
80
+ ):
81
+ """
82
+ Set tasks instances to skipped from the same dag run.
83
+
84
+ If this instance has a `task_id` attribute, store the list of skipped task IDs to XCom
85
+ so that NotPreviouslySkippedDep knows these tasks should be skipped when they
86
+ are cleared.
87
+
88
+ :param ti: the task instance for which to set the tasks to skipped
89
+ :param tasks: tasks to skip (not task_ids)
90
+ """
91
+ # SkipMixin may not necessarily have a task_id attribute. Only store to XCom if one is available.
92
+ task_id: str | None = getattr(self, "task_id", None)
93
+ task_list = _ensure_tasks(tasks)
94
+ if not task_list:
95
+ return
96
+
97
+ task_ids_list = [d.task_id for d in task_list]
98
+
99
+ if task_id is not None:
100
+ ti.xcom_push(
101
+ key=XCOM_SKIPMIXIN_KEY,
102
+ value={XCOM_SKIPMIXIN_SKIPPED: task_ids_list},
103
+ )
104
+
105
+ self._set_state_to_skipped(task_ids_list, ti.map_index)
106
+
107
+ def skip_all_except(
108
+ self,
109
+ ti: RuntimeTaskInstanceProtocol,
110
+ branch_task_ids: None | str | Iterable[str],
111
+ ):
112
+ """
113
+ Implement the logic for a branching operator.
114
+
115
+ Given a single task ID or list of task IDs to follow, this skips all other tasks
116
+ immediately downstream of this operator.
117
+
118
+ branch_task_ids is stored to XCom so that NotPreviouslySkippedDep knows skipped tasks or
119
+ newly added tasks should be skipped when they are cleared.
120
+ """
121
+ # Ensure we don't serialize a generator object
122
+ if branch_task_ids and isinstance(branch_task_ids, GeneratorType):
123
+ branch_task_ids = list(branch_task_ids)
124
+ log = self.log # Note: need to catch logger form instance, static logger breaks pytest
125
+ if isinstance(branch_task_ids, str):
126
+ branch_task_id_set = {branch_task_ids}
127
+ elif isinstance(branch_task_ids, Iterable):
128
+ branch_task_id_set = set(branch_task_ids)
129
+ invalid_task_ids_type = {
130
+ (bti, type(bti).__name__) for bti in branch_task_id_set if not isinstance(bti, str)
131
+ }
132
+ if invalid_task_ids_type:
133
+ raise AirflowException(
134
+ f"'branch_task_ids' expected all task IDs are strings. "
135
+ f"Invalid tasks found: {invalid_task_ids_type}."
136
+ )
137
+ elif branch_task_ids is None:
138
+ branch_task_id_set = set()
139
+ else:
140
+ raise AirflowException(
141
+ "'branch_task_ids' must be either None, a task ID, or an Iterable of IDs, "
142
+ f"but got {type(branch_task_ids).__name__!r}."
143
+ )
144
+
145
+ log.info("Following branch %s", branch_task_id_set)
146
+
147
+ if TYPE_CHECKING:
148
+ assert ti.task
149
+
150
+ task = ti.task
151
+ dag = ti.task.dag
152
+
153
+ valid_task_ids = set(dag.task_ids)
154
+ invalid_task_ids = branch_task_id_set - valid_task_ids
155
+ if invalid_task_ids:
156
+ raise AirflowException(
157
+ "'branch_task_ids' must contain only valid task_ids. "
158
+ f"Invalid tasks found: {invalid_task_ids}."
159
+ )
160
+
161
+ downstream_tasks = _ensure_tasks(task.downstream_list)
162
+
163
+ if downstream_tasks:
164
+ # For a branching workflow that looks like this, when "branch" does skip_all_except("task1"),
165
+ # we intuitively expect both "task1" and "join" to execute even though strictly speaking,
166
+ # "join" is also immediately downstream of "branch" and should have been skipped. Therefore,
167
+ # we need a special case here for such empty branches: Check downstream tasks of branch_task_ids.
168
+ # In case the task to skip is also downstream of branch_task_ids, we add it to branch_task_ids and
169
+ # exclude it from skipping.
170
+ #
171
+ # branch -----> join
172
+ # \ ^
173
+ # v /
174
+ # task1
175
+ #
176
+ for branch_task_id in list(branch_task_id_set):
177
+ branch_task_id_set.update(dag.get_task(branch_task_id).get_flat_relative_ids(upstream=False))
178
+
179
+ skip_tasks = [
180
+ (t.task_id, ti.map_index) for t in downstream_tasks if t.task_id not in branch_task_id_set
181
+ ]
182
+
183
+ follow_task_ids = [t.task_id for t in downstream_tasks if t.task_id in branch_task_id_set]
184
+ log.info("Skipping tasks %s", skip_tasks)
185
+ ti.xcom_push(
186
+ key=XCOM_SKIPMIXIN_KEY,
187
+ value={XCOM_SKIPMIXIN_FOLLOWED: follow_task_ids},
188
+ )
189
+ # The following could be applied only for non-mapped tasks,
190
+ # as future mapped tasks have not been expanded yet. Such tasks
191
+ # have to be handled by NotPreviouslySkippedDep.
192
+ self._set_state_to_skipped(skip_tasks, ti.map_index) # type: ignore[arg-type]
@@ -0,0 +1,77 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+ """Get the ISO standard day number of the week from a given day string."""
18
+
19
+ from __future__ import annotations
20
+
21
+ import enum
22
+ from collections.abc import Iterable
23
+
24
+
25
+ @enum.unique
26
+ class WeekDay(enum.IntEnum):
27
+ """Python Enum containing Days of the Week."""
28
+
29
+ MONDAY = 1
30
+ TUESDAY = 2
31
+ WEDNESDAY = 3
32
+ THURSDAY = 4
33
+ FRIDAY = 5
34
+ SATURDAY = 6
35
+ SUNDAY = 7
36
+
37
+ @classmethod
38
+ def get_weekday_number(cls, week_day_str: str):
39
+ """
40
+ Return the ISO Week Day Number for a Week Day.
41
+
42
+ :param week_day_str: Full Name of the Week Day. Example: "Sunday"
43
+ :return: ISO Week Day Number corresponding to the provided Weekday
44
+ """
45
+ sanitized_week_day_str = week_day_str.upper()
46
+
47
+ if sanitized_week_day_str not in cls.__members__:
48
+ raise AttributeError(f'Invalid Week Day passed: "{week_day_str}"')
49
+
50
+ return cls[sanitized_week_day_str]
51
+
52
+ @classmethod
53
+ def convert(cls, day: str | WeekDay) -> int:
54
+ """Return the day number in the week."""
55
+ if isinstance(day, WeekDay):
56
+ return day
57
+ return cls.get_weekday_number(week_day_str=day)
58
+
59
+ @classmethod
60
+ def validate_week_day(
61
+ cls,
62
+ week_day: str | WeekDay | Iterable[str] | Iterable[WeekDay],
63
+ ) -> set[int]:
64
+ """Validate each item of iterable and create a set to ease compare of values."""
65
+ if not isinstance(week_day, Iterable):
66
+ if isinstance(week_day, WeekDay):
67
+ week_day = {week_day}
68
+ else:
69
+ raise TypeError(
70
+ f"Unsupported Type for week_day parameter: {type(week_day)}."
71
+ "Input should be iterable type:"
72
+ "str, set, list, dict or Weekday enum type"
73
+ )
74
+ if isinstance(week_day, str):
75
+ week_day = {week_day}
76
+
77
+ return {cls.convert(item) for item in week_day}