eventsourcing 9.5.0b3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1441 @@
1
+ from __future__ import annotations
2
+
3
+ import contextlib
4
+ import logging
5
+ from asyncio import CancelledError
6
+ from contextlib import contextmanager
7
+ from threading import Thread
8
+ from typing import TYPE_CHECKING, Any, Generic, Literal, NamedTuple, cast
9
+
10
+ import psycopg
11
+ import psycopg.errors
12
+ import psycopg_pool
13
+ from psycopg import Connection, Cursor, Error
14
+ from psycopg.errors import DuplicateObject
15
+ from psycopg.generators import notifies
16
+ from psycopg.rows import DictRow, dict_row
17
+ from psycopg.sql import SQL, Composed, Identifier
18
+ from psycopg.types.composite import CompositeInfo, register_composite
19
+ from psycopg_pool.abc import (
20
+ CT,
21
+ ConnectFailedCB,
22
+ ConnectionCB,
23
+ ConninfoParam,
24
+ KwargsParam,
25
+ )
26
+ from typing_extensions import TypeVar
27
+
28
+ from eventsourcing.persistence import (
29
+ AggregateRecorder,
30
+ ApplicationRecorder,
31
+ BaseInfrastructureFactory,
32
+ DatabaseError,
33
+ DataError,
34
+ InfrastructureFactory,
35
+ IntegrityError,
36
+ InterfaceError,
37
+ InternalError,
38
+ ListenNotifySubscription,
39
+ Notification,
40
+ NotSupportedError,
41
+ OperationalError,
42
+ PersistenceError,
43
+ ProcessRecorder,
44
+ ProgrammingError,
45
+ StoredEvent,
46
+ Subscription,
47
+ Tracking,
48
+ TrackingRecorder,
49
+ TTrackingRecorder,
50
+ )
51
+ from eventsourcing.utils import Environment, EnvType, resolve_topic, retry, strtobool
52
+
53
+ if TYPE_CHECKING:
54
+ from collections.abc import Callable, Iterator, Sequence
55
+ from types import TracebackType
56
+ from uuid import UUID
57
+
58
+ from psycopg.abc import Query
59
+ from typing_extensions import Self
60
+
61
+
62
+ logging.getLogger("psycopg.pool").setLevel(logging.ERROR)
63
+ logging.getLogger("psycopg").setLevel(logging.ERROR)
64
+
65
+ # Copy of "private" psycopg.errors._NO_TRACEBACK (in case it changes)
66
+ # From psycopg: "Don't show a complete traceback upon raising these exception.
67
+ # Usually the traceback starts from internal functions (for instance in the
68
+ # server communication callbacks) but, for the end user, it's more important
69
+ # to get the high level information about where the exception was raised, for
70
+ # instance in a certain `Cursor.execute()`."
71
+ NO_TRACEBACK = (Error, KeyboardInterrupt, CancelledError)
72
+
73
+
74
+ class PgStoredEvent(NamedTuple):
75
+ originator_id: UUID | str
76
+ originator_version: int
77
+ topic: str
78
+ state: bytes
79
+
80
+
81
+ class ConnectionPool(psycopg_pool.ConnectionPool[CT], Generic[CT]):
82
+ def __init__( # noqa: PLR0913
83
+ self,
84
+ conninfo: ConninfoParam = "",
85
+ *,
86
+ connection_class: type[CT] = cast(type[CT], Connection), # noqa: B008
87
+ kwargs: KwargsParam | None = None,
88
+ min_size: int = 4,
89
+ max_size: int | None = None,
90
+ open: bool | None = None, # noqa: A002
91
+ configure: ConnectionCB[CT] | None = None,
92
+ check: ConnectionCB[CT] | None = None,
93
+ reset: ConnectionCB[CT] | None = None,
94
+ name: str | None = None,
95
+ close_returns: bool = False,
96
+ timeout: float = 30.0,
97
+ max_waiting: int = 0,
98
+ max_lifetime: float = 60 * 60.0,
99
+ max_idle: float = 10 * 60.0,
100
+ reconnect_timeout: float = 5 * 60.0,
101
+ reconnect_failed: ConnectFailedCB | None = None,
102
+ num_workers: int = 3,
103
+ get_password_func: Callable[[], str] | None = None,
104
+ ) -> None:
105
+ self.get_password_func = get_password_func
106
+ super().__init__(
107
+ conninfo,
108
+ connection_class=connection_class,
109
+ kwargs=kwargs,
110
+ min_size=min_size,
111
+ max_size=max_size,
112
+ open=open,
113
+ configure=configure,
114
+ check=check,
115
+ reset=reset,
116
+ name=name,
117
+ close_returns=close_returns,
118
+ timeout=timeout,
119
+ max_waiting=max_waiting,
120
+ max_lifetime=max_lifetime,
121
+ max_idle=max_idle,
122
+ reconnect_timeout=reconnect_timeout,
123
+ reconnect_failed=reconnect_failed,
124
+ num_workers=num_workers,
125
+ )
126
+
127
+ def _connect(self, timeout: float | None = None) -> CT:
128
+ if self.get_password_func:
129
+ assert isinstance(self.kwargs, dict)
130
+ self.kwargs["password"] = self.get_password_func()
131
+ return super()._connect(timeout=timeout)
132
+
133
+
134
+ class PostgresDatastore:
135
+ def __init__( # noqa: PLR0913
136
+ self,
137
+ dbname: str,
138
+ host: str,
139
+ port: str | int,
140
+ user: str,
141
+ password: str,
142
+ *,
143
+ connect_timeout: float = 5.0,
144
+ idle_in_transaction_session_timeout: float = 0,
145
+ pool_size: int = 1,
146
+ max_overflow: int = 0,
147
+ max_waiting: int = 0,
148
+ conn_max_age: float = 60 * 60.0,
149
+ pre_ping: bool = False,
150
+ lock_timeout: int = 0,
151
+ schema: str = "",
152
+ pool_open_timeout: float | None = None,
153
+ get_password_func: Callable[[], str] | None = None,
154
+ single_row_tracking: bool = True,
155
+ originator_id_type: Literal["uuid", "text"] = "uuid",
156
+ enable_db_functions: bool = False,
157
+ ):
158
+ self.idle_in_transaction_session_timeout = idle_in_transaction_session_timeout
159
+ self.pre_ping = pre_ping
160
+ self.pool_open_timeout = pool_open_timeout
161
+ self.single_row_tracking = single_row_tracking
162
+ self.lock_timeout = lock_timeout
163
+ self.schema = schema.strip() or "public"
164
+ if originator_id_type.lower() not in ("uuid", "text"):
165
+ msg = (
166
+ f"Invalid originator_id_type '{originator_id_type}', "
167
+ f"must be 'uuid' or 'text'"
168
+ )
169
+ raise ValueError(msg)
170
+ self.originator_id_type = originator_id_type.lower()
171
+
172
+ self.enable_db_functions = enable_db_functions
173
+
174
+ check = ConnectionPool.check_connection if pre_ping else None
175
+ self.db_type_names = set[str]()
176
+ self.psycopg_type_adapters: dict[str, CompositeInfo] = {}
177
+ self.psycopg_python_types: dict[str, Any] = {}
178
+ self.pool = ConnectionPool(
179
+ get_password_func=get_password_func,
180
+ connection_class=Connection[DictRow],
181
+ kwargs={
182
+ "dbname": dbname,
183
+ "host": host,
184
+ "port": port,
185
+ "user": user,
186
+ "password": password,
187
+ "row_factory": dict_row,
188
+ },
189
+ min_size=pool_size,
190
+ max_size=pool_size + max_overflow,
191
+ open=False,
192
+ configure=self.after_connect_func(),
193
+ timeout=connect_timeout,
194
+ max_waiting=max_waiting,
195
+ max_lifetime=conn_max_age,
196
+ check=check, # pyright: ignore [reportArgumentType]
197
+ )
198
+
199
+ def after_connect_func(self) -> Callable[[Connection[Any]], None]:
200
+ set_idle_in_transaction_session_timeout_statement = SQL(
201
+ "SET idle_in_transaction_session_timeout = '{0}ms'"
202
+ ).format(int(self.idle_in_transaction_session_timeout * 1000))
203
+
204
+ # Avoid passing a bound method to the pool,
205
+ # to avoid creating a circular ref to self.
206
+ def after_connect(conn: Connection[DictRow]) -> None:
207
+ # Put connection in auto-commit mode.
208
+ conn.autocommit = True
209
+
210
+ # Set idle in transaction session timeout.
211
+ conn.cursor().execute(set_idle_in_transaction_session_timeout_statement)
212
+
213
+ return after_connect
214
+
215
+ def register_type_adapters(self) -> None:
216
+ # Construct and/or register composite type adapters.
217
+ unregistered_names = [
218
+ name
219
+ for name in self.db_type_names
220
+ if name not in self.psycopg_type_adapters
221
+ ]
222
+ if not unregistered_names:
223
+ return
224
+ with self.get_connection() as conn:
225
+ for name in unregistered_names:
226
+ # Construct type adapter from database info.
227
+ info = CompositeInfo.fetch(conn, f"{self.schema}.{name}")
228
+ if info is None:
229
+ continue
230
+ # Register the type adapter centrally.
231
+ register_composite(info, conn)
232
+ # Cache the python type for our own use.
233
+ self.psycopg_type_adapters[name] = info
234
+ assert info.python_type is not None, info
235
+ self.psycopg_python_types[name] = info.python_type
236
+
237
+ @contextmanager
238
+ def get_connection(self) -> Iterator[Connection[DictRow]]:
239
+ try:
240
+ wait = self.pool_open_timeout is not None
241
+ timeout = self.pool_open_timeout or 30.0
242
+ self.pool.open(wait, timeout)
243
+
244
+ with self.pool.connection() as conn:
245
+ # Make sure the connection has the type adapters.
246
+ for info in self.psycopg_type_adapters.values():
247
+ if not conn.adapters.types.get(info.oid):
248
+ register_composite(info, conn)
249
+ # Yield connection.
250
+ yield conn
251
+ except psycopg.InterfaceError as e:
252
+ # conn.close()
253
+ raise InterfaceError(str(e)) from e
254
+ except psycopg.OperationalError as e:
255
+ # conn.close()
256
+ raise OperationalError(str(e)) from e
257
+ except psycopg.DataError as e:
258
+ raise DataError(str(e)) from e
259
+ except psycopg.IntegrityError as e:
260
+ raise IntegrityError(str(e)) from e
261
+ except psycopg.InternalError as e:
262
+ raise InternalError(str(e)) from e
263
+ except psycopg.ProgrammingError as e:
264
+ raise ProgrammingError(str(e)) from e
265
+ except psycopg.NotSupportedError as e:
266
+ raise NotSupportedError(str(e)) from e
267
+ except psycopg.DatabaseError as e:
268
+ raise DatabaseError(str(e)) from e
269
+ except psycopg.Error as e:
270
+ # conn.close()
271
+ raise PersistenceError(str(e)) from e
272
+ except Exception:
273
+ # conn.close()
274
+ raise
275
+
276
+ @contextmanager
277
+ def cursor(self) -> Iterator[Cursor[DictRow]]:
278
+ with self.get_connection() as conn:
279
+ yield conn.cursor()
280
+
281
+ @contextmanager
282
+ def transaction(self, *, commit: bool = False) -> Iterator[Cursor[DictRow]]:
283
+ with self.get_connection() as conn, conn.transaction(force_rollback=not commit):
284
+ yield conn.cursor()
285
+
286
+ def close(self) -> None:
287
+ with contextlib.suppress(AttributeError):
288
+ self.pool.close()
289
+
290
+ def __enter__(self) -> Self:
291
+ self.pool.__enter__()
292
+ return self
293
+
294
+ def __exit__(
295
+ self,
296
+ exc_type: type[BaseException] | None,
297
+ exc_val: BaseException | None,
298
+ exc_tb: TracebackType | None,
299
+ ) -> None:
300
+ self.pool.__exit__(exc_type, exc_val, exc_tb)
301
+
302
+ def __del__(self) -> None:
303
+ self.close()
304
+
305
+
306
+ class PostgresRecorder:
307
+ """Base class for recorders that use PostgreSQL."""
308
+
309
+ MAX_IDENTIFIER_LEN = 63
310
+ # From the PostgreSQL docs: "The system uses no more than NAMEDATALEN-1 bytes
311
+ # of an identifier; longer names can be written in commands, but they will be
312
+ # truncated. By default, NAMEDATALEN is 64 so the maximum identifier length is
313
+ # 63 bytes." https://www.postgresql.org/docs/current/sql-syntax-lexical.html
314
+
315
+ def __init__(
316
+ self,
317
+ datastore: PostgresDatastore,
318
+ ):
319
+ self.datastore = datastore
320
+ self.sql_create_statements: list[Composed] = []
321
+
322
+ @staticmethod
323
+ def check_identifier_length(table_name: str) -> None:
324
+ if len(table_name) > PostgresRecorder.MAX_IDENTIFIER_LEN:
325
+ msg = f"Identifier too long: {table_name}"
326
+ raise ProgrammingError(msg)
327
+
328
+ def create_table(self) -> None:
329
+ # Create composite types.
330
+ for statement in self.sql_create_statements:
331
+ if "CREATE TYPE" in statement.as_string():
332
+ # Do in own transaction, because there is no 'IF NOT EXISTS' option
333
+ # when creating types, and if exists, then a DuplicateObject error
334
+ # is raised, terminating the transaction and causing an opaque error.
335
+ with (
336
+ self.datastore.transaction(commit=True) as curs,
337
+ contextlib.suppress(DuplicateObject),
338
+ ):
339
+ curs.execute(statement, prepare=False)
340
+ # try:
341
+ # except psycopg.errors.SyntaxError as e:
342
+ # msg = f"Syntax error: '{e}' in: {statement.as_string()}"
343
+ # raise ProgrammingError(msg) from e
344
+
345
+ # Create tables, indexes, types, functions, and procedures.
346
+ with self.datastore.transaction(commit=True) as curs:
347
+ self._create_table(curs)
348
+
349
+ # Register type adapters.
350
+ self.datastore.register_type_adapters()
351
+
352
+ def _create_table(self, curs: Cursor[DictRow]) -> None:
353
+ for statement in self.sql_create_statements:
354
+ if "CREATE TYPE" not in statement.as_string():
355
+ try:
356
+ curs.execute(statement, prepare=False)
357
+ except psycopg.errors.SyntaxError as e:
358
+ msg = f"Syntax error: '{e}' in: {statement.as_string()}"
359
+ raise ProgrammingError(msg) from e
360
+
361
+
362
+ class PostgresAggregateRecorder(PostgresRecorder, AggregateRecorder):
363
+ def __init__(
364
+ self,
365
+ datastore: PostgresDatastore,
366
+ *,
367
+ events_table_name: str = "stored_events",
368
+ ):
369
+ super().__init__(datastore)
370
+ self.check_identifier_length(events_table_name)
371
+ self.events_table_name = events_table_name
372
+ # Index names can't be qualified names, but
373
+ # are created in the same schema as the table.
374
+ self.notification_id_index_name = (
375
+ f"{self.events_table_name}_notification_id_idx"
376
+ )
377
+
378
+ self.stored_event_type_name = (
379
+ f"stored_event_{self.datastore.originator_id_type}"
380
+ )
381
+ self.datastore.db_type_names.add(self.stored_event_type_name)
382
+ self.datastore.register_type_adapters()
383
+ self.create_table_statement_index = len(self.sql_create_statements)
384
+ self.sql_create_statements.append(
385
+ SQL(
386
+ "CREATE TABLE IF NOT EXISTS {schema}.{table} ("
387
+ "originator_id {originator_id_type} NOT NULL, "
388
+ "originator_version bigint NOT NULL, "
389
+ "topic text, "
390
+ "state bytea, "
391
+ "PRIMARY KEY "
392
+ "(originator_id, originator_version)) "
393
+ "WITH ("
394
+ " autovacuum_enabled = true,"
395
+ " autovacuum_vacuum_threshold = 100000000,"
396
+ " autovacuum_vacuum_scale_factor = 0.5,"
397
+ " autovacuum_analyze_threshold = 1000,"
398
+ " autovacuum_analyze_scale_factor = 0.01"
399
+ ")"
400
+ ).format(
401
+ schema=Identifier(self.datastore.schema),
402
+ table=Identifier(self.events_table_name),
403
+ originator_id_type=Identifier(self.datastore.originator_id_type),
404
+ )
405
+ )
406
+
407
+ self.insert_events_statement = SQL(
408
+ " INSERT INTO {schema}.{table} AS t ("
409
+ " originator_id, originator_version, topic, state)"
410
+ " SELECT originator_id, originator_version, topic, state"
411
+ " FROM unnest(%s::{schema}.{stored_event_type}[])"
412
+ ).format(
413
+ schema=Identifier(self.datastore.schema),
414
+ table=Identifier(self.events_table_name),
415
+ stored_event_type=Identifier(self.stored_event_type_name),
416
+ )
417
+
418
+ self.select_events_statement = SQL(
419
+ "SELECT * FROM {schema}.{table} WHERE originator_id = %s"
420
+ ).format(
421
+ schema=Identifier(self.datastore.schema),
422
+ table=Identifier(self.events_table_name),
423
+ )
424
+
425
+ self.lock_table_statements: list[Query] = []
426
+
427
+ self.sql_create_statements.append(
428
+ SQL(
429
+ "CREATE TYPE {schema}.{name} "
430
+ "AS (originator_id {originator_id_type}, "
431
+ "originator_version bigint, "
432
+ "topic text, "
433
+ "state bytea)"
434
+ ).format(
435
+ schema=Identifier(self.datastore.schema),
436
+ name=Identifier(self.stored_event_type_name),
437
+ originator_id_type=Identifier(self.datastore.originator_id_type),
438
+ )
439
+ )
440
+
441
+ def construct_pg_stored_event(
442
+ self,
443
+ originator_id: UUID | str,
444
+ originator_version: int,
445
+ topic: str,
446
+ state: bytes,
447
+ ) -> PgStoredEvent:
448
+ try:
449
+ return self.datastore.psycopg_python_types[self.stored_event_type_name](
450
+ originator_id, originator_version, topic, state
451
+ )
452
+ except KeyError:
453
+ msg = f"Composite type '{self.stored_event_type_name}' not found"
454
+ raise ProgrammingError(msg) from None
455
+
456
+ @retry((InterfaceError, OperationalError), max_attempts=10, wait=0.2)
457
+ def insert_events(
458
+ self, stored_events: Sequence[StoredEvent], **kwargs: Any
459
+ ) -> Sequence[int] | None:
460
+ # Only do something if there is something to do.
461
+ if len(stored_events) > 0:
462
+ with self.datastore.get_connection() as conn, conn.cursor() as curs:
463
+ assert conn.autocommit
464
+ self._insert_stored_events(curs, stored_events, **kwargs)
465
+ return None
466
+
467
+ def _insert_stored_events(
468
+ self,
469
+ curs: Cursor[DictRow],
470
+ stored_events: Sequence[StoredEvent],
471
+ **_: Any,
472
+ ) -> None:
473
+ # Construct composite type.
474
+ pg_stored_events = [
475
+ self.construct_pg_stored_event(
476
+ stored_event.originator_id,
477
+ stored_event.originator_version,
478
+ stored_event.topic,
479
+ stored_event.state,
480
+ )
481
+ for stored_event in stored_events
482
+ ]
483
+ # Insert events.
484
+ curs.execute(
485
+ query=self.insert_events_statement,
486
+ params=(pg_stored_events,),
487
+ )
488
+
489
+ @retry((InterfaceError, OperationalError), max_attempts=10, wait=0.2)
490
+ def select_events(
491
+ self,
492
+ originator_id: UUID | str,
493
+ *,
494
+ gt: int | None = None,
495
+ lte: int | None = None,
496
+ desc: bool = False,
497
+ limit: int | None = None,
498
+ ) -> Sequence[StoredEvent]:
499
+ statement = self.select_events_statement
500
+ params: list[Any] = [originator_id]
501
+ if gt is not None:
502
+ params.append(gt)
503
+ statement += SQL(" AND originator_version > %s")
504
+ if lte is not None:
505
+ params.append(lte)
506
+ statement += SQL(" AND originator_version <= %s")
507
+ statement += SQL(" ORDER BY originator_version")
508
+ if desc is False:
509
+ statement += SQL(" ASC")
510
+ else:
511
+ statement += SQL(" DESC")
512
+ if limit is not None:
513
+ params.append(limit)
514
+ statement += SQL(" LIMIT %s")
515
+
516
+ with self.datastore.get_connection() as conn, conn.cursor() as curs:
517
+ curs.execute(statement, params, prepare=True)
518
+ return [
519
+ StoredEvent(
520
+ originator_id=row["originator_id"],
521
+ originator_version=row["originator_version"],
522
+ topic=row["topic"],
523
+ state=bytes(row["state"]),
524
+ )
525
+ for row in curs.fetchall()
526
+ ]
527
+
528
+
529
+ class PostgresApplicationRecorder(PostgresAggregateRecorder, ApplicationRecorder):
530
+ def __init__(
531
+ self,
532
+ datastore: PostgresDatastore,
533
+ *,
534
+ events_table_name: str = "stored_events",
535
+ ):
536
+ super().__init__(datastore, events_table_name=events_table_name)
537
+ self.sql_create_statements[self.create_table_statement_index] = SQL(
538
+ "CREATE TABLE IF NOT EXISTS {schema}.{table} ("
539
+ "originator_id {originator_id_type} NOT NULL, "
540
+ "originator_version bigint NOT NULL, "
541
+ "topic text, "
542
+ "state bytea, "
543
+ "notification_id bigserial, "
544
+ "PRIMARY KEY "
545
+ "(originator_id, originator_version)) "
546
+ "WITH ("
547
+ " autovacuum_enabled = true,"
548
+ " autovacuum_vacuum_threshold = 100000000,"
549
+ " autovacuum_vacuum_scale_factor = 0.5,"
550
+ " autovacuum_analyze_threshold = 1000,"
551
+ " autovacuum_analyze_scale_factor = 0.01"
552
+ ")"
553
+ ).format(
554
+ schema=Identifier(self.datastore.schema),
555
+ table=Identifier(self.events_table_name),
556
+ originator_id_type=Identifier(self.datastore.originator_id_type),
557
+ )
558
+
559
+ self.sql_create_statements.append(
560
+ SQL(
561
+ "CREATE UNIQUE INDEX IF NOT EXISTS {index} "
562
+ "ON {schema}.{table} (notification_id ASC);"
563
+ ).format(
564
+ index=Identifier(self.notification_id_index_name),
565
+ schema=Identifier(self.datastore.schema),
566
+ table=Identifier(self.events_table_name),
567
+ )
568
+ )
569
+
570
+ self.channel_name = self.events_table_name.replace(".", "_")
571
+ self.insert_events_statement += SQL(" RETURNING notification_id")
572
+
573
+ self.max_notification_id_statement = SQL(
574
+ "SELECT MAX(notification_id) FROM {schema}.{table}"
575
+ ).format(
576
+ schema=Identifier(self.datastore.schema),
577
+ table=Identifier(self.events_table_name),
578
+ )
579
+
580
+ self.lock_table_statements = [
581
+ SQL("SET LOCAL lock_timeout = '{0}s'").format(self.datastore.lock_timeout),
582
+ SQL("LOCK TABLE {0}.{1} IN EXCLUSIVE MODE").format(
583
+ Identifier(self.datastore.schema),
584
+ Identifier(self.events_table_name),
585
+ ),
586
+ ]
587
+
588
+ self.pg_function_name_insert_events = (
589
+ f"es_insert_events_{self.datastore.originator_id_type}"
590
+ )
591
+ self.sql_invoke_pg_function_insert_events = SQL(
592
+ "SELECT * FROM {insert_events}((%s))"
593
+ ).format(insert_events=Identifier(self.pg_function_name_insert_events))
594
+
595
+ self.sql_create_pg_function_insert_events = SQL(
596
+ "CREATE OR REPLACE FUNCTION {insert_events}(events {schema}.{event}[]) "
597
+ "RETURNS SETOF bigint "
598
+ "LANGUAGE plpgsql "
599
+ "AS "
600
+ "$BODY$"
601
+ "BEGIN"
602
+ " SET LOCAL lock_timeout = '{lock_timeout}s';"
603
+ " NOTIFY {channel};"
604
+ " RETURN QUERY"
605
+ " INSERT INTO {schema}.{table} AS t ("
606
+ " originator_id, originator_version, topic, state)"
607
+ " SELECT originator_id, originator_version, topic, state"
608
+ " FROM unnest(events)"
609
+ " RETURNING notification_id;"
610
+ "END;"
611
+ "$BODY$"
612
+ ).format(
613
+ insert_events=Identifier(self.pg_function_name_insert_events),
614
+ lock_timeout=self.datastore.lock_timeout,
615
+ channel=Identifier(self.channel_name),
616
+ event=Identifier(self.stored_event_type_name),
617
+ schema=Identifier(self.datastore.schema),
618
+ table=Identifier(self.events_table_name),
619
+ )
620
+ self.create_insert_function_statement_index = len(self.sql_create_statements)
621
+ self.sql_create_statements.append(self.sql_create_pg_function_insert_events)
622
+
623
+ @retry((InterfaceError, OperationalError), max_attempts=10, wait=0.2)
624
+ def insert_events(
625
+ self, stored_events: Sequence[StoredEvent], **kwargs: Any
626
+ ) -> Sequence[int] | None:
627
+ if self.datastore.enable_db_functions:
628
+ pg_stored_events = [
629
+ self.construct_pg_stored_event(
630
+ originator_id=e.originator_id,
631
+ originator_version=e.originator_version,
632
+ topic=e.topic,
633
+ state=e.state,
634
+ )
635
+ for e in stored_events
636
+ ]
637
+ with self.datastore.get_connection() as conn, conn.cursor() as curs:
638
+ curs.execute(
639
+ self.sql_invoke_pg_function_insert_events,
640
+ (pg_stored_events,),
641
+ prepare=True,
642
+ )
643
+ return [r[self.pg_function_name_insert_events] for r in curs.fetchall()]
644
+
645
+ exc: Exception | None = None
646
+ notification_ids: Sequence[int] | None = None
647
+ with self.datastore.get_connection() as conn:
648
+ with conn.pipeline() as pipeline, conn.transaction():
649
+ # Do other things first, so they can be pipelined too.
650
+ with conn.cursor() as curs:
651
+ self._insert_events(curs, stored_events, **kwargs)
652
+ # Then use a different cursor for the executemany() call.
653
+ if len(stored_events) > 0:
654
+ with conn.cursor() as curs:
655
+ try:
656
+ self._insert_stored_events(curs, stored_events, **kwargs)
657
+ # Sync now, so any uniqueness constraint violation causes an
658
+ # IntegrityError to be raised here, rather an InternalError
659
+ # being raised sometime later e.g. when commit() is called.
660
+ pipeline.sync()
661
+ notification_ids = self._fetch_ids_after_insert_events(
662
+ curs, stored_events, **kwargs
663
+ )
664
+ except Exception as e:
665
+ # Avoid psycopg emitting a pipeline warning.
666
+ exc = e
667
+ if exc:
668
+ # Reraise exception after pipeline context manager has exited.
669
+ raise exc
670
+ return notification_ids
671
+
672
+ def _insert_events(
673
+ self,
674
+ curs: Cursor[DictRow],
675
+ stored_events: Sequence[StoredEvent],
676
+ **_: Any,
677
+ ) -> None:
678
+ pass
679
+
680
+ def _insert_stored_events(
681
+ self,
682
+ curs: Cursor[DictRow],
683
+ stored_events: Sequence[StoredEvent],
684
+ **kwargs: Any,
685
+ ) -> None:
686
+ self._lock_table(curs)
687
+ self._notify_channel(curs)
688
+ super()._insert_stored_events(curs, stored_events, **kwargs)
689
+
690
+ @retry((InterfaceError, OperationalError), max_attempts=10, wait=0.2)
691
+ def select_notifications(
692
+ self,
693
+ start: int | None,
694
+ limit: int,
695
+ stop: int | None = None,
696
+ topics: Sequence[str] = (),
697
+ *,
698
+ inclusive_of_start: bool = True,
699
+ ) -> Sequence[Notification]:
700
+ """Returns a list of event notifications
701
+ from 'start', limited by 'limit'.
702
+ """
703
+ params: list[int | str | Sequence[str]] = []
704
+ statement = SQL("SELECT * FROM {schema}.{table}").format(
705
+ schema=Identifier(self.datastore.schema),
706
+ table=Identifier(self.events_table_name),
707
+ )
708
+ has_where = False
709
+ if start is not None:
710
+ statement += SQL(" WHERE")
711
+ has_where = True
712
+ params.append(start)
713
+ if inclusive_of_start:
714
+ statement += SQL(" notification_id>=%s")
715
+ else:
716
+ statement += SQL(" notification_id>%s")
717
+
718
+ if stop is not None:
719
+ if not has_where:
720
+ has_where = True
721
+ statement += SQL(" WHERE")
722
+ else:
723
+ statement += SQL(" AND")
724
+
725
+ params.append(stop)
726
+ statement += SQL(" notification_id <= %s")
727
+
728
+ if topics:
729
+ # Check sequence and ensure list of strings.
730
+ assert isinstance(topics, (tuple, list)), topics
731
+ topics = list(topics) if isinstance(topics, tuple) else topics
732
+ assert all(isinstance(t, str) for t in topics), topics
733
+ if not has_where:
734
+ statement += SQL(" WHERE")
735
+ else:
736
+ statement += SQL(" AND")
737
+ params.append(topics)
738
+ statement += SQL(" topic = ANY(%s)")
739
+
740
+ params.append(limit)
741
+ statement += SQL(" ORDER BY notification_id LIMIT %s")
742
+
743
+ connection = self.datastore.get_connection()
744
+ with connection as conn, conn.cursor() as curs:
745
+ curs.execute(statement, params, prepare=True)
746
+ return [
747
+ Notification(
748
+ id=row["notification_id"],
749
+ originator_id=row["originator_id"],
750
+ originator_version=row["originator_version"],
751
+ topic=row["topic"],
752
+ state=bytes(row["state"]),
753
+ )
754
+ for row in curs.fetchall()
755
+ ]
756
+
757
+ @retry((InterfaceError, OperationalError), max_attempts=10, wait=0.2)
758
+ def max_notification_id(self) -> int | None:
759
+ """Returns the maximum notification ID."""
760
+ with self.datastore.get_connection() as conn, conn.cursor() as curs:
761
+ curs.execute(self.max_notification_id_statement)
762
+ fetchone = curs.fetchone()
763
+ assert fetchone is not None
764
+ return fetchone["max"]
765
+
766
+ def _lock_table(self, curs: Cursor[DictRow]) -> None:
767
+ # Acquire "EXCLUSIVE" table lock, to serialize transactions that insert
768
+ # stored events, so that readers don't pass over gaps that are filled in
769
+ # later. We want each transaction that will be issued with notification
770
+ # IDs by the notification ID sequence to receive all its notification IDs
771
+ # and then commit, before another transaction is issued with any notification
772
+ # IDs. In other words, we want the insert order to be the same as the commit
773
+ # order. We can accomplish this by locking the table for writes. The
774
+ # EXCLUSIVE lock mode does not block SELECT statements, which acquire an
775
+ # ACCESS SHARE lock, so the stored events table can be read concurrently
776
+ # with writes and other reads. However, INSERT statements normally just
777
+ # acquires ROW EXCLUSIVE locks, which risks the interleaving (within the
778
+ # recorded sequence of notification IDs) of stored events from one transaction
779
+ # with those of another transaction. And since one transaction will always
780
+ # commit before another, the possibility arises when using ROW EXCLUSIVE locks
781
+ # for readers that are tailing a notification log to miss items inserted later
782
+ # but issued with lower notification IDs.
783
+ # https://www.postgresql.org/docs/current/explicit-locking.html#LOCKING-TABLES
784
+ # https://www.postgresql.org/docs/9.1/sql-lock.html
785
+ # https://stackoverflow.com/questions/45866187/guarantee-monotonicity-of
786
+ # -postgresql-serial-column-values-by-commit-order
787
+ for lock_statement in self.lock_table_statements:
788
+ curs.execute(lock_statement, prepare=True)
789
+
790
+ def _notify_channel(self, curs: Cursor[DictRow]) -> None:
791
+ curs.execute(SQL("NOTIFY {0}").format(Identifier(self.channel_name)))
792
+
793
+ def _fetch_ids_after_insert_events(
794
+ self,
795
+ curs: Cursor[DictRow],
796
+ stored_events: Sequence[StoredEvent],
797
+ **_: Any,
798
+ ) -> Sequence[int] | None:
799
+ notification_ids: list[int] = []
800
+ assert curs.statusmessage and curs.statusmessage.startswith(
801
+ "INSERT"
802
+ ), curs.statusmessage
803
+ try:
804
+ notification_ids = [row["notification_id"] for row in curs.fetchall()]
805
+ except psycopg.ProgrammingError as e:
806
+ msg = "Couldn't get all notification IDs "
807
+ msg += f"(got {len(notification_ids)}, expected {len(stored_events)})"
808
+ raise ProgrammingError(msg) from e
809
+ return notification_ids
810
+
811
+ def subscribe(
812
+ self, gt: int | None = None, topics: Sequence[str] = ()
813
+ ) -> Subscription[ApplicationRecorder]:
814
+ return PostgresSubscription(recorder=self, gt=gt, topics=topics)
815
+
816
+
817
+ class PostgresSubscription(ListenNotifySubscription[PostgresApplicationRecorder]):
818
+ def __init__(
819
+ self,
820
+ recorder: PostgresApplicationRecorder,
821
+ gt: int | None = None,
822
+ topics: Sequence[str] = (),
823
+ ) -> None:
824
+ assert isinstance(recorder, PostgresApplicationRecorder)
825
+ super().__init__(recorder=recorder, gt=gt, topics=topics)
826
+ self._listen_thread = Thread(target=self._listen)
827
+ self._listen_thread.start()
828
+
829
+ def __exit__(self, *args: object, **kwargs: Any) -> None:
830
+ try:
831
+ super().__exit__(*args, **kwargs)
832
+ finally:
833
+ self._listen_thread.join()
834
+
835
+ def _listen(self) -> None:
836
+ try:
837
+ with self._recorder.datastore.get_connection() as conn:
838
+ conn.execute(
839
+ SQL("LISTEN {0}").format(Identifier(self._recorder.channel_name))
840
+ )
841
+ while not self._has_been_stopped and not self._thread_error:
842
+ # This block simplifies psycopg's conn.notifies(), because
843
+ # we aren't interested in the actual notify messages, and
844
+ # also we want to stop consuming notify messages when the
845
+ # subscription has an error or is otherwise stopped.
846
+ with conn.lock:
847
+ try:
848
+ if conn.wait(notifies(conn.pgconn), interval=0.1):
849
+ self._has_been_notified.set()
850
+ except NO_TRACEBACK as ex: # pragma: no cover
851
+ raise ex.with_traceback(None) from None
852
+
853
+ except BaseException as e:
854
+ if self._thread_error is None:
855
+ self._thread_error = e
856
+ self.stop()
857
+
858
+
859
+ class PostgresTrackingRecorder(PostgresRecorder, TrackingRecorder):
860
+ def __init__(
861
+ self,
862
+ datastore: PostgresDatastore,
863
+ *,
864
+ tracking_table_name: str = "notification_tracking",
865
+ **kwargs: Any,
866
+ ):
867
+ super().__init__(datastore, **kwargs)
868
+ self.check_identifier_length(tracking_table_name)
869
+ self.tracking_table_name = tracking_table_name
870
+ self.tracking_table_exists: bool = False
871
+ self.tracking_migration_previous: int | None = None
872
+ self.tracking_migration_current: int | None = None
873
+ self.table_migration_identifier = "__migration__"
874
+ self.has_checked_for_multi_row_tracking_table: bool = False
875
+ if self.datastore.single_row_tracking:
876
+ # For single-row tracking.
877
+ self.sql_create_statements.append(
878
+ SQL(
879
+ "CREATE TABLE IF NOT EXISTS {0}.{1} ("
880
+ "application_name text, "
881
+ "notification_id bigint, "
882
+ "PRIMARY KEY "
883
+ "(application_name))"
884
+ "WITH ("
885
+ " autovacuum_enabled = true,"
886
+ " autovacuum_vacuum_threshold = 100000000,"
887
+ " autovacuum_vacuum_scale_factor = 0.5,"
888
+ " autovacuum_analyze_threshold = 1000,"
889
+ " autovacuum_analyze_scale_factor = 0.01"
890
+ ")"
891
+ ).format(
892
+ Identifier(self.datastore.schema),
893
+ Identifier(self.tracking_table_name),
894
+ )
895
+ )
896
+ self.insert_tracking_statement = SQL(
897
+ "INSERT INTO {0}.{1} "
898
+ "VALUES (%(application_name)s, %(notification_id)s) "
899
+ "ON CONFLICT (application_name) DO UPDATE "
900
+ "SET notification_id = %(notification_id)s "
901
+ "WHERE {0}.{1}.notification_id < %(notification_id)s "
902
+ "RETURNING notification_id"
903
+ ).format(
904
+ Identifier(self.datastore.schema),
905
+ Identifier(self.tracking_table_name),
906
+ )
907
+ else:
908
+ # For legacy multi-row tracking.
909
+ self.sql_create_statements.append(
910
+ SQL(
911
+ "CREATE TABLE IF NOT EXISTS {0}.{1} ("
912
+ "application_name text, "
913
+ "notification_id bigint, "
914
+ "PRIMARY KEY "
915
+ "(application_name, notification_id))"
916
+ "WITH ("
917
+ " autovacuum_enabled = true,"
918
+ " autovacuum_vacuum_threshold = 100000000,"
919
+ " autovacuum_vacuum_scale_factor = 0.5,"
920
+ " autovacuum_analyze_threshold = 1000,"
921
+ " autovacuum_analyze_scale_factor = 0.01"
922
+ ")"
923
+ ).format(
924
+ Identifier(self.datastore.schema),
925
+ Identifier(self.tracking_table_name),
926
+ )
927
+ )
928
+ self.insert_tracking_statement = SQL(
929
+ "INSERT INTO {0}.{1} VALUES (%(application_name)s, %(notification_id)s)"
930
+ ).format(
931
+ Identifier(self.datastore.schema),
932
+ Identifier(self.tracking_table_name),
933
+ )
934
+
935
+ self.max_tracking_id_statement = SQL(
936
+ "SELECT MAX(notification_id) FROM {0}.{1} WHERE application_name=%s"
937
+ ).format(
938
+ Identifier(self.datastore.schema),
939
+ Identifier(self.tracking_table_name),
940
+ )
941
+
942
+ def create_table(self) -> None:
943
+ # Get the migration version.
944
+ try:
945
+ self.tracking_migration_current = self.tracking_migration_previous = (
946
+ self.max_tracking_id(self.table_migration_identifier)
947
+ )
948
+ except ProgrammingError:
949
+ pass
950
+ else:
951
+ self.tracking_table_exists = True
952
+ super().create_table()
953
+ if (
954
+ not self.datastore.single_row_tracking
955
+ and self.tracking_migration_current is not None
956
+ ):
957
+ msg = "Can't do multi-row tracking with single-row tracking table"
958
+ raise OperationalError(msg)
959
+
960
+ def _create_table(self, curs: Cursor[DictRow]) -> None:
961
+ max_tracking_ids: dict[str, int] = {}
962
+ if (
963
+ self.datastore.single_row_tracking
964
+ and self.tracking_table_exists
965
+ and not self.tracking_migration_previous
966
+ ):
967
+ # Migrate the table.
968
+ curs.execute(
969
+ SQL("SET LOCAL lock_timeout = '{0}s'").format(
970
+ self.datastore.lock_timeout
971
+ )
972
+ )
973
+ curs.execute(
974
+ SQL("LOCK TABLE {0}.{1} IN ACCESS EXCLUSIVE MODE").format(
975
+ Identifier(self.datastore.schema),
976
+ Identifier(self.tracking_table_name),
977
+ )
978
+ )
979
+
980
+ # Get all application names.
981
+ application_names: list[str] = [
982
+ select_row["application_name"]
983
+ for select_row in curs.execute(
984
+ SQL("SELECT DISTINCT application_name FROM {0}.{1}").format(
985
+ Identifier(self.datastore.schema),
986
+ Identifier(self.tracking_table_name),
987
+ )
988
+ )
989
+ ]
990
+
991
+ # Get max tracking ID for each application name.
992
+ for application_name in application_names:
993
+ curs.execute(self.max_tracking_id_statement, (application_name,))
994
+ max_tracking_id_row = curs.fetchone()
995
+ assert max_tracking_id_row is not None
996
+ max_tracking_ids[application_name] = max_tracking_id_row["max"]
997
+ # Rename the table.
998
+ rename = f"bkup1_{self.tracking_table_name}"[: self.MAX_IDENTIFIER_LEN]
999
+ drop_table_statement = SQL("ALTER TABLE {0}.{1} RENAME TO {2}").format(
1000
+ Identifier(self.datastore.schema),
1001
+ Identifier(self.tracking_table_name),
1002
+ Identifier(rename),
1003
+ )
1004
+ curs.execute(drop_table_statement)
1005
+ # Create the table.
1006
+ super()._create_table(curs)
1007
+ # Maybe insert migration tracking record and application tracking records.
1008
+ if self.datastore.single_row_tracking and (
1009
+ not self.tracking_table_exists
1010
+ or (self.tracking_table_exists and not self.tracking_migration_previous)
1011
+ ):
1012
+ # Assume we just created a table for single-row tracking.
1013
+ self._insert_tracking(curs, Tracking(self.table_migration_identifier, 1))
1014
+ self.tracking_migration_current = 1
1015
+ for application_name, max_tracking_id in max_tracking_ids.items():
1016
+ self._insert_tracking(curs, Tracking(application_name, max_tracking_id))
1017
+
1018
+ @retry((InterfaceError, OperationalError), max_attempts=10, wait=0.2)
1019
+ def insert_tracking(self, tracking: Tracking) -> None:
1020
+ with self.datastore.transaction(commit=True) as curs:
1021
+ self._insert_tracking(curs, tracking)
1022
+
1023
+ def _insert_tracking(
1024
+ self,
1025
+ curs: Cursor[DictRow],
1026
+ tracking: Tracking,
1027
+ ) -> None:
1028
+ self._check_has_multi_row_tracking_table(curs)
1029
+
1030
+ curs.execute(
1031
+ query=self.insert_tracking_statement,
1032
+ params={
1033
+ "application_name": tracking.application_name,
1034
+ "notification_id": tracking.notification_id,
1035
+ },
1036
+ prepare=True,
1037
+ )
1038
+ if self.datastore.single_row_tracking:
1039
+ fetchone = curs.fetchone()
1040
+ if fetchone is None:
1041
+ msg = (
1042
+ "Failed to record tracking for "
1043
+ f"{tracking.application_name} {tracking.notification_id}"
1044
+ )
1045
+ raise IntegrityError(msg)
1046
+
1047
+ def _check_has_multi_row_tracking_table(self, c: Cursor[DictRow]) -> None:
1048
+ if (
1049
+ not self.datastore.single_row_tracking
1050
+ and not self.has_checked_for_multi_row_tracking_table
1051
+ and self._max_tracking_id(self.table_migration_identifier, c)
1052
+ ):
1053
+ msg = "Can't do multi-row tracking with single-row tracking table"
1054
+ raise ProgrammingError(msg)
1055
+ self.has_checked_for_multi_row_tracking_table = True
1056
+
1057
+ @retry((InterfaceError, OperationalError), max_attempts=10, wait=0.2)
1058
+ def max_tracking_id(self, application_name: str) -> int | None:
1059
+ with self.datastore.get_connection() as conn, conn.cursor() as curs:
1060
+ return self._max_tracking_id(application_name, curs)
1061
+
1062
+ def _max_tracking_id(
1063
+ self, application_name: str, curs: Cursor[DictRow]
1064
+ ) -> int | None:
1065
+ curs.execute(
1066
+ query=self.max_tracking_id_statement,
1067
+ params=(application_name,),
1068
+ prepare=True,
1069
+ )
1070
+ fetchone = curs.fetchone()
1071
+ assert fetchone is not None
1072
+ return fetchone["max"]
1073
+
1074
+ @retry((InterfaceError, OperationalError), max_attempts=10, wait=0.2)
1075
+ def has_tracking_id(
1076
+ self, application_name: str, notification_id: int | None
1077
+ ) -> bool:
1078
+ return super().has_tracking_id(application_name, notification_id)
1079
+
1080
+
1081
+ TPostgresTrackingRecorder = TypeVar(
1082
+ "TPostgresTrackingRecorder",
1083
+ bound=PostgresTrackingRecorder,
1084
+ default=PostgresTrackingRecorder,
1085
+ )
1086
+
1087
+
1088
+ class PostgresProcessRecorder(
1089
+ PostgresTrackingRecorder, PostgresApplicationRecorder, ProcessRecorder
1090
+ ):
1091
+ def __init__(
1092
+ self,
1093
+ datastore: PostgresDatastore,
1094
+ *,
1095
+ events_table_name: str = "stored_events",
1096
+ tracking_table_name: str = "notification_tracking",
1097
+ ):
1098
+ super().__init__(
1099
+ datastore,
1100
+ tracking_table_name=tracking_table_name,
1101
+ events_table_name=events_table_name,
1102
+ )
1103
+
1104
+ def _insert_events(
1105
+ self,
1106
+ curs: Cursor[DictRow],
1107
+ stored_events: Sequence[StoredEvent],
1108
+ **kwargs: Any,
1109
+ ) -> None:
1110
+ tracking: Tracking | None = kwargs.get("tracking")
1111
+ if tracking is not None:
1112
+ self._insert_tracking(curs, tracking=tracking)
1113
+ super()._insert_events(curs, stored_events, **kwargs)
1114
+
1115
+
1116
+ class BasePostgresFactory(BaseInfrastructureFactory[TTrackingRecorder]):
1117
+ POSTGRES_DBNAME = "POSTGRES_DBNAME"
1118
+ POSTGRES_HOST = "POSTGRES_HOST"
1119
+ POSTGRES_PORT = "POSTGRES_PORT"
1120
+ POSTGRES_USER = "POSTGRES_USER"
1121
+ POSTGRES_PASSWORD = "POSTGRES_PASSWORD" # noqa: S105
1122
+ POSTGRES_GET_PASSWORD_TOPIC = "POSTGRES_GET_PASSWORD_TOPIC" # noqa: S105
1123
+ POSTGRES_CONNECT_TIMEOUT = "POSTGRES_CONNECT_TIMEOUT"
1124
+ POSTGRES_CONN_MAX_AGE = "POSTGRES_CONN_MAX_AGE"
1125
+ POSTGRES_PRE_PING = "POSTGRES_PRE_PING"
1126
+ POSTGRES_MAX_WAITING = "POSTGRES_MAX_WAITING"
1127
+ POSTGRES_LOCK_TIMEOUT = "POSTGRES_LOCK_TIMEOUT"
1128
+ POSTGRES_POOL_SIZE = "POSTGRES_POOL_SIZE"
1129
+ POSTGRES_MAX_OVERFLOW = "POSTGRES_MAX_OVERFLOW"
1130
+ POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT = (
1131
+ "POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT"
1132
+ )
1133
+ POSTGRES_SCHEMA = "POSTGRES_SCHEMA"
1134
+ POSTGRES_SINGLE_ROW_TRACKING = "SINGLE_ROW_TRACKING"
1135
+ ORIGINATOR_ID_TYPE = "ORIGINATOR_ID_TYPE"
1136
+ POSTGRES_ENABLE_DB_FUNCTIONS = "POSTGRES_ENABLE_DB_FUNCTIONS"
1137
+ CREATE_TABLE = "CREATE_TABLE"
1138
+
1139
+ def __init__(self, env: Environment | EnvType | None):
1140
+ super().__init__(env)
1141
+ dbname = self.env.get(self.POSTGRES_DBNAME)
1142
+ if dbname is None:
1143
+ msg = (
1144
+ "Postgres database name not found "
1145
+ "in environment with key "
1146
+ f"'{self.POSTGRES_DBNAME}'"
1147
+ )
1148
+ # TODO: Indicate both keys here, also for other environment variables.
1149
+ # ) + " or ".join(
1150
+ # [f"'{key}'" for key in self.env.create_keys(self.POSTGRES_DBNAME)]
1151
+ # )
1152
+ raise OSError(msg)
1153
+
1154
+ host = self.env.get(self.POSTGRES_HOST)
1155
+ if host is None:
1156
+ msg = (
1157
+ "Postgres host not found "
1158
+ "in environment with key "
1159
+ f"'{self.POSTGRES_HOST}'"
1160
+ )
1161
+ raise OSError(msg)
1162
+
1163
+ port = self.env.get(self.POSTGRES_PORT) or "5432"
1164
+
1165
+ user = self.env.get(self.POSTGRES_USER)
1166
+ if user is None:
1167
+ msg = (
1168
+ "Postgres user not found "
1169
+ "in environment with key "
1170
+ f"'{self.POSTGRES_USER}'"
1171
+ )
1172
+ raise OSError(msg)
1173
+
1174
+ get_password_func = None
1175
+ get_password_topic = self.env.get(self.POSTGRES_GET_PASSWORD_TOPIC)
1176
+ if not get_password_topic:
1177
+ password = self.env.get(self.POSTGRES_PASSWORD)
1178
+ if password is None:
1179
+ msg = (
1180
+ "Postgres password not found "
1181
+ "in environment with key "
1182
+ f"'{self.POSTGRES_PASSWORD}'"
1183
+ )
1184
+ raise OSError(msg)
1185
+ else:
1186
+ get_password_func = resolve_topic(get_password_topic)
1187
+ password = ""
1188
+
1189
+ connect_timeout = 30
1190
+ connect_timeout_str = self.env.get(self.POSTGRES_CONNECT_TIMEOUT)
1191
+ if connect_timeout_str:
1192
+ try:
1193
+ connect_timeout = int(connect_timeout_str)
1194
+ except ValueError:
1195
+ msg = (
1196
+ "Postgres environment value for key "
1197
+ f"'{self.POSTGRES_CONNECT_TIMEOUT}' is invalid. "
1198
+ "If set, an integer or empty string is expected: "
1199
+ f"'{connect_timeout_str}'"
1200
+ )
1201
+ raise OSError(msg) from None
1202
+
1203
+ idle_in_transaction_session_timeout_str = (
1204
+ self.env.get(self.POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT) or "5"
1205
+ )
1206
+
1207
+ try:
1208
+ idle_in_transaction_session_timeout = int(
1209
+ idle_in_transaction_session_timeout_str
1210
+ )
1211
+ except ValueError:
1212
+ msg = (
1213
+ "Postgres environment value for key "
1214
+ f"'{self.POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT}' is invalid. "
1215
+ "If set, an integer or empty string is expected: "
1216
+ f"'{idle_in_transaction_session_timeout_str}'"
1217
+ )
1218
+ raise OSError(msg) from None
1219
+
1220
+ pool_size = 5
1221
+ pool_size_str = self.env.get(self.POSTGRES_POOL_SIZE)
1222
+ if pool_size_str:
1223
+ try:
1224
+ pool_size = int(pool_size_str)
1225
+ except ValueError:
1226
+ msg = (
1227
+ "Postgres environment value for key "
1228
+ f"'{self.POSTGRES_POOL_SIZE}' is invalid. "
1229
+ "If set, an integer or empty string is expected: "
1230
+ f"'{pool_size_str}'"
1231
+ )
1232
+ raise OSError(msg) from None
1233
+
1234
+ pool_max_overflow = 10
1235
+ pool_max_overflow_str = self.env.get(self.POSTGRES_MAX_OVERFLOW)
1236
+ if pool_max_overflow_str:
1237
+ try:
1238
+ pool_max_overflow = int(pool_max_overflow_str)
1239
+ except ValueError:
1240
+ msg = (
1241
+ "Postgres environment value for key "
1242
+ f"'{self.POSTGRES_MAX_OVERFLOW}' is invalid. "
1243
+ "If set, an integer or empty string is expected: "
1244
+ f"'{pool_max_overflow_str}'"
1245
+ )
1246
+ raise OSError(msg) from None
1247
+
1248
+ max_waiting = 0
1249
+ max_waiting_str = self.env.get(self.POSTGRES_MAX_WAITING)
1250
+ if max_waiting_str:
1251
+ try:
1252
+ max_waiting = int(max_waiting_str)
1253
+ except ValueError:
1254
+ msg = (
1255
+ "Postgres environment value for key "
1256
+ f"'{self.POSTGRES_MAX_WAITING}' is invalid. "
1257
+ "If set, an integer or empty string is expected: "
1258
+ f"'{max_waiting_str}'"
1259
+ )
1260
+ raise OSError(msg) from None
1261
+
1262
+ conn_max_age = 60 * 60.0
1263
+ conn_max_age_str = self.env.get(self.POSTGRES_CONN_MAX_AGE)
1264
+ if conn_max_age_str:
1265
+ try:
1266
+ conn_max_age = float(conn_max_age_str)
1267
+ except ValueError:
1268
+ msg = (
1269
+ "Postgres environment value for key "
1270
+ f"'{self.POSTGRES_CONN_MAX_AGE}' is invalid. "
1271
+ "If set, a float or empty string is expected: "
1272
+ f"'{conn_max_age_str}'"
1273
+ )
1274
+ raise OSError(msg) from None
1275
+
1276
+ pre_ping = strtobool(self.env.get(self.POSTGRES_PRE_PING) or "no")
1277
+
1278
+ lock_timeout_str = self.env.get(self.POSTGRES_LOCK_TIMEOUT) or "0"
1279
+
1280
+ try:
1281
+ lock_timeout = int(lock_timeout_str)
1282
+ except ValueError:
1283
+ msg = (
1284
+ "Postgres environment value for key "
1285
+ f"'{self.POSTGRES_LOCK_TIMEOUT}' is invalid. "
1286
+ "If set, an integer or empty string is expected: "
1287
+ f"'{lock_timeout_str}'"
1288
+ )
1289
+ raise OSError(msg) from None
1290
+
1291
+ schema = self.env.get(self.POSTGRES_SCHEMA) or ""
1292
+
1293
+ single_row_tracking = strtobool(
1294
+ self.env.get(self.POSTGRES_SINGLE_ROW_TRACKING, "t")
1295
+ )
1296
+
1297
+ originator_id_type = cast(
1298
+ Literal["uuid", "text"],
1299
+ self.env.get(self.ORIGINATOR_ID_TYPE, "uuid"),
1300
+ )
1301
+ if originator_id_type.lower() not in ("uuid", "text"):
1302
+ msg = (
1303
+ f"Invalid {self.ORIGINATOR_ID_TYPE} '{originator_id_type}', "
1304
+ f"must be 'uuid' or 'text'"
1305
+ )
1306
+ raise OSError(msg)
1307
+
1308
+ enable_db_functions = strtobool(
1309
+ self.env.get(self.POSTGRES_ENABLE_DB_FUNCTIONS) or "no"
1310
+ )
1311
+
1312
+ self.datastore = PostgresDatastore(
1313
+ dbname=dbname,
1314
+ host=host,
1315
+ port=port,
1316
+ user=user,
1317
+ password=password,
1318
+ connect_timeout=connect_timeout,
1319
+ idle_in_transaction_session_timeout=idle_in_transaction_session_timeout,
1320
+ pool_size=pool_size,
1321
+ max_overflow=pool_max_overflow,
1322
+ max_waiting=max_waiting,
1323
+ conn_max_age=conn_max_age,
1324
+ pre_ping=pre_ping,
1325
+ lock_timeout=lock_timeout,
1326
+ schema=schema,
1327
+ get_password_func=get_password_func,
1328
+ single_row_tracking=single_row_tracking,
1329
+ originator_id_type=originator_id_type,
1330
+ enable_db_functions=enable_db_functions,
1331
+ )
1332
+
1333
+ def env_create_table(self) -> bool:
1334
+ return strtobool(self.env.get(self.CREATE_TABLE) or "yes")
1335
+
1336
+ def __enter__(self) -> Self:
1337
+ self.datastore.__enter__()
1338
+ return self
1339
+
1340
+ def __exit__(
1341
+ self,
1342
+ exc_type: type[BaseException] | None,
1343
+ exc_val: BaseException | None,
1344
+ exc_tb: TracebackType | None,
1345
+ ) -> None:
1346
+ self.datastore.__exit__(exc_type, exc_val, exc_tb)
1347
+
1348
+ def close(self) -> None:
1349
+ with contextlib.suppress(AttributeError):
1350
+ self.datastore.close()
1351
+
1352
+
1353
+ class PostgresFactory(
1354
+ BasePostgresFactory[PostgresTrackingRecorder],
1355
+ InfrastructureFactory[PostgresTrackingRecorder],
1356
+ ):
1357
+ aggregate_recorder_class = PostgresAggregateRecorder
1358
+ application_recorder_class = PostgresApplicationRecorder
1359
+ tracking_recorder_class = PostgresTrackingRecorder
1360
+ process_recorder_class = PostgresProcessRecorder
1361
+
1362
+ def aggregate_recorder(self, purpose: str = "events") -> AggregateRecorder:
1363
+ prefix = self.env.name.lower() or "stored"
1364
+ events_table_name = prefix + "_" + purpose
1365
+ recorder = type(self).aggregate_recorder_class(
1366
+ datastore=self.datastore,
1367
+ events_table_name=events_table_name,
1368
+ )
1369
+ if self.env_create_table():
1370
+ recorder.create_table()
1371
+ return recorder
1372
+
1373
+ def application_recorder(self) -> ApplicationRecorder:
1374
+ prefix = self.env.name.lower() or "stored"
1375
+ events_table_name = prefix + "_events"
1376
+ application_recorder_topic = self.env.get(self.APPLICATION_RECORDER_TOPIC)
1377
+ if application_recorder_topic:
1378
+ application_recorder_class: type[PostgresApplicationRecorder] = (
1379
+ resolve_topic(application_recorder_topic)
1380
+ )
1381
+ assert issubclass(application_recorder_class, PostgresApplicationRecorder)
1382
+ else:
1383
+ application_recorder_class = type(self).application_recorder_class
1384
+
1385
+ recorder = application_recorder_class(
1386
+ datastore=self.datastore,
1387
+ events_table_name=events_table_name,
1388
+ )
1389
+ if self.env_create_table():
1390
+ recorder.create_table()
1391
+ return recorder
1392
+
1393
+ def tracking_recorder(
1394
+ self, tracking_recorder_class: type[TPostgresTrackingRecorder] | None = None
1395
+ ) -> TPostgresTrackingRecorder:
1396
+ prefix = self.env.name.lower() or "notification"
1397
+ tracking_table_name = prefix + "_tracking"
1398
+ if tracking_recorder_class is None:
1399
+ tracking_recorder_topic = self.env.get(self.TRACKING_RECORDER_TOPIC)
1400
+ if tracking_recorder_topic:
1401
+ tracking_recorder_class = resolve_topic(tracking_recorder_topic)
1402
+ else:
1403
+ tracking_recorder_class = cast(
1404
+ "type[TPostgresTrackingRecorder]",
1405
+ type(self).tracking_recorder_class,
1406
+ )
1407
+ assert tracking_recorder_class is not None
1408
+ assert issubclass(tracking_recorder_class, PostgresTrackingRecorder)
1409
+ recorder = tracking_recorder_class(
1410
+ datastore=self.datastore,
1411
+ tracking_table_name=tracking_table_name,
1412
+ )
1413
+ if self.env_create_table():
1414
+ recorder.create_table()
1415
+ return recorder
1416
+
1417
+ def process_recorder(self) -> ProcessRecorder:
1418
+ prefix = self.env.name.lower() or "stored"
1419
+ events_table_name = prefix + "_events"
1420
+ prefix = self.env.name.lower() or "notification"
1421
+ tracking_table_name = prefix + "_tracking"
1422
+ process_recorder_topic = self.env.get(self.PROCESS_RECORDER_TOPIC)
1423
+ if process_recorder_topic:
1424
+ process_recorder_class: type[PostgresTrackingRecorder] = resolve_topic(
1425
+ process_recorder_topic
1426
+ )
1427
+ assert issubclass(process_recorder_class, PostgresProcessRecorder)
1428
+ else:
1429
+ process_recorder_class = type(self).process_recorder_class
1430
+
1431
+ recorder = process_recorder_class(
1432
+ datastore=self.datastore,
1433
+ events_table_name=events_table_name,
1434
+ tracking_table_name=tracking_table_name,
1435
+ )
1436
+ if self.env_create_table():
1437
+ recorder.create_table()
1438
+ return recorder
1439
+
1440
+
1441
+ Factory = PostgresFactory