plain.jobs 0.33.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of plain.jobs might be problematic. Click here for more details.

plain/jobs/models.py ADDED
@@ -0,0 +1,438 @@
1
+ from __future__ import annotations
2
+
3
+ import datetime
4
+ import logging
5
+ import traceback
6
+ import uuid
7
+ from typing import Self
8
+
9
+ from opentelemetry import trace
10
+ from opentelemetry.semconv._incubating.attributes.code_attributes import (
11
+ CODE_NAMESPACE,
12
+ )
13
+ from opentelemetry.semconv._incubating.attributes.messaging_attributes import (
14
+ MESSAGING_CONSUMER_GROUP_NAME,
15
+ MESSAGING_DESTINATION_NAME,
16
+ MESSAGING_MESSAGE_ID,
17
+ MESSAGING_OPERATION_NAME,
18
+ MESSAGING_OPERATION_TYPE,
19
+ MESSAGING_SYSTEM,
20
+ MessagingOperationTypeValues,
21
+ )
22
+ from opentelemetry.semconv.attributes.error_attributes import ERROR_TYPE
23
+ from opentelemetry.trace import Link, SpanContext, SpanKind
24
+
25
+ from plain import models
26
+ from plain.models import transaction
27
+ from plain.runtime import settings
28
+ from plain.utils import timezone
29
+
30
+ from .registry import jobs_registry
31
+
32
+ logger = logging.getLogger("plain.jobs")
33
+ tracer = trace.get_tracer("plain.jobs")
34
+
35
+
36
+ @models.register_model
37
+ class JobRequest(models.Model):
38
+ """
39
+ Keep all pending job requests in a single table.
40
+ """
41
+
42
+ created_at = models.DateTimeField(auto_now_add=True)
43
+ uuid = models.UUIDField(default=uuid.uuid4)
44
+
45
+ job_class = models.CharField(max_length=255)
46
+ parameters = models.JSONField(required=False, allow_null=True)
47
+ priority = models.IntegerField(default=0)
48
+ source = models.TextField(required=False)
49
+ queue = models.CharField(default="default", max_length=255)
50
+
51
+ retries = models.IntegerField(default=0)
52
+ retry_attempt = models.IntegerField(default=0)
53
+
54
+ unique_key = models.CharField(max_length=255, required=False)
55
+
56
+ start_at = models.DateTimeField(required=False, allow_null=True)
57
+
58
+ # OpenTelemetry trace context
59
+ trace_id = models.CharField(max_length=34, required=False, allow_null=True)
60
+ span_id = models.CharField(max_length=18, required=False, allow_null=True)
61
+
62
+ # expires_at = models.DateTimeField(required=False, allow_null=True)
63
+
64
+ model_options = models.Options(
65
+ ordering=["priority", "-created_at"],
66
+ indexes=[
67
+ models.Index(fields=["priority"]),
68
+ models.Index(fields=["created_at"]),
69
+ models.Index(fields=["queue"]),
70
+ models.Index(fields=["start_at"]),
71
+ models.Index(fields=["unique_key"]),
72
+ models.Index(fields=["job_class"]),
73
+ models.Index(fields=["trace_id"]),
74
+ # Used to dedupe unique in-process jobs
75
+ models.Index(
76
+ name="job_request_class_unique_key", fields=["job_class", "unique_key"]
77
+ ),
78
+ ],
79
+ # The job_class and unique_key should be unique at the db-level,
80
+ # but only if unique_key is not ""
81
+ constraints=[
82
+ models.UniqueConstraint(
83
+ fields=["job_class", "unique_key"],
84
+ condition=models.Q(unique_key__gt="", retry_attempt=0),
85
+ name="plainjobs_jobrequest_unique_job_class_key",
86
+ ),
87
+ models.UniqueConstraint(
88
+ fields=["uuid"], name="plainjobs_jobrequest_unique_uuid"
89
+ ),
90
+ ],
91
+ )
92
+
93
+ def __str__(self) -> str:
94
+ return f"{self.job_class} [{self.uuid}]"
95
+
96
+ def convert_to_job_process(self) -> JobProcess:
97
+ """
98
+ JobRequests are the pending jobs that are waiting to be executed.
99
+ We immediately convert them to JobProcess when they are picked up.
100
+ """
101
+ with transaction.atomic():
102
+ result = JobProcess.query.create(
103
+ job_request_uuid=self.uuid,
104
+ job_class=self.job_class,
105
+ parameters=self.parameters,
106
+ priority=self.priority,
107
+ source=self.source,
108
+ queue=self.queue,
109
+ retries=self.retries,
110
+ retry_attempt=self.retry_attempt,
111
+ unique_key=self.unique_key,
112
+ trace_id=self.trace_id,
113
+ span_id=self.span_id,
114
+ )
115
+
116
+ # Delete the pending JobRequest now
117
+ self.delete()
118
+
119
+ return result
120
+
121
+
122
+ class JobQuerySet(models.QuerySet["JobProcess"]):
123
+ def running(self) -> Self:
124
+ return self.filter(started_at__isnull=False)
125
+
126
+ def waiting(self) -> Self:
127
+ return self.filter(started_at__isnull=True)
128
+
129
+ def mark_lost_jobs(self) -> None:
130
+ # Lost jobs are jobs that have been pending for too long,
131
+ # and probably never going to get picked up by a worker process.
132
+ # In theory we could save a timeout per-job and mark them timed-out more quickly,
133
+ # but if they're still running, we can't actually send a signal to cancel it...
134
+ now = timezone.now()
135
+ cutoff = now - datetime.timedelta(seconds=settings.JOBS_TIMEOUT)
136
+ lost_jobs = self.filter(
137
+ created_at__lt=cutoff
138
+ ) # Doesn't matter whether it started or not -- it shouldn't take this long.
139
+
140
+ # Note that this will save it in the results,
141
+ # but lost jobs are only retried if they have a retry!
142
+ for job in lost_jobs:
143
+ job.convert_to_result(status=JobResultStatuses.LOST)
144
+
145
+
146
+ @models.register_model
147
+ class JobProcess(models.Model):
148
+ """
149
+ All active jobs are stored in this table.
150
+ """
151
+
152
+ uuid = models.UUIDField(default=uuid.uuid4)
153
+ created_at = models.DateTimeField(auto_now_add=True)
154
+ started_at = models.DateTimeField(required=False, allow_null=True)
155
+
156
+ # From the JobRequest
157
+ job_request_uuid = models.UUIDField()
158
+ job_class = models.CharField(max_length=255)
159
+ parameters = models.JSONField(required=False, allow_null=True)
160
+ priority = models.IntegerField(default=0)
161
+ source = models.TextField(required=False)
162
+ queue = models.CharField(default="default", max_length=255)
163
+ retries = models.IntegerField(default=0)
164
+ retry_attempt = models.IntegerField(default=0)
165
+ unique_key = models.CharField(max_length=255, required=False)
166
+
167
+ # OpenTelemetry trace context
168
+ trace_id = models.CharField(max_length=34, required=False, allow_null=True)
169
+ span_id = models.CharField(max_length=18, required=False, allow_null=True)
170
+
171
+ query = JobQuerySet()
172
+
173
+ model_options = models.Options(
174
+ ordering=["-created_at"],
175
+ indexes=[
176
+ models.Index(fields=["created_at"]),
177
+ models.Index(fields=["queue"]),
178
+ models.Index(fields=["unique_key"]),
179
+ models.Index(fields=["started_at"]),
180
+ models.Index(fields=["job_class"]),
181
+ models.Index(fields=["job_request_uuid"]),
182
+ models.Index(fields=["trace_id"]),
183
+ # Used to dedupe unique in-process jobs
184
+ models.Index(
185
+ name="job_class_unique_key", fields=["job_class", "unique_key"]
186
+ ),
187
+ ],
188
+ constraints=[
189
+ models.UniqueConstraint(fields=["uuid"], name="plainjobs_job_unique_uuid"),
190
+ ],
191
+ )
192
+
193
+ def run(self) -> JobResult:
194
+ links = []
195
+ if self.trace_id and self.span_id:
196
+ try:
197
+ links.append(
198
+ Link(
199
+ SpanContext(
200
+ trace_id=int(self.trace_id, 16),
201
+ span_id=int(self.span_id, 16),
202
+ is_remote=True,
203
+ )
204
+ )
205
+ )
206
+ except (ValueError, TypeError):
207
+ logger.warning("Invalid trace context for job %s", self.uuid)
208
+
209
+ with (
210
+ tracer.start_as_current_span(
211
+ f"run {self.job_class}",
212
+ kind=SpanKind.CONSUMER,
213
+ attributes={
214
+ MESSAGING_SYSTEM: "plain.jobs",
215
+ MESSAGING_OPERATION_TYPE: MessagingOperationTypeValues.PROCESS.value,
216
+ MESSAGING_OPERATION_NAME: "run",
217
+ MESSAGING_MESSAGE_ID: str(self.uuid),
218
+ MESSAGING_DESTINATION_NAME: self.queue,
219
+ MESSAGING_CONSUMER_GROUP_NAME: self.queue, # Workers consume from specific queues
220
+ CODE_NAMESPACE: self.job_class,
221
+ },
222
+ links=links,
223
+ ) as span
224
+ ):
225
+ # This is how we know it has been picked up
226
+ self.started_at = timezone.now()
227
+ self.save(update_fields=["started_at"])
228
+
229
+ try:
230
+ job = jobs_registry.load_job(self.job_class, self.parameters)
231
+ job.run()
232
+ status = JobResultStatuses.SUCCESSFUL
233
+ error = ""
234
+ span.set_status(trace.StatusCode.OK)
235
+ except Exception as e:
236
+ status = JobResultStatuses.ERRORED
237
+ error = "".join(traceback.format_tb(e.__traceback__))
238
+ logger.exception(e)
239
+ span.record_exception(e)
240
+ span.set_status(trace.Status(trace.StatusCode.ERROR, str(e)))
241
+ span.set_attribute(ERROR_TYPE, type(e).__name__)
242
+
243
+ return self.convert_to_result(status=status, error=error)
244
+
245
+ def convert_to_result(self, *, status: str, error: str = "") -> JobResult:
246
+ """
247
+ Convert this JobProcess to a JobResult.
248
+ """
249
+ with transaction.atomic():
250
+ result = JobResult.query.create(
251
+ ended_at=timezone.now(),
252
+ error=error,
253
+ status=status,
254
+ # From the JobProcess
255
+ job_process_uuid=self.uuid,
256
+ started_at=self.started_at,
257
+ # From the JobRequest
258
+ job_request_uuid=self.job_request_uuid,
259
+ job_class=self.job_class,
260
+ parameters=self.parameters,
261
+ priority=self.priority,
262
+ source=self.source,
263
+ queue=self.queue,
264
+ retries=self.retries,
265
+ retry_attempt=self.retry_attempt,
266
+ unique_key=self.unique_key,
267
+ trace_id=self.trace_id,
268
+ span_id=self.span_id,
269
+ )
270
+
271
+ # Delete the JobProcess now
272
+ self.delete()
273
+
274
+ return result
275
+
276
+ def as_json(self) -> dict[str, str | int | dict | None]:
277
+ """A JSON-compatible representation to make it easier to reference in Sentry or logging"""
278
+ return {
279
+ "uuid": str(self.uuid),
280
+ "created_at": self.created_at.isoformat(),
281
+ "started_at": self.started_at.isoformat() if self.started_at else None,
282
+ "job_request_uuid": str(self.job_request_uuid),
283
+ "job_class": self.job_class,
284
+ "parameters": self.parameters,
285
+ "priority": self.priority,
286
+ "source": self.source,
287
+ "queue": self.queue,
288
+ "retries": self.retries,
289
+ "retry_attempt": self.retry_attempt,
290
+ "unique_key": self.unique_key,
291
+ "trace_id": self.trace_id,
292
+ "span_id": self.span_id,
293
+ }
294
+
295
+
296
+ class JobResultQuerySet(models.QuerySet["JobResult"]):
297
+ def successful(self) -> Self:
298
+ return self.filter(status=JobResultStatuses.SUCCESSFUL)
299
+
300
+ def cancelled(self) -> Self:
301
+ return self.filter(status=JobResultStatuses.CANCELLED)
302
+
303
+ def lost(self) -> Self:
304
+ return self.filter(status=JobResultStatuses.LOST)
305
+
306
+ def errored(self) -> Self:
307
+ return self.filter(status=JobResultStatuses.ERRORED)
308
+
309
+ def retried(self) -> Self:
310
+ return self.filter(
311
+ models.Q(retry_job_request_uuid__isnull=False)
312
+ | models.Q(retry_attempt__gt=0)
313
+ )
314
+
315
+ def failed(self) -> Self:
316
+ return self.filter(
317
+ status__in=[
318
+ JobResultStatuses.ERRORED,
319
+ JobResultStatuses.LOST,
320
+ JobResultStatuses.CANCELLED,
321
+ ]
322
+ )
323
+
324
+ def retryable(self) -> Self:
325
+ return self.failed().filter(
326
+ retry_job_request_uuid__isnull=True,
327
+ retries__gt=0,
328
+ retry_attempt__lt=models.F("retries"),
329
+ )
330
+
331
+ def retry_failed_jobs(self) -> None:
332
+ for result in self.retryable():
333
+ try:
334
+ result.retry_job()
335
+ except Exception:
336
+ # If something went wrong (like a job class being deleted)
337
+ # then we immediately increment the retry_attempt on the existing obj
338
+ # so it won't retry forever.
339
+ logger.exception(
340
+ "Failed to retry job (incrementing retry_attempt): %s", result
341
+ )
342
+ result.retry_attempt += 1
343
+ result.save(update_fields=["retry_attempt"])
344
+
345
+
346
+ class JobResultStatuses(models.TextChoices):
347
+ SUCCESSFUL = "SUCCESSFUL", "Successful"
348
+ ERRORED = "ERRORED", "Errored" # Threw an error
349
+ CANCELLED = "CANCELLED", "Cancelled" # Cancelled (probably by deploy)
350
+ LOST = (
351
+ "LOST",
352
+ "Lost",
353
+ ) # Either process lost, lost in transit, or otherwise never finished
354
+
355
+
356
+ @models.register_model
357
+ class JobResult(models.Model):
358
+ """
359
+ All in-process and completed jobs are stored in this table.
360
+ """
361
+
362
+ uuid = models.UUIDField(default=uuid.uuid4)
363
+ created_at = models.DateTimeField(auto_now_add=True)
364
+
365
+ # From the Job
366
+ job_process_uuid = models.UUIDField()
367
+ started_at = models.DateTimeField(required=False, allow_null=True)
368
+ ended_at = models.DateTimeField(required=False, allow_null=True)
369
+ error = models.TextField(required=False)
370
+ status = models.CharField(
371
+ max_length=20,
372
+ choices=JobResultStatuses.choices,
373
+ )
374
+
375
+ # From the JobRequest
376
+ job_request_uuid = models.UUIDField()
377
+ job_class = models.CharField(max_length=255)
378
+ parameters = models.JSONField(required=False, allow_null=True)
379
+ priority = models.IntegerField(default=0)
380
+ source = models.TextField(required=False)
381
+ queue = models.CharField(default="default", max_length=255)
382
+ retries = models.IntegerField(default=0)
383
+ retry_attempt = models.IntegerField(default=0)
384
+ unique_key = models.CharField(max_length=255, required=False)
385
+
386
+ # Retries
387
+ retry_job_request_uuid = models.UUIDField(required=False, allow_null=True)
388
+
389
+ # OpenTelemetry trace context
390
+ trace_id = models.CharField(max_length=34, required=False, allow_null=True)
391
+ span_id = models.CharField(max_length=18, required=False, allow_null=True)
392
+
393
+ query = JobResultQuerySet()
394
+
395
+ model_options = models.Options(
396
+ ordering=["-created_at"],
397
+ indexes=[
398
+ models.Index(fields=["created_at"]),
399
+ models.Index(fields=["job_process_uuid"]),
400
+ models.Index(fields=["started_at"]),
401
+ models.Index(fields=["ended_at"]),
402
+ models.Index(fields=["status"]),
403
+ models.Index(fields=["job_request_uuid"]),
404
+ models.Index(fields=["job_class"]),
405
+ models.Index(fields=["queue"]),
406
+ models.Index(fields=["trace_id"]),
407
+ ],
408
+ constraints=[
409
+ models.UniqueConstraint(
410
+ fields=["uuid"], name="plainjobs_jobresult_unique_uuid"
411
+ ),
412
+ ],
413
+ )
414
+
415
+ def retry_job(self, delay: int | None = None) -> JobRequest:
416
+ retry_attempt = self.retry_attempt + 1
417
+ job = jobs_registry.load_job(self.job_class, self.parameters)
418
+ retry_delay = delay or job.get_retry_delay(retry_attempt)
419
+
420
+ with transaction.atomic():
421
+ result = job.run_in_worker(
422
+ # Pass most of what we know through so it stays consistent
423
+ queue=self.queue,
424
+ delay=retry_delay,
425
+ priority=self.priority,
426
+ retries=self.retries,
427
+ retry_attempt=retry_attempt,
428
+ # Unique key could be passed also?
429
+ )
430
+
431
+ # TODO it is actually possible that result is a list
432
+ # of pending jobs, which would need to be handled...
433
+ # Right now it will throw an exception which could be caught by retry_failed_jobs.
434
+
435
+ self.retry_job_request_uuid = result.uuid # type: ignore
436
+ self.save(update_fields=["retry_job_request_uuid"])
437
+
438
+ return result # type: ignore
@@ -0,0 +1,193 @@
1
+ from __future__ import annotations
2
+
3
+ import datetime
4
+ from typing import Any
5
+
6
+ from plain.models import Model, models_registry
7
+
8
+
9
+ class JobParameter:
10
+ """Base class for job parameter serialization/deserialization."""
11
+
12
+ STR_PREFIX: str | None = None # Subclasses should define this
13
+
14
+ @classmethod
15
+ def serialize(cls, value: Any) -> str | None:
16
+ """Return serialized string or None if can't handle this value."""
17
+ return None
18
+
19
+ @classmethod
20
+ def deserialize(cls, data: Any) -> Any:
21
+ """Return deserialized value or None if can't handle this data."""
22
+ return None
23
+
24
+ @classmethod
25
+ def _extract_string_value(cls, data: Any) -> str | None:
26
+ """Extract value from string with prefix, return None if invalid format."""
27
+ if not isinstance(data, str) or not cls.STR_PREFIX:
28
+ return None
29
+ if not data.startswith(cls.STR_PREFIX) or len(data) <= len(cls.STR_PREFIX):
30
+ return None
31
+ return data[len(cls.STR_PREFIX) :]
32
+
33
+
34
+ class ModelParameter(JobParameter):
35
+ """Handle Plain model instances using a new string format."""
36
+
37
+ STR_PREFIX = "__plain://model/"
38
+
39
+ @classmethod
40
+ def serialize(cls, value: Any) -> str | None:
41
+ if isinstance(value, Model):
42
+ return f"{cls.STR_PREFIX}{value.model_options.package_label}/{value.model_options.model_name}/{value.id}"
43
+ return None
44
+
45
+ @classmethod
46
+ def deserialize(cls, data: Any) -> Model | None:
47
+ if value_part := cls._extract_string_value(data):
48
+ try:
49
+ parts = value_part.split("/")
50
+ if len(parts) == 3 and all(parts):
51
+ package, model_name, obj_id = parts
52
+ model = models_registry.get_model(package, model_name)
53
+ return model.query.get(id=obj_id)
54
+ except (ValueError, Exception):
55
+ pass
56
+ return None
57
+
58
+
59
+ class DateParameter(JobParameter):
60
+ """Handle date objects."""
61
+
62
+ STR_PREFIX = "__plain://date/"
63
+
64
+ @classmethod
65
+ def serialize(cls, value: Any) -> str | None:
66
+ if isinstance(value, datetime.date) and not isinstance(
67
+ value, datetime.datetime
68
+ ):
69
+ return f"{cls.STR_PREFIX}{value.isoformat()}"
70
+ return None
71
+
72
+ @classmethod
73
+ def deserialize(cls, data: Any) -> datetime.date | None:
74
+ if value_part := cls._extract_string_value(data):
75
+ try:
76
+ return datetime.date.fromisoformat(value_part)
77
+ except ValueError:
78
+ pass
79
+ return None
80
+
81
+
82
+ class DateTimeParameter(JobParameter):
83
+ """Handle datetime objects."""
84
+
85
+ STR_PREFIX = "__plain://datetime/"
86
+
87
+ @classmethod
88
+ def serialize(cls, value: Any) -> str | None:
89
+ if isinstance(value, datetime.datetime):
90
+ return f"{cls.STR_PREFIX}{value.isoformat()}"
91
+ return None
92
+
93
+ @classmethod
94
+ def deserialize(cls, data: Any) -> datetime.datetime | None:
95
+ if value_part := cls._extract_string_value(data):
96
+ try:
97
+ return datetime.datetime.fromisoformat(value_part)
98
+ except ValueError:
99
+ pass
100
+ return None
101
+
102
+
103
+ class LegacyModelParameter(JobParameter):
104
+ """Legacy model parameter handling for backwards compatibility."""
105
+
106
+ STR_PREFIX = "gid://"
107
+
108
+ @classmethod
109
+ def serialize(cls, value: Any) -> str | None:
110
+ # Don't serialize new instances with legacy format
111
+ return None
112
+
113
+ @classmethod
114
+ def deserialize(cls, data: Any) -> Model | None:
115
+ if value_part := cls._extract_string_value(data):
116
+ try:
117
+ package, model, obj_id = value_part.split("/")
118
+ model = models_registry.get_model(package, model)
119
+ return model.query.get(id=obj_id)
120
+ except (ValueError, Exception):
121
+ pass
122
+ return None
123
+
124
+
125
+ # Registry of parameter types to check in order
126
+ # The order matters - more specific types should come first
127
+ # DateTimeParameter must come before DateParameter since datetime is a subclass of date
128
+ # LegacyModelParameter is last since it only handles deserialization
129
+ PARAMETER_TYPES = [
130
+ ModelParameter,
131
+ DateTimeParameter,
132
+ DateParameter,
133
+ LegacyModelParameter,
134
+ ]
135
+
136
+
137
+ class JobParameters:
138
+ """
139
+ Main interface for serializing and deserializing job parameters.
140
+ Uses the registered parameter types to handle different value types.
141
+ """
142
+
143
+ @staticmethod
144
+ def to_json(args: tuple[Any, ...], kwargs: dict[str, Any]) -> dict[str, Any]:
145
+ serialized_args = []
146
+ for arg in args:
147
+ serialized = JobParameters._serialize_value(arg)
148
+ serialized_args.append(serialized)
149
+
150
+ serialized_kwargs = {}
151
+ for key, value in kwargs.items():
152
+ serialized = JobParameters._serialize_value(value)
153
+ serialized_kwargs[key] = serialized
154
+
155
+ return {"args": serialized_args, "kwargs": serialized_kwargs}
156
+
157
+ @staticmethod
158
+ def _serialize_value(value: Any) -> Any:
159
+ """Serialize a single value using the registered parameter types."""
160
+ # Try each parameter type to see if it can serialize this value
161
+ for param_type in PARAMETER_TYPES:
162
+ result = param_type.serialize(value)
163
+ if result is not None:
164
+ return result
165
+
166
+ # If no parameter type can handle it, return as-is
167
+ return value
168
+
169
+ @staticmethod
170
+ def from_json(data: dict[str, Any]) -> tuple[tuple[Any, ...], dict[str, Any]]:
171
+ args = []
172
+ for arg in data["args"]:
173
+ deserialized = JobParameters._deserialize_value(arg)
174
+ args.append(deserialized)
175
+
176
+ kwargs = {}
177
+ for key, value in data["kwargs"].items():
178
+ deserialized = JobParameters._deserialize_value(value)
179
+ kwargs[key] = deserialized
180
+
181
+ return tuple(args), kwargs
182
+
183
+ @staticmethod
184
+ def _deserialize_value(value: Any) -> Any:
185
+ """Deserialize a single value using the registered parameter types."""
186
+ # Try each parameter type to see if it can deserialize this value
187
+ for param_type in PARAMETER_TYPES:
188
+ result = param_type.deserialize(value)
189
+ if result is not None:
190
+ return result
191
+
192
+ # If no parameter type can handle it, return as-is
193
+ return value
plain/jobs/registry.py ADDED
@@ -0,0 +1,60 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Callable
4
+ from typing import TYPE_CHECKING, Any, TypeVar
5
+
6
+ from .parameters import JobParameters
7
+
8
+ if TYPE_CHECKING:
9
+ from .jobs import Job
10
+
11
+ T = TypeVar("T", bound=type["Job"])
12
+
13
+
14
+ class JobsRegistry:
15
+ def __init__(self) -> None:
16
+ self.jobs: dict[str, type[Job]] = {}
17
+ self.ready = False
18
+
19
+ def register_job(self, job_class: type[Job], alias: str = "") -> None:
20
+ name = self.get_job_class_name(job_class)
21
+ self.jobs[name] = job_class
22
+
23
+ if alias:
24
+ self.jobs[alias] = job_class
25
+
26
+ def get_job_class_name(self, job_class: type[Job]) -> str:
27
+ return f"{job_class.__module__}.{job_class.__qualname__}"
28
+
29
+ def get_job_class(self, name: str) -> type[Job]:
30
+ return self.jobs[name]
31
+
32
+ def load_job(self, job_class_name: str, parameters: dict[str, Any]) -> Job:
33
+ if not self.ready:
34
+ raise RuntimeError("Jobs registry is not ready yet")
35
+
36
+ job_class = self.get_job_class(job_class_name)
37
+ args, kwargs = JobParameters.from_json(parameters)
38
+ return job_class(*args, **kwargs)
39
+
40
+
41
+ jobs_registry = JobsRegistry()
42
+
43
+
44
+ def register_job(
45
+ job_class: T | None = None, *, alias: str = ""
46
+ ) -> T | Callable[[T], T]:
47
+ """
48
+ A decorator that registers a job class in the jobs registry with an optional alias.
49
+ Can be used both with and without parentheses.
50
+ """
51
+ if job_class is None:
52
+
53
+ def wrapper(cls: T) -> T:
54
+ jobs_registry.register_job(cls, alias=alias)
55
+ return cls
56
+
57
+ return wrapper
58
+ else:
59
+ jobs_registry.register_job(job_class, alias=alias)
60
+ return job_class