dbos 1.1.0a3__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dbos/_sys_db.py CHANGED
@@ -1,7 +1,9 @@
1
1
  import datetime
2
+ import functools
2
3
  import json
3
4
  import logging
4
5
  import os
6
+ import random
5
7
  import re
6
8
  import threading
7
9
  import time
@@ -17,6 +19,7 @@ from typing import (
17
19
  Sequence,
18
20
  TypedDict,
19
21
  TypeVar,
22
+ cast,
20
23
  )
21
24
 
22
25
  import psycopg
@@ -27,11 +30,12 @@ from alembic.config import Config
27
30
  from sqlalchemy.exc import DBAPIError
28
31
  from sqlalchemy.sql import func
29
32
 
30
- from dbos._utils import INTERNAL_QUEUE_NAME
33
+ from dbos._utils import INTERNAL_QUEUE_NAME, retriable_postgres_exception
31
34
 
32
35
  from . import _serialization
33
36
  from ._context import get_local_dbos_context
34
37
  from ._error import (
38
+ DBOSAwaitedWorkflowCancelledError,
35
39
  DBOSConflictingWorkflowError,
36
40
  DBOSDeadLetterQueueError,
37
41
  DBOSNonExistentWorkflowError,
@@ -96,6 +100,10 @@ class WorkflowStatus:
96
100
  executor_id: Optional[str]
97
101
  # The application version on which this workflow was started
98
102
  app_version: Optional[str]
103
+ # The start-to-close timeout of the workflow in ms
104
+ workflow_timeout_ms: Optional[int]
105
+ # The deadline of a workflow, computed by adding its timeout to its start time.
106
+ workflow_deadline_epoch_ms: Optional[int]
99
107
 
100
108
  # INTERNAL FIELDS
101
109
 
@@ -263,6 +271,51 @@ class ThreadSafeConditionDict:
263
271
  dbos_logger.warning(f"Key {key} not found in condition dictionary.")
264
272
 
265
273
 
274
+ F = TypeVar("F", bound=Callable[..., Any])
275
+
276
+
277
+ def db_retry(
278
+ initial_backoff: float = 1.0, max_backoff: float = 60.0
279
+ ) -> Callable[[F], F]:
280
+ """
281
+ If a workflow encounters a database connection issue while performing an operation,
282
+ block the workflow and retry the operation until it reconnects and succeeds.
283
+
284
+ In other words, if DBOS loses its database connection, everything pauses until the connection is recovered,
285
+ trading off availability for correctness.
286
+ """
287
+
288
+ def decorator(func: F) -> F:
289
+ @functools.wraps(func)
290
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
291
+ retries: int = 0
292
+ backoff: float = initial_backoff
293
+ while True:
294
+ try:
295
+ return func(*args, **kwargs)
296
+ except DBAPIError as e:
297
+
298
+ # Determine if this is a retriable exception
299
+ if not retriable_postgres_exception(e):
300
+ raise
301
+
302
+ retries += 1
303
+ # Calculate backoff with jitter
304
+ actual_backoff: float = backoff * (0.5 + random.random())
305
+ dbos_logger.warning(
306
+ f"Database connection failed: {str(e)}. "
307
+ f"Retrying in {actual_backoff:.2f}s (attempt {retries})"
308
+ )
309
+ # Sleep with backoff
310
+ time.sleep(actual_backoff)
311
+ # Increase backoff for next attempt (exponential)
312
+ backoff = min(backoff * 2, max_backoff)
313
+
314
+ return cast(F, wrapper)
315
+
316
+ return decorator
317
+
318
+
266
319
  class SystemDatabase:
267
320
 
268
321
  def __init__(
@@ -360,7 +413,7 @@ class SystemDatabase:
360
413
  self.notification_conn.close()
361
414
  self.engine.dispose()
362
415
 
363
- def insert_workflow_status(
416
+ def _insert_workflow_status(
364
417
  self,
365
418
  status: WorkflowStatusInternal,
366
419
  conn: sa.Connection,
@@ -372,6 +425,15 @@ class SystemDatabase:
372
425
  wf_status: WorkflowStatuses = status["status"]
373
426
  workflow_deadline_epoch_ms: Optional[int] = status["workflow_deadline_epoch_ms"]
374
427
 
428
+ # Values to update when a row already exists for this workflow
429
+ update_values: dict[str, Any] = {
430
+ "recovery_attempts": SystemSchema.workflow_status.c.recovery_attempts + 1,
431
+ "updated_at": func.extract("epoch", func.now()) * 1000,
432
+ }
433
+ # Don't update an existing executor ID when enqueueing a workflow.
434
+ if wf_status != WorkflowStatusString.ENQUEUED.value:
435
+ update_values["executor_id"] = status["executor_id"]
436
+
375
437
  cmd = (
376
438
  pg.insert(SystemSchema.workflow_status)
377
439
  .values(
@@ -397,13 +459,7 @@ class SystemDatabase:
397
459
  )
398
460
  .on_conflict_do_update(
399
461
  index_elements=["workflow_uuid"],
400
- set_=dict(
401
- executor_id=status["executor_id"],
402
- recovery_attempts=(
403
- SystemSchema.workflow_status.c.recovery_attempts + 1
404
- ),
405
- updated_at=func.extract("epoch", func.now()) * 1000,
406
- ),
462
+ set_=update_values,
407
463
  )
408
464
  )
409
465
 
@@ -469,53 +525,46 @@ class SystemDatabase:
469
525
 
470
526
  return wf_status, workflow_deadline_epoch_ms
471
527
 
528
+ @db_retry()
472
529
  def update_workflow_status(
473
530
  self,
474
531
  status: WorkflowStatusInternal,
475
- *,
476
- conn: Optional[sa.Connection] = None,
477
532
  ) -> None:
478
533
  if self._debug_mode:
479
534
  raise Exception("called update_workflow_status in debug mode")
480
535
  wf_status: WorkflowStatuses = status["status"]
481
-
482
- cmd = (
483
- pg.insert(SystemSchema.workflow_status)
484
- .values(
485
- workflow_uuid=status["workflow_uuid"],
486
- status=status["status"],
487
- name=status["name"],
488
- class_name=status["class_name"],
489
- config_name=status["config_name"],
490
- output=status["output"],
491
- error=status["error"],
492
- executor_id=status["executor_id"],
493
- application_version=status["app_version"],
494
- application_id=status["app_id"],
495
- authenticated_user=status["authenticated_user"],
496
- authenticated_roles=status["authenticated_roles"],
497
- assumed_role=status["assumed_role"],
498
- queue_name=status["queue_name"],
499
- recovery_attempts=(
500
- 1 if wf_status != WorkflowStatusString.ENQUEUED.value else 0
501
- ),
502
- )
503
- .on_conflict_do_update(
504
- index_elements=["workflow_uuid"],
505
- set_=dict(
536
+ with self.engine.begin() as c:
537
+ c.execute(
538
+ pg.insert(SystemSchema.workflow_status)
539
+ .values(
540
+ workflow_uuid=status["workflow_uuid"],
506
541
  status=status["status"],
542
+ name=status["name"],
543
+ class_name=status["class_name"],
544
+ config_name=status["config_name"],
507
545
  output=status["output"],
508
546
  error=status["error"],
509
- updated_at=func.extract("epoch", func.now()) * 1000,
510
- ),
547
+ executor_id=status["executor_id"],
548
+ application_version=status["app_version"],
549
+ application_id=status["app_id"],
550
+ authenticated_user=status["authenticated_user"],
551
+ authenticated_roles=status["authenticated_roles"],
552
+ assumed_role=status["assumed_role"],
553
+ queue_name=status["queue_name"],
554
+ recovery_attempts=(
555
+ 1 if wf_status != WorkflowStatusString.ENQUEUED.value else 0
556
+ ),
557
+ )
558
+ .on_conflict_do_update(
559
+ index_elements=["workflow_uuid"],
560
+ set_=dict(
561
+ status=status["status"],
562
+ output=status["output"],
563
+ error=status["error"],
564
+ updated_at=func.extract("epoch", func.now()) * 1000,
565
+ ),
566
+ )
511
567
  )
512
- )
513
-
514
- if conn is not None:
515
- conn.execute(cmd)
516
- else:
517
- with self.engine.begin() as c:
518
- c.execute(cmd)
519
568
 
520
569
  def cancel_workflow(
521
570
  self,
@@ -596,18 +645,6 @@ class SystemDatabase:
596
645
  )
597
646
  )
598
647
 
599
- def get_max_function_id(self, workflow_uuid: str) -> Optional[int]:
600
- with self.engine.begin() as conn:
601
- max_function_id_row = conn.execute(
602
- sa.select(
603
- sa.func.max(SystemSchema.operation_outputs.c.function_id)
604
- ).where(SystemSchema.operation_outputs.c.workflow_uuid == workflow_uuid)
605
- ).fetchone()
606
-
607
- max_function_id = max_function_id_row[0] if max_function_id_row else None
608
-
609
- return max_function_id
610
-
611
648
  def fork_workflow(
612
649
  self,
613
650
  original_workflow_id: str,
@@ -693,6 +730,7 @@ class SystemDatabase:
693
730
  )
694
731
  return forked_workflow_id
695
732
 
733
+ @db_retry()
696
734
  def get_workflow_status(
697
735
  self, workflow_uuid: str
698
736
  ) -> Optional[WorkflowStatusInternal]:
@@ -742,6 +780,7 @@ class SystemDatabase:
742
780
  }
743
781
  return status
744
782
 
783
+ @db_retry()
745
784
  def await_workflow_result(self, workflow_id: str) -> Any:
746
785
  while True:
747
786
  with self.engine.begin() as c:
@@ -761,14 +800,14 @@ class SystemDatabase:
761
800
  error = row[2]
762
801
  raise _serialization.deserialize_exception(error)
763
802
  elif status == WorkflowStatusString.CANCELLED.value:
764
- # Raise a normal exception here, not the cancellation exception
803
+ # Raise AwaitedWorkflowCancelledError here, not the cancellation exception
765
804
  # because the awaiting workflow is not being cancelled.
766
- raise Exception(f"Awaited workflow {workflow_id} was cancelled")
805
+ raise DBOSAwaitedWorkflowCancelledError(workflow_id)
767
806
  else:
768
807
  pass # CB: I guess we're assuming the WF will show up eventually.
769
808
  time.sleep(1)
770
809
 
771
- def update_workflow_inputs(
810
+ def _update_workflow_inputs(
772
811
  self, workflow_uuid: str, inputs: str, conn: sa.Connection
773
812
  ) -> None:
774
813
  if self._debug_mode:
@@ -798,6 +837,7 @@ class SystemDatabase:
798
837
 
799
838
  return
800
839
 
840
+ @db_retry()
801
841
  def get_workflow_inputs(
802
842
  self, workflow_uuid: str
803
843
  ) -> Optional[_serialization.WorkflowInputs]:
@@ -837,6 +877,8 @@ class SystemDatabase:
837
877
  SystemSchema.workflow_inputs.c.inputs,
838
878
  SystemSchema.workflow_status.c.output,
839
879
  SystemSchema.workflow_status.c.error,
880
+ SystemSchema.workflow_status.c.workflow_deadline_epoch_ms,
881
+ SystemSchema.workflow_status.c.workflow_timeout_ms,
840
882
  ).join(
841
883
  SystemSchema.workflow_inputs,
842
884
  SystemSchema.workflow_status.c.workflow_uuid
@@ -918,6 +960,8 @@ class SystemDatabase:
918
960
  info.input = inputs
919
961
  info.output = output
920
962
  info.error = exception
963
+ info.workflow_deadline_epoch_ms = row[18]
964
+ info.workflow_timeout_ms = row[19]
921
965
 
922
966
  infos.append(info)
923
967
  return infos
@@ -947,6 +991,8 @@ class SystemDatabase:
947
991
  SystemSchema.workflow_inputs.c.inputs,
948
992
  SystemSchema.workflow_status.c.output,
949
993
  SystemSchema.workflow_status.c.error,
994
+ SystemSchema.workflow_status.c.workflow_deadline_epoch_ms,
995
+ SystemSchema.workflow_status.c.workflow_timeout_ms,
950
996
  ).select_from(
951
997
  SystemSchema.workflow_queue.join(
952
998
  SystemSchema.workflow_status,
@@ -1024,6 +1070,8 @@ class SystemDatabase:
1024
1070
  info.input = inputs
1025
1071
  info.output = output
1026
1072
  info.error = exception
1073
+ info.workflow_deadline_epoch_ms = row[18]
1074
+ info.workflow_timeout_ms = row[19]
1027
1075
 
1028
1076
  infos.append(info)
1029
1077
 
@@ -1083,8 +1131,8 @@ class SystemDatabase:
1083
1131
  for row in rows
1084
1132
  ]
1085
1133
 
1086
- def record_operation_result(
1087
- self, result: OperationResultInternal, conn: Optional[sa.Connection] = None
1134
+ def _record_operation_result_txn(
1135
+ self, result: OperationResultInternal, conn: sa.Connection
1088
1136
  ) -> None:
1089
1137
  if self._debug_mode:
1090
1138
  raise Exception("called record_operation_result in debug mode")
@@ -1099,16 +1147,18 @@ class SystemDatabase:
1099
1147
  error=error,
1100
1148
  )
1101
1149
  try:
1102
- if conn is not None:
1103
- conn.execute(sql)
1104
- else:
1105
- with self.engine.begin() as c:
1106
- c.execute(sql)
1150
+ conn.execute(sql)
1107
1151
  except DBAPIError as dbapi_error:
1108
1152
  if dbapi_error.orig.sqlstate == "23505": # type: ignore
1109
1153
  raise DBOSWorkflowConflictIDError(result["workflow_uuid"])
1110
1154
  raise
1111
1155
 
1156
+ @db_retry()
1157
+ def record_operation_result(self, result: OperationResultInternal) -> None:
1158
+ with self.engine.begin() as c:
1159
+ self._record_operation_result_txn(result, c)
1160
+
1161
+ @db_retry()
1112
1162
  def record_get_result(
1113
1163
  self, result_workflow_id: str, output: Optional[str], error: Optional[str]
1114
1164
  ) -> None:
@@ -1134,6 +1184,7 @@ class SystemDatabase:
1134
1184
  with self.engine.begin() as c:
1135
1185
  c.execute(sql)
1136
1186
 
1187
+ @db_retry()
1137
1188
  def record_child_workflow(
1138
1189
  self,
1139
1190
  parentUUID: str,
@@ -1158,13 +1209,12 @@ class SystemDatabase:
1158
1209
  raise DBOSWorkflowConflictIDError(parentUUID)
1159
1210
  raise
1160
1211
 
1161
- def check_operation_execution(
1212
+ def _check_operation_execution_txn(
1162
1213
  self,
1163
1214
  workflow_id: str,
1164
1215
  function_id: int,
1165
1216
  function_name: str,
1166
- *,
1167
- conn: Optional[sa.Connection] = None,
1217
+ conn: sa.Connection,
1168
1218
  ) -> Optional[RecordedResult]:
1169
1219
  # First query: Retrieve the workflow status
1170
1220
  workflow_status_sql = sa.select(
@@ -1182,13 +1232,8 @@ class SystemDatabase:
1182
1232
  )
1183
1233
 
1184
1234
  # Execute both queries
1185
- if conn is not None:
1186
- workflow_status_rows = conn.execute(workflow_status_sql).all()
1187
- operation_output_rows = conn.execute(operation_output_sql).all()
1188
- else:
1189
- with self.engine.begin() as c:
1190
- workflow_status_rows = c.execute(workflow_status_sql).all()
1191
- operation_output_rows = c.execute(operation_output_sql).all()
1235
+ workflow_status_rows = conn.execute(workflow_status_sql).all()
1236
+ operation_output_rows = conn.execute(operation_output_sql).all()
1192
1237
 
1193
1238
  # Check if the workflow exists
1194
1239
  assert (
@@ -1230,6 +1275,16 @@ class SystemDatabase:
1230
1275
  }
1231
1276
  return result
1232
1277
 
1278
+ @db_retry()
1279
+ def check_operation_execution(
1280
+ self, workflow_id: str, function_id: int, function_name: str
1281
+ ) -> Optional[RecordedResult]:
1282
+ with self.engine.begin() as c:
1283
+ return self._check_operation_execution_txn(
1284
+ workflow_id, function_id, function_name, c
1285
+ )
1286
+
1287
+ @db_retry()
1233
1288
  def check_child_workflow(
1234
1289
  self, workflow_uuid: str, function_id: int
1235
1290
  ) -> Optional[str]:
@@ -1247,6 +1302,7 @@ class SystemDatabase:
1247
1302
  return None
1248
1303
  return str(row[0])
1249
1304
 
1305
+ @db_retry()
1250
1306
  def send(
1251
1307
  self,
1252
1308
  workflow_uuid: str,
@@ -1258,7 +1314,7 @@ class SystemDatabase:
1258
1314
  function_name = "DBOS.send"
1259
1315
  topic = topic if topic is not None else _dbos_null_topic
1260
1316
  with self.engine.begin() as c:
1261
- recorded_output = self.check_operation_execution(
1317
+ recorded_output = self._check_operation_execution_txn(
1262
1318
  workflow_uuid, function_id, function_name, conn=c
1263
1319
  )
1264
1320
  if self._debug_mode and recorded_output is None:
@@ -1296,8 +1352,9 @@ class SystemDatabase:
1296
1352
  "output": None,
1297
1353
  "error": None,
1298
1354
  }
1299
- self.record_operation_result(output, conn=c)
1355
+ self._record_operation_result_txn(output, conn=c)
1300
1356
 
1357
+ @db_retry()
1301
1358
  def recv(
1302
1359
  self,
1303
1360
  workflow_uuid: str,
@@ -1390,7 +1447,7 @@ class SystemDatabase:
1390
1447
  message: Any = None
1391
1448
  if len(rows) > 0:
1392
1449
  message = _serialization.deserialize(rows[0][0])
1393
- self.record_operation_result(
1450
+ self._record_operation_result_txn(
1394
1451
  {
1395
1452
  "workflow_uuid": workflow_uuid,
1396
1453
  "function_id": function_id,
@@ -1454,13 +1511,14 @@ class SystemDatabase:
1454
1511
  dbos_logger.error(f"Unknown channel: {channel}")
1455
1512
  except Exception as e:
1456
1513
  if self._run_background_processes:
1457
- dbos_logger.error(f"Notification listener error: {e}")
1514
+ dbos_logger.warning(f"Notification listener error: {e}")
1458
1515
  time.sleep(1)
1459
1516
  # Then the loop will try to reconnect and restart the listener
1460
1517
  finally:
1461
1518
  if self.notification_conn is not None:
1462
1519
  self.notification_conn.close()
1463
1520
 
1521
+ @db_retry()
1464
1522
  def sleep(
1465
1523
  self,
1466
1524
  workflow_uuid: str,
@@ -1500,6 +1558,7 @@ class SystemDatabase:
1500
1558
  time.sleep(duration)
1501
1559
  return duration
1502
1560
 
1561
+ @db_retry()
1503
1562
  def set_event(
1504
1563
  self,
1505
1564
  workflow_uuid: str,
@@ -1509,7 +1568,7 @@ class SystemDatabase:
1509
1568
  ) -> None:
1510
1569
  function_name = "DBOS.setEvent"
1511
1570
  with self.engine.begin() as c:
1512
- recorded_output = self.check_operation_execution(
1571
+ recorded_output = self._check_operation_execution_txn(
1513
1572
  workflow_uuid, function_id, function_name, conn=c
1514
1573
  )
1515
1574
  if self._debug_mode and recorded_output is None:
@@ -1541,8 +1600,9 @@ class SystemDatabase:
1541
1600
  "output": None,
1542
1601
  "error": None,
1543
1602
  }
1544
- self.record_operation_result(output, conn=c)
1603
+ self._record_operation_result_txn(output, conn=c)
1545
1604
 
1605
+ @db_retry()
1546
1606
  def get_event(
1547
1607
  self,
1548
1608
  target_uuid: str,
@@ -1633,7 +1693,7 @@ class SystemDatabase:
1633
1693
  )
1634
1694
  return value
1635
1695
 
1636
- def enqueue(
1696
+ def _enqueue(
1637
1697
  self,
1638
1698
  workflow_id: str,
1639
1699
  queue_name: str,
@@ -1709,13 +1769,8 @@ class SystemDatabase:
1709
1769
  if num_recent_queries >= queue.limiter["limit"]:
1710
1770
  return []
1711
1771
 
1712
- # Dequeue functions eligible for this worker and ordered by the time at which they were enqueued.
1713
- # If there is a global or local concurrency limit N, select only the N oldest enqueued
1714
- # functions, else select all of them.
1715
-
1716
- # First lets figure out how many tasks are eligible for dequeue.
1717
- # This means figuring out how many unstarted tasks are within the local and global concurrency limits
1718
- running_tasks_query = (
1772
+ # Count how many workflows on this queue are currently PENDING both locally and globally.
1773
+ pending_tasks_query = (
1719
1774
  sa.select(
1720
1775
  SystemSchema.workflow_status.c.executor_id,
1721
1776
  sa.func.count().label("task_count"),
@@ -1729,41 +1784,37 @@ class SystemDatabase:
1729
1784
  )
1730
1785
  .where(SystemSchema.workflow_queue.c.queue_name == queue.name)
1731
1786
  .where(
1732
- SystemSchema.workflow_queue.c.started_at_epoch_ms.isnot(
1733
- None
1734
- ) # Task is started
1735
- )
1736
- .where(
1737
- SystemSchema.workflow_queue.c.completed_at_epoch_ms.is_(
1738
- None
1739
- ) # Task is not completed.
1787
+ SystemSchema.workflow_status.c.status
1788
+ == WorkflowStatusString.PENDING.value
1740
1789
  )
1741
1790
  .group_by(SystemSchema.workflow_status.c.executor_id)
1742
1791
  )
1743
- running_tasks_result = c.execute(running_tasks_query).fetchall()
1744
- running_tasks_result_dict = {row[0]: row[1] for row in running_tasks_result}
1745
- running_tasks_for_this_worker = running_tasks_result_dict.get(
1746
- executor_id, 0
1747
- ) # Get count for current executor
1792
+ pending_workflows = c.execute(pending_tasks_query).fetchall()
1793
+ pending_workflows_dict = {row[0]: row[1] for row in pending_workflows}
1794
+ local_pending_workflows = pending_workflows_dict.get(executor_id, 0)
1748
1795
 
1796
+ # Compute max_tasks, the number of workflows that can be dequeued given local and global concurrency limits,
1749
1797
  max_tasks = float("inf")
1750
1798
  if queue.worker_concurrency is not None:
1751
- max_tasks = max(
1752
- 0, queue.worker_concurrency - running_tasks_for_this_worker
1753
- )
1799
+ # Print a warning if the local concurrency limit is violated
1800
+ if local_pending_workflows > queue.worker_concurrency:
1801
+ dbos_logger.warning(
1802
+ f"The number of local pending workflows ({local_pending_workflows}) on queue {queue.name} exceeds the local concurrency limit ({queue.worker_concurrency})"
1803
+ )
1804
+ max_tasks = max(0, queue.worker_concurrency - local_pending_workflows)
1805
+
1754
1806
  if queue.concurrency is not None:
1755
- total_running_tasks = sum(running_tasks_result_dict.values())
1756
- # Queue global concurrency limit should always be >= running_tasks_count
1757
- # This should never happen but a check + warning doesn't hurt
1758
- if total_running_tasks > queue.concurrency:
1807
+ global_pending_workflows = sum(pending_workflows_dict.values())
1808
+ # Print a warning if the global concurrency limit is violated
1809
+ if global_pending_workflows > queue.concurrency:
1759
1810
  dbos_logger.warning(
1760
- f"Total running tasks ({total_running_tasks}) exceeds the global concurrency limit ({queue.concurrency})"
1811
+ f"The total number of pending workflows ({global_pending_workflows}) on queue {queue.name} exceeds the global concurrency limit ({queue.concurrency})"
1761
1812
  )
1762
- available_tasks = max(0, queue.concurrency - total_running_tasks)
1813
+ available_tasks = max(0, queue.concurrency - global_pending_workflows)
1763
1814
  max_tasks = min(max_tasks, available_tasks)
1764
1815
 
1765
1816
  # Retrieve the first max_tasks workflows in the queue.
1766
- # Only retrieve workflows of the appropriate version (or without version set)
1817
+ # Only retrieve workflows of the local version (or without version set)
1767
1818
  query = (
1768
1819
  sa.select(
1769
1820
  SystemSchema.workflow_queue.c.workflow_uuid,
@@ -1776,8 +1827,10 @@ class SystemDatabase:
1776
1827
  )
1777
1828
  )
1778
1829
  .where(SystemSchema.workflow_queue.c.queue_name == queue.name)
1779
- .where(SystemSchema.workflow_queue.c.started_at_epoch_ms == None)
1780
- .where(SystemSchema.workflow_queue.c.completed_at_epoch_ms == None)
1830
+ .where(
1831
+ SystemSchema.workflow_status.c.status
1832
+ == WorkflowStatusString.ENQUEUED.value
1833
+ )
1781
1834
  .where(
1782
1835
  sa.or_(
1783
1836
  SystemSchema.workflow_status.c.application_version
@@ -1806,20 +1859,16 @@ class SystemDatabase:
1806
1859
  ret_ids: list[str] = []
1807
1860
 
1808
1861
  for id in dequeued_ids:
1809
- # If we have a limiter, stop starting functions when the number
1810
- # of functions started this period exceeds the limit.
1862
+ # If we have a limiter, stop dequeueing workflows when the number
1863
+ # of workflows started this period exceeds the limit.
1811
1864
  if queue.limiter is not None:
1812
1865
  if len(ret_ids) + num_recent_queries >= queue.limiter["limit"]:
1813
1866
  break
1814
1867
 
1815
- # To start a function, first set its status to PENDING and update its executor ID
1816
- res = c.execute(
1868
+ # To start a workflow, first set its status to PENDING and update its executor ID
1869
+ c.execute(
1817
1870
  SystemSchema.workflow_status.update()
1818
1871
  .where(SystemSchema.workflow_status.c.workflow_uuid == id)
1819
- .where(
1820
- SystemSchema.workflow_status.c.status
1821
- == WorkflowStatusString.ENQUEUED.value
1822
- )
1823
1872
  .values(
1824
1873
  status=WorkflowStatusString.PENDING.value,
1825
1874
  application_version=app_version,
@@ -1827,8 +1876,13 @@ class SystemDatabase:
1827
1876
  # If a timeout is set, set the deadline on dequeue
1828
1877
  workflow_deadline_epoch_ms=sa.case(
1829
1878
  (
1830
- SystemSchema.workflow_status.c.workflow_timeout_ms.isnot(
1831
- None
1879
+ sa.and_(
1880
+ SystemSchema.workflow_status.c.workflow_timeout_ms.isnot(
1881
+ None
1882
+ ),
1883
+ SystemSchema.workflow_status.c.workflow_deadline_epoch_ms.is_(
1884
+ None
1885
+ ),
1832
1886
  ),
1833
1887
  sa.func.extract("epoch", sa.func.now()) * 1000
1834
1888
  + SystemSchema.workflow_status.c.workflow_timeout_ms,
@@ -1837,16 +1891,15 @@ class SystemDatabase:
1837
1891
  ),
1838
1892
  )
1839
1893
  )
1840
- if res.rowcount > 0:
1841
- # Then give it a start time and assign the executor ID
1842
- c.execute(
1843
- SystemSchema.workflow_queue.update()
1844
- .where(SystemSchema.workflow_queue.c.workflow_uuid == id)
1845
- .values(started_at_epoch_ms=start_time_ms)
1846
- )
1847
- ret_ids.append(id)
1894
+ # Then give it a start time
1895
+ c.execute(
1896
+ SystemSchema.workflow_queue.update()
1897
+ .where(SystemSchema.workflow_queue.c.workflow_uuid == id)
1898
+ .values(started_at_epoch_ms=start_time_ms)
1899
+ )
1900
+ ret_ids.append(id)
1848
1901
 
1849
- # If we have a limiter, garbage-collect all completed functions started
1902
+ # If we have a limiter, garbage-collect all completed workflows started
1850
1903
  # before the period. If there's no limiter, there's no need--they were
1851
1904
  # deleted on completion.
1852
1905
  if queue.limiter is not None:
@@ -1863,6 +1916,7 @@ class SystemDatabase:
1863
1916
  # Return the IDs of all functions we started
1864
1917
  return ret_ids
1865
1918
 
1919
+ @db_retry()
1866
1920
  def remove_from_queue(self, workflow_id: str, queue: "Queue") -> None:
1867
1921
  if self._debug_mode:
1868
1922
  raise Exception("called remove_from_queue in debug mode")
@@ -1951,6 +2005,7 @@ class SystemDatabase:
1951
2005
  )
1952
2006
  return result
1953
2007
 
2008
+ @db_retry()
1954
2009
  def init_workflow(
1955
2010
  self,
1956
2011
  status: WorkflowStatusInternal,
@@ -1963,17 +2018,17 @@ class SystemDatabase:
1963
2018
  Synchronously record the status and inputs for workflows in a single transaction
1964
2019
  """
1965
2020
  with self.engine.begin() as conn:
1966
- wf_status, workflow_deadline_epoch_ms = self.insert_workflow_status(
2021
+ wf_status, workflow_deadline_epoch_ms = self._insert_workflow_status(
1967
2022
  status, conn, max_recovery_attempts=max_recovery_attempts
1968
2023
  )
1969
2024
  # TODO: Modify the inputs if they were changed by `update_workflow_inputs`
1970
- self.update_workflow_inputs(status["workflow_uuid"], inputs, conn)
2025
+ self._update_workflow_inputs(status["workflow_uuid"], inputs, conn)
1971
2026
 
1972
2027
  if (
1973
2028
  status["queue_name"] is not None
1974
2029
  and wf_status == WorkflowStatusString.ENQUEUED.value
1975
2030
  ):
1976
- self.enqueue(
2031
+ self._enqueue(
1977
2032
  status["workflow_uuid"],
1978
2033
  status["queue_name"],
1979
2034
  conn,
@@ -1981,6 +2036,14 @@ class SystemDatabase:
1981
2036
  )
1982
2037
  return wf_status, workflow_deadline_epoch_ms
1983
2038
 
2039
+ def check_connection(self) -> None:
2040
+ try:
2041
+ with self.engine.begin() as conn:
2042
+ conn.execute(sa.text("SELECT 1")).fetchall()
2043
+ except Exception as e:
2044
+ dbos_logger.error(f"Error connecting to the DBOS system database: {e}")
2045
+ raise
2046
+
1984
2047
 
1985
2048
  def reset_system_database(postgres_db_url: sa.URL, sysdb_name: str) -> None:
1986
2049
  try: