dbos 2.1.0a3__py3-none-any.whl → 2.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dbos/_queue.py CHANGED
@@ -43,6 +43,7 @@ class Queue:
43
43
  *, # Disable positional arguments from here on
44
44
  worker_concurrency: Optional[int] = None,
45
45
  priority_enabled: bool = False,
46
+ partition_queue: bool = False,
46
47
  ) -> None:
47
48
  if (
48
49
  worker_concurrency is not None
@@ -57,6 +58,7 @@ class Queue:
57
58
  self.worker_concurrency = worker_concurrency
58
59
  self.limiter = limiter
59
60
  self.priority_enabled = priority_enabled
61
+ self.partition_queue = partition_queue
60
62
  from ._dbos import _get_or_create_dbos_registry
61
63
 
62
64
  registry = _get_or_create_dbos_registry()
@@ -78,6 +80,18 @@ class Queue:
78
80
  raise Exception(
79
81
  f"Priority is not enabled for queue {self.name}. Setting priority will not have any effect."
80
82
  )
83
+ if self.partition_queue and (
84
+ context is None or context.queue_partition_key is None
85
+ ):
86
+ raise Exception(
87
+ f"A workflow cannot be enqueued on partitioned queue {self.name} without a partition key"
88
+ )
89
+ if context and context.queue_partition_key and not self.partition_queue:
90
+ raise Exception(
91
+ f"You can only use a partition key on a partition-enabled queue. Key {context.queue_partition_key} was used with non-partitioned queue {self.name}"
92
+ )
93
+ if context and context.queue_partition_key and context.deduplication_id:
94
+ raise Exception("Deduplication is not supported for partitioned queues")
81
95
 
82
96
  dbos = _get_dbos_instance()
83
97
  return start_workflow(dbos, func, self.name, False, *args, **kwargs)
@@ -105,10 +119,21 @@ def queue_thread(stop_event: threading.Event, dbos: "DBOS") -> None:
105
119
  queues = dict(dbos._registry.queue_info_map)
106
120
  for _, queue in queues.items():
107
121
  try:
108
- wf_ids = dbos._sys_db.start_queued_workflows(
109
- queue, GlobalParams.executor_id, GlobalParams.app_version
110
- )
111
- for id in wf_ids:
122
+ if queue.partition_queue:
123
+ dequeued_workflows = []
124
+ queue_partition_keys = dbos._sys_db.get_queue_partitions(queue.name)
125
+ for key in queue_partition_keys:
126
+ dequeued_workflows += dbos._sys_db.start_queued_workflows(
127
+ queue,
128
+ GlobalParams.executor_id,
129
+ GlobalParams.app_version,
130
+ key,
131
+ )
132
+ else:
133
+ dequeued_workflows = dbos._sys_db.start_queued_workflows(
134
+ queue, GlobalParams.executor_id, GlobalParams.app_version, None
135
+ )
136
+ for id in dequeued_workflows:
112
137
  execute_workflow_by_id(dbos, id)
113
138
  except OperationalError as e:
114
139
  if isinstance(
@@ -77,6 +77,7 @@ class SystemSchema:
77
77
  Column("deduplication_id", Text(), nullable=True),
78
78
  Column("inputs", Text()),
79
79
  Column("priority", Integer(), nullable=False, server_default=text("'0'::int")),
80
+ Column("queue_partition_key", Text()),
80
81
  Index("workflow_status_created_at_index", "created_at"),
81
82
  Index("workflow_status_executor_id_index", "executor_id"),
82
83
  Index("workflow_status_status_index", "status"),
dbos/_serialization.py CHANGED
@@ -1,6 +1,6 @@
1
1
  import base64
2
2
  import pickle
3
- import types
3
+ from abc import ABC, abstractmethod
4
4
  from typing import Any, Dict, Optional, Tuple, TypedDict
5
5
 
6
6
  from ._logger import dbos_logger
@@ -11,47 +11,31 @@ class WorkflowInputs(TypedDict):
11
11
  kwargs: Dict[str, Any]
12
12
 
13
13
 
14
- def serialize(data: Any) -> str:
15
- pickled_data: bytes = pickle.dumps(data)
16
- encoded_data: str = base64.b64encode(pickled_data).decode("utf-8")
17
- return encoded_data
14
+ class Serializer(ABC):
18
15
 
16
+ @abstractmethod
17
+ def serialize(self, data: Any) -> str:
18
+ pass
19
19
 
20
- def serialize_args(data: WorkflowInputs) -> str:
21
- """Serialize args to a base64-encoded string using pickle."""
22
- pickled_data: bytes = pickle.dumps(data)
23
- encoded_data: str = base64.b64encode(pickled_data).decode("utf-8")
24
- return encoded_data
20
+ @abstractmethod
21
+ def deserialize(cls, serialized_data: str) -> Any:
22
+ pass
25
23
 
26
24
 
27
- def serialize_exception(data: Exception) -> str:
28
- """Serialize an Exception object to a base64-encoded string using pickle."""
29
- pickled_data: bytes = pickle.dumps(data)
30
- encoded_data: str = base64.b64encode(pickled_data).decode("utf-8")
31
- return encoded_data
25
+ class DefaultSerializer(Serializer):
32
26
 
27
+ def serialize(self, data: Any) -> str:
28
+ pickled_data: bytes = pickle.dumps(data)
29
+ encoded_data: str = base64.b64encode(pickled_data).decode("utf-8")
30
+ return encoded_data
33
31
 
34
- def deserialize(serialized_data: str) -> Any:
35
- """Deserialize a base64-encoded string back to a Python object using pickle."""
36
- pickled_data: bytes = base64.b64decode(serialized_data)
37
- return pickle.loads(pickled_data)
38
-
39
-
40
- def deserialize_args(serialized_data: str) -> WorkflowInputs:
41
- """Deserialize a base64-encoded string back to a Python object list using pickle."""
42
- pickled_data: bytes = base64.b64decode(serialized_data)
43
- args: WorkflowInputs = pickle.loads(pickled_data)
44
- return args
45
-
46
-
47
- def deserialize_exception(serialized_data: str) -> Exception:
48
- """Deserialize a base64-encoded string back to a Python Exception using pickle."""
49
- pickled_data: bytes = base64.b64decode(serialized_data)
50
- exc: Exception = pickle.loads(pickled_data)
51
- return exc
32
+ def deserialize(cls, serialized_data: str) -> Any:
33
+ pickled_data: bytes = base64.b64decode(serialized_data)
34
+ return pickle.loads(pickled_data)
52
35
 
53
36
 
54
37
  def safe_deserialize(
38
+ serializer: Serializer,
55
39
  workflow_id: str,
56
40
  *,
57
41
  serialized_input: Optional[str],
@@ -68,7 +52,9 @@ def safe_deserialize(
68
52
  input: Optional[WorkflowInputs]
69
53
  try:
70
54
  input = (
71
- deserialize_args(serialized_input) if serialized_input is not None else None
55
+ serializer.deserialize(serialized_input)
56
+ if serialized_input is not None
57
+ else None
72
58
  )
73
59
  except Exception as e:
74
60
  dbos_logger.warning(
@@ -78,7 +64,9 @@ def safe_deserialize(
78
64
  output: Optional[Any]
79
65
  try:
80
66
  output = (
81
- deserialize(serialized_output) if serialized_output is not None else None
67
+ serializer.deserialize(serialized_output)
68
+ if serialized_output is not None
69
+ else None
82
70
  )
83
71
  except Exception as e:
84
72
  dbos_logger.warning(
@@ -88,7 +76,7 @@ def safe_deserialize(
88
76
  exception: Optional[Exception]
89
77
  try:
90
78
  exception = (
91
- deserialize_exception(serialized_exception)
79
+ serializer.deserialize(serialized_exception)
92
80
  if serialized_exception is not None
93
81
  else None
94
82
  )
dbos/_sys_db.py CHANGED
@@ -30,7 +30,6 @@ from dbos._utils import (
30
30
  retriable_sqlite_exception,
31
31
  )
32
32
 
33
- from . import _serialization
34
33
  from ._context import get_local_dbos_context
35
34
  from ._error import (
36
35
  DBOSAwaitedWorkflowCancelledError,
@@ -44,6 +43,7 @@ from ._error import (
44
43
  )
45
44
  from ._logger import dbos_logger
46
45
  from ._schemas.system_database import SystemSchema
46
+ from ._serialization import Serializer, WorkflowInputs, safe_deserialize
47
47
 
48
48
  if TYPE_CHECKING:
49
49
  from ._queue import Queue
@@ -95,7 +95,7 @@ class WorkflowStatus:
95
95
  # All roles which the authenticated user could assume
96
96
  authenticated_roles: Optional[list[str]]
97
97
  # The deserialized workflow input object
98
- input: Optional[_serialization.WorkflowInputs]
98
+ input: Optional[WorkflowInputs]
99
99
  # The workflow's output, if any
100
100
  output: Optional[Any] = None
101
101
  # The error the workflow threw, if any
@@ -152,6 +152,8 @@ class WorkflowStatusInternal(TypedDict):
152
152
  priority: int
153
153
  # Serialized workflow inputs
154
154
  inputs: str
155
+ # If this workflow is enqueued on a partitioned queue, its partition key
156
+ queue_partition_key: Optional[str]
155
157
 
156
158
 
157
159
  class EnqueueOptionsInternal(TypedDict):
@@ -161,6 +163,8 @@ class EnqueueOptionsInternal(TypedDict):
161
163
  priority: Optional[int]
162
164
  # On what version the workflow is enqueued. Current version if not specified.
163
165
  app_version: Optional[str]
166
+ # If the workflow is enqueued on a partitioned queue, its partition key
167
+ queue_partition_key: Optional[str]
164
168
 
165
169
 
166
170
  class RecordedResult(TypedDict):
@@ -341,6 +345,39 @@ def db_retry(
341
345
 
342
346
  class SystemDatabase(ABC):
343
347
 
348
+ @staticmethod
349
+ def create(
350
+ system_database_url: str,
351
+ engine_kwargs: Dict[str, Any],
352
+ engine: Optional[sa.Engine],
353
+ schema: Optional[str],
354
+ serializer: Serializer,
355
+ debug_mode: bool = False,
356
+ ) -> "SystemDatabase":
357
+ """Factory method to create the appropriate SystemDatabase implementation based on URL."""
358
+ if system_database_url.startswith("sqlite"):
359
+ from ._sys_db_sqlite import SQLiteSystemDatabase
360
+
361
+ return SQLiteSystemDatabase(
362
+ system_database_url=system_database_url,
363
+ engine_kwargs=engine_kwargs,
364
+ engine=engine,
365
+ schema=schema,
366
+ serializer=serializer,
367
+ debug_mode=debug_mode,
368
+ )
369
+ else:
370
+ from ._sys_db_postgres import PostgresSystemDatabase
371
+
372
+ return PostgresSystemDatabase(
373
+ system_database_url=system_database_url,
374
+ engine_kwargs=engine_kwargs,
375
+ engine=engine,
376
+ schema=schema,
377
+ serializer=serializer,
378
+ debug_mode=debug_mode,
379
+ )
380
+
344
381
  def __init__(
345
382
  self,
346
383
  *,
@@ -348,6 +385,7 @@ class SystemDatabase(ABC):
348
385
  engine_kwargs: Dict[str, Any],
349
386
  engine: Optional[sa.Engine],
350
387
  schema: Optional[str],
388
+ serializer: Serializer,
351
389
  debug_mode: bool = False,
352
390
  ):
353
391
  import sqlalchemy.dialects.postgresql as pg
@@ -355,6 +393,8 @@ class SystemDatabase(ABC):
355
393
 
356
394
  self.dialect = sq if system_database_url.startswith("sqlite") else pg
357
395
 
396
+ self.serializer = serializer
397
+
358
398
  if system_database_url.startswith("sqlite"):
359
399
  self.schema = None
360
400
  else:
@@ -454,6 +494,7 @@ class SystemDatabase(ABC):
454
494
  deduplication_id=status["deduplication_id"],
455
495
  priority=status["priority"],
456
496
  inputs=status["inputs"],
497
+ queue_partition_key=status["queue_partition_key"],
457
498
  )
458
499
  .on_conflict_do_update(
459
500
  index_elements=["workflow_uuid"],
@@ -725,6 +766,7 @@ class SystemDatabase(ABC):
725
766
  SystemSchema.workflow_status.c.deduplication_id,
726
767
  SystemSchema.workflow_status.c.priority,
727
768
  SystemSchema.workflow_status.c.inputs,
769
+ SystemSchema.workflow_status.c.queue_partition_key,
728
770
  ).where(SystemSchema.workflow_status.c.workflow_uuid == workflow_uuid)
729
771
  ).fetchone()
730
772
  if row is None:
@@ -752,6 +794,7 @@ class SystemDatabase(ABC):
752
794
  "deduplication_id": row[16],
753
795
  "priority": row[17],
754
796
  "inputs": row[18],
797
+ "queue_partition_key": row[19],
755
798
  }
756
799
  return status
757
800
 
@@ -797,10 +840,11 @@ class SystemDatabase(ABC):
797
840
  status = row[0]
798
841
  if status == WorkflowStatusString.SUCCESS.value:
799
842
  output = row[1]
800
- return _serialization.deserialize(output)
843
+ return self.serializer.deserialize(output)
801
844
  elif status == WorkflowStatusString.ERROR.value:
802
845
  error = row[2]
803
- raise _serialization.deserialize_exception(error)
846
+ e: Exception = self.serializer.deserialize(error)
847
+ raise e
804
848
  elif status == WorkflowStatusString.CANCELLED.value:
805
849
  # Raise AwaitedWorkflowCancelledError here, not the cancellation exception
806
850
  # because the awaiting workflow is not being cancelled.
@@ -917,7 +961,8 @@ class SystemDatabase(ABC):
917
961
  raw_input = row[17] if load_input else None
918
962
  raw_output = row[18] if load_output else None
919
963
  raw_error = row[19] if load_output else None
920
- inputs, output, exception = _serialization.safe_deserialize(
964
+ inputs, output, exception = safe_deserialize(
965
+ self.serializer,
921
966
  info.workflow_id,
922
967
  serialized_input=raw_input,
923
968
  serialized_output=raw_output,
@@ -1028,7 +1073,8 @@ class SystemDatabase(ABC):
1028
1073
  raw_input = row[17] if load_input else None
1029
1074
 
1030
1075
  # Error and Output are not loaded because they should always be None for queued workflows.
1031
- inputs, output, exception = _serialization.safe_deserialize(
1076
+ inputs, output, exception = safe_deserialize(
1077
+ self.serializer,
1032
1078
  info.workflow_id,
1033
1079
  serialized_input=raw_input,
1034
1080
  serialized_output=None,
@@ -1079,7 +1125,8 @@ class SystemDatabase(ABC):
1079
1125
  ).fetchall()
1080
1126
  steps = []
1081
1127
  for row in rows:
1082
- _, output, exception = _serialization.safe_deserialize(
1128
+ _, output, exception = safe_deserialize(
1129
+ self.serializer,
1083
1130
  workflow_id,
1084
1131
  serialized_input=None,
1085
1132
  serialized_output=row[2],
@@ -1278,7 +1325,8 @@ class SystemDatabase(ABC):
1278
1325
  if row is None:
1279
1326
  return None
1280
1327
  elif row[1]:
1281
- raise _serialization.deserialize_exception(row[1])
1328
+ e: Exception = self.serializer.deserialize(row[1])
1329
+ raise e
1282
1330
  else:
1283
1331
  return str(row[0])
1284
1332
 
@@ -1317,7 +1365,7 @@ class SystemDatabase(ABC):
1317
1365
  sa.insert(SystemSchema.notifications).values(
1318
1366
  destination_uuid=destination_uuid,
1319
1367
  topic=topic,
1320
- message=_serialization.serialize(message),
1368
+ message=self.serializer.serialize(message),
1321
1369
  )
1322
1370
  )
1323
1371
  except DBAPIError as dbapi_error:
@@ -1354,7 +1402,7 @@ class SystemDatabase(ABC):
1354
1402
  if recorded_output is not None:
1355
1403
  dbos_logger.debug(f"Replaying recv, id: {function_id}, topic: {topic}")
1356
1404
  if recorded_output["output"] is not None:
1357
- return _serialization.deserialize(recorded_output["output"])
1405
+ return self.serializer.deserialize(recorded_output["output"])
1358
1406
  else:
1359
1407
  raise Exception("No output recorded in the last recv")
1360
1408
  else:
@@ -1421,13 +1469,13 @@ class SystemDatabase(ABC):
1421
1469
  rows = c.execute(delete_stmt).fetchall()
1422
1470
  message: Any = None
1423
1471
  if len(rows) > 0:
1424
- message = _serialization.deserialize(rows[0][0])
1472
+ message = self.serializer.deserialize(rows[0][0])
1425
1473
  self._record_operation_result_txn(
1426
1474
  {
1427
1475
  "workflow_uuid": workflow_uuid,
1428
1476
  "function_id": function_id,
1429
1477
  "function_name": function_name,
1430
- "output": _serialization.serialize(
1478
+ "output": self.serializer.serialize(
1431
1479
  message
1432
1480
  ), # None will be serialized to 'null'
1433
1481
  "error": None,
@@ -1453,36 +1501,6 @@ class SystemDatabase(ABC):
1453
1501
 
1454
1502
  PostgresSystemDatabase._reset_system_database(database_url)
1455
1503
 
1456
- @staticmethod
1457
- def create(
1458
- system_database_url: str,
1459
- engine_kwargs: Dict[str, Any],
1460
- engine: Optional[sa.Engine],
1461
- schema: Optional[str],
1462
- debug_mode: bool = False,
1463
- ) -> "SystemDatabase":
1464
- """Factory method to create the appropriate SystemDatabase implementation based on URL."""
1465
- if system_database_url.startswith("sqlite"):
1466
- from ._sys_db_sqlite import SQLiteSystemDatabase
1467
-
1468
- return SQLiteSystemDatabase(
1469
- system_database_url=system_database_url,
1470
- engine_kwargs=engine_kwargs,
1471
- engine=engine,
1472
- schema=schema,
1473
- debug_mode=debug_mode,
1474
- )
1475
- else:
1476
- from ._sys_db_postgres import PostgresSystemDatabase
1477
-
1478
- return PostgresSystemDatabase(
1479
- system_database_url=system_database_url,
1480
- engine_kwargs=engine_kwargs,
1481
- engine=engine,
1482
- schema=schema,
1483
- debug_mode=debug_mode,
1484
- )
1485
-
1486
1504
  @db_retry()
1487
1505
  def sleep(
1488
1506
  self,
@@ -1502,7 +1520,7 @@ class SystemDatabase(ABC):
1502
1520
  if recorded_output is not None:
1503
1521
  dbos_logger.debug(f"Replaying sleep, id: {function_id}, seconds: {seconds}")
1504
1522
  assert recorded_output["output"] is not None, "no recorded end time"
1505
- end_time = _serialization.deserialize(recorded_output["output"])
1523
+ end_time = self.serializer.deserialize(recorded_output["output"])
1506
1524
  else:
1507
1525
  dbos_logger.debug(f"Running sleep, id: {function_id}, seconds: {seconds}")
1508
1526
  end_time = time.time() + seconds
@@ -1512,7 +1530,7 @@ class SystemDatabase(ABC):
1512
1530
  "workflow_uuid": workflow_uuid,
1513
1531
  "function_id": function_id,
1514
1532
  "function_name": function_name,
1515
- "output": _serialization.serialize(end_time),
1533
+ "output": self.serializer.serialize(end_time),
1516
1534
  "error": None,
1517
1535
  }
1518
1536
  )
@@ -1550,11 +1568,11 @@ class SystemDatabase(ABC):
1550
1568
  .values(
1551
1569
  workflow_uuid=workflow_uuid,
1552
1570
  key=key,
1553
- value=_serialization.serialize(message),
1571
+ value=self.serializer.serialize(message),
1554
1572
  )
1555
1573
  .on_conflict_do_update(
1556
1574
  index_elements=["workflow_uuid", "key"],
1557
- set_={"value": _serialization.serialize(message)},
1575
+ set_={"value": self.serializer.serialize(message)},
1558
1576
  )
1559
1577
  )
1560
1578
  output: OperationResultInternal = {
@@ -1578,11 +1596,11 @@ class SystemDatabase(ABC):
1578
1596
  .values(
1579
1597
  workflow_uuid=workflow_uuid,
1580
1598
  key=key,
1581
- value=_serialization.serialize(message),
1599
+ value=self.serializer.serialize(message),
1582
1600
  )
1583
1601
  .on_conflict_do_update(
1584
1602
  index_elements=["workflow_uuid", "key"],
1585
- set_={"value": _serialization.serialize(message)},
1603
+ set_={"value": self.serializer.serialize(message)},
1586
1604
  )
1587
1605
  )
1588
1606
 
@@ -1607,7 +1625,7 @@ class SystemDatabase(ABC):
1607
1625
  events: Dict[str, Any] = {}
1608
1626
  for row in rows:
1609
1627
  key = row[0]
1610
- value = _serialization.deserialize(row[1])
1628
+ value = self.serializer.deserialize(row[1])
1611
1629
  events[key] = value
1612
1630
 
1613
1631
  return events
@@ -1641,7 +1659,7 @@ class SystemDatabase(ABC):
1641
1659
  f"Replaying get_event, id: {caller_ctx['function_id']}, key: {key}"
1642
1660
  )
1643
1661
  if recorded_output["output"] is not None:
1644
- return _serialization.deserialize(recorded_output["output"])
1662
+ return self.serializer.deserialize(recorded_output["output"])
1645
1663
  else:
1646
1664
  raise Exception("No output recorded in the last get_event")
1647
1665
  else:
@@ -1666,7 +1684,7 @@ class SystemDatabase(ABC):
1666
1684
 
1667
1685
  value: Any = None
1668
1686
  if len(init_recv) > 0:
1669
- value = _serialization.deserialize(init_recv[0][0])
1687
+ value = self.serializer.deserialize(init_recv[0][0])
1670
1688
  else:
1671
1689
  # Wait for the notification
1672
1690
  actual_timeout = timeout_seconds
@@ -1684,7 +1702,7 @@ class SystemDatabase(ABC):
1684
1702
  with self.engine.begin() as c:
1685
1703
  final_recv = c.execute(get_sql).fetchall()
1686
1704
  if len(final_recv) > 0:
1687
- value = _serialization.deserialize(final_recv[0][0])
1705
+ value = self.serializer.deserialize(final_recv[0][0])
1688
1706
  condition.release()
1689
1707
  self.workflow_events_map.pop(payload)
1690
1708
 
@@ -1695,7 +1713,7 @@ class SystemDatabase(ABC):
1695
1713
  "workflow_uuid": caller_ctx["workflow_uuid"],
1696
1714
  "function_id": caller_ctx["function_id"],
1697
1715
  "function_name": function_name,
1698
- "output": _serialization.serialize(
1716
+ "output": self.serializer.serialize(
1699
1717
  value
1700
1718
  ), # None will be serialized to 'null'
1701
1719
  "error": None,
@@ -1703,8 +1721,41 @@ class SystemDatabase(ABC):
1703
1721
  )
1704
1722
  return value
1705
1723
 
1724
+ @db_retry()
1725
+ def get_queue_partitions(self, queue_name: str) -> List[str]:
1726
+ """
1727
+ Get all unique partition names associated with a queue for ENQUEUED workflows.
1728
+
1729
+ Args:
1730
+ queue_name: The name of the queue to get partitions for
1731
+
1732
+ Returns:
1733
+ A list of unique partition names for the queue
1734
+ """
1735
+ with self.engine.begin() as c:
1736
+ query = (
1737
+ sa.select(SystemSchema.workflow_status.c.queue_partition_key)
1738
+ .distinct()
1739
+ .where(SystemSchema.workflow_status.c.queue_name == queue_name)
1740
+ .where(
1741
+ SystemSchema.workflow_status.c.status.in_(
1742
+ [
1743
+ WorkflowStatusString.ENQUEUED.value,
1744
+ ]
1745
+ )
1746
+ )
1747
+ .where(SystemSchema.workflow_status.c.queue_partition_key.isnot(None))
1748
+ )
1749
+
1750
+ rows = c.execute(query).fetchall()
1751
+ return [row[0] for row in rows]
1752
+
1706
1753
  def start_queued_workflows(
1707
- self, queue: "Queue", executor_id: str, app_version: str
1754
+ self,
1755
+ queue: "Queue",
1756
+ executor_id: str,
1757
+ app_version: str,
1758
+ queue_partition_key: Optional[str],
1708
1759
  ) -> List[str]:
1709
1760
  if self._debug_mode:
1710
1761
  return []
@@ -1723,6 +1774,10 @@ class SystemDatabase(ABC):
1723
1774
  sa.select(sa.func.count())
1724
1775
  .select_from(SystemSchema.workflow_status)
1725
1776
  .where(SystemSchema.workflow_status.c.queue_name == queue.name)
1777
+ .where(
1778
+ SystemSchema.workflow_status.c.queue_partition_key
1779
+ == queue_partition_key
1780
+ )
1726
1781
  .where(
1727
1782
  SystemSchema.workflow_status.c.status
1728
1783
  != WorkflowStatusString.ENQUEUED.value
@@ -1747,6 +1802,10 @@ class SystemDatabase(ABC):
1747
1802
  )
1748
1803
  .select_from(SystemSchema.workflow_status)
1749
1804
  .where(SystemSchema.workflow_status.c.queue_name == queue.name)
1805
+ .where(
1806
+ SystemSchema.workflow_status.c.queue_partition_key
1807
+ == queue_partition_key
1808
+ )
1750
1809
  .where(
1751
1810
  SystemSchema.workflow_status.c.status
1752
1811
  == WorkflowStatusString.PENDING.value
@@ -1788,6 +1847,10 @@ class SystemDatabase(ABC):
1788
1847
  )
1789
1848
  .select_from(SystemSchema.workflow_status)
1790
1849
  .where(SystemSchema.workflow_status.c.queue_name == queue.name)
1850
+ .where(
1851
+ SystemSchema.workflow_status.c.queue_partition_key
1852
+ == queue_partition_key
1853
+ )
1791
1854
  .where(
1792
1855
  SystemSchema.workflow_status.c.status
1793
1856
  == WorkflowStatusString.ENQUEUED.value
@@ -1897,12 +1960,13 @@ class SystemDatabase(ABC):
1897
1960
  )
1898
1961
  if res is not None:
1899
1962
  if res["output"] is not None:
1900
- resstat: SystemDatabase.T = _serialization.deserialize(
1963
+ resstat: SystemDatabase.T = self.serializer.deserialize(
1901
1964
  res["output"]
1902
1965
  )
1903
1966
  return resstat
1904
1967
  elif res["error"] is not None:
1905
- raise _serialization.deserialize_exception(res["error"])
1968
+ e: Exception = self.serializer.deserialize(res["error"])
1969
+ raise e
1906
1970
  else:
1907
1971
  raise Exception(
1908
1972
  f"Recorded output and error are both None for {function_name}"
@@ -1914,7 +1978,7 @@ class SystemDatabase(ABC):
1914
1978
  "workflow_uuid": ctx.workflow_id,
1915
1979
  "function_id": ctx.function_id,
1916
1980
  "function_name": function_name,
1917
- "output": _serialization.serialize(result),
1981
+ "output": self.serializer.serialize(result),
1918
1982
  "error": None,
1919
1983
  }
1920
1984
  )
@@ -1968,7 +2032,7 @@ class SystemDatabase(ABC):
1968
2032
  )
1969
2033
 
1970
2034
  # Serialize the value before storing
1971
- serialized_value = _serialization.serialize(value)
2035
+ serialized_value = self.serializer.serialize(value)
1972
2036
 
1973
2037
  # Insert the new stream entry
1974
2038
  c.execute(
@@ -2023,7 +2087,7 @@ class SystemDatabase(ABC):
2023
2087
  )
2024
2088
 
2025
2089
  # Serialize the value before storing
2026
- serialized_value = _serialization.serialize(value)
2090
+ serialized_value = self.serializer.serialize(value)
2027
2091
 
2028
2092
  # Insert the new stream entry
2029
2093
  c.execute(
@@ -2068,7 +2132,7 @@ class SystemDatabase(ABC):
2068
2132
  )
2069
2133
 
2070
2134
  # Deserialize the value before returning
2071
- return _serialization.deserialize(result[0])
2135
+ return self.serializer.deserialize(result[0])
2072
2136
 
2073
2137
  def garbage_collect(
2074
2138
  self, cutoff_epoch_timestamp_ms: Optional[int], rows_threshold: Optional[int]
dbos/cli/migration.py CHANGED
@@ -4,6 +4,7 @@ import sqlalchemy as sa
4
4
  import typer
5
5
 
6
6
  from dbos._app_db import ApplicationDatabase
7
+ from dbos._serialization import DefaultSerializer
7
8
  from dbos._sys_db import SystemDatabase
8
9
 
9
10
 
@@ -22,6 +23,7 @@ def migrate_dbos_databases(
22
23
  },
23
24
  engine=None,
24
25
  schema=schema,
26
+ serializer=DefaultSerializer(),
25
27
  )
26
28
  sys_db.run_migrations()
27
29
  if app_database_url:
@@ -33,6 +35,7 @@ def migrate_dbos_databases(
33
35
  "pool_size": 2,
34
36
  },
35
37
  schema=schema,
38
+ serializer=DefaultSerializer(),
36
39
  )
37
40
  app_db.run_migrations()
38
41
  except Exception as e:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: dbos
3
- Version: 2.1.0a3
3
+ Version: 2.2.0
4
4
  Summary: Ultra-lightweight durable execution in Python
5
5
  Author-Email: "DBOS, Inc." <contact@dbos.dev>
6
6
  License: MIT