eventsourcing 9.4.0a7__py3-none-any.whl → 9.4.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of eventsourcing might be problematic. Click here for more details.

@@ -40,45 +40,77 @@ class ApplicationSubscription(Iterator[tuple[DomainEventProtocol, Tracking]]):
40
40
  gt: int | None = None,
41
41
  topics: Sequence[str] = (),
42
42
  ):
43
+ """
44
+ Starts subscription to application's stored events using application's recorder.
45
+ """
43
46
  self.name = app.name
44
47
  self.recorder = app.recorder
45
48
  self.mapper = app.mapper
46
49
  self.subscription = self.recorder.subscribe(gt=gt, topics=topics)
47
50
 
51
+ def stop(self) -> None:
52
+ """
53
+ Stops the stored event subscription.
54
+ """
55
+ self.subscription.stop()
56
+
48
57
  def __enter__(self) -> Self:
58
+ """
59
+ Calls __enter__ on the stored event subscription.
60
+ """
49
61
  self.subscription.__enter__()
50
62
  return self
51
63
 
52
64
  def __exit__(self, *args: object, **kwargs: Any) -> None:
65
+ """
66
+ Calls __exit__ on the stored event subscription.
67
+ """
53
68
  self.subscription.__exit__(*args, **kwargs)
54
69
 
55
70
  def __iter__(self) -> Self:
56
71
  return self
57
72
 
58
73
  def __next__(self) -> tuple[DomainEventProtocol, Tracking]:
74
+ """
75
+ Returns the next stored event from the stored event subscription.
76
+ Constructs a tracking object that identifies the position of
77
+ the event in the application sequence, and reconstructs a domain
78
+ event object from the stored event object.
79
+ """
59
80
  notification = next(self.subscription)
60
81
  tracking = Tracking(self.name, notification.id)
61
82
  domain_event = self.mapper.to_domain_event(notification)
62
83
  return domain_event, tracking
63
84
 
64
85
  def __del__(self) -> None:
86
+ """
87
+ Stops the stored event subscription.
88
+ """
65
89
  self.stop()
66
90
 
67
- def stop(self) -> None:
68
- self.subscription.stop()
69
-
70
91
 
71
92
  class Projection(ABC, Generic[TTrackingRecorder]):
72
93
  name: str = ""
73
- """Name of projection, used to pick prefixed environment variables."""
74
- topics: Sequence[str] = ()
75
- """Event topics, used to filter events in database."""
94
+ """
95
+ Name of projection, used to pick prefixed environment
96
+ variables and define database table names.
97
+ """
98
+ topics: tuple[str, ...] = ()
99
+ """
100
+ Filter events in database when subscribing to an application.
101
+ """
76
102
 
77
103
  def __init__(
78
104
  self,
79
- tracking_recorder: TTrackingRecorder,
105
+ view: TTrackingRecorder,
80
106
  ):
81
- self.tracking_recorder = tracking_recorder
107
+ """Initialises a projection instance."""
108
+ self._view = view
109
+
110
+ @property
111
+ def view(self) -> TTrackingRecorder:
112
+ """Materialised view of an event-sourced application."""
113
+ return self._view
82
114
 
83
115
  @singledispatchmethod
84
116
  @abstractmethod
@@ -90,9 +122,6 @@ class Projection(ABC, Generic[TTrackingRecorder]):
90
122
  """
91
123
 
92
124
 
93
- TProjection = TypeVar("TProjection", bound=Projection[Any])
94
-
95
-
96
125
  TApplication = TypeVar("TApplication", bound=Application)
97
126
 
98
127
 
@@ -102,30 +131,40 @@ class ProjectionRunner(Generic[TApplication, TTrackingRecorder]):
102
131
  *,
103
132
  application_class: type[TApplication],
104
133
  projection_class: type[Projection[TTrackingRecorder]],
105
- tracking_recorder_class: type[TTrackingRecorder] | None = None,
134
+ view_class: type[TTrackingRecorder],
106
135
  env: EnvType | None = None,
107
136
  ):
137
+ """
138
+ Constructs application from given application class with given environment.
139
+ Also constructs a materialised view from given class using an infrastructure
140
+ factory constructed with an environment named after the projection. Also
141
+ constructs a projection with the constructed materialised view object.
142
+ Starts a subscription to application and, in a separate event-processing
143
+ thread, calls projection's process_event() method for each event and tracking
144
+ object pair received from the subscription.
145
+ """
146
+ self._is_stopping = Event()
147
+
108
148
  self.app: TApplication = application_class(env)
109
149
 
110
- projection_environment = self._construct_env(
111
- name=projection_class.name or projection_class.__name__, env=env
112
- )
113
- self.projection_factory: InfrastructureFactory[TTrackingRecorder] = (
114
- InfrastructureFactory.construct(env=projection_environment)
115
- )
116
- self.tracking_recorder: TTrackingRecorder = (
117
- self.projection_factory.tracking_recorder(tracking_recorder_class)
150
+ self.view = (
151
+ InfrastructureFactory[TTrackingRecorder]
152
+ .construct(
153
+ env=self._construct_env(
154
+ name=projection_class.name or projection_class.__name__, env=env
155
+ )
156
+ )
157
+ .tracking_recorder(view_class)
118
158
  )
119
159
 
120
160
  self.projection = projection_class(
121
- tracking_recorder=self.tracking_recorder,
161
+ view=self.view,
122
162
  )
123
163
  self.subscription = ApplicationSubscription(
124
164
  app=self.app,
125
- gt=self.tracking_recorder.max_tracking_id(self.app.name),
165
+ gt=self.view.max_tracking_id(self.app.name),
126
166
  topics=self.projection.topics,
127
167
  )
128
- self._is_stopping = Event()
129
168
  self.thread_error: BaseException | None = None
130
169
  self.processing_thread = Thread(
131
170
  target=self._process_events_loop,
@@ -149,6 +188,9 @@ class ProjectionRunner(Generic[TApplication, TTrackingRecorder]):
149
188
  return Environment(name, _env)
150
189
 
151
190
  def stop(self) -> None:
191
+ """
192
+ Stops the application subscription, which will stop the event-processing thread.
193
+ """
152
194
  self._is_stopping.set()
153
195
  self.subscription.stop()
154
196
 
@@ -177,15 +219,23 @@ class ProjectionRunner(Generic[TApplication, TTrackingRecorder]):
177
219
  )
178
220
 
179
221
  is_stopping.set()
180
- subscription.subscription.stop()
222
+ subscription.stop()
181
223
 
182
224
  def run_forever(self, timeout: float | None = None) -> None:
225
+ """
226
+ Blocks until timeout, or until the runner is stopped or errors. Re-raises
227
+ any error otherwise exits normally
228
+ """
183
229
  if self._is_stopping.wait(timeout=timeout) and self.thread_error is not None:
184
230
  raise self.thread_error
185
231
 
186
- def wait(self, notification_id: int, timeout: float = 1.0) -> None:
232
+ def wait(self, notification_id: int | None, timeout: float = 1.0) -> None:
233
+ """
234
+ Blocks until timeout, or until the materialised view has recorded a tracking
235
+ object that is greater than or equal to the given notification ID.
236
+ """
187
237
  try:
188
- self.projection.tracking_recorder.wait(
238
+ self.projection.view.wait(
189
239
  application_name=self.subscription.name,
190
240
  notification_id=notification_id,
191
241
  timeout=timeout,
@@ -199,8 +249,14 @@ class ProjectionRunner(Generic[TApplication, TTrackingRecorder]):
199
249
  return self
200
250
 
201
251
  def __exit__(self, *args: object, **kwargs: Any) -> None:
252
+ """
253
+ Calls stop() and waits for the event-processing thread to exit.
254
+ """
202
255
  self.stop()
203
256
  self.processing_thread.join()
204
257
 
205
258
  def __del__(self) -> None:
259
+ """
260
+ Calls stop().
261
+ """
206
262
  self.stop()
eventsourcing/sqlite.py CHANGED
@@ -530,7 +530,11 @@ class SQLiteTrackingRecorder(SQLiteRecorder, TrackingRecorder):
530
530
  c.execute(self.select_max_tracking_id_statement, params)
531
531
  return c.fetchone()[0]
532
532
 
533
- def has_tracking_id(self, application_name: str, notification_id: int) -> bool:
533
+ def has_tracking_id(
534
+ self, application_name: str, notification_id: int | None
535
+ ) -> bool:
536
+ if notification_id is None:
537
+ return True
534
538
  params = [application_name, notification_id]
535
539
  with self.datastore.transaction(commit=False) as c:
536
540
  c.execute(self.count_tracking_id_statement, params)
eventsourcing/system.py CHANGED
@@ -17,9 +17,11 @@ from eventsourcing.application import (
17
17
  Application,
18
18
  NotificationLog,
19
19
  ProcessingEvent,
20
+ ProgrammingError,
20
21
  Section,
21
22
  TApplication,
22
23
  )
24
+ from eventsourcing.dispatch import singledispatchmethod
23
25
  from eventsourcing.domain import DomainEventProtocol, MutableOrImmutableAggregate
24
26
  from eventsourcing.persistence import (
25
27
  IntegrityError,
@@ -196,7 +198,7 @@ class Follower(Application):
196
198
  self.notify(processing_event.events)
197
199
  self._notify(recordings)
198
200
 
199
- @abstractmethod
201
+ @singledispatchmethod
200
202
  def policy(
201
203
  self,
202
204
  domain_event: DomainEventProtocol,
@@ -379,7 +381,7 @@ class System:
379
381
  return cls
380
382
 
381
383
  @property
382
- def topic(self) -> str | None:
384
+ def topic(self) -> str:
383
385
  """
384
386
  Returns a topic to the system object, if constructed as a module attribute.
385
387
  """
@@ -389,6 +391,9 @@ class System:
389
391
  if value is self:
390
392
  topic = module.__name__ + ":" + name
391
393
  assert resolve_topic(topic) is self
394
+ if topic is None:
395
+ msg = "Unable to compute topic for system object: %s" % self
396
+ raise ProgrammingError(msg)
392
397
  return topic
393
398
 
394
399
 
@@ -423,6 +428,13 @@ class Runner(ABC):
423
428
  Returns an application instance for given application class.
424
429
  """
425
430
 
431
+ def __enter__(self) -> Self:
432
+ self.start()
433
+ return self
434
+
435
+ def __exit__(self, *args: object, **kwargs: Any) -> None:
436
+ self.stop()
437
+
426
438
 
427
439
  class RunnerAlreadyStartedError(Exception):
428
440
  """
@@ -548,13 +560,6 @@ class SingleThreadedRunner(Runner, RecordingEventReceiver):
548
560
  assert isinstance(app, cls)
549
561
  return app
550
562
 
551
- def __enter__(self) -> Self:
552
- self.start()
553
- return self
554
-
555
- def __exit__(self, *args: object, **kwargs: Any) -> None:
556
- self.stop()
557
-
558
563
 
559
564
  class NewSingleThreadedRunner(Runner, RecordingEventReceiver):
560
565
  """
@@ -34,7 +34,7 @@ class ExampleApplicationTestCase(TestCase):
34
34
  counts: ClassVar[dict[type[TestCase], int]] = {}
35
35
  expected_factory_topic: str
36
36
 
37
- def test_example_application(self):
37
+ def test_example_application(self) -> None:
38
38
  app = BankAccounts(env={"IS_SNAPSHOTTING_ENABLED": "y"})
39
39
 
40
40
  self.assertEqual(get_topic(type(app.factory)), self.expected_factory_topic)
@@ -80,30 +80,32 @@ class ExampleApplicationTestCase(TestCase):
80
80
  # Take snapshot (specify version).
81
81
  app.take_snapshot(account_id, version=Aggregate.INITIAL_VERSION + 1)
82
82
 
83
+ assert app.snapshots is not None # for mypy
83
84
  snapshots = list(app.snapshots.get(account_id))
84
85
  self.assertEqual(len(snapshots), 1)
85
86
  self.assertEqual(snapshots[0].originator_version, Aggregate.INITIAL_VERSION + 1)
86
87
 
87
- from_snapshot = app.repository.get(
88
+ from_snapshot1: BankAccount = app.repository.get(
88
89
  account_id, version=Aggregate.INITIAL_VERSION + 2
89
90
  )
90
- self.assertIsInstance(from_snapshot, BankAccount)
91
- self.assertEqual(from_snapshot.version, Aggregate.INITIAL_VERSION + 2)
92
- self.assertEqual(from_snapshot.balance, Decimal("35.00"))
91
+ self.assertIsInstance(from_snapshot1, BankAccount)
92
+ self.assertEqual(from_snapshot1.version, Aggregate.INITIAL_VERSION + 2)
93
+ self.assertEqual(from_snapshot1.balance, Decimal("35.00"))
93
94
 
94
95
  # Take snapshot (don't specify version).
95
96
  app.take_snapshot(account_id)
97
+ assert app.snapshots is not None # for mypy
96
98
  snapshots = list(app.snapshots.get(account_id))
97
99
  self.assertEqual(len(snapshots), 2)
98
100
  self.assertEqual(snapshots[0].originator_version, Aggregate.INITIAL_VERSION + 1)
99
101
  self.assertEqual(snapshots[1].originator_version, Aggregate.INITIAL_VERSION + 3)
100
102
 
101
- from_snapshot = app.repository.get(account_id)
102
- self.assertIsInstance(from_snapshot, BankAccount)
103
- self.assertEqual(from_snapshot.version, Aggregate.INITIAL_VERSION + 3)
104
- self.assertEqual(from_snapshot.balance, Decimal("65.00"))
103
+ from_snapshot2: BankAccount = app.repository.get(account_id)
104
+ self.assertIsInstance(from_snapshot2, BankAccount)
105
+ self.assertEqual(from_snapshot2.version, Aggregate.INITIAL_VERSION + 3)
106
+ self.assertEqual(from_snapshot2.balance, Decimal("65.00"))
105
107
 
106
- def test__put_performance(self):
108
+ def test__put_performance(self) -> None:
107
109
  app = BankAccounts()
108
110
 
109
111
  # Open an account.
@@ -113,7 +115,7 @@ class ExampleApplicationTestCase(TestCase):
113
115
  )
114
116
  account = app.get_account(account_id)
115
117
 
116
- def put():
118
+ def put() -> None:
117
119
  # Credit the account.
118
120
  account.append_transaction(Decimal("10.00"))
119
121
  app.save(account)
@@ -125,14 +127,14 @@ class ExampleApplicationTestCase(TestCase):
125
127
  duration = timeit(put, number=self.timeit_number)
126
128
  self.print_time("store events", duration)
127
129
 
128
- def test__get_performance_with_snapshotting_enabled(self):
130
+ def test__get_performance_with_snapshotting_enabled(self) -> None:
129
131
  print()
130
132
  self._test_get_performance(is_snapshotting_enabled=True)
131
133
 
132
- def test__get_performance_without_snapshotting_enabled(self):
134
+ def test__get_performance_without_snapshotting_enabled(self) -> None:
133
135
  self._test_get_performance(is_snapshotting_enabled=False)
134
136
 
135
- def _test_get_performance(self, *, is_snapshotting_enabled: bool):
137
+ def _test_get_performance(self, *, is_snapshotting_enabled: bool) -> None:
136
138
  app = BankAccounts(
137
139
  env={"IS_SNAPSHOTTING_ENABLED": "y" if is_snapshotting_enabled else "n"}
138
140
  )
@@ -143,7 +145,7 @@ class ExampleApplicationTestCase(TestCase):
143
145
  email_address="alice@example.com",
144
146
  )
145
147
 
146
- def read():
148
+ def read() -> None:
147
149
  # Get the account.
148
150
  app.get_account(account_id)
149
151
 
@@ -158,7 +160,7 @@ class ExampleApplicationTestCase(TestCase):
158
160
  test_label = "get without snapshotting"
159
161
  self.print_time(test_label, duration)
160
162
 
161
- def print_time(self, test_label, duration):
163
+ def print_time(self, test_label: str, duration: float) -> None:
162
164
  cls = type(self)
163
165
  if cls not in self.started_ats:
164
166
  self.started_ats[cls] = datetime.now()
@@ -176,8 +178,8 @@ class ExampleApplicationTestCase(TestCase):
176
178
  )
177
179
 
178
180
  if self.counts[cls] == 3:
179
- duration = datetime.now() - cls.started_ats[cls]
180
- print(f"{cls.__name__: <29} timeit duration: {duration}")
181
+ cls_duration = datetime.now() - cls.started_ats[cls]
182
+ print(f"{cls.__name__: <29} timeit duration: {cls_duration}")
181
183
  sys.stdout.flush()
182
184
 
183
185
 
@@ -199,7 +201,7 @@ class BankAccounts(Application):
199
201
  super().register_transcodings(transcoder)
200
202
  transcoder.register(EmailAddressAsStr())
201
203
 
202
- def open_account(self, full_name, email_address):
204
+ def open_account(self, full_name: str, email_address: str) -> UUID:
203
205
  account = BankAccount.open(
204
206
  full_name=full_name,
205
207
  email_address=email_address,
@@ -218,7 +220,7 @@ class BankAccounts(Application):
218
220
 
219
221
  def get_account(self, account_id: UUID) -> BankAccount:
220
222
  try:
221
- aggregate = self.repository.get(account_id)
223
+ aggregate: BankAccount = self.repository.get(account_id)
222
224
  except AggregateNotFoundError:
223
225
  raise self.AccountNotFoundError(account_id) from None
224
226
  else:
@@ -230,7 +232,7 @@ class BankAccounts(Application):
230
232
 
231
233
 
232
234
  class ApplicationTestCase(TestCase):
233
- def test_name(self):
235
+ def test_name(self) -> None:
234
236
  self.assertEqual(Application.name, "Application")
235
237
 
236
238
  class MyApplication1(Application):
@@ -243,7 +245,7 @@ class ApplicationTestCase(TestCase):
243
245
 
244
246
  self.assertEqual(MyApplication2.name, "MyBoundedContext")
245
247
 
246
- def test_resolve_persistence_topics(self):
248
+ def test_resolve_persistence_topics(self) -> None:
247
249
  # None specified.
248
250
  app = Application()
249
251
  self.assertIsInstance(app.factory, InfrastructureFactory)
@@ -279,7 +281,7 @@ class ApplicationTestCase(TestCase):
279
281
  "eventsourcing.application:Application",
280
282
  )
281
283
 
282
- def test_save_returns_recording_event(self):
284
+ def test_save_returns_recording_event(self) -> None:
283
285
  app = Application()
284
286
 
285
287
  recordings = app.save()
@@ -301,7 +303,9 @@ class ApplicationTestCase(TestCase):
301
303
  self.assertEqual(recordings[0].notification.id, 3)
302
304
  self.assertEqual(recordings[1].notification.id, 4)
303
305
 
304
- def test_take_snapshot_raises_assertion_error_if_snapshotting_not_enabled(self):
306
+ def test_take_snapshot_raises_assertion_error_if_snapshotting_not_enabled(
307
+ self,
308
+ ) -> None:
305
309
  app = Application()
306
310
  with self.assertRaises(AssertionError) as cm:
307
311
  app.take_snapshot(uuid4())
@@ -314,12 +318,13 @@ class ApplicationTestCase(TestCase):
314
318
  "application class.",
315
319
  )
316
320
 
317
- def test_application_with_cached_aggregates_and_fastforward(self):
321
+ def test_application_with_cached_aggregates_and_fastforward(self) -> None:
318
322
  app = Application(env={"AGGREGATE_CACHE_MAXSIZE": "10"})
319
323
 
320
324
  aggregate = Aggregate()
321
325
  app.save(aggregate)
322
326
  # Should not put the aggregate in the cache.
327
+ assert app.repository.cache is not None # for mypy
323
328
  with self.assertRaises(KeyError):
324
329
  self.assertEqual(aggregate, app.repository.cache.get(aggregate.id))
325
330
 
@@ -339,7 +344,7 @@ class ApplicationTestCase(TestCase):
339
344
  app.repository.get(aggregate.id)
340
345
  self.assertEqual(aggregate, app.repository.cache.get(aggregate.id))
341
346
 
342
- def test_application_fastforward_skipping_during_contention(self):
347
+ def test_application_fastforward_skipping_during_contention(self) -> None:
343
348
  app = Application(
344
349
  env={
345
350
  "AGGREGATE_CACHE_MAXSIZE": "10",
@@ -354,18 +359,18 @@ class ApplicationTestCase(TestCase):
354
359
  stopped = Event()
355
360
 
356
361
  # Trigger, save, get, check.
357
- def trigger_save_get_check():
362
+ def trigger_save_get_check() -> None:
358
363
  while not stopped.is_set():
359
364
  try:
360
- aggregate = app.repository.get(aggregate_id)
365
+ aggregate: Aggregate = app.repository.get(aggregate_id)
361
366
  aggregate.trigger_event(Aggregate.Event)
362
367
  saved_version = aggregate.version
363
368
  try:
364
369
  app.save(aggregate)
365
370
  except IntegrityError:
366
371
  continue
367
- cached_version = app.repository.get(aggregate_id).version
368
- if saved_version > cached_version:
372
+ cached: Aggregate = app.repository.get(aggregate_id)
373
+ if saved_version > cached.version:
369
374
  print(f"Skipped fast-forwarding at version {saved_version}")
370
375
  stopped.set()
371
376
  if aggregate.version % 1000 == 0:
@@ -384,7 +389,7 @@ class ApplicationTestCase(TestCase):
384
389
  self.fail("Didn't skip fast forwarding before test timed out...")
385
390
  executor.shutdown()
386
391
 
387
- def test_application_fastforward_blocking_during_contention(self):
392
+ def test_application_fastforward_blocking_during_contention(self) -> None:
388
393
  app = Application(
389
394
  env={
390
395
  "AGGREGATE_CACHE_MAXSIZE": "10",
@@ -398,18 +403,18 @@ class ApplicationTestCase(TestCase):
398
403
  stopped = Event()
399
404
 
400
405
  # Trigger, save, get, check.
401
- def trigger_save_get_check():
406
+ def trigger_save_get_check() -> None:
402
407
  while not stopped.is_set():
403
408
  try:
404
- aggregate = app.repository.get(aggregate_id)
409
+ aggregate: Aggregate = app.repository.get(aggregate_id)
405
410
  aggregate.trigger_event(Aggregate.Event)
406
411
  saved_version = aggregate.version
407
412
  try:
408
413
  app.save(aggregate)
409
414
  except IntegrityError:
410
415
  continue
411
- cached_version = app.repository.get(aggregate_id).version
412
- if saved_version > cached_version:
416
+ cached: Aggregate = app.repository.get(aggregate_id)
417
+ if saved_version > cached.version:
413
418
  print(f"Skipped fast-forwarding at version {saved_version}")
414
419
  stopped.set()
415
420
  if aggregate.version % 1000 == 0:
@@ -429,7 +434,7 @@ class ApplicationTestCase(TestCase):
429
434
  self.fail("Wrongly skipped fast forwarding")
430
435
  executor.shutdown()
431
436
 
432
- def test_application_with_cached_aggregates_not_fastforward(self):
437
+ def test_application_with_cached_aggregates_not_fastforward(self) -> None:
433
438
  app = Application(
434
439
  env={
435
440
  "AGGREGATE_CACHE_MAXSIZE": "10",
@@ -439,11 +444,12 @@ class ApplicationTestCase(TestCase):
439
444
  aggregate = Aggregate()
440
445
  app.save(aggregate)
441
446
  # Should put the aggregate in the cache.
447
+ assert app.repository.cache is not None # for mypy
442
448
  self.assertEqual(aggregate, app.repository.cache.get(aggregate.id))
443
449
  app.repository.get(aggregate.id)
444
450
  self.assertEqual(aggregate, app.repository.cache.get(aggregate.id))
445
451
 
446
- def test_application_with_deepcopy_from_cache_arg(self):
452
+ def test_application_with_deepcopy_from_cache_arg(self) -> None:
447
453
  app = Application(
448
454
  env={
449
455
  "AGGREGATE_CACHE_MAXSIZE": "10",
@@ -452,14 +458,15 @@ class ApplicationTestCase(TestCase):
452
458
  aggregate = Aggregate()
453
459
  app.save(aggregate)
454
460
  self.assertEqual(aggregate.version, 1)
455
- aggregate = app.repository.get(aggregate.id)
456
- aggregate.version = 101
461
+ reconstructed: Aggregate = app.repository.get(aggregate.id)
462
+ reconstructed.version = 101
463
+ assert app.repository.cache is not None # for mypy
457
464
  self.assertEqual(app.repository.cache.get(aggregate.id).version, 1)
458
- aggregate = app.repository.get(aggregate.id, deepcopy_from_cache=False)
459
- aggregate.version = 101
465
+ cached: Aggregate = app.repository.get(aggregate.id, deepcopy_from_cache=False)
466
+ cached.version = 101
460
467
  self.assertEqual(app.repository.cache.get(aggregate.id).version, 101)
461
468
 
462
- def test_application_with_deepcopy_from_cache_attribute(self):
469
+ def test_application_with_deepcopy_from_cache_attribute(self) -> None:
463
470
  app = Application(
464
471
  env={
465
472
  "AGGREGATE_CACHE_MAXSIZE": "10",
@@ -468,15 +475,16 @@ class ApplicationTestCase(TestCase):
468
475
  aggregate = Aggregate()
469
476
  app.save(aggregate)
470
477
  self.assertEqual(aggregate.version, 1)
471
- aggregate = app.repository.get(aggregate.id)
472
- aggregate.version = 101
478
+ reconstructed: Aggregate = app.repository.get(aggregate.id)
479
+ reconstructed.version = 101
480
+ assert app.repository.cache is not None # for mypy
473
481
  self.assertEqual(app.repository.cache.get(aggregate.id).version, 1)
474
482
  app.repository.deepcopy_from_cache = False
475
- aggregate = app.repository.get(aggregate.id)
476
- aggregate.version = 101
483
+ cached: Aggregate = app.repository.get(aggregate.id)
484
+ cached.version = 101
477
485
  self.assertEqual(app.repository.cache.get(aggregate.id).version, 101)
478
486
 
479
- def test_application_log(self):
487
+ def test_application_log(self) -> None:
480
488
  # Check the old 'log' attribute presents the 'notification log' object.
481
489
  app = Application()
482
490
 
@@ -486,6 +494,6 @@ class ApplicationTestCase(TestCase):
486
494
 
487
495
  self.assertEqual(1, len(w))
488
496
  self.assertIs(w[-1].category, DeprecationWarning)
489
- self.assertEqual(
490
- "'log' is deprecated, use 'notifications' instead", w[-1].message.args[0]
497
+ self.assertIn(
498
+ "'log' is deprecated, use 'notifications' instead", str(w[-1].message)
491
499
  )
@@ -2,6 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass
4
4
  from decimal import Decimal
5
+ from typing import cast
5
6
  from uuid import uuid4
6
7
 
7
8
  from eventsourcing.domain import Aggregate, AggregateCreated, AggregateEvent
@@ -59,6 +60,7 @@ class BankAccount(Aggregate):
59
60
  if self.balance + amount < -self.overdraft_limit:
60
61
  raise InsufficientFundsError({"account_id": self.id})
61
62
 
63
+ @dataclass(frozen=True)
62
64
  class TransactionAppended(AggregateEvent):
63
65
  """
64
66
  Domain event for when transaction
@@ -67,11 +69,11 @@ class BankAccount(Aggregate):
67
69
 
68
70
  amount: Decimal
69
71
 
70
- def apply(self, account: BankAccount) -> None:
72
+ def apply(self, aggregate: Aggregate) -> None:
71
73
  """
72
74
  Increments the account balance.
73
75
  """
74
- account.balance += self.amount
76
+ cast(BankAccount, aggregate).balance += self.amount
75
77
 
76
78
  def set_overdraft_limit(self, overdraft_limit: Decimal) -> None:
77
79
  """
@@ -93,8 +95,8 @@ class BankAccount(Aggregate):
93
95
 
94
96
  overdraft_limit: Decimal
95
97
 
96
- def apply(self, account: BankAccount):
97
- account.overdraft_limit = self.overdraft_limit
98
+ def apply(self, aggregate: Aggregate) -> None:
99
+ cast(BankAccount, aggregate).overdraft_limit = self.overdraft_limit
98
100
 
99
101
  def close(self) -> None:
100
102
  """
@@ -107,8 +109,8 @@ class BankAccount(Aggregate):
107
109
  Domain event for when account is closed.
108
110
  """
109
111
 
110
- def apply(self, account: BankAccount):
111
- account.is_closed = True
112
+ def apply(self, aggregate: Aggregate) -> None:
113
+ cast(BankAccount, aggregate).is_closed = True
112
114
 
113
115
 
114
116
  class AccountClosedError(Exception):