dbos 0.25.0a16__py3-none-any.whl → 0.26.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dbos/_sys_db.py CHANGED
@@ -1,20 +1,23 @@
1
1
  import datetime
2
+ import json
2
3
  import logging
3
4
  import os
4
5
  import re
5
6
  import threading
6
7
  import time
8
+ import uuid
7
9
  from enum import Enum
8
10
  from typing import (
9
11
  TYPE_CHECKING,
10
12
  Any,
13
+ Callable,
11
14
  Dict,
12
15
  List,
13
16
  Literal,
14
17
  Optional,
15
18
  Sequence,
16
- Set,
17
19
  TypedDict,
20
+ TypeVar,
18
21
  )
19
22
 
20
23
  import psycopg
@@ -25,7 +28,7 @@ from alembic.config import Config
25
28
  from sqlalchemy.exc import DBAPIError
26
29
  from sqlalchemy.sql import func
27
30
 
28
- from dbos._utils import GlobalParams
31
+ from dbos._utils import INTERNAL_QUEUE_NAME, GlobalParams
29
32
 
30
33
  from . import _serialization
31
34
  from ._context import get_local_dbos_context
@@ -34,6 +37,8 @@ from ._error import (
34
37
  DBOSConflictingWorkflowError,
35
38
  DBOSDeadLetterQueueError,
36
39
  DBOSNonExistentWorkflowError,
40
+ DBOSUnexpectedStepError,
41
+ DBOSWorkflowCancelledError,
37
42
  DBOSWorkflowConflictIDError,
38
43
  )
39
44
  from ._logger import dbos_logger
@@ -60,6 +65,50 @@ WorkflowStatuses = Literal[
60
65
  ]
61
66
 
62
67
 
68
+ class WorkflowStatus:
69
+ # The workflow ID
70
+ workflow_id: str
71
+ # The workflow status. Must be one of ENQUEUED, PENDING, SUCCESS, ERROR, CANCELLED, or RETRIES_EXCEEDED
72
+ status: str
73
+ # The name of the workflow function
74
+ name: str
75
+ # The name of the workflow's class, if any
76
+ class_name: Optional[str]
77
+ # The name with which the workflow's class instance was configured, if any
78
+ config_name: Optional[str]
79
+ # The user who ran the workflow, if specified
80
+ authenticated_user: Optional[str]
81
+ # The role with which the workflow ran, if specified
82
+ assumed_role: Optional[str]
83
+ # All roles which the authenticated user could assume
84
+ authenticated_roles: Optional[list[str]]
85
+ # The deserialized workflow input object
86
+ input: Optional[_serialization.WorkflowInputs]
87
+ # The workflow's output, if any
88
+ output: Optional[Any] = None
89
+ # The error the workflow threw, if any
90
+ error: Optional[Exception] = None
91
+ # Workflow start time, as a Unix epoch timestamp in ms
92
+ created_at: Optional[int]
93
+ # Last time the workflow status was updated, as a Unix epoch timestamp in ms
94
+ updated_at: Optional[int]
95
+ # If this workflow was enqueued, on which queue
96
+ queue_name: Optional[str]
97
+ # The executor to most recently executed this workflow
98
+ executor_id: Optional[str]
99
+ # The application version on which this workflow was started
100
+ app_version: Optional[str]
101
+
102
+ # INTERNAL FIELDS
103
+
104
+ # The ID of the application executing this workflow
105
+ app_id: Optional[str]
106
+ # The number of times this workflow's execution has been attempted
107
+ recovery_attempts: Optional[int]
108
+ # The HTTP request that triggered the workflow, if known
109
+ request: Optional[str]
110
+
111
+
63
112
  class WorkflowStatusInternal(TypedDict):
64
113
  workflow_uuid: str
65
114
  status: WorkflowStatuses
@@ -79,6 +128,11 @@ class WorkflowStatusInternal(TypedDict):
79
128
  app_version: Optional[str]
80
129
  app_id: Optional[str]
81
130
  recovery_attempts: Optional[int]
131
+ # The start-to-close timeout of the workflow in ms
132
+ workflow_timeout_ms: Optional[int]
133
+ # The deadline of a workflow, computed by adding its timeout to its start time.
134
+ # Deadlines propagate to children. When the deadline is reached, the workflow is cancelled.
135
+ workflow_deadline_epoch_ms: Optional[int]
82
136
 
83
137
 
84
138
  class RecordedResult(TypedDict):
@@ -128,6 +182,9 @@ class GetWorkflowsInput:
128
182
  self.sort_desc: bool = (
129
183
  False # If true, sort by created_at in DESC order. Default false (in ASC order).
130
184
  )
185
+ self.workflow_id_prefix: Optional[str] = (
186
+ None # If set, search for workflow IDs starting with this string
187
+ )
131
188
 
132
189
 
133
190
  class GetQueuedWorkflowsInput(TypedDict):
@@ -141,11 +198,6 @@ class GetQueuedWorkflowsInput(TypedDict):
141
198
  sort_desc: Optional[bool] # Sort by created_at in DESC or ASC order
142
199
 
143
200
 
144
- class GetWorkflowsOutput:
145
- def __init__(self, workflow_uuids: List[str]):
146
- self.workflow_uuids = workflow_uuids
147
-
148
-
149
201
  class GetPendingWorkflowsOutput:
150
202
  def __init__(self, *, workflow_uuid: str, queue_name: Optional[str] = None):
151
203
  self.workflow_uuid: str = workflow_uuid
@@ -278,12 +330,14 @@ class SystemDatabase:
278
330
  def insert_workflow_status(
279
331
  self,
280
332
  status: WorkflowStatusInternal,
333
+ conn: sa.Connection,
281
334
  *,
282
- max_recovery_attempts: int = DEFAULT_MAX_RECOVERY_ATTEMPTS,
283
- ) -> WorkflowStatuses:
335
+ max_recovery_attempts: Optional[int],
336
+ ) -> tuple[WorkflowStatuses, Optional[int]]:
284
337
  if self._debug_mode:
285
338
  raise Exception("called insert_workflow_status in debug mode")
286
339
  wf_status: WorkflowStatuses = status["status"]
340
+ workflow_deadline_epoch_ms: Optional[int] = status["workflow_deadline_epoch_ms"]
287
341
 
288
342
  cmd = (
289
343
  pg.insert(SystemSchema.workflow_status)
@@ -306,6 +360,8 @@ class SystemDatabase:
306
360
  recovery_attempts=(
307
361
  1 if wf_status != WorkflowStatusString.ENQUEUED.value else 0
308
362
  ),
363
+ workflow_timeout_ms=status["workflow_timeout_ms"],
364
+ workflow_deadline_epoch_ms=status["workflow_deadline_epoch_ms"],
309
365
  )
310
366
  .on_conflict_do_update(
311
367
  index_elements=["workflow_uuid"],
@@ -319,10 +375,9 @@ class SystemDatabase:
319
375
  )
320
376
  )
321
377
 
322
- cmd = cmd.returning(SystemSchema.workflow_status.c.recovery_attempts, SystemSchema.workflow_status.c.status, SystemSchema.workflow_status.c.name, SystemSchema.workflow_status.c.class_name, SystemSchema.workflow_status.c.config_name, SystemSchema.workflow_status.c.queue_name) # type: ignore
378
+ cmd = cmd.returning(SystemSchema.workflow_status.c.recovery_attempts, SystemSchema.workflow_status.c.status, SystemSchema.workflow_status.c.workflow_deadline_epoch_ms, SystemSchema.workflow_status.c.name, SystemSchema.workflow_status.c.class_name, SystemSchema.workflow_status.c.config_name, SystemSchema.workflow_status.c.queue_name) # type: ignore
323
379
 
324
- with self.engine.begin() as c:
325
- results = c.execute(cmd)
380
+ results = conn.execute(cmd)
326
381
 
327
382
  row = results.fetchone()
328
383
  if row is not None:
@@ -330,51 +385,58 @@ class SystemDatabase:
330
385
  # A mismatch indicates a workflow starting with the same UUID but different functions, which would throw an exception.
331
386
  recovery_attempts: int = row[0]
332
387
  wf_status = row[1]
388
+ workflow_deadline_epoch_ms = row[2]
333
389
  err_msg: Optional[str] = None
334
- if row[2] != status["name"]:
335
- err_msg = f"Workflow already exists with a different function name: {row[2]}, but the provided function name is: {status['name']}"
336
- elif row[3] != status["class_name"]:
337
- err_msg = f"Workflow already exists with a different class name: {row[3]}, but the provided class name is: {status['class_name']}"
338
- elif row[4] != status["config_name"]:
339
- err_msg = f"Workflow already exists with a different config name: {row[4]}, but the provided config name is: {status['config_name']}"
340
- elif row[5] != status["queue_name"]:
390
+ if row[3] != status["name"]:
391
+ err_msg = f"Workflow already exists with a different function name: {row[3]}, but the provided function name is: {status['name']}"
392
+ elif row[4] != status["class_name"]:
393
+ err_msg = f"Workflow already exists with a different class name: {row[4]}, but the provided class name is: {status['class_name']}"
394
+ elif row[5] != status["config_name"]:
395
+ err_msg = f"Workflow already exists with a different config name: {row[5]}, but the provided config name is: {status['config_name']}"
396
+ elif row[6] != status["queue_name"]:
341
397
  # This is a warning because a different queue name is not necessarily an error.
342
398
  dbos_logger.warning(
343
- f"Workflow already exists in queue: {row[5]}, but the provided queue name is: {status['queue_name']}. The queue is not updated."
399
+ f"Workflow already exists in queue: {row[6]}, but the provided queue name is: {status['queue_name']}. The queue is not updated."
344
400
  )
345
401
  if err_msg is not None:
346
402
  raise DBOSConflictingWorkflowError(status["workflow_uuid"], err_msg)
347
403
 
348
404
  # Every time we start executing a workflow (and thus attempt to insert its status), we increment `recovery_attempts` by 1.
349
405
  # When this number becomes equal to `maxRetries + 1`, we mark the workflow as `RETRIES_EXCEEDED`.
350
- if recovery_attempts > max_recovery_attempts + 1:
351
- with self.engine.begin() as c:
352
- c.execute(
353
- sa.delete(SystemSchema.workflow_queue).where(
354
- SystemSchema.workflow_queue.c.workflow_uuid
355
- == status["workflow_uuid"]
356
- )
406
+ if (
407
+ (wf_status != "SUCCESS" and wf_status != "ERROR")
408
+ and max_recovery_attempts is not None
409
+ and recovery_attempts > max_recovery_attempts + 1
410
+ ):
411
+ delete_cmd = sa.delete(SystemSchema.workflow_queue).where(
412
+ SystemSchema.workflow_queue.c.workflow_uuid
413
+ == status["workflow_uuid"]
414
+ )
415
+ conn.execute(delete_cmd)
416
+
417
+ dlq_cmd = (
418
+ sa.update(SystemSchema.workflow_status)
419
+ .where(
420
+ SystemSchema.workflow_status.c.workflow_uuid
421
+ == status["workflow_uuid"]
357
422
  )
358
- c.execute(
359
- sa.update(SystemSchema.workflow_status)
360
- .where(
361
- SystemSchema.workflow_status.c.workflow_uuid
362
- == status["workflow_uuid"]
363
- )
364
- .where(
365
- SystemSchema.workflow_status.c.status
366
- == WorkflowStatusString.PENDING.value
367
- )
368
- .values(
369
- status=WorkflowStatusString.RETRIES_EXCEEDED.value,
370
- queue_name=None,
371
- )
423
+ .where(
424
+ SystemSchema.workflow_status.c.status
425
+ == WorkflowStatusString.PENDING.value
426
+ )
427
+ .values(
428
+ status=WorkflowStatusString.RETRIES_EXCEEDED.value,
429
+ queue_name=None,
372
430
  )
431
+ )
432
+ conn.execute(dlq_cmd)
433
+ # Need to commit here because we're throwing an exception
434
+ conn.commit()
373
435
  raise DBOSDeadLetterQueueError(
374
436
  status["workflow_uuid"], max_recovery_attempts
375
437
  )
376
438
 
377
- return wf_status
439
+ return wf_status, workflow_deadline_epoch_ms
378
440
 
379
441
  def update_workflow_status(
380
442
  self,
@@ -432,6 +494,18 @@ class SystemDatabase:
432
494
  if self._debug_mode:
433
495
  raise Exception("called cancel_workflow in debug mode")
434
496
  with self.engine.begin() as c:
497
+ # Check the status of the workflow. If it is complete, do nothing.
498
+ row = c.execute(
499
+ sa.select(
500
+ SystemSchema.workflow_status.c.status,
501
+ ).where(SystemSchema.workflow_status.c.workflow_uuid == workflow_id)
502
+ ).fetchone()
503
+ if (
504
+ row is None
505
+ or row[0] == WorkflowStatusString.SUCCESS.value
506
+ or row[0] == WorkflowStatusString.ERROR.value
507
+ ):
508
+ return
435
509
  # Remove the workflow from the queues table so it does not block the table
436
510
  c.execute(
437
511
  sa.delete(SystemSchema.workflow_queue).where(
@@ -447,13 +521,12 @@ class SystemDatabase:
447
521
  )
448
522
  )
449
523
 
450
- def resume_workflow(
451
- self,
452
- workflow_id: str,
453
- ) -> None:
524
+ def resume_workflow(self, workflow_id: str) -> None:
454
525
  if self._debug_mode:
455
526
  raise Exception("called resume_workflow in debug mode")
456
527
  with self.engine.begin() as c:
528
+ # Execute with snapshot isolation in case of concurrent calls on the same workflow
529
+ c.execute(sa.text("SET TRANSACTION ISOLATION LEVEL REPEATABLE READ"))
457
530
  # Check the status of the workflow. If it is complete, do nothing.
458
531
  row = c.execute(
459
532
  sa.select(
@@ -472,13 +545,113 @@ class SystemDatabase:
472
545
  SystemSchema.workflow_queue.c.workflow_uuid == workflow_id
473
546
  )
474
547
  )
475
- # Set the workflow's status to PENDING and clear its recovery attempts.
548
+ # Enqueue the workflow on the internal queue
549
+ c.execute(
550
+ pg.insert(SystemSchema.workflow_queue).values(
551
+ workflow_uuid=workflow_id,
552
+ queue_name=INTERNAL_QUEUE_NAME,
553
+ )
554
+ )
555
+ # Set the workflow's status to ENQUEUED and clear its recovery attempts and deadline.
476
556
  c.execute(
477
557
  sa.update(SystemSchema.workflow_status)
478
558
  .where(SystemSchema.workflow_status.c.workflow_uuid == workflow_id)
479
- .values(status=WorkflowStatusString.PENDING.value, recovery_attempts=0)
559
+ .values(
560
+ status=WorkflowStatusString.ENQUEUED.value,
561
+ recovery_attempts=0,
562
+ workflow_deadline_epoch_ms=None,
563
+ )
480
564
  )
481
565
 
566
+ def get_max_function_id(self, workflow_uuid: str) -> Optional[int]:
567
+ with self.engine.begin() as conn:
568
+ max_function_id_row = conn.execute(
569
+ sa.select(
570
+ sa.func.max(SystemSchema.operation_outputs.c.function_id)
571
+ ).where(SystemSchema.operation_outputs.c.workflow_uuid == workflow_uuid)
572
+ ).fetchone()
573
+
574
+ max_function_id = max_function_id_row[0] if max_function_id_row else None
575
+
576
+ return max_function_id
577
+
578
+ def fork_workflow(
579
+ self, original_workflow_id: str, forked_workflow_id: str, start_step: int = 1
580
+ ) -> str:
581
+
582
+ status = self.get_workflow_status(original_workflow_id)
583
+ if status is None:
584
+ raise Exception(f"Workflow {original_workflow_id} not found")
585
+ inputs = self.get_workflow_inputs(original_workflow_id)
586
+ if inputs is None:
587
+ raise Exception(f"Workflow {original_workflow_id} not found")
588
+
589
+ with self.engine.begin() as c:
590
+ # Create an entry for the forked workflow with the same
591
+ # initial values as the original.
592
+ c.execute(
593
+ pg.insert(SystemSchema.workflow_status).values(
594
+ workflow_uuid=forked_workflow_id,
595
+ status=WorkflowStatusString.ENQUEUED.value,
596
+ name=status["name"],
597
+ class_name=status["class_name"],
598
+ config_name=status["config_name"],
599
+ application_version=status["app_version"],
600
+ application_id=status["app_id"],
601
+ request=status["request"],
602
+ authenticated_user=status["authenticated_user"],
603
+ authenticated_roles=status["authenticated_roles"],
604
+ assumed_role=status["assumed_role"],
605
+ queue_name=INTERNAL_QUEUE_NAME,
606
+ )
607
+ )
608
+ # Copy the original workflow's inputs into the forked workflow
609
+ c.execute(
610
+ pg.insert(SystemSchema.workflow_inputs).values(
611
+ workflow_uuid=forked_workflow_id,
612
+ inputs=_serialization.serialize_args(inputs),
613
+ )
614
+ )
615
+
616
+ if start_step > 1:
617
+
618
+ # Copy the original workflow's outputs into the forked workflow
619
+ insert_stmt = sa.insert(SystemSchema.operation_outputs).from_select(
620
+ [
621
+ "workflow_uuid",
622
+ "function_id",
623
+ "output",
624
+ "error",
625
+ "function_name",
626
+ "child_workflow_id",
627
+ ],
628
+ sa.select(
629
+ sa.literal(forked_workflow_id).label("workflow_uuid"),
630
+ SystemSchema.operation_outputs.c.function_id,
631
+ SystemSchema.operation_outputs.c.output,
632
+ SystemSchema.operation_outputs.c.error,
633
+ SystemSchema.operation_outputs.c.function_name,
634
+ SystemSchema.operation_outputs.c.child_workflow_id,
635
+ ).where(
636
+ (
637
+ SystemSchema.operation_outputs.c.workflow_uuid
638
+ == original_workflow_id
639
+ )
640
+ & (SystemSchema.operation_outputs.c.function_id < start_step)
641
+ ),
642
+ )
643
+
644
+ c.execute(insert_stmt)
645
+
646
+ # Enqueue the forked workflow on the internal queue
647
+ c.execute(
648
+ pg.insert(SystemSchema.workflow_queue).values(
649
+ workflow_uuid=forked_workflow_id,
650
+ queue_name=INTERNAL_QUEUE_NAME,
651
+ )
652
+ )
653
+ return forked_workflow_id
654
+
482
655
  def get_workflow_status(
483
656
  self, workflow_uuid: str
484
657
  ) -> Optional[WorkflowStatusInternal]:
@@ -500,6 +673,8 @@ class SystemDatabase:
500
673
  SystemSchema.workflow_status.c.updated_at,
501
674
  SystemSchema.workflow_status.c.application_version,
502
675
  SystemSchema.workflow_status.c.application_id,
676
+ SystemSchema.workflow_status.c.workflow_deadline_epoch_ms,
677
+ SystemSchema.workflow_status.c.workflow_timeout_ms,
503
678
  ).where(SystemSchema.workflow_status.c.workflow_uuid == workflow_uuid)
504
679
  ).fetchone()
505
680
  if row is None:
@@ -523,12 +698,12 @@ class SystemDatabase:
523
698
  "updated_at": row[12],
524
699
  "app_version": row[13],
525
700
  "app_id": row[14],
701
+ "workflow_deadline_epoch_ms": row[15],
702
+ "workflow_timeout_ms": row[16],
526
703
  }
527
704
  return status
528
705
 
529
- def await_workflow_result_internal(self, workflow_uuid: str) -> dict[str, Any]:
530
- polling_interval_secs: float = 1.000
531
-
706
+ def await_workflow_result(self, workflow_id: str) -> Any:
532
707
  while True:
533
708
  with self.engine.begin() as c:
534
709
  row = c.execute(
@@ -536,44 +711,26 @@ class SystemDatabase:
536
711
  SystemSchema.workflow_status.c.status,
537
712
  SystemSchema.workflow_status.c.output,
538
713
  SystemSchema.workflow_status.c.error,
539
- ).where(
540
- SystemSchema.workflow_status.c.workflow_uuid == workflow_uuid
541
- )
714
+ ).where(SystemSchema.workflow_status.c.workflow_uuid == workflow_id)
542
715
  ).fetchone()
543
716
  if row is not None:
544
717
  status = row[0]
545
- if status == str(WorkflowStatusString.SUCCESS.value):
546
- return {
547
- "status": status,
548
- "output": row[1],
549
- "workflow_uuid": workflow_uuid,
550
- }
551
-
552
- elif status == str(WorkflowStatusString.ERROR.value):
553
- return {
554
- "status": status,
555
- "error": row[2],
556
- "workflow_uuid": workflow_uuid,
557
- }
558
-
718
+ if status == WorkflowStatusString.SUCCESS.value:
719
+ output = row[1]
720
+ return _serialization.deserialize(output)
721
+ elif status == WorkflowStatusString.ERROR.value:
722
+ error = row[2]
723
+ raise _serialization.deserialize_exception(error)
724
+ elif status == WorkflowStatusString.CANCELLED.value:
725
+ # Raise a normal exception here, not the cancellation exception
726
+ # because the awaiting workflow is not being cancelled.
727
+ raise Exception(f"Awaited workflow {workflow_id} was cancelled")
559
728
  else:
560
729
  pass # CB: I guess we're assuming the WF will show up eventually.
561
-
562
- time.sleep(polling_interval_secs)
563
-
564
- def await_workflow_result(self, workflow_uuid: str) -> Any:
565
- stat = self.await_workflow_result_internal(workflow_uuid)
566
- if not stat:
567
- return None
568
- status: str = stat["status"]
569
- if status == str(WorkflowStatusString.SUCCESS.value):
570
- return _serialization.deserialize(stat["output"])
571
- elif status == str(WorkflowStatusString.ERROR.value):
572
- raise _serialization.deserialize_exception(stat["error"])
573
- return None
730
+ time.sleep(1)
574
731
 
575
732
  def update_workflow_inputs(
576
- self, workflow_uuid: str, inputs: str, conn: Optional[sa.Connection] = None
733
+ self, workflow_uuid: str, inputs: str, conn: sa.Connection
577
734
  ) -> None:
578
735
  if self._debug_mode:
579
736
  raise Exception("called update_workflow_inputs in debug mode")
@@ -590,11 +747,8 @@ class SystemDatabase:
590
747
  )
591
748
  .returning(SystemSchema.workflow_inputs.c.inputs)
592
749
  )
593
- if conn is not None:
594
- row = conn.execute(cmd).fetchone()
595
- else:
596
- with self.engine.begin() as c:
597
- row = c.execute(cmd).fetchone()
750
+
751
+ row = conn.execute(cmd).fetchone()
598
752
  if row is not None and row[0] != inputs:
599
753
  # In a distributed environment, scheduled workflows are enqueued multiple times with slightly different timestamps
600
754
  if not workflow_uuid.startswith("sched-"):
@@ -621,8 +775,37 @@ class SystemDatabase:
621
775
  )
622
776
  return inputs
623
777
 
624
- def get_workflows(self, input: GetWorkflowsInput) -> GetWorkflowsOutput:
625
- query = sa.select(SystemSchema.workflow_status.c.workflow_uuid)
778
+ def get_workflows(
779
+ self, input: GetWorkflowsInput, get_request: bool = False
780
+ ) -> List[WorkflowStatus]:
781
+ """
782
+ Retrieve a list of workflows result and inputs based on the input criteria. The result is a list of external-facing workflow status objects.
783
+ """
784
+ query = sa.select(
785
+ SystemSchema.workflow_status.c.workflow_uuid,
786
+ SystemSchema.workflow_status.c.status,
787
+ SystemSchema.workflow_status.c.name,
788
+ SystemSchema.workflow_status.c.request,
789
+ SystemSchema.workflow_status.c.recovery_attempts,
790
+ SystemSchema.workflow_status.c.config_name,
791
+ SystemSchema.workflow_status.c.class_name,
792
+ SystemSchema.workflow_status.c.authenticated_user,
793
+ SystemSchema.workflow_status.c.authenticated_roles,
794
+ SystemSchema.workflow_status.c.assumed_role,
795
+ SystemSchema.workflow_status.c.queue_name,
796
+ SystemSchema.workflow_status.c.executor_id,
797
+ SystemSchema.workflow_status.c.created_at,
798
+ SystemSchema.workflow_status.c.updated_at,
799
+ SystemSchema.workflow_status.c.application_version,
800
+ SystemSchema.workflow_status.c.application_id,
801
+ SystemSchema.workflow_inputs.c.inputs,
802
+ SystemSchema.workflow_status.c.output,
803
+ SystemSchema.workflow_status.c.error,
804
+ ).join(
805
+ SystemSchema.workflow_inputs,
806
+ SystemSchema.workflow_status.c.workflow_uuid
807
+ == SystemSchema.workflow_inputs.c.workflow_uuid,
808
+ )
626
809
  if input.sort_desc:
627
810
  query = query.order_by(SystemSchema.workflow_status.c.created_at.desc())
628
811
  else:
@@ -655,6 +838,12 @@ class SystemDatabase:
655
838
  query = query.where(
656
839
  SystemSchema.workflow_status.c.workflow_uuid.in_(input.workflow_ids)
657
840
  )
841
+ if input.workflow_id_prefix:
842
+ query = query.where(
843
+ SystemSchema.workflow_status.c.workflow_uuid.startswith(
844
+ input.workflow_id_prefix
845
+ )
846
+ )
658
847
  if input.limit:
659
848
  query = query.limit(input.limit)
660
849
  if input.offset:
@@ -662,18 +851,76 @@ class SystemDatabase:
662
851
 
663
852
  with self.engine.begin() as c:
664
853
  rows = c.execute(query)
665
- workflow_ids = [row[0] for row in rows]
666
854
 
667
- return GetWorkflowsOutput(workflow_ids)
855
+ infos: List[WorkflowStatus] = []
856
+ for row in rows:
857
+ info = WorkflowStatus()
858
+ info.workflow_id = row[0]
859
+ info.status = row[1]
860
+ info.name = row[2]
861
+ info.request = row[3] if get_request else None
862
+ info.recovery_attempts = row[4]
863
+ info.config_name = row[5]
864
+ info.class_name = row[6]
865
+ info.authenticated_user = row[7]
866
+ info.authenticated_roles = (
867
+ json.loads(row[8]) if row[8] is not None else None
868
+ )
869
+ info.assumed_role = row[9]
870
+ info.queue_name = row[10]
871
+ info.executor_id = row[11]
872
+ info.created_at = row[12]
873
+ info.updated_at = row[13]
874
+ info.app_version = row[14]
875
+ info.app_id = row[15]
876
+
877
+ inputs = _serialization.deserialize_args(row[16])
878
+ if inputs is not None:
879
+ info.input = inputs
880
+ if info.status == WorkflowStatusString.SUCCESS.value:
881
+ info.output = _serialization.deserialize(row[17])
882
+ elif info.status == WorkflowStatusString.ERROR.value:
883
+ info.error = _serialization.deserialize_exception(row[18])
884
+
885
+ infos.append(info)
886
+ return infos
668
887
 
669
888
  def get_queued_workflows(
670
- self, input: GetQueuedWorkflowsInput
671
- ) -> GetWorkflowsOutput:
672
-
673
- query = sa.select(SystemSchema.workflow_queue.c.workflow_uuid).join(
674
- SystemSchema.workflow_status,
675
- SystemSchema.workflow_queue.c.workflow_uuid
676
- == SystemSchema.workflow_status.c.workflow_uuid,
889
+ self, input: GetQueuedWorkflowsInput, get_request: bool = False
890
+ ) -> List[WorkflowStatus]:
891
+ """
892
+ Retrieve a list of queued workflows result and inputs based on the input criteria. The result is a list of external-facing workflow status objects.
893
+ """
894
+ query = sa.select(
895
+ SystemSchema.workflow_status.c.workflow_uuid,
896
+ SystemSchema.workflow_status.c.status,
897
+ SystemSchema.workflow_status.c.name,
898
+ SystemSchema.workflow_status.c.request,
899
+ SystemSchema.workflow_status.c.recovery_attempts,
900
+ SystemSchema.workflow_status.c.config_name,
901
+ SystemSchema.workflow_status.c.class_name,
902
+ SystemSchema.workflow_status.c.authenticated_user,
903
+ SystemSchema.workflow_status.c.authenticated_roles,
904
+ SystemSchema.workflow_status.c.assumed_role,
905
+ SystemSchema.workflow_status.c.queue_name,
906
+ SystemSchema.workflow_status.c.executor_id,
907
+ SystemSchema.workflow_status.c.created_at,
908
+ SystemSchema.workflow_status.c.updated_at,
909
+ SystemSchema.workflow_status.c.application_version,
910
+ SystemSchema.workflow_status.c.application_id,
911
+ SystemSchema.workflow_inputs.c.inputs,
912
+ SystemSchema.workflow_status.c.output,
913
+ SystemSchema.workflow_status.c.error,
914
+ ).select_from(
915
+ SystemSchema.workflow_queue.join(
916
+ SystemSchema.workflow_status,
917
+ SystemSchema.workflow_queue.c.workflow_uuid
918
+ == SystemSchema.workflow_status.c.workflow_uuid,
919
+ ).join(
920
+ SystemSchema.workflow_inputs,
921
+ SystemSchema.workflow_queue.c.workflow_uuid
922
+ == SystemSchema.workflow_inputs.c.workflow_uuid,
923
+ )
677
924
  )
678
925
  if input["sort_desc"]:
679
926
  query = query.order_by(SystemSchema.workflow_status.c.created_at.desc())
@@ -710,9 +957,40 @@ class SystemDatabase:
710
957
 
711
958
  with self.engine.begin() as c:
712
959
  rows = c.execute(query)
713
- workflow_uuids = [row[0] for row in rows]
714
960
 
715
- return GetWorkflowsOutput(workflow_uuids)
961
+ infos: List[WorkflowStatus] = []
962
+ for row in rows:
963
+ info = WorkflowStatus()
964
+ info.workflow_id = row[0]
965
+ info.status = row[1]
966
+ info.name = row[2]
967
+ info.request = row[3] if get_request else None
968
+ info.recovery_attempts = row[4]
969
+ info.config_name = row[5]
970
+ info.class_name = row[6]
971
+ info.authenticated_user = row[7]
972
+ info.authenticated_roles = (
973
+ json.loads(row[8]) if row[8] is not None else None
974
+ )
975
+ info.assumed_role = row[9]
976
+ info.queue_name = row[10]
977
+ info.executor_id = row[11]
978
+ info.created_at = row[12]
979
+ info.updated_at = row[13]
980
+ info.app_version = row[14]
981
+ info.app_id = row[15]
982
+
983
+ inputs = _serialization.deserialize_args(row[16])
984
+ if inputs is not None:
985
+ info.input = inputs
986
+ if info.status == WorkflowStatusString.SUCCESS.value:
987
+ info.output = _serialization.deserialize(row[17])
988
+ elif info.status == WorkflowStatusString.ERROR.value:
989
+ info.error = _serialization.deserialize_exception(row[18])
990
+
991
+ infos.append(info)
992
+
993
+ return infos
716
994
 
717
995
  def get_pending_workflows(
718
996
  self, executor_id: str, app_version: str
@@ -844,28 +1122,74 @@ class SystemDatabase:
844
1122
  raise
845
1123
 
846
1124
  def check_operation_execution(
847
- self, workflow_uuid: str, function_id: int, conn: Optional[sa.Connection] = None
1125
+ self,
1126
+ workflow_id: str,
1127
+ function_id: int,
1128
+ function_name: str,
1129
+ *,
1130
+ conn: Optional[sa.Connection] = None,
848
1131
  ) -> Optional[RecordedResult]:
849
- sql = sa.select(
1132
+ # First query: Retrieve the workflow status
1133
+ workflow_status_sql = sa.select(
1134
+ SystemSchema.workflow_status.c.status,
1135
+ ).where(SystemSchema.workflow_status.c.workflow_uuid == workflow_id)
1136
+
1137
+ # Second query: Retrieve operation outputs if they exist
1138
+ operation_output_sql = sa.select(
850
1139
  SystemSchema.operation_outputs.c.output,
851
1140
  SystemSchema.operation_outputs.c.error,
1141
+ SystemSchema.operation_outputs.c.function_name,
852
1142
  ).where(
853
- SystemSchema.operation_outputs.c.workflow_uuid == workflow_uuid,
854
- SystemSchema.operation_outputs.c.function_id == function_id,
1143
+ (SystemSchema.operation_outputs.c.workflow_uuid == workflow_id)
1144
+ & (SystemSchema.operation_outputs.c.function_id == function_id)
855
1145
  )
856
1146
 
857
- # If in a transaction, use the provided connection
858
- rows: Sequence[Any]
1147
+ # Execute both queries
859
1148
  if conn is not None:
860
- rows = conn.execute(sql).all()
1149
+ workflow_status_rows = conn.execute(workflow_status_sql).all()
1150
+ operation_output_rows = conn.execute(operation_output_sql).all()
861
1151
  else:
862
1152
  with self.engine.begin() as c:
863
- rows = c.execute(sql).all()
864
- if len(rows) == 0:
1153
+ workflow_status_rows = c.execute(workflow_status_sql).all()
1154
+ operation_output_rows = c.execute(operation_output_sql).all()
1155
+
1156
+ # Check if the workflow exists
1157
+ assert (
1158
+ len(workflow_status_rows) > 0
1159
+ ), f"Error: Workflow {workflow_id} does not exist"
1160
+
1161
+ # Get workflow status
1162
+ workflow_status = workflow_status_rows[0][0]
1163
+
1164
+ # If the workflow is cancelled, raise the exception
1165
+ if workflow_status == WorkflowStatusString.CANCELLED.value:
1166
+ raise DBOSWorkflowCancelledError(
1167
+ f"Workflow {workflow_id} is cancelled. Aborting function."
1168
+ )
1169
+
1170
+ # If there are no operation outputs, return None
1171
+ if not operation_output_rows:
865
1172
  return None
1173
+
1174
+ # Extract operation output data
1175
+ output, error, recorded_function_name = (
1176
+ operation_output_rows[0][0],
1177
+ operation_output_rows[0][1],
1178
+ operation_output_rows[0][2],
1179
+ )
1180
+
1181
+ # If the provided and recorded function name are different, throw an exception
1182
+ if function_name != recorded_function_name:
1183
+ raise DBOSUnexpectedStepError(
1184
+ workflow_id=workflow_id,
1185
+ step_id=function_id,
1186
+ expected_name=function_name,
1187
+ recorded_name=recorded_function_name,
1188
+ )
1189
+
866
1190
  result: RecordedResult = {
867
- "output": rows[0][0],
868
- "error": rows[0][1],
1191
+ "output": output,
1192
+ "error": error,
869
1193
  }
870
1194
  return result
871
1195
 
@@ -894,10 +1218,11 @@ class SystemDatabase:
894
1218
  message: Any,
895
1219
  topic: Optional[str] = None,
896
1220
  ) -> None:
1221
+ function_name = "DBOS.send"
897
1222
  topic = topic if topic is not None else _dbos_null_topic
898
1223
  with self.engine.begin() as c:
899
1224
  recorded_output = self.check_operation_execution(
900
- workflow_uuid, function_id, conn=c
1225
+ workflow_uuid, function_id, function_name, conn=c
901
1226
  )
902
1227
  if self._debug_mode and recorded_output is None:
903
1228
  raise Exception(
@@ -930,7 +1255,7 @@ class SystemDatabase:
930
1255
  output: OperationResultInternal = {
931
1256
  "workflow_uuid": workflow_uuid,
932
1257
  "function_id": function_id,
933
- "function_name": "DBOS.send",
1258
+ "function_name": function_name,
934
1259
  "output": None,
935
1260
  "error": None,
936
1261
  }
@@ -944,10 +1269,13 @@ class SystemDatabase:
944
1269
  topic: Optional[str],
945
1270
  timeout_seconds: float = 60,
946
1271
  ) -> Any:
1272
+ function_name = "DBOS.recv"
947
1273
  topic = topic if topic is not None else _dbos_null_topic
948
1274
 
949
1275
  # First, check for previous executions.
950
- recorded_output = self.check_operation_execution(workflow_uuid, function_id)
1276
+ recorded_output = self.check_operation_execution(
1277
+ workflow_uuid, function_id, function_name
1278
+ )
951
1279
  if self._debug_mode and recorded_output is None:
952
1280
  raise Exception("called recv in debug mode without a previous execution")
953
1281
  if recorded_output is not None:
@@ -1024,7 +1352,7 @@ class SystemDatabase:
1024
1352
  {
1025
1353
  "workflow_uuid": workflow_uuid,
1026
1354
  "function_id": function_id,
1027
- "function_name": "DBOS.recv",
1355
+ "function_name": function_name,
1028
1356
  "output": _serialization.serialize(
1029
1357
  message
1030
1358
  ), # None will be serialized to 'null'
@@ -1098,7 +1426,10 @@ class SystemDatabase:
1098
1426
  seconds: float,
1099
1427
  skip_sleep: bool = False,
1100
1428
  ) -> float:
1101
- recorded_output = self.check_operation_execution(workflow_uuid, function_id)
1429
+ function_name = "DBOS.sleep"
1430
+ recorded_output = self.check_operation_execution(
1431
+ workflow_uuid, function_id, function_name
1432
+ )
1102
1433
  end_time: float
1103
1434
  if self._debug_mode and recorded_output is None:
1104
1435
  raise Exception("called sleep in debug mode without a previous execution")
@@ -1115,7 +1446,7 @@ class SystemDatabase:
1115
1446
  {
1116
1447
  "workflow_uuid": workflow_uuid,
1117
1448
  "function_id": function_id,
1118
- "function_name": "DBOS.sleep",
1449
+ "function_name": function_name,
1119
1450
  "output": _serialization.serialize(end_time),
1120
1451
  "error": None,
1121
1452
  }
@@ -1134,9 +1465,10 @@ class SystemDatabase:
1134
1465
  key: str,
1135
1466
  message: Any,
1136
1467
  ) -> None:
1468
+ function_name = "DBOS.setEvent"
1137
1469
  with self.engine.begin() as c:
1138
1470
  recorded_output = self.check_operation_execution(
1139
- workflow_uuid, function_id, conn=c
1471
+ workflow_uuid, function_id, function_name, conn=c
1140
1472
  )
1141
1473
  if self._debug_mode and recorded_output is None:
1142
1474
  raise Exception(
@@ -1163,7 +1495,7 @@ class SystemDatabase:
1163
1495
  output: OperationResultInternal = {
1164
1496
  "workflow_uuid": workflow_uuid,
1165
1497
  "function_id": function_id,
1166
- "function_name": "DBOS.setEvent",
1498
+ "function_name": function_name,
1167
1499
  "output": None,
1168
1500
  "error": None,
1169
1501
  }
@@ -1176,6 +1508,7 @@ class SystemDatabase:
1176
1508
  timeout_seconds: float = 60,
1177
1509
  caller_ctx: Optional[GetEventWorkflowContext] = None,
1178
1510
  ) -> Any:
1511
+ function_name = "DBOS.getEvent"
1179
1512
  get_sql = sa.select(
1180
1513
  SystemSchema.workflow_events.c.value,
1181
1514
  ).where(
@@ -1185,7 +1518,7 @@ class SystemDatabase:
1185
1518
  # Check for previous executions only if it's in a workflow
1186
1519
  if caller_ctx is not None:
1187
1520
  recorded_output = self.check_operation_execution(
1188
- caller_ctx["workflow_uuid"], caller_ctx["function_id"]
1521
+ caller_ctx["workflow_uuid"], caller_ctx["function_id"], function_name
1189
1522
  )
1190
1523
  if self._debug_mode and recorded_output is None:
1191
1524
  raise Exception(
@@ -1244,7 +1577,7 @@ class SystemDatabase:
1244
1577
  {
1245
1578
  "workflow_uuid": caller_ctx["workflow_uuid"],
1246
1579
  "function_id": caller_ctx["function_id"],
1247
- "function_name": "DBOS.getEvent",
1580
+ "function_name": function_name,
1248
1581
  "output": _serialization.serialize(
1249
1582
  value
1250
1583
  ), # None will be serialized to 'null'
@@ -1253,18 +1586,17 @@ class SystemDatabase:
1253
1586
  )
1254
1587
  return value
1255
1588
 
1256
- def enqueue(self, workflow_id: str, queue_name: str) -> None:
1589
+ def enqueue(self, workflow_id: str, queue_name: str, conn: sa.Connection) -> None:
1257
1590
  if self._debug_mode:
1258
1591
  raise Exception("called enqueue in debug mode")
1259
- with self.engine.begin() as c:
1260
- c.execute(
1261
- pg.insert(SystemSchema.workflow_queue)
1262
- .values(
1263
- workflow_uuid=workflow_id,
1264
- queue_name=queue_name,
1265
- )
1266
- .on_conflict_do_nothing()
1592
+ conn.execute(
1593
+ pg.insert(SystemSchema.workflow_queue)
1594
+ .values(
1595
+ workflow_uuid=workflow_id,
1596
+ queue_name=queue_name,
1267
1597
  )
1598
+ .on_conflict_do_nothing()
1599
+ )
1268
1600
 
1269
1601
  def start_queued_workflows(
1270
1602
  self, queue: "Queue", executor_id: str, app_version: str
@@ -1403,6 +1735,17 @@ class SystemDatabase:
1403
1735
  status=WorkflowStatusString.PENDING.value,
1404
1736
  application_version=app_version,
1405
1737
  executor_id=executor_id,
1738
+ # If a timeout is set, set the deadline on dequeue
1739
+ workflow_deadline_epoch_ms=sa.case(
1740
+ (
1741
+ SystemSchema.workflow_status.c.workflow_timeout_ms.isnot(
1742
+ None
1743
+ ),
1744
+ sa.func.extract("epoch", sa.func.now()) * 1000
1745
+ + SystemSchema.workflow_status.c.workflow_timeout_ms,
1746
+ ),
1747
+ else_=SystemSchema.workflow_status.c.workflow_deadline_epoch_ms,
1748
+ ),
1406
1749
  )
1407
1750
  )
1408
1751
  if res.rowcount > 0:
@@ -1483,6 +1826,66 @@ class SystemDatabase:
1483
1826
  )
1484
1827
  return True
1485
1828
 
1829
+ T = TypeVar("T")
1830
+
1831
+ def call_function_as_step(self, fn: Callable[[], T], function_name: str) -> T:
1832
+ ctx = get_local_dbos_context()
1833
+ if ctx and ctx.is_transaction():
1834
+ raise Exception(f"Invalid call to `{function_name}` inside a transaction")
1835
+ if ctx and ctx.is_workflow():
1836
+ ctx.function_id += 1
1837
+ res = self.check_operation_execution(
1838
+ ctx.workflow_id, ctx.function_id, function_name
1839
+ )
1840
+ if res is not None:
1841
+ if res["output"] is not None:
1842
+ resstat: SystemDatabase.T = _serialization.deserialize(
1843
+ res["output"]
1844
+ )
1845
+ return resstat
1846
+ elif res["error"] is not None:
1847
+ raise _serialization.deserialize_exception(res["error"])
1848
+ else:
1849
+ raise Exception(
1850
+ f"Recorded output and error are both None for {function_name}"
1851
+ )
1852
+ result = fn()
1853
+ if ctx and ctx.is_workflow():
1854
+ self.record_operation_result(
1855
+ {
1856
+ "workflow_uuid": ctx.workflow_id,
1857
+ "function_id": ctx.function_id,
1858
+ "function_name": function_name,
1859
+ "output": _serialization.serialize(result),
1860
+ "error": None,
1861
+ }
1862
+ )
1863
+ return result
1864
+
1865
+ def init_workflow(
1866
+ self,
1867
+ status: WorkflowStatusInternal,
1868
+ inputs: str,
1869
+ *,
1870
+ max_recovery_attempts: Optional[int],
1871
+ ) -> tuple[WorkflowStatuses, Optional[int]]:
1872
+ """
1873
+ Synchronously record the status and inputs for workflows in a single transaction
1874
+ """
1875
+ with self.engine.begin() as conn:
1876
+ wf_status, workflow_deadline_epoch_ms = self.insert_workflow_status(
1877
+ status, conn, max_recovery_attempts=max_recovery_attempts
1878
+ )
1879
+ # TODO: Modify the inputs if they were changed by `update_workflow_inputs`
1880
+ self.update_workflow_inputs(status["workflow_uuid"], inputs, conn)
1881
+
1882
+ if (
1883
+ status["queue_name"] is not None
1884
+ and wf_status == WorkflowStatusString.ENQUEUED.value
1885
+ ):
1886
+ self.enqueue(status["workflow_uuid"], status["queue_name"], conn)
1887
+ return wf_status, workflow_deadline_epoch_ms
1888
+
1486
1889
 
1487
1890
  def reset_system_database(config: ConfigFile) -> None:
1488
1891
  sysdb_name = (