dbos 2.1.0a2__py3-none-any.whl → 2.4.0a7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dbos might be problematic. Click here for more details.

dbos/_sys_db.py CHANGED
@@ -30,7 +30,6 @@ from dbos._utils import (
30
30
  retriable_sqlite_exception,
31
31
  )
32
32
 
33
- from . import _serialization
34
33
  from ._context import get_local_dbos_context
35
34
  from ._error import (
36
35
  DBOSAwaitedWorkflowCancelledError,
@@ -44,6 +43,7 @@ from ._error import (
44
43
  )
45
44
  from ._logger import dbos_logger
46
45
  from ._schemas.system_database import SystemSchema
46
+ from ._serialization import Serializer, WorkflowInputs, safe_deserialize
47
47
 
48
48
  if TYPE_CHECKING:
49
49
  from ._queue import Queue
@@ -95,7 +95,7 @@ class WorkflowStatus:
95
95
  # All roles which the authenticated user could assume
96
96
  authenticated_roles: Optional[list[str]]
97
97
  # The deserialized workflow input object
98
- input: Optional[_serialization.WorkflowInputs]
98
+ input: Optional[WorkflowInputs]
99
99
  # The workflow's output, if any
100
100
  output: Optional[Any] = None
101
101
  # The error the workflow threw, if any
@@ -106,7 +106,7 @@ class WorkflowStatus:
106
106
  updated_at: Optional[int]
107
107
  # If this workflow was enqueued, on which queue
108
108
  queue_name: Optional[str]
109
- # The executor to most recently executed this workflow
109
+ # The executor to most recently execute this workflow
110
110
  executor_id: Optional[str]
111
111
  # The application version on which this workflow was started
112
112
  app_version: Optional[str]
@@ -114,6 +114,14 @@ class WorkflowStatus:
114
114
  workflow_timeout_ms: Optional[int]
115
115
  # The deadline of a workflow, computed by adding its timeout to its start time.
116
116
  workflow_deadline_epoch_ms: Optional[int]
117
+ # Unique ID for deduplication on a queue
118
+ deduplication_id: Optional[str]
119
+ # Priority of the workflow on the queue, starting from 1 ~ 2,147,483,647. Default 0 (highest priority).
120
+ priority: Optional[int]
121
+ # If this workflow is enqueued on a partitioned queue, its partition key
122
+ queue_partition_key: Optional[str]
123
+ # If this workflow was forked from another, that workflow's ID.
124
+ forked_from: Optional[str]
117
125
 
118
126
  # INTERNAL FIELDS
119
127
 
@@ -141,17 +149,13 @@ class WorkflowStatusInternal(TypedDict):
141
149
  app_version: Optional[str]
142
150
  app_id: Optional[str]
143
151
  recovery_attempts: Optional[int]
144
- # The start-to-close timeout of the workflow in ms
145
152
  workflow_timeout_ms: Optional[int]
146
- # The deadline of a workflow, computed by adding its timeout to its start time.
147
- # Deadlines propagate to children. When the deadline is reached, the workflow is cancelled.
148
153
  workflow_deadline_epoch_ms: Optional[int]
149
- # Unique ID for deduplication on a queue
150
154
  deduplication_id: Optional[str]
151
- # Priority of the workflow on the queue, starting from 1 ~ 2,147,483,647. Default 0 (highest priority).
152
155
  priority: int
153
- # Serialized workflow inputs
154
156
  inputs: str
157
+ queue_partition_key: Optional[str]
158
+ forked_from: Optional[str]
155
159
 
156
160
 
157
161
  class EnqueueOptionsInternal(TypedDict):
@@ -161,6 +165,8 @@ class EnqueueOptionsInternal(TypedDict):
161
165
  priority: Optional[int]
162
166
  # On what version the workflow is enqueued. Current version if not specified.
163
167
  app_version: Optional[str]
168
+ # If the workflow is enqueued on a partitioned queue, its partition key
169
+ queue_partition_key: Optional[str]
164
170
 
165
171
 
166
172
  class RecordedResult(TypedDict):
@@ -174,6 +180,7 @@ class OperationResultInternal(TypedDict):
174
180
  function_name: str
175
181
  output: Optional[str] # JSON (jsonpickle)
176
182
  error: Optional[str] # JSON (jsonpickle)
183
+ started_at_epoch_ms: int
177
184
 
178
185
 
179
186
  class GetEventWorkflowContext(TypedDict):
@@ -190,42 +197,34 @@ class GetWorkflowsInput:
190
197
  """
191
198
 
192
199
  def __init__(self) -> None:
193
- self.workflow_ids: Optional[List[str]] = (
194
- None # Search only in these workflow IDs
195
- )
196
- self.name: Optional[str] = None # The name of the workflow function
197
- self.authenticated_user: Optional[str] = None # The user who ran the workflow.
198
- self.start_time: Optional[str] = None # Timestamp in ISO 8601 format
199
- self.end_time: Optional[str] = None # Timestamp in ISO 8601 format
200
- self.status: Optional[List[str]] = (
201
- None # Get workflows with one of these statuses
202
- )
203
- self.application_version: Optional[str] = (
204
- None # The application version that ran this workflow. = None
205
- )
206
- self.limit: Optional[int] = (
207
- None # Return up to this many workflows IDs. IDs are ordered by workflow creation time.
208
- )
209
- self.offset: Optional[int] = (
210
- None # Offset into the matching records for pagination
211
- )
212
- self.sort_desc: bool = (
213
- False # If true, sort by created_at in DESC order. Default false (in ASC order).
214
- )
215
- self.workflow_id_prefix: Optional[str] = (
216
- None # If set, search for workflow IDs starting with this string
217
- )
218
-
219
-
220
- class GetQueuedWorkflowsInput(TypedDict):
221
- queue_name: Optional[str] # Get workflows belonging to this queue
222
- status: Optional[list[str]] # Get workflows with one of these statuses
223
- start_time: Optional[str] # Timestamp in ISO 8601 format
224
- end_time: Optional[str] # Timestamp in ISO 8601 format
225
- limit: Optional[int] # Return up to this many workflows IDs.
226
- offset: Optional[int] # Offset into the matching records for pagination
227
- name: Optional[str] # The name of the workflow function
228
- sort_desc: Optional[bool] # Sort by created_at in DESC or ASC order
200
+ # Search only in these workflow IDs
201
+ self.workflow_ids: Optional[List[str]] = None
202
+ # The name of the workflow function
203
+ self.name: Optional[str] = None
204
+ # The user who ran the workflow.
205
+ self.authenticated_user: Optional[str] = None
206
+ # Timestamp in ISO 8601 format
207
+ self.start_time: Optional[str] = None
208
+ # Timestamp in ISO 8601 format
209
+ self.end_time: Optional[str] = None
210
+ # Get workflows with one of these statuses
211
+ self.status: Optional[List[str]] = None
212
+ # The application version that ran this workflow.
213
+ self.application_version: Optional[str] = None
214
+ # Get workflows forked from this workflow ID.
215
+ self.forked_from: Optional[str] = None
216
+ # Return up to this many workflows IDs. IDs are ordered by workflow creation time.
217
+ self.limit: Optional[int] = None
218
+ # Offset into the matching records for pagination
219
+ self.offset: Optional[int] = None
220
+ # If true, sort by created_at in DESC order. Default false (in ASC order).
221
+ self.sort_desc: bool = False
222
+ # Search only for workflow IDs starting with this string
223
+ self.workflow_id_prefix: Optional[str] = None
224
+ # Search only for workflows enqueued on this queue
225
+ self.queue_name: Optional[str] = None
226
+ # Search only currently enqueued workflows
227
+ self.queues_only: bool = False
229
228
 
230
229
 
231
230
  class GetPendingWorkflowsOutput:
@@ -245,6 +244,10 @@ class StepInfo(TypedDict):
245
244
  error: Optional[Exception]
246
245
  # If the step starts or retrieves the result of a workflow, its ID
247
246
  child_workflow_id: Optional[str]
247
+ # The Unix epoch timestamp at which this step started
248
+ started_at_epoch_ms: Optional[int]
249
+ # The Unix epoch timestamp at which this step completed
250
+ completed_at_epoch_ms: Optional[int]
248
251
 
249
252
 
250
253
  _dbos_null_topic = "__null__topic__"
@@ -341,6 +344,42 @@ def db_retry(
341
344
 
342
345
  class SystemDatabase(ABC):
343
346
 
347
+ @staticmethod
348
+ def create(
349
+ system_database_url: str,
350
+ engine_kwargs: Dict[str, Any],
351
+ engine: Optional[sa.Engine],
352
+ schema: Optional[str],
353
+ serializer: Serializer,
354
+ executor_id: Optional[str],
355
+ debug_mode: bool = False,
356
+ ) -> "SystemDatabase":
357
+ """Factory method to create the appropriate SystemDatabase implementation based on URL."""
358
+ if system_database_url.startswith("sqlite"):
359
+ from ._sys_db_sqlite import SQLiteSystemDatabase
360
+
361
+ return SQLiteSystemDatabase(
362
+ system_database_url=system_database_url,
363
+ engine_kwargs=engine_kwargs,
364
+ engine=engine,
365
+ schema=schema,
366
+ serializer=serializer,
367
+ executor_id=executor_id,
368
+ debug_mode=debug_mode,
369
+ )
370
+ else:
371
+ from ._sys_db_postgres import PostgresSystemDatabase
372
+
373
+ return PostgresSystemDatabase(
374
+ system_database_url=system_database_url,
375
+ engine_kwargs=engine_kwargs,
376
+ engine=engine,
377
+ schema=schema,
378
+ serializer=serializer,
379
+ executor_id=executor_id,
380
+ debug_mode=debug_mode,
381
+ )
382
+
344
383
  def __init__(
345
384
  self,
346
385
  *,
@@ -348,6 +387,8 @@ class SystemDatabase(ABC):
348
387
  engine_kwargs: Dict[str, Any],
349
388
  engine: Optional[sa.Engine],
350
389
  schema: Optional[str],
390
+ serializer: Serializer,
391
+ executor_id: Optional[str],
351
392
  debug_mode: bool = False,
352
393
  ):
353
394
  import sqlalchemy.dialects.postgresql as pg
@@ -355,6 +396,8 @@ class SystemDatabase(ABC):
355
396
 
356
397
  self.dialect = sq if system_database_url.startswith("sqlite") else pg
357
398
 
399
+ self.serializer = serializer
400
+
358
401
  if system_database_url.startswith("sqlite"):
359
402
  self.schema = None
360
403
  else:
@@ -371,6 +414,8 @@ class SystemDatabase(ABC):
371
414
 
372
415
  self.notifications_map = ThreadSafeConditionDict()
373
416
  self.workflow_events_map = ThreadSafeConditionDict()
417
+ self.executor_id = executor_id
418
+
374
419
  self._listener_thread_lock = threading.Lock()
375
420
 
376
421
  # Now we can run background processes
@@ -454,6 +499,7 @@ class SystemDatabase(ABC):
454
499
  deduplication_id=status["deduplication_id"],
455
500
  priority=status["priority"],
456
501
  inputs=status["inputs"],
502
+ queue_partition_key=status["queue_partition_key"],
457
503
  )
458
504
  .on_conflict_do_update(
459
505
  index_elements=["workflow_uuid"],
@@ -665,6 +711,7 @@ class SystemDatabase(ABC):
665
711
  assumed_role=status["assumed_role"],
666
712
  queue_name=INTERNAL_QUEUE_NAME,
667
713
  inputs=status["inputs"],
714
+ forked_from=original_workflow_id,
668
715
  )
669
716
  )
670
717
 
@@ -725,6 +772,8 @@ class SystemDatabase(ABC):
725
772
  SystemSchema.workflow_status.c.deduplication_id,
726
773
  SystemSchema.workflow_status.c.priority,
727
774
  SystemSchema.workflow_status.c.inputs,
775
+ SystemSchema.workflow_status.c.queue_partition_key,
776
+ SystemSchema.workflow_status.c.forked_from,
728
777
  ).where(SystemSchema.workflow_status.c.workflow_uuid == workflow_uuid)
729
778
  ).fetchone()
730
779
  if row is None:
@@ -752,6 +801,8 @@ class SystemDatabase(ABC):
752
801
  "deduplication_id": row[16],
753
802
  "priority": row[17],
754
803
  "inputs": row[18],
804
+ "queue_partition_key": row[19],
805
+ "forked_from": row[20],
755
806
  }
756
807
  return status
757
808
 
@@ -797,10 +848,11 @@ class SystemDatabase(ABC):
797
848
  status = row[0]
798
849
  if status == WorkflowStatusString.SUCCESS.value:
799
850
  output = row[1]
800
- return _serialization.deserialize(output)
851
+ return self.serializer.deserialize(output)
801
852
  elif status == WorkflowStatusString.ERROR.value:
802
853
  error = row[2]
803
- raise _serialization.deserialize_exception(error)
854
+ e: Exception = self.serializer.deserialize(error)
855
+ raise e
804
856
  elif status == WorkflowStatusString.CANCELLED.value:
805
857
  # Raise AwaitedWorkflowCancelledError here, not the cancellation exception
806
858
  # because the awaiting workflow is not being cancelled.
@@ -837,6 +889,10 @@ class SystemDatabase(ABC):
837
889
  SystemSchema.workflow_status.c.application_id,
838
890
  SystemSchema.workflow_status.c.workflow_deadline_epoch_ms,
839
891
  SystemSchema.workflow_status.c.workflow_timeout_ms,
892
+ SystemSchema.workflow_status.c.deduplication_id,
893
+ SystemSchema.workflow_status.c.priority,
894
+ SystemSchema.workflow_status.c.queue_partition_key,
895
+ SystemSchema.workflow_status.c.forked_from,
840
896
  ]
841
897
  if load_input:
842
898
  load_columns.append(SystemSchema.workflow_status.c.inputs)
@@ -844,7 +900,15 @@ class SystemDatabase(ABC):
844
900
  load_columns.append(SystemSchema.workflow_status.c.output)
845
901
  load_columns.append(SystemSchema.workflow_status.c.error)
846
902
 
847
- query = sa.select(*load_columns)
903
+ if input.queues_only:
904
+ query = sa.select(*load_columns).where(
905
+ sa.and_(
906
+ SystemSchema.workflow_status.c.queue_name.isnot(None),
907
+ SystemSchema.workflow_status.c.status.in_(["ENQUEUED", "PENDING"]),
908
+ )
909
+ )
910
+ else:
911
+ query = sa.select(*load_columns)
848
912
  if input.sort_desc:
849
913
  query = query.order_by(SystemSchema.workflow_status.c.created_at.desc())
850
914
  else:
@@ -873,6 +937,10 @@ class SystemDatabase(ABC):
873
937
  SystemSchema.workflow_status.c.application_version
874
938
  == input.application_version
875
939
  )
940
+ if input.forked_from:
941
+ query = query.where(
942
+ SystemSchema.workflow_status.c.forked_from == input.forked_from
943
+ )
876
944
  if input.workflow_ids:
877
945
  query = query.where(
878
946
  SystemSchema.workflow_status.c.workflow_uuid.in_(input.workflow_ids)
@@ -883,6 +951,10 @@ class SystemDatabase(ABC):
883
951
  input.workflow_id_prefix
884
952
  )
885
953
  )
954
+ if input.queue_name:
955
+ query = query.where(
956
+ SystemSchema.workflow_status.c.queue_name == input.queue_name
957
+ )
886
958
  if input.limit:
887
959
  query = query.limit(input.limit)
888
960
  if input.offset:
@@ -892,6 +964,7 @@ class SystemDatabase(ABC):
892
964
  rows = c.execute(query).fetchall()
893
965
 
894
966
  infos: List[WorkflowStatus] = []
967
+ workflow_ids: List[str] = []
895
968
  for row in rows:
896
969
  info = WorkflowStatus()
897
970
  info.workflow_id = row[0]
@@ -913,11 +986,16 @@ class SystemDatabase(ABC):
913
986
  info.app_id = row[14]
914
987
  info.workflow_deadline_epoch_ms = row[15]
915
988
  info.workflow_timeout_ms = row[16]
916
-
917
- raw_input = row[17] if load_input else None
918
- raw_output = row[18] if load_output else None
919
- raw_error = row[19] if load_output else None
920
- inputs, output, exception = _serialization.safe_deserialize(
989
+ info.deduplication_id = row[17]
990
+ info.priority = row[18]
991
+ info.queue_partition_key = row[19]
992
+ info.forked_from = row[20]
993
+
994
+ raw_input = row[21] if load_input else None
995
+ raw_output = row[22] if load_output else None
996
+ raw_error = row[23] if load_output else None
997
+ inputs, output, exception = safe_deserialize(
998
+ self.serializer,
921
999
  info.workflow_id,
922
1000
  serialized_input=raw_input,
923
1001
  serialized_output=raw_output,
@@ -927,121 +1005,10 @@ class SystemDatabase(ABC):
927
1005
  info.output = output
928
1006
  info.error = exception
929
1007
 
1008
+ workflow_ids.append(info.workflow_id)
930
1009
  infos.append(info)
931
1010
  return infos
932
1011
 
933
- def get_queued_workflows(
934
- self,
935
- input: GetQueuedWorkflowsInput,
936
- *,
937
- load_input: bool = True,
938
- ) -> List[WorkflowStatus]:
939
- """
940
- Retrieve a list of queued workflows result and inputs based on the input criteria. The result is a list of external-facing workflow status objects.
941
- """
942
- load_columns = [
943
- SystemSchema.workflow_status.c.workflow_uuid,
944
- SystemSchema.workflow_status.c.status,
945
- SystemSchema.workflow_status.c.name,
946
- SystemSchema.workflow_status.c.recovery_attempts,
947
- SystemSchema.workflow_status.c.config_name,
948
- SystemSchema.workflow_status.c.class_name,
949
- SystemSchema.workflow_status.c.authenticated_user,
950
- SystemSchema.workflow_status.c.authenticated_roles,
951
- SystemSchema.workflow_status.c.assumed_role,
952
- SystemSchema.workflow_status.c.queue_name,
953
- SystemSchema.workflow_status.c.executor_id,
954
- SystemSchema.workflow_status.c.created_at,
955
- SystemSchema.workflow_status.c.updated_at,
956
- SystemSchema.workflow_status.c.application_version,
957
- SystemSchema.workflow_status.c.application_id,
958
- SystemSchema.workflow_status.c.workflow_deadline_epoch_ms,
959
- SystemSchema.workflow_status.c.workflow_timeout_ms,
960
- ]
961
- if load_input:
962
- load_columns.append(SystemSchema.workflow_status.c.inputs)
963
-
964
- query = sa.select(*load_columns).where(
965
- sa.and_(
966
- SystemSchema.workflow_status.c.queue_name.isnot(None),
967
- SystemSchema.workflow_status.c.status.in_(["ENQUEUED", "PENDING"]),
968
- )
969
- )
970
- if input["sort_desc"]:
971
- query = query.order_by(SystemSchema.workflow_status.c.created_at.desc())
972
- else:
973
- query = query.order_by(SystemSchema.workflow_status.c.created_at.asc())
974
-
975
- if input.get("name"):
976
- query = query.where(SystemSchema.workflow_status.c.name == input["name"])
977
-
978
- if input.get("queue_name"):
979
- query = query.where(
980
- SystemSchema.workflow_status.c.queue_name == input["queue_name"]
981
- )
982
-
983
- status = input.get("status", None)
984
- if status:
985
- query = query.where(SystemSchema.workflow_status.c.status.in_(status))
986
- if "start_time" in input and input["start_time"] is not None:
987
- query = query.where(
988
- SystemSchema.workflow_status.c.created_at
989
- >= datetime.datetime.fromisoformat(input["start_time"]).timestamp()
990
- * 1000
991
- )
992
- if "end_time" in input and input["end_time"] is not None:
993
- query = query.where(
994
- SystemSchema.workflow_status.c.created_at
995
- <= datetime.datetime.fromisoformat(input["end_time"]).timestamp() * 1000
996
- )
997
- if input.get("limit"):
998
- query = query.limit(input["limit"])
999
- if input.get("offset"):
1000
- query = query.offset(input["offset"])
1001
-
1002
- with self.engine.begin() as c:
1003
- rows = c.execute(query).fetchall()
1004
-
1005
- infos: List[WorkflowStatus] = []
1006
- for row in rows:
1007
- info = WorkflowStatus()
1008
- info.workflow_id = row[0]
1009
- info.status = row[1]
1010
- info.name = row[2]
1011
- info.recovery_attempts = row[3]
1012
- info.config_name = row[4]
1013
- info.class_name = row[5]
1014
- info.authenticated_user = row[6]
1015
- info.authenticated_roles = (
1016
- json.loads(row[7]) if row[7] is not None else None
1017
- )
1018
- info.assumed_role = row[8]
1019
- info.queue_name = row[9]
1020
- info.executor_id = row[10]
1021
- info.created_at = row[11]
1022
- info.updated_at = row[12]
1023
- info.app_version = row[13]
1024
- info.app_id = row[14]
1025
- info.workflow_deadline_epoch_ms = row[15]
1026
- info.workflow_timeout_ms = row[16]
1027
-
1028
- raw_input = row[17] if load_input else None
1029
-
1030
- # Error and Output are not loaded because they should always be None for queued workflows.
1031
- inputs, output, exception = _serialization.safe_deserialize(
1032
- info.workflow_id,
1033
- serialized_input=raw_input,
1034
- serialized_output=None,
1035
- serialized_exception=None,
1036
- )
1037
- info.input = inputs
1038
- info.output = output
1039
- info.error = exception
1040
-
1041
- infos.append(info)
1042
-
1043
- return infos
1044
-
1045
1012
  def get_pending_workflows(
1046
1013
  self, executor_id: str, app_version: str
1047
1014
  ) -> list[GetPendingWorkflowsOutput]:
@@ -1075,11 +1042,14 @@ class SystemDatabase(ABC):
1075
1042
  SystemSchema.operation_outputs.c.output,
1076
1043
  SystemSchema.operation_outputs.c.error,
1077
1044
  SystemSchema.operation_outputs.c.child_workflow_id,
1045
+ SystemSchema.operation_outputs.c.started_at_epoch_ms,
1046
+ SystemSchema.operation_outputs.c.completed_at_epoch_ms,
1078
1047
  ).where(SystemSchema.operation_outputs.c.workflow_uuid == workflow_id)
1079
1048
  ).fetchall()
1080
1049
  steps = []
1081
1050
  for row in rows:
1082
- _, output, exception = _serialization.safe_deserialize(
1051
+ _, output, exception = safe_deserialize(
1052
+ self.serializer,
1083
1053
  workflow_id,
1084
1054
  serialized_input=None,
1085
1055
  serialized_output=row[2],
@@ -1091,6 +1061,8 @@ class SystemDatabase(ABC):
1091
1061
  output=output,
1092
1062
  error=exception,
1093
1063
  child_workflow_id=row[4],
1064
+ started_at_epoch_ms=row[5],
1065
+ completed_at_epoch_ms=row[6],
1094
1066
  )
1095
1067
  steps.append(step)
1096
1068
  return steps
@@ -1103,10 +1075,33 @@ class SystemDatabase(ABC):
1103
1075
  error = result["error"]
1104
1076
  output = result["output"]
1105
1077
  assert error is None or output is None, "Only one of error or output can be set"
1078
+ wf_executor_id_row = conn.execute(
1079
+ sa.select(
1080
+ SystemSchema.workflow_status.c.executor_id,
1081
+ ).where(
1082
+ SystemSchema.workflow_status.c.workflow_uuid == result["workflow_uuid"]
1083
+ )
1084
+ ).fetchone()
1085
+ assert wf_executor_id_row is not None
1086
+ wf_executor_id = wf_executor_id_row[0]
1087
+ if self.executor_id is not None and wf_executor_id != self.executor_id:
1088
+ dbos_logger.debug(
1089
+ f'Resetting executor_id from {wf_executor_id} to {self.executor_id} for workflow {result["workflow_uuid"]}'
1090
+ )
1091
+ conn.execute(
1092
+ sa.update(SystemSchema.workflow_status)
1093
+ .values(executor_id=self.executor_id)
1094
+ .where(
1095
+ SystemSchema.workflow_status.c.workflow_uuid
1096
+ == result["workflow_uuid"]
1097
+ )
1098
+ )
1106
1099
  sql = sa.insert(SystemSchema.operation_outputs).values(
1107
1100
  workflow_uuid=result["workflow_uuid"],
1108
1101
  function_id=result["function_id"],
1109
1102
  function_name=result["function_name"],
1103
+ started_at_epoch_ms=result["started_at_epoch_ms"],
1104
+ completed_at_epoch_ms=int(time.time() * 1000),
1110
1105
  output=output,
1111
1106
  error=error,
1112
1107
  )
@@ -1278,7 +1273,8 @@ class SystemDatabase(ABC):
1278
1273
  if row is None:
1279
1274
  return None
1280
1275
  elif row[1]:
1281
- raise _serialization.deserialize_exception(row[1])
1276
+ e: Exception = self.serializer.deserialize(row[1])
1277
+ raise e
1282
1278
  else:
1283
1279
  return str(row[0])
1284
1280
 
@@ -1292,6 +1288,7 @@ class SystemDatabase(ABC):
1292
1288
  topic: Optional[str] = None,
1293
1289
  ) -> None:
1294
1290
  function_name = "DBOS.send"
1291
+ start_time = int(time.time() * 1000)
1295
1292
  topic = topic if topic is not None else _dbos_null_topic
1296
1293
  with self.engine.begin() as c:
1297
1294
  recorded_output = self._check_operation_execution_txn(
@@ -1317,7 +1314,7 @@ class SystemDatabase(ABC):
1317
1314
  sa.insert(SystemSchema.notifications).values(
1318
1315
  destination_uuid=destination_uuid,
1319
1316
  topic=topic,
1320
- message=_serialization.serialize(message),
1317
+ message=self.serializer.serialize(message),
1321
1318
  )
1322
1319
  )
1323
1320
  except DBAPIError as dbapi_error:
@@ -1328,6 +1325,7 @@ class SystemDatabase(ABC):
1328
1325
  "workflow_uuid": workflow_uuid,
1329
1326
  "function_id": function_id,
1330
1327
  "function_name": function_name,
1328
+ "started_at_epoch_ms": start_time,
1331
1329
  "output": None,
1332
1330
  "error": None,
1333
1331
  }
@@ -1343,6 +1341,7 @@ class SystemDatabase(ABC):
1343
1341
  timeout_seconds: float = 60,
1344
1342
  ) -> Any:
1345
1343
  function_name = "DBOS.recv"
1344
+ start_time = int(time.time() * 1000)
1346
1345
  topic = topic if topic is not None else _dbos_null_topic
1347
1346
 
1348
1347
  # First, check for previous executions.
@@ -1354,7 +1353,7 @@ class SystemDatabase(ABC):
1354
1353
  if recorded_output is not None:
1355
1354
  dbos_logger.debug(f"Replaying recv, id: {function_id}, topic: {topic}")
1356
1355
  if recorded_output["output"] is not None:
1357
- return _serialization.deserialize(recorded_output["output"])
1356
+ return self.serializer.deserialize(recorded_output["output"])
1358
1357
  else:
1359
1358
  raise Exception("No output recorded in the last recv")
1360
1359
  else:
@@ -1421,13 +1420,14 @@ class SystemDatabase(ABC):
1421
1420
  rows = c.execute(delete_stmt).fetchall()
1422
1421
  message: Any = None
1423
1422
  if len(rows) > 0:
1424
- message = _serialization.deserialize(rows[0][0])
1423
+ message = self.serializer.deserialize(rows[0][0])
1425
1424
  self._record_operation_result_txn(
1426
1425
  {
1427
1426
  "workflow_uuid": workflow_uuid,
1428
1427
  "function_id": function_id,
1429
1428
  "function_name": function_name,
1430
- "output": _serialization.serialize(
1429
+ "started_at_epoch_ms": start_time,
1430
+ "output": self.serializer.serialize(
1431
1431
  message
1432
1432
  ), # None will be serialized to 'null'
1433
1433
  "error": None,
@@ -1453,36 +1453,6 @@ class SystemDatabase(ABC):
1453
1453
 
1454
1454
  PostgresSystemDatabase._reset_system_database(database_url)
1455
1455
 
1456
- @staticmethod
1457
- def create(
1458
- system_database_url: str,
1459
- engine_kwargs: Dict[str, Any],
1460
- engine: Optional[sa.Engine],
1461
- schema: Optional[str],
1462
- debug_mode: bool = False,
1463
- ) -> "SystemDatabase":
1464
- """Factory method to create the appropriate SystemDatabase implementation based on URL."""
1465
- if system_database_url.startswith("sqlite"):
1466
- from ._sys_db_sqlite import SQLiteSystemDatabase
1467
-
1468
- return SQLiteSystemDatabase(
1469
- system_database_url=system_database_url,
1470
- engine_kwargs=engine_kwargs,
1471
- engine=engine,
1472
- schema=schema,
1473
- debug_mode=debug_mode,
1474
- )
1475
- else:
1476
- from ._sys_db_postgres import PostgresSystemDatabase
1477
-
1478
- return PostgresSystemDatabase(
1479
- system_database_url=system_database_url,
1480
- engine_kwargs=engine_kwargs,
1481
- engine=engine,
1482
- schema=schema,
1483
- debug_mode=debug_mode,
1484
- )
1485
-
1486
1456
  @db_retry()
1487
1457
  def sleep(
1488
1458
  self,
@@ -1492,6 +1462,7 @@ class SystemDatabase(ABC):
1492
1462
  skip_sleep: bool = False,
1493
1463
  ) -> float:
1494
1464
  function_name = "DBOS.sleep"
1465
+ start_time = int(time.time() * 1000)
1495
1466
  recorded_output = self.check_operation_execution(
1496
1467
  workflow_uuid, function_id, function_name
1497
1468
  )
@@ -1502,7 +1473,7 @@ class SystemDatabase(ABC):
1502
1473
  if recorded_output is not None:
1503
1474
  dbos_logger.debug(f"Replaying sleep, id: {function_id}, seconds: {seconds}")
1504
1475
  assert recorded_output["output"] is not None, "no recorded end time"
1505
- end_time = _serialization.deserialize(recorded_output["output"])
1476
+ end_time = self.serializer.deserialize(recorded_output["output"])
1506
1477
  else:
1507
1478
  dbos_logger.debug(f"Running sleep, id: {function_id}, seconds: {seconds}")
1508
1479
  end_time = time.time() + seconds
@@ -1512,7 +1483,8 @@ class SystemDatabase(ABC):
1512
1483
  "workflow_uuid": workflow_uuid,
1513
1484
  "function_id": function_id,
1514
1485
  "function_name": function_name,
1515
- "output": _serialization.serialize(end_time),
1486
+ "started_at_epoch_ms": start_time,
1487
+ "output": self.serializer.serialize(end_time),
1516
1488
  "error": None,
1517
1489
  }
1518
1490
  )
@@ -1532,6 +1504,7 @@ class SystemDatabase(ABC):
1532
1504
  message: Any,
1533
1505
  ) -> None:
1534
1506
  function_name = "DBOS.setEvent"
1507
+ start_time = int(time.time() * 1000)
1535
1508
  with self.engine.begin() as c:
1536
1509
  recorded_output = self._check_operation_execution_txn(
1537
1510
  workflow_uuid, function_id, function_name, conn=c
@@ -1550,17 +1523,18 @@ class SystemDatabase(ABC):
1550
1523
  .values(
1551
1524
  workflow_uuid=workflow_uuid,
1552
1525
  key=key,
1553
- value=_serialization.serialize(message),
1526
+ value=self.serializer.serialize(message),
1554
1527
  )
1555
1528
  .on_conflict_do_update(
1556
1529
  index_elements=["workflow_uuid", "key"],
1557
- set_={"value": _serialization.serialize(message)},
1530
+ set_={"value": self.serializer.serialize(message)},
1558
1531
  )
1559
1532
  )
1560
1533
  output: OperationResultInternal = {
1561
1534
  "workflow_uuid": workflow_uuid,
1562
1535
  "function_id": function_id,
1563
1536
  "function_name": function_name,
1537
+ "started_at_epoch_ms": start_time,
1564
1538
  "output": None,
1565
1539
  "error": None,
1566
1540
  }
@@ -1578,11 +1552,11 @@ class SystemDatabase(ABC):
1578
1552
  .values(
1579
1553
  workflow_uuid=workflow_uuid,
1580
1554
  key=key,
1581
- value=_serialization.serialize(message),
1555
+ value=self.serializer.serialize(message),
1582
1556
  )
1583
1557
  .on_conflict_do_update(
1584
1558
  index_elements=["workflow_uuid", "key"],
1585
- set_={"value": _serialization.serialize(message)},
1559
+ set_={"value": self.serializer.serialize(message)},
1586
1560
  )
1587
1561
  )
1588
1562
 
@@ -1607,7 +1581,7 @@ class SystemDatabase(ABC):
1607
1581
  events: Dict[str, Any] = {}
1608
1582
  for row in rows:
1609
1583
  key = row[0]
1610
- value = _serialization.deserialize(row[1])
1584
+ value = self.serializer.deserialize(row[1])
1611
1585
  events[key] = value
1612
1586
 
1613
1587
  return events
@@ -1621,6 +1595,7 @@ class SystemDatabase(ABC):
1621
1595
  caller_ctx: Optional[GetEventWorkflowContext] = None,
1622
1596
  ) -> Any:
1623
1597
  function_name = "DBOS.getEvent"
1598
+ start_time = int(time.time() * 1000)
1624
1599
  get_sql = sa.select(
1625
1600
  SystemSchema.workflow_events.c.value,
1626
1601
  ).where(
@@ -1641,7 +1616,7 @@ class SystemDatabase(ABC):
1641
1616
  f"Replaying get_event, id: {caller_ctx['function_id']}, key: {key}"
1642
1617
  )
1643
1618
  if recorded_output["output"] is not None:
1644
- return _serialization.deserialize(recorded_output["output"])
1619
+ return self.serializer.deserialize(recorded_output["output"])
1645
1620
  else:
1646
1621
  raise Exception("No output recorded in the last get_event")
1647
1622
  else:
@@ -1666,7 +1641,7 @@ class SystemDatabase(ABC):
1666
1641
 
1667
1642
  value: Any = None
1668
1643
  if len(init_recv) > 0:
1669
- value = _serialization.deserialize(init_recv[0][0])
1644
+ value = self.serializer.deserialize(init_recv[0][0])
1670
1645
  else:
1671
1646
  # Wait for the notification
1672
1647
  actual_timeout = timeout_seconds
@@ -1684,7 +1659,7 @@ class SystemDatabase(ABC):
1684
1659
  with self.engine.begin() as c:
1685
1660
  final_recv = c.execute(get_sql).fetchall()
1686
1661
  if len(final_recv) > 0:
1687
- value = _serialization.deserialize(final_recv[0][0])
1662
+ value = self.serializer.deserialize(final_recv[0][0])
1688
1663
  condition.release()
1689
1664
  self.workflow_events_map.pop(payload)
1690
1665
 
@@ -1695,7 +1670,8 @@ class SystemDatabase(ABC):
1695
1670
  "workflow_uuid": caller_ctx["workflow_uuid"],
1696
1671
  "function_id": caller_ctx["function_id"],
1697
1672
  "function_name": function_name,
1698
- "output": _serialization.serialize(
1673
+ "started_at_epoch_ms": start_time,
1674
+ "output": self.serializer.serialize(
1699
1675
  value
1700
1676
  ), # None will be serialized to 'null'
1701
1677
  "error": None,
@@ -1703,8 +1679,41 @@ class SystemDatabase(ABC):
1703
1679
  )
1704
1680
  return value
1705
1681
 
1682
+ @db_retry()
1683
+ def get_queue_partitions(self, queue_name: str) -> List[str]:
1684
+ """
1685
+ Get all unique partition names associated with a queue for ENQUEUED workflows.
1686
+
1687
+ Args:
1688
+ queue_name: The name of the queue to get partitions for
1689
+
1690
+ Returns:
1691
+ A list of unique partition names for the queue
1692
+ """
1693
+ with self.engine.begin() as c:
1694
+ query = (
1695
+ sa.select(SystemSchema.workflow_status.c.queue_partition_key)
1696
+ .distinct()
1697
+ .where(SystemSchema.workflow_status.c.queue_name == queue_name)
1698
+ .where(
1699
+ SystemSchema.workflow_status.c.status.in_(
1700
+ [
1701
+ WorkflowStatusString.ENQUEUED.value,
1702
+ ]
1703
+ )
1704
+ )
1705
+ .where(SystemSchema.workflow_status.c.queue_partition_key.isnot(None))
1706
+ )
1707
+
1708
+ rows = c.execute(query).fetchall()
1709
+ return [row[0] for row in rows]
1710
+
1706
1711
  def start_queued_workflows(
1707
- self, queue: "Queue", executor_id: str, app_version: str
1712
+ self,
1713
+ queue: "Queue",
1714
+ executor_id: str,
1715
+ app_version: str,
1716
+ queue_partition_key: Optional[str],
1708
1717
  ) -> List[str]:
1709
1718
  if self._debug_mode:
1710
1719
  return []
@@ -1723,6 +1732,10 @@ class SystemDatabase(ABC):
1723
1732
  sa.select(sa.func.count())
1724
1733
  .select_from(SystemSchema.workflow_status)
1725
1734
  .where(SystemSchema.workflow_status.c.queue_name == queue.name)
1735
+ .where(
1736
+ SystemSchema.workflow_status.c.queue_partition_key
1737
+ == queue_partition_key
1738
+ )
1726
1739
  .where(
1727
1740
  SystemSchema.workflow_status.c.status
1728
1741
  != WorkflowStatusString.ENQUEUED.value
@@ -1747,6 +1760,10 @@ class SystemDatabase(ABC):
1747
1760
  )
1748
1761
  .select_from(SystemSchema.workflow_status)
1749
1762
  .where(SystemSchema.workflow_status.c.queue_name == queue.name)
1763
+ .where(
1764
+ SystemSchema.workflow_status.c.queue_partition_key
1765
+ == queue_partition_key
1766
+ )
1750
1767
  .where(
1751
1768
  SystemSchema.workflow_status.c.status
1752
1769
  == WorkflowStatusString.PENDING.value
@@ -1788,6 +1805,10 @@ class SystemDatabase(ABC):
1788
1805
  )
1789
1806
  .select_from(SystemSchema.workflow_status)
1790
1807
  .where(SystemSchema.workflow_status.c.queue_name == queue.name)
1808
+ .where(
1809
+ SystemSchema.workflow_status.c.queue_partition_key
1810
+ == queue_partition_key
1811
+ )
1791
1812
  .where(
1792
1813
  SystemSchema.workflow_status.c.status
1793
1814
  == WorkflowStatusString.ENQUEUED.value
@@ -1888,6 +1909,7 @@ class SystemDatabase(ABC):
1888
1909
 
1889
1910
  def call_function_as_step(self, fn: Callable[[], T], function_name: str) -> T:
1890
1911
  ctx = get_local_dbos_context()
1912
+ start_time = int(time.time() * 1000)
1891
1913
  if ctx and ctx.is_transaction():
1892
1914
  raise Exception(f"Invalid call to `{function_name}` inside a transaction")
1893
1915
  if ctx and ctx.is_workflow():
@@ -1897,12 +1919,13 @@ class SystemDatabase(ABC):
1897
1919
  )
1898
1920
  if res is not None:
1899
1921
  if res["output"] is not None:
1900
- resstat: SystemDatabase.T = _serialization.deserialize(
1922
+ resstat: SystemDatabase.T = self.serializer.deserialize(
1901
1923
  res["output"]
1902
1924
  )
1903
1925
  return resstat
1904
1926
  elif res["error"] is not None:
1905
- raise _serialization.deserialize_exception(res["error"])
1927
+ e: Exception = self.serializer.deserialize(res["error"])
1928
+ raise e
1906
1929
  else:
1907
1930
  raise Exception(
1908
1931
  f"Recorded output and error are both None for {function_name}"
@@ -1914,7 +1937,8 @@ class SystemDatabase(ABC):
1914
1937
  "workflow_uuid": ctx.workflow_id,
1915
1938
  "function_id": ctx.function_id,
1916
1939
  "function_name": function_name,
1917
- "output": _serialization.serialize(result),
1940
+ "started_at_epoch_ms": start_time,
1941
+ "output": self.serializer.serialize(result),
1918
1942
  "error": None,
1919
1943
  }
1920
1944
  )
@@ -1968,7 +1992,7 @@ class SystemDatabase(ABC):
1968
1992
  )
1969
1993
 
1970
1994
  # Serialize the value before storing
1971
- serialized_value = _serialization.serialize(value)
1995
+ serialized_value = self.serializer.serialize(value)
1972
1996
 
1973
1997
  # Insert the new stream entry
1974
1998
  c.execute(
@@ -1992,6 +2016,7 @@ class SystemDatabase(ABC):
1992
2016
  if value == _dbos_stream_closed_sentinel
1993
2017
  else "DBOS.writeStream"
1994
2018
  )
2019
+ start_time = int(time.time() * 1000)
1995
2020
 
1996
2021
  with self.engine.begin() as c:
1997
2022
 
@@ -2023,7 +2048,7 @@ class SystemDatabase(ABC):
2023
2048
  )
2024
2049
 
2025
2050
  # Serialize the value before storing
2026
- serialized_value = _serialization.serialize(value)
2051
+ serialized_value = self.serializer.serialize(value)
2027
2052
 
2028
2053
  # Insert the new stream entry
2029
2054
  c.execute(
@@ -2038,6 +2063,7 @@ class SystemDatabase(ABC):
2038
2063
  "workflow_uuid": workflow_uuid,
2039
2064
  "function_id": function_id,
2040
2065
  "function_name": function_name,
2066
+ "started_at_epoch_ms": start_time,
2041
2067
  "output": None,
2042
2068
  "error": None,
2043
2069
  }
@@ -2068,7 +2094,7 @@ class SystemDatabase(ABC):
2068
2094
  )
2069
2095
 
2070
2096
  # Deserialize the value before returning
2071
- return _serialization.deserialize(result[0])
2097
+ return self.serializer.deserialize(result[0])
2072
2098
 
2073
2099
  def garbage_collect(
2074
2100
  self, cutoff_epoch_timestamp_ms: Optional[int], rows_threshold: Optional[int]