django-lambda-tasks 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
lambda_tasks/models.py ADDED
@@ -0,0 +1,240 @@
1
+ """Django model for persisting task execution records."""
2
+
3
+ import random
4
+ import traceback
5
+
6
+ import boto3
7
+ from django.core.exceptions import ImproperlyConfigured
8
+ from django.db import models, transaction
9
+ from django.db.models import Q
10
+ from django.utils.module_loading import import_string
11
+ from django.utils.timezone import now
12
+ from pydantic import BaseModel, ConfigDict, Field
13
+
14
+ from lambda_tasks.logging import task_logger
15
+ from lambda_tasks.settings import LambdaTasksSettings
16
+ from lambda_tasks.timeouts import TimeoutContext
17
+
18
+
19
+ class MaxRetriesExceededError(Exception):
20
+ def __init__(self, *, task_name: str, n_retries: int) -> None:
21
+ super().__init__(
22
+ f"Task '{task_name}' exceeded the maximum retry limit ({n_retries} retries)."
23
+ )
24
+
25
+
26
+ class TaskStatus(models.TextChoices):
27
+ RUNNING = "RUNNING"
28
+ SUCCESS = "SUCCESS"
29
+ FAILED = "FAILED"
30
+ RETRYING = "RETRYING"
31
+
32
+
33
+ class TaskRecord(models.Model):
34
+ TaskStatus = TaskStatus
35
+
36
+ task_name = models.CharField(max_length=255, editable=False)
37
+ invocation_id = models.UUIDField(unique=True, editable=False)
38
+ kwargs = models.JSONField(editable=False)
39
+ n_retries = models.PositiveSmallIntegerField(editable=False)
40
+ status = models.CharField(
41
+ max_length=10,
42
+ choices=TaskStatus,
43
+ editable=False,
44
+ )
45
+ start_time = models.DateTimeField(null=True, editable=False)
46
+ end_time = models.DateTimeField(null=True, editable=False)
47
+ result = models.JSONField(null=True, editable=False)
48
+ traceback = models.TextField(null=True, editable=False)
49
+
50
+ @property
51
+ def duration(self) -> object:
52
+ return self.end_time - self.start_time # type: ignore[operator]
53
+
54
+ class Meta:
55
+ ordering = ["-start_time"]
56
+ indexes = [
57
+ models.Index(fields=["task_name"]),
58
+ models.Index(fields=["invocation_id"]),
59
+ models.Index(fields=["status"]),
60
+ models.Index(fields=["-start_time"]),
61
+ ]
62
+ constraints = [
63
+ models.CheckConstraint(
64
+ condition=Q(status__in=TaskStatus.values),
65
+ name="taskrecord_status_valid",
66
+ ),
67
+ ]
68
+
69
+
70
+ class SQSLambdaTaskMessage(BaseModel):
71
+ model_config = ConfigDict(extra="forbid")
72
+
73
+ task_name: str
74
+ invocation_id: str
75
+ kwargs: dict
76
+ n_retries: int = Field(default=0, ge=0)
77
+
78
+ def execute_immediately(self) -> None:
79
+ """Execute a background task described.
80
+
81
+ Creates a TaskRecord, resolves timeouts, validates configuration,
82
+ runs the task inside an atomic block with timeout enforcement, and
83
+ persists the outcome.
84
+ """
85
+ task_logger.invocation_id = self.invocation_id
86
+
87
+ try:
88
+ task_logger.info(f"Received {self.task_name}")
89
+
90
+ # Local import to avoid circular dependency
91
+ from lambda_tasks.decorators import LambdaTaskWrapper
92
+
93
+ wrapper = import_string(self.task_name)
94
+
95
+ if not isinstance(wrapper, LambdaTaskWrapper):
96
+ raise TypeError(
97
+ f"import_string('{self.task_name}') returned {type(wrapper)!r}, "
98
+ f"expected LambdaTaskWrapper."
99
+ )
100
+
101
+ # Resolve before creating the TaskRecord — ConfigurationError here means
102
+ # misconfigured settings, not a task failure, so no record should be written.
103
+ soft_timeout, hard_timeout = wrapper.resolved_timeouts
104
+
105
+ record, created = (
106
+ TaskRecord.objects.get_or_create( # ty: ignore[unresolved-attribute]
107
+ invocation_id=self.invocation_id,
108
+ defaults={
109
+ "task_name": self.task_name,
110
+ "kwargs": self.kwargs,
111
+ "n_retries": self.n_retries,
112
+ "status": TaskRecord.TaskStatus.RUNNING,
113
+ "start_time": now(),
114
+ },
115
+ )
116
+ )
117
+
118
+ if not created:
119
+ task_logger.warning(
120
+ f"Skipping duplicate delivery (existing record status={record.status})"
121
+ )
122
+ return
123
+
124
+ ignored_exception: BaseException | None = None
125
+ ignored_traceback: str | None = None
126
+ result = None
127
+
128
+ try:
129
+ with transaction.atomic():
130
+ with TimeoutContext(
131
+ soft_timeout=soft_timeout, hard_timeout=hard_timeout
132
+ ):
133
+ result = wrapper(**self.kwargs)
134
+ except Exception as error:
135
+ if wrapper.ignore_errors and isinstance(error, wrapper.ignore_errors):
136
+ ignored_exception = error
137
+ ignored_traceback = traceback.format_exc()
138
+ elif wrapper.retry_on and isinstance(error, wrapper.retry_on):
139
+ conf = LambdaTasksSettings()
140
+ if self.n_retries >= conf.MAX_RETRIES:
141
+ record.status = TaskRecord.TaskStatus.FAILED
142
+ record.traceback = traceback.format_exc()
143
+ record.end_time = now()
144
+ record.save(update_fields=["status", "traceback", "end_time"])
145
+
146
+ task_logger.warning(
147
+ f"Failed due to MaxRetriesExceededError in {record.duration}"
148
+ )
149
+
150
+ raise MaxRetriesExceededError(
151
+ task_name=self.task_name, n_retries=self.n_retries
152
+ )
153
+ else:
154
+ record.status = TaskRecord.TaskStatus.RETRYING
155
+ record.traceback = traceback.format_exc()
156
+ record.end_time = now()
157
+ record.save(update_fields=["status", "traceback", "end_time"])
158
+
159
+ task_logger.warning(
160
+ f"Retrying (due to {type(ignored_exception).__name__}) after {record.duration}"
161
+ )
162
+
163
+ delay = (
164
+ wrapper._delay
165
+ if wrapper._delay != 0
166
+ else round(random.uniform(1, 5))
167
+ )
168
+ wrapper.execute_on_commit(
169
+ **self.kwargs,
170
+ _delay=delay,
171
+ _n_retries=self.n_retries + 1,
172
+ )
173
+ return
174
+ else:
175
+ task_logger.error(error, exc_info=True)
176
+
177
+ record.status = TaskRecord.TaskStatus.FAILED
178
+ record.traceback = traceback.format_exc()
179
+ record.end_time = now()
180
+ record.save(update_fields=["status", "traceback", "end_time"])
181
+
182
+ task_logger.warning(f"Failed in {record.duration}")
183
+ return
184
+
185
+ if ignored_exception is None:
186
+ record.status = TaskRecord.TaskStatus.SUCCESS
187
+ record.result = result
188
+ record.end_time = now()
189
+ record.save(update_fields=["status", "result", "end_time"])
190
+
191
+ task_logger.info(f"Succeeded in {record.duration}")
192
+ else:
193
+ record.status = TaskRecord.TaskStatus.SUCCESS
194
+ record.traceback = ignored_traceback
195
+ record.end_time = now()
196
+ record.save(update_fields=["status", "traceback", "end_time"])
197
+
198
+ task_logger.info(
199
+ f"Succeeded (ignored {type(ignored_exception).__name__}) in {record.duration}"
200
+ )
201
+
202
+ finally:
203
+ task_logger.invocation_id = None
204
+
205
+
206
+ class SQSLambdaTask(BaseModel):
207
+ model_config = ConfigDict(extra="forbid")
208
+
209
+ message: SQSLambdaTaskMessage
210
+ delay: int
211
+ queue: str
212
+
213
+ def execute_on_commit(self) -> None:
214
+ """Enqueue this task after the current transaction commits."""
215
+ transaction.on_commit(self._execute)
216
+
217
+ def _execute(self) -> None:
218
+ """Send this task to SQS (or execute eagerly).
219
+
220
+ Raises:
221
+ ImproperlyConfigured: if the queue name is not found in settings.
222
+ Any boto3 exception: propagated directly to the caller.
223
+ """
224
+ conf = LambdaTasksSettings()
225
+
226
+ if conf.EAGER:
227
+ self.message.execute_immediately()
228
+ else:
229
+ try:
230
+ queue_url = conf.QUEUES[self.queue]
231
+ except KeyError:
232
+ raise ImproperlyConfigured(
233
+ f"Queue '{self.queue}' is not defined in LAMBDA_TASKS_QUEUES."
234
+ )
235
+ client = boto3.client("sqs")
236
+ client.send_message(
237
+ QueueUrl=queue_url,
238
+ MessageBody=self.message.model_dump_json(),
239
+ DelaySeconds=self.delay,
240
+ )
@@ -0,0 +1,183 @@
1
+ """
2
+ Resolves environment variables that reference AWS Secrets Manager ARNs.
3
+
4
+ Any env var prefixed with ``AWS_SECRETS_MANAGER_`` is treated as a pointer
5
+ to a secret value. The unprefixed name is the target env var to populate.
6
+
7
+ Required value format
8
+ ---------------------
9
+ Every reference must follow the full dynamic reference syntax::
10
+
11
+ AWS_SECRETS_MANAGER_DJANGO_ADMIN_URL=arn:aws:secretsmanager:eu-west-1:123:secret:my-secret:DJANGO_ADMIN_URL:AWSCURRENT:v1
12
+
13
+ That is: ``<arn>:<json-key>:<version-stage>:<version-id>``
14
+
15
+ All four suffix segments must be present and non-empty.
16
+ A malformed reference raises ``ValueError`` immediately so the Lambda
17
+ container fails at cold start rather than silently misconfiguring Django.
18
+
19
+ It is a configuration error to set both ``AWS_SECRETS_MANAGER_FOO`` and
20
+ ``FOO`` — use one or the other. Having both raises ``ValueError`` at cold
21
+ start so the misconfiguration is caught immediately.
22
+
23
+ Calls are batched by unique (ARN, version-stage, version-id) combination —
24
+ one ``GetSecretValue`` call per unique combination, regardless of how many
25
+ env vars reference it. Results are cached in-process so repeated calls to
26
+ ``resolve_secrets_into_env`` are free after the first.
27
+ """
28
+
29
+ import json
30
+ import logging
31
+ import os
32
+ from typing import NamedTuple
33
+
34
+ import boto3
35
+
36
+ logger = logging.getLogger(__name__)
37
+
38
+ _PREFIX = "AWS_SECRETS_MANAGER_"
39
+
40
+ # Module-level cache: (arn, version_stage, version_id) → raw secret string.
41
+ # Populated on first call; reused for the lifetime of the Lambda container.
42
+ _secret_cache: dict[tuple[str, str, str], dict[str, str]] = {}
43
+
44
+
45
+ class _SecretReference(NamedTuple):
46
+ arn: str
47
+ json_key: str
48
+ version_stage: str
49
+ version_id: str
50
+
51
+
52
+ def _parse_reference(*, env_var: str, value: str) -> _SecretReference:
53
+ """Parse and validate a Secrets Manager reference.
54
+
55
+ Expected format::
56
+
57
+ <arn>:<json-key>:<version-stage>:<version-id>
58
+
59
+ The ARN itself is 7 colon-separated segments, plus 3 suffix segments
60
+ (json-key, version-stage, version-id) = 10 total.
61
+ All four suffix fields must be non-empty.
62
+
63
+ Raises ``ValueError`` if the format is invalid — this is intentional so
64
+ the Lambda container fails at cold start rather than starting with a
65
+ misconfigured environment.
66
+ """
67
+ # ARN has exactly 7 colon-separated parts:
68
+ # arn : aws : secretsmanager : region : account : secret : name
69
+ # Plus 3 suffix parts: json-key : version-stage : version-id → 10 total
70
+ parts = value.split(":")
71
+
72
+ if len(parts) != 10:
73
+ raise ValueError(
74
+ f"{env_var} has an invalid Secrets Manager reference format. "
75
+ "Expected <arn>:<json-key>:<version-stage>:<version-id> "
76
+ f"(10 colon-separated segments), got {len(parts)}: {value!r}"
77
+ )
78
+
79
+ arn = ":".join(parts[:7])
80
+ json_key, version_stage, version_id = parts[7], parts[8], parts[9]
81
+
82
+ for field, field_value in (
83
+ ("json-key", json_key),
84
+ ("version-stage", version_stage),
85
+ ("version-id", version_id),
86
+ ):
87
+ if not field_value:
88
+ raise ValueError(
89
+ f"{env_var} is missing the {field} segment in its Secrets Manager "
90
+ f"reference: {value!r}"
91
+ )
92
+
93
+ return _SecretReference(
94
+ arn=arn,
95
+ json_key=json_key,
96
+ version_stage=version_stage,
97
+ version_id=version_id,
98
+ )
99
+
100
+
101
+ def _fetch_secret(*, client: object, ref: _SecretReference) -> dict[str, str]:
102
+ """Fetch a secret string from Secrets Manager, using the module cache.
103
+
104
+ The cache key is ``(arn, version_stage, version_id)`` so that references
105
+ to different versions of the same secret are fetched independently.
106
+ """
107
+ cache_key = (ref.arn, ref.version_stage, ref.version_id)
108
+
109
+ if cache_key in _secret_cache:
110
+ return _secret_cache[cache_key]
111
+
112
+ logger.debug(f"Fetching secret {ref}")
113
+
114
+ response = client.get_secret_value( # type: ignore[union-attr]
115
+ SecretId=ref.arn,
116
+ VersionStage=ref.version_stage,
117
+ VersionId=ref.version_id,
118
+ )
119
+
120
+ try:
121
+ secret = json.loads(response["SecretString"])
122
+ except json.JSONDecodeError as error:
123
+ raise ValueError(f"Could not decode secret as json for {ref}")
124
+
125
+ _secret_cache[cache_key] = secret
126
+
127
+ return secret
128
+
129
+
130
+ def resolve_secrets_into_env() -> None:
131
+ """Scan env vars for ``AWS_SECRETS_MANAGER_*`` references and resolve them.
132
+
133
+ For each matching env var the resolved value is written back into
134
+ ``os.environ`` under the unprefixed name.
135
+
136
+ Raises ``ValueError`` at cold start if:
137
+ - A reference is malformed (wrong segment count, any empty field)
138
+ - The target env var is already set — use ``AWS_SECRETS_MANAGER_FOO`` or
139
+ ``FOO``, not both
140
+
141
+ This function is idempotent — calling it multiple times is safe and cheap
142
+ because resolved secrets are cached after the first fetch.
143
+ """
144
+ references: dict[str, _SecretReference] = {} # target_name → _SecretReference
145
+
146
+ for key, value in os.environ.items():
147
+ if not key.startswith(_PREFIX):
148
+ continue
149
+ target = key[len(_PREFIX) :]
150
+ references[target] = _parse_reference(env_var=key, value=value)
151
+
152
+ if not references:
153
+ return
154
+
155
+ # Fail fast if any target is already set in the environment.
156
+ conflicts = [target for target in references if target in os.environ]
157
+ if conflicts:
158
+ raise ValueError(
159
+ "The following environment variables are set both directly and via "
160
+ f"AWS_SECRETS_MANAGER_*: {', '.join(sorted(conflicts))}. "
161
+ "Use one or the other, not both."
162
+ )
163
+
164
+ # Unique (arn, version_stage, version_id) combinations to minimise API calls.
165
+ unique_cache_keys = {
166
+ (ref.arn, ref.version_stage, ref.version_id) for ref in references.values()
167
+ }
168
+ uncached = unique_cache_keys - _secret_cache.keys()
169
+ logger.info(
170
+ f"Resolving {len(references)} secret reference(s) from"
171
+ f" {len(unique_cache_keys)} unique secret version(s)"
172
+ f" ({len(unique_cache_keys) - len(uncached)} already cached)"
173
+ )
174
+
175
+ client = boto3.client("secretsmanager") if uncached else None
176
+
177
+ for target, ref in references.items():
178
+ secret = _fetch_secret(client=client, ref=ref)
179
+ secret_value = secret[ref.json_key]
180
+
181
+ os.environ[target] = secret_value
182
+
183
+ logger.info(f"Resolved {target} from secret {ref}")
@@ -0,0 +1,76 @@
1
+ """
2
+ Lazy settings object for lambda_tasks.
3
+ Reads from django.conf.settings on first attribute access.
4
+ """
5
+
6
+ from django.conf import settings as django_settings
7
+ from django.core.exceptions import ImproperlyConfigured
8
+
9
+ _UNSET = object()
10
+
11
+ MAX_TIMEOUT = 900
12
+
13
+
14
+ class LambdaTasksSettings:
15
+ """Lazy wrapper around Django settings for the lambda_tasks library."""
16
+
17
+ def _get(self, *, name: str, default: object = _UNSET) -> object:
18
+ val = getattr(django_settings, name, _UNSET)
19
+ if val is _UNSET:
20
+ if default is _UNSET:
21
+ return None
22
+ return default
23
+ if val is None:
24
+ if default is _UNSET:
25
+ return None
26
+ return None
27
+ return val
28
+
29
+ def _resolve_queues(self) -> dict[str, str]:
30
+ queues = self._get(name="LAMBDA_TASKS_QUEUES")
31
+
32
+ if queues:
33
+ if "default" not in queues: # type: ignore[operator]
34
+ raise ImproperlyConfigured(
35
+ "LAMBDA_TASKS_QUEUES must contain a 'default' key."
36
+ )
37
+ return queues # type: ignore[return-value]
38
+
39
+ raise ImproperlyConfigured(
40
+ "LAMBDA_TASKS_QUEUES must be defined in Django settings."
41
+ )
42
+
43
+ @property
44
+ def QUEUES(self) -> dict[str, str]:
45
+ return self._resolve_queues()
46
+
47
+ @property
48
+ def EAGER(self) -> bool:
49
+ val = self._get(name="LAMBDA_TASKS_EAGER", default=False)
50
+ return bool(val)
51
+
52
+ @property
53
+ def DEFAULT_SOFT_TIMEOUT(self) -> int:
54
+ default = 270
55
+ val = self._get(name="LAMBDA_TASKS_DEFAULT_SOFT_TIMEOUT", default=default)
56
+ if val is None:
57
+ return default
58
+ else:
59
+ return int(val) # type: ignore[return-value]
60
+
61
+ @property
62
+ def DEFAULT_HARD_TIMEOUT(self) -> int:
63
+ default = 300
64
+ val = self._get(name="LAMBDA_TASKS_DEFAULT_HARD_TIMEOUT", default=default)
65
+ if val is None:
66
+ return default
67
+ else:
68
+ return int(val) # type: ignore[return-value]
69
+
70
+ @property
71
+ def MAX_RETRIES(self) -> int:
72
+ default = 2880 # 60 * 24 * 2
73
+ val = self._get(name="LAMBDA_TASKS_MAX_RETRIES", default=default)
74
+ if val is None:
75
+ return default
76
+ return int(val) # type: ignore[return-value]
@@ -0,0 +1,78 @@
1
+ """Timeout enforcement for lambda tasks using Unix SIGALRM."""
2
+
3
+ import signal
4
+
5
+ from lambda_tasks.settings import LambdaTasksSettings
6
+
7
+
8
+ class SoftTimeLimitExceeded(Exception):
9
+ """Raised inside a running task when the soft timeout is exceeded.
10
+
11
+ The task may catch this exception to perform cleanup before the hard
12
+ timeout forcibly terminates execution.
13
+ """
14
+
15
+
16
+ class HardTimeLimitExceeded(Exception):
17
+ """Raised by the executor when the hard timeout is exceeded.
18
+
19
+ The task is terminated immediately; this exception is not intended to
20
+ be caught by task code.
21
+ """
22
+
23
+
24
+ class TimeoutContext:
25
+ """Context manager that enforces soft and hard timeouts via SIGALRM.
26
+
27
+ Two-phase approach:
28
+ 1. Arm soft_timeout seconds. On SIGALRM, raise SoftTimeLimitExceeded
29
+ inside the task and re-arm for (hard_timeout - soft_timeout) seconds.
30
+ 2. Second SIGALRM raises HardTimeLimitExceeded.
31
+
32
+ Any pre-existing alarm is saved on enter and restored on exit.
33
+
34
+ Usage::
35
+
36
+ with TimeoutContext(soft_timeout=270, hard_timeout=300):
37
+ run_task()
38
+ """
39
+
40
+ def __init__(self, soft_timeout: int, hard_timeout: int) -> None:
41
+ self.soft_timeout = soft_timeout
42
+ self.hard_timeout = hard_timeout
43
+ self._previous_alarm: int = 0
44
+ self._previous_handler = None
45
+ self._in_hard_phase: bool = False
46
+
47
+ def _handler(self, signum: int, frame) -> None:
48
+ if not self._in_hard_phase:
49
+ # Phase 1 → Phase 2: raise soft, re-arm for remaining hard window
50
+ self._in_hard_phase = True
51
+ remaining = self.hard_timeout - self.soft_timeout
52
+ signal.alarm(max(1, remaining))
53
+ raise SoftTimeLimitExceeded()
54
+ else:
55
+ # Phase 2: hard timeout expired
56
+ raise HardTimeLimitExceeded()
57
+
58
+ def __enter__(self) -> "TimeoutContext":
59
+ conf = LambdaTasksSettings()
60
+
61
+ if not conf.EAGER:
62
+ self._in_hard_phase = False
63
+ self._previous_handler = signal.signal(signal.SIGALRM, self._handler)
64
+ self._previous_alarm = signal.alarm(self.soft_timeout)
65
+
66
+ return self
67
+
68
+ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
69
+ conf = LambdaTasksSettings()
70
+
71
+ if not conf.EAGER:
72
+ # Cancel any pending alarm we armed
73
+ signal.alarm(0)
74
+ # Restore the original signal handler
75
+ signal.signal(signal.SIGALRM, self._previous_handler)
76
+ # Restore any pre-existing alarm
77
+ if self._previous_alarm > 0:
78
+ signal.alarm(self._previous_alarm)