evolutionary-policy-optimization 0.0.39__tar.gz → 0.0.40__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (13) hide show
  1. {evolutionary_policy_optimization-0.0.39 → evolutionary_policy_optimization-0.0.40}/PKG-INFO +1 -1
  2. {evolutionary_policy_optimization-0.0.39 → evolutionary_policy_optimization-0.0.40}/evolutionary_policy_optimization/epo.py +18 -14
  3. {evolutionary_policy_optimization-0.0.39 → evolutionary_policy_optimization-0.0.40}/pyproject.toml +1 -1
  4. {evolutionary_policy_optimization-0.0.39 → evolutionary_policy_optimization-0.0.40}/.github/workflows/python-publish.yml +0 -0
  5. {evolutionary_policy_optimization-0.0.39 → evolutionary_policy_optimization-0.0.40}/.github/workflows/test.yml +0 -0
  6. {evolutionary_policy_optimization-0.0.39 → evolutionary_policy_optimization-0.0.40}/.gitignore +0 -0
  7. {evolutionary_policy_optimization-0.0.39 → evolutionary_policy_optimization-0.0.40}/LICENSE +0 -0
  8. {evolutionary_policy_optimization-0.0.39 → evolutionary_policy_optimization-0.0.40}/README.md +0 -0
  9. {evolutionary_policy_optimization-0.0.39 → evolutionary_policy_optimization-0.0.40}/evolutionary_policy_optimization/__init__.py +0 -0
  10. {evolutionary_policy_optimization-0.0.39 → evolutionary_policy_optimization-0.0.40}/evolutionary_policy_optimization/experimental.py +0 -0
  11. {evolutionary_policy_optimization-0.0.39 → evolutionary_policy_optimization-0.0.40}/evolutionary_policy_optimization/mock_env.py +0 -0
  12. {evolutionary_policy_optimization-0.0.39 → evolutionary_policy_optimization-0.0.40}/requirements.txt +0 -0
  13. {evolutionary_policy_optimization-0.0.39 → evolutionary_policy_optimization-0.0.40}/tests/test_epo.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.0.39
3
+ Version: 0.0.40
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from functools import partial
3
+ from functools import partial, wraps
4
4
  from pathlib import Path
5
5
  from collections import namedtuple
6
6
 
@@ -9,6 +9,7 @@ from torch import nn, cat, stack, is_tensor, tensor
9
9
  import torch.nn.functional as F
10
10
  from torch.nn import Linear, Module, ModuleList
11
11
  from torch.utils.data import TensorDataset, DataLoader
12
+ from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
12
13
 
13
14
  import einx
14
15
  from einops import rearrange, repeat, einsum, pack
@@ -73,6 +74,19 @@ def gather_log_prob(
73
74
  log_prob = log_probs.gather(-1, indices)
74
75
  return rearrange(log_prob, '... 1 -> ...')
75
76
 
77
+ def temp_batch_dim(fn):
78
+
79
+ @wraps(fn)
80
+ def inner(*args, **kwargs):
81
+ args, kwargs = tree_map(lambda t: rearrange(t, '... -> 1 ...') if is_tensor(t) else t, (args, kwargs))
82
+
83
+ out = fn(*args, **kwargs)
84
+
85
+ out = tree_map(lambda t: rearrange(t, '1 ... -> ...') if is_tensor(t) else t, out)
86
+ return out
87
+
88
+ return inner
89
+
76
90
  # generalized advantage estimate
77
91
 
78
92
  def calc_generalized_advantage_estimate(
@@ -939,20 +953,13 @@ class EPO(Module):
939
953
 
940
954
  while time < self.max_episode_length:
941
955
 
942
- batched_state = rearrange(state, '... -> 1 ...')
943
-
944
956
  # sample action
945
957
 
946
- action, log_prob = self.agent.get_actor_actions(batched_state, latent = latent, sample = True, temperature = self.action_sample_temperature)
947
-
948
- action = rearrange(action, '1 ... -> ...')
949
- log_prob = rearrange(log_prob, '1 ... -> ...')
958
+ action, log_prob = temp_batch_dim(self.agent.get_actor_actions)(state, latent = latent, sample = True, temperature = self.action_sample_temperature)
950
959
 
951
960
  # values
952
961
 
953
- value = self.agent.get_critic_values(batched_state, latent = latent)
954
-
955
- value = rearrange(value, '1 ... -> ...')
962
+ value = temp_batch_dim(self.agent.get_critic_values)(state, latent = latent)
956
963
 
957
964
  # get the next state, action, and reward
958
965
 
@@ -981,10 +988,7 @@ class EPO(Module):
981
988
 
982
989
  # need the final next value for GAE, iiuc
983
990
 
984
- batched_state = rearrange(state, '... -> 1 ...')
985
-
986
- next_value = self.agent.get_critic_values(batched_state, latent = latent)
987
- next_value = rearrange(next_value, '1 ... -> ...')
991
+ next_value = temp_batch_dim(self.agent.get_critic_values)(state, latent = latent)
988
992
 
989
993
  memory_for_gae = memory._replace(
990
994
  episode_id = invalid_episode,
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "evolutionary-policy-optimization"
3
- version = "0.0.39"
3
+ version = "0.0.40"
4
4
  description = "EPO - Pytorch"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }