mimic-video 0.0.31__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mimic_video/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from mimic_video.mimic_video import MimicVideo
|
|
@@ -0,0 +1,285 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from torch.nn import Module
|
|
8
|
+
from torch import nn, Tensor
|
|
9
|
+
from einops import rearrange
|
|
10
|
+
|
|
11
|
+
from diffusers.models.transformers.transformer_cosmos import CosmosTransformer3DModel
|
|
12
|
+
from diffusers.models.autoencoders.autoencoder_kl_cosmos import AutoencoderKLCosmos
|
|
13
|
+
from diffusers.schedulers.scheduling_edm_euler import EDMEulerScheduler
|
|
14
|
+
from transformers import T5EncoderModel, T5TokenizerFast, T5Config
|
|
15
|
+
|
|
16
|
+
# helpers
|
|
17
|
+
|
|
18
|
+
def exists(v):
|
|
19
|
+
return v is not None
|
|
20
|
+
|
|
21
|
+
def identity(t):
|
|
22
|
+
return t
|
|
23
|
+
|
|
24
|
+
def default(v, d):
|
|
25
|
+
return v if exists(v) else d
|
|
26
|
+
|
|
27
|
+
# constants
|
|
28
|
+
|
|
29
|
+
TINY_TRANSFORMER_CONFIG = dict(
|
|
30
|
+
in_channels = 16,
|
|
31
|
+
out_channels = 16,
|
|
32
|
+
num_attention_heads = 1,
|
|
33
|
+
attention_head_dim = 16,
|
|
34
|
+
mlp_ratio = 1.0,
|
|
35
|
+
text_embed_dim = 32,
|
|
36
|
+
adaln_lora_dim = 32,
|
|
37
|
+
patch_size = (1, 2, 2),
|
|
38
|
+
max_size = (4, 16, 16),
|
|
39
|
+
extra_pos_embed_type = None,
|
|
40
|
+
concat_padding_mask = False,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
TINY_VAE_CONFIG = dict(
|
|
44
|
+
in_channels = 3,
|
|
45
|
+
out_channels = 3,
|
|
46
|
+
latent_channels = 16,
|
|
47
|
+
encoder_block_out_channels = (8, 16),
|
|
48
|
+
decode_block_out_channels = (8, 16),
|
|
49
|
+
temporal_compression_ratio = 4,
|
|
50
|
+
spatial_compression_ratio = 4,
|
|
51
|
+
num_layers = 1,
|
|
52
|
+
attention_resolutions = (),
|
|
53
|
+
resolution = 64,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
TINY_T5_CONFIG = dict(
|
|
57
|
+
vocab_size = 32128,
|
|
58
|
+
d_model = 32,
|
|
59
|
+
d_kv = 8,
|
|
60
|
+
d_ff = 64,
|
|
61
|
+
num_layers = 1,
|
|
62
|
+
num_heads = 1,
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
REAL_TRANSFORMER_CONFIG = dict(
|
|
66
|
+
in_channels = 16,
|
|
67
|
+
out_channels = 16,
|
|
68
|
+
num_attention_heads = 32,
|
|
69
|
+
attention_head_dim = 128,
|
|
70
|
+
mlp_ratio = 4.0,
|
|
71
|
+
text_embed_dim = 1024,
|
|
72
|
+
patch_size = (1, 2, 2),
|
|
73
|
+
max_size = (128, 240, 240),
|
|
74
|
+
extra_pos_embed_type = "learnable",
|
|
75
|
+
concat_padding_mask = True,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
REAL_VAE_CONFIG = dict(
|
|
79
|
+
in_channels = 3,
|
|
80
|
+
out_channels = 3,
|
|
81
|
+
latent_channels = 16,
|
|
82
|
+
encoder_block_out_channels = (128, 256, 512, 512),
|
|
83
|
+
decode_block_out_channels = (256, 512, 512, 512),
|
|
84
|
+
temporal_compression_ratio = 8,
|
|
85
|
+
spatial_compression_ratio = 8,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
REAL_T5_CONFIG = dict(
|
|
89
|
+
vocab_size = 32128,
|
|
90
|
+
d_model = 1024,
|
|
91
|
+
d_kv = 64,
|
|
92
|
+
d_ff = 2048,
|
|
93
|
+
num_layers = 12,
|
|
94
|
+
num_heads = 16,
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
# main class
|
|
98
|
+
|
|
99
|
+
class CosmosPredictWrapper(Module):
|
|
100
|
+
"""
|
|
101
|
+
Wraps Cosmos VAE + DiT for extracting hidden states from a video.
|
|
102
|
+
Supports proper EDM Euler denoising steps.
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
def __init__(
|
|
106
|
+
self,
|
|
107
|
+
model_name: str = 'nvidia/Cosmos-1.0-Diffusion-7B-Video2World',
|
|
108
|
+
extract_layers: int | list[int] | None = None,
|
|
109
|
+
random_weights: bool = False,
|
|
110
|
+
tiny: bool = False,
|
|
111
|
+
normalize = lambda t: (t - 0.5) * 2.0,
|
|
112
|
+
extract_layer: int | None = None
|
|
113
|
+
):
|
|
114
|
+
super().__init__()
|
|
115
|
+
extract_layers = default(extract_layers, extract_layer)
|
|
116
|
+
extract_layers = default(extract_layers, 19)
|
|
117
|
+
|
|
118
|
+
self.extract_layers = [extract_layers] if isinstance(extract_layers, int) else extract_layers
|
|
119
|
+
self.return_list = isinstance(extract_layers, list)
|
|
120
|
+
|
|
121
|
+
self.hook_handles: list = []
|
|
122
|
+
self.cached_hidden_states: list[Tensor] = []
|
|
123
|
+
|
|
124
|
+
if random_weights:
|
|
125
|
+
self._init_random_weights(tiny = tiny)
|
|
126
|
+
else:
|
|
127
|
+
self._init_pretrained(model_name)
|
|
128
|
+
|
|
129
|
+
# Initialize scheduler
|
|
130
|
+
self.scheduler = EDMEulerScheduler()
|
|
131
|
+
|
|
132
|
+
# store hidden dim for consumers
|
|
133
|
+
self.dim_latent = self.transformer.config.num_attention_heads * self.transformer.config.attention_head_dim
|
|
134
|
+
|
|
135
|
+
# maybe normalize
|
|
136
|
+
self.normalize = normalize
|
|
137
|
+
|
|
138
|
+
self._register_hook()
|
|
139
|
+
|
|
140
|
+
@property
|
|
141
|
+
def device(self):
|
|
142
|
+
return next(self.parameters()).device
|
|
143
|
+
|
|
144
|
+
def _init_pretrained(self, model_name: str):
|
|
145
|
+
"""Load pretrained weights from HuggingFace"""
|
|
146
|
+
from diffusers import CosmosVideoToWorldPipeline
|
|
147
|
+
|
|
148
|
+
pipeline = CosmosVideoToWorldPipeline.from_pretrained(model_name)
|
|
149
|
+
|
|
150
|
+
# Extract components we need
|
|
151
|
+
self.vae = pipeline.vae
|
|
152
|
+
self.transformer = pipeline.transformer
|
|
153
|
+
self.text_encoder = pipeline.text_encoder
|
|
154
|
+
self.tokenizer = pipeline.tokenizer
|
|
155
|
+
|
|
156
|
+
# Clean up pipeline
|
|
157
|
+
del pipeline
|
|
158
|
+
|
|
159
|
+
def _init_random_weights(self, tiny: bool = False):
|
|
160
|
+
"""Initialize with random weights for testing"""
|
|
161
|
+
|
|
162
|
+
transformer_config = TINY_TRANSFORMER_CONFIG if tiny else REAL_TRANSFORMER_CONFIG
|
|
163
|
+
vae_config = TINY_VAE_CONFIG if tiny else REAL_VAE_CONFIG
|
|
164
|
+
t5_config_dict = TINY_T5_CONFIG if tiny else REAL_T5_CONFIG
|
|
165
|
+
|
|
166
|
+
num_layers = max(2, *[layer + 1 for layer in self.extract_layers])
|
|
167
|
+
if not tiny:
|
|
168
|
+
num_layers = max(28, num_layers)
|
|
169
|
+
|
|
170
|
+
self.transformer = CosmosTransformer3DModel(
|
|
171
|
+
num_layers = num_layers,
|
|
172
|
+
**transformer_config
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
self.vae = AutoencoderKLCosmos(**vae_config)
|
|
176
|
+
|
|
177
|
+
t5_config = T5Config(**t5_config_dict)
|
|
178
|
+
self.text_encoder = T5EncoderModel(t5_config)
|
|
179
|
+
self.tokenizer = T5TokenizerFast.from_pretrained("google-t5/t5-small")
|
|
180
|
+
|
|
181
|
+
def __del__(self):
|
|
182
|
+
if not hasattr(self, 'hook_handles'):
|
|
183
|
+
return
|
|
184
|
+
|
|
185
|
+
for handle in self.hook_handles:
|
|
186
|
+
handle.remove()
|
|
187
|
+
|
|
188
|
+
def _register_hook(self):
|
|
189
|
+
assert hasattr(self.transformer, 'transformer_blocks'), 'transformer must have transformer_blocks'
|
|
190
|
+
|
|
191
|
+
for layer_index in self.extract_layers:
|
|
192
|
+
assert len(self.transformer.transformer_blocks) > layer_index, f'layer {layer_index} out of bounds'
|
|
193
|
+
|
|
194
|
+
target_layer = self.transformer.transformer_blocks[layer_index]
|
|
195
|
+
|
|
196
|
+
def hook_fn(module, inp, out):
|
|
197
|
+
self.cached_hidden_states.append(out.detach().cpu())
|
|
198
|
+
|
|
199
|
+
handle = target_layer.register_forward_hook(hook_fn)
|
|
200
|
+
self.hook_handles.append(handle)
|
|
201
|
+
|
|
202
|
+
def forward(
|
|
203
|
+
self,
|
|
204
|
+
videos: Tensor,
|
|
205
|
+
prompts: str | list[str] | None = None,
|
|
206
|
+
prompt_token_ids: Tensor | None = None,
|
|
207
|
+
num_inference_steps: int = 1,
|
|
208
|
+
) -> Tensor:
|
|
209
|
+
"""
|
|
210
|
+
videos: (batch, frames, channels, height, width) in [0, 1]
|
|
211
|
+
num_inference_steps: number of denoising steps to run
|
|
212
|
+
returns: hidden states tensor from the specified transformer layer (from first step)
|
|
213
|
+
"""
|
|
214
|
+
batch, t, c, h, w = videos.shape
|
|
215
|
+
|
|
216
|
+
assert exists(prompts) ^ exists(prompt_token_ids)
|
|
217
|
+
|
|
218
|
+
# Scale videos from [0, 1] to [-1, 1] for Cosmos VAE
|
|
219
|
+
|
|
220
|
+
videos = self.normalize(videos)
|
|
221
|
+
|
|
222
|
+
if isinstance(prompts, str):
|
|
223
|
+
prompts = [prompts] * batch
|
|
224
|
+
|
|
225
|
+
self.cached_hidden_states.clear()
|
|
226
|
+
|
|
227
|
+
# Move video to device and rearrange for VAE: (B, T, C, H, W) -> (B, C, T, H, W)
|
|
228
|
+
videos = rearrange(videos, 'b t c h w -> b c t h w')
|
|
229
|
+
|
|
230
|
+
with torch.inference_mode():
|
|
231
|
+
# 1. encode video to latents via VAE
|
|
232
|
+
|
|
233
|
+
latents = self.vae.encode(videos).latent_dist.sample()
|
|
234
|
+
|
|
235
|
+
# 2. maybe encode text prompts
|
|
236
|
+
|
|
237
|
+
if exists(prompt_token_ids):
|
|
238
|
+
text_inputs = dict(input_ids = prompt_token_ids)
|
|
239
|
+
else:
|
|
240
|
+
text_inputs = self.tokenizer(
|
|
241
|
+
prompts,
|
|
242
|
+
return_tensors = "pt",
|
|
243
|
+
padding = True,
|
|
244
|
+
truncation = True,
|
|
245
|
+
max_length = 512
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
encoder_hidden_states = self.text_encoder(**text_inputs).last_hidden_state
|
|
249
|
+
|
|
250
|
+
# 3. Setup scheduler timesteps
|
|
251
|
+
self.scheduler.set_timesteps(num_inference_steps, device = self.device)
|
|
252
|
+
timesteps = self.scheduler.timesteps
|
|
253
|
+
|
|
254
|
+
# 4. Add noise to latents (start from pure noise scaled by initial sigma)
|
|
255
|
+
noise = torch.randn_like(latents)
|
|
256
|
+
latents = latents + noise * self.scheduler.init_noise_sigma
|
|
257
|
+
|
|
258
|
+
# 5. Denoising loop
|
|
259
|
+
for i, timestep in enumerate(timesteps):
|
|
260
|
+
# Scale model input
|
|
261
|
+
latent_model_input = self.scheduler.scale_model_input(latents, timestep)
|
|
262
|
+
|
|
263
|
+
# Predict noise residual
|
|
264
|
+
noise_pred = self.transformer(
|
|
265
|
+
hidden_states = latent_model_input,
|
|
266
|
+
encoder_hidden_states = encoder_hidden_states,
|
|
267
|
+
timestep = timestep.expand(batch),
|
|
268
|
+
return_dict = False
|
|
269
|
+
)[0]
|
|
270
|
+
|
|
271
|
+
# Compute previous noisy sample
|
|
272
|
+
latents = self.scheduler.step(noise_pred, timestep, latents, return_dict = False)[0]
|
|
273
|
+
|
|
274
|
+
assert len(self.cached_hidden_states) >= len(self.extract_layers), 'hidden states not captured'
|
|
275
|
+
|
|
276
|
+
# Return hidden states from the first denoising step
|
|
277
|
+
hiddens = self.cached_hidden_states[:len(self.extract_layers)]
|
|
278
|
+
|
|
279
|
+
for hidden in hiddens:
|
|
280
|
+
assert hidden.shape[-1] == self.dim_latent, f'hidden dim mismatch: expected {self.dim_latent}, got {hidden.shape[-1]}'
|
|
281
|
+
|
|
282
|
+
if not self.return_list:
|
|
283
|
+
return hiddens[0]
|
|
284
|
+
|
|
285
|
+
return hiddens
|
|
@@ -0,0 +1,797 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from functools import partial
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from torch import nn, cat, stack, is_tensor, tensor
|
|
6
|
+
from torch.nn import Module, ModuleList, Linear, GRU
|
|
7
|
+
|
|
8
|
+
import torch.nn.functional as F
|
|
9
|
+
|
|
10
|
+
import einx
|
|
11
|
+
from einops import einsum, rearrange, repeat, reduce
|
|
12
|
+
from einops.layers.torch import Rearrange
|
|
13
|
+
|
|
14
|
+
from x_mlps_pytorch import create_mlp
|
|
15
|
+
|
|
16
|
+
from tqdm import tqdm
|
|
17
|
+
|
|
18
|
+
from torch_einops_utils import (
|
|
19
|
+
lens_to_mask,
|
|
20
|
+
pad_left_ndim,
|
|
21
|
+
align_dims_left,
|
|
22
|
+
pad_at_dim,
|
|
23
|
+
pack_with_inverse,
|
|
24
|
+
masked_mean,
|
|
25
|
+
tree_map_tensor
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
from hyper_connections.mHCv2 import get_init_and_expand_reduce_stream_functions
|
|
29
|
+
|
|
30
|
+
# ein notation
|
|
31
|
+
|
|
32
|
+
# b - batch
|
|
33
|
+
# h - heads
|
|
34
|
+
# g - groups
|
|
35
|
+
# n - sequence
|
|
36
|
+
# i, j - sequence (source, target)
|
|
37
|
+
# d - feature dimension
|
|
38
|
+
|
|
39
|
+
# constants
|
|
40
|
+
|
|
41
|
+
LinearNoBias = partial(Linear, bias = False)
|
|
42
|
+
|
|
43
|
+
# functions
|
|
44
|
+
|
|
45
|
+
def exists(v):
|
|
46
|
+
return v is not None
|
|
47
|
+
|
|
48
|
+
def default(v, d):
|
|
49
|
+
return v if exists(v) else d
|
|
50
|
+
|
|
51
|
+
def identity(t):
|
|
52
|
+
return t
|
|
53
|
+
|
|
54
|
+
def divisible_by(num, den):
|
|
55
|
+
return (num % den) == 0
|
|
56
|
+
|
|
57
|
+
# wrappers
|
|
58
|
+
|
|
59
|
+
def eval_no_grad(fn):
|
|
60
|
+
def inner(*args, **kwargs):
|
|
61
|
+
with torch.no_grad():
|
|
62
|
+
fn.eval()
|
|
63
|
+
return fn(*args, **kwargs)
|
|
64
|
+
|
|
65
|
+
return inner
|
|
66
|
+
|
|
67
|
+
# tensor function
|
|
68
|
+
|
|
69
|
+
def cast_tensor(val, device = None):
|
|
70
|
+
return tensor(val, device = device) if not is_tensor(val) else val
|
|
71
|
+
|
|
72
|
+
def max_neg_value(t):
|
|
73
|
+
return -torch.finfo(t.dtype).max
|
|
74
|
+
|
|
75
|
+
def l2norm(t, eps = 1e-10):
|
|
76
|
+
return F.normalize(t, dim = -1, eps = eps)
|
|
77
|
+
|
|
78
|
+
# token shift from Peng et al. of RWKV
|
|
79
|
+
# cheap way to generate relative positions
|
|
80
|
+
|
|
81
|
+
def shift_feature_dim(t):
|
|
82
|
+
x, x_shift = t.chunk(2, dim = -1)
|
|
83
|
+
x_shift = pad_at_dim(x_shift, (1, -1), dim = 1)
|
|
84
|
+
return cat((x, x_shift), dim = -1)
|
|
85
|
+
|
|
86
|
+
# action normalization
|
|
87
|
+
|
|
88
|
+
class Normalizer(Module):
|
|
89
|
+
def __init__(
|
|
90
|
+
self,
|
|
91
|
+
mean,
|
|
92
|
+
std,
|
|
93
|
+
eps = 1e-6
|
|
94
|
+
):
|
|
95
|
+
super().__init__()
|
|
96
|
+
assert (std > 0.).all(), 'std must be positive'
|
|
97
|
+
self.eps = eps
|
|
98
|
+
|
|
99
|
+
self.register_buffer('mean', mean)
|
|
100
|
+
self.register_buffer('std', std)
|
|
101
|
+
|
|
102
|
+
def normalize(self, t):
|
|
103
|
+
mean, std = self.mean, self.std
|
|
104
|
+
return (t - mean) / std.clamp_min(self.eps)
|
|
105
|
+
|
|
106
|
+
def inverse_normalize(self, t):
|
|
107
|
+
mean, std = self.mean, self.std
|
|
108
|
+
return (t * std) + mean
|
|
109
|
+
|
|
110
|
+
# time
|
|
111
|
+
|
|
112
|
+
# they follow p0's research finding with the beta distribution
|
|
113
|
+
# lets stick with 0 noise to 1 data instead of the reverse
|
|
114
|
+
|
|
115
|
+
def default_sample_time_fn(time, s = 0.999):
|
|
116
|
+
return torch.sqrt(s - time)
|
|
117
|
+
|
|
118
|
+
class RandomFourierEmbed(Module):
|
|
119
|
+
def __init__(self, dim):
|
|
120
|
+
super().__init__()
|
|
121
|
+
self.proj = nn.Sequential(
|
|
122
|
+
Rearrange('... -> ... 1'),
|
|
123
|
+
nn.Linear(1, dim)
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
self.proj.requires_grad_(False)
|
|
127
|
+
|
|
128
|
+
def forward(self, times):
|
|
129
|
+
rand_proj = self.proj(times)
|
|
130
|
+
return torch.cos(2 * torch.pi * rand_proj)
|
|
131
|
+
|
|
132
|
+
# adaptive rmsnorm
|
|
133
|
+
|
|
134
|
+
class AdaptiveRMSNorm(Module):
|
|
135
|
+
def __init__(
|
|
136
|
+
self,
|
|
137
|
+
dim,
|
|
138
|
+
dim_time_cond,
|
|
139
|
+
eps = 1e-6,
|
|
140
|
+
ada_ln_zero_bias = -5.
|
|
141
|
+
):
|
|
142
|
+
super().__init__()
|
|
143
|
+
self.scale = dim ** 0.5
|
|
144
|
+
self.eps = eps
|
|
145
|
+
|
|
146
|
+
self.to_modulation = LinearNoBias(dim_time_cond, dim * 3)
|
|
147
|
+
self.split_modulation = Rearrange('... (three d) -> three ... d', three = 3)
|
|
148
|
+
|
|
149
|
+
nn.init.zeros_(self.to_modulation.weight)
|
|
150
|
+
|
|
151
|
+
self.ada_ln_zero_bias = ada_ln_zero_bias
|
|
152
|
+
|
|
153
|
+
def forward(
|
|
154
|
+
self,
|
|
155
|
+
tokens,
|
|
156
|
+
time_cond
|
|
157
|
+
):
|
|
158
|
+
if time_cond.ndim == 2:
|
|
159
|
+
time_cond = rearrange(time_cond, 'b d -> b 1 d')
|
|
160
|
+
|
|
161
|
+
modulations = self.to_modulation(time_cond)
|
|
162
|
+
|
|
163
|
+
scale, shift, gate = self.split_modulation(modulations)
|
|
164
|
+
|
|
165
|
+
normed = l2norm(tokens, self.eps) * self.scale
|
|
166
|
+
|
|
167
|
+
adaptive_normed = normed * (scale + 1.) + shift
|
|
168
|
+
|
|
169
|
+
gate_with_bias = (gate + self.ada_ln_zero_bias).sigmoid()
|
|
170
|
+
|
|
171
|
+
return adaptive_normed, gate_with_bias
|
|
172
|
+
|
|
173
|
+
# attention
|
|
174
|
+
|
|
175
|
+
class Attention(Module):
|
|
176
|
+
def __init__(
|
|
177
|
+
self,
|
|
178
|
+
dim,
|
|
179
|
+
*,
|
|
180
|
+
dim_context = None,
|
|
181
|
+
dim_head = 64,
|
|
182
|
+
heads = 8,
|
|
183
|
+
kv_heads = 2,
|
|
184
|
+
attn_gate_value = True,
|
|
185
|
+
norm_context = False
|
|
186
|
+
):
|
|
187
|
+
super().__init__()
|
|
188
|
+
dim_q_inner = dim_head * heads
|
|
189
|
+
dim_kv_inner = dim_head * kv_heads
|
|
190
|
+
|
|
191
|
+
dim_context = default(dim_context, dim)
|
|
192
|
+
self.context_norm = nn.RMSNorm(dim_context) if norm_context else nn.Identity()
|
|
193
|
+
|
|
194
|
+
self.scale = dim_head ** -0.5
|
|
195
|
+
|
|
196
|
+
self.to_queries = LinearNoBias(dim, dim_q_inner)
|
|
197
|
+
self.to_keys_values = LinearNoBias(dim_context, dim_kv_inner * 2)
|
|
198
|
+
|
|
199
|
+
self.attn_gate_value = nn.Sequential(LinearNoBias(dim, heads), Rearrange('b n (g h) -> b g h n 1', h = kv_heads))
|
|
200
|
+
|
|
201
|
+
self.to_out = LinearNoBias(dim_q_inner, dim)
|
|
202
|
+
|
|
203
|
+
assert divisible_by(heads, kv_heads)
|
|
204
|
+
groups = heads // kv_heads
|
|
205
|
+
|
|
206
|
+
self.split_q_heads = Rearrange('b n (g h d) -> b g h n d', g = groups, d = dim_head)
|
|
207
|
+
self.split_kv_heads = Rearrange('b n (h d) -> b h n d', d = dim_head)
|
|
208
|
+
self.merge_heads = Rearrange('b g h n d -> b n (g h d)')
|
|
209
|
+
|
|
210
|
+
def forward(
|
|
211
|
+
self,
|
|
212
|
+
tokens,
|
|
213
|
+
context = None,
|
|
214
|
+
context_mask = None,
|
|
215
|
+
kv = None,
|
|
216
|
+
return_kv = False
|
|
217
|
+
):
|
|
218
|
+
context = default(context, tokens)
|
|
219
|
+
|
|
220
|
+
queries = self.to_queries(tokens)
|
|
221
|
+
queries = self.split_q_heads(queries)
|
|
222
|
+
|
|
223
|
+
if not exists(kv):
|
|
224
|
+
context = self.context_norm(context)
|
|
225
|
+
|
|
226
|
+
keys, values = self.to_keys_values(context).chunk(2, dim = -1)
|
|
227
|
+
keys, values = tuple(self.split_kv_heads(t) for t in (keys, values))
|
|
228
|
+
else:
|
|
229
|
+
keys, values = kv
|
|
230
|
+
|
|
231
|
+
queries = queries * self.scale
|
|
232
|
+
|
|
233
|
+
sim = einsum(queries, keys, 'b g h i d, b h j d -> b g h i j')
|
|
234
|
+
|
|
235
|
+
if exists(context_mask):
|
|
236
|
+
mask_value = max_neg_value(sim)
|
|
237
|
+
sim = einx.where('b j, b g h i j,', context_mask, sim, mask_value)
|
|
238
|
+
|
|
239
|
+
attn = sim.softmax(dim = -1)
|
|
240
|
+
|
|
241
|
+
out = einsum(attn, values, 'b g h i j, b h j d -> b g h i d')
|
|
242
|
+
|
|
243
|
+
# https://openreview.net/forum?id=1b7whO4SfY - should become standard practice
|
|
244
|
+
|
|
245
|
+
out = out * self.attn_gate_value(tokens).sigmoid()
|
|
246
|
+
|
|
247
|
+
out = self.merge_heads(out)
|
|
248
|
+
|
|
249
|
+
out = self.to_out(out)
|
|
250
|
+
|
|
251
|
+
if not return_kv:
|
|
252
|
+
return out
|
|
253
|
+
|
|
254
|
+
return out, stack((keys, values))
|
|
255
|
+
|
|
256
|
+
# feedforward
|
|
257
|
+
|
|
258
|
+
class SwiGLUFeedForward(Module):
|
|
259
|
+
def __init__(
|
|
260
|
+
self,
|
|
261
|
+
dim,
|
|
262
|
+
*,
|
|
263
|
+
expansion_factor = 4.,
|
|
264
|
+
):
|
|
265
|
+
super().__init__()
|
|
266
|
+
dim_inner = int(dim * expansion_factor * 2 / 3)
|
|
267
|
+
|
|
268
|
+
self.proj_in = nn.Linear(dim, dim_inner * 2)
|
|
269
|
+
self.proj_out = nn.Linear(dim_inner, dim)
|
|
270
|
+
|
|
271
|
+
def forward(
|
|
272
|
+
self,
|
|
273
|
+
tokens
|
|
274
|
+
):
|
|
275
|
+
hidden, gates = self.proj_in(tokens).chunk(2, dim = -1)
|
|
276
|
+
|
|
277
|
+
out = hidden * F.gelu(gates)
|
|
278
|
+
|
|
279
|
+
return self.proj_out(out)
|
|
280
|
+
|
|
281
|
+
# classes
|
|
282
|
+
|
|
283
|
+
class MimicVideo(Module):
|
|
284
|
+
def __init__(
|
|
285
|
+
self,
|
|
286
|
+
dim,
|
|
287
|
+
video_predict_wrapper: Module | None = None,
|
|
288
|
+
*,
|
|
289
|
+
dim_video_hidden = None,
|
|
290
|
+
action_chunk_len = 32,
|
|
291
|
+
dim_action = 20,
|
|
292
|
+
dim_joint_state = 32,
|
|
293
|
+
proprio_mask_prob = 0.1,
|
|
294
|
+
depth = 8,
|
|
295
|
+
dim_head = 64,
|
|
296
|
+
heads = 8,
|
|
297
|
+
expansion_factor = 4.,
|
|
298
|
+
ada_ln_zero_bias = -5.,
|
|
299
|
+
dim_time_cond = None,
|
|
300
|
+
sample_time_fn = None,
|
|
301
|
+
train_time_rtc = False,
|
|
302
|
+
train_time_rtc_max_delay = None,
|
|
303
|
+
num_residual_streams = 1,
|
|
304
|
+
mhc_kwargs: dict = dict(),
|
|
305
|
+
action_mean_std: Tensor | None = None,
|
|
306
|
+
joint_mean_std: Tensor | None = None,
|
|
307
|
+
num_task_ids = 0,
|
|
308
|
+
num_advantage_ids = 0,
|
|
309
|
+
advantage_cfg_dropout = 0.25,
|
|
310
|
+
extracted_video_layer_indices: list[int] | None = None
|
|
311
|
+
):
|
|
312
|
+
super().__init__()
|
|
313
|
+
|
|
314
|
+
self.depth = depth
|
|
315
|
+
|
|
316
|
+
# maybe video predict
|
|
317
|
+
|
|
318
|
+
self.video_predict_wrapper = video_predict_wrapper
|
|
319
|
+
|
|
320
|
+
# action related
|
|
321
|
+
|
|
322
|
+
self.action_chunk_len = action_chunk_len
|
|
323
|
+
self.dim_action = dim_action
|
|
324
|
+
|
|
325
|
+
self.action_shape = (action_chunk_len, dim_action)
|
|
326
|
+
|
|
327
|
+
self.action_normalizer = None
|
|
328
|
+
|
|
329
|
+
if exists(action_mean_std):
|
|
330
|
+
assert action_mean_std.shape == (2, dim_action), f'must be in shape of (2 action_dim)'
|
|
331
|
+
self.action_normalizer = Normalizer(*action_mean_std)
|
|
332
|
+
|
|
333
|
+
# joint dim
|
|
334
|
+
|
|
335
|
+
self.dim_joint_state = dim_joint_state
|
|
336
|
+
|
|
337
|
+
dim_video_hidden = default(dim_video_hidden, video_predict_wrapper.dim_latent if exists(video_predict_wrapper) else None)
|
|
338
|
+
|
|
339
|
+
assert exists(dim_video_hidden), f'`dim_video_hidden` must be set or `video_predict_wrapper` passed in with `dim_latent`'
|
|
340
|
+
|
|
341
|
+
self.dim_video_hidden = dim_video_hidden
|
|
342
|
+
|
|
343
|
+
self.joint_normalizer = None
|
|
344
|
+
|
|
345
|
+
if exists(joint_mean_std):
|
|
346
|
+
assert joint_mean_std == (2, dim_joint_state)
|
|
347
|
+
self.joint_normalizer = Normalizer(*joint_mean_std)
|
|
348
|
+
|
|
349
|
+
# flow related
|
|
350
|
+
|
|
351
|
+
self.sample_time_fn = default(sample_time_fn, default_sample_time_fn)
|
|
352
|
+
|
|
353
|
+
# embed
|
|
354
|
+
|
|
355
|
+
self.to_action_tokens = Linear(dim_action, dim)
|
|
356
|
+
|
|
357
|
+
dim_time_cond = default(dim_time_cond, dim * 2)
|
|
358
|
+
|
|
359
|
+
self.to_fourier_embed = RandomFourierEmbed(dim) # used by deepmind, its fine
|
|
360
|
+
self.to_time_cond = create_mlp(dim_in = dim * 2, dim = dim_time_cond, depth = 2, activation = nn.SiLU())
|
|
361
|
+
|
|
362
|
+
# joint token related
|
|
363
|
+
|
|
364
|
+
self.to_joint_state_token = Linear(dim_joint_state, dim)
|
|
365
|
+
|
|
366
|
+
self.proprio_mask_prob = proprio_mask_prob
|
|
367
|
+
self.has_proprio_masking = proprio_mask_prob > 0.
|
|
368
|
+
|
|
369
|
+
self.proprio_mask_token = nn.Parameter(torch.randn(dim))
|
|
370
|
+
|
|
371
|
+
# video norm
|
|
372
|
+
|
|
373
|
+
self.video_hidden_norm = nn.RMSNorm(dim_video_hidden)
|
|
374
|
+
|
|
375
|
+
# manifold constrained hyper connections (mHC) from bytedance + deepseek
|
|
376
|
+
|
|
377
|
+
init_hyper_conn, self.expand_stream, self.reduce_stream = get_init_and_expand_reduce_stream_functions(num_residual_streams, dim = dim, add_stream_embed = True, **mhc_kwargs)
|
|
378
|
+
|
|
379
|
+
# rnn
|
|
380
|
+
|
|
381
|
+
self.rnn = GRU(dim, dim)
|
|
382
|
+
|
|
383
|
+
# transformer
|
|
384
|
+
|
|
385
|
+
layers = []
|
|
386
|
+
|
|
387
|
+
for _ in range(depth):
|
|
388
|
+
attn_adanorm = AdaptiveRMSNorm(dim = dim, dim_time_cond = dim_time_cond)
|
|
389
|
+
|
|
390
|
+
attn = Attention(dim = dim, dim_head = dim_head, heads = heads)
|
|
391
|
+
|
|
392
|
+
cross_attn_adanorm = AdaptiveRMSNorm(dim = dim, dim_time_cond = dim_time_cond)
|
|
393
|
+
|
|
394
|
+
cross_attn = Attention(dim = dim, dim_head = dim_head, dim_context = dim_video_hidden, heads = heads, norm_context = True)
|
|
395
|
+
|
|
396
|
+
ff_adanorm = AdaptiveRMSNorm(dim = dim, dim_time_cond = dim_time_cond, ada_ln_zero_bias = ada_ln_zero_bias)
|
|
397
|
+
|
|
398
|
+
ff = SwiGLUFeedForward(dim = dim, expansion_factor = expansion_factor)
|
|
399
|
+
|
|
400
|
+
# maybe hyper connect
|
|
401
|
+
|
|
402
|
+
attn_residual = init_hyper_conn()
|
|
403
|
+
cross_attn_residual = init_hyper_conn()
|
|
404
|
+
ff_residual = init_hyper_conn()
|
|
405
|
+
|
|
406
|
+
layers.append(ModuleList([
|
|
407
|
+
cross_attn_residual,
|
|
408
|
+
cross_attn_adanorm,
|
|
409
|
+
cross_attn,
|
|
410
|
+
attn_residual,
|
|
411
|
+
attn_adanorm,
|
|
412
|
+
attn,
|
|
413
|
+
ff_residual,
|
|
414
|
+
ff_adanorm,
|
|
415
|
+
ff
|
|
416
|
+
]))
|
|
417
|
+
|
|
418
|
+
self.layers = ModuleList(layers)
|
|
419
|
+
|
|
420
|
+
# predictions
|
|
421
|
+
|
|
422
|
+
self.to_pred_action_flow = nn.Sequential(
|
|
423
|
+
nn.RMSNorm(dim),
|
|
424
|
+
Linear(dim, dim_action, bias = False)
|
|
425
|
+
)
|
|
426
|
+
|
|
427
|
+
# inference related
|
|
428
|
+
|
|
429
|
+
# train time RTC related - https://arxiv.org/abs/2512.05964
|
|
430
|
+
|
|
431
|
+
self.train_time_rtc = train_time_rtc
|
|
432
|
+
|
|
433
|
+
assert not train_time_rtc or exists(train_time_rtc_max_delay)
|
|
434
|
+
self.train_time_rtc_max_delay = train_time_rtc_max_delay
|
|
435
|
+
|
|
436
|
+
# condition related
|
|
437
|
+
|
|
438
|
+
self.task_embed = nn.Embedding(num_task_ids, dim) if num_task_ids > 0 else None
|
|
439
|
+
|
|
440
|
+
self.advantage_embed = nn.Embedding(num_advantage_ids + 1, dim) if num_advantage_ids > 0 else None
|
|
441
|
+
|
|
442
|
+
assert advantage_cfg_dropout > 0.
|
|
443
|
+
|
|
444
|
+
self.advantage_cfg_dropout = advantage_cfg_dropout
|
|
445
|
+
|
|
446
|
+
# allow for researchers to explore beyond just one layer of pretrained
|
|
447
|
+
# we should also open up research into multiple pretrained models eventually
|
|
448
|
+
|
|
449
|
+
self.extracted_video_layer_indices = default(extracted_video_layer_indices, (0,) * depth)
|
|
450
|
+
assert len(self.extracted_video_layer_indices) == depth
|
|
451
|
+
|
|
452
|
+
# aux loss and device
|
|
453
|
+
|
|
454
|
+
self.register_buffer('zero', tensor(0.), persistent = False)
|
|
455
|
+
|
|
456
|
+
# only action parameters
|
|
457
|
+
|
|
458
|
+
def action_parameters(self):
|
|
459
|
+
video_model_params = set(self.video_predict_wrapper.parameters()) if exists(self.video_predict_wrapper) else {}
|
|
460
|
+
return set(self.parameters()) - video_model_params
|
|
461
|
+
|
|
462
|
+
@property
|
|
463
|
+
def device(self):
|
|
464
|
+
return self.zero.device
|
|
465
|
+
|
|
466
|
+
@torch.no_grad()
|
|
467
|
+
def sample(
|
|
468
|
+
self,
|
|
469
|
+
steps = 16,
|
|
470
|
+
batch_size = 1,
|
|
471
|
+
prefix_action_chunk = None,
|
|
472
|
+
disable_progress_bar = False,
|
|
473
|
+
**kwargs
|
|
474
|
+
):
|
|
475
|
+
|
|
476
|
+
self.eval()
|
|
477
|
+
|
|
478
|
+
inpainting = exists(prefix_action_chunk)
|
|
479
|
+
|
|
480
|
+
# times
|
|
481
|
+
|
|
482
|
+
times = torch.linspace(0., 1., steps + 1, device = self.device)[:-1]
|
|
483
|
+
delta = 1. / steps
|
|
484
|
+
|
|
485
|
+
# inpaint
|
|
486
|
+
|
|
487
|
+
if inpainting:
|
|
488
|
+
prefix_len = prefix_action_chunk.shape[1]
|
|
489
|
+
assert prefix_len < self.action_chunk_len
|
|
490
|
+
|
|
491
|
+
maybe_normed_prefix = prefix_action_chunk
|
|
492
|
+
|
|
493
|
+
if exists(self.action_normalizer):
|
|
494
|
+
maybe_normed_prefix = self.action_normalizer.normalize(prefix_action_chunk)
|
|
495
|
+
|
|
496
|
+
times = repeat(times, 'steps -> steps b n', b = batch_size, n = self.action_chunk_len).clone()
|
|
497
|
+
times[..., :prefix_len] = 1.
|
|
498
|
+
|
|
499
|
+
# noise
|
|
500
|
+
|
|
501
|
+
noise = torch.randn((batch_size, *self.action_shape), device = self.device)
|
|
502
|
+
|
|
503
|
+
# denoised action starts as noise
|
|
504
|
+
|
|
505
|
+
denoised = noise
|
|
506
|
+
|
|
507
|
+
cache = None
|
|
508
|
+
|
|
509
|
+
# denoise
|
|
510
|
+
|
|
511
|
+
for time in tqdm(times, disable = disable_progress_bar):
|
|
512
|
+
|
|
513
|
+
if inpainting:
|
|
514
|
+
denoised[:, :prefix_len] = maybe_normed_prefix
|
|
515
|
+
|
|
516
|
+
pred_flow, cache = self.forward(actions = denoised, time = time, cache = cache, return_cache = True, **kwargs)
|
|
517
|
+
|
|
518
|
+
denoised = denoised + delta * pred_flow
|
|
519
|
+
|
|
520
|
+
# handle action inverse norm
|
|
521
|
+
|
|
522
|
+
if exists(self.action_normalizer):
|
|
523
|
+
denoised = self.action_normalizer.inverse_normalize(denoised)
|
|
524
|
+
|
|
525
|
+
# final set, with unnormalized prefix, if inpainting
|
|
526
|
+
|
|
527
|
+
if inpainting:
|
|
528
|
+
denoised[:, :prefix_len] = prefix_action_chunk
|
|
529
|
+
|
|
530
|
+
return denoised
|
|
531
|
+
|
|
532
|
+
def forward(
|
|
533
|
+
self,
|
|
534
|
+
*,
|
|
535
|
+
actions, # (b na d)
|
|
536
|
+
joint_state, # (b)
|
|
537
|
+
task_ids = None, # (b)
|
|
538
|
+
advantage_ids = None, # (b)
|
|
539
|
+
video = None, # (b c t h w)
|
|
540
|
+
video_hiddens = None, # (b nv dv) - they use layer 19 of cosmos predict, at first denoising step. that's all
|
|
541
|
+
context_mask = None,
|
|
542
|
+
time = None, # () | (b) | (b n)
|
|
543
|
+
time_video_denoise = 0., # 0 is noise in the scheme i prefer - default to their optimal choice, but can be changed
|
|
544
|
+
prompts: list[str] | None = None,
|
|
545
|
+
prompt_token_ids = None,
|
|
546
|
+
detach_video_hiddens = False,
|
|
547
|
+
no_grad_video_model_forward = False,
|
|
548
|
+
cache = None,
|
|
549
|
+
return_cache = False,
|
|
550
|
+
return_flow = False
|
|
551
|
+
):
|
|
552
|
+
assert not exists(self.video_predict_wrapper) or (exists(prompts) ^ exists(prompt_token_ids))
|
|
553
|
+
assert actions.shape[-2:] == self.action_shape
|
|
554
|
+
|
|
555
|
+
if exists(self.action_normalizer):
|
|
556
|
+
actions = self.action_normalizer.normalize(actions)
|
|
557
|
+
|
|
558
|
+
batch, device = actions.shape[0], actions.device
|
|
559
|
+
orig_actions = actions
|
|
560
|
+
|
|
561
|
+
is_training = not exists(time) and not return_flow
|
|
562
|
+
|
|
563
|
+
if not exists(cache):
|
|
564
|
+
# handle maybe extraction of video hiddens
|
|
565
|
+
# only if cache is not given
|
|
566
|
+
|
|
567
|
+
assert exists(video) ^ exists(video_hiddens)
|
|
568
|
+
|
|
569
|
+
if not exists(video_hiddens):
|
|
570
|
+
assert exists(self.video_predict_wrapper), f'`video_predict_wrapper` must be passed in if raw video is passed into MimicVideo'
|
|
571
|
+
|
|
572
|
+
video_forward_wrap = eval_no_grad if no_grad_video_model_forward else identity
|
|
573
|
+
|
|
574
|
+
video_hiddens = video_forward_wrap(self.video_predict_wrapper)(video, prompts = prompts, prompt_token_ids = prompt_token_ids)
|
|
575
|
+
|
|
576
|
+
video_hiddens = tree_map_tensor(lambda t: t.to(self.device).float(), video_hiddens) # maybe bfloat to float32
|
|
577
|
+
|
|
578
|
+
video_hiddens = tree_map_tensor(lambda t: pack_with_inverse(t, 'b * d')[0], video_hiddens)
|
|
579
|
+
|
|
580
|
+
# handle video hiddens
|
|
581
|
+
|
|
582
|
+
if detach_video_hiddens:
|
|
583
|
+
video_hiddens = video_hiddens.detach()
|
|
584
|
+
|
|
585
|
+
if not isinstance(video_hiddens, list):
|
|
586
|
+
video_hiddens = [video_hiddens]
|
|
587
|
+
|
|
588
|
+
# handle caching
|
|
589
|
+
|
|
590
|
+
prev_cached_video_hiddens_kv = cache if exists(cache) else ((None,) * self.depth)
|
|
591
|
+
next_cached_video_hiddens_kv = []
|
|
592
|
+
|
|
593
|
+
# handle flow time conditioning
|
|
594
|
+
|
|
595
|
+
if is_training:
|
|
596
|
+
time = torch.rand((batch,), device = device)
|
|
597
|
+
time = self.sample_time_fn(time)
|
|
598
|
+
|
|
599
|
+
noise = torch.randn_like(actions)
|
|
600
|
+
flow = actions - noise
|
|
601
|
+
|
|
602
|
+
actions, left_aligned_time = align_dims_left((actions, time))
|
|
603
|
+
|
|
604
|
+
noised = noise.lerp(actions, left_aligned_time)
|
|
605
|
+
|
|
606
|
+
else:
|
|
607
|
+
noised = actions
|
|
608
|
+
|
|
609
|
+
# maybe train time rtc
|
|
610
|
+
|
|
611
|
+
action_loss_mask = None
|
|
612
|
+
|
|
613
|
+
if is_training and self.train_time_rtc:
|
|
614
|
+
|
|
615
|
+
rand_prefix_len = torch.randint(0, self.train_time_rtc_max_delay, (batch,), device = device)
|
|
616
|
+
action_prefix_mask = lens_to_mask(rand_prefix_len, self.action_chunk_len)
|
|
617
|
+
|
|
618
|
+
actions = einx.where('b na, b na d, b na d', action_prefix_mask, orig_actions, actions)
|
|
619
|
+
time = einx.where('b na, , b', action_prefix_mask, 1., time)
|
|
620
|
+
|
|
621
|
+
action_loss_mask = ~action_prefix_mask
|
|
622
|
+
|
|
623
|
+
if time.ndim == 0:
|
|
624
|
+
time = repeat(time, '-> b', b = batch)
|
|
625
|
+
|
|
626
|
+
# handle the video denoising times
|
|
627
|
+
|
|
628
|
+
time_video_denoise = cast_tensor(time_video_denoise)
|
|
629
|
+
|
|
630
|
+
if time_video_denoise.ndim == 0:
|
|
631
|
+
time_video_denoise = rearrange(time_video_denoise, '-> 1')
|
|
632
|
+
|
|
633
|
+
if time_video_denoise.shape[0] != batch:
|
|
634
|
+
time_video_denoise = repeat(time_video_denoise, '1 -> b', b = batch)
|
|
635
|
+
|
|
636
|
+
if time.ndim == 2:
|
|
637
|
+
time_video_denoise = repeat(time_video_denoise, 'b -> b n', n = time.shape[-1])
|
|
638
|
+
|
|
639
|
+
times = stack((time, time_video_denoise), dim = -1)
|
|
640
|
+
|
|
641
|
+
# embed
|
|
642
|
+
|
|
643
|
+
tokens = self.to_action_tokens(noised)
|
|
644
|
+
|
|
645
|
+
# setup empty tokens for various packed condition tokens
|
|
646
|
+
|
|
647
|
+
empty_token = tokens[:, 0:0]
|
|
648
|
+
|
|
649
|
+
# one layer of rnn for actions
|
|
650
|
+
|
|
651
|
+
rnn_out, _, = self.rnn(tokens)
|
|
652
|
+
tokens = rnn_out + tokens
|
|
653
|
+
|
|
654
|
+
# mask joint state token for proprioception masking training
|
|
655
|
+
|
|
656
|
+
if exists(self.joint_normalizer):
|
|
657
|
+
joint_state = self.joint_normalizer.normalize(joint_state)
|
|
658
|
+
|
|
659
|
+
joint_state_token = self.to_joint_state_token(joint_state)
|
|
660
|
+
|
|
661
|
+
if self.training and self.has_proprio_masking:
|
|
662
|
+
mask = torch.rand((batch,), device = device) < self.proprio_mask_prob
|
|
663
|
+
|
|
664
|
+
joint_state_token = einx.where('b, d, b d', mask, self.proprio_mask_token, joint_state_token)
|
|
665
|
+
|
|
666
|
+
# setup task
|
|
667
|
+
|
|
668
|
+
task_embed = empty_token
|
|
669
|
+
|
|
670
|
+
if exists(task_ids):
|
|
671
|
+
assert exists(self.task_embed)
|
|
672
|
+
task_embed = self.task_embed(task_ids)
|
|
673
|
+
|
|
674
|
+
# setup maybe advantage
|
|
675
|
+
|
|
676
|
+
advantage_embed = empty_token
|
|
677
|
+
|
|
678
|
+
if exists(advantage_ids):
|
|
679
|
+
assert exists(self.advantage_embed)
|
|
680
|
+
|
|
681
|
+
advantage_ids = advantage_ids + 1 # 0 for dropout
|
|
682
|
+
cfg_dropout = torch.rand_like(advantage_ids.float()) < self.advantage_cfg_dropout
|
|
683
|
+
|
|
684
|
+
advantage_ids = einx.where('b, , b', cfg_dropout, 0, advantage_ids)
|
|
685
|
+
|
|
686
|
+
advantage_embed = self.advantage_embed(advantage_ids)
|
|
687
|
+
|
|
688
|
+
# determine time - need to handle the sequence dimension given train time RTC and various conditioning tokens
|
|
689
|
+
|
|
690
|
+
if times.ndim == 3:
|
|
691
|
+
joint_task_advantage_times = 1 + int(exists(advantage_ids)) + int(exists(task_ids))
|
|
692
|
+
|
|
693
|
+
times = pad_at_dim(times, (joint_task_advantage_times, 0), dim = 1, value = 1.) # handle joint state token on the action
|
|
694
|
+
|
|
695
|
+
# fourier embed and mlp to time condition
|
|
696
|
+
|
|
697
|
+
fourier_embed = self.to_fourier_embed(times)
|
|
698
|
+
|
|
699
|
+
fourier_embed = rearrange(fourier_embed, '... times d -> ... (times d)')
|
|
700
|
+
|
|
701
|
+
time_cond = self.to_time_cond(fourier_embed)
|
|
702
|
+
|
|
703
|
+
# pack with action tokens for attention tower
|
|
704
|
+
|
|
705
|
+
tokens, inverse_pack = pack_with_inverse((advantage_embed, task_embed, joint_state_token, tokens), 'b * d')
|
|
706
|
+
|
|
707
|
+
# maybe expand streams
|
|
708
|
+
|
|
709
|
+
tokens = self.expand_stream(tokens)
|
|
710
|
+
|
|
711
|
+
# transformer layers
|
|
712
|
+
|
|
713
|
+
for ((
|
|
714
|
+
maybe_cross_attn_mhc,
|
|
715
|
+
cross_attn_norm,
|
|
716
|
+
cross_attn,
|
|
717
|
+
maybe_attn_mhc,
|
|
718
|
+
attn_norm,
|
|
719
|
+
attn,
|
|
720
|
+
maybe_ff_mhc,
|
|
721
|
+
ff_norm,
|
|
722
|
+
ff
|
|
723
|
+
), layer_video_hidden_index, cached_video_kv) in zip(self.layers, self.extracted_video_layer_indices, prev_cached_video_hiddens_kv):
|
|
724
|
+
|
|
725
|
+
# cross attention
|
|
726
|
+
|
|
727
|
+
tokens, add_residual = maybe_cross_attn_mhc(tokens)
|
|
728
|
+
|
|
729
|
+
tokens, gate = cross_attn_norm(tokens, time_cond)
|
|
730
|
+
|
|
731
|
+
layer_video_hidden = None
|
|
732
|
+
|
|
733
|
+
if exists(video_hiddens):
|
|
734
|
+
layer_video_hidden = video_hiddens[layer_video_hidden_index]
|
|
735
|
+
|
|
736
|
+
cross_attn_out, video_kv = cross_attn(tokens, context = layer_video_hidden, context_mask = context_mask, kv = cached_video_kv, return_kv = True)
|
|
737
|
+
|
|
738
|
+
tokens = add_residual(cross_attn_out * gate)
|
|
739
|
+
|
|
740
|
+
if return_cache:
|
|
741
|
+
next_cached_video_hiddens_kv.append(video_kv)
|
|
742
|
+
|
|
743
|
+
# self attention
|
|
744
|
+
|
|
745
|
+
tokens, add_residual = maybe_attn_mhc(tokens)
|
|
746
|
+
|
|
747
|
+
tokens, gate = attn_norm(tokens, time_cond)
|
|
748
|
+
|
|
749
|
+
tokens = add_residual(attn(tokens) * gate)
|
|
750
|
+
|
|
751
|
+
# prepare feedforward
|
|
752
|
+
|
|
753
|
+
tokens, add_residual = maybe_ff_mhc(tokens)
|
|
754
|
+
|
|
755
|
+
tokens, gate = ff_norm(tokens, time_cond)
|
|
756
|
+
|
|
757
|
+
# shift along time for action tokens for cheap relative positioning, which is better than messing with rope with such short action chunks
|
|
758
|
+
|
|
759
|
+
*non_action_tokens, tokens = inverse_pack(tokens)
|
|
760
|
+
|
|
761
|
+
tokens = shift_feature_dim(tokens)
|
|
762
|
+
|
|
763
|
+
tokens, _ = pack_with_inverse((*non_action_tokens, tokens), 'b * d')
|
|
764
|
+
|
|
765
|
+
# feedforward
|
|
766
|
+
|
|
767
|
+
tokens = add_residual(ff(tokens) * gate)
|
|
768
|
+
|
|
769
|
+
# maybe reduce streams
|
|
770
|
+
|
|
771
|
+
tokens = self.reduce_stream(tokens)
|
|
772
|
+
|
|
773
|
+
# remove joint token
|
|
774
|
+
|
|
775
|
+
*_, tokens = inverse_pack(tokens)
|
|
776
|
+
|
|
777
|
+
# prediction
|
|
778
|
+
|
|
779
|
+
pred_flow = self.to_pred_action_flow(tokens)
|
|
780
|
+
|
|
781
|
+
if not is_training:
|
|
782
|
+
# flow
|
|
783
|
+
|
|
784
|
+
out = pred_flow
|
|
785
|
+
else:
|
|
786
|
+
# mse flow loss
|
|
787
|
+
|
|
788
|
+
flow_loss = F.mse_loss(pred_flow, flow, reduction = 'none')
|
|
789
|
+
|
|
790
|
+
out = masked_mean(flow_loss, action_loss_mask)
|
|
791
|
+
|
|
792
|
+
if not return_cache:
|
|
793
|
+
return out
|
|
794
|
+
|
|
795
|
+
# handle returning of cache
|
|
796
|
+
|
|
797
|
+
return out, next_cached_video_hiddens_kv
|
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: mimic-video
|
|
3
|
+
Version: 0.0.31
|
|
4
|
+
Summary: Mimic Video
|
|
5
|
+
Project-URL: Homepage, https://pypi.org/project/mimic-video/
|
|
6
|
+
Project-URL: Repository, https://github.com/lucidrains/mimic-video
|
|
7
|
+
Author-email: Phil Wang <lucidrains@gmail.com>
|
|
8
|
+
License: MIT License
|
|
9
|
+
|
|
10
|
+
Copyright (c) 2025 Phil Wang
|
|
11
|
+
|
|
12
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
13
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
14
|
+
in the Software without restriction, including without limitation the rights
|
|
15
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
16
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
17
|
+
furnished to do so, subject to the following conditions:
|
|
18
|
+
|
|
19
|
+
The above copyright notice and this permission notice shall be included in all
|
|
20
|
+
copies or substantial portions of the Software.
|
|
21
|
+
|
|
22
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
23
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
24
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
25
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
26
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
27
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
28
|
+
SOFTWARE.
|
|
29
|
+
License-File: LICENSE
|
|
30
|
+
Keywords: artificial intelligence,attention mechanism,deep learning,video language action model
|
|
31
|
+
Classifier: Development Status :: 4 - Beta
|
|
32
|
+
Classifier: Intended Audience :: Developers
|
|
33
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
34
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
35
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
36
|
+
Requires-Python: >=3.10
|
|
37
|
+
Requires-Dist: einops>=0.8.1
|
|
38
|
+
Requires-Dist: einx>=0.3.0
|
|
39
|
+
Requires-Dist: hyper-connections>=0.4.3
|
|
40
|
+
Requires-Dist: torch-einops-utils>=0.0.14
|
|
41
|
+
Requires-Dist: torch>=2.5
|
|
42
|
+
Requires-Dist: tqdm
|
|
43
|
+
Requires-Dist: x-mlps-pytorch
|
|
44
|
+
Provides-Extra: examples
|
|
45
|
+
Provides-Extra: test
|
|
46
|
+
Requires-Dist: accelerate; extra == 'test'
|
|
47
|
+
Requires-Dist: diffusers>=0.32.0; extra == 'test'
|
|
48
|
+
Requires-Dist: pytest; extra == 'test'
|
|
49
|
+
Requires-Dist: transformers; extra == 'test'
|
|
50
|
+
Description-Content-Type: text/markdown
|
|
51
|
+
|
|
52
|
+
<img src="./mimic-video.png" width="450px"></img>
|
|
53
|
+
|
|
54
|
+
## Mimic Video (wip)
|
|
55
|
+
|
|
56
|
+
Implementation of [Mimic-Video](https://mimic-video.github.io/), Video-Action Models for Generalizable Robot Control Beyond VLAs
|
|
57
|
+
|
|
58
|
+
## Appreciation
|
|
59
|
+
|
|
60
|
+
- [Pranoy](https://github.com/pranoyr) for submitting a pull request for proprioception masking
|
|
61
|
+
|
|
62
|
+
## Install
|
|
63
|
+
|
|
64
|
+
```shell
|
|
65
|
+
$ pip install mimic-video
|
|
66
|
+
```
|
|
67
|
+
|
|
68
|
+
## Usage
|
|
69
|
+
|
|
70
|
+
```python
|
|
71
|
+
import torch
|
|
72
|
+
|
|
73
|
+
# video wrapper
|
|
74
|
+
# but will be agnostic to the model
|
|
75
|
+
|
|
76
|
+
from mimic_video.cosmos_predict import CosmosPredictWrapper
|
|
77
|
+
|
|
78
|
+
video_wrapper = CosmosPredictWrapper(
|
|
79
|
+
extract_layer = 1,
|
|
80
|
+
random_weights = True,
|
|
81
|
+
tiny = True
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
# mimic video
|
|
85
|
+
|
|
86
|
+
from mimic_video import MimicVideo
|
|
87
|
+
|
|
88
|
+
model = MimicVideo(512, video_wrapper)
|
|
89
|
+
|
|
90
|
+
# states
|
|
91
|
+
|
|
92
|
+
video = torch.rand(2, 3, 3, 32, 32)
|
|
93
|
+
|
|
94
|
+
joint_state = torch.randn(2, 32)
|
|
95
|
+
|
|
96
|
+
# action
|
|
97
|
+
|
|
98
|
+
actions = torch.randn(2, 32, 20)
|
|
99
|
+
|
|
100
|
+
# training
|
|
101
|
+
|
|
102
|
+
loss = model(
|
|
103
|
+
prompts = [
|
|
104
|
+
'put the package on the conveyer belt',
|
|
105
|
+
'pass the butter'
|
|
106
|
+
],
|
|
107
|
+
video = video,
|
|
108
|
+
actions = actions,
|
|
109
|
+
joint_state = joint_state
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
loss.backward()
|
|
113
|
+
|
|
114
|
+
# inference
|
|
115
|
+
|
|
116
|
+
actions = model.sample(
|
|
117
|
+
prompts = 'peel the orange',
|
|
118
|
+
video = video[:1],
|
|
119
|
+
joint_state = joint_state[:1]
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
assert actions.shape == (1, 32, 20)
|
|
123
|
+
```
|
|
124
|
+
|
|
125
|
+
## Contributing
|
|
126
|
+
|
|
127
|
+
First make sure `pytest` and test dependencies are installed with
|
|
128
|
+
|
|
129
|
+
```shell
|
|
130
|
+
$ pip install '.[test]'
|
|
131
|
+
```
|
|
132
|
+
|
|
133
|
+
Then add your test to `tests/test_mimic_video.py` and run
|
|
134
|
+
|
|
135
|
+
```shell
|
|
136
|
+
$ pytest tests
|
|
137
|
+
```
|
|
138
|
+
|
|
139
|
+
That's it
|
|
140
|
+
|
|
141
|
+
## Citations
|
|
142
|
+
|
|
143
|
+
```bibtex
|
|
144
|
+
@inproceedings{Pai2025mimicvideoVM,
|
|
145
|
+
title = {mimic-video: Video-Action Models for Generalizable Robot Control Beyond VLAs},
|
|
146
|
+
author = {Jonas Pai and Liam Achenbach and Victoriano Montesinos and Benedek Forrai and Oier Mees and Elvis Nava},
|
|
147
|
+
year = {2025},
|
|
148
|
+
url = {https://api.semanticscholar.org/CorpusID:283920528}
|
|
149
|
+
}
|
|
150
|
+
```
|
|
151
|
+
|
|
152
|
+
```bibtex
|
|
153
|
+
@misc{black2025trainingtimeactionconditioningefficient,
|
|
154
|
+
title = {Training-Time Action Conditioning for Efficient Real-Time Chunking},
|
|
155
|
+
author = {Kevin Black and Allen Z. Ren and Michael Equi and Sergey Levine},
|
|
156
|
+
year = {2025},
|
|
157
|
+
eprint = {2512.05964},
|
|
158
|
+
archivePrefix = {arXiv},
|
|
159
|
+
primaryClass = {cs.RO},
|
|
160
|
+
url = {https://arxiv.org/abs/2512.05964},
|
|
161
|
+
}
|
|
162
|
+
```
|
|
163
|
+
|
|
164
|
+
```bibtex
|
|
165
|
+
@misc{intelligence2025pi06vlalearnsexperience,
|
|
166
|
+
title = {$\pi^{*}_{0.6}$: a VLA That Learns From Experience},
|
|
167
|
+
author = {Physical Intelligence and Ali Amin and Raichelle Aniceto and Ashwin Balakrishna and Kevin Black and Ken Conley and Grace Connors and James Darpinian and Karan Dhabalia and Jared DiCarlo and Danny Driess and Michael Equi and Adnan Esmail and Yunhao Fang and Chelsea Finn and Catherine Glossop and Thomas Godden and Ivan Goryachev and Lachy Groom and Hunter Hancock and Karol Hausman and Gashon Hussein and Brian Ichter and Szymon Jakubczak and Rowan Jen and Tim Jones and Ben Katz and Liyiming Ke and Chandra Kuchi and Marinda Lamb and Devin LeBlanc and Sergey Levine and Adrian Li-Bell and Yao Lu and Vishnu Mano and Mohith Mothukuri and Suraj Nair and Karl Pertsch and Allen Z. Ren and Charvi Sharma and Lucy Xiaoyang Shi and Laura Smith and Jost Tobias Springenberg and Kyle Stachowicz and Will Stoeckle and Alex Swerdlow and James Tanner and Marcel Torne and Quan Vuong and Anna Walling and Haohuan Wang and Blake Williams and Sukwon Yoo and Lili Yu and Ury Zhilinsky and Zhiyuan Zhou},
|
|
168
|
+
year = {2025},
|
|
169
|
+
eprint = {2511.14759},
|
|
170
|
+
archivePrefix = {arXiv},
|
|
171
|
+
primaryClass = {cs.LG},
|
|
172
|
+
url = {https://arxiv.org/abs/2511.14759},
|
|
173
|
+
}
|
|
174
|
+
```
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
mimic_video/__init__.py,sha256=Rs3QeBBGBKKi1U1ykcyeBrCL2XCbfNvppeeD1Fb1pdY,47
|
|
2
|
+
mimic_video/cosmos_predict.py,sha256=TJYIBFY2OGMAb1-5U5AOYZPbUKx-CuBIN3qbnpVWj4k,9078
|
|
3
|
+
mimic_video/mimic_video.py,sha256=9jMkWzHu9qxVN5Uez8jTgErsWyaiJz9r-FVhPpf7Fbg,23165
|
|
4
|
+
mimic_video-0.0.31.dist-info/METADATA,sha256=icex90WVH5NJhaZVUJ-2NI3X3qDHZhmKBxxsoe20kTQ,5876
|
|
5
|
+
mimic_video-0.0.31.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
6
|
+
mimic_video-0.0.31.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
7
|
+
mimic_video-0.0.31.dist-info/RECORD,,
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 Phil Wang
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|