beanqueue 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Launch Platform
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,23 @@
1
+ Metadata-Version: 2.1
2
+ Name: beanqueue
3
+ Version: 0.1.0
4
+ Summary: BeanQueue or BQ for short, PostgreSQL SKIP LOCK based worker queue library
5
+ License: MIT
6
+ Author: Fang-Pen Lin
7
+ Author-email: fangpen@launchplatform.com
8
+ Requires-Python: >=3.11,<4.0
9
+ Classifier: License :: OSI Approved :: MIT License
10
+ Classifier: Programming Language :: Python :: 3
11
+ Classifier: Programming Language :: Python :: 3.11
12
+ Classifier: Programming Language :: Python :: 3.12
13
+ Requires-Dist: click (>=8.1.7,<9.0.0)
14
+ Requires-Dist: dependency-injector (>=4.41.0,<5.0.0)
15
+ Requires-Dist: pg-activity (>=3.5.1,<4.0.0)
16
+ Requires-Dist: pydantic-settings (>=2.2.1,<3.0.0)
17
+ Requires-Dist: sqlalchemy (>=2.0.30,<3.0.0)
18
+ Requires-Dist: venusian (>=3.1.0,<4.0.0)
19
+ Description-Content-Type: text/markdown
20
+
21
+ # bq
22
+ BeanQueue or BQ for short, PostgreSQL SKIP LOCK based worker queue library
23
+
@@ -0,0 +1,23 @@
1
+ bq/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ bq/cmds/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
+ bq/cmds/create_tables.py,sha256=lI9GmovZT7-_ws4jq8xV-0pHStiyPZoE6Ygb0Xn1q2c,609
4
+ bq/cmds/process.py,sha256=9CE2PCzdl3WAdx1whlcqmDwuQujeuh6rvrLjEAcXlAs,5879
5
+ bq/cmds/submit.py,sha256=WhzitLawbR_hEjLqCuvr2yhKJcwGNplarntpAhjGBb4,1148
6
+ bq/config.py,sha256=LUEwzodrt6H32bC2Fn8_J8oQQQ9dSFZnIg2Uv_EMLus,1508
7
+ bq/container.py,sha256=NQLpc2zEap8LW2XkqRQjGtzyZslPL5iDnEZKBnH67eo,1558
8
+ bq/db/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
+ bq/db/base.py,sha256=0LXS0WlLr1KAHGZng46SmUwJam8m2AxVDBqK1Yzwx8w,88
10
+ bq/db/session.py,sha256=ife8ocHxCT7EbDgyEG38TFhP4l4wQv_Bo16bxt9XCos,157
11
+ bq/models/__init__.py,sha256=bibCa4EgF9MtkmSSqa15C7Jc7tN9P6LS78YOBkvX72s,114
12
+ bq/models/helpers.py,sha256=vf8IvREk2Lfc_LsdxWfUgVc99Fx1K2RsWVj9Xi4Dhbg,184
13
+ bq/models/task.py,sha256=hVkYgt-TPFdKmbmL9E_AUbWiZlk0nkq7ZGWfkyoZICk,3604
14
+ bq/models/worker.py,sha256=EuC8k6g9l_xddjXxhMqXmHt6uSO86VZFs9hJH1DQUi8,1940
15
+ bq/processors/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
+ bq/processors/registry.py,sha256=sc9rfkWAJFFh5p8K-Vv1z2GCzgqoTLX3yNMJ3j4F15k,4554
17
+ bq/services/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
18
+ bq/services/dispatch.py,sha256=uSssmqGKCxsr3ZzbglXs01XZNi3nv3zUfO2d88u3Wv0,3463
19
+ bq/services/worker.py,sha256=-NDP-6dA40OYdw00sbS3jLqU2PSY62oL_1gPVK5ijMQ,2554
20
+ beanqueue-0.1.0.dist-info/LICENSE,sha256=sNYpr-_bmGA6hBtD7qtmZmVgQH6HC6KXlx7tOOOjFq8,1072
21
+ beanqueue-0.1.0.dist-info/METADATA,sha256=YRbok6yaosRGobhhsnQYWTiRouSeytUIBOSfV82jseY,834
22
+ beanqueue-0.1.0.dist-info/WHEEL,sha256=FMvqSimYX_P7y0a7UY-_Mc83r5zkBZsCYPm7Lr0Bsq4,88
23
+ beanqueue-0.1.0.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: poetry-core 1.8.1
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
bq/__init__.py ADDED
File without changes
bq/cmds/__init__.py ADDED
File without changes
@@ -0,0 +1,25 @@
1
+ import logging
2
+
3
+ import click
4
+ from dependency_injector.wiring import inject
5
+ from dependency_injector.wiring import Provide
6
+ from sqlalchemy.engine import Engine
7
+
8
+ from .. import models # noqa
9
+ from ..container import Container
10
+ from ..db.base import Base
11
+
12
+
13
+ @click.command()
14
+ @inject
15
+ def main(engine: Engine = Provide[Container.db_engine]):
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger(__name__)
18
+ Base.metadata.create_all(bind=engine)
19
+ logger.info("Done, tables created")
20
+
21
+
22
+ if __name__ == "__main__":
23
+ container = Container()
24
+ container.wire(modules=[__name__])
25
+ main()
bq/cmds/process.py ADDED
@@ -0,0 +1,178 @@
1
+ import functools
2
+ import importlib
3
+ import logging
4
+ import platform
5
+ import sys
6
+ import threading
7
+ import time
8
+ import typing
9
+ import uuid
10
+
11
+ import click
12
+ from dependency_injector.wiring import inject
13
+ from dependency_injector.wiring import Provide
14
+ from sqlalchemy import func
15
+ from sqlalchemy.orm import Session as DBSession
16
+
17
+ from .. import models
18
+ from ..config import Config
19
+ from ..container import Container
20
+ from ..processors.registry import collect
21
+ from ..services.dispatch import DispatchService
22
+ from ..services.worker import WorkerService
23
+
24
+
25
+ def update_workers(
26
+ make_session: typing.Callable[[], DBSession],
27
+ worker_id: uuid.UUID,
28
+ heartbeat_period: int,
29
+ heartbeat_timeout: int,
30
+ ):
31
+ db: DBSession = make_session()
32
+ worker_service = WorkerService(session=db)
33
+ dispatch_service = DispatchService(session=db)
34
+ current_worker = db.get(models.Worker, worker_id)
35
+ logger = logging.getLogger(__name__)
36
+ logger.info(
37
+ "Updating worker %s with heartbeat_period=%s, heartbeat_timeout=%s",
38
+ current_worker.id,
39
+ heartbeat_period,
40
+ heartbeat_timeout,
41
+ )
42
+ while True:
43
+ dead_workers = worker_service.fetch_dead_workers(timeout=heartbeat_timeout)
44
+ task_count = worker_service.reschedule_dead_tasks(
45
+ dead_workers.with_entities(models.Worker.id)
46
+ )
47
+ found_dead_worker = False
48
+ for dead_worker in dead_workers:
49
+ found_dead_worker = True
50
+ logger.info(
51
+ "Found dead worker %s (name=%s), reschedule %s dead tasks in channels %s",
52
+ dead_worker.id,
53
+ dead_worker.name,
54
+ task_count,
55
+ dead_worker.channels,
56
+ )
57
+ dispatch_service.notify(dead_worker.channels)
58
+ if found_dead_worker:
59
+ db.commit()
60
+
61
+ time.sleep(heartbeat_period)
62
+ current_worker.last_heartbeat = func.now()
63
+ db.add(current_worker)
64
+ db.commit()
65
+
66
+
67
+ @inject
68
+ def process_tasks(
69
+ channels: tuple[str, ...],
70
+ config: Config = Provide[Container.config],
71
+ session_factory: typing.Callable = Provide[Container.session_factory],
72
+ db: DBSession = Provide[Container.session],
73
+ dispatch_service: DispatchService = Provide[Container.dispatch_service],
74
+ worker_service: WorkerService = Provide[Container.worker_service],
75
+ ):
76
+ logger = logging.getLogger(__name__)
77
+
78
+ if not channels:
79
+ channels = ["default"]
80
+
81
+ if not config.PROCESSOR_PACKAGES:
82
+ logger.error("No PROCESSOR_PACKAGES provided")
83
+ sys.exit(-1)
84
+
85
+ logger.info("Scanning packages %s", config.PROCESSOR_PACKAGES)
86
+ pkgs = list(map(importlib.import_module, config.PROCESSOR_PACKAGES))
87
+ registry = collect(pkgs)
88
+ for channel, module_processors in registry.processors.items():
89
+ logger.info("Collected processors with channel %r", channel)
90
+ for module, func_processors in module_processors.items():
91
+ for processor in func_processors.values():
92
+ logger.info(
93
+ " Processor module %r, processor %r", module, processor.name
94
+ )
95
+
96
+ worker = models.Worker(name=platform.node(), channels=channels)
97
+ db.add(worker)
98
+ dispatch_service.listen(channels)
99
+ db.commit()
100
+
101
+ logger.info("Created worker %s, name=%s", worker.id, worker.name)
102
+ logger.info("Processing tasks in channels = %s ...", channels)
103
+
104
+ worker_update_thread = threading.Thread(
105
+ target=functools.partial(
106
+ update_workers,
107
+ make_session=session_factory,
108
+ worker_id=worker.id,
109
+ heartbeat_period=config.WORKER_HEARTBEAT_PERIOD,
110
+ heartbeat_timeout=config.WORKER_HEARTBEAT_TIMEOUT,
111
+ ),
112
+ name="update_workers",
113
+ )
114
+ worker_update_thread.daemon = True
115
+ worker_update_thread.start()
116
+
117
+ worker_id = worker.id
118
+
119
+ try:
120
+ while True:
121
+ while True:
122
+ tasks = dispatch_service.dispatch(
123
+ channels,
124
+ worker_id=worker_id,
125
+ limit=config.BATCH_SIZE,
126
+ ).all()
127
+ for task in tasks:
128
+ logger.info(
129
+ "Processing task %s, channel=%s, module=%s, func=%s",
130
+ task.id,
131
+ task.channel,
132
+ task.module,
133
+ task.func_name,
134
+ )
135
+ # TODO: support processor pool and other approaches to dispatch the workload
136
+ registry.process(task)
137
+ if not tasks:
138
+ # we should try to keep dispatching until we cannot find tasks
139
+ break
140
+ else:
141
+ db.commit()
142
+ # we will not see notifications in a transaction, need to close the transaction first before entering
143
+ # polling
144
+ db.close()
145
+ try:
146
+ for notification in dispatch_service.poll(timeout=config.POLL_TIMEOUT):
147
+ logger.debug("Receive notification %s", notification)
148
+ except TimeoutError:
149
+ logger.debug("Poll timeout, try again")
150
+ continue
151
+ except (SystemExit, KeyboardInterrupt):
152
+ db.rollback()
153
+ logger.info("Shutting down ...")
154
+ worker_update_thread.join(5)
155
+
156
+ worker.state = models.WorkerState.SHUTDOWN
157
+ db.add(worker)
158
+ task_count = worker_service.reschedule_dead_tasks([worker.id])
159
+ logger.info("Reschedule %s tasks", task_count)
160
+ dispatch_service.notify(channels)
161
+ db.commit()
162
+
163
+ logger.info("Shutdown gracefully")
164
+
165
+
166
+ @click.command()
167
+ @click.argument("channels", nargs=-1)
168
+ def main(
169
+ channels: tuple[str, ...],
170
+ ):
171
+ process_tasks(channels)
172
+
173
+
174
+ if __name__ == "__main__":
175
+ logging.basicConfig(level=logging.INFO)
176
+ container = Container()
177
+ container.wire(modules=[__name__])
178
+ main()
bq/cmds/submit.py ADDED
@@ -0,0 +1,48 @@
1
+ import json
2
+ import logging
3
+
4
+ import click
5
+ from dependency_injector.wiring import inject
6
+ from dependency_injector.wiring import Provide
7
+
8
+ from .. import models
9
+ from ..container import Container
10
+ from ..db.session import Session
11
+
12
+
13
+ @click.command()
14
+ @click.argument("channel", nargs=1)
15
+ @click.argument("module", nargs=1)
16
+ @click.argument("func", nargs=1)
17
+ @click.option(
18
+ "-k", "--kwargs", type=str, help="Keyword arguments as JSON", default=None
19
+ )
20
+ @inject
21
+ def main(
22
+ channel: str,
23
+ module: str,
24
+ func: str,
25
+ kwargs: str | None,
26
+ db: Session = Provide[Container.session],
27
+ ):
28
+ logging.basicConfig(level=logging.INFO)
29
+ logger = logging.getLogger(__name__)
30
+
31
+ logger.info(
32
+ "Submit task with channel=%s, module=%s, func=%s", channel, module, func
33
+ )
34
+ kwargs_value = {}
35
+ if kwargs:
36
+ kwargs_value = json.loads(kwargs)
37
+ task = models.Task(
38
+ channel=channel, module=module, func_name=func, kwargs=kwargs_value
39
+ )
40
+ db.add(task)
41
+ db.commit()
42
+ logger.info("Done, submit task %s", task.id)
43
+
44
+
45
+ if __name__ == "__main__":
46
+ container = Container()
47
+ container.wire(modules=[__name__])
48
+ main()
bq/config.py ADDED
@@ -0,0 +1,47 @@
1
+ import typing
2
+
3
+ from pydantic import field_validator
4
+ from pydantic import PostgresDsn
5
+ from pydantic import ValidationInfo
6
+ from pydantic_settings import BaseSettings
7
+ from pydantic_settings import SettingsConfigDict
8
+
9
+
10
+ class Config(BaseSettings):
11
+ # Packages to scan for processor functions
12
+ PROCESSOR_PACKAGES: list[str] = []
13
+
14
+ # Size of tasks batch to fetch each time from the database
15
+ BATCH_SIZE: int = 1
16
+
17
+ # How long we should poll before timeout in seconds
18
+ POLL_TIMEOUT: int = 60
19
+
20
+ # Interval of worker heartbeat update cycle in seconds
21
+ WORKER_HEARTBEAT_PERIOD: int = 30
22
+
23
+ # Timeout of worker heartbeat in seconds
24
+ WORKER_HEARTBEAT_TIMEOUT: int = 100
25
+
26
+ POSTGRES_SERVER: str = "localhost"
27
+ POSTGRES_USER: str = "bq"
28
+ POSTGRES_PASSWORD: str = ""
29
+ POSTGRES_DB: str = "bq"
30
+ # The URL of postgresql database to connect
31
+ DATABASE_URL: typing.Optional[PostgresDsn] = None
32
+
33
+ @field_validator("DATABASE_URL", mode="before")
34
+ def assemble_db_connection(
35
+ cls, v: typing.Optional[str], info: ValidationInfo
36
+ ) -> typing.Any:
37
+ if isinstance(v, str):
38
+ return v
39
+ return PostgresDsn.build(
40
+ scheme="postgresql",
41
+ username=info.data.get("POSTGRES_USER"),
42
+ password=info.data.get("POSTGRES_PASSWORD"),
43
+ host=info.data.get("POSTGRES_SERVER"),
44
+ path=f"{info.data.get('POSTGRES_DB') or ''}",
45
+ )
46
+
47
+ model_config = SettingsConfigDict(case_sensitive=True, env_prefix="BQ_")
bq/container.py ADDED
@@ -0,0 +1,54 @@
1
+ import functools
2
+ import typing
3
+
4
+ from dependency_injector import containers
5
+ from dependency_injector import providers
6
+ from sqlalchemy import create_engine
7
+ from sqlalchemy import Engine
8
+ from sqlalchemy.orm import Session as DBSession
9
+ from sqlalchemy.pool import SingletonThreadPool
10
+
11
+ from .config import Config
12
+ from .db.session import SessionMaker
13
+ from .services.dispatch import DispatchService
14
+ from .services.worker import WorkerService
15
+
16
+
17
+ def make_db_engine(config: Config) -> Engine:
18
+ return create_engine(str(config.DATABASE_URL), poolclass=SingletonThreadPool)
19
+
20
+
21
+ def make_session_factory(engine: Engine) -> typing.Callable:
22
+ return functools.partial(SessionMaker, bind=engine)
23
+
24
+
25
+ def make_session(factory: typing.Callable) -> DBSession:
26
+ return factory()
27
+
28
+
29
+ def make_dispatch_service(session: DBSession) -> DispatchService:
30
+ return DispatchService(session)
31
+
32
+
33
+ def make_worker_service(session: DBSession) -> WorkerService:
34
+ return WorkerService(session)
35
+
36
+
37
+ class Container(containers.DeclarativeContainer):
38
+ config = providers.Singleton(Config)
39
+
40
+ db_engine: Engine = providers.Singleton(make_db_engine, config=config)
41
+
42
+ session_factory: typing.Callable = providers.Singleton(
43
+ make_session_factory, engine=db_engine
44
+ )
45
+
46
+ session: DBSession = providers.Singleton(make_session, factory=session_factory)
47
+
48
+ dispatch_service: DispatchService = providers.Singleton(
49
+ make_dispatch_service, session=session
50
+ )
51
+
52
+ worker_service: WorkerService = providers.Singleton(
53
+ make_worker_service, session=session
54
+ )
bq/db/__init__.py ADDED
File without changes
bq/db/base.py ADDED
@@ -0,0 +1,5 @@
1
+ from sqlalchemy.orm import DeclarativeBase
2
+
3
+
4
+ class Base(DeclarativeBase):
5
+ pass
bq/db/session.py ADDED
@@ -0,0 +1,5 @@
1
+ from sqlalchemy.orm import scoped_session
2
+ from sqlalchemy.orm import sessionmaker
3
+
4
+ SessionMaker = sessionmaker()
5
+ Session = scoped_session(SessionMaker)
bq/models/__init__.py ADDED
@@ -0,0 +1,4 @@
1
+ from .task import Task
2
+ from .task import TaskState
3
+ from .worker import Worker
4
+ from .worker import WorkerState
bq/models/helpers.py ADDED
@@ -0,0 +1,5 @@
1
+ import typing
2
+
3
+
4
+ def make_repr_attrs(items: typing.Sequence[typing.Tuple[str, typing.Any]]) -> str:
5
+ return " ".join(map(lambda item: "=".join([item[0], str(item[1])]), items))
bq/models/task.py ADDED
@@ -0,0 +1,120 @@
1
+ import enum
2
+
3
+ from sqlalchemy import Column
4
+ from sqlalchemy import Connection
5
+ from sqlalchemy import DateTime
6
+ from sqlalchemy import Enum
7
+ from sqlalchemy import event
8
+ from sqlalchemy import ForeignKey
9
+ from sqlalchemy import func
10
+ from sqlalchemy import inspect
11
+ from sqlalchemy import String
12
+ from sqlalchemy.dialects.postgresql import JSONB
13
+ from sqlalchemy.dialects.postgresql import UUID
14
+ from sqlalchemy.orm import Mapper
15
+ from sqlalchemy.orm import relationship
16
+
17
+ from ..db.base import Base
18
+ from .helpers import make_repr_attrs
19
+
20
+
21
+ class TaskState(enum.Enum):
22
+ # task just created, not scheduled yet
23
+ PENDING = "PENDING"
24
+ # a worker is processing the task right now
25
+ PROCESSING = "PROCESSING"
26
+ # the task is done
27
+ DONE = "DONE"
28
+ # the task is failed
29
+ FAILED = "FAILED"
30
+
31
+
32
+ class Task(Base):
33
+ id = Column(
34
+ UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()
35
+ )
36
+ # foreign key id of assigned worker
37
+ worker_id = Column(
38
+ UUID(as_uuid=True),
39
+ ForeignKey("bq_workers.id", name="fk_workers_id"),
40
+ nullable=True,
41
+ )
42
+ # current state of the task
43
+ state = Column(
44
+ Enum(TaskState),
45
+ nullable=False,
46
+ default=TaskState.PENDING,
47
+ server_default=TaskState.PENDING.value,
48
+ index=True,
49
+ )
50
+ # channel for workers and job creator to listen/notify
51
+ channel = Column(String, nullable=False, index=True)
52
+ # module of the processor function
53
+ module = Column(String, nullable=False)
54
+ # func name of the processor func
55
+ func_name = Column(String, nullable=False)
56
+ # keyword arguments
57
+ kwargs = Column(JSONB, nullable=True)
58
+ # Result of the task
59
+ result = Column(JSONB, nullable=True)
60
+ # Error message
61
+ error_message = Column(String, nullable=True)
62
+ # created datetime of the task
63
+ created_at = Column(
64
+ DateTime(timezone=True), nullable=False, server_default=func.now()
65
+ )
66
+
67
+ worker = relationship("Worker", back_populates="tasks", uselist=False)
68
+
69
+ __tablename__ = "bq_tasks"
70
+
71
+ def __repr__(self) -> str:
72
+ items = [
73
+ ("id", self.id),
74
+ ("state", self.state),
75
+ ("channel", self.channel),
76
+ ]
77
+ return f"<{self.__class__.__name__} {make_repr_attrs(items)}>"
78
+
79
+
80
+ def notify_if_needed(connection: Connection, task: Task):
81
+ session = inspect(task).session
82
+ transaction = session.get_transaction()
83
+ if transaction is not None:
84
+ key = "_notified_channels"
85
+ if hasattr(transaction, key):
86
+ notified_channels = getattr(transaction, key)
87
+ else:
88
+ notified_channels = set()
89
+ setattr(transaction, key, notified_channels)
90
+
91
+ if task.channel in notified_channels:
92
+ # already notified, skip
93
+ return
94
+ notified_channels.add(task.channel)
95
+
96
+ quoted_channel = connection.dialect.identifier_preparer.quote_identifier(
97
+ task.channel
98
+ )
99
+ connection.exec_driver_sql(f"NOTIFY {quoted_channel}")
100
+
101
+
102
+ @event.listens_for(Task, "after_insert")
103
+ def task_insert_notify(mapper: Mapper, connection: Connection, target: Task):
104
+ from .. import models
105
+
106
+ if target.state != models.TaskState.PENDING:
107
+ return
108
+ notify_if_needed(connection, target)
109
+
110
+
111
+ @event.listens_for(Task, "after_update")
112
+ def task_update_notify(mapper: Mapper, connection: Connection, target: Task):
113
+ from .. import models
114
+
115
+ history = inspect(target).attrs.state.history
116
+ if not history.has_changes():
117
+ return
118
+ if target.state != models.TaskState.PENDING:
119
+ return
120
+ notify_if_needed(connection, target)
bq/models/worker.py ADDED
@@ -0,0 +1,69 @@
1
+ import enum
2
+
3
+ from sqlalchemy import Column
4
+ from sqlalchemy import DateTime
5
+ from sqlalchemy import Enum
6
+ from sqlalchemy import func
7
+ from sqlalchemy import String
8
+ from sqlalchemy.dialects.postgresql import ARRAY
9
+ from sqlalchemy.dialects.postgresql import UUID
10
+ from sqlalchemy.orm import relationship
11
+
12
+ from ..db.base import Base
13
+ from .helpers import make_repr_attrs
14
+
15
+
16
+ class WorkerState(enum.Enum):
17
+ # the worker is running
18
+ RUNNING = "RUNNING"
19
+ # the worker shuts down normally
20
+ SHUTDOWN = "SHUTDOWN"
21
+ # The worker has no heartbeat for a while
22
+ NO_HEARTBEAT = "NO_HEARTBEAT"
23
+
24
+
25
+ class Worker(Base):
26
+ id = Column(
27
+ UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()
28
+ )
29
+ # current state of the worker
30
+ state = Column(
31
+ Enum(WorkerState),
32
+ nullable=False,
33
+ default=WorkerState.RUNNING,
34
+ server_default=WorkerState.RUNNING.value,
35
+ index=True,
36
+ )
37
+ # name of the worker
38
+ name = Column(String, nullable=False)
39
+ # the channels we are processing
40
+ channels = Column(ARRAY(String), nullable=False)
41
+ # last heartbeat of this worker
42
+ last_heartbeat = Column(
43
+ DateTime(timezone=True),
44
+ nullable=False,
45
+ server_default=func.now(),
46
+ index=True,
47
+ )
48
+ # created datetime of the worker
49
+ created_at = Column(
50
+ DateTime(timezone=True), nullable=False, server_default=func.now()
51
+ )
52
+
53
+ tasks = relationship(
54
+ "Task",
55
+ back_populates="worker",
56
+ cascade="all,delete",
57
+ order_by="Task.created_at",
58
+ )
59
+
60
+ __tablename__ = "bq_workers"
61
+
62
+ def __repr__(self) -> str:
63
+ items = [
64
+ ("id", self.id),
65
+ ("name", self.name),
66
+ ("channels", self.channels),
67
+ ("state", self.state),
68
+ ]
69
+ return f"<{self.__class__.__name__} {make_repr_attrs(items)}>"
File without changes
@@ -0,0 +1,136 @@
1
+ import collections
2
+ import dataclasses
3
+ import inspect
4
+ import logging
5
+ import typing
6
+
7
+ import venusian
8
+ from sqlalchemy.orm import object_session
9
+
10
+ from bq import models
11
+
12
+ BQ_PROCESSOR_CATEGORY = "bq_processor"
13
+
14
+
15
+ @dataclasses.dataclass(frozen=True)
16
+ class Processor:
17
+ channel: str
18
+ module: str
19
+ name: str
20
+ func: typing.Callable
21
+ # should we auto complete the task or not
22
+ auto_complete: bool = True
23
+ # should we auto rollback the transaction when encounter unhandled exception
24
+ auto_rollback_on_exc: bool = True
25
+
26
+
27
+ class ProcessorHelper:
28
+ def __init__(self, processor: Processor, task_cls: typing.Type = models.Task):
29
+ self._processor = processor
30
+ self._task_cls = task_cls
31
+
32
+ def __call__(self, *args, **kwargs):
33
+ return self._processor.func(*args, **kwargs)
34
+
35
+ def run(self, **kwargs) -> models.Task:
36
+ return self._task_cls(
37
+ channel=self._processor.channel,
38
+ module=self._processor.module,
39
+ func_name=self._processor.name,
40
+ kwargs=kwargs,
41
+ )
42
+
43
+
44
+ def process_task(task: models.Task, processor: Processor):
45
+ logger = logging.getLogger(__name__)
46
+ db = object_session(task)
47
+ func_signature = inspect.signature(processor.func)
48
+ base_kwargs = {}
49
+ if "task" in func_signature.parameters:
50
+ base_kwargs["task"] = task
51
+ if "db" in func_signature.parameters:
52
+ base_kwargs["db"] = db
53
+ with db.begin_nested() as savepoint:
54
+ try:
55
+ result = processor.func(**base_kwargs, **task.kwargs)
56
+ savepoint.commit()
57
+ except Exception as exc:
58
+ logger.error("Unhandled exception for task %s", task.id, exc_info=True)
59
+ if processor.auto_rollback_on_exc:
60
+ savepoint.rollback()
61
+ # TODO: add error event
62
+ task.state = models.TaskState.FAILED
63
+ task.error_message = str(exc)
64
+ db.add(task)
65
+ return
66
+ if processor.auto_complete:
67
+ logger.info("Task %s auto complete", task.id)
68
+ task.state = models.TaskState.DONE
69
+ task.result = result
70
+ db.add(task)
71
+ return result
72
+
73
+
74
+ class Registry:
75
+ def __init__(self):
76
+ self.logger = logging.getLogger(__name__)
77
+ self.processors = collections.defaultdict(lambda: collections.defaultdict(dict))
78
+
79
+ def add(self, processor: Processor):
80
+ self.processors[processor.channel][processor.module][processor.name] = processor
81
+
82
+ def process(self, task: models.Task) -> typing.Any:
83
+ modules = self.processors.get(task.channel, {})
84
+ functions = modules.get(task.module, {})
85
+ processor = functions.get(task.func_name)
86
+ db = object_session(task)
87
+ if processor is None:
88
+ self.logger.error(
89
+ "Cannot find processor for task %s with module=%s, func=%s",
90
+ task.id,
91
+ task.module,
92
+ task.func_name,
93
+ )
94
+ # TODO: add error event
95
+ task.state = models.TaskState.FAILED
96
+ task.error_message = f"Cannot find processor for task with module={task.module}, func={task.func_name}"
97
+ db.add(task)
98
+ return
99
+ return process_task(task, processor)
100
+
101
+
102
+ def processor(
103
+ channel: str,
104
+ auto_complete: bool = True,
105
+ auto_rollback_on_exc: bool = True,
106
+ task_cls: typing.Type = models.Task,
107
+ ) -> typing.Callable:
108
+ def decorator(wrapped: typing.Callable):
109
+ processor = Processor(
110
+ module=wrapped.__module__,
111
+ name=wrapped.__name__,
112
+ channel=channel,
113
+ func=wrapped,
114
+ auto_complete=auto_complete,
115
+ auto_rollback_on_exc=auto_rollback_on_exc,
116
+ )
117
+ helper_obj = ProcessorHelper(processor, task_cls=task_cls)
118
+
119
+ def callback(scanner: venusian.Scanner, name: str, ob: typing.Callable):
120
+ if processor.name != name:
121
+ raise ValueError("Name is not the same")
122
+ scanner.registry.add(processor)
123
+
124
+ venusian.attach(helper_obj, callback, category=BQ_PROCESSOR_CATEGORY)
125
+ return helper_obj
126
+
127
+ return decorator
128
+
129
+
130
+ def collect(packages: list[typing.Any], registry: Registry | None = None) -> Registry:
131
+ if registry is None:
132
+ registry = Registry()
133
+ scanner = venusian.Scanner(registry=registry)
134
+ for package in packages:
135
+ scanner.scan(package, categories=(BQ_PROCESSOR_CATEGORY,))
136
+ return registry
File without changes
@@ -0,0 +1,96 @@
1
+ import dataclasses
2
+ import select
3
+ import typing
4
+ import uuid
5
+
6
+ from sqlalchemy.orm import Query
7
+
8
+ from .. import models
9
+ from ..db.session import Session
10
+
11
+
12
+ @dataclasses.dataclass(frozen=True)
13
+ class Notification:
14
+ pid: int
15
+ channel: str
16
+ payload: typing.Optional[str] = None
17
+
18
+
19
+ class DispatchService:
20
+ def __init__(self, session: Session):
21
+ self.session = session
22
+
23
+ def make_task_query(self, channels: typing.Sequence[str], limit: int = 1) -> Query:
24
+ return (
25
+ self.session.query(models.Task.id)
26
+ .filter(models.Task.channel.in_(channels))
27
+ .filter(models.Task.state == models.TaskState.PENDING)
28
+ .order_by(models.Task.created_at)
29
+ .limit(limit)
30
+ .with_for_update(skip_locked=True)
31
+ )
32
+
33
+ def make_update_query(self, task_query: typing.Any, worker_id: uuid.UUID):
34
+ return (
35
+ models.Task.__table__.update()
36
+ .where(models.Task.id.in_(task_query))
37
+ .values(
38
+ state=models.TaskState.PROCESSING,
39
+ worker_id=worker_id,
40
+ )
41
+ .returning(models.Task.id)
42
+ )
43
+
44
+ def dispatch(
45
+ self, channels: typing.Sequence[str], worker_id: uuid.UUID, limit: int = 1
46
+ ) -> Query:
47
+ task_query = self.make_task_query(channels, limit=limit)
48
+ task_subquery = task_query.scalar_subquery()
49
+ task_ids = [
50
+ item[0]
51
+ for item in self.session.execute(
52
+ self.make_update_query(task_subquery, worker_id=worker_id)
53
+ )
54
+ ]
55
+ # TODO: ideally returning with (models.Task) should return the whole model, but SQLAlchemy is returning
56
+ # it columns in rows. We can save a round trip if we can find out how to solve this
57
+ return self.session.query(models.Task).filter(models.Task.id.in_(task_ids))
58
+
59
+ def listen(self, channels: typing.Sequence[str]):
60
+ conn = self.session.connection()
61
+ for channel in channels:
62
+ quoted_channel = conn.dialect.identifier_preparer.quote_identifier(channel)
63
+ conn.exec_driver_sql(f"LISTEN {quoted_channel}")
64
+
65
+ def poll(self, timeout: int = 5) -> typing.Generator[Notification, None, None]:
66
+ conn = self.session.connection()
67
+ driver_conn = conn.connection.driver_connection
68
+
69
+ def pop_notifies():
70
+ while driver_conn.notifies:
71
+ notify = driver_conn.notifies.pop(0)
72
+ yield Notification(
73
+ pid=notify.pid,
74
+ channel=notify.channel,
75
+ payload=notify.payload,
76
+ )
77
+
78
+ # poll first to see if there's anything already
79
+ driver_conn.poll()
80
+ if driver_conn.notifies:
81
+ yield from pop_notifies()
82
+ else:
83
+ # okay, nothing, let's select and wait for new stuff
84
+ if select.select([driver_conn], [], [], timeout) == ([], [], []):
85
+ # nope, nothing, times out
86
+ raise TimeoutError("Timeout waiting for new notifications")
87
+ else:
88
+ # yep, we got something
89
+ driver_conn.poll()
90
+ yield from pop_notifies()
91
+
92
+ def notify(self, channels: typing.Sequence[str]):
93
+ conn = self.session.connection()
94
+ for channel in channels:
95
+ quoted_channel = conn.dialect.identifier_preparer.quote_identifier(channel)
96
+ conn.exec_driver_sql(f"NOTIFY {quoted_channel}")
bq/services/worker.py ADDED
@@ -0,0 +1,69 @@
1
+ import datetime
2
+ import typing
3
+
4
+ from sqlalchemy import func
5
+ from sqlalchemy.orm import Query
6
+ from sqlalchemy.orm import Session
7
+
8
+ from .. import models
9
+
10
+
11
+ class WorkerService:
12
+ def __init__(self, session: Session):
13
+ self.session = session
14
+
15
+ def update_heartbeat(self, worker: models.Worker):
16
+ worker.last_heartbeat = func.now()
17
+ self.session.add(worker)
18
+
19
+ def make_dead_worker_query(self, timeout: int, limit: int = 5) -> Query:
20
+ return (
21
+ self.session.query(models.Worker.id)
22
+ .filter(
23
+ models.Worker.last_heartbeat
24
+ < (func.now() - datetime.timedelta(seconds=timeout))
25
+ )
26
+ .filter(models.Worker.state == models.WorkerState.RUNNING)
27
+ .limit(limit)
28
+ .with_for_update(skip_locked=True)
29
+ )
30
+
31
+ def make_update_dead_worker_query(self, worker_query: typing.Any):
32
+ return (
33
+ models.Worker.__table__.update()
34
+ .where(models.Worker.id.in_(worker_query))
35
+ .values(
36
+ state=models.WorkerState.NO_HEARTBEAT,
37
+ )
38
+ .returning(models.Worker.id)
39
+ )
40
+
41
+ def fetch_dead_workers(self, timeout: int, limit: int = 5) -> Query:
42
+ dead_worker_query = self.make_dead_worker_query(timeout=timeout, limit=limit)
43
+ dead_worker_subquery = dead_worker_query.scalar_subquery()
44
+ worker_ids = [
45
+ item[0]
46
+ for item in self.session.execute(
47
+ self.make_update_dead_worker_query(dead_worker_subquery)
48
+ )
49
+ ]
50
+ # TODO: ideally returning with (models.Task) should return the whole model, but SQLAlchemy is returning
51
+ # it columns in rows. We can save a round trip if we can find out how to solve this
52
+ return self.session.query(models.Worker).filter(
53
+ models.Worker.id.in_(worker_ids)
54
+ )
55
+
56
+ def make_update_tasks_query(self, worker_query: typing.Any):
57
+ return (
58
+ models.Task.__table__.update()
59
+ .where(models.Task.worker_id.in_(worker_query))
60
+ .where(models.Task.state == models.TaskState.PROCESSING)
61
+ .values(
62
+ state=models.TaskState.PENDING,
63
+ )
64
+ )
65
+
66
+ def reschedule_dead_tasks(self, worker_query: typing.Any) -> int:
67
+ update_dead_task_query = self.make_update_tasks_query(worker_query=worker_query)
68
+ res = self.session.execute(update_dead_task_query)
69
+ return res.rowcount