dreamer4 0.0.82__py3-none-any.whl → 0.0.83__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dreamer4 might be problematic. Click here for more details.

dreamer4/dreamer4.py CHANGED
@@ -2133,13 +2133,18 @@ class DynamicsWorldModel(Module):
2133
2133
  # keep track of termination, for setting the `is_truncated` field on Experience and for early stopping interaction with env
2134
2134
 
2135
2135
  is_terminated = full((batch,), False, device = device)
2136
+ is_truncated = full((batch,), False, device = device)
2137
+
2136
2138
  episode_lens = full((batch,), 0, device = device)
2137
2139
 
2138
2140
  # maybe time kv cache
2139
2141
 
2140
2142
  time_kv_cache = None
2141
2143
 
2142
- for i in range(max_timesteps + 1):
2144
+ step_index = 0
2145
+
2146
+ while not is_terminated.all():
2147
+ step_index += 1
2143
2148
 
2144
2149
  latents = self.video_tokenizer(video, return_latents = True)
2145
2150
 
@@ -2201,10 +2206,15 @@ class DynamicsWorldModel(Module):
2201
2206
 
2202
2207
  if len(env_step_out) == 2:
2203
2208
  next_frame, reward = env_step_out
2204
- terminate = full((batch,), False)
2209
+ terminated = full((batch,), False)
2210
+ truncated = full((batch,), False)
2205
2211
 
2206
2212
  elif len(env_step_out) == 3:
2207
- next_frame, reward, terminate = env_step_out
2213
+ next_frame, reward, terminated = env_step_out
2214
+ truncated = full((batch,), False)
2215
+
2216
+ elif len(env_step_out) == 4:
2217
+ next_frame, reward, terminated, truncated = env_step_out
2208
2218
 
2209
2219
  # update episode lens
2210
2220
 
@@ -2212,7 +2222,20 @@ class DynamicsWorldModel(Module):
2212
2222
 
2213
2223
  # update `is_terminated`
2214
2224
 
2215
- is_terminated |= terminate
2225
+ # (1) - environment says it is terminated
2226
+ # (2) - previous step is truncated (this step is for bootstrap value)
2227
+
2228
+ is_terminated |= (terminated | is_truncated)
2229
+
2230
+ # update `is_truncated`
2231
+
2232
+ if step_index <= max_timesteps:
2233
+ is_truncated |= truncated
2234
+
2235
+ if step_index == max_timesteps:
2236
+ # if the step index is at the max time step allowed, set the truncated flag, if not already terminated
2237
+
2238
+ is_truncated |= ~is_terminated
2216
2239
 
2217
2240
  # batch and time dimension
2218
2241
 
@@ -2228,11 +2251,6 @@ class DynamicsWorldModel(Module):
2228
2251
  video = cat((video, next_frame), dim = 2)
2229
2252
  rewards = safe_cat((rewards, reward), dim = 1)
2230
2253
 
2231
- # early break out if all terminated
2232
-
2233
- if is_terminated.all():
2234
- break
2235
-
2236
2254
  # package up one experience for learning
2237
2255
 
2238
2256
  batch, device = latents.shape[0], latents.device
@@ -2246,7 +2264,7 @@ class DynamicsWorldModel(Module):
2246
2264
  values = values,
2247
2265
  step_size = step_size,
2248
2266
  agent_index = agent_index,
2249
- is_truncated = ~is_terminated,
2267
+ is_truncated = is_truncated,
2250
2268
  lens = episode_lens,
2251
2269
  is_from_world_model = False
2252
2270
  )
dreamer4/mocks.py CHANGED
@@ -22,7 +22,9 @@ class MockEnv(Module):
22
22
  num_envs = 1,
23
23
  vectorized = False,
24
24
  terminate_after_step = None,
25
- rand_terminate_prob = 0.05
25
+ rand_terminate_prob = 0.05,
26
+ can_truncate = False,
27
+ rand_truncate_prob = 0.05,
26
28
  ):
27
29
  super().__init__()
28
30
  self.image_shape = image_shape
@@ -32,12 +34,15 @@ class MockEnv(Module):
32
34
  self.vectorized = vectorized
33
35
  assert not (vectorized and num_envs == 1)
34
36
 
35
- # mocking termination
37
+ # mocking termination and truncation
36
38
 
37
39
  self.can_terminate = exists(terminate_after_step)
38
40
  self.terminate_after_step = terminate_after_step
39
41
  self.rand_terminate_prob = rand_terminate_prob
40
42
 
43
+ self.can_truncate = can_truncate
44
+ self.rand_truncate_prob = rand_truncate_prob
45
+
41
46
  self.register_buffer('_step', tensor(0))
42
47
 
43
48
  def get_random_state(self):
@@ -72,14 +77,21 @@ class MockEnv(Module):
72
77
 
73
78
  out = (state, reward)
74
79
 
80
+
75
81
  if self.can_terminate:
76
- terminate = (
77
- (torch.rand((self.num_envs)) < self.rand_terminate_prob) &
78
- (self._step > self.terminate_after_step)
79
- )
82
+ shape = (self.num_envs,) if self.vectorized else (1,)
83
+ valid_step = self._step > self.terminate_after_step
84
+
85
+ terminate = (torch.rand(shape) < self.rand_terminate_prob) & valid_step
80
86
 
81
87
  out = (*out, terminate)
82
88
 
89
+ # maybe truncation
90
+
91
+ if self.can_truncate:
92
+ truncate = (torch.rand(shape) < self.rand_truncate_prob) & valid_step & ~terminate
93
+ out = (*out, truncate)
94
+
83
95
  self._step.add_(1)
84
96
 
85
97
  return out
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dreamer4
3
- Version: 0.0.82
3
+ Version: 0.0.83
4
4
  Summary: Dreamer 4
5
5
  Project-URL: Homepage, https://pypi.org/project/dreamer4/
6
6
  Project-URL: Repository, https://github.com/lucidrains/dreamer4
@@ -0,0 +1,8 @@
1
+ dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
2
+ dreamer4/dreamer4.py,sha256=t9g-fJVRPn4nefpNiKwqkIkKZtIeJAn5V4ruNCJ5a9A,111265
3
+ dreamer4/mocks.py,sha256=TfqOB_Gq6N_GggBYwa6ZAJQx38ntlYbXZe23Ne4jshw,2502
4
+ dreamer4/trainers.py,sha256=D2b7WTgTHElLhIWLFgl2Ct2knGJLTk91HHpC5UkNvG0,14028
5
+ dreamer4-0.0.83.dist-info/METADATA,sha256=pYIw5Pj41JhfJ4Hp9AfegXrxnhIZ9Fk98kF7lY3nAZk,3065
6
+ dreamer4-0.0.83.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ dreamer4-0.0.83.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ dreamer4-0.0.83.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
2
- dreamer4/dreamer4.py,sha256=1WeHCCJO-wa85Ki9uP8FR70pFTro46LF4WyFMxSD90I,110556
3
- dreamer4/mocks.py,sha256=S1kiENV3kHM1L6pBOLDqLVCZJg7ZydEmagiNb8sFIXc,2077
4
- dreamer4/trainers.py,sha256=D2b7WTgTHElLhIWLFgl2Ct2knGJLTk91HHpC5UkNvG0,14028
5
- dreamer4-0.0.82.dist-info/METADATA,sha256=by1KbrcKVy0epdLDLAiSN1MXR_FBgyD40Ol_Mn5iZNM,3065
6
- dreamer4-0.0.82.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- dreamer4-0.0.82.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- dreamer4-0.0.82.dist-info/RECORD,,