durabletask 0.4.1__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of durabletask might be problematic. Click here for more details.

durabletask/worker.py CHANGED
@@ -3,11 +3,12 @@
3
3
 
4
4
  import asyncio
5
5
  import inspect
6
+ import json
6
7
  import logging
7
8
  import os
8
9
  import random
9
10
  from concurrent.futures import ThreadPoolExecutor
10
- from datetime import datetime, timedelta
11
+ from datetime import datetime, timedelta, timezone
11
12
  from threading import Event, Thread
12
13
  from types import GeneratorType
13
14
  from enum import Enum
@@ -17,6 +18,11 @@ from packaging.version import InvalidVersion, parse
17
18
  import grpc
18
19
  from google.protobuf import empty_pb2
19
20
 
21
+ from durabletask.internal import helpers
22
+ from durabletask.internal.entity_state_shim import StateShim
23
+ from durabletask.internal.helpers import new_timestamp
24
+ from durabletask.entities import DurableEntity, EntityLock, EntityInstanceId, EntityContext
25
+ from durabletask.internal.orchestration_entity_context import OrchestrationEntityContext
20
26
  import durabletask.internal.helpers as ph
21
27
  import durabletask.internal.exceptions as pe
22
28
  import durabletask.internal.orchestrator_service_pb2 as pb
@@ -40,6 +46,7 @@ class ConcurrencyOptions:
40
46
  self,
41
47
  maximum_concurrent_activity_work_items: Optional[int] = None,
42
48
  maximum_concurrent_orchestration_work_items: Optional[int] = None,
49
+ maximum_concurrent_entity_work_items: Optional[int] = None,
43
50
  maximum_thread_pool_workers: Optional[int] = None,
44
51
  ):
45
52
  """Initialize concurrency options.
@@ -68,6 +75,12 @@ class ConcurrencyOptions:
68
75
  else default_concurrency
69
76
  )
70
77
 
78
+ self.maximum_concurrent_entity_work_items = (
79
+ maximum_concurrent_entity_work_items
80
+ if maximum_concurrent_entity_work_items is not None
81
+ else default_concurrency
82
+ )
83
+
71
84
  self.maximum_thread_pool_workers = (
72
85
  maximum_thread_pool_workers
73
86
  if maximum_thread_pool_workers is not None
@@ -124,11 +137,15 @@ class VersioningOptions:
124
137
  class _Registry:
125
138
  orchestrators: dict[str, task.Orchestrator]
126
139
  activities: dict[str, task.Activity]
140
+ entities: dict[str, task.Entity]
141
+ entity_instances: dict[str, DurableEntity]
127
142
  versioning: Optional[VersioningOptions] = None
128
143
 
129
144
  def __init__(self):
130
145
  self.orchestrators = {}
131
146
  self.activities = {}
147
+ self.entities = {}
148
+ self.entity_instances = {}
132
149
 
133
150
  def add_orchestrator(self, fn: task.Orchestrator) -> str:
134
151
  if fn is None:
@@ -168,6 +185,29 @@ class _Registry:
168
185
  def get_activity(self, name: str) -> Optional[task.Activity]:
169
186
  return self.activities.get(name)
170
187
 
188
+ def add_entity(self, fn: task.Entity) -> str:
189
+ if fn is None:
190
+ raise ValueError("An entity function argument is required.")
191
+
192
+ if isinstance(fn, type) and issubclass(fn, DurableEntity):
193
+ name = fn.__name__
194
+ self.add_named_entity(name, fn)
195
+ else:
196
+ name = task.get_name(fn)
197
+ self.add_named_entity(name, fn)
198
+ return name
199
+
200
+ def add_named_entity(self, name: str, fn: task.Entity) -> None:
201
+ if not name:
202
+ raise ValueError("A non-empty entity name is required.")
203
+ if name in self.entities:
204
+ raise ValueError(f"A '{name}' entity already exists.")
205
+
206
+ self.entities[name] = fn
207
+
208
+ def get_entity(self, name: str) -> Optional[task.Entity]:
209
+ return self.entities.get(name)
210
+
171
211
 
172
212
  class OrchestratorNotRegisteredError(ValueError):
173
213
  """Raised when attempting to start an orchestration that is not registered"""
@@ -181,6 +221,12 @@ class ActivityNotRegisteredError(ValueError):
181
221
  pass
182
222
 
183
223
 
224
+ class EntityNotRegisteredError(ValueError):
225
+ """Raised when attempting to call an entity that is not registered"""
226
+
227
+ pass
228
+
229
+
184
230
  class TaskHubGrpcWorker:
185
231
  """A gRPC-based worker for processing durable task orchestrations and activities.
186
232
 
@@ -329,6 +375,14 @@ class TaskHubGrpcWorker:
329
375
  )
330
376
  return self._registry.add_activity(fn)
331
377
 
378
+ def add_entity(self, fn: task.Entity) -> str:
379
+ """Registers an entity function with the worker."""
380
+ if self._is_running:
381
+ raise RuntimeError(
382
+ "Entities cannot be added while the worker is running."
383
+ )
384
+ return self._registry.add_entity(fn)
385
+
332
386
  def use_versioning(self, version: VersioningOptions) -> None:
333
387
  """Initializes versioning options for sub-orchestrators and activities."""
334
388
  if self._is_running:
@@ -490,6 +544,20 @@ class TaskHubGrpcWorker:
490
544
  stub,
491
545
  work_item.completionToken,
492
546
  )
547
+ elif work_item.HasField("entityRequest"):
548
+ self._async_worker_manager.submit_entity_batch(
549
+ self._execute_entity_batch,
550
+ work_item.entityRequest,
551
+ stub,
552
+ work_item.completionToken,
553
+ )
554
+ elif work_item.HasField("entityRequestV2"):
555
+ self._async_worker_manager.submit_entity_batch(
556
+ self._execute_entity_batch,
557
+ work_item.entityRequestV2,
558
+ stub,
559
+ work_item.completionToken
560
+ )
493
561
  elif work_item.HasField("healthPing"):
494
562
  pass
495
563
  else:
@@ -635,22 +703,95 @@ class TaskHubGrpcWorker:
635
703
  f"Failed to deliver activity response for '{req.name}#{req.taskId}' of orchestration ID '{instance_id}' to sidecar: {ex}"
636
704
  )
637
705
 
706
+ def _execute_entity_batch(
707
+ self,
708
+ req: Union[pb.EntityBatchRequest, pb.EntityRequest],
709
+ stub: stubs.TaskHubSidecarServiceStub,
710
+ completionToken,
711
+ ):
712
+ if isinstance(req, pb.EntityRequest):
713
+ req, operation_infos = helpers.convert_to_entity_batch_request(req)
714
+
715
+ entity_state = StateShim(shared.from_json(req.entityState.value) if req.entityState.value else None)
716
+
717
+ instance_id = req.instanceId
718
+
719
+ results: list[pb.OperationResult] = []
720
+ for operation in req.operations:
721
+ start_time = datetime.now(timezone.utc)
722
+ executor = _EntityExecutor(self._registry, self._logger)
723
+ entity_instance_id = EntityInstanceId.parse(instance_id)
724
+ if not entity_instance_id:
725
+ raise RuntimeError(f"Invalid entity instance ID '{operation.requestId}' in entity operation request.")
726
+
727
+ operation_result = None
728
+
729
+ try:
730
+ entity_result = executor.execute(
731
+ instance_id, entity_instance_id, operation.operation, entity_state, operation.input.value
732
+ )
733
+
734
+ entity_result = ph.get_string_value_or_empty(entity_result)
735
+ operation_result = pb.OperationResult(success=pb.OperationResultSuccess(
736
+ result=entity_result,
737
+ startTimeUtc=new_timestamp(start_time),
738
+ endTimeUtc=new_timestamp(datetime.now(timezone.utc))
739
+ ))
740
+ results.append(operation_result)
741
+
742
+ entity_state.commit()
743
+ except Exception as ex:
744
+ self._logger.exception(ex)
745
+ operation_result = pb.OperationResult(failure=pb.OperationResultFailure(
746
+ failureDetails=ph.new_failure_details(ex),
747
+ startTimeUtc=new_timestamp(start_time),
748
+ endTimeUtc=new_timestamp(datetime.now(timezone.utc))
749
+ ))
750
+ results.append(operation_result)
751
+
752
+ entity_state.rollback()
753
+
754
+ batch_result = pb.EntityBatchResult(
755
+ results=results,
756
+ actions=entity_state.get_operation_actions(),
757
+ entityState=helpers.get_string_value(shared.to_json(entity_state._current_state)) if entity_state._current_state else None,
758
+ failureDetails=None,
759
+ completionToken=completionToken,
760
+ operationInfos=operation_infos,
761
+ )
762
+
763
+ try:
764
+ stub.CompleteEntityTask(batch_result)
765
+ except Exception as ex:
766
+ self._logger.exception(
767
+ f"Failed to deliver entity response for '{entity_instance_id}' of orchestration ID '{instance_id}' to sidecar: {ex}"
768
+ )
769
+
770
+ # TODO: Reset context
771
+
772
+ return batch_result
773
+
638
774
 
639
775
  class _RuntimeOrchestrationContext(task.OrchestrationContext):
640
776
  _generator: Optional[Generator[task.Task, Any, Any]]
641
777
  _previous_task: Optional[task.Task]
642
778
 
643
- def __init__(self, instance_id: str, registry: _Registry):
779
+ def __init__(self, instance_id: str, registry: _Registry, entity_context: OrchestrationEntityContext):
644
780
  self._generator = None
645
781
  self._is_replaying = True
646
782
  self._is_complete = False
647
783
  self._result = None
648
784
  self._pending_actions: dict[int, pb.OrchestratorAction] = {}
649
785
  self._pending_tasks: dict[int, task.CompletableTask] = {}
786
+ # Maps entity ID to task ID
787
+ self._entity_task_id_map: dict[str, tuple[EntityInstanceId, int]] = {}
788
+ # Maps criticalSectionId to task ID
789
+ self._entity_lock_id_map: dict[str, int] = {}
650
790
  self._sequence_number = 0
651
791
  self._current_utc_datetime = datetime(1000, 1, 1)
652
792
  self._instance_id = instance_id
653
793
  self._registry = registry
794
+ self._entity_context = entity_context
654
795
  self._version: Optional[str] = None
655
796
  self._completion_status: Optional[pb.OrchestrationStatus] = None
656
797
  self._received_events: dict[str, list[Any]] = {}
@@ -701,9 +842,15 @@ class _RuntimeOrchestrationContext(task.OrchestrationContext):
701
842
  if self._is_complete:
702
843
  return
703
844
 
845
+ # If the user code returned without yielding the entity unlock, do that now
846
+ if self._entity_context.is_inside_critical_section:
847
+ self._exit_critical_section()
848
+
704
849
  self._is_complete = True
705
850
  self._completion_status = status
706
- self._pending_actions.clear() # Cancel any pending actions
851
+ # This is probably a bug - an orchestrator may complete with some actions remaining that the user still
852
+ # wants to execute - for example, signaling an entity. So we shouldn't clear the pending actions here.
853
+ # self._pending_actions.clear() # Cancel any pending actions
707
854
 
708
855
  self._result = result
709
856
  result_json: Optional[str] = None
@@ -718,8 +865,14 @@ class _RuntimeOrchestrationContext(task.OrchestrationContext):
718
865
  if self._is_complete:
719
866
  return
720
867
 
868
+ # If the user code crashed inside a critical section, or did not exit it, do that now
869
+ if self._entity_context.is_inside_critical_section:
870
+ self._exit_critical_section()
871
+
721
872
  self._is_complete = True
722
- self._pending_actions.clear() # Cancel any pending actions
873
+ # We also cannot cancel the pending actions in the failure case - if the user code had released an entity
874
+ # lock, we *must* send that action to the sidecar.
875
+ # self._pending_actions.clear() # Cancel any pending actions
723
876
  self._completion_status = pb.ORCHESTRATION_STATUS_FAILED
724
877
 
725
878
  action = ph.new_complete_orchestration_action(
@@ -734,13 +887,20 @@ class _RuntimeOrchestrationContext(task.OrchestrationContext):
734
887
  if self._is_complete:
735
888
  return
736
889
 
890
+ # If the user code called continue_as_new while holding an entity lock, unlock it now
891
+ if self._entity_context.is_inside_critical_section:
892
+ self._exit_critical_section()
893
+
737
894
  self._is_complete = True
738
- self._pending_actions.clear() # Cancel any pending actions
895
+ # We also cannot cancel the pending actions in the continue as new case - if the user code had released an
896
+ # entity lock, we *must* send that action to the sidecar.
897
+ # self._pending_actions.clear() # Cancel any pending actions
739
898
  self._completion_status = pb.ORCHESTRATION_STATUS_CONTINUED_AS_NEW
740
899
  self._new_input = new_input
741
900
  self._save_events = save_events
742
901
 
743
902
  def get_actions(self) -> list[pb.OrchestratorAction]:
903
+ current_actions = list(self._pending_actions.values())
744
904
  if self._completion_status == pb.ORCHESTRATION_STATUS_CONTINUED_AS_NEW:
745
905
  # When continuing-as-new, we only return a single completion action.
746
906
  carryover_events: Optional[list[pb.HistoryEvent]] = None
@@ -765,9 +925,9 @@ class _RuntimeOrchestrationContext(task.OrchestrationContext):
765
925
  failure_details=None,
766
926
  carryover_events=carryover_events,
767
927
  )
768
- return [action]
769
- else:
770
- return list(self._pending_actions.values())
928
+ # We must return the existing tasks as well, to capture entity unlocks
929
+ current_actions.append(action)
930
+ return current_actions
771
931
 
772
932
  def next_sequence_number(self) -> int:
773
933
  self._sequence_number += 1
@@ -833,6 +993,40 @@ class _RuntimeOrchestrationContext(task.OrchestrationContext):
833
993
  )
834
994
  return self._pending_tasks.get(id, task.CompletableTask())
835
995
 
996
+ def call_entity(
997
+ self,
998
+ entity_id: EntityInstanceId,
999
+ operation: str,
1000
+ input: Optional[TInput] = None,
1001
+ ) -> task.Task:
1002
+ id = self.next_sequence_number()
1003
+
1004
+ self.call_entity_function_helper(
1005
+ id, entity_id, operation, input=input
1006
+ )
1007
+
1008
+ return self._pending_tasks.get(id, task.CompletableTask())
1009
+
1010
+ def signal_entity(
1011
+ self,
1012
+ entity_id: EntityInstanceId,
1013
+ operation: str,
1014
+ input: Optional[TInput] = None
1015
+ ) -> None:
1016
+ id = self.next_sequence_number()
1017
+
1018
+ self.signal_entity_function_helper(
1019
+ id, entity_id, operation, input
1020
+ )
1021
+
1022
+ def lock_entities(self, entities: list[EntityInstanceId]) -> task.Task[EntityLock]:
1023
+ id = self.next_sequence_number()
1024
+
1025
+ self.lock_entities_function_helper(
1026
+ id, entities
1027
+ )
1028
+ return self._pending_tasks.get(id, task.CompletableTask())
1029
+
836
1030
  def call_sub_orchestrator(
837
1031
  self,
838
1032
  orchestrator: task.Orchestrator[TInput, TOutput],
@@ -909,6 +1103,80 @@ class _RuntimeOrchestrationContext(task.OrchestrationContext):
909
1103
  )
910
1104
  self._pending_tasks[id] = fn_task
911
1105
 
1106
+ def call_entity_function_helper(
1107
+ self,
1108
+ id: Optional[int],
1109
+ entity_id: EntityInstanceId,
1110
+ operation: str,
1111
+ *,
1112
+ input: Optional[TInput] = None,
1113
+ ):
1114
+ if id is None:
1115
+ id = self.next_sequence_number()
1116
+
1117
+ transition_valid, error_message = self._entity_context.validate_operation_transition(entity_id, False)
1118
+ if not transition_valid:
1119
+ raise RuntimeError(error_message)
1120
+
1121
+ encoded_input = shared.to_json(input) if input is not None else None
1122
+ action = ph.new_call_entity_action(id, self.instance_id, entity_id, operation, encoded_input)
1123
+ self._pending_actions[id] = action
1124
+
1125
+ fn_task = task.CompletableTask()
1126
+ self._pending_tasks[id] = fn_task
1127
+
1128
+ def signal_entity_function_helper(
1129
+ self,
1130
+ id: Optional[int],
1131
+ entity_id: EntityInstanceId,
1132
+ operation: str,
1133
+ input: Optional[TInput]
1134
+ ) -> None:
1135
+ if id is None:
1136
+ id = self.next_sequence_number()
1137
+
1138
+ transition_valid, error_message = self._entity_context.validate_operation_transition(entity_id, True)
1139
+
1140
+ if not transition_valid:
1141
+ raise RuntimeError(error_message)
1142
+
1143
+ encoded_input = shared.to_json(input) if input is not None else None
1144
+
1145
+ action = ph.new_signal_entity_action(id, entity_id, operation, encoded_input)
1146
+ self._pending_actions[id] = action
1147
+
1148
+ def lock_entities_function_helper(self, id: int, entities: list[EntityInstanceId]) -> None:
1149
+ if id is None:
1150
+ id = self.next_sequence_number()
1151
+
1152
+ transition_valid, error_message = self._entity_context.validate_acquire_transition()
1153
+ if not transition_valid:
1154
+ raise RuntimeError(error_message)
1155
+
1156
+ critical_section_id = f"{self.instance_id}:{id:04x}"
1157
+
1158
+ request, target = self._entity_context.emit_acquire_message(critical_section_id, entities)
1159
+
1160
+ if not request or not target:
1161
+ raise RuntimeError("Failed to create entity lock request.")
1162
+
1163
+ action = ph.new_lock_entities_action(id, request)
1164
+ self._pending_actions[id] = action
1165
+
1166
+ fn_task = task.CompletableTask[EntityLock]()
1167
+ self._pending_tasks[id] = fn_task
1168
+
1169
+ def _exit_critical_section(self) -> None:
1170
+ if not self._entity_context.is_inside_critical_section:
1171
+ # Possible if the user calls continue_as_new inside the lock - in the success case, we will call
1172
+ # _exit_critical_section both from the EntityLock and the continue_as_new logic. We must keep both calls in
1173
+ # case the user code crashes after calling continue_as_new but before the EntityLock object is exited.
1174
+ return
1175
+ for entity_unlock_message in self._entity_context.emit_lock_release_messages():
1176
+ task_id = self.next_sequence_number()
1177
+ action = pb.OrchestratorAction(id=task_id, sendEntityMessage=entity_unlock_message)
1178
+ self._pending_actions[task_id] = action
1179
+
912
1180
  def wait_for_external_event(self, name: str) -> task.Task:
913
1181
  # Check to see if this event has already been received, in which case we
914
1182
  # can return it immediately. Otherwise, record out intent to receive an
@@ -957,6 +1225,7 @@ class _OrchestrationExecutor:
957
1225
  self._logger = logger
958
1226
  self._is_suspended = False
959
1227
  self._suspended_events: list[pb.HistoryEvent] = []
1228
+ self._entity_state: Optional[OrchestrationEntityContext] = None
960
1229
 
961
1230
  def execute(
962
1231
  self,
@@ -964,12 +1233,14 @@ class _OrchestrationExecutor:
964
1233
  old_events: Sequence[pb.HistoryEvent],
965
1234
  new_events: Sequence[pb.HistoryEvent],
966
1235
  ) -> ExecutionResults:
1236
+ self._entity_state = OrchestrationEntityContext(instance_id)
1237
+
967
1238
  if not new_events:
968
1239
  raise task.OrchestrationStateError(
969
1240
  "The new history event list must have at least one event in it."
970
1241
  )
971
1242
 
972
- ctx = _RuntimeOrchestrationContext(instance_id, self._registry)
1243
+ ctx = _RuntimeOrchestrationContext(instance_id, self._registry, self._entity_state)
973
1244
  try:
974
1245
  # Rebuild local state by replaying old history into the orchestrator function
975
1246
  self._logger.debug(
@@ -1316,6 +1587,108 @@ class _OrchestrationExecutor:
1316
1587
  pb.ORCHESTRATION_STATUS_TERMINATED,
1317
1588
  is_result_encoded=True,
1318
1589
  )
1590
+ elif event.HasField("entityOperationCalled"):
1591
+ # This history event confirms that the entity operation was successfully scheduled.
1592
+ # Remove the entityOperationCalled event from the pending action list so we don't schedule it again
1593
+ entity_call_id = event.eventId
1594
+ action = ctx._pending_actions.pop(entity_call_id, None)
1595
+ entity_task = ctx._pending_tasks.get(entity_call_id, None)
1596
+ if not action:
1597
+ raise _get_non_determinism_error(
1598
+ entity_call_id, task.get_name(ctx.call_entity)
1599
+ )
1600
+ elif not action.HasField("sendEntityMessage") or not action.sendEntityMessage.HasField("entityOperationCalled"):
1601
+ expected_method_name = task.get_name(ctx.call_entity)
1602
+ raise _get_wrong_action_type_error(
1603
+ entity_call_id, expected_method_name, action
1604
+ )
1605
+ entity_id = EntityInstanceId.parse(event.entityOperationCalled.targetInstanceId.value)
1606
+ if not entity_id:
1607
+ raise RuntimeError(f"Could not parse entity ID from targetInstanceId '{event.entityOperationCalled.targetInstanceId.value}'")
1608
+ ctx._entity_task_id_map[event.entityOperationCalled.requestId] = (entity_id, entity_call_id)
1609
+ elif event.HasField("entityOperationSignaled"):
1610
+ # This history event confirms that the entity signal was successfully scheduled.
1611
+ # Remove the entityOperationSignaled event from the pending action list so we don't schedule it
1612
+ entity_signal_id = event.eventId
1613
+ action = ctx._pending_actions.pop(entity_signal_id, None)
1614
+ if not action:
1615
+ raise _get_non_determinism_error(
1616
+ entity_signal_id, task.get_name(ctx.signal_entity)
1617
+ )
1618
+ elif not action.HasField("sendEntityMessage") or not action.sendEntityMessage.HasField("entityOperationSignaled"):
1619
+ expected_method_name = task.get_name(ctx.signal_entity)
1620
+ raise _get_wrong_action_type_error(
1621
+ entity_signal_id, expected_method_name, action
1622
+ )
1623
+ elif event.HasField("entityLockRequested"):
1624
+ section_id = event.entityLockRequested.criticalSectionId
1625
+ task_id = event.eventId
1626
+ action = ctx._pending_actions.pop(task_id, None)
1627
+ entity_task = ctx._pending_tasks.get(task_id, None)
1628
+ if not action:
1629
+ raise _get_non_determinism_error(
1630
+ task_id, task.get_name(ctx.lock_entities)
1631
+ )
1632
+ elif not action.HasField("sendEntityMessage") or not action.sendEntityMessage.HasField("entityLockRequested"):
1633
+ expected_method_name = task.get_name(ctx.lock_entities)
1634
+ raise _get_wrong_action_type_error(
1635
+ task_id, expected_method_name, action
1636
+ )
1637
+ ctx._entity_lock_id_map[section_id] = task_id
1638
+ elif event.HasField("entityUnlockSent"):
1639
+ # Remove the unlock tasks as they have already been processed
1640
+ tasks_to_remove = []
1641
+ for task_id, action in ctx._pending_actions.items():
1642
+ if action.HasField("sendEntityMessage") and action.sendEntityMessage.HasField("entityUnlockSent"):
1643
+ if action.sendEntityMessage.entityUnlockSent.criticalSectionId == event.entityUnlockSent.criticalSectionId:
1644
+ tasks_to_remove.append(task_id)
1645
+ for task_to_remove in tasks_to_remove:
1646
+ ctx._pending_actions.pop(task_to_remove, None)
1647
+ elif event.HasField("entityLockGranted"):
1648
+ section_id = event.entityLockGranted.criticalSectionId
1649
+ task_id = ctx._entity_lock_id_map.pop(section_id, None)
1650
+ if not task_id:
1651
+ # TODO: Should this be an error? When would it ever happen?
1652
+ if not ctx.is_replaying:
1653
+ self._logger.warning(
1654
+ f"{ctx.instance_id}: Ignoring unexpected entityLockGranted event for criticalSectionId '{section_id}'."
1655
+ )
1656
+ return
1657
+ entity_task = ctx._pending_tasks.pop(task_id, None)
1658
+ if not entity_task:
1659
+ if not ctx.is_replaying:
1660
+ self._logger.warning(
1661
+ f"{ctx.instance_id}: Ignoring unexpected entityLockGranted event for criticalSectionId '{section_id}'."
1662
+ )
1663
+ return
1664
+ ctx._entity_context.complete_acquire(section_id)
1665
+ entity_task.complete(EntityLock(ctx))
1666
+ ctx.resume()
1667
+ elif event.HasField("entityOperationCompleted"):
1668
+ request_id = event.entityOperationCompleted.requestId
1669
+ entity_id, task_id = ctx._entity_task_id_map.pop(request_id, (None, None))
1670
+ if not entity_id:
1671
+ raise RuntimeError(f"Could not parse entity ID from request ID '{request_id}'")
1672
+ if not task_id:
1673
+ raise RuntimeError(f"Could not find matching task ID for entity operation with request ID '{request_id}'")
1674
+ entity_task = ctx._pending_tasks.pop(task_id, None)
1675
+ if not entity_task:
1676
+ if not ctx.is_replaying:
1677
+ self._logger.warning(
1678
+ f"{ctx.instance_id}: Ignoring unexpected entityOperationCompleted event with request ID = {request_id}."
1679
+ )
1680
+ return
1681
+ result = None
1682
+ if not ph.is_empty(event.entityOperationCompleted.output):
1683
+ result = shared.from_json(event.entityOperationCompleted.output.value)
1684
+ ctx._entity_context.recover_lock_after_call(entity_id)
1685
+ entity_task.complete(result)
1686
+ ctx.resume()
1687
+ elif event.HasField("entityOperationFailed"):
1688
+ if not ctx.is_replaying:
1689
+ self._logger.info(f"{ctx.instance_id}: Entity operation failed.")
1690
+ self._logger.info(f"Data: {json.dumps(event.entityOperationFailed)}")
1691
+ pass
1319
1692
  else:
1320
1693
  eventType = event.WhichOneof("eventType")
1321
1694
  raise task.OrchestrationStateError(
@@ -1406,6 +1779,60 @@ class _ActivityExecutor:
1406
1779
  return encoded_output
1407
1780
 
1408
1781
 
1782
+ class _EntityExecutor:
1783
+ def __init__(self, registry: _Registry, logger: logging.Logger):
1784
+ self._registry = registry
1785
+ self._logger = logger
1786
+
1787
+ def execute(
1788
+ self,
1789
+ orchestration_id: str,
1790
+ entity_id: EntityInstanceId,
1791
+ operation: str,
1792
+ state: StateShim,
1793
+ encoded_input: Optional[str],
1794
+ ) -> Optional[str]:
1795
+ """Executes an entity function and returns the serialized result, if any."""
1796
+ self._logger.debug(
1797
+ f"{orchestration_id}: Executing entity '{entity_id}'..."
1798
+ )
1799
+ fn = self._registry.get_entity(entity_id.entity)
1800
+ if not fn:
1801
+ raise EntityNotRegisteredError(
1802
+ f"Entity function named '{entity_id.entity}' was not registered!"
1803
+ )
1804
+
1805
+ entity_input = shared.from_json(encoded_input) if encoded_input else None
1806
+ ctx = EntityContext(orchestration_id, operation, state, entity_id)
1807
+
1808
+ if isinstance(fn, type) and issubclass(fn, DurableEntity):
1809
+ if self._registry.entity_instances.get(str(entity_id), None):
1810
+ entity_instance = self._registry.entity_instances[str(entity_id)]
1811
+ else:
1812
+ entity_instance = fn()
1813
+ self._registry.entity_instances[str(entity_id)] = entity_instance
1814
+ if not hasattr(entity_instance, operation):
1815
+ raise AttributeError(f"Entity '{entity_id}' does not have operation '{operation}'")
1816
+ method = getattr(entity_instance, operation)
1817
+ if not callable(method):
1818
+ raise TypeError(f"Entity operation '{operation}' is not callable")
1819
+ # Execute the entity method
1820
+ entity_instance._initialize_entity_context(ctx)
1821
+ entity_output = method(entity_input)
1822
+ else:
1823
+ # Execute the entity function
1824
+ entity_output = fn(ctx, entity_input)
1825
+
1826
+ encoded_output = (
1827
+ shared.to_json(entity_output) if entity_output is not None else None
1828
+ )
1829
+ chars = len(encoded_output) if encoded_output else 0
1830
+ self._logger.debug(
1831
+ f"{orchestration_id}: Entity '{entity_id}' completed successfully with {chars} char(s) of encoded output."
1832
+ )
1833
+ return encoded_output
1834
+
1835
+
1409
1836
  def _get_non_determinism_error(
1410
1837
  task_id: int, action_name: str
1411
1838
  ) -> task.NonDeterminismError:
@@ -1497,13 +1924,16 @@ class _AsyncWorkerManager:
1497
1924
  self.concurrency_options = concurrency_options
1498
1925
  self.activity_semaphore = None
1499
1926
  self.orchestration_semaphore = None
1927
+ self.entity_semaphore = None
1500
1928
  # Don't create queues here - defer until we have an event loop
1501
1929
  self.activity_queue: Optional[asyncio.Queue] = None
1502
1930
  self.orchestration_queue: Optional[asyncio.Queue] = None
1931
+ self.entity_batch_queue: Optional[asyncio.Queue] = None
1503
1932
  self._queue_event_loop: Optional[asyncio.AbstractEventLoop] = None
1504
1933
  # Store work items when no event loop is available
1505
1934
  self._pending_activity_work: list = []
1506
1935
  self._pending_orchestration_work: list = []
1936
+ self._pending_entity_batch_work: list = []
1507
1937
  self.thread_pool = ThreadPoolExecutor(
1508
1938
  max_workers=concurrency_options.maximum_thread_pool_workers,
1509
1939
  thread_name_prefix="DurableTask",
@@ -1520,7 +1950,7 @@ class _AsyncWorkerManager:
1520
1950
 
1521
1951
  # Check if queues are already properly set up for current loop
1522
1952
  if self._queue_event_loop is current_loop:
1523
- if self.activity_queue is not None and self.orchestration_queue is not None:
1953
+ if self.activity_queue is not None and self.orchestration_queue is not None and self.entity_batch_queue is not None:
1524
1954
  # Queues are already bound to the current loop and exist
1525
1955
  return
1526
1956
 
@@ -1528,6 +1958,7 @@ class _AsyncWorkerManager:
1528
1958
  # First, preserve any existing work items
1529
1959
  existing_activity_items = []
1530
1960
  existing_orchestration_items = []
1961
+ existing_entity_batch_items = []
1531
1962
 
1532
1963
  if self.activity_queue is not None:
1533
1964
  try:
@@ -1545,9 +1976,19 @@ class _AsyncWorkerManager:
1545
1976
  except Exception:
1546
1977
  pass
1547
1978
 
1979
+ if self.entity_batch_queue is not None:
1980
+ try:
1981
+ while not self.entity_batch_queue.empty():
1982
+ existing_entity_batch_items.append(
1983
+ self.entity_batch_queue.get_nowait()
1984
+ )
1985
+ except Exception:
1986
+ pass
1987
+
1548
1988
  # Create fresh queues for the current event loop
1549
1989
  self.activity_queue = asyncio.Queue()
1550
1990
  self.orchestration_queue = asyncio.Queue()
1991
+ self.entity_batch_queue = asyncio.Queue()
1551
1992
  self._queue_event_loop = current_loop
1552
1993
 
1553
1994
  # Restore the work items to the new queues
@@ -1555,16 +1996,21 @@ class _AsyncWorkerManager:
1555
1996
  self.activity_queue.put_nowait(item)
1556
1997
  for item in existing_orchestration_items:
1557
1998
  self.orchestration_queue.put_nowait(item)
1999
+ for item in existing_entity_batch_items:
2000
+ self.entity_batch_queue.put_nowait(item)
1558
2001
 
1559
2002
  # Move pending work items to the queues
1560
2003
  for item in self._pending_activity_work:
1561
2004
  self.activity_queue.put_nowait(item)
1562
2005
  for item in self._pending_orchestration_work:
1563
2006
  self.orchestration_queue.put_nowait(item)
2007
+ for item in self._pending_entity_batch_work:
2008
+ self.entity_batch_queue.put_nowait(item)
1564
2009
 
1565
2010
  # Clear the pending work lists
1566
2011
  self._pending_activity_work.clear()
1567
2012
  self._pending_orchestration_work.clear()
2013
+ self._pending_entity_batch_work.clear()
1568
2014
 
1569
2015
  async def run(self):
1570
2016
  # Reset shutdown flag in case this manager is being reused
@@ -1580,14 +2026,21 @@ class _AsyncWorkerManager:
1580
2026
  self.orchestration_semaphore = asyncio.Semaphore(
1581
2027
  self.concurrency_options.maximum_concurrent_orchestration_work_items
1582
2028
  )
2029
+ self.entity_semaphore = asyncio.Semaphore(
2030
+ self.concurrency_options.maximum_concurrent_entity_work_items
2031
+ )
1583
2032
 
1584
2033
  # Start background consumers for each work type
1585
- if self.activity_queue is not None and self.orchestration_queue is not None:
2034
+ if self.activity_queue is not None and self.orchestration_queue is not None \
2035
+ and self.entity_batch_queue is not None:
1586
2036
  await asyncio.gather(
1587
2037
  self._consume_queue(self.activity_queue, self.activity_semaphore),
1588
2038
  self._consume_queue(
1589
2039
  self.orchestration_queue, self.orchestration_semaphore
1590
2040
  ),
2041
+ self._consume_queue(
2042
+ self.entity_batch_queue, self.entity_semaphore
2043
+ )
1591
2044
  )
1592
2045
 
1593
2046
  async def _consume_queue(self, queue: asyncio.Queue, semaphore: asyncio.Semaphore):
@@ -1657,6 +2110,15 @@ class _AsyncWorkerManager:
1657
2110
  # No event loop running, store in pending list
1658
2111
  self._pending_orchestration_work.append(work_item)
1659
2112
 
2113
+ def submit_entity_batch(self, func, *args, **kwargs):
2114
+ work_item = (func, args, kwargs)
2115
+ self._ensure_queues_for_current_loop()
2116
+ if self.entity_batch_queue is not None:
2117
+ self.entity_batch_queue.put_nowait(work_item)
2118
+ else:
2119
+ # No event loop running, store in pending list
2120
+ self._pending_entity_batch_work.append(work_item)
2121
+
1660
2122
  def shutdown(self):
1661
2123
  self._shutdown = True
1662
2124
  self.thread_pool.shutdown(wait=True)