eventsourcing 9.5.0b3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eventsourcing/__init__.py +0 -0
- eventsourcing/application.py +998 -0
- eventsourcing/cipher.py +107 -0
- eventsourcing/compressor.py +15 -0
- eventsourcing/cryptography.py +91 -0
- eventsourcing/dcb/__init__.py +0 -0
- eventsourcing/dcb/api.py +144 -0
- eventsourcing/dcb/application.py +159 -0
- eventsourcing/dcb/domain.py +369 -0
- eventsourcing/dcb/msgpack.py +38 -0
- eventsourcing/dcb/persistence.py +193 -0
- eventsourcing/dcb/popo.py +178 -0
- eventsourcing/dcb/postgres_tt.py +704 -0
- eventsourcing/dcb/tests.py +608 -0
- eventsourcing/dispatch.py +80 -0
- eventsourcing/domain.py +1964 -0
- eventsourcing/interface.py +164 -0
- eventsourcing/persistence.py +1429 -0
- eventsourcing/popo.py +267 -0
- eventsourcing/postgres.py +1441 -0
- eventsourcing/projection.py +502 -0
- eventsourcing/py.typed +0 -0
- eventsourcing/sqlite.py +816 -0
- eventsourcing/system.py +1203 -0
- eventsourcing/tests/__init__.py +3 -0
- eventsourcing/tests/application.py +483 -0
- eventsourcing/tests/domain.py +105 -0
- eventsourcing/tests/persistence.py +1744 -0
- eventsourcing/tests/postgres_utils.py +131 -0
- eventsourcing/utils.py +257 -0
- eventsourcing-9.5.0b3.dist-info/METADATA +253 -0
- eventsourcing-9.5.0b3.dist-info/RECORD +35 -0
- eventsourcing-9.5.0b3.dist-info/WHEEL +4 -0
- eventsourcing-9.5.0b3.dist-info/licenses/AUTHORS +10 -0
- eventsourcing-9.5.0b3.dist-info/licenses/LICENSE +29 -0
|
@@ -0,0 +1,1429 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import typing
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from collections import deque
|
|
7
|
+
from collections.abc import Callable, Hashable, Iterator, Mapping, Sequence
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from datetime import datetime
|
|
10
|
+
from decimal import Decimal
|
|
11
|
+
from functools import lru_cache
|
|
12
|
+
from queue import Queue
|
|
13
|
+
from threading import Condition, Event, Lock, Semaphore, Thread, Timer
|
|
14
|
+
from time import monotonic, sleep, time
|
|
15
|
+
from types import GenericAlias, ModuleType, TracebackType
|
|
16
|
+
from typing import Any, Generic, cast
|
|
17
|
+
from uuid import UUID
|
|
18
|
+
|
|
19
|
+
from typing_extensions import Self, TypeVar
|
|
20
|
+
|
|
21
|
+
from eventsourcing.domain import (
|
|
22
|
+
DomainEventProtocol,
|
|
23
|
+
EventSourcingError,
|
|
24
|
+
HasOriginatorIDVersion,
|
|
25
|
+
TAggregateID,
|
|
26
|
+
)
|
|
27
|
+
from eventsourcing.utils import (
|
|
28
|
+
Environment,
|
|
29
|
+
EnvType,
|
|
30
|
+
TopicError,
|
|
31
|
+
get_topic,
|
|
32
|
+
resolve_topic,
|
|
33
|
+
strtobool,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class Transcoding(ABC):
|
|
38
|
+
"""Abstract base class for custom transcodings."""
|
|
39
|
+
|
|
40
|
+
type: type
|
|
41
|
+
name: str
|
|
42
|
+
|
|
43
|
+
@abstractmethod
|
|
44
|
+
def encode(self, obj: Any) -> Any:
|
|
45
|
+
"""Encodes given object."""
|
|
46
|
+
|
|
47
|
+
@abstractmethod
|
|
48
|
+
def decode(self, data: Any) -> Any:
|
|
49
|
+
"""Decodes encoded object."""
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class Transcoder(ABC):
|
|
53
|
+
"""Abstract base class for transcoders."""
|
|
54
|
+
|
|
55
|
+
@abstractmethod
|
|
56
|
+
def encode(self, obj: Any) -> bytes:
|
|
57
|
+
"""Encodes obj as bytes."""
|
|
58
|
+
|
|
59
|
+
@abstractmethod
|
|
60
|
+
def decode(self, data: bytes) -> Any:
|
|
61
|
+
"""Decodes obj from bytes."""
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class TranscodingNotRegisteredError(EventSourcingError, TypeError):
|
|
65
|
+
"""Raised when a transcoding isn't registered with JSONTranscoder."""
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class JSONTranscoder(Transcoder):
|
|
69
|
+
"""Extensible transcoder that uses the Python :mod:`json` module."""
|
|
70
|
+
|
|
71
|
+
def __init__(self) -> None:
|
|
72
|
+
self.types: dict[type, Transcoding] = {}
|
|
73
|
+
self.names: dict[str, Transcoding] = {}
|
|
74
|
+
self.encoder = json.JSONEncoder(
|
|
75
|
+
default=self._encode_obj,
|
|
76
|
+
separators=(",", ":"),
|
|
77
|
+
ensure_ascii=False,
|
|
78
|
+
)
|
|
79
|
+
self.decoder = json.JSONDecoder(object_hook=self._decode_obj)
|
|
80
|
+
|
|
81
|
+
def register(self, transcoding: Transcoding) -> None:
|
|
82
|
+
"""Registers given transcoding with the transcoder."""
|
|
83
|
+
self.types[transcoding.type] = transcoding
|
|
84
|
+
self.names[transcoding.name] = transcoding
|
|
85
|
+
|
|
86
|
+
def encode(self, obj: Any) -> bytes:
|
|
87
|
+
"""Encodes given object as a bytes array."""
|
|
88
|
+
return self.encoder.encode(obj).encode("utf8")
|
|
89
|
+
|
|
90
|
+
def decode(self, data: bytes) -> Any:
|
|
91
|
+
"""Decodes bytes array as previously encoded object."""
|
|
92
|
+
return self.decoder.decode(data.decode("utf8"))
|
|
93
|
+
|
|
94
|
+
def _encode_obj(self, o: Any) -> dict[str, Any]:
|
|
95
|
+
try:
|
|
96
|
+
transcoding = self.types[type(o)]
|
|
97
|
+
except KeyError:
|
|
98
|
+
msg = (
|
|
99
|
+
f"Object of type {type(o)} is not "
|
|
100
|
+
"serializable. Please define and register "
|
|
101
|
+
"a custom transcoding for this type."
|
|
102
|
+
)
|
|
103
|
+
raise TranscodingNotRegisteredError(msg) from None
|
|
104
|
+
else:
|
|
105
|
+
return {
|
|
106
|
+
"_type_": transcoding.name,
|
|
107
|
+
"_data_": transcoding.encode(o),
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
def _decode_obj(self, d: dict[str, Any]) -> Any:
|
|
111
|
+
if len(d) == 2:
|
|
112
|
+
try:
|
|
113
|
+
_type_ = d["_type_"]
|
|
114
|
+
except KeyError:
|
|
115
|
+
return d
|
|
116
|
+
else:
|
|
117
|
+
try:
|
|
118
|
+
_data_ = d["_data_"]
|
|
119
|
+
except KeyError:
|
|
120
|
+
return d
|
|
121
|
+
else:
|
|
122
|
+
try:
|
|
123
|
+
transcoding = self.names[cast("str", _type_)]
|
|
124
|
+
except KeyError as e:
|
|
125
|
+
msg = (
|
|
126
|
+
f"Data serialized with name '{cast('str', _type_)}' is not "
|
|
127
|
+
"deserializable. Please register a "
|
|
128
|
+
"custom transcoding for this type."
|
|
129
|
+
)
|
|
130
|
+
raise TranscodingNotRegisteredError(msg) from e
|
|
131
|
+
else:
|
|
132
|
+
return transcoding.decode(_data_)
|
|
133
|
+
else:
|
|
134
|
+
return d
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
class UUIDAsHex(Transcoding):
|
|
138
|
+
"""Transcoding that represents :class:`UUID` objects as hex values."""
|
|
139
|
+
|
|
140
|
+
type = UUID
|
|
141
|
+
name = "uuid_hex"
|
|
142
|
+
|
|
143
|
+
def encode(self, obj: UUID) -> str:
|
|
144
|
+
return obj.hex
|
|
145
|
+
|
|
146
|
+
def decode(self, data: str) -> UUID:
|
|
147
|
+
assert isinstance(data, str)
|
|
148
|
+
return UUID(data)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
class DecimalAsStr(Transcoding):
|
|
152
|
+
"""Transcoding that represents :class:`Decimal` objects as strings."""
|
|
153
|
+
|
|
154
|
+
type = Decimal
|
|
155
|
+
name = "decimal_str"
|
|
156
|
+
|
|
157
|
+
def encode(self, obj: Decimal) -> str:
|
|
158
|
+
return str(obj)
|
|
159
|
+
|
|
160
|
+
def decode(self, data: str) -> Decimal:
|
|
161
|
+
return Decimal(data)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
class DatetimeAsISO(Transcoding):
|
|
165
|
+
"""Transcoding that represents :class:`datetime` objects as ISO strings."""
|
|
166
|
+
|
|
167
|
+
type = datetime
|
|
168
|
+
name = "datetime_iso"
|
|
169
|
+
|
|
170
|
+
def encode(self, obj: datetime) -> str:
|
|
171
|
+
return obj.isoformat()
|
|
172
|
+
|
|
173
|
+
def decode(self, data: str) -> datetime:
|
|
174
|
+
assert isinstance(data, str)
|
|
175
|
+
return datetime.fromisoformat(data)
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
@dataclass(frozen=True)
|
|
179
|
+
class StoredEvent:
|
|
180
|
+
"""Frozen dataclass that represents :class:`~eventsourcing.domain.DomainEvent`
|
|
181
|
+
objects, such as aggregate :class:`~eventsourcing.domain.Aggregate.Event`
|
|
182
|
+
objects and :class:`~eventsourcing.domain.Snapshot` objects.
|
|
183
|
+
"""
|
|
184
|
+
|
|
185
|
+
originator_id: UUID | str
|
|
186
|
+
"""ID of the originating aggregate."""
|
|
187
|
+
originator_version: int
|
|
188
|
+
"""Position in an aggregate sequence."""
|
|
189
|
+
topic: str
|
|
190
|
+
"""Topic of a domain event object class."""
|
|
191
|
+
state: bytes
|
|
192
|
+
"""Serialised state of a domain event object."""
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
class Compressor(ABC):
|
|
196
|
+
"""Base class for compressors."""
|
|
197
|
+
|
|
198
|
+
@abstractmethod
|
|
199
|
+
def compress(self, data: bytes) -> bytes:
|
|
200
|
+
"""Compress bytes."""
|
|
201
|
+
|
|
202
|
+
@abstractmethod
|
|
203
|
+
def decompress(self, data: bytes) -> bytes:
|
|
204
|
+
"""Decompress bytes."""
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
class Cipher(ABC):
|
|
208
|
+
"""Base class for ciphers."""
|
|
209
|
+
|
|
210
|
+
@abstractmethod
|
|
211
|
+
def __init__(self, environment: Environment):
|
|
212
|
+
"""Initialises cipher with given environment."""
|
|
213
|
+
|
|
214
|
+
@abstractmethod
|
|
215
|
+
def encrypt(self, plaintext: bytes) -> bytes:
|
|
216
|
+
"""Return ciphertext for given plaintext."""
|
|
217
|
+
|
|
218
|
+
@abstractmethod
|
|
219
|
+
def decrypt(self, ciphertext: bytes) -> bytes:
|
|
220
|
+
"""Return plaintext for given ciphertext."""
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
class MapperDeserialisationError(EventSourcingError, ValueError):
|
|
224
|
+
"""Raised when deserialization fails in a Mapper."""
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
TAggregateIDType = TypeVar("TAggregateIDType", type[UUID], type[str])
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
class Mapper(Generic[TAggregateID]):
|
|
231
|
+
"""Converts between domain event objects and :class:`StoredEvent` objects.
|
|
232
|
+
|
|
233
|
+
Uses a :class:`Transcoder`, and optionally a cryptographic cipher and compressor.
|
|
234
|
+
"""
|
|
235
|
+
|
|
236
|
+
def __init__(
|
|
237
|
+
self,
|
|
238
|
+
transcoder: Transcoder,
|
|
239
|
+
compressor: Compressor | None = None,
|
|
240
|
+
cipher: Cipher | None = None,
|
|
241
|
+
):
|
|
242
|
+
self.transcoder = transcoder
|
|
243
|
+
self.compressor = compressor
|
|
244
|
+
self.cipher = cipher
|
|
245
|
+
|
|
246
|
+
def to_stored_event(
|
|
247
|
+
self, domain_event: DomainEventProtocol[TAggregateID]
|
|
248
|
+
) -> StoredEvent:
|
|
249
|
+
"""Converts the given domain event to a :class:`StoredEvent` object."""
|
|
250
|
+
topic = get_topic(domain_event.__class__)
|
|
251
|
+
event_state = domain_event.__dict__.copy()
|
|
252
|
+
originator_id = event_state.pop("originator_id")
|
|
253
|
+
originator_version = event_state.pop("originator_version")
|
|
254
|
+
class_version = getattr(type(domain_event), "class_version", 1)
|
|
255
|
+
if class_version > 1:
|
|
256
|
+
event_state["class_version"] = class_version
|
|
257
|
+
stored_state = self.transcoder.encode(event_state)
|
|
258
|
+
if self.compressor:
|
|
259
|
+
stored_state = self.compressor.compress(stored_state)
|
|
260
|
+
if self.cipher:
|
|
261
|
+
stored_state = self.cipher.encrypt(stored_state)
|
|
262
|
+
return StoredEvent(
|
|
263
|
+
originator_id=originator_id,
|
|
264
|
+
originator_version=originator_version,
|
|
265
|
+
topic=topic,
|
|
266
|
+
state=stored_state,
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
def to_domain_event(
|
|
270
|
+
self, stored_event: StoredEvent
|
|
271
|
+
) -> DomainEventProtocol[TAggregateID]:
|
|
272
|
+
"""Converts the given :class:`StoredEvent` to a domain event object."""
|
|
273
|
+
cls = resolve_topic(stored_event.topic)
|
|
274
|
+
|
|
275
|
+
stored_state = stored_event.state
|
|
276
|
+
try:
|
|
277
|
+
if self.cipher:
|
|
278
|
+
stored_state = self.cipher.decrypt(stored_state)
|
|
279
|
+
if self.compressor:
|
|
280
|
+
stored_state = self.compressor.decompress(stored_state)
|
|
281
|
+
event_state: dict[str, Any] = self.transcoder.decode(stored_state)
|
|
282
|
+
except Exception as e:
|
|
283
|
+
msg = (
|
|
284
|
+
f"Failed to deserialise state of stored event with "
|
|
285
|
+
f"topic '{stored_event.topic}', "
|
|
286
|
+
f"originator_id '{stored_event.originator_id}' and "
|
|
287
|
+
f"originator_version {stored_event.originator_version}: {e}"
|
|
288
|
+
)
|
|
289
|
+
raise MapperDeserialisationError(msg) from e
|
|
290
|
+
|
|
291
|
+
id_convertor = find_id_convertor(
|
|
292
|
+
cls, cast(Hashable, type(stored_event.originator_id))
|
|
293
|
+
)
|
|
294
|
+
# print("ID of convertor:", id(convertor))
|
|
295
|
+
event_state["originator_id"] = id_convertor(stored_event.originator_id)
|
|
296
|
+
event_state["originator_version"] = stored_event.originator_version
|
|
297
|
+
class_version = getattr(cls, "class_version", 1)
|
|
298
|
+
from_version = event_state.pop("class_version", 1)
|
|
299
|
+
while from_version < class_version:
|
|
300
|
+
getattr(cls, f"upcast_v{from_version}_v{from_version + 1}")(event_state)
|
|
301
|
+
from_version += 1
|
|
302
|
+
|
|
303
|
+
domain_event = object.__new__(cls)
|
|
304
|
+
domain_event.__dict__.update(event_state)
|
|
305
|
+
return domain_event
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
@lru_cache
|
|
309
|
+
def find_id_convertor(
|
|
310
|
+
domain_event_cls: type[object], originator_id_cls: type[UUID | str]
|
|
311
|
+
) -> Callable[[UUID | str], UUID | str]:
|
|
312
|
+
# Try to find the originator_id type.
|
|
313
|
+
if issubclass(domain_event_cls, HasOriginatorIDVersion):
|
|
314
|
+
# For classes that inherit CanMutateAggregate, and don't use a different
|
|
315
|
+
# mapper, then assume they aren't overriding __init_subclass__ is a way
|
|
316
|
+
# that prevents 'originator_id_type' being found from type arguments and
|
|
317
|
+
# set on the class.
|
|
318
|
+
# TODO: Write a test where a custom class does override __init_subclass__
|
|
319
|
+
# so that the next line will cause an AssertionError. Then fix this code.
|
|
320
|
+
if domain_event_cls.originator_id_type is None:
|
|
321
|
+
msg = "originator_id_type cannot be None"
|
|
322
|
+
raise TypeError(msg)
|
|
323
|
+
originator_id_type = domain_event_cls.originator_id_type
|
|
324
|
+
else:
|
|
325
|
+
# Otherwise look for annotations.
|
|
326
|
+
for cls in domain_event_cls.__mro__:
|
|
327
|
+
try:
|
|
328
|
+
annotation = cls.__annotations__["originator_id"]
|
|
329
|
+
except (KeyError, AttributeError): # noqa: PERF203
|
|
330
|
+
continue
|
|
331
|
+
else:
|
|
332
|
+
valid_annotations = {
|
|
333
|
+
str: str,
|
|
334
|
+
UUID: UUID,
|
|
335
|
+
"str": str,
|
|
336
|
+
"UUID": UUID,
|
|
337
|
+
"uuid.UUID": UUID,
|
|
338
|
+
}
|
|
339
|
+
if annotation not in valid_annotations:
|
|
340
|
+
msg = f"originator_id annotation on {cls} is not either UUID or str"
|
|
341
|
+
raise TypeError(msg)
|
|
342
|
+
assert annotation in valid_annotations, annotation
|
|
343
|
+
originator_id_type = valid_annotations[annotation]
|
|
344
|
+
break
|
|
345
|
+
else:
|
|
346
|
+
msg = (
|
|
347
|
+
f"Neither event class {domain_event_cls}"
|
|
348
|
+
f"nor its bases have an originator_id annotation"
|
|
349
|
+
)
|
|
350
|
+
raise TypeError(msg)
|
|
351
|
+
|
|
352
|
+
if originator_id_cls is str and originator_id_type is UUID:
|
|
353
|
+
convertor = str_to_uuid_convertor
|
|
354
|
+
else:
|
|
355
|
+
convertor = pass_through_convertor
|
|
356
|
+
return convertor
|
|
357
|
+
|
|
358
|
+
|
|
359
|
+
def str_to_uuid_convertor(originator_id: UUID | str) -> UUID | str:
|
|
360
|
+
assert isinstance(originator_id, str)
|
|
361
|
+
return UUID(originator_id)
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
def pass_through_convertor(originator_id: UUID | str) -> UUID | str:
|
|
365
|
+
return originator_id
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
class RecordConflictError(EventSourcingError):
|
|
369
|
+
"""Legacy exception, replaced with IntegrityError."""
|
|
370
|
+
|
|
371
|
+
|
|
372
|
+
class PersistenceError(EventSourcingError):
|
|
373
|
+
"""The base class of the other exceptions in this module.
|
|
374
|
+
|
|
375
|
+
Exception class names follow https://www.python.org/dev/peps/pep-0249/#exceptions
|
|
376
|
+
"""
|
|
377
|
+
|
|
378
|
+
|
|
379
|
+
class InterfaceError(PersistenceError):
|
|
380
|
+
"""Exception raised for errors that are related to the database
|
|
381
|
+
interface rather than the database itself.
|
|
382
|
+
"""
|
|
383
|
+
|
|
384
|
+
|
|
385
|
+
class DatabaseError(PersistenceError):
|
|
386
|
+
"""Exception raised for errors that are related to the database."""
|
|
387
|
+
|
|
388
|
+
|
|
389
|
+
class DataError(DatabaseError):
|
|
390
|
+
"""Exception raised for errors that are due to problems with the
|
|
391
|
+
processed data like division by zero, numeric value out of range, etc.
|
|
392
|
+
"""
|
|
393
|
+
|
|
394
|
+
|
|
395
|
+
class OperationalError(DatabaseError):
|
|
396
|
+
"""Exception raised for errors that are related to the database's
|
|
397
|
+
operation and not necessarily under the control of the programmer,
|
|
398
|
+
e.g. an unexpected disconnect occurs, the data source name is not
|
|
399
|
+
found, a transaction could not be processed, a memory allocation
|
|
400
|
+
error occurred during processing, etc.
|
|
401
|
+
"""
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
class IntegrityError(DatabaseError, RecordConflictError):
|
|
405
|
+
"""Exception raised when the relational integrity of the
|
|
406
|
+
database is affected, e.g. a foreign key check fails.
|
|
407
|
+
"""
|
|
408
|
+
|
|
409
|
+
|
|
410
|
+
class InternalError(DatabaseError):
|
|
411
|
+
"""Exception raised when the database encounters an internal
|
|
412
|
+
error, e.g. the cursor is not valid anymore, the transaction
|
|
413
|
+
is out of sync, etc.
|
|
414
|
+
"""
|
|
415
|
+
|
|
416
|
+
|
|
417
|
+
class ProgrammingError(DatabaseError):
|
|
418
|
+
"""Exception raised for database programming errors, e.g. table
|
|
419
|
+
not found or already exists, syntax error in the SQL statement,
|
|
420
|
+
wrong number of parameters specified, etc.
|
|
421
|
+
"""
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
class NotSupportedError(DatabaseError):
|
|
425
|
+
"""Exception raised in case a method or database API was used
|
|
426
|
+
which is not supported by the database, e.g. calling the
|
|
427
|
+
rollback() method on a connection that does not support
|
|
428
|
+
transaction or has transactions turned off.
|
|
429
|
+
"""
|
|
430
|
+
|
|
431
|
+
|
|
432
|
+
class WaitInterruptedError(PersistenceError):
|
|
433
|
+
"""Raised when waiting for a tracking record is interrupted."""
|
|
434
|
+
|
|
435
|
+
|
|
436
|
+
class Recorder:
|
|
437
|
+
pass
|
|
438
|
+
|
|
439
|
+
|
|
440
|
+
class AggregateRecorder(Recorder, ABC):
|
|
441
|
+
"""Abstract base class for inserting and selecting stored events."""
|
|
442
|
+
|
|
443
|
+
@abstractmethod
|
|
444
|
+
def insert_events(
|
|
445
|
+
self, stored_events: Sequence[StoredEvent], **kwargs: Any
|
|
446
|
+
) -> Sequence[int] | None:
|
|
447
|
+
"""Writes stored events into database."""
|
|
448
|
+
|
|
449
|
+
@abstractmethod
|
|
450
|
+
def select_events(
|
|
451
|
+
self,
|
|
452
|
+
originator_id: UUID | str,
|
|
453
|
+
*,
|
|
454
|
+
gt: int | None = None,
|
|
455
|
+
lte: int | None = None,
|
|
456
|
+
desc: bool = False,
|
|
457
|
+
limit: int | None = None,
|
|
458
|
+
) -> Sequence[StoredEvent]:
|
|
459
|
+
"""Reads stored events from database."""
|
|
460
|
+
|
|
461
|
+
|
|
462
|
+
@dataclass(frozen=True)
|
|
463
|
+
class Notification(StoredEvent):
|
|
464
|
+
"""Frozen dataclass that represents domain event notifications."""
|
|
465
|
+
|
|
466
|
+
id: int
|
|
467
|
+
"""Position in an application sequence."""
|
|
468
|
+
|
|
469
|
+
|
|
470
|
+
class ApplicationRecorder(AggregateRecorder):
|
|
471
|
+
"""Abstract base class for recording events in both aggregate
|
|
472
|
+
and application sequences.
|
|
473
|
+
"""
|
|
474
|
+
|
|
475
|
+
@abstractmethod
|
|
476
|
+
def select_notifications(
|
|
477
|
+
self,
|
|
478
|
+
start: int | None,
|
|
479
|
+
limit: int,
|
|
480
|
+
stop: int | None = None,
|
|
481
|
+
topics: Sequence[str] = (),
|
|
482
|
+
*,
|
|
483
|
+
inclusive_of_start: bool = True,
|
|
484
|
+
) -> Sequence[Notification]:
|
|
485
|
+
"""Returns a list of Notification objects representing events from an
|
|
486
|
+
application sequence. If `inclusive_of_start` is True (the default),
|
|
487
|
+
the returned Notification objects will have IDs greater than or equal
|
|
488
|
+
to `start` and less than or equal to `stop`. If `inclusive_of_start`
|
|
489
|
+
is False, the Notification objects will have IDs greater than `start`
|
|
490
|
+
and less than or equal to `stop`.
|
|
491
|
+
"""
|
|
492
|
+
|
|
493
|
+
@abstractmethod
|
|
494
|
+
def max_notification_id(self) -> int | None:
|
|
495
|
+
"""Returns the largest notification ID in an application sequence,
|
|
496
|
+
or None if no stored events have been recorded.
|
|
497
|
+
"""
|
|
498
|
+
|
|
499
|
+
@abstractmethod
|
|
500
|
+
def subscribe(
|
|
501
|
+
self, gt: int | None = None, topics: Sequence[str] = ()
|
|
502
|
+
) -> Subscription[ApplicationRecorder]:
|
|
503
|
+
"""Returns an iterator of Notification objects representing events from an
|
|
504
|
+
application sequence.
|
|
505
|
+
|
|
506
|
+
The iterator will block after the last recorded event has been yielded, but
|
|
507
|
+
will then continue yielding newly recorded events when they are recorded.
|
|
508
|
+
|
|
509
|
+
Notifications will have IDs greater than the optional `gt` argument.
|
|
510
|
+
"""
|
|
511
|
+
|
|
512
|
+
|
|
513
|
+
class TrackingRecorder(Recorder, ABC):
|
|
514
|
+
"""Abstract base class for recorders that record tracking
|
|
515
|
+
objects atomically with other state.
|
|
516
|
+
"""
|
|
517
|
+
|
|
518
|
+
@abstractmethod
|
|
519
|
+
def insert_tracking(self, tracking: Tracking) -> None:
|
|
520
|
+
"""Records a tracking object."""
|
|
521
|
+
|
|
522
|
+
@abstractmethod
|
|
523
|
+
def max_tracking_id(self, application_name: str) -> int | None:
|
|
524
|
+
"""Returns the largest notification ID across all recorded tracking objects
|
|
525
|
+
for the named application, or None if no tracking objects have been recorded.
|
|
526
|
+
"""
|
|
527
|
+
|
|
528
|
+
def has_tracking_id(
|
|
529
|
+
self, application_name: str, notification_id: int | None
|
|
530
|
+
) -> bool:
|
|
531
|
+
"""Returns True if given notification_id is None or a tracking
|
|
532
|
+
object with the given application_name and a notification ID greater
|
|
533
|
+
than or equal to the given notification_id has been recorded.
|
|
534
|
+
"""
|
|
535
|
+
if notification_id is None:
|
|
536
|
+
return True
|
|
537
|
+
max_tracking_id = self.max_tracking_id(application_name)
|
|
538
|
+
return max_tracking_id is not None and max_tracking_id >= notification_id
|
|
539
|
+
|
|
540
|
+
def wait(
|
|
541
|
+
self,
|
|
542
|
+
application_name: str,
|
|
543
|
+
notification_id: int | None,
|
|
544
|
+
timeout: float = 1.0,
|
|
545
|
+
interrupt: Event | None = None,
|
|
546
|
+
) -> None:
|
|
547
|
+
"""Block until a tracking object with the given application name and a
|
|
548
|
+
notification ID greater than equal to the given value has been recorded.
|
|
549
|
+
|
|
550
|
+
Polls max_tracking_id() with exponential backoff until the timeout
|
|
551
|
+
is reached, or until the optional interrupt event is set.
|
|
552
|
+
|
|
553
|
+
The timeout argument should be a floating point number specifying a
|
|
554
|
+
timeout for the operation in seconds (or fractions thereof). The default
|
|
555
|
+
is 1.0 seconds.
|
|
556
|
+
|
|
557
|
+
Raises TimeoutError if the timeout is reached.
|
|
558
|
+
|
|
559
|
+
Raises WaitInterruptError if the `interrupt` is set before `timeout` is reached.
|
|
560
|
+
"""
|
|
561
|
+
deadline = monotonic() + timeout
|
|
562
|
+
sleep_interval_ms = 100.0
|
|
563
|
+
max_sleep_interval_ms = 800.0
|
|
564
|
+
while True:
|
|
565
|
+
if self.has_tracking_id(application_name, notification_id):
|
|
566
|
+
break
|
|
567
|
+
if interrupt:
|
|
568
|
+
if interrupt.wait(timeout=sleep_interval_ms / 1000):
|
|
569
|
+
raise WaitInterruptedError
|
|
570
|
+
else:
|
|
571
|
+
sleep(sleep_interval_ms / 1000)
|
|
572
|
+
remaining = deadline - monotonic()
|
|
573
|
+
if remaining < 0:
|
|
574
|
+
msg = (
|
|
575
|
+
f"Timed out waiting for notification {notification_id} "
|
|
576
|
+
f"from application '{application_name}' to be processed"
|
|
577
|
+
)
|
|
578
|
+
raise TimeoutError(msg)
|
|
579
|
+
sleep_interval_ms = min(
|
|
580
|
+
sleep_interval_ms * 2, remaining * 1000, max_sleep_interval_ms
|
|
581
|
+
)
|
|
582
|
+
|
|
583
|
+
|
|
584
|
+
class ProcessRecorder(TrackingRecorder, ApplicationRecorder, ABC):
|
|
585
|
+
pass
|
|
586
|
+
|
|
587
|
+
|
|
588
|
+
@dataclass(frozen=True)
|
|
589
|
+
class Recording(Generic[TAggregateID]):
|
|
590
|
+
"""Represents the recording of a domain event."""
|
|
591
|
+
|
|
592
|
+
domain_event: DomainEventProtocol[TAggregateID]
|
|
593
|
+
"""The domain event that has been recorded."""
|
|
594
|
+
notification: Notification
|
|
595
|
+
"""A Notification that represents the domain event in the application sequence."""
|
|
596
|
+
|
|
597
|
+
|
|
598
|
+
class EventStore(Generic[TAggregateID]):
|
|
599
|
+
"""Stores and retrieves domain events."""
|
|
600
|
+
|
|
601
|
+
def __init__(
|
|
602
|
+
self,
|
|
603
|
+
mapper: Mapper[TAggregateID],
|
|
604
|
+
recorder: AggregateRecorder,
|
|
605
|
+
):
|
|
606
|
+
self.mapper: Mapper[TAggregateID] = mapper
|
|
607
|
+
self.recorder = recorder
|
|
608
|
+
|
|
609
|
+
def put(
|
|
610
|
+
self, domain_events: Sequence[DomainEventProtocol[TAggregateID]], **kwargs: Any
|
|
611
|
+
) -> list[Recording[TAggregateID]]:
|
|
612
|
+
"""Stores domain events in aggregate sequence."""
|
|
613
|
+
stored_events = list(map(self.mapper.to_stored_event, domain_events))
|
|
614
|
+
recordings = []
|
|
615
|
+
notification_ids = self.recorder.insert_events(stored_events, **kwargs)
|
|
616
|
+
if notification_ids:
|
|
617
|
+
assert len(notification_ids) == len(stored_events)
|
|
618
|
+
for d, s, n_id in zip(
|
|
619
|
+
domain_events, stored_events, notification_ids, strict=True
|
|
620
|
+
):
|
|
621
|
+
recordings.append(
|
|
622
|
+
Recording(
|
|
623
|
+
d,
|
|
624
|
+
Notification(
|
|
625
|
+
originator_id=s.originator_id,
|
|
626
|
+
originator_version=s.originator_version,
|
|
627
|
+
topic=s.topic,
|
|
628
|
+
state=s.state,
|
|
629
|
+
id=n_id,
|
|
630
|
+
),
|
|
631
|
+
)
|
|
632
|
+
)
|
|
633
|
+
return recordings
|
|
634
|
+
|
|
635
|
+
def get(
|
|
636
|
+
self,
|
|
637
|
+
originator_id: TAggregateID,
|
|
638
|
+
*,
|
|
639
|
+
gt: int | None = None,
|
|
640
|
+
lte: int | None = None,
|
|
641
|
+
desc: bool = False,
|
|
642
|
+
limit: int | None = None,
|
|
643
|
+
) -> Iterator[DomainEventProtocol[TAggregateID]]:
|
|
644
|
+
"""Retrieves domain events from aggregate sequence."""
|
|
645
|
+
return map(
|
|
646
|
+
self.mapper.to_domain_event,
|
|
647
|
+
self.recorder.select_events(
|
|
648
|
+
originator_id=originator_id,
|
|
649
|
+
gt=gt,
|
|
650
|
+
lte=lte,
|
|
651
|
+
desc=desc,
|
|
652
|
+
limit=limit,
|
|
653
|
+
),
|
|
654
|
+
)
|
|
655
|
+
|
|
656
|
+
|
|
657
|
+
TTrackingRecorder = TypeVar(
|
|
658
|
+
"TTrackingRecorder", bound=TrackingRecorder, default=TrackingRecorder
|
|
659
|
+
)
|
|
660
|
+
|
|
661
|
+
|
|
662
|
+
class InfrastructureFactoryError(EventSourcingError):
|
|
663
|
+
"""Raised when an infrastructure factory cannot be created."""
|
|
664
|
+
|
|
665
|
+
|
|
666
|
+
class BaseInfrastructureFactory(ABC, Generic[TTrackingRecorder]):
|
|
667
|
+
"""Abstract base class for infrastructure factories."""
|
|
668
|
+
|
|
669
|
+
PERSISTENCE_MODULE = "PERSISTENCE_MODULE"
|
|
670
|
+
TRANSCODER_TOPIC = "TRANSCODER_TOPIC"
|
|
671
|
+
CIPHER_TOPIC = "CIPHER_TOPIC"
|
|
672
|
+
COMPRESSOR_TOPIC = "COMPRESSOR_TOPIC"
|
|
673
|
+
|
|
674
|
+
def __init__(self, env: Environment | EnvType | None):
|
|
675
|
+
"""Initialises infrastructure factory object with given application name."""
|
|
676
|
+
self.env = env if isinstance(env, Environment) else Environment(env=env)
|
|
677
|
+
self._is_entered = False
|
|
678
|
+
|
|
679
|
+
def __enter__(self) -> Self:
|
|
680
|
+
self._is_entered = True
|
|
681
|
+
return self
|
|
682
|
+
|
|
683
|
+
def __exit__(
|
|
684
|
+
self,
|
|
685
|
+
exc_type: type[BaseException] | None,
|
|
686
|
+
exc_val: BaseException | None,
|
|
687
|
+
exc_tb: TracebackType | None,
|
|
688
|
+
) -> None:
|
|
689
|
+
self._is_entered = False
|
|
690
|
+
|
|
691
|
+
def close(self) -> None:
|
|
692
|
+
"""Closes any database connections, and anything else that needs closing."""
|
|
693
|
+
|
|
694
|
+
@classmethod
|
|
695
|
+
def construct(
|
|
696
|
+
cls: type[Self],
|
|
697
|
+
env: Environment | None = None,
|
|
698
|
+
) -> Self:
|
|
699
|
+
"""Constructs concrete infrastructure factory for given
|
|
700
|
+
named application. Reads and resolves persistence
|
|
701
|
+
topic from environment variable 'PERSISTENCE_MODULE'.
|
|
702
|
+
"""
|
|
703
|
+
factory_cls: type[Self]
|
|
704
|
+
if env is None:
|
|
705
|
+
env = Environment()
|
|
706
|
+
topic = (
|
|
707
|
+
env.get(
|
|
708
|
+
cls.PERSISTENCE_MODULE,
|
|
709
|
+
"",
|
|
710
|
+
)
|
|
711
|
+
or env.get(
|
|
712
|
+
"INFRASTRUCTURE_FACTORY", # Legacy.
|
|
713
|
+
"",
|
|
714
|
+
)
|
|
715
|
+
or env.get(
|
|
716
|
+
"FACTORY_TOPIC", # Legacy.
|
|
717
|
+
"",
|
|
718
|
+
)
|
|
719
|
+
or "eventsourcing.popo"
|
|
720
|
+
)
|
|
721
|
+
try:
|
|
722
|
+
obj: type[Self] | ModuleType = resolve_topic(topic)
|
|
723
|
+
except TopicError as e:
|
|
724
|
+
msg = (
|
|
725
|
+
"Failed to resolve persistence module topic: "
|
|
726
|
+
f"'{topic}' from environment "
|
|
727
|
+
f"variable '{cls.PERSISTENCE_MODULE}'"
|
|
728
|
+
)
|
|
729
|
+
raise InfrastructureFactoryError(msg) from e
|
|
730
|
+
|
|
731
|
+
if isinstance(obj, ModuleType):
|
|
732
|
+
# Find the factory in the module.
|
|
733
|
+
factory_classes = set[type[Self]]()
|
|
734
|
+
for member in obj.__dict__.values():
|
|
735
|
+
# Look for classes...
|
|
736
|
+
if not isinstance(member, type):
|
|
737
|
+
continue
|
|
738
|
+
# Issue with Python 3.9 and 3.10.
|
|
739
|
+
if isinstance(member, GenericAlias):
|
|
740
|
+
continue # pragma: no cover (for Python > 3.10 only)
|
|
741
|
+
if not issubclass(member, cls):
|
|
742
|
+
continue
|
|
743
|
+
if getattr(member, "__parameters__", None):
|
|
744
|
+
continue
|
|
745
|
+
factory_classes.add(member)
|
|
746
|
+
|
|
747
|
+
if len(factory_classes) == 1:
|
|
748
|
+
factory_cls = next(iter(factory_classes))
|
|
749
|
+
else:
|
|
750
|
+
msg = (
|
|
751
|
+
f"Found {len(factory_classes)} infrastructure factory classes in"
|
|
752
|
+
f" '{topic}', expected 1."
|
|
753
|
+
)
|
|
754
|
+
raise InfrastructureFactoryError(msg)
|
|
755
|
+
elif isinstance(obj, type) and issubclass(obj, cls):
|
|
756
|
+
factory_cls = obj
|
|
757
|
+
else:
|
|
758
|
+
msg = (
|
|
759
|
+
f"Topic '{topic}' didn't resolve to a persistence module "
|
|
760
|
+
f"or infrastructure factory class: {obj}"
|
|
761
|
+
)
|
|
762
|
+
raise InfrastructureFactoryError(msg)
|
|
763
|
+
return factory_cls(env=env)
|
|
764
|
+
|
|
765
|
+
def transcoder(
|
|
766
|
+
self,
|
|
767
|
+
) -> Transcoder:
|
|
768
|
+
"""Constructs a transcoder."""
|
|
769
|
+
transcoder_topic = self.env.get(self.TRANSCODER_TOPIC)
|
|
770
|
+
if transcoder_topic:
|
|
771
|
+
transcoder_class: type[Transcoder] = resolve_topic(transcoder_topic)
|
|
772
|
+
else:
|
|
773
|
+
transcoder_class = JSONTranscoder
|
|
774
|
+
return transcoder_class()
|
|
775
|
+
|
|
776
|
+
def cipher(self) -> Cipher | None:
|
|
777
|
+
"""Reads environment variables 'CIPHER_TOPIC'
|
|
778
|
+
and 'CIPHER_KEY' to decide whether or not
|
|
779
|
+
to construct a cipher.
|
|
780
|
+
"""
|
|
781
|
+
cipher_topic = self.env.get(self.CIPHER_TOPIC)
|
|
782
|
+
cipher: Cipher | None = None
|
|
783
|
+
default_cipher_topic = "eventsourcing.cipher:AESCipher"
|
|
784
|
+
if self.env.get("CIPHER_KEY") and not cipher_topic:
|
|
785
|
+
cipher_topic = default_cipher_topic
|
|
786
|
+
|
|
787
|
+
if cipher_topic:
|
|
788
|
+
cipher_cls: type[Cipher] = resolve_topic(cipher_topic)
|
|
789
|
+
cipher = cipher_cls(self.env)
|
|
790
|
+
|
|
791
|
+
return cipher
|
|
792
|
+
|
|
793
|
+
def compressor(self) -> Compressor | None:
|
|
794
|
+
"""Reads environment variable 'COMPRESSOR_TOPIC' to
|
|
795
|
+
decide whether or not to construct a compressor.
|
|
796
|
+
"""
|
|
797
|
+
compressor: Compressor | None = None
|
|
798
|
+
compressor_topic = self.env.get(self.COMPRESSOR_TOPIC)
|
|
799
|
+
if compressor_topic:
|
|
800
|
+
compressor_cls: type[Compressor] | Compressor = resolve_topic(
|
|
801
|
+
compressor_topic
|
|
802
|
+
)
|
|
803
|
+
if isinstance(compressor_cls, type):
|
|
804
|
+
compressor = compressor_cls()
|
|
805
|
+
else:
|
|
806
|
+
compressor = compressor_cls
|
|
807
|
+
return compressor
|
|
808
|
+
|
|
809
|
+
|
|
810
|
+
class InfrastructureFactory(BaseInfrastructureFactory[TTrackingRecorder]):
|
|
811
|
+
"""Abstract base class for Application factories."""
|
|
812
|
+
|
|
813
|
+
MAPPER_TOPIC = "MAPPER_TOPIC"
|
|
814
|
+
IS_SNAPSHOTTING_ENABLED = "IS_SNAPSHOTTING_ENABLED"
|
|
815
|
+
APPLICATION_RECORDER_TOPIC = "APPLICATION_RECORDER_TOPIC"
|
|
816
|
+
TRACKING_RECORDER_TOPIC = "TRACKING_RECORDER_TOPIC"
|
|
817
|
+
PROCESS_RECORDER_TOPIC = "PROCESS_RECORDER_TOPIC"
|
|
818
|
+
|
|
819
|
+
def mapper(
|
|
820
|
+
self,
|
|
821
|
+
transcoder: Transcoder | None = None,
|
|
822
|
+
mapper_class: type[Mapper[TAggregateID]] | None = None,
|
|
823
|
+
) -> Mapper[TAggregateID]:
|
|
824
|
+
"""Constructs a mapper."""
|
|
825
|
+
# Resolve MAPPER_TOPIC if no given class.
|
|
826
|
+
if mapper_class is None:
|
|
827
|
+
mapper_topic = self.env.get(self.MAPPER_TOPIC)
|
|
828
|
+
mapper_class = (
|
|
829
|
+
resolve_topic(mapper_topic) if mapper_topic else Mapper[TAggregateID]
|
|
830
|
+
)
|
|
831
|
+
|
|
832
|
+
# Check we have a mapper class.
|
|
833
|
+
assert mapper_class is not None
|
|
834
|
+
origin_mapper_class = typing.get_origin(mapper_class) or mapper_class
|
|
835
|
+
assert isinstance(origin_mapper_class, type), mapper_class
|
|
836
|
+
assert issubclass(origin_mapper_class, Mapper), mapper_class
|
|
837
|
+
|
|
838
|
+
# Construct and return a mapper.
|
|
839
|
+
return mapper_class(
|
|
840
|
+
transcoder=transcoder or self.transcoder(),
|
|
841
|
+
cipher=self.cipher(),
|
|
842
|
+
compressor=self.compressor(),
|
|
843
|
+
)
|
|
844
|
+
|
|
845
|
+
def event_store(
|
|
846
|
+
self,
|
|
847
|
+
mapper: Mapper[TAggregateID] | None = None,
|
|
848
|
+
recorder: AggregateRecorder | None = None,
|
|
849
|
+
) -> EventStore[TAggregateID]:
|
|
850
|
+
"""Constructs an event store."""
|
|
851
|
+
return EventStore(
|
|
852
|
+
mapper=mapper or self.mapper(),
|
|
853
|
+
recorder=recorder or self.application_recorder(),
|
|
854
|
+
)
|
|
855
|
+
|
|
856
|
+
@abstractmethod
|
|
857
|
+
def aggregate_recorder(self, purpose: str = "events") -> AggregateRecorder:
|
|
858
|
+
"""Constructs an aggregate recorder."""
|
|
859
|
+
|
|
860
|
+
@abstractmethod
|
|
861
|
+
def application_recorder(self) -> ApplicationRecorder:
|
|
862
|
+
"""Constructs an application recorder."""
|
|
863
|
+
|
|
864
|
+
@abstractmethod
|
|
865
|
+
def tracking_recorder(
|
|
866
|
+
self, tracking_recorder_class: type[TTrackingRecorder] | None = None
|
|
867
|
+
) -> TTrackingRecorder:
|
|
868
|
+
"""Constructs a tracking recorder."""
|
|
869
|
+
|
|
870
|
+
@abstractmethod
|
|
871
|
+
def process_recorder(self) -> ProcessRecorder:
|
|
872
|
+
"""Constructs a process recorder."""
|
|
873
|
+
|
|
874
|
+
def is_snapshotting_enabled(self) -> bool:
|
|
875
|
+
"""Decides whether or not snapshotting is enabled by
|
|
876
|
+
reading environment variable 'IS_SNAPSHOTTING_ENABLED'.
|
|
877
|
+
Snapshotting is not enabled by default.
|
|
878
|
+
"""
|
|
879
|
+
return strtobool(self.env.get(self.IS_SNAPSHOTTING_ENABLED, "no"))
|
|
880
|
+
|
|
881
|
+
|
|
882
|
+
@dataclass(frozen=True)
|
|
883
|
+
class Tracking:
|
|
884
|
+
"""Frozen dataclass representing the position of a domain
|
|
885
|
+
event :class:`Notification` in an application's notification log.
|
|
886
|
+
"""
|
|
887
|
+
|
|
888
|
+
application_name: str
|
|
889
|
+
notification_id: int
|
|
890
|
+
|
|
891
|
+
|
|
892
|
+
Params = Sequence[Any] | Mapping[str, Any]
|
|
893
|
+
|
|
894
|
+
|
|
895
|
+
class Cursor(ABC):
|
|
896
|
+
@abstractmethod
|
|
897
|
+
def execute(self, statement: str | bytes, params: Params | None = None) -> None:
|
|
898
|
+
"""Executes given statement."""
|
|
899
|
+
|
|
900
|
+
@abstractmethod
|
|
901
|
+
def fetchall(self) -> Any:
|
|
902
|
+
"""Fetches all results."""
|
|
903
|
+
|
|
904
|
+
@abstractmethod
|
|
905
|
+
def fetchone(self) -> Any:
|
|
906
|
+
"""Fetches one result."""
|
|
907
|
+
|
|
908
|
+
|
|
909
|
+
TCursor = TypeVar("TCursor", bound=Cursor)
|
|
910
|
+
|
|
911
|
+
|
|
912
|
+
class Connection(ABC, Generic[TCursor]):
|
|
913
|
+
def __init__(self, max_age: float | None = None) -> None:
|
|
914
|
+
self._closed = False
|
|
915
|
+
self._closing = Event()
|
|
916
|
+
self._close_lock = Lock()
|
|
917
|
+
self.in_use = Lock()
|
|
918
|
+
self.in_use.acquire()
|
|
919
|
+
if max_age is not None:
|
|
920
|
+
self._max_age_timer: Timer | None = Timer(
|
|
921
|
+
interval=max_age,
|
|
922
|
+
function=self._close_when_not_in_use,
|
|
923
|
+
)
|
|
924
|
+
self._max_age_timer.daemon = True
|
|
925
|
+
self._max_age_timer.start()
|
|
926
|
+
else:
|
|
927
|
+
self._max_age_timer = None
|
|
928
|
+
self.is_writer: bool | None = None
|
|
929
|
+
|
|
930
|
+
@property
|
|
931
|
+
def closed(self) -> bool:
|
|
932
|
+
return self._closed
|
|
933
|
+
|
|
934
|
+
@property
|
|
935
|
+
def closing(self) -> bool:
|
|
936
|
+
return self._closing.is_set()
|
|
937
|
+
|
|
938
|
+
@abstractmethod
|
|
939
|
+
def commit(self) -> None:
|
|
940
|
+
"""Commits transaction."""
|
|
941
|
+
|
|
942
|
+
@abstractmethod
|
|
943
|
+
def rollback(self) -> None:
|
|
944
|
+
"""Rolls back transaction."""
|
|
945
|
+
|
|
946
|
+
@abstractmethod
|
|
947
|
+
def cursor(self) -> TCursor:
|
|
948
|
+
"""Creates new cursor."""
|
|
949
|
+
|
|
950
|
+
def close(self) -> None:
|
|
951
|
+
with self._close_lock:
|
|
952
|
+
self._close()
|
|
953
|
+
|
|
954
|
+
@abstractmethod
|
|
955
|
+
def _close(self) -> None:
|
|
956
|
+
self._closed = True
|
|
957
|
+
if self._max_age_timer:
|
|
958
|
+
self._max_age_timer.cancel()
|
|
959
|
+
|
|
960
|
+
def _close_when_not_in_use(self) -> None:
|
|
961
|
+
self._closing.set()
|
|
962
|
+
with self.in_use:
|
|
963
|
+
if not self._closed:
|
|
964
|
+
self.close()
|
|
965
|
+
|
|
966
|
+
|
|
967
|
+
TConnection = TypeVar("TConnection", bound=Connection[Any])
|
|
968
|
+
|
|
969
|
+
|
|
970
|
+
class ConnectionPoolClosedError(EventSourcingError):
|
|
971
|
+
"""Raised when using a connection pool that is already closed."""
|
|
972
|
+
|
|
973
|
+
|
|
974
|
+
class ConnectionNotFromPoolError(EventSourcingError):
|
|
975
|
+
"""Raised when putting a connection in the wrong pool."""
|
|
976
|
+
|
|
977
|
+
|
|
978
|
+
class ConnectionUnavailableError(OperationalError, TimeoutError):
|
|
979
|
+
"""Raised when a request to get a connection from a
|
|
980
|
+
connection pool times out.
|
|
981
|
+
"""
|
|
982
|
+
|
|
983
|
+
|
|
984
|
+
class ConnectionPool(ABC, Generic[TConnection]):
|
|
985
|
+
def __init__(
|
|
986
|
+
self,
|
|
987
|
+
*,
|
|
988
|
+
pool_size: int = 5,
|
|
989
|
+
max_overflow: int = 10,
|
|
990
|
+
pool_timeout: float = 30.0,
|
|
991
|
+
max_age: float | None = None,
|
|
992
|
+
pre_ping: bool = False,
|
|
993
|
+
mutually_exclusive_read_write: bool = False,
|
|
994
|
+
) -> None:
|
|
995
|
+
"""Initialises a new connection pool.
|
|
996
|
+
|
|
997
|
+
The 'pool_size' argument specifies the maximum number of connections
|
|
998
|
+
that will be put into the pool when connections are returned. The
|
|
999
|
+
default value is 5
|
|
1000
|
+
|
|
1001
|
+
The 'max_overflow' argument specifies the additional number of
|
|
1002
|
+
connections that can be issued by the pool, above the 'pool_size'.
|
|
1003
|
+
The default value is 10.
|
|
1004
|
+
|
|
1005
|
+
The 'pool_timeout' argument specifies the maximum time in seconds
|
|
1006
|
+
to keep requests for connections waiting. Connections are kept
|
|
1007
|
+
waiting if the number of connections currently in use is not less
|
|
1008
|
+
than the sum of 'pool_size' and 'max_overflow'. The default value
|
|
1009
|
+
is 30.0
|
|
1010
|
+
|
|
1011
|
+
The 'max_age' argument specifies the time in seconds until a
|
|
1012
|
+
connection will automatically be closed. Connections are only closed
|
|
1013
|
+
in this way after are not in use. Connections that are in use will
|
|
1014
|
+
not be closed automatically. The default value in None, meaning
|
|
1015
|
+
connections will not be automatically closed in this way.
|
|
1016
|
+
|
|
1017
|
+
The 'mutually_exclusive_read_write' argument specifies whether
|
|
1018
|
+
requests for connections for writing whilst connections for reading
|
|
1019
|
+
are in use. It also specifies whether requests for connections for reading
|
|
1020
|
+
will be kept waiting whilst a connection for writing is in use. The default
|
|
1021
|
+
value is false, meaning reading and writing will not be mutually exclusive
|
|
1022
|
+
in this way.
|
|
1023
|
+
"""
|
|
1024
|
+
self.pool_size = pool_size
|
|
1025
|
+
self.max_overflow = max_overflow
|
|
1026
|
+
self.pool_timeout = pool_timeout
|
|
1027
|
+
self.max_age = max_age
|
|
1028
|
+
self.pre_ping = pre_ping
|
|
1029
|
+
self._pool: deque[TConnection] = deque()
|
|
1030
|
+
self._in_use: dict[int, TConnection] = {}
|
|
1031
|
+
self._get_semaphore = Semaphore()
|
|
1032
|
+
self._put_condition = Condition()
|
|
1033
|
+
self._no_readers = Condition()
|
|
1034
|
+
self._num_readers: int = 0
|
|
1035
|
+
self._writer_lock = Lock()
|
|
1036
|
+
self._num_writers: int = 0
|
|
1037
|
+
self._mutually_exclusive_read_write = mutually_exclusive_read_write
|
|
1038
|
+
self._closed = False
|
|
1039
|
+
|
|
1040
|
+
@property
|
|
1041
|
+
def closed(self) -> bool:
|
|
1042
|
+
return self._closed
|
|
1043
|
+
|
|
1044
|
+
@property
|
|
1045
|
+
def num_in_use(self) -> int:
|
|
1046
|
+
"""Indicates the total number of connections currently in use."""
|
|
1047
|
+
with self._put_condition:
|
|
1048
|
+
return self._num_in_use
|
|
1049
|
+
|
|
1050
|
+
@property
|
|
1051
|
+
def _num_in_use(self) -> int:
|
|
1052
|
+
return len(self._in_use)
|
|
1053
|
+
|
|
1054
|
+
@property
|
|
1055
|
+
def num_in_pool(self) -> int:
|
|
1056
|
+
"""Indicates the number of connections currently in the pool."""
|
|
1057
|
+
with self._put_condition:
|
|
1058
|
+
return self._num_in_pool
|
|
1059
|
+
|
|
1060
|
+
@property
|
|
1061
|
+
def _num_in_pool(self) -> int:
|
|
1062
|
+
return len(self._pool)
|
|
1063
|
+
|
|
1064
|
+
@property
|
|
1065
|
+
def _is_pool_full(self) -> bool:
|
|
1066
|
+
return self._num_in_pool >= self.pool_size
|
|
1067
|
+
|
|
1068
|
+
@property
|
|
1069
|
+
def _is_use_full(self) -> bool:
|
|
1070
|
+
return self._num_in_use >= self.pool_size + self.max_overflow
|
|
1071
|
+
|
|
1072
|
+
def get_connection(
|
|
1073
|
+
self, timeout: float | None = None, *, is_writer: bool | None = None
|
|
1074
|
+
) -> TConnection:
|
|
1075
|
+
"""Issues connections, or raises ConnectionPoolExhausted error.
|
|
1076
|
+
Provides "fairness" on attempts to get connections, meaning that
|
|
1077
|
+
connections are issued in the same order as they are requested.
|
|
1078
|
+
|
|
1079
|
+
The 'timeout' argument overrides the timeout specified
|
|
1080
|
+
by the constructor argument 'pool_timeout'. The default
|
|
1081
|
+
value is None, meaning the 'pool_timeout' argument will
|
|
1082
|
+
not be overridden.
|
|
1083
|
+
|
|
1084
|
+
The optional 'is_writer' argument can be used to request
|
|
1085
|
+
a connection for writing (true), and request a connection
|
|
1086
|
+
for reading (false). If the value of this argument is None,
|
|
1087
|
+
which is the default, the writing and reading interlocking
|
|
1088
|
+
mechanism is not activated. Only one connection for writing
|
|
1089
|
+
will be issued, which means requests for connections for
|
|
1090
|
+
writing are kept waiting whilst another connection for writing
|
|
1091
|
+
is in use.
|
|
1092
|
+
|
|
1093
|
+
If reading and writing are mutually exclusive, requsts for
|
|
1094
|
+
connections for writing are kept waiting whilst connections
|
|
1095
|
+
for reading are in use, and requests for connections for reading
|
|
1096
|
+
are kept waiting whilst a connection for writing is in use.
|
|
1097
|
+
"""
|
|
1098
|
+
# Make sure we aren't dealing with a closed pool.
|
|
1099
|
+
if self._closed:
|
|
1100
|
+
raise ConnectionPoolClosedError
|
|
1101
|
+
|
|
1102
|
+
# Decide the timeout for getting a connection.
|
|
1103
|
+
timeout = self.pool_timeout if timeout is None else timeout
|
|
1104
|
+
|
|
1105
|
+
# Remember when we started trying to get a connection.
|
|
1106
|
+
started = time()
|
|
1107
|
+
|
|
1108
|
+
# Join queue of threads waiting to get a connection ("fairness").
|
|
1109
|
+
if self._get_semaphore.acquire(timeout=timeout):
|
|
1110
|
+
try:
|
|
1111
|
+
# If connection is for writing, get write lock and wait for no readers.
|
|
1112
|
+
if is_writer is True:
|
|
1113
|
+
if not self._writer_lock.acquire(
|
|
1114
|
+
timeout=self._time_remaining(timeout, started)
|
|
1115
|
+
):
|
|
1116
|
+
msg = "Timed out waiting for return of writer"
|
|
1117
|
+
raise ConnectionUnavailableError(msg)
|
|
1118
|
+
if self._mutually_exclusive_read_write:
|
|
1119
|
+
with self._no_readers:
|
|
1120
|
+
if self._num_readers > 0 and not self._no_readers.wait(
|
|
1121
|
+
timeout=self._time_remaining(timeout, started)
|
|
1122
|
+
):
|
|
1123
|
+
self._writer_lock.release()
|
|
1124
|
+
msg = "Timed out waiting for return of reader"
|
|
1125
|
+
raise ConnectionUnavailableError(msg)
|
|
1126
|
+
self._num_writers += 1
|
|
1127
|
+
|
|
1128
|
+
# If connection is for reading, and writing excludes reading,
|
|
1129
|
+
# then wait for the writer lock, and increment number of readers.
|
|
1130
|
+
elif is_writer is False:
|
|
1131
|
+
if self._mutually_exclusive_read_write:
|
|
1132
|
+
if not self._writer_lock.acquire(
|
|
1133
|
+
timeout=self._time_remaining(timeout, started)
|
|
1134
|
+
):
|
|
1135
|
+
msg = "Timed out waiting for return of writer"
|
|
1136
|
+
raise ConnectionUnavailableError(msg)
|
|
1137
|
+
self._writer_lock.release()
|
|
1138
|
+
with self._no_readers:
|
|
1139
|
+
self._num_readers += 1
|
|
1140
|
+
|
|
1141
|
+
# Actually try to get a connection withing the time remaining.
|
|
1142
|
+
conn = self._get_connection(
|
|
1143
|
+
timeout=self._time_remaining(timeout, started)
|
|
1144
|
+
)
|
|
1145
|
+
|
|
1146
|
+
# Remember if this connection is for reading or writing.
|
|
1147
|
+
conn.is_writer = is_writer
|
|
1148
|
+
|
|
1149
|
+
# Return the connection.
|
|
1150
|
+
return conn
|
|
1151
|
+
finally:
|
|
1152
|
+
self._get_semaphore.release()
|
|
1153
|
+
else:
|
|
1154
|
+
# Timed out waiting for semaphore.
|
|
1155
|
+
msg = "Timed out waiting for connection pool semaphore"
|
|
1156
|
+
raise ConnectionUnavailableError(msg)
|
|
1157
|
+
|
|
1158
|
+
def _get_connection(self, timeout: float = 0.0) -> TConnection:
|
|
1159
|
+
"""Gets or creates connections from pool within given
|
|
1160
|
+
time, otherwise raises a "pool exhausted" error.
|
|
1161
|
+
|
|
1162
|
+
Waits for connections to be returned if the pool
|
|
1163
|
+
is fully used. And optionally ensures a connection
|
|
1164
|
+
is usable before returning a connection for use.
|
|
1165
|
+
|
|
1166
|
+
Tracks use of connections, and number of readers.
|
|
1167
|
+
"""
|
|
1168
|
+
started = time()
|
|
1169
|
+
# Get lock on tracking usage of connections.
|
|
1170
|
+
with self._put_condition:
|
|
1171
|
+
# Try to get a connection from the pool.
|
|
1172
|
+
try:
|
|
1173
|
+
conn = self._pool.popleft()
|
|
1174
|
+
except IndexError:
|
|
1175
|
+
# Pool is empty, but are connections fully used?
|
|
1176
|
+
if self._is_use_full:
|
|
1177
|
+
# Fully used, so wait for a connection to be returned.
|
|
1178
|
+
if self._put_condition.wait(
|
|
1179
|
+
timeout=self._time_remaining(timeout, started)
|
|
1180
|
+
):
|
|
1181
|
+
# Connection has been returned, so try again.
|
|
1182
|
+
return self._get_connection(
|
|
1183
|
+
timeout=self._time_remaining(timeout, started)
|
|
1184
|
+
)
|
|
1185
|
+
# Timed out waiting for a connection to be returned.
|
|
1186
|
+
msg = "Timed out waiting for return of connection"
|
|
1187
|
+
raise ConnectionUnavailableError(msg) from None
|
|
1188
|
+
# Not fully used, so create a new connection.
|
|
1189
|
+
conn = self._create_connection()
|
|
1190
|
+
# print("created another connection")
|
|
1191
|
+
|
|
1192
|
+
# Connection should be pre-locked for use (avoids timer race).
|
|
1193
|
+
assert conn.in_use.locked()
|
|
1194
|
+
|
|
1195
|
+
else:
|
|
1196
|
+
# Got unused connection from pool, so lock for use.
|
|
1197
|
+
conn.in_use.acquire()
|
|
1198
|
+
|
|
1199
|
+
# Check the connection wasn't closed by the timer.
|
|
1200
|
+
if conn.closed:
|
|
1201
|
+
return self._get_connection(
|
|
1202
|
+
timeout=self._time_remaining(timeout, started)
|
|
1203
|
+
)
|
|
1204
|
+
|
|
1205
|
+
# Check the connection is actually usable.
|
|
1206
|
+
if self.pre_ping:
|
|
1207
|
+
try:
|
|
1208
|
+
conn.cursor().execute("SELECT 1")
|
|
1209
|
+
except Exception:
|
|
1210
|
+
# Probably connection is closed on server,
|
|
1211
|
+
# but just try to make sure it is closed.
|
|
1212
|
+
conn.close()
|
|
1213
|
+
|
|
1214
|
+
# Try again to get a connection.
|
|
1215
|
+
return self._get_connection(
|
|
1216
|
+
timeout=self._time_remaining(timeout, started)
|
|
1217
|
+
)
|
|
1218
|
+
|
|
1219
|
+
# Track the connection is now being used.
|
|
1220
|
+
self._in_use[id(conn)] = conn
|
|
1221
|
+
|
|
1222
|
+
# Return the connection.
|
|
1223
|
+
return conn
|
|
1224
|
+
|
|
1225
|
+
def put_connection(self, conn: TConnection) -> None:
|
|
1226
|
+
"""Returns connections to the pool, or closes connection
|
|
1227
|
+
if the pool is full.
|
|
1228
|
+
|
|
1229
|
+
Unlocks write lock after writer has returned, and
|
|
1230
|
+
updates count of readers when readers are returned.
|
|
1231
|
+
|
|
1232
|
+
Notifies waiters when connections have been returned,
|
|
1233
|
+
and when there are no longer any readers.
|
|
1234
|
+
"""
|
|
1235
|
+
# Start forgetting if this connection was for reading or writing.
|
|
1236
|
+
is_writer, conn.is_writer = conn.is_writer, None
|
|
1237
|
+
|
|
1238
|
+
# Get a lock on tracking usage of connections.
|
|
1239
|
+
with self._put_condition:
|
|
1240
|
+
# Make sure we aren't dealing with a closed pool
|
|
1241
|
+
if self._closed:
|
|
1242
|
+
msg = "Pool is closed"
|
|
1243
|
+
raise ConnectionPoolClosedError(msg)
|
|
1244
|
+
|
|
1245
|
+
# Make sure we are dealing with a connection from this pool.
|
|
1246
|
+
try:
|
|
1247
|
+
del self._in_use[id(conn)]
|
|
1248
|
+
except KeyError:
|
|
1249
|
+
msg = "Connection not in use in this pool"
|
|
1250
|
+
raise ConnectionNotFromPoolError(msg) from None
|
|
1251
|
+
|
|
1252
|
+
if not conn.closed:
|
|
1253
|
+
# Put open connection in pool if not full.
|
|
1254
|
+
if not conn.closing and not self._is_pool_full:
|
|
1255
|
+
self._pool.append(conn)
|
|
1256
|
+
# Close open connection if the pool is full or timer has fired.
|
|
1257
|
+
else:
|
|
1258
|
+
# Otherwise, close the connection.
|
|
1259
|
+
conn.close()
|
|
1260
|
+
|
|
1261
|
+
# Unlock the connection for subsequent use (and for closing by the timer).
|
|
1262
|
+
conn.in_use.release()
|
|
1263
|
+
|
|
1264
|
+
# If the connection was for writing, unlock the writer lock.
|
|
1265
|
+
if is_writer is True:
|
|
1266
|
+
self._num_writers -= 1
|
|
1267
|
+
self._writer_lock.release()
|
|
1268
|
+
|
|
1269
|
+
# Or if it was for reading, decrement the number of readers.
|
|
1270
|
+
elif is_writer is False:
|
|
1271
|
+
with self._no_readers:
|
|
1272
|
+
self._num_readers -= 1
|
|
1273
|
+
if self._num_readers == 0 and self._mutually_exclusive_read_write:
|
|
1274
|
+
self._no_readers.notify()
|
|
1275
|
+
|
|
1276
|
+
# Notify a thread that is waiting for a connection to be returned.
|
|
1277
|
+
self._put_condition.notify()
|
|
1278
|
+
|
|
1279
|
+
@abstractmethod
|
|
1280
|
+
def _create_connection(self) -> TConnection:
|
|
1281
|
+
"""Create a new connection.
|
|
1282
|
+
|
|
1283
|
+
Subclasses should implement this method by
|
|
1284
|
+
creating a database connection of the type
|
|
1285
|
+
being pooled.
|
|
1286
|
+
"""
|
|
1287
|
+
|
|
1288
|
+
def close(self) -> None:
|
|
1289
|
+
"""Close the connection pool."""
|
|
1290
|
+
with self._put_condition:
|
|
1291
|
+
if self._closed:
|
|
1292
|
+
return
|
|
1293
|
+
for conn in self._in_use.values():
|
|
1294
|
+
conn.close()
|
|
1295
|
+
while True:
|
|
1296
|
+
try:
|
|
1297
|
+
conn = self._pool.popleft()
|
|
1298
|
+
except IndexError: # noqa: PERF203
|
|
1299
|
+
break
|
|
1300
|
+
else:
|
|
1301
|
+
conn.close()
|
|
1302
|
+
self._closed = True
|
|
1303
|
+
|
|
1304
|
+
@staticmethod
|
|
1305
|
+
def _time_remaining(timeout: float, started: float) -> float:
|
|
1306
|
+
return max(0.0, timeout + started - time())
|
|
1307
|
+
|
|
1308
|
+
def __del__(self) -> None:
|
|
1309
|
+
self.close()
|
|
1310
|
+
|
|
1311
|
+
|
|
1312
|
+
TApplicationRecorder_co = TypeVar(
|
|
1313
|
+
"TApplicationRecorder_co", bound=ApplicationRecorder, covariant=True
|
|
1314
|
+
)
|
|
1315
|
+
|
|
1316
|
+
|
|
1317
|
+
class Subscription(Iterator[Notification], Generic[TApplicationRecorder_co]):
|
|
1318
|
+
def __init__(
|
|
1319
|
+
self,
|
|
1320
|
+
recorder: TApplicationRecorder_co,
|
|
1321
|
+
gt: int | None = None,
|
|
1322
|
+
topics: Sequence[str] = (),
|
|
1323
|
+
) -> None:
|
|
1324
|
+
self._recorder = recorder
|
|
1325
|
+
self._last_notification_id = gt
|
|
1326
|
+
self._topics = topics
|
|
1327
|
+
self._has_been_entered = False
|
|
1328
|
+
self._has_been_stopped = False
|
|
1329
|
+
|
|
1330
|
+
def __enter__(self) -> Self:
|
|
1331
|
+
if self._has_been_entered:
|
|
1332
|
+
msg = "Already entered subscription context manager"
|
|
1333
|
+
raise ProgrammingError(msg)
|
|
1334
|
+
self._has_been_entered = True
|
|
1335
|
+
return self
|
|
1336
|
+
|
|
1337
|
+
def __exit__(self, *args: object, **kwargs: Any) -> None:
|
|
1338
|
+
if not self._has_been_entered:
|
|
1339
|
+
msg = "Not already entered subscription context manager"
|
|
1340
|
+
raise ProgrammingError(msg)
|
|
1341
|
+
self.stop()
|
|
1342
|
+
|
|
1343
|
+
def stop(self) -> None:
|
|
1344
|
+
"""Stops the subscription."""
|
|
1345
|
+
self._has_been_stopped = True
|
|
1346
|
+
|
|
1347
|
+
def __iter__(self) -> Self:
|
|
1348
|
+
return self
|
|
1349
|
+
|
|
1350
|
+
@abstractmethod
|
|
1351
|
+
def __next__(self) -> Notification:
|
|
1352
|
+
"""Returns the next Notification object in the application sequence."""
|
|
1353
|
+
|
|
1354
|
+
|
|
1355
|
+
class ListenNotifySubscription(Subscription[TApplicationRecorder_co]):
|
|
1356
|
+
def __init__(
|
|
1357
|
+
self,
|
|
1358
|
+
recorder: TApplicationRecorder_co,
|
|
1359
|
+
gt: int | None = None,
|
|
1360
|
+
topics: Sequence[str] = (),
|
|
1361
|
+
) -> None:
|
|
1362
|
+
super().__init__(recorder=recorder, gt=gt, topics=topics)
|
|
1363
|
+
self._select_limit = 500
|
|
1364
|
+
self._notifications: Sequence[Notification] = []
|
|
1365
|
+
self._notifications_index: int = 0
|
|
1366
|
+
self._notifications_queue: Queue[Sequence[Notification]] = Queue(maxsize=10)
|
|
1367
|
+
self._has_been_notified = Event()
|
|
1368
|
+
self._thread_error: BaseException | None = None
|
|
1369
|
+
self._pull_thread = Thread(target=self._loop_on_pull)
|
|
1370
|
+
self._pull_thread.start()
|
|
1371
|
+
|
|
1372
|
+
def __exit__(self, *args: object, **kwargs: Any) -> None:
|
|
1373
|
+
try:
|
|
1374
|
+
super().__exit__(*args, **kwargs)
|
|
1375
|
+
finally:
|
|
1376
|
+
self._pull_thread.join()
|
|
1377
|
+
|
|
1378
|
+
def stop(self) -> None:
|
|
1379
|
+
"""Stops the subscription."""
|
|
1380
|
+
super().stop()
|
|
1381
|
+
self._notifications_queue.put([])
|
|
1382
|
+
self._has_been_notified.set()
|
|
1383
|
+
|
|
1384
|
+
def __next__(self) -> Notification:
|
|
1385
|
+
# If necessary, get a new list of notifications from the recorder.
|
|
1386
|
+
if (
|
|
1387
|
+
self._notifications_index == len(self._notifications)
|
|
1388
|
+
and not self._has_been_stopped
|
|
1389
|
+
):
|
|
1390
|
+
self._notifications = self._notifications_queue.get()
|
|
1391
|
+
self._notifications_index = 0
|
|
1392
|
+
|
|
1393
|
+
# Stop the iteration if necessary, maybe raise thread error.
|
|
1394
|
+
if self._has_been_stopped or not self._notifications:
|
|
1395
|
+
if self._thread_error is not None:
|
|
1396
|
+
raise self._thread_error
|
|
1397
|
+
raise StopIteration
|
|
1398
|
+
|
|
1399
|
+
# Return a notification from previously obtained list.
|
|
1400
|
+
notification = self._notifications[self._notifications_index]
|
|
1401
|
+
self._notifications_index += 1
|
|
1402
|
+
return notification
|
|
1403
|
+
|
|
1404
|
+
def _loop_on_pull(self) -> None:
|
|
1405
|
+
try:
|
|
1406
|
+
self._pull() # Already recorded events.
|
|
1407
|
+
while not self._has_been_stopped:
|
|
1408
|
+
self._has_been_notified.wait()
|
|
1409
|
+
self._pull() # Newly recorded events.
|
|
1410
|
+
except BaseException as e:
|
|
1411
|
+
if self._thread_error is None:
|
|
1412
|
+
self._thread_error = e
|
|
1413
|
+
self.stop()
|
|
1414
|
+
|
|
1415
|
+
def _pull(self) -> None:
|
|
1416
|
+
while not self._has_been_stopped:
|
|
1417
|
+
self._has_been_notified.clear()
|
|
1418
|
+
notifications = self._recorder.select_notifications(
|
|
1419
|
+
start=self._last_notification_id or 0,
|
|
1420
|
+
limit=self._select_limit,
|
|
1421
|
+
topics=self._topics,
|
|
1422
|
+
inclusive_of_start=False,
|
|
1423
|
+
)
|
|
1424
|
+
if len(notifications) > 0:
|
|
1425
|
+
# print("Putting", len(notifications), "notifications into queue")
|
|
1426
|
+
self._notifications_queue.put(notifications)
|
|
1427
|
+
self._last_notification_id = notifications[-1].id
|
|
1428
|
+
if len(notifications) < self._select_limit:
|
|
1429
|
+
break
|