durabletask 0.2.1__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of durabletask might be problematic. Click here for more details.

durabletask/worker.py CHANGED
@@ -1,32 +1,130 @@
1
1
  # Copyright (c) Microsoft Corporation.
2
2
  # Licensed under the MIT License.
3
3
 
4
- import concurrent.futures
4
+ import asyncio
5
+ import inspect
5
6
  import logging
7
+ import os
8
+ import random
9
+ from concurrent.futures import ThreadPoolExecutor
6
10
  from datetime import datetime, timedelta
7
11
  from threading import Event, Thread
8
12
  from types import GeneratorType
13
+ from enum import Enum
9
14
  from typing import Any, Generator, Optional, Sequence, TypeVar, Union
15
+ from packaging.version import InvalidVersion, parse
10
16
 
11
17
  import grpc
12
18
  from google.protobuf import empty_pb2
13
19
 
14
20
  import durabletask.internal.helpers as ph
15
- import durabletask.internal.helpers as pbh
21
+ import durabletask.internal.exceptions as pe
16
22
  import durabletask.internal.orchestrator_service_pb2 as pb
17
23
  import durabletask.internal.orchestrator_service_pb2_grpc as stubs
18
24
  import durabletask.internal.shared as shared
19
25
  from durabletask import task
20
26
  from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl
21
27
 
22
- TInput = TypeVar('TInput')
23
- TOutput = TypeVar('TOutput')
28
+ TInput = TypeVar("TInput")
29
+ TOutput = TypeVar("TOutput")
30
+
31
+
32
+ class ConcurrencyOptions:
33
+ """Configuration options for controlling concurrency of different work item types and the thread pool size.
34
+
35
+ This class provides fine-grained control over concurrent processing limits for
36
+ activities, orchestrations and the thread pool size.
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ maximum_concurrent_activity_work_items: Optional[int] = None,
42
+ maximum_concurrent_orchestration_work_items: Optional[int] = None,
43
+ maximum_thread_pool_workers: Optional[int] = None,
44
+ ):
45
+ """Initialize concurrency options.
46
+
47
+ Args:
48
+ maximum_concurrent_activity_work_items: Maximum number of activity work items
49
+ that can be processed concurrently. Defaults to 100 * processor_count.
50
+ maximum_concurrent_orchestration_work_items: Maximum number of orchestration work items
51
+ that can be processed concurrently. Defaults to 100 * processor_count.
52
+ maximum_thread_pool_workers: Maximum number of thread pool workers to use.
53
+ """
54
+ processor_count = os.cpu_count() or 1
55
+ default_concurrency = 100 * processor_count
56
+ # see https://docs.python.org/3/library/concurrent.futures.html
57
+ default_max_workers = processor_count + 4
58
+
59
+ self.maximum_concurrent_activity_work_items = (
60
+ maximum_concurrent_activity_work_items
61
+ if maximum_concurrent_activity_work_items is not None
62
+ else default_concurrency
63
+ )
24
64
 
65
+ self.maximum_concurrent_orchestration_work_items = (
66
+ maximum_concurrent_orchestration_work_items
67
+ if maximum_concurrent_orchestration_work_items is not None
68
+ else default_concurrency
69
+ )
25
70
 
26
- class _Registry:
71
+ self.maximum_thread_pool_workers = (
72
+ maximum_thread_pool_workers
73
+ if maximum_thread_pool_workers is not None
74
+ else default_max_workers
75
+ )
76
+
77
+
78
+ class VersionMatchStrategy(Enum):
79
+ """Enumeration for version matching strategies."""
80
+
81
+ NONE = 1
82
+ STRICT = 2
83
+ CURRENT_OR_OLDER = 3
84
+
85
+
86
+ class VersionFailureStrategy(Enum):
87
+ """Enumeration for version failure strategies."""
88
+
89
+ REJECT = 1
90
+ FAIL = 2
91
+
92
+
93
+ class VersioningOptions:
94
+ """Configuration options for orchestrator and activity versioning.
27
95
 
96
+ This class provides options to control how versioning is handled for orchestrators
97
+ and activities, including whether to use the default version and how to compare versions.
98
+ """
99
+
100
+ version: Optional[str] = None
101
+ default_version: Optional[str] = None
102
+ match_strategy: Optional[VersionMatchStrategy] = None
103
+ failure_strategy: Optional[VersionFailureStrategy] = None
104
+
105
+ def __init__(self, version: Optional[str] = None,
106
+ default_version: Optional[str] = None,
107
+ match_strategy: Optional[VersionMatchStrategy] = None,
108
+ failure_strategy: Optional[VersionFailureStrategy] = None
109
+ ):
110
+ """Initialize versioning options.
111
+
112
+ Args:
113
+ version: The version of orchestrations that the worker can work on.
114
+ default_version: The default version that will be used for starting new orchestrations.
115
+ match_strategy: The versioning strategy for the Durable Task worker.
116
+ failure_strategy: The versioning failure strategy for the Durable Task worker.
117
+ """
118
+ self.version = version
119
+ self.default_version = default_version
120
+ self.match_strategy = match_strategy
121
+ self.failure_strategy = failure_strategy
122
+
123
+
124
+ class _Registry:
28
125
  orchestrators: dict[str, task.Orchestrator]
29
126
  activities: dict[str, task.Activity]
127
+ versioning: Optional[VersioningOptions] = None
30
128
 
31
129
  def __init__(self):
32
130
  self.orchestrators = {}
@@ -34,7 +132,7 @@ class _Registry:
34
132
 
35
133
  def add_orchestrator(self, fn: task.Orchestrator) -> str:
36
134
  if fn is None:
37
- raise ValueError('An orchestrator function argument is required.')
135
+ raise ValueError("An orchestrator function argument is required.")
38
136
 
39
137
  name = task.get_name(fn)
40
138
  self.add_named_orchestrator(name, fn)
@@ -42,7 +140,7 @@ class _Registry:
42
140
 
43
141
  def add_named_orchestrator(self, name: str, fn: task.Orchestrator) -> None:
44
142
  if not name:
45
- raise ValueError('A non-empty orchestrator name is required.')
143
+ raise ValueError("A non-empty orchestrator name is required.")
46
144
  if name in self.orchestrators:
47
145
  raise ValueError(f"A '{name}' orchestrator already exists.")
48
146
 
@@ -53,7 +151,7 @@ class _Registry:
53
151
 
54
152
  def add_activity(self, fn: task.Activity) -> str:
55
153
  if fn is None:
56
- raise ValueError('An activity function argument is required.')
154
+ raise ValueError("An activity function argument is required.")
57
155
 
58
156
  name = task.get_name(fn)
59
157
  self.add_named_activity(name, fn)
@@ -61,7 +159,7 @@ class _Registry:
61
159
 
62
160
  def add_named_activity(self, name: str, fn: task.Activity) -> None:
63
161
  if not name:
64
- raise ValueError('A non-empty activity name is required.')
162
+ raise ValueError("A non-empty activity name is required.")
65
163
  if name in self.activities:
66
164
  raise ValueError(f"A '{name}' activity already exists.")
67
165
 
@@ -73,32 +171,125 @@ class _Registry:
73
171
 
74
172
  class OrchestratorNotRegisteredError(ValueError):
75
173
  """Raised when attempting to start an orchestration that is not registered"""
174
+
76
175
  pass
77
176
 
78
177
 
79
178
  class ActivityNotRegisteredError(ValueError):
80
179
  """Raised when attempting to call an activity that is not registered"""
180
+
81
181
  pass
82
182
 
83
183
 
84
184
  class TaskHubGrpcWorker:
185
+ """A gRPC-based worker for processing durable task orchestrations and activities.
186
+
187
+ This worker connects to a Durable Task backend service via gRPC to receive and process
188
+ work items including orchestration functions and activity functions. It provides
189
+ concurrent execution capabilities with configurable limits and automatic retry handling.
190
+
191
+ The worker manages the complete lifecycle:
192
+ - Registers orchestrator and activity functions
193
+ - Connects to the gRPC backend service
194
+ - Receives work items and executes them concurrently
195
+ - Handles failures, retries, and state management
196
+ - Provides logging and monitoring capabilities
197
+
198
+ Args:
199
+ host_address (Optional[str], optional): The gRPC endpoint address of the backend service.
200
+ Defaults to the value from environment variables or localhost.
201
+ metadata (Optional[list[tuple[str, str]]], optional): gRPC metadata to include with
202
+ requests. Used for authentication and routing. Defaults to None.
203
+ log_handler (optional): Custom logging handler for worker logs. Defaults to None.
204
+ log_formatter (Optional[logging.Formatter], optional): Custom log formatter.
205
+ Defaults to None.
206
+ secure_channel (bool, optional): Whether to use a secure gRPC channel (TLS).
207
+ Defaults to False.
208
+ interceptors (Optional[Sequence[shared.ClientInterceptor]], optional): Custom gRPC
209
+ interceptors to apply to the channel. Defaults to None.
210
+ concurrency_options (Optional[ConcurrencyOptions], optional): Configuration for
211
+ controlling worker concurrency limits. If None, default settings are used.
212
+
213
+ Attributes:
214
+ concurrency_options (ConcurrencyOptions): The current concurrency configuration.
215
+
216
+ Example:
217
+ Basic worker setup:
218
+
219
+ >>> from durabletask.worker import TaskHubGrpcWorker, ConcurrencyOptions
220
+ >>>
221
+ >>> # Create worker with custom concurrency settings
222
+ >>> concurrency = ConcurrencyOptions(
223
+ ... maximum_concurrent_activity_work_items=50,
224
+ ... maximum_concurrent_orchestration_work_items=20
225
+ ... )
226
+ >>> worker = TaskHubGrpcWorker(
227
+ ... host_address="localhost:4001",
228
+ ... concurrency_options=concurrency
229
+ ... )
230
+ >>>
231
+ >>> # Register functions
232
+ >>> @worker.add_orchestrator
233
+ ... def my_orchestrator(context, input):
234
+ ... result = yield context.call_activity("my_activity", input="hello")
235
+ ... return result
236
+ >>>
237
+ >>> @worker.add_activity
238
+ ... def my_activity(context, input):
239
+ ... return f"Processed: {input}"
240
+ >>>
241
+ >>> # Start the worker
242
+ >>> worker.start()
243
+ >>> # ... worker runs in background thread
244
+ >>> worker.stop()
245
+
246
+ Using as context manager:
247
+
248
+ >>> with TaskHubGrpcWorker() as worker:
249
+ ... worker.add_orchestrator(my_orchestrator)
250
+ ... worker.add_activity(my_activity)
251
+ ... worker.start()
252
+ ... # Worker automatically stops when exiting context
253
+
254
+ Raises:
255
+ RuntimeError: If attempting to add orchestrators/activities while the worker is running,
256
+ or if starting a worker that is already running.
257
+ OrchestratorNotRegisteredError: If an orchestration work item references an
258
+ unregistered orchestrator function.
259
+ ActivityNotRegisteredError: If an activity work item references an unregistered
260
+ activity function.
261
+ """
262
+
85
263
  _response_stream: Optional[grpc.Future] = None
86
264
  _interceptors: Optional[list[shared.ClientInterceptor]] = None
87
265
 
88
- def __init__(self, *,
89
- host_address: Optional[str] = None,
90
- metadata: Optional[list[tuple[str, str]]] = None,
91
- log_handler=None,
92
- log_formatter: Optional[logging.Formatter] = None,
93
- secure_channel: bool = False,
94
- interceptors: Optional[Sequence[shared.ClientInterceptor]] = None):
266
+ def __init__(
267
+ self,
268
+ *,
269
+ host_address: Optional[str] = None,
270
+ metadata: Optional[list[tuple[str, str]]] = None,
271
+ log_handler=None,
272
+ log_formatter: Optional[logging.Formatter] = None,
273
+ secure_channel: bool = False,
274
+ interceptors: Optional[Sequence[shared.ClientInterceptor]] = None,
275
+ concurrency_options: Optional[ConcurrencyOptions] = None,
276
+ ):
95
277
  self._registry = _Registry()
96
- self._host_address = host_address if host_address else shared.get_default_host_address()
278
+ self._host_address = (
279
+ host_address if host_address else shared.get_default_host_address()
280
+ )
97
281
  self._logger = shared.get_logger("worker", log_handler, log_formatter)
98
282
  self._shutdown = Event()
99
283
  self._is_running = False
100
284
  self._secure_channel = secure_channel
101
285
 
286
+ # Use provided concurrency options or create default ones
287
+ self._concurrency_options = (
288
+ concurrency_options
289
+ if concurrency_options is not None
290
+ else ConcurrencyOptions()
291
+ )
292
+
102
293
  # Determine the interceptors to use
103
294
  if interceptors is not None:
104
295
  self._interceptors = list(interceptors)
@@ -109,6 +300,13 @@ class TaskHubGrpcWorker:
109
300
  else:
110
301
  self._interceptors = None
111
302
 
303
+ self._async_worker_manager = _AsyncWorkerManager(self._concurrency_options)
304
+
305
+ @property
306
+ def concurrency_options(self) -> ConcurrencyOptions:
307
+ """Get the current concurrency options for this worker."""
308
+ return self._concurrency_options
309
+
112
310
  def __enter__(self):
113
311
  return self
114
312
 
@@ -118,72 +316,229 @@ class TaskHubGrpcWorker:
118
316
  def add_orchestrator(self, fn: task.Orchestrator) -> str:
119
317
  """Registers an orchestrator function with the worker."""
120
318
  if self._is_running:
121
- raise RuntimeError('Orchestrators cannot be added while the worker is running.')
319
+ raise RuntimeError(
320
+ "Orchestrators cannot be added while the worker is running."
321
+ )
122
322
  return self._registry.add_orchestrator(fn)
123
323
 
124
324
  def add_activity(self, fn: task.Activity) -> str:
125
325
  """Registers an activity function with the worker."""
126
326
  if self._is_running:
127
- raise RuntimeError('Activities cannot be added while the worker is running.')
327
+ raise RuntimeError(
328
+ "Activities cannot be added while the worker is running."
329
+ )
128
330
  return self._registry.add_activity(fn)
129
331
 
332
+ def use_versioning(self, version: VersioningOptions) -> None:
333
+ """Initializes versioning options for sub-orchestrators and activities."""
334
+ if self._is_running:
335
+ raise RuntimeError("Cannot set default version while the worker is running.")
336
+ self._registry.versioning = version
337
+
130
338
  def start(self):
131
339
  """Starts the worker on a background thread and begins listening for work items."""
132
- channel = shared.get_grpc_channel(self._host_address, self._secure_channel, self._interceptors)
133
- stub = stubs.TaskHubSidecarServiceStub(channel)
134
-
135
340
  if self._is_running:
136
- raise RuntimeError('The worker is already running.')
341
+ raise RuntimeError("The worker is already running.")
137
342
 
138
343
  def run_loop():
139
- # TODO: Investigate whether asyncio could be used to enable greater concurrency for async activity
140
- # functions. We'd need to know ahead of time whether a function is async or not.
141
- # TODO: Max concurrency configuration settings
142
- with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor:
143
- while not self._shutdown.is_set():
144
- try:
145
- # send a "Hello" message to the sidecar to ensure that it's listening
146
- stub.Hello(empty_pb2.Empty())
147
-
148
- # stream work items
149
- self._response_stream = stub.GetWorkItems(pb.GetWorkItemsRequest())
150
- self._logger.info(f'Successfully connected to {self._host_address}. Waiting for work items...')
151
-
152
- # The stream blocks until either a work item is received or the stream is canceled
153
- # by another thread (see the stop() method).
154
- for work_item in self._response_stream: # type: ignore
155
- request_type = work_item.WhichOneof('request')
156
- self._logger.debug(f'Received "{request_type}" work item')
157
- if work_item.HasField('orchestratorRequest'):
158
- executor.submit(self._execute_orchestrator, work_item.orchestratorRequest, stub, work_item.completionToken)
159
- elif work_item.HasField('activityRequest'):
160
- executor.submit(self._execute_activity, work_item.activityRequest, stub, work_item.completionToken)
161
- elif work_item.HasField('healthPing'):
162
- pass # no-op
163
- else:
164
- self._logger.warning(f'Unexpected work item type: {request_type}')
165
-
166
- except grpc.RpcError as rpc_error:
167
- if rpc_error.code() == grpc.StatusCode.CANCELLED: # type: ignore
168
- self._logger.info(f'Disconnected from {self._host_address}')
169
- elif rpc_error.code() == grpc.StatusCode.UNAVAILABLE: # type: ignore
170
- self._logger.warning(
171
- f'The sidecar at address {self._host_address} is unavailable - will continue retrying')
172
- else:
173
- self._logger.warning(f'Unexpected error: {rpc_error}')
174
- except Exception as ex:
175
- self._logger.warning(f'Unexpected error: {ex}')
176
-
177
- # CONSIDER: exponential backoff
178
- self._shutdown.wait(5)
179
- self._logger.info("No longer listening for work items")
180
- return
344
+ loop = asyncio.new_event_loop()
345
+ asyncio.set_event_loop(loop)
346
+ loop.run_until_complete(self._async_run_loop())
181
347
 
182
348
  self._logger.info(f"Starting gRPC worker that connects to {self._host_address}")
183
349
  self._runLoop = Thread(target=run_loop)
184
350
  self._runLoop.start()
185
351
  self._is_running = True
186
352
 
353
+ async def _async_run_loop(self):
354
+ worker_task = asyncio.create_task(self._async_worker_manager.run())
355
+ # Connection state management for retry fix
356
+ current_channel = None
357
+ current_stub = None
358
+ current_reader_thread = None
359
+ conn_retry_count = 0
360
+ conn_max_retry_delay = 60
361
+
362
+ def create_fresh_connection():
363
+ nonlocal current_channel, current_stub, conn_retry_count
364
+ if current_channel:
365
+ try:
366
+ current_channel.close()
367
+ except Exception:
368
+ pass
369
+ current_channel = None
370
+ current_stub = None
371
+ try:
372
+ current_channel = shared.get_grpc_channel(
373
+ self._host_address, self._secure_channel, self._interceptors
374
+ )
375
+ current_stub = stubs.TaskHubSidecarServiceStub(current_channel)
376
+ current_stub.Hello(empty_pb2.Empty())
377
+ conn_retry_count = 0
378
+ self._logger.info(f"Created fresh connection to {self._host_address}")
379
+ except Exception as e:
380
+ self._logger.warning(f"Failed to create connection: {e}")
381
+ current_channel = None
382
+ current_stub = None
383
+ raise
384
+
385
+ def invalidate_connection():
386
+ nonlocal current_channel, current_stub, current_reader_thread
387
+ # Cancel the response stream first to signal the reader thread to stop
388
+ if self._response_stream is not None:
389
+ try:
390
+ self._response_stream.cancel()
391
+ except Exception:
392
+ pass
393
+ self._response_stream = None
394
+
395
+ # Wait for the reader thread to finish
396
+ if current_reader_thread is not None:
397
+ try:
398
+ current_reader_thread.join(timeout=2)
399
+ if current_reader_thread.is_alive():
400
+ self._logger.warning("Stream reader thread did not shut down gracefully")
401
+ except Exception:
402
+ pass
403
+ current_reader_thread = None
404
+
405
+ # Close the channel
406
+ if current_channel:
407
+ try:
408
+ current_channel.close()
409
+ except Exception:
410
+ pass
411
+ current_channel = None
412
+ current_stub = None
413
+
414
+ def should_invalidate_connection(rpc_error):
415
+ error_code = rpc_error.code() # type: ignore
416
+ connection_level_errors = {
417
+ grpc.StatusCode.UNAVAILABLE,
418
+ grpc.StatusCode.DEADLINE_EXCEEDED,
419
+ grpc.StatusCode.CANCELLED,
420
+ grpc.StatusCode.UNAUTHENTICATED,
421
+ grpc.StatusCode.ABORTED,
422
+ }
423
+ return error_code in connection_level_errors
424
+
425
+ while not self._shutdown.is_set():
426
+ if current_stub is None:
427
+ try:
428
+ create_fresh_connection()
429
+ except Exception:
430
+ conn_retry_count += 1
431
+ delay = min(
432
+ conn_max_retry_delay,
433
+ (2 ** min(conn_retry_count, 6)) + random.uniform(0, 1),
434
+ )
435
+ self._logger.warning(
436
+ f"Connection failed, retrying in {delay:.2f} seconds (attempt {conn_retry_count})"
437
+ )
438
+ if self._shutdown.wait(delay):
439
+ break
440
+ continue
441
+ try:
442
+ assert current_stub is not None
443
+ stub = current_stub
444
+ get_work_items_request = pb.GetWorkItemsRequest(
445
+ maxConcurrentOrchestrationWorkItems=self._concurrency_options.maximum_concurrent_orchestration_work_items,
446
+ maxConcurrentActivityWorkItems=self._concurrency_options.maximum_concurrent_activity_work_items,
447
+ )
448
+ self._response_stream = stub.GetWorkItems(get_work_items_request)
449
+ self._logger.info(
450
+ f"Successfully connected to {self._host_address}. Waiting for work items..."
451
+ )
452
+
453
+ # Use a thread to read from the blocking gRPC stream and forward to asyncio
454
+ import queue
455
+
456
+ work_item_queue = queue.Queue()
457
+
458
+ def stream_reader():
459
+ try:
460
+ for work_item in self._response_stream:
461
+ work_item_queue.put(work_item)
462
+ except Exception as e:
463
+ work_item_queue.put(e)
464
+
465
+ import threading
466
+
467
+ current_reader_thread = threading.Thread(target=stream_reader, daemon=True)
468
+ current_reader_thread.start()
469
+ loop = asyncio.get_running_loop()
470
+ while not self._shutdown.is_set():
471
+ try:
472
+ work_item = await loop.run_in_executor(
473
+ None, work_item_queue.get
474
+ )
475
+ if isinstance(work_item, Exception):
476
+ raise work_item
477
+ request_type = work_item.WhichOneof("request")
478
+ self._logger.debug(f'Received "{request_type}" work item')
479
+ if work_item.HasField("orchestratorRequest"):
480
+ self._async_worker_manager.submit_orchestration(
481
+ self._execute_orchestrator,
482
+ work_item.orchestratorRequest,
483
+ stub,
484
+ work_item.completionToken,
485
+ )
486
+ elif work_item.HasField("activityRequest"):
487
+ self._async_worker_manager.submit_activity(
488
+ self._execute_activity,
489
+ work_item.activityRequest,
490
+ stub,
491
+ work_item.completionToken,
492
+ )
493
+ elif work_item.HasField("healthPing"):
494
+ pass
495
+ else:
496
+ self._logger.warning(
497
+ f"Unexpected work item type: {request_type}"
498
+ )
499
+ except Exception as e:
500
+ self._logger.warning(f"Error in work item stream: {e}")
501
+ raise e
502
+ current_reader_thread.join(timeout=1)
503
+ self._logger.info("Work item stream ended normally")
504
+ except grpc.RpcError as rpc_error:
505
+ should_invalidate = should_invalidate_connection(rpc_error)
506
+ if should_invalidate:
507
+ invalidate_connection()
508
+ error_code = rpc_error.code() # type: ignore
509
+ error_details = str(rpc_error)
510
+
511
+ if error_code == grpc.StatusCode.CANCELLED:
512
+ self._logger.info(f"Disconnected from {self._host_address}")
513
+ break
514
+ elif error_code == grpc.StatusCode.UNAVAILABLE:
515
+ # Check if this is a connection timeout scenario
516
+ if "Timeout occurred" in error_details or "Failed to connect to remote host" in error_details:
517
+ self._logger.warning(
518
+ f"Connection timeout to {self._host_address}: {error_details} - will retry with fresh connection"
519
+ )
520
+ else:
521
+ self._logger.warning(
522
+ f"The sidecar at address {self._host_address} is unavailable: {error_details} - will continue retrying"
523
+ )
524
+ elif should_invalidate:
525
+ self._logger.warning(
526
+ f"Connection-level gRPC error ({error_code}): {rpc_error} - resetting connection"
527
+ )
528
+ else:
529
+ self._logger.warning(
530
+ f"Application-level gRPC error ({error_code}): {rpc_error}"
531
+ )
532
+ self._shutdown.wait(1)
533
+ except Exception as ex:
534
+ invalidate_connection()
535
+ self._logger.warning(f"Unexpected error: {ex}")
536
+ self._shutdown.wait(1)
537
+ invalidate_connection()
538
+ self._logger.info("No longer listening for work items")
539
+ self._async_worker_manager.shutdown()
540
+ await worker_task
541
+
187
542
  def stop(self):
188
543
  """Stops the worker and waits for any pending work items to complete."""
189
544
  if not self._is_running:
@@ -195,58 +550,97 @@ class TaskHubGrpcWorker:
195
550
  self._response_stream.cancel()
196
551
  if self._runLoop is not None:
197
552
  self._runLoop.join(timeout=30)
553
+ self._async_worker_manager.shutdown()
198
554
  self._logger.info("Worker shutdown completed")
199
555
  self._is_running = False
200
556
 
201
- def _execute_orchestrator(self, req: pb.OrchestratorRequest, stub: stubs.TaskHubSidecarServiceStub, completionToken):
557
+ def _execute_orchestrator(
558
+ self,
559
+ req: pb.OrchestratorRequest,
560
+ stub: stubs.TaskHubSidecarServiceStub,
561
+ completionToken,
562
+ ):
202
563
  try:
203
564
  executor = _OrchestrationExecutor(self._registry, self._logger)
204
565
  result = executor.execute(req.instanceId, req.pastEvents, req.newEvents)
205
566
  res = pb.OrchestratorResponse(
206
567
  instanceId=req.instanceId,
207
568
  actions=result.actions,
208
- customStatus=pbh.get_string_value(result.encoded_custom_status),
209
- completionToken=completionToken)
569
+ customStatus=ph.get_string_value(result.encoded_custom_status),
570
+ completionToken=completionToken,
571
+ )
572
+ except pe.AbandonOrchestrationError:
573
+ self._logger.info(
574
+ f"Abandoning orchestration. InstanceId = '{req.instanceId}'. Completion token = '{completionToken}'"
575
+ )
576
+ stub.AbandonTaskOrchestratorWorkItem(
577
+ pb.AbandonOrchestrationTaskRequest(
578
+ completionToken=completionToken
579
+ )
580
+ )
581
+ return
210
582
  except Exception as ex:
211
- self._logger.exception(f"An error occurred while trying to execute instance '{req.instanceId}': {ex}")
212
- failure_details = pbh.new_failure_details(ex)
213
- actions = [pbh.new_complete_orchestration_action(-1, pb.ORCHESTRATION_STATUS_FAILED, "", failure_details)]
214
- res = pb.OrchestratorResponse(instanceId=req.instanceId, actions=actions, completionToken=completionToken)
583
+ self._logger.exception(
584
+ f"An error occurred while trying to execute instance '{req.instanceId}': {ex}"
585
+ )
586
+ failure_details = ph.new_failure_details(ex)
587
+ actions = [
588
+ ph.new_complete_orchestration_action(
589
+ -1, pb.ORCHESTRATION_STATUS_FAILED, "", failure_details
590
+ )
591
+ ]
592
+ res = pb.OrchestratorResponse(
593
+ instanceId=req.instanceId,
594
+ actions=actions,
595
+ completionToken=completionToken,
596
+ )
215
597
 
216
598
  try:
217
599
  stub.CompleteOrchestratorTask(res)
218
600
  except Exception as ex:
219
- self._logger.exception(f"Failed to deliver orchestrator response for '{req.instanceId}' to sidecar: {ex}")
220
-
221
- def _execute_activity(self, req: pb.ActivityRequest, stub: stubs.TaskHubSidecarServiceStub, completionToken):
601
+ self._logger.exception(
602
+ f"Failed to deliver orchestrator response for '{req.instanceId}' to sidecar: {ex}"
603
+ )
604
+
605
+ def _execute_activity(
606
+ self,
607
+ req: pb.ActivityRequest,
608
+ stub: stubs.TaskHubSidecarServiceStub,
609
+ completionToken,
610
+ ):
222
611
  instance_id = req.orchestrationInstance.instanceId
223
612
  try:
224
613
  executor = _ActivityExecutor(self._registry, self._logger)
225
- result = executor.execute(instance_id, req.name, req.taskId, req.input.value)
614
+ result = executor.execute(
615
+ instance_id, req.name, req.taskId, req.input.value
616
+ )
226
617
  res = pb.ActivityResponse(
227
618
  instanceId=instance_id,
228
619
  taskId=req.taskId,
229
- result=pbh.get_string_value(result),
230
- completionToken=completionToken)
620
+ result=ph.get_string_value(result),
621
+ completionToken=completionToken,
622
+ )
231
623
  except Exception as ex:
232
624
  res = pb.ActivityResponse(
233
625
  instanceId=instance_id,
234
626
  taskId=req.taskId,
235
- failureDetails=pbh.new_failure_details(ex),
236
- completionToken=completionToken)
627
+ failureDetails=ph.new_failure_details(ex),
628
+ completionToken=completionToken,
629
+ )
237
630
 
238
631
  try:
239
632
  stub.CompleteActivityTask(res)
240
633
  except Exception as ex:
241
634
  self._logger.exception(
242
- f"Failed to deliver activity response for '{req.name}#{req.taskId}' of orchestration ID '{instance_id}' to sidecar: {ex}")
635
+ f"Failed to deliver activity response for '{req.name}#{req.taskId}' of orchestration ID '{instance_id}' to sidecar: {ex}"
636
+ )
243
637
 
244
638
 
245
639
  class _RuntimeOrchestrationContext(task.OrchestrationContext):
246
640
  _generator: Optional[Generator[task.Task, Any, Any]]
247
641
  _previous_task: Optional[task.Task]
248
642
 
249
- def __init__(self, instance_id: str):
643
+ def __init__(self, instance_id: str, registry: _Registry):
250
644
  self._generator = None
251
645
  self._is_replaying = True
252
646
  self._is_complete = False
@@ -256,6 +650,8 @@ class _RuntimeOrchestrationContext(task.OrchestrationContext):
256
650
  self._sequence_number = 0
257
651
  self._current_utc_datetime = datetime(1000, 1, 1)
258
652
  self._instance_id = instance_id
653
+ self._registry = registry
654
+ self._version: Optional[str] = None
259
655
  self._completion_status: Optional[pb.OrchestrationStatus] = None
260
656
  self._received_events: dict[str, list[Any]] = {}
261
657
  self._pending_events: dict[str, list[task.CompletableTask]] = {}
@@ -273,7 +669,9 @@ class _RuntimeOrchestrationContext(task.OrchestrationContext):
273
669
  def resume(self):
274
670
  if self._generator is None:
275
671
  # This is never expected unless maybe there's an issue with the history
276
- raise TypeError("The orchestrator generator is not initialized! Was the orchestration history corrupted?")
672
+ raise TypeError(
673
+ "The orchestrator generator is not initialized! Was the orchestration history corrupted?"
674
+ )
277
675
 
278
676
  # We can resume the generator only if the previously yielded task
279
677
  # has reached a completed state. The only time this won't be the
@@ -294,7 +692,12 @@ class _RuntimeOrchestrationContext(task.OrchestrationContext):
294
692
  raise TypeError("The orchestrator generator yielded a non-Task object")
295
693
  self._previous_task = next_task
296
694
 
297
- def set_complete(self, result: Any, status: pb.OrchestrationStatus, is_result_encoded: bool = False):
695
+ def set_complete(
696
+ self,
697
+ result: Any,
698
+ status: pb.OrchestrationStatus,
699
+ is_result_encoded: bool = False,
700
+ ):
298
701
  if self._is_complete:
299
702
  return
300
703
 
@@ -307,10 +710,11 @@ class _RuntimeOrchestrationContext(task.OrchestrationContext):
307
710
  if result is not None:
308
711
  result_json = result if is_result_encoded else shared.to_json(result)
309
712
  action = ph.new_complete_orchestration_action(
310
- self.next_sequence_number(), status, result_json)
713
+ self.next_sequence_number(), status, result_json
714
+ )
311
715
  self._pending_actions[action.id] = action
312
716
 
313
- def set_failed(self, ex: Exception):
717
+ def set_failed(self, ex: Union[Exception, pb.TaskFailureDetails]):
314
718
  if self._is_complete:
315
719
  return
316
720
 
@@ -319,7 +723,10 @@ class _RuntimeOrchestrationContext(task.OrchestrationContext):
319
723
  self._completion_status = pb.ORCHESTRATION_STATUS_FAILED
320
724
 
321
725
  action = ph.new_complete_orchestration_action(
322
- self.next_sequence_number(), pb.ORCHESTRATION_STATUS_FAILED, None, ph.new_failure_details(ex)
726
+ self.next_sequence_number(),
727
+ pb.ORCHESTRATION_STATUS_FAILED,
728
+ None,
729
+ ph.new_failure_details(ex) if isinstance(ex, Exception) else ex,
323
730
  )
324
731
  self._pending_actions[action.id] = action
325
732
 
@@ -343,14 +750,21 @@ class _RuntimeOrchestrationContext(task.OrchestrationContext):
343
750
  # replayed when the new instance starts.
344
751
  for event_name, values in self._received_events.items():
345
752
  for event_value in values:
346
- encoded_value = shared.to_json(event_value) if event_value else None
347
- carryover_events.append(ph.new_event_raised_event(event_name, encoded_value))
753
+ encoded_value = (
754
+ shared.to_json(event_value) if event_value else None
755
+ )
756
+ carryover_events.append(
757
+ ph.new_event_raised_event(event_name, encoded_value)
758
+ )
348
759
  action = ph.new_complete_orchestration_action(
349
760
  self.next_sequence_number(),
350
761
  pb.ORCHESTRATION_STATUS_CONTINUED_AS_NEW,
351
- result=shared.to_json(self._new_input) if self._new_input is not None else None,
762
+ result=shared.to_json(self._new_input)
763
+ if self._new_input is not None
764
+ else None,
352
765
  failure_details=None,
353
- carryover_events=carryover_events)
766
+ carryover_events=carryover_events,
767
+ )
354
768
  return [action]
355
769
  else:
356
770
  return list(self._pending_actions.values())
@@ -364,63 +778,98 @@ class _RuntimeOrchestrationContext(task.OrchestrationContext):
364
778
  return self._instance_id
365
779
 
366
780
  @property
367
- def current_utc_datetime(self) -> datetime:
368
- return self._current_utc_datetime
781
+ def version(self) -> Optional[str]:
782
+ return self._version
369
783
 
370
784
  @property
371
- def is_replaying(self) -> bool:
372
- return self._is_replaying
785
+ def current_utc_datetime(self) -> datetime:
786
+ return self._current_utc_datetime
373
787
 
374
788
  @current_utc_datetime.setter
375
789
  def current_utc_datetime(self, value: datetime):
376
790
  self._current_utc_datetime = value
377
791
 
792
+ @property
793
+ def is_replaying(self) -> bool:
794
+ return self._is_replaying
795
+
378
796
  def set_custom_status(self, custom_status: Any) -> None:
379
- self._encoded_custom_status = shared.to_json(custom_status) if custom_status is not None else None
797
+ self._encoded_custom_status = (
798
+ shared.to_json(custom_status) if custom_status is not None else None
799
+ )
380
800
 
381
801
  def create_timer(self, fire_at: Union[datetime, timedelta]) -> task.Task:
382
802
  return self.create_timer_internal(fire_at)
383
803
 
384
- def create_timer_internal(self, fire_at: Union[datetime, timedelta],
385
- retryable_task: Optional[task.RetryableTask] = None) -> task.Task:
804
+ def create_timer_internal(
805
+ self,
806
+ fire_at: Union[datetime, timedelta],
807
+ retryable_task: Optional[task.RetryableTask] = None,
808
+ ) -> task.Task:
386
809
  id = self.next_sequence_number()
387
810
  if isinstance(fire_at, timedelta):
388
811
  fire_at = self.current_utc_datetime + fire_at
389
812
  action = ph.new_create_timer_action(id, fire_at)
390
813
  self._pending_actions[id] = action
391
814
 
392
- timer_task = task.TimerTask()
815
+ timer_task: task.TimerTask = task.TimerTask()
393
816
  if retryable_task is not None:
394
817
  timer_task.set_retryable_parent(retryable_task)
395
818
  self._pending_tasks[id] = timer_task
396
819
  return timer_task
397
820
 
398
- def call_activity(self, activity: Union[task.Activity[TInput, TOutput], str], *,
399
- input: Optional[TInput] = None,
400
- retry_policy: Optional[task.RetryPolicy] = None) -> task.Task[TOutput]:
821
+ def call_activity(
822
+ self,
823
+ activity: Union[task.Activity[TInput, TOutput], str],
824
+ *,
825
+ input: Optional[TInput] = None,
826
+ retry_policy: Optional[task.RetryPolicy] = None,
827
+ tags: Optional[dict[str, str]] = None,
828
+ ) -> task.Task[TOutput]:
401
829
  id = self.next_sequence_number()
402
830
 
403
- self.call_activity_function_helper(id, activity, input=input, retry_policy=retry_policy,
404
- is_sub_orch=False)
831
+ self.call_activity_function_helper(
832
+ id, activity, input=input, retry_policy=retry_policy, is_sub_orch=False, tags=tags
833
+ )
405
834
  return self._pending_tasks.get(id, task.CompletableTask())
406
835
 
407
- def call_sub_orchestrator(self, orchestrator: task.Orchestrator[TInput, TOutput], *,
408
- input: Optional[TInput] = None,
409
- instance_id: Optional[str] = None,
410
- retry_policy: Optional[task.RetryPolicy] = None) -> task.Task[TOutput]:
836
+ def call_sub_orchestrator(
837
+ self,
838
+ orchestrator: task.Orchestrator[TInput, TOutput],
839
+ *,
840
+ input: Optional[TInput] = None,
841
+ instance_id: Optional[str] = None,
842
+ retry_policy: Optional[task.RetryPolicy] = None,
843
+ version: Optional[str] = None,
844
+ ) -> task.Task[TOutput]:
411
845
  id = self.next_sequence_number()
412
846
  orchestrator_name = task.get_name(orchestrator)
413
- self.call_activity_function_helper(id, orchestrator_name, input=input, retry_policy=retry_policy,
414
- is_sub_orch=True, instance_id=instance_id)
847
+ default_version = self._registry.versioning.default_version if self._registry.versioning else None
848
+ orchestrator_version = version if version else default_version
849
+ self.call_activity_function_helper(
850
+ id,
851
+ orchestrator_name,
852
+ input=input,
853
+ retry_policy=retry_policy,
854
+ is_sub_orch=True,
855
+ instance_id=instance_id,
856
+ version=orchestrator_version
857
+ )
415
858
  return self._pending_tasks.get(id, task.CompletableTask())
416
859
 
417
- def call_activity_function_helper(self, id: Optional[int],
418
- activity_function: Union[task.Activity[TInput, TOutput], str], *,
419
- input: Optional[TInput] = None,
420
- retry_policy: Optional[task.RetryPolicy] = None,
421
- is_sub_orch: bool = False,
422
- instance_id: Optional[str] = None,
423
- fn_task: Optional[task.CompletableTask[TOutput]] = None):
860
+ def call_activity_function_helper(
861
+ self,
862
+ id: Optional[int],
863
+ activity_function: Union[task.Activity[TInput, TOutput], str],
864
+ *,
865
+ input: Optional[TInput] = None,
866
+ retry_policy: Optional[task.RetryPolicy] = None,
867
+ tags: Optional[dict[str, str]] = None,
868
+ is_sub_orch: bool = False,
869
+ instance_id: Optional[str] = None,
870
+ fn_task: Optional[task.CompletableTask[TOutput]] = None,
871
+ version: Optional[str] = None,
872
+ ):
424
873
  if id is None:
425
874
  id = self.next_sequence_number()
426
875
 
@@ -431,24 +880,33 @@ class _RuntimeOrchestrationContext(task.OrchestrationContext):
431
880
  # We just need to take string representation of it.
432
881
  encoded_input = str(input)
433
882
  if not is_sub_orch:
434
- name = activity_function if isinstance(activity_function, str) else task.get_name(activity_function)
435
- action = ph.new_schedule_task_action(id, name, encoded_input)
883
+ name = (
884
+ activity_function
885
+ if isinstance(activity_function, str)
886
+ else task.get_name(activity_function)
887
+ )
888
+ action = ph.new_schedule_task_action(id, name, encoded_input, tags)
436
889
  else:
437
890
  if instance_id is None:
438
891
  # Create a deteministic instance ID based on the parent instance ID
439
892
  instance_id = f"{self.instance_id}:{id:04x}"
440
893
  if not isinstance(activity_function, str):
441
894
  raise ValueError("Orchestrator function name must be a string")
442
- action = ph.new_create_sub_orchestration_action(id, activity_function, instance_id, encoded_input)
895
+ action = ph.new_create_sub_orchestration_action(
896
+ id, activity_function, instance_id, encoded_input, version
897
+ )
443
898
  self._pending_actions[id] = action
444
899
 
445
900
  if fn_task is None:
446
901
  if retry_policy is None:
447
902
  fn_task = task.CompletableTask[TOutput]()
448
903
  else:
449
- fn_task = task.RetryableTask[TOutput](retry_policy=retry_policy, action=action,
450
- start_time=self.current_utc_datetime,
451
- is_sub_orch=is_sub_orch)
904
+ fn_task = task.RetryableTask[TOutput](
905
+ retry_policy=retry_policy,
906
+ action=action,
907
+ start_time=self.current_utc_datetime,
908
+ is_sub_orch=is_sub_orch,
909
+ )
452
910
  self._pending_tasks[id] = fn_task
453
911
 
454
912
  def wait_for_external_event(self, name: str) -> task.Task:
@@ -457,7 +915,7 @@ class _RuntimeOrchestrationContext(task.OrchestrationContext):
457
915
  # event with the given name so that we can resume the generator when it
458
916
  # arrives. If there are multiple events with the same name, we return
459
917
  # them in the order they were received.
460
- external_event_task = task.CompletableTask()
918
+ external_event_task: task.CompletableTask = task.CompletableTask()
461
919
  event_name = name.casefold()
462
920
  event_list = self._received_events.get(event_name, None)
463
921
  if event_list:
@@ -484,7 +942,9 @@ class ExecutionResults:
484
942
  actions: list[pb.OrchestratorAction]
485
943
  encoded_custom_status: Optional[str]
486
944
 
487
- def __init__(self, actions: list[pb.OrchestratorAction], encoded_custom_status: Optional[str]):
945
+ def __init__(
946
+ self, actions: list[pb.OrchestratorAction], encoded_custom_status: Optional[str]
947
+ ):
488
948
  self.actions = actions
489
949
  self.encoded_custom_status = encoded_custom_status
490
950
 
@@ -498,26 +958,64 @@ class _OrchestrationExecutor:
498
958
  self._is_suspended = False
499
959
  self._suspended_events: list[pb.HistoryEvent] = []
500
960
 
501
- def execute(self, instance_id: str, old_events: Sequence[pb.HistoryEvent], new_events: Sequence[pb.HistoryEvent]) -> ExecutionResults:
961
+ def execute(
962
+ self,
963
+ instance_id: str,
964
+ old_events: Sequence[pb.HistoryEvent],
965
+ new_events: Sequence[pb.HistoryEvent],
966
+ ) -> ExecutionResults:
502
967
  if not new_events:
503
- raise task.OrchestrationStateError("The new history event list must have at least one event in it.")
968
+ raise task.OrchestrationStateError(
969
+ "The new history event list must have at least one event in it."
970
+ )
504
971
 
505
- ctx = _RuntimeOrchestrationContext(instance_id)
972
+ ctx = _RuntimeOrchestrationContext(instance_id, self._registry)
973
+ version_failure = None
506
974
  try:
507
975
  # Rebuild local state by replaying old history into the orchestrator function
508
- self._logger.debug(f"{instance_id}: Rebuilding local state with {len(old_events)} history event...")
976
+ self._logger.debug(
977
+ f"{instance_id}: Rebuilding local state with {len(old_events)} history event..."
978
+ )
509
979
  ctx._is_replaying = True
510
980
  for old_event in old_events:
511
981
  self.process_event(ctx, old_event)
512
982
 
983
+ # Process versioning if applicable
984
+ execution_started_events = [e.executionStarted for e in old_events if e.HasField("executionStarted")]
985
+ # We only check versioning if there are executionStarted events - otherwise, on the first replay when
986
+ # ctx.version will be Null, we may invalidate orchestrations early depending on the versioning strategy.
987
+ if self._registry.versioning and len(execution_started_events) > 0:
988
+ version_failure = self.evaluate_orchestration_versioning(
989
+ self._registry.versioning,
990
+ ctx.version
991
+ )
992
+ if version_failure:
993
+ self._logger.warning(
994
+ f"Orchestration version did not meet worker versioning requirements. "
995
+ f"Error action = '{self._registry.versioning.failure_strategy}'. "
996
+ f"Version error = '{version_failure}'"
997
+ )
998
+ raise pe.VersionFailureException
999
+
513
1000
  # Get new actions by executing newly received events into the orchestrator function
514
1001
  if self._logger.level <= logging.DEBUG:
515
1002
  summary = _get_new_event_summary(new_events)
516
- self._logger.debug(f"{instance_id}: Processing {len(new_events)} new event(s): {summary}")
1003
+ self._logger.debug(
1004
+ f"{instance_id}: Processing {len(new_events)} new event(s): {summary}"
1005
+ )
517
1006
  ctx._is_replaying = False
518
1007
  for new_event in new_events:
519
1008
  self.process_event(ctx, new_event)
520
1009
 
1010
+ except pe.VersionFailureException as ex:
1011
+ if self._registry.versioning and self._registry.versioning.failure_strategy == VersionFailureStrategy.FAIL:
1012
+ if version_failure:
1013
+ ctx.set_failed(version_failure)
1014
+ else:
1015
+ ctx.set_failed(ex)
1016
+ elif self._registry.versioning and self._registry.versioning.failure_strategy == VersionFailureStrategy.REJECT:
1017
+ raise pe.AbandonOrchestrationError
1018
+
521
1019
  except Exception as ex:
522
1020
  # Unhandled exceptions fail the orchestration
523
1021
  ctx.set_failed(ex)
@@ -525,17 +1023,31 @@ class _OrchestrationExecutor:
525
1023
  if not ctx._is_complete:
526
1024
  task_count = len(ctx._pending_tasks)
527
1025
  event_count = len(ctx._pending_events)
528
- self._logger.info(f"{instance_id}: Orchestrator yielded with {task_count} task(s) and {event_count} event(s) outstanding.")
529
- elif ctx._completion_status and ctx._completion_status is not pb.ORCHESTRATION_STATUS_CONTINUED_AS_NEW:
530
- completion_status_str = pbh.get_orchestration_status_str(ctx._completion_status)
531
- self._logger.info(f"{instance_id}: Orchestration completed with status: {completion_status_str}")
1026
+ self._logger.info(
1027
+ f"{instance_id}: Orchestrator yielded with {task_count} task(s) and {event_count} event(s) outstanding."
1028
+ )
1029
+ elif (
1030
+ ctx._completion_status and ctx._completion_status is not pb.ORCHESTRATION_STATUS_CONTINUED_AS_NEW
1031
+ ):
1032
+ completion_status_str = ph.get_orchestration_status_str(
1033
+ ctx._completion_status
1034
+ )
1035
+ self._logger.info(
1036
+ f"{instance_id}: Orchestration completed with status: {completion_status_str}"
1037
+ )
532
1038
 
533
1039
  actions = ctx.get_actions()
534
1040
  if self._logger.level <= logging.DEBUG:
535
- self._logger.debug(f"{instance_id}: Returning {len(actions)} action(s): {_get_action_summary(actions)}")
536
- return ExecutionResults(actions=actions, encoded_custom_status=ctx._encoded_custom_status)
1041
+ self._logger.debug(
1042
+ f"{instance_id}: Returning {len(actions)} action(s): {_get_action_summary(actions)}"
1043
+ )
1044
+ return ExecutionResults(
1045
+ actions=actions, encoded_custom_status=ctx._encoded_custom_status
1046
+ )
537
1047
 
538
- def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEvent) -> None:
1048
+ def process_event(
1049
+ self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEvent
1050
+ ) -> None:
539
1051
  if self._is_suspended and _is_suspendable(event):
540
1052
  # We are suspended, so we need to buffer this event until we are resumed
541
1053
  self._suspended_events.append(event)
@@ -550,14 +1062,22 @@ class _OrchestrationExecutor:
550
1062
  fn = self._registry.get_orchestrator(event.executionStarted.name)
551
1063
  if fn is None:
552
1064
  raise OrchestratorNotRegisteredError(
553
- f"A '{event.executionStarted.name}' orchestrator was not registered.")
1065
+ f"A '{event.executionStarted.name}' orchestrator was not registered."
1066
+ )
1067
+
1068
+ if event.executionStarted.version:
1069
+ ctx._version = event.executionStarted.version.value
554
1070
 
555
1071
  # deserialize the input, if any
556
1072
  input = None
557
- if event.executionStarted.input is not None and event.executionStarted.input.value != "":
1073
+ if (
1074
+ event.executionStarted.input is not None and event.executionStarted.input.value != ""
1075
+ ):
558
1076
  input = shared.from_json(event.executionStarted.input.value)
559
1077
 
560
- result = fn(ctx, input) # this does not execute the generator, only creates it
1078
+ result = fn(
1079
+ ctx, input
1080
+ ) # this does not execute the generator, only creates it
561
1081
  if isinstance(result, GeneratorType):
562
1082
  # Start the orchestrator's generator function
563
1083
  ctx.run(result)
@@ -570,10 +1090,14 @@ class _OrchestrationExecutor:
570
1090
  timer_id = event.eventId
571
1091
  action = ctx._pending_actions.pop(timer_id, None)
572
1092
  if not action:
573
- raise _get_non_determinism_error(timer_id, task.get_name(ctx.create_timer))
1093
+ raise _get_non_determinism_error(
1094
+ timer_id, task.get_name(ctx.create_timer)
1095
+ )
574
1096
  elif not action.HasField("createTimer"):
575
1097
  expected_method_name = task.get_name(ctx.create_timer)
576
- raise _get_wrong_action_type_error(timer_id, expected_method_name, action)
1098
+ raise _get_wrong_action_type_error(
1099
+ timer_id, expected_method_name, action
1100
+ )
577
1101
  elif event.HasField("timerFired"):
578
1102
  timer_id = event.timerFired.timerId
579
1103
  timer_task = ctx._pending_tasks.pop(timer_id, None)
@@ -581,7 +1105,8 @@ class _OrchestrationExecutor:
581
1105
  # TODO: Should this be an error? When would it ever happen?
582
1106
  if not ctx._is_replaying:
583
1107
  self._logger.warning(
584
- f"{ctx.instance_id}: Ignoring unexpected timerFired event with ID = {timer_id}.")
1108
+ f"{ctx.instance_id}: Ignoring unexpected timerFired event with ID = {timer_id}."
1109
+ )
585
1110
  return
586
1111
  timer_task.complete(None)
587
1112
  if timer_task._retryable_parent is not None:
@@ -593,12 +1118,15 @@ class _OrchestrationExecutor:
593
1118
  else:
594
1119
  cur_task = activity_action.createSubOrchestration
595
1120
  instance_id = cur_task.instanceId
596
- ctx.call_activity_function_helper(id=activity_action.id, activity_function=cur_task.name,
597
- input=cur_task.input.value,
598
- retry_policy=timer_task._retryable_parent._retry_policy,
599
- is_sub_orch=timer_task._retryable_parent._is_sub_orch,
600
- instance_id=instance_id,
601
- fn_task=timer_task._retryable_parent)
1121
+ ctx.call_activity_function_helper(
1122
+ id=activity_action.id,
1123
+ activity_function=cur_task.name,
1124
+ input=cur_task.input.value,
1125
+ retry_policy=timer_task._retryable_parent._retry_policy,
1126
+ is_sub_orch=timer_task._retryable_parent._is_sub_orch,
1127
+ instance_id=instance_id,
1128
+ fn_task=timer_task._retryable_parent,
1129
+ )
602
1130
  else:
603
1131
  ctx.resume()
604
1132
  elif event.HasField("taskScheduled"):
@@ -608,16 +1136,21 @@ class _OrchestrationExecutor:
608
1136
  action = ctx._pending_actions.pop(task_id, None)
609
1137
  activity_task = ctx._pending_tasks.get(task_id, None)
610
1138
  if not action:
611
- raise _get_non_determinism_error(task_id, task.get_name(ctx.call_activity))
1139
+ raise _get_non_determinism_error(
1140
+ task_id, task.get_name(ctx.call_activity)
1141
+ )
612
1142
  elif not action.HasField("scheduleTask"):
613
1143
  expected_method_name = task.get_name(ctx.call_activity)
614
- raise _get_wrong_action_type_error(task_id, expected_method_name, action)
1144
+ raise _get_wrong_action_type_error(
1145
+ task_id, expected_method_name, action
1146
+ )
615
1147
  elif action.scheduleTask.name != event.taskScheduled.name:
616
1148
  raise _get_wrong_action_name_error(
617
1149
  task_id,
618
1150
  method_name=task.get_name(ctx.call_activity),
619
1151
  expected_task_name=event.taskScheduled.name,
620
- actual_task_name=action.scheduleTask.name)
1152
+ actual_task_name=action.scheduleTask.name,
1153
+ )
621
1154
  elif event.HasField("taskCompleted"):
622
1155
  # This history event contains the result of a completed activity task.
623
1156
  task_id = event.taskCompleted.taskScheduledId
@@ -626,7 +1159,8 @@ class _OrchestrationExecutor:
626
1159
  # TODO: Should this be an error? When would it ever happen?
627
1160
  if not ctx.is_replaying:
628
1161
  self._logger.warning(
629
- f"{ctx.instance_id}: Ignoring unexpected taskCompleted event with ID = {task_id}.")
1162
+ f"{ctx.instance_id}: Ignoring unexpected taskCompleted event with ID = {task_id}."
1163
+ )
630
1164
  return
631
1165
  result = None
632
1166
  if not ph.is_empty(event.taskCompleted.result):
@@ -640,7 +1174,8 @@ class _OrchestrationExecutor:
640
1174
  # TODO: Should this be an error? When would it ever happen?
641
1175
  if not ctx.is_replaying:
642
1176
  self._logger.warning(
643
- f"{ctx.instance_id}: Ignoring unexpected taskFailed event with ID = {task_id}.")
1177
+ f"{ctx.instance_id}: Ignoring unexpected taskFailed event with ID = {task_id}."
1178
+ )
644
1179
  return
645
1180
 
646
1181
  if isinstance(activity_task, task.RetryableTask):
@@ -649,7 +1184,8 @@ class _OrchestrationExecutor:
649
1184
  if next_delay is None:
650
1185
  activity_task.fail(
651
1186
  f"{ctx.instance_id}: Activity task #{task_id} failed: {event.taskFailed.failureDetails.errorMessage}",
652
- event.taskFailed.failureDetails)
1187
+ event.taskFailed.failureDetails,
1188
+ )
653
1189
  ctx.resume()
654
1190
  else:
655
1191
  activity_task.increment_attempt_count()
@@ -657,7 +1193,8 @@ class _OrchestrationExecutor:
657
1193
  elif isinstance(activity_task, task.CompletableTask):
658
1194
  activity_task.fail(
659
1195
  f"{ctx.instance_id}: Activity task #{task_id} failed: {event.taskFailed.failureDetails.errorMessage}",
660
- event.taskFailed.failureDetails)
1196
+ event.taskFailed.failureDetails,
1197
+ )
661
1198
  ctx.resume()
662
1199
  else:
663
1200
  raise TypeError("Unexpected task type")
@@ -667,16 +1204,23 @@ class _OrchestrationExecutor:
667
1204
  task_id = event.eventId
668
1205
  action = ctx._pending_actions.pop(task_id, None)
669
1206
  if not action:
670
- raise _get_non_determinism_error(task_id, task.get_name(ctx.call_sub_orchestrator))
1207
+ raise _get_non_determinism_error(
1208
+ task_id, task.get_name(ctx.call_sub_orchestrator)
1209
+ )
671
1210
  elif not action.HasField("createSubOrchestration"):
672
1211
  expected_method_name = task.get_name(ctx.call_sub_orchestrator)
673
- raise _get_wrong_action_type_error(task_id, expected_method_name, action)
674
- elif action.createSubOrchestration.name != event.subOrchestrationInstanceCreated.name:
1212
+ raise _get_wrong_action_type_error(
1213
+ task_id, expected_method_name, action
1214
+ )
1215
+ elif (
1216
+ action.createSubOrchestration.name != event.subOrchestrationInstanceCreated.name
1217
+ ):
675
1218
  raise _get_wrong_action_name_error(
676
1219
  task_id,
677
1220
  method_name=task.get_name(ctx.call_sub_orchestrator),
678
1221
  expected_task_name=event.subOrchestrationInstanceCreated.name,
679
- actual_task_name=action.createSubOrchestration.name)
1222
+ actual_task_name=action.createSubOrchestration.name,
1223
+ )
680
1224
  elif event.HasField("subOrchestrationInstanceCompleted"):
681
1225
  task_id = event.subOrchestrationInstanceCompleted.taskScheduledId
682
1226
  sub_orch_task = ctx._pending_tasks.pop(task_id, None)
@@ -684,11 +1228,14 @@ class _OrchestrationExecutor:
684
1228
  # TODO: Should this be an error? When would it ever happen?
685
1229
  if not ctx.is_replaying:
686
1230
  self._logger.warning(
687
- f"{ctx.instance_id}: Ignoring unexpected subOrchestrationInstanceCompleted event with ID = {task_id}.")
1231
+ f"{ctx.instance_id}: Ignoring unexpected subOrchestrationInstanceCompleted event with ID = {task_id}."
1232
+ )
688
1233
  return
689
1234
  result = None
690
1235
  if not ph.is_empty(event.subOrchestrationInstanceCompleted.result):
691
- result = shared.from_json(event.subOrchestrationInstanceCompleted.result.value)
1236
+ result = shared.from_json(
1237
+ event.subOrchestrationInstanceCompleted.result.value
1238
+ )
692
1239
  sub_orch_task.complete(result)
693
1240
  ctx.resume()
694
1241
  elif event.HasField("subOrchestrationInstanceFailed"):
@@ -699,7 +1246,8 @@ class _OrchestrationExecutor:
699
1246
  # TODO: Should this be an error? When would it ever happen?
700
1247
  if not ctx.is_replaying:
701
1248
  self._logger.warning(
702
- f"{ctx.instance_id}: Ignoring unexpected subOrchestrationInstanceFailed event with ID = {task_id}.")
1249
+ f"{ctx.instance_id}: Ignoring unexpected subOrchestrationInstanceFailed event with ID = {task_id}."
1250
+ )
703
1251
  return
704
1252
  if isinstance(sub_orch_task, task.RetryableTask):
705
1253
  if sub_orch_task._retry_policy is not None:
@@ -707,7 +1255,8 @@ class _OrchestrationExecutor:
707
1255
  if next_delay is None:
708
1256
  sub_orch_task.fail(
709
1257
  f"Sub-orchestration task #{task_id} failed: {failedEvent.failureDetails.errorMessage}",
710
- failedEvent.failureDetails)
1258
+ failedEvent.failureDetails,
1259
+ )
711
1260
  ctx.resume()
712
1261
  else:
713
1262
  sub_orch_task.increment_attempt_count()
@@ -715,7 +1264,8 @@ class _OrchestrationExecutor:
715
1264
  elif isinstance(sub_orch_task, task.CompletableTask):
716
1265
  sub_orch_task.fail(
717
1266
  f"Sub-orchestration task #{task_id} failed: {failedEvent.failureDetails.errorMessage}",
718
- failedEvent.failureDetails)
1267
+ failedEvent.failureDetails,
1268
+ )
719
1269
  ctx.resume()
720
1270
  else:
721
1271
  raise TypeError("Unexpected sub-orchestration task type")
@@ -744,7 +1294,9 @@ class _OrchestrationExecutor:
744
1294
  decoded_result = shared.from_json(event.eventRaised.input.value)
745
1295
  event_list.append(decoded_result)
746
1296
  if not ctx.is_replaying:
747
- self._logger.info(f"{ctx.instance_id}: Event '{event_name}' has been buffered as there are no tasks waiting for it.")
1297
+ self._logger.info(
1298
+ f"{ctx.instance_id}: Event '{event_name}' has been buffered as there are no tasks waiting for it."
1299
+ )
748
1300
  elif event.HasField("executionSuspended"):
749
1301
  if not self._is_suspended and not ctx.is_replaying:
750
1302
  self._logger.info(f"{ctx.instance_id}: Execution suspended.")
@@ -759,27 +1311,89 @@ class _OrchestrationExecutor:
759
1311
  elif event.HasField("executionTerminated"):
760
1312
  if not ctx.is_replaying:
761
1313
  self._logger.info(f"{ctx.instance_id}: Execution terminating.")
762
- encoded_output = event.executionTerminated.input.value if not ph.is_empty(event.executionTerminated.input) else None
763
- ctx.set_complete(encoded_output, pb.ORCHESTRATION_STATUS_TERMINATED, is_result_encoded=True)
1314
+ encoded_output = (
1315
+ event.executionTerminated.input.value
1316
+ if not ph.is_empty(event.executionTerminated.input)
1317
+ else None
1318
+ )
1319
+ ctx.set_complete(
1320
+ encoded_output,
1321
+ pb.ORCHESTRATION_STATUS_TERMINATED,
1322
+ is_result_encoded=True,
1323
+ )
764
1324
  else:
765
1325
  eventType = event.WhichOneof("eventType")
766
- raise task.OrchestrationStateError(f"Don't know how to handle event of type '{eventType}'")
1326
+ raise task.OrchestrationStateError(
1327
+ f"Don't know how to handle event of type '{eventType}'"
1328
+ )
767
1329
  except StopIteration as generatorStopped:
768
1330
  # The orchestrator generator function completed
769
1331
  ctx.set_complete(generatorStopped.value, pb.ORCHESTRATION_STATUS_COMPLETED)
770
1332
 
1333
+ def evaluate_orchestration_versioning(self, versioning: Optional[VersioningOptions], orchestration_version: Optional[str]) -> Optional[pb.TaskFailureDetails]:
1334
+ if versioning is None:
1335
+ return None
1336
+ version_comparison = self.compare_versions(orchestration_version, versioning.version)
1337
+ if versioning.match_strategy == VersionMatchStrategy.NONE:
1338
+ return None
1339
+ elif versioning.match_strategy == VersionMatchStrategy.STRICT:
1340
+ if version_comparison != 0:
1341
+ return pb.TaskFailureDetails(
1342
+ errorType="VersionMismatch",
1343
+ errorMessage=f"The orchestration version '{orchestration_version}' does not match the worker version '{versioning.version}'.",
1344
+ isNonRetriable=True,
1345
+ )
1346
+ elif versioning.match_strategy == VersionMatchStrategy.CURRENT_OR_OLDER:
1347
+ if version_comparison > 0:
1348
+ return pb.TaskFailureDetails(
1349
+ errorType="VersionMismatch",
1350
+ errorMessage=f"The orchestration version '{orchestration_version}' is greater than the worker version '{versioning.version}'.",
1351
+ isNonRetriable=True,
1352
+ )
1353
+ else:
1354
+ # If there is a type of versioning we don't understand, it is better to treat it as a versioning failure.
1355
+ return pb.TaskFailureDetails(
1356
+ errorType="VersionMismatch",
1357
+ errorMessage=f"The version match strategy '{versioning.match_strategy}' is unknown.",
1358
+ isNonRetriable=True,
1359
+ )
1360
+
1361
+ def compare_versions(self, source_version: Optional[str], default_version: Optional[str]) -> int:
1362
+ if not source_version and not default_version:
1363
+ return 0
1364
+ if not source_version:
1365
+ return -1
1366
+ if not default_version:
1367
+ return 1
1368
+ try:
1369
+ source_version_parsed = parse(source_version)
1370
+ default_version_parsed = parse(default_version)
1371
+ return (source_version_parsed > default_version_parsed) - (source_version_parsed < default_version_parsed)
1372
+ except InvalidVersion:
1373
+ return (source_version > default_version) - (source_version < default_version)
1374
+
771
1375
 
772
1376
  class _ActivityExecutor:
773
1377
  def __init__(self, registry: _Registry, logger: logging.Logger):
774
1378
  self._registry = registry
775
1379
  self._logger = logger
776
1380
 
777
- def execute(self, orchestration_id: str, name: str, task_id: int, encoded_input: Optional[str]) -> Optional[str]:
1381
+ def execute(
1382
+ self,
1383
+ orchestration_id: str,
1384
+ name: str,
1385
+ task_id: int,
1386
+ encoded_input: Optional[str],
1387
+ ) -> Optional[str]:
778
1388
  """Executes an activity function and returns the serialized result, if any."""
779
- self._logger.debug(f"{orchestration_id}/{task_id}: Executing activity '{name}'...")
1389
+ self._logger.debug(
1390
+ f"{orchestration_id}/{task_id}: Executing activity '{name}'..."
1391
+ )
780
1392
  fn = self._registry.get_activity(name)
781
1393
  if not fn:
782
- raise ActivityNotRegisteredError(f"Activity function named '{name}' was not registered!")
1394
+ raise ActivityNotRegisteredError(
1395
+ f"Activity function named '{name}' was not registered!"
1396
+ )
783
1397
 
784
1398
  activity_input = shared.from_json(encoded_input) if encoded_input else None
785
1399
  ctx = task.ActivityContext(orchestration_id, task_id)
@@ -787,49 +1401,54 @@ class _ActivityExecutor:
787
1401
  # Execute the activity function
788
1402
  activity_output = fn(ctx, activity_input)
789
1403
 
790
- encoded_output = shared.to_json(activity_output) if activity_output is not None else None
1404
+ encoded_output = (
1405
+ shared.to_json(activity_output) if activity_output is not None else None
1406
+ )
791
1407
  chars = len(encoded_output) if encoded_output else 0
792
1408
  self._logger.debug(
793
- f"{orchestration_id}/{task_id}: Activity '{name}' completed successfully with {chars} char(s) of encoded output.")
1409
+ f"{orchestration_id}/{task_id}: Activity '{name}' completed successfully with {chars} char(s) of encoded output."
1410
+ )
794
1411
  return encoded_output
795
1412
 
796
1413
 
797
- def _get_non_determinism_error(task_id: int, action_name: str) -> task.NonDeterminismError:
1414
+ def _get_non_determinism_error(
1415
+ task_id: int, action_name: str
1416
+ ) -> task.NonDeterminismError:
798
1417
  return task.NonDeterminismError(
799
1418
  f"A previous execution called {action_name} with ID={task_id}, but the current "
800
1419
  f"execution doesn't have this action with this ID. This problem occurs when either "
801
1420
  f"the orchestration has non-deterministic logic or if the code was changed after an "
802
- f"instance of this orchestration already started running.")
1421
+ f"instance of this orchestration already started running."
1422
+ )
803
1423
 
804
1424
 
805
1425
  def _get_wrong_action_type_error(
806
- task_id: int,
807
- expected_method_name: str,
808
- action: pb.OrchestratorAction) -> task.NonDeterminismError:
1426
+ task_id: int, expected_method_name: str, action: pb.OrchestratorAction
1427
+ ) -> task.NonDeterminismError:
809
1428
  unexpected_method_name = _get_method_name_for_action(action)
810
1429
  return task.NonDeterminismError(
811
1430
  f"Failed to restore orchestration state due to a history mismatch: A previous execution called "
812
1431
  f"{expected_method_name} with ID={task_id}, but the current execution is instead trying to call "
813
1432
  f"{unexpected_method_name} as part of rebuilding it's history. This kind of mismatch can happen if an "
814
1433
  f"orchestration has non-deterministic logic or if the code was changed after an instance of this "
815
- f"orchestration already started running.")
1434
+ f"orchestration already started running."
1435
+ )
816
1436
 
817
1437
 
818
1438
  def _get_wrong_action_name_error(
819
- task_id: int,
820
- method_name: str,
821
- expected_task_name: str,
822
- actual_task_name: str) -> task.NonDeterminismError:
1439
+ task_id: int, method_name: str, expected_task_name: str, actual_task_name: str
1440
+ ) -> task.NonDeterminismError:
823
1441
  return task.NonDeterminismError(
824
1442
  f"Failed to restore orchestration state due to a history mismatch: A previous execution called "
825
1443
  f"{method_name} with name='{expected_task_name}' and sequence number {task_id}, but the current "
826
1444
  f"execution is instead trying to call {actual_task_name} as part of rebuilding it's history. "
827
1445
  f"This kind of mismatch can happen if an orchestration has non-deterministic logic or if the code "
828
- f"was changed after an instance of this orchestration already started running.")
1446
+ f"was changed after an instance of this orchestration already started running."
1447
+ )
829
1448
 
830
1449
 
831
1450
  def _get_method_name_for_action(action: pb.OrchestratorAction) -> str:
832
- action_type = action.WhichOneof('orchestratorActionType')
1451
+ action_type = action.WhichOneof("orchestratorActionType")
833
1452
  if action_type == "scheduleTask":
834
1453
  return task.get_name(task.OrchestrationContext.call_activity)
835
1454
  elif action_type == "createTimer":
@@ -851,7 +1470,7 @@ def _get_new_event_summary(new_events: Sequence[pb.HistoryEvent]) -> str:
851
1470
  else:
852
1471
  counts: dict[str, int] = {}
853
1472
  for event in new_events:
854
- event_type = event.WhichOneof('eventType')
1473
+ event_type = event.WhichOneof("eventType")
855
1474
  counts[event_type] = counts.get(event_type, 0) + 1
856
1475
  return f"[{', '.join(f'{name}={count}' for name, count in counts.items())}]"
857
1476
 
@@ -865,11 +1484,210 @@ def _get_action_summary(new_actions: Sequence[pb.OrchestratorAction]) -> str:
865
1484
  else:
866
1485
  counts: dict[str, int] = {}
867
1486
  for action in new_actions:
868
- action_type = action.WhichOneof('orchestratorActionType')
1487
+ action_type = action.WhichOneof("orchestratorActionType")
869
1488
  counts[action_type] = counts.get(action_type, 0) + 1
870
1489
  return f"[{', '.join(f'{name}={count}' for name, count in counts.items())}]"
871
1490
 
872
1491
 
873
1492
  def _is_suspendable(event: pb.HistoryEvent) -> bool:
874
1493
  """Returns true if the event is one that can be suspended and resumed."""
875
- return event.WhichOneof("eventType") not in ["executionResumed", "executionTerminated"]
1494
+ return event.WhichOneof("eventType") not in [
1495
+ "executionResumed",
1496
+ "executionTerminated",
1497
+ ]
1498
+
1499
+
1500
+ class _AsyncWorkerManager:
1501
+ def __init__(self, concurrency_options: ConcurrencyOptions):
1502
+ self.concurrency_options = concurrency_options
1503
+ self.activity_semaphore = None
1504
+ self.orchestration_semaphore = None
1505
+ # Don't create queues here - defer until we have an event loop
1506
+ self.activity_queue: Optional[asyncio.Queue] = None
1507
+ self.orchestration_queue: Optional[asyncio.Queue] = None
1508
+ self._queue_event_loop: Optional[asyncio.AbstractEventLoop] = None
1509
+ # Store work items when no event loop is available
1510
+ self._pending_activity_work: list = []
1511
+ self._pending_orchestration_work: list = []
1512
+ self.thread_pool = ThreadPoolExecutor(
1513
+ max_workers=concurrency_options.maximum_thread_pool_workers,
1514
+ thread_name_prefix="DurableTask",
1515
+ )
1516
+ self._shutdown = False
1517
+
1518
+ def _ensure_queues_for_current_loop(self):
1519
+ """Ensure queues are bound to the current event loop."""
1520
+ try:
1521
+ current_loop = asyncio.get_running_loop()
1522
+ except RuntimeError:
1523
+ # No event loop running, can't create queues
1524
+ return
1525
+
1526
+ # Check if queues are already properly set up for current loop
1527
+ if self._queue_event_loop is current_loop:
1528
+ if self.activity_queue is not None and self.orchestration_queue is not None:
1529
+ # Queues are already bound to the current loop and exist
1530
+ return
1531
+
1532
+ # Need to recreate queues for the current event loop
1533
+ # First, preserve any existing work items
1534
+ existing_activity_items = []
1535
+ existing_orchestration_items = []
1536
+
1537
+ if self.activity_queue is not None:
1538
+ try:
1539
+ while not self.activity_queue.empty():
1540
+ existing_activity_items.append(self.activity_queue.get_nowait())
1541
+ except Exception:
1542
+ pass
1543
+
1544
+ if self.orchestration_queue is not None:
1545
+ try:
1546
+ while not self.orchestration_queue.empty():
1547
+ existing_orchestration_items.append(
1548
+ self.orchestration_queue.get_nowait()
1549
+ )
1550
+ except Exception:
1551
+ pass
1552
+
1553
+ # Create fresh queues for the current event loop
1554
+ self.activity_queue = asyncio.Queue()
1555
+ self.orchestration_queue = asyncio.Queue()
1556
+ self._queue_event_loop = current_loop
1557
+
1558
+ # Restore the work items to the new queues
1559
+ for item in existing_activity_items:
1560
+ self.activity_queue.put_nowait(item)
1561
+ for item in existing_orchestration_items:
1562
+ self.orchestration_queue.put_nowait(item)
1563
+
1564
+ # Move pending work items to the queues
1565
+ for item in self._pending_activity_work:
1566
+ self.activity_queue.put_nowait(item)
1567
+ for item in self._pending_orchestration_work:
1568
+ self.orchestration_queue.put_nowait(item)
1569
+
1570
+ # Clear the pending work lists
1571
+ self._pending_activity_work.clear()
1572
+ self._pending_orchestration_work.clear()
1573
+
1574
+ async def run(self):
1575
+ # Reset shutdown flag in case this manager is being reused
1576
+ self._shutdown = False
1577
+
1578
+ # Ensure queues are properly bound to the current event loop
1579
+ self._ensure_queues_for_current_loop()
1580
+
1581
+ # Create semaphores in the current event loop
1582
+ self.activity_semaphore = asyncio.Semaphore(
1583
+ self.concurrency_options.maximum_concurrent_activity_work_items
1584
+ )
1585
+ self.orchestration_semaphore = asyncio.Semaphore(
1586
+ self.concurrency_options.maximum_concurrent_orchestration_work_items
1587
+ )
1588
+
1589
+ # Start background consumers for each work type
1590
+ if self.activity_queue is not None and self.orchestration_queue is not None:
1591
+ await asyncio.gather(
1592
+ self._consume_queue(self.activity_queue, self.activity_semaphore),
1593
+ self._consume_queue(
1594
+ self.orchestration_queue, self.orchestration_semaphore
1595
+ ),
1596
+ )
1597
+
1598
+ async def _consume_queue(self, queue: asyncio.Queue, semaphore: asyncio.Semaphore):
1599
+ # List to track running tasks
1600
+ running_tasks: set[asyncio.Task] = set()
1601
+
1602
+ while True:
1603
+ # Clean up completed tasks
1604
+ done_tasks = {task for task in running_tasks if task.done()}
1605
+ running_tasks -= done_tasks
1606
+
1607
+ # Exit if shutdown is set and the queue is empty and no tasks are running
1608
+ if self._shutdown and queue.empty() and not running_tasks:
1609
+ break
1610
+
1611
+ try:
1612
+ work = await asyncio.wait_for(queue.get(), timeout=1.0)
1613
+ except asyncio.TimeoutError:
1614
+ continue
1615
+
1616
+ func, args, kwargs = work
1617
+ # Create a concurrent task for processing
1618
+ task = asyncio.create_task(
1619
+ self._process_work_item(semaphore, queue, func, args, kwargs)
1620
+ )
1621
+ running_tasks.add(task)
1622
+
1623
+ async def _process_work_item(
1624
+ self, semaphore: asyncio.Semaphore, queue: asyncio.Queue, func, args, kwargs
1625
+ ):
1626
+ async with semaphore:
1627
+ try:
1628
+ await self._run_func(func, *args, **kwargs)
1629
+ finally:
1630
+ queue.task_done()
1631
+
1632
+ async def _run_func(self, func, *args, **kwargs):
1633
+ if inspect.iscoroutinefunction(func):
1634
+ return await func(*args, **kwargs)
1635
+ else:
1636
+ loop = asyncio.get_running_loop()
1637
+ # Avoid submitting to executor after shutdown
1638
+ if (
1639
+ getattr(self, "_shutdown", False) and getattr(self, "thread_pool", None) and getattr(
1640
+ self.thread_pool, "_shutdown", False)
1641
+ ):
1642
+ return None
1643
+ return await loop.run_in_executor(
1644
+ self.thread_pool, lambda: func(*args, **kwargs)
1645
+ )
1646
+
1647
+ def submit_activity(self, func, *args, **kwargs):
1648
+ work_item = (func, args, kwargs)
1649
+ self._ensure_queues_for_current_loop()
1650
+ if self.activity_queue is not None:
1651
+ self.activity_queue.put_nowait(work_item)
1652
+ else:
1653
+ # No event loop running, store in pending list
1654
+ self._pending_activity_work.append(work_item)
1655
+
1656
+ def submit_orchestration(self, func, *args, **kwargs):
1657
+ work_item = (func, args, kwargs)
1658
+ self._ensure_queues_for_current_loop()
1659
+ if self.orchestration_queue is not None:
1660
+ self.orchestration_queue.put_nowait(work_item)
1661
+ else:
1662
+ # No event loop running, store in pending list
1663
+ self._pending_orchestration_work.append(work_item)
1664
+
1665
+ def shutdown(self):
1666
+ self._shutdown = True
1667
+ self.thread_pool.shutdown(wait=True)
1668
+
1669
+ def reset_for_new_run(self):
1670
+ """Reset the manager state for a new run."""
1671
+ self._shutdown = False
1672
+ # Clear any existing queues - they'll be recreated when needed
1673
+ if self.activity_queue is not None:
1674
+ # Clear existing queue by creating a new one
1675
+ # This ensures no items from previous runs remain
1676
+ try:
1677
+ while not self.activity_queue.empty():
1678
+ self.activity_queue.get_nowait()
1679
+ except Exception:
1680
+ pass
1681
+ if self.orchestration_queue is not None:
1682
+ try:
1683
+ while not self.orchestration_queue.empty():
1684
+ self.orchestration_queue.get_nowait()
1685
+ except Exception:
1686
+ pass
1687
+ # Clear pending work lists
1688
+ self._pending_activity_work.clear()
1689
+ self._pending_orchestration_work.clear()
1690
+
1691
+
1692
+ # Export public API
1693
+ __all__ = ["ConcurrencyOptions", "TaskHubGrpcWorker"]