durabletask 0.1.0a1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -5,30 +5,68 @@ import dataclasses
5
5
  import json
6
6
  import logging
7
7
  from types import SimpleNamespace
8
- from typing import Any, Dict, Union
8
+ from typing import Any, Optional, Sequence, Union
9
9
 
10
10
  import grpc
11
11
 
12
+ ClientInterceptor = Union[
13
+ grpc.UnaryUnaryClientInterceptor,
14
+ grpc.UnaryStreamClientInterceptor,
15
+ grpc.StreamUnaryClientInterceptor,
16
+ grpc.StreamStreamClientInterceptor
17
+ ]
18
+
12
19
  # Field name used to indicate that an object was automatically serialized
13
20
  # and should be deserialized as a SimpleNamespace
14
21
  AUTO_SERIALIZED = "__durabletask_autoobject__"
15
22
 
23
+ SECURE_PROTOCOLS = ["https://", "grpcs://"]
24
+ INSECURE_PROTOCOLS = ["http://", "grpc://"]
25
+
16
26
 
17
27
  def get_default_host_address() -> str:
18
28
  return "localhost:4001"
19
29
 
20
30
 
21
- def get_grpc_channel(host_address: Union[str, None]) -> grpc.Channel:
31
+ def get_grpc_channel(
32
+ host_address: Optional[str],
33
+ secure_channel: bool = False,
34
+ interceptors: Optional[Sequence[ClientInterceptor]] = None) -> grpc.Channel:
35
+
22
36
  if host_address is None:
23
37
  host_address = get_default_host_address()
24
- channel = grpc.insecure_channel(host_address)
38
+
39
+ for protocol in SECURE_PROTOCOLS:
40
+ if host_address.lower().startswith(protocol):
41
+ secure_channel = True
42
+ # remove the protocol from the host name
43
+ host_address = host_address[len(protocol):]
44
+ break
45
+
46
+ for protocol in INSECURE_PROTOCOLS:
47
+ if host_address.lower().startswith(protocol):
48
+ secure_channel = False
49
+ # remove the protocol from the host name
50
+ host_address = host_address[len(protocol):]
51
+ break
52
+
53
+ # Create the base channel
54
+ if secure_channel:
55
+ channel = grpc.secure_channel(host_address, grpc.ssl_channel_credentials())
56
+ else:
57
+ channel = grpc.insecure_channel(host_address)
58
+
59
+ # Apply interceptors ONLY if they exist
60
+ if interceptors:
61
+ channel = grpc.intercept_channel(channel, *interceptors)
25
62
  return channel
26
63
 
27
64
 
28
65
  def get_logger(
29
- log_handler: Union[logging.Handler, None] = None,
30
- log_formatter: Union[logging.Formatter, None] = None) -> logging.Logger:
31
- logger = logging.Logger("durabletask")
66
+ name_suffix: str,
67
+ log_handler: Optional[logging.Handler] = None,
68
+ log_formatter: Optional[logging.Formatter] = None) -> logging.Logger:
69
+ logger = logging.Logger(f"durabletask-{name_suffix}")
32
70
 
33
71
  # Add a default log handler if none is provided
34
72
  if log_handler is None:
@@ -68,7 +106,7 @@ class InternalJSONEncoder(json.JSONEncoder):
68
106
  if dataclasses.is_dataclass(obj):
69
107
  # Dataclasses are not serializable by default, so we convert them to a dict and mark them for
70
108
  # automatic deserialization by the receiver
71
- d = dataclasses.asdict(obj)
109
+ d = dataclasses.asdict(obj) # type: ignore
72
110
  d[AUTO_SERIALIZED] = True
73
111
  return d
74
112
  elif isinstance(obj, SimpleNamespace):
@@ -84,7 +122,7 @@ class InternalJSONDecoder(json.JSONDecoder):
84
122
  def __init__(self, *args, **kwargs):
85
123
  super().__init__(object_hook=self.dict_to_object, *args, **kwargs)
86
124
 
87
- def dict_to_object(self, d: Dict[str, Any]):
125
+ def dict_to_object(self, d: dict[str, Any]):
88
126
  # If the object was serialized by the InternalJSONEncoder, deserialize it as a SimpleNamespace
89
127
  if d.pop(AUTO_SERIALIZED, False):
90
128
  return SimpleNamespace(**d)
durabletask/task.py CHANGED
@@ -4,10 +4,12 @@
4
4
  # See https://peps.python.org/pep-0563/
5
5
  from __future__ import annotations
6
6
 
7
+ import math
7
8
  from abc import ABC, abstractmethod
8
9
  from datetime import datetime, timedelta
9
- from typing import Any, Callable, Generator, Generic, List, TypeVar, Union
10
+ from typing import Any, Callable, Generator, Generic, Optional, TypeVar, Union
10
11
 
12
+ from durabletask.entities import DurableEntity, EntityInstanceId, EntityLock, EntityContext
11
13
  import durabletask.internal.helpers as pbh
12
14
  import durabletask.internal.orchestrator_service_pb2 as pb
13
15
 
@@ -34,6 +36,21 @@ class OrchestrationContext(ABC):
34
36
  """
35
37
  pass
36
38
 
39
+ @property
40
+ @abstractmethod
41
+ def version(self) -> Optional[str]:
42
+ """Get the version of the orchestration instance.
43
+
44
+ This version is set when the orchestration is scheduled and can be used
45
+ to determine which version of the orchestrator function is being executed.
46
+
47
+ Returns
48
+ -------
49
+ Optional[str]
50
+ The version of the orchestration instance, or None if not set.
51
+ """
52
+ pass
53
+
37
54
  @property
38
55
  @abstractmethod
39
56
  def current_utc_datetime(self) -> datetime:
@@ -69,6 +86,17 @@ class OrchestrationContext(ABC):
69
86
  """
70
87
  pass
71
88
 
89
+ @abstractmethod
90
+ def set_custom_status(self, custom_status: Any) -> None:
91
+ """Set the orchestration instance's custom status.
92
+
93
+ Parameters
94
+ ----------
95
+ custom_status: Any
96
+ A JSON-serializable custom status value to set.
97
+ """
98
+ pass
99
+
72
100
  @abstractmethod
73
101
  def create_timer(self, fire_at: Union[datetime, timedelta]) -> Task:
74
102
  """Create a Timer Task to fire after at the specified deadline.
@@ -87,17 +115,21 @@ class OrchestrationContext(ABC):
87
115
 
88
116
  @abstractmethod
89
117
  def call_activity(self, activity: Union[Activity[TInput, TOutput], str], *,
90
- input: Union[TInput, None] = None) -> Task[TOutput]:
118
+ input: Optional[TInput] = None,
119
+ retry_policy: Optional[RetryPolicy] = None,
120
+ tags: Optional[dict[str, str]] = None) -> Task[TOutput]:
91
121
  """Schedule an activity for execution.
92
122
 
93
123
  Parameters
94
124
  ----------
95
125
  activity: Union[Activity[TInput, TOutput], str]
96
126
  A reference to the activity function to call.
97
- input: Union[TInput, None]
127
+ input: Optional[TInput]
98
128
  The JSON-serializable input (or None) to pass to the activity.
99
- return_type: task.Task[TOutput]
100
- The JSON-serializable output type to expect from the activity result.
129
+ retry_policy: Optional[RetryPolicy]
130
+ The retry policy to use for this activity call.
131
+ tags: Optional[dict[str, str]]
132
+ Optional tags to associate with the activity invocation.
101
133
 
102
134
  Returns
103
135
  -------
@@ -107,20 +139,86 @@ class OrchestrationContext(ABC):
107
139
  pass
108
140
 
109
141
  @abstractmethod
110
- def call_sub_orchestrator(self, orchestrator: Orchestrator[TInput, TOutput], *,
111
- input: Union[TInput, None] = None,
112
- instance_id: Union[str, None] = None) -> Task[TOutput]:
142
+ def call_entity(self, entity: EntityInstanceId,
143
+ operation: str,
144
+ input: Optional[TInput] = None) -> Task:
145
+ """Schedule entity function for execution.
146
+
147
+ Parameters
148
+ ----------
149
+ entity: EntityInstanceId
150
+ The ID of the entity instance to call.
151
+ operation: str
152
+ The name of the operation to invoke on the entity.
153
+ input: Optional[TInput]
154
+ The optional JSON-serializable input to pass to the entity function.
155
+
156
+ Returns
157
+ -------
158
+ Task
159
+ A Durable Task that completes when the called entity function completes or fails.
160
+ """
161
+ pass
162
+
163
+ @abstractmethod
164
+ def signal_entity(
165
+ self,
166
+ entity_id: EntityInstanceId,
167
+ operation_name: str,
168
+ input: Optional[TInput] = None
169
+ ) -> None:
170
+ """Signal an entity function for execution.
171
+
172
+ Parameters
173
+ ----------
174
+ entity_id: EntityInstanceId
175
+ The ID of the entity instance to signal.
176
+ operation_name: str
177
+ The name of the operation to invoke on the entity.
178
+ input: Optional[TInput]
179
+ The optional JSON-serializable input to pass to the entity function.
180
+ """
181
+ pass
182
+
183
+ @abstractmethod
184
+ def lock_entities(self, entities: list[EntityInstanceId]) -> Task[EntityLock]:
185
+ """Creates a Task object that locks the specified entity instances.
186
+
187
+ The locks will be acquired the next time the orchestrator yields.
188
+ Best practice is to immediately yield this Task and enter the returned EntityLock.
189
+ The lock is released when the EntityLock is exited.
190
+
191
+ Parameters
192
+ ----------
193
+ entities: list[EntityInstanceId]
194
+ The list of entity instance IDs to lock.
195
+
196
+ Returns
197
+ -------
198
+ EntityLock
199
+ A context manager object that releases the locks when exited.
200
+ """
201
+ pass
202
+
203
+ @abstractmethod
204
+ def call_sub_orchestrator(self, orchestrator: Union[Orchestrator[TInput, TOutput], str], *,
205
+ input: Optional[TInput] = None,
206
+ instance_id: Optional[str] = None,
207
+ retry_policy: Optional[RetryPolicy] = None,
208
+ version: Optional[str] = None) -> Task[TOutput]:
113
209
  """Schedule sub-orchestrator function for execution.
114
210
 
115
211
  Parameters
116
212
  ----------
117
213
  orchestrator: Orchestrator[TInput, TOutput]
118
214
  A reference to the orchestrator function to call.
119
- input: Union[TInput, None]
215
+ input: Optional[TInput]
120
216
  The optional JSON-serializable input to pass to the orchestrator function.
121
- instance_id: Union[str, None]
217
+ instance_id: Optional[str]
122
218
  A unique ID to use for the sub-orchestration instance. If not specified, a
123
219
  random UUID will be used.
220
+ retry_policy: Optional[RetryPolicy]
221
+ The retry policy to use for this sub-orchestrator call.
124
222
 
125
223
  Returns
126
224
  -------
@@ -147,9 +245,26 @@ class OrchestrationContext(ABC):
147
245
  """
148
246
  pass
149
247
 
248
+ @abstractmethod
249
+ def continue_as_new(self, new_input: Any, *, save_events: bool = False) -> None:
250
+ """Continue the orchestration execution as a new instance.
251
+
252
+ Parameters
253
+ ----------
254
+ new_input : Any
255
+ The new input to use for the new orchestration instance.
256
+ save_events : bool
257
+ A flag indicating whether to add any unprocessed external events in the new orchestration history.
258
+ """
259
+ pass
260
+
261
+ @abstractmethod
262
+ def _exit_critical_section(self) -> None:
263
+ pass
264
+
150
265
 
151
266
  class FailureDetails:
152
- def __init__(self, message: str, error_type: str, stack_trace: Union[str, None]):
267
+ def __init__(self, message: str, error_type: str, stack_trace: Optional[str]):
153
268
  self._message = message
154
269
  self._error_type = error_type
155
270
  self._stack_trace = stack_trace
@@ -163,7 +278,7 @@ class FailureDetails:
163
278
  return self._error_type
164
279
 
165
280
  @property
166
- def stack_trace(self) -> Union[str, None]:
281
+ def stack_trace(self) -> Optional[str]:
167
282
  return self._stack_trace
168
283
 
169
284
 
@@ -193,8 +308,8 @@ class OrchestrationStateError(Exception):
193
308
  class Task(ABC, Generic[T]):
194
309
  """Abstract base class for asynchronous tasks in a durable orchestration."""
195
310
  _result: T
196
- _exception: Union[TaskFailedError, None]
197
- _parent: Union[CompositeTask[T], None]
311
+ _exception: Optional[TaskFailedError]
312
+ _parent: Optional[CompositeTask[T]]
198
313
 
199
314
  def __init__(self) -> None:
200
315
  super().__init__()
@@ -229,9 +344,9 @@ class Task(ABC, Generic[T]):
229
344
 
230
345
  class CompositeTask(Task[T]):
231
346
  """A task that is composed of other tasks."""
232
- _tasks: List[Task]
347
+ _tasks: list[Task]
233
348
 
234
- def __init__(self, tasks: List[Task]):
349
+ def __init__(self, tasks: list[Task]):
235
350
  super().__init__()
236
351
  self._tasks = tasks
237
352
  self._completed_tasks = 0
@@ -241,7 +356,7 @@ class CompositeTask(Task[T]):
241
356
  if task.is_complete:
242
357
  self.on_child_completed(task)
243
358
 
244
- def get_tasks(self) -> List[Task]:
359
+ def get_tasks(self) -> list[Task]:
245
360
  return self._tasks
246
361
 
247
362
  @abstractmethod
@@ -249,10 +364,40 @@ class CompositeTask(Task[T]):
249
364
  pass
250
365
 
251
366
 
367
+ class WhenAllTask(CompositeTask[list[T]]):
368
+ """A task that completes when all of its child tasks complete."""
369
+
370
+ def __init__(self, tasks: list[Task[T]]):
371
+ super().__init__(tasks)
372
+ self._completed_tasks = 0
373
+ self._failed_tasks = 0
374
+
375
+ @property
376
+ def pending_tasks(self) -> int:
377
+ """Returns the number of tasks that have not yet completed."""
378
+ return len(self._tasks) - self._completed_tasks
379
+
380
+ def on_child_completed(self, task: Task[T]):
381
+ if self.is_complete:
382
+ raise ValueError('The task has already completed.')
383
+ self._completed_tasks += 1
384
+ if task.is_failed and self._exception is None:
385
+ self._exception = task.get_exception()
386
+ self._is_complete = True
387
+ if self._completed_tasks == len(self._tasks):
388
+ # The order of the result MUST match the order of the tasks provided to the constructor.
389
+ self._result = [task.get_result() for task in self._tasks]
390
+ self._is_complete = True
391
+
392
+ def get_completed_tasks(self) -> int:
393
+ return self._completed_tasks
394
+
395
+
252
396
  class CompletableTask(Task[T]):
253
397
 
254
398
  def __init__(self):
255
399
  super().__init__()
400
+ self._retryable_parent = None
256
401
 
257
402
  def complete(self, result: T):
258
403
  if self._is_complete:
@@ -271,39 +416,57 @@ class CompletableTask(Task[T]):
271
416
  self._parent.on_child_completed(self)
272
417
 
273
418
 
274
- class WhenAllTask(CompositeTask[List[T]]):
275
- """A task that completes when all of its child tasks complete."""
419
+ class RetryableTask(CompletableTask[T]):
420
+ """A task that can be retried according to a retry policy."""
276
421
 
277
- def __init__(self, tasks: List[Task[T]]):
278
- super().__init__(tasks)
279
- self._completed_tasks = 0
280
- self._failed_tasks = 0
422
+ def __init__(self, retry_policy: RetryPolicy, action: pb.OrchestratorAction,
423
+ start_time: datetime, is_sub_orch: bool) -> None:
424
+ super().__init__()
425
+ self._action = action
426
+ self._retry_policy = retry_policy
427
+ self._attempt_count = 1
428
+ self._start_time = start_time
429
+ self._is_sub_orch = is_sub_orch
281
430
 
282
- @property
283
- def pending_tasks(self) -> int:
284
- """Returns the number of tasks that have not yet completed."""
285
- return len(self._tasks) - self._completed_tasks
431
+ def increment_attempt_count(self) -> None:
432
+ self._attempt_count += 1
286
433
 
287
- def on_child_completed(self, task: Task[T]):
288
- if self.is_complete:
289
- raise ValueError('The task has already completed.')
290
- self._completed_tasks += 1
291
- if task.is_failed and self._exception is None:
292
- self._exception = task.get_exception()
293
- self._is_complete = True
294
- if self._completed_tasks == len(self._tasks):
295
- # The order of the result MUST match the order of the tasks provided to the constructor.
296
- self._result = [task.get_result() for task in self._tasks]
297
- self._is_complete = True
434
+ def compute_next_delay(self) -> Optional[timedelta]:
435
+ if self._attempt_count >= self._retry_policy.max_number_of_attempts:
436
+ return None
298
437
 
299
- def get_completed_tasks(self) -> int:
300
- return self._completed_tasks
438
+ retry_expiration: datetime = datetime.max
439
+ if self._retry_policy.retry_timeout is not None and self._retry_policy.retry_timeout != datetime.max:
440
+ retry_expiration = self._start_time + self._retry_policy.retry_timeout
441
+
442
+ if self._retry_policy.backoff_coefficient is None:
443
+ backoff_coefficient = 1.0
444
+ else:
445
+ backoff_coefficient = self._retry_policy.backoff_coefficient
446
+
447
+ if datetime.utcnow() < retry_expiration:
448
+ next_delay_f = math.pow(backoff_coefficient, self._attempt_count - 1) * self._retry_policy.first_retry_interval.total_seconds()
449
+
450
+ if self._retry_policy.max_retry_interval is not None:
451
+ next_delay_f = min(next_delay_f, self._retry_policy.max_retry_interval.total_seconds())
452
+ return timedelta(seconds=next_delay_f)
453
+
454
+ return None
455
+
456
+
457
+ class TimerTask(CompletableTask[T]):
458
+
459
+ def __init__(self) -> None:
460
+ super().__init__()
461
+
462
+ def set_retryable_parent(self, retryable_task: RetryableTask):
463
+ self._retryable_parent = retryable_task
301
464
 
302
465
 
303
466
  class WhenAnyTask(CompositeTask[Task]):
304
467
  """A task that completes when any of its child tasks complete."""
305
468
 
306
- def __init__(self, tasks: List[Task]):
469
+ def __init__(self, tasks: list[Task]):
307
470
  super().__init__(tasks)
308
471
 
309
472
  def on_child_completed(self, task: Task):
@@ -313,12 +476,12 @@ class WhenAnyTask(CompositeTask[Task]):
313
476
  self._result = task
314
477
 
315
478
 
316
- def when_all(tasks: List[Task[T]]) -> WhenAllTask[T]:
479
+ def when_all(tasks: list[Task[T]]) -> WhenAllTask[T]:
317
480
  """Returns a task that completes when all of the provided tasks complete or when one of the tasks fail."""
318
481
  return WhenAllTask(tasks)
319
482
 
320
483
 
321
- def when_any(tasks: List[Task]) -> WhenAnyTask:
484
+ def when_any(tasks: list[Task]) -> WhenAnyTask:
322
485
  """Returns a task that completes when any of the provided tasks complete or fail."""
323
486
  return WhenAnyTask(tasks)
324
487
 
@@ -362,6 +525,76 @@ Orchestrator = Callable[[OrchestrationContext, TInput], Union[Generator[Task, An
362
525
  # Activities are simple functions that can be scheduled by orchestrators
363
526
  Activity = Callable[[ActivityContext, TInput], TOutput]
364
527
 
528
+ Entity = Union[Callable[[EntityContext, TInput], TOutput], type[DurableEntity]]
529
+
530
+
531
+ class RetryPolicy:
532
+ """Represents the retry policy for an orchestration or activity function."""
533
+
534
+ def __init__(self, *,
535
+ first_retry_interval: timedelta,
536
+ max_number_of_attempts: int,
537
+ backoff_coefficient: Optional[float] = 1.0,
538
+ max_retry_interval: Optional[timedelta] = None,
539
+ retry_timeout: Optional[timedelta] = None):
540
+ """Creates a new RetryPolicy instance.
541
+
542
+ Parameters
543
+ ----------
544
+ first_retry_interval : timedelta
545
+ The retry interval to use for the first retry attempt.
546
+ max_number_of_attempts : int
547
+ The maximum number of retry attempts.
548
+ backoff_coefficient : Optional[float]
549
+ The backoff coefficient to use for calculating the next retry interval.
550
+ max_retry_interval : Optional[timedelta]
551
+ The maximum retry interval to use for any retry attempt.
552
+ retry_timeout : Optional[timedelta]
553
+ The maximum amount of time to spend retrying the operation.
554
+ """
555
+ # validate inputs
556
+ if first_retry_interval < timedelta(seconds=0):
557
+ raise ValueError('first_retry_interval must be >= 0')
558
+ if max_number_of_attempts < 1:
559
+ raise ValueError('max_number_of_attempts must be >= 1')
560
+ if backoff_coefficient is not None and backoff_coefficient < 1:
561
+ raise ValueError('backoff_coefficient must be >= 1')
562
+ if max_retry_interval is not None and max_retry_interval < timedelta(seconds=0):
563
+ raise ValueError('max_retry_interval must be >= 0')
564
+ if retry_timeout is not None and retry_timeout < timedelta(seconds=0):
565
+ raise ValueError('retry_timeout must be >= 0')
566
+
567
+ self._first_retry_interval = first_retry_interval
568
+ self._max_number_of_attempts = max_number_of_attempts
569
+ self._backoff_coefficient = backoff_coefficient
570
+ self._max_retry_interval = max_retry_interval
571
+ self._retry_timeout = retry_timeout
572
+
573
+ @property
574
+ def first_retry_interval(self) -> timedelta:
575
+ """The retry interval to use for the first retry attempt."""
576
+ return self._first_retry_interval
577
+
578
+ @property
579
+ def max_number_of_attempts(self) -> int:
580
+ """The maximum number of retry attempts."""
581
+ return self._max_number_of_attempts
582
+
583
+ @property
584
+ def backoff_coefficient(self) -> Optional[float]:
585
+ """The backoff coefficient to use for calculating the next retry interval."""
586
+ return self._backoff_coefficient
587
+
588
+ @property
589
+ def max_retry_interval(self) -> Optional[timedelta]:
590
+ """The maximum retry interval to use for any retry attempt."""
591
+ return self._max_retry_interval
592
+
593
+ @property
594
+ def retry_timeout(self) -> Optional[timedelta]:
595
+ """The maximum amount of time to spend retrying the operation."""
596
+ return self._retry_timeout
597
+
365
598
 
366
599
  def get_name(fn: Callable) -> str:
367
600
  """Returns the name of the provided function"""