dbos 1.12.0a2__py3-none-any.whl → 1.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dbos might be problematic. Click here for more details.

Files changed (40) hide show
  1. dbos/_alembic_migrations/versions/471b60d64126_dbos_migrations.py +35 -0
  2. dbos/_app_db.py +215 -80
  3. dbos/_client.py +30 -15
  4. dbos/_context.py +4 -0
  5. dbos/_core.py +7 -8
  6. dbos/_dbos.py +28 -18
  7. dbos/_dbos_config.py +124 -50
  8. dbos/_fastapi.py +3 -1
  9. dbos/_logger.py +3 -1
  10. dbos/_migration.py +322 -0
  11. dbos/_sys_db.py +122 -200
  12. dbos/_sys_db_postgres.py +173 -0
  13. dbos/_sys_db_sqlite.py +182 -0
  14. dbos/_tracer.py +5 -1
  15. dbos/_utils.py +10 -1
  16. dbos/cli/cli.py +238 -100
  17. dbos/cli/migration.py +2 -2
  18. dbos/dbos-config.schema.json +4 -0
  19. {dbos-1.12.0a2.dist-info → dbos-1.13.0.dist-info}/METADATA +1 -1
  20. dbos-1.13.0.dist-info/RECORD +78 -0
  21. dbos-1.12.0a2.dist-info/RECORD +0 -74
  22. /dbos/{_migrations → _alembic_migrations}/env.py +0 -0
  23. /dbos/{_migrations → _alembic_migrations}/script.py.mako +0 -0
  24. /dbos/{_migrations → _alembic_migrations}/versions/01ce9f07bd10_streaming.py +0 -0
  25. /dbos/{_migrations → _alembic_migrations}/versions/04ca4f231047_workflow_queues_executor_id.py +0 -0
  26. /dbos/{_migrations → _alembic_migrations}/versions/27ac6900c6ad_add_queue_dedup.py +0 -0
  27. /dbos/{_migrations → _alembic_migrations}/versions/50f3227f0b4b_fix_job_queue.py +0 -0
  28. /dbos/{_migrations → _alembic_migrations}/versions/5c361fc04708_added_system_tables.py +0 -0
  29. /dbos/{_migrations → _alembic_migrations}/versions/66478e1b95e5_consolidate_queues.py +0 -0
  30. /dbos/{_migrations → _alembic_migrations}/versions/83f3732ae8e7_workflow_timeout.py +0 -0
  31. /dbos/{_migrations → _alembic_migrations}/versions/933e86bdac6a_add_queue_priority.py +0 -0
  32. /dbos/{_migrations → _alembic_migrations}/versions/a3b18ad34abe_added_triggers.py +0 -0
  33. /dbos/{_migrations → _alembic_migrations}/versions/d76646551a6b_job_queue_limiter.py +0 -0
  34. /dbos/{_migrations → _alembic_migrations}/versions/d76646551a6c_workflow_queue.py +0 -0
  35. /dbos/{_migrations → _alembic_migrations}/versions/d994145b47b6_consolidate_inputs.py +0 -0
  36. /dbos/{_migrations → _alembic_migrations}/versions/eab0cc1d9a14_job_queue.py +0 -0
  37. /dbos/{_migrations → _alembic_migrations}/versions/f4b9b32ba814_functionname_childid_op_outputs.py +0 -0
  38. {dbos-1.12.0a2.dist-info → dbos-1.13.0.dist-info}/WHEEL +0 -0
  39. {dbos-1.12.0a2.dist-info → dbos-1.13.0.dist-info}/entry_points.txt +0 -0
  40. {dbos-1.12.0a2.dist-info → dbos-1.13.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,35 @@
1
+ """dbos_migrations
2
+
3
+ Revision ID: 471b60d64126
4
+ Revises: 01ce9f07bd10
5
+ Create Date: 2025-08-21 14:22:31.455266
6
+
7
+ """
8
+
9
+ from typing import Sequence, Union
10
+
11
+ import sqlalchemy as sa
12
+ from alembic import op
13
+
14
+ # revision identifiers, used by Alembic.
15
+ revision: str = "471b60d64126"
16
+ down_revision: Union[str, None] = "01ce9f07bd10"
17
+ branch_labels: Union[str, Sequence[str], None] = None
18
+ depends_on: Union[str, Sequence[str], None] = None
19
+
20
+
21
+ def upgrade() -> None:
22
+ # Create dbos_migrations table
23
+ op.create_table(
24
+ "dbos_migrations",
25
+ sa.Column("version", sa.BigInteger(), nullable=False),
26
+ sa.PrimaryKeyConstraint("version"),
27
+ schema="dbos",
28
+ )
29
+
30
+ # Insert initial version 1
31
+ op.execute("INSERT INTO dbos.dbos_migrations (version) VALUES (1)")
32
+
33
+
34
+ def downgrade() -> None:
35
+ op.drop_table("dbos_migrations", schema="dbos")
dbos/_app_db.py CHANGED
@@ -1,11 +1,14 @@
1
+ from abc import ABC, abstractmethod
1
2
  from typing import Any, Dict, List, Optional, TypedDict
2
3
 
4
+ import psycopg
3
5
  import sqlalchemy as sa
4
- import sqlalchemy.dialects.postgresql as pg
5
6
  from sqlalchemy import inspect, text
6
7
  from sqlalchemy.exc import DBAPIError
7
8
  from sqlalchemy.orm import Session, sessionmaker
8
9
 
10
+ from dbos._migration import get_sqlite_timestamp_expr
11
+
9
12
  from . import _serialization
10
13
  from ._error import DBOSUnexpectedStepError, DBOSWorkflowConflictIDError
11
14
  from ._logger import dbos_logger
@@ -29,7 +32,7 @@ class RecordedResult(TypedDict):
29
32
  error: Optional[str] # JSON (jsonpickle)
30
33
 
31
34
 
32
- class ApplicationDatabase:
35
+ class ApplicationDatabase(ABC):
33
36
 
34
37
  def __init__(
35
38
  self,
@@ -38,95 +41,37 @@ class ApplicationDatabase:
38
41
  engine_kwargs: Dict[str, Any],
39
42
  debug_mode: bool = False,
40
43
  ):
41
- app_db_url = sa.make_url(database_url).set(drivername="postgresql+psycopg")
42
-
43
- if engine_kwargs is None:
44
- engine_kwargs = {}
45
-
46
- self.engine = sa.create_engine(
47
- app_db_url,
48
- **engine_kwargs,
49
- )
44
+ self.engine = self._create_engine(database_url, engine_kwargs)
50
45
  self._engine_kwargs = engine_kwargs
51
46
  self.sessionmaker = sessionmaker(bind=self.engine)
52
47
  self.debug_mode = debug_mode
53
48
 
54
- def run_migrations(self) -> None:
55
- if self.debug_mode:
56
- dbos_logger.warning(
57
- "Application database migrations are skipped in debug mode."
58
- )
59
- return
60
- # Check if the database exists
61
- app_db_url = self.engine.url
62
- postgres_db_engine = sa.create_engine(
63
- app_db_url.set(database="postgres"),
64
- **self._engine_kwargs,
65
- )
66
- with postgres_db_engine.connect() as conn:
67
- conn.execution_options(isolation_level="AUTOCOMMIT")
68
- if not conn.execute(
69
- sa.text("SELECT 1 FROM pg_database WHERE datname=:db_name"),
70
- parameters={"db_name": app_db_url.database},
71
- ).scalar():
72
- conn.execute(sa.text(f"CREATE DATABASE {app_db_url.database}"))
73
- postgres_db_engine.dispose()
74
-
75
- # Create the dbos schema and transaction_outputs table in the application database
76
- with self.engine.begin() as conn:
77
- # Check if schema exists first
78
- schema_exists = conn.execute(
79
- sa.text(
80
- "SELECT 1 FROM information_schema.schemata WHERE schema_name = :schema_name"
81
- ),
82
- parameters={"schema_name": ApplicationSchema.schema},
83
- ).scalar()
84
-
85
- if not schema_exists:
86
- schema_creation_query = sa.text(
87
- f"CREATE SCHEMA {ApplicationSchema.schema}"
88
- )
89
- conn.execute(schema_creation_query)
90
-
91
- inspector = inspect(self.engine)
92
- if not inspector.has_table(
93
- "transaction_outputs", schema=ApplicationSchema.schema
94
- ):
95
- ApplicationSchema.metadata_obj.create_all(self.engine)
96
- else:
97
- columns = inspector.get_columns(
98
- "transaction_outputs", schema=ApplicationSchema.schema
99
- )
100
- column_names = [col["name"] for col in columns]
49
+ @abstractmethod
50
+ def _create_engine(
51
+ self, database_url: str, engine_kwargs: Dict[str, Any]
52
+ ) -> sa.Engine:
53
+ """Create a database engine specific to the database type."""
54
+ pass
101
55
 
102
- if "function_name" not in column_names:
103
- # Column missing, alter table to add it
104
- with self.engine.connect() as conn:
105
- conn.execute(
106
- text(
107
- f"""
108
- ALTER TABLE {ApplicationSchema.schema}.transaction_outputs
109
- ADD COLUMN function_name TEXT NOT NULL DEFAULT '';
110
- """
111
- )
112
- )
113
- conn.commit()
56
+ @abstractmethod
57
+ def run_migrations(self) -> None:
58
+ """Run database migrations specific to the database type."""
59
+ pass
114
60
 
115
61
  def destroy(self) -> None:
116
62
  self.engine.dispose()
117
63
 
118
- @staticmethod
119
64
  def record_transaction_output(
120
- session: Session, output: TransactionResultInternal
65
+ self, session: Session, output: TransactionResultInternal
121
66
  ) -> None:
122
67
  try:
123
68
  session.execute(
124
- pg.insert(ApplicationSchema.transaction_outputs).values(
69
+ sa.insert(ApplicationSchema.transaction_outputs).values(
125
70
  workflow_uuid=output["workflow_uuid"],
126
71
  function_id=output["function_id"],
127
72
  output=output["output"],
128
73
  error=None,
129
- txn_id=sa.text("(select pg_current_xact_id_if_assigned()::text)"),
74
+ txn_id="",
130
75
  txn_snapshot=output["txn_snapshot"],
131
76
  executor_id=(
132
77
  output["executor_id"] if output["executor_id"] else None
@@ -135,7 +80,7 @@ class ApplicationDatabase:
135
80
  )
136
81
  )
137
82
  except DBAPIError as dbapi_error:
138
- if dbapi_error.orig.sqlstate == "23505": # type: ignore
83
+ if self._is_unique_constraint_violation(dbapi_error):
139
84
  raise DBOSWorkflowConflictIDError(output["workflow_uuid"])
140
85
  raise
141
86
 
@@ -145,14 +90,12 @@ class ApplicationDatabase:
145
90
  try:
146
91
  with self.engine.begin() as conn:
147
92
  conn.execute(
148
- pg.insert(ApplicationSchema.transaction_outputs).values(
93
+ sa.insert(ApplicationSchema.transaction_outputs).values(
149
94
  workflow_uuid=output["workflow_uuid"],
150
95
  function_id=output["function_id"],
151
96
  output=None,
152
97
  error=output["error"],
153
- txn_id=sa.text(
154
- "(select pg_current_xact_id_if_assigned()::text)"
155
- ),
98
+ txn_id="",
156
99
  txn_snapshot=output["txn_snapshot"],
157
100
  executor_id=(
158
101
  output["executor_id"] if output["executor_id"] else None
@@ -161,7 +104,7 @@ class ApplicationDatabase:
161
104
  )
162
105
  )
163
106
  except DBAPIError as dbapi_error:
164
- if dbapi_error.orig.sqlstate == "23505": # type: ignore
107
+ if self._is_unique_constraint_violation(dbapi_error):
165
108
  raise DBOSWorkflowConflictIDError(output["workflow_uuid"])
166
109
  raise
167
110
 
@@ -283,3 +226,195 @@ class ApplicationDatabase:
283
226
  )
284
227
 
285
228
  c.execute(delete_query)
229
+
230
+ @abstractmethod
231
+ def _is_unique_constraint_violation(self, dbapi_error: DBAPIError) -> bool:
232
+ """Check if the error is a unique constraint violation."""
233
+ pass
234
+
235
+ @abstractmethod
236
+ def _is_serialization_error(self, dbapi_error: DBAPIError) -> bool:
237
+ """Check if the error is a serialization/concurrency error."""
238
+ pass
239
+
240
+ @staticmethod
241
+ def create(
242
+ database_url: str,
243
+ engine_kwargs: Dict[str, Any],
244
+ debug_mode: bool = False,
245
+ ) -> "ApplicationDatabase":
246
+ """Factory method to create the appropriate ApplicationDatabase implementation based on URL."""
247
+ if database_url.startswith("sqlite"):
248
+ return SQLiteApplicationDatabase(
249
+ database_url=database_url,
250
+ engine_kwargs=engine_kwargs,
251
+ debug_mode=debug_mode,
252
+ )
253
+ else:
254
+ # Default to PostgreSQL for postgresql://, postgres://, or other URLs
255
+ return PostgresApplicationDatabase(
256
+ database_url=database_url,
257
+ engine_kwargs=engine_kwargs,
258
+ debug_mode=debug_mode,
259
+ )
260
+
261
+
262
+ class PostgresApplicationDatabase(ApplicationDatabase):
263
+ """PostgreSQL-specific implementation of ApplicationDatabase."""
264
+
265
+ def _create_engine(
266
+ self, database_url: str, engine_kwargs: Dict[str, Any]
267
+ ) -> sa.Engine:
268
+ """Create a PostgreSQL engine."""
269
+ app_db_url = sa.make_url(database_url).set(drivername="postgresql+psycopg")
270
+
271
+ if engine_kwargs is None:
272
+ engine_kwargs = {}
273
+
274
+ # TODO: Make the schema dynamic so this isn't needed
275
+ ApplicationSchema.transaction_outputs.schema = "dbos"
276
+
277
+ return sa.create_engine(
278
+ app_db_url,
279
+ **engine_kwargs,
280
+ )
281
+
282
+ def run_migrations(self) -> None:
283
+ if self.debug_mode:
284
+ dbos_logger.warning(
285
+ "Application database migrations are skipped in debug mode."
286
+ )
287
+ return
288
+ # Check if the database exists
289
+ app_db_url = self.engine.url
290
+ postgres_db_engine = sa.create_engine(
291
+ app_db_url.set(database="postgres"),
292
+ **self._engine_kwargs,
293
+ )
294
+ with postgres_db_engine.connect() as conn:
295
+ conn.execution_options(isolation_level="AUTOCOMMIT")
296
+ if not conn.execute(
297
+ sa.text("SELECT 1 FROM pg_database WHERE datname=:db_name"),
298
+ parameters={"db_name": app_db_url.database},
299
+ ).scalar():
300
+ conn.execute(sa.text(f"CREATE DATABASE {app_db_url.database}"))
301
+ postgres_db_engine.dispose()
302
+
303
+ # Create the dbos schema and transaction_outputs table in the application database
304
+ with self.engine.begin() as conn:
305
+ # Check if schema exists first
306
+ schema_exists = conn.execute(
307
+ sa.text(
308
+ "SELECT 1 FROM information_schema.schemata WHERE schema_name = :schema_name"
309
+ ),
310
+ parameters={"schema_name": ApplicationSchema.schema},
311
+ ).scalar()
312
+
313
+ if not schema_exists:
314
+ schema_creation_query = sa.text(
315
+ f"CREATE SCHEMA {ApplicationSchema.schema}"
316
+ )
317
+ conn.execute(schema_creation_query)
318
+
319
+ inspector = inspect(self.engine)
320
+ if not inspector.has_table(
321
+ "transaction_outputs", schema=ApplicationSchema.schema
322
+ ):
323
+ ApplicationSchema.metadata_obj.create_all(self.engine)
324
+ else:
325
+ columns = inspector.get_columns(
326
+ "transaction_outputs", schema=ApplicationSchema.schema
327
+ )
328
+ column_names = [col["name"] for col in columns]
329
+
330
+ if "function_name" not in column_names:
331
+ # Column missing, alter table to add it
332
+ with self.engine.connect() as conn:
333
+ conn.execute(
334
+ text(
335
+ f"""
336
+ ALTER TABLE {ApplicationSchema.schema}.transaction_outputs
337
+ ADD COLUMN function_name TEXT NOT NULL DEFAULT '';
338
+ """
339
+ )
340
+ )
341
+ conn.commit()
342
+
343
+ def _is_unique_constraint_violation(self, dbapi_error: DBAPIError) -> bool:
344
+ """Check if the error is a unique constraint violation in PostgreSQL."""
345
+ return dbapi_error.orig.sqlstate == "23505" # type: ignore
346
+
347
+ def _is_serialization_error(self, dbapi_error: DBAPIError) -> bool:
348
+ """Check if the error is a serialization/concurrency error in PostgreSQL."""
349
+ # 40001: serialization_failure (MVCC conflict)
350
+ # 40P01: deadlock_detected
351
+ driver_error = dbapi_error.orig
352
+ return (
353
+ driver_error is not None
354
+ and isinstance(driver_error, psycopg.OperationalError)
355
+ and driver_error.sqlstate in ("40001", "40P01")
356
+ )
357
+
358
+
359
+ class SQLiteApplicationDatabase(ApplicationDatabase):
360
+ """SQLite-specific implementation of ApplicationDatabase."""
361
+
362
+ def _create_engine(
363
+ self, database_url: str, engine_kwargs: Dict[str, Any]
364
+ ) -> sa.Engine:
365
+ """Create a SQLite engine."""
366
+ # TODO: Make the schema dynamic so this isn't needed
367
+ ApplicationSchema.transaction_outputs.schema = None
368
+ return sa.create_engine(database_url)
369
+
370
+ def run_migrations(self) -> None:
371
+ if self.debug_mode:
372
+ dbos_logger.warning(
373
+ "Application database migrations are skipped in debug mode."
374
+ )
375
+ return
376
+
377
+ with self.engine.begin() as conn:
378
+ # Check if table exists
379
+ result = conn.execute(
380
+ sa.text(
381
+ "SELECT name FROM sqlite_master WHERE type='table' AND name='transaction_outputs'"
382
+ )
383
+ ).fetchone()
384
+
385
+ if result is None:
386
+ conn.execute(
387
+ sa.text(
388
+ f"""
389
+ CREATE TABLE transaction_outputs (
390
+ workflow_uuid TEXT NOT NULL,
391
+ function_id INTEGER NOT NULL,
392
+ output TEXT,
393
+ error TEXT,
394
+ txn_id TEXT,
395
+ txn_snapshot TEXT NOT NULL,
396
+ executor_id TEXT,
397
+ function_name TEXT NOT NULL DEFAULT '',
398
+ created_at BIGINT NOT NULL DEFAULT {get_sqlite_timestamp_expr()},
399
+ PRIMARY KEY (workflow_uuid, function_id)
400
+ )
401
+ """
402
+ )
403
+ )
404
+ conn.execute(
405
+ sa.text(
406
+ "CREATE INDEX transaction_outputs_created_at_index ON transaction_outputs (created_at)"
407
+ )
408
+ )
409
+
410
+ def _is_unique_constraint_violation(self, dbapi_error: DBAPIError) -> bool:
411
+ """Check if the error is a unique constraint violation in SQLite."""
412
+ return "UNIQUE constraint failed" in str(dbapi_error.orig)
413
+
414
+ def _is_serialization_error(self, dbapi_error: DBAPIError) -> bool:
415
+ """Check if the error is a serialization/concurrency error in SQLite."""
416
+ # SQLite database is locked or busy errors
417
+ error_msg = str(dbapi_error.orig).lower()
418
+ return (
419
+ "database is locked" in error_msg or "database table is locked" in error_msg
420
+ )
dbos/_client.py CHANGED
@@ -16,6 +16,7 @@ from typing import (
16
16
 
17
17
  from dbos._app_db import ApplicationDatabase
18
18
  from dbos._context import MaxPriority, MinPriority
19
+ from dbos._sys_db import SystemDatabase
19
20
 
20
21
  if sys.version_info < (3, 11):
21
22
  from typing_extensions import NotRequired
@@ -24,7 +25,11 @@ else:
24
25
 
25
26
  from dbos import _serialization
26
27
  from dbos._dbos import WorkflowHandle, WorkflowHandleAsync
27
- from dbos._dbos_config import get_system_database_url, is_valid_database_url
28
+ from dbos._dbos_config import (
29
+ get_application_database_url,
30
+ get_system_database_url,
31
+ is_valid_database_url,
32
+ )
28
33
  from dbos._error import DBOSException, DBOSNonExistentWorkflowError
29
34
  from dbos._registrations import DEFAULT_MAX_RECOVERY_ATTEMPTS
30
35
  from dbos._serialization import WorkflowInputs
@@ -112,21 +117,32 @@ class WorkflowHandleClientAsyncPolling(Generic[R]):
112
117
  class DBOSClient:
113
118
  def __init__(
114
119
  self,
115
- database_url: str,
120
+ database_url: Optional[str] = None, # DEPRECATED
116
121
  *,
117
122
  system_database_url: Optional[str] = None,
118
- system_database: Optional[str] = None,
123
+ application_database_url: Optional[str] = None,
124
+ system_database: Optional[str] = None, # DEPRECATED
119
125
  ):
120
- assert is_valid_database_url(database_url)
126
+ application_database_url = get_application_database_url(
127
+ {
128
+ "system_database_url": system_database_url,
129
+ "database_url": (
130
+ database_url if database_url else application_database_url
131
+ ),
132
+ }
133
+ )
134
+ system_database_url = get_system_database_url(
135
+ {
136
+ "system_database_url": system_database_url,
137
+ "database_url": application_database_url,
138
+ "database": {"sys_db_name": system_database},
139
+ }
140
+ )
141
+ assert is_valid_database_url(system_database_url)
142
+ assert is_valid_database_url(application_database_url)
121
143
  # We only create database connections but do not run migrations
122
- self._sys_db = SystemDatabase(
123
- system_database_url=get_system_database_url(
124
- {
125
- "system_database_url": system_database_url,
126
- "database_url": database_url,
127
- "database": {"sys_db_name": system_database},
128
- }
129
- ),
144
+ self._sys_db = SystemDatabase.create(
145
+ system_database_url=system_database_url,
130
146
  engine_kwargs={
131
147
  "pool_timeout": 30,
132
148
  "max_overflow": 0,
@@ -134,15 +150,14 @@ class DBOSClient:
134
150
  },
135
151
  )
136
152
  self._sys_db.check_connection()
137
- self._app_db = ApplicationDatabase(
138
- database_url=database_url,
153
+ self._app_db = ApplicationDatabase.create(
154
+ database_url=application_database_url,
139
155
  engine_kwargs={
140
156
  "pool_timeout": 30,
141
157
  "max_overflow": 0,
142
158
  "pool_size": 2,
143
159
  },
144
160
  )
145
- self._db_url = database_url
146
161
 
147
162
  def destroy(self) -> None:
148
163
  self._sys_db.destroy()
dbos/_context.py CHANGED
@@ -221,6 +221,8 @@ class DBOSContext:
221
221
  return None
222
222
 
223
223
  def _start_span(self, attributes: TracedAttributes) -> None:
224
+ if dbos_tracer.disable_otlp:
225
+ return
224
226
  attributes["operationUUID"] = (
225
227
  self.workflow_id if len(self.workflow_id) > 0 else None
226
228
  )
@@ -246,6 +248,8 @@ class DBOSContext:
246
248
  cm.__enter__()
247
249
 
248
250
  def _end_span(self, exc_value: Optional[BaseException]) -> None:
251
+ if dbos_tracer.disable_otlp:
252
+ return
249
253
  context_span = self.context_spans.pop()
250
254
  if exc_value is None:
251
255
  context_span.span.set_status(Status(StatusCode.OK))
dbos/_core.py CHANGED
@@ -356,6 +356,7 @@ def _get_wf_invoke_func(
356
356
  )
357
357
  return recorded_result
358
358
  try:
359
+ dbos._active_workflows_set.add(status["workflow_uuid"])
359
360
  output = func()
360
361
  if not dbos.debug_mode:
361
362
  dbos._sys_db.update_workflow_outcome(
@@ -378,6 +379,8 @@ def _get_wf_invoke_func(
378
379
  error=_serialization.serialize_exception(error),
379
380
  )
380
381
  raise
382
+ finally:
383
+ dbos._active_workflows_set.discard(status["workflow_uuid"])
381
384
 
382
385
  return persist
383
386
 
@@ -947,18 +950,14 @@ def decorate_transaction(
947
950
  assert (
948
951
  ctx.sql_session is not None
949
952
  ), "Cannot find a database connection"
950
- ApplicationDatabase.record_transaction_output(
953
+ dbos._app_db.record_transaction_output(
951
954
  ctx.sql_session, txn_output
952
955
  )
953
956
  break
954
957
  except DBAPIError as dbapi_error:
955
- driver_error = cast(
956
- Optional[psycopg.OperationalError], dbapi_error.orig
957
- )
958
- if retriable_postgres_exception(dbapi_error) or (
959
- driver_error is not None
960
- and driver_error.sqlstate == "40001"
961
- ):
958
+ if retriable_postgres_exception(
959
+ dbapi_error
960
+ ) or dbos._app_db._is_serialization_error(dbapi_error):
962
961
  # Retry on serialization failure
963
962
  span = ctx.get_current_span()
964
963
  if span:
dbos/_dbos.py CHANGED
@@ -32,7 +32,7 @@ from opentelemetry.trace import Span
32
32
  from rich import print
33
33
 
34
34
  from dbos._conductor.conductor import ConductorWebsocket
35
- from dbos._sys_db import WorkflowStatus
35
+ from dbos._sys_db import SystemDatabase, WorkflowStatus
36
36
  from dbos._utils import INTERNAL_QUEUE_NAME, GlobalParams
37
37
  from dbos._workflow_commands import fork_workflow, list_queued_workflows, list_workflows
38
38
 
@@ -70,7 +70,6 @@ from ._sys_db import (
70
70
  SystemDatabase,
71
71
  WorkflowStatus,
72
72
  _dbos_stream_closed_sentinel,
73
- reset_system_database,
74
73
  workflow_is_active,
75
74
  )
76
75
  from ._tracer import DBOSTracer, dbos_tracer
@@ -80,7 +79,6 @@ if TYPE_CHECKING:
80
79
  from ._kafka import _KafkaConsumerWorkflow
81
80
  from flask import Flask
82
81
 
83
- from sqlalchemy import make_url
84
82
  from sqlalchemy.orm import Session
85
83
 
86
84
  if sys.version_info < (3, 10):
@@ -293,16 +291,24 @@ class DBOS:
293
291
  return _dbos_global_instance
294
292
 
295
293
  @classmethod
296
- def destroy(cls, *, destroy_registry: bool = False) -> None:
294
+ def destroy(
295
+ cls,
296
+ *,
297
+ destroy_registry: bool = False,
298
+ workflow_completion_timeout_sec: int = 0,
299
+ ) -> None:
297
300
  global _dbos_global_instance
298
301
  if _dbos_global_instance is not None:
299
- _dbos_global_instance._destroy()
302
+ _dbos_global_instance._destroy(
303
+ workflow_completion_timeout_sec=workflow_completion_timeout_sec,
304
+ )
300
305
  _dbos_global_instance = None
301
306
  if destroy_registry:
302
307
  global _dbos_global_registry
303
308
  _dbos_global_registry = None
304
309
  GlobalParams.app_version = os.environ.get("DBOS__APPVERSION", "")
305
310
  GlobalParams.executor_id = os.environ.get("DBOS__VMID", "local")
311
+ dbos_logger.info("DBOS successfully shut down")
306
312
 
307
313
  def __init__(
308
314
  self,
@@ -337,6 +343,7 @@ class DBOS:
337
343
  self.conductor_key: Optional[str] = conductor_key
338
344
  self.conductor_websocket: Optional[ConductorWebsocket] = None
339
345
  self._background_event_loop: BackgroundEventLoop = BackgroundEventLoop()
346
+ self._active_workflows_set: set[str] = set()
340
347
 
341
348
  # Globally set the application version and executor ID.
342
349
  # In DBOS Cloud, instead use the values supplied through environment variables.
@@ -448,13 +455,13 @@ class DBOS:
448
455
  self._background_event_loop.start()
449
456
  assert self._config["database_url"] is not None
450
457
  assert self._config["database"]["sys_db_engine_kwargs"] is not None
451
- self._sys_db_field = SystemDatabase(
458
+ self._sys_db_field = SystemDatabase.create(
452
459
  system_database_url=get_system_database_url(self._config),
453
460
  engine_kwargs=self._config["database"]["sys_db_engine_kwargs"],
454
461
  debug_mode=debug_mode,
455
462
  )
456
463
  assert self._config["database"]["db_engine_kwargs"] is not None
457
- self._app_db_field = ApplicationDatabase(
464
+ self._app_db_field = ApplicationDatabase.create(
458
465
  database_url=self._config["database_url"],
459
466
  engine_kwargs=self._config["database"]["db_engine_kwargs"],
460
467
  debug_mode=debug_mode,
@@ -580,20 +587,25 @@ class DBOS:
580
587
  not self._launched
581
588
  ), "The system database cannot be reset after DBOS is launched. Resetting the system database is a destructive operation that should only be used in a test environment."
582
589
 
583
- sysdb_name = self._config["database"]["sys_db_name"]
584
- assert sysdb_name is not None
590
+ SystemDatabase.reset_system_database(get_system_database_url(self._config))
585
591
 
586
- assert self._config["database_url"] is not None
587
- pg_db_url = make_url(self._config["database_url"]).set(database="postgres")
588
-
589
- reset_system_database(pg_db_url, sysdb_name)
590
-
591
- def _destroy(self) -> None:
592
+ def _destroy(self, *, workflow_completion_timeout_sec: int) -> None:
592
593
  self._initialized = False
593
594
  for event in self.poller_stop_events:
594
595
  event.set()
595
596
  for event in self.background_thread_stop_events:
596
597
  event.set()
598
+ if workflow_completion_timeout_sec > 0:
599
+ deadline = time.time() + workflow_completion_timeout_sec
600
+ while time.time() < deadline:
601
+ time.sleep(1)
602
+ active_workflows = len(self._active_workflows_set)
603
+ if active_workflows > 0:
604
+ dbos_logger.info(
605
+ f"Attempting to shut down DBOS. {active_workflows} workflows remain active. IDs: {self._active_workflows_set}"
606
+ )
607
+ else:
608
+ break
597
609
  self._background_event_loop.stop()
598
610
  if self._sys_db_field is not None:
599
611
  self._sys_db_field.destroy()
@@ -609,10 +621,8 @@ class DBOS:
609
621
  and self.conductor_websocket.websocket is not None
610
622
  ):
611
623
  self.conductor_websocket.websocket.close()
612
- # CB - This needs work, some things ought to stop before DBs are tossed out,
613
- # on the other hand it hangs to move it
614
624
  if self._executor_field is not None:
615
- self._executor_field.shutdown(cancel_futures=True)
625
+ self._executor_field.shutdown(wait=False, cancel_futures=True)
616
626
  self._executor_field = None
617
627
  for bg_thread in self._background_threads:
618
628
  bg_thread.join()