d4rt 0.0.2__tar.gz → 0.0.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {d4rt-0.0.2 → d4rt-0.0.3}/PKG-INFO +41 -1
- {d4rt-0.0.2 → d4rt-0.0.3}/README.md +40 -0
- {d4rt-0.0.2 → d4rt-0.0.3}/d4rt/d4rt.py +29 -7
- {d4rt-0.0.2 → d4rt-0.0.3}/pyproject.toml +1 -1
- {d4rt-0.0.2 → d4rt-0.0.3}/.gitignore +0 -0
- {d4rt-0.0.2 → d4rt-0.0.3}/LICENSE +0 -0
- {d4rt-0.0.2 → d4rt-0.0.3}/d4rt/__init__.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: d4rt
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.3
|
|
4
4
|
Summary: Implementation of D4RT, Efficiently Reconstructing Dynamic Scenes
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/d4rt/
|
|
6
6
|
Project-URL: Repository, https://codeberg.org/lucidrains/d4rt
|
|
@@ -50,6 +50,46 @@ Description-Content-Type: text/markdown
|
|
|
50
50
|
|
|
51
51
|
Implementation of [D4RT](https://d4rt-paper.github.io/), Efficiently Reconstructing Dynamic Scenes, Deepmind
|
|
52
52
|
|
|
53
|
+
## install
|
|
54
|
+
|
|
55
|
+
```shell
|
|
56
|
+
$ pip install d4rt
|
|
57
|
+
```
|
|
58
|
+
|
|
59
|
+
## usage
|
|
60
|
+
|
|
61
|
+
```python
|
|
62
|
+
import torch
|
|
63
|
+
from d4rt import D4RT
|
|
64
|
+
|
|
65
|
+
model = D4RT(
|
|
66
|
+
dim = 512,
|
|
67
|
+
video_image_size = 128,
|
|
68
|
+
video_patch_size = 32,
|
|
69
|
+
video_max_time_len = 10,
|
|
70
|
+
enc_depth = 6,
|
|
71
|
+
dec_depth = 6
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
videos = torch.randn(2, 10, 3, 128, 128)
|
|
75
|
+
points = torch.randn(2, 5, 3)
|
|
76
|
+
queries = torch.randn(2, 5, 512)
|
|
77
|
+
|
|
78
|
+
loss = model(
|
|
79
|
+
videos,
|
|
80
|
+
coors = torch.randint(0, 128, (2, 5, 2)),
|
|
81
|
+
time_src = torch.randint(0, 10, (2, 5)),
|
|
82
|
+
time_tgt = torch.randint(0, 10, (2, 5)),
|
|
83
|
+
time_camera = torch.randint(0, 10, (2, 5)),
|
|
84
|
+
points = points
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
loss.backward()
|
|
88
|
+
|
|
89
|
+
pred = model(videos, queries = queries) # (2, 5, 3)
|
|
90
|
+
assert pred.shape == (2, 5, 3)
|
|
91
|
+
```
|
|
92
|
+
|
|
53
93
|
## citations
|
|
54
94
|
|
|
55
95
|
```bibtex
|
|
@@ -4,6 +4,46 @@
|
|
|
4
4
|
|
|
5
5
|
Implementation of [D4RT](https://d4rt-paper.github.io/), Efficiently Reconstructing Dynamic Scenes, Deepmind
|
|
6
6
|
|
|
7
|
+
## install
|
|
8
|
+
|
|
9
|
+
```shell
|
|
10
|
+
$ pip install d4rt
|
|
11
|
+
```
|
|
12
|
+
|
|
13
|
+
## usage
|
|
14
|
+
|
|
15
|
+
```python
|
|
16
|
+
import torch
|
|
17
|
+
from d4rt import D4RT
|
|
18
|
+
|
|
19
|
+
model = D4RT(
|
|
20
|
+
dim = 512,
|
|
21
|
+
video_image_size = 128,
|
|
22
|
+
video_patch_size = 32,
|
|
23
|
+
video_max_time_len = 10,
|
|
24
|
+
enc_depth = 6,
|
|
25
|
+
dec_depth = 6
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
videos = torch.randn(2, 10, 3, 128, 128)
|
|
29
|
+
points = torch.randn(2, 5, 3)
|
|
30
|
+
queries = torch.randn(2, 5, 512)
|
|
31
|
+
|
|
32
|
+
loss = model(
|
|
33
|
+
videos,
|
|
34
|
+
coors = torch.randint(0, 128, (2, 5, 2)),
|
|
35
|
+
time_src = torch.randint(0, 10, (2, 5)),
|
|
36
|
+
time_tgt = torch.randint(0, 10, (2, 5)),
|
|
37
|
+
time_camera = torch.randint(0, 10, (2, 5)),
|
|
38
|
+
points = points
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
loss.backward()
|
|
42
|
+
|
|
43
|
+
pred = model(videos, queries = queries) # (2, 5, 3)
|
|
44
|
+
assert pred.shape == (2, 5, 3)
|
|
45
|
+
```
|
|
46
|
+
|
|
7
47
|
## citations
|
|
8
48
|
|
|
9
49
|
```bibtex
|
|
@@ -10,9 +10,9 @@ from x_transformers import Encoder, CrossAttender, Attention, FeedForward
|
|
|
10
10
|
# ein notation
|
|
11
11
|
|
|
12
12
|
import einx
|
|
13
|
-
from einops import rearrange
|
|
13
|
+
from einops import rearrange, repeat
|
|
14
14
|
from einops.layers.torch import Rearrange
|
|
15
|
-
from torch_einops_utils import pack_with_inverse
|
|
15
|
+
from torch_einops_utils import pack_with_inverse, lens_to_mask, maybe
|
|
16
16
|
|
|
17
17
|
# helpers
|
|
18
18
|
|
|
@@ -24,7 +24,12 @@ def divisible_by(num, den):
|
|
|
24
24
|
|
|
25
25
|
# function for the patch embedding in the query
|
|
26
26
|
|
|
27
|
-
def extract_patches(
|
|
27
|
+
def extract_patches(
|
|
28
|
+
video, # float[b t c h w]
|
|
29
|
+
coors, # int[b q 2]
|
|
30
|
+
time_src, # int[b q]
|
|
31
|
+
patch_size
|
|
32
|
+
):
|
|
28
33
|
b, q, p, device = *time_src.shape, patch_size, video.device
|
|
29
34
|
|
|
30
35
|
padded_video = F.pad(video, (p,) * 4)
|
|
@@ -112,11 +117,15 @@ class VideoEncoder(Module):
|
|
|
112
117
|
|
|
113
118
|
def forward(
|
|
114
119
|
self,
|
|
115
|
-
video
|
|
120
|
+
video, # float[b t c h w],
|
|
121
|
+
mask = None # bool[b t]
|
|
116
122
|
): # float[b n d]
|
|
117
123
|
|
|
118
124
|
tokens = self.patch_to_tokens(video) # float[b t s d]
|
|
119
125
|
|
|
126
|
+
if exists(mask):
|
|
127
|
+
mask = repeat(mask, 'b ... -> (b s) ...', s = tokens.shape[-2])
|
|
128
|
+
|
|
120
129
|
for spatial_attn, time_attn, ff in self.layers:
|
|
121
130
|
|
|
122
131
|
# space attn
|
|
@@ -133,7 +142,7 @@ class VideoEncoder(Module):
|
|
|
133
142
|
|
|
134
143
|
tokens, inverse_pack = pack_with_inverse(tokens, '* t d')
|
|
135
144
|
|
|
136
|
-
tokens = time_attn(tokens) + tokens
|
|
145
|
+
tokens = time_attn(tokens, mask = mask) + tokens
|
|
137
146
|
|
|
138
147
|
tokens = inverse_pack(tokens)
|
|
139
148
|
|
|
@@ -223,8 +232,12 @@ class D4RT(Module):
|
|
|
223
232
|
time_camera = None, # int[b q]
|
|
224
233
|
queries = None, # float[b q d]
|
|
225
234
|
points = None, # float[b q 3]
|
|
226
|
-
return_pred = False
|
|
235
|
+
return_pred = False,
|
|
236
|
+
video_lens = None # int[b]
|
|
227
237
|
):
|
|
238
|
+
|
|
239
|
+
# embedding to queries
|
|
240
|
+
|
|
228
241
|
assert (
|
|
229
242
|
exists(queries) or
|
|
230
243
|
all([exists(p) for p in (coors, time_src, time_tgt, time_camera)])
|
|
@@ -245,12 +258,21 @@ class D4RT(Module):
|
|
|
245
258
|
|
|
246
259
|
queries = self.norm_queries(queries)
|
|
247
260
|
|
|
248
|
-
|
|
261
|
+
# self attention
|
|
262
|
+
|
|
263
|
+
time = video.shape[1]
|
|
264
|
+
video_mask = maybe(lens_to_mask)(video_lens, time)
|
|
265
|
+
|
|
266
|
+
global_spatial_repr = self.to_global_spatial_repr(video, mask = video_mask)
|
|
249
267
|
|
|
250
268
|
global_spatial_repr, inverse_pack_spacetime = pack_with_inverse(global_spatial_repr, 'b * d')
|
|
251
269
|
|
|
270
|
+
# cross attention
|
|
271
|
+
|
|
252
272
|
queried = self.cross_attender(queries, context = global_spatial_repr)
|
|
253
273
|
|
|
274
|
+
# prediction
|
|
275
|
+
|
|
254
276
|
pred = self.to_pred(queried)
|
|
255
277
|
|
|
256
278
|
if not exists(points):
|
|
File without changes
|
|
File without changes
|
|
File without changes
|