SparkRT 0.1.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sparkrt/__init__.py ADDED
@@ -0,0 +1,33 @@
1
+ """SparkRT - edge-side low-latency inference runtime for SparkMind2 models.
2
+
3
+ Public surface:
4
+ from sparkrt import from_sparkmind_agent, InferenceSession
5
+
6
+ The runtime is organised in four decoupled layers (see ``docs``/memory plan):
7
+
8
+ processors obs -> model-ready tensors -> action (normalize/tokenize)
9
+ adapters *what* a model computes (wraps SparkMind2 nn.Module as regions)
10
+ backends *how* it executes (eager now; CUDA-graph / native C++ later)
11
+ session unified, stateful select_action() loop (queue + ensemble)
12
+
13
+ ``sparkrt.core`` defines the contracts that tie these together and imports
14
+ without any SparkMind2 / heavy dependency, so the execution seam stays clean.
15
+ """
16
+
17
+ from sparkrt.api import from_sparkmind_agent, load_policy
18
+ from sparkrt.config import BackendConfig, Pi05RuntimeConfig, RuntimeConfig
19
+ from sparkrt.observation import Observation
20
+ from sparkrt.policy import Policy
21
+ from sparkrt.session import InferenceSession
22
+
23
+ __all__ = [
24
+ "from_sparkmind_agent",
25
+ "load_policy",
26
+ "Policy",
27
+ "Observation",
28
+ "InferenceSession",
29
+ "RuntimeConfig",
30
+ "BackendConfig",
31
+ "Pi05RuntimeConfig",
32
+ ]
33
+ __version__ = "0.1.0rc1"
@@ -0,0 +1,13 @@
1
+ """Model adapters: wrap SparkMind2 nn.Modules as runtime regions."""
2
+
3
+ from sparkrt.adapters.act_adapter import ACTAdapter
4
+ from sparkrt.adapters.pi05_adapter import Pi05Adapter
5
+
6
+ __all__ = ["ACTAdapter", "Pi05Adapter"]
7
+
8
+ #: Maps ``cfg.Model.type`` -> adapter class. Extend this (plus a processor) to
9
+ #: add a new model; no backend or session changes are required.
10
+ ADAPTER_REGISTRY = {
11
+ "act": ACTAdapter,
12
+ "pi05": Pi05Adapter,
13
+ }
@@ -0,0 +1,107 @@
1
+ """Read-only prefix KV-cache shim for Pi0.5 denoise_step.
2
+
3
+ ``_ReadOnlyPrefixCache`` is a lightweight HF-cache-protocol adapter that wraps
4
+ the frozen prefix KV tensors (computed once by ``encode_prefix``) and avoids
5
+ cloning them on every denoising step. It implements the minimal subset of the
6
+ HF ``DynamicCache`` interface used by ``PaliGemmaWithExpertModel.forward`` when
7
+ ``use_cache=False``, i.e. cross-attending to a prefix without updating it.
8
+
9
+ When the backend is ``cudagraph`` and no sliding-window layers are present the
10
+ regular ``clone_past_key_values`` copy is replaced by this object. The saving
11
+ is ~2–3 ms/chunk on A800 (18 layers × 2 tensors cloned × 10 steps avoided).
12
+
13
+ For ``torchcompile`` the shim is disabled by default because it introduces a
14
+ small extra numeric drift when combined with Inductor fusion; the ``off`` or
15
+ ``auto`` mode in :class:`~sparkrt.config.Pi05RuntimeConfig` controls this.
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ from typing import Any
21
+
22
+
23
+ __all__ = ["ReadOnlyPrefixCache", "ReadOnlyPrefixCacheLayer"]
24
+
25
+
26
+ class ReadOnlyPrefixCacheLayer:
27
+ """Minimal HF cache-layer shim for a single frozen prefix KV pair."""
28
+
29
+ is_sliding = False
30
+ is_compileable = False
31
+ is_initialized = True
32
+
33
+ def __init__(self, keys: Any, values: Any) -> None:
34
+ self.keys = keys
35
+ self.values = values
36
+
37
+ def update(
38
+ self, key_states: Any, value_states: Any, *args: Any, **kwargs: Any
39
+ ) -> tuple[Any, Any]:
40
+ import torch
41
+
42
+ return (
43
+ torch.cat([self.keys, key_states], dim=-2),
44
+ torch.cat([self.values, value_states], dim=-2),
45
+ )
46
+
47
+ def get_seq_length(self) -> int:
48
+ return int(self.keys.shape[-2])
49
+
50
+ def get_mask_sizes(self, query_length: int) -> tuple[int, int]:
51
+ return self.get_seq_length() + int(query_length), 0
52
+
53
+ def get_max_cache_shape(self) -> int:
54
+ return -1
55
+
56
+
57
+ class ReadOnlyPrefixCache:
58
+ """Cache-like prefix view that avoids cloning immutable prefix KV tensors.
59
+
60
+ Implements the minimal subset of the HF DynamicCache protocol needed by
61
+ ``PaliGemmaWithExpertModel.forward`` when ``use_cache=False``.
62
+
63
+ :param layers: Tuple of ``(keys, values, sliding_window)`` triples as
64
+ returned by ``encode_prefix``; ``sliding_window`` is used only to
65
+ decide whether this optimisation is safe (requires all ``None``).
66
+ """
67
+
68
+ offloading = False
69
+ is_compileable = False
70
+ is_initialized = True
71
+
72
+ def __init__(self, layers: tuple[tuple[Any, Any, Any], ...]) -> None:
73
+ self.layers = [
74
+ ReadOnlyPrefixCacheLayer(keys, values) for keys, values, _ in layers
75
+ ]
76
+
77
+ def update(
78
+ self,
79
+ key_states: Any,
80
+ value_states: Any,
81
+ layer_idx: int,
82
+ *args: Any,
83
+ **kwargs: Any,
84
+ ) -> tuple[Any, Any]:
85
+ return self.layers[layer_idx].update(key_states, value_states, *args, **kwargs)
86
+
87
+ def get_seq_length(self, layer_idx: int = 0) -> int:
88
+ if layer_idx >= len(self.layers):
89
+ return 0
90
+ return self.layers[layer_idx].get_seq_length()
91
+
92
+ def get_mask_sizes(self, query_length: int, layer_idx: int = 0) -> tuple[int, int]:
93
+ if layer_idx >= len(self.layers):
94
+ return int(query_length), 0
95
+ return self.layers[layer_idx].get_mask_sizes(query_length)
96
+
97
+ def get_max_cache_shape(self, layer_idx: int = 0) -> int:
98
+ if layer_idx >= len(self.layers):
99
+ return -1
100
+ return self.layers[layer_idx].get_max_cache_shape()
101
+
102
+ @property
103
+ def is_sliding(self) -> list[bool]:
104
+ return [False for _ in self.layers]
105
+
106
+ def __len__(self) -> int:
107
+ return len(self.layers)
@@ -0,0 +1,127 @@
1
+ """ACT adapter - single-shot Action Chunking Transformer.
2
+
3
+ ACT is the simple execution shape: one forward pass maps an observation to a
4
+ full action chunk ``[B, chunk_size, action_dim]``; there is no sampling loop and
5
+ no language prompt. This adapter declares a single ``"forward"`` region and
6
+ replicates exactly the image-list assembly that ``ACTAgent.predict_action_chunk``
7
+ does before calling the module.
8
+
9
+ At inference the ACT forward is a fixed-shape, purely device-side computation
10
+ (the VAE encoder is skipped in eval, so there is no host-side control flow or
11
+ RNG): a ResNet backbone per camera plus a transformer encoder/decoder. That
12
+ makes it an ideal single-graph capture target. The region therefore takes its
13
+ inputs as *positional CUDA tensors* (``state`` then one tensor per camera) -
14
+ the form a graph backend can record over static buffers - and reconstructs the
15
+ tiny ``{state, images}`` dict the module indexes internally. Capture only kicks
16
+ in for the simple ``state + images`` configuration; exotic feature sets
17
+ (environment-state or DoF features) keep the eager dict path.
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ from typing import Any, Dict, List
23
+
24
+ from sparkmind.data.constants import OBS_IMAGES, OBS_STATE
25
+
26
+ from sparkrt.core.adapter import Capabilities, ModelAdapter
27
+ from sparkrt.core.region import Region
28
+
29
+
30
+ class ACTAdapter(ModelAdapter):
31
+ """Adapter for :class:`sparkmind...act_model.ACTModel` via an ``ACTAgent``.
32
+
33
+ :param agent: A constructed SparkMind2 ``ACTAgent`` (provides ``.model``,
34
+ ``.cfg`` and ``.image_features``).
35
+ """
36
+
37
+ def __init__(self, agent: Any) -> None:
38
+ super().__init__()
39
+ self._agent = agent
40
+ self._model = agent.model
41
+ cfg = agent.cfg
42
+ # image_features is a dict {obs_key: feature}; we only need the keys,
43
+ # in order, to assemble the OBS_IMAGES list the module expects.
44
+ self._image_keys = list(getattr(agent, "image_features", {}) or {})
45
+
46
+ model = self._model
47
+ self._has_state = getattr(model, "robot_state_feature", None) is not None
48
+ has_env_state = getattr(model, "env_state_feature", None) is not None
49
+ uses_dof = bool(getattr(model, "use_dof_features", False))
50
+ # The positional/captured path covers the common ACT shape (robot state
51
+ # + camera images). Anything that pulls extra batch keys into the module
52
+ # forward stays on the eager dict path to preserve correctness.
53
+ self._capturable = (
54
+ bool(self._image_keys) and not has_env_state and not uses_dof
55
+ )
56
+
57
+ model_cfg = cfg.Model
58
+ action_dim = int(model_cfg.output_features["action"].shape[0])
59
+ state_dim = 0
60
+ if self._has_state:
61
+ state_feat = model_cfg.input_features.get(OBS_STATE)
62
+ if state_feat is not None:
63
+ state_dim = int(state_feat.shape[0])
64
+ ensemble_coeff = getattr(cfg.Trainer, "temporal_ensemble_coeff", None)
65
+ self._ensemble_coeff = ensemble_coeff
66
+ self._caps = Capabilities(
67
+ requires_prompt=False,
68
+ is_iterative=False,
69
+ num_inference_steps=1,
70
+ chunk_size=int(model_cfg.chunk_size),
71
+ n_action_steps=int(model_cfg.n_action_steps),
72
+ action_dim=action_dim,
73
+ supports_temporal_ensemble=ensemble_coeff is not None,
74
+ camera_keys=tuple(self._image_keys),
75
+ state_dim=state_dim,
76
+ )
77
+
78
+ @property
79
+ def capabilities(self) -> Capabilities:
80
+ return self._caps
81
+
82
+ def build_regions(self) -> Dict[str, Region]:
83
+ model = self._model
84
+
85
+ if not self._capturable:
86
+ # Eager fallback: the module indexes whatever keys it needs straight
87
+ # out of the full batch dict (env-state / DoF configurations).
88
+ def forward_dict(batch: Dict[str, Any]):
89
+ return model(batch)[0]
90
+
91
+ return {"forward": Region("forward", forward_dict, capturable=False)}
92
+
93
+ has_state = self._has_state
94
+
95
+ def forward(*tensors: Any):
96
+ # Positional layout: (state, img0, img1, ...) when a robot-state
97
+ # feature is present, else (img0, img1, ...). Rebuild the minimal
98
+ # dict the module forward indexes; ACTModel returns
99
+ # (actions, (mu, log_sigma)) and we want the actions.
100
+ if has_state:
101
+ batch = {OBS_STATE: tensors[0], OBS_IMAGES: list(tensors[1:])}
102
+ else:
103
+ batch = {OBS_IMAGES: list(tensors)}
104
+ return model(batch)[0]
105
+
106
+ # Fixed-shape backbone + transformer over static buffers -> capturable.
107
+ return {"forward": Region("forward", forward, capturable=True)}
108
+
109
+ def predict_chunk(self, ctx: Any, batch: Dict[str, Any], *, noise: Any = None):
110
+ if not self._capturable:
111
+ # Assemble the multi-camera image list exactly as the agent does.
112
+ batch = dict(batch)
113
+ batch[OBS_IMAGES] = [batch[key] for key in self._image_keys]
114
+ return self.region("forward")(batch)
115
+
116
+ # Pass the exact tensors the module needs, positionally, so a graph
117
+ # backend can capture/replay the region.
118
+ images = [batch[key] for key in self._image_keys]
119
+ args: List[Any] = [batch[OBS_STATE], *images] if self._has_state else images
120
+ return self.region("forward")(*args)
121
+
122
+ def make_ensembler(self) -> Any:
123
+ from sparkmind.learning.IL.models.act_model import ACTTemporalEnsembler
124
+
125
+ if self._ensemble_coeff is None:
126
+ raise NotImplementedError("temporal ensemble not enabled for this model")
127
+ return ACTTemporalEnsembler(self._ensemble_coeff, self._caps.chunk_size)
@@ -0,0 +1,381 @@
1
+ """Pi0.5 adapter - VLA with a flow-matching denoising loop.
2
+
3
+ Pi0.5 is the latency-critical, *stateful-per-chunk* shape:
4
+
5
+ prepare images + language -> embed prefix (SigLIP + PaliGemma) -> KV cache
6
+ -> denoise loop (N steps over a 300M expert)
7
+ -> action chunk
8
+
9
+ The adapter declares regions that mirror exactly what
10
+ ``PI05Pytorch.sample_actions`` does, but split at the natural seams:
11
+
12
+ * ``encode_prefix`` (run once per chunk, **not** capturable): embeds the images
13
+ and language prompt and runs the PaliGemma prefill to build the KV cache. This
14
+ involves HF control flow / variable content and is not the hot path.
15
+ * ``embed_suffix`` (run ``num_inference_steps`` times, **not** capturable):
16
+ embeds the noisy actions + timestep into the suffix tokens. It is cheap (a few
17
+ small projections) but builds its attention-mask constant via
18
+ ``torch.tensor([...], device=cuda)`` - a host->device copy that is illegal
19
+ during graph capture - so it stays eager.
20
+ * ``denoise_step`` (run ``num_inference_steps`` times, **capturable**): the
21
+ expensive expert-transformer forward over the fixed-shape suffix against the
22
+ cached prefix, plus the action projection. This is the latency-critical inner
23
+ hot path and the CUDA-graph capture target. It mirrors the body of
24
+ ``PI05Pytorch.denoise_step`` *after* ``embed_suffix`` and reuses every
25
+ ``nn.Module`` (the expert, ``action_out_proj``); only the mask-assembly glue
26
+ is inlined, exactly as ``encode_prefix`` mirrors the prefill glue.
27
+
28
+ The KV cache is passed to ``denoise_step`` as a *flat tuple of tensors* (so the
29
+ graph backend can allocate one static buffer per tensor and replay across
30
+ chunks); the per-layer ``sliding_window`` metadata - constant for the model - is
31
+ held on the adapter and used to rebuild the ``DynamicCache`` inside the region.
32
+ The denoise loop body (``x_t += dt * v_t``) stays in Python in ``predict_chunk``,
33
+ identical to ``sample_actions``, which keeps numerical parity exact.
34
+
35
+ Image/language preprocessing intentionally stays *outside* the regions (it has
36
+ data-dependent shapes and control flow), so the captured hot path is a pure
37
+ tensor-in/tensor-out graph a native core could reimplement 1:1.
38
+ """
39
+
40
+ from __future__ import annotations
41
+
42
+ from dataclasses import dataclass, field
43
+ from typing import Any, Dict, List, Optional
44
+
45
+ from sparkrt.adapters._pi05_kv_cache import ReadOnlyPrefixCache
46
+ from sparkrt.config.runtime import Pi05RuntimeConfig
47
+ from sparkrt.core.adapter import Capabilities, ModelAdapter
48
+ from sparkrt.core.region import Region
49
+
50
+
51
+ @dataclass
52
+ class _Pi05Context:
53
+ schedules: dict[tuple[str, int, int, str], tuple[list[Any], list[float]]] = field(
54
+ default_factory=dict
55
+ )
56
+ suffix_masks: dict[tuple[str, int, int, str], tuple[Any, Any]] = field(
57
+ default_factory=dict
58
+ )
59
+
60
+
61
+ class Pi05Adapter(ModelAdapter):
62
+ """Adapter for ``PI05Pytorch`` via a SparkMind2 ``Pi05Agent``.
63
+
64
+ :param agent: A constructed SparkMind2 ``Pi05Agent`` (provides ``.model``,
65
+ ``.cfg`` and the ``prepare_images`` / ``_get_language_inputs`` helpers
66
+ used to build model inputs).
67
+ """
68
+
69
+ def __init__(self, agent: Any, config: Optional[Pi05RuntimeConfig] = None) -> None:
70
+ super().__init__()
71
+ self._agent = agent
72
+ self._model = agent._core_model()
73
+ cfg = agent.cfg
74
+ model_cfg = cfg.Model
75
+ self._action_dim = int(model_cfg.output_features["action"].shape[0])
76
+ self._chunk_size = int(model_cfg.chunk_size)
77
+ self._max_action_dim = int(getattr(model_cfg, "max_action_dim", self._action_dim))
78
+
79
+ # Camera keys (full ``observation.images.*`` keys, in order) and robot
80
+ # state dim, surfaced through Capabilities for the SDK observation layer.
81
+ input_features = model_cfg.input_features or {}
82
+ self._image_keys = list(getattr(agent, "image_features", []) or [])
83
+ from sparkmind.data.constants import OBS_STATE
84
+
85
+ state_feat = input_features.get(OBS_STATE)
86
+ self._state_dim = int(state_feat.shape[0]) if state_feat is not None else 0
87
+
88
+ # Resolve runtime config: explicit > env vars > defaults.
89
+ self._rtcfg = config if config is not None else Pi05RuntimeConfig.from_env()
90
+
91
+ # Resolve num_steps: config value wins; None means read from checkpoint.
92
+ if self._rtcfg.num_steps is not None:
93
+ self._num_steps = self._rtcfg.num_steps
94
+ else:
95
+ self._num_steps = int(getattr(model_cfg, "num_inference_steps", 10))
96
+
97
+ # Per-layer KV-cache sliding-window metadata (constant for the model);
98
+ # learned lazily from the first prefix so denoise_step can rebuild the
99
+ # DynamicCache from a flat tensor tuple.
100
+ self._sliding_windows: Optional[List[Any]] = None
101
+ self._caps = Capabilities(
102
+ requires_prompt=True,
103
+ is_iterative=True,
104
+ num_inference_steps=self._num_steps,
105
+ chunk_size=self._chunk_size,
106
+ n_action_steps=int(model_cfg.n_action_steps),
107
+ action_dim=self._action_dim,
108
+ supports_temporal_ensemble=False,
109
+ camera_keys=tuple(self._image_keys),
110
+ state_dim=self._state_dim,
111
+ )
112
+
113
+ @property
114
+ def capabilities(self) -> Capabilities:
115
+ return self._caps
116
+
117
+ def new_context(self) -> _Pi05Context:
118
+ return _Pi05Context()
119
+
120
+ def _get_schedule(
121
+ self,
122
+ ctx: _Pi05Context,
123
+ device: Any,
124
+ batch_size: int,
125
+ ) -> tuple[list[Any], list[float]]:
126
+ import torch
127
+
128
+ schedule = self._rtcfg.schedule
129
+ key = (str(device), int(batch_size), self._num_steps, schedule)
130
+ cached = ctx.schedules.get(key)
131
+ if cached is not None:
132
+ return cached
133
+
134
+ if schedule == "uniform":
135
+ power = 1.0
136
+ elif schedule == "power1.5":
137
+ power = 1.5
138
+ else:
139
+ power = 2.0
140
+
141
+ t_nodes = [
142
+ ((self._num_steps - idx) / self._num_steps) ** power
143
+ for idx in range(self._num_steps + 1)
144
+ ]
145
+ timesteps = [
146
+ torch.tensor(t_nodes[idx], dtype=torch.float32, device=device).expand(batch_size)
147
+ for idx in range(self._num_steps)
148
+ ]
149
+ dts = [t_nodes[idx + 1] - t_nodes[idx] for idx in range(self._num_steps)]
150
+ cached = (timesteps, dts)
151
+ ctx.schedules[key] = cached
152
+ return cached
153
+
154
+
155
+ def _get_suffix_masks(
156
+ self,
157
+ ctx: _Pi05Context,
158
+ device: Any,
159
+ batch_size: int,
160
+ dtype: Any,
161
+ ) -> tuple[Any, Any]:
162
+ import torch
163
+
164
+ key = (str(device), int(batch_size), self._chunk_size, str(dtype))
165
+ cached = ctx.suffix_masks.get(key)
166
+ if cached is not None:
167
+ return cached
168
+
169
+ suffix_pad_masks = torch.ones(
170
+ batch_size,
171
+ self._chunk_size,
172
+ dtype=torch.bool,
173
+ device=device,
174
+ )
175
+ suffix_att_mask_1d = torch.zeros(self._chunk_size, dtype=dtype, device=device)
176
+ suffix_att_mask_1d[0] = 1
177
+ suffix_att_masks = suffix_att_mask_1d[None, :].expand(batch_size, self._chunk_size)
178
+ cached = (suffix_pad_masks, suffix_att_masks)
179
+ ctx.suffix_masks[key] = cached
180
+ return cached
181
+
182
+ def build_regions(self) -> Dict[str, Region]:
183
+ import torch
184
+ import torch.nn.functional as F
185
+
186
+ from sparkmind.learning.VLA.models.pi05_model import (
187
+ clone_past_key_values,
188
+ create_sinusoidal_pos_embedding,
189
+ make_att_2d_masks,
190
+ )
191
+
192
+ model = self._model
193
+ chunk_size = self._chunk_size
194
+ attn_impl = self._rtcfg.attn_impl
195
+ use_readonly_prefix_cache = self._resolve_readonly_prefix_cache()
196
+
197
+ def encode_prefix(images, img_masks, tokens, masks):
198
+ # Mirrors the prefix section of PI05Pytorch.sample_actions exactly.
199
+ prefix_embs, prefix_pad_masks, prefix_att_masks = model.embed_prefix(
200
+ images, img_masks, tokens, masks
201
+ )
202
+ prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
203
+ prefix_position_ids = prefix_pad_masks.cumsum(dim=1) - 1
204
+ prefix_att_2d_masks_4d = model._prepare_attention_masks_4d(prefix_att_2d_masks)
205
+ language_model = model.paligemma_with_expert.paligemma.model.language_model
206
+ language_model.config._attn_implementation = attn_impl # noqa: SLF001
207
+ if attn_impl == "sdpa":
208
+ # HF's SDPA path uses the 4D mask as an additive bias and
209
+ # requires it to match the query dtype (bf16); the eager path
210
+ # tolerated the float32 mask. Cast to the attention weight dtype.
211
+ prefix_att_2d_masks_4d = prefix_att_2d_masks_4d.to(
212
+ language_model.layers[0].self_attn.q_proj.weight.dtype
213
+ )
214
+ _, past_key_values = model.paligemma_with_expert.forward(
215
+ attention_mask=prefix_att_2d_masks_4d,
216
+ position_ids=prefix_position_ids,
217
+ past_key_values=None,
218
+ inputs_embeds=[prefix_embs, None],
219
+ use_cache=True,
220
+ )
221
+ return prefix_pad_masks, past_key_values
222
+
223
+ def embed_suffix(x_t, timestep):
224
+ # Mirrors PI05Pytorch.embed_suffix's numeric path, but leaves the
225
+ # fixed suffix masks to the adapter context so they can be reused
226
+ # across denoising steps and chunks.
227
+ time_emb = create_sinusoidal_pos_embedding(
228
+ timestep,
229
+ model.action_in_proj.out_features,
230
+ min_period=model.config.min_period,
231
+ max_period=model.config.max_period,
232
+ device=timestep.device,
233
+ )
234
+ time_emb = time_emb.type(dtype=timestep.dtype)
235
+ action_emb = model.action_in_proj(x_t)
236
+ time_emb = model.time_mlp_in(time_emb)
237
+ time_emb = F.silu(time_emb)
238
+ time_emb = model.time_mlp_out(time_emb)
239
+ adarms_cond = F.silu(time_emb)
240
+ return action_emb, adarms_cond
241
+
242
+ def denoise_step(
243
+ suffix_embs,
244
+ adarms_cond,
245
+ prefix_pad_masks,
246
+ suffix_pad_masks,
247
+ suffix_att_masks,
248
+ *kv_tensors,
249
+ ):
250
+ # Mirrors PI05Pytorch.denoise_step *after* embed_suffix: assemble the
251
+ # full attention mask + position ids, run the expert against the
252
+ # (cloned) cached prefix, and project to a velocity. All ops are
253
+ # device-side and fixed-shape, so this body is graph-capturable.
254
+ #
255
+ # Signature ordering matters for the graph backend: the two inputs
256
+ # that change every denoising step (``suffix_embs``, ``adarms_cond``)
257
+ # come first; everything after them - the prefix/suffix masks and the
258
+ # KV cache - is constant across the loop, so ``invariant_from=2`` lets
259
+ # the backend skip re-copying the (large) KV tensors on each replay.
260
+ sliding = self._sliding_windows or []
261
+ layers = tuple(
262
+ (kv_tensors[2 * i], kv_tensors[2 * i + 1], sliding[i])
263
+ for i in range(len(sliding))
264
+ )
265
+ can_use_readonly_prefix_cache = use_readonly_prefix_cache and all(
266
+ sliding_window is None for sliding_window in sliding
267
+ )
268
+ if can_use_readonly_prefix_cache:
269
+ past_key_values = ReadOnlyPrefixCache(layers)
270
+ else:
271
+ past_key_values = clone_past_key_values(layers)
272
+
273
+ suffix_len = suffix_pad_masks.shape[1]
274
+ batch_size = prefix_pad_masks.shape[0]
275
+ prefix_len = prefix_pad_masks.shape[1]
276
+ prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(
277
+ batch_size, suffix_len, prefix_len
278
+ )
279
+ suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
280
+ full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2)
281
+ prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
282
+ position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
283
+ full_att_2d_masks_4d = model._prepare_attention_masks_4d(full_att_2d_masks)
284
+ gemma_expert = model.paligemma_with_expert.gemma_expert.model
285
+ gemma_expert.config._attn_implementation = attn_impl # noqa: SLF001
286
+ if attn_impl == "sdpa":
287
+ full_att_2d_masks_4d = full_att_2d_masks_4d.to(
288
+ gemma_expert.layers[0].self_attn.q_proj.weight.dtype
289
+ )
290
+ outputs_embeds, _ = model.paligemma_with_expert.forward(
291
+ attention_mask=full_att_2d_masks_4d,
292
+ position_ids=position_ids,
293
+ past_key_values=past_key_values,
294
+ inputs_embeds=[None, suffix_embs],
295
+ use_cache=False,
296
+ adarms_cond=[None, adarms_cond],
297
+ )
298
+ suffix_out = outputs_embeds[1][:, -chunk_size:].to(dtype=torch.float32)
299
+ return model.action_out_proj(suffix_out)
300
+
301
+ return {
302
+ # Prefill: HF control flow + variable content -> eager only.
303
+ "encode_prefix": Region("encode_prefix", encode_prefix, capturable=False),
304
+ # Suffix embedding: cheap eager region; fixed suffix masks are
305
+ # cached separately in the adapter context.
306
+ "embed_suffix": Region("embed_suffix", embed_suffix, capturable=False),
307
+ # Expert forward + projection: fixed-shape hot path -> captured.
308
+ # Inputs 2.. (masks + KV cache) are loop-invariant within a chunk.
309
+ "denoise_step": Region(
310
+ "denoise_step", denoise_step, capturable=True, invariant_from=2
311
+ ),
312
+ }
313
+
314
+ def _resolve_readonly_prefix_cache(self) -> bool:
315
+ mode = self._rtcfg.readonly_prefix_cache.lower()
316
+ if mode in {"1", "true", "yes", "on"}:
317
+ return True
318
+ if mode in {"0", "false", "no", "off"}:
319
+ return False
320
+ # auto: use read-only cache for cudagraph (where it was validated),
321
+ # but not for torchcompile (adds numeric drift when combined).
322
+ backend_name = getattr(self._backend, "name", "")
323
+ return backend_name == "cudagraph"
324
+
325
+ def predict_chunk(self, ctx: Any, batch: Dict[str, Any], *, noise: Any = None):
326
+ import torch
327
+
328
+ model = self._model
329
+ # Build model inputs using the agent's own (parity-exact) helpers; this
330
+ # is preprocessing and stays outside the region bodies.
331
+ images, img_masks = self._agent.prepare_images(batch)
332
+ batch_size = images[0].shape[0]
333
+ device = images[0].device
334
+ tokens, masks = self._agent._get_language_inputs(batch, batch_size, device)
335
+
336
+ # Sample noise *before* the prefill, exactly as sample_actions does, so
337
+ # seeded parity matches bit-for-bit (the prefill consumes no RNG).
338
+ if noise is None:
339
+ actions_shape = (batch_size, self._chunk_size, self._max_action_dim)
340
+ noise = model.sample_noise(actions_shape, device)
341
+
342
+ prefix_pad_masks, past_key_values = self.region("encode_prefix")(
343
+ images, img_masks, tokens, masks
344
+ )
345
+
346
+ # Flatten the KV cache into positional tensors for the capturable region;
347
+ # record the (constant) per-layer sliding windows on first use.
348
+ kv_tensors: List[Any] = []
349
+ sliding_windows: List[Any] = []
350
+ for keys, values, sliding_window in past_key_values:
351
+ kv_tensors.append(keys)
352
+ kv_tensors.append(values)
353
+ sliding_windows.append(sliding_window)
354
+ self._sliding_windows = sliding_windows
355
+
356
+ if not isinstance(ctx, _Pi05Context):
357
+ ctx = self.new_context()
358
+
359
+ embed_suffix = self.region("embed_suffix")
360
+ denoise = self.region("denoise_step")
361
+ x_t = noise
362
+ timesteps, dts = self._get_schedule(ctx, device, batch_size)
363
+ for timestep, dt in zip(timesteps, dts):
364
+ suffix_embs, adarms_cond = embed_suffix(x_t, timestep)
365
+ suffix_pad_masks, suffix_att_masks = self._get_suffix_masks(
366
+ ctx,
367
+ device,
368
+ batch_size,
369
+ suffix_embs.dtype,
370
+ )
371
+ v_t = denoise(
372
+ suffix_embs,
373
+ adarms_cond,
374
+ prefix_pad_masks,
375
+ suffix_pad_masks,
376
+ suffix_att_masks,
377
+ *kv_tensors,
378
+ )
379
+ x_t = x_t + dt * v_t
380
+
381
+ return x_t[:, :, : self._action_dim]