dreamer4 0.0.73__py3-none-any.whl → 0.0.74__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dreamer4 might be problematic. Click here for more details.
- dreamer4/dreamer4.py +37 -0
- {dreamer4-0.0.73.dist-info → dreamer4-0.0.74.dist-info}/METADATA +1 -1
- dreamer4-0.0.74.dist-info/RECORD +8 -0
- dreamer4-0.0.73.dist-info/RECORD +0 -8
- {dreamer4-0.0.73.dist-info → dreamer4-0.0.74.dist-info}/WHEEL +0 -0
- {dreamer4-0.0.73.dist-info → dreamer4-0.0.74.dist-info}/licenses/LICENSE +0 -0
dreamer4/dreamer4.py
CHANGED
|
@@ -91,6 +91,18 @@ def combine_experiences(
|
|
|
91
91
|
) -> Experience:
|
|
92
92
|
|
|
93
93
|
assert len(exps) > 0
|
|
94
|
+
|
|
95
|
+
# set lens if not there
|
|
96
|
+
|
|
97
|
+
for exp in exps:
|
|
98
|
+
latents = exp.latents
|
|
99
|
+
batch, time, device = *latents.shape[:2], latents.device
|
|
100
|
+
|
|
101
|
+
if not exists(exp.lens):
|
|
102
|
+
exp.lens = torch.full((batch,), time, device = device)
|
|
103
|
+
|
|
104
|
+
# convert to dictionary
|
|
105
|
+
|
|
94
106
|
exps_dict = [asdict(exp) for exp in exps]
|
|
95
107
|
|
|
96
108
|
values, tree_specs = zip(*[tree_flatten(exp_dict) for exp_dict in exps_dict])
|
|
@@ -109,7 +121,11 @@ def combine_experiences(
|
|
|
109
121
|
concatted = []
|
|
110
122
|
|
|
111
123
|
for field_values in all_field_values:
|
|
124
|
+
|
|
112
125
|
if is_tensor(first(field_values)):
|
|
126
|
+
|
|
127
|
+
field_values = pad_tensors_at_dim_to_max_len(field_values, dims = (1, 2))
|
|
128
|
+
|
|
113
129
|
new_field_value = cat(field_values)
|
|
114
130
|
else:
|
|
115
131
|
new_field_value = first(list(set(field_values)))
|
|
@@ -223,6 +239,27 @@ def pad_at_dim(
|
|
|
223
239
|
zeros = ((0, 0) * dims_from_right)
|
|
224
240
|
return F.pad(t, (*zeros, *pad), value = value)
|
|
225
241
|
|
|
242
|
+
def pad_to_len(t, target_len, *, dim):
|
|
243
|
+
curr_len = t.shape[dim]
|
|
244
|
+
|
|
245
|
+
if curr_len >= target_len:
|
|
246
|
+
return t
|
|
247
|
+
|
|
248
|
+
return pad_at_dim(t, (0, target_len - curr_len), dim = dim)
|
|
249
|
+
|
|
250
|
+
def pad_tensors_at_dim_to_max_len(
|
|
251
|
+
tensors: list[Tensor],
|
|
252
|
+
dims: tuple[int, ...]
|
|
253
|
+
):
|
|
254
|
+
for dim in dims:
|
|
255
|
+
if dim >= first(tensors).ndim:
|
|
256
|
+
continue
|
|
257
|
+
|
|
258
|
+
max_time = max([t.shape[dim] for t in tensors])
|
|
259
|
+
tensors = [pad_to_len(t, max_time, dim = dim) for t in tensors]
|
|
260
|
+
|
|
261
|
+
return tensors
|
|
262
|
+
|
|
226
263
|
def align_dims_left(t, aligned_to):
|
|
227
264
|
shape = t.shape
|
|
228
265
|
num_right_dims = aligned_to.ndim - t.ndim
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
|
|
2
|
+
dreamer4/dreamer4.py,sha256=4B1avRi7yhQNcUKmFWPz6Y0nWlTmvvWpoyPtInvl7Mk,105712
|
|
3
|
+
dreamer4/mocks.py,sha256=Oi91Yv1oK0E-Wz-KDkf79xoyWzIXCvMLCr0WYCpJDLA,1482
|
|
4
|
+
dreamer4/trainers.py,sha256=898ye9Y1mqxGZnU_gfQS6pECibZwwyA43sL7wK_JHAU,13993
|
|
5
|
+
dreamer4-0.0.74.dist-info/METADATA,sha256=UfYxKbbOXqGDkcVU8l9_pvvDruZ4XJmRkyiG8x1JRWk,3065
|
|
6
|
+
dreamer4-0.0.74.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
+
dreamer4-0.0.74.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
+
dreamer4-0.0.74.dist-info/RECORD,,
|
dreamer4-0.0.73.dist-info/RECORD
DELETED
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
|
|
2
|
-
dreamer4/dreamer4.py,sha256=z9zYb9GLRGqfiXz0I6PaE94ZFFOEIy3TwmcUUS_5KvE,104842
|
|
3
|
-
dreamer4/mocks.py,sha256=Oi91Yv1oK0E-Wz-KDkf79xoyWzIXCvMLCr0WYCpJDLA,1482
|
|
4
|
-
dreamer4/trainers.py,sha256=898ye9Y1mqxGZnU_gfQS6pECibZwwyA43sL7wK_JHAU,13993
|
|
5
|
-
dreamer4-0.0.73.dist-info/METADATA,sha256=TwPW4CmYD__Ecv9qR8_uBlj09wmhjspkdPDMvVBdQO4,3065
|
|
6
|
-
dreamer4-0.0.73.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
-
dreamer4-0.0.73.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
-
dreamer4-0.0.73.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|