eventsourcing 9.4.3__py3-none-any.whl → 9.4.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of eventsourcing might be problematic. Click here for more details.

@@ -1,20 +1,17 @@
1
1
  from __future__ import annotations
2
2
 
3
- import os
4
- import sys
5
3
  import traceback
6
4
  import warnings
7
5
  from concurrent.futures import ThreadPoolExecutor
8
6
  from decimal import Decimal
9
7
  from threading import Event, get_ident
10
8
  from time import sleep
11
- from timeit import timeit
12
9
  from typing import TYPE_CHECKING, Any, ClassVar
13
10
  from unittest import TestCase
14
11
  from uuid import UUID, uuid4
15
12
 
16
13
  from eventsourcing.application import AggregateNotFoundError, Application
17
- from eventsourcing.domain import Aggregate, datetime_now_with_tzinfo
14
+ from eventsourcing.domain import Aggregate
18
15
  from eventsourcing.persistence import (
19
16
  InfrastructureFactory,
20
17
  InfrastructureFactoryError,
@@ -28,11 +25,8 @@ from eventsourcing.utils import EnvType, get_topic
28
25
  if TYPE_CHECKING:
29
26
  from datetime import datetime
30
27
 
31
- TIMEIT_FACTOR = int(os.environ.get("TEST_TIMEIT_FACTOR", default="10"))
32
-
33
28
 
34
29
  class ExampleApplicationTestCase(TestCase):
35
- timeit_number: ClassVar[int] = TIMEIT_FACTOR
36
30
  started_ats: ClassVar[dict[type[TestCase], datetime]] = {}
37
31
  counts: ClassVar[dict[type[TestCase], int]] = {}
38
32
  expected_factory_topic: str
@@ -76,7 +70,7 @@ class ExampleApplicationTestCase(TestCase):
76
70
  Decimal("65.00"),
77
71
  )
78
72
 
79
- sleep(1) # Added to make eventsourcing-axon tests work, perhaps not necessary.
73
+ # sleep(1) # Added to make eventsourcing-axon tests work.
80
74
  section = app.notification_log["1,10"]
81
75
  self.assertEqual(len(section.items), 4)
82
76
 
@@ -108,83 +102,6 @@ class ExampleApplicationTestCase(TestCase):
108
102
  self.assertEqual(from_snapshot2.version, Aggregate.INITIAL_VERSION + 3)
109
103
  self.assertEqual(from_snapshot2.balance, Decimal("65.00"))
110
104
 
111
- def test__put_performance(self) -> None:
112
- app = BankAccounts()
113
-
114
- # Open an account.
115
- account_id = app.open_account(
116
- full_name="Alice",
117
- email_address="alice@example.com",
118
- )
119
- account = app.get_account(account_id)
120
-
121
- def put() -> None:
122
- # Credit the account.
123
- account.append_transaction(Decimal("10.00"))
124
- app.save(account)
125
-
126
- # Warm up.
127
- number = 10
128
- timeit(put, number=number)
129
-
130
- duration = timeit(put, number=self.timeit_number)
131
- self.print_time("store events", duration)
132
-
133
- def test__get_performance_with_snapshotting_enabled(self) -> None:
134
- print()
135
- self._test_get_performance(is_snapshotting_enabled=True)
136
-
137
- def test__get_performance_without_snapshotting_enabled(self) -> None:
138
- self._test_get_performance(is_snapshotting_enabled=False)
139
-
140
- def _test_get_performance(self, *, is_snapshotting_enabled: bool) -> None:
141
- app = BankAccounts(
142
- env={"IS_SNAPSHOTTING_ENABLED": "y" if is_snapshotting_enabled else "n"}
143
- )
144
-
145
- # Open an account.
146
- account_id = app.open_account(
147
- full_name="Alice",
148
- email_address="alice@example.com",
149
- )
150
-
151
- def read() -> None:
152
- # Get the account.
153
- app.get_account(account_id)
154
-
155
- # Warm up.
156
- timeit(read, number=10)
157
-
158
- duration = timeit(read, number=self.timeit_number)
159
-
160
- if is_snapshotting_enabled:
161
- test_label = "get with snapshotting"
162
- else:
163
- test_label = "get without snapshotting"
164
- self.print_time(test_label, duration)
165
-
166
- def print_time(self, test_label: str, duration: float) -> None:
167
- cls = type(self)
168
- if cls not in self.started_ats:
169
- self.started_ats[cls] = datetime_now_with_tzinfo()
170
- print(f"{cls.__name__: <29} timeit number: {cls.timeit_number}")
171
- self.counts[cls] = 1
172
- else:
173
- self.counts[cls] += 1
174
-
175
- rate = f"{self.timeit_number / duration:.0f} events/s"
176
- print(
177
- f"{cls.__name__: <29}",
178
- f"{test_label: <21}",
179
- f"{rate: >15}",
180
- f" {1000 * duration / self.timeit_number:.3f} ms/event",
181
- )
182
-
183
- if self.counts[cls] == 3:
184
- cls_duration = datetime_now_with_tzinfo() - cls.started_ats[cls]
185
- print(f"{cls.__name__: <29} timeit duration: {cls_duration}")
186
- sys.stdout.flush()
187
-
188
105
 
189
106
  class EmailAddressAsStr(Transcoding):
190
107
  type = EmailAddress
@@ -197,7 +114,7 @@ class EmailAddressAsStr(Transcoding):
197
114
  return EmailAddress(data)
198
115
 
199
116
 
200
- class BankAccounts(Application):
117
+ class BankAccounts(Application[UUID]):
201
118
  is_snapshotting_enabled = True
202
119
 
203
120
  def register_transcodings(self, transcoder: JSONTranscoder) -> None:
@@ -238,19 +155,19 @@ class ApplicationTestCase(TestCase):
238
155
  def test_name(self) -> None:
239
156
  self.assertEqual(Application.name, "Application")
240
157
 
241
- class MyApplication1(Application):
158
+ class MyApplication1(Application[UUID]):
242
159
  pass
243
160
 
244
161
  self.assertEqual(MyApplication1.name, "MyApplication1")
245
162
 
246
- class MyApplication2(Application):
163
+ class MyApplication2(Application[UUID]):
247
164
  name = "MyBoundedContext"
248
165
 
249
166
  self.assertEqual(MyApplication2.name, "MyBoundedContext")
250
167
 
251
168
  def test_resolve_persistence_topics(self) -> None:
252
169
  # None specified.
253
- app = Application()
170
+ app = Application[UUID]()
254
171
  self.assertIsInstance(app.factory, InfrastructureFactory)
255
172
 
256
173
  # Legacy 'INFRASTRUCTURE_FACTORY'.
@@ -286,7 +203,7 @@ class ApplicationTestCase(TestCase):
286
203
  )
287
204
 
288
205
  def test_save_returns_recording_event(self) -> None:
289
- app = Application()
206
+ app = Application[UUID]()
290
207
 
291
208
  recordings = app.save()
292
209
  self.assertEqual(recordings, [])
@@ -310,7 +227,7 @@ class ApplicationTestCase(TestCase):
310
227
  def test_take_snapshot_raises_assertion_error_if_snapshotting_not_enabled(
311
228
  self,
312
229
  ) -> None:
313
- app = Application()
230
+ app = Application[UUID]()
314
231
  with self.assertRaises(AssertionError) as cm:
315
232
  app.take_snapshot(uuid4())
316
233
  self.assertEqual(
@@ -323,7 +240,7 @@ class ApplicationTestCase(TestCase):
323
240
  )
324
241
 
325
242
  def test_application_with_cached_aggregates_and_fastforward(self) -> None:
326
- app = Application(env={"AGGREGATE_CACHE_MAXSIZE": "10"})
243
+ app = Application[UUID](env={"AGGREGATE_CACHE_MAXSIZE": "10"})
327
244
 
328
245
  aggregate = Aggregate()
329
246
  app.save(aggregate)
@@ -362,7 +279,7 @@ class ApplicationTestCase(TestCase):
362
279
  )
363
280
 
364
281
  def _check_aggregate_fastforwarding_during_contention(self, env: EnvType) -> None:
365
- app = Application(env=env)
282
+ app = Application[UUID](env=env)
366
283
 
367
284
  self.assertEqual(len(app.repository._fastforward_locks_inuse), 0)
368
285
 
@@ -474,7 +391,7 @@ class ApplicationTestCase(TestCase):
474
391
  app.close()
475
392
 
476
393
  def test_application_with_cached_aggregates_not_fastforward(self) -> None:
477
- app = Application(
394
+ app = Application[UUID](
478
395
  env={
479
396
  "AGGREGATE_CACHE_MAXSIZE": "10",
480
397
  "AGGREGATE_CACHE_FASTFORWARD": "f",
@@ -513,7 +430,7 @@ class ApplicationTestCase(TestCase):
513
430
  app.save(aggregate4)
514
431
 
515
432
  def test_application_with_deepcopy_from_cache_arg(self) -> None:
516
- app = Application(
433
+ app = Application[UUID](
517
434
  env={
518
435
  "AGGREGATE_CACHE_MAXSIZE": "10",
519
436
  }
@@ -530,7 +447,7 @@ class ApplicationTestCase(TestCase):
530
447
  self.assertEqual(app.repository.cache.get(aggregate.id).version, 101)
531
448
 
532
449
  def test_application_with_deepcopy_from_cache_attribute(self) -> None:
533
- app = Application(
450
+ app = Application[UUID](
534
451
  env={
535
452
  "AGGREGATE_CACHE_MAXSIZE": "10",
536
453
  }
@@ -549,7 +466,7 @@ class ApplicationTestCase(TestCase):
549
466
 
550
467
  def test_application_log(self) -> None:
551
468
  # Check the old 'log' attribute presents the 'notification log' object.
552
- app = Application()
469
+ app = Application[UUID]()
553
470
 
554
471
  # Verify deprecation warning.
555
472
  with warnings.catch_warnings(record=True) as w:
@@ -40,7 +40,7 @@ from eventsourcing.persistence import (
40
40
  from eventsourcing.utils import Environment, get_topic
41
41
 
42
42
  if TYPE_CHECKING:
43
- from collections.abc import Iterator
43
+ from collections.abc import Iterator, Sequence
44
44
 
45
45
  from typing_extensions import Never
46
46
 
@@ -79,6 +79,7 @@ class AggregateRecorderTestCase(TestCase, ABC):
79
79
 
80
80
  # Select stored events, expect list of one.
81
81
  stored_events = recorder.select_events(originator_id1)
82
+ stored_events = convert_stored_event_originator_ids(stored_events)
82
83
  self.assertEqual(len(stored_events), 1)
83
84
  self.assertEqual(stored_events[0].originator_id, originator_id1)
84
85
  self.assertEqual(stored_events[0].originator_version, self.INITIAL_VERSION)
@@ -105,6 +106,7 @@ class AggregateRecorderTestCase(TestCase, ABC):
105
106
 
106
107
  # Check still only have one record.
107
108
  stored_events = recorder.select_events(originator_id1)
109
+ stored_events = convert_stored_event_originator_ids(stored_events)
108
110
  self.assertEqual(len(stored_events), 1)
109
111
  self.assertEqual(stored_events[0].originator_id, stored_event1.originator_id)
110
112
  self.assertEqual(
@@ -124,6 +126,7 @@ class AggregateRecorderTestCase(TestCase, ABC):
124
126
 
125
127
  # Check we got what was written.
126
128
  stored_events = recorder.select_events(originator_id1)
129
+ stored_events = convert_stored_event_originator_ids(stored_events)
127
130
  self.assertEqual(len(stored_events), 3)
128
131
  self.assertEqual(stored_events[0].originator_id, originator_id1)
129
132
  self.assertEqual(stored_events[0].originator_version, self.INITIAL_VERSION)
@@ -140,6 +143,7 @@ class AggregateRecorderTestCase(TestCase, ABC):
140
143
 
141
144
  # Check we can get the last one recorded (used to get last snapshot).
142
145
  stored_events = recorder.select_events(originator_id1, desc=True, limit=1)
146
+ stored_events = convert_stored_event_originator_ids(stored_events)
143
147
  self.assertEqual(len(stored_events), 1)
144
148
  self.assertEqual(
145
149
  stored_events[0],
@@ -150,6 +154,7 @@ class AggregateRecorderTestCase(TestCase, ABC):
150
154
  stored_events = recorder.select_events(
151
155
  originator_id1, lte=self.INITIAL_VERSION + 1, desc=True, limit=1
152
156
  )
157
+ stored_events = convert_stored_event_originator_ids(stored_events)
153
158
  self.assertEqual(len(stored_events), 1)
154
159
  self.assertEqual(
155
160
  stored_events[0],
@@ -157,12 +162,13 @@ class AggregateRecorderTestCase(TestCase, ABC):
157
162
  )
158
163
 
159
164
  # Check we can get events between versions (historical state with snapshot).
160
- events = recorder.select_events(
165
+ stored_events = recorder.select_events(
161
166
  originator_id1, gt=self.INITIAL_VERSION, lte=self.INITIAL_VERSION + 1
162
167
  )
163
- self.assertEqual(len(events), 1)
168
+ stored_events = convert_stored_event_originator_ids(stored_events)
169
+ self.assertEqual(len(stored_events), 1)
164
170
  self.assertEqual(
165
- events[0],
171
+ stored_events[0],
166
172
  stored_event2,
167
173
  )
168
174
 
@@ -181,8 +187,10 @@ class AggregateRecorderTestCase(TestCase, ABC):
181
187
  state=b"state4",
182
188
  )
183
189
  recorder.insert_events([stored_event4])
190
+ stored_events = recorder.select_events(originator_id2)
191
+ stored_events = convert_stored_event_originator_ids(stored_events)
184
192
  self.assertEqual(
185
- recorder.select_events(originator_id2),
193
+ stored_events,
186
194
  [stored_event4],
187
195
  )
188
196
 
@@ -219,6 +227,39 @@ _TApplicationRecorder = TypeVar(
219
227
  )
220
228
 
221
229
 
230
+ def convert_notification_originator_ids(
231
+ notifications: Sequence[Notification],
232
+ ) -> Sequence[Notification]:
233
+ return [
234
+ Notification(
235
+ originator_id=convert_originator_id(n.originator_id),
236
+ originator_version=n.originator_version,
237
+ topic=n.topic,
238
+ state=n.state,
239
+ id=n.id,
240
+ )
241
+ for n in notifications
242
+ ]
243
+
244
+
245
+ def convert_stored_event_originator_ids(
246
+ stored_events: Sequence[StoredEvent],
247
+ ) -> Sequence[StoredEvent]:
248
+ return [
249
+ StoredEvent(
250
+ originator_id=convert_originator_id(s.originator_id),
251
+ originator_version=s.originator_version,
252
+ topic=s.topic,
253
+ state=s.state,
254
+ )
255
+ for s in stored_events
256
+ ]
257
+
258
+
259
+ def convert_originator_id(originator_id: UUID | str) -> UUID:
260
+ return originator_id if isinstance(originator_id, UUID) else UUID(originator_id)
261
+
262
+
222
263
  class ApplicationRecorderTestCase(TestCase, ABC, Generic[_TApplicationRecorder]):
223
264
  INITIAL_VERSION = 1
224
265
  EXPECT_CONTIGUOUS_NOTIFICATION_IDS = True
@@ -281,8 +322,9 @@ class ApplicationRecorderTestCase(TestCase, ABC, Generic[_TApplicationRecorder])
281
322
  with self.assertRaises(IntegrityError):
282
323
  recorder.insert_events([stored_event3])
283
324
 
284
- sleep(1) # Added to make eventsourcing-axon tests work, perhaps not necessary.
325
+ # sleep(1) # Added to make eventsourcing-axon tests work.
285
326
  notifications = recorder.select_notifications(start=None, limit=10)
327
+ notifications = convert_notification_originator_ids(notifications)
286
328
  self.assertEqual(len(notifications), 3)
287
329
  self.assertEqual(notifications[0].id, 1)
288
330
  self.assertEqual(notifications[0].originator_id, originator_id1)
@@ -298,6 +340,7 @@ class ApplicationRecorderTestCase(TestCase, ABC, Generic[_TApplicationRecorder])
298
340
  self.assertEqual(notifications[2].state, b"state3")
299
341
 
300
342
  notifications = recorder.select_notifications(start=1, limit=10)
343
+ notifications = convert_notification_originator_ids(notifications)
301
344
  self.assertEqual(len(notifications), 3)
302
345
  self.assertEqual(notifications[0].id, 1)
303
346
  self.assertEqual(notifications[0].originator_id, originator_id1)
@@ -313,6 +356,7 @@ class ApplicationRecorderTestCase(TestCase, ABC, Generic[_TApplicationRecorder])
313
356
  self.assertEqual(notifications[2].state, b"state3")
314
357
 
315
358
  notifications = recorder.select_notifications(start=None, stop=2, limit=10)
359
+ notifications = convert_notification_originator_ids(notifications)
316
360
  self.assertEqual(len(notifications), 2)
317
361
  self.assertEqual(notifications[0].id, 1)
318
362
  self.assertEqual(notifications[0].originator_id, originator_id1)
@@ -326,6 +370,7 @@ class ApplicationRecorderTestCase(TestCase, ABC, Generic[_TApplicationRecorder])
326
370
  notifications = recorder.select_notifications(
327
371
  start=1, limit=10, inclusive_of_start=False
328
372
  )
373
+ notifications = convert_notification_originator_ids(notifications)
329
374
  self.assertEqual(len(notifications), 2)
330
375
  self.assertEqual(notifications[0].id, 2)
331
376
  self.assertEqual(notifications[0].originator_id, originator_id1)
@@ -339,6 +384,7 @@ class ApplicationRecorderTestCase(TestCase, ABC, Generic[_TApplicationRecorder])
339
384
  notifications = recorder.select_notifications(
340
385
  start=2, limit=10, inclusive_of_start=False
341
386
  )
387
+ notifications = convert_notification_originator_ids(notifications)
342
388
  self.assertEqual(len(notifications), 1)
343
389
  self.assertEqual(notifications[0].id, 3)
344
390
  self.assertEqual(notifications[0].originator_id, originator_id2)
@@ -348,6 +394,7 @@ class ApplicationRecorderTestCase(TestCase, ABC, Generic[_TApplicationRecorder])
348
394
  notifications = recorder.select_notifications(
349
395
  start=None, limit=10, topics=["topic1", "topic2", "topic3"]
350
396
  )
397
+ notifications = convert_notification_originator_ids(notifications)
351
398
  self.assertEqual(len(notifications), 3)
352
399
  self.assertEqual(notifications[0].id, 1)
353
400
  self.assertEqual(notifications[0].originator_id, originator_id1)
@@ -363,6 +410,7 @@ class ApplicationRecorderTestCase(TestCase, ABC, Generic[_TApplicationRecorder])
363
410
  self.assertEqual(notifications[2].state, b"state3")
364
411
 
365
412
  notifications = recorder.select_notifications(1, 10, topics=["topic1"])
413
+ notifications = convert_notification_originator_ids(notifications)
366
414
  self.assertEqual(len(notifications), 1)
367
415
  self.assertEqual(notifications[0].id, 1)
368
416
  self.assertEqual(notifications[0].originator_id, originator_id1)
@@ -370,6 +418,7 @@ class ApplicationRecorderTestCase(TestCase, ABC, Generic[_TApplicationRecorder])
370
418
  self.assertEqual(notifications[0].state, b"state1")
371
419
 
372
420
  notifications = recorder.select_notifications(1, 3, topics=["topic2"])
421
+ notifications = convert_notification_originator_ids(notifications)
373
422
  self.assertEqual(len(notifications), 1)
374
423
  self.assertEqual(notifications[0].id, 2)
375
424
  self.assertEqual(notifications[0].originator_id, originator_id1)
@@ -377,6 +426,7 @@ class ApplicationRecorderTestCase(TestCase, ABC, Generic[_TApplicationRecorder])
377
426
  self.assertEqual(notifications[0].state, b"state2")
378
427
 
379
428
  notifications = recorder.select_notifications(1, 3, topics=["topic3"])
429
+ notifications = convert_notification_originator_ids(notifications)
380
430
  self.assertEqual(len(notifications), 1)
381
431
  self.assertEqual(notifications[0].id, 3)
382
432
  self.assertEqual(notifications[0].originator_id, originator_id2)
@@ -384,6 +434,7 @@ class ApplicationRecorderTestCase(TestCase, ABC, Generic[_TApplicationRecorder])
384
434
  self.assertEqual(notifications[0].state, b"state3")
385
435
 
386
436
  notifications = recorder.select_notifications(1, 3, topics=["topic1", "topic3"])
437
+ notifications = convert_notification_originator_ids(notifications)
387
438
  self.assertEqual(len(notifications), 2)
388
439
  self.assertEqual(notifications[0].id, 1)
389
440
  self.assertEqual(notifications[0].originator_id, originator_id1)
@@ -397,34 +448,42 @@ class ApplicationRecorderTestCase(TestCase, ABC, Generic[_TApplicationRecorder])
397
448
 
398
449
  # Check limit is working
399
450
  notifications = recorder.select_notifications(None, 1)
451
+ notifications = convert_notification_originator_ids(notifications)
400
452
  self.assertEqual(len(notifications), 1)
401
453
  self.assertEqual(notifications[0].id, 1)
402
454
 
403
455
  notifications = recorder.select_notifications(2, 1)
456
+ notifications = convert_notification_originator_ids(notifications)
404
457
  self.assertEqual(len(notifications), 1)
405
458
  self.assertEqual(notifications[0].id, 2)
406
459
 
407
460
  notifications = recorder.select_notifications(1, 1, inclusive_of_start=False)
461
+ notifications = convert_notification_originator_ids(notifications)
408
462
  self.assertEqual(len(notifications), 1)
409
463
  self.assertEqual(notifications[0].id, 2)
410
464
 
411
465
  notifications = recorder.select_notifications(2, 2)
466
+ notifications = convert_notification_originator_ids(notifications)
412
467
  self.assertEqual(len(notifications), 2)
413
468
  self.assertEqual(notifications[0].id, 2)
414
469
  self.assertEqual(notifications[1].id, 3)
415
470
 
416
471
  notifications = recorder.select_notifications(3, 1)
472
+ notifications = convert_notification_originator_ids(notifications)
417
473
  self.assertEqual(len(notifications), 1)
418
474
  self.assertEqual(notifications[0].id, 3)
419
475
 
420
476
  notifications = recorder.select_notifications(3, 1, inclusive_of_start=False)
477
+ notifications = convert_notification_originator_ids(notifications)
421
478
  self.assertEqual(len(notifications), 0)
422
479
 
423
480
  notifications = recorder.select_notifications(start=2, limit=10, stop=2)
481
+ notifications = convert_notification_originator_ids(notifications)
424
482
  self.assertEqual(len(notifications), 1)
425
483
  self.assertEqual(notifications[0].id, 2)
426
484
 
427
485
  notifications = recorder.select_notifications(start=1, limit=10, stop=2)
486
+ notifications = convert_notification_originator_ids(notifications)
428
487
  self.assertEqual(len(notifications), 2, len(notifications))
429
488
  self.assertEqual(notifications[0].id, 1)
430
489
  self.assertEqual(notifications[1].id, 2)
@@ -432,6 +491,7 @@ class ApplicationRecorderTestCase(TestCase, ABC, Generic[_TApplicationRecorder])
432
491
  notifications = recorder.select_notifications(
433
492
  start=1, limit=10, stop=2, inclusive_of_start=False
434
493
  )
494
+ notifications = convert_notification_originator_ids(notifications)
435
495
  self.assertEqual(len(notifications), 1, len(notifications))
436
496
  self.assertEqual(notifications[0].id, 2)
437
497
 
@@ -547,8 +607,8 @@ class ApplicationRecorderTestCase(TestCase, ABC, Generic[_TApplicationRecorder])
547
607
  errors_happened = Event()
548
608
 
549
609
  # Match this to the batch page size in postgres insert for max throughput.
550
- num_events_per_job = 500
551
- num_jobs = 60
610
+ num_events_per_job = 50
611
+ num_jobs = 10
552
612
  num_workers = 4
553
613
 
554
614
  def insert_events() -> None:
@@ -891,6 +951,13 @@ class ProcessRecorderTestCase(TestCase, ABC):
891
951
  tracking=tracking1,
892
952
  )
893
953
 
954
+ # Check get record conflict error if attempt to store same event again.
955
+ with self.assertRaises(IntegrityError):
956
+ recorder.insert_events(
957
+ stored_events=[stored_event2],
958
+ tracking=tracking2,
959
+ )
960
+
894
961
  # Get current position.
895
962
  self.assertEqual(
896
963
  recorder.max_tracking_id("upstream_app"),
@@ -1065,7 +1132,7 @@ class NonInterleavingNotificationIDsBaseCase(ABC, TestCase):
1065
1132
 
1066
1133
  errors = []
1067
1134
 
1068
- def insert_stack(stack: list[StoredEvent]) -> None:
1135
+ def insert_stack(stack: Sequence[StoredEvent]) -> None:
1069
1136
  try:
1070
1137
  race_started.wait()
1071
1138
  recorder.insert_events(stack)
@@ -1086,12 +1153,13 @@ class NonInterleavingNotificationIDsBaseCase(ABC, TestCase):
1086
1153
  if errors:
1087
1154
  raise errors[0]
1088
1155
 
1089
- sleep(1) # Added to make eventsourcing-axon tests work, perhaps not necessary.
1156
+ # sleep(1) # Added to make eventsourcing-axon tests work.
1090
1157
  notifications = recorder.select_notifications(
1091
1158
  start=max_notification_id,
1092
1159
  limit=2 * self.insert_num,
1093
1160
  inclusive_of_start=False,
1094
1161
  )
1162
+ notifications = convert_notification_originator_ids(notifications)
1095
1163
  ids_for_sequence1 = [
1096
1164
  e.id for e in notifications if e.originator_id == originator1_id
1097
1165
  ]
@@ -1111,7 +1179,7 @@ class NonInterleavingNotificationIDsBaseCase(ABC, TestCase):
1111
1179
  else:
1112
1180
  self.assertGreater(min_id_for_sequence2, max_id_for_sequence1)
1113
1181
 
1114
- def create_stack(self, originator_id: UUID) -> list[StoredEvent]:
1182
+ def create_stack(self, originator_id: UUID) -> Sequence[StoredEvent]:
1115
1183
  return [
1116
1184
  StoredEvent(
1117
1185
  originator_id=originator_id,
@@ -1161,7 +1229,7 @@ class InfrastructureFactoryTestCase(ABC, TestCase, Generic[_TInfrastrutureFactor
1161
1229
 
1162
1230
  def setUp(self) -> None:
1163
1231
  self.factory = cast(
1164
- "_TInfrastrutureFactory", InfrastructureFactory.construct(self.env)
1232
+ _TInfrastrutureFactory, InfrastructureFactory.construct(self.env)
1165
1233
  )
1166
1234
  self.assertIsInstance(self.factory, self.expected_factory_class())
1167
1235
  self.transcoder = JSONTranscoder()
@@ -1201,7 +1269,7 @@ class InfrastructureFactoryTestCase(ABC, TestCase, Generic[_TInfrastrutureFactor
1201
1269
 
1202
1270
  # Create mapper.
1203
1271
 
1204
- mapper = self.factory.mapper(
1272
+ mapper: Mapper[UUID] = self.factory.mapper(
1205
1273
  transcoder=self.transcoder,
1206
1274
  )
1207
1275
  self.assertIsInstance(mapper, Mapper)
@@ -1211,7 +1279,7 @@ class InfrastructureFactoryTestCase(ABC, TestCase, Generic[_TInfrastrutureFactor
1211
1279
  def test_createmapper_with_compressor(self) -> None:
1212
1280
  # Create mapper with compressor class as topic.
1213
1281
  self.env[self.factory.COMPRESSOR_TOPIC] = get_topic(ZlibCompressor)
1214
- mapper = self.factory.mapper(transcoder=self.transcoder)
1282
+ mapper: Mapper[UUID] = self.factory.mapper(transcoder=self.transcoder)
1215
1283
  self.assertIsInstance(mapper, Mapper)
1216
1284
  self.assertIsInstance(mapper.compressor, ZlibCompressor)
1217
1285
  self.assertIsNone(mapper.cipher)
@@ -1237,7 +1305,7 @@ class InfrastructureFactoryTestCase(ABC, TestCase, Generic[_TInfrastrutureFactor
1237
1305
  self.env[AESCipher.CIPHER_KEY] = cipher_key
1238
1306
 
1239
1307
  # Create mapper with cipher.
1240
- mapper = self.factory.mapper(transcoder=self.transcoder)
1308
+ mapper: Mapper[UUID] = self.factory.mapper(transcoder=self.transcoder)
1241
1309
  self.assertIsInstance(mapper, Mapper)
1242
1310
  self.assertIsNotNone(mapper.cipher)
1243
1311
  self.assertIsNone(mapper.compressor)
@@ -1252,7 +1320,7 @@ class InfrastructureFactoryTestCase(ABC, TestCase, Generic[_TInfrastrutureFactor
1252
1320
  cipher_key = AESCipher.create_key(16)
1253
1321
  self.env[AESCipher.CIPHER_KEY] = cipher_key
1254
1322
 
1255
- mapper = self.factory.mapper(transcoder=self.transcoder)
1323
+ mapper: Mapper[UUID] = self.factory.mapper(transcoder=self.transcoder)
1256
1324
  self.assertIsInstance(mapper, Mapper)
1257
1325
  self.assertIsNotNone(mapper.cipher)
1258
1326
  self.assertIsNotNone(mapper.compressor)
@@ -1265,7 +1333,7 @@ class InfrastructureFactoryTestCase(ABC, TestCase, Generic[_TInfrastrutureFactor
1265
1333
  self.env["APP1_" + AESCipher.CIPHER_KEY] = cipher_key1
1266
1334
  self.env["APP2_" + AESCipher.CIPHER_KEY] = cipher_key2
1267
1335
 
1268
- mapper1: Mapper = self.factory.mapper(
1336
+ mapper1: Mapper[UUID] = self.factory.mapper(
1269
1337
  transcoder=self.transcoder,
1270
1338
  )
1271
1339
 
@@ -1279,7 +1347,7 @@ class InfrastructureFactoryTestCase(ABC, TestCase, Generic[_TInfrastrutureFactor
1279
1347
  self.assertEqual(domain_event.originator_id, copy.originator_id)
1280
1348
 
1281
1349
  self.env.name = "App2"
1282
- mapper2: Mapper = self.factory.mapper(
1350
+ mapper2: Mapper[UUID] = self.factory.mapper(
1283
1351
  transcoder=self.transcoder,
1284
1352
  )
1285
1353
  # This should fail because the infrastructure factory
@@ -1,3 +1,5 @@
1
+ import os
2
+
1
3
  import psycopg
2
4
  from psycopg.sql import SQL, Identifier
3
5
 
@@ -42,10 +44,28 @@ def pg_close_all_connections(
42
44
  pg_conn_cursor.execute(close_all_connections)
43
45
 
44
46
 
45
- def drop_postgres_table(datastore: PostgresDatastore, table_name: str) -> None:
46
- # print(f"Dropping table {datastore.schema}.{table_name}")
47
- statement = SQL("DROP TABLE IF EXISTS {0}.{1}").format(
48
- Identifier(datastore.schema), Identifier(table_name)
49
- )
50
- with datastore.transaction(commit=True) as curs:
51
- curs.execute(statement, prepare=False)
47
+ def drop_tables() -> None:
48
+
49
+ for schema in ["public", "myschema"]:
50
+ datastore = PostgresDatastore(
51
+ dbname=os.environ.get("POSTGRES_DBNAME", "eventsourcing"),
52
+ host=os.environ.get("POSTGRES_HOST", "127.0.0.1"),
53
+ port=os.environ.get("POSTGRES_PORT", "5432"),
54
+ user=os.environ.get("POSTGRES_USER", "eventsourcing"),
55
+ password=os.environ.get("POSTGRES_PASSWORD", "eventsourcing"),
56
+ schema=schema,
57
+ )
58
+ with datastore.transaction(commit=True) as curs:
59
+ select_table_names = SQL(
60
+ "SELECT table_name FROM information_schema.tables "
61
+ "WHERE table_schema = %s"
62
+ )
63
+ fetchall = curs.execute(select_table_names, (datastore.schema,)).fetchall()
64
+ for row in fetchall:
65
+ table_name = row["table_name"]
66
+ # print(f"Dropping table '{table_name}' in schema '{schema}'")
67
+ statement = SQL("DROP TABLE IF EXISTS {0}.{1}").format(
68
+ Identifier(datastore.schema), Identifier(table_name)
69
+ )
70
+ curs.execute(statement, prepare=False)
71
+ # print(f"Dropped table '{table_name}' in schema '{schema}'")
eventsourcing/utils.py CHANGED
@@ -35,7 +35,7 @@ def get_topic(obj: SupportsTopic, /) -> str:
35
35
  try:
36
36
  return _type_cache[obj]
37
37
  except KeyError:
38
- topic = f"{obj.__module__}:{obj.__qualname__}"
38
+ topic = getattr(obj, "TOPIC", f"{obj.__module__}:{obj.__qualname__}")
39
39
  register_topic(topic, obj)
40
40
  _type_cache[obj] = topic
41
41
  return topic