mimic-video 0.0.1__py3-none-any.whl → 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mimic-video might be problematic. Click here for more details.
- mimic_video/__init__.py +0 -1
- mimic_video/cosmos_predict.py +269 -0
- mimic_video/mimic_video.py +237 -35
- {mimic_video-0.0.1.dist-info → mimic_video-0.0.19.dist-info}/METADATA +72 -1
- mimic_video-0.0.19.dist-info/RECORD +7 -0
- mimic_video-0.0.1.dist-info/RECORD +0 -6
- {mimic_video-0.0.1.dist-info → mimic_video-0.0.19.dist-info}/WHEEL +0 -0
- {mimic_video-0.0.1.dist-info → mimic_video-0.0.19.dist-info}/licenses/LICENSE +0 -0
mimic_video/__init__.py
CHANGED
|
@@ -0,0 +1,269 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from torch.nn import Module
|
|
8
|
+
from torch import nn, Tensor
|
|
9
|
+
from einops import rearrange
|
|
10
|
+
|
|
11
|
+
from diffusers.models.transformers.transformer_cosmos import CosmosTransformer3DModel
|
|
12
|
+
from diffusers.models.autoencoders.autoencoder_kl_cosmos import AutoencoderKLCosmos
|
|
13
|
+
from diffusers.schedulers.scheduling_edm_euler import EDMEulerScheduler
|
|
14
|
+
from transformers import T5EncoderModel, T5TokenizerFast, T5Config
|
|
15
|
+
|
|
16
|
+
# helpers
|
|
17
|
+
|
|
18
|
+
def exists(v):
|
|
19
|
+
return v is not None
|
|
20
|
+
|
|
21
|
+
def identity(t):
|
|
22
|
+
return t
|
|
23
|
+
|
|
24
|
+
def default(v, d):
|
|
25
|
+
return v if exists(v) else d
|
|
26
|
+
|
|
27
|
+
# constants
|
|
28
|
+
|
|
29
|
+
TINY_TRANSFORMER_CONFIG = dict(
|
|
30
|
+
in_channels = 16,
|
|
31
|
+
out_channels = 16,
|
|
32
|
+
num_attention_heads = 1,
|
|
33
|
+
attention_head_dim = 16,
|
|
34
|
+
mlp_ratio = 1.0,
|
|
35
|
+
text_embed_dim = 32,
|
|
36
|
+
adaln_lora_dim = 32,
|
|
37
|
+
patch_size = (1, 2, 2),
|
|
38
|
+
max_size = (4, 16, 16),
|
|
39
|
+
extra_pos_embed_type = None,
|
|
40
|
+
concat_padding_mask = False,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
TINY_VAE_CONFIG = dict(
|
|
44
|
+
in_channels = 3,
|
|
45
|
+
out_channels = 3,
|
|
46
|
+
latent_channels = 16,
|
|
47
|
+
encoder_block_out_channels = (8, 16),
|
|
48
|
+
decode_block_out_channels = (8, 16),
|
|
49
|
+
temporal_compression_ratio = 4,
|
|
50
|
+
spatial_compression_ratio = 4,
|
|
51
|
+
num_layers = 1,
|
|
52
|
+
attention_resolutions = (),
|
|
53
|
+
resolution = 64,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
TINY_T5_CONFIG = dict(
|
|
57
|
+
vocab_size = 32128,
|
|
58
|
+
d_model = 32,
|
|
59
|
+
d_kv = 8,
|
|
60
|
+
d_ff = 64,
|
|
61
|
+
num_layers = 1,
|
|
62
|
+
num_heads = 1,
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
REAL_TRANSFORMER_CONFIG = dict(
|
|
66
|
+
in_channels = 16,
|
|
67
|
+
out_channels = 16,
|
|
68
|
+
num_attention_heads = 32,
|
|
69
|
+
attention_head_dim = 128,
|
|
70
|
+
mlp_ratio = 4.0,
|
|
71
|
+
text_embed_dim = 1024,
|
|
72
|
+
patch_size = (1, 2, 2),
|
|
73
|
+
max_size = (128, 240, 240),
|
|
74
|
+
extra_pos_embed_type = "learnable",
|
|
75
|
+
concat_padding_mask = True,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
REAL_VAE_CONFIG = dict(
|
|
79
|
+
in_channels = 3,
|
|
80
|
+
out_channels = 3,
|
|
81
|
+
latent_channels = 16,
|
|
82
|
+
encoder_block_out_channels = (128, 256, 512, 512),
|
|
83
|
+
decode_block_out_channels = (256, 512, 512, 512),
|
|
84
|
+
temporal_compression_ratio = 8,
|
|
85
|
+
spatial_compression_ratio = 8,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
REAL_T5_CONFIG = dict(
|
|
89
|
+
vocab_size = 32128,
|
|
90
|
+
d_model = 1024,
|
|
91
|
+
d_kv = 64,
|
|
92
|
+
d_ff = 2048,
|
|
93
|
+
num_layers = 12,
|
|
94
|
+
num_heads = 16,
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
# main class
|
|
98
|
+
|
|
99
|
+
class CosmosPredictWrapper(Module):
|
|
100
|
+
"""
|
|
101
|
+
Wraps Cosmos VAE + DiT for extracting hidden states from a video.
|
|
102
|
+
Supports proper EDM Euler denoising steps.
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
def __init__(
|
|
106
|
+
self,
|
|
107
|
+
model_name: str = 'nvidia/Cosmos-1.0-Diffusion-7B-Video2World',
|
|
108
|
+
extract_layer: int = 19,
|
|
109
|
+
random_weights: bool = False,
|
|
110
|
+
tiny: bool = False,
|
|
111
|
+
normalize = lambda t: (t - 0.5) * 2.0
|
|
112
|
+
):
|
|
113
|
+
super().__init__()
|
|
114
|
+
self.extract_layer = extract_layer
|
|
115
|
+
self.hook_handle = None
|
|
116
|
+
self.cached_hidden_states: list[Tensor] = []
|
|
117
|
+
|
|
118
|
+
if random_weights:
|
|
119
|
+
self._init_random_weights(tiny = tiny)
|
|
120
|
+
else:
|
|
121
|
+
self._init_pretrained(model_name)
|
|
122
|
+
|
|
123
|
+
# Initialize scheduler
|
|
124
|
+
self.scheduler = EDMEulerScheduler()
|
|
125
|
+
|
|
126
|
+
# store hidden dim for consumers
|
|
127
|
+
self.dim_latent = self.transformer.config.num_attention_heads * self.transformer.config.attention_head_dim
|
|
128
|
+
|
|
129
|
+
# maybe normalize
|
|
130
|
+
self.normalize = normalize
|
|
131
|
+
|
|
132
|
+
self._register_hook()
|
|
133
|
+
|
|
134
|
+
@property
|
|
135
|
+
def device(self):
|
|
136
|
+
return next(self.parameters()).device
|
|
137
|
+
|
|
138
|
+
def _init_pretrained(self, model_name: str):
|
|
139
|
+
"""Load pretrained weights from HuggingFace"""
|
|
140
|
+
from diffusers import CosmosVideoToWorldPipeline
|
|
141
|
+
|
|
142
|
+
pipeline = CosmosVideoToWorldPipeline.from_pretrained(model_name)
|
|
143
|
+
|
|
144
|
+
# Extract components we need
|
|
145
|
+
self.vae = pipeline.vae
|
|
146
|
+
self.transformer = pipeline.transformer
|
|
147
|
+
self.text_encoder = pipeline.text_encoder
|
|
148
|
+
self.tokenizer = pipeline.tokenizer
|
|
149
|
+
|
|
150
|
+
# Clean up pipeline
|
|
151
|
+
del pipeline
|
|
152
|
+
|
|
153
|
+
def _init_random_weights(self, tiny: bool = False):
|
|
154
|
+
"""Initialize with random weights for testing"""
|
|
155
|
+
|
|
156
|
+
transformer_config = TINY_TRANSFORMER_CONFIG if tiny else REAL_TRANSFORMER_CONFIG
|
|
157
|
+
vae_config = TINY_VAE_CONFIG if tiny else REAL_VAE_CONFIG
|
|
158
|
+
t5_config_dict = TINY_T5_CONFIG if tiny else REAL_T5_CONFIG
|
|
159
|
+
|
|
160
|
+
num_layers = max(2, self.extract_layer + 1)
|
|
161
|
+
if not tiny:
|
|
162
|
+
num_layers = max(28, num_layers)
|
|
163
|
+
|
|
164
|
+
self.transformer = CosmosTransformer3DModel(
|
|
165
|
+
num_layers = num_layers,
|
|
166
|
+
**transformer_config
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
self.vae = AutoencoderKLCosmos(**vae_config)
|
|
170
|
+
|
|
171
|
+
t5_config = T5Config(**t5_config_dict)
|
|
172
|
+
self.text_encoder = T5EncoderModel(t5_config)
|
|
173
|
+
self.tokenizer = T5TokenizerFast.from_pretrained("google-t5/t5-small")
|
|
174
|
+
|
|
175
|
+
def __del__(self):
|
|
176
|
+
if exists(self.hook_handle):
|
|
177
|
+
self.hook_handle.remove()
|
|
178
|
+
|
|
179
|
+
def _register_hook(self):
|
|
180
|
+
assert hasattr(self.transformer, 'transformer_blocks'), 'transformer must have transformer_blocks'
|
|
181
|
+
assert len(self.transformer.transformer_blocks) > self.extract_layer, f'layer {self.extract_layer} out of bounds'
|
|
182
|
+
|
|
183
|
+
target_layer = self.transformer.transformer_blocks[self.extract_layer]
|
|
184
|
+
|
|
185
|
+
def hook_fn(module, inp, out):
|
|
186
|
+
self.cached_hidden_states.append(out.detach().cpu())
|
|
187
|
+
|
|
188
|
+
self.hook_handle = target_layer.register_forward_hook(hook_fn)
|
|
189
|
+
|
|
190
|
+
def forward(
|
|
191
|
+
self,
|
|
192
|
+
videos: Tensor,
|
|
193
|
+
prompts: str | list[str] | None = None,
|
|
194
|
+
prompt_token_ids: Tensor | None = None,
|
|
195
|
+
num_inference_steps: int = 1,
|
|
196
|
+
) -> Tensor:
|
|
197
|
+
"""
|
|
198
|
+
videos: (batch, frames, channels, height, width) in [0, 1]
|
|
199
|
+
num_inference_steps: number of denoising steps to run
|
|
200
|
+
returns: hidden states tensor from the specified transformer layer (from first step)
|
|
201
|
+
"""
|
|
202
|
+
batch, t, c, h, w = videos.shape
|
|
203
|
+
|
|
204
|
+
assert exists(prompts) ^ exists(prompt_token_ids)
|
|
205
|
+
|
|
206
|
+
# Scale videos from [0, 1] to [-1, 1] for Cosmos VAE
|
|
207
|
+
|
|
208
|
+
videos = self.normalize(videos)
|
|
209
|
+
|
|
210
|
+
if isinstance(prompts, str):
|
|
211
|
+
prompts = [prompts] * batch
|
|
212
|
+
|
|
213
|
+
self.cached_hidden_states.clear()
|
|
214
|
+
|
|
215
|
+
# Move video to device and rearrange for VAE: (B, T, C, H, W) -> (B, C, T, H, W)
|
|
216
|
+
videos = rearrange(videos, 'b t c h w -> b c t h w')
|
|
217
|
+
|
|
218
|
+
with torch.inference_mode():
|
|
219
|
+
# 1. encode video to latents via VAE
|
|
220
|
+
|
|
221
|
+
latents = self.vae.encode(videos).latent_dist.sample()
|
|
222
|
+
|
|
223
|
+
# 2. maybe encode text prompts
|
|
224
|
+
|
|
225
|
+
if exists(prompt_token_ids):
|
|
226
|
+
text_inputs = dict(input_ids = prompt_token_ids)
|
|
227
|
+
else:
|
|
228
|
+
text_inputs = self.tokenizer(
|
|
229
|
+
prompts,
|
|
230
|
+
return_tensors = "pt",
|
|
231
|
+
padding = True,
|
|
232
|
+
truncation = True,
|
|
233
|
+
max_length = 512
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
encoder_hidden_states = self.text_encoder(**text_inputs).last_hidden_state
|
|
237
|
+
|
|
238
|
+
# 3. Setup scheduler timesteps
|
|
239
|
+
self.scheduler.set_timesteps(num_inference_steps, device = self.device)
|
|
240
|
+
timesteps = self.scheduler.timesteps
|
|
241
|
+
|
|
242
|
+
# 4. Add noise to latents (start from pure noise scaled by initial sigma)
|
|
243
|
+
noise = torch.randn_like(latents)
|
|
244
|
+
latents = latents + noise * self.scheduler.init_noise_sigma
|
|
245
|
+
|
|
246
|
+
# 5. Denoising loop
|
|
247
|
+
for i, timestep in enumerate(timesteps):
|
|
248
|
+
# Scale model input
|
|
249
|
+
latent_model_input = self.scheduler.scale_model_input(latents, timestep)
|
|
250
|
+
|
|
251
|
+
# Predict noise residual
|
|
252
|
+
noise_pred = self.transformer(
|
|
253
|
+
hidden_states = latent_model_input,
|
|
254
|
+
encoder_hidden_states = encoder_hidden_states,
|
|
255
|
+
timestep = timestep.expand(batch),
|
|
256
|
+
return_dict = False
|
|
257
|
+
)[0]
|
|
258
|
+
|
|
259
|
+
# Compute previous noisy sample
|
|
260
|
+
latents = self.scheduler.step(noise_pred, timestep, latents, return_dict = False)[0]
|
|
261
|
+
|
|
262
|
+
assert len(self.cached_hidden_states) > 0, 'hidden states not captured'
|
|
263
|
+
|
|
264
|
+
# Return hidden states from the first denoising step
|
|
265
|
+
hidden = self.cached_hidden_states[0]
|
|
266
|
+
|
|
267
|
+
assert hidden.shape[-1] == self.dim_latent, f'hidden dim mismatch: expected {self.dim_latent_hidden}, got {hidden.shape[-1]}'
|
|
268
|
+
|
|
269
|
+
return hidden
|
mimic_video/mimic_video.py
CHANGED
|
@@ -1,18 +1,24 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import torch
|
|
2
|
-
from torch import nn
|
|
3
|
-
from torch.nn import Module, ModuleList, Linear
|
|
4
|
+
from torch import nn, cat, stack, is_tensor, tensor
|
|
5
|
+
from torch.nn import Module, ModuleList, Linear, GRU
|
|
4
6
|
|
|
5
7
|
import torch.nn.functional as F
|
|
6
8
|
|
|
7
9
|
import einx
|
|
8
|
-
from einops import einsum, rearrange
|
|
10
|
+
from einops import einsum, rearrange, repeat
|
|
9
11
|
from einops.layers.torch import Rearrange
|
|
10
12
|
|
|
11
13
|
from x_mlps_pytorch import create_mlp
|
|
12
14
|
|
|
15
|
+
from tqdm import tqdm
|
|
16
|
+
|
|
13
17
|
from torch_einops_utils import (
|
|
14
18
|
pad_left_ndim,
|
|
15
|
-
align_dims_left
|
|
19
|
+
align_dims_left,
|
|
20
|
+
pad_at_dim,
|
|
21
|
+
pack_with_inverse,
|
|
16
22
|
)
|
|
17
23
|
|
|
18
24
|
# ein notation
|
|
@@ -37,12 +43,23 @@ def divisible_by(num, den):
|
|
|
37
43
|
|
|
38
44
|
# tensor function
|
|
39
45
|
|
|
46
|
+
def cast_tensor(val, device = None):
|
|
47
|
+
return tensor(val, device = device) if not is_tensor(val) else val
|
|
48
|
+
|
|
40
49
|
def max_neg_value(t):
|
|
41
50
|
return -torch.finfo(t.dtype).max
|
|
42
51
|
|
|
43
52
|
def l2norm(t, eps = 1e-10):
|
|
44
53
|
return F.normalize(t, dim = -1, eps = eps)
|
|
45
54
|
|
|
55
|
+
# token shift from Peng et al. of RWKV
|
|
56
|
+
# cheap way to generate relative positions
|
|
57
|
+
|
|
58
|
+
def shift_feature_dim(t):
|
|
59
|
+
x, x_shift = t.chunk(2, dim = -1)
|
|
60
|
+
x_shift = pad_at_dim(x_shift, (1, -1), dim = 1)
|
|
61
|
+
return cat((x, x_shift), dim = -1)
|
|
62
|
+
|
|
46
63
|
# time
|
|
47
64
|
|
|
48
65
|
# they follow p0's research finding with the beta distribution
|
|
@@ -72,7 +89,8 @@ class AdaptiveRMSNorm(Module):
|
|
|
72
89
|
self,
|
|
73
90
|
dim,
|
|
74
91
|
dim_time_cond,
|
|
75
|
-
eps = 1e-6
|
|
92
|
+
eps = 1e-6,
|
|
93
|
+
ada_ln_zero_bias = -5.
|
|
76
94
|
):
|
|
77
95
|
super().__init__()
|
|
78
96
|
self.scale = dim ** 0.5
|
|
@@ -83,6 +101,8 @@ class AdaptiveRMSNorm(Module):
|
|
|
83
101
|
|
|
84
102
|
nn.init.zeros_(self.to_modulation.weight)
|
|
85
103
|
|
|
104
|
+
self.ada_ln_zero_bias = ada_ln_zero_bias
|
|
105
|
+
|
|
86
106
|
def forward(
|
|
87
107
|
self,
|
|
88
108
|
tokens,
|
|
@@ -100,7 +120,9 @@ class AdaptiveRMSNorm(Module):
|
|
|
100
120
|
|
|
101
121
|
adaptive_normed = normed * (scale + 1.) + shift
|
|
102
122
|
|
|
103
|
-
|
|
123
|
+
gate_with_bias = (gate + self.ada_ln_zero_bias).sigmoid()
|
|
124
|
+
|
|
125
|
+
return adaptive_normed, gate_with_bias
|
|
104
126
|
|
|
105
127
|
# attention
|
|
106
128
|
|
|
@@ -136,15 +158,20 @@ class Attention(Module):
|
|
|
136
158
|
self,
|
|
137
159
|
tokens,
|
|
138
160
|
context = None,
|
|
139
|
-
context_mask = None
|
|
161
|
+
context_mask = None,
|
|
162
|
+
kv = None,
|
|
163
|
+
return_kv = False
|
|
140
164
|
):
|
|
141
165
|
context = default(context, tokens)
|
|
142
166
|
|
|
143
167
|
queries = self.to_queries(tokens)
|
|
144
|
-
keys, values = self.to_keys_values(context).chunk(2, dim = -1)
|
|
145
|
-
|
|
146
168
|
queries = self.split_q_heads(queries)
|
|
147
|
-
|
|
169
|
+
|
|
170
|
+
if not exists(kv):
|
|
171
|
+
keys, values = self.to_keys_values(context).chunk(2, dim = -1)
|
|
172
|
+
keys, values = tuple(self.split_kv_heads(t) for t in (keys, values))
|
|
173
|
+
else:
|
|
174
|
+
keys, values = kv
|
|
148
175
|
|
|
149
176
|
queries = queries * self.scale
|
|
150
177
|
|
|
@@ -160,7 +187,12 @@ class Attention(Module):
|
|
|
160
187
|
|
|
161
188
|
out = self.merge_heads(out)
|
|
162
189
|
|
|
163
|
-
|
|
190
|
+
out = self.to_out(out)
|
|
191
|
+
|
|
192
|
+
if not return_kv:
|
|
193
|
+
return out
|
|
194
|
+
|
|
195
|
+
return out, stack((keys, values))
|
|
164
196
|
|
|
165
197
|
# feedforward
|
|
166
198
|
|
|
@@ -193,18 +225,43 @@ class MimicVideo(Module):
|
|
|
193
225
|
def __init__(
|
|
194
226
|
self,
|
|
195
227
|
dim,
|
|
228
|
+
video_predict_wrapper: Module | None = None,
|
|
196
229
|
*,
|
|
197
|
-
dim_video_hidden,
|
|
230
|
+
dim_video_hidden = None,
|
|
231
|
+
action_chunk_len = 32,
|
|
198
232
|
dim_action = 20,
|
|
233
|
+
dim_joint_state = 32,
|
|
234
|
+
proprio_mask_prob = 0.1,
|
|
199
235
|
depth = 8,
|
|
200
236
|
dim_head = 64,
|
|
201
237
|
heads = 8,
|
|
202
238
|
expansion_factor = 4.,
|
|
239
|
+
ada_ln_zero_bias = -5.,
|
|
203
240
|
dim_time_cond = None,
|
|
204
241
|
sample_time_fn = None
|
|
205
242
|
):
|
|
206
243
|
super().__init__()
|
|
207
244
|
|
|
245
|
+
self.depth = depth
|
|
246
|
+
|
|
247
|
+
# maybe video predict
|
|
248
|
+
|
|
249
|
+
self.video_predict_wrapper = video_predict_wrapper
|
|
250
|
+
|
|
251
|
+
# dims
|
|
252
|
+
|
|
253
|
+
self.action_chunk_len = action_chunk_len
|
|
254
|
+
self.dim_action = dim_action
|
|
255
|
+
|
|
256
|
+
self.action_shape = (action_chunk_len, dim_action)
|
|
257
|
+
self.dim_joint_state = dim_joint_state
|
|
258
|
+
|
|
259
|
+
dim_video_hidden = default(dim_video_hidden, video_predict_wrapper.dim_latent if exists(video_predict_wrapper) else None)
|
|
260
|
+
|
|
261
|
+
assert exists(dim_video_hidden), f'`dim_video_hidden` must be set or `video_predict_wrapper` passed in with `dim_latent`'
|
|
262
|
+
|
|
263
|
+
self.dim_video_hidden = dim_video_hidden
|
|
264
|
+
|
|
208
265
|
# flow related
|
|
209
266
|
|
|
210
267
|
self.sample_time_fn = default(sample_time_fn, default_sample_time_fn)
|
|
@@ -215,13 +272,26 @@ class MimicVideo(Module):
|
|
|
215
272
|
|
|
216
273
|
dim_time_cond = default(dim_time_cond, dim * 2)
|
|
217
274
|
|
|
218
|
-
self.
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
275
|
+
self.to_fourier_embed = RandomFourierEmbed(dim) # used by deepmind, its fine
|
|
276
|
+
self.to_time_cond = create_mlp(dim_in = dim * 2, dim = dim_time_cond, depth = 2, activation = nn.SiLU())
|
|
277
|
+
|
|
278
|
+
# joint token related
|
|
279
|
+
|
|
280
|
+
self.to_joint_state_token = Linear(dim_joint_state, dim)
|
|
281
|
+
|
|
282
|
+
self.proprio_mask_prob = proprio_mask_prob
|
|
283
|
+
self.has_proprio_masking = proprio_mask_prob > 0.
|
|
284
|
+
|
|
285
|
+
self.proprio_mask_token = nn.Parameter(torch.randn(dim))
|
|
286
|
+
|
|
287
|
+
# video norm
|
|
222
288
|
|
|
223
289
|
self.video_hidden_norm = nn.RMSNorm(dim_video_hidden)
|
|
224
290
|
|
|
291
|
+
# rnn
|
|
292
|
+
|
|
293
|
+
self.rnn = GRU(dim, dim)
|
|
294
|
+
|
|
225
295
|
# transformer
|
|
226
296
|
|
|
227
297
|
layers = []
|
|
@@ -235,7 +305,7 @@ class MimicVideo(Module):
|
|
|
235
305
|
|
|
236
306
|
cross_attn = Attention(dim = dim, dim_head = dim_head, dim_context = dim_video_hidden, heads = heads)
|
|
237
307
|
|
|
238
|
-
ff_adanorm = AdaptiveRMSNorm(dim = dim, dim_time_cond = dim_time_cond)
|
|
308
|
+
ff_adanorm = AdaptiveRMSNorm(dim = dim, dim_time_cond = dim_time_cond, ada_ln_zero_bias = ada_ln_zero_bias)
|
|
239
309
|
|
|
240
310
|
ff = SwiGLUFeedForward(dim = dim, expansion_factor = expansion_factor)
|
|
241
311
|
|
|
@@ -254,25 +324,93 @@ class MimicVideo(Module):
|
|
|
254
324
|
|
|
255
325
|
self.to_pred_action_flow = nn.Sequential(
|
|
256
326
|
nn.RMSNorm(dim),
|
|
257
|
-
Linear(dim, dim_action)
|
|
327
|
+
Linear(dim, dim_action, bias = False)
|
|
258
328
|
)
|
|
259
329
|
|
|
330
|
+
self.register_buffer('zero', tensor(0.), persistent = False)
|
|
331
|
+
|
|
332
|
+
@property
|
|
333
|
+
def device(self):
|
|
334
|
+
return self.zero.device
|
|
335
|
+
|
|
336
|
+
@torch.no_grad()
|
|
337
|
+
def sample(
|
|
338
|
+
self,
|
|
339
|
+
steps = 16,
|
|
340
|
+
batch_size = 1,
|
|
341
|
+
disable_progress_bar = False,
|
|
342
|
+
**kwargs
|
|
343
|
+
):
|
|
344
|
+
|
|
345
|
+
self.eval()
|
|
346
|
+
|
|
347
|
+
noise = torch.randn((batch_size, *self.action_shape), device = self.device)
|
|
348
|
+
|
|
349
|
+
times = torch.linspace(0., 1., steps + 1, device = self.device)[:-1]
|
|
350
|
+
delta = 1. / steps
|
|
351
|
+
|
|
352
|
+
denoised = noise
|
|
353
|
+
|
|
354
|
+
cache = None
|
|
355
|
+
|
|
356
|
+
for time in tqdm(times, disable = disable_progress_bar):
|
|
357
|
+
pred_flow, cache = self.forward(actions = denoised, time = time, cache = cache, return_cache = True, **kwargs)
|
|
358
|
+
|
|
359
|
+
denoised = denoised + delta * pred_flow
|
|
360
|
+
|
|
361
|
+
return denoised
|
|
362
|
+
|
|
260
363
|
def forward(
|
|
261
364
|
self,
|
|
262
|
-
actions,
|
|
263
|
-
video_hiddens, # they use layer 19 of cosmos predict, at first denoising step. that's all
|
|
264
365
|
*,
|
|
265
|
-
|
|
366
|
+
actions,
|
|
367
|
+
joint_state,
|
|
368
|
+
video = None,
|
|
369
|
+
video_hiddens = None, # they use layer 19 of cosmos predict, at first denoising step. that's all
|
|
266
370
|
context_mask = None,
|
|
371
|
+
time = None,
|
|
372
|
+
time_video_denoise = 0., # 0 is noise in the scheme i prefer - default to their optimal choice, but can be changed
|
|
373
|
+
prompts = None,
|
|
374
|
+
prompt_token_ids = None,
|
|
375
|
+
cache = None,
|
|
376
|
+
return_cache = False,
|
|
377
|
+
return_flow = False
|
|
267
378
|
):
|
|
379
|
+
assert not exists(self.video_predict_wrapper) or (exists(prompts) ^ exists(prompt_token_ids))
|
|
380
|
+
assert actions.shape[-2:] == self.action_shape
|
|
381
|
+
|
|
382
|
+
batch, device = actions.shape[0], actions.device
|
|
383
|
+
|
|
384
|
+
is_training = not exists(time) and not return_flow
|
|
385
|
+
|
|
386
|
+
if not exists(cache):
|
|
387
|
+
# handle maybe extraction of video hiddens
|
|
388
|
+
# only if cache is not given
|
|
389
|
+
|
|
390
|
+
assert exists(video) ^ exists(video_hiddens)
|
|
391
|
+
|
|
392
|
+
if not exists(video_hiddens):
|
|
393
|
+
assert exists(self.video_predict_wrapper), f'`video_predict_wrapper` must be passed in if raw video is passed into MimicVideo'
|
|
394
|
+
|
|
395
|
+
video_hiddens = self.video_predict_wrapper(video, prompts = prompts, prompt_token_ids = prompt_token_ids)
|
|
396
|
+
video_hiddens = video_hiddens.float() # maybe bfloat to float32
|
|
397
|
+
|
|
398
|
+
video_hiddens, _ = pack_with_inverse(video_hiddens, 'b * d')
|
|
399
|
+
|
|
400
|
+
assert video_hiddens.shape[-1] == self.dim_video_hidden
|
|
401
|
+
|
|
402
|
+
# handle video hiddens
|
|
403
|
+
|
|
404
|
+
video_hiddens = self.video_hidden_norm(video_hiddens)
|
|
268
405
|
|
|
269
|
-
|
|
406
|
+
# handle caching
|
|
407
|
+
|
|
408
|
+
prev_cached_video_hiddens_kv = cache if exists(cache) else ((None,) * self.depth)
|
|
409
|
+
next_cached_video_hiddens_kv = []
|
|
270
410
|
|
|
271
411
|
# handle flow time conditioning
|
|
272
412
|
|
|
273
413
|
if is_training:
|
|
274
|
-
batch, device = actions.shape[0], actions.device
|
|
275
|
-
|
|
276
414
|
time = torch.rand((batch,), device = device)
|
|
277
415
|
time = self.sample_time_fn(time)
|
|
278
416
|
|
|
@@ -285,26 +423,61 @@ class MimicVideo(Module):
|
|
|
285
423
|
else:
|
|
286
424
|
noised = actions
|
|
287
425
|
|
|
288
|
-
|
|
426
|
+
if time.ndim == 0:
|
|
427
|
+
time = repeat(time, '-> b', b = batch)
|
|
428
|
+
|
|
429
|
+
# handle the video denoising times
|
|
430
|
+
|
|
431
|
+
time_video_denoise = cast_tensor(time_video_denoise)
|
|
432
|
+
|
|
433
|
+
if time_video_denoise.ndim == 0:
|
|
434
|
+
time_video_denoise = rearrange(time_video_denoise, '-> 1')
|
|
435
|
+
|
|
436
|
+
if time_video_denoise.shape[0] != batch:
|
|
437
|
+
time_video_denoise = repeat(time_video_denoise, '1 -> b', b = batch)
|
|
289
438
|
|
|
290
|
-
|
|
439
|
+
times = stack((time, time_video_denoise), dim = -1)
|
|
291
440
|
|
|
292
|
-
|
|
441
|
+
# fourier embed and mlp to time condition
|
|
442
|
+
|
|
443
|
+
fourier_embed = self.to_fourier_embed(times)
|
|
444
|
+
|
|
445
|
+
fourier_embed = rearrange(fourier_embed, '... times d -> ... (times d)')
|
|
446
|
+
|
|
447
|
+
time_cond = self.to_time_cond(fourier_embed)
|
|
293
448
|
|
|
294
449
|
# embed
|
|
295
450
|
|
|
296
451
|
tokens = self.to_action_tokens(noised)
|
|
297
452
|
|
|
453
|
+
# one layer of rnn for actions
|
|
454
|
+
|
|
455
|
+
rnn_out, _, = self.rnn(tokens)
|
|
456
|
+
tokens = rnn_out + tokens
|
|
457
|
+
|
|
458
|
+
# mask joint state token for proprioception masking training
|
|
459
|
+
|
|
460
|
+
joint_state_token = self.to_joint_state_token(joint_state)
|
|
461
|
+
|
|
462
|
+
if self.training and self.has_proprio_masking:
|
|
463
|
+
mask = torch.rand((batch,), device = device) < self.proprio_mask_prob
|
|
464
|
+
|
|
465
|
+
joint_state_token = einx.where('b, d, b d', mask, self.proprio_mask_token, joint_state_token)
|
|
466
|
+
|
|
467
|
+
# pack joint with action tokens
|
|
468
|
+
|
|
469
|
+
tokens, inverse_pack = pack_with_inverse((joint_state_token, tokens), 'b * d')
|
|
470
|
+
|
|
298
471
|
# transformer layers
|
|
299
472
|
|
|
300
|
-
for (
|
|
473
|
+
for ((
|
|
301
474
|
attn_norm,
|
|
302
475
|
attn,
|
|
303
476
|
cross_attn_norm,
|
|
304
477
|
cross_attn,
|
|
305
478
|
ff_norm,
|
|
306
479
|
ff
|
|
307
|
-
) in self.layers:
|
|
480
|
+
), cached_video_kv) in zip(self.layers, prev_cached_video_hiddens_kv):
|
|
308
481
|
|
|
309
482
|
# cross attention
|
|
310
483
|
|
|
@@ -312,7 +485,12 @@ class MimicVideo(Module):
|
|
|
312
485
|
|
|
313
486
|
tokens, gate = cross_attn_norm(tokens, time_cond)
|
|
314
487
|
|
|
315
|
-
|
|
488
|
+
cross_attn_out, video_kv = cross_attn(tokens, context = video_hiddens, context_mask = context_mask, kv = cached_video_kv, return_kv = True)
|
|
489
|
+
|
|
490
|
+
tokens = residual + cross_attn_out * gate
|
|
491
|
+
|
|
492
|
+
if return_cache:
|
|
493
|
+
next_cached_video_hiddens_kv.append(video_kv)
|
|
316
494
|
|
|
317
495
|
# self attention
|
|
318
496
|
|
|
@@ -322,22 +500,46 @@ class MimicVideo(Module):
|
|
|
322
500
|
|
|
323
501
|
tokens = residual + attn(tokens) * gate
|
|
324
502
|
|
|
325
|
-
# feedforward
|
|
503
|
+
# prepare feedforward
|
|
326
504
|
|
|
327
505
|
residual = tokens
|
|
328
506
|
|
|
329
507
|
tokens, gate = ff_norm(tokens, time_cond)
|
|
330
508
|
|
|
509
|
+
# shift along time for action tokens for cheap relative positioning, which is better than messing with rope with such short action chunks
|
|
510
|
+
|
|
511
|
+
joint_state_token, tokens = inverse_pack(tokens)
|
|
512
|
+
|
|
513
|
+
tokens = shift_feature_dim(tokens)
|
|
514
|
+
|
|
515
|
+
tokens, _ = pack_with_inverse((joint_state_token, tokens), 'b * d')
|
|
516
|
+
|
|
517
|
+
# feedforward
|
|
518
|
+
|
|
331
519
|
tokens = residual + ff(tokens) * gate
|
|
332
520
|
|
|
521
|
+
# remove joint token
|
|
522
|
+
|
|
523
|
+
_, tokens = inverse_pack(tokens)
|
|
524
|
+
|
|
333
525
|
# prediction
|
|
334
526
|
|
|
335
527
|
pred_flow = self.to_pred_action_flow(tokens)
|
|
336
528
|
|
|
337
529
|
if not is_training:
|
|
338
|
-
|
|
530
|
+
# flow
|
|
531
|
+
|
|
532
|
+
out = pred_flow
|
|
533
|
+
else:
|
|
534
|
+
# mse flow loss
|
|
535
|
+
|
|
536
|
+
flow_loss = F.mse_loss(pred_flow, flow)
|
|
537
|
+
|
|
538
|
+
out = flow_loss
|
|
539
|
+
|
|
540
|
+
if not return_cache:
|
|
541
|
+
return out
|
|
339
542
|
|
|
340
|
-
#
|
|
543
|
+
# handle returning of cache
|
|
341
544
|
|
|
342
|
-
|
|
343
|
-
return flow_loss
|
|
545
|
+
return out, next_cached_video_hiddens_kv
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: mimic-video
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.19
|
|
4
4
|
Summary: Mimic Video
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/mimic-video/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/mimic-video
|
|
@@ -38,10 +38,14 @@ Requires-Dist: einops>=0.8.1
|
|
|
38
38
|
Requires-Dist: einx>=0.3.0
|
|
39
39
|
Requires-Dist: torch-einops-utils>=0.0.8
|
|
40
40
|
Requires-Dist: torch>=2.5
|
|
41
|
+
Requires-Dist: tqdm
|
|
41
42
|
Requires-Dist: x-mlps-pytorch
|
|
42
43
|
Provides-Extra: examples
|
|
43
44
|
Provides-Extra: test
|
|
45
|
+
Requires-Dist: accelerate; extra == 'test'
|
|
46
|
+
Requires-Dist: diffusers>=0.32.0; extra == 'test'
|
|
44
47
|
Requires-Dist: pytest; extra == 'test'
|
|
48
|
+
Requires-Dist: transformers; extra == 'test'
|
|
45
49
|
Description-Content-Type: text/markdown
|
|
46
50
|
|
|
47
51
|
<img src="./mimic-video.png" width="450px"></img>
|
|
@@ -50,6 +54,73 @@ Description-Content-Type: text/markdown
|
|
|
50
54
|
|
|
51
55
|
Implementation of [Mimic-Video](https://mimic-video.github.io/), Video-Action Models for Generalizable Robot Control Beyond VLAs
|
|
52
56
|
|
|
57
|
+
## Appreciation
|
|
58
|
+
|
|
59
|
+
- [Pranoy](https://github.com/pranoyr) for submitting a pull request for proprioception masking
|
|
60
|
+
|
|
61
|
+
## Install
|
|
62
|
+
|
|
63
|
+
```shell
|
|
64
|
+
$ pip install mimic-video
|
|
65
|
+
```
|
|
66
|
+
|
|
67
|
+
## Usage
|
|
68
|
+
|
|
69
|
+
```python
|
|
70
|
+
import torch
|
|
71
|
+
|
|
72
|
+
# video wrapper
|
|
73
|
+
# but will be agnostic to the model
|
|
74
|
+
|
|
75
|
+
from mimic_video.cosmos_predict import CosmosPredictWrapper
|
|
76
|
+
|
|
77
|
+
video_wrapper = CosmosPredictWrapper(
|
|
78
|
+
extract_layer = 1,
|
|
79
|
+
random_weights = True,
|
|
80
|
+
tiny = True
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
# mimic video
|
|
84
|
+
|
|
85
|
+
from mimic_video import MimicVideo
|
|
86
|
+
|
|
87
|
+
model = MimicVideo(512, video_wrapper)
|
|
88
|
+
|
|
89
|
+
# states
|
|
90
|
+
|
|
91
|
+
video = torch.rand(2, 3, 3, 32, 32)
|
|
92
|
+
|
|
93
|
+
joint_state = torch.randn(2, 32)
|
|
94
|
+
|
|
95
|
+
# action
|
|
96
|
+
|
|
97
|
+
actions = torch.randn(2, 32, 20)
|
|
98
|
+
|
|
99
|
+
# training
|
|
100
|
+
|
|
101
|
+
loss = model(
|
|
102
|
+
prompts = [
|
|
103
|
+
'put the package on the conveyer belt',
|
|
104
|
+
'pass the butter'
|
|
105
|
+
],
|
|
106
|
+
video = video,
|
|
107
|
+
actions = actions,
|
|
108
|
+
joint_state = joint_state
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
loss.backward()
|
|
112
|
+
|
|
113
|
+
# inference
|
|
114
|
+
|
|
115
|
+
actions = model.sample(
|
|
116
|
+
prompts = 'peel the orange',
|
|
117
|
+
video = video[:1],
|
|
118
|
+
joint_state = joint_state[:1]
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
assert actions.shape == (1, 32, 20)
|
|
122
|
+
```
|
|
123
|
+
|
|
53
124
|
## Contributing
|
|
54
125
|
|
|
55
126
|
First make sure `pytest` and test dependencies are installed with
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
mimic_video/__init__.py,sha256=Rs3QeBBGBKKi1U1ykcyeBrCL2XCbfNvppeeD1Fb1pdY,47
|
|
2
|
+
mimic_video/cosmos_predict.py,sha256=2XR9cqcUC4gKpjEDBy-GtLtMkLXvs8yKe7w8g6EeS6s,8471
|
|
3
|
+
mimic_video/mimic_video.py,sha256=wQNfdV2PGlR_-S-4Mm3cyswtRQ3nBQGJiHptya3ckKU,14761
|
|
4
|
+
mimic_video-0.0.19.dist-info/METADATA,sha256=67a7iVIkf557qMos_yvpcqL8dK9yL0Jf1DPQuhb2bwo,4142
|
|
5
|
+
mimic_video-0.0.19.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
6
|
+
mimic_video-0.0.19.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
7
|
+
mimic_video-0.0.19.dist-info/RECORD,,
|
|
@@ -1,6 +0,0 @@
|
|
|
1
|
-
mimic_video/__init__.py,sha256=-4HP_pbT4YLhRUwNwuL4qyLHbgDyQ099nHL7eVi0_Ag,48
|
|
2
|
-
mimic_video/mimic_video.py,sha256=aejvjr1F3A7pZFikf-kEgeOpi1_53xVddBMpDPoxA90,8272
|
|
3
|
-
mimic_video-0.0.1.dist-info/METADATA,sha256=414y344JcuIKQJss7d9riTrHszIwthHW8DDSSuRntdo,2960
|
|
4
|
-
mimic_video-0.0.1.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
5
|
-
mimic_video-0.0.1.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
-
mimic_video-0.0.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|