durabletask 0.1.0a1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
durabletask/worker.py CHANGED
@@ -1,40 +1,155 @@
1
1
  # Copyright (c) Microsoft Corporation.
2
2
  # Licensed under the MIT License.
3
3
 
4
- import concurrent.futures
4
+ import asyncio
5
+ import inspect
6
+ import json
5
7
  import logging
6
- from dataclasses import dataclass
7
- from datetime import datetime, timedelta
8
+ import os
9
+ import random
10
+ from concurrent.futures import ThreadPoolExecutor
11
+ from datetime import datetime, timedelta, timezone
8
12
  from threading import Event, Thread
9
13
  from types import GeneratorType
10
- from typing import Any, Dict, Generator, List, Sequence, TypeVar, Union
14
+ from enum import Enum
15
+ from typing import Any, Generator, Optional, Sequence, TypeVar, Union
16
+ from packaging.version import InvalidVersion, parse
11
17
 
12
18
  import grpc
13
19
  from google.protobuf import empty_pb2
14
20
 
21
+ from durabletask.internal import helpers
22
+ from durabletask.internal.entity_state_shim import StateShim
23
+ from durabletask.internal.helpers import new_timestamp
24
+ from durabletask.entities import DurableEntity, EntityLock, EntityInstanceId, EntityContext
25
+ from durabletask.internal.orchestration_entity_context import OrchestrationEntityContext
15
26
  import durabletask.internal.helpers as ph
16
- import durabletask.internal.helpers as pbh
27
+ import durabletask.internal.exceptions as pe
17
28
  import durabletask.internal.orchestrator_service_pb2 as pb
18
29
  import durabletask.internal.orchestrator_service_pb2_grpc as stubs
19
30
  import durabletask.internal.shared as shared
20
31
  from durabletask import task
32
+ from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl
33
+
34
+ TInput = TypeVar("TInput")
35
+ TOutput = TypeVar("TOutput")
36
+
37
+
38
+ class ConcurrencyOptions:
39
+ """Configuration options for controlling concurrency of different work item types and the thread pool size.
40
+
41
+ This class provides fine-grained control over concurrent processing limits for
42
+ activities, orchestrations and the thread pool size.
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ maximum_concurrent_activity_work_items: Optional[int] = None,
48
+ maximum_concurrent_orchestration_work_items: Optional[int] = None,
49
+ maximum_concurrent_entity_work_items: Optional[int] = None,
50
+ maximum_thread_pool_workers: Optional[int] = None,
51
+ ):
52
+ """Initialize concurrency options.
53
+
54
+ Args:
55
+ maximum_concurrent_activity_work_items: Maximum number of activity work items
56
+ that can be processed concurrently. Defaults to 100 * processor_count.
57
+ maximum_concurrent_orchestration_work_items: Maximum number of orchestration work items
58
+ that can be processed concurrently. Defaults to 100 * processor_count.
59
+ maximum_thread_pool_workers: Maximum number of thread pool workers to use.
60
+ """
61
+ processor_count = os.cpu_count() or 1
62
+ default_concurrency = 100 * processor_count
63
+ # see https://docs.python.org/3/library/concurrent.futures.html
64
+ default_max_workers = processor_count + 4
65
+
66
+ self.maximum_concurrent_activity_work_items = (
67
+ maximum_concurrent_activity_work_items
68
+ if maximum_concurrent_activity_work_items is not None
69
+ else default_concurrency
70
+ )
21
71
 
22
- TInput = TypeVar('TInput')
23
- TOutput = TypeVar('TOutput')
72
+ self.maximum_concurrent_orchestration_work_items = (
73
+ maximum_concurrent_orchestration_work_items
74
+ if maximum_concurrent_orchestration_work_items is not None
75
+ else default_concurrency
76
+ )
24
77
 
78
+ self.maximum_concurrent_entity_work_items = (
79
+ maximum_concurrent_entity_work_items
80
+ if maximum_concurrent_entity_work_items is not None
81
+ else default_concurrency
82
+ )
25
83
 
26
- class _Registry:
84
+ self.maximum_thread_pool_workers = (
85
+ maximum_thread_pool_workers
86
+ if maximum_thread_pool_workers is not None
87
+ else default_max_workers
88
+ )
89
+
90
+
91
+ class VersionMatchStrategy(Enum):
92
+ """Enumeration for version matching strategies."""
93
+
94
+ NONE = 1
95
+ STRICT = 2
96
+ CURRENT_OR_OLDER = 3
97
+
98
+
99
+ class VersionFailureStrategy(Enum):
100
+ """Enumeration for version failure strategies."""
101
+
102
+ REJECT = 1
103
+ FAIL = 2
104
+
105
+
106
+ class VersioningOptions:
107
+ """Configuration options for orchestrator and activity versioning.
108
+
109
+ This class provides options to control how versioning is handled for orchestrators
110
+ and activities, including whether to use the default version and how to compare versions.
111
+ """
112
+
113
+ version: Optional[str] = None
114
+ default_version: Optional[str] = None
115
+ match_strategy: Optional[VersionMatchStrategy] = None
116
+ failure_strategy: Optional[VersionFailureStrategy] = None
27
117
 
28
- orchestrators: Dict[str, task.Orchestrator]
29
- activities: Dict[str, task.Activity]
118
+ def __init__(self, version: Optional[str] = None,
119
+ default_version: Optional[str] = None,
120
+ match_strategy: Optional[VersionMatchStrategy] = None,
121
+ failure_strategy: Optional[VersionFailureStrategy] = None
122
+ ):
123
+ """Initialize versioning options.
124
+
125
+ Args:
126
+ version: The version of orchestrations that the worker can work on.
127
+ default_version: The default version that will be used for starting new sub-orchestrations.
128
+ match_strategy: The versioning strategy for the Durable Task worker.
129
+ failure_strategy: The versioning failure strategy for the Durable Task worker.
130
+ """
131
+ self.version = version
132
+ self.default_version = default_version
133
+ self.match_strategy = match_strategy
134
+ self.failure_strategy = failure_strategy
135
+
136
+
137
+ class _Registry:
138
+ orchestrators: dict[str, task.Orchestrator]
139
+ activities: dict[str, task.Activity]
140
+ entities: dict[str, task.Entity]
141
+ entity_instances: dict[str, DurableEntity]
142
+ versioning: Optional[VersioningOptions] = None
30
143
 
31
144
  def __init__(self):
32
- self.orchestrators = dict[str, task.Orchestrator]()
33
- self.activities = dict[str, task.Activity]()
145
+ self.orchestrators = {}
146
+ self.activities = {}
147
+ self.entities = {}
148
+ self.entity_instances = {}
34
149
 
35
150
  def add_orchestrator(self, fn: task.Orchestrator) -> str:
36
151
  if fn is None:
37
- raise ValueError('An orchestrator function argument is required.')
152
+ raise ValueError("An orchestrator function argument is required.")
38
153
 
39
154
  name = task.get_name(fn)
40
155
  self.add_named_orchestrator(name, fn)
@@ -42,18 +157,18 @@ class _Registry:
42
157
 
43
158
  def add_named_orchestrator(self, name: str, fn: task.Orchestrator) -> None:
44
159
  if not name:
45
- raise ValueError('A non-empty orchestrator name is required.')
160
+ raise ValueError("A non-empty orchestrator name is required.")
46
161
  if name in self.orchestrators:
47
162
  raise ValueError(f"A '{name}' orchestrator already exists.")
48
163
 
49
164
  self.orchestrators[name] = fn
50
165
 
51
- def get_orchestrator(self, name: str) -> Union[task.Orchestrator, None]:
166
+ def get_orchestrator(self, name: str) -> Optional[task.Orchestrator]:
52
167
  return self.orchestrators.get(name)
53
168
 
54
169
  def add_activity(self, fn: task.Activity) -> str:
55
170
  if fn is None:
56
- raise ValueError('An activity function argument is required.')
171
+ raise ValueError("An activity function argument is required.")
57
172
 
58
173
  name = task.get_name(fn)
59
174
  self.add_named_activity(name, fn)
@@ -61,39 +176,182 @@ class _Registry:
61
176
 
62
177
  def add_named_activity(self, name: str, fn: task.Activity) -> None:
63
178
  if not name:
64
- raise ValueError('A non-empty activity name is required.')
179
+ raise ValueError("A non-empty activity name is required.")
65
180
  if name in self.activities:
66
181
  raise ValueError(f"A '{name}' activity already exists.")
67
182
 
68
183
  self.activities[name] = fn
69
184
 
70
- def get_activity(self, name: str) -> Union[task.Activity, None]:
185
+ def get_activity(self, name: str) -> Optional[task.Activity]:
71
186
  return self.activities.get(name)
72
187
 
188
+ def add_entity(self, fn: task.Entity) -> str:
189
+ if fn is None:
190
+ raise ValueError("An entity function argument is required.")
191
+
192
+ if isinstance(fn, type) and issubclass(fn, DurableEntity):
193
+ name = fn.__name__
194
+ self.add_named_entity(name, fn)
195
+ else:
196
+ name = task.get_name(fn)
197
+ self.add_named_entity(name, fn)
198
+ return name
199
+
200
+ def add_named_entity(self, name: str, fn: task.Entity) -> None:
201
+ if not name:
202
+ raise ValueError("A non-empty entity name is required.")
203
+ if name in self.entities:
204
+ raise ValueError(f"A '{name}' entity already exists.")
205
+
206
+ self.entities[name] = fn
207
+
208
+ def get_entity(self, name: str) -> Optional[task.Entity]:
209
+ return self.entities.get(name)
210
+
73
211
 
74
212
  class OrchestratorNotRegisteredError(ValueError):
75
213
  """Raised when attempting to start an orchestration that is not registered"""
214
+
76
215
  pass
77
216
 
78
217
 
79
218
  class ActivityNotRegisteredError(ValueError):
80
219
  """Raised when attempting to call an activity that is not registered"""
220
+
81
221
  pass
82
222
 
83
223
 
84
- class TaskHubGrpcWorker:
85
- _response_stream: Union[grpc.Future, None]
224
+ class EntityNotRegisteredError(ValueError):
225
+ """Raised when attempting to call an entity that is not registered"""
226
+
227
+ pass
228
+
86
229
 
87
- def __init__(self, *,
88
- host_address: Union[str, None] = None,
89
- log_handler=None,
90
- log_formatter: Union[logging.Formatter, None] = None):
230
+ class TaskHubGrpcWorker:
231
+ """A gRPC-based worker for processing durable task orchestrations and activities.
232
+
233
+ This worker connects to a Durable Task backend service via gRPC to receive and process
234
+ work items including orchestration functions and activity functions. It provides
235
+ concurrent execution capabilities with configurable limits and automatic retry handling.
236
+
237
+ The worker manages the complete lifecycle:
238
+ - Registers orchestrator and activity functions
239
+ - Connects to the gRPC backend service
240
+ - Receives work items and executes them concurrently
241
+ - Handles failures, retries, and state management
242
+ - Provides logging and monitoring capabilities
243
+
244
+ Args:
245
+ host_address (Optional[str], optional): The gRPC endpoint address of the backend service.
246
+ Defaults to the value from environment variables or localhost.
247
+ metadata (Optional[list[tuple[str, str]]], optional): gRPC metadata to include with
248
+ requests. Used for authentication and routing. Defaults to None.
249
+ log_handler (optional[logging.Handler]): Custom logging handler for worker logs. Defaults to None.
250
+ log_formatter (Optional[logging.Formatter], optional): Custom log formatter.
251
+ Defaults to None.
252
+ secure_channel (bool, optional): Whether to use a secure gRPC channel (TLS).
253
+ Defaults to False.
254
+ interceptors (Optional[Sequence[shared.ClientInterceptor]], optional): Custom gRPC
255
+ interceptors to apply to the channel. Defaults to None.
256
+ concurrency_options (Optional[ConcurrencyOptions], optional): Configuration for
257
+ controlling worker concurrency limits. If None, default settings are used.
258
+
259
+ Attributes:
260
+ concurrency_options (ConcurrencyOptions): The current concurrency configuration.
261
+
262
+ Example:
263
+ Basic worker setup:
264
+
265
+ >>> from durabletask.worker import TaskHubGrpcWorker, ConcurrencyOptions
266
+ >>>
267
+ >>> # Create worker with custom concurrency settings
268
+ >>> concurrency = ConcurrencyOptions(
269
+ ... maximum_concurrent_activity_work_items=50,
270
+ ... maximum_concurrent_orchestration_work_items=20
271
+ ... )
272
+ >>> worker = TaskHubGrpcWorker(
273
+ ... host_address="localhost:4001",
274
+ ... concurrency_options=concurrency
275
+ ... )
276
+ >>>
277
+ >>> # Register functions
278
+ >>> @worker.add_orchestrator
279
+ ... def my_orchestrator(context, input):
280
+ ... result = yield context.call_activity("my_activity", input="hello")
281
+ ... return result
282
+ >>>
283
+ >>> @worker.add_activity
284
+ ... def my_activity(context, input):
285
+ ... return f"Processed: {input}"
286
+ >>>
287
+ >>> # Start the worker
288
+ >>> worker.start()
289
+ >>> # ... worker runs in background thread
290
+ >>> worker.stop()
291
+
292
+ Using as context manager:
293
+
294
+ >>> with TaskHubGrpcWorker() as worker:
295
+ ... worker.add_orchestrator(my_orchestrator)
296
+ ... worker.add_activity(my_activity)
297
+ ... worker.start()
298
+ ... # Worker automatically stops when exiting context
299
+
300
+ Raises:
301
+ RuntimeError: If attempting to add orchestrators/activities while the worker is running,
302
+ or if starting a worker that is already running.
303
+ OrchestratorNotRegisteredError: If an orchestration work item references an
304
+ unregistered orchestrator function.
305
+ ActivityNotRegisteredError: If an activity work item references an unregistered
306
+ activity function.
307
+ """
308
+
309
+ _response_stream: Optional[grpc.Future] = None
310
+ _interceptors: Optional[list[shared.ClientInterceptor]] = None
311
+
312
+ def __init__(
313
+ self,
314
+ *,
315
+ host_address: Optional[str] = None,
316
+ metadata: Optional[list[tuple[str, str]]] = None,
317
+ log_handler: Optional[logging.Handler] = None,
318
+ log_formatter: Optional[logging.Formatter] = None,
319
+ secure_channel: bool = False,
320
+ interceptors: Optional[Sequence[shared.ClientInterceptor]] = None,
321
+ concurrency_options: Optional[ConcurrencyOptions] = None,
322
+ ):
91
323
  self._registry = _Registry()
92
- self._host_address = host_address if host_address else shared.get_default_host_address()
93
- self._logger = shared.get_logger(log_handler, log_formatter)
324
+ self._host_address = (
325
+ host_address if host_address else shared.get_default_host_address()
326
+ )
327
+ self._logger = shared.get_logger("worker", log_handler, log_formatter)
94
328
  self._shutdown = Event()
95
- self._response_stream = None
96
329
  self._is_running = False
330
+ self._secure_channel = secure_channel
331
+
332
+ # Use provided concurrency options or create default ones
333
+ self._concurrency_options = (
334
+ concurrency_options
335
+ if concurrency_options is not None
336
+ else ConcurrencyOptions()
337
+ )
338
+
339
+ # Determine the interceptors to use
340
+ if interceptors is not None:
341
+ self._interceptors = list(interceptors)
342
+ if metadata:
343
+ self._interceptors.append(DefaultClientInterceptorImpl(metadata))
344
+ elif metadata:
345
+ self._interceptors = [DefaultClientInterceptorImpl(metadata)]
346
+ else:
347
+ self._interceptors = None
348
+
349
+ self._async_worker_manager = _AsyncWorkerManager(self._concurrency_options, self._logger)
350
+
351
+ @property
352
+ def concurrency_options(self) -> ConcurrencyOptions:
353
+ """Get the current concurrency options for this worker."""
354
+ return self._concurrency_options
97
355
 
98
356
  def __enter__(self):
99
357
  return self
@@ -104,69 +362,254 @@ class TaskHubGrpcWorker:
104
362
  def add_orchestrator(self, fn: task.Orchestrator) -> str:
105
363
  """Registers an orchestrator function with the worker."""
106
364
  if self._is_running:
107
- raise RuntimeError('Orchestrators cannot be added while the worker is running.')
365
+ raise RuntimeError(
366
+ "Orchestrators cannot be added while the worker is running."
367
+ )
108
368
  return self._registry.add_orchestrator(fn)
109
369
 
110
370
  def add_activity(self, fn: task.Activity) -> str:
111
371
  """Registers an activity function with the worker."""
112
372
  if self._is_running:
113
- raise RuntimeError('Activities cannot be added while the worker is running.')
373
+ raise RuntimeError(
374
+ "Activities cannot be added while the worker is running."
375
+ )
114
376
  return self._registry.add_activity(fn)
115
377
 
378
+ def add_entity(self, fn: task.Entity) -> str:
379
+ """Registers an entity function with the worker."""
380
+ if self._is_running:
381
+ raise RuntimeError(
382
+ "Entities cannot be added while the worker is running."
383
+ )
384
+ return self._registry.add_entity(fn)
385
+
386
+ def use_versioning(self, version: VersioningOptions) -> None:
387
+ """Initializes versioning options for sub-orchestrators and activities."""
388
+ if self._is_running:
389
+ raise RuntimeError("Cannot set default version while the worker is running.")
390
+ self._registry.versioning = version
391
+
116
392
  def start(self):
117
393
  """Starts the worker on a background thread and begins listening for work items."""
118
- channel = shared.get_grpc_channel(self._host_address)
119
- stub = stubs.TaskHubSidecarServiceStub(channel)
120
-
121
394
  if self._is_running:
122
- raise RuntimeError('The worker is already running.')
395
+ raise RuntimeError("The worker is already running.")
123
396
 
124
397
  def run_loop():
125
- # TODO: Investigate whether asyncio could be used to enable greater concurrency for async activity
126
- # functions. We'd need to know ahead of time whether a function is async or not.
127
- # TODO: Max concurrency configuration settings
128
- with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor:
129
- while not self._shutdown.is_set():
130
- try:
131
- # send a "Hello" message to the sidecar to ensure that it's listening
132
- stub.Hello(empty_pb2.Empty())
398
+ loop = asyncio.new_event_loop()
399
+ asyncio.set_event_loop(loop)
400
+ loop.run_until_complete(self._async_run_loop())
133
401
 
134
- # stream work items
135
- self._response_stream = stub.GetWorkItems(pb.GetWorkItemsRequest())
136
- self._logger.info(f'Successfully connected to {self._host_address}. Waiting for work items...')
402
+ self._logger.info(f"Starting gRPC worker that connects to {self._host_address}")
403
+ self._runLoop = Thread(target=run_loop)
404
+ self._runLoop.start()
405
+ self._is_running = True
137
406
 
138
- # The stream blocks until either a work item is received or the stream is canceled
139
- # by another thread (see the stop() method).
407
+ async def _async_run_loop(self):
408
+ worker_task = asyncio.create_task(self._async_worker_manager.run())
409
+ # Connection state management for retry fix
410
+ current_channel = None
411
+ current_stub = None
412
+ current_reader_thread = None
413
+ conn_retry_count = 0
414
+ conn_max_retry_delay = 60
415
+
416
+ def create_fresh_connection():
417
+ nonlocal current_channel, current_stub, conn_retry_count
418
+ if current_channel:
419
+ try:
420
+ current_channel.close()
421
+ except Exception:
422
+ pass
423
+ current_channel = None
424
+ current_stub = None
425
+ try:
426
+ current_channel = shared.get_grpc_channel(
427
+ self._host_address, self._secure_channel, self._interceptors
428
+ )
429
+ current_stub = stubs.TaskHubSidecarServiceStub(current_channel)
430
+ current_stub.Hello(empty_pb2.Empty())
431
+ conn_retry_count = 0
432
+ self._logger.info(f"Created fresh connection to {self._host_address}")
433
+ except Exception as e:
434
+ self._logger.warning(f"Failed to create connection: {e}")
435
+ current_channel = None
436
+ current_stub = None
437
+ raise
438
+
439
+ def invalidate_connection():
440
+ nonlocal current_channel, current_stub, current_reader_thread
441
+ # Cancel the response stream first to signal the reader thread to stop
442
+ if self._response_stream is not None:
443
+ try:
444
+ self._response_stream.cancel()
445
+ except Exception:
446
+ pass
447
+ self._response_stream = None
448
+
449
+ # Wait for the reader thread to finish
450
+ if current_reader_thread is not None:
451
+ try:
452
+ current_reader_thread.join(timeout=2)
453
+ if current_reader_thread.is_alive():
454
+ self._logger.warning("Stream reader thread did not shut down gracefully")
455
+ except Exception:
456
+ pass
457
+ current_reader_thread = None
458
+
459
+ # Close the channel
460
+ if current_channel:
461
+ try:
462
+ current_channel.close()
463
+ except Exception:
464
+ pass
465
+ current_channel = None
466
+ current_stub = None
467
+
468
+ def should_invalidate_connection(rpc_error):
469
+ error_code = rpc_error.code() # type: ignore
470
+ connection_level_errors = {
471
+ grpc.StatusCode.UNAVAILABLE,
472
+ grpc.StatusCode.DEADLINE_EXCEEDED,
473
+ grpc.StatusCode.CANCELLED,
474
+ grpc.StatusCode.UNAUTHENTICATED,
475
+ grpc.StatusCode.ABORTED,
476
+ }
477
+ return error_code in connection_level_errors
478
+
479
+ while not self._shutdown.is_set():
480
+ if current_stub is None:
481
+ try:
482
+ create_fresh_connection()
483
+ except Exception:
484
+ conn_retry_count += 1
485
+ delay = min(
486
+ conn_max_retry_delay,
487
+ (2 ** min(conn_retry_count, 6)) + random.uniform(0, 1),
488
+ )
489
+ self._logger.warning(
490
+ f"Connection failed, retrying in {delay:.2f} seconds (attempt {conn_retry_count})"
491
+ )
492
+ if self._shutdown.wait(delay):
493
+ break
494
+ continue
495
+ try:
496
+ assert current_stub is not None
497
+ stub = current_stub
498
+ get_work_items_request = pb.GetWorkItemsRequest(
499
+ maxConcurrentOrchestrationWorkItems=self._concurrency_options.maximum_concurrent_orchestration_work_items,
500
+ maxConcurrentActivityWorkItems=self._concurrency_options.maximum_concurrent_activity_work_items,
501
+ )
502
+ self._response_stream = stub.GetWorkItems(get_work_items_request)
503
+ self._logger.info(
504
+ f"Successfully connected to {self._host_address}. Waiting for work items..."
505
+ )
506
+
507
+ # Use a thread to read from the blocking gRPC stream and forward to asyncio
508
+ import queue
509
+
510
+ work_item_queue = queue.Queue()
511
+
512
+ def stream_reader():
513
+ try:
140
514
  for work_item in self._response_stream:
141
- request_type = work_item.WhichOneof('request')
142
- self._logger.debug(f'Received "{request_type}" work item')
143
- if work_item.HasField('orchestratorRequest'):
144
- executor.submit(self._execute_orchestrator, work_item.orchestratorRequest, stub)
145
- elif work_item.HasField('activityRequest'):
146
- executor.submit(self._execute_activity, work_item.activityRequest, stub)
147
- else:
148
- self._logger.warning(f'Unexpected work item type: {request_type}')
149
-
150
- except grpc.RpcError as rpc_error:
151
- if rpc_error.code() == grpc.StatusCode.CANCELLED: # type: ignore
152
- self._logger.warning(f'Disconnected from {self._host_address}')
153
- elif rpc_error.code() == grpc.StatusCode.UNAVAILABLE: # type: ignore
154
- self._logger.warning(
155
- f'The sidecar at address {self._host_address} is unavailable - will continue retrying')
156
- else:
157
- self._logger.warning(f'Unexpected error: {rpc_error}')
158
- except Exception as ex:
159
- self._logger.warning(f'Unexpected error: {ex}')
515
+ work_item_queue.put(work_item)
516
+ except Exception as e:
517
+ work_item_queue.put(e)
160
518
 
161
- # CONSIDER: exponential backoff
162
- self._shutdown.wait(5)
163
- self._logger.info("No longer listening for work items")
164
- return
519
+ import threading
165
520
 
166
- self._logger.info(f"starting gRPC worker that connects to {self._host_address}")
167
- self._runLoop = Thread(target=run_loop)
168
- self._runLoop.start()
169
- self._is_running = True
521
+ current_reader_thread = threading.Thread(target=stream_reader, daemon=True)
522
+ current_reader_thread.start()
523
+ loop = asyncio.get_running_loop()
524
+ while not self._shutdown.is_set():
525
+ try:
526
+ work_item = await loop.run_in_executor(
527
+ None, work_item_queue.get
528
+ )
529
+ if isinstance(work_item, Exception):
530
+ raise work_item
531
+ request_type = work_item.WhichOneof("request")
532
+ self._logger.debug(f'Received "{request_type}" work item')
533
+ if work_item.HasField("orchestratorRequest"):
534
+ self._async_worker_manager.submit_orchestration(
535
+ self._execute_orchestrator,
536
+ self._cancel_orchestrator,
537
+ work_item.orchestratorRequest,
538
+ stub,
539
+ work_item.completionToken,
540
+ )
541
+ elif work_item.HasField("activityRequest"):
542
+ self._async_worker_manager.submit_activity(
543
+ self._execute_activity,
544
+ self._cancel_activity,
545
+ work_item.activityRequest,
546
+ stub,
547
+ work_item.completionToken,
548
+ )
549
+ elif work_item.HasField("entityRequest"):
550
+ self._async_worker_manager.submit_entity_batch(
551
+ self._execute_entity_batch,
552
+ self._cancel_entity_batch,
553
+ work_item.entityRequest,
554
+ stub,
555
+ work_item.completionToken,
556
+ )
557
+ elif work_item.HasField("entityRequestV2"):
558
+ self._async_worker_manager.submit_entity_batch(
559
+ self._execute_entity_batch,
560
+ self._cancel_entity_batch,
561
+ work_item.entityRequestV2,
562
+ stub,
563
+ work_item.completionToken
564
+ )
565
+ elif work_item.HasField("healthPing"):
566
+ pass
567
+ else:
568
+ self._logger.warning(
569
+ f"Unexpected work item type: {request_type}"
570
+ )
571
+ except Exception as e:
572
+ self._logger.warning(f"Error in work item stream: {e}")
573
+ raise e
574
+ current_reader_thread.join(timeout=1)
575
+ self._logger.info("Work item stream ended normally")
576
+ except grpc.RpcError as rpc_error:
577
+ should_invalidate = should_invalidate_connection(rpc_error)
578
+ if should_invalidate:
579
+ invalidate_connection()
580
+ error_code = rpc_error.code() # type: ignore
581
+ error_details = str(rpc_error)
582
+
583
+ if error_code == grpc.StatusCode.CANCELLED:
584
+ self._logger.info(f"Disconnected from {self._host_address}")
585
+ break
586
+ elif error_code == grpc.StatusCode.UNAVAILABLE:
587
+ # Check if this is a connection timeout scenario
588
+ if "Timeout occurred" in error_details or "Failed to connect to remote host" in error_details:
589
+ self._logger.warning(
590
+ f"Connection timeout to {self._host_address}: {error_details} - will retry with fresh connection"
591
+ )
592
+ else:
593
+ self._logger.warning(
594
+ f"The sidecar at address {self._host_address} is unavailable: {error_details} - will continue retrying"
595
+ )
596
+ elif should_invalidate:
597
+ self._logger.warning(
598
+ f"Connection-level gRPC error ({error_code}): {rpc_error} - resetting connection"
599
+ )
600
+ else:
601
+ self._logger.warning(
602
+ f"Application-level gRPC error ({error_code}): {rpc_error}"
603
+ )
604
+ self._shutdown.wait(1)
605
+ except Exception as ex:
606
+ invalidate_connection()
607
+ self._logger.warning(f"Unexpected error: {ex}")
608
+ self._shutdown.wait(1)
609
+ invalidate_connection()
610
+ self._logger.info("No longer listening for work items")
611
+ self._async_worker_manager.shutdown()
612
+ await worker_task
170
613
 
171
614
  def stop(self):
172
615
  """Stops the worker and waits for any pending work items to complete."""
@@ -179,70 +622,226 @@ class TaskHubGrpcWorker:
179
622
  self._response_stream.cancel()
180
623
  if self._runLoop is not None:
181
624
  self._runLoop.join(timeout=30)
625
+ self._async_worker_manager.shutdown()
182
626
  self._logger.info("Worker shutdown completed")
183
627
  self._is_running = False
184
628
 
185
- def _execute_orchestrator(self, req: pb.OrchestratorRequest, stub: stubs.TaskHubSidecarServiceStub):
629
+ def _execute_orchestrator(
630
+ self,
631
+ req: pb.OrchestratorRequest,
632
+ stub: stubs.TaskHubSidecarServiceStub,
633
+ completionToken,
634
+ ):
186
635
  try:
187
636
  executor = _OrchestrationExecutor(self._registry, self._logger)
188
- actions = executor.execute(req.instanceId, req.pastEvents, req.newEvents)
189
- res = pb.OrchestratorResponse(instanceId=req.instanceId, actions=actions)
637
+ result = executor.execute(req.instanceId, req.pastEvents, req.newEvents)
638
+ res = pb.OrchestratorResponse(
639
+ instanceId=req.instanceId,
640
+ actions=result.actions,
641
+ customStatus=ph.get_string_value(result.encoded_custom_status),
642
+ completionToken=completionToken,
643
+ )
644
+ except pe.AbandonOrchestrationError:
645
+ self._logger.info(
646
+ f"Abandoning orchestration. InstanceId = '{req.instanceId}'. Completion token = '{completionToken}'"
647
+ )
648
+ stub.AbandonTaskOrchestratorWorkItem(
649
+ pb.AbandonOrchestrationTaskRequest(
650
+ completionToken=completionToken
651
+ )
652
+ )
653
+ return
190
654
  except Exception as ex:
191
- self._logger.exception(f"An error occurred while trying to execute instance '{req.instanceId}': {ex}")
192
- failure_details = pbh.new_failure_details(ex)
193
- actions = [pbh.new_complete_orchestration_action(-1, pb.ORCHESTRATION_STATUS_FAILED, "", failure_details)]
194
- res = pb.OrchestratorResponse(instanceId=req.instanceId, actions=actions)
655
+ self._logger.exception(
656
+ f"An error occurred while trying to execute instance '{req.instanceId}': {ex}"
657
+ )
658
+ failure_details = ph.new_failure_details(ex)
659
+ actions = [
660
+ ph.new_complete_orchestration_action(
661
+ -1, pb.ORCHESTRATION_STATUS_FAILED, "", failure_details
662
+ )
663
+ ]
664
+ res = pb.OrchestratorResponse(
665
+ instanceId=req.instanceId,
666
+ actions=actions,
667
+ completionToken=completionToken,
668
+ )
195
669
 
196
670
  try:
197
671
  stub.CompleteOrchestratorTask(res)
198
672
  except Exception as ex:
199
- self._logger.exception(f"Failed to deliver orchestrator response for '{req.instanceId}' to sidecar: {ex}")
200
-
201
- def _execute_activity(self, req: pb.ActivityRequest, stub: stubs.TaskHubSidecarServiceStub):
673
+ self._logger.exception(
674
+ f"Failed to deliver orchestrator response for '{req.instanceId}' to sidecar: {ex}"
675
+ )
676
+
677
+ def _cancel_orchestrator(
678
+ self,
679
+ req: pb.OrchestratorRequest,
680
+ stub: stubs.TaskHubSidecarServiceStub,
681
+ completionToken,
682
+ ):
683
+ stub.AbandonTaskOrchestratorWorkItem(
684
+ pb.AbandonOrchestrationTaskRequest(
685
+ completionToken=completionToken
686
+ )
687
+ )
688
+ self._logger.info(f"Cancelled orchestration task for invocation ID: {req.instanceId}")
689
+
690
+ def _execute_activity(
691
+ self,
692
+ req: pb.ActivityRequest,
693
+ stub: stubs.TaskHubSidecarServiceStub,
694
+ completionToken,
695
+ ):
202
696
  instance_id = req.orchestrationInstance.instanceId
203
697
  try:
204
698
  executor = _ActivityExecutor(self._registry, self._logger)
205
- result = executor.execute(instance_id, req.name, req.taskId, req.input.value)
699
+ result = executor.execute(
700
+ instance_id, req.name, req.taskId, req.input.value
701
+ )
206
702
  res = pb.ActivityResponse(
207
703
  instanceId=instance_id,
208
704
  taskId=req.taskId,
209
- result=pbh.get_string_value(result))
705
+ result=ph.get_string_value(result),
706
+ completionToken=completionToken,
707
+ )
210
708
  except Exception as ex:
211
709
  res = pb.ActivityResponse(
212
710
  instanceId=instance_id,
213
711
  taskId=req.taskId,
214
- failureDetails=pbh.new_failure_details(ex))
712
+ failureDetails=ph.new_failure_details(ex),
713
+ completionToken=completionToken,
714
+ )
215
715
 
216
716
  try:
217
717
  stub.CompleteActivityTask(res)
218
718
  except Exception as ex:
219
719
  self._logger.exception(
220
- f"Failed to deliver activity response for '{req.name}#{req.taskId}' of orchestration ID '{instance_id}' to sidecar: {ex}")
221
-
720
+ f"Failed to deliver activity response for '{req.name}#{req.taskId}' of orchestration ID '{instance_id}' to sidecar: {ex}"
721
+ )
722
+
723
+ def _cancel_activity(
724
+ self,
725
+ req: pb.ActivityRequest,
726
+ stub: stubs.TaskHubSidecarServiceStub,
727
+ completionToken,
728
+ ):
729
+ stub.AbandonTaskActivityWorkItem(
730
+ pb.AbandonActivityTaskRequest(
731
+ completionToken=completionToken
732
+ )
733
+ )
734
+ self._logger.info(f"Cancelled activity task for task ID: {req.taskId} on orchestration ID: {req.orchestrationInstance.instanceId}")
735
+
736
+ def _execute_entity_batch(
737
+ self,
738
+ req: Union[pb.EntityBatchRequest, pb.EntityRequest],
739
+ stub: stubs.TaskHubSidecarServiceStub,
740
+ completionToken,
741
+ ):
742
+ if isinstance(req, pb.EntityRequest):
743
+ req, operation_infos = helpers.convert_to_entity_batch_request(req)
744
+
745
+ entity_state = StateShim(shared.from_json(req.entityState.value) if req.entityState.value else None)
746
+
747
+ instance_id = req.instanceId
748
+
749
+ results: list[pb.OperationResult] = []
750
+ for operation in req.operations:
751
+ start_time = datetime.now(timezone.utc)
752
+ executor = _EntityExecutor(self._registry, self._logger)
753
+ entity_instance_id = EntityInstanceId.parse(instance_id)
754
+ if not entity_instance_id:
755
+ raise RuntimeError(f"Invalid entity instance ID '{operation.requestId}' in entity operation request.")
756
+
757
+ operation_result = None
758
+
759
+ try:
760
+ entity_result = executor.execute(
761
+ instance_id, entity_instance_id, operation.operation, entity_state, operation.input.value
762
+ )
763
+
764
+ entity_result = ph.get_string_value_or_empty(entity_result)
765
+ operation_result = pb.OperationResult(success=pb.OperationResultSuccess(
766
+ result=entity_result,
767
+ startTimeUtc=new_timestamp(start_time),
768
+ endTimeUtc=new_timestamp(datetime.now(timezone.utc))
769
+ ))
770
+ results.append(operation_result)
771
+
772
+ entity_state.commit()
773
+ except Exception as ex:
774
+ self._logger.exception(ex)
775
+ operation_result = pb.OperationResult(failure=pb.OperationResultFailure(
776
+ failureDetails=ph.new_failure_details(ex),
777
+ startTimeUtc=new_timestamp(start_time),
778
+ endTimeUtc=new_timestamp(datetime.now(timezone.utc))
779
+ ))
780
+ results.append(operation_result)
781
+
782
+ entity_state.rollback()
783
+
784
+ batch_result = pb.EntityBatchResult(
785
+ results=results,
786
+ actions=entity_state.get_operation_actions(),
787
+ entityState=helpers.get_string_value(shared.to_json(entity_state._current_state)) if entity_state._current_state else None,
788
+ failureDetails=None,
789
+ completionToken=completionToken,
790
+ operationInfos=operation_infos,
791
+ )
222
792
 
223
- @dataclass
224
- class _ExternalEvent:
225
- name: str
226
- data: Any
793
+ try:
794
+ stub.CompleteEntityTask(batch_result)
795
+ except Exception as ex:
796
+ self._logger.exception(
797
+ f"Failed to deliver entity response for '{entity_instance_id}' of orchestration ID '{instance_id}' to sidecar: {ex}"
798
+ )
799
+
800
+ # TODO: Reset context
801
+
802
+ return batch_result
803
+
804
+ def _cancel_entity_batch(
805
+ self,
806
+ req: Union[pb.EntityBatchRequest, pb.EntityRequest],
807
+ stub: stubs.TaskHubSidecarServiceStub,
808
+ completionToken,
809
+ ):
810
+ stub.AbandonTaskEntityWorkItem(
811
+ pb.AbandonEntityTaskRequest(
812
+ completionToken=completionToken
813
+ )
814
+ )
815
+ self._logger.info(f"Cancelled entity batch task for instance ID: {req.instanceId}")
227
816
 
228
817
 
229
818
  class _RuntimeOrchestrationContext(task.OrchestrationContext):
230
- _generator: Union[Generator[task.Task, Any, Any], None]
231
- _previous_task: Union[task.Task, None]
819
+ _generator: Optional[Generator[task.Task, Any, Any]]
820
+ _previous_task: Optional[task.Task]
232
821
 
233
- def __init__(self, instance_id: str):
822
+ def __init__(self, instance_id: str, registry: _Registry):
234
823
  self._generator = None
235
824
  self._is_replaying = True
236
825
  self._is_complete = False
237
826
  self._result = None
238
- self._pending_actions = dict[int, pb.OrchestratorAction]()
239
- self._pending_tasks = dict[int, task.CompletableTask]()
827
+ self._pending_actions: dict[int, pb.OrchestratorAction] = {}
828
+ self._pending_tasks: dict[int, task.CompletableTask] = {}
829
+ # Maps entity ID to task ID
830
+ self._entity_task_id_map: dict[str, tuple[EntityInstanceId, int]] = {}
831
+ # Maps criticalSectionId to task ID
832
+ self._entity_lock_id_map: dict[str, int] = {}
240
833
  self._sequence_number = 0
241
834
  self._current_utc_datetime = datetime(1000, 1, 1)
242
835
  self._instance_id = instance_id
243
- self._completion_status: Union[pb.OrchestrationStatus, None] = None
244
- self._received_events: Dict[str, List[_ExternalEvent]] = {}
245
- self._pending_events: Dict[str, List[task.CompletableTask]] = {}
836
+ self._registry = registry
837
+ self._entity_context = OrchestrationEntityContext(instance_id)
838
+ self._version: Optional[str] = None
839
+ self._completion_status: Optional[pb.OrchestrationStatus] = None
840
+ self._received_events: dict[str, list[Any]] = {}
841
+ self._pending_events: dict[str, list[task.CompletableTask]] = {}
842
+ self._new_input: Optional[Any] = None
843
+ self._save_events = False
844
+ self._encoded_custom_status: Optional[str] = None
246
845
 
247
846
  def run(self, generator: Generator[task.Task, Any, Any]):
248
847
  self._generator = generator
@@ -254,55 +853,124 @@ class _RuntimeOrchestrationContext(task.OrchestrationContext):
254
853
  def resume(self):
255
854
  if self._generator is None:
256
855
  # This is never expected unless maybe there's an issue with the history
257
- raise TypeError("The orchestrator generator is not initialized! Was the orchestration history corrupted?")
856
+ raise TypeError(
857
+ "The orchestrator generator is not initialized! Was the orchestration history corrupted?"
858
+ )
258
859
 
259
860
  # We can resume the generator only if the previously yielded task
260
861
  # has reached a completed state. The only time this won't be the
261
862
  # case is if the user yielded on a WhenAll task and there are still
262
863
  # outstanding child tasks that need to be completed.
263
- if self._previous_task is not None:
864
+ while self._previous_task is not None and self._previous_task.is_complete:
865
+ next_task = None
264
866
  if self._previous_task.is_failed:
265
- # Raise the failure as an exception to the generator. The orchestrator can then either
266
- # handle the exception or allow it to fail the orchestration.
267
- self._generator.throw(self._previous_task.get_exception())
268
- elif self._previous_task.is_complete:
269
- while True:
270
- # Resume the generator. This will either return a Task or raise StopIteration if it's done.
271
- # CONSIDER: Should we check for possible infinite loops here?
272
- next_task = self._generator.send(self._previous_task.get_result())
273
- if not isinstance(next_task, task.Task):
274
- raise TypeError("The orchestrator generator yielded a non-Task object")
275
- self._previous_task = next_task
276
- # If a completed task was returned, then we can keep running the generator function.
277
- if not self._previous_task.is_complete:
278
- break
279
-
280
- def set_complete(self, result: Any, status: pb.OrchestrationStatus, is_result_encoded: bool = False):
867
+ # Raise the failure as an exception to the generator.
868
+ # The orchestrator can then either handle the exception or allow it to fail the orchestration.
869
+ next_task = self._generator.throw(self._previous_task.get_exception())
870
+ else:
871
+ # Resume the generator with the previous result.
872
+ # This will either return a Task or raise StopIteration if it's done.
873
+ next_task = self._generator.send(self._previous_task.get_result())
874
+
875
+ if not isinstance(next_task, task.Task):
876
+ raise TypeError("The orchestrator generator yielded a non-Task object")
877
+ self._previous_task = next_task
878
+
879
+ def set_complete(
880
+ self,
881
+ result: Any,
882
+ status: pb.OrchestrationStatus,
883
+ is_result_encoded: bool = False,
884
+ ):
281
885
  if self._is_complete:
282
886
  return
283
887
 
888
+ # If the user code returned without yielding the entity unlock, do that now
889
+ if self._entity_context.is_inside_critical_section:
890
+ self._exit_critical_section()
891
+
284
892
  self._is_complete = True
893
+ self._completion_status = status
894
+ # This is probably a bug - an orchestrator may complete with some actions remaining that the user still
895
+ # wants to execute - for example, signaling an entity. So we shouldn't clear the pending actions here.
896
+ # self._pending_actions.clear() # Cancel any pending actions
897
+
285
898
  self._result = result
286
- result_json: Union[str, None] = None
899
+ result_json: Optional[str] = None
287
900
  if result is not None:
288
901
  result_json = result if is_result_encoded else shared.to_json(result)
289
902
  action = ph.new_complete_orchestration_action(
290
- self.next_sequence_number(), status, result_json)
903
+ self.next_sequence_number(), status, result_json
904
+ )
291
905
  self._pending_actions[action.id] = action
292
906
 
293
- def set_failed(self, ex: Exception):
907
+ def set_failed(self, ex: Union[Exception, pb.TaskFailureDetails]):
294
908
  if self._is_complete:
295
909
  return
296
910
 
911
+ # If the user code crashed inside a critical section, or did not exit it, do that now
912
+ if self._entity_context.is_inside_critical_section:
913
+ self._exit_critical_section()
914
+
297
915
  self._is_complete = True
298
- self._pending_actions.clear() # Cancel any pending actions
916
+ # We also cannot cancel the pending actions in the failure case - if the user code had released an entity
917
+ # lock, we *must* send that action to the sidecar.
918
+ # self._pending_actions.clear() # Cancel any pending actions
919
+ self._completion_status = pb.ORCHESTRATION_STATUS_FAILED
920
+
299
921
  action = ph.new_complete_orchestration_action(
300
- self.next_sequence_number(), pb.ORCHESTRATION_STATUS_FAILED, None, ph.new_failure_details(ex)
922
+ self.next_sequence_number(),
923
+ pb.ORCHESTRATION_STATUS_FAILED,
924
+ None,
925
+ ph.new_failure_details(ex) if isinstance(ex, Exception) else ex,
301
926
  )
302
927
  self._pending_actions[action.id] = action
303
928
 
304
- def get_actions(self) -> List[pb.OrchestratorAction]:
305
- return list(self._pending_actions.values())
929
+ def set_continued_as_new(self, new_input: Any, save_events: bool):
930
+ if self._is_complete:
931
+ return
932
+
933
+ # If the user code called continue_as_new while holding an entity lock, unlock it now
934
+ if self._entity_context.is_inside_critical_section:
935
+ self._exit_critical_section()
936
+
937
+ self._is_complete = True
938
+ # We also cannot cancel the pending actions in the continue as new case - if the user code had released an
939
+ # entity lock, we *must* send that action to the sidecar.
940
+ # self._pending_actions.clear() # Cancel any pending actions
941
+ self._completion_status = pb.ORCHESTRATION_STATUS_CONTINUED_AS_NEW
942
+ self._new_input = new_input
943
+ self._save_events = save_events
944
+
945
+ def get_actions(self) -> list[pb.OrchestratorAction]:
946
+ current_actions = list(self._pending_actions.values())
947
+ if self._completion_status == pb.ORCHESTRATION_STATUS_CONTINUED_AS_NEW:
948
+ # When continuing-as-new, we only return a single completion action.
949
+ carryover_events: Optional[list[pb.HistoryEvent]] = None
950
+ if self._save_events:
951
+ carryover_events = []
952
+ # We need to save the current set of pending events so that they can be
953
+ # replayed when the new instance starts.
954
+ for event_name, values in self._received_events.items():
955
+ for event_value in values:
956
+ encoded_value = (
957
+ shared.to_json(event_value) if event_value else None
958
+ )
959
+ carryover_events.append(
960
+ ph.new_event_raised_event(event_name, encoded_value)
961
+ )
962
+ action = ph.new_complete_orchestration_action(
963
+ self.next_sequence_number(),
964
+ pb.ORCHESTRATION_STATUS_CONTINUED_AS_NEW,
965
+ result=shared.to_json(self._new_input)
966
+ if self._new_input is not None
967
+ else None,
968
+ failure_details=None,
969
+ carryover_events=carryover_events,
970
+ )
971
+ # We must return the existing tasks as well, to capture entity unlocks
972
+ current_actions.append(action)
973
+ return current_actions
306
974
 
307
975
  def next_sequence_number(self) -> int:
308
976
  self._sequence_number += 1
@@ -312,56 +980,248 @@ class _RuntimeOrchestrationContext(task.OrchestrationContext):
312
980
  def instance_id(self) -> str:
313
981
  return self._instance_id
314
982
 
983
+ @property
984
+ def version(self) -> Optional[str]:
985
+ return self._version
986
+
315
987
  @property
316
988
  def current_utc_datetime(self) -> datetime:
317
989
  return self._current_utc_datetime
318
990
 
991
+ @current_utc_datetime.setter
992
+ def current_utc_datetime(self, value: datetime):
993
+ self._current_utc_datetime = value
994
+
319
995
  @property
320
996
  def is_replaying(self) -> bool:
321
997
  return self._is_replaying
322
998
 
323
- @current_utc_datetime.setter
324
- def current_utc_datetime(self, value: datetime):
325
- self._current_utc_datetime = value
999
+ def set_custom_status(self, custom_status: Any) -> None:
1000
+ self._encoded_custom_status = (
1001
+ shared.to_json(custom_status) if custom_status is not None else None
1002
+ )
326
1003
 
327
1004
  def create_timer(self, fire_at: Union[datetime, timedelta]) -> task.Task:
1005
+ return self.create_timer_internal(fire_at)
1006
+
1007
+ def create_timer_internal(
1008
+ self,
1009
+ fire_at: Union[datetime, timedelta],
1010
+ retryable_task: Optional[task.RetryableTask] = None,
1011
+ ) -> task.Task:
328
1012
  id = self.next_sequence_number()
329
1013
  if isinstance(fire_at, timedelta):
330
1014
  fire_at = self.current_utc_datetime + fire_at
331
1015
  action = ph.new_create_timer_action(id, fire_at)
332
1016
  self._pending_actions[id] = action
333
1017
 
334
- timer_task = task.CompletableTask()
1018
+ timer_task: task.TimerTask = task.TimerTask()
1019
+ if retryable_task is not None:
1020
+ timer_task.set_retryable_parent(retryable_task)
335
1021
  self._pending_tasks[id] = timer_task
336
1022
  return timer_task
337
1023
 
338
- def call_activity(self, activity: Union[task.Activity[TInput, TOutput], str], *,
339
- input: Union[TInput, None] = None) -> task.Task[TOutput]:
1024
+ def call_activity(
1025
+ self,
1026
+ activity: Union[task.Activity[TInput, TOutput], str],
1027
+ *,
1028
+ input: Optional[TInput] = None,
1029
+ retry_policy: Optional[task.RetryPolicy] = None,
1030
+ tags: Optional[dict[str, str]] = None,
1031
+ ) -> task.Task[TOutput]:
1032
+ id = self.next_sequence_number()
1033
+
1034
+ self.call_activity_function_helper(
1035
+ id, activity, input=input, retry_policy=retry_policy, is_sub_orch=False, tags=tags
1036
+ )
1037
+ return self._pending_tasks.get(id, task.CompletableTask())
1038
+
1039
+ def call_entity(
1040
+ self,
1041
+ entity_id: EntityInstanceId,
1042
+ operation: str,
1043
+ input: Optional[TInput] = None,
1044
+ ) -> task.Task:
1045
+ id = self.next_sequence_number()
1046
+
1047
+ self.call_entity_function_helper(
1048
+ id, entity_id, operation, input=input
1049
+ )
1050
+
1051
+ return self._pending_tasks.get(id, task.CompletableTask())
1052
+
1053
+ def signal_entity(
1054
+ self,
1055
+ entity_id: EntityInstanceId,
1056
+ operation: str,
1057
+ input: Optional[TInput] = None
1058
+ ) -> None:
340
1059
  id = self.next_sequence_number()
341
- name = activity if isinstance(activity, str) else task.get_name(activity)
342
- encoded_input = shared.to_json(input) if input else None
343
- action = ph.new_schedule_task_action(id, name, encoded_input)
344
- self._pending_actions[id] = action
345
1060
 
346
- activity_task = task.CompletableTask[TOutput]()
347
- self._pending_tasks[id] = activity_task
348
- return activity_task
1061
+ self.signal_entity_function_helper(
1062
+ id, entity_id, operation, input
1063
+ )
1064
+
1065
+ def lock_entities(self, entities: list[EntityInstanceId]) -> task.Task[EntityLock]:
1066
+ id = self.next_sequence_number()
349
1067
 
350
- def call_sub_orchestrator(self, orchestrator: task.Orchestrator[TInput, TOutput], *,
351
- input: Union[TInput, None] = None,
352
- instance_id: Union[str, None] = None) -> task.Task[TOutput]:
1068
+ self.lock_entities_function_helper(
1069
+ id, entities
1070
+ )
1071
+ return self._pending_tasks.get(id, task.CompletableTask())
1072
+
1073
+ def call_sub_orchestrator(
1074
+ self,
1075
+ orchestrator: Union[task.Orchestrator[TInput, TOutput], str],
1076
+ *,
1077
+ input: Optional[TInput] = None,
1078
+ instance_id: Optional[str] = None,
1079
+ retry_policy: Optional[task.RetryPolicy] = None,
1080
+ version: Optional[str] = None,
1081
+ ) -> task.Task[TOutput]:
353
1082
  id = self.next_sequence_number()
354
- name = task.get_name(orchestrator)
355
- if instance_id is None:
356
- # Create a deteministic instance ID based on the parent instance ID
357
- instance_id = f"{self.instance_id}:{id:04x}"
358
- encoded_input = shared.to_json(input) if input else None
359
- action = ph.new_create_sub_orchestration_action(id, name, instance_id, encoded_input)
1083
+ if isinstance(orchestrator, str):
1084
+ orchestrator_name = orchestrator
1085
+ else:
1086
+ orchestrator_name = task.get_name(orchestrator)
1087
+ default_version = self._registry.versioning.default_version if self._registry.versioning else None
1088
+ orchestrator_version = version if version else default_version
1089
+ self.call_activity_function_helper(
1090
+ id,
1091
+ orchestrator_name,
1092
+ input=input,
1093
+ retry_policy=retry_policy,
1094
+ is_sub_orch=True,
1095
+ instance_id=instance_id,
1096
+ version=orchestrator_version
1097
+ )
1098
+ return self._pending_tasks.get(id, task.CompletableTask())
1099
+
1100
+ def call_activity_function_helper(
1101
+ self,
1102
+ id: Optional[int],
1103
+ activity_function: Union[task.Activity[TInput, TOutput], str],
1104
+ *,
1105
+ input: Optional[TInput] = None,
1106
+ retry_policy: Optional[task.RetryPolicy] = None,
1107
+ tags: Optional[dict[str, str]] = None,
1108
+ is_sub_orch: bool = False,
1109
+ instance_id: Optional[str] = None,
1110
+ fn_task: Optional[task.CompletableTask[TOutput]] = None,
1111
+ version: Optional[str] = None,
1112
+ ):
1113
+ if id is None:
1114
+ id = self.next_sequence_number()
1115
+
1116
+ if fn_task is None:
1117
+ encoded_input = shared.to_json(input) if input is not None else None
1118
+ else:
1119
+ # Here, we don't need to convert the input to JSON because it is already converted.
1120
+ # We just need to take string representation of it.
1121
+ encoded_input = str(input)
1122
+ if not is_sub_orch:
1123
+ name = (
1124
+ activity_function
1125
+ if isinstance(activity_function, str)
1126
+ else task.get_name(activity_function)
1127
+ )
1128
+ action = ph.new_schedule_task_action(id, name, encoded_input, tags)
1129
+ else:
1130
+ if instance_id is None:
1131
+ # Create a deteministic instance ID based on the parent instance ID
1132
+ instance_id = f"{self.instance_id}:{id:04x}"
1133
+ if not isinstance(activity_function, str):
1134
+ raise ValueError("Orchestrator function name must be a string")
1135
+ action = ph.new_create_sub_orchestration_action(
1136
+ id, activity_function, instance_id, encoded_input, version
1137
+ )
1138
+ self._pending_actions[id] = action
1139
+
1140
+ if fn_task is None:
1141
+ if retry_policy is None:
1142
+ fn_task = task.CompletableTask[TOutput]()
1143
+ else:
1144
+ fn_task = task.RetryableTask[TOutput](
1145
+ retry_policy=retry_policy,
1146
+ action=action,
1147
+ start_time=self.current_utc_datetime,
1148
+ is_sub_orch=is_sub_orch,
1149
+ )
1150
+ self._pending_tasks[id] = fn_task
1151
+
1152
+ def call_entity_function_helper(
1153
+ self,
1154
+ id: Optional[int],
1155
+ entity_id: EntityInstanceId,
1156
+ operation: str,
1157
+ *,
1158
+ input: Optional[TInput] = None,
1159
+ ):
1160
+ if id is None:
1161
+ id = self.next_sequence_number()
1162
+
1163
+ transition_valid, error_message = self._entity_context.validate_operation_transition(entity_id, False)
1164
+ if not transition_valid:
1165
+ raise RuntimeError(error_message)
1166
+
1167
+ encoded_input = shared.to_json(input) if input is not None else None
1168
+ action = ph.new_call_entity_action(id, self.instance_id, entity_id, operation, encoded_input)
1169
+ self._pending_actions[id] = action
1170
+
1171
+ fn_task = task.CompletableTask()
1172
+ self._pending_tasks[id] = fn_task
1173
+
1174
+ def signal_entity_function_helper(
1175
+ self,
1176
+ id: Optional[int],
1177
+ entity_id: EntityInstanceId,
1178
+ operation: str,
1179
+ input: Optional[TInput]
1180
+ ) -> None:
1181
+ if id is None:
1182
+ id = self.next_sequence_number()
1183
+
1184
+ transition_valid, error_message = self._entity_context.validate_operation_transition(entity_id, True)
1185
+
1186
+ if not transition_valid:
1187
+ raise RuntimeError(error_message)
1188
+
1189
+ encoded_input = shared.to_json(input) if input is not None else None
1190
+
1191
+ action = ph.new_signal_entity_action(id, entity_id, operation, encoded_input)
360
1192
  self._pending_actions[id] = action
361
1193
 
362
- sub_orch_task = task.CompletableTask[TOutput]()
363
- self._pending_tasks[id] = sub_orch_task
364
- return sub_orch_task
1194
+ def lock_entities_function_helper(self, id: int, entities: list[EntityInstanceId]) -> None:
1195
+ if id is None:
1196
+ id = self.next_sequence_number()
1197
+
1198
+ transition_valid, error_message = self._entity_context.validate_acquire_transition()
1199
+ if not transition_valid:
1200
+ raise RuntimeError(error_message)
1201
+
1202
+ critical_section_id = f"{self.instance_id}:{id:04x}"
1203
+
1204
+ request, target = self._entity_context.emit_acquire_message(critical_section_id, entities)
1205
+
1206
+ if not request or not target:
1207
+ raise RuntimeError("Failed to create entity lock request.")
1208
+
1209
+ action = ph.new_lock_entities_action(id, request)
1210
+ self._pending_actions[id] = action
1211
+
1212
+ fn_task = task.CompletableTask[EntityLock]()
1213
+ self._pending_tasks[id] = fn_task
1214
+
1215
+ def _exit_critical_section(self) -> None:
1216
+ if not self._entity_context.is_inside_critical_section:
1217
+ # Possible if the user calls continue_as_new inside the lock - in the success case, we will call
1218
+ # _exit_critical_section both from the EntityLock and the continue_as_new logic. We must keep both calls in
1219
+ # case the user code crashes after calling continue_as_new but before the EntityLock object is exited.
1220
+ return
1221
+ for entity_unlock_message in self._entity_context.emit_lock_release_messages():
1222
+ task_id = self.next_sequence_number()
1223
+ action = pb.OrchestratorAction(id=task_id, sendEntityMessage=entity_unlock_message)
1224
+ self._pending_actions[task_id] = action
365
1225
 
366
1226
  def wait_for_external_event(self, name: str) -> task.Task:
367
1227
  # Check to see if this event has already been received, in which case we
@@ -369,14 +1229,14 @@ class _RuntimeOrchestrationContext(task.OrchestrationContext):
369
1229
  # event with the given name so that we can resume the generator when it
370
1230
  # arrives. If there are multiple events with the same name, we return
371
1231
  # them in the order they were received.
372
- external_event_task = task.CompletableTask()
373
- event_name = name.upper()
1232
+ external_event_task: task.CompletableTask = task.CompletableTask()
1233
+ event_name = name.casefold()
374
1234
  event_list = self._received_events.get(event_name, None)
375
1235
  if event_list:
376
- event = event_list.pop(0)
1236
+ event_data = event_list.pop(0)
377
1237
  if not event_list:
378
1238
  del self._received_events[event_name]
379
- external_event_task.complete(event.data)
1239
+ external_event_task.complete(event_data)
380
1240
  else:
381
1241
  task_list = self._pending_events.get(event_name, None)
382
1242
  if not task_list:
@@ -385,25 +1245,59 @@ class _RuntimeOrchestrationContext(task.OrchestrationContext):
385
1245
  task_list.append(external_event_task)
386
1246
  return external_event_task
387
1247
 
1248
+ def continue_as_new(self, new_input, *, save_events: bool = False) -> None:
1249
+ if self._is_complete:
1250
+ return
1251
+
1252
+ self.set_continued_as_new(new_input, save_events)
1253
+
1254
+
1255
+ class ExecutionResults:
1256
+ actions: list[pb.OrchestratorAction]
1257
+ encoded_custom_status: Optional[str]
1258
+
1259
+ def __init__(
1260
+ self, actions: list[pb.OrchestratorAction], encoded_custom_status: Optional[str]
1261
+ ):
1262
+ self.actions = actions
1263
+ self.encoded_custom_status = encoded_custom_status
1264
+
388
1265
 
389
1266
  class _OrchestrationExecutor:
390
- _generator: Union[task.Orchestrator, None]
1267
+ _generator: Optional[task.Orchestrator] = None
391
1268
 
392
1269
  def __init__(self, registry: _Registry, logger: logging.Logger):
393
1270
  self._registry = registry
394
1271
  self._logger = logger
395
- self._generator = None
396
1272
  self._is_suspended = False
397
- self._suspended_events: List[pb.HistoryEvent] = []
1273
+ self._suspended_events: list[pb.HistoryEvent] = []
1274
+
1275
+ def execute(
1276
+ self,
1277
+ instance_id: str,
1278
+ old_events: Sequence[pb.HistoryEvent],
1279
+ new_events: Sequence[pb.HistoryEvent],
1280
+ ) -> ExecutionResults:
1281
+ orchestration_name = "<unknown>"
1282
+ orchestration_started_events = [e for e in old_events if e.HasField("executionStarted")]
1283
+ if len(orchestration_started_events) >= 1:
1284
+ orchestration_name = orchestration_started_events[0].executionStarted.name
1285
+
1286
+ self._logger.debug(
1287
+ f"{instance_id}: Beginning replay for orchestrator {orchestration_name}..."
1288
+ )
398
1289
 
399
- def execute(self, instance_id: str, old_events: Sequence[pb.HistoryEvent], new_events: Sequence[pb.HistoryEvent]) -> List[pb.OrchestratorAction]:
400
1290
  if not new_events:
401
- raise task.OrchestrationStateError("The new history event list must have at least one event in it.")
1291
+ raise task.OrchestrationStateError(
1292
+ "The new history event list must have at least one event in it."
1293
+ )
402
1294
 
403
- ctx = _RuntimeOrchestrationContext(instance_id)
1295
+ ctx = _RuntimeOrchestrationContext(instance_id, self._registry)
404
1296
  try:
405
1297
  # Rebuild local state by replaying old history into the orchestrator function
406
- self._logger.debug(f"{instance_id}: Rebuilding local state with {len(old_events)} history event...")
1298
+ self._logger.debug(
1299
+ f"{instance_id}: Rebuilding local state with {len(old_events)} history event..."
1300
+ )
407
1301
  ctx._is_replaying = True
408
1302
  for old_event in old_events:
409
1303
  self.process_event(ctx, old_event)
@@ -411,26 +1305,56 @@ class _OrchestrationExecutor:
411
1305
  # Get new actions by executing newly received events into the orchestrator function
412
1306
  if self._logger.level <= logging.DEBUG:
413
1307
  summary = _get_new_event_summary(new_events)
414
- self._logger.debug(f"{instance_id}: Processing {len(new_events)} new event(s): {summary}")
1308
+ self._logger.debug(
1309
+ f"{instance_id}: Processing {len(new_events)} new event(s): {summary}"
1310
+ )
415
1311
  ctx._is_replaying = False
416
1312
  for new_event in new_events:
417
1313
  self.process_event(ctx, new_event)
418
- if ctx._is_complete:
419
- break
1314
+
1315
+ except pe.VersionFailureException as ex:
1316
+ if self._registry.versioning and self._registry.versioning.failure_strategy == VersionFailureStrategy.FAIL:
1317
+ if ex.error_details:
1318
+ ctx.set_failed(ex.error_details)
1319
+ else:
1320
+ ctx.set_failed(ex)
1321
+ elif self._registry.versioning and self._registry.versioning.failure_strategy == VersionFailureStrategy.REJECT:
1322
+ raise pe.AbandonOrchestrationError
1323
+
420
1324
  except Exception as ex:
421
1325
  # Unhandled exceptions fail the orchestration
1326
+ self._logger.debug(f"{instance_id}: Orchestration {orchestration_name} failed")
422
1327
  ctx.set_failed(ex)
423
1328
 
424
- if ctx._completion_status:
425
- completion_status_str = pbh.get_orchestration_status_str(ctx._completion_status)
426
- self._logger.info(f"{instance_id}: Orchestration completed with status: {completion_status_str}")
1329
+ if not ctx._is_complete:
1330
+ task_count = len(ctx._pending_tasks)
1331
+ event_count = len(ctx._pending_events)
1332
+ self._logger.info(
1333
+ f"{instance_id}: Orchestrator {orchestration_name} yielded with {task_count} task(s) "
1334
+ f"and {event_count} event(s) outstanding."
1335
+ )
1336
+ elif (
1337
+ ctx._completion_status and ctx._completion_status is not pb.ORCHESTRATION_STATUS_CONTINUED_AS_NEW
1338
+ ):
1339
+ completion_status_str = ph.get_orchestration_status_str(
1340
+ ctx._completion_status
1341
+ )
1342
+ self._logger.info(
1343
+ f"{instance_id}: Orchestration {orchestration_name} completed with status: {completion_status_str}"
1344
+ )
427
1345
 
428
1346
  actions = ctx.get_actions()
429
1347
  if self._logger.level <= logging.DEBUG:
430
- self._logger.debug(f"{instance_id}: Returning {len(actions)} action(s): {_get_action_summary(actions)}")
431
- return actions
1348
+ self._logger.debug(
1349
+ f"{instance_id}: Returning {len(actions)} action(s): {_get_action_summary(actions)}"
1350
+ )
1351
+ return ExecutionResults(
1352
+ actions=actions, encoded_custom_status=ctx._encoded_custom_status
1353
+ )
432
1354
 
433
- def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEvent) -> None:
1355
+ def process_event(
1356
+ self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEvent
1357
+ ) -> None:
434
1358
  if self._is_suspended and _is_suspendable(event):
435
1359
  # We are suspended, so we need to buffer this event until we are resumed
436
1360
  self._suspended_events.append(event)
@@ -445,14 +1369,35 @@ class _OrchestrationExecutor:
445
1369
  fn = self._registry.get_orchestrator(event.executionStarted.name)
446
1370
  if fn is None:
447
1371
  raise OrchestratorNotRegisteredError(
448
- f"A '{event.executionStarted.name}' orchestrator was not registered.")
1372
+ f"A '{event.executionStarted.name}' orchestrator was not registered."
1373
+ )
1374
+
1375
+ if event.executionStarted.version:
1376
+ ctx._version = event.executionStarted.version.value
1377
+
1378
+ if self._registry.versioning:
1379
+ version_failure = self.evaluate_orchestration_versioning(
1380
+ self._registry.versioning,
1381
+ ctx.version
1382
+ )
1383
+ if version_failure:
1384
+ self._logger.warning(
1385
+ f"Orchestration version did not meet worker versioning requirements. "
1386
+ f"Error action = '{self._registry.versioning.failure_strategy}'. "
1387
+ f"Version error = '{version_failure}'"
1388
+ )
1389
+ raise pe.VersionFailureException(version_failure)
449
1390
 
450
1391
  # deserialize the input, if any
451
1392
  input = None
452
- if event.executionStarted.input is not None and event.executionStarted.input.value != "":
1393
+ if (
1394
+ event.executionStarted.input is not None and event.executionStarted.input.value != ""
1395
+ ):
453
1396
  input = shared.from_json(event.executionStarted.input.value)
454
1397
 
455
- result = fn(ctx, input) # this does not execute the generator, only creates it
1398
+ result = fn(
1399
+ ctx, input
1400
+ ) # this does not execute the generator, only creates it
456
1401
  if isinstance(result, GeneratorType):
457
1402
  # Start the orchestrator's generator function
458
1403
  ctx.run(result)
@@ -465,10 +1410,14 @@ class _OrchestrationExecutor:
465
1410
  timer_id = event.eventId
466
1411
  action = ctx._pending_actions.pop(timer_id, None)
467
1412
  if not action:
468
- raise _get_non_determinism_error(timer_id, task.get_name(ctx.create_timer))
1413
+ raise _get_non_determinism_error(
1414
+ timer_id, task.get_name(ctx.create_timer)
1415
+ )
469
1416
  elif not action.HasField("createTimer"):
470
1417
  expected_method_name = task.get_name(ctx.create_timer)
471
- raise _get_wrong_action_type_error(timer_id, expected_method_name, action)
1418
+ raise _get_wrong_action_type_error(
1419
+ timer_id, expected_method_name, action
1420
+ )
472
1421
  elif event.HasField("timerFired"):
473
1422
  timer_id = event.timerFired.timerId
474
1423
  timer_task = ctx._pending_tasks.pop(timer_id, None)
@@ -476,26 +1425,52 @@ class _OrchestrationExecutor:
476
1425
  # TODO: Should this be an error? When would it ever happen?
477
1426
  if not ctx._is_replaying:
478
1427
  self._logger.warning(
479
- f"{ctx.instance_id}: Ignoring unexpected timerFired event with ID = {timer_id}.")
1428
+ f"{ctx.instance_id}: Ignoring unexpected timerFired event with ID = {timer_id}."
1429
+ )
480
1430
  return
481
1431
  timer_task.complete(None)
482
- ctx.resume()
1432
+ if timer_task._retryable_parent is not None:
1433
+ activity_action = timer_task._retryable_parent._action
1434
+
1435
+ if not timer_task._retryable_parent._is_sub_orch:
1436
+ cur_task = activity_action.scheduleTask
1437
+ instance_id = None
1438
+ else:
1439
+ cur_task = activity_action.createSubOrchestration
1440
+ instance_id = cur_task.instanceId
1441
+ ctx.call_activity_function_helper(
1442
+ id=activity_action.id,
1443
+ activity_function=cur_task.name,
1444
+ input=cur_task.input.value,
1445
+ retry_policy=timer_task._retryable_parent._retry_policy,
1446
+ is_sub_orch=timer_task._retryable_parent._is_sub_orch,
1447
+ instance_id=instance_id,
1448
+ fn_task=timer_task._retryable_parent,
1449
+ )
1450
+ else:
1451
+ ctx.resume()
483
1452
  elif event.HasField("taskScheduled"):
484
1453
  # This history event confirms that the activity execution was successfully scheduled.
485
1454
  # Remove the taskScheduled event from the pending action list so we don't schedule it again.
486
1455
  task_id = event.eventId
487
1456
  action = ctx._pending_actions.pop(task_id, None)
1457
+ activity_task = ctx._pending_tasks.get(task_id, None)
488
1458
  if not action:
489
- raise _get_non_determinism_error(task_id, task.get_name(ctx.call_activity))
1459
+ raise _get_non_determinism_error(
1460
+ task_id, task.get_name(ctx.call_activity)
1461
+ )
490
1462
  elif not action.HasField("scheduleTask"):
491
1463
  expected_method_name = task.get_name(ctx.call_activity)
492
- raise _get_wrong_action_type_error(task_id, expected_method_name, action)
1464
+ raise _get_wrong_action_type_error(
1465
+ task_id, expected_method_name, action
1466
+ )
493
1467
  elif action.scheduleTask.name != event.taskScheduled.name:
494
1468
  raise _get_wrong_action_name_error(
495
1469
  task_id,
496
1470
  method_name=task.get_name(ctx.call_activity),
497
1471
  expected_task_name=event.taskScheduled.name,
498
- actual_task_name=action.scheduleTask.name)
1472
+ actual_task_name=action.scheduleTask.name,
1473
+ )
499
1474
  elif event.HasField("taskCompleted"):
500
1475
  # This history event contains the result of a completed activity task.
501
1476
  task_id = event.taskCompleted.taskScheduledId
@@ -504,7 +1479,8 @@ class _OrchestrationExecutor:
504
1479
  # TODO: Should this be an error? When would it ever happen?
505
1480
  if not ctx.is_replaying:
506
1481
  self._logger.warning(
507
- f"{ctx.instance_id}: Ignoring unexpected taskCompleted event with ID = {task_id}.")
1482
+ f"{ctx.instance_id}: Ignoring unexpected taskCompleted event with ID = {task_id}."
1483
+ )
508
1484
  return
509
1485
  result = None
510
1486
  if not ph.is_empty(event.taskCompleted.result):
@@ -518,28 +1494,53 @@ class _OrchestrationExecutor:
518
1494
  # TODO: Should this be an error? When would it ever happen?
519
1495
  if not ctx.is_replaying:
520
1496
  self._logger.warning(
521
- f"{ctx.instance_id}: Ignoring unexpected taskFailed event with ID = {task_id}.")
1497
+ f"{ctx.instance_id}: Ignoring unexpected taskFailed event with ID = {task_id}."
1498
+ )
522
1499
  return
523
- activity_task.fail(
524
- f"{ctx.instance_id}: Activity task #{task_id} failed: {event.taskFailed.failureDetails.errorMessage}",
525
- event.taskFailed.failureDetails)
526
- ctx.resume()
1500
+
1501
+ if isinstance(activity_task, task.RetryableTask):
1502
+ if activity_task._retry_policy is not None:
1503
+ next_delay = activity_task.compute_next_delay()
1504
+ if next_delay is None:
1505
+ activity_task.fail(
1506
+ f"{ctx.instance_id}: Activity task #{task_id} failed: {event.taskFailed.failureDetails.errorMessage}",
1507
+ event.taskFailed.failureDetails,
1508
+ )
1509
+ ctx.resume()
1510
+ else:
1511
+ activity_task.increment_attempt_count()
1512
+ ctx.create_timer_internal(next_delay, activity_task)
1513
+ elif isinstance(activity_task, task.CompletableTask):
1514
+ activity_task.fail(
1515
+ f"{ctx.instance_id}: Activity task #{task_id} failed: {event.taskFailed.failureDetails.errorMessage}",
1516
+ event.taskFailed.failureDetails,
1517
+ )
1518
+ ctx.resume()
1519
+ else:
1520
+ raise TypeError("Unexpected task type")
527
1521
  elif event.HasField("subOrchestrationInstanceCreated"):
528
1522
  # This history event confirms that the sub-orchestration execution was successfully scheduled.
529
1523
  # Remove the subOrchestrationInstanceCreated event from the pending action list so we don't schedule it again.
530
1524
  task_id = event.eventId
531
1525
  action = ctx._pending_actions.pop(task_id, None)
532
1526
  if not action:
533
- raise _get_non_determinism_error(task_id, task.get_name(ctx.call_sub_orchestrator))
1527
+ raise _get_non_determinism_error(
1528
+ task_id, task.get_name(ctx.call_sub_orchestrator)
1529
+ )
534
1530
  elif not action.HasField("createSubOrchestration"):
535
1531
  expected_method_name = task.get_name(ctx.call_sub_orchestrator)
536
- raise _get_wrong_action_type_error(task_id, expected_method_name, action)
537
- elif action.createSubOrchestration.name != event.subOrchestrationInstanceCreated.name:
1532
+ raise _get_wrong_action_type_error(
1533
+ task_id, expected_method_name, action
1534
+ )
1535
+ elif (
1536
+ action.createSubOrchestration.name != event.subOrchestrationInstanceCreated.name
1537
+ ):
538
1538
  raise _get_wrong_action_name_error(
539
1539
  task_id,
540
1540
  method_name=task.get_name(ctx.call_sub_orchestrator),
541
1541
  expected_task_name=event.subOrchestrationInstanceCreated.name,
542
- actual_task_name=action.createSubOrchestration.name)
1542
+ actual_task_name=action.createSubOrchestration.name,
1543
+ )
543
1544
  elif event.HasField("subOrchestrationInstanceCompleted"):
544
1545
  task_id = event.subOrchestrationInstanceCompleted.taskScheduledId
545
1546
  sub_orch_task = ctx._pending_tasks.pop(task_id, None)
@@ -547,11 +1548,14 @@ class _OrchestrationExecutor:
547
1548
  # TODO: Should this be an error? When would it ever happen?
548
1549
  if not ctx.is_replaying:
549
1550
  self._logger.warning(
550
- f"{ctx.instance_id}: Ignoring unexpected subOrchestrationInstanceCompleted event with ID = {task_id}.")
1551
+ f"{ctx.instance_id}: Ignoring unexpected subOrchestrationInstanceCompleted event with ID = {task_id}."
1552
+ )
551
1553
  return
552
1554
  result = None
553
1555
  if not ph.is_empty(event.subOrchestrationInstanceCompleted.result):
554
- result = shared.from_json(event.subOrchestrationInstanceCompleted.result.value)
1556
+ result = shared.from_json(
1557
+ event.subOrchestrationInstanceCompleted.result.value
1558
+ )
555
1559
  sub_orch_task.complete(result)
556
1560
  ctx.resume()
557
1561
  elif event.HasField("subOrchestrationInstanceFailed"):
@@ -562,19 +1566,36 @@ class _OrchestrationExecutor:
562
1566
  # TODO: Should this be an error? When would it ever happen?
563
1567
  if not ctx.is_replaying:
564
1568
  self._logger.warning(
565
- f"{ctx.instance_id}: Ignoring unexpected subOrchestrationInstanceFailed event with ID = {task_id}.")
1569
+ f"{ctx.instance_id}: Ignoring unexpected subOrchestrationInstanceFailed event with ID = {task_id}."
1570
+ )
566
1571
  return
567
- sub_orch_task.fail(
568
- f"Sub-orchestration task #{task_id} failed: {failedEvent.failureDetails.errorMessage}",
569
- failedEvent.failureDetails)
570
- ctx.resume()
1572
+ if isinstance(sub_orch_task, task.RetryableTask):
1573
+ if sub_orch_task._retry_policy is not None:
1574
+ next_delay = sub_orch_task.compute_next_delay()
1575
+ if next_delay is None:
1576
+ sub_orch_task.fail(
1577
+ f"Sub-orchestration task #{task_id} failed: {failedEvent.failureDetails.errorMessage}",
1578
+ failedEvent.failureDetails,
1579
+ )
1580
+ ctx.resume()
1581
+ else:
1582
+ sub_orch_task.increment_attempt_count()
1583
+ ctx.create_timer_internal(next_delay, sub_orch_task)
1584
+ elif isinstance(sub_orch_task, task.CompletableTask):
1585
+ sub_orch_task.fail(
1586
+ f"Sub-orchestration task #{task_id} failed: {failedEvent.failureDetails.errorMessage}",
1587
+ failedEvent.failureDetails,
1588
+ )
1589
+ ctx.resume()
1590
+ else:
1591
+ raise TypeError("Unexpected sub-orchestration task type")
571
1592
  elif event.HasField("eventRaised"):
572
1593
  # event names are case-insensitive
573
- event_name = event.eventRaised.name.upper()
1594
+ event_name = event.eventRaised.name.casefold()
574
1595
  if not ctx.is_replaying:
575
- self._logger.info(f"Event raised: {event_name}")
1596
+ self._logger.info(f"{ctx.instance_id} Event raised: {event_name}")
576
1597
  task_list = ctx._pending_events.get(event_name, None)
577
- decoded_result: Union[Any, None] = None
1598
+ decoded_result: Optional[Any] = None
578
1599
  if task_list:
579
1600
  event_task = task_list.pop(0)
580
1601
  if not ph.is_empty(event.eventRaised.input):
@@ -591,9 +1612,11 @@ class _OrchestrationExecutor:
591
1612
  ctx._received_events[event_name] = event_list
592
1613
  if not ph.is_empty(event.eventRaised.input):
593
1614
  decoded_result = shared.from_json(event.eventRaised.input.value)
594
- event_list.append(_ExternalEvent(event.eventRaised.name, decoded_result))
1615
+ event_list.append(decoded_result)
595
1616
  if not ctx.is_replaying:
596
- self._logger.info(f"{ctx.instance_id}: Event '{event_name}' has been buffered as there are no tasks waiting for it.")
1617
+ self._logger.info(
1618
+ f"{ctx.instance_id}: Event '{event_name}' has been buffered as there are no tasks waiting for it."
1619
+ )
597
1620
  elif event.HasField("executionSuspended"):
598
1621
  if not self._is_suspended and not ctx.is_replaying:
599
1622
  self._logger.info(f"{ctx.instance_id}: Execution suspended.")
@@ -608,27 +1631,191 @@ class _OrchestrationExecutor:
608
1631
  elif event.HasField("executionTerminated"):
609
1632
  if not ctx.is_replaying:
610
1633
  self._logger.info(f"{ctx.instance_id}: Execution terminating.")
611
- encoded_output = event.executionTerminated.input.value if not ph.is_empty(event.executionTerminated.input) else None
612
- ctx.set_complete(encoded_output, pb.ORCHESTRATION_STATUS_TERMINATED, is_result_encoded=True)
1634
+ encoded_output = (
1635
+ event.executionTerminated.input.value
1636
+ if not ph.is_empty(event.executionTerminated.input)
1637
+ else None
1638
+ )
1639
+ ctx.set_complete(
1640
+ encoded_output,
1641
+ pb.ORCHESTRATION_STATUS_TERMINATED,
1642
+ is_result_encoded=True,
1643
+ )
1644
+ elif event.HasField("entityOperationCalled"):
1645
+ # This history event confirms that the entity operation was successfully scheduled.
1646
+ # Remove the entityOperationCalled event from the pending action list so we don't schedule it again
1647
+ entity_call_id = event.eventId
1648
+ action = ctx._pending_actions.pop(entity_call_id, None)
1649
+ entity_task = ctx._pending_tasks.get(entity_call_id, None)
1650
+ if not action:
1651
+ raise _get_non_determinism_error(
1652
+ entity_call_id, task.get_name(ctx.call_entity)
1653
+ )
1654
+ elif not action.HasField("sendEntityMessage") or not action.sendEntityMessage.HasField("entityOperationCalled"):
1655
+ expected_method_name = task.get_name(ctx.call_entity)
1656
+ raise _get_wrong_action_type_error(
1657
+ entity_call_id, expected_method_name, action
1658
+ )
1659
+ entity_id = EntityInstanceId.parse(event.entityOperationCalled.targetInstanceId.value)
1660
+ if not entity_id:
1661
+ raise RuntimeError(f"Could not parse entity ID from targetInstanceId '{event.entityOperationCalled.targetInstanceId.value}'")
1662
+ ctx._entity_task_id_map[event.entityOperationCalled.requestId] = (entity_id, entity_call_id)
1663
+ elif event.HasField("entityOperationSignaled"):
1664
+ # This history event confirms that the entity signal was successfully scheduled.
1665
+ # Remove the entityOperationSignaled event from the pending action list so we don't schedule it
1666
+ entity_signal_id = event.eventId
1667
+ action = ctx._pending_actions.pop(entity_signal_id, None)
1668
+ if not action:
1669
+ raise _get_non_determinism_error(
1670
+ entity_signal_id, task.get_name(ctx.signal_entity)
1671
+ )
1672
+ elif not action.HasField("sendEntityMessage") or not action.sendEntityMessage.HasField("entityOperationSignaled"):
1673
+ expected_method_name = task.get_name(ctx.signal_entity)
1674
+ raise _get_wrong_action_type_error(
1675
+ entity_signal_id, expected_method_name, action
1676
+ )
1677
+ elif event.HasField("entityLockRequested"):
1678
+ section_id = event.entityLockRequested.criticalSectionId
1679
+ task_id = event.eventId
1680
+ action = ctx._pending_actions.pop(task_id, None)
1681
+ entity_task = ctx._pending_tasks.get(task_id, None)
1682
+ if not action:
1683
+ raise _get_non_determinism_error(
1684
+ task_id, task.get_name(ctx.lock_entities)
1685
+ )
1686
+ elif not action.HasField("sendEntityMessage") or not action.sendEntityMessage.HasField("entityLockRequested"):
1687
+ expected_method_name = task.get_name(ctx.lock_entities)
1688
+ raise _get_wrong_action_type_error(
1689
+ task_id, expected_method_name, action
1690
+ )
1691
+ ctx._entity_lock_id_map[section_id] = task_id
1692
+ elif event.HasField("entityUnlockSent"):
1693
+ # Remove the unlock tasks as they have already been processed
1694
+ tasks_to_remove = []
1695
+ for task_id, action in ctx._pending_actions.items():
1696
+ if action.HasField("sendEntityMessage") and action.sendEntityMessage.HasField("entityUnlockSent"):
1697
+ if action.sendEntityMessage.entityUnlockSent.criticalSectionId == event.entityUnlockSent.criticalSectionId:
1698
+ tasks_to_remove.append(task_id)
1699
+ for task_to_remove in tasks_to_remove:
1700
+ ctx._pending_actions.pop(task_to_remove, None)
1701
+ elif event.HasField("entityLockGranted"):
1702
+ section_id = event.entityLockGranted.criticalSectionId
1703
+ task_id = ctx._entity_lock_id_map.pop(section_id, None)
1704
+ if not task_id:
1705
+ # TODO: Should this be an error? When would it ever happen?
1706
+ if not ctx.is_replaying:
1707
+ self._logger.warning(
1708
+ f"{ctx.instance_id}: Ignoring unexpected entityLockGranted event for criticalSectionId '{section_id}'."
1709
+ )
1710
+ return
1711
+ entity_task = ctx._pending_tasks.pop(task_id, None)
1712
+ if not entity_task:
1713
+ if not ctx.is_replaying:
1714
+ self._logger.warning(
1715
+ f"{ctx.instance_id}: Ignoring unexpected entityLockGranted event for criticalSectionId '{section_id}'."
1716
+ )
1717
+ return
1718
+ ctx._entity_context.complete_acquire(section_id)
1719
+ entity_task.complete(EntityLock(ctx))
1720
+ ctx.resume()
1721
+ elif event.HasField("entityOperationCompleted"):
1722
+ request_id = event.entityOperationCompleted.requestId
1723
+ entity_id, task_id = ctx._entity_task_id_map.pop(request_id, (None, None))
1724
+ if not entity_id:
1725
+ raise RuntimeError(f"Could not parse entity ID from request ID '{request_id}'")
1726
+ if not task_id:
1727
+ raise RuntimeError(f"Could not find matching task ID for entity operation with request ID '{request_id}'")
1728
+ entity_task = ctx._pending_tasks.pop(task_id, None)
1729
+ if not entity_task:
1730
+ if not ctx.is_replaying:
1731
+ self._logger.warning(
1732
+ f"{ctx.instance_id}: Ignoring unexpected entityOperationCompleted event with request ID = {request_id}."
1733
+ )
1734
+ return
1735
+ result = None
1736
+ if not ph.is_empty(event.entityOperationCompleted.output):
1737
+ result = shared.from_json(event.entityOperationCompleted.output.value)
1738
+ ctx._entity_context.recover_lock_after_call(entity_id)
1739
+ entity_task.complete(result)
1740
+ ctx.resume()
1741
+ elif event.HasField("entityOperationFailed"):
1742
+ if not ctx.is_replaying:
1743
+ self._logger.info(f"{ctx.instance_id}: Entity operation failed.")
1744
+ self._logger.info(f"Data: {json.dumps(event.entityOperationFailed)}")
1745
+ pass
613
1746
  else:
614
1747
  eventType = event.WhichOneof("eventType")
615
- raise task.OrchestrationStateError(f"Don't know how to handle event of type '{eventType}'")
1748
+ raise task.OrchestrationStateError(
1749
+ f"Don't know how to handle event of type '{eventType}'"
1750
+ )
616
1751
  except StopIteration as generatorStopped:
617
1752
  # The orchestrator generator function completed
618
1753
  ctx.set_complete(generatorStopped.value, pb.ORCHESTRATION_STATUS_COMPLETED)
619
1754
 
1755
+ def evaluate_orchestration_versioning(self, versioning: Optional[VersioningOptions], orchestration_version: Optional[str]) -> Optional[pb.TaskFailureDetails]:
1756
+ if versioning is None:
1757
+ return None
1758
+ version_comparison = self.compare_versions(orchestration_version, versioning.version)
1759
+ if versioning.match_strategy == VersionMatchStrategy.NONE:
1760
+ return None
1761
+ elif versioning.match_strategy == VersionMatchStrategy.STRICT:
1762
+ if version_comparison != 0:
1763
+ return pb.TaskFailureDetails(
1764
+ errorType="VersionMismatch",
1765
+ errorMessage=f"The orchestration version '{orchestration_version}' does not match the worker version '{versioning.version}'.",
1766
+ isNonRetriable=True,
1767
+ )
1768
+ elif versioning.match_strategy == VersionMatchStrategy.CURRENT_OR_OLDER:
1769
+ if version_comparison > 0:
1770
+ return pb.TaskFailureDetails(
1771
+ errorType="VersionMismatch",
1772
+ errorMessage=f"The orchestration version '{orchestration_version}' is greater than the worker version '{versioning.version}'.",
1773
+ isNonRetriable=True,
1774
+ )
1775
+ else:
1776
+ # If there is a type of versioning we don't understand, it is better to treat it as a versioning failure.
1777
+ return pb.TaskFailureDetails(
1778
+ errorType="VersionMismatch",
1779
+ errorMessage=f"The version match strategy '{versioning.match_strategy}' is unknown.",
1780
+ isNonRetriable=True,
1781
+ )
1782
+
1783
+ def compare_versions(self, source_version: Optional[str], default_version: Optional[str]) -> int:
1784
+ if not source_version and not default_version:
1785
+ return 0
1786
+ if not source_version:
1787
+ return -1
1788
+ if not default_version:
1789
+ return 1
1790
+ try:
1791
+ source_version_parsed = parse(source_version)
1792
+ default_version_parsed = parse(default_version)
1793
+ return (source_version_parsed > default_version_parsed) - (source_version_parsed < default_version_parsed)
1794
+ except InvalidVersion:
1795
+ return (source_version > default_version) - (source_version < default_version)
1796
+
620
1797
 
621
1798
  class _ActivityExecutor:
622
1799
  def __init__(self, registry: _Registry, logger: logging.Logger):
623
1800
  self._registry = registry
624
1801
  self._logger = logger
625
1802
 
626
- def execute(self, orchestration_id: str, name: str, task_id: int, encoded_input: Union[str, None]) -> Union[str, None]:
1803
+ def execute(
1804
+ self,
1805
+ orchestration_id: str,
1806
+ name: str,
1807
+ task_id: int,
1808
+ encoded_input: Optional[str],
1809
+ ) -> Optional[str]:
627
1810
  """Executes an activity function and returns the serialized result, if any."""
628
- self._logger.debug(f"{orchestration_id}/{task_id}: Executing activity '{name}'...")
1811
+ self._logger.debug(
1812
+ f"{orchestration_id}/{task_id}: Executing activity '{name}'..."
1813
+ )
629
1814
  fn = self._registry.get_activity(name)
630
1815
  if not fn:
631
- raise ActivityNotRegisteredError(f"Activity function named '{name}' was not registered!")
1816
+ raise ActivityNotRegisteredError(
1817
+ f"Activity function named '{name}' was not registered!"
1818
+ )
632
1819
 
633
1820
  activity_input = shared.from_json(encoded_input) if encoded_input else None
634
1821
  ctx = task.ActivityContext(orchestration_id, task_id)
@@ -636,49 +1823,108 @@ class _ActivityExecutor:
636
1823
  # Execute the activity function
637
1824
  activity_output = fn(ctx, activity_input)
638
1825
 
639
- encoded_output = shared.to_json(activity_output) if activity_output is not None else None
1826
+ encoded_output = (
1827
+ shared.to_json(activity_output) if activity_output is not None else None
1828
+ )
640
1829
  chars = len(encoded_output) if encoded_output else 0
641
1830
  self._logger.debug(
642
- f"{orchestration_id}/{task_id}: Activity '{name}' completed successfully with {chars} char(s) of encoded output.")
1831
+ f"{orchestration_id}/{task_id}: Activity '{name}' completed successfully with {chars} char(s) of encoded output."
1832
+ )
643
1833
  return encoded_output
644
1834
 
645
1835
 
646
- def _get_non_determinism_error(task_id: int, action_name: str) -> task.NonDeterminismError:
1836
+ class _EntityExecutor:
1837
+ def __init__(self, registry: _Registry, logger: logging.Logger):
1838
+ self._registry = registry
1839
+ self._logger = logger
1840
+
1841
+ def execute(
1842
+ self,
1843
+ orchestration_id: str,
1844
+ entity_id: EntityInstanceId,
1845
+ operation: str,
1846
+ state: StateShim,
1847
+ encoded_input: Optional[str],
1848
+ ) -> Optional[str]:
1849
+ """Executes an entity function and returns the serialized result, if any."""
1850
+ self._logger.debug(
1851
+ f"{orchestration_id}: Executing entity '{entity_id}'..."
1852
+ )
1853
+ fn = self._registry.get_entity(entity_id.entity)
1854
+ if not fn:
1855
+ raise EntityNotRegisteredError(
1856
+ f"Entity function named '{entity_id.entity}' was not registered!"
1857
+ )
1858
+
1859
+ entity_input = shared.from_json(encoded_input) if encoded_input else None
1860
+ ctx = EntityContext(orchestration_id, operation, state, entity_id)
1861
+
1862
+ if isinstance(fn, type) and issubclass(fn, DurableEntity):
1863
+ if self._registry.entity_instances.get(str(entity_id), None):
1864
+ entity_instance = self._registry.entity_instances[str(entity_id)]
1865
+ else:
1866
+ entity_instance = fn()
1867
+ self._registry.entity_instances[str(entity_id)] = entity_instance
1868
+ if not hasattr(entity_instance, operation):
1869
+ raise AttributeError(f"Entity '{entity_id}' does not have operation '{operation}'")
1870
+ method = getattr(entity_instance, operation)
1871
+ if not callable(method):
1872
+ raise TypeError(f"Entity operation '{operation}' is not callable")
1873
+ # Execute the entity method
1874
+ entity_instance._initialize_entity_context(ctx)
1875
+ entity_output = method(entity_input)
1876
+ else:
1877
+ # Execute the entity function
1878
+ entity_output = fn(ctx, entity_input)
1879
+
1880
+ encoded_output = (
1881
+ shared.to_json(entity_output) if entity_output is not None else None
1882
+ )
1883
+ chars = len(encoded_output) if encoded_output else 0
1884
+ self._logger.debug(
1885
+ f"{orchestration_id}: Entity '{entity_id}' completed successfully with {chars} char(s) of encoded output."
1886
+ )
1887
+ return encoded_output
1888
+
1889
+
1890
+ def _get_non_determinism_error(
1891
+ task_id: int, action_name: str
1892
+ ) -> task.NonDeterminismError:
647
1893
  return task.NonDeterminismError(
648
1894
  f"A previous execution called {action_name} with ID={task_id}, but the current "
649
1895
  f"execution doesn't have this action with this ID. This problem occurs when either "
650
1896
  f"the orchestration has non-deterministic logic or if the code was changed after an "
651
- f"instance of this orchestration already started running.")
1897
+ f"instance of this orchestration already started running."
1898
+ )
652
1899
 
653
1900
 
654
1901
  def _get_wrong_action_type_error(
655
- task_id: int,
656
- expected_method_name: str,
657
- action: pb.OrchestratorAction) -> task.NonDeterminismError:
1902
+ task_id: int, expected_method_name: str, action: pb.OrchestratorAction
1903
+ ) -> task.NonDeterminismError:
658
1904
  unexpected_method_name = _get_method_name_for_action(action)
659
1905
  return task.NonDeterminismError(
660
1906
  f"Failed to restore orchestration state due to a history mismatch: A previous execution called "
661
1907
  f"{expected_method_name} with ID={task_id}, but the current execution is instead trying to call "
662
1908
  f"{unexpected_method_name} as part of rebuilding it's history. This kind of mismatch can happen if an "
663
1909
  f"orchestration has non-deterministic logic or if the code was changed after an instance of this "
664
- f"orchestration already started running.")
1910
+ f"orchestration already started running."
1911
+ )
665
1912
 
666
1913
 
667
1914
  def _get_wrong_action_name_error(
668
- task_id: int,
669
- method_name: str,
670
- expected_task_name: str,
671
- actual_task_name: str) -> task.NonDeterminismError:
1915
+ task_id: int, method_name: str, expected_task_name: str, actual_task_name: str
1916
+ ) -> task.NonDeterminismError:
672
1917
  return task.NonDeterminismError(
673
1918
  f"Failed to restore orchestration state due to a history mismatch: A previous execution called "
674
1919
  f"{method_name} with name='{expected_task_name}' and sequence number {task_id}, but the current "
675
1920
  f"execution is instead trying to call {actual_task_name} as part of rebuilding it's history. "
676
1921
  f"This kind of mismatch can happen if an orchestration has non-deterministic logic or if the code "
677
- f"was changed after an instance of this orchestration already started running.")
1922
+ f"was changed after an instance of this orchestration already started running."
1923
+ )
678
1924
 
679
1925
 
680
1926
  def _get_method_name_for_action(action: pb.OrchestratorAction) -> str:
681
- action_type = action.WhichOneof('orchestratorActionType')
1927
+ action_type = action.WhichOneof("orchestratorActionType")
682
1928
  if action_type == "scheduleTask":
683
1929
  return task.get_name(task.OrchestrationContext.call_activity)
684
1930
  elif action_type == "createTimer":
@@ -698,9 +1944,9 @@ def _get_new_event_summary(new_events: Sequence[pb.HistoryEvent]) -> str:
698
1944
  elif len(new_events) == 1:
699
1945
  return f"[{new_events[0].WhichOneof('eventType')}]"
700
1946
  else:
701
- counts = dict[str, int]()
1947
+ counts: dict[str, int] = {}
702
1948
  for event in new_events:
703
- event_type = event.WhichOneof('eventType')
1949
+ event_type = event.WhichOneof("eventType")
704
1950
  counts[event_type] = counts.get(event_type, 0) + 1
705
1951
  return f"[{', '.join(f'{name}={count}' for name, count in counts.items())}]"
706
1952
 
@@ -712,13 +1958,302 @@ def _get_action_summary(new_actions: Sequence[pb.OrchestratorAction]) -> str:
712
1958
  elif len(new_actions) == 1:
713
1959
  return f"[{new_actions[0].WhichOneof('orchestratorActionType')}]"
714
1960
  else:
715
- counts = dict[str, int]()
1961
+ counts: dict[str, int] = {}
716
1962
  for action in new_actions:
717
- action_type = action.WhichOneof('orchestratorActionType')
1963
+ action_type = action.WhichOneof("orchestratorActionType")
718
1964
  counts[action_type] = counts.get(action_type, 0) + 1
719
1965
  return f"[{', '.join(f'{name}={count}' for name, count in counts.items())}]"
720
1966
 
721
1967
 
722
1968
  def _is_suspendable(event: pb.HistoryEvent) -> bool:
723
1969
  """Returns true if the event is one that can be suspended and resumed."""
724
- return event.WhichOneof("eventType") not in ["executionResumed", "executionTerminated"]
1970
+ return event.WhichOneof("eventType") not in [
1971
+ "executionResumed",
1972
+ "executionTerminated",
1973
+ ]
1974
+
1975
+
1976
+ class _AsyncWorkerManager:
1977
+ def __init__(self, concurrency_options: ConcurrencyOptions, logger: logging.Logger):
1978
+ self.concurrency_options = concurrency_options
1979
+ self._logger = logger
1980
+
1981
+ self.activity_semaphore = None
1982
+ self.orchestration_semaphore = None
1983
+ self.entity_semaphore = None
1984
+ # Don't create queues here - defer until we have an event loop
1985
+ self.activity_queue: Optional[asyncio.Queue] = None
1986
+ self.orchestration_queue: Optional[asyncio.Queue] = None
1987
+ self.entity_batch_queue: Optional[asyncio.Queue] = None
1988
+ self._queue_event_loop: Optional[asyncio.AbstractEventLoop] = None
1989
+ # Store work items when no event loop is available
1990
+ self._pending_activity_work: list = []
1991
+ self._pending_orchestration_work: list = []
1992
+ self._pending_entity_batch_work: list = []
1993
+ self.thread_pool = ThreadPoolExecutor(
1994
+ max_workers=concurrency_options.maximum_thread_pool_workers,
1995
+ thread_name_prefix="DurableTask",
1996
+ )
1997
+ self._shutdown = False
1998
+
1999
+ def _ensure_queues_for_current_loop(self):
2000
+ """Ensure queues are bound to the current event loop."""
2001
+ try:
2002
+ current_loop = asyncio.get_running_loop()
2003
+ except RuntimeError:
2004
+ # No event loop running, can't create queues
2005
+ return
2006
+
2007
+ # Check if queues are already properly set up for current loop
2008
+ if self._queue_event_loop is current_loop:
2009
+ if self.activity_queue is not None and self.orchestration_queue is not None and self.entity_batch_queue is not None:
2010
+ # Queues are already bound to the current loop and exist
2011
+ return
2012
+
2013
+ # Need to recreate queues for the current event loop
2014
+ # First, preserve any existing work items
2015
+ existing_activity_items = []
2016
+ existing_orchestration_items = []
2017
+ existing_entity_batch_items = []
2018
+
2019
+ if self.activity_queue is not None:
2020
+ try:
2021
+ while not self.activity_queue.empty():
2022
+ existing_activity_items.append(self.activity_queue.get_nowait())
2023
+ except Exception:
2024
+ pass
2025
+
2026
+ if self.orchestration_queue is not None:
2027
+ try:
2028
+ while not self.orchestration_queue.empty():
2029
+ existing_orchestration_items.append(
2030
+ self.orchestration_queue.get_nowait()
2031
+ )
2032
+ except Exception:
2033
+ pass
2034
+
2035
+ if self.entity_batch_queue is not None:
2036
+ try:
2037
+ while not self.entity_batch_queue.empty():
2038
+ existing_entity_batch_items.append(
2039
+ self.entity_batch_queue.get_nowait()
2040
+ )
2041
+ except Exception:
2042
+ pass
2043
+
2044
+ # Create fresh queues for the current event loop
2045
+ self.activity_queue = asyncio.Queue()
2046
+ self.orchestration_queue = asyncio.Queue()
2047
+ self.entity_batch_queue = asyncio.Queue()
2048
+ self._queue_event_loop = current_loop
2049
+
2050
+ # Restore the work items to the new queues
2051
+ for item in existing_activity_items:
2052
+ self.activity_queue.put_nowait(item)
2053
+ for item in existing_orchestration_items:
2054
+ self.orchestration_queue.put_nowait(item)
2055
+ for item in existing_entity_batch_items:
2056
+ self.entity_batch_queue.put_nowait(item)
2057
+
2058
+ # Move pending work items to the queues
2059
+ for item in self._pending_activity_work:
2060
+ self.activity_queue.put_nowait(item)
2061
+ for item in self._pending_orchestration_work:
2062
+ self.orchestration_queue.put_nowait(item)
2063
+ for item in self._pending_entity_batch_work:
2064
+ self.entity_batch_queue.put_nowait(item)
2065
+
2066
+ # Clear the pending work lists
2067
+ self._pending_activity_work.clear()
2068
+ self._pending_orchestration_work.clear()
2069
+ self._pending_entity_batch_work.clear()
2070
+
2071
+ async def run(self):
2072
+ # Reset shutdown flag in case this manager is being reused
2073
+ self._shutdown = False
2074
+
2075
+ # Ensure queues are properly bound to the current event loop
2076
+ self._ensure_queues_for_current_loop()
2077
+
2078
+ # Create semaphores in the current event loop
2079
+ self.activity_semaphore = asyncio.Semaphore(
2080
+ self.concurrency_options.maximum_concurrent_activity_work_items
2081
+ )
2082
+ self.orchestration_semaphore = asyncio.Semaphore(
2083
+ self.concurrency_options.maximum_concurrent_orchestration_work_items
2084
+ )
2085
+ self.entity_semaphore = asyncio.Semaphore(
2086
+ self.concurrency_options.maximum_concurrent_entity_work_items
2087
+ )
2088
+
2089
+ # Start background consumers for each work type
2090
+ try:
2091
+ if self.activity_queue is not None and self.orchestration_queue is not None \
2092
+ and self.entity_batch_queue is not None:
2093
+ await asyncio.gather(
2094
+ self._consume_queue(self.activity_queue, self.activity_semaphore),
2095
+ self._consume_queue(
2096
+ self.orchestration_queue, self.orchestration_semaphore
2097
+ ),
2098
+ self._consume_queue(
2099
+ self.entity_batch_queue, self.entity_semaphore
2100
+ )
2101
+ )
2102
+ except Exception as queue_exception:
2103
+ self._logger.error(f"Shutting down worker - Uncaught error in worker manager: {queue_exception}")
2104
+ while self.activity_queue is not None and not self.activity_queue.empty():
2105
+ try:
2106
+ func, cancellation_func, args, kwargs = self.activity_queue.get_nowait()
2107
+ await self._run_func(cancellation_func, *args, **kwargs)
2108
+ self._logger.error(f"Activity work item args: {args}, kwargs: {kwargs}")
2109
+ except asyncio.QueueEmpty:
2110
+ # Queue was empty, no cancellation needed
2111
+ pass
2112
+ except Exception as cancellation_exception:
2113
+ self._logger.error(f"Uncaught error while cancelling activity work item: {cancellation_exception}")
2114
+ while self.orchestration_queue is not None and not self.orchestration_queue.empty():
2115
+ try:
2116
+ func, cancellation_func, args, kwargs = self.orchestration_queue.get_nowait()
2117
+ await self._run_func(cancellation_func, *args, **kwargs)
2118
+ self._logger.error(f"Orchestration work item args: {args}, kwargs: {kwargs}")
2119
+ except asyncio.QueueEmpty:
2120
+ # Queue was empty, no cancellation needed
2121
+ pass
2122
+ except Exception as cancellation_exception:
2123
+ self._logger.error(f"Uncaught error while cancelling orchestration work item: {cancellation_exception}")
2124
+ while self.entity_batch_queue is not None and not self.entity_batch_queue.empty():
2125
+ try:
2126
+ func, cancellation_func, args, kwargs = self.entity_batch_queue.get_nowait()
2127
+ await self._run_func(cancellation_func, *args, **kwargs)
2128
+ self._logger.error(f"Entity batch work item args: {args}, kwargs: {kwargs}")
2129
+ except asyncio.QueueEmpty:
2130
+ # Queue was empty, no cancellation needed
2131
+ pass
2132
+ except Exception as cancellation_exception:
2133
+ self._logger.error(f"Uncaught error while cancelling entity batch work item: {cancellation_exception}")
2134
+ self.shutdown()
2135
+
2136
+ async def _consume_queue(self, queue: asyncio.Queue, semaphore: asyncio.Semaphore):
2137
+ # List to track running tasks
2138
+ running_tasks: set[asyncio.Task] = set()
2139
+
2140
+ while True:
2141
+ # Clean up completed tasks
2142
+ done_tasks = {task for task in running_tasks if task.done()}
2143
+ running_tasks -= done_tasks
2144
+
2145
+ # Exit if shutdown is set and the queue is empty and no tasks are running
2146
+ if self._shutdown and queue.empty() and not running_tasks:
2147
+ break
2148
+
2149
+ try:
2150
+ work = await asyncio.wait_for(queue.get(), timeout=1.0)
2151
+ except asyncio.TimeoutError:
2152
+ continue
2153
+
2154
+ func, cancellation_func, args, kwargs = work
2155
+ # Create a concurrent task for processing
2156
+ task = asyncio.create_task(
2157
+ self._process_work_item(semaphore, queue, func, cancellation_func, args, kwargs)
2158
+ )
2159
+ running_tasks.add(task)
2160
+
2161
+ async def _process_work_item(
2162
+ self, semaphore: asyncio.Semaphore, queue: asyncio.Queue, func, cancellation_func, args, kwargs
2163
+ ):
2164
+ async with semaphore:
2165
+ try:
2166
+ await self._run_func(func, *args, **kwargs)
2167
+ except Exception as work_exception:
2168
+ self._logger.error(f"Uncaught error while processing work item, item will be abandoned: {work_exception}")
2169
+ await self._run_func(cancellation_func, *args, **kwargs)
2170
+ finally:
2171
+ queue.task_done()
2172
+
2173
+ async def _run_func(self, func, *args, **kwargs):
2174
+ if inspect.iscoroutinefunction(func):
2175
+ return await func(*args, **kwargs)
2176
+ else:
2177
+ loop = asyncio.get_running_loop()
2178
+ # Avoid submitting to executor after shutdown
2179
+ if (
2180
+ getattr(self, "_shutdown", False) and getattr(self, "thread_pool", None) and getattr(
2181
+ self.thread_pool, "_shutdown", False)
2182
+ ):
2183
+ return None
2184
+ return await loop.run_in_executor(
2185
+ self.thread_pool, lambda: func(*args, **kwargs)
2186
+ )
2187
+
2188
+ def submit_activity(self, func, cancellation_func, *args, **kwargs):
2189
+ if self._shutdown:
2190
+ raise RuntimeError("Cannot submit new work items after shutdown has been initiated.")
2191
+ work_item = (func, cancellation_func, args, kwargs)
2192
+ self._ensure_queues_for_current_loop()
2193
+ if self.activity_queue is not None:
2194
+ self.activity_queue.put_nowait(work_item)
2195
+ else:
2196
+ # No event loop running, store in pending list
2197
+ self._pending_activity_work.append(work_item)
2198
+
2199
+ def submit_orchestration(self, func, cancellation_func, *args, **kwargs):
2200
+ if self._shutdown:
2201
+ raise RuntimeError("Cannot submit new work items after shutdown has been initiated.")
2202
+ work_item = (func, cancellation_func, args, kwargs)
2203
+ self._ensure_queues_for_current_loop()
2204
+ if self.orchestration_queue is not None:
2205
+ self.orchestration_queue.put_nowait(work_item)
2206
+ else:
2207
+ # No event loop running, store in pending list
2208
+ self._pending_orchestration_work.append(work_item)
2209
+
2210
+ def submit_entity_batch(self, func, cancellation_func, *args, **kwargs):
2211
+ if self._shutdown:
2212
+ raise RuntimeError("Cannot submit new work items after shutdown has been initiated.")
2213
+ work_item = (func, cancellation_func, args, kwargs)
2214
+ self._ensure_queues_for_current_loop()
2215
+ if self.entity_batch_queue is not None:
2216
+ self.entity_batch_queue.put_nowait(work_item)
2217
+ else:
2218
+ # No event loop running, store in pending list
2219
+ self._pending_entity_batch_work.append(work_item)
2220
+
2221
+ def shutdown(self):
2222
+ self._shutdown = True
2223
+ self.thread_pool.shutdown(wait=True)
2224
+
2225
+ async def reset_for_new_run(self):
2226
+ """Reset the manager state for a new run."""
2227
+ self._shutdown = False
2228
+ # Clear any existing queues - they'll be recreated when needed
2229
+ if self.activity_queue is not None:
2230
+ # Clear existing queue by creating a new one
2231
+ # This ensures no items from previous runs remain
2232
+ try:
2233
+ while not self.activity_queue.empty():
2234
+ func, cancellation_func, args, kwargs = self.activity_queue.get_nowait()
2235
+ await self._run_func(cancellation_func, *args, **kwargs)
2236
+ except Exception as reset_exception:
2237
+ self._logger.warning(f"Error while clearing activity queue during reset: {reset_exception}")
2238
+ if self.orchestration_queue is not None:
2239
+ try:
2240
+ while not self.orchestration_queue.empty():
2241
+ func, cancellation_func, args, kwargs = self.orchestration_queue.get_nowait()
2242
+ await self._run_func(cancellation_func, *args, **kwargs)
2243
+ except Exception as reset_exception:
2244
+ self._logger.warning(f"Error while clearing orchestration queue during reset: {reset_exception}")
2245
+ if self.entity_batch_queue is not None:
2246
+ try:
2247
+ while not self.entity_batch_queue.empty():
2248
+ func, cancellation_func, args, kwargs = self.entity_batch_queue.get_nowait()
2249
+ await self._run_func(cancellation_func, *args, **kwargs)
2250
+ except Exception as reset_exception:
2251
+ self._logger.warning(f"Error while clearing entity queue during reset: {reset_exception}")
2252
+ # Clear pending work lists
2253
+ self._pending_activity_work.clear()
2254
+ self._pending_orchestration_work.clear()
2255
+ self._pending_entity_batch_work.clear()
2256
+
2257
+
2258
+ # Export public API
2259
+ __all__ = ["ConcurrencyOptions", "TaskHubGrpcWorker"]