beanqueue 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- beanqueue-0.1.0.dist-info/LICENSE +21 -0
- beanqueue-0.1.0.dist-info/METADATA +23 -0
- beanqueue-0.1.0.dist-info/RECORD +23 -0
- beanqueue-0.1.0.dist-info/WHEEL +4 -0
- bq/__init__.py +0 -0
- bq/cmds/__init__.py +0 -0
- bq/cmds/create_tables.py +25 -0
- bq/cmds/process.py +178 -0
- bq/cmds/submit.py +48 -0
- bq/config.py +47 -0
- bq/container.py +54 -0
- bq/db/__init__.py +0 -0
- bq/db/base.py +5 -0
- bq/db/session.py +5 -0
- bq/models/__init__.py +4 -0
- bq/models/helpers.py +5 -0
- bq/models/task.py +120 -0
- bq/models/worker.py +69 -0
- bq/processors/__init__.py +0 -0
- bq/processors/registry.py +136 -0
- bq/services/__init__.py +0 -0
- bq/services/dispatch.py +96 -0
- bq/services/worker.py +69 -0
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2024 Launch Platform
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: beanqueue
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: BeanQueue or BQ for short, PostgreSQL SKIP LOCK based worker queue library
|
|
5
|
+
License: MIT
|
|
6
|
+
Author: Fang-Pen Lin
|
|
7
|
+
Author-email: fangpen@launchplatform.com
|
|
8
|
+
Requires-Python: >=3.11,<4.0
|
|
9
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
10
|
+
Classifier: Programming Language :: Python :: 3
|
|
11
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
13
|
+
Requires-Dist: click (>=8.1.7,<9.0.0)
|
|
14
|
+
Requires-Dist: dependency-injector (>=4.41.0,<5.0.0)
|
|
15
|
+
Requires-Dist: pg-activity (>=3.5.1,<4.0.0)
|
|
16
|
+
Requires-Dist: pydantic-settings (>=2.2.1,<3.0.0)
|
|
17
|
+
Requires-Dist: sqlalchemy (>=2.0.30,<3.0.0)
|
|
18
|
+
Requires-Dist: venusian (>=3.1.0,<4.0.0)
|
|
19
|
+
Description-Content-Type: text/markdown
|
|
20
|
+
|
|
21
|
+
# bq
|
|
22
|
+
BeanQueue or BQ for short, PostgreSQL SKIP LOCK based worker queue library
|
|
23
|
+
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
bq/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
+
bq/cmds/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
|
+
bq/cmds/create_tables.py,sha256=lI9GmovZT7-_ws4jq8xV-0pHStiyPZoE6Ygb0Xn1q2c,609
|
|
4
|
+
bq/cmds/process.py,sha256=9CE2PCzdl3WAdx1whlcqmDwuQujeuh6rvrLjEAcXlAs,5879
|
|
5
|
+
bq/cmds/submit.py,sha256=WhzitLawbR_hEjLqCuvr2yhKJcwGNplarntpAhjGBb4,1148
|
|
6
|
+
bq/config.py,sha256=LUEwzodrt6H32bC2Fn8_J8oQQQ9dSFZnIg2Uv_EMLus,1508
|
|
7
|
+
bq/container.py,sha256=NQLpc2zEap8LW2XkqRQjGtzyZslPL5iDnEZKBnH67eo,1558
|
|
8
|
+
bq/db/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
9
|
+
bq/db/base.py,sha256=0LXS0WlLr1KAHGZng46SmUwJam8m2AxVDBqK1Yzwx8w,88
|
|
10
|
+
bq/db/session.py,sha256=ife8ocHxCT7EbDgyEG38TFhP4l4wQv_Bo16bxt9XCos,157
|
|
11
|
+
bq/models/__init__.py,sha256=bibCa4EgF9MtkmSSqa15C7Jc7tN9P6LS78YOBkvX72s,114
|
|
12
|
+
bq/models/helpers.py,sha256=vf8IvREk2Lfc_LsdxWfUgVc99Fx1K2RsWVj9Xi4Dhbg,184
|
|
13
|
+
bq/models/task.py,sha256=hVkYgt-TPFdKmbmL9E_AUbWiZlk0nkq7ZGWfkyoZICk,3604
|
|
14
|
+
bq/models/worker.py,sha256=EuC8k6g9l_xddjXxhMqXmHt6uSO86VZFs9hJH1DQUi8,1940
|
|
15
|
+
bq/processors/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
16
|
+
bq/processors/registry.py,sha256=sc9rfkWAJFFh5p8K-Vv1z2GCzgqoTLX3yNMJ3j4F15k,4554
|
|
17
|
+
bq/services/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
18
|
+
bq/services/dispatch.py,sha256=uSssmqGKCxsr3ZzbglXs01XZNi3nv3zUfO2d88u3Wv0,3463
|
|
19
|
+
bq/services/worker.py,sha256=-NDP-6dA40OYdw00sbS3jLqU2PSY62oL_1gPVK5ijMQ,2554
|
|
20
|
+
beanqueue-0.1.0.dist-info/LICENSE,sha256=sNYpr-_bmGA6hBtD7qtmZmVgQH6HC6KXlx7tOOOjFq8,1072
|
|
21
|
+
beanqueue-0.1.0.dist-info/METADATA,sha256=YRbok6yaosRGobhhsnQYWTiRouSeytUIBOSfV82jseY,834
|
|
22
|
+
beanqueue-0.1.0.dist-info/WHEEL,sha256=FMvqSimYX_P7y0a7UY-_Mc83r5zkBZsCYPm7Lr0Bsq4,88
|
|
23
|
+
beanqueue-0.1.0.dist-info/RECORD,,
|
bq/__init__.py
ADDED
|
File without changes
|
bq/cmds/__init__.py
ADDED
|
File without changes
|
bq/cmds/create_tables.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
import click
|
|
4
|
+
from dependency_injector.wiring import inject
|
|
5
|
+
from dependency_injector.wiring import Provide
|
|
6
|
+
from sqlalchemy.engine import Engine
|
|
7
|
+
|
|
8
|
+
from .. import models # noqa
|
|
9
|
+
from ..container import Container
|
|
10
|
+
from ..db.base import Base
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@click.command()
|
|
14
|
+
@inject
|
|
15
|
+
def main(engine: Engine = Provide[Container.db_engine]):
|
|
16
|
+
logging.basicConfig(level=logging.INFO)
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
Base.metadata.create_all(bind=engine)
|
|
19
|
+
logger.info("Done, tables created")
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
if __name__ == "__main__":
|
|
23
|
+
container = Container()
|
|
24
|
+
container.wire(modules=[__name__])
|
|
25
|
+
main()
|
bq/cmds/process.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import importlib
|
|
3
|
+
import logging
|
|
4
|
+
import platform
|
|
5
|
+
import sys
|
|
6
|
+
import threading
|
|
7
|
+
import time
|
|
8
|
+
import typing
|
|
9
|
+
import uuid
|
|
10
|
+
|
|
11
|
+
import click
|
|
12
|
+
from dependency_injector.wiring import inject
|
|
13
|
+
from dependency_injector.wiring import Provide
|
|
14
|
+
from sqlalchemy import func
|
|
15
|
+
from sqlalchemy.orm import Session as DBSession
|
|
16
|
+
|
|
17
|
+
from .. import models
|
|
18
|
+
from ..config import Config
|
|
19
|
+
from ..container import Container
|
|
20
|
+
from ..processors.registry import collect
|
|
21
|
+
from ..services.dispatch import DispatchService
|
|
22
|
+
from ..services.worker import WorkerService
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def update_workers(
|
|
26
|
+
make_session: typing.Callable[[], DBSession],
|
|
27
|
+
worker_id: uuid.UUID,
|
|
28
|
+
heartbeat_period: int,
|
|
29
|
+
heartbeat_timeout: int,
|
|
30
|
+
):
|
|
31
|
+
db: DBSession = make_session()
|
|
32
|
+
worker_service = WorkerService(session=db)
|
|
33
|
+
dispatch_service = DispatchService(session=db)
|
|
34
|
+
current_worker = db.get(models.Worker, worker_id)
|
|
35
|
+
logger = logging.getLogger(__name__)
|
|
36
|
+
logger.info(
|
|
37
|
+
"Updating worker %s with heartbeat_period=%s, heartbeat_timeout=%s",
|
|
38
|
+
current_worker.id,
|
|
39
|
+
heartbeat_period,
|
|
40
|
+
heartbeat_timeout,
|
|
41
|
+
)
|
|
42
|
+
while True:
|
|
43
|
+
dead_workers = worker_service.fetch_dead_workers(timeout=heartbeat_timeout)
|
|
44
|
+
task_count = worker_service.reschedule_dead_tasks(
|
|
45
|
+
dead_workers.with_entities(models.Worker.id)
|
|
46
|
+
)
|
|
47
|
+
found_dead_worker = False
|
|
48
|
+
for dead_worker in dead_workers:
|
|
49
|
+
found_dead_worker = True
|
|
50
|
+
logger.info(
|
|
51
|
+
"Found dead worker %s (name=%s), reschedule %s dead tasks in channels %s",
|
|
52
|
+
dead_worker.id,
|
|
53
|
+
dead_worker.name,
|
|
54
|
+
task_count,
|
|
55
|
+
dead_worker.channels,
|
|
56
|
+
)
|
|
57
|
+
dispatch_service.notify(dead_worker.channels)
|
|
58
|
+
if found_dead_worker:
|
|
59
|
+
db.commit()
|
|
60
|
+
|
|
61
|
+
time.sleep(heartbeat_period)
|
|
62
|
+
current_worker.last_heartbeat = func.now()
|
|
63
|
+
db.add(current_worker)
|
|
64
|
+
db.commit()
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
@inject
|
|
68
|
+
def process_tasks(
|
|
69
|
+
channels: tuple[str, ...],
|
|
70
|
+
config: Config = Provide[Container.config],
|
|
71
|
+
session_factory: typing.Callable = Provide[Container.session_factory],
|
|
72
|
+
db: DBSession = Provide[Container.session],
|
|
73
|
+
dispatch_service: DispatchService = Provide[Container.dispatch_service],
|
|
74
|
+
worker_service: WorkerService = Provide[Container.worker_service],
|
|
75
|
+
):
|
|
76
|
+
logger = logging.getLogger(__name__)
|
|
77
|
+
|
|
78
|
+
if not channels:
|
|
79
|
+
channels = ["default"]
|
|
80
|
+
|
|
81
|
+
if not config.PROCESSOR_PACKAGES:
|
|
82
|
+
logger.error("No PROCESSOR_PACKAGES provided")
|
|
83
|
+
sys.exit(-1)
|
|
84
|
+
|
|
85
|
+
logger.info("Scanning packages %s", config.PROCESSOR_PACKAGES)
|
|
86
|
+
pkgs = list(map(importlib.import_module, config.PROCESSOR_PACKAGES))
|
|
87
|
+
registry = collect(pkgs)
|
|
88
|
+
for channel, module_processors in registry.processors.items():
|
|
89
|
+
logger.info("Collected processors with channel %r", channel)
|
|
90
|
+
for module, func_processors in module_processors.items():
|
|
91
|
+
for processor in func_processors.values():
|
|
92
|
+
logger.info(
|
|
93
|
+
" Processor module %r, processor %r", module, processor.name
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
worker = models.Worker(name=platform.node(), channels=channels)
|
|
97
|
+
db.add(worker)
|
|
98
|
+
dispatch_service.listen(channels)
|
|
99
|
+
db.commit()
|
|
100
|
+
|
|
101
|
+
logger.info("Created worker %s, name=%s", worker.id, worker.name)
|
|
102
|
+
logger.info("Processing tasks in channels = %s ...", channels)
|
|
103
|
+
|
|
104
|
+
worker_update_thread = threading.Thread(
|
|
105
|
+
target=functools.partial(
|
|
106
|
+
update_workers,
|
|
107
|
+
make_session=session_factory,
|
|
108
|
+
worker_id=worker.id,
|
|
109
|
+
heartbeat_period=config.WORKER_HEARTBEAT_PERIOD,
|
|
110
|
+
heartbeat_timeout=config.WORKER_HEARTBEAT_TIMEOUT,
|
|
111
|
+
),
|
|
112
|
+
name="update_workers",
|
|
113
|
+
)
|
|
114
|
+
worker_update_thread.daemon = True
|
|
115
|
+
worker_update_thread.start()
|
|
116
|
+
|
|
117
|
+
worker_id = worker.id
|
|
118
|
+
|
|
119
|
+
try:
|
|
120
|
+
while True:
|
|
121
|
+
while True:
|
|
122
|
+
tasks = dispatch_service.dispatch(
|
|
123
|
+
channels,
|
|
124
|
+
worker_id=worker_id,
|
|
125
|
+
limit=config.BATCH_SIZE,
|
|
126
|
+
).all()
|
|
127
|
+
for task in tasks:
|
|
128
|
+
logger.info(
|
|
129
|
+
"Processing task %s, channel=%s, module=%s, func=%s",
|
|
130
|
+
task.id,
|
|
131
|
+
task.channel,
|
|
132
|
+
task.module,
|
|
133
|
+
task.func_name,
|
|
134
|
+
)
|
|
135
|
+
# TODO: support processor pool and other approaches to dispatch the workload
|
|
136
|
+
registry.process(task)
|
|
137
|
+
if not tasks:
|
|
138
|
+
# we should try to keep dispatching until we cannot find tasks
|
|
139
|
+
break
|
|
140
|
+
else:
|
|
141
|
+
db.commit()
|
|
142
|
+
# we will not see notifications in a transaction, need to close the transaction first before entering
|
|
143
|
+
# polling
|
|
144
|
+
db.close()
|
|
145
|
+
try:
|
|
146
|
+
for notification in dispatch_service.poll(timeout=config.POLL_TIMEOUT):
|
|
147
|
+
logger.debug("Receive notification %s", notification)
|
|
148
|
+
except TimeoutError:
|
|
149
|
+
logger.debug("Poll timeout, try again")
|
|
150
|
+
continue
|
|
151
|
+
except (SystemExit, KeyboardInterrupt):
|
|
152
|
+
db.rollback()
|
|
153
|
+
logger.info("Shutting down ...")
|
|
154
|
+
worker_update_thread.join(5)
|
|
155
|
+
|
|
156
|
+
worker.state = models.WorkerState.SHUTDOWN
|
|
157
|
+
db.add(worker)
|
|
158
|
+
task_count = worker_service.reschedule_dead_tasks([worker.id])
|
|
159
|
+
logger.info("Reschedule %s tasks", task_count)
|
|
160
|
+
dispatch_service.notify(channels)
|
|
161
|
+
db.commit()
|
|
162
|
+
|
|
163
|
+
logger.info("Shutdown gracefully")
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
@click.command()
|
|
167
|
+
@click.argument("channels", nargs=-1)
|
|
168
|
+
def main(
|
|
169
|
+
channels: tuple[str, ...],
|
|
170
|
+
):
|
|
171
|
+
process_tasks(channels)
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
if __name__ == "__main__":
|
|
175
|
+
logging.basicConfig(level=logging.INFO)
|
|
176
|
+
container = Container()
|
|
177
|
+
container.wire(modules=[__name__])
|
|
178
|
+
main()
|
bq/cmds/submit.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
|
|
4
|
+
import click
|
|
5
|
+
from dependency_injector.wiring import inject
|
|
6
|
+
from dependency_injector.wiring import Provide
|
|
7
|
+
|
|
8
|
+
from .. import models
|
|
9
|
+
from ..container import Container
|
|
10
|
+
from ..db.session import Session
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@click.command()
|
|
14
|
+
@click.argument("channel", nargs=1)
|
|
15
|
+
@click.argument("module", nargs=1)
|
|
16
|
+
@click.argument("func", nargs=1)
|
|
17
|
+
@click.option(
|
|
18
|
+
"-k", "--kwargs", type=str, help="Keyword arguments as JSON", default=None
|
|
19
|
+
)
|
|
20
|
+
@inject
|
|
21
|
+
def main(
|
|
22
|
+
channel: str,
|
|
23
|
+
module: str,
|
|
24
|
+
func: str,
|
|
25
|
+
kwargs: str | None,
|
|
26
|
+
db: Session = Provide[Container.session],
|
|
27
|
+
):
|
|
28
|
+
logging.basicConfig(level=logging.INFO)
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
30
|
+
|
|
31
|
+
logger.info(
|
|
32
|
+
"Submit task with channel=%s, module=%s, func=%s", channel, module, func
|
|
33
|
+
)
|
|
34
|
+
kwargs_value = {}
|
|
35
|
+
if kwargs:
|
|
36
|
+
kwargs_value = json.loads(kwargs)
|
|
37
|
+
task = models.Task(
|
|
38
|
+
channel=channel, module=module, func_name=func, kwargs=kwargs_value
|
|
39
|
+
)
|
|
40
|
+
db.add(task)
|
|
41
|
+
db.commit()
|
|
42
|
+
logger.info("Done, submit task %s", task.id)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
if __name__ == "__main__":
|
|
46
|
+
container = Container()
|
|
47
|
+
container.wire(modules=[__name__])
|
|
48
|
+
main()
|
bq/config.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
import typing
|
|
2
|
+
|
|
3
|
+
from pydantic import field_validator
|
|
4
|
+
from pydantic import PostgresDsn
|
|
5
|
+
from pydantic import ValidationInfo
|
|
6
|
+
from pydantic_settings import BaseSettings
|
|
7
|
+
from pydantic_settings import SettingsConfigDict
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Config(BaseSettings):
|
|
11
|
+
# Packages to scan for processor functions
|
|
12
|
+
PROCESSOR_PACKAGES: list[str] = []
|
|
13
|
+
|
|
14
|
+
# Size of tasks batch to fetch each time from the database
|
|
15
|
+
BATCH_SIZE: int = 1
|
|
16
|
+
|
|
17
|
+
# How long we should poll before timeout in seconds
|
|
18
|
+
POLL_TIMEOUT: int = 60
|
|
19
|
+
|
|
20
|
+
# Interval of worker heartbeat update cycle in seconds
|
|
21
|
+
WORKER_HEARTBEAT_PERIOD: int = 30
|
|
22
|
+
|
|
23
|
+
# Timeout of worker heartbeat in seconds
|
|
24
|
+
WORKER_HEARTBEAT_TIMEOUT: int = 100
|
|
25
|
+
|
|
26
|
+
POSTGRES_SERVER: str = "localhost"
|
|
27
|
+
POSTGRES_USER: str = "bq"
|
|
28
|
+
POSTGRES_PASSWORD: str = ""
|
|
29
|
+
POSTGRES_DB: str = "bq"
|
|
30
|
+
# The URL of postgresql database to connect
|
|
31
|
+
DATABASE_URL: typing.Optional[PostgresDsn] = None
|
|
32
|
+
|
|
33
|
+
@field_validator("DATABASE_URL", mode="before")
|
|
34
|
+
def assemble_db_connection(
|
|
35
|
+
cls, v: typing.Optional[str], info: ValidationInfo
|
|
36
|
+
) -> typing.Any:
|
|
37
|
+
if isinstance(v, str):
|
|
38
|
+
return v
|
|
39
|
+
return PostgresDsn.build(
|
|
40
|
+
scheme="postgresql",
|
|
41
|
+
username=info.data.get("POSTGRES_USER"),
|
|
42
|
+
password=info.data.get("POSTGRES_PASSWORD"),
|
|
43
|
+
host=info.data.get("POSTGRES_SERVER"),
|
|
44
|
+
path=f"{info.data.get('POSTGRES_DB') or ''}",
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
model_config = SettingsConfigDict(case_sensitive=True, env_prefix="BQ_")
|
bq/container.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import typing
|
|
3
|
+
|
|
4
|
+
from dependency_injector import containers
|
|
5
|
+
from dependency_injector import providers
|
|
6
|
+
from sqlalchemy import create_engine
|
|
7
|
+
from sqlalchemy import Engine
|
|
8
|
+
from sqlalchemy.orm import Session as DBSession
|
|
9
|
+
from sqlalchemy.pool import SingletonThreadPool
|
|
10
|
+
|
|
11
|
+
from .config import Config
|
|
12
|
+
from .db.session import SessionMaker
|
|
13
|
+
from .services.dispatch import DispatchService
|
|
14
|
+
from .services.worker import WorkerService
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def make_db_engine(config: Config) -> Engine:
|
|
18
|
+
return create_engine(str(config.DATABASE_URL), poolclass=SingletonThreadPool)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def make_session_factory(engine: Engine) -> typing.Callable:
|
|
22
|
+
return functools.partial(SessionMaker, bind=engine)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def make_session(factory: typing.Callable) -> DBSession:
|
|
26
|
+
return factory()
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def make_dispatch_service(session: DBSession) -> DispatchService:
|
|
30
|
+
return DispatchService(session)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def make_worker_service(session: DBSession) -> WorkerService:
|
|
34
|
+
return WorkerService(session)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class Container(containers.DeclarativeContainer):
|
|
38
|
+
config = providers.Singleton(Config)
|
|
39
|
+
|
|
40
|
+
db_engine: Engine = providers.Singleton(make_db_engine, config=config)
|
|
41
|
+
|
|
42
|
+
session_factory: typing.Callable = providers.Singleton(
|
|
43
|
+
make_session_factory, engine=db_engine
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
session: DBSession = providers.Singleton(make_session, factory=session_factory)
|
|
47
|
+
|
|
48
|
+
dispatch_service: DispatchService = providers.Singleton(
|
|
49
|
+
make_dispatch_service, session=session
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
worker_service: WorkerService = providers.Singleton(
|
|
53
|
+
make_worker_service, session=session
|
|
54
|
+
)
|
bq/db/__init__.py
ADDED
|
File without changes
|
bq/db/base.py
ADDED
bq/db/session.py
ADDED
bq/models/__init__.py
ADDED
bq/models/helpers.py
ADDED
bq/models/task.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
import enum
|
|
2
|
+
|
|
3
|
+
from sqlalchemy import Column
|
|
4
|
+
from sqlalchemy import Connection
|
|
5
|
+
from sqlalchemy import DateTime
|
|
6
|
+
from sqlalchemy import Enum
|
|
7
|
+
from sqlalchemy import event
|
|
8
|
+
from sqlalchemy import ForeignKey
|
|
9
|
+
from sqlalchemy import func
|
|
10
|
+
from sqlalchemy import inspect
|
|
11
|
+
from sqlalchemy import String
|
|
12
|
+
from sqlalchemy.dialects.postgresql import JSONB
|
|
13
|
+
from sqlalchemy.dialects.postgresql import UUID
|
|
14
|
+
from sqlalchemy.orm import Mapper
|
|
15
|
+
from sqlalchemy.orm import relationship
|
|
16
|
+
|
|
17
|
+
from ..db.base import Base
|
|
18
|
+
from .helpers import make_repr_attrs
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class TaskState(enum.Enum):
|
|
22
|
+
# task just created, not scheduled yet
|
|
23
|
+
PENDING = "PENDING"
|
|
24
|
+
# a worker is processing the task right now
|
|
25
|
+
PROCESSING = "PROCESSING"
|
|
26
|
+
# the task is done
|
|
27
|
+
DONE = "DONE"
|
|
28
|
+
# the task is failed
|
|
29
|
+
FAILED = "FAILED"
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class Task(Base):
|
|
33
|
+
id = Column(
|
|
34
|
+
UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()
|
|
35
|
+
)
|
|
36
|
+
# foreign key id of assigned worker
|
|
37
|
+
worker_id = Column(
|
|
38
|
+
UUID(as_uuid=True),
|
|
39
|
+
ForeignKey("bq_workers.id", name="fk_workers_id"),
|
|
40
|
+
nullable=True,
|
|
41
|
+
)
|
|
42
|
+
# current state of the task
|
|
43
|
+
state = Column(
|
|
44
|
+
Enum(TaskState),
|
|
45
|
+
nullable=False,
|
|
46
|
+
default=TaskState.PENDING,
|
|
47
|
+
server_default=TaskState.PENDING.value,
|
|
48
|
+
index=True,
|
|
49
|
+
)
|
|
50
|
+
# channel for workers and job creator to listen/notify
|
|
51
|
+
channel = Column(String, nullable=False, index=True)
|
|
52
|
+
# module of the processor function
|
|
53
|
+
module = Column(String, nullable=False)
|
|
54
|
+
# func name of the processor func
|
|
55
|
+
func_name = Column(String, nullable=False)
|
|
56
|
+
# keyword arguments
|
|
57
|
+
kwargs = Column(JSONB, nullable=True)
|
|
58
|
+
# Result of the task
|
|
59
|
+
result = Column(JSONB, nullable=True)
|
|
60
|
+
# Error message
|
|
61
|
+
error_message = Column(String, nullable=True)
|
|
62
|
+
# created datetime of the task
|
|
63
|
+
created_at = Column(
|
|
64
|
+
DateTime(timezone=True), nullable=False, server_default=func.now()
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
worker = relationship("Worker", back_populates="tasks", uselist=False)
|
|
68
|
+
|
|
69
|
+
__tablename__ = "bq_tasks"
|
|
70
|
+
|
|
71
|
+
def __repr__(self) -> str:
|
|
72
|
+
items = [
|
|
73
|
+
("id", self.id),
|
|
74
|
+
("state", self.state),
|
|
75
|
+
("channel", self.channel),
|
|
76
|
+
]
|
|
77
|
+
return f"<{self.__class__.__name__} {make_repr_attrs(items)}>"
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def notify_if_needed(connection: Connection, task: Task):
|
|
81
|
+
session = inspect(task).session
|
|
82
|
+
transaction = session.get_transaction()
|
|
83
|
+
if transaction is not None:
|
|
84
|
+
key = "_notified_channels"
|
|
85
|
+
if hasattr(transaction, key):
|
|
86
|
+
notified_channels = getattr(transaction, key)
|
|
87
|
+
else:
|
|
88
|
+
notified_channels = set()
|
|
89
|
+
setattr(transaction, key, notified_channels)
|
|
90
|
+
|
|
91
|
+
if task.channel in notified_channels:
|
|
92
|
+
# already notified, skip
|
|
93
|
+
return
|
|
94
|
+
notified_channels.add(task.channel)
|
|
95
|
+
|
|
96
|
+
quoted_channel = connection.dialect.identifier_preparer.quote_identifier(
|
|
97
|
+
task.channel
|
|
98
|
+
)
|
|
99
|
+
connection.exec_driver_sql(f"NOTIFY {quoted_channel}")
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
@event.listens_for(Task, "after_insert")
|
|
103
|
+
def task_insert_notify(mapper: Mapper, connection: Connection, target: Task):
|
|
104
|
+
from .. import models
|
|
105
|
+
|
|
106
|
+
if target.state != models.TaskState.PENDING:
|
|
107
|
+
return
|
|
108
|
+
notify_if_needed(connection, target)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
@event.listens_for(Task, "after_update")
|
|
112
|
+
def task_update_notify(mapper: Mapper, connection: Connection, target: Task):
|
|
113
|
+
from .. import models
|
|
114
|
+
|
|
115
|
+
history = inspect(target).attrs.state.history
|
|
116
|
+
if not history.has_changes():
|
|
117
|
+
return
|
|
118
|
+
if target.state != models.TaskState.PENDING:
|
|
119
|
+
return
|
|
120
|
+
notify_if_needed(connection, target)
|
bq/models/worker.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
import enum
|
|
2
|
+
|
|
3
|
+
from sqlalchemy import Column
|
|
4
|
+
from sqlalchemy import DateTime
|
|
5
|
+
from sqlalchemy import Enum
|
|
6
|
+
from sqlalchemy import func
|
|
7
|
+
from sqlalchemy import String
|
|
8
|
+
from sqlalchemy.dialects.postgresql import ARRAY
|
|
9
|
+
from sqlalchemy.dialects.postgresql import UUID
|
|
10
|
+
from sqlalchemy.orm import relationship
|
|
11
|
+
|
|
12
|
+
from ..db.base import Base
|
|
13
|
+
from .helpers import make_repr_attrs
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class WorkerState(enum.Enum):
|
|
17
|
+
# the worker is running
|
|
18
|
+
RUNNING = "RUNNING"
|
|
19
|
+
# the worker shuts down normally
|
|
20
|
+
SHUTDOWN = "SHUTDOWN"
|
|
21
|
+
# The worker has no heartbeat for a while
|
|
22
|
+
NO_HEARTBEAT = "NO_HEARTBEAT"
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class Worker(Base):
|
|
26
|
+
id = Column(
|
|
27
|
+
UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()
|
|
28
|
+
)
|
|
29
|
+
# current state of the worker
|
|
30
|
+
state = Column(
|
|
31
|
+
Enum(WorkerState),
|
|
32
|
+
nullable=False,
|
|
33
|
+
default=WorkerState.RUNNING,
|
|
34
|
+
server_default=WorkerState.RUNNING.value,
|
|
35
|
+
index=True,
|
|
36
|
+
)
|
|
37
|
+
# name of the worker
|
|
38
|
+
name = Column(String, nullable=False)
|
|
39
|
+
# the channels we are processing
|
|
40
|
+
channels = Column(ARRAY(String), nullable=False)
|
|
41
|
+
# last heartbeat of this worker
|
|
42
|
+
last_heartbeat = Column(
|
|
43
|
+
DateTime(timezone=True),
|
|
44
|
+
nullable=False,
|
|
45
|
+
server_default=func.now(),
|
|
46
|
+
index=True,
|
|
47
|
+
)
|
|
48
|
+
# created datetime of the worker
|
|
49
|
+
created_at = Column(
|
|
50
|
+
DateTime(timezone=True), nullable=False, server_default=func.now()
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
tasks = relationship(
|
|
54
|
+
"Task",
|
|
55
|
+
back_populates="worker",
|
|
56
|
+
cascade="all,delete",
|
|
57
|
+
order_by="Task.created_at",
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
__tablename__ = "bq_workers"
|
|
61
|
+
|
|
62
|
+
def __repr__(self) -> str:
|
|
63
|
+
items = [
|
|
64
|
+
("id", self.id),
|
|
65
|
+
("name", self.name),
|
|
66
|
+
("channels", self.channels),
|
|
67
|
+
("state", self.state),
|
|
68
|
+
]
|
|
69
|
+
return f"<{self.__class__.__name__} {make_repr_attrs(items)}>"
|
|
File without changes
|
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
import collections
|
|
2
|
+
import dataclasses
|
|
3
|
+
import inspect
|
|
4
|
+
import logging
|
|
5
|
+
import typing
|
|
6
|
+
|
|
7
|
+
import venusian
|
|
8
|
+
from sqlalchemy.orm import object_session
|
|
9
|
+
|
|
10
|
+
from bq import models
|
|
11
|
+
|
|
12
|
+
BQ_PROCESSOR_CATEGORY = "bq_processor"
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclasses.dataclass(frozen=True)
|
|
16
|
+
class Processor:
|
|
17
|
+
channel: str
|
|
18
|
+
module: str
|
|
19
|
+
name: str
|
|
20
|
+
func: typing.Callable
|
|
21
|
+
# should we auto complete the task or not
|
|
22
|
+
auto_complete: bool = True
|
|
23
|
+
# should we auto rollback the transaction when encounter unhandled exception
|
|
24
|
+
auto_rollback_on_exc: bool = True
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class ProcessorHelper:
|
|
28
|
+
def __init__(self, processor: Processor, task_cls: typing.Type = models.Task):
|
|
29
|
+
self._processor = processor
|
|
30
|
+
self._task_cls = task_cls
|
|
31
|
+
|
|
32
|
+
def __call__(self, *args, **kwargs):
|
|
33
|
+
return self._processor.func(*args, **kwargs)
|
|
34
|
+
|
|
35
|
+
def run(self, **kwargs) -> models.Task:
|
|
36
|
+
return self._task_cls(
|
|
37
|
+
channel=self._processor.channel,
|
|
38
|
+
module=self._processor.module,
|
|
39
|
+
func_name=self._processor.name,
|
|
40
|
+
kwargs=kwargs,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def process_task(task: models.Task, processor: Processor):
|
|
45
|
+
logger = logging.getLogger(__name__)
|
|
46
|
+
db = object_session(task)
|
|
47
|
+
func_signature = inspect.signature(processor.func)
|
|
48
|
+
base_kwargs = {}
|
|
49
|
+
if "task" in func_signature.parameters:
|
|
50
|
+
base_kwargs["task"] = task
|
|
51
|
+
if "db" in func_signature.parameters:
|
|
52
|
+
base_kwargs["db"] = db
|
|
53
|
+
with db.begin_nested() as savepoint:
|
|
54
|
+
try:
|
|
55
|
+
result = processor.func(**base_kwargs, **task.kwargs)
|
|
56
|
+
savepoint.commit()
|
|
57
|
+
except Exception as exc:
|
|
58
|
+
logger.error("Unhandled exception for task %s", task.id, exc_info=True)
|
|
59
|
+
if processor.auto_rollback_on_exc:
|
|
60
|
+
savepoint.rollback()
|
|
61
|
+
# TODO: add error event
|
|
62
|
+
task.state = models.TaskState.FAILED
|
|
63
|
+
task.error_message = str(exc)
|
|
64
|
+
db.add(task)
|
|
65
|
+
return
|
|
66
|
+
if processor.auto_complete:
|
|
67
|
+
logger.info("Task %s auto complete", task.id)
|
|
68
|
+
task.state = models.TaskState.DONE
|
|
69
|
+
task.result = result
|
|
70
|
+
db.add(task)
|
|
71
|
+
return result
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class Registry:
|
|
75
|
+
def __init__(self):
|
|
76
|
+
self.logger = logging.getLogger(__name__)
|
|
77
|
+
self.processors = collections.defaultdict(lambda: collections.defaultdict(dict))
|
|
78
|
+
|
|
79
|
+
def add(self, processor: Processor):
|
|
80
|
+
self.processors[processor.channel][processor.module][processor.name] = processor
|
|
81
|
+
|
|
82
|
+
def process(self, task: models.Task) -> typing.Any:
|
|
83
|
+
modules = self.processors.get(task.channel, {})
|
|
84
|
+
functions = modules.get(task.module, {})
|
|
85
|
+
processor = functions.get(task.func_name)
|
|
86
|
+
db = object_session(task)
|
|
87
|
+
if processor is None:
|
|
88
|
+
self.logger.error(
|
|
89
|
+
"Cannot find processor for task %s with module=%s, func=%s",
|
|
90
|
+
task.id,
|
|
91
|
+
task.module,
|
|
92
|
+
task.func_name,
|
|
93
|
+
)
|
|
94
|
+
# TODO: add error event
|
|
95
|
+
task.state = models.TaskState.FAILED
|
|
96
|
+
task.error_message = f"Cannot find processor for task with module={task.module}, func={task.func_name}"
|
|
97
|
+
db.add(task)
|
|
98
|
+
return
|
|
99
|
+
return process_task(task, processor)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def processor(
|
|
103
|
+
channel: str,
|
|
104
|
+
auto_complete: bool = True,
|
|
105
|
+
auto_rollback_on_exc: bool = True,
|
|
106
|
+
task_cls: typing.Type = models.Task,
|
|
107
|
+
) -> typing.Callable:
|
|
108
|
+
def decorator(wrapped: typing.Callable):
|
|
109
|
+
processor = Processor(
|
|
110
|
+
module=wrapped.__module__,
|
|
111
|
+
name=wrapped.__name__,
|
|
112
|
+
channel=channel,
|
|
113
|
+
func=wrapped,
|
|
114
|
+
auto_complete=auto_complete,
|
|
115
|
+
auto_rollback_on_exc=auto_rollback_on_exc,
|
|
116
|
+
)
|
|
117
|
+
helper_obj = ProcessorHelper(processor, task_cls=task_cls)
|
|
118
|
+
|
|
119
|
+
def callback(scanner: venusian.Scanner, name: str, ob: typing.Callable):
|
|
120
|
+
if processor.name != name:
|
|
121
|
+
raise ValueError("Name is not the same")
|
|
122
|
+
scanner.registry.add(processor)
|
|
123
|
+
|
|
124
|
+
venusian.attach(helper_obj, callback, category=BQ_PROCESSOR_CATEGORY)
|
|
125
|
+
return helper_obj
|
|
126
|
+
|
|
127
|
+
return decorator
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def collect(packages: list[typing.Any], registry: Registry | None = None) -> Registry:
|
|
131
|
+
if registry is None:
|
|
132
|
+
registry = Registry()
|
|
133
|
+
scanner = venusian.Scanner(registry=registry)
|
|
134
|
+
for package in packages:
|
|
135
|
+
scanner.scan(package, categories=(BQ_PROCESSOR_CATEGORY,))
|
|
136
|
+
return registry
|
bq/services/__init__.py
ADDED
|
File without changes
|
bq/services/dispatch.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
import select
|
|
3
|
+
import typing
|
|
4
|
+
import uuid
|
|
5
|
+
|
|
6
|
+
from sqlalchemy.orm import Query
|
|
7
|
+
|
|
8
|
+
from .. import models
|
|
9
|
+
from ..db.session import Session
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclasses.dataclass(frozen=True)
|
|
13
|
+
class Notification:
|
|
14
|
+
pid: int
|
|
15
|
+
channel: str
|
|
16
|
+
payload: typing.Optional[str] = None
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class DispatchService:
|
|
20
|
+
def __init__(self, session: Session):
|
|
21
|
+
self.session = session
|
|
22
|
+
|
|
23
|
+
def make_task_query(self, channels: typing.Sequence[str], limit: int = 1) -> Query:
|
|
24
|
+
return (
|
|
25
|
+
self.session.query(models.Task.id)
|
|
26
|
+
.filter(models.Task.channel.in_(channels))
|
|
27
|
+
.filter(models.Task.state == models.TaskState.PENDING)
|
|
28
|
+
.order_by(models.Task.created_at)
|
|
29
|
+
.limit(limit)
|
|
30
|
+
.with_for_update(skip_locked=True)
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
def make_update_query(self, task_query: typing.Any, worker_id: uuid.UUID):
|
|
34
|
+
return (
|
|
35
|
+
models.Task.__table__.update()
|
|
36
|
+
.where(models.Task.id.in_(task_query))
|
|
37
|
+
.values(
|
|
38
|
+
state=models.TaskState.PROCESSING,
|
|
39
|
+
worker_id=worker_id,
|
|
40
|
+
)
|
|
41
|
+
.returning(models.Task.id)
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
def dispatch(
|
|
45
|
+
self, channels: typing.Sequence[str], worker_id: uuid.UUID, limit: int = 1
|
|
46
|
+
) -> Query:
|
|
47
|
+
task_query = self.make_task_query(channels, limit=limit)
|
|
48
|
+
task_subquery = task_query.scalar_subquery()
|
|
49
|
+
task_ids = [
|
|
50
|
+
item[0]
|
|
51
|
+
for item in self.session.execute(
|
|
52
|
+
self.make_update_query(task_subquery, worker_id=worker_id)
|
|
53
|
+
)
|
|
54
|
+
]
|
|
55
|
+
# TODO: ideally returning with (models.Task) should return the whole model, but SQLAlchemy is returning
|
|
56
|
+
# it columns in rows. We can save a round trip if we can find out how to solve this
|
|
57
|
+
return self.session.query(models.Task).filter(models.Task.id.in_(task_ids))
|
|
58
|
+
|
|
59
|
+
def listen(self, channels: typing.Sequence[str]):
|
|
60
|
+
conn = self.session.connection()
|
|
61
|
+
for channel in channels:
|
|
62
|
+
quoted_channel = conn.dialect.identifier_preparer.quote_identifier(channel)
|
|
63
|
+
conn.exec_driver_sql(f"LISTEN {quoted_channel}")
|
|
64
|
+
|
|
65
|
+
def poll(self, timeout: int = 5) -> typing.Generator[Notification, None, None]:
|
|
66
|
+
conn = self.session.connection()
|
|
67
|
+
driver_conn = conn.connection.driver_connection
|
|
68
|
+
|
|
69
|
+
def pop_notifies():
|
|
70
|
+
while driver_conn.notifies:
|
|
71
|
+
notify = driver_conn.notifies.pop(0)
|
|
72
|
+
yield Notification(
|
|
73
|
+
pid=notify.pid,
|
|
74
|
+
channel=notify.channel,
|
|
75
|
+
payload=notify.payload,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
# poll first to see if there's anything already
|
|
79
|
+
driver_conn.poll()
|
|
80
|
+
if driver_conn.notifies:
|
|
81
|
+
yield from pop_notifies()
|
|
82
|
+
else:
|
|
83
|
+
# okay, nothing, let's select and wait for new stuff
|
|
84
|
+
if select.select([driver_conn], [], [], timeout) == ([], [], []):
|
|
85
|
+
# nope, nothing, times out
|
|
86
|
+
raise TimeoutError("Timeout waiting for new notifications")
|
|
87
|
+
else:
|
|
88
|
+
# yep, we got something
|
|
89
|
+
driver_conn.poll()
|
|
90
|
+
yield from pop_notifies()
|
|
91
|
+
|
|
92
|
+
def notify(self, channels: typing.Sequence[str]):
|
|
93
|
+
conn = self.session.connection()
|
|
94
|
+
for channel in channels:
|
|
95
|
+
quoted_channel = conn.dialect.identifier_preparer.quote_identifier(channel)
|
|
96
|
+
conn.exec_driver_sql(f"NOTIFY {quoted_channel}")
|
bq/services/worker.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
import datetime
|
|
2
|
+
import typing
|
|
3
|
+
|
|
4
|
+
from sqlalchemy import func
|
|
5
|
+
from sqlalchemy.orm import Query
|
|
6
|
+
from sqlalchemy.orm import Session
|
|
7
|
+
|
|
8
|
+
from .. import models
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class WorkerService:
|
|
12
|
+
def __init__(self, session: Session):
|
|
13
|
+
self.session = session
|
|
14
|
+
|
|
15
|
+
def update_heartbeat(self, worker: models.Worker):
|
|
16
|
+
worker.last_heartbeat = func.now()
|
|
17
|
+
self.session.add(worker)
|
|
18
|
+
|
|
19
|
+
def make_dead_worker_query(self, timeout: int, limit: int = 5) -> Query:
|
|
20
|
+
return (
|
|
21
|
+
self.session.query(models.Worker.id)
|
|
22
|
+
.filter(
|
|
23
|
+
models.Worker.last_heartbeat
|
|
24
|
+
< (func.now() - datetime.timedelta(seconds=timeout))
|
|
25
|
+
)
|
|
26
|
+
.filter(models.Worker.state == models.WorkerState.RUNNING)
|
|
27
|
+
.limit(limit)
|
|
28
|
+
.with_for_update(skip_locked=True)
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
def make_update_dead_worker_query(self, worker_query: typing.Any):
|
|
32
|
+
return (
|
|
33
|
+
models.Worker.__table__.update()
|
|
34
|
+
.where(models.Worker.id.in_(worker_query))
|
|
35
|
+
.values(
|
|
36
|
+
state=models.WorkerState.NO_HEARTBEAT,
|
|
37
|
+
)
|
|
38
|
+
.returning(models.Worker.id)
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
def fetch_dead_workers(self, timeout: int, limit: int = 5) -> Query:
|
|
42
|
+
dead_worker_query = self.make_dead_worker_query(timeout=timeout, limit=limit)
|
|
43
|
+
dead_worker_subquery = dead_worker_query.scalar_subquery()
|
|
44
|
+
worker_ids = [
|
|
45
|
+
item[0]
|
|
46
|
+
for item in self.session.execute(
|
|
47
|
+
self.make_update_dead_worker_query(dead_worker_subquery)
|
|
48
|
+
)
|
|
49
|
+
]
|
|
50
|
+
# TODO: ideally returning with (models.Task) should return the whole model, but SQLAlchemy is returning
|
|
51
|
+
# it columns in rows. We can save a round trip if we can find out how to solve this
|
|
52
|
+
return self.session.query(models.Worker).filter(
|
|
53
|
+
models.Worker.id.in_(worker_ids)
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
def make_update_tasks_query(self, worker_query: typing.Any):
|
|
57
|
+
return (
|
|
58
|
+
models.Task.__table__.update()
|
|
59
|
+
.where(models.Task.worker_id.in_(worker_query))
|
|
60
|
+
.where(models.Task.state == models.TaskState.PROCESSING)
|
|
61
|
+
.values(
|
|
62
|
+
state=models.TaskState.PENDING,
|
|
63
|
+
)
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
def reschedule_dead_tasks(self, worker_query: typing.Any) -> int:
|
|
67
|
+
update_dead_task_query = self.make_update_tasks_query(worker_query=worker_query)
|
|
68
|
+
res = self.session.execute(update_dead_task_query)
|
|
69
|
+
return res.rowcount
|