dreamer4 0.0.73__py3-none-any.whl → 0.0.75__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dreamer4/dreamer4.py +44 -0
- {dreamer4-0.0.73.dist-info → dreamer4-0.0.75.dist-info}/METADATA +1 -1
- dreamer4-0.0.75.dist-info/RECORD +8 -0
- dreamer4-0.0.73.dist-info/RECORD +0 -8
- {dreamer4-0.0.73.dist-info → dreamer4-0.0.75.dist-info}/WHEEL +0 -0
- {dreamer4-0.0.73.dist-info → dreamer4-0.0.75.dist-info}/licenses/LICENSE +0 -0
dreamer4/dreamer4.py
CHANGED
|
@@ -91,6 +91,18 @@ def combine_experiences(
|
|
|
91
91
|
) -> Experience:
|
|
92
92
|
|
|
93
93
|
assert len(exps) > 0
|
|
94
|
+
|
|
95
|
+
# set lens if not there
|
|
96
|
+
|
|
97
|
+
for exp in exps:
|
|
98
|
+
latents = exp.latents
|
|
99
|
+
batch, time, device = *latents.shape[:2], latents.device
|
|
100
|
+
|
|
101
|
+
if not exists(exp.lens):
|
|
102
|
+
exp.lens = torch.full((batch,), time, device = device)
|
|
103
|
+
|
|
104
|
+
# convert to dictionary
|
|
105
|
+
|
|
94
106
|
exps_dict = [asdict(exp) for exp in exps]
|
|
95
107
|
|
|
96
108
|
values, tree_specs = zip(*[tree_flatten(exp_dict) for exp_dict in exps_dict])
|
|
@@ -109,7 +121,11 @@ def combine_experiences(
|
|
|
109
121
|
concatted = []
|
|
110
122
|
|
|
111
123
|
for field_values in all_field_values:
|
|
124
|
+
|
|
112
125
|
if is_tensor(first(field_values)):
|
|
126
|
+
|
|
127
|
+
field_values = pad_tensors_at_dim_to_max_len(field_values, dims = (1, 2))
|
|
128
|
+
|
|
113
129
|
new_field_value = cat(field_values)
|
|
114
130
|
else:
|
|
115
131
|
new_field_value = first(list(set(field_values)))
|
|
@@ -223,6 +239,27 @@ def pad_at_dim(
|
|
|
223
239
|
zeros = ((0, 0) * dims_from_right)
|
|
224
240
|
return F.pad(t, (*zeros, *pad), value = value)
|
|
225
241
|
|
|
242
|
+
def pad_to_len(t, target_len, *, dim):
|
|
243
|
+
curr_len = t.shape[dim]
|
|
244
|
+
|
|
245
|
+
if curr_len >= target_len:
|
|
246
|
+
return t
|
|
247
|
+
|
|
248
|
+
return pad_at_dim(t, (0, target_len - curr_len), dim = dim)
|
|
249
|
+
|
|
250
|
+
def pad_tensors_at_dim_to_max_len(
|
|
251
|
+
tensors: list[Tensor],
|
|
252
|
+
dims: tuple[int, ...]
|
|
253
|
+
):
|
|
254
|
+
for dim in dims:
|
|
255
|
+
if dim >= first(tensors).ndim:
|
|
256
|
+
continue
|
|
257
|
+
|
|
258
|
+
max_time = max([t.shape[dim] for t in tensors])
|
|
259
|
+
tensors = [pad_to_len(t, max_time, dim = dim) for t in tensors]
|
|
260
|
+
|
|
261
|
+
return tensors
|
|
262
|
+
|
|
226
263
|
def align_dims_left(t, aligned_to):
|
|
227
264
|
shape = t.shape
|
|
228
265
|
num_right_dims = aligned_to.ndim - t.ndim
|
|
@@ -2833,6 +2870,13 @@ class DynamicsWorldModel(Module):
|
|
|
2833
2870
|
|
|
2834
2871
|
space_tokens = self.latents_to_spatial_tokens(noised_latents)
|
|
2835
2872
|
|
|
2873
|
+
# maybe add view embedding
|
|
2874
|
+
|
|
2875
|
+
if self.video_has_multi_view:
|
|
2876
|
+
space_tokens = add('b t v ... d, v d', space_tokens, self.view_emb)
|
|
2877
|
+
|
|
2878
|
+
# merge spatial tokens
|
|
2879
|
+
|
|
2836
2880
|
space_tokens, inverse_pack_space_per_latent = pack_one(space_tokens, 'b t * d')
|
|
2837
2881
|
|
|
2838
2882
|
num_spatial_tokens = space_tokens.shape[-2]
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
|
|
2
|
+
dreamer4/dreamer4.py,sha256=6ngJq_cpil97RsITPmcExlxoZTcZ6XNdFfOBDcdACQg,105915
|
|
3
|
+
dreamer4/mocks.py,sha256=Oi91Yv1oK0E-Wz-KDkf79xoyWzIXCvMLCr0WYCpJDLA,1482
|
|
4
|
+
dreamer4/trainers.py,sha256=898ye9Y1mqxGZnU_gfQS6pECibZwwyA43sL7wK_JHAU,13993
|
|
5
|
+
dreamer4-0.0.75.dist-info/METADATA,sha256=VVaIj0vNfpT2JBm9AaSr8D-SP5wfWNgHcJjszdmkwU4,3065
|
|
6
|
+
dreamer4-0.0.75.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
+
dreamer4-0.0.75.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
+
dreamer4-0.0.75.dist-info/RECORD,,
|
dreamer4-0.0.73.dist-info/RECORD
DELETED
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
|
|
2
|
-
dreamer4/dreamer4.py,sha256=z9zYb9GLRGqfiXz0I6PaE94ZFFOEIy3TwmcUUS_5KvE,104842
|
|
3
|
-
dreamer4/mocks.py,sha256=Oi91Yv1oK0E-Wz-KDkf79xoyWzIXCvMLCr0WYCpJDLA,1482
|
|
4
|
-
dreamer4/trainers.py,sha256=898ye9Y1mqxGZnU_gfQS6pECibZwwyA43sL7wK_JHAU,13993
|
|
5
|
-
dreamer4-0.0.73.dist-info/METADATA,sha256=TwPW4CmYD__Ecv9qR8_uBlj09wmhjspkdPDMvVBdQO4,3065
|
|
6
|
-
dreamer4-0.0.73.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
-
dreamer4-0.0.73.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
-
dreamer4-0.0.73.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|