evolutionary-policy-optimization 0.0.62__py3-none-any.whl → 0.0.63__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,6 +3,7 @@ from typing import Callable
3
3
 
4
4
  from pathlib import Path
5
5
  from math import ceil
6
+ from itertools import product
6
7
  from functools import partial, wraps
7
8
  from collections import namedtuple
8
9
  from random import randrange
@@ -1083,29 +1084,56 @@ class EPO(Module):
1083
1084
  def device(self):
1084
1085
  return self.dummy.device
1085
1086
 
1086
- def latents_for_machine(self):
1087
+ def rollouts_for_machine(
1088
+ self,
1089
+ fix_environ_across_latents = False
1090
+ ): # -> (<latent_id>, <episode_id>, <maybe synced env seed>) for the machine
1091
+
1087
1092
  num_latents = self.num_latents
1093
+ episodes = self.episodes_per_latent
1094
+ num_latent_episodes = num_latents * episodes
1095
+
1096
+ # if fixing environment across latents, compute all the environment seeds upfront for simplicity
1097
+
1098
+ environment_seeds = None
1099
+
1100
+ if fix_environ_across_latents:
1101
+ environment_seeds = torch.randint(0, int(1e6), (episodes,))
1102
+
1103
+ if is_distributed():
1104
+ dist.all_reduce(environment_seeds) # reduce sum as a way to synchronize. it's fine
1105
+
1106
+ # get number of machines, and this machine id
1088
1107
 
1089
1108
  world_size, rank = get_world_and_rank()
1090
1109
 
1091
- assert num_latents >= world_size, 'number of latents must be greater than world size for now'
1092
- assert rank < world_size
1110
+ assert num_latent_episodes >= world_size, f'number of ({self.num_latents} latents x {self.episodes_per_latent} episodes) ({num_latent_episodes}) must be greater than world size ({world_size}) for now'
1111
+
1112
+ latent_episode_permutations = list(product(range(num_latents), range(episodes)))
1093
1113
 
1094
- num_latents_per_machine = ceil(num_latents / world_size)
1114
+ num_rollouts_per_machine = ceil(num_latent_episodes / world_size)
1095
1115
 
1096
- for i in range(num_latents_per_machine):
1097
- latent_id = rank * num_latents_per_machine + i
1116
+ for i in range(num_rollouts_per_machine):
1117
+ rollout_id = rank * num_rollouts_per_machine + i
1098
1118
 
1099
- if latent_id >= num_latents:
1119
+ if rollout_id >= num_latent_episodes:
1100
1120
  continue
1101
1121
 
1102
- yield i
1122
+ latent_id, episode_id = latent_episode_permutations[rollout_id]
1123
+
1124
+ # maybe synchronized environment seed
1125
+
1126
+ maybe_seed = None
1127
+ if fix_environ_across_latents:
1128
+ maybe_seed = environment_seeds[episode_id]
1129
+
1130
+ yield latent_id, episode_id, maybe_seed
1103
1131
 
1104
1132
  @torch.no_grad()
1105
1133
  def forward(
1106
1134
  self,
1107
1135
  env,
1108
- fix_seed_across_latents = True
1136
+ fix_environ_across_latents = True
1109
1137
  ) -> MemoriesAndCumulativeRewards:
1110
1138
 
1111
1139
  self.agent.eval()
@@ -1116,82 +1144,74 @@ class EPO(Module):
1116
1144
 
1117
1145
  rewards_per_latent_episode = torch.zeros((self.num_latents, self.episodes_per_latent))
1118
1146
 
1119
- latent_ids_gen = self.latents_for_machine()
1120
-
1121
- for episode_id in tqdm(range(self.episodes_per_latent), desc = 'episode'):
1147
+ rollout_gen = self.rollouts_for_machine(fix_environ_across_latents)
1122
1148
 
1123
- maybe_barrier()
1149
+ for latent_id, episode_id, maybe_seed in tqdm(rollout_gen, desc = 'rollout'):
1124
1150
 
1125
- # maybe fix seed for environment across all latents
1151
+ time = 0
1126
1152
 
1127
- env_reset_kwargs = dict()
1153
+ # initial state
1128
1154
 
1129
- if fix_seed_across_latents:
1130
- seed = maybe_sync_seed(device = self.device)
1131
- env_reset_kwargs = dict(seed = seed)
1155
+ reset_kwargs = dict()
1132
1156
 
1133
- # for each latent (on a single machine for now)
1157
+ if fix_environ_across_latents:
1158
+ reset_kwargs.update(seed = maybe_seed)
1134
1159
 
1135
- for latent_id in tqdm(latent_ids_gen, desc = 'latent'):
1136
- time = 0
1160
+ state = env.reset(**reset_kwargs)
1137
1161
 
1138
- # initial state
1162
+ # get latent from pool
1139
1163
 
1140
- state = env.reset(**env_reset_kwargs)
1164
+ latent = self.agent.latent_gene_pool(latent_id = latent_id)
1141
1165
 
1142
- # get latent from pool
1166
+ # until maximum episode length
1143
1167
 
1144
- latent = self.agent.latent_gene_pool(latent_id = latent_id)
1168
+ done = tensor(False)
1145
1169
 
1146
- # until maximum episode length
1170
+ while time < self.max_episode_length and not done:
1147
1171
 
1148
- done = tensor(False)
1172
+ # sample action
1149
1173
 
1150
- while time < self.max_episode_length and not done:
1174
+ action, log_prob = temp_batch_dim(self.agent.get_actor_actions)(state, latent = latent, sample = True, temperature = self.action_sample_temperature)
1151
1175
 
1152
- # sample action
1176
+ # values
1153
1177
 
1154
- action, log_prob = temp_batch_dim(self.agent.get_actor_actions)(state, latent = latent, sample = True, temperature = self.action_sample_temperature)
1178
+ value = temp_batch_dim(self.agent.get_critic_values)(state, latent = latent, use_ema_if_available = True)
1155
1179
 
1156
- # values
1180
+ # get the next state, action, and reward
1157
1181
 
1158
- value = temp_batch_dim(self.agent.get_critic_values)(state, latent = latent, use_ema_if_available = True)
1182
+ state, reward, done = env(action)
1159
1183
 
1160
- # get the next state, action, and reward
1184
+ # update cumulative rewards per latent, to be used as default fitness score
1161
1185
 
1162
- state, reward, done = env(action)
1186
+ rewards_per_latent_episode[latent_id, episode_id] += reward
1187
+
1188
+ # store memories
1163
1189
 
1164
- # update cumulative rewards per latent, to be used as default fitness score
1165
-
1166
- rewards_per_latent_episode[latent_id, episode_id] += reward
1167
-
1168
- # store memories
1169
-
1170
- memory = Memory(
1171
- tensor(episode_id),
1172
- state,
1173
- tensor(latent_id),
1174
- action,
1175
- log_prob,
1176
- reward,
1177
- value,
1178
- done
1179
- )
1190
+ memory = Memory(
1191
+ tensor(episode_id),
1192
+ state,
1193
+ tensor(latent_id),
1194
+ action,
1195
+ log_prob,
1196
+ reward,
1197
+ value,
1198
+ done
1199
+ )
1180
1200
 
1181
- memories.append(memory)
1201
+ memories.append(memory)
1182
1202
 
1183
- time += 1
1203
+ time += 1
1184
1204
 
1185
- # need the final next value for GAE, iiuc
1205
+ # need the final next value for GAE, iiuc
1186
1206
 
1187
- next_value = temp_batch_dim(self.agent.get_critic_values)(state, latent = latent)
1207
+ next_value = temp_batch_dim(self.agent.get_critic_values)(state, latent = latent)
1188
1208
 
1189
- memory_for_gae = memory._replace(
1190
- episode_id = invalid_episode,
1191
- value = next_value
1192
- )
1209
+ memory_for_gae = memory._replace(
1210
+ episode_id = invalid_episode,
1211
+ value = next_value
1212
+ )
1193
1213
 
1194
- memories.append(memory_for_gae)
1214
+ memories.append(memory_for_gae)
1195
1215
 
1196
1216
  return MemoriesAndCumulativeRewards(
1197
1217
  memories = memories,
@@ -26,7 +26,7 @@ class Env(Module):
26
26
 
27
27
  def reset(
28
28
  self,
29
- seed
29
+ seed = None
30
30
  ):
31
31
  state = randn(self.state_shape, device = self.device)
32
32
  return state
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.0.62
3
+ Version: 0.0.63
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -0,0 +1,9 @@
1
+ evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
2
+ evolutionary_policy_optimization/distributed.py,sha256=7KgZdeS_wxBHo_du9XZFB1Cu318J-Bp66Xdr6Log_20,2423
3
+ evolutionary_policy_optimization/epo.py,sha256=DSG2fYWLk0cyHhfoiwqmSzh2TBOWhz25sD1oWIM5p1k,36695
4
+ evolutionary_policy_optimization/experimental.py,sha256=-IgqjJ_Wk_CMB1y9YYWpoYqTG9GZHAS6kbRdTluVevg,1563
5
+ evolutionary_policy_optimization/mock_env.py,sha256=gvATGA51Ym5sf3jiR2VmlpjiCcT7KCDDY_SrR-MEwsU,941
6
+ evolutionary_policy_optimization-0.0.63.dist-info/METADATA,sha256=X2FKT8WJ9T1t0ydEdtxrJsJGXY1ubfvydQSykv2G03M,6220
7
+ evolutionary_policy_optimization-0.0.63.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ evolutionary_policy_optimization-0.0.63.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ evolutionary_policy_optimization-0.0.63.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
2
- evolutionary_policy_optimization/distributed.py,sha256=7KgZdeS_wxBHo_du9XZFB1Cu318J-Bp66Xdr6Log_20,2423
3
- evolutionary_policy_optimization/epo.py,sha256=lWhHpsfq6vpri6yeDXSTLRMKGPwl0kt3klh0fVaInSs,35921
4
- evolutionary_policy_optimization/experimental.py,sha256=-IgqjJ_Wk_CMB1y9YYWpoYqTG9GZHAS6kbRdTluVevg,1563
5
- evolutionary_policy_optimization/mock_env.py,sha256=202KJ5g57wQvOzhGYzgHfBa7Y2do5uuDvl5kFg5o73g,934
6
- evolutionary_policy_optimization-0.0.62.dist-info/METADATA,sha256=oqJyUOXJwHrdf6JCVKPfOmhGJbXgqOmPWN_46l0JtWs,6220
7
- evolutionary_policy_optimization-0.0.62.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- evolutionary_policy_optimization-0.0.62.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- evolutionary_policy_optimization-0.0.62.dist-info/RECORD,,