dmpo 0.0.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dmpo-0.0.2/.gitignore ADDED
@@ -0,0 +1,220 @@
1
+ *.sh
2
+
3
+ # Byte-compiled / optimized / DLL files
4
+ __pycache__/
5
+ *.py[codz]
6
+ *$py.class
7
+
8
+ # C extensions
9
+ *.so
10
+
11
+ # Distribution / packaging
12
+ .Python
13
+ build/
14
+ develop-eggs/
15
+ dist/
16
+ downloads/
17
+ eggs/
18
+ .eggs/
19
+ lib/
20
+ lib64/
21
+ parts/
22
+ sdist/
23
+ var/
24
+ wheels/
25
+ share/python-wheels/
26
+ *.egg-info/
27
+ .installed.cfg
28
+ *.egg
29
+ MANIFEST
30
+
31
+ # PyInstaller
32
+ # Usually these files are written by a python script from a template
33
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
34
+ *.manifest
35
+ *.spec
36
+
37
+ # Installer logs
38
+ pip-log.txt
39
+ pip-delete-this-directory.txt
40
+
41
+ # Unit test / coverage reports
42
+ htmlcov/
43
+ .tox/
44
+ .nox/
45
+ .coverage
46
+ .coverage.*
47
+ .cache
48
+ nosetests.xml
49
+ coverage.xml
50
+ *.cover
51
+ *.py.cover
52
+ .hypothesis/
53
+ .pytest_cache/
54
+ cover/
55
+
56
+ # Translations
57
+ *.mo
58
+ *.pot
59
+
60
+ # Django stuff:
61
+ *.log
62
+ local_settings.py
63
+ db.sqlite3
64
+ db.sqlite3-journal
65
+
66
+ # Flask stuff:
67
+ instance/
68
+ .webassets-cache
69
+
70
+ # Scrapy stuff:
71
+ .scrapy
72
+
73
+ # Sphinx documentation
74
+ docs/_build/
75
+
76
+ # PyBuilder
77
+ .pybuilder/
78
+ target/
79
+
80
+ # Jupyter Notebook
81
+ .ipynb_checkpoints
82
+
83
+ # IPython
84
+ profile_default/
85
+ ipython_config.py
86
+
87
+ # pyenv
88
+ # For a library or package, you might want to ignore these files since the code is
89
+ # intended to run in multiple environments; otherwise, check them in:
90
+ # .python-version
91
+
92
+ # pipenv
93
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
94
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
95
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
96
+ # install all needed dependencies.
97
+ # Pipfile.lock
98
+
99
+ # UV
100
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
101
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
102
+ # commonly ignored for libraries.
103
+ # uv.lock
104
+
105
+ # poetry
106
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
107
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
108
+ # commonly ignored for libraries.
109
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
110
+ # poetry.lock
111
+ # poetry.toml
112
+
113
+ # pdm
114
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
115
+ # pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
116
+ # https://pdm-project.org/en/latest/usage/project/#working-with-version-control
117
+ # pdm.lock
118
+ # pdm.toml
119
+ .pdm-python
120
+ .pdm-build/
121
+
122
+ # pixi
123
+ # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
124
+ # pixi.lock
125
+ # Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
126
+ # in the .venv directory. It is recommended not to include this directory in version control.
127
+ .pixi
128
+
129
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
130
+ __pypackages__/
131
+
132
+ # Celery stuff
133
+ celerybeat-schedule
134
+ celerybeat.pid
135
+
136
+ # Redis
137
+ *.rdb
138
+ *.aof
139
+ *.pid
140
+
141
+ # RabbitMQ
142
+ mnesia/
143
+ rabbitmq/
144
+ rabbitmq-data/
145
+
146
+ # ActiveMQ
147
+ activemq-data/
148
+
149
+ # SageMath parsed files
150
+ *.sage.py
151
+
152
+ # Environments
153
+ .env
154
+ .envrc
155
+ .venv
156
+ env/
157
+ venv/
158
+ ENV/
159
+ env.bak/
160
+ venv.bak/
161
+
162
+ # Spyder project settings
163
+ .spyderproject
164
+ .spyproject
165
+
166
+ # Rope project settings
167
+ .ropeproject
168
+
169
+ # mkdocs documentation
170
+ /site
171
+
172
+ # mypy
173
+ .mypy_cache/
174
+ .dmypy.json
175
+ dmypy.json
176
+
177
+ # Pyre type checker
178
+ .pyre/
179
+
180
+ # pytype static type analyzer
181
+ .pytype/
182
+
183
+ # Cython debug symbols
184
+ cython_debug/
185
+
186
+ # PyCharm
187
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
188
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
189
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
190
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
191
+ # .idea/
192
+
193
+ # Abstra
194
+ # Abstra is an AI-powered process automation framework.
195
+ # Ignore directories containing user credentials, local state, and settings.
196
+ # Learn more at https://abstra.io/docs
197
+ .abstra/
198
+
199
+ # Visual Studio Code
200
+ # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
201
+ # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
202
+ # and can be added to the global gitignore or merged into this file. However, if you prefer,
203
+ # you could uncomment the following to ignore the entire vscode folder
204
+ # .vscode/
205
+ # Temporary file for partial code execution
206
+ tempCodeRunnerFile.py
207
+
208
+ # Ruff stuff:
209
+ .ruff_cache/
210
+
211
+ # PyPI configuration file
212
+ .pypirc
213
+
214
+ # Marimo
215
+ marimo/_static/
216
+ marimo/_lsp/
217
+ __marimo__/
218
+
219
+ # Streamlit
220
+ .streamlit/secrets.toml
dmpo-0.0.2/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Phil Wang
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
dmpo-0.0.2/PKG-INFO ADDED
@@ -0,0 +1,133 @@
1
+ Metadata-Version: 2.4
2
+ Name: dmpo
3
+ Version: 0.0.2
4
+ Summary: Maximum a Posteriori Policy Optimization and Related Algorithms
5
+ Project-URL: Homepage, https://pypi.org/project/dmpo/
6
+ Project-URL: Repository, https://codeberg.org/lucidrains/dmpo
7
+ Author-email: Phil Wang <lucidrains@gmail.com>
8
+ License: MIT License
9
+
10
+ Copyright (c) 2026 Phil Wang
11
+
12
+ Permission is hereby granted, free of charge, to any person obtaining a copy
13
+ of this software and associated documentation files (the "Software"), to deal
14
+ in the Software without restriction, including without limitation the rights
15
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
16
+ copies of the Software, and to permit persons to whom the Software is
17
+ furnished to do so, subject to the following conditions:
18
+
19
+ The above copyright notice and this permission notice shall be included in all
20
+ copies or substantial portions of the Software.
21
+
22
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
23
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
24
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
25
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
26
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
27
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
28
+ SOFTWARE.
29
+ License-File: LICENSE
30
+ Keywords: artificial intelligence,deep learning,mpo,reinforcement learning,tpo
31
+ Classifier: Development Status :: 4 - Beta
32
+ Classifier: Intended Audience :: Developers
33
+ Classifier: License :: OSI Approved :: MIT License
34
+ Classifier: Programming Language :: Python :: 3.10
35
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
+ Requires-Python: >=3.10
37
+ Requires-Dist: accelerate
38
+ Requires-Dist: discrete-continuous-embed-readout
39
+ Requires-Dist: einops>=0.8.1
40
+ Requires-Dist: memmap-replay-buffer>=0.1.4
41
+ Requires-Dist: torch-einops-utils>=0.1.2
42
+ Requires-Dist: torch>=2.5
43
+ Requires-Dist: tqdm
44
+ Requires-Dist: x-mlps-pytorch
45
+ Requires-Dist: x-transformers
46
+ Provides-Extra: examples
47
+ Provides-Extra: test
48
+ Requires-Dist: pytest; extra == 'test'
49
+ Description-Content-Type: text/markdown
50
+
51
+ ## DMPO (wip)
52
+
53
+ Implementation and explorations into [MPO](https://arxiv.org/abs/1806.06920) / DMPO
54
+
55
+ ## Citations
56
+
57
+ ```bibtex
58
+ @article{Haarnoja_2024,
59
+ title = {Learning agile soccer skills for a bipedal robot with deep reinforcement learning},
60
+ volume = {9},
61
+ ISSN = {2470-9476},
62
+ url = {http://dx.doi.org/10.1126/scirobotics.adi8022},
63
+ DOI = {10.1126/scirobotics.adi8022},
64
+ number = {89},
65
+ journal = {Science Robotics},
66
+ publisher = {American Association for the Advancement of Science (AAAS)},
67
+ author = {Haarnoja, Tuomas and Moran, Ben and Lever, Guy and Huang, Sandy H. and Tirumala, Dhruva and Humplik, Jan and Wulfmeier, Markus and Tunyasuvunakool, Saran and Siegel, Noah Y. and Hafner, Roland and Bloesch, Michael and Hartikainen, Kristian and Byravan, Arunkumar and Hasenclever, Leonard and Tassa, Yuval and Sadeghi, Fereshteh and Batchelor, Nathan and Casarini, Federico and Saliceti, Stefano and Game, Charles and Sreendra, Neil and Patel, Kushal and Gwira, Marlon and Huber, Andrea and Hurley, Nicole and Nori, Francesco and Hadsell, Raia and Heess, Nicolas},
68
+ year = {2024},
69
+ month = {Apr}
70
+ }
71
+ ```
72
+
73
+ ```bibtex
74
+ @misc{abdolmaleki2018maximumposterioripolicyoptimisation,
75
+ title = {Maximum a Posteriori Policy Optimisation},
76
+ author = {Abbas Abdolmaleki and Jost Tobias Springenberg and Yuval Tassa and Remi Munos and Nicolas Heess and Martin Riedmiller},
77
+ year = {2018},
78
+ eprint = {1806.06920},
79
+ archivePrefix = {arXiv},
80
+ primaryClass = {cs.LG},
81
+ url = {https://arxiv.org/abs/1806.06920}
82
+ }
83
+ ```
84
+
85
+ ```bibtex
86
+ @misc{song2019vmpoonpolicymaximumposteriori,
87
+ title = {V-MPO: On-Policy Maximum a Posteriori Policy Optimization for Discrete and Continuous Control},
88
+ author = {H. Francis Song and Abbas Abdolmaleki and Jost Tobias Springenberg and Aidan Clark and Hubert Soyer and Jack W. Rae and Seb Noury and Arun Ahuja and Siqi Liu and Dhruva Tirumala and Nicolas Heess and Dan Belov and Martin Riedmiller and Matthew M. Botvinick},
89
+ year = {2019},
90
+ eprint = {1909.12238},
91
+ archivePrefix = {arXiv},
92
+ primaryClass = {cs.AI},
93
+ url = {https://arxiv.org/abs/1909.12238}
94
+ }
95
+ ```
96
+
97
+ ```bibtex
98
+ @InProceedings{pmlr-v235-li24z,
99
+ title = {Value-Evolutionary-Based Reinforcement Learning},
100
+ author = {Li, Pengyi and Hao, Jianye and Tang, Hongyao and Zheng, Yan and Barez, Fazl},
101
+ booktitle = {Proceedings of the 41st International Conference on Machine Learning},
102
+ pages = {27875--27889},
103
+ year = {2024},
104
+ editor = {Salakhutdinov, Ruslan and Kolter, Zico and Heller, Katherine and Weller, Adrian and Oliver, Nuria and Scarlett, Jonathan and Berkenkamp, Felix},
105
+ volume = {235},
106
+ series = {Proceedings of Machine Learning Research},
107
+ month = {21--27 Jul},
108
+ publisher = {PMLR},
109
+ pdf = {https://raw.githubusercontent.com/mlresearch/v235/main/assets/li24z/li24z.pdf},
110
+ url = {https://proceedings.mlr.press/v235/li24z.html}
111
+ }
112
+ ```
113
+
114
+ ```bibtex
115
+ @article{kaddour2026target,
116
+ title = {Target Policy Optimization},
117
+ author = {Kaddour, Jean},
118
+ journal = {arXiv preprint arXiv:2604.06159},
119
+ year = {2026}
120
+ }
121
+ ```
122
+
123
+ ```bibtex
124
+ @misc{qu2026listwisepolicyoptimizationgroupbased,
125
+ title = {Listwise Policy Optimization: Group-based RLVR as Target-Projection on the LLM Response Simplex},
126
+ author = {Yun Qu and Qi Wang and Yixiu Mao and Heming Zou and Yuhang Jiang and Yingyue Li and Wutong Xu and Lizhou Cai and Weijie Liu and Clive Bai and Kai Yang and Yangkun Chen and Saiyong Yang and Xiangyang Ji},
127
+ year = {2026},
128
+ eprint = {2605.06139},
129
+ archivePrefix = {arXiv},
130
+ primaryClass = {cs.LG},
131
+ url = {https://arxiv.org/abs/2605.06139},
132
+ }
133
+ ```
dmpo-0.0.2/README.md ADDED
@@ -0,0 +1,83 @@
1
+ ## DMPO (wip)
2
+
3
+ Implementation and explorations into [MPO](https://arxiv.org/abs/1806.06920) / DMPO
4
+
5
+ ## Citations
6
+
7
+ ```bibtex
8
+ @article{Haarnoja_2024,
9
+ title = {Learning agile soccer skills for a bipedal robot with deep reinforcement learning},
10
+ volume = {9},
11
+ ISSN = {2470-9476},
12
+ url = {http://dx.doi.org/10.1126/scirobotics.adi8022},
13
+ DOI = {10.1126/scirobotics.adi8022},
14
+ number = {89},
15
+ journal = {Science Robotics},
16
+ publisher = {American Association for the Advancement of Science (AAAS)},
17
+ author = {Haarnoja, Tuomas and Moran, Ben and Lever, Guy and Huang, Sandy H. and Tirumala, Dhruva and Humplik, Jan and Wulfmeier, Markus and Tunyasuvunakool, Saran and Siegel, Noah Y. and Hafner, Roland and Bloesch, Michael and Hartikainen, Kristian and Byravan, Arunkumar and Hasenclever, Leonard and Tassa, Yuval and Sadeghi, Fereshteh and Batchelor, Nathan and Casarini, Federico and Saliceti, Stefano and Game, Charles and Sreendra, Neil and Patel, Kushal and Gwira, Marlon and Huber, Andrea and Hurley, Nicole and Nori, Francesco and Hadsell, Raia and Heess, Nicolas},
18
+ year = {2024},
19
+ month = {Apr}
20
+ }
21
+ ```
22
+
23
+ ```bibtex
24
+ @misc{abdolmaleki2018maximumposterioripolicyoptimisation,
25
+ title = {Maximum a Posteriori Policy Optimisation},
26
+ author = {Abbas Abdolmaleki and Jost Tobias Springenberg and Yuval Tassa and Remi Munos and Nicolas Heess and Martin Riedmiller},
27
+ year = {2018},
28
+ eprint = {1806.06920},
29
+ archivePrefix = {arXiv},
30
+ primaryClass = {cs.LG},
31
+ url = {https://arxiv.org/abs/1806.06920}
32
+ }
33
+ ```
34
+
35
+ ```bibtex
36
+ @misc{song2019vmpoonpolicymaximumposteriori,
37
+ title = {V-MPO: On-Policy Maximum a Posteriori Policy Optimization for Discrete and Continuous Control},
38
+ author = {H. Francis Song and Abbas Abdolmaleki and Jost Tobias Springenberg and Aidan Clark and Hubert Soyer and Jack W. Rae and Seb Noury and Arun Ahuja and Siqi Liu and Dhruva Tirumala and Nicolas Heess and Dan Belov and Martin Riedmiller and Matthew M. Botvinick},
39
+ year = {2019},
40
+ eprint = {1909.12238},
41
+ archivePrefix = {arXiv},
42
+ primaryClass = {cs.AI},
43
+ url = {https://arxiv.org/abs/1909.12238}
44
+ }
45
+ ```
46
+
47
+ ```bibtex
48
+ @InProceedings{pmlr-v235-li24z,
49
+ title = {Value-Evolutionary-Based Reinforcement Learning},
50
+ author = {Li, Pengyi and Hao, Jianye and Tang, Hongyao and Zheng, Yan and Barez, Fazl},
51
+ booktitle = {Proceedings of the 41st International Conference on Machine Learning},
52
+ pages = {27875--27889},
53
+ year = {2024},
54
+ editor = {Salakhutdinov, Ruslan and Kolter, Zico and Heller, Katherine and Weller, Adrian and Oliver, Nuria and Scarlett, Jonathan and Berkenkamp, Felix},
55
+ volume = {235},
56
+ series = {Proceedings of Machine Learning Research},
57
+ month = {21--27 Jul},
58
+ publisher = {PMLR},
59
+ pdf = {https://raw.githubusercontent.com/mlresearch/v235/main/assets/li24z/li24z.pdf},
60
+ url = {https://proceedings.mlr.press/v235/li24z.html}
61
+ }
62
+ ```
63
+
64
+ ```bibtex
65
+ @article{kaddour2026target,
66
+ title = {Target Policy Optimization},
67
+ author = {Kaddour, Jean},
68
+ journal = {arXiv preprint arXiv:2604.06159},
69
+ year = {2026}
70
+ }
71
+ ```
72
+
73
+ ```bibtex
74
+ @misc{qu2026listwisepolicyoptimizationgroupbased,
75
+ title = {Listwise Policy Optimization: Group-based RLVR as Target-Projection on the LLM Response Simplex},
76
+ author = {Yun Qu and Qi Wang and Yixiu Mao and Heming Zou and Yuhang Jiang and Yingyue Li and Wutong Xu and Lizhou Cai and Weijie Liu and Clive Bai and Kai Yang and Yangkun Chen and Saiyong Yang and Xiangyang Ji},
77
+ year = {2026},
78
+ eprint = {2605.06139},
79
+ archivePrefix = {arXiv},
80
+ primaryClass = {cs.LG},
81
+ url = {https://arxiv.org/abs/2605.06139},
82
+ }
83
+ ```
@@ -0,0 +1 @@
1
+ from dmpo.tpo import TPO
@@ -0,0 +1 @@
1
+
dmpo-0.0.2/dmpo/mpo.py ADDED
@@ -0,0 +1 @@
1
+
dmpo-0.0.2/dmpo/tpo.py ADDED
@@ -0,0 +1,399 @@
1
+ from collections import deque, namedtuple
2
+ from tqdm import tqdm
3
+
4
+ import torch
5
+ from torch.nn import Module
6
+ from torch.optim import Adam
7
+ import torch.nn.functional as F
8
+
9
+ from einops import rearrange, einsum, reduce
10
+ from torch_einops_utils import masked_mean, lens_to_mask
11
+
12
+ from accelerate import Accelerator
13
+ from memmap_replay_buffer import ReplayBuffer
14
+ from discrete_continuous_embed_readout import ParameterlessReadout
15
+
16
+ # helpers
17
+
18
+ def exists(val):
19
+ return val is not None
20
+
21
+ def default(val, d):
22
+ return val if exists(val) else d
23
+
24
+ def z_score(t, eps = 1e-8):
25
+ return (t - t.mean()) / (t.std(unbiased = False) + eps)
26
+
27
+ # tpo loss functions
28
+
29
+ LogScoreReturn = namedtuple('LogScoreReturn', ['log_scores', 'logits'])
30
+
31
+ def tpo_target(log_scores, u, eta = 1.0):
32
+ return F.log_softmax(log_scores + u / eta, dim = -1)
33
+
34
+ def tpo_forward_kl_loss(log_p, log_q):
35
+ q = log_q.exp()
36
+ return -einsum(q, log_p, '... k, ... k -> ...').mean()
37
+
38
+ def tpo_reverse_kl_loss(log_p, log_q):
39
+ p = log_p.exp()
40
+ return einsum(p, log_p - log_q, '... k, ... k -> ...').mean()
41
+
42
+ def tpo_js_loss(log_p, log_q, weight = 0.5, eps = 1e-10):
43
+ p = log_p.exp()
44
+ q = log_q.exp()
45
+
46
+ m = q.lerp(p, weight)
47
+ log_m = m.clamp(min = eps).log()
48
+
49
+ kl_p_m = einsum(p, log_p - log_m, '... k, ... k -> ...').mean()
50
+ kl_q_m = einsum(q, log_q - log_m, '... k, ... k -> ...').mean()
51
+
52
+ return kl_q_m.lerp(kl_p_m, weight)
53
+
54
+ TPO_LOSS_FNS = dict(
55
+ forward_kl = tpo_forward_kl_loss,
56
+ reverse_kl = tpo_reverse_kl_loss,
57
+ js = tpo_js_loss
58
+ )
59
+
60
+ # environments
61
+
62
+ class GymEnvironment(Module):
63
+ def __init__(
64
+ self,
65
+ env,
66
+ readout,
67
+ maybe_reshape_logits,
68
+ action_fields,
69
+ is_discrete,
70
+ is_continuous,
71
+ num_continuous = None,
72
+ num_discrete_categories = None,
73
+ num_discrete_logits = None,
74
+ group_size = 64,
75
+ max_timesteps = None,
76
+ buffer_folder = './tpo_buffer',
77
+ overwrite_buffer_on_start = True
78
+ ):
79
+ super().__init__()
80
+ self.env = env
81
+ self.readout = readout
82
+
83
+ self.is_discrete = is_discrete
84
+ self.is_continuous = is_continuous
85
+ self.num_continuous = num_continuous
86
+ self.maybe_reshape_logits = maybe_reshape_logits
87
+
88
+ self.num_discrete_categories = num_discrete_categories
89
+
90
+ if exists(num_discrete_categories):
91
+ categories = torch.tensor(num_discrete_categories)
92
+ self.register_buffer('categories', categories)
93
+ self.register_buffer('divisors', torch.cat((torch.tensor([1]), categories.cumprod(dim = 0)[:-1])))
94
+
95
+ self.group_size = group_size
96
+
97
+ obs_dim = int(env.observation_space.shape[0])
98
+ max_timesteps = default(max_timesteps, group_size * 1000)
99
+
100
+ self.buffer = ReplayBuffer(
101
+ folder = buffer_folder,
102
+ max_episodes = group_size,
103
+ max_timesteps = max_timesteps,
104
+ fields = dict(
105
+ state = ('float', (obs_dim,)),
106
+ **action_fields
107
+ ),
108
+ meta_fields = dict(
109
+ cum_reward = 'float'
110
+ ),
111
+ circular = False,
112
+ overwrite = overwrite_buffer_on_start
113
+ )
114
+
115
+ @property
116
+ def is_multi_discrete(self):
117
+ return exists(self.num_discrete_categories)
118
+
119
+ def get_discrete_env_action(self, discrete_tensor):
120
+ if not self.is_multi_discrete:
121
+ return discrete_tensor.item()
122
+ return ((discrete_tensor // self.divisors.to(discrete_tensor.device)) % self.categories.to(discrete_tensor.device)).cpu().numpy()
123
+
124
+ def action_to_env(self, action_tensor):
125
+ if self.is_continuous and not self.is_discrete:
126
+ return action_tensor.cpu().numpy()
127
+
128
+ if self.is_discrete and not self.is_continuous:
129
+ return self.get_discrete_env_action(action_tensor)
130
+
131
+ discrete_tensor, continuous_tensor = action_tensor
132
+ return (self.get_discrete_env_action(discrete_tensor), continuous_tensor.cpu().numpy())
133
+
134
+ def forward(self, actor):
135
+ device = next(actor.parameters()).device
136
+ self.buffer.clear()
137
+
138
+ for k in range(self.group_size):
139
+ state, _ = self.env.reset()
140
+ episode_reward = 0.
141
+ done = False
142
+
143
+ while not done:
144
+ state_t = torch.tensor(state, dtype = torch.float32, device = device)
145
+
146
+ with torch.no_grad():
147
+ logits = self.maybe_reshape_logits(actor(state_t))
148
+ action_tensor = self.readout.sample(logits)
149
+
150
+ action = self.action_to_env(action_tensor)
151
+
152
+ next_state, reward, terminated, truncated, _ = self.env.step(action)
153
+ done = terminated or truncated
154
+
155
+ store_kwargs = dict(state = state)
156
+
157
+ if self.is_discrete:
158
+ t = action_tensor[0] if self.is_continuous else action_tensor
159
+ store_kwargs['action_discrete'] = t.item()
160
+
161
+ if self.is_continuous:
162
+ t = action_tensor[1] if self.is_discrete else action_tensor
163
+ store_kwargs['action_continuous'] = t.cpu().numpy()
164
+
165
+ self.buffer.store(**store_kwargs)
166
+
167
+ episode_reward += reward
168
+ state = next_state
169
+
170
+ self.buffer.store_meta_datapoint(k, 'cum_reward', episode_reward)
171
+ self.buffer.advance_episode()
172
+
173
+ return self.buffer.get_all_data(device = device)
174
+
175
+ # main class
176
+
177
+ class TPO(Module):
178
+ def __init__(
179
+ self,
180
+ actor,
181
+ environment,
182
+ *,
183
+ action_num_discrete = None,
184
+ action_num_continuous = None,
185
+ buffer_folder = './tpo_buffer',
186
+ overwrite_buffer_on_start = True,
187
+ max_timesteps = None,
188
+ epochs = 4,
189
+ group_size = 64,
190
+ optim = None,
191
+ optim_kwargs = dict(),
192
+ lr = 3e-4,
193
+ max_grad_norm = None,
194
+ eta = 1.0,
195
+ min_rewards_std = 1e-4,
196
+ entropy_coef = 0.01,
197
+ divergence = 'forward_kl',
198
+ reward_moving_average_len = 20,
199
+ cpu = False,
200
+ on_result = None,
201
+ **readout_kwargs
202
+ ):
203
+ super().__init__()
204
+
205
+ self.has_discrete = exists(action_num_discrete)
206
+ self.has_continuous = exists(action_num_continuous)
207
+
208
+ assert self.has_discrete or self.has_continuous, 'must specify at least one of action_num_discrete or action_num_continuous'
209
+
210
+ # readout
211
+
212
+ readout_params = dict(**readout_kwargs)
213
+
214
+ if self.has_discrete:
215
+ readout_params['num_discrete'] = action_num_discrete
216
+
217
+ if self.has_continuous:
218
+ readout_params['num_continuous'] = action_num_continuous
219
+
220
+ self.readout = ParameterlessReadout(**readout_params)
221
+
222
+ # derive buffer field and action conversion from config
223
+
224
+ action_fields = dict()
225
+ num_discrete_categories = None
226
+ self.num_discrete_logits = None
227
+
228
+ if self.has_discrete:
229
+ is_multi = isinstance(action_num_discrete, (tuple, list))
230
+ action_fields['action_discrete'] = 'int'
231
+ num_discrete_categories = tuple(action_num_discrete) if is_multi else None
232
+ self.num_discrete_logits = sum(action_num_discrete) if is_multi else action_num_discrete
233
+
234
+ if self.has_continuous:
235
+ action_fields['action_continuous'] = ('float', (action_num_continuous,))
236
+
237
+ # setup environment
238
+
239
+ if not callable(environment):
240
+ self.environment = GymEnvironment(
241
+ environment,
242
+ readout = self.readout,
243
+ maybe_reshape_logits = self.maybe_reshape_logits,
244
+ action_fields = action_fields,
245
+ is_discrete = self.has_discrete,
246
+ is_continuous = self.has_continuous,
247
+ num_continuous = action_num_continuous,
248
+ num_discrete_categories = num_discrete_categories,
249
+ num_discrete_logits = self.num_discrete_logits,
250
+ group_size = group_size,
251
+ max_timesteps = max_timesteps,
252
+ buffer_folder = buffer_folder,
253
+ overwrite_buffer_on_start = overwrite_buffer_on_start
254
+ )
255
+ else:
256
+ self.environment = environment
257
+
258
+ # store refs
259
+
260
+ self.num_continuous = action_num_continuous
261
+
262
+ self.actor = actor
263
+
264
+ self.accelerator = Accelerator(cpu = cpu)
265
+ self.device = self.accelerator.device
266
+
267
+ if exists(optim):
268
+ self.optimizer = optim
269
+ else:
270
+ self.optimizer = Adam(self.actor.parameters(), lr = lr, **optim_kwargs)
271
+
272
+ self.actor, self.readout, self.optimizer = self.accelerator.prepare(
273
+ self.actor, self.readout, self.optimizer
274
+ )
275
+
276
+ self.epochs = epochs
277
+ self.eta = eta
278
+ self.min_rewards_std = min_rewards_std
279
+ self.max_grad_norm = max_grad_norm
280
+ self.entropy_coef = entropy_coef
281
+ self.reward_moving_average_len = reward_moving_average_len
282
+
283
+ assert divergence in TPO_LOSS_FNS, f'divergence must be one of {list(TPO_LOSS_FNS.keys())}'
284
+ self.tpo_loss_fn = TPO_LOSS_FNS[divergence]
285
+
286
+ self.on_result = on_result
287
+
288
+ def maybe_reshape_logits(self, logits):
289
+ if self.has_discrete and not self.has_continuous:
290
+ return logits
291
+
292
+ if self.has_continuous and not self.has_discrete:
293
+ return rearrange(logits, '... (c d) -> ... c d', c = self.num_continuous)
294
+
295
+ discrete_logits, continuous_logits = logits.split([self.num_discrete_logits, self.num_continuous * 2], dim = -1)
296
+ continuous_params = rearrange(continuous_logits, '... (c d) -> ... c d', c = self.num_continuous)
297
+
298
+ return (discrete_logits, continuous_params)
299
+
300
+ def calculate_log_scores(self, states, actions, mask, episode_lens_float):
301
+ logits = self.maybe_reshape_logits(self.actor(states))
302
+
303
+ neg_log_probs = self.readout.calculate_loss(
304
+ logits,
305
+ targets = actions,
306
+ mask = mask,
307
+ return_unreduced_loss = True
308
+ )
309
+
310
+ log_scores = reduce(-neg_log_probs, 'b ... -> b', 'sum')
311
+ log_scores = log_scores / episode_lens_float
312
+
313
+ return LogScoreReturn(log_scores, logits)
314
+
315
+ def forward(
316
+ self,
317
+ num_iterations = 2000
318
+ ):
319
+ device = self.device
320
+ recent_rewards = deque(maxlen = self.reward_moving_average_len)
321
+ pbar = tqdm(range(num_iterations), desc = 'tpo training')
322
+
323
+ for it in pbar:
324
+
325
+ # get rollout
326
+
327
+ data = self.environment(self.actor)
328
+
329
+ # unpack data
330
+
331
+ states = data['state']
332
+
333
+ if 'action' in data:
334
+ actions = data['action']
335
+ elif self.has_discrete and self.has_continuous:
336
+ actions = (data['action_discrete'], data['action_continuous'])
337
+ elif self.has_discrete:
338
+ actions = data['action_discrete']
339
+ else:
340
+ actions = data['action_continuous']
341
+
342
+ rewards = data.get('cum_reward', data.get('reward'))
343
+ episode_lens = data.get('episode_lens')
344
+
345
+ # log reward
346
+
347
+ recent_rewards.extend(rewards.tolist())
348
+
349
+ avg_reward = sum(recent_rewards) / max(1, len(recent_rewards))
350
+
351
+ if exists(self.on_result):
352
+ self.on_result(avg_reward, pbar)
353
+ else:
354
+ pbar.set_postfix(avg_reward = f'{avg_reward:.2f}')
355
+
356
+ # calculate baseline and mask
357
+
358
+ if rewards.std(unbiased = False) < self.min_rewards_std:
359
+ u = torch.zeros_like(rewards)
360
+ else:
361
+ u = z_score(rewards)
362
+
363
+ mask = data.get('mask')
364
+
365
+ if not exists(mask):
366
+ assert exists(episode_lens), 'episode_lens must be returned by environment if mask is not provided'
367
+ mask = lens_to_mask(episode_lens, max_len = states.shape[1])
368
+
369
+ mask = mask.to(device)
370
+
371
+ episode_lens_float = mask.sum(dim = 1).clamp(min = 1.).float()
372
+
373
+ # target distribution
374
+
375
+ with torch.no_grad():
376
+ out = self.calculate_log_scores(states, actions, mask, episode_lens_float)
377
+ log_q = tpo_target(out.log_scores, u, self.eta)
378
+
379
+ # train policy
380
+
381
+ for epoch in range(self.epochs):
382
+ self.optimizer.zero_grad()
383
+
384
+ out = self.calculate_log_scores(states, actions, mask, episode_lens_float)
385
+
386
+ log_p = F.log_softmax(out.log_scores, dim = -1)
387
+
388
+ entropy = self.readout.entropy(out.logits)
389
+ entropy = masked_mean(entropy, mask)
390
+
391
+ loss = self.tpo_loss_fn(log_p, log_q)
392
+ loss = loss - self.entropy_coef * entropy
393
+
394
+ self.accelerator.backward(loss)
395
+
396
+ if exists(self.max_grad_norm):
397
+ self.accelerator.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
398
+
399
+ self.optimizer.step()
@@ -0,0 +1 @@
1
+
@@ -0,0 +1,74 @@
1
+ [project]
2
+ name = "dmpo"
3
+ version = "0.0.2"
4
+ description = "Maximum a Posteriori Policy Optimization and Related Algorithms"
5
+ authors = [
6
+ { name = "Phil Wang", email = "lucidrains@gmail.com" }
7
+ ]
8
+ readme = "README.md"
9
+ requires-python = ">= 3.10"
10
+ license = { file = "LICENSE" }
11
+ keywords = [
12
+ 'artificial intelligence',
13
+ 'deep learning',
14
+ 'reinforcement learning',
15
+ 'mpo',
16
+ 'tpo'
17
+ ]
18
+
19
+ classifiers=[
20
+ 'Development Status :: 4 - Beta',
21
+ 'Intended Audience :: Developers',
22
+ 'Topic :: Scientific/Engineering :: Artificial Intelligence',
23
+ 'License :: OSI Approved :: MIT License',
24
+ 'Programming Language :: Python :: 3.10',
25
+ ]
26
+
27
+ dependencies = [
28
+ "accelerate",
29
+ "discrete-continuous-embed-readout",
30
+ "einops>=0.8.1",
31
+ "memmap-replay-buffer>=0.1.4",
32
+ "torch>=2.5",
33
+ "torch-einops-utils>=0.1.2",
34
+ "tqdm",
35
+ "x-mlps-pytorch",
36
+ "x-transformers"
37
+ ]
38
+
39
+ [project.urls]
40
+ Homepage = "https://pypi.org/project/dmpo/"
41
+ Repository = "https://codeberg.org/lucidrains/dmpo"
42
+
43
+ [project.optional-dependencies]
44
+ examples = []
45
+ test = [
46
+ "pytest"
47
+ ]
48
+
49
+ [tool.pytest.ini_options]
50
+ pythonpath = [
51
+ "."
52
+ ]
53
+
54
+ [build-system]
55
+ requires = ["hatchling"]
56
+ build-backend = "hatchling.build"
57
+
58
+ [tool.rye]
59
+ managed = true
60
+ dev-dependencies = []
61
+
62
+ [tool.hatch.metadata]
63
+ allow-direct-references = true
64
+
65
+ [tool.hatch.build]
66
+ include = [
67
+ "dmpo",
68
+ "pyproject.toml",
69
+ "README.md",
70
+ "LICENSE"
71
+ ]
72
+
73
+ [tool.hatch.build.targets.wheel]
74
+ packages = ["dmpo"]