dreamer4 0.0.82__py3-none-any.whl → 0.0.83__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dreamer4 might be problematic. Click here for more details.
- dreamer4/dreamer4.py +28 -10
- dreamer4/mocks.py +18 -6
- {dreamer4-0.0.82.dist-info → dreamer4-0.0.83.dist-info}/METADATA +1 -1
- dreamer4-0.0.83.dist-info/RECORD +8 -0
- dreamer4-0.0.82.dist-info/RECORD +0 -8
- {dreamer4-0.0.82.dist-info → dreamer4-0.0.83.dist-info}/WHEEL +0 -0
- {dreamer4-0.0.82.dist-info → dreamer4-0.0.83.dist-info}/licenses/LICENSE +0 -0
dreamer4/dreamer4.py
CHANGED
|
@@ -2133,13 +2133,18 @@ class DynamicsWorldModel(Module):
|
|
|
2133
2133
|
# keep track of termination, for setting the `is_truncated` field on Experience and for early stopping interaction with env
|
|
2134
2134
|
|
|
2135
2135
|
is_terminated = full((batch,), False, device = device)
|
|
2136
|
+
is_truncated = full((batch,), False, device = device)
|
|
2137
|
+
|
|
2136
2138
|
episode_lens = full((batch,), 0, device = device)
|
|
2137
2139
|
|
|
2138
2140
|
# maybe time kv cache
|
|
2139
2141
|
|
|
2140
2142
|
time_kv_cache = None
|
|
2141
2143
|
|
|
2142
|
-
|
|
2144
|
+
step_index = 0
|
|
2145
|
+
|
|
2146
|
+
while not is_terminated.all():
|
|
2147
|
+
step_index += 1
|
|
2143
2148
|
|
|
2144
2149
|
latents = self.video_tokenizer(video, return_latents = True)
|
|
2145
2150
|
|
|
@@ -2201,10 +2206,15 @@ class DynamicsWorldModel(Module):
|
|
|
2201
2206
|
|
|
2202
2207
|
if len(env_step_out) == 2:
|
|
2203
2208
|
next_frame, reward = env_step_out
|
|
2204
|
-
|
|
2209
|
+
terminated = full((batch,), False)
|
|
2210
|
+
truncated = full((batch,), False)
|
|
2205
2211
|
|
|
2206
2212
|
elif len(env_step_out) == 3:
|
|
2207
|
-
next_frame, reward,
|
|
2213
|
+
next_frame, reward, terminated = env_step_out
|
|
2214
|
+
truncated = full((batch,), False)
|
|
2215
|
+
|
|
2216
|
+
elif len(env_step_out) == 4:
|
|
2217
|
+
next_frame, reward, terminated, truncated = env_step_out
|
|
2208
2218
|
|
|
2209
2219
|
# update episode lens
|
|
2210
2220
|
|
|
@@ -2212,7 +2222,20 @@ class DynamicsWorldModel(Module):
|
|
|
2212
2222
|
|
|
2213
2223
|
# update `is_terminated`
|
|
2214
2224
|
|
|
2215
|
-
|
|
2225
|
+
# (1) - environment says it is terminated
|
|
2226
|
+
# (2) - previous step is truncated (this step is for bootstrap value)
|
|
2227
|
+
|
|
2228
|
+
is_terminated |= (terminated | is_truncated)
|
|
2229
|
+
|
|
2230
|
+
# update `is_truncated`
|
|
2231
|
+
|
|
2232
|
+
if step_index <= max_timesteps:
|
|
2233
|
+
is_truncated |= truncated
|
|
2234
|
+
|
|
2235
|
+
if step_index == max_timesteps:
|
|
2236
|
+
# if the step index is at the max time step allowed, set the truncated flag, if not already terminated
|
|
2237
|
+
|
|
2238
|
+
is_truncated |= ~is_terminated
|
|
2216
2239
|
|
|
2217
2240
|
# batch and time dimension
|
|
2218
2241
|
|
|
@@ -2228,11 +2251,6 @@ class DynamicsWorldModel(Module):
|
|
|
2228
2251
|
video = cat((video, next_frame), dim = 2)
|
|
2229
2252
|
rewards = safe_cat((rewards, reward), dim = 1)
|
|
2230
2253
|
|
|
2231
|
-
# early break out if all terminated
|
|
2232
|
-
|
|
2233
|
-
if is_terminated.all():
|
|
2234
|
-
break
|
|
2235
|
-
|
|
2236
2254
|
# package up one experience for learning
|
|
2237
2255
|
|
|
2238
2256
|
batch, device = latents.shape[0], latents.device
|
|
@@ -2246,7 +2264,7 @@ class DynamicsWorldModel(Module):
|
|
|
2246
2264
|
values = values,
|
|
2247
2265
|
step_size = step_size,
|
|
2248
2266
|
agent_index = agent_index,
|
|
2249
|
-
is_truncated =
|
|
2267
|
+
is_truncated = is_truncated,
|
|
2250
2268
|
lens = episode_lens,
|
|
2251
2269
|
is_from_world_model = False
|
|
2252
2270
|
)
|
dreamer4/mocks.py
CHANGED
|
@@ -22,7 +22,9 @@ class MockEnv(Module):
|
|
|
22
22
|
num_envs = 1,
|
|
23
23
|
vectorized = False,
|
|
24
24
|
terminate_after_step = None,
|
|
25
|
-
rand_terminate_prob = 0.05
|
|
25
|
+
rand_terminate_prob = 0.05,
|
|
26
|
+
can_truncate = False,
|
|
27
|
+
rand_truncate_prob = 0.05,
|
|
26
28
|
):
|
|
27
29
|
super().__init__()
|
|
28
30
|
self.image_shape = image_shape
|
|
@@ -32,12 +34,15 @@ class MockEnv(Module):
|
|
|
32
34
|
self.vectorized = vectorized
|
|
33
35
|
assert not (vectorized and num_envs == 1)
|
|
34
36
|
|
|
35
|
-
# mocking termination
|
|
37
|
+
# mocking termination and truncation
|
|
36
38
|
|
|
37
39
|
self.can_terminate = exists(terminate_after_step)
|
|
38
40
|
self.terminate_after_step = terminate_after_step
|
|
39
41
|
self.rand_terminate_prob = rand_terminate_prob
|
|
40
42
|
|
|
43
|
+
self.can_truncate = can_truncate
|
|
44
|
+
self.rand_truncate_prob = rand_truncate_prob
|
|
45
|
+
|
|
41
46
|
self.register_buffer('_step', tensor(0))
|
|
42
47
|
|
|
43
48
|
def get_random_state(self):
|
|
@@ -72,14 +77,21 @@ class MockEnv(Module):
|
|
|
72
77
|
|
|
73
78
|
out = (state, reward)
|
|
74
79
|
|
|
80
|
+
|
|
75
81
|
if self.can_terminate:
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
)
|
|
82
|
+
shape = (self.num_envs,) if self.vectorized else (1,)
|
|
83
|
+
valid_step = self._step > self.terminate_after_step
|
|
84
|
+
|
|
85
|
+
terminate = (torch.rand(shape) < self.rand_terminate_prob) & valid_step
|
|
80
86
|
|
|
81
87
|
out = (*out, terminate)
|
|
82
88
|
|
|
89
|
+
# maybe truncation
|
|
90
|
+
|
|
91
|
+
if self.can_truncate:
|
|
92
|
+
truncate = (torch.rand(shape) < self.rand_truncate_prob) & valid_step & ~terminate
|
|
93
|
+
out = (*out, truncate)
|
|
94
|
+
|
|
83
95
|
self._step.add_(1)
|
|
84
96
|
|
|
85
97
|
return out
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
|
|
2
|
+
dreamer4/dreamer4.py,sha256=t9g-fJVRPn4nefpNiKwqkIkKZtIeJAn5V4ruNCJ5a9A,111265
|
|
3
|
+
dreamer4/mocks.py,sha256=TfqOB_Gq6N_GggBYwa6ZAJQx38ntlYbXZe23Ne4jshw,2502
|
|
4
|
+
dreamer4/trainers.py,sha256=D2b7WTgTHElLhIWLFgl2Ct2knGJLTk91HHpC5UkNvG0,14028
|
|
5
|
+
dreamer4-0.0.83.dist-info/METADATA,sha256=pYIw5Pj41JhfJ4Hp9AfegXrxnhIZ9Fk98kF7lY3nAZk,3065
|
|
6
|
+
dreamer4-0.0.83.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
+
dreamer4-0.0.83.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
+
dreamer4-0.0.83.dist-info/RECORD,,
|
dreamer4-0.0.82.dist-info/RECORD
DELETED
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
|
|
2
|
-
dreamer4/dreamer4.py,sha256=1WeHCCJO-wa85Ki9uP8FR70pFTro46LF4WyFMxSD90I,110556
|
|
3
|
-
dreamer4/mocks.py,sha256=S1kiENV3kHM1L6pBOLDqLVCZJg7ZydEmagiNb8sFIXc,2077
|
|
4
|
-
dreamer4/trainers.py,sha256=D2b7WTgTHElLhIWLFgl2Ct2knGJLTk91HHpC5UkNvG0,14028
|
|
5
|
-
dreamer4-0.0.82.dist-info/METADATA,sha256=by1KbrcKVy0epdLDLAiSN1MXR_FBgyD40Ol_Mn5iZNM,3065
|
|
6
|
-
dreamer4-0.0.82.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
-
dreamer4-0.0.82.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
-
dreamer4-0.0.82.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|