mimic-video 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mimic-video might be problematic. Click here for more details.

@@ -1,18 +1,20 @@
1
1
  import torch
2
- from torch import nn
2
+ from torch import nn, cat, stack, is_tensor, tensor
3
3
  from torch.nn import Module, ModuleList, Linear
4
4
 
5
5
  import torch.nn.functional as F
6
6
 
7
7
  import einx
8
- from einops import einsum, rearrange
8
+ from einops import einsum, rearrange, repeat
9
9
  from einops.layers.torch import Rearrange
10
10
 
11
11
  from x_mlps_pytorch import create_mlp
12
12
 
13
13
  from torch_einops_utils import (
14
14
  pad_left_ndim,
15
- align_dims_left
15
+ align_dims_left,
16
+ pad_at_dim,
17
+ pack_with_inverse,
16
18
  )
17
19
 
18
20
  # ein notation
@@ -37,12 +39,23 @@ def divisible_by(num, den):
37
39
 
38
40
  # tensor function
39
41
 
42
+ def cast_tensor(val, device = None):
43
+ return tensor(val, device = device) if not is_tensor(val) else val
44
+
40
45
  def max_neg_value(t):
41
46
  return -torch.finfo(t.dtype).max
42
47
 
43
48
  def l2norm(t, eps = 1e-10):
44
49
  return F.normalize(t, dim = -1, eps = eps)
45
50
 
51
+ # token shift from Peng et al. of RWKV
52
+ # cheap way to generate relative positions
53
+
54
+ def shift_feature_dim(t):
55
+ x, x_shift = t.chunk(2, dim = -1)
56
+ x_shift = pad_at_dim(x_shift, (1, -1), dim = 1)
57
+ return cat((x, x_shift), dim = -1)
58
+
46
59
  # time
47
60
 
48
61
  # they follow p0's research finding with the beta distribution
@@ -196,6 +209,7 @@ class MimicVideo(Module):
196
209
  *,
197
210
  dim_video_hidden,
198
211
  dim_action = 20,
212
+ dim_joint_state = 32,
199
213
  depth = 8,
200
214
  dim_head = 64,
201
215
  heads = 8,
@@ -215,10 +229,10 @@ class MimicVideo(Module):
215
229
 
216
230
  dim_time_cond = default(dim_time_cond, dim * 2)
217
231
 
218
- self.to_time_cond = nn.Sequential(
219
- RandomFourierEmbed(dim),
220
- create_mlp(dim_in = dim, dim = dim_time_cond, depth = 2, activation = nn.SiLU())
221
- )
232
+ self.to_fourier_embed = RandomFourierEmbed(dim) # used by deepmind, its fine
233
+ self.to_time_cond = create_mlp(dim_in = dim * 2, dim = dim_time_cond, depth = 2, activation = nn.SiLU())
234
+
235
+ self.to_joint_state_token = Linear(dim_joint_state, dim)
222
236
 
223
237
  self.video_hidden_norm = nn.RMSNorm(dim_video_hidden)
224
238
 
@@ -262,17 +276,18 @@ class MimicVideo(Module):
262
276
  actions,
263
277
  video_hiddens, # they use layer 19 of cosmos predict, at first denoising step. that's all
264
278
  *,
279
+ joint_state,
265
280
  time = None,
281
+ time_video_denoise = 0., # 0 is noise in the scheme i prefer - default to their optimal choice, but can be changed
266
282
  context_mask = None,
267
283
  ):
284
+ batch, device = actions.shape[0], actions.device
268
285
 
269
286
  is_training = not exists(time)
270
287
 
271
288
  # handle flow time conditioning
272
289
 
273
290
  if is_training:
274
- batch, device = actions.shape[0], actions.device
275
-
276
291
  time = torch.rand((batch,), device = device)
277
292
  time = self.sample_time_fn(time)
278
293
 
@@ -285,7 +300,28 @@ class MimicVideo(Module):
285
300
  else:
286
301
  noised = actions
287
302
 
288
- time_cond = self.to_time_cond(time)
303
+ if time.ndim == 0:
304
+ time = rearrange(time, '-> b', b = batch)
305
+
306
+ # handle the video denoising times
307
+
308
+ time_video_denoise = cast_tensor(time_video_denoise)
309
+
310
+ if time_video_denoise.ndim == 0:
311
+ time_video_denoise = rearrange(time_video_denoise, '-> 1')
312
+
313
+ if time_video_denoise.shape[0] != batch:
314
+ time_video_denoise = repeat(time_video_denoise, '1 -> b', b = batch)
315
+
316
+ times = stack((time, time_video_denoise), dim = -1)
317
+
318
+ # fourier embed and mlp to time condition
319
+
320
+ fourier_embed = self.to_fourier_embed(times)
321
+
322
+ fourier_embed = rearrange(fourier_embed, '... times d -> ... (times d)')
323
+
324
+ time_cond = self.to_time_cond(fourier_embed)
289
325
 
290
326
  # handle video hiddens
291
327
 
@@ -295,6 +331,10 @@ class MimicVideo(Module):
295
331
 
296
332
  tokens = self.to_action_tokens(noised)
297
333
 
334
+ joint_state_token = self.to_joint_state_token(joint_state)
335
+
336
+ tokens, inverse_pack = pack_with_inverse((joint_state_token, tokens), 'b * d')
337
+
298
338
  # transformer layers
299
339
 
300
340
  for (
@@ -322,14 +362,28 @@ class MimicVideo(Module):
322
362
 
323
363
  tokens = residual + attn(tokens) * gate
324
364
 
325
- # feedforward
365
+ # prepare feedforward
326
366
 
327
367
  residual = tokens
328
368
 
329
369
  tokens, gate = ff_norm(tokens, time_cond)
330
370
 
371
+ # shift along time for action tokens for cheap relative positioning, which is better than messing with rope with such short action chunks
372
+
373
+ joint_state_token, tokens = inverse_pack(tokens)
374
+
375
+ tokens = shift_feature_dim(tokens)
376
+
377
+ tokens, _ = pack_with_inverse((joint_state_token, tokens), 'b * d')
378
+
379
+ # feedforward
380
+
331
381
  tokens = residual + ff(tokens) * gate
332
382
 
383
+ # remove joint token
384
+
385
+ _, tokens = inverse_pack(tokens)
386
+
333
387
  # prediction
334
388
 
335
389
  pred_flow = self.to_pred_action_flow(tokens)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mimic-video
3
- Version: 0.0.1
3
+ Version: 0.0.3
4
4
  Summary: Mimic Video
5
5
  Project-URL: Homepage, https://pypi.org/project/mimic-video/
6
6
  Project-URL: Repository, https://github.com/lucidrains/mimic-video
@@ -0,0 +1,6 @@
1
+ mimic_video/__init__.py,sha256=-4HP_pbT4YLhRUwNwuL4qyLHbgDyQ099nHL7eVi0_Ag,48
2
+ mimic_video/mimic_video.py,sha256=-2HVpXAgEG28JFkJeUlypdmOMyYDD2tw0Fisf9-BZ-M,10243
3
+ mimic_video-0.0.3.dist-info/METADATA,sha256=MVJMzysTCCpsgxBKUA9ye-aFSeQAXinyP3ejCtJ8JD8,2960
4
+ mimic_video-0.0.3.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
+ mimic_video-0.0.3.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
+ mimic_video-0.0.3.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- mimic_video/__init__.py,sha256=-4HP_pbT4YLhRUwNwuL4qyLHbgDyQ099nHL7eVi0_Ag,48
2
- mimic_video/mimic_video.py,sha256=aejvjr1F3A7pZFikf-kEgeOpi1_53xVddBMpDPoxA90,8272
3
- mimic_video-0.0.1.dist-info/METADATA,sha256=414y344JcuIKQJss7d9riTrHszIwthHW8DDSSuRntdo,2960
4
- mimic_video-0.0.1.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
- mimic_video-0.0.1.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
- mimic_video-0.0.1.dist-info/RECORD,,