dreamer4 0.0.73__py3-none-any.whl → 0.0.75__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dreamer4/dreamer4.py CHANGED
@@ -91,6 +91,18 @@ def combine_experiences(
91
91
  ) -> Experience:
92
92
 
93
93
  assert len(exps) > 0
94
+
95
+ # set lens if not there
96
+
97
+ for exp in exps:
98
+ latents = exp.latents
99
+ batch, time, device = *latents.shape[:2], latents.device
100
+
101
+ if not exists(exp.lens):
102
+ exp.lens = torch.full((batch,), time, device = device)
103
+
104
+ # convert to dictionary
105
+
94
106
  exps_dict = [asdict(exp) for exp in exps]
95
107
 
96
108
  values, tree_specs = zip(*[tree_flatten(exp_dict) for exp_dict in exps_dict])
@@ -109,7 +121,11 @@ def combine_experiences(
109
121
  concatted = []
110
122
 
111
123
  for field_values in all_field_values:
124
+
112
125
  if is_tensor(first(field_values)):
126
+
127
+ field_values = pad_tensors_at_dim_to_max_len(field_values, dims = (1, 2))
128
+
113
129
  new_field_value = cat(field_values)
114
130
  else:
115
131
  new_field_value = first(list(set(field_values)))
@@ -223,6 +239,27 @@ def pad_at_dim(
223
239
  zeros = ((0, 0) * dims_from_right)
224
240
  return F.pad(t, (*zeros, *pad), value = value)
225
241
 
242
+ def pad_to_len(t, target_len, *, dim):
243
+ curr_len = t.shape[dim]
244
+
245
+ if curr_len >= target_len:
246
+ return t
247
+
248
+ return pad_at_dim(t, (0, target_len - curr_len), dim = dim)
249
+
250
+ def pad_tensors_at_dim_to_max_len(
251
+ tensors: list[Tensor],
252
+ dims: tuple[int, ...]
253
+ ):
254
+ for dim in dims:
255
+ if dim >= first(tensors).ndim:
256
+ continue
257
+
258
+ max_time = max([t.shape[dim] for t in tensors])
259
+ tensors = [pad_to_len(t, max_time, dim = dim) for t in tensors]
260
+
261
+ return tensors
262
+
226
263
  def align_dims_left(t, aligned_to):
227
264
  shape = t.shape
228
265
  num_right_dims = aligned_to.ndim - t.ndim
@@ -2833,6 +2870,13 @@ class DynamicsWorldModel(Module):
2833
2870
 
2834
2871
  space_tokens = self.latents_to_spatial_tokens(noised_latents)
2835
2872
 
2873
+ # maybe add view embedding
2874
+
2875
+ if self.video_has_multi_view:
2876
+ space_tokens = add('b t v ... d, v d', space_tokens, self.view_emb)
2877
+
2878
+ # merge spatial tokens
2879
+
2836
2880
  space_tokens, inverse_pack_space_per_latent = pack_one(space_tokens, 'b t * d')
2837
2881
 
2838
2882
  num_spatial_tokens = space_tokens.shape[-2]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dreamer4
3
- Version: 0.0.73
3
+ Version: 0.0.75
4
4
  Summary: Dreamer 4
5
5
  Project-URL: Homepage, https://pypi.org/project/dreamer4/
6
6
  Project-URL: Repository, https://github.com/lucidrains/dreamer4
@@ -0,0 +1,8 @@
1
+ dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
2
+ dreamer4/dreamer4.py,sha256=6ngJq_cpil97RsITPmcExlxoZTcZ6XNdFfOBDcdACQg,105915
3
+ dreamer4/mocks.py,sha256=Oi91Yv1oK0E-Wz-KDkf79xoyWzIXCvMLCr0WYCpJDLA,1482
4
+ dreamer4/trainers.py,sha256=898ye9Y1mqxGZnU_gfQS6pECibZwwyA43sL7wK_JHAU,13993
5
+ dreamer4-0.0.75.dist-info/METADATA,sha256=VVaIj0vNfpT2JBm9AaSr8D-SP5wfWNgHcJjszdmkwU4,3065
6
+ dreamer4-0.0.75.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ dreamer4-0.0.75.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ dreamer4-0.0.75.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
2
- dreamer4/dreamer4.py,sha256=z9zYb9GLRGqfiXz0I6PaE94ZFFOEIy3TwmcUUS_5KvE,104842
3
- dreamer4/mocks.py,sha256=Oi91Yv1oK0E-Wz-KDkf79xoyWzIXCvMLCr0WYCpJDLA,1482
4
- dreamer4/trainers.py,sha256=898ye9Y1mqxGZnU_gfQS6pECibZwwyA43sL7wK_JHAU,13993
5
- dreamer4-0.0.73.dist-info/METADATA,sha256=TwPW4CmYD__Ecv9qR8_uBlj09wmhjspkdPDMvVBdQO4,3065
6
- dreamer4-0.0.73.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- dreamer4-0.0.73.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- dreamer4-0.0.73.dist-info/RECORD,,