mimic-video 0.0.1__py3-none-any.whl → 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mimic-video might be problematic. Click here for more details.

mimic_video/__init__.py CHANGED
@@ -1,2 +1 @@
1
-
2
1
  from mimic_video.mimic_video import MimicVideo
@@ -0,0 +1,269 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
5
+
6
+ import torch
7
+ from torch.nn import Module
8
+ from torch import nn, Tensor
9
+ from einops import rearrange
10
+
11
+ from diffusers.models.transformers.transformer_cosmos import CosmosTransformer3DModel
12
+ from diffusers.models.autoencoders.autoencoder_kl_cosmos import AutoencoderKLCosmos
13
+ from diffusers.schedulers.scheduling_edm_euler import EDMEulerScheduler
14
+ from transformers import T5EncoderModel, T5TokenizerFast, T5Config
15
+
16
+ # helpers
17
+
18
+ def exists(v):
19
+ return v is not None
20
+
21
+ def identity(t):
22
+ return t
23
+
24
+ def default(v, d):
25
+ return v if exists(v) else d
26
+
27
+ # constants
28
+
29
+ TINY_TRANSFORMER_CONFIG = dict(
30
+ in_channels = 16,
31
+ out_channels = 16,
32
+ num_attention_heads = 1,
33
+ attention_head_dim = 16,
34
+ mlp_ratio = 1.0,
35
+ text_embed_dim = 32,
36
+ adaln_lora_dim = 32,
37
+ patch_size = (1, 2, 2),
38
+ max_size = (4, 16, 16),
39
+ extra_pos_embed_type = None,
40
+ concat_padding_mask = False,
41
+ )
42
+
43
+ TINY_VAE_CONFIG = dict(
44
+ in_channels = 3,
45
+ out_channels = 3,
46
+ latent_channels = 16,
47
+ encoder_block_out_channels = (8, 16),
48
+ decode_block_out_channels = (8, 16),
49
+ temporal_compression_ratio = 4,
50
+ spatial_compression_ratio = 4,
51
+ num_layers = 1,
52
+ attention_resolutions = (),
53
+ resolution = 64,
54
+ )
55
+
56
+ TINY_T5_CONFIG = dict(
57
+ vocab_size = 32128,
58
+ d_model = 32,
59
+ d_kv = 8,
60
+ d_ff = 64,
61
+ num_layers = 1,
62
+ num_heads = 1,
63
+ )
64
+
65
+ REAL_TRANSFORMER_CONFIG = dict(
66
+ in_channels = 16,
67
+ out_channels = 16,
68
+ num_attention_heads = 32,
69
+ attention_head_dim = 128,
70
+ mlp_ratio = 4.0,
71
+ text_embed_dim = 1024,
72
+ patch_size = (1, 2, 2),
73
+ max_size = (128, 240, 240),
74
+ extra_pos_embed_type = "learnable",
75
+ concat_padding_mask = True,
76
+ )
77
+
78
+ REAL_VAE_CONFIG = dict(
79
+ in_channels = 3,
80
+ out_channels = 3,
81
+ latent_channels = 16,
82
+ encoder_block_out_channels = (128, 256, 512, 512),
83
+ decode_block_out_channels = (256, 512, 512, 512),
84
+ temporal_compression_ratio = 8,
85
+ spatial_compression_ratio = 8,
86
+ )
87
+
88
+ REAL_T5_CONFIG = dict(
89
+ vocab_size = 32128,
90
+ d_model = 1024,
91
+ d_kv = 64,
92
+ d_ff = 2048,
93
+ num_layers = 12,
94
+ num_heads = 16,
95
+ )
96
+
97
+ # main class
98
+
99
+ class CosmosPredictWrapper(Module):
100
+ """
101
+ Wraps Cosmos VAE + DiT for extracting hidden states from a video.
102
+ Supports proper EDM Euler denoising steps.
103
+ """
104
+
105
+ def __init__(
106
+ self,
107
+ model_name: str = 'nvidia/Cosmos-1.0-Diffusion-7B-Video2World',
108
+ extract_layer: int = 19,
109
+ random_weights: bool = False,
110
+ tiny: bool = False,
111
+ normalize = lambda t: (t - 0.5) * 2.0
112
+ ):
113
+ super().__init__()
114
+ self.extract_layer = extract_layer
115
+ self.hook_handle = None
116
+ self.cached_hidden_states: list[Tensor] = []
117
+
118
+ if random_weights:
119
+ self._init_random_weights(tiny = tiny)
120
+ else:
121
+ self._init_pretrained(model_name)
122
+
123
+ # Initialize scheduler
124
+ self.scheduler = EDMEulerScheduler()
125
+
126
+ # store hidden dim for consumers
127
+ self.dim_latent = self.transformer.config.num_attention_heads * self.transformer.config.attention_head_dim
128
+
129
+ # maybe normalize
130
+ self.normalize = normalize
131
+
132
+ self._register_hook()
133
+
134
+ @property
135
+ def device(self):
136
+ return next(self.parameters()).device
137
+
138
+ def _init_pretrained(self, model_name: str):
139
+ """Load pretrained weights from HuggingFace"""
140
+ from diffusers import CosmosVideoToWorldPipeline
141
+
142
+ pipeline = CosmosVideoToWorldPipeline.from_pretrained(model_name)
143
+
144
+ # Extract components we need
145
+ self.vae = pipeline.vae
146
+ self.transformer = pipeline.transformer
147
+ self.text_encoder = pipeline.text_encoder
148
+ self.tokenizer = pipeline.tokenizer
149
+
150
+ # Clean up pipeline
151
+ del pipeline
152
+
153
+ def _init_random_weights(self, tiny: bool = False):
154
+ """Initialize with random weights for testing"""
155
+
156
+ transformer_config = TINY_TRANSFORMER_CONFIG if tiny else REAL_TRANSFORMER_CONFIG
157
+ vae_config = TINY_VAE_CONFIG if tiny else REAL_VAE_CONFIG
158
+ t5_config_dict = TINY_T5_CONFIG if tiny else REAL_T5_CONFIG
159
+
160
+ num_layers = max(2, self.extract_layer + 1)
161
+ if not tiny:
162
+ num_layers = max(28, num_layers)
163
+
164
+ self.transformer = CosmosTransformer3DModel(
165
+ num_layers = num_layers,
166
+ **transformer_config
167
+ )
168
+
169
+ self.vae = AutoencoderKLCosmos(**vae_config)
170
+
171
+ t5_config = T5Config(**t5_config_dict)
172
+ self.text_encoder = T5EncoderModel(t5_config)
173
+ self.tokenizer = T5TokenizerFast.from_pretrained("google-t5/t5-small")
174
+
175
+ def __del__(self):
176
+ if exists(self.hook_handle):
177
+ self.hook_handle.remove()
178
+
179
+ def _register_hook(self):
180
+ assert hasattr(self.transformer, 'transformer_blocks'), 'transformer must have transformer_blocks'
181
+ assert len(self.transformer.transformer_blocks) > self.extract_layer, f'layer {self.extract_layer} out of bounds'
182
+
183
+ target_layer = self.transformer.transformer_blocks[self.extract_layer]
184
+
185
+ def hook_fn(module, inp, out):
186
+ self.cached_hidden_states.append(out.detach().cpu())
187
+
188
+ self.hook_handle = target_layer.register_forward_hook(hook_fn)
189
+
190
+ def forward(
191
+ self,
192
+ videos: Tensor,
193
+ prompts: str | list[str] | None = None,
194
+ prompt_token_ids: Tensor | None = None,
195
+ num_inference_steps: int = 1,
196
+ ) -> Tensor:
197
+ """
198
+ videos: (batch, frames, channels, height, width) in [0, 1]
199
+ num_inference_steps: number of denoising steps to run
200
+ returns: hidden states tensor from the specified transformer layer (from first step)
201
+ """
202
+ batch, t, c, h, w = videos.shape
203
+
204
+ assert exists(prompts) ^ exists(prompt_token_ids)
205
+
206
+ # Scale videos from [0, 1] to [-1, 1] for Cosmos VAE
207
+
208
+ videos = self.normalize(videos)
209
+
210
+ if isinstance(prompts, str):
211
+ prompts = [prompts] * batch
212
+
213
+ self.cached_hidden_states.clear()
214
+
215
+ # Move video to device and rearrange for VAE: (B, T, C, H, W) -> (B, C, T, H, W)
216
+ videos = rearrange(videos, 'b t c h w -> b c t h w')
217
+
218
+ with torch.inference_mode():
219
+ # 1. encode video to latents via VAE
220
+
221
+ latents = self.vae.encode(videos).latent_dist.sample()
222
+
223
+ # 2. maybe encode text prompts
224
+
225
+ if exists(prompt_token_ids):
226
+ text_inputs = dict(input_ids = prompt_token_ids)
227
+ else:
228
+ text_inputs = self.tokenizer(
229
+ prompts,
230
+ return_tensors = "pt",
231
+ padding = True,
232
+ truncation = True,
233
+ max_length = 512
234
+ )
235
+
236
+ encoder_hidden_states = self.text_encoder(**text_inputs).last_hidden_state
237
+
238
+ # 3. Setup scheduler timesteps
239
+ self.scheduler.set_timesteps(num_inference_steps, device = self.device)
240
+ timesteps = self.scheduler.timesteps
241
+
242
+ # 4. Add noise to latents (start from pure noise scaled by initial sigma)
243
+ noise = torch.randn_like(latents)
244
+ latents = latents + noise * self.scheduler.init_noise_sigma
245
+
246
+ # 5. Denoising loop
247
+ for i, timestep in enumerate(timesteps):
248
+ # Scale model input
249
+ latent_model_input = self.scheduler.scale_model_input(latents, timestep)
250
+
251
+ # Predict noise residual
252
+ noise_pred = self.transformer(
253
+ hidden_states = latent_model_input,
254
+ encoder_hidden_states = encoder_hidden_states,
255
+ timestep = timestep.expand(batch),
256
+ return_dict = False
257
+ )[0]
258
+
259
+ # Compute previous noisy sample
260
+ latents = self.scheduler.step(noise_pred, timestep, latents, return_dict = False)[0]
261
+
262
+ assert len(self.cached_hidden_states) > 0, 'hidden states not captured'
263
+
264
+ # Return hidden states from the first denoising step
265
+ hidden = self.cached_hidden_states[0]
266
+
267
+ assert hidden.shape[-1] == self.dim_latent, f'hidden dim mismatch: expected {self.dim_latent_hidden}, got {hidden.shape[-1]}'
268
+
269
+ return hidden
@@ -1,18 +1,24 @@
1
+ from __future__ import annotations
2
+
1
3
  import torch
2
- from torch import nn
3
- from torch.nn import Module, ModuleList, Linear
4
+ from torch import nn, cat, stack, is_tensor, tensor
5
+ from torch.nn import Module, ModuleList, Linear, GRU
4
6
 
5
7
  import torch.nn.functional as F
6
8
 
7
9
  import einx
8
- from einops import einsum, rearrange
10
+ from einops import einsum, rearrange, repeat
9
11
  from einops.layers.torch import Rearrange
10
12
 
11
13
  from x_mlps_pytorch import create_mlp
12
14
 
15
+ from tqdm import tqdm
16
+
13
17
  from torch_einops_utils import (
14
18
  pad_left_ndim,
15
- align_dims_left
19
+ align_dims_left,
20
+ pad_at_dim,
21
+ pack_with_inverse,
16
22
  )
17
23
 
18
24
  # ein notation
@@ -37,12 +43,23 @@ def divisible_by(num, den):
37
43
 
38
44
  # tensor function
39
45
 
46
+ def cast_tensor(val, device = None):
47
+ return tensor(val, device = device) if not is_tensor(val) else val
48
+
40
49
  def max_neg_value(t):
41
50
  return -torch.finfo(t.dtype).max
42
51
 
43
52
  def l2norm(t, eps = 1e-10):
44
53
  return F.normalize(t, dim = -1, eps = eps)
45
54
 
55
+ # token shift from Peng et al. of RWKV
56
+ # cheap way to generate relative positions
57
+
58
+ def shift_feature_dim(t):
59
+ x, x_shift = t.chunk(2, dim = -1)
60
+ x_shift = pad_at_dim(x_shift, (1, -1), dim = 1)
61
+ return cat((x, x_shift), dim = -1)
62
+
46
63
  # time
47
64
 
48
65
  # they follow p0's research finding with the beta distribution
@@ -72,7 +89,8 @@ class AdaptiveRMSNorm(Module):
72
89
  self,
73
90
  dim,
74
91
  dim_time_cond,
75
- eps = 1e-6
92
+ eps = 1e-6,
93
+ ada_ln_zero_bias = -5.
76
94
  ):
77
95
  super().__init__()
78
96
  self.scale = dim ** 0.5
@@ -83,6 +101,8 @@ class AdaptiveRMSNorm(Module):
83
101
 
84
102
  nn.init.zeros_(self.to_modulation.weight)
85
103
 
104
+ self.ada_ln_zero_bias = ada_ln_zero_bias
105
+
86
106
  def forward(
87
107
  self,
88
108
  tokens,
@@ -100,7 +120,9 @@ class AdaptiveRMSNorm(Module):
100
120
 
101
121
  adaptive_normed = normed * (scale + 1.) + shift
102
122
 
103
- return adaptive_normed, gate
123
+ gate_with_bias = (gate + self.ada_ln_zero_bias).sigmoid()
124
+
125
+ return adaptive_normed, gate_with_bias
104
126
 
105
127
  # attention
106
128
 
@@ -136,15 +158,20 @@ class Attention(Module):
136
158
  self,
137
159
  tokens,
138
160
  context = None,
139
- context_mask = None
161
+ context_mask = None,
162
+ kv = None,
163
+ return_kv = False
140
164
  ):
141
165
  context = default(context, tokens)
142
166
 
143
167
  queries = self.to_queries(tokens)
144
- keys, values = self.to_keys_values(context).chunk(2, dim = -1)
145
-
146
168
  queries = self.split_q_heads(queries)
147
- keys, values = tuple(self.split_kv_heads(t) for t in (keys, values))
169
+
170
+ if not exists(kv):
171
+ keys, values = self.to_keys_values(context).chunk(2, dim = -1)
172
+ keys, values = tuple(self.split_kv_heads(t) for t in (keys, values))
173
+ else:
174
+ keys, values = kv
148
175
 
149
176
  queries = queries * self.scale
150
177
 
@@ -160,7 +187,12 @@ class Attention(Module):
160
187
 
161
188
  out = self.merge_heads(out)
162
189
 
163
- return self.to_out(out)
190
+ out = self.to_out(out)
191
+
192
+ if not return_kv:
193
+ return out
194
+
195
+ return out, stack((keys, values))
164
196
 
165
197
  # feedforward
166
198
 
@@ -193,18 +225,43 @@ class MimicVideo(Module):
193
225
  def __init__(
194
226
  self,
195
227
  dim,
228
+ video_predict_wrapper: Module | None = None,
196
229
  *,
197
- dim_video_hidden,
230
+ dim_video_hidden = None,
231
+ action_chunk_len = 32,
198
232
  dim_action = 20,
233
+ dim_joint_state = 32,
234
+ proprio_mask_prob = 0.1,
199
235
  depth = 8,
200
236
  dim_head = 64,
201
237
  heads = 8,
202
238
  expansion_factor = 4.,
239
+ ada_ln_zero_bias = -5.,
203
240
  dim_time_cond = None,
204
241
  sample_time_fn = None
205
242
  ):
206
243
  super().__init__()
207
244
 
245
+ self.depth = depth
246
+
247
+ # maybe video predict
248
+
249
+ self.video_predict_wrapper = video_predict_wrapper
250
+
251
+ # dims
252
+
253
+ self.action_chunk_len = action_chunk_len
254
+ self.dim_action = dim_action
255
+
256
+ self.action_shape = (action_chunk_len, dim_action)
257
+ self.dim_joint_state = dim_joint_state
258
+
259
+ dim_video_hidden = default(dim_video_hidden, video_predict_wrapper.dim_latent if exists(video_predict_wrapper) else None)
260
+
261
+ assert exists(dim_video_hidden), f'`dim_video_hidden` must be set or `video_predict_wrapper` passed in with `dim_latent`'
262
+
263
+ self.dim_video_hidden = dim_video_hidden
264
+
208
265
  # flow related
209
266
 
210
267
  self.sample_time_fn = default(sample_time_fn, default_sample_time_fn)
@@ -215,13 +272,26 @@ class MimicVideo(Module):
215
272
 
216
273
  dim_time_cond = default(dim_time_cond, dim * 2)
217
274
 
218
- self.to_time_cond = nn.Sequential(
219
- RandomFourierEmbed(dim),
220
- create_mlp(dim_in = dim, dim = dim_time_cond, depth = 2, activation = nn.SiLU())
221
- )
275
+ self.to_fourier_embed = RandomFourierEmbed(dim) # used by deepmind, its fine
276
+ self.to_time_cond = create_mlp(dim_in = dim * 2, dim = dim_time_cond, depth = 2, activation = nn.SiLU())
277
+
278
+ # joint token related
279
+
280
+ self.to_joint_state_token = Linear(dim_joint_state, dim)
281
+
282
+ self.proprio_mask_prob = proprio_mask_prob
283
+ self.has_proprio_masking = proprio_mask_prob > 0.
284
+
285
+ self.proprio_mask_token = nn.Parameter(torch.randn(dim))
286
+
287
+ # video norm
222
288
 
223
289
  self.video_hidden_norm = nn.RMSNorm(dim_video_hidden)
224
290
 
291
+ # rnn
292
+
293
+ self.rnn = GRU(dim, dim)
294
+
225
295
  # transformer
226
296
 
227
297
  layers = []
@@ -235,7 +305,7 @@ class MimicVideo(Module):
235
305
 
236
306
  cross_attn = Attention(dim = dim, dim_head = dim_head, dim_context = dim_video_hidden, heads = heads)
237
307
 
238
- ff_adanorm = AdaptiveRMSNorm(dim = dim, dim_time_cond = dim_time_cond)
308
+ ff_adanorm = AdaptiveRMSNorm(dim = dim, dim_time_cond = dim_time_cond, ada_ln_zero_bias = ada_ln_zero_bias)
239
309
 
240
310
  ff = SwiGLUFeedForward(dim = dim, expansion_factor = expansion_factor)
241
311
 
@@ -254,25 +324,93 @@ class MimicVideo(Module):
254
324
 
255
325
  self.to_pred_action_flow = nn.Sequential(
256
326
  nn.RMSNorm(dim),
257
- Linear(dim, dim_action)
327
+ Linear(dim, dim_action, bias = False)
258
328
  )
259
329
 
330
+ self.register_buffer('zero', tensor(0.), persistent = False)
331
+
332
+ @property
333
+ def device(self):
334
+ return self.zero.device
335
+
336
+ @torch.no_grad()
337
+ def sample(
338
+ self,
339
+ steps = 16,
340
+ batch_size = 1,
341
+ disable_progress_bar = False,
342
+ **kwargs
343
+ ):
344
+
345
+ self.eval()
346
+
347
+ noise = torch.randn((batch_size, *self.action_shape), device = self.device)
348
+
349
+ times = torch.linspace(0., 1., steps + 1, device = self.device)[:-1]
350
+ delta = 1. / steps
351
+
352
+ denoised = noise
353
+
354
+ cache = None
355
+
356
+ for time in tqdm(times, disable = disable_progress_bar):
357
+ pred_flow, cache = self.forward(actions = denoised, time = time, cache = cache, return_cache = True, **kwargs)
358
+
359
+ denoised = denoised + delta * pred_flow
360
+
361
+ return denoised
362
+
260
363
  def forward(
261
364
  self,
262
- actions,
263
- video_hiddens, # they use layer 19 of cosmos predict, at first denoising step. that's all
264
365
  *,
265
- time = None,
366
+ actions,
367
+ joint_state,
368
+ video = None,
369
+ video_hiddens = None, # they use layer 19 of cosmos predict, at first denoising step. that's all
266
370
  context_mask = None,
371
+ time = None,
372
+ time_video_denoise = 0., # 0 is noise in the scheme i prefer - default to their optimal choice, but can be changed
373
+ prompts = None,
374
+ prompt_token_ids = None,
375
+ cache = None,
376
+ return_cache = False,
377
+ return_flow = False
267
378
  ):
379
+ assert not exists(self.video_predict_wrapper) or (exists(prompts) ^ exists(prompt_token_ids))
380
+ assert actions.shape[-2:] == self.action_shape
381
+
382
+ batch, device = actions.shape[0], actions.device
383
+
384
+ is_training = not exists(time) and not return_flow
385
+
386
+ if not exists(cache):
387
+ # handle maybe extraction of video hiddens
388
+ # only if cache is not given
389
+
390
+ assert exists(video) ^ exists(video_hiddens)
391
+
392
+ if not exists(video_hiddens):
393
+ assert exists(self.video_predict_wrapper), f'`video_predict_wrapper` must be passed in if raw video is passed into MimicVideo'
394
+
395
+ video_hiddens = self.video_predict_wrapper(video, prompts = prompts, prompt_token_ids = prompt_token_ids)
396
+ video_hiddens = video_hiddens.float() # maybe bfloat to float32
397
+
398
+ video_hiddens, _ = pack_with_inverse(video_hiddens, 'b * d')
399
+
400
+ assert video_hiddens.shape[-1] == self.dim_video_hidden
401
+
402
+ # handle video hiddens
403
+
404
+ video_hiddens = self.video_hidden_norm(video_hiddens)
268
405
 
269
- is_training = not exists(time)
406
+ # handle caching
407
+
408
+ prev_cached_video_hiddens_kv = cache if exists(cache) else ((None,) * self.depth)
409
+ next_cached_video_hiddens_kv = []
270
410
 
271
411
  # handle flow time conditioning
272
412
 
273
413
  if is_training:
274
- batch, device = actions.shape[0], actions.device
275
-
276
414
  time = torch.rand((batch,), device = device)
277
415
  time = self.sample_time_fn(time)
278
416
 
@@ -285,26 +423,61 @@ class MimicVideo(Module):
285
423
  else:
286
424
  noised = actions
287
425
 
288
- time_cond = self.to_time_cond(time)
426
+ if time.ndim == 0:
427
+ time = repeat(time, '-> b', b = batch)
428
+
429
+ # handle the video denoising times
430
+
431
+ time_video_denoise = cast_tensor(time_video_denoise)
432
+
433
+ if time_video_denoise.ndim == 0:
434
+ time_video_denoise = rearrange(time_video_denoise, '-> 1')
435
+
436
+ if time_video_denoise.shape[0] != batch:
437
+ time_video_denoise = repeat(time_video_denoise, '1 -> b', b = batch)
289
438
 
290
- # handle video hiddens
439
+ times = stack((time, time_video_denoise), dim = -1)
291
440
 
292
- video_hiddens = self.video_hidden_norm(video_hiddens)
441
+ # fourier embed and mlp to time condition
442
+
443
+ fourier_embed = self.to_fourier_embed(times)
444
+
445
+ fourier_embed = rearrange(fourier_embed, '... times d -> ... (times d)')
446
+
447
+ time_cond = self.to_time_cond(fourier_embed)
293
448
 
294
449
  # embed
295
450
 
296
451
  tokens = self.to_action_tokens(noised)
297
452
 
453
+ # one layer of rnn for actions
454
+
455
+ rnn_out, _, = self.rnn(tokens)
456
+ tokens = rnn_out + tokens
457
+
458
+ # mask joint state token for proprioception masking training
459
+
460
+ joint_state_token = self.to_joint_state_token(joint_state)
461
+
462
+ if self.training and self.has_proprio_masking:
463
+ mask = torch.rand((batch,), device = device) < self.proprio_mask_prob
464
+
465
+ joint_state_token = einx.where('b, d, b d', mask, self.proprio_mask_token, joint_state_token)
466
+
467
+ # pack joint with action tokens
468
+
469
+ tokens, inverse_pack = pack_with_inverse((joint_state_token, tokens), 'b * d')
470
+
298
471
  # transformer layers
299
472
 
300
- for (
473
+ for ((
301
474
  attn_norm,
302
475
  attn,
303
476
  cross_attn_norm,
304
477
  cross_attn,
305
478
  ff_norm,
306
479
  ff
307
- ) in self.layers:
480
+ ), cached_video_kv) in zip(self.layers, prev_cached_video_hiddens_kv):
308
481
 
309
482
  # cross attention
310
483
 
@@ -312,7 +485,12 @@ class MimicVideo(Module):
312
485
 
313
486
  tokens, gate = cross_attn_norm(tokens, time_cond)
314
487
 
315
- tokens = residual + cross_attn(tokens, context = video_hiddens, context_mask = context_mask) * gate
488
+ cross_attn_out, video_kv = cross_attn(tokens, context = video_hiddens, context_mask = context_mask, kv = cached_video_kv, return_kv = True)
489
+
490
+ tokens = residual + cross_attn_out * gate
491
+
492
+ if return_cache:
493
+ next_cached_video_hiddens_kv.append(video_kv)
316
494
 
317
495
  # self attention
318
496
 
@@ -322,22 +500,46 @@ class MimicVideo(Module):
322
500
 
323
501
  tokens = residual + attn(tokens) * gate
324
502
 
325
- # feedforward
503
+ # prepare feedforward
326
504
 
327
505
  residual = tokens
328
506
 
329
507
  tokens, gate = ff_norm(tokens, time_cond)
330
508
 
509
+ # shift along time for action tokens for cheap relative positioning, which is better than messing with rope with such short action chunks
510
+
511
+ joint_state_token, tokens = inverse_pack(tokens)
512
+
513
+ tokens = shift_feature_dim(tokens)
514
+
515
+ tokens, _ = pack_with_inverse((joint_state_token, tokens), 'b * d')
516
+
517
+ # feedforward
518
+
331
519
  tokens = residual + ff(tokens) * gate
332
520
 
521
+ # remove joint token
522
+
523
+ _, tokens = inverse_pack(tokens)
524
+
333
525
  # prediction
334
526
 
335
527
  pred_flow = self.to_pred_action_flow(tokens)
336
528
 
337
529
  if not is_training:
338
- return pred_flow
530
+ # flow
531
+
532
+ out = pred_flow
533
+ else:
534
+ # mse flow loss
535
+
536
+ flow_loss = F.mse_loss(pred_flow, flow)
537
+
538
+ out = flow_loss
539
+
540
+ if not return_cache:
541
+ return out
339
542
 
340
- # mse flow loss
543
+ # handle returning of cache
341
544
 
342
- flow_loss = F.mse_loss(pred_flow, flow)
343
- return flow_loss
545
+ return out, next_cached_video_hiddens_kv
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mimic-video
3
- Version: 0.0.1
3
+ Version: 0.0.19
4
4
  Summary: Mimic Video
5
5
  Project-URL: Homepage, https://pypi.org/project/mimic-video/
6
6
  Project-URL: Repository, https://github.com/lucidrains/mimic-video
@@ -38,10 +38,14 @@ Requires-Dist: einops>=0.8.1
38
38
  Requires-Dist: einx>=0.3.0
39
39
  Requires-Dist: torch-einops-utils>=0.0.8
40
40
  Requires-Dist: torch>=2.5
41
+ Requires-Dist: tqdm
41
42
  Requires-Dist: x-mlps-pytorch
42
43
  Provides-Extra: examples
43
44
  Provides-Extra: test
45
+ Requires-Dist: accelerate; extra == 'test'
46
+ Requires-Dist: diffusers>=0.32.0; extra == 'test'
44
47
  Requires-Dist: pytest; extra == 'test'
48
+ Requires-Dist: transformers; extra == 'test'
45
49
  Description-Content-Type: text/markdown
46
50
 
47
51
  <img src="./mimic-video.png" width="450px"></img>
@@ -50,6 +54,73 @@ Description-Content-Type: text/markdown
50
54
 
51
55
  Implementation of [Mimic-Video](https://mimic-video.github.io/), Video-Action Models for Generalizable Robot Control Beyond VLAs
52
56
 
57
+ ## Appreciation
58
+
59
+ - [Pranoy](https://github.com/pranoyr) for submitting a pull request for proprioception masking
60
+
61
+ ## Install
62
+
63
+ ```shell
64
+ $ pip install mimic-video
65
+ ```
66
+
67
+ ## Usage
68
+
69
+ ```python
70
+ import torch
71
+
72
+ # video wrapper
73
+ # but will be agnostic to the model
74
+
75
+ from mimic_video.cosmos_predict import CosmosPredictWrapper
76
+
77
+ video_wrapper = CosmosPredictWrapper(
78
+ extract_layer = 1,
79
+ random_weights = True,
80
+ tiny = True
81
+ )
82
+
83
+ # mimic video
84
+
85
+ from mimic_video import MimicVideo
86
+
87
+ model = MimicVideo(512, video_wrapper)
88
+
89
+ # states
90
+
91
+ video = torch.rand(2, 3, 3, 32, 32)
92
+
93
+ joint_state = torch.randn(2, 32)
94
+
95
+ # action
96
+
97
+ actions = torch.randn(2, 32, 20)
98
+
99
+ # training
100
+
101
+ loss = model(
102
+ prompts = [
103
+ 'put the package on the conveyer belt',
104
+ 'pass the butter'
105
+ ],
106
+ video = video,
107
+ actions = actions,
108
+ joint_state = joint_state
109
+ )
110
+
111
+ loss.backward()
112
+
113
+ # inference
114
+
115
+ actions = model.sample(
116
+ prompts = 'peel the orange',
117
+ video = video[:1],
118
+ joint_state = joint_state[:1]
119
+ )
120
+
121
+ assert actions.shape == (1, 32, 20)
122
+ ```
123
+
53
124
  ## Contributing
54
125
 
55
126
  First make sure `pytest` and test dependencies are installed with
@@ -0,0 +1,7 @@
1
+ mimic_video/__init__.py,sha256=Rs3QeBBGBKKi1U1ykcyeBrCL2XCbfNvppeeD1Fb1pdY,47
2
+ mimic_video/cosmos_predict.py,sha256=2XR9cqcUC4gKpjEDBy-GtLtMkLXvs8yKe7w8g6EeS6s,8471
3
+ mimic_video/mimic_video.py,sha256=wQNfdV2PGlR_-S-4Mm3cyswtRQ3nBQGJiHptya3ckKU,14761
4
+ mimic_video-0.0.19.dist-info/METADATA,sha256=67a7iVIkf557qMos_yvpcqL8dK9yL0Jf1DPQuhb2bwo,4142
5
+ mimic_video-0.0.19.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
6
+ mimic_video-0.0.19.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
+ mimic_video-0.0.19.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- mimic_video/__init__.py,sha256=-4HP_pbT4YLhRUwNwuL4qyLHbgDyQ099nHL7eVi0_Ag,48
2
- mimic_video/mimic_video.py,sha256=aejvjr1F3A7pZFikf-kEgeOpi1_53xVddBMpDPoxA90,8272
3
- mimic_video-0.0.1.dist-info/METADATA,sha256=414y344JcuIKQJss7d9riTrHszIwthHW8DDSSuRntdo,2960
4
- mimic_video-0.0.1.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
- mimic_video-0.0.1.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
- mimic_video-0.0.1.dist-info/RECORD,,