apache-airflow-providers-standard 1.9.1rc1__py3-none-any.whl → 1.10.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/standard/__init__.py +3 -3
- airflow/providers/standard/decorators/bash.py +1 -2
- airflow/providers/standard/example_dags/example_bash_decorator.py +1 -1
- airflow/providers/standard/exceptions.py +1 -1
- airflow/providers/standard/hooks/subprocess.py +2 -9
- airflow/providers/standard/operators/bash.py +7 -3
- airflow/providers/standard/operators/datetime.py +1 -2
- airflow/providers/standard/operators/hitl.py +20 -10
- airflow/providers/standard/operators/latest_only.py +19 -10
- airflow/providers/standard/operators/python.py +39 -6
- airflow/providers/standard/operators/trigger_dagrun.py +82 -27
- airflow/providers/standard/sensors/bash.py +2 -4
- airflow/providers/standard/sensors/date_time.py +1 -16
- airflow/providers/standard/sensors/external_task.py +91 -51
- airflow/providers/standard/sensors/filesystem.py +2 -19
- airflow/providers/standard/sensors/time.py +2 -18
- airflow/providers/standard/sensors/time_delta.py +7 -6
- airflow/providers/standard/triggers/external_task.py +43 -40
- airflow/providers/standard/triggers/file.py +1 -1
- airflow/providers/standard/triggers/hitl.py +136 -87
- airflow/providers/standard/utils/openlineage.py +185 -0
- airflow/providers/standard/utils/python_virtualenv.py +38 -4
- airflow/providers/standard/utils/python_virtualenv_script.jinja2 +18 -3
- airflow/providers/standard/utils/sensor_helper.py +19 -8
- airflow/providers/standard/utils/skipmixin.py +2 -2
- airflow/providers/standard/version_compat.py +1 -0
- {apache_airflow_providers_standard-1.9.1rc1.dist-info → apache_airflow_providers_standard-1.10.3.dist-info}/METADATA +25 -11
- {apache_airflow_providers_standard-1.9.1rc1.dist-info → apache_airflow_providers_standard-1.10.3.dist-info}/RECORD +32 -30
- apache_airflow_providers_standard-1.10.3.dist-info/licenses/NOTICE +5 -0
- {apache_airflow_providers_standard-1.9.1rc1.dist-info → apache_airflow_providers_standard-1.10.3.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_standard-1.9.1rc1.dist-info → apache_airflow_providers_standard-1.10.3.dist-info}/entry_points.txt +0 -0
- {airflow/providers/standard → apache_airflow_providers_standard-1.10.3.dist-info/licenses}/LICENSE +0 -0
|
@@ -18,14 +18,18 @@ from __future__ import annotations
|
|
|
18
18
|
|
|
19
19
|
import datetime
|
|
20
20
|
import os
|
|
21
|
+
import typing
|
|
21
22
|
import warnings
|
|
22
|
-
from collections.abc import Callable, Collection, Iterable
|
|
23
|
-
from typing import TYPE_CHECKING,
|
|
23
|
+
from collections.abc import Callable, Collection, Iterable, Sequence
|
|
24
|
+
from typing import TYPE_CHECKING, ClassVar
|
|
24
25
|
|
|
25
|
-
from airflow.configuration import conf
|
|
26
|
-
from airflow.exceptions import AirflowSkipException
|
|
27
26
|
from airflow.models.dag import DagModel
|
|
28
|
-
from airflow.providers.common.compat.sdk import
|
|
27
|
+
from airflow.providers.common.compat.sdk import (
|
|
28
|
+
AirflowSkipException,
|
|
29
|
+
BaseOperatorLink,
|
|
30
|
+
BaseSensorOperator,
|
|
31
|
+
conf,
|
|
32
|
+
)
|
|
29
33
|
from airflow.providers.standard.exceptions import (
|
|
30
34
|
DuplicateStateError,
|
|
31
35
|
ExternalDagDeletedError,
|
|
@@ -250,18 +254,21 @@ class ExternalTaskSensor(BaseSensorOperator):
|
|
|
250
254
|
self._has_checked_existence = False
|
|
251
255
|
self.deferrable = deferrable
|
|
252
256
|
self.poll_interval = poll_interval
|
|
257
|
+
self.external_dates_filter: str | None = None
|
|
253
258
|
|
|
254
|
-
def _get_dttm_filter(self, context):
|
|
259
|
+
def _get_dttm_filter(self, context: Context) -> Sequence[datetime.datetime]:
|
|
255
260
|
logical_date = self._get_logical_date(context)
|
|
256
261
|
|
|
257
262
|
if self.execution_delta:
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
+
return [logical_date - self.execution_delta]
|
|
264
|
+
if self.execution_date_fn:
|
|
265
|
+
result = self._handle_execution_date_fn(context=context)
|
|
266
|
+
return result if isinstance(result, list) else [result]
|
|
267
|
+
return [logical_date]
|
|
263
268
|
|
|
264
|
-
|
|
269
|
+
@staticmethod
|
|
270
|
+
def _serialize_dttm_filter(dttm_filter: Sequence[datetime.datetime]) -> str:
|
|
271
|
+
return ",".join(dt.isoformat() for dt in dttm_filter)
|
|
265
272
|
|
|
266
273
|
def poke(self, context: Context) -> bool:
|
|
267
274
|
# delay check to poke rather than __init__ in case it was supplied as XComArgs
|
|
@@ -269,7 +276,9 @@ class ExternalTaskSensor(BaseSensorOperator):
|
|
|
269
276
|
raise ValueError("Duplicate task_ids passed in external_task_ids parameter")
|
|
270
277
|
|
|
271
278
|
dttm_filter = self._get_dttm_filter(context)
|
|
272
|
-
serialized_dttm_filter =
|
|
279
|
+
serialized_dttm_filter = self._serialize_dttm_filter(dttm_filter)
|
|
280
|
+
# Save as attribute - to be used by listeners
|
|
281
|
+
self.external_dates_filter = serialized_dttm_filter
|
|
273
282
|
|
|
274
283
|
if self.external_task_ids:
|
|
275
284
|
self.log.info(
|
|
@@ -298,7 +307,7 @@ class ExternalTaskSensor(BaseSensorOperator):
|
|
|
298
307
|
return self._poke_af3(context, dttm_filter)
|
|
299
308
|
return self._poke_af2(dttm_filter)
|
|
300
309
|
|
|
301
|
-
def _poke_af3(self, context: Context, dttm_filter:
|
|
310
|
+
def _poke_af3(self, context: Context, dttm_filter: Sequence[datetime.datetime]) -> bool:
|
|
302
311
|
from airflow.providers.standard.utils.sensor_helper import _get_count_by_matched_states
|
|
303
312
|
|
|
304
313
|
self._has_checked_existence = True
|
|
@@ -308,20 +317,20 @@ class ExternalTaskSensor(BaseSensorOperator):
|
|
|
308
317
|
if self.external_task_ids:
|
|
309
318
|
return ti.get_ti_count(
|
|
310
319
|
dag_id=self.external_dag_id,
|
|
311
|
-
task_ids=self.external_task_ids,
|
|
312
|
-
logical_dates=dttm_filter,
|
|
320
|
+
task_ids=list(self.external_task_ids),
|
|
321
|
+
logical_dates=list(dttm_filter),
|
|
313
322
|
states=states,
|
|
314
323
|
)
|
|
315
324
|
if self.external_task_group_id:
|
|
316
325
|
run_id_task_state_map = ti.get_task_states(
|
|
317
326
|
dag_id=self.external_dag_id,
|
|
318
327
|
task_group_id=self.external_task_group_id,
|
|
319
|
-
logical_dates=dttm_filter,
|
|
328
|
+
logical_dates=list(dttm_filter),
|
|
320
329
|
)
|
|
321
330
|
return _get_count_by_matched_states(run_id_task_state_map, states)
|
|
322
331
|
return ti.get_dr_count(
|
|
323
332
|
dag_id=self.external_dag_id,
|
|
324
|
-
logical_dates=dttm_filter,
|
|
333
|
+
logical_dates=list(dttm_filter),
|
|
325
334
|
states=states,
|
|
326
335
|
)
|
|
327
336
|
|
|
@@ -339,7 +348,7 @@ class ExternalTaskSensor(BaseSensorOperator):
|
|
|
339
348
|
count_allowed = self._calculate_count(count, dttm_filter)
|
|
340
349
|
return count_allowed == len(dttm_filter)
|
|
341
350
|
|
|
342
|
-
def _calculate_count(self, count: int, dttm_filter:
|
|
351
|
+
def _calculate_count(self, count: int, dttm_filter: Sequence[datetime.datetime]) -> float | int:
|
|
343
352
|
"""Calculate the normalized count based on the type of check."""
|
|
344
353
|
if self.external_task_ids:
|
|
345
354
|
return count / len(self.external_task_ids)
|
|
@@ -395,7 +404,7 @@ class ExternalTaskSensor(BaseSensorOperator):
|
|
|
395
404
|
if not AIRFLOW_V_3_0_PLUS:
|
|
396
405
|
|
|
397
406
|
@provide_session
|
|
398
|
-
def _poke_af2(self, dttm_filter:
|
|
407
|
+
def _poke_af2(self, dttm_filter: Sequence[datetime.datetime], session: Session = NEW_SESSION) -> bool:
|
|
399
408
|
if self.check_existence and not self._has_checked_existence:
|
|
400
409
|
self._check_for_existence(session=session)
|
|
401
410
|
|
|
@@ -416,27 +425,51 @@ class ExternalTaskSensor(BaseSensorOperator):
|
|
|
416
425
|
super().execute(context)
|
|
417
426
|
else:
|
|
418
427
|
dttm_filter = self._get_dttm_filter(context)
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
428
|
+
if AIRFLOW_V_3_0_PLUS:
|
|
429
|
+
self.defer(
|
|
430
|
+
timeout=self.execution_timeout,
|
|
431
|
+
trigger=WorkflowTrigger(
|
|
432
|
+
external_dag_id=self.external_dag_id,
|
|
433
|
+
external_task_group_id=self.external_task_group_id,
|
|
434
|
+
external_task_ids=self.external_task_ids,
|
|
435
|
+
allowed_states=self.allowed_states,
|
|
436
|
+
failed_states=self.failed_states,
|
|
437
|
+
skipped_states=self.skipped_states,
|
|
438
|
+
poke_interval=self.poll_interval,
|
|
439
|
+
soft_fail=self.soft_fail,
|
|
440
|
+
logical_dates=list(dttm_filter),
|
|
441
|
+
run_ids=None,
|
|
442
|
+
execution_dates=None,
|
|
443
|
+
),
|
|
444
|
+
method_name="execute_complete",
|
|
445
|
+
)
|
|
446
|
+
else:
|
|
447
|
+
self.defer(
|
|
448
|
+
timeout=self.execution_timeout,
|
|
449
|
+
trigger=WorkflowTrigger(
|
|
450
|
+
external_dag_id=self.external_dag_id,
|
|
451
|
+
external_task_group_id=self.external_task_group_id,
|
|
452
|
+
external_task_ids=self.external_task_ids,
|
|
453
|
+
allowed_states=self.allowed_states,
|
|
454
|
+
failed_states=self.failed_states,
|
|
455
|
+
skipped_states=self.skipped_states,
|
|
456
|
+
poke_interval=self.poll_interval,
|
|
457
|
+
soft_fail=self.soft_fail,
|
|
458
|
+
execution_dates=list(dttm_filter),
|
|
459
|
+
logical_dates=None,
|
|
460
|
+
run_ids=None,
|
|
461
|
+
),
|
|
462
|
+
method_name="execute_complete",
|
|
463
|
+
)
|
|
437
464
|
|
|
438
|
-
def execute_complete(self, context, event=None):
|
|
465
|
+
def execute_complete(self, context: Context, event: dict[str, typing.Any] | None = None) -> None:
|
|
439
466
|
"""Execute when the trigger fires - return immediately."""
|
|
467
|
+
if event is None:
|
|
468
|
+
raise ExternalTaskNotFoundError("No event received from trigger")
|
|
469
|
+
|
|
470
|
+
# Re-set as attribute after coming back from deferral - to be used by listeners
|
|
471
|
+
self.external_dates_filter = self._serialize_dttm_filter(self._get_dttm_filter(context))
|
|
472
|
+
|
|
440
473
|
if event["status"] == "success":
|
|
441
474
|
self.log.info("External tasks %s has executed successfully.", self.external_task_ids)
|
|
442
475
|
elif event["status"] == "skipped":
|
|
@@ -453,13 +486,14 @@ class ExternalTaskSensor(BaseSensorOperator):
|
|
|
453
486
|
"name of executed task and Dag."
|
|
454
487
|
)
|
|
455
488
|
|
|
456
|
-
def _check_for_existence(self, session) -> None:
|
|
489
|
+
def _check_for_existence(self, session: Session) -> None:
|
|
457
490
|
dag_to_wait = DagModel.get_current(self.external_dag_id, session)
|
|
458
491
|
|
|
459
492
|
if not dag_to_wait:
|
|
460
493
|
raise ExternalDagNotFoundError(f"The external DAG {self.external_dag_id} does not exist.")
|
|
461
494
|
|
|
462
|
-
|
|
495
|
+
path = correct_maybe_zipped(dag_to_wait.fileloc)
|
|
496
|
+
if not path or not os.path.exists(path):
|
|
463
497
|
raise ExternalDagDeletedError(f"The external DAG {self.external_dag_id} was deleted.")
|
|
464
498
|
|
|
465
499
|
if self.external_task_ids:
|
|
@@ -488,7 +522,7 @@ class ExternalTaskSensor(BaseSensorOperator):
|
|
|
488
522
|
|
|
489
523
|
self._has_checked_existence = True
|
|
490
524
|
|
|
491
|
-
def get_count(self, dttm_filter, session, states) -> int:
|
|
525
|
+
def get_count(self, dttm_filter: Sequence[datetime.datetime], session: Session, states: list[str]) -> int:
|
|
492
526
|
"""
|
|
493
527
|
Get the count of records against dttm filter and states.
|
|
494
528
|
|
|
@@ -509,15 +543,19 @@ class ExternalTaskSensor(BaseSensorOperator):
|
|
|
509
543
|
session,
|
|
510
544
|
)
|
|
511
545
|
|
|
512
|
-
def get_external_task_group_task_ids(
|
|
546
|
+
def get_external_task_group_task_ids(
|
|
547
|
+
self, session: Session, dttm_filter: Sequence[datetime.datetime]
|
|
548
|
+
) -> list[tuple[str, int]]:
|
|
513
549
|
warnings.warn(
|
|
514
550
|
"This method is deprecated and will be removed in future.", DeprecationWarning, stacklevel=2
|
|
515
551
|
)
|
|
552
|
+
if self.external_task_group_id is None:
|
|
553
|
+
return []
|
|
516
554
|
return _get_external_task_group_task_ids(
|
|
517
|
-
dttm_filter, self.external_task_group_id, self.external_dag_id, session
|
|
555
|
+
list(dttm_filter), self.external_task_group_id, self.external_dag_id, session
|
|
518
556
|
)
|
|
519
557
|
|
|
520
|
-
def _get_logical_date(self, context) -> datetime.datetime:
|
|
558
|
+
def _get_logical_date(self, context: Context) -> datetime.datetime:
|
|
521
559
|
"""
|
|
522
560
|
Handle backwards- and forwards-compatible retrieval of the date.
|
|
523
561
|
|
|
@@ -527,19 +565,21 @@ class ExternalTaskSensor(BaseSensorOperator):
|
|
|
527
565
|
if AIRFLOW_V_3_0_PLUS:
|
|
528
566
|
logical_date = context.get("logical_date")
|
|
529
567
|
dag_run = context.get("dag_run")
|
|
530
|
-
if
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
568
|
+
if logical_date:
|
|
569
|
+
return logical_date
|
|
570
|
+
if dag_run and hasattr(dag_run, "run_after") and dag_run.run_after:
|
|
571
|
+
return dag_run.run_after
|
|
572
|
+
raise ValueError("Either `logical_date` or `dag_run.run_after` must be provided in the context")
|
|
535
573
|
|
|
536
574
|
# Airflow 2.x and earlier: contexts used "execution_date"
|
|
537
575
|
execution_date = context.get("execution_date")
|
|
538
576
|
if not execution_date:
|
|
539
577
|
raise ValueError("Either `execution_date` must be provided in the context`")
|
|
578
|
+
if not isinstance(execution_date, datetime.datetime):
|
|
579
|
+
raise ValueError("execution_date must be a datetime object")
|
|
540
580
|
return execution_date
|
|
541
581
|
|
|
542
|
-
def _handle_execution_date_fn(self, context) ->
|
|
582
|
+
def _handle_execution_date_fn(self, context: Context) -> datetime.datetime | list[datetime.datetime]:
|
|
543
583
|
"""
|
|
544
584
|
Handle backward compatibility.
|
|
545
585
|
|
|
@@ -20,31 +20,14 @@ from __future__ import annotations
|
|
|
20
20
|
import datetime
|
|
21
21
|
import os
|
|
22
22
|
from collections.abc import Sequence
|
|
23
|
-
from dataclasses import dataclass
|
|
24
23
|
from functools import cached_property
|
|
25
24
|
from glob import glob
|
|
26
25
|
from typing import TYPE_CHECKING, Any
|
|
27
26
|
|
|
28
|
-
from airflow.
|
|
29
|
-
from airflow.exceptions import AirflowException
|
|
30
|
-
from airflow.providers.common.compat.sdk import BaseSensorOperator
|
|
27
|
+
from airflow.providers.common.compat.sdk import AirflowException, BaseSensorOperator, conf
|
|
31
28
|
from airflow.providers.standard.hooks.filesystem import FSHook
|
|
32
29
|
from airflow.providers.standard.triggers.file import FileTrigger
|
|
33
|
-
|
|
34
|
-
try:
|
|
35
|
-
from airflow.triggers.base import StartTriggerArgs # type: ignore[no-redef]
|
|
36
|
-
except ImportError: # TODO: Remove this when min airflow version is 2.10.0 for standard provider
|
|
37
|
-
|
|
38
|
-
@dataclass
|
|
39
|
-
class StartTriggerArgs: # type: ignore[no-redef]
|
|
40
|
-
"""Arguments required for start task execution from triggerer."""
|
|
41
|
-
|
|
42
|
-
trigger_cls: str
|
|
43
|
-
next_method: str
|
|
44
|
-
trigger_kwargs: dict[str, Any] | None = None
|
|
45
|
-
next_kwargs: dict[str, Any] | None = None
|
|
46
|
-
timeout: datetime.timedelta | None = None
|
|
47
|
-
|
|
30
|
+
from airflow.triggers.base import StartTriggerArgs
|
|
48
31
|
|
|
49
32
|
if TYPE_CHECKING:
|
|
50
33
|
from airflow.sdk import Context
|
|
@@ -19,28 +19,12 @@ from __future__ import annotations
|
|
|
19
19
|
|
|
20
20
|
import datetime
|
|
21
21
|
import warnings
|
|
22
|
-
from dataclasses import dataclass
|
|
23
22
|
from typing import TYPE_CHECKING, Any
|
|
24
23
|
|
|
25
|
-
from airflow.configuration import conf
|
|
26
24
|
from airflow.exceptions import AirflowProviderDeprecationWarning
|
|
27
|
-
from airflow.providers.common.compat.sdk import BaseSensorOperator, timezone
|
|
25
|
+
from airflow.providers.common.compat.sdk import BaseSensorOperator, conf, timezone
|
|
28
26
|
from airflow.providers.standard.triggers.temporal import DateTimeTrigger
|
|
29
|
-
|
|
30
|
-
try:
|
|
31
|
-
from airflow.triggers.base import StartTriggerArgs # type: ignore[no-redef]
|
|
32
|
-
except ImportError: # TODO: Remove this when min airflow version is 2.10.0 for standard provider
|
|
33
|
-
|
|
34
|
-
@dataclass
|
|
35
|
-
class StartTriggerArgs: # type: ignore[no-redef]
|
|
36
|
-
"""Arguments required for start task execution from triggerer."""
|
|
37
|
-
|
|
38
|
-
trigger_cls: str
|
|
39
|
-
next_method: str
|
|
40
|
-
trigger_kwargs: dict[str, Any] | None = None
|
|
41
|
-
next_kwargs: dict[str, Any] | None = None
|
|
42
|
-
timeout: datetime.timedelta | None = None
|
|
43
|
-
|
|
27
|
+
from airflow.triggers.base import StartTriggerArgs
|
|
44
28
|
|
|
45
29
|
if TYPE_CHECKING:
|
|
46
30
|
from airflow.sdk import Context
|
|
@@ -25,9 +25,8 @@ from typing import TYPE_CHECKING, Any
|
|
|
25
25
|
from deprecated.classic import deprecated
|
|
26
26
|
from packaging.version import Version
|
|
27
27
|
|
|
28
|
-
from airflow.
|
|
29
|
-
from airflow.
|
|
30
|
-
from airflow.providers.common.compat.sdk import BaseSensorOperator, timezone
|
|
28
|
+
from airflow.exceptions import AirflowProviderDeprecationWarning
|
|
29
|
+
from airflow.providers.common.compat.sdk import AirflowSkipException, BaseSensorOperator, conf, timezone
|
|
31
30
|
from airflow.providers.standard.triggers.temporal import DateTimeTrigger, TimeDeltaTrigger
|
|
32
31
|
from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS
|
|
33
32
|
|
|
@@ -194,9 +193,11 @@ class WaitSensor(BaseSensorOperator):
|
|
|
194
193
|
def execute(self, context: Context) -> None:
|
|
195
194
|
if self.deferrable:
|
|
196
195
|
self.defer(
|
|
197
|
-
trigger=
|
|
198
|
-
|
|
199
|
-
|
|
196
|
+
trigger=(
|
|
197
|
+
TimeDeltaTrigger(self.time_to_wait, end_from_trigger=True)
|
|
198
|
+
if AIRFLOW_V_3_0_PLUS
|
|
199
|
+
else TimeDeltaTrigger(self.time_to_wait)
|
|
200
|
+
),
|
|
200
201
|
method_name="execute_complete",
|
|
201
202
|
)
|
|
202
203
|
else:
|
|
@@ -18,10 +18,11 @@ from __future__ import annotations
|
|
|
18
18
|
|
|
19
19
|
import asyncio
|
|
20
20
|
import typing
|
|
21
|
+
from collections.abc import Collection
|
|
21
22
|
from typing import Any
|
|
22
23
|
|
|
23
24
|
from asgiref.sync import sync_to_async
|
|
24
|
-
from sqlalchemy import func
|
|
25
|
+
from sqlalchemy import func, select
|
|
25
26
|
|
|
26
27
|
from airflow.models import DagRun
|
|
27
28
|
from airflow.providers.standard.utils.sensor_helper import _get_count
|
|
@@ -60,9 +61,9 @@ class WorkflowTrigger(BaseTrigger):
|
|
|
60
61
|
logical_dates: list[datetime] | None = None,
|
|
61
62
|
external_task_ids: typing.Collection[str] | None = None,
|
|
62
63
|
external_task_group_id: str | None = None,
|
|
63
|
-
failed_states:
|
|
64
|
-
skipped_states:
|
|
65
|
-
allowed_states:
|
|
64
|
+
failed_states: Collection[str] | None = None,
|
|
65
|
+
skipped_states: Collection[str] | None = None,
|
|
66
|
+
allowed_states: Collection[str] | None = None,
|
|
66
67
|
poke_interval: float = 2.0,
|
|
67
68
|
soft_fail: bool = False,
|
|
68
69
|
**kwargs,
|
|
@@ -129,43 +130,41 @@ class WorkflowTrigger(BaseTrigger):
|
|
|
129
130
|
self.log.info("Sleeping for %s seconds", self.poke_interval)
|
|
130
131
|
await asyncio.sleep(self.poke_interval)
|
|
131
132
|
|
|
132
|
-
async def _get_count_af_3(self, states):
|
|
133
|
+
async def _get_count_af_3(self, states: Collection[str] | None) -> int:
|
|
133
134
|
from airflow.providers.standard.utils.sensor_helper import _get_count_by_matched_states
|
|
134
135
|
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
|
|
135
136
|
|
|
136
|
-
params = {
|
|
137
|
-
"dag_id": self.external_dag_id,
|
|
138
|
-
"logical_dates": self.logical_dates,
|
|
139
|
-
"run_ids": self.run_ids,
|
|
140
|
-
}
|
|
141
137
|
if self.external_task_ids:
|
|
142
138
|
count = await sync_to_async(RuntimeTaskInstance.get_ti_count)(
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
139
|
+
dag_id=self.external_dag_id,
|
|
140
|
+
task_ids=list(self.external_task_ids),
|
|
141
|
+
logical_dates=self.logical_dates,
|
|
142
|
+
run_ids=self.run_ids,
|
|
143
|
+
states=list(states) if states else None,
|
|
146
144
|
)
|
|
147
|
-
|
|
145
|
+
return int(count / len(self.external_task_ids))
|
|
146
|
+
if self.external_task_group_id:
|
|
148
147
|
run_id_task_state_map = await sync_to_async(RuntimeTaskInstance.get_task_states)(
|
|
148
|
+
dag_id=self.external_dag_id,
|
|
149
149
|
task_group_id=self.external_task_group_id,
|
|
150
|
-
|
|
150
|
+
logical_dates=self.logical_dates,
|
|
151
|
+
run_ids=self.run_ids,
|
|
151
152
|
)
|
|
152
153
|
count = await sync_to_async(_get_count_by_matched_states)(
|
|
153
154
|
run_id_task_state_map=run_id_task_state_map,
|
|
154
|
-
states=states,
|
|
155
|
-
)
|
|
156
|
-
else:
|
|
157
|
-
count = await sync_to_async(RuntimeTaskInstance.get_dr_count)(
|
|
158
|
-
dag_id=self.external_dag_id,
|
|
159
|
-
logical_dates=self.logical_dates,
|
|
160
|
-
run_ids=self.run_ids,
|
|
161
|
-
states=states,
|
|
155
|
+
states=states or [],
|
|
162
156
|
)
|
|
163
|
-
|
|
164
|
-
|
|
157
|
+
return count
|
|
158
|
+
count = await sync_to_async(RuntimeTaskInstance.get_dr_count)(
|
|
159
|
+
dag_id=self.external_dag_id,
|
|
160
|
+
logical_dates=self.logical_dates,
|
|
161
|
+
run_ids=self.run_ids,
|
|
162
|
+
states=list(states) if states else None,
|
|
163
|
+
)
|
|
165
164
|
return count
|
|
166
165
|
|
|
167
166
|
@sync_to_async
|
|
168
|
-
def _get_count(self, states:
|
|
167
|
+
def _get_count(self, states: Collection[str] | None) -> int:
|
|
169
168
|
"""
|
|
170
169
|
Get the count of records against dttm filter and states. Async wrapper for _get_count.
|
|
171
170
|
|
|
@@ -227,23 +226,26 @@ class DagStateTrigger(BaseTrigger):
|
|
|
227
226
|
elif self.execution_dates:
|
|
228
227
|
runs_ids_or_dates = len(self.execution_dates)
|
|
229
228
|
|
|
229
|
+
cls_path, data = self.serialize()
|
|
230
|
+
|
|
230
231
|
if AIRFLOW_V_3_0_PLUS:
|
|
231
|
-
|
|
232
|
-
|
|
232
|
+
data.update( # update with {run_id: run_state} dict
|
|
233
|
+
await self.validate_count_dags_af_3(runs_ids_or_dates_len=runs_ids_or_dates)
|
|
234
|
+
)
|
|
235
|
+
yield TriggerEvent((cls_path, data))
|
|
233
236
|
return
|
|
234
237
|
else:
|
|
235
238
|
while True:
|
|
236
239
|
num_dags = await self.count_dags()
|
|
237
240
|
if num_dags == runs_ids_or_dates:
|
|
238
|
-
yield TriggerEvent(
|
|
241
|
+
yield TriggerEvent((cls_path, data))
|
|
239
242
|
return
|
|
240
243
|
await asyncio.sleep(self.poll_interval)
|
|
241
244
|
|
|
242
|
-
async def validate_count_dags_af_3(self, runs_ids_or_dates_len: int = 0) ->
|
|
245
|
+
async def validate_count_dags_af_3(self, runs_ids_or_dates_len: int = 0) -> dict[str, str]:
|
|
243
246
|
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
|
|
244
247
|
|
|
245
|
-
|
|
246
|
-
|
|
248
|
+
run_states: dict[str, str] = {} # {run_id: run_state}
|
|
247
249
|
while True:
|
|
248
250
|
num_dags = await sync_to_async(RuntimeTaskInstance.get_dr_count)(
|
|
249
251
|
dag_id=self.dag_id,
|
|
@@ -258,8 +260,8 @@ class DagStateTrigger(BaseTrigger):
|
|
|
258
260
|
dag_id=self.dag_id,
|
|
259
261
|
run_id=run_id,
|
|
260
262
|
)
|
|
261
|
-
|
|
262
|
-
return
|
|
263
|
+
run_states[run_id] = state
|
|
264
|
+
return run_states
|
|
263
265
|
await asyncio.sleep(self.poll_interval)
|
|
264
266
|
|
|
265
267
|
if not AIRFLOW_V_3_0_PLUS:
|
|
@@ -270,17 +272,18 @@ class DagStateTrigger(BaseTrigger):
|
|
|
270
272
|
def count_dags(self, *, session: Session = NEW_SESSION) -> int:
|
|
271
273
|
"""Count how many dag runs in the database match our criteria."""
|
|
272
274
|
_dag_run_date_condition = (
|
|
273
|
-
DagRun.run_id.in_(self.run_ids)
|
|
275
|
+
DagRun.run_id.in_(self.run_ids or [])
|
|
274
276
|
if AIRFLOW_V_3_0_PLUS
|
|
275
277
|
else DagRun.execution_date.in_(self.execution_dates)
|
|
276
278
|
)
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
.
|
|
279
|
+
stmt = (
|
|
280
|
+
select(func.count())
|
|
281
|
+
.select_from(DagRun)
|
|
282
|
+
.where(
|
|
280
283
|
DagRun.dag_id == self.dag_id,
|
|
281
284
|
DagRun.state.in_(self.states),
|
|
282
285
|
_dag_run_date_condition,
|
|
283
286
|
)
|
|
284
|
-
.scalar()
|
|
285
287
|
)
|
|
286
|
-
|
|
288
|
+
result = session.execute(stmt).scalar()
|
|
289
|
+
return result or 0
|
|
@@ -79,7 +79,7 @@ class FileTrigger(BaseTrigger):
|
|
|
79
79
|
self.log.info("Found File %s last modified: %s", path, mod_time)
|
|
80
80
|
yield TriggerEvent(True)
|
|
81
81
|
return
|
|
82
|
-
for _, _, files in os.walk(
|
|
82
|
+
for _, _, files in os.walk(path):
|
|
83
83
|
if files:
|
|
84
84
|
yield TriggerEvent(True)
|
|
85
85
|
return
|