d4rt 0.0.2__tar.gz → 0.0.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: d4rt
3
- Version: 0.0.2
3
+ Version: 0.0.3
4
4
  Summary: Implementation of D4RT, Efficiently Reconstructing Dynamic Scenes
5
5
  Project-URL: Homepage, https://pypi.org/project/d4rt/
6
6
  Project-URL: Repository, https://codeberg.org/lucidrains/d4rt
@@ -50,6 +50,46 @@ Description-Content-Type: text/markdown
50
50
 
51
51
  Implementation of [D4RT](https://d4rt-paper.github.io/), Efficiently Reconstructing Dynamic Scenes, Deepmind
52
52
 
53
+ ## install
54
+
55
+ ```shell
56
+ $ pip install d4rt
57
+ ```
58
+
59
+ ## usage
60
+
61
+ ```python
62
+ import torch
63
+ from d4rt import D4RT
64
+
65
+ model = D4RT(
66
+ dim = 512,
67
+ video_image_size = 128,
68
+ video_patch_size = 32,
69
+ video_max_time_len = 10,
70
+ enc_depth = 6,
71
+ dec_depth = 6
72
+ )
73
+
74
+ videos = torch.randn(2, 10, 3, 128, 128)
75
+ points = torch.randn(2, 5, 3)
76
+ queries = torch.randn(2, 5, 512)
77
+
78
+ loss = model(
79
+ videos,
80
+ coors = torch.randint(0, 128, (2, 5, 2)),
81
+ time_src = torch.randint(0, 10, (2, 5)),
82
+ time_tgt = torch.randint(0, 10, (2, 5)),
83
+ time_camera = torch.randint(0, 10, (2, 5)),
84
+ points = points
85
+ )
86
+
87
+ loss.backward()
88
+
89
+ pred = model(videos, queries = queries) # (2, 5, 3)
90
+ assert pred.shape == (2, 5, 3)
91
+ ```
92
+
53
93
  ## citations
54
94
 
55
95
  ```bibtex
@@ -4,6 +4,46 @@
4
4
 
5
5
  Implementation of [D4RT](https://d4rt-paper.github.io/), Efficiently Reconstructing Dynamic Scenes, Deepmind
6
6
 
7
+ ## install
8
+
9
+ ```shell
10
+ $ pip install d4rt
11
+ ```
12
+
13
+ ## usage
14
+
15
+ ```python
16
+ import torch
17
+ from d4rt import D4RT
18
+
19
+ model = D4RT(
20
+ dim = 512,
21
+ video_image_size = 128,
22
+ video_patch_size = 32,
23
+ video_max_time_len = 10,
24
+ enc_depth = 6,
25
+ dec_depth = 6
26
+ )
27
+
28
+ videos = torch.randn(2, 10, 3, 128, 128)
29
+ points = torch.randn(2, 5, 3)
30
+ queries = torch.randn(2, 5, 512)
31
+
32
+ loss = model(
33
+ videos,
34
+ coors = torch.randint(0, 128, (2, 5, 2)),
35
+ time_src = torch.randint(0, 10, (2, 5)),
36
+ time_tgt = torch.randint(0, 10, (2, 5)),
37
+ time_camera = torch.randint(0, 10, (2, 5)),
38
+ points = points
39
+ )
40
+
41
+ loss.backward()
42
+
43
+ pred = model(videos, queries = queries) # (2, 5, 3)
44
+ assert pred.shape == (2, 5, 3)
45
+ ```
46
+
7
47
  ## citations
8
48
 
9
49
  ```bibtex
@@ -10,9 +10,9 @@ from x_transformers import Encoder, CrossAttender, Attention, FeedForward
10
10
  # ein notation
11
11
 
12
12
  import einx
13
- from einops import rearrange
13
+ from einops import rearrange, repeat
14
14
  from einops.layers.torch import Rearrange
15
- from torch_einops_utils import pack_with_inverse
15
+ from torch_einops_utils import pack_with_inverse, lens_to_mask, maybe
16
16
 
17
17
  # helpers
18
18
 
@@ -24,7 +24,12 @@ def divisible_by(num, den):
24
24
 
25
25
  # function for the patch embedding in the query
26
26
 
27
- def extract_patches(video, coors, time_src, patch_size):
27
+ def extract_patches(
28
+ video, # float[b t c h w]
29
+ coors, # int[b q 2]
30
+ time_src, # int[b q]
31
+ patch_size
32
+ ):
28
33
  b, q, p, device = *time_src.shape, patch_size, video.device
29
34
 
30
35
  padded_video = F.pad(video, (p,) * 4)
@@ -112,11 +117,15 @@ class VideoEncoder(Module):
112
117
 
113
118
  def forward(
114
119
  self,
115
- video # float[b t c h w]
120
+ video, # float[b t c h w],
121
+ mask = None # bool[b t]
116
122
  ): # float[b n d]
117
123
 
118
124
  tokens = self.patch_to_tokens(video) # float[b t s d]
119
125
 
126
+ if exists(mask):
127
+ mask = repeat(mask, 'b ... -> (b s) ...', s = tokens.shape[-2])
128
+
120
129
  for spatial_attn, time_attn, ff in self.layers:
121
130
 
122
131
  # space attn
@@ -133,7 +142,7 @@ class VideoEncoder(Module):
133
142
 
134
143
  tokens, inverse_pack = pack_with_inverse(tokens, '* t d')
135
144
 
136
- tokens = time_attn(tokens) + tokens
145
+ tokens = time_attn(tokens, mask = mask) + tokens
137
146
 
138
147
  tokens = inverse_pack(tokens)
139
148
 
@@ -223,8 +232,12 @@ class D4RT(Module):
223
232
  time_camera = None, # int[b q]
224
233
  queries = None, # float[b q d]
225
234
  points = None, # float[b q 3]
226
- return_pred = False
235
+ return_pred = False,
236
+ video_lens = None # int[b]
227
237
  ):
238
+
239
+ # embedding to queries
240
+
228
241
  assert (
229
242
  exists(queries) or
230
243
  all([exists(p) for p in (coors, time_src, time_tgt, time_camera)])
@@ -245,12 +258,21 @@ class D4RT(Module):
245
258
 
246
259
  queries = self.norm_queries(queries)
247
260
 
248
- global_spatial_repr = self.to_global_spatial_repr(video)
261
+ # self attention
262
+
263
+ time = video.shape[1]
264
+ video_mask = maybe(lens_to_mask)(video_lens, time)
265
+
266
+ global_spatial_repr = self.to_global_spatial_repr(video, mask = video_mask)
249
267
 
250
268
  global_spatial_repr, inverse_pack_spacetime = pack_with_inverse(global_spatial_repr, 'b * d')
251
269
 
270
+ # cross attention
271
+
252
272
  queried = self.cross_attender(queries, context = global_spatial_repr)
253
273
 
274
+ # prediction
275
+
254
276
  pred = self.to_pred(queried)
255
277
 
256
278
  if not exists(points):
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "d4rt"
3
- version = "0.0.2"
3
+ version = "0.0.3"
4
4
  description = "Implementation of D4RT, Efficiently Reconstructing Dynamic Scenes"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
File without changes
File without changes
File without changes