durabletask 0.4.0__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of durabletask might be problematic. Click here for more details.

durabletask/worker.py CHANGED
@@ -3,11 +3,12 @@
3
3
 
4
4
  import asyncio
5
5
  import inspect
6
+ import json
6
7
  import logging
7
8
  import os
8
9
  import random
9
10
  from concurrent.futures import ThreadPoolExecutor
10
- from datetime import datetime, timedelta
11
+ from datetime import datetime, timedelta, timezone
11
12
  from threading import Event, Thread
12
13
  from types import GeneratorType
13
14
  from enum import Enum
@@ -17,6 +18,11 @@ from packaging.version import InvalidVersion, parse
17
18
  import grpc
18
19
  from google.protobuf import empty_pb2
19
20
 
21
+ from durabletask.internal import helpers
22
+ from durabletask.internal.entity_state_shim import StateShim
23
+ from durabletask.internal.helpers import new_timestamp
24
+ from durabletask.entities import DurableEntity, EntityLock, EntityInstanceId, EntityContext
25
+ from durabletask.internal.orchestration_entity_context import OrchestrationEntityContext
20
26
  import durabletask.internal.helpers as ph
21
27
  import durabletask.internal.exceptions as pe
22
28
  import durabletask.internal.orchestrator_service_pb2 as pb
@@ -40,6 +46,7 @@ class ConcurrencyOptions:
40
46
  self,
41
47
  maximum_concurrent_activity_work_items: Optional[int] = None,
42
48
  maximum_concurrent_orchestration_work_items: Optional[int] = None,
49
+ maximum_concurrent_entity_work_items: Optional[int] = None,
43
50
  maximum_thread_pool_workers: Optional[int] = None,
44
51
  ):
45
52
  """Initialize concurrency options.
@@ -68,6 +75,12 @@ class ConcurrencyOptions:
68
75
  else default_concurrency
69
76
  )
70
77
 
78
+ self.maximum_concurrent_entity_work_items = (
79
+ maximum_concurrent_entity_work_items
80
+ if maximum_concurrent_entity_work_items is not None
81
+ else default_concurrency
82
+ )
83
+
71
84
  self.maximum_thread_pool_workers = (
72
85
  maximum_thread_pool_workers
73
86
  if maximum_thread_pool_workers is not None
@@ -111,7 +124,7 @@ class VersioningOptions:
111
124
 
112
125
  Args:
113
126
  version: The version of orchestrations that the worker can work on.
114
- default_version: The default version that will be used for starting new orchestrations.
127
+ default_version: The default version that will be used for starting new sub-orchestrations.
115
128
  match_strategy: The versioning strategy for the Durable Task worker.
116
129
  failure_strategy: The versioning failure strategy for the Durable Task worker.
117
130
  """
@@ -124,11 +137,15 @@ class VersioningOptions:
124
137
  class _Registry:
125
138
  orchestrators: dict[str, task.Orchestrator]
126
139
  activities: dict[str, task.Activity]
140
+ entities: dict[str, task.Entity]
141
+ entity_instances: dict[str, DurableEntity]
127
142
  versioning: Optional[VersioningOptions] = None
128
143
 
129
144
  def __init__(self):
130
145
  self.orchestrators = {}
131
146
  self.activities = {}
147
+ self.entities = {}
148
+ self.entity_instances = {}
132
149
 
133
150
  def add_orchestrator(self, fn: task.Orchestrator) -> str:
134
151
  if fn is None:
@@ -168,6 +185,29 @@ class _Registry:
168
185
  def get_activity(self, name: str) -> Optional[task.Activity]:
169
186
  return self.activities.get(name)
170
187
 
188
+ def add_entity(self, fn: task.Entity) -> str:
189
+ if fn is None:
190
+ raise ValueError("An entity function argument is required.")
191
+
192
+ if isinstance(fn, type) and issubclass(fn, DurableEntity):
193
+ name = fn.__name__
194
+ self.add_named_entity(name, fn)
195
+ else:
196
+ name = task.get_name(fn)
197
+ self.add_named_entity(name, fn)
198
+ return name
199
+
200
+ def add_named_entity(self, name: str, fn: task.Entity) -> None:
201
+ if not name:
202
+ raise ValueError("A non-empty entity name is required.")
203
+ if name in self.entities:
204
+ raise ValueError(f"A '{name}' entity already exists.")
205
+
206
+ self.entities[name] = fn
207
+
208
+ def get_entity(self, name: str) -> Optional[task.Entity]:
209
+ return self.entities.get(name)
210
+
171
211
 
172
212
  class OrchestratorNotRegisteredError(ValueError):
173
213
  """Raised when attempting to start an orchestration that is not registered"""
@@ -181,6 +221,12 @@ class ActivityNotRegisteredError(ValueError):
181
221
  pass
182
222
 
183
223
 
224
+ class EntityNotRegisteredError(ValueError):
225
+ """Raised when attempting to call an entity that is not registered"""
226
+
227
+ pass
228
+
229
+
184
230
  class TaskHubGrpcWorker:
185
231
  """A gRPC-based worker for processing durable task orchestrations and activities.
186
232
 
@@ -329,6 +375,14 @@ class TaskHubGrpcWorker:
329
375
  )
330
376
  return self._registry.add_activity(fn)
331
377
 
378
+ def add_entity(self, fn: task.Entity) -> str:
379
+ """Registers an entity function with the worker."""
380
+ if self._is_running:
381
+ raise RuntimeError(
382
+ "Entities cannot be added while the worker is running."
383
+ )
384
+ return self._registry.add_entity(fn)
385
+
332
386
  def use_versioning(self, version: VersioningOptions) -> None:
333
387
  """Initializes versioning options for sub-orchestrators and activities."""
334
388
  if self._is_running:
@@ -490,6 +544,20 @@ class TaskHubGrpcWorker:
490
544
  stub,
491
545
  work_item.completionToken,
492
546
  )
547
+ elif work_item.HasField("entityRequest"):
548
+ self._async_worker_manager.submit_entity_batch(
549
+ self._execute_entity_batch,
550
+ work_item.entityRequest,
551
+ stub,
552
+ work_item.completionToken,
553
+ )
554
+ elif work_item.HasField("entityRequestV2"):
555
+ self._async_worker_manager.submit_entity_batch(
556
+ self._execute_entity_batch,
557
+ work_item.entityRequestV2,
558
+ stub,
559
+ work_item.completionToken
560
+ )
493
561
  elif work_item.HasField("healthPing"):
494
562
  pass
495
563
  else:
@@ -635,22 +703,95 @@ class TaskHubGrpcWorker:
635
703
  f"Failed to deliver activity response for '{req.name}#{req.taskId}' of orchestration ID '{instance_id}' to sidecar: {ex}"
636
704
  )
637
705
 
706
+ def _execute_entity_batch(
707
+ self,
708
+ req: Union[pb.EntityBatchRequest, pb.EntityRequest],
709
+ stub: stubs.TaskHubSidecarServiceStub,
710
+ completionToken,
711
+ ):
712
+ if isinstance(req, pb.EntityRequest):
713
+ req, operation_infos = helpers.convert_to_entity_batch_request(req)
714
+
715
+ entity_state = StateShim(shared.from_json(req.entityState.value) if req.entityState.value else None)
716
+
717
+ instance_id = req.instanceId
718
+
719
+ results: list[pb.OperationResult] = []
720
+ for operation in req.operations:
721
+ start_time = datetime.now(timezone.utc)
722
+ executor = _EntityExecutor(self._registry, self._logger)
723
+ entity_instance_id = EntityInstanceId.parse(instance_id)
724
+ if not entity_instance_id:
725
+ raise RuntimeError(f"Invalid entity instance ID '{operation.requestId}' in entity operation request.")
726
+
727
+ operation_result = None
728
+
729
+ try:
730
+ entity_result = executor.execute(
731
+ instance_id, entity_instance_id, operation.operation, entity_state, operation.input.value
732
+ )
733
+
734
+ entity_result = ph.get_string_value_or_empty(entity_result)
735
+ operation_result = pb.OperationResult(success=pb.OperationResultSuccess(
736
+ result=entity_result,
737
+ startTimeUtc=new_timestamp(start_time),
738
+ endTimeUtc=new_timestamp(datetime.now(timezone.utc))
739
+ ))
740
+ results.append(operation_result)
741
+
742
+ entity_state.commit()
743
+ except Exception as ex:
744
+ self._logger.exception(ex)
745
+ operation_result = pb.OperationResult(failure=pb.OperationResultFailure(
746
+ failureDetails=ph.new_failure_details(ex),
747
+ startTimeUtc=new_timestamp(start_time),
748
+ endTimeUtc=new_timestamp(datetime.now(timezone.utc))
749
+ ))
750
+ results.append(operation_result)
751
+
752
+ entity_state.rollback()
753
+
754
+ batch_result = pb.EntityBatchResult(
755
+ results=results,
756
+ actions=entity_state.get_operation_actions(),
757
+ entityState=helpers.get_string_value(shared.to_json(entity_state._current_state)) if entity_state._current_state else None,
758
+ failureDetails=None,
759
+ completionToken=completionToken,
760
+ operationInfos=operation_infos,
761
+ )
762
+
763
+ try:
764
+ stub.CompleteEntityTask(batch_result)
765
+ except Exception as ex:
766
+ self._logger.exception(
767
+ f"Failed to deliver entity response for '{entity_instance_id}' of orchestration ID '{instance_id}' to sidecar: {ex}"
768
+ )
769
+
770
+ # TODO: Reset context
771
+
772
+ return batch_result
773
+
638
774
 
639
775
  class _RuntimeOrchestrationContext(task.OrchestrationContext):
640
776
  _generator: Optional[Generator[task.Task, Any, Any]]
641
777
  _previous_task: Optional[task.Task]
642
778
 
643
- def __init__(self, instance_id: str, registry: _Registry):
779
+ def __init__(self, instance_id: str, registry: _Registry, entity_context: OrchestrationEntityContext):
644
780
  self._generator = None
645
781
  self._is_replaying = True
646
782
  self._is_complete = False
647
783
  self._result = None
648
784
  self._pending_actions: dict[int, pb.OrchestratorAction] = {}
649
785
  self._pending_tasks: dict[int, task.CompletableTask] = {}
786
+ # Maps entity ID to task ID
787
+ self._entity_task_id_map: dict[str, tuple[EntityInstanceId, int]] = {}
788
+ # Maps criticalSectionId to task ID
789
+ self._entity_lock_id_map: dict[str, int] = {}
650
790
  self._sequence_number = 0
651
791
  self._current_utc_datetime = datetime(1000, 1, 1)
652
792
  self._instance_id = instance_id
653
793
  self._registry = registry
794
+ self._entity_context = entity_context
654
795
  self._version: Optional[str] = None
655
796
  self._completion_status: Optional[pb.OrchestrationStatus] = None
656
797
  self._received_events: dict[str, list[Any]] = {}
@@ -701,9 +842,15 @@ class _RuntimeOrchestrationContext(task.OrchestrationContext):
701
842
  if self._is_complete:
702
843
  return
703
844
 
845
+ # If the user code returned without yielding the entity unlock, do that now
846
+ if self._entity_context.is_inside_critical_section:
847
+ self._exit_critical_section()
848
+
704
849
  self._is_complete = True
705
850
  self._completion_status = status
706
- self._pending_actions.clear() # Cancel any pending actions
851
+ # This is probably a bug - an orchestrator may complete with some actions remaining that the user still
852
+ # wants to execute - for example, signaling an entity. So we shouldn't clear the pending actions here.
853
+ # self._pending_actions.clear() # Cancel any pending actions
707
854
 
708
855
  self._result = result
709
856
  result_json: Optional[str] = None
@@ -718,8 +865,14 @@ class _RuntimeOrchestrationContext(task.OrchestrationContext):
718
865
  if self._is_complete:
719
866
  return
720
867
 
868
+ # If the user code crashed inside a critical section, or did not exit it, do that now
869
+ if self._entity_context.is_inside_critical_section:
870
+ self._exit_critical_section()
871
+
721
872
  self._is_complete = True
722
- self._pending_actions.clear() # Cancel any pending actions
873
+ # We also cannot cancel the pending actions in the failure case - if the user code had released an entity
874
+ # lock, we *must* send that action to the sidecar.
875
+ # self._pending_actions.clear() # Cancel any pending actions
723
876
  self._completion_status = pb.ORCHESTRATION_STATUS_FAILED
724
877
 
725
878
  action = ph.new_complete_orchestration_action(
@@ -734,13 +887,20 @@ class _RuntimeOrchestrationContext(task.OrchestrationContext):
734
887
  if self._is_complete:
735
888
  return
736
889
 
890
+ # If the user code called continue_as_new while holding an entity lock, unlock it now
891
+ if self._entity_context.is_inside_critical_section:
892
+ self._exit_critical_section()
893
+
737
894
  self._is_complete = True
738
- self._pending_actions.clear() # Cancel any pending actions
895
+ # We also cannot cancel the pending actions in the continue as new case - if the user code had released an
896
+ # entity lock, we *must* send that action to the sidecar.
897
+ # self._pending_actions.clear() # Cancel any pending actions
739
898
  self._completion_status = pb.ORCHESTRATION_STATUS_CONTINUED_AS_NEW
740
899
  self._new_input = new_input
741
900
  self._save_events = save_events
742
901
 
743
902
  def get_actions(self) -> list[pb.OrchestratorAction]:
903
+ current_actions = list(self._pending_actions.values())
744
904
  if self._completion_status == pb.ORCHESTRATION_STATUS_CONTINUED_AS_NEW:
745
905
  # When continuing-as-new, we only return a single completion action.
746
906
  carryover_events: Optional[list[pb.HistoryEvent]] = None
@@ -765,9 +925,9 @@ class _RuntimeOrchestrationContext(task.OrchestrationContext):
765
925
  failure_details=None,
766
926
  carryover_events=carryover_events,
767
927
  )
768
- return [action]
769
- else:
770
- return list(self._pending_actions.values())
928
+ # We must return the existing tasks as well, to capture entity unlocks
929
+ current_actions.append(action)
930
+ return current_actions
771
931
 
772
932
  def next_sequence_number(self) -> int:
773
933
  self._sequence_number += 1
@@ -833,6 +993,40 @@ class _RuntimeOrchestrationContext(task.OrchestrationContext):
833
993
  )
834
994
  return self._pending_tasks.get(id, task.CompletableTask())
835
995
 
996
+ def call_entity(
997
+ self,
998
+ entity_id: EntityInstanceId,
999
+ operation: str,
1000
+ input: Optional[TInput] = None,
1001
+ ) -> task.Task:
1002
+ id = self.next_sequence_number()
1003
+
1004
+ self.call_entity_function_helper(
1005
+ id, entity_id, operation, input=input
1006
+ )
1007
+
1008
+ return self._pending_tasks.get(id, task.CompletableTask())
1009
+
1010
+ def signal_entity(
1011
+ self,
1012
+ entity_id: EntityInstanceId,
1013
+ operation: str,
1014
+ input: Optional[TInput] = None
1015
+ ) -> None:
1016
+ id = self.next_sequence_number()
1017
+
1018
+ self.signal_entity_function_helper(
1019
+ id, entity_id, operation, input
1020
+ )
1021
+
1022
+ def lock_entities(self, entities: list[EntityInstanceId]) -> task.Task[EntityLock]:
1023
+ id = self.next_sequence_number()
1024
+
1025
+ self.lock_entities_function_helper(
1026
+ id, entities
1027
+ )
1028
+ return self._pending_tasks.get(id, task.CompletableTask())
1029
+
836
1030
  def call_sub_orchestrator(
837
1031
  self,
838
1032
  orchestrator: task.Orchestrator[TInput, TOutput],
@@ -909,6 +1103,80 @@ class _RuntimeOrchestrationContext(task.OrchestrationContext):
909
1103
  )
910
1104
  self._pending_tasks[id] = fn_task
911
1105
 
1106
+ def call_entity_function_helper(
1107
+ self,
1108
+ id: Optional[int],
1109
+ entity_id: EntityInstanceId,
1110
+ operation: str,
1111
+ *,
1112
+ input: Optional[TInput] = None,
1113
+ ):
1114
+ if id is None:
1115
+ id = self.next_sequence_number()
1116
+
1117
+ transition_valid, error_message = self._entity_context.validate_operation_transition(entity_id, False)
1118
+ if not transition_valid:
1119
+ raise RuntimeError(error_message)
1120
+
1121
+ encoded_input = shared.to_json(input) if input is not None else None
1122
+ action = ph.new_call_entity_action(id, self.instance_id, entity_id, operation, encoded_input)
1123
+ self._pending_actions[id] = action
1124
+
1125
+ fn_task = task.CompletableTask()
1126
+ self._pending_tasks[id] = fn_task
1127
+
1128
+ def signal_entity_function_helper(
1129
+ self,
1130
+ id: Optional[int],
1131
+ entity_id: EntityInstanceId,
1132
+ operation: str,
1133
+ input: Optional[TInput]
1134
+ ) -> None:
1135
+ if id is None:
1136
+ id = self.next_sequence_number()
1137
+
1138
+ transition_valid, error_message = self._entity_context.validate_operation_transition(entity_id, True)
1139
+
1140
+ if not transition_valid:
1141
+ raise RuntimeError(error_message)
1142
+
1143
+ encoded_input = shared.to_json(input) if input is not None else None
1144
+
1145
+ action = ph.new_signal_entity_action(id, entity_id, operation, encoded_input)
1146
+ self._pending_actions[id] = action
1147
+
1148
+ def lock_entities_function_helper(self, id: int, entities: list[EntityInstanceId]) -> None:
1149
+ if id is None:
1150
+ id = self.next_sequence_number()
1151
+
1152
+ transition_valid, error_message = self._entity_context.validate_acquire_transition()
1153
+ if not transition_valid:
1154
+ raise RuntimeError(error_message)
1155
+
1156
+ critical_section_id = f"{self.instance_id}:{id:04x}"
1157
+
1158
+ request, target = self._entity_context.emit_acquire_message(critical_section_id, entities)
1159
+
1160
+ if not request or not target:
1161
+ raise RuntimeError("Failed to create entity lock request.")
1162
+
1163
+ action = ph.new_lock_entities_action(id, request)
1164
+ self._pending_actions[id] = action
1165
+
1166
+ fn_task = task.CompletableTask[EntityLock]()
1167
+ self._pending_tasks[id] = fn_task
1168
+
1169
+ def _exit_critical_section(self) -> None:
1170
+ if not self._entity_context.is_inside_critical_section:
1171
+ # Possible if the user calls continue_as_new inside the lock - in the success case, we will call
1172
+ # _exit_critical_section both from the EntityLock and the continue_as_new logic. We must keep both calls in
1173
+ # case the user code crashes after calling continue_as_new but before the EntityLock object is exited.
1174
+ return
1175
+ for entity_unlock_message in self._entity_context.emit_lock_release_messages():
1176
+ task_id = self.next_sequence_number()
1177
+ action = pb.OrchestratorAction(id=task_id, sendEntityMessage=entity_unlock_message)
1178
+ self._pending_actions[task_id] = action
1179
+
912
1180
  def wait_for_external_event(self, name: str) -> task.Task:
913
1181
  # Check to see if this event has already been received, in which case we
914
1182
  # can return it immediately. Otherwise, record out intent to receive an
@@ -957,6 +1225,7 @@ class _OrchestrationExecutor:
957
1225
  self._logger = logger
958
1226
  self._is_suspended = False
959
1227
  self._suspended_events: list[pb.HistoryEvent] = []
1228
+ self._entity_state: Optional[OrchestrationEntityContext] = None
960
1229
 
961
1230
  def execute(
962
1231
  self,
@@ -964,13 +1233,14 @@ class _OrchestrationExecutor:
964
1233
  old_events: Sequence[pb.HistoryEvent],
965
1234
  new_events: Sequence[pb.HistoryEvent],
966
1235
  ) -> ExecutionResults:
1236
+ self._entity_state = OrchestrationEntityContext(instance_id)
1237
+
967
1238
  if not new_events:
968
1239
  raise task.OrchestrationStateError(
969
1240
  "The new history event list must have at least one event in it."
970
1241
  )
971
1242
 
972
- ctx = _RuntimeOrchestrationContext(instance_id, self._registry)
973
- version_failure = None
1243
+ ctx = _RuntimeOrchestrationContext(instance_id, self._registry, self._entity_state)
974
1244
  try:
975
1245
  # Rebuild local state by replaying old history into the orchestrator function
976
1246
  self._logger.debug(
@@ -980,23 +1250,6 @@ class _OrchestrationExecutor:
980
1250
  for old_event in old_events:
981
1251
  self.process_event(ctx, old_event)
982
1252
 
983
- # Process versioning if applicable
984
- execution_started_events = [e.executionStarted for e in old_events if e.HasField("executionStarted")]
985
- # We only check versioning if there are executionStarted events - otherwise, on the first replay when
986
- # ctx.version will be Null, we may invalidate orchestrations early depending on the versioning strategy.
987
- if self._registry.versioning and len(execution_started_events) > 0:
988
- version_failure = self.evaluate_orchestration_versioning(
989
- self._registry.versioning,
990
- ctx.version
991
- )
992
- if version_failure:
993
- self._logger.warning(
994
- f"Orchestration version did not meet worker versioning requirements. "
995
- f"Error action = '{self._registry.versioning.failure_strategy}'. "
996
- f"Version error = '{version_failure}'"
997
- )
998
- raise pe.VersionFailureException
999
-
1000
1253
  # Get new actions by executing newly received events into the orchestrator function
1001
1254
  if self._logger.level <= logging.DEBUG:
1002
1255
  summary = _get_new_event_summary(new_events)
@@ -1009,8 +1262,8 @@ class _OrchestrationExecutor:
1009
1262
 
1010
1263
  except pe.VersionFailureException as ex:
1011
1264
  if self._registry.versioning and self._registry.versioning.failure_strategy == VersionFailureStrategy.FAIL:
1012
- if version_failure:
1013
- ctx.set_failed(version_failure)
1265
+ if ex.error_details:
1266
+ ctx.set_failed(ex.error_details)
1014
1267
  else:
1015
1268
  ctx.set_failed(ex)
1016
1269
  elif self._registry.versioning and self._registry.versioning.failure_strategy == VersionFailureStrategy.REJECT:
@@ -1068,6 +1321,19 @@ class _OrchestrationExecutor:
1068
1321
  if event.executionStarted.version:
1069
1322
  ctx._version = event.executionStarted.version.value
1070
1323
 
1324
+ if self._registry.versioning:
1325
+ version_failure = self.evaluate_orchestration_versioning(
1326
+ self._registry.versioning,
1327
+ ctx.version
1328
+ )
1329
+ if version_failure:
1330
+ self._logger.warning(
1331
+ f"Orchestration version did not meet worker versioning requirements. "
1332
+ f"Error action = '{self._registry.versioning.failure_strategy}'. "
1333
+ f"Version error = '{version_failure}'"
1334
+ )
1335
+ raise pe.VersionFailureException(version_failure)
1336
+
1071
1337
  # deserialize the input, if any
1072
1338
  input = None
1073
1339
  if (
@@ -1321,6 +1587,108 @@ class _OrchestrationExecutor:
1321
1587
  pb.ORCHESTRATION_STATUS_TERMINATED,
1322
1588
  is_result_encoded=True,
1323
1589
  )
1590
+ elif event.HasField("entityOperationCalled"):
1591
+ # This history event confirms that the entity operation was successfully scheduled.
1592
+ # Remove the entityOperationCalled event from the pending action list so we don't schedule it again
1593
+ entity_call_id = event.eventId
1594
+ action = ctx._pending_actions.pop(entity_call_id, None)
1595
+ entity_task = ctx._pending_tasks.get(entity_call_id, None)
1596
+ if not action:
1597
+ raise _get_non_determinism_error(
1598
+ entity_call_id, task.get_name(ctx.call_entity)
1599
+ )
1600
+ elif not action.HasField("sendEntityMessage") or not action.sendEntityMessage.HasField("entityOperationCalled"):
1601
+ expected_method_name = task.get_name(ctx.call_entity)
1602
+ raise _get_wrong_action_type_error(
1603
+ entity_call_id, expected_method_name, action
1604
+ )
1605
+ entity_id = EntityInstanceId.parse(event.entityOperationCalled.targetInstanceId.value)
1606
+ if not entity_id:
1607
+ raise RuntimeError(f"Could not parse entity ID from targetInstanceId '{event.entityOperationCalled.targetInstanceId.value}'")
1608
+ ctx._entity_task_id_map[event.entityOperationCalled.requestId] = (entity_id, entity_call_id)
1609
+ elif event.HasField("entityOperationSignaled"):
1610
+ # This history event confirms that the entity signal was successfully scheduled.
1611
+ # Remove the entityOperationSignaled event from the pending action list so we don't schedule it
1612
+ entity_signal_id = event.eventId
1613
+ action = ctx._pending_actions.pop(entity_signal_id, None)
1614
+ if not action:
1615
+ raise _get_non_determinism_error(
1616
+ entity_signal_id, task.get_name(ctx.signal_entity)
1617
+ )
1618
+ elif not action.HasField("sendEntityMessage") or not action.sendEntityMessage.HasField("entityOperationSignaled"):
1619
+ expected_method_name = task.get_name(ctx.signal_entity)
1620
+ raise _get_wrong_action_type_error(
1621
+ entity_signal_id, expected_method_name, action
1622
+ )
1623
+ elif event.HasField("entityLockRequested"):
1624
+ section_id = event.entityLockRequested.criticalSectionId
1625
+ task_id = event.eventId
1626
+ action = ctx._pending_actions.pop(task_id, None)
1627
+ entity_task = ctx._pending_tasks.get(task_id, None)
1628
+ if not action:
1629
+ raise _get_non_determinism_error(
1630
+ task_id, task.get_name(ctx.lock_entities)
1631
+ )
1632
+ elif not action.HasField("sendEntityMessage") or not action.sendEntityMessage.HasField("entityLockRequested"):
1633
+ expected_method_name = task.get_name(ctx.lock_entities)
1634
+ raise _get_wrong_action_type_error(
1635
+ task_id, expected_method_name, action
1636
+ )
1637
+ ctx._entity_lock_id_map[section_id] = task_id
1638
+ elif event.HasField("entityUnlockSent"):
1639
+ # Remove the unlock tasks as they have already been processed
1640
+ tasks_to_remove = []
1641
+ for task_id, action in ctx._pending_actions.items():
1642
+ if action.HasField("sendEntityMessage") and action.sendEntityMessage.HasField("entityUnlockSent"):
1643
+ if action.sendEntityMessage.entityUnlockSent.criticalSectionId == event.entityUnlockSent.criticalSectionId:
1644
+ tasks_to_remove.append(task_id)
1645
+ for task_to_remove in tasks_to_remove:
1646
+ ctx._pending_actions.pop(task_to_remove, None)
1647
+ elif event.HasField("entityLockGranted"):
1648
+ section_id = event.entityLockGranted.criticalSectionId
1649
+ task_id = ctx._entity_lock_id_map.pop(section_id, None)
1650
+ if not task_id:
1651
+ # TODO: Should this be an error? When would it ever happen?
1652
+ if not ctx.is_replaying:
1653
+ self._logger.warning(
1654
+ f"{ctx.instance_id}: Ignoring unexpected entityLockGranted event for criticalSectionId '{section_id}'."
1655
+ )
1656
+ return
1657
+ entity_task = ctx._pending_tasks.pop(task_id, None)
1658
+ if not entity_task:
1659
+ if not ctx.is_replaying:
1660
+ self._logger.warning(
1661
+ f"{ctx.instance_id}: Ignoring unexpected entityLockGranted event for criticalSectionId '{section_id}'."
1662
+ )
1663
+ return
1664
+ ctx._entity_context.complete_acquire(section_id)
1665
+ entity_task.complete(EntityLock(ctx))
1666
+ ctx.resume()
1667
+ elif event.HasField("entityOperationCompleted"):
1668
+ request_id = event.entityOperationCompleted.requestId
1669
+ entity_id, task_id = ctx._entity_task_id_map.pop(request_id, (None, None))
1670
+ if not entity_id:
1671
+ raise RuntimeError(f"Could not parse entity ID from request ID '{request_id}'")
1672
+ if not task_id:
1673
+ raise RuntimeError(f"Could not find matching task ID for entity operation with request ID '{request_id}'")
1674
+ entity_task = ctx._pending_tasks.pop(task_id, None)
1675
+ if not entity_task:
1676
+ if not ctx.is_replaying:
1677
+ self._logger.warning(
1678
+ f"{ctx.instance_id}: Ignoring unexpected entityOperationCompleted event with request ID = {request_id}."
1679
+ )
1680
+ return
1681
+ result = None
1682
+ if not ph.is_empty(event.entityOperationCompleted.output):
1683
+ result = shared.from_json(event.entityOperationCompleted.output.value)
1684
+ ctx._entity_context.recover_lock_after_call(entity_id)
1685
+ entity_task.complete(result)
1686
+ ctx.resume()
1687
+ elif event.HasField("entityOperationFailed"):
1688
+ if not ctx.is_replaying:
1689
+ self._logger.info(f"{ctx.instance_id}: Entity operation failed.")
1690
+ self._logger.info(f"Data: {json.dumps(event.entityOperationFailed)}")
1691
+ pass
1324
1692
  else:
1325
1693
  eventType = event.WhichOneof("eventType")
1326
1694
  raise task.OrchestrationStateError(
@@ -1411,6 +1779,60 @@ class _ActivityExecutor:
1411
1779
  return encoded_output
1412
1780
 
1413
1781
 
1782
+ class _EntityExecutor:
1783
+ def __init__(self, registry: _Registry, logger: logging.Logger):
1784
+ self._registry = registry
1785
+ self._logger = logger
1786
+
1787
+ def execute(
1788
+ self,
1789
+ orchestration_id: str,
1790
+ entity_id: EntityInstanceId,
1791
+ operation: str,
1792
+ state: StateShim,
1793
+ encoded_input: Optional[str],
1794
+ ) -> Optional[str]:
1795
+ """Executes an entity function and returns the serialized result, if any."""
1796
+ self._logger.debug(
1797
+ f"{orchestration_id}: Executing entity '{entity_id}'..."
1798
+ )
1799
+ fn = self._registry.get_entity(entity_id.entity)
1800
+ if not fn:
1801
+ raise EntityNotRegisteredError(
1802
+ f"Entity function named '{entity_id.entity}' was not registered!"
1803
+ )
1804
+
1805
+ entity_input = shared.from_json(encoded_input) if encoded_input else None
1806
+ ctx = EntityContext(orchestration_id, operation, state, entity_id)
1807
+
1808
+ if isinstance(fn, type) and issubclass(fn, DurableEntity):
1809
+ if self._registry.entity_instances.get(str(entity_id), None):
1810
+ entity_instance = self._registry.entity_instances[str(entity_id)]
1811
+ else:
1812
+ entity_instance = fn()
1813
+ self._registry.entity_instances[str(entity_id)] = entity_instance
1814
+ if not hasattr(entity_instance, operation):
1815
+ raise AttributeError(f"Entity '{entity_id}' does not have operation '{operation}'")
1816
+ method = getattr(entity_instance, operation)
1817
+ if not callable(method):
1818
+ raise TypeError(f"Entity operation '{operation}' is not callable")
1819
+ # Execute the entity method
1820
+ entity_instance._initialize_entity_context(ctx)
1821
+ entity_output = method(entity_input)
1822
+ else:
1823
+ # Execute the entity function
1824
+ entity_output = fn(ctx, entity_input)
1825
+
1826
+ encoded_output = (
1827
+ shared.to_json(entity_output) if entity_output is not None else None
1828
+ )
1829
+ chars = len(encoded_output) if encoded_output else 0
1830
+ self._logger.debug(
1831
+ f"{orchestration_id}: Entity '{entity_id}' completed successfully with {chars} char(s) of encoded output."
1832
+ )
1833
+ return encoded_output
1834
+
1835
+
1414
1836
  def _get_non_determinism_error(
1415
1837
  task_id: int, action_name: str
1416
1838
  ) -> task.NonDeterminismError:
@@ -1502,13 +1924,16 @@ class _AsyncWorkerManager:
1502
1924
  self.concurrency_options = concurrency_options
1503
1925
  self.activity_semaphore = None
1504
1926
  self.orchestration_semaphore = None
1927
+ self.entity_semaphore = None
1505
1928
  # Don't create queues here - defer until we have an event loop
1506
1929
  self.activity_queue: Optional[asyncio.Queue] = None
1507
1930
  self.orchestration_queue: Optional[asyncio.Queue] = None
1931
+ self.entity_batch_queue: Optional[asyncio.Queue] = None
1508
1932
  self._queue_event_loop: Optional[asyncio.AbstractEventLoop] = None
1509
1933
  # Store work items when no event loop is available
1510
1934
  self._pending_activity_work: list = []
1511
1935
  self._pending_orchestration_work: list = []
1936
+ self._pending_entity_batch_work: list = []
1512
1937
  self.thread_pool = ThreadPoolExecutor(
1513
1938
  max_workers=concurrency_options.maximum_thread_pool_workers,
1514
1939
  thread_name_prefix="DurableTask",
@@ -1525,7 +1950,7 @@ class _AsyncWorkerManager:
1525
1950
 
1526
1951
  # Check if queues are already properly set up for current loop
1527
1952
  if self._queue_event_loop is current_loop:
1528
- if self.activity_queue is not None and self.orchestration_queue is not None:
1953
+ if self.activity_queue is not None and self.orchestration_queue is not None and self.entity_batch_queue is not None:
1529
1954
  # Queues are already bound to the current loop and exist
1530
1955
  return
1531
1956
 
@@ -1533,6 +1958,7 @@ class _AsyncWorkerManager:
1533
1958
  # First, preserve any existing work items
1534
1959
  existing_activity_items = []
1535
1960
  existing_orchestration_items = []
1961
+ existing_entity_batch_items = []
1536
1962
 
1537
1963
  if self.activity_queue is not None:
1538
1964
  try:
@@ -1550,9 +1976,19 @@ class _AsyncWorkerManager:
1550
1976
  except Exception:
1551
1977
  pass
1552
1978
 
1979
+ if self.entity_batch_queue is not None:
1980
+ try:
1981
+ while not self.entity_batch_queue.empty():
1982
+ existing_entity_batch_items.append(
1983
+ self.entity_batch_queue.get_nowait()
1984
+ )
1985
+ except Exception:
1986
+ pass
1987
+
1553
1988
  # Create fresh queues for the current event loop
1554
1989
  self.activity_queue = asyncio.Queue()
1555
1990
  self.orchestration_queue = asyncio.Queue()
1991
+ self.entity_batch_queue = asyncio.Queue()
1556
1992
  self._queue_event_loop = current_loop
1557
1993
 
1558
1994
  # Restore the work items to the new queues
@@ -1560,16 +1996,21 @@ class _AsyncWorkerManager:
1560
1996
  self.activity_queue.put_nowait(item)
1561
1997
  for item in existing_orchestration_items:
1562
1998
  self.orchestration_queue.put_nowait(item)
1999
+ for item in existing_entity_batch_items:
2000
+ self.entity_batch_queue.put_nowait(item)
1563
2001
 
1564
2002
  # Move pending work items to the queues
1565
2003
  for item in self._pending_activity_work:
1566
2004
  self.activity_queue.put_nowait(item)
1567
2005
  for item in self._pending_orchestration_work:
1568
2006
  self.orchestration_queue.put_nowait(item)
2007
+ for item in self._pending_entity_batch_work:
2008
+ self.entity_batch_queue.put_nowait(item)
1569
2009
 
1570
2010
  # Clear the pending work lists
1571
2011
  self._pending_activity_work.clear()
1572
2012
  self._pending_orchestration_work.clear()
2013
+ self._pending_entity_batch_work.clear()
1573
2014
 
1574
2015
  async def run(self):
1575
2016
  # Reset shutdown flag in case this manager is being reused
@@ -1585,14 +2026,21 @@ class _AsyncWorkerManager:
1585
2026
  self.orchestration_semaphore = asyncio.Semaphore(
1586
2027
  self.concurrency_options.maximum_concurrent_orchestration_work_items
1587
2028
  )
2029
+ self.entity_semaphore = asyncio.Semaphore(
2030
+ self.concurrency_options.maximum_concurrent_entity_work_items
2031
+ )
1588
2032
 
1589
2033
  # Start background consumers for each work type
1590
- if self.activity_queue is not None and self.orchestration_queue is not None:
2034
+ if self.activity_queue is not None and self.orchestration_queue is not None \
2035
+ and self.entity_batch_queue is not None:
1591
2036
  await asyncio.gather(
1592
2037
  self._consume_queue(self.activity_queue, self.activity_semaphore),
1593
2038
  self._consume_queue(
1594
2039
  self.orchestration_queue, self.orchestration_semaphore
1595
2040
  ),
2041
+ self._consume_queue(
2042
+ self.entity_batch_queue, self.entity_semaphore
2043
+ )
1596
2044
  )
1597
2045
 
1598
2046
  async def _consume_queue(self, queue: asyncio.Queue, semaphore: asyncio.Semaphore):
@@ -1662,6 +2110,15 @@ class _AsyncWorkerManager:
1662
2110
  # No event loop running, store in pending list
1663
2111
  self._pending_orchestration_work.append(work_item)
1664
2112
 
2113
+ def submit_entity_batch(self, func, *args, **kwargs):
2114
+ work_item = (func, args, kwargs)
2115
+ self._ensure_queues_for_current_loop()
2116
+ if self.entity_batch_queue is not None:
2117
+ self.entity_batch_queue.put_nowait(work_item)
2118
+ else:
2119
+ # No event loop running, store in pending list
2120
+ self._pending_entity_batch_work.append(work_item)
2121
+
1665
2122
  def shutdown(self):
1666
2123
  self._shutdown = True
1667
2124
  self.thread_pool.shutdown(wait=True)