workercommon 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,455 @@
1
+ #!/usr/bin/env python3
2
+ import sys
3
+
4
+ if sys.version_info < (3, 8):
5
+ raise RuntimeError("At least Python 3.8 is required")
6
+
7
+ import logging
8
+ import json
9
+ import pika.connection
10
+ import pika.channel
11
+ from pika import (
12
+ BasicProperties,
13
+ BlockingConnection,
14
+ SelectConnection,
15
+ URLParameters,
16
+ )
17
+ from pika.spec import PERSISTENT_DELIVERY_MODE, TRANSIENT_DELIVERY_MODE
18
+ from concurrent.futures import Future
19
+ from functools import partial
20
+ from threading import RLock, Lock, Thread
21
+ from typing import (
22
+ Any,
23
+ Callable,
24
+ Iterable,
25
+ List,
26
+ NamedTuple,
27
+ Optional,
28
+ Tuple,
29
+ )
30
+ from urllib.parse import quote as urlquote
31
+ from datetime import datetime
32
+ from collections import namedtuple
33
+ from copy import copy
34
+
35
+
36
+ class Parameters(NamedTuple):
37
+ mqparams: pika.connection.Parameters
38
+ queue: str
39
+ exchange: str
40
+ durable: bool
41
+ prefetch: int
42
+ routing_key: str
43
+
44
+ @classmethod
45
+ def of(
46
+ cls,
47
+ host: str,
48
+ port: int,
49
+ queue: str,
50
+ exchange: Optional[str] = None,
51
+ durable: bool = False,
52
+ username: str = "guest",
53
+ password: str = "guest",
54
+ prefetch: int = 100,
55
+ vhost: str = "/",
56
+ routing_key: str = "",
57
+ ) -> "Parameters":
58
+ cleaned_username = urlquote(username)
59
+ cleaned_password = urlquote(password)
60
+ vhost = urlquote(vhost, safe="")
61
+
62
+ parameters = URLParameters(
63
+ f"amqp://{cleaned_username}:{cleaned_password}@{host}:{port}/{vhost}"
64
+ )
65
+ cleaned_exchange = exchange if exchange else ""
66
+ return cls(parameters, queue, cleaned_exchange, durable, prefetch, routing_key)
67
+
68
+
69
+ build_parameters = Parameters.of
70
+
71
+
72
+ class BaseQueue(object):
73
+ def get_channel(self) -> pika.channel.Channel:
74
+ raise NotImplementedError
75
+
76
+ def get_parameters(self) -> Parameters:
77
+ raise NotImplementedError
78
+
79
+ @classmethod
80
+ def serialize_for_json(cls, obj: Any) -> str:
81
+ if isinstance(obj, datetime):
82
+ return obj.isoformat()
83
+
84
+ raise TypeError(f"Cannot serialize: {obj}")
85
+
86
+ def send_callback(self, message: Any, persist: bool = False) -> None:
87
+ properties = BasicProperties(
88
+ content_type="application/json",
89
+ delivery_mode=(
90
+ PERSISTENT_DELIVERY_MODE if persist else TRANSIENT_DELIVERY_MODE
91
+ ),
92
+ )
93
+
94
+ parameters = self.get_parameters()
95
+ self.get_channel().basic_publish(
96
+ exchange=parameters.exchange,
97
+ routing_key=parameters.routing_key,
98
+ body=json.dumps(message, default=self.serialize_for_json),
99
+ properties=properties,
100
+ )
101
+
102
+ def send(self, message: Any, persist: bool = False) -> None:
103
+ self.add_callback_threadsafe(
104
+ partial(self.send_callback, message, persist)
105
+ ).result(60)
106
+
107
+ def add_callback_threadsafe(self, callback: Callable[[], Any]) -> Future:
108
+ raise NotImplementedError
109
+
110
+ def run(self) -> None:
111
+ raise NotImplementedError
112
+
113
+ @staticmethod
114
+ def futurify(callback: Callable[[], Any]) -> Tuple[Callable[[], None], Future]:
115
+ future: Future = Future()
116
+
117
+ def wrapped():
118
+ try:
119
+ future.set_result(callback())
120
+
121
+ except BaseException as e:
122
+ future.set_exception(e)
123
+
124
+ return wrapped, future
125
+
126
+
127
+ ToSend = namedtuple("ToSend", ["message", "persist"])
128
+
129
+
130
+ class SendQueue(BaseQueue):
131
+ @classmethod
132
+ def build_channel(
133
+ cls,
134
+ connection: pika.connection.Connection,
135
+ parameters: Parameters,
136
+ **queue_arguments: Any,
137
+ ) -> pika.channel.Channel:
138
+ channel = connection.channel()
139
+ channel.exchange_declare(
140
+ exchange=parameters.exchange, durable=parameters.durable
141
+ )
142
+ channel.queue_declare(
143
+ queue=parameters.queue,
144
+ durable=parameters.durable,
145
+ arguments=queue_arguments,
146
+ )
147
+ channel.queue_bind(
148
+ queue=parameters.queue,
149
+ exchange=parameters.exchange,
150
+ routing_key=parameters.routing_key,
151
+ )
152
+
153
+ return channel
154
+
155
+ def __init__(self, parameters: Parameters, **queue_arguments: Any):
156
+ self.lock = RLock()
157
+ self.parameters = parameters
158
+ self.connection: Optional[BlockingConnection] = None
159
+ self.channel: Optional[pika.channel.Channel] = None
160
+ try:
161
+ self.connection = BlockingConnection(self.parameters.mqparams)
162
+ self.channel = self.build_channel(self.connection, self.parameters, **queue_arguments)
163
+ except BaseException:
164
+ self.close()
165
+ raise
166
+
167
+ def get_channel(self) -> pika.channel.Channel:
168
+ if self.channel is None:
169
+ raise RuntimeError("Channel is not available")
170
+ return self.channel
171
+
172
+ def get_parameters(self) -> Parameters:
173
+ return self.parameters
174
+
175
+ def get_connection(self) -> BlockingConnection:
176
+ with self.lock:
177
+ if self.connection is None:
178
+ raise RuntimeError("Connection is not available")
179
+ return self.connection
180
+
181
+ def close(self) -> None:
182
+ with self.lock:
183
+ if self.connection is not None:
184
+ self.connection.close()
185
+ self.connection = None
186
+ self.channel = None
187
+
188
+ @classmethod
189
+ def quick_send(cls, parameters: Parameters, messages_iter: Iterable[ToSend]) -> None:
190
+ messages = list(messages_iter)
191
+ if not messages:
192
+ return
193
+ queue: Optional[SendQueue] = None
194
+ try:
195
+ queue = cls(parameters)
196
+ while messages:
197
+ msg = messages.pop()
198
+ try:
199
+ queue.send(msg.message, msg.persist)
200
+ except BaseException:
201
+ messages.insert(0, msg)
202
+ except BaseException:
203
+ logging.exception(f"Failed to send messages: {messages}")
204
+ raise
205
+ finally:
206
+ if queue is not None:
207
+ queue.close()
208
+
209
+ def add_callback_threadsafe(self, callback: Callable[[], Any]) -> Future:
210
+ wrapped, future = self.futurify(callback)
211
+ try:
212
+ self.get_connection().add_callback_threadsafe(wrapped)
213
+ except BaseException as e:
214
+ future.set_exception(e)
215
+ return future
216
+
217
+ def run(self) -> None:
218
+ logging.info(f"Starting data events for {self.get_connection()}")
219
+ while True:
220
+ self.get_connection().process_data_events(None)
221
+
222
+
223
+ class SendQueueWrapper(object):
224
+ "Wraps a SendQueue and recreates on error."
225
+ TRIES = 3
226
+
227
+ def __init__(self, parameters: Parameters, **queue_arguments: Any):
228
+ self.lock = Lock()
229
+ self.parameters = parameters
230
+ self.queue_arguments = queue_arguments
231
+ self.queue: Optional[SendQueue] = SendQueue(self.parameters, **self.queue_arguments)
232
+ self.queue_thread: Optional[Thread] = self.create_queue_thread(self.queue)
233
+ self.closed = False
234
+
235
+ def close(self) -> None:
236
+ with self.lock:
237
+ self.closed = True
238
+ if self.queue is not None:
239
+ self.queue.close()
240
+ self.queue = None
241
+
242
+ self.join_queue_thread()
243
+
244
+ @staticmethod
245
+ def create_queue_thread(queue: SendQueue) -> Thread:
246
+ queue_thread = Thread(target=queue.run)
247
+ queue_thread.start()
248
+ return queue_thread
249
+
250
+ def join_queue_thread(self) -> None:
251
+ "Must be called with self.lock held!"
252
+ if self.queue_thread is not None:
253
+ try:
254
+ self.queue_thread.join(10)
255
+ except BaseException:
256
+ logging.exception(f"Failed to wait for queue thread {self.queue_thread}")
257
+ self.queue_thread = None
258
+
259
+ def send(self, message: Any, persist: bool = False) -> None:
260
+ for i in range(self.TRIES):
261
+ with self.lock:
262
+ if self.closed:
263
+ raise RuntimeError("Queue is closed")
264
+ try:
265
+ if self.queue is None:
266
+ self.queue = SendQueue(self.parameters, **self.queue_arguments)
267
+ self.join_queue_thread()
268
+ self.queue_thread = self.create_queue_thread(self.queue)
269
+ logging.info("Re-established broken connection")
270
+
271
+ self.queue.send(message, persist)
272
+ break
273
+
274
+ except BaseException:
275
+ logging.exception("Failed to send. Attempting to reconnect")
276
+ if self.queue is not None:
277
+ try:
278
+ self.queue.close()
279
+ except BaseException:
280
+ logging.warning("Failed to close queue")
281
+ self.queue = None
282
+
283
+ self.join_queue_thread()
284
+
285
+ else:
286
+ logging.error(f"Failed to send after {self.TRIES} tries: {message}")
287
+
288
+
289
+ Callback = Callable[
290
+ [pika.channel.Channel, pika.spec.Basic.Deliver, pika.spec.BasicProperties, str], None
291
+ ]
292
+ DisconnectCallback = Callable[[bool, List[ToSend]], None]
293
+
294
+
295
+ class ReceiveSendQueue(BaseQueue):
296
+ def __init__(
297
+ self,
298
+ parameters: Parameters,
299
+ on_message: Callback,
300
+ on_disconnect: DisconnectCallback,
301
+ **queue_arguments: Any,
302
+ ):
303
+ self.lock = RLock()
304
+ self.outbox: List[ToSend] = []
305
+ self.parameters = parameters
306
+ self.on_message = on_message
307
+ self.on_disconnect = on_disconnect
308
+ self.terminated = False
309
+ self.queue_arguments = queue_arguments
310
+ self.connection: Optional[pika.connection.Connection] = None
311
+ self.channel: Optional[pika.channel.Channel] = None
312
+ self.init()
313
+
314
+ def add_to_outbox(self, items: Iterable[ToSend]) -> None:
315
+ self.outbox.extend(items)
316
+
317
+ def get_outbox(self) -> List[ToSend]:
318
+ with self.lock:
319
+ return copy(self.outbox)
320
+
321
+ def init(self):
322
+ with self.lock:
323
+ try:
324
+ self.terminated = False
325
+ self.connecion = None
326
+ self.channel = None
327
+ self.connection = SelectConnection(
328
+ parameters=self.parameters.mqparams,
329
+ on_open_callback=self.on_connection_open,
330
+ on_close_callback=self.on_connection_closed,
331
+ )
332
+ except BaseException:
333
+ logging.execption("Failed to initialize connection")
334
+ self.close()
335
+ raise
336
+
337
+ def recovery(self, channel: pika.channel.Channel) -> None:
338
+ with self.lock:
339
+ while self.outbox:
340
+ to_send = self.outbox.pop(0)
341
+ self.send(to_send.message, to_send.persist)
342
+
343
+ def get_channel(self) -> pika.channel.Channel:
344
+ with self.lock:
345
+ if self.channel is None:
346
+ raise RuntimeError("Channel is not available")
347
+ return self.channel
348
+
349
+ def get_parameters(self) -> Parameters:
350
+ return self.parameters
351
+
352
+ def get_connection(self) -> SelectConnection:
353
+ with self.lock:
354
+ if self.connection is None:
355
+ raise RuntimeError("Connection is not available")
356
+ return self.connection
357
+
358
+ def send(self, message: Any, persist: bool = False) -> None:
359
+ with self.lock:
360
+ self.outbox.append(ToSend(message, persist))
361
+ super().send(message, persist)
362
+ with self.lock:
363
+ del self.outbox[-1]
364
+
365
+ def close(self) -> None:
366
+ try:
367
+ with self.lock:
368
+ if self.outbox:
369
+ logging.info(f"Trying to empty outbox before closing: {self.outbox}")
370
+ self.recovery(self.channel)
371
+ except BaseException:
372
+ logging.exception(f"Failed to empty outbox: {self.outbox}")
373
+
374
+ connection: Optional[pika.connection.Connection] = None
375
+ with self.lock:
376
+ self.terminated = True
377
+ self.channel = None
378
+ connection = self.connection
379
+ self.connection = None
380
+
381
+ if connection is not None:
382
+ try:
383
+ connection.ioloop.stop()
384
+ except BaseException:
385
+ logging.exception("Failed to stop ioloop")
386
+ finally:
387
+ try:
388
+ connection.close()
389
+ except BaseException:
390
+ logging.exception("Failed to close connection")
391
+
392
+ with self.lock:
393
+ if self.outbox:
394
+ logging.warning(f"Outbound items not sent: {self.outbox}")
395
+
396
+ def on_connection_closed(
397
+ self, connection: pika.connection.Connection, error: Any
398
+ ) -> None:
399
+ with self.lock:
400
+ if not self.terminated:
401
+ logging.warning(
402
+ f"Flagging to set up a new connection because the old one ({connection}) was lost: {error}"
403
+ )
404
+ try:
405
+ connection.close()
406
+ finally:
407
+ try:
408
+ connection.ioloop.stop()
409
+ finally:
410
+ self.on_disconnect(True, copy(self.outbox))
411
+ else:
412
+ self.on_disconnect(False, copy(self.outbox))
413
+
414
+ def on_connection_open(self, connection: pika.connection.Connection) -> None:
415
+ self.channel = connection.channel(on_open_callback=self.on_channel_open)
416
+
417
+ def on_channel_open(self, channel: pika.channel.Channel) -> None:
418
+ channel.basic_qos(prefetch_count=self.parameters.prefetch)
419
+ channel.queue_declare(
420
+ queue=self.parameters.queue,
421
+ durable=self.parameters.durable,
422
+ callback=lambda _: self.on_queue_declare(channel),
423
+ arguments=self.queue_arguments,
424
+ )
425
+
426
+ def on_queue_declare(self, channel: pika.channel.Channel) -> None:
427
+ channel.exchange_declare(
428
+ exchange=self.parameters.exchange,
429
+ durable=self.parameters.durable,
430
+ callback=lambda _: channel.queue_bind(
431
+ queue=self.parameters.queue,
432
+ exchange=self.parameters.exchange,
433
+ routing_key=self.parameters.routing_key,
434
+ callback=lambda _: self.on_channel_ready(channel),
435
+ ),
436
+ )
437
+
438
+ def on_channel_ready(self, channel: pika.channel.Channel) -> None:
439
+ self.recovery(channel)
440
+ channel.basic_consume(
441
+ on_message_callback=self.on_message,
442
+ queue=self.parameters.queue,
443
+ auto_ack=False,
444
+ )
445
+
446
+ def add_callback_threadsafe(self, callback: Callable[[], Any]) -> Future:
447
+ wrapped, future = self.futurify(callback)
448
+ try:
449
+ self.get_connection().ioloop.add_callback_threadsafe(callback)
450
+ except BaseException as e:
451
+ future.set_exception(e)
452
+ return future
453
+
454
+ def run(self) -> None:
455
+ self.get_connection().ioloop.start()
workercommon/test.py ADDED
@@ -0,0 +1,177 @@
1
+ #!/usr/bin/env python3
2
+ import sys
3
+
4
+ if sys.version_info < (3, 8):
5
+ raise RuntimeError("At least Python 3.8 is required")
6
+ import unittest
7
+
8
+ from os import environ
9
+ from getpass import getpass
10
+ import logging
11
+ from database import *
12
+ from typing import List
13
+
14
+
15
+ class DBTesting(object):
16
+ DB_PASSWORD: List[str] = []
17
+
18
+ @classmethod
19
+ def get_password(cls) -> str:
20
+ if not cls.DB_PASSWORD:
21
+ cls.DB_PASSWORD.append(getpass("Password: "))
22
+ return cls.DB_PASSWORD[0]
23
+
24
+ @classmethod
25
+ def build_connection(cls) -> Connection:
26
+ host = environ.get("DB_HOST", None)
27
+
28
+ port_number = 5432
29
+ if port := environ.get("DB_PORT", ""):
30
+ port_number = int(port)
31
+
32
+ username = environ["DB_USER"]
33
+ password = cls.get_password()
34
+ database = environ["DB_DATABASE"]
35
+
36
+ return PgConnection(host, port_number, username, password, database)
37
+
38
+
39
+ class TestConnect(unittest.TestCase, DBTesting):
40
+ def test_connect(self):
41
+ self.build_connection().close()
42
+
43
+
44
+ class TestOperations(unittest.TestCase, DBTesting):
45
+ def setUp(self):
46
+ self.connection = self.build_connection()
47
+
48
+ def tearDown(self):
49
+ self.connection.close()
50
+ self.connection = None
51
+
52
+ def test_cursor(self):
53
+ cursor = self.connection.cursor()
54
+ self.assertIsNotNone(cursor)
55
+ with self.assertRaises(RuntimeError):
56
+ cursor.get()
57
+
58
+ with cursor:
59
+ self.assertIsNotNone(cursor.get())
60
+
61
+ with self.assertRaises(RuntimeError):
62
+ cursor.get()
63
+
64
+ def test_direct_table(self):
65
+ with self.connection.cursor() as cursor:
66
+ cursor.execute("CREATE TEMPORARY TABLE test(value1 INTEGER, value2 TEXT)")
67
+ cursor.execute("INSERT INTO test(value1, value2) VALUES(%s, %s)", 1, "foo")
68
+
69
+ # Basic functionality
70
+ row = cursor.execute_direct(
71
+ "SELECT value1, value2 FROM test LIMIT 1"
72
+ ).fetchone()
73
+ self.assertIsNotNone(row)
74
+ self.assertTrue(cursor.columns)
75
+ self.assertEqual(row[0], 1)
76
+ self.assertEqual(row[1], "foo")
77
+
78
+ # Advanced functionality
79
+ rows = list(cursor.execute("SELECT value1, value2 FROM test LIMIT 1"))
80
+ self.assertEqual(len(rows), 1)
81
+ dict_row = rows[0]
82
+ self.assertEqual(dict_row["value1"], 1)
83
+ self.assertEqual(dict_row["value2"], "foo")
84
+
85
+ def test_table(self):
86
+ table_type = PgBaseTable("table_test", "key", ["key", "value1", "value2"])
87
+ with self.connection.cursor() as cursor:
88
+ table_cursor = TableCursor(cursor, table_type)
89
+
90
+ def verify2():
91
+ row = table_cursor[2]
92
+ self.assertTrue(row)
93
+ self.assertEqual(row["key"], 2)
94
+ self.assertEqual(row["value1"], 30)
95
+ self.assertEqual(row["value2"], "brown")
96
+
97
+ cursor.execute(
98
+ "CREATE TEMPORARY TABLE table_test(key SERIAL PRIMARY KEY NOT NULL, value1 INTEGER, value2 TEXT)"
99
+ )
100
+ # Test insert and get
101
+ self.assertEqual(table_type.insert(cursor, value1=100, value2="cow"), 1)
102
+
103
+ row = table_type.select_by_id(cursor, 1)
104
+ self.assertTrue(row)
105
+ self.assertEqual(row["key"], 1)
106
+ self.assertEqual(row["value1"], 100)
107
+ self.assertEqual(row["value2"], "cow")
108
+
109
+ # Test insert again and get
110
+ self.assertEqual(table_type.insert(cursor, value1=30, value2="brown"), 2)
111
+
112
+ verify2()
113
+
114
+ # Test selecting by other column values
115
+ rows = list(table_cursor.where(100, value1=100))
116
+ self.assertEqual(len(rows), 1)
117
+ row = rows[0]
118
+ self.assertEqual(row["key"], 1)
119
+ self.assertEqual(row["value1"], 100)
120
+ self.assertEqual(row["value2"], "cow")
121
+
122
+ # Test custom selecting
123
+ rows = list(table_cursor.custom_where("value1 > %s", 50))
124
+ self.assertEqual(len(rows), 1)
125
+ row = rows[0]
126
+ self.assertEqual(row["key"], 1)
127
+ self.assertEqual(row["value1"], 100)
128
+ self.assertEqual(row["value2"], "cow")
129
+
130
+ # Test selecting by other column values where a None is present
131
+ self.assertEqual(table_type.insert(cursor, value1=None, value2="skittle"), 3)
132
+ rows = list(table_cursor.where(100, value1=None))
133
+ self.assertEqual(len(rows), 1)
134
+ row = rows[0]
135
+ self.assertEqual(row["key"], 3)
136
+ self.assertIsNone(row["value1"])
137
+ self.assertEqual(row["value2"], "skittle")
138
+
139
+ # Test upserting on an alternate column
140
+ self.assertEqual(
141
+ table_cursor.upsert_on_column("value1", None, value2="1234"), 3
142
+ )
143
+ rows = list(table_cursor.where(100, value1=None))
144
+ self.assertEqual(len(rows), 1)
145
+ row = rows[0]
146
+ self.assertEqual(row["key"], 3)
147
+ self.assertIsNone(row["value1"])
148
+ self.assertEqual(row["value2"], "1234")
149
+
150
+ # Test upserting
151
+ table_cursor[1] = {"value1": 200, "value2": "quack"}
152
+ self.assertEqual(table_type.upsert(cursor, 1, value1=200, value2="quack"), 1)
153
+
154
+ row = table_type.select_by_id(cursor, 1)
155
+ self.assertTrue(row)
156
+ self.assertEqual(row["key"], 1)
157
+ self.assertEqual(row["value1"], 200)
158
+ self.assertEqual(row["value2"], "quack")
159
+
160
+ verify2()
161
+
162
+ # Test deletion
163
+ self.assertTrue(table_type.delete(cursor, 1))
164
+ with self.assertRaises(KeyError):
165
+ table_type.select_by_id(cursor, 1)
166
+ verify2()
167
+
168
+ self.assertFalse(table_type.delete(cursor, 1))
169
+
170
+ with self.assertRaises(KeyError):
171
+ del table_cursor[1]
172
+
173
+ del table_cursor[2]
174
+
175
+
176
+ logging.basicConfig(level=logging.DEBUG, stream=sys.stderr)
177
+ unittest.main()