eventsourcing 9.4.0a8__py3-none-any.whl → 9.4.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of eventsourcing might be problematic. Click here for more details.

@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import contextlib
3
4
  import os
4
5
  from abc import ABC, abstractmethod
5
6
  from collections.abc import Iterable, Iterator, Sequence
@@ -26,6 +27,7 @@ from eventsourcing.domain import (
26
27
  DomainEventProtocol,
27
28
  EventSourcingError,
28
29
  MutableOrImmutableAggregate,
30
+ SDomainEvent,
29
31
  Snapshot,
30
32
  SnapshotProtocol,
31
33
  TDomainEvent,
@@ -352,19 +354,18 @@ class Repository:
352
354
  return aggregate
353
355
 
354
356
  def _use_fastforward_lock(self, aggregate_id: UUID) -> Lock:
357
+ lock: Lock | None = None
355
358
  with self._fastforward_locks_lock:
356
- try:
359
+ num_users = 0
360
+ with contextlib.suppress(KeyError):
357
361
  lock, num_users = self._fastforward_locks_inuse[aggregate_id]
358
- except KeyError:
359
- try:
362
+ if lock is None:
363
+ with contextlib.suppress(KeyError):
360
364
  lock = self._fastforward_locks_cache.get(aggregate_id, evict=True)
361
- except KeyError:
362
- lock = Lock()
363
- finally:
364
- num_users = 0
365
- finally:
366
- num_users += 1
367
- self._fastforward_locks_inuse[aggregate_id] = (lock, num_users)
365
+ if lock is None:
366
+ lock = Lock()
367
+ num_users += 1
368
+ self._fastforward_locks_inuse[aggregate_id] = (lock, num_users)
368
369
  return lock
369
370
 
370
371
  def _disuse_fastforward_lock(self, aggregate_id: UUID) -> None:
@@ -610,12 +611,10 @@ class Application:
610
611
  name = "Application"
611
612
  env: ClassVar[dict[str, str]] = {}
612
613
  is_snapshotting_enabled: bool = False
613
- snapshotting_intervals: ClassVar[
614
- dict[type[MutableOrImmutableAggregate], int] | None
615
- ] = None
614
+ snapshotting_intervals: ClassVar[dict[type[MutableOrImmutableAggregate], int]] = {}
616
615
  snapshotting_projectors: ClassVar[
617
- dict[type[MutableOrImmutableAggregate], ProjectorFunction[Any, Any]] | None
618
- ] = None
616
+ dict[type[MutableOrImmutableAggregate], ProjectorFunction[Any, Any]]
617
+ ] = {}
619
618
  snapshot_class: type[SnapshotProtocol] = Snapshot
620
619
  log_section_size = 10
621
620
  notify_topics: Sequence[str] = []
@@ -817,12 +816,9 @@ class Application:
817
816
  continue
818
817
  interval = self.snapshotting_intervals.get(type(aggregate))
819
818
  if interval is not None and event.originator_version % interval == 0:
820
- if (
821
- self.snapshotting_projectors
822
- and type(aggregate) in self.snapshotting_projectors
823
- ):
819
+ try:
824
820
  projector_func = self.snapshotting_projectors[type(aggregate)]
825
- else:
821
+ except KeyError:
826
822
  projector_func = project_aggregate
827
823
  if projector_func is project_aggregate and not isinstance(
828
824
  event, CanMutateProtocol
@@ -947,10 +943,10 @@ class EventSourcedLog(Generic[TDomainEvent]):
947
943
 
948
944
  def _trigger_event(
949
945
  self,
950
- logged_cls: type[T] | None,
946
+ logged_cls: type[SDomainEvent],
951
947
  next_originator_version: int | None = None,
952
948
  **kwargs: Any,
953
- ) -> T:
949
+ ) -> SDomainEvent:
954
950
  """
955
951
  Constructs and returns a new log event.
956
952
  """
@@ -961,7 +957,7 @@ class EventSourcedLog(Generic[TDomainEvent]):
961
957
  else:
962
958
  next_originator_version = last_logged.originator_version + 1
963
959
 
964
- return logged_cls( # type: ignore
960
+ return logged_cls(
965
961
  originator_id=self.originator_id,
966
962
  originator_version=next_originator_version,
967
963
  timestamp=datetime_now_with_tzinfo(),
eventsourcing/cipher.py CHANGED
@@ -5,7 +5,9 @@ from base64 import b64decode, b64encode
5
5
  from typing import TYPE_CHECKING
6
6
 
7
7
  from Crypto.Cipher import AES
8
- from Crypto.Cipher._mode_gcm import GcmMode
8
+ from Crypto.Cipher._mode_gcm import (
9
+ GcmMode, # pyright: ignore [reportPrivateImportUsage]
10
+ )
9
11
  from Crypto.Cipher.AES import key_size
10
12
 
11
13
  from eventsourcing.persistence import Cipher
eventsourcing/dispatch.py CHANGED
@@ -1,14 +1,50 @@
1
1
  from __future__ import annotations
2
2
 
3
- from functools import singledispatchmethod as _singledispatchmethod
3
+ import functools
4
+ from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, cast, overload
4
5
 
6
+ _T = TypeVar("_T")
7
+ _S = TypeVar("_S")
5
8
 
6
- class singledispatchmethod(_singledispatchmethod): # noqa: N801
7
- def __init__(self, func):
9
+ if TYPE_CHECKING:
10
+
11
+ class _singledispatchmethod(functools.singledispatchmethod[_T]): # noqa: N801
12
+ pass
13
+
14
+ else:
15
+
16
+ class _singledispatchmethod( # noqa: N801
17
+ functools.singledispatchmethod, Generic[_T]
18
+ ):
19
+ pass
20
+
21
+
22
+ class singledispatchmethod(_singledispatchmethod[_T]): # noqa: N801
23
+ def __init__(self, func: Callable[..., _T]) -> None:
8
24
  super().__init__(func)
9
- self.deferred_registrations = []
25
+ self.deferred_registrations: list[
26
+ tuple[type[Any] | Callable[..., _T], Callable[..., _T] | None]
27
+ ] = []
28
+
29
+ @overload
30
+ def register(
31
+ self, cls: type[Any], method: None = None
32
+ ) -> Callable[[Callable[..., _T]], Callable[..., _T]]: ... # pragma: no cover
33
+ @overload
34
+ def register(
35
+ self, cls: Callable[..., _T], method: None = None
36
+ ) -> Callable[..., _T]: ... # pragma: no cover
37
+
38
+ @overload
39
+ def register(
40
+ self, cls: type[Any], method: Callable[..., _T]
41
+ ) -> Callable[..., _T]: ... # pragma: no cover
10
42
 
11
- def register(self, cls, method=None):
43
+ def register(
44
+ self,
45
+ cls: type[Any] | Callable[..., _T],
46
+ method: Callable[..., _T] | None = None,
47
+ ) -> Callable[[Callable[..., _T]], Callable[..., _T]] | Callable[..., _T]:
12
48
  """generic_method.register(cls, func) -> func
13
49
 
14
50
  Registers a new implementation for the given *cls* on a *generic_method*.
@@ -22,17 +58,22 @@ class singledispatchmethod(_singledispatchmethod): # noqa: N801
22
58
 
23
59
  # for globals in typing.get_type_hints() in Python 3.8 and 3.9
24
60
  if not hasattr(cls, "__wrapped__"):
25
- cls.__wrapped__ = cls.__func__
61
+ cls.__dict__["__wrapped__"] = cls.__func__
62
+ # cls.__wrapped__ = cls.__func__
26
63
 
27
64
  try:
28
- return self.dispatcher.register(cls, func=method)
65
+ return self.dispatcher.register(cast(type[Any], cls), func=method)
29
66
  except NameError:
30
- self.deferred_registrations.append([cls, method])
67
+ self.deferred_registrations.append(
68
+ (cls, method) # pyright: ignore [reportArgumentType]
69
+ )
31
70
  # TODO: Fix this....
32
- return method or cls
71
+ return method or cls # pyright: ignore [reportReturnType]
33
72
 
34
- def __get__(self, obj, cls=None):
73
+ def __get__(self, obj: _S, cls: type[_S] | None = None) -> Callable[..., _T]:
35
74
  for registered_cls, registered_method in self.deferred_registrations:
36
- self.dispatcher.register(registered_cls, func=registered_method)
75
+ self.dispatcher.register(
76
+ cast(type[Any], registered_cls), func=registered_method
77
+ )
37
78
  self.deferred_registrations = []
38
79
  return super().__get__(obj, cls=cls)
eventsourcing/domain.py CHANGED
@@ -88,20 +88,26 @@ class DomainEventProtocol(Protocol):
88
88
  kinds of domain event classes, such as Pydantic classes.
89
89
  """
90
90
 
91
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
92
+ pass # pragma: no cover
93
+
91
94
  @property
92
95
  def originator_id(self) -> UUID:
93
96
  """
94
97
  UUID identifying an aggregate to which the event belongs.
95
98
  """
99
+ raise NotImplementedError # pragma: no cover
96
100
 
97
101
  @property
98
102
  def originator_version(self) -> int:
99
103
  """
100
104
  Integer identifying the version of the aggregate when the event occurred.
101
105
  """
106
+ raise NotImplementedError # pragma: no cover
102
107
 
103
108
 
104
109
  TDomainEvent = TypeVar("TDomainEvent", bound=DomainEventProtocol)
110
+ SDomainEvent = TypeVar("SDomainEvent", bound=DomainEventProtocol)
105
111
 
106
112
 
107
113
  class MutableAggregateProtocol(Protocol):
@@ -120,18 +126,21 @@ class MutableAggregateProtocol(Protocol):
120
126
  """
121
127
  Mutable aggregates have a read-only ID that is a UUID.
122
128
  """
129
+ raise NotImplementedError # pragma: no cover
123
130
 
124
131
  @property
125
132
  def version(self) -> int:
126
133
  """
127
134
  Mutable aggregates have a read-write version that is an int.
128
135
  """
136
+ raise NotImplementedError # pragma: no cover
129
137
 
130
138
  @version.setter
131
139
  def version(self, value: int) -> None:
132
140
  """
133
141
  Mutable aggregates have a read-write version that is an int.
134
142
  """
143
+ raise NotImplementedError # pragma: no cover
135
144
 
136
145
 
137
146
  class ImmutableAggregateProtocol(Protocol):
@@ -150,12 +159,14 @@ class ImmutableAggregateProtocol(Protocol):
150
159
  """
151
160
  Immutable aggregates have a read-only ID that is a UUID.
152
161
  """
162
+ raise NotImplementedError # pragma: no cover
153
163
 
154
164
  @property
155
165
  def version(self) -> int:
156
166
  """
157
167
  Immutable aggregates have a read-only version that is an int.
158
168
  """
169
+ raise NotImplementedError # pragma: no cover
159
170
 
160
171
 
161
172
  MutableOrImmutableAggregate = Union[
@@ -180,6 +191,7 @@ class CollectEventsProtocol(Protocol):
180
191
  """
181
192
  Returns a sequence of events.
182
193
  """
194
+ raise NotImplementedError # pragma: no cover
183
195
 
184
196
 
185
197
  @runtime_checkable
@@ -233,7 +245,7 @@ class CanCreateTimestamp:
233
245
  return datetime_now_with_tzinfo()
234
246
 
235
247
 
236
- TAggregate = TypeVar("TAggregate", bound="Aggregate")
248
+ TAggregate = TypeVar("TAggregate", bound="BaseAggregate")
237
249
 
238
250
 
239
251
  class HasOriginatorIDVersion:
@@ -300,7 +312,7 @@ class CanMutateAggregate(HasOriginatorIDVersion, CanCreateTimestamp):
300
312
  # Return the mutated aggregate.
301
313
  return aggregate
302
314
 
303
- def apply(self, aggregate: Aggregate) -> None:
315
+ def apply(self, aggregate: Any) -> None:
304
316
  """
305
317
  Applies the domain event to its aggregate.
306
318
 
@@ -934,9 +946,9 @@ def _raise_missing_names_type_error(missing_names: list[str], msg: str) -> None:
934
946
  raise TypeError(msg)
935
947
 
936
948
 
937
- _annotations_mention_id: set[type[Aggregate]] = set()
938
- _init_mentions_id: set[type[Aggregate]] = set()
939
- _create_id_param_names: dict[type[Aggregate], list[str]] = defaultdict(list)
949
+ _annotations_mention_id: set[type[BaseAggregate]] = set()
950
+ _init_mentions_id: set[type[BaseAggregate]] = set()
951
+ _create_id_param_names: dict[type[BaseAggregate], list[str]] = defaultdict(list)
940
952
 
941
953
 
942
954
  class MetaAggregate(EventsourcingType, Generic[TAggregate], type):
@@ -1011,19 +1023,13 @@ class MetaAggregate(EventsourcingType, Generic[TAggregate], type):
1011
1023
  _created_event_class: type[CanInitAggregate]
1012
1024
 
1013
1025
 
1014
- class Aggregate(metaclass=MetaAggregate):
1026
+ class BaseAggregate(metaclass=MetaAggregate):
1015
1027
  """
1016
1028
  Base class for aggregates.
1017
1029
  """
1018
1030
 
1019
1031
  INITIAL_VERSION = 1
1020
1032
 
1021
- class Event(AggregateEvent):
1022
- pass
1023
-
1024
- class Created(Event, AggregateCreated):
1025
- pass
1026
-
1027
1033
  @staticmethod
1028
1034
  def create_id(*_: Any, **__: Any) -> UUID:
1029
1035
  """
@@ -1081,7 +1087,7 @@ class Aggregate(metaclass=MetaAggregate):
1081
1087
 
1082
1088
  assert agg is not None
1083
1089
  # Append the domain event to pending list.
1084
- agg.pending_events.append(created_event)
1090
+ agg._pending_events.append(created_event)
1085
1091
  # Return the aggregate.
1086
1092
  return agg
1087
1093
 
@@ -1197,7 +1203,7 @@ class Aggregate(metaclass=MetaAggregate):
1197
1203
  return f"{type(self).__name__}({', '.join(attrs)})"
1198
1204
 
1199
1205
  def __init_subclass__(
1200
- cls: type[Aggregate], *, created_event_name: str | None = None
1206
+ cls: type[BaseAggregate], *, created_event_name: str | None = None
1201
1207
  ) -> None:
1202
1208
  """
1203
1209
  Initialises aggregate subclass by defining __init__ method and event classes.
@@ -1211,8 +1217,10 @@ class Aggregate(metaclass=MetaAggregate):
1211
1217
  except KeyError:
1212
1218
  pass
1213
1219
 
1214
- if class_annotations or any(
1215
- dataclasses.is_dataclass(base) for base in cls.__bases__
1220
+ if (
1221
+ class_annotations
1222
+ or cls in _annotations_mention_id
1223
+ or any(dataclasses.is_dataclass(base) for base in cls.__bases__)
1216
1224
  ):
1217
1225
  dataclasses.dataclass(eq=False, repr=False)(cls)
1218
1226
 
@@ -1223,7 +1231,9 @@ class Aggregate(metaclass=MetaAggregate):
1223
1231
  base_event_cls = cls.__dict__[base_event_name]
1224
1232
  except KeyError:
1225
1233
  base_event_cls = cls._define_event_class(
1226
- base_event_name, (cls.Event,), None
1234
+ name=base_event_name,
1235
+ bases=(getattr(cls, base_event_name, AggregateEvent),),
1236
+ apply_method=None,
1227
1237
  )
1228
1238
  setattr(cls, base_event_name, base_event_cls)
1229
1239
 
@@ -1482,9 +1492,12 @@ class Aggregate(metaclass=MetaAggregate):
1482
1492
  setattr(cls, name, sub_class)
1483
1493
 
1484
1494
 
1485
- # Special case for the Aggregate class because
1486
- # it's not processed by Aggregate.__init_subclass__.
1487
- _created_event_classes[Aggregate] = [Aggregate.Created]
1495
+ class Aggregate(BaseAggregate):
1496
+ class Event(AggregateEvent):
1497
+ pass
1498
+
1499
+ class Created(Event, AggregateCreated):
1500
+ pass
1488
1501
 
1489
1502
 
1490
1503
  @overload
@@ -1578,6 +1591,7 @@ class SnapshotProtocol(DomainEventProtocol, Protocol):
1578
1591
  """
1579
1592
  Snapshots have a read-only 'state'.
1580
1593
  """
1594
+ raise NotImplementedError # pragma: no cover
1581
1595
 
1582
1596
  # TODO: Improve on this 'Any'.
1583
1597
  @classmethod
@@ -1594,6 +1608,16 @@ class CanSnapshotAggregate(HasOriginatorIDVersion, CanCreateTimestamp):
1594
1608
  topic: str
1595
1609
  state: Any
1596
1610
 
1611
+ def __init__(
1612
+ self,
1613
+ originator_id: UUID,
1614
+ originator_version: int,
1615
+ timestamp: datetime,
1616
+ topic: str,
1617
+ state: Any,
1618
+ ) -> None:
1619
+ raise NotImplementedError # pragma: no cover
1620
+
1597
1621
  @classmethod
1598
1622
  def take(
1599
1623
  cls: type[TCanSnapshotAggregate],
@@ -1610,7 +1634,7 @@ class CanSnapshotAggregate(HasOriginatorIDVersion, CanCreateTimestamp):
1610
1634
  aggregate_state.pop("_id")
1611
1635
  aggregate_state.pop("_version")
1612
1636
  aggregate_state.pop("_pending_events")
1613
- return cls( # type: ignore
1637
+ return cls(
1614
1638
  originator_id=aggregate.id,
1615
1639
  originator_version=aggregate.version,
1616
1640
  timestamp=cls.create_timestamp(),
@@ -141,7 +141,7 @@ class NotificationLogJSONClient(NotificationLog):
141
141
  self,
142
142
  start: int | None,
143
143
  limit: int,
144
- _: int | None = None,
144
+ stop: int | None = None,
145
145
  topics: Sequence[str] = (),
146
146
  *,
147
147
  inclusive_of_start: bool = True,
@@ -517,10 +517,10 @@ class TrackingRecorder(Recorder, ABC):
517
517
  interrupt: Event | None = None,
518
518
  ) -> None:
519
519
  """
520
- Block until a tracking object with the given application name and
521
- notification ID has been recorded.
520
+ Block until a tracking object with the given application name and a
521
+ notification ID greater than equal to the given value has been recorded.
522
522
 
523
- Polls has_tracking_id() with exponential backoff until the timeout
523
+ Polls max_tracking_id() with exponential backoff until the timeout
524
524
  is reached, or until the optional interrupt event is set.
525
525
 
526
526
  The timeout argument should be a floating point number specifying a
@@ -534,7 +534,10 @@ class TrackingRecorder(Recorder, ABC):
534
534
  deadline = monotonic() + timeout
535
535
  delay_ms = 1.0
536
536
  while True:
537
- if self.has_tracking_id(application_name, notification_id):
537
+ max_tracking_id = self.max_tracking_id(application_name)
538
+ if notification_id is None or (
539
+ max_tracking_id is not None and max_tracking_id >= notification_id
540
+ ):
538
541
  break
539
542
  if interrupt:
540
543
  if interrupt.wait(timeout=delay_ms / 1000):
@@ -751,6 +754,7 @@ class InfrastructureFactory(ABC, Generic[TTrackingRecorder]):
751
754
  mapper_topic = self.env.get(self.MAPPER_TOPIC)
752
755
  mapper_class = resolve_topic(mapper_topic) if mapper_topic else Mapper
753
756
 
757
+ assert isinstance(mapper_class, type) and issubclass(mapper_class, Mapper)
754
758
  return mapper_class(
755
759
  transcoder=transcoder or self.transcoder(),
756
760
  cipher=self.cipher(),
eventsourcing/postgres.py CHANGED
@@ -5,7 +5,7 @@ import logging
5
5
  from asyncio import CancelledError
6
6
  from contextlib import contextmanager
7
7
  from threading import Thread
8
- from typing import TYPE_CHECKING, Any, Callable
8
+ from typing import TYPE_CHECKING, Any, Callable, cast
9
9
 
10
10
  import psycopg
11
11
  import psycopg.errors
@@ -13,6 +13,8 @@ import psycopg_pool
13
13
  from psycopg import Connection, Cursor, Error
14
14
  from psycopg.generators import notifies
15
15
  from psycopg.rows import DictRow, dict_row
16
+ from psycopg.sql import SQL, Composed, Identifier
17
+ from typing_extensions import TypeVar
16
18
 
17
19
  from eventsourcing.persistence import (
18
20
  AggregateRecorder,
@@ -41,6 +43,7 @@ if TYPE_CHECKING:
41
43
  from collections.abc import Iterator, Sequence
42
44
  from uuid import UUID
43
45
 
46
+ from psycopg.abc import Query
44
47
  from typing_extensions import Self
45
48
 
46
49
  logging.getLogger("psycopg.pool").setLevel(logging.CRITICAL)
@@ -118,12 +121,11 @@ class PostgresDatastore:
118
121
  check=check,
119
122
  )
120
123
  self.lock_timeout = lock_timeout
121
- self.schema = schema.strip()
124
+ self.schema = schema.strip() or "public"
122
125
 
123
126
  def after_connect_func(self) -> Callable[[Connection[Any]], None]:
124
- statement = (
125
- "SET idle_in_transaction_session_timeout = "
126
- f"'{self.idle_in_transaction_session_timeout}s'"
127
+ statement = SQL("SET idle_in_transaction_session_timeout = '{0}s'").format(
128
+ self.idle_in_transaction_session_timeout
127
129
  )
128
130
 
129
131
  def after_connect(conn: Connection[DictRow]) -> None:
@@ -168,7 +170,6 @@ class PostgresDatastore:
168
170
 
169
171
  @contextmanager
170
172
  def transaction(self, *, commit: bool = False) -> Iterator[Cursor[DictRow]]:
171
- conn: Connection[DictRow]
172
173
  with self.get_connection() as conn, conn.transaction(force_rollback=not commit):
173
174
  yield conn.cursor()
174
175
 
@@ -195,17 +196,12 @@ class PostgresRecorder:
195
196
  self.datastore = datastore
196
197
  self.create_table_statements = self.construct_create_table_statements()
197
198
 
198
- def construct_create_table_statements(self) -> list[str]:
199
+ def construct_create_table_statements(self) -> list[Composed]:
199
200
  return []
200
201
 
201
202
  def check_table_name_length(self, table_name: str) -> None:
202
- schema_prefix = self.datastore.schema + "."
203
- if table_name.startswith(schema_prefix):
204
- unqualified_table_name = table_name[len(schema_prefix) :]
205
- else:
206
- unqualified_table_name = table_name
207
- if len(unqualified_table_name) > 63:
208
- msg = f"Table name too long: {unqualified_table_name}"
203
+ if len(table_name) > 63:
204
+ msg = f"Table name too long: {table_name}"
209
205
  raise ProgrammingError(msg)
210
206
 
211
207
  def create_table(self) -> None:
@@ -226,38 +222,45 @@ class PostgresAggregateRecorder(PostgresRecorder, AggregateRecorder):
226
222
  self.events_table_name = events_table_name
227
223
  # Index names can't be qualified names, but
228
224
  # are created in the same schema as the table.
229
- if "." in self.events_table_name:
230
- unqualified_table_name = self.events_table_name.split(".")[-1]
231
- else:
232
- unqualified_table_name = self.events_table_name
233
225
  self.notification_id_index_name = (
234
- f"{unqualified_table_name}_notification_id_idx "
226
+ f"{self.events_table_name}_notification_id_idx"
235
227
  )
236
228
  self.create_table_statements.append(
237
- "CREATE TABLE IF NOT EXISTS "
238
- f"{self.events_table_name} ("
239
- "originator_id uuid NOT NULL, "
240
- "originator_version bigint NOT NULL, "
241
- "topic text, "
242
- "state bytea, "
243
- "PRIMARY KEY "
244
- "(originator_id, originator_version)) "
245
- "WITH (autovacuum_enabled=false)"
229
+ SQL(
230
+ "CREATE TABLE IF NOT EXISTS {0}.{1} ("
231
+ "originator_id uuid NOT NULL, "
232
+ "originator_version bigint NOT NULL, "
233
+ "topic text, "
234
+ "state bytea, "
235
+ "PRIMARY KEY "
236
+ "(originator_id, originator_version)) "
237
+ "WITH (autovacuum_enabled=false)"
238
+ ).format(
239
+ Identifier(self.datastore.schema),
240
+ Identifier(self.events_table_name),
241
+ )
246
242
  )
247
243
 
248
- self.insert_events_statement = (
249
- f"INSERT INTO {self.events_table_name} VALUES (%s, %s, %s, %s)"
244
+ self.insert_events_statement = SQL(
245
+ "INSERT INTO {0}.{1} VALUES (%s, %s, %s, %s)"
246
+ ).format(
247
+ Identifier(self.datastore.schema),
248
+ Identifier(self.events_table_name),
250
249
  )
251
- self.select_events_statement = (
252
- f"SELECT * FROM {self.events_table_name} WHERE originator_id = %s"
250
+
251
+ self.select_events_statement = SQL(
252
+ "SELECT * FROM {0}.{1} WHERE originator_id = %s"
253
+ ).format(
254
+ Identifier(self.datastore.schema),
255
+ Identifier(self.events_table_name),
253
256
  )
254
- self.lock_table_statements: list[str] = []
257
+
258
+ self.lock_table_statements: list[Query] = []
255
259
 
256
260
  @retry((InterfaceError, OperationalError), max_attempts=10, wait=0.2)
257
261
  def insert_events(
258
262
  self, stored_events: list[StoredEvent], **kwargs: Any
259
263
  ) -> Sequence[int] | None:
260
- conn: Connection[DictRow]
261
264
  exc: Exception | None = None
262
265
  notification_ids: Sequence[int] | None = None
263
266
  with self.datastore.get_connection() as conn:
@@ -316,7 +319,7 @@ class PostgresAggregateRecorder(PostgresRecorder, AggregateRecorder):
316
319
  )
317
320
  for stored_event in stored_events
318
321
  ],
319
- returning="RETURNING" in self.insert_events_statement,
322
+ returning="RETURNING" in self.insert_events_statement.as_string(),
320
323
  )
321
324
 
322
325
  def _lock_table(self, curs: Cursor[DictRow]) -> None:
@@ -347,18 +350,18 @@ class PostgresAggregateRecorder(PostgresRecorder, AggregateRecorder):
347
350
  params: list[Any] = [originator_id]
348
351
  if gt is not None:
349
352
  params.append(gt)
350
- statement += " AND originator_version > %s"
353
+ statement += SQL(" AND originator_version > %s")
351
354
  if lte is not None:
352
355
  params.append(lte)
353
- statement += " AND originator_version <= %s"
354
- statement += " ORDER BY originator_version"
356
+ statement += SQL(" AND originator_version <= %s")
357
+ statement += SQL(" ORDER BY originator_version")
355
358
  if desc is False:
356
- statement += " ASC"
359
+ statement += SQL(" ASC")
357
360
  else:
358
- statement += " DESC"
361
+ statement += SQL(" DESC")
359
362
  if limit is not None:
360
363
  params.append(limit)
361
- statement += " LIMIT %s"
364
+ statement += SQL(" LIMIT %s")
362
365
 
363
366
  with self.datastore.get_connection() as conn, conn.cursor() as curs:
364
367
  curs.execute(statement, params, prepare=True)
@@ -381,9 +384,8 @@ class PostgresApplicationRecorder(PostgresAggregateRecorder, ApplicationRecorder
381
384
  events_table_name: str = "stored_events",
382
385
  ):
383
386
  super().__init__(datastore, events_table_name=events_table_name)
384
- self.create_table_statements[-1] = (
385
- "CREATE TABLE IF NOT EXISTS "
386
- f"{self.events_table_name} ("
387
+ self.create_table_statements[-1] = SQL(
388
+ "CREATE TABLE IF NOT EXISTS {0}.{1} ("
387
389
  "originator_id uuid NOT NULL, "
388
390
  "originator_version bigint NOT NULL, "
389
391
  "topic text, "
@@ -392,20 +394,40 @@ class PostgresApplicationRecorder(PostgresAggregateRecorder, ApplicationRecorder
392
394
  "PRIMARY KEY "
393
395
  "(originator_id, originator_version)) "
394
396
  "WITH (autovacuum_enabled=false)"
397
+ ).format(
398
+ Identifier(self.datastore.schema),
399
+ Identifier(self.events_table_name),
395
400
  )
401
+
396
402
  self.create_table_statements.append(
397
- "CREATE UNIQUE INDEX IF NOT EXISTS "
398
- f"{self.notification_id_index_name}"
399
- f"ON {self.events_table_name} (notification_id ASC);"
403
+ SQL(
404
+ "CREATE UNIQUE INDEX IF NOT EXISTS {0} "
405
+ "ON {1}.{2} (notification_id ASC);"
406
+ ).format(
407
+ Identifier(self.notification_id_index_name),
408
+ Identifier(self.datastore.schema),
409
+ Identifier(self.events_table_name),
410
+ )
400
411
  )
412
+
401
413
  self.channel_name = self.events_table_name.replace(".", "_")
402
- self.insert_events_statement += " RETURNING notification_id"
403
- self.max_notification_id_statement = (
404
- f"SELECT MAX(notification_id) FROM {self.events_table_name}"
414
+ self.insert_events_statement = self.insert_events_statement + SQL(
415
+ " RETURNING notification_id"
416
+ )
417
+
418
+ self.max_notification_id_statement = SQL(
419
+ "SELECT MAX(notification_id) FROM {0}.{1}"
420
+ ).format(
421
+ Identifier(self.datastore.schema),
422
+ Identifier(self.events_table_name),
405
423
  )
424
+
406
425
  self.lock_table_statements = [
407
- f"SET LOCAL lock_timeout = '{self.datastore.lock_timeout}s'",
408
- f"LOCK TABLE {self.events_table_name} IN EXCLUSIVE MODE",
426
+ SQL("SET LOCAL lock_timeout = '{0}s'").format(self.datastore.lock_timeout),
427
+ SQL("LOCK TABLE {0}.{1} IN EXCLUSIVE MODE").format(
428
+ Identifier(self.datastore.schema),
429
+ Identifier(self.events_table_name),
430
+ ),
409
431
  ]
410
432
 
411
433
  @retry((InterfaceError, OperationalError), max_attempts=10, wait=0.2)
@@ -424,37 +446,44 @@ class PostgresApplicationRecorder(PostgresAggregateRecorder, ApplicationRecorder
424
446
  """
425
447
 
426
448
  params: list[int | str | Sequence[str]] = []
427
- statement = f"SELECT * FROM {self.events_table_name}"
449
+ statement = SQL("SELECT * FROM {0}.{1}").format(
450
+ Identifier(self.datastore.schema),
451
+ Identifier(self.events_table_name),
452
+ )
428
453
  has_where = False
429
454
  if start is not None:
430
- statement += " WHERE"
455
+ statement += SQL(" WHERE")
431
456
  has_where = True
432
457
  params.append(start)
433
458
  if inclusive_of_start:
434
- statement += " notification_id>=%s"
459
+ statement += SQL(" notification_id>=%s")
435
460
  else:
436
- statement += " notification_id>%s"
461
+ statement += SQL(" notification_id>%s")
437
462
 
438
463
  if stop is not None:
439
464
  if not has_where:
440
465
  has_where = True
441
- statement += " WHERE"
466
+ statement += SQL(" WHERE")
442
467
  else:
443
- statement += " AND"
468
+ statement += SQL(" AND")
444
469
 
445
470
  params.append(stop)
446
- statement += " notification_id <= %s"
471
+ statement += SQL(" notification_id <= %s")
447
472
 
448
473
  if topics:
474
+ # Check sequence and ensure list of strings.
475
+ assert isinstance(topics, (tuple, list)), topics
476
+ topics = list(topics) if isinstance(topics, tuple) else topics
477
+ assert all(isinstance(t, str) for t in topics), topics
449
478
  if not has_where:
450
- statement += " WHERE"
479
+ statement += SQL(" WHERE")
451
480
  else:
452
- statement += " AND"
481
+ statement += SQL(" AND")
453
482
  params.append(topics)
454
- statement += " topic = ANY(%s)"
483
+ statement += SQL(" topic = ANY(%s)")
455
484
 
456
485
  params.append(limit)
457
- statement += " ORDER BY notification_id LIMIT %s"
486
+ statement += SQL(" ORDER BY notification_id LIMIT %s")
458
487
 
459
488
  connection = self.datastore.get_connection()
460
489
  with connection as conn, conn.cursor() as curs:
@@ -475,7 +504,6 @@ class PostgresApplicationRecorder(PostgresAggregateRecorder, ApplicationRecorder
475
504
  """
476
505
  Returns the maximum notification ID.
477
506
  """
478
- conn: Connection[DictRow]
479
507
  with self.datastore.get_connection() as conn, conn.cursor() as curs:
480
508
  curs.execute(self.max_notification_id_statement)
481
509
  fetchone = curs.fetchone()
@@ -507,7 +535,7 @@ class PostgresApplicationRecorder(PostgresAggregateRecorder, ApplicationRecorder
507
535
  curs.execute(lock_statement, prepare=True)
508
536
 
509
537
  def _notify_channel(self, curs: Cursor[DictRow]) -> None:
510
- curs.execute("NOTIFY " + self.channel_name)
538
+ curs.execute(SQL("NOTIFY {0}").format(Identifier(self.channel_name)))
511
539
 
512
540
  def _fetch_ids_after_insert_events(
513
541
  self,
@@ -554,7 +582,9 @@ class PostgresSubscription(ListenNotifySubscription[PostgresApplicationRecorder]
554
582
  def _listen(self) -> None:
555
583
  try:
556
584
  with self._recorder.datastore.get_connection() as conn:
557
- conn.execute("LISTEN " + self._recorder.channel_name)
585
+ conn.execute(
586
+ SQL("LISTEN {0}").format(Identifier(self._recorder.channel_name))
587
+ )
558
588
  while not self._has_been_stopped and not self._thread_error:
559
589
  # This block simplifies psycopg's conn.notifies(), because
560
590
  # we aren't interested in the actual notify messages, and
@@ -585,30 +615,42 @@ class PostgresTrackingRecorder(PostgresRecorder, TrackingRecorder):
585
615
  self.check_table_name_length(tracking_table_name)
586
616
  self.tracking_table_name = tracking_table_name
587
617
  self.create_table_statements.append(
588
- "CREATE TABLE IF NOT EXISTS "
589
- f"{self.tracking_table_name} ("
590
- "application_name text, "
591
- "notification_id bigint, "
592
- "PRIMARY KEY "
593
- "(application_name, notification_id))"
618
+ SQL(
619
+ "CREATE TABLE IF NOT EXISTS {0}.{1} ("
620
+ "application_name text, "
621
+ "notification_id bigint, "
622
+ "PRIMARY KEY "
623
+ "(application_name, notification_id))"
624
+ ).format(
625
+ Identifier(self.datastore.schema),
626
+ Identifier(self.tracking_table_name),
627
+ )
594
628
  )
595
- self.insert_tracking_statement = (
596
- f"INSERT INTO {self.tracking_table_name} VALUES (%s, %s)"
629
+
630
+ self.insert_tracking_statement = SQL(
631
+ "INSERT INTO {0}.{1} VALUES (%s, %s)"
632
+ ).format(
633
+ Identifier(self.datastore.schema),
634
+ Identifier(self.tracking_table_name),
597
635
  )
598
- self.max_tracking_id_statement = (
599
- "SELECT MAX(notification_id) "
600
- f"FROM {self.tracking_table_name} "
601
- "WHERE application_name=%s"
636
+
637
+ self.max_tracking_id_statement = SQL(
638
+ "SELECT MAX(notification_id) FROM {0}.{1} WHERE application_name=%s"
639
+ ).format(
640
+ Identifier(self.datastore.schema),
641
+ Identifier(self.tracking_table_name),
602
642
  )
603
- self.count_tracking_id_statement = (
604
- "SELECT COUNT(*) "
605
- f"FROM {self.tracking_table_name} "
643
+
644
+ self.count_tracking_id_statement = SQL(
645
+ "SELECT COUNT(*) FROM {0}.{1} "
606
646
  "WHERE application_name=%s AND notification_id=%s"
647
+ ).format(
648
+ Identifier(self.datastore.schema),
649
+ Identifier(self.tracking_table_name),
607
650
  )
608
651
 
609
652
  @retry((InterfaceError, OperationalError), max_attempts=10, wait=0.2)
610
653
  def insert_tracking(self, tracking: Tracking) -> None:
611
- conn: Connection[DictRow]
612
654
  with (
613
655
  self.datastore.get_connection() as conn,
614
656
  conn.transaction(),
@@ -648,7 +690,6 @@ class PostgresTrackingRecorder(PostgresRecorder, TrackingRecorder):
648
690
  ) -> bool:
649
691
  if notification_id is None:
650
692
  return True
651
- conn: Connection[DictRow]
652
693
  with self.datastore.get_connection() as conn, conn.cursor() as curs:
653
694
  curs.execute(
654
695
  query=self.count_tracking_id_statement,
@@ -660,6 +701,13 @@ class PostgresTrackingRecorder(PostgresRecorder, TrackingRecorder):
660
701
  return bool(fetchone["count"])
661
702
 
662
703
 
704
+ TPostgresTrackingRecorder = TypeVar(
705
+ "TPostgresTrackingRecorder",
706
+ bound=PostgresTrackingRecorder,
707
+ default=PostgresTrackingRecorder,
708
+ )
709
+
710
+
663
711
  class PostgresProcessRecorder(
664
712
  PostgresTrackingRecorder, PostgresApplicationRecorder, ProcessRecorder
665
713
  ):
@@ -887,8 +935,6 @@ class PostgresFactory(InfrastructureFactory[PostgresTrackingRecorder]):
887
935
  def aggregate_recorder(self, purpose: str = "events") -> AggregateRecorder:
888
936
  prefix = self.env.name.lower() or "stored"
889
937
  events_table_name = prefix + "_" + purpose
890
- if self.datastore.schema:
891
- events_table_name = f"{self.datastore.schema}.{events_table_name}"
892
938
  recorder = type(self).aggregate_recorder_class(
893
939
  datastore=self.datastore,
894
940
  events_table_name=events_table_name,
@@ -900,9 +946,6 @@ class PostgresFactory(InfrastructureFactory[PostgresTrackingRecorder]):
900
946
  def application_recorder(self) -> ApplicationRecorder:
901
947
  prefix = self.env.name.lower() or "stored"
902
948
  events_table_name = prefix + "_events"
903
- if self.datastore.schema:
904
- events_table_name = f"{self.datastore.schema}.{events_table_name}"
905
-
906
949
  application_recorder_topic = self.env.get(self.APPLICATION_RECORDER_TOPIC)
907
950
  if application_recorder_topic:
908
951
  application_recorder_class: type[PostgresApplicationRecorder] = (
@@ -921,18 +964,18 @@ class PostgresFactory(InfrastructureFactory[PostgresTrackingRecorder]):
921
964
  return recorder
922
965
 
923
966
  def tracking_recorder(
924
- self, tracking_recorder_class: type[PostgresTrackingRecorder] | None = None
925
- ) -> PostgresTrackingRecorder:
967
+ self, tracking_recorder_class: type[TPostgresTrackingRecorder] | None = None
968
+ ) -> TPostgresTrackingRecorder:
926
969
  prefix = self.env.name.lower() or "notification"
927
970
  tracking_table_name = prefix + "_tracking"
928
- if self.datastore.schema:
929
- tracking_table_name = f"{self.datastore.schema}.{tracking_table_name}"
930
971
  if tracking_recorder_class is None:
931
972
  tracking_recorder_topic = self.env.get(self.TRACKING_RECORDER_TOPIC)
932
973
  if tracking_recorder_topic:
933
974
  tracking_recorder_class = resolve_topic(tracking_recorder_topic)
934
975
  else:
935
- tracking_recorder_class = type(self).tracking_recorder_class
976
+ tracking_recorder_class = cast(
977
+ type[TPostgresTrackingRecorder], type(self).tracking_recorder_class
978
+ )
936
979
  assert tracking_recorder_class is not None
937
980
  assert issubclass(tracking_recorder_class, PostgresTrackingRecorder)
938
981
  recorder = tracking_recorder_class(
@@ -948,10 +991,6 @@ class PostgresFactory(InfrastructureFactory[PostgresTrackingRecorder]):
948
991
  events_table_name = prefix + "_events"
949
992
  prefix = self.env.name.lower() or "notification"
950
993
  tracking_table_name = prefix + "_tracking"
951
- if self.datastore.schema:
952
- events_table_name = f"{self.datastore.schema}.{events_table_name}"
953
- tracking_table_name = f"{self.datastore.schema}.{tracking_table_name}"
954
-
955
994
  process_recorder_topic = self.env.get(self.PROCESS_RECORDER_TOPIC)
956
995
  if process_recorder_topic:
957
996
  process_recorder_class: type[PostgresTrackingRecorder] = resolve_topic(
@@ -40,39 +40,65 @@ class ApplicationSubscription(Iterator[tuple[DomainEventProtocol, Tracking]]):
40
40
  gt: int | None = None,
41
41
  topics: Sequence[str] = (),
42
42
  ):
43
+ """
44
+ Starts subscription to application's stored events using application's recorder.
45
+ """
43
46
  self.name = app.name
44
47
  self.recorder = app.recorder
45
48
  self.mapper = app.mapper
46
49
  self.subscription = self.recorder.subscribe(gt=gt, topics=topics)
47
50
 
51
+ def stop(self) -> None:
52
+ """
53
+ Stops the stored event subscription.
54
+ """
55
+ self.subscription.stop()
56
+
48
57
  def __enter__(self) -> Self:
58
+ """
59
+ Calls __enter__ on the stored event subscription.
60
+ """
49
61
  self.subscription.__enter__()
50
62
  return self
51
63
 
52
64
  def __exit__(self, *args: object, **kwargs: Any) -> None:
65
+ """
66
+ Calls __exit__ on the stored event subscription.
67
+ """
53
68
  self.subscription.__exit__(*args, **kwargs)
54
69
 
55
70
  def __iter__(self) -> Self:
56
71
  return self
57
72
 
58
73
  def __next__(self) -> tuple[DomainEventProtocol, Tracking]:
74
+ """
75
+ Returns the next stored event from the stored event subscription.
76
+ Constructs a tracking object that identifies the position of
77
+ the event in the application sequence, and reconstructs a domain
78
+ event object from the stored event object.
79
+ """
59
80
  notification = next(self.subscription)
60
81
  tracking = Tracking(self.name, notification.id)
61
82
  domain_event = self.mapper.to_domain_event(notification)
62
83
  return domain_event, tracking
63
84
 
64
85
  def __del__(self) -> None:
86
+ """
87
+ Stops the stored event subscription.
88
+ """
65
89
  self.stop()
66
90
 
67
- def stop(self) -> None:
68
- self.subscription.stop()
69
-
70
91
 
71
92
  class Projection(ABC, Generic[TTrackingRecorder]):
72
93
  name: str = ""
73
- """Name of projection, used to pick prefixed environment variables."""
74
- topics: Sequence[str] = ()
75
- """Filter events in database when subscribing to an application."""
94
+ """
95
+ Name of projection, used to pick prefixed environment
96
+ variables and define database table names.
97
+ """
98
+ topics: tuple[str, ...] = ()
99
+ """
100
+ Filter events in database when subscribing to an application.
101
+ """
76
102
 
77
103
  def __init__(
78
104
  self,
@@ -104,10 +130,21 @@ class ProjectionRunner(Generic[TApplication, TTrackingRecorder]):
104
130
  self,
105
131
  *,
106
132
  application_class: type[TApplication],
107
- view_class: type[TTrackingRecorder],
108
133
  projection_class: type[Projection[TTrackingRecorder]],
134
+ view_class: type[TTrackingRecorder],
109
135
  env: EnvType | None = None,
110
136
  ):
137
+ """
138
+ Constructs application from given application class with given environment.
139
+ Also constructs a materialised view from given class using an infrastructure
140
+ factory constructed with an environment named after the projection. Also
141
+ constructs a projection with the constructed materialised view object.
142
+ Starts a subscription to application and, in a separate event-processing
143
+ thread, calls projection's process_event() method for each event and tracking
144
+ object pair received from the subscription.
145
+ """
146
+ self._is_stopping = Event()
147
+
111
148
  self.app: TApplication = application_class(env)
112
149
 
113
150
  self.view = (
@@ -128,7 +165,6 @@ class ProjectionRunner(Generic[TApplication, TTrackingRecorder]):
128
165
  gt=self.view.max_tracking_id(self.app.name),
129
166
  topics=self.projection.topics,
130
167
  )
131
- self._is_stopping = Event()
132
168
  self.thread_error: BaseException | None = None
133
169
  self.processing_thread = Thread(
134
170
  target=self._process_events_loop,
@@ -152,6 +188,9 @@ class ProjectionRunner(Generic[TApplication, TTrackingRecorder]):
152
188
  return Environment(name, _env)
153
189
 
154
190
  def stop(self) -> None:
191
+ """
192
+ Stops the application subscription, which will stop the event-processing thread.
193
+ """
155
194
  self._is_stopping.set()
156
195
  self.subscription.stop()
157
196
 
@@ -180,13 +219,21 @@ class ProjectionRunner(Generic[TApplication, TTrackingRecorder]):
180
219
  )
181
220
 
182
221
  is_stopping.set()
183
- subscription.subscription.stop()
222
+ subscription.stop()
184
223
 
185
224
  def run_forever(self, timeout: float | None = None) -> None:
225
+ """
226
+ Blocks until timeout, or until the runner is stopped or errors. Re-raises
227
+ any error otherwise exits normally
228
+ """
186
229
  if self._is_stopping.wait(timeout=timeout) and self.thread_error is not None:
187
230
  raise self.thread_error
188
231
 
189
232
  def wait(self, notification_id: int | None, timeout: float = 1.0) -> None:
233
+ """
234
+ Blocks until timeout, or until the materialised view has recorded a tracking
235
+ object that is greater than or equal to the given notification ID.
236
+ """
190
237
  try:
191
238
  self.projection.view.wait(
192
239
  application_name=self.subscription.name,
@@ -202,8 +249,14 @@ class ProjectionRunner(Generic[TApplication, TTrackingRecorder]):
202
249
  return self
203
250
 
204
251
  def __exit__(self, *args: object, **kwargs: Any) -> None:
252
+ """
253
+ Calls stop() and waits for the event-processing thread to exit.
254
+ """
205
255
  self.stop()
206
256
  self.processing_thread.join()
207
257
 
208
258
  def __del__(self) -> None:
259
+ """
260
+ Calls stop().
261
+ """
209
262
  self.stop()
eventsourcing/system.py CHANGED
@@ -7,12 +7,11 @@ from abc import ABC, abstractmethod
7
7
  from collections import defaultdict
8
8
  from queue import Full, Queue
9
9
  from types import FrameType, ModuleType
10
- from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union, cast
10
+ from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union, cast
11
11
 
12
12
  if TYPE_CHECKING:
13
13
  from collections.abc import Iterable, Iterator, Sequence
14
14
  from typing_extensions import Self
15
- from eventsourcing.dispatch import singledispatchmethod
16
15
 
17
16
  from eventsourcing.application import (
18
17
  Application,
@@ -22,6 +21,7 @@ from eventsourcing.application import (
22
21
  Section,
23
22
  TApplication,
24
23
  )
24
+ from eventsourcing.dispatch import singledispatchmethod
25
25
  from eventsourcing.domain import DomainEventProtocol, MutableOrImmutableAggregate
26
26
  from eventsourcing.persistence import (
27
27
  IntegrityError,
@@ -198,11 +198,8 @@ class Follower(Application):
198
198
  self.notify(processing_event.events)
199
199
  self._notify(recordings)
200
200
 
201
- policy: (
202
- Callable[[DomainEventProtocol, ProcessingEvent], None] | singledispatchmethod
203
- )
204
-
205
- def policy( # type: ignore[no-redef]
201
+ @singledispatchmethod
202
+ def policy(
206
203
  self,
207
204
  domain_event: DomainEventProtocol,
208
205
  processing_event: ProcessingEvent,
@@ -69,11 +69,11 @@ class BankAccount(Aggregate):
69
69
 
70
70
  amount: Decimal
71
71
 
72
- def apply(self, account: Aggregate) -> None:
72
+ def apply(self, aggregate: Aggregate) -> None:
73
73
  """
74
74
  Increments the account balance.
75
75
  """
76
- cast(BankAccount, account).balance += self.amount
76
+ cast(BankAccount, aggregate).balance += self.amount
77
77
 
78
78
  def set_overdraft_limit(self, overdraft_limit: Decimal) -> None:
79
79
  """
@@ -95,8 +95,8 @@ class BankAccount(Aggregate):
95
95
 
96
96
  overdraft_limit: Decimal
97
97
 
98
- def apply(self, account: Aggregate) -> None:
99
- cast(BankAccount, account).overdraft_limit = self.overdraft_limit
98
+ def apply(self, aggregate: Aggregate) -> None:
99
+ cast(BankAccount, aggregate).overdraft_limit = self.overdraft_limit
100
100
 
101
101
  def close(self) -> None:
102
102
  """
@@ -109,8 +109,8 @@ class BankAccount(Aggregate):
109
109
  Domain event for when account is closed.
110
110
  """
111
111
 
112
- def apply(self, account: Aggregate) -> None:
113
- cast(BankAccount, account).is_closed = True
112
+ def apply(self, aggregate: Aggregate) -> None:
113
+ cast(BankAccount, aggregate).is_closed = True
114
114
 
115
115
 
116
116
  class AccountClosedError(Exception):
@@ -793,8 +793,16 @@ class TrackingRecorderTestCase(TestCase, ABC):
793
793
 
794
794
  def test_wait(self) -> None:
795
795
  tracking_recorder = self.create_recorder()
796
+
797
+ tracking_recorder.wait("upstream1", None)
798
+
799
+ with self.assertRaises(TimeoutError):
800
+ tracking_recorder.wait("upstream1", 21, timeout=0.1)
801
+
796
802
  tracking1 = Tracking(notification_id=21, application_name="upstream1")
797
803
  tracking_recorder.insert_tracking(tracking=tracking1)
804
+ tracking_recorder.wait("upstream1", None)
805
+ tracking_recorder.wait("upstream1", 10)
798
806
  tracking_recorder.wait("upstream1", 21)
799
807
  with self.assertRaises(TimeoutError):
800
808
  tracking_recorder.wait("upstream1", 22, timeout=0.1)
@@ -1,4 +1,5 @@
1
1
  import psycopg
2
+ from psycopg.sql import SQL, Identifier
2
3
 
3
4
  from eventsourcing.persistence import PersistenceError
4
5
  from eventsourcing.postgres import PostgresDatastore
@@ -43,7 +44,10 @@ def pg_close_all_connections(
43
44
 
44
45
 
45
46
  def drop_postgres_table(datastore: PostgresDatastore, table_name: str) -> None:
46
- statement = f"DROP TABLE {table_name}"
47
+ statement = SQL("DROP TABLE {0}.{1}").format(
48
+ Identifier(datastore.schema), Identifier(table_name)
49
+ )
50
+ # print(f"Dropping table {datastore.schema}.{table_name}")
47
51
  try:
48
52
  with datastore.transaction(commit=True) as curs:
49
53
  curs.execute(statement, prepare=False)
eventsourcing/utils.py CHANGED
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  import importlib
4
4
  import sys
5
- from collections.abc import Iterator, Mapping, Sequence
5
+ from collections.abc import Iterator, Mapping
6
6
  from functools import wraps
7
7
  from inspect import isfunction
8
8
  from random import random
@@ -129,7 +129,7 @@ def clear_topic_cache() -> None:
129
129
 
130
130
 
131
131
  def retry(
132
- exc: type[Exception] | Sequence[type[Exception]] = Exception,
132
+ exc: type[Exception] | tuple[type[Exception], ...] = Exception,
133
133
  max_attempts: int = 1,
134
134
  wait: float = 0,
135
135
  stall: float = 0,
@@ -235,18 +235,23 @@ class Environment(dict[str, str]):
235
235
  super().__init__(env or {})
236
236
  self.name = name
237
237
 
238
+ @overload # type: ignore[override]
239
+ def get(self, __key: str) -> str | None: ... # pragma: no cover
240
+
238
241
  @overload
239
- def get(self, key: str) -> str | None: ... # pragma: no cover
242
+ def get(self, __key: str, __default: str) -> str: ... # pragma: no cover
240
243
 
241
244
  @overload
242
- def get(self, key: str, default: str | T) -> str | T: ... # pragma: no cover
245
+ def get(self, __key: str, __default: T) -> str | T: ... # pragma: no cover
243
246
 
244
- def get(self, key: str, default: str | T | None = None) -> str | T | None:
245
- for _key in self.create_keys(key):
247
+ def get( # pyright: ignore [reportIncompatibleMethodOverride]
248
+ self, __key: str, __default: str | T | None = None
249
+ ) -> str | T | None:
250
+ for _key in self.create_keys(__key):
246
251
  value = super().get(_key, None)
247
252
  if value is not None:
248
253
  return value
249
- return default
254
+ return __default
250
255
 
251
256
  def create_keys(self, key: str) -> list[str]:
252
257
  keys = []
@@ -1,13 +1,13 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: eventsourcing
3
- Version: 9.4.0a8
3
+ Version: 9.4.0b1
4
4
  Summary: Event sourcing in Python
5
5
  License: BSD 3-Clause
6
6
  Keywords: event sourcing,event store,domain driven design,domain-driven design,ddd,cqrs,cqs
7
7
  Author: John Bywater
8
8
  Author-email: john.bywater@appropriatesoftware.net
9
9
  Requires-Python: >=3.9, !=2.7.*, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*, !=3.7.*, !=3.8.*
10
- Classifier: Development Status :: 3 - Alpha
10
+ Classifier: Development Status :: 4 - Beta
11
11
  Classifier: Intended Audience :: Developers
12
12
  Classifier: Intended Audience :: Education
13
13
  Classifier: Intended Audience :: Science/Research
@@ -0,0 +1,26 @@
1
+ eventsourcing/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ eventsourcing/application.py,sha256=lVgKXCeGA36CsUW7qgkkABX0mCUBvUH-QGQFtOYwmUw,35796
3
+ eventsourcing/cipher.py,sha256=R6bvq7Zcsd7NvV1o8WghdBcnqu0IDkOxlfxUSk1txeQ,3320
4
+ eventsourcing/compressor.py,sha256=IdvrJUB9B2td871oifInv4lGXmHwYL9d69MbHHCr7uI,421
5
+ eventsourcing/cryptography.py,sha256=ZsQFyeyMZysADqKy38ECV71j6EMMSbo3VQO7oRnC1h0,2994
6
+ eventsourcing/dispatch.py,sha256=3eVnGCagnn_CENSnTKonMt2kZ1eoHm8abdK7YRVONbU,2736
7
+ eventsourcing/domain.py,sha256=WwDwo-IxYrC3fEXu_5E2x-Vk4s1Ye9IaZSQLcUGhQqw,59889
8
+ eventsourcing/interface.py,sha256=uCoV9ARAu229SGwp169yeSLbB8wDLKDwWcnQdvOXOQM,5141
9
+ eventsourcing/persistence.py,sha256=hCS1vCtS5X_LUtycuDpF6d6Os7zjQ3qTnGf7W39DWV4,46139
10
+ eventsourcing/popo.py,sha256=xBUnnPuQ_wdF0ErU9AApRPwlkB0CJePMbWCz6qn3U1M,9654
11
+ eventsourcing/postgres.py,sha256=hHYpzvZc7nANXE1uQZ3gA-XKk5X9rB6nJ_zouOkpidk,37537
12
+ eventsourcing/projection.py,sha256=niCF9L5LsdeA0kgbKC3sE7p65rOF5du-YpdibkAhNU0,8588
13
+ eventsourcing/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
+ eventsourcing/sqlite.py,sha256=UVgG0JCmYr7xN8HHnoiMPpZtZ3d5NTfB-iq0MlZGax4,22051
15
+ eventsourcing/system.py,sha256=tyqGaGUE6CUGPWUZf27_26-Zl3PLsm3vK017Wxn7IM0,47279
16
+ eventsourcing/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
+ eventsourcing/tests/application.py,sha256=DEncXtCsX0X6Udua9GbGpua0xTGshnDPIThFfzuiBog,18025
18
+ eventsourcing/tests/domain.py,sha256=LkvGFYFc6jJR0PYjO2hAGPw0TNW8vUzdTuR8K5xDEQ0,3385
19
+ eventsourcing/tests/persistence.py,sha256=DePUevT4uYNTcIzVZoO6uzyyrUI2sOK6CSioiJRcD6I,58518
20
+ eventsourcing/tests/postgres_utils.py,sha256=0ywklGp6cXZ5PmV8ANVkwSHsZZCl5zTmOk7iG-RmrCE,1548
21
+ eventsourcing/utils.py,sha256=1mG24CXb4oRaumB6NMaH3QqxtHEiWTiJEWzFDRBf6nc,8537
22
+ eventsourcing-9.4.0b1.dist-info/AUTHORS,sha256=8aHOM4UbNZcKlD-cHpFRcM6RWyCqtwtxRev6DeUgVRs,137
23
+ eventsourcing-9.4.0b1.dist-info/LICENSE,sha256=CQEQzcZO8AWXL5i3hIo4yVKrYjh2FBz6hCM7kpXWpw4,1512
24
+ eventsourcing-9.4.0b1.dist-info/METADATA,sha256=Q4NfLZMDdQ9EKyxbXmk3_exmsHcD2PfOcSh-ldO3R-8,9796
25
+ eventsourcing-9.4.0b1.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
26
+ eventsourcing-9.4.0b1.dist-info/RECORD,,
@@ -1,26 +0,0 @@
1
- eventsourcing/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- eventsourcing/application.py,sha256=YF1CAUZrqb51vzT95mk6mC0tWCHdR_N6-rn1s6SKL8g,35909
3
- eventsourcing/cipher.py,sha256=AjgOlOv9FF6xyXiFnHwcK6NX5IJ3nPHFL5GzIyWozyg,3265
4
- eventsourcing/compressor.py,sha256=IdvrJUB9B2td871oifInv4lGXmHwYL9d69MbHHCr7uI,421
5
- eventsourcing/cryptography.py,sha256=ZsQFyeyMZysADqKy38ECV71j6EMMSbo3VQO7oRnC1h0,2994
6
- eventsourcing/dispatch.py,sha256=yYSpT-jqc6l_wTdqEnfPJJfvsZN2Ta8g2anrVPWIcqQ,1412
7
- eventsourcing/domain.py,sha256=FUvCklB-8BGEzoalk3IcuEeyDZrRzPhzgI96L4I9OxM,58980
8
- eventsourcing/interface.py,sha256=kObA7ouzLD4YpJMjhfPVmRUcDzhbK0bbFKXy75EscHU,5138
9
- eventsourcing/persistence.py,sha256=zsd_RL_H_gc64Ka45Vapd92eVZXrwfanrGNE36zp9tg,45877
10
- eventsourcing/popo.py,sha256=xBUnnPuQ_wdF0ErU9AApRPwlkB0CJePMbWCz6qn3U1M,9654
11
- eventsourcing/postgres.py,sha256=mMdorzwSbDDNjK5KQDNVLy48JrI_18KoRUkDtOI8_wM,36506
12
- eventsourcing/projection.py,sha256=pmSywj7DkUFIoOMuo8RJv7snB5jNdODSupMgQikEZto,6677
13
- eventsourcing/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
- eventsourcing/sqlite.py,sha256=UVgG0JCmYr7xN8HHnoiMPpZtZ3d5NTfB-iq0MlZGax4,22051
15
- eventsourcing/system.py,sha256=UgIwaU35drvaq9Nvs16zsZ4UMPjZ_PlFclvEhIWIVFA,47400
16
- eventsourcing/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
- eventsourcing/tests/application.py,sha256=DEncXtCsX0X6Udua9GbGpua0xTGshnDPIThFfzuiBog,18025
18
- eventsourcing/tests/domain.py,sha256=1FTpfG5w-UMbqh_R-6LKDZY0JKb_nZtOIEGrdYDEUyY,3373
19
- eventsourcing/tests/persistence.py,sha256=mb29R8-CJCNroSmrfYYxsN354zmJKQ--ROrjPDhjKTc,58256
20
- eventsourcing/tests/postgres_utils.py,sha256=jS1Ac5Yj4IPx-bsL2IRlytcd5oa0l6SiFcprHxqWdVQ,1371
21
- eventsourcing/utils.py,sha256=0DlFnDmvGwCSWiEQ_h5GLgyjAgHJbsVtdf-GLQvoH7I,8350
22
- eventsourcing-9.4.0a8.dist-info/AUTHORS,sha256=8aHOM4UbNZcKlD-cHpFRcM6RWyCqtwtxRev6DeUgVRs,137
23
- eventsourcing-9.4.0a8.dist-info/LICENSE,sha256=CQEQzcZO8AWXL5i3hIo4yVKrYjh2FBz6hCM7kpXWpw4,1512
24
- eventsourcing-9.4.0a8.dist-info/METADATA,sha256=22EfMdqFS1lzyD7ZDa4t3tfgl1b3mgeBwhXMYkzTCJg,9797
25
- eventsourcing-9.4.0a8.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
26
- eventsourcing-9.4.0a8.dist-info/RECORD,,