mimic-video 0.0.3__py3-none-any.whl → 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mimic-video might be problematic. Click here for more details.

mimic_video/__init__.py CHANGED
@@ -1,2 +1 @@
1
-
2
1
  from mimic_video.mimic_video import MimicVideo
@@ -0,0 +1,269 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
5
+
6
+ import torch
7
+ from torch.nn import Module
8
+ from torch import nn, Tensor
9
+ from einops import rearrange
10
+
11
+ from diffusers.models.transformers.transformer_cosmos import CosmosTransformer3DModel
12
+ from diffusers.models.autoencoders.autoencoder_kl_cosmos import AutoencoderKLCosmos
13
+ from diffusers.schedulers.scheduling_edm_euler import EDMEulerScheduler
14
+ from transformers import T5EncoderModel, T5TokenizerFast, T5Config
15
+
16
+ # helpers
17
+
18
+ def exists(v):
19
+ return v is not None
20
+
21
+ def identity(t):
22
+ return t
23
+
24
+ def default(v, d):
25
+ return v if exists(v) else d
26
+
27
+ # constants
28
+
29
+ TINY_TRANSFORMER_CONFIG = dict(
30
+ in_channels = 16,
31
+ out_channels = 16,
32
+ num_attention_heads = 1,
33
+ attention_head_dim = 16,
34
+ mlp_ratio = 1.0,
35
+ text_embed_dim = 32,
36
+ adaln_lora_dim = 32,
37
+ patch_size = (1, 2, 2),
38
+ max_size = (4, 16, 16),
39
+ extra_pos_embed_type = None,
40
+ concat_padding_mask = False,
41
+ )
42
+
43
+ TINY_VAE_CONFIG = dict(
44
+ in_channels = 3,
45
+ out_channels = 3,
46
+ latent_channels = 16,
47
+ encoder_block_out_channels = (8, 16),
48
+ decode_block_out_channels = (8, 16),
49
+ temporal_compression_ratio = 4,
50
+ spatial_compression_ratio = 4,
51
+ num_layers = 1,
52
+ attention_resolutions = (),
53
+ resolution = 64,
54
+ )
55
+
56
+ TINY_T5_CONFIG = dict(
57
+ vocab_size = 32128,
58
+ d_model = 32,
59
+ d_kv = 8,
60
+ d_ff = 64,
61
+ num_layers = 1,
62
+ num_heads = 1,
63
+ )
64
+
65
+ REAL_TRANSFORMER_CONFIG = dict(
66
+ in_channels = 16,
67
+ out_channels = 16,
68
+ num_attention_heads = 32,
69
+ attention_head_dim = 128,
70
+ mlp_ratio = 4.0,
71
+ text_embed_dim = 1024,
72
+ patch_size = (1, 2, 2),
73
+ max_size = (128, 240, 240),
74
+ extra_pos_embed_type = "learnable",
75
+ concat_padding_mask = True,
76
+ )
77
+
78
+ REAL_VAE_CONFIG = dict(
79
+ in_channels = 3,
80
+ out_channels = 3,
81
+ latent_channels = 16,
82
+ encoder_block_out_channels = (128, 256, 512, 512),
83
+ decode_block_out_channels = (256, 512, 512, 512),
84
+ temporal_compression_ratio = 8,
85
+ spatial_compression_ratio = 8,
86
+ )
87
+
88
+ REAL_T5_CONFIG = dict(
89
+ vocab_size = 32128,
90
+ d_model = 1024,
91
+ d_kv = 64,
92
+ d_ff = 2048,
93
+ num_layers = 12,
94
+ num_heads = 16,
95
+ )
96
+
97
+ # main class
98
+
99
+ class CosmosPredictWrapper(Module):
100
+ """
101
+ Wraps Cosmos VAE + DiT for extracting hidden states from a video.
102
+ Supports proper EDM Euler denoising steps.
103
+ """
104
+
105
+ def __init__(
106
+ self,
107
+ model_name: str = 'nvidia/Cosmos-1.0-Diffusion-7B-Video2World',
108
+ extract_layer: int = 19,
109
+ random_weights: bool = False,
110
+ tiny: bool = False,
111
+ normalize = lambda t: (t - 0.5) * 2.0
112
+ ):
113
+ super().__init__()
114
+ self.extract_layer = extract_layer
115
+ self.hook_handle = None
116
+ self.cached_hidden_states: list[Tensor] = []
117
+
118
+ if random_weights:
119
+ self._init_random_weights(tiny = tiny)
120
+ else:
121
+ self._init_pretrained(model_name)
122
+
123
+ # Initialize scheduler
124
+ self.scheduler = EDMEulerScheduler()
125
+
126
+ # store hidden dim for consumers
127
+ self.dim_latent = self.transformer.config.num_attention_heads * self.transformer.config.attention_head_dim
128
+
129
+ # maybe normalize
130
+ self.normalize = normalize
131
+
132
+ self._register_hook()
133
+
134
+ @property
135
+ def device(self):
136
+ return next(self.parameters()).device
137
+
138
+ def _init_pretrained(self, model_name: str):
139
+ """Load pretrained weights from HuggingFace"""
140
+ from diffusers import CosmosVideoToWorldPipeline
141
+
142
+ pipeline = CosmosVideoToWorldPipeline.from_pretrained(model_name)
143
+
144
+ # Extract components we need
145
+ self.vae = pipeline.vae
146
+ self.transformer = pipeline.transformer
147
+ self.text_encoder = pipeline.text_encoder
148
+ self.tokenizer = pipeline.tokenizer
149
+
150
+ # Clean up pipeline
151
+ del pipeline
152
+
153
+ def _init_random_weights(self, tiny: bool = False):
154
+ """Initialize with random weights for testing"""
155
+
156
+ transformer_config = TINY_TRANSFORMER_CONFIG if tiny else REAL_TRANSFORMER_CONFIG
157
+ vae_config = TINY_VAE_CONFIG if tiny else REAL_VAE_CONFIG
158
+ t5_config_dict = TINY_T5_CONFIG if tiny else REAL_T5_CONFIG
159
+
160
+ num_layers = max(2, self.extract_layer + 1)
161
+ if not tiny:
162
+ num_layers = max(28, num_layers)
163
+
164
+ self.transformer = CosmosTransformer3DModel(
165
+ num_layers = num_layers,
166
+ **transformer_config
167
+ )
168
+
169
+ self.vae = AutoencoderKLCosmos(**vae_config)
170
+
171
+ t5_config = T5Config(**t5_config_dict)
172
+ self.text_encoder = T5EncoderModel(t5_config)
173
+ self.tokenizer = T5TokenizerFast.from_pretrained("google-t5/t5-small")
174
+
175
+ def __del__(self):
176
+ if exists(self.hook_handle):
177
+ self.hook_handle.remove()
178
+
179
+ def _register_hook(self):
180
+ assert hasattr(self.transformer, 'transformer_blocks'), 'transformer must have transformer_blocks'
181
+ assert len(self.transformer.transformer_blocks) > self.extract_layer, f'layer {self.extract_layer} out of bounds'
182
+
183
+ target_layer = self.transformer.transformer_blocks[self.extract_layer]
184
+
185
+ def hook_fn(module, inp, out):
186
+ self.cached_hidden_states.append(out.detach().cpu())
187
+
188
+ self.hook_handle = target_layer.register_forward_hook(hook_fn)
189
+
190
+ def forward(
191
+ self,
192
+ videos: Tensor,
193
+ prompts: str | list[str] | None = None,
194
+ prompt_token_ids: Tensor | None = None,
195
+ num_inference_steps: int = 1,
196
+ ) -> Tensor:
197
+ """
198
+ videos: (batch, frames, channels, height, width) in [0, 1]
199
+ num_inference_steps: number of denoising steps to run
200
+ returns: hidden states tensor from the specified transformer layer (from first step)
201
+ """
202
+ batch, t, c, h, w = videos.shape
203
+
204
+ assert exists(prompts) ^ exists(prompt_token_ids)
205
+
206
+ # Scale videos from [0, 1] to [-1, 1] for Cosmos VAE
207
+
208
+ videos = self.normalize(videos)
209
+
210
+ if isinstance(prompts, str):
211
+ prompts = [prompts] * batch
212
+
213
+ self.cached_hidden_states.clear()
214
+
215
+ # Move video to device and rearrange for VAE: (B, T, C, H, W) -> (B, C, T, H, W)
216
+ videos = rearrange(videos, 'b t c h w -> b c t h w')
217
+
218
+ with torch.inference_mode():
219
+ # 1. encode video to latents via VAE
220
+
221
+ latents = self.vae.encode(videos).latent_dist.sample()
222
+
223
+ # 2. maybe encode text prompts
224
+
225
+ if exists(prompt_token_ids):
226
+ text_inputs = dict(input_ids = prompt_token_ids)
227
+ else:
228
+ text_inputs = self.tokenizer(
229
+ prompts,
230
+ return_tensors = "pt",
231
+ padding = True,
232
+ truncation = True,
233
+ max_length = 512
234
+ )
235
+
236
+ encoder_hidden_states = self.text_encoder(**text_inputs).last_hidden_state
237
+
238
+ # 3. Setup scheduler timesteps
239
+ self.scheduler.set_timesteps(num_inference_steps, device = self.device)
240
+ timesteps = self.scheduler.timesteps
241
+
242
+ # 4. Add noise to latents (start from pure noise scaled by initial sigma)
243
+ noise = torch.randn_like(latents)
244
+ latents = latents + noise * self.scheduler.init_noise_sigma
245
+
246
+ # 5. Denoising loop
247
+ for i, timestep in enumerate(timesteps):
248
+ # Scale model input
249
+ latent_model_input = self.scheduler.scale_model_input(latents, timestep)
250
+
251
+ # Predict noise residual
252
+ noise_pred = self.transformer(
253
+ hidden_states = latent_model_input,
254
+ encoder_hidden_states = encoder_hidden_states,
255
+ timestep = timestep.expand(batch),
256
+ return_dict = False
257
+ )[0]
258
+
259
+ # Compute previous noisy sample
260
+ latents = self.scheduler.step(noise_pred, timestep, latents, return_dict = False)[0]
261
+
262
+ assert len(self.cached_hidden_states) > 0, 'hidden states not captured'
263
+
264
+ # Return hidden states from the first denoising step
265
+ hidden = self.cached_hidden_states[0]
266
+
267
+ assert hidden.shape[-1] == self.dim_latent, f'hidden dim mismatch: expected {self.dim_latent_hidden}, got {hidden.shape[-1]}'
268
+
269
+ return hidden
@@ -1,6 +1,8 @@
1
+ from __future__ import annotations
2
+
1
3
  import torch
2
4
  from torch import nn, cat, stack, is_tensor, tensor
3
- from torch.nn import Module, ModuleList, Linear
5
+ from torch.nn import Module, ModuleList, Linear, GRU
4
6
 
5
7
  import torch.nn.functional as F
6
8
 
@@ -10,6 +12,8 @@ from einops.layers.torch import Rearrange
10
12
 
11
13
  from x_mlps_pytorch import create_mlp
12
14
 
15
+ from tqdm import tqdm
16
+
13
17
  from torch_einops_utils import (
14
18
  pad_left_ndim,
15
19
  align_dims_left,
@@ -85,7 +89,8 @@ class AdaptiveRMSNorm(Module):
85
89
  self,
86
90
  dim,
87
91
  dim_time_cond,
88
- eps = 1e-6
92
+ eps = 1e-6,
93
+ ada_ln_zero_bias = -5.
89
94
  ):
90
95
  super().__init__()
91
96
  self.scale = dim ** 0.5
@@ -96,6 +101,8 @@ class AdaptiveRMSNorm(Module):
96
101
 
97
102
  nn.init.zeros_(self.to_modulation.weight)
98
103
 
104
+ self.ada_ln_zero_bias = ada_ln_zero_bias
105
+
99
106
  def forward(
100
107
  self,
101
108
  tokens,
@@ -113,7 +120,9 @@ class AdaptiveRMSNorm(Module):
113
120
 
114
121
  adaptive_normed = normed * (scale + 1.) + shift
115
122
 
116
- return adaptive_normed, gate
123
+ gate_with_bias = (gate + self.ada_ln_zero_bias).sigmoid()
124
+
125
+ return adaptive_normed, gate_with_bias
117
126
 
118
127
  # attention
119
128
 
@@ -149,15 +158,20 @@ class Attention(Module):
149
158
  self,
150
159
  tokens,
151
160
  context = None,
152
- context_mask = None
161
+ context_mask = None,
162
+ kv = None,
163
+ return_kv = False
153
164
  ):
154
165
  context = default(context, tokens)
155
166
 
156
167
  queries = self.to_queries(tokens)
157
- keys, values = self.to_keys_values(context).chunk(2, dim = -1)
158
-
159
168
  queries = self.split_q_heads(queries)
160
- keys, values = tuple(self.split_kv_heads(t) for t in (keys, values))
169
+
170
+ if not exists(kv):
171
+ keys, values = self.to_keys_values(context).chunk(2, dim = -1)
172
+ keys, values = tuple(self.split_kv_heads(t) for t in (keys, values))
173
+ else:
174
+ keys, values = kv
161
175
 
162
176
  queries = queries * self.scale
163
177
 
@@ -173,7 +187,12 @@ class Attention(Module):
173
187
 
174
188
  out = self.merge_heads(out)
175
189
 
176
- return self.to_out(out)
190
+ out = self.to_out(out)
191
+
192
+ if not return_kv:
193
+ return out
194
+
195
+ return out, stack((keys, values))
177
196
 
178
197
  # feedforward
179
198
 
@@ -206,19 +225,43 @@ class MimicVideo(Module):
206
225
  def __init__(
207
226
  self,
208
227
  dim,
228
+ video_predict_wrapper: Module | None = None,
209
229
  *,
210
- dim_video_hidden,
230
+ dim_video_hidden = None,
231
+ action_chunk_len = 32,
211
232
  dim_action = 20,
212
233
  dim_joint_state = 32,
234
+ proprio_mask_prob = 0.1,
213
235
  depth = 8,
214
236
  dim_head = 64,
215
237
  heads = 8,
216
238
  expansion_factor = 4.,
239
+ ada_ln_zero_bias = -5.,
217
240
  dim_time_cond = None,
218
241
  sample_time_fn = None
219
242
  ):
220
243
  super().__init__()
221
244
 
245
+ self.depth = depth
246
+
247
+ # maybe video predict
248
+
249
+ self.video_predict_wrapper = video_predict_wrapper
250
+
251
+ # dims
252
+
253
+ self.action_chunk_len = action_chunk_len
254
+ self.dim_action = dim_action
255
+
256
+ self.action_shape = (action_chunk_len, dim_action)
257
+ self.dim_joint_state = dim_joint_state
258
+
259
+ dim_video_hidden = default(dim_video_hidden, video_predict_wrapper.dim_latent if exists(video_predict_wrapper) else None)
260
+
261
+ assert exists(dim_video_hidden), f'`dim_video_hidden` must be set or `video_predict_wrapper` passed in with `dim_latent`'
262
+
263
+ self.dim_video_hidden = dim_video_hidden
264
+
222
265
  # flow related
223
266
 
224
267
  self.sample_time_fn = default(sample_time_fn, default_sample_time_fn)
@@ -232,10 +275,23 @@ class MimicVideo(Module):
232
275
  self.to_fourier_embed = RandomFourierEmbed(dim) # used by deepmind, its fine
233
276
  self.to_time_cond = create_mlp(dim_in = dim * 2, dim = dim_time_cond, depth = 2, activation = nn.SiLU())
234
277
 
278
+ # joint token related
279
+
235
280
  self.to_joint_state_token = Linear(dim_joint_state, dim)
236
281
 
282
+ self.proprio_mask_prob = proprio_mask_prob
283
+ self.has_proprio_masking = proprio_mask_prob > 0.
284
+
285
+ self.proprio_mask_token = nn.Parameter(torch.randn(dim))
286
+
287
+ # video norm
288
+
237
289
  self.video_hidden_norm = nn.RMSNorm(dim_video_hidden)
238
290
 
291
+ # rnn
292
+
293
+ self.rnn = GRU(dim, dim)
294
+
239
295
  # transformer
240
296
 
241
297
  layers = []
@@ -249,7 +305,7 @@ class MimicVideo(Module):
249
305
 
250
306
  cross_attn = Attention(dim = dim, dim_head = dim_head, dim_context = dim_video_hidden, heads = heads)
251
307
 
252
- ff_adanorm = AdaptiveRMSNorm(dim = dim, dim_time_cond = dim_time_cond)
308
+ ff_adanorm = AdaptiveRMSNorm(dim = dim, dim_time_cond = dim_time_cond, ada_ln_zero_bias = ada_ln_zero_bias)
253
309
 
254
310
  ff = SwiGLUFeedForward(dim = dim, expansion_factor = expansion_factor)
255
311
 
@@ -268,22 +324,89 @@ class MimicVideo(Module):
268
324
 
269
325
  self.to_pred_action_flow = nn.Sequential(
270
326
  nn.RMSNorm(dim),
271
- Linear(dim, dim_action)
327
+ Linear(dim, dim_action, bias = False)
272
328
  )
273
329
 
330
+ self.register_buffer('zero', tensor(0.), persistent = False)
331
+
332
+ @property
333
+ def device(self):
334
+ return self.zero.device
335
+
336
+ @torch.no_grad()
337
+ def sample(
338
+ self,
339
+ steps = 16,
340
+ batch_size = 1,
341
+ disable_progress_bar = False,
342
+ **kwargs
343
+ ):
344
+
345
+ self.eval()
346
+
347
+ noise = torch.randn((batch_size, *self.action_shape), device = self.device)
348
+
349
+ times = torch.linspace(0., 1., steps + 1, device = self.device)[:-1]
350
+ delta = 1. / steps
351
+
352
+ denoised = noise
353
+
354
+ cache = None
355
+
356
+ for time in tqdm(times, disable = disable_progress_bar):
357
+ pred_flow, cache = self.forward(actions = denoised, time = time, cache = cache, return_cache = True, **kwargs)
358
+
359
+ denoised = denoised + delta * pred_flow
360
+
361
+ return denoised
362
+
274
363
  def forward(
275
364
  self,
276
- actions,
277
- video_hiddens, # they use layer 19 of cosmos predict, at first denoising step. that's all
278
365
  *,
366
+ actions,
279
367
  joint_state,
368
+ video = None,
369
+ video_hiddens = None, # they use layer 19 of cosmos predict, at first denoising step. that's all
370
+ context_mask = None,
280
371
  time = None,
281
372
  time_video_denoise = 0., # 0 is noise in the scheme i prefer - default to their optimal choice, but can be changed
282
- context_mask = None,
373
+ prompts = None,
374
+ prompt_token_ids = None,
375
+ cache = None,
376
+ return_cache = False,
377
+ return_flow = False
283
378
  ):
379
+ assert not exists(self.video_predict_wrapper) or (exists(prompts) ^ exists(prompt_token_ids))
380
+ assert actions.shape[-2:] == self.action_shape
381
+
284
382
  batch, device = actions.shape[0], actions.device
285
383
 
286
- is_training = not exists(time)
384
+ is_training = not exists(time) and not return_flow
385
+
386
+ if not exists(cache):
387
+ # handle maybe extraction of video hiddens
388
+ # only if cache is not given
389
+
390
+ assert exists(video) ^ exists(video_hiddens)
391
+
392
+ if not exists(video_hiddens):
393
+ assert exists(self.video_predict_wrapper), f'`video_predict_wrapper` must be passed in if raw video is passed into MimicVideo'
394
+
395
+ video_hiddens = self.video_predict_wrapper(video, prompts = prompts, prompt_token_ids = prompt_token_ids)
396
+ video_hiddens = video_hiddens.float() # maybe bfloat to float32
397
+
398
+ video_hiddens, _ = pack_with_inverse(video_hiddens, 'b * d')
399
+
400
+ assert video_hiddens.shape[-1] == self.dim_video_hidden
401
+
402
+ # handle video hiddens
403
+
404
+ video_hiddens = self.video_hidden_norm(video_hiddens)
405
+
406
+ # handle caching
407
+
408
+ prev_cached_video_hiddens_kv = cache if exists(cache) else ((None,) * self.depth)
409
+ next_cached_video_hiddens_kv = []
287
410
 
288
411
  # handle flow time conditioning
289
412
 
@@ -301,7 +424,7 @@ class MimicVideo(Module):
301
424
  noised = actions
302
425
 
303
426
  if time.ndim == 0:
304
- time = rearrange(time, '-> b', b = batch)
427
+ time = repeat(time, '-> b', b = batch)
305
428
 
306
429
  # handle the video denoising times
307
430
 
@@ -323,28 +446,38 @@ class MimicVideo(Module):
323
446
 
324
447
  time_cond = self.to_time_cond(fourier_embed)
325
448
 
326
- # handle video hiddens
327
-
328
- video_hiddens = self.video_hidden_norm(video_hiddens)
329
-
330
449
  # embed
331
450
 
332
451
  tokens = self.to_action_tokens(noised)
333
452
 
453
+ # one layer of rnn for actions
454
+
455
+ rnn_out, _, = self.rnn(tokens)
456
+ tokens = rnn_out + tokens
457
+
458
+ # mask joint state token for proprioception masking training
459
+
334
460
  joint_state_token = self.to_joint_state_token(joint_state)
335
461
 
462
+ if self.training and self.has_proprio_masking:
463
+ mask = torch.rand((batch,), device = device) < self.proprio_mask_prob
464
+
465
+ joint_state_token = einx.where('b, d, b d', mask, self.proprio_mask_token, joint_state_token)
466
+
467
+ # pack joint with action tokens
468
+
336
469
  tokens, inverse_pack = pack_with_inverse((joint_state_token, tokens), 'b * d')
337
470
 
338
471
  # transformer layers
339
472
 
340
- for (
473
+ for ((
341
474
  attn_norm,
342
475
  attn,
343
476
  cross_attn_norm,
344
477
  cross_attn,
345
478
  ff_norm,
346
479
  ff
347
- ) in self.layers:
480
+ ), cached_video_kv) in zip(self.layers, prev_cached_video_hiddens_kv):
348
481
 
349
482
  # cross attention
350
483
 
@@ -352,7 +485,12 @@ class MimicVideo(Module):
352
485
 
353
486
  tokens, gate = cross_attn_norm(tokens, time_cond)
354
487
 
355
- tokens = residual + cross_attn(tokens, context = video_hiddens, context_mask = context_mask) * gate
488
+ cross_attn_out, video_kv = cross_attn(tokens, context = video_hiddens, context_mask = context_mask, kv = cached_video_kv, return_kv = True)
489
+
490
+ tokens = residual + cross_attn_out * gate
491
+
492
+ if return_cache:
493
+ next_cached_video_hiddens_kv.append(video_kv)
356
494
 
357
495
  # self attention
358
496
 
@@ -389,9 +527,19 @@ class MimicVideo(Module):
389
527
  pred_flow = self.to_pred_action_flow(tokens)
390
528
 
391
529
  if not is_training:
392
- return pred_flow
530
+ # flow
531
+
532
+ out = pred_flow
533
+ else:
534
+ # mse flow loss
535
+
536
+ flow_loss = F.mse_loss(pred_flow, flow)
537
+
538
+ out = flow_loss
539
+
540
+ if not return_cache:
541
+ return out
393
542
 
394
- # mse flow loss
543
+ # handle returning of cache
395
544
 
396
- flow_loss = F.mse_loss(pred_flow, flow)
397
- return flow_loss
545
+ return out, next_cached_video_hiddens_kv
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mimic-video
3
- Version: 0.0.3
3
+ Version: 0.0.19
4
4
  Summary: Mimic Video
5
5
  Project-URL: Homepage, https://pypi.org/project/mimic-video/
6
6
  Project-URL: Repository, https://github.com/lucidrains/mimic-video
@@ -38,10 +38,14 @@ Requires-Dist: einops>=0.8.1
38
38
  Requires-Dist: einx>=0.3.0
39
39
  Requires-Dist: torch-einops-utils>=0.0.8
40
40
  Requires-Dist: torch>=2.5
41
+ Requires-Dist: tqdm
41
42
  Requires-Dist: x-mlps-pytorch
42
43
  Provides-Extra: examples
43
44
  Provides-Extra: test
45
+ Requires-Dist: accelerate; extra == 'test'
46
+ Requires-Dist: diffusers>=0.32.0; extra == 'test'
44
47
  Requires-Dist: pytest; extra == 'test'
48
+ Requires-Dist: transformers; extra == 'test'
45
49
  Description-Content-Type: text/markdown
46
50
 
47
51
  <img src="./mimic-video.png" width="450px"></img>
@@ -50,6 +54,73 @@ Description-Content-Type: text/markdown
50
54
 
51
55
  Implementation of [Mimic-Video](https://mimic-video.github.io/), Video-Action Models for Generalizable Robot Control Beyond VLAs
52
56
 
57
+ ## Appreciation
58
+
59
+ - [Pranoy](https://github.com/pranoyr) for submitting a pull request for proprioception masking
60
+
61
+ ## Install
62
+
63
+ ```shell
64
+ $ pip install mimic-video
65
+ ```
66
+
67
+ ## Usage
68
+
69
+ ```python
70
+ import torch
71
+
72
+ # video wrapper
73
+ # but will be agnostic to the model
74
+
75
+ from mimic_video.cosmos_predict import CosmosPredictWrapper
76
+
77
+ video_wrapper = CosmosPredictWrapper(
78
+ extract_layer = 1,
79
+ random_weights = True,
80
+ tiny = True
81
+ )
82
+
83
+ # mimic video
84
+
85
+ from mimic_video import MimicVideo
86
+
87
+ model = MimicVideo(512, video_wrapper)
88
+
89
+ # states
90
+
91
+ video = torch.rand(2, 3, 3, 32, 32)
92
+
93
+ joint_state = torch.randn(2, 32)
94
+
95
+ # action
96
+
97
+ actions = torch.randn(2, 32, 20)
98
+
99
+ # training
100
+
101
+ loss = model(
102
+ prompts = [
103
+ 'put the package on the conveyer belt',
104
+ 'pass the butter'
105
+ ],
106
+ video = video,
107
+ actions = actions,
108
+ joint_state = joint_state
109
+ )
110
+
111
+ loss.backward()
112
+
113
+ # inference
114
+
115
+ actions = model.sample(
116
+ prompts = 'peel the orange',
117
+ video = video[:1],
118
+ joint_state = joint_state[:1]
119
+ )
120
+
121
+ assert actions.shape == (1, 32, 20)
122
+ ```
123
+
53
124
  ## Contributing
54
125
 
55
126
  First make sure `pytest` and test dependencies are installed with
@@ -0,0 +1,7 @@
1
+ mimic_video/__init__.py,sha256=Rs3QeBBGBKKi1U1ykcyeBrCL2XCbfNvppeeD1Fb1pdY,47
2
+ mimic_video/cosmos_predict.py,sha256=2XR9cqcUC4gKpjEDBy-GtLtMkLXvs8yKe7w8g6EeS6s,8471
3
+ mimic_video/mimic_video.py,sha256=wQNfdV2PGlR_-S-4Mm3cyswtRQ3nBQGJiHptya3ckKU,14761
4
+ mimic_video-0.0.19.dist-info/METADATA,sha256=67a7iVIkf557qMos_yvpcqL8dK9yL0Jf1DPQuhb2bwo,4142
5
+ mimic_video-0.0.19.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
6
+ mimic_video-0.0.19.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
+ mimic_video-0.0.19.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- mimic_video/__init__.py,sha256=-4HP_pbT4YLhRUwNwuL4qyLHbgDyQ099nHL7eVi0_Ag,48
2
- mimic_video/mimic_video.py,sha256=-2HVpXAgEG28JFkJeUlypdmOMyYDD2tw0Fisf9-BZ-M,10243
3
- mimic_video-0.0.3.dist-info/METADATA,sha256=MVJMzysTCCpsgxBKUA9ye-aFSeQAXinyP3ejCtJ8JD8,2960
4
- mimic_video-0.0.3.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
- mimic_video-0.0.3.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
- mimic_video-0.0.3.dist-info/RECORD,,