evolutionary-policy-optimization 0.0.62__py3-none-any.whl → 0.0.64__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evolutionary_policy_optimization/epo.py +81 -55
- evolutionary_policy_optimization/mock_env.py +1 -1
- {evolutionary_policy_optimization-0.0.62.dist-info → evolutionary_policy_optimization-0.0.64.dist-info}/METADATA +1 -1
- evolutionary_policy_optimization-0.0.64.dist-info/RECORD +9 -0
- evolutionary_policy_optimization-0.0.62.dist-info/RECORD +0 -9
- {evolutionary_policy_optimization-0.0.62.dist-info → evolutionary_policy_optimization-0.0.64.dist-info}/WHEEL +0 -0
- {evolutionary_policy_optimization-0.0.62.dist-info → evolutionary_policy_optimization-0.0.64.dist-info}/licenses/LICENSE +0 -0
@@ -3,6 +3,7 @@ from typing import Callable
|
|
3
3
|
|
4
4
|
from pathlib import Path
|
5
5
|
from math import ceil
|
6
|
+
from itertools import product
|
6
7
|
from functools import partial, wraps
|
7
8
|
from collections import namedtuple
|
8
9
|
from random import randrange
|
@@ -1067,7 +1068,8 @@ class EPO(Module):
|
|
1067
1068
|
agent: Agent,
|
1068
1069
|
episodes_per_latent,
|
1069
1070
|
max_episode_length,
|
1070
|
-
action_sample_temperature = 1
|
1071
|
+
action_sample_temperature = 1.,
|
1072
|
+
fix_environ_across_latents = True
|
1071
1073
|
):
|
1072
1074
|
super().__init__()
|
1073
1075
|
self.agent = agent
|
@@ -1076,6 +1078,7 @@ class EPO(Module):
|
|
1076
1078
|
self.num_latents = agent.latent_gene_pool.num_latents
|
1077
1079
|
self.episodes_per_latent = episodes_per_latent
|
1078
1080
|
self.max_episode_length = max_episode_length
|
1081
|
+
self.fix_environ_across_latents = fix_environ_across_latents
|
1079
1082
|
|
1080
1083
|
self.register_buffer('dummy', tensor(0))
|
1081
1084
|
|
@@ -1083,31 +1086,60 @@ class EPO(Module):
|
|
1083
1086
|
def device(self):
|
1084
1087
|
return self.dummy.device
|
1085
1088
|
|
1086
|
-
def
|
1089
|
+
def rollouts_for_machine(
|
1090
|
+
self,
|
1091
|
+
fix_environ_across_latents = False
|
1092
|
+
): # -> (<latent_id>, <episode_id>, <maybe synced env seed>) for the machine
|
1093
|
+
|
1087
1094
|
num_latents = self.num_latents
|
1095
|
+
episodes = self.episodes_per_latent
|
1096
|
+
num_latent_episodes = num_latents * episodes
|
1097
|
+
|
1098
|
+
# if fixing environment across latents, compute all the environment seeds upfront for simplicity
|
1099
|
+
|
1100
|
+
environment_seeds = None
|
1101
|
+
|
1102
|
+
if fix_environ_across_latents:
|
1103
|
+
environment_seeds = torch.randint(0, int(1e6), (episodes,))
|
1104
|
+
|
1105
|
+
if is_distributed():
|
1106
|
+
dist.all_reduce(environment_seeds) # reduce sum as a way to synchronize. it's fine
|
1107
|
+
|
1108
|
+
# get number of machines, and this machine id
|
1088
1109
|
|
1089
1110
|
world_size, rank = get_world_and_rank()
|
1090
1111
|
|
1091
|
-
assert
|
1092
|
-
|
1112
|
+
assert num_latent_episodes >= world_size, f'number of ({self.num_latents} latents x {self.episodes_per_latent} episodes) ({num_latent_episodes}) must be greater than world size ({world_size}) for now'
|
1113
|
+
|
1114
|
+
latent_episode_permutations = list(product(range(num_latents), range(episodes)))
|
1093
1115
|
|
1094
|
-
|
1116
|
+
num_rollouts_per_machine = ceil(num_latent_episodes / world_size)
|
1095
1117
|
|
1096
|
-
for i in range(
|
1097
|
-
|
1118
|
+
for i in range(num_rollouts_per_machine):
|
1119
|
+
rollout_id = rank * num_rollouts_per_machine + i
|
1098
1120
|
|
1099
|
-
if
|
1121
|
+
if rollout_id >= num_latent_episodes:
|
1100
1122
|
continue
|
1101
1123
|
|
1102
|
-
|
1124
|
+
latent_id, episode_id = latent_episode_permutations[rollout_id]
|
1125
|
+
|
1126
|
+
# maybe synchronized environment seed
|
1127
|
+
|
1128
|
+
maybe_seed = None
|
1129
|
+
if fix_environ_across_latents:
|
1130
|
+
maybe_seed = environment_seeds[episode_id]
|
1131
|
+
|
1132
|
+
yield latent_id, episode_id, maybe_seed
|
1103
1133
|
|
1104
1134
|
@torch.no_grad()
|
1105
1135
|
def forward(
|
1106
1136
|
self,
|
1107
1137
|
env,
|
1108
|
-
|
1138
|
+
fix_environ_across_latents = None
|
1109
1139
|
) -> MemoriesAndCumulativeRewards:
|
1110
1140
|
|
1141
|
+
fix_environ_across_latents = default(fix_environ_across_latents, self.fix_environ_across_latents)
|
1142
|
+
|
1111
1143
|
self.agent.eval()
|
1112
1144
|
|
1113
1145
|
invalid_episode = tensor(-1) # will use `episode_id` value of `-1` for the `next_value`, needed for not discarding last reward for generalized advantage estimate
|
@@ -1116,79 +1148,73 @@ class EPO(Module):
|
|
1116
1148
|
|
1117
1149
|
rewards_per_latent_episode = torch.zeros((self.num_latents, self.episodes_per_latent))
|
1118
1150
|
|
1119
|
-
|
1151
|
+
rollout_gen = self.rollouts_for_machine(fix_environ_across_latents)
|
1120
1152
|
|
1121
|
-
for episode_id in tqdm(
|
1153
|
+
for latent_id, episode_id, maybe_seed in tqdm(rollout_gen, desc = 'rollout'):
|
1122
1154
|
|
1123
|
-
|
1155
|
+
time = 0
|
1124
1156
|
|
1125
|
-
#
|
1157
|
+
# initial state
|
1126
1158
|
|
1127
|
-
|
1159
|
+
reset_kwargs = dict()
|
1128
1160
|
|
1129
|
-
if
|
1130
|
-
seed =
|
1131
|
-
env_reset_kwargs = dict(seed = seed)
|
1161
|
+
if fix_environ_across_latents:
|
1162
|
+
reset_kwargs.update(seed = maybe_seed)
|
1132
1163
|
|
1133
|
-
|
1164
|
+
state = env.reset(**reset_kwargs)
|
1134
1165
|
|
1135
|
-
|
1136
|
-
time = 0
|
1166
|
+
# get latent from pool
|
1137
1167
|
|
1138
|
-
|
1168
|
+
latent = self.agent.latent_gene_pool(latent_id = latent_id)
|
1139
1169
|
|
1140
|
-
|
1170
|
+
# until maximum episode length
|
1141
1171
|
|
1142
|
-
|
1172
|
+
done = tensor(False)
|
1143
1173
|
|
1144
|
-
|
1174
|
+
while time < self.max_episode_length and not done:
|
1145
1175
|
|
1146
|
-
#
|
1176
|
+
# sample action
|
1147
1177
|
|
1148
|
-
|
1178
|
+
action, log_prob = temp_batch_dim(self.agent.get_actor_actions)(state, latent = latent, sample = True, temperature = self.action_sample_temperature)
|
1149
1179
|
|
1150
|
-
|
1180
|
+
# values
|
1151
1181
|
|
1152
|
-
|
1182
|
+
value = temp_batch_dim(self.agent.get_critic_values)(state, latent = latent, use_ema_if_available = True)
|
1153
1183
|
|
1154
|
-
|
1184
|
+
# get the next state, action, and reward
|
1155
1185
|
|
1156
|
-
|
1186
|
+
state, reward, done = env(action)
|
1157
1187
|
|
1158
|
-
|
1188
|
+
# update cumulative rewards per latent, to be used as default fitness score
|
1159
1189
|
|
1160
|
-
|
1190
|
+
rewards_per_latent_episode[latent_id, episode_id] += reward
|
1191
|
+
|
1192
|
+
# store memories
|
1161
1193
|
|
1162
|
-
|
1163
|
-
|
1164
|
-
|
1165
|
-
|
1166
|
-
|
1167
|
-
|
1168
|
-
|
1169
|
-
|
1170
|
-
|
1171
|
-
|
1172
|
-
state,
|
1173
|
-
tensor(latent_id),
|
1174
|
-
action,
|
1175
|
-
log_prob,
|
1176
|
-
reward,
|
1177
|
-
value,
|
1178
|
-
done
|
1179
|
-
)
|
1194
|
+
memory = Memory(
|
1195
|
+
tensor(episode_id),
|
1196
|
+
state,
|
1197
|
+
tensor(latent_id),
|
1198
|
+
action,
|
1199
|
+
log_prob,
|
1200
|
+
reward,
|
1201
|
+
value,
|
1202
|
+
done
|
1203
|
+
)
|
1180
1204
|
|
1181
|
-
|
1205
|
+
memories.append(memory)
|
1182
1206
|
|
1183
|
-
|
1207
|
+
time += 1
|
1184
1208
|
|
1185
|
-
|
1209
|
+
if not done:
|
1210
|
+
# add bootstrap value if truncated
|
1186
1211
|
|
1187
1212
|
next_value = temp_batch_dim(self.agent.get_critic_values)(state, latent = latent)
|
1188
1213
|
|
1189
1214
|
memory_for_gae = memory._replace(
|
1190
1215
|
episode_id = invalid_episode,
|
1191
|
-
value = next_value
|
1216
|
+
value = next_value,
|
1217
|
+
done = tensor(True)
|
1192
1218
|
)
|
1193
1219
|
|
1194
1220
|
memories.append(memory_for_gae)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: evolutionary-policy-optimization
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.64
|
4
4
|
Summary: EPO - Pytorch
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
|
@@ -0,0 +1,9 @@
|
|
1
|
+
evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
|
2
|
+
evolutionary_policy_optimization/distributed.py,sha256=7KgZdeS_wxBHo_du9XZFB1Cu318J-Bp66Xdr6Log_20,2423
|
3
|
+
evolutionary_policy_optimization/epo.py,sha256=0_jC9Tbl6FiscLHklvTKtuQTwZL8egqFKW-4JUxxwvw,37001
|
4
|
+
evolutionary_policy_optimization/experimental.py,sha256=-IgqjJ_Wk_CMB1y9YYWpoYqTG9GZHAS6kbRdTluVevg,1563
|
5
|
+
evolutionary_policy_optimization/mock_env.py,sha256=gvATGA51Ym5sf3jiR2VmlpjiCcT7KCDDY_SrR-MEwsU,941
|
6
|
+
evolutionary_policy_optimization-0.0.64.dist-info/METADATA,sha256=vWdnTe2a86wTenEh29TNJlYEjD8A5CPtsyylxh4XsE0,6220
|
7
|
+
evolutionary_policy_optimization-0.0.64.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
evolutionary_policy_optimization-0.0.64.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
evolutionary_policy_optimization-0.0.64.dist-info/RECORD,,
|
@@ -1,9 +0,0 @@
|
|
1
|
-
evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
|
2
|
-
evolutionary_policy_optimization/distributed.py,sha256=7KgZdeS_wxBHo_du9XZFB1Cu318J-Bp66Xdr6Log_20,2423
|
3
|
-
evolutionary_policy_optimization/epo.py,sha256=lWhHpsfq6vpri6yeDXSTLRMKGPwl0kt3klh0fVaInSs,35921
|
4
|
-
evolutionary_policy_optimization/experimental.py,sha256=-IgqjJ_Wk_CMB1y9YYWpoYqTG9GZHAS6kbRdTluVevg,1563
|
5
|
-
evolutionary_policy_optimization/mock_env.py,sha256=202KJ5g57wQvOzhGYzgHfBa7Y2do5uuDvl5kFg5o73g,934
|
6
|
-
evolutionary_policy_optimization-0.0.62.dist-info/METADATA,sha256=oqJyUOXJwHrdf6JCVKPfOmhGJbXgqOmPWN_46l0JtWs,6220
|
7
|
-
evolutionary_policy_optimization-0.0.62.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
-
evolutionary_policy_optimization-0.0.62.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
-
evolutionary_policy_optimization-0.0.62.dist-info/RECORD,,
|
File without changes
|