eventsourcing 9.3.4__py3-none-any.whl → 9.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of eventsourcing might be problematic. Click here for more details.

eventsourcing/postgres.py CHANGED
@@ -1,14 +1,20 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import contextlib
3
4
  import logging
5
+ from asyncio import CancelledError
4
6
  from contextlib import contextmanager
5
- from typing import TYPE_CHECKING, Any, Callable, Iterator, List, Sequence
7
+ from threading import Thread
8
+ from typing import TYPE_CHECKING, Any, Callable, cast
6
9
 
7
10
  import psycopg
8
11
  import psycopg.errors
9
12
  import psycopg_pool
10
- from psycopg import Connection, Cursor
13
+ from psycopg import Connection, Cursor, Error
14
+ from psycopg.generators import notifies
11
15
  from psycopg.rows import DictRow, dict_row
16
+ from psycopg.sql import SQL, Composed, Identifier
17
+ from typing_extensions import TypeVar
12
18
 
13
19
  from eventsourcing.persistence import (
14
20
  AggregateRecorder,
@@ -19,6 +25,7 @@ from eventsourcing.persistence import (
19
25
  IntegrityError,
20
26
  InterfaceError,
21
27
  InternalError,
28
+ ListenNotifySubscription,
22
29
  Notification,
23
30
  NotSupportedError,
24
31
  OperationalError,
@@ -26,18 +33,30 @@ from eventsourcing.persistence import (
26
33
  ProcessRecorder,
27
34
  ProgrammingError,
28
35
  StoredEvent,
36
+ Subscription,
29
37
  Tracking,
38
+ TrackingRecorder,
30
39
  )
31
40
  from eventsourcing.utils import Environment, resolve_topic, retry, strtobool
32
41
 
33
- if TYPE_CHECKING: # pragma: nocover
42
+ if TYPE_CHECKING:
43
+ from collections.abc import Iterator, Sequence
34
44
  from uuid import UUID
35
45
 
46
+ from psycopg.abc import Query
36
47
  from typing_extensions import Self
37
48
 
38
49
  logging.getLogger("psycopg.pool").setLevel(logging.CRITICAL)
39
50
  logging.getLogger("psycopg").setLevel(logging.CRITICAL)
40
51
 
52
+ # Copy of "private" psycopg.errors._NO_TRACEBACK (in case it changes)
53
+ # From psycopg: "Don't show a complete traceback upon raising these exception.
54
+ # Usually the traceback starts from internal functions (for instance in the
55
+ # server communication callbacks) but, for the end user, it's more important
56
+ # to get the high level information about where the exception was raised, for
57
+ # instance in a certain `Cursor.execute()`."
58
+ NO_TRACEBACK = (Error, KeyboardInterrupt, CancelledError)
59
+
41
60
 
42
61
  class ConnectionPool(psycopg_pool.ConnectionPool[Any]):
43
62
  def __init__(
@@ -56,11 +75,11 @@ class ConnectionPool(psycopg_pool.ConnectionPool[Any]):
56
75
 
57
76
 
58
77
  class PostgresDatastore:
59
- def __init__(
78
+ def __init__( # noqa: PLR0913
60
79
  self,
61
80
  dbname: str,
62
81
  host: str,
63
- port: str,
82
+ port: str | int,
64
83
  user: str,
65
84
  password: str,
66
85
  *,
@@ -95,22 +114,26 @@ class PostgresDatastore:
95
114
  min_size=pool_size,
96
115
  max_size=pool_size + max_overflow,
97
116
  open=False,
98
- configure=self.after_connect,
117
+ configure=self.after_connect_func(),
99
118
  timeout=connect_timeout,
100
119
  max_waiting=max_waiting,
101
120
  max_lifetime=conn_max_age,
102
121
  check=check,
103
122
  )
104
123
  self.lock_timeout = lock_timeout
105
- self.schema = schema.strip()
124
+ self.schema = schema.strip() or "public"
106
125
 
107
- def after_connect(self, conn: Connection[DictRow]) -> None:
108
- conn.autocommit = True
109
- conn.cursor().execute(
110
- "SET idle_in_transaction_session_timeout = "
111
- f"'{self.idle_in_transaction_session_timeout}s'"
126
+ def after_connect_func(self) -> Callable[[Connection[Any]], None]:
127
+ statement = SQL("SET idle_in_transaction_session_timeout = '{0}s'").format(
128
+ self.idle_in_transaction_session_timeout
112
129
  )
113
130
 
131
+ def after_connect(conn: Connection[DictRow]) -> None:
132
+ conn.autocommit = True
133
+ conn.cursor().execute(statement)
134
+
135
+ return after_connect
136
+
114
137
  @contextmanager
115
138
  def get_connection(self) -> Iterator[Connection[DictRow]]:
116
139
  try:
@@ -147,86 +170,97 @@ class PostgresDatastore:
147
170
 
148
171
  @contextmanager
149
172
  def transaction(self, *, commit: bool = False) -> Iterator[Cursor[DictRow]]:
150
- conn: Connection[DictRow]
151
173
  with self.get_connection() as conn, conn.transaction(force_rollback=not commit):
152
174
  yield conn.cursor()
153
175
 
154
176
  def close(self) -> None:
155
177
  self.pool.close()
156
178
 
157
- def __del__(self) -> None:
158
- self.close()
159
-
160
179
  def __enter__(self) -> Self:
161
180
  return self
162
181
 
163
182
  def __exit__(self, *args: object, **kwargs: Any) -> None:
164
183
  self.close()
165
184
 
185
+ def __del__(self) -> None:
186
+ self.close()
187
+
188
+
189
+ class PostgresRecorder:
190
+ """Base class for recorders that use PostgreSQL."""
166
191
 
167
- class PostgresAggregateRecorder(AggregateRecorder):
168
192
  def __init__(
169
193
  self,
170
194
  datastore: PostgresDatastore,
171
- events_table_name: str,
172
195
  ):
173
- self.check_table_name_length(events_table_name, datastore.schema)
174
196
  self.datastore = datastore
197
+ self.create_table_statements = self.construct_create_table_statements()
198
+
199
+ def construct_create_table_statements(self) -> list[Composed]:
200
+ return []
201
+
202
+ def check_table_name_length(self, table_name: str) -> None:
203
+ if len(table_name) > 63:
204
+ msg = f"Table name too long: {table_name}"
205
+ raise ProgrammingError(msg)
206
+
207
+ def create_table(self) -> None:
208
+ with self.datastore.transaction(commit=True) as curs:
209
+ for statement in self.create_table_statements:
210
+ curs.execute(statement, prepare=False)
211
+
212
+
213
+ class PostgresAggregateRecorder(PostgresRecorder, AggregateRecorder):
214
+ def __init__(
215
+ self,
216
+ datastore: PostgresDatastore,
217
+ *,
218
+ events_table_name: str = "stored_events",
219
+ ):
220
+ super().__init__(datastore)
221
+ self.check_table_name_length(events_table_name)
175
222
  self.events_table_name = events_table_name
176
223
  # Index names can't be qualified names, but
177
224
  # are created in the same schema as the table.
178
- if "." in self.events_table_name:
179
- unqualified_table_name = self.events_table_name.split(".")[-1]
180
- else:
181
- unqualified_table_name = self.events_table_name
182
225
  self.notification_id_index_name = (
183
- f"{unqualified_table_name}_notification_id_idx "
226
+ f"{self.events_table_name}_notification_id_idx"
184
227
  )
185
-
186
- self.create_table_statements = self.construct_create_table_statements()
187
- self.insert_events_statement = (
188
- f"INSERT INTO {self.events_table_name} VALUES (%s, %s, %s, %s)"
189
- )
190
- self.select_events_statement = (
191
- f"SELECT * FROM {self.events_table_name} WHERE originator_id = %s"
228
+ self.create_table_statements.append(
229
+ SQL(
230
+ "CREATE TABLE IF NOT EXISTS {0}.{1} ("
231
+ "originator_id uuid NOT NULL, "
232
+ "originator_version bigint NOT NULL, "
233
+ "topic text, "
234
+ "state bytea, "
235
+ "PRIMARY KEY "
236
+ "(originator_id, originator_version)) "
237
+ "WITH (autovacuum_enabled=false)"
238
+ ).format(
239
+ Identifier(self.datastore.schema),
240
+ Identifier(self.events_table_name),
241
+ )
192
242
  )
193
- self.lock_table_statements: List[str] = []
194
243
 
195
- @staticmethod
196
- def check_table_name_length(table_name: str, schema_name: str) -> None:
197
- schema_prefix = schema_name + "."
198
- if table_name.startswith(schema_prefix):
199
- unqualified_table_name = table_name[len(schema_prefix) :]
200
- else:
201
- unqualified_table_name = table_name
202
- if len(unqualified_table_name) > 63:
203
- msg = f"Table name too long: {unqualified_table_name}"
204
- raise ProgrammingError(msg)
244
+ self.insert_events_statement = SQL(
245
+ "INSERT INTO {0}.{1} VALUES (%s, %s, %s, %s)"
246
+ ).format(
247
+ Identifier(self.datastore.schema),
248
+ Identifier(self.events_table_name),
249
+ )
205
250
 
206
- def construct_create_table_statements(self) -> List[str]:
207
- statement = (
208
- "CREATE TABLE IF NOT EXISTS "
209
- f"{self.events_table_name} ("
210
- "originator_id uuid NOT NULL, "
211
- "originator_version bigint NOT NULL, "
212
- "topic text, "
213
- "state bytea, "
214
- "PRIMARY KEY "
215
- "(originator_id, originator_version)) "
216
- "WITH (autovacuum_enabled=false)"
251
+ self.select_events_statement = SQL(
252
+ "SELECT * FROM {0}.{1} WHERE originator_id = %s"
253
+ ).format(
254
+ Identifier(self.datastore.schema),
255
+ Identifier(self.events_table_name),
217
256
  )
218
- return [statement]
219
257
 
220
- def create_table(self) -> None:
221
- with self.datastore.transaction(commit=True) as curs:
222
- for statement in self.create_table_statements:
223
- curs.execute(statement, prepare=False)
258
+ self.lock_table_statements: list[Query] = []
224
259
 
225
260
  @retry((InterfaceError, OperationalError), max_attempts=10, wait=0.2)
226
261
  def insert_events(
227
- self, stored_events: List[StoredEvent], **kwargs: Any
262
+ self, stored_events: list[StoredEvent], **kwargs: Any
228
263
  ) -> Sequence[int] | None:
229
- conn: Connection[DictRow]
230
264
  exc: Exception | None = None
231
265
  notification_ids: Sequence[int] | None = None
232
266
  with self.datastore.get_connection() as conn:
@@ -255,24 +289,26 @@ class PostgresAggregateRecorder(AggregateRecorder):
255
289
 
256
290
  def _insert_events(
257
291
  self,
258
- c: Cursor[DictRow],
259
- stored_events: List[StoredEvent],
260
- **kwargs: Any,
292
+ curs: Cursor[DictRow],
293
+ stored_events: list[StoredEvent],
294
+ **_: Any,
261
295
  ) -> None:
262
296
  pass
263
297
 
264
298
  def _insert_stored_events(
265
299
  self,
266
- c: Cursor[DictRow],
267
- stored_events: List[StoredEvent],
300
+ curs: Cursor[DictRow],
301
+ stored_events: list[StoredEvent],
268
302
  **_: Any,
269
303
  ) -> None:
270
304
  # Only do something if there is something to do.
271
305
  if len(stored_events) > 0:
272
- self._lock_table(c)
306
+ self._lock_table(curs)
307
+
308
+ self._notify_channel(curs)
273
309
 
274
310
  # Insert events.
275
- c.executemany(
311
+ curs.executemany(
276
312
  query=self.insert_events_statement,
277
313
  params_seq=[
278
314
  (
@@ -283,16 +319,19 @@ class PostgresAggregateRecorder(AggregateRecorder):
283
319
  )
284
320
  for stored_event in stored_events
285
321
  ],
286
- returning="RETURNING" in self.insert_events_statement,
322
+ returning="RETURNING" in self.insert_events_statement.as_string(),
287
323
  )
288
324
 
289
- def _lock_table(self, c: Cursor[DictRow]) -> None:
325
+ def _lock_table(self, curs: Cursor[DictRow]) -> None:
326
+ pass
327
+
328
+ def _notify_channel(self, curs: Cursor[DictRow]) -> None:
290
329
  pass
291
330
 
292
331
  def _fetch_ids_after_insert_events(
293
332
  self,
294
- c: Cursor[DictRow],
295
- stored_events: List[StoredEvent],
333
+ curs: Cursor[DictRow],
334
+ stored_events: list[StoredEvent],
296
335
  **kwargs: Any,
297
336
  ) -> Sequence[int] | None:
298
337
  return None
@@ -306,23 +345,23 @@ class PostgresAggregateRecorder(AggregateRecorder):
306
345
  lte: int | None = None,
307
346
  desc: bool = False,
308
347
  limit: int | None = None,
309
- ) -> List[StoredEvent]:
348
+ ) -> list[StoredEvent]:
310
349
  statement = self.select_events_statement
311
- params: List[Any] = [originator_id]
350
+ params: list[Any] = [originator_id]
312
351
  if gt is not None:
313
352
  params.append(gt)
314
- statement += " AND originator_version > %s"
353
+ statement += SQL(" AND originator_version > %s")
315
354
  if lte is not None:
316
355
  params.append(lte)
317
- statement += " AND originator_version <= %s"
318
- statement += " ORDER BY originator_version"
356
+ statement += SQL(" AND originator_version <= %s")
357
+ statement += SQL(" ORDER BY originator_version")
319
358
  if desc is False:
320
- statement += " ASC"
359
+ statement += SQL(" ASC")
321
360
  else:
322
- statement += " DESC"
361
+ statement += SQL(" DESC")
323
362
  if limit is not None:
324
363
  params.append(limit)
325
- statement += " LIMIT %s"
364
+ statement += SQL(" LIMIT %s")
326
365
 
327
366
  with self.datastore.get_connection() as conn, conn.cursor() as curs:
328
367
  curs.execute(statement, params, prepare=True)
@@ -341,65 +380,108 @@ class PostgresApplicationRecorder(PostgresAggregateRecorder, ApplicationRecorder
341
380
  def __init__(
342
381
  self,
343
382
  datastore: PostgresDatastore,
383
+ *,
344
384
  events_table_name: str = "stored_events",
345
385
  ):
346
- super().__init__(datastore, events_table_name)
347
- self.insert_events_statement += " RETURNING notification_id"
348
- self.max_notification_id_statement = (
349
- f"SELECT MAX(notification_id) FROM {self.events_table_name}"
386
+ super().__init__(datastore, events_table_name=events_table_name)
387
+ self.create_table_statements[-1] = SQL(
388
+ "CREATE TABLE IF NOT EXISTS {0}.{1} ("
389
+ "originator_id uuid NOT NULL, "
390
+ "originator_version bigint NOT NULL, "
391
+ "topic text, "
392
+ "state bytea, "
393
+ "notification_id bigserial, "
394
+ "PRIMARY KEY "
395
+ "(originator_id, originator_version)) "
396
+ "WITH (autovacuum_enabled=false)"
397
+ ).format(
398
+ Identifier(self.datastore.schema),
399
+ Identifier(self.events_table_name),
350
400
  )
351
- self.lock_table_statements = [
352
- f"SET LOCAL lock_timeout = '{self.datastore.lock_timeout}s'",
353
- f"LOCK TABLE {self.events_table_name} IN EXCLUSIVE MODE",
354
- ]
355
401
 
356
- def construct_create_table_statements(self) -> List[str]:
357
- return [
358
- (
359
- "CREATE TABLE IF NOT EXISTS "
360
- f"{self.events_table_name} ("
361
- "originator_id uuid NOT NULL, "
362
- "originator_version bigint NOT NULL, "
363
- "topic text, "
364
- "state bytea, "
365
- "notification_id bigserial, "
366
- "PRIMARY KEY "
367
- "(originator_id, originator_version)) "
368
- "WITH (autovacuum_enabled=false)"
369
- ),
370
- (
371
- "CREATE UNIQUE INDEX IF NOT EXISTS "
372
- f"{self.notification_id_index_name}"
373
- f"ON {self.events_table_name} (notification_id ASC);"
402
+ self.create_table_statements.append(
403
+ SQL(
404
+ "CREATE UNIQUE INDEX IF NOT EXISTS {0} "
405
+ "ON {1}.{2} (notification_id ASC);"
406
+ ).format(
407
+ Identifier(self.notification_id_index_name),
408
+ Identifier(self.datastore.schema),
409
+ Identifier(self.events_table_name),
410
+ )
411
+ )
412
+
413
+ self.channel_name = self.events_table_name.replace(".", "_")
414
+ self.insert_events_statement = self.insert_events_statement + SQL(
415
+ " RETURNING notification_id"
416
+ )
417
+
418
+ self.max_notification_id_statement = SQL(
419
+ "SELECT MAX(notification_id) FROM {0}.{1}"
420
+ ).format(
421
+ Identifier(self.datastore.schema),
422
+ Identifier(self.events_table_name),
423
+ )
424
+
425
+ self.lock_table_statements = [
426
+ SQL("SET LOCAL lock_timeout = '{0}s'").format(self.datastore.lock_timeout),
427
+ SQL("LOCK TABLE {0}.{1} IN EXCLUSIVE MODE").format(
428
+ Identifier(self.datastore.schema),
429
+ Identifier(self.events_table_name),
374
430
  ),
375
431
  ]
376
432
 
377
433
  @retry((InterfaceError, OperationalError), max_attempts=10, wait=0.2)
378
434
  def select_notifications(
379
435
  self,
380
- start: int,
436
+ start: int | None,
381
437
  limit: int,
382
438
  stop: int | None = None,
383
439
  topics: Sequence[str] = (),
384
- ) -> List[Notification]:
385
- """
386
- Returns a list of event notifications
440
+ *,
441
+ inclusive_of_start: bool = True,
442
+ ) -> list[Notification]:
443
+ """Returns a list of event notifications
387
444
  from 'start', limited by 'limit'.
388
445
  """
389
-
390
- params: List[int | str | Sequence[str]] = [start]
391
- statement = f"SELECT * FROM {self.events_table_name} WHERE notification_id>=%s"
446
+ params: list[int | str | Sequence[str]] = []
447
+ statement = SQL("SELECT * FROM {0}.{1}").format(
448
+ Identifier(self.datastore.schema),
449
+ Identifier(self.events_table_name),
450
+ )
451
+ has_where = False
452
+ if start is not None:
453
+ statement += SQL(" WHERE")
454
+ has_where = True
455
+ params.append(start)
456
+ if inclusive_of_start:
457
+ statement += SQL(" notification_id>=%s")
458
+ else:
459
+ statement += SQL(" notification_id>%s")
392
460
 
393
461
  if stop is not None:
462
+ if not has_where:
463
+ has_where = True
464
+ statement += SQL(" WHERE")
465
+ else:
466
+ statement += SQL(" AND")
467
+
394
468
  params.append(stop)
395
- statement += " AND notification_id <= %s"
469
+ statement += SQL(" notification_id <= %s")
396
470
 
397
471
  if topics:
472
+ # Check sequence and ensure list of strings.
473
+ assert isinstance(topics, (tuple, list)), topics
474
+ topics = list(topics) if isinstance(topics, tuple) else topics
475
+ assert all(isinstance(t, str) for t in topics), topics
476
+ if not has_where:
477
+ statement += SQL(" WHERE")
478
+ else:
479
+ statement += SQL(" AND")
398
480
  params.append(topics)
399
- statement += " AND topic = ANY(%s)"
481
+ statement += SQL(" topic = ANY(%s)")
400
482
 
401
483
  params.append(limit)
402
- statement += " ORDER BY notification_id LIMIT %s"
484
+ statement += SQL(" ORDER BY notification_id LIMIT %s")
403
485
 
404
486
  connection = self.datastore.get_connection()
405
487
  with connection as conn, conn.cursor() as curs:
@@ -416,18 +498,15 @@ class PostgresApplicationRecorder(PostgresAggregateRecorder, ApplicationRecorder
416
498
  ]
417
499
 
418
500
  @retry((InterfaceError, OperationalError), max_attempts=10, wait=0.2)
419
- def max_notification_id(self) -> int:
420
- """
421
- Returns the maximum notification ID.
422
- """
423
- conn: Connection[DictRow]
501
+ def max_notification_id(self) -> int | None:
502
+ """Returns the maximum notification ID."""
424
503
  with self.datastore.get_connection() as conn, conn.cursor() as curs:
425
504
  curs.execute(self.max_notification_id_statement)
426
505
  fetchone = curs.fetchone()
427
506
  assert fetchone is not None
428
- return fetchone["max"] or 0
507
+ return fetchone["max"]
429
508
 
430
- def _lock_table(self, c: Cursor[DictRow]) -> None:
509
+ def _lock_table(self, curs: Cursor[DictRow]) -> None:
431
510
  # Acquire "EXCLUSIVE" table lock, to serialize transactions that insert
432
511
  # stored events, so that readers don't pass over gaps that are filled in
433
512
  # later. We want each transaction that will be issued with notifications
@@ -449,70 +528,148 @@ class PostgresApplicationRecorder(PostgresAggregateRecorder, ApplicationRecorder
449
528
  # https://stackoverflow.com/questions/45866187/guarantee-monotonicity-of
450
529
  # -postgresql-serial-column-values-by-commit-order
451
530
  for lock_statement in self.lock_table_statements:
452
- c.execute(lock_statement, prepare=True)
531
+ curs.execute(lock_statement, prepare=True)
532
+
533
+ def _notify_channel(self, curs: Cursor[DictRow]) -> None:
534
+ curs.execute(SQL("NOTIFY {0}").format(Identifier(self.channel_name)))
453
535
 
454
536
  def _fetch_ids_after_insert_events(
455
537
  self,
456
- c: Cursor[DictRow],
457
- stored_events: List[StoredEvent],
538
+ curs: Cursor[DictRow],
539
+ stored_events: list[StoredEvent],
458
540
  **kwargs: Any,
459
541
  ) -> Sequence[int] | None:
460
- notification_ids: List[int] = []
542
+ notification_ids: list[int] = []
461
543
  len_events = len(stored_events)
462
544
  if len_events:
463
- if (
464
- (c.statusmessage == "SET")
465
- and c.nextset()
466
- and (c.statusmessage == "LOCK TABLE")
467
- ):
468
- while c.nextset() and len(notification_ids) != len_events:
469
- row = c.fetchone()
545
+ while curs.nextset() and len(notification_ids) != len_events:
546
+ if curs.statusmessage and curs.statusmessage.startswith("INSERT"):
547
+ row = curs.fetchone()
470
548
  assert row is not None
471
549
  notification_ids.append(row["notification_id"])
472
550
  if len(notification_ids) != len(stored_events):
473
- msg = "Couldn't get all notification IDs"
551
+ msg = "Couldn't get all notification IDs "
552
+ msg += f"(got {len(notification_ids)}, expected {len(stored_events)})"
474
553
  raise ProgrammingError(msg)
475
554
  return notification_ids
476
555
 
556
+ def subscribe(
557
+ self, gt: int | None = None, topics: Sequence[str] = ()
558
+ ) -> Subscription[ApplicationRecorder]:
559
+ return PostgresSubscription(recorder=self, gt=gt, topics=topics)
560
+
561
+
562
+ class PostgresSubscription(ListenNotifySubscription[PostgresApplicationRecorder]):
563
+ def __init__(
564
+ self,
565
+ recorder: PostgresApplicationRecorder,
566
+ gt: int | None = None,
567
+ topics: Sequence[str] = (),
568
+ ) -> None:
569
+ assert isinstance(recorder, PostgresApplicationRecorder)
570
+ super().__init__(recorder=recorder, gt=gt, topics=topics)
571
+ self._listen_thread = Thread(target=self._listen)
572
+ self._listen_thread.start()
477
573
 
478
- class PostgresProcessRecorder(PostgresApplicationRecorder, ProcessRecorder):
574
+ def __exit__(self, *args: object, **kwargs: Any) -> None:
575
+ super().__exit__(*args, **kwargs)
576
+ self._listen_thread.join()
577
+
578
+ def _listen(self) -> None:
579
+ try:
580
+ with self._recorder.datastore.get_connection() as conn:
581
+ conn.execute(
582
+ SQL("LISTEN {0}").format(Identifier(self._recorder.channel_name))
583
+ )
584
+ while not self._has_been_stopped and not self._thread_error:
585
+ # This block simplifies psycopg's conn.notifies(), because
586
+ # we aren't interested in the actual notify messages, and
587
+ # also we want to stop consuming notify messages when the
588
+ # subscription has an error or is otherwise stopped.
589
+ with conn.lock:
590
+ try:
591
+ if conn.wait(notifies(conn.pgconn), interval=0.1):
592
+ self._has_been_notified.set()
593
+ except NO_TRACEBACK as ex: # pragma: no cover
594
+ raise ex.with_traceback(None) from None
595
+
596
+ except BaseException as e:
597
+ if self._thread_error is None:
598
+ self._thread_error = e
599
+ self.stop()
600
+
601
+
602
+ class PostgresTrackingRecorder(PostgresRecorder, TrackingRecorder):
479
603
  def __init__(
480
604
  self,
481
605
  datastore: PostgresDatastore,
482
- events_table_name: str,
483
- tracking_table_name: str,
606
+ *,
607
+ tracking_table_name: str = "notification_tracking",
608
+ **kwargs: Any,
484
609
  ):
485
- self.check_table_name_length(tracking_table_name, datastore.schema)
610
+ super().__init__(datastore, **kwargs)
611
+ self.check_table_name_length(tracking_table_name)
486
612
  self.tracking_table_name = tracking_table_name
487
- super().__init__(datastore, events_table_name)
488
- self.insert_tracking_statement = (
489
- f"INSERT INTO {self.tracking_table_name} VALUES (%s, %s)"
613
+ self.create_table_statements.append(
614
+ SQL(
615
+ "CREATE TABLE IF NOT EXISTS {0}.{1} ("
616
+ "application_name text, "
617
+ "notification_id bigint, "
618
+ "PRIMARY KEY "
619
+ "(application_name, notification_id))"
620
+ ).format(
621
+ Identifier(self.datastore.schema),
622
+ Identifier(self.tracking_table_name),
623
+ )
490
624
  )
491
- self.max_tracking_id_statement = (
492
- "SELECT MAX(notification_id) "
493
- f"FROM {self.tracking_table_name} "
494
- "WHERE application_name=%s"
625
+
626
+ self.insert_tracking_statement = SQL(
627
+ "INSERT INTO {0}.{1} VALUES (%s, %s)"
628
+ ).format(
629
+ Identifier(self.datastore.schema),
630
+ Identifier(self.tracking_table_name),
495
631
  )
496
- self.count_tracking_id_statement = (
497
- "SELECT COUNT(*) "
498
- f"FROM {self.tracking_table_name} "
632
+
633
+ self.max_tracking_id_statement = SQL(
634
+ "SELECT MAX(notification_id) FROM {0}.{1} WHERE application_name=%s"
635
+ ).format(
636
+ Identifier(self.datastore.schema),
637
+ Identifier(self.tracking_table_name),
638
+ )
639
+
640
+ self.count_tracking_id_statement = SQL(
641
+ "SELECT COUNT(*) FROM {0}.{1} "
499
642
  "WHERE application_name=%s AND notification_id=%s"
643
+ ).format(
644
+ Identifier(self.datastore.schema),
645
+ Identifier(self.tracking_table_name),
500
646
  )
501
647
 
502
- def construct_create_table_statements(self) -> List[str]:
503
- statements = super().construct_create_table_statements()
504
- statements.append(
505
- "CREATE TABLE IF NOT EXISTS "
506
- f"{self.tracking_table_name} ("
507
- "application_name text, "
508
- "notification_id bigint, "
509
- "PRIMARY KEY "
510
- "(application_name, notification_id))"
648
+ @retry((InterfaceError, OperationalError), max_attempts=10, wait=0.2)
649
+ def insert_tracking(self, tracking: Tracking) -> None:
650
+ with (
651
+ self.datastore.get_connection() as conn,
652
+ conn.transaction(),
653
+ conn.cursor() as curs,
654
+ ):
655
+ self._insert_tracking(curs, tracking)
656
+
657
+ def _insert_tracking(
658
+ self,
659
+ curs: Cursor[DictRow],
660
+ tracking: Tracking,
661
+ ) -> None:
662
+ curs.execute(
663
+ query=self.insert_tracking_statement,
664
+ params=(
665
+ tracking.application_name,
666
+ tracking.notification_id,
667
+ ),
668
+ prepare=True,
511
669
  )
512
- return statements
513
670
 
514
671
  @retry((InterfaceError, OperationalError), max_attempts=10, wait=0.2)
515
- def max_tracking_id(self, application_name: str) -> int:
672
+ def max_tracking_id(self, application_name: str) -> int | None:
516
673
  with self.datastore.get_connection() as conn, conn.cursor() as curs:
517
674
  curs.execute(
518
675
  query=self.max_tracking_id_statement,
@@ -521,11 +678,14 @@ class PostgresProcessRecorder(PostgresApplicationRecorder, ProcessRecorder):
521
678
  )
522
679
  fetchone = curs.fetchone()
523
680
  assert fetchone is not None
524
- return fetchone["max"] or 0
681
+ return fetchone["max"]
525
682
 
526
683
  @retry((InterfaceError, OperationalError), max_attempts=10, wait=0.2)
527
- def has_tracking_id(self, application_name: str, notification_id: int) -> bool:
528
- conn: Connection[DictRow]
684
+ def has_tracking_id(
685
+ self, application_name: str, notification_id: int | None
686
+ ) -> bool:
687
+ if notification_id is None:
688
+ return True
529
689
  with self.datastore.get_connection() as conn, conn.cursor() as curs:
530
690
  curs.execute(
531
691
  query=self.count_tracking_id_statement,
@@ -536,26 +696,43 @@ class PostgresProcessRecorder(PostgresApplicationRecorder, ProcessRecorder):
536
696
  assert fetchone is not None
537
697
  return bool(fetchone["count"])
538
698
 
699
+
700
+ TPostgresTrackingRecorder = TypeVar(
701
+ "TPostgresTrackingRecorder",
702
+ bound=PostgresTrackingRecorder,
703
+ default=PostgresTrackingRecorder,
704
+ )
705
+
706
+
707
+ class PostgresProcessRecorder(
708
+ PostgresTrackingRecorder, PostgresApplicationRecorder, ProcessRecorder
709
+ ):
710
+ def __init__(
711
+ self,
712
+ datastore: PostgresDatastore,
713
+ *,
714
+ events_table_name: str = "stored_events",
715
+ tracking_table_name: str = "notification_tracking",
716
+ ):
717
+ super().__init__(
718
+ datastore,
719
+ tracking_table_name=tracking_table_name,
720
+ events_table_name=events_table_name,
721
+ )
722
+
539
723
  def _insert_events(
540
724
  self,
541
- c: Cursor[DictRow],
542
- stored_events: List[StoredEvent],
725
+ curs: Cursor[DictRow],
726
+ stored_events: list[StoredEvent],
543
727
  **kwargs: Any,
544
728
  ) -> None:
545
- tracking: Tracking | None = kwargs.get("tracking", None)
729
+ tracking: Tracking | None = kwargs.get("tracking")
546
730
  if tracking is not None:
547
- c.execute(
548
- query=self.insert_tracking_statement,
549
- params=(
550
- tracking.application_name,
551
- tracking.notification_id,
552
- ),
553
- prepare=True,
554
- )
555
- super()._insert_events(c, stored_events, **kwargs)
731
+ self._insert_tracking(curs, tracking=tracking)
732
+ super()._insert_events(curs, stored_events, **kwargs)
556
733
 
557
734
 
558
- class Factory(InfrastructureFactory):
735
+ class PostgresFactory(InfrastructureFactory[PostgresTrackingRecorder]):
559
736
  POSTGRES_DBNAME = "POSTGRES_DBNAME"
560
737
  POSTGRES_HOST = "POSTGRES_HOST"
561
738
  POSTGRES_PORT = "POSTGRES_PORT"
@@ -577,6 +754,7 @@ class Factory(InfrastructureFactory):
577
754
 
578
755
  aggregate_recorder_class = PostgresAggregateRecorder
579
756
  application_recorder_class = PostgresApplicationRecorder
757
+ tracking_recorder_class = PostgresTrackingRecorder
580
758
  process_recorder_class = PostgresProcessRecorder
581
759
 
582
760
  def __init__(self, env: Environment):
@@ -588,6 +766,10 @@ class Factory(InfrastructureFactory):
588
766
  "in environment with key "
589
767
  f"'{self.POSTGRES_DBNAME}'"
590
768
  )
769
+ # TODO: Indicate both keys here, also for other environment variables.
770
+ # ) + " or ".join(
771
+ # [f"'{key}'" for key in self.env.create_keys(self.POSTGRES_DBNAME)]
772
+ # )
591
773
  raise OSError(msg)
592
774
 
593
775
  host = self.env.get(self.POSTGRES_HOST)
@@ -753,8 +935,6 @@ class Factory(InfrastructureFactory):
753
935
  def aggregate_recorder(self, purpose: str = "events") -> AggregateRecorder:
754
936
  prefix = self.env.name.lower() or "stored"
755
937
  events_table_name = prefix + "_" + purpose
756
- if self.datastore.schema:
757
- events_table_name = f"{self.datastore.schema}.{events_table_name}"
758
938
  recorder = type(self).aggregate_recorder_class(
759
939
  datastore=self.datastore,
760
940
  events_table_name=events_table_name,
@@ -766,9 +946,16 @@ class Factory(InfrastructureFactory):
766
946
  def application_recorder(self) -> ApplicationRecorder:
767
947
  prefix = self.env.name.lower() or "stored"
768
948
  events_table_name = prefix + "_events"
769
- if self.datastore.schema:
770
- events_table_name = f"{self.datastore.schema}.{events_table_name}"
771
- recorder = type(self).application_recorder_class(
949
+ application_recorder_topic = self.env.get(self.APPLICATION_RECORDER_TOPIC)
950
+ if application_recorder_topic:
951
+ application_recorder_class: type[PostgresApplicationRecorder] = (
952
+ resolve_topic(application_recorder_topic)
953
+ )
954
+ assert issubclass(application_recorder_class, PostgresApplicationRecorder)
955
+ else:
956
+ application_recorder_class = type(self).application_recorder_class
957
+
958
+ recorder = application_recorder_class(
772
959
  datastore=self.datastore,
773
960
  events_table_name=events_table_name,
774
961
  )
@@ -776,15 +963,45 @@ class Factory(InfrastructureFactory):
776
963
  recorder.create_table()
777
964
  return recorder
778
965
 
966
+ def tracking_recorder(
967
+ self, tracking_recorder_class: type[TPostgresTrackingRecorder] | None = None
968
+ ) -> TPostgresTrackingRecorder:
969
+ prefix = self.env.name.lower() or "notification"
970
+ tracking_table_name = prefix + "_tracking"
971
+ if tracking_recorder_class is None:
972
+ tracking_recorder_topic = self.env.get(self.TRACKING_RECORDER_TOPIC)
973
+ if tracking_recorder_topic:
974
+ tracking_recorder_class = resolve_topic(tracking_recorder_topic)
975
+ else:
976
+ tracking_recorder_class = cast(
977
+ "type[TPostgresTrackingRecorder]",
978
+ type(self).tracking_recorder_class,
979
+ )
980
+ assert tracking_recorder_class is not None
981
+ assert issubclass(tracking_recorder_class, PostgresTrackingRecorder)
982
+ recorder = tracking_recorder_class(
983
+ datastore=self.datastore,
984
+ tracking_table_name=tracking_table_name,
985
+ )
986
+ if self.env_create_table():
987
+ recorder.create_table()
988
+ return recorder
989
+
779
990
  def process_recorder(self) -> ProcessRecorder:
780
991
  prefix = self.env.name.lower() or "stored"
781
992
  events_table_name = prefix + "_events"
782
993
  prefix = self.env.name.lower() or "notification"
783
994
  tracking_table_name = prefix + "_tracking"
784
- if self.datastore.schema:
785
- events_table_name = f"{self.datastore.schema}.{events_table_name}"
786
- tracking_table_name = f"{self.datastore.schema}.{tracking_table_name}"
787
- recorder = type(self).process_recorder_class(
995
+ process_recorder_topic = self.env.get(self.PROCESS_RECORDER_TOPIC)
996
+ if process_recorder_topic:
997
+ process_recorder_class: type[PostgresTrackingRecorder] = resolve_topic(
998
+ process_recorder_topic
999
+ )
1000
+ assert issubclass(process_recorder_class, PostgresProcessRecorder)
1001
+ else:
1002
+ process_recorder_class = type(self).process_recorder_class
1003
+
1004
+ recorder = process_recorder_class(
788
1005
  datastore=self.datastore,
789
1006
  events_table_name=events_table_name,
790
1007
  tracking_table_name=tracking_table_name,
@@ -794,8 +1011,8 @@ class Factory(InfrastructureFactory):
794
1011
  return recorder
795
1012
 
796
1013
  def close(self) -> None:
797
- if hasattr(self, "datastore"):
1014
+ with contextlib.suppress(AttributeError):
798
1015
  self.datastore.close()
799
1016
 
800
- def __del__(self) -> None:
801
- self.close()
1017
+
1018
+ Factory = PostgresFactory