evolutionary-policy-optimization 0.0.61__py3-none-any.whl → 0.0.63__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evolutionary_policy_optimization/epo.py +107 -71
- evolutionary_policy_optimization/mock_env.py +1 -1
- {evolutionary_policy_optimization-0.0.61.dist-info → evolutionary_policy_optimization-0.0.63.dist-info}/METADATA +1 -1
- evolutionary_policy_optimization-0.0.63.dist-info/RECORD +9 -0
- evolutionary_policy_optimization-0.0.61.dist-info/RECORD +0 -9
- {evolutionary_policy_optimization-0.0.61.dist-info → evolutionary_policy_optimization-0.0.63.dist-info}/WHEEL +0 -0
- {evolutionary_policy_optimization-0.0.61.dist-info → evolutionary_policy_optimization-0.0.63.dist-info}/licenses/LICENSE +0 -0
@@ -3,6 +3,7 @@ from typing import Callable
|
|
3
3
|
|
4
4
|
from pathlib import Path
|
5
5
|
from math import ceil
|
6
|
+
from itertools import product
|
6
7
|
from functools import partial, wraps
|
7
8
|
from collections import namedtuple
|
8
9
|
from random import randrange
|
@@ -104,8 +105,11 @@ def temp_batch_dim(fn):
|
|
104
105
|
|
105
106
|
# fitness related
|
106
107
|
|
107
|
-
def get_fitness_scores(
|
108
|
-
|
108
|
+
def get_fitness_scores(
|
109
|
+
cum_rewards, # Float['gene episodes']
|
110
|
+
memories
|
111
|
+
): # Float['gene']
|
112
|
+
return cum_rewards.sum(dim = -1) # sum all rewards across episodes, but could override this function for normalizing with whatever
|
109
113
|
|
110
114
|
# generalized advantage estimate
|
111
115
|
|
@@ -684,7 +688,8 @@ class Agent(Module):
|
|
684
688
|
),
|
685
689
|
actor_loss_kwargs: dict = dict(
|
686
690
|
eps_clip = 0.2,
|
687
|
-
entropy_weight = .01
|
691
|
+
entropy_weight = .01,
|
692
|
+
norm_advantages = True
|
688
693
|
),
|
689
694
|
ema_kwargs: dict = dict(),
|
690
695
|
actor_optim_kwargs: dict = dict(),
|
@@ -826,9 +831,7 @@ class Agent(Module):
|
|
826
831
|
memories_and_cumulative_rewards: MemoriesAndCumulativeRewards,
|
827
832
|
epochs = 2
|
828
833
|
):
|
829
|
-
memories,
|
830
|
-
|
831
|
-
fitness_scores = self.get_fitness_scores(cumulative_rewards, memories)
|
834
|
+
memories, rewards_per_latent_episode = memories_and_cumulative_rewards
|
832
835
|
|
833
836
|
# stack memories
|
834
837
|
|
@@ -839,7 +842,13 @@ class Agent(Module):
|
|
839
842
|
if is_distributed():
|
840
843
|
memories = map(partial(all_gather_variable_dim, dim = 0), memories)
|
841
844
|
|
842
|
-
|
845
|
+
rewards_per_latent_episode = dist.all_reduce(rewards_per_latent_episode)
|
846
|
+
|
847
|
+
# calculate fitness scores
|
848
|
+
|
849
|
+
fitness_scores = self.get_fitness_scores(rewards_per_latent_episode, memories)
|
850
|
+
|
851
|
+
# process memories
|
843
852
|
|
844
853
|
(
|
845
854
|
episode_ids,
|
@@ -854,12 +863,16 @@ class Agent(Module):
|
|
854
863
|
|
855
864
|
masks = 1. - dones.float()
|
856
865
|
|
866
|
+
# generalized advantage estimate
|
867
|
+
|
857
868
|
advantages = self.calc_gae(
|
858
869
|
rewards[:-1],
|
859
870
|
values,
|
860
871
|
masks[:-1],
|
861
872
|
)
|
862
873
|
|
874
|
+
# dataset and dataloader
|
875
|
+
|
863
876
|
valid_episode = episode_ids >= 0
|
864
877
|
|
865
878
|
dataset = TensorDataset(
|
@@ -871,6 +884,8 @@ class Agent(Module):
|
|
871
884
|
|
872
885
|
dataloader = DataLoader(dataset, batch_size = self.batch_size, shuffle = True)
|
873
886
|
|
887
|
+
# updating actor and critic
|
888
|
+
|
874
889
|
self.actor.train()
|
875
890
|
self.critic.train()
|
876
891
|
|
@@ -954,7 +969,8 @@ def actor_loss(
|
|
954
969
|
advantages, # Float[b]
|
955
970
|
eps_clip = 0.2,
|
956
971
|
entropy_weight = .01,
|
957
|
-
eps = 1e-5
|
972
|
+
eps = 1e-5,
|
973
|
+
norm_advantages = True
|
958
974
|
):
|
959
975
|
batch = logits.shape[0]
|
960
976
|
|
@@ -966,7 +982,8 @@ def actor_loss(
|
|
966
982
|
|
967
983
|
clipped_ratio = ratio.clamp(min = 1. - eps_clip, max = 1. + eps_clip)
|
968
984
|
|
969
|
-
|
985
|
+
if norm_advantages:
|
986
|
+
advantages = F.layer_norm(advantages, (batch,), eps = eps)
|
970
987
|
|
971
988
|
actor_loss = -torch.min(clipped_ratio * advantages, ratio * advantages)
|
972
989
|
|
@@ -1041,7 +1058,7 @@ Memory = namedtuple('Memory', [
|
|
1041
1058
|
|
1042
1059
|
MemoriesAndCumulativeRewards = namedtuple('MemoriesAndCumulativeRewards', [
|
1043
1060
|
'memories',
|
1044
|
-
'cumulative_rewards'
|
1061
|
+
'cumulative_rewards' # Float['latent episodes']
|
1045
1062
|
])
|
1046
1063
|
|
1047
1064
|
class EPO(Module):
|
@@ -1067,29 +1084,56 @@ class EPO(Module):
|
|
1067
1084
|
def device(self):
|
1068
1085
|
return self.dummy.device
|
1069
1086
|
|
1070
|
-
def
|
1087
|
+
def rollouts_for_machine(
|
1088
|
+
self,
|
1089
|
+
fix_environ_across_latents = False
|
1090
|
+
): # -> (<latent_id>, <episode_id>, <maybe synced env seed>) for the machine
|
1091
|
+
|
1071
1092
|
num_latents = self.num_latents
|
1093
|
+
episodes = self.episodes_per_latent
|
1094
|
+
num_latent_episodes = num_latents * episodes
|
1095
|
+
|
1096
|
+
# if fixing environment across latents, compute all the environment seeds upfront for simplicity
|
1097
|
+
|
1098
|
+
environment_seeds = None
|
1099
|
+
|
1100
|
+
if fix_environ_across_latents:
|
1101
|
+
environment_seeds = torch.randint(0, int(1e6), (episodes,))
|
1102
|
+
|
1103
|
+
if is_distributed():
|
1104
|
+
dist.all_reduce(environment_seeds) # reduce sum as a way to synchronize. it's fine
|
1105
|
+
|
1106
|
+
# get number of machines, and this machine id
|
1072
1107
|
|
1073
1108
|
world_size, rank = get_world_and_rank()
|
1074
1109
|
|
1075
|
-
assert
|
1076
|
-
|
1110
|
+
assert num_latent_episodes >= world_size, f'number of ({self.num_latents} latents x {self.episodes_per_latent} episodes) ({num_latent_episodes}) must be greater than world size ({world_size}) for now'
|
1111
|
+
|
1112
|
+
latent_episode_permutations = list(product(range(num_latents), range(episodes)))
|
1077
1113
|
|
1078
|
-
|
1114
|
+
num_rollouts_per_machine = ceil(num_latent_episodes / world_size)
|
1079
1115
|
|
1080
|
-
for i in range(
|
1081
|
-
|
1116
|
+
for i in range(num_rollouts_per_machine):
|
1117
|
+
rollout_id = rank * num_rollouts_per_machine + i
|
1082
1118
|
|
1083
|
-
if
|
1119
|
+
if rollout_id >= num_latent_episodes:
|
1084
1120
|
continue
|
1085
1121
|
|
1086
|
-
|
1122
|
+
latent_id, episode_id = latent_episode_permutations[rollout_id]
|
1123
|
+
|
1124
|
+
# maybe synchronized environment seed
|
1125
|
+
|
1126
|
+
maybe_seed = None
|
1127
|
+
if fix_environ_across_latents:
|
1128
|
+
maybe_seed = environment_seeds[episode_id]
|
1129
|
+
|
1130
|
+
yield latent_id, episode_id, maybe_seed
|
1087
1131
|
|
1088
1132
|
@torch.no_grad()
|
1089
1133
|
def forward(
|
1090
1134
|
self,
|
1091
1135
|
env,
|
1092
|
-
|
1136
|
+
fix_environ_across_latents = True
|
1093
1137
|
) -> MemoriesAndCumulativeRewards:
|
1094
1138
|
|
1095
1139
|
self.agent.eval()
|
@@ -1098,86 +1142,78 @@ class EPO(Module):
|
|
1098
1142
|
|
1099
1143
|
memories: list[Memory] = []
|
1100
1144
|
|
1101
|
-
|
1145
|
+
rewards_per_latent_episode = torch.zeros((self.num_latents, self.episodes_per_latent))
|
1102
1146
|
|
1103
|
-
|
1147
|
+
rollout_gen = self.rollouts_for_machine(fix_environ_across_latents)
|
1104
1148
|
|
1105
|
-
for episode_id in tqdm(
|
1149
|
+
for latent_id, episode_id, maybe_seed in tqdm(rollout_gen, desc = 'rollout'):
|
1106
1150
|
|
1107
|
-
|
1151
|
+
time = 0
|
1108
1152
|
|
1109
|
-
#
|
1153
|
+
# initial state
|
1110
1154
|
|
1111
|
-
|
1155
|
+
reset_kwargs = dict()
|
1112
1156
|
|
1113
|
-
if
|
1114
|
-
seed =
|
1115
|
-
env_reset_kwargs = dict(seed = seed)
|
1157
|
+
if fix_environ_across_latents:
|
1158
|
+
reset_kwargs.update(seed = maybe_seed)
|
1116
1159
|
|
1117
|
-
|
1160
|
+
state = env.reset(**reset_kwargs)
|
1118
1161
|
|
1119
|
-
|
1120
|
-
time = 0
|
1162
|
+
# get latent from pool
|
1121
1163
|
|
1122
|
-
|
1164
|
+
latent = self.agent.latent_gene_pool(latent_id = latent_id)
|
1123
1165
|
|
1124
|
-
|
1166
|
+
# until maximum episode length
|
1125
1167
|
|
1126
|
-
|
1168
|
+
done = tensor(False)
|
1127
1169
|
|
1128
|
-
|
1170
|
+
while time < self.max_episode_length and not done:
|
1129
1171
|
|
1130
|
-
#
|
1172
|
+
# sample action
|
1131
1173
|
|
1132
|
-
|
1174
|
+
action, log_prob = temp_batch_dim(self.agent.get_actor_actions)(state, latent = latent, sample = True, temperature = self.action_sample_temperature)
|
1133
1175
|
|
1134
|
-
|
1176
|
+
# values
|
1135
1177
|
|
1136
|
-
|
1178
|
+
value = temp_batch_dim(self.agent.get_critic_values)(state, latent = latent, use_ema_if_available = True)
|
1137
1179
|
|
1138
|
-
|
1180
|
+
# get the next state, action, and reward
|
1139
1181
|
|
1140
|
-
|
1182
|
+
state, reward, done = env(action)
|
1141
1183
|
|
1142
|
-
|
1184
|
+
# update cumulative rewards per latent, to be used as default fitness score
|
1143
1185
|
|
1144
|
-
|
1186
|
+
rewards_per_latent_episode[latent_id, episode_id] += reward
|
1187
|
+
|
1188
|
+
# store memories
|
1145
1189
|
|
1146
|
-
|
1147
|
-
|
1148
|
-
|
1149
|
-
|
1150
|
-
|
1151
|
-
|
1152
|
-
|
1153
|
-
|
1154
|
-
|
1155
|
-
|
1156
|
-
state,
|
1157
|
-
tensor(latent_id),
|
1158
|
-
action,
|
1159
|
-
log_prob,
|
1160
|
-
reward,
|
1161
|
-
value,
|
1162
|
-
done
|
1163
|
-
)
|
1190
|
+
memory = Memory(
|
1191
|
+
tensor(episode_id),
|
1192
|
+
state,
|
1193
|
+
tensor(latent_id),
|
1194
|
+
action,
|
1195
|
+
log_prob,
|
1196
|
+
reward,
|
1197
|
+
value,
|
1198
|
+
done
|
1199
|
+
)
|
1164
1200
|
|
1165
|
-
|
1201
|
+
memories.append(memory)
|
1166
1202
|
|
1167
|
-
|
1203
|
+
time += 1
|
1168
1204
|
|
1169
|
-
|
1205
|
+
# need the final next value for GAE, iiuc
|
1170
1206
|
|
1171
|
-
|
1207
|
+
next_value = temp_batch_dim(self.agent.get_critic_values)(state, latent = latent)
|
1172
1208
|
|
1173
|
-
|
1174
|
-
|
1175
|
-
|
1176
|
-
|
1209
|
+
memory_for_gae = memory._replace(
|
1210
|
+
episode_id = invalid_episode,
|
1211
|
+
value = next_value
|
1212
|
+
)
|
1177
1213
|
|
1178
|
-
|
1214
|
+
memories.append(memory_for_gae)
|
1179
1215
|
|
1180
1216
|
return MemoriesAndCumulativeRewards(
|
1181
1217
|
memories = memories,
|
1182
|
-
cumulative_rewards =
|
1218
|
+
cumulative_rewards = rewards_per_latent_episode
|
1183
1219
|
)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: evolutionary-policy-optimization
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.63
|
4
4
|
Summary: EPO - Pytorch
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
|
@@ -0,0 +1,9 @@
|
|
1
|
+
evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
|
2
|
+
evolutionary_policy_optimization/distributed.py,sha256=7KgZdeS_wxBHo_du9XZFB1Cu318J-Bp66Xdr6Log_20,2423
|
3
|
+
evolutionary_policy_optimization/epo.py,sha256=DSG2fYWLk0cyHhfoiwqmSzh2TBOWhz25sD1oWIM5p1k,36695
|
4
|
+
evolutionary_policy_optimization/experimental.py,sha256=-IgqjJ_Wk_CMB1y9YYWpoYqTG9GZHAS6kbRdTluVevg,1563
|
5
|
+
evolutionary_policy_optimization/mock_env.py,sha256=gvATGA51Ym5sf3jiR2VmlpjiCcT7KCDDY_SrR-MEwsU,941
|
6
|
+
evolutionary_policy_optimization-0.0.63.dist-info/METADATA,sha256=X2FKT8WJ9T1t0ydEdtxrJsJGXY1ubfvydQSykv2G03M,6220
|
7
|
+
evolutionary_policy_optimization-0.0.63.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
evolutionary_policy_optimization-0.0.63.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
evolutionary_policy_optimization-0.0.63.dist-info/RECORD,,
|
@@ -1,9 +0,0 @@
|
|
1
|
-
evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
|
2
|
-
evolutionary_policy_optimization/distributed.py,sha256=7KgZdeS_wxBHo_du9XZFB1Cu318J-Bp66Xdr6Log_20,2423
|
3
|
-
evolutionary_policy_optimization/epo.py,sha256=kFT49rJdcmaDehfpx3YyhYhvAcp7S-gRWDkS2y20Q2Y,35377
|
4
|
-
evolutionary_policy_optimization/experimental.py,sha256=-IgqjJ_Wk_CMB1y9YYWpoYqTG9GZHAS6kbRdTluVevg,1563
|
5
|
-
evolutionary_policy_optimization/mock_env.py,sha256=202KJ5g57wQvOzhGYzgHfBa7Y2do5uuDvl5kFg5o73g,934
|
6
|
-
evolutionary_policy_optimization-0.0.61.dist-info/METADATA,sha256=3IbcY9kg71P6lTNxZaRBw3IYfDjcK4uTJJaFRD0Skwg,6220
|
7
|
-
evolutionary_policy_optimization-0.0.61.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
-
evolutionary_policy_optimization-0.0.61.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
-
evolutionary_policy_optimization-0.0.61.dist-info/RECORD,,
|
File without changes
|