mimic-video 0.0.24__py3-none-any.whl → 0.0.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mimic-video might be problematic. Click here for more details.
- mimic_video/mimic_video.py +100 -3
- {mimic_video-0.0.24.dist-info → mimic_video-0.0.27.dist-info}/METADATA +1 -1
- mimic_video-0.0.27.dist-info/RECORD +7 -0
- mimic_video-0.0.24.dist-info/RECORD +0 -7
- {mimic_video-0.0.24.dist-info → mimic_video-0.0.27.dist-info}/WHEEL +0 -0
- {mimic_video-0.0.24.dist-info → mimic_video-0.0.27.dist-info}/licenses/LICENSE +0 -0
mimic_video/mimic_video.py
CHANGED
|
@@ -47,9 +47,22 @@ def exists(v):
|
|
|
47
47
|
def default(v, d):
|
|
48
48
|
return v if exists(v) else d
|
|
49
49
|
|
|
50
|
+
def identity(t):
|
|
51
|
+
return t
|
|
52
|
+
|
|
50
53
|
def divisible_by(num, den):
|
|
51
54
|
return (num % den) == 0
|
|
52
55
|
|
|
56
|
+
# wrappers
|
|
57
|
+
|
|
58
|
+
def eval_no_grad(fn):
|
|
59
|
+
def inner(*args, **kwargs):
|
|
60
|
+
with torch.no_grad():
|
|
61
|
+
fn.eval()
|
|
62
|
+
return fn(*args, **kwargs)
|
|
63
|
+
|
|
64
|
+
return inner
|
|
65
|
+
|
|
53
66
|
# tensor function
|
|
54
67
|
|
|
55
68
|
def cast_tensor(val, device = None):
|
|
@@ -69,6 +82,30 @@ def shift_feature_dim(t):
|
|
|
69
82
|
x_shift = pad_at_dim(x_shift, (1, -1), dim = 1)
|
|
70
83
|
return cat((x, x_shift), dim = -1)
|
|
71
84
|
|
|
85
|
+
# action normalization
|
|
86
|
+
|
|
87
|
+
class Normalizer(Module):
|
|
88
|
+
def __init__(
|
|
89
|
+
self,
|
|
90
|
+
mean,
|
|
91
|
+
std,
|
|
92
|
+
eps = 1e-6
|
|
93
|
+
):
|
|
94
|
+
super().__init__()
|
|
95
|
+
assert (std > 0.).all(), 'std must be positive'
|
|
96
|
+
self.eps = eps
|
|
97
|
+
|
|
98
|
+
self.register_buffer('mean', mean)
|
|
99
|
+
self.register_buffer('std', std)
|
|
100
|
+
|
|
101
|
+
def normalize(self, t):
|
|
102
|
+
mean, std = self.mean, self.std
|
|
103
|
+
return (t - mean) / std.clamp_min(self.eps)
|
|
104
|
+
|
|
105
|
+
def inverse_normalize(self, t):
|
|
106
|
+
mean, std = self.mean, self.std
|
|
107
|
+
return (t * std) + mean
|
|
108
|
+
|
|
72
109
|
# time
|
|
73
110
|
|
|
74
111
|
# they follow p0's research finding with the beta distribution
|
|
@@ -256,7 +293,8 @@ class MimicVideo(Module):
|
|
|
256
293
|
train_time_rtc = False,
|
|
257
294
|
train_time_rtc_max_delay = None,
|
|
258
295
|
num_residual_streams = 1,
|
|
259
|
-
mhc_kwargs: dict = dict()
|
|
296
|
+
mhc_kwargs: dict = dict(),
|
|
297
|
+
action_mean_std: Tensor | None = None
|
|
260
298
|
):
|
|
261
299
|
super().__init__()
|
|
262
300
|
|
|
@@ -266,12 +304,21 @@ class MimicVideo(Module):
|
|
|
266
304
|
|
|
267
305
|
self.video_predict_wrapper = video_predict_wrapper
|
|
268
306
|
|
|
269
|
-
#
|
|
307
|
+
# action related
|
|
270
308
|
|
|
271
309
|
self.action_chunk_len = action_chunk_len
|
|
272
310
|
self.dim_action = dim_action
|
|
273
311
|
|
|
274
312
|
self.action_shape = (action_chunk_len, dim_action)
|
|
313
|
+
|
|
314
|
+
self.action_normalizer = None
|
|
315
|
+
|
|
316
|
+
if exists(action_mean_std):
|
|
317
|
+
assert action_mean_std.shape == (2, dim_action), f'must be in shape of (2 action_dim)'
|
|
318
|
+
self.action_normalizer = Normalizer(*action_mean_std)
|
|
319
|
+
|
|
320
|
+
# joint dim
|
|
321
|
+
|
|
275
322
|
self.dim_joint_state = dim_joint_state
|
|
276
323
|
|
|
277
324
|
dim_video_hidden = default(dim_video_hidden, video_predict_wrapper.dim_latent if exists(video_predict_wrapper) else None)
|
|
@@ -371,6 +418,12 @@ class MimicVideo(Module):
|
|
|
371
418
|
|
|
372
419
|
self.register_buffer('zero', tensor(0.), persistent = False)
|
|
373
420
|
|
|
421
|
+
# only action parameters
|
|
422
|
+
|
|
423
|
+
def action_parameters(self):
|
|
424
|
+
video_model_params = set(self.video_predict_wrapper.parameters()) if exists(self.video_predict_wrapper) else {}
|
|
425
|
+
return set(self.parameters()) - video_model_params
|
|
426
|
+
|
|
374
427
|
@property
|
|
375
428
|
def device(self):
|
|
376
429
|
return self.zero.device
|
|
@@ -380,26 +433,60 @@ class MimicVideo(Module):
|
|
|
380
433
|
self,
|
|
381
434
|
steps = 16,
|
|
382
435
|
batch_size = 1,
|
|
436
|
+
prefix_action_chunk = None,
|
|
383
437
|
disable_progress_bar = False,
|
|
384
438
|
**kwargs
|
|
385
439
|
):
|
|
386
440
|
|
|
387
441
|
self.eval()
|
|
388
442
|
|
|
443
|
+
inpainting = exists(prefix_action_chunk)
|
|
444
|
+
|
|
445
|
+
if inpainting:
|
|
446
|
+
prefix_len = prefix_action_chunk.shape[1]
|
|
447
|
+
assert prefix_len < self.action_chunk_len
|
|
448
|
+
|
|
449
|
+
maybe_normed_prefix = prefix_action_chunk
|
|
450
|
+
|
|
451
|
+
if exists(self.action_normalizer):
|
|
452
|
+
maybe_normed_prefix = self.action_normalizer.normalize(prefix_action_chunk)
|
|
453
|
+
|
|
454
|
+
# noise
|
|
455
|
+
|
|
389
456
|
noise = torch.randn((batch_size, *self.action_shape), device = self.device)
|
|
390
457
|
|
|
458
|
+
# times
|
|
459
|
+
|
|
391
460
|
times = torch.linspace(0., 1., steps + 1, device = self.device)[:-1]
|
|
392
461
|
delta = 1. / steps
|
|
393
462
|
|
|
463
|
+
# denoised action starts as noise
|
|
464
|
+
|
|
394
465
|
denoised = noise
|
|
395
466
|
|
|
396
467
|
cache = None
|
|
397
468
|
|
|
469
|
+
# denoise
|
|
470
|
+
|
|
398
471
|
for time in tqdm(times, disable = disable_progress_bar):
|
|
472
|
+
|
|
473
|
+
if inpainting:
|
|
474
|
+
denoised[:, :prefix_len] = maybe_normed_prefix
|
|
475
|
+
|
|
399
476
|
pred_flow, cache = self.forward(actions = denoised, time = time, cache = cache, return_cache = True, **kwargs)
|
|
400
477
|
|
|
401
478
|
denoised = denoised + delta * pred_flow
|
|
402
479
|
|
|
480
|
+
# handle action inverse norm
|
|
481
|
+
|
|
482
|
+
if exists(self.action_normalizer):
|
|
483
|
+
denoised = self.action_normalizer.inverse_normalize(denoised)
|
|
484
|
+
|
|
485
|
+
# final set, with unnormalized prefix, if inpainting
|
|
486
|
+
|
|
487
|
+
if inpainting:
|
|
488
|
+
denoised[:, :prefix_len] = prefix_action_chunk
|
|
489
|
+
|
|
403
490
|
return denoised
|
|
404
491
|
|
|
405
492
|
def forward(
|
|
@@ -414,6 +501,8 @@ class MimicVideo(Module):
|
|
|
414
501
|
time_video_denoise = 0., # 0 is noise in the scheme i prefer - default to their optimal choice, but can be changed
|
|
415
502
|
prompts = None,
|
|
416
503
|
prompt_token_ids = None,
|
|
504
|
+
detach_video_hiddens = False,
|
|
505
|
+
no_grad_video_model_forward = False,
|
|
417
506
|
cache = None,
|
|
418
507
|
return_cache = False,
|
|
419
508
|
return_flow = False
|
|
@@ -421,6 +510,9 @@ class MimicVideo(Module):
|
|
|
421
510
|
assert not exists(self.video_predict_wrapper) or (exists(prompts) ^ exists(prompt_token_ids))
|
|
422
511
|
assert actions.shape[-2:] == self.action_shape
|
|
423
512
|
|
|
513
|
+
if exists(self.action_normalizer):
|
|
514
|
+
actions = self.action_normalizer.normalize(actions)
|
|
515
|
+
|
|
424
516
|
batch, device = actions.shape[0], actions.device
|
|
425
517
|
orig_actions = actions
|
|
426
518
|
|
|
@@ -435,7 +527,9 @@ class MimicVideo(Module):
|
|
|
435
527
|
if not exists(video_hiddens):
|
|
436
528
|
assert exists(self.video_predict_wrapper), f'`video_predict_wrapper` must be passed in if raw video is passed into MimicVideo'
|
|
437
529
|
|
|
438
|
-
|
|
530
|
+
video_forward_wrap = eval_no_grad if no_grad_video_model_forward else identity
|
|
531
|
+
|
|
532
|
+
video_hiddens = video_forward_wrap(self.video_predict_wrapper)(video, prompts = prompts, prompt_token_ids = prompt_token_ids)
|
|
439
533
|
|
|
440
534
|
video_hiddens = video_hiddens.to(self.device).float() # maybe bfloat to float32
|
|
441
535
|
|
|
@@ -445,6 +539,9 @@ class MimicVideo(Module):
|
|
|
445
539
|
|
|
446
540
|
# handle video hiddens
|
|
447
541
|
|
|
542
|
+
if detach_video_hiddens:
|
|
543
|
+
video_hiddens = video_hiddens.detach()
|
|
544
|
+
|
|
448
545
|
video_hiddens = self.video_hidden_norm(video_hiddens)
|
|
449
546
|
|
|
450
547
|
# handle caching
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
mimic_video/__init__.py,sha256=Rs3QeBBGBKKi1U1ykcyeBrCL2XCbfNvppeeD1Fb1pdY,47
|
|
2
|
+
mimic_video/cosmos_predict.py,sha256=2XR9cqcUC4gKpjEDBy-GtLtMkLXvs8yKe7w8g6EeS6s,8471
|
|
3
|
+
mimic_video/mimic_video.py,sha256=WlwFfFvOW5k6X-BxRvF0zjwpKEET9C_FIyewD6_GmcE,20017
|
|
4
|
+
mimic_video-0.0.27.dist-info/METADATA,sha256=al9--DJ_U_jwWilronv3IADdbCIuQfEQRMCJ3vEtE80,4581
|
|
5
|
+
mimic_video-0.0.27.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
6
|
+
mimic_video-0.0.27.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
7
|
+
mimic_video-0.0.27.dist-info/RECORD,,
|
|
@@ -1,7 +0,0 @@
|
|
|
1
|
-
mimic_video/__init__.py,sha256=Rs3QeBBGBKKi1U1ykcyeBrCL2XCbfNvppeeD1Fb1pdY,47
|
|
2
|
-
mimic_video/cosmos_predict.py,sha256=2XR9cqcUC4gKpjEDBy-GtLtMkLXvs8yKe7w8g6EeS6s,8471
|
|
3
|
-
mimic_video/mimic_video.py,sha256=Qr0Dc4z-LTRlTt0qXlgcJtdSP1pBsarXeOnJSUxj_yY,17388
|
|
4
|
-
mimic_video-0.0.24.dist-info/METADATA,sha256=4kXYmqL3XtJbZ35iX42Z85RFV_ZGMM_phKGUZWnfcaw,4581
|
|
5
|
-
mimic_video-0.0.24.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
6
|
-
mimic_video-0.0.24.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
7
|
-
mimic_video-0.0.24.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|