rectified-flow-pytorch 0.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,36 @@
1
+ # This workflow will upload a Python Package using Twine when a release is created
2
+ # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3
+
4
+ # This workflow uses actions that are not certified by GitHub.
5
+ # They are provided by a third-party and are governed by
6
+ # separate terms of service, privacy policy, and support
7
+ # documentation.
8
+
9
+ name: Upload Python Package
10
+
11
+ on:
12
+ release:
13
+ types: [published]
14
+
15
+ jobs:
16
+ deploy:
17
+
18
+ runs-on: ubuntu-latest
19
+
20
+ steps:
21
+ - uses: actions/checkout@v2
22
+ - name: Set up Python
23
+ uses: actions/setup-python@v2
24
+ with:
25
+ python-version: '3.x'
26
+ - name: Install dependencies
27
+ run: |
28
+ python -m pip install --upgrade pip
29
+ pip install build
30
+ - name: Build package
31
+ run: python -m build
32
+ - name: Publish package
33
+ uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
34
+ with:
35
+ user: __token__
36
+ password: ${{ secrets.PYPI_API_TOKEN }}
@@ -0,0 +1,162 @@
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110
+ .pdm.toml
111
+ .pdm-python
112
+ .pdm-build/
113
+
114
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115
+ __pypackages__/
116
+
117
+ # Celery stuff
118
+ celerybeat-schedule
119
+ celerybeat.pid
120
+
121
+ # SageMath parsed files
122
+ *.sage.py
123
+
124
+ # Environments
125
+ .env
126
+ .venv
127
+ env/
128
+ venv/
129
+ ENV/
130
+ env.bak/
131
+ venv.bak/
132
+
133
+ # Spyder project settings
134
+ .spyderproject
135
+ .spyproject
136
+
137
+ # Rope project settings
138
+ .ropeproject
139
+
140
+ # mkdocs documentation
141
+ /site
142
+
143
+ # mypy
144
+ .mypy_cache/
145
+ .dmypy.json
146
+ dmypy.json
147
+
148
+ # Pyre type checker
149
+ .pyre/
150
+
151
+ # pytype static type analyzer
152
+ .pytype/
153
+
154
+ # Cython debug symbols
155
+ cython_debug/
156
+
157
+ # PyCharm
158
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
161
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
+ #.idea/
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Phil Wang
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,138 @@
1
+ Metadata-Version: 2.3
2
+ Name: rectified-flow-pytorch
3
+ Version: 0.0.1
4
+ Summary: Rectified Flow in Pytorch
5
+ Project-URL: Homepage, https://pypi.org/project/rectified-flow-pytorch/
6
+ Project-URL: Repository, https://github.com/lucidrains/rectified-flow-pytorch
7
+ Author-email: Phil Wang <lucidrains@gmail.com>
8
+ License: MIT License
9
+
10
+ Copyright (c) 2024 Phil Wang
11
+
12
+ Permission is hereby granted, free of charge, to any person obtaining a copy
13
+ of this software and associated documentation files (the "Software"), to deal
14
+ in the Software without restriction, including without limitation the rights
15
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
16
+ copies of the Software, and to permit persons to whom the Software is
17
+ furnished to do so, subject to the following conditions:
18
+
19
+ The above copyright notice and this permission notice shall be included in all
20
+ copies or substantial portions of the Software.
21
+
22
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
23
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
24
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
25
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
26
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
27
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
28
+ SOFTWARE.
29
+ License-File: LICENSE
30
+ Keywords: artificial intelligence,deep learning,rectified flow
31
+ Classifier: Development Status :: 4 - Beta
32
+ Classifier: Intended Audience :: Developers
33
+ Classifier: License :: OSI Approved :: MIT License
34
+ Classifier: Programming Language :: Python :: 3.8
35
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
+ Requires-Python: >=3.8
37
+ Requires-Dist: einops>=0.8.0
38
+ Requires-Dist: scipy
39
+ Requires-Dist: torch>=2.0
40
+ Requires-Dist: torchdiffeq
41
+ Provides-Extra: examples
42
+ Description-Content-Type: text/markdown
43
+
44
+ ## Rectified Flow - Pytorch (wip)
45
+
46
+ Implementation of rectified flow and some of its followup research / improvements in Pytorch
47
+
48
+ ## Install
49
+
50
+ ```bash
51
+ $ pip install rectified-flow-pytorch
52
+ ```
53
+
54
+ ## Usage
55
+
56
+ ```python
57
+ import torch
58
+ from torch import nn
59
+
60
+ from rectified_flow_pytorch import RectifiedFlow
61
+
62
+ model = nn.Conv2d(3, 3, 1)
63
+
64
+ rectified_flow = RectifiedFlow(model, time_cond_kwarg = None)
65
+
66
+ images = torch.randn(1, 3, 256, 256)
67
+
68
+ loss = rectified_flow(images)
69
+ loss.backward()
70
+
71
+ sampled = rectified_flow.sample()
72
+ assert sampled.shape == images.shape
73
+ ```
74
+
75
+ For reflow as described in the paper
76
+
77
+ ```python
78
+ import torch
79
+ from torch import nn
80
+
81
+ from rectified_flow_pytorch import RectifiedFlow, Reflow
82
+
83
+ model = nn.Conv2d(3, 3, 1)
84
+
85
+ rectified_flow = RectifiedFlow(model, time_cond_kwarg = None)
86
+
87
+ images = torch.randn(1, 3, 256, 256)
88
+
89
+ loss = rectified_flow(images)
90
+ loss.backward()
91
+
92
+ # first train on many images
93
+
94
+ reflow = Reflow(rectified_flow)
95
+
96
+ reflow_loss = reflow()
97
+ reflow_loss.backward()
98
+
99
+ # then do the above in a loop many times
100
+
101
+ sampled = reflow.sample()
102
+ assert sampled.shape == images.shape
103
+ ```
104
+
105
+ ## Citations
106
+
107
+ ```bibtex
108
+ @article{Liu2022FlowSA,
109
+ title = {Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow},
110
+ author = {Xingchao Liu and Chengyue Gong and Qiang Liu},
111
+ journal = {ArXiv},
112
+ year = {2022},
113
+ volume = {abs/2209.03003},
114
+ url = {https://api.semanticscholar.org/CorpusID:252111177}
115
+ }
116
+ ```
117
+
118
+ ```bibtex
119
+ @article{Lee2024ImprovingTT,
120
+ title = {Improving the Training of Rectified Flows},
121
+ author = {Sangyun Lee and Zinan Lin and Giulia Fanti},
122
+ journal = {ArXiv},
123
+ year = {2024},
124
+ volume = {abs/2405.20320},
125
+ url = {https://api.semanticscholar.org/CorpusID:270123378}
126
+ }
127
+ ```
128
+
129
+ ```bibtex
130
+ @article{Esser2024ScalingRF,
131
+ title = {Scaling Rectified Flow Transformers for High-Resolution Image Synthesis},
132
+ author = {Patrick Esser and Sumith Kulal and A. Blattmann and Rahim Entezari and Jonas Muller and Harry Saini and Yam Levi and Dominik Lorenz and Axel Sauer and Frederic Boesel and Dustin Podell and Tim Dockhorn and Zion English and Kyle Lacey and Alex Goodwin and Yannik Marek and Robin Rombach},
133
+ journal = {ArXiv},
134
+ year = {2024},
135
+ volume = {abs/2403.03206},
136
+ url = {https://api.semanticscholar.org/CorpusID:268247980}
137
+ }
138
+ ```
@@ -0,0 +1,95 @@
1
+ ## Rectified Flow - Pytorch (wip)
2
+
3
+ Implementation of rectified flow and some of its followup research / improvements in Pytorch
4
+
5
+ ## Install
6
+
7
+ ```bash
8
+ $ pip install rectified-flow-pytorch
9
+ ```
10
+
11
+ ## Usage
12
+
13
+ ```python
14
+ import torch
15
+ from torch import nn
16
+
17
+ from rectified_flow_pytorch import RectifiedFlow
18
+
19
+ model = nn.Conv2d(3, 3, 1)
20
+
21
+ rectified_flow = RectifiedFlow(model, time_cond_kwarg = None)
22
+
23
+ images = torch.randn(1, 3, 256, 256)
24
+
25
+ loss = rectified_flow(images)
26
+ loss.backward()
27
+
28
+ sampled = rectified_flow.sample()
29
+ assert sampled.shape == images.shape
30
+ ```
31
+
32
+ For reflow as described in the paper
33
+
34
+ ```python
35
+ import torch
36
+ from torch import nn
37
+
38
+ from rectified_flow_pytorch import RectifiedFlow, Reflow
39
+
40
+ model = nn.Conv2d(3, 3, 1)
41
+
42
+ rectified_flow = RectifiedFlow(model, time_cond_kwarg = None)
43
+
44
+ images = torch.randn(1, 3, 256, 256)
45
+
46
+ loss = rectified_flow(images)
47
+ loss.backward()
48
+
49
+ # first train on many images
50
+
51
+ reflow = Reflow(rectified_flow)
52
+
53
+ reflow_loss = reflow()
54
+ reflow_loss.backward()
55
+
56
+ # then do the above in a loop many times
57
+
58
+ sampled = reflow.sample()
59
+ assert sampled.shape == images.shape
60
+ ```
61
+
62
+ ## Citations
63
+
64
+ ```bibtex
65
+ @article{Liu2022FlowSA,
66
+ title = {Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow},
67
+ author = {Xingchao Liu and Chengyue Gong and Qiang Liu},
68
+ journal = {ArXiv},
69
+ year = {2022},
70
+ volume = {abs/2209.03003},
71
+ url = {https://api.semanticscholar.org/CorpusID:252111177}
72
+ }
73
+ ```
74
+
75
+ ```bibtex
76
+ @article{Lee2024ImprovingTT,
77
+ title = {Improving the Training of Rectified Flows},
78
+ author = {Sangyun Lee and Zinan Lin and Giulia Fanti},
79
+ journal = {ArXiv},
80
+ year = {2024},
81
+ volume = {abs/2405.20320},
82
+ url = {https://api.semanticscholar.org/CorpusID:270123378}
83
+ }
84
+ ```
85
+
86
+ ```bibtex
87
+ @article{Esser2024ScalingRF,
88
+ title = {Scaling Rectified Flow Transformers for High-Resolution Image Synthesis},
89
+ author = {Patrick Esser and Sumith Kulal and A. Blattmann and Rahim Entezari and Jonas Muller and Harry Saini and Yam Levi and Dominik Lorenz and Axel Sauer and Frederic Boesel and Dustin Podell and Tim Dockhorn and Zion English and Kyle Lacey and Alex Goodwin and Yannik Marek and Robin Rombach},
90
+ journal = {ArXiv},
91
+ year = {2024},
92
+ volume = {abs/2403.03206},
93
+ url = {https://api.semanticscholar.org/CorpusID:268247980}
94
+ }
95
+ ```
@@ -0,0 +1,46 @@
1
+ [project]
2
+ name = "rectified-flow-pytorch"
3
+ version = "0.0.1"
4
+ description = "Rectified Flow in Pytorch"
5
+ authors = [
6
+ { name = "Phil Wang", email = "lucidrains@gmail.com" }
7
+ ]
8
+ readme = "README.md"
9
+ requires-python = ">= 3.8"
10
+ license = { file = "LICENSE" }
11
+ keywords = [
12
+ 'artificial intelligence',
13
+ 'deep learning',
14
+ 'rectified flow'
15
+ ]
16
+ classifiers=[
17
+ 'Development Status :: 4 - Beta',
18
+ 'Intended Audience :: Developers',
19
+ 'Topic :: Scientific/Engineering :: Artificial Intelligence',
20
+ 'License :: OSI Approved :: MIT License',
21
+ 'Programming Language :: Python :: 3.8',
22
+ ]
23
+
24
+ dependencies = [
25
+ 'einops>=0.8.0',
26
+ 'scipy',
27
+ 'torch>=2.0',
28
+ 'torchdiffeq',
29
+ ]
30
+
31
+ [project.urls]
32
+ Homepage = "https://pypi.org/project/rectified-flow-pytorch/"
33
+ Repository = "https://github.com/lucidrains/rectified-flow-pytorch"
34
+
35
+ [project.optional-dependencies]
36
+ examples = []
37
+
38
+ [build-system]
39
+ requires = ["hatchling"]
40
+ build-backend = "hatchling.build"
41
+
42
+ [tool.hatch.metadata]
43
+ allow-direct-references = true
44
+
45
+ [tool.hatch.build.targets.wheel]
46
+ packages = ["rectified_flow_pytorch"]
@@ -0,0 +1 @@
1
+ from rectified_flow_pytorch.rectified_flow import RectifiedFlow, Reflow
@@ -0,0 +1,168 @@
1
+ from __future__ import annotations
2
+ from typing import Tuple
3
+ from copy import deepcopy
4
+
5
+ import torch
6
+ from torch.nn import Module
7
+ import torch.nn.functional as F
8
+
9
+ from torchdiffeq import odeint
10
+
11
+ # helpers
12
+
13
+ def exists(v):
14
+ return v is not None
15
+
16
+ def default(v, d):
17
+ return v if exists(v) else d
18
+
19
+ # tensor helpers
20
+
21
+ def append_dims(t, ndims):
22
+ shape = t.shape
23
+ return t.reshape(*shape, *((1,) * ndims))
24
+
25
+ # main class
26
+
27
+ class RectifiedFlow(Module):
28
+ def __init__(
29
+ self,
30
+ model: Module,
31
+ time_cond_kwarg: str | None = 'times',
32
+ odeint_kwargs: dict = dict(
33
+ atol = 1e-5,
34
+ rtol = 1e-5,
35
+ method = 'midpoint'
36
+ ),
37
+ data_shape: Tuple[int, ...] | None = None,
38
+ ):
39
+ super().__init__()
40
+ self.model = model
41
+ self.time_cond_kwarg = time_cond_kwarg # whether the model is to be conditioned on the times
42
+
43
+ # sampling
44
+
45
+ self.odeint_kwargs = odeint_kwargs
46
+ self.data_shape = data_shape
47
+
48
+ @property
49
+ def device(self):
50
+ return next(self.model.parameters()).device
51
+
52
+ @torch.no_grad()
53
+ def sample(
54
+ self,
55
+ batch_size = 1,
56
+ steps = 16,
57
+ noise = None,
58
+ data_shape: Tuple[int, ...] | None = None,
59
+ **model_kwargs
60
+ ):
61
+ was_training = self.training
62
+ self.eval()
63
+
64
+ data_shape = default(data_shape, self.data_shape)
65
+ assert exists(data_shape), 'you need to either pass in a `data_shape` or have trained at least with one forward'
66
+
67
+ def ode_fn(t, x):
68
+ time_kwarg = self.time_cond_kwarg
69
+
70
+ if exists(time_kwarg):
71
+ model_kwargs.update(**{time_kwarg: t})
72
+
73
+ return self.model(x, **model_kwargs)
74
+
75
+ # start with random gaussian noise - y0
76
+
77
+ noise = default(noise, torch.randn((batch_size, *data_shape)))
78
+
79
+ # time steps
80
+
81
+ times = torch.linspace(0., 1., steps, device = self.device)
82
+
83
+ # ode
84
+
85
+ trajectory = odeint(ode_fn, noise, times, **self.odeint_kwargs)
86
+
87
+ sampled_data = trajectory[-1]
88
+
89
+ self.train(was_training)
90
+ return sampled_data
91
+
92
+ def forward(
93
+ self,
94
+ data,
95
+ noise = None,
96
+ **model_kwargs
97
+ ):
98
+ batch, *data_shape = data.shape
99
+
100
+ self.data_shape = default(self.data_shape, data_shape)
101
+
102
+ # x0 - gaussian noise, x1 - data
103
+
104
+ noise = default(noise, torch.randn_like(data))
105
+
106
+ # times, and times with dimension padding on right
107
+
108
+ times = torch.rand(batch, device = self.device)
109
+ padded_times = append_dims(times, data.ndim - 1)
110
+
111
+ # Algorithm 2 in paper
112
+ # linear interpolation of noise with data using random times
113
+ # x1 * t + x0 * (1 - t) - so from noise (time = 0) to data (time = 1.)
114
+
115
+ noised = padded_times * data + (1. - padded_times) * noise
116
+
117
+ # prepare maybe time conditioning for model
118
+
119
+ time_kwarg = self.time_cond_kwarg
120
+
121
+ if exists(time_kwarg):
122
+ model_kwargs.update(**{time_kwarg: times})
123
+
124
+ # the model predicts the flow from the noised data
125
+
126
+ flow = data - noise
127
+ pred_flow = self.model(noised, **model_kwargs)
128
+
129
+ loss = F.mse_loss(pred_flow, flow)
130
+
131
+ return loss
132
+
133
+ # reflow wrapper
134
+
135
+ class Reflow(Module):
136
+ def __init__(
137
+ self,
138
+ rectified_flow: RectifiedFlow,
139
+ *,
140
+ batch_size = 16,
141
+
142
+ ):
143
+ super().__init__()
144
+ model, data_shape = rectified_flow.model, rectified_flow.data_shape
145
+ assert exists(data_shape), '`data_shape` must be defined in RectifiedFlow'
146
+
147
+ self.batch_size = batch_size
148
+ self.data_shape = data_shape
149
+
150
+ self.model = rectified_flow
151
+ self.frozen_model = deepcopy(rectified_flow)
152
+
153
+ def parameters(self):
154
+ return self.model.parameters() # omit frozen model
155
+
156
+ def sample(self, *args, **kwargs):
157
+ return self.model.sample(*args, **kwargs)
158
+
159
+ def forward(self):
160
+
161
+ noise = torch.randn((self.batch_size, *self.data_shape))
162
+ sampled_output = self.frozen_model.sample(noise = noise)
163
+
164
+ # the coupling in the paper is (noise, sampled_output)
165
+
166
+ loss = self.model(sampled_output, noise = noise)
167
+
168
+ return loss