dbos 1.13.0a3__py3-none-any.whl → 1.13.0a5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dbos might be problematic. Click here for more details.

dbos/_sys_db.py CHANGED
@@ -4,6 +4,7 @@ import json
4
4
  import random
5
5
  import threading
6
6
  import time
7
+ from abc import ABC, abstractmethod
7
8
  from enum import Enum
8
9
  from typing import (
9
10
  TYPE_CHECKING,
@@ -19,18 +20,15 @@ from typing import (
19
20
  cast,
20
21
  )
21
22
 
22
- import psycopg
23
23
  import sqlalchemy as sa
24
- import sqlalchemy.dialects.postgresql as pg
25
24
  from sqlalchemy.exc import DBAPIError
26
25
  from sqlalchemy.sql import func
27
26
 
28
- from dbos._migration import (
29
- ensure_dbos_schema,
30
- run_alembic_migrations,
31
- run_dbos_migrations,
27
+ from dbos._utils import (
28
+ INTERNAL_QUEUE_NAME,
29
+ retriable_postgres_exception,
30
+ retriable_sqlite_exception,
32
31
  )
33
- from dbos._utils import INTERNAL_QUEUE_NAME, retriable_postgres_exception
34
32
 
35
33
  from . import _serialization
36
34
  from ._context import get_local_dbos_context
@@ -316,10 +314,12 @@ def db_retry(
316
314
  while True:
317
315
  try:
318
316
  return func(*args, **kwargs)
319
- except DBAPIError as e:
317
+ except Exception as e:
320
318
 
321
319
  # Determine if this is a retriable exception
322
- if not retriable_postgres_exception(e):
320
+ if not retriable_postgres_exception(
321
+ e
322
+ ) and not retriable_sqlite_exception(e):
323
323
  raise
324
324
 
325
325
  retries += 1
@@ -339,7 +339,7 @@ def db_retry(
339
339
  return decorator
340
340
 
341
341
 
342
- class SystemDatabase:
342
+ class SystemDatabase(ABC):
343
343
 
344
344
  def __init__(
345
345
  self,
@@ -348,16 +348,13 @@ class SystemDatabase:
348
348
  engine_kwargs: Dict[str, Any],
349
349
  debug_mode: bool = False,
350
350
  ):
351
- # Set driver
352
- url = sa.make_url(system_database_url).set(drivername="postgresql+psycopg")
351
+ import sqlalchemy.dialects.postgresql as pg
352
+ import sqlalchemy.dialects.sqlite as sq
353
353
 
354
- self.engine = sa.create_engine(
355
- url,
356
- **engine_kwargs,
357
- )
354
+ self.dialect = sq if system_database_url.startswith("sqlite") else pg
355
+ self.engine = self._create_engine(system_database_url, engine_kwargs)
358
356
  self._engine_kwargs = engine_kwargs
359
357
 
360
- self.notification_conn: Optional[psycopg.connection.Connection] = None
361
358
  self.notifications_map = ThreadSafeConditionDict()
362
359
  self.workflow_events_map = ThreadSafeConditionDict()
363
360
 
@@ -365,40 +362,29 @@ class SystemDatabase:
365
362
  self._run_background_processes = True
366
363
  self._debug_mode = debug_mode
367
364
 
368
- # Run migrations
365
+ @abstractmethod
366
+ def _create_engine(
367
+ self, system_database_url: str, engine_kwargs: Dict[str, Any]
368
+ ) -> sa.Engine:
369
+ """Create a database engine specific to the database type."""
370
+ pass
371
+
372
+ @abstractmethod
369
373
  def run_migrations(self) -> None:
370
- if self._debug_mode:
371
- dbos_logger.warning("System database migrations are skipped in debug mode.")
372
- return
373
- system_db_url = self.engine.url
374
- sysdb_name = system_db_url.database
375
- # If the system database does not already exist, create it
376
- engine = sa.create_engine(
377
- system_db_url.set(database="postgres"), **self._engine_kwargs
378
- )
379
- with engine.connect() as conn:
380
- conn.execution_options(isolation_level="AUTOCOMMIT")
381
- if not conn.execute(
382
- sa.text("SELECT 1 FROM pg_database WHERE datname=:db_name"),
383
- parameters={"db_name": sysdb_name},
384
- ).scalar():
385
- dbos_logger.info(f"Creating system database {sysdb_name}")
386
- conn.execute(sa.text(f"CREATE DATABASE {sysdb_name}"))
387
- engine.dispose()
388
-
389
- using_dbos_migrations = ensure_dbos_schema(self.engine)
390
- if not using_dbos_migrations:
391
- # Complete the Alembic migrations, create the dbos_migrations table
392
- run_alembic_migrations(self.engine)
393
- run_dbos_migrations(self.engine)
374
+ """Run database migrations specific to the database type."""
375
+ pass
394
376
 
395
377
  # Destroy the pool when finished
396
378
  def destroy(self) -> None:
397
379
  self._run_background_processes = False
398
- if self.notification_conn is not None:
399
- self.notification_conn.close()
380
+ self._cleanup_connections()
400
381
  self.engine.dispose()
401
382
 
383
+ @abstractmethod
384
+ def _cleanup_connections(self) -> None:
385
+ """Clean up database-specific connections."""
386
+ pass
387
+
402
388
  def _insert_workflow_status(
403
389
  self,
404
390
  status: WorkflowStatusInternal,
@@ -406,6 +392,7 @@ class SystemDatabase:
406
392
  *,
407
393
  max_recovery_attempts: Optional[int],
408
394
  ) -> tuple[WorkflowStatuses, Optional[int]]:
395
+ """Insert or update workflow status using PostgreSQL upsert operations."""
409
396
  if self._debug_mode:
410
397
  raise Exception("called insert_workflow_status in debug mode")
411
398
  wf_status: WorkflowStatuses = status["status"]
@@ -421,14 +408,14 @@ class SystemDatabase:
421
408
  ),
422
409
  else_=SystemSchema.workflow_status.c.recovery_attempts,
423
410
  ),
424
- "updated_at": func.extract("epoch", func.now()) * 1000,
411
+ "updated_at": sa.func.extract("epoch", sa.func.now()) * 1000,
425
412
  }
426
413
  # Don't update an existing executor ID when enqueueing a workflow.
427
414
  if wf_status != WorkflowStatusString.ENQUEUED.value:
428
415
  update_values["executor_id"] = status["executor_id"]
429
416
 
430
417
  cmd = (
431
- pg.insert(SystemSchema.workflow_status)
418
+ self.dialect.insert(SystemSchema.workflow_status)
432
419
  .values(
433
420
  workflow_uuid=status["workflow_uuid"],
434
421
  status=status["status"],
@@ -459,13 +446,21 @@ class SystemDatabase:
459
446
  )
460
447
  )
461
448
 
462
- cmd = cmd.returning(SystemSchema.workflow_status.c.recovery_attempts, SystemSchema.workflow_status.c.status, SystemSchema.workflow_status.c.workflow_deadline_epoch_ms, SystemSchema.workflow_status.c.name, SystemSchema.workflow_status.c.class_name, SystemSchema.workflow_status.c.config_name, SystemSchema.workflow_status.c.queue_name) # type: ignore
449
+ cmd = cmd.returning(
450
+ SystemSchema.workflow_status.c.recovery_attempts,
451
+ SystemSchema.workflow_status.c.status,
452
+ SystemSchema.workflow_status.c.workflow_deadline_epoch_ms,
453
+ SystemSchema.workflow_status.c.name,
454
+ SystemSchema.workflow_status.c.class_name,
455
+ SystemSchema.workflow_status.c.config_name,
456
+ SystemSchema.workflow_status.c.queue_name,
457
+ )
463
458
 
464
459
  try:
465
460
  results = conn.execute(cmd)
466
461
  except DBAPIError as dbapi_error:
467
462
  # Unique constraint violation for the deduplication ID
468
- if dbapi_error.orig.sqlstate == "23505": # type: ignore
463
+ if self._is_unique_constraint_violation(dbapi_error):
469
464
  assert status["deduplication_id"] is not None
470
465
  assert status["queue_name"] is not None
471
466
  raise DBOSQueueDeduplicatedError(
@@ -591,7 +586,8 @@ class SystemDatabase:
591
586
  raise Exception("called resume_workflow in debug mode")
592
587
  with self.engine.begin() as c:
593
588
  # Execute with snapshot isolation in case of concurrent calls on the same workflow
594
- c.execute(sa.text("SET TRANSACTION ISOLATION LEVEL REPEATABLE READ"))
589
+ if self.engine.dialect.name == "postgresql":
590
+ c.execute(sa.text("SET TRANSACTION ISOLATION LEVEL REPEATABLE READ"))
595
591
  # Check the status of the workflow. If it is complete, do nothing.
596
592
  status_row = c.execute(
597
593
  sa.select(
@@ -637,7 +633,7 @@ class SystemDatabase:
637
633
  # Create an entry for the forked workflow with the same
638
634
  # initial values as the original.
639
635
  c.execute(
640
- pg.insert(SystemSchema.workflow_status).values(
636
+ sa.insert(SystemSchema.workflow_status).values(
641
637
  workflow_uuid=forked_workflow_id,
642
638
  status=WorkflowStatusString.ENQUEUED.value,
643
639
  name=status["name"],
@@ -851,7 +847,7 @@ class SystemDatabase:
851
847
  query = query.offset(input.offset)
852
848
 
853
849
  with self.engine.begin() as c:
854
- rows = c.execute(query)
850
+ rows = c.execute(query).fetchall()
855
851
 
856
852
  infos: List[WorkflowStatus] = []
857
853
  for row in rows:
@@ -962,7 +958,7 @@ class SystemDatabase:
962
958
  query = query.offset(input["offset"])
963
959
 
964
960
  with self.engine.begin() as c:
965
- rows = c.execute(query)
961
+ rows = c.execute(query).fetchall()
966
962
 
967
963
  infos: List[WorkflowStatus] = []
968
964
  for row in rows:
@@ -1066,7 +1062,7 @@ class SystemDatabase:
1066
1062
  error = result["error"]
1067
1063
  output = result["output"]
1068
1064
  assert error is None or output is None, "Only one of error or output can be set"
1069
- sql = pg.insert(SystemSchema.operation_outputs).values(
1065
+ sql = sa.insert(SystemSchema.operation_outputs).values(
1070
1066
  workflow_uuid=result["workflow_uuid"],
1071
1067
  function_id=result["function_id"],
1072
1068
  function_name=result["function_name"],
@@ -1076,7 +1072,7 @@ class SystemDatabase:
1076
1072
  try:
1077
1073
  conn.execute(sql)
1078
1074
  except DBAPIError as dbapi_error:
1079
- if dbapi_error.orig.sqlstate == "23505": # type: ignore
1075
+ if self._is_unique_constraint_violation(dbapi_error):
1080
1076
  raise DBOSWorkflowConflictIDError(result["workflow_uuid"])
1081
1077
  raise
1082
1078
 
@@ -1097,7 +1093,7 @@ class SystemDatabase:
1097
1093
  # Because there's no corresponding check, we do nothing on conflict
1098
1094
  # and do not raise a DBOSWorkflowConflictIDError
1099
1095
  sql = (
1100
- pg.insert(SystemSchema.operation_outputs)
1096
+ self.dialect.insert(SystemSchema.operation_outputs)
1101
1097
  .values(
1102
1098
  workflow_uuid=ctx.workflow_id,
1103
1099
  function_id=ctx.function_id,
@@ -1122,7 +1118,7 @@ class SystemDatabase:
1122
1118
  if self._debug_mode:
1123
1119
  raise Exception("called record_child_workflow in debug mode")
1124
1120
 
1125
- sql = pg.insert(SystemSchema.operation_outputs).values(
1121
+ sql = sa.insert(SystemSchema.operation_outputs).values(
1126
1122
  workflow_uuid=parentUUID,
1127
1123
  function_id=functionID,
1128
1124
  function_name=functionName,
@@ -1132,10 +1128,20 @@ class SystemDatabase:
1132
1128
  with self.engine.begin() as c:
1133
1129
  c.execute(sql)
1134
1130
  except DBAPIError as dbapi_error:
1135
- if dbapi_error.orig.sqlstate == "23505": # type: ignore
1131
+ if self._is_unique_constraint_violation(dbapi_error):
1136
1132
  raise DBOSWorkflowConflictIDError(parentUUID)
1137
1133
  raise
1138
1134
 
1135
+ @abstractmethod
1136
+ def _is_unique_constraint_violation(self, dbapi_error: DBAPIError) -> bool:
1137
+ """Check if the error is a unique constraint violation."""
1138
+ pass
1139
+
1140
+ @abstractmethod
1141
+ def _is_foreign_key_violation(self, dbapi_error: DBAPIError) -> bool:
1142
+ """Check if the error is a foreign key violation."""
1143
+ pass
1144
+
1139
1145
  def _check_operation_execution_txn(
1140
1146
  self,
1141
1147
  workflow_id: str,
@@ -1261,15 +1267,14 @@ class SystemDatabase:
1261
1267
 
1262
1268
  try:
1263
1269
  c.execute(
1264
- pg.insert(SystemSchema.notifications).values(
1270
+ sa.insert(SystemSchema.notifications).values(
1265
1271
  destination_uuid=destination_uuid,
1266
1272
  topic=topic,
1267
1273
  message=_serialization.serialize(message),
1268
1274
  )
1269
1275
  )
1270
1276
  except DBAPIError as dbapi_error:
1271
- # Foreign key violation
1272
- if dbapi_error.orig.sqlstate == "23503": # type: ignore
1277
+ if self._is_foreign_key_violation(dbapi_error):
1273
1278
  raise DBOSNonExistentWorkflowError(destination_uuid)
1274
1279
  raise
1275
1280
  output: OperationResultInternal = {
@@ -1344,29 +1349,25 @@ class SystemDatabase:
1344
1349
 
1345
1350
  # Transactionally consume and return the message if it's in the database, otherwise return null.
1346
1351
  with self.engine.begin() as c:
1347
- oldest_entry_cte = (
1348
- sa.select(
1349
- SystemSchema.notifications.c.destination_uuid,
1350
- SystemSchema.notifications.c.topic,
1351
- SystemSchema.notifications.c.message,
1352
- SystemSchema.notifications.c.created_at_epoch_ms,
1353
- )
1354
- .where(
1355
- SystemSchema.notifications.c.destination_uuid == workflow_uuid,
1356
- SystemSchema.notifications.c.topic == topic,
1357
- )
1358
- .order_by(SystemSchema.notifications.c.created_at_epoch_ms.asc())
1359
- .limit(1)
1360
- .cte("oldest_entry")
1361
- )
1362
1352
  delete_stmt = (
1363
1353
  sa.delete(SystemSchema.notifications)
1364
1354
  .where(
1365
- SystemSchema.notifications.c.destination_uuid
1366
- == oldest_entry_cte.c.destination_uuid,
1367
- SystemSchema.notifications.c.topic == oldest_entry_cte.c.topic,
1368
- SystemSchema.notifications.c.created_at_epoch_ms
1369
- == oldest_entry_cte.c.created_at_epoch_ms,
1355
+ SystemSchema.notifications.c.destination_uuid == workflow_uuid,
1356
+ SystemSchema.notifications.c.topic == topic,
1357
+ SystemSchema.notifications.c.message_uuid
1358
+ == (
1359
+ sa.select(SystemSchema.notifications.c.message_uuid)
1360
+ .where(
1361
+ SystemSchema.notifications.c.destination_uuid
1362
+ == workflow_uuid,
1363
+ SystemSchema.notifications.c.topic == topic,
1364
+ )
1365
+ .order_by(
1366
+ SystemSchema.notifications.c.created_at_epoch_ms.asc()
1367
+ )
1368
+ .limit(1)
1369
+ .scalar_subquery()
1370
+ ),
1370
1371
  )
1371
1372
  .returning(SystemSchema.notifications.c.message)
1372
1373
  )
@@ -1388,62 +1389,47 @@ class SystemDatabase:
1388
1389
  )
1389
1390
  return message
1390
1391
 
1392
+ @abstractmethod
1391
1393
  def _notification_listener(self) -> None:
1392
- while self._run_background_processes:
1393
- try:
1394
- # since we're using the psycopg connection directly, we need a url without the "+pycopg" suffix
1395
- url = sa.URL.create(
1396
- "postgresql", **self.engine.url.translate_connect_args()
1397
- )
1398
- # Listen to notifications
1399
- self.notification_conn = psycopg.connect(
1400
- url.render_as_string(hide_password=False), autocommit=True
1401
- )
1394
+ """Listen for database notifications using database-specific mechanisms."""
1395
+ pass
1402
1396
 
1403
- self.notification_conn.execute("LISTEN dbos_notifications_channel")
1404
- self.notification_conn.execute("LISTEN dbos_workflow_events_channel")
1397
+ @staticmethod
1398
+ def reset_system_database(database_url: str) -> None:
1399
+ """Reset the system database by calling the appropriate implementation."""
1400
+ if database_url.startswith("sqlite"):
1401
+ from ._sys_db_sqlite import SQLiteSystemDatabase
1405
1402
 
1406
- while self._run_background_processes:
1407
- gen = self.notification_conn.notifies()
1408
- for notify in gen:
1409
- channel = notify.channel
1410
- dbos_logger.debug(
1411
- f"Received notification on channel: {channel}, payload: {notify.payload}"
1412
- )
1413
- if channel == "dbos_notifications_channel":
1414
- if notify.payload:
1415
- condition = self.notifications_map.get(notify.payload)
1416
- if condition is None:
1417
- # No condition found for this payload
1418
- continue
1419
- condition.acquire()
1420
- condition.notify_all()
1421
- condition.release()
1422
- dbos_logger.debug(
1423
- f"Signaled notifications condition for {notify.payload}"
1424
- )
1425
- elif channel == "dbos_workflow_events_channel":
1426
- if notify.payload:
1427
- condition = self.workflow_events_map.get(notify.payload)
1428
- if condition is None:
1429
- # No condition found for this payload
1430
- continue
1431
- condition.acquire()
1432
- condition.notify_all()
1433
- condition.release()
1434
- dbos_logger.debug(
1435
- f"Signaled workflow_events condition for {notify.payload}"
1436
- )
1437
- else:
1438
- dbos_logger.error(f"Unknown channel: {channel}")
1439
- except Exception as e:
1440
- if self._run_background_processes:
1441
- dbos_logger.warning(f"Notification listener error: {e}")
1442
- time.sleep(1)
1443
- # Then the loop will try to reconnect and restart the listener
1444
- finally:
1445
- if self.notification_conn is not None:
1446
- self.notification_conn.close()
1403
+ SQLiteSystemDatabase._reset_system_database(database_url)
1404
+ else:
1405
+ from ._sys_db_postgres import PostgresSystemDatabase
1406
+
1407
+ PostgresSystemDatabase._reset_system_database(database_url)
1408
+
1409
+ @staticmethod
1410
+ def create(
1411
+ system_database_url: str,
1412
+ engine_kwargs: Dict[str, Any],
1413
+ debug_mode: bool = False,
1414
+ ) -> "SystemDatabase":
1415
+ """Factory method to create the appropriate SystemDatabase implementation based on URL."""
1416
+ if system_database_url.startswith("sqlite"):
1417
+ from ._sys_db_sqlite import SQLiteSystemDatabase
1418
+
1419
+ return SQLiteSystemDatabase(
1420
+ system_database_url=system_database_url,
1421
+ engine_kwargs=engine_kwargs,
1422
+ debug_mode=debug_mode,
1423
+ )
1424
+ else:
1425
+ # Default to PostgreSQL for postgresql://, postgres://, or other URLs
1426
+ from ._sys_db_postgres import PostgresSystemDatabase
1427
+
1428
+ return PostgresSystemDatabase(
1429
+ system_database_url=system_database_url,
1430
+ engine_kwargs=engine_kwargs,
1431
+ debug_mode=debug_mode,
1432
+ )
1447
1433
 
1448
1434
  @db_retry()
1449
1435
  def sleep(
@@ -1507,9 +1493,8 @@ class SystemDatabase:
1507
1493
  return # Already sent before
1508
1494
  else:
1509
1495
  dbos_logger.debug(f"Running set_event, id: {function_id}, key: {key}")
1510
-
1511
1496
  c.execute(
1512
- pg.insert(SystemSchema.workflow_events)
1497
+ self.dialect.insert(SystemSchema.workflow_events)
1513
1498
  .values(
1514
1499
  workflow_uuid=workflow_uuid,
1515
1500
  key=key,
@@ -1631,7 +1616,8 @@ class SystemDatabase:
1631
1616
  limiter_period_ms = int(queue.limiter["period"] * 1000)
1632
1617
  with self.engine.begin() as c:
1633
1618
  # Execute with snapshot isolation to ensure multiple workers respect limits
1634
- c.execute(sa.text("SET TRANSACTION ISOLATION LEVEL REPEATABLE READ"))
1619
+ if self.engine.dialect.name == "postgresql":
1620
+ c.execute(sa.text("SET TRANSACTION ISOLATION LEVEL REPEATABLE READ"))
1635
1621
 
1636
1622
  # If there is a limiter, compute how many functions have started in its period.
1637
1623
  if queue.limiter is not None:
@@ -2036,36 +2022,3 @@ class SystemDatabase:
2036
2022
  return cutoff_epoch_timestamp_ms, [
2037
2023
  row[0] for row in pending_enqueued_result
2038
2024
  ]
2039
-
2040
-
2041
- def reset_system_database(postgres_db_url: sa.URL, sysdb_name: str) -> None:
2042
- try:
2043
- # Connect to postgres default database
2044
- engine = sa.create_engine(
2045
- postgres_db_url.set(drivername="postgresql+psycopg"),
2046
- connect_args={"connect_timeout": 10},
2047
- )
2048
-
2049
- with engine.connect() as conn:
2050
- # Set autocommit required for database dropping
2051
- conn.execution_options(isolation_level="AUTOCOMMIT")
2052
-
2053
- # Terminate existing connections
2054
- conn.execute(
2055
- sa.text(
2056
- """
2057
- SELECT pg_terminate_backend(pg_stat_activity.pid)
2058
- FROM pg_stat_activity
2059
- WHERE pg_stat_activity.datname = :db_name
2060
- AND pid <> pg_backend_pid()
2061
- """
2062
- ),
2063
- {"db_name": sysdb_name},
2064
- )
2065
-
2066
- # Drop the database
2067
- conn.execute(sa.text(f"DROP DATABASE IF EXISTS {sysdb_name}"))
2068
-
2069
- except sa.exc.SQLAlchemyError as e:
2070
- dbos_logger.error(f"Error resetting system database: {str(e)}")
2071
- raise e
@@ -0,0 +1,173 @@
1
+ import time
2
+ from typing import Any, Dict, Optional
3
+
4
+ import psycopg
5
+ import sqlalchemy as sa
6
+ from sqlalchemy.exc import DBAPIError
7
+
8
+ from dbos._migration import (
9
+ ensure_dbos_schema,
10
+ run_alembic_migrations,
11
+ run_dbos_migrations,
12
+ )
13
+ from dbos._schemas.system_database import SystemSchema
14
+
15
+ from ._logger import dbos_logger
16
+ from ._sys_db import SystemDatabase
17
+
18
+
19
+ class PostgresSystemDatabase(SystemDatabase):
20
+ """PostgreSQL-specific implementation of SystemDatabase."""
21
+
22
+ def __init__(
23
+ self,
24
+ *,
25
+ system_database_url: str,
26
+ engine_kwargs: Dict[str, Any],
27
+ debug_mode: bool = False,
28
+ ):
29
+ super().__init__(
30
+ system_database_url=system_database_url,
31
+ engine_kwargs=engine_kwargs,
32
+ debug_mode=debug_mode,
33
+ )
34
+ self.notification_conn: Optional[psycopg.connection.Connection] = None
35
+
36
+ def _create_engine(
37
+ self, system_database_url: str, engine_kwargs: Dict[str, Any]
38
+ ) -> sa.Engine:
39
+ # TODO: Make the schema dynamic so this isn't needed
40
+ SystemSchema.workflow_status.schema = "dbos"
41
+ SystemSchema.operation_outputs.schema = "dbos"
42
+ SystemSchema.notifications.schema = "dbos"
43
+ SystemSchema.workflow_events.schema = "dbos"
44
+ SystemSchema.streams.schema = "dbos"
45
+ url = sa.make_url(system_database_url).set(drivername="postgresql+psycopg")
46
+ return sa.create_engine(url, **engine_kwargs)
47
+
48
+ def run_migrations(self) -> None:
49
+ """Run PostgreSQL-specific migrations."""
50
+ if self._debug_mode:
51
+ dbos_logger.warning("System database migrations are skipped in debug mode.")
52
+ return
53
+ system_db_url = self.engine.url
54
+ sysdb_name = system_db_url.database
55
+ # If the system database does not already exist, create it
56
+ engine = sa.create_engine(
57
+ system_db_url.set(database="postgres"), **self._engine_kwargs
58
+ )
59
+ with engine.connect() as conn:
60
+ conn.execution_options(isolation_level="AUTOCOMMIT")
61
+ if not conn.execute(
62
+ sa.text("SELECT 1 FROM pg_database WHERE datname=:db_name"),
63
+ parameters={"db_name": sysdb_name},
64
+ ).scalar():
65
+ dbos_logger.info(f"Creating system database {sysdb_name}")
66
+ conn.execute(sa.text(f"CREATE DATABASE {sysdb_name}"))
67
+ engine.dispose()
68
+
69
+ using_dbos_migrations = ensure_dbos_schema(self.engine)
70
+ if not using_dbos_migrations:
71
+ # Complete the Alembic migrations, create the dbos_migrations table
72
+ run_alembic_migrations(self.engine)
73
+ run_dbos_migrations(self.engine)
74
+
75
+ def _cleanup_connections(self) -> None:
76
+ """Clean up PostgreSQL-specific connections."""
77
+ if self.notification_conn is not None:
78
+ self.notification_conn.close()
79
+
80
+ def _is_unique_constraint_violation(self, dbapi_error: DBAPIError) -> bool:
81
+ """Check if the error is a unique constraint violation in PostgreSQL."""
82
+ return dbapi_error.orig.sqlstate == "23505" # type: ignore
83
+
84
+ def _is_foreign_key_violation(self, dbapi_error: DBAPIError) -> bool:
85
+ """Check if the error is a foreign key violation in PostgreSQL."""
86
+ return dbapi_error.orig.sqlstate == "23503" # type: ignore
87
+
88
+ @staticmethod
89
+ def _reset_system_database(database_url: str) -> None:
90
+ """Reset the PostgreSQL system database by dropping it."""
91
+ system_db_url = sa.make_url(database_url)
92
+ sysdb_name = system_db_url.database
93
+
94
+ if sysdb_name is None:
95
+ raise ValueError(f"System database name not found in URL {system_db_url}")
96
+
97
+ try:
98
+ # Connect to postgres default database
99
+ engine = sa.create_engine(
100
+ system_db_url.set(database="postgres", drivername="postgresql+psycopg"),
101
+ connect_args={"connect_timeout": 10},
102
+ )
103
+
104
+ with engine.connect() as conn:
105
+ # Set autocommit required for database dropping
106
+ conn.execution_options(isolation_level="AUTOCOMMIT")
107
+
108
+ # Drop the database
109
+ conn.execute(
110
+ sa.text(f"DROP DATABASE IF EXISTS {sysdb_name} WITH (FORCE)")
111
+ )
112
+ engine.dispose()
113
+ except Exception as e:
114
+ dbos_logger.error(f"Error resetting PostgreSQL system database: {str(e)}")
115
+ raise e
116
+
117
+ def _notification_listener(self) -> None:
118
+ """Listen for PostgreSQL notifications using psycopg."""
119
+ while self._run_background_processes:
120
+ try:
121
+ # since we're using the psycopg connection directly, we need a url without the "+psycopg" suffix
122
+ url = sa.URL.create(
123
+ "postgresql", **self.engine.url.translate_connect_args()
124
+ )
125
+ # Listen to notifications
126
+ self.notification_conn = psycopg.connect(
127
+ url.render_as_string(hide_password=False), autocommit=True
128
+ )
129
+
130
+ self.notification_conn.execute("LISTEN dbos_notifications_channel")
131
+ self.notification_conn.execute("LISTEN dbos_workflow_events_channel")
132
+
133
+ while self._run_background_processes:
134
+ gen = self.notification_conn.notifies()
135
+ for notify in gen:
136
+ channel = notify.channel
137
+ dbos_logger.debug(
138
+ f"Received notification on channel: {channel}, payload: {notify.payload}"
139
+ )
140
+ if channel == "dbos_notifications_channel":
141
+ if notify.payload:
142
+ condition = self.notifications_map.get(notify.payload)
143
+ if condition is None:
144
+ # No condition found for this payload
145
+ continue
146
+ condition.acquire()
147
+ condition.notify_all()
148
+ condition.release()
149
+ dbos_logger.debug(
150
+ f"Signaled notifications condition for {notify.payload}"
151
+ )
152
+ elif channel == "dbos_workflow_events_channel":
153
+ if notify.payload:
154
+ condition = self.workflow_events_map.get(notify.payload)
155
+ if condition is None:
156
+ # No condition found for this payload
157
+ continue
158
+ condition.acquire()
159
+ condition.notify_all()
160
+ condition.release()
161
+ dbos_logger.debug(
162
+ f"Signaled workflow_events condition for {notify.payload}"
163
+ )
164
+ else:
165
+ dbos_logger.error(f"Unknown channel: {channel}")
166
+ except Exception as e:
167
+ if self._run_background_processes:
168
+ dbos_logger.warning(f"Notification listener error: {e}")
169
+ time.sleep(1)
170
+ # Then the loop will try to reconnect and restart the listener
171
+ finally:
172
+ if self.notification_conn is not None:
173
+ self.notification_conn.close()