dbos 0.22.0a10__py3-none-any.whl → 0.23.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dbos might be problematic. Click here for more details.

dbos/_sys_db.py CHANGED
@@ -14,9 +14,7 @@ from typing import (
14
14
  Optional,
15
15
  Sequence,
16
16
  Set,
17
- Tuple,
18
17
  TypedDict,
19
- cast,
20
18
  )
21
19
 
22
20
  import psycopg
@@ -27,6 +25,8 @@ from alembic.config import Config
27
25
  from sqlalchemy.exc import DBAPIError
28
26
  from sqlalchemy.sql import func
29
27
 
28
+ from dbos._utils import GlobalParams
29
+
30
30
  from . import _serialization
31
31
  from ._dbos_config import ConfigFile
32
32
  from ._error import (
@@ -66,17 +66,19 @@ class WorkflowStatusInternal(TypedDict):
66
66
  name: str
67
67
  class_name: Optional[str]
68
68
  config_name: Optional[str]
69
+ authenticated_user: Optional[str]
70
+ assumed_role: Optional[str]
71
+ authenticated_roles: Optional[str] # JSON list of roles
69
72
  output: Optional[str] # JSON (jsonpickle)
73
+ request: Optional[str] # JSON (jsonpickle)
70
74
  error: Optional[str] # JSON (jsonpickle)
75
+ created_at: Optional[int] # Unix epoch timestamp in ms
76
+ updated_at: Optional[int] # Unix epoch timestamp in ms
77
+ queue_name: Optional[str]
71
78
  executor_id: Optional[str]
72
79
  app_version: Optional[str]
73
80
  app_id: Optional[str]
74
- request: Optional[str] # JSON (jsonpickle)
75
81
  recovery_attempts: Optional[int]
76
- authenticated_user: Optional[str]
77
- assumed_role: Optional[str]
78
- authenticated_roles: Optional[str] # JSON list of roles.
79
- queue_name: Optional[str]
80
82
 
81
83
 
82
84
  class RecordedResult(TypedDict):
@@ -102,19 +104,12 @@ class GetWorkflowsInput:
102
104
  Structure for argument to `get_workflows` function.
103
105
 
104
106
  This specifies the search criteria for workflow retrieval by `get_workflows`.
105
-
106
- Attributes:
107
- name(str): The name of the workflow function
108
- authenticated_user(str): The name of the user who invoked the function
109
- start_time(str): Beginning of search range for time of invocation, in ISO 8601 format
110
- end_time(str): End of search range for time of invocation, in ISO 8601 format
111
- status(str): Current status of the workflow invocation (see `WorkflowStatusString`)
112
- application_version(str): Application version that invoked the workflow
113
- limit(int): Limit on number of returned records
114
-
115
107
  """
116
108
 
117
109
  def __init__(self) -> None:
110
+ self.workflow_ids: Optional[List[str]] = (
111
+ None # Search only in these workflow IDs
112
+ )
118
113
  self.name: Optional[str] = None # The name of the workflow function
119
114
  self.authenticated_user: Optional[str] = None # The user who ran the workflow.
120
115
  self.start_time: Optional[str] = None # Timestamp in ISO 8601 format
@@ -126,15 +121,23 @@ class GetWorkflowsInput:
126
121
  self.limit: Optional[int] = (
127
122
  None # Return up to this many workflows IDs. IDs are ordered by workflow creation time.
128
123
  )
124
+ self.offset: Optional[int] = (
125
+ None # Offset into the matching records for pagination
126
+ )
127
+ self.sort_desc: bool = (
128
+ False # If true, sort by created_at in DESC order. Default false (in ASC order).
129
+ )
129
130
 
130
131
 
131
132
  class GetQueuedWorkflowsInput(TypedDict):
132
- queue_name: Optional[str]
133
- status: Optional[str]
133
+ queue_name: Optional[str] # Get workflows belonging to this queue
134
+ status: Optional[str] # Get workflows with this status
134
135
  start_time: Optional[str] # Timestamp in ISO 8601 format
135
136
  end_time: Optional[str] # Timestamp in ISO 8601 format
136
137
  limit: Optional[int] # Return up to this many workflows IDs.
138
+ offset: Optional[int] # Offset into the matching records for pagination
137
139
  name: Optional[str] # The name of the workflow function
140
+ sort_desc: Optional[bool] # Sort by created_at in DESC or ASC order
138
141
 
139
142
 
140
143
  class GetWorkflowsOutput:
@@ -148,25 +151,6 @@ class GetPendingWorkflowsOutput:
148
151
  self.queue_name: Optional[str] = queue_name
149
152
 
150
153
 
151
- class WorkflowInformation(TypedDict, total=False):
152
- workflow_uuid: str
153
- status: WorkflowStatuses # The status of the workflow.
154
- name: str # The name of the workflow function.
155
- workflow_class_name: str # The class name holding the workflow function.
156
- workflow_config_name: (
157
- str # The name of the configuration, if the class needs configuration
158
- )
159
- authenticated_user: str # The user who ran the workflow. Empty string if not set.
160
- assumed_role: str
161
- # The role used to run this workflow. Empty string if authorization is not required.
162
- authenticated_roles: List[str]
163
- # All roles the authenticated user has, if any.
164
- input: Optional[_serialization.WorkflowInputs]
165
- output: Optional[str]
166
- error: Optional[str]
167
- request: Optional[str]
168
-
169
-
170
154
  _dbos_null_topic = "__null__topic__"
171
155
  _buffer_flush_batch_size = 100
172
156
  _buffer_flush_interval_secs = 1.0
@@ -174,7 +158,7 @@ _buffer_flush_interval_secs = 1.0
174
158
 
175
159
  class SystemDatabase:
176
160
 
177
- def __init__(self, config: ConfigFile):
161
+ def __init__(self, config: ConfigFile, *, debug_mode: bool = False):
178
162
  self.config = config
179
163
 
180
164
  sysdb_name = (
@@ -183,28 +167,27 @@ class SystemDatabase:
183
167
  else config["database"]["app_db_name"] + SystemSchema.sysdb_suffix
184
168
  )
185
169
 
186
- # If the system database does not already exist, create it
187
- postgres_db_url = sa.URL.create(
188
- "postgresql+psycopg",
189
- username=config["database"]["username"],
190
- password=config["database"]["password"],
191
- host=config["database"]["hostname"],
192
- port=config["database"]["port"],
193
- database="postgres",
194
- # fills the "application_name" column in pg_stat_activity
195
- query={
196
- "application_name": f"dbos_transact_{os.environ.get('DBOS__VMID', 'local')}"
197
- },
198
- )
199
- engine = sa.create_engine(postgres_db_url)
200
- with engine.connect() as conn:
201
- conn.execution_options(isolation_level="AUTOCOMMIT")
202
- if not conn.execute(
203
- sa.text("SELECT 1 FROM pg_database WHERE datname=:db_name"),
204
- parameters={"db_name": sysdb_name},
205
- ).scalar():
206
- conn.execute(sa.text(f"CREATE DATABASE {sysdb_name}"))
207
- engine.dispose()
170
+ if not debug_mode:
171
+ # If the system database does not already exist, create it
172
+ postgres_db_url = sa.URL.create(
173
+ "postgresql+psycopg",
174
+ username=config["database"]["username"],
175
+ password=config["database"]["password"],
176
+ host=config["database"]["hostname"],
177
+ port=config["database"]["port"],
178
+ database="postgres",
179
+ # fills the "application_name" column in pg_stat_activity
180
+ query={"application_name": f"dbos_transact_{GlobalParams.executor_id}"},
181
+ )
182
+ engine = sa.create_engine(postgres_db_url)
183
+ with engine.connect() as conn:
184
+ conn.execution_options(isolation_level="AUTOCOMMIT")
185
+ if not conn.execute(
186
+ sa.text("SELECT 1 FROM pg_database WHERE datname=:db_name"),
187
+ parameters={"db_name": sysdb_name},
188
+ ).scalar():
189
+ conn.execute(sa.text(f"CREATE DATABASE {sysdb_name}"))
190
+ engine.dispose()
208
191
 
209
192
  system_db_url = sa.URL.create(
210
193
  "postgresql+psycopg",
@@ -214,9 +197,7 @@ class SystemDatabase:
214
197
  port=config["database"]["port"],
215
198
  database=sysdb_name,
216
199
  # fills the "application_name" column in pg_stat_activity
217
- query={
218
- "application_name": f"dbos_transact_{os.environ.get('DBOS__VMID', 'local')}"
219
- },
200
+ query={"application_name": f"dbos_transact_{GlobalParams.executor_id}"},
220
201
  )
221
202
 
222
203
  # Create a connection pool for the system database
@@ -225,25 +206,41 @@ class SystemDatabase:
225
206
  )
226
207
 
227
208
  # Run a schema migration for the system database
228
- migration_dir = os.path.join(
229
- os.path.dirname(os.path.realpath(__file__)), "_migrations"
230
- )
231
- alembic_cfg = Config()
232
- alembic_cfg.set_main_option("script_location", migration_dir)
233
- logging.getLogger("alembic").setLevel(logging.WARNING)
234
- # Alembic requires the % in URL-escaped parameters to itself be escaped to %%.
235
- escaped_conn_string = re.sub(
236
- r"%(?=[0-9A-Fa-f]{2})",
237
- "%%",
238
- self.engine.url.render_as_string(hide_password=False),
239
- )
240
- alembic_cfg.set_main_option("sqlalchemy.url", escaped_conn_string)
241
- try:
242
- command.upgrade(alembic_cfg, "head")
243
- except Exception as e:
244
- dbos_logger.warning(
245
- f"Exception during system database construction. This is most likely because the system database was configured using a later version of DBOS: {e}"
209
+ if not debug_mode:
210
+ migration_dir = os.path.join(
211
+ os.path.dirname(os.path.realpath(__file__)), "_migrations"
212
+ )
213
+ alembic_cfg = Config()
214
+ alembic_cfg.set_main_option("script_location", migration_dir)
215
+ logging.getLogger("alembic").setLevel(logging.WARNING)
216
+ # Alembic requires the % in URL-escaped parameters to itself be escaped to %%.
217
+ escaped_conn_string = re.sub(
218
+ r"%(?=[0-9A-Fa-f]{2})",
219
+ "%%",
220
+ self.engine.url.render_as_string(hide_password=False),
246
221
  )
222
+ alembic_cfg.set_main_option("sqlalchemy.url", escaped_conn_string)
223
+ try:
224
+ command.upgrade(alembic_cfg, "head")
225
+ except Exception as e:
226
+ dbos_logger.warning(
227
+ f"Exception during system database construction. This is most likely because the system database was configured using a later version of DBOS: {e}"
228
+ )
229
+ alembic_cfg = Config()
230
+ alembic_cfg.set_main_option("script_location", migration_dir)
231
+ # Alembic requires the % in URL-escaped parameters to itself be escaped to %%.
232
+ escaped_conn_string = re.sub(
233
+ r"%(?=[0-9A-Fa-f]{2})",
234
+ "%%",
235
+ self.engine.url.render_as_string(hide_password=False),
236
+ )
237
+ alembic_cfg.set_main_option("sqlalchemy.url", escaped_conn_string)
238
+ try:
239
+ command.upgrade(alembic_cfg, "head")
240
+ except Exception as e:
241
+ dbos_logger.warning(
242
+ f"Exception during system database construction. This is most likely because the system database was configured using a later version of DBOS: {e}"
243
+ )
247
244
 
248
245
  self.notification_conn: Optional[psycopg.connection.Connection] = None
249
246
  self.notifications_map: Dict[str, threading.Condition] = {}
@@ -259,6 +256,7 @@ class SystemDatabase:
259
256
 
260
257
  # Now we can run background processes
261
258
  self._run_background_processes = True
259
+ self._debug_mode = debug_mode
262
260
 
263
261
  # Destroy the pool when finished
264
262
  def destroy(self) -> None:
@@ -280,6 +278,8 @@ class SystemDatabase:
280
278
  *,
281
279
  max_recovery_attempts: int = DEFAULT_MAX_RECOVERY_ATTEMPTS,
282
280
  ) -> WorkflowStatuses:
281
+ if self._debug_mode:
282
+ raise Exception("called insert_workflow_status in debug mode")
283
283
  wf_status: WorkflowStatuses = status["status"]
284
284
 
285
285
  cmd = (
@@ -307,6 +307,7 @@ class SystemDatabase:
307
307
  .on_conflict_do_update(
308
308
  index_elements=["workflow_uuid"],
309
309
  set_=dict(
310
+ executor_id=status["executor_id"],
310
311
  recovery_attempts=(
311
312
  SystemSchema.workflow_status.c.recovery_attempts + 1
312
313
  ),
@@ -378,6 +379,8 @@ class SystemDatabase:
378
379
  *,
379
380
  conn: Optional[sa.Connection] = None,
380
381
  ) -> None:
382
+ if self._debug_mode:
383
+ raise Exception("called update_workflow_status in debug mode")
381
384
  wf_status: WorkflowStatuses = status["status"]
382
385
 
383
386
  cmd = (
@@ -427,6 +430,8 @@ class SystemDatabase:
427
430
  self,
428
431
  workflow_id: str,
429
432
  ) -> None:
433
+ if self._debug_mode:
434
+ raise Exception("called cancel_workflow in debug mode")
430
435
  with self.engine.begin() as c:
431
436
  # Remove the workflow from the queues table so it does not block the table
432
437
  c.execute(
@@ -447,6 +452,8 @@ class SystemDatabase:
447
452
  self,
448
453
  workflow_id: str,
449
454
  ) -> None:
455
+ if self._debug_mode:
456
+ raise Exception("called resume_workflow in debug mode")
450
457
  with self.engine.begin() as c:
451
458
  # Check the status of the workflow. If it is complete, do nothing.
452
459
  row = c.execute(
@@ -490,27 +497,33 @@ class SystemDatabase:
490
497
  SystemSchema.workflow_status.c.assumed_role,
491
498
  SystemSchema.workflow_status.c.queue_name,
492
499
  SystemSchema.workflow_status.c.executor_id,
500
+ SystemSchema.workflow_status.c.created_at,
501
+ SystemSchema.workflow_status.c.updated_at,
502
+ SystemSchema.workflow_status.c.application_version,
503
+ SystemSchema.workflow_status.c.application_id,
493
504
  ).where(SystemSchema.workflow_status.c.workflow_uuid == workflow_uuid)
494
505
  ).fetchone()
495
506
  if row is None:
496
507
  return None
497
508
  status: WorkflowStatusInternal = {
498
509
  "workflow_uuid": workflow_uuid,
499
- "status": row[0],
500
- "name": row[1],
501
- "class_name": row[5],
502
- "config_name": row[4],
503
510
  "output": None,
504
511
  "error": None,
505
- "app_id": None,
506
- "app_version": None,
507
- "executor_id": row[10],
512
+ "status": row[0],
513
+ "name": row[1],
508
514
  "request": row[2],
509
515
  "recovery_attempts": row[3],
516
+ "config_name": row[4],
517
+ "class_name": row[5],
510
518
  "authenticated_user": row[6],
511
519
  "authenticated_roles": row[7],
512
520
  "assumed_role": row[8],
513
521
  "queue_name": row[9],
522
+ "executor_id": row[10],
523
+ "created_at": row[11],
524
+ "updated_at": row[12],
525
+ "app_version": row[13],
526
+ "app_id": row[14],
514
527
  }
515
528
  return status
516
529
 
@@ -539,47 +552,6 @@ class SystemDatabase:
539
552
  )
540
553
  return stat
541
554
 
542
- def get_workflow_status_w_outputs(
543
- self, workflow_uuid: str
544
- ) -> Optional[WorkflowStatusInternal]:
545
- with self.engine.begin() as c:
546
- row = c.execute(
547
- sa.select(
548
- SystemSchema.workflow_status.c.status,
549
- SystemSchema.workflow_status.c.name,
550
- SystemSchema.workflow_status.c.request,
551
- SystemSchema.workflow_status.c.output,
552
- SystemSchema.workflow_status.c.error,
553
- SystemSchema.workflow_status.c.config_name,
554
- SystemSchema.workflow_status.c.class_name,
555
- SystemSchema.workflow_status.c.authenticated_user,
556
- SystemSchema.workflow_status.c.authenticated_roles,
557
- SystemSchema.workflow_status.c.assumed_role,
558
- SystemSchema.workflow_status.c.queue_name,
559
- ).where(SystemSchema.workflow_status.c.workflow_uuid == workflow_uuid)
560
- ).fetchone()
561
- if row is None:
562
- return None
563
- status: WorkflowStatusInternal = {
564
- "workflow_uuid": workflow_uuid,
565
- "status": row[0],
566
- "name": row[1],
567
- "config_name": row[5],
568
- "class_name": row[6],
569
- "output": row[3],
570
- "error": row[4],
571
- "app_id": None,
572
- "app_version": None,
573
- "executor_id": None,
574
- "request": row[2],
575
- "recovery_attempts": None,
576
- "authenticated_user": row[7],
577
- "authenticated_roles": row[8],
578
- "assumed_role": row[9],
579
- "queue_name": row[10],
580
- }
581
- return status
582
-
583
555
  def await_workflow_result_internal(self, workflow_uuid: str) -> dict[str, Any]:
584
556
  polling_interval_secs: float = 1.000
585
557
 
@@ -626,24 +598,12 @@ class SystemDatabase:
626
598
  raise _serialization.deserialize_exception(stat["error"])
627
599
  return None
628
600
 
629
- def get_workflow_info(
630
- self, workflow_uuid: str, get_request: bool
631
- ) -> Optional[WorkflowInformation]:
632
- stat = self.get_workflow_status_w_outputs(workflow_uuid)
633
- if stat is None:
634
- return None
635
- info = cast(WorkflowInformation, stat)
636
- input = self.get_workflow_inputs(workflow_uuid)
637
- if input is not None:
638
- info["input"] = input
639
- if not get_request:
640
- info.pop("request", None)
641
-
642
- return info
643
-
644
601
  def update_workflow_inputs(
645
602
  self, workflow_uuid: str, inputs: str, conn: Optional[sa.Connection] = None
646
603
  ) -> None:
604
+ if self._debug_mode:
605
+ raise Exception("called update_workflow_inputs in debug mode")
606
+
647
607
  cmd = (
648
608
  pg.insert(SystemSchema.workflow_inputs)
649
609
  .values(
@@ -689,9 +649,11 @@ class SystemDatabase:
689
649
  return inputs
690
650
 
691
651
  def get_workflows(self, input: GetWorkflowsInput) -> GetWorkflowsOutput:
692
- query = sa.select(SystemSchema.workflow_status.c.workflow_uuid).order_by(
693
- SystemSchema.workflow_status.c.created_at.asc()
694
- )
652
+ query = sa.select(SystemSchema.workflow_status.c.workflow_uuid)
653
+ if input.sort_desc:
654
+ query = query.order_by(SystemSchema.workflow_status.c.created_at.desc())
655
+ else:
656
+ query = query.order_by(SystemSchema.workflow_status.c.created_at.asc())
695
657
  if input.name:
696
658
  query = query.where(SystemSchema.workflow_status.c.name == input.name)
697
659
  if input.authenticated_user:
@@ -716,28 +678,34 @@ class SystemDatabase:
716
678
  SystemSchema.workflow_status.c.application_version
717
679
  == input.application_version
718
680
  )
681
+ if input.workflow_ids:
682
+ query = query.where(
683
+ SystemSchema.workflow_status.c.workflow_uuid.in_(input.workflow_ids)
684
+ )
719
685
  if input.limit:
720
686
  query = query.limit(input.limit)
687
+ if input.offset:
688
+ query = query.offset(input.offset)
721
689
 
722
690
  with self.engine.begin() as c:
723
691
  rows = c.execute(query)
724
- workflow_uuids = [row[0] for row in rows]
692
+ workflow_ids = [row[0] for row in rows]
725
693
 
726
- return GetWorkflowsOutput(workflow_uuids)
694
+ return GetWorkflowsOutput(workflow_ids)
727
695
 
728
696
  def get_queued_workflows(
729
697
  self, input: GetQueuedWorkflowsInput
730
698
  ) -> GetWorkflowsOutput:
731
699
 
732
- query = (
733
- sa.select(SystemSchema.workflow_queue.c.workflow_uuid)
734
- .join(
735
- SystemSchema.workflow_status,
736
- SystemSchema.workflow_queue.c.workflow_uuid
737
- == SystemSchema.workflow_status.c.workflow_uuid,
738
- )
739
- .order_by(SystemSchema.workflow_status.c.created_at.asc())
700
+ query = sa.select(SystemSchema.workflow_queue.c.workflow_uuid).join(
701
+ SystemSchema.workflow_status,
702
+ SystemSchema.workflow_queue.c.workflow_uuid
703
+ == SystemSchema.workflow_status.c.workflow_uuid,
740
704
  )
705
+ if input["sort_desc"]:
706
+ query = query.order_by(SystemSchema.workflow_status.c.created_at.desc())
707
+ else:
708
+ query = query.order_by(SystemSchema.workflow_status.c.created_at.asc())
741
709
 
742
710
  if input.get("name"):
743
711
  query = query.where(SystemSchema.workflow_status.c.name == input["name"])
@@ -764,6 +732,8 @@ class SystemDatabase:
764
732
  )
765
733
  if input.get("limit"):
766
734
  query = query.limit(input["limit"])
735
+ if input.get("offset"):
736
+ query = query.offset(input["offset"])
767
737
 
768
738
  with self.engine.begin() as c:
769
739
  rows = c.execute(query)
@@ -798,6 +768,8 @@ class SystemDatabase:
798
768
  def record_operation_result(
799
769
  self, result: OperationResultInternal, conn: Optional[sa.Connection] = None
800
770
  ) -> None:
771
+ if self._debug_mode:
772
+ raise Exception("called record_operation_result in debug mode")
801
773
  error = result["error"]
802
774
  output = result["output"]
803
775
  assert error is None or output is None, "Only one of error or output can be set"
@@ -857,6 +829,11 @@ class SystemDatabase:
857
829
  recorded_output = self.check_operation_execution(
858
830
  workflow_uuid, function_id, conn=c
859
831
  )
832
+ if self._debug_mode and recorded_output is None:
833
+ raise Exception(
834
+ "called send in debug mode without a previous execution"
835
+ )
836
+
860
837
  if recorded_output is not None:
861
838
  dbos_logger.debug(
862
839
  f"Replaying send, id: {function_id}, destination_uuid: {destination_uuid}, topic: {topic}"
@@ -900,6 +877,8 @@ class SystemDatabase:
900
877
 
901
878
  # First, check for previous executions.
902
879
  recorded_output = self.check_operation_execution(workflow_uuid, function_id)
880
+ if self._debug_mode and recorded_output is None:
881
+ raise Exception("called recv in debug mode without a previous execution")
903
882
  if recorded_output is not None:
904
883
  dbos_logger.debug(f"Replaying recv, id: {function_id}, topic: {topic}")
905
884
  if recorded_output["output"] is not None:
@@ -1049,6 +1028,9 @@ class SystemDatabase:
1049
1028
  ) -> float:
1050
1029
  recorded_output = self.check_operation_execution(workflow_uuid, function_id)
1051
1030
  end_time: float
1031
+ if self._debug_mode and recorded_output is None:
1032
+ raise Exception("called sleep in debug mode without a previous execution")
1033
+
1052
1034
  if recorded_output is not None:
1053
1035
  dbos_logger.debug(f"Replaying sleep, id: {function_id}, seconds: {seconds}")
1054
1036
  assert recorded_output["output"] is not None, "no recorded end time"
@@ -1083,6 +1065,10 @@ class SystemDatabase:
1083
1065
  recorded_output = self.check_operation_execution(
1084
1066
  workflow_uuid, function_id, conn=c
1085
1067
  )
1068
+ if self._debug_mode and recorded_output is None:
1069
+ raise Exception(
1070
+ "called set_event in debug mode without a previous execution"
1071
+ )
1086
1072
  if recorded_output is not None:
1087
1073
  dbos_logger.debug(f"Replaying set_event, id: {function_id}, key: {key}")
1088
1074
  return # Already sent before
@@ -1127,6 +1113,10 @@ class SystemDatabase:
1127
1113
  recorded_output = self.check_operation_execution(
1128
1114
  caller_ctx["workflow_uuid"], caller_ctx["function_id"]
1129
1115
  )
1116
+ if self._debug_mode and recorded_output is None:
1117
+ raise Exception(
1118
+ "called get_event in debug mode without a previous execution"
1119
+ )
1130
1120
  if recorded_output is not None:
1131
1121
  dbos_logger.debug(
1132
1122
  f"Replaying get_event, id: {caller_ctx['function_id']}, key: {key}"
@@ -1189,6 +1179,9 @@ class SystemDatabase:
1189
1179
  return value
1190
1180
 
1191
1181
  def _flush_workflow_status_buffer(self) -> None:
1182
+ if self._debug_mode:
1183
+ raise Exception("called _flush_workflow_status_buffer in debug mode")
1184
+
1192
1185
  """Export the workflow status buffer to the database, up to the batch size."""
1193
1186
  if len(self._workflow_status_buffer) == 0:
1194
1187
  return
@@ -1219,6 +1212,9 @@ class SystemDatabase:
1219
1212
  break
1220
1213
 
1221
1214
  def _flush_workflow_inputs_buffer(self) -> None:
1215
+ if self._debug_mode:
1216
+ raise Exception("called _flush_workflow_inputs_buffer in debug mode")
1217
+
1222
1218
  """Export the workflow inputs buffer to the database, up to the batch size."""
1223
1219
  if len(self._workflow_inputs_buffer) == 0:
1224
1220
  return
@@ -1283,6 +1279,8 @@ class SystemDatabase:
1283
1279
  )
1284
1280
 
1285
1281
  def enqueue(self, workflow_id: str, queue_name: str) -> None:
1282
+ if self._debug_mode:
1283
+ raise Exception("called enqueue in debug mode")
1286
1284
  with self.engine.begin() as c:
1287
1285
  c.execute(
1288
1286
  pg.insert(SystemSchema.workflow_queue)
@@ -1294,6 +1292,9 @@ class SystemDatabase:
1294
1292
  )
1295
1293
 
1296
1294
  def start_queued_workflows(self, queue: "Queue", executor_id: str) -> List[str]:
1295
+ if self._debug_mode:
1296
+ return []
1297
+
1297
1298
  start_time_ms = int(time.time() * 1000)
1298
1299
  if queue.limiter is not None:
1299
1300
  limiter_period_ms = int(queue.limiter["period"] * 1000)
@@ -1323,24 +1324,32 @@ class SystemDatabase:
1323
1324
  # If there is a global or local concurrency limit N, select only the N oldest enqueued
1324
1325
  # functions, else select all of them.
1325
1326
 
1326
- # First lets figure out how many tasks the worker can dequeue
1327
+ # First lets figure out how many tasks are eligible for dequeue.
1328
+ # This means figuring out how many unstarted tasks are within the local and global concurrency limits
1327
1329
  running_tasks_query = (
1328
1330
  sa.select(
1329
- SystemSchema.workflow_queue.c.executor_id,
1331
+ SystemSchema.workflow_status.c.executor_id,
1330
1332
  sa.func.count().label("task_count"),
1331
1333
  )
1334
+ .select_from(
1335
+ SystemSchema.workflow_queue.join(
1336
+ SystemSchema.workflow_status,
1337
+ SystemSchema.workflow_queue.c.workflow_uuid
1338
+ == SystemSchema.workflow_status.c.workflow_uuid,
1339
+ )
1340
+ )
1332
1341
  .where(SystemSchema.workflow_queue.c.queue_name == queue.name)
1333
1342
  .where(
1334
- SystemSchema.workflow_queue.c.executor_id.isnot(
1343
+ SystemSchema.workflow_queue.c.started_at_epoch_ms.isnot(
1335
1344
  None
1336
- ) # Task is dequeued
1345
+ ) # Task is started
1337
1346
  )
1338
1347
  .where(
1339
1348
  SystemSchema.workflow_queue.c.completed_at_epoch_ms.is_(
1340
1349
  None
1341
- ) # Task is not completed
1350
+ ) # Task is not completed.
1342
1351
  )
1343
- .group_by(SystemSchema.workflow_queue.c.executor_id)
1352
+ .group_by(SystemSchema.workflow_status.c.executor_id)
1344
1353
  )
1345
1354
  running_tasks_result = c.execute(running_tasks_query).fetchall()
1346
1355
  running_tasks_result_dict = {row[0]: row[1] for row in running_tasks_result}
@@ -1350,12 +1359,6 @@ class SystemDatabase:
1350
1359
 
1351
1360
  max_tasks = float("inf")
1352
1361
  if queue.worker_concurrency is not None:
1353
- # Worker local concurrency limit should always be >= running_tasks_for_this_worker
1354
- # This should never happen but a check + warning doesn't hurt
1355
- if running_tasks_for_this_worker > queue.worker_concurrency:
1356
- dbos_logger.warning(
1357
- f"Number of tasks on this worker ({running_tasks_for_this_worker}) exceeds the worker concurrency limit ({queue.worker_concurrency})"
1358
- )
1359
1362
  max_tasks = max(
1360
1363
  0, queue.worker_concurrency - running_tasks_for_this_worker
1361
1364
  )
@@ -1370,16 +1373,14 @@ class SystemDatabase:
1370
1373
  available_tasks = max(0, queue.concurrency - total_running_tasks)
1371
1374
  max_tasks = min(max_tasks, available_tasks)
1372
1375
 
1373
- # Lookup tasks
1376
+ # Lookup unstarted/uncompleted tasks (not running)
1374
1377
  query = (
1375
1378
  sa.select(
1376
1379
  SystemSchema.workflow_queue.c.workflow_uuid,
1377
- SystemSchema.workflow_queue.c.started_at_epoch_ms,
1378
- SystemSchema.workflow_queue.c.executor_id,
1379
1380
  )
1380
1381
  .where(SystemSchema.workflow_queue.c.queue_name == queue.name)
1382
+ .where(SystemSchema.workflow_queue.c.started_at_epoch_ms == None)
1381
1383
  .where(SystemSchema.workflow_queue.c.completed_at_epoch_ms == None)
1382
- .where(SystemSchema.workflow_queue.c.executor_id == None)
1383
1384
  .order_by(SystemSchema.workflow_queue.c.created_at_epoch_ms.asc())
1384
1385
  .with_for_update(nowait=True) # Error out early
1385
1386
  )
@@ -1422,7 +1423,7 @@ class SystemDatabase:
1422
1423
  c.execute(
1423
1424
  SystemSchema.workflow_queue.update()
1424
1425
  .where(SystemSchema.workflow_queue.c.workflow_uuid == id)
1425
- .values(started_at_epoch_ms=start_time_ms, executor_id=executor_id)
1426
+ .values(started_at_epoch_ms=start_time_ms)
1426
1427
  )
1427
1428
  ret_ids.append(id)
1428
1429
 
@@ -1444,6 +1445,9 @@ class SystemDatabase:
1444
1445
  return ret_ids
1445
1446
 
1446
1447
  def remove_from_queue(self, workflow_id: str, queue: "Queue") -> None:
1448
+ if self._debug_mode:
1449
+ raise Exception("called remove_from_queue in debug mode")
1450
+
1447
1451
  with self.engine.begin() as c:
1448
1452
  if queue.limiter is None:
1449
1453
  c.execute(
@@ -1458,18 +1462,39 @@ class SystemDatabase:
1458
1462
  .values(completed_at_epoch_ms=int(time.time() * 1000))
1459
1463
  )
1460
1464
 
1461
- def clear_queue_assignment(self, workflow_id: str) -> None:
1462
- with self.engine.begin() as c:
1463
- c.execute(
1464
- sa.update(SystemSchema.workflow_queue)
1465
- .where(SystemSchema.workflow_queue.c.workflow_uuid == workflow_id)
1466
- .values(executor_id=None, started_at_epoch_ms=None)
1467
- )
1468
- c.execute(
1469
- sa.update(SystemSchema.workflow_status)
1470
- .where(SystemSchema.workflow_status.c.workflow_uuid == workflow_id)
1471
- .values(executor_id=None, status=WorkflowStatusString.ENQUEUED.value)
1472
- )
1465
+ def clear_queue_assignment(self, workflow_id: str) -> bool:
1466
+ if self._debug_mode:
1467
+ raise Exception("called clear_queue_assignment in debug mode")
1468
+
1469
+ with self.engine.connect() as conn:
1470
+ with conn.begin() as transaction:
1471
+ # Reset the start time in the queue to mark it as not started
1472
+ res = conn.execute(
1473
+ sa.update(SystemSchema.workflow_queue)
1474
+ .where(SystemSchema.workflow_queue.c.workflow_uuid == workflow_id)
1475
+ .where(
1476
+ SystemSchema.workflow_queue.c.completed_at_epoch_ms.is_(None)
1477
+ )
1478
+ .values(started_at_epoch_ms=None)
1479
+ )
1480
+
1481
+ # If no rows were affected, the workflow is not anymore in the queue or was already completed
1482
+ if res.rowcount == 0:
1483
+ transaction.rollback()
1484
+ return False
1485
+
1486
+ # Reset the status of the task to "ENQUEUED"
1487
+ res = conn.execute(
1488
+ sa.update(SystemSchema.workflow_status)
1489
+ .where(SystemSchema.workflow_status.c.workflow_uuid == workflow_id)
1490
+ .values(status=WorkflowStatusString.ENQUEUED.value)
1491
+ )
1492
+ if res.rowcount == 0:
1493
+ # This should never happen
1494
+ raise Exception(
1495
+ f"UNREACHABLE: Workflow {workflow_id} is found in the workflow_queue table but not found in the workflow_status table"
1496
+ )
1497
+ return True
1473
1498
 
1474
1499
 
1475
1500
  def reset_system_database(config: ConfigFile) -> None: