eventsourcing 9.4.5__py3-none-any.whl → 9.5.0a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of eventsourcing might be problematic. Click here for more details.

eventsourcing/domain.py CHANGED
@@ -17,6 +17,7 @@ from typing import (
17
17
  Callable,
18
18
  ClassVar,
19
19
  Generic,
20
+ Optional,
20
21
  Protocol,
21
22
  TypeVar,
22
23
  Union,
@@ -31,6 +32,7 @@ from warnings import warn
31
32
 
32
33
  from eventsourcing.utils import (
33
34
  TopicError,
35
+ construct_topic,
34
36
  get_method_name,
35
37
  get_topic,
36
38
  register_topic,
@@ -237,35 +239,41 @@ class CanCreateTimestamp:
237
239
  TAggregate = TypeVar("TAggregate", bound="BaseAggregate[Any]")
238
240
 
239
241
 
240
- class HasOriginatorIDVersion(Generic[TAggregateID_co]):
242
+ class HasOriginatorIDVersion(Generic[TAggregateID]):
241
243
  """Declares ``originator_id`` and ``originator_version`` attributes."""
242
244
 
243
- originator_id: TAggregateID_co
245
+ originator_id: TAggregateID
244
246
  """UUID identifying an aggregate to which the event belongs."""
245
247
  originator_version: int
246
248
  """Integer identifying the version of the aggregate when the event occurred."""
247
249
 
248
- type_originator_id: ClassVar[type[Union[UUID, str]]] # noqa: UP007
250
+ originator_id_type: ClassVar[Optional[type[Union[UUID, str]]]] = None # noqa: UP007
249
251
 
250
252
  def __init_subclass__(cls) -> None:
251
253
  cls.find_originator_id_type(HasOriginatorIDVersion)
254
+ super().__init_subclass__()
252
255
 
253
256
  @classmethod
254
257
  def find_originator_id_type(cls: type, generic_cls: type) -> None:
255
- """Store the type argument of TAggregateID_co on the subclass."""
256
- for orig_base in cls.__orig_bases__: # type: ignore[attr-defined]
257
- type_originator_id = orig_base.__dict__.get("type_originator_id", "")
258
- if type_originator_id in (UUID, str):
259
- cls.type_originator_id = type_originator_id # type: ignore[attr-defined]
260
- break
261
- if get_origin(orig_base) is generic_cls:
262
- type_originator_id = get_args(orig_base)[0]
263
- if type_originator_id in (UUID, str):
264
- cls.type_originator_id = type_originator_id # type: ignore[attr-defined]
265
- break
258
+ """Store the type argument of TAggregateID on the subclass."""
259
+ if "originator_id_type" not in cls.__dict__:
260
+ for orig_base in cls.__orig_bases__: # type: ignore[attr-defined]
261
+ if "originator_id_type" in orig_base.__dict__:
262
+ cls.originator_id_type = orig_base.__dict__["originator_id_type"] # type: ignore[attr-defined]
263
+ elif get_origin(orig_base) is generic_cls:
264
+ originator_id_type = get_args(orig_base)[0]
265
+ if originator_id_type in (UUID, str):
266
+ cls.originator_id_type = originator_id_type # type: ignore[attr-defined]
267
+ break
268
+ if originator_id_type is Any:
269
+ continue
270
+ if isinstance(originator_id_type, TypeVar):
271
+ continue
272
+ msg = f"Aggregate ID type arg cannot be {originator_id_type}"
273
+ raise TypeError(msg)
266
274
 
267
275
 
268
- class CanMutateAggregate(HasOriginatorIDVersion[TAggregateID_co], CanCreateTimestamp):
276
+ class CanMutateAggregate(HasOriginatorIDVersion[TAggregateID], CanCreateTimestamp):
269
277
  """Implements a :py:func:`~eventsourcing.domain.CanMutateAggregate.mutate`
270
278
  method that evolves the state of an aggregate.
271
279
  """
@@ -276,6 +284,7 @@ class CanMutateAggregate(HasOriginatorIDVersion[TAggregateID_co], CanCreateTimes
276
284
 
277
285
  def __init_subclass__(cls) -> None:
278
286
  cls.find_originator_id_type(CanMutateAggregate)
287
+ super().__init_subclass__()
279
288
 
280
289
  def mutate(self, aggregate: TAggregate | None) -> TAggregate | None:
281
290
  """Validates and adjusts the attributes of the given ``aggregate``
@@ -333,7 +342,7 @@ class CanMutateAggregate(HasOriginatorIDVersion[TAggregateID_co], CanCreateTimes
333
342
  return self.__dict__
334
343
 
335
344
 
336
- class CanInitAggregate(CanMutateAggregate[TAggregateID_co]):
345
+ class CanInitAggregate(CanMutateAggregate[TAggregateID]):
337
346
  """Implements a :func:`~eventsourcing.domain.CanMutateAggregate.mutate`
338
347
  method that constructs the initial state of an aggregate.
339
348
  """
@@ -343,6 +352,7 @@ class CanInitAggregate(CanMutateAggregate[TAggregateID_co]):
343
352
 
344
353
  def __init_subclass__(cls) -> None:
345
354
  cls.find_originator_id_type(CanInitAggregate)
355
+ super().__init_subclass__()
346
356
 
347
357
  def mutate(self, aggregate: TAggregate | None) -> TAggregate | None:
348
358
  """Constructs an aggregate instance according to the attributes of an event.
@@ -360,7 +370,7 @@ class CanInitAggregate(CanMutateAggregate[TAggregateID_co]):
360
370
 
361
371
  # Pick out event attributes for the aggregate base class init method.
362
372
  self_dict = self._as_dict()
363
- base_kwargs = _filter_kwargs_for_method_params(
373
+ base_kwargs = filter_kwargs_for_method_params(
364
374
  self_dict, type(agg).__base_init__
365
375
  )
366
376
 
@@ -369,7 +379,7 @@ class CanInitAggregate(CanMutateAggregate[TAggregateID_co]):
369
379
  agg.__base_init__(**base_kwargs)
370
380
 
371
381
  # Pick out event attributes for aggregate subclass class init method.
372
- init_kwargs = _filter_kwargs_for_method_params(self_dict, type(agg).__init__)
382
+ init_kwargs = filter_kwargs_for_method_params(self_dict, type(agg).__init__)
373
383
 
374
384
  # Provide the aggregate id, if the __init__ method expects it.
375
385
  if aggregate_class in _init_mentions_id:
@@ -450,7 +460,7 @@ class LogEvent(DomainEvent):
450
460
  """
451
461
 
452
462
 
453
- def _filter_kwargs_for_method_params(
463
+ def filter_kwargs_for_method_params(
454
464
  kwargs: dict[str, Any], method: Callable[..., Any]
455
465
  ) -> dict[str, Any]:
456
466
  names = _spec_filter_kwargs_for_method_params(method)
@@ -463,8 +473,14 @@ def _spec_filter_kwargs_for_method_params(method: Callable[..., Any]) -> set[str
463
473
  return set(method_signature.parameters)
464
474
 
465
475
 
476
+ class AbstractDCBEvent:
477
+ pass
478
+
479
+
466
480
  if TYPE_CHECKING:
467
- EventSpecType = Union[str, type[CanMutateAggregate[Any]]]
481
+ EventSpecType = Union[ # noqa: PYI055
482
+ str, type[CanMutateAggregate[Any]], type[AbstractDCBEvent]
483
+ ]
468
484
 
469
485
  CallableType = Callable[..., None]
470
486
  DecoratableType = Union[CallableType, property]
@@ -479,7 +495,9 @@ class CommandMethodDecorator:
479
495
  event_topic: str | None = None,
480
496
  ):
481
497
  self.is_name_inferred_from_method = False
482
- self.given_event_cls: type[CanMutateAggregate[Any]] | None = None
498
+ self.given_event_cls: (
499
+ type[CanMutateAggregate[Any] | AbstractDCBEvent] | None
500
+ ) = None
483
501
  self.event_cls_name: str | None = None
484
502
  self.decorated_property: property | None = None
485
503
  self.is_property_setter = False
@@ -496,9 +514,13 @@ class CommandMethodDecorator:
496
514
 
497
515
  # Event class has been specified.
498
516
  elif isinstance(event_spec, type) and issubclass(
499
- event_spec, CanMutateAggregate
517
+ event_spec, (CanMutateAggregate, AbstractDCBEvent)
500
518
  ):
501
- if event_spec in _given_event_classes:
519
+ # Guard against associating more than one method body with any given class.
520
+ if (
521
+ issubclass(event_spec, CanMutateAggregate)
522
+ and event_spec in _given_event_classes
523
+ ):
502
524
  name = event_spec.__name__
503
525
  msg = f"{name} event class used in more than one decorator"
504
526
  raise TypeError(msg)
@@ -542,6 +564,10 @@ class CommandMethodDecorator:
542
564
  # Remember the decorated obj as the decorated method.
543
565
  self.decorated_func = decorated_obj
544
566
 
567
+ if self.decorated_func.__name__ == "_":
568
+ underscore_method_decorators.append(
569
+ (construct_topic(self.decorated_func), self)
570
+ )
545
571
  # If necessary, derive an event class name from the method.
546
572
  if not self.given_event_cls and not self.event_cls_name:
547
573
  original_method_name = self.decorated_func.__name__
@@ -599,6 +625,10 @@ class CommandMethodDecorator:
599
625
  self, instance: BaseAggregate[Any] | None, owner: type[BaseAggregate[Any]]
600
626
  ) -> BoundCommandMethodDecorator | UnboundCommandMethodDecorator | property | Any:
601
627
  """Descriptor protocol for getting decorated method or property."""
628
+ if self.decorated_func.__name__ == "_":
629
+ msg = "Underscore 'non-command' methods cannot be used to trigger events."
630
+ raise ProgrammingError(msg)
631
+
602
632
  # If we are decorating a property, then delegate to the property's __get__.
603
633
  if self.decorated_property:
604
634
  return self.decorated_property.__get__(instance, owner)
@@ -611,6 +641,12 @@ class CommandMethodDecorator:
611
641
  if instance:
612
642
  return BoundCommandMethodDecorator(self, instance)
613
643
 
644
+ if "SPHINX_BUILD" in os.environ: # pragma: no cover
645
+ # Sphinx hack: use the original function when sphinx is running so that the
646
+ # documentation ends up with the correct function signatures.
647
+ # See 'SPHINX_BUILD' in conf.py.
648
+ return self.decorated_func
649
+
614
650
  # Return an "unbound" command method decorator if we have no instance.
615
651
  return UnboundCommandMethodDecorator(self)
616
652
 
@@ -630,7 +666,7 @@ def event(arg: TDecoratableType, /) -> TDecoratableType:
630
666
 
631
667
  @overload
632
668
  def event(
633
- arg: type[CanMutateAggregate[Any]], /
669
+ arg: type[CanMutateAggregate[Any] | AbstractDCBEvent], /
634
670
  ) -> Callable[[TDecoratableType], TDecoratableType]:
635
671
  """Signature for calling ``@event`` decorator with event class."""
636
672
 
@@ -711,7 +747,10 @@ def event(
711
747
  if (
712
748
  arg is None
713
749
  or isinstance(arg, str)
714
- or (isinstance(arg, type) and issubclass(arg, CanMutateAggregate))
750
+ or (
751
+ isinstance(arg, type)
752
+ and issubclass(arg, (CanMutateAggregate, AbstractDCBEvent))
753
+ )
715
754
  ):
716
755
  event_spec = arg
717
756
 
@@ -764,14 +803,22 @@ class UnboundCommandMethodDecorator:
764
803
  )
765
804
 
766
805
 
806
+ class CanTriggerEvent(Protocol):
807
+ def trigger_event(
808
+ self,
809
+ event_class: type[Any],
810
+ **kwargs: Any,
811
+ ) -> None:
812
+ pass # pragma: no cover
813
+
814
+
767
815
  class BoundCommandMethodDecorator:
768
- """Binds a CommandMethodDecorator with an aggregate instance so calls to
769
- decorated command methods can be intercepted and will trigger an event.
816
+ """Binds a CommandMethodDecorator with an object instance that can trigger
817
+ events, so that calls to decorated command methods can be intercepted and
818
+ will trigger a "decorated func caller" event.
770
819
  """
771
820
 
772
- def __init__(
773
- self, event_decorator: CommandMethodDecorator, aggregate: BaseAggregate[Any]
774
- ):
821
+ def __init__(self, event_decorator: CommandMethodDecorator, obj: CanTriggerEvent):
775
822
  """:param CommandMethodDecorator event_decorator:
776
823
  :param Aggregate aggregate:
777
824
  """
@@ -781,29 +828,41 @@ class BoundCommandMethodDecorator:
781
828
  self.__qualname__ = event_decorator.decorated_func.__qualname__
782
829
  self.__annotations__ = event_decorator.decorated_func.__annotations__
783
830
  self.__doc__ = event_decorator.decorated_func.__doc__
784
- self.aggregate = aggregate
831
+ self.obj = obj
785
832
 
786
833
  def trigger(self, *args: Any, **kwargs: Any) -> None:
787
834
  kwargs = _coerce_args_to_kwargs(
788
835
  self.event_decorator.decorated_func, args, kwargs
789
836
  )
790
- event_cls = decorator_event_classes[self.event_decorator]
791
- kwargs = _filter_kwargs_for_method_params(kwargs, event_cls)
792
- self.aggregate.trigger_event(event_cls, **kwargs)
837
+ try:
838
+ event_cls = decorated_func_callers[self.event_decorator]
839
+ except KeyError as e: # pragma: no cover
840
+ msg = (
841
+ f"Event class not registered for event decorator on "
842
+ f"{self.event_decorator.decorated_func.__qualname__}"
843
+ )
844
+ raise KeyError(msg) from e
845
+ kwargs = filter_kwargs_for_method_params(kwargs, event_cls)
846
+ assert issubclass(event_cls, AbstractDecoratedFuncCaller), event_cls
847
+ self.obj.trigger_event(event_cls, **kwargs)
793
848
 
794
849
  def __call__(self, *args: Any, **kwargs: Any) -> None:
795
850
  self.trigger(*args, **kwargs)
796
851
 
797
852
 
798
- class DecoratorEvent(CanMutateAggregate[Any]):
853
+ class AbstractDecoratedFuncCaller:
854
+ pass
855
+
856
+
857
+ class DecoratedFuncCaller(CanMutateAggregate[Any], AbstractDecoratedFuncCaller):
799
858
  def apply(self, aggregate: BaseAggregate[Any]) -> None:
800
859
  """Applies event to aggregate by calling method decorated by @event."""
801
860
  # Identify the function that was decorated.
802
- decorated_func = _decorated_funcs[type(self)]
861
+ decorated_func = decorated_funcs[type(self)]
803
862
 
804
863
  # Select event attributes mentioned in function signature.
805
864
  self_dict = self._as_dict()
806
- kwargs = _filter_kwargs_for_method_params(self_dict, decorated_func)
865
+ kwargs = filter_kwargs_for_method_params(self_dict, decorated_func)
807
866
 
808
867
  # Call the original method with event attribute values.
809
868
  decorated_method = decorated_func.__get__(aggregate, type(aggregate))
@@ -813,12 +872,22 @@ class DecoratorEvent(CanMutateAggregate[Any]):
813
872
  super().apply(aggregate)
814
873
 
815
874
 
816
- _given_event_classes: set[type] = set()
817
- _decorated_funcs: dict[type, CallableType] = {}
875
+ # This helps enforce single usage of original event classes in decorators.
876
+ _given_event_classes = set[type]()
877
+
878
+ # This keeps track of the "created" event classes for an aggregate.
818
879
  _created_event_classes: dict[type, list[type[CanInitAggregate[Any]]]] = {}
819
880
 
881
+ # This remembers which event class to trigger when a decorated method is called.
882
+ decorated_func_callers: dict[
883
+ CommandMethodDecorator, type[AbstractDecoratedFuncCaller]
884
+ ] = {}
885
+
886
+ # This remembers which decorated func a decorated func caller should call.
887
+ decorated_funcs: dict[type, CallableType] = {}
820
888
 
821
- decorator_event_classes: dict[CommandMethodDecorator, type[DecoratorEvent]] = {}
889
+ # This keeps track of decorated "non-command" projection-only methods called "_".
890
+ underscore_method_decorators: list[tuple[str, CommandMethodDecorator]] = []
822
891
 
823
892
 
824
893
  def _raise_type_error_if_func_has_variable_params(method: CallableType) -> None:
@@ -1075,7 +1144,7 @@ class BaseAggregate(Generic[TAggregateID], metaclass=MetaAggregate):
1075
1144
  cls: type[Self],
1076
1145
  event_class: type[CanInitAggregate[TAggregateID]],
1077
1146
  *,
1078
- id: UUID | str | None = None, # noqa: A002
1147
+ id: TAggregateID | None = None, # noqa: A002
1079
1148
  **kwargs: Any,
1080
1149
  ) -> Self:
1081
1150
  """Constructs a new aggregate object instance."""
@@ -1253,7 +1322,10 @@ class BaseAggregate(Generic[TAggregateID], metaclass=MetaAggregate):
1253
1322
  cls.__name__ in _module.__dict__
1254
1323
  and ENVVAR_DISABLE_REDEFINITION_CHECK not in os.environ
1255
1324
  ):
1256
- msg = f"Name '{cls.__name__}' already defined in '{cls.__module__}' module"
1325
+ msg = (
1326
+ f"Name '{cls.__name__}' of {cls} already defined in "
1327
+ f"'{cls.__module__}' module: {_module.__dict__[cls.__name__]}"
1328
+ )
1257
1329
  raise ProgrammingError(msg)
1258
1330
 
1259
1331
  # Get the class annotations.
@@ -1343,24 +1415,41 @@ class BaseAggregate(Generic[TAggregateID], metaclass=MetaAggregate):
1343
1415
  if name.lower() == name:
1344
1416
  continue
1345
1417
 
1346
- # Only consider "event" classes (implement "CanMutateAggregate" protocol).
1418
+ # Don't subclass if not "CanMutateAggregate".
1347
1419
  if not isinstance(value, type) or not issubclass(value, CanMutateAggregate):
1348
1420
  continue
1349
1421
 
1422
+ # # Don't subclass generic classes (we don't have a type argument).
1423
+ # # TODO: Maybe also prohibit triggering such things?
1424
+ # if value.__dict__.get("__parameters__", ()):
1425
+ # continue
1426
+
1350
1427
  # Check we have a base event class.
1351
1428
  if base_event_cls is None:
1352
1429
  raise base_event_class_not_defined_error
1353
1430
 
1354
1431
  # Redefine events that aren't already subclass of the base event class.
1355
1432
  if not issubclass(value, base_event_cls):
1433
+ # Identify base classes that were redefined, to preserve hierarchy.
1434
+ redefined_bases = []
1435
+ for base in value.__bases__:
1436
+ if base in redefined_event_classes:
1437
+ redefined_bases.append(redefined_event_classes[base])
1438
+ elif "__pydantic_generic_metadata__" in base.__dict__:
1439
+ pydantic_metadata = base.__dict__[
1440
+ "__pydantic_generic_metadata__"
1441
+ ]
1442
+ for i, key in enumerate(pydantic_metadata):
1443
+ if key == "origin":
1444
+ origin = base.__bases__[i]
1445
+ if origin in redefined_event_classes:
1446
+ redefined_bases.append(
1447
+ redefined_event_classes[origin]
1448
+ )
1449
+
1356
1450
  # Decide base classes of redefined event class: it must be
1357
1451
  # a subclass of the original class, all redefined classes that
1358
1452
  # were in its bases, and the aggregate's base event class.
1359
- redefined_bases = [
1360
- redefined_event_classes[b]
1361
- for b in value.__bases__
1362
- if b in redefined_event_classes
1363
- ]
1364
1453
  event_class_bases = (
1365
1454
  value,
1366
1455
  *redefined_bases,
@@ -1584,7 +1673,7 @@ class BaseAggregate(Generic[TAggregateID], metaclass=MetaAggregate):
1584
1673
  # the subclassing of events above? Maybe do this first?
1585
1674
  event_cls = cls._define_event_class(
1586
1675
  event_decorator.given_event_cls.__name__,
1587
- (DecoratorEvent, given_subclass),
1676
+ (DecoratedFuncCaller, given_subclass),
1588
1677
  None,
1589
1678
  )
1590
1679
 
@@ -1605,20 +1694,20 @@ class BaseAggregate(Generic[TAggregateID], metaclass=MetaAggregate):
1605
1694
  # Define event class from signature of original method.
1606
1695
  event_cls = cls._define_event_class(
1607
1696
  event_decorator.event_cls_name,
1608
- (DecoratorEvent, base_event_cls),
1697
+ (DecoratedFuncCaller, base_event_cls),
1609
1698
  event_decorator.decorated_func,
1610
1699
  event_topic=event_decorator.event_topic,
1611
1700
  )
1612
1701
 
1613
1702
  # Cache the decorated method for the event class to use.
1614
- _decorated_funcs[event_cls] = event_decorator.decorated_func
1703
+ decorated_funcs[event_cls] = event_decorator.decorated_func
1615
1704
 
1616
1705
  # Set the event class as an attribute of the aggregate class.
1617
1706
  setattr(cls, event_cls.__name__, event_cls)
1618
1707
 
1619
1708
  # Remember which event class to trigger.
1620
- decorator_event_classes[event_decorator] = cast(
1621
- "type[DecoratorEvent]", event_cls
1709
+ decorated_func_callers[event_decorator] = cast(
1710
+ type[DecoratedFuncCaller], event_cls
1622
1711
  )
1623
1712
 
1624
1713
  # Check any create_id() method defined on this class is static or class method.
@@ -1735,17 +1824,13 @@ class SnapshotProtocol(DomainEventProtocol[TAggregateID_co], Protocol):
1735
1824
  """Snapshots have a 'take()' class method."""
1736
1825
 
1737
1826
 
1738
- TCanSnapshotAggregate = TypeVar(
1739
- "TCanSnapshotAggregate", bound="CanSnapshotAggregate[Any]"
1740
- )
1741
-
1742
-
1743
- class CanSnapshotAggregate(HasOriginatorIDVersion[TAggregateID_co], CanCreateTimestamp):
1827
+ class CanSnapshotAggregate(HasOriginatorIDVersion[TAggregateID], CanCreateTimestamp):
1744
1828
  topic: str
1745
1829
  state: Any
1746
1830
 
1747
1831
  def __init_subclass__(cls) -> None:
1748
1832
  cls.find_originator_id_type(CanSnapshotAggregate)
1833
+ super().__init_subclass__()
1749
1834
 
1750
1835
  # def __init__(
1751
1836
  # self,
@@ -1760,7 +1845,7 @@ class CanSnapshotAggregate(HasOriginatorIDVersion[TAggregateID_co], CanCreateTim
1760
1845
  @classmethod
1761
1846
  def take(
1762
1847
  cls,
1763
- aggregate: MutableOrImmutableAggregate[TAggregateID_co],
1848
+ aggregate: MutableOrImmutableAggregate[TAggregateID],
1764
1849
  ) -> Self:
1765
1850
  """Creates a snapshot of the given :class:`Aggregate` object."""
1766
1851
  aggregate_state = dict(aggregate.__dict__)
@@ -1779,9 +1864,9 @@ class CanSnapshotAggregate(HasOriginatorIDVersion[TAggregateID_co], CanCreateTim
1779
1864
  state=aggregate_state, # pyright: ignore[reportCallIssue]
1780
1865
  )
1781
1866
 
1782
- def mutate(self, _: None) -> BaseAggregate[TAggregateID_co]:
1867
+ def mutate(self, _: None) -> BaseAggregate[TAggregateID]:
1783
1868
  """Reconstructs the snapshotted :class:`Aggregate` object."""
1784
- cls = cast(type[BaseAggregate[TAggregateID_co]], resolve_topic(self.topic))
1869
+ cls = cast(type[BaseAggregate[TAggregateID]], resolve_topic(self.topic))
1785
1870
  aggregate_state = dict(self.state)
1786
1871
  from_version = aggregate_state.pop("class_version", 1)
1787
1872
  class_version = getattr(cls, "class_version", 1)
@@ -13,10 +13,10 @@ from queue import Queue
13
13
  from threading import Condition, Event, Lock, Semaphore, Thread, Timer
14
14
  from time import monotonic, sleep, time
15
15
  from types import GenericAlias, ModuleType
16
- from typing import TYPE_CHECKING, Any, Callable, Generic, Union, cast
16
+ from typing import Any, Callable, Generic, Union, cast
17
17
  from uuid import UUID
18
18
 
19
- from typing_extensions import TypeVar
19
+ from typing_extensions import Self, TypeVar
20
20
 
21
21
  from eventsourcing.domain import (
22
22
  DomainEventProtocol,
@@ -33,9 +33,6 @@ from eventsourcing.utils import (
33
33
  strtobool,
34
34
  )
35
35
 
36
- if TYPE_CHECKING:
37
- from typing_extensions import Self
38
-
39
36
 
40
37
  class Transcoding(ABC):
41
38
  """Abstract base class for custom transcodings."""
@@ -291,7 +288,7 @@ class Mapper(Generic[TAggregateID]):
291
288
  )
292
289
  raise MapperDeserialisationError(msg) from e
293
290
 
294
- id_convertor = _find_id_convertor(
291
+ id_convertor = find_id_convertor(
295
292
  cls, cast(Hashable, type(stored_event.originator_id))
296
293
  )
297
294
  # print("ID of convertor:", id(convertor))
@@ -309,33 +306,53 @@ class Mapper(Generic[TAggregateID]):
309
306
 
310
307
 
311
308
  @lru_cache
312
- def _find_id_convertor(
309
+ def find_id_convertor(
313
310
  domain_event_cls: type[object], originator_id_cls: type[UUID | str]
314
311
  ) -> Callable[[UUID | str], UUID | str]:
315
312
  # Try to find the originator_id type.
316
- type_originator_id: type[UUID | str] = UUID
317
313
  if issubclass(domain_event_cls, HasOriginatorIDVersion):
318
- type_originator_id = domain_event_cls.type_originator_id
314
+ # For classes that inherit CanMutateAggregate, and don't use a different
315
+ # mapper, then assume they aren't overriding __init_subclass__ is a way
316
+ # that prevents 'originator_id_type' being found from type arguments and
317
+ # set on the class.
318
+ # TODO: Write a test where a custom class does override __init_subclass__
319
+ # so that the next line will cause an AssertionError. Then fix this code.
320
+ if domain_event_cls.originator_id_type is None:
321
+ msg = "originator_id_type cannot be None"
322
+ raise TypeError(msg)
323
+ originator_id_type = domain_event_cls.originator_id_type
319
324
  else:
320
- try:
321
- # Look on plain simple annotations.
322
- originator_id_annotation = typing.get_type_hints(
323
- domain_event_cls, globalns=globals()
324
- ).get("originator_id", None)
325
- assert originator_id_annotation in [UUID, str]
326
- type_originator_id = cast(type[Union[UUID, str]], originator_id_annotation)
327
- except NameError:
328
- pass
329
-
330
- if originator_id_cls is str and type_originator_id is UUID:
325
+ # Otherwise look for annotations.
326
+ for cls in domain_event_cls.__mro__:
327
+ try:
328
+ annotation = cls.__annotations__["originator_id"]
329
+ except (KeyError, AttributeError): # noqa: PERF203
330
+ continue
331
+ else:
332
+ valid_annotations = {
333
+ str: str,
334
+ UUID: UUID,
335
+ "str": str,
336
+ "UUID": UUID,
337
+ "uuid.UUID": UUID,
338
+ }
339
+ if annotation not in valid_annotations:
340
+ msg = f"originator_id annotation on {cls} is not either UUID or str"
341
+ raise TypeError(msg)
342
+ assert annotation in valid_annotations, annotation
343
+ originator_id_type = valid_annotations[annotation]
344
+ break
345
+ else:
346
+ msg = (
347
+ f"Neither event class {domain_event_cls}"
348
+ f"nor its bases have an originator_id annotation"
349
+ )
350
+ raise TypeError(msg)
351
+
352
+ if originator_id_cls is str and originator_id_type is UUID:
331
353
  convertor = str_to_uuid_convertor
332
354
  else:
333
355
  convertor = pass_through_convertor
334
- # print(
335
- # f"Decided {convertor.__name__} "
336
- # f"for {domain_event_cls.__name__} "
337
- # f"and {originator_id_cls.__name__}."
338
- # )
339
356
  return convertor
340
357
 
341
358
 
@@ -659,14 +676,14 @@ class InfrastructureFactory(ABC, Generic[TTrackingRecorder]):
659
676
 
660
677
  @classmethod
661
678
  def construct(
662
- cls: type[InfrastructureFactory[TTrackingRecorder]],
679
+ cls: type[Self],
663
680
  env: Environment | None = None,
664
- ) -> InfrastructureFactory[TTrackingRecorder]:
681
+ ) -> Self:
665
682
  """Constructs concrete infrastructure factory for given
666
683
  named application. Reads and resolves persistence
667
684
  topic from environment variable 'PERSISTENCE_MODULE'.
668
685
  """
669
- factory_cls: type[InfrastructureFactory[TTrackingRecorder]]
686
+ factory_cls: type[Self]
670
687
  if env is None:
671
688
  env = Environment()
672
689
  topic = (
@@ -685,9 +702,7 @@ class InfrastructureFactory(ABC, Generic[TTrackingRecorder]):
685
702
  or "eventsourcing.popo"
686
703
  )
687
704
  try:
688
- obj: type[InfrastructureFactory[TTrackingRecorder]] | ModuleType = (
689
- resolve_topic(topic)
690
- )
705
+ obj: type[Self] | ModuleType = resolve_topic(topic)
691
706
  except TopicError as e:
692
707
  msg = (
693
708
  "Failed to resolve persistence module topic: "
@@ -698,29 +713,29 @@ class InfrastructureFactory(ABC, Generic[TTrackingRecorder]):
698
713
 
699
714
  if isinstance(obj, ModuleType):
700
715
  # Find the factory in the module.
701
- factory_classes: list[type[InfrastructureFactory[TTrackingRecorder]]] = []
716
+ factory_classes = set[type[Self]]()
702
717
  for member in obj.__dict__.values():
703
- if (
704
- member is not InfrastructureFactory
705
- and isinstance(member, type) # Look for classes...
706
- and isinstance(member, type) # Look for classes...
707
- and not isinstance(
708
- member, GenericAlias
709
- ) # Issue with Python 3.9 and 3.10.
710
- and issubclass(member, InfrastructureFactory) # Ignore base class.
711
- and member not in factory_classes # Forgive aliases.
712
- ):
713
- factory_classes.append(member)
718
+ # Look for classes...
719
+ if not isinstance(member, type):
720
+ continue
721
+ # Issue with Python 3.9 and 3.10.
722
+ if isinstance(member, GenericAlias):
723
+ continue # pragma: no cover (for Python > 3.10 only)
724
+ if not issubclass(member, cls):
725
+ continue
726
+ if getattr(member, "__parameters__", None):
727
+ continue
728
+ factory_classes.add(member)
714
729
 
715
730
  if len(factory_classes) == 1:
716
- factory_cls = factory_classes[0]
731
+ factory_cls = next(iter(factory_classes))
717
732
  else:
718
733
  msg = (
719
734
  f"Found {len(factory_classes)} infrastructure factory classes in"
720
735
  f" '{topic}', expected 1."
721
736
  )
722
737
  raise InfrastructureFactoryError(msg)
723
- elif isinstance(obj, type) and issubclass(obj, InfrastructureFactory):
738
+ elif isinstance(obj, type) and issubclass(obj, cls):
724
739
  factory_cls = obj
725
740
  else:
726
741
  msg = (
eventsourcing/popo.py CHANGED
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  import contextlib
4
4
  from collections import defaultdict
5
- from threading import Event, Lock
5
+ from threading import Event, RLock
6
6
  from typing import TYPE_CHECKING, Any
7
7
 
8
8
  from eventsourcing.persistence import (
@@ -27,7 +27,7 @@ if TYPE_CHECKING:
27
27
 
28
28
  class POPORecorder:
29
29
  def __init__(self) -> None:
30
- self._database_lock = Lock()
30
+ self._database_lock = RLock()
31
31
 
32
32
 
33
33
  class POPOAggregateRecorder(POPORecorder, AggregateRecorder):