apache-airflow-providers-standard 1.0.0.dev1__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of apache-airflow-providers-standard might be problematic. Click here for more details.

Files changed (50) hide show
  1. airflow/providers/standard/LICENSE +0 -52
  2. airflow/providers/standard/__init__.py +23 -1
  3. airflow/providers/standard/decorators/__init__.py +16 -0
  4. airflow/providers/standard/decorators/bash.py +121 -0
  5. airflow/providers/standard/decorators/branch_external_python.py +63 -0
  6. airflow/providers/standard/decorators/branch_python.py +62 -0
  7. airflow/providers/standard/decorators/branch_virtualenv.py +62 -0
  8. airflow/providers/standard/decorators/external_python.py +70 -0
  9. airflow/providers/standard/decorators/python.py +86 -0
  10. airflow/providers/standard/decorators/python_virtualenv.py +67 -0
  11. airflow/providers/standard/decorators/sensor.py +83 -0
  12. airflow/providers/standard/decorators/short_circuit.py +65 -0
  13. airflow/providers/standard/get_provider_info.py +89 -7
  14. airflow/providers/standard/hooks/__init__.py +16 -0
  15. airflow/providers/standard/hooks/filesystem.py +89 -0
  16. airflow/providers/standard/hooks/package_index.py +95 -0
  17. airflow/providers/standard/hooks/subprocess.py +119 -0
  18. airflow/providers/standard/operators/bash.py +73 -56
  19. airflow/providers/standard/operators/branch.py +105 -0
  20. airflow/providers/standard/operators/datetime.py +15 -5
  21. airflow/providers/standard/operators/empty.py +39 -0
  22. airflow/providers/standard/operators/latest_only.py +127 -0
  23. airflow/providers/standard/operators/python.py +1143 -0
  24. airflow/providers/standard/operators/smooth.py +38 -0
  25. airflow/providers/standard/operators/trigger_dagrun.py +391 -0
  26. airflow/providers/standard/operators/weekday.py +19 -9
  27. airflow/providers/standard/sensors/bash.py +15 -11
  28. airflow/providers/standard/sensors/date_time.py +32 -8
  29. airflow/providers/standard/sensors/external_task.py +593 -0
  30. airflow/providers/standard/sensors/filesystem.py +158 -0
  31. airflow/providers/standard/sensors/python.py +84 -0
  32. airflow/providers/standard/sensors/time.py +28 -5
  33. airflow/providers/standard/sensors/time_delta.py +68 -15
  34. airflow/providers/standard/sensors/weekday.py +25 -7
  35. airflow/providers/standard/triggers/__init__.py +16 -0
  36. airflow/providers/standard/triggers/external_task.py +288 -0
  37. airflow/providers/standard/triggers/file.py +131 -0
  38. airflow/providers/standard/triggers/temporal.py +113 -0
  39. airflow/providers/standard/utils/__init__.py +16 -0
  40. airflow/providers/standard/utils/python_virtualenv.py +209 -0
  41. airflow/providers/standard/utils/python_virtualenv_script.jinja2 +82 -0
  42. airflow/providers/standard/utils/sensor_helper.py +137 -0
  43. airflow/providers/standard/utils/skipmixin.py +192 -0
  44. airflow/providers/standard/utils/weekday.py +77 -0
  45. airflow/providers/standard/version_compat.py +36 -0
  46. {apache_airflow_providers_standard-1.0.0.dev1.dist-info → apache_airflow_providers_standard-1.1.0.dist-info}/METADATA +16 -35
  47. apache_airflow_providers_standard-1.1.0.dist-info/RECORD +51 -0
  48. {apache_airflow_providers_standard-1.0.0.dev1.dist-info → apache_airflow_providers_standard-1.1.0.dist-info}/WHEEL +1 -1
  49. apache_airflow_providers_standard-1.0.0.dev1.dist-info/RECORD +0 -17
  50. {apache_airflow_providers_standard-1.0.0.dev1.dist-info → apache_airflow_providers_standard-1.1.0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,209 @@
1
+ #
2
+ # Licensed to the Apache Software Foundation (ASF) under one
3
+ # or more contributor license agreements. See the NOTICE file
4
+ # distributed with this work for additional information
5
+ # regarding copyright ownership. The ASF licenses this file
6
+ # to you under the Apache License, Version 2.0 (the
7
+ # "License"); you may not use this file except in compliance
8
+ # with the License. You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing,
13
+ # software distributed under the License is distributed on an
14
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15
+ # KIND, either express or implied. See the License for the
16
+ # specific language governing permissions and limitations
17
+ # under the License.
18
+ """Utilities for creating a virtual environment."""
19
+
20
+ from __future__ import annotations
21
+
22
+ import os
23
+ import shutil
24
+ import sys
25
+ from pathlib import Path
26
+
27
+ import jinja2
28
+ from jinja2 import select_autoescape
29
+
30
+ from airflow.configuration import conf
31
+ from airflow.utils.process_utils import execute_in_subprocess
32
+
33
+
34
+ def _is_uv_installed() -> bool:
35
+ """
36
+ Verify whether the uv tool is installed by checking if it's included in the system PATH or installed as a package.
37
+
38
+ :return: True if it is. Whichever way of checking it works, is fine.
39
+ """
40
+ return bool(shutil.which("uv"))
41
+
42
+
43
+ def _use_uv() -> bool:
44
+ """
45
+ Check if the uv tool should be used.
46
+
47
+ :return: True if uv should be used.
48
+ """
49
+ venv_install_method = conf.get("standard", "venv_install_method", fallback="auto").lower()
50
+ if venv_install_method == "auto":
51
+ return _is_uv_installed()
52
+ if venv_install_method == "uv":
53
+ return True
54
+ return False
55
+
56
+
57
+ def _generate_uv_cmd(tmp_dir: str, python_bin: str, system_site_packages: bool) -> list[str]:
58
+ """Build the command to install the venv via UV."""
59
+ cmd = ["uv", "venv", "--allow-existing", "--seed"]
60
+ if python_bin is not None:
61
+ cmd += ["--python", python_bin]
62
+ if system_site_packages:
63
+ cmd.append("--system-site-packages")
64
+ cmd.append(tmp_dir)
65
+ return cmd
66
+
67
+
68
+ def _generate_venv_cmd(tmp_dir: str, python_bin: str, system_site_packages: bool) -> list[str]:
69
+ """We are using venv command instead of venv module to allow creation of venv for different python versions."""
70
+ if python_bin is None:
71
+ python_bin = sys.executable
72
+ cmd = [python_bin, "-m", "venv", tmp_dir]
73
+ if system_site_packages:
74
+ cmd.append("--system-site-packages")
75
+ return cmd
76
+
77
+
78
+ def _generate_uv_install_cmd_from_file(
79
+ tmp_dir: str, requirements_file_path: str, pip_install_options: list[str]
80
+ ) -> list[str]:
81
+ return [
82
+ "uv",
83
+ "pip",
84
+ "install",
85
+ "--python",
86
+ f"{tmp_dir}/bin/python",
87
+ *pip_install_options,
88
+ "-r",
89
+ requirements_file_path,
90
+ ]
91
+
92
+
93
+ def _generate_pip_install_cmd_from_file(
94
+ tmp_dir: str, requirements_file_path: str, pip_install_options: list[str]
95
+ ) -> list[str]:
96
+ return [f"{tmp_dir}/bin/pip", "install", *pip_install_options, "-r", requirements_file_path]
97
+
98
+
99
+ def _generate_uv_install_cmd_from_list(
100
+ tmp_dir: str, requirements: list[str], pip_install_options: list[str]
101
+ ) -> list[str]:
102
+ return ["uv", "pip", "install", "--python", f"{tmp_dir}/bin/python", *pip_install_options, *requirements]
103
+
104
+
105
+ def _generate_pip_install_cmd_from_list(
106
+ tmp_dir: str, requirements: list[str], pip_install_options: list[str]
107
+ ) -> list[str]:
108
+ return [f"{tmp_dir}/bin/pip", "install", *pip_install_options, *requirements]
109
+
110
+
111
+ def _generate_pip_conf(conf_file: Path, index_urls: list[str]) -> None:
112
+ if index_urls:
113
+ pip_conf_options = f"index-url = {index_urls[0]}"
114
+ if len(index_urls) > 1:
115
+ pip_conf_options += f"\nextra-index-url = {' '.join(x for x in index_urls[1:])}"
116
+ else:
117
+ pip_conf_options = "no-index = true"
118
+ conf_file.write_text(f"[global]\n{pip_conf_options}")
119
+
120
+
121
+ def prepare_virtualenv(
122
+ venv_directory: str,
123
+ python_bin: str,
124
+ system_site_packages: bool,
125
+ requirements: list[str] | None = None,
126
+ requirements_file_path: str | None = None,
127
+ pip_install_options: list[str] | None = None,
128
+ index_urls: list[str] | None = None,
129
+ ) -> str:
130
+ """
131
+ Create a virtual environment and install the additional python packages.
132
+
133
+ :param venv_directory: The path for directory where the environment will be created.
134
+ :param python_bin: Path to the Python executable.
135
+ :param system_site_packages: Whether to include system_site_packages in your virtualenv.
136
+ See virtualenv documentation for more information.
137
+ :param requirements: List of additional python packages.
138
+ :param requirements_file_path: Path to the ``requirements.txt`` file.
139
+ :param pip_install_options: a list of pip install options when installing requirements
140
+ See 'pip install -h' for available options
141
+ :param index_urls: an optional list of index urls to load Python packages from.
142
+ If not provided the system pip conf will be used to source packages from.
143
+ :return: Path to a binary file with Python in a virtual environment.
144
+ """
145
+ if pip_install_options is None:
146
+ pip_install_options = []
147
+
148
+ if requirements is not None and requirements_file_path is not None:
149
+ raise ValueError("Either requirements OR requirements_file_path has to be passed, but not both")
150
+
151
+ if index_urls is not None:
152
+ _generate_pip_conf(Path(venv_directory) / "pip.conf", index_urls)
153
+
154
+ if _use_uv():
155
+ venv_cmd = _generate_uv_cmd(venv_directory, python_bin, system_site_packages)
156
+ else:
157
+ venv_cmd = _generate_venv_cmd(venv_directory, python_bin, system_site_packages)
158
+ execute_in_subprocess(venv_cmd)
159
+
160
+ pip_cmd = None
161
+ if requirements is not None and len(requirements) != 0:
162
+ if _use_uv():
163
+ pip_cmd = _generate_uv_install_cmd_from_list(venv_directory, requirements, pip_install_options)
164
+ else:
165
+ pip_cmd = _generate_pip_install_cmd_from_list(venv_directory, requirements, pip_install_options)
166
+ if requirements_file_path is not None and requirements_file_path:
167
+ if _use_uv():
168
+ pip_cmd = _generate_uv_install_cmd_from_file(
169
+ venv_directory, requirements_file_path, pip_install_options
170
+ )
171
+ else:
172
+ pip_cmd = _generate_pip_install_cmd_from_file(
173
+ venv_directory, requirements_file_path, pip_install_options
174
+ )
175
+
176
+ if pip_cmd:
177
+ execute_in_subprocess(pip_cmd)
178
+
179
+ return f"{venv_directory}/bin/python"
180
+
181
+
182
+ def write_python_script(
183
+ jinja_context: dict,
184
+ filename: str,
185
+ render_template_as_native_obj: bool = False,
186
+ ):
187
+ """
188
+ Render the python script to a file to execute in the virtual environment.
189
+
190
+ :param jinja_context: The jinja context variables to unpack and replace with its placeholders in the
191
+ template file.
192
+ :param filename: The name of the file to dump the rendered script to.
193
+ :param render_template_as_native_obj: If ``True``, rendered Jinja template would be converted
194
+ to a native Python object
195
+ """
196
+ template_loader = jinja2.FileSystemLoader(searchpath=os.path.dirname(__file__))
197
+ template_env: jinja2.Environment
198
+ if render_template_as_native_obj:
199
+ template_env = jinja2.nativetypes.NativeEnvironment(
200
+ loader=template_loader, undefined=jinja2.StrictUndefined
201
+ )
202
+ else:
203
+ template_env = jinja2.Environment(
204
+ loader=template_loader,
205
+ undefined=jinja2.StrictUndefined,
206
+ autoescape=select_autoescape(["html", "xml"]),
207
+ )
208
+ template = template_env.get_template("python_virtualenv_script.jinja2")
209
+ template.stream(**jinja_context).dump(filename)
@@ -0,0 +1,82 @@
1
+ {#
2
+ Licensed to the Apache Software Foundation (ASF) under one
3
+ or more contributor license agreements. See the NOTICE file
4
+ distributed with this work for additional information
5
+ regarding copyright ownership. The ASF licenses this file
6
+ to you under the Apache License, Version 2.0 (the
7
+ "License"); you may not use this file except in compliance
8
+ with the License. You may obtain a copy of the License at
9
+
10
+ http://www.apache.org/licenses/LICENSE-2.0
11
+
12
+ Unless required by applicable law or agreed to in writing,
13
+ software distributed under the License is distributed on an
14
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15
+ KIND, either express or implied. See the License for the
16
+ specific language governing permissions and limitations
17
+ under the License.
18
+ -#}
19
+ from __future__ import annotations
20
+
21
+ import {{ pickling_library }}
22
+ import sys
23
+ import os
24
+ # Setting the PYTHON_OPERATORS_VIRTUAL_ENV_MODE environment variable to 1,
25
+ # helps to avoid the issue of re creating the orm session in the settings file, otherwise
26
+ # it fails with airflow-db-not-allowed
27
+
28
+ os.environ["PYTHON_OPERATORS_VIRTUAL_ENV_MODE"] = "1"
29
+ {% if expect_airflow %}
30
+ {# Check whether Airflow is available in the environment.
31
+ # If it is, we'll want to ensure that we integrate any macros that are being provided
32
+ # by plugins prior to unpickling the task context. #}
33
+ if sys.version_info >= (3,6):
34
+ try:
35
+ from airflow.plugins_manager import integrate_macros_plugins
36
+ integrate_macros_plugins()
37
+ except ImportError:
38
+ {# Airflow is not available in this environment, therefore we won't
39
+ # be able to integrate any plugin macros. #}
40
+ pass
41
+ {% endif %}
42
+
43
+ # Script
44
+ {{ python_callable_source }}
45
+
46
+ # monkey patching for the cases when python_callable is part of the dag module.
47
+ {% if modified_dag_module_name is defined %}
48
+
49
+ import types
50
+
51
+ {{ modified_dag_module_name }} = types.ModuleType("{{ modified_dag_module_name }}")
52
+
53
+ {{ modified_dag_module_name }}.{{ python_callable }} = {{ python_callable }}
54
+
55
+ sys.modules["{{modified_dag_module_name}}"] = {{modified_dag_module_name}}
56
+
57
+ {% endif%}
58
+
59
+ {% if op_args or op_kwargs %}
60
+ with open(sys.argv[1], "rb") as file:
61
+ arg_dict = {{ pickling_library }}.load(file)
62
+ {% else %}
63
+ arg_dict = {"args": [], "kwargs": {}}
64
+ {% endif %}
65
+
66
+ {% if string_args_global | default(true) -%}
67
+ # Read string args
68
+ with open(sys.argv[3], "r") as file:
69
+ virtualenv_string_args = list(map(lambda x: x.strip(), list(file)))
70
+ {% endif %}
71
+
72
+ try:
73
+ res = {{ python_callable }}(*arg_dict["args"], **arg_dict["kwargs"])
74
+ except Exception as e:
75
+ with open(sys.argv[4], "w") as file:
76
+ file.write(str(e))
77
+ raise
78
+
79
+ # Write output
80
+ with open(sys.argv[2], "wb") as file:
81
+ if res is not None:
82
+ {{ pickling_library }}.dump(res, file)
@@ -0,0 +1,137 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+ from __future__ import annotations
18
+
19
+ from typing import TYPE_CHECKING, Any, cast
20
+
21
+ from sqlalchemy import func, select, tuple_
22
+
23
+ from airflow.models import DagBag, DagRun, TaskInstance
24
+ from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS
25
+ from airflow.utils.session import NEW_SESSION, provide_session
26
+
27
+ if TYPE_CHECKING:
28
+ from sqlalchemy.orm import Session
29
+ from sqlalchemy.sql import Executable
30
+
31
+
32
+ @provide_session
33
+ def _get_count(
34
+ dttm_filter,
35
+ external_task_ids,
36
+ external_task_group_id,
37
+ external_dag_id,
38
+ states,
39
+ session: Session = NEW_SESSION,
40
+ ) -> int:
41
+ """
42
+ Get the count of records against dttm filter and states.
43
+
44
+ :param dttm_filter: date time filter for logical date
45
+ :param external_task_ids: The list of task_ids
46
+ :param external_task_group_id: The ID of the external task group
47
+ :param external_dag_id: The ID of the external DAG.
48
+ :param states: task or dag states
49
+ :param session: airflow session object
50
+ """
51
+ TI = TaskInstance
52
+ DR = DagRun
53
+ if not dttm_filter:
54
+ return 0
55
+
56
+ if external_task_ids:
57
+ count = (
58
+ session.scalar(
59
+ _count_stmt(TI, states, dttm_filter, external_dag_id).where(TI.task_id.in_(external_task_ids))
60
+ )
61
+ ) / len(external_task_ids)
62
+ elif external_task_group_id:
63
+ external_task_group_task_ids = _get_external_task_group_task_ids(
64
+ dttm_filter, external_task_group_id, external_dag_id, session
65
+ )
66
+ if not external_task_group_task_ids:
67
+ count = 0
68
+ else:
69
+ count = (
70
+ session.scalar(
71
+ _count_stmt(TI, states, dttm_filter, external_dag_id).where(
72
+ tuple_(TI.task_id, TI.map_index).in_(external_task_group_task_ids)
73
+ )
74
+ )
75
+ / len(external_task_group_task_ids)
76
+ * len(dttm_filter)
77
+ )
78
+ else:
79
+ count = session.scalar(_count_stmt(DR, states, dttm_filter, external_dag_id))
80
+ return cast("int", count)
81
+
82
+
83
+ def _count_stmt(model, states, dttm_filter, external_dag_id) -> Executable:
84
+ """
85
+ Get the count of records against dttm filter and states.
86
+
87
+ :param model: The SQLAlchemy model representing the relevant table.
88
+ :param states: task or dag states
89
+ :param dttm_filter: date time filter for logical date
90
+ :param external_dag_id: The ID of the external DAG.
91
+ """
92
+ date_field = model.logical_date if AIRFLOW_V_3_0_PLUS else model.execution_date
93
+
94
+ return select(func.count()).where(
95
+ model.dag_id == external_dag_id, model.state.in_(states), date_field.in_(dttm_filter)
96
+ )
97
+
98
+
99
+ def _get_external_task_group_task_ids(dttm_filter, external_task_group_id, external_dag_id, session):
100
+ """
101
+ Get the count of records against dttm filter and states.
102
+
103
+ :param dttm_filter: date time filter for logical date
104
+ :param external_task_group_id: The ID of the external task group
105
+ :param external_dag_id: The ID of the external DAG.
106
+ :param session: airflow session object
107
+ """
108
+ refreshed_dag_info = DagBag(read_dags_from_db=True).get_dag(external_dag_id, session)
109
+ task_group = refreshed_dag_info.task_group_dict.get(external_task_group_id)
110
+
111
+ if task_group:
112
+ date_field = TaskInstance.logical_date if AIRFLOW_V_3_0_PLUS else TaskInstance.execution_date
113
+
114
+ group_tasks = session.scalars(
115
+ select(TaskInstance).filter(
116
+ TaskInstance.dag_id == external_dag_id,
117
+ TaskInstance.task_id.in_(task.task_id for task in task_group),
118
+ date_field.in_(dttm_filter),
119
+ )
120
+ )
121
+
122
+ return [(t.task_id, t.map_index) for t in group_tasks]
123
+
124
+ # returning default task_id as group_id itself, this will avoid any failure in case of
125
+ # 'check_existence=False' and will fail on timeout
126
+ return [(external_task_group_id, -1)]
127
+
128
+
129
+ def _get_count_by_matched_states(
130
+ run_id_task_state_map: dict[str, dict[str, Any]],
131
+ states: list[str],
132
+ ):
133
+ count = 0
134
+ for _, task_states in run_id_task_state_map.items():
135
+ if all(state in states for state in task_states.values() if state):
136
+ count += 1
137
+ return count
@@ -0,0 +1,192 @@
1
+ #
2
+ # Licensed to the Apache Software Foundation (ASF) under one
3
+ # or more contributor license agreements. See the NOTICE file
4
+ # distributed with this work for additional information
5
+ # regarding copyright ownership. The ASF licenses this file
6
+ # to you under the Apache License, Version 2.0 (the
7
+ # "License"); you may not use this file except in compliance
8
+ # with the License. You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing,
13
+ # software distributed under the License is distributed on an
14
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15
+ # KIND, either express or implied. See the License for the
16
+ # specific language governing permissions and limitations
17
+ # under the License.
18
+ from __future__ import annotations
19
+
20
+ from collections.abc import Iterable, Sequence
21
+ from types import GeneratorType
22
+ from typing import TYPE_CHECKING
23
+
24
+ from airflow.exceptions import AirflowException
25
+ from airflow.utils.log.logging_mixin import LoggingMixin
26
+
27
+ if TYPE_CHECKING:
28
+ from airflow.models.operator import Operator
29
+ from airflow.sdk.definitions._internal.node import DAGNode
30
+ from airflow.sdk.types import RuntimeTaskInstanceProtocol
31
+
32
+ # The key used by SkipMixin to store XCom data.
33
+ XCOM_SKIPMIXIN_KEY = "skipmixin_key"
34
+
35
+ # The dictionary key used to denote task IDs that are skipped
36
+ XCOM_SKIPMIXIN_SKIPPED = "skipped"
37
+
38
+ # The dictionary key used to denote task IDs that are followed
39
+ XCOM_SKIPMIXIN_FOLLOWED = "followed"
40
+
41
+
42
+ def _ensure_tasks(nodes: Iterable[DAGNode]) -> Sequence[Operator]:
43
+ from airflow.models.baseoperator import BaseOperator
44
+ from airflow.models.mappedoperator import MappedOperator
45
+
46
+ return [n for n in nodes if isinstance(n, (BaseOperator, MappedOperator))]
47
+
48
+
49
+ # This class should only be used in Airflow 3.0 and later.
50
+ class SkipMixin(LoggingMixin):
51
+ """A Mixin to skip Tasks Instances."""
52
+
53
+ @staticmethod
54
+ def _set_state_to_skipped(
55
+ tasks: Sequence[str | tuple[str, int]],
56
+ map_index: int | None,
57
+ ) -> None:
58
+ """
59
+ Set state of task instances to skipped from the same dag run.
60
+
61
+ Raises
62
+ ------
63
+ SkipDownstreamTaskInstances
64
+ If the task instances are not in the same dag run.
65
+ """
66
+ # Import is internal for backward compatibility when importing PythonOperator
67
+ # from airflow.providers.common.compat.standard.operators
68
+ from airflow.exceptions import DownstreamTasksSkipped
69
+
70
+ # The following could be applied only for non-mapped tasks,
71
+ # as future mapped tasks have not been expanded yet. Such tasks
72
+ # have to be handled by NotPreviouslySkippedDep.
73
+ if tasks and map_index == -1:
74
+ raise DownstreamTasksSkipped(tasks=tasks)
75
+
76
+ def skip(
77
+ self,
78
+ ti: RuntimeTaskInstanceProtocol,
79
+ tasks: Iterable[DAGNode],
80
+ ):
81
+ """
82
+ Set tasks instances to skipped from the same dag run.
83
+
84
+ If this instance has a `task_id` attribute, store the list of skipped task IDs to XCom
85
+ so that NotPreviouslySkippedDep knows these tasks should be skipped when they
86
+ are cleared.
87
+
88
+ :param ti: the task instance for which to set the tasks to skipped
89
+ :param tasks: tasks to skip (not task_ids)
90
+ """
91
+ # SkipMixin may not necessarily have a task_id attribute. Only store to XCom if one is available.
92
+ task_id: str | None = getattr(self, "task_id", None)
93
+ task_list = _ensure_tasks(tasks)
94
+ if not task_list:
95
+ return
96
+
97
+ task_ids_list = [d.task_id for d in task_list]
98
+
99
+ if task_id is not None:
100
+ ti.xcom_push(
101
+ key=XCOM_SKIPMIXIN_KEY,
102
+ value={XCOM_SKIPMIXIN_SKIPPED: task_ids_list},
103
+ )
104
+
105
+ self._set_state_to_skipped(task_ids_list, ti.map_index)
106
+
107
+ def skip_all_except(
108
+ self,
109
+ ti: RuntimeTaskInstanceProtocol,
110
+ branch_task_ids: None | str | Iterable[str],
111
+ ):
112
+ """
113
+ Implement the logic for a branching operator.
114
+
115
+ Given a single task ID or list of task IDs to follow, this skips all other tasks
116
+ immediately downstream of this operator.
117
+
118
+ branch_task_ids is stored to XCom so that NotPreviouslySkippedDep knows skipped tasks or
119
+ newly added tasks should be skipped when they are cleared.
120
+ """
121
+ # Ensure we don't serialize a generator object
122
+ if branch_task_ids and isinstance(branch_task_ids, GeneratorType):
123
+ branch_task_ids = list(branch_task_ids)
124
+ log = self.log # Note: need to catch logger form instance, static logger breaks pytest
125
+ if isinstance(branch_task_ids, str):
126
+ branch_task_id_set = {branch_task_ids}
127
+ elif isinstance(branch_task_ids, Iterable):
128
+ branch_task_id_set = set(branch_task_ids)
129
+ invalid_task_ids_type = {
130
+ (bti, type(bti).__name__) for bti in branch_task_id_set if not isinstance(bti, str)
131
+ }
132
+ if invalid_task_ids_type:
133
+ raise AirflowException(
134
+ f"'branch_task_ids' expected all task IDs are strings. "
135
+ f"Invalid tasks found: {invalid_task_ids_type}."
136
+ )
137
+ elif branch_task_ids is None:
138
+ branch_task_id_set = set()
139
+ else:
140
+ raise AirflowException(
141
+ "'branch_task_ids' must be either None, a task ID, or an Iterable of IDs, "
142
+ f"but got {type(branch_task_ids).__name__!r}."
143
+ )
144
+
145
+ log.info("Following branch %s", branch_task_id_set)
146
+
147
+ if TYPE_CHECKING:
148
+ assert ti.task
149
+
150
+ task = ti.task
151
+ dag = ti.task.dag
152
+
153
+ valid_task_ids = set(dag.task_ids)
154
+ invalid_task_ids = branch_task_id_set - valid_task_ids
155
+ if invalid_task_ids:
156
+ raise AirflowException(
157
+ "'branch_task_ids' must contain only valid task_ids. "
158
+ f"Invalid tasks found: {invalid_task_ids}."
159
+ )
160
+
161
+ downstream_tasks = _ensure_tasks(task.downstream_list)
162
+
163
+ if downstream_tasks:
164
+ # For a branching workflow that looks like this, when "branch" does skip_all_except("task1"),
165
+ # we intuitively expect both "task1" and "join" to execute even though strictly speaking,
166
+ # "join" is also immediately downstream of "branch" and should have been skipped. Therefore,
167
+ # we need a special case here for such empty branches: Check downstream tasks of branch_task_ids.
168
+ # In case the task to skip is also downstream of branch_task_ids, we add it to branch_task_ids and
169
+ # exclude it from skipping.
170
+ #
171
+ # branch -----> join
172
+ # \ ^
173
+ # v /
174
+ # task1
175
+ #
176
+ for branch_task_id in list(branch_task_id_set):
177
+ branch_task_id_set.update(dag.get_task(branch_task_id).get_flat_relative_ids(upstream=False))
178
+
179
+ skip_tasks = [
180
+ (t.task_id, ti.map_index) for t in downstream_tasks if t.task_id not in branch_task_id_set
181
+ ]
182
+
183
+ follow_task_ids = [t.task_id for t in downstream_tasks if t.task_id in branch_task_id_set]
184
+ log.info("Skipping tasks %s", skip_tasks)
185
+ ti.xcom_push(
186
+ key=XCOM_SKIPMIXIN_KEY,
187
+ value={XCOM_SKIPMIXIN_FOLLOWED: follow_task_ids},
188
+ )
189
+ # The following could be applied only for non-mapped tasks,
190
+ # as future mapped tasks have not been expanded yet. Such tasks
191
+ # have to be handled by NotPreviouslySkippedDep.
192
+ self._set_state_to_skipped(skip_tasks, ti.map_index) # type: ignore[arg-type]
@@ -0,0 +1,77 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+ """Get the ISO standard day number of the week from a given day string."""
18
+
19
+ from __future__ import annotations
20
+
21
+ import enum
22
+ from collections.abc import Iterable
23
+
24
+
25
+ @enum.unique
26
+ class WeekDay(enum.IntEnum):
27
+ """Python Enum containing Days of the Week."""
28
+
29
+ MONDAY = 1
30
+ TUESDAY = 2
31
+ WEDNESDAY = 3
32
+ THURSDAY = 4
33
+ FRIDAY = 5
34
+ SATURDAY = 6
35
+ SUNDAY = 7
36
+
37
+ @classmethod
38
+ def get_weekday_number(cls, week_day_str: str):
39
+ """
40
+ Return the ISO Week Day Number for a Week Day.
41
+
42
+ :param week_day_str: Full Name of the Week Day. Example: "Sunday"
43
+ :return: ISO Week Day Number corresponding to the provided Weekday
44
+ """
45
+ sanitized_week_day_str = week_day_str.upper()
46
+
47
+ if sanitized_week_day_str not in cls.__members__:
48
+ raise AttributeError(f'Invalid Week Day passed: "{week_day_str}"')
49
+
50
+ return cls[sanitized_week_day_str]
51
+
52
+ @classmethod
53
+ def convert(cls, day: str | WeekDay) -> int:
54
+ """Return the day number in the week."""
55
+ if isinstance(day, WeekDay):
56
+ return day
57
+ return cls.get_weekday_number(week_day_str=day)
58
+
59
+ @classmethod
60
+ def validate_week_day(
61
+ cls,
62
+ week_day: str | WeekDay | Iterable[str] | Iterable[WeekDay],
63
+ ) -> set[int]:
64
+ """Validate each item of iterable and create a set to ease compare of values."""
65
+ if not isinstance(week_day, Iterable):
66
+ if isinstance(week_day, WeekDay):
67
+ week_day = {week_day}
68
+ else:
69
+ raise TypeError(
70
+ f"Unsupported Type for week_day parameter: {type(week_day)}."
71
+ "Input should be iterable type:"
72
+ "str, set, list, dict or Weekday enum type"
73
+ )
74
+ if isinstance(week_day, str):
75
+ week_day = {week_day}
76
+
77
+ return {cls.convert(item) for item in week_day}