durabletask 0.2.0b1__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of durabletask might be problematic. Click here for more details.

durabletask/worker.py CHANGED
@@ -1,8 +1,12 @@
1
1
  # Copyright (c) Microsoft Corporation.
2
2
  # Licensed under the MIT License.
3
3
 
4
- import concurrent.futures
4
+ import asyncio
5
+ import inspect
5
6
  import logging
7
+ import os
8
+ import random
9
+ from concurrent.futures import ThreadPoolExecutor
6
10
  from datetime import datetime, timedelta
7
11
  from threading import Event, Thread
8
12
  from types import GeneratorType
@@ -12,19 +16,63 @@ import grpc
12
16
  from google.protobuf import empty_pb2
13
17
 
14
18
  import durabletask.internal.helpers as ph
15
- import durabletask.internal.helpers as pbh
16
19
  import durabletask.internal.orchestrator_service_pb2 as pb
17
20
  import durabletask.internal.orchestrator_service_pb2_grpc as stubs
18
21
  import durabletask.internal.shared as shared
19
22
  from durabletask import task
20
23
  from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl
21
24
 
22
- TInput = TypeVar('TInput')
23
- TOutput = TypeVar('TOutput')
25
+ TInput = TypeVar("TInput")
26
+ TOutput = TypeVar("TOutput")
27
+
28
+
29
+ class ConcurrencyOptions:
30
+ """Configuration options for controlling concurrency of different work item types and the thread pool size.
31
+
32
+ This class provides fine-grained control over concurrent processing limits for
33
+ activities, orchestrations and the thread pool size.
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ maximum_concurrent_activity_work_items: Optional[int] = None,
39
+ maximum_concurrent_orchestration_work_items: Optional[int] = None,
40
+ maximum_thread_pool_workers: Optional[int] = None,
41
+ ):
42
+ """Initialize concurrency options.
43
+
44
+ Args:
45
+ maximum_concurrent_activity_work_items: Maximum number of activity work items
46
+ that can be processed concurrently. Defaults to 100 * processor_count.
47
+ maximum_concurrent_orchestration_work_items: Maximum number of orchestration work items
48
+ that can be processed concurrently. Defaults to 100 * processor_count.
49
+ maximum_thread_pool_workers: Maximum number of thread pool workers to use.
50
+ """
51
+ processor_count = os.cpu_count() or 1
52
+ default_concurrency = 100 * processor_count
53
+ # see https://docs.python.org/3/library/concurrent.futures.html
54
+ default_max_workers = processor_count + 4
55
+
56
+ self.maximum_concurrent_activity_work_items = (
57
+ maximum_concurrent_activity_work_items
58
+ if maximum_concurrent_activity_work_items is not None
59
+ else default_concurrency
60
+ )
24
61
 
62
+ self.maximum_concurrent_orchestration_work_items = (
63
+ maximum_concurrent_orchestration_work_items
64
+ if maximum_concurrent_orchestration_work_items is not None
65
+ else default_concurrency
66
+ )
67
+
68
+ self.maximum_thread_pool_workers = (
69
+ maximum_thread_pool_workers
70
+ if maximum_thread_pool_workers is not None
71
+ else default_max_workers
72
+ )
25
73
 
26
- class _Registry:
27
74
 
75
+ class _Registry:
28
76
  orchestrators: dict[str, task.Orchestrator]
29
77
  activities: dict[str, task.Activity]
30
78
 
@@ -34,7 +82,7 @@ class _Registry:
34
82
 
35
83
  def add_orchestrator(self, fn: task.Orchestrator) -> str:
36
84
  if fn is None:
37
- raise ValueError('An orchestrator function argument is required.')
85
+ raise ValueError("An orchestrator function argument is required.")
38
86
 
39
87
  name = task.get_name(fn)
40
88
  self.add_named_orchestrator(name, fn)
@@ -42,7 +90,7 @@ class _Registry:
42
90
 
43
91
  def add_named_orchestrator(self, name: str, fn: task.Orchestrator) -> None:
44
92
  if not name:
45
- raise ValueError('A non-empty orchestrator name is required.')
93
+ raise ValueError("A non-empty orchestrator name is required.")
46
94
  if name in self.orchestrators:
47
95
  raise ValueError(f"A '{name}' orchestrator already exists.")
48
96
 
@@ -53,7 +101,7 @@ class _Registry:
53
101
 
54
102
  def add_activity(self, fn: task.Activity) -> str:
55
103
  if fn is None:
56
- raise ValueError('An activity function argument is required.')
104
+ raise ValueError("An activity function argument is required.")
57
105
 
58
106
  name = task.get_name(fn)
59
107
  self.add_named_activity(name, fn)
@@ -61,7 +109,7 @@ class _Registry:
61
109
 
62
110
  def add_named_activity(self, name: str, fn: task.Activity) -> None:
63
111
  if not name:
64
- raise ValueError('A non-empty activity name is required.')
112
+ raise ValueError("A non-empty activity name is required.")
65
113
  if name in self.activities:
66
114
  raise ValueError(f"A '{name}' activity already exists.")
67
115
 
@@ -73,32 +121,125 @@ class _Registry:
73
121
 
74
122
  class OrchestratorNotRegisteredError(ValueError):
75
123
  """Raised when attempting to start an orchestration that is not registered"""
124
+
76
125
  pass
77
126
 
78
127
 
79
128
  class ActivityNotRegisteredError(ValueError):
80
129
  """Raised when attempting to call an activity that is not registered"""
130
+
81
131
  pass
82
132
 
83
133
 
84
134
  class TaskHubGrpcWorker:
135
+ """A gRPC-based worker for processing durable task orchestrations and activities.
136
+
137
+ This worker connects to a Durable Task backend service via gRPC to receive and process
138
+ work items including orchestration functions and activity functions. It provides
139
+ concurrent execution capabilities with configurable limits and automatic retry handling.
140
+
141
+ The worker manages the complete lifecycle:
142
+ - Registers orchestrator and activity functions
143
+ - Connects to the gRPC backend service
144
+ - Receives work items and executes them concurrently
145
+ - Handles failures, retries, and state management
146
+ - Provides logging and monitoring capabilities
147
+
148
+ Args:
149
+ host_address (Optional[str], optional): The gRPC endpoint address of the backend service.
150
+ Defaults to the value from environment variables or localhost.
151
+ metadata (Optional[list[tuple[str, str]]], optional): gRPC metadata to include with
152
+ requests. Used for authentication and routing. Defaults to None.
153
+ log_handler (optional): Custom logging handler for worker logs. Defaults to None.
154
+ log_formatter (Optional[logging.Formatter], optional): Custom log formatter.
155
+ Defaults to None.
156
+ secure_channel (bool, optional): Whether to use a secure gRPC channel (TLS).
157
+ Defaults to False.
158
+ interceptors (Optional[Sequence[shared.ClientInterceptor]], optional): Custom gRPC
159
+ interceptors to apply to the channel. Defaults to None.
160
+ concurrency_options (Optional[ConcurrencyOptions], optional): Configuration for
161
+ controlling worker concurrency limits. If None, default settings are used.
162
+
163
+ Attributes:
164
+ concurrency_options (ConcurrencyOptions): The current concurrency configuration.
165
+
166
+ Example:
167
+ Basic worker setup:
168
+
169
+ >>> from durabletask.worker import TaskHubGrpcWorker, ConcurrencyOptions
170
+ >>>
171
+ >>> # Create worker with custom concurrency settings
172
+ >>> concurrency = ConcurrencyOptions(
173
+ ... maximum_concurrent_activity_work_items=50,
174
+ ... maximum_concurrent_orchestration_work_items=20
175
+ ... )
176
+ >>> worker = TaskHubGrpcWorker(
177
+ ... host_address="localhost:4001",
178
+ ... concurrency_options=concurrency
179
+ ... )
180
+ >>>
181
+ >>> # Register functions
182
+ >>> @worker.add_orchestrator
183
+ ... def my_orchestrator(context, input):
184
+ ... result = yield context.call_activity("my_activity", input="hello")
185
+ ... return result
186
+ >>>
187
+ >>> @worker.add_activity
188
+ ... def my_activity(context, input):
189
+ ... return f"Processed: {input}"
190
+ >>>
191
+ >>> # Start the worker
192
+ >>> worker.start()
193
+ >>> # ... worker runs in background thread
194
+ >>> worker.stop()
195
+
196
+ Using as context manager:
197
+
198
+ >>> with TaskHubGrpcWorker() as worker:
199
+ ... worker.add_orchestrator(my_orchestrator)
200
+ ... worker.add_activity(my_activity)
201
+ ... worker.start()
202
+ ... # Worker automatically stops when exiting context
203
+
204
+ Raises:
205
+ RuntimeError: If attempting to add orchestrators/activities while the worker is running,
206
+ or if starting a worker that is already running.
207
+ OrchestratorNotRegisteredError: If an orchestration work item references an
208
+ unregistered orchestrator function.
209
+ ActivityNotRegisteredError: If an activity work item references an unregistered
210
+ activity function.
211
+ """
212
+
85
213
  _response_stream: Optional[grpc.Future] = None
86
214
  _interceptors: Optional[list[shared.ClientInterceptor]] = None
87
215
 
88
- def __init__(self, *,
89
- host_address: Optional[str] = None,
90
- metadata: Optional[list[tuple[str, str]]] = None,
91
- log_handler=None,
92
- log_formatter: Optional[logging.Formatter] = None,
93
- secure_channel: bool = False,
94
- interceptors: Optional[Sequence[shared.ClientInterceptor]] = None):
216
+ def __init__(
217
+ self,
218
+ *,
219
+ host_address: Optional[str] = None,
220
+ metadata: Optional[list[tuple[str, str]]] = None,
221
+ log_handler=None,
222
+ log_formatter: Optional[logging.Formatter] = None,
223
+ secure_channel: bool = False,
224
+ interceptors: Optional[Sequence[shared.ClientInterceptor]] = None,
225
+ concurrency_options: Optional[ConcurrencyOptions] = None,
226
+ ):
95
227
  self._registry = _Registry()
96
- self._host_address = host_address if host_address else shared.get_default_host_address()
228
+ self._host_address = (
229
+ host_address if host_address else shared.get_default_host_address()
230
+ )
97
231
  self._logger = shared.get_logger("worker", log_handler, log_formatter)
98
232
  self._shutdown = Event()
99
233
  self._is_running = False
100
234
  self._secure_channel = secure_channel
101
235
 
236
+ # Use provided concurrency options or create default ones
237
+ self._concurrency_options = (
238
+ concurrency_options
239
+ if concurrency_options is not None
240
+ else ConcurrencyOptions()
241
+ )
242
+
102
243
  # Determine the interceptors to use
103
244
  if interceptors is not None:
104
245
  self._interceptors = list(interceptors)
@@ -109,6 +250,13 @@ class TaskHubGrpcWorker:
109
250
  else:
110
251
  self._interceptors = None
111
252
 
253
+ self._async_worker_manager = _AsyncWorkerManager(self._concurrency_options)
254
+
255
+ @property
256
+ def concurrency_options(self) -> ConcurrencyOptions:
257
+ """Get the current concurrency options for this worker."""
258
+ return self._concurrency_options
259
+
112
260
  def __enter__(self):
113
261
  return self
114
262
 
@@ -118,72 +266,223 @@ class TaskHubGrpcWorker:
118
266
  def add_orchestrator(self, fn: task.Orchestrator) -> str:
119
267
  """Registers an orchestrator function with the worker."""
120
268
  if self._is_running:
121
- raise RuntimeError('Orchestrators cannot be added while the worker is running.')
269
+ raise RuntimeError(
270
+ "Orchestrators cannot be added while the worker is running."
271
+ )
122
272
  return self._registry.add_orchestrator(fn)
123
273
 
124
274
  def add_activity(self, fn: task.Activity) -> str:
125
275
  """Registers an activity function with the worker."""
126
276
  if self._is_running:
127
- raise RuntimeError('Activities cannot be added while the worker is running.')
277
+ raise RuntimeError(
278
+ "Activities cannot be added while the worker is running."
279
+ )
128
280
  return self._registry.add_activity(fn)
129
281
 
130
282
  def start(self):
131
283
  """Starts the worker on a background thread and begins listening for work items."""
132
- channel = shared.get_grpc_channel(self._host_address, self._secure_channel, self._interceptors)
133
- stub = stubs.TaskHubSidecarServiceStub(channel)
134
-
135
284
  if self._is_running:
136
- raise RuntimeError('The worker is already running.')
285
+ raise RuntimeError("The worker is already running.")
137
286
 
138
287
  def run_loop():
139
- # TODO: Investigate whether asyncio could be used to enable greater concurrency for async activity
140
- # functions. We'd need to know ahead of time whether a function is async or not.
141
- # TODO: Max concurrency configuration settings
142
- with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor:
143
- while not self._shutdown.is_set():
144
- try:
145
- # send a "Hello" message to the sidecar to ensure that it's listening
146
- stub.Hello(empty_pb2.Empty())
147
-
148
- # stream work items
149
- self._response_stream = stub.GetWorkItems(pb.GetWorkItemsRequest())
150
- self._logger.info(f'Successfully connected to {self._host_address}. Waiting for work items...')
151
-
152
- # The stream blocks until either a work item is received or the stream is canceled
153
- # by another thread (see the stop() method).
154
- for work_item in self._response_stream: # type: ignore
155
- request_type = work_item.WhichOneof('request')
156
- self._logger.debug(f'Received "{request_type}" work item')
157
- if work_item.HasField('orchestratorRequest'):
158
- executor.submit(self._execute_orchestrator, work_item.orchestratorRequest, stub, work_item.completionToken)
159
- elif work_item.HasField('activityRequest'):
160
- executor.submit(self._execute_activity, work_item.activityRequest, stub, work_item.completionToken)
161
- elif work_item.HasField('healthPing'):
162
- pass # no-op
163
- else:
164
- self._logger.warning(f'Unexpected work item type: {request_type}')
165
-
166
- except grpc.RpcError as rpc_error:
167
- if rpc_error.code() == grpc.StatusCode.CANCELLED: # type: ignore
168
- self._logger.info(f'Disconnected from {self._host_address}')
169
- elif rpc_error.code() == grpc.StatusCode.UNAVAILABLE: # type: ignore
170
- self._logger.warning(
171
- f'The sidecar at address {self._host_address} is unavailable - will continue retrying')
172
- else:
173
- self._logger.warning(f'Unexpected error: {rpc_error}')
174
- except Exception as ex:
175
- self._logger.warning(f'Unexpected error: {ex}')
176
-
177
- # CONSIDER: exponential backoff
178
- self._shutdown.wait(5)
179
- self._logger.info("No longer listening for work items")
180
- return
288
+ loop = asyncio.new_event_loop()
289
+ asyncio.set_event_loop(loop)
290
+ loop.run_until_complete(self._async_run_loop())
181
291
 
182
292
  self._logger.info(f"Starting gRPC worker that connects to {self._host_address}")
183
293
  self._runLoop = Thread(target=run_loop)
184
294
  self._runLoop.start()
185
295
  self._is_running = True
186
296
 
297
+ async def _async_run_loop(self):
298
+ worker_task = asyncio.create_task(self._async_worker_manager.run())
299
+ # Connection state management for retry fix
300
+ current_channel = None
301
+ current_stub = None
302
+ current_reader_thread = None
303
+ conn_retry_count = 0
304
+ conn_max_retry_delay = 60
305
+
306
+ def create_fresh_connection():
307
+ nonlocal current_channel, current_stub, conn_retry_count
308
+ if current_channel:
309
+ try:
310
+ current_channel.close()
311
+ except Exception:
312
+ pass
313
+ current_channel = None
314
+ current_stub = None
315
+ try:
316
+ current_channel = shared.get_grpc_channel(
317
+ self._host_address, self._secure_channel, self._interceptors
318
+ )
319
+ current_stub = stubs.TaskHubSidecarServiceStub(current_channel)
320
+ current_stub.Hello(empty_pb2.Empty())
321
+ conn_retry_count = 0
322
+ self._logger.info(f"Created fresh connection to {self._host_address}")
323
+ except Exception as e:
324
+ self._logger.warning(f"Failed to create connection: {e}")
325
+ current_channel = None
326
+ current_stub = None
327
+ raise
328
+
329
+ def invalidate_connection():
330
+ nonlocal current_channel, current_stub, current_reader_thread
331
+ # Cancel the response stream first to signal the reader thread to stop
332
+ if self._response_stream is not None:
333
+ try:
334
+ self._response_stream.cancel()
335
+ except Exception:
336
+ pass
337
+ self._response_stream = None
338
+
339
+ # Wait for the reader thread to finish
340
+ if current_reader_thread is not None:
341
+ try:
342
+ current_reader_thread.join(timeout=2)
343
+ if current_reader_thread.is_alive():
344
+ self._logger.warning("Stream reader thread did not shut down gracefully")
345
+ except Exception:
346
+ pass
347
+ current_reader_thread = None
348
+
349
+ # Close the channel
350
+ if current_channel:
351
+ try:
352
+ current_channel.close()
353
+ except Exception:
354
+ pass
355
+ current_channel = None
356
+ current_stub = None
357
+
358
+ def should_invalidate_connection(rpc_error):
359
+ error_code = rpc_error.code() # type: ignore
360
+ connection_level_errors = {
361
+ grpc.StatusCode.UNAVAILABLE,
362
+ grpc.StatusCode.DEADLINE_EXCEEDED,
363
+ grpc.StatusCode.CANCELLED,
364
+ grpc.StatusCode.UNAUTHENTICATED,
365
+ grpc.StatusCode.ABORTED,
366
+ }
367
+ return error_code in connection_level_errors
368
+
369
+ while not self._shutdown.is_set():
370
+ if current_stub is None:
371
+ try:
372
+ create_fresh_connection()
373
+ except Exception:
374
+ conn_retry_count += 1
375
+ delay = min(
376
+ conn_max_retry_delay,
377
+ (2 ** min(conn_retry_count, 6)) + random.uniform(0, 1),
378
+ )
379
+ self._logger.warning(
380
+ f"Connection failed, retrying in {delay:.2f} seconds (attempt {conn_retry_count})"
381
+ )
382
+ if self._shutdown.wait(delay):
383
+ break
384
+ continue
385
+ try:
386
+ assert current_stub is not None
387
+ stub = current_stub
388
+ get_work_items_request = pb.GetWorkItemsRequest(
389
+ maxConcurrentOrchestrationWorkItems=self._concurrency_options.maximum_concurrent_orchestration_work_items,
390
+ maxConcurrentActivityWorkItems=self._concurrency_options.maximum_concurrent_activity_work_items,
391
+ )
392
+ self._response_stream = stub.GetWorkItems(get_work_items_request)
393
+ self._logger.info(
394
+ f"Successfully connected to {self._host_address}. Waiting for work items..."
395
+ )
396
+
397
+ # Use a thread to read from the blocking gRPC stream and forward to asyncio
398
+ import queue
399
+
400
+ work_item_queue = queue.Queue()
401
+
402
+ def stream_reader():
403
+ try:
404
+ for work_item in self._response_stream:
405
+ work_item_queue.put(work_item)
406
+ except Exception as e:
407
+ work_item_queue.put(e)
408
+
409
+ import threading
410
+
411
+ current_reader_thread = threading.Thread(target=stream_reader, daemon=True)
412
+ current_reader_thread.start()
413
+ loop = asyncio.get_running_loop()
414
+ while not self._shutdown.is_set():
415
+ try:
416
+ work_item = await loop.run_in_executor(
417
+ None, work_item_queue.get
418
+ )
419
+ if isinstance(work_item, Exception):
420
+ raise work_item
421
+ request_type = work_item.WhichOneof("request")
422
+ self._logger.debug(f'Received "{request_type}" work item')
423
+ if work_item.HasField("orchestratorRequest"):
424
+ self._async_worker_manager.submit_orchestration(
425
+ self._execute_orchestrator,
426
+ work_item.orchestratorRequest,
427
+ stub,
428
+ work_item.completionToken,
429
+ )
430
+ elif work_item.HasField("activityRequest"):
431
+ self._async_worker_manager.submit_activity(
432
+ self._execute_activity,
433
+ work_item.activityRequest,
434
+ stub,
435
+ work_item.completionToken,
436
+ )
437
+ elif work_item.HasField("healthPing"):
438
+ pass
439
+ else:
440
+ self._logger.warning(
441
+ f"Unexpected work item type: {request_type}"
442
+ )
443
+ except Exception as e:
444
+ self._logger.warning(f"Error in work item stream: {e}")
445
+ raise e
446
+ current_reader_thread.join(timeout=1)
447
+ self._logger.info("Work item stream ended normally")
448
+ except grpc.RpcError as rpc_error:
449
+ should_invalidate = should_invalidate_connection(rpc_error)
450
+ if should_invalidate:
451
+ invalidate_connection()
452
+ error_code = rpc_error.code() # type: ignore
453
+ error_details = str(rpc_error)
454
+
455
+ if error_code == grpc.StatusCode.CANCELLED:
456
+ self._logger.info(f"Disconnected from {self._host_address}")
457
+ break
458
+ elif error_code == grpc.StatusCode.UNAVAILABLE:
459
+ # Check if this is a connection timeout scenario
460
+ if "Timeout occurred" in error_details or "Failed to connect to remote host" in error_details:
461
+ self._logger.warning(
462
+ f"Connection timeout to {self._host_address}: {error_details} - will retry with fresh connection"
463
+ )
464
+ else:
465
+ self._logger.warning(
466
+ f"The sidecar at address {self._host_address} is unavailable: {error_details} - will continue retrying"
467
+ )
468
+ elif should_invalidate:
469
+ self._logger.warning(
470
+ f"Connection-level gRPC error ({error_code}): {rpc_error} - resetting connection"
471
+ )
472
+ else:
473
+ self._logger.warning(
474
+ f"Application-level gRPC error ({error_code}): {rpc_error}"
475
+ )
476
+ self._shutdown.wait(1)
477
+ except Exception as ex:
478
+ invalidate_connection()
479
+ self._logger.warning(f"Unexpected error: {ex}")
480
+ self._shutdown.wait(1)
481
+ invalidate_connection()
482
+ self._logger.info("No longer listening for work items")
483
+ self._async_worker_manager.shutdown()
484
+ await worker_task
485
+
187
486
  def stop(self):
188
487
  """Stops the worker and waits for any pending work items to complete."""
189
488
  if not self._is_running:
@@ -195,51 +494,80 @@ class TaskHubGrpcWorker:
195
494
  self._response_stream.cancel()
196
495
  if self._runLoop is not None:
197
496
  self._runLoop.join(timeout=30)
497
+ self._async_worker_manager.shutdown()
198
498
  self._logger.info("Worker shutdown completed")
199
499
  self._is_running = False
200
500
 
201
- def _execute_orchestrator(self, req: pb.OrchestratorRequest, stub: stubs.TaskHubSidecarServiceStub, completionToken):
501
+ def _execute_orchestrator(
502
+ self,
503
+ req: pb.OrchestratorRequest,
504
+ stub: stubs.TaskHubSidecarServiceStub,
505
+ completionToken,
506
+ ):
202
507
  try:
203
508
  executor = _OrchestrationExecutor(self._registry, self._logger)
204
509
  result = executor.execute(req.instanceId, req.pastEvents, req.newEvents)
205
510
  res = pb.OrchestratorResponse(
206
511
  instanceId=req.instanceId,
207
512
  actions=result.actions,
208
- customStatus=pbh.get_string_value(result.encoded_custom_status),
209
- completionToken=completionToken)
513
+ customStatus=ph.get_string_value(result.encoded_custom_status),
514
+ completionToken=completionToken,
515
+ )
210
516
  except Exception as ex:
211
- self._logger.exception(f"An error occurred while trying to execute instance '{req.instanceId}': {ex}")
212
- failure_details = pbh.new_failure_details(ex)
213
- actions = [pbh.new_complete_orchestration_action(-1, pb.ORCHESTRATION_STATUS_FAILED, "", failure_details)]
214
- res = pb.OrchestratorResponse(instanceId=req.instanceId, actions=actions, completionToken=completionToken)
517
+ self._logger.exception(
518
+ f"An error occurred while trying to execute instance '{req.instanceId}': {ex}"
519
+ )
520
+ failure_details = ph.new_failure_details(ex)
521
+ actions = [
522
+ ph.new_complete_orchestration_action(
523
+ -1, pb.ORCHESTRATION_STATUS_FAILED, "", failure_details
524
+ )
525
+ ]
526
+ res = pb.OrchestratorResponse(
527
+ instanceId=req.instanceId,
528
+ actions=actions,
529
+ completionToken=completionToken,
530
+ )
215
531
 
216
532
  try:
217
533
  stub.CompleteOrchestratorTask(res)
218
534
  except Exception as ex:
219
- self._logger.exception(f"Failed to deliver orchestrator response for '{req.instanceId}' to sidecar: {ex}")
220
-
221
- def _execute_activity(self, req: pb.ActivityRequest, stub: stubs.TaskHubSidecarServiceStub, completionToken):
535
+ self._logger.exception(
536
+ f"Failed to deliver orchestrator response for '{req.instanceId}' to sidecar: {ex}"
537
+ )
538
+
539
+ def _execute_activity(
540
+ self,
541
+ req: pb.ActivityRequest,
542
+ stub: stubs.TaskHubSidecarServiceStub,
543
+ completionToken,
544
+ ):
222
545
  instance_id = req.orchestrationInstance.instanceId
223
546
  try:
224
547
  executor = _ActivityExecutor(self._registry, self._logger)
225
- result = executor.execute(instance_id, req.name, req.taskId, req.input.value)
548
+ result = executor.execute(
549
+ instance_id, req.name, req.taskId, req.input.value
550
+ )
226
551
  res = pb.ActivityResponse(
227
552
  instanceId=instance_id,
228
553
  taskId=req.taskId,
229
- result=pbh.get_string_value(result),
230
- completionToken=completionToken)
554
+ result=ph.get_string_value(result),
555
+ completionToken=completionToken,
556
+ )
231
557
  except Exception as ex:
232
558
  res = pb.ActivityResponse(
233
559
  instanceId=instance_id,
234
560
  taskId=req.taskId,
235
- failureDetails=pbh.new_failure_details(ex),
236
- completionToken=completionToken)
561
+ failureDetails=ph.new_failure_details(ex),
562
+ completionToken=completionToken,
563
+ )
237
564
 
238
565
  try:
239
566
  stub.CompleteActivityTask(res)
240
567
  except Exception as ex:
241
568
  self._logger.exception(
242
- f"Failed to deliver activity response for '{req.name}#{req.taskId}' of orchestration ID '{instance_id}' to sidecar: {ex}")
569
+ f"Failed to deliver activity response for '{req.name}#{req.taskId}' of orchestration ID '{instance_id}' to sidecar: {ex}"
570
+ )
243
571
 
244
572
 
245
573
  class _RuntimeOrchestrationContext(task.OrchestrationContext):
@@ -273,7 +601,9 @@ class _RuntimeOrchestrationContext(task.OrchestrationContext):
273
601
  def resume(self):
274
602
  if self._generator is None:
275
603
  # This is never expected unless maybe there's an issue with the history
276
- raise TypeError("The orchestrator generator is not initialized! Was the orchestration history corrupted?")
604
+ raise TypeError(
605
+ "The orchestrator generator is not initialized! Was the orchestration history corrupted?"
606
+ )
277
607
 
278
608
  # We can resume the generator only if the previously yielded task
279
609
  # has reached a completed state. The only time this won't be the
@@ -294,7 +624,12 @@ class _RuntimeOrchestrationContext(task.OrchestrationContext):
294
624
  raise TypeError("The orchestrator generator yielded a non-Task object")
295
625
  self._previous_task = next_task
296
626
 
297
- def set_complete(self, result: Any, status: pb.OrchestrationStatus, is_result_encoded: bool = False):
627
+ def set_complete(
628
+ self,
629
+ result: Any,
630
+ status: pb.OrchestrationStatus,
631
+ is_result_encoded: bool = False,
632
+ ):
298
633
  if self._is_complete:
299
634
  return
300
635
 
@@ -307,7 +642,8 @@ class _RuntimeOrchestrationContext(task.OrchestrationContext):
307
642
  if result is not None:
308
643
  result_json = result if is_result_encoded else shared.to_json(result)
309
644
  action = ph.new_complete_orchestration_action(
310
- self.next_sequence_number(), status, result_json)
645
+ self.next_sequence_number(), status, result_json
646
+ )
311
647
  self._pending_actions[action.id] = action
312
648
 
313
649
  def set_failed(self, ex: Exception):
@@ -319,7 +655,10 @@ class _RuntimeOrchestrationContext(task.OrchestrationContext):
319
655
  self._completion_status = pb.ORCHESTRATION_STATUS_FAILED
320
656
 
321
657
  action = ph.new_complete_orchestration_action(
322
- self.next_sequence_number(), pb.ORCHESTRATION_STATUS_FAILED, None, ph.new_failure_details(ex)
658
+ self.next_sequence_number(),
659
+ pb.ORCHESTRATION_STATUS_FAILED,
660
+ None,
661
+ ph.new_failure_details(ex),
323
662
  )
324
663
  self._pending_actions[action.id] = action
325
664
 
@@ -343,14 +682,21 @@ class _RuntimeOrchestrationContext(task.OrchestrationContext):
343
682
  # replayed when the new instance starts.
344
683
  for event_name, values in self._received_events.items():
345
684
  for event_value in values:
346
- encoded_value = shared.to_json(event_value) if event_value else None
347
- carryover_events.append(ph.new_event_raised_event(event_name, encoded_value))
685
+ encoded_value = (
686
+ shared.to_json(event_value) if event_value else None
687
+ )
688
+ carryover_events.append(
689
+ ph.new_event_raised_event(event_name, encoded_value)
690
+ )
348
691
  action = ph.new_complete_orchestration_action(
349
692
  self.next_sequence_number(),
350
693
  pb.ORCHESTRATION_STATUS_CONTINUED_AS_NEW,
351
- result=shared.to_json(self._new_input) if self._new_input is not None else None,
694
+ result=shared.to_json(self._new_input)
695
+ if self._new_input is not None
696
+ else None,
352
697
  failure_details=None,
353
- carryover_events=carryover_events)
698
+ carryover_events=carryover_events,
699
+ )
354
700
  return [action]
355
701
  else:
356
702
  return list(self._pending_actions.values())
@@ -367,60 +713,84 @@ class _RuntimeOrchestrationContext(task.OrchestrationContext):
367
713
  def current_utc_datetime(self) -> datetime:
368
714
  return self._current_utc_datetime
369
715
 
370
- @property
371
- def is_replaying(self) -> bool:
372
- return self._is_replaying
373
-
374
716
  @current_utc_datetime.setter
375
717
  def current_utc_datetime(self, value: datetime):
376
718
  self._current_utc_datetime = value
377
719
 
720
+ @property
721
+ def is_replaying(self) -> bool:
722
+ return self._is_replaying
723
+
378
724
  def set_custom_status(self, custom_status: Any) -> None:
379
- self._encoded_custom_status = shared.to_json(custom_status) if custom_status is not None else None
725
+ self._encoded_custom_status = (
726
+ shared.to_json(custom_status) if custom_status is not None else None
727
+ )
380
728
 
381
729
  def create_timer(self, fire_at: Union[datetime, timedelta]) -> task.Task:
382
730
  return self.create_timer_internal(fire_at)
383
731
 
384
- def create_timer_internal(self, fire_at: Union[datetime, timedelta],
385
- retryable_task: Optional[task.RetryableTask] = None) -> task.Task:
732
+ def create_timer_internal(
733
+ self,
734
+ fire_at: Union[datetime, timedelta],
735
+ retryable_task: Optional[task.RetryableTask] = None,
736
+ ) -> task.Task:
386
737
  id = self.next_sequence_number()
387
738
  if isinstance(fire_at, timedelta):
388
739
  fire_at = self.current_utc_datetime + fire_at
389
740
  action = ph.new_create_timer_action(id, fire_at)
390
741
  self._pending_actions[id] = action
391
742
 
392
- timer_task = task.TimerTask()
743
+ timer_task: task.TimerTask = task.TimerTask()
393
744
  if retryable_task is not None:
394
745
  timer_task.set_retryable_parent(retryable_task)
395
746
  self._pending_tasks[id] = timer_task
396
747
  return timer_task
397
748
 
398
- def call_activity(self, activity: Union[task.Activity[TInput, TOutput], str], *,
399
- input: Optional[TInput] = None,
400
- retry_policy: Optional[task.RetryPolicy] = None) -> task.Task[TOutput]:
749
+ def call_activity(
750
+ self,
751
+ activity: Union[task.Activity[TInput, TOutput], str],
752
+ *,
753
+ input: Optional[TInput] = None,
754
+ retry_policy: Optional[task.RetryPolicy] = None,
755
+ ) -> task.Task[TOutput]:
401
756
  id = self.next_sequence_number()
402
757
 
403
- self.call_activity_function_helper(id, activity, input=input, retry_policy=retry_policy,
404
- is_sub_orch=False)
758
+ self.call_activity_function_helper(
759
+ id, activity, input=input, retry_policy=retry_policy, is_sub_orch=False
760
+ )
405
761
  return self._pending_tasks.get(id, task.CompletableTask())
406
762
 
407
- def call_sub_orchestrator(self, orchestrator: task.Orchestrator[TInput, TOutput], *,
408
- input: Optional[TInput] = None,
409
- instance_id: Optional[str] = None,
410
- retry_policy: Optional[task.RetryPolicy] = None) -> task.Task[TOutput]:
763
+ def call_sub_orchestrator(
764
+ self,
765
+ orchestrator: task.Orchestrator[TInput, TOutput],
766
+ *,
767
+ input: Optional[TInput] = None,
768
+ instance_id: Optional[str] = None,
769
+ retry_policy: Optional[task.RetryPolicy] = None,
770
+ ) -> task.Task[TOutput]:
411
771
  id = self.next_sequence_number()
412
772
  orchestrator_name = task.get_name(orchestrator)
413
- self.call_activity_function_helper(id, orchestrator_name, input=input, retry_policy=retry_policy,
414
- is_sub_orch=True, instance_id=instance_id)
773
+ self.call_activity_function_helper(
774
+ id,
775
+ orchestrator_name,
776
+ input=input,
777
+ retry_policy=retry_policy,
778
+ is_sub_orch=True,
779
+ instance_id=instance_id,
780
+ )
415
781
  return self._pending_tasks.get(id, task.CompletableTask())
416
782
 
417
- def call_activity_function_helper(self, id: Optional[int],
418
- activity_function: Union[task.Activity[TInput, TOutput], str], *,
419
- input: Optional[TInput] = None,
420
- retry_policy: Optional[task.RetryPolicy] = None,
421
- is_sub_orch: bool = False,
422
- instance_id: Optional[str] = None,
423
- fn_task: Optional[task.CompletableTask[TOutput]] = None):
783
+ def call_activity_function_helper(
784
+ self,
785
+ id: Optional[int],
786
+ activity_function: Union[task.Activity[TInput, TOutput], str],
787
+ *,
788
+ input: Optional[TInput] = None,
789
+ retry_policy: Optional[task.RetryPolicy] = None,
790
+ is_sub_orch: bool = False,
791
+ instance_id: Optional[str] = None,
792
+ fn_task: Optional[task.CompletableTask[TOutput]] = None,
793
+ ):
424
794
  if id is None:
425
795
  id = self.next_sequence_number()
426
796
 
@@ -431,7 +801,11 @@ class _RuntimeOrchestrationContext(task.OrchestrationContext):
431
801
  # We just need to take string representation of it.
432
802
  encoded_input = str(input)
433
803
  if not is_sub_orch:
434
- name = activity_function if isinstance(activity_function, str) else task.get_name(activity_function)
804
+ name = (
805
+ activity_function
806
+ if isinstance(activity_function, str)
807
+ else task.get_name(activity_function)
808
+ )
435
809
  action = ph.new_schedule_task_action(id, name, encoded_input)
436
810
  else:
437
811
  if instance_id is None:
@@ -439,16 +813,21 @@ class _RuntimeOrchestrationContext(task.OrchestrationContext):
439
813
  instance_id = f"{self.instance_id}:{id:04x}"
440
814
  if not isinstance(activity_function, str):
441
815
  raise ValueError("Orchestrator function name must be a string")
442
- action = ph.new_create_sub_orchestration_action(id, activity_function, instance_id, encoded_input)
816
+ action = ph.new_create_sub_orchestration_action(
817
+ id, activity_function, instance_id, encoded_input
818
+ )
443
819
  self._pending_actions[id] = action
444
820
 
445
821
  if fn_task is None:
446
822
  if retry_policy is None:
447
823
  fn_task = task.CompletableTask[TOutput]()
448
824
  else:
449
- fn_task = task.RetryableTask[TOutput](retry_policy=retry_policy, action=action,
450
- start_time=self.current_utc_datetime,
451
- is_sub_orch=is_sub_orch)
825
+ fn_task = task.RetryableTask[TOutput](
826
+ retry_policy=retry_policy,
827
+ action=action,
828
+ start_time=self.current_utc_datetime,
829
+ is_sub_orch=is_sub_orch,
830
+ )
452
831
  self._pending_tasks[id] = fn_task
453
832
 
454
833
  def wait_for_external_event(self, name: str) -> task.Task:
@@ -457,7 +836,7 @@ class _RuntimeOrchestrationContext(task.OrchestrationContext):
457
836
  # event with the given name so that we can resume the generator when it
458
837
  # arrives. If there are multiple events with the same name, we return
459
838
  # them in the order they were received.
460
- external_event_task = task.CompletableTask()
839
+ external_event_task: task.CompletableTask = task.CompletableTask()
461
840
  event_name = name.casefold()
462
841
  event_list = self._received_events.get(event_name, None)
463
842
  if event_list:
@@ -484,7 +863,9 @@ class ExecutionResults:
484
863
  actions: list[pb.OrchestratorAction]
485
864
  encoded_custom_status: Optional[str]
486
865
 
487
- def __init__(self, actions: list[pb.OrchestratorAction], encoded_custom_status: Optional[str]):
866
+ def __init__(
867
+ self, actions: list[pb.OrchestratorAction], encoded_custom_status: Optional[str]
868
+ ):
488
869
  self.actions = actions
489
870
  self.encoded_custom_status = encoded_custom_status
490
871
 
@@ -498,14 +879,23 @@ class _OrchestrationExecutor:
498
879
  self._is_suspended = False
499
880
  self._suspended_events: list[pb.HistoryEvent] = []
500
881
 
501
- def execute(self, instance_id: str, old_events: Sequence[pb.HistoryEvent], new_events: Sequence[pb.HistoryEvent]) -> ExecutionResults:
882
+ def execute(
883
+ self,
884
+ instance_id: str,
885
+ old_events: Sequence[pb.HistoryEvent],
886
+ new_events: Sequence[pb.HistoryEvent],
887
+ ) -> ExecutionResults:
502
888
  if not new_events:
503
- raise task.OrchestrationStateError("The new history event list must have at least one event in it.")
889
+ raise task.OrchestrationStateError(
890
+ "The new history event list must have at least one event in it."
891
+ )
504
892
 
505
893
  ctx = _RuntimeOrchestrationContext(instance_id)
506
894
  try:
507
895
  # Rebuild local state by replaying old history into the orchestrator function
508
- self._logger.debug(f"{instance_id}: Rebuilding local state with {len(old_events)} history event...")
896
+ self._logger.debug(
897
+ f"{instance_id}: Rebuilding local state with {len(old_events)} history event..."
898
+ )
509
899
  ctx._is_replaying = True
510
900
  for old_event in old_events:
511
901
  self.process_event(ctx, old_event)
@@ -513,7 +903,9 @@ class _OrchestrationExecutor:
513
903
  # Get new actions by executing newly received events into the orchestrator function
514
904
  if self._logger.level <= logging.DEBUG:
515
905
  summary = _get_new_event_summary(new_events)
516
- self._logger.debug(f"{instance_id}: Processing {len(new_events)} new event(s): {summary}")
906
+ self._logger.debug(
907
+ f"{instance_id}: Processing {len(new_events)} new event(s): {summary}"
908
+ )
517
909
  ctx._is_replaying = False
518
910
  for new_event in new_events:
519
911
  self.process_event(ctx, new_event)
@@ -525,17 +917,31 @@ class _OrchestrationExecutor:
525
917
  if not ctx._is_complete:
526
918
  task_count = len(ctx._pending_tasks)
527
919
  event_count = len(ctx._pending_events)
528
- self._logger.info(f"{instance_id}: Orchestrator yielded with {task_count} task(s) and {event_count} event(s) outstanding.")
529
- elif ctx._completion_status and ctx._completion_status is not pb.ORCHESTRATION_STATUS_CONTINUED_AS_NEW:
530
- completion_status_str = pbh.get_orchestration_status_str(ctx._completion_status)
531
- self._logger.info(f"{instance_id}: Orchestration completed with status: {completion_status_str}")
920
+ self._logger.info(
921
+ f"{instance_id}: Orchestrator yielded with {task_count} task(s) and {event_count} event(s) outstanding."
922
+ )
923
+ elif (
924
+ ctx._completion_status and ctx._completion_status is not pb.ORCHESTRATION_STATUS_CONTINUED_AS_NEW
925
+ ):
926
+ completion_status_str = ph.get_orchestration_status_str(
927
+ ctx._completion_status
928
+ )
929
+ self._logger.info(
930
+ f"{instance_id}: Orchestration completed with status: {completion_status_str}"
931
+ )
532
932
 
533
933
  actions = ctx.get_actions()
534
934
  if self._logger.level <= logging.DEBUG:
535
- self._logger.debug(f"{instance_id}: Returning {len(actions)} action(s): {_get_action_summary(actions)}")
536
- return ExecutionResults(actions=actions, encoded_custom_status=ctx._encoded_custom_status)
935
+ self._logger.debug(
936
+ f"{instance_id}: Returning {len(actions)} action(s): {_get_action_summary(actions)}"
937
+ )
938
+ return ExecutionResults(
939
+ actions=actions, encoded_custom_status=ctx._encoded_custom_status
940
+ )
537
941
 
538
- def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEvent) -> None:
942
+ def process_event(
943
+ self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEvent
944
+ ) -> None:
539
945
  if self._is_suspended and _is_suspendable(event):
540
946
  # We are suspended, so we need to buffer this event until we are resumed
541
947
  self._suspended_events.append(event)
@@ -550,14 +956,19 @@ class _OrchestrationExecutor:
550
956
  fn = self._registry.get_orchestrator(event.executionStarted.name)
551
957
  if fn is None:
552
958
  raise OrchestratorNotRegisteredError(
553
- f"A '{event.executionStarted.name}' orchestrator was not registered.")
959
+ f"A '{event.executionStarted.name}' orchestrator was not registered."
960
+ )
554
961
 
555
962
  # deserialize the input, if any
556
963
  input = None
557
- if event.executionStarted.input is not None and event.executionStarted.input.value != "":
964
+ if (
965
+ event.executionStarted.input is not None and event.executionStarted.input.value != ""
966
+ ):
558
967
  input = shared.from_json(event.executionStarted.input.value)
559
968
 
560
- result = fn(ctx, input) # this does not execute the generator, only creates it
969
+ result = fn(
970
+ ctx, input
971
+ ) # this does not execute the generator, only creates it
561
972
  if isinstance(result, GeneratorType):
562
973
  # Start the orchestrator's generator function
563
974
  ctx.run(result)
@@ -570,10 +981,14 @@ class _OrchestrationExecutor:
570
981
  timer_id = event.eventId
571
982
  action = ctx._pending_actions.pop(timer_id, None)
572
983
  if not action:
573
- raise _get_non_determinism_error(timer_id, task.get_name(ctx.create_timer))
984
+ raise _get_non_determinism_error(
985
+ timer_id, task.get_name(ctx.create_timer)
986
+ )
574
987
  elif not action.HasField("createTimer"):
575
988
  expected_method_name = task.get_name(ctx.create_timer)
576
- raise _get_wrong_action_type_error(timer_id, expected_method_name, action)
989
+ raise _get_wrong_action_type_error(
990
+ timer_id, expected_method_name, action
991
+ )
577
992
  elif event.HasField("timerFired"):
578
993
  timer_id = event.timerFired.timerId
579
994
  timer_task = ctx._pending_tasks.pop(timer_id, None)
@@ -581,7 +996,8 @@ class _OrchestrationExecutor:
581
996
  # TODO: Should this be an error? When would it ever happen?
582
997
  if not ctx._is_replaying:
583
998
  self._logger.warning(
584
- f"{ctx.instance_id}: Ignoring unexpected timerFired event with ID = {timer_id}.")
999
+ f"{ctx.instance_id}: Ignoring unexpected timerFired event with ID = {timer_id}."
1000
+ )
585
1001
  return
586
1002
  timer_task.complete(None)
587
1003
  if timer_task._retryable_parent is not None:
@@ -593,12 +1009,15 @@ class _OrchestrationExecutor:
593
1009
  else:
594
1010
  cur_task = activity_action.createSubOrchestration
595
1011
  instance_id = cur_task.instanceId
596
- ctx.call_activity_function_helper(id=activity_action.id, activity_function=cur_task.name,
597
- input=cur_task.input.value,
598
- retry_policy=timer_task._retryable_parent._retry_policy,
599
- is_sub_orch=timer_task._retryable_parent._is_sub_orch,
600
- instance_id=instance_id,
601
- fn_task=timer_task._retryable_parent)
1012
+ ctx.call_activity_function_helper(
1013
+ id=activity_action.id,
1014
+ activity_function=cur_task.name,
1015
+ input=cur_task.input.value,
1016
+ retry_policy=timer_task._retryable_parent._retry_policy,
1017
+ is_sub_orch=timer_task._retryable_parent._is_sub_orch,
1018
+ instance_id=instance_id,
1019
+ fn_task=timer_task._retryable_parent,
1020
+ )
602
1021
  else:
603
1022
  ctx.resume()
604
1023
  elif event.HasField("taskScheduled"):
@@ -608,16 +1027,21 @@ class _OrchestrationExecutor:
608
1027
  action = ctx._pending_actions.pop(task_id, None)
609
1028
  activity_task = ctx._pending_tasks.get(task_id, None)
610
1029
  if not action:
611
- raise _get_non_determinism_error(task_id, task.get_name(ctx.call_activity))
1030
+ raise _get_non_determinism_error(
1031
+ task_id, task.get_name(ctx.call_activity)
1032
+ )
612
1033
  elif not action.HasField("scheduleTask"):
613
1034
  expected_method_name = task.get_name(ctx.call_activity)
614
- raise _get_wrong_action_type_error(task_id, expected_method_name, action)
1035
+ raise _get_wrong_action_type_error(
1036
+ task_id, expected_method_name, action
1037
+ )
615
1038
  elif action.scheduleTask.name != event.taskScheduled.name:
616
1039
  raise _get_wrong_action_name_error(
617
1040
  task_id,
618
1041
  method_name=task.get_name(ctx.call_activity),
619
1042
  expected_task_name=event.taskScheduled.name,
620
- actual_task_name=action.scheduleTask.name)
1043
+ actual_task_name=action.scheduleTask.name,
1044
+ )
621
1045
  elif event.HasField("taskCompleted"):
622
1046
  # This history event contains the result of a completed activity task.
623
1047
  task_id = event.taskCompleted.taskScheduledId
@@ -626,7 +1050,8 @@ class _OrchestrationExecutor:
626
1050
  # TODO: Should this be an error? When would it ever happen?
627
1051
  if not ctx.is_replaying:
628
1052
  self._logger.warning(
629
- f"{ctx.instance_id}: Ignoring unexpected taskCompleted event with ID = {task_id}.")
1053
+ f"{ctx.instance_id}: Ignoring unexpected taskCompleted event with ID = {task_id}."
1054
+ )
630
1055
  return
631
1056
  result = None
632
1057
  if not ph.is_empty(event.taskCompleted.result):
@@ -640,7 +1065,8 @@ class _OrchestrationExecutor:
640
1065
  # TODO: Should this be an error? When would it ever happen?
641
1066
  if not ctx.is_replaying:
642
1067
  self._logger.warning(
643
- f"{ctx.instance_id}: Ignoring unexpected taskFailed event with ID = {task_id}.")
1068
+ f"{ctx.instance_id}: Ignoring unexpected taskFailed event with ID = {task_id}."
1069
+ )
644
1070
  return
645
1071
 
646
1072
  if isinstance(activity_task, task.RetryableTask):
@@ -649,7 +1075,8 @@ class _OrchestrationExecutor:
649
1075
  if next_delay is None:
650
1076
  activity_task.fail(
651
1077
  f"{ctx.instance_id}: Activity task #{task_id} failed: {event.taskFailed.failureDetails.errorMessage}",
652
- event.taskFailed.failureDetails)
1078
+ event.taskFailed.failureDetails,
1079
+ )
653
1080
  ctx.resume()
654
1081
  else:
655
1082
  activity_task.increment_attempt_count()
@@ -657,7 +1084,8 @@ class _OrchestrationExecutor:
657
1084
  elif isinstance(activity_task, task.CompletableTask):
658
1085
  activity_task.fail(
659
1086
  f"{ctx.instance_id}: Activity task #{task_id} failed: {event.taskFailed.failureDetails.errorMessage}",
660
- event.taskFailed.failureDetails)
1087
+ event.taskFailed.failureDetails,
1088
+ )
661
1089
  ctx.resume()
662
1090
  else:
663
1091
  raise TypeError("Unexpected task type")
@@ -667,16 +1095,23 @@ class _OrchestrationExecutor:
667
1095
  task_id = event.eventId
668
1096
  action = ctx._pending_actions.pop(task_id, None)
669
1097
  if not action:
670
- raise _get_non_determinism_error(task_id, task.get_name(ctx.call_sub_orchestrator))
1098
+ raise _get_non_determinism_error(
1099
+ task_id, task.get_name(ctx.call_sub_orchestrator)
1100
+ )
671
1101
  elif not action.HasField("createSubOrchestration"):
672
1102
  expected_method_name = task.get_name(ctx.call_sub_orchestrator)
673
- raise _get_wrong_action_type_error(task_id, expected_method_name, action)
674
- elif action.createSubOrchestration.name != event.subOrchestrationInstanceCreated.name:
1103
+ raise _get_wrong_action_type_error(
1104
+ task_id, expected_method_name, action
1105
+ )
1106
+ elif (
1107
+ action.createSubOrchestration.name != event.subOrchestrationInstanceCreated.name
1108
+ ):
675
1109
  raise _get_wrong_action_name_error(
676
1110
  task_id,
677
1111
  method_name=task.get_name(ctx.call_sub_orchestrator),
678
1112
  expected_task_name=event.subOrchestrationInstanceCreated.name,
679
- actual_task_name=action.createSubOrchestration.name)
1113
+ actual_task_name=action.createSubOrchestration.name,
1114
+ )
680
1115
  elif event.HasField("subOrchestrationInstanceCompleted"):
681
1116
  task_id = event.subOrchestrationInstanceCompleted.taskScheduledId
682
1117
  sub_orch_task = ctx._pending_tasks.pop(task_id, None)
@@ -684,11 +1119,14 @@ class _OrchestrationExecutor:
684
1119
  # TODO: Should this be an error? When would it ever happen?
685
1120
  if not ctx.is_replaying:
686
1121
  self._logger.warning(
687
- f"{ctx.instance_id}: Ignoring unexpected subOrchestrationInstanceCompleted event with ID = {task_id}.")
1122
+ f"{ctx.instance_id}: Ignoring unexpected subOrchestrationInstanceCompleted event with ID = {task_id}."
1123
+ )
688
1124
  return
689
1125
  result = None
690
1126
  if not ph.is_empty(event.subOrchestrationInstanceCompleted.result):
691
- result = shared.from_json(event.subOrchestrationInstanceCompleted.result.value)
1127
+ result = shared.from_json(
1128
+ event.subOrchestrationInstanceCompleted.result.value
1129
+ )
692
1130
  sub_orch_task.complete(result)
693
1131
  ctx.resume()
694
1132
  elif event.HasField("subOrchestrationInstanceFailed"):
@@ -699,7 +1137,8 @@ class _OrchestrationExecutor:
699
1137
  # TODO: Should this be an error? When would it ever happen?
700
1138
  if not ctx.is_replaying:
701
1139
  self._logger.warning(
702
- f"{ctx.instance_id}: Ignoring unexpected subOrchestrationInstanceFailed event with ID = {task_id}.")
1140
+ f"{ctx.instance_id}: Ignoring unexpected subOrchestrationInstanceFailed event with ID = {task_id}."
1141
+ )
703
1142
  return
704
1143
  if isinstance(sub_orch_task, task.RetryableTask):
705
1144
  if sub_orch_task._retry_policy is not None:
@@ -707,7 +1146,8 @@ class _OrchestrationExecutor:
707
1146
  if next_delay is None:
708
1147
  sub_orch_task.fail(
709
1148
  f"Sub-orchestration task #{task_id} failed: {failedEvent.failureDetails.errorMessage}",
710
- failedEvent.failureDetails)
1149
+ failedEvent.failureDetails,
1150
+ )
711
1151
  ctx.resume()
712
1152
  else:
713
1153
  sub_orch_task.increment_attempt_count()
@@ -715,7 +1155,8 @@ class _OrchestrationExecutor:
715
1155
  elif isinstance(sub_orch_task, task.CompletableTask):
716
1156
  sub_orch_task.fail(
717
1157
  f"Sub-orchestration task #{task_id} failed: {failedEvent.failureDetails.errorMessage}",
718
- failedEvent.failureDetails)
1158
+ failedEvent.failureDetails,
1159
+ )
719
1160
  ctx.resume()
720
1161
  else:
721
1162
  raise TypeError("Unexpected sub-orchestration task type")
@@ -744,7 +1185,9 @@ class _OrchestrationExecutor:
744
1185
  decoded_result = shared.from_json(event.eventRaised.input.value)
745
1186
  event_list.append(decoded_result)
746
1187
  if not ctx.is_replaying:
747
- self._logger.info(f"{ctx.instance_id}: Event '{event_name}' has been buffered as there are no tasks waiting for it.")
1188
+ self._logger.info(
1189
+ f"{ctx.instance_id}: Event '{event_name}' has been buffered as there are no tasks waiting for it."
1190
+ )
748
1191
  elif event.HasField("executionSuspended"):
749
1192
  if not self._is_suspended and not ctx.is_replaying:
750
1193
  self._logger.info(f"{ctx.instance_id}: Execution suspended.")
@@ -759,11 +1202,21 @@ class _OrchestrationExecutor:
759
1202
  elif event.HasField("executionTerminated"):
760
1203
  if not ctx.is_replaying:
761
1204
  self._logger.info(f"{ctx.instance_id}: Execution terminating.")
762
- encoded_output = event.executionTerminated.input.value if not ph.is_empty(event.executionTerminated.input) else None
763
- ctx.set_complete(encoded_output, pb.ORCHESTRATION_STATUS_TERMINATED, is_result_encoded=True)
1205
+ encoded_output = (
1206
+ event.executionTerminated.input.value
1207
+ if not ph.is_empty(event.executionTerminated.input)
1208
+ else None
1209
+ )
1210
+ ctx.set_complete(
1211
+ encoded_output,
1212
+ pb.ORCHESTRATION_STATUS_TERMINATED,
1213
+ is_result_encoded=True,
1214
+ )
764
1215
  else:
765
1216
  eventType = event.WhichOneof("eventType")
766
- raise task.OrchestrationStateError(f"Don't know how to handle event of type '{eventType}'")
1217
+ raise task.OrchestrationStateError(
1218
+ f"Don't know how to handle event of type '{eventType}'"
1219
+ )
767
1220
  except StopIteration as generatorStopped:
768
1221
  # The orchestrator generator function completed
769
1222
  ctx.set_complete(generatorStopped.value, pb.ORCHESTRATION_STATUS_COMPLETED)
@@ -774,12 +1227,22 @@ class _ActivityExecutor:
774
1227
  self._registry = registry
775
1228
  self._logger = logger
776
1229
 
777
- def execute(self, orchestration_id: str, name: str, task_id: int, encoded_input: Optional[str]) -> Optional[str]:
1230
+ def execute(
1231
+ self,
1232
+ orchestration_id: str,
1233
+ name: str,
1234
+ task_id: int,
1235
+ encoded_input: Optional[str],
1236
+ ) -> Optional[str]:
778
1237
  """Executes an activity function and returns the serialized result, if any."""
779
- self._logger.debug(f"{orchestration_id}/{task_id}: Executing activity '{name}'...")
1238
+ self._logger.debug(
1239
+ f"{orchestration_id}/{task_id}: Executing activity '{name}'..."
1240
+ )
780
1241
  fn = self._registry.get_activity(name)
781
1242
  if not fn:
782
- raise ActivityNotRegisteredError(f"Activity function named '{name}' was not registered!")
1243
+ raise ActivityNotRegisteredError(
1244
+ f"Activity function named '{name}' was not registered!"
1245
+ )
783
1246
 
784
1247
  activity_input = shared.from_json(encoded_input) if encoded_input else None
785
1248
  ctx = task.ActivityContext(orchestration_id, task_id)
@@ -787,49 +1250,54 @@ class _ActivityExecutor:
787
1250
  # Execute the activity function
788
1251
  activity_output = fn(ctx, activity_input)
789
1252
 
790
- encoded_output = shared.to_json(activity_output) if activity_output is not None else None
1253
+ encoded_output = (
1254
+ shared.to_json(activity_output) if activity_output is not None else None
1255
+ )
791
1256
  chars = len(encoded_output) if encoded_output else 0
792
1257
  self._logger.debug(
793
- f"{orchestration_id}/{task_id}: Activity '{name}' completed successfully with {chars} char(s) of encoded output.")
1258
+ f"{orchestration_id}/{task_id}: Activity '{name}' completed successfully with {chars} char(s) of encoded output."
1259
+ )
794
1260
  return encoded_output
795
1261
 
796
1262
 
797
- def _get_non_determinism_error(task_id: int, action_name: str) -> task.NonDeterminismError:
1263
+ def _get_non_determinism_error(
1264
+ task_id: int, action_name: str
1265
+ ) -> task.NonDeterminismError:
798
1266
  return task.NonDeterminismError(
799
1267
  f"A previous execution called {action_name} with ID={task_id}, but the current "
800
1268
  f"execution doesn't have this action with this ID. This problem occurs when either "
801
1269
  f"the orchestration has non-deterministic logic or if the code was changed after an "
802
- f"instance of this orchestration already started running.")
1270
+ f"instance of this orchestration already started running."
1271
+ )
803
1272
 
804
1273
 
805
1274
  def _get_wrong_action_type_error(
806
- task_id: int,
807
- expected_method_name: str,
808
- action: pb.OrchestratorAction) -> task.NonDeterminismError:
1275
+ task_id: int, expected_method_name: str, action: pb.OrchestratorAction
1276
+ ) -> task.NonDeterminismError:
809
1277
  unexpected_method_name = _get_method_name_for_action(action)
810
1278
  return task.NonDeterminismError(
811
1279
  f"Failed to restore orchestration state due to a history mismatch: A previous execution called "
812
1280
  f"{expected_method_name} with ID={task_id}, but the current execution is instead trying to call "
813
1281
  f"{unexpected_method_name} as part of rebuilding it's history. This kind of mismatch can happen if an "
814
1282
  f"orchestration has non-deterministic logic or if the code was changed after an instance of this "
815
- f"orchestration already started running.")
1283
+ f"orchestration already started running."
1284
+ )
816
1285
 
817
1286
 
818
1287
  def _get_wrong_action_name_error(
819
- task_id: int,
820
- method_name: str,
821
- expected_task_name: str,
822
- actual_task_name: str) -> task.NonDeterminismError:
1288
+ task_id: int, method_name: str, expected_task_name: str, actual_task_name: str
1289
+ ) -> task.NonDeterminismError:
823
1290
  return task.NonDeterminismError(
824
1291
  f"Failed to restore orchestration state due to a history mismatch: A previous execution called "
825
1292
  f"{method_name} with name='{expected_task_name}' and sequence number {task_id}, but the current "
826
1293
  f"execution is instead trying to call {actual_task_name} as part of rebuilding it's history. "
827
1294
  f"This kind of mismatch can happen if an orchestration has non-deterministic logic or if the code "
828
- f"was changed after an instance of this orchestration already started running.")
1295
+ f"was changed after an instance of this orchestration already started running."
1296
+ )
829
1297
 
830
1298
 
831
1299
  def _get_method_name_for_action(action: pb.OrchestratorAction) -> str:
832
- action_type = action.WhichOneof('orchestratorActionType')
1300
+ action_type = action.WhichOneof("orchestratorActionType")
833
1301
  if action_type == "scheduleTask":
834
1302
  return task.get_name(task.OrchestrationContext.call_activity)
835
1303
  elif action_type == "createTimer":
@@ -851,7 +1319,7 @@ def _get_new_event_summary(new_events: Sequence[pb.HistoryEvent]) -> str:
851
1319
  else:
852
1320
  counts: dict[str, int] = {}
853
1321
  for event in new_events:
854
- event_type = event.WhichOneof('eventType')
1322
+ event_type = event.WhichOneof("eventType")
855
1323
  counts[event_type] = counts.get(event_type, 0) + 1
856
1324
  return f"[{', '.join(f'{name}={count}' for name, count in counts.items())}]"
857
1325
 
@@ -865,11 +1333,210 @@ def _get_action_summary(new_actions: Sequence[pb.OrchestratorAction]) -> str:
865
1333
  else:
866
1334
  counts: dict[str, int] = {}
867
1335
  for action in new_actions:
868
- action_type = action.WhichOneof('orchestratorActionType')
1336
+ action_type = action.WhichOneof("orchestratorActionType")
869
1337
  counts[action_type] = counts.get(action_type, 0) + 1
870
1338
  return f"[{', '.join(f'{name}={count}' for name, count in counts.items())}]"
871
1339
 
872
1340
 
873
1341
  def _is_suspendable(event: pb.HistoryEvent) -> bool:
874
1342
  """Returns true if the event is one that can be suspended and resumed."""
875
- return event.WhichOneof("eventType") not in ["executionResumed", "executionTerminated"]
1343
+ return event.WhichOneof("eventType") not in [
1344
+ "executionResumed",
1345
+ "executionTerminated",
1346
+ ]
1347
+
1348
+
1349
+ class _AsyncWorkerManager:
1350
+ def __init__(self, concurrency_options: ConcurrencyOptions):
1351
+ self.concurrency_options = concurrency_options
1352
+ self.activity_semaphore = None
1353
+ self.orchestration_semaphore = None
1354
+ # Don't create queues here - defer until we have an event loop
1355
+ self.activity_queue: Optional[asyncio.Queue] = None
1356
+ self.orchestration_queue: Optional[asyncio.Queue] = None
1357
+ self._queue_event_loop: Optional[asyncio.AbstractEventLoop] = None
1358
+ # Store work items when no event loop is available
1359
+ self._pending_activity_work: list = []
1360
+ self._pending_orchestration_work: list = []
1361
+ self.thread_pool = ThreadPoolExecutor(
1362
+ max_workers=concurrency_options.maximum_thread_pool_workers,
1363
+ thread_name_prefix="DurableTask",
1364
+ )
1365
+ self._shutdown = False
1366
+
1367
+ def _ensure_queues_for_current_loop(self):
1368
+ """Ensure queues are bound to the current event loop."""
1369
+ try:
1370
+ current_loop = asyncio.get_running_loop()
1371
+ except RuntimeError:
1372
+ # No event loop running, can't create queues
1373
+ return
1374
+
1375
+ # Check if queues are already properly set up for current loop
1376
+ if self._queue_event_loop is current_loop:
1377
+ if self.activity_queue is not None and self.orchestration_queue is not None:
1378
+ # Queues are already bound to the current loop and exist
1379
+ return
1380
+
1381
+ # Need to recreate queues for the current event loop
1382
+ # First, preserve any existing work items
1383
+ existing_activity_items = []
1384
+ existing_orchestration_items = []
1385
+
1386
+ if self.activity_queue is not None:
1387
+ try:
1388
+ while not self.activity_queue.empty():
1389
+ existing_activity_items.append(self.activity_queue.get_nowait())
1390
+ except Exception:
1391
+ pass
1392
+
1393
+ if self.orchestration_queue is not None:
1394
+ try:
1395
+ while not self.orchestration_queue.empty():
1396
+ existing_orchestration_items.append(
1397
+ self.orchestration_queue.get_nowait()
1398
+ )
1399
+ except Exception:
1400
+ pass
1401
+
1402
+ # Create fresh queues for the current event loop
1403
+ self.activity_queue = asyncio.Queue()
1404
+ self.orchestration_queue = asyncio.Queue()
1405
+ self._queue_event_loop = current_loop
1406
+
1407
+ # Restore the work items to the new queues
1408
+ for item in existing_activity_items:
1409
+ self.activity_queue.put_nowait(item)
1410
+ for item in existing_orchestration_items:
1411
+ self.orchestration_queue.put_nowait(item)
1412
+
1413
+ # Move pending work items to the queues
1414
+ for item in self._pending_activity_work:
1415
+ self.activity_queue.put_nowait(item)
1416
+ for item in self._pending_orchestration_work:
1417
+ self.orchestration_queue.put_nowait(item)
1418
+
1419
+ # Clear the pending work lists
1420
+ self._pending_activity_work.clear()
1421
+ self._pending_orchestration_work.clear()
1422
+
1423
+ async def run(self):
1424
+ # Reset shutdown flag in case this manager is being reused
1425
+ self._shutdown = False
1426
+
1427
+ # Ensure queues are properly bound to the current event loop
1428
+ self._ensure_queues_for_current_loop()
1429
+
1430
+ # Create semaphores in the current event loop
1431
+ self.activity_semaphore = asyncio.Semaphore(
1432
+ self.concurrency_options.maximum_concurrent_activity_work_items
1433
+ )
1434
+ self.orchestration_semaphore = asyncio.Semaphore(
1435
+ self.concurrency_options.maximum_concurrent_orchestration_work_items
1436
+ )
1437
+
1438
+ # Start background consumers for each work type
1439
+ if self.activity_queue is not None and self.orchestration_queue is not None:
1440
+ await asyncio.gather(
1441
+ self._consume_queue(self.activity_queue, self.activity_semaphore),
1442
+ self._consume_queue(
1443
+ self.orchestration_queue, self.orchestration_semaphore
1444
+ ),
1445
+ )
1446
+
1447
+ async def _consume_queue(self, queue: asyncio.Queue, semaphore: asyncio.Semaphore):
1448
+ # List to track running tasks
1449
+ running_tasks: set[asyncio.Task] = set()
1450
+
1451
+ while True:
1452
+ # Clean up completed tasks
1453
+ done_tasks = {task for task in running_tasks if task.done()}
1454
+ running_tasks -= done_tasks
1455
+
1456
+ # Exit if shutdown is set and the queue is empty and no tasks are running
1457
+ if self._shutdown and queue.empty() and not running_tasks:
1458
+ break
1459
+
1460
+ try:
1461
+ work = await asyncio.wait_for(queue.get(), timeout=1.0)
1462
+ except asyncio.TimeoutError:
1463
+ continue
1464
+
1465
+ func, args, kwargs = work
1466
+ # Create a concurrent task for processing
1467
+ task = asyncio.create_task(
1468
+ self._process_work_item(semaphore, queue, func, args, kwargs)
1469
+ )
1470
+ running_tasks.add(task)
1471
+
1472
+ async def _process_work_item(
1473
+ self, semaphore: asyncio.Semaphore, queue: asyncio.Queue, func, args, kwargs
1474
+ ):
1475
+ async with semaphore:
1476
+ try:
1477
+ await self._run_func(func, *args, **kwargs)
1478
+ finally:
1479
+ queue.task_done()
1480
+
1481
+ async def _run_func(self, func, *args, **kwargs):
1482
+ if inspect.iscoroutinefunction(func):
1483
+ return await func(*args, **kwargs)
1484
+ else:
1485
+ loop = asyncio.get_running_loop()
1486
+ # Avoid submitting to executor after shutdown
1487
+ if (
1488
+ getattr(self, "_shutdown", False) and getattr(self, "thread_pool", None) and getattr(
1489
+ self.thread_pool, "_shutdown", False)
1490
+ ):
1491
+ return None
1492
+ return await loop.run_in_executor(
1493
+ self.thread_pool, lambda: func(*args, **kwargs)
1494
+ )
1495
+
1496
+ def submit_activity(self, func, *args, **kwargs):
1497
+ work_item = (func, args, kwargs)
1498
+ self._ensure_queues_for_current_loop()
1499
+ if self.activity_queue is not None:
1500
+ self.activity_queue.put_nowait(work_item)
1501
+ else:
1502
+ # No event loop running, store in pending list
1503
+ self._pending_activity_work.append(work_item)
1504
+
1505
+ def submit_orchestration(self, func, *args, **kwargs):
1506
+ work_item = (func, args, kwargs)
1507
+ self._ensure_queues_for_current_loop()
1508
+ if self.orchestration_queue is not None:
1509
+ self.orchestration_queue.put_nowait(work_item)
1510
+ else:
1511
+ # No event loop running, store in pending list
1512
+ self._pending_orchestration_work.append(work_item)
1513
+
1514
+ def shutdown(self):
1515
+ self._shutdown = True
1516
+ self.thread_pool.shutdown(wait=True)
1517
+
1518
+ def reset_for_new_run(self):
1519
+ """Reset the manager state for a new run."""
1520
+ self._shutdown = False
1521
+ # Clear any existing queues - they'll be recreated when needed
1522
+ if self.activity_queue is not None:
1523
+ # Clear existing queue by creating a new one
1524
+ # This ensures no items from previous runs remain
1525
+ try:
1526
+ while not self.activity_queue.empty():
1527
+ self.activity_queue.get_nowait()
1528
+ except Exception:
1529
+ pass
1530
+ if self.orchestration_queue is not None:
1531
+ try:
1532
+ while not self.orchestration_queue.empty():
1533
+ self.orchestration_queue.get_nowait()
1534
+ except Exception:
1535
+ pass
1536
+ # Clear pending work lists
1537
+ self._pending_activity_work.clear()
1538
+ self._pending_orchestration_work.clear()
1539
+
1540
+
1541
+ # Export public API
1542
+ __all__ = ["ConcurrencyOptions", "TaskHubGrpcWorker"]