durabletask 0.1.0a5__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of durabletask might be problematic. Click here for more details.

@@ -5,39 +5,67 @@ import dataclasses
5
5
  import json
6
6
  import logging
7
7
  from types import SimpleNamespace
8
- from typing import Any, Dict, List, Tuple, Union
8
+ from typing import Any, Optional, Sequence, Union
9
9
 
10
10
  import grpc
11
11
 
12
- from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl
12
+ ClientInterceptor = Union[
13
+ grpc.UnaryUnaryClientInterceptor,
14
+ grpc.UnaryStreamClientInterceptor,
15
+ grpc.StreamUnaryClientInterceptor,
16
+ grpc.StreamStreamClientInterceptor
17
+ ]
13
18
 
14
19
  # Field name used to indicate that an object was automatically serialized
15
20
  # and should be deserialized as a SimpleNamespace
16
21
  AUTO_SERIALIZED = "__durabletask_autoobject__"
17
22
 
23
+ SECURE_PROTOCOLS = ["https://", "grpcs://"]
24
+ INSECURE_PROTOCOLS = ["http://", "grpc://"]
25
+
18
26
 
19
27
  def get_default_host_address() -> str:
20
28
  return "localhost:4001"
21
29
 
22
30
 
23
- def get_grpc_channel(host_address: Union[str, None], metadata: Union[List[Tuple[str, str]], None], secure_channel: bool = False) -> grpc.Channel:
31
+ def get_grpc_channel(
32
+ host_address: Optional[str],
33
+ secure_channel: bool = False,
34
+ interceptors: Optional[Sequence[ClientInterceptor]] = None) -> grpc.Channel:
35
+
24
36
  if host_address is None:
25
37
  host_address = get_default_host_address()
26
38
 
39
+ for protocol in SECURE_PROTOCOLS:
40
+ if host_address.lower().startswith(protocol):
41
+ secure_channel = True
42
+ # remove the protocol from the host name
43
+ host_address = host_address[len(protocol):]
44
+ break
45
+
46
+ for protocol in INSECURE_PROTOCOLS:
47
+ if host_address.lower().startswith(protocol):
48
+ secure_channel = False
49
+ # remove the protocol from the host name
50
+ host_address = host_address[len(protocol):]
51
+ break
52
+
53
+ # Create the base channel
27
54
  if secure_channel:
28
55
  channel = grpc.secure_channel(host_address, grpc.ssl_channel_credentials())
29
56
  else:
30
57
  channel = grpc.insecure_channel(host_address)
31
58
 
32
- if metadata is not None and len(metadata) > 0:
33
- interceptors = [DefaultClientInterceptorImpl(metadata)]
59
+ # Apply interceptors ONLY if they exist
60
+ if interceptors:
34
61
  channel = grpc.intercept_channel(channel, *interceptors)
35
62
  return channel
36
63
 
64
+
37
65
  def get_logger(
38
66
  name_suffix: str,
39
- log_handler: Union[logging.Handler, None] = None,
40
- log_formatter: Union[logging.Formatter, None] = None) -> logging.Logger:
67
+ log_handler: Optional[logging.Handler] = None,
68
+ log_formatter: Optional[logging.Formatter] = None) -> logging.Logger:
41
69
  logger = logging.Logger(f"durabletask-{name_suffix}")
42
70
 
43
71
  # Add a default log handler if none is provided
@@ -78,7 +106,7 @@ class InternalJSONEncoder(json.JSONEncoder):
78
106
  if dataclasses.is_dataclass(obj):
79
107
  # Dataclasses are not serializable by default, so we convert them to a dict and mark them for
80
108
  # automatic deserialization by the receiver
81
- d = dataclasses.asdict(obj)
109
+ d = dataclasses.asdict(obj) # type: ignore
82
110
  d[AUTO_SERIALIZED] = True
83
111
  return d
84
112
  elif isinstance(obj, SimpleNamespace):
@@ -94,7 +122,7 @@ class InternalJSONDecoder(json.JSONDecoder):
94
122
  def __init__(self, *args, **kwargs):
95
123
  super().__init__(object_hook=self.dict_to_object, *args, **kwargs)
96
124
 
97
- def dict_to_object(self, d: Dict[str, Any]):
125
+ def dict_to_object(self, d: dict[str, Any]):
98
126
  # If the object was serialized by the InternalJSONEncoder, deserialize it as a SimpleNamespace
99
127
  if d.pop(AUTO_SERIALIZED, False):
100
128
  return SimpleNamespace(**d)
durabletask/task.py CHANGED
@@ -4,9 +4,10 @@
4
4
  # See https://peps.python.org/pep-0563/
5
5
  from __future__ import annotations
6
6
 
7
+ import math
7
8
  from abc import ABC, abstractmethod
8
9
  from datetime import datetime, timedelta
9
- from typing import Any, Callable, Generator, Generic, List, TypeVar, Union
10
+ from typing import Any, Callable, Generator, Generic, Optional, TypeVar, Union
10
11
 
11
12
  import durabletask.internal.helpers as pbh
12
13
  import durabletask.internal.orchestrator_service_pb2 as pb
@@ -69,6 +70,17 @@ class OrchestrationContext(ABC):
69
70
  """
70
71
  pass
71
72
 
73
+ @abstractmethod
74
+ def set_custom_status(self, custom_status: Any) -> None:
75
+ """Set the orchestration instance's custom status.
76
+
77
+ Parameters
78
+ ----------
79
+ custom_status: Any
80
+ A JSON-serializable custom status value to set.
81
+ """
82
+ pass
83
+
72
84
  @abstractmethod
73
85
  def create_timer(self, fire_at: Union[datetime, timedelta]) -> Task:
74
86
  """Create a Timer Task to fire after at the specified deadline.
@@ -87,17 +99,18 @@ class OrchestrationContext(ABC):
87
99
 
88
100
  @abstractmethod
89
101
  def call_activity(self, activity: Union[Activity[TInput, TOutput], str], *,
90
- input: Union[TInput, None] = None) -> Task[TOutput]:
102
+ input: Optional[TInput] = None,
103
+ retry_policy: Optional[RetryPolicy] = None) -> Task[TOutput]:
91
104
  """Schedule an activity for execution.
92
105
 
93
106
  Parameters
94
107
  ----------
95
108
  activity: Union[Activity[TInput, TOutput], str]
96
109
  A reference to the activity function to call.
97
- input: Union[TInput, None]
110
+ input: Optional[TInput]
98
111
  The JSON-serializable input (or None) to pass to the activity.
99
- return_type: task.Task[TOutput]
100
- The JSON-serializable output type to expect from the activity result.
112
+ retry_policy: Optional[RetryPolicy]
113
+ The retry policy to use for this activity call.
101
114
 
102
115
  Returns
103
116
  -------
@@ -108,19 +121,22 @@ class OrchestrationContext(ABC):
108
121
 
109
122
  @abstractmethod
110
123
  def call_sub_orchestrator(self, orchestrator: Orchestrator[TInput, TOutput], *,
111
- input: Union[TInput, None] = None,
112
- instance_id: Union[str, None] = None) -> Task[TOutput]:
124
+ input: Optional[TInput] = None,
125
+ instance_id: Optional[str] = None,
126
+ retry_policy: Optional[RetryPolicy] = None) -> Task[TOutput]:
113
127
  """Schedule sub-orchestrator function for execution.
114
128
 
115
129
  Parameters
116
130
  ----------
117
131
  orchestrator: Orchestrator[TInput, TOutput]
118
132
  A reference to the orchestrator function to call.
119
- input: Union[TInput, None]
133
+ input: Optional[TInput]
120
134
  The optional JSON-serializable input to pass to the orchestrator function.
121
- instance_id: Union[str, None]
135
+ instance_id: Optional[str]
122
136
  A unique ID to use for the sub-orchestration instance. If not specified, a
123
137
  random UUID will be used.
138
+ retry_policy: Optional[RetryPolicy]
139
+ The retry policy to use for this sub-orchestrator call.
124
140
 
125
141
  Returns
126
142
  -------
@@ -162,7 +178,7 @@ class OrchestrationContext(ABC):
162
178
 
163
179
 
164
180
  class FailureDetails:
165
- def __init__(self, message: str, error_type: str, stack_trace: Union[str, None]):
181
+ def __init__(self, message: str, error_type: str, stack_trace: Optional[str]):
166
182
  self._message = message
167
183
  self._error_type = error_type
168
184
  self._stack_trace = stack_trace
@@ -176,7 +192,7 @@ class FailureDetails:
176
192
  return self._error_type
177
193
 
178
194
  @property
179
- def stack_trace(self) -> Union[str, None]:
195
+ def stack_trace(self) -> Optional[str]:
180
196
  return self._stack_trace
181
197
 
182
198
 
@@ -206,8 +222,8 @@ class OrchestrationStateError(Exception):
206
222
  class Task(ABC, Generic[T]):
207
223
  """Abstract base class for asynchronous tasks in a durable orchestration."""
208
224
  _result: T
209
- _exception: Union[TaskFailedError, None]
210
- _parent: Union[CompositeTask[T], None]
225
+ _exception: Optional[TaskFailedError]
226
+ _parent: Optional[CompositeTask[T]]
211
227
 
212
228
  def __init__(self) -> None:
213
229
  super().__init__()
@@ -242,9 +258,9 @@ class Task(ABC, Generic[T]):
242
258
 
243
259
  class CompositeTask(Task[T]):
244
260
  """A task that is composed of other tasks."""
245
- _tasks: List[Task]
261
+ _tasks: list[Task]
246
262
 
247
- def __init__(self, tasks: List[Task]):
263
+ def __init__(self, tasks: list[Task]):
248
264
  super().__init__()
249
265
  self._tasks = tasks
250
266
  self._completed_tasks = 0
@@ -254,7 +270,7 @@ class CompositeTask(Task[T]):
254
270
  if task.is_complete:
255
271
  self.on_child_completed(task)
256
272
 
257
- def get_tasks(self) -> List[Task]:
273
+ def get_tasks(self) -> list[Task]:
258
274
  return self._tasks
259
275
 
260
276
  @abstractmethod
@@ -262,10 +278,40 @@ class CompositeTask(Task[T]):
262
278
  pass
263
279
 
264
280
 
281
+ class WhenAllTask(CompositeTask[list[T]]):
282
+ """A task that completes when all of its child tasks complete."""
283
+
284
+ def __init__(self, tasks: list[Task[T]]):
285
+ super().__init__(tasks)
286
+ self._completed_tasks = 0
287
+ self._failed_tasks = 0
288
+
289
+ @property
290
+ def pending_tasks(self) -> int:
291
+ """Returns the number of tasks that have not yet completed."""
292
+ return len(self._tasks) - self._completed_tasks
293
+
294
+ def on_child_completed(self, task: Task[T]):
295
+ if self.is_complete:
296
+ raise ValueError('The task has already completed.')
297
+ self._completed_tasks += 1
298
+ if task.is_failed and self._exception is None:
299
+ self._exception = task.get_exception()
300
+ self._is_complete = True
301
+ if self._completed_tasks == len(self._tasks):
302
+ # The order of the result MUST match the order of the tasks provided to the constructor.
303
+ self._result = [task.get_result() for task in self._tasks]
304
+ self._is_complete = True
305
+
306
+ def get_completed_tasks(self) -> int:
307
+ return self._completed_tasks
308
+
309
+
265
310
  class CompletableTask(Task[T]):
266
311
 
267
312
  def __init__(self):
268
313
  super().__init__()
314
+ self._retryable_parent = None
269
315
 
270
316
  def complete(self, result: T):
271
317
  if self._is_complete:
@@ -284,39 +330,57 @@ class CompletableTask(Task[T]):
284
330
  self._parent.on_child_completed(self)
285
331
 
286
332
 
287
- class WhenAllTask(CompositeTask[List[T]]):
288
- """A task that completes when all of its child tasks complete."""
333
+ class RetryableTask(CompletableTask[T]):
334
+ """A task that can be retried according to a retry policy."""
289
335
 
290
- def __init__(self, tasks: List[Task[T]]):
291
- super().__init__(tasks)
292
- self._completed_tasks = 0
293
- self._failed_tasks = 0
336
+ def __init__(self, retry_policy: RetryPolicy, action: pb.OrchestratorAction,
337
+ start_time: datetime, is_sub_orch: bool) -> None:
338
+ super().__init__()
339
+ self._action = action
340
+ self._retry_policy = retry_policy
341
+ self._attempt_count = 1
342
+ self._start_time = start_time
343
+ self._is_sub_orch = is_sub_orch
294
344
 
295
- @property
296
- def pending_tasks(self) -> int:
297
- """Returns the number of tasks that have not yet completed."""
298
- return len(self._tasks) - self._completed_tasks
345
+ def increment_attempt_count(self) -> None:
346
+ self._attempt_count += 1
299
347
 
300
- def on_child_completed(self, task: Task[T]):
301
- if self.is_complete:
302
- raise ValueError('The task has already completed.')
303
- self._completed_tasks += 1
304
- if task.is_failed and self._exception is None:
305
- self._exception = task.get_exception()
306
- self._is_complete = True
307
- if self._completed_tasks == len(self._tasks):
308
- # The order of the result MUST match the order of the tasks provided to the constructor.
309
- self._result = [task.get_result() for task in self._tasks]
310
- self._is_complete = True
348
+ def compute_next_delay(self) -> Optional[timedelta]:
349
+ if self._attempt_count >= self._retry_policy.max_number_of_attempts:
350
+ return None
311
351
 
312
- def get_completed_tasks(self) -> int:
313
- return self._completed_tasks
352
+ retry_expiration: datetime = datetime.max
353
+ if self._retry_policy.retry_timeout is not None and self._retry_policy.retry_timeout != datetime.max:
354
+ retry_expiration = self._start_time + self._retry_policy.retry_timeout
355
+
356
+ if self._retry_policy.backoff_coefficient is None:
357
+ backoff_coefficient = 1.0
358
+ else:
359
+ backoff_coefficient = self._retry_policy.backoff_coefficient
360
+
361
+ if datetime.utcnow() < retry_expiration:
362
+ next_delay_f = math.pow(backoff_coefficient, self._attempt_count - 1) * self._retry_policy.first_retry_interval.total_seconds()
363
+
364
+ if self._retry_policy.max_retry_interval is not None:
365
+ next_delay_f = min(next_delay_f, self._retry_policy.max_retry_interval.total_seconds())
366
+ return timedelta(seconds=next_delay_f)
367
+
368
+ return None
369
+
370
+
371
+ class TimerTask(CompletableTask[T]):
372
+
373
+ def __init__(self) -> None:
374
+ super().__init__()
375
+
376
+ def set_retryable_parent(self, retryable_task: RetryableTask):
377
+ self._retryable_parent = retryable_task
314
378
 
315
379
 
316
380
  class WhenAnyTask(CompositeTask[Task]):
317
381
  """A task that completes when any of its child tasks complete."""
318
382
 
319
- def __init__(self, tasks: List[Task]):
383
+ def __init__(self, tasks: list[Task]):
320
384
  super().__init__(tasks)
321
385
 
322
386
  def on_child_completed(self, task: Task):
@@ -326,12 +390,12 @@ class WhenAnyTask(CompositeTask[Task]):
326
390
  self._result = task
327
391
 
328
392
 
329
- def when_all(tasks: List[Task[T]]) -> WhenAllTask[T]:
393
+ def when_all(tasks: list[Task[T]]) -> WhenAllTask[T]:
330
394
  """Returns a task that completes when all of the provided tasks complete or when one of the tasks fail."""
331
395
  return WhenAllTask(tasks)
332
396
 
333
397
 
334
- def when_any(tasks: List[Task]) -> WhenAnyTask:
398
+ def when_any(tasks: list[Task]) -> WhenAnyTask:
335
399
  """Returns a task that completes when any of the provided tasks complete or fail."""
336
400
  return WhenAnyTask(tasks)
337
401
 
@@ -376,6 +440,74 @@ Orchestrator = Callable[[OrchestrationContext, TInput], Union[Generator[Task, An
376
440
  Activity = Callable[[ActivityContext, TInput], TOutput]
377
441
 
378
442
 
443
+ class RetryPolicy:
444
+ """Represents the retry policy for an orchestration or activity function."""
445
+
446
+ def __init__(self, *,
447
+ first_retry_interval: timedelta,
448
+ max_number_of_attempts: int,
449
+ backoff_coefficient: Optional[float] = 1.0,
450
+ max_retry_interval: Optional[timedelta] = None,
451
+ retry_timeout: Optional[timedelta] = None):
452
+ """Creates a new RetryPolicy instance.
453
+
454
+ Parameters
455
+ ----------
456
+ first_retry_interval : timedelta
457
+ The retry interval to use for the first retry attempt.
458
+ max_number_of_attempts : int
459
+ The maximum number of retry attempts.
460
+ backoff_coefficient : Optional[float]
461
+ The backoff coefficient to use for calculating the next retry interval.
462
+ max_retry_interval : Optional[timedelta]
463
+ The maximum retry interval to use for any retry attempt.
464
+ retry_timeout : Optional[timedelta]
465
+ The maximum amount of time to spend retrying the operation.
466
+ """
467
+ # validate inputs
468
+ if first_retry_interval < timedelta(seconds=0):
469
+ raise ValueError('first_retry_interval must be >= 0')
470
+ if max_number_of_attempts < 1:
471
+ raise ValueError('max_number_of_attempts must be >= 1')
472
+ if backoff_coefficient is not None and backoff_coefficient < 1:
473
+ raise ValueError('backoff_coefficient must be >= 1')
474
+ if max_retry_interval is not None and max_retry_interval < timedelta(seconds=0):
475
+ raise ValueError('max_retry_interval must be >= 0')
476
+ if retry_timeout is not None and retry_timeout < timedelta(seconds=0):
477
+ raise ValueError('retry_timeout must be >= 0')
478
+
479
+ self._first_retry_interval = first_retry_interval
480
+ self._max_number_of_attempts = max_number_of_attempts
481
+ self._backoff_coefficient = backoff_coefficient
482
+ self._max_retry_interval = max_retry_interval
483
+ self._retry_timeout = retry_timeout
484
+
485
+ @property
486
+ def first_retry_interval(self) -> timedelta:
487
+ """The retry interval to use for the first retry attempt."""
488
+ return self._first_retry_interval
489
+
490
+ @property
491
+ def max_number_of_attempts(self) -> int:
492
+ """The maximum number of retry attempts."""
493
+ return self._max_number_of_attempts
494
+
495
+ @property
496
+ def backoff_coefficient(self) -> Optional[float]:
497
+ """The backoff coefficient to use for calculating the next retry interval."""
498
+ return self._backoff_coefficient
499
+
500
+ @property
501
+ def max_retry_interval(self) -> Optional[timedelta]:
502
+ """The maximum retry interval to use for any retry attempt."""
503
+ return self._max_retry_interval
504
+
505
+ @property
506
+ def retry_timeout(self) -> Optional[timedelta]:
507
+ """The maximum amount of time to spend retrying the operation."""
508
+ return self._retry_timeout
509
+
510
+
379
511
  def get_name(fn: Callable) -> str:
380
512
  """Returns the name of the provided function"""
381
513
  name = fn.__name__