SparkRT 0.1.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sparkrt/__init__.py +33 -0
- sparkrt/adapters/__init__.py +13 -0
- sparkrt/adapters/_pi05_kv_cache.py +107 -0
- sparkrt/adapters/act_adapter.py +127 -0
- sparkrt/adapters/pi05_adapter.py +381 -0
- sparkrt/api.py +149 -0
- sparkrt/backends/__init__.py +59 -0
- sparkrt/backends/cudagraph.py +220 -0
- sparkrt/backends/eager.py +38 -0
- sparkrt/backends/torchcompile.py +114 -0
- sparkrt/config/__init__.py +35 -0
- sparkrt/config/loader.py +117 -0
- sparkrt/config/presets/default.yaml +18 -0
- sparkrt/config/presets/latency.yaml +25 -0
- sparkrt/config/presets/memory.yaml +30 -0
- sparkrt/config/presets/quality.yaml +21 -0
- sparkrt/config/presets/safe.yaml +15 -0
- sparkrt/config/runtime.py +205 -0
- sparkrt/core/__init__.py +22 -0
- sparkrt/core/adapter.py +146 -0
- sparkrt/core/backend.py +50 -0
- sparkrt/core/region.py +46 -0
- sparkrt/core/shape.py +47 -0
- sparkrt/eval/__init__.py +18 -0
- sparkrt/eval/policy.py +125 -0
- sparkrt/io/__init__.py +12 -0
- sparkrt/io/checkpoint.py +319 -0
- sparkrt/observation.py +210 -0
- sparkrt/policy.py +136 -0
- sparkrt/processors/__init__.py +6 -0
- sparkrt/processors/base.py +30 -0
- sparkrt/processors/sparkmind.py +40 -0
- sparkrt/session/__init__.py +5 -0
- sparkrt/session/session.py +117 -0
- sparkrt-0.1.0rc1.dist-info/METADATA +334 -0
- sparkrt-0.1.0rc1.dist-info/RECORD +39 -0
- sparkrt-0.1.0rc1.dist-info/WHEEL +5 -0
- sparkrt-0.1.0rc1.dist-info/licenses/LICENSE +164 -0
- sparkrt-0.1.0rc1.dist-info/top_level.txt +1 -0
sparkrt/__init__.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
"""SparkRT - edge-side low-latency inference runtime for SparkMind2 models.
|
|
2
|
+
|
|
3
|
+
Public surface:
|
|
4
|
+
from sparkrt import from_sparkmind_agent, InferenceSession
|
|
5
|
+
|
|
6
|
+
The runtime is organised in four decoupled layers (see ``docs``/memory plan):
|
|
7
|
+
|
|
8
|
+
processors obs -> model-ready tensors -> action (normalize/tokenize)
|
|
9
|
+
adapters *what* a model computes (wraps SparkMind2 nn.Module as regions)
|
|
10
|
+
backends *how* it executes (eager now; CUDA-graph / native C++ later)
|
|
11
|
+
session unified, stateful select_action() loop (queue + ensemble)
|
|
12
|
+
|
|
13
|
+
``sparkrt.core`` defines the contracts that tie these together and imports
|
|
14
|
+
without any SparkMind2 / heavy dependency, so the execution seam stays clean.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from sparkrt.api import from_sparkmind_agent, load_policy
|
|
18
|
+
from sparkrt.config import BackendConfig, Pi05RuntimeConfig, RuntimeConfig
|
|
19
|
+
from sparkrt.observation import Observation
|
|
20
|
+
from sparkrt.policy import Policy
|
|
21
|
+
from sparkrt.session import InferenceSession
|
|
22
|
+
|
|
23
|
+
__all__ = [
|
|
24
|
+
"from_sparkmind_agent",
|
|
25
|
+
"load_policy",
|
|
26
|
+
"Policy",
|
|
27
|
+
"Observation",
|
|
28
|
+
"InferenceSession",
|
|
29
|
+
"RuntimeConfig",
|
|
30
|
+
"BackendConfig",
|
|
31
|
+
"Pi05RuntimeConfig",
|
|
32
|
+
]
|
|
33
|
+
__version__ = "0.1.0rc1"
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
"""Model adapters: wrap SparkMind2 nn.Modules as runtime regions."""
|
|
2
|
+
|
|
3
|
+
from sparkrt.adapters.act_adapter import ACTAdapter
|
|
4
|
+
from sparkrt.adapters.pi05_adapter import Pi05Adapter
|
|
5
|
+
|
|
6
|
+
__all__ = ["ACTAdapter", "Pi05Adapter"]
|
|
7
|
+
|
|
8
|
+
#: Maps ``cfg.Model.type`` -> adapter class. Extend this (plus a processor) to
|
|
9
|
+
#: add a new model; no backend or session changes are required.
|
|
10
|
+
ADAPTER_REGISTRY = {
|
|
11
|
+
"act": ACTAdapter,
|
|
12
|
+
"pi05": Pi05Adapter,
|
|
13
|
+
}
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
"""Read-only prefix KV-cache shim for Pi0.5 denoise_step.
|
|
2
|
+
|
|
3
|
+
``_ReadOnlyPrefixCache`` is a lightweight HF-cache-protocol adapter that wraps
|
|
4
|
+
the frozen prefix KV tensors (computed once by ``encode_prefix``) and avoids
|
|
5
|
+
cloning them on every denoising step. It implements the minimal subset of the
|
|
6
|
+
HF ``DynamicCache`` interface used by ``PaliGemmaWithExpertModel.forward`` when
|
|
7
|
+
``use_cache=False``, i.e. cross-attending to a prefix without updating it.
|
|
8
|
+
|
|
9
|
+
When the backend is ``cudagraph`` and no sliding-window layers are present the
|
|
10
|
+
regular ``clone_past_key_values`` copy is replaced by this object. The saving
|
|
11
|
+
is ~2–3 ms/chunk on A800 (18 layers × 2 tensors cloned × 10 steps avoided).
|
|
12
|
+
|
|
13
|
+
For ``torchcompile`` the shim is disabled by default because it introduces a
|
|
14
|
+
small extra numeric drift when combined with Inductor fusion; the ``off`` or
|
|
15
|
+
``auto`` mode in :class:`~sparkrt.config.Pi05RuntimeConfig` controls this.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
20
|
+
from typing import Any
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
__all__ = ["ReadOnlyPrefixCache", "ReadOnlyPrefixCacheLayer"]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class ReadOnlyPrefixCacheLayer:
|
|
27
|
+
"""Minimal HF cache-layer shim for a single frozen prefix KV pair."""
|
|
28
|
+
|
|
29
|
+
is_sliding = False
|
|
30
|
+
is_compileable = False
|
|
31
|
+
is_initialized = True
|
|
32
|
+
|
|
33
|
+
def __init__(self, keys: Any, values: Any) -> None:
|
|
34
|
+
self.keys = keys
|
|
35
|
+
self.values = values
|
|
36
|
+
|
|
37
|
+
def update(
|
|
38
|
+
self, key_states: Any, value_states: Any, *args: Any, **kwargs: Any
|
|
39
|
+
) -> tuple[Any, Any]:
|
|
40
|
+
import torch
|
|
41
|
+
|
|
42
|
+
return (
|
|
43
|
+
torch.cat([self.keys, key_states], dim=-2),
|
|
44
|
+
torch.cat([self.values, value_states], dim=-2),
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
def get_seq_length(self) -> int:
|
|
48
|
+
return int(self.keys.shape[-2])
|
|
49
|
+
|
|
50
|
+
def get_mask_sizes(self, query_length: int) -> tuple[int, int]:
|
|
51
|
+
return self.get_seq_length() + int(query_length), 0
|
|
52
|
+
|
|
53
|
+
def get_max_cache_shape(self) -> int:
|
|
54
|
+
return -1
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class ReadOnlyPrefixCache:
|
|
58
|
+
"""Cache-like prefix view that avoids cloning immutable prefix KV tensors.
|
|
59
|
+
|
|
60
|
+
Implements the minimal subset of the HF DynamicCache protocol needed by
|
|
61
|
+
``PaliGemmaWithExpertModel.forward`` when ``use_cache=False``.
|
|
62
|
+
|
|
63
|
+
:param layers: Tuple of ``(keys, values, sliding_window)`` triples as
|
|
64
|
+
returned by ``encode_prefix``; ``sliding_window`` is used only to
|
|
65
|
+
decide whether this optimisation is safe (requires all ``None``).
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
offloading = False
|
|
69
|
+
is_compileable = False
|
|
70
|
+
is_initialized = True
|
|
71
|
+
|
|
72
|
+
def __init__(self, layers: tuple[tuple[Any, Any, Any], ...]) -> None:
|
|
73
|
+
self.layers = [
|
|
74
|
+
ReadOnlyPrefixCacheLayer(keys, values) for keys, values, _ in layers
|
|
75
|
+
]
|
|
76
|
+
|
|
77
|
+
def update(
|
|
78
|
+
self,
|
|
79
|
+
key_states: Any,
|
|
80
|
+
value_states: Any,
|
|
81
|
+
layer_idx: int,
|
|
82
|
+
*args: Any,
|
|
83
|
+
**kwargs: Any,
|
|
84
|
+
) -> tuple[Any, Any]:
|
|
85
|
+
return self.layers[layer_idx].update(key_states, value_states, *args, **kwargs)
|
|
86
|
+
|
|
87
|
+
def get_seq_length(self, layer_idx: int = 0) -> int:
|
|
88
|
+
if layer_idx >= len(self.layers):
|
|
89
|
+
return 0
|
|
90
|
+
return self.layers[layer_idx].get_seq_length()
|
|
91
|
+
|
|
92
|
+
def get_mask_sizes(self, query_length: int, layer_idx: int = 0) -> tuple[int, int]:
|
|
93
|
+
if layer_idx >= len(self.layers):
|
|
94
|
+
return int(query_length), 0
|
|
95
|
+
return self.layers[layer_idx].get_mask_sizes(query_length)
|
|
96
|
+
|
|
97
|
+
def get_max_cache_shape(self, layer_idx: int = 0) -> int:
|
|
98
|
+
if layer_idx >= len(self.layers):
|
|
99
|
+
return -1
|
|
100
|
+
return self.layers[layer_idx].get_max_cache_shape()
|
|
101
|
+
|
|
102
|
+
@property
|
|
103
|
+
def is_sliding(self) -> list[bool]:
|
|
104
|
+
return [False for _ in self.layers]
|
|
105
|
+
|
|
106
|
+
def __len__(self) -> int:
|
|
107
|
+
return len(self.layers)
|
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
"""ACT adapter - single-shot Action Chunking Transformer.
|
|
2
|
+
|
|
3
|
+
ACT is the simple execution shape: one forward pass maps an observation to a
|
|
4
|
+
full action chunk ``[B, chunk_size, action_dim]``; there is no sampling loop and
|
|
5
|
+
no language prompt. This adapter declares a single ``"forward"`` region and
|
|
6
|
+
replicates exactly the image-list assembly that ``ACTAgent.predict_action_chunk``
|
|
7
|
+
does before calling the module.
|
|
8
|
+
|
|
9
|
+
At inference the ACT forward is a fixed-shape, purely device-side computation
|
|
10
|
+
(the VAE encoder is skipped in eval, so there is no host-side control flow or
|
|
11
|
+
RNG): a ResNet backbone per camera plus a transformer encoder/decoder. That
|
|
12
|
+
makes it an ideal single-graph capture target. The region therefore takes its
|
|
13
|
+
inputs as *positional CUDA tensors* (``state`` then one tensor per camera) -
|
|
14
|
+
the form a graph backend can record over static buffers - and reconstructs the
|
|
15
|
+
tiny ``{state, images}`` dict the module indexes internally. Capture only kicks
|
|
16
|
+
in for the simple ``state + images`` configuration; exotic feature sets
|
|
17
|
+
(environment-state or DoF features) keep the eager dict path.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
from __future__ import annotations
|
|
21
|
+
|
|
22
|
+
from typing import Any, Dict, List
|
|
23
|
+
|
|
24
|
+
from sparkmind.data.constants import OBS_IMAGES, OBS_STATE
|
|
25
|
+
|
|
26
|
+
from sparkrt.core.adapter import Capabilities, ModelAdapter
|
|
27
|
+
from sparkrt.core.region import Region
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class ACTAdapter(ModelAdapter):
|
|
31
|
+
"""Adapter for :class:`sparkmind...act_model.ACTModel` via an ``ACTAgent``.
|
|
32
|
+
|
|
33
|
+
:param agent: A constructed SparkMind2 ``ACTAgent`` (provides ``.model``,
|
|
34
|
+
``.cfg`` and ``.image_features``).
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(self, agent: Any) -> None:
|
|
38
|
+
super().__init__()
|
|
39
|
+
self._agent = agent
|
|
40
|
+
self._model = agent.model
|
|
41
|
+
cfg = agent.cfg
|
|
42
|
+
# image_features is a dict {obs_key: feature}; we only need the keys,
|
|
43
|
+
# in order, to assemble the OBS_IMAGES list the module expects.
|
|
44
|
+
self._image_keys = list(getattr(agent, "image_features", {}) or {})
|
|
45
|
+
|
|
46
|
+
model = self._model
|
|
47
|
+
self._has_state = getattr(model, "robot_state_feature", None) is not None
|
|
48
|
+
has_env_state = getattr(model, "env_state_feature", None) is not None
|
|
49
|
+
uses_dof = bool(getattr(model, "use_dof_features", False))
|
|
50
|
+
# The positional/captured path covers the common ACT shape (robot state
|
|
51
|
+
# + camera images). Anything that pulls extra batch keys into the module
|
|
52
|
+
# forward stays on the eager dict path to preserve correctness.
|
|
53
|
+
self._capturable = (
|
|
54
|
+
bool(self._image_keys) and not has_env_state and not uses_dof
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
model_cfg = cfg.Model
|
|
58
|
+
action_dim = int(model_cfg.output_features["action"].shape[0])
|
|
59
|
+
state_dim = 0
|
|
60
|
+
if self._has_state:
|
|
61
|
+
state_feat = model_cfg.input_features.get(OBS_STATE)
|
|
62
|
+
if state_feat is not None:
|
|
63
|
+
state_dim = int(state_feat.shape[0])
|
|
64
|
+
ensemble_coeff = getattr(cfg.Trainer, "temporal_ensemble_coeff", None)
|
|
65
|
+
self._ensemble_coeff = ensemble_coeff
|
|
66
|
+
self._caps = Capabilities(
|
|
67
|
+
requires_prompt=False,
|
|
68
|
+
is_iterative=False,
|
|
69
|
+
num_inference_steps=1,
|
|
70
|
+
chunk_size=int(model_cfg.chunk_size),
|
|
71
|
+
n_action_steps=int(model_cfg.n_action_steps),
|
|
72
|
+
action_dim=action_dim,
|
|
73
|
+
supports_temporal_ensemble=ensemble_coeff is not None,
|
|
74
|
+
camera_keys=tuple(self._image_keys),
|
|
75
|
+
state_dim=state_dim,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
@property
|
|
79
|
+
def capabilities(self) -> Capabilities:
|
|
80
|
+
return self._caps
|
|
81
|
+
|
|
82
|
+
def build_regions(self) -> Dict[str, Region]:
|
|
83
|
+
model = self._model
|
|
84
|
+
|
|
85
|
+
if not self._capturable:
|
|
86
|
+
# Eager fallback: the module indexes whatever keys it needs straight
|
|
87
|
+
# out of the full batch dict (env-state / DoF configurations).
|
|
88
|
+
def forward_dict(batch: Dict[str, Any]):
|
|
89
|
+
return model(batch)[0]
|
|
90
|
+
|
|
91
|
+
return {"forward": Region("forward", forward_dict, capturable=False)}
|
|
92
|
+
|
|
93
|
+
has_state = self._has_state
|
|
94
|
+
|
|
95
|
+
def forward(*tensors: Any):
|
|
96
|
+
# Positional layout: (state, img0, img1, ...) when a robot-state
|
|
97
|
+
# feature is present, else (img0, img1, ...). Rebuild the minimal
|
|
98
|
+
# dict the module forward indexes; ACTModel returns
|
|
99
|
+
# (actions, (mu, log_sigma)) and we want the actions.
|
|
100
|
+
if has_state:
|
|
101
|
+
batch = {OBS_STATE: tensors[0], OBS_IMAGES: list(tensors[1:])}
|
|
102
|
+
else:
|
|
103
|
+
batch = {OBS_IMAGES: list(tensors)}
|
|
104
|
+
return model(batch)[0]
|
|
105
|
+
|
|
106
|
+
# Fixed-shape backbone + transformer over static buffers -> capturable.
|
|
107
|
+
return {"forward": Region("forward", forward, capturable=True)}
|
|
108
|
+
|
|
109
|
+
def predict_chunk(self, ctx: Any, batch: Dict[str, Any], *, noise: Any = None):
|
|
110
|
+
if not self._capturable:
|
|
111
|
+
# Assemble the multi-camera image list exactly as the agent does.
|
|
112
|
+
batch = dict(batch)
|
|
113
|
+
batch[OBS_IMAGES] = [batch[key] for key in self._image_keys]
|
|
114
|
+
return self.region("forward")(batch)
|
|
115
|
+
|
|
116
|
+
# Pass the exact tensors the module needs, positionally, so a graph
|
|
117
|
+
# backend can capture/replay the region.
|
|
118
|
+
images = [batch[key] for key in self._image_keys]
|
|
119
|
+
args: List[Any] = [batch[OBS_STATE], *images] if self._has_state else images
|
|
120
|
+
return self.region("forward")(*args)
|
|
121
|
+
|
|
122
|
+
def make_ensembler(self) -> Any:
|
|
123
|
+
from sparkmind.learning.IL.models.act_model import ACTTemporalEnsembler
|
|
124
|
+
|
|
125
|
+
if self._ensemble_coeff is None:
|
|
126
|
+
raise NotImplementedError("temporal ensemble not enabled for this model")
|
|
127
|
+
return ACTTemporalEnsembler(self._ensemble_coeff, self._caps.chunk_size)
|
|
@@ -0,0 +1,381 @@
|
|
|
1
|
+
"""Pi0.5 adapter - VLA with a flow-matching denoising loop.
|
|
2
|
+
|
|
3
|
+
Pi0.5 is the latency-critical, *stateful-per-chunk* shape:
|
|
4
|
+
|
|
5
|
+
prepare images + language -> embed prefix (SigLIP + PaliGemma) -> KV cache
|
|
6
|
+
-> denoise loop (N steps over a 300M expert)
|
|
7
|
+
-> action chunk
|
|
8
|
+
|
|
9
|
+
The adapter declares regions that mirror exactly what
|
|
10
|
+
``PI05Pytorch.sample_actions`` does, but split at the natural seams:
|
|
11
|
+
|
|
12
|
+
* ``encode_prefix`` (run once per chunk, **not** capturable): embeds the images
|
|
13
|
+
and language prompt and runs the PaliGemma prefill to build the KV cache. This
|
|
14
|
+
involves HF control flow / variable content and is not the hot path.
|
|
15
|
+
* ``embed_suffix`` (run ``num_inference_steps`` times, **not** capturable):
|
|
16
|
+
embeds the noisy actions + timestep into the suffix tokens. It is cheap (a few
|
|
17
|
+
small projections) but builds its attention-mask constant via
|
|
18
|
+
``torch.tensor([...], device=cuda)`` - a host->device copy that is illegal
|
|
19
|
+
during graph capture - so it stays eager.
|
|
20
|
+
* ``denoise_step`` (run ``num_inference_steps`` times, **capturable**): the
|
|
21
|
+
expensive expert-transformer forward over the fixed-shape suffix against the
|
|
22
|
+
cached prefix, plus the action projection. This is the latency-critical inner
|
|
23
|
+
hot path and the CUDA-graph capture target. It mirrors the body of
|
|
24
|
+
``PI05Pytorch.denoise_step`` *after* ``embed_suffix`` and reuses every
|
|
25
|
+
``nn.Module`` (the expert, ``action_out_proj``); only the mask-assembly glue
|
|
26
|
+
is inlined, exactly as ``encode_prefix`` mirrors the prefill glue.
|
|
27
|
+
|
|
28
|
+
The KV cache is passed to ``denoise_step`` as a *flat tuple of tensors* (so the
|
|
29
|
+
graph backend can allocate one static buffer per tensor and replay across
|
|
30
|
+
chunks); the per-layer ``sliding_window`` metadata - constant for the model - is
|
|
31
|
+
held on the adapter and used to rebuild the ``DynamicCache`` inside the region.
|
|
32
|
+
The denoise loop body (``x_t += dt * v_t``) stays in Python in ``predict_chunk``,
|
|
33
|
+
identical to ``sample_actions``, which keeps numerical parity exact.
|
|
34
|
+
|
|
35
|
+
Image/language preprocessing intentionally stays *outside* the regions (it has
|
|
36
|
+
data-dependent shapes and control flow), so the captured hot path is a pure
|
|
37
|
+
tensor-in/tensor-out graph a native core could reimplement 1:1.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
from __future__ import annotations
|
|
41
|
+
|
|
42
|
+
from dataclasses import dataclass, field
|
|
43
|
+
from typing import Any, Dict, List, Optional
|
|
44
|
+
|
|
45
|
+
from sparkrt.adapters._pi05_kv_cache import ReadOnlyPrefixCache
|
|
46
|
+
from sparkrt.config.runtime import Pi05RuntimeConfig
|
|
47
|
+
from sparkrt.core.adapter import Capabilities, ModelAdapter
|
|
48
|
+
from sparkrt.core.region import Region
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@dataclass
|
|
52
|
+
class _Pi05Context:
|
|
53
|
+
schedules: dict[tuple[str, int, int, str], tuple[list[Any], list[float]]] = field(
|
|
54
|
+
default_factory=dict
|
|
55
|
+
)
|
|
56
|
+
suffix_masks: dict[tuple[str, int, int, str], tuple[Any, Any]] = field(
|
|
57
|
+
default_factory=dict
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class Pi05Adapter(ModelAdapter):
|
|
62
|
+
"""Adapter for ``PI05Pytorch`` via a SparkMind2 ``Pi05Agent``.
|
|
63
|
+
|
|
64
|
+
:param agent: A constructed SparkMind2 ``Pi05Agent`` (provides ``.model``,
|
|
65
|
+
``.cfg`` and the ``prepare_images`` / ``_get_language_inputs`` helpers
|
|
66
|
+
used to build model inputs).
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
def __init__(self, agent: Any, config: Optional[Pi05RuntimeConfig] = None) -> None:
|
|
70
|
+
super().__init__()
|
|
71
|
+
self._agent = agent
|
|
72
|
+
self._model = agent._core_model()
|
|
73
|
+
cfg = agent.cfg
|
|
74
|
+
model_cfg = cfg.Model
|
|
75
|
+
self._action_dim = int(model_cfg.output_features["action"].shape[0])
|
|
76
|
+
self._chunk_size = int(model_cfg.chunk_size)
|
|
77
|
+
self._max_action_dim = int(getattr(model_cfg, "max_action_dim", self._action_dim))
|
|
78
|
+
|
|
79
|
+
# Camera keys (full ``observation.images.*`` keys, in order) and robot
|
|
80
|
+
# state dim, surfaced through Capabilities for the SDK observation layer.
|
|
81
|
+
input_features = model_cfg.input_features or {}
|
|
82
|
+
self._image_keys = list(getattr(agent, "image_features", []) or [])
|
|
83
|
+
from sparkmind.data.constants import OBS_STATE
|
|
84
|
+
|
|
85
|
+
state_feat = input_features.get(OBS_STATE)
|
|
86
|
+
self._state_dim = int(state_feat.shape[0]) if state_feat is not None else 0
|
|
87
|
+
|
|
88
|
+
# Resolve runtime config: explicit > env vars > defaults.
|
|
89
|
+
self._rtcfg = config if config is not None else Pi05RuntimeConfig.from_env()
|
|
90
|
+
|
|
91
|
+
# Resolve num_steps: config value wins; None means read from checkpoint.
|
|
92
|
+
if self._rtcfg.num_steps is not None:
|
|
93
|
+
self._num_steps = self._rtcfg.num_steps
|
|
94
|
+
else:
|
|
95
|
+
self._num_steps = int(getattr(model_cfg, "num_inference_steps", 10))
|
|
96
|
+
|
|
97
|
+
# Per-layer KV-cache sliding-window metadata (constant for the model);
|
|
98
|
+
# learned lazily from the first prefix so denoise_step can rebuild the
|
|
99
|
+
# DynamicCache from a flat tensor tuple.
|
|
100
|
+
self._sliding_windows: Optional[List[Any]] = None
|
|
101
|
+
self._caps = Capabilities(
|
|
102
|
+
requires_prompt=True,
|
|
103
|
+
is_iterative=True,
|
|
104
|
+
num_inference_steps=self._num_steps,
|
|
105
|
+
chunk_size=self._chunk_size,
|
|
106
|
+
n_action_steps=int(model_cfg.n_action_steps),
|
|
107
|
+
action_dim=self._action_dim,
|
|
108
|
+
supports_temporal_ensemble=False,
|
|
109
|
+
camera_keys=tuple(self._image_keys),
|
|
110
|
+
state_dim=self._state_dim,
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
@property
|
|
114
|
+
def capabilities(self) -> Capabilities:
|
|
115
|
+
return self._caps
|
|
116
|
+
|
|
117
|
+
def new_context(self) -> _Pi05Context:
|
|
118
|
+
return _Pi05Context()
|
|
119
|
+
|
|
120
|
+
def _get_schedule(
|
|
121
|
+
self,
|
|
122
|
+
ctx: _Pi05Context,
|
|
123
|
+
device: Any,
|
|
124
|
+
batch_size: int,
|
|
125
|
+
) -> tuple[list[Any], list[float]]:
|
|
126
|
+
import torch
|
|
127
|
+
|
|
128
|
+
schedule = self._rtcfg.schedule
|
|
129
|
+
key = (str(device), int(batch_size), self._num_steps, schedule)
|
|
130
|
+
cached = ctx.schedules.get(key)
|
|
131
|
+
if cached is not None:
|
|
132
|
+
return cached
|
|
133
|
+
|
|
134
|
+
if schedule == "uniform":
|
|
135
|
+
power = 1.0
|
|
136
|
+
elif schedule == "power1.5":
|
|
137
|
+
power = 1.5
|
|
138
|
+
else:
|
|
139
|
+
power = 2.0
|
|
140
|
+
|
|
141
|
+
t_nodes = [
|
|
142
|
+
((self._num_steps - idx) / self._num_steps) ** power
|
|
143
|
+
for idx in range(self._num_steps + 1)
|
|
144
|
+
]
|
|
145
|
+
timesteps = [
|
|
146
|
+
torch.tensor(t_nodes[idx], dtype=torch.float32, device=device).expand(batch_size)
|
|
147
|
+
for idx in range(self._num_steps)
|
|
148
|
+
]
|
|
149
|
+
dts = [t_nodes[idx + 1] - t_nodes[idx] for idx in range(self._num_steps)]
|
|
150
|
+
cached = (timesteps, dts)
|
|
151
|
+
ctx.schedules[key] = cached
|
|
152
|
+
return cached
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def _get_suffix_masks(
|
|
156
|
+
self,
|
|
157
|
+
ctx: _Pi05Context,
|
|
158
|
+
device: Any,
|
|
159
|
+
batch_size: int,
|
|
160
|
+
dtype: Any,
|
|
161
|
+
) -> tuple[Any, Any]:
|
|
162
|
+
import torch
|
|
163
|
+
|
|
164
|
+
key = (str(device), int(batch_size), self._chunk_size, str(dtype))
|
|
165
|
+
cached = ctx.suffix_masks.get(key)
|
|
166
|
+
if cached is not None:
|
|
167
|
+
return cached
|
|
168
|
+
|
|
169
|
+
suffix_pad_masks = torch.ones(
|
|
170
|
+
batch_size,
|
|
171
|
+
self._chunk_size,
|
|
172
|
+
dtype=torch.bool,
|
|
173
|
+
device=device,
|
|
174
|
+
)
|
|
175
|
+
suffix_att_mask_1d = torch.zeros(self._chunk_size, dtype=dtype, device=device)
|
|
176
|
+
suffix_att_mask_1d[0] = 1
|
|
177
|
+
suffix_att_masks = suffix_att_mask_1d[None, :].expand(batch_size, self._chunk_size)
|
|
178
|
+
cached = (suffix_pad_masks, suffix_att_masks)
|
|
179
|
+
ctx.suffix_masks[key] = cached
|
|
180
|
+
return cached
|
|
181
|
+
|
|
182
|
+
def build_regions(self) -> Dict[str, Region]:
|
|
183
|
+
import torch
|
|
184
|
+
import torch.nn.functional as F
|
|
185
|
+
|
|
186
|
+
from sparkmind.learning.VLA.models.pi05_model import (
|
|
187
|
+
clone_past_key_values,
|
|
188
|
+
create_sinusoidal_pos_embedding,
|
|
189
|
+
make_att_2d_masks,
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
model = self._model
|
|
193
|
+
chunk_size = self._chunk_size
|
|
194
|
+
attn_impl = self._rtcfg.attn_impl
|
|
195
|
+
use_readonly_prefix_cache = self._resolve_readonly_prefix_cache()
|
|
196
|
+
|
|
197
|
+
def encode_prefix(images, img_masks, tokens, masks):
|
|
198
|
+
# Mirrors the prefix section of PI05Pytorch.sample_actions exactly.
|
|
199
|
+
prefix_embs, prefix_pad_masks, prefix_att_masks = model.embed_prefix(
|
|
200
|
+
images, img_masks, tokens, masks
|
|
201
|
+
)
|
|
202
|
+
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
|
203
|
+
prefix_position_ids = prefix_pad_masks.cumsum(dim=1) - 1
|
|
204
|
+
prefix_att_2d_masks_4d = model._prepare_attention_masks_4d(prefix_att_2d_masks)
|
|
205
|
+
language_model = model.paligemma_with_expert.paligemma.model.language_model
|
|
206
|
+
language_model.config._attn_implementation = attn_impl # noqa: SLF001
|
|
207
|
+
if attn_impl == "sdpa":
|
|
208
|
+
# HF's SDPA path uses the 4D mask as an additive bias and
|
|
209
|
+
# requires it to match the query dtype (bf16); the eager path
|
|
210
|
+
# tolerated the float32 mask. Cast to the attention weight dtype.
|
|
211
|
+
prefix_att_2d_masks_4d = prefix_att_2d_masks_4d.to(
|
|
212
|
+
language_model.layers[0].self_attn.q_proj.weight.dtype
|
|
213
|
+
)
|
|
214
|
+
_, past_key_values = model.paligemma_with_expert.forward(
|
|
215
|
+
attention_mask=prefix_att_2d_masks_4d,
|
|
216
|
+
position_ids=prefix_position_ids,
|
|
217
|
+
past_key_values=None,
|
|
218
|
+
inputs_embeds=[prefix_embs, None],
|
|
219
|
+
use_cache=True,
|
|
220
|
+
)
|
|
221
|
+
return prefix_pad_masks, past_key_values
|
|
222
|
+
|
|
223
|
+
def embed_suffix(x_t, timestep):
|
|
224
|
+
# Mirrors PI05Pytorch.embed_suffix's numeric path, but leaves the
|
|
225
|
+
# fixed suffix masks to the adapter context so they can be reused
|
|
226
|
+
# across denoising steps and chunks.
|
|
227
|
+
time_emb = create_sinusoidal_pos_embedding(
|
|
228
|
+
timestep,
|
|
229
|
+
model.action_in_proj.out_features,
|
|
230
|
+
min_period=model.config.min_period,
|
|
231
|
+
max_period=model.config.max_period,
|
|
232
|
+
device=timestep.device,
|
|
233
|
+
)
|
|
234
|
+
time_emb = time_emb.type(dtype=timestep.dtype)
|
|
235
|
+
action_emb = model.action_in_proj(x_t)
|
|
236
|
+
time_emb = model.time_mlp_in(time_emb)
|
|
237
|
+
time_emb = F.silu(time_emb)
|
|
238
|
+
time_emb = model.time_mlp_out(time_emb)
|
|
239
|
+
adarms_cond = F.silu(time_emb)
|
|
240
|
+
return action_emb, adarms_cond
|
|
241
|
+
|
|
242
|
+
def denoise_step(
|
|
243
|
+
suffix_embs,
|
|
244
|
+
adarms_cond,
|
|
245
|
+
prefix_pad_masks,
|
|
246
|
+
suffix_pad_masks,
|
|
247
|
+
suffix_att_masks,
|
|
248
|
+
*kv_tensors,
|
|
249
|
+
):
|
|
250
|
+
# Mirrors PI05Pytorch.denoise_step *after* embed_suffix: assemble the
|
|
251
|
+
# full attention mask + position ids, run the expert against the
|
|
252
|
+
# (cloned) cached prefix, and project to a velocity. All ops are
|
|
253
|
+
# device-side and fixed-shape, so this body is graph-capturable.
|
|
254
|
+
#
|
|
255
|
+
# Signature ordering matters for the graph backend: the two inputs
|
|
256
|
+
# that change every denoising step (``suffix_embs``, ``adarms_cond``)
|
|
257
|
+
# come first; everything after them - the prefix/suffix masks and the
|
|
258
|
+
# KV cache - is constant across the loop, so ``invariant_from=2`` lets
|
|
259
|
+
# the backend skip re-copying the (large) KV tensors on each replay.
|
|
260
|
+
sliding = self._sliding_windows or []
|
|
261
|
+
layers = tuple(
|
|
262
|
+
(kv_tensors[2 * i], kv_tensors[2 * i + 1], sliding[i])
|
|
263
|
+
for i in range(len(sliding))
|
|
264
|
+
)
|
|
265
|
+
can_use_readonly_prefix_cache = use_readonly_prefix_cache and all(
|
|
266
|
+
sliding_window is None for sliding_window in sliding
|
|
267
|
+
)
|
|
268
|
+
if can_use_readonly_prefix_cache:
|
|
269
|
+
past_key_values = ReadOnlyPrefixCache(layers)
|
|
270
|
+
else:
|
|
271
|
+
past_key_values = clone_past_key_values(layers)
|
|
272
|
+
|
|
273
|
+
suffix_len = suffix_pad_masks.shape[1]
|
|
274
|
+
batch_size = prefix_pad_masks.shape[0]
|
|
275
|
+
prefix_len = prefix_pad_masks.shape[1]
|
|
276
|
+
prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(
|
|
277
|
+
batch_size, suffix_len, prefix_len
|
|
278
|
+
)
|
|
279
|
+
suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
|
|
280
|
+
full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2)
|
|
281
|
+
prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
|
|
282
|
+
position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
|
|
283
|
+
full_att_2d_masks_4d = model._prepare_attention_masks_4d(full_att_2d_masks)
|
|
284
|
+
gemma_expert = model.paligemma_with_expert.gemma_expert.model
|
|
285
|
+
gemma_expert.config._attn_implementation = attn_impl # noqa: SLF001
|
|
286
|
+
if attn_impl == "sdpa":
|
|
287
|
+
full_att_2d_masks_4d = full_att_2d_masks_4d.to(
|
|
288
|
+
gemma_expert.layers[0].self_attn.q_proj.weight.dtype
|
|
289
|
+
)
|
|
290
|
+
outputs_embeds, _ = model.paligemma_with_expert.forward(
|
|
291
|
+
attention_mask=full_att_2d_masks_4d,
|
|
292
|
+
position_ids=position_ids,
|
|
293
|
+
past_key_values=past_key_values,
|
|
294
|
+
inputs_embeds=[None, suffix_embs],
|
|
295
|
+
use_cache=False,
|
|
296
|
+
adarms_cond=[None, adarms_cond],
|
|
297
|
+
)
|
|
298
|
+
suffix_out = outputs_embeds[1][:, -chunk_size:].to(dtype=torch.float32)
|
|
299
|
+
return model.action_out_proj(suffix_out)
|
|
300
|
+
|
|
301
|
+
return {
|
|
302
|
+
# Prefill: HF control flow + variable content -> eager only.
|
|
303
|
+
"encode_prefix": Region("encode_prefix", encode_prefix, capturable=False),
|
|
304
|
+
# Suffix embedding: cheap eager region; fixed suffix masks are
|
|
305
|
+
# cached separately in the adapter context.
|
|
306
|
+
"embed_suffix": Region("embed_suffix", embed_suffix, capturable=False),
|
|
307
|
+
# Expert forward + projection: fixed-shape hot path -> captured.
|
|
308
|
+
# Inputs 2.. (masks + KV cache) are loop-invariant within a chunk.
|
|
309
|
+
"denoise_step": Region(
|
|
310
|
+
"denoise_step", denoise_step, capturable=True, invariant_from=2
|
|
311
|
+
),
|
|
312
|
+
}
|
|
313
|
+
|
|
314
|
+
def _resolve_readonly_prefix_cache(self) -> bool:
|
|
315
|
+
mode = self._rtcfg.readonly_prefix_cache.lower()
|
|
316
|
+
if mode in {"1", "true", "yes", "on"}:
|
|
317
|
+
return True
|
|
318
|
+
if mode in {"0", "false", "no", "off"}:
|
|
319
|
+
return False
|
|
320
|
+
# auto: use read-only cache for cudagraph (where it was validated),
|
|
321
|
+
# but not for torchcompile (adds numeric drift when combined).
|
|
322
|
+
backend_name = getattr(self._backend, "name", "")
|
|
323
|
+
return backend_name == "cudagraph"
|
|
324
|
+
|
|
325
|
+
def predict_chunk(self, ctx: Any, batch: Dict[str, Any], *, noise: Any = None):
|
|
326
|
+
import torch
|
|
327
|
+
|
|
328
|
+
model = self._model
|
|
329
|
+
# Build model inputs using the agent's own (parity-exact) helpers; this
|
|
330
|
+
# is preprocessing and stays outside the region bodies.
|
|
331
|
+
images, img_masks = self._agent.prepare_images(batch)
|
|
332
|
+
batch_size = images[0].shape[0]
|
|
333
|
+
device = images[0].device
|
|
334
|
+
tokens, masks = self._agent._get_language_inputs(batch, batch_size, device)
|
|
335
|
+
|
|
336
|
+
# Sample noise *before* the prefill, exactly as sample_actions does, so
|
|
337
|
+
# seeded parity matches bit-for-bit (the prefill consumes no RNG).
|
|
338
|
+
if noise is None:
|
|
339
|
+
actions_shape = (batch_size, self._chunk_size, self._max_action_dim)
|
|
340
|
+
noise = model.sample_noise(actions_shape, device)
|
|
341
|
+
|
|
342
|
+
prefix_pad_masks, past_key_values = self.region("encode_prefix")(
|
|
343
|
+
images, img_masks, tokens, masks
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
# Flatten the KV cache into positional tensors for the capturable region;
|
|
347
|
+
# record the (constant) per-layer sliding windows on first use.
|
|
348
|
+
kv_tensors: List[Any] = []
|
|
349
|
+
sliding_windows: List[Any] = []
|
|
350
|
+
for keys, values, sliding_window in past_key_values:
|
|
351
|
+
kv_tensors.append(keys)
|
|
352
|
+
kv_tensors.append(values)
|
|
353
|
+
sliding_windows.append(sliding_window)
|
|
354
|
+
self._sliding_windows = sliding_windows
|
|
355
|
+
|
|
356
|
+
if not isinstance(ctx, _Pi05Context):
|
|
357
|
+
ctx = self.new_context()
|
|
358
|
+
|
|
359
|
+
embed_suffix = self.region("embed_suffix")
|
|
360
|
+
denoise = self.region("denoise_step")
|
|
361
|
+
x_t = noise
|
|
362
|
+
timesteps, dts = self._get_schedule(ctx, device, batch_size)
|
|
363
|
+
for timestep, dt in zip(timesteps, dts):
|
|
364
|
+
suffix_embs, adarms_cond = embed_suffix(x_t, timestep)
|
|
365
|
+
suffix_pad_masks, suffix_att_masks = self._get_suffix_masks(
|
|
366
|
+
ctx,
|
|
367
|
+
device,
|
|
368
|
+
batch_size,
|
|
369
|
+
suffix_embs.dtype,
|
|
370
|
+
)
|
|
371
|
+
v_t = denoise(
|
|
372
|
+
suffix_embs,
|
|
373
|
+
adarms_cond,
|
|
374
|
+
prefix_pad_masks,
|
|
375
|
+
suffix_pad_masks,
|
|
376
|
+
suffix_att_masks,
|
|
377
|
+
*kv_tensors,
|
|
378
|
+
)
|
|
379
|
+
x_t = x_t + dt * v_t
|
|
380
|
+
|
|
381
|
+
return x_t[:, :, : self._action_dim]
|