eventsourcing 9.5.0b3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1744 @@
1
+ from __future__ import annotations
2
+
3
+ import traceback
4
+ import zlib
5
+ from abc import ABC, abstractmethod
6
+ from concurrent.futures import ThreadPoolExecutor
7
+ from pathlib import Path
8
+ from tempfile import NamedTemporaryFile
9
+ from threading import Event, Thread, get_ident
10
+ from time import sleep
11
+ from timeit import timeit
12
+ from typing import TYPE_CHECKING, Any, Generic, cast
13
+ from unittest import TestCase
14
+ from uuid import UUID, uuid4
15
+
16
+ from typing_extensions import TypeVar
17
+
18
+ from eventsourcing.cipher import AESCipher
19
+ from eventsourcing.compressor import ZlibCompressor
20
+ from eventsourcing.domain import DomainEvent, datetime_now_with_tzinfo
21
+ from eventsourcing.persistence import (
22
+ AggregateRecorder,
23
+ ApplicationRecorder,
24
+ DatetimeAsISO,
25
+ DecimalAsStr,
26
+ InfrastructureFactory,
27
+ IntegrityError,
28
+ JSONTranscoder,
29
+ Mapper,
30
+ Notification,
31
+ ProcessRecorder,
32
+ StoredEvent,
33
+ Tracking,
34
+ TrackingRecorder,
35
+ Transcoder,
36
+ Transcoding,
37
+ UUIDAsHex,
38
+ WaitInterruptedError,
39
+ )
40
+ from eventsourcing.utils import Environment, get_topic
41
+
42
+ if TYPE_CHECKING:
43
+ from collections.abc import Iterator, Sequence
44
+
45
+ from typing_extensions import Never
46
+
47
+
48
+ class RecorderTestCase(TestCase, ABC):
49
+ INITIAL_VERSION = 1
50
+
51
+ def new_originator_id(self) -> UUID | str:
52
+ return uuid4()
53
+
54
+
55
+ class AggregateRecorderTestCase(RecorderTestCase, ABC):
56
+ @abstractmethod
57
+ def create_recorder(self) -> AggregateRecorder:
58
+ """"""
59
+
60
+ def test_insert_and_select(self) -> None:
61
+ # Construct the recorder.
62
+ recorder = self.create_recorder()
63
+
64
+ # Check we can call insert_events() with an empty list.
65
+ notification_ids = recorder.insert_events([])
66
+ self.assertEqual(notification_ids, None)
67
+
68
+ # Select stored events, expect empty list.
69
+ originator_id1 = self.new_originator_id()
70
+ self.assertEqual(
71
+ recorder.select_events(originator_id1, desc=True, limit=1),
72
+ [],
73
+ )
74
+
75
+ # Write a stored event.
76
+ stored_event1 = StoredEvent(
77
+ originator_id=originator_id1,
78
+ originator_version=self.INITIAL_VERSION,
79
+ topic="topic1",
80
+ state=b"state1",
81
+ )
82
+ notification_ids = recorder.insert_events([stored_event1])
83
+ self.assertEqual(notification_ids, None)
84
+
85
+ # Select stored events, expect list of one.
86
+ stored_events = recorder.select_events(originator_id1)
87
+ self.assertEqual(len(stored_events), 1)
88
+ self.assertEqual(stored_events[0].originator_id, originator_id1)
89
+ self.assertEqual(stored_events[0].originator_version, self.INITIAL_VERSION)
90
+ self.assertEqual(stored_events[0].topic, "topic1")
91
+ self.assertEqual(stored_events[0].state, b"state1")
92
+ self.assertIsInstance(stored_events[0].state, bytes)
93
+
94
+ # Check get record conflict error if attempt to store it again.
95
+ with self.assertRaises(IntegrityError):
96
+ recorder.insert_events([stored_event1])
97
+
98
+ # Check writing of events is atomic.
99
+ stored_event2 = StoredEvent(
100
+ originator_id=originator_id1,
101
+ originator_version=self.INITIAL_VERSION + 1,
102
+ topic="topic2",
103
+ state=b"state2",
104
+ )
105
+ with self.assertRaises(IntegrityError):
106
+ recorder.insert_events([stored_event1, stored_event2])
107
+
108
+ with self.assertRaises(IntegrityError):
109
+ recorder.insert_events([stored_event2, stored_event2])
110
+
111
+ # Check still only have one record.
112
+ stored_events = recorder.select_events(originator_id1)
113
+ self.assertEqual(len(stored_events), 1)
114
+ self.assertEqual(stored_events[0].originator_id, stored_event1.originator_id)
115
+ self.assertEqual(
116
+ stored_events[0].originator_version, stored_event1.originator_version
117
+ )
118
+ self.assertEqual(stored_events[0].topic, stored_event1.topic)
119
+
120
+ # Check can write two events together.
121
+ stored_event3 = StoredEvent(
122
+ originator_id=originator_id1,
123
+ originator_version=self.INITIAL_VERSION + 2,
124
+ topic="topic3",
125
+ state=b"state3",
126
+ )
127
+ notification_ids = recorder.insert_events([stored_event2, stored_event3])
128
+ self.assertEqual(notification_ids, None)
129
+
130
+ # Check we got what was written.
131
+ stored_events = recorder.select_events(originator_id1)
132
+ self.assertEqual(len(stored_events), 3)
133
+ self.assertEqual(stored_events[0].originator_id, originator_id1)
134
+ self.assertEqual(stored_events[0].originator_version, self.INITIAL_VERSION)
135
+ self.assertEqual(stored_events[0].topic, "topic1")
136
+ self.assertEqual(stored_events[0].state, b"state1")
137
+ self.assertEqual(stored_events[1].originator_id, originator_id1)
138
+ self.assertEqual(stored_events[1].originator_version, self.INITIAL_VERSION + 1)
139
+ self.assertEqual(stored_events[1].topic, "topic2")
140
+ self.assertEqual(stored_events[1].state, b"state2")
141
+ self.assertEqual(stored_events[2].originator_id, originator_id1)
142
+ self.assertEqual(stored_events[2].originator_version, self.INITIAL_VERSION + 2)
143
+ self.assertEqual(stored_events[2].topic, "topic3")
144
+ self.assertEqual(stored_events[2].state, b"state3")
145
+
146
+ # Check we can get the last one recorded (used to get last snapshot).
147
+ stored_events = recorder.select_events(originator_id1, desc=True, limit=1)
148
+ self.assertEqual(len(stored_events), 1)
149
+ self.assertEqual(
150
+ stored_events[0],
151
+ stored_event3,
152
+ )
153
+
154
+ # Check we can get the last one before a particular version.
155
+ stored_events = recorder.select_events(
156
+ originator_id1, lte=self.INITIAL_VERSION + 1, desc=True, limit=1
157
+ )
158
+ self.assertEqual(len(stored_events), 1)
159
+ self.assertEqual(
160
+ stored_events[0],
161
+ stored_event2,
162
+ )
163
+
164
+ # Check we can get events between versions (historical state with snapshot).
165
+ stored_events = recorder.select_events(
166
+ originator_id1, gt=self.INITIAL_VERSION, lte=self.INITIAL_VERSION + 1
167
+ )
168
+ self.assertEqual(len(stored_events), 1)
169
+ self.assertEqual(
170
+ stored_events[0],
171
+ stored_event2,
172
+ )
173
+
174
+ # Check aggregate sequences are distinguished.
175
+ originator_id2 = self.new_originator_id()
176
+ self.assertEqual(
177
+ recorder.select_events(originator_id2),
178
+ [],
179
+ )
180
+
181
+ # Write a stored event in a different sequence.
182
+ stored_event4 = StoredEvent(
183
+ originator_id=originator_id2,
184
+ originator_version=0,
185
+ topic="topic4",
186
+ state=b"state4",
187
+ )
188
+ recorder.insert_events([stored_event4])
189
+ stored_events = recorder.select_events(originator_id2)
190
+ self.assertEqual(
191
+ stored_events,
192
+ [stored_event4],
193
+ )
194
+
195
+ def test_performance(self) -> None:
196
+ # Construct the recorder.
197
+ recorder = self.create_recorder()
198
+
199
+ def insert() -> None:
200
+ originator_id = self.new_originator_id()
201
+
202
+ stored_event = StoredEvent(
203
+ originator_id=originator_id,
204
+ originator_version=self.INITIAL_VERSION,
205
+ topic="topic1",
206
+ state=b"state1",
207
+ )
208
+ recorder.insert_events([stored_event])
209
+
210
+ # Warm up.
211
+ number = 10
212
+ timeit(insert, number=number)
213
+
214
+ number = 100
215
+ duration = timeit(insert, number=number)
216
+ print(
217
+ self,
218
+ f"\n{1000000 * duration / number:.1f} μs per insert, "
219
+ f"{number / duration:.0f} inserts per second",
220
+ )
221
+
222
+
223
+ _TApplicationRecorder = TypeVar(
224
+ "_TApplicationRecorder", bound=ApplicationRecorder, default=ApplicationRecorder
225
+ )
226
+
227
+
228
+ class ApplicationRecorderTestCase(
229
+ RecorderTestCase, ABC, Generic[_TApplicationRecorder]
230
+ ):
231
+ EXPECT_CONTIGUOUS_NOTIFICATION_IDS = True
232
+
233
+ @abstractmethod
234
+ def create_recorder(self) -> _TApplicationRecorder:
235
+ """"""
236
+
237
+ def test_insert_select(self) -> None:
238
+ # Construct the recorder.
239
+ recorder = self.create_recorder()
240
+
241
+ # Check notifications methods work when there aren't any.
242
+ self.assertEqual(len(recorder.select_notifications(start=None, limit=3)), 0)
243
+ self.assertEqual(
244
+ len(recorder.select_notifications(start=None, limit=3, topics=["topic1"])),
245
+ 0,
246
+ )
247
+
248
+ self.assertIsNone(recorder.max_notification_id())
249
+
250
+ # Write two stored events.
251
+ originator_id1 = self.new_originator_id()
252
+ originator_id2 = self.new_originator_id()
253
+
254
+ stored_event1 = StoredEvent(
255
+ originator_id=originator_id1,
256
+ originator_version=self.INITIAL_VERSION,
257
+ topic="topic1",
258
+ state=b"state1",
259
+ )
260
+ stored_event2 = StoredEvent(
261
+ originator_id=originator_id1,
262
+ originator_version=self.INITIAL_VERSION + 1,
263
+ topic="topic2",
264
+ state=b"state2",
265
+ )
266
+
267
+ notification_ids = recorder.insert_events([stored_event1, stored_event2])
268
+ self.assertEqual(notification_ids, [1, 2])
269
+
270
+ # Store a third event.
271
+ stored_event3 = StoredEvent(
272
+ originator_id=originator_id2,
273
+ originator_version=self.INITIAL_VERSION,
274
+ topic="topic3",
275
+ state=b"state3",
276
+ )
277
+ notification_ids = recorder.insert_events([stored_event3])
278
+ self.assertEqual(notification_ids, [3])
279
+
280
+ stored_events1 = recorder.select_events(originator_id1)
281
+ stored_events2 = recorder.select_events(originator_id2)
282
+
283
+ # Check we got what was written.
284
+ self.assertEqual(len(stored_events1), 2)
285
+ self.assertEqual(len(stored_events2), 1)
286
+
287
+ # Check get record conflict error if attempt to store it again.
288
+ with self.assertRaises(IntegrityError):
289
+ recorder.insert_events([stored_event3])
290
+
291
+ # sleep(1) # Added to make eventsourcing-axon tests work.
292
+ notifications = recorder.select_notifications(start=None, limit=10)
293
+ self.assertEqual(len(notifications), 3)
294
+ self.assertEqual(notifications[0].id, 1)
295
+ self.assertEqual(notifications[0].originator_id, originator_id1)
296
+ self.assertEqual(notifications[0].topic, "topic1")
297
+ self.assertEqual(notifications[0].state, b"state1")
298
+ self.assertEqual(notifications[1].id, 2)
299
+ self.assertEqual(notifications[1].originator_id, originator_id1)
300
+ self.assertEqual(notifications[1].topic, "topic2")
301
+ self.assertEqual(notifications[1].state, b"state2")
302
+ self.assertEqual(notifications[2].id, 3)
303
+ self.assertEqual(notifications[2].originator_id, originator_id2)
304
+ self.assertEqual(notifications[2].topic, "topic3")
305
+ self.assertEqual(notifications[2].state, b"state3")
306
+
307
+ notifications = recorder.select_notifications(start=1, limit=10)
308
+ self.assertEqual(len(notifications), 3)
309
+ self.assertEqual(notifications[0].id, 1)
310
+ self.assertEqual(notifications[0].originator_id, originator_id1)
311
+ self.assertEqual(notifications[0].topic, "topic1")
312
+ self.assertEqual(notifications[0].state, b"state1")
313
+ self.assertEqual(notifications[1].id, 2)
314
+ self.assertEqual(notifications[1].originator_id, originator_id1)
315
+ self.assertEqual(notifications[1].topic, "topic2")
316
+ self.assertEqual(notifications[1].state, b"state2")
317
+ self.assertEqual(notifications[2].id, 3)
318
+ self.assertEqual(notifications[2].originator_id, originator_id2)
319
+ self.assertEqual(notifications[2].topic, "topic3")
320
+ self.assertEqual(notifications[2].state, b"state3")
321
+
322
+ notifications = recorder.select_notifications(start=None, stop=2, limit=10)
323
+ self.assertEqual(len(notifications), 2)
324
+ self.assertEqual(notifications[0].id, 1)
325
+ self.assertEqual(notifications[0].originator_id, originator_id1)
326
+ self.assertEqual(notifications[0].topic, "topic1")
327
+ self.assertEqual(notifications[0].state, b"state1")
328
+ self.assertEqual(notifications[1].id, 2)
329
+ self.assertEqual(notifications[1].originator_id, originator_id1)
330
+ self.assertEqual(notifications[1].topic, "topic2")
331
+ self.assertEqual(notifications[1].state, b"state2")
332
+
333
+ notifications = recorder.select_notifications(
334
+ start=1, limit=10, inclusive_of_start=False
335
+ )
336
+ self.assertEqual(len(notifications), 2)
337
+ self.assertEqual(notifications[0].id, 2)
338
+ self.assertEqual(notifications[0].originator_id, originator_id1)
339
+ self.assertEqual(notifications[0].topic, "topic2")
340
+ self.assertEqual(notifications[0].state, b"state2")
341
+ self.assertEqual(notifications[1].id, 3)
342
+ self.assertEqual(notifications[1].originator_id, originator_id2)
343
+ self.assertEqual(notifications[1].topic, "topic3")
344
+ self.assertEqual(notifications[1].state, b"state3")
345
+
346
+ notifications = recorder.select_notifications(
347
+ start=2, limit=10, inclusive_of_start=False
348
+ )
349
+ self.assertEqual(len(notifications), 1)
350
+ self.assertEqual(notifications[0].id, 3)
351
+ self.assertEqual(notifications[0].originator_id, originator_id2)
352
+ self.assertEqual(notifications[0].topic, "topic3")
353
+ self.assertEqual(notifications[0].state, b"state3")
354
+
355
+ notifications = recorder.select_notifications(
356
+ start=None, limit=10, topics=["topic1", "topic2", "topic3"]
357
+ )
358
+ self.assertEqual(len(notifications), 3)
359
+ self.assertEqual(notifications[0].id, 1)
360
+ self.assertEqual(notifications[0].originator_id, originator_id1)
361
+ self.assertEqual(notifications[0].topic, "topic1")
362
+ self.assertEqual(notifications[0].state, b"state1")
363
+ self.assertEqual(notifications[1].id, 2)
364
+ self.assertEqual(notifications[1].originator_id, originator_id1)
365
+ self.assertEqual(notifications[1].topic, "topic2")
366
+ self.assertEqual(notifications[1].state, b"state2")
367
+ self.assertEqual(notifications[2].id, 3)
368
+ self.assertEqual(notifications[2].originator_id, originator_id2)
369
+ self.assertEqual(notifications[2].topic, "topic3")
370
+ self.assertEqual(notifications[2].state, b"state3")
371
+
372
+ notifications = recorder.select_notifications(1, 10, topics=["topic1"])
373
+ self.assertEqual(len(notifications), 1)
374
+ self.assertEqual(notifications[0].id, 1)
375
+ self.assertEqual(notifications[0].originator_id, originator_id1)
376
+ self.assertEqual(notifications[0].topic, "topic1")
377
+ self.assertEqual(notifications[0].state, b"state1")
378
+
379
+ notifications = recorder.select_notifications(1, 3, topics=["topic2"])
380
+ self.assertEqual(len(notifications), 1)
381
+ self.assertEqual(notifications[0].id, 2)
382
+ self.assertEqual(notifications[0].originator_id, originator_id1)
383
+ self.assertEqual(notifications[0].topic, "topic2")
384
+ self.assertEqual(notifications[0].state, b"state2")
385
+
386
+ notifications = recorder.select_notifications(1, 3, topics=["topic3"])
387
+ self.assertEqual(len(notifications), 1)
388
+ self.assertEqual(notifications[0].id, 3)
389
+ self.assertEqual(notifications[0].originator_id, originator_id2)
390
+ self.assertEqual(notifications[0].topic, "topic3")
391
+ self.assertEqual(notifications[0].state, b"state3")
392
+
393
+ notifications = recorder.select_notifications(1, 3, topics=["topic1", "topic3"])
394
+ self.assertEqual(len(notifications), 2)
395
+ self.assertEqual(notifications[0].id, 1)
396
+ self.assertEqual(notifications[0].originator_id, originator_id1)
397
+ self.assertEqual(notifications[0].topic, "topic1")
398
+ self.assertEqual(notifications[0].state, b"state1")
399
+ self.assertEqual(notifications[1].id, 3)
400
+ self.assertEqual(notifications[1].topic, "topic3")
401
+ self.assertEqual(notifications[1].state, b"state3")
402
+
403
+ self.assertEqual(recorder.max_notification_id(), 3)
404
+
405
+ # Check limit is working
406
+ notifications = recorder.select_notifications(None, 1)
407
+ self.assertEqual(len(notifications), 1)
408
+ self.assertEqual(notifications[0].id, 1)
409
+
410
+ notifications = recorder.select_notifications(2, 1)
411
+ self.assertEqual(len(notifications), 1)
412
+ self.assertEqual(notifications[0].id, 2)
413
+
414
+ notifications = recorder.select_notifications(1, 1, inclusive_of_start=False)
415
+ self.assertEqual(len(notifications), 1)
416
+ self.assertEqual(notifications[0].id, 2)
417
+
418
+ notifications = recorder.select_notifications(2, 2)
419
+ self.assertEqual(len(notifications), 2)
420
+ self.assertEqual(notifications[0].id, 2)
421
+ self.assertEqual(notifications[1].id, 3)
422
+
423
+ notifications = recorder.select_notifications(3, 1)
424
+ self.assertEqual(len(notifications), 1)
425
+ self.assertEqual(notifications[0].id, 3)
426
+
427
+ notifications = recorder.select_notifications(3, 1, inclusive_of_start=False)
428
+ self.assertEqual(len(notifications), 0)
429
+
430
+ notifications = recorder.select_notifications(start=2, limit=10, stop=2)
431
+ self.assertEqual(len(notifications), 1)
432
+ self.assertEqual(notifications[0].id, 2)
433
+
434
+ notifications = recorder.select_notifications(start=1, limit=10, stop=2)
435
+ self.assertEqual(len(notifications), 2, len(notifications))
436
+ self.assertEqual(notifications[0].id, 1)
437
+ self.assertEqual(notifications[1].id, 2)
438
+
439
+ notifications = recorder.select_notifications(
440
+ start=1, limit=10, stop=2, inclusive_of_start=False
441
+ )
442
+ self.assertEqual(len(notifications), 1, len(notifications))
443
+ self.assertEqual(notifications[0].id, 2)
444
+
445
+ def test_performance(self) -> None:
446
+ # Construct the recorder.
447
+ recorder = self.create_recorder()
448
+
449
+ def insert() -> None:
450
+ originator_id = self.new_originator_id()
451
+
452
+ stored_event = StoredEvent(
453
+ originator_id=originator_id,
454
+ originator_version=self.INITIAL_VERSION,
455
+ topic="topic1",
456
+ state=b"state1",
457
+ )
458
+ recorder.insert_events([stored_event])
459
+
460
+ # Warm up.
461
+ number = 10
462
+ timeit(insert, number=number)
463
+
464
+ number = 100
465
+ duration = timeit(insert, number=number)
466
+ print(
467
+ self,
468
+ f"\n{1000000 * duration / number:.1f} μs per insert, "
469
+ f"{number / duration:.0f} inserts per second",
470
+ )
471
+
472
+ def test_concurrent_no_conflicts(self, initial_position: int = 0) -> None:
473
+ print(self)
474
+
475
+ recorder = self.create_recorder()
476
+
477
+ errors_happened = Event()
478
+ errors: list[Exception] = []
479
+
480
+ counts = {}
481
+ threads: dict[int, int] = {}
482
+ durations: dict[int, float] = {}
483
+
484
+ num_writers = 10
485
+ num_writes_per_writer = 100
486
+ num_events_per_write = 100
487
+ reader_sleep = 0.0001
488
+ writer_sleep = 0.0001
489
+
490
+ def insert_events() -> None:
491
+ thread_id = get_ident()
492
+ if thread_id not in threads:
493
+ threads[thread_id] = len(threads)
494
+ if thread_id not in counts:
495
+ counts[thread_id] = 0
496
+ if thread_id not in durations:
497
+ durations[thread_id] = 0
498
+
499
+ # thread_num = threads[thread_id]
500
+ # count = counts[thread_id]
501
+
502
+ originator_id = self.new_originator_id()
503
+ stored_events = [
504
+ StoredEvent(
505
+ originator_id=originator_id,
506
+ originator_version=i,
507
+ topic="topic",
508
+ state=b"state",
509
+ )
510
+ for i in range(num_events_per_write)
511
+ ]
512
+ started = datetime_now_with_tzinfo()
513
+ # print(f"Thread {thread_num} write beginning #{count + 1}")
514
+ try:
515
+ recorder.insert_events(stored_events)
516
+
517
+ except Exception as e: # pragma: no cover
518
+ if errors:
519
+ return
520
+ ended = datetime_now_with_tzinfo()
521
+ duration = (ended - started).total_seconds()
522
+ print(f"Error after starting {duration}", e)
523
+ errors.append(e)
524
+ else:
525
+ ended = datetime_now_with_tzinfo()
526
+ duration = (ended - started).total_seconds()
527
+ counts[thread_id] += 1
528
+ durations[thread_id] = max(durations[thread_id], duration)
529
+ sleep(writer_sleep)
530
+
531
+ stop_reading = Event()
532
+
533
+ def read_continuously() -> None:
534
+ while not stop_reading.is_set():
535
+ try:
536
+ recorder.select_notifications(
537
+ start=initial_position, limit=10, inclusive_of_start=False
538
+ )
539
+ except Exception as e: # pragma: no cover
540
+ errors.append(e)
541
+ return
542
+ # else:
543
+ sleep(reader_sleep)
544
+
545
+ reader_thread1 = Thread(target=read_continuously)
546
+ reader_thread1.start()
547
+
548
+ reader_thread2 = Thread(target=read_continuously)
549
+ reader_thread2.start()
550
+
551
+ with ThreadPoolExecutor(max_workers=num_writers) as executor:
552
+ futures = []
553
+ for _ in range(num_writes_per_writer):
554
+ if errors: # pragma: no cover
555
+ break
556
+ future = executor.submit(insert_events)
557
+ futures.append(future)
558
+ for future in futures:
559
+ if errors: # pragma: no cover
560
+ break
561
+ try:
562
+ future.result()
563
+ except Exception as e: # pragma: no cover
564
+ errors.append(e)
565
+ break
566
+
567
+ stop_reading.set()
568
+
569
+ if errors: # pragma: no cover
570
+ raise errors[0]
571
+
572
+ for thread_id, thread_num in threads.items():
573
+ count = counts[thread_id]
574
+ duration = durations[thread_id]
575
+ print(f"Thread {thread_num} wrote {count} times (max dur {duration})")
576
+ self.assertFalse(errors_happened.is_set())
577
+
578
+ def test_concurrent_throughput(self) -> None:
579
+ print(self)
580
+
581
+ recorder = self.create_recorder()
582
+
583
+ errors_happened = Event()
584
+
585
+ # Match this to the batch page size in postgres insert for max throughput.
586
+ num_events_per_job = 50
587
+ num_jobs = 10
588
+ num_workers = 4
589
+
590
+ def insert_events() -> None:
591
+ originator_id = self.new_originator_id()
592
+ stored_events = [
593
+ StoredEvent(
594
+ originator_id=originator_id,
595
+ originator_version=i,
596
+ topic="topic",
597
+ state=b"state",
598
+ )
599
+ for i in range(num_events_per_job)
600
+ ]
601
+
602
+ try:
603
+ recorder.insert_events(stored_events)
604
+
605
+ except Exception: # pragma: no cover
606
+ errors_happened.set()
607
+ tb = traceback.format_exc()
608
+ print(tb)
609
+
610
+ # Warm up.
611
+ with ThreadPoolExecutor(max_workers=num_workers) as executor:
612
+ futures = []
613
+ for _ in range(num_workers):
614
+ future = executor.submit(insert_events)
615
+ futures.append(future)
616
+ for future in futures:
617
+ future.result()
618
+
619
+ # Run.
620
+ with ThreadPoolExecutor(max_workers=num_workers) as executor:
621
+ started = datetime_now_with_tzinfo()
622
+ futures = []
623
+ for _ in range(num_jobs):
624
+ future = executor.submit(insert_events)
625
+ futures.append(future)
626
+ for future in futures:
627
+ future.result()
628
+
629
+ self.assertFalse(errors_happened.is_set(), "There were errors (see above)")
630
+ ended = datetime_now_with_tzinfo()
631
+ rate = num_jobs * num_events_per_job / (ended - started).total_seconds()
632
+ print(f"Rate: {rate:.0f} inserts per second")
633
+
634
+ def optional_test_insert_subscribe(self, initial_position: int = 0) -> None:
635
+
636
+ recorder = self.create_recorder()
637
+
638
+ # Get the max notification ID.
639
+ max_notification_id1 = recorder.max_notification_id()
640
+
641
+ # Write two stored events.
642
+ originator_id1 = self.new_originator_id()
643
+ originator_id2 = self.new_originator_id()
644
+
645
+ stored_event1 = StoredEvent(
646
+ originator_id=originator_id1,
647
+ originator_version=self.INITIAL_VERSION,
648
+ topic="topic1",
649
+ state=b"state1",
650
+ )
651
+ stored_event2 = StoredEvent(
652
+ originator_id=originator_id1,
653
+ originator_version=self.INITIAL_VERSION + 1,
654
+ topic="topic2",
655
+ state=b"state2",
656
+ )
657
+
658
+ notification_ids = recorder.insert_events([stored_event1, stored_event2])
659
+ if self.EXPECT_CONTIGUOUS_NOTIFICATION_IDS:
660
+ self.assertEqual(
661
+ notification_ids, [1 + initial_position, 2 + initial_position]
662
+ )
663
+
664
+ # Get the max notification ID.
665
+ max_notification_id2 = recorder.max_notification_id()
666
+
667
+ # Start a subscription with default value for 'start'.
668
+ with recorder.subscribe(gt=initial_position) as subscription:
669
+
670
+ # Receive events from the subscription.
671
+ for _ in subscription:
672
+ break
673
+
674
+ # Start a subscription with None value for 'start'.
675
+ with recorder.subscribe(gt=initial_position) as subscription:
676
+
677
+ # Receive events from the subscription.
678
+ for _ in subscription:
679
+ break
680
+
681
+ # Start a subscription with int value for 'start'.
682
+ with recorder.subscribe(gt=max_notification_id1) as subscription:
683
+
684
+ # Receive events from the subscription.
685
+ notifications: list[Notification] = []
686
+ for notification in subscription:
687
+ notifications.append(notification)
688
+ if len(notifications) == 2:
689
+ break
690
+
691
+ # Check the events we received are the ones that were written.
692
+ self.assertEqual(
693
+ stored_event1.originator_id, notifications[0].originator_id
694
+ )
695
+ self.assertEqual(
696
+ stored_event1.originator_version, notifications[0].originator_version
697
+ )
698
+ self.assertEqual(
699
+ stored_event2.originator_id, notifications[1].originator_id
700
+ )
701
+ self.assertEqual(
702
+ stored_event2.originator_version, notifications[1].originator_version
703
+ )
704
+ if self.EXPECT_CONTIGUOUS_NOTIFICATION_IDS:
705
+ self.assertEqual(1 + initial_position, notifications[0].id)
706
+ self.assertEqual(2 + initial_position, notifications[1].id)
707
+
708
+ # Store a third event.
709
+ stored_event3 = StoredEvent(
710
+ originator_id=originator_id2,
711
+ originator_version=self.INITIAL_VERSION,
712
+ topic="topic3",
713
+ state=b"state3",
714
+ )
715
+ notification_ids = recorder.insert_events([stored_event3])
716
+ if self.EXPECT_CONTIGUOUS_NOTIFICATION_IDS:
717
+ self.assertEqual(notification_ids, [3 + initial_position])
718
+
719
+ # Receive events from the subscription.
720
+ for notification in subscription:
721
+ notifications.append(notification)
722
+ if len(notifications) == 3:
723
+ break
724
+
725
+ # Check the events we received are the ones that were written.
726
+ self.assertEqual(
727
+ stored_event3.originator_id, notifications[2].originator_id
728
+ )
729
+ self.assertEqual(
730
+ stored_event3.originator_version, notifications[2].originator_version
731
+ )
732
+ if self.EXPECT_CONTIGUOUS_NOTIFICATION_IDS:
733
+ self.assertEqual(3 + initial_position, notifications[2].id)
734
+
735
+ # Start a subscription with int value for 'start'.
736
+ with recorder.subscribe(gt=max_notification_id2) as subscription:
737
+
738
+ # Receive events from the subscription.
739
+ notifications = []
740
+ for notification in subscription:
741
+ notifications.append(notification)
742
+ if len(notifications) == 1:
743
+ break
744
+
745
+ # Check the events we received are the ones that were written.
746
+ self.assertEqual(
747
+ stored_event3.originator_id, notifications[0].originator_id
748
+ )
749
+
750
+ # Start a subscription, call stop() during iteration.
751
+ with recorder.subscribe(gt=initial_position) as subscription:
752
+
753
+ # Receive events from the subscription.
754
+ for i, _ in enumerate(subscription):
755
+ subscription.stop()
756
+ # Shouldn't get here twice...
757
+ self.assertLess(i, 1, "Got here twice")
758
+
759
+ # Start a subscription, call stop() before iteration.
760
+ subscription = recorder.subscribe(gt=None)
761
+ with subscription:
762
+ subscription.stop()
763
+ # Receive events from the subscription.
764
+ for _ in subscription:
765
+ # Shouldn't get here...
766
+ self.fail("Got here")
767
+
768
+ # Start a subscription, call stop() before entering context manager.
769
+ subscription = recorder.subscribe(gt=None)
770
+ subscription.stop()
771
+ with subscription:
772
+ # Receive events from the subscription.
773
+ for _ in subscription:
774
+ # Shouldn't get here...
775
+ self.fail("Got here")
776
+
777
+ # Start a subscription with topics.
778
+ subscription = recorder.subscribe(gt=None, topics=["topic3"])
779
+ with subscription:
780
+ for notification in subscription:
781
+ self.assertEqual(notification.topic, "topic3")
782
+ if (
783
+ notification.originator_id == stored_event3.originator_id
784
+ and notification.originator_version
785
+ == stored_event3.originator_version
786
+ ):
787
+ break
788
+
789
+ def close_db_connection(self, *args: Any) -> None:
790
+ """"""
791
+
792
+
793
+ class TrackingRecorderTestCase(TestCase, ABC):
794
+ @abstractmethod
795
+ def create_recorder(self) -> TrackingRecorder:
796
+ """"""
797
+
798
+ def test_insert_tracking(self) -> None:
799
+ tracking_recorder = self.create_recorder()
800
+
801
+ # Construct tracking objects.
802
+ tracking1 = Tracking("upstream1", 21)
803
+ tracking2 = Tracking("upstream1", 22)
804
+ tracking3 = Tracking("upstream2", 21)
805
+
806
+ # Insert tracking objects.
807
+ tracking_recorder.insert_tracking(tracking=tracking1)
808
+ tracking_recorder.insert_tracking(tracking=tracking2)
809
+ tracking_recorder.insert_tracking(tracking=tracking3)
810
+
811
+ # raise Exception(tracking_recorder.max_tracking_id(tracking1.application_name))
812
+
813
+ # Fail to insert same tracking object twice.
814
+ with self.assertRaises(IntegrityError):
815
+ tracking_recorder.insert_tracking(tracking=tracking1)
816
+ with self.assertRaises(IntegrityError):
817
+ tracking_recorder.insert_tracking(tracking=tracking2)
818
+ with self.assertRaises(IntegrityError):
819
+ tracking_recorder.insert_tracking(tracking=tracking3)
820
+
821
+ # Get max tracking ID.
822
+ self.assertEqual(tracking_recorder.max_tracking_id("upstream1"), 22)
823
+ self.assertEqual(tracking_recorder.max_tracking_id("upstream2"), 21)
824
+ self.assertIsNone(tracking_recorder.max_tracking_id("upstream3"))
825
+
826
+ # Check if an event notification has been processed.
827
+ self.assertTrue(tracking_recorder.has_tracking_id("upstream1", None))
828
+ self.assertTrue(tracking_recorder.has_tracking_id("upstream1", 20))
829
+ self.assertTrue(tracking_recorder.has_tracking_id("upstream1", 21))
830
+ self.assertTrue(tracking_recorder.has_tracking_id("upstream1", 22))
831
+ self.assertFalse(tracking_recorder.has_tracking_id("upstream1", 23))
832
+
833
+ self.assertTrue(tracking_recorder.has_tracking_id("upstream2", None))
834
+ self.assertTrue(tracking_recorder.has_tracking_id("upstream2", 20))
835
+ self.assertTrue(tracking_recorder.has_tracking_id("upstream2", 21))
836
+ self.assertFalse(tracking_recorder.has_tracking_id("upstream2", 22))
837
+
838
+ self.assertTrue(tracking_recorder.has_tracking_id("upstream2", None))
839
+ self.assertFalse(tracking_recorder.has_tracking_id("upstream2", 22))
840
+
841
+ # Construct more tracking objects.
842
+ tracking4 = Tracking("upstream1", 23)
843
+ tracking5 = Tracking("upstream1", 24)
844
+
845
+ tracking_recorder.insert_tracking(tracking5)
846
+
847
+ # Can't fill in the gap.
848
+ with self.assertRaises(IntegrityError):
849
+ tracking_recorder.insert_tracking(tracking4)
850
+
851
+ self.assertTrue(tracking_recorder.has_tracking_id("upstream1", 23))
852
+
853
+ def test_wait(self) -> None:
854
+ tracking_recorder = self.create_recorder()
855
+
856
+ tracking_recorder.wait("upstream1", None)
857
+
858
+ with self.assertRaises(TimeoutError):
859
+ tracking_recorder.wait("upstream1", 21, timeout=0.1)
860
+
861
+ tracking1 = Tracking(notification_id=21, application_name="upstream1")
862
+ tracking_recorder.insert_tracking(tracking=tracking1)
863
+ tracking_recorder.wait("upstream1", None)
864
+ tracking_recorder.wait("upstream1", 10)
865
+ tracking_recorder.wait("upstream1", 21)
866
+ with self.assertRaises(TimeoutError):
867
+ tracking_recorder.wait("upstream1", 22, timeout=0.1)
868
+ with self.assertRaises(WaitInterruptedError):
869
+ interrupt = Event()
870
+ interrupt.set()
871
+ tracking_recorder.wait("upstream1", 22, interrupt=interrupt)
872
+
873
+
874
+ class ProcessRecorderTestCase(RecorderTestCase, ABC):
875
+ @abstractmethod
876
+ def create_recorder(self) -> ProcessRecorder:
877
+ """"""
878
+
879
+ def test_insert_select(self) -> None:
880
+ # Construct the recorder.
881
+ recorder = self.create_recorder()
882
+
883
+ # Get current position.
884
+ self.assertIsNone(recorder.max_tracking_id("upstream_app"))
885
+
886
+ # Write two stored events.
887
+ originator_id1 = self.new_originator_id()
888
+ originator_id2 = self.new_originator_id()
889
+
890
+ stored_event1 = StoredEvent(
891
+ originator_id=originator_id1,
892
+ originator_version=1,
893
+ topic="topic1",
894
+ state=b"state1",
895
+ )
896
+ stored_event2 = StoredEvent(
897
+ originator_id=originator_id1,
898
+ originator_version=2,
899
+ topic="topic2",
900
+ state=b"state2",
901
+ )
902
+ stored_event3 = StoredEvent(
903
+ originator_id=originator_id2,
904
+ originator_version=1,
905
+ topic="topic3",
906
+ state=b"state3",
907
+ )
908
+ stored_event4 = StoredEvent(
909
+ originator_id=originator_id2,
910
+ originator_version=2,
911
+ topic="topic4",
912
+ state=b"state4",
913
+ )
914
+ tracking1 = Tracking(
915
+ application_name="upstream_app",
916
+ notification_id=1,
917
+ )
918
+ tracking2 = Tracking(
919
+ application_name="upstream_app",
920
+ notification_id=2,
921
+ )
922
+
923
+ # Insert two events with tracking info.
924
+ recorder.insert_events(
925
+ stored_events=[
926
+ stored_event1,
927
+ stored_event2,
928
+ ],
929
+ tracking=tracking1,
930
+ )
931
+
932
+ # Check get record conflict error if attempt to store same event again.
933
+ with self.assertRaises(IntegrityError):
934
+ recorder.insert_events(
935
+ stored_events=[stored_event2],
936
+ tracking=tracking2,
937
+ )
938
+
939
+ # Get current position.
940
+ self.assertEqual(
941
+ recorder.max_tracking_id("upstream_app"),
942
+ 1,
943
+ )
944
+
945
+ # Check can't insert third event with same tracking info.
946
+ with self.assertRaises(IntegrityError):
947
+ recorder.insert_events(
948
+ stored_events=[stored_event3],
949
+ tracking=tracking1,
950
+ )
951
+
952
+ # Get current position.
953
+ self.assertEqual(
954
+ recorder.max_tracking_id("upstream_app"),
955
+ 1,
956
+ )
957
+
958
+ # Insert third event with different tracking info.
959
+ recorder.insert_events(
960
+ stored_events=[stored_event3],
961
+ tracking=tracking2,
962
+ )
963
+
964
+ # Get current position.
965
+ self.assertEqual(
966
+ recorder.max_tracking_id("upstream_app"),
967
+ 2,
968
+ )
969
+
970
+ # Insert fourth event without tracking info.
971
+ recorder.insert_events(
972
+ stored_events=[stored_event4],
973
+ )
974
+
975
+ # Get current position.
976
+ self.assertEqual(
977
+ recorder.max_tracking_id("upstream_app"),
978
+ 2,
979
+ )
980
+
981
+ def test_has_tracking_id(self) -> None:
982
+ # Construct the recorder.
983
+ recorder = self.create_recorder()
984
+
985
+ self.assertTrue(recorder.has_tracking_id("upstream_app", None))
986
+ self.assertFalse(recorder.has_tracking_id("upstream_app", 1))
987
+ self.assertFalse(recorder.has_tracking_id("upstream_app", 2))
988
+ self.assertFalse(recorder.has_tracking_id("upstream_app", 3))
989
+ self.assertFalse(recorder.has_tracking_id("upstream_app", 4))
990
+
991
+ tracking1 = Tracking(
992
+ application_name="upstream_app",
993
+ notification_id=1,
994
+ )
995
+ tracking3 = Tracking(
996
+ application_name="upstream_app",
997
+ notification_id=3,
998
+ )
999
+
1000
+ recorder.insert_events(
1001
+ stored_events=[],
1002
+ tracking=tracking1,
1003
+ )
1004
+
1005
+ self.assertTrue(recorder.has_tracking_id("upstream_app", 1))
1006
+ self.assertFalse(recorder.has_tracking_id("upstream_app", 2))
1007
+ self.assertFalse(recorder.has_tracking_id("upstream_app", 3))
1008
+ self.assertFalse(recorder.has_tracking_id("upstream_app", 4))
1009
+
1010
+ recorder.insert_events(
1011
+ stored_events=[],
1012
+ tracking=tracking3,
1013
+ )
1014
+
1015
+ self.assertTrue(recorder.has_tracking_id("upstream_app", 1))
1016
+ self.assertTrue(recorder.has_tracking_id("upstream_app", 2))
1017
+ self.assertTrue(recorder.has_tracking_id("upstream_app", 3))
1018
+ self.assertFalse(recorder.has_tracking_id("upstream_app", 4))
1019
+
1020
+ def test_raises_when_lower_inserted_later(self) -> None:
1021
+ # Construct the recorder.
1022
+ recorder = self.create_recorder()
1023
+
1024
+ tracking1 = Tracking(
1025
+ application_name="upstream_app",
1026
+ notification_id=1,
1027
+ )
1028
+ tracking2 = Tracking(
1029
+ application_name="upstream_app",
1030
+ notification_id=2,
1031
+ )
1032
+
1033
+ # Insert tracking info.
1034
+ recorder.insert_events(
1035
+ stored_events=[],
1036
+ tracking=tracking2,
1037
+ )
1038
+
1039
+ # Get current position.
1040
+ self.assertEqual(
1041
+ recorder.max_tracking_id("upstream_app"),
1042
+ 2,
1043
+ )
1044
+
1045
+ # Insert tracking info.
1046
+ with self.assertRaises(IntegrityError):
1047
+ recorder.insert_events(
1048
+ stored_events=[],
1049
+ tracking=tracking1,
1050
+ )
1051
+
1052
+ # Get current position.
1053
+ self.assertEqual(
1054
+ recorder.max_tracking_id("upstream_app"),
1055
+ 2,
1056
+ )
1057
+
1058
+ def test_performance(self) -> None:
1059
+ # Construct the recorder.
1060
+ recorder = self.create_recorder()
1061
+
1062
+ number = 100
1063
+
1064
+ notification_ids = iter(range(1, number + 1))
1065
+
1066
+ def insert_events() -> None:
1067
+ originator_id = self.new_originator_id()
1068
+
1069
+ stored_event = StoredEvent(
1070
+ originator_id=originator_id,
1071
+ originator_version=0,
1072
+ topic="topic1",
1073
+ state=b"state1",
1074
+ )
1075
+ tracking1 = Tracking(
1076
+ application_name="upstream_app",
1077
+ notification_id=next(notification_ids),
1078
+ )
1079
+
1080
+ recorder.insert_events(
1081
+ stored_events=[
1082
+ stored_event,
1083
+ ],
1084
+ tracking=tracking1,
1085
+ )
1086
+
1087
+ duration = timeit(insert_events, number=number)
1088
+ print(
1089
+ f"\n{self}",
1090
+ f"{1000000 * duration / number:.1f} μs per insert, "
1091
+ f"{number / duration:.0f} inserts per second",
1092
+ )
1093
+
1094
+
1095
+ class NonInterleavingNotificationIDsBaseCase(RecorderTestCase, ABC):
1096
+ insert_num = 1000
1097
+
1098
+ def test(self) -> None:
1099
+ recorder = self.create_recorder()
1100
+
1101
+ max_notification_id = recorder.max_notification_id()
1102
+
1103
+ race_started = Event()
1104
+
1105
+ originator1_id = self.new_originator_id()
1106
+ originator2_id = self.new_originator_id()
1107
+
1108
+ stack1 = self.create_stack(originator1_id)
1109
+ stack2 = self.create_stack(originator2_id)
1110
+
1111
+ errors = []
1112
+
1113
+ def insert_stack(stack: Sequence[StoredEvent]) -> None:
1114
+ try:
1115
+ race_started.wait()
1116
+ recorder.insert_events(stack)
1117
+ except Exception as e:
1118
+ errors.append(e)
1119
+
1120
+ thread1 = Thread(target=insert_stack, args=(stack1,), daemon=True)
1121
+ thread2 = Thread(target=insert_stack, args=(stack2,), daemon=True)
1122
+
1123
+ thread1.start()
1124
+ thread2.start()
1125
+
1126
+ race_started.set()
1127
+
1128
+ thread1.join()
1129
+ thread2.join()
1130
+
1131
+ if errors:
1132
+ raise errors[0]
1133
+
1134
+ # sleep(1) # Added to make eventsourcing-axon tests work.
1135
+ notifications = recorder.select_notifications(
1136
+ start=max_notification_id,
1137
+ limit=2 * self.insert_num,
1138
+ inclusive_of_start=False,
1139
+ )
1140
+ ids_for_sequence1 = [
1141
+ e.id for e in notifications if e.originator_id == originator1_id
1142
+ ]
1143
+ ids_for_sequence2 = [
1144
+ e.id for e in notifications if e.originator_id == originator2_id
1145
+ ]
1146
+ self.assertEqual(self.insert_num, len(ids_for_sequence1))
1147
+ self.assertEqual(self.insert_num, len(ids_for_sequence2))
1148
+
1149
+ max_id_for_sequence1 = max(ids_for_sequence1)
1150
+ max_id_for_sequence2 = max(ids_for_sequence2)
1151
+ min_id_for_sequence1 = min(ids_for_sequence1)
1152
+ min_id_for_sequence2 = min(ids_for_sequence2)
1153
+
1154
+ if max_id_for_sequence1 > min_id_for_sequence2:
1155
+ self.assertGreater(min_id_for_sequence1, max_id_for_sequence2)
1156
+ else:
1157
+ self.assertGreater(min_id_for_sequence2, max_id_for_sequence1)
1158
+
1159
+ def create_stack(self, originator_id: UUID | str) -> Sequence[StoredEvent]:
1160
+ return [
1161
+ StoredEvent(
1162
+ originator_id=originator_id,
1163
+ originator_version=i,
1164
+ topic="non-interleaving-test-event",
1165
+ state=b"{}",
1166
+ )
1167
+ for i in range(self.insert_num)
1168
+ ]
1169
+
1170
+ @abstractmethod
1171
+ def create_recorder(self) -> ApplicationRecorder:
1172
+ pass
1173
+
1174
+
1175
+ _TInfrastrutureFactory = TypeVar(
1176
+ "_TInfrastrutureFactory", bound=InfrastructureFactory[Any]
1177
+ )
1178
+
1179
+
1180
+ class InfrastructureFactoryTestCase(ABC, TestCase, Generic[_TInfrastrutureFactory]):
1181
+ env: Environment
1182
+
1183
+ @abstractmethod
1184
+ def expected_factory_class(self) -> type[_TInfrastrutureFactory]:
1185
+ pass
1186
+
1187
+ @abstractmethod
1188
+ def expected_aggregate_recorder_class(self) -> type[AggregateRecorder]:
1189
+ pass
1190
+
1191
+ @abstractmethod
1192
+ def expected_application_recorder_class(self) -> type[ApplicationRecorder]:
1193
+ pass
1194
+
1195
+ @abstractmethod
1196
+ def expected_tracking_recorder_class(self) -> type[TrackingRecorder]:
1197
+ pass
1198
+
1199
+ @abstractmethod
1200
+ def tracking_recorder_subclass(self) -> type[TrackingRecorder]:
1201
+ pass
1202
+
1203
+ @abstractmethod
1204
+ def expected_process_recorder_class(self) -> type[ProcessRecorder]:
1205
+ pass
1206
+
1207
+ def setUp(self) -> None:
1208
+ self.factory = cast(
1209
+ _TInfrastrutureFactory, InfrastructureFactory.construct(self.env)
1210
+ )
1211
+ self.assertIsInstance(self.factory, self.expected_factory_class())
1212
+ self.transcoder = JSONTranscoder()
1213
+ self.transcoder.register(UUIDAsHex())
1214
+ self.transcoder.register(DecimalAsStr())
1215
+ self.transcoder.register(DatetimeAsISO())
1216
+
1217
+ def tearDown(self) -> None:
1218
+ self.factory.close()
1219
+
1220
+ def test_mapper(self) -> None:
1221
+ # Want to construct:
1222
+ # - application recorder
1223
+ # - snapshot recorder
1224
+ # - mapper
1225
+ # - event store
1226
+ # - snapshot store
1227
+
1228
+ # Want to make configurable:
1229
+ # - cipher (and cipher key)
1230
+ # - compressor
1231
+ # - application recorder class (and db uri, and session)
1232
+ # - snapshot recorder class (and db uri, and session)
1233
+
1234
+ # Common environment:
1235
+ # - factory topic
1236
+ # - cipher topic
1237
+ # - cipher key
1238
+ # - compressor topic
1239
+
1240
+ # POPO environment:
1241
+
1242
+ # SQLite environment:
1243
+ # - database topic
1244
+ # - table name for stored events
1245
+ # - table name for snapshots
1246
+
1247
+ # Create mapper.
1248
+
1249
+ mapper: Mapper[UUID] = self.factory.mapper(
1250
+ transcoder=self.transcoder,
1251
+ )
1252
+ self.assertIsInstance(mapper, Mapper)
1253
+ self.assertIsNone(mapper.cipher)
1254
+ self.assertIsNone(mapper.compressor)
1255
+
1256
+ def test_mapper_with_compressor(self) -> None:
1257
+ # Create mapper with compressor class as topic.
1258
+ self.env[self.factory.COMPRESSOR_TOPIC] = get_topic(ZlibCompressor)
1259
+ mapper: Mapper[UUID] = self.factory.mapper(transcoder=self.transcoder)
1260
+ self.assertIsInstance(mapper, Mapper)
1261
+ self.assertIsInstance(mapper.compressor, ZlibCompressor)
1262
+ self.assertIsNone(mapper.cipher)
1263
+
1264
+ # Create mapper with compressor module as topic.
1265
+ self.env[self.factory.COMPRESSOR_TOPIC] = "zlib"
1266
+ mapper = self.factory.mapper(transcoder=self.transcoder)
1267
+ self.assertIsInstance(mapper, Mapper)
1268
+ self.assertEqual(mapper.compressor, zlib)
1269
+ self.assertIsNone(mapper.cipher)
1270
+
1271
+ def test_mapper_with_cipher(self) -> None:
1272
+ # Check cipher needs a key.
1273
+ self.env[self.factory.CIPHER_TOPIC] = get_topic(AESCipher)
1274
+
1275
+ with self.assertRaises(EnvironmentError):
1276
+ self.factory.mapper(transcoder=self.transcoder)
1277
+
1278
+ # Check setting key but no topic defers to AES.
1279
+ del self.env[self.factory.CIPHER_TOPIC]
1280
+
1281
+ cipher_key = AESCipher.create_key(16)
1282
+ self.env[AESCipher.CIPHER_KEY] = cipher_key
1283
+
1284
+ # Create mapper with cipher.
1285
+ mapper: Mapper[UUID] = self.factory.mapper(transcoder=self.transcoder)
1286
+ self.assertIsInstance(mapper, Mapper)
1287
+ self.assertIsNotNone(mapper.cipher)
1288
+ self.assertIsNone(mapper.compressor)
1289
+
1290
+ def test_mapper_with_cipher_and_compressor(
1291
+ self,
1292
+ ) -> None:
1293
+ # Create mapper with cipher and compressor.
1294
+ self.env[self.factory.COMPRESSOR_TOPIC] = get_topic(ZlibCompressor)
1295
+
1296
+ self.env[self.factory.CIPHER_TOPIC] = get_topic(AESCipher)
1297
+ cipher_key = AESCipher.create_key(16)
1298
+ self.env[AESCipher.CIPHER_KEY] = cipher_key
1299
+
1300
+ mapper: Mapper[UUID] = self.factory.mapper(transcoder=self.transcoder)
1301
+ self.assertIsInstance(mapper, Mapper)
1302
+ self.assertIsNotNone(mapper.cipher)
1303
+ self.assertIsNotNone(mapper.compressor)
1304
+
1305
+ def test_mapper_with_wrong_cipher_key(self) -> None:
1306
+ self.env.name = "App1"
1307
+ self.env[self.factory.CIPHER_TOPIC] = get_topic(AESCipher)
1308
+ cipher_key1 = AESCipher.create_key(16)
1309
+ cipher_key2 = AESCipher.create_key(16)
1310
+ self.env["APP1_" + AESCipher.CIPHER_KEY] = cipher_key1
1311
+ self.env["APP2_" + AESCipher.CIPHER_KEY] = cipher_key2
1312
+
1313
+ mapper1: Mapper[UUID] = self.factory.mapper(
1314
+ transcoder=self.transcoder,
1315
+ )
1316
+
1317
+ domain_event = DomainEvent(
1318
+ originator_id=uuid4(),
1319
+ originator_version=1,
1320
+ timestamp=DomainEvent.create_timestamp(),
1321
+ )
1322
+ stored_event = mapper1.to_stored_event(domain_event)
1323
+ copy = mapper1.to_domain_event(stored_event)
1324
+ self.assertEqual(domain_event.originator_id, copy.originator_id)
1325
+
1326
+ self.env.name = "App2"
1327
+ mapper2: Mapper[UUID] = self.factory.mapper(
1328
+ transcoder=self.transcoder,
1329
+ )
1330
+ # This should fail because the infrastructure factory
1331
+ # should read different cipher keys from the environment.
1332
+ with self.assertRaises(ValueError):
1333
+ mapper2.to_domain_event(stored_event)
1334
+
1335
+ def test_create_aggregate_recorder(self) -> None:
1336
+ recorder = self.factory.aggregate_recorder()
1337
+ self.assertEqual(type(recorder), self.expected_aggregate_recorder_class())
1338
+
1339
+ self.assertIsInstance(recorder, AggregateRecorder)
1340
+
1341
+ # Exercise code path where table is not created.
1342
+ self.env["CREATE_TABLE"] = "f"
1343
+ recorder = self.factory.aggregate_recorder()
1344
+ self.assertEqual(type(recorder), self.expected_aggregate_recorder_class())
1345
+
1346
+ def test_create_application_recorder(self) -> None:
1347
+ recorder = self.factory.application_recorder()
1348
+ self.assertEqual(type(recorder), self.expected_application_recorder_class())
1349
+ self.assertIsInstance(recorder, ApplicationRecorder)
1350
+
1351
+ # Exercise code path where table is not created.
1352
+ self.env["CREATE_TABLE"] = "f"
1353
+ recorder = self.factory.application_recorder()
1354
+ self.assertEqual(type(recorder), self.expected_application_recorder_class())
1355
+
1356
+ def test_create_tracking_recorder(self) -> None:
1357
+ recorder = self.factory.tracking_recorder()
1358
+ self.assertEqual(type(recorder), self.expected_tracking_recorder_class())
1359
+ self.assertIsInstance(recorder, TrackingRecorder)
1360
+
1361
+ # Exercise code path where table is not created.
1362
+ self.env["CREATE_TABLE"] = "f"
1363
+ recorder = self.factory.tracking_recorder()
1364
+ self.assertEqual(type(recorder), self.expected_tracking_recorder_class())
1365
+
1366
+ # Exercise code path where tracking recorder class is specified as arg.
1367
+ subclass = self.tracking_recorder_subclass()
1368
+ recorder = self.factory.tracking_recorder(subclass)
1369
+ self.assertEqual(type(recorder), subclass)
1370
+
1371
+ # Exercise code path where tracking recorder class is specified as topic.
1372
+ self.factory.env[self.factory.TRACKING_RECORDER_TOPIC] = get_topic(subclass)
1373
+ recorder = self.factory.tracking_recorder()
1374
+ self.assertEqual(type(recorder), subclass)
1375
+
1376
+ def test_create_process_recorder(self) -> None:
1377
+ recorder = self.factory.process_recorder()
1378
+ self.assertEqual(type(recorder), self.expected_process_recorder_class())
1379
+ self.assertIsInstance(recorder, ProcessRecorder)
1380
+
1381
+ # Exercise code path where table is not created.
1382
+ self.env["CREATE_TABLE"] = "f"
1383
+ recorder = self.factory.process_recorder()
1384
+ self.assertEqual(type(recorder), self.expected_process_recorder_class())
1385
+
1386
+
1387
+ def tmpfile_uris() -> Iterator[str]:
1388
+ tmp_files = []
1389
+ ram_disk_path = Path("/Volumes/RAM DISK/")
1390
+ prefix: str | None = None
1391
+ if ram_disk_path.exists():
1392
+ prefix = str(ram_disk_path)
1393
+ while True:
1394
+ with NamedTemporaryFile(
1395
+ prefix=prefix,
1396
+ suffix="_eventsourcing_test.db",
1397
+ ) as tmp_file:
1398
+ tmp_files.append(tmp_file)
1399
+ yield "file:" + tmp_file.name
1400
+
1401
+
1402
+ class CustomType1:
1403
+ def __init__(self, value: UUID):
1404
+ self.value = value
1405
+
1406
+ def __eq__(self, other: object) -> bool:
1407
+ return type(self) is type(other) and self.__dict__ == other.__dict__
1408
+
1409
+ def __hash__(self) -> int:
1410
+ raise NotImplementedError
1411
+
1412
+
1413
+ class CustomType2:
1414
+ def __init__(self, value: CustomType1):
1415
+ self.value = value
1416
+
1417
+ def __eq__(self, other: object) -> bool:
1418
+ return type(self) is type(other) and self.__dict__ == other.__dict__
1419
+
1420
+ def __hash__(self) -> int:
1421
+ raise NotImplementedError
1422
+
1423
+
1424
+ _KT = TypeVar("_KT")
1425
+ _VT = TypeVar("_VT")
1426
+
1427
+
1428
+ class Mydict(dict[_KT, _VT]): # noqa: PLW1641
1429
+ def __repr__(self) -> str:
1430
+ return f"{type(self).__name__}({super().__repr__()})"
1431
+
1432
+ def __eq__(self, other: object) -> bool:
1433
+ return type(self) is type(other) and super().__eq__(other)
1434
+
1435
+
1436
+ _T = TypeVar("_T")
1437
+
1438
+
1439
+ class MyList(list[_T]): # noqa: PLW1641
1440
+ def __repr__(self) -> str:
1441
+ return f"{type(self).__name__}({super().__repr__()})"
1442
+
1443
+ def __eq__(self, other: object) -> bool:
1444
+ return type(self) is type(other) and super().__eq__(other)
1445
+
1446
+
1447
+ class MyStr(str):
1448
+ __slots__ = ()
1449
+
1450
+ def __repr__(self) -> str:
1451
+ return f"{type(self).__name__}({super().__repr__()})"
1452
+
1453
+ def __eq__(self, other: object) -> bool:
1454
+ return type(self) is type(other) and super().__eq__(other)
1455
+
1456
+ def __hash__(self) -> int:
1457
+ return hash(str(self))
1458
+
1459
+
1460
+ class MyInt(int):
1461
+ def __repr__(self) -> str:
1462
+ return f"{type(self).__name__}({super().__repr__()})"
1463
+
1464
+ def __eq__(self, other: object) -> bool:
1465
+ return type(self) is type(other) and super().__eq__(other)
1466
+
1467
+ def __hash__(self) -> int:
1468
+ return int(self)
1469
+
1470
+
1471
+ class MyClass:
1472
+ pass
1473
+
1474
+
1475
+ class CustomType1AsDict(Transcoding):
1476
+ type = CustomType1
1477
+ name = "custom_type1_as_dict"
1478
+
1479
+ def encode(self, obj: CustomType1) -> UUID:
1480
+ return obj.value
1481
+
1482
+ def decode(self, data: UUID) -> CustomType1:
1483
+ assert isinstance(data, UUID)
1484
+ return CustomType1(value=data)
1485
+
1486
+
1487
+ class CustomType2AsDict(Transcoding):
1488
+ type = CustomType2
1489
+ name = "custom_type2_as_dict"
1490
+
1491
+ def encode(self, obj: CustomType2) -> CustomType1:
1492
+ return obj.value
1493
+
1494
+ def decode(self, data: CustomType1) -> CustomType2:
1495
+ assert isinstance(data, CustomType1)
1496
+ return CustomType2(data)
1497
+
1498
+
1499
+ class TranscoderTestCase(TestCase):
1500
+ def setUp(self) -> None:
1501
+ self.transcoder = self.construct_transcoder()
1502
+
1503
+ def construct_transcoder(self) -> Transcoder:
1504
+ raise NotImplementedError
1505
+
1506
+ def test_str(self) -> None:
1507
+ obj = "a"
1508
+ data = self.transcoder.encode(obj)
1509
+ self.assertEqual(data, b'"a"')
1510
+ self.assertEqual(obj, self.transcoder.decode(data))
1511
+
1512
+ obj = "abc"
1513
+ data = self.transcoder.encode(obj)
1514
+ self.assertEqual(data, b'"abc"')
1515
+ self.assertEqual(obj, self.transcoder.decode(data))
1516
+
1517
+ obj = "a'b"
1518
+ data = self.transcoder.encode(obj)
1519
+ self.assertEqual(data, b'''"a'b"''')
1520
+ self.assertEqual(obj, self.transcoder.decode(data))
1521
+
1522
+ obj = 'a"b'
1523
+ data = self.transcoder.encode(obj)
1524
+ self.assertEqual(data, b'''"a\\"b"''')
1525
+ self.assertEqual(obj, self.transcoder.decode(data))
1526
+
1527
+ obj = "🐈 哈哈"
1528
+ data = self.transcoder.encode(obj)
1529
+ self.assertEqual(b'"\xf0\x9f\x90\x88 \xe5\x93\x88\xe5\x93\x88"', data)
1530
+ self.assertEqual(obj, self.transcoder.decode(data))
1531
+
1532
+ # Check data encoded with ensure_ascii=True can be decoded okay.
1533
+ legacy_encoding_with_ensure_ascii_true = b'"\\ud83d\\udc08 \\u54c8\\u54c8"'
1534
+ self.assertEqual(
1535
+ obj, self.transcoder.decode(legacy_encoding_with_ensure_ascii_true)
1536
+ )
1537
+
1538
+ def test_dict(self) -> None:
1539
+ # Empty dict.
1540
+ obj1: dict[Never, Never] = {}
1541
+ data = self.transcoder.encode(obj1)
1542
+ self.assertEqual(data, b"{}")
1543
+ self.assertEqual(obj1, self.transcoder.decode(data))
1544
+
1545
+ # dict with single key.
1546
+ obj2 = {"a": 1}
1547
+ data = self.transcoder.encode(obj2)
1548
+ self.assertEqual(data, b'{"a":1}')
1549
+ self.assertEqual(obj2, self.transcoder.decode(data))
1550
+
1551
+ # dict with many keys.
1552
+ obj3 = {"a": 1, "b": 2}
1553
+ data = self.transcoder.encode(obj3)
1554
+ self.assertEqual(data, b'{"a":1,"b":2}')
1555
+ self.assertEqual(obj3, self.transcoder.decode(data))
1556
+
1557
+ # Empty dict in dict.
1558
+ obj4: dict[str, dict[Never, Never]] = {"a": {}}
1559
+ data = self.transcoder.encode(obj4)
1560
+ self.assertEqual(data, b'{"a":{}}')
1561
+ self.assertEqual(obj4, self.transcoder.decode(data))
1562
+
1563
+ # Empty dicts in dict.
1564
+ obj5: dict[str, dict[Never, Never]] = {"a": {}, "b": {}}
1565
+ data = self.transcoder.encode(obj5)
1566
+ self.assertEqual(data, b'{"a":{},"b":{}}')
1567
+ self.assertEqual(obj5, self.transcoder.decode(data))
1568
+
1569
+ # Empty dict in dict in dict.
1570
+ obj6: dict[str, dict[str, dict[Never, Never]]] = {"a": {"b": {}}}
1571
+ data = self.transcoder.encode(obj6)
1572
+ self.assertEqual(data, b'{"a":{"b":{}}}')
1573
+ self.assertEqual(obj6, self.transcoder.decode(data))
1574
+
1575
+ # Int in dict in dict in dict.
1576
+ obj7 = {"a": {"b": {"c": 1}}}
1577
+ data = self.transcoder.encode(obj7)
1578
+ self.assertEqual(data, b'{"a":{"b":{"c":1}}}')
1579
+ self.assertEqual(obj7, self.transcoder.decode(data))
1580
+
1581
+ # TODO: Int keys?
1582
+ # obj = {1: "a"}
1583
+ # data = self.transcoder.encode(obj)
1584
+ # self.assertEqual(data, b'{1:{"a"}')
1585
+ # self.assertEqual(obj, self.transcoder.decode(data))
1586
+
1587
+ def test_dict_with_len_2_and__data_(self) -> None:
1588
+ obj = {"_data_": 1, "something_else": 2}
1589
+ data = self.transcoder.encode(obj)
1590
+ self.assertEqual(obj, self.transcoder.decode(data))
1591
+
1592
+ def test_dict_with_len_2_and__type_(self) -> None:
1593
+ obj = {"_type_": 1, "something_else": 2}
1594
+ data = self.transcoder.encode(obj)
1595
+ self.assertEqual(obj, self.transcoder.decode(data))
1596
+
1597
+ def test_dict_subclass(self) -> None:
1598
+ my_dict = Mydict({"a": 1})
1599
+ data = self.transcoder.encode(my_dict)
1600
+ self.assertEqual(b'{"_type_":"mydict","_data_":{"a":1}}', data)
1601
+ copy = self.transcoder.decode(data)
1602
+ self.assertEqual(my_dict, copy)
1603
+
1604
+ def test_list_subclass(self) -> None:
1605
+ my_list = MyList((("a", 1),))
1606
+ data = self.transcoder.encode(my_list)
1607
+ copy = self.transcoder.decode(data)
1608
+ self.assertEqual(my_list, copy)
1609
+
1610
+ def test_str_subclass(self) -> None:
1611
+ my_str = MyStr("a")
1612
+ data = self.transcoder.encode(my_str)
1613
+ copy = self.transcoder.decode(data)
1614
+ self.assertEqual(my_str, copy)
1615
+
1616
+ def test_int_subclass(self) -> None:
1617
+ my_int = MyInt(3)
1618
+ data = self.transcoder.encode(my_int)
1619
+ copy = self.transcoder.decode(data)
1620
+ self.assertEqual(my_int, copy)
1621
+
1622
+ def test_tuple(self) -> None:
1623
+ # Empty tuple.
1624
+ obj1 = ()
1625
+ data = self.transcoder.encode(obj1)
1626
+ self.assertEqual(data, b'{"_type_":"tuple_as_list","_data_":[]}')
1627
+ self.assertEqual(obj1, self.transcoder.decode(data))
1628
+
1629
+ # Empty tuple in a tuple.
1630
+ obj2 = ((),)
1631
+ data = self.transcoder.encode(obj2)
1632
+ self.assertEqual(obj2, self.transcoder.decode(data))
1633
+
1634
+ # Int in tuple in a tuple.
1635
+ obj3 = ((1, 2),)
1636
+ data = self.transcoder.encode(obj3)
1637
+ self.assertEqual(obj3, self.transcoder.decode(data))
1638
+
1639
+ # Str in tuple in a tuple.
1640
+ obj4 = (("a", "b"),)
1641
+ data = self.transcoder.encode(obj4)
1642
+ self.assertEqual(obj4, self.transcoder.decode(data))
1643
+
1644
+ # Int and str in tuple in a tuple.
1645
+ obj5 = ((1, "a"),)
1646
+ data = self.transcoder.encode(obj5)
1647
+ self.assertEqual(obj5, self.transcoder.decode(data))
1648
+
1649
+ def test_list(self) -> None:
1650
+ # Empty list.
1651
+ obj1: list[Never] = []
1652
+ data = self.transcoder.encode(obj1)
1653
+ self.assertEqual(obj1, self.transcoder.decode(data))
1654
+
1655
+ # Empty list in a list.
1656
+ obj2: list[list[Never]] = [[]]
1657
+ data = self.transcoder.encode(obj2)
1658
+ self.assertEqual(obj2, self.transcoder.decode(data))
1659
+
1660
+ # Int in list in a list.
1661
+ obj3 = [[1, 2]]
1662
+ data = self.transcoder.encode(obj3)
1663
+ self.assertEqual(obj3, self.transcoder.decode(data))
1664
+
1665
+ # Str in list in a list.
1666
+ obj4 = [["a", "b"]]
1667
+ data = self.transcoder.encode(obj4)
1668
+ self.assertEqual(obj4, self.transcoder.decode(data))
1669
+
1670
+ # Int and str in list in a list.
1671
+ obj5 = [[1, "a"]]
1672
+ data = self.transcoder.encode(obj5)
1673
+ self.assertEqual(obj5, self.transcoder.decode(data))
1674
+
1675
+ def test_mixed(self) -> None:
1676
+ obj1 = [(1, "a"), {"b": 2}]
1677
+ data = self.transcoder.encode(obj1)
1678
+ self.assertEqual(obj1, self.transcoder.decode(data))
1679
+
1680
+ obj2 = ([1, "a"], {"b": 2})
1681
+ data = self.transcoder.encode(obj2)
1682
+ self.assertEqual(obj2, self.transcoder.decode(data))
1683
+
1684
+ obj3 = {"a": (1, 2), "b": [3, 4]}
1685
+ data = self.transcoder.encode(obj3)
1686
+ self.assertEqual(obj3, self.transcoder.decode(data))
1687
+
1688
+ def test_custom_type_in_dict(self) -> None:
1689
+ # Int in dict in dict in dict.
1690
+ obj = {"a": CustomType2(CustomType1(UUID("b2723fe2c01a40d2875ea3aac6a09ff5")))}
1691
+ data = self.transcoder.encode(obj)
1692
+ decoded_obj = self.transcoder.decode(data)
1693
+ self.assertEqual(obj, decoded_obj)
1694
+
1695
+ def test_nested_custom_type(self) -> None:
1696
+ obj = CustomType2(CustomType1(UUID("b2723fe2c01a40d2875ea3aac6a09ff5")))
1697
+ data = self.transcoder.encode(obj)
1698
+ expect = (
1699
+ b'{"_type_":"custom_type2_as_dict","_data_":'
1700
+ b'{"_type_":"custom_type1_as_dict","_data_":'
1701
+ b'{"_type_":"uuid_hex","_data_":"b2723fe2c01'
1702
+ b'a40d2875ea3aac6a09ff5"}}}'
1703
+ )
1704
+ self.assertEqual(data, expect)
1705
+ copy = self.transcoder.decode(data)
1706
+ self.assertIsInstance(copy, CustomType2)
1707
+ self.assertIsInstance(copy.value, CustomType1)
1708
+ self.assertIsInstance(copy.value.value, UUID)
1709
+ self.assertEqual(copy.value.value, obj.value.value)
1710
+
1711
+ def test_custom_type_error(self) -> None:
1712
+ # Expect a TypeError when encoding because transcoding not registered.
1713
+ with self.assertRaises(TypeError) as cm:
1714
+ self.transcoder.encode(MyClass())
1715
+
1716
+ self.assertEqual(
1717
+ cm.exception.args[0],
1718
+ "Object of type <class 'eventsourcing.tests.persistence."
1719
+ "MyClass'> is not serializable. Please define "
1720
+ "and register a custom transcoding for this type.",
1721
+ )
1722
+
1723
+ # Expect a TypeError when encoding because transcoding not registered (nested).
1724
+ with self.assertRaises(TypeError) as cm:
1725
+ self.transcoder.encode({"a": MyClass()})
1726
+
1727
+ self.assertEqual(
1728
+ cm.exception.args[0],
1729
+ "Object of type <class 'eventsourcing.tests.persistence."
1730
+ "MyClass'> is not serializable. Please define "
1731
+ "and register a custom transcoding for this type.",
1732
+ )
1733
+
1734
+ # Check we get a TypeError when decoding because transcodings aren't registered.
1735
+ data = b'{"_type_":"custom_type3_as_dict","_data_":""}'
1736
+
1737
+ with self.assertRaises(TypeError) as cm:
1738
+ self.transcoder.decode(data)
1739
+
1740
+ self.assertEqual(
1741
+ cm.exception.args[0],
1742
+ "Data serialized with name 'custom_type3_as_dict' is not "
1743
+ "deserializable. Please register a custom transcoding for this type.",
1744
+ )