durabletask 0.0.0.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
durabletask/task.py ADDED
@@ -0,0 +1,621 @@
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+
4
+ # See https://peps.python.org/pep-0563/
5
+ from __future__ import annotations
6
+
7
+ import math
8
+ from abc import ABC, abstractmethod
9
+ from datetime import datetime, timedelta
10
+ from typing import Any, Callable, Generator, Generic, Optional, TypeVar, Union
11
+
12
+ from durabletask.entities import DurableEntity, EntityInstanceId, EntityLock, EntityContext
13
+ import durabletask.internal.helpers as pbh
14
+ import durabletask.internal.orchestrator_service_pb2 as pb
15
+
16
+ T = TypeVar('T')
17
+ TInput = TypeVar('TInput')
18
+ TOutput = TypeVar('TOutput')
19
+
20
+
21
+ class OrchestrationContext(ABC):
22
+
23
+ @property
24
+ @abstractmethod
25
+ def instance_id(self) -> str:
26
+ """Get the ID of the current orchestration instance.
27
+
28
+ The instance ID is generated and fixed when the orchestrator function
29
+ is scheduled. It can be either auto-generated, in which case it is
30
+ formatted as a UUID, or it can be user-specified with any format.
31
+
32
+ Returns
33
+ -------
34
+ str
35
+ The ID of the current orchestration instance.
36
+ """
37
+ pass
38
+
39
+ @property
40
+ @abstractmethod
41
+ def version(self) -> Optional[str]:
42
+ """Get the version of the orchestration instance.
43
+
44
+ This version is set when the orchestration is scheduled and can be used
45
+ to determine which version of the orchestrator function is being executed.
46
+
47
+ Returns
48
+ -------
49
+ Optional[str]
50
+ The version of the orchestration instance, or None if not set.
51
+ """
52
+ pass
53
+
54
+ @property
55
+ @abstractmethod
56
+ def current_utc_datetime(self) -> datetime:
57
+ """Get the current date/time as UTC.
58
+
59
+ This date/time value is derived from the orchestration history. It
60
+ always returns the same value at specific points in the orchestrator
61
+ function code, making it deterministic and safe for replay.
62
+
63
+ Returns
64
+ -------
65
+ datetime
66
+ The current timestamp in a way that is safe for use by orchestrator functions
67
+ """
68
+ pass
69
+
70
+ @property
71
+ @abstractmethod
72
+ def is_replaying(self) -> bool:
73
+ """Get the value indicating whether the orchestrator is replaying from history.
74
+
75
+ This property is useful when there is logic that needs to run only when
76
+ the orchestrator function is _not_ replaying. For example, certain
77
+ types of application logging may become too noisy when duplicated as
78
+ part of orchestrator function replay. The orchestrator code could check
79
+ to see whether the function is being replayed and then issue the log
80
+ statements when this value is `false`.
81
+
82
+ Returns
83
+ -------
84
+ bool
85
+ Value indicating whether the orchestrator function is currently replaying.
86
+ """
87
+ pass
88
+
89
+ @abstractmethod
90
+ def set_custom_status(self, custom_status: Any) -> None:
91
+ """Set the orchestration instance's custom status.
92
+
93
+ Parameters
94
+ ----------
95
+ custom_status: Any
96
+ A JSON-serializable custom status value to set.
97
+ """
98
+ pass
99
+
100
+ @abstractmethod
101
+ def create_timer(self, fire_at: Union[datetime, timedelta]) -> Task:
102
+ """Create a Timer Task to fire after at the specified deadline.
103
+
104
+ Parameters
105
+ ----------
106
+ fire_at: datetime.datetime | datetime.timedelta
107
+ The time for the timer to trigger or a time delta from now.
108
+
109
+ Returns
110
+ -------
111
+ Task
112
+ A Durable Timer Task that schedules the timer to wake up the orchestrator
113
+ """
114
+ pass
115
+
116
+ @abstractmethod
117
+ def call_activity(self, activity: Union[Activity[TInput, TOutput], str], *,
118
+ input: Optional[TInput] = None,
119
+ retry_policy: Optional[RetryPolicy] = None,
120
+ tags: Optional[dict[str, str]] = None) -> Task[TOutput]:
121
+ """Schedule an activity for execution.
122
+
123
+ Parameters
124
+ ----------
125
+ activity: Union[Activity[TInput, TOutput], str]
126
+ A reference to the activity function to call.
127
+ input: Optional[TInput]
128
+ The JSON-serializable input (or None) to pass to the activity.
129
+ retry_policy: Optional[RetryPolicy]
130
+ The retry policy to use for this activity call.
131
+ tags: Optional[dict[str, str]]
132
+ Optional tags to associate with the activity invocation.
133
+
134
+ Returns
135
+ -------
136
+ Task
137
+ A Durable Task that completes when the called activity function completes or fails.
138
+ """
139
+ pass
140
+
141
+ @abstractmethod
142
+ def call_entity(self, entity: EntityInstanceId,
143
+ operation: str,
144
+ input: Optional[TInput] = None) -> Task:
145
+ """Schedule entity function for execution.
146
+
147
+ Parameters
148
+ ----------
149
+ entity: EntityInstanceId
150
+ The ID of the entity instance to call.
151
+ operation: str
152
+ The name of the operation to invoke on the entity.
153
+ input: Optional[TInput]
154
+ The optional JSON-serializable input to pass to the entity function.
155
+
156
+ Returns
157
+ -------
158
+ Task
159
+ A Durable Task that completes when the called entity function completes or fails.
160
+ """
161
+ pass
162
+
163
+ @abstractmethod
164
+ def signal_entity(
165
+ self,
166
+ entity_id: EntityInstanceId,
167
+ operation_name: str,
168
+ input: Optional[TInput] = None
169
+ ) -> None:
170
+ """Signal an entity function for execution.
171
+
172
+ Parameters
173
+ ----------
174
+ entity_id: EntityInstanceId
175
+ The ID of the entity instance to signal.
176
+ operation_name: str
177
+ The name of the operation to invoke on the entity.
178
+ input: Optional[TInput]
179
+ The optional JSON-serializable input to pass to the entity function.
180
+ """
181
+ pass
182
+
183
+ @abstractmethod
184
+ def lock_entities(self, entities: list[EntityInstanceId]) -> Task[EntityLock]:
185
+ """Creates a Task object that locks the specified entity instances.
186
+
187
+ The locks will be acquired the next time the orchestrator yields.
188
+ Best practice is to immediately yield this Task and enter the returned EntityLock.
189
+ The lock is released when the EntityLock is exited.
190
+
191
+ Parameters
192
+ ----------
193
+ entities: list[EntityInstanceId]
194
+ The list of entity instance IDs to lock.
195
+
196
+ Returns
197
+ -------
198
+ EntityLock
199
+ A context manager object that releases the locks when exited.
200
+ """
201
+ pass
202
+
203
+ @abstractmethod
204
+ def call_sub_orchestrator(self, orchestrator: Union[Orchestrator[TInput, TOutput], str], *,
205
+ input: Optional[TInput] = None,
206
+ instance_id: Optional[str] = None,
207
+ retry_policy: Optional[RetryPolicy] = None,
208
+ version: Optional[str] = None) -> Task[TOutput]:
209
+ """Schedule sub-orchestrator function for execution.
210
+
211
+ Parameters
212
+ ----------
213
+ orchestrator: Orchestrator[TInput, TOutput]
214
+ A reference to the orchestrator function to call.
215
+ input: Optional[TInput]
216
+ The optional JSON-serializable input to pass to the orchestrator function.
217
+ instance_id: Optional[str]
218
+ A unique ID to use for the sub-orchestration instance. If not specified, a
219
+ random UUID will be used.
220
+ retry_policy: Optional[RetryPolicy]
221
+ The retry policy to use for this sub-orchestrator call.
222
+
223
+ Returns
224
+ -------
225
+ Task
226
+ A Durable Task that completes when the called sub-orchestrator completes or fails.
227
+ """
228
+ pass
229
+
230
+ # TOOD: Add a timeout parameter, which allows the task to be canceled if the event is
231
+ # not received within the specified timeout. This requires support for task cancellation.
232
+ @abstractmethod
233
+ def wait_for_external_event(self, name: str) -> Task:
234
+ """Wait asynchronously for an event to be raised with the name `name`.
235
+
236
+ Parameters
237
+ ----------
238
+ name : str
239
+ The event name of the event that the task is waiting for.
240
+
241
+ Returns
242
+ -------
243
+ Task[TOutput]
244
+ A Durable Task that completes when the event is received.
245
+ """
246
+ pass
247
+
248
+ @abstractmethod
249
+ def continue_as_new(self, new_input: Any, *, save_events: bool = False) -> None:
250
+ """Continue the orchestration execution as a new instance.
251
+
252
+ Parameters
253
+ ----------
254
+ new_input : Any
255
+ The new input to use for the new orchestration instance.
256
+ save_events : bool
257
+ A flag indicating whether to add any unprocessed external events in the new orchestration history.
258
+ """
259
+ pass
260
+
261
+ @abstractmethod
262
+ def new_uuid(self) -> str:
263
+ """Create a new UUID that is safe for replay within an orchestration or operation.
264
+
265
+ The default implementation of this method creates a name-based UUID
266
+ using the algorithm from RFC 4122 §4.3. The name input used to generate
267
+ this value is a combination of the orchestration instance ID and an
268
+ internally managed sequence number.
269
+
270
+ Returns
271
+ -------
272
+ str
273
+ New UUID that is safe for replay within an orchestration or operation.
274
+ """
275
+ pass
276
+
277
+ @abstractmethod
278
+ def _exit_critical_section(self) -> None:
279
+ pass
280
+
281
+
282
+ class FailureDetails:
283
+ def __init__(self, message: str, error_type: str, stack_trace: Optional[str]):
284
+ self._message = message
285
+ self._error_type = error_type
286
+ self._stack_trace = stack_trace
287
+
288
+ @property
289
+ def message(self) -> str:
290
+ return self._message
291
+
292
+ @property
293
+ def error_type(self) -> str:
294
+ return self._error_type
295
+
296
+ @property
297
+ def stack_trace(self) -> Optional[str]:
298
+ return self._stack_trace
299
+
300
+
301
+ class TaskFailedError(Exception):
302
+ """Exception type for all orchestration task failures."""
303
+
304
+ def __init__(self, message: str, details: pb.TaskFailureDetails):
305
+ super().__init__(message)
306
+ self._details = FailureDetails(
307
+ details.errorMessage,
308
+ details.errorType,
309
+ details.stackTrace.value if not pbh.is_empty(details.stackTrace) else None)
310
+
311
+ @property
312
+ def details(self) -> FailureDetails:
313
+ return self._details
314
+
315
+
316
+ class NonDeterminismError(Exception):
317
+ pass
318
+
319
+
320
+ class OrchestrationStateError(Exception):
321
+ pass
322
+
323
+
324
+ class Task(ABC, Generic[T]):
325
+ """Abstract base class for asynchronous tasks in a durable orchestration."""
326
+ _result: T
327
+ _exception: Optional[TaskFailedError]
328
+ _parent: Optional[CompositeTask[T]]
329
+
330
+ def __init__(self) -> None:
331
+ super().__init__()
332
+ self._is_complete = False
333
+ self._exception = None
334
+ self._parent = None
335
+
336
+ @property
337
+ def is_complete(self) -> bool:
338
+ """Returns True if the task has completed, False otherwise."""
339
+ return self._is_complete
340
+
341
+ @property
342
+ def is_failed(self) -> bool:
343
+ """Returns True if the task has failed, False otherwise."""
344
+ return self._exception is not None
345
+
346
+ def get_result(self) -> T:
347
+ """Returns the result of the task."""
348
+ if not self._is_complete:
349
+ raise ValueError('The task has not completed.')
350
+ elif self._exception is not None:
351
+ raise self._exception
352
+ return self._result
353
+
354
+ def get_exception(self) -> TaskFailedError:
355
+ """Returns the exception that caused the task to fail."""
356
+ if self._exception is None:
357
+ raise ValueError('The task has not failed.')
358
+ return self._exception
359
+
360
+
361
+ class CompositeTask(Task[T]):
362
+ """A task that is composed of other tasks."""
363
+ _tasks: list[Task]
364
+
365
+ def __init__(self, tasks: list[Task]):
366
+ super().__init__()
367
+ self._tasks = tasks
368
+ self._completed_tasks = 0
369
+ self._failed_tasks = 0
370
+ for task in tasks:
371
+ task._parent = self
372
+ if task.is_complete:
373
+ self.on_child_completed(task)
374
+
375
+ def get_tasks(self) -> list[Task]:
376
+ return self._tasks
377
+
378
+ @abstractmethod
379
+ def on_child_completed(self, task: Task[T]):
380
+ pass
381
+
382
+
383
+ class WhenAllTask(CompositeTask[list[T]]):
384
+ """A task that completes when all of its child tasks complete."""
385
+
386
+ def __init__(self, tasks: list[Task[T]]):
387
+ super().__init__(tasks)
388
+ self._completed_tasks = 0
389
+ self._failed_tasks = 0
390
+
391
+ @property
392
+ def pending_tasks(self) -> int:
393
+ """Returns the number of tasks that have not yet completed."""
394
+ return len(self._tasks) - self._completed_tasks
395
+
396
+ def on_child_completed(self, task: Task[T]):
397
+ if self.is_complete:
398
+ raise ValueError('The task has already completed.')
399
+ self._completed_tasks += 1
400
+ if task.is_failed and self._exception is None:
401
+ self._exception = task.get_exception()
402
+ self._is_complete = True
403
+ if self._completed_tasks == len(self._tasks):
404
+ # The order of the result MUST match the order of the tasks provided to the constructor.
405
+ self._result = [task.get_result() for task in self._tasks]
406
+ self._is_complete = True
407
+
408
+ def get_completed_tasks(self) -> int:
409
+ return self._completed_tasks
410
+
411
+
412
+ class CompletableTask(Task[T]):
413
+
414
+ def __init__(self):
415
+ super().__init__()
416
+ self._retryable_parent = None
417
+
418
+ def complete(self, result: T):
419
+ if self._is_complete:
420
+ raise ValueError('The task has already completed.')
421
+ self._result = result
422
+ self._is_complete = True
423
+ if self._parent is not None:
424
+ self._parent.on_child_completed(self)
425
+
426
+ def fail(self, message: str, details: pb.TaskFailureDetails):
427
+ if self._is_complete:
428
+ raise ValueError('The task has already completed.')
429
+ self._exception = TaskFailedError(message, details)
430
+ self._is_complete = True
431
+ if self._parent is not None:
432
+ self._parent.on_child_completed(self)
433
+
434
+
435
+ class RetryableTask(CompletableTask[T]):
436
+ """A task that can be retried according to a retry policy."""
437
+
438
+ def __init__(self, retry_policy: RetryPolicy, action: pb.OrchestratorAction,
439
+ start_time: datetime, is_sub_orch: bool) -> None:
440
+ super().__init__()
441
+ self._action = action
442
+ self._retry_policy = retry_policy
443
+ self._attempt_count = 1
444
+ self._start_time = start_time
445
+ self._is_sub_orch = is_sub_orch
446
+
447
+ def increment_attempt_count(self) -> None:
448
+ self._attempt_count += 1
449
+
450
+ def compute_next_delay(self) -> Optional[timedelta]:
451
+ if self._attempt_count >= self._retry_policy.max_number_of_attempts:
452
+ return None
453
+
454
+ retry_expiration: datetime = datetime.max
455
+ if self._retry_policy.retry_timeout is not None and self._retry_policy.retry_timeout != datetime.max:
456
+ retry_expiration = self._start_time + self._retry_policy.retry_timeout
457
+
458
+ if self._retry_policy.backoff_coefficient is None:
459
+ backoff_coefficient = 1.0
460
+ else:
461
+ backoff_coefficient = self._retry_policy.backoff_coefficient
462
+
463
+ if datetime.utcnow() < retry_expiration:
464
+ next_delay_f = math.pow(backoff_coefficient, self._attempt_count - 1) * self._retry_policy.first_retry_interval.total_seconds()
465
+
466
+ if self._retry_policy.max_retry_interval is not None:
467
+ next_delay_f = min(next_delay_f, self._retry_policy.max_retry_interval.total_seconds())
468
+ return timedelta(seconds=next_delay_f)
469
+
470
+ return None
471
+
472
+
473
+ class TimerTask(CompletableTask[T]):
474
+
475
+ def __init__(self) -> None:
476
+ super().__init__()
477
+
478
+ def set_retryable_parent(self, retryable_task: RetryableTask):
479
+ self._retryable_parent = retryable_task
480
+
481
+
482
+ class WhenAnyTask(CompositeTask[Task]):
483
+ """A task that completes when any of its child tasks complete."""
484
+
485
+ def __init__(self, tasks: list[Task]):
486
+ super().__init__(tasks)
487
+
488
+ def on_child_completed(self, task: Task):
489
+ # The first task to complete is the result of the WhenAnyTask.
490
+ if not self.is_complete:
491
+ self._is_complete = True
492
+ self._result = task
493
+
494
+
495
+ def when_all(tasks: list[Task[T]]) -> WhenAllTask[T]:
496
+ """Returns a task that completes when all of the provided tasks complete or when one of the tasks fail."""
497
+ return WhenAllTask(tasks)
498
+
499
+
500
+ def when_any(tasks: list[Task]) -> WhenAnyTask:
501
+ """Returns a task that completes when any of the provided tasks complete or fail."""
502
+ return WhenAnyTask(tasks)
503
+
504
+
505
+ class ActivityContext:
506
+ def __init__(self, orchestration_id: str, task_id: int):
507
+ self._orchestration_id = orchestration_id
508
+ self._task_id = task_id
509
+
510
+ @property
511
+ def orchestration_id(self) -> str:
512
+ """Get the ID of the orchestration instance that scheduled this activity.
513
+
514
+ Returns
515
+ -------
516
+ str
517
+ The ID of the current orchestration instance.
518
+ """
519
+ return self._orchestration_id
520
+
521
+ @property
522
+ def task_id(self) -> int:
523
+ """Get the task ID associated with this activity invocation.
524
+
525
+ The task ID is an auto-incrementing integer that is unique within
526
+ the scope of the orchestration instance. It can be used to distinguish
527
+ between multiple activity invocations that are part of the same
528
+ orchestration instance.
529
+
530
+ Returns
531
+ -------
532
+ str
533
+ The ID of the current orchestration instance.
534
+ """
535
+ return self._task_id
536
+
537
+
538
+ # Orchestrators are generators that yield tasks and receive/return any type
539
+ Orchestrator = Callable[[OrchestrationContext, TInput], Union[Generator[Task, Any, Any], TOutput]]
540
+
541
+ # Activities are simple functions that can be scheduled by orchestrators
542
+ Activity = Callable[[ActivityContext, TInput], TOutput]
543
+
544
+ Entity = Union[Callable[[EntityContext, TInput], TOutput], type[DurableEntity]]
545
+
546
+
547
+ class RetryPolicy:
548
+ """Represents the retry policy for an orchestration or activity function."""
549
+
550
+ def __init__(self, *,
551
+ first_retry_interval: timedelta,
552
+ max_number_of_attempts: int,
553
+ backoff_coefficient: Optional[float] = 1.0,
554
+ max_retry_interval: Optional[timedelta] = None,
555
+ retry_timeout: Optional[timedelta] = None):
556
+ """Creates a new RetryPolicy instance.
557
+
558
+ Parameters
559
+ ----------
560
+ first_retry_interval : timedelta
561
+ The retry interval to use for the first retry attempt.
562
+ max_number_of_attempts : int
563
+ The maximum number of retry attempts.
564
+ backoff_coefficient : Optional[float]
565
+ The backoff coefficient to use for calculating the next retry interval.
566
+ max_retry_interval : Optional[timedelta]
567
+ The maximum retry interval to use for any retry attempt.
568
+ retry_timeout : Optional[timedelta]
569
+ The maximum amount of time to spend retrying the operation.
570
+ """
571
+ # validate inputs
572
+ if first_retry_interval < timedelta(seconds=0):
573
+ raise ValueError('first_retry_interval must be >= 0')
574
+ if max_number_of_attempts < 1:
575
+ raise ValueError('max_number_of_attempts must be >= 1')
576
+ if backoff_coefficient is not None and backoff_coefficient < 1:
577
+ raise ValueError('backoff_coefficient must be >= 1')
578
+ if max_retry_interval is not None and max_retry_interval < timedelta(seconds=0):
579
+ raise ValueError('max_retry_interval must be >= 0')
580
+ if retry_timeout is not None and retry_timeout < timedelta(seconds=0):
581
+ raise ValueError('retry_timeout must be >= 0')
582
+
583
+ self._first_retry_interval = first_retry_interval
584
+ self._max_number_of_attempts = max_number_of_attempts
585
+ self._backoff_coefficient = backoff_coefficient
586
+ self._max_retry_interval = max_retry_interval
587
+ self._retry_timeout = retry_timeout
588
+
589
+ @property
590
+ def first_retry_interval(self) -> timedelta:
591
+ """The retry interval to use for the first retry attempt."""
592
+ return self._first_retry_interval
593
+
594
+ @property
595
+ def max_number_of_attempts(self) -> int:
596
+ """The maximum number of retry attempts."""
597
+ return self._max_number_of_attempts
598
+
599
+ @property
600
+ def backoff_coefficient(self) -> Optional[float]:
601
+ """The backoff coefficient to use for calculating the next retry interval."""
602
+ return self._backoff_coefficient
603
+
604
+ @property
605
+ def max_retry_interval(self) -> Optional[timedelta]:
606
+ """The maximum retry interval to use for any retry attempt."""
607
+ return self._max_retry_interval
608
+
609
+ @property
610
+ def retry_timeout(self) -> Optional[timedelta]:
611
+ """The maximum amount of time to spend retrying the operation."""
612
+ return self._retry_timeout
613
+
614
+
615
+ def get_name(fn: Callable) -> str:
616
+ """Returns the name of the provided function"""
617
+ name = fn.__name__
618
+ if name == '<lambda>':
619
+ raise ValueError('Cannot infer a name from a lambda function. Please provide a name explicitly.')
620
+
621
+ return name