mimic-video 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mimic-video might be problematic. Click here for more details.
- mimic_video/mimic_video.py +65 -11
- {mimic_video-0.0.1.dist-info → mimic_video-0.0.3.dist-info}/METADATA +1 -1
- mimic_video-0.0.3.dist-info/RECORD +6 -0
- mimic_video-0.0.1.dist-info/RECORD +0 -6
- {mimic_video-0.0.1.dist-info → mimic_video-0.0.3.dist-info}/WHEEL +0 -0
- {mimic_video-0.0.1.dist-info → mimic_video-0.0.3.dist-info}/licenses/LICENSE +0 -0
mimic_video/mimic_video.py
CHANGED
|
@@ -1,18 +1,20 @@
|
|
|
1
1
|
import torch
|
|
2
|
-
from torch import nn
|
|
2
|
+
from torch import nn, cat, stack, is_tensor, tensor
|
|
3
3
|
from torch.nn import Module, ModuleList, Linear
|
|
4
4
|
|
|
5
5
|
import torch.nn.functional as F
|
|
6
6
|
|
|
7
7
|
import einx
|
|
8
|
-
from einops import einsum, rearrange
|
|
8
|
+
from einops import einsum, rearrange, repeat
|
|
9
9
|
from einops.layers.torch import Rearrange
|
|
10
10
|
|
|
11
11
|
from x_mlps_pytorch import create_mlp
|
|
12
12
|
|
|
13
13
|
from torch_einops_utils import (
|
|
14
14
|
pad_left_ndim,
|
|
15
|
-
align_dims_left
|
|
15
|
+
align_dims_left,
|
|
16
|
+
pad_at_dim,
|
|
17
|
+
pack_with_inverse,
|
|
16
18
|
)
|
|
17
19
|
|
|
18
20
|
# ein notation
|
|
@@ -37,12 +39,23 @@ def divisible_by(num, den):
|
|
|
37
39
|
|
|
38
40
|
# tensor function
|
|
39
41
|
|
|
42
|
+
def cast_tensor(val, device = None):
|
|
43
|
+
return tensor(val, device = device) if not is_tensor(val) else val
|
|
44
|
+
|
|
40
45
|
def max_neg_value(t):
|
|
41
46
|
return -torch.finfo(t.dtype).max
|
|
42
47
|
|
|
43
48
|
def l2norm(t, eps = 1e-10):
|
|
44
49
|
return F.normalize(t, dim = -1, eps = eps)
|
|
45
50
|
|
|
51
|
+
# token shift from Peng et al. of RWKV
|
|
52
|
+
# cheap way to generate relative positions
|
|
53
|
+
|
|
54
|
+
def shift_feature_dim(t):
|
|
55
|
+
x, x_shift = t.chunk(2, dim = -1)
|
|
56
|
+
x_shift = pad_at_dim(x_shift, (1, -1), dim = 1)
|
|
57
|
+
return cat((x, x_shift), dim = -1)
|
|
58
|
+
|
|
46
59
|
# time
|
|
47
60
|
|
|
48
61
|
# they follow p0's research finding with the beta distribution
|
|
@@ -196,6 +209,7 @@ class MimicVideo(Module):
|
|
|
196
209
|
*,
|
|
197
210
|
dim_video_hidden,
|
|
198
211
|
dim_action = 20,
|
|
212
|
+
dim_joint_state = 32,
|
|
199
213
|
depth = 8,
|
|
200
214
|
dim_head = 64,
|
|
201
215
|
heads = 8,
|
|
@@ -215,10 +229,10 @@ class MimicVideo(Module):
|
|
|
215
229
|
|
|
216
230
|
dim_time_cond = default(dim_time_cond, dim * 2)
|
|
217
231
|
|
|
218
|
-
self.
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
)
|
|
232
|
+
self.to_fourier_embed = RandomFourierEmbed(dim) # used by deepmind, its fine
|
|
233
|
+
self.to_time_cond = create_mlp(dim_in = dim * 2, dim = dim_time_cond, depth = 2, activation = nn.SiLU())
|
|
234
|
+
|
|
235
|
+
self.to_joint_state_token = Linear(dim_joint_state, dim)
|
|
222
236
|
|
|
223
237
|
self.video_hidden_norm = nn.RMSNorm(dim_video_hidden)
|
|
224
238
|
|
|
@@ -262,17 +276,18 @@ class MimicVideo(Module):
|
|
|
262
276
|
actions,
|
|
263
277
|
video_hiddens, # they use layer 19 of cosmos predict, at first denoising step. that's all
|
|
264
278
|
*,
|
|
279
|
+
joint_state,
|
|
265
280
|
time = None,
|
|
281
|
+
time_video_denoise = 0., # 0 is noise in the scheme i prefer - default to their optimal choice, but can be changed
|
|
266
282
|
context_mask = None,
|
|
267
283
|
):
|
|
284
|
+
batch, device = actions.shape[0], actions.device
|
|
268
285
|
|
|
269
286
|
is_training = not exists(time)
|
|
270
287
|
|
|
271
288
|
# handle flow time conditioning
|
|
272
289
|
|
|
273
290
|
if is_training:
|
|
274
|
-
batch, device = actions.shape[0], actions.device
|
|
275
|
-
|
|
276
291
|
time = torch.rand((batch,), device = device)
|
|
277
292
|
time = self.sample_time_fn(time)
|
|
278
293
|
|
|
@@ -285,7 +300,28 @@ class MimicVideo(Module):
|
|
|
285
300
|
else:
|
|
286
301
|
noised = actions
|
|
287
302
|
|
|
288
|
-
|
|
303
|
+
if time.ndim == 0:
|
|
304
|
+
time = rearrange(time, '-> b', b = batch)
|
|
305
|
+
|
|
306
|
+
# handle the video denoising times
|
|
307
|
+
|
|
308
|
+
time_video_denoise = cast_tensor(time_video_denoise)
|
|
309
|
+
|
|
310
|
+
if time_video_denoise.ndim == 0:
|
|
311
|
+
time_video_denoise = rearrange(time_video_denoise, '-> 1')
|
|
312
|
+
|
|
313
|
+
if time_video_denoise.shape[0] != batch:
|
|
314
|
+
time_video_denoise = repeat(time_video_denoise, '1 -> b', b = batch)
|
|
315
|
+
|
|
316
|
+
times = stack((time, time_video_denoise), dim = -1)
|
|
317
|
+
|
|
318
|
+
# fourier embed and mlp to time condition
|
|
319
|
+
|
|
320
|
+
fourier_embed = self.to_fourier_embed(times)
|
|
321
|
+
|
|
322
|
+
fourier_embed = rearrange(fourier_embed, '... times d -> ... (times d)')
|
|
323
|
+
|
|
324
|
+
time_cond = self.to_time_cond(fourier_embed)
|
|
289
325
|
|
|
290
326
|
# handle video hiddens
|
|
291
327
|
|
|
@@ -295,6 +331,10 @@ class MimicVideo(Module):
|
|
|
295
331
|
|
|
296
332
|
tokens = self.to_action_tokens(noised)
|
|
297
333
|
|
|
334
|
+
joint_state_token = self.to_joint_state_token(joint_state)
|
|
335
|
+
|
|
336
|
+
tokens, inverse_pack = pack_with_inverse((joint_state_token, tokens), 'b * d')
|
|
337
|
+
|
|
298
338
|
# transformer layers
|
|
299
339
|
|
|
300
340
|
for (
|
|
@@ -322,14 +362,28 @@ class MimicVideo(Module):
|
|
|
322
362
|
|
|
323
363
|
tokens = residual + attn(tokens) * gate
|
|
324
364
|
|
|
325
|
-
# feedforward
|
|
365
|
+
# prepare feedforward
|
|
326
366
|
|
|
327
367
|
residual = tokens
|
|
328
368
|
|
|
329
369
|
tokens, gate = ff_norm(tokens, time_cond)
|
|
330
370
|
|
|
371
|
+
# shift along time for action tokens for cheap relative positioning, which is better than messing with rope with such short action chunks
|
|
372
|
+
|
|
373
|
+
joint_state_token, tokens = inverse_pack(tokens)
|
|
374
|
+
|
|
375
|
+
tokens = shift_feature_dim(tokens)
|
|
376
|
+
|
|
377
|
+
tokens, _ = pack_with_inverse((joint_state_token, tokens), 'b * d')
|
|
378
|
+
|
|
379
|
+
# feedforward
|
|
380
|
+
|
|
331
381
|
tokens = residual + ff(tokens) * gate
|
|
332
382
|
|
|
383
|
+
# remove joint token
|
|
384
|
+
|
|
385
|
+
_, tokens = inverse_pack(tokens)
|
|
386
|
+
|
|
333
387
|
# prediction
|
|
334
388
|
|
|
335
389
|
pred_flow = self.to_pred_action_flow(tokens)
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
mimic_video/__init__.py,sha256=-4HP_pbT4YLhRUwNwuL4qyLHbgDyQ099nHL7eVi0_Ag,48
|
|
2
|
+
mimic_video/mimic_video.py,sha256=-2HVpXAgEG28JFkJeUlypdmOMyYDD2tw0Fisf9-BZ-M,10243
|
|
3
|
+
mimic_video-0.0.3.dist-info/METADATA,sha256=MVJMzysTCCpsgxBKUA9ye-aFSeQAXinyP3ejCtJ8JD8,2960
|
|
4
|
+
mimic_video-0.0.3.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
5
|
+
mimic_video-0.0.3.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
+
mimic_video-0.0.3.dist-info/RECORD,,
|
|
@@ -1,6 +0,0 @@
|
|
|
1
|
-
mimic_video/__init__.py,sha256=-4HP_pbT4YLhRUwNwuL4qyLHbgDyQ099nHL7eVi0_Ag,48
|
|
2
|
-
mimic_video/mimic_video.py,sha256=aejvjr1F3A7pZFikf-kEgeOpi1_53xVddBMpDPoxA90,8272
|
|
3
|
-
mimic_video-0.0.1.dist-info/METADATA,sha256=414y344JcuIKQJss7d9riTrHszIwthHW8DDSSuRntdo,2960
|
|
4
|
-
mimic_video-0.0.1.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
5
|
-
mimic_video-0.0.1.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
-
mimic_video-0.0.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|