mimic-video 0.0.19__py3-none-any.whl → 0.0.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mimic-video might be problematic. Click here for more details.
- mimic_video/mimic_video.py +100 -24
- {mimic_video-0.0.19.dist-info → mimic_video-0.0.24.dist-info}/METADATA +16 -2
- mimic_video-0.0.24.dist-info/RECORD +7 -0
- mimic_video-0.0.19.dist-info/RECORD +0 -7
- {mimic_video-0.0.19.dist-info → mimic_video-0.0.24.dist-info}/WHEEL +0 -0
- {mimic_video-0.0.19.dist-info → mimic_video-0.0.24.dist-info}/licenses/LICENSE +0 -0
mimic_video/mimic_video.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
+
from functools import partial
|
|
2
3
|
|
|
3
4
|
import torch
|
|
4
5
|
from torch import nn, cat, stack, is_tensor, tensor
|
|
@@ -7,7 +8,7 @@ from torch.nn import Module, ModuleList, Linear, GRU
|
|
|
7
8
|
import torch.nn.functional as F
|
|
8
9
|
|
|
9
10
|
import einx
|
|
10
|
-
from einops import einsum, rearrange, repeat
|
|
11
|
+
from einops import einsum, rearrange, repeat, reduce
|
|
11
12
|
from einops.layers.torch import Rearrange
|
|
12
13
|
|
|
13
14
|
from x_mlps_pytorch import create_mlp
|
|
@@ -15,12 +16,16 @@ from x_mlps_pytorch import create_mlp
|
|
|
15
16
|
from tqdm import tqdm
|
|
16
17
|
|
|
17
18
|
from torch_einops_utils import (
|
|
19
|
+
lens_to_mask,
|
|
18
20
|
pad_left_ndim,
|
|
19
21
|
align_dims_left,
|
|
20
22
|
pad_at_dim,
|
|
21
23
|
pack_with_inverse,
|
|
24
|
+
masked_mean
|
|
22
25
|
)
|
|
23
26
|
|
|
27
|
+
from hyper_connections.mHCv2 import get_init_and_expand_reduce_stream_functions
|
|
28
|
+
|
|
24
29
|
# ein notation
|
|
25
30
|
|
|
26
31
|
# b - batch
|
|
@@ -30,6 +35,10 @@ from torch_einops_utils import (
|
|
|
30
35
|
# i, j - sequence (source, target)
|
|
31
36
|
# d - feature dimension
|
|
32
37
|
|
|
38
|
+
# constants
|
|
39
|
+
|
|
40
|
+
LinearNoBias = partial(Linear, bias = False)
|
|
41
|
+
|
|
33
42
|
# functions
|
|
34
43
|
|
|
35
44
|
def exists(v):
|
|
@@ -96,8 +105,8 @@ class AdaptiveRMSNorm(Module):
|
|
|
96
105
|
self.scale = dim ** 0.5
|
|
97
106
|
self.eps = eps
|
|
98
107
|
|
|
99
|
-
self.to_modulation =
|
|
100
|
-
self.split_modulation = Rearrange('
|
|
108
|
+
self.to_modulation = LinearNoBias(dim_time_cond, dim * 3)
|
|
109
|
+
self.split_modulation = Rearrange('... (three d) -> three ... d', three = 3)
|
|
101
110
|
|
|
102
111
|
nn.init.zeros_(self.to_modulation.weight)
|
|
103
112
|
|
|
@@ -108,9 +117,8 @@ class AdaptiveRMSNorm(Module):
|
|
|
108
117
|
tokens,
|
|
109
118
|
time_cond
|
|
110
119
|
):
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
time_cond = pad_left_ndim(time_cond, 1)
|
|
120
|
+
if time_cond.ndim == 2:
|
|
121
|
+
time_cond = rearrange(time_cond, 'b d -> b 1 d')
|
|
114
122
|
|
|
115
123
|
modulations = self.to_modulation(time_cond)
|
|
116
124
|
|
|
@@ -134,7 +142,8 @@ class Attention(Module):
|
|
|
134
142
|
dim_context = None,
|
|
135
143
|
dim_head = 64,
|
|
136
144
|
heads = 8,
|
|
137
|
-
kv_heads = 2
|
|
145
|
+
kv_heads = 2,
|
|
146
|
+
attn_gate_value = True
|
|
138
147
|
):
|
|
139
148
|
super().__init__()
|
|
140
149
|
dim_q_inner = dim_head * heads
|
|
@@ -143,9 +152,12 @@ class Attention(Module):
|
|
|
143
152
|
|
|
144
153
|
self.scale = dim_head ** -0.5
|
|
145
154
|
|
|
146
|
-
self.to_queries =
|
|
147
|
-
self.to_keys_values =
|
|
148
|
-
|
|
155
|
+
self.to_queries = LinearNoBias(dim, dim_q_inner)
|
|
156
|
+
self.to_keys_values = LinearNoBias(dim_context, dim_kv_inner * 2)
|
|
157
|
+
|
|
158
|
+
self.attn_gate_value = nn.Sequential(LinearNoBias(dim, heads), Rearrange('b n (g h) -> b g h n 1', h = kv_heads))
|
|
159
|
+
|
|
160
|
+
self.to_out = LinearNoBias(dim_q_inner, dim)
|
|
149
161
|
|
|
150
162
|
assert divisible_by(heads, kv_heads)
|
|
151
163
|
groups = heads // kv_heads
|
|
@@ -185,6 +197,8 @@ class Attention(Module):
|
|
|
185
197
|
|
|
186
198
|
out = einsum(attn, values, 'b g h i j, b h j d -> b g h i d')
|
|
187
199
|
|
|
200
|
+
out = out * self.attn_gate_value(tokens).sigmoid()
|
|
201
|
+
|
|
188
202
|
out = self.merge_heads(out)
|
|
189
203
|
|
|
190
204
|
out = self.to_out(out)
|
|
@@ -238,7 +252,11 @@ class MimicVideo(Module):
|
|
|
238
252
|
expansion_factor = 4.,
|
|
239
253
|
ada_ln_zero_bias = -5.,
|
|
240
254
|
dim_time_cond = None,
|
|
241
|
-
sample_time_fn = None
|
|
255
|
+
sample_time_fn = None,
|
|
256
|
+
train_time_rtc = False,
|
|
257
|
+
train_time_rtc_max_delay = None,
|
|
258
|
+
num_residual_streams = 1,
|
|
259
|
+
mhc_kwargs: dict = dict()
|
|
242
260
|
):
|
|
243
261
|
super().__init__()
|
|
244
262
|
|
|
@@ -288,6 +306,10 @@ class MimicVideo(Module):
|
|
|
288
306
|
|
|
289
307
|
self.video_hidden_norm = nn.RMSNorm(dim_video_hidden)
|
|
290
308
|
|
|
309
|
+
# manifold constrained hyper connections (mHC) from bytedance + deepseek
|
|
310
|
+
|
|
311
|
+
init_hyper_conn, self.expand_stream, self.reduce_stream = get_init_and_expand_reduce_stream_functions(num_residual_streams, dim = dim, add_stream_embed = True, **mhc_kwargs)
|
|
312
|
+
|
|
291
313
|
# rnn
|
|
292
314
|
|
|
293
315
|
self.rnn = GRU(dim, dim)
|
|
@@ -309,11 +331,20 @@ class MimicVideo(Module):
|
|
|
309
331
|
|
|
310
332
|
ff = SwiGLUFeedForward(dim = dim, expansion_factor = expansion_factor)
|
|
311
333
|
|
|
334
|
+
# maybe hyper connect
|
|
335
|
+
|
|
336
|
+
attn_residual = init_hyper_conn()
|
|
337
|
+
cross_attn_residual = init_hyper_conn()
|
|
338
|
+
ff_residual = init_hyper_conn()
|
|
339
|
+
|
|
312
340
|
layers.append(ModuleList([
|
|
313
|
-
|
|
314
|
-
attn,
|
|
341
|
+
cross_attn_residual,
|
|
315
342
|
cross_attn_adanorm,
|
|
316
343
|
cross_attn,
|
|
344
|
+
attn_residual,
|
|
345
|
+
attn_adanorm,
|
|
346
|
+
attn,
|
|
347
|
+
ff_residual,
|
|
317
348
|
ff_adanorm,
|
|
318
349
|
ff
|
|
319
350
|
]))
|
|
@@ -327,6 +358,17 @@ class MimicVideo(Module):
|
|
|
327
358
|
Linear(dim, dim_action, bias = False)
|
|
328
359
|
)
|
|
329
360
|
|
|
361
|
+
# inference related
|
|
362
|
+
|
|
363
|
+
# train time RTC related - https://arxiv.org/abs/2512.05964
|
|
364
|
+
|
|
365
|
+
self.train_time_rtc = train_time_rtc
|
|
366
|
+
|
|
367
|
+
assert not train_time_rtc or exists(train_time_rtc_max_delay)
|
|
368
|
+
self.train_time_rtc_max_delay = train_time_rtc_max_delay
|
|
369
|
+
|
|
370
|
+
# aux loss and device
|
|
371
|
+
|
|
330
372
|
self.register_buffer('zero', tensor(0.), persistent = False)
|
|
331
373
|
|
|
332
374
|
@property
|
|
@@ -380,6 +422,7 @@ class MimicVideo(Module):
|
|
|
380
422
|
assert actions.shape[-2:] == self.action_shape
|
|
381
423
|
|
|
382
424
|
batch, device = actions.shape[0], actions.device
|
|
425
|
+
orig_actions = actions
|
|
383
426
|
|
|
384
427
|
is_training = not exists(time) and not return_flow
|
|
385
428
|
|
|
@@ -393,7 +436,8 @@ class MimicVideo(Module):
|
|
|
393
436
|
assert exists(self.video_predict_wrapper), f'`video_predict_wrapper` must be passed in if raw video is passed into MimicVideo'
|
|
394
437
|
|
|
395
438
|
video_hiddens = self.video_predict_wrapper(video, prompts = prompts, prompt_token_ids = prompt_token_ids)
|
|
396
|
-
|
|
439
|
+
|
|
440
|
+
video_hiddens = video_hiddens.to(self.device).float() # maybe bfloat to float32
|
|
397
441
|
|
|
398
442
|
video_hiddens, _ = pack_with_inverse(video_hiddens, 'b * d')
|
|
399
443
|
|
|
@@ -420,9 +464,24 @@ class MimicVideo(Module):
|
|
|
420
464
|
actions, left_aligned_time = align_dims_left((actions, time))
|
|
421
465
|
|
|
422
466
|
noised = noise.lerp(actions, left_aligned_time)
|
|
467
|
+
|
|
423
468
|
else:
|
|
424
469
|
noised = actions
|
|
425
470
|
|
|
471
|
+
# maybe train time rtc
|
|
472
|
+
|
|
473
|
+
action_loss_mask = None
|
|
474
|
+
|
|
475
|
+
if is_training and self.train_time_rtc:
|
|
476
|
+
|
|
477
|
+
rand_prefix_len = torch.randint(0, self.train_time_rtc_max_delay, (batch,), device = device)
|
|
478
|
+
action_prefix_mask = lens_to_mask(rand_prefix_len, self.action_chunk_len)
|
|
479
|
+
|
|
480
|
+
actions = einx.where('b na, b na d, b na d', action_prefix_mask, orig_actions, actions)
|
|
481
|
+
time = einx.where('b na, , b', action_prefix_mask, 1., time)
|
|
482
|
+
|
|
483
|
+
action_loss_mask = ~action_prefix_mask
|
|
484
|
+
|
|
426
485
|
if time.ndim == 0:
|
|
427
486
|
time = repeat(time, '-> b', b = batch)
|
|
428
487
|
|
|
@@ -436,8 +495,14 @@ class MimicVideo(Module):
|
|
|
436
495
|
if time_video_denoise.shape[0] != batch:
|
|
437
496
|
time_video_denoise = repeat(time_video_denoise, '1 -> b', b = batch)
|
|
438
497
|
|
|
498
|
+
if time.ndim == 2:
|
|
499
|
+
time_video_denoise = repeat(time_video_denoise, 'b -> b n', n = time.shape[-1])
|
|
500
|
+
|
|
439
501
|
times = stack((time, time_video_denoise), dim = -1)
|
|
440
502
|
|
|
503
|
+
if times.ndim == 3:
|
|
504
|
+
times = pad_at_dim(times, (1, 0), dim = 1, value = 1.) # handle joint state token on the action
|
|
505
|
+
|
|
441
506
|
# fourier embed and mlp to time condition
|
|
442
507
|
|
|
443
508
|
fourier_embed = self.to_fourier_embed(times)
|
|
@@ -468,41 +533,48 @@ class MimicVideo(Module):
|
|
|
468
533
|
|
|
469
534
|
tokens, inverse_pack = pack_with_inverse((joint_state_token, tokens), 'b * d')
|
|
470
535
|
|
|
536
|
+
# maybe expand streams
|
|
537
|
+
|
|
538
|
+
tokens = self.expand_stream(tokens)
|
|
539
|
+
|
|
471
540
|
# transformer layers
|
|
472
541
|
|
|
473
542
|
for ((
|
|
474
|
-
|
|
475
|
-
attn,
|
|
543
|
+
maybe_cross_attn_mhc,
|
|
476
544
|
cross_attn_norm,
|
|
477
545
|
cross_attn,
|
|
546
|
+
maybe_attn_mhc,
|
|
547
|
+
attn_norm,
|
|
548
|
+
attn,
|
|
549
|
+
maybe_ff_mhc,
|
|
478
550
|
ff_norm,
|
|
479
551
|
ff
|
|
480
552
|
), cached_video_kv) in zip(self.layers, prev_cached_video_hiddens_kv):
|
|
481
553
|
|
|
482
554
|
# cross attention
|
|
483
555
|
|
|
484
|
-
|
|
556
|
+
tokens, add_residual = maybe_cross_attn_mhc(tokens)
|
|
485
557
|
|
|
486
558
|
tokens, gate = cross_attn_norm(tokens, time_cond)
|
|
487
559
|
|
|
488
560
|
cross_attn_out, video_kv = cross_attn(tokens, context = video_hiddens, context_mask = context_mask, kv = cached_video_kv, return_kv = True)
|
|
489
561
|
|
|
490
|
-
tokens =
|
|
562
|
+
tokens = add_residual(cross_attn_out * gate)
|
|
491
563
|
|
|
492
564
|
if return_cache:
|
|
493
565
|
next_cached_video_hiddens_kv.append(video_kv)
|
|
494
566
|
|
|
495
567
|
# self attention
|
|
496
568
|
|
|
497
|
-
|
|
569
|
+
tokens, add_residual = maybe_attn_mhc(tokens)
|
|
498
570
|
|
|
499
571
|
tokens, gate = attn_norm(tokens, time_cond)
|
|
500
572
|
|
|
501
|
-
tokens =
|
|
573
|
+
tokens = add_residual(attn(tokens) * gate)
|
|
502
574
|
|
|
503
575
|
# prepare feedforward
|
|
504
576
|
|
|
505
|
-
|
|
577
|
+
tokens, add_residual = maybe_ff_mhc(tokens)
|
|
506
578
|
|
|
507
579
|
tokens, gate = ff_norm(tokens, time_cond)
|
|
508
580
|
|
|
@@ -516,7 +588,11 @@ class MimicVideo(Module):
|
|
|
516
588
|
|
|
517
589
|
# feedforward
|
|
518
590
|
|
|
519
|
-
tokens =
|
|
591
|
+
tokens = add_residual(ff(tokens) * gate)
|
|
592
|
+
|
|
593
|
+
# maybe reduce streams
|
|
594
|
+
|
|
595
|
+
tokens = self.reduce_stream(tokens)
|
|
520
596
|
|
|
521
597
|
# remove joint token
|
|
522
598
|
|
|
@@ -533,9 +609,9 @@ class MimicVideo(Module):
|
|
|
533
609
|
else:
|
|
534
610
|
# mse flow loss
|
|
535
611
|
|
|
536
|
-
flow_loss = F.mse_loss(pred_flow, flow)
|
|
612
|
+
flow_loss = F.mse_loss(pred_flow, flow, reduction = 'none')
|
|
537
613
|
|
|
538
|
-
out = flow_loss
|
|
614
|
+
out = masked_mean(flow_loss, action_loss_mask)
|
|
539
615
|
|
|
540
616
|
if not return_cache:
|
|
541
617
|
return out
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: mimic-video
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.24
|
|
4
4
|
Summary: Mimic Video
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/mimic-video/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/mimic-video
|
|
@@ -36,7 +36,8 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
|
36
36
|
Requires-Python: >=3.10
|
|
37
37
|
Requires-Dist: einops>=0.8.1
|
|
38
38
|
Requires-Dist: einx>=0.3.0
|
|
39
|
-
Requires-Dist:
|
|
39
|
+
Requires-Dist: hyper-connections>=0.4.3
|
|
40
|
+
Requires-Dist: torch-einops-utils>=0.0.12
|
|
40
41
|
Requires-Dist: torch>=2.5
|
|
41
42
|
Requires-Dist: tqdm
|
|
42
43
|
Requires-Dist: x-mlps-pytorch
|
|
@@ -147,3 +148,16 @@ That's it
|
|
|
147
148
|
url = {https://api.semanticscholar.org/CorpusID:283920528}
|
|
148
149
|
}
|
|
149
150
|
```
|
|
151
|
+
|
|
152
|
+
```bibtex
|
|
153
|
+
@misc{black2025trainingtimeactionconditioningefficient,
|
|
154
|
+
title = {Training-Time Action Conditioning for Efficient Real-Time Chunking},
|
|
155
|
+
author = {Kevin Black and Allen Z. Ren and Michael Equi and Sergey Levine},
|
|
156
|
+
year = {2025},
|
|
157
|
+
eprint = {2512.05964},
|
|
158
|
+
archivePrefix = {arXiv},
|
|
159
|
+
primaryClass = {cs.RO},
|
|
160
|
+
url = {https://arxiv.org/abs/2512.05964},
|
|
161
|
+
}
|
|
162
|
+
```
|
|
163
|
+
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
mimic_video/__init__.py,sha256=Rs3QeBBGBKKi1U1ykcyeBrCL2XCbfNvppeeD1Fb1pdY,47
|
|
2
|
+
mimic_video/cosmos_predict.py,sha256=2XR9cqcUC4gKpjEDBy-GtLtMkLXvs8yKe7w8g6EeS6s,8471
|
|
3
|
+
mimic_video/mimic_video.py,sha256=Qr0Dc4z-LTRlTt0qXlgcJtdSP1pBsarXeOnJSUxj_yY,17388
|
|
4
|
+
mimic_video-0.0.24.dist-info/METADATA,sha256=4kXYmqL3XtJbZ35iX42Z85RFV_ZGMM_phKGUZWnfcaw,4581
|
|
5
|
+
mimic_video-0.0.24.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
6
|
+
mimic_video-0.0.24.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
7
|
+
mimic_video-0.0.24.dist-info/RECORD,,
|
|
@@ -1,7 +0,0 @@
|
|
|
1
|
-
mimic_video/__init__.py,sha256=Rs3QeBBGBKKi1U1ykcyeBrCL2XCbfNvppeeD1Fb1pdY,47
|
|
2
|
-
mimic_video/cosmos_predict.py,sha256=2XR9cqcUC4gKpjEDBy-GtLtMkLXvs8yKe7w8g6EeS6s,8471
|
|
3
|
-
mimic_video/mimic_video.py,sha256=wQNfdV2PGlR_-S-4Mm3cyswtRQ3nBQGJiHptya3ckKU,14761
|
|
4
|
-
mimic_video-0.0.19.dist-info/METADATA,sha256=67a7iVIkf557qMos_yvpcqL8dK9yL0Jf1DPQuhb2bwo,4142
|
|
5
|
-
mimic_video-0.0.19.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
6
|
-
mimic_video-0.0.19.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
7
|
-
mimic_video-0.0.19.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|