dbos 2.2.0a2__tar.gz → 2.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. {dbos-2.2.0a2 → dbos-2.3.0}/PKG-INFO +1 -1
  2. {dbos-2.2.0a2 → dbos-2.3.0}/dbos/_client.py +16 -2
  3. {dbos-2.2.0a2 → dbos-2.3.0}/dbos/_context.py +8 -0
  4. {dbos-2.2.0a2 → dbos-2.3.0}/dbos/_core.py +11 -21
  5. {dbos-2.2.0a2 → dbos-2.3.0}/dbos/_dbos_config.py +1 -2
  6. {dbos-2.2.0a2 → dbos-2.3.0}/dbos/_kafka.py +6 -4
  7. {dbos-2.2.0a2 → dbos-2.3.0}/dbos/_logger.py +23 -16
  8. {dbos-2.2.0a2 → dbos-2.3.0}/dbos/_migration.py +12 -2
  9. {dbos-2.2.0a2 → dbos-2.3.0}/dbos/_queue.py +29 -4
  10. {dbos-2.2.0a2 → dbos-2.3.0}/dbos/_scheduler.py +5 -2
  11. {dbos-2.2.0a2 → dbos-2.3.0}/dbos/_schemas/system_database.py +1 -0
  12. {dbos-2.2.0a2 → dbos-2.3.0}/dbos/_serialization.py +7 -3
  13. {dbos-2.2.0a2 → dbos-2.3.0}/dbos/_sys_db.py +53 -1
  14. {dbos-2.2.0a2 → dbos-2.3.0}/dbos/_sys_db_postgres.py +1 -1
  15. {dbos-2.2.0a2 → dbos-2.3.0}/dbos/_tracer.py +24 -19
  16. {dbos-2.2.0a2 → dbos-2.3.0}/dbos/cli/cli.py +1 -15
  17. {dbos-2.2.0a2 → dbos-2.3.0}/pyproject.toml +1 -1
  18. {dbos-2.2.0a2 → dbos-2.3.0}/tests/test_client.py +32 -0
  19. {dbos-2.2.0a2 → dbos-2.3.0}/tests/test_config.py +29 -35
  20. {dbos-2.2.0a2 → dbos-2.3.0}/tests/test_dbos.py +60 -0
  21. {dbos-2.2.0a2 → dbos-2.3.0}/tests/test_failures.py +14 -1
  22. {dbos-2.2.0a2 → dbos-2.3.0}/tests/test_kafka.py +50 -17
  23. {dbos-2.2.0a2 → dbos-2.3.0}/tests/test_queue.py +78 -1
  24. {dbos-2.2.0a2 → dbos-2.3.0}/tests/test_scheduler.py +13 -0
  25. {dbos-2.2.0a2 → dbos-2.3.0}/tests/test_spans.py +1 -5
  26. {dbos-2.2.0a2 → dbos-2.3.0}/LICENSE +0 -0
  27. {dbos-2.2.0a2 → dbos-2.3.0}/README.md +0 -0
  28. {dbos-2.2.0a2 → dbos-2.3.0}/dbos/__init__.py +0 -0
  29. {dbos-2.2.0a2 → dbos-2.3.0}/dbos/__main__.py +0 -0
  30. {dbos-2.2.0a2 → dbos-2.3.0}/dbos/_admin_server.py +0 -0
  31. {dbos-2.2.0a2 → dbos-2.3.0}/dbos/_app_db.py +0 -0
  32. {dbos-2.2.0a2 → dbos-2.3.0}/dbos/_classproperty.py +0 -0
  33. {dbos-2.2.0a2 → dbos-2.3.0}/dbos/_conductor/conductor.py +0 -0
  34. {dbos-2.2.0a2 → dbos-2.3.0}/dbos/_conductor/protocol.py +0 -0
  35. {dbos-2.2.0a2 → dbos-2.3.0}/dbos/_croniter.py +0 -0
  36. {dbos-2.2.0a2 → dbos-2.3.0}/dbos/_dbos.py +0 -0
  37. {dbos-2.2.0a2 → dbos-2.3.0}/dbos/_debouncer.py +0 -0
  38. {dbos-2.2.0a2 → dbos-2.3.0}/dbos/_debug.py +0 -0
  39. {dbos-2.2.0a2 → dbos-2.3.0}/dbos/_docker_pg_helper.py +0 -0
  40. {dbos-2.2.0a2 → dbos-2.3.0}/dbos/_error.py +0 -0
  41. {dbos-2.2.0a2 → dbos-2.3.0}/dbos/_event_loop.py +0 -0
  42. {dbos-2.2.0a2 → dbos-2.3.0}/dbos/_fastapi.py +0 -0
  43. {dbos-2.2.0a2 → dbos-2.3.0}/dbos/_flask.py +0 -0
  44. {dbos-2.2.0a2 → dbos-2.3.0}/dbos/_kafka_message.py +0 -0
  45. {dbos-2.2.0a2 → dbos-2.3.0}/dbos/_outcome.py +0 -0
  46. {dbos-2.2.0a2 → dbos-2.3.0}/dbos/_recovery.py +0 -0
  47. {dbos-2.2.0a2 → dbos-2.3.0}/dbos/_registrations.py +0 -0
  48. {dbos-2.2.0a2 → dbos-2.3.0}/dbos/_roles.py +0 -0
  49. {dbos-2.2.0a2 → dbos-2.3.0}/dbos/_schemas/__init__.py +0 -0
  50. {dbos-2.2.0a2 → dbos-2.3.0}/dbos/_schemas/application_database.py +0 -0
  51. {dbos-2.2.0a2 → dbos-2.3.0}/dbos/_sys_db_sqlite.py +0 -0
  52. {dbos-2.2.0a2 → dbos-2.3.0}/dbos/_templates/dbos-db-starter/README.md +0 -0
  53. {dbos-2.2.0a2 → dbos-2.3.0}/dbos/_templates/dbos-db-starter/__package/__init__.py +0 -0
  54. {dbos-2.2.0a2 → dbos-2.3.0}/dbos/_templates/dbos-db-starter/__package/main.py.dbos +0 -0
  55. {dbos-2.2.0a2 → dbos-2.3.0}/dbos/_templates/dbos-db-starter/__package/schema.py +0 -0
  56. {dbos-2.2.0a2 → dbos-2.3.0}/dbos/_templates/dbos-db-starter/dbos-config.yaml.dbos +0 -0
  57. {dbos-2.2.0a2 → dbos-2.3.0}/dbos/_templates/dbos-db-starter/migrations/create_table.py.dbos +0 -0
  58. {dbos-2.2.0a2 → dbos-2.3.0}/dbos/_templates/dbos-db-starter/start_postgres_docker.py +0 -0
  59. {dbos-2.2.0a2 → dbos-2.3.0}/dbos/_utils.py +0 -0
  60. {dbos-2.2.0a2 → dbos-2.3.0}/dbos/_workflow_commands.py +0 -0
  61. {dbos-2.2.0a2 → dbos-2.3.0}/dbos/cli/_github_init.py +0 -0
  62. {dbos-2.2.0a2 → dbos-2.3.0}/dbos/cli/_template_init.py +0 -0
  63. {dbos-2.2.0a2 → dbos-2.3.0}/dbos/cli/migration.py +0 -0
  64. {dbos-2.2.0a2 → dbos-2.3.0}/dbos/dbos-config.schema.json +0 -0
  65. {dbos-2.2.0a2 → dbos-2.3.0}/dbos/py.typed +0 -0
  66. {dbos-2.2.0a2 → dbos-2.3.0}/tests/__init__.py +0 -0
  67. {dbos-2.2.0a2 → dbos-2.3.0}/tests/atexit_no_ctor.py +0 -0
  68. {dbos-2.2.0a2 → dbos-2.3.0}/tests/atexit_no_launch.py +0 -0
  69. {dbos-2.2.0a2 → dbos-2.3.0}/tests/classdefs.py +0 -0
  70. {dbos-2.2.0a2 → dbos-2.3.0}/tests/client_collateral.py +0 -0
  71. {dbos-2.2.0a2 → dbos-2.3.0}/tests/client_worker.py +0 -0
  72. {dbos-2.2.0a2 → dbos-2.3.0}/tests/conftest.py +0 -0
  73. {dbos-2.2.0a2 → dbos-2.3.0}/tests/dupname_classdefs1.py +0 -0
  74. {dbos-2.2.0a2 → dbos-2.3.0}/tests/dupname_classdefsa.py +0 -0
  75. {dbos-2.2.0a2 → dbos-2.3.0}/tests/more_classdefs.py +0 -0
  76. {dbos-2.2.0a2 → dbos-2.3.0}/tests/queuedworkflow.py +0 -0
  77. {dbos-2.2.0a2 → dbos-2.3.0}/tests/script_without_fastapi.py +0 -0
  78. {dbos-2.2.0a2 → dbos-2.3.0}/tests/test_admin_server.py +0 -0
  79. {dbos-2.2.0a2 → dbos-2.3.0}/tests/test_async.py +0 -0
  80. {dbos-2.2.0a2 → dbos-2.3.0}/tests/test_async_workflow_management.py +0 -0
  81. {dbos-2.2.0a2 → dbos-2.3.0}/tests/test_classdecorators.py +0 -0
  82. {dbos-2.2.0a2 → dbos-2.3.0}/tests/test_cli.py +0 -0
  83. {dbos-2.2.0a2 → dbos-2.3.0}/tests/test_concurrency.py +0 -0
  84. {dbos-2.2.0a2 → dbos-2.3.0}/tests/test_croniter.py +0 -0
  85. {dbos-2.2.0a2 → dbos-2.3.0}/tests/test_debouncer.py +0 -0
  86. {dbos-2.2.0a2 → dbos-2.3.0}/tests/test_debug.py +0 -0
  87. {dbos-2.2.0a2 → dbos-2.3.0}/tests/test_docker_secrets.py +0 -0
  88. {dbos-2.2.0a2 → dbos-2.3.0}/tests/test_fastapi.py +0 -0
  89. {dbos-2.2.0a2 → dbos-2.3.0}/tests/test_fastapi_roles.py +0 -0
  90. {dbos-2.2.0a2 → dbos-2.3.0}/tests/test_flask.py +0 -0
  91. {dbos-2.2.0a2 → dbos-2.3.0}/tests/test_outcome.py +0 -0
  92. {dbos-2.2.0a2 → dbos-2.3.0}/tests/test_package.py +0 -0
  93. {dbos-2.2.0a2 → dbos-2.3.0}/tests/test_schema_migration.py +0 -0
  94. {dbos-2.2.0a2 → dbos-2.3.0}/tests/test_singleton.py +0 -0
  95. {dbos-2.2.0a2 → dbos-2.3.0}/tests/test_sqlalchemy.py +0 -0
  96. {dbos-2.2.0a2 → dbos-2.3.0}/tests/test_streaming.py +0 -0
  97. {dbos-2.2.0a2 → dbos-2.3.0}/tests/test_workflow_introspection.py +0 -0
  98. {dbos-2.2.0a2 → dbos-2.3.0}/tests/test_workflow_management.py +0 -0
  99. {dbos-2.2.0a2 → dbos-2.3.0}/version/__init__.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: dbos
3
- Version: 2.2.0a2
3
+ Version: 2.3.0
4
4
  Summary: Ultra-lightweight durable execution in Python
5
5
  Author-Email: "DBOS, Inc." <contact@dbos.dev>
6
6
  License: MIT
@@ -1,4 +1,5 @@
1
1
  import asyncio
2
+ import json
2
3
  import time
3
4
  import uuid
4
5
  from typing import (
@@ -62,6 +63,9 @@ class EnqueueOptions(_EnqueueOptionsRequired, total=False):
62
63
  deduplication_id: str
63
64
  priority: int
64
65
  max_recovery_attempts: int
66
+ queue_partition_key: str
67
+ authenticated_user: str
68
+ authenticated_roles: list[str]
65
69
 
66
70
 
67
71
  def validate_enqueue_options(options: EnqueueOptions) -> None:
@@ -185,8 +189,16 @@ class DBOSClient:
185
189
  "deduplication_id": options.get("deduplication_id"),
186
190
  "priority": options.get("priority"),
187
191
  "app_version": options.get("app_version"),
192
+ "queue_partition_key": options.get("queue_partition_key"),
188
193
  }
189
194
 
195
+ authenticated_user = options.get("authenticated_user")
196
+ authenticated_roles = (
197
+ json.dumps(options.get("authenticated_roles"))
198
+ if options.get("authenticated_roles")
199
+ else None
200
+ )
201
+
190
202
  inputs: WorkflowInputs = {
191
203
  "args": args,
192
204
  "kwargs": kwargs,
@@ -200,9 +212,9 @@ class DBOSClient:
200
212
  "queue_name": queue_name,
201
213
  "app_version": enqueue_options_internal["app_version"],
202
214
  "config_name": None,
203
- "authenticated_user": None,
215
+ "authenticated_user": authenticated_user,
204
216
  "assumed_role": None,
205
- "authenticated_roles": None,
217
+ "authenticated_roles": authenticated_roles,
206
218
  "output": None,
207
219
  "error": None,
208
220
  "created_at": None,
@@ -221,6 +233,7 @@ class DBOSClient:
221
233
  else 0
222
234
  ),
223
235
  "inputs": self._serializer.serialize(inputs),
236
+ "queue_partition_key": enqueue_options_internal["queue_partition_key"],
224
237
  }
225
238
 
226
239
  self._sys_db.init_workflow(
@@ -286,6 +299,7 @@ class DBOSClient:
286
299
  "deduplication_id": None,
287
300
  "priority": 0,
288
301
  "inputs": self._serializer.serialize({"args": (), "kwargs": {}}),
302
+ "queue_partition_key": None,
289
303
  }
290
304
  with self._sys_db.engine.begin() as conn:
291
305
  self._sys_db._insert_workflow_status(
@@ -120,6 +120,8 @@ class DBOSContext:
120
120
  self.deduplication_id: Optional[str] = None
121
121
  # A user-specified priority for the enqueuing workflow.
122
122
  self.priority: Optional[int] = None
123
+ # If the workflow is enqueued on a partitioned queue, its partition key
124
+ self.queue_partition_key: Optional[str] = None
123
125
 
124
126
  def create_child(self) -> DBOSContext:
125
127
  rv = DBOSContext()
@@ -479,6 +481,7 @@ class SetEnqueueOptions:
479
481
  deduplication_id: Optional[str] = None,
480
482
  priority: Optional[int] = None,
481
483
  app_version: Optional[str] = None,
484
+ queue_partition_key: Optional[str] = None,
482
485
  ) -> None:
483
486
  self.created_ctx = False
484
487
  self.deduplication_id: Optional[str] = deduplication_id
@@ -491,6 +494,8 @@ class SetEnqueueOptions:
491
494
  self.saved_priority: Optional[int] = None
492
495
  self.app_version: Optional[str] = app_version
493
496
  self.saved_app_version: Optional[str] = None
497
+ self.queue_partition_key = queue_partition_key
498
+ self.saved_queue_partition_key: Optional[str] = None
494
499
 
495
500
  def __enter__(self) -> SetEnqueueOptions:
496
501
  # Code to create a basic context
@@ -505,6 +510,8 @@ class SetEnqueueOptions:
505
510
  ctx.priority = self.priority
506
511
  self.saved_app_version = ctx.app_version
507
512
  ctx.app_version = self.app_version
513
+ self.saved_queue_partition_key = ctx.queue_partition_key
514
+ ctx.queue_partition_key = self.queue_partition_key
508
515
  return self
509
516
 
510
517
  def __exit__(
@@ -517,6 +524,7 @@ class SetEnqueueOptions:
517
524
  curr_ctx.deduplication_id = self.saved_deduplication_id
518
525
  curr_ctx.priority = self.saved_priority
519
526
  curr_ctx.app_version = self.saved_app_version
527
+ curr_ctx.queue_partition_key = self.saved_queue_partition_key
520
528
  # Code to clean up the basic context if we created it
521
529
  if self.created_ctx:
522
530
  _clear_local_dbos_context()
@@ -93,14 +93,6 @@ TEMP_SEND_WF_NAME = "<temp>.temp_send_workflow"
93
93
  DEBOUNCER_WORKFLOW_NAME = "_dbos_debouncer_workflow"
94
94
 
95
95
 
96
- def check_is_in_coroutine() -> bool:
97
- try:
98
- asyncio.get_running_loop()
99
- return True
100
- except RuntimeError:
101
- return False
102
-
103
-
104
96
  class WorkflowHandleFuture(Generic[R]):
105
97
 
106
98
  def __init__(self, workflow_id: str, future: Future[R], dbos: "DBOS"):
@@ -303,6 +295,11 @@ def _init_workflow(
303
295
  else 0
304
296
  ),
305
297
  "inputs": dbos._serializer.serialize(inputs),
298
+ "queue_partition_key": (
299
+ enqueue_options["queue_partition_key"]
300
+ if enqueue_options is not None
301
+ else None
302
+ ),
306
303
  }
307
304
 
308
305
  # Synchronously record the status and inputs for workflows
@@ -571,6 +568,9 @@ def start_workflow(
571
568
  deduplication_id=local_ctx.deduplication_id if local_ctx is not None else None,
572
569
  priority=local_ctx.priority if local_ctx is not None else None,
573
570
  app_version=local_ctx.app_version if local_ctx is not None else None,
571
+ queue_partition_key=(
572
+ local_ctx.queue_partition_key if local_ctx is not None else None
573
+ ),
574
574
  )
575
575
  new_wf_id, new_wf_ctx = _get_new_wf()
576
576
 
@@ -664,6 +664,9 @@ async def start_workflow_async(
664
664
  deduplication_id=local_ctx.deduplication_id if local_ctx is not None else None,
665
665
  priority=local_ctx.priority if local_ctx is not None else None,
666
666
  app_version=local_ctx.app_version if local_ctx is not None else None,
667
+ queue_partition_key=(
668
+ local_ctx.queue_partition_key if local_ctx is not None else None
669
+ ),
667
670
  )
668
671
  new_wf_id, new_wf_ctx = _get_new_wf()
669
672
 
@@ -845,11 +848,6 @@ def workflow_wrapper(
845
848
  dbos._sys_db.record_get_result(workflow_id, serialized_r, None)
846
849
  return r
847
850
 
848
- if check_is_in_coroutine() and not inspect.iscoroutinefunction(func):
849
- dbos_logger.warning(
850
- f"Sync workflow ({get_dbos_func_name(func)}) shouldn't be invoked from within another async function. Define it as async or use asyncio.to_thread instead."
851
- )
852
-
853
851
  outcome = (
854
852
  wfOutcome.wrap(init_wf, dbos=dbos)
855
853
  .also(DBOSAssumeRole(rr))
@@ -1035,10 +1033,6 @@ def decorate_transaction(
1035
1033
  assert (
1036
1034
  ctx.is_workflow()
1037
1035
  ), "Transactions must be called from within workflows"
1038
- if check_is_in_coroutine():
1039
- dbos_logger.warning(
1040
- f"Transaction function ({get_dbos_func_name(func)}) shouldn't be invoked from within another async function. Use asyncio.to_thread instead."
1041
- )
1042
1036
  with DBOSAssumeRole(rr):
1043
1037
  return invoke_tx(*args, **kwargs)
1044
1038
  else:
@@ -1183,10 +1177,6 @@ def decorate_step(
1183
1177
 
1184
1178
  @wraps(func)
1185
1179
  def wrapper(*args: Any, **kwargs: Any) -> Any:
1186
- if check_is_in_coroutine() and not inspect.iscoroutinefunction(func):
1187
- dbos_logger.warning(
1188
- f"Sync step ({get_dbos_func_name(func)}) shouldn't be invoked from within another async function. Define it as async or use asyncio.to_thread instead."
1189
- )
1190
1180
  # If the step is called from a workflow, run it as a step.
1191
1181
  # Otherwise, run it as a normal function.
1192
1182
  ctx = get_local_dbos_context()
@@ -444,6 +444,7 @@ def configure_db_engine_parameters(
444
444
 
445
445
  # Configure user database engine parameters
446
446
  app_engine_kwargs: dict[str, Any] = {
447
+ "connect_args": {"application_name": "dbos_transact"},
447
448
  "pool_timeout": 30,
448
449
  "max_overflow": 0,
449
450
  "pool_size": 20,
@@ -477,8 +478,6 @@ def is_valid_database_url(database_url: str) -> bool:
477
478
  return True
478
479
  url = make_url(database_url)
479
480
  required_fields = [
480
- ("username", "Username must be specified in the connection URL"),
481
- ("host", "Host must be specified in the connection URL"),
482
481
  ("database", "Database name must be specified in the connection URL"),
483
482
  ]
484
483
  for field_name, error_message in required_fields:
@@ -1,6 +1,6 @@
1
1
  import re
2
2
  import threading
3
- from typing import TYPE_CHECKING, Any, Callable, NoReturn
3
+ from typing import TYPE_CHECKING, Any, Callable, Coroutine, NoReturn
4
4
 
5
5
  from confluent_kafka import Consumer, KafkaError, KafkaException
6
6
 
@@ -15,7 +15,9 @@ from ._kafka_message import KafkaMessage
15
15
  from ._logger import dbos_logger
16
16
  from ._registrations import get_dbos_func_name
17
17
 
18
- _KafkaConsumerWorkflow = Callable[[KafkaMessage], None]
18
+ _KafkaConsumerWorkflow = (
19
+ Callable[[KafkaMessage], None] | Callable[[KafkaMessage], Coroutine[Any, Any, None]]
20
+ )
19
21
 
20
22
  _kafka_queue: Queue
21
23
  _in_order_kafka_queues: dict[str, Queue] = {}
@@ -37,8 +39,8 @@ def _kafka_consumer_loop(
37
39
  in_order: bool,
38
40
  ) -> None:
39
41
 
40
- def on_error(err: KafkaError) -> NoReturn:
41
- raise KafkaException(err)
42
+ def on_error(err: KafkaError) -> None:
43
+ dbos_logger.error(f"Exception in Kafka consumer: {err}")
42
44
 
43
45
  config["error_cb"] = on_error
44
46
  if "auto.offset.reset" not in config:
@@ -68,30 +68,37 @@ def config_logger(config: "ConfigFile") -> None:
68
68
  )
69
69
  disable_otlp = config.get("telemetry", {}).get("disable_otlp", False) # type: ignore
70
70
 
71
- if not disable_otlp and otlp_logs_endpoints:
71
+ if not disable_otlp:
72
72
 
73
- from opentelemetry._logs import set_logger_provider
73
+ from opentelemetry._logs import get_logger_provider, set_logger_provider
74
74
  from opentelemetry.exporter.otlp.proto.http._log_exporter import OTLPLogExporter
75
75
  from opentelemetry.sdk._logs import LoggerProvider, LoggingHandler
76
76
  from opentelemetry.sdk._logs.export import BatchLogRecordProcessor
77
77
  from opentelemetry.sdk.resources import Resource
78
78
  from opentelemetry.semconv.attributes.service_attributes import SERVICE_NAME
79
79
 
80
- log_provider = LoggerProvider(
81
- Resource.create(
82
- attributes={
83
- SERVICE_NAME: config["name"],
84
- }
85
- )
86
- )
87
- set_logger_provider(log_provider)
88
- for e in otlp_logs_endpoints:
89
- log_provider.add_log_record_processor(
90
- BatchLogRecordProcessor(
91
- OTLPLogExporter(endpoint=e),
92
- export_timeout_millis=5000,
80
+ # Only set up OTLP provider and exporter if endpoints are provided
81
+ log_provider = get_logger_provider()
82
+ if otlp_logs_endpoints is not None:
83
+ if not isinstance(log_provider, LoggerProvider):
84
+ log_provider = LoggerProvider(
85
+ Resource.create(
86
+ attributes={
87
+ SERVICE_NAME: config["name"],
88
+ }
89
+ )
90
+ )
91
+ set_logger_provider(log_provider)
92
+
93
+ for e in otlp_logs_endpoints:
94
+ log_provider.add_log_record_processor(
95
+ BatchLogRecordProcessor(
96
+ OTLPLogExporter(endpoint=e),
97
+ export_timeout_millis=5000,
98
+ )
93
99
  )
94
- )
100
+
101
+ # Even if no endpoints are provided, we still need a LoggerProvider to create the LoggingHandler
95
102
  global _otlp_handler
96
103
  _otlp_handler = LoggingHandler(logger_provider=log_provider)
97
104
 
@@ -203,8 +203,14 @@ CREATE TABLE \"{schema}\".event_dispatch_kv (
203
203
  """
204
204
 
205
205
 
206
+ def get_dbos_migration_two(schema: str) -> str:
207
+ return f"""
208
+ ALTER TABLE \"{schema}\".workflow_status ADD COLUMN queue_partition_key TEXT;
209
+ """
210
+
211
+
206
212
  def get_dbos_migrations(schema: str) -> list[str]:
207
- return [get_dbos_migration_one(schema)]
213
+ return [get_dbos_migration_one(schema), get_dbos_migration_two(schema)]
208
214
 
209
215
 
210
216
  def get_sqlite_timestamp_expr() -> str:
@@ -293,4 +299,8 @@ CREATE TABLE streams (
293
299
  );
294
300
  """
295
301
 
296
- sqlite_migrations = [sqlite_migration_one]
302
+ sqlite_migration_two = """
303
+ ALTER TABLE workflow_status ADD COLUMN queue_partition_key TEXT;
304
+ """
305
+
306
+ sqlite_migrations = [sqlite_migration_one, sqlite_migration_two]
@@ -43,6 +43,7 @@ class Queue:
43
43
  *, # Disable positional arguments from here on
44
44
  worker_concurrency: Optional[int] = None,
45
45
  priority_enabled: bool = False,
46
+ partition_queue: bool = False,
46
47
  ) -> None:
47
48
  if (
48
49
  worker_concurrency is not None
@@ -57,6 +58,7 @@ class Queue:
57
58
  self.worker_concurrency = worker_concurrency
58
59
  self.limiter = limiter
59
60
  self.priority_enabled = priority_enabled
61
+ self.partition_queue = partition_queue
60
62
  from ._dbos import _get_or_create_dbos_registry
61
63
 
62
64
  registry = _get_or_create_dbos_registry()
@@ -78,6 +80,18 @@ class Queue:
78
80
  raise Exception(
79
81
  f"Priority is not enabled for queue {self.name}. Setting priority will not have any effect."
80
82
  )
83
+ if self.partition_queue and (
84
+ context is None or context.queue_partition_key is None
85
+ ):
86
+ raise Exception(
87
+ f"A workflow cannot be enqueued on partitioned queue {self.name} without a partition key"
88
+ )
89
+ if context and context.queue_partition_key and not self.partition_queue:
90
+ raise Exception(
91
+ f"You can only use a partition key on a partition-enabled queue. Key {context.queue_partition_key} was used with non-partitioned queue {self.name}"
92
+ )
93
+ if context and context.queue_partition_key and context.deduplication_id:
94
+ raise Exception("Deduplication is not supported for partitioned queues")
81
95
 
82
96
  dbos = _get_dbos_instance()
83
97
  return start_workflow(dbos, func, self.name, False, *args, **kwargs)
@@ -105,10 +119,21 @@ def queue_thread(stop_event: threading.Event, dbos: "DBOS") -> None:
105
119
  queues = dict(dbos._registry.queue_info_map)
106
120
  for _, queue in queues.items():
107
121
  try:
108
- wf_ids = dbos._sys_db.start_queued_workflows(
109
- queue, GlobalParams.executor_id, GlobalParams.app_version
110
- )
111
- for id in wf_ids:
122
+ if queue.partition_queue:
123
+ dequeued_workflows = []
124
+ queue_partition_keys = dbos._sys_db.get_queue_partitions(queue.name)
125
+ for key in queue_partition_keys:
126
+ dequeued_workflows += dbos._sys_db.start_queued_workflows(
127
+ queue,
128
+ GlobalParams.executor_id,
129
+ GlobalParams.app_version,
130
+ key,
131
+ )
132
+ else:
133
+ dequeued_workflows = dbos._sys_db.start_queued_workflows(
134
+ queue, GlobalParams.executor_id, GlobalParams.app_version, None
135
+ )
136
+ for id in dequeued_workflows:
112
137
  execute_workflow_by_id(dbos, id)
113
138
  except OperationalError as e:
114
139
  if isinstance(
@@ -2,7 +2,7 @@ import random
2
2
  import threading
3
3
  import traceback
4
4
  from datetime import datetime, timezone
5
- from typing import TYPE_CHECKING, Callable
5
+ from typing import TYPE_CHECKING, Any, Callable, Coroutine
6
6
 
7
7
  from ._logger import dbos_logger
8
8
  from ._queue import Queue
@@ -14,7 +14,10 @@ from ._context import SetWorkflowID
14
14
  from ._croniter import croniter # type: ignore
15
15
  from ._registrations import get_dbos_func_name
16
16
 
17
- ScheduledWorkflow = Callable[[datetime, datetime], None]
17
+ ScheduledWorkflow = (
18
+ Callable[[datetime, datetime], None]
19
+ | Callable[[datetime, datetime], Coroutine[Any, Any, None]]
20
+ )
18
21
 
19
22
 
20
23
  def scheduler_loop(
@@ -77,6 +77,7 @@ class SystemSchema:
77
77
  Column("deduplication_id", Text(), nullable=True),
78
78
  Column("inputs", Text()),
79
79
  Column("priority", Integer(), nullable=False, server_default=text("'0'::int")),
80
+ Column("queue_partition_key", Text()),
80
81
  Index("workflow_status_created_at_index", "created_at"),
81
82
  Index("workflow_status_executor_id_index", "executor_id"),
82
83
  Index("workflow_status_status_index", "status"),
@@ -25,9 +25,13 @@ class Serializer(ABC):
25
25
  class DefaultSerializer(Serializer):
26
26
 
27
27
  def serialize(self, data: Any) -> str:
28
- pickled_data: bytes = pickle.dumps(data)
29
- encoded_data: str = base64.b64encode(pickled_data).decode("utf-8")
30
- return encoded_data
28
+ try:
29
+ pickled_data: bytes = pickle.dumps(data)
30
+ encoded_data: str = base64.b64encode(pickled_data).decode("utf-8")
31
+ return encoded_data
32
+ except Exception as e:
33
+ dbos_logger.error(f"Error serializing object: {data}", exc_info=e)
34
+ raise
31
35
 
32
36
  def deserialize(cls, serialized_data: str) -> Any:
33
37
  pickled_data: bytes = base64.b64decode(serialized_data)
@@ -152,6 +152,8 @@ class WorkflowStatusInternal(TypedDict):
152
152
  priority: int
153
153
  # Serialized workflow inputs
154
154
  inputs: str
155
+ # If this workflow is enqueued on a partitioned queue, its partition key
156
+ queue_partition_key: Optional[str]
155
157
 
156
158
 
157
159
  class EnqueueOptionsInternal(TypedDict):
@@ -161,6 +163,8 @@ class EnqueueOptionsInternal(TypedDict):
161
163
  priority: Optional[int]
162
164
  # On what version the workflow is enqueued. Current version if not specified.
163
165
  app_version: Optional[str]
166
+ # If the workflow is enqueued on a partitioned queue, its partition key
167
+ queue_partition_key: Optional[str]
164
168
 
165
169
 
166
170
  class RecordedResult(TypedDict):
@@ -490,6 +494,7 @@ class SystemDatabase(ABC):
490
494
  deduplication_id=status["deduplication_id"],
491
495
  priority=status["priority"],
492
496
  inputs=status["inputs"],
497
+ queue_partition_key=status["queue_partition_key"],
493
498
  )
494
499
  .on_conflict_do_update(
495
500
  index_elements=["workflow_uuid"],
@@ -761,6 +766,7 @@ class SystemDatabase(ABC):
761
766
  SystemSchema.workflow_status.c.deduplication_id,
762
767
  SystemSchema.workflow_status.c.priority,
763
768
  SystemSchema.workflow_status.c.inputs,
769
+ SystemSchema.workflow_status.c.queue_partition_key,
764
770
  ).where(SystemSchema.workflow_status.c.workflow_uuid == workflow_uuid)
765
771
  ).fetchone()
766
772
  if row is None:
@@ -788,6 +794,7 @@ class SystemDatabase(ABC):
788
794
  "deduplication_id": row[16],
789
795
  "priority": row[17],
790
796
  "inputs": row[18],
797
+ "queue_partition_key": row[19],
791
798
  }
792
799
  return status
793
800
 
@@ -1714,8 +1721,41 @@ class SystemDatabase(ABC):
1714
1721
  )
1715
1722
  return value
1716
1723
 
1724
+ @db_retry()
1725
+ def get_queue_partitions(self, queue_name: str) -> List[str]:
1726
+ """
1727
+ Get all unique partition names associated with a queue for ENQUEUED workflows.
1728
+
1729
+ Args:
1730
+ queue_name: The name of the queue to get partitions for
1731
+
1732
+ Returns:
1733
+ A list of unique partition names for the queue
1734
+ """
1735
+ with self.engine.begin() as c:
1736
+ query = (
1737
+ sa.select(SystemSchema.workflow_status.c.queue_partition_key)
1738
+ .distinct()
1739
+ .where(SystemSchema.workflow_status.c.queue_name == queue_name)
1740
+ .where(
1741
+ SystemSchema.workflow_status.c.status.in_(
1742
+ [
1743
+ WorkflowStatusString.ENQUEUED.value,
1744
+ ]
1745
+ )
1746
+ )
1747
+ .where(SystemSchema.workflow_status.c.queue_partition_key.isnot(None))
1748
+ )
1749
+
1750
+ rows = c.execute(query).fetchall()
1751
+ return [row[0] for row in rows]
1752
+
1717
1753
  def start_queued_workflows(
1718
- self, queue: "Queue", executor_id: str, app_version: str
1754
+ self,
1755
+ queue: "Queue",
1756
+ executor_id: str,
1757
+ app_version: str,
1758
+ queue_partition_key: Optional[str],
1719
1759
  ) -> List[str]:
1720
1760
  if self._debug_mode:
1721
1761
  return []
@@ -1734,6 +1774,10 @@ class SystemDatabase(ABC):
1734
1774
  sa.select(sa.func.count())
1735
1775
  .select_from(SystemSchema.workflow_status)
1736
1776
  .where(SystemSchema.workflow_status.c.queue_name == queue.name)
1777
+ .where(
1778
+ SystemSchema.workflow_status.c.queue_partition_key
1779
+ == queue_partition_key
1780
+ )
1737
1781
  .where(
1738
1782
  SystemSchema.workflow_status.c.status
1739
1783
  != WorkflowStatusString.ENQUEUED.value
@@ -1758,6 +1802,10 @@ class SystemDatabase(ABC):
1758
1802
  )
1759
1803
  .select_from(SystemSchema.workflow_status)
1760
1804
  .where(SystemSchema.workflow_status.c.queue_name == queue.name)
1805
+ .where(
1806
+ SystemSchema.workflow_status.c.queue_partition_key
1807
+ == queue_partition_key
1808
+ )
1761
1809
  .where(
1762
1810
  SystemSchema.workflow_status.c.status
1763
1811
  == WorkflowStatusString.PENDING.value
@@ -1799,6 +1847,10 @@ class SystemDatabase(ABC):
1799
1847
  )
1800
1848
  .select_from(SystemSchema.workflow_status)
1801
1849
  .where(SystemSchema.workflow_status.c.queue_name == queue.name)
1850
+ .where(
1851
+ SystemSchema.workflow_status.c.queue_partition_key
1852
+ == queue_partition_key
1853
+ )
1802
1854
  .where(
1803
1855
  SystemSchema.workflow_status.c.status
1804
1856
  == WorkflowStatusString.ENQUEUED.value
@@ -41,7 +41,7 @@ class PostgresSystemDatabase(SystemDatabase):
41
41
  parameters={"db_name": sysdb_name},
42
42
  ).scalar():
43
43
  dbos_logger.info(f"Creating system database {sysdb_name}")
44
- conn.execute(sa.text(f"CREATE DATABASE {sysdb_name}"))
44
+ conn.execute(sa.text(f'CREATE DATABASE "{sysdb_name}"'))
45
45
  engine.dispose()
46
46
  else:
47
47
  # If we were provided an engine, validate it can connect
@@ -25,6 +25,10 @@ class DBOSTracer:
25
25
  def config(self, config: ConfigFile) -> None:
26
26
  self.otlp_attributes = config.get("telemetry", {}).get("otlp_attributes", {}) # type: ignore
27
27
  self.disable_otlp = config.get("telemetry", {}).get("disable_otlp", False) # type: ignore
28
+ otlp_traces_endpoints = (
29
+ config.get("telemetry", {}).get("OTLPExporter", {}).get("tracesEndpoint") # type: ignore
30
+ )
31
+
28
32
  if not self.disable_otlp:
29
33
  from opentelemetry import trace
30
34
  from opentelemetry.exporter.otlp.proto.http.trace_exporter import (
@@ -38,25 +42,26 @@ class DBOSTracer:
38
42
  )
39
43
  from opentelemetry.semconv.attributes.service_attributes import SERVICE_NAME
40
44
 
41
- if not isinstance(trace.get_tracer_provider(), TracerProvider):
42
- resource = Resource(
43
- attributes={
44
- SERVICE_NAME: config["name"],
45
- }
46
- )
47
-
48
- provider = TracerProvider(resource=resource)
49
- if os.environ.get("DBOS__CONSOLE_TRACES", None) is not None:
50
- processor = BatchSpanProcessor(ConsoleSpanExporter())
51
- provider.add_span_processor(processor)
52
- otlp_traces_endpoints = (
53
- config.get("telemetry", {}).get("OTLPExporter", {}).get("tracesEndpoint") # type: ignore
54
- )
55
- if otlp_traces_endpoints:
56
- for e in otlp_traces_endpoints:
57
- processor = BatchSpanProcessor(OTLPSpanExporter(endpoint=e))
58
- provider.add_span_processor(processor)
59
- trace.set_tracer_provider(provider)
45
+ tracer_provider = trace.get_tracer_provider()
46
+
47
+ # Only set up OTLP provider and exporter if endpoints are provided
48
+ if otlp_traces_endpoints is not None:
49
+ if not isinstance(tracer_provider, TracerProvider):
50
+ resource = Resource(
51
+ attributes={
52
+ SERVICE_NAME: config["name"],
53
+ }
54
+ )
55
+
56
+ tracer_provider = TracerProvider(resource=resource)
57
+ if os.environ.get("DBOS__CONSOLE_TRACES", None) is not None:
58
+ processor = BatchSpanProcessor(ConsoleSpanExporter())
59
+ tracer_provider.add_span_processor(processor)
60
+ trace.set_tracer_provider(tracer_provider)
61
+
62
+ for e in otlp_traces_endpoints:
63
+ processor = BatchSpanProcessor(OTLPSpanExporter(endpoint=e))
64
+ tracer_provider.add_span_processor(processor)
60
65
 
61
66
  def set_provider(self, provider: "Optional[TracerProvider]") -> None:
62
67
  self.provider = provider
@@ -140,26 +140,12 @@ def start() -> None:
140
140
  Forward kill signals to children.
141
141
 
142
142
  When we receive a signal, send it to the entire process group of the child.
143
- If that doesn't work, SIGKILL them then exit.
144
143
  """
145
144
  # Send the signal to the child's entire process group
146
145
  if process.poll() is None:
147
146
  os.killpg(os.getpgid(process.pid), signum)
148
147
 
149
- # Give some time for the child to terminate
150
- for _ in range(10): # Wait up to 1 second
151
- if process.poll() is not None:
152
- break
153
- time.sleep(0.1)
154
-
155
- # If the child is still running, force kill it
156
- if process.poll() is None:
157
- try:
158
- os.killpg(os.getpgid(process.pid), signal.SIGKILL)
159
- except Exception:
160
- pass
161
-
162
- # Exit immediately
148
+ # Exit
163
149
  os._exit(process.returncode if process.returncode is not None else 1)
164
150
 
165
151
  # Configure the single handler only on Unix-like systems.
@@ -34,7 +34,7 @@ classifiers = [
34
34
  "Topic :: Software Development :: Libraries :: Python Modules",
35
35
  "Framework :: AsyncIO",
36
36
  ]
37
- version = "2.2.0a2"
37
+ version = "2.3.0"
38
38
 
39
39
  [project.license]
40
40
  text = "MIT"