evolutionary-policy-optimization 0.0.61__py3-none-any.whl → 0.0.63__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,6 +3,7 @@ from typing import Callable
3
3
 
4
4
  from pathlib import Path
5
5
  from math import ceil
6
+ from itertools import product
6
7
  from functools import partial, wraps
7
8
  from collections import namedtuple
8
9
  from random import randrange
@@ -104,8 +105,11 @@ def temp_batch_dim(fn):
104
105
 
105
106
  # fitness related
106
107
 
107
- def get_fitness_scores(cum_rewards, memories):
108
- return cum_rewards
108
+ def get_fitness_scores(
109
+ cum_rewards, # Float['gene episodes']
110
+ memories
111
+ ): # Float['gene']
112
+ return cum_rewards.sum(dim = -1) # sum all rewards across episodes, but could override this function for normalizing with whatever
109
113
 
110
114
  # generalized advantage estimate
111
115
 
@@ -684,7 +688,8 @@ class Agent(Module):
684
688
  ),
685
689
  actor_loss_kwargs: dict = dict(
686
690
  eps_clip = 0.2,
687
- entropy_weight = .01
691
+ entropy_weight = .01,
692
+ norm_advantages = True
688
693
  ),
689
694
  ema_kwargs: dict = dict(),
690
695
  actor_optim_kwargs: dict = dict(),
@@ -826,9 +831,7 @@ class Agent(Module):
826
831
  memories_and_cumulative_rewards: MemoriesAndCumulativeRewards,
827
832
  epochs = 2
828
833
  ):
829
- memories, cumulative_rewards = memories_and_cumulative_rewards
830
-
831
- fitness_scores = self.get_fitness_scores(cumulative_rewards, memories)
834
+ memories, rewards_per_latent_episode = memories_and_cumulative_rewards
832
835
 
833
836
  # stack memories
834
837
 
@@ -839,7 +842,13 @@ class Agent(Module):
839
842
  if is_distributed():
840
843
  memories = map(partial(all_gather_variable_dim, dim = 0), memories)
841
844
 
842
- fitness_scores = all_gather_variable_dim(fitness_scores, dim = 0)
845
+ rewards_per_latent_episode = dist.all_reduce(rewards_per_latent_episode)
846
+
847
+ # calculate fitness scores
848
+
849
+ fitness_scores = self.get_fitness_scores(rewards_per_latent_episode, memories)
850
+
851
+ # process memories
843
852
 
844
853
  (
845
854
  episode_ids,
@@ -854,12 +863,16 @@ class Agent(Module):
854
863
 
855
864
  masks = 1. - dones.float()
856
865
 
866
+ # generalized advantage estimate
867
+
857
868
  advantages = self.calc_gae(
858
869
  rewards[:-1],
859
870
  values,
860
871
  masks[:-1],
861
872
  )
862
873
 
874
+ # dataset and dataloader
875
+
863
876
  valid_episode = episode_ids >= 0
864
877
 
865
878
  dataset = TensorDataset(
@@ -871,6 +884,8 @@ class Agent(Module):
871
884
 
872
885
  dataloader = DataLoader(dataset, batch_size = self.batch_size, shuffle = True)
873
886
 
887
+ # updating actor and critic
888
+
874
889
  self.actor.train()
875
890
  self.critic.train()
876
891
 
@@ -954,7 +969,8 @@ def actor_loss(
954
969
  advantages, # Float[b]
955
970
  eps_clip = 0.2,
956
971
  entropy_weight = .01,
957
- eps = 1e-5
972
+ eps = 1e-5,
973
+ norm_advantages = True
958
974
  ):
959
975
  batch = logits.shape[0]
960
976
 
@@ -966,7 +982,8 @@ def actor_loss(
966
982
 
967
983
  clipped_ratio = ratio.clamp(min = 1. - eps_clip, max = 1. + eps_clip)
968
984
 
969
- advantages = F.layer_norm(advantages, (batch,), eps = eps)
985
+ if norm_advantages:
986
+ advantages = F.layer_norm(advantages, (batch,), eps = eps)
970
987
 
971
988
  actor_loss = -torch.min(clipped_ratio * advantages, ratio * advantages)
972
989
 
@@ -1041,7 +1058,7 @@ Memory = namedtuple('Memory', [
1041
1058
 
1042
1059
  MemoriesAndCumulativeRewards = namedtuple('MemoriesAndCumulativeRewards', [
1043
1060
  'memories',
1044
- 'cumulative_rewards'
1061
+ 'cumulative_rewards' # Float['latent episodes']
1045
1062
  ])
1046
1063
 
1047
1064
  class EPO(Module):
@@ -1067,29 +1084,56 @@ class EPO(Module):
1067
1084
  def device(self):
1068
1085
  return self.dummy.device
1069
1086
 
1070
- def latents_for_machine(self):
1087
+ def rollouts_for_machine(
1088
+ self,
1089
+ fix_environ_across_latents = False
1090
+ ): # -> (<latent_id>, <episode_id>, <maybe synced env seed>) for the machine
1091
+
1071
1092
  num_latents = self.num_latents
1093
+ episodes = self.episodes_per_latent
1094
+ num_latent_episodes = num_latents * episodes
1095
+
1096
+ # if fixing environment across latents, compute all the environment seeds upfront for simplicity
1097
+
1098
+ environment_seeds = None
1099
+
1100
+ if fix_environ_across_latents:
1101
+ environment_seeds = torch.randint(0, int(1e6), (episodes,))
1102
+
1103
+ if is_distributed():
1104
+ dist.all_reduce(environment_seeds) # reduce sum as a way to synchronize. it's fine
1105
+
1106
+ # get number of machines, and this machine id
1072
1107
 
1073
1108
  world_size, rank = get_world_and_rank()
1074
1109
 
1075
- assert num_latents >= world_size, 'number of latents must be greater than world size for now'
1076
- assert rank < world_size
1110
+ assert num_latent_episodes >= world_size, f'number of ({self.num_latents} latents x {self.episodes_per_latent} episodes) ({num_latent_episodes}) must be greater than world size ({world_size}) for now'
1111
+
1112
+ latent_episode_permutations = list(product(range(num_latents), range(episodes)))
1077
1113
 
1078
- num_latents_per_machine = ceil(num_latents / world_size)
1114
+ num_rollouts_per_machine = ceil(num_latent_episodes / world_size)
1079
1115
 
1080
- for i in range(num_latents_per_machine):
1081
- latent_id = rank * num_latents_per_machine + i
1116
+ for i in range(num_rollouts_per_machine):
1117
+ rollout_id = rank * num_rollouts_per_machine + i
1082
1118
 
1083
- if latent_id >= num_latents:
1119
+ if rollout_id >= num_latent_episodes:
1084
1120
  continue
1085
1121
 
1086
- yield i
1122
+ latent_id, episode_id = latent_episode_permutations[rollout_id]
1123
+
1124
+ # maybe synchronized environment seed
1125
+
1126
+ maybe_seed = None
1127
+ if fix_environ_across_latents:
1128
+ maybe_seed = environment_seeds[episode_id]
1129
+
1130
+ yield latent_id, episode_id, maybe_seed
1087
1131
 
1088
1132
  @torch.no_grad()
1089
1133
  def forward(
1090
1134
  self,
1091
1135
  env,
1092
- fix_seed_across_latents = True
1136
+ fix_environ_across_latents = True
1093
1137
  ) -> MemoriesAndCumulativeRewards:
1094
1138
 
1095
1139
  self.agent.eval()
@@ -1098,86 +1142,78 @@ class EPO(Module):
1098
1142
 
1099
1143
  memories: list[Memory] = []
1100
1144
 
1101
- cumulative_rewards = torch.zeros((self.num_latents))
1145
+ rewards_per_latent_episode = torch.zeros((self.num_latents, self.episodes_per_latent))
1102
1146
 
1103
- latent_ids_gen = self.latents_for_machine()
1147
+ rollout_gen = self.rollouts_for_machine(fix_environ_across_latents)
1104
1148
 
1105
- for episode_id in tqdm(range(self.episodes_per_latent), desc = 'episode'):
1149
+ for latent_id, episode_id, maybe_seed in tqdm(rollout_gen, desc = 'rollout'):
1106
1150
 
1107
- maybe_barrier()
1151
+ time = 0
1108
1152
 
1109
- # maybe fix seed for environment across all latents
1153
+ # initial state
1110
1154
 
1111
- env_reset_kwargs = dict()
1155
+ reset_kwargs = dict()
1112
1156
 
1113
- if fix_seed_across_latents:
1114
- seed = maybe_sync_seed(device = self.device)
1115
- env_reset_kwargs = dict(seed = seed)
1157
+ if fix_environ_across_latents:
1158
+ reset_kwargs.update(seed = maybe_seed)
1116
1159
 
1117
- # for each latent (on a single machine for now)
1160
+ state = env.reset(**reset_kwargs)
1118
1161
 
1119
- for latent_id in tqdm(latent_ids_gen, desc = 'latent'):
1120
- time = 0
1162
+ # get latent from pool
1121
1163
 
1122
- # initial state
1164
+ latent = self.agent.latent_gene_pool(latent_id = latent_id)
1123
1165
 
1124
- state = env.reset(**env_reset_kwargs)
1166
+ # until maximum episode length
1125
1167
 
1126
- # get latent from pool
1168
+ done = tensor(False)
1127
1169
 
1128
- latent = self.agent.latent_gene_pool(latent_id = latent_id)
1170
+ while time < self.max_episode_length and not done:
1129
1171
 
1130
- # until maximum episode length
1172
+ # sample action
1131
1173
 
1132
- done = tensor(False)
1174
+ action, log_prob = temp_batch_dim(self.agent.get_actor_actions)(state, latent = latent, sample = True, temperature = self.action_sample_temperature)
1133
1175
 
1134
- while time < self.max_episode_length and not done:
1176
+ # values
1135
1177
 
1136
- # sample action
1178
+ value = temp_batch_dim(self.agent.get_critic_values)(state, latent = latent, use_ema_if_available = True)
1137
1179
 
1138
- action, log_prob = temp_batch_dim(self.agent.get_actor_actions)(state, latent = latent, sample = True, temperature = self.action_sample_temperature)
1180
+ # get the next state, action, and reward
1139
1181
 
1140
- # values
1182
+ state, reward, done = env(action)
1141
1183
 
1142
- value = temp_batch_dim(self.agent.get_critic_values)(state, latent = latent, use_ema_if_available = True)
1184
+ # update cumulative rewards per latent, to be used as default fitness score
1143
1185
 
1144
- # get the next state, action, and reward
1186
+ rewards_per_latent_episode[latent_id, episode_id] += reward
1187
+
1188
+ # store memories
1145
1189
 
1146
- state, reward, done = env(action)
1147
-
1148
- # update cumulative rewards per latent, to be used as default fitness score
1149
-
1150
- cumulative_rewards[latent_id] += reward
1151
-
1152
- # store memories
1153
-
1154
- memory = Memory(
1155
- tensor(episode_id),
1156
- state,
1157
- tensor(latent_id),
1158
- action,
1159
- log_prob,
1160
- reward,
1161
- value,
1162
- done
1163
- )
1190
+ memory = Memory(
1191
+ tensor(episode_id),
1192
+ state,
1193
+ tensor(latent_id),
1194
+ action,
1195
+ log_prob,
1196
+ reward,
1197
+ value,
1198
+ done
1199
+ )
1164
1200
 
1165
- memories.append(memory)
1201
+ memories.append(memory)
1166
1202
 
1167
- time += 1
1203
+ time += 1
1168
1204
 
1169
- # need the final next value for GAE, iiuc
1205
+ # need the final next value for GAE, iiuc
1170
1206
 
1171
- next_value = temp_batch_dim(self.agent.get_critic_values)(state, latent = latent)
1207
+ next_value = temp_batch_dim(self.agent.get_critic_values)(state, latent = latent)
1172
1208
 
1173
- memory_for_gae = memory._replace(
1174
- episode_id = invalid_episode,
1175
- value = next_value
1176
- )
1209
+ memory_for_gae = memory._replace(
1210
+ episode_id = invalid_episode,
1211
+ value = next_value
1212
+ )
1177
1213
 
1178
- memories.append(memory_for_gae)
1214
+ memories.append(memory_for_gae)
1179
1215
 
1180
1216
  return MemoriesAndCumulativeRewards(
1181
1217
  memories = memories,
1182
- cumulative_rewards = cumulative_rewards
1218
+ cumulative_rewards = rewards_per_latent_episode
1183
1219
  )
@@ -26,7 +26,7 @@ class Env(Module):
26
26
 
27
27
  def reset(
28
28
  self,
29
- seed
29
+ seed = None
30
30
  ):
31
31
  state = randn(self.state_shape, device = self.device)
32
32
  return state
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.0.61
3
+ Version: 0.0.63
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -0,0 +1,9 @@
1
+ evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
2
+ evolutionary_policy_optimization/distributed.py,sha256=7KgZdeS_wxBHo_du9XZFB1Cu318J-Bp66Xdr6Log_20,2423
3
+ evolutionary_policy_optimization/epo.py,sha256=DSG2fYWLk0cyHhfoiwqmSzh2TBOWhz25sD1oWIM5p1k,36695
4
+ evolutionary_policy_optimization/experimental.py,sha256=-IgqjJ_Wk_CMB1y9YYWpoYqTG9GZHAS6kbRdTluVevg,1563
5
+ evolutionary_policy_optimization/mock_env.py,sha256=gvATGA51Ym5sf3jiR2VmlpjiCcT7KCDDY_SrR-MEwsU,941
6
+ evolutionary_policy_optimization-0.0.63.dist-info/METADATA,sha256=X2FKT8WJ9T1t0ydEdtxrJsJGXY1ubfvydQSykv2G03M,6220
7
+ evolutionary_policy_optimization-0.0.63.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ evolutionary_policy_optimization-0.0.63.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ evolutionary_policy_optimization-0.0.63.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
2
- evolutionary_policy_optimization/distributed.py,sha256=7KgZdeS_wxBHo_du9XZFB1Cu318J-Bp66Xdr6Log_20,2423
3
- evolutionary_policy_optimization/epo.py,sha256=kFT49rJdcmaDehfpx3YyhYhvAcp7S-gRWDkS2y20Q2Y,35377
4
- evolutionary_policy_optimization/experimental.py,sha256=-IgqjJ_Wk_CMB1y9YYWpoYqTG9GZHAS6kbRdTluVevg,1563
5
- evolutionary_policy_optimization/mock_env.py,sha256=202KJ5g57wQvOzhGYzgHfBa7Y2do5uuDvl5kFg5o73g,934
6
- evolutionary_policy_optimization-0.0.61.dist-info/METADATA,sha256=3IbcY9kg71P6lTNxZaRBw3IYfDjcK4uTJJaFRD0Skwg,6220
7
- evolutionary_policy_optimization-0.0.61.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- evolutionary_policy_optimization-0.0.61.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- evolutionary_policy_optimization-0.0.61.dist-info/RECORD,,