workercommon 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- workercommon/__init__.py +7 -0
- workercommon/background.py +72 -0
- workercommon/commands.py +42 -0
- workercommon/config.py +81 -0
- workercommon/connectionpool.py +97 -0
- workercommon/database/__init__.py +16 -0
- workercommon/database/base.py +173 -0
- workercommon/database/pg.py +278 -0
- workercommon/database/py.typed +0 -0
- workercommon/locking.py +52 -0
- workercommon/py.typed +0 -0
- workercommon/rabbitmqueue.py +455 -0
- workercommon/test.py +177 -0
- workercommon/worker.py +347 -0
- workercommon-0.4.1.dist-info/METADATA +25 -0
- workercommon-0.4.1.dist-info/RECORD +19 -0
- workercommon-0.4.1.dist-info/WHEEL +5 -0
- workercommon-0.4.1.dist-info/licenses/LICENSE.TXT +9 -0
- workercommon-0.4.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,455 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
import sys
|
|
3
|
+
|
|
4
|
+
if sys.version_info < (3, 8):
|
|
5
|
+
raise RuntimeError("At least Python 3.8 is required")
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
import json
|
|
9
|
+
import pika.connection
|
|
10
|
+
import pika.channel
|
|
11
|
+
from pika import (
|
|
12
|
+
BasicProperties,
|
|
13
|
+
BlockingConnection,
|
|
14
|
+
SelectConnection,
|
|
15
|
+
URLParameters,
|
|
16
|
+
)
|
|
17
|
+
from pika.spec import PERSISTENT_DELIVERY_MODE, TRANSIENT_DELIVERY_MODE
|
|
18
|
+
from concurrent.futures import Future
|
|
19
|
+
from functools import partial
|
|
20
|
+
from threading import RLock, Lock, Thread
|
|
21
|
+
from typing import (
|
|
22
|
+
Any,
|
|
23
|
+
Callable,
|
|
24
|
+
Iterable,
|
|
25
|
+
List,
|
|
26
|
+
NamedTuple,
|
|
27
|
+
Optional,
|
|
28
|
+
Tuple,
|
|
29
|
+
)
|
|
30
|
+
from urllib.parse import quote as urlquote
|
|
31
|
+
from datetime import datetime
|
|
32
|
+
from collections import namedtuple
|
|
33
|
+
from copy import copy
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class Parameters(NamedTuple):
|
|
37
|
+
mqparams: pika.connection.Parameters
|
|
38
|
+
queue: str
|
|
39
|
+
exchange: str
|
|
40
|
+
durable: bool
|
|
41
|
+
prefetch: int
|
|
42
|
+
routing_key: str
|
|
43
|
+
|
|
44
|
+
@classmethod
|
|
45
|
+
def of(
|
|
46
|
+
cls,
|
|
47
|
+
host: str,
|
|
48
|
+
port: int,
|
|
49
|
+
queue: str,
|
|
50
|
+
exchange: Optional[str] = None,
|
|
51
|
+
durable: bool = False,
|
|
52
|
+
username: str = "guest",
|
|
53
|
+
password: str = "guest",
|
|
54
|
+
prefetch: int = 100,
|
|
55
|
+
vhost: str = "/",
|
|
56
|
+
routing_key: str = "",
|
|
57
|
+
) -> "Parameters":
|
|
58
|
+
cleaned_username = urlquote(username)
|
|
59
|
+
cleaned_password = urlquote(password)
|
|
60
|
+
vhost = urlquote(vhost, safe="")
|
|
61
|
+
|
|
62
|
+
parameters = URLParameters(
|
|
63
|
+
f"amqp://{cleaned_username}:{cleaned_password}@{host}:{port}/{vhost}"
|
|
64
|
+
)
|
|
65
|
+
cleaned_exchange = exchange if exchange else ""
|
|
66
|
+
return cls(parameters, queue, cleaned_exchange, durable, prefetch, routing_key)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
build_parameters = Parameters.of
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class BaseQueue(object):
|
|
73
|
+
def get_channel(self) -> pika.channel.Channel:
|
|
74
|
+
raise NotImplementedError
|
|
75
|
+
|
|
76
|
+
def get_parameters(self) -> Parameters:
|
|
77
|
+
raise NotImplementedError
|
|
78
|
+
|
|
79
|
+
@classmethod
|
|
80
|
+
def serialize_for_json(cls, obj: Any) -> str:
|
|
81
|
+
if isinstance(obj, datetime):
|
|
82
|
+
return obj.isoformat()
|
|
83
|
+
|
|
84
|
+
raise TypeError(f"Cannot serialize: {obj}")
|
|
85
|
+
|
|
86
|
+
def send_callback(self, message: Any, persist: bool = False) -> None:
|
|
87
|
+
properties = BasicProperties(
|
|
88
|
+
content_type="application/json",
|
|
89
|
+
delivery_mode=(
|
|
90
|
+
PERSISTENT_DELIVERY_MODE if persist else TRANSIENT_DELIVERY_MODE
|
|
91
|
+
),
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
parameters = self.get_parameters()
|
|
95
|
+
self.get_channel().basic_publish(
|
|
96
|
+
exchange=parameters.exchange,
|
|
97
|
+
routing_key=parameters.routing_key,
|
|
98
|
+
body=json.dumps(message, default=self.serialize_for_json),
|
|
99
|
+
properties=properties,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
def send(self, message: Any, persist: bool = False) -> None:
|
|
103
|
+
self.add_callback_threadsafe(
|
|
104
|
+
partial(self.send_callback, message, persist)
|
|
105
|
+
).result(60)
|
|
106
|
+
|
|
107
|
+
def add_callback_threadsafe(self, callback: Callable[[], Any]) -> Future:
|
|
108
|
+
raise NotImplementedError
|
|
109
|
+
|
|
110
|
+
def run(self) -> None:
|
|
111
|
+
raise NotImplementedError
|
|
112
|
+
|
|
113
|
+
@staticmethod
|
|
114
|
+
def futurify(callback: Callable[[], Any]) -> Tuple[Callable[[], None], Future]:
|
|
115
|
+
future: Future = Future()
|
|
116
|
+
|
|
117
|
+
def wrapped():
|
|
118
|
+
try:
|
|
119
|
+
future.set_result(callback())
|
|
120
|
+
|
|
121
|
+
except BaseException as e:
|
|
122
|
+
future.set_exception(e)
|
|
123
|
+
|
|
124
|
+
return wrapped, future
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
ToSend = namedtuple("ToSend", ["message", "persist"])
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
class SendQueue(BaseQueue):
|
|
131
|
+
@classmethod
|
|
132
|
+
def build_channel(
|
|
133
|
+
cls,
|
|
134
|
+
connection: pika.connection.Connection,
|
|
135
|
+
parameters: Parameters,
|
|
136
|
+
**queue_arguments: Any,
|
|
137
|
+
) -> pika.channel.Channel:
|
|
138
|
+
channel = connection.channel()
|
|
139
|
+
channel.exchange_declare(
|
|
140
|
+
exchange=parameters.exchange, durable=parameters.durable
|
|
141
|
+
)
|
|
142
|
+
channel.queue_declare(
|
|
143
|
+
queue=parameters.queue,
|
|
144
|
+
durable=parameters.durable,
|
|
145
|
+
arguments=queue_arguments,
|
|
146
|
+
)
|
|
147
|
+
channel.queue_bind(
|
|
148
|
+
queue=parameters.queue,
|
|
149
|
+
exchange=parameters.exchange,
|
|
150
|
+
routing_key=parameters.routing_key,
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
return channel
|
|
154
|
+
|
|
155
|
+
def __init__(self, parameters: Parameters, **queue_arguments: Any):
|
|
156
|
+
self.lock = RLock()
|
|
157
|
+
self.parameters = parameters
|
|
158
|
+
self.connection: Optional[BlockingConnection] = None
|
|
159
|
+
self.channel: Optional[pika.channel.Channel] = None
|
|
160
|
+
try:
|
|
161
|
+
self.connection = BlockingConnection(self.parameters.mqparams)
|
|
162
|
+
self.channel = self.build_channel(self.connection, self.parameters, **queue_arguments)
|
|
163
|
+
except BaseException:
|
|
164
|
+
self.close()
|
|
165
|
+
raise
|
|
166
|
+
|
|
167
|
+
def get_channel(self) -> pika.channel.Channel:
|
|
168
|
+
if self.channel is None:
|
|
169
|
+
raise RuntimeError("Channel is not available")
|
|
170
|
+
return self.channel
|
|
171
|
+
|
|
172
|
+
def get_parameters(self) -> Parameters:
|
|
173
|
+
return self.parameters
|
|
174
|
+
|
|
175
|
+
def get_connection(self) -> BlockingConnection:
|
|
176
|
+
with self.lock:
|
|
177
|
+
if self.connection is None:
|
|
178
|
+
raise RuntimeError("Connection is not available")
|
|
179
|
+
return self.connection
|
|
180
|
+
|
|
181
|
+
def close(self) -> None:
|
|
182
|
+
with self.lock:
|
|
183
|
+
if self.connection is not None:
|
|
184
|
+
self.connection.close()
|
|
185
|
+
self.connection = None
|
|
186
|
+
self.channel = None
|
|
187
|
+
|
|
188
|
+
@classmethod
|
|
189
|
+
def quick_send(cls, parameters: Parameters, messages_iter: Iterable[ToSend]) -> None:
|
|
190
|
+
messages = list(messages_iter)
|
|
191
|
+
if not messages:
|
|
192
|
+
return
|
|
193
|
+
queue: Optional[SendQueue] = None
|
|
194
|
+
try:
|
|
195
|
+
queue = cls(parameters)
|
|
196
|
+
while messages:
|
|
197
|
+
msg = messages.pop()
|
|
198
|
+
try:
|
|
199
|
+
queue.send(msg.message, msg.persist)
|
|
200
|
+
except BaseException:
|
|
201
|
+
messages.insert(0, msg)
|
|
202
|
+
except BaseException:
|
|
203
|
+
logging.exception(f"Failed to send messages: {messages}")
|
|
204
|
+
raise
|
|
205
|
+
finally:
|
|
206
|
+
if queue is not None:
|
|
207
|
+
queue.close()
|
|
208
|
+
|
|
209
|
+
def add_callback_threadsafe(self, callback: Callable[[], Any]) -> Future:
|
|
210
|
+
wrapped, future = self.futurify(callback)
|
|
211
|
+
try:
|
|
212
|
+
self.get_connection().add_callback_threadsafe(wrapped)
|
|
213
|
+
except BaseException as e:
|
|
214
|
+
future.set_exception(e)
|
|
215
|
+
return future
|
|
216
|
+
|
|
217
|
+
def run(self) -> None:
|
|
218
|
+
logging.info(f"Starting data events for {self.get_connection()}")
|
|
219
|
+
while True:
|
|
220
|
+
self.get_connection().process_data_events(None)
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
class SendQueueWrapper(object):
|
|
224
|
+
"Wraps a SendQueue and recreates on error."
|
|
225
|
+
TRIES = 3
|
|
226
|
+
|
|
227
|
+
def __init__(self, parameters: Parameters, **queue_arguments: Any):
|
|
228
|
+
self.lock = Lock()
|
|
229
|
+
self.parameters = parameters
|
|
230
|
+
self.queue_arguments = queue_arguments
|
|
231
|
+
self.queue: Optional[SendQueue] = SendQueue(self.parameters, **self.queue_arguments)
|
|
232
|
+
self.queue_thread: Optional[Thread] = self.create_queue_thread(self.queue)
|
|
233
|
+
self.closed = False
|
|
234
|
+
|
|
235
|
+
def close(self) -> None:
|
|
236
|
+
with self.lock:
|
|
237
|
+
self.closed = True
|
|
238
|
+
if self.queue is not None:
|
|
239
|
+
self.queue.close()
|
|
240
|
+
self.queue = None
|
|
241
|
+
|
|
242
|
+
self.join_queue_thread()
|
|
243
|
+
|
|
244
|
+
@staticmethod
|
|
245
|
+
def create_queue_thread(queue: SendQueue) -> Thread:
|
|
246
|
+
queue_thread = Thread(target=queue.run)
|
|
247
|
+
queue_thread.start()
|
|
248
|
+
return queue_thread
|
|
249
|
+
|
|
250
|
+
def join_queue_thread(self) -> None:
|
|
251
|
+
"Must be called with self.lock held!"
|
|
252
|
+
if self.queue_thread is not None:
|
|
253
|
+
try:
|
|
254
|
+
self.queue_thread.join(10)
|
|
255
|
+
except BaseException:
|
|
256
|
+
logging.exception(f"Failed to wait for queue thread {self.queue_thread}")
|
|
257
|
+
self.queue_thread = None
|
|
258
|
+
|
|
259
|
+
def send(self, message: Any, persist: bool = False) -> None:
|
|
260
|
+
for i in range(self.TRIES):
|
|
261
|
+
with self.lock:
|
|
262
|
+
if self.closed:
|
|
263
|
+
raise RuntimeError("Queue is closed")
|
|
264
|
+
try:
|
|
265
|
+
if self.queue is None:
|
|
266
|
+
self.queue = SendQueue(self.parameters, **self.queue_arguments)
|
|
267
|
+
self.join_queue_thread()
|
|
268
|
+
self.queue_thread = self.create_queue_thread(self.queue)
|
|
269
|
+
logging.info("Re-established broken connection")
|
|
270
|
+
|
|
271
|
+
self.queue.send(message, persist)
|
|
272
|
+
break
|
|
273
|
+
|
|
274
|
+
except BaseException:
|
|
275
|
+
logging.exception("Failed to send. Attempting to reconnect")
|
|
276
|
+
if self.queue is not None:
|
|
277
|
+
try:
|
|
278
|
+
self.queue.close()
|
|
279
|
+
except BaseException:
|
|
280
|
+
logging.warning("Failed to close queue")
|
|
281
|
+
self.queue = None
|
|
282
|
+
|
|
283
|
+
self.join_queue_thread()
|
|
284
|
+
|
|
285
|
+
else:
|
|
286
|
+
logging.error(f"Failed to send after {self.TRIES} tries: {message}")
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
Callback = Callable[
|
|
290
|
+
[pika.channel.Channel, pika.spec.Basic.Deliver, pika.spec.BasicProperties, str], None
|
|
291
|
+
]
|
|
292
|
+
DisconnectCallback = Callable[[bool, List[ToSend]], None]
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
class ReceiveSendQueue(BaseQueue):
|
|
296
|
+
def __init__(
|
|
297
|
+
self,
|
|
298
|
+
parameters: Parameters,
|
|
299
|
+
on_message: Callback,
|
|
300
|
+
on_disconnect: DisconnectCallback,
|
|
301
|
+
**queue_arguments: Any,
|
|
302
|
+
):
|
|
303
|
+
self.lock = RLock()
|
|
304
|
+
self.outbox: List[ToSend] = []
|
|
305
|
+
self.parameters = parameters
|
|
306
|
+
self.on_message = on_message
|
|
307
|
+
self.on_disconnect = on_disconnect
|
|
308
|
+
self.terminated = False
|
|
309
|
+
self.queue_arguments = queue_arguments
|
|
310
|
+
self.connection: Optional[pika.connection.Connection] = None
|
|
311
|
+
self.channel: Optional[pika.channel.Channel] = None
|
|
312
|
+
self.init()
|
|
313
|
+
|
|
314
|
+
def add_to_outbox(self, items: Iterable[ToSend]) -> None:
|
|
315
|
+
self.outbox.extend(items)
|
|
316
|
+
|
|
317
|
+
def get_outbox(self) -> List[ToSend]:
|
|
318
|
+
with self.lock:
|
|
319
|
+
return copy(self.outbox)
|
|
320
|
+
|
|
321
|
+
def init(self):
|
|
322
|
+
with self.lock:
|
|
323
|
+
try:
|
|
324
|
+
self.terminated = False
|
|
325
|
+
self.connecion = None
|
|
326
|
+
self.channel = None
|
|
327
|
+
self.connection = SelectConnection(
|
|
328
|
+
parameters=self.parameters.mqparams,
|
|
329
|
+
on_open_callback=self.on_connection_open,
|
|
330
|
+
on_close_callback=self.on_connection_closed,
|
|
331
|
+
)
|
|
332
|
+
except BaseException:
|
|
333
|
+
logging.execption("Failed to initialize connection")
|
|
334
|
+
self.close()
|
|
335
|
+
raise
|
|
336
|
+
|
|
337
|
+
def recovery(self, channel: pika.channel.Channel) -> None:
|
|
338
|
+
with self.lock:
|
|
339
|
+
while self.outbox:
|
|
340
|
+
to_send = self.outbox.pop(0)
|
|
341
|
+
self.send(to_send.message, to_send.persist)
|
|
342
|
+
|
|
343
|
+
def get_channel(self) -> pika.channel.Channel:
|
|
344
|
+
with self.lock:
|
|
345
|
+
if self.channel is None:
|
|
346
|
+
raise RuntimeError("Channel is not available")
|
|
347
|
+
return self.channel
|
|
348
|
+
|
|
349
|
+
def get_parameters(self) -> Parameters:
|
|
350
|
+
return self.parameters
|
|
351
|
+
|
|
352
|
+
def get_connection(self) -> SelectConnection:
|
|
353
|
+
with self.lock:
|
|
354
|
+
if self.connection is None:
|
|
355
|
+
raise RuntimeError("Connection is not available")
|
|
356
|
+
return self.connection
|
|
357
|
+
|
|
358
|
+
def send(self, message: Any, persist: bool = False) -> None:
|
|
359
|
+
with self.lock:
|
|
360
|
+
self.outbox.append(ToSend(message, persist))
|
|
361
|
+
super().send(message, persist)
|
|
362
|
+
with self.lock:
|
|
363
|
+
del self.outbox[-1]
|
|
364
|
+
|
|
365
|
+
def close(self) -> None:
|
|
366
|
+
try:
|
|
367
|
+
with self.lock:
|
|
368
|
+
if self.outbox:
|
|
369
|
+
logging.info(f"Trying to empty outbox before closing: {self.outbox}")
|
|
370
|
+
self.recovery(self.channel)
|
|
371
|
+
except BaseException:
|
|
372
|
+
logging.exception(f"Failed to empty outbox: {self.outbox}")
|
|
373
|
+
|
|
374
|
+
connection: Optional[pika.connection.Connection] = None
|
|
375
|
+
with self.lock:
|
|
376
|
+
self.terminated = True
|
|
377
|
+
self.channel = None
|
|
378
|
+
connection = self.connection
|
|
379
|
+
self.connection = None
|
|
380
|
+
|
|
381
|
+
if connection is not None:
|
|
382
|
+
try:
|
|
383
|
+
connection.ioloop.stop()
|
|
384
|
+
except BaseException:
|
|
385
|
+
logging.exception("Failed to stop ioloop")
|
|
386
|
+
finally:
|
|
387
|
+
try:
|
|
388
|
+
connection.close()
|
|
389
|
+
except BaseException:
|
|
390
|
+
logging.exception("Failed to close connection")
|
|
391
|
+
|
|
392
|
+
with self.lock:
|
|
393
|
+
if self.outbox:
|
|
394
|
+
logging.warning(f"Outbound items not sent: {self.outbox}")
|
|
395
|
+
|
|
396
|
+
def on_connection_closed(
|
|
397
|
+
self, connection: pika.connection.Connection, error: Any
|
|
398
|
+
) -> None:
|
|
399
|
+
with self.lock:
|
|
400
|
+
if not self.terminated:
|
|
401
|
+
logging.warning(
|
|
402
|
+
f"Flagging to set up a new connection because the old one ({connection}) was lost: {error}"
|
|
403
|
+
)
|
|
404
|
+
try:
|
|
405
|
+
connection.close()
|
|
406
|
+
finally:
|
|
407
|
+
try:
|
|
408
|
+
connection.ioloop.stop()
|
|
409
|
+
finally:
|
|
410
|
+
self.on_disconnect(True, copy(self.outbox))
|
|
411
|
+
else:
|
|
412
|
+
self.on_disconnect(False, copy(self.outbox))
|
|
413
|
+
|
|
414
|
+
def on_connection_open(self, connection: pika.connection.Connection) -> None:
|
|
415
|
+
self.channel = connection.channel(on_open_callback=self.on_channel_open)
|
|
416
|
+
|
|
417
|
+
def on_channel_open(self, channel: pika.channel.Channel) -> None:
|
|
418
|
+
channel.basic_qos(prefetch_count=self.parameters.prefetch)
|
|
419
|
+
channel.queue_declare(
|
|
420
|
+
queue=self.parameters.queue,
|
|
421
|
+
durable=self.parameters.durable,
|
|
422
|
+
callback=lambda _: self.on_queue_declare(channel),
|
|
423
|
+
arguments=self.queue_arguments,
|
|
424
|
+
)
|
|
425
|
+
|
|
426
|
+
def on_queue_declare(self, channel: pika.channel.Channel) -> None:
|
|
427
|
+
channel.exchange_declare(
|
|
428
|
+
exchange=self.parameters.exchange,
|
|
429
|
+
durable=self.parameters.durable,
|
|
430
|
+
callback=lambda _: channel.queue_bind(
|
|
431
|
+
queue=self.parameters.queue,
|
|
432
|
+
exchange=self.parameters.exchange,
|
|
433
|
+
routing_key=self.parameters.routing_key,
|
|
434
|
+
callback=lambda _: self.on_channel_ready(channel),
|
|
435
|
+
),
|
|
436
|
+
)
|
|
437
|
+
|
|
438
|
+
def on_channel_ready(self, channel: pika.channel.Channel) -> None:
|
|
439
|
+
self.recovery(channel)
|
|
440
|
+
channel.basic_consume(
|
|
441
|
+
on_message_callback=self.on_message,
|
|
442
|
+
queue=self.parameters.queue,
|
|
443
|
+
auto_ack=False,
|
|
444
|
+
)
|
|
445
|
+
|
|
446
|
+
def add_callback_threadsafe(self, callback: Callable[[], Any]) -> Future:
|
|
447
|
+
wrapped, future = self.futurify(callback)
|
|
448
|
+
try:
|
|
449
|
+
self.get_connection().ioloop.add_callback_threadsafe(callback)
|
|
450
|
+
except BaseException as e:
|
|
451
|
+
future.set_exception(e)
|
|
452
|
+
return future
|
|
453
|
+
|
|
454
|
+
def run(self) -> None:
|
|
455
|
+
self.get_connection().ioloop.start()
|
workercommon/test.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
import sys
|
|
3
|
+
|
|
4
|
+
if sys.version_info < (3, 8):
|
|
5
|
+
raise RuntimeError("At least Python 3.8 is required")
|
|
6
|
+
import unittest
|
|
7
|
+
|
|
8
|
+
from os import environ
|
|
9
|
+
from getpass import getpass
|
|
10
|
+
import logging
|
|
11
|
+
from database import *
|
|
12
|
+
from typing import List
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class DBTesting(object):
|
|
16
|
+
DB_PASSWORD: List[str] = []
|
|
17
|
+
|
|
18
|
+
@classmethod
|
|
19
|
+
def get_password(cls) -> str:
|
|
20
|
+
if not cls.DB_PASSWORD:
|
|
21
|
+
cls.DB_PASSWORD.append(getpass("Password: "))
|
|
22
|
+
return cls.DB_PASSWORD[0]
|
|
23
|
+
|
|
24
|
+
@classmethod
|
|
25
|
+
def build_connection(cls) -> Connection:
|
|
26
|
+
host = environ.get("DB_HOST", None)
|
|
27
|
+
|
|
28
|
+
port_number = 5432
|
|
29
|
+
if port := environ.get("DB_PORT", ""):
|
|
30
|
+
port_number = int(port)
|
|
31
|
+
|
|
32
|
+
username = environ["DB_USER"]
|
|
33
|
+
password = cls.get_password()
|
|
34
|
+
database = environ["DB_DATABASE"]
|
|
35
|
+
|
|
36
|
+
return PgConnection(host, port_number, username, password, database)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class TestConnect(unittest.TestCase, DBTesting):
|
|
40
|
+
def test_connect(self):
|
|
41
|
+
self.build_connection().close()
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class TestOperations(unittest.TestCase, DBTesting):
|
|
45
|
+
def setUp(self):
|
|
46
|
+
self.connection = self.build_connection()
|
|
47
|
+
|
|
48
|
+
def tearDown(self):
|
|
49
|
+
self.connection.close()
|
|
50
|
+
self.connection = None
|
|
51
|
+
|
|
52
|
+
def test_cursor(self):
|
|
53
|
+
cursor = self.connection.cursor()
|
|
54
|
+
self.assertIsNotNone(cursor)
|
|
55
|
+
with self.assertRaises(RuntimeError):
|
|
56
|
+
cursor.get()
|
|
57
|
+
|
|
58
|
+
with cursor:
|
|
59
|
+
self.assertIsNotNone(cursor.get())
|
|
60
|
+
|
|
61
|
+
with self.assertRaises(RuntimeError):
|
|
62
|
+
cursor.get()
|
|
63
|
+
|
|
64
|
+
def test_direct_table(self):
|
|
65
|
+
with self.connection.cursor() as cursor:
|
|
66
|
+
cursor.execute("CREATE TEMPORARY TABLE test(value1 INTEGER, value2 TEXT)")
|
|
67
|
+
cursor.execute("INSERT INTO test(value1, value2) VALUES(%s, %s)", 1, "foo")
|
|
68
|
+
|
|
69
|
+
# Basic functionality
|
|
70
|
+
row = cursor.execute_direct(
|
|
71
|
+
"SELECT value1, value2 FROM test LIMIT 1"
|
|
72
|
+
).fetchone()
|
|
73
|
+
self.assertIsNotNone(row)
|
|
74
|
+
self.assertTrue(cursor.columns)
|
|
75
|
+
self.assertEqual(row[0], 1)
|
|
76
|
+
self.assertEqual(row[1], "foo")
|
|
77
|
+
|
|
78
|
+
# Advanced functionality
|
|
79
|
+
rows = list(cursor.execute("SELECT value1, value2 FROM test LIMIT 1"))
|
|
80
|
+
self.assertEqual(len(rows), 1)
|
|
81
|
+
dict_row = rows[0]
|
|
82
|
+
self.assertEqual(dict_row["value1"], 1)
|
|
83
|
+
self.assertEqual(dict_row["value2"], "foo")
|
|
84
|
+
|
|
85
|
+
def test_table(self):
|
|
86
|
+
table_type = PgBaseTable("table_test", "key", ["key", "value1", "value2"])
|
|
87
|
+
with self.connection.cursor() as cursor:
|
|
88
|
+
table_cursor = TableCursor(cursor, table_type)
|
|
89
|
+
|
|
90
|
+
def verify2():
|
|
91
|
+
row = table_cursor[2]
|
|
92
|
+
self.assertTrue(row)
|
|
93
|
+
self.assertEqual(row["key"], 2)
|
|
94
|
+
self.assertEqual(row["value1"], 30)
|
|
95
|
+
self.assertEqual(row["value2"], "brown")
|
|
96
|
+
|
|
97
|
+
cursor.execute(
|
|
98
|
+
"CREATE TEMPORARY TABLE table_test(key SERIAL PRIMARY KEY NOT NULL, value1 INTEGER, value2 TEXT)"
|
|
99
|
+
)
|
|
100
|
+
# Test insert and get
|
|
101
|
+
self.assertEqual(table_type.insert(cursor, value1=100, value2="cow"), 1)
|
|
102
|
+
|
|
103
|
+
row = table_type.select_by_id(cursor, 1)
|
|
104
|
+
self.assertTrue(row)
|
|
105
|
+
self.assertEqual(row["key"], 1)
|
|
106
|
+
self.assertEqual(row["value1"], 100)
|
|
107
|
+
self.assertEqual(row["value2"], "cow")
|
|
108
|
+
|
|
109
|
+
# Test insert again and get
|
|
110
|
+
self.assertEqual(table_type.insert(cursor, value1=30, value2="brown"), 2)
|
|
111
|
+
|
|
112
|
+
verify2()
|
|
113
|
+
|
|
114
|
+
# Test selecting by other column values
|
|
115
|
+
rows = list(table_cursor.where(100, value1=100))
|
|
116
|
+
self.assertEqual(len(rows), 1)
|
|
117
|
+
row = rows[0]
|
|
118
|
+
self.assertEqual(row["key"], 1)
|
|
119
|
+
self.assertEqual(row["value1"], 100)
|
|
120
|
+
self.assertEqual(row["value2"], "cow")
|
|
121
|
+
|
|
122
|
+
# Test custom selecting
|
|
123
|
+
rows = list(table_cursor.custom_where("value1 > %s", 50))
|
|
124
|
+
self.assertEqual(len(rows), 1)
|
|
125
|
+
row = rows[0]
|
|
126
|
+
self.assertEqual(row["key"], 1)
|
|
127
|
+
self.assertEqual(row["value1"], 100)
|
|
128
|
+
self.assertEqual(row["value2"], "cow")
|
|
129
|
+
|
|
130
|
+
# Test selecting by other column values where a None is present
|
|
131
|
+
self.assertEqual(table_type.insert(cursor, value1=None, value2="skittle"), 3)
|
|
132
|
+
rows = list(table_cursor.where(100, value1=None))
|
|
133
|
+
self.assertEqual(len(rows), 1)
|
|
134
|
+
row = rows[0]
|
|
135
|
+
self.assertEqual(row["key"], 3)
|
|
136
|
+
self.assertIsNone(row["value1"])
|
|
137
|
+
self.assertEqual(row["value2"], "skittle")
|
|
138
|
+
|
|
139
|
+
# Test upserting on an alternate column
|
|
140
|
+
self.assertEqual(
|
|
141
|
+
table_cursor.upsert_on_column("value1", None, value2="1234"), 3
|
|
142
|
+
)
|
|
143
|
+
rows = list(table_cursor.where(100, value1=None))
|
|
144
|
+
self.assertEqual(len(rows), 1)
|
|
145
|
+
row = rows[0]
|
|
146
|
+
self.assertEqual(row["key"], 3)
|
|
147
|
+
self.assertIsNone(row["value1"])
|
|
148
|
+
self.assertEqual(row["value2"], "1234")
|
|
149
|
+
|
|
150
|
+
# Test upserting
|
|
151
|
+
table_cursor[1] = {"value1": 200, "value2": "quack"}
|
|
152
|
+
self.assertEqual(table_type.upsert(cursor, 1, value1=200, value2="quack"), 1)
|
|
153
|
+
|
|
154
|
+
row = table_type.select_by_id(cursor, 1)
|
|
155
|
+
self.assertTrue(row)
|
|
156
|
+
self.assertEqual(row["key"], 1)
|
|
157
|
+
self.assertEqual(row["value1"], 200)
|
|
158
|
+
self.assertEqual(row["value2"], "quack")
|
|
159
|
+
|
|
160
|
+
verify2()
|
|
161
|
+
|
|
162
|
+
# Test deletion
|
|
163
|
+
self.assertTrue(table_type.delete(cursor, 1))
|
|
164
|
+
with self.assertRaises(KeyError):
|
|
165
|
+
table_type.select_by_id(cursor, 1)
|
|
166
|
+
verify2()
|
|
167
|
+
|
|
168
|
+
self.assertFalse(table_type.delete(cursor, 1))
|
|
169
|
+
|
|
170
|
+
with self.assertRaises(KeyError):
|
|
171
|
+
del table_cursor[1]
|
|
172
|
+
|
|
173
|
+
del table_cursor[2]
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
logging.basicConfig(level=logging.DEBUG, stream=sys.stderr)
|
|
177
|
+
unittest.main()
|