oban 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- oban/__init__.py +22 -0
- oban/__main__.py +12 -0
- oban/_backoff.py +87 -0
- oban/_config.py +171 -0
- oban/_executor.py +188 -0
- oban/_extensions.py +16 -0
- oban/_leader.py +118 -0
- oban/_lifeline.py +77 -0
- oban/_notifier.py +324 -0
- oban/_producer.py +334 -0
- oban/_pruner.py +93 -0
- oban/_query.py +409 -0
- oban/_recorded.py +34 -0
- oban/_refresher.py +88 -0
- oban/_scheduler.py +359 -0
- oban/_stager.py +115 -0
- oban/_worker.py +78 -0
- oban/cli.py +436 -0
- oban/decorators.py +218 -0
- oban/job.py +315 -0
- oban/oban.py +1084 -0
- oban/py.typed +0 -0
- oban/queries/__init__.py +0 -0
- oban/queries/ack_job.sql +11 -0
- oban/queries/all_jobs.sql +25 -0
- oban/queries/cancel_many_jobs.sql +37 -0
- oban/queries/cleanup_expired_leaders.sql +4 -0
- oban/queries/cleanup_expired_producers.sql +2 -0
- oban/queries/delete_many_jobs.sql +5 -0
- oban/queries/delete_producer.sql +2 -0
- oban/queries/elect_leader.sql +10 -0
- oban/queries/fetch_jobs.sql +44 -0
- oban/queries/get_job.sql +23 -0
- oban/queries/insert_job.sql +28 -0
- oban/queries/insert_producer.sql +2 -0
- oban/queries/install.sql +113 -0
- oban/queries/prune_jobs.sql +18 -0
- oban/queries/reelect_leader.sql +12 -0
- oban/queries/refresh_producers.sql +3 -0
- oban/queries/rescue_jobs.sql +18 -0
- oban/queries/reset.sql +5 -0
- oban/queries/resign_leader.sql +4 -0
- oban/queries/retry_many_jobs.sql +13 -0
- oban/queries/stage_jobs.sql +34 -0
- oban/queries/uninstall.sql +4 -0
- oban/queries/update_job.sql +54 -0
- oban/queries/update_producer.sql +3 -0
- oban/queries/verify_structure.sql +9 -0
- oban/schema.py +115 -0
- oban/telemetry/__init__.py +10 -0
- oban/telemetry/core.py +170 -0
- oban/telemetry/logger.py +147 -0
- oban/testing.py +439 -0
- oban-0.5.0.dist-info/METADATA +290 -0
- oban-0.5.0.dist-info/RECORD +59 -0
- oban-0.5.0.dist-info/WHEEL +5 -0
- oban-0.5.0.dist-info/entry_points.txt +2 -0
- oban-0.5.0.dist-info/licenses/LICENSE.txt +201 -0
- oban-0.5.0.dist-info/top_level.txt +1 -0
oban/_notifier.py
ADDED
|
@@ -0,0 +1,324 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import base64
|
|
5
|
+
import gzip
|
|
6
|
+
import inspect
|
|
7
|
+
import logging
|
|
8
|
+
from collections import defaultdict
|
|
9
|
+
from typing import TYPE_CHECKING, Any, Callable, Protocol, runtime_checkable
|
|
10
|
+
from uuid import uuid4
|
|
11
|
+
|
|
12
|
+
import orjson
|
|
13
|
+
from psycopg import AsyncConnection, OperationalError
|
|
14
|
+
|
|
15
|
+
from ._backoff import jittery_exponential
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
if TYPE_CHECKING:
|
|
20
|
+
from ._query import Query
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def encode_payload(payload: dict) -> str:
|
|
24
|
+
"""Encode a dict payload to an efficient format for publishing.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
payload: Dict to encode
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
Base64-encoded gzipped JSON string
|
|
31
|
+
"""
|
|
32
|
+
dumped = orjson.dumps(payload)
|
|
33
|
+
zipped = gzip.compress(dumped)
|
|
34
|
+
|
|
35
|
+
return base64.b64encode(zipped).decode("ascii")
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def decode_payload(payload: str) -> dict:
|
|
39
|
+
"""Decode a payload string to a dict.
|
|
40
|
+
|
|
41
|
+
Handles both compressed (base64+gzip) and plain JSON payloads.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
payload: Payload string to decode
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
Decoded dict
|
|
48
|
+
"""
|
|
49
|
+
# Payloads created by SQL queries won't be encoded or compressed.
|
|
50
|
+
if payload.startswith("{"):
|
|
51
|
+
return orjson.loads(payload)
|
|
52
|
+
else:
|
|
53
|
+
decoded = base64.b64decode(payload)
|
|
54
|
+
unzipped = gzip.decompress(decoded)
|
|
55
|
+
|
|
56
|
+
return orjson.loads(unzipped.decode("utf-8"))
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@runtime_checkable
|
|
60
|
+
class Notifier(Protocol):
|
|
61
|
+
"""Protocol for pub/sub notification systems.
|
|
62
|
+
|
|
63
|
+
Notifiers enable real-time communication between Oban components using
|
|
64
|
+
channels. The default implementation uses PostgreSQL LISTEN/NOTIFY.
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
async def start(self) -> None:
|
|
68
|
+
"""Start the notifier and establish necessary connections."""
|
|
69
|
+
...
|
|
70
|
+
|
|
71
|
+
async def stop(self) -> None:
|
|
72
|
+
"""Stop the notifier and clean up resources."""
|
|
73
|
+
...
|
|
74
|
+
|
|
75
|
+
async def listen(
|
|
76
|
+
self, channel: str, callback: Callable[[str, dict], Any], wait: bool = True
|
|
77
|
+
) -> str:
|
|
78
|
+
"""Register a callback for a channel.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
channels: Channel name to listen on
|
|
82
|
+
callback: Sync or async function called when notification received.
|
|
83
|
+
Receives (channel, payload) as arguments where payload is a dict.
|
|
84
|
+
Async callbacks are executed in the background without blocking.
|
|
85
|
+
wait: If True, blocks until the subscription is fully established and ready to
|
|
86
|
+
receive notifications. If False, returns immediately after registering.
|
|
87
|
+
Defaults to True for test reliability.
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
Token (UUID string) used to unregister this subscription
|
|
91
|
+
|
|
92
|
+
Example:
|
|
93
|
+
>>> async def handler(channel: str, payload: dict):
|
|
94
|
+
... print(f"Received on {channel}: {payload}")
|
|
95
|
+
...
|
|
96
|
+
>>> token = await notifier.listen("insert", handler)
|
|
97
|
+
"""
|
|
98
|
+
...
|
|
99
|
+
|
|
100
|
+
async def unlisten(self, token: str) -> None:
|
|
101
|
+
"""Unregister a subscription by token.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
token: The token returned from listen()
|
|
105
|
+
|
|
106
|
+
Example:
|
|
107
|
+
>>> token = await notifier.listen("insert", handler)
|
|
108
|
+
>>> await notifier.unlisten(token)
|
|
109
|
+
"""
|
|
110
|
+
...
|
|
111
|
+
|
|
112
|
+
async def notify(self, channel: str, payloads: dict | list[dict]) -> None:
|
|
113
|
+
"""Send one or more notifications to a channel.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
channel: Channel name to send notification on
|
|
117
|
+
payloads: Payload dict or list of payload dicts to send
|
|
118
|
+
|
|
119
|
+
Example:
|
|
120
|
+
>>> await notifier.notify("insert", {"queue": "default"})
|
|
121
|
+
>>> await notifier.notify("insert", [{"queue": "default"}, {"queue": "mailers"}])
|
|
122
|
+
"""
|
|
123
|
+
...
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
class PostgresNotifier:
|
|
127
|
+
"""PostgreSQL-based notifier using LISTEN/NOTIFY.
|
|
128
|
+
|
|
129
|
+
Maintains a dedicated connection for receiving notifications and dispatches
|
|
130
|
+
them to registered callbacks. Automatically reconnects on connection loss.
|
|
131
|
+
"""
|
|
132
|
+
|
|
133
|
+
def __init__(
|
|
134
|
+
self,
|
|
135
|
+
*,
|
|
136
|
+
query: Query,
|
|
137
|
+
prefix: str = "public",
|
|
138
|
+
beat_interval: float = 30.0,
|
|
139
|
+
connect_timeout: float = 5.0,
|
|
140
|
+
notify_timeout: float = 0.1,
|
|
141
|
+
) -> None:
|
|
142
|
+
self._query = query
|
|
143
|
+
self._prefix = prefix
|
|
144
|
+
self._beat_interval = beat_interval
|
|
145
|
+
self._connect_timeout = connect_timeout
|
|
146
|
+
self._notify_timeout = notify_timeout
|
|
147
|
+
|
|
148
|
+
self._pending_listen = set()
|
|
149
|
+
self._pending_unlisten = set()
|
|
150
|
+
self._listen_events = {}
|
|
151
|
+
self._subscriptions = defaultdict(dict)
|
|
152
|
+
self._tokens = {}
|
|
153
|
+
|
|
154
|
+
self._conn = None
|
|
155
|
+
self._loop_task = None
|
|
156
|
+
self._beat_task = None
|
|
157
|
+
self._reconnect_attempts = 0
|
|
158
|
+
|
|
159
|
+
def _to_full_channel(self, channel: str) -> str:
|
|
160
|
+
return f"{self._prefix}.oban_{channel}"
|
|
161
|
+
|
|
162
|
+
def _from_full_channel(self, full_channel: str) -> str:
|
|
163
|
+
(_prefix, channel) = full_channel.split(".", 1)
|
|
164
|
+
|
|
165
|
+
return channel[5:]
|
|
166
|
+
|
|
167
|
+
async def start(self) -> None:
|
|
168
|
+
await self._connect()
|
|
169
|
+
|
|
170
|
+
async def stop(self) -> None:
|
|
171
|
+
if not self._loop_task or not self._beat_task or not self._conn:
|
|
172
|
+
return
|
|
173
|
+
|
|
174
|
+
self._loop_task.cancel()
|
|
175
|
+
self._beat_task.cancel()
|
|
176
|
+
|
|
177
|
+
await asyncio.gather(self._loop_task, self._beat_task, return_exceptions=True)
|
|
178
|
+
|
|
179
|
+
try:
|
|
180
|
+
await self._conn.close()
|
|
181
|
+
except Exception:
|
|
182
|
+
pass
|
|
183
|
+
|
|
184
|
+
async def listen(
|
|
185
|
+
self, channel: str, callback: Callable[[str, dict], Any], wait: bool = True
|
|
186
|
+
) -> str:
|
|
187
|
+
token = str(uuid4())
|
|
188
|
+
|
|
189
|
+
if channel not in self._subscriptions:
|
|
190
|
+
self._pending_listen.add(channel)
|
|
191
|
+
self._listen_events[channel] = asyncio.Event()
|
|
192
|
+
|
|
193
|
+
self._tokens[token] = channel
|
|
194
|
+
self._subscriptions[channel][token] = callback
|
|
195
|
+
|
|
196
|
+
if wait and channel in self._listen_events:
|
|
197
|
+
await self._listen_events[channel].wait()
|
|
198
|
+
|
|
199
|
+
return token
|
|
200
|
+
|
|
201
|
+
async def unlisten(self, token: str) -> None:
|
|
202
|
+
if token not in self._tokens:
|
|
203
|
+
return
|
|
204
|
+
|
|
205
|
+
channel = self._tokens.pop(token)
|
|
206
|
+
|
|
207
|
+
if channel in self._subscriptions:
|
|
208
|
+
self._subscriptions[channel].pop(token, None)
|
|
209
|
+
|
|
210
|
+
if len(self._subscriptions[channel]) == 0:
|
|
211
|
+
del self._subscriptions[channel]
|
|
212
|
+
self._pending_unlisten.add(channel)
|
|
213
|
+
|
|
214
|
+
async def notify(self, channel: str, payloads: dict | list[dict]) -> None:
|
|
215
|
+
if isinstance(payloads, dict):
|
|
216
|
+
payloads = [payloads]
|
|
217
|
+
|
|
218
|
+
channel = self._to_full_channel(channel)
|
|
219
|
+
encoded = [encode_payload(payload) for payload in payloads]
|
|
220
|
+
|
|
221
|
+
await self._query.notify(channel, encoded)
|
|
222
|
+
|
|
223
|
+
async def _connect(self) -> None:
|
|
224
|
+
dsn = self._query.dsn
|
|
225
|
+
|
|
226
|
+
self._conn = await asyncio.wait_for(
|
|
227
|
+
AsyncConnection.connect(dsn, autocommit=True),
|
|
228
|
+
timeout=self._connect_timeout,
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
for channel in list(self._subscriptions.keys()):
|
|
232
|
+
self._pending_listen.add(channel)
|
|
233
|
+
|
|
234
|
+
self._loop_task = asyncio.create_task(self._loop(), name="oban-notifier-loop")
|
|
235
|
+
self._beat_task = asyncio.create_task(self._beat(), name="oban-notifier-beat")
|
|
236
|
+
|
|
237
|
+
self._reconnect_attempts = 0
|
|
238
|
+
|
|
239
|
+
async def _reconnect(self) -> None:
|
|
240
|
+
self._reconnect_attempts += 1
|
|
241
|
+
|
|
242
|
+
delay = jittery_exponential(self._reconnect_attempts, max_pow=4)
|
|
243
|
+
|
|
244
|
+
await asyncio.sleep(delay)
|
|
245
|
+
|
|
246
|
+
try:
|
|
247
|
+
await self._connect()
|
|
248
|
+
except (OSError, OperationalError, asyncio.TimeoutError):
|
|
249
|
+
asyncio.create_task(self._reconnect())
|
|
250
|
+
|
|
251
|
+
async def _loop(self) -> None:
|
|
252
|
+
if not self._conn:
|
|
253
|
+
return
|
|
254
|
+
|
|
255
|
+
while True:
|
|
256
|
+
try:
|
|
257
|
+
await self._process_pending()
|
|
258
|
+
await self._process_notifications()
|
|
259
|
+
except asyncio.CancelledError:
|
|
260
|
+
raise
|
|
261
|
+
except (OSError, OperationalError):
|
|
262
|
+
self._conn = None
|
|
263
|
+
asyncio.create_task(self._reconnect())
|
|
264
|
+
break
|
|
265
|
+
|
|
266
|
+
async def _process_pending(self) -> None:
|
|
267
|
+
if not self._conn:
|
|
268
|
+
return
|
|
269
|
+
|
|
270
|
+
for channel in list(self._pending_listen):
|
|
271
|
+
full_channel = self._to_full_channel(channel)
|
|
272
|
+
await self._conn.execute(f'LISTEN "{full_channel}"')
|
|
273
|
+
|
|
274
|
+
self._pending_listen.discard(channel)
|
|
275
|
+
|
|
276
|
+
if channel in self._listen_events:
|
|
277
|
+
self._listen_events[channel].set()
|
|
278
|
+
del self._listen_events[channel]
|
|
279
|
+
|
|
280
|
+
for channel in list(self._pending_unlisten):
|
|
281
|
+
full_channel = self._to_full_channel(channel)
|
|
282
|
+
await self._conn.execute(f'UNLISTEN "{full_channel}"')
|
|
283
|
+
|
|
284
|
+
self._pending_unlisten.discard(channel)
|
|
285
|
+
|
|
286
|
+
async def _process_notifications(self) -> None:
|
|
287
|
+
if not self._conn:
|
|
288
|
+
return
|
|
289
|
+
|
|
290
|
+
gen = self._conn.notifies(timeout=self._notify_timeout)
|
|
291
|
+
|
|
292
|
+
async for notify in gen:
|
|
293
|
+
await self._dispatch(notify)
|
|
294
|
+
|
|
295
|
+
async def _dispatch(self, notify) -> None:
|
|
296
|
+
channel = self._from_full_channel(notify.channel)
|
|
297
|
+
payload = decode_payload(notify.payload)
|
|
298
|
+
|
|
299
|
+
if channel in self._subscriptions:
|
|
300
|
+
for callback in self._subscriptions[channel].values():
|
|
301
|
+
try:
|
|
302
|
+
if inspect.iscoroutinefunction(callback):
|
|
303
|
+
asyncio.create_task(callback(channel, payload))
|
|
304
|
+
else:
|
|
305
|
+
callback(channel, payload)
|
|
306
|
+
except Exception:
|
|
307
|
+
logger.exception(
|
|
308
|
+
"Error in notifier callback for channel %s with payload %s",
|
|
309
|
+
channel,
|
|
310
|
+
payload,
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
async def _beat(self) -> None:
|
|
314
|
+
while True:
|
|
315
|
+
try:
|
|
316
|
+
await asyncio.sleep(self._beat_interval)
|
|
317
|
+
|
|
318
|
+
if self._conn and not self._conn.closed:
|
|
319
|
+
await self._conn.execute("SELECT 1")
|
|
320
|
+
|
|
321
|
+
except asyncio.CancelledError:
|
|
322
|
+
raise
|
|
323
|
+
except (OSError, OperationalError):
|
|
324
|
+
pass
|
oban/_producer.py
ADDED
|
@@ -0,0 +1,334 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import logging
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from datetime import datetime, timezone
|
|
7
|
+
from typing import TYPE_CHECKING, Any
|
|
8
|
+
from uuid import uuid4
|
|
9
|
+
|
|
10
|
+
from . import telemetry
|
|
11
|
+
from ._executor import Executor
|
|
12
|
+
from ._extensions import use_ext
|
|
13
|
+
from .job import Job
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from ._notifier import Notifier
|
|
17
|
+
from ._query import Query
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _init(producer: Producer) -> dict:
|
|
23
|
+
return {"local_limit": producer._limit, "paused": producer._paused}
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _pause(producer: Producer) -> dict:
|
|
27
|
+
producer._paused = True
|
|
28
|
+
|
|
29
|
+
return {"paused": True}
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _resume(producer: Producer) -> dict:
|
|
33
|
+
producer._paused = False
|
|
34
|
+
|
|
35
|
+
return {"paused": False}
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _scale(producer: Producer, **kwargs: Any) -> dict:
|
|
39
|
+
if "limit" in kwargs:
|
|
40
|
+
producer._limit = kwargs["limit"]
|
|
41
|
+
|
|
42
|
+
return {"local_limit": producer._limit}
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _validate(*, queue: str, limit: int, **extra) -> None:
|
|
46
|
+
if not isinstance(queue, str):
|
|
47
|
+
raise TypeError(f"queue must be a string, got {queue}")
|
|
48
|
+
if not queue.strip():
|
|
49
|
+
raise ValueError("queue must not be blank")
|
|
50
|
+
|
|
51
|
+
if limit < 1:
|
|
52
|
+
raise ValueError(f"Queue '{queue}' limit must be positive")
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
async def _get_jobs(producer: Producer) -> list[Job]:
|
|
56
|
+
demand = producer._limit - len(producer._running_jobs)
|
|
57
|
+
|
|
58
|
+
if demand > 0:
|
|
59
|
+
return await producer._query.fetch_jobs(
|
|
60
|
+
demand=demand,
|
|
61
|
+
queue=producer._queue,
|
|
62
|
+
node=producer._node,
|
|
63
|
+
uuid=producer._uuid,
|
|
64
|
+
)
|
|
65
|
+
else:
|
|
66
|
+
return []
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class LocalDispatcher:
|
|
70
|
+
def dispatch(self, producer: Producer, job: Job) -> asyncio.Task:
|
|
71
|
+
return asyncio.create_task(producer._execute(job))
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@dataclass(frozen=True, slots=True)
|
|
75
|
+
class QueueInfo:
|
|
76
|
+
"""Information about a queue's runtime state."""
|
|
77
|
+
|
|
78
|
+
limit: int
|
|
79
|
+
"""The concurrency limit for this queue."""
|
|
80
|
+
|
|
81
|
+
meta: dict[str, Any]
|
|
82
|
+
"""Queue metadata including extension options (global_limit, partition, etc.)."""
|
|
83
|
+
|
|
84
|
+
node: str
|
|
85
|
+
"""The node name where this queue is running."""
|
|
86
|
+
|
|
87
|
+
paused: bool
|
|
88
|
+
"""Whether the queue is currently paused."""
|
|
89
|
+
|
|
90
|
+
queue: str
|
|
91
|
+
"""The queue name."""
|
|
92
|
+
|
|
93
|
+
running: list[int]
|
|
94
|
+
"""List of currently executing job IDs."""
|
|
95
|
+
|
|
96
|
+
started_at: datetime | None
|
|
97
|
+
"""When the queue was started."""
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class Producer:
|
|
101
|
+
def __init__(
|
|
102
|
+
self,
|
|
103
|
+
*,
|
|
104
|
+
debounce_interval: float = 0.005,
|
|
105
|
+
dispatcher: Any = None,
|
|
106
|
+
limit: int = 10,
|
|
107
|
+
paused: bool = False,
|
|
108
|
+
queue: str = "default",
|
|
109
|
+
name: str,
|
|
110
|
+
node: str,
|
|
111
|
+
notifier: Notifier,
|
|
112
|
+
query: Query,
|
|
113
|
+
**extra,
|
|
114
|
+
) -> None:
|
|
115
|
+
self._debounce_interval = debounce_interval
|
|
116
|
+
self._dispatcher = dispatcher or LocalDispatcher()
|
|
117
|
+
self._extra = extra
|
|
118
|
+
self._limit = limit
|
|
119
|
+
self._name = name
|
|
120
|
+
self._node = node
|
|
121
|
+
self._notifier = notifier
|
|
122
|
+
self._paused = paused
|
|
123
|
+
self._query = query
|
|
124
|
+
self._queue = queue
|
|
125
|
+
|
|
126
|
+
self._validate()
|
|
127
|
+
|
|
128
|
+
self._init_lock = asyncio.Lock()
|
|
129
|
+
self._last_fetch_time = 0.0
|
|
130
|
+
self._listen_token = None
|
|
131
|
+
self._loop_task = None
|
|
132
|
+
self._notified = asyncio.Event()
|
|
133
|
+
self._pending_acks = []
|
|
134
|
+
self._running_jobs = {}
|
|
135
|
+
self._started_at = None
|
|
136
|
+
self._uuid = str(uuid4())
|
|
137
|
+
|
|
138
|
+
def _validate(self, **opts) -> None:
|
|
139
|
+
params = {"queue": self._queue, "limit": self._limit, **self._extra}
|
|
140
|
+
merged = {**params, **opts}
|
|
141
|
+
|
|
142
|
+
use_ext("producer.validate", _validate, **merged)
|
|
143
|
+
|
|
144
|
+
async def start(self) -> None:
|
|
145
|
+
async with self._init_lock:
|
|
146
|
+
self._started_at = datetime.now(timezone.utc)
|
|
147
|
+
|
|
148
|
+
await self._query.insert_producer(
|
|
149
|
+
uuid=self._uuid,
|
|
150
|
+
name=self._name,
|
|
151
|
+
node=self._node,
|
|
152
|
+
queue=self._queue,
|
|
153
|
+
meta=use_ext("producer.init", _init, self),
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
self._listen_token = await self._notifier.listen(
|
|
157
|
+
"signal", self._on_signal, wait=False
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
self._loop_task = asyncio.create_task(
|
|
161
|
+
self._loop(), name=f"oban-producer-{self._queue}"
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
async def stop(self) -> None:
|
|
165
|
+
async with self._init_lock:
|
|
166
|
+
if not self._listen_token or not self._loop_task:
|
|
167
|
+
return
|
|
168
|
+
|
|
169
|
+
self._loop_task.cancel()
|
|
170
|
+
|
|
171
|
+
await self._notifier.unlisten(self._listen_token)
|
|
172
|
+
|
|
173
|
+
running_tasks = [task for (_job, task) in self._running_jobs.values()]
|
|
174
|
+
|
|
175
|
+
if running_tasks:
|
|
176
|
+
try:
|
|
177
|
+
now = datetime.now(timezone.utc).isoformat()
|
|
178
|
+
await self._query.update_producer(
|
|
179
|
+
uuid=self._uuid,
|
|
180
|
+
meta={"paused": True, "shutdown_started_at": now},
|
|
181
|
+
)
|
|
182
|
+
except Exception:
|
|
183
|
+
logger.debug(
|
|
184
|
+
"Failed to update producer %s during shutdown", self._uuid
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
await asyncio.gather(
|
|
188
|
+
self._loop_task,
|
|
189
|
+
*running_tasks,
|
|
190
|
+
return_exceptions=True,
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
try:
|
|
194
|
+
await self._query.delete_producer(self._uuid)
|
|
195
|
+
except Exception:
|
|
196
|
+
logger.debug(
|
|
197
|
+
"Failed to delete producer %s during shutdown, will be cleaned up asynchronously",
|
|
198
|
+
self._uuid,
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
def notify(self) -> None:
|
|
202
|
+
self._notified.set()
|
|
203
|
+
|
|
204
|
+
async def pause(self) -> None:
|
|
205
|
+
meta = use_ext("producer.pause", _pause, self)
|
|
206
|
+
|
|
207
|
+
await self._query.update_producer(uuid=self._uuid, meta=meta)
|
|
208
|
+
|
|
209
|
+
async def resume(self) -> None:
|
|
210
|
+
meta = use_ext("producer.resume", _resume, self)
|
|
211
|
+
|
|
212
|
+
await self._query.update_producer(uuid=self._uuid, meta=meta)
|
|
213
|
+
|
|
214
|
+
self.notify()
|
|
215
|
+
|
|
216
|
+
async def scale(self, **kwargs: Any) -> None:
|
|
217
|
+
self._validate(**kwargs)
|
|
218
|
+
|
|
219
|
+
meta = use_ext("producer.scale", _scale, self, **kwargs)
|
|
220
|
+
|
|
221
|
+
await self._query.update_producer(uuid=self._uuid, meta=meta)
|
|
222
|
+
|
|
223
|
+
self.notify()
|
|
224
|
+
|
|
225
|
+
def check(self) -> QueueInfo:
|
|
226
|
+
return QueueInfo(
|
|
227
|
+
limit=self._limit,
|
|
228
|
+
meta=self._extra,
|
|
229
|
+
node=self._node,
|
|
230
|
+
paused=self._paused,
|
|
231
|
+
queue=self._queue,
|
|
232
|
+
running=list(self._running_jobs.keys()),
|
|
233
|
+
started_at=self._started_at,
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
async def _loop(self) -> None:
|
|
237
|
+
while True:
|
|
238
|
+
try:
|
|
239
|
+
await asyncio.wait_for(self._notified.wait(), timeout=1.0)
|
|
240
|
+
except asyncio.TimeoutError:
|
|
241
|
+
continue
|
|
242
|
+
except asyncio.CancelledError:
|
|
243
|
+
break
|
|
244
|
+
|
|
245
|
+
self._notified.clear()
|
|
246
|
+
|
|
247
|
+
try:
|
|
248
|
+
await self._debounce()
|
|
249
|
+
await self._produce()
|
|
250
|
+
except asyncio.CancelledError:
|
|
251
|
+
break
|
|
252
|
+
except Exception:
|
|
253
|
+
logger.exception("Error in producer for queue %s", self._queue)
|
|
254
|
+
|
|
255
|
+
async def _debounce(self) -> None:
|
|
256
|
+
now = asyncio.get_event_loop().time()
|
|
257
|
+
elapsed = now - self._last_fetch_time
|
|
258
|
+
|
|
259
|
+
if elapsed < self._debounce_interval:
|
|
260
|
+
await asyncio.sleep(self._debounce_interval - elapsed)
|
|
261
|
+
|
|
262
|
+
self._last_fetch_time = asyncio.get_event_loop().time()
|
|
263
|
+
|
|
264
|
+
async def _produce(self) -> None:
|
|
265
|
+
if self._paused or (self._limit - len(self._running_jobs)) <= 0:
|
|
266
|
+
return
|
|
267
|
+
|
|
268
|
+
_ack = await self._ack_jobs()
|
|
269
|
+
jobs = await self._get_jobs()
|
|
270
|
+
|
|
271
|
+
for job in jobs:
|
|
272
|
+
task = self._dispatcher.dispatch(self, job)
|
|
273
|
+
task.add_done_callback(
|
|
274
|
+
lambda _, job_id=job.id: self._on_job_complete(job_id)
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
self._running_jobs[job.id] = (job, task)
|
|
278
|
+
|
|
279
|
+
async def _ack_jobs(self):
|
|
280
|
+
with telemetry.span("oban.producer.ack", {"queue": self._queue}) as context:
|
|
281
|
+
if self._pending_acks:
|
|
282
|
+
acked_ids = await self._query.ack_jobs(self._pending_acks)
|
|
283
|
+
acked_set = set(acked_ids)
|
|
284
|
+
|
|
285
|
+
self._pending_acks = [
|
|
286
|
+
ack for ack in self._pending_acks if ack.id not in acked_set
|
|
287
|
+
]
|
|
288
|
+
else:
|
|
289
|
+
acked_ids = []
|
|
290
|
+
|
|
291
|
+
context.add({"count": len(acked_ids)})
|
|
292
|
+
|
|
293
|
+
async def _get_jobs(self):
|
|
294
|
+
with telemetry.span("oban.producer.get", {"queue": self._queue}) as context:
|
|
295
|
+
jobs = await use_ext("producer.get_jobs", _get_jobs, self)
|
|
296
|
+
|
|
297
|
+
context.add({"count": len(jobs)})
|
|
298
|
+
|
|
299
|
+
return jobs
|
|
300
|
+
|
|
301
|
+
async def _execute(self, job: Job) -> None:
|
|
302
|
+
job._cancellation = asyncio.Event()
|
|
303
|
+
|
|
304
|
+
executor = await Executor(job=job, safe=True).execute()
|
|
305
|
+
|
|
306
|
+
self._pending_acks.append(executor.action)
|
|
307
|
+
|
|
308
|
+
def _on_job_complete(self, job_id: int) -> None:
|
|
309
|
+
self._running_jobs.pop(job_id, None)
|
|
310
|
+
|
|
311
|
+
self.notify()
|
|
312
|
+
|
|
313
|
+
async def _on_signal(self, _channel: str, payload: dict) -> None:
|
|
314
|
+
ident = payload.get("ident", "any")
|
|
315
|
+
queue = payload.get("queue", "*")
|
|
316
|
+
|
|
317
|
+
if queue != "*" and queue != self._queue:
|
|
318
|
+
return
|
|
319
|
+
|
|
320
|
+
if ident != "any" and ident != f"{self._name}.{self._node}":
|
|
321
|
+
return
|
|
322
|
+
|
|
323
|
+
match payload.get("action"):
|
|
324
|
+
case "pause":
|
|
325
|
+
await self.pause()
|
|
326
|
+
case "resume":
|
|
327
|
+
await self.resume()
|
|
328
|
+
case "pkill":
|
|
329
|
+
job_id = payload["job_id"]
|
|
330
|
+
|
|
331
|
+
if job_id in self._running_jobs:
|
|
332
|
+
(job, _task) = self._running_jobs[job_id]
|
|
333
|
+
|
|
334
|
+
job._cancellation.set()
|