evolutionary-policy-optimization 0.0.62__py3-none-any.whl → 0.0.63__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evolutionary_policy_optimization/epo.py +79 -59
- evolutionary_policy_optimization/mock_env.py +1 -1
- {evolutionary_policy_optimization-0.0.62.dist-info → evolutionary_policy_optimization-0.0.63.dist-info}/METADATA +1 -1
- evolutionary_policy_optimization-0.0.63.dist-info/RECORD +9 -0
- evolutionary_policy_optimization-0.0.62.dist-info/RECORD +0 -9
- {evolutionary_policy_optimization-0.0.62.dist-info → evolutionary_policy_optimization-0.0.63.dist-info}/WHEEL +0 -0
- {evolutionary_policy_optimization-0.0.62.dist-info → evolutionary_policy_optimization-0.0.63.dist-info}/licenses/LICENSE +0 -0
@@ -3,6 +3,7 @@ from typing import Callable
|
|
3
3
|
|
4
4
|
from pathlib import Path
|
5
5
|
from math import ceil
|
6
|
+
from itertools import product
|
6
7
|
from functools import partial, wraps
|
7
8
|
from collections import namedtuple
|
8
9
|
from random import randrange
|
@@ -1083,29 +1084,56 @@ class EPO(Module):
|
|
1083
1084
|
def device(self):
|
1084
1085
|
return self.dummy.device
|
1085
1086
|
|
1086
|
-
def
|
1087
|
+
def rollouts_for_machine(
|
1088
|
+
self,
|
1089
|
+
fix_environ_across_latents = False
|
1090
|
+
): # -> (<latent_id>, <episode_id>, <maybe synced env seed>) for the machine
|
1091
|
+
|
1087
1092
|
num_latents = self.num_latents
|
1093
|
+
episodes = self.episodes_per_latent
|
1094
|
+
num_latent_episodes = num_latents * episodes
|
1095
|
+
|
1096
|
+
# if fixing environment across latents, compute all the environment seeds upfront for simplicity
|
1097
|
+
|
1098
|
+
environment_seeds = None
|
1099
|
+
|
1100
|
+
if fix_environ_across_latents:
|
1101
|
+
environment_seeds = torch.randint(0, int(1e6), (episodes,))
|
1102
|
+
|
1103
|
+
if is_distributed():
|
1104
|
+
dist.all_reduce(environment_seeds) # reduce sum as a way to synchronize. it's fine
|
1105
|
+
|
1106
|
+
# get number of machines, and this machine id
|
1088
1107
|
|
1089
1108
|
world_size, rank = get_world_and_rank()
|
1090
1109
|
|
1091
|
-
assert
|
1092
|
-
|
1110
|
+
assert num_latent_episodes >= world_size, f'number of ({self.num_latents} latents x {self.episodes_per_latent} episodes) ({num_latent_episodes}) must be greater than world size ({world_size}) for now'
|
1111
|
+
|
1112
|
+
latent_episode_permutations = list(product(range(num_latents), range(episodes)))
|
1093
1113
|
|
1094
|
-
|
1114
|
+
num_rollouts_per_machine = ceil(num_latent_episodes / world_size)
|
1095
1115
|
|
1096
|
-
for i in range(
|
1097
|
-
|
1116
|
+
for i in range(num_rollouts_per_machine):
|
1117
|
+
rollout_id = rank * num_rollouts_per_machine + i
|
1098
1118
|
|
1099
|
-
if
|
1119
|
+
if rollout_id >= num_latent_episodes:
|
1100
1120
|
continue
|
1101
1121
|
|
1102
|
-
|
1122
|
+
latent_id, episode_id = latent_episode_permutations[rollout_id]
|
1123
|
+
|
1124
|
+
# maybe synchronized environment seed
|
1125
|
+
|
1126
|
+
maybe_seed = None
|
1127
|
+
if fix_environ_across_latents:
|
1128
|
+
maybe_seed = environment_seeds[episode_id]
|
1129
|
+
|
1130
|
+
yield latent_id, episode_id, maybe_seed
|
1103
1131
|
|
1104
1132
|
@torch.no_grad()
|
1105
1133
|
def forward(
|
1106
1134
|
self,
|
1107
1135
|
env,
|
1108
|
-
|
1136
|
+
fix_environ_across_latents = True
|
1109
1137
|
) -> MemoriesAndCumulativeRewards:
|
1110
1138
|
|
1111
1139
|
self.agent.eval()
|
@@ -1116,82 +1144,74 @@ class EPO(Module):
|
|
1116
1144
|
|
1117
1145
|
rewards_per_latent_episode = torch.zeros((self.num_latents, self.episodes_per_latent))
|
1118
1146
|
|
1119
|
-
|
1120
|
-
|
1121
|
-
for episode_id in tqdm(range(self.episodes_per_latent), desc = 'episode'):
|
1147
|
+
rollout_gen = self.rollouts_for_machine(fix_environ_across_latents)
|
1122
1148
|
|
1123
|
-
|
1149
|
+
for latent_id, episode_id, maybe_seed in tqdm(rollout_gen, desc = 'rollout'):
|
1124
1150
|
|
1125
|
-
|
1151
|
+
time = 0
|
1126
1152
|
|
1127
|
-
|
1153
|
+
# initial state
|
1128
1154
|
|
1129
|
-
|
1130
|
-
seed = maybe_sync_seed(device = self.device)
|
1131
|
-
env_reset_kwargs = dict(seed = seed)
|
1155
|
+
reset_kwargs = dict()
|
1132
1156
|
|
1133
|
-
|
1157
|
+
if fix_environ_across_latents:
|
1158
|
+
reset_kwargs.update(seed = maybe_seed)
|
1134
1159
|
|
1135
|
-
|
1136
|
-
time = 0
|
1160
|
+
state = env.reset(**reset_kwargs)
|
1137
1161
|
|
1138
|
-
|
1162
|
+
# get latent from pool
|
1139
1163
|
|
1140
|
-
|
1164
|
+
latent = self.agent.latent_gene_pool(latent_id = latent_id)
|
1141
1165
|
|
1142
|
-
|
1166
|
+
# until maximum episode length
|
1143
1167
|
|
1144
|
-
|
1168
|
+
done = tensor(False)
|
1145
1169
|
|
1146
|
-
|
1170
|
+
while time < self.max_episode_length and not done:
|
1147
1171
|
|
1148
|
-
|
1172
|
+
# sample action
|
1149
1173
|
|
1150
|
-
|
1174
|
+
action, log_prob = temp_batch_dim(self.agent.get_actor_actions)(state, latent = latent, sample = True, temperature = self.action_sample_temperature)
|
1151
1175
|
|
1152
|
-
|
1176
|
+
# values
|
1153
1177
|
|
1154
|
-
|
1178
|
+
value = temp_batch_dim(self.agent.get_critic_values)(state, latent = latent, use_ema_if_available = True)
|
1155
1179
|
|
1156
|
-
|
1180
|
+
# get the next state, action, and reward
|
1157
1181
|
|
1158
|
-
|
1182
|
+
state, reward, done = env(action)
|
1159
1183
|
|
1160
|
-
|
1184
|
+
# update cumulative rewards per latent, to be used as default fitness score
|
1161
1185
|
|
1162
|
-
|
1186
|
+
rewards_per_latent_episode[latent_id, episode_id] += reward
|
1187
|
+
|
1188
|
+
# store memories
|
1163
1189
|
|
1164
|
-
|
1165
|
-
|
1166
|
-
|
1167
|
-
|
1168
|
-
|
1169
|
-
|
1170
|
-
|
1171
|
-
|
1172
|
-
|
1173
|
-
|
1174
|
-
action,
|
1175
|
-
log_prob,
|
1176
|
-
reward,
|
1177
|
-
value,
|
1178
|
-
done
|
1179
|
-
)
|
1190
|
+
memory = Memory(
|
1191
|
+
tensor(episode_id),
|
1192
|
+
state,
|
1193
|
+
tensor(latent_id),
|
1194
|
+
action,
|
1195
|
+
log_prob,
|
1196
|
+
reward,
|
1197
|
+
value,
|
1198
|
+
done
|
1199
|
+
)
|
1180
1200
|
|
1181
|
-
|
1201
|
+
memories.append(memory)
|
1182
1202
|
|
1183
|
-
|
1203
|
+
time += 1
|
1184
1204
|
|
1185
|
-
|
1205
|
+
# need the final next value for GAE, iiuc
|
1186
1206
|
|
1187
|
-
|
1207
|
+
next_value = temp_batch_dim(self.agent.get_critic_values)(state, latent = latent)
|
1188
1208
|
|
1189
|
-
|
1190
|
-
|
1191
|
-
|
1192
|
-
|
1209
|
+
memory_for_gae = memory._replace(
|
1210
|
+
episode_id = invalid_episode,
|
1211
|
+
value = next_value
|
1212
|
+
)
|
1193
1213
|
|
1194
|
-
|
1214
|
+
memories.append(memory_for_gae)
|
1195
1215
|
|
1196
1216
|
return MemoriesAndCumulativeRewards(
|
1197
1217
|
memories = memories,
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: evolutionary-policy-optimization
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.63
|
4
4
|
Summary: EPO - Pytorch
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
|
@@ -0,0 +1,9 @@
|
|
1
|
+
evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
|
2
|
+
evolutionary_policy_optimization/distributed.py,sha256=7KgZdeS_wxBHo_du9XZFB1Cu318J-Bp66Xdr6Log_20,2423
|
3
|
+
evolutionary_policy_optimization/epo.py,sha256=DSG2fYWLk0cyHhfoiwqmSzh2TBOWhz25sD1oWIM5p1k,36695
|
4
|
+
evolutionary_policy_optimization/experimental.py,sha256=-IgqjJ_Wk_CMB1y9YYWpoYqTG9GZHAS6kbRdTluVevg,1563
|
5
|
+
evolutionary_policy_optimization/mock_env.py,sha256=gvATGA51Ym5sf3jiR2VmlpjiCcT7KCDDY_SrR-MEwsU,941
|
6
|
+
evolutionary_policy_optimization-0.0.63.dist-info/METADATA,sha256=X2FKT8WJ9T1t0ydEdtxrJsJGXY1ubfvydQSykv2G03M,6220
|
7
|
+
evolutionary_policy_optimization-0.0.63.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
evolutionary_policy_optimization-0.0.63.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
evolutionary_policy_optimization-0.0.63.dist-info/RECORD,,
|
@@ -1,9 +0,0 @@
|
|
1
|
-
evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
|
2
|
-
evolutionary_policy_optimization/distributed.py,sha256=7KgZdeS_wxBHo_du9XZFB1Cu318J-Bp66Xdr6Log_20,2423
|
3
|
-
evolutionary_policy_optimization/epo.py,sha256=lWhHpsfq6vpri6yeDXSTLRMKGPwl0kt3klh0fVaInSs,35921
|
4
|
-
evolutionary_policy_optimization/experimental.py,sha256=-IgqjJ_Wk_CMB1y9YYWpoYqTG9GZHAS6kbRdTluVevg,1563
|
5
|
-
evolutionary_policy_optimization/mock_env.py,sha256=202KJ5g57wQvOzhGYzgHfBa7Y2do5uuDvl5kFg5o73g,934
|
6
|
-
evolutionary_policy_optimization-0.0.62.dist-info/METADATA,sha256=oqJyUOXJwHrdf6JCVKPfOmhGJbXgqOmPWN_46l0JtWs,6220
|
7
|
-
evolutionary_policy_optimization-0.0.62.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
-
evolutionary_policy_optimization-0.0.62.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
-
evolutionary_policy_optimization-0.0.62.dist-info/RECORD,,
|
File without changes
|