durabletask 1.3.0.dev27__py3-none-any.whl → 1.3.0.dev29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
durabletask/__init__.py CHANGED
@@ -4,13 +4,24 @@
4
4
  """Durable Task SDK for Python"""
5
5
 
6
6
  from durabletask.payload.store import LargePayloadStorageOptions, PayloadStore
7
- from durabletask.worker import ConcurrencyOptions, VersioningOptions
7
+ from durabletask.worker import (
8
+ ActivityWorkItemFilter,
9
+ ConcurrencyOptions,
10
+ EntityWorkItemFilter,
11
+ OrchestrationWorkItemFilter,
12
+ VersioningOptions,
13
+ WorkItemFilters,
14
+ )
8
15
 
9
16
  __all__ = [
17
+ "ActivityWorkItemFilter",
10
18
  "ConcurrencyOptions",
19
+ "EntityWorkItemFilter",
11
20
  "LargePayloadStorageOptions",
21
+ "OrchestrationWorkItemFilter",
12
22
  "PayloadStore",
13
23
  "VersioningOptions",
24
+ "WorkItemFilters",
14
25
  ]
15
26
 
16
27
  PACKAGE_NAME = "durabletask"
durabletask/task.py CHANGED
@@ -98,7 +98,7 @@ class OrchestrationContext(ABC):
98
98
  pass
99
99
 
100
100
  @abstractmethod
101
- def create_timer(self, fire_at: Union[datetime, timedelta]) -> Task:
101
+ def create_timer(self, fire_at: Union[datetime, timedelta]) -> CancellableTask:
102
102
  """Create a Timer Task to fire after at the specified deadline.
103
103
 
104
104
  Parameters
@@ -228,10 +228,10 @@ class OrchestrationContext(ABC):
228
228
  """
229
229
  pass
230
230
 
231
- # TOOD: Add a timeout parameter, which allows the task to be canceled if the event is
231
+ # TOOD: Add a timeout parameter, which allows the task to be cancelled if the event is
232
232
  # not received within the specified timeout. This requires support for task cancellation.
233
233
  @abstractmethod
234
- def wait_for_external_event(self, name: str) -> CompletableTask:
234
+ def wait_for_external_event(self, name: str) -> CancellableTask:
235
235
  """Wait asynchronously for an event to be raised with the name `name`.
236
236
 
237
237
  Parameters
@@ -324,6 +324,10 @@ class OrchestrationStateError(Exception):
324
324
  pass
325
325
 
326
326
 
327
+ class TaskCancelledError(Exception):
328
+ """Exception type for cancelled orchestration tasks."""
329
+
330
+
327
331
  class Task(ABC, Generic[T]):
328
332
  """Abstract base class for asynchronous tasks in a durable orchestration."""
329
333
  _result: T
@@ -435,6 +439,48 @@ class CompletableTask(Task[T]):
435
439
  self._parent.on_child_completed(self)
436
440
 
437
441
 
442
+ class CancellableTask(CompletableTask[T]):
443
+ """A completable task that can be cancelled before it finishes."""
444
+
445
+ def __init__(self) -> None:
446
+ super().__init__()
447
+ self._is_cancelled = False
448
+ self._cancel_handler: Optional[Callable[[], None]] = None
449
+
450
+ @property
451
+ def is_cancelled(self) -> bool:
452
+ """Returns True if the task was cancelled, False otherwise."""
453
+ return self._is_cancelled
454
+
455
+ def get_result(self) -> T:
456
+ if self._is_cancelled:
457
+ raise TaskCancelledError('The task was cancelled.')
458
+ return super().get_result()
459
+
460
+ def set_cancel_handler(self, cancel_handler: Callable[[], None]) -> None:
461
+ self._cancel_handler = cancel_handler
462
+
463
+ def cancel(self) -> bool:
464
+ """Attempts to cancel this task.
465
+
466
+ Returns
467
+ -------
468
+ bool
469
+ True if cancellation was applied, False if the task had already completed.
470
+ """
471
+ if self._is_complete:
472
+ return False
473
+
474
+ if self._cancel_handler is not None:
475
+ self._cancel_handler()
476
+
477
+ self._is_cancelled = True
478
+ self._is_complete = True
479
+ if self._parent is not None:
480
+ self._parent.on_child_completed(self)
481
+ return True
482
+
483
+
438
484
  class RetryableTask(CompletableTask[T]):
439
485
  """A task that can be retried according to a retry policy."""
440
486
 
@@ -474,14 +520,29 @@ class RetryableTask(CompletableTask[T]):
474
520
  return None
475
521
 
476
522
 
477
- class TimerTask(CompletableTask[T]):
478
-
479
- def __init__(self) -> None:
523
+ class TimerTask(CancellableTask[None]):
524
+ def __init__(self, final_fire_at: Optional[datetime] = None,
525
+ maximum_timer_interval: Optional[timedelta] = None):
480
526
  super().__init__()
527
+ self._final_fire_at = final_fire_at
528
+ self._maximum_timer_interval = maximum_timer_interval
481
529
 
482
530
  def set_retryable_parent(self, retryable_task: RetryableTask):
483
531
  self._retryable_parent = retryable_task
484
532
 
533
+ def _handle_timer_fired(self, current_utc_datetime: datetime) -> Optional[datetime]:
534
+ if (self._final_fire_at is not None
535
+ and self._maximum_timer_interval is not None
536
+ and current_utc_datetime < self._final_fire_at):
537
+ return self._get_next_fire_at(current_utc_datetime)
538
+ super().complete(None)
539
+ return None
540
+
541
+ def _get_next_fire_at(self, current_utc_datetime: datetime) -> datetime:
542
+ if current_utc_datetime + self._maximum_timer_interval < self._final_fire_at:
543
+ return current_utc_datetime + self._maximum_timer_interval
544
+ return self._final_fire_at
545
+
485
546
 
486
547
  class WhenAnyTask(CompositeTask[Task]):
487
548
  """A task that completes when any of its child tasks complete."""
@@ -26,6 +26,7 @@ from google.protobuf import empty_pb2, timestamp_pb2, wrappers_pb2
26
26
  import durabletask.internal.orchestrator_service_pb2 as pb
27
27
  import durabletask.internal.orchestrator_service_pb2_grpc as stubs
28
28
  import durabletask.internal.helpers as helpers
29
+ from durabletask.entities.entity_instance_id import EntityInstanceId
29
30
 
30
31
 
31
32
  @dataclass
@@ -56,6 +57,7 @@ class ActivityWorkItem:
56
57
  task_id: int
57
58
  input: Optional[str]
58
59
  completion_token: int
60
+ version: Optional[str] = None
59
61
 
60
62
 
61
63
  @dataclass
@@ -451,9 +453,57 @@ class InMemoryOrchestrationBackend(stubs.TaskHubSidecarServiceServicer):
451
453
  f"Restarted instance '{request.instanceId}' as '{new_instance_id}'")
452
454
  return pb.RestartInstanceResponse(instanceId=new_instance_id)
453
455
 
456
+ @staticmethod
457
+ def _parse_work_item_filters(request: pb.GetWorkItemsRequest):
458
+ """Extract filters from the request.
459
+
460
+ Returns a tuple of three values, one per work-item category. Each
461
+ value is either ``None`` (no filtering -- dispatch everything) or a
462
+ ``dict`` mapping a task name to a ``frozenset`` of accepted versions
463
+ (empty frozenset means *any* version of that name is accepted).
464
+ An empty ``dict`` means the worker opted into filtering for that
465
+ category but listed no names, so *nothing* should match.
466
+ """
467
+ if not request.HasField("workItemFilters"):
468
+ return None, None, None
469
+ wf = request.workItemFilters
470
+
471
+ def _build_filter(filters):
472
+ result: dict[str, frozenset[str]] = {}
473
+ for f in filters:
474
+ versions = frozenset(f.versions) if f.versions else frozenset()
475
+ existing = result.get(f.name, frozenset())
476
+ result[f.name] = existing | versions
477
+ return result
478
+
479
+ orch_filter = _build_filter(wf.orchestrations)
480
+ activity_filter = _build_filter(wf.activities)
481
+ entity_filter = {f.name: frozenset() for f in wf.entities}
482
+ return orch_filter, activity_filter, entity_filter
483
+
484
+ @staticmethod
485
+ def _matches_filter(name: str, version: Optional[str],
486
+ filt: Optional[dict[str, frozenset[str]]]) -> bool:
487
+ """Check whether a work item matches the parsed filter.
488
+
489
+ *filt* is ``None`` when the worker did not opt into filtering
490
+ (everything matches). Otherwise it is a dict mapping accepted
491
+ names to a frozenset of accepted versions. An empty frozenset
492
+ means any version of that name is accepted.
493
+ """
494
+ if filt is None:
495
+ return True
496
+ accepted_versions = filt.get(name)
497
+ if accepted_versions is None:
498
+ return False
499
+ if not accepted_versions:
500
+ return True # empty set -- any version
501
+ return (version or "") in accepted_versions
502
+
454
503
  def GetWorkItems(self, request: pb.GetWorkItemsRequest, context):
455
504
  """Streams work items to the worker (orchestration and activity work items)."""
456
505
  self._logger.info("Worker connected and requesting work items")
506
+ orch_filter, activity_filter, entity_filter = self._parse_work_item_filters(request)
457
507
 
458
508
  try:
459
509
  while context.is_active() and not self._shutdown_event.is_set():
@@ -461,6 +511,7 @@ class InMemoryOrchestrationBackend(stubs.TaskHubSidecarServiceServicer):
461
511
 
462
512
  with self._lock:
463
513
  # Check for orchestration work
514
+ skipped_orchs: list[str] = []
464
515
  while self._orchestration_queue:
465
516
  instance_id = self._orchestration_queue.popleft()
466
517
  self._orchestration_queue_set.discard(instance_id)
@@ -469,11 +520,15 @@ class InMemoryOrchestrationBackend(stubs.TaskHubSidecarServiceServicer):
469
520
  if not instance or not instance.pending_events:
470
521
  continue
471
522
 
523
+ # Skip if orchestration doesn't match filters
524
+ if not self._matches_filter(
525
+ instance.name, instance.version, orch_filter):
526
+ skipped_orchs.append(instance_id)
527
+ continue
528
+
472
529
  if instance_id in self._orchestration_in_flight:
473
530
  # Already being processed — re-add to queue
474
- if instance_id not in self._orchestration_queue_set:
475
- self._orchestration_queue.append(instance_id)
476
- self._orchestration_queue_set.add(instance_id)
531
+ skipped_orchs.append(instance_id)
477
532
  break
478
533
 
479
534
  # Move pending events to dispatched_events
@@ -500,27 +555,66 @@ class InMemoryOrchestrationBackend(stubs.TaskHubSidecarServiceServicer):
500
555
  )
501
556
  break
502
557
 
558
+ # Re-queue skipped orchestrations for other workers
559
+ for s in skipped_orchs:
560
+ if s not in self._orchestration_queue_set:
561
+ self._orchestration_queue.append(s)
562
+ self._orchestration_queue_set.add(s)
563
+
503
564
  # Check for activity work
504
565
  if not work_item and self._activity_queue:
505
- activity = self._activity_queue.popleft()
506
- work_item = pb.WorkItem(
507
- completionToken=str(activity.completion_token),
508
- activityRequest=pb.ActivityRequest(
509
- name=activity.name,
510
- taskId=activity.task_id,
511
- input=wrappers_pb2.StringValue(value=activity.input) if activity.input else None,
512
- orchestrationInstance=pb.OrchestrationInstance(instanceId=activity.instance_id)
566
+ # Scan for the first matching activity
567
+ skipped: list = []
568
+ matched_activity = None
569
+ while self._activity_queue:
570
+ candidate = self._activity_queue.popleft()
571
+ if not self._matches_filter(
572
+ candidate.name, candidate.version,
573
+ activity_filter):
574
+ skipped.append(candidate)
575
+ continue
576
+ matched_activity = candidate
577
+ break
578
+ # Put back non-matching items
579
+ for s in skipped:
580
+ self._activity_queue.append(s)
581
+
582
+ if matched_activity is not None:
583
+ work_item = pb.WorkItem(
584
+ completionToken=str(matched_activity.completion_token),
585
+ activityRequest=pb.ActivityRequest(
586
+ name=matched_activity.name,
587
+ taskId=matched_activity.task_id,
588
+ input=wrappers_pb2.StringValue(value=matched_activity.input) if matched_activity.input else None,
589
+ orchestrationInstance=pb.OrchestrationInstance(instanceId=matched_activity.instance_id)
590
+ )
513
591
  )
514
- )
515
592
 
516
593
  # Check for entity work
517
594
  if not work_item:
595
+ skipped_entities: list[str] = []
518
596
  while self._entity_queue:
519
597
  entity_id = self._entity_queue.popleft()
520
598
  self._entity_queue_set.discard(entity_id)
521
599
  entity = self._entities.get(entity_id)
522
600
 
523
601
  if entity and entity.pending_operations:
602
+ # Skip if entity name doesn't match filters
603
+ if entity_filter is not None:
604
+ try:
605
+ parsed = EntityInstanceId.parse(entity_id)
606
+ if not self._matches_filter(
607
+ parsed.entity, None,
608
+ entity_filter):
609
+ skipped_entities.append(entity_id)
610
+ continue
611
+ except ValueError:
612
+ self._logger.warning(
613
+ f"Cannot parse entity ID '{entity_id}' "
614
+ f"for filter matching; skipping")
615
+ skipped_entities.append(entity_id)
616
+ continue
617
+
524
618
  # Skip if this entity is already being processed
525
619
  if entity_id in self._entity_in_flight:
526
620
  continue
@@ -547,6 +641,12 @@ class InMemoryOrchestrationBackend(stubs.TaskHubSidecarServiceServicer):
547
641
  )
548
642
  break
549
643
 
644
+ # Re-queue skipped entities for other workers
645
+ for s in skipped_entities:
646
+ if s not in self._entity_queue_set:
647
+ self._entity_queue.append(s)
648
+ self._entity_queue_set.add(s)
649
+
550
650
  if work_item:
551
651
  yield work_item
552
652
  else:
@@ -1274,12 +1374,15 @@ class InMemoryOrchestrationBackend(stubs.TaskHubSidecarServiceServicer):
1274
1374
  instance.status = pb.ORCHESTRATION_STATUS_RUNNING
1275
1375
 
1276
1376
  # Queue activity for execution
1377
+ task_version = schedule_task.version.value \
1378
+ if schedule_task.HasField("version") else None
1277
1379
  self._activity_queue.append(ActivityWorkItem(
1278
1380
  instance_id=instance.instance_id,
1279
1381
  name=task_name,
1280
1382
  task_id=task_id,
1281
1383
  input=input_value,
1282
- completion_token=instance.completion_token
1384
+ completion_token=instance.completion_token,
1385
+ version=task_version,
1283
1386
  ))
1284
1387
  self._work_available.set()
1285
1388
 
durabletask/worker.py CHANGED
@@ -9,11 +9,12 @@ import os
9
9
  import random
10
10
  import time
11
11
  from concurrent.futures import ThreadPoolExecutor
12
+ from dataclasses import dataclass, field
12
13
  from datetime import datetime, timedelta, timezone
13
14
  from threading import Event, Thread
14
15
  from types import GeneratorType
15
16
  from enum import Enum
16
- from typing import Any, Generator, Optional, Sequence, Tuple, TypeVar, Union
17
+ from typing import Any, Generator, Optional, Sequence, Tuple, TypeVar, Union, overload
17
18
  import uuid
18
19
  from packaging.version import InvalidVersion, parse
19
20
 
@@ -42,6 +43,7 @@ from durabletask.payload.store import PayloadStore
42
43
  TInput = TypeVar("TInput")
43
44
  TOutput = TypeVar("TOutput")
44
45
  DATETIME_STRING_FORMAT = '%Y-%m-%dT%H:%M:%S.%fZ'
46
+ DEFAULT_MAXIMUM_TIMER_INTERVAL = timedelta(days=3)
45
47
 
46
48
 
47
49
  class ConcurrencyOptions:
@@ -143,6 +145,109 @@ class VersioningOptions:
143
145
  self.failure_strategy = failure_strategy
144
146
 
145
147
 
148
+ # Sentinel object used to distinguish "auto-generate filters" from "clear filters (None)".
149
+ _AUTO_GENERATE_FILTERS = object()
150
+
151
+
152
+ @dataclass(frozen=True)
153
+ class OrchestrationWorkItemFilter:
154
+ """Specifies a filter for orchestration work items."""
155
+
156
+ name: str
157
+ """The name of the orchestration to filter."""
158
+ versions: list[str] = field(default_factory=list)
159
+ """Optional list of versions to filter."""
160
+
161
+
162
+ @dataclass(frozen=True)
163
+ class ActivityWorkItemFilter:
164
+ """Specifies a filter for activity work items."""
165
+
166
+ name: str
167
+ """The name of the activity to filter."""
168
+ versions: list[str] = field(default_factory=list)
169
+ """Optional list of versions to filter."""
170
+
171
+
172
+ @dataclass(frozen=True)
173
+ class EntityWorkItemFilter:
174
+ """Specifies a filter for entity work items.
175
+
176
+ The name is normalized to lowercase to match entity registration
177
+ and instance ID conventions.
178
+ """
179
+
180
+ name: str
181
+ """The name of the entity to filter."""
182
+
183
+ def __post_init__(self):
184
+ EntityInstanceId.validate_entity_name(self.name)
185
+ object.__setattr__(self, 'name', self.name.lower())
186
+
187
+
188
+ @dataclass(frozen=True)
189
+ class WorkItemFilters:
190
+ """Work item filters for a Durable Task Worker.
191
+
192
+ These filters are passed to the backend and only work items matching the
193
+ filters will be processed by the worker. If no filters are provided, the
194
+ worker will process all work items.
195
+
196
+ By default, no filters are applied. Call
197
+ :meth:`TaskHubGrpcWorker.use_work_item_filters` to enable filtering.
198
+ """
199
+
200
+ orchestrations: list[OrchestrationWorkItemFilter] = field(default_factory=list)
201
+ """List of orchestration filters."""
202
+ activities: list[ActivityWorkItemFilter] = field(default_factory=list)
203
+ """List of activity filters."""
204
+ entities: list[EntityWorkItemFilter] = field(default_factory=list)
205
+ """List of entity filters."""
206
+
207
+ @classmethod
208
+ def _from_registry(cls, registry: '_Registry') -> 'WorkItemFilters':
209
+ """Auto-generate work item filters from the task registry."""
210
+ versions: list[str] = []
211
+ v = registry.versioning
212
+ if v and v.match_strategy == VersionMatchStrategy.STRICT and v.version:
213
+ versions = [registry.versioning.version]
214
+
215
+ orchestrations = [
216
+ OrchestrationWorkItemFilter(name=name, versions=list(versions))
217
+ for name in registry.orchestrators
218
+ ]
219
+ activities = [
220
+ ActivityWorkItemFilter(name=name, versions=list(versions))
221
+ for name in registry.activities
222
+ ]
223
+ entities = [
224
+ EntityWorkItemFilter(name=name)
225
+ for name in registry.entities
226
+ ]
227
+ return cls(
228
+ orchestrations=orchestrations,
229
+ activities=activities,
230
+ entities=entities,
231
+ )
232
+
233
+ def _to_grpc(self) -> pb.WorkItemFilters:
234
+ """Convert to a gRPC WorkItemFilters message."""
235
+ grpc_filters = pb.WorkItemFilters()
236
+ for f in self.orchestrations:
237
+ grpc_filters.orchestrations.append(
238
+ pb.OrchestrationFilter(name=f.name, versions=f.versions)
239
+ )
240
+ for f in self.activities:
241
+ grpc_filters.activities.append(
242
+ pb.ActivityFilter(name=f.name, versions=f.versions)
243
+ )
244
+ for f in self.entities:
245
+ grpc_filters.entities.append(
246
+ pb.EntityFilter(name=f.name)
247
+ )
248
+ return grpc_filters
249
+
250
+
146
251
  class _Registry:
147
252
  orchestrators: dict[str, task.Orchestrator]
148
253
  activities: dict[str, task.Activity]
@@ -311,7 +416,7 @@ class TaskHubGrpcWorker:
311
416
  activity function.
312
417
  """
313
418
 
314
- _response_stream: Optional[grpc.Future] = None
419
+ _response_stream: Optional[Any] = None
315
420
  _interceptors: Optional[list[shared.ClientInterceptor]] = None
316
421
 
317
422
  def __init__(
@@ -324,6 +429,7 @@ class TaskHubGrpcWorker:
324
429
  secure_channel: bool = False,
325
430
  interceptors: Optional[Sequence[shared.ClientInterceptor]] = None,
326
431
  concurrency_options: Optional[ConcurrencyOptions] = None,
432
+ maximum_timer_interval: Optional[timedelta] = DEFAULT_MAXIMUM_TIMER_INTERVAL,
327
433
  payload_store: Optional[PayloadStore] = None,
328
434
  ):
329
435
  self._registry = _Registry()
@@ -354,12 +460,20 @@ class TaskHubGrpcWorker:
354
460
  self._interceptors = None
355
461
 
356
462
  self._async_worker_manager = _AsyncWorkerManager(self._concurrency_options, self._logger)
463
+ self._maximum_timer_interval = maximum_timer_interval
464
+ self._work_item_filters: Optional[WorkItemFilters] = None
465
+ self._auto_generate_work_item_filters: bool = False
357
466
 
358
467
  @property
359
468
  def concurrency_options(self) -> ConcurrencyOptions:
360
469
  """Get the current concurrency options for this worker."""
361
470
  return self._concurrency_options
362
471
 
472
+ @property
473
+ def maximum_timer_interval(self) -> Optional[timedelta]:
474
+ """Get the configured maximum timer interval for long timer chunking."""
475
+ return self._maximum_timer_interval
476
+
363
477
  def __enter__(self):
364
478
  return self
365
479
 
@@ -396,11 +510,65 @@ class TaskHubGrpcWorker:
396
510
  raise RuntimeError("Cannot set default version while the worker is running.")
397
511
  self._registry.versioning = version
398
512
 
513
+ @overload
514
+ def use_work_item_filters(self) -> None:
515
+ ...
516
+
517
+ @overload
518
+ def use_work_item_filters(self, filters: WorkItemFilters) -> None:
519
+ ...
520
+
521
+ @overload
522
+ def use_work_item_filters(self, filters: None) -> None:
523
+ ...
524
+
525
+ def use_work_item_filters(
526
+ self,
527
+ filters: Union[WorkItemFilters, None, object] = _AUTO_GENERATE_FILTERS,
528
+ ) -> None:
529
+ """Configures work item filters for the worker.
530
+
531
+ Work item filters tell the backend which orchestrations, activities,
532
+ and entities this worker can handle. When enabled, only matching work
533
+ items are dispatched to this worker.
534
+
535
+ By default no filters are applied and the worker processes all work
536
+ items. Calling this method enables filtering.
537
+
538
+ Args:
539
+ filters: The filters to apply. If omitted (default), filters are
540
+ auto-generated from registered orchestrations, activities, and
541
+ entities at :meth:`start` time. Pass a :class:`WorkItemFilters`
542
+ instance to provide explicit filters. Pass ``None`` to clear
543
+ any previously configured filters.
544
+ """
545
+ if self._is_running:
546
+ raise RuntimeError(
547
+ "Work item filters cannot be changed while the worker is running."
548
+ )
549
+ if filters is _AUTO_GENERATE_FILTERS:
550
+ self._auto_generate_work_item_filters = True
551
+ self._work_item_filters = None
552
+ elif filters is None:
553
+ self._auto_generate_work_item_filters = False
554
+ self._work_item_filters = None
555
+ elif isinstance(filters, WorkItemFilters):
556
+ self._auto_generate_work_item_filters = False
557
+ self._work_item_filters = filters
558
+ else:
559
+ raise TypeError(
560
+ "filters must be a WorkItemFilters instance, None, or omitted."
561
+ )
562
+
399
563
  def start(self):
400
564
  """Starts the worker on a background thread and begins listening for work items."""
401
565
  if self._is_running:
402
566
  raise RuntimeError("The worker is already running.")
403
567
 
568
+ # Auto-generate work item filters from registry if opted in
569
+ if self._auto_generate_work_item_filters:
570
+ self._work_item_filters = WorkItemFilters._from_registry(self._registry)
571
+
404
572
  def run_loop():
405
573
  loop = asyncio.new_event_loop()
406
574
  asyncio.set_event_loop(loop)
@@ -510,6 +678,10 @@ class TaskHubGrpcWorker:
510
678
  maxConcurrentActivityWorkItems=self._concurrency_options.maximum_concurrent_activity_work_items,
511
679
  capabilities=capabilities,
512
680
  )
681
+ if self._work_item_filters is not None:
682
+ get_work_items_request.workItemFilters.CopyFrom(
683
+ self._work_item_filters._to_grpc()
684
+ )
513
685
  self._response_stream = stub.GetWorkItems(get_work_items_request)
514
686
  self._logger.info(
515
687
  f"Successfully connected to {self._host_address}. Waiting for work items..."
@@ -522,7 +694,11 @@ class TaskHubGrpcWorker:
522
694
 
523
695
  def stream_reader():
524
696
  try:
525
- for work_item in self._response_stream:
697
+ response_stream = self._response_stream
698
+ if response_stream is None:
699
+ return
700
+
701
+ for work_item in response_stream:
526
702
  work_item_queue.put(work_item)
527
703
  except Exception as e:
528
704
  work_item_queue.put(e)
@@ -674,7 +850,8 @@ class TaskHubGrpcWorker:
674
850
  try:
675
851
  executor = _OrchestrationExecutor(
676
852
  self._registry, self._logger,
677
- persisted_orch_span_id=persisted_orch_span_id)
853
+ persisted_orch_span_id=persisted_orch_span_id,
854
+ maximum_timer_interval=self.maximum_timer_interval)
678
855
  result = executor.execute(instance_id, req.pastEvents, req.newEvents)
679
856
 
680
857
  # Determine completion status for span
@@ -962,7 +1139,11 @@ class _RuntimeOrchestrationContext(task.OrchestrationContext):
962
1139
  _generator: Optional[Generator[task.Task, Any, Any]]
963
1140
  _previous_task: Optional[task.Task]
964
1141
 
965
- def __init__(self, instance_id: str, registry: _Registry):
1142
+ def __init__(self,
1143
+ instance_id: str,
1144
+ registry: _Registry,
1145
+ maximum_timer_interval: Optional[timedelta] = DEFAULT_MAXIMUM_TIMER_INTERVAL,
1146
+ ):
966
1147
  self._generator = None
967
1148
  self._is_replaying = True
968
1149
  self._is_complete = False
@@ -983,12 +1164,13 @@ class _RuntimeOrchestrationContext(task.OrchestrationContext):
983
1164
  self._version: Optional[str] = None
984
1165
  self._completion_status: Optional[pb.OrchestrationStatus] = None
985
1166
  self._received_events: dict[str, list[Any]] = {}
986
- self._pending_events: dict[str, list[task.CompletableTask]] = {}
1167
+ self._pending_events: dict[str, list[task.CancellableTask]] = {}
987
1168
  self._new_input: Optional[Any] = None
988
1169
  self._save_events = False
989
1170
  self._encoded_custom_status: Optional[str] = None
990
1171
  self._parent_trace_context: Optional[pb.TraceContext] = None
991
1172
  self._orchestration_trace_context: Optional[pb.TraceContext] = None
1173
+ self._maximum_timer_interval = maximum_timer_interval
992
1174
 
993
1175
  def run(self, generator: Generator[task.Task, Any, Any]):
994
1176
  self._generator = generator
@@ -1154,7 +1336,7 @@ class _RuntimeOrchestrationContext(task.OrchestrationContext):
1154
1336
  shared.to_json(custom_status) if custom_status is not None else None
1155
1337
  )
1156
1338
 
1157
- def create_timer(self, fire_at: Union[datetime, timedelta]) -> task.Task:
1339
+ def create_timer(self, fire_at: Union[datetime, timedelta]) -> task.CancellableTask:
1158
1340
  return self.create_timer_internal(fire_at)
1159
1341
 
1160
1342
  def create_timer_internal(
@@ -1164,11 +1346,30 @@ class _RuntimeOrchestrationContext(task.OrchestrationContext):
1164
1346
  ) -> task.TimerTask:
1165
1347
  id = self.next_sequence_number()
1166
1348
  if isinstance(fire_at, timedelta):
1167
- fire_at = self.current_utc_datetime + fire_at
1168
- action = ph.new_create_timer_action(id, fire_at)
1349
+ final_fire_at = self.current_utc_datetime + fire_at
1350
+ else:
1351
+ final_fire_at = fire_at
1352
+
1353
+ next_fire_at: datetime = final_fire_at
1354
+
1355
+ if (
1356
+ self._maximum_timer_interval is not None
1357
+ and self._maximum_timer_interval > timedelta(0)
1358
+ and self.current_utc_datetime + self._maximum_timer_interval < final_fire_at
1359
+ ):
1360
+ timer_task = task.TimerTask(final_fire_at, self._maximum_timer_interval)
1361
+ next_fire_at = timer_task._get_next_fire_at(self.current_utc_datetime)
1362
+ else:
1363
+ timer_task = task.TimerTask()
1364
+
1365
+ action = ph.new_create_timer_action(id, next_fire_at)
1169
1366
  self._pending_actions[id] = action
1170
1367
 
1171
- timer_task: task.TimerTask = task.TimerTask()
1368
+ def _cancel_timer() -> None:
1369
+ self._pending_actions.pop(id, None)
1370
+ self._pending_tasks.pop(id, None)
1371
+
1372
+ timer_task.set_cancel_handler(_cancel_timer)
1172
1373
  if retryable_task is not None:
1173
1374
  timer_task.set_retryable_parent(retryable_task)
1174
1375
  self._pending_tasks[id] = timer_task
@@ -1399,13 +1600,13 @@ class _RuntimeOrchestrationContext(task.OrchestrationContext):
1399
1600
  action = pb.OrchestratorAction(id=task_id, sendEntityMessage=entity_unlock_message)
1400
1601
  self._pending_actions[task_id] = action
1401
1602
 
1402
- def wait_for_external_event(self, name: str) -> task.CompletableTask:
1603
+ def wait_for_external_event(self, name: str) -> task.CancellableTask:
1403
1604
  # Check to see if this event has already been received, in which case we
1404
1605
  # can return it immediately. Otherwise, record out intent to receive an
1405
1606
  # event with the given name so that we can resume the generator when it
1406
1607
  # arrives. If there are multiple events with the same name, we return
1407
1608
  # them in the order they were received.
1408
- external_event_task: task.CompletableTask = task.CompletableTask()
1609
+ external_event_task: task.CancellableTask = task.CancellableTask()
1409
1610
  event_name = name.casefold()
1410
1611
  event_list = self._received_events.get(event_name, None)
1411
1612
  if event_list:
@@ -1419,6 +1620,19 @@ class _RuntimeOrchestrationContext(task.OrchestrationContext):
1419
1620
  task_list = []
1420
1621
  self._pending_events[event_name] = task_list
1421
1622
  task_list.append(external_event_task)
1623
+
1624
+ def _cancel_wait() -> None:
1625
+ waiting_tasks = self._pending_events.get(event_name)
1626
+ if waiting_tasks is None:
1627
+ return
1628
+ try:
1629
+ waiting_tasks.remove(external_event_task)
1630
+ except ValueError:
1631
+ return
1632
+ if not waiting_tasks:
1633
+ del self._pending_events[event_name]
1634
+
1635
+ external_event_task.set_cancel_handler(_cancel_wait)
1422
1636
  return external_event_task
1423
1637
 
1424
1638
  def continue_as_new(self, new_input, *, save_events: bool = False) -> None:
@@ -1461,9 +1675,11 @@ class _OrchestrationExecutor:
1461
1675
  registry: _Registry,
1462
1676
  logger: logging.Logger,
1463
1677
  persisted_orch_span_id: Optional[str] = None,
1678
+ maximum_timer_interval: Optional[timedelta] = DEFAULT_MAXIMUM_TIMER_INTERVAL,
1464
1679
  ):
1465
1680
  self._registry = registry
1466
1681
  self._logger = logger
1682
+ self._maximum_timer_interval = maximum_timer_interval
1467
1683
  self._is_suspended = False
1468
1684
  self._suspended_events: list[pb.HistoryEvent] = []
1469
1685
  self._persisted_orch_span_id = persisted_orch_span_id
@@ -1497,7 +1713,11 @@ class _OrchestrationExecutor:
1497
1713
  "The new history event list must have at least one event in it."
1498
1714
  )
1499
1715
 
1500
- ctx = _RuntimeOrchestrationContext(instance_id, self._registry)
1716
+ ctx = _RuntimeOrchestrationContext(
1717
+ instance_id,
1718
+ self._registry,
1719
+ maximum_timer_interval=self._maximum_timer_interval,
1720
+ )
1501
1721
  try:
1502
1722
  # Rebuild local state by replaying old history into the orchestrator function
1503
1723
  self._logger.debug(
@@ -1654,6 +1874,12 @@ class _OrchestrationExecutor:
1654
1874
  f"{ctx.instance_id}: Ignoring unexpected timerFired event with ID = {timer_id}."
1655
1875
  )
1656
1876
  return
1877
+ if not isinstance(timer_task, task.TimerTask):
1878
+ if not ctx._is_replaying:
1879
+ self._logger.warning(
1880
+ f"{ctx.instance_id}: Ignoring timerFired event with non-timer task ID = {timer_id}."
1881
+ )
1882
+ return
1657
1883
  # Emit timer span with backdated start time (skip during replay)
1658
1884
  if not ctx.is_replaying:
1659
1885
  timer_info = self._timer_fire_at.get(timer_id)
@@ -1665,27 +1891,39 @@ class _OrchestrationExecutor:
1665
1891
  scheduled_time_ns=created_ns,
1666
1892
  parent_trace_context=ctx._orchestration_trace_context or ctx._parent_trace_context,
1667
1893
  )
1668
- timer_task.complete(None)
1669
- if timer_task._retryable_parent is not None:
1670
- activity_action = timer_task._retryable_parent._action
1894
+ next_fire_at = timer_task._handle_timer_fired(event.timerFired.fireAt.ToDatetime())
1895
+ if next_fire_at is not None:
1896
+ id = ctx.next_sequence_number()
1897
+ new_action = ph.new_create_timer_action(id, next_fire_at)
1898
+ ctx._pending_tasks[id] = timer_task
1899
+ ctx._pending_actions[id] = new_action
1900
+
1901
+ def _cancel_timer() -> None:
1902
+ ctx._pending_actions.pop(id, None)
1903
+ ctx._pending_tasks.pop(id, None)
1904
+
1905
+ timer_task.set_cancel_handler(_cancel_timer)
1906
+ else:
1907
+ if timer_task._retryable_parent is not None:
1908
+ activity_action = timer_task._retryable_parent._action
1671
1909
 
1672
- if not timer_task._retryable_parent._is_sub_orch:
1673
- cur_task = activity_action.scheduleTask
1674
- instance_id = None
1910
+ if not timer_task._retryable_parent._is_sub_orch:
1911
+ cur_task = activity_action.scheduleTask
1912
+ instance_id = None
1913
+ else:
1914
+ cur_task = activity_action.createSubOrchestration
1915
+ instance_id = cur_task.instanceId
1916
+ ctx.call_activity_function_helper(
1917
+ id=activity_action.id,
1918
+ activity_function=cur_task.name,
1919
+ input=cur_task.input.value,
1920
+ retry_policy=timer_task._retryable_parent._retry_policy,
1921
+ is_sub_orch=timer_task._retryable_parent._is_sub_orch,
1922
+ instance_id=instance_id,
1923
+ fn_task=timer_task._retryable_parent,
1924
+ )
1675
1925
  else:
1676
- cur_task = activity_action.createSubOrchestration
1677
- instance_id = cur_task.instanceId
1678
- ctx.call_activity_function_helper(
1679
- id=activity_action.id,
1680
- activity_function=cur_task.name,
1681
- input=cur_task.input.value,
1682
- retry_policy=timer_task._retryable_parent._retry_policy,
1683
- is_sub_orch=timer_task._retryable_parent._is_sub_orch,
1684
- instance_id=instance_id,
1685
- fn_task=timer_task._retryable_parent,
1686
- )
1687
- else:
1688
- ctx.resume()
1926
+ ctx.resume()
1689
1927
  elif event.HasField("taskScheduled"):
1690
1928
  # This history event confirms that the activity execution was successfully scheduled.
1691
1929
  # Remove the taskScheduled event from the pending action list so we don't schedule it again.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: durabletask
3
- Version: 1.3.0.dev27
3
+ Version: 1.3.0.dev29
4
4
  Summary: A Durable Task Client SDK for Python
5
5
  License: MIT License
6
6
 
@@ -1,8 +1,8 @@
1
- durabletask/__init__.py,sha256=Xb-2zUIwDbCCl4Q_89TqMJ-3zckHFrGtSMcVhEeSseQ,407
1
+ durabletask/__init__.py,sha256=OdfKCNlS_NJawRfLWsFNj7YIHeGSQkh2VH3OzG0Oric,644
2
2
  durabletask/client.py,sha256=NbIdDTQR7XI_ZiqsGMP0q5vmbe5-ShyGUQre1qgB-Ag,34107
3
3
  durabletask/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
- durabletask/task.py,sha256=YeZj-tYHiLml672hAMtri8_hqnpmR5in4AHMi5owFQo,21793
5
- durabletask/worker.py,sha256=TNc1WJTmh42sLWyD_FuwUw7YqrIPA7AFpJKd0iHOn5M,124479
4
+ durabletask/task.py,sha256=QS7RPUaWhb4Xq5hi7nnueooJqQSDO6ipEoiAiSVkAIs,24009
5
+ durabletask/worker.py,sha256=B8aYoOguNuYccmmtGF9GFr3YEEdfTGJqcOoox0kByQM,133661
6
6
  durabletask/entities/__init__.py,sha256=DbNd5riqWZaj3tG6gN82O8Q6wTmFpe6QaH0pQgDSPHs,721
7
7
  durabletask/entities/durable_entity.py,sha256=LQPWnUlRsHiFVRoTdpeSK--eXtjf2UGbVQwEEKf7QwI,3318
8
8
  durabletask/entities/entity_context.py,sha256=U-B3i9QP34N-6Fikx_tMp8zo0YLdmwhwpdxwjHd7z-M,5346
@@ -31,9 +31,9 @@ durabletask/payload/__init__.py,sha256=1h68pQvgk8JUp5LBJuBq9W4GUPYkdlhqmCCQEg6YB
31
31
  durabletask/payload/helpers.py,sha256=RYG5MEVAqHjm4zfFHs3Td91FVQHUoCcb5hbEJ4sYj5s,12350
32
32
  durabletask/payload/store.py,sha256=3qJMvKxRUkr6ScWUzxpKAVgzuhFLywRW8a2_5OOmNk4,3000
33
33
  durabletask/testing/__init__.py,sha256=rXbcSFtzuaRAbDNX-HmdgbxLTegvKJ1FRjZfSOIAMgA,323
34
- durabletask/testing/in_memory_backend.py,sha256=REmzhgAAw_AOpxrRAEbZlU4jqQLZo6QdbWYoWaruBDo,74102
35
- durabletask-1.3.0.dev27.dist-info/licenses/LICENSE,sha256=ws_MuBL-SCEBqPBFl9_FqZkaaydIJmxHrJG2parhU4M,1141
36
- durabletask-1.3.0.dev27.dist-info/METADATA,sha256=uGCKNOJQtqqay5ULKwLclHtAQh3zHYj39_ZP37C1lMU,4404
37
- durabletask-1.3.0.dev27.dist-info/WHEEL,sha256=aeYiig01lYGDzBgS8HxWXOg3uV61G9ijOsup-k9o1sk,91
38
- durabletask-1.3.0.dev27.dist-info/top_level.txt,sha256=EBVyuKWnjOwq8bJI1Uvb9U3c4fzQxACWj9p83he6fik,12
39
- durabletask-1.3.0.dev27.dist-info/RECORD,,
34
+ durabletask/testing/in_memory_backend.py,sha256=ELxyCDRDNOabygIvw9ZeRUpP3MzeM5Hdbu6QlRwKdno,79249
35
+ durabletask-1.3.0.dev29.dist-info/licenses/LICENSE,sha256=ws_MuBL-SCEBqPBFl9_FqZkaaydIJmxHrJG2parhU4M,1141
36
+ durabletask-1.3.0.dev29.dist-info/METADATA,sha256=NFWHPGrcuEDhGp2-ut4FlmAyjtBdlOhTMz9ZuCb_EOI,4404
37
+ durabletask-1.3.0.dev29.dist-info/WHEEL,sha256=aeYiig01lYGDzBgS8HxWXOg3uV61G9ijOsup-k9o1sk,91
38
+ durabletask-1.3.0.dev29.dist-info/top_level.txt,sha256=EBVyuKWnjOwq8bJI1Uvb9U3c4fzQxACWj9p83he6fik,12
39
+ durabletask-1.3.0.dev29.dist-info/RECORD,,