cadence-python-client 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. cadence/_internal/activity/_activity_executor.py +20 -4
  2. cadence/_internal/activity/_context.py +41 -6
  3. cadence/_internal/activity/_definition.py +33 -5
  4. cadence/_internal/activity/_heartbeat.py +42 -0
  5. cadence/_internal/fn_signature.py +31 -1
  6. cadence/_internal/rpc/retry.py +13 -5
  7. cadence/_internal/rpc/yarpc.py +6 -3
  8. cadence/_internal/workflow/active_cluster_selection_policy.py +32 -0
  9. cadence/_internal/workflow/context.py +41 -10
  10. cadence/_internal/workflow/deterministic_event_loop.py +25 -2
  11. cadence/_internal/workflow/retry_policy.py +62 -0
  12. cadence/_internal/workflow/statemachine/event_dispatcher.py +3 -2
  13. cadence/_internal/workflow/waiter.py +37 -0
  14. cadence/_internal/workflow/workflow_engine.py +52 -25
  15. cadence/_internal/workflow/workflow_instance.py +53 -1
  16. cadence/activity.py +22 -0
  17. cadence/api/v1/__init__.py +12 -0
  18. cadence/api/v1/common_pb2.py +30 -12
  19. cadence/api/v1/common_pb2.pyi +29 -2
  20. cadence/api/v1/domain_pb2.py +22 -10
  21. cadence/api/v1/domain_pb2.pyi +24 -2
  22. cadence/api/v1/error_pb2.py +3 -1
  23. cadence/api/v1/error_pb2.pyi +4 -0
  24. cadence/api/v1/history_pb2.py +44 -44
  25. cadence/api/v1/history_pb2.pyi +6 -2
  26. cadence/api/v1/schedule_pb2.py +61 -0
  27. cadence/api/v1/schedule_pb2.pyi +154 -0
  28. cadence/api/v1/schedule_pb2_grpc.py +24 -0
  29. cadence/api/v1/service_domain_pb2.py +51 -33
  30. cadence/api/v1/service_domain_pb2.pyi +68 -2
  31. cadence/api/v1/service_domain_pb2_grpc.py +88 -0
  32. cadence/api/v1/service_schedule_pb2.py +72 -0
  33. cadence/api/v1/service_schedule_pb2.pyi +163 -0
  34. cadence/api/v1/service_schedule_pb2_grpc.py +409 -0
  35. cadence/api/v1/service_workflow_pb2.py +9 -9
  36. cadence/api/v1/service_workflow_pb2.pyi +4 -2
  37. cadence/api/v1/service_workflow_pb2_grpc.py +2 -2
  38. cadence/api/v1/workflow_pb2.py +50 -48
  39. cadence/api/v1/workflow_pb2.pyi +29 -2
  40. cadence/client.py +64 -14
  41. cadence/contrib/__init__.py +0 -0
  42. cadence/contrib/openai/README.md +124 -0
  43. cadence/contrib/openai/__init__.py +15 -0
  44. cadence/contrib/openai/cadence_agent_runner.py +133 -0
  45. cadence/contrib/openai/cadence_handoff.py +42 -0
  46. cadence/contrib/openai/cadence_model.py +71 -0
  47. cadence/contrib/openai/cadence_registry.py +6 -0
  48. cadence/contrib/openai/cadence_tool.py +54 -0
  49. cadence/contrib/openai/images/cadence-web-agent-run.jpg +0 -0
  50. cadence/contrib/openai/openai_activities.py +51 -0
  51. cadence/contrib/openai/pydantic_data_converter.py +172 -0
  52. cadence/data_converter.py +21 -6
  53. cadence/error.py +46 -8
  54. cadence/signal.py +22 -94
  55. cadence/worker/_worker.py +25 -5
  56. cadence/workflow.py +102 -5
  57. {cadence_python_client-0.2.0.dist-info → cadence_python_client-0.2.2.dist-info}/METADATA +8 -1
  58. {cadence_python_client-0.2.0.dist-info → cadence_python_client-0.2.2.dist-info}/RECORD +62 -41
  59. {cadence_python_client-0.2.0.dist-info → cadence_python_client-0.2.2.dist-info}/WHEEL +1 -1
  60. {cadence_python_client-0.2.0.dist-info → cadence_python_client-0.2.2.dist-info}/licenses/LICENSE +0 -0
  61. {cadence_python_client-0.2.0.dist-info → cadence_python_client-0.2.2.dist-info}/licenses/NOTICE +0 -0
  62. {cadence_python_client-0.2.0.dist-info → cadence_python_client-0.2.2.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,13 @@
1
1
  from concurrent.futures import ThreadPoolExecutor
2
2
  from logging import getLogger
3
3
  from traceback import format_exception
4
- from typing import Any, Callable, cast
4
+ from typing import Any, Callable, Union, cast
5
5
  from google.protobuf.duration import to_timedelta
6
6
  from google.protobuf.timestamp import to_datetime
7
7
 
8
8
  from cadence._internal.activity._context import _Context, _SyncContext
9
9
  from cadence._internal.activity._definition import BaseDefinition, ExecutionStrategy
10
+ from cadence._internal.activity._heartbeat import _HeartbeatSender
10
11
  from cadence.activity import ActivityInfo, ActivityDefinition
11
12
  from cadence.api.v1.common_pb2 import Failure
12
13
  from cadence.api.v1.service_worker_pb2 import (
@@ -46,7 +47,9 @@ class ActivityExecutor:
46
47
  _logger.exception("Activity failed")
47
48
  await self._report_failure(task, e)
48
49
 
49
- def _create_context(self, task: PollForActivityTaskResponse) -> _Context:
50
+ def _create_context(
51
+ self, task: PollForActivityTaskResponse
52
+ ) -> Union[_Context, _SyncContext]:
50
53
  activity_type = task.activity_type.name
51
54
  try:
52
55
  activity_def = cast(BaseDefinition, self._registry(activity_type))
@@ -54,11 +57,24 @@ class ActivityExecutor:
54
57
  raise KeyError(f"Activity type not found: {activity_type}") from None
55
58
 
56
59
  info = self._create_info(task)
60
+ heartbeat_sender = _HeartbeatSender(
61
+ self._client.worker_stub,
62
+ self._data_converter,
63
+ task.task_token,
64
+ self._identity,
65
+ task.heartbeat_details,
66
+ )
57
67
 
58
68
  if activity_def.strategy == ExecutionStrategy.ASYNC:
59
- return _Context(self._client, info, activity_def)
69
+ return _Context(self._client, info, activity_def, heartbeat_sender)
60
70
  else:
61
- return _SyncContext(self._client, info, activity_def, self._thread_pool)
71
+ return _SyncContext(
72
+ self._client,
73
+ info,
74
+ activity_def,
75
+ self._thread_pool,
76
+ heartbeat_sender,
77
+ )
62
78
 
63
79
  async def _report_failure(
64
80
  self, task: PollForActivityTaskResponse, error: Exception
@@ -1,9 +1,10 @@
1
1
  import asyncio
2
2
  from concurrent.futures.thread import ThreadPoolExecutor
3
- from typing import Any
3
+ from typing import Any, Type
4
4
 
5
5
  from cadence import Client
6
6
  from cadence._internal.activity._definition import BaseDefinition
7
+ from cadence._internal.activity._heartbeat import _HeartbeatSender
7
8
  from cadence.activity import ActivityInfo, ActivityContext
8
9
  from cadence.api.v1.common_pb2 import Payload
9
10
 
@@ -14,15 +15,27 @@ class _Context(ActivityContext):
14
15
  client: Client,
15
16
  info: ActivityInfo,
16
17
  activity_def: BaseDefinition[[Any], Any],
18
+ heartbeat_sender: _HeartbeatSender,
17
19
  ):
18
20
  self._client = client
19
21
  self._info = info
20
22
  self._activity_def = activity_def
23
+ self._heartbeat_sender = heartbeat_sender
24
+ self._heartbeat_tasks: set[asyncio.Future[None]] = set()
21
25
 
22
26
  async def execute(self, payload: Payload) -> Any:
23
27
  params = self._to_params(payload)
24
- with self._activate():
25
- return await self._activity_def.impl_fn(*params)
28
+ try:
29
+ with self._activate():
30
+ return await self._activity_def.impl_fn(*params)
31
+ finally:
32
+ await self._wait_pending_heartbeats()
33
+
34
+ async def _wait_pending_heartbeats(self) -> None:
35
+ if not self._heartbeat_tasks:
36
+ return
37
+ tasks = list(self._heartbeat_tasks)
38
+ await asyncio.gather(*tasks, return_exceptions=True)
26
39
 
27
40
  def _to_params(self, payload: Payload) -> list[Any]:
28
41
  return self._activity_def.signature.params_from_payload(
@@ -35,6 +48,16 @@ class _Context(ActivityContext):
35
48
  def info(self) -> ActivityInfo:
36
49
  return self._info
37
50
 
51
+ def heartbeat(self, *details: Any) -> None:
52
+ heartbeat_task = asyncio.create_task(
53
+ self._heartbeat_sender.send_heartbeat(*details)
54
+ )
55
+ self._heartbeat_tasks.add(heartbeat_task)
56
+ heartbeat_task.add_done_callback(self._heartbeat_tasks.discard)
57
+
58
+ def heartbeat_details(self, *types: Type) -> list[Any]:
59
+ return self._heartbeat_sender.get_details(*types)
60
+
38
61
 
39
62
  class _SyncContext(_Context):
40
63
  def __init__(
@@ -43,14 +66,18 @@ class _SyncContext(_Context):
43
66
  info: ActivityInfo,
44
67
  activity_def: BaseDefinition[[Any], Any],
45
68
  executor: ThreadPoolExecutor,
69
+ heartbeat_sender: _HeartbeatSender,
46
70
  ):
47
- super().__init__(client, info, activity_def)
71
+ super().__init__(client, info, activity_def, heartbeat_sender)
48
72
  self._executor = executor
49
73
 
50
74
  async def execute(self, payload: Payload) -> Any:
51
75
  params = self._to_params(payload)
52
- loop = asyncio.get_running_loop()
53
- return await loop.run_in_executor(self._executor, self._run, params)
76
+ self._loop = asyncio.get_running_loop()
77
+ try:
78
+ return await self._loop.run_in_executor(self._executor, self._run, params)
79
+ finally:
80
+ await self._wait_pending_heartbeats()
54
81
 
55
82
  def _run(self, args: list[Any]) -> Any:
56
83
  with self._activate():
@@ -58,3 +85,11 @@ class _SyncContext(_Context):
58
85
 
59
86
  def client(self) -> Client:
60
87
  raise RuntimeError("client is only supported in async activities")
88
+
89
+ def heartbeat(self, *details: Any) -> None:
90
+ future = asyncio.run_coroutine_threadsafe(
91
+ self._heartbeat_sender.send_heartbeat(*details), self._loop
92
+ )
93
+ wrapped = asyncio.wrap_future(future, loop=self._loop)
94
+ self._heartbeat_tasks.add(wrapped)
95
+ wrapped.add_done_callback(self._heartbeat_tasks.discard)
@@ -1,4 +1,8 @@
1
1
  import abc
2
+ import asyncio.coroutines
3
+ import inspect
4
+ import sys
5
+
2
6
  from abc import ABC
3
7
  from enum import Enum
4
8
  from functools import update_wrapper, partial
@@ -10,7 +14,9 @@ from typing import (
10
14
  ParamSpec,
11
15
  TypeVar,
12
16
  Awaitable,
17
+ Type,
13
18
  cast,
19
+ overload,
14
20
  Concatenate,
15
21
  )
16
22
 
@@ -21,6 +27,8 @@ T = TypeVar("T")
21
27
  P = ParamSpec("P")
22
28
  R = TypeVar("R")
23
29
 
30
+ _COROUTINE_MARKER = getattr(asyncio.coroutines, "_is_coroutine")
31
+
24
32
 
25
33
  class ExecutionStrategy(Enum):
26
34
  ASYNC = "async"
@@ -113,10 +121,13 @@ class SyncMethodImpl(BaseDefinition[P, R], Generic[T, P, R]):
113
121
  super().__init__(name, wrapped, ExecutionStrategy.THREAD_POOL, signature)
114
122
  update_wrapper(self, wrapped)
115
123
 
116
- def __get__(self, instance, owner):
124
+ @overload
125
+ def __get__(self, instance: None, owner: Type[T]) -> "SyncMethodImpl[T, P, R]": ...
126
+ @overload
127
+ def __get__(self, instance: T, owner: Type[T]) -> SyncImpl[P, R]: ...
128
+ def __get__(self, instance: T | None, owner: Type[T]) -> "SyncImpl[P, R] | Self":
117
129
  if instance is None:
118
130
  return self
119
- # If we bound the method to an instance, then drop the self parameter. It's a normal function again
120
131
  return SyncImpl[P, R](
121
132
  partial(self._wrapped, instance), self.name, self._signature
122
133
  )
@@ -141,6 +152,13 @@ class AsyncImpl(BaseDefinition[P, R]):
141
152
  ):
142
153
  super().__init__(name, wrapped, ExecutionStrategy.ASYNC, signature)
143
154
  update_wrapper(self, wrapped)
155
+ if sys.version_info >= (3, 12):
156
+ """
157
+ Mark the function as a coroutine function. This is only available in python 3.12 and above
158
+ """
159
+ inspect.markcoroutinefunction(self)
160
+ else:
161
+ self._is_coroutine = _COROUTINE_MARKER
144
162
 
145
163
  async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
146
164
  if WorkflowContext.is_set():
@@ -160,11 +178,21 @@ class AsyncMethodImpl(BaseDefinition[P, R], Generic[T, P, R]):
160
178
  ):
161
179
  super().__init__(name, wrapped, ExecutionStrategy.ASYNC, signature)
162
180
  update_wrapper(self, wrapped)
163
-
164
- def __get__(self, instance, owner):
181
+ if sys.version_info >= (3, 12):
182
+ """
183
+ Mark the function as a coroutine function. This is only available in python 3.12 and above
184
+ """
185
+ inspect.markcoroutinefunction(self)
186
+ else:
187
+ self._is_coroutine = _COROUTINE_MARKER
188
+
189
+ @overload
190
+ def __get__(self, instance: None, owner: Type[T]) -> "AsyncMethodImpl[T, P, R]": ...
191
+ @overload
192
+ def __get__(self, instance: T, owner: Type[T]) -> AsyncImpl[P, R]: ...
193
+ def __get__(self, instance: T | None, owner: Type[T]) -> "AsyncImpl[P, R] | Self":
165
194
  if instance is None:
166
195
  return self
167
- # If we bound the method to an instance, then drop the self parameter. It's a normal function again
168
196
  return AsyncImpl[P, R](
169
197
  partial(self._wrapped, instance), self.name, self._signature
170
198
  )
@@ -0,0 +1,42 @@
1
+ from logging import getLogger
2
+ from typing import Any, Type
3
+
4
+ from cadence.api.v1.common_pb2 import Payload
5
+ from cadence.api.v1.service_worker_pb2 import RecordActivityTaskHeartbeatRequest
6
+ from cadence.api.v1.service_worker_pb2_grpc import WorkerAPIStub
7
+ from cadence.data_converter import DataConverter
8
+
9
+ _logger = getLogger(__name__)
10
+
11
+
12
+ class _HeartbeatSender:
13
+ def __init__(
14
+ self,
15
+ worker_stub: WorkerAPIStub,
16
+ data_converter: DataConverter,
17
+ task_token: bytes,
18
+ identity: str,
19
+ previous_details: Payload,
20
+ ):
21
+ self._worker_stub = worker_stub
22
+ self._data_converter = data_converter
23
+ self._task_token = task_token
24
+ self._identity = identity
25
+ self._previous_details = previous_details
26
+
27
+ def get_details(self, *types: Type) -> list[Any]:
28
+ return self._data_converter.from_data(self._previous_details, list(types))
29
+
30
+ async def send_heartbeat(self, *details: Any) -> None:
31
+ try:
32
+ payload = self._data_converter.to_data(list(details))
33
+ await self._worker_stub.RecordActivityTaskHeartbeat(
34
+ RecordActivityTaskHeartbeatRequest(
35
+ task_token=self._task_token,
36
+ details=payload,
37
+ identity=self._identity,
38
+ )
39
+ )
40
+ self._previous_details = payload
41
+ except Exception:
42
+ _logger.warning("Heartbeat failed", exc_info=True)
@@ -53,8 +53,20 @@ class FnSignature:
53
53
  def params_from_payload(
54
54
  self, data_converter: DataConverter, payload: Payload
55
55
  ) -> list[Any]:
56
+ if not self.params:
57
+ return []
56
58
  type_hints = [param.type_hint for param in self.params]
57
- return data_converter.from_data(payload, type_hints)
59
+ decoded = _decode_provided_values(data_converter, payload, type_hints)
60
+ for i, param in enumerate(self.params):
61
+ if i < len(decoded):
62
+ continue
63
+ if param.has_default:
64
+ decoded.append(param.default_value)
65
+ else:
66
+ raise ValueError(
67
+ f"required parameter '{param.name}' (position {i}) not provided in payload"
68
+ )
69
+ return decoded
58
70
 
59
71
  @staticmethod
60
72
  def of(fn: Callable) -> "FnSignature":
@@ -88,3 +100,21 @@ class FnSignature:
88
100
  return_type = hints.get("return", Any)
89
101
 
90
102
  return FnSignature(params, return_type)
103
+
104
+
105
+ def _decode_provided_values(
106
+ data_converter: DataConverter,
107
+ payload: Payload,
108
+ type_hints: Sequence[Type | None],
109
+ ) -> list[Any]:
110
+ decoder = getattr(data_converter, "_decode_provided_values", None)
111
+ if callable(decoder):
112
+ return list(decoder(payload, type_hints))
113
+
114
+ counter = getattr(data_converter, "_payload_value_count", None)
115
+ if callable(counter):
116
+ provided_count = int(counter(payload, len(type_hints)))
117
+ return data_converter.from_data(payload, list(type_hints[:provided_count]))
118
+
119
+ # Backward compatibility
120
+ return data_converter.from_data(payload, list(type_hints))
@@ -60,6 +60,8 @@ class RetryInterceptor(UnaryUnaryClientInterceptor):
60
60
  ) -> Any:
61
61
  loop = asyncio.get_running_loop()
62
62
  expiration_interval = client_call_details.timeout
63
+ if expiration_interval is None:
64
+ expiration_interval = float("inf")
63
65
  start_time = loop.time()
64
66
  deadline = start_time + expiration_interval
65
67
 
@@ -68,11 +70,16 @@ class RetryInterceptor(UnaryUnaryClientInterceptor):
68
70
  remaining = deadline - loop.time()
69
71
  # Namedtuple methods start with an underscore to avoid conflicts and aren't actually private
70
72
  # noinspection PyProtectedMember
71
- call_details = client_call_details._replace(timeout=remaining)
73
+ call_details = client_call_details._replace( # type: ignore[attr-defined]
74
+ timeout=remaining
75
+ )
72
76
  rpc_call = await continuation(call_details, request)
73
77
  try:
74
- # Return the result directly if success. GRPC will wrap it back into a UnaryUnaryCall
75
- return await rpc_call
78
+ await rpc_call
79
+ # Return the call object (not the raw response) so outer interceptors
80
+ # that rely on UnaryUnaryCall methods like add_done_callback still work
81
+ # (e.g. opentelemetry-instrumentation-grpc).
82
+ return rpc_call
76
83
  except CadenceRpcError as e:
77
84
  err = e
78
85
 
@@ -92,8 +99,9 @@ class RetryInterceptor(UnaryUnaryClientInterceptor):
92
99
 
93
100
  def is_retryable(err: CadenceRpcError, call_details: ClientCallDetails) -> bool:
94
101
  # Handle requests to the passive side, matching the Go and Java Clients
95
- if call_details.method == GET_WORKFLOW_HISTORY and isinstance(
96
- err, EntityNotExistsError
102
+ if (
103
+ call_details.method == GET_WORKFLOW_HISTORY # type: ignore[comparison-overlap]
104
+ and isinstance(err, EntityNotExistsError)
97
105
  ):
98
106
  return (
99
107
  err.active_cluster is not None
@@ -1,4 +1,4 @@
1
- from typing import Any, Callable
1
+ from typing import Any, Callable, cast
2
2
 
3
3
  from grpc.aio import Metadata
4
4
  from grpc.aio import UnaryUnaryClientInterceptor, ClientCallDetails
@@ -40,6 +40,9 @@ class YarpcMetadataInterceptor(UnaryUnaryClientInterceptor):
40
40
 
41
41
  # Namedtuple methods start with an underscore to avoid conflicts and aren't actually private
42
42
  # noinspection PyProtectedMember
43
- return client_call_details._replace(
44
- metadata=metadata, timeout=client_call_details.timeout or 60.0
43
+ return cast(
44
+ ClientCallDetails,
45
+ client_call_details._replace( # type: ignore[attr-defined]
46
+ metadata=metadata, timeout=client_call_details.timeout or 60.0
47
+ ),
45
48
  )
@@ -0,0 +1,32 @@
1
+ """Adapt :class:`cadence.workflow.ActiveClusterSelectionPolicy` (TypedDict) to its protobuf wire form."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Mapping, cast
6
+
7
+ from cadence.api.v1 import common_pb2
8
+ from cadence.workflow import ActiveClusterSelectionPolicy
9
+
10
+
11
+ def active_cluster_selection_policy_to_proto(
12
+ policy: ActiveClusterSelectionPolicy | Mapping[str, object] | None,
13
+ ) -> common_pb2.ActiveClusterSelectionPolicy | None:
14
+ """Convert a user active-cluster selection policy to protobuf, or ``None`` if empty.
15
+
16
+ ``None`` and an empty mapping both map to ``None``.
17
+ """
18
+ if policy is None or (isinstance(policy, Mapping) and len(policy) == 0):
19
+ return None
20
+
21
+ out = common_pb2.ActiveClusterSelectionPolicy()
22
+
23
+ if (ca := policy.get("cluster_attribute")) is not None:
24
+ ca_map = cast(Mapping[str, str], ca)
25
+ out.cluster_attribute.CopyFrom(
26
+ common_pb2.ClusterAttribute(
27
+ scope=ca_map.get("scope", ""),
28
+ name=ca_map.get("name", ""),
29
+ )
30
+ )
31
+
32
+ return out
@@ -1,14 +1,30 @@
1
1
  from contextlib import contextmanager
2
+ from asyncio import get_running_loop
2
3
  from datetime import timedelta
3
4
  from math import ceil
4
- from typing import Iterator, Optional, Any, Unpack, Type, cast
5
+ from typing import Iterator, Optional, Any, Unpack, Type, cast, Callable
5
6
 
7
+ from cadence._internal.workflow.deterministic_event_loop import DeterministicEventLoop
8
+ from cadence._internal.workflow.retry_policy import retry_policy_to_proto
6
9
  from cadence._internal.workflow.statemachine.decision_manager import DecisionManager
7
10
  from cadence.api.v1.common_pb2 import ActivityType
8
- from cadence.api.v1.decision_pb2 import ScheduleActivityTaskDecisionAttributes
11
+ from cadence.api.v1.decision_pb2 import (
12
+ ScheduleActivityTaskDecisionAttributes,
13
+ StartTimerDecisionAttributes,
14
+ )
9
15
  from cadence.api.v1.tasklist_pb2 import TaskList, TaskListKind
10
16
  from cadence.data_converter import DataConverter
11
- from cadence.workflow import WorkflowContext, WorkflowInfo, ResultType, ActivityOptions
17
+ from cadence.workflow import (
18
+ ActivityOptions,
19
+ ResultType,
20
+ WorkflowContext,
21
+ WorkflowInfo,
22
+ )
23
+
24
+ _DEFAULT_ACTIVITY_OPTIONS: ActivityOptions = {
25
+ "schedule_to_close_timeout": timedelta(hours=1),
26
+ "schedule_to_start_timeout": timedelta(seconds=10),
27
+ }
12
28
 
13
29
 
14
30
  class Context(WorkflowContext):
@@ -35,7 +51,7 @@ class Context(WorkflowContext):
35
51
  *args: Any,
36
52
  **kwargs: Unpack[ActivityOptions],
37
53
  ) -> ResultType:
38
- opts = ActivityOptions(**kwargs)
54
+ opts: ActivityOptions = {**_DEFAULT_ACTIVITY_OPTIONS, **kwargs}
39
55
  if "schedule_to_close_timeout" not in opts and (
40
56
  "schedule_to_start_timeout" not in opts
41
57
  or "start_to_close_timeout" not in opts
@@ -77,19 +93,28 @@ class Context(WorkflowContext):
77
93
  schedule_to_start_timeout=_round_to_nearest_second(schedule_to_start),
78
94
  start_to_close_timeout=_round_to_nearest_second(start_to_close),
79
95
  heartbeat_timeout=_round_to_nearest_second(heartbeat),
80
- retry_policy=None,
96
+ retry_policy=retry_policy_to_proto(opts.get("retry_policy")),
81
97
  header=None,
82
98
  request_local_dispatch=False,
83
99
  )
84
100
 
85
- result_payload = await self._decision_manager.schedule_activity(
86
- schedule_attributes
87
- )
101
+ future = self._decision_manager.schedule_activity(schedule_attributes)
102
+ result_payload = await future
88
103
 
89
104
  result = self.data_converter().from_data(result_payload, [result_type])[0]
90
105
 
91
106
  return cast(ResultType, result)
92
107
 
108
+ async def start_timer(self, duration: timedelta):
109
+ if duration.total_seconds() <= 0: # shortcut
110
+ return
111
+ future = self._decision_manager.start_timer(
112
+ StartTimerDecisionAttributes(
113
+ start_to_fire_timeout=duration,
114
+ )
115
+ )
116
+ await future
117
+
93
118
  def set_replay_mode(self, replay: bool) -> None:
94
119
  """Set whether the workflow is currently in replay mode."""
95
120
  self._replay_mode = replay
@@ -106,11 +131,17 @@ class Context(WorkflowContext):
106
131
  """Get the current replay time in milliseconds."""
107
132
  return self._replay_current_time_milliseconds
108
133
 
134
+ async def wait_condition(self, predicate: Callable[[], bool]) -> None:
135
+ loop = cast(DeterministicEventLoop, get_running_loop())
136
+ await loop.create_waiter(predicate)
137
+
109
138
  @contextmanager
110
139
  def _activate(self) -> Iterator["Context"]:
111
140
  token = WorkflowContext._var.set(self)
112
- yield self
113
- WorkflowContext._var.reset(token)
141
+ try:
142
+ yield self
143
+ finally:
144
+ WorkflowContext._var.reset(token)
114
145
 
115
146
 
116
147
  def _round_to_nearest_second(delta: timedelta) -> timedelta:
@@ -8,6 +8,8 @@ import threading
8
8
  from typing import Callable, Any, TypeVar, Coroutine, Awaitable, Generator
9
9
  from typing_extensions import Unpack, TypeVarTuple
10
10
 
11
+ from cadence._internal.workflow.waiter import Waiter
12
+
11
13
  logger = logging.getLogger(__name__)
12
14
 
13
15
 
@@ -32,6 +34,7 @@ class DeterministicEventLoop(AbstractEventLoop):
32
34
  self._thread_id: int | None = None # indicate if the event loop is running
33
35
  self._debug: bool = False
34
36
  self._ready: collections.deque[events.Handle] = collections.deque()
37
+ self._waiters: list[Waiter] = []
35
38
  self._stopping: bool = False
36
39
  self._closed: bool = False
37
40
 
@@ -141,6 +144,13 @@ class DeterministicEventLoop(AbstractEventLoop):
141
144
  def create_future(self) -> Future[Any]:
142
145
  return futures.Future(loop=self)
143
146
 
147
+ def create_waiter(self, predicate: Callable[[], bool]) -> Waiter:
148
+ """Register a predicate-driven awaitable."""
149
+ waiter = Waiter(predicate, self)
150
+ if not waiter.poll():
151
+ self._waiters.append(waiter)
152
+ return waiter
153
+
144
154
  def _run_once(self) -> None:
145
155
  ntodo = len(self._ready)
146
156
  for i in range(ntodo):
@@ -149,6 +159,19 @@ class DeterministicEventLoop(AbstractEventLoop):
149
159
  continue
150
160
  handle._run()
151
161
 
162
+ # Poll waiters; only stop early if settling one schedules new work,
163
+ # so remaining waiters are not skipped.
164
+ i = 0
165
+ while i < len(self._waiters):
166
+ w = self._waiters[i]
167
+ ready_before = len(self._ready)
168
+ if w.poll():
169
+ del self._waiters[i]
170
+ if len(self._ready) > ready_before:
171
+ return
172
+ else:
173
+ i += 1
174
+
152
175
  def _run_forever_setup(self) -> None:
153
176
  self._check_closed()
154
177
  self._check_running()
@@ -190,6 +213,7 @@ class DeterministicEventLoop(AbstractEventLoop):
190
213
  logger.debug("Close %r", self)
191
214
  self._closed = True
192
215
  self._ready.clear()
216
+ self._waiters.clear()
193
217
 
194
218
  def is_closed(self) -> bool:
195
219
  """Returns True if the event loop was closed."""
@@ -462,13 +486,12 @@ class DeterministicEventLoop(AbstractEventLoop):
462
486
  )
463
487
 
464
488
  def call_exception_handler(self, context: dict[str, Any]) -> None:
465
- # This is called if a task has an unhandled exception. Short term, it's helpful to log these for debugging.
466
- # Long term, we need some combination of failing decision tasks or workflows based on these errors.
467
489
  message = context.get("message")
468
490
  if not message:
469
491
  message = "Unhandled exception in event loop"
470
492
 
471
493
  exception = context.get("exception")
494
+
472
495
  if isinstance(exception, BaseException):
473
496
  exc_info = exception
474
497
  else:
@@ -0,0 +1,62 @@
1
+ """Adapt :class:`cadence.workflow.RetryPolicy` (TypedDict) to its protobuf wire form."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from datetime import timedelta
6
+ from math import ceil
7
+ from typing import Mapping, cast
8
+
9
+ from google.protobuf.duration_pb2 import Duration
10
+
11
+ from cadence.api.v1 import common_pb2
12
+ from cadence.workflow import RetryPolicy
13
+
14
+
15
+ def _round_to_whole_seconds(delta: timedelta) -> timedelta:
16
+ """Ceil-round a ``timedelta`` to whole seconds."""
17
+ return timedelta(seconds=ceil(delta.total_seconds()))
18
+
19
+
20
+ def _set_duration_field(target: Duration, delta: timedelta) -> None:
21
+ """Write ``delta``, ceil-rounded to whole seconds, into a proto ``Duration`` field."""
22
+ d = Duration()
23
+ d.FromTimedelta(_round_to_whole_seconds(delta))
24
+ target.CopyFrom(d)
25
+
26
+
27
+ def retry_policy_to_proto(
28
+ policy: RetryPolicy | Mapping[str, object] | None,
29
+ ) -> common_pb2.RetryPolicy | None:
30
+ """Convert a user retry policy to protobuf, or ``None`` if no policy was provided.
31
+
32
+ ``None`` and an empty mapping both map to ``None`` so that the server applies its
33
+ own defaults instead of receiving an explicit empty policy. Durations are ceiled
34
+ to whole seconds to match the server's resolution and the Go/Java SDKs.
35
+ """
36
+ if policy is None or (isinstance(policy, Mapping) and len(policy) == 0):
37
+ return None
38
+
39
+ out = common_pb2.RetryPolicy()
40
+
41
+ if (ii := policy.get("initial_interval")) is not None:
42
+ _set_duration_field(out.initial_interval, cast(timedelta, ii))
43
+
44
+ if (coef := policy.get("backoff_coefficient")) is not None:
45
+ coef_f = cast(float, coef)
46
+ if coef_f < 1.0:
47
+ raise ValueError("backoff_coefficient must be >= 1.0 when provided")
48
+ out.backoff_coefficient = coef_f
49
+
50
+ if (mi := policy.get("maximum_interval")) is not None:
51
+ _set_duration_field(out.maximum_interval, cast(timedelta, mi))
52
+
53
+ if (ma := policy.get("maximum_attempts")) is not None:
54
+ out.maximum_attempts = int(cast(int, ma))
55
+
56
+ if (reasons := policy.get("non_retryable_error_reasons")) is not None:
57
+ out.non_retryable_error_reasons.extend(cast(list[str], reasons))
58
+
59
+ if (ei := policy.get("expiration_interval")) is not None:
60
+ _set_duration_field(out.expiration_interval, cast(timedelta, ei))
61
+
62
+ return out
@@ -21,7 +21,7 @@ class Action:
21
21
  class EventDispatcher:
22
22
  handlers: dict[Type, Action]
23
23
 
24
- def __init__(self, default_id_attr: str) -> None:
24
+ def __init__(self, default_id_attr: str = "") -> None:
25
25
  self._default_id_attr = default_id_attr
26
26
  self.handlers = {}
27
27
 
@@ -32,7 +32,8 @@ class EventDispatcher:
32
32
  event_type = _find_event_type(func)
33
33
  event_id_attr = id_attr if id_attr else self._default_id_attr
34
34
 
35
- _validate_field(func, event_type, event_id_attr)
35
+ if event_id_attr:
36
+ _validate_field(func, event_type, event_id_attr)
36
37
  if event_type in self.handlers:
37
38
  raise ValueError(
38
39
  f"Duplicate handler for {event_type}: {func.__qualname__} and {self.handlers[event_type].fn.__qualname__}"