diffusiongym 2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. diffusiongym-2.0/.editorconfig +16 -0
  2. diffusiongym-2.0/.envrc +2 -0
  3. diffusiongym-2.0/.github/workflows/docs.yml +22 -0
  4. diffusiongym-2.0/.github/workflows/test.yml +32 -0
  5. diffusiongym-2.0/.gitignore +188 -0
  6. diffusiongym-2.0/.pre-commit-config.yaml +27 -0
  7. diffusiongym-2.0/LICENSE +21 -0
  8. diffusiongym-2.0/PKG-INFO +98 -0
  9. diffusiongym-2.0/README.md +46 -0
  10. diffusiongym-2.0/diffusiongym/__init__.py +59 -0
  11. diffusiongym-2.0/diffusiongym/base_models/__init__.py +6 -0
  12. diffusiongym-2.0/diffusiongym/base_models/base.py +163 -0
  13. diffusiongym-2.0/diffusiongym/base_models/one_dim_gmm.py +202 -0
  14. diffusiongym-2.0/diffusiongym/environments/__init__.py +16 -0
  15. diffusiongym-2.0/diffusiongym/environments/base.py +383 -0
  16. diffusiongym-2.0/diffusiongym/environments/endpoint.py +90 -0
  17. diffusiongym-2.0/diffusiongym/environments/epsilon.py +91 -0
  18. diffusiongym-2.0/diffusiongym/environments/score.py +90 -0
  19. diffusiongym-2.0/diffusiongym/environments/velocity.py +95 -0
  20. diffusiongym-2.0/diffusiongym/images/__init__.py +17 -0
  21. diffusiongym-2.0/diffusiongym/images/base_models/cifar.py +102 -0
  22. diffusiongym-2.0/diffusiongym/images/base_models/dit.py +171 -0
  23. diffusiongym-2.0/diffusiongym/images/base_models/refl_data.json +40002 -0
  24. diffusiongym-2.0/diffusiongym/images/base_models/stable_diffusion.py +358 -0
  25. diffusiongym-2.0/diffusiongym/images/rewards/__init__.py +6 -0
  26. diffusiongym-2.0/diffusiongym/images/rewards/aesthetic.py +75 -0
  27. diffusiongym-2.0/diffusiongym/images/rewards/compression.py +84 -0
  28. diffusiongym-2.0/diffusiongym/make.py +136 -0
  29. diffusiongym-2.0/diffusiongym/molecules/__init__.py +39 -0
  30. diffusiongym-2.0/diffusiongym/molecules/flowmol.py +301 -0
  31. diffusiongym-2.0/diffusiongym/molecules/rewards/__init__.py +27 -0
  32. diffusiongym-2.0/diffusiongym/molecules/rewards/qed.py +33 -0
  33. diffusiongym-2.0/diffusiongym/molecules/rewards/utils.py +117 -0
  34. diffusiongym-2.0/diffusiongym/molecules/rewards/validity.py +30 -0
  35. diffusiongym-2.0/diffusiongym/molecules/rewards/xtb.py +236 -0
  36. diffusiongym-2.0/diffusiongym/molecules/types.py +220 -0
  37. diffusiongym-2.0/diffusiongym/py.typed +0 -0
  38. diffusiongym-2.0/diffusiongym/registry.py +171 -0
  39. diffusiongym-2.0/diffusiongym/rewards/__init__.py +11 -0
  40. diffusiongym-2.0/diffusiongym/rewards/base.py +26 -0
  41. diffusiongym-2.0/diffusiongym/rewards/one_dim.py +34 -0
  42. diffusiongym-2.0/diffusiongym/schedulers/__init__.py +15 -0
  43. diffusiongym-2.0/diffusiongym/schedulers/base.py +199 -0
  44. diffusiongym-2.0/diffusiongym/schedulers/noise_schedules.py +24 -0
  45. diffusiongym-2.0/diffusiongym/schedulers/schedulers.py +93 -0
  46. diffusiongym-2.0/diffusiongym/types.py +241 -0
  47. diffusiongym-2.0/diffusiongym/utils.py +258 -0
  48. diffusiongym-2.0/docs/Makefile +20 -0
  49. diffusiongym-2.0/docs/_static/teaser.gif +0 -0
  50. diffusiongym-2.0/docs/_templates/class.rst +8 -0
  51. diffusiongym-2.0/docs/api/base_models.rst +11 -0
  52. diffusiongym-2.0/docs/api/environments.rst +15 -0
  53. diffusiongym-2.0/docs/api/images.rst +18 -0
  54. diffusiongym-2.0/docs/api/molecules.rst +23 -0
  55. diffusiongym-2.0/docs/api/rewards.rst +11 -0
  56. diffusiongym-2.0/docs/api/schedulers.rst +16 -0
  57. diffusiongym-2.0/docs/api/types.rst +14 -0
  58. diffusiongym-2.0/docs/api.rst +14 -0
  59. diffusiongym-2.0/docs/changelog/includes/1.11.rst +8 -0
  60. diffusiongym-2.0/docs/changelog/includes/1.12.rst +6 -0
  61. diffusiongym-2.0/docs/changelog/includes/1.13.rst +7 -0
  62. diffusiongym-2.0/docs/changelog/includes/1.3.rst +4 -0
  63. diffusiongym-2.0/docs/changelog/includes/1.6.rst +7 -0
  64. diffusiongym-2.0/docs/changelog/includes/1.7.rst +4 -0
  65. diffusiongym-2.0/docs/changelog/includes/1.8.rst +4 -0
  66. diffusiongym-2.0/docs/changelog/includes/1.9.rst +7 -0
  67. diffusiongym-2.0/docs/changelog/includes/2.0.rst +8 -0
  68. diffusiongym-2.0/docs/changelog/index.rst +28 -0
  69. diffusiongym-2.0/docs/conf.py +106 -0
  70. diffusiongym-2.0/docs/index.rst +58 -0
  71. diffusiongym-2.0/docs/make.bat +35 -0
  72. diffusiongym-2.0/docs/math.rst +146 -0
  73. diffusiongym-2.0/docs/policies.rst +29 -0
  74. diffusiongym-2.0/docs/quickstart.rst +77 -0
  75. diffusiongym-2.0/docs/registries.rst +44 -0
  76. diffusiongym-2.0/docs/stable_diffusion.rst +181 -0
  77. diffusiongym-2.0/pixi.lock +4087 -0
  78. diffusiongym-2.0/pyproject.toml +152 -0
@@ -0,0 +1,16 @@
1
+ # http://editorconfig.org/#file-format-details
2
+ root = true
3
+
4
+ [*]
5
+ charset = utf-8
6
+ end_of_line = lf
7
+ indent_size = 4
8
+ indent_style = space
9
+ insert_final_newline = true
10
+ trim_trailing_whitespace = true
11
+
12
+ [*.md]
13
+ trim_trailing_whitespace = false
14
+
15
+ [Makefile]
16
+ indent_style = tab
@@ -0,0 +1,2 @@
1
+ watch_file pixi.lock
2
+ eval "$(pixi shell-hook -e dev)"
@@ -0,0 +1,22 @@
1
+ name: Deploy Docs
2
+
3
+ on:
4
+ push:
5
+ branches: [main]
6
+
7
+ permissions:
8
+ contents: read
9
+ pages: write
10
+ id-token: write
11
+
12
+ jobs:
13
+ deploy:
14
+ runs-on: ubuntu-latest
15
+ steps:
16
+ - uses: actions/checkout@v4
17
+ - uses: prefix-dev/setup-pixi@v0.8.1
18
+ - run: pixi run -e dev sphinx-build -b html docs docs/_build/html
19
+ - uses: actions/upload-pages-artifact@v3
20
+ with:
21
+ path: docs/_build/html
22
+ - uses: actions/deploy-pages@v4
@@ -0,0 +1,32 @@
1
+ name: Test
2
+
3
+ on:
4
+ pull_request: {}
5
+ push:
6
+ branches: master
7
+
8
+ jobs:
9
+ test:
10
+ strategy:
11
+ matrix:
12
+ python-version: ['3.13']
13
+ os: [ubuntu-latest]
14
+
15
+ name: Python ${{ matrix.os }} ${{ matrix.python-version }}
16
+ runs-on: ${{ matrix.os }}
17
+
18
+ steps:
19
+ - uses: actions/checkout@v4
20
+
21
+ - uses: actions/setup-python@v5
22
+ with:
23
+ python-version: ${{ matrix.python-version }}
24
+
25
+ - uses: prefix-dev/setup-pixi@v0.8.8
26
+ with:
27
+ pixi-version: v0.48.0
28
+ cache: true
29
+ environments: dev
30
+
31
+ - run: pixi run fmt
32
+ - run: pixi run lint
@@ -0,0 +1,188 @@
1
+ # Created by https://www.toptal.com/developers/gitignore/api/python
2
+ # Edit at https://www.toptal.com/developers/gitignore?templates=python
3
+
4
+ ### Python ###
5
+ # Byte-compiled / optimized / DLL files
6
+ __pycache__/
7
+ *.py[cod]
8
+ *$py.class
9
+
10
+ # C extensions
11
+ *.so
12
+
13
+ # Distribution / packaging
14
+ .Python
15
+ build/
16
+ develop-eggs/
17
+ dist/
18
+ downloads/
19
+ eggs/
20
+ .eggs/
21
+ lib/
22
+ lib64/
23
+ parts/
24
+ sdist/
25
+ var/
26
+ wheels/
27
+ share/python-wheels/
28
+ *.egg-info/
29
+ .installed.cfg
30
+ *.egg
31
+ MANIFEST
32
+
33
+ # PyInstaller
34
+ # Usually these files are written by a python script from a template
35
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
36
+ *.manifest
37
+ *.spec
38
+
39
+ # Installer logs
40
+ pip-log.txt
41
+ pip-delete-this-directory.txt
42
+
43
+ # Unit test / coverage reports
44
+ htmlcov/
45
+ .tox/
46
+ .nox/
47
+ .coverage
48
+ .coverage.*
49
+ .cache
50
+ nosetests.xml
51
+ coverage.xml
52
+ *.cover
53
+ *.py,cover
54
+ .hypothesis/
55
+ .pytest_cache/
56
+ cover/
57
+
58
+ # Translations
59
+ *.mo
60
+ *.pot
61
+
62
+ # Django stuff:
63
+ *.log
64
+ local_settings.py
65
+ db.sqlite3
66
+ db.sqlite3-journal
67
+
68
+ # Flask stuff:
69
+ instance/
70
+ .webassets-cache
71
+
72
+ # Scrapy stuff:
73
+ .scrapy
74
+
75
+ # Sphinx documentation
76
+ docs/_build/
77
+
78
+ # PyBuilder
79
+ .pybuilder/
80
+ target/
81
+
82
+ # Jupyter Notebook
83
+ .ipynb_checkpoints
84
+
85
+ # IPython
86
+ profile_default/
87
+ ipython_config.py
88
+
89
+ # pyenv
90
+ # For a library or package, you might want to ignore these files since the code is
91
+ # intended to run in multiple environments; otherwise, check them in:
92
+ # .python-version
93
+
94
+ # pipenv
95
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
97
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
98
+ # install all needed dependencies.
99
+ #Pipfile.lock
100
+
101
+ # poetry
102
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
103
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
104
+ # commonly ignored for libraries.
105
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
106
+ #poetry.lock
107
+
108
+ # pdm
109
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
110
+ #pdm.lock
111
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
112
+ # in version control.
113
+ # https://pdm.fming.dev/#use-with-ide
114
+ .pdm.toml
115
+
116
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
117
+ __pypackages__/
118
+
119
+ # Celery stuff
120
+ celerybeat-schedule
121
+ celerybeat.pid
122
+
123
+ # SageMath parsed files
124
+ *.sage.py
125
+
126
+ # Environments
127
+ .env
128
+ .venv
129
+ env/
130
+ venv/
131
+ ENV/
132
+ env.bak/
133
+ venv.bak/
134
+
135
+ # Spyder project settings
136
+ .spyderproject
137
+ .spyproject
138
+
139
+ # Rope project settings
140
+ .ropeproject
141
+
142
+ # mkdocs documentation
143
+ /site
144
+
145
+ # mypy
146
+ .mypy_cache/
147
+ .dmypy.json
148
+ dmypy.json
149
+
150
+ # Pyre type checker
151
+ .pyre/
152
+
153
+ # pytype static type analyzer
154
+ .pytype/
155
+
156
+ # Cython debug symbols
157
+ cython_debug/
158
+
159
+ # PyCharm
160
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
161
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
162
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
163
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
164
+ #.idea/
165
+
166
+ ### Python Patch ###
167
+ # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
168
+ poetry.toml
169
+
170
+ # ruff
171
+ .ruff_cache/
172
+
173
+ # LSP config files
174
+ pyrightconfig.json
175
+
176
+ # End of https://www.toptal.com/developers/gitignore/api/python
177
+
178
+ # Pixi
179
+ .pixi
180
+ *.ipynb
181
+
182
+ docs/api/generated/
183
+ .DS_Store
184
+ *.png
185
+ jobs/
186
+ slurm-*
187
+ *.pt
188
+ scripts/
@@ -0,0 +1,27 @@
1
+ exclude: '.pixi/'
2
+ repos:
3
+ - repo: https://github.com/pre-commit/pre-commit-hooks
4
+ rev: v5.0.0 # this is optional, use `pre-commit autoupdate` to get the latest rev!
5
+ hooks:
6
+ - id: check-yaml
7
+ - id: check-toml
8
+ - id: end-of-file-fixer
9
+ - id: trailing-whitespace
10
+
11
+ - repo: local
12
+ hooks:
13
+ - id: ruff
14
+ name: ruff-format
15
+ stages: [pre-commit, pre-push]
16
+ language: system
17
+ entry: pixi run fmt
18
+ types: [python]
19
+ pass_filenames: false
20
+
21
+ - id: ruff
22
+ name: ruff-check
23
+ stages: [pre-commit, pre-push]
24
+ language: system
25
+ entry: pixi run lint
26
+ types: [python]
27
+ pass_filenames: false
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Cristian Perez Jensen
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,98 @@
1
+ Metadata-Version: 2.4
2
+ Name: diffusiongym
3
+ Version: 2.0
4
+ Summary: Diffusion Gym
5
+ Project-URL: Homepage, https://github.com/cristianpjensen/diffusiongym
6
+ Project-URL: Issues, https://github.com/cristianpjensen/diffusiongym/issues
7
+ Author: Cristian Perez Jensen
8
+ License: MIT License
9
+
10
+ Copyright (c) 2025 Cristian Perez Jensen
11
+
12
+ Permission is hereby granted, free of charge, to any person obtaining a copy
13
+ of this software and associated documentation files (the "Software"), to deal
14
+ in the Software without restriction, including without limitation the rights
15
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
16
+ copies of the Software, and to permit persons to whom the Software is
17
+ furnished to do so, subject to the following conditions:
18
+
19
+ The above copyright notice and this permission notice shall be included in all
20
+ copies or substantial portions of the Software.
21
+
22
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
23
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
24
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
25
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
26
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
27
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
28
+ SOFTWARE.
29
+ License-File: LICENSE
30
+ Requires-Python: <3.12,>=3.8
31
+ Requires-Dist: datasets==2.11.0
32
+ Requires-Dist: dgl==2.4
33
+ Requires-Dist: diffusers==0.33.1
34
+ Requires-Dist: meeko==0.6.1
35
+ Requires-Dist: numpy<2,>=1.24
36
+ Requires-Dist: open-clip-torch<3,>=2.26
37
+ Requires-Dist: pandas<3,>=2
38
+ Requires-Dist: polars<2,>=1.34.0
39
+ Requires-Dist: prody<3,>=2.4
40
+ Requires-Dist: pyarrow==11.0.0
41
+ Requires-Dist: pydantic<3,>=1.10
42
+ Requires-Dist: pyyaml<7,>=6
43
+ Requires-Dist: rdkit-stubs<0.9,>=0.6
44
+ Requires-Dist: rdkit<2025,>=2023.9.4
45
+ Requires-Dist: torch==2.3.1
46
+ Requires-Dist: torchdata<0.9,>=0.7
47
+ Requires-Dist: torchvision==0.18.1
48
+ Requires-Dist: tqdm<5,>=4.66
49
+ Requires-Dist: transformers>=4.36
50
+ Requires-Dist: typing-extensions<5,>=4.9
51
+ Description-Content-Type: text/markdown
52
+
53
+ # Diffusion Gym
54
+
55
+ <div align="center">
56
+ <img src="docs/_static/teaser.gif" width="100%" />
57
+ </div>
58
+
59
+ <p align="center">
60
+ <a href="https://github.com/cristianpjensen/diffusiongym/blob/master/LICENSE"><img alt="License" src="https://img.shields.io/github/license/cristianpjensen/diffusiongym"></a>
61
+ <a href="https://github.com/astral-sh/ruff"><img alt="Code style: ruff" src="https://img.shields.io/badge/code%20style-ruff-000000.svg"></a>
62
+ </p>
63
+
64
+ `diffusiongym` is a library for reward adaptation of any pre-trained flow model on any data modality.
65
+
66
+ ## Installation
67
+
68
+ In order to install *diffusiongym*, execute the following command:
69
+ ```bash
70
+ pip install diffusiongym
71
+ ```
72
+
73
+ *diffusiongym* requires PyTorch 2.3.1, and there may be other hard dependencies. Please open an issue if
74
+ installation fails through the above command.
75
+
76
+ Molecule environments depend on [FlowMol](https://github.com/cristianpjensen/FlowMol),
77
+ which currently needs to be installed manually:
78
+ ```bash
79
+ pip install git+https://github.com/cristianpjensen/FlowMol.git@8f4c98cbe68111e4e63480b250d925b6d960d3bc
80
+ ```
81
+
82
+ Some image rewards depend on the clip package, which needs to be installed manually as well:
83
+ ```bash
84
+ pip install git+https://github.com/openai/CLIP.git
85
+ ```
86
+
87
+ ## High-level overview
88
+
89
+ Diffusion and flow models are largely agnostic to their data modality. They only require that the underlying data type supports a small set of operations. Building on this idea, *diffusiongym* is designed to be fully modular. You only need to provide the following:
90
+ * Data type `YourDataType` that implements `DDProtocol`, which defines some functions necessary for interacting with it as a flow model.
91
+ * Base model `BaseModel[YourDataType]`, which defines the scheduler, how to sample $p_0$, how to compute the forward pass, and how to preprocess and postprocess data.
92
+ * Reward function `Reward[YourDataType]`.
93
+
94
+ Once these are defined, you can sample from the flow model and apply reward adaptation methods, such as Value Matching.
95
+
96
+ ## Documentation
97
+
98
+ Much more information can be found in [the documentation](https://cristianpjensen.github.io/diffusiongym/), including tutorials and API references.
@@ -0,0 +1,46 @@
1
+ # Diffusion Gym
2
+
3
+ <div align="center">
4
+ <img src="docs/_static/teaser.gif" width="100%" />
5
+ </div>
6
+
7
+ <p align="center">
8
+ <a href="https://github.com/cristianpjensen/diffusiongym/blob/master/LICENSE"><img alt="License" src="https://img.shields.io/github/license/cristianpjensen/diffusiongym"></a>
9
+ <a href="https://github.com/astral-sh/ruff"><img alt="Code style: ruff" src="https://img.shields.io/badge/code%20style-ruff-000000.svg"></a>
10
+ </p>
11
+
12
+ `diffusiongym` is a library for reward adaptation of any pre-trained flow model on any data modality.
13
+
14
+ ## Installation
15
+
16
+ In order to install *diffusiongym*, execute the following command:
17
+ ```bash
18
+ pip install diffusiongym
19
+ ```
20
+
21
+ *diffusiongym* requires PyTorch 2.3.1, and there may be other hard dependencies. Please open an issue if
22
+ installation fails through the above command.
23
+
24
+ Molecule environments depend on [FlowMol](https://github.com/cristianpjensen/FlowMol),
25
+ which currently needs to be installed manually:
26
+ ```bash
27
+ pip install git+https://github.com/cristianpjensen/FlowMol.git@8f4c98cbe68111e4e63480b250d925b6d960d3bc
28
+ ```
29
+
30
+ Some image rewards depend on the clip package, which needs to be installed manually as well:
31
+ ```bash
32
+ pip install git+https://github.com/openai/CLIP.git
33
+ ```
34
+
35
+ ## High-level overview
36
+
37
+ Diffusion and flow models are largely agnostic to their data modality. They only require that the underlying data type supports a small set of operations. Building on this idea, *diffusiongym* is designed to be fully modular. You only need to provide the following:
38
+ * Data type `YourDataType` that implements `DDProtocol`, which defines some functions necessary for interacting with it as a flow model.
39
+ * Base model `BaseModel[YourDataType]`, which defines the scheduler, how to sample $p_0$, how to compute the forward pass, and how to preprocess and postprocess data.
40
+ * Reward function `Reward[YourDataType]`.
41
+
42
+ Once these are defined, you can sample from the flow model and apply reward adaptation methods, such as Value Matching.
43
+
44
+ ## Documentation
45
+
46
+ Much more information can be found in [the documentation](https://cristianpjensen.github.io/diffusiongym/), including tutorials and API references.
@@ -0,0 +1,59 @@
1
+ """Diffusion Gym package."""
2
+
3
+ from importlib.metadata import PackageNotFoundError, version
4
+
5
+ try:
6
+ __version__ = version("diffusiongym")
7
+ except PackageNotFoundError:
8
+ __version__ = "0.0.0"
9
+
10
+ from diffusiongym.base_models import BaseModel
11
+ from diffusiongym.environments import (
12
+ EndpointEnvironment,
13
+ Environment,
14
+ EpsilonEnvironment,
15
+ Sample,
16
+ ScoreEnvironment,
17
+ VelocityEnvironment,
18
+ )
19
+ from diffusiongym.make import construct_env, make
20
+ from diffusiongym.registry import base_model_registry, reward_registry
21
+ from diffusiongym.rewards import DummyReward, Reward
22
+ from diffusiongym.schedulers import (
23
+ ConstantNoiseSchedule,
24
+ CosineScheduler,
25
+ DiffusionScheduler,
26
+ MemorylessNoiseSchedule,
27
+ NoiseSchedule,
28
+ OptimalTransportScheduler,
29
+ Scheduler,
30
+ )
31
+ from diffusiongym.types import D, DDMixin, DDTensor
32
+ from diffusiongym.utils import train_base_model
33
+
34
+ __all__ = [
35
+ "BaseModel",
36
+ "ConstantNoiseSchedule",
37
+ "CosineScheduler",
38
+ "D",
39
+ "DDMixin",
40
+ "DDTensor",
41
+ "DiffusionScheduler",
42
+ "DummyReward",
43
+ "EndpointEnvironment",
44
+ "Environment",
45
+ "EpsilonEnvironment",
46
+ "MemorylessNoiseSchedule",
47
+ "NoiseSchedule",
48
+ "OptimalTransportScheduler",
49
+ "Reward",
50
+ "Sample",
51
+ "Scheduler",
52
+ "ScoreEnvironment",
53
+ "VelocityEnvironment",
54
+ "base_model_registry",
55
+ "construct_env",
56
+ "make",
57
+ "reward_registry",
58
+ "train_base_model",
59
+ ]
@@ -0,0 +1,6 @@
1
+ """Base models for flow matching and diffusion."""
2
+
3
+ from .base import BaseModel
4
+ from .one_dim_gmm import OneDimensionalBaseModel
5
+
6
+ __all__ = ["BaseModel", "OneDimensionalBaseModel"]
@@ -0,0 +1,163 @@
1
+ """Abstract base class for base models used in flow matching and diffusion models."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Any, Generic, Literal, Optional
5
+
6
+ import torch
7
+ from torch import nn
8
+
9
+ from diffusiongym.schedulers import Scheduler
10
+ from diffusiongym.types import D
11
+
12
+ OutputType = Literal["epsilon", "endpoint", "velocity", "score"]
13
+
14
+
15
+ class BaseModel(ABC, nn.Module, Generic[D]):
16
+ """Abstract base class for base models used in flow matching and diffusion."""
17
+
18
+ output_type: OutputType
19
+
20
+ def __init__(self, device: Optional[torch.device]):
21
+ super().__init__()
22
+
23
+ if device is None:
24
+ device = torch.device("cpu")
25
+
26
+ self.device = device
27
+
28
+ @property
29
+ @abstractmethod
30
+ def scheduler(self) -> Scheduler[D]:
31
+ """Base model-dependent scheduler used for sampling."""
32
+
33
+ @abstractmethod
34
+ def sample_p0(self, n: int, **kwargs: Any) -> tuple[D, dict[str, Any]]:
35
+ """Sample n data points from the base distribution p0.
36
+
37
+ Parameters
38
+ ----------
39
+ n : int
40
+ Number of samples to draw.
41
+ **kwargs : dict
42
+ Additional keyword arguments.
43
+
44
+ Returns
45
+ -------
46
+ samples : D
47
+ Samples from the base distribution p0.
48
+ kwargs : dict
49
+ Additional keyword arguments.
50
+ """
51
+
52
+ @abstractmethod
53
+ def forward(self, x: D, t: torch.Tensor, **kwargs: Any) -> D:
54
+ """Forward pass of the base model.
55
+
56
+ Parameters
57
+ ----------
58
+ x : D
59
+ Input data.
60
+ t : torch.Tensor, shape (n,)
61
+ Time steps, values in [0, 1].
62
+
63
+ Returns
64
+ -------
65
+ output : D
66
+ Output of the model.
67
+ """
68
+
69
+ def preprocess(self, x: D, **kwargs: Any) -> tuple[D, dict[str, Any]]:
70
+ """Preprocess data and keyword arguments for the base model.
71
+
72
+ Parameters
73
+ ----------
74
+ x : D
75
+ Input data to preprocess.
76
+ **kwargs : dict
77
+ Additional keyword arguments to preprocess.
78
+
79
+ Returns
80
+ -------
81
+ output : D
82
+ Preprocessed data.
83
+ kwargs : dict
84
+ Preprocessed keyword arguments.
85
+ """
86
+ return x, kwargs
87
+
88
+ def postprocess(self, x: D) -> D:
89
+ """Postprocess samples x_1 (e.g., decode with VAE).
90
+
91
+ Parameters
92
+ ----------
93
+ x : D
94
+ Input data to postprocess.
95
+
96
+ Returns
97
+ -------
98
+ output : D
99
+ Postprocessed output.
100
+ """
101
+ return x
102
+
103
+ def train_loss(
104
+ self,
105
+ x1: D,
106
+ xt: Optional[D] = None,
107
+ t: Optional[torch.Tensor] = None,
108
+ pred: Optional[D] = None,
109
+ **kwargs: Any,
110
+ ) -> torch.Tensor:
111
+ """Compute loss for a single batch training step.
112
+
113
+ Parameters
114
+ ----------
115
+ x1 : D
116
+ Target data points.
117
+ xt : Optional[D], default=None
118
+ Noisy data points at time t. If None, will be sampled.
119
+ t : Optional[torch.Tensor], shape (len(x1),), default=None
120
+ Time steps. If None, will be sampled.
121
+ pred : Optional[D], default=None
122
+ Model predictions. If None, will be computed by the model.
123
+ **kwargs : dict
124
+ Keyword arguments
125
+
126
+ Returns
127
+ -------
128
+ loss : torch.Tensor, shape (len(x1),)
129
+ Computed loss for the training step.
130
+ """
131
+ if t is None:
132
+ t = torch.rand(len(x1), device=x1.device)
133
+
134
+ assert t.shape == (len(x1),)
135
+
136
+ alpha = self.scheduler.alpha(x1, t)
137
+ beta = self.scheduler.beta(x1, t)
138
+
139
+ if xt is None:
140
+ x0 = x1.randn_like()
141
+ xt = alpha * x1 + beta * x0
142
+ else:
143
+ assert len(xt) == len(x1)
144
+ x0 = (xt - alpha * x1) / beta
145
+
146
+ if pred is None:
147
+ pred = self.forward(xt, t, **kwargs)
148
+
149
+ target = None
150
+ if self.output_type == "velocity":
151
+ alpha_dot = self.scheduler.alpha_dot(x1, t)
152
+ beta_dot = self.scheduler.beta_dot(x1, t)
153
+ target = alpha_dot * x1 + beta_dot * x0
154
+ elif self.output_type == "score":
155
+ target = -x0 / beta
156
+ elif self.output_type == "endpoint":
157
+ target = x1
158
+ elif self.output_type == "epsilon":
159
+ target = x0
160
+ else:
161
+ raise ValueError(f"Unknown output type: {self.output_type}")
162
+
163
+ return ((pred - target) ** 2).aggregate("mean")