mimic-video 0.0.19__py3-none-any.whl → 0.0.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mimic-video might be problematic. Click here for more details.

@@ -1,4 +1,5 @@
1
1
  from __future__ import annotations
2
+ from functools import partial
2
3
 
3
4
  import torch
4
5
  from torch import nn, cat, stack, is_tensor, tensor
@@ -7,7 +8,7 @@ from torch.nn import Module, ModuleList, Linear, GRU
7
8
  import torch.nn.functional as F
8
9
 
9
10
  import einx
10
- from einops import einsum, rearrange, repeat
11
+ from einops import einsum, rearrange, repeat, reduce
11
12
  from einops.layers.torch import Rearrange
12
13
 
13
14
  from x_mlps_pytorch import create_mlp
@@ -15,12 +16,16 @@ from x_mlps_pytorch import create_mlp
15
16
  from tqdm import tqdm
16
17
 
17
18
  from torch_einops_utils import (
19
+ lens_to_mask,
18
20
  pad_left_ndim,
19
21
  align_dims_left,
20
22
  pad_at_dim,
21
23
  pack_with_inverse,
24
+ masked_mean
22
25
  )
23
26
 
27
+ from hyper_connections.mHCv2 import get_init_and_expand_reduce_stream_functions
28
+
24
29
  # ein notation
25
30
 
26
31
  # b - batch
@@ -30,6 +35,10 @@ from torch_einops_utils import (
30
35
  # i, j - sequence (source, target)
31
36
  # d - feature dimension
32
37
 
38
+ # constants
39
+
40
+ LinearNoBias = partial(Linear, bias = False)
41
+
33
42
  # functions
34
43
 
35
44
  def exists(v):
@@ -96,8 +105,8 @@ class AdaptiveRMSNorm(Module):
96
105
  self.scale = dim ** 0.5
97
106
  self.eps = eps
98
107
 
99
- self.to_modulation = Linear(dim_time_cond, dim * 3, bias = False)
100
- self.split_modulation = Rearrange('b (three d) -> three b 1 d', three = 3)
108
+ self.to_modulation = LinearNoBias(dim_time_cond, dim * 3)
109
+ self.split_modulation = Rearrange('... (three d) -> three ... d', three = 3)
101
110
 
102
111
  nn.init.zeros_(self.to_modulation.weight)
103
112
 
@@ -108,9 +117,8 @@ class AdaptiveRMSNorm(Module):
108
117
  tokens,
109
118
  time_cond
110
119
  ):
111
-
112
- if time_cond.ndim == 1:
113
- time_cond = pad_left_ndim(time_cond, 1)
120
+ if time_cond.ndim == 2:
121
+ time_cond = rearrange(time_cond, 'b d -> b 1 d')
114
122
 
115
123
  modulations = self.to_modulation(time_cond)
116
124
 
@@ -134,7 +142,8 @@ class Attention(Module):
134
142
  dim_context = None,
135
143
  dim_head = 64,
136
144
  heads = 8,
137
- kv_heads = 2
145
+ kv_heads = 2,
146
+ attn_gate_value = True
138
147
  ):
139
148
  super().__init__()
140
149
  dim_q_inner = dim_head * heads
@@ -143,9 +152,12 @@ class Attention(Module):
143
152
 
144
153
  self.scale = dim_head ** -0.5
145
154
 
146
- self.to_queries = Linear(dim, dim_q_inner, bias = False)
147
- self.to_keys_values = Linear(dim_context, dim_kv_inner * 2, bias = False)
148
- self.to_out = Linear(dim_q_inner, dim, bias = False)
155
+ self.to_queries = LinearNoBias(dim, dim_q_inner)
156
+ self.to_keys_values = LinearNoBias(dim_context, dim_kv_inner * 2)
157
+
158
+ self.attn_gate_value = nn.Sequential(LinearNoBias(dim, heads), Rearrange('b n (g h) -> b g h n 1', h = kv_heads))
159
+
160
+ self.to_out = LinearNoBias(dim_q_inner, dim)
149
161
 
150
162
  assert divisible_by(heads, kv_heads)
151
163
  groups = heads // kv_heads
@@ -185,6 +197,8 @@ class Attention(Module):
185
197
 
186
198
  out = einsum(attn, values, 'b g h i j, b h j d -> b g h i d')
187
199
 
200
+ out = out * self.attn_gate_value(tokens).sigmoid()
201
+
188
202
  out = self.merge_heads(out)
189
203
 
190
204
  out = self.to_out(out)
@@ -238,7 +252,11 @@ class MimicVideo(Module):
238
252
  expansion_factor = 4.,
239
253
  ada_ln_zero_bias = -5.,
240
254
  dim_time_cond = None,
241
- sample_time_fn = None
255
+ sample_time_fn = None,
256
+ train_time_rtc = False,
257
+ train_time_rtc_max_delay = None,
258
+ num_residual_streams = 1,
259
+ mhc_kwargs: dict = dict()
242
260
  ):
243
261
  super().__init__()
244
262
 
@@ -288,6 +306,10 @@ class MimicVideo(Module):
288
306
 
289
307
  self.video_hidden_norm = nn.RMSNorm(dim_video_hidden)
290
308
 
309
+ # manifold constrained hyper connections (mHC) from bytedance + deepseek
310
+
311
+ init_hyper_conn, self.expand_stream, self.reduce_stream = get_init_and_expand_reduce_stream_functions(num_residual_streams, dim = dim, add_stream_embed = True, **mhc_kwargs)
312
+
291
313
  # rnn
292
314
 
293
315
  self.rnn = GRU(dim, dim)
@@ -309,11 +331,20 @@ class MimicVideo(Module):
309
331
 
310
332
  ff = SwiGLUFeedForward(dim = dim, expansion_factor = expansion_factor)
311
333
 
334
+ # maybe hyper connect
335
+
336
+ attn_residual = init_hyper_conn()
337
+ cross_attn_residual = init_hyper_conn()
338
+ ff_residual = init_hyper_conn()
339
+
312
340
  layers.append(ModuleList([
313
- attn_adanorm,
314
- attn,
341
+ cross_attn_residual,
315
342
  cross_attn_adanorm,
316
343
  cross_attn,
344
+ attn_residual,
345
+ attn_adanorm,
346
+ attn,
347
+ ff_residual,
317
348
  ff_adanorm,
318
349
  ff
319
350
  ]))
@@ -327,6 +358,17 @@ class MimicVideo(Module):
327
358
  Linear(dim, dim_action, bias = False)
328
359
  )
329
360
 
361
+ # inference related
362
+
363
+ # train time RTC related - https://arxiv.org/abs/2512.05964
364
+
365
+ self.train_time_rtc = train_time_rtc
366
+
367
+ assert not train_time_rtc or exists(train_time_rtc_max_delay)
368
+ self.train_time_rtc_max_delay = train_time_rtc_max_delay
369
+
370
+ # aux loss and device
371
+
330
372
  self.register_buffer('zero', tensor(0.), persistent = False)
331
373
 
332
374
  @property
@@ -380,6 +422,7 @@ class MimicVideo(Module):
380
422
  assert actions.shape[-2:] == self.action_shape
381
423
 
382
424
  batch, device = actions.shape[0], actions.device
425
+ orig_actions = actions
383
426
 
384
427
  is_training = not exists(time) and not return_flow
385
428
 
@@ -393,7 +436,8 @@ class MimicVideo(Module):
393
436
  assert exists(self.video_predict_wrapper), f'`video_predict_wrapper` must be passed in if raw video is passed into MimicVideo'
394
437
 
395
438
  video_hiddens = self.video_predict_wrapper(video, prompts = prompts, prompt_token_ids = prompt_token_ids)
396
- video_hiddens = video_hiddens.float() # maybe bfloat to float32
439
+
440
+ video_hiddens = video_hiddens.to(self.device).float() # maybe bfloat to float32
397
441
 
398
442
  video_hiddens, _ = pack_with_inverse(video_hiddens, 'b * d')
399
443
 
@@ -420,9 +464,24 @@ class MimicVideo(Module):
420
464
  actions, left_aligned_time = align_dims_left((actions, time))
421
465
 
422
466
  noised = noise.lerp(actions, left_aligned_time)
467
+
423
468
  else:
424
469
  noised = actions
425
470
 
471
+ # maybe train time rtc
472
+
473
+ action_loss_mask = None
474
+
475
+ if is_training and self.train_time_rtc:
476
+
477
+ rand_prefix_len = torch.randint(0, self.train_time_rtc_max_delay, (batch,), device = device)
478
+ action_prefix_mask = lens_to_mask(rand_prefix_len, self.action_chunk_len)
479
+
480
+ actions = einx.where('b na, b na d, b na d', action_prefix_mask, orig_actions, actions)
481
+ time = einx.where('b na, , b', action_prefix_mask, 1., time)
482
+
483
+ action_loss_mask = ~action_prefix_mask
484
+
426
485
  if time.ndim == 0:
427
486
  time = repeat(time, '-> b', b = batch)
428
487
 
@@ -436,8 +495,14 @@ class MimicVideo(Module):
436
495
  if time_video_denoise.shape[0] != batch:
437
496
  time_video_denoise = repeat(time_video_denoise, '1 -> b', b = batch)
438
497
 
498
+ if time.ndim == 2:
499
+ time_video_denoise = repeat(time_video_denoise, 'b -> b n', n = time.shape[-1])
500
+
439
501
  times = stack((time, time_video_denoise), dim = -1)
440
502
 
503
+ if times.ndim == 3:
504
+ times = pad_at_dim(times, (1, 0), dim = 1, value = 1.) # handle joint state token on the action
505
+
441
506
  # fourier embed and mlp to time condition
442
507
 
443
508
  fourier_embed = self.to_fourier_embed(times)
@@ -468,41 +533,48 @@ class MimicVideo(Module):
468
533
 
469
534
  tokens, inverse_pack = pack_with_inverse((joint_state_token, tokens), 'b * d')
470
535
 
536
+ # maybe expand streams
537
+
538
+ tokens = self.expand_stream(tokens)
539
+
471
540
  # transformer layers
472
541
 
473
542
  for ((
474
- attn_norm,
475
- attn,
543
+ maybe_cross_attn_mhc,
476
544
  cross_attn_norm,
477
545
  cross_attn,
546
+ maybe_attn_mhc,
547
+ attn_norm,
548
+ attn,
549
+ maybe_ff_mhc,
478
550
  ff_norm,
479
551
  ff
480
552
  ), cached_video_kv) in zip(self.layers, prev_cached_video_hiddens_kv):
481
553
 
482
554
  # cross attention
483
555
 
484
- residual = tokens
556
+ tokens, add_residual = maybe_cross_attn_mhc(tokens)
485
557
 
486
558
  tokens, gate = cross_attn_norm(tokens, time_cond)
487
559
 
488
560
  cross_attn_out, video_kv = cross_attn(tokens, context = video_hiddens, context_mask = context_mask, kv = cached_video_kv, return_kv = True)
489
561
 
490
- tokens = residual + cross_attn_out * gate
562
+ tokens = add_residual(cross_attn_out * gate)
491
563
 
492
564
  if return_cache:
493
565
  next_cached_video_hiddens_kv.append(video_kv)
494
566
 
495
567
  # self attention
496
568
 
497
- residual = tokens
569
+ tokens, add_residual = maybe_attn_mhc(tokens)
498
570
 
499
571
  tokens, gate = attn_norm(tokens, time_cond)
500
572
 
501
- tokens = residual + attn(tokens) * gate
573
+ tokens = add_residual(attn(tokens) * gate)
502
574
 
503
575
  # prepare feedforward
504
576
 
505
- residual = tokens
577
+ tokens, add_residual = maybe_ff_mhc(tokens)
506
578
 
507
579
  tokens, gate = ff_norm(tokens, time_cond)
508
580
 
@@ -516,7 +588,11 @@ class MimicVideo(Module):
516
588
 
517
589
  # feedforward
518
590
 
519
- tokens = residual + ff(tokens) * gate
591
+ tokens = add_residual(ff(tokens) * gate)
592
+
593
+ # maybe reduce streams
594
+
595
+ tokens = self.reduce_stream(tokens)
520
596
 
521
597
  # remove joint token
522
598
 
@@ -533,9 +609,9 @@ class MimicVideo(Module):
533
609
  else:
534
610
  # mse flow loss
535
611
 
536
- flow_loss = F.mse_loss(pred_flow, flow)
612
+ flow_loss = F.mse_loss(pred_flow, flow, reduction = 'none')
537
613
 
538
- out = flow_loss
614
+ out = masked_mean(flow_loss, action_loss_mask)
539
615
 
540
616
  if not return_cache:
541
617
  return out
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mimic-video
3
- Version: 0.0.19
3
+ Version: 0.0.24
4
4
  Summary: Mimic Video
5
5
  Project-URL: Homepage, https://pypi.org/project/mimic-video/
6
6
  Project-URL: Repository, https://github.com/lucidrains/mimic-video
@@ -36,7 +36,8 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
36
  Requires-Python: >=3.10
37
37
  Requires-Dist: einops>=0.8.1
38
38
  Requires-Dist: einx>=0.3.0
39
- Requires-Dist: torch-einops-utils>=0.0.8
39
+ Requires-Dist: hyper-connections>=0.4.3
40
+ Requires-Dist: torch-einops-utils>=0.0.12
40
41
  Requires-Dist: torch>=2.5
41
42
  Requires-Dist: tqdm
42
43
  Requires-Dist: x-mlps-pytorch
@@ -147,3 +148,16 @@ That's it
147
148
  url = {https://api.semanticscholar.org/CorpusID:283920528}
148
149
  }
149
150
  ```
151
+
152
+ ```bibtex
153
+ @misc{black2025trainingtimeactionconditioningefficient,
154
+ title = {Training-Time Action Conditioning for Efficient Real-Time Chunking},
155
+ author = {Kevin Black and Allen Z. Ren and Michael Equi and Sergey Levine},
156
+ year = {2025},
157
+ eprint = {2512.05964},
158
+ archivePrefix = {arXiv},
159
+ primaryClass = {cs.RO},
160
+ url = {https://arxiv.org/abs/2512.05964},
161
+ }
162
+ ```
163
+
@@ -0,0 +1,7 @@
1
+ mimic_video/__init__.py,sha256=Rs3QeBBGBKKi1U1ykcyeBrCL2XCbfNvppeeD1Fb1pdY,47
2
+ mimic_video/cosmos_predict.py,sha256=2XR9cqcUC4gKpjEDBy-GtLtMkLXvs8yKe7w8g6EeS6s,8471
3
+ mimic_video/mimic_video.py,sha256=Qr0Dc4z-LTRlTt0qXlgcJtdSP1pBsarXeOnJSUxj_yY,17388
4
+ mimic_video-0.0.24.dist-info/METADATA,sha256=4kXYmqL3XtJbZ35iX42Z85RFV_ZGMM_phKGUZWnfcaw,4581
5
+ mimic_video-0.0.24.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
6
+ mimic_video-0.0.24.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
+ mimic_video-0.0.24.dist-info/RECORD,,
@@ -1,7 +0,0 @@
1
- mimic_video/__init__.py,sha256=Rs3QeBBGBKKi1U1ykcyeBrCL2XCbfNvppeeD1Fb1pdY,47
2
- mimic_video/cosmos_predict.py,sha256=2XR9cqcUC4gKpjEDBy-GtLtMkLXvs8yKe7w8g6EeS6s,8471
3
- mimic_video/mimic_video.py,sha256=wQNfdV2PGlR_-S-4Mm3cyswtRQ3nBQGJiHptya3ckKU,14761
4
- mimic_video-0.0.19.dist-info/METADATA,sha256=67a7iVIkf557qMos_yvpcqL8dK9yL0Jf1DPQuhb2bwo,4142
5
- mimic_video-0.0.19.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
6
- mimic_video-0.0.19.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
- mimic_video-0.0.19.dist-info/RECORD,,