evolutionary-policy-optimization 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from collections import namedtuple
4
+
3
5
  import torch
4
6
  from torch import nn, cat
5
7
  import torch.nn.functional as F
@@ -9,6 +11,8 @@ from torch.nn import Linear, Module, ModuleList
9
11
 
10
12
  from einops import rearrange, repeat
11
13
 
14
+ from assoc_scan import AssocScan
15
+
12
16
  # helpers
13
17
 
14
18
  def exists(v):
@@ -56,18 +60,18 @@ def actor_loss(
56
60
  ):
57
61
  log_probs = gather_log_prob(logits, actions)
58
62
 
59
- entropy = calc_entropy(logits)
60
-
61
63
  ratio = (log_probs - old_log_probs).exp()
62
64
 
63
- clipped_ratio = ratio.clamp(min = 1. - eps_clip, max = 1. + eps_clip)
64
-
65
65
  # classic clipped surrogate loss from ppo
66
66
 
67
+ clipped_ratio = ratio.clamp(min = 1. - eps_clip, max = 1. + eps_clip)
68
+
67
69
  actor_loss = -torch.min(clipped_ratio * advantage, ratio * advantage)
68
70
 
69
71
  # add entropy loss for exploration
70
72
 
73
+ entropy = calc_entropy(logits)
74
+
71
75
  entropy_aux_loss = -entropy_weight * entropy
72
76
 
73
77
  return actor_loss + entropy_aux_loss
@@ -80,6 +84,39 @@ def critic_loss(
80
84
  discounted_values = advantages + old_values
81
85
  return F.mse_loss(pred_values, discounted_values)
82
86
 
87
+ # generalized advantage estimate
88
+
89
+ def calc_generalized_advantage_estimate(
90
+ rewards, # Float[g n]
91
+ values, # Float[g n+1]
92
+ masks, # Bool[n]
93
+ gamma = 0.99,
94
+ lam = 0.95,
95
+ use_accelerated = None
96
+
97
+ ):
98
+ assert values.shape[-1] == (rewards.shape[-1] + 1)
99
+
100
+ use_accelerated = default(use_accelerated, rewards.is_cuda)
101
+ device = rewards.device
102
+
103
+ masks = repeat(masks, 'n -> g n', g = rewards.shape[0])
104
+
105
+ values, values_next = values[:, :-1], values[:, 1:]
106
+
107
+ delta = rewards + gamma * values_next * masks - values
108
+ gates = gamma * lam * masks
109
+
110
+ gates, delta = gates[..., :, None], delta[..., :, None]
111
+
112
+ scan = AssocScan(reverse = True, use_accelerated = use_accelerated)
113
+
114
+ gae = scan(gates, delta)
115
+
116
+ gae = gae[..., :, 0]
117
+
118
+ return gae
119
+
83
120
  # evolution related functions
84
121
 
85
122
  def crossover_latents(
@@ -183,6 +220,100 @@ class MLP(Module):
183
220
 
184
221
  return x
185
222
 
223
+ # actor, critic, and agent (actor + critic)
224
+ # eventually, should just create a separate repo and aggregate all the MLP related architectures
225
+
226
+ class Actor(Module):
227
+ def __init__(
228
+ self,
229
+ dim_in,
230
+ num_actions,
231
+ dim_hiddens: tuple[int, ...],
232
+ dim_latent = 0,
233
+ ):
234
+ super().__init__()
235
+
236
+ assert len(dim_hiddens) >= 2
237
+ dim_first, *_, dim_last = dim_hiddens
238
+
239
+ self.init_layer = nn.Sequential(
240
+ nn.Linear(dim_in, dim_first),
241
+ nn.SiLU()
242
+ )
243
+
244
+ self.mlp = MLP(dims = dim_hiddens, dim_latent = dim_latent)
245
+
246
+ self.to_out = nn.Sequential(
247
+ nn.SiLU(),
248
+ nn.Linear(dim_last, num_actions),
249
+ )
250
+
251
+ def forward(
252
+ self,
253
+ state,
254
+ latent
255
+ ):
256
+
257
+ hidden = self.init_layer(state)
258
+
259
+ hidden = self.mlp(state, latent)
260
+
261
+ return self.to_out(hidden)
262
+
263
+ class Critic(Module):
264
+ def __init__(
265
+ self,
266
+ dim_in,
267
+ dim_hiddens: tuple[int, ...],
268
+ dim_latent = 0,
269
+ ):
270
+ super().__init__()
271
+
272
+ assert len(dim_hiddens) >= 2
273
+ dim_first, *_, dim_last = dim_hiddens
274
+
275
+ self.init_layer = nn.Sequential(
276
+ nn.Linear(dim_in, dim_first),
277
+ nn.SiLU()
278
+ )
279
+
280
+ self.mlp = MLP(dims = dim_hiddens, dim_latent = dim_latent)
281
+
282
+ self.to_out = nn.Sequential(
283
+ nn.SiLU(),
284
+ nn.Linear(dim_last, 1),
285
+ Rearrange('... 1 -> ...')
286
+ )
287
+
288
+ def forward(
289
+ self,
290
+ state,
291
+ latent
292
+ ):
293
+
294
+ hidden = self.init_layer(state)
295
+
296
+ hidden = self.mlp(state, latent)
297
+
298
+ return self.to_out(hidden)
299
+
300
+ class Agent(Module):
301
+ def __init__(
302
+ self,
303
+ actor: Actor,
304
+ critic: Critic,
305
+ ):
306
+ super().__init__()
307
+
308
+ self.actor = actor
309
+ self.critic = critic
310
+
311
+ def forward(
312
+ self,
313
+ memories: list[Memory]
314
+ ):
315
+ raise NotImplementedError
316
+
186
317
  # classes
187
318
 
188
319
  class LatentGenePool(Module):
@@ -333,3 +464,35 @@ class LatentGenePool(Module):
333
464
  latent = latent,
334
465
  **kwargs
335
466
  )
467
+
468
+ # EPO - which is just PPO with natural selection of a population of latent variables conditioning the agent
469
+ # the tricky part is that the latent ids for each episode / trajectory needs to be tracked
470
+
471
+ Memory = namedtuple('Memory', [
472
+ 'state',
473
+ 'latent_gene_id',
474
+ 'action',
475
+ 'log_prob',
476
+ 'reward',
477
+ 'values',
478
+ 'done'
479
+ ])
480
+
481
+ class EPO(Module):
482
+
483
+ def __init__(
484
+ self,
485
+ agent: Agent,
486
+ latent_gene_pool: LatentGenePool
487
+ ):
488
+ super().__init__()
489
+
490
+ self.agent = agent
491
+ self.latent_gene_pool = latent_gene_pool
492
+
493
+ def forward(
494
+ self,
495
+ env
496
+ ) -> list[Memory]:
497
+
498
+ raise NotImplementedError
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.0.3
3
+ Version: 0.0.5
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -34,6 +34,7 @@ Classifier: License :: OSI Approved :: MIT License
34
34
  Classifier: Programming Language :: Python :: 3.8
35
35
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
36
  Requires-Python: >=3.8
37
+ Requires-Dist: assoc-scan
37
38
  Requires-Dist: einops>=0.8.0
38
39
  Requires-Dist: torch>=2.2
39
40
  Requires-Dist: tqdm
@@ -0,0 +1,7 @@
1
+ evolutionary_policy_optimization/__init__.py,sha256=Nu-_OMA8abe4AiW9Gw6MvbZH0JZpMHMqjeXmkC9-7UU,81
2
+ evolutionary_policy_optimization/epo.py,sha256=lDhMV535MhUw1di7D7RM-Rr_J6aiuLqV-puh4EaNCd8,13455
3
+ evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
4
+ evolutionary_policy_optimization-0.0.5.dist-info/METADATA,sha256=uzkB4DrpzLLxbMEeiTID4CDxDxmEX1pO9fabwryDQcY,4098
5
+ evolutionary_policy_optimization-0.0.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
+ evolutionary_policy_optimization-0.0.5.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
+ evolutionary_policy_optimization-0.0.5.dist-info/RECORD,,
@@ -1,7 +0,0 @@
1
- evolutionary_policy_optimization/__init__.py,sha256=Nu-_OMA8abe4AiW9Gw6MvbZH0JZpMHMqjeXmkC9-7UU,81
2
- evolutionary_policy_optimization/epo.py,sha256=gB9_ZPtVjjFilZEeXX1ZWh67ZSd-OBt6Xs6WyfjRSxI,9967
3
- evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
4
- evolutionary_policy_optimization-0.0.3.dist-info/METADATA,sha256=qlQ1NxiA-AVWiNlpDcMI9hVqBg1g5DY9AvISpnrKkV0,4072
5
- evolutionary_policy_optimization-0.0.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
- evolutionary_policy_optimization-0.0.3.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
- evolutionary_policy_optimization-0.0.3.dist-info/RECORD,,