apache-airflow-providers-standard 0.1.1rc1__tar.gz → 0.2.0b1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of apache-airflow-providers-standard might be problematic. Click here for more details.

Files changed (40) hide show
  1. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0b1}/PKG-INFO +8 -8
  2. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0b1}/README.rst +3 -3
  3. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0b1}/pyproject.toml +5 -5
  4. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0b1}/src/airflow/providers/standard/__init__.py +1 -1
  5. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0b1}/src/airflow/providers/standard/get_provider_info.py +3 -2
  6. apache_airflow_providers_standard-0.2.0b1/src/airflow/providers/standard/operators/branch.py +105 -0
  7. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0b1}/src/airflow/providers/standard/operators/datetime.py +1 -1
  8. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0b1}/src/airflow/providers/standard/operators/latest_only.py +1 -1
  9. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0b1}/src/airflow/providers/standard/operators/python.py +26 -17
  10. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0b1}/src/airflow/providers/standard/operators/trigger_dagrun.py +52 -22
  11. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0b1}/src/airflow/providers/standard/operators/weekday.py +3 -3
  12. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0b1}/src/airflow/providers/standard/sensors/time_delta.py +30 -15
  13. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0b1}/src/airflow/providers/standard/sensors/weekday.py +17 -3
  14. apache_airflow_providers_standard-0.2.0b1/src/airflow/providers/standard/utils/skipmixin.py +192 -0
  15. apache_airflow_providers_standard-0.2.0b1/src/airflow/providers/standard/utils/weekday.py +77 -0
  16. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0b1}/src/airflow/providers/standard/LICENSE +0 -0
  17. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0b1}/src/airflow/providers/standard/hooks/__init__.py +0 -0
  18. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0b1}/src/airflow/providers/standard/hooks/filesystem.py +0 -0
  19. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0b1}/src/airflow/providers/standard/hooks/package_index.py +0 -0
  20. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0b1}/src/airflow/providers/standard/hooks/subprocess.py +0 -0
  21. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0b1}/src/airflow/providers/standard/operators/__init__.py +0 -0
  22. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0b1}/src/airflow/providers/standard/operators/bash.py +0 -0
  23. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0b1}/src/airflow/providers/standard/operators/empty.py +0 -0
  24. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0b1}/src/airflow/providers/standard/operators/smooth.py +0 -0
  25. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0b1}/src/airflow/providers/standard/sensors/__init__.py +0 -0
  26. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0b1}/src/airflow/providers/standard/sensors/bash.py +0 -0
  27. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0b1}/src/airflow/providers/standard/sensors/date_time.py +0 -0
  28. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0b1}/src/airflow/providers/standard/sensors/external_task.py +0 -0
  29. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0b1}/src/airflow/providers/standard/sensors/filesystem.py +0 -0
  30. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0b1}/src/airflow/providers/standard/sensors/python.py +0 -0
  31. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0b1}/src/airflow/providers/standard/sensors/time.py +0 -0
  32. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0b1}/src/airflow/providers/standard/triggers/__init__.py +0 -0
  33. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0b1}/src/airflow/providers/standard/triggers/external_task.py +0 -0
  34. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0b1}/src/airflow/providers/standard/triggers/file.py +0 -0
  35. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0b1}/src/airflow/providers/standard/triggers/temporal.py +0 -0
  36. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0b1}/src/airflow/providers/standard/utils/__init__.py +0 -0
  37. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0b1}/src/airflow/providers/standard/utils/python_virtualenv.py +0 -0
  38. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0b1}/src/airflow/providers/standard/utils/python_virtualenv_script.jinja2 +0 -0
  39. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0b1}/src/airflow/providers/standard/utils/sensor_helper.py +0 -0
  40. {apache_airflow_providers_standard-0.1.1rc1 → apache_airflow_providers_standard-0.2.0b1}/src/airflow/providers/standard/version_compat.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: apache-airflow-providers-standard
3
- Version: 0.1.1rc1
3
+ Version: 0.2.0b1
4
4
  Summary: Provider package apache-airflow-providers-standard for Apache Airflow
5
5
  Keywords: airflow-provider,standard,airflow,integration
6
6
  Author-email: Apache Software Foundation <dev@airflow.apache.org>
@@ -20,13 +20,13 @@ Classifier: Programming Language :: Python :: 3.10
20
20
  Classifier: Programming Language :: Python :: 3.11
21
21
  Classifier: Programming Language :: Python :: 3.12
22
22
  Classifier: Topic :: System :: Monitoring
23
- Requires-Dist: apache-airflow>=2.9.0rc0
23
+ Requires-Dist: apache-airflow>=2.9.0
24
24
  Project-URL: Bug Tracker, https://github.com/apache/airflow/issues
25
- Project-URL: Changelog, https://airflow.apache.org/docs/apache-airflow-providers-standard/0.1.1/changelog.html
26
- Project-URL: Documentation, https://airflow.apache.org/docs/apache-airflow-providers-standard/0.1.1
25
+ Project-URL: Changelog, https://airflow.apache.org/docs/apache-airflow-providers-standard/0.2.0b1/changelog.html
26
+ Project-URL: Documentation, https://airflow.apache.org/docs/apache-airflow-providers-standard/0.2.0b1
27
+ Project-URL: Mastodon, https://fosstodon.org/@airflow
27
28
  Project-URL: Slack Chat, https://s.apache.org/airflow-slack
28
29
  Project-URL: Source Code, https://github.com/apache/airflow
29
- Project-URL: Twitter, https://x.com/ApacheAirflow
30
30
  Project-URL: YouTube, https://www.youtube.com/channel/UCSXwxpWZQ7XZ1WL3wqevChA/
31
31
 
32
32
 
@@ -54,7 +54,7 @@ Project-URL: YouTube, https://www.youtube.com/channel/UCSXwxpWZQ7XZ1WL3wqevChA/
54
54
 
55
55
  Package ``apache-airflow-providers-standard``
56
56
 
57
- Release: ``0.1.1``
57
+ Release: ``0.2.0b1``
58
58
 
59
59
 
60
60
  Airflow Standard Provider
@@ -67,7 +67,7 @@ This is a provider package for ``standard`` provider. All classes for this provi
67
67
  are in ``airflow.providers.standard`` python package.
68
68
 
69
69
  You can find package information and changelog for the provider
70
- in the `documentation <https://airflow.apache.org/docs/apache-airflow-providers-standard/0.1.1/>`_.
70
+ in the `documentation <https://airflow.apache.org/docs/apache-airflow-providers-standard/0.2.0b1/>`_.
71
71
 
72
72
  Installation
73
73
  ------------
@@ -88,5 +88,5 @@ PIP package Version required
88
88
  ================== ==================
89
89
 
90
90
  The changelog for the provider package can be found in the
91
- `changelog <https://airflow.apache.org/docs/apache-airflow-providers-standard/0.1.1/changelog.html>`_.
91
+ `changelog <https://airflow.apache.org/docs/apache-airflow-providers-standard/0.2.0b1/changelog.html>`_.
92
92
 
@@ -23,7 +23,7 @@
23
23
 
24
24
  Package ``apache-airflow-providers-standard``
25
25
 
26
- Release: ``0.1.1``
26
+ Release: ``0.2.0b1``
27
27
 
28
28
 
29
29
  Airflow Standard Provider
@@ -36,7 +36,7 @@ This is a provider package for ``standard`` provider. All classes for this provi
36
36
  are in ``airflow.providers.standard`` python package.
37
37
 
38
38
  You can find package information and changelog for the provider
39
- in the `documentation <https://airflow.apache.org/docs/apache-airflow-providers-standard/0.1.1/>`_.
39
+ in the `documentation <https://airflow.apache.org/docs/apache-airflow-providers-standard/0.2.0b1/>`_.
40
40
 
41
41
  Installation
42
42
  ------------
@@ -57,4 +57,4 @@ PIP package Version required
57
57
  ================== ==================
58
58
 
59
59
  The changelog for the provider package can be found in the
60
- `changelog <https://airflow.apache.org/docs/apache-airflow-providers-standard/0.1.1/changelog.html>`_.
60
+ `changelog <https://airflow.apache.org/docs/apache-airflow-providers-standard/0.2.0b1/changelog.html>`_.
@@ -25,7 +25,7 @@ build-backend = "flit_core.buildapi"
25
25
 
26
26
  [project]
27
27
  name = "apache-airflow-providers-standard"
28
- version = "0.1.1.rc1"
28
+ version = "0.2.0b1"
29
29
  description = "Provider package apache-airflow-providers-standard for Apache Airflow"
30
30
  readme = "README.rst"
31
31
  authors = [
@@ -57,7 +57,7 @@ requires-python = "~=3.9"
57
57
  # Make sure to run ``breeze static-checks --type update-providers-dependencies --all-files``
58
58
  # After you modify the dependencies, and rebuild your Breeze CI image with ``breeze ci-image build``
59
59
  dependencies = [
60
- "apache-airflow>=2.9.0rc0",
60
+ "apache-airflow>=2.9.0",
61
61
  ]
62
62
 
63
63
  [dependency-groups]
@@ -79,12 +79,12 @@ apache-airflow-providers-fab = {workspace = true}
79
79
  apache-airflow-providers-standard = {workspace = true}
80
80
 
81
81
  [project.urls]
82
- "Documentation" = "https://airflow.apache.org/docs/apache-airflow-providers-standard/0.1.1"
83
- "Changelog" = "https://airflow.apache.org/docs/apache-airflow-providers-standard/0.1.1/changelog.html"
82
+ "Documentation" = "https://airflow.apache.org/docs/apache-airflow-providers-standard/0.2.0b1"
83
+ "Changelog" = "https://airflow.apache.org/docs/apache-airflow-providers-standard/0.2.0b1/changelog.html"
84
84
  "Bug Tracker" = "https://github.com/apache/airflow/issues"
85
85
  "Source Code" = "https://github.com/apache/airflow"
86
86
  "Slack Chat" = "https://s.apache.org/airflow-slack"
87
- "Twitter" = "https://x.com/ApacheAirflow"
87
+ "Mastodon" = "https://fosstodon.org/@airflow"
88
88
  "YouTube" = "https://www.youtube.com/channel/UCSXwxpWZQ7XZ1WL3wqevChA/"
89
89
 
90
90
  [project.entry-points."apache_airflow_provider"]
@@ -29,7 +29,7 @@ from airflow import __version__ as airflow_version
29
29
 
30
30
  __all__ = ["__version__"]
31
31
 
32
- __version__ = "0.1.1"
32
+ __version__ = "0.2.0b1"
33
33
 
34
34
  if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
35
35
  "2.9.0"
@@ -27,8 +27,8 @@ def get_provider_info():
27
27
  "name": "Standard",
28
28
  "description": "Airflow Standard Provider\n",
29
29
  "state": "ready",
30
- "source-date-epoch": 1741509906,
31
- "versions": ["0.1.1", "0.1.0", "0.0.3", "0.0.2", "0.0.1"],
30
+ "source-date-epoch": 1742480519,
31
+ "versions": ["0.2.0b1", "0.1.1", "0.1.0", "0.0.3", "0.0.2", "0.0.1"],
32
32
  "integrations": [
33
33
  {
34
34
  "integration-name": "Standard",
@@ -53,6 +53,7 @@ def get_provider_info():
53
53
  "airflow.providers.standard.operators.trigger_dagrun",
54
54
  "airflow.providers.standard.operators.latest_only",
55
55
  "airflow.providers.standard.operators.smooth",
56
+ "airflow.providers.standard.operators.branch",
56
57
  ],
57
58
  }
58
59
  ],
@@ -0,0 +1,105 @@
1
+ #
2
+ # Licensed to the Apache Software Foundation (ASF) under one
3
+ # or more contributor license agreements. See the NOTICE file
4
+ # distributed with this work for additional information
5
+ # regarding copyright ownership. The ASF licenses this file
6
+ # to you under the Apache License, Version 2.0 (the
7
+ # "License"); you may not use this file except in compliance
8
+ # with the License. You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing,
13
+ # software distributed under the License is distributed on an
14
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15
+ # KIND, either express or implied. See the License for the
16
+ # specific language governing permissions and limitations
17
+ # under the License.
18
+ """Branching operators."""
19
+
20
+ from __future__ import annotations
21
+
22
+ from collections.abc import Iterable
23
+ from typing import TYPE_CHECKING
24
+
25
+ from airflow.models.baseoperator import BaseOperator
26
+ from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS
27
+
28
+ if AIRFLOW_V_3_0_PLUS:
29
+ from airflow.providers.standard.utils.skipmixin import SkipMixin
30
+ else:
31
+ from airflow.models.skipmixin import SkipMixin
32
+
33
+ if TYPE_CHECKING:
34
+ from airflow.sdk.definitions.context import Context
35
+ from airflow.sdk.types import RuntimeTaskInstanceProtocol
36
+
37
+
38
+ class BranchMixIn(SkipMixin):
39
+ """Utility helper which handles the branching as one-liner."""
40
+
41
+ def do_branch(self, context: Context, branches_to_execute: str | Iterable[str]) -> str | Iterable[str]:
42
+ """Implement the handling of branching including logging."""
43
+ self.log.info("Branch into %s", branches_to_execute)
44
+ branch_task_ids = self._expand_task_group_roots(context["ti"], branches_to_execute)
45
+ self.skip_all_except(context["ti"], branch_task_ids)
46
+ return branches_to_execute
47
+
48
+ def _expand_task_group_roots(
49
+ self, ti: RuntimeTaskInstanceProtocol, branches_to_execute: str | Iterable[str]
50
+ ) -> Iterable[str]:
51
+ """Expand any task group into its root task ids."""
52
+ if TYPE_CHECKING:
53
+ assert ti.task
54
+
55
+ task = ti.task
56
+ dag = task.dag
57
+ if TYPE_CHECKING:
58
+ assert dag
59
+
60
+ if branches_to_execute is None:
61
+ return
62
+ elif isinstance(branches_to_execute, str) or not isinstance(branches_to_execute, Iterable):
63
+ branches_to_execute = [branches_to_execute]
64
+
65
+ for branch in branches_to_execute:
66
+ if branch in dag.task_group_dict:
67
+ tg = dag.task_group_dict[branch]
68
+ root_ids = [root.task_id for root in tg.roots]
69
+ self.log.info("Expanding task group %s into %s", tg.group_id, root_ids)
70
+ yield from root_ids
71
+ else:
72
+ yield branch
73
+
74
+
75
+ class BaseBranchOperator(BaseOperator, BranchMixIn):
76
+ """
77
+ A base class for creating operators with branching functionality, like to BranchPythonOperator.
78
+
79
+ Users should create a subclass from this operator and implement the function
80
+ `choose_branch(self, context)`. This should run whatever business logic
81
+ is needed to determine the branch, and return one of the following:
82
+ - A single task_id (as a str)
83
+ - A single task_group_id (as a str)
84
+ - A list containing a combination of task_ids and task_group_ids
85
+
86
+ The operator will continue with the returned task_id(s) and/or task_group_id(s), and all other
87
+ tasks directly downstream of this operator will be skipped.
88
+ """
89
+
90
+ inherits_from_skipmixin = True
91
+
92
+ def choose_branch(self, context: Context) -> str | Iterable[str]:
93
+ """
94
+ Abstract method to choose which branch to run.
95
+
96
+ Subclasses should implement this, running whatever logic is
97
+ necessary to choose a branch and returning a task_id or list of
98
+ task_ids.
99
+
100
+ :param context: Context dictionary as passed to execute()
101
+ """
102
+ raise NotImplementedError
103
+
104
+ def execute(self, context: Context):
105
+ return self.do_branch(context, self.choose_branch(context))
@@ -21,7 +21,7 @@ from collections.abc import Iterable
21
21
  from typing import TYPE_CHECKING
22
22
 
23
23
  from airflow.exceptions import AirflowException
24
- from airflow.operators.branch import BaseBranchOperator
24
+ from airflow.providers.standard.operators.branch import BaseBranchOperator
25
25
  from airflow.utils import timezone
26
26
 
27
27
  if TYPE_CHECKING:
@@ -24,7 +24,7 @@ from typing import TYPE_CHECKING
24
24
 
25
25
  import pendulum
26
26
 
27
- from airflow.operators.branch import BaseBranchOperator
27
+ from airflow.providers.standard.operators.branch import BaseBranchOperator
28
28
  from airflow.utils.types import DagRunType
29
29
 
30
30
  if TYPE_CHECKING:
@@ -42,9 +42,7 @@ from airflow.exceptions import (
42
42
  DeserializingResultError,
43
43
  )
44
44
  from airflow.models.baseoperator import BaseOperator
45
- from airflow.models.skipmixin import SkipMixin
46
45
  from airflow.models.variable import Variable
47
- from airflow.operators.branch import BranchMixIn
48
46
  from airflow.providers.standard.utils.python_virtualenv import prepare_virtualenv, write_python_script
49
47
  from airflow.providers.standard.version_compat import AIRFLOW_V_2_10_PLUS, AIRFLOW_V_3_0_PLUS
50
48
  from airflow.utils import hashlib_wrapper
@@ -53,6 +51,14 @@ from airflow.utils.file import get_unique_dag_module_name
53
51
  from airflow.utils.operator_helpers import KeywordParameters
54
52
  from airflow.utils.process_utils import execute_in_subprocess, execute_in_subprocess_with_kwargs
55
53
 
54
+ if AIRFLOW_V_3_0_PLUS:
55
+ from airflow.providers.standard.operators.branch import BranchMixIn
56
+ from airflow.providers.standard.utils.skipmixin import SkipMixin
57
+ else:
58
+ from airflow.models.skipmixin import SkipMixin
59
+ from airflow.operators.branch import BranchMixIn
60
+
61
+
56
62
  log = logging.getLogger(__name__)
57
63
 
58
64
  if TYPE_CHECKING:
@@ -212,13 +218,11 @@ class PythonOperator(BaseOperator):
212
218
  """
213
219
  try:
214
220
  from airflow.utils.operator_helpers import ExecutionCallableRunner
215
-
216
- asset_events = self._asset_events if AIRFLOW_V_3_0_PLUS else self._dataset_events
217
-
218
- runner = ExecutionCallableRunner(self.python_callable, asset_events, logger=self.log)
219
221
  except ImportError:
220
- # Handle Pre Airflow 3.10 case where ExecutionCallableRunner was not available
222
+ # Handle Pre Airflow 2.10 case where ExecutionCallableRunner was not available
221
223
  return self.python_callable(*self.op_args, **self.op_kwargs)
224
+ asset_events = self._asset_events if AIRFLOW_V_3_0_PLUS else self._dataset_events
225
+ runner = ExecutionCallableRunner(self.python_callable, asset_events, logger=self.log)
222
226
  return runner.run(*self.op_args, **self.op_kwargs)
223
227
 
224
228
 
@@ -236,6 +240,8 @@ class BranchPythonOperator(PythonOperator, BranchMixIn):
236
240
  the DAG run's state to be inferred.
237
241
  """
238
242
 
243
+ inherits_from_skipmixin = True
244
+
239
245
  def execute(self, context: Context) -> Any:
240
246
  return self.do_branch(context, super().execute(context))
241
247
 
@@ -266,6 +272,8 @@ class ShortCircuitOperator(PythonOperator, SkipMixin):
266
272
  skipped but the ``trigger_rule`` defined for all other downstream tasks will be respected.
267
273
  """
268
274
 
275
+ inherits_from_skipmixin = True
276
+
269
277
  def __init__(self, *, ignore_downstream_trigger_rules: bool = True, **kwargs) -> None:
270
278
  super().__init__(**kwargs)
271
279
  self.ignore_downstream_trigger_rules = ignore_downstream_trigger_rules
@@ -295,25 +303,24 @@ class ShortCircuitOperator(PythonOperator, SkipMixin):
295
303
 
296
304
  to_skip = get_tasks_to_skip()
297
305
 
298
- # this let's us avoid an intermediate list unless debug logging
306
+ # this lets us avoid an intermediate list unless debug logging
299
307
  if self.log.getEffectiveLevel() <= logging.DEBUG:
300
308
  self.log.debug("Downstream task IDs %s", to_skip := list(get_tasks_to_skip()))
301
309
 
302
310
  self.log.info("Skipping downstream tasks")
303
311
  if AIRFLOW_V_3_0_PLUS:
304
312
  self.skip(
305
- dag_id=dag_run.dag_id,
306
- run_id=dag_run.run_id,
313
+ ti=context["ti"],
307
314
  tasks=to_skip,
308
- map_index=context["ti"].map_index,
309
315
  )
310
316
  else:
311
- self.skip(
312
- dag_run=dag_run,
313
- tasks=to_skip,
314
- execution_date=cast("DateTime", dag_run.logical_date), # type: ignore[call-arg, union-attr]
315
- map_index=context["ti"].map_index,
316
- )
317
+ if to_skip:
318
+ self.skip(
319
+ dag_run=context["dag_run"],
320
+ tasks=to_skip,
321
+ execution_date=cast("DateTime", dag_run.logical_date), # type: ignore[call-arg, union-attr]
322
+ map_index=context["ti"].map_index,
323
+ )
317
324
 
318
325
  self.log.info("Done.")
319
326
  # returns the result of the super execute method as it is instead of returning None
@@ -868,6 +875,8 @@ class BranchPythonVirtualenvOperator(PythonVirtualenvOperator, BranchMixIn):
868
875
  :ref:`howto/operator:BranchPythonVirtualenvOperator`
869
876
  """
870
877
 
878
+ inherits_from_skipmixin = True
879
+
871
880
  def execute(self, context: Context) -> Any:
872
881
  return self.do_branch(context, super().execute(context))
873
882
 
@@ -38,13 +38,12 @@ from airflow.models import BaseOperator
38
38
  from airflow.models.dag import DagModel
39
39
  from airflow.models.dagbag import DagBag
40
40
  from airflow.models.dagrun import DagRun
41
- from airflow.models.xcom import XCom
42
41
  from airflow.providers.standard.triggers.external_task import DagStateTrigger
43
42
  from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS
44
43
  from airflow.utils import timezone
45
44
  from airflow.utils.session import provide_session
46
45
  from airflow.utils.state import DagRunState
47
- from airflow.utils.types import DagRunTriggeredByType, DagRunType
46
+ from airflow.utils.types import DagRunType
48
47
 
49
48
  XCOM_LOGICAL_DATE_ISO = "trigger_logical_date_iso"
50
49
  XCOM_RUN_ID = "trigger_run_id"
@@ -63,7 +62,9 @@ if TYPE_CHECKING:
63
62
 
64
63
  if AIRFLOW_V_3_0_PLUS:
65
64
  from airflow.sdk import BaseOperatorLink
65
+ from airflow.sdk.execution_time.xcom import XCom
66
66
  else:
67
+ from airflow.models import XCom # type: ignore[no-redef]
67
68
  from airflow.models.baseoperatorlink import BaseOperatorLink # type: ignore[no-redef]
68
69
 
69
70
 
@@ -77,15 +78,15 @@ class TriggerDagRunLink(BaseOperatorLink):
77
78
  name = "Triggered DAG"
78
79
 
79
80
  def get_link(self, operator: BaseOperator, *, ti_key: TaskInstanceKey) -> str:
80
- from airflow.models.renderedtifields import RenderedTaskInstanceFields
81
-
82
81
  if TYPE_CHECKING:
83
82
  assert isinstance(operator, TriggerDagRunOperator)
84
83
 
85
- if template_fields := RenderedTaskInstanceFields.get_templated_fields(ti_key):
86
- trigger_dag_id: str = template_fields.get("trigger_dag_id", operator.trigger_dag_id)
87
- else:
88
- trigger_dag_id = operator.trigger_dag_id
84
+ trigger_dag_id = operator.trigger_dag_id
85
+ if not AIRFLOW_V_3_0_PLUS:
86
+ from airflow.models.renderedtifields import RenderedTaskInstanceFields
87
+
88
+ if template_fields := RenderedTaskInstanceFields.get_templated_fields(ti_key):
89
+ trigger_dag_id: str = template_fields.get("trigger_dag_id", operator.trigger_dag_id) # type: ignore[no-redef]
89
90
 
90
91
  # Fetch the correct dag_run_id for the triggerED dag which is
91
92
  # stored in xcom during execution of the triggerING task.
@@ -173,7 +174,7 @@ class TriggerDagRunOperator(BaseOperator):
173
174
  self.allowed_states = [DagRunState(s) for s in allowed_states]
174
175
  else:
175
176
  self.allowed_states = [DagRunState.SUCCESS]
176
- if failed_states or failed_states == []:
177
+ if failed_states is not None:
177
178
  self.failed_states = [DagRunState(s) for s in failed_states]
178
179
  else:
179
180
  self.failed_states = [DagRunState.FAILED]
@@ -197,25 +198,51 @@ class TriggerDagRunOperator(BaseOperator):
197
198
  try:
198
199
  json.dumps(self.conf)
199
200
  except TypeError:
200
- raise AirflowException("conf parameter should be JSON Serializable")
201
+ raise ValueError("conf parameter should be JSON Serializable")
201
202
 
202
203
  if self.trigger_run_id:
203
204
  run_id = str(self.trigger_run_id)
204
205
  else:
205
- run_id = DagRun.generate_run_id(
206
- run_type=DagRunType.MANUAL,
207
- logical_date=parsed_logical_date,
208
- run_after=parsed_logical_date or timezone.utcnow(),
209
- )
206
+ if AIRFLOW_V_3_0_PLUS:
207
+ run_id = DagRun.generate_run_id(
208
+ run_type=DagRunType.MANUAL,
209
+ logical_date=parsed_logical_date,
210
+ run_after=parsed_logical_date or timezone.utcnow(),
211
+ )
212
+ else:
213
+ run_id = DagRun.generate_run_id(DagRunType.MANUAL, parsed_logical_date or timezone.utcnow()) # type: ignore[misc,call-arg]
210
214
 
215
+ if AIRFLOW_V_3_0_PLUS:
216
+ self._trigger_dag_af_3(context=context, run_id=run_id, parsed_logical_date=parsed_logical_date)
217
+ else:
218
+ self._trigger_dag_af_2(context=context, run_id=run_id, parsed_logical_date=parsed_logical_date)
219
+
220
+ def _trigger_dag_af_3(self, context, run_id, parsed_logical_date):
221
+ from airflow.exceptions import DagRunTriggerException
222
+
223
+ raise DagRunTriggerException(
224
+ trigger_dag_id=self.trigger_dag_id,
225
+ dag_run_id=run_id,
226
+ conf=self.conf,
227
+ logical_date=parsed_logical_date,
228
+ reset_dag_run=self.reset_dag_run,
229
+ skip_when_already_exists=self.skip_when_already_exists,
230
+ wait_for_completion=self.wait_for_completion,
231
+ allowed_states=self.allowed_states,
232
+ failed_states=self.failed_states,
233
+ poke_interval=self.poke_interval,
234
+ )
235
+
236
+ # TODO: Support deferral
237
+
238
+ def _trigger_dag_af_2(self, context, run_id, parsed_logical_date):
211
239
  try:
212
240
  dag_run = trigger_dag(
213
241
  dag_id=self.trigger_dag_id,
214
242
  run_id=run_id,
215
243
  conf=self.conf,
216
- logical_date=parsed_logical_date,
244
+ execution_date=parsed_logical_date,
217
245
  replace_microseconds=False,
218
- triggered_by=DagRunTriggeredByType.OPERATOR,
219
246
  )
220
247
 
221
248
  except DagRunAlreadyExists as e:
@@ -231,7 +258,7 @@ class TriggerDagRunOperator(BaseOperator):
231
258
  # Note: here execution fails on database isolation mode. Needs structural changes for AIP-72
232
259
  dag_bag = DagBag(dag_folder=dag_model.fileloc, read_dags_from_db=True)
233
260
  dag = dag_bag.get_dag(self.trigger_dag_id)
234
- dag.clear(run_id=dag_run.run_id)
261
+ dag.clear(start_date=dag_run.logical_date, end_date=dag_run.logical_date)
235
262
  else:
236
263
  if self.skip_when_already_exists:
237
264
  raise AirflowSkipException(
@@ -252,6 +279,7 @@ class TriggerDagRunOperator(BaseOperator):
252
279
  trigger=DagStateTrigger(
253
280
  dag_id=self.trigger_dag_id,
254
281
  states=self.allowed_states + self.failed_states,
282
+ execution_dates=[dag_run.logical_date],
255
283
  run_ids=[run_id],
256
284
  poll_interval=self.poke_interval,
257
285
  ),
@@ -278,16 +306,18 @@ class TriggerDagRunOperator(BaseOperator):
278
306
 
279
307
  @provide_session
280
308
  def execute_complete(self, context: Context, session: Session, event: tuple[str, dict[str, Any]]):
281
- # This run_ids is parsed from the return trigger event
282
- provided_run_id = event[1]["run_ids"][0]
309
+ # This logical_date is parsed from the return trigger event
310
+ provided_logical_date = event[1]["execution_dates"][0]
283
311
  try:
284
312
  # Note: here execution fails on database isolation mode. Needs structural changes for AIP-72
285
313
  dag_run = session.execute(
286
- select(DagRun).where(DagRun.dag_id == self.trigger_dag_id, DagRun.run_id == provided_run_id)
314
+ select(DagRun).where(
315
+ DagRun.dag_id == self.trigger_dag_id, DagRun.execution_date == provided_logical_date
316
+ )
287
317
  ).scalar_one()
288
318
  except NoResultFound:
289
319
  raise AirflowException(
290
- f"No DAG run found for DAG {self.trigger_dag_id} and run ID {provided_run_id}"
320
+ f"No DAG run found for DAG {self.trigger_dag_id} and logical date {self.logical_date}"
291
321
  )
292
322
 
293
323
  state = dag_run.state
@@ -20,9 +20,9 @@ from __future__ import annotations
20
20
  from collections.abc import Iterable
21
21
  from typing import TYPE_CHECKING
22
22
 
23
- from airflow.operators.branch import BaseBranchOperator
23
+ from airflow.providers.standard.operators.branch import BaseBranchOperator
24
+ from airflow.providers.standard.utils.weekday import WeekDay
24
25
  from airflow.utils import timezone
25
- from airflow.utils.weekday import WeekDay
26
26
 
27
27
  if TYPE_CHECKING:
28
28
  try:
@@ -63,7 +63,7 @@ class BranchDayOfWeekOperator(BaseBranchOperator):
63
63
  .. code-block:: python
64
64
 
65
65
  # import WeekDay Enum
66
- from airflow.utils.weekday import WeekDay
66
+ from airflow.providers.standard.utils.weekday import WeekDay
67
67
  from airflow.providers.standard.operators.empty import EmptyOperator
68
68
  from airflow.operators.weekday import BranchDayOfWeekOperator
69
69
 
@@ -46,29 +46,48 @@ def _get_airflow_version():
46
46
 
47
47
  class TimeDeltaSensor(BaseSensorOperator):
48
48
  """
49
- Waits for a timedelta after the run's data interval.
49
+ Waits for a timedelta.
50
50
 
51
- :param delta: time length to wait after the data interval before succeeding.
51
+ The delta will be evaluated against data_interval_end if present for the dag run,
52
+ otherwise run_after will be used.
53
+
54
+ :param delta: time to wait before succeeding.
52
55
 
53
56
  .. seealso::
54
57
  For more information on how to use this sensor, take a look at the guide:
55
58
  :ref:`howto/operator:TimeDeltaSensor`
56
59
 
57
-
58
60
  """
59
61
 
60
62
  def __init__(self, *, delta, **kwargs):
61
63
  super().__init__(**kwargs)
62
64
  self.delta = delta
63
65
 
64
- def poke(self, context: Context):
65
- data_interval_end = context["data_interval_end"]
66
+ def _derive_base_time(self, context: Context) -> datetime:
67
+ """
68
+ Get the "base time" against which the delta should be calculated.
69
+
70
+ If data_interval_end is populated, use it; else use run_after.
71
+ """
72
+ data_interval_end = context.get("data_interval_end")
73
+ if data_interval_end:
74
+ if not isinstance(data_interval_end, datetime):
75
+ raise ValueError("`data_interval_end` returned non-datetime object")
66
76
 
67
- if not isinstance(data_interval_end, datetime):
68
- raise ValueError("`data_interval_end` returned non-datetime object")
77
+ return data_interval_end
69
78
 
70
- target_dttm: datetime = data_interval_end + self.delta
71
- self.log.info("Checking if the time (%s) has come", target_dttm)
79
+ if not data_interval_end and not AIRFLOW_V_3_0_PLUS:
80
+ raise ValueError("`data_interval_end` not found in task context.")
81
+
82
+ dag_run = context.get("dag_run")
83
+ if not dag_run:
84
+ raise ValueError("`dag_run` not found in task context")
85
+ return dag_run.run_after
86
+
87
+ def poke(self, context: Context) -> bool:
88
+ base_time = self._derive_base_time(context=context)
89
+ target_dttm = base_time + self.delta
90
+ self.log.info("Checking if the delta has elapsed base_time=%s, delta=%s", base_time, self.delta)
72
91
  return timezone.utcnow() > target_dttm
73
92
 
74
93
 
@@ -92,12 +111,8 @@ class TimeDeltaSensorAsync(TimeDeltaSensor):
92
111
  self.end_from_trigger = end_from_trigger
93
112
 
94
113
  def execute(self, context: Context) -> bool | NoReturn:
95
- data_interval_end = context["data_interval_end"]
96
-
97
- if not isinstance(data_interval_end, datetime):
98
- raise ValueError("`data_interval_end` returned non-datetime object")
99
-
100
- target_dttm: datetime = data_interval_end + self.delta
114
+ base_time = self._derive_base_time(context=context)
115
+ target_dttm: datetime = base_time + self.delta
101
116
 
102
117
  if timezone.utcnow() > target_dttm:
103
118
  # If the target datetime is in the past, return immediately
@@ -20,9 +20,9 @@ from __future__ import annotations
20
20
  from collections.abc import Iterable
21
21
  from typing import TYPE_CHECKING
22
22
 
23
+ from airflow.providers.standard.utils.weekday import WeekDay
23
24
  from airflow.sensors.base import BaseSensorOperator
24
25
  from airflow.utils import timezone
25
- from airflow.utils.weekday import WeekDay
26
26
 
27
27
  if TYPE_CHECKING:
28
28
  try:
@@ -54,7 +54,7 @@ class DayOfWeekSensor(BaseSensorOperator):
54
54
  **Example** (with :class:`~airflow.utils.weekday.WeekDay` enum): ::
55
55
 
56
56
  # import WeekDay Enum
57
- from airflow.utils.weekday import WeekDay
57
+ from airflow.providers.standard.utils.weekday import WeekDay
58
58
 
59
59
  weekend_check = DayOfWeekSensor(
60
60
  task_id="weekend_check",
@@ -103,7 +103,21 @@ class DayOfWeekSensor(BaseSensorOperator):
103
103
  self.week_day,
104
104
  WeekDay(timezone.utcnow().isoweekday()).name,
105
105
  )
106
+
106
107
  if self.use_task_logical_date:
107
- return context["logical_date"].isoweekday() in self._week_day_num
108
+ logical_date = context.get("logical_date")
109
+ dag_run = context.get("dag_run")
110
+
111
+ if not (logical_date or (dag_run and dag_run.run_after)):
112
+ raise ValueError(
113
+ "Either `logical_date` or `run_after` should be provided in the task context when "
114
+ "`use_task_logical_date` is True"
115
+ )
116
+
117
+ determined_weekday_num = (
118
+ logical_date.isoweekday() if logical_date else dag_run.run_after.isoweekday() # type: ignore[union-attr]
119
+ )
120
+
121
+ return determined_weekday_num in self._week_day_num
108
122
  else:
109
123
  return timezone.utcnow().isoweekday() in self._week_day_num
@@ -0,0 +1,192 @@
1
+ #
2
+ # Licensed to the Apache Software Foundation (ASF) under one
3
+ # or more contributor license agreements. See the NOTICE file
4
+ # distributed with this work for additional information
5
+ # regarding copyright ownership. The ASF licenses this file
6
+ # to you under the Apache License, Version 2.0 (the
7
+ # "License"); you may not use this file except in compliance
8
+ # with the License. You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing,
13
+ # software distributed under the License is distributed on an
14
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15
+ # KIND, either express or implied. See the License for the
16
+ # specific language governing permissions and limitations
17
+ # under the License.
18
+ from __future__ import annotations
19
+
20
+ from collections.abc import Iterable, Sequence
21
+ from types import GeneratorType
22
+ from typing import TYPE_CHECKING
23
+
24
+ from airflow.exceptions import AirflowException
25
+ from airflow.utils.log.logging_mixin import LoggingMixin
26
+
27
+ if TYPE_CHECKING:
28
+ from airflow.models.operator import Operator
29
+ from airflow.sdk.definitions._internal.node import DAGNode
30
+ from airflow.sdk.types import RuntimeTaskInstanceProtocol
31
+
32
+ # The key used by SkipMixin to store XCom data.
33
+ XCOM_SKIPMIXIN_KEY = "skipmixin_key"
34
+
35
+ # The dictionary key used to denote task IDs that are skipped
36
+ XCOM_SKIPMIXIN_SKIPPED = "skipped"
37
+
38
+ # The dictionary key used to denote task IDs that are followed
39
+ XCOM_SKIPMIXIN_FOLLOWED = "followed"
40
+
41
+
42
+ def _ensure_tasks(nodes: Iterable[DAGNode]) -> Sequence[Operator]:
43
+ from airflow.models.baseoperator import BaseOperator
44
+ from airflow.models.mappedoperator import MappedOperator
45
+
46
+ return [n for n in nodes if isinstance(n, (BaseOperator, MappedOperator))]
47
+
48
+
49
+ # This class should only be used in Airflow 3.0 and later.
50
+ class SkipMixin(LoggingMixin):
51
+ """A Mixin to skip Tasks Instances."""
52
+
53
+ @staticmethod
54
+ def _set_state_to_skipped(
55
+ tasks: Sequence[str | tuple[str, int]],
56
+ map_index: int | None,
57
+ ) -> None:
58
+ """
59
+ Set state of task instances to skipped from the same dag run.
60
+
61
+ Raises
62
+ ------
63
+ SkipDownstreamTaskInstances
64
+ If the task instances are not in the same dag run.
65
+ """
66
+ # Import is internal for backward compatibility when importing PythonOperator
67
+ # from airflow.providers.common.compat.standard.operators
68
+ from airflow.exceptions import DownstreamTasksSkipped
69
+
70
+ # The following could be applied only for non-mapped tasks,
71
+ # as future mapped tasks have not been expanded yet. Such tasks
72
+ # have to be handled by NotPreviouslySkippedDep.
73
+ if tasks and map_index == -1:
74
+ raise DownstreamTasksSkipped(tasks=tasks)
75
+
76
+ def skip(
77
+ self,
78
+ ti: RuntimeTaskInstanceProtocol,
79
+ tasks: Iterable[DAGNode],
80
+ ):
81
+ """
82
+ Set tasks instances to skipped from the same dag run.
83
+
84
+ If this instance has a `task_id` attribute, store the list of skipped task IDs to XCom
85
+ so that NotPreviouslySkippedDep knows these tasks should be skipped when they
86
+ are cleared.
87
+
88
+ :param ti: the task instance for which to set the tasks to skipped
89
+ :param tasks: tasks to skip (not task_ids)
90
+ """
91
+ # SkipMixin may not necessarily have a task_id attribute. Only store to XCom if one is available.
92
+ task_id: str | None = getattr(self, "task_id", None)
93
+ task_list = _ensure_tasks(tasks)
94
+ if not task_list:
95
+ return
96
+
97
+ task_ids_list = [d.task_id for d in task_list]
98
+
99
+ if task_id is not None:
100
+ ti.xcom_push(
101
+ key=XCOM_SKIPMIXIN_KEY,
102
+ value={XCOM_SKIPMIXIN_SKIPPED: task_ids_list},
103
+ )
104
+
105
+ self._set_state_to_skipped(task_ids_list, ti.map_index)
106
+
107
+ def skip_all_except(
108
+ self,
109
+ ti: RuntimeTaskInstanceProtocol,
110
+ branch_task_ids: None | str | Iterable[str],
111
+ ):
112
+ """
113
+ Implement the logic for a branching operator.
114
+
115
+ Given a single task ID or list of task IDs to follow, this skips all other tasks
116
+ immediately downstream of this operator.
117
+
118
+ branch_task_ids is stored to XCom so that NotPreviouslySkippedDep knows skipped tasks or
119
+ newly added tasks should be skipped when they are cleared.
120
+ """
121
+ # Ensure we don't serialize a generator object
122
+ if branch_task_ids and isinstance(branch_task_ids, GeneratorType):
123
+ branch_task_ids = list(branch_task_ids)
124
+ log = self.log # Note: need to catch logger form instance, static logger breaks pytest
125
+ if isinstance(branch_task_ids, str):
126
+ branch_task_id_set = {branch_task_ids}
127
+ elif isinstance(branch_task_ids, Iterable):
128
+ branch_task_id_set = set(branch_task_ids)
129
+ invalid_task_ids_type = {
130
+ (bti, type(bti).__name__) for bti in branch_task_id_set if not isinstance(bti, str)
131
+ }
132
+ if invalid_task_ids_type:
133
+ raise AirflowException(
134
+ f"'branch_task_ids' expected all task IDs are strings. "
135
+ f"Invalid tasks found: {invalid_task_ids_type}."
136
+ )
137
+ elif branch_task_ids is None:
138
+ branch_task_id_set = set()
139
+ else:
140
+ raise AirflowException(
141
+ "'branch_task_ids' must be either None, a task ID, or an Iterable of IDs, "
142
+ f"but got {type(branch_task_ids).__name__!r}."
143
+ )
144
+
145
+ log.info("Following branch %s", branch_task_id_set)
146
+
147
+ if TYPE_CHECKING:
148
+ assert ti.task
149
+
150
+ task = ti.task
151
+ dag = ti.task.dag
152
+
153
+ valid_task_ids = set(dag.task_ids)
154
+ invalid_task_ids = branch_task_id_set - valid_task_ids
155
+ if invalid_task_ids:
156
+ raise AirflowException(
157
+ "'branch_task_ids' must contain only valid task_ids. "
158
+ f"Invalid tasks found: {invalid_task_ids}."
159
+ )
160
+
161
+ downstream_tasks = _ensure_tasks(task.downstream_list)
162
+
163
+ if downstream_tasks:
164
+ # For a branching workflow that looks like this, when "branch" does skip_all_except("task1"),
165
+ # we intuitively expect both "task1" and "join" to execute even though strictly speaking,
166
+ # "join" is also immediately downstream of "branch" and should have been skipped. Therefore,
167
+ # we need a special case here for such empty branches: Check downstream tasks of branch_task_ids.
168
+ # In case the task to skip is also downstream of branch_task_ids, we add it to branch_task_ids and
169
+ # exclude it from skipping.
170
+ #
171
+ # branch -----> join
172
+ # \ ^
173
+ # v /
174
+ # task1
175
+ #
176
+ for branch_task_id in list(branch_task_id_set):
177
+ branch_task_id_set.update(dag.get_task(branch_task_id).get_flat_relative_ids(upstream=False))
178
+
179
+ skip_tasks = [
180
+ (t.task_id, ti.map_index) for t in downstream_tasks if t.task_id not in branch_task_id_set
181
+ ]
182
+
183
+ follow_task_ids = [t.task_id for t in downstream_tasks if t.task_id in branch_task_id_set]
184
+ log.info("Skipping tasks %s", skip_tasks)
185
+ ti.xcom_push(
186
+ key=XCOM_SKIPMIXIN_KEY,
187
+ value={XCOM_SKIPMIXIN_FOLLOWED: follow_task_ids},
188
+ )
189
+ # The following could be applied only for non-mapped tasks,
190
+ # as future mapped tasks have not been expanded yet. Such tasks
191
+ # have to be handled by NotPreviouslySkippedDep.
192
+ self._set_state_to_skipped(skip_tasks, ti.map_index) # type: ignore[arg-type]
@@ -0,0 +1,77 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+ """Get the ISO standard day number of the week from a given day string."""
18
+
19
+ from __future__ import annotations
20
+
21
+ import enum
22
+ from collections.abc import Iterable
23
+
24
+
25
+ @enum.unique
26
+ class WeekDay(enum.IntEnum):
27
+ """Python Enum containing Days of the Week."""
28
+
29
+ MONDAY = 1
30
+ TUESDAY = 2
31
+ WEDNESDAY = 3
32
+ THURSDAY = 4
33
+ FRIDAY = 5
34
+ SATURDAY = 6
35
+ SUNDAY = 7
36
+
37
+ @classmethod
38
+ def get_weekday_number(cls, week_day_str: str):
39
+ """
40
+ Return the ISO Week Day Number for a Week Day.
41
+
42
+ :param week_day_str: Full Name of the Week Day. Example: "Sunday"
43
+ :return: ISO Week Day Number corresponding to the provided Weekday
44
+ """
45
+ sanitized_week_day_str = week_day_str.upper()
46
+
47
+ if sanitized_week_day_str not in cls.__members__:
48
+ raise AttributeError(f'Invalid Week Day passed: "{week_day_str}"')
49
+
50
+ return cls[sanitized_week_day_str]
51
+
52
+ @classmethod
53
+ def convert(cls, day: str | WeekDay) -> int:
54
+ """Return the day number in the week."""
55
+ if isinstance(day, WeekDay):
56
+ return day
57
+ return cls.get_weekday_number(week_day_str=day)
58
+
59
+ @classmethod
60
+ def validate_week_day(
61
+ cls,
62
+ week_day: str | WeekDay | Iterable[str] | Iterable[WeekDay],
63
+ ) -> set[int]:
64
+ """Validate each item of iterable and create a set to ease compare of values."""
65
+ if not isinstance(week_day, Iterable):
66
+ if isinstance(week_day, WeekDay):
67
+ week_day = {week_day}
68
+ else:
69
+ raise TypeError(
70
+ f"Unsupported Type for week_day parameter: {type(week_day)}."
71
+ "Input should be iterable type:"
72
+ "str, set, list, dict or Weekday enum type"
73
+ )
74
+ if isinstance(week_day, str):
75
+ week_day = {week_day}
76
+
77
+ return {cls.convert(item) for item in week_day}