dbos 1.15.0a9__py3-none-any.whl → 2.4.0a5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dbos/_core.py CHANGED
@@ -23,7 +23,6 @@ from typing import (
23
23
  from dbos._outcome import Immediate, NoResult, Outcome, Pending
24
24
  from dbos._utils import GlobalParams, retriable_postgres_exception
25
25
 
26
- from . import _serialization
27
26
  from ._app_db import ApplicationDatabase, TransactionResultInternal
28
27
  from ._context import (
29
28
  DBOSAssumeRole,
@@ -94,14 +93,6 @@ TEMP_SEND_WF_NAME = "<temp>.temp_send_workflow"
94
93
  DEBOUNCER_WORKFLOW_NAME = "_dbos_debouncer_workflow"
95
94
 
96
95
 
97
- def check_is_in_coroutine() -> bool:
98
- try:
99
- asyncio.get_running_loop()
100
- return True
101
- except RuntimeError:
102
- return False
103
-
104
-
105
96
  class WorkflowHandleFuture(Generic[R]):
106
97
 
107
98
  def __init__(self, workflow_id: str, future: Future[R], dbos: "DBOS"):
@@ -116,10 +107,10 @@ class WorkflowHandleFuture(Generic[R]):
116
107
  try:
117
108
  r = self.future.result()
118
109
  except Exception as e:
119
- serialized_e = _serialization.serialize_exception(e)
110
+ serialized_e = self.dbos._serializer.serialize(e)
120
111
  self.dbos._sys_db.record_get_result(self.workflow_id, None, serialized_e)
121
112
  raise
122
- serialized_r = _serialization.serialize(r)
113
+ serialized_r = self.dbos._serializer.serialize(r)
123
114
  self.dbos._sys_db.record_get_result(self.workflow_id, serialized_r, None)
124
115
  return r
125
116
 
@@ -143,10 +134,10 @@ class WorkflowHandlePolling(Generic[R]):
143
134
  try:
144
135
  r: R = self.dbos._sys_db.await_workflow_result(self.workflow_id)
145
136
  except Exception as e:
146
- serialized_e = _serialization.serialize_exception(e)
137
+ serialized_e = self.dbos._serializer.serialize(e)
147
138
  self.dbos._sys_db.record_get_result(self.workflow_id, None, serialized_e)
148
139
  raise
149
- serialized_r = _serialization.serialize(r)
140
+ serialized_r = self.dbos._serializer.serialize(r)
150
141
  self.dbos._sys_db.record_get_result(self.workflow_id, serialized_r, None)
151
142
  return r
152
143
 
@@ -171,7 +162,7 @@ class WorkflowHandleAsyncTask(Generic[R]):
171
162
  try:
172
163
  r = await self.task
173
164
  except Exception as e:
174
- serialized_e = _serialization.serialize_exception(e)
165
+ serialized_e = self.dbos._serializer.serialize(e)
175
166
  await asyncio.to_thread(
176
167
  self.dbos._sys_db.record_get_result,
177
168
  self.workflow_id,
@@ -179,7 +170,7 @@ class WorkflowHandleAsyncTask(Generic[R]):
179
170
  serialized_e,
180
171
  )
181
172
  raise
182
- serialized_r = _serialization.serialize(r)
173
+ serialized_r = self.dbos._serializer.serialize(r)
183
174
  await asyncio.to_thread(
184
175
  self.dbos._sys_db.record_get_result, self.workflow_id, serialized_r, None
185
176
  )
@@ -207,7 +198,7 @@ class WorkflowHandleAsyncPolling(Generic[R]):
207
198
  self.dbos._sys_db.await_workflow_result, self.workflow_id
208
199
  )
209
200
  except Exception as e:
210
- serialized_e = _serialization.serialize_exception(e)
201
+ serialized_e = self.dbos._serializer.serialize(e)
211
202
  await asyncio.to_thread(
212
203
  self.dbos._sys_db.record_get_result,
213
204
  self.workflow_id,
@@ -215,7 +206,7 @@ class WorkflowHandleAsyncPolling(Generic[R]):
215
206
  serialized_e,
216
207
  )
217
208
  raise
218
- serialized_r = _serialization.serialize(r)
209
+ serialized_r = self.dbos._serializer.serialize(r)
219
210
  await asyncio.to_thread(
220
211
  self.dbos._sys_db.record_get_result, self.workflow_id, serialized_r, None
221
212
  )
@@ -303,7 +294,13 @@ def _init_workflow(
303
294
  if enqueue_options is not None
304
295
  else 0
305
296
  ),
306
- "inputs": _serialization.serialize_args(inputs),
297
+ "inputs": dbos._serializer.serialize(inputs),
298
+ "queue_partition_key": (
299
+ enqueue_options["queue_partition_key"]
300
+ if enqueue_options is not None
301
+ else None
302
+ ),
303
+ "forked_from": None,
307
304
  }
308
305
 
309
306
  # Synchronously record the status and inputs for workflows
@@ -319,7 +316,8 @@ def _init_workflow(
319
316
  "function_id": ctx.parent_workflow_fid,
320
317
  "function_name": wf_name,
321
318
  "output": None,
322
- "error": _serialization.serialize_exception(e),
319
+ "error": dbos._serializer.serialize(e),
320
+ "started_at_epoch_ms": int(time.time() * 1000),
323
321
  }
324
322
  dbos._sys_db.record_operation_result(result)
325
323
  raise
@@ -378,7 +376,7 @@ def _get_wf_invoke_func(
378
376
  dbos._sys_db.update_workflow_outcome(
379
377
  status["workflow_uuid"],
380
378
  "SUCCESS",
381
- output=_serialization.serialize(output),
379
+ output=dbos._serializer.serialize(output),
382
380
  )
383
381
  return output
384
382
  except DBOSWorkflowConflictIDError:
@@ -392,7 +390,7 @@ def _get_wf_invoke_func(
392
390
  dbos._sys_db.update_workflow_outcome(
393
391
  status["workflow_uuid"],
394
392
  "ERROR",
395
- error=_serialization.serialize_exception(error),
393
+ error=dbos._serializer.serialize(error),
396
394
  )
397
395
  raise
398
396
  finally:
@@ -464,7 +462,7 @@ def execute_workflow_by_id(dbos: "DBOS", workflow_id: str) -> "WorkflowHandle[An
464
462
  status = dbos._sys_db.get_workflow_status(workflow_id)
465
463
  if not status:
466
464
  raise DBOSRecoveryError(workflow_id, "Workflow status not found")
467
- inputs = _serialization.deserialize_args(status["inputs"])
465
+ inputs: WorkflowInputs = dbos._serializer.deserialize(status["inputs"])
468
466
  wf_func = dbos._registry.workflow_info_map.get(status["name"], None)
469
467
  if not wf_func:
470
468
  raise DBOSWorkflowFunctionNotFoundError(
@@ -572,6 +570,9 @@ def start_workflow(
572
570
  deduplication_id=local_ctx.deduplication_id if local_ctx is not None else None,
573
571
  priority=local_ctx.priority if local_ctx is not None else None,
574
572
  app_version=local_ctx.app_version if local_ctx is not None else None,
573
+ queue_partition_key=(
574
+ local_ctx.queue_partition_key if local_ctx is not None else None
575
+ ),
575
576
  )
576
577
  new_wf_id, new_wf_ctx = _get_new_wf()
577
578
 
@@ -665,6 +666,9 @@ async def start_workflow_async(
665
666
  deduplication_id=local_ctx.deduplication_id if local_ctx is not None else None,
666
667
  priority=local_ctx.priority if local_ctx is not None else None,
667
668
  app_version=local_ctx.app_version if local_ctx is not None else None,
669
+ queue_partition_key=(
670
+ local_ctx.queue_partition_key if local_ctx is not None else None
671
+ ),
668
672
  )
669
673
  new_wf_id, new_wf_ctx = _get_new_wf()
670
674
 
@@ -837,20 +841,15 @@ def workflow_wrapper(
837
841
  try:
838
842
  r = func()
839
843
  except Exception as e:
840
- serialized_e = _serialization.serialize_exception(e)
844
+ serialized_e = dbos._serializer.serialize(e)
841
845
  assert workflow_id is not None
842
846
  dbos._sys_db.record_get_result(workflow_id, None, serialized_e)
843
847
  raise
844
- serialized_r = _serialization.serialize(r)
848
+ serialized_r = dbos._serializer.serialize(r)
845
849
  assert workflow_id is not None
846
850
  dbos._sys_db.record_get_result(workflow_id, serialized_r, None)
847
851
  return r
848
852
 
849
- if check_is_in_coroutine() and not inspect.iscoroutinefunction(func):
850
- dbos_logger.warning(
851
- f"Sync workflow ({get_dbos_func_name(func)}) shouldn't be invoked from within another async function. Define it as async or use asyncio.to_thread instead."
852
- )
853
-
854
853
  outcome = (
855
854
  wfOutcome.wrap(init_wf, dbos=dbos)
856
855
  .also(DBOSAssumeRole(rr))
@@ -948,15 +947,15 @@ def decorate_transaction(
948
947
  f"Replaying transaction, id: {ctx.function_id}, name: {attributes['name']}"
949
948
  )
950
949
  if recorded_output["error"]:
951
- deserialized_error = (
952
- _serialization.deserialize_exception(
950
+ deserialized_error: Exception = (
951
+ dbos._serializer.deserialize(
953
952
  recorded_output["error"]
954
953
  )
955
954
  )
956
955
  has_recorded_error = True
957
956
  raise deserialized_error
958
957
  elif recorded_output["output"]:
959
- return _serialization.deserialize(
958
+ return dbos._serializer.deserialize(
960
959
  recorded_output["output"]
961
960
  )
962
961
  else:
@@ -969,7 +968,9 @@ def decorate_transaction(
969
968
  )
970
969
 
971
970
  output = func(*args, **kwargs)
972
- txn_output["output"] = _serialization.serialize(output)
971
+ txn_output["output"] = dbos._serializer.serialize(
972
+ output
973
+ )
973
974
  assert (
974
975
  ctx.sql_session is not None
975
976
  ), "Cannot find a database connection"
@@ -1010,8 +1011,8 @@ def decorate_transaction(
1010
1011
  finally:
1011
1012
  # Don't record the error if it was already recorded
1012
1013
  if txn_error and not has_recorded_error:
1013
- txn_output["error"] = (
1014
- _serialization.serialize_exception(txn_error)
1014
+ txn_output["error"] = dbos._serializer.serialize(
1015
+ txn_error
1015
1016
  )
1016
1017
  dbos._app_db.record_transaction_error(txn_output)
1017
1018
  return output
@@ -1034,10 +1035,6 @@ def decorate_transaction(
1034
1035
  assert (
1035
1036
  ctx.is_workflow()
1036
1037
  ), "Transactions must be called from within workflows"
1037
- if check_is_in_coroutine():
1038
- dbos_logger.warning(
1039
- f"Transaction function ({get_dbos_func_name(func)}) shouldn't be invoked from within another async function. Use asyncio.to_thread instead."
1040
- )
1041
1038
  with DBOSAssumeRole(rr):
1042
1039
  return invoke_tx(*args, **kwargs)
1043
1040
  else:
@@ -1123,15 +1120,16 @@ def decorate_step(
1123
1120
  "function_name": step_name,
1124
1121
  "output": None,
1125
1122
  "error": None,
1123
+ "started_at_epoch_ms": int(time.time() * 1000),
1126
1124
  }
1127
1125
 
1128
1126
  try:
1129
1127
  output = func()
1130
1128
  except Exception as error:
1131
- step_output["error"] = _serialization.serialize_exception(error)
1129
+ step_output["error"] = dbos._serializer.serialize(error)
1132
1130
  dbos._sys_db.record_operation_result(step_output)
1133
1131
  raise
1134
- step_output["output"] = _serialization.serialize(output)
1132
+ step_output["output"] = dbos._serializer.serialize(output)
1135
1133
  dbos._sys_db.record_operation_result(step_output)
1136
1134
  return output
1137
1135
 
@@ -1147,13 +1145,13 @@ def decorate_step(
1147
1145
  f"Replaying step, id: {ctx.function_id}, name: {attributes['name']}"
1148
1146
  )
1149
1147
  if recorded_output["error"] is not None:
1150
- deserialized_error = _serialization.deserialize_exception(
1148
+ deserialized_error: Exception = dbos._serializer.deserialize(
1151
1149
  recorded_output["error"]
1152
1150
  )
1153
1151
  raise deserialized_error
1154
1152
  elif recorded_output["output"] is not None:
1155
1153
  return cast(
1156
- R, _serialization.deserialize(recorded_output["output"])
1154
+ R, dbos._serializer.deserialize(recorded_output["output"])
1157
1155
  )
1158
1156
  else:
1159
1157
  raise Exception("Output and error are both None")
@@ -1182,10 +1180,6 @@ def decorate_step(
1182
1180
 
1183
1181
  @wraps(func)
1184
1182
  def wrapper(*args: Any, **kwargs: Any) -> Any:
1185
- if check_is_in_coroutine() and not inspect.iscoroutinefunction(func):
1186
- dbos_logger.warning(
1187
- f"Sync step ({get_dbos_func_name(func)}) shouldn't be invoked from within another async function. Define it as async or use asyncio.to_thread instead."
1188
- )
1189
1183
  # If the step is called from a workflow, run it as a step.
1190
1184
  # Otherwise, run it as a normal function.
1191
1185
  ctx = get_local_dbos_context()
@@ -1278,21 +1272,24 @@ def recv(dbos: "DBOS", topic: Optional[str] = None, timeout_seconds: float = 60)
1278
1272
  def set_event(dbos: "DBOS", key: str, value: Any) -> None:
1279
1273
  cur_ctx = get_local_dbos_context()
1280
1274
  if cur_ctx is not None:
1281
- # Must call it within a workflow
1282
- assert (
1283
- cur_ctx.is_workflow()
1284
- ), "set_event() must be called from within a workflow"
1285
- attributes: TracedAttributes = {
1286
- "name": "set_event",
1287
- }
1288
- with EnterDBOSStep(attributes):
1289
- ctx = assert_current_dbos_context()
1290
- dbos._sys_db.set_event(
1291
- ctx.workflow_id, ctx.curr_step_function_id, key, value
1275
+ if cur_ctx.is_workflow():
1276
+ # If called from a workflow function, run as a step
1277
+ attributes: TracedAttributes = {
1278
+ "name": "set_event",
1279
+ }
1280
+ with EnterDBOSStep(attributes):
1281
+ ctx = assert_current_dbos_context()
1282
+ dbos._sys_db.set_event_from_workflow(
1283
+ ctx.workflow_id, ctx.curr_step_function_id, key, value
1284
+ )
1285
+ elif cur_ctx.is_step():
1286
+ dbos._sys_db.set_event_from_step(cur_ctx.workflow_id, key, value)
1287
+ else:
1288
+ raise DBOSException(
1289
+ "set_event() must be called from within a workflow or step"
1292
1290
  )
1293
1291
  else:
1294
- # Cannot call it from outside of a workflow
1295
- raise DBOSException("set_event() must be called from within a workflow")
1292
+ raise DBOSException("set_event() must be called from within a workflow or step")
1296
1293
 
1297
1294
 
1298
1295
  def get_event(
dbos/_dbos.py CHANGED
@@ -31,6 +31,7 @@ from typing import (
31
31
 
32
32
  from dbos._conductor.conductor import ConductorWebsocket
33
33
  from dbos._debouncer import debouncer_workflow
34
+ from dbos._serialization import DefaultSerializer, Serializer
34
35
  from dbos._sys_db import SystemDatabase, WorkflowStatus
35
36
  from dbos._utils import INTERNAL_QUEUE_NAME, GlobalParams
36
37
  from dbos._workflow_commands import fork_workflow, list_queued_workflows, list_workflows
@@ -341,6 +342,8 @@ class DBOS:
341
342
  self.conductor_websocket: Optional[ConductorWebsocket] = None
342
343
  self._background_event_loop: BackgroundEventLoop = BackgroundEventLoop()
343
344
  self._active_workflows_set: set[str] = set()
345
+ serializer = config.get("serializer")
346
+ self._serializer: Serializer = serializer if serializer else DefaultSerializer()
344
347
 
345
348
  # Globally set the application version and executor ID.
346
349
  # In DBOS Cloud, instead use the values supplied through environment variables.
@@ -449,28 +452,35 @@ class DBOS:
449
452
  assert self._config["database"]["sys_db_engine_kwargs"] is not None
450
453
  # Get the schema configuration, use "dbos" as default
451
454
  schema = self._config.get("dbos_system_schema", "dbos")
455
+ dbos_logger.debug("Creating system database")
452
456
  self._sys_db_field = SystemDatabase.create(
453
457
  system_database_url=get_system_database_url(self._config),
454
458
  engine_kwargs=self._config["database"]["sys_db_engine_kwargs"],
455
459
  engine=self._config["system_database_engine"],
456
460
  debug_mode=debug_mode,
457
461
  schema=schema,
462
+ serializer=self._serializer,
463
+ executor_id=GlobalParams.executor_id,
458
464
  )
459
465
  assert self._config["database"]["db_engine_kwargs"] is not None
460
466
  if self._config["database_url"]:
467
+ dbos_logger.debug("Creating application database")
461
468
  self._app_db_field = ApplicationDatabase.create(
462
469
  database_url=self._config["database_url"],
463
470
  engine_kwargs=self._config["database"]["db_engine_kwargs"],
464
471
  debug_mode=debug_mode,
465
472
  schema=schema,
473
+ serializer=self._serializer,
466
474
  )
467
475
 
468
476
  if debug_mode:
469
477
  return
470
478
 
471
479
  # Run migrations for the system and application databases
480
+ dbos_logger.debug("Running system database migrations")
472
481
  self._sys_db.run_migrations()
473
482
  if self._app_db:
483
+ dbos_logger.debug("Running application database migrations")
474
484
  self._app_db.run_migrations()
475
485
 
476
486
  admin_port = self._config.get("runtimeConfig", {}).get("admin_port")
@@ -481,10 +491,12 @@ class DBOS:
481
491
  )
482
492
  if run_admin_server:
483
493
  try:
494
+ dbos_logger.debug("Starting admin server")
484
495
  self._admin_server_field = AdminServer(dbos=self, port=admin_port)
485
496
  except Exception as e:
486
497
  dbos_logger.warning(f"Failed to start admin server: {e}")
487
498
 
499
+ dbos_logger.debug("Retrieving local pending workflows for recovery")
488
500
  workflow_ids = self._sys_db.get_pending_workflows(
489
501
  GlobalParams.executor_id, GlobalParams.app_version
490
502
  )
@@ -500,6 +512,7 @@ class DBOS:
500
512
  self._executor.submit(startup_recovery_thread, self, workflow_ids)
501
513
 
502
514
  # Listen to notifications
515
+ dbos_logger.debug("Starting notifications listener thread")
503
516
  notification_listener_thread = threading.Thread(
504
517
  target=self._sys_db._notification_listener,
505
518
  daemon=True,
@@ -511,6 +524,7 @@ class DBOS:
511
524
  self._registry.get_internal_queue()
512
525
 
513
526
  # Start the queue thread
527
+ dbos_logger.debug("Starting queue thread")
514
528
  evt = threading.Event()
515
529
  self.background_thread_stop_events.append(evt)
516
530
  bg_queue_thread = threading.Thread(
@@ -526,6 +540,7 @@ class DBOS:
526
540
  self.conductor_url = f"wss://{dbos_domain}/conductor/v1alpha1"
527
541
  evt = threading.Event()
528
542
  self.background_thread_stop_events.append(evt)
543
+ dbos_logger.debug("Starting Conductor thread")
529
544
  self.conductor_websocket = ConductorWebsocket(
530
545
  self,
531
546
  conductor_url=self.conductor_url,
@@ -536,6 +551,7 @@ class DBOS:
536
551
  self._background_threads.append(self.conductor_websocket)
537
552
 
538
553
  # Grab any pollers that were deferred and start them
554
+ dbos_logger.debug("Starting event receivers")
539
555
  for evt, func, args, kwargs in self._registry.pollers:
540
556
  self.poller_stop_events.append(evt)
541
557
  poller_thread = threading.Thread(
@@ -1112,7 +1128,9 @@ class DBOS:
1112
1128
  end_time: Optional[str] = None,
1113
1129
  name: Optional[str] = None,
1114
1130
  app_version: Optional[str] = None,
1131
+ forked_from: Optional[str] = None,
1115
1132
  user: Optional[str] = None,
1133
+ queue_name: Optional[str] = None,
1116
1134
  limit: Optional[int] = None,
1117
1135
  offset: Optional[int] = None,
1118
1136
  sort_desc: bool = False,
@@ -1129,6 +1147,7 @@ class DBOS:
1129
1147
  end_time=end_time,
1130
1148
  name=name,
1131
1149
  app_version=app_version,
1150
+ forked_from=forked_from,
1132
1151
  user=user,
1133
1152
  limit=limit,
1134
1153
  offset=offset,
@@ -1136,6 +1155,7 @@ class DBOS:
1136
1155
  workflow_id_prefix=workflow_id_prefix,
1137
1156
  load_input=load_input,
1138
1157
  load_output=load_output,
1158
+ queue_name=queue_name,
1139
1159
  )
1140
1160
 
1141
1161
  return _get_dbos_instance()._sys_db.call_function_as_step(
@@ -1152,6 +1172,7 @@ class DBOS:
1152
1172
  end_time: Optional[str] = None,
1153
1173
  name: Optional[str] = None,
1154
1174
  app_version: Optional[str] = None,
1175
+ forked_from: Optional[str] = None,
1155
1176
  user: Optional[str] = None,
1156
1177
  limit: Optional[int] = None,
1157
1178
  offset: Optional[int] = None,
@@ -1169,6 +1190,7 @@ class DBOS:
1169
1190
  end_time=end_time,
1170
1191
  name=name,
1171
1192
  app_version=app_version,
1193
+ forked_from=forked_from,
1172
1194
  user=user,
1173
1195
  limit=limit,
1174
1196
  offset=offset,
@@ -1184,6 +1206,7 @@ class DBOS:
1184
1206
  *,
1185
1207
  queue_name: Optional[str] = None,
1186
1208
  status: Optional[Union[str, List[str]]] = None,
1209
+ forked_from: Optional[str] = None,
1187
1210
  start_time: Optional[str] = None,
1188
1211
  end_time: Optional[str] = None,
1189
1212
  name: Optional[str] = None,
@@ -1197,6 +1220,7 @@ class DBOS:
1197
1220
  _get_dbos_instance()._sys_db,
1198
1221
  queue_name=queue_name,
1199
1222
  status=status,
1223
+ forked_from=forked_from,
1200
1224
  start_time=start_time,
1201
1225
  end_time=end_time,
1202
1226
  name=name,
@@ -1216,6 +1240,7 @@ class DBOS:
1216
1240
  *,
1217
1241
  queue_name: Optional[str] = None,
1218
1242
  status: Optional[Union[str, List[str]]] = None,
1243
+ forked_from: Optional[str] = None,
1219
1244
  start_time: Optional[str] = None,
1220
1245
  end_time: Optional[str] = None,
1221
1246
  name: Optional[str] = None,
@@ -1229,6 +1254,7 @@ class DBOS:
1229
1254
  cls.list_queued_workflows,
1230
1255
  queue_name=queue_name,
1231
1256
  status=status,
1257
+ forked_from=forked_from,
1232
1258
  start_time=start_time,
1233
1259
  end_time=end_time,
1234
1260
  name=name,
dbos/_dbos_config.py CHANGED
@@ -7,6 +7,8 @@ import sqlalchemy as sa
7
7
  import yaml
8
8
  from sqlalchemy import make_url
9
9
 
10
+ from dbos._serialization import Serializer
11
+
10
12
  from ._error import DBOSInitializationError
11
13
  from ._logger import dbos_logger
12
14
  from ._schemas.system_database import SystemSchema
@@ -37,6 +39,7 @@ class DBOSConfig(TypedDict, total=False):
37
39
  enable_otlp (bool): If True, enable built-in DBOS OTLP tracing and logging.
38
40
  system_database_engine (sa.Engine): A custom system database engine. If provided, DBOS will not create an engine but use this instead.
39
41
  conductor_key (str): An API key for DBOS Conductor. Pass this in to connect your process to Conductor.
42
+ serializer (Serializer): A custom serializer and deserializer DBOS uses when storing program data in the system database
40
43
  """
41
44
 
42
45
  name: str
@@ -57,6 +60,7 @@ class DBOSConfig(TypedDict, total=False):
57
60
  enable_otlp: Optional[bool]
58
61
  system_database_engine: Optional[sa.Engine]
59
62
  conductor_key: Optional[str]
63
+ serializer: Optional[Serializer]
60
64
 
61
65
 
62
66
  class RuntimeConfig(TypedDict, total=False):
@@ -67,16 +71,6 @@ class RuntimeConfig(TypedDict, total=False):
67
71
 
68
72
 
69
73
  class DatabaseConfig(TypedDict, total=False):
70
- """
71
- Internal data structure containing the DBOS database configuration.
72
- Attributes:
73
- sys_db_name (str): System database name
74
- sys_db_pool_size (int): System database pool size
75
- db_engine_kwargs (Dict[str, Any]): SQLAlchemy engine kwargs
76
- migrate (List[str]): Migration commands to run on startup
77
- dbos_system_schema (str): Schema name for DBOS system tables. Defaults to "dbos".
78
- """
79
-
80
74
  sys_db_pool_size: Optional[int]
81
75
  db_engine_kwargs: Optional[Dict[str, Any]]
82
76
  sys_db_engine_kwargs: Optional[Dict[str, Any]]
@@ -450,6 +444,7 @@ def configure_db_engine_parameters(
450
444
 
451
445
  # Configure user database engine parameters
452
446
  app_engine_kwargs: dict[str, Any] = {
447
+ "connect_args": {"application_name": "dbos_transact"},
453
448
  "pool_timeout": 30,
454
449
  "max_overflow": 0,
455
450
  "pool_size": 20,
@@ -483,8 +478,6 @@ def is_valid_database_url(database_url: str) -> bool:
483
478
  return True
484
479
  url = make_url(database_url)
485
480
  required_fields = [
486
- ("username", "Username must be specified in the connection URL"),
487
- ("host", "Host must be specified in the connection URL"),
488
481
  ("database", "Database name must be specified in the connection URL"),
489
482
  ]
490
483
  for field_name, error_message in required_fields:
dbos/_debouncer.py CHANGED
@@ -86,6 +86,7 @@ def debouncer_workflow(
86
86
  dbos = _get_dbos_instance()
87
87
 
88
88
  workflow_inputs: WorkflowInputs = {"args": args, "kwargs": kwargs}
89
+
89
90
  # Every time the debounced workflow is called, a message is sent to this workflow.
90
91
  # It waits until debounce_period_sec have passed since the last message or until
91
92
  # debounce_timeout_sec has elapsed.
@@ -95,7 +96,10 @@ def debouncer_workflow(
95
96
  if options["debounce_timeout_sec"]
96
97
  else math.inf
97
98
  )
98
- debounce_deadline_epoch_sec = dbos._sys_db.call_function_as_step(get_debounce_deadline_epoch_sec, "get_debounce_deadline_epoch_sec")
99
+
100
+ debounce_deadline_epoch_sec = dbos._sys_db.call_function_as_step(
101
+ get_debounce_deadline_epoch_sec, "get_debounce_deadline_epoch_sec"
102
+ )
99
103
  debounce_period_sec = initial_debounce_period_sec
100
104
  while time.time() < debounce_deadline_epoch_sec:
101
105
  time_until_deadline = max(debounce_deadline_epoch_sec - time.time(), 0)
dbos/_kafka.py CHANGED
@@ -1,6 +1,6 @@
1
1
  import re
2
2
  import threading
3
- from typing import TYPE_CHECKING, Any, Callable, NoReturn
3
+ from typing import TYPE_CHECKING, Any, Callable, Coroutine, NoReturn
4
4
 
5
5
  from confluent_kafka import Consumer, KafkaError, KafkaException
6
6
 
@@ -15,7 +15,9 @@ from ._kafka_message import KafkaMessage
15
15
  from ._logger import dbos_logger
16
16
  from ._registrations import get_dbos_func_name
17
17
 
18
- _KafkaConsumerWorkflow = Callable[[KafkaMessage], None]
18
+ _KafkaConsumerWorkflow = (
19
+ Callable[[KafkaMessage], None] | Callable[[KafkaMessage], Coroutine[Any, Any, None]]
20
+ )
19
21
 
20
22
  _kafka_queue: Queue
21
23
  _in_order_kafka_queues: dict[str, Queue] = {}
@@ -37,8 +39,8 @@ def _kafka_consumer_loop(
37
39
  in_order: bool,
38
40
  ) -> None:
39
41
 
40
- def on_error(err: KafkaError) -> NoReturn:
41
- raise KafkaException(err)
42
+ def on_error(err: KafkaError) -> None:
43
+ dbos_logger.error(f"Exception in Kafka consumer: {err}")
42
44
 
43
45
  config["error_cb"] = on_error
44
46
  if "auto.offset.reset" not in config:
dbos/_logger.py CHANGED
@@ -68,30 +68,37 @@ def config_logger(config: "ConfigFile") -> None:
68
68
  )
69
69
  disable_otlp = config.get("telemetry", {}).get("disable_otlp", False) # type: ignore
70
70
 
71
- if not disable_otlp and otlp_logs_endpoints:
71
+ if not disable_otlp:
72
72
 
73
- from opentelemetry._logs import set_logger_provider
73
+ from opentelemetry._logs import get_logger_provider, set_logger_provider
74
74
  from opentelemetry.exporter.otlp.proto.http._log_exporter import OTLPLogExporter
75
75
  from opentelemetry.sdk._logs import LoggerProvider, LoggingHandler
76
76
  from opentelemetry.sdk._logs.export import BatchLogRecordProcessor
77
77
  from opentelemetry.sdk.resources import Resource
78
78
  from opentelemetry.semconv.attributes.service_attributes import SERVICE_NAME
79
79
 
80
- log_provider = LoggerProvider(
81
- Resource.create(
82
- attributes={
83
- SERVICE_NAME: config["name"],
84
- }
85
- )
86
- )
87
- set_logger_provider(log_provider)
88
- for e in otlp_logs_endpoints:
89
- log_provider.add_log_record_processor(
90
- BatchLogRecordProcessor(
91
- OTLPLogExporter(endpoint=e),
92
- export_timeout_millis=5000,
80
+ # Only set up OTLP provider and exporter if endpoints are provided
81
+ log_provider = get_logger_provider()
82
+ if otlp_logs_endpoints is not None:
83
+ if not isinstance(log_provider, LoggerProvider):
84
+ log_provider = LoggerProvider(
85
+ Resource.create(
86
+ attributes={
87
+ SERVICE_NAME: config["name"],
88
+ }
89
+ )
90
+ )
91
+ set_logger_provider(log_provider)
92
+
93
+ for e in otlp_logs_endpoints:
94
+ log_provider.add_log_record_processor(
95
+ BatchLogRecordProcessor(
96
+ OTLPLogExporter(endpoint=e),
97
+ export_timeout_millis=5000,
98
+ )
93
99
  )
94
- )
100
+
101
+ # Even if no endpoints are provided, we still need a LoggerProvider to create the LoggingHandler
95
102
  global _otlp_handler
96
103
  _otlp_handler = LoggingHandler(logger_provider=log_provider)
97
104