mimic-video 0.0.3__py3-none-any.whl → 0.0.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mimic-video might be problematic. Click here for more details.

mimic_video/__init__.py CHANGED
@@ -1,2 +1 @@
1
-
2
1
  from mimic_video.mimic_video import MimicVideo
@@ -0,0 +1,269 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
5
+
6
+ import torch
7
+ from torch.nn import Module
8
+ from torch import nn, Tensor
9
+ from einops import rearrange
10
+
11
+ from diffusers.models.transformers.transformer_cosmos import CosmosTransformer3DModel
12
+ from diffusers.models.autoencoders.autoencoder_kl_cosmos import AutoencoderKLCosmos
13
+ from diffusers.schedulers.scheduling_edm_euler import EDMEulerScheduler
14
+ from transformers import T5EncoderModel, T5TokenizerFast, T5Config
15
+
16
+ # helpers
17
+
18
+ def exists(v):
19
+ return v is not None
20
+
21
+ def identity(t):
22
+ return t
23
+
24
+ def default(v, d):
25
+ return v if exists(v) else d
26
+
27
+ # constants
28
+
29
+ TINY_TRANSFORMER_CONFIG = dict(
30
+ in_channels = 16,
31
+ out_channels = 16,
32
+ num_attention_heads = 1,
33
+ attention_head_dim = 16,
34
+ mlp_ratio = 1.0,
35
+ text_embed_dim = 32,
36
+ adaln_lora_dim = 32,
37
+ patch_size = (1, 2, 2),
38
+ max_size = (4, 16, 16),
39
+ extra_pos_embed_type = None,
40
+ concat_padding_mask = False,
41
+ )
42
+
43
+ TINY_VAE_CONFIG = dict(
44
+ in_channels = 3,
45
+ out_channels = 3,
46
+ latent_channels = 16,
47
+ encoder_block_out_channels = (8, 16),
48
+ decode_block_out_channels = (8, 16),
49
+ temporal_compression_ratio = 4,
50
+ spatial_compression_ratio = 4,
51
+ num_layers = 1,
52
+ attention_resolutions = (),
53
+ resolution = 64,
54
+ )
55
+
56
+ TINY_T5_CONFIG = dict(
57
+ vocab_size = 32128,
58
+ d_model = 32,
59
+ d_kv = 8,
60
+ d_ff = 64,
61
+ num_layers = 1,
62
+ num_heads = 1,
63
+ )
64
+
65
+ REAL_TRANSFORMER_CONFIG = dict(
66
+ in_channels = 16,
67
+ out_channels = 16,
68
+ num_attention_heads = 32,
69
+ attention_head_dim = 128,
70
+ mlp_ratio = 4.0,
71
+ text_embed_dim = 1024,
72
+ patch_size = (1, 2, 2),
73
+ max_size = (128, 240, 240),
74
+ extra_pos_embed_type = "learnable",
75
+ concat_padding_mask = True,
76
+ )
77
+
78
+ REAL_VAE_CONFIG = dict(
79
+ in_channels = 3,
80
+ out_channels = 3,
81
+ latent_channels = 16,
82
+ encoder_block_out_channels = (128, 256, 512, 512),
83
+ decode_block_out_channels = (256, 512, 512, 512),
84
+ temporal_compression_ratio = 8,
85
+ spatial_compression_ratio = 8,
86
+ )
87
+
88
+ REAL_T5_CONFIG = dict(
89
+ vocab_size = 32128,
90
+ d_model = 1024,
91
+ d_kv = 64,
92
+ d_ff = 2048,
93
+ num_layers = 12,
94
+ num_heads = 16,
95
+ )
96
+
97
+ # main class
98
+
99
+ class CosmosPredictWrapper(Module):
100
+ """
101
+ Wraps Cosmos VAE + DiT for extracting hidden states from a video.
102
+ Supports proper EDM Euler denoising steps.
103
+ """
104
+
105
+ def __init__(
106
+ self,
107
+ model_name: str = 'nvidia/Cosmos-1.0-Diffusion-7B-Video2World',
108
+ extract_layer: int = 19,
109
+ random_weights: bool = False,
110
+ tiny: bool = False,
111
+ normalize = lambda t: (t - 0.5) * 2.0
112
+ ):
113
+ super().__init__()
114
+ self.extract_layer = extract_layer
115
+ self.hook_handle = None
116
+ self.cached_hidden_states: list[Tensor] = []
117
+
118
+ if random_weights:
119
+ self._init_random_weights(tiny = tiny)
120
+ else:
121
+ self._init_pretrained(model_name)
122
+
123
+ # Initialize scheduler
124
+ self.scheduler = EDMEulerScheduler()
125
+
126
+ # store hidden dim for consumers
127
+ self.dim_latent = self.transformer.config.num_attention_heads * self.transformer.config.attention_head_dim
128
+
129
+ # maybe normalize
130
+ self.normalize = normalize
131
+
132
+ self._register_hook()
133
+
134
+ @property
135
+ def device(self):
136
+ return next(self.parameters()).device
137
+
138
+ def _init_pretrained(self, model_name: str):
139
+ """Load pretrained weights from HuggingFace"""
140
+ from diffusers import CosmosVideoToWorldPipeline
141
+
142
+ pipeline = CosmosVideoToWorldPipeline.from_pretrained(model_name)
143
+
144
+ # Extract components we need
145
+ self.vae = pipeline.vae
146
+ self.transformer = pipeline.transformer
147
+ self.text_encoder = pipeline.text_encoder
148
+ self.tokenizer = pipeline.tokenizer
149
+
150
+ # Clean up pipeline
151
+ del pipeline
152
+
153
+ def _init_random_weights(self, tiny: bool = False):
154
+ """Initialize with random weights for testing"""
155
+
156
+ transformer_config = TINY_TRANSFORMER_CONFIG if tiny else REAL_TRANSFORMER_CONFIG
157
+ vae_config = TINY_VAE_CONFIG if tiny else REAL_VAE_CONFIG
158
+ t5_config_dict = TINY_T5_CONFIG if tiny else REAL_T5_CONFIG
159
+
160
+ num_layers = max(2, self.extract_layer + 1)
161
+ if not tiny:
162
+ num_layers = max(28, num_layers)
163
+
164
+ self.transformer = CosmosTransformer3DModel(
165
+ num_layers = num_layers,
166
+ **transformer_config
167
+ )
168
+
169
+ self.vae = AutoencoderKLCosmos(**vae_config)
170
+
171
+ t5_config = T5Config(**t5_config_dict)
172
+ self.text_encoder = T5EncoderModel(t5_config)
173
+ self.tokenizer = T5TokenizerFast.from_pretrained("google-t5/t5-small")
174
+
175
+ def __del__(self):
176
+ if exists(self.hook_handle):
177
+ self.hook_handle.remove()
178
+
179
+ def _register_hook(self):
180
+ assert hasattr(self.transformer, 'transformer_blocks'), 'transformer must have transformer_blocks'
181
+ assert len(self.transformer.transformer_blocks) > self.extract_layer, f'layer {self.extract_layer} out of bounds'
182
+
183
+ target_layer = self.transformer.transformer_blocks[self.extract_layer]
184
+
185
+ def hook_fn(module, inp, out):
186
+ self.cached_hidden_states.append(out.detach().cpu())
187
+
188
+ self.hook_handle = target_layer.register_forward_hook(hook_fn)
189
+
190
+ def forward(
191
+ self,
192
+ videos: Tensor,
193
+ prompts: str | list[str] | None = None,
194
+ prompt_token_ids: Tensor | None = None,
195
+ num_inference_steps: int = 1,
196
+ ) -> Tensor:
197
+ """
198
+ videos: (batch, frames, channels, height, width) in [0, 1]
199
+ num_inference_steps: number of denoising steps to run
200
+ returns: hidden states tensor from the specified transformer layer (from first step)
201
+ """
202
+ batch, t, c, h, w = videos.shape
203
+
204
+ assert exists(prompts) ^ exists(prompt_token_ids)
205
+
206
+ # Scale videos from [0, 1] to [-1, 1] for Cosmos VAE
207
+
208
+ videos = self.normalize(videos)
209
+
210
+ if isinstance(prompts, str):
211
+ prompts = [prompts] * batch
212
+
213
+ self.cached_hidden_states.clear()
214
+
215
+ # Move video to device and rearrange for VAE: (B, T, C, H, W) -> (B, C, T, H, W)
216
+ videos = rearrange(videos, 'b t c h w -> b c t h w')
217
+
218
+ with torch.inference_mode():
219
+ # 1. encode video to latents via VAE
220
+
221
+ latents = self.vae.encode(videos).latent_dist.sample()
222
+
223
+ # 2. maybe encode text prompts
224
+
225
+ if exists(prompt_token_ids):
226
+ text_inputs = dict(input_ids = prompt_token_ids)
227
+ else:
228
+ text_inputs = self.tokenizer(
229
+ prompts,
230
+ return_tensors = "pt",
231
+ padding = True,
232
+ truncation = True,
233
+ max_length = 512
234
+ )
235
+
236
+ encoder_hidden_states = self.text_encoder(**text_inputs).last_hidden_state
237
+
238
+ # 3. Setup scheduler timesteps
239
+ self.scheduler.set_timesteps(num_inference_steps, device = self.device)
240
+ timesteps = self.scheduler.timesteps
241
+
242
+ # 4. Add noise to latents (start from pure noise scaled by initial sigma)
243
+ noise = torch.randn_like(latents)
244
+ latents = latents + noise * self.scheduler.init_noise_sigma
245
+
246
+ # 5. Denoising loop
247
+ for i, timestep in enumerate(timesteps):
248
+ # Scale model input
249
+ latent_model_input = self.scheduler.scale_model_input(latents, timestep)
250
+
251
+ # Predict noise residual
252
+ noise_pred = self.transformer(
253
+ hidden_states = latent_model_input,
254
+ encoder_hidden_states = encoder_hidden_states,
255
+ timestep = timestep.expand(batch),
256
+ return_dict = False
257
+ )[0]
258
+
259
+ # Compute previous noisy sample
260
+ latents = self.scheduler.step(noise_pred, timestep, latents, return_dict = False)[0]
261
+
262
+ assert len(self.cached_hidden_states) > 0, 'hidden states not captured'
263
+
264
+ # Return hidden states from the first denoising step
265
+ hidden = self.cached_hidden_states[0]
266
+
267
+ assert hidden.shape[-1] == self.dim_latent, f'hidden dim mismatch: expected {self.dim_latent_hidden}, got {hidden.shape[-1]}'
268
+
269
+ return hidden
@@ -1,22 +1,31 @@
1
+ from __future__ import annotations
2
+ from functools import partial
3
+
1
4
  import torch
2
5
  from torch import nn, cat, stack, is_tensor, tensor
3
- from torch.nn import Module, ModuleList, Linear
6
+ from torch.nn import Module, ModuleList, Linear, GRU
4
7
 
5
8
  import torch.nn.functional as F
6
9
 
7
10
  import einx
8
- from einops import einsum, rearrange, repeat
11
+ from einops import einsum, rearrange, repeat, reduce
9
12
  from einops.layers.torch import Rearrange
10
13
 
11
14
  from x_mlps_pytorch import create_mlp
12
15
 
16
+ from tqdm import tqdm
17
+
13
18
  from torch_einops_utils import (
19
+ lens_to_mask,
14
20
  pad_left_ndim,
15
21
  align_dims_left,
16
22
  pad_at_dim,
17
23
  pack_with_inverse,
24
+ masked_mean
18
25
  )
19
26
 
27
+ from hyper_connections.mHCv2 import get_init_and_expand_reduce_stream_functions
28
+
20
29
  # ein notation
21
30
 
22
31
  # b - batch
@@ -26,6 +35,10 @@ from torch_einops_utils import (
26
35
  # i, j - sequence (source, target)
27
36
  # d - feature dimension
28
37
 
38
+ # constants
39
+
40
+ LinearNoBias = partial(Linear, bias = False)
41
+
29
42
  # functions
30
43
 
31
44
  def exists(v):
@@ -85,25 +98,27 @@ class AdaptiveRMSNorm(Module):
85
98
  self,
86
99
  dim,
87
100
  dim_time_cond,
88
- eps = 1e-6
101
+ eps = 1e-6,
102
+ ada_ln_zero_bias = -5.
89
103
  ):
90
104
  super().__init__()
91
105
  self.scale = dim ** 0.5
92
106
  self.eps = eps
93
107
 
94
- self.to_modulation = Linear(dim_time_cond, dim * 3, bias = False)
95
- self.split_modulation = Rearrange('b (three d) -> three b 1 d', three = 3)
108
+ self.to_modulation = LinearNoBias(dim_time_cond, dim * 3)
109
+ self.split_modulation = Rearrange('... (three d) -> three ... d', three = 3)
96
110
 
97
111
  nn.init.zeros_(self.to_modulation.weight)
98
112
 
113
+ self.ada_ln_zero_bias = ada_ln_zero_bias
114
+
99
115
  def forward(
100
116
  self,
101
117
  tokens,
102
118
  time_cond
103
119
  ):
104
-
105
- if time_cond.ndim == 1:
106
- time_cond = pad_left_ndim(time_cond, 1)
120
+ if time_cond.ndim == 2:
121
+ time_cond = rearrange(time_cond, 'b d -> b 1 d')
107
122
 
108
123
  modulations = self.to_modulation(time_cond)
109
124
 
@@ -113,7 +128,9 @@ class AdaptiveRMSNorm(Module):
113
128
 
114
129
  adaptive_normed = normed * (scale + 1.) + shift
115
130
 
116
- return adaptive_normed, gate
131
+ gate_with_bias = (gate + self.ada_ln_zero_bias).sigmoid()
132
+
133
+ return adaptive_normed, gate_with_bias
117
134
 
118
135
  # attention
119
136
 
@@ -125,7 +142,8 @@ class Attention(Module):
125
142
  dim_context = None,
126
143
  dim_head = 64,
127
144
  heads = 8,
128
- kv_heads = 2
145
+ kv_heads = 2,
146
+ attn_gate_value = True
129
147
  ):
130
148
  super().__init__()
131
149
  dim_q_inner = dim_head * heads
@@ -134,9 +152,12 @@ class Attention(Module):
134
152
 
135
153
  self.scale = dim_head ** -0.5
136
154
 
137
- self.to_queries = Linear(dim, dim_q_inner, bias = False)
138
- self.to_keys_values = Linear(dim_context, dim_kv_inner * 2, bias = False)
139
- self.to_out = Linear(dim_q_inner, dim, bias = False)
155
+ self.to_queries = LinearNoBias(dim, dim_q_inner)
156
+ self.to_keys_values = LinearNoBias(dim_context, dim_kv_inner * 2)
157
+
158
+ self.attn_gate_value = nn.Sequential(LinearNoBias(dim, heads), Rearrange('b n (g h) -> b g h n 1', h = kv_heads))
159
+
160
+ self.to_out = LinearNoBias(dim_q_inner, dim)
140
161
 
141
162
  assert divisible_by(heads, kv_heads)
142
163
  groups = heads // kv_heads
@@ -149,15 +170,20 @@ class Attention(Module):
149
170
  self,
150
171
  tokens,
151
172
  context = None,
152
- context_mask = None
173
+ context_mask = None,
174
+ kv = None,
175
+ return_kv = False
153
176
  ):
154
177
  context = default(context, tokens)
155
178
 
156
179
  queries = self.to_queries(tokens)
157
- keys, values = self.to_keys_values(context).chunk(2, dim = -1)
158
-
159
180
  queries = self.split_q_heads(queries)
160
- keys, values = tuple(self.split_kv_heads(t) for t in (keys, values))
181
+
182
+ if not exists(kv):
183
+ keys, values = self.to_keys_values(context).chunk(2, dim = -1)
184
+ keys, values = tuple(self.split_kv_heads(t) for t in (keys, values))
185
+ else:
186
+ keys, values = kv
161
187
 
162
188
  queries = queries * self.scale
163
189
 
@@ -171,9 +197,16 @@ class Attention(Module):
171
197
 
172
198
  out = einsum(attn, values, 'b g h i j, b h j d -> b g h i d')
173
199
 
200
+ out = out * self.attn_gate_value(tokens).sigmoid()
201
+
174
202
  out = self.merge_heads(out)
175
203
 
176
- return self.to_out(out)
204
+ out = self.to_out(out)
205
+
206
+ if not return_kv:
207
+ return out
208
+
209
+ return out, stack((keys, values))
177
210
 
178
211
  # feedforward
179
212
 
@@ -206,19 +239,47 @@ class MimicVideo(Module):
206
239
  def __init__(
207
240
  self,
208
241
  dim,
242
+ video_predict_wrapper: Module | None = None,
209
243
  *,
210
- dim_video_hidden,
244
+ dim_video_hidden = None,
245
+ action_chunk_len = 32,
211
246
  dim_action = 20,
212
247
  dim_joint_state = 32,
248
+ proprio_mask_prob = 0.1,
213
249
  depth = 8,
214
250
  dim_head = 64,
215
251
  heads = 8,
216
252
  expansion_factor = 4.,
253
+ ada_ln_zero_bias = -5.,
217
254
  dim_time_cond = None,
218
- sample_time_fn = None
255
+ sample_time_fn = None,
256
+ train_time_rtc = False,
257
+ train_time_rtc_max_delay = None,
258
+ num_residual_streams = 1,
259
+ mhc_kwargs: dict = dict()
219
260
  ):
220
261
  super().__init__()
221
262
 
263
+ self.depth = depth
264
+
265
+ # maybe video predict
266
+
267
+ self.video_predict_wrapper = video_predict_wrapper
268
+
269
+ # dims
270
+
271
+ self.action_chunk_len = action_chunk_len
272
+ self.dim_action = dim_action
273
+
274
+ self.action_shape = (action_chunk_len, dim_action)
275
+ self.dim_joint_state = dim_joint_state
276
+
277
+ dim_video_hidden = default(dim_video_hidden, video_predict_wrapper.dim_latent if exists(video_predict_wrapper) else None)
278
+
279
+ assert exists(dim_video_hidden), f'`dim_video_hidden` must be set or `video_predict_wrapper` passed in with `dim_latent`'
280
+
281
+ self.dim_video_hidden = dim_video_hidden
282
+
222
283
  # flow related
223
284
 
224
285
  self.sample_time_fn = default(sample_time_fn, default_sample_time_fn)
@@ -232,10 +293,27 @@ class MimicVideo(Module):
232
293
  self.to_fourier_embed = RandomFourierEmbed(dim) # used by deepmind, its fine
233
294
  self.to_time_cond = create_mlp(dim_in = dim * 2, dim = dim_time_cond, depth = 2, activation = nn.SiLU())
234
295
 
296
+ # joint token related
297
+
235
298
  self.to_joint_state_token = Linear(dim_joint_state, dim)
236
299
 
300
+ self.proprio_mask_prob = proprio_mask_prob
301
+ self.has_proprio_masking = proprio_mask_prob > 0.
302
+
303
+ self.proprio_mask_token = nn.Parameter(torch.randn(dim))
304
+
305
+ # video norm
306
+
237
307
  self.video_hidden_norm = nn.RMSNorm(dim_video_hidden)
238
308
 
309
+ # manifold constrained hyper connections (mHC) from bytedance + deepseek
310
+
311
+ init_hyper_conn, self.expand_stream, self.reduce_stream = get_init_and_expand_reduce_stream_functions(num_residual_streams, dim = dim, add_stream_embed = True, **mhc_kwargs)
312
+
313
+ # rnn
314
+
315
+ self.rnn = GRU(dim, dim)
316
+
239
317
  # transformer
240
318
 
241
319
  layers = []
@@ -249,15 +327,24 @@ class MimicVideo(Module):
249
327
 
250
328
  cross_attn = Attention(dim = dim, dim_head = dim_head, dim_context = dim_video_hidden, heads = heads)
251
329
 
252
- ff_adanorm = AdaptiveRMSNorm(dim = dim, dim_time_cond = dim_time_cond)
330
+ ff_adanorm = AdaptiveRMSNorm(dim = dim, dim_time_cond = dim_time_cond, ada_ln_zero_bias = ada_ln_zero_bias)
253
331
 
254
332
  ff = SwiGLUFeedForward(dim = dim, expansion_factor = expansion_factor)
255
333
 
334
+ # maybe hyper connect
335
+
336
+ attn_residual = init_hyper_conn()
337
+ cross_attn_residual = init_hyper_conn()
338
+ ff_residual = init_hyper_conn()
339
+
256
340
  layers.append(ModuleList([
257
- attn_adanorm,
258
- attn,
341
+ cross_attn_residual,
259
342
  cross_attn_adanorm,
260
343
  cross_attn,
344
+ attn_residual,
345
+ attn_adanorm,
346
+ attn,
347
+ ff_residual,
261
348
  ff_adanorm,
262
349
  ff
263
350
  ]))
@@ -268,22 +355,102 @@ class MimicVideo(Module):
268
355
 
269
356
  self.to_pred_action_flow = nn.Sequential(
270
357
  nn.RMSNorm(dim),
271
- Linear(dim, dim_action)
358
+ Linear(dim, dim_action, bias = False)
272
359
  )
273
360
 
361
+ # inference related
362
+
363
+ # train time RTC related - https://arxiv.org/abs/2512.05964
364
+
365
+ self.train_time_rtc = train_time_rtc
366
+
367
+ assert not train_time_rtc or exists(train_time_rtc_max_delay)
368
+ self.train_time_rtc_max_delay = train_time_rtc_max_delay
369
+
370
+ # aux loss and device
371
+
372
+ self.register_buffer('zero', tensor(0.), persistent = False)
373
+
374
+ @property
375
+ def device(self):
376
+ return self.zero.device
377
+
378
+ @torch.no_grad()
379
+ def sample(
380
+ self,
381
+ steps = 16,
382
+ batch_size = 1,
383
+ disable_progress_bar = False,
384
+ **kwargs
385
+ ):
386
+
387
+ self.eval()
388
+
389
+ noise = torch.randn((batch_size, *self.action_shape), device = self.device)
390
+
391
+ times = torch.linspace(0., 1., steps + 1, device = self.device)[:-1]
392
+ delta = 1. / steps
393
+
394
+ denoised = noise
395
+
396
+ cache = None
397
+
398
+ for time in tqdm(times, disable = disable_progress_bar):
399
+ pred_flow, cache = self.forward(actions = denoised, time = time, cache = cache, return_cache = True, **kwargs)
400
+
401
+ denoised = denoised + delta * pred_flow
402
+
403
+ return denoised
404
+
274
405
  def forward(
275
406
  self,
276
- actions,
277
- video_hiddens, # they use layer 19 of cosmos predict, at first denoising step. that's all
278
407
  *,
408
+ actions,
279
409
  joint_state,
410
+ video = None,
411
+ video_hiddens = None, # they use layer 19 of cosmos predict, at first denoising step. that's all
412
+ context_mask = None,
280
413
  time = None,
281
414
  time_video_denoise = 0., # 0 is noise in the scheme i prefer - default to their optimal choice, but can be changed
282
- context_mask = None,
415
+ prompts = None,
416
+ prompt_token_ids = None,
417
+ cache = None,
418
+ return_cache = False,
419
+ return_flow = False
283
420
  ):
421
+ assert not exists(self.video_predict_wrapper) or (exists(prompts) ^ exists(prompt_token_ids))
422
+ assert actions.shape[-2:] == self.action_shape
423
+
284
424
  batch, device = actions.shape[0], actions.device
425
+ orig_actions = actions
426
+
427
+ is_training = not exists(time) and not return_flow
428
+
429
+ if not exists(cache):
430
+ # handle maybe extraction of video hiddens
431
+ # only if cache is not given
432
+
433
+ assert exists(video) ^ exists(video_hiddens)
434
+
435
+ if not exists(video_hiddens):
436
+ assert exists(self.video_predict_wrapper), f'`video_predict_wrapper` must be passed in if raw video is passed into MimicVideo'
285
437
 
286
- is_training = not exists(time)
438
+ video_hiddens = self.video_predict_wrapper(video, prompts = prompts, prompt_token_ids = prompt_token_ids)
439
+
440
+ video_hiddens = video_hiddens.to(self.device).float() # maybe bfloat to float32
441
+
442
+ video_hiddens, _ = pack_with_inverse(video_hiddens, 'b * d')
443
+
444
+ assert video_hiddens.shape[-1] == self.dim_video_hidden
445
+
446
+ # handle video hiddens
447
+
448
+ video_hiddens = self.video_hidden_norm(video_hiddens)
449
+
450
+ # handle caching
451
+
452
+ prev_cached_video_hiddens_kv = cache if exists(cache) else ((None,) * self.depth)
453
+ next_cached_video_hiddens_kv = []
287
454
 
288
455
  # handle flow time conditioning
289
456
 
@@ -297,11 +464,26 @@ class MimicVideo(Module):
297
464
  actions, left_aligned_time = align_dims_left((actions, time))
298
465
 
299
466
  noised = noise.lerp(actions, left_aligned_time)
467
+
300
468
  else:
301
469
  noised = actions
302
470
 
471
+ # maybe train time rtc
472
+
473
+ action_loss_mask = None
474
+
475
+ if is_training and self.train_time_rtc:
476
+
477
+ rand_prefix_len = torch.randint(0, self.train_time_rtc_max_delay, (batch,), device = device)
478
+ action_prefix_mask = lens_to_mask(rand_prefix_len, self.action_chunk_len)
479
+
480
+ actions = einx.where('b na, b na d, b na d', action_prefix_mask, orig_actions, actions)
481
+ time = einx.where('b na, , b', action_prefix_mask, 1., time)
482
+
483
+ action_loss_mask = ~action_prefix_mask
484
+
303
485
  if time.ndim == 0:
304
- time = rearrange(time, '-> b', b = batch)
486
+ time = repeat(time, '-> b', b = batch)
305
487
 
306
488
  # handle the video denoising times
307
489
 
@@ -313,8 +495,14 @@ class MimicVideo(Module):
313
495
  if time_video_denoise.shape[0] != batch:
314
496
  time_video_denoise = repeat(time_video_denoise, '1 -> b', b = batch)
315
497
 
498
+ if time.ndim == 2:
499
+ time_video_denoise = repeat(time_video_denoise, 'b -> b n', n = time.shape[-1])
500
+
316
501
  times = stack((time, time_video_denoise), dim = -1)
317
502
 
503
+ if times.ndim == 3:
504
+ times = pad_at_dim(times, (1, 0), dim = 1, value = 1.) # handle joint state token on the action
505
+
318
506
  # fourier embed and mlp to time condition
319
507
 
320
508
  fourier_embed = self.to_fourier_embed(times)
@@ -323,48 +511,70 @@ class MimicVideo(Module):
323
511
 
324
512
  time_cond = self.to_time_cond(fourier_embed)
325
513
 
326
- # handle video hiddens
327
-
328
- video_hiddens = self.video_hidden_norm(video_hiddens)
329
-
330
514
  # embed
331
515
 
332
516
  tokens = self.to_action_tokens(noised)
333
517
 
518
+ # one layer of rnn for actions
519
+
520
+ rnn_out, _, = self.rnn(tokens)
521
+ tokens = rnn_out + tokens
522
+
523
+ # mask joint state token for proprioception masking training
524
+
334
525
  joint_state_token = self.to_joint_state_token(joint_state)
335
526
 
527
+ if self.training and self.has_proprio_masking:
528
+ mask = torch.rand((batch,), device = device) < self.proprio_mask_prob
529
+
530
+ joint_state_token = einx.where('b, d, b d', mask, self.proprio_mask_token, joint_state_token)
531
+
532
+ # pack joint with action tokens
533
+
336
534
  tokens, inverse_pack = pack_with_inverse((joint_state_token, tokens), 'b * d')
337
535
 
536
+ # maybe expand streams
537
+
538
+ tokens = self.expand_stream(tokens)
539
+
338
540
  # transformer layers
339
541
 
340
- for (
341
- attn_norm,
342
- attn,
542
+ for ((
543
+ maybe_cross_attn_mhc,
343
544
  cross_attn_norm,
344
545
  cross_attn,
546
+ maybe_attn_mhc,
547
+ attn_norm,
548
+ attn,
549
+ maybe_ff_mhc,
345
550
  ff_norm,
346
551
  ff
347
- ) in self.layers:
552
+ ), cached_video_kv) in zip(self.layers, prev_cached_video_hiddens_kv):
348
553
 
349
554
  # cross attention
350
555
 
351
- residual = tokens
556
+ tokens, add_residual = maybe_cross_attn_mhc(tokens)
352
557
 
353
558
  tokens, gate = cross_attn_norm(tokens, time_cond)
354
559
 
355
- tokens = residual + cross_attn(tokens, context = video_hiddens, context_mask = context_mask) * gate
560
+ cross_attn_out, video_kv = cross_attn(tokens, context = video_hiddens, context_mask = context_mask, kv = cached_video_kv, return_kv = True)
561
+
562
+ tokens = add_residual(cross_attn_out * gate)
563
+
564
+ if return_cache:
565
+ next_cached_video_hiddens_kv.append(video_kv)
356
566
 
357
567
  # self attention
358
568
 
359
- residual = tokens
569
+ tokens, add_residual = maybe_attn_mhc(tokens)
360
570
 
361
571
  tokens, gate = attn_norm(tokens, time_cond)
362
572
 
363
- tokens = residual + attn(tokens) * gate
573
+ tokens = add_residual(attn(tokens) * gate)
364
574
 
365
575
  # prepare feedforward
366
576
 
367
- residual = tokens
577
+ tokens, add_residual = maybe_ff_mhc(tokens)
368
578
 
369
579
  tokens, gate = ff_norm(tokens, time_cond)
370
580
 
@@ -378,7 +588,11 @@ class MimicVideo(Module):
378
588
 
379
589
  # feedforward
380
590
 
381
- tokens = residual + ff(tokens) * gate
591
+ tokens = add_residual(ff(tokens) * gate)
592
+
593
+ # maybe reduce streams
594
+
595
+ tokens = self.reduce_stream(tokens)
382
596
 
383
597
  # remove joint token
384
598
 
@@ -389,9 +603,19 @@ class MimicVideo(Module):
389
603
  pred_flow = self.to_pred_action_flow(tokens)
390
604
 
391
605
  if not is_training:
392
- return pred_flow
606
+ # flow
607
+
608
+ out = pred_flow
609
+ else:
610
+ # mse flow loss
611
+
612
+ flow_loss = F.mse_loss(pred_flow, flow, reduction = 'none')
613
+
614
+ out = masked_mean(flow_loss, action_loss_mask)
615
+
616
+ if not return_cache:
617
+ return out
393
618
 
394
- # mse flow loss
619
+ # handle returning of cache
395
620
 
396
- flow_loss = F.mse_loss(pred_flow, flow)
397
- return flow_loss
621
+ return out, next_cached_video_hiddens_kv
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mimic-video
3
- Version: 0.0.3
3
+ Version: 0.0.24
4
4
  Summary: Mimic Video
5
5
  Project-URL: Homepage, https://pypi.org/project/mimic-video/
6
6
  Project-URL: Repository, https://github.com/lucidrains/mimic-video
@@ -36,12 +36,17 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
36
  Requires-Python: >=3.10
37
37
  Requires-Dist: einops>=0.8.1
38
38
  Requires-Dist: einx>=0.3.0
39
- Requires-Dist: torch-einops-utils>=0.0.8
39
+ Requires-Dist: hyper-connections>=0.4.3
40
+ Requires-Dist: torch-einops-utils>=0.0.12
40
41
  Requires-Dist: torch>=2.5
42
+ Requires-Dist: tqdm
41
43
  Requires-Dist: x-mlps-pytorch
42
44
  Provides-Extra: examples
43
45
  Provides-Extra: test
46
+ Requires-Dist: accelerate; extra == 'test'
47
+ Requires-Dist: diffusers>=0.32.0; extra == 'test'
44
48
  Requires-Dist: pytest; extra == 'test'
49
+ Requires-Dist: transformers; extra == 'test'
45
50
  Description-Content-Type: text/markdown
46
51
 
47
52
  <img src="./mimic-video.png" width="450px"></img>
@@ -50,6 +55,73 @@ Description-Content-Type: text/markdown
50
55
 
51
56
  Implementation of [Mimic-Video](https://mimic-video.github.io/), Video-Action Models for Generalizable Robot Control Beyond VLAs
52
57
 
58
+ ## Appreciation
59
+
60
+ - [Pranoy](https://github.com/pranoyr) for submitting a pull request for proprioception masking
61
+
62
+ ## Install
63
+
64
+ ```shell
65
+ $ pip install mimic-video
66
+ ```
67
+
68
+ ## Usage
69
+
70
+ ```python
71
+ import torch
72
+
73
+ # video wrapper
74
+ # but will be agnostic to the model
75
+
76
+ from mimic_video.cosmos_predict import CosmosPredictWrapper
77
+
78
+ video_wrapper = CosmosPredictWrapper(
79
+ extract_layer = 1,
80
+ random_weights = True,
81
+ tiny = True
82
+ )
83
+
84
+ # mimic video
85
+
86
+ from mimic_video import MimicVideo
87
+
88
+ model = MimicVideo(512, video_wrapper)
89
+
90
+ # states
91
+
92
+ video = torch.rand(2, 3, 3, 32, 32)
93
+
94
+ joint_state = torch.randn(2, 32)
95
+
96
+ # action
97
+
98
+ actions = torch.randn(2, 32, 20)
99
+
100
+ # training
101
+
102
+ loss = model(
103
+ prompts = [
104
+ 'put the package on the conveyer belt',
105
+ 'pass the butter'
106
+ ],
107
+ video = video,
108
+ actions = actions,
109
+ joint_state = joint_state
110
+ )
111
+
112
+ loss.backward()
113
+
114
+ # inference
115
+
116
+ actions = model.sample(
117
+ prompts = 'peel the orange',
118
+ video = video[:1],
119
+ joint_state = joint_state[:1]
120
+ )
121
+
122
+ assert actions.shape == (1, 32, 20)
123
+ ```
124
+
53
125
  ## Contributing
54
126
 
55
127
  First make sure `pytest` and test dependencies are installed with
@@ -76,3 +148,16 @@ That's it
76
148
  url = {https://api.semanticscholar.org/CorpusID:283920528}
77
149
  }
78
150
  ```
151
+
152
+ ```bibtex
153
+ @misc{black2025trainingtimeactionconditioningefficient,
154
+ title = {Training-Time Action Conditioning for Efficient Real-Time Chunking},
155
+ author = {Kevin Black and Allen Z. Ren and Michael Equi and Sergey Levine},
156
+ year = {2025},
157
+ eprint = {2512.05964},
158
+ archivePrefix = {arXiv},
159
+ primaryClass = {cs.RO},
160
+ url = {https://arxiv.org/abs/2512.05964},
161
+ }
162
+ ```
163
+
@@ -0,0 +1,7 @@
1
+ mimic_video/__init__.py,sha256=Rs3QeBBGBKKi1U1ykcyeBrCL2XCbfNvppeeD1Fb1pdY,47
2
+ mimic_video/cosmos_predict.py,sha256=2XR9cqcUC4gKpjEDBy-GtLtMkLXvs8yKe7w8g6EeS6s,8471
3
+ mimic_video/mimic_video.py,sha256=Qr0Dc4z-LTRlTt0qXlgcJtdSP1pBsarXeOnJSUxj_yY,17388
4
+ mimic_video-0.0.24.dist-info/METADATA,sha256=4kXYmqL3XtJbZ35iX42Z85RFV_ZGMM_phKGUZWnfcaw,4581
5
+ mimic_video-0.0.24.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
6
+ mimic_video-0.0.24.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
+ mimic_video-0.0.24.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- mimic_video/__init__.py,sha256=-4HP_pbT4YLhRUwNwuL4qyLHbgDyQ099nHL7eVi0_Ag,48
2
- mimic_video/mimic_video.py,sha256=-2HVpXAgEG28JFkJeUlypdmOMyYDD2tw0Fisf9-BZ-M,10243
3
- mimic_video-0.0.3.dist-info/METADATA,sha256=MVJMzysTCCpsgxBKUA9ye-aFSeQAXinyP3ejCtJ8JD8,2960
4
- mimic_video-0.0.3.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
- mimic_video-0.0.3.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
- mimic_video-0.0.3.dist-info/RECORD,,