durabletask 0.0.0.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
durabletask/worker.py ADDED
@@ -0,0 +1,2327 @@
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+
4
+ import asyncio
5
+ import inspect
6
+ import json
7
+ import logging
8
+ import os
9
+ import random
10
+ from concurrent.futures import ThreadPoolExecutor
11
+ from datetime import datetime, timedelta, timezone
12
+ from threading import Event, Thread
13
+ from types import GeneratorType
14
+ from enum import Enum
15
+ from typing import Any, Generator, Optional, Sequence, TypeVar, Union
16
+ import uuid
17
+ from packaging.version import InvalidVersion, parse
18
+
19
+ import grpc
20
+ from google.protobuf import empty_pb2
21
+
22
+ from durabletask.internal import helpers
23
+ from durabletask.internal.ProtoTaskHubSidecarServiceStub import ProtoTaskHubSidecarServiceStub
24
+ from durabletask.internal.entity_state_shim import StateShim
25
+ from durabletask.internal.helpers import new_timestamp
26
+ from durabletask.entities import DurableEntity, EntityLock, EntityInstanceId, EntityContext
27
+ from durabletask.internal.orchestration_entity_context import OrchestrationEntityContext
28
+ import durabletask.internal.helpers as ph
29
+ import durabletask.internal.exceptions as pe
30
+ import durabletask.internal.orchestrator_service_pb2 as pb
31
+ import durabletask.internal.orchestrator_service_pb2_grpc as stubs
32
+ import durabletask.internal.shared as shared
33
+ from durabletask import task
34
+ from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl
35
+
36
+ TInput = TypeVar("TInput")
37
+ TOutput = TypeVar("TOutput")
38
+ DATETIME_STRING_FORMAT = '%Y-%m-%dT%H:%M:%S.%fZ'
39
+
40
+
41
+ class ConcurrencyOptions:
42
+ """Configuration options for controlling concurrency of different work item types and the thread pool size.
43
+
44
+ This class provides fine-grained control over concurrent processing limits for
45
+ activities, orchestrations and the thread pool size.
46
+ """
47
+
48
+ def __init__(
49
+ self,
50
+ maximum_concurrent_activity_work_items: Optional[int] = None,
51
+ maximum_concurrent_orchestration_work_items: Optional[int] = None,
52
+ maximum_concurrent_entity_work_items: Optional[int] = None,
53
+ maximum_thread_pool_workers: Optional[int] = None,
54
+ ):
55
+ """Initialize concurrency options.
56
+
57
+ Args:
58
+ maximum_concurrent_activity_work_items: Maximum number of activity work items
59
+ that can be processed concurrently. Defaults to 100 * processor_count.
60
+ maximum_concurrent_orchestration_work_items: Maximum number of orchestration work items
61
+ that can be processed concurrently. Defaults to 100 * processor_count.
62
+ maximum_thread_pool_workers: Maximum number of thread pool workers to use.
63
+ """
64
+ processor_count = os.cpu_count() or 1
65
+ default_concurrency = 100 * processor_count
66
+ # see https://docs.python.org/3/library/concurrent.futures.html
67
+ default_max_workers = processor_count + 4
68
+
69
+ self.maximum_concurrent_activity_work_items = (
70
+ maximum_concurrent_activity_work_items
71
+ if maximum_concurrent_activity_work_items is not None
72
+ else default_concurrency
73
+ )
74
+
75
+ self.maximum_concurrent_orchestration_work_items = (
76
+ maximum_concurrent_orchestration_work_items
77
+ if maximum_concurrent_orchestration_work_items is not None
78
+ else default_concurrency
79
+ )
80
+
81
+ self.maximum_concurrent_entity_work_items = (
82
+ maximum_concurrent_entity_work_items
83
+ if maximum_concurrent_entity_work_items is not None
84
+ else default_concurrency
85
+ )
86
+
87
+ self.maximum_thread_pool_workers = (
88
+ maximum_thread_pool_workers
89
+ if maximum_thread_pool_workers is not None
90
+ else default_max_workers
91
+ )
92
+
93
+
94
+ class VersionMatchStrategy(Enum):
95
+ """Enumeration for version matching strategies."""
96
+
97
+ NONE = 1
98
+ STRICT = 2
99
+ CURRENT_OR_OLDER = 3
100
+
101
+
102
+ class VersionFailureStrategy(Enum):
103
+ """Enumeration for version failure strategies."""
104
+
105
+ REJECT = 1
106
+ FAIL = 2
107
+
108
+
109
+ class VersioningOptions:
110
+ """Configuration options for orchestrator and activity versioning.
111
+
112
+ This class provides options to control how versioning is handled for orchestrators
113
+ and activities, including whether to use the default version and how to compare versions.
114
+ """
115
+
116
+ version: Optional[str] = None
117
+ default_version: Optional[str] = None
118
+ match_strategy: Optional[VersionMatchStrategy] = None
119
+ failure_strategy: Optional[VersionFailureStrategy] = None
120
+
121
+ def __init__(self, version: Optional[str] = None,
122
+ default_version: Optional[str] = None,
123
+ match_strategy: Optional[VersionMatchStrategy] = None,
124
+ failure_strategy: Optional[VersionFailureStrategy] = None
125
+ ):
126
+ """Initialize versioning options.
127
+
128
+ Args:
129
+ version: The version of orchestrations that the worker can work on.
130
+ default_version: The default version that will be used for starting new sub-orchestrations.
131
+ match_strategy: The versioning strategy for the Durable Task worker.
132
+ failure_strategy: The versioning failure strategy for the Durable Task worker.
133
+ """
134
+ self.version = version
135
+ self.default_version = default_version
136
+ self.match_strategy = match_strategy
137
+ self.failure_strategy = failure_strategy
138
+
139
+
140
+ class _Registry:
141
+ orchestrators: dict[str, task.Orchestrator]
142
+ activities: dict[str, task.Activity]
143
+ entities: dict[str, task.Entity]
144
+ entity_instances: dict[str, DurableEntity]
145
+ versioning: Optional[VersioningOptions] = None
146
+
147
+ def __init__(self):
148
+ self.orchestrators = {}
149
+ self.activities = {}
150
+ self.entities = {}
151
+ self.entity_instances = {}
152
+
153
+ def add_orchestrator(self, fn: task.Orchestrator) -> str:
154
+ if fn is None:
155
+ raise ValueError("An orchestrator function argument is required.")
156
+
157
+ name = task.get_name(fn)
158
+ self.add_named_orchestrator(name, fn)
159
+ return name
160
+
161
+ def add_named_orchestrator(self, name: str, fn: task.Orchestrator) -> None:
162
+ if not name:
163
+ raise ValueError("A non-empty orchestrator name is required.")
164
+ if name in self.orchestrators:
165
+ raise ValueError(f"A '{name}' orchestrator already exists.")
166
+
167
+ self.orchestrators[name] = fn
168
+
169
+ def get_orchestrator(self, name: str) -> Optional[task.Orchestrator]:
170
+ return self.orchestrators.get(name)
171
+
172
+ def add_activity(self, fn: task.Activity) -> str:
173
+ if fn is None:
174
+ raise ValueError("An activity function argument is required.")
175
+
176
+ name = task.get_name(fn)
177
+ self.add_named_activity(name, fn)
178
+ return name
179
+
180
+ def add_named_activity(self, name: str, fn: task.Activity) -> None:
181
+ if not name:
182
+ raise ValueError("A non-empty activity name is required.")
183
+ if name in self.activities:
184
+ raise ValueError(f"A '{name}' activity already exists.")
185
+
186
+ self.activities[name] = fn
187
+
188
+ def get_activity(self, name: str) -> Optional[task.Activity]:
189
+ return self.activities.get(name)
190
+
191
+ def add_entity(self, fn: task.Entity) -> str:
192
+ if fn is None:
193
+ raise ValueError("An entity function argument is required.")
194
+
195
+ if isinstance(fn, type) and issubclass(fn, DurableEntity):
196
+ name = fn.__name__
197
+ self.add_named_entity(name, fn)
198
+ else:
199
+ name = task.get_name(fn)
200
+ self.add_named_entity(name, fn)
201
+ return name
202
+
203
+ def add_named_entity(self, name: str, fn: task.Entity) -> None:
204
+ if not name:
205
+ raise ValueError("A non-empty entity name is required.")
206
+ if name in self.entities:
207
+ raise ValueError(f"A '{name}' entity already exists.")
208
+
209
+ self.entities[name] = fn
210
+
211
+ def get_entity(self, name: str) -> Optional[task.Entity]:
212
+ return self.entities.get(name)
213
+
214
+
215
+ class OrchestratorNotRegisteredError(ValueError):
216
+ """Raised when attempting to start an orchestration that is not registered"""
217
+
218
+ pass
219
+
220
+
221
+ class ActivityNotRegisteredError(ValueError):
222
+ """Raised when attempting to call an activity that is not registered"""
223
+
224
+ pass
225
+
226
+
227
+ class EntityNotRegisteredError(ValueError):
228
+ """Raised when attempting to call an entity that is not registered"""
229
+
230
+ pass
231
+
232
+
233
+ class TaskHubGrpcWorker:
234
+ """A gRPC-based worker for processing durable task orchestrations and activities.
235
+
236
+ This worker connects to a Durable Task backend service via gRPC to receive and process
237
+ work items including orchestration functions and activity functions. It provides
238
+ concurrent execution capabilities with configurable limits and automatic retry handling.
239
+
240
+ The worker manages the complete lifecycle:
241
+ - Registers orchestrator and activity functions
242
+ - Connects to the gRPC backend service
243
+ - Receives work items and executes them concurrently
244
+ - Handles failures, retries, and state management
245
+ - Provides logging and monitoring capabilities
246
+
247
+ Args:
248
+ host_address (Optional[str], optional): The gRPC endpoint address of the backend service.
249
+ Defaults to the value from environment variables or localhost.
250
+ metadata (Optional[list[tuple[str, str]]], optional): gRPC metadata to include with
251
+ requests. Used for authentication and routing. Defaults to None.
252
+ log_handler (optional[logging.Handler]): Custom logging handler for worker logs. Defaults to None.
253
+ log_formatter (Optional[logging.Formatter], optional): Custom log formatter.
254
+ Defaults to None.
255
+ secure_channel (bool, optional): Whether to use a secure gRPC channel (TLS).
256
+ Defaults to False.
257
+ interceptors (Optional[Sequence[shared.ClientInterceptor]], optional): Custom gRPC
258
+ interceptors to apply to the channel. Defaults to None.
259
+ concurrency_options (Optional[ConcurrencyOptions], optional): Configuration for
260
+ controlling worker concurrency limits. If None, default settings are used.
261
+
262
+ Attributes:
263
+ concurrency_options (ConcurrencyOptions): The current concurrency configuration.
264
+
265
+ Example:
266
+ Basic worker setup:
267
+
268
+ >>> from durabletask.worker import TaskHubGrpcWorker, ConcurrencyOptions
269
+ >>>
270
+ >>> # Create worker with custom concurrency settings
271
+ >>> concurrency = ConcurrencyOptions(
272
+ ... maximum_concurrent_activity_work_items=50,
273
+ ... maximum_concurrent_orchestration_work_items=20
274
+ ... )
275
+ >>> worker = TaskHubGrpcWorker(
276
+ ... host_address="localhost:4001",
277
+ ... concurrency_options=concurrency
278
+ ... )
279
+ >>>
280
+ >>> # Register functions
281
+ >>> @worker.add_orchestrator
282
+ ... def my_orchestrator(context, input):
283
+ ... result = yield context.call_activity("my_activity", input="hello")
284
+ ... return result
285
+ >>>
286
+ >>> @worker.add_activity
287
+ ... def my_activity(context, input):
288
+ ... return f"Processed: {input}"
289
+ >>>
290
+ >>> # Start the worker
291
+ >>> worker.start()
292
+ >>> # ... worker runs in background thread
293
+ >>> worker.stop()
294
+
295
+ Using as context manager:
296
+
297
+ >>> with TaskHubGrpcWorker() as worker:
298
+ ... worker.add_orchestrator(my_orchestrator)
299
+ ... worker.add_activity(my_activity)
300
+ ... worker.start()
301
+ ... # Worker automatically stops when exiting context
302
+
303
+ Raises:
304
+ RuntimeError: If attempting to add orchestrators/activities while the worker is running,
305
+ or if starting a worker that is already running.
306
+ OrchestratorNotRegisteredError: If an orchestration work item references an
307
+ unregistered orchestrator function.
308
+ ActivityNotRegisteredError: If an activity work item references an unregistered
309
+ activity function.
310
+ """
311
+
312
+ _response_stream: Optional[grpc.Future] = None
313
+ _interceptors: Optional[list[shared.ClientInterceptor]] = None
314
+
315
+ def __init__(
316
+ self,
317
+ *,
318
+ host_address: Optional[str] = None,
319
+ metadata: Optional[list[tuple[str, str]]] = None,
320
+ log_handler: Optional[logging.Handler] = None,
321
+ log_formatter: Optional[logging.Formatter] = None,
322
+ secure_channel: bool = False,
323
+ interceptors: Optional[Sequence[shared.ClientInterceptor]] = None,
324
+ concurrency_options: Optional[ConcurrencyOptions] = None,
325
+ ):
326
+ self._registry = _Registry()
327
+ self._host_address = (
328
+ host_address if host_address else shared.get_default_host_address()
329
+ )
330
+ self._logger = shared.get_logger("worker", log_handler, log_formatter)
331
+ self._shutdown = Event()
332
+ self._is_running = False
333
+ self._secure_channel = secure_channel
334
+
335
+ # Use provided concurrency options or create default ones
336
+ self._concurrency_options = (
337
+ concurrency_options
338
+ if concurrency_options is not None
339
+ else ConcurrencyOptions()
340
+ )
341
+
342
+ # Determine the interceptors to use
343
+ if interceptors is not None:
344
+ self._interceptors = list(interceptors)
345
+ if metadata:
346
+ self._interceptors.append(DefaultClientInterceptorImpl(metadata))
347
+ elif metadata:
348
+ self._interceptors = [DefaultClientInterceptorImpl(metadata)]
349
+ else:
350
+ self._interceptors = None
351
+
352
+ self._async_worker_manager = _AsyncWorkerManager(self._concurrency_options, self._logger)
353
+
354
+ @property
355
+ def concurrency_options(self) -> ConcurrencyOptions:
356
+ """Get the current concurrency options for this worker."""
357
+ return self._concurrency_options
358
+
359
+ def __enter__(self):
360
+ return self
361
+
362
+ def __exit__(self, type, value, traceback):
363
+ self.stop()
364
+
365
+ def add_orchestrator(self, fn: task.Orchestrator) -> str:
366
+ """Registers an orchestrator function with the worker."""
367
+ if self._is_running:
368
+ raise RuntimeError(
369
+ "Orchestrators cannot be added while the worker is running."
370
+ )
371
+ return self._registry.add_orchestrator(fn)
372
+
373
+ def add_activity(self, fn: task.Activity) -> str:
374
+ """Registers an activity function with the worker."""
375
+ if self._is_running:
376
+ raise RuntimeError(
377
+ "Activities cannot be added while the worker is running."
378
+ )
379
+ return self._registry.add_activity(fn)
380
+
381
+ def add_entity(self, fn: task.Entity) -> str:
382
+ """Registers an entity function with the worker."""
383
+ if self._is_running:
384
+ raise RuntimeError(
385
+ "Entities cannot be added while the worker is running."
386
+ )
387
+ return self._registry.add_entity(fn)
388
+
389
+ def use_versioning(self, version: VersioningOptions) -> None:
390
+ """Initializes versioning options for sub-orchestrators and activities."""
391
+ if self._is_running:
392
+ raise RuntimeError("Cannot set default version while the worker is running.")
393
+ self._registry.versioning = version
394
+
395
+ def start(self):
396
+ """Starts the worker on a background thread and begins listening for work items."""
397
+ if self._is_running:
398
+ raise RuntimeError("The worker is already running.")
399
+
400
+ def run_loop():
401
+ loop = asyncio.new_event_loop()
402
+ asyncio.set_event_loop(loop)
403
+ loop.run_until_complete(self._async_run_loop())
404
+
405
+ self._logger.info(f"Starting gRPC worker that connects to {self._host_address}")
406
+ self._runLoop = Thread(target=run_loop)
407
+ self._runLoop.start()
408
+ self._is_running = True
409
+
410
+ async def _async_run_loop(self):
411
+ worker_task = asyncio.create_task(self._async_worker_manager.run())
412
+ # Connection state management for retry fix
413
+ current_channel = None
414
+ current_stub = None
415
+ current_reader_thread = None
416
+ conn_retry_count = 0
417
+ conn_max_retry_delay = 60
418
+
419
+ def create_fresh_connection():
420
+ nonlocal current_channel, current_stub, conn_retry_count
421
+ if current_channel:
422
+ try:
423
+ current_channel.close()
424
+ except Exception:
425
+ pass
426
+ current_channel = None
427
+ current_stub = None
428
+ try:
429
+ current_channel = shared.get_grpc_channel(
430
+ self._host_address, self._secure_channel, self._interceptors
431
+ )
432
+ current_stub = stubs.TaskHubSidecarServiceStub(current_channel)
433
+ current_stub.Hello(empty_pb2.Empty())
434
+ conn_retry_count = 0
435
+ self._logger.info(f"Created fresh connection to {self._host_address}")
436
+ except Exception as e:
437
+ self._logger.warning(f"Failed to create connection: {e}")
438
+ current_channel = None
439
+ current_stub = None
440
+ raise
441
+
442
+ def invalidate_connection():
443
+ nonlocal current_channel, current_stub, current_reader_thread
444
+ # Cancel the response stream first to signal the reader thread to stop
445
+ if self._response_stream is not None:
446
+ try:
447
+ self._response_stream.cancel()
448
+ except Exception:
449
+ pass
450
+ self._response_stream = None
451
+
452
+ # Wait for the reader thread to finish
453
+ if current_reader_thread is not None:
454
+ try:
455
+ current_reader_thread.join(timeout=2)
456
+ if current_reader_thread.is_alive():
457
+ self._logger.warning("Stream reader thread did not shut down gracefully")
458
+ except Exception:
459
+ pass
460
+ current_reader_thread = None
461
+
462
+ # Close the channel
463
+ if current_channel:
464
+ try:
465
+ current_channel.close()
466
+ except Exception:
467
+ pass
468
+ current_channel = None
469
+ current_stub = None
470
+
471
+ def should_invalidate_connection(rpc_error):
472
+ error_code = rpc_error.code() # type: ignore
473
+ connection_level_errors = {
474
+ grpc.StatusCode.UNAVAILABLE,
475
+ grpc.StatusCode.DEADLINE_EXCEEDED,
476
+ grpc.StatusCode.CANCELLED,
477
+ grpc.StatusCode.UNAUTHENTICATED,
478
+ grpc.StatusCode.ABORTED,
479
+ }
480
+ return error_code in connection_level_errors
481
+
482
+ while not self._shutdown.is_set():
483
+ if current_stub is None:
484
+ try:
485
+ create_fresh_connection()
486
+ except Exception:
487
+ conn_retry_count += 1
488
+ delay = min(
489
+ conn_max_retry_delay,
490
+ (2 ** min(conn_retry_count, 6)) + random.uniform(0, 1),
491
+ )
492
+ self._logger.warning(
493
+ f"Connection failed, retrying in {delay:.2f} seconds (attempt {conn_retry_count})"
494
+ )
495
+ if self._shutdown.wait(delay):
496
+ break
497
+ continue
498
+ try:
499
+ assert current_stub is not None
500
+ stub = current_stub
501
+ get_work_items_request = pb.GetWorkItemsRequest(
502
+ maxConcurrentOrchestrationWorkItems=self._concurrency_options.maximum_concurrent_orchestration_work_items,
503
+ maxConcurrentActivityWorkItems=self._concurrency_options.maximum_concurrent_activity_work_items,
504
+ )
505
+ self._response_stream = stub.GetWorkItems(get_work_items_request)
506
+ self._logger.info(
507
+ f"Successfully connected to {self._host_address}. Waiting for work items..."
508
+ )
509
+
510
+ # Use a thread to read from the blocking gRPC stream and forward to asyncio
511
+ import queue
512
+
513
+ work_item_queue = queue.Queue()
514
+
515
+ def stream_reader():
516
+ try:
517
+ for work_item in self._response_stream:
518
+ work_item_queue.put(work_item)
519
+ except Exception as e:
520
+ work_item_queue.put(e)
521
+
522
+ import threading
523
+
524
+ current_reader_thread = threading.Thread(target=stream_reader, daemon=True)
525
+ current_reader_thread.start()
526
+ loop = asyncio.get_running_loop()
527
+ while not self._shutdown.is_set():
528
+ try:
529
+ work_item = await loop.run_in_executor(
530
+ None, work_item_queue.get
531
+ )
532
+ if isinstance(work_item, Exception):
533
+ raise work_item
534
+ request_type = work_item.WhichOneof("request")
535
+ self._logger.debug(f'Received "{request_type}" work item')
536
+ if work_item.HasField("orchestratorRequest"):
537
+ self._async_worker_manager.submit_orchestration(
538
+ self._execute_orchestrator,
539
+ self._cancel_orchestrator,
540
+ work_item.orchestratorRequest,
541
+ stub,
542
+ work_item.completionToken,
543
+ )
544
+ elif work_item.HasField("activityRequest"):
545
+ self._async_worker_manager.submit_activity(
546
+ self._execute_activity,
547
+ self._cancel_activity,
548
+ work_item.activityRequest,
549
+ stub,
550
+ work_item.completionToken,
551
+ )
552
+ elif work_item.HasField("entityRequest"):
553
+ self._async_worker_manager.submit_entity_batch(
554
+ self._execute_entity_batch,
555
+ self._cancel_entity_batch,
556
+ work_item.entityRequest,
557
+ stub,
558
+ work_item.completionToken,
559
+ )
560
+ elif work_item.HasField("entityRequestV2"):
561
+ self._async_worker_manager.submit_entity_batch(
562
+ self._execute_entity_batch,
563
+ self._cancel_entity_batch,
564
+ work_item.entityRequestV2,
565
+ stub,
566
+ work_item.completionToken
567
+ )
568
+ elif work_item.HasField("healthPing"):
569
+ pass
570
+ else:
571
+ self._logger.warning(
572
+ f"Unexpected work item type: {request_type}"
573
+ )
574
+ except Exception as e:
575
+ self._logger.warning(f"Error in work item stream: {e}")
576
+ raise e
577
+ current_reader_thread.join(timeout=1)
578
+ self._logger.info("Work item stream ended normally")
579
+ except grpc.RpcError as rpc_error:
580
+ should_invalidate = should_invalidate_connection(rpc_error)
581
+ if should_invalidate:
582
+ invalidate_connection()
583
+ error_code = rpc_error.code() # type: ignore
584
+ error_details = str(rpc_error)
585
+
586
+ if error_code == grpc.StatusCode.CANCELLED:
587
+ self._logger.info(f"Disconnected from {self._host_address}")
588
+ break
589
+ elif error_code == grpc.StatusCode.UNAVAILABLE:
590
+ # Check if this is a connection timeout scenario
591
+ if "Timeout occurred" in error_details or "Failed to connect to remote host" in error_details:
592
+ self._logger.warning(
593
+ f"Connection timeout to {self._host_address}: {error_details} - will retry with fresh connection"
594
+ )
595
+ else:
596
+ self._logger.warning(
597
+ f"The sidecar at address {self._host_address} is unavailable: {error_details} - will continue retrying"
598
+ )
599
+ elif should_invalidate:
600
+ self._logger.warning(
601
+ f"Connection-level gRPC error ({error_code}): {rpc_error} - resetting connection"
602
+ )
603
+ else:
604
+ self._logger.warning(
605
+ f"Application-level gRPC error ({error_code}): {rpc_error}"
606
+ )
607
+ self._shutdown.wait(1)
608
+ except Exception as ex:
609
+ invalidate_connection()
610
+ self._logger.warning(f"Unexpected error: {ex}")
611
+ self._shutdown.wait(1)
612
+ invalidate_connection()
613
+ self._logger.info("No longer listening for work items")
614
+ self._async_worker_manager.shutdown()
615
+ await worker_task
616
+
617
+ def stop(self):
618
+ """Stops the worker and waits for any pending work items to complete."""
619
+ if not self._is_running:
620
+ return
621
+
622
+ self._logger.info("Stopping gRPC worker...")
623
+ self._shutdown.set()
624
+ if self._response_stream is not None:
625
+ self._response_stream.cancel()
626
+ if self._runLoop is not None:
627
+ self._runLoop.join(timeout=30)
628
+ self._async_worker_manager.shutdown()
629
+ self._logger.info("Worker shutdown completed")
630
+ self._is_running = False
631
+
632
+ def _execute_orchestrator(
633
+ self,
634
+ req: pb.OrchestratorRequest,
635
+ stub: Union[stubs.TaskHubSidecarServiceStub, ProtoTaskHubSidecarServiceStub],
636
+ completionToken,
637
+ ):
638
+ try:
639
+ executor = _OrchestrationExecutor(self._registry, self._logger)
640
+ result = executor.execute(req.instanceId, req.pastEvents, req.newEvents)
641
+ res = pb.OrchestratorResponse(
642
+ instanceId=req.instanceId,
643
+ actions=result.actions,
644
+ customStatus=ph.get_string_value(result.encoded_custom_status),
645
+ completionToken=completionToken,
646
+ )
647
+ except pe.AbandonOrchestrationError:
648
+ self._logger.info(
649
+ f"Abandoning orchestration. InstanceId = '{req.instanceId}'. Completion token = '{completionToken}'"
650
+ )
651
+ stub.AbandonTaskOrchestratorWorkItem(
652
+ pb.AbandonOrchestrationTaskRequest(
653
+ completionToken=completionToken
654
+ )
655
+ )
656
+ return
657
+ except Exception as ex:
658
+ self._logger.exception(
659
+ f"An error occurred while trying to execute instance '{req.instanceId}': {ex}"
660
+ )
661
+ failure_details = ph.new_failure_details(ex)
662
+ actions = [
663
+ ph.new_complete_orchestration_action(
664
+ -1, pb.ORCHESTRATION_STATUS_FAILED, "", failure_details
665
+ )
666
+ ]
667
+ res = pb.OrchestratorResponse(
668
+ instanceId=req.instanceId,
669
+ actions=actions,
670
+ completionToken=completionToken,
671
+ )
672
+
673
+ try:
674
+ stub.CompleteOrchestratorTask(res)
675
+ except Exception as ex:
676
+ self._logger.exception(
677
+ f"Failed to deliver orchestrator response for '{req.instanceId}' to sidecar: {ex}"
678
+ )
679
+
680
+ def _cancel_orchestrator(
681
+ self,
682
+ req: pb.OrchestratorRequest,
683
+ stub: stubs.TaskHubSidecarServiceStub,
684
+ completionToken,
685
+ ):
686
+ stub.AbandonTaskOrchestratorWorkItem(
687
+ pb.AbandonOrchestrationTaskRequest(
688
+ completionToken=completionToken
689
+ )
690
+ )
691
+ self._logger.info(f"Cancelled orchestration task for invocation ID: {req.instanceId}")
692
+
693
+ def _execute_activity(
694
+ self,
695
+ req: pb.ActivityRequest,
696
+ stub: stubs.TaskHubSidecarServiceStub,
697
+ completionToken,
698
+ ):
699
+ instance_id = req.orchestrationInstance.instanceId
700
+ try:
701
+ executor = _ActivityExecutor(self._registry, self._logger)
702
+ result = executor.execute(
703
+ instance_id, req.name, req.taskId, req.input.value
704
+ )
705
+ res = pb.ActivityResponse(
706
+ instanceId=instance_id,
707
+ taskId=req.taskId,
708
+ result=ph.get_string_value(result),
709
+ completionToken=completionToken,
710
+ )
711
+ except Exception as ex:
712
+ res = pb.ActivityResponse(
713
+ instanceId=instance_id,
714
+ taskId=req.taskId,
715
+ failureDetails=ph.new_failure_details(ex),
716
+ completionToken=completionToken,
717
+ )
718
+
719
+ try:
720
+ stub.CompleteActivityTask(res)
721
+ except Exception as ex:
722
+ self._logger.exception(
723
+ f"Failed to deliver activity response for '{req.name}#{req.taskId}' of orchestration ID '{instance_id}' to sidecar: {ex}"
724
+ )
725
+
726
+ def _cancel_activity(
727
+ self,
728
+ req: pb.ActivityRequest,
729
+ stub: stubs.TaskHubSidecarServiceStub,
730
+ completionToken,
731
+ ):
732
+ stub.AbandonTaskActivityWorkItem(
733
+ pb.AbandonActivityTaskRequest(
734
+ completionToken=completionToken
735
+ )
736
+ )
737
+ self._logger.info(f"Cancelled activity task for task ID: {req.taskId} on orchestration ID: {req.orchestrationInstance.instanceId}")
738
+
739
+ def _execute_entity_batch(
740
+ self,
741
+ req: Union[pb.EntityBatchRequest, pb.EntityRequest],
742
+ stub: Union[stubs.TaskHubSidecarServiceStub, ProtoTaskHubSidecarServiceStub],
743
+ completionToken,
744
+ ):
745
+ operation_infos = None
746
+ if isinstance(req, pb.EntityRequest):
747
+ req, operation_infos = helpers.convert_to_entity_batch_request(req)
748
+
749
+ entity_state = StateShim(shared.from_json(req.entityState.value) if req.entityState.value else None)
750
+
751
+ instance_id = req.instanceId
752
+
753
+ results: list[pb.OperationResult] = []
754
+ for operation in req.operations:
755
+ start_time = datetime.now(timezone.utc)
756
+ executor = _EntityExecutor(self._registry, self._logger)
757
+ entity_instance_id = EntityInstanceId.parse(instance_id)
758
+ if not entity_instance_id:
759
+ raise RuntimeError(f"Invalid entity instance ID '{operation.requestId}' in entity operation request.")
760
+
761
+ operation_result = None
762
+
763
+ try:
764
+ entity_result = executor.execute(
765
+ instance_id, entity_instance_id, operation.operation, entity_state, operation.input.value
766
+ )
767
+
768
+ entity_result = ph.get_string_value_or_empty(entity_result)
769
+ operation_result = pb.OperationResult(success=pb.OperationResultSuccess(
770
+ result=entity_result,
771
+ startTimeUtc=new_timestamp(start_time),
772
+ endTimeUtc=new_timestamp(datetime.now(timezone.utc))
773
+ ))
774
+ results.append(operation_result)
775
+
776
+ entity_state.commit()
777
+ except Exception as ex:
778
+ self._logger.exception(ex)
779
+ operation_result = pb.OperationResult(failure=pb.OperationResultFailure(
780
+ failureDetails=ph.new_failure_details(ex),
781
+ startTimeUtc=new_timestamp(start_time),
782
+ endTimeUtc=new_timestamp(datetime.now(timezone.utc))
783
+ ))
784
+ results.append(operation_result)
785
+
786
+ entity_state.rollback()
787
+
788
+ batch_result = pb.EntityBatchResult(
789
+ results=results,
790
+ actions=entity_state.get_operation_actions(),
791
+ entityState=helpers.get_string_value(shared.to_json(entity_state._current_state)) if entity_state._current_state else None,
792
+ failureDetails=None,
793
+ completionToken=completionToken,
794
+ operationInfos=operation_infos,
795
+ )
796
+
797
+ try:
798
+ stub.CompleteEntityTask(batch_result)
799
+ except Exception as ex:
800
+ self._logger.exception(
801
+ f"Failed to deliver entity response for orchestration ID '{instance_id}' to sidecar: {ex}"
802
+ )
803
+
804
+ # TODO: Reset context
805
+
806
+ return batch_result
807
+
808
+ def _cancel_entity_batch(
809
+ self,
810
+ req: Union[pb.EntityBatchRequest, pb.EntityRequest],
811
+ stub: stubs.TaskHubSidecarServiceStub,
812
+ completionToken,
813
+ ):
814
+ stub.AbandonTaskEntityWorkItem(
815
+ pb.AbandonEntityTaskRequest(
816
+ completionToken=completionToken
817
+ )
818
+ )
819
+ self._logger.info(f"Cancelled entity batch task for instance ID: {req.instanceId}")
820
+
821
+
822
+ class _RuntimeOrchestrationContext(task.OrchestrationContext):
823
+ _generator: Optional[Generator[task.Task, Any, Any]]
824
+ _previous_task: Optional[task.Task]
825
+
826
+ def __init__(self, instance_id: str, registry: _Registry):
827
+ self._generator = None
828
+ self._is_replaying = True
829
+ self._is_complete = False
830
+ self._result = None
831
+ self._pending_actions: dict[int, pb.OrchestratorAction] = {}
832
+ self._pending_tasks: dict[int, task.CompletableTask] = {}
833
+ # Maps entity ID to task ID
834
+ self._entity_task_id_map: dict[str, tuple[EntityInstanceId, int, Optional[str]]] = {}
835
+ # Maps criticalSectionId to task ID
836
+ self._entity_lock_id_map: dict[str, int] = {}
837
+ self._sequence_number = 0
838
+ self._new_uuid_counter = 0
839
+ self._current_utc_datetime = datetime(1000, 1, 1)
840
+ self._instance_id = instance_id
841
+ self._registry = registry
842
+ self._entity_context = OrchestrationEntityContext(instance_id)
843
+ self._version: Optional[str] = None
844
+ self._completion_status: Optional[pb.OrchestrationStatus] = None
845
+ self._received_events: dict[str, list[Any]] = {}
846
+ self._pending_events: dict[str, list[task.CompletableTask]] = {}
847
+ self._new_input: Optional[Any] = None
848
+ self._save_events = False
849
+ self._encoded_custom_status: Optional[str] = None
850
+
851
+ def run(self, generator: Generator[task.Task, Any, Any]):
852
+ self._generator = generator
853
+ # TODO: Do something with this task
854
+ task = next(generator) # this starts the generator
855
+ # TODO: Check if the task is null?
856
+ self._previous_task = task
857
+
858
+ def resume(self):
859
+ if self._generator is None:
860
+ # This is never expected unless maybe there's an issue with the history
861
+ raise TypeError(
862
+ "The orchestrator generator is not initialized! Was the orchestration history corrupted?"
863
+ )
864
+
865
+ # We can resume the generator only if the previously yielded task
866
+ # has reached a completed state. The only time this won't be the
867
+ # case is if the user yielded on a WhenAll task and there are still
868
+ # outstanding child tasks that need to be completed.
869
+ while self._previous_task is not None and self._previous_task.is_complete:
870
+ next_task = None
871
+ if self._previous_task.is_failed:
872
+ # Raise the failure as an exception to the generator.
873
+ # The orchestrator can then either handle the exception or allow it to fail the orchestration.
874
+ next_task = self._generator.throw(self._previous_task.get_exception())
875
+ else:
876
+ # Resume the generator with the previous result.
877
+ # This will either return a Task or raise StopIteration if it's done.
878
+ next_task = self._generator.send(self._previous_task.get_result())
879
+
880
+ if not isinstance(next_task, task.Task):
881
+ raise TypeError("The orchestrator generator yielded a non-Task object")
882
+ self._previous_task = next_task
883
+
884
+ def set_complete(
885
+ self,
886
+ result: Any,
887
+ status: pb.OrchestrationStatus,
888
+ is_result_encoded: bool = False,
889
+ ):
890
+ if self._is_complete:
891
+ return
892
+
893
+ # If the user code returned without yielding the entity unlock, do that now
894
+ if self._entity_context.is_inside_critical_section:
895
+ self._exit_critical_section()
896
+
897
+ self._is_complete = True
898
+ self._completion_status = status
899
+ # This is probably a bug - an orchestrator may complete with some actions remaining that the user still
900
+ # wants to execute - for example, signaling an entity. So we shouldn't clear the pending actions here.
901
+ # self._pending_actions.clear() # Cancel any pending actions
902
+
903
+ self._result = result
904
+ result_json: Optional[str] = None
905
+ if result is not None:
906
+ result_json = result if is_result_encoded else shared.to_json(result)
907
+ action = ph.new_complete_orchestration_action(
908
+ self.next_sequence_number(), status, result_json
909
+ )
910
+ self._pending_actions[action.id] = action
911
+
912
+ def set_failed(self, ex: Union[Exception, pb.TaskFailureDetails]):
913
+ if self._is_complete:
914
+ return
915
+
916
+ # If the user code crashed inside a critical section, or did not exit it, do that now
917
+ if self._entity_context.is_inside_critical_section:
918
+ self._exit_critical_section()
919
+
920
+ self._is_complete = True
921
+ # We also cannot cancel the pending actions in the failure case - if the user code had released an entity
922
+ # lock, we *must* send that action to the sidecar.
923
+ # self._pending_actions.clear() # Cancel any pending actions
924
+ self._completion_status = pb.ORCHESTRATION_STATUS_FAILED
925
+
926
+ action = ph.new_complete_orchestration_action(
927
+ self.next_sequence_number(),
928
+ pb.ORCHESTRATION_STATUS_FAILED,
929
+ None,
930
+ ph.new_failure_details(ex) if isinstance(ex, Exception) else ex,
931
+ )
932
+ self._pending_actions[action.id] = action
933
+
934
+ def set_continued_as_new(self, new_input: Any, save_events: bool):
935
+ if self._is_complete:
936
+ return
937
+
938
+ # If the user code called continue_as_new while holding an entity lock, unlock it now
939
+ if self._entity_context.is_inside_critical_section:
940
+ self._exit_critical_section()
941
+
942
+ self._is_complete = True
943
+ # We also cannot cancel the pending actions in the continue as new case - if the user code had released an
944
+ # entity lock, we *must* send that action to the sidecar.
945
+ # self._pending_actions.clear() # Cancel any pending actions
946
+ self._completion_status = pb.ORCHESTRATION_STATUS_CONTINUED_AS_NEW
947
+ self._new_input = new_input
948
+ self._save_events = save_events
949
+
950
+ def get_actions(self) -> list[pb.OrchestratorAction]:
951
+ current_actions = list(self._pending_actions.values())
952
+ if self._completion_status == pb.ORCHESTRATION_STATUS_CONTINUED_AS_NEW:
953
+ # When continuing-as-new, we only return a single completion action.
954
+ carryover_events: Optional[list[pb.HistoryEvent]] = None
955
+ if self._save_events:
956
+ carryover_events = []
957
+ # We need to save the current set of pending events so that they can be
958
+ # replayed when the new instance starts.
959
+ for event_name, values in self._received_events.items():
960
+ for event_value in values:
961
+ encoded_value = (
962
+ shared.to_json(event_value) if event_value else None
963
+ )
964
+ carryover_events.append(
965
+ ph.new_event_raised_event(event_name, encoded_value)
966
+ )
967
+ action = ph.new_complete_orchestration_action(
968
+ self.next_sequence_number(),
969
+ pb.ORCHESTRATION_STATUS_CONTINUED_AS_NEW,
970
+ result=shared.to_json(self._new_input)
971
+ if self._new_input is not None
972
+ else None,
973
+ failure_details=None,
974
+ carryover_events=carryover_events,
975
+ )
976
+ # We must return the existing tasks as well, to capture entity unlocks
977
+ current_actions.append(action)
978
+ return current_actions
979
+
980
+ def next_sequence_number(self) -> int:
981
+ self._sequence_number += 1
982
+ return self._sequence_number
983
+
984
+ @property
985
+ def instance_id(self) -> str:
986
+ return self._instance_id
987
+
988
+ @property
989
+ def version(self) -> Optional[str]:
990
+ return self._version
991
+
992
+ @property
993
+ def current_utc_datetime(self) -> datetime:
994
+ return self._current_utc_datetime
995
+
996
+ @current_utc_datetime.setter
997
+ def current_utc_datetime(self, value: datetime):
998
+ self._current_utc_datetime = value
999
+
1000
+ @property
1001
+ def is_replaying(self) -> bool:
1002
+ return self._is_replaying
1003
+
1004
+ def set_custom_status(self, custom_status: Any) -> None:
1005
+ self._encoded_custom_status = (
1006
+ shared.to_json(custom_status) if custom_status is not None else None
1007
+ )
1008
+
1009
+ def create_timer(self, fire_at: Union[datetime, timedelta]) -> task.Task:
1010
+ return self.create_timer_internal(fire_at)
1011
+
1012
+ def create_timer_internal(
1013
+ self,
1014
+ fire_at: Union[datetime, timedelta],
1015
+ retryable_task: Optional[task.RetryableTask] = None,
1016
+ ) -> task.Task:
1017
+ id = self.next_sequence_number()
1018
+ if isinstance(fire_at, timedelta):
1019
+ fire_at = self.current_utc_datetime + fire_at
1020
+ action = ph.new_create_timer_action(id, fire_at)
1021
+ self._pending_actions[id] = action
1022
+
1023
+ timer_task: task.TimerTask = task.TimerTask()
1024
+ if retryable_task is not None:
1025
+ timer_task.set_retryable_parent(retryable_task)
1026
+ self._pending_tasks[id] = timer_task
1027
+ return timer_task
1028
+
1029
+ def call_activity(
1030
+ self,
1031
+ activity: Union[task.Activity[TInput, TOutput], str],
1032
+ *,
1033
+ input: Optional[TInput] = None,
1034
+ retry_policy: Optional[task.RetryPolicy] = None,
1035
+ tags: Optional[dict[str, str]] = None,
1036
+ ) -> task.Task[TOutput]:
1037
+ id = self.next_sequence_number()
1038
+
1039
+ self.call_activity_function_helper(
1040
+ id, activity, input=input, retry_policy=retry_policy, is_sub_orch=False, tags=tags
1041
+ )
1042
+ return self._pending_tasks.get(id, task.CompletableTask())
1043
+
1044
+ def call_entity(
1045
+ self,
1046
+ entity: EntityInstanceId,
1047
+ operation: str,
1048
+ input: Optional[TInput] = None,
1049
+ ) -> task.Task:
1050
+ id = self.next_sequence_number()
1051
+
1052
+ self.call_entity_function_helper(
1053
+ id, entity, operation, input=input
1054
+ )
1055
+
1056
+ return self._pending_tasks.get(id, task.CompletableTask())
1057
+
1058
+ def signal_entity(
1059
+ self,
1060
+ entity_id: EntityInstanceId,
1061
+ operation_name: str,
1062
+ input: Optional[TInput] = None
1063
+ ) -> None:
1064
+ id = self.next_sequence_number()
1065
+
1066
+ self.signal_entity_function_helper(
1067
+ id, entity_id, operation_name, input
1068
+ )
1069
+
1070
+ def lock_entities(self, entities: list[EntityInstanceId]) -> task.Task[EntityLock]:
1071
+ id = self.next_sequence_number()
1072
+
1073
+ self.lock_entities_function_helper(
1074
+ id, entities
1075
+ )
1076
+ return self._pending_tasks.get(id, task.CompletableTask())
1077
+
1078
+ def call_sub_orchestrator(
1079
+ self,
1080
+ orchestrator: Union[task.Orchestrator[TInput, TOutput], str],
1081
+ *,
1082
+ input: Optional[TInput] = None,
1083
+ instance_id: Optional[str] = None,
1084
+ retry_policy: Optional[task.RetryPolicy] = None,
1085
+ version: Optional[str] = None,
1086
+ ) -> task.Task[TOutput]:
1087
+ id = self.next_sequence_number()
1088
+ if isinstance(orchestrator, str):
1089
+ orchestrator_name = orchestrator
1090
+ else:
1091
+ orchestrator_name = task.get_name(orchestrator)
1092
+ default_version = self._registry.versioning.default_version if self._registry.versioning else None
1093
+ orchestrator_version = version if version else default_version
1094
+ self.call_activity_function_helper(
1095
+ id,
1096
+ orchestrator_name,
1097
+ input=input,
1098
+ retry_policy=retry_policy,
1099
+ is_sub_orch=True,
1100
+ instance_id=instance_id,
1101
+ version=orchestrator_version
1102
+ )
1103
+ return self._pending_tasks.get(id, task.CompletableTask())
1104
+
1105
+ def call_activity_function_helper(
1106
+ self,
1107
+ id: Optional[int],
1108
+ activity_function: Union[task.Activity[TInput, TOutput], str],
1109
+ *,
1110
+ input: Optional[TInput] = None,
1111
+ retry_policy: Optional[task.RetryPolicy] = None,
1112
+ tags: Optional[dict[str, str]] = None,
1113
+ is_sub_orch: bool = False,
1114
+ instance_id: Optional[str] = None,
1115
+ fn_task: Optional[task.CompletableTask[TOutput]] = None,
1116
+ version: Optional[str] = None,
1117
+ ):
1118
+ if id is None:
1119
+ id = self.next_sequence_number()
1120
+
1121
+ if fn_task is None:
1122
+ encoded_input = shared.to_json(input) if input is not None else None
1123
+ else:
1124
+ # Here, we don't need to convert the input to JSON because it is already converted.
1125
+ # We just need to take string representation of it.
1126
+ encoded_input = str(input)
1127
+ if not is_sub_orch:
1128
+ name = (
1129
+ activity_function
1130
+ if isinstance(activity_function, str)
1131
+ else task.get_name(activity_function)
1132
+ )
1133
+ action = ph.new_schedule_task_action(id, name, encoded_input, tags)
1134
+ else:
1135
+ if instance_id is None:
1136
+ # Create a deteministic instance ID based on the parent instance ID
1137
+ instance_id = f"{self.instance_id}:{id:04x}"
1138
+ if not isinstance(activity_function, str):
1139
+ raise ValueError("Orchestrator function name must be a string")
1140
+ action = ph.new_create_sub_orchestration_action(
1141
+ id, activity_function, instance_id, encoded_input, version
1142
+ )
1143
+ self._pending_actions[id] = action
1144
+
1145
+ if fn_task is None:
1146
+ if retry_policy is None:
1147
+ fn_task = task.CompletableTask[TOutput]()
1148
+ else:
1149
+ fn_task = task.RetryableTask[TOutput](
1150
+ retry_policy=retry_policy,
1151
+ action=action,
1152
+ start_time=self.current_utc_datetime,
1153
+ is_sub_orch=is_sub_orch,
1154
+ )
1155
+ self._pending_tasks[id] = fn_task
1156
+
1157
+ def call_entity_function_helper(
1158
+ self,
1159
+ id: Optional[int],
1160
+ entity_id: EntityInstanceId,
1161
+ operation: str,
1162
+ *,
1163
+ input: Optional[TInput] = None,
1164
+ ):
1165
+ if id is None:
1166
+ id = self.next_sequence_number()
1167
+
1168
+ transition_valid, error_message = self._entity_context.validate_operation_transition(entity_id, False)
1169
+ if not transition_valid:
1170
+ raise RuntimeError(error_message)
1171
+
1172
+ encoded_input = shared.to_json(input) if input is not None else None
1173
+ action = ph.new_call_entity_action(id,
1174
+ self.instance_id,
1175
+ entity_id,
1176
+ operation,
1177
+ encoded_input,
1178
+ self.new_uuid())
1179
+ self._pending_actions[id] = action
1180
+
1181
+ fn_task = task.CompletableTask()
1182
+ self._pending_tasks[id] = fn_task
1183
+
1184
+ def signal_entity_function_helper(
1185
+ self,
1186
+ id: Optional[int],
1187
+ entity_id: EntityInstanceId,
1188
+ operation: str,
1189
+ input: Optional[TInput]
1190
+ ) -> None:
1191
+ if id is None:
1192
+ id = self.next_sequence_number()
1193
+
1194
+ transition_valid, error_message = self._entity_context.validate_operation_transition(entity_id, True)
1195
+
1196
+ if not transition_valid:
1197
+ raise RuntimeError(error_message)
1198
+
1199
+ encoded_input = shared.to_json(input) if input is not None else None
1200
+
1201
+ action = ph.new_signal_entity_action(id, entity_id, operation, encoded_input, self.new_uuid())
1202
+ self._pending_actions[id] = action
1203
+
1204
+ def lock_entities_function_helper(self, id: int, entities: list[EntityInstanceId]) -> None:
1205
+ if id is None:
1206
+ id = self.next_sequence_number()
1207
+
1208
+ transition_valid, error_message = self._entity_context.validate_acquire_transition()
1209
+ if not transition_valid:
1210
+ raise RuntimeError(error_message)
1211
+
1212
+ critical_section_id = self.new_uuid()
1213
+
1214
+ request, target = self._entity_context.emit_acquire_message(critical_section_id, entities)
1215
+
1216
+ if not request or not target:
1217
+ raise RuntimeError("Failed to create entity lock request.")
1218
+
1219
+ action = ph.new_lock_entities_action(id, request)
1220
+ self._pending_actions[id] = action
1221
+
1222
+ fn_task = task.CompletableTask[EntityLock]()
1223
+ self._pending_tasks[id] = fn_task
1224
+
1225
+ def _exit_critical_section(self) -> None:
1226
+ if not self._entity_context.is_inside_critical_section:
1227
+ # Possible if the user calls continue_as_new inside the lock - in the success case, we will call
1228
+ # _exit_critical_section both from the EntityLock and the continue_as_new logic. We must keep both calls in
1229
+ # case the user code crashes after calling continue_as_new but before the EntityLock object is exited.
1230
+ return
1231
+ for entity_unlock_message in self._entity_context.emit_lock_release_messages():
1232
+ task_id = self.next_sequence_number()
1233
+ action = pb.OrchestratorAction(id=task_id, sendEntityMessage=entity_unlock_message)
1234
+ self._pending_actions[task_id] = action
1235
+
1236
+ def wait_for_external_event(self, name: str) -> task.Task:
1237
+ # Check to see if this event has already been received, in which case we
1238
+ # can return it immediately. Otherwise, record out intent to receive an
1239
+ # event with the given name so that we can resume the generator when it
1240
+ # arrives. If there are multiple events with the same name, we return
1241
+ # them in the order they were received.
1242
+ external_event_task: task.CompletableTask = task.CompletableTask()
1243
+ event_name = name.casefold()
1244
+ event_list = self._received_events.get(event_name, None)
1245
+ if event_list:
1246
+ event_data = event_list.pop(0)
1247
+ if not event_list:
1248
+ del self._received_events[event_name]
1249
+ external_event_task.complete(event_data)
1250
+ else:
1251
+ task_list = self._pending_events.get(event_name, None)
1252
+ if not task_list:
1253
+ task_list = []
1254
+ self._pending_events[event_name] = task_list
1255
+ task_list.append(external_event_task)
1256
+ return external_event_task
1257
+
1258
+ def continue_as_new(self, new_input, *, save_events: bool = False) -> None:
1259
+ if self._is_complete:
1260
+ return
1261
+
1262
+ self.set_continued_as_new(new_input, save_events)
1263
+
1264
+ def new_uuid(self) -> str:
1265
+ URL_NAMESPACE: str = "9e952958-5e33-4daf-827f-2fa12937b875"
1266
+
1267
+ uuid_name_value = \
1268
+ f"{self._instance_id}" \
1269
+ f"_{self.current_utc_datetime.strftime(DATETIME_STRING_FORMAT)}" \
1270
+ f"_{self._new_uuid_counter}"
1271
+ self._new_uuid_counter += 1
1272
+ namespace_uuid = uuid.uuid5(uuid.NAMESPACE_OID, URL_NAMESPACE)
1273
+ return str(uuid.uuid5(namespace_uuid, uuid_name_value))
1274
+
1275
+
1276
+ class ExecutionResults:
1277
+ actions: list[pb.OrchestratorAction]
1278
+ encoded_custom_status: Optional[str]
1279
+
1280
+ def __init__(
1281
+ self, actions: list[pb.OrchestratorAction], encoded_custom_status: Optional[str]
1282
+ ):
1283
+ self.actions = actions
1284
+ self.encoded_custom_status = encoded_custom_status
1285
+
1286
+
1287
+ class _OrchestrationExecutor:
1288
+ _generator: Optional[task.Orchestrator] = None
1289
+
1290
+ def __init__(self, registry: _Registry, logger: logging.Logger):
1291
+ self._registry = registry
1292
+ self._logger = logger
1293
+ self._is_suspended = False
1294
+ self._suspended_events: list[pb.HistoryEvent] = []
1295
+
1296
+ def execute(
1297
+ self,
1298
+ instance_id: str,
1299
+ old_events: Sequence[pb.HistoryEvent],
1300
+ new_events: Sequence[pb.HistoryEvent],
1301
+ ) -> ExecutionResults:
1302
+ orchestration_name = "<unknown>"
1303
+ orchestration_started_events = [e for e in old_events if e.HasField("executionStarted")]
1304
+ if len(orchestration_started_events) >= 1:
1305
+ orchestration_name = orchestration_started_events[0].executionStarted.name
1306
+
1307
+ self._logger.debug(
1308
+ f"{instance_id}: Beginning replay for orchestrator {orchestration_name}..."
1309
+ )
1310
+
1311
+ if not new_events:
1312
+ raise task.OrchestrationStateError(
1313
+ "The new history event list must have at least one event in it."
1314
+ )
1315
+
1316
+ ctx = _RuntimeOrchestrationContext(instance_id, self._registry)
1317
+ try:
1318
+ # Rebuild local state by replaying old history into the orchestrator function
1319
+ self._logger.debug(
1320
+ f"{instance_id}: Rebuilding local state with {len(old_events)} history event..."
1321
+ )
1322
+ ctx._is_replaying = True
1323
+ for old_event in old_events:
1324
+ self.process_event(ctx, old_event)
1325
+
1326
+ # Get new actions by executing newly received events into the orchestrator function
1327
+ if self._logger.level <= logging.DEBUG:
1328
+ summary = _get_new_event_summary(new_events)
1329
+ self._logger.debug(
1330
+ f"{instance_id}: Processing {len(new_events)} new event(s): {summary}"
1331
+ )
1332
+ ctx._is_replaying = False
1333
+ for new_event in new_events:
1334
+ self.process_event(ctx, new_event)
1335
+
1336
+ except pe.VersionFailureException as ex:
1337
+ if self._registry.versioning and self._registry.versioning.failure_strategy == VersionFailureStrategy.FAIL:
1338
+ if ex.error_details:
1339
+ ctx.set_failed(ex.error_details)
1340
+ else:
1341
+ ctx.set_failed(ex)
1342
+ elif self._registry.versioning and self._registry.versioning.failure_strategy == VersionFailureStrategy.REJECT:
1343
+ raise pe.AbandonOrchestrationError
1344
+
1345
+ except Exception as ex:
1346
+ # Unhandled exceptions fail the orchestration
1347
+ self._logger.debug(f"{instance_id}: Orchestration {orchestration_name} failed")
1348
+ ctx.set_failed(ex)
1349
+
1350
+ if not ctx._is_complete:
1351
+ task_count = len(ctx._pending_tasks)
1352
+ event_count = len(ctx._pending_events)
1353
+ self._logger.info(
1354
+ f"{instance_id}: Orchestrator {orchestration_name} yielded with {task_count} task(s) "
1355
+ f"and {event_count} event(s) outstanding."
1356
+ )
1357
+ elif (
1358
+ ctx._completion_status and ctx._completion_status is not pb.ORCHESTRATION_STATUS_CONTINUED_AS_NEW
1359
+ ):
1360
+ completion_status_str = ph.get_orchestration_status_str(
1361
+ ctx._completion_status
1362
+ )
1363
+ self._logger.info(
1364
+ f"{instance_id}: Orchestration {orchestration_name} completed with status: {completion_status_str}"
1365
+ )
1366
+
1367
+ actions = ctx.get_actions()
1368
+ if self._logger.level <= logging.DEBUG:
1369
+ self._logger.debug(
1370
+ f"{instance_id}: Returning {len(actions)} action(s): {_get_action_summary(actions)}"
1371
+ )
1372
+ return ExecutionResults(
1373
+ actions=actions, encoded_custom_status=ctx._encoded_custom_status
1374
+ )
1375
+
1376
+ def process_event(
1377
+ self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEvent
1378
+ ) -> None:
1379
+ if self._is_suspended and _is_suspendable(event):
1380
+ # We are suspended, so we need to buffer this event until we are resumed
1381
+ self._suspended_events.append(event)
1382
+ return
1383
+
1384
+ # CONSIDER: change to a switch statement with event.WhichOneof("eventType")
1385
+ try:
1386
+ if event.HasField("orchestratorStarted"):
1387
+ ctx.current_utc_datetime = event.timestamp.ToDatetime()
1388
+ elif event.HasField("executionStarted"):
1389
+ # TODO: Check if we already started the orchestration
1390
+ fn = self._registry.get_orchestrator(event.executionStarted.name)
1391
+ if fn is None:
1392
+ raise OrchestratorNotRegisteredError(
1393
+ f"A '{event.executionStarted.name}' orchestrator was not registered."
1394
+ )
1395
+
1396
+ if event.executionStarted.version:
1397
+ ctx._version = event.executionStarted.version.value
1398
+
1399
+ if self._registry.versioning:
1400
+ version_failure = self.evaluate_orchestration_versioning(
1401
+ self._registry.versioning,
1402
+ ctx.version
1403
+ )
1404
+ if version_failure:
1405
+ self._logger.warning(
1406
+ f"Orchestration version did not meet worker versioning requirements. "
1407
+ f"Error action = '{self._registry.versioning.failure_strategy}'. "
1408
+ f"Version error = '{version_failure}'"
1409
+ )
1410
+ raise pe.VersionFailureException(version_failure)
1411
+
1412
+ # deserialize the input, if any
1413
+ input = None
1414
+ if (
1415
+ event.executionStarted.input is not None and event.executionStarted.input.value != ""
1416
+ ):
1417
+ input = shared.from_json(event.executionStarted.input.value)
1418
+
1419
+ result = fn(
1420
+ ctx, input
1421
+ ) # this does not execute the generator, only creates it
1422
+ if isinstance(result, GeneratorType):
1423
+ # Start the orchestrator's generator function
1424
+ ctx.run(result)
1425
+ else:
1426
+ # This is an orchestrator that doesn't schedule any tasks
1427
+ ctx.set_complete(result, pb.ORCHESTRATION_STATUS_COMPLETED)
1428
+ elif event.HasField("timerCreated"):
1429
+ # This history event confirms that the timer was successfully scheduled.
1430
+ # Remove the timerCreated event from the pending action list so we don't schedule it again.
1431
+ timer_id = event.eventId
1432
+ action = ctx._pending_actions.pop(timer_id, None)
1433
+ if not action:
1434
+ raise _get_non_determinism_error(
1435
+ timer_id, task.get_name(ctx.create_timer)
1436
+ )
1437
+ elif not action.HasField("createTimer"):
1438
+ expected_method_name = task.get_name(ctx.create_timer)
1439
+ raise _get_wrong_action_type_error(
1440
+ timer_id, expected_method_name, action
1441
+ )
1442
+ elif event.HasField("timerFired"):
1443
+ timer_id = event.timerFired.timerId
1444
+ timer_task = ctx._pending_tasks.pop(timer_id, None)
1445
+ if not timer_task:
1446
+ # TODO: Should this be an error? When would it ever happen?
1447
+ if not ctx._is_replaying:
1448
+ self._logger.warning(
1449
+ f"{ctx.instance_id}: Ignoring unexpected timerFired event with ID = {timer_id}."
1450
+ )
1451
+ return
1452
+ timer_task.complete(None)
1453
+ if timer_task._retryable_parent is not None:
1454
+ activity_action = timer_task._retryable_parent._action
1455
+
1456
+ if not timer_task._retryable_parent._is_sub_orch:
1457
+ cur_task = activity_action.scheduleTask
1458
+ instance_id = None
1459
+ else:
1460
+ cur_task = activity_action.createSubOrchestration
1461
+ instance_id = cur_task.instanceId
1462
+ ctx.call_activity_function_helper(
1463
+ id=activity_action.id,
1464
+ activity_function=cur_task.name,
1465
+ input=cur_task.input.value,
1466
+ retry_policy=timer_task._retryable_parent._retry_policy,
1467
+ is_sub_orch=timer_task._retryable_parent._is_sub_orch,
1468
+ instance_id=instance_id,
1469
+ fn_task=timer_task._retryable_parent,
1470
+ )
1471
+ else:
1472
+ ctx.resume()
1473
+ elif event.HasField("taskScheduled"):
1474
+ # This history event confirms that the activity execution was successfully scheduled.
1475
+ # Remove the taskScheduled event from the pending action list so we don't schedule it again.
1476
+ task_id = event.eventId
1477
+ action = ctx._pending_actions.pop(task_id, None)
1478
+ activity_task = ctx._pending_tasks.get(task_id, None)
1479
+ if not action:
1480
+ raise _get_non_determinism_error(
1481
+ task_id, task.get_name(ctx.call_activity)
1482
+ )
1483
+ elif not action.HasField("scheduleTask"):
1484
+ expected_method_name = task.get_name(ctx.call_activity)
1485
+ raise _get_wrong_action_type_error(
1486
+ task_id, expected_method_name, action
1487
+ )
1488
+ elif action.scheduleTask.name != event.taskScheduled.name:
1489
+ raise _get_wrong_action_name_error(
1490
+ task_id,
1491
+ method_name=task.get_name(ctx.call_activity),
1492
+ expected_task_name=event.taskScheduled.name,
1493
+ actual_task_name=action.scheduleTask.name,
1494
+ )
1495
+ elif event.HasField("taskCompleted"):
1496
+ # This history event contains the result of a completed activity task.
1497
+ task_id = event.taskCompleted.taskScheduledId
1498
+ activity_task = ctx._pending_tasks.pop(task_id, None)
1499
+ if not activity_task:
1500
+ # TODO: Should this be an error? When would it ever happen?
1501
+ if not ctx.is_replaying:
1502
+ self._logger.warning(
1503
+ f"{ctx.instance_id}: Ignoring unexpected taskCompleted event with ID = {task_id}."
1504
+ )
1505
+ return
1506
+ result = None
1507
+ if not ph.is_empty(event.taskCompleted.result):
1508
+ result = shared.from_json(event.taskCompleted.result.value)
1509
+ activity_task.complete(result)
1510
+ ctx.resume()
1511
+ elif event.HasField("taskFailed"):
1512
+ task_id = event.taskFailed.taskScheduledId
1513
+ activity_task = ctx._pending_tasks.pop(task_id, None)
1514
+ if not activity_task:
1515
+ # TODO: Should this be an error? When would it ever happen?
1516
+ if not ctx.is_replaying:
1517
+ self._logger.warning(
1518
+ f"{ctx.instance_id}: Ignoring unexpected taskFailed event with ID = {task_id}."
1519
+ )
1520
+ return
1521
+
1522
+ if isinstance(activity_task, task.RetryableTask):
1523
+ if activity_task._retry_policy is not None:
1524
+ next_delay = activity_task.compute_next_delay()
1525
+ if next_delay is None:
1526
+ activity_task.fail(
1527
+ f"{ctx.instance_id}: Activity task #{task_id} failed: {event.taskFailed.failureDetails.errorMessage}",
1528
+ event.taskFailed.failureDetails,
1529
+ )
1530
+ ctx.resume()
1531
+ else:
1532
+ activity_task.increment_attempt_count()
1533
+ ctx.create_timer_internal(next_delay, activity_task)
1534
+ elif isinstance(activity_task, task.CompletableTask):
1535
+ activity_task.fail(
1536
+ f"{ctx.instance_id}: Activity task #{task_id} failed: {event.taskFailed.failureDetails.errorMessage}",
1537
+ event.taskFailed.failureDetails,
1538
+ )
1539
+ ctx.resume()
1540
+ else:
1541
+ raise TypeError("Unexpected task type")
1542
+ elif event.HasField("subOrchestrationInstanceCreated"):
1543
+ # This history event confirms that the sub-orchestration execution was successfully scheduled.
1544
+ # Remove the subOrchestrationInstanceCreated event from the pending action list so we don't schedule it again.
1545
+ task_id = event.eventId
1546
+ action = ctx._pending_actions.pop(task_id, None)
1547
+ if not action:
1548
+ raise _get_non_determinism_error(
1549
+ task_id, task.get_name(ctx.call_sub_orchestrator)
1550
+ )
1551
+ elif not action.HasField("createSubOrchestration"):
1552
+ expected_method_name = task.get_name(ctx.call_sub_orchestrator)
1553
+ raise _get_wrong_action_type_error(
1554
+ task_id, expected_method_name, action
1555
+ )
1556
+ elif (
1557
+ action.createSubOrchestration.name != event.subOrchestrationInstanceCreated.name
1558
+ ):
1559
+ raise _get_wrong_action_name_error(
1560
+ task_id,
1561
+ method_name=task.get_name(ctx.call_sub_orchestrator),
1562
+ expected_task_name=event.subOrchestrationInstanceCreated.name,
1563
+ actual_task_name=action.createSubOrchestration.name,
1564
+ )
1565
+ elif event.HasField("subOrchestrationInstanceCompleted"):
1566
+ task_id = event.subOrchestrationInstanceCompleted.taskScheduledId
1567
+ sub_orch_task = ctx._pending_tasks.pop(task_id, None)
1568
+ if not sub_orch_task:
1569
+ # TODO: Should this be an error? When would it ever happen?
1570
+ if not ctx.is_replaying:
1571
+ self._logger.warning(
1572
+ f"{ctx.instance_id}: Ignoring unexpected subOrchestrationInstanceCompleted event with ID = {task_id}."
1573
+ )
1574
+ return
1575
+ result = None
1576
+ if not ph.is_empty(event.subOrchestrationInstanceCompleted.result):
1577
+ result = shared.from_json(
1578
+ event.subOrchestrationInstanceCompleted.result.value
1579
+ )
1580
+ sub_orch_task.complete(result)
1581
+ ctx.resume()
1582
+ elif event.HasField("subOrchestrationInstanceFailed"):
1583
+ failedEvent = event.subOrchestrationInstanceFailed
1584
+ task_id = failedEvent.taskScheduledId
1585
+ sub_orch_task = ctx._pending_tasks.pop(task_id, None)
1586
+ if not sub_orch_task:
1587
+ # TODO: Should this be an error? When would it ever happen?
1588
+ if not ctx.is_replaying:
1589
+ self._logger.warning(
1590
+ f"{ctx.instance_id}: Ignoring unexpected subOrchestrationInstanceFailed event with ID = {task_id}."
1591
+ )
1592
+ return
1593
+ if isinstance(sub_orch_task, task.RetryableTask):
1594
+ if sub_orch_task._retry_policy is not None:
1595
+ next_delay = sub_orch_task.compute_next_delay()
1596
+ if next_delay is None:
1597
+ sub_orch_task.fail(
1598
+ f"Sub-orchestration task #{task_id} failed: {failedEvent.failureDetails.errorMessage}",
1599
+ failedEvent.failureDetails,
1600
+ )
1601
+ ctx.resume()
1602
+ else:
1603
+ sub_orch_task.increment_attempt_count()
1604
+ ctx.create_timer_internal(next_delay, sub_orch_task)
1605
+ elif isinstance(sub_orch_task, task.CompletableTask):
1606
+ sub_orch_task.fail(
1607
+ f"Sub-orchestration task #{task_id} failed: {failedEvent.failureDetails.errorMessage}",
1608
+ failedEvent.failureDetails,
1609
+ )
1610
+ ctx.resume()
1611
+ else:
1612
+ raise TypeError("Unexpected sub-orchestration task type")
1613
+ elif event.HasField("eventRaised"):
1614
+ if event.eventRaised.name in ctx._entity_task_id_map:
1615
+ # This eventRaised represents the result of an entity operation after being translated to the old
1616
+ # entity protocol by the Durable WebJobs extension
1617
+ entity_id, task_id, action_type = ctx._entity_task_id_map.get(event.eventRaised.name, (None, None, None))
1618
+ if entity_id is None:
1619
+ raise RuntimeError(f"Could not retrieve entity ID for entity-related eventRaised with ID '{event.eventId}'")
1620
+ if task_id is None:
1621
+ raise RuntimeError(f"Could not retrieve task ID for entity-related eventRaised with ID '{event.eventId}'")
1622
+ entity_task = ctx._pending_tasks.pop(task_id, None)
1623
+ if not entity_task:
1624
+ raise RuntimeError(f"Could not retrieve entity task for entity-related eventRaised with ID '{event.eventId}'")
1625
+ result = None
1626
+ if not ph.is_empty(event.eventRaised.input):
1627
+ # TODO: Investigate why the event result is wrapped in a dict with "result" key
1628
+ result = shared.from_json(event.eventRaised.input.value)["result"]
1629
+ if action_type == "entityOperationCalled":
1630
+ ctx._entity_context.recover_lock_after_call(entity_id)
1631
+ entity_task.complete(result)
1632
+ ctx.resume()
1633
+ elif action_type == "entityLockRequested":
1634
+ ctx._entity_context.complete_acquire(event.eventRaised.name)
1635
+ entity_task.complete(EntityLock(ctx))
1636
+ ctx.resume()
1637
+ else:
1638
+ raise RuntimeError(f"Unknown action type '{action_type}' for entity-related eventRaised "
1639
+ f"with ID '{event.eventId}'")
1640
+
1641
+ else:
1642
+ # event names are case-insensitive
1643
+ event_name = event.eventRaised.name.casefold()
1644
+ if not ctx.is_replaying:
1645
+ self._logger.info(f"{ctx.instance_id} Event raised: {event_name}")
1646
+ task_list = ctx._pending_events.get(event_name, None)
1647
+ decoded_result: Optional[Any] = None
1648
+ if task_list:
1649
+ event_task = task_list.pop(0)
1650
+ if not ph.is_empty(event.eventRaised.input):
1651
+ decoded_result = shared.from_json(event.eventRaised.input.value)
1652
+ event_task.complete(decoded_result)
1653
+ if not task_list:
1654
+ del ctx._pending_events[event_name]
1655
+ ctx.resume()
1656
+ else:
1657
+ # buffer the event
1658
+ event_list = ctx._received_events.get(event_name, None)
1659
+ if not event_list:
1660
+ event_list = []
1661
+ ctx._received_events[event_name] = event_list
1662
+ if not ph.is_empty(event.eventRaised.input):
1663
+ decoded_result = shared.from_json(event.eventRaised.input.value)
1664
+ event_list.append(decoded_result)
1665
+ if not ctx.is_replaying:
1666
+ self._logger.info(
1667
+ f"{ctx.instance_id}: Event '{event_name}' has been buffered as there are no tasks waiting for it."
1668
+ )
1669
+ elif event.HasField("executionSuspended"):
1670
+ if not self._is_suspended and not ctx.is_replaying:
1671
+ self._logger.info(f"{ctx.instance_id}: Execution suspended.")
1672
+ self._is_suspended = True
1673
+ elif event.HasField("executionResumed") and self._is_suspended:
1674
+ if not ctx.is_replaying:
1675
+ self._logger.info(f"{ctx.instance_id}: Resuming execution.")
1676
+ self._is_suspended = False
1677
+ for e in self._suspended_events:
1678
+ self.process_event(ctx, e)
1679
+ self._suspended_events = []
1680
+ elif event.HasField("executionTerminated"):
1681
+ if not ctx.is_replaying:
1682
+ self._logger.info(f"{ctx.instance_id}: Execution terminating.")
1683
+ encoded_output = (
1684
+ event.executionTerminated.input.value
1685
+ if not ph.is_empty(event.executionTerminated.input)
1686
+ else None
1687
+ )
1688
+ ctx.set_complete(
1689
+ encoded_output,
1690
+ pb.ORCHESTRATION_STATUS_TERMINATED,
1691
+ is_result_encoded=True,
1692
+ )
1693
+ elif event.HasField("entityOperationCalled"):
1694
+ # This history event confirms that the entity operation was successfully scheduled.
1695
+ # Remove the entityOperationCalled event from the pending action list so we don't schedule it again
1696
+ entity_call_id = event.eventId
1697
+ action = ctx._pending_actions.pop(entity_call_id, None)
1698
+ entity_task = ctx._pending_tasks.get(entity_call_id, None)
1699
+ if not action:
1700
+ raise _get_non_determinism_error(
1701
+ entity_call_id, task.get_name(ctx.call_entity)
1702
+ )
1703
+ elif not action.HasField("sendEntityMessage") or not action.sendEntityMessage.HasField("entityOperationCalled"):
1704
+ expected_method_name = task.get_name(ctx.call_entity)
1705
+ raise _get_wrong_action_type_error(
1706
+ entity_call_id, expected_method_name, action
1707
+ )
1708
+ entity_id = EntityInstanceId.parse(event.entityOperationCalled.targetInstanceId.value)
1709
+ if not entity_id:
1710
+ raise RuntimeError(f"Could not parse entity ID from targetInstanceId '{event.entityOperationCalled.targetInstanceId.value}'")
1711
+ ctx._entity_task_id_map[event.entityOperationCalled.requestId] = (entity_id, entity_call_id, None)
1712
+ elif event.HasField("entityOperationSignaled"):
1713
+ # This history event confirms that the entity signal was successfully scheduled.
1714
+ # Remove the entityOperationSignaled event from the pending action list so we don't schedule it
1715
+ entity_signal_id = event.eventId
1716
+ action = ctx._pending_actions.pop(entity_signal_id, None)
1717
+ if not action:
1718
+ raise _get_non_determinism_error(
1719
+ entity_signal_id, task.get_name(ctx.signal_entity)
1720
+ )
1721
+ elif not action.HasField("sendEntityMessage") or not action.sendEntityMessage.HasField("entityOperationSignaled"):
1722
+ expected_method_name = task.get_name(ctx.signal_entity)
1723
+ raise _get_wrong_action_type_error(
1724
+ entity_signal_id, expected_method_name, action
1725
+ )
1726
+ elif event.HasField("entityLockRequested"):
1727
+ section_id = event.entityLockRequested.criticalSectionId
1728
+ task_id = event.eventId
1729
+ action = ctx._pending_actions.pop(task_id, None)
1730
+ entity_task = ctx._pending_tasks.get(task_id, None)
1731
+ if not action:
1732
+ raise _get_non_determinism_error(
1733
+ task_id, task.get_name(ctx.lock_entities)
1734
+ )
1735
+ elif not action.HasField("sendEntityMessage") or not action.sendEntityMessage.HasField("entityLockRequested"):
1736
+ expected_method_name = task.get_name(ctx.lock_entities)
1737
+ raise _get_wrong_action_type_error(
1738
+ task_id, expected_method_name, action
1739
+ )
1740
+ ctx._entity_lock_id_map[section_id] = task_id
1741
+ elif event.HasField("entityUnlockSent"):
1742
+ # Remove the unlock tasks as they have already been processed
1743
+ tasks_to_remove = []
1744
+ for task_id, action in ctx._pending_actions.items():
1745
+ if action.HasField("sendEntityMessage") and action.sendEntityMessage.HasField("entityUnlockSent"):
1746
+ if action.sendEntityMessage.entityUnlockSent.criticalSectionId == event.entityUnlockSent.criticalSectionId:
1747
+ tasks_to_remove.append(task_id)
1748
+ for task_to_remove in tasks_to_remove:
1749
+ ctx._pending_actions.pop(task_to_remove, None)
1750
+ elif event.HasField("entityLockGranted"):
1751
+ section_id = event.entityLockGranted.criticalSectionId
1752
+ task_id = ctx._entity_lock_id_map.pop(section_id, None)
1753
+ if not task_id:
1754
+ # TODO: Should this be an error? When would it ever happen?
1755
+ if not ctx.is_replaying:
1756
+ self._logger.warning(
1757
+ f"{ctx.instance_id}: Ignoring unexpected entityLockGranted event for criticalSectionId '{section_id}'."
1758
+ )
1759
+ return
1760
+ entity_task = ctx._pending_tasks.pop(task_id, None)
1761
+ if not entity_task:
1762
+ if not ctx.is_replaying:
1763
+ self._logger.warning(
1764
+ f"{ctx.instance_id}: Ignoring unexpected entityLockGranted event for criticalSectionId '{section_id}'."
1765
+ )
1766
+ return
1767
+ ctx._entity_context.complete_acquire(section_id)
1768
+ entity_task.complete(EntityLock(ctx))
1769
+ ctx.resume()
1770
+ elif event.HasField("entityOperationCompleted"):
1771
+ request_id = event.entityOperationCompleted.requestId
1772
+ entity_id, task_id, _ = ctx._entity_task_id_map.pop(request_id, (None, None, None))
1773
+ if not entity_id:
1774
+ raise RuntimeError(f"Could not parse entity ID from request ID '{request_id}'")
1775
+ if not task_id:
1776
+ raise RuntimeError(f"Could not find matching task ID for entity operation with request ID '{request_id}'")
1777
+ entity_task = ctx._pending_tasks.pop(task_id, None)
1778
+ if not entity_task:
1779
+ if not ctx.is_replaying:
1780
+ self._logger.warning(
1781
+ f"{ctx.instance_id}: Ignoring unexpected entityOperationCompleted event with request ID = {request_id}."
1782
+ )
1783
+ return
1784
+ result = None
1785
+ if not ph.is_empty(event.entityOperationCompleted.output):
1786
+ result = shared.from_json(event.entityOperationCompleted.output.value)
1787
+ ctx._entity_context.recover_lock_after_call(entity_id)
1788
+ entity_task.complete(result)
1789
+ ctx.resume()
1790
+ elif event.HasField("entityOperationFailed"):
1791
+ if not ctx.is_replaying:
1792
+ self._logger.info(f"{ctx.instance_id}: Entity operation failed.")
1793
+ self._logger.info(f"Data: {json.dumps(event.entityOperationFailed)}")
1794
+ pass
1795
+ elif event.HasField("orchestratorCompleted"):
1796
+ # Added in Functions only (for some reason) and does not affect orchestrator flow
1797
+ pass
1798
+ elif event.HasField("eventSent"):
1799
+ # Check if this eventSent corresponds to an entity operation call after being translated to the old
1800
+ # entity protocol by the Durable WebJobs extension. If so, treat this message similarly to
1801
+ # entityOperationCalled and remove the pending action. Also store the entity id and event id for later
1802
+ action = ctx._pending_actions.pop(event.eventId, None)
1803
+ if action and action.HasField("sendEntityMessage"):
1804
+ if action.sendEntityMessage.HasField("entityOperationCalled"):
1805
+ action_type = "entityOperationCalled"
1806
+ elif action.sendEntityMessage.HasField("entityLockRequested"):
1807
+ action_type = "entityLockRequested"
1808
+ else:
1809
+ return
1810
+
1811
+ entity_id = EntityInstanceId.parse(event.eventSent.instanceId)
1812
+ event_id = json.loads(event.eventSent.input.value)["id"]
1813
+ ctx._entity_task_id_map[event_id] = (entity_id, event.eventId, action_type)
1814
+ else:
1815
+ eventType = event.WhichOneof("eventType")
1816
+ raise task.OrchestrationStateError(
1817
+ f"Don't know how to handle event of type '{eventType}'"
1818
+ )
1819
+ except StopIteration as generatorStopped:
1820
+ # The orchestrator generator function completed
1821
+ ctx.set_complete(generatorStopped.value, pb.ORCHESTRATION_STATUS_COMPLETED)
1822
+
1823
+ def evaluate_orchestration_versioning(self, versioning: Optional[VersioningOptions], orchestration_version: Optional[str]) -> Optional[pb.TaskFailureDetails]:
1824
+ if versioning is None:
1825
+ return None
1826
+ version_comparison = self.compare_versions(orchestration_version, versioning.version)
1827
+ if versioning.match_strategy == VersionMatchStrategy.NONE:
1828
+ return None
1829
+ elif versioning.match_strategy == VersionMatchStrategy.STRICT:
1830
+ if version_comparison != 0:
1831
+ return pb.TaskFailureDetails(
1832
+ errorType="VersionMismatch",
1833
+ errorMessage=f"The orchestration version '{orchestration_version}' does not match the worker version '{versioning.version}'.",
1834
+ isNonRetriable=True,
1835
+ )
1836
+ elif versioning.match_strategy == VersionMatchStrategy.CURRENT_OR_OLDER:
1837
+ if version_comparison > 0:
1838
+ return pb.TaskFailureDetails(
1839
+ errorType="VersionMismatch",
1840
+ errorMessage=f"The orchestration version '{orchestration_version}' is greater than the worker version '{versioning.version}'.",
1841
+ isNonRetriable=True,
1842
+ )
1843
+ else:
1844
+ # If there is a type of versioning we don't understand, it is better to treat it as a versioning failure.
1845
+ return pb.TaskFailureDetails(
1846
+ errorType="VersionMismatch",
1847
+ errorMessage=f"The version match strategy '{versioning.match_strategy}' is unknown.",
1848
+ isNonRetriable=True,
1849
+ )
1850
+
1851
+ def compare_versions(self, source_version: Optional[str], default_version: Optional[str]) -> int:
1852
+ if not source_version and not default_version:
1853
+ return 0
1854
+ if not source_version:
1855
+ return -1
1856
+ if not default_version:
1857
+ return 1
1858
+ try:
1859
+ source_version_parsed = parse(source_version)
1860
+ default_version_parsed = parse(default_version)
1861
+ return (source_version_parsed > default_version_parsed) - (source_version_parsed < default_version_parsed)
1862
+ except InvalidVersion:
1863
+ return (source_version > default_version) - (source_version < default_version)
1864
+
1865
+
1866
+ class _ActivityExecutor:
1867
+ def __init__(self, registry: _Registry, logger: logging.Logger):
1868
+ self._registry = registry
1869
+ self._logger = logger
1870
+
1871
+ def execute(
1872
+ self,
1873
+ orchestration_id: str,
1874
+ name: str,
1875
+ task_id: int,
1876
+ encoded_input: Optional[str],
1877
+ ) -> Optional[str]:
1878
+ """Executes an activity function and returns the serialized result, if any."""
1879
+ self._logger.debug(
1880
+ f"{orchestration_id}/{task_id}: Executing activity '{name}'..."
1881
+ )
1882
+ fn = self._registry.get_activity(name)
1883
+ if not fn:
1884
+ raise ActivityNotRegisteredError(
1885
+ f"Activity function named '{name}' was not registered!"
1886
+ )
1887
+
1888
+ activity_input = shared.from_json(encoded_input) if encoded_input else None
1889
+ ctx = task.ActivityContext(orchestration_id, task_id)
1890
+
1891
+ # Execute the activity function
1892
+ activity_output = fn(ctx, activity_input)
1893
+
1894
+ encoded_output = (
1895
+ shared.to_json(activity_output) if activity_output is not None else None
1896
+ )
1897
+ chars = len(encoded_output) if encoded_output else 0
1898
+ self._logger.debug(
1899
+ f"{orchestration_id}/{task_id}: Activity '{name}' completed successfully with {chars} char(s) of encoded output."
1900
+ )
1901
+ return encoded_output
1902
+
1903
+
1904
+ class _EntityExecutor:
1905
+ def __init__(self, registry: _Registry, logger: logging.Logger):
1906
+ self._registry = registry
1907
+ self._logger = logger
1908
+
1909
+ def execute(
1910
+ self,
1911
+ orchestration_id: str,
1912
+ entity_id: EntityInstanceId,
1913
+ operation: str,
1914
+ state: StateShim,
1915
+ encoded_input: Optional[str],
1916
+ ) -> Optional[str]:
1917
+ """Executes an entity function and returns the serialized result, if any."""
1918
+ self._logger.debug(
1919
+ f"{orchestration_id}: Executing entity '{entity_id}'..."
1920
+ )
1921
+ fn = self._registry.get_entity(entity_id.entity)
1922
+ if not fn:
1923
+ raise EntityNotRegisteredError(
1924
+ f"Entity function named '{entity_id.entity}' was not registered!"
1925
+ )
1926
+
1927
+ entity_input = shared.from_json(encoded_input) if encoded_input else None
1928
+ ctx = EntityContext(orchestration_id, operation, state, entity_id)
1929
+
1930
+ if isinstance(fn, type) and issubclass(fn, DurableEntity):
1931
+ if self._registry.entity_instances.get(str(entity_id), None):
1932
+ entity_instance = self._registry.entity_instances[str(entity_id)]
1933
+ else:
1934
+ entity_instance = fn()
1935
+ self._registry.entity_instances[str(entity_id)] = entity_instance
1936
+ if not hasattr(entity_instance, operation):
1937
+ raise AttributeError(f"Entity '{entity_id}' does not have operation '{operation}'")
1938
+ method = getattr(entity_instance, operation)
1939
+ if not callable(method):
1940
+ raise TypeError(f"Entity operation '{operation}' is not callable")
1941
+ # Execute the entity method
1942
+ entity_instance._initialize_entity_context(ctx)
1943
+ entity_output = method(entity_input)
1944
+ else:
1945
+ # Execute the entity function
1946
+ entity_output = fn(ctx, entity_input)
1947
+
1948
+ encoded_output = (
1949
+ shared.to_json(entity_output) if entity_output is not None else None
1950
+ )
1951
+ chars = len(encoded_output) if encoded_output else 0
1952
+ self._logger.debug(
1953
+ f"{orchestration_id}: Entity '{entity_id}' completed successfully with {chars} char(s) of encoded output."
1954
+ )
1955
+ return encoded_output
1956
+
1957
+
1958
+ def _get_non_determinism_error(
1959
+ task_id: int, action_name: str
1960
+ ) -> task.NonDeterminismError:
1961
+ return task.NonDeterminismError(
1962
+ f"A previous execution called {action_name} with ID={task_id}, but the current "
1963
+ f"execution doesn't have this action with this ID. This problem occurs when either "
1964
+ f"the orchestration has non-deterministic logic or if the code was changed after an "
1965
+ f"instance of this orchestration already started running."
1966
+ )
1967
+
1968
+
1969
+ def _get_wrong_action_type_error(
1970
+ task_id: int, expected_method_name: str, action: pb.OrchestratorAction
1971
+ ) -> task.NonDeterminismError:
1972
+ unexpected_method_name = _get_method_name_for_action(action)
1973
+ return task.NonDeterminismError(
1974
+ f"Failed to restore orchestration state due to a history mismatch: A previous execution called "
1975
+ f"{expected_method_name} with ID={task_id}, but the current execution is instead trying to call "
1976
+ f"{unexpected_method_name} as part of rebuilding it's history. This kind of mismatch can happen if an "
1977
+ f"orchestration has non-deterministic logic or if the code was changed after an instance of this "
1978
+ f"orchestration already started running."
1979
+ )
1980
+
1981
+
1982
+ def _get_wrong_action_name_error(
1983
+ task_id: int, method_name: str, expected_task_name: str, actual_task_name: str
1984
+ ) -> task.NonDeterminismError:
1985
+ return task.NonDeterminismError(
1986
+ f"Failed to restore orchestration state due to a history mismatch: A previous execution called "
1987
+ f"{method_name} with name='{expected_task_name}' and sequence number {task_id}, but the current "
1988
+ f"execution is instead trying to call {actual_task_name} as part of rebuilding it's history. "
1989
+ f"This kind of mismatch can happen if an orchestration has non-deterministic logic or if the code "
1990
+ f"was changed after an instance of this orchestration already started running."
1991
+ )
1992
+
1993
+
1994
+ def _get_method_name_for_action(action: pb.OrchestratorAction) -> str:
1995
+ action_type = action.WhichOneof("orchestratorActionType")
1996
+ if action_type == "scheduleTask":
1997
+ return task.get_name(task.OrchestrationContext.call_activity)
1998
+ elif action_type == "createTimer":
1999
+ return task.get_name(task.OrchestrationContext.create_timer)
2000
+ elif action_type == "createSubOrchestration":
2001
+ return task.get_name(task.OrchestrationContext.call_sub_orchestrator)
2002
+ # elif action_type == "sendEvent":
2003
+ # return task.get_name(task.OrchestrationContext.send_event)
2004
+ else:
2005
+ raise NotImplementedError(f"Action type '{action_type}' not supported!")
2006
+
2007
+
2008
+ def _get_new_event_summary(new_events: Sequence[pb.HistoryEvent]) -> str:
2009
+ """Returns a summary of the new events that can be used for logging."""
2010
+ if not new_events:
2011
+ return "[]"
2012
+ elif len(new_events) == 1:
2013
+ return f"[{new_events[0].WhichOneof('eventType')}]"
2014
+ else:
2015
+ counts: dict[str, int] = {}
2016
+ for event in new_events:
2017
+ event_type = event.WhichOneof("eventType")
2018
+ counts[event_type] = counts.get(event_type, 0) + 1
2019
+ return f"[{', '.join(f'{name}={count}' for name, count in counts.items())}]"
2020
+
2021
+
2022
+ def _get_action_summary(new_actions: Sequence[pb.OrchestratorAction]) -> str:
2023
+ """Returns a summary of the new actions that can be used for logging."""
2024
+ if not new_actions:
2025
+ return "[]"
2026
+ elif len(new_actions) == 1:
2027
+ return f"[{new_actions[0].WhichOneof('orchestratorActionType')}]"
2028
+ else:
2029
+ counts: dict[str, int] = {}
2030
+ for action in new_actions:
2031
+ action_type = action.WhichOneof("orchestratorActionType")
2032
+ counts[action_type] = counts.get(action_type, 0) + 1
2033
+ return f"[{', '.join(f'{name}={count}' for name, count in counts.items())}]"
2034
+
2035
+
2036
+ def _is_suspendable(event: pb.HistoryEvent) -> bool:
2037
+ """Returns true if the event is one that can be suspended and resumed."""
2038
+ return event.WhichOneof("eventType") not in [
2039
+ "executionResumed",
2040
+ "executionTerminated",
2041
+ ]
2042
+
2043
+
2044
+ class _AsyncWorkerManager:
2045
+ def __init__(self, concurrency_options: ConcurrencyOptions, logger: logging.Logger):
2046
+ self.concurrency_options = concurrency_options
2047
+ self._logger = logger
2048
+
2049
+ self.activity_semaphore = None
2050
+ self.orchestration_semaphore = None
2051
+ self.entity_semaphore = None
2052
+ # Don't create queues here - defer until we have an event loop
2053
+ self.activity_queue: Optional[asyncio.Queue] = None
2054
+ self.orchestration_queue: Optional[asyncio.Queue] = None
2055
+ self.entity_batch_queue: Optional[asyncio.Queue] = None
2056
+ self._queue_event_loop: Optional[asyncio.AbstractEventLoop] = None
2057
+ # Store work items when no event loop is available
2058
+ self._pending_activity_work: list = []
2059
+ self._pending_orchestration_work: list = []
2060
+ self._pending_entity_batch_work: list = []
2061
+ self.thread_pool = ThreadPoolExecutor(
2062
+ max_workers=concurrency_options.maximum_thread_pool_workers,
2063
+ thread_name_prefix="DurableTask",
2064
+ )
2065
+ self._shutdown = False
2066
+
2067
+ def _ensure_queues_for_current_loop(self):
2068
+ """Ensure queues are bound to the current event loop."""
2069
+ try:
2070
+ current_loop = asyncio.get_running_loop()
2071
+ except RuntimeError:
2072
+ # No event loop running, can't create queues
2073
+ return
2074
+
2075
+ # Check if queues are already properly set up for current loop
2076
+ if self._queue_event_loop is current_loop:
2077
+ if self.activity_queue is not None and self.orchestration_queue is not None and self.entity_batch_queue is not None:
2078
+ # Queues are already bound to the current loop and exist
2079
+ return
2080
+
2081
+ # Need to recreate queues for the current event loop
2082
+ # First, preserve any existing work items
2083
+ existing_activity_items = []
2084
+ existing_orchestration_items = []
2085
+ existing_entity_batch_items = []
2086
+
2087
+ if self.activity_queue is not None:
2088
+ try:
2089
+ while not self.activity_queue.empty():
2090
+ existing_activity_items.append(self.activity_queue.get_nowait())
2091
+ except Exception:
2092
+ pass
2093
+
2094
+ if self.orchestration_queue is not None:
2095
+ try:
2096
+ while not self.orchestration_queue.empty():
2097
+ existing_orchestration_items.append(
2098
+ self.orchestration_queue.get_nowait()
2099
+ )
2100
+ except Exception:
2101
+ pass
2102
+
2103
+ if self.entity_batch_queue is not None:
2104
+ try:
2105
+ while not self.entity_batch_queue.empty():
2106
+ existing_entity_batch_items.append(
2107
+ self.entity_batch_queue.get_nowait()
2108
+ )
2109
+ except Exception:
2110
+ pass
2111
+
2112
+ # Create fresh queues for the current event loop
2113
+ self.activity_queue = asyncio.Queue()
2114
+ self.orchestration_queue = asyncio.Queue()
2115
+ self.entity_batch_queue = asyncio.Queue()
2116
+ self._queue_event_loop = current_loop
2117
+
2118
+ # Restore the work items to the new queues
2119
+ for item in existing_activity_items:
2120
+ self.activity_queue.put_nowait(item)
2121
+ for item in existing_orchestration_items:
2122
+ self.orchestration_queue.put_nowait(item)
2123
+ for item in existing_entity_batch_items:
2124
+ self.entity_batch_queue.put_nowait(item)
2125
+
2126
+ # Move pending work items to the queues
2127
+ for item in self._pending_activity_work:
2128
+ self.activity_queue.put_nowait(item)
2129
+ for item in self._pending_orchestration_work:
2130
+ self.orchestration_queue.put_nowait(item)
2131
+ for item in self._pending_entity_batch_work:
2132
+ self.entity_batch_queue.put_nowait(item)
2133
+
2134
+ # Clear the pending work lists
2135
+ self._pending_activity_work.clear()
2136
+ self._pending_orchestration_work.clear()
2137
+ self._pending_entity_batch_work.clear()
2138
+
2139
+ async def run(self):
2140
+ # Reset shutdown flag in case this manager is being reused
2141
+ self._shutdown = False
2142
+
2143
+ # Ensure queues are properly bound to the current event loop
2144
+ self._ensure_queues_for_current_loop()
2145
+
2146
+ # Create semaphores in the current event loop
2147
+ self.activity_semaphore = asyncio.Semaphore(
2148
+ self.concurrency_options.maximum_concurrent_activity_work_items
2149
+ )
2150
+ self.orchestration_semaphore = asyncio.Semaphore(
2151
+ self.concurrency_options.maximum_concurrent_orchestration_work_items
2152
+ )
2153
+ self.entity_semaphore = asyncio.Semaphore(
2154
+ self.concurrency_options.maximum_concurrent_entity_work_items
2155
+ )
2156
+
2157
+ # Start background consumers for each work type
2158
+ try:
2159
+ if self.activity_queue is not None and self.orchestration_queue is not None \
2160
+ and self.entity_batch_queue is not None:
2161
+ await asyncio.gather(
2162
+ self._consume_queue(self.activity_queue, self.activity_semaphore),
2163
+ self._consume_queue(
2164
+ self.orchestration_queue, self.orchestration_semaphore
2165
+ ),
2166
+ self._consume_queue(
2167
+ self.entity_batch_queue, self.entity_semaphore
2168
+ )
2169
+ )
2170
+ except Exception as queue_exception:
2171
+ self._logger.error(f"Shutting down worker - Uncaught error in worker manager: {queue_exception}")
2172
+ while self.activity_queue is not None and not self.activity_queue.empty():
2173
+ try:
2174
+ func, cancellation_func, args, kwargs = self.activity_queue.get_nowait()
2175
+ await self._run_func(cancellation_func, *args, **kwargs)
2176
+ self._logger.error(f"Activity work item args: {args}, kwargs: {kwargs}")
2177
+ except asyncio.QueueEmpty:
2178
+ # Queue was empty, no cancellation needed
2179
+ pass
2180
+ except Exception as cancellation_exception:
2181
+ self._logger.error(f"Uncaught error while cancelling activity work item: {cancellation_exception}")
2182
+ while self.orchestration_queue is not None and not self.orchestration_queue.empty():
2183
+ try:
2184
+ func, cancellation_func, args, kwargs = self.orchestration_queue.get_nowait()
2185
+ await self._run_func(cancellation_func, *args, **kwargs)
2186
+ self._logger.error(f"Orchestration work item args: {args}, kwargs: {kwargs}")
2187
+ except asyncio.QueueEmpty:
2188
+ # Queue was empty, no cancellation needed
2189
+ pass
2190
+ except Exception as cancellation_exception:
2191
+ self._logger.error(f"Uncaught error while cancelling orchestration work item: {cancellation_exception}")
2192
+ while self.entity_batch_queue is not None and not self.entity_batch_queue.empty():
2193
+ try:
2194
+ func, cancellation_func, args, kwargs = self.entity_batch_queue.get_nowait()
2195
+ await self._run_func(cancellation_func, *args, **kwargs)
2196
+ self._logger.error(f"Entity batch work item args: {args}, kwargs: {kwargs}")
2197
+ except asyncio.QueueEmpty:
2198
+ # Queue was empty, no cancellation needed
2199
+ pass
2200
+ except Exception as cancellation_exception:
2201
+ self._logger.error(f"Uncaught error while cancelling entity batch work item: {cancellation_exception}")
2202
+ self.shutdown()
2203
+
2204
+ async def _consume_queue(self, queue: asyncio.Queue, semaphore: asyncio.Semaphore):
2205
+ # List to track running tasks
2206
+ running_tasks: set[asyncio.Task] = set()
2207
+
2208
+ while True:
2209
+ # Clean up completed tasks
2210
+ done_tasks = {task for task in running_tasks if task.done()}
2211
+ running_tasks -= done_tasks
2212
+
2213
+ # Exit if shutdown is set and the queue is empty and no tasks are running
2214
+ if self._shutdown and queue.empty() and not running_tasks:
2215
+ break
2216
+
2217
+ try:
2218
+ work = await asyncio.wait_for(queue.get(), timeout=1.0)
2219
+ except asyncio.TimeoutError:
2220
+ continue
2221
+
2222
+ func, cancellation_func, args, kwargs = work
2223
+ # Create a concurrent task for processing
2224
+ task = asyncio.create_task(
2225
+ self._process_work_item(semaphore, queue, func, cancellation_func, args, kwargs)
2226
+ )
2227
+ running_tasks.add(task)
2228
+
2229
+ async def _process_work_item(
2230
+ self, semaphore: asyncio.Semaphore, queue: asyncio.Queue, func, cancellation_func, args, kwargs
2231
+ ):
2232
+ async with semaphore:
2233
+ try:
2234
+ await self._run_func(func, *args, **kwargs)
2235
+ except Exception as work_exception:
2236
+ self._logger.error(f"Uncaught error while processing work item, item will be abandoned: {work_exception}")
2237
+ await self._run_func(cancellation_func, *args, **kwargs)
2238
+ finally:
2239
+ queue.task_done()
2240
+
2241
+ async def _run_func(self, func, *args, **kwargs):
2242
+ if inspect.iscoroutinefunction(func):
2243
+ return await func(*args, **kwargs)
2244
+ else:
2245
+ loop = asyncio.get_running_loop()
2246
+ # Avoid submitting to executor after shutdown
2247
+ if (
2248
+ getattr(self, "_shutdown", False) and getattr(self, "thread_pool", None) and getattr(
2249
+ self.thread_pool, "_shutdown", False)
2250
+ ):
2251
+ return None
2252
+ return await loop.run_in_executor(
2253
+ self.thread_pool, lambda: func(*args, **kwargs)
2254
+ )
2255
+
2256
+ def submit_activity(self, func, cancellation_func, *args, **kwargs):
2257
+ if self._shutdown:
2258
+ raise RuntimeError("Cannot submit new work items after shutdown has been initiated.")
2259
+ work_item = (func, cancellation_func, args, kwargs)
2260
+ self._ensure_queues_for_current_loop()
2261
+ if self.activity_queue is not None:
2262
+ self.activity_queue.put_nowait(work_item)
2263
+ else:
2264
+ # No event loop running, store in pending list
2265
+ self._pending_activity_work.append(work_item)
2266
+
2267
+ def submit_orchestration(self, func, cancellation_func, *args, **kwargs):
2268
+ if self._shutdown:
2269
+ raise RuntimeError("Cannot submit new work items after shutdown has been initiated.")
2270
+ work_item = (func, cancellation_func, args, kwargs)
2271
+ self._ensure_queues_for_current_loop()
2272
+ if self.orchestration_queue is not None:
2273
+ self.orchestration_queue.put_nowait(work_item)
2274
+ else:
2275
+ # No event loop running, store in pending list
2276
+ self._pending_orchestration_work.append(work_item)
2277
+
2278
+ def submit_entity_batch(self, func, cancellation_func, *args, **kwargs):
2279
+ if self._shutdown:
2280
+ raise RuntimeError("Cannot submit new work items after shutdown has been initiated.")
2281
+ work_item = (func, cancellation_func, args, kwargs)
2282
+ self._ensure_queues_for_current_loop()
2283
+ if self.entity_batch_queue is not None:
2284
+ self.entity_batch_queue.put_nowait(work_item)
2285
+ else:
2286
+ # No event loop running, store in pending list
2287
+ self._pending_entity_batch_work.append(work_item)
2288
+
2289
+ def shutdown(self):
2290
+ self._shutdown = True
2291
+ self.thread_pool.shutdown(wait=True)
2292
+
2293
+ async def reset_for_new_run(self):
2294
+ """Reset the manager state for a new run."""
2295
+ self._shutdown = False
2296
+ # Clear any existing queues - they'll be recreated when needed
2297
+ if self.activity_queue is not None:
2298
+ # Clear existing queue by creating a new one
2299
+ # This ensures no items from previous runs remain
2300
+ try:
2301
+ while not self.activity_queue.empty():
2302
+ func, cancellation_func, args, kwargs = self.activity_queue.get_nowait()
2303
+ await self._run_func(cancellation_func, *args, **kwargs)
2304
+ except Exception as reset_exception:
2305
+ self._logger.warning(f"Error while clearing activity queue during reset: {reset_exception}")
2306
+ if self.orchestration_queue is not None:
2307
+ try:
2308
+ while not self.orchestration_queue.empty():
2309
+ func, cancellation_func, args, kwargs = self.orchestration_queue.get_nowait()
2310
+ await self._run_func(cancellation_func, *args, **kwargs)
2311
+ except Exception as reset_exception:
2312
+ self._logger.warning(f"Error while clearing orchestration queue during reset: {reset_exception}")
2313
+ if self.entity_batch_queue is not None:
2314
+ try:
2315
+ while not self.entity_batch_queue.empty():
2316
+ func, cancellation_func, args, kwargs = self.entity_batch_queue.get_nowait()
2317
+ await self._run_func(cancellation_func, *args, **kwargs)
2318
+ except Exception as reset_exception:
2319
+ self._logger.warning(f"Error while clearing entity queue during reset: {reset_exception}")
2320
+ # Clear pending work lists
2321
+ self._pending_activity_work.clear()
2322
+ self._pending_orchestration_work.clear()
2323
+ self._pending_entity_batch_work.clear()
2324
+
2325
+
2326
+ # Export public API
2327
+ __all__ = ["ConcurrencyOptions", "TaskHubGrpcWorker"]