dreamer4 0.0.82__tar.gz → 0.0.84__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dreamer4 might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dreamer4
3
- Version: 0.0.82
3
+ Version: 0.0.84
4
4
  Summary: Dreamer 4
5
5
  Project-URL: Homepage, https://pypi.org/project/dreamer4/
6
6
  Project-URL: Repository, https://github.com/lucidrains/dreamer4
@@ -77,6 +77,7 @@ class Experience:
77
77
  latents: Tensor
78
78
  video: Tensor | None = None
79
79
  proprio: Tensor | None = None
80
+ agent_embed: Tensor | None = None,
80
81
  rewards: Tensor | None = None
81
82
  actions: tuple[Tensor, Tensor] | None = None
82
83
  log_probs: tuple[Tensor, Tensor] | None = None
@@ -2105,7 +2106,8 @@ class DynamicsWorldModel(Module):
2105
2106
  step_size = 4,
2106
2107
  max_timesteps = 16,
2107
2108
  env_is_vectorized = False,
2108
- use_time_kv_cache = True
2109
+ use_time_kv_cache = True,
2110
+ store_agent_embed = True
2109
2111
  ):
2110
2112
  assert exists(self.video_tokenizer)
2111
2113
 
@@ -2130,16 +2132,23 @@ class DynamicsWorldModel(Module):
2130
2132
  values = None
2131
2133
  latents = None
2132
2134
 
2135
+ acc_agent_embed = None
2136
+
2133
2137
  # keep track of termination, for setting the `is_truncated` field on Experience and for early stopping interaction with env
2134
2138
 
2135
2139
  is_terminated = full((batch,), False, device = device)
2140
+ is_truncated = full((batch,), False, device = device)
2141
+
2136
2142
  episode_lens = full((batch,), 0, device = device)
2137
2143
 
2138
2144
  # maybe time kv cache
2139
2145
 
2140
2146
  time_kv_cache = None
2141
2147
 
2142
- for i in range(max_timesteps + 1):
2148
+ step_index = 0
2149
+
2150
+ while not is_terminated.all():
2151
+ step_index += 1
2143
2152
 
2144
2153
  latents = self.video_tokenizer(video, return_latents = True)
2145
2154
 
@@ -2201,10 +2210,15 @@ class DynamicsWorldModel(Module):
2201
2210
 
2202
2211
  if len(env_step_out) == 2:
2203
2212
  next_frame, reward = env_step_out
2204
- terminate = full((batch,), False)
2213
+ terminated = full((batch,), False)
2214
+ truncated = full((batch,), False)
2205
2215
 
2206
2216
  elif len(env_step_out) == 3:
2207
- next_frame, reward, terminate = env_step_out
2217
+ next_frame, reward, terminated = env_step_out
2218
+ truncated = full((batch,), False)
2219
+
2220
+ elif len(env_step_out) == 4:
2221
+ next_frame, reward, terminated, truncated = env_step_out
2208
2222
 
2209
2223
  # update episode lens
2210
2224
 
@@ -2212,7 +2226,20 @@ class DynamicsWorldModel(Module):
2212
2226
 
2213
2227
  # update `is_terminated`
2214
2228
 
2215
- is_terminated |= terminate
2229
+ # (1) - environment says it is terminated
2230
+ # (2) - previous step is truncated (this step is for bootstrap value)
2231
+
2232
+ is_terminated |= (terminated | is_truncated)
2233
+
2234
+ # update `is_truncated`
2235
+
2236
+ if step_index <= max_timesteps:
2237
+ is_truncated |= truncated
2238
+
2239
+ if step_index == max_timesteps:
2240
+ # if the step index is at the max time step allowed, set the truncated flag, if not already terminated
2241
+
2242
+ is_truncated |= ~is_terminated
2216
2243
 
2217
2244
  # batch and time dimension
2218
2245
 
@@ -2228,10 +2255,7 @@ class DynamicsWorldModel(Module):
2228
2255
  video = cat((video, next_frame), dim = 2)
2229
2256
  rewards = safe_cat((rewards, reward), dim = 1)
2230
2257
 
2231
- # early break out if all terminated
2232
-
2233
- if is_terminated.all():
2234
- break
2258
+ acc_agent_embed = safe_cat((acc_agent_embed, agent_embed), dim = 1)
2235
2259
 
2236
2260
  # package up one experience for learning
2237
2261
 
@@ -2244,9 +2268,10 @@ class DynamicsWorldModel(Module):
2244
2268
  actions = (discrete_actions, continuous_actions),
2245
2269
  log_probs = (discrete_log_probs, continuous_log_probs),
2246
2270
  values = values,
2271
+ agent_embed = acc_agent_embed if store_agent_embed else None,
2247
2272
  step_size = step_size,
2248
2273
  agent_index = agent_index,
2249
- is_truncated = ~is_terminated,
2274
+ is_truncated = is_truncated,
2250
2275
  lens = episode_lens,
2251
2276
  is_from_world_model = False
2252
2277
  )
@@ -2473,6 +2498,7 @@ class DynamicsWorldModel(Module):
2473
2498
  return_agent_actions = False,
2474
2499
  return_log_probs_and_values = False,
2475
2500
  return_time_kv_cache = False,
2501
+ store_agent_embed = False
2476
2502
 
2477
2503
  ): # (b t n d) | (b c t h w)
2478
2504
 
@@ -2525,6 +2551,10 @@ class DynamicsWorldModel(Module):
2525
2551
  decoded_continuous_log_probs = None
2526
2552
  decoded_values = None
2527
2553
 
2554
+ # maybe store agent embed
2555
+
2556
+ acc_agent_embed = None
2557
+
2528
2558
  # maybe return rewards
2529
2559
 
2530
2560
  decoded_rewards = None
@@ -2633,6 +2663,10 @@ class DynamicsWorldModel(Module):
2633
2663
 
2634
2664
  decoded_rewards = cat((decoded_rewards, pred_reward), dim = 1)
2635
2665
 
2666
+ # maybe store agent embed
2667
+
2668
+ acc_agent_embed = safe_cat((acc_agent_embed, agent_embed), dim = 1)
2669
+
2636
2670
  # decode the agent actions if needed
2637
2671
 
2638
2672
  if return_agent_actions:
@@ -2729,6 +2763,7 @@ class DynamicsWorldModel(Module):
2729
2763
  latents = latents,
2730
2764
  video = video,
2731
2765
  proprio = proprio if has_proprio else None,
2766
+ agent_embed = acc_agent_embed if store_agent_embed else None,
2732
2767
  step_size = step_size,
2733
2768
  agent_index = agent_index,
2734
2769
  lens = experience_lens,
@@ -22,7 +22,9 @@ class MockEnv(Module):
22
22
  num_envs = 1,
23
23
  vectorized = False,
24
24
  terminate_after_step = None,
25
- rand_terminate_prob = 0.05
25
+ rand_terminate_prob = 0.05,
26
+ can_truncate = False,
27
+ rand_truncate_prob = 0.05,
26
28
  ):
27
29
  super().__init__()
28
30
  self.image_shape = image_shape
@@ -32,12 +34,15 @@ class MockEnv(Module):
32
34
  self.vectorized = vectorized
33
35
  assert not (vectorized and num_envs == 1)
34
36
 
35
- # mocking termination
37
+ # mocking termination and truncation
36
38
 
37
39
  self.can_terminate = exists(terminate_after_step)
38
40
  self.terminate_after_step = terminate_after_step
39
41
  self.rand_terminate_prob = rand_terminate_prob
40
42
 
43
+ self.can_truncate = can_truncate
44
+ self.rand_truncate_prob = rand_truncate_prob
45
+
41
46
  self.register_buffer('_step', tensor(0))
42
47
 
43
48
  def get_random_state(self):
@@ -72,14 +77,21 @@ class MockEnv(Module):
72
77
 
73
78
  out = (state, reward)
74
79
 
80
+
75
81
  if self.can_terminate:
76
- terminate = (
77
- (torch.rand((self.num_envs)) < self.rand_terminate_prob) &
78
- (self._step > self.terminate_after_step)
79
- )
82
+ shape = (self.num_envs,) if self.vectorized else (1,)
83
+ valid_step = self._step > self.terminate_after_step
84
+
85
+ terminate = (torch.rand(shape) < self.rand_terminate_prob) & valid_step
80
86
 
81
87
  out = (*out, terminate)
82
88
 
89
+ # maybe truncation
90
+
91
+ if self.can_truncate:
92
+ truncate = (torch.rand(shape) < self.rand_truncate_prob) & valid_step & ~terminate
93
+ out = (*out, truncate)
94
+
83
95
  self._step.add_(1)
84
96
 
85
97
  return out
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "dreamer4"
3
- version = "0.0.82"
3
+ version = "0.0.84"
4
4
  description = "Dreamer 4"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -613,10 +613,14 @@ def test_cache_generate():
613
613
  @param('vectorized', (False, True))
614
614
  @param('use_signed_advantage', (False, True))
615
615
  @param('env_can_terminate', (False, True))
616
+ @param('env_can_truncate', (False, True))
617
+ @param('store_agent_embed', (False, True))
616
618
  def test_online_rl(
617
619
  vectorized,
618
620
  use_signed_advantage,
619
- env_can_terminate
621
+ env_can_terminate,
622
+ env_can_truncate,
623
+ store_agent_embed
620
624
  ):
621
625
  from dreamer4.dreamer4 import DynamicsWorldModel, VideoTokenizer
622
626
 
@@ -656,16 +660,20 @@ def test_online_rl(
656
660
  vectorized = vectorized,
657
661
  num_envs = 4,
658
662
  terminate_after_step = 2 if env_can_terminate else None,
663
+ can_truncate = env_can_truncate,
659
664
  rand_terminate_prob = 0.1
660
665
  )
661
666
 
662
667
  # manually
663
668
 
664
- one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = vectorized)
665
- another_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized)
669
+ one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = vectorized, store_agent_embed = store_agent_embed)
670
+ another_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized, store_agent_embed = store_agent_embed)
666
671
 
667
672
  combined_experience = combine_experiences([one_experience, another_experience])
668
673
 
674
+ if store_agent_embed:
675
+ assert exists(combined_experience.agent_embed)
676
+
669
677
  actor_loss, critic_loss = world_model_and_policy.learn_from_experience(combined_experience, use_signed_advantage = use_signed_advantage)
670
678
 
671
679
  actor_loss.backward()
File without changes
File without changes
File without changes
File without changes