dreamer4 0.0.72__tar.gz → 0.0.74__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dreamer4 might be problematic. Click here for more details.

@@ -5,11 +5,11 @@ jobs:
5
5
  build:
6
6
 
7
7
  runs-on: ubuntu-latest
8
- timeout-minutes: 20
8
+ timeout-minutes: 60
9
9
  strategy:
10
10
  fail-fast: false
11
11
  matrix:
12
- group: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
12
+ group: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
13
13
 
14
14
  steps:
15
15
  - uses: actions/checkout@v4
@@ -24,4 +24,4 @@ jobs:
24
24
  python -m uv pip install -e .[test]
25
25
  - name: Test with pytest
26
26
  run: |
27
- python -m pytest --num-shards 10 --shard-id ${{ matrix.group }} tests/
27
+ python -m pytest --num-shards 20 --shard-id ${{ matrix.group }} tests/
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dreamer4
3
- Version: 0.0.72
3
+ Version: 0.0.74
4
4
  Summary: Dreamer 4
5
5
  Project-URL: Homepage, https://pypi.org/project/dreamer4/
6
6
  Project-URL: Repository, https://github.com/lucidrains/dreamer4
@@ -82,6 +82,7 @@ class Experience:
82
82
  log_probs: tuple[Tensor, Tensor] | None = None
83
83
  values: Tensor | None = None
84
84
  step_size: int | None = None
85
+ lens: Tensor | None = None,
85
86
  agent_index: int = 0
86
87
  is_from_world_model: bool = True
87
88
 
@@ -90,6 +91,18 @@ def combine_experiences(
90
91
  ) -> Experience:
91
92
 
92
93
  assert len(exps) > 0
94
+
95
+ # set lens if not there
96
+
97
+ for exp in exps:
98
+ latents = exp.latents
99
+ batch, time, device = *latents.shape[:2], latents.device
100
+
101
+ if not exists(exp.lens):
102
+ exp.lens = torch.full((batch,), time, device = device)
103
+
104
+ # convert to dictionary
105
+
93
106
  exps_dict = [asdict(exp) for exp in exps]
94
107
 
95
108
  values, tree_specs = zip(*[tree_flatten(exp_dict) for exp_dict in exps_dict])
@@ -108,7 +121,11 @@ def combine_experiences(
108
121
  concatted = []
109
122
 
110
123
  for field_values in all_field_values:
124
+
111
125
  if is_tensor(first(field_values)):
126
+
127
+ field_values = pad_tensors_at_dim_to_max_len(field_values, dims = (1, 2))
128
+
112
129
  new_field_value = cat(field_values)
113
130
  else:
114
131
  new_field_value = first(list(set(field_values)))
@@ -222,6 +239,27 @@ def pad_at_dim(
222
239
  zeros = ((0, 0) * dims_from_right)
223
240
  return F.pad(t, (*zeros, *pad), value = value)
224
241
 
242
+ def pad_to_len(t, target_len, *, dim):
243
+ curr_len = t.shape[dim]
244
+
245
+ if curr_len >= target_len:
246
+ return t
247
+
248
+ return pad_at_dim(t, (0, target_len - curr_len), dim = dim)
249
+
250
+ def pad_tensors_at_dim_to_max_len(
251
+ tensors: list[Tensor],
252
+ dims: tuple[int, ...]
253
+ ):
254
+ for dim in dims:
255
+ if dim >= first(tensors).ndim:
256
+ continue
257
+
258
+ max_time = max([t.shape[dim] for t in tensors])
259
+ tensors = [pad_to_len(t, max_time, dim = dim) for t in tensors]
260
+
261
+ return tensors
262
+
225
263
  def align_dims_left(t, aligned_to):
226
264
  shape = t.shape
227
265
  num_right_dims = aligned_to.ndim - t.ndim
@@ -2560,12 +2598,16 @@ class DynamicsWorldModel(Module):
2560
2598
 
2561
2599
  # returning agent actions, rewards, and log probs + values for policy optimization
2562
2600
 
2601
+ batch, device = latents.shape[0], latents.device
2602
+ experience_lens = torch.full((batch,), time_steps, device = device)
2603
+
2563
2604
  gen = Experience(
2564
2605
  latents = latents,
2565
2606
  video = video,
2566
2607
  proprio = proprio if has_proprio else None,
2567
2608
  step_size = step_size,
2568
2609
  agent_index = agent_index,
2610
+ lens = experience_lens,
2569
2611
  is_from_world_model = True
2570
2612
  )
2571
2613
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "dreamer4"
3
- version = "0.0.72"
3
+ version = "0.0.74"
4
4
  description = "Dreamer 4"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -653,7 +653,7 @@ def test_online_rl(
653
653
 
654
654
  # manually
655
655
 
656
- one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized)
656
+ one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = vectorized)
657
657
  another_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized)
658
658
 
659
659
  combined_experience = combine_experiences([one_experience, another_experience])
File without changes
File without changes
File without changes
File without changes
File without changes