apache-airflow-providers-standard 1.9.1rc1__py3-none-any.whl → 1.10.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. airflow/providers/standard/__init__.py +3 -3
  2. airflow/providers/standard/decorators/bash.py +1 -2
  3. airflow/providers/standard/example_dags/example_bash_decorator.py +1 -1
  4. airflow/providers/standard/exceptions.py +1 -1
  5. airflow/providers/standard/hooks/subprocess.py +2 -9
  6. airflow/providers/standard/operators/bash.py +7 -3
  7. airflow/providers/standard/operators/datetime.py +1 -2
  8. airflow/providers/standard/operators/hitl.py +20 -10
  9. airflow/providers/standard/operators/latest_only.py +19 -10
  10. airflow/providers/standard/operators/python.py +39 -6
  11. airflow/providers/standard/operators/trigger_dagrun.py +82 -27
  12. airflow/providers/standard/sensors/bash.py +2 -4
  13. airflow/providers/standard/sensors/date_time.py +1 -16
  14. airflow/providers/standard/sensors/external_task.py +91 -51
  15. airflow/providers/standard/sensors/filesystem.py +2 -19
  16. airflow/providers/standard/sensors/time.py +2 -18
  17. airflow/providers/standard/sensors/time_delta.py +7 -6
  18. airflow/providers/standard/triggers/external_task.py +43 -40
  19. airflow/providers/standard/triggers/file.py +1 -1
  20. airflow/providers/standard/triggers/hitl.py +136 -87
  21. airflow/providers/standard/utils/openlineage.py +185 -0
  22. airflow/providers/standard/utils/python_virtualenv.py +38 -4
  23. airflow/providers/standard/utils/python_virtualenv_script.jinja2 +18 -3
  24. airflow/providers/standard/utils/sensor_helper.py +19 -8
  25. airflow/providers/standard/utils/skipmixin.py +2 -2
  26. airflow/providers/standard/version_compat.py +1 -0
  27. {apache_airflow_providers_standard-1.9.1rc1.dist-info → apache_airflow_providers_standard-1.10.3.dist-info}/METADATA +25 -11
  28. {apache_airflow_providers_standard-1.9.1rc1.dist-info → apache_airflow_providers_standard-1.10.3.dist-info}/RECORD +32 -30
  29. apache_airflow_providers_standard-1.10.3.dist-info/licenses/NOTICE +5 -0
  30. {apache_airflow_providers_standard-1.9.1rc1.dist-info → apache_airflow_providers_standard-1.10.3.dist-info}/WHEEL +0 -0
  31. {apache_airflow_providers_standard-1.9.1rc1.dist-info → apache_airflow_providers_standard-1.10.3.dist-info}/entry_points.txt +0 -0
  32. {airflow/providers/standard → apache_airflow_providers_standard-1.10.3.dist-info/licenses}/LICENSE +0 -0
@@ -18,14 +18,18 @@ from __future__ import annotations
18
18
 
19
19
  import datetime
20
20
  import os
21
+ import typing
21
22
  import warnings
22
- from collections.abc import Callable, Collection, Iterable
23
- from typing import TYPE_CHECKING, Any, ClassVar
23
+ from collections.abc import Callable, Collection, Iterable, Sequence
24
+ from typing import TYPE_CHECKING, ClassVar
24
25
 
25
- from airflow.configuration import conf
26
- from airflow.exceptions import AirflowSkipException
27
26
  from airflow.models.dag import DagModel
28
- from airflow.providers.common.compat.sdk import BaseOperatorLink, BaseSensorOperator
27
+ from airflow.providers.common.compat.sdk import (
28
+ AirflowSkipException,
29
+ BaseOperatorLink,
30
+ BaseSensorOperator,
31
+ conf,
32
+ )
29
33
  from airflow.providers.standard.exceptions import (
30
34
  DuplicateStateError,
31
35
  ExternalDagDeletedError,
@@ -250,18 +254,21 @@ class ExternalTaskSensor(BaseSensorOperator):
250
254
  self._has_checked_existence = False
251
255
  self.deferrable = deferrable
252
256
  self.poll_interval = poll_interval
257
+ self.external_dates_filter: str | None = None
253
258
 
254
- def _get_dttm_filter(self, context):
259
+ def _get_dttm_filter(self, context: Context) -> Sequence[datetime.datetime]:
255
260
  logical_date = self._get_logical_date(context)
256
261
 
257
262
  if self.execution_delta:
258
- dttm = logical_date - self.execution_delta
259
- elif self.execution_date_fn:
260
- dttm = self._handle_execution_date_fn(context=context)
261
- else:
262
- dttm = logical_date
263
+ return [logical_date - self.execution_delta]
264
+ if self.execution_date_fn:
265
+ result = self._handle_execution_date_fn(context=context)
266
+ return result if isinstance(result, list) else [result]
267
+ return [logical_date]
263
268
 
264
- return dttm if isinstance(dttm, list) else [dttm]
269
+ @staticmethod
270
+ def _serialize_dttm_filter(dttm_filter: Sequence[datetime.datetime]) -> str:
271
+ return ",".join(dt.isoformat() for dt in dttm_filter)
265
272
 
266
273
  def poke(self, context: Context) -> bool:
267
274
  # delay check to poke rather than __init__ in case it was supplied as XComArgs
@@ -269,7 +276,9 @@ class ExternalTaskSensor(BaseSensorOperator):
269
276
  raise ValueError("Duplicate task_ids passed in external_task_ids parameter")
270
277
 
271
278
  dttm_filter = self._get_dttm_filter(context)
272
- serialized_dttm_filter = ",".join(dt.isoformat() for dt in dttm_filter)
279
+ serialized_dttm_filter = self._serialize_dttm_filter(dttm_filter)
280
+ # Save as attribute - to be used by listeners
281
+ self.external_dates_filter = serialized_dttm_filter
273
282
 
274
283
  if self.external_task_ids:
275
284
  self.log.info(
@@ -298,7 +307,7 @@ class ExternalTaskSensor(BaseSensorOperator):
298
307
  return self._poke_af3(context, dttm_filter)
299
308
  return self._poke_af2(dttm_filter)
300
309
 
301
- def _poke_af3(self, context: Context, dttm_filter: list[datetime.datetime]) -> bool:
310
+ def _poke_af3(self, context: Context, dttm_filter: Sequence[datetime.datetime]) -> bool:
302
311
  from airflow.providers.standard.utils.sensor_helper import _get_count_by_matched_states
303
312
 
304
313
  self._has_checked_existence = True
@@ -308,20 +317,20 @@ class ExternalTaskSensor(BaseSensorOperator):
308
317
  if self.external_task_ids:
309
318
  return ti.get_ti_count(
310
319
  dag_id=self.external_dag_id,
311
- task_ids=self.external_task_ids, # type: ignore[arg-type]
312
- logical_dates=dttm_filter,
320
+ task_ids=list(self.external_task_ids),
321
+ logical_dates=list(dttm_filter),
313
322
  states=states,
314
323
  )
315
324
  if self.external_task_group_id:
316
325
  run_id_task_state_map = ti.get_task_states(
317
326
  dag_id=self.external_dag_id,
318
327
  task_group_id=self.external_task_group_id,
319
- logical_dates=dttm_filter,
328
+ logical_dates=list(dttm_filter),
320
329
  )
321
330
  return _get_count_by_matched_states(run_id_task_state_map, states)
322
331
  return ti.get_dr_count(
323
332
  dag_id=self.external_dag_id,
324
- logical_dates=dttm_filter,
333
+ logical_dates=list(dttm_filter),
325
334
  states=states,
326
335
  )
327
336
 
@@ -339,7 +348,7 @@ class ExternalTaskSensor(BaseSensorOperator):
339
348
  count_allowed = self._calculate_count(count, dttm_filter)
340
349
  return count_allowed == len(dttm_filter)
341
350
 
342
- def _calculate_count(self, count: int, dttm_filter: list[datetime.datetime]) -> float | int:
351
+ def _calculate_count(self, count: int, dttm_filter: Sequence[datetime.datetime]) -> float | int:
343
352
  """Calculate the normalized count based on the type of check."""
344
353
  if self.external_task_ids:
345
354
  return count / len(self.external_task_ids)
@@ -395,7 +404,7 @@ class ExternalTaskSensor(BaseSensorOperator):
395
404
  if not AIRFLOW_V_3_0_PLUS:
396
405
 
397
406
  @provide_session
398
- def _poke_af2(self, dttm_filter: list[datetime.datetime], session: Session = NEW_SESSION) -> bool:
407
+ def _poke_af2(self, dttm_filter: Sequence[datetime.datetime], session: Session = NEW_SESSION) -> bool:
399
408
  if self.check_existence and not self._has_checked_existence:
400
409
  self._check_for_existence(session=session)
401
410
 
@@ -416,27 +425,51 @@ class ExternalTaskSensor(BaseSensorOperator):
416
425
  super().execute(context)
417
426
  else:
418
427
  dttm_filter = self._get_dttm_filter(context)
419
- logical_or_execution_dates = (
420
- {"logical_dates": dttm_filter} if AIRFLOW_V_3_0_PLUS else {"execution_dates": dttm_filter}
421
- )
422
- self.defer(
423
- timeout=self.execution_timeout,
424
- trigger=WorkflowTrigger(
425
- external_dag_id=self.external_dag_id,
426
- external_task_group_id=self.external_task_group_id,
427
- external_task_ids=self.external_task_ids,
428
- allowed_states=self.allowed_states,
429
- failed_states=self.failed_states,
430
- skipped_states=self.skipped_states,
431
- poke_interval=self.poll_interval,
432
- soft_fail=self.soft_fail,
433
- **logical_or_execution_dates,
434
- ),
435
- method_name="execute_complete",
436
- )
428
+ if AIRFLOW_V_3_0_PLUS:
429
+ self.defer(
430
+ timeout=self.execution_timeout,
431
+ trigger=WorkflowTrigger(
432
+ external_dag_id=self.external_dag_id,
433
+ external_task_group_id=self.external_task_group_id,
434
+ external_task_ids=self.external_task_ids,
435
+ allowed_states=self.allowed_states,
436
+ failed_states=self.failed_states,
437
+ skipped_states=self.skipped_states,
438
+ poke_interval=self.poll_interval,
439
+ soft_fail=self.soft_fail,
440
+ logical_dates=list(dttm_filter),
441
+ run_ids=None,
442
+ execution_dates=None,
443
+ ),
444
+ method_name="execute_complete",
445
+ )
446
+ else:
447
+ self.defer(
448
+ timeout=self.execution_timeout,
449
+ trigger=WorkflowTrigger(
450
+ external_dag_id=self.external_dag_id,
451
+ external_task_group_id=self.external_task_group_id,
452
+ external_task_ids=self.external_task_ids,
453
+ allowed_states=self.allowed_states,
454
+ failed_states=self.failed_states,
455
+ skipped_states=self.skipped_states,
456
+ poke_interval=self.poll_interval,
457
+ soft_fail=self.soft_fail,
458
+ execution_dates=list(dttm_filter),
459
+ logical_dates=None,
460
+ run_ids=None,
461
+ ),
462
+ method_name="execute_complete",
463
+ )
437
464
 
438
- def execute_complete(self, context, event=None):
465
+ def execute_complete(self, context: Context, event: dict[str, typing.Any] | None = None) -> None:
439
466
  """Execute when the trigger fires - return immediately."""
467
+ if event is None:
468
+ raise ExternalTaskNotFoundError("No event received from trigger")
469
+
470
+ # Re-set as attribute after coming back from deferral - to be used by listeners
471
+ self.external_dates_filter = self._serialize_dttm_filter(self._get_dttm_filter(context))
472
+
440
473
  if event["status"] == "success":
441
474
  self.log.info("External tasks %s has executed successfully.", self.external_task_ids)
442
475
  elif event["status"] == "skipped":
@@ -453,13 +486,14 @@ class ExternalTaskSensor(BaseSensorOperator):
453
486
  "name of executed task and Dag."
454
487
  )
455
488
 
456
- def _check_for_existence(self, session) -> None:
489
+ def _check_for_existence(self, session: Session) -> None:
457
490
  dag_to_wait = DagModel.get_current(self.external_dag_id, session)
458
491
 
459
492
  if not dag_to_wait:
460
493
  raise ExternalDagNotFoundError(f"The external DAG {self.external_dag_id} does not exist.")
461
494
 
462
- if not os.path.exists(correct_maybe_zipped(dag_to_wait.fileloc)):
495
+ path = correct_maybe_zipped(dag_to_wait.fileloc)
496
+ if not path or not os.path.exists(path):
463
497
  raise ExternalDagDeletedError(f"The external DAG {self.external_dag_id} was deleted.")
464
498
 
465
499
  if self.external_task_ids:
@@ -488,7 +522,7 @@ class ExternalTaskSensor(BaseSensorOperator):
488
522
 
489
523
  self._has_checked_existence = True
490
524
 
491
- def get_count(self, dttm_filter, session, states) -> int:
525
+ def get_count(self, dttm_filter: Sequence[datetime.datetime], session: Session, states: list[str]) -> int:
492
526
  """
493
527
  Get the count of records against dttm filter and states.
494
528
 
@@ -509,15 +543,19 @@ class ExternalTaskSensor(BaseSensorOperator):
509
543
  session,
510
544
  )
511
545
 
512
- def get_external_task_group_task_ids(self, session, dttm_filter):
546
+ def get_external_task_group_task_ids(
547
+ self, session: Session, dttm_filter: Sequence[datetime.datetime]
548
+ ) -> list[tuple[str, int]]:
513
549
  warnings.warn(
514
550
  "This method is deprecated and will be removed in future.", DeprecationWarning, stacklevel=2
515
551
  )
552
+ if self.external_task_group_id is None:
553
+ return []
516
554
  return _get_external_task_group_task_ids(
517
- dttm_filter, self.external_task_group_id, self.external_dag_id, session
555
+ list(dttm_filter), self.external_task_group_id, self.external_dag_id, session
518
556
  )
519
557
 
520
- def _get_logical_date(self, context) -> datetime.datetime:
558
+ def _get_logical_date(self, context: Context) -> datetime.datetime:
521
559
  """
522
560
  Handle backwards- and forwards-compatible retrieval of the date.
523
561
 
@@ -527,19 +565,21 @@ class ExternalTaskSensor(BaseSensorOperator):
527
565
  if AIRFLOW_V_3_0_PLUS:
528
566
  logical_date = context.get("logical_date")
529
567
  dag_run = context.get("dag_run")
530
- if not (logical_date or (dag_run and dag_run.run_after)):
531
- raise ValueError(
532
- "Either `logical_date` or `dag_run.run_after` must be provided in the context"
533
- )
534
- return logical_date or dag_run.run_after
568
+ if logical_date:
569
+ return logical_date
570
+ if dag_run and hasattr(dag_run, "run_after") and dag_run.run_after:
571
+ return dag_run.run_after
572
+ raise ValueError("Either `logical_date` or `dag_run.run_after` must be provided in the context")
535
573
 
536
574
  # Airflow 2.x and earlier: contexts used "execution_date"
537
575
  execution_date = context.get("execution_date")
538
576
  if not execution_date:
539
577
  raise ValueError("Either `execution_date` must be provided in the context`")
578
+ if not isinstance(execution_date, datetime.datetime):
579
+ raise ValueError("execution_date must be a datetime object")
540
580
  return execution_date
541
581
 
542
- def _handle_execution_date_fn(self, context) -> Any:
582
+ def _handle_execution_date_fn(self, context: Context) -> datetime.datetime | list[datetime.datetime]:
543
583
  """
544
584
  Handle backward compatibility.
545
585
 
@@ -20,31 +20,14 @@ from __future__ import annotations
20
20
  import datetime
21
21
  import os
22
22
  from collections.abc import Sequence
23
- from dataclasses import dataclass
24
23
  from functools import cached_property
25
24
  from glob import glob
26
25
  from typing import TYPE_CHECKING, Any
27
26
 
28
- from airflow.configuration import conf
29
- from airflow.exceptions import AirflowException
30
- from airflow.providers.common.compat.sdk import BaseSensorOperator
27
+ from airflow.providers.common.compat.sdk import AirflowException, BaseSensorOperator, conf
31
28
  from airflow.providers.standard.hooks.filesystem import FSHook
32
29
  from airflow.providers.standard.triggers.file import FileTrigger
33
-
34
- try:
35
- from airflow.triggers.base import StartTriggerArgs # type: ignore[no-redef]
36
- except ImportError: # TODO: Remove this when min airflow version is 2.10.0 for standard provider
37
-
38
- @dataclass
39
- class StartTriggerArgs: # type: ignore[no-redef]
40
- """Arguments required for start task execution from triggerer."""
41
-
42
- trigger_cls: str
43
- next_method: str
44
- trigger_kwargs: dict[str, Any] | None = None
45
- next_kwargs: dict[str, Any] | None = None
46
- timeout: datetime.timedelta | None = None
47
-
30
+ from airflow.triggers.base import StartTriggerArgs
48
31
 
49
32
  if TYPE_CHECKING:
50
33
  from airflow.sdk import Context
@@ -19,28 +19,12 @@ from __future__ import annotations
19
19
 
20
20
  import datetime
21
21
  import warnings
22
- from dataclasses import dataclass
23
22
  from typing import TYPE_CHECKING, Any
24
23
 
25
- from airflow.configuration import conf
26
24
  from airflow.exceptions import AirflowProviderDeprecationWarning
27
- from airflow.providers.common.compat.sdk import BaseSensorOperator, timezone
25
+ from airflow.providers.common.compat.sdk import BaseSensorOperator, conf, timezone
28
26
  from airflow.providers.standard.triggers.temporal import DateTimeTrigger
29
-
30
- try:
31
- from airflow.triggers.base import StartTriggerArgs # type: ignore[no-redef]
32
- except ImportError: # TODO: Remove this when min airflow version is 2.10.0 for standard provider
33
-
34
- @dataclass
35
- class StartTriggerArgs: # type: ignore[no-redef]
36
- """Arguments required for start task execution from triggerer."""
37
-
38
- trigger_cls: str
39
- next_method: str
40
- trigger_kwargs: dict[str, Any] | None = None
41
- next_kwargs: dict[str, Any] | None = None
42
- timeout: datetime.timedelta | None = None
43
-
27
+ from airflow.triggers.base import StartTriggerArgs
44
28
 
45
29
  if TYPE_CHECKING:
46
30
  from airflow.sdk import Context
@@ -25,9 +25,8 @@ from typing import TYPE_CHECKING, Any
25
25
  from deprecated.classic import deprecated
26
26
  from packaging.version import Version
27
27
 
28
- from airflow.configuration import conf
29
- from airflow.exceptions import AirflowProviderDeprecationWarning, AirflowSkipException
30
- from airflow.providers.common.compat.sdk import BaseSensorOperator, timezone
28
+ from airflow.exceptions import AirflowProviderDeprecationWarning
29
+ from airflow.providers.common.compat.sdk import AirflowSkipException, BaseSensorOperator, conf, timezone
31
30
  from airflow.providers.standard.triggers.temporal import DateTimeTrigger, TimeDeltaTrigger
32
31
  from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS
33
32
 
@@ -194,9 +193,11 @@ class WaitSensor(BaseSensorOperator):
194
193
  def execute(self, context: Context) -> None:
195
194
  if self.deferrable:
196
195
  self.defer(
197
- trigger=TimeDeltaTrigger(self.time_to_wait, end_from_trigger=True)
198
- if AIRFLOW_V_3_0_PLUS
199
- else TimeDeltaTrigger(self.time_to_wait),
196
+ trigger=(
197
+ TimeDeltaTrigger(self.time_to_wait, end_from_trigger=True)
198
+ if AIRFLOW_V_3_0_PLUS
199
+ else TimeDeltaTrigger(self.time_to_wait)
200
+ ),
200
201
  method_name="execute_complete",
201
202
  )
202
203
  else:
@@ -18,10 +18,11 @@ from __future__ import annotations
18
18
 
19
19
  import asyncio
20
20
  import typing
21
+ from collections.abc import Collection
21
22
  from typing import Any
22
23
 
23
24
  from asgiref.sync import sync_to_async
24
- from sqlalchemy import func
25
+ from sqlalchemy import func, select
25
26
 
26
27
  from airflow.models import DagRun
27
28
  from airflow.providers.standard.utils.sensor_helper import _get_count
@@ -60,9 +61,9 @@ class WorkflowTrigger(BaseTrigger):
60
61
  logical_dates: list[datetime] | None = None,
61
62
  external_task_ids: typing.Collection[str] | None = None,
62
63
  external_task_group_id: str | None = None,
63
- failed_states: typing.Iterable[str] | None = None,
64
- skipped_states: typing.Iterable[str] | None = None,
65
- allowed_states: typing.Iterable[str] | None = None,
64
+ failed_states: Collection[str] | None = None,
65
+ skipped_states: Collection[str] | None = None,
66
+ allowed_states: Collection[str] | None = None,
66
67
  poke_interval: float = 2.0,
67
68
  soft_fail: bool = False,
68
69
  **kwargs,
@@ -129,43 +130,41 @@ class WorkflowTrigger(BaseTrigger):
129
130
  self.log.info("Sleeping for %s seconds", self.poke_interval)
130
131
  await asyncio.sleep(self.poke_interval)
131
132
 
132
- async def _get_count_af_3(self, states):
133
+ async def _get_count_af_3(self, states: Collection[str] | None) -> int:
133
134
  from airflow.providers.standard.utils.sensor_helper import _get_count_by_matched_states
134
135
  from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
135
136
 
136
- params = {
137
- "dag_id": self.external_dag_id,
138
- "logical_dates": self.logical_dates,
139
- "run_ids": self.run_ids,
140
- }
141
137
  if self.external_task_ids:
142
138
  count = await sync_to_async(RuntimeTaskInstance.get_ti_count)(
143
- task_ids=self.external_task_ids,
144
- states=states,
145
- **params,
139
+ dag_id=self.external_dag_id,
140
+ task_ids=list(self.external_task_ids),
141
+ logical_dates=self.logical_dates,
142
+ run_ids=self.run_ids,
143
+ states=list(states) if states else None,
146
144
  )
147
- elif self.external_task_group_id:
145
+ return int(count / len(self.external_task_ids))
146
+ if self.external_task_group_id:
148
147
  run_id_task_state_map = await sync_to_async(RuntimeTaskInstance.get_task_states)(
148
+ dag_id=self.external_dag_id,
149
149
  task_group_id=self.external_task_group_id,
150
- **params,
150
+ logical_dates=self.logical_dates,
151
+ run_ids=self.run_ids,
151
152
  )
152
153
  count = await sync_to_async(_get_count_by_matched_states)(
153
154
  run_id_task_state_map=run_id_task_state_map,
154
- states=states,
155
- )
156
- else:
157
- count = await sync_to_async(RuntimeTaskInstance.get_dr_count)(
158
- dag_id=self.external_dag_id,
159
- logical_dates=self.logical_dates,
160
- run_ids=self.run_ids,
161
- states=states,
155
+ states=states or [],
162
156
  )
163
- if self.external_task_ids:
164
- return count / len(self.external_task_ids)
157
+ return count
158
+ count = await sync_to_async(RuntimeTaskInstance.get_dr_count)(
159
+ dag_id=self.external_dag_id,
160
+ logical_dates=self.logical_dates,
161
+ run_ids=self.run_ids,
162
+ states=list(states) if states else None,
163
+ )
165
164
  return count
166
165
 
167
166
  @sync_to_async
168
- def _get_count(self, states: typing.Iterable[str] | None) -> int:
167
+ def _get_count(self, states: Collection[str] | None) -> int:
169
168
  """
170
169
  Get the count of records against dttm filter and states. Async wrapper for _get_count.
171
170
 
@@ -227,23 +226,26 @@ class DagStateTrigger(BaseTrigger):
227
226
  elif self.execution_dates:
228
227
  runs_ids_or_dates = len(self.execution_dates)
229
228
 
229
+ cls_path, data = self.serialize()
230
+
230
231
  if AIRFLOW_V_3_0_PLUS:
231
- event = await self.validate_count_dags_af_3(runs_ids_or_dates_len=runs_ids_or_dates)
232
- yield TriggerEvent(event)
232
+ data.update( # update with {run_id: run_state} dict
233
+ await self.validate_count_dags_af_3(runs_ids_or_dates_len=runs_ids_or_dates)
234
+ )
235
+ yield TriggerEvent((cls_path, data))
233
236
  return
234
237
  else:
235
238
  while True:
236
239
  num_dags = await self.count_dags()
237
240
  if num_dags == runs_ids_or_dates:
238
- yield TriggerEvent(self.serialize())
241
+ yield TriggerEvent((cls_path, data))
239
242
  return
240
243
  await asyncio.sleep(self.poll_interval)
241
244
 
242
- async def validate_count_dags_af_3(self, runs_ids_or_dates_len: int = 0) -> tuple[str, dict[str, Any]]:
245
+ async def validate_count_dags_af_3(self, runs_ids_or_dates_len: int = 0) -> dict[str, str]:
243
246
  from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
244
247
 
245
- cls_path, data = self.serialize()
246
-
248
+ run_states: dict[str, str] = {} # {run_id: run_state}
247
249
  while True:
248
250
  num_dags = await sync_to_async(RuntimeTaskInstance.get_dr_count)(
249
251
  dag_id=self.dag_id,
@@ -258,8 +260,8 @@ class DagStateTrigger(BaseTrigger):
258
260
  dag_id=self.dag_id,
259
261
  run_id=run_id,
260
262
  )
261
- data[run_id] = state
262
- return cls_path, data
263
+ run_states[run_id] = state
264
+ return run_states
263
265
  await asyncio.sleep(self.poll_interval)
264
266
 
265
267
  if not AIRFLOW_V_3_0_PLUS:
@@ -270,17 +272,18 @@ class DagStateTrigger(BaseTrigger):
270
272
  def count_dags(self, *, session: Session = NEW_SESSION) -> int:
271
273
  """Count how many dag runs in the database match our criteria."""
272
274
  _dag_run_date_condition = (
273
- DagRun.run_id.in_(self.run_ids)
275
+ DagRun.run_id.in_(self.run_ids or [])
274
276
  if AIRFLOW_V_3_0_PLUS
275
277
  else DagRun.execution_date.in_(self.execution_dates)
276
278
  )
277
- count = (
278
- session.query(func.count("*")) # .count() is inefficient
279
- .filter(
279
+ stmt = (
280
+ select(func.count())
281
+ .select_from(DagRun)
282
+ .where(
280
283
  DagRun.dag_id == self.dag_id,
281
284
  DagRun.state.in_(self.states),
282
285
  _dag_run_date_condition,
283
286
  )
284
- .scalar()
285
287
  )
286
- return typing.cast("int", count)
288
+ result = session.execute(stmt).scalar()
289
+ return result or 0
@@ -79,7 +79,7 @@ class FileTrigger(BaseTrigger):
79
79
  self.log.info("Found File %s last modified: %s", path, mod_time)
80
80
  yield TriggerEvent(True)
81
81
  return
82
- for _, _, files in os.walk(self.filepath):
82
+ for _, _, files in os.walk(path):
83
83
  if files:
84
84
  yield TriggerEvent(True)
85
85
  return