evolutionary-policy-optimization 0.0.62__py3-none-any.whl → 0.0.64__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,6 +3,7 @@ from typing import Callable
3
3
 
4
4
  from pathlib import Path
5
5
  from math import ceil
6
+ from itertools import product
6
7
  from functools import partial, wraps
7
8
  from collections import namedtuple
8
9
  from random import randrange
@@ -1067,7 +1068,8 @@ class EPO(Module):
1067
1068
  agent: Agent,
1068
1069
  episodes_per_latent,
1069
1070
  max_episode_length,
1070
- action_sample_temperature = 1.
1071
+ action_sample_temperature = 1.,
1072
+ fix_environ_across_latents = True
1071
1073
  ):
1072
1074
  super().__init__()
1073
1075
  self.agent = agent
@@ -1076,6 +1078,7 @@ class EPO(Module):
1076
1078
  self.num_latents = agent.latent_gene_pool.num_latents
1077
1079
  self.episodes_per_latent = episodes_per_latent
1078
1080
  self.max_episode_length = max_episode_length
1081
+ self.fix_environ_across_latents = fix_environ_across_latents
1079
1082
 
1080
1083
  self.register_buffer('dummy', tensor(0))
1081
1084
 
@@ -1083,31 +1086,60 @@ class EPO(Module):
1083
1086
  def device(self):
1084
1087
  return self.dummy.device
1085
1088
 
1086
- def latents_for_machine(self):
1089
+ def rollouts_for_machine(
1090
+ self,
1091
+ fix_environ_across_latents = False
1092
+ ): # -> (<latent_id>, <episode_id>, <maybe synced env seed>) for the machine
1093
+
1087
1094
  num_latents = self.num_latents
1095
+ episodes = self.episodes_per_latent
1096
+ num_latent_episodes = num_latents * episodes
1097
+
1098
+ # if fixing environment across latents, compute all the environment seeds upfront for simplicity
1099
+
1100
+ environment_seeds = None
1101
+
1102
+ if fix_environ_across_latents:
1103
+ environment_seeds = torch.randint(0, int(1e6), (episodes,))
1104
+
1105
+ if is_distributed():
1106
+ dist.all_reduce(environment_seeds) # reduce sum as a way to synchronize. it's fine
1107
+
1108
+ # get number of machines, and this machine id
1088
1109
 
1089
1110
  world_size, rank = get_world_and_rank()
1090
1111
 
1091
- assert num_latents >= world_size, 'number of latents must be greater than world size for now'
1092
- assert rank < world_size
1112
+ assert num_latent_episodes >= world_size, f'number of ({self.num_latents} latents x {self.episodes_per_latent} episodes) ({num_latent_episodes}) must be greater than world size ({world_size}) for now'
1113
+
1114
+ latent_episode_permutations = list(product(range(num_latents), range(episodes)))
1093
1115
 
1094
- num_latents_per_machine = ceil(num_latents / world_size)
1116
+ num_rollouts_per_machine = ceil(num_latent_episodes / world_size)
1095
1117
 
1096
- for i in range(num_latents_per_machine):
1097
- latent_id = rank * num_latents_per_machine + i
1118
+ for i in range(num_rollouts_per_machine):
1119
+ rollout_id = rank * num_rollouts_per_machine + i
1098
1120
 
1099
- if latent_id >= num_latents:
1121
+ if rollout_id >= num_latent_episodes:
1100
1122
  continue
1101
1123
 
1102
- yield i
1124
+ latent_id, episode_id = latent_episode_permutations[rollout_id]
1125
+
1126
+ # maybe synchronized environment seed
1127
+
1128
+ maybe_seed = None
1129
+ if fix_environ_across_latents:
1130
+ maybe_seed = environment_seeds[episode_id]
1131
+
1132
+ yield latent_id, episode_id, maybe_seed
1103
1133
 
1104
1134
  @torch.no_grad()
1105
1135
  def forward(
1106
1136
  self,
1107
1137
  env,
1108
- fix_seed_across_latents = True
1138
+ fix_environ_across_latents = None
1109
1139
  ) -> MemoriesAndCumulativeRewards:
1110
1140
 
1141
+ fix_environ_across_latents = default(fix_environ_across_latents, self.fix_environ_across_latents)
1142
+
1111
1143
  self.agent.eval()
1112
1144
 
1113
1145
  invalid_episode = tensor(-1) # will use `episode_id` value of `-1` for the `next_value`, needed for not discarding last reward for generalized advantage estimate
@@ -1116,79 +1148,73 @@ class EPO(Module):
1116
1148
 
1117
1149
  rewards_per_latent_episode = torch.zeros((self.num_latents, self.episodes_per_latent))
1118
1150
 
1119
- latent_ids_gen = self.latents_for_machine()
1151
+ rollout_gen = self.rollouts_for_machine(fix_environ_across_latents)
1120
1152
 
1121
- for episode_id in tqdm(range(self.episodes_per_latent), desc = 'episode'):
1153
+ for latent_id, episode_id, maybe_seed in tqdm(rollout_gen, desc = 'rollout'):
1122
1154
 
1123
- maybe_barrier()
1155
+ time = 0
1124
1156
 
1125
- # maybe fix seed for environment across all latents
1157
+ # initial state
1126
1158
 
1127
- env_reset_kwargs = dict()
1159
+ reset_kwargs = dict()
1128
1160
 
1129
- if fix_seed_across_latents:
1130
- seed = maybe_sync_seed(device = self.device)
1131
- env_reset_kwargs = dict(seed = seed)
1161
+ if fix_environ_across_latents:
1162
+ reset_kwargs.update(seed = maybe_seed)
1132
1163
 
1133
- # for each latent (on a single machine for now)
1164
+ state = env.reset(**reset_kwargs)
1134
1165
 
1135
- for latent_id in tqdm(latent_ids_gen, desc = 'latent'):
1136
- time = 0
1166
+ # get latent from pool
1137
1167
 
1138
- # initial state
1168
+ latent = self.agent.latent_gene_pool(latent_id = latent_id)
1139
1169
 
1140
- state = env.reset(**env_reset_kwargs)
1170
+ # until maximum episode length
1141
1171
 
1142
- # get latent from pool
1172
+ done = tensor(False)
1143
1173
 
1144
- latent = self.agent.latent_gene_pool(latent_id = latent_id)
1174
+ while time < self.max_episode_length and not done:
1145
1175
 
1146
- # until maximum episode length
1176
+ # sample action
1147
1177
 
1148
- done = tensor(False)
1178
+ action, log_prob = temp_batch_dim(self.agent.get_actor_actions)(state, latent = latent, sample = True, temperature = self.action_sample_temperature)
1149
1179
 
1150
- while time < self.max_episode_length and not done:
1180
+ # values
1151
1181
 
1152
- # sample action
1182
+ value = temp_batch_dim(self.agent.get_critic_values)(state, latent = latent, use_ema_if_available = True)
1153
1183
 
1154
- action, log_prob = temp_batch_dim(self.agent.get_actor_actions)(state, latent = latent, sample = True, temperature = self.action_sample_temperature)
1184
+ # get the next state, action, and reward
1155
1185
 
1156
- # values
1186
+ state, reward, done = env(action)
1157
1187
 
1158
- value = temp_batch_dim(self.agent.get_critic_values)(state, latent = latent, use_ema_if_available = True)
1188
+ # update cumulative rewards per latent, to be used as default fitness score
1159
1189
 
1160
- # get the next state, action, and reward
1190
+ rewards_per_latent_episode[latent_id, episode_id] += reward
1191
+
1192
+ # store memories
1161
1193
 
1162
- state, reward, done = env(action)
1163
-
1164
- # update cumulative rewards per latent, to be used as default fitness score
1165
-
1166
- rewards_per_latent_episode[latent_id, episode_id] += reward
1167
-
1168
- # store memories
1169
-
1170
- memory = Memory(
1171
- tensor(episode_id),
1172
- state,
1173
- tensor(latent_id),
1174
- action,
1175
- log_prob,
1176
- reward,
1177
- value,
1178
- done
1179
- )
1194
+ memory = Memory(
1195
+ tensor(episode_id),
1196
+ state,
1197
+ tensor(latent_id),
1198
+ action,
1199
+ log_prob,
1200
+ reward,
1201
+ value,
1202
+ done
1203
+ )
1180
1204
 
1181
- memories.append(memory)
1205
+ memories.append(memory)
1182
1206
 
1183
- time += 1
1207
+ time += 1
1184
1208
 
1185
- # need the final next value for GAE, iiuc
1209
+ if not done:
1210
+ # add bootstrap value if truncated
1186
1211
 
1187
1212
  next_value = temp_batch_dim(self.agent.get_critic_values)(state, latent = latent)
1188
1213
 
1189
1214
  memory_for_gae = memory._replace(
1190
1215
  episode_id = invalid_episode,
1191
- value = next_value
1216
+ value = next_value,
1217
+ done = tensor(True)
1192
1218
  )
1193
1219
 
1194
1220
  memories.append(memory_for_gae)
@@ -26,7 +26,7 @@ class Env(Module):
26
26
 
27
27
  def reset(
28
28
  self,
29
- seed
29
+ seed = None
30
30
  ):
31
31
  state = randn(self.state_shape, device = self.device)
32
32
  return state
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.0.62
3
+ Version: 0.0.64
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -0,0 +1,9 @@
1
+ evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
2
+ evolutionary_policy_optimization/distributed.py,sha256=7KgZdeS_wxBHo_du9XZFB1Cu318J-Bp66Xdr6Log_20,2423
3
+ evolutionary_policy_optimization/epo.py,sha256=0_jC9Tbl6FiscLHklvTKtuQTwZL8egqFKW-4JUxxwvw,37001
4
+ evolutionary_policy_optimization/experimental.py,sha256=-IgqjJ_Wk_CMB1y9YYWpoYqTG9GZHAS6kbRdTluVevg,1563
5
+ evolutionary_policy_optimization/mock_env.py,sha256=gvATGA51Ym5sf3jiR2VmlpjiCcT7KCDDY_SrR-MEwsU,941
6
+ evolutionary_policy_optimization-0.0.64.dist-info/METADATA,sha256=vWdnTe2a86wTenEh29TNJlYEjD8A5CPtsyylxh4XsE0,6220
7
+ evolutionary_policy_optimization-0.0.64.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ evolutionary_policy_optimization-0.0.64.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ evolutionary_policy_optimization-0.0.64.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
2
- evolutionary_policy_optimization/distributed.py,sha256=7KgZdeS_wxBHo_du9XZFB1Cu318J-Bp66Xdr6Log_20,2423
3
- evolutionary_policy_optimization/epo.py,sha256=lWhHpsfq6vpri6yeDXSTLRMKGPwl0kt3klh0fVaInSs,35921
4
- evolutionary_policy_optimization/experimental.py,sha256=-IgqjJ_Wk_CMB1y9YYWpoYqTG9GZHAS6kbRdTluVevg,1563
5
- evolutionary_policy_optimization/mock_env.py,sha256=202KJ5g57wQvOzhGYzgHfBa7Y2do5uuDvl5kFg5o73g,934
6
- evolutionary_policy_optimization-0.0.62.dist-info/METADATA,sha256=oqJyUOXJwHrdf6JCVKPfOmhGJbXgqOmPWN_46l0JtWs,6220
7
- evolutionary_policy_optimization-0.0.62.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- evolutionary_policy_optimization-0.0.62.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- evolutionary_policy_optimization-0.0.62.dist-info/RECORD,,