mimic-video 0.0.3__py3-none-any.whl → 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mimic-video might be problematic. Click here for more details.
- mimic_video/__init__.py +0 -1
- mimic_video/cosmos_predict.py +269 -0
- mimic_video/mimic_video.py +175 -27
- {mimic_video-0.0.3.dist-info → mimic_video-0.0.19.dist-info}/METADATA +72 -1
- mimic_video-0.0.19.dist-info/RECORD +7 -0
- mimic_video-0.0.3.dist-info/RECORD +0 -6
- {mimic_video-0.0.3.dist-info → mimic_video-0.0.19.dist-info}/WHEEL +0 -0
- {mimic_video-0.0.3.dist-info → mimic_video-0.0.19.dist-info}/licenses/LICENSE +0 -0
mimic_video/__init__.py
CHANGED
|
@@ -0,0 +1,269 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from torch.nn import Module
|
|
8
|
+
from torch import nn, Tensor
|
|
9
|
+
from einops import rearrange
|
|
10
|
+
|
|
11
|
+
from diffusers.models.transformers.transformer_cosmos import CosmosTransformer3DModel
|
|
12
|
+
from diffusers.models.autoencoders.autoencoder_kl_cosmos import AutoencoderKLCosmos
|
|
13
|
+
from diffusers.schedulers.scheduling_edm_euler import EDMEulerScheduler
|
|
14
|
+
from transformers import T5EncoderModel, T5TokenizerFast, T5Config
|
|
15
|
+
|
|
16
|
+
# helpers
|
|
17
|
+
|
|
18
|
+
def exists(v):
|
|
19
|
+
return v is not None
|
|
20
|
+
|
|
21
|
+
def identity(t):
|
|
22
|
+
return t
|
|
23
|
+
|
|
24
|
+
def default(v, d):
|
|
25
|
+
return v if exists(v) else d
|
|
26
|
+
|
|
27
|
+
# constants
|
|
28
|
+
|
|
29
|
+
TINY_TRANSFORMER_CONFIG = dict(
|
|
30
|
+
in_channels = 16,
|
|
31
|
+
out_channels = 16,
|
|
32
|
+
num_attention_heads = 1,
|
|
33
|
+
attention_head_dim = 16,
|
|
34
|
+
mlp_ratio = 1.0,
|
|
35
|
+
text_embed_dim = 32,
|
|
36
|
+
adaln_lora_dim = 32,
|
|
37
|
+
patch_size = (1, 2, 2),
|
|
38
|
+
max_size = (4, 16, 16),
|
|
39
|
+
extra_pos_embed_type = None,
|
|
40
|
+
concat_padding_mask = False,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
TINY_VAE_CONFIG = dict(
|
|
44
|
+
in_channels = 3,
|
|
45
|
+
out_channels = 3,
|
|
46
|
+
latent_channels = 16,
|
|
47
|
+
encoder_block_out_channels = (8, 16),
|
|
48
|
+
decode_block_out_channels = (8, 16),
|
|
49
|
+
temporal_compression_ratio = 4,
|
|
50
|
+
spatial_compression_ratio = 4,
|
|
51
|
+
num_layers = 1,
|
|
52
|
+
attention_resolutions = (),
|
|
53
|
+
resolution = 64,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
TINY_T5_CONFIG = dict(
|
|
57
|
+
vocab_size = 32128,
|
|
58
|
+
d_model = 32,
|
|
59
|
+
d_kv = 8,
|
|
60
|
+
d_ff = 64,
|
|
61
|
+
num_layers = 1,
|
|
62
|
+
num_heads = 1,
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
REAL_TRANSFORMER_CONFIG = dict(
|
|
66
|
+
in_channels = 16,
|
|
67
|
+
out_channels = 16,
|
|
68
|
+
num_attention_heads = 32,
|
|
69
|
+
attention_head_dim = 128,
|
|
70
|
+
mlp_ratio = 4.0,
|
|
71
|
+
text_embed_dim = 1024,
|
|
72
|
+
patch_size = (1, 2, 2),
|
|
73
|
+
max_size = (128, 240, 240),
|
|
74
|
+
extra_pos_embed_type = "learnable",
|
|
75
|
+
concat_padding_mask = True,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
REAL_VAE_CONFIG = dict(
|
|
79
|
+
in_channels = 3,
|
|
80
|
+
out_channels = 3,
|
|
81
|
+
latent_channels = 16,
|
|
82
|
+
encoder_block_out_channels = (128, 256, 512, 512),
|
|
83
|
+
decode_block_out_channels = (256, 512, 512, 512),
|
|
84
|
+
temporal_compression_ratio = 8,
|
|
85
|
+
spatial_compression_ratio = 8,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
REAL_T5_CONFIG = dict(
|
|
89
|
+
vocab_size = 32128,
|
|
90
|
+
d_model = 1024,
|
|
91
|
+
d_kv = 64,
|
|
92
|
+
d_ff = 2048,
|
|
93
|
+
num_layers = 12,
|
|
94
|
+
num_heads = 16,
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
# main class
|
|
98
|
+
|
|
99
|
+
class CosmosPredictWrapper(Module):
|
|
100
|
+
"""
|
|
101
|
+
Wraps Cosmos VAE + DiT for extracting hidden states from a video.
|
|
102
|
+
Supports proper EDM Euler denoising steps.
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
def __init__(
|
|
106
|
+
self,
|
|
107
|
+
model_name: str = 'nvidia/Cosmos-1.0-Diffusion-7B-Video2World',
|
|
108
|
+
extract_layer: int = 19,
|
|
109
|
+
random_weights: bool = False,
|
|
110
|
+
tiny: bool = False,
|
|
111
|
+
normalize = lambda t: (t - 0.5) * 2.0
|
|
112
|
+
):
|
|
113
|
+
super().__init__()
|
|
114
|
+
self.extract_layer = extract_layer
|
|
115
|
+
self.hook_handle = None
|
|
116
|
+
self.cached_hidden_states: list[Tensor] = []
|
|
117
|
+
|
|
118
|
+
if random_weights:
|
|
119
|
+
self._init_random_weights(tiny = tiny)
|
|
120
|
+
else:
|
|
121
|
+
self._init_pretrained(model_name)
|
|
122
|
+
|
|
123
|
+
# Initialize scheduler
|
|
124
|
+
self.scheduler = EDMEulerScheduler()
|
|
125
|
+
|
|
126
|
+
# store hidden dim for consumers
|
|
127
|
+
self.dim_latent = self.transformer.config.num_attention_heads * self.transformer.config.attention_head_dim
|
|
128
|
+
|
|
129
|
+
# maybe normalize
|
|
130
|
+
self.normalize = normalize
|
|
131
|
+
|
|
132
|
+
self._register_hook()
|
|
133
|
+
|
|
134
|
+
@property
|
|
135
|
+
def device(self):
|
|
136
|
+
return next(self.parameters()).device
|
|
137
|
+
|
|
138
|
+
def _init_pretrained(self, model_name: str):
|
|
139
|
+
"""Load pretrained weights from HuggingFace"""
|
|
140
|
+
from diffusers import CosmosVideoToWorldPipeline
|
|
141
|
+
|
|
142
|
+
pipeline = CosmosVideoToWorldPipeline.from_pretrained(model_name)
|
|
143
|
+
|
|
144
|
+
# Extract components we need
|
|
145
|
+
self.vae = pipeline.vae
|
|
146
|
+
self.transformer = pipeline.transformer
|
|
147
|
+
self.text_encoder = pipeline.text_encoder
|
|
148
|
+
self.tokenizer = pipeline.tokenizer
|
|
149
|
+
|
|
150
|
+
# Clean up pipeline
|
|
151
|
+
del pipeline
|
|
152
|
+
|
|
153
|
+
def _init_random_weights(self, tiny: bool = False):
|
|
154
|
+
"""Initialize with random weights for testing"""
|
|
155
|
+
|
|
156
|
+
transformer_config = TINY_TRANSFORMER_CONFIG if tiny else REAL_TRANSFORMER_CONFIG
|
|
157
|
+
vae_config = TINY_VAE_CONFIG if tiny else REAL_VAE_CONFIG
|
|
158
|
+
t5_config_dict = TINY_T5_CONFIG if tiny else REAL_T5_CONFIG
|
|
159
|
+
|
|
160
|
+
num_layers = max(2, self.extract_layer + 1)
|
|
161
|
+
if not tiny:
|
|
162
|
+
num_layers = max(28, num_layers)
|
|
163
|
+
|
|
164
|
+
self.transformer = CosmosTransformer3DModel(
|
|
165
|
+
num_layers = num_layers,
|
|
166
|
+
**transformer_config
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
self.vae = AutoencoderKLCosmos(**vae_config)
|
|
170
|
+
|
|
171
|
+
t5_config = T5Config(**t5_config_dict)
|
|
172
|
+
self.text_encoder = T5EncoderModel(t5_config)
|
|
173
|
+
self.tokenizer = T5TokenizerFast.from_pretrained("google-t5/t5-small")
|
|
174
|
+
|
|
175
|
+
def __del__(self):
|
|
176
|
+
if exists(self.hook_handle):
|
|
177
|
+
self.hook_handle.remove()
|
|
178
|
+
|
|
179
|
+
def _register_hook(self):
|
|
180
|
+
assert hasattr(self.transformer, 'transformer_blocks'), 'transformer must have transformer_blocks'
|
|
181
|
+
assert len(self.transformer.transformer_blocks) > self.extract_layer, f'layer {self.extract_layer} out of bounds'
|
|
182
|
+
|
|
183
|
+
target_layer = self.transformer.transformer_blocks[self.extract_layer]
|
|
184
|
+
|
|
185
|
+
def hook_fn(module, inp, out):
|
|
186
|
+
self.cached_hidden_states.append(out.detach().cpu())
|
|
187
|
+
|
|
188
|
+
self.hook_handle = target_layer.register_forward_hook(hook_fn)
|
|
189
|
+
|
|
190
|
+
def forward(
|
|
191
|
+
self,
|
|
192
|
+
videos: Tensor,
|
|
193
|
+
prompts: str | list[str] | None = None,
|
|
194
|
+
prompt_token_ids: Tensor | None = None,
|
|
195
|
+
num_inference_steps: int = 1,
|
|
196
|
+
) -> Tensor:
|
|
197
|
+
"""
|
|
198
|
+
videos: (batch, frames, channels, height, width) in [0, 1]
|
|
199
|
+
num_inference_steps: number of denoising steps to run
|
|
200
|
+
returns: hidden states tensor from the specified transformer layer (from first step)
|
|
201
|
+
"""
|
|
202
|
+
batch, t, c, h, w = videos.shape
|
|
203
|
+
|
|
204
|
+
assert exists(prompts) ^ exists(prompt_token_ids)
|
|
205
|
+
|
|
206
|
+
# Scale videos from [0, 1] to [-1, 1] for Cosmos VAE
|
|
207
|
+
|
|
208
|
+
videos = self.normalize(videos)
|
|
209
|
+
|
|
210
|
+
if isinstance(prompts, str):
|
|
211
|
+
prompts = [prompts] * batch
|
|
212
|
+
|
|
213
|
+
self.cached_hidden_states.clear()
|
|
214
|
+
|
|
215
|
+
# Move video to device and rearrange for VAE: (B, T, C, H, W) -> (B, C, T, H, W)
|
|
216
|
+
videos = rearrange(videos, 'b t c h w -> b c t h w')
|
|
217
|
+
|
|
218
|
+
with torch.inference_mode():
|
|
219
|
+
# 1. encode video to latents via VAE
|
|
220
|
+
|
|
221
|
+
latents = self.vae.encode(videos).latent_dist.sample()
|
|
222
|
+
|
|
223
|
+
# 2. maybe encode text prompts
|
|
224
|
+
|
|
225
|
+
if exists(prompt_token_ids):
|
|
226
|
+
text_inputs = dict(input_ids = prompt_token_ids)
|
|
227
|
+
else:
|
|
228
|
+
text_inputs = self.tokenizer(
|
|
229
|
+
prompts,
|
|
230
|
+
return_tensors = "pt",
|
|
231
|
+
padding = True,
|
|
232
|
+
truncation = True,
|
|
233
|
+
max_length = 512
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
encoder_hidden_states = self.text_encoder(**text_inputs).last_hidden_state
|
|
237
|
+
|
|
238
|
+
# 3. Setup scheduler timesteps
|
|
239
|
+
self.scheduler.set_timesteps(num_inference_steps, device = self.device)
|
|
240
|
+
timesteps = self.scheduler.timesteps
|
|
241
|
+
|
|
242
|
+
# 4. Add noise to latents (start from pure noise scaled by initial sigma)
|
|
243
|
+
noise = torch.randn_like(latents)
|
|
244
|
+
latents = latents + noise * self.scheduler.init_noise_sigma
|
|
245
|
+
|
|
246
|
+
# 5. Denoising loop
|
|
247
|
+
for i, timestep in enumerate(timesteps):
|
|
248
|
+
# Scale model input
|
|
249
|
+
latent_model_input = self.scheduler.scale_model_input(latents, timestep)
|
|
250
|
+
|
|
251
|
+
# Predict noise residual
|
|
252
|
+
noise_pred = self.transformer(
|
|
253
|
+
hidden_states = latent_model_input,
|
|
254
|
+
encoder_hidden_states = encoder_hidden_states,
|
|
255
|
+
timestep = timestep.expand(batch),
|
|
256
|
+
return_dict = False
|
|
257
|
+
)[0]
|
|
258
|
+
|
|
259
|
+
# Compute previous noisy sample
|
|
260
|
+
latents = self.scheduler.step(noise_pred, timestep, latents, return_dict = False)[0]
|
|
261
|
+
|
|
262
|
+
assert len(self.cached_hidden_states) > 0, 'hidden states not captured'
|
|
263
|
+
|
|
264
|
+
# Return hidden states from the first denoising step
|
|
265
|
+
hidden = self.cached_hidden_states[0]
|
|
266
|
+
|
|
267
|
+
assert hidden.shape[-1] == self.dim_latent, f'hidden dim mismatch: expected {self.dim_latent_hidden}, got {hidden.shape[-1]}'
|
|
268
|
+
|
|
269
|
+
return hidden
|
mimic_video/mimic_video.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import torch
|
|
2
4
|
from torch import nn, cat, stack, is_tensor, tensor
|
|
3
|
-
from torch.nn import Module, ModuleList, Linear
|
|
5
|
+
from torch.nn import Module, ModuleList, Linear, GRU
|
|
4
6
|
|
|
5
7
|
import torch.nn.functional as F
|
|
6
8
|
|
|
@@ -10,6 +12,8 @@ from einops.layers.torch import Rearrange
|
|
|
10
12
|
|
|
11
13
|
from x_mlps_pytorch import create_mlp
|
|
12
14
|
|
|
15
|
+
from tqdm import tqdm
|
|
16
|
+
|
|
13
17
|
from torch_einops_utils import (
|
|
14
18
|
pad_left_ndim,
|
|
15
19
|
align_dims_left,
|
|
@@ -85,7 +89,8 @@ class AdaptiveRMSNorm(Module):
|
|
|
85
89
|
self,
|
|
86
90
|
dim,
|
|
87
91
|
dim_time_cond,
|
|
88
|
-
eps = 1e-6
|
|
92
|
+
eps = 1e-6,
|
|
93
|
+
ada_ln_zero_bias = -5.
|
|
89
94
|
):
|
|
90
95
|
super().__init__()
|
|
91
96
|
self.scale = dim ** 0.5
|
|
@@ -96,6 +101,8 @@ class AdaptiveRMSNorm(Module):
|
|
|
96
101
|
|
|
97
102
|
nn.init.zeros_(self.to_modulation.weight)
|
|
98
103
|
|
|
104
|
+
self.ada_ln_zero_bias = ada_ln_zero_bias
|
|
105
|
+
|
|
99
106
|
def forward(
|
|
100
107
|
self,
|
|
101
108
|
tokens,
|
|
@@ -113,7 +120,9 @@ class AdaptiveRMSNorm(Module):
|
|
|
113
120
|
|
|
114
121
|
adaptive_normed = normed * (scale + 1.) + shift
|
|
115
122
|
|
|
116
|
-
|
|
123
|
+
gate_with_bias = (gate + self.ada_ln_zero_bias).sigmoid()
|
|
124
|
+
|
|
125
|
+
return adaptive_normed, gate_with_bias
|
|
117
126
|
|
|
118
127
|
# attention
|
|
119
128
|
|
|
@@ -149,15 +158,20 @@ class Attention(Module):
|
|
|
149
158
|
self,
|
|
150
159
|
tokens,
|
|
151
160
|
context = None,
|
|
152
|
-
context_mask = None
|
|
161
|
+
context_mask = None,
|
|
162
|
+
kv = None,
|
|
163
|
+
return_kv = False
|
|
153
164
|
):
|
|
154
165
|
context = default(context, tokens)
|
|
155
166
|
|
|
156
167
|
queries = self.to_queries(tokens)
|
|
157
|
-
keys, values = self.to_keys_values(context).chunk(2, dim = -1)
|
|
158
|
-
|
|
159
168
|
queries = self.split_q_heads(queries)
|
|
160
|
-
|
|
169
|
+
|
|
170
|
+
if not exists(kv):
|
|
171
|
+
keys, values = self.to_keys_values(context).chunk(2, dim = -1)
|
|
172
|
+
keys, values = tuple(self.split_kv_heads(t) for t in (keys, values))
|
|
173
|
+
else:
|
|
174
|
+
keys, values = kv
|
|
161
175
|
|
|
162
176
|
queries = queries * self.scale
|
|
163
177
|
|
|
@@ -173,7 +187,12 @@ class Attention(Module):
|
|
|
173
187
|
|
|
174
188
|
out = self.merge_heads(out)
|
|
175
189
|
|
|
176
|
-
|
|
190
|
+
out = self.to_out(out)
|
|
191
|
+
|
|
192
|
+
if not return_kv:
|
|
193
|
+
return out
|
|
194
|
+
|
|
195
|
+
return out, stack((keys, values))
|
|
177
196
|
|
|
178
197
|
# feedforward
|
|
179
198
|
|
|
@@ -206,19 +225,43 @@ class MimicVideo(Module):
|
|
|
206
225
|
def __init__(
|
|
207
226
|
self,
|
|
208
227
|
dim,
|
|
228
|
+
video_predict_wrapper: Module | None = None,
|
|
209
229
|
*,
|
|
210
|
-
dim_video_hidden,
|
|
230
|
+
dim_video_hidden = None,
|
|
231
|
+
action_chunk_len = 32,
|
|
211
232
|
dim_action = 20,
|
|
212
233
|
dim_joint_state = 32,
|
|
234
|
+
proprio_mask_prob = 0.1,
|
|
213
235
|
depth = 8,
|
|
214
236
|
dim_head = 64,
|
|
215
237
|
heads = 8,
|
|
216
238
|
expansion_factor = 4.,
|
|
239
|
+
ada_ln_zero_bias = -5.,
|
|
217
240
|
dim_time_cond = None,
|
|
218
241
|
sample_time_fn = None
|
|
219
242
|
):
|
|
220
243
|
super().__init__()
|
|
221
244
|
|
|
245
|
+
self.depth = depth
|
|
246
|
+
|
|
247
|
+
# maybe video predict
|
|
248
|
+
|
|
249
|
+
self.video_predict_wrapper = video_predict_wrapper
|
|
250
|
+
|
|
251
|
+
# dims
|
|
252
|
+
|
|
253
|
+
self.action_chunk_len = action_chunk_len
|
|
254
|
+
self.dim_action = dim_action
|
|
255
|
+
|
|
256
|
+
self.action_shape = (action_chunk_len, dim_action)
|
|
257
|
+
self.dim_joint_state = dim_joint_state
|
|
258
|
+
|
|
259
|
+
dim_video_hidden = default(dim_video_hidden, video_predict_wrapper.dim_latent if exists(video_predict_wrapper) else None)
|
|
260
|
+
|
|
261
|
+
assert exists(dim_video_hidden), f'`dim_video_hidden` must be set or `video_predict_wrapper` passed in with `dim_latent`'
|
|
262
|
+
|
|
263
|
+
self.dim_video_hidden = dim_video_hidden
|
|
264
|
+
|
|
222
265
|
# flow related
|
|
223
266
|
|
|
224
267
|
self.sample_time_fn = default(sample_time_fn, default_sample_time_fn)
|
|
@@ -232,10 +275,23 @@ class MimicVideo(Module):
|
|
|
232
275
|
self.to_fourier_embed = RandomFourierEmbed(dim) # used by deepmind, its fine
|
|
233
276
|
self.to_time_cond = create_mlp(dim_in = dim * 2, dim = dim_time_cond, depth = 2, activation = nn.SiLU())
|
|
234
277
|
|
|
278
|
+
# joint token related
|
|
279
|
+
|
|
235
280
|
self.to_joint_state_token = Linear(dim_joint_state, dim)
|
|
236
281
|
|
|
282
|
+
self.proprio_mask_prob = proprio_mask_prob
|
|
283
|
+
self.has_proprio_masking = proprio_mask_prob > 0.
|
|
284
|
+
|
|
285
|
+
self.proprio_mask_token = nn.Parameter(torch.randn(dim))
|
|
286
|
+
|
|
287
|
+
# video norm
|
|
288
|
+
|
|
237
289
|
self.video_hidden_norm = nn.RMSNorm(dim_video_hidden)
|
|
238
290
|
|
|
291
|
+
# rnn
|
|
292
|
+
|
|
293
|
+
self.rnn = GRU(dim, dim)
|
|
294
|
+
|
|
239
295
|
# transformer
|
|
240
296
|
|
|
241
297
|
layers = []
|
|
@@ -249,7 +305,7 @@ class MimicVideo(Module):
|
|
|
249
305
|
|
|
250
306
|
cross_attn = Attention(dim = dim, dim_head = dim_head, dim_context = dim_video_hidden, heads = heads)
|
|
251
307
|
|
|
252
|
-
ff_adanorm = AdaptiveRMSNorm(dim = dim, dim_time_cond = dim_time_cond)
|
|
308
|
+
ff_adanorm = AdaptiveRMSNorm(dim = dim, dim_time_cond = dim_time_cond, ada_ln_zero_bias = ada_ln_zero_bias)
|
|
253
309
|
|
|
254
310
|
ff = SwiGLUFeedForward(dim = dim, expansion_factor = expansion_factor)
|
|
255
311
|
|
|
@@ -268,22 +324,89 @@ class MimicVideo(Module):
|
|
|
268
324
|
|
|
269
325
|
self.to_pred_action_flow = nn.Sequential(
|
|
270
326
|
nn.RMSNorm(dim),
|
|
271
|
-
Linear(dim, dim_action)
|
|
327
|
+
Linear(dim, dim_action, bias = False)
|
|
272
328
|
)
|
|
273
329
|
|
|
330
|
+
self.register_buffer('zero', tensor(0.), persistent = False)
|
|
331
|
+
|
|
332
|
+
@property
|
|
333
|
+
def device(self):
|
|
334
|
+
return self.zero.device
|
|
335
|
+
|
|
336
|
+
@torch.no_grad()
|
|
337
|
+
def sample(
|
|
338
|
+
self,
|
|
339
|
+
steps = 16,
|
|
340
|
+
batch_size = 1,
|
|
341
|
+
disable_progress_bar = False,
|
|
342
|
+
**kwargs
|
|
343
|
+
):
|
|
344
|
+
|
|
345
|
+
self.eval()
|
|
346
|
+
|
|
347
|
+
noise = torch.randn((batch_size, *self.action_shape), device = self.device)
|
|
348
|
+
|
|
349
|
+
times = torch.linspace(0., 1., steps + 1, device = self.device)[:-1]
|
|
350
|
+
delta = 1. / steps
|
|
351
|
+
|
|
352
|
+
denoised = noise
|
|
353
|
+
|
|
354
|
+
cache = None
|
|
355
|
+
|
|
356
|
+
for time in tqdm(times, disable = disable_progress_bar):
|
|
357
|
+
pred_flow, cache = self.forward(actions = denoised, time = time, cache = cache, return_cache = True, **kwargs)
|
|
358
|
+
|
|
359
|
+
denoised = denoised + delta * pred_flow
|
|
360
|
+
|
|
361
|
+
return denoised
|
|
362
|
+
|
|
274
363
|
def forward(
|
|
275
364
|
self,
|
|
276
|
-
actions,
|
|
277
|
-
video_hiddens, # they use layer 19 of cosmos predict, at first denoising step. that's all
|
|
278
365
|
*,
|
|
366
|
+
actions,
|
|
279
367
|
joint_state,
|
|
368
|
+
video = None,
|
|
369
|
+
video_hiddens = None, # they use layer 19 of cosmos predict, at first denoising step. that's all
|
|
370
|
+
context_mask = None,
|
|
280
371
|
time = None,
|
|
281
372
|
time_video_denoise = 0., # 0 is noise in the scheme i prefer - default to their optimal choice, but can be changed
|
|
282
|
-
|
|
373
|
+
prompts = None,
|
|
374
|
+
prompt_token_ids = None,
|
|
375
|
+
cache = None,
|
|
376
|
+
return_cache = False,
|
|
377
|
+
return_flow = False
|
|
283
378
|
):
|
|
379
|
+
assert not exists(self.video_predict_wrapper) or (exists(prompts) ^ exists(prompt_token_ids))
|
|
380
|
+
assert actions.shape[-2:] == self.action_shape
|
|
381
|
+
|
|
284
382
|
batch, device = actions.shape[0], actions.device
|
|
285
383
|
|
|
286
|
-
is_training = not exists(time)
|
|
384
|
+
is_training = not exists(time) and not return_flow
|
|
385
|
+
|
|
386
|
+
if not exists(cache):
|
|
387
|
+
# handle maybe extraction of video hiddens
|
|
388
|
+
# only if cache is not given
|
|
389
|
+
|
|
390
|
+
assert exists(video) ^ exists(video_hiddens)
|
|
391
|
+
|
|
392
|
+
if not exists(video_hiddens):
|
|
393
|
+
assert exists(self.video_predict_wrapper), f'`video_predict_wrapper` must be passed in if raw video is passed into MimicVideo'
|
|
394
|
+
|
|
395
|
+
video_hiddens = self.video_predict_wrapper(video, prompts = prompts, prompt_token_ids = prompt_token_ids)
|
|
396
|
+
video_hiddens = video_hiddens.float() # maybe bfloat to float32
|
|
397
|
+
|
|
398
|
+
video_hiddens, _ = pack_with_inverse(video_hiddens, 'b * d')
|
|
399
|
+
|
|
400
|
+
assert video_hiddens.shape[-1] == self.dim_video_hidden
|
|
401
|
+
|
|
402
|
+
# handle video hiddens
|
|
403
|
+
|
|
404
|
+
video_hiddens = self.video_hidden_norm(video_hiddens)
|
|
405
|
+
|
|
406
|
+
# handle caching
|
|
407
|
+
|
|
408
|
+
prev_cached_video_hiddens_kv = cache if exists(cache) else ((None,) * self.depth)
|
|
409
|
+
next_cached_video_hiddens_kv = []
|
|
287
410
|
|
|
288
411
|
# handle flow time conditioning
|
|
289
412
|
|
|
@@ -301,7 +424,7 @@ class MimicVideo(Module):
|
|
|
301
424
|
noised = actions
|
|
302
425
|
|
|
303
426
|
if time.ndim == 0:
|
|
304
|
-
time =
|
|
427
|
+
time = repeat(time, '-> b', b = batch)
|
|
305
428
|
|
|
306
429
|
# handle the video denoising times
|
|
307
430
|
|
|
@@ -323,28 +446,38 @@ class MimicVideo(Module):
|
|
|
323
446
|
|
|
324
447
|
time_cond = self.to_time_cond(fourier_embed)
|
|
325
448
|
|
|
326
|
-
# handle video hiddens
|
|
327
|
-
|
|
328
|
-
video_hiddens = self.video_hidden_norm(video_hiddens)
|
|
329
|
-
|
|
330
449
|
# embed
|
|
331
450
|
|
|
332
451
|
tokens = self.to_action_tokens(noised)
|
|
333
452
|
|
|
453
|
+
# one layer of rnn for actions
|
|
454
|
+
|
|
455
|
+
rnn_out, _, = self.rnn(tokens)
|
|
456
|
+
tokens = rnn_out + tokens
|
|
457
|
+
|
|
458
|
+
# mask joint state token for proprioception masking training
|
|
459
|
+
|
|
334
460
|
joint_state_token = self.to_joint_state_token(joint_state)
|
|
335
461
|
|
|
462
|
+
if self.training and self.has_proprio_masking:
|
|
463
|
+
mask = torch.rand((batch,), device = device) < self.proprio_mask_prob
|
|
464
|
+
|
|
465
|
+
joint_state_token = einx.where('b, d, b d', mask, self.proprio_mask_token, joint_state_token)
|
|
466
|
+
|
|
467
|
+
# pack joint with action tokens
|
|
468
|
+
|
|
336
469
|
tokens, inverse_pack = pack_with_inverse((joint_state_token, tokens), 'b * d')
|
|
337
470
|
|
|
338
471
|
# transformer layers
|
|
339
472
|
|
|
340
|
-
for (
|
|
473
|
+
for ((
|
|
341
474
|
attn_norm,
|
|
342
475
|
attn,
|
|
343
476
|
cross_attn_norm,
|
|
344
477
|
cross_attn,
|
|
345
478
|
ff_norm,
|
|
346
479
|
ff
|
|
347
|
-
) in self.layers:
|
|
480
|
+
), cached_video_kv) in zip(self.layers, prev_cached_video_hiddens_kv):
|
|
348
481
|
|
|
349
482
|
# cross attention
|
|
350
483
|
|
|
@@ -352,7 +485,12 @@ class MimicVideo(Module):
|
|
|
352
485
|
|
|
353
486
|
tokens, gate = cross_attn_norm(tokens, time_cond)
|
|
354
487
|
|
|
355
|
-
|
|
488
|
+
cross_attn_out, video_kv = cross_attn(tokens, context = video_hiddens, context_mask = context_mask, kv = cached_video_kv, return_kv = True)
|
|
489
|
+
|
|
490
|
+
tokens = residual + cross_attn_out * gate
|
|
491
|
+
|
|
492
|
+
if return_cache:
|
|
493
|
+
next_cached_video_hiddens_kv.append(video_kv)
|
|
356
494
|
|
|
357
495
|
# self attention
|
|
358
496
|
|
|
@@ -389,9 +527,19 @@ class MimicVideo(Module):
|
|
|
389
527
|
pred_flow = self.to_pred_action_flow(tokens)
|
|
390
528
|
|
|
391
529
|
if not is_training:
|
|
392
|
-
|
|
530
|
+
# flow
|
|
531
|
+
|
|
532
|
+
out = pred_flow
|
|
533
|
+
else:
|
|
534
|
+
# mse flow loss
|
|
535
|
+
|
|
536
|
+
flow_loss = F.mse_loss(pred_flow, flow)
|
|
537
|
+
|
|
538
|
+
out = flow_loss
|
|
539
|
+
|
|
540
|
+
if not return_cache:
|
|
541
|
+
return out
|
|
393
542
|
|
|
394
|
-
#
|
|
543
|
+
# handle returning of cache
|
|
395
544
|
|
|
396
|
-
|
|
397
|
-
return flow_loss
|
|
545
|
+
return out, next_cached_video_hiddens_kv
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: mimic-video
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.19
|
|
4
4
|
Summary: Mimic Video
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/mimic-video/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/mimic-video
|
|
@@ -38,10 +38,14 @@ Requires-Dist: einops>=0.8.1
|
|
|
38
38
|
Requires-Dist: einx>=0.3.0
|
|
39
39
|
Requires-Dist: torch-einops-utils>=0.0.8
|
|
40
40
|
Requires-Dist: torch>=2.5
|
|
41
|
+
Requires-Dist: tqdm
|
|
41
42
|
Requires-Dist: x-mlps-pytorch
|
|
42
43
|
Provides-Extra: examples
|
|
43
44
|
Provides-Extra: test
|
|
45
|
+
Requires-Dist: accelerate; extra == 'test'
|
|
46
|
+
Requires-Dist: diffusers>=0.32.0; extra == 'test'
|
|
44
47
|
Requires-Dist: pytest; extra == 'test'
|
|
48
|
+
Requires-Dist: transformers; extra == 'test'
|
|
45
49
|
Description-Content-Type: text/markdown
|
|
46
50
|
|
|
47
51
|
<img src="./mimic-video.png" width="450px"></img>
|
|
@@ -50,6 +54,73 @@ Description-Content-Type: text/markdown
|
|
|
50
54
|
|
|
51
55
|
Implementation of [Mimic-Video](https://mimic-video.github.io/), Video-Action Models for Generalizable Robot Control Beyond VLAs
|
|
52
56
|
|
|
57
|
+
## Appreciation
|
|
58
|
+
|
|
59
|
+
- [Pranoy](https://github.com/pranoyr) for submitting a pull request for proprioception masking
|
|
60
|
+
|
|
61
|
+
## Install
|
|
62
|
+
|
|
63
|
+
```shell
|
|
64
|
+
$ pip install mimic-video
|
|
65
|
+
```
|
|
66
|
+
|
|
67
|
+
## Usage
|
|
68
|
+
|
|
69
|
+
```python
|
|
70
|
+
import torch
|
|
71
|
+
|
|
72
|
+
# video wrapper
|
|
73
|
+
# but will be agnostic to the model
|
|
74
|
+
|
|
75
|
+
from mimic_video.cosmos_predict import CosmosPredictWrapper
|
|
76
|
+
|
|
77
|
+
video_wrapper = CosmosPredictWrapper(
|
|
78
|
+
extract_layer = 1,
|
|
79
|
+
random_weights = True,
|
|
80
|
+
tiny = True
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
# mimic video
|
|
84
|
+
|
|
85
|
+
from mimic_video import MimicVideo
|
|
86
|
+
|
|
87
|
+
model = MimicVideo(512, video_wrapper)
|
|
88
|
+
|
|
89
|
+
# states
|
|
90
|
+
|
|
91
|
+
video = torch.rand(2, 3, 3, 32, 32)
|
|
92
|
+
|
|
93
|
+
joint_state = torch.randn(2, 32)
|
|
94
|
+
|
|
95
|
+
# action
|
|
96
|
+
|
|
97
|
+
actions = torch.randn(2, 32, 20)
|
|
98
|
+
|
|
99
|
+
# training
|
|
100
|
+
|
|
101
|
+
loss = model(
|
|
102
|
+
prompts = [
|
|
103
|
+
'put the package on the conveyer belt',
|
|
104
|
+
'pass the butter'
|
|
105
|
+
],
|
|
106
|
+
video = video,
|
|
107
|
+
actions = actions,
|
|
108
|
+
joint_state = joint_state
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
loss.backward()
|
|
112
|
+
|
|
113
|
+
# inference
|
|
114
|
+
|
|
115
|
+
actions = model.sample(
|
|
116
|
+
prompts = 'peel the orange',
|
|
117
|
+
video = video[:1],
|
|
118
|
+
joint_state = joint_state[:1]
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
assert actions.shape == (1, 32, 20)
|
|
122
|
+
```
|
|
123
|
+
|
|
53
124
|
## Contributing
|
|
54
125
|
|
|
55
126
|
First make sure `pytest` and test dependencies are installed with
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
mimic_video/__init__.py,sha256=Rs3QeBBGBKKi1U1ykcyeBrCL2XCbfNvppeeD1Fb1pdY,47
|
|
2
|
+
mimic_video/cosmos_predict.py,sha256=2XR9cqcUC4gKpjEDBy-GtLtMkLXvs8yKe7w8g6EeS6s,8471
|
|
3
|
+
mimic_video/mimic_video.py,sha256=wQNfdV2PGlR_-S-4Mm3cyswtRQ3nBQGJiHptya3ckKU,14761
|
|
4
|
+
mimic_video-0.0.19.dist-info/METADATA,sha256=67a7iVIkf557qMos_yvpcqL8dK9yL0Jf1DPQuhb2bwo,4142
|
|
5
|
+
mimic_video-0.0.19.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
6
|
+
mimic_video-0.0.19.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
7
|
+
mimic_video-0.0.19.dist-info/RECORD,,
|
|
@@ -1,6 +0,0 @@
|
|
|
1
|
-
mimic_video/__init__.py,sha256=-4HP_pbT4YLhRUwNwuL4qyLHbgDyQ099nHL7eVi0_Ag,48
|
|
2
|
-
mimic_video/mimic_video.py,sha256=-2HVpXAgEG28JFkJeUlypdmOMyYDD2tw0Fisf9-BZ-M,10243
|
|
3
|
-
mimic_video-0.0.3.dist-info/METADATA,sha256=MVJMzysTCCpsgxBKUA9ye-aFSeQAXinyP3ejCtJ8JD8,2960
|
|
4
|
-
mimic_video-0.0.3.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
5
|
-
mimic_video-0.0.3.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
-
mimic_video-0.0.3.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|