migma 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
migma-0.1.0/.gitignore ADDED
@@ -0,0 +1,6 @@
1
+ vendor/
2
+ __pycache__/
3
+ build/
4
+ *egg-info/
5
+ dist/
6
+ *.egg
migma-0.1.0/PKG-INFO ADDED
@@ -0,0 +1,26 @@
1
+ Metadata-Version: 2.4
2
+ Name: migma
3
+ Version: 0.1.0
4
+ Summary: Modular action policy components for Vision-Language-Action models
5
+ Project-URL: Repository, https://github.com/lucasjinreal/migma
6
+ License: Apache-2.0
7
+ Keywords: action-policy,diffusion,flow-matching,robotics,vla
8
+ Classifier: Development Status :: 3 - Alpha
9
+ Classifier: Intended Audience :: Science/Research
10
+ Classifier: License :: OSI Approved :: Apache Software License
11
+ Classifier: Programming Language :: Python :: 3
12
+ Classifier: Programming Language :: Python :: 3.9
13
+ Classifier: Programming Language :: Python :: 3.10
14
+ Classifier: Programming Language :: Python :: 3.11
15
+ Classifier: Programming Language :: Python :: 3.12
16
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
17
+ Requires-Python: >=3.9
18
+ Requires-Dist: diffusers>=0.27.0
19
+ Requires-Dist: numpy>=1.24.0
20
+ Requires-Dist: torch>=2.0.0
21
+ Requires-Dist: transformers>=4.40.0
22
+ Provides-Extra: dev
23
+ Requires-Dist: black>=23.0.0; extra == 'dev'
24
+ Requires-Dist: pytest>=7.0.0; extra == 'dev'
25
+ Requires-Dist: ruff>=0.1.0; extra == 'dev'
26
+ Requires-Dist: twine>=4.0.0; extra == 'dev'
@@ -0,0 +1,21 @@
1
+ """
2
+ migma — Modular action policy components for Vision-Language-Action (VLA) models.
3
+
4
+ Provides layerwise flow-matching action heads that attach to a frozen VLM backbone
5
+ and decode robot actions from its per-layer KV cache or hidden states.
6
+ """
7
+
8
+ from migma.models.action_heads import (
9
+ FlowmatchingActionHeadConfig,
10
+ LayerwiseFMDSCActionHead,
11
+ LayerwiseFMPi05ActionHead,
12
+ LayerwiseFMPiActionHead,
13
+ )
14
+
15
+ __version__ = "0.1.0"
16
+ __all__ = [
17
+ "FlowmatchingActionHeadConfig",
18
+ "LayerwiseFMPiActionHead",
19
+ "LayerwiseFMPi05ActionHead",
20
+ "LayerwiseFMDSCActionHead",
21
+ ]
@@ -0,0 +1,11 @@
1
+ from migma.models.action_heads import (
2
+ FlowmatchingActionHeadConfig,
3
+ LayerwiseFMDSCActionHead,
4
+ LayerwiseFMPiActionHead,
5
+ )
6
+
7
+ __all__ = [
8
+ "FlowmatchingActionHeadConfig",
9
+ "LayerwiseFMPiActionHead",
10
+ "LayerwiseFMDSCActionHead",
11
+ ]
@@ -0,0 +1,11 @@
1
+ from migma.models.action_heads.configuration import FlowmatchingActionHeadConfig
2
+ from migma.models.action_heads.layerwise_fm_dsc import LayerwiseFMDSCActionHead
3
+ from migma.models.action_heads.layerwise_fm_pi import LayerwiseFMPiActionHead
4
+ from migma.models.action_heads.layerwise_fm_pi05 import LayerwiseFMPi05ActionHead
5
+
6
+ __all__ = [
7
+ "FlowmatchingActionHeadConfig",
8
+ "LayerwiseFMPiActionHead",
9
+ "LayerwiseFMDSCActionHead",
10
+ "LayerwiseFMPi05ActionHead",
11
+ ]
@@ -0,0 +1,121 @@
1
+ """
2
+ Configuration dataclass for layerwise flow-matching action heads.
3
+ """
4
+
5
+ from dataclasses import dataclass, field
6
+ from typing import Any, Dict, Optional
7
+
8
+
9
+ @dataclass
10
+ class FlowmatchingActionHeadConfig:
11
+ """
12
+ Unified configuration for all layerwise flow-matching action head variants.
13
+
14
+ The two concrete head classes (:class:`~migma.models.action_heads.LayerwiseFMPiActionHead`
15
+ and :class:`~migma.models.action_heads.LayerwiseFMDSCActionHead`) share this
16
+ configuration schema and ignore fields that are not relevant to them.
17
+
18
+ Core dimensions
19
+ ---------------
20
+ action_dim:
21
+ Dimension of a single action vector (e.g. 14 for bimanual euler, 20 for r6d).
22
+ action_horizon:
23
+ Total rollout length including the current step
24
+ (``future_action_window_size + 1`` in the original starVLA convention).
25
+ vl_hidden_dim:
26
+ Hidden dimension of the VLM backbone (e.g. 2048 for Qwen2.5-VL-2B).
27
+ num_inference_timesteps:
28
+ Number of Euler integration steps at inference time.
29
+
30
+ DiT architecture
31
+ ----------------
32
+ dit_cfg:
33
+ Dict of overrides merged on top of the default DiT hyperparameters.
34
+ Recognised keys: ``num_layers``, ``num_attention_heads``,
35
+ ``attention_head_dim``, ``dropout``, ``activation_fn``, …
36
+ Any key accepted by the underlying :class:`~migma.nn.DiT` /
37
+ :class:`~migma.nn.DSCDiT` constructor is valid here.
38
+ use_kv_cache_dit:
39
+ When ``True``, uses the KV-cache-capable DiT variant.
40
+ kv_hidden_size:
41
+ Hidden size of the KV tensors supplied by the VLM.
42
+ kv_compress_ratio:
43
+ If < 1.0, a :class:`~migma.nn.GatedMLPCompressor` is applied to the KV
44
+ cache before projection (DSC head only).
45
+ global_cond_dim:
46
+ Dimension of ``vlm_last_hidden_state``. Defaults to ``kv_hidden_size``
47
+ when ``None`` (DSC head only).
48
+
49
+ State encoder
50
+ -------------
51
+ state_dim:
52
+ When set, a two-layer MLP maps the proprioceptive state vector to
53
+ ``vl_hidden_dim`` before concatenating with action tokens.
54
+
55
+ Positional embedding
56
+ --------------------
57
+ add_pos_embed:
58
+ Whether to add a learnable positional embedding to action features.
59
+ max_seq_len:
60
+ Maximum sequence length for the positional embedding table.
61
+ num_target_vision_tokens:
62
+ Number of learnable "future" tokens prepended to the action sequence.
63
+
64
+ Flow-matching noise
65
+ -------------------
66
+ noise_beta_alpha / noise_beta_beta:
67
+ Parameters of the Beta distribution used to sample diffusion time.
68
+ noise_s:
69
+ Flow-matching schedule shift: ``t = (s - Beta_sample) / s``.
70
+ num_timestep_buckets:
71
+ Number of discrete bins for the continuous time variable.
72
+ use_scaled_noise:
73
+ Whether to scale initial noise by per-dimension action statistics.
74
+ t_eps:
75
+ Minimum ``(1 - t)`` denominator to avoid division by zero in the AML
76
+ velocity formula (DSC head only).
77
+
78
+ Loss coefficients
79
+ -----------------
80
+ smoothness_loss_weight:
81
+ Weight for the action smoothness regulariser (0 = disabled).
82
+ enable_first_action_state_loss:
83
+ Anchor the first predicted action to the current proprioceptive state
84
+ (Pi head only).
85
+ first_action_state_loss_weight:
86
+ Weight for the first-action anchor loss.
87
+ """
88
+
89
+ # ── Core dimensions (required) ──────────────────────────────────────────
90
+ action_dim: int
91
+ action_horizon: int
92
+ vl_hidden_dim: int
93
+ num_inference_timesteps: int
94
+
95
+ # ── DiT architecture ────────────────────────────────────────────────────
96
+ dit_cfg: Dict[str, Any] = field(default_factory=dict)
97
+ use_kv_cache_dit: bool = False
98
+ kv_hidden_size: Optional[int] = None # None means same as vl_hidden_dim (no projection)
99
+ kv_compress_ratio: float = 1.0
100
+ global_cond_dim: Optional[int] = None # DSC-specific
101
+
102
+ # ── State encoder ───────────────────────────────────────────────────────
103
+ state_dim: Optional[int] = None
104
+
105
+ # ── Token sequence ──────────────────────────────────────────────────────
106
+ num_target_vision_tokens: int = 32
107
+ add_pos_embed: bool = True
108
+ max_seq_len: int = 1024
109
+
110
+ # ── Flow-matching noise ─────────────────────────────────────────────────
111
+ noise_beta_alpha: float = 1.5
112
+ noise_beta_beta: float = 1.0
113
+ noise_s: float = 0.999
114
+ num_timestep_buckets: int = 1000
115
+ use_scaled_noise: bool = True
116
+ t_eps: float = 5e-2 # DSC/AML-specific
117
+
118
+ # ── Loss coefficients ───────────────────────────────────────────────────
119
+ smoothness_loss_weight: float = 0.0
120
+ enable_first_action_state_loss: bool = False # Pi-specific
121
+ first_action_state_loss_weight: float = 0.0
@@ -0,0 +1,403 @@
1
+ """
2
+ Layerwise Flow-Matching Action Head — AML / DSC variant.
3
+
4
+ Velocity target: ``v = (actions − x_t) / (1 − t)`` (AML / posterior-mean)
5
+ KV conditioning: list of ``(key, value)`` tuples or hidden-state tensors.
6
+ Extra signal: optional ``vlm_last_hidden_state`` drives dual-stream
7
+ conditioning inside :class:`~migma.nn.DSCDiT`.
8
+
9
+ References:
10
+ - Consistent Flow Matching (AML training objective)
11
+ - DSC-DiT dual-stream conditioning design
12
+ """
13
+
14
+ import logging
15
+ from typing import List, Optional, Tuple, Union
16
+
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn as nn
20
+ from torch.distributions import Beta
21
+
22
+ from migma.models.action_heads.configuration import FlowmatchingActionHeadConfig
23
+ from migma.nn.dit import DiT
24
+ from migma.nn.dit_dsc import DSCDiT
25
+ from migma.nn.encoders import ActionEncoder, MLP
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+ _DEFAULT_DIT_CFG = {
30
+ "num_layers": 28,
31
+ "attention_head_dim": 64,
32
+ "dropout": 0.1,
33
+ "attention_bias": True,
34
+ "activation_fn": "gelu-approximate",
35
+ "norm_type": "ada_norm",
36
+ "norm_elementwise_affine": False,
37
+ "norm_eps": 1e-5,
38
+ "max_num_positional_embeddings": 512,
39
+ "final_dropout": True,
40
+ "positional_embeddings": "sinusoidal",
41
+ "interleave_self_attention": False,
42
+ }
43
+
44
+
45
+ def _build_dit_config(cfg: FlowmatchingActionHeadConfig) -> dict:
46
+ final = _DEFAULT_DIT_CFG.copy()
47
+ final.update(cfg.dit_cfg)
48
+
49
+ vl_dim = cfg.vl_hidden_dim
50
+ head_dim = final.get("attention_head_dim", 64)
51
+ final["num_attention_heads"] = final.get("num_attention_heads", vl_dim // head_dim)
52
+ final["output_dim"] = vl_dim
53
+ final["kv_hidden_size"] = final.get("kv_hidden_size", cfg.kv_hidden_size)
54
+ final["kv_compress_ratio"] = final.get("kv_compress_ratio", cfg.kv_compress_ratio)
55
+ final["global_cond_dim"] = final.get(
56
+ "global_cond_dim",
57
+ cfg.global_cond_dim if cfg.global_cond_dim is not None else cfg.kv_hidden_size,
58
+ )
59
+ # cross_attention_dim only used by the standard DiT fallback.
60
+ final["cross_attention_dim"] = final.get("cross_attention_dim", vl_dim)
61
+ return final
62
+
63
+
64
+ class LayerwiseFMDSCActionHead(nn.Module):
65
+ """
66
+ AML flow-matching head with Dual-Stream Conditioned DiT (DSC-DiT).
67
+
68
+ **Training objective** (AML / consistent FM):
69
+
70
+ .. math::
71
+
72
+ x_t = t \\cdot x_0 + (1-t) \\cdot \\epsilon, \\quad
73
+ v_\\text{target} = \\frac{x_0 - x_t}{1-t}
74
+
75
+ where :math:`\\epsilon \\sim \\mathcal{N}(0, I)` (optionally scaled).
76
+
77
+ **Conditioning**:
78
+
79
+ * *Primary* — per-layer ``(key, value)`` tensors from the VLM or a list
80
+ of hidden-state tensors (standard cross-attention).
81
+ * *Secondary* — ``vlm_last_hidden_state`` fed into :class:`~migma.nn.DSCDiT`
82
+ for attention-pooled global adaLN and sequence-level cross-attention.
83
+
84
+ Args:
85
+ config: :class:`~migma.models.action_heads.FlowmatchingActionHeadConfig`
86
+
87
+ Example::
88
+
89
+ cfg = FlowmatchingActionHeadConfig(
90
+ action_dim=14, action_horizon=16,
91
+ vl_hidden_dim=2048, num_inference_timesteps=10,
92
+ use_kv_cache_dit=True,
93
+ kv_hidden_size=512,
94
+ )
95
+ head = LayerwiseFMDSCActionHead(cfg)
96
+
97
+ loss = head(context_list, actions, vlm_last_hidden_state=last_hs)
98
+ actions = head.predict_action(context_list, vlm_last_hidden_state=last_hs)
99
+ """
100
+
101
+ def __init__(self, config: FlowmatchingActionHeadConfig) -> None:
102
+ super().__init__()
103
+ self.config = config
104
+
105
+ dit_kwargs = _build_dit_config(config)
106
+ use_kv_cache = config.use_kv_cache_dit
107
+ if use_kv_cache:
108
+ self.model = DSCDiT(**dit_kwargs)
109
+ logger.info("LayerwiseFMDSCActionHead: using DSCDiT")
110
+ else:
111
+ self.model = DiT(**dit_kwargs)
112
+ logger.info("LayerwiseFMDSCActionHead: using standard DiT (no KV cache)")
113
+
114
+ self.use_kv_cache_dit = use_kv_cache
115
+ self.action_dim = config.action_dim
116
+ self.action_horizon = config.action_horizon
117
+ self.num_inference_timesteps = config.num_inference_timesteps
118
+ self.smoothness_loss_weight = config.smoothness_loss_weight
119
+ self.t_eps = config.t_eps
120
+
121
+ hidden_dim = config.vl_hidden_dim
122
+
123
+ self.state_encoder = (
124
+ MLP(input_dim=config.state_dim, output_dim=hidden_dim)
125
+ if config.state_dim
126
+ else None
127
+ )
128
+ self.action_encoder = ActionEncoder(
129
+ action_dim=config.action_dim, hidden_size=hidden_dim
130
+ )
131
+ self.action_decoder = MLP(
132
+ input_dim=hidden_dim, hidden_dim=1024, output_dim=config.action_dim
133
+ )
134
+
135
+ self.future_tokens = nn.Embedding(config.num_target_vision_tokens, hidden_dim)
136
+ nn.init.normal_(self.future_tokens.weight, mean=0.0, std=0.02)
137
+
138
+ if config.add_pos_embed:
139
+ self.position_embedding = nn.Embedding(config.max_seq_len, hidden_dim)
140
+ nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
141
+
142
+ self.beta_dist = Beta(config.noise_beta_alpha, config.noise_beta_beta)
143
+ self.num_timestep_buckets = config.num_timestep_buckets
144
+ self.use_scaled_noise = config.use_scaled_noise
145
+ self.register_buffer("noise_scale", torch.ones(config.action_dim))
146
+
147
+ logger.info(
148
+ "LayerwiseFMDSCActionHead initialised | "
149
+ f"action_dim={config.action_dim} horizon={config.action_horizon} "
150
+ f"vl_hidden_dim={config.vl_hidden_dim} "
151
+ f"params={sum(p.numel() for p in self.parameters()):,}"
152
+ )
153
+
154
+ # -----------------------------------------------------------------------
155
+ # Helpers
156
+ # -----------------------------------------------------------------------
157
+
158
+ def _sample_time(
159
+ self, batch_size: int, device: torch.device, dtype: torch.dtype
160
+ ) -> torch.Tensor:
161
+ sample = self.beta_dist.sample([batch_size]).to(device, dtype=dtype)
162
+ return (self.config.noise_s - sample) / self.config.noise_s
163
+
164
+ @staticmethod
165
+ def _is_kv_list(context: list) -> bool:
166
+ return (
167
+ isinstance(context, (list, tuple))
168
+ and len(context) > 0
169
+ and isinstance(context[0], (list, tuple))
170
+ )
171
+
172
+ def _build_sa_embs(
173
+ self,
174
+ action_features: torch.Tensor,
175
+ state_features: Optional[torch.Tensor],
176
+ B: int,
177
+ device: torch.device,
178
+ ) -> torch.Tensor:
179
+ future_tokens = self.future_tokens.weight.unsqueeze(0).expand(B, -1, -1)
180
+ if self.config.add_pos_embed:
181
+ pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device)
182
+ action_features = action_features + self.position_embedding(pos_ids).unsqueeze(0)
183
+ if state_features is not None:
184
+ return torch.cat([state_features, future_tokens, action_features], dim=1)
185
+ return torch.cat([future_tokens, action_features], dim=1)
186
+
187
+ def _run_model(
188
+ self,
189
+ sa_embs: torch.Tensor,
190
+ context_list: list,
191
+ t_disc: torch.Tensor,
192
+ is_kv: bool,
193
+ vlm_last_hidden_state: Optional[torch.Tensor],
194
+ ) -> torch.Tensor:
195
+ device, dtype = sa_embs.device, sa_embs.dtype
196
+
197
+ if self.use_kv_cache_dit:
198
+ return self.model(
199
+ hidden_states=sa_embs,
200
+ kv_cache_list=context_list,
201
+ timestep=t_disc,
202
+ encoder_attention_mask=None,
203
+ vlm_last_hidden_state=vlm_last_hidden_state,
204
+ )
205
+
206
+ # Standard DiT: build encoder_hidden_states from context.
207
+ if vlm_last_hidden_state is not None:
208
+ enc_hs = vlm_last_hidden_state.to(device=device, dtype=dtype)
209
+ else:
210
+ enc_hs = (context_list[-1][1] if is_kv else context_list[-1]).to(
211
+ device=device, dtype=dtype
212
+ )
213
+ return self.model(
214
+ hidden_states=sa_embs,
215
+ encoder_hidden_states=enc_hs,
216
+ timestep=t_disc,
217
+ encoder_attention_mask=None,
218
+ vlm_last_hidden_state=vlm_last_hidden_state,
219
+ )
220
+
221
+ # -----------------------------------------------------------------------
222
+ # Noise-scale calibration
223
+ # -----------------------------------------------------------------------
224
+
225
+ def set_noise_scale_from_stats(
226
+ self,
227
+ action_std: Union[list, "np.ndarray", torch.Tensor],
228
+ ) -> None:
229
+ """
230
+ Calibrate per-dimension noise scaling from dataset action statistics.
231
+
232
+ See :meth:`LayerwiseFMPiActionHead.set_noise_scale_from_stats` for the
233
+ full docstring — the logic is identical.
234
+ """
235
+ if isinstance(action_std, (list, np.ndarray)):
236
+ action_std = torch.tensor(action_std, dtype=torch.float32)
237
+
238
+ if len(action_std) == 14 and self.action_dim == 20:
239
+ action_std = _expand_14d_to_20d(action_std)
240
+
241
+ if len(action_std) != self.action_dim:
242
+ raise ValueError(
243
+ f"action_std length {len(action_std)} != action_dim {self.action_dim}."
244
+ )
245
+
246
+ TARGET = 0.3
247
+ scale = torch.clamp(TARGET / (action_std + 1e-6), max=1.0)
248
+ self.noise_scale.data.copy_((action_std * scale).to(self.noise_scale.device))
249
+ logger.info(
250
+ f"Noise scale updated | min={self.noise_scale.min():.4f} "
251
+ f"max={self.noise_scale.max():.4f}"
252
+ )
253
+
254
+ # -----------------------------------------------------------------------
255
+ # Forward (training)
256
+ # -----------------------------------------------------------------------
257
+
258
+ def forward(
259
+ self,
260
+ context_list: List[Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]],
261
+ actions: torch.Tensor,
262
+ state: Optional[torch.Tensor] = None,
263
+ vlm_last_hidden_state: Optional[torch.Tensor] = None,
264
+ ) -> torch.Tensor:
265
+ """
266
+ Compute the AML flow-matching velocity loss.
267
+
268
+ Args:
269
+ context_list: Per-layer context — list of ``(key, value)``
270
+ tuples or list of hidden-state tensors.
271
+ actions: Target trajectory ``(B, T, action_dim)``.
272
+ state: Optional proprioceptive state.
273
+ vlm_last_hidden_state: Optional ``(B, S, D)`` — enables DSC dual-stream
274
+ conditioning when using :class:`~migma.nn.DSCDiT`.
275
+
276
+ Returns:
277
+ Scalar training loss.
278
+ """
279
+ device = actions.device
280
+ is_kv = self._is_kv_list(context_list)
281
+ B = context_list[0][0].shape[0] if is_kv else context_list[0].shape[0]
282
+
283
+ t = self._sample_time(B, device=device, dtype=actions.dtype)
284
+ t = t[:, None, None] # (B, 1, 1)
285
+
286
+ noise = torch.randn(B, self.action_horizon, self.action_dim,
287
+ dtype=actions.dtype, device=device)
288
+ if self.use_scaled_noise:
289
+ noise = noise * self.noise_scale[None, None, :]
290
+
291
+ # AML interpolation: x_t = t·x_0 + (1-t)·ε
292
+ noisy = t * actions + (1 - t) * noise
293
+ # AML velocity target: v = (x_0 - x_t) / (1 - t)
294
+ velocity = (actions - noisy) / (1 - t).clamp_min(self.t_eps)
295
+
296
+ t_disc = (t[:, 0, 0] * self.num_timestep_buckets).long()
297
+ action_features = self.action_encoder(noisy, t_disc)
298
+
299
+ state_features = (
300
+ self.state_encoder(state.to(device=device, dtype=actions.dtype))
301
+ if state is not None and self.state_encoder is not None
302
+ else None
303
+ )
304
+ sa_embs = self._build_sa_embs(action_features, state_features, B, device)
305
+
306
+ model_out = self._run_model(sa_embs, context_list, t_disc, is_kv, vlm_last_hidden_state)
307
+
308
+ # AML: model predicts action samples; velocity derived from predictions.
309
+ pred_actions = self.action_decoder(model_out)[:, -actions.shape[1]:]
310
+ pred_velocity = (pred_actions - noisy) / (1 - t).clamp_min(self.t_eps)
311
+
312
+ loss = ((pred_velocity - velocity) ** 2).mean()
313
+
314
+ if self.smoothness_loss_weight > 0 and pred_actions.shape[1] > 1:
315
+ smooth_loss = ((pred_actions[:, 1:] - pred_actions[:, :-1]) ** 2).mean()
316
+ loss = loss + self.smoothness_loss_weight * smooth_loss
317
+
318
+ return loss
319
+
320
+ # -----------------------------------------------------------------------
321
+ # Predict (inference)
322
+ # -----------------------------------------------------------------------
323
+
324
+ @torch.no_grad()
325
+ def predict_action(
326
+ self,
327
+ context_list: List[Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]],
328
+ state: Optional[torch.Tensor] = None,
329
+ vlm_last_hidden_state: Optional[torch.Tensor] = None,
330
+ ) -> torch.Tensor:
331
+ """
332
+ Denoise from pure noise to a clean action rollout.
333
+
334
+ Args:
335
+ context_list: Same format as :meth:`forward`.
336
+ state: Optional proprioceptive state.
337
+ vlm_last_hidden_state: Optional last-layer VLM hidden states.
338
+
339
+ Returns:
340
+ ``(B, action_horizon, action_dim)``
341
+ """
342
+ is_kv = self._is_kv_list(context_list)
343
+ if is_kv:
344
+ B, device = context_list[0][0].shape[0], context_list[0][0].device
345
+ else:
346
+ B, device = context_list[0].shape[0], context_list[0].device
347
+ dtype = torch.float32
348
+
349
+ actions = torch.randn(B, self.action_horizon, self.action_dim, dtype=dtype, device=device)
350
+ if self.use_scaled_noise:
351
+ actions = actions * self.noise_scale[None, None, :].to(dtype=dtype)
352
+
353
+ state_features = None
354
+ if state is not None and self.state_encoder is not None:
355
+ state_features = self.state_encoder(state.to(device=device, dtype=dtype))
356
+
357
+ dt = 1.0 / self.num_inference_timesteps
358
+ for step in range(self.num_inference_timesteps):
359
+ t_cont = step / float(self.num_inference_timesteps)
360
+ t_disc_int = int(t_cont * self.num_timestep_buckets)
361
+ t_tensor = torch.full((B,), t_disc_int, device=device, dtype=torch.long)
362
+
363
+ action_features = self.action_encoder(actions, t_tensor)
364
+ sa_embs = self._build_sa_embs(action_features, state_features, B, device)
365
+
366
+ model_out = self._run_model(sa_embs, context_list, t_tensor, is_kv,
367
+ vlm_last_hidden_state)
368
+
369
+ pred_actions = self.action_decoder(model_out)[:, -self.action_horizon:]
370
+ # AML: velocity derived from predicted actions
371
+ pred_velocity = (pred_actions - actions) / (1.0 - t_cont)
372
+ actions = actions + dt * pred_velocity
373
+
374
+ return actions
375
+
376
+ # -----------------------------------------------------------------------
377
+ # Properties
378
+ # -----------------------------------------------------------------------
379
+
380
+ @property
381
+ def device(self) -> torch.device:
382
+ return next(iter(self.parameters())).device
383
+
384
+ @property
385
+ def dtype(self) -> torch.dtype:
386
+ return next(iter(self.parameters())).dtype
387
+
388
+
389
+ # ---------------------------------------------------------------------------
390
+ # Utilities
391
+ # ---------------------------------------------------------------------------
392
+
393
+
394
+ def _expand_14d_to_20d(action_std_14: torch.Tensor) -> torch.Tensor:
395
+ """Expand 14-D euler action statistics to 20-D r6d representation."""
396
+ out = torch.zeros(20, dtype=action_std_14.dtype)
397
+ out[0:3] = action_std_14[0:3]
398
+ out[3:9] = torch.repeat_interleave(action_std_14[3:6], 2)
399
+ out[9] = action_std_14[6]
400
+ out[10:13] = action_std_14[7:10]
401
+ out[13:19] = torch.repeat_interleave(action_std_14[10:13], 2)
402
+ out[19] = action_std_14[13]
403
+ return out