grctl-sdk-python 0.1.1__tar.gz → 0.1.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. {grctl_sdk_python-0.1.1 → grctl_sdk_python-0.1.2}/PKG-INFO +1 -1
  2. {grctl_sdk_python-0.1.1 → grctl_sdk_python-0.1.2}/grctl/models/directive.py +2 -2
  3. {grctl_sdk_python-0.1.1 → grctl_sdk_python-0.1.2}/grctl/models/history.py +3 -3
  4. {grctl_sdk_python-0.1.1 → grctl_sdk_python-0.1.2}/grctl/nats/publisher.py +9 -3
  5. {grctl_sdk_python-0.1.1 → grctl_sdk_python-0.1.2}/grctl/worker/codec.py +7 -0
  6. {grctl_sdk_python-0.1.1 → grctl_sdk_python-0.1.2}/grctl/worker/context.py +3 -3
  7. {grctl_sdk_python-0.1.1 → grctl_sdk_python-0.1.2}/grctl/worker/runner.py +64 -80
  8. {grctl_sdk_python-0.1.1 → grctl_sdk_python-0.1.2}/grctl/worker/runtime.py +10 -8
  9. {grctl_sdk_python-0.1.1 → grctl_sdk_python-0.1.2}/grctl/worker/task.py +21 -16
  10. {grctl_sdk_python-0.1.1 → grctl_sdk_python-0.1.2}/grctl/workflow/__init__.py +2 -2
  11. {grctl_sdk_python-0.1.1 → grctl_sdk_python-0.1.2}/grctl/workflow/workflow.py +2 -8
  12. {grctl_sdk_python-0.1.1 → grctl_sdk_python-0.1.2}/pyproject.toml +1 -1
  13. {grctl_sdk_python-0.1.1 → grctl_sdk_python-0.1.2}/LICENSE +0 -0
  14. {grctl_sdk_python-0.1.1 → grctl_sdk_python-0.1.2}/README.md +0 -0
  15. {grctl_sdk_python-0.1.1 → grctl_sdk_python-0.1.2}/grctl/__init__.py +0 -0
  16. {grctl_sdk_python-0.1.1 → grctl_sdk_python-0.1.2}/grctl/client/__init__.py +0 -0
  17. {grctl_sdk_python-0.1.1 → grctl_sdk_python-0.1.2}/grctl/client/client.py +0 -0
  18. {grctl_sdk_python-0.1.1 → grctl_sdk_python-0.1.2}/grctl/logging_config.py +0 -0
  19. {grctl_sdk_python-0.1.1 → grctl_sdk_python-0.1.2}/grctl/models/__init__.py +0 -0
  20. {grctl_sdk_python-0.1.1 → grctl_sdk_python-0.1.2}/grctl/models/api.py +0 -0
  21. {grctl_sdk_python-0.1.1 → grctl_sdk_python-0.1.2}/grctl/models/command.py +0 -0
  22. {grctl_sdk_python-0.1.1 → grctl_sdk_python-0.1.2}/grctl/models/common.py +0 -0
  23. {grctl_sdk_python-0.1.1 → grctl_sdk_python-0.1.2}/grctl/models/errors.py +0 -0
  24. {grctl_sdk_python-0.1.1 → grctl_sdk_python-0.1.2}/grctl/models/run_info.py +0 -0
  25. {grctl_sdk_python-0.1.1 → grctl_sdk_python-0.1.2}/grctl/models/run_info_helper.py +0 -0
  26. {grctl_sdk_python-0.1.1 → grctl_sdk_python-0.1.2}/grctl/models/worker.py +0 -0
  27. {grctl_sdk_python-0.1.1 → grctl_sdk_python-0.1.2}/grctl/nats/__init__.py +0 -0
  28. {grctl_sdk_python-0.1.1 → grctl_sdk_python-0.1.2}/grctl/nats/connection.py +0 -0
  29. {grctl_sdk_python-0.1.1 → grctl_sdk_python-0.1.2}/grctl/nats/history_fetch.py +0 -0
  30. {grctl_sdk_python-0.1.1 → grctl_sdk_python-0.1.2}/grctl/nats/history_sub.py +0 -0
  31. {grctl_sdk_python-0.1.1 → grctl_sdk_python-0.1.2}/grctl/nats/kv_store.py +0 -0
  32. {grctl_sdk_python-0.1.1 → grctl_sdk_python-0.1.2}/grctl/nats/manifest.py +0 -0
  33. {grctl_sdk_python-0.1.1 → grctl_sdk_python-0.1.2}/grctl/nats/nats_client.py +0 -0
  34. {grctl_sdk_python-0.1.1 → grctl_sdk_python-0.1.2}/grctl/nats/nats_manifest.yaml +0 -0
  35. {grctl_sdk_python-0.1.1 → grctl_sdk_python-0.1.2}/grctl/nats/subscriber.py +0 -0
  36. {grctl_sdk_python-0.1.1 → grctl_sdk_python-0.1.2}/grctl/py.typed +0 -0
  37. {grctl_sdk_python-0.1.1 → grctl_sdk_python-0.1.2}/grctl/settings.py +0 -0
  38. {grctl_sdk_python-0.1.1 → grctl_sdk_python-0.1.2}/grctl/worker/__init__.py +0 -0
  39. {grctl_sdk_python-0.1.1 → grctl_sdk_python-0.1.2}/grctl/worker/errors.py +0 -0
  40. {grctl_sdk_python-0.1.1 → grctl_sdk_python-0.1.2}/grctl/worker/logger.py +0 -0
  41. {grctl_sdk_python-0.1.1 → grctl_sdk_python-0.1.2}/grctl/worker/run_manager.py +0 -0
  42. {grctl_sdk_python-0.1.1 → grctl_sdk_python-0.1.2}/grctl/worker/store.py +0 -0
  43. {grctl_sdk_python-0.1.1 → grctl_sdk_python-0.1.2}/grctl/worker/worker.py +0 -0
  44. {grctl_sdk_python-0.1.1 → grctl_sdk_python-0.1.2}/grctl/workflow/future.py +0 -0
  45. {grctl_sdk_python-0.1.1 → grctl_sdk_python-0.1.2}/grctl/workflow/handle.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: grctl-sdk-python
3
- Version: 0.1.1
3
+ Version: 0.1.2
4
4
  Summary: The Python SDK for the Ground Control
5
5
  Author: cemevren
6
6
  Author-email: cemevren <cemevren@gmail.com>
@@ -158,11 +158,11 @@ class DirectiveWire(msgspec.Struct, omit_defaults=True):
158
158
  kv: dict[str, Any] | None = None # kv_revs
159
159
 
160
160
 
161
- def directive_encoder(directive: Directive) -> bytes:
161
+ def directive_encoder(directive: Directive, enc_hook: Any = None) -> bytes:
162
162
  if directive.msg is None:
163
163
  raise ValueError("Directive message cannot be None")
164
164
 
165
- msg_bytes = msgspec.msgpack.encode(directive.msg)
165
+ msg_bytes = msgspec.msgpack.encode(directive.msg, enc_hook=enc_hook)
166
166
 
167
167
  wire = DirectiveWire(
168
168
  id=directive.id,
@@ -135,7 +135,7 @@ class TaskCompleted(msgspec.Struct):
135
135
 
136
136
  task_id: str
137
137
  task_name: str
138
- output: Any # Task return value
138
+ output: dict[str, Any] # Always {"result": <primitive>}
139
139
  step_name: str
140
140
  duration_ms: int
141
141
 
@@ -297,7 +297,7 @@ class HistoryWire(msgspec.Struct):
297
297
  o: str = ""
298
298
 
299
299
 
300
- def history_encoder(event: HistoryEvent) -> bytes:
300
+ def history_encoder(event: HistoryEvent, enc_hook: Any = None) -> bytes:
301
301
  """Encode history event to msgpack as array: [kind, msg, run_id, timestamp, wf_id, worker_id]."""
302
302
  if event.msg is None:
303
303
  raise ValueError("HistoryEvent message cannot be None")
@@ -308,7 +308,7 @@ def history_encoder(event: HistoryEvent) -> bytes:
308
308
  wo=event.worker_id,
309
309
  ts=event.timestamp,
310
310
  k=event.kind,
311
- m=msgspec.msgpack.encode(event.msg),
311
+ m=msgspec.msgpack.encode(event.msg, enc_hook=enc_hook),
312
312
  o=event.operation_id,
313
313
  )
314
314
 
@@ -1,3 +1,6 @@
1
+ from collections.abc import Callable
2
+ from typing import Any
3
+
1
4
  from nats.aio.client import Client as NATSClient
2
5
  from nats.js.client import JetStreamContext
3
6
 
@@ -11,22 +14,25 @@ class Publisher:
11
14
  self._js = js
12
15
  self._manifest = manifest
13
16
 
14
- async def publish_history(self, run_info: RunInfo, event: HistoryEvent) -> None:
17
+ async def publish_history(
18
+ self, run_info: RunInfo, event: HistoryEvent, enc_hook: Callable[[Any], Any] | None = None
19
+ ) -> None:
15
20
  subject = self._manifest.history_subject(wf_id=run_info.wf_id, run_id=run_info.id)
16
- data = history_encoder(event)
21
+ data = history_encoder(event, enc_hook=enc_hook)
17
22
  await self._js.publish(subject, data)
18
23
 
19
24
  async def publish_next_directive(
20
25
  self,
21
26
  run: RunInfo,
22
27
  directive: Directive,
28
+ enc_hook: Callable[[Any], Any] | None = None,
23
29
  ) -> None:
24
30
  subject = self._manifest.directive_subject(
25
31
  wf_type=run.wf_type,
26
32
  wf_id=run.wf_id,
27
33
  run_id=run.id,
28
34
  )
29
- data = directive_encoder(directive)
35
+ data = directive_encoder(directive, enc_hook=enc_hook)
30
36
  await self._js.publish(subject, data)
31
37
 
32
38
  async def publish_cmd(
@@ -1,6 +1,7 @@
1
1
  from collections.abc import Callable
2
2
  from typing import Any
3
3
 
4
+ import msgspec
4
5
  import msgspec.msgpack
5
6
  from pydantic import BaseModel
6
7
 
@@ -35,6 +36,12 @@ class CodecRegistry:
35
36
  return decode(tp, obj)
36
37
  raise TypeError(f"Unsupported type: {tp}")
37
38
 
39
+ def to_primitive(self, value: Any) -> Any:
40
+ return msgspec.to_builtins(value, enc_hook=self.enc_hook)
41
+
42
+ def from_primitive(self, raw: Any, tp: type) -> Any:
43
+ return msgspec.convert(raw, tp, dec_hook=self.dec_hook)
44
+
38
45
  def encode(self, value: Any) -> bytes:
39
46
  return msgspec.msgpack.encode(value, enc_hook=self.enc_hook)
40
47
 
@@ -33,7 +33,7 @@ from grctl.worker.logger import ReplayAwareLogger
33
33
  from grctl.worker.runtime import get_step_runtime
34
34
  from grctl.worker.store import Store
35
35
  from grctl.workflow import WorkflowHandle
36
- from grctl.workflow.workflow import StepConfig
36
+ from grctl.workflow.workflow import HandlerConfig
37
37
 
38
38
  StepHandler = Callable[..., Awaitable[Directive]]
39
39
 
@@ -50,7 +50,7 @@ class NextBuilder:
50
50
  worker_id: str,
51
51
  store: Store,
52
52
  current_directive: Directive,
53
- step_configs: dict[str, StepConfig] | None = None,
53
+ step_configs: dict[str, HandlerConfig] | None = None,
54
54
  ) -> None:
55
55
  self._run = run
56
56
  self._worker_id = worker_id
@@ -149,7 +149,7 @@ class Context:
149
149
  worker_id: str,
150
150
  directive: Directive,
151
151
  parent_run: RunInfo | None = None,
152
- step_configs: dict[str, StepConfig] | None = None,
152
+ step_configs: dict[str, HandlerConfig] | None = None,
153
153
  workflow_logger: logging.Logger | None = None,
154
154
  ) -> None:
155
155
  self.run = run_info
@@ -1,11 +1,8 @@
1
1
  import functools
2
2
  import traceback
3
- from collections.abc import Awaitable, Callable
4
3
  from datetime import UTC, datetime
5
4
  from typing import Any
6
5
 
7
- import msgspec
8
-
9
6
  from grctl.logging_config import get_logger
10
7
  from grctl.models import (
11
8
  Directive,
@@ -18,11 +15,9 @@ from grctl.models import (
18
15
  )
19
16
  from grctl.models.directive import StepResult
20
17
  from grctl.models.run_info_helper import RunInfoManager
21
- from grctl.worker.codec import CodecRegistry
22
- from grctl.worker.context import Context
23
18
  from grctl.worker.errors import NextDirectiveMissingError
24
19
  from grctl.worker.runtime import StepRuntime, set_step_runtime
25
- from grctl.workflow.workflow import HandlerSpec
20
+ from grctl.workflow.workflow import HandlerConfig
26
21
 
27
22
  logger = get_logger(__name__)
28
23
 
@@ -46,7 +41,9 @@ def workflow_error_handler(func): # noqa: ANN001, ANN201
46
41
  stack_trace=stack_trace,
47
42
  ),
48
43
  )
49
- await self.runtime.publisher.publish_next_directive(self.runtime.run_info, fail_directive)
44
+ await self.runtime.publisher.publish_next_directive(
45
+ self.runtime.run_info, fail_directive, enc_hook=self.runtime.codec.enc_hook
46
+ )
50
47
  raise
51
48
 
52
49
  return wrapper
@@ -60,20 +57,13 @@ But it's okay for now as we are focusing on correctness and not efficiency.
60
57
 
61
58
 
62
59
  class WorkflowRunner:
63
- """Orchestrates workflow run lifecycle.
64
-
65
- Manages the complete workflow run lifecycle including:
66
- - Validating run requests
67
- - Publishing lifecycle events (started, completed/failed/timeout)
68
- - Executing workflow with timeout enforcement
69
- - Tracking execution duration
70
- """
60
+ """Orchestrates workflow run lifecycle."""
71
61
 
72
62
  _result = None
73
63
 
74
64
  def __init__(self, runtime: StepRuntime) -> None:
75
65
  self.runtime = runtime
76
- set_step_runtime(runtime)
66
+ self._runtime_token = set_step_runtime(runtime)
77
67
  self.workflow = runtime.workflow
78
68
 
79
69
  async def handle_directive(self, directive: Directive) -> None:
@@ -90,102 +80,96 @@ class WorkflowRunner:
90
80
 
91
81
  @workflow_error_handler
92
82
  async def handle_start(self, payload: Any | None) -> None:
93
- start_config = self.workflow._start_handler # noqa: SLF001
94
- if start_config is None:
83
+ handler_config = self.workflow._start_handler # noqa: SLF001
84
+ if handler_config is None:
95
85
  raise ValueError("Workflow start handler is not defined.")
96
86
 
97
87
  self.runtime.run_info = RunInfoManager.start(self.runtime.run_info, datetime.now(UTC))
98
88
  self.runtime.step_name = "start"
99
- await _execute_step(self.runtime, start_config.handler, start_config.spec, payload)
89
+ await self._execute_step(handler_config, payload)
100
90
 
101
91
  @workflow_error_handler
102
92
  async def handle_event(self, event_name: str, payload: Any | None) -> None:
103
- """Handle an event by executing its corresponding handler."""
104
93
  handler_config = self.workflow._on_event_handlers.get(event_name) # noqa: SLF001
105
94
  if handler_config is None:
106
95
  logger.warning(f"No handler registered for event '{event_name}'")
107
96
  return
108
97
 
109
98
  self.runtime.step_name = event_name
110
- await _execute_step(self.runtime, handler_config.handler, handler_config.spec, payload)
99
+ await self._execute_step(handler_config, payload)
111
100
 
112
101
  @workflow_error_handler
113
102
  async def handle_step(self, step: Step) -> None:
114
- """Execute workflow step."""
115
103
  logger.debug(f"Executing step: {step.step_name} for run {self.runtime.run_info.id}")
116
104
 
117
105
  step_config = self.workflow._step_handlers.get(step.step_name) # noqa: SLF001
118
106
  if step_config is None:
119
107
  raise ValueError(f"Step handler '{step.step_name}' is not defined.")
120
108
  self.runtime.step_name = step.step_name
121
- await _execute_step(self.runtime, step_config.handler, step_config.spec, None)
109
+ await self._execute_step(step_config, None)
122
110
 
123
111
  def _get_event_name(self, handler: Any) -> str | None:
124
- """Check if handler is an @on_event handler and return its event name."""
125
112
  for event_name, event_config in self.workflow._on_event_handlers.items(): # noqa: SLF001
126
113
  if handler == event_config.handler:
127
114
  return event_name
128
115
  return None
129
116
 
117
+ async def _execute_step(self, handler_config: HandlerConfig, payload: Any | None) -> None:
118
+ ctx = self.runtime.get_step_context()
119
+ start_time = datetime.now(UTC)
130
120
 
131
- async def _dispatch_handler(
132
- codec: CodecRegistry,
133
- spec: HandlerSpec,
134
- ctx: Context,
135
- handler: Callable[..., Awaitable[Directive]],
136
- payload: Any | None,
137
- ) -> Directive:
138
- if not spec.params:
139
- return await handler(ctx)
140
- if not isinstance(payload, dict):
141
- raise TypeError(f"Handler expects params {list(spec.params)} but payload is not a dict: {type(payload)}")
142
- typed_kwargs = {
143
- name: msgspec.convert(payload[name], param_type, dec_hook=codec.dec_hook)
144
- for name, param_type in spec.params.items()
145
- }
146
- return await handler(ctx, **typed_kwargs)
147
-
148
-
149
- async def _execute_step(
150
- runtime: StepRuntime,
151
- handler: Callable[..., Awaitable[Directive]],
152
- spec: HandlerSpec,
153
- payload: Any | None,
154
- ) -> None:
155
- ctx = runtime.get_step_context()
156
- start_time = datetime.now(UTC)
157
-
158
- await _publish_step_started_event(runtime)
159
-
160
- directive = await _dispatch_handler(runtime.codec, spec, ctx, handler, payload)
161
-
162
- await _publish_next_directive(runtime, directive, start_time)
163
-
164
-
165
- async def _publish_step_started_event(runtime: StepRuntime) -> None:
166
- # Step started event will be send by the worker
167
- if runtime.step_history is None or len(runtime.step_history) == 0:
168
- await runtime.record(
169
- kind=HistoryKind.step_started,
170
- payload=StepStarted(step_name=runtime.step_name),
171
- operation_id="",
172
- )
121
+ await self._publish_step_started_event()
173
122
 
123
+ spec = handler_config.spec
124
+ handler = handler_config.handler
125
+ if not spec.params:
126
+ directive = await handler(ctx)
174
127
 
175
- async def _publish_next_directive(
176
- runtime: StepRuntime,
177
- directive: Directive,
178
- step_start_time: datetime | None = None,
179
- ) -> None:
180
- """Process Directive returned by a handler."""
181
- if not isinstance(directive, Directive):
182
- raise NextDirectiveMissingError(f"Step did not return a Directive. {directive=}", runtime.step_name)
128
+ # Single param: if payload is already keyed by param name use the value,
129
+ # otherwise treat payload itself as the value (e.g. bare Pydantic model).
130
+ elif len(spec.params) == 1:
131
+ name, param_type = next(iter(spec.params.items()))
132
+ raw = payload[name] if isinstance(payload, dict) and name in payload else payload
133
+ typed_value = self.runtime.codec.from_primitive(raw, param_type)
134
+ directive = await handler(ctx, **{name: typed_value})
183
135
 
184
- if step_start_time is not None and isinstance(directive.msg, StepResult):
185
- directive.msg.duration_ms = int((datetime.now(UTC) - step_start_time).total_seconds() * 1000)
136
+ # Multi param: convert each param from the payload dict and pass as kwargs
137
+ else:
138
+ if not isinstance(payload, dict):
139
+ raise TypeError(
140
+ f"Handler expects params {list(spec.params)} but payload is not a dict: {type(payload)}"
141
+ )
142
+ typed_kwargs = {
143
+ name: self.runtime.codec.from_primitive(payload[name], param_type)
144
+ for name, param_type in spec.params.items()
145
+ }
146
+ directive = await handler(ctx, **typed_kwargs)
147
+
148
+ await self._publish_next_directive(directive, start_time)
149
+
150
+ async def _publish_step_started_event(self) -> None:
151
+ if self.runtime.step_history is None or len(self.runtime.step_history) == 0:
152
+ await self.runtime.record(
153
+ kind=HistoryKind.step_started,
154
+ payload=StepStarted(step_name=self.runtime.step_name),
155
+ operation_id="",
156
+ )
157
+
158
+ async def _publish_next_directive(
159
+ self,
160
+ directive: Directive,
161
+ step_start_time: datetime | None = None,
162
+ ) -> None:
163
+ if not isinstance(directive, Directive):
164
+ raise NextDirectiveMissingError(f"Step did not return a Directive. {directive=}", self.runtime.step_name)
186
165
 
187
- pending_updates = runtime.store.get_pending_updates()
188
- if pending_updates:
189
- directive.kv_revs = pending_updates
166
+ if step_start_time is not None and isinstance(directive.msg, StepResult):
167
+ directive.msg.duration_ms = int((datetime.now(UTC) - step_start_time).total_seconds() * 1000)
190
168
 
191
- await runtime.publisher.publish_next_directive(runtime.run_info, directive)
169
+ pending_updates = self.runtime.store.get_pending_updates()
170
+ if pending_updates:
171
+ directive.kv_revs = pending_updates
172
+
173
+ await self.runtime.publisher.publish_next_directive(
174
+ self.runtime.run_info, directive, enc_hook=self.runtime.codec.enc_hook
175
+ )
@@ -1,6 +1,6 @@
1
1
  import asyncio
2
2
  import hashlib
3
- from contextvars import ContextVar
3
+ from contextvars import ContextVar, Token
4
4
  from datetime import UTC, datetime
5
5
  from typing import TYPE_CHECKING, Any
6
6
 
@@ -29,10 +29,10 @@ def _generate_operation_id(fn_name: str, args: dict[str, Any], seq: int) -> str:
29
29
 
30
30
 
31
31
  class NonDeterminismError(Exception):
32
- """Raised when replay journal doesn't match current execution order."""
32
+ """Raised when replay history doesn't match current execution order."""
33
33
 
34
34
 
35
- # Only these kinds participate in replay journal matching — everything else is observability-only
35
+ # Only these kinds participate in replay history matching — everything else is observability-only
36
36
  _REPLAY_KINDS = frozenset(
37
37
  {
38
38
  HistoryKind.task_completed,
@@ -86,6 +86,8 @@ class StepRuntime:
86
86
  async def next(
87
87
  self, acceptable_kinds: HistoryKind | frozenset[HistoryKind], operation_id: str
88
88
  ) -> asyncio.Future[HistoryEvents] | None:
89
+
90
+ # Check if we have a step history and the cursor is within bounds. If not, we are not replaying — return None.
89
91
  if not self.step_history or self._cursor >= len(self.step_history):
90
92
  return None
91
93
 
@@ -101,7 +103,7 @@ class StepRuntime:
101
103
  if not future.done():
102
104
  if self._cursor >= len(self.step_history):
103
105
  self._pending.pop(operation_id, None)
104
- return None # journal exhausted — live execution
106
+ return None # history exhausted — live execution
105
107
  raise NonDeterminismError(
106
108
  f"Unresolved operation {operation_id} ({acceptable_kinds}) after yield — "
107
109
  f"cursor at {self._cursor}, pending: {list(self._pending.keys())}"
@@ -124,7 +126,7 @@ class StepRuntime:
124
126
  if entry.kind not in acceptable_kinds:
125
127
  future.set_exception(
126
128
  NonDeterminismError(
127
- f"Expected one of {acceptable_kinds} but journal has {entry.kind} "
129
+ f"Expected one of {acceptable_kinds} but history has {entry.kind} "
128
130
  f"at cursor {self._cursor} for {entry.operation_id}"
129
131
  )
130
132
  )
@@ -134,7 +136,7 @@ class StepRuntime:
134
136
 
135
137
  async def record(self, kind: HistoryKind, payload: HistoryEvents, operation_id: str) -> None:
136
138
  event = self._create_history_event(kind, payload, operation_id)
137
- await self.publisher.publish_history(run_info=self.run_info, event=event)
139
+ await self.publisher.publish_history(run_info=self.run_info, event=event, enc_hook=self.codec.enc_hook)
138
140
 
139
141
  def get_step_context(self) -> "Context":
140
142
  from grctl.worker.context import Context # noqa: PLC0415
@@ -179,5 +181,5 @@ def get_step_runtime() -> StepRuntime:
179
181
  return _step_run_time.get()
180
182
 
181
183
 
182
- def set_step_runtime(runtime: StepRuntime) -> None:
183
- _step_run_time.set(runtime)
184
+ def set_step_runtime(runtime: StepRuntime) -> Token:
185
+ return _step_run_time.set(runtime)
@@ -7,7 +7,7 @@ import traceback
7
7
  from collections.abc import AsyncGenerator, Awaitable, Callable
8
8
  from dataclasses import dataclass
9
9
  from datetime import UTC, datetime
10
- from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, cast, overload
10
+ from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, cast, get_type_hints, overload
11
11
 
12
12
  from grctl.logging_config import get_logger
13
13
  from grctl.models.directive import RetryPolicy
@@ -43,8 +43,7 @@ def _capture_args(fn: Callable[..., Any], args: tuple[Any, ...], kwargs: dict[st
43
43
 
44
44
 
45
45
  def _normalize_args(args_dict: dict[str, Any], codec: CodecRegistry) -> dict[str, Any]:
46
- """Normalize task args to msgspec-encodable primitives via codec round-trip."""
47
- return {k: codec.decode(codec.encode(v)) for k, v in args_dict.items()}
46
+ return {k: codec.to_primitive(v) for k, v in args_dict.items()}
48
47
 
49
48
 
50
49
  def _calculate_backoff_delay(policy: RetryPolicy, attempt: int) -> int:
@@ -80,20 +79,17 @@ def _is_error_retryable(error: Exception, policy: RetryPolicy) -> bool:
80
79
  """Check if an error should be retried based on the retry policy filters."""
81
80
  error_type = type(error).__name__
82
81
 
83
- has_retryable = policy.retryable_errors is not None
84
- has_non_retryable = policy.non_retryable_errors is not None
85
-
86
- if not has_retryable and not has_non_retryable:
82
+ if policy.retryable_errors is None and policy.non_retryable_errors is None:
87
83
  return True
88
84
 
89
- if has_retryable and not has_non_retryable:
90
- return error_type in policy.retryable_errors # type: ignore[operator] # ty:ignore[unsupported-operator]
85
+ if policy.retryable_errors is not None and policy.non_retryable_errors is None:
86
+ return error_type in policy.retryable_errors
91
87
 
92
- if not has_retryable and has_non_retryable:
93
- return error_type not in policy.non_retryable_errors # type: ignore[operator] # ty:ignore[unsupported-operator]
88
+ if policy.retryable_errors is None and policy.non_retryable_errors is not None:
89
+ return error_type not in policy.non_retryable_errors
94
90
 
95
91
  # Both set: must be in retryable AND not in non_retryable
96
- return error_type in policy.retryable_errors and error_type not in policy.non_retryable_errors # type: ignore[operator] # ty:ignore[unsupported-operator]
92
+ return error_type in policy.retryable_errors and error_type not in policy.non_retryable_errors # ty:ignore[unsupported-operator]
97
93
 
98
94
 
99
95
  @dataclass
@@ -139,8 +135,8 @@ class RetryRunner:
139
135
  can_retry = (
140
136
  self._policy is not None and attempt < self._max_attempts and _is_error_retryable(e, self._policy)
141
137
  )
142
- if can_retry:
143
- delay_ms = _calculate_backoff_delay(self._policy, attempt) # type: ignore[arg-type] # ty:ignore[invalid-argument-type]
138
+ if can_retry and self._policy is not None:
139
+ delay_ms = _calculate_backoff_delay(self._policy, attempt)
144
140
  yield AttemptFailed(
145
141
  attempt=attempt,
146
142
  max_attempts=self._max_attempts,
@@ -349,6 +345,8 @@ async def _execute_task(
349
345
  normalized_args = _normalize_args(args_dict, runtime.codec)
350
346
  operation_id = runtime.generate_operation_id(task_name, normalized_args)
351
347
 
348
+ # If it's a replaying run, wait for the task outcome from the step history (in-memory).
349
+ # Otherwise, execute the task live.
352
350
  future = await runtime.next(
353
351
  frozenset({HistoryKind.task_completed, HistoryKind.task_failed, HistoryKind.task_cancelled}),
354
352
  operation_id,
@@ -357,7 +355,14 @@ async def _execute_task(
357
355
  logger.info(f"Replaying outcome for task {task_name} ({operation_id}) in step {step_name}")
358
356
  event = await future
359
357
  if isinstance(event, TaskCompleted):
360
- return event.output
358
+ raw = event.output["result"]
359
+ try:
360
+ return_type = get_type_hints(fn).get("return")
361
+ except Exception:
362
+ return_type = None
363
+ if return_type is not None:
364
+ return runtime.codec.from_primitive(raw, return_type)
365
+ return raw
361
366
  if isinstance(event, TaskFailed):
362
367
  raise _reconstruct_error(event.error)
363
368
  # TaskCancelled — task didn't finish; fall through to execute it live
@@ -391,7 +396,7 @@ async def _execute_task(
391
396
  task_id=operation_id,
392
397
  task_name=task_name,
393
398
  step_name=step_name,
394
- output=result,
399
+ output={"result": runtime.codec.to_primitive(result)},
395
400
  duration_ms=duration_ms,
396
401
  ),
397
402
  operation_id,
@@ -2,11 +2,11 @@
2
2
 
3
3
  from grctl.models.directive import Directive
4
4
  from grctl.workflow.handle import WorkflowHandle
5
- from grctl.workflow.workflow import StepConfig, Workflow
5
+ from grctl.workflow.workflow import HandlerConfig, Workflow
6
6
 
7
7
  __all__ = [
8
8
  "Directive",
9
- "StepConfig",
9
+ "HandlerConfig",
10
10
  "Workflow",
11
11
  "WorkflowHandle",
12
12
  ]
@@ -44,12 +44,6 @@ def inspect_handler(fn: Callable[..., Any]) -> HandlerSpec:
44
44
  class HandlerConfig:
45
45
  handler: Callable[..., Awaitable[Directive]]
46
46
  spec: HandlerSpec
47
-
48
-
49
- @dataclasses.dataclass
50
- class StepConfig:
51
- handler: Callable[..., Awaitable[Directive]]
52
- spec: HandlerSpec
53
47
  timeout: timedelta | None = None
54
48
 
55
49
 
@@ -82,7 +76,7 @@ class Workflow:
82
76
  self._type = workflow_type
83
77
  self._start_handler: HandlerConfig | None = None
84
78
  self._run_handler: Callable[..., Any] | None = None
85
- self._step_handlers: dict[str, StepConfig] = {}
79
+ self._step_handlers: dict[str, HandlerConfig] = {}
86
80
  self._on_event_handlers: dict[str, HandlerConfig] = {}
87
81
  self._update_handlers: dict[str, Callable[..., Any]] = {}
88
82
  self._query_handlers: dict[str, Callable[..., Any]] = {}
@@ -170,7 +164,7 @@ class Workflow:
170
164
  step_timeout = timeout if timeout is not None else timedelta(seconds=10)
171
165
  spec = inspect_handler(func)
172
166
 
173
- self._step_handlers[func.__name__] = StepConfig(
167
+ self._step_handlers[func.__name__] = HandlerConfig(
174
168
  handler=func,
175
169
  spec=spec,
176
170
  timeout=step_timeout,
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "grctl-sdk-python"
3
- version = "0.1.1"
3
+ version = "0.1.2"
4
4
  description = "The Python SDK for the Ground Control"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.13"