ema-pytorch 0.3.1__tar.gz → 0.6.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {ema-pytorch-0.3.1 → ema_pytorch-0.6.3}/PKG-INFO +2 -3
- ema_pytorch-0.6.3/README.md +109 -0
- ema_pytorch-0.6.3/ema_pytorch/__init__.py +6 -0
- {ema-pytorch-0.3.1 → ema_pytorch-0.6.3}/ema_pytorch/ema_pytorch.py +109 -23
- ema_pytorch-0.6.3/ema_pytorch/post_hoc_ema.py +420 -0
- {ema-pytorch-0.3.1 → ema_pytorch-0.6.3}/ema_pytorch.egg-info/PKG-INFO +2 -3
- {ema-pytorch-0.3.1 → ema_pytorch-0.6.3}/ema_pytorch.egg-info/SOURCES.txt +1 -0
- ema_pytorch-0.6.3/ema_pytorch.egg-info/requires.txt +1 -0
- {ema-pytorch-0.3.1 → ema_pytorch-0.6.3}/setup.py +2 -3
- ema-pytorch-0.3.1/README.md +0 -54
- ema-pytorch-0.3.1/ema_pytorch/__init__.py +0 -1
- ema-pytorch-0.3.1/ema_pytorch.egg-info/requires.txt +0 -2
- {ema-pytorch-0.3.1 → ema_pytorch-0.6.3}/LICENSE +0 -0
- {ema-pytorch-0.3.1 → ema_pytorch-0.6.3}/ema_pytorch.egg-info/dependency_links.txt +0 -0
- {ema-pytorch-0.3.1 → ema_pytorch-0.6.3}/ema_pytorch.egg-info/top_level.txt +0 -0
- {ema-pytorch-0.3.1 → ema_pytorch-0.6.3}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: ema-pytorch
|
|
3
|
-
Version: 0.3
|
|
3
|
+
Version: 0.6.3
|
|
4
4
|
Summary: Easy way to keep track of exponential moving average version of your pytorch module
|
|
5
5
|
Home-page: https://github.com/lucidrains/ema-pytorch
|
|
6
6
|
Author: Phil Wang
|
|
@@ -14,5 +14,4 @@ Classifier: License :: OSI Approved :: MIT License
|
|
|
14
14
|
Classifier: Programming Language :: Python :: 3.6
|
|
15
15
|
Description-Content-Type: text/markdown
|
|
16
16
|
License-File: LICENSE
|
|
17
|
-
Requires-Dist:
|
|
18
|
-
Requires-Dist: torch>=1.6
|
|
17
|
+
Requires-Dist: torch>=2.0
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
## EMA - Pytorch
|
|
2
|
+
|
|
3
|
+
A simple way to keep track of an Exponential Moving Average (EMA) version of your pytorch model
|
|
4
|
+
|
|
5
|
+
## Install
|
|
6
|
+
|
|
7
|
+
```bash
|
|
8
|
+
$ pip install ema-pytorch
|
|
9
|
+
```
|
|
10
|
+
|
|
11
|
+
## Usage
|
|
12
|
+
|
|
13
|
+
```python
|
|
14
|
+
import torch
|
|
15
|
+
from ema_pytorch import EMA
|
|
16
|
+
|
|
17
|
+
# your neural network as a pytorch module
|
|
18
|
+
|
|
19
|
+
net = torch.nn.Linear(512, 512)
|
|
20
|
+
|
|
21
|
+
# wrap your neural network, specify the decay (beta)
|
|
22
|
+
|
|
23
|
+
ema = EMA(
|
|
24
|
+
net,
|
|
25
|
+
beta = 0.9999, # exponential moving average factor
|
|
26
|
+
update_after_step = 100, # only after this number of .update() calls will it start updating
|
|
27
|
+
update_every = 10, # how often to actually update, to save on compute (updates every 10th .update() call)
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
# mutate your network, with SGD or otherwise
|
|
31
|
+
|
|
32
|
+
with torch.no_grad():
|
|
33
|
+
net.weight.copy_(torch.randn_like(net.weight))
|
|
34
|
+
net.bias.copy_(torch.randn_like(net.bias))
|
|
35
|
+
|
|
36
|
+
# you will call the update function on your moving average wrapper
|
|
37
|
+
|
|
38
|
+
ema.update()
|
|
39
|
+
|
|
40
|
+
# then, later on, you can invoke the EMA model the same way as your network
|
|
41
|
+
|
|
42
|
+
data = torch.randn(1, 512)
|
|
43
|
+
|
|
44
|
+
output = net(data)
|
|
45
|
+
ema_output = ema(data)
|
|
46
|
+
|
|
47
|
+
# if you want to save your ema model, it is recommended you save the entire wrapper
|
|
48
|
+
# as it contains the number of steps taken (there is a warmup logic in there, recommended by @crowsonkb, validated for a number of projects now)
|
|
49
|
+
# however, if you wish to access the copy of your model with EMA, then it will live at ema.ema_model
|
|
50
|
+
```
|
|
51
|
+
|
|
52
|
+
In order to use the post-hoc synthesized EMA, proposed by Karras et al. in <a href="https://arxiv.org/abs/2312.02696">a recent paper</a>, follow the example below
|
|
53
|
+
|
|
54
|
+
```python
|
|
55
|
+
import torch
|
|
56
|
+
from ema_pytorch import PostHocEMA
|
|
57
|
+
|
|
58
|
+
# your neural network as a pytorch module
|
|
59
|
+
|
|
60
|
+
net = torch.nn.Linear(512, 512)
|
|
61
|
+
|
|
62
|
+
# wrap your neural network, specify the sigma_rels or gammas
|
|
63
|
+
|
|
64
|
+
emas = PostHocEMA(
|
|
65
|
+
net,
|
|
66
|
+
sigma_rels = (0.05, 0.3), # a tuple with the hyperparameter for the multiple EMAs. you need at least 2 here to synthesize a new one
|
|
67
|
+
update_every = 10, # how often to actually update, to save on compute (updates every 10th .update() call)
|
|
68
|
+
checkpoint_every_num_steps = 10,
|
|
69
|
+
checkpoint_folder = './post-hoc-ema-checkpoints' # the folder of saved checkpoints for each sigma_rel (gamma) across timesteps with the hparam above, used to synthesizing a new EMA model after training
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
net.train()
|
|
73
|
+
|
|
74
|
+
for _ in range(1000):
|
|
75
|
+
# mutate your network, with SGD or otherwise
|
|
76
|
+
|
|
77
|
+
with torch.no_grad():
|
|
78
|
+
net.weight.copy_(torch.randn_like(net.weight))
|
|
79
|
+
net.bias.copy_(torch.randn_like(net.bias))
|
|
80
|
+
|
|
81
|
+
# you will call the update function on your moving average wrapper
|
|
82
|
+
|
|
83
|
+
emas.update()
|
|
84
|
+
|
|
85
|
+
# now that you have a few checkpoints
|
|
86
|
+
# you can synthesize an EMA model with a different sigma_rel (say 0.15)
|
|
87
|
+
|
|
88
|
+
synthesized_ema = emas.synthesize_ema_model(sigma_rel = 0.15)
|
|
89
|
+
|
|
90
|
+
# output with synthesized EMA
|
|
91
|
+
|
|
92
|
+
data = torch.randn(1, 512)
|
|
93
|
+
|
|
94
|
+
synthesized_ema_output = synthesized_ema(data)
|
|
95
|
+
|
|
96
|
+
```
|
|
97
|
+
|
|
98
|
+
## Citations
|
|
99
|
+
|
|
100
|
+
```bibtex
|
|
101
|
+
@article{Karras2023AnalyzingAI,
|
|
102
|
+
title = {Analyzing and Improving the Training Dynamics of Diffusion Models},
|
|
103
|
+
author = {Tero Karras and Miika Aittala and Jaakko Lehtinen and Janne Hellsten and Timo Aila and Samuli Laine},
|
|
104
|
+
journal = {ArXiv},
|
|
105
|
+
year = {2023},
|
|
106
|
+
volume = {abs/2312.02696},
|
|
107
|
+
url = {https://api.semanticscholar.org/CorpusID:265659032}
|
|
108
|
+
}
|
|
109
|
+
```
|
|
@@ -1,3 +1,6 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import Set, Tuple
|
|
3
|
+
|
|
1
4
|
from copy import deepcopy
|
|
2
5
|
from functools import partial
|
|
3
6
|
|
|
@@ -5,26 +8,35 @@ import torch
|
|
|
5
8
|
from torch import nn, Tensor
|
|
6
9
|
from torch.nn import Module
|
|
7
10
|
|
|
8
|
-
from beartype import beartype
|
|
9
|
-
from beartype.typing import Set, Optional
|
|
10
|
-
|
|
11
11
|
def exists(val):
|
|
12
12
|
return val is not None
|
|
13
13
|
|
|
14
14
|
def get_module_device(m: Module):
|
|
15
15
|
return next(m.parameters()).device
|
|
16
16
|
|
|
17
|
-
def
|
|
17
|
+
def maybe_coerce_dtype(t, dtype):
|
|
18
|
+
if t.dtype == dtype:
|
|
19
|
+
return t
|
|
20
|
+
|
|
21
|
+
return t.to(dtype)
|
|
22
|
+
|
|
23
|
+
def inplace_copy(tgt: Tensor, src: Tensor, *, auto_move_device = False, coerce_dtype = False):
|
|
18
24
|
if auto_move_device:
|
|
19
|
-
|
|
25
|
+
src = src.to(tgt.device)
|
|
26
|
+
|
|
27
|
+
if coerce_dtype:
|
|
28
|
+
src = maybe_coerce_dtype(src, tgt.dtype)
|
|
20
29
|
|
|
21
|
-
|
|
30
|
+
tgt.copy_(src)
|
|
22
31
|
|
|
23
|
-
def inplace_lerp(
|
|
32
|
+
def inplace_lerp(tgt: Tensor, src: Tensor, weight, *, auto_move_device = False, coerce_dtype = False):
|
|
24
33
|
if auto_move_device:
|
|
25
|
-
|
|
34
|
+
src = src.to(tgt.device)
|
|
26
35
|
|
|
27
|
-
|
|
36
|
+
if coerce_dtype:
|
|
37
|
+
src = maybe_coerce_dtype(src, tgt.dtype)
|
|
38
|
+
|
|
39
|
+
tgt.lerp_(src, weight)
|
|
28
40
|
|
|
29
41
|
class EMA(Module):
|
|
30
42
|
"""
|
|
@@ -43,15 +55,14 @@ class EMA(Module):
|
|
|
43
55
|
|
|
44
56
|
Args:
|
|
45
57
|
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
|
|
46
|
-
power (float): Exponential factor of EMA warmup. Default:
|
|
58
|
+
power (float): Exponential factor of EMA warmup. Default: 2/3.
|
|
47
59
|
min_value (float): The minimum EMA decay rate. Default: 0.
|
|
48
60
|
"""
|
|
49
61
|
|
|
50
|
-
@beartype
|
|
51
62
|
def __init__(
|
|
52
63
|
self,
|
|
53
64
|
model: Module,
|
|
54
|
-
ema_model:
|
|
65
|
+
ema_model: Module | None = None, # if your model has lazylinears or other types of non-deepcopyable modules, you can pass in your own ema model
|
|
55
66
|
beta = 0.9999,
|
|
56
67
|
update_after_step = 100,
|
|
57
68
|
update_every = 10,
|
|
@@ -62,11 +73,17 @@ class EMA(Module):
|
|
|
62
73
|
ignore_names: Set[str] = set(),
|
|
63
74
|
ignore_startswith_names: Set[str] = set(),
|
|
64
75
|
include_online_model = True, # set this to False if you do not wish for the online model to be saved along with the ema model (managed externally)
|
|
65
|
-
allow_different_devices = False
|
|
76
|
+
allow_different_devices = False, # if the EMA model is on a different device (say CPU), automatically move the tensor
|
|
77
|
+
use_foreach = False,
|
|
78
|
+
forward_method_names: Tuple[str, ...] = (),
|
|
79
|
+
move_ema_to_online_device = False,
|
|
80
|
+
coerce_dtype = False
|
|
66
81
|
):
|
|
67
82
|
super().__init__()
|
|
68
83
|
self.beta = beta
|
|
69
84
|
|
|
85
|
+
self.is_frozen = beta == 1.
|
|
86
|
+
|
|
70
87
|
# whether to include the online model within the module tree, so that state_dict also saves it
|
|
71
88
|
|
|
72
89
|
self.include_online_model = include_online_model
|
|
@@ -88,17 +105,24 @@ class EMA(Module):
|
|
|
88
105
|
print('Your model was not copyable. Please make sure you are not using any LazyLinear')
|
|
89
106
|
exit()
|
|
90
107
|
|
|
91
|
-
self.ema_model.
|
|
108
|
+
for p in self.ema_model.parameters():
|
|
109
|
+
p.detach_()
|
|
110
|
+
|
|
111
|
+
# forwarding methods
|
|
112
|
+
|
|
113
|
+
for forward_method_name in forward_method_names:
|
|
114
|
+
fn = getattr(self.ema_model, forward_method_name)
|
|
115
|
+
setattr(self, forward_method_name, fn)
|
|
92
116
|
|
|
93
117
|
# parameter and buffer names
|
|
94
118
|
|
|
95
|
-
self.parameter_names = {name for name, param in self.ema_model.named_parameters() if param
|
|
96
|
-
self.buffer_names = {name for name, buffer in self.ema_model.named_buffers() if buffer
|
|
119
|
+
self.parameter_names = {name for name, param in self.ema_model.named_parameters() if torch.is_floating_point(param) or torch.is_complex(param)}
|
|
120
|
+
self.buffer_names = {name for name, buffer in self.ema_model.named_buffers() if torch.is_floating_point(buffer) or torch.is_complex(buffer)}
|
|
97
121
|
|
|
98
122
|
# tensor update functions
|
|
99
123
|
|
|
100
|
-
self.inplace_copy = partial(inplace_copy, auto_move_device = allow_different_devices)
|
|
101
|
-
self.inplace_lerp = partial(inplace_lerp, auto_move_device = allow_different_devices)
|
|
124
|
+
self.inplace_copy = partial(inplace_copy, auto_move_device = allow_different_devices, coerce_dtype = coerce_dtype)
|
|
125
|
+
self.inplace_lerp = partial(inplace_lerp, auto_move_device = allow_different_devices, coerce_dtype = coerce_dtype)
|
|
102
126
|
|
|
103
127
|
# updating hyperparameters
|
|
104
128
|
|
|
@@ -119,6 +143,21 @@ class EMA(Module):
|
|
|
119
143
|
|
|
120
144
|
self.allow_different_devices = allow_different_devices
|
|
121
145
|
|
|
146
|
+
# whether to coerce dtype when copy or lerp from online to EMA model
|
|
147
|
+
|
|
148
|
+
self.coerce_dtype = coerce_dtype
|
|
149
|
+
|
|
150
|
+
# whether to move EMA model to online model device automatically
|
|
151
|
+
|
|
152
|
+
self.move_ema_to_online_device = move_ema_to_online_device
|
|
153
|
+
|
|
154
|
+
# whether to use foreach
|
|
155
|
+
|
|
156
|
+
if use_foreach:
|
|
157
|
+
assert hasattr(torch, '_foreach_lerp_') and hasattr(torch, '_foreach_copy_'), 'your version of torch does not have the prerequisite foreach functions'
|
|
158
|
+
|
|
159
|
+
self.use_foreach = use_foreach
|
|
160
|
+
|
|
122
161
|
# init and step states
|
|
123
162
|
|
|
124
163
|
self.register_buffer('initted', torch.tensor(False))
|
|
@@ -193,9 +232,25 @@ class EMA(Module):
|
|
|
193
232
|
|
|
194
233
|
@torch.no_grad()
|
|
195
234
|
def update_moving_average(self, ma_model, current_model):
|
|
196
|
-
|
|
235
|
+
if self.is_frozen:
|
|
236
|
+
return
|
|
237
|
+
|
|
238
|
+
# move ema model to online model device if not same and needed
|
|
239
|
+
|
|
240
|
+
if self.move_ema_to_online_device and get_module_device(ma_model) != get_module_device(current_model):
|
|
241
|
+
ma_model.to(get_module_device(current_model))
|
|
242
|
+
|
|
243
|
+
# get current decay
|
|
244
|
+
|
|
197
245
|
current_decay = self.get_current_decay()
|
|
198
246
|
|
|
247
|
+
# store all source and target tensors to copy or lerp
|
|
248
|
+
|
|
249
|
+
tensors_to_copy = []
|
|
250
|
+
tensors_to_lerp = []
|
|
251
|
+
|
|
252
|
+
# loop through parameters
|
|
253
|
+
|
|
199
254
|
for (name, current_params), (_, ma_params) in zip(self.get_params_iter(current_model), self.get_params_iter(ma_model)):
|
|
200
255
|
if name in self.ignore_names:
|
|
201
256
|
continue
|
|
@@ -204,10 +259,12 @@ class EMA(Module):
|
|
|
204
259
|
continue
|
|
205
260
|
|
|
206
261
|
if name in self.param_or_buffer_names_no_ema:
|
|
207
|
-
|
|
262
|
+
tensors_to_copy.append((ma_params.data, current_params.data))
|
|
208
263
|
continue
|
|
209
264
|
|
|
210
|
-
|
|
265
|
+
tensors_to_lerp.append((ma_params.data, current_params.data))
|
|
266
|
+
|
|
267
|
+
# loop through buffers
|
|
211
268
|
|
|
212
269
|
for (name, current_buffer), (_, ma_buffer) in zip(self.get_buffers_iter(current_model), self.get_buffers_iter(ma_model)):
|
|
213
270
|
if name in self.ignore_names:
|
|
@@ -217,10 +274,39 @@ class EMA(Module):
|
|
|
217
274
|
continue
|
|
218
275
|
|
|
219
276
|
if name in self.param_or_buffer_names_no_ema:
|
|
220
|
-
|
|
277
|
+
tensors_to_copy.append((ma_buffer.data, current_buffer.data))
|
|
221
278
|
continue
|
|
222
279
|
|
|
223
|
-
|
|
280
|
+
tensors_to_lerp.append((ma_buffer.data, current_buffer.data))
|
|
281
|
+
|
|
282
|
+
# execute inplace copy or lerp
|
|
283
|
+
|
|
284
|
+
if not self.use_foreach:
|
|
285
|
+
|
|
286
|
+
for tgt, src in tensors_to_copy:
|
|
287
|
+
self.inplace_copy(tgt, src)
|
|
288
|
+
|
|
289
|
+
for tgt, src in tensors_to_lerp:
|
|
290
|
+
self.inplace_lerp(tgt, src, 1. - current_decay)
|
|
291
|
+
|
|
292
|
+
else:
|
|
293
|
+
# use foreach if available and specified
|
|
294
|
+
|
|
295
|
+
if self.allow_different_devices:
|
|
296
|
+
tensors_to_copy = [(tgt, src.to(tgt.device)) for tgt, src in tensors_to_copy]
|
|
297
|
+
tensors_to_lerp = [(tgt, src.to(tgt.device)) for tgt, src in tensors_to_lerp]
|
|
298
|
+
|
|
299
|
+
if self.coerce_dtype:
|
|
300
|
+
tensors_to_copy = [(tgt, maybe_coerce_dtype(src, tgt.dtype)) for tgt, src in tensors_to_copy]
|
|
301
|
+
tensors_to_lerp = [(tgt, maybe_coerce_dtype(src, tgt.dtype)) for tgt, src in tensors_to_lerp]
|
|
302
|
+
|
|
303
|
+
if len(tensors_to_copy) > 0:
|
|
304
|
+
tgt_copy, src_copy = zip(*tensors_to_copy)
|
|
305
|
+
torch._foreach_copy_(tgt_copy, src_copy)
|
|
306
|
+
|
|
307
|
+
if len(tensors_to_lerp) > 0:
|
|
308
|
+
tgt_lerp, src_lerp = zip(*tensors_to_lerp)
|
|
309
|
+
torch._foreach_lerp_(tgt_lerp, src_lerp, 1. - current_decay)
|
|
224
310
|
|
|
225
311
|
def __call__(self, *args, **kwargs):
|
|
226
312
|
return self.ema_model(*args, **kwargs)
|
|
@@ -0,0 +1,420 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from copy import deepcopy
|
|
5
|
+
from functools import partial
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from torch import nn, Tensor
|
|
9
|
+
from torch.nn import Module, ModuleList
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
from typing import Set, Tuple
|
|
14
|
+
|
|
15
|
+
def exists(val):
|
|
16
|
+
return val is not None
|
|
17
|
+
|
|
18
|
+
def default(val, d):
|
|
19
|
+
return val if exists(val) else d
|
|
20
|
+
|
|
21
|
+
def first(arr):
|
|
22
|
+
return arr[0]
|
|
23
|
+
|
|
24
|
+
def get_module_device(m: Module):
|
|
25
|
+
return next(m.parameters()).device
|
|
26
|
+
|
|
27
|
+
def inplace_copy(tgt: Tensor, src: Tensor, *, auto_move_device = False):
|
|
28
|
+
if auto_move_device:
|
|
29
|
+
src = src.to(tgt.device)
|
|
30
|
+
|
|
31
|
+
tgt.copy_(src)
|
|
32
|
+
|
|
33
|
+
def inplace_lerp(tgt: Tensor, src: Tensor, weight, *, auto_move_device = False):
|
|
34
|
+
if auto_move_device:
|
|
35
|
+
src = src.to(tgt.device)
|
|
36
|
+
|
|
37
|
+
tgt.lerp_(src, weight)
|
|
38
|
+
|
|
39
|
+
# algorithm 2 in https://arxiv.org/abs/2312.02696
|
|
40
|
+
|
|
41
|
+
def sigma_rel_to_gamma(sigma_rel):
|
|
42
|
+
t = sigma_rel ** -2
|
|
43
|
+
return np.roots([1, 7, 16 - t, 12 - t]).real.max().item()
|
|
44
|
+
|
|
45
|
+
class KarrasEMA(Module):
|
|
46
|
+
"""
|
|
47
|
+
exponential moving average module that uses hyperparameters from the paper https://arxiv.org/abs/2312.02696
|
|
48
|
+
can either use gamma or sigma_rel from paper
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
def __init__(
|
|
52
|
+
self,
|
|
53
|
+
model: Module,
|
|
54
|
+
sigma_rel: float | None = None,
|
|
55
|
+
gamma: float | None = None,
|
|
56
|
+
ema_model: Module | None = None, # if your model has lazylinears or other types of non-deepcopyable modules, you can pass in your own ema model
|
|
57
|
+
update_every: int = 100,
|
|
58
|
+
frozen: bool = False,
|
|
59
|
+
param_or_buffer_names_no_ema: Set[str] = set(),
|
|
60
|
+
ignore_names: Set[str] = set(),
|
|
61
|
+
ignore_startswith_names: Set[str] = set(),
|
|
62
|
+
allow_different_devices = False, # if the EMA model is on a different device (say CPU), automatically move the tensor
|
|
63
|
+
move_ema_to_online_device = False # will move entire EMA model to the same device as online model, if different
|
|
64
|
+
):
|
|
65
|
+
super().__init__()
|
|
66
|
+
|
|
67
|
+
assert exists(sigma_rel) ^ exists(gamma), 'either sigma_rel or gamma is given. gamma is derived from sigma_rel as in the paper, then beta is dervied from gamma'
|
|
68
|
+
|
|
69
|
+
if exists(sigma_rel):
|
|
70
|
+
gamma = sigma_rel_to_gamma(sigma_rel)
|
|
71
|
+
|
|
72
|
+
self.gamma = gamma
|
|
73
|
+
self.frozen = frozen
|
|
74
|
+
|
|
75
|
+
self.online_model = [model]
|
|
76
|
+
|
|
77
|
+
# ema model
|
|
78
|
+
|
|
79
|
+
self.ema_model = ema_model
|
|
80
|
+
|
|
81
|
+
if not exists(self.ema_model):
|
|
82
|
+
try:
|
|
83
|
+
self.ema_model = deepcopy(model)
|
|
84
|
+
except Exception as e:
|
|
85
|
+
print(f'Error: While trying to deepcopy model: {e}')
|
|
86
|
+
print('Your model was not copyable. Please make sure you are not using any LazyLinear')
|
|
87
|
+
exit()
|
|
88
|
+
|
|
89
|
+
for p in self.ema_model.parameters():
|
|
90
|
+
p.detach_()
|
|
91
|
+
|
|
92
|
+
# parameter and buffer names
|
|
93
|
+
|
|
94
|
+
self.parameter_names = {name for name, param in self.ema_model.named_parameters() if torch.is_floating_point(param) or torch.is_complex(param)}
|
|
95
|
+
self.buffer_names = {name for name, buffer in self.ema_model.named_buffers() if torch.is_floating_point(buffer) or torch.is_complex(buffer)}
|
|
96
|
+
|
|
97
|
+
# tensor update functions
|
|
98
|
+
|
|
99
|
+
self.inplace_copy = partial(inplace_copy, auto_move_device = allow_different_devices)
|
|
100
|
+
self.inplace_lerp = partial(inplace_lerp, auto_move_device = allow_different_devices)
|
|
101
|
+
|
|
102
|
+
# updating hyperparameters
|
|
103
|
+
|
|
104
|
+
self.update_every = update_every
|
|
105
|
+
|
|
106
|
+
assert isinstance(param_or_buffer_names_no_ema, (set, list))
|
|
107
|
+
self.param_or_buffer_names_no_ema = param_or_buffer_names_no_ema # parameter or buffer
|
|
108
|
+
|
|
109
|
+
self.ignore_names = ignore_names
|
|
110
|
+
self.ignore_startswith_names = ignore_startswith_names
|
|
111
|
+
|
|
112
|
+
# whether to manage if EMA model is kept on a different device
|
|
113
|
+
|
|
114
|
+
self.allow_different_devices = allow_different_devices
|
|
115
|
+
|
|
116
|
+
# whether to move EMA model to online model device automatically
|
|
117
|
+
|
|
118
|
+
self.move_ema_to_online_device = move_ema_to_online_device
|
|
119
|
+
|
|
120
|
+
# init and step states
|
|
121
|
+
|
|
122
|
+
self.register_buffer('initted', torch.tensor(False))
|
|
123
|
+
self.register_buffer('step', torch.tensor(0))
|
|
124
|
+
|
|
125
|
+
@property
|
|
126
|
+
def model(self):
|
|
127
|
+
return first(self.online_model)
|
|
128
|
+
|
|
129
|
+
@property
|
|
130
|
+
def beta(self):
|
|
131
|
+
return (1. - 1. / (self.step.item() + 1.)) ** (1. + self.gamma)
|
|
132
|
+
|
|
133
|
+
def eval(self):
|
|
134
|
+
return self.ema_model.eval()
|
|
135
|
+
|
|
136
|
+
def restore_ema_model_device(self):
|
|
137
|
+
device = self.initted.device
|
|
138
|
+
self.ema_model.to(device)
|
|
139
|
+
|
|
140
|
+
def get_params_iter(self, model):
|
|
141
|
+
for name, param in model.named_parameters():
|
|
142
|
+
if name not in self.parameter_names:
|
|
143
|
+
continue
|
|
144
|
+
yield name, param
|
|
145
|
+
|
|
146
|
+
def get_buffers_iter(self, model):
|
|
147
|
+
for name, buffer in model.named_buffers():
|
|
148
|
+
if name not in self.buffer_names:
|
|
149
|
+
continue
|
|
150
|
+
yield name, buffer
|
|
151
|
+
|
|
152
|
+
def copy_params_from_model_to_ema(self):
|
|
153
|
+
copy = self.inplace_copy
|
|
154
|
+
|
|
155
|
+
for (_, ma_params), (_, current_params) in zip(self.get_params_iter(self.ema_model), self.get_params_iter(self.model)):
|
|
156
|
+
copy(ma_params.data, current_params.data)
|
|
157
|
+
|
|
158
|
+
for (_, ma_buffers), (_, current_buffers) in zip(self.get_buffers_iter(self.ema_model), self.get_buffers_iter(self.model)):
|
|
159
|
+
copy(ma_buffers.data, current_buffers.data)
|
|
160
|
+
|
|
161
|
+
def copy_params_from_ema_to_model(self):
|
|
162
|
+
copy = self.inplace_copy
|
|
163
|
+
|
|
164
|
+
for (_, ma_params), (_, current_params) in zip(self.get_params_iter(self.ema_model), self.get_params_iter(self.model)):
|
|
165
|
+
copy(current_params.data, ma_params.data)
|
|
166
|
+
|
|
167
|
+
for (_, ma_buffers), (_, current_buffers) in zip(self.get_buffers_iter(self.ema_model), self.get_buffers_iter(self.model)):
|
|
168
|
+
copy(current_buffers.data, ma_buffers.data)
|
|
169
|
+
|
|
170
|
+
def update(self):
|
|
171
|
+
step = self.step.item()
|
|
172
|
+
self.step += 1
|
|
173
|
+
|
|
174
|
+
if (step % self.update_every) != 0:
|
|
175
|
+
return
|
|
176
|
+
|
|
177
|
+
if not self.initted.item():
|
|
178
|
+
self.copy_params_from_model_to_ema()
|
|
179
|
+
self.initted.data.copy_(torch.tensor(True))
|
|
180
|
+
|
|
181
|
+
self.update_moving_average(self.ema_model, self.model)
|
|
182
|
+
|
|
183
|
+
def iter_all_ema_params_and_buffers(self):
|
|
184
|
+
for name, ma_params in self.get_params_iter(self.ema_model):
|
|
185
|
+
if name in self.ignore_names:
|
|
186
|
+
continue
|
|
187
|
+
|
|
188
|
+
if any([name.startswith(prefix) for prefix in self.ignore_startswith_names]):
|
|
189
|
+
continue
|
|
190
|
+
|
|
191
|
+
if name in self.param_or_buffer_names_no_ema:
|
|
192
|
+
continue
|
|
193
|
+
|
|
194
|
+
yield ma_params
|
|
195
|
+
|
|
196
|
+
for name, ma_buffer in self.get_buffers_iter(self.ema_model):
|
|
197
|
+
if name in self.ignore_names:
|
|
198
|
+
continue
|
|
199
|
+
|
|
200
|
+
if any([name.startswith(prefix) for prefix in self.ignore_startswith_names]):
|
|
201
|
+
continue
|
|
202
|
+
|
|
203
|
+
if name in self.param_or_buffer_names_no_ema:
|
|
204
|
+
continue
|
|
205
|
+
|
|
206
|
+
yield ma_buffer
|
|
207
|
+
|
|
208
|
+
@torch.no_grad()
|
|
209
|
+
def update_moving_average(self, ma_model, current_model):
|
|
210
|
+
if self.frozen:
|
|
211
|
+
return
|
|
212
|
+
|
|
213
|
+
# move ema model to online model device if not same and needed
|
|
214
|
+
|
|
215
|
+
if self.move_ema_to_online_device and get_module_device(ma_model) != get_module_device(current_model):
|
|
216
|
+
ma_model.to(get_module_device(current_model))
|
|
217
|
+
|
|
218
|
+
# get some functions and current decay
|
|
219
|
+
|
|
220
|
+
copy, lerp = self.inplace_copy, self.inplace_lerp
|
|
221
|
+
current_decay = self.beta
|
|
222
|
+
|
|
223
|
+
for (name, current_params), (_, ma_params) in zip(self.get_params_iter(current_model), self.get_params_iter(ma_model)):
|
|
224
|
+
if name in self.ignore_names:
|
|
225
|
+
continue
|
|
226
|
+
|
|
227
|
+
if any([name.startswith(prefix) for prefix in self.ignore_startswith_names]):
|
|
228
|
+
continue
|
|
229
|
+
|
|
230
|
+
if name in self.param_or_buffer_names_no_ema:
|
|
231
|
+
copy(ma_params.data, current_params.data)
|
|
232
|
+
continue
|
|
233
|
+
|
|
234
|
+
lerp(ma_params.data, current_params.data, 1. - current_decay)
|
|
235
|
+
|
|
236
|
+
for (name, current_buffer), (_, ma_buffer) in zip(self.get_buffers_iter(current_model), self.get_buffers_iter(ma_model)):
|
|
237
|
+
if name in self.ignore_names:
|
|
238
|
+
continue
|
|
239
|
+
|
|
240
|
+
if any([name.startswith(prefix) for prefix in self.ignore_startswith_names]):
|
|
241
|
+
continue
|
|
242
|
+
|
|
243
|
+
if name in self.param_or_buffer_names_no_ema:
|
|
244
|
+
copy(ma_buffer.data, current_buffer.data)
|
|
245
|
+
continue
|
|
246
|
+
|
|
247
|
+
lerp(ma_buffer.data, current_buffer.data, 1. - current_decay)
|
|
248
|
+
|
|
249
|
+
def __call__(self, *args, **kwargs):
|
|
250
|
+
return self.ema_model(*args, **kwargs)
|
|
251
|
+
|
|
252
|
+
# post hoc ema wrapper
|
|
253
|
+
|
|
254
|
+
# solving of the weights for combining all checkpoints into a newly synthesized EMA at desired gamma
|
|
255
|
+
# Algorithm 3 copied from paper, redone in torch
|
|
256
|
+
|
|
257
|
+
def p_dot_p(t_a, gamma_a, t_b, gamma_b):
|
|
258
|
+
t_ratio = t_a / t_b
|
|
259
|
+
t_exp = torch.where(t_a < t_b , gamma_b , -gamma_a)
|
|
260
|
+
t_max = torch.maximum(t_a , t_b)
|
|
261
|
+
num = (gamma_a + 1) * (gamma_b + 1) * t_ratio ** t_exp
|
|
262
|
+
den = (gamma_a + gamma_b + 1) * t_max
|
|
263
|
+
return num / den
|
|
264
|
+
|
|
265
|
+
def solve_weights(t_i, gamma_i, t_r, gamma_r):
|
|
266
|
+
rv = lambda x: x.double().reshape(-1, 1)
|
|
267
|
+
cv = lambda x: x.double().reshape(1, -1)
|
|
268
|
+
A = p_dot_p(rv(t_i), rv(gamma_i), cv(t_i), cv(gamma_i))
|
|
269
|
+
b = p_dot_p(rv(t_i), rv(gamma_i), cv(t_r), cv(gamma_r))
|
|
270
|
+
return torch.linalg.solve(A, b)
|
|
271
|
+
|
|
272
|
+
class PostHocEMA(Module):
|
|
273
|
+
|
|
274
|
+
def __init__(
|
|
275
|
+
self,
|
|
276
|
+
model: Module,
|
|
277
|
+
sigma_rels: Tuple[float, ...] | None = None,
|
|
278
|
+
gammas: Tuple[float, ...] | None = None,
|
|
279
|
+
checkpoint_every_num_steps: int = 1000,
|
|
280
|
+
checkpoint_folder: str = './post-hoc-ema-checkpoints',
|
|
281
|
+
checkpoint_dtype: torch.dtype = torch.float16,
|
|
282
|
+
**kwargs
|
|
283
|
+
):
|
|
284
|
+
super().__init__()
|
|
285
|
+
assert exists(sigma_rels) ^ exists(gammas)
|
|
286
|
+
|
|
287
|
+
if exists(sigma_rels):
|
|
288
|
+
gammas = tuple(map(sigma_rel_to_gamma, sigma_rels))
|
|
289
|
+
|
|
290
|
+
assert len(gammas) > 1, 'at least 2 ema models with different gammas in order to synthesize new ema models of a different gamma'
|
|
291
|
+
assert len(set(gammas)) == len(gammas), 'calculated gammas must be all unique'
|
|
292
|
+
|
|
293
|
+
self.gammas = gammas
|
|
294
|
+
self.num_ema_models = len(gammas)
|
|
295
|
+
|
|
296
|
+
self._model = [model]
|
|
297
|
+
self.ema_models = ModuleList([KarrasEMA(model, gamma = gamma, **kwargs) for gamma in gammas])
|
|
298
|
+
|
|
299
|
+
self.checkpoint_folder = Path(checkpoint_folder)
|
|
300
|
+
self.checkpoint_folder.mkdir(exist_ok = True, parents = True)
|
|
301
|
+
assert self.checkpoint_folder.is_dir()
|
|
302
|
+
|
|
303
|
+
self.checkpoint_every_num_steps = checkpoint_every_num_steps
|
|
304
|
+
self.checkpoint_dtype = checkpoint_dtype
|
|
305
|
+
self.ema_kwargs = kwargs
|
|
306
|
+
|
|
307
|
+
@property
|
|
308
|
+
def model(self):
|
|
309
|
+
return first(self._model)
|
|
310
|
+
|
|
311
|
+
@property
|
|
312
|
+
def step(self):
|
|
313
|
+
return first(self.ema_models).step
|
|
314
|
+
|
|
315
|
+
@property
|
|
316
|
+
def device(self):
|
|
317
|
+
return self.step.device
|
|
318
|
+
|
|
319
|
+
def copy_params_from_model_to_ema(self):
|
|
320
|
+
for ema_model in self.ema_models:
|
|
321
|
+
ema_model.copy_params_from_model_to_ema()
|
|
322
|
+
|
|
323
|
+
def copy_params_from_ema_to_model(self):
|
|
324
|
+
for ema_model in self.ema_models:
|
|
325
|
+
ema_model.copy_params_from_ema_to_model()
|
|
326
|
+
|
|
327
|
+
def update(self):
|
|
328
|
+
for ema_model in self.ema_models:
|
|
329
|
+
ema_model.update()
|
|
330
|
+
|
|
331
|
+
if not (self.step.item() % self.checkpoint_every_num_steps):
|
|
332
|
+
self.checkpoint()
|
|
333
|
+
|
|
334
|
+
def checkpoint(self):
|
|
335
|
+
step = self.step.item()
|
|
336
|
+
|
|
337
|
+
for ind, ema_model in enumerate(self.ema_models):
|
|
338
|
+
filename = f'{ind}.{step}.pt'
|
|
339
|
+
path = self.checkpoint_folder / filename
|
|
340
|
+
|
|
341
|
+
pkg = deepcopy(ema_model).to(self.checkpoint_dtype).state_dict()
|
|
342
|
+
torch.save(pkg, str(path))
|
|
343
|
+
|
|
344
|
+
def synthesize_ema_model(
|
|
345
|
+
self,
|
|
346
|
+
gamma: float | None = None,
|
|
347
|
+
sigma_rel: float | None = None,
|
|
348
|
+
step: int | None = None,
|
|
349
|
+
) -> KarrasEMA:
|
|
350
|
+
assert exists(gamma) ^ exists(sigma_rel)
|
|
351
|
+
device = self.device
|
|
352
|
+
|
|
353
|
+
if exists(sigma_rel):
|
|
354
|
+
gamma = sigma_rel_to_gamma(sigma_rel)
|
|
355
|
+
|
|
356
|
+
synthesized_ema_model = KarrasEMA(
|
|
357
|
+
model = self.model,
|
|
358
|
+
gamma = gamma,
|
|
359
|
+
**self.ema_kwargs
|
|
360
|
+
)
|
|
361
|
+
|
|
362
|
+
synthesized_ema_model
|
|
363
|
+
|
|
364
|
+
# get all checkpoints
|
|
365
|
+
|
|
366
|
+
gammas = []
|
|
367
|
+
timesteps = []
|
|
368
|
+
checkpoints = [*self.checkpoint_folder.glob('*.pt')]
|
|
369
|
+
|
|
370
|
+
for file in checkpoints:
|
|
371
|
+
gamma_ind, timestep = map(int, file.stem.split('.'))
|
|
372
|
+
gammas.append(self.gammas[gamma_ind])
|
|
373
|
+
timesteps.append(timestep)
|
|
374
|
+
|
|
375
|
+
step = default(step, max(timesteps))
|
|
376
|
+
assert step <= max(timesteps), f'you can only synthesize for a timestep that is less than the max timestep {max(timesteps)}'
|
|
377
|
+
|
|
378
|
+
# line up with Algorithm 3
|
|
379
|
+
|
|
380
|
+
gamma_i = torch.tensor(gammas, device = device)
|
|
381
|
+
t_i = torch.tensor(timesteps, device = device)
|
|
382
|
+
|
|
383
|
+
gamma_r = torch.tensor([gamma], device = device)
|
|
384
|
+
t_r = torch.tensor([step], device = device)
|
|
385
|
+
|
|
386
|
+
# solve for weights for combining all checkpoints into synthesized, using least squares as in paper
|
|
387
|
+
|
|
388
|
+
weights = solve_weights(t_i, gamma_i, t_r, gamma_r)
|
|
389
|
+
weights = weights.squeeze(-1)
|
|
390
|
+
|
|
391
|
+
# now sum up all the checkpoints using the weights one by one
|
|
392
|
+
|
|
393
|
+
tmp_ema_model = KarrasEMA(
|
|
394
|
+
model = self.model,
|
|
395
|
+
gamma = gamma,
|
|
396
|
+
**self.ema_kwargs
|
|
397
|
+
)
|
|
398
|
+
|
|
399
|
+
for ind, (checkpoint, weight) in enumerate(zip(checkpoints, weights.tolist())):
|
|
400
|
+
is_first = ind == 0
|
|
401
|
+
|
|
402
|
+
# load checkpoint into a temporary ema model
|
|
403
|
+
|
|
404
|
+
ckpt_state_dict = torch.load(str(checkpoint), weights_only=True)
|
|
405
|
+
tmp_ema_model.load_state_dict(ckpt_state_dict)
|
|
406
|
+
|
|
407
|
+
# add weighted checkpoint to synthesized
|
|
408
|
+
|
|
409
|
+
for ckpt_tensor, synth_tensor in zip(tmp_ema_model.iter_all_ema_params_and_buffers(), synthesized_ema_model.iter_all_ema_params_and_buffers()):
|
|
410
|
+
if is_first:
|
|
411
|
+
synth_tensor.zero_()
|
|
412
|
+
|
|
413
|
+
synth_tensor.add_(ckpt_tensor * weight)
|
|
414
|
+
|
|
415
|
+
# return the synthesized model
|
|
416
|
+
|
|
417
|
+
return synthesized_ema_model
|
|
418
|
+
|
|
419
|
+
def __call__(self, *args, **kwargs):
|
|
420
|
+
return tuple(ema_model(*args, **kwargs) for ema_model in self.ema_models)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: ema-pytorch
|
|
3
|
-
Version: 0.3
|
|
3
|
+
Version: 0.6.3
|
|
4
4
|
Summary: Easy way to keep track of exponential moving average version of your pytorch module
|
|
5
5
|
Home-page: https://github.com/lucidrains/ema-pytorch
|
|
6
6
|
Author: Phil Wang
|
|
@@ -14,5 +14,4 @@ Classifier: License :: OSI Approved :: MIT License
|
|
|
14
14
|
Classifier: Programming Language :: Python :: 3.6
|
|
15
15
|
Description-Content-Type: text/markdown
|
|
16
16
|
License-File: LICENSE
|
|
17
|
-
Requires-Dist:
|
|
18
|
-
Requires-Dist: torch>=1.6
|
|
17
|
+
Requires-Dist: torch>=2.0
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
torch>=2.0
|
|
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
|
|
3
3
|
setup(
|
|
4
4
|
name = 'ema-pytorch',
|
|
5
5
|
packages = find_packages(exclude=[]),
|
|
6
|
-
version = '0.3
|
|
6
|
+
version = '0.6.3',
|
|
7
7
|
license='MIT',
|
|
8
8
|
description = 'Easy way to keep track of exponential moving average version of your pytorch module',
|
|
9
9
|
author = 'Phil Wang',
|
|
@@ -16,8 +16,7 @@ setup(
|
|
|
16
16
|
'exponential moving average'
|
|
17
17
|
],
|
|
18
18
|
install_requires=[
|
|
19
|
-
'
|
|
20
|
-
'torch>=1.6',
|
|
19
|
+
'torch>=2.0',
|
|
21
20
|
],
|
|
22
21
|
classifiers=[
|
|
23
22
|
'Development Status :: 4 - Beta',
|
ema-pytorch-0.3.1/README.md
DELETED
|
@@ -1,54 +0,0 @@
|
|
|
1
|
-
## EMA - Pytorch
|
|
2
|
-
|
|
3
|
-
A simple way to keep track of an Exponential Moving Average (EMA) version of your pytorch model
|
|
4
|
-
|
|
5
|
-
## Install
|
|
6
|
-
|
|
7
|
-
```bash
|
|
8
|
-
$ pip install ema-pytorch
|
|
9
|
-
```
|
|
10
|
-
|
|
11
|
-
## Usage
|
|
12
|
-
|
|
13
|
-
```python
|
|
14
|
-
import torch
|
|
15
|
-
from ema_pytorch import EMA
|
|
16
|
-
|
|
17
|
-
# your neural network as a pytorch module
|
|
18
|
-
|
|
19
|
-
net = torch.nn.Linear(512, 512)
|
|
20
|
-
|
|
21
|
-
# wrap your neural network, specify the decay (beta)
|
|
22
|
-
|
|
23
|
-
ema = EMA(
|
|
24
|
-
net,
|
|
25
|
-
beta = 0.9999, # exponential moving average factor
|
|
26
|
-
update_after_step = 100, # only after this number of .update() calls will it start updating
|
|
27
|
-
update_every = 10, # how often to actually update, to save on compute (updates every 10th .update() call)
|
|
28
|
-
)
|
|
29
|
-
|
|
30
|
-
# mutate your network, with SGD or otherwise
|
|
31
|
-
|
|
32
|
-
with torch.no_grad():
|
|
33
|
-
net.weight.copy_(torch.randn_like(net.weight))
|
|
34
|
-
net.bias.copy_(torch.randn_like(net.bias))
|
|
35
|
-
|
|
36
|
-
# you will call the update function on your moving average wrapper
|
|
37
|
-
|
|
38
|
-
ema.update()
|
|
39
|
-
|
|
40
|
-
# then, later on, you can invoke the EMA model the same way as your network
|
|
41
|
-
|
|
42
|
-
data = torch.randn(1, 512)
|
|
43
|
-
|
|
44
|
-
output = net(data)
|
|
45
|
-
ema_output = ema(data)
|
|
46
|
-
|
|
47
|
-
# if you want to save your ema model, it is recommended you save the entire wrapper
|
|
48
|
-
# as it contains the number of steps taken (there is a warmup logic in there, recommended by @crowsonkb, validated for a number of projects now)
|
|
49
|
-
# however, if you wish to access the copy of your model with EMA, then it will live at ema.ema_model
|
|
50
|
-
```
|
|
51
|
-
|
|
52
|
-
## Todo
|
|
53
|
-
|
|
54
|
-
- [ ] address the issue of annealing EMA to 1 near the end of training for BYOL https://github.com/lucidrains/byol-pytorch/issues/82
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
from ema_pytorch.ema_pytorch import EMA
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|