mimic-video 0.0.24__py3-none-any.whl → 0.0.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mimic-video might be problematic. Click here for more details.

@@ -47,9 +47,22 @@ def exists(v):
47
47
  def default(v, d):
48
48
  return v if exists(v) else d
49
49
 
50
+ def identity(t):
51
+ return t
52
+
50
53
  def divisible_by(num, den):
51
54
  return (num % den) == 0
52
55
 
56
+ # wrappers
57
+
58
+ def eval_no_grad(fn):
59
+ def inner(*args, **kwargs):
60
+ with torch.no_grad():
61
+ fn.eval()
62
+ return fn(*args, **kwargs)
63
+
64
+ return inner
65
+
53
66
  # tensor function
54
67
 
55
68
  def cast_tensor(val, device = None):
@@ -69,6 +82,30 @@ def shift_feature_dim(t):
69
82
  x_shift = pad_at_dim(x_shift, (1, -1), dim = 1)
70
83
  return cat((x, x_shift), dim = -1)
71
84
 
85
+ # action normalization
86
+
87
+ class Normalizer(Module):
88
+ def __init__(
89
+ self,
90
+ mean,
91
+ std,
92
+ eps = 1e-6
93
+ ):
94
+ super().__init__()
95
+ assert (std > 0.).all(), 'std must be positive'
96
+ self.eps = eps
97
+
98
+ self.register_buffer('mean', mean)
99
+ self.register_buffer('std', std)
100
+
101
+ def normalize(self, t):
102
+ mean, std = self.mean, self.std
103
+ return (t - mean) / std.clamp_min(self.eps)
104
+
105
+ def inverse_normalize(self, t):
106
+ mean, std = self.mean, self.std
107
+ return (t * std) + mean
108
+
72
109
  # time
73
110
 
74
111
  # they follow p0's research finding with the beta distribution
@@ -256,7 +293,8 @@ class MimicVideo(Module):
256
293
  train_time_rtc = False,
257
294
  train_time_rtc_max_delay = None,
258
295
  num_residual_streams = 1,
259
- mhc_kwargs: dict = dict()
296
+ mhc_kwargs: dict = dict(),
297
+ action_mean_std: Tensor | None = None
260
298
  ):
261
299
  super().__init__()
262
300
 
@@ -266,12 +304,21 @@ class MimicVideo(Module):
266
304
 
267
305
  self.video_predict_wrapper = video_predict_wrapper
268
306
 
269
- # dims
307
+ # action related
270
308
 
271
309
  self.action_chunk_len = action_chunk_len
272
310
  self.dim_action = dim_action
273
311
 
274
312
  self.action_shape = (action_chunk_len, dim_action)
313
+
314
+ self.action_normalizer = None
315
+
316
+ if exists(action_mean_std):
317
+ assert action_mean_std.shape == (2, dim_action), f'must be in shape of (2 action_dim)'
318
+ self.action_normalizer = Normalizer(*action_mean_std)
319
+
320
+ # joint dim
321
+
275
322
  self.dim_joint_state = dim_joint_state
276
323
 
277
324
  dim_video_hidden = default(dim_video_hidden, video_predict_wrapper.dim_latent if exists(video_predict_wrapper) else None)
@@ -371,6 +418,12 @@ class MimicVideo(Module):
371
418
 
372
419
  self.register_buffer('zero', tensor(0.), persistent = False)
373
420
 
421
+ # only action parameters
422
+
423
+ def action_parameters(self):
424
+ video_model_params = set(self.video_predict_wrapper.parameters()) if exists(self.video_predict_wrapper) else {}
425
+ return set(self.parameters()) - video_model_params
426
+
374
427
  @property
375
428
  def device(self):
376
429
  return self.zero.device
@@ -380,26 +433,60 @@ class MimicVideo(Module):
380
433
  self,
381
434
  steps = 16,
382
435
  batch_size = 1,
436
+ prefix_action_chunk = None,
383
437
  disable_progress_bar = False,
384
438
  **kwargs
385
439
  ):
386
440
 
387
441
  self.eval()
388
442
 
443
+ inpainting = exists(prefix_action_chunk)
444
+
445
+ if inpainting:
446
+ prefix_len = prefix_action_chunk.shape[1]
447
+ assert prefix_len < self.action_chunk_len
448
+
449
+ maybe_normed_prefix = prefix_action_chunk
450
+
451
+ if exists(self.action_normalizer):
452
+ maybe_normed_prefix = self.action_normalizer.normalize(prefix_action_chunk)
453
+
454
+ # noise
455
+
389
456
  noise = torch.randn((batch_size, *self.action_shape), device = self.device)
390
457
 
458
+ # times
459
+
391
460
  times = torch.linspace(0., 1., steps + 1, device = self.device)[:-1]
392
461
  delta = 1. / steps
393
462
 
463
+ # denoised action starts as noise
464
+
394
465
  denoised = noise
395
466
 
396
467
  cache = None
397
468
 
469
+ # denoise
470
+
398
471
  for time in tqdm(times, disable = disable_progress_bar):
472
+
473
+ if inpainting:
474
+ denoised[:, :prefix_len] = maybe_normed_prefix
475
+
399
476
  pred_flow, cache = self.forward(actions = denoised, time = time, cache = cache, return_cache = True, **kwargs)
400
477
 
401
478
  denoised = denoised + delta * pred_flow
402
479
 
480
+ # handle action inverse norm
481
+
482
+ if exists(self.action_normalizer):
483
+ denoised = self.action_normalizer.inverse_normalize(denoised)
484
+
485
+ # final set, with unnormalized prefix, if inpainting
486
+
487
+ if inpainting:
488
+ denoised[:, :prefix_len] = prefix_action_chunk
489
+
403
490
  return denoised
404
491
 
405
492
  def forward(
@@ -414,6 +501,8 @@ class MimicVideo(Module):
414
501
  time_video_denoise = 0., # 0 is noise in the scheme i prefer - default to their optimal choice, but can be changed
415
502
  prompts = None,
416
503
  prompt_token_ids = None,
504
+ detach_video_hiddens = False,
505
+ no_grad_video_model_forward = False,
417
506
  cache = None,
418
507
  return_cache = False,
419
508
  return_flow = False
@@ -421,6 +510,9 @@ class MimicVideo(Module):
421
510
  assert not exists(self.video_predict_wrapper) or (exists(prompts) ^ exists(prompt_token_ids))
422
511
  assert actions.shape[-2:] == self.action_shape
423
512
 
513
+ if exists(self.action_normalizer):
514
+ actions = self.action_normalizer.normalize(actions)
515
+
424
516
  batch, device = actions.shape[0], actions.device
425
517
  orig_actions = actions
426
518
 
@@ -435,7 +527,9 @@ class MimicVideo(Module):
435
527
  if not exists(video_hiddens):
436
528
  assert exists(self.video_predict_wrapper), f'`video_predict_wrapper` must be passed in if raw video is passed into MimicVideo'
437
529
 
438
- video_hiddens = self.video_predict_wrapper(video, prompts = prompts, prompt_token_ids = prompt_token_ids)
530
+ video_forward_wrap = eval_no_grad if no_grad_video_model_forward else identity
531
+
532
+ video_hiddens = video_forward_wrap(self.video_predict_wrapper)(video, prompts = prompts, prompt_token_ids = prompt_token_ids)
439
533
 
440
534
  video_hiddens = video_hiddens.to(self.device).float() # maybe bfloat to float32
441
535
 
@@ -445,6 +539,9 @@ class MimicVideo(Module):
445
539
 
446
540
  # handle video hiddens
447
541
 
542
+ if detach_video_hiddens:
543
+ video_hiddens = video_hiddens.detach()
544
+
448
545
  video_hiddens = self.video_hidden_norm(video_hiddens)
449
546
 
450
547
  # handle caching
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mimic-video
3
- Version: 0.0.24
3
+ Version: 0.0.27
4
4
  Summary: Mimic Video
5
5
  Project-URL: Homepage, https://pypi.org/project/mimic-video/
6
6
  Project-URL: Repository, https://github.com/lucidrains/mimic-video
@@ -0,0 +1,7 @@
1
+ mimic_video/__init__.py,sha256=Rs3QeBBGBKKi1U1ykcyeBrCL2XCbfNvppeeD1Fb1pdY,47
2
+ mimic_video/cosmos_predict.py,sha256=2XR9cqcUC4gKpjEDBy-GtLtMkLXvs8yKe7w8g6EeS6s,8471
3
+ mimic_video/mimic_video.py,sha256=WlwFfFvOW5k6X-BxRvF0zjwpKEET9C_FIyewD6_GmcE,20017
4
+ mimic_video-0.0.27.dist-info/METADATA,sha256=al9--DJ_U_jwWilronv3IADdbCIuQfEQRMCJ3vEtE80,4581
5
+ mimic_video-0.0.27.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
6
+ mimic_video-0.0.27.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
+ mimic_video-0.0.27.dist-info/RECORD,,
@@ -1,7 +0,0 @@
1
- mimic_video/__init__.py,sha256=Rs3QeBBGBKKi1U1ykcyeBrCL2XCbfNvppeeD1Fb1pdY,47
2
- mimic_video/cosmos_predict.py,sha256=2XR9cqcUC4gKpjEDBy-GtLtMkLXvs8yKe7w8g6EeS6s,8471
3
- mimic_video/mimic_video.py,sha256=Qr0Dc4z-LTRlTt0qXlgcJtdSP1pBsarXeOnJSUxj_yY,17388
4
- mimic_video-0.0.24.dist-info/METADATA,sha256=4kXYmqL3XtJbZ35iX42Z85RFV_ZGMM_phKGUZWnfcaw,4581
5
- mimic_video-0.0.24.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
6
- mimic_video-0.0.24.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
- mimic_video-0.0.24.dist-info/RECORD,,