migma 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- migma-0.1.0/.gitignore +6 -0
- migma-0.1.0/PKG-INFO +26 -0
- migma-0.1.0/migma/__init__.py +21 -0
- migma-0.1.0/migma/models/__init__.py +11 -0
- migma-0.1.0/migma/models/action_heads/__init__.py +11 -0
- migma-0.1.0/migma/models/action_heads/configuration.py +121 -0
- migma-0.1.0/migma/models/action_heads/layerwise_fm_dsc.py +403 -0
- migma-0.1.0/migma/models/action_heads/layerwise_fm_pi.py +432 -0
- migma-0.1.0/migma/models/action_heads/layerwise_fm_pi05.py +419 -0
- migma-0.1.0/migma/nn/__init__.py +47 -0
- migma-0.1.0/migma/nn/dit.py +641 -0
- migma-0.1.0/migma/nn/dit_dsc.py +481 -0
- migma-0.1.0/migma/nn/dit_pi05.py +244 -0
- migma-0.1.0/migma/nn/encoders.py +196 -0
- migma-0.1.0/migma/utils/__init__.py +1 -0
- migma-0.1.0/pyproject.toml +53 -0
migma-0.1.0/.gitignore
ADDED
migma-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: migma
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Modular action policy components for Vision-Language-Action models
|
|
5
|
+
Project-URL: Repository, https://github.com/lucasjinreal/migma
|
|
6
|
+
License: Apache-2.0
|
|
7
|
+
Keywords: action-policy,diffusion,flow-matching,robotics,vla
|
|
8
|
+
Classifier: Development Status :: 3 - Alpha
|
|
9
|
+
Classifier: Intended Audience :: Science/Research
|
|
10
|
+
Classifier: License :: OSI Approved :: Apache Software License
|
|
11
|
+
Classifier: Programming Language :: Python :: 3
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
16
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
17
|
+
Requires-Python: >=3.9
|
|
18
|
+
Requires-Dist: diffusers>=0.27.0
|
|
19
|
+
Requires-Dist: numpy>=1.24.0
|
|
20
|
+
Requires-Dist: torch>=2.0.0
|
|
21
|
+
Requires-Dist: transformers>=4.40.0
|
|
22
|
+
Provides-Extra: dev
|
|
23
|
+
Requires-Dist: black>=23.0.0; extra == 'dev'
|
|
24
|
+
Requires-Dist: pytest>=7.0.0; extra == 'dev'
|
|
25
|
+
Requires-Dist: ruff>=0.1.0; extra == 'dev'
|
|
26
|
+
Requires-Dist: twine>=4.0.0; extra == 'dev'
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
"""
|
|
2
|
+
migma — Modular action policy components for Vision-Language-Action (VLA) models.
|
|
3
|
+
|
|
4
|
+
Provides layerwise flow-matching action heads that attach to a frozen VLM backbone
|
|
5
|
+
and decode robot actions from its per-layer KV cache or hidden states.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from migma.models.action_heads import (
|
|
9
|
+
FlowmatchingActionHeadConfig,
|
|
10
|
+
LayerwiseFMDSCActionHead,
|
|
11
|
+
LayerwiseFMPi05ActionHead,
|
|
12
|
+
LayerwiseFMPiActionHead,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
__version__ = "0.1.0"
|
|
16
|
+
__all__ = [
|
|
17
|
+
"FlowmatchingActionHeadConfig",
|
|
18
|
+
"LayerwiseFMPiActionHead",
|
|
19
|
+
"LayerwiseFMPi05ActionHead",
|
|
20
|
+
"LayerwiseFMDSCActionHead",
|
|
21
|
+
]
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
from migma.models.action_heads.configuration import FlowmatchingActionHeadConfig
|
|
2
|
+
from migma.models.action_heads.layerwise_fm_dsc import LayerwiseFMDSCActionHead
|
|
3
|
+
from migma.models.action_heads.layerwise_fm_pi import LayerwiseFMPiActionHead
|
|
4
|
+
from migma.models.action_heads.layerwise_fm_pi05 import LayerwiseFMPi05ActionHead
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
"FlowmatchingActionHeadConfig",
|
|
8
|
+
"LayerwiseFMPiActionHead",
|
|
9
|
+
"LayerwiseFMDSCActionHead",
|
|
10
|
+
"LayerwiseFMPi05ActionHead",
|
|
11
|
+
]
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Configuration dataclass for layerwise flow-matching action heads.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from typing import Any, Dict, Optional
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class FlowmatchingActionHeadConfig:
|
|
11
|
+
"""
|
|
12
|
+
Unified configuration for all layerwise flow-matching action head variants.
|
|
13
|
+
|
|
14
|
+
The two concrete head classes (:class:`~migma.models.action_heads.LayerwiseFMPiActionHead`
|
|
15
|
+
and :class:`~migma.models.action_heads.LayerwiseFMDSCActionHead`) share this
|
|
16
|
+
configuration schema and ignore fields that are not relevant to them.
|
|
17
|
+
|
|
18
|
+
Core dimensions
|
|
19
|
+
---------------
|
|
20
|
+
action_dim:
|
|
21
|
+
Dimension of a single action vector (e.g. 14 for bimanual euler, 20 for r6d).
|
|
22
|
+
action_horizon:
|
|
23
|
+
Total rollout length including the current step
|
|
24
|
+
(``future_action_window_size + 1`` in the original starVLA convention).
|
|
25
|
+
vl_hidden_dim:
|
|
26
|
+
Hidden dimension of the VLM backbone (e.g. 2048 for Qwen2.5-VL-2B).
|
|
27
|
+
num_inference_timesteps:
|
|
28
|
+
Number of Euler integration steps at inference time.
|
|
29
|
+
|
|
30
|
+
DiT architecture
|
|
31
|
+
----------------
|
|
32
|
+
dit_cfg:
|
|
33
|
+
Dict of overrides merged on top of the default DiT hyperparameters.
|
|
34
|
+
Recognised keys: ``num_layers``, ``num_attention_heads``,
|
|
35
|
+
``attention_head_dim``, ``dropout``, ``activation_fn``, …
|
|
36
|
+
Any key accepted by the underlying :class:`~migma.nn.DiT` /
|
|
37
|
+
:class:`~migma.nn.DSCDiT` constructor is valid here.
|
|
38
|
+
use_kv_cache_dit:
|
|
39
|
+
When ``True``, uses the KV-cache-capable DiT variant.
|
|
40
|
+
kv_hidden_size:
|
|
41
|
+
Hidden size of the KV tensors supplied by the VLM.
|
|
42
|
+
kv_compress_ratio:
|
|
43
|
+
If < 1.0, a :class:`~migma.nn.GatedMLPCompressor` is applied to the KV
|
|
44
|
+
cache before projection (DSC head only).
|
|
45
|
+
global_cond_dim:
|
|
46
|
+
Dimension of ``vlm_last_hidden_state``. Defaults to ``kv_hidden_size``
|
|
47
|
+
when ``None`` (DSC head only).
|
|
48
|
+
|
|
49
|
+
State encoder
|
|
50
|
+
-------------
|
|
51
|
+
state_dim:
|
|
52
|
+
When set, a two-layer MLP maps the proprioceptive state vector to
|
|
53
|
+
``vl_hidden_dim`` before concatenating with action tokens.
|
|
54
|
+
|
|
55
|
+
Positional embedding
|
|
56
|
+
--------------------
|
|
57
|
+
add_pos_embed:
|
|
58
|
+
Whether to add a learnable positional embedding to action features.
|
|
59
|
+
max_seq_len:
|
|
60
|
+
Maximum sequence length for the positional embedding table.
|
|
61
|
+
num_target_vision_tokens:
|
|
62
|
+
Number of learnable "future" tokens prepended to the action sequence.
|
|
63
|
+
|
|
64
|
+
Flow-matching noise
|
|
65
|
+
-------------------
|
|
66
|
+
noise_beta_alpha / noise_beta_beta:
|
|
67
|
+
Parameters of the Beta distribution used to sample diffusion time.
|
|
68
|
+
noise_s:
|
|
69
|
+
Flow-matching schedule shift: ``t = (s - Beta_sample) / s``.
|
|
70
|
+
num_timestep_buckets:
|
|
71
|
+
Number of discrete bins for the continuous time variable.
|
|
72
|
+
use_scaled_noise:
|
|
73
|
+
Whether to scale initial noise by per-dimension action statistics.
|
|
74
|
+
t_eps:
|
|
75
|
+
Minimum ``(1 - t)`` denominator to avoid division by zero in the AML
|
|
76
|
+
velocity formula (DSC head only).
|
|
77
|
+
|
|
78
|
+
Loss coefficients
|
|
79
|
+
-----------------
|
|
80
|
+
smoothness_loss_weight:
|
|
81
|
+
Weight for the action smoothness regulariser (0 = disabled).
|
|
82
|
+
enable_first_action_state_loss:
|
|
83
|
+
Anchor the first predicted action to the current proprioceptive state
|
|
84
|
+
(Pi head only).
|
|
85
|
+
first_action_state_loss_weight:
|
|
86
|
+
Weight for the first-action anchor loss.
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
# ── Core dimensions (required) ──────────────────────────────────────────
|
|
90
|
+
action_dim: int
|
|
91
|
+
action_horizon: int
|
|
92
|
+
vl_hidden_dim: int
|
|
93
|
+
num_inference_timesteps: int
|
|
94
|
+
|
|
95
|
+
# ── DiT architecture ────────────────────────────────────────────────────
|
|
96
|
+
dit_cfg: Dict[str, Any] = field(default_factory=dict)
|
|
97
|
+
use_kv_cache_dit: bool = False
|
|
98
|
+
kv_hidden_size: Optional[int] = None # None means same as vl_hidden_dim (no projection)
|
|
99
|
+
kv_compress_ratio: float = 1.0
|
|
100
|
+
global_cond_dim: Optional[int] = None # DSC-specific
|
|
101
|
+
|
|
102
|
+
# ── State encoder ───────────────────────────────────────────────────────
|
|
103
|
+
state_dim: Optional[int] = None
|
|
104
|
+
|
|
105
|
+
# ── Token sequence ──────────────────────────────────────────────────────
|
|
106
|
+
num_target_vision_tokens: int = 32
|
|
107
|
+
add_pos_embed: bool = True
|
|
108
|
+
max_seq_len: int = 1024
|
|
109
|
+
|
|
110
|
+
# ── Flow-matching noise ─────────────────────────────────────────────────
|
|
111
|
+
noise_beta_alpha: float = 1.5
|
|
112
|
+
noise_beta_beta: float = 1.0
|
|
113
|
+
noise_s: float = 0.999
|
|
114
|
+
num_timestep_buckets: int = 1000
|
|
115
|
+
use_scaled_noise: bool = True
|
|
116
|
+
t_eps: float = 5e-2 # DSC/AML-specific
|
|
117
|
+
|
|
118
|
+
# ── Loss coefficients ───────────────────────────────────────────────────
|
|
119
|
+
smoothness_loss_weight: float = 0.0
|
|
120
|
+
enable_first_action_state_loss: bool = False # Pi-specific
|
|
121
|
+
first_action_state_loss_weight: float = 0.0
|
|
@@ -0,0 +1,403 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Layerwise Flow-Matching Action Head — AML / DSC variant.
|
|
3
|
+
|
|
4
|
+
Velocity target: ``v = (actions − x_t) / (1 − t)`` (AML / posterior-mean)
|
|
5
|
+
KV conditioning: list of ``(key, value)`` tuples or hidden-state tensors.
|
|
6
|
+
Extra signal: optional ``vlm_last_hidden_state`` drives dual-stream
|
|
7
|
+
conditioning inside :class:`~migma.nn.DSCDiT`.
|
|
8
|
+
|
|
9
|
+
References:
|
|
10
|
+
- Consistent Flow Matching (AML training objective)
|
|
11
|
+
- DSC-DiT dual-stream conditioning design
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
import logging
|
|
15
|
+
from typing import List, Optional, Tuple, Union
|
|
16
|
+
|
|
17
|
+
import numpy as np
|
|
18
|
+
import torch
|
|
19
|
+
import torch.nn as nn
|
|
20
|
+
from torch.distributions import Beta
|
|
21
|
+
|
|
22
|
+
from migma.models.action_heads.configuration import FlowmatchingActionHeadConfig
|
|
23
|
+
from migma.nn.dit import DiT
|
|
24
|
+
from migma.nn.dit_dsc import DSCDiT
|
|
25
|
+
from migma.nn.encoders import ActionEncoder, MLP
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
_DEFAULT_DIT_CFG = {
|
|
30
|
+
"num_layers": 28,
|
|
31
|
+
"attention_head_dim": 64,
|
|
32
|
+
"dropout": 0.1,
|
|
33
|
+
"attention_bias": True,
|
|
34
|
+
"activation_fn": "gelu-approximate",
|
|
35
|
+
"norm_type": "ada_norm",
|
|
36
|
+
"norm_elementwise_affine": False,
|
|
37
|
+
"norm_eps": 1e-5,
|
|
38
|
+
"max_num_positional_embeddings": 512,
|
|
39
|
+
"final_dropout": True,
|
|
40
|
+
"positional_embeddings": "sinusoidal",
|
|
41
|
+
"interleave_self_attention": False,
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _build_dit_config(cfg: FlowmatchingActionHeadConfig) -> dict:
|
|
46
|
+
final = _DEFAULT_DIT_CFG.copy()
|
|
47
|
+
final.update(cfg.dit_cfg)
|
|
48
|
+
|
|
49
|
+
vl_dim = cfg.vl_hidden_dim
|
|
50
|
+
head_dim = final.get("attention_head_dim", 64)
|
|
51
|
+
final["num_attention_heads"] = final.get("num_attention_heads", vl_dim // head_dim)
|
|
52
|
+
final["output_dim"] = vl_dim
|
|
53
|
+
final["kv_hidden_size"] = final.get("kv_hidden_size", cfg.kv_hidden_size)
|
|
54
|
+
final["kv_compress_ratio"] = final.get("kv_compress_ratio", cfg.kv_compress_ratio)
|
|
55
|
+
final["global_cond_dim"] = final.get(
|
|
56
|
+
"global_cond_dim",
|
|
57
|
+
cfg.global_cond_dim if cfg.global_cond_dim is not None else cfg.kv_hidden_size,
|
|
58
|
+
)
|
|
59
|
+
# cross_attention_dim only used by the standard DiT fallback.
|
|
60
|
+
final["cross_attention_dim"] = final.get("cross_attention_dim", vl_dim)
|
|
61
|
+
return final
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class LayerwiseFMDSCActionHead(nn.Module):
|
|
65
|
+
"""
|
|
66
|
+
AML flow-matching head with Dual-Stream Conditioned DiT (DSC-DiT).
|
|
67
|
+
|
|
68
|
+
**Training objective** (AML / consistent FM):
|
|
69
|
+
|
|
70
|
+
.. math::
|
|
71
|
+
|
|
72
|
+
x_t = t \\cdot x_0 + (1-t) \\cdot \\epsilon, \\quad
|
|
73
|
+
v_\\text{target} = \\frac{x_0 - x_t}{1-t}
|
|
74
|
+
|
|
75
|
+
where :math:`\\epsilon \\sim \\mathcal{N}(0, I)` (optionally scaled).
|
|
76
|
+
|
|
77
|
+
**Conditioning**:
|
|
78
|
+
|
|
79
|
+
* *Primary* — per-layer ``(key, value)`` tensors from the VLM or a list
|
|
80
|
+
of hidden-state tensors (standard cross-attention).
|
|
81
|
+
* *Secondary* — ``vlm_last_hidden_state`` fed into :class:`~migma.nn.DSCDiT`
|
|
82
|
+
for attention-pooled global adaLN and sequence-level cross-attention.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
config: :class:`~migma.models.action_heads.FlowmatchingActionHeadConfig`
|
|
86
|
+
|
|
87
|
+
Example::
|
|
88
|
+
|
|
89
|
+
cfg = FlowmatchingActionHeadConfig(
|
|
90
|
+
action_dim=14, action_horizon=16,
|
|
91
|
+
vl_hidden_dim=2048, num_inference_timesteps=10,
|
|
92
|
+
use_kv_cache_dit=True,
|
|
93
|
+
kv_hidden_size=512,
|
|
94
|
+
)
|
|
95
|
+
head = LayerwiseFMDSCActionHead(cfg)
|
|
96
|
+
|
|
97
|
+
loss = head(context_list, actions, vlm_last_hidden_state=last_hs)
|
|
98
|
+
actions = head.predict_action(context_list, vlm_last_hidden_state=last_hs)
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
def __init__(self, config: FlowmatchingActionHeadConfig) -> None:
|
|
102
|
+
super().__init__()
|
|
103
|
+
self.config = config
|
|
104
|
+
|
|
105
|
+
dit_kwargs = _build_dit_config(config)
|
|
106
|
+
use_kv_cache = config.use_kv_cache_dit
|
|
107
|
+
if use_kv_cache:
|
|
108
|
+
self.model = DSCDiT(**dit_kwargs)
|
|
109
|
+
logger.info("LayerwiseFMDSCActionHead: using DSCDiT")
|
|
110
|
+
else:
|
|
111
|
+
self.model = DiT(**dit_kwargs)
|
|
112
|
+
logger.info("LayerwiseFMDSCActionHead: using standard DiT (no KV cache)")
|
|
113
|
+
|
|
114
|
+
self.use_kv_cache_dit = use_kv_cache
|
|
115
|
+
self.action_dim = config.action_dim
|
|
116
|
+
self.action_horizon = config.action_horizon
|
|
117
|
+
self.num_inference_timesteps = config.num_inference_timesteps
|
|
118
|
+
self.smoothness_loss_weight = config.smoothness_loss_weight
|
|
119
|
+
self.t_eps = config.t_eps
|
|
120
|
+
|
|
121
|
+
hidden_dim = config.vl_hidden_dim
|
|
122
|
+
|
|
123
|
+
self.state_encoder = (
|
|
124
|
+
MLP(input_dim=config.state_dim, output_dim=hidden_dim)
|
|
125
|
+
if config.state_dim
|
|
126
|
+
else None
|
|
127
|
+
)
|
|
128
|
+
self.action_encoder = ActionEncoder(
|
|
129
|
+
action_dim=config.action_dim, hidden_size=hidden_dim
|
|
130
|
+
)
|
|
131
|
+
self.action_decoder = MLP(
|
|
132
|
+
input_dim=hidden_dim, hidden_dim=1024, output_dim=config.action_dim
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
self.future_tokens = nn.Embedding(config.num_target_vision_tokens, hidden_dim)
|
|
136
|
+
nn.init.normal_(self.future_tokens.weight, mean=0.0, std=0.02)
|
|
137
|
+
|
|
138
|
+
if config.add_pos_embed:
|
|
139
|
+
self.position_embedding = nn.Embedding(config.max_seq_len, hidden_dim)
|
|
140
|
+
nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
|
|
141
|
+
|
|
142
|
+
self.beta_dist = Beta(config.noise_beta_alpha, config.noise_beta_beta)
|
|
143
|
+
self.num_timestep_buckets = config.num_timestep_buckets
|
|
144
|
+
self.use_scaled_noise = config.use_scaled_noise
|
|
145
|
+
self.register_buffer("noise_scale", torch.ones(config.action_dim))
|
|
146
|
+
|
|
147
|
+
logger.info(
|
|
148
|
+
"LayerwiseFMDSCActionHead initialised | "
|
|
149
|
+
f"action_dim={config.action_dim} horizon={config.action_horizon} "
|
|
150
|
+
f"vl_hidden_dim={config.vl_hidden_dim} "
|
|
151
|
+
f"params={sum(p.numel() for p in self.parameters()):,}"
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
# -----------------------------------------------------------------------
|
|
155
|
+
# Helpers
|
|
156
|
+
# -----------------------------------------------------------------------
|
|
157
|
+
|
|
158
|
+
def _sample_time(
|
|
159
|
+
self, batch_size: int, device: torch.device, dtype: torch.dtype
|
|
160
|
+
) -> torch.Tensor:
|
|
161
|
+
sample = self.beta_dist.sample([batch_size]).to(device, dtype=dtype)
|
|
162
|
+
return (self.config.noise_s - sample) / self.config.noise_s
|
|
163
|
+
|
|
164
|
+
@staticmethod
|
|
165
|
+
def _is_kv_list(context: list) -> bool:
|
|
166
|
+
return (
|
|
167
|
+
isinstance(context, (list, tuple))
|
|
168
|
+
and len(context) > 0
|
|
169
|
+
and isinstance(context[0], (list, tuple))
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
def _build_sa_embs(
|
|
173
|
+
self,
|
|
174
|
+
action_features: torch.Tensor,
|
|
175
|
+
state_features: Optional[torch.Tensor],
|
|
176
|
+
B: int,
|
|
177
|
+
device: torch.device,
|
|
178
|
+
) -> torch.Tensor:
|
|
179
|
+
future_tokens = self.future_tokens.weight.unsqueeze(0).expand(B, -1, -1)
|
|
180
|
+
if self.config.add_pos_embed:
|
|
181
|
+
pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device)
|
|
182
|
+
action_features = action_features + self.position_embedding(pos_ids).unsqueeze(0)
|
|
183
|
+
if state_features is not None:
|
|
184
|
+
return torch.cat([state_features, future_tokens, action_features], dim=1)
|
|
185
|
+
return torch.cat([future_tokens, action_features], dim=1)
|
|
186
|
+
|
|
187
|
+
def _run_model(
|
|
188
|
+
self,
|
|
189
|
+
sa_embs: torch.Tensor,
|
|
190
|
+
context_list: list,
|
|
191
|
+
t_disc: torch.Tensor,
|
|
192
|
+
is_kv: bool,
|
|
193
|
+
vlm_last_hidden_state: Optional[torch.Tensor],
|
|
194
|
+
) -> torch.Tensor:
|
|
195
|
+
device, dtype = sa_embs.device, sa_embs.dtype
|
|
196
|
+
|
|
197
|
+
if self.use_kv_cache_dit:
|
|
198
|
+
return self.model(
|
|
199
|
+
hidden_states=sa_embs,
|
|
200
|
+
kv_cache_list=context_list,
|
|
201
|
+
timestep=t_disc,
|
|
202
|
+
encoder_attention_mask=None,
|
|
203
|
+
vlm_last_hidden_state=vlm_last_hidden_state,
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
# Standard DiT: build encoder_hidden_states from context.
|
|
207
|
+
if vlm_last_hidden_state is not None:
|
|
208
|
+
enc_hs = vlm_last_hidden_state.to(device=device, dtype=dtype)
|
|
209
|
+
else:
|
|
210
|
+
enc_hs = (context_list[-1][1] if is_kv else context_list[-1]).to(
|
|
211
|
+
device=device, dtype=dtype
|
|
212
|
+
)
|
|
213
|
+
return self.model(
|
|
214
|
+
hidden_states=sa_embs,
|
|
215
|
+
encoder_hidden_states=enc_hs,
|
|
216
|
+
timestep=t_disc,
|
|
217
|
+
encoder_attention_mask=None,
|
|
218
|
+
vlm_last_hidden_state=vlm_last_hidden_state,
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
# -----------------------------------------------------------------------
|
|
222
|
+
# Noise-scale calibration
|
|
223
|
+
# -----------------------------------------------------------------------
|
|
224
|
+
|
|
225
|
+
def set_noise_scale_from_stats(
|
|
226
|
+
self,
|
|
227
|
+
action_std: Union[list, "np.ndarray", torch.Tensor],
|
|
228
|
+
) -> None:
|
|
229
|
+
"""
|
|
230
|
+
Calibrate per-dimension noise scaling from dataset action statistics.
|
|
231
|
+
|
|
232
|
+
See :meth:`LayerwiseFMPiActionHead.set_noise_scale_from_stats` for the
|
|
233
|
+
full docstring — the logic is identical.
|
|
234
|
+
"""
|
|
235
|
+
if isinstance(action_std, (list, np.ndarray)):
|
|
236
|
+
action_std = torch.tensor(action_std, dtype=torch.float32)
|
|
237
|
+
|
|
238
|
+
if len(action_std) == 14 and self.action_dim == 20:
|
|
239
|
+
action_std = _expand_14d_to_20d(action_std)
|
|
240
|
+
|
|
241
|
+
if len(action_std) != self.action_dim:
|
|
242
|
+
raise ValueError(
|
|
243
|
+
f"action_std length {len(action_std)} != action_dim {self.action_dim}."
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
TARGET = 0.3
|
|
247
|
+
scale = torch.clamp(TARGET / (action_std + 1e-6), max=1.0)
|
|
248
|
+
self.noise_scale.data.copy_((action_std * scale).to(self.noise_scale.device))
|
|
249
|
+
logger.info(
|
|
250
|
+
f"Noise scale updated | min={self.noise_scale.min():.4f} "
|
|
251
|
+
f"max={self.noise_scale.max():.4f}"
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
# -----------------------------------------------------------------------
|
|
255
|
+
# Forward (training)
|
|
256
|
+
# -----------------------------------------------------------------------
|
|
257
|
+
|
|
258
|
+
def forward(
|
|
259
|
+
self,
|
|
260
|
+
context_list: List[Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]],
|
|
261
|
+
actions: torch.Tensor,
|
|
262
|
+
state: Optional[torch.Tensor] = None,
|
|
263
|
+
vlm_last_hidden_state: Optional[torch.Tensor] = None,
|
|
264
|
+
) -> torch.Tensor:
|
|
265
|
+
"""
|
|
266
|
+
Compute the AML flow-matching velocity loss.
|
|
267
|
+
|
|
268
|
+
Args:
|
|
269
|
+
context_list: Per-layer context — list of ``(key, value)``
|
|
270
|
+
tuples or list of hidden-state tensors.
|
|
271
|
+
actions: Target trajectory ``(B, T, action_dim)``.
|
|
272
|
+
state: Optional proprioceptive state.
|
|
273
|
+
vlm_last_hidden_state: Optional ``(B, S, D)`` — enables DSC dual-stream
|
|
274
|
+
conditioning when using :class:`~migma.nn.DSCDiT`.
|
|
275
|
+
|
|
276
|
+
Returns:
|
|
277
|
+
Scalar training loss.
|
|
278
|
+
"""
|
|
279
|
+
device = actions.device
|
|
280
|
+
is_kv = self._is_kv_list(context_list)
|
|
281
|
+
B = context_list[0][0].shape[0] if is_kv else context_list[0].shape[0]
|
|
282
|
+
|
|
283
|
+
t = self._sample_time(B, device=device, dtype=actions.dtype)
|
|
284
|
+
t = t[:, None, None] # (B, 1, 1)
|
|
285
|
+
|
|
286
|
+
noise = torch.randn(B, self.action_horizon, self.action_dim,
|
|
287
|
+
dtype=actions.dtype, device=device)
|
|
288
|
+
if self.use_scaled_noise:
|
|
289
|
+
noise = noise * self.noise_scale[None, None, :]
|
|
290
|
+
|
|
291
|
+
# AML interpolation: x_t = t·x_0 + (1-t)·ε
|
|
292
|
+
noisy = t * actions + (1 - t) * noise
|
|
293
|
+
# AML velocity target: v = (x_0 - x_t) / (1 - t)
|
|
294
|
+
velocity = (actions - noisy) / (1 - t).clamp_min(self.t_eps)
|
|
295
|
+
|
|
296
|
+
t_disc = (t[:, 0, 0] * self.num_timestep_buckets).long()
|
|
297
|
+
action_features = self.action_encoder(noisy, t_disc)
|
|
298
|
+
|
|
299
|
+
state_features = (
|
|
300
|
+
self.state_encoder(state.to(device=device, dtype=actions.dtype))
|
|
301
|
+
if state is not None and self.state_encoder is not None
|
|
302
|
+
else None
|
|
303
|
+
)
|
|
304
|
+
sa_embs = self._build_sa_embs(action_features, state_features, B, device)
|
|
305
|
+
|
|
306
|
+
model_out = self._run_model(sa_embs, context_list, t_disc, is_kv, vlm_last_hidden_state)
|
|
307
|
+
|
|
308
|
+
# AML: model predicts action samples; velocity derived from predictions.
|
|
309
|
+
pred_actions = self.action_decoder(model_out)[:, -actions.shape[1]:]
|
|
310
|
+
pred_velocity = (pred_actions - noisy) / (1 - t).clamp_min(self.t_eps)
|
|
311
|
+
|
|
312
|
+
loss = ((pred_velocity - velocity) ** 2).mean()
|
|
313
|
+
|
|
314
|
+
if self.smoothness_loss_weight > 0 and pred_actions.shape[1] > 1:
|
|
315
|
+
smooth_loss = ((pred_actions[:, 1:] - pred_actions[:, :-1]) ** 2).mean()
|
|
316
|
+
loss = loss + self.smoothness_loss_weight * smooth_loss
|
|
317
|
+
|
|
318
|
+
return loss
|
|
319
|
+
|
|
320
|
+
# -----------------------------------------------------------------------
|
|
321
|
+
# Predict (inference)
|
|
322
|
+
# -----------------------------------------------------------------------
|
|
323
|
+
|
|
324
|
+
@torch.no_grad()
|
|
325
|
+
def predict_action(
|
|
326
|
+
self,
|
|
327
|
+
context_list: List[Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]],
|
|
328
|
+
state: Optional[torch.Tensor] = None,
|
|
329
|
+
vlm_last_hidden_state: Optional[torch.Tensor] = None,
|
|
330
|
+
) -> torch.Tensor:
|
|
331
|
+
"""
|
|
332
|
+
Denoise from pure noise to a clean action rollout.
|
|
333
|
+
|
|
334
|
+
Args:
|
|
335
|
+
context_list: Same format as :meth:`forward`.
|
|
336
|
+
state: Optional proprioceptive state.
|
|
337
|
+
vlm_last_hidden_state: Optional last-layer VLM hidden states.
|
|
338
|
+
|
|
339
|
+
Returns:
|
|
340
|
+
``(B, action_horizon, action_dim)``
|
|
341
|
+
"""
|
|
342
|
+
is_kv = self._is_kv_list(context_list)
|
|
343
|
+
if is_kv:
|
|
344
|
+
B, device = context_list[0][0].shape[0], context_list[0][0].device
|
|
345
|
+
else:
|
|
346
|
+
B, device = context_list[0].shape[0], context_list[0].device
|
|
347
|
+
dtype = torch.float32
|
|
348
|
+
|
|
349
|
+
actions = torch.randn(B, self.action_horizon, self.action_dim, dtype=dtype, device=device)
|
|
350
|
+
if self.use_scaled_noise:
|
|
351
|
+
actions = actions * self.noise_scale[None, None, :].to(dtype=dtype)
|
|
352
|
+
|
|
353
|
+
state_features = None
|
|
354
|
+
if state is not None and self.state_encoder is not None:
|
|
355
|
+
state_features = self.state_encoder(state.to(device=device, dtype=dtype))
|
|
356
|
+
|
|
357
|
+
dt = 1.0 / self.num_inference_timesteps
|
|
358
|
+
for step in range(self.num_inference_timesteps):
|
|
359
|
+
t_cont = step / float(self.num_inference_timesteps)
|
|
360
|
+
t_disc_int = int(t_cont * self.num_timestep_buckets)
|
|
361
|
+
t_tensor = torch.full((B,), t_disc_int, device=device, dtype=torch.long)
|
|
362
|
+
|
|
363
|
+
action_features = self.action_encoder(actions, t_tensor)
|
|
364
|
+
sa_embs = self._build_sa_embs(action_features, state_features, B, device)
|
|
365
|
+
|
|
366
|
+
model_out = self._run_model(sa_embs, context_list, t_tensor, is_kv,
|
|
367
|
+
vlm_last_hidden_state)
|
|
368
|
+
|
|
369
|
+
pred_actions = self.action_decoder(model_out)[:, -self.action_horizon:]
|
|
370
|
+
# AML: velocity derived from predicted actions
|
|
371
|
+
pred_velocity = (pred_actions - actions) / (1.0 - t_cont)
|
|
372
|
+
actions = actions + dt * pred_velocity
|
|
373
|
+
|
|
374
|
+
return actions
|
|
375
|
+
|
|
376
|
+
# -----------------------------------------------------------------------
|
|
377
|
+
# Properties
|
|
378
|
+
# -----------------------------------------------------------------------
|
|
379
|
+
|
|
380
|
+
@property
|
|
381
|
+
def device(self) -> torch.device:
|
|
382
|
+
return next(iter(self.parameters())).device
|
|
383
|
+
|
|
384
|
+
@property
|
|
385
|
+
def dtype(self) -> torch.dtype:
|
|
386
|
+
return next(iter(self.parameters())).dtype
|
|
387
|
+
|
|
388
|
+
|
|
389
|
+
# ---------------------------------------------------------------------------
|
|
390
|
+
# Utilities
|
|
391
|
+
# ---------------------------------------------------------------------------
|
|
392
|
+
|
|
393
|
+
|
|
394
|
+
def _expand_14d_to_20d(action_std_14: torch.Tensor) -> torch.Tensor:
|
|
395
|
+
"""Expand 14-D euler action statistics to 20-D r6d representation."""
|
|
396
|
+
out = torch.zeros(20, dtype=action_std_14.dtype)
|
|
397
|
+
out[0:3] = action_std_14[0:3]
|
|
398
|
+
out[3:9] = torch.repeat_interleave(action_std_14[3:6], 2)
|
|
399
|
+
out[9] = action_std_14[6]
|
|
400
|
+
out[10:13] = action_std_14[7:10]
|
|
401
|
+
out[13:19] = torch.repeat_interleave(action_std_14[10:13], 2)
|
|
402
|
+
out[19] = action_std_14[13]
|
|
403
|
+
return out
|