eventsourcing 9.5.0b3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1429 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import typing
5
+ from abc import ABC, abstractmethod
6
+ from collections import deque
7
+ from collections.abc import Callable, Hashable, Iterator, Mapping, Sequence
8
+ from dataclasses import dataclass
9
+ from datetime import datetime
10
+ from decimal import Decimal
11
+ from functools import lru_cache
12
+ from queue import Queue
13
+ from threading import Condition, Event, Lock, Semaphore, Thread, Timer
14
+ from time import monotonic, sleep, time
15
+ from types import GenericAlias, ModuleType, TracebackType
16
+ from typing import Any, Generic, cast
17
+ from uuid import UUID
18
+
19
+ from typing_extensions import Self, TypeVar
20
+
21
+ from eventsourcing.domain import (
22
+ DomainEventProtocol,
23
+ EventSourcingError,
24
+ HasOriginatorIDVersion,
25
+ TAggregateID,
26
+ )
27
+ from eventsourcing.utils import (
28
+ Environment,
29
+ EnvType,
30
+ TopicError,
31
+ get_topic,
32
+ resolve_topic,
33
+ strtobool,
34
+ )
35
+
36
+
37
+ class Transcoding(ABC):
38
+ """Abstract base class for custom transcodings."""
39
+
40
+ type: type
41
+ name: str
42
+
43
+ @abstractmethod
44
+ def encode(self, obj: Any) -> Any:
45
+ """Encodes given object."""
46
+
47
+ @abstractmethod
48
+ def decode(self, data: Any) -> Any:
49
+ """Decodes encoded object."""
50
+
51
+
52
+ class Transcoder(ABC):
53
+ """Abstract base class for transcoders."""
54
+
55
+ @abstractmethod
56
+ def encode(self, obj: Any) -> bytes:
57
+ """Encodes obj as bytes."""
58
+
59
+ @abstractmethod
60
+ def decode(self, data: bytes) -> Any:
61
+ """Decodes obj from bytes."""
62
+
63
+
64
+ class TranscodingNotRegisteredError(EventSourcingError, TypeError):
65
+ """Raised when a transcoding isn't registered with JSONTranscoder."""
66
+
67
+
68
+ class JSONTranscoder(Transcoder):
69
+ """Extensible transcoder that uses the Python :mod:`json` module."""
70
+
71
+ def __init__(self) -> None:
72
+ self.types: dict[type, Transcoding] = {}
73
+ self.names: dict[str, Transcoding] = {}
74
+ self.encoder = json.JSONEncoder(
75
+ default=self._encode_obj,
76
+ separators=(",", ":"),
77
+ ensure_ascii=False,
78
+ )
79
+ self.decoder = json.JSONDecoder(object_hook=self._decode_obj)
80
+
81
+ def register(self, transcoding: Transcoding) -> None:
82
+ """Registers given transcoding with the transcoder."""
83
+ self.types[transcoding.type] = transcoding
84
+ self.names[transcoding.name] = transcoding
85
+
86
+ def encode(self, obj: Any) -> bytes:
87
+ """Encodes given object as a bytes array."""
88
+ return self.encoder.encode(obj).encode("utf8")
89
+
90
+ def decode(self, data: bytes) -> Any:
91
+ """Decodes bytes array as previously encoded object."""
92
+ return self.decoder.decode(data.decode("utf8"))
93
+
94
+ def _encode_obj(self, o: Any) -> dict[str, Any]:
95
+ try:
96
+ transcoding = self.types[type(o)]
97
+ except KeyError:
98
+ msg = (
99
+ f"Object of type {type(o)} is not "
100
+ "serializable. Please define and register "
101
+ "a custom transcoding for this type."
102
+ )
103
+ raise TranscodingNotRegisteredError(msg) from None
104
+ else:
105
+ return {
106
+ "_type_": transcoding.name,
107
+ "_data_": transcoding.encode(o),
108
+ }
109
+
110
+ def _decode_obj(self, d: dict[str, Any]) -> Any:
111
+ if len(d) == 2:
112
+ try:
113
+ _type_ = d["_type_"]
114
+ except KeyError:
115
+ return d
116
+ else:
117
+ try:
118
+ _data_ = d["_data_"]
119
+ except KeyError:
120
+ return d
121
+ else:
122
+ try:
123
+ transcoding = self.names[cast("str", _type_)]
124
+ except KeyError as e:
125
+ msg = (
126
+ f"Data serialized with name '{cast('str', _type_)}' is not "
127
+ "deserializable. Please register a "
128
+ "custom transcoding for this type."
129
+ )
130
+ raise TranscodingNotRegisteredError(msg) from e
131
+ else:
132
+ return transcoding.decode(_data_)
133
+ else:
134
+ return d
135
+
136
+
137
+ class UUIDAsHex(Transcoding):
138
+ """Transcoding that represents :class:`UUID` objects as hex values."""
139
+
140
+ type = UUID
141
+ name = "uuid_hex"
142
+
143
+ def encode(self, obj: UUID) -> str:
144
+ return obj.hex
145
+
146
+ def decode(self, data: str) -> UUID:
147
+ assert isinstance(data, str)
148
+ return UUID(data)
149
+
150
+
151
+ class DecimalAsStr(Transcoding):
152
+ """Transcoding that represents :class:`Decimal` objects as strings."""
153
+
154
+ type = Decimal
155
+ name = "decimal_str"
156
+
157
+ def encode(self, obj: Decimal) -> str:
158
+ return str(obj)
159
+
160
+ def decode(self, data: str) -> Decimal:
161
+ return Decimal(data)
162
+
163
+
164
+ class DatetimeAsISO(Transcoding):
165
+ """Transcoding that represents :class:`datetime` objects as ISO strings."""
166
+
167
+ type = datetime
168
+ name = "datetime_iso"
169
+
170
+ def encode(self, obj: datetime) -> str:
171
+ return obj.isoformat()
172
+
173
+ def decode(self, data: str) -> datetime:
174
+ assert isinstance(data, str)
175
+ return datetime.fromisoformat(data)
176
+
177
+
178
+ @dataclass(frozen=True)
179
+ class StoredEvent:
180
+ """Frozen dataclass that represents :class:`~eventsourcing.domain.DomainEvent`
181
+ objects, such as aggregate :class:`~eventsourcing.domain.Aggregate.Event`
182
+ objects and :class:`~eventsourcing.domain.Snapshot` objects.
183
+ """
184
+
185
+ originator_id: UUID | str
186
+ """ID of the originating aggregate."""
187
+ originator_version: int
188
+ """Position in an aggregate sequence."""
189
+ topic: str
190
+ """Topic of a domain event object class."""
191
+ state: bytes
192
+ """Serialised state of a domain event object."""
193
+
194
+
195
+ class Compressor(ABC):
196
+ """Base class for compressors."""
197
+
198
+ @abstractmethod
199
+ def compress(self, data: bytes) -> bytes:
200
+ """Compress bytes."""
201
+
202
+ @abstractmethod
203
+ def decompress(self, data: bytes) -> bytes:
204
+ """Decompress bytes."""
205
+
206
+
207
+ class Cipher(ABC):
208
+ """Base class for ciphers."""
209
+
210
+ @abstractmethod
211
+ def __init__(self, environment: Environment):
212
+ """Initialises cipher with given environment."""
213
+
214
+ @abstractmethod
215
+ def encrypt(self, plaintext: bytes) -> bytes:
216
+ """Return ciphertext for given plaintext."""
217
+
218
+ @abstractmethod
219
+ def decrypt(self, ciphertext: bytes) -> bytes:
220
+ """Return plaintext for given ciphertext."""
221
+
222
+
223
+ class MapperDeserialisationError(EventSourcingError, ValueError):
224
+ """Raised when deserialization fails in a Mapper."""
225
+
226
+
227
+ TAggregateIDType = TypeVar("TAggregateIDType", type[UUID], type[str])
228
+
229
+
230
+ class Mapper(Generic[TAggregateID]):
231
+ """Converts between domain event objects and :class:`StoredEvent` objects.
232
+
233
+ Uses a :class:`Transcoder`, and optionally a cryptographic cipher and compressor.
234
+ """
235
+
236
+ def __init__(
237
+ self,
238
+ transcoder: Transcoder,
239
+ compressor: Compressor | None = None,
240
+ cipher: Cipher | None = None,
241
+ ):
242
+ self.transcoder = transcoder
243
+ self.compressor = compressor
244
+ self.cipher = cipher
245
+
246
+ def to_stored_event(
247
+ self, domain_event: DomainEventProtocol[TAggregateID]
248
+ ) -> StoredEvent:
249
+ """Converts the given domain event to a :class:`StoredEvent` object."""
250
+ topic = get_topic(domain_event.__class__)
251
+ event_state = domain_event.__dict__.copy()
252
+ originator_id = event_state.pop("originator_id")
253
+ originator_version = event_state.pop("originator_version")
254
+ class_version = getattr(type(domain_event), "class_version", 1)
255
+ if class_version > 1:
256
+ event_state["class_version"] = class_version
257
+ stored_state = self.transcoder.encode(event_state)
258
+ if self.compressor:
259
+ stored_state = self.compressor.compress(stored_state)
260
+ if self.cipher:
261
+ stored_state = self.cipher.encrypt(stored_state)
262
+ return StoredEvent(
263
+ originator_id=originator_id,
264
+ originator_version=originator_version,
265
+ topic=topic,
266
+ state=stored_state,
267
+ )
268
+
269
+ def to_domain_event(
270
+ self, stored_event: StoredEvent
271
+ ) -> DomainEventProtocol[TAggregateID]:
272
+ """Converts the given :class:`StoredEvent` to a domain event object."""
273
+ cls = resolve_topic(stored_event.topic)
274
+
275
+ stored_state = stored_event.state
276
+ try:
277
+ if self.cipher:
278
+ stored_state = self.cipher.decrypt(stored_state)
279
+ if self.compressor:
280
+ stored_state = self.compressor.decompress(stored_state)
281
+ event_state: dict[str, Any] = self.transcoder.decode(stored_state)
282
+ except Exception as e:
283
+ msg = (
284
+ f"Failed to deserialise state of stored event with "
285
+ f"topic '{stored_event.topic}', "
286
+ f"originator_id '{stored_event.originator_id}' and "
287
+ f"originator_version {stored_event.originator_version}: {e}"
288
+ )
289
+ raise MapperDeserialisationError(msg) from e
290
+
291
+ id_convertor = find_id_convertor(
292
+ cls, cast(Hashable, type(stored_event.originator_id))
293
+ )
294
+ # print("ID of convertor:", id(convertor))
295
+ event_state["originator_id"] = id_convertor(stored_event.originator_id)
296
+ event_state["originator_version"] = stored_event.originator_version
297
+ class_version = getattr(cls, "class_version", 1)
298
+ from_version = event_state.pop("class_version", 1)
299
+ while from_version < class_version:
300
+ getattr(cls, f"upcast_v{from_version}_v{from_version + 1}")(event_state)
301
+ from_version += 1
302
+
303
+ domain_event = object.__new__(cls)
304
+ domain_event.__dict__.update(event_state)
305
+ return domain_event
306
+
307
+
308
+ @lru_cache
309
+ def find_id_convertor(
310
+ domain_event_cls: type[object], originator_id_cls: type[UUID | str]
311
+ ) -> Callable[[UUID | str], UUID | str]:
312
+ # Try to find the originator_id type.
313
+ if issubclass(domain_event_cls, HasOriginatorIDVersion):
314
+ # For classes that inherit CanMutateAggregate, and don't use a different
315
+ # mapper, then assume they aren't overriding __init_subclass__ is a way
316
+ # that prevents 'originator_id_type' being found from type arguments and
317
+ # set on the class.
318
+ # TODO: Write a test where a custom class does override __init_subclass__
319
+ # so that the next line will cause an AssertionError. Then fix this code.
320
+ if domain_event_cls.originator_id_type is None:
321
+ msg = "originator_id_type cannot be None"
322
+ raise TypeError(msg)
323
+ originator_id_type = domain_event_cls.originator_id_type
324
+ else:
325
+ # Otherwise look for annotations.
326
+ for cls in domain_event_cls.__mro__:
327
+ try:
328
+ annotation = cls.__annotations__["originator_id"]
329
+ except (KeyError, AttributeError): # noqa: PERF203
330
+ continue
331
+ else:
332
+ valid_annotations = {
333
+ str: str,
334
+ UUID: UUID,
335
+ "str": str,
336
+ "UUID": UUID,
337
+ "uuid.UUID": UUID,
338
+ }
339
+ if annotation not in valid_annotations:
340
+ msg = f"originator_id annotation on {cls} is not either UUID or str"
341
+ raise TypeError(msg)
342
+ assert annotation in valid_annotations, annotation
343
+ originator_id_type = valid_annotations[annotation]
344
+ break
345
+ else:
346
+ msg = (
347
+ f"Neither event class {domain_event_cls}"
348
+ f"nor its bases have an originator_id annotation"
349
+ )
350
+ raise TypeError(msg)
351
+
352
+ if originator_id_cls is str and originator_id_type is UUID:
353
+ convertor = str_to_uuid_convertor
354
+ else:
355
+ convertor = pass_through_convertor
356
+ return convertor
357
+
358
+
359
+ def str_to_uuid_convertor(originator_id: UUID | str) -> UUID | str:
360
+ assert isinstance(originator_id, str)
361
+ return UUID(originator_id)
362
+
363
+
364
+ def pass_through_convertor(originator_id: UUID | str) -> UUID | str:
365
+ return originator_id
366
+
367
+
368
+ class RecordConflictError(EventSourcingError):
369
+ """Legacy exception, replaced with IntegrityError."""
370
+
371
+
372
+ class PersistenceError(EventSourcingError):
373
+ """The base class of the other exceptions in this module.
374
+
375
+ Exception class names follow https://www.python.org/dev/peps/pep-0249/#exceptions
376
+ """
377
+
378
+
379
+ class InterfaceError(PersistenceError):
380
+ """Exception raised for errors that are related to the database
381
+ interface rather than the database itself.
382
+ """
383
+
384
+
385
+ class DatabaseError(PersistenceError):
386
+ """Exception raised for errors that are related to the database."""
387
+
388
+
389
+ class DataError(DatabaseError):
390
+ """Exception raised for errors that are due to problems with the
391
+ processed data like division by zero, numeric value out of range, etc.
392
+ """
393
+
394
+
395
+ class OperationalError(DatabaseError):
396
+ """Exception raised for errors that are related to the database's
397
+ operation and not necessarily under the control of the programmer,
398
+ e.g. an unexpected disconnect occurs, the data source name is not
399
+ found, a transaction could not be processed, a memory allocation
400
+ error occurred during processing, etc.
401
+ """
402
+
403
+
404
+ class IntegrityError(DatabaseError, RecordConflictError):
405
+ """Exception raised when the relational integrity of the
406
+ database is affected, e.g. a foreign key check fails.
407
+ """
408
+
409
+
410
+ class InternalError(DatabaseError):
411
+ """Exception raised when the database encounters an internal
412
+ error, e.g. the cursor is not valid anymore, the transaction
413
+ is out of sync, etc.
414
+ """
415
+
416
+
417
+ class ProgrammingError(DatabaseError):
418
+ """Exception raised for database programming errors, e.g. table
419
+ not found or already exists, syntax error in the SQL statement,
420
+ wrong number of parameters specified, etc.
421
+ """
422
+
423
+
424
+ class NotSupportedError(DatabaseError):
425
+ """Exception raised in case a method or database API was used
426
+ which is not supported by the database, e.g. calling the
427
+ rollback() method on a connection that does not support
428
+ transaction or has transactions turned off.
429
+ """
430
+
431
+
432
+ class WaitInterruptedError(PersistenceError):
433
+ """Raised when waiting for a tracking record is interrupted."""
434
+
435
+
436
+ class Recorder:
437
+ pass
438
+
439
+
440
+ class AggregateRecorder(Recorder, ABC):
441
+ """Abstract base class for inserting and selecting stored events."""
442
+
443
+ @abstractmethod
444
+ def insert_events(
445
+ self, stored_events: Sequence[StoredEvent], **kwargs: Any
446
+ ) -> Sequence[int] | None:
447
+ """Writes stored events into database."""
448
+
449
+ @abstractmethod
450
+ def select_events(
451
+ self,
452
+ originator_id: UUID | str,
453
+ *,
454
+ gt: int | None = None,
455
+ lte: int | None = None,
456
+ desc: bool = False,
457
+ limit: int | None = None,
458
+ ) -> Sequence[StoredEvent]:
459
+ """Reads stored events from database."""
460
+
461
+
462
+ @dataclass(frozen=True)
463
+ class Notification(StoredEvent):
464
+ """Frozen dataclass that represents domain event notifications."""
465
+
466
+ id: int
467
+ """Position in an application sequence."""
468
+
469
+
470
+ class ApplicationRecorder(AggregateRecorder):
471
+ """Abstract base class for recording events in both aggregate
472
+ and application sequences.
473
+ """
474
+
475
+ @abstractmethod
476
+ def select_notifications(
477
+ self,
478
+ start: int | None,
479
+ limit: int,
480
+ stop: int | None = None,
481
+ topics: Sequence[str] = (),
482
+ *,
483
+ inclusive_of_start: bool = True,
484
+ ) -> Sequence[Notification]:
485
+ """Returns a list of Notification objects representing events from an
486
+ application sequence. If `inclusive_of_start` is True (the default),
487
+ the returned Notification objects will have IDs greater than or equal
488
+ to `start` and less than or equal to `stop`. If `inclusive_of_start`
489
+ is False, the Notification objects will have IDs greater than `start`
490
+ and less than or equal to `stop`.
491
+ """
492
+
493
+ @abstractmethod
494
+ def max_notification_id(self) -> int | None:
495
+ """Returns the largest notification ID in an application sequence,
496
+ or None if no stored events have been recorded.
497
+ """
498
+
499
+ @abstractmethod
500
+ def subscribe(
501
+ self, gt: int | None = None, topics: Sequence[str] = ()
502
+ ) -> Subscription[ApplicationRecorder]:
503
+ """Returns an iterator of Notification objects representing events from an
504
+ application sequence.
505
+
506
+ The iterator will block after the last recorded event has been yielded, but
507
+ will then continue yielding newly recorded events when they are recorded.
508
+
509
+ Notifications will have IDs greater than the optional `gt` argument.
510
+ """
511
+
512
+
513
+ class TrackingRecorder(Recorder, ABC):
514
+ """Abstract base class for recorders that record tracking
515
+ objects atomically with other state.
516
+ """
517
+
518
+ @abstractmethod
519
+ def insert_tracking(self, tracking: Tracking) -> None:
520
+ """Records a tracking object."""
521
+
522
+ @abstractmethod
523
+ def max_tracking_id(self, application_name: str) -> int | None:
524
+ """Returns the largest notification ID across all recorded tracking objects
525
+ for the named application, or None if no tracking objects have been recorded.
526
+ """
527
+
528
+ def has_tracking_id(
529
+ self, application_name: str, notification_id: int | None
530
+ ) -> bool:
531
+ """Returns True if given notification_id is None or a tracking
532
+ object with the given application_name and a notification ID greater
533
+ than or equal to the given notification_id has been recorded.
534
+ """
535
+ if notification_id is None:
536
+ return True
537
+ max_tracking_id = self.max_tracking_id(application_name)
538
+ return max_tracking_id is not None and max_tracking_id >= notification_id
539
+
540
+ def wait(
541
+ self,
542
+ application_name: str,
543
+ notification_id: int | None,
544
+ timeout: float = 1.0,
545
+ interrupt: Event | None = None,
546
+ ) -> None:
547
+ """Block until a tracking object with the given application name and a
548
+ notification ID greater than equal to the given value has been recorded.
549
+
550
+ Polls max_tracking_id() with exponential backoff until the timeout
551
+ is reached, or until the optional interrupt event is set.
552
+
553
+ The timeout argument should be a floating point number specifying a
554
+ timeout for the operation in seconds (or fractions thereof). The default
555
+ is 1.0 seconds.
556
+
557
+ Raises TimeoutError if the timeout is reached.
558
+
559
+ Raises WaitInterruptError if the `interrupt` is set before `timeout` is reached.
560
+ """
561
+ deadline = monotonic() + timeout
562
+ sleep_interval_ms = 100.0
563
+ max_sleep_interval_ms = 800.0
564
+ while True:
565
+ if self.has_tracking_id(application_name, notification_id):
566
+ break
567
+ if interrupt:
568
+ if interrupt.wait(timeout=sleep_interval_ms / 1000):
569
+ raise WaitInterruptedError
570
+ else:
571
+ sleep(sleep_interval_ms / 1000)
572
+ remaining = deadline - monotonic()
573
+ if remaining < 0:
574
+ msg = (
575
+ f"Timed out waiting for notification {notification_id} "
576
+ f"from application '{application_name}' to be processed"
577
+ )
578
+ raise TimeoutError(msg)
579
+ sleep_interval_ms = min(
580
+ sleep_interval_ms * 2, remaining * 1000, max_sleep_interval_ms
581
+ )
582
+
583
+
584
+ class ProcessRecorder(TrackingRecorder, ApplicationRecorder, ABC):
585
+ pass
586
+
587
+
588
+ @dataclass(frozen=True)
589
+ class Recording(Generic[TAggregateID]):
590
+ """Represents the recording of a domain event."""
591
+
592
+ domain_event: DomainEventProtocol[TAggregateID]
593
+ """The domain event that has been recorded."""
594
+ notification: Notification
595
+ """A Notification that represents the domain event in the application sequence."""
596
+
597
+
598
+ class EventStore(Generic[TAggregateID]):
599
+ """Stores and retrieves domain events."""
600
+
601
+ def __init__(
602
+ self,
603
+ mapper: Mapper[TAggregateID],
604
+ recorder: AggregateRecorder,
605
+ ):
606
+ self.mapper: Mapper[TAggregateID] = mapper
607
+ self.recorder = recorder
608
+
609
+ def put(
610
+ self, domain_events: Sequence[DomainEventProtocol[TAggregateID]], **kwargs: Any
611
+ ) -> list[Recording[TAggregateID]]:
612
+ """Stores domain events in aggregate sequence."""
613
+ stored_events = list(map(self.mapper.to_stored_event, domain_events))
614
+ recordings = []
615
+ notification_ids = self.recorder.insert_events(stored_events, **kwargs)
616
+ if notification_ids:
617
+ assert len(notification_ids) == len(stored_events)
618
+ for d, s, n_id in zip(
619
+ domain_events, stored_events, notification_ids, strict=True
620
+ ):
621
+ recordings.append(
622
+ Recording(
623
+ d,
624
+ Notification(
625
+ originator_id=s.originator_id,
626
+ originator_version=s.originator_version,
627
+ topic=s.topic,
628
+ state=s.state,
629
+ id=n_id,
630
+ ),
631
+ )
632
+ )
633
+ return recordings
634
+
635
+ def get(
636
+ self,
637
+ originator_id: TAggregateID,
638
+ *,
639
+ gt: int | None = None,
640
+ lte: int | None = None,
641
+ desc: bool = False,
642
+ limit: int | None = None,
643
+ ) -> Iterator[DomainEventProtocol[TAggregateID]]:
644
+ """Retrieves domain events from aggregate sequence."""
645
+ return map(
646
+ self.mapper.to_domain_event,
647
+ self.recorder.select_events(
648
+ originator_id=originator_id,
649
+ gt=gt,
650
+ lte=lte,
651
+ desc=desc,
652
+ limit=limit,
653
+ ),
654
+ )
655
+
656
+
657
+ TTrackingRecorder = TypeVar(
658
+ "TTrackingRecorder", bound=TrackingRecorder, default=TrackingRecorder
659
+ )
660
+
661
+
662
+ class InfrastructureFactoryError(EventSourcingError):
663
+ """Raised when an infrastructure factory cannot be created."""
664
+
665
+
666
+ class BaseInfrastructureFactory(ABC, Generic[TTrackingRecorder]):
667
+ """Abstract base class for infrastructure factories."""
668
+
669
+ PERSISTENCE_MODULE = "PERSISTENCE_MODULE"
670
+ TRANSCODER_TOPIC = "TRANSCODER_TOPIC"
671
+ CIPHER_TOPIC = "CIPHER_TOPIC"
672
+ COMPRESSOR_TOPIC = "COMPRESSOR_TOPIC"
673
+
674
+ def __init__(self, env: Environment | EnvType | None):
675
+ """Initialises infrastructure factory object with given application name."""
676
+ self.env = env if isinstance(env, Environment) else Environment(env=env)
677
+ self._is_entered = False
678
+
679
+ def __enter__(self) -> Self:
680
+ self._is_entered = True
681
+ return self
682
+
683
+ def __exit__(
684
+ self,
685
+ exc_type: type[BaseException] | None,
686
+ exc_val: BaseException | None,
687
+ exc_tb: TracebackType | None,
688
+ ) -> None:
689
+ self._is_entered = False
690
+
691
+ def close(self) -> None:
692
+ """Closes any database connections, and anything else that needs closing."""
693
+
694
+ @classmethod
695
+ def construct(
696
+ cls: type[Self],
697
+ env: Environment | None = None,
698
+ ) -> Self:
699
+ """Constructs concrete infrastructure factory for given
700
+ named application. Reads and resolves persistence
701
+ topic from environment variable 'PERSISTENCE_MODULE'.
702
+ """
703
+ factory_cls: type[Self]
704
+ if env is None:
705
+ env = Environment()
706
+ topic = (
707
+ env.get(
708
+ cls.PERSISTENCE_MODULE,
709
+ "",
710
+ )
711
+ or env.get(
712
+ "INFRASTRUCTURE_FACTORY", # Legacy.
713
+ "",
714
+ )
715
+ or env.get(
716
+ "FACTORY_TOPIC", # Legacy.
717
+ "",
718
+ )
719
+ or "eventsourcing.popo"
720
+ )
721
+ try:
722
+ obj: type[Self] | ModuleType = resolve_topic(topic)
723
+ except TopicError as e:
724
+ msg = (
725
+ "Failed to resolve persistence module topic: "
726
+ f"'{topic}' from environment "
727
+ f"variable '{cls.PERSISTENCE_MODULE}'"
728
+ )
729
+ raise InfrastructureFactoryError(msg) from e
730
+
731
+ if isinstance(obj, ModuleType):
732
+ # Find the factory in the module.
733
+ factory_classes = set[type[Self]]()
734
+ for member in obj.__dict__.values():
735
+ # Look for classes...
736
+ if not isinstance(member, type):
737
+ continue
738
+ # Issue with Python 3.9 and 3.10.
739
+ if isinstance(member, GenericAlias):
740
+ continue # pragma: no cover (for Python > 3.10 only)
741
+ if not issubclass(member, cls):
742
+ continue
743
+ if getattr(member, "__parameters__", None):
744
+ continue
745
+ factory_classes.add(member)
746
+
747
+ if len(factory_classes) == 1:
748
+ factory_cls = next(iter(factory_classes))
749
+ else:
750
+ msg = (
751
+ f"Found {len(factory_classes)} infrastructure factory classes in"
752
+ f" '{topic}', expected 1."
753
+ )
754
+ raise InfrastructureFactoryError(msg)
755
+ elif isinstance(obj, type) and issubclass(obj, cls):
756
+ factory_cls = obj
757
+ else:
758
+ msg = (
759
+ f"Topic '{topic}' didn't resolve to a persistence module "
760
+ f"or infrastructure factory class: {obj}"
761
+ )
762
+ raise InfrastructureFactoryError(msg)
763
+ return factory_cls(env=env)
764
+
765
+ def transcoder(
766
+ self,
767
+ ) -> Transcoder:
768
+ """Constructs a transcoder."""
769
+ transcoder_topic = self.env.get(self.TRANSCODER_TOPIC)
770
+ if transcoder_topic:
771
+ transcoder_class: type[Transcoder] = resolve_topic(transcoder_topic)
772
+ else:
773
+ transcoder_class = JSONTranscoder
774
+ return transcoder_class()
775
+
776
+ def cipher(self) -> Cipher | None:
777
+ """Reads environment variables 'CIPHER_TOPIC'
778
+ and 'CIPHER_KEY' to decide whether or not
779
+ to construct a cipher.
780
+ """
781
+ cipher_topic = self.env.get(self.CIPHER_TOPIC)
782
+ cipher: Cipher | None = None
783
+ default_cipher_topic = "eventsourcing.cipher:AESCipher"
784
+ if self.env.get("CIPHER_KEY") and not cipher_topic:
785
+ cipher_topic = default_cipher_topic
786
+
787
+ if cipher_topic:
788
+ cipher_cls: type[Cipher] = resolve_topic(cipher_topic)
789
+ cipher = cipher_cls(self.env)
790
+
791
+ return cipher
792
+
793
+ def compressor(self) -> Compressor | None:
794
+ """Reads environment variable 'COMPRESSOR_TOPIC' to
795
+ decide whether or not to construct a compressor.
796
+ """
797
+ compressor: Compressor | None = None
798
+ compressor_topic = self.env.get(self.COMPRESSOR_TOPIC)
799
+ if compressor_topic:
800
+ compressor_cls: type[Compressor] | Compressor = resolve_topic(
801
+ compressor_topic
802
+ )
803
+ if isinstance(compressor_cls, type):
804
+ compressor = compressor_cls()
805
+ else:
806
+ compressor = compressor_cls
807
+ return compressor
808
+
809
+
810
+ class InfrastructureFactory(BaseInfrastructureFactory[TTrackingRecorder]):
811
+ """Abstract base class for Application factories."""
812
+
813
+ MAPPER_TOPIC = "MAPPER_TOPIC"
814
+ IS_SNAPSHOTTING_ENABLED = "IS_SNAPSHOTTING_ENABLED"
815
+ APPLICATION_RECORDER_TOPIC = "APPLICATION_RECORDER_TOPIC"
816
+ TRACKING_RECORDER_TOPIC = "TRACKING_RECORDER_TOPIC"
817
+ PROCESS_RECORDER_TOPIC = "PROCESS_RECORDER_TOPIC"
818
+
819
+ def mapper(
820
+ self,
821
+ transcoder: Transcoder | None = None,
822
+ mapper_class: type[Mapper[TAggregateID]] | None = None,
823
+ ) -> Mapper[TAggregateID]:
824
+ """Constructs a mapper."""
825
+ # Resolve MAPPER_TOPIC if no given class.
826
+ if mapper_class is None:
827
+ mapper_topic = self.env.get(self.MAPPER_TOPIC)
828
+ mapper_class = (
829
+ resolve_topic(mapper_topic) if mapper_topic else Mapper[TAggregateID]
830
+ )
831
+
832
+ # Check we have a mapper class.
833
+ assert mapper_class is not None
834
+ origin_mapper_class = typing.get_origin(mapper_class) or mapper_class
835
+ assert isinstance(origin_mapper_class, type), mapper_class
836
+ assert issubclass(origin_mapper_class, Mapper), mapper_class
837
+
838
+ # Construct and return a mapper.
839
+ return mapper_class(
840
+ transcoder=transcoder or self.transcoder(),
841
+ cipher=self.cipher(),
842
+ compressor=self.compressor(),
843
+ )
844
+
845
+ def event_store(
846
+ self,
847
+ mapper: Mapper[TAggregateID] | None = None,
848
+ recorder: AggregateRecorder | None = None,
849
+ ) -> EventStore[TAggregateID]:
850
+ """Constructs an event store."""
851
+ return EventStore(
852
+ mapper=mapper or self.mapper(),
853
+ recorder=recorder or self.application_recorder(),
854
+ )
855
+
856
+ @abstractmethod
857
+ def aggregate_recorder(self, purpose: str = "events") -> AggregateRecorder:
858
+ """Constructs an aggregate recorder."""
859
+
860
+ @abstractmethod
861
+ def application_recorder(self) -> ApplicationRecorder:
862
+ """Constructs an application recorder."""
863
+
864
+ @abstractmethod
865
+ def tracking_recorder(
866
+ self, tracking_recorder_class: type[TTrackingRecorder] | None = None
867
+ ) -> TTrackingRecorder:
868
+ """Constructs a tracking recorder."""
869
+
870
+ @abstractmethod
871
+ def process_recorder(self) -> ProcessRecorder:
872
+ """Constructs a process recorder."""
873
+
874
+ def is_snapshotting_enabled(self) -> bool:
875
+ """Decides whether or not snapshotting is enabled by
876
+ reading environment variable 'IS_SNAPSHOTTING_ENABLED'.
877
+ Snapshotting is not enabled by default.
878
+ """
879
+ return strtobool(self.env.get(self.IS_SNAPSHOTTING_ENABLED, "no"))
880
+
881
+
882
+ @dataclass(frozen=True)
883
+ class Tracking:
884
+ """Frozen dataclass representing the position of a domain
885
+ event :class:`Notification` in an application's notification log.
886
+ """
887
+
888
+ application_name: str
889
+ notification_id: int
890
+
891
+
892
+ Params = Sequence[Any] | Mapping[str, Any]
893
+
894
+
895
+ class Cursor(ABC):
896
+ @abstractmethod
897
+ def execute(self, statement: str | bytes, params: Params | None = None) -> None:
898
+ """Executes given statement."""
899
+
900
+ @abstractmethod
901
+ def fetchall(self) -> Any:
902
+ """Fetches all results."""
903
+
904
+ @abstractmethod
905
+ def fetchone(self) -> Any:
906
+ """Fetches one result."""
907
+
908
+
909
+ TCursor = TypeVar("TCursor", bound=Cursor)
910
+
911
+
912
+ class Connection(ABC, Generic[TCursor]):
913
+ def __init__(self, max_age: float | None = None) -> None:
914
+ self._closed = False
915
+ self._closing = Event()
916
+ self._close_lock = Lock()
917
+ self.in_use = Lock()
918
+ self.in_use.acquire()
919
+ if max_age is not None:
920
+ self._max_age_timer: Timer | None = Timer(
921
+ interval=max_age,
922
+ function=self._close_when_not_in_use,
923
+ )
924
+ self._max_age_timer.daemon = True
925
+ self._max_age_timer.start()
926
+ else:
927
+ self._max_age_timer = None
928
+ self.is_writer: bool | None = None
929
+
930
+ @property
931
+ def closed(self) -> bool:
932
+ return self._closed
933
+
934
+ @property
935
+ def closing(self) -> bool:
936
+ return self._closing.is_set()
937
+
938
+ @abstractmethod
939
+ def commit(self) -> None:
940
+ """Commits transaction."""
941
+
942
+ @abstractmethod
943
+ def rollback(self) -> None:
944
+ """Rolls back transaction."""
945
+
946
+ @abstractmethod
947
+ def cursor(self) -> TCursor:
948
+ """Creates new cursor."""
949
+
950
+ def close(self) -> None:
951
+ with self._close_lock:
952
+ self._close()
953
+
954
+ @abstractmethod
955
+ def _close(self) -> None:
956
+ self._closed = True
957
+ if self._max_age_timer:
958
+ self._max_age_timer.cancel()
959
+
960
+ def _close_when_not_in_use(self) -> None:
961
+ self._closing.set()
962
+ with self.in_use:
963
+ if not self._closed:
964
+ self.close()
965
+
966
+
967
+ TConnection = TypeVar("TConnection", bound=Connection[Any])
968
+
969
+
970
+ class ConnectionPoolClosedError(EventSourcingError):
971
+ """Raised when using a connection pool that is already closed."""
972
+
973
+
974
+ class ConnectionNotFromPoolError(EventSourcingError):
975
+ """Raised when putting a connection in the wrong pool."""
976
+
977
+
978
+ class ConnectionUnavailableError(OperationalError, TimeoutError):
979
+ """Raised when a request to get a connection from a
980
+ connection pool times out.
981
+ """
982
+
983
+
984
+ class ConnectionPool(ABC, Generic[TConnection]):
985
+ def __init__(
986
+ self,
987
+ *,
988
+ pool_size: int = 5,
989
+ max_overflow: int = 10,
990
+ pool_timeout: float = 30.0,
991
+ max_age: float | None = None,
992
+ pre_ping: bool = False,
993
+ mutually_exclusive_read_write: bool = False,
994
+ ) -> None:
995
+ """Initialises a new connection pool.
996
+
997
+ The 'pool_size' argument specifies the maximum number of connections
998
+ that will be put into the pool when connections are returned. The
999
+ default value is 5
1000
+
1001
+ The 'max_overflow' argument specifies the additional number of
1002
+ connections that can be issued by the pool, above the 'pool_size'.
1003
+ The default value is 10.
1004
+
1005
+ The 'pool_timeout' argument specifies the maximum time in seconds
1006
+ to keep requests for connections waiting. Connections are kept
1007
+ waiting if the number of connections currently in use is not less
1008
+ than the sum of 'pool_size' and 'max_overflow'. The default value
1009
+ is 30.0
1010
+
1011
+ The 'max_age' argument specifies the time in seconds until a
1012
+ connection will automatically be closed. Connections are only closed
1013
+ in this way after are not in use. Connections that are in use will
1014
+ not be closed automatically. The default value in None, meaning
1015
+ connections will not be automatically closed in this way.
1016
+
1017
+ The 'mutually_exclusive_read_write' argument specifies whether
1018
+ requests for connections for writing whilst connections for reading
1019
+ are in use. It also specifies whether requests for connections for reading
1020
+ will be kept waiting whilst a connection for writing is in use. The default
1021
+ value is false, meaning reading and writing will not be mutually exclusive
1022
+ in this way.
1023
+ """
1024
+ self.pool_size = pool_size
1025
+ self.max_overflow = max_overflow
1026
+ self.pool_timeout = pool_timeout
1027
+ self.max_age = max_age
1028
+ self.pre_ping = pre_ping
1029
+ self._pool: deque[TConnection] = deque()
1030
+ self._in_use: dict[int, TConnection] = {}
1031
+ self._get_semaphore = Semaphore()
1032
+ self._put_condition = Condition()
1033
+ self._no_readers = Condition()
1034
+ self._num_readers: int = 0
1035
+ self._writer_lock = Lock()
1036
+ self._num_writers: int = 0
1037
+ self._mutually_exclusive_read_write = mutually_exclusive_read_write
1038
+ self._closed = False
1039
+
1040
+ @property
1041
+ def closed(self) -> bool:
1042
+ return self._closed
1043
+
1044
+ @property
1045
+ def num_in_use(self) -> int:
1046
+ """Indicates the total number of connections currently in use."""
1047
+ with self._put_condition:
1048
+ return self._num_in_use
1049
+
1050
+ @property
1051
+ def _num_in_use(self) -> int:
1052
+ return len(self._in_use)
1053
+
1054
+ @property
1055
+ def num_in_pool(self) -> int:
1056
+ """Indicates the number of connections currently in the pool."""
1057
+ with self._put_condition:
1058
+ return self._num_in_pool
1059
+
1060
+ @property
1061
+ def _num_in_pool(self) -> int:
1062
+ return len(self._pool)
1063
+
1064
+ @property
1065
+ def _is_pool_full(self) -> bool:
1066
+ return self._num_in_pool >= self.pool_size
1067
+
1068
+ @property
1069
+ def _is_use_full(self) -> bool:
1070
+ return self._num_in_use >= self.pool_size + self.max_overflow
1071
+
1072
+ def get_connection(
1073
+ self, timeout: float | None = None, *, is_writer: bool | None = None
1074
+ ) -> TConnection:
1075
+ """Issues connections, or raises ConnectionPoolExhausted error.
1076
+ Provides "fairness" on attempts to get connections, meaning that
1077
+ connections are issued in the same order as they are requested.
1078
+
1079
+ The 'timeout' argument overrides the timeout specified
1080
+ by the constructor argument 'pool_timeout'. The default
1081
+ value is None, meaning the 'pool_timeout' argument will
1082
+ not be overridden.
1083
+
1084
+ The optional 'is_writer' argument can be used to request
1085
+ a connection for writing (true), and request a connection
1086
+ for reading (false). If the value of this argument is None,
1087
+ which is the default, the writing and reading interlocking
1088
+ mechanism is not activated. Only one connection for writing
1089
+ will be issued, which means requests for connections for
1090
+ writing are kept waiting whilst another connection for writing
1091
+ is in use.
1092
+
1093
+ If reading and writing are mutually exclusive, requsts for
1094
+ connections for writing are kept waiting whilst connections
1095
+ for reading are in use, and requests for connections for reading
1096
+ are kept waiting whilst a connection for writing is in use.
1097
+ """
1098
+ # Make sure we aren't dealing with a closed pool.
1099
+ if self._closed:
1100
+ raise ConnectionPoolClosedError
1101
+
1102
+ # Decide the timeout for getting a connection.
1103
+ timeout = self.pool_timeout if timeout is None else timeout
1104
+
1105
+ # Remember when we started trying to get a connection.
1106
+ started = time()
1107
+
1108
+ # Join queue of threads waiting to get a connection ("fairness").
1109
+ if self._get_semaphore.acquire(timeout=timeout):
1110
+ try:
1111
+ # If connection is for writing, get write lock and wait for no readers.
1112
+ if is_writer is True:
1113
+ if not self._writer_lock.acquire(
1114
+ timeout=self._time_remaining(timeout, started)
1115
+ ):
1116
+ msg = "Timed out waiting for return of writer"
1117
+ raise ConnectionUnavailableError(msg)
1118
+ if self._mutually_exclusive_read_write:
1119
+ with self._no_readers:
1120
+ if self._num_readers > 0 and not self._no_readers.wait(
1121
+ timeout=self._time_remaining(timeout, started)
1122
+ ):
1123
+ self._writer_lock.release()
1124
+ msg = "Timed out waiting for return of reader"
1125
+ raise ConnectionUnavailableError(msg)
1126
+ self._num_writers += 1
1127
+
1128
+ # If connection is for reading, and writing excludes reading,
1129
+ # then wait for the writer lock, and increment number of readers.
1130
+ elif is_writer is False:
1131
+ if self._mutually_exclusive_read_write:
1132
+ if not self._writer_lock.acquire(
1133
+ timeout=self._time_remaining(timeout, started)
1134
+ ):
1135
+ msg = "Timed out waiting for return of writer"
1136
+ raise ConnectionUnavailableError(msg)
1137
+ self._writer_lock.release()
1138
+ with self._no_readers:
1139
+ self._num_readers += 1
1140
+
1141
+ # Actually try to get a connection withing the time remaining.
1142
+ conn = self._get_connection(
1143
+ timeout=self._time_remaining(timeout, started)
1144
+ )
1145
+
1146
+ # Remember if this connection is for reading or writing.
1147
+ conn.is_writer = is_writer
1148
+
1149
+ # Return the connection.
1150
+ return conn
1151
+ finally:
1152
+ self._get_semaphore.release()
1153
+ else:
1154
+ # Timed out waiting for semaphore.
1155
+ msg = "Timed out waiting for connection pool semaphore"
1156
+ raise ConnectionUnavailableError(msg)
1157
+
1158
+ def _get_connection(self, timeout: float = 0.0) -> TConnection:
1159
+ """Gets or creates connections from pool within given
1160
+ time, otherwise raises a "pool exhausted" error.
1161
+
1162
+ Waits for connections to be returned if the pool
1163
+ is fully used. And optionally ensures a connection
1164
+ is usable before returning a connection for use.
1165
+
1166
+ Tracks use of connections, and number of readers.
1167
+ """
1168
+ started = time()
1169
+ # Get lock on tracking usage of connections.
1170
+ with self._put_condition:
1171
+ # Try to get a connection from the pool.
1172
+ try:
1173
+ conn = self._pool.popleft()
1174
+ except IndexError:
1175
+ # Pool is empty, but are connections fully used?
1176
+ if self._is_use_full:
1177
+ # Fully used, so wait for a connection to be returned.
1178
+ if self._put_condition.wait(
1179
+ timeout=self._time_remaining(timeout, started)
1180
+ ):
1181
+ # Connection has been returned, so try again.
1182
+ return self._get_connection(
1183
+ timeout=self._time_remaining(timeout, started)
1184
+ )
1185
+ # Timed out waiting for a connection to be returned.
1186
+ msg = "Timed out waiting for return of connection"
1187
+ raise ConnectionUnavailableError(msg) from None
1188
+ # Not fully used, so create a new connection.
1189
+ conn = self._create_connection()
1190
+ # print("created another connection")
1191
+
1192
+ # Connection should be pre-locked for use (avoids timer race).
1193
+ assert conn.in_use.locked()
1194
+
1195
+ else:
1196
+ # Got unused connection from pool, so lock for use.
1197
+ conn.in_use.acquire()
1198
+
1199
+ # Check the connection wasn't closed by the timer.
1200
+ if conn.closed:
1201
+ return self._get_connection(
1202
+ timeout=self._time_remaining(timeout, started)
1203
+ )
1204
+
1205
+ # Check the connection is actually usable.
1206
+ if self.pre_ping:
1207
+ try:
1208
+ conn.cursor().execute("SELECT 1")
1209
+ except Exception:
1210
+ # Probably connection is closed on server,
1211
+ # but just try to make sure it is closed.
1212
+ conn.close()
1213
+
1214
+ # Try again to get a connection.
1215
+ return self._get_connection(
1216
+ timeout=self._time_remaining(timeout, started)
1217
+ )
1218
+
1219
+ # Track the connection is now being used.
1220
+ self._in_use[id(conn)] = conn
1221
+
1222
+ # Return the connection.
1223
+ return conn
1224
+
1225
+ def put_connection(self, conn: TConnection) -> None:
1226
+ """Returns connections to the pool, or closes connection
1227
+ if the pool is full.
1228
+
1229
+ Unlocks write lock after writer has returned, and
1230
+ updates count of readers when readers are returned.
1231
+
1232
+ Notifies waiters when connections have been returned,
1233
+ and when there are no longer any readers.
1234
+ """
1235
+ # Start forgetting if this connection was for reading or writing.
1236
+ is_writer, conn.is_writer = conn.is_writer, None
1237
+
1238
+ # Get a lock on tracking usage of connections.
1239
+ with self._put_condition:
1240
+ # Make sure we aren't dealing with a closed pool
1241
+ if self._closed:
1242
+ msg = "Pool is closed"
1243
+ raise ConnectionPoolClosedError(msg)
1244
+
1245
+ # Make sure we are dealing with a connection from this pool.
1246
+ try:
1247
+ del self._in_use[id(conn)]
1248
+ except KeyError:
1249
+ msg = "Connection not in use in this pool"
1250
+ raise ConnectionNotFromPoolError(msg) from None
1251
+
1252
+ if not conn.closed:
1253
+ # Put open connection in pool if not full.
1254
+ if not conn.closing and not self._is_pool_full:
1255
+ self._pool.append(conn)
1256
+ # Close open connection if the pool is full or timer has fired.
1257
+ else:
1258
+ # Otherwise, close the connection.
1259
+ conn.close()
1260
+
1261
+ # Unlock the connection for subsequent use (and for closing by the timer).
1262
+ conn.in_use.release()
1263
+
1264
+ # If the connection was for writing, unlock the writer lock.
1265
+ if is_writer is True:
1266
+ self._num_writers -= 1
1267
+ self._writer_lock.release()
1268
+
1269
+ # Or if it was for reading, decrement the number of readers.
1270
+ elif is_writer is False:
1271
+ with self._no_readers:
1272
+ self._num_readers -= 1
1273
+ if self._num_readers == 0 and self._mutually_exclusive_read_write:
1274
+ self._no_readers.notify()
1275
+
1276
+ # Notify a thread that is waiting for a connection to be returned.
1277
+ self._put_condition.notify()
1278
+
1279
+ @abstractmethod
1280
+ def _create_connection(self) -> TConnection:
1281
+ """Create a new connection.
1282
+
1283
+ Subclasses should implement this method by
1284
+ creating a database connection of the type
1285
+ being pooled.
1286
+ """
1287
+
1288
+ def close(self) -> None:
1289
+ """Close the connection pool."""
1290
+ with self._put_condition:
1291
+ if self._closed:
1292
+ return
1293
+ for conn in self._in_use.values():
1294
+ conn.close()
1295
+ while True:
1296
+ try:
1297
+ conn = self._pool.popleft()
1298
+ except IndexError: # noqa: PERF203
1299
+ break
1300
+ else:
1301
+ conn.close()
1302
+ self._closed = True
1303
+
1304
+ @staticmethod
1305
+ def _time_remaining(timeout: float, started: float) -> float:
1306
+ return max(0.0, timeout + started - time())
1307
+
1308
+ def __del__(self) -> None:
1309
+ self.close()
1310
+
1311
+
1312
+ TApplicationRecorder_co = TypeVar(
1313
+ "TApplicationRecorder_co", bound=ApplicationRecorder, covariant=True
1314
+ )
1315
+
1316
+
1317
+ class Subscription(Iterator[Notification], Generic[TApplicationRecorder_co]):
1318
+ def __init__(
1319
+ self,
1320
+ recorder: TApplicationRecorder_co,
1321
+ gt: int | None = None,
1322
+ topics: Sequence[str] = (),
1323
+ ) -> None:
1324
+ self._recorder = recorder
1325
+ self._last_notification_id = gt
1326
+ self._topics = topics
1327
+ self._has_been_entered = False
1328
+ self._has_been_stopped = False
1329
+
1330
+ def __enter__(self) -> Self:
1331
+ if self._has_been_entered:
1332
+ msg = "Already entered subscription context manager"
1333
+ raise ProgrammingError(msg)
1334
+ self._has_been_entered = True
1335
+ return self
1336
+
1337
+ def __exit__(self, *args: object, **kwargs: Any) -> None:
1338
+ if not self._has_been_entered:
1339
+ msg = "Not already entered subscription context manager"
1340
+ raise ProgrammingError(msg)
1341
+ self.stop()
1342
+
1343
+ def stop(self) -> None:
1344
+ """Stops the subscription."""
1345
+ self._has_been_stopped = True
1346
+
1347
+ def __iter__(self) -> Self:
1348
+ return self
1349
+
1350
+ @abstractmethod
1351
+ def __next__(self) -> Notification:
1352
+ """Returns the next Notification object in the application sequence."""
1353
+
1354
+
1355
+ class ListenNotifySubscription(Subscription[TApplicationRecorder_co]):
1356
+ def __init__(
1357
+ self,
1358
+ recorder: TApplicationRecorder_co,
1359
+ gt: int | None = None,
1360
+ topics: Sequence[str] = (),
1361
+ ) -> None:
1362
+ super().__init__(recorder=recorder, gt=gt, topics=topics)
1363
+ self._select_limit = 500
1364
+ self._notifications: Sequence[Notification] = []
1365
+ self._notifications_index: int = 0
1366
+ self._notifications_queue: Queue[Sequence[Notification]] = Queue(maxsize=10)
1367
+ self._has_been_notified = Event()
1368
+ self._thread_error: BaseException | None = None
1369
+ self._pull_thread = Thread(target=self._loop_on_pull)
1370
+ self._pull_thread.start()
1371
+
1372
+ def __exit__(self, *args: object, **kwargs: Any) -> None:
1373
+ try:
1374
+ super().__exit__(*args, **kwargs)
1375
+ finally:
1376
+ self._pull_thread.join()
1377
+
1378
+ def stop(self) -> None:
1379
+ """Stops the subscription."""
1380
+ super().stop()
1381
+ self._notifications_queue.put([])
1382
+ self._has_been_notified.set()
1383
+
1384
+ def __next__(self) -> Notification:
1385
+ # If necessary, get a new list of notifications from the recorder.
1386
+ if (
1387
+ self._notifications_index == len(self._notifications)
1388
+ and not self._has_been_stopped
1389
+ ):
1390
+ self._notifications = self._notifications_queue.get()
1391
+ self._notifications_index = 0
1392
+
1393
+ # Stop the iteration if necessary, maybe raise thread error.
1394
+ if self._has_been_stopped or not self._notifications:
1395
+ if self._thread_error is not None:
1396
+ raise self._thread_error
1397
+ raise StopIteration
1398
+
1399
+ # Return a notification from previously obtained list.
1400
+ notification = self._notifications[self._notifications_index]
1401
+ self._notifications_index += 1
1402
+ return notification
1403
+
1404
+ def _loop_on_pull(self) -> None:
1405
+ try:
1406
+ self._pull() # Already recorded events.
1407
+ while not self._has_been_stopped:
1408
+ self._has_been_notified.wait()
1409
+ self._pull() # Newly recorded events.
1410
+ except BaseException as e:
1411
+ if self._thread_error is None:
1412
+ self._thread_error = e
1413
+ self.stop()
1414
+
1415
+ def _pull(self) -> None:
1416
+ while not self._has_been_stopped:
1417
+ self._has_been_notified.clear()
1418
+ notifications = self._recorder.select_notifications(
1419
+ start=self._last_notification_id or 0,
1420
+ limit=self._select_limit,
1421
+ topics=self._topics,
1422
+ inclusive_of_start=False,
1423
+ )
1424
+ if len(notifications) > 0:
1425
+ # print("Putting", len(notifications), "notifications into queue")
1426
+ self._notifications_queue.put(notifications)
1427
+ self._last_notification_id = notifications[-1].id
1428
+ if len(notifications) < self._select_limit:
1429
+ break