mimic-video 0.0.19__py3-none-any.whl → 0.0.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mimic-video might be problematic. Click here for more details.

@@ -1,4 +1,5 @@
1
1
  from __future__ import annotations
2
+ from functools import partial
2
3
 
3
4
  import torch
4
5
  from torch import nn, cat, stack, is_tensor, tensor
@@ -7,7 +8,7 @@ from torch.nn import Module, ModuleList, Linear, GRU
7
8
  import torch.nn.functional as F
8
9
 
9
10
  import einx
10
- from einops import einsum, rearrange, repeat
11
+ from einops import einsum, rearrange, repeat, reduce
11
12
  from einops.layers.torch import Rearrange
12
13
 
13
14
  from x_mlps_pytorch import create_mlp
@@ -15,12 +16,16 @@ from x_mlps_pytorch import create_mlp
15
16
  from tqdm import tqdm
16
17
 
17
18
  from torch_einops_utils import (
19
+ lens_to_mask,
18
20
  pad_left_ndim,
19
21
  align_dims_left,
20
22
  pad_at_dim,
21
23
  pack_with_inverse,
24
+ masked_mean
22
25
  )
23
26
 
27
+ from hyper_connections.mHCv2 import get_init_and_expand_reduce_stream_functions
28
+
24
29
  # ein notation
25
30
 
26
31
  # b - batch
@@ -30,6 +35,10 @@ from torch_einops_utils import (
30
35
  # i, j - sequence (source, target)
31
36
  # d - feature dimension
32
37
 
38
+ # constants
39
+
40
+ LinearNoBias = partial(Linear, bias = False)
41
+
33
42
  # functions
34
43
 
35
44
  def exists(v):
@@ -38,9 +47,22 @@ def exists(v):
38
47
  def default(v, d):
39
48
  return v if exists(v) else d
40
49
 
50
+ def identity(t):
51
+ return t
52
+
41
53
  def divisible_by(num, den):
42
54
  return (num % den) == 0
43
55
 
56
+ # wrappers
57
+
58
+ def eval_no_grad(fn):
59
+ def inner(*args, **kwargs):
60
+ with torch.no_grad():
61
+ fn.eval()
62
+ return fn(*args, **kwargs)
63
+
64
+ return inner
65
+
44
66
  # tensor function
45
67
 
46
68
  def cast_tensor(val, device = None):
@@ -60,6 +82,30 @@ def shift_feature_dim(t):
60
82
  x_shift = pad_at_dim(x_shift, (1, -1), dim = 1)
61
83
  return cat((x, x_shift), dim = -1)
62
84
 
85
+ # action normalization
86
+
87
+ class Normalizer(Module):
88
+ def __init__(
89
+ self,
90
+ mean,
91
+ std,
92
+ eps = 1e-6
93
+ ):
94
+ super().__init__()
95
+ assert (std > 0.).all(), 'std must be positive'
96
+ self.eps = eps
97
+
98
+ self.register_buffer('mean', mean)
99
+ self.register_buffer('std', std)
100
+
101
+ def normalize(self, t):
102
+ mean, std = self.mean, self.std
103
+ return (t - mean) / std.clamp_min(self.eps)
104
+
105
+ def inverse_normalize(self, t):
106
+ mean, std = self.mean, self.std
107
+ return (t * std) + mean
108
+
63
109
  # time
64
110
 
65
111
  # they follow p0's research finding with the beta distribution
@@ -96,8 +142,8 @@ class AdaptiveRMSNorm(Module):
96
142
  self.scale = dim ** 0.5
97
143
  self.eps = eps
98
144
 
99
- self.to_modulation = Linear(dim_time_cond, dim * 3, bias = False)
100
- self.split_modulation = Rearrange('b (three d) -> three b 1 d', three = 3)
145
+ self.to_modulation = LinearNoBias(dim_time_cond, dim * 3)
146
+ self.split_modulation = Rearrange('... (three d) -> three ... d', three = 3)
101
147
 
102
148
  nn.init.zeros_(self.to_modulation.weight)
103
149
 
@@ -108,9 +154,8 @@ class AdaptiveRMSNorm(Module):
108
154
  tokens,
109
155
  time_cond
110
156
  ):
111
-
112
- if time_cond.ndim == 1:
113
- time_cond = pad_left_ndim(time_cond, 1)
157
+ if time_cond.ndim == 2:
158
+ time_cond = rearrange(time_cond, 'b d -> b 1 d')
114
159
 
115
160
  modulations = self.to_modulation(time_cond)
116
161
 
@@ -134,7 +179,8 @@ class Attention(Module):
134
179
  dim_context = None,
135
180
  dim_head = 64,
136
181
  heads = 8,
137
- kv_heads = 2
182
+ kv_heads = 2,
183
+ attn_gate_value = True
138
184
  ):
139
185
  super().__init__()
140
186
  dim_q_inner = dim_head * heads
@@ -143,9 +189,12 @@ class Attention(Module):
143
189
 
144
190
  self.scale = dim_head ** -0.5
145
191
 
146
- self.to_queries = Linear(dim, dim_q_inner, bias = False)
147
- self.to_keys_values = Linear(dim_context, dim_kv_inner * 2, bias = False)
148
- self.to_out = Linear(dim_q_inner, dim, bias = False)
192
+ self.to_queries = LinearNoBias(dim, dim_q_inner)
193
+ self.to_keys_values = LinearNoBias(dim_context, dim_kv_inner * 2)
194
+
195
+ self.attn_gate_value = nn.Sequential(LinearNoBias(dim, heads), Rearrange('b n (g h) -> b g h n 1', h = kv_heads))
196
+
197
+ self.to_out = LinearNoBias(dim_q_inner, dim)
149
198
 
150
199
  assert divisible_by(heads, kv_heads)
151
200
  groups = heads // kv_heads
@@ -185,6 +234,8 @@ class Attention(Module):
185
234
 
186
235
  out = einsum(attn, values, 'b g h i j, b h j d -> b g h i d')
187
236
 
237
+ out = out * self.attn_gate_value(tokens).sigmoid()
238
+
188
239
  out = self.merge_heads(out)
189
240
 
190
241
  out = self.to_out(out)
@@ -238,7 +289,12 @@ class MimicVideo(Module):
238
289
  expansion_factor = 4.,
239
290
  ada_ln_zero_bias = -5.,
240
291
  dim_time_cond = None,
241
- sample_time_fn = None
292
+ sample_time_fn = None,
293
+ train_time_rtc = False,
294
+ train_time_rtc_max_delay = None,
295
+ num_residual_streams = 1,
296
+ mhc_kwargs: dict = dict(),
297
+ action_mean_std: Tensor | None = None
242
298
  ):
243
299
  super().__init__()
244
300
 
@@ -248,12 +304,21 @@ class MimicVideo(Module):
248
304
 
249
305
  self.video_predict_wrapper = video_predict_wrapper
250
306
 
251
- # dims
307
+ # action related
252
308
 
253
309
  self.action_chunk_len = action_chunk_len
254
310
  self.dim_action = dim_action
255
311
 
256
312
  self.action_shape = (action_chunk_len, dim_action)
313
+
314
+ self.action_normalizer = None
315
+
316
+ if exists(action_mean_std):
317
+ assert action_mean_std.shape == (2, dim_action), f'must be in shape of (2 action_dim)'
318
+ self.action_normalizer = Normalizer(*action_mean_std)
319
+
320
+ # joint dim
321
+
257
322
  self.dim_joint_state = dim_joint_state
258
323
 
259
324
  dim_video_hidden = default(dim_video_hidden, video_predict_wrapper.dim_latent if exists(video_predict_wrapper) else None)
@@ -288,6 +353,10 @@ class MimicVideo(Module):
288
353
 
289
354
  self.video_hidden_norm = nn.RMSNorm(dim_video_hidden)
290
355
 
356
+ # manifold constrained hyper connections (mHC) from bytedance + deepseek
357
+
358
+ init_hyper_conn, self.expand_stream, self.reduce_stream = get_init_and_expand_reduce_stream_functions(num_residual_streams, dim = dim, add_stream_embed = True, **mhc_kwargs)
359
+
291
360
  # rnn
292
361
 
293
362
  self.rnn = GRU(dim, dim)
@@ -309,11 +378,20 @@ class MimicVideo(Module):
309
378
 
310
379
  ff = SwiGLUFeedForward(dim = dim, expansion_factor = expansion_factor)
311
380
 
381
+ # maybe hyper connect
382
+
383
+ attn_residual = init_hyper_conn()
384
+ cross_attn_residual = init_hyper_conn()
385
+ ff_residual = init_hyper_conn()
386
+
312
387
  layers.append(ModuleList([
313
- attn_adanorm,
314
- attn,
388
+ cross_attn_residual,
315
389
  cross_attn_adanorm,
316
390
  cross_attn,
391
+ attn_residual,
392
+ attn_adanorm,
393
+ attn,
394
+ ff_residual,
317
395
  ff_adanorm,
318
396
  ff
319
397
  ]))
@@ -327,8 +405,25 @@ class MimicVideo(Module):
327
405
  Linear(dim, dim_action, bias = False)
328
406
  )
329
407
 
408
+ # inference related
409
+
410
+ # train time RTC related - https://arxiv.org/abs/2512.05964
411
+
412
+ self.train_time_rtc = train_time_rtc
413
+
414
+ assert not train_time_rtc or exists(train_time_rtc_max_delay)
415
+ self.train_time_rtc_max_delay = train_time_rtc_max_delay
416
+
417
+ # aux loss and device
418
+
330
419
  self.register_buffer('zero', tensor(0.), persistent = False)
331
420
 
421
+ # only action parameters
422
+
423
+ def action_parameters(self):
424
+ video_model_params = set(self.video_predict_wrapper.parameters()) if exists(self.video_predict_wrapper) else {}
425
+ return set(self.parameters()) - video_model_params
426
+
332
427
  @property
333
428
  def device(self):
334
429
  return self.zero.device
@@ -338,26 +433,60 @@ class MimicVideo(Module):
338
433
  self,
339
434
  steps = 16,
340
435
  batch_size = 1,
436
+ prefix_action_chunk = None,
341
437
  disable_progress_bar = False,
342
438
  **kwargs
343
439
  ):
344
440
 
345
441
  self.eval()
346
442
 
443
+ inpainting = exists(prefix_action_chunk)
444
+
445
+ if inpainting:
446
+ prefix_len = prefix_action_chunk.shape[1]
447
+ assert prefix_len < self.action_chunk_len
448
+
449
+ maybe_normed_prefix = prefix_action_chunk
450
+
451
+ if exists(self.action_normalizer):
452
+ maybe_normed_prefix = self.action_normalizer.normalize(prefix_action_chunk)
453
+
454
+ # noise
455
+
347
456
  noise = torch.randn((batch_size, *self.action_shape), device = self.device)
348
457
 
458
+ # times
459
+
349
460
  times = torch.linspace(0., 1., steps + 1, device = self.device)[:-1]
350
461
  delta = 1. / steps
351
462
 
463
+ # denoised action starts as noise
464
+
352
465
  denoised = noise
353
466
 
354
467
  cache = None
355
468
 
469
+ # denoise
470
+
356
471
  for time in tqdm(times, disable = disable_progress_bar):
472
+
473
+ if inpainting:
474
+ denoised[:, :prefix_len] = maybe_normed_prefix
475
+
357
476
  pred_flow, cache = self.forward(actions = denoised, time = time, cache = cache, return_cache = True, **kwargs)
358
477
 
359
478
  denoised = denoised + delta * pred_flow
360
479
 
480
+ # handle action inverse norm
481
+
482
+ if exists(self.action_normalizer):
483
+ denoised = self.action_normalizer.inverse_normalize(denoised)
484
+
485
+ # final set, with unnormalized prefix, if inpainting
486
+
487
+ if inpainting:
488
+ denoised[:, :prefix_len] = prefix_action_chunk
489
+
361
490
  return denoised
362
491
 
363
492
  def forward(
@@ -372,6 +501,8 @@ class MimicVideo(Module):
372
501
  time_video_denoise = 0., # 0 is noise in the scheme i prefer - default to their optimal choice, but can be changed
373
502
  prompts = None,
374
503
  prompt_token_ids = None,
504
+ detach_video_hiddens = False,
505
+ no_grad_video_model_forward = False,
375
506
  cache = None,
376
507
  return_cache = False,
377
508
  return_flow = False
@@ -379,7 +510,11 @@ class MimicVideo(Module):
379
510
  assert not exists(self.video_predict_wrapper) or (exists(prompts) ^ exists(prompt_token_ids))
380
511
  assert actions.shape[-2:] == self.action_shape
381
512
 
513
+ if exists(self.action_normalizer):
514
+ actions = self.action_normalizer.normalize(actions)
515
+
382
516
  batch, device = actions.shape[0], actions.device
517
+ orig_actions = actions
383
518
 
384
519
  is_training = not exists(time) and not return_flow
385
520
 
@@ -392,8 +527,11 @@ class MimicVideo(Module):
392
527
  if not exists(video_hiddens):
393
528
  assert exists(self.video_predict_wrapper), f'`video_predict_wrapper` must be passed in if raw video is passed into MimicVideo'
394
529
 
395
- video_hiddens = self.video_predict_wrapper(video, prompts = prompts, prompt_token_ids = prompt_token_ids)
396
- video_hiddens = video_hiddens.float() # maybe bfloat to float32
530
+ video_forward_wrap = eval_no_grad if no_grad_video_model_forward else identity
531
+
532
+ video_hiddens = video_forward_wrap(self.video_predict_wrapper)(video, prompts = prompts, prompt_token_ids = prompt_token_ids)
533
+
534
+ video_hiddens = video_hiddens.to(self.device).float() # maybe bfloat to float32
397
535
 
398
536
  video_hiddens, _ = pack_with_inverse(video_hiddens, 'b * d')
399
537
 
@@ -401,6 +539,9 @@ class MimicVideo(Module):
401
539
 
402
540
  # handle video hiddens
403
541
 
542
+ if detach_video_hiddens:
543
+ video_hiddens = video_hiddens.detach()
544
+
404
545
  video_hiddens = self.video_hidden_norm(video_hiddens)
405
546
 
406
547
  # handle caching
@@ -420,9 +561,24 @@ class MimicVideo(Module):
420
561
  actions, left_aligned_time = align_dims_left((actions, time))
421
562
 
422
563
  noised = noise.lerp(actions, left_aligned_time)
564
+
423
565
  else:
424
566
  noised = actions
425
567
 
568
+ # maybe train time rtc
569
+
570
+ action_loss_mask = None
571
+
572
+ if is_training and self.train_time_rtc:
573
+
574
+ rand_prefix_len = torch.randint(0, self.train_time_rtc_max_delay, (batch,), device = device)
575
+ action_prefix_mask = lens_to_mask(rand_prefix_len, self.action_chunk_len)
576
+
577
+ actions = einx.where('b na, b na d, b na d', action_prefix_mask, orig_actions, actions)
578
+ time = einx.where('b na, , b', action_prefix_mask, 1., time)
579
+
580
+ action_loss_mask = ~action_prefix_mask
581
+
426
582
  if time.ndim == 0:
427
583
  time = repeat(time, '-> b', b = batch)
428
584
 
@@ -436,8 +592,14 @@ class MimicVideo(Module):
436
592
  if time_video_denoise.shape[0] != batch:
437
593
  time_video_denoise = repeat(time_video_denoise, '1 -> b', b = batch)
438
594
 
595
+ if time.ndim == 2:
596
+ time_video_denoise = repeat(time_video_denoise, 'b -> b n', n = time.shape[-1])
597
+
439
598
  times = stack((time, time_video_denoise), dim = -1)
440
599
 
600
+ if times.ndim == 3:
601
+ times = pad_at_dim(times, (1, 0), dim = 1, value = 1.) # handle joint state token on the action
602
+
441
603
  # fourier embed and mlp to time condition
442
604
 
443
605
  fourier_embed = self.to_fourier_embed(times)
@@ -468,41 +630,48 @@ class MimicVideo(Module):
468
630
 
469
631
  tokens, inverse_pack = pack_with_inverse((joint_state_token, tokens), 'b * d')
470
632
 
633
+ # maybe expand streams
634
+
635
+ tokens = self.expand_stream(tokens)
636
+
471
637
  # transformer layers
472
638
 
473
639
  for ((
474
- attn_norm,
475
- attn,
640
+ maybe_cross_attn_mhc,
476
641
  cross_attn_norm,
477
642
  cross_attn,
643
+ maybe_attn_mhc,
644
+ attn_norm,
645
+ attn,
646
+ maybe_ff_mhc,
478
647
  ff_norm,
479
648
  ff
480
649
  ), cached_video_kv) in zip(self.layers, prev_cached_video_hiddens_kv):
481
650
 
482
651
  # cross attention
483
652
 
484
- residual = tokens
653
+ tokens, add_residual = maybe_cross_attn_mhc(tokens)
485
654
 
486
655
  tokens, gate = cross_attn_norm(tokens, time_cond)
487
656
 
488
657
  cross_attn_out, video_kv = cross_attn(tokens, context = video_hiddens, context_mask = context_mask, kv = cached_video_kv, return_kv = True)
489
658
 
490
- tokens = residual + cross_attn_out * gate
659
+ tokens = add_residual(cross_attn_out * gate)
491
660
 
492
661
  if return_cache:
493
662
  next_cached_video_hiddens_kv.append(video_kv)
494
663
 
495
664
  # self attention
496
665
 
497
- residual = tokens
666
+ tokens, add_residual = maybe_attn_mhc(tokens)
498
667
 
499
668
  tokens, gate = attn_norm(tokens, time_cond)
500
669
 
501
- tokens = residual + attn(tokens) * gate
670
+ tokens = add_residual(attn(tokens) * gate)
502
671
 
503
672
  # prepare feedforward
504
673
 
505
- residual = tokens
674
+ tokens, add_residual = maybe_ff_mhc(tokens)
506
675
 
507
676
  tokens, gate = ff_norm(tokens, time_cond)
508
677
 
@@ -516,7 +685,11 @@ class MimicVideo(Module):
516
685
 
517
686
  # feedforward
518
687
 
519
- tokens = residual + ff(tokens) * gate
688
+ tokens = add_residual(ff(tokens) * gate)
689
+
690
+ # maybe reduce streams
691
+
692
+ tokens = self.reduce_stream(tokens)
520
693
 
521
694
  # remove joint token
522
695
 
@@ -533,9 +706,9 @@ class MimicVideo(Module):
533
706
  else:
534
707
  # mse flow loss
535
708
 
536
- flow_loss = F.mse_loss(pred_flow, flow)
709
+ flow_loss = F.mse_loss(pred_flow, flow, reduction = 'none')
537
710
 
538
- out = flow_loss
711
+ out = masked_mean(flow_loss, action_loss_mask)
539
712
 
540
713
  if not return_cache:
541
714
  return out
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mimic-video
3
- Version: 0.0.19
3
+ Version: 0.0.27
4
4
  Summary: Mimic Video
5
5
  Project-URL: Homepage, https://pypi.org/project/mimic-video/
6
6
  Project-URL: Repository, https://github.com/lucidrains/mimic-video
@@ -36,7 +36,8 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
36
  Requires-Python: >=3.10
37
37
  Requires-Dist: einops>=0.8.1
38
38
  Requires-Dist: einx>=0.3.0
39
- Requires-Dist: torch-einops-utils>=0.0.8
39
+ Requires-Dist: hyper-connections>=0.4.3
40
+ Requires-Dist: torch-einops-utils>=0.0.12
40
41
  Requires-Dist: torch>=2.5
41
42
  Requires-Dist: tqdm
42
43
  Requires-Dist: x-mlps-pytorch
@@ -147,3 +148,16 @@ That's it
147
148
  url = {https://api.semanticscholar.org/CorpusID:283920528}
148
149
  }
149
150
  ```
151
+
152
+ ```bibtex
153
+ @misc{black2025trainingtimeactionconditioningefficient,
154
+ title = {Training-Time Action Conditioning for Efficient Real-Time Chunking},
155
+ author = {Kevin Black and Allen Z. Ren and Michael Equi and Sergey Levine},
156
+ year = {2025},
157
+ eprint = {2512.05964},
158
+ archivePrefix = {arXiv},
159
+ primaryClass = {cs.RO},
160
+ url = {https://arxiv.org/abs/2512.05964},
161
+ }
162
+ ```
163
+
@@ -0,0 +1,7 @@
1
+ mimic_video/__init__.py,sha256=Rs3QeBBGBKKi1U1ykcyeBrCL2XCbfNvppeeD1Fb1pdY,47
2
+ mimic_video/cosmos_predict.py,sha256=2XR9cqcUC4gKpjEDBy-GtLtMkLXvs8yKe7w8g6EeS6s,8471
3
+ mimic_video/mimic_video.py,sha256=WlwFfFvOW5k6X-BxRvF0zjwpKEET9C_FIyewD6_GmcE,20017
4
+ mimic_video-0.0.27.dist-info/METADATA,sha256=al9--DJ_U_jwWilronv3IADdbCIuQfEQRMCJ3vEtE80,4581
5
+ mimic_video-0.0.27.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
6
+ mimic_video-0.0.27.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
+ mimic_video-0.0.27.dist-info/RECORD,,
@@ -1,7 +0,0 @@
1
- mimic_video/__init__.py,sha256=Rs3QeBBGBKKi1U1ykcyeBrCL2XCbfNvppeeD1Fb1pdY,47
2
- mimic_video/cosmos_predict.py,sha256=2XR9cqcUC4gKpjEDBy-GtLtMkLXvs8yKe7w8g6EeS6s,8471
3
- mimic_video/mimic_video.py,sha256=wQNfdV2PGlR_-S-4Mm3cyswtRQ3nBQGJiHptya3ckKU,14761
4
- mimic_video-0.0.19.dist-info/METADATA,sha256=67a7iVIkf557qMos_yvpcqL8dK9yL0Jf1DPQuhb2bwo,4142
5
- mimic_video-0.0.19.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
6
- mimic_video-0.0.19.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
- mimic_video-0.0.19.dist-info/RECORD,,