apache-airflow-providers-standard 1.9.0__py3-none-any.whl → 1.9.2rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. airflow/providers/standard/__init__.py +1 -1
  2. airflow/providers/standard/decorators/bash.py +7 -13
  3. airflow/providers/standard/decorators/branch_external_python.py +2 -8
  4. airflow/providers/standard/decorators/branch_python.py +2 -7
  5. airflow/providers/standard/decorators/branch_virtualenv.py +2 -7
  6. airflow/providers/standard/decorators/external_python.py +2 -7
  7. airflow/providers/standard/decorators/python.py +2 -7
  8. airflow/providers/standard/decorators/python_virtualenv.py +2 -9
  9. airflow/providers/standard/decorators/sensor.py +2 -9
  10. airflow/providers/standard/decorators/short_circuit.py +2 -8
  11. airflow/providers/standard/decorators/stub.py +6 -12
  12. airflow/providers/standard/example_dags/example_bash_decorator.py +1 -6
  13. airflow/providers/standard/example_dags/example_branch_operator.py +1 -6
  14. airflow/providers/standard/example_dags/example_branch_operator_decorator.py +1 -6
  15. airflow/providers/standard/example_dags/example_external_task_parent_deferrable.py +2 -6
  16. airflow/providers/standard/example_dags/example_hitl_operator.py +1 -1
  17. airflow/providers/standard/example_dags/example_sensors.py +1 -6
  18. airflow/providers/standard/example_dags/example_short_circuit_decorator.py +1 -6
  19. airflow/providers/standard/example_dags/example_short_circuit_operator.py +1 -6
  20. airflow/providers/standard/hooks/filesystem.py +1 -1
  21. airflow/providers/standard/hooks/package_index.py +1 -1
  22. airflow/providers/standard/hooks/subprocess.py +3 -10
  23. airflow/providers/standard/operators/bash.py +3 -12
  24. airflow/providers/standard/operators/branch.py +1 -1
  25. airflow/providers/standard/operators/datetime.py +2 -6
  26. airflow/providers/standard/operators/empty.py +1 -1
  27. airflow/providers/standard/operators/hitl.py +12 -9
  28. airflow/providers/standard/operators/latest_only.py +3 -8
  29. airflow/providers/standard/operators/python.py +9 -9
  30. airflow/providers/standard/operators/smooth.py +1 -1
  31. airflow/providers/standard/operators/trigger_dagrun.py +16 -21
  32. airflow/providers/standard/operators/weekday.py +2 -6
  33. airflow/providers/standard/sensors/bash.py +3 -8
  34. airflow/providers/standard/sensors/date_time.py +2 -6
  35. airflow/providers/standard/sensors/external_task.py +77 -55
  36. airflow/providers/standard/sensors/filesystem.py +1 -1
  37. airflow/providers/standard/sensors/python.py +2 -6
  38. airflow/providers/standard/sensors/time.py +1 -6
  39. airflow/providers/standard/sensors/time_delta.py +3 -7
  40. airflow/providers/standard/sensors/weekday.py +2 -7
  41. airflow/providers/standard/triggers/external_task.py +36 -36
  42. airflow/providers/standard/triggers/file.py +1 -1
  43. airflow/providers/standard/triggers/hitl.py +135 -86
  44. airflow/providers/standard/triggers/temporal.py +1 -5
  45. airflow/providers/standard/utils/python_virtualenv.py +36 -3
  46. airflow/providers/standard/utils/sensor_helper.py +19 -8
  47. airflow/providers/standard/utils/skipmixin.py +1 -7
  48. airflow/providers/standard/version_compat.py +4 -21
  49. {apache_airflow_providers_standard-1.9.0.dist-info → apache_airflow_providers_standard-1.9.2rc1.dist-info}/METADATA +36 -13
  50. apache_airflow_providers_standard-1.9.2rc1.dist-info/RECORD +78 -0
  51. apache_airflow_providers_standard-1.9.2rc1.dist-info/licenses/NOTICE +5 -0
  52. apache_airflow_providers_standard-1.9.0.dist-info/RECORD +0 -77
  53. {apache_airflow_providers_standard-1.9.0.dist-info → apache_airflow_providers_standard-1.9.2rc1.dist-info}/WHEEL +0 -0
  54. {apache_airflow_providers_standard-1.9.0.dist-info → apache_airflow_providers_standard-1.9.2rc1.dist-info}/entry_points.txt +0 -0
  55. {airflow/providers/standard → apache_airflow_providers_standard-1.9.2rc1.dist-info/licenses}/LICENSE +0 -0
@@ -22,7 +22,7 @@ from typing import TYPE_CHECKING
22
22
  from airflow.providers.standard.version_compat import BaseOperator
23
23
 
24
24
  if TYPE_CHECKING:
25
- from airflow.sdk.definitions.context import Context
25
+ from airflow.providers.common.compat.sdk import Context
26
26
 
27
27
 
28
28
  class SmoothOperator(BaseOperator):
@@ -21,6 +21,7 @@ import datetime
21
21
  import json
22
22
  import time
23
23
  from collections.abc import Sequence
24
+ from json import JSONDecodeError
24
25
  from typing import TYPE_CHECKING, Any
25
26
 
26
27
  from sqlalchemy import select
@@ -37,13 +38,9 @@ from airflow.exceptions import (
37
38
  from airflow.models.dag import DagModel
38
39
  from airflow.models.dagrun import DagRun
39
40
  from airflow.models.serialized_dag import SerializedDagModel
41
+ from airflow.providers.common.compat.sdk import BaseOperatorLink, XCom, timezone
40
42
  from airflow.providers.standard.triggers.external_task import DagStateTrigger
41
- from airflow.providers.standard.version_compat import (
42
- AIRFLOW_V_3_0_PLUS,
43
- BaseOperator,
44
- BaseOperatorLink,
45
- timezone,
46
- )
43
+ from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS, BaseOperator
47
44
  from airflow.utils.state import DagRunState
48
45
  from airflow.utils.types import NOTSET, ArgNotSet, DagRunType
49
46
 
@@ -55,17 +52,7 @@ if TYPE_CHECKING:
55
52
  from sqlalchemy.orm.session import Session
56
53
 
57
54
  from airflow.models.taskinstancekey import TaskInstanceKey
58
-
59
- try:
60
- from airflow.sdk.definitions.context import Context
61
- except ImportError:
62
- # TODO: Remove once provider drops support for Airflow 2
63
- from airflow.utils.context import Context
64
-
65
- if AIRFLOW_V_3_0_PLUS:
66
- from airflow.sdk.execution_time.xcom import XCom
67
- else:
68
- from airflow.models import XCom
55
+ from airflow.providers.common.compat.sdk import Context
69
56
 
70
57
 
71
58
  class DagIsPaused(AirflowException):
@@ -203,6 +190,9 @@ class TriggerDagRunOperator(BaseOperator):
203
190
  f"Expected str, datetime.datetime, or None for parameter 'logical_date'. Got {type(logical_date).__name__}"
204
191
  )
205
192
 
193
+ if fail_when_dag_is_paused and AIRFLOW_V_3_0_PLUS:
194
+ raise NotImplementedError("Setting `fail_when_dag_is_paused` not yet supported for Airflow 3.x")
195
+
206
196
  def execute(self, context: Context):
207
197
  if self.logical_date is NOTSET:
208
198
  # If no logical_date is provided we will set utcnow()
@@ -213,9 +203,11 @@ class TriggerDagRunOperator(BaseOperator):
213
203
  parsed_logical_date = timezone.parse(self.logical_date)
214
204
 
215
205
  try:
206
+ if self.conf and isinstance(self.conf, str):
207
+ self.conf = json.loads(self.conf)
216
208
  json.dumps(self.conf)
217
- except TypeError:
218
- raise ValueError("conf parameter should be JSON Serializable")
209
+ except (TypeError, JSONDecodeError):
210
+ raise ValueError("conf parameter should be JSON Serializable %s", self.conf)
219
211
 
220
212
  if self.trigger_run_id:
221
213
  run_id = str(self.trigger_run_id)
@@ -231,9 +223,12 @@ class TriggerDagRunOperator(BaseOperator):
231
223
 
232
224
  if self.fail_when_dag_is_paused:
233
225
  dag_model = DagModel.get_current(self.trigger_dag_id)
226
+ if not dag_model:
227
+ raise ValueError(f"Dag {self.trigger_dag_id} is not found")
234
228
  if dag_model.is_paused:
235
- if AIRFLOW_V_3_0_PLUS:
236
- raise DagIsPaused(dag_id=self.trigger_dag_id)
229
+ # TODO: enable this when dag state endpoint available from task sdk
230
+ # if AIRFLOW_V_3_0_PLUS:
231
+ # raise DagIsPaused(dag_id=self.trigger_dag_id)
237
232
  raise AirflowException(f"Dag {self.trigger_dag_id} is paused")
238
233
 
239
234
  if AIRFLOW_V_3_0_PLUS:
@@ -20,16 +20,12 @@ from __future__ import annotations
20
20
  from collections.abc import Iterable
21
21
  from typing import TYPE_CHECKING
22
22
 
23
+ from airflow.providers.common.compat.sdk import timezone
23
24
  from airflow.providers.standard.operators.branch import BaseBranchOperator
24
25
  from airflow.providers.standard.utils.weekday import WeekDay
25
- from airflow.utils import timezone
26
26
 
27
27
  if TYPE_CHECKING:
28
- try:
29
- from airflow.sdk.definitions.context import Context
30
- except ImportError:
31
- # TODO: Remove once provider drops support for Airflow 2
32
- from airflow.utils.context import Context
28
+ from airflow.providers.common.compat.sdk import Context
33
29
 
34
30
 
35
31
  class BranchDayOfWeekOperator(BaseBranchOperator):
@@ -17,21 +17,16 @@
17
17
  # under the License.
18
18
  from __future__ import annotations
19
19
 
20
- import os
21
20
  from collections.abc import Sequence
22
21
  from subprocess import PIPE, STDOUT, Popen
23
22
  from tempfile import NamedTemporaryFile, TemporaryDirectory, gettempdir
24
23
  from typing import TYPE_CHECKING
25
24
 
26
25
  from airflow.exceptions import AirflowFailException
27
- from airflow.providers.standard.version_compat import BaseSensorOperator
26
+ from airflow.providers.common.compat.sdk import BaseSensorOperator
28
27
 
29
28
  if TYPE_CHECKING:
30
- try:
31
- from airflow.sdk.definitions.context import Context
32
- except ImportError:
33
- # TODO: Remove once provider drops support for Airflow 2
34
- from airflow.utils.context import Context
29
+ from airflow.providers.common.compat.sdk import Context
35
30
 
36
31
 
37
32
  class BashSensor(BaseSensorOperator):
@@ -93,7 +88,7 @@ class BashSensor(BaseSensorOperator):
93
88
  close_fds=True,
94
89
  cwd=tmp_dir,
95
90
  env=self.env,
96
- preexec_fn=os.setsid,
91
+ start_new_session=True,
97
92
  ) as resp:
98
93
  if resp.stdout:
99
94
  self.log.info("Output:")
@@ -22,13 +22,9 @@ from collections.abc import Sequence
22
22
  from dataclasses import dataclass
23
23
  from typing import TYPE_CHECKING, Any, NoReturn
24
24
 
25
+ from airflow.providers.common.compat.sdk import BaseSensorOperator, timezone
25
26
  from airflow.providers.standard.triggers.temporal import DateTimeTrigger
26
- from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS, BaseSensorOperator
27
-
28
- try:
29
- from airflow.sdk import timezone
30
- except ImportError: # TODO: Remove this when min airflow version is 3.1.0 for standard provider
31
- from airflow.utils import timezone # type: ignore[attr-defined,no-redef]
27
+ from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS
32
28
 
33
29
  try:
34
30
  from airflow.triggers.base import StartTriggerArgs # type: ignore[no-redef]
@@ -18,13 +18,15 @@ from __future__ import annotations
18
18
 
19
19
  import datetime
20
20
  import os
21
+ import typing
21
22
  import warnings
22
- from collections.abc import Callable, Collection, Iterable
23
- from typing import TYPE_CHECKING, Any, ClassVar
23
+ from collections.abc import Callable, Collection, Iterable, Sequence
24
+ from typing import TYPE_CHECKING, ClassVar
24
25
 
25
26
  from airflow.configuration import conf
26
27
  from airflow.exceptions import AirflowSkipException
27
28
  from airflow.models.dag import DagModel
29
+ from airflow.providers.common.compat.sdk import BaseOperatorLink, BaseSensorOperator
28
30
  from airflow.providers.standard.exceptions import (
29
31
  DuplicateStateError,
30
32
  ExternalDagDeletedError,
@@ -42,8 +44,6 @@ from airflow.providers.standard.version_compat import (
42
44
  AIRFLOW_V_3_0_PLUS,
43
45
  AIRFLOW_V_3_2_PLUS,
44
46
  BaseOperator,
45
- BaseOperatorLink,
46
- BaseSensorOperator,
47
47
  )
48
48
  from airflow.utils.file import correct_maybe_zipped
49
49
  from airflow.utils.state import State, TaskInstanceState
@@ -60,11 +60,7 @@ if TYPE_CHECKING:
60
60
  from sqlalchemy.orm import Session
61
61
 
62
62
  from airflow.models.taskinstancekey import TaskInstanceKey
63
-
64
- if AIRFLOW_V_3_0_PLUS:
65
- from airflow.sdk.definitions.context import Context
66
- else:
67
- from airflow.utils.context import Context
63
+ from airflow.providers.common.compat.sdk import Context
68
64
 
69
65
 
70
66
  class ExternalDagLink(BaseOperatorLink):
@@ -256,17 +252,15 @@ class ExternalTaskSensor(BaseSensorOperator):
256
252
  self.deferrable = deferrable
257
253
  self.poll_interval = poll_interval
258
254
 
259
- def _get_dttm_filter(self, context):
255
+ def _get_dttm_filter(self, context: Context) -> Sequence[datetime.datetime]:
260
256
  logical_date = self._get_logical_date(context)
261
257
 
262
258
  if self.execution_delta:
263
- dttm = logical_date - self.execution_delta
264
- elif self.execution_date_fn:
265
- dttm = self._handle_execution_date_fn(context=context)
266
- else:
267
- dttm = logical_date
268
-
269
- return dttm if isinstance(dttm, list) else [dttm]
259
+ return [logical_date - self.execution_delta]
260
+ if self.execution_date_fn:
261
+ result = self._handle_execution_date_fn(context=context)
262
+ return result if isinstance(result, list) else [result]
263
+ return [logical_date]
270
264
 
271
265
  def poke(self, context: Context) -> bool:
272
266
  # delay check to poke rather than __init__ in case it was supplied as XComArgs
@@ -303,7 +297,7 @@ class ExternalTaskSensor(BaseSensorOperator):
303
297
  return self._poke_af3(context, dttm_filter)
304
298
  return self._poke_af2(dttm_filter)
305
299
 
306
- def _poke_af3(self, context: Context, dttm_filter: list[datetime.datetime]) -> bool:
300
+ def _poke_af3(self, context: Context, dttm_filter: Sequence[datetime.datetime]) -> bool:
307
301
  from airflow.providers.standard.utils.sensor_helper import _get_count_by_matched_states
308
302
 
309
303
  self._has_checked_existence = True
@@ -313,20 +307,20 @@ class ExternalTaskSensor(BaseSensorOperator):
313
307
  if self.external_task_ids:
314
308
  return ti.get_ti_count(
315
309
  dag_id=self.external_dag_id,
316
- task_ids=self.external_task_ids, # type: ignore[arg-type]
317
- logical_dates=dttm_filter,
310
+ task_ids=list(self.external_task_ids),
311
+ logical_dates=list(dttm_filter),
318
312
  states=states,
319
313
  )
320
314
  if self.external_task_group_id:
321
315
  run_id_task_state_map = ti.get_task_states(
322
316
  dag_id=self.external_dag_id,
323
317
  task_group_id=self.external_task_group_id,
324
- logical_dates=dttm_filter,
318
+ logical_dates=list(dttm_filter),
325
319
  )
326
320
  return _get_count_by_matched_states(run_id_task_state_map, states)
327
321
  return ti.get_dr_count(
328
322
  dag_id=self.external_dag_id,
329
- logical_dates=dttm_filter,
323
+ logical_dates=list(dttm_filter),
330
324
  states=states,
331
325
  )
332
326
 
@@ -344,7 +338,7 @@ class ExternalTaskSensor(BaseSensorOperator):
344
338
  count_allowed = self._calculate_count(count, dttm_filter)
345
339
  return count_allowed == len(dttm_filter)
346
340
 
347
- def _calculate_count(self, count: int, dttm_filter: list[datetime.datetime]) -> float | int:
341
+ def _calculate_count(self, count: int, dttm_filter: Sequence[datetime.datetime]) -> float | int:
348
342
  """Calculate the normalized count based on the type of check."""
349
343
  if self.external_task_ids:
350
344
  return count / len(self.external_task_ids)
@@ -400,7 +394,7 @@ class ExternalTaskSensor(BaseSensorOperator):
400
394
  if not AIRFLOW_V_3_0_PLUS:
401
395
 
402
396
  @provide_session
403
- def _poke_af2(self, dttm_filter: list[datetime.datetime], session: Session = NEW_SESSION) -> bool:
397
+ def _poke_af2(self, dttm_filter: Sequence[datetime.datetime], session: Session = NEW_SESSION) -> bool:
404
398
  if self.check_existence and not self._has_checked_existence:
405
399
  self._check_for_existence(session=session)
406
400
 
@@ -421,27 +415,48 @@ class ExternalTaskSensor(BaseSensorOperator):
421
415
  super().execute(context)
422
416
  else:
423
417
  dttm_filter = self._get_dttm_filter(context)
424
- logical_or_execution_dates = (
425
- {"logical_dates": dttm_filter} if AIRFLOW_V_3_0_PLUS else {"execution_dates": dttm_filter}
426
- )
427
- self.defer(
428
- timeout=self.execution_timeout,
429
- trigger=WorkflowTrigger(
430
- external_dag_id=self.external_dag_id,
431
- external_task_group_id=self.external_task_group_id,
432
- external_task_ids=self.external_task_ids,
433
- allowed_states=self.allowed_states,
434
- failed_states=self.failed_states,
435
- skipped_states=self.skipped_states,
436
- poke_interval=self.poll_interval,
437
- soft_fail=self.soft_fail,
438
- **logical_or_execution_dates,
439
- ),
440
- method_name="execute_complete",
441
- )
418
+ if AIRFLOW_V_3_0_PLUS:
419
+ self.defer(
420
+ timeout=self.execution_timeout,
421
+ trigger=WorkflowTrigger(
422
+ external_dag_id=self.external_dag_id,
423
+ external_task_group_id=self.external_task_group_id,
424
+ external_task_ids=self.external_task_ids,
425
+ allowed_states=self.allowed_states,
426
+ failed_states=self.failed_states,
427
+ skipped_states=self.skipped_states,
428
+ poke_interval=self.poll_interval,
429
+ soft_fail=self.soft_fail,
430
+ logical_dates=list(dttm_filter),
431
+ run_ids=None,
432
+ execution_dates=None,
433
+ ),
434
+ method_name="execute_complete",
435
+ )
436
+ else:
437
+ self.defer(
438
+ timeout=self.execution_timeout,
439
+ trigger=WorkflowTrigger(
440
+ external_dag_id=self.external_dag_id,
441
+ external_task_group_id=self.external_task_group_id,
442
+ external_task_ids=self.external_task_ids,
443
+ allowed_states=self.allowed_states,
444
+ failed_states=self.failed_states,
445
+ skipped_states=self.skipped_states,
446
+ poke_interval=self.poll_interval,
447
+ soft_fail=self.soft_fail,
448
+ execution_dates=list(dttm_filter),
449
+ logical_dates=None,
450
+ run_ids=None,
451
+ ),
452
+ method_name="execute_complete",
453
+ )
442
454
 
443
- def execute_complete(self, context, event=None):
455
+ def execute_complete(self, context: Context, event: dict[str, typing.Any] | None = None) -> None:
444
456
  """Execute when the trigger fires - return immediately."""
457
+ if event is None:
458
+ raise ExternalTaskNotFoundError("No event received from trigger")
459
+
445
460
  if event["status"] == "success":
446
461
  self.log.info("External tasks %s has executed successfully.", self.external_task_ids)
447
462
  elif event["status"] == "skipped":
@@ -458,13 +473,14 @@ class ExternalTaskSensor(BaseSensorOperator):
458
473
  "name of executed task and Dag."
459
474
  )
460
475
 
461
- def _check_for_existence(self, session) -> None:
476
+ def _check_for_existence(self, session: Session) -> None:
462
477
  dag_to_wait = DagModel.get_current(self.external_dag_id, session)
463
478
 
464
479
  if not dag_to_wait:
465
480
  raise ExternalDagNotFoundError(f"The external DAG {self.external_dag_id} does not exist.")
466
481
 
467
- if not os.path.exists(correct_maybe_zipped(dag_to_wait.fileloc)):
482
+ path = correct_maybe_zipped(dag_to_wait.fileloc)
483
+ if not path or not os.path.exists(path):
468
484
  raise ExternalDagDeletedError(f"The external DAG {self.external_dag_id} was deleted.")
469
485
 
470
486
  if self.external_task_ids:
@@ -493,7 +509,7 @@ class ExternalTaskSensor(BaseSensorOperator):
493
509
 
494
510
  self._has_checked_existence = True
495
511
 
496
- def get_count(self, dttm_filter, session, states) -> int:
512
+ def get_count(self, dttm_filter: Sequence[datetime.datetime], session: Session, states: list[str]) -> int:
497
513
  """
498
514
  Get the count of records against dttm filter and states.
499
515
 
@@ -514,15 +530,19 @@ class ExternalTaskSensor(BaseSensorOperator):
514
530
  session,
515
531
  )
516
532
 
517
- def get_external_task_group_task_ids(self, session, dttm_filter):
533
+ def get_external_task_group_task_ids(
534
+ self, session: Session, dttm_filter: Sequence[datetime.datetime]
535
+ ) -> list[tuple[str, int]]:
518
536
  warnings.warn(
519
537
  "This method is deprecated and will be removed in future.", DeprecationWarning, stacklevel=2
520
538
  )
539
+ if self.external_task_group_id is None:
540
+ return []
521
541
  return _get_external_task_group_task_ids(
522
- dttm_filter, self.external_task_group_id, self.external_dag_id, session
542
+ list(dttm_filter), self.external_task_group_id, self.external_dag_id, session
523
543
  )
524
544
 
525
- def _get_logical_date(self, context) -> datetime.datetime:
545
+ def _get_logical_date(self, context: Context) -> datetime.datetime:
526
546
  """
527
547
  Handle backwards- and forwards-compatible retrieval of the date.
528
548
 
@@ -532,19 +552,21 @@ class ExternalTaskSensor(BaseSensorOperator):
532
552
  if AIRFLOW_V_3_0_PLUS:
533
553
  logical_date = context.get("logical_date")
534
554
  dag_run = context.get("dag_run")
535
- if not (logical_date or (dag_run and dag_run.run_after)):
536
- raise ValueError(
537
- "Either `logical_date` or `dag_run.run_after` must be provided in the context"
538
- )
539
- return logical_date or dag_run.run_after
555
+ if logical_date:
556
+ return logical_date
557
+ if dag_run and hasattr(dag_run, "run_after") and dag_run.run_after:
558
+ return dag_run.run_after
559
+ raise ValueError("Either `logical_date` or `dag_run.run_after` must be provided in the context")
540
560
 
541
561
  # Airflow 2.x and earlier: contexts used "execution_date"
542
562
  execution_date = context.get("execution_date")
543
563
  if not execution_date:
544
564
  raise ValueError("Either `execution_date` must be provided in the context`")
565
+ if not isinstance(execution_date, datetime.datetime):
566
+ raise ValueError("execution_date must be a datetime object")
545
567
  return execution_date
546
568
 
547
- def _handle_execution_date_fn(self, context) -> Any:
569
+ def _handle_execution_date_fn(self, context: Context) -> datetime.datetime | list[datetime.datetime]:
548
570
  """
549
571
  Handle backward compatibility.
550
572
 
@@ -27,9 +27,9 @@ from typing import TYPE_CHECKING, Any
27
27
 
28
28
  from airflow.configuration import conf
29
29
  from airflow.exceptions import AirflowException
30
+ from airflow.providers.common.compat.sdk import BaseSensorOperator
30
31
  from airflow.providers.standard.hooks.filesystem import FSHook
31
32
  from airflow.providers.standard.triggers.file import FileTrigger
32
- from airflow.providers.standard.version_compat import BaseSensorOperator
33
33
 
34
34
  try:
35
35
  from airflow.triggers.base import StartTriggerArgs # type: ignore[no-redef]
@@ -20,15 +20,11 @@ from __future__ import annotations
20
20
  from collections.abc import Callable, Mapping, Sequence
21
21
  from typing import TYPE_CHECKING, Any
22
22
 
23
- from airflow.providers.standard.version_compat import BaseSensorOperator, PokeReturnValue, context_merge
23
+ from airflow.providers.common.compat.sdk import BaseSensorOperator, PokeReturnValue, context_merge
24
24
  from airflow.utils.operator_helpers import determine_kwargs
25
25
 
26
26
  if TYPE_CHECKING:
27
- try:
28
- from airflow.sdk.definitions.context import Context
29
- except ImportError:
30
- # TODO: Remove once provider drops support for Airflow 2
31
- from airflow.utils.context import Context # type: ignore[no-redef, attr-defined]
27
+ from airflow.providers.common.compat.sdk import Context
32
28
 
33
29
 
34
30
  class PythonSensor(BaseSensorOperator):
@@ -24,8 +24,8 @@ from typing import TYPE_CHECKING, Any
24
24
 
25
25
  from airflow.configuration import conf
26
26
  from airflow.exceptions import AirflowProviderDeprecationWarning
27
+ from airflow.providers.common.compat.sdk import BaseSensorOperator, timezone
27
28
  from airflow.providers.standard.triggers.temporal import DateTimeTrigger
28
- from airflow.providers.standard.version_compat import BaseSensorOperator
29
29
 
30
30
  try:
31
31
  from airflow.triggers.base import StartTriggerArgs # type: ignore[no-redef]
@@ -42,11 +42,6 @@ except ImportError: # TODO: Remove this when min airflow version is 2.10.0 for
42
42
  timeout: datetime.timedelta | None = None
43
43
 
44
44
 
45
- try:
46
- from airflow.sdk import timezone
47
- except ImportError:
48
- from airflow.utils import timezone # type: ignore[attr-defined,no-redef]
49
-
50
45
  if TYPE_CHECKING:
51
46
  from airflow.sdk import Context
52
47
 
@@ -27,16 +27,12 @@ from packaging.version import Version
27
27
 
28
28
  from airflow.configuration import conf
29
29
  from airflow.exceptions import AirflowProviderDeprecationWarning, AirflowSkipException
30
+ from airflow.providers.common.compat.sdk import BaseSensorOperator, timezone
30
31
  from airflow.providers.standard.triggers.temporal import DateTimeTrigger, TimeDeltaTrigger
31
- from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS, BaseSensorOperator
32
- from airflow.utils import timezone
32
+ from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS
33
33
 
34
34
  if TYPE_CHECKING:
35
- try:
36
- from airflow.sdk.definitions.context import Context
37
- except ImportError:
38
- # TODO: Remove once provider drops support for Airflow 2
39
- from airflow.utils.context import Context
35
+ from airflow.providers.common.compat.sdk import Context
40
36
 
41
37
 
42
38
  def _get_airflow_version():
@@ -20,16 +20,11 @@ from __future__ import annotations
20
20
  from collections.abc import Iterable
21
21
  from typing import TYPE_CHECKING
22
22
 
23
+ from airflow.providers.common.compat.sdk import BaseSensorOperator, timezone
23
24
  from airflow.providers.standard.utils.weekday import WeekDay
24
- from airflow.providers.standard.version_compat import BaseSensorOperator
25
- from airflow.utils import timezone
26
25
 
27
26
  if TYPE_CHECKING:
28
- try:
29
- from airflow.sdk.definitions.context import Context
30
- except ImportError:
31
- # TODO: Remove once provider drops support for Airflow 2
32
- from airflow.utils.context import Context
27
+ from airflow.providers.common.compat.sdk import Context
33
28
 
34
29
 
35
30
  class DayOfWeekSensor(BaseSensorOperator):
@@ -18,10 +18,11 @@ from __future__ import annotations
18
18
 
19
19
  import asyncio
20
20
  import typing
21
+ from collections.abc import Collection
21
22
  from typing import Any
22
23
 
23
24
  from asgiref.sync import sync_to_async
24
- from sqlalchemy import func
25
+ from sqlalchemy import func, select
25
26
 
26
27
  from airflow.models import DagRun
27
28
  from airflow.providers.standard.utils.sensor_helper import _get_count
@@ -60,9 +61,9 @@ class WorkflowTrigger(BaseTrigger):
60
61
  logical_dates: list[datetime] | None = None,
61
62
  external_task_ids: typing.Collection[str] | None = None,
62
63
  external_task_group_id: str | None = None,
63
- failed_states: typing.Iterable[str] | None = None,
64
- skipped_states: typing.Iterable[str] | None = None,
65
- allowed_states: typing.Iterable[str] | None = None,
64
+ failed_states: Collection[str] | None = None,
65
+ skipped_states: Collection[str] | None = None,
66
+ allowed_states: Collection[str] | None = None,
66
67
  poke_interval: float = 2.0,
67
68
  soft_fail: bool = False,
68
69
  **kwargs,
@@ -129,43 +130,41 @@ class WorkflowTrigger(BaseTrigger):
129
130
  self.log.info("Sleeping for %s seconds", self.poke_interval)
130
131
  await asyncio.sleep(self.poke_interval)
131
132
 
132
- async def _get_count_af_3(self, states):
133
+ async def _get_count_af_3(self, states: Collection[str] | None) -> int:
133
134
  from airflow.providers.standard.utils.sensor_helper import _get_count_by_matched_states
134
135
  from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
135
136
 
136
- params = {
137
- "dag_id": self.external_dag_id,
138
- "logical_dates": self.logical_dates,
139
- "run_ids": self.run_ids,
140
- }
141
137
  if self.external_task_ids:
142
138
  count = await sync_to_async(RuntimeTaskInstance.get_ti_count)(
143
- task_ids=self.external_task_ids,
144
- states=states,
145
- **params,
139
+ dag_id=self.external_dag_id,
140
+ task_ids=list(self.external_task_ids),
141
+ logical_dates=self.logical_dates,
142
+ run_ids=self.run_ids,
143
+ states=list(states) if states else None,
146
144
  )
147
- elif self.external_task_group_id:
145
+ return int(count / len(self.external_task_ids))
146
+ if self.external_task_group_id:
148
147
  run_id_task_state_map = await sync_to_async(RuntimeTaskInstance.get_task_states)(
148
+ dag_id=self.external_dag_id,
149
149
  task_group_id=self.external_task_group_id,
150
- **params,
150
+ logical_dates=self.logical_dates,
151
+ run_ids=self.run_ids,
151
152
  )
152
153
  count = await sync_to_async(_get_count_by_matched_states)(
153
154
  run_id_task_state_map=run_id_task_state_map,
154
- states=states,
155
+ states=states or [],
155
156
  )
156
- else:
157
- count = await sync_to_async(RuntimeTaskInstance.get_dr_count)(
158
- dag_id=self.external_dag_id,
159
- logical_dates=self.logical_dates,
160
- run_ids=self.run_ids,
161
- states=states,
162
- )
163
- if self.external_task_ids:
164
- return count / len(self.external_task_ids)
157
+ return count
158
+ count = await sync_to_async(RuntimeTaskInstance.get_dr_count)(
159
+ dag_id=self.external_dag_id,
160
+ logical_dates=self.logical_dates,
161
+ run_ids=self.run_ids,
162
+ states=list(states) if states else None,
163
+ )
165
164
  return count
166
165
 
167
166
  @sync_to_async
168
- def _get_count(self, states: typing.Iterable[str] | None) -> int:
167
+ def _get_count(self, states: Collection[str] | None) -> int:
169
168
  """
170
169
  Get the count of records against dttm filter and states. Async wrapper for _get_count.
171
170
 
@@ -228,8 +227,8 @@ class DagStateTrigger(BaseTrigger):
228
227
  runs_ids_or_dates = len(self.execution_dates)
229
228
 
230
229
  if AIRFLOW_V_3_0_PLUS:
231
- event = await self.validate_count_dags_af_3(runs_ids_or_dates_len=runs_ids_or_dates)
232
- yield TriggerEvent(event)
230
+ data = await self.validate_count_dags_af_3(runs_ids_or_dates_len=runs_ids_or_dates)
231
+ yield TriggerEvent(data)
233
232
  return
234
233
  else:
235
234
  while True:
@@ -239,7 +238,7 @@ class DagStateTrigger(BaseTrigger):
239
238
  return
240
239
  await asyncio.sleep(self.poll_interval)
241
240
 
242
- async def validate_count_dags_af_3(self, runs_ids_or_dates_len: int = 0) -> tuple[str, dict[str, Any]]:
241
+ async def validate_count_dags_af_3(self, runs_ids_or_dates_len: int = 0) -> dict[str, typing.Any]:
243
242
  from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
244
243
 
245
244
  cls_path, data = self.serialize()
@@ -259,7 +258,7 @@ class DagStateTrigger(BaseTrigger):
259
258
  run_id=run_id,
260
259
  )
261
260
  data[run_id] = state
262
- return cls_path, data
261
+ return data
263
262
  await asyncio.sleep(self.poll_interval)
264
263
 
265
264
  if not AIRFLOW_V_3_0_PLUS:
@@ -270,17 +269,18 @@ class DagStateTrigger(BaseTrigger):
270
269
  def count_dags(self, *, session: Session = NEW_SESSION) -> int:
271
270
  """Count how many dag runs in the database match our criteria."""
272
271
  _dag_run_date_condition = (
273
- DagRun.run_id.in_(self.run_ids)
272
+ DagRun.run_id.in_(self.run_ids or [])
274
273
  if AIRFLOW_V_3_0_PLUS
275
274
  else DagRun.execution_date.in_(self.execution_dates)
276
275
  )
277
- count = (
278
- session.query(func.count("*")) # .count() is inefficient
279
- .filter(
276
+ stmt = (
277
+ select(func.count())
278
+ .select_from(DagRun)
279
+ .where(
280
280
  DagRun.dag_id == self.dag_id,
281
281
  DagRun.state.in_(self.states),
282
282
  _dag_run_date_condition,
283
283
  )
284
- .scalar()
285
284
  )
286
- return typing.cast("int", count)
285
+ result = session.execute(stmt).scalar()
286
+ return result or 0
@@ -79,7 +79,7 @@ class FileTrigger(BaseTrigger):
79
79
  self.log.info("Found File %s last modified: %s", path, mod_time)
80
80
  yield TriggerEvent(True)
81
81
  return
82
- for _, _, files in os.walk(self.filepath):
82
+ for _, _, files in os.walk(path):
83
83
  if files:
84
84
  yield TriggerEvent(True)
85
85
  return