ema-pytorch 0.3.1__tar.gz → 0.6.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ema-pytorch
3
- Version: 0.3.1
3
+ Version: 0.6.3
4
4
  Summary: Easy way to keep track of exponential moving average version of your pytorch module
5
5
  Home-page: https://github.com/lucidrains/ema-pytorch
6
6
  Author: Phil Wang
@@ -14,5 +14,4 @@ Classifier: License :: OSI Approved :: MIT License
14
14
  Classifier: Programming Language :: Python :: 3.6
15
15
  Description-Content-Type: text/markdown
16
16
  License-File: LICENSE
17
- Requires-Dist: beartype
18
- Requires-Dist: torch>=1.6
17
+ Requires-Dist: torch>=2.0
@@ -0,0 +1,109 @@
1
+ ## EMA - Pytorch
2
+
3
+ A simple way to keep track of an Exponential Moving Average (EMA) version of your pytorch model
4
+
5
+ ## Install
6
+
7
+ ```bash
8
+ $ pip install ema-pytorch
9
+ ```
10
+
11
+ ## Usage
12
+
13
+ ```python
14
+ import torch
15
+ from ema_pytorch import EMA
16
+
17
+ # your neural network as a pytorch module
18
+
19
+ net = torch.nn.Linear(512, 512)
20
+
21
+ # wrap your neural network, specify the decay (beta)
22
+
23
+ ema = EMA(
24
+ net,
25
+ beta = 0.9999, # exponential moving average factor
26
+ update_after_step = 100, # only after this number of .update() calls will it start updating
27
+ update_every = 10, # how often to actually update, to save on compute (updates every 10th .update() call)
28
+ )
29
+
30
+ # mutate your network, with SGD or otherwise
31
+
32
+ with torch.no_grad():
33
+ net.weight.copy_(torch.randn_like(net.weight))
34
+ net.bias.copy_(torch.randn_like(net.bias))
35
+
36
+ # you will call the update function on your moving average wrapper
37
+
38
+ ema.update()
39
+
40
+ # then, later on, you can invoke the EMA model the same way as your network
41
+
42
+ data = torch.randn(1, 512)
43
+
44
+ output = net(data)
45
+ ema_output = ema(data)
46
+
47
+ # if you want to save your ema model, it is recommended you save the entire wrapper
48
+ # as it contains the number of steps taken (there is a warmup logic in there, recommended by @crowsonkb, validated for a number of projects now)
49
+ # however, if you wish to access the copy of your model with EMA, then it will live at ema.ema_model
50
+ ```
51
+
52
+ In order to use the post-hoc synthesized EMA, proposed by Karras et al. in <a href="https://arxiv.org/abs/2312.02696">a recent paper</a>, follow the example below
53
+
54
+ ```python
55
+ import torch
56
+ from ema_pytorch import PostHocEMA
57
+
58
+ # your neural network as a pytorch module
59
+
60
+ net = torch.nn.Linear(512, 512)
61
+
62
+ # wrap your neural network, specify the sigma_rels or gammas
63
+
64
+ emas = PostHocEMA(
65
+ net,
66
+ sigma_rels = (0.05, 0.3), # a tuple with the hyperparameter for the multiple EMAs. you need at least 2 here to synthesize a new one
67
+ update_every = 10, # how often to actually update, to save on compute (updates every 10th .update() call)
68
+ checkpoint_every_num_steps = 10,
69
+ checkpoint_folder = './post-hoc-ema-checkpoints' # the folder of saved checkpoints for each sigma_rel (gamma) across timesteps with the hparam above, used to synthesizing a new EMA model after training
70
+ )
71
+
72
+ net.train()
73
+
74
+ for _ in range(1000):
75
+ # mutate your network, with SGD or otherwise
76
+
77
+ with torch.no_grad():
78
+ net.weight.copy_(torch.randn_like(net.weight))
79
+ net.bias.copy_(torch.randn_like(net.bias))
80
+
81
+ # you will call the update function on your moving average wrapper
82
+
83
+ emas.update()
84
+
85
+ # now that you have a few checkpoints
86
+ # you can synthesize an EMA model with a different sigma_rel (say 0.15)
87
+
88
+ synthesized_ema = emas.synthesize_ema_model(sigma_rel = 0.15)
89
+
90
+ # output with synthesized EMA
91
+
92
+ data = torch.randn(1, 512)
93
+
94
+ synthesized_ema_output = synthesized_ema(data)
95
+
96
+ ```
97
+
98
+ ## Citations
99
+
100
+ ```bibtex
101
+ @article{Karras2023AnalyzingAI,
102
+ title = {Analyzing and Improving the Training Dynamics of Diffusion Models},
103
+ author = {Tero Karras and Miika Aittala and Jaakko Lehtinen and Janne Hellsten and Timo Aila and Samuli Laine},
104
+ journal = {ArXiv},
105
+ year = {2023},
106
+ volume = {abs/2312.02696},
107
+ url = {https://api.semanticscholar.org/CorpusID:265659032}
108
+ }
109
+ ```
@@ -0,0 +1,6 @@
1
+ from ema_pytorch.ema_pytorch import EMA
2
+
3
+ from ema_pytorch.post_hoc_ema import (
4
+ KarrasEMA,
5
+ PostHocEMA
6
+ )
@@ -1,3 +1,6 @@
1
+ from __future__ import annotations
2
+ from typing import Set, Tuple
3
+
1
4
  from copy import deepcopy
2
5
  from functools import partial
3
6
 
@@ -5,26 +8,35 @@ import torch
5
8
  from torch import nn, Tensor
6
9
  from torch.nn import Module
7
10
 
8
- from beartype import beartype
9
- from beartype.typing import Set, Optional
10
-
11
11
  def exists(val):
12
12
  return val is not None
13
13
 
14
14
  def get_module_device(m: Module):
15
15
  return next(m.parameters()).device
16
16
 
17
- def inplace_copy(src: Tensor, tgt: Tensor, *, auto_move_device = False):
17
+ def maybe_coerce_dtype(t, dtype):
18
+ if t.dtype == dtype:
19
+ return t
20
+
21
+ return t.to(dtype)
22
+
23
+ def inplace_copy(tgt: Tensor, src: Tensor, *, auto_move_device = False, coerce_dtype = False):
18
24
  if auto_move_device:
19
- tgt = tgt.to(src.device)
25
+ src = src.to(tgt.device)
26
+
27
+ if coerce_dtype:
28
+ src = maybe_coerce_dtype(src, tgt.dtype)
20
29
 
21
- src.copy_(tgt)
30
+ tgt.copy_(src)
22
31
 
23
- def inplace_lerp(src: Tensor, tgt: Tensor, weight, *, auto_move_device = False):
32
+ def inplace_lerp(tgt: Tensor, src: Tensor, weight, *, auto_move_device = False, coerce_dtype = False):
24
33
  if auto_move_device:
25
- tgt = tgt.to(src.device)
34
+ src = src.to(tgt.device)
26
35
 
27
- src.lerp_(tgt, weight)
36
+ if coerce_dtype:
37
+ src = maybe_coerce_dtype(src, tgt.dtype)
38
+
39
+ tgt.lerp_(src, weight)
28
40
 
29
41
  class EMA(Module):
30
42
  """
@@ -43,15 +55,14 @@ class EMA(Module):
43
55
 
44
56
  Args:
45
57
  inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
46
- power (float): Exponential factor of EMA warmup. Default: 1.
58
+ power (float): Exponential factor of EMA warmup. Default: 2/3.
47
59
  min_value (float): The minimum EMA decay rate. Default: 0.
48
60
  """
49
61
 
50
- @beartype
51
62
  def __init__(
52
63
  self,
53
64
  model: Module,
54
- ema_model: Optional[Module] = None, # if your model has lazylinears or other types of non-deepcopyable modules, you can pass in your own ema model
65
+ ema_model: Module | None = None, # if your model has lazylinears or other types of non-deepcopyable modules, you can pass in your own ema model
55
66
  beta = 0.9999,
56
67
  update_after_step = 100,
57
68
  update_every = 10,
@@ -62,11 +73,17 @@ class EMA(Module):
62
73
  ignore_names: Set[str] = set(),
63
74
  ignore_startswith_names: Set[str] = set(),
64
75
  include_online_model = True, # set this to False if you do not wish for the online model to be saved along with the ema model (managed externally)
65
- allow_different_devices = False # if the EMA model is on a different device (say CPU), automatically move the tensor
76
+ allow_different_devices = False, # if the EMA model is on a different device (say CPU), automatically move the tensor
77
+ use_foreach = False,
78
+ forward_method_names: Tuple[str, ...] = (),
79
+ move_ema_to_online_device = False,
80
+ coerce_dtype = False
66
81
  ):
67
82
  super().__init__()
68
83
  self.beta = beta
69
84
 
85
+ self.is_frozen = beta == 1.
86
+
70
87
  # whether to include the online model within the module tree, so that state_dict also saves it
71
88
 
72
89
  self.include_online_model = include_online_model
@@ -88,17 +105,24 @@ class EMA(Module):
88
105
  print('Your model was not copyable. Please make sure you are not using any LazyLinear')
89
106
  exit()
90
107
 
91
- self.ema_model.requires_grad_(False)
108
+ for p in self.ema_model.parameters():
109
+ p.detach_()
110
+
111
+ # forwarding methods
112
+
113
+ for forward_method_name in forward_method_names:
114
+ fn = getattr(self.ema_model, forward_method_name)
115
+ setattr(self, forward_method_name, fn)
92
116
 
93
117
  # parameter and buffer names
94
118
 
95
- self.parameter_names = {name for name, param in self.ema_model.named_parameters() if param.dtype in [torch.float, torch.float16]}
96
- self.buffer_names = {name for name, buffer in self.ema_model.named_buffers() if buffer.dtype in [torch.float, torch.float16]}
119
+ self.parameter_names = {name for name, param in self.ema_model.named_parameters() if torch.is_floating_point(param) or torch.is_complex(param)}
120
+ self.buffer_names = {name for name, buffer in self.ema_model.named_buffers() if torch.is_floating_point(buffer) or torch.is_complex(buffer)}
97
121
 
98
122
  # tensor update functions
99
123
 
100
- self.inplace_copy = partial(inplace_copy, auto_move_device = allow_different_devices)
101
- self.inplace_lerp = partial(inplace_lerp, auto_move_device = allow_different_devices)
124
+ self.inplace_copy = partial(inplace_copy, auto_move_device = allow_different_devices, coerce_dtype = coerce_dtype)
125
+ self.inplace_lerp = partial(inplace_lerp, auto_move_device = allow_different_devices, coerce_dtype = coerce_dtype)
102
126
 
103
127
  # updating hyperparameters
104
128
 
@@ -119,6 +143,21 @@ class EMA(Module):
119
143
 
120
144
  self.allow_different_devices = allow_different_devices
121
145
 
146
+ # whether to coerce dtype when copy or lerp from online to EMA model
147
+
148
+ self.coerce_dtype = coerce_dtype
149
+
150
+ # whether to move EMA model to online model device automatically
151
+
152
+ self.move_ema_to_online_device = move_ema_to_online_device
153
+
154
+ # whether to use foreach
155
+
156
+ if use_foreach:
157
+ assert hasattr(torch, '_foreach_lerp_') and hasattr(torch, '_foreach_copy_'), 'your version of torch does not have the prerequisite foreach functions'
158
+
159
+ self.use_foreach = use_foreach
160
+
122
161
  # init and step states
123
162
 
124
163
  self.register_buffer('initted', torch.tensor(False))
@@ -193,9 +232,25 @@ class EMA(Module):
193
232
 
194
233
  @torch.no_grad()
195
234
  def update_moving_average(self, ma_model, current_model):
196
- copy, lerp = self.inplace_copy, self.inplace_lerp
235
+ if self.is_frozen:
236
+ return
237
+
238
+ # move ema model to online model device if not same and needed
239
+
240
+ if self.move_ema_to_online_device and get_module_device(ma_model) != get_module_device(current_model):
241
+ ma_model.to(get_module_device(current_model))
242
+
243
+ # get current decay
244
+
197
245
  current_decay = self.get_current_decay()
198
246
 
247
+ # store all source and target tensors to copy or lerp
248
+
249
+ tensors_to_copy = []
250
+ tensors_to_lerp = []
251
+
252
+ # loop through parameters
253
+
199
254
  for (name, current_params), (_, ma_params) in zip(self.get_params_iter(current_model), self.get_params_iter(ma_model)):
200
255
  if name in self.ignore_names:
201
256
  continue
@@ -204,10 +259,12 @@ class EMA(Module):
204
259
  continue
205
260
 
206
261
  if name in self.param_or_buffer_names_no_ema:
207
- copy(ma_params.data, current_params.data)
262
+ tensors_to_copy.append((ma_params.data, current_params.data))
208
263
  continue
209
264
 
210
- lerp(ma_params.data, current_params.data, 1. - current_decay)
265
+ tensors_to_lerp.append((ma_params.data, current_params.data))
266
+
267
+ # loop through buffers
211
268
 
212
269
  for (name, current_buffer), (_, ma_buffer) in zip(self.get_buffers_iter(current_model), self.get_buffers_iter(ma_model)):
213
270
  if name in self.ignore_names:
@@ -217,10 +274,39 @@ class EMA(Module):
217
274
  continue
218
275
 
219
276
  if name in self.param_or_buffer_names_no_ema:
220
- copy(ma_buffer.data, current_buffer.data)
277
+ tensors_to_copy.append((ma_buffer.data, current_buffer.data))
221
278
  continue
222
279
 
223
- lerp(ma_buffer.data, current_buffer.data, 1. - current_decay)
280
+ tensors_to_lerp.append((ma_buffer.data, current_buffer.data))
281
+
282
+ # execute inplace copy or lerp
283
+
284
+ if not self.use_foreach:
285
+
286
+ for tgt, src in tensors_to_copy:
287
+ self.inplace_copy(tgt, src)
288
+
289
+ for tgt, src in tensors_to_lerp:
290
+ self.inplace_lerp(tgt, src, 1. - current_decay)
291
+
292
+ else:
293
+ # use foreach if available and specified
294
+
295
+ if self.allow_different_devices:
296
+ tensors_to_copy = [(tgt, src.to(tgt.device)) for tgt, src in tensors_to_copy]
297
+ tensors_to_lerp = [(tgt, src.to(tgt.device)) for tgt, src in tensors_to_lerp]
298
+
299
+ if self.coerce_dtype:
300
+ tensors_to_copy = [(tgt, maybe_coerce_dtype(src, tgt.dtype)) for tgt, src in tensors_to_copy]
301
+ tensors_to_lerp = [(tgt, maybe_coerce_dtype(src, tgt.dtype)) for tgt, src in tensors_to_lerp]
302
+
303
+ if len(tensors_to_copy) > 0:
304
+ tgt_copy, src_copy = zip(*tensors_to_copy)
305
+ torch._foreach_copy_(tgt_copy, src_copy)
306
+
307
+ if len(tensors_to_lerp) > 0:
308
+ tgt_lerp, src_lerp = zip(*tensors_to_lerp)
309
+ torch._foreach_lerp_(tgt_lerp, src_lerp, 1. - current_decay)
224
310
 
225
311
  def __call__(self, *args, **kwargs):
226
312
  return self.ema_model(*args, **kwargs)
@@ -0,0 +1,420 @@
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from copy import deepcopy
5
+ from functools import partial
6
+
7
+ import torch
8
+ from torch import nn, Tensor
9
+ from torch.nn import Module, ModuleList
10
+
11
+ import numpy as np
12
+
13
+ from typing import Set, Tuple
14
+
15
+ def exists(val):
16
+ return val is not None
17
+
18
+ def default(val, d):
19
+ return val if exists(val) else d
20
+
21
+ def first(arr):
22
+ return arr[0]
23
+
24
+ def get_module_device(m: Module):
25
+ return next(m.parameters()).device
26
+
27
+ def inplace_copy(tgt: Tensor, src: Tensor, *, auto_move_device = False):
28
+ if auto_move_device:
29
+ src = src.to(tgt.device)
30
+
31
+ tgt.copy_(src)
32
+
33
+ def inplace_lerp(tgt: Tensor, src: Tensor, weight, *, auto_move_device = False):
34
+ if auto_move_device:
35
+ src = src.to(tgt.device)
36
+
37
+ tgt.lerp_(src, weight)
38
+
39
+ # algorithm 2 in https://arxiv.org/abs/2312.02696
40
+
41
+ def sigma_rel_to_gamma(sigma_rel):
42
+ t = sigma_rel ** -2
43
+ return np.roots([1, 7, 16 - t, 12 - t]).real.max().item()
44
+
45
+ class KarrasEMA(Module):
46
+ """
47
+ exponential moving average module that uses hyperparameters from the paper https://arxiv.org/abs/2312.02696
48
+ can either use gamma or sigma_rel from paper
49
+ """
50
+
51
+ def __init__(
52
+ self,
53
+ model: Module,
54
+ sigma_rel: float | None = None,
55
+ gamma: float | None = None,
56
+ ema_model: Module | None = None, # if your model has lazylinears or other types of non-deepcopyable modules, you can pass in your own ema model
57
+ update_every: int = 100,
58
+ frozen: bool = False,
59
+ param_or_buffer_names_no_ema: Set[str] = set(),
60
+ ignore_names: Set[str] = set(),
61
+ ignore_startswith_names: Set[str] = set(),
62
+ allow_different_devices = False, # if the EMA model is on a different device (say CPU), automatically move the tensor
63
+ move_ema_to_online_device = False # will move entire EMA model to the same device as online model, if different
64
+ ):
65
+ super().__init__()
66
+
67
+ assert exists(sigma_rel) ^ exists(gamma), 'either sigma_rel or gamma is given. gamma is derived from sigma_rel as in the paper, then beta is dervied from gamma'
68
+
69
+ if exists(sigma_rel):
70
+ gamma = sigma_rel_to_gamma(sigma_rel)
71
+
72
+ self.gamma = gamma
73
+ self.frozen = frozen
74
+
75
+ self.online_model = [model]
76
+
77
+ # ema model
78
+
79
+ self.ema_model = ema_model
80
+
81
+ if not exists(self.ema_model):
82
+ try:
83
+ self.ema_model = deepcopy(model)
84
+ except Exception as e:
85
+ print(f'Error: While trying to deepcopy model: {e}')
86
+ print('Your model was not copyable. Please make sure you are not using any LazyLinear')
87
+ exit()
88
+
89
+ for p in self.ema_model.parameters():
90
+ p.detach_()
91
+
92
+ # parameter and buffer names
93
+
94
+ self.parameter_names = {name for name, param in self.ema_model.named_parameters() if torch.is_floating_point(param) or torch.is_complex(param)}
95
+ self.buffer_names = {name for name, buffer in self.ema_model.named_buffers() if torch.is_floating_point(buffer) or torch.is_complex(buffer)}
96
+
97
+ # tensor update functions
98
+
99
+ self.inplace_copy = partial(inplace_copy, auto_move_device = allow_different_devices)
100
+ self.inplace_lerp = partial(inplace_lerp, auto_move_device = allow_different_devices)
101
+
102
+ # updating hyperparameters
103
+
104
+ self.update_every = update_every
105
+
106
+ assert isinstance(param_or_buffer_names_no_ema, (set, list))
107
+ self.param_or_buffer_names_no_ema = param_or_buffer_names_no_ema # parameter or buffer
108
+
109
+ self.ignore_names = ignore_names
110
+ self.ignore_startswith_names = ignore_startswith_names
111
+
112
+ # whether to manage if EMA model is kept on a different device
113
+
114
+ self.allow_different_devices = allow_different_devices
115
+
116
+ # whether to move EMA model to online model device automatically
117
+
118
+ self.move_ema_to_online_device = move_ema_to_online_device
119
+
120
+ # init and step states
121
+
122
+ self.register_buffer('initted', torch.tensor(False))
123
+ self.register_buffer('step', torch.tensor(0))
124
+
125
+ @property
126
+ def model(self):
127
+ return first(self.online_model)
128
+
129
+ @property
130
+ def beta(self):
131
+ return (1. - 1. / (self.step.item() + 1.)) ** (1. + self.gamma)
132
+
133
+ def eval(self):
134
+ return self.ema_model.eval()
135
+
136
+ def restore_ema_model_device(self):
137
+ device = self.initted.device
138
+ self.ema_model.to(device)
139
+
140
+ def get_params_iter(self, model):
141
+ for name, param in model.named_parameters():
142
+ if name not in self.parameter_names:
143
+ continue
144
+ yield name, param
145
+
146
+ def get_buffers_iter(self, model):
147
+ for name, buffer in model.named_buffers():
148
+ if name not in self.buffer_names:
149
+ continue
150
+ yield name, buffer
151
+
152
+ def copy_params_from_model_to_ema(self):
153
+ copy = self.inplace_copy
154
+
155
+ for (_, ma_params), (_, current_params) in zip(self.get_params_iter(self.ema_model), self.get_params_iter(self.model)):
156
+ copy(ma_params.data, current_params.data)
157
+
158
+ for (_, ma_buffers), (_, current_buffers) in zip(self.get_buffers_iter(self.ema_model), self.get_buffers_iter(self.model)):
159
+ copy(ma_buffers.data, current_buffers.data)
160
+
161
+ def copy_params_from_ema_to_model(self):
162
+ copy = self.inplace_copy
163
+
164
+ for (_, ma_params), (_, current_params) in zip(self.get_params_iter(self.ema_model), self.get_params_iter(self.model)):
165
+ copy(current_params.data, ma_params.data)
166
+
167
+ for (_, ma_buffers), (_, current_buffers) in zip(self.get_buffers_iter(self.ema_model), self.get_buffers_iter(self.model)):
168
+ copy(current_buffers.data, ma_buffers.data)
169
+
170
+ def update(self):
171
+ step = self.step.item()
172
+ self.step += 1
173
+
174
+ if (step % self.update_every) != 0:
175
+ return
176
+
177
+ if not self.initted.item():
178
+ self.copy_params_from_model_to_ema()
179
+ self.initted.data.copy_(torch.tensor(True))
180
+
181
+ self.update_moving_average(self.ema_model, self.model)
182
+
183
+ def iter_all_ema_params_and_buffers(self):
184
+ for name, ma_params in self.get_params_iter(self.ema_model):
185
+ if name in self.ignore_names:
186
+ continue
187
+
188
+ if any([name.startswith(prefix) for prefix in self.ignore_startswith_names]):
189
+ continue
190
+
191
+ if name in self.param_or_buffer_names_no_ema:
192
+ continue
193
+
194
+ yield ma_params
195
+
196
+ for name, ma_buffer in self.get_buffers_iter(self.ema_model):
197
+ if name in self.ignore_names:
198
+ continue
199
+
200
+ if any([name.startswith(prefix) for prefix in self.ignore_startswith_names]):
201
+ continue
202
+
203
+ if name in self.param_or_buffer_names_no_ema:
204
+ continue
205
+
206
+ yield ma_buffer
207
+
208
+ @torch.no_grad()
209
+ def update_moving_average(self, ma_model, current_model):
210
+ if self.frozen:
211
+ return
212
+
213
+ # move ema model to online model device if not same and needed
214
+
215
+ if self.move_ema_to_online_device and get_module_device(ma_model) != get_module_device(current_model):
216
+ ma_model.to(get_module_device(current_model))
217
+
218
+ # get some functions and current decay
219
+
220
+ copy, lerp = self.inplace_copy, self.inplace_lerp
221
+ current_decay = self.beta
222
+
223
+ for (name, current_params), (_, ma_params) in zip(self.get_params_iter(current_model), self.get_params_iter(ma_model)):
224
+ if name in self.ignore_names:
225
+ continue
226
+
227
+ if any([name.startswith(prefix) for prefix in self.ignore_startswith_names]):
228
+ continue
229
+
230
+ if name in self.param_or_buffer_names_no_ema:
231
+ copy(ma_params.data, current_params.data)
232
+ continue
233
+
234
+ lerp(ma_params.data, current_params.data, 1. - current_decay)
235
+
236
+ for (name, current_buffer), (_, ma_buffer) in zip(self.get_buffers_iter(current_model), self.get_buffers_iter(ma_model)):
237
+ if name in self.ignore_names:
238
+ continue
239
+
240
+ if any([name.startswith(prefix) for prefix in self.ignore_startswith_names]):
241
+ continue
242
+
243
+ if name in self.param_or_buffer_names_no_ema:
244
+ copy(ma_buffer.data, current_buffer.data)
245
+ continue
246
+
247
+ lerp(ma_buffer.data, current_buffer.data, 1. - current_decay)
248
+
249
+ def __call__(self, *args, **kwargs):
250
+ return self.ema_model(*args, **kwargs)
251
+
252
+ # post hoc ema wrapper
253
+
254
+ # solving of the weights for combining all checkpoints into a newly synthesized EMA at desired gamma
255
+ # Algorithm 3 copied from paper, redone in torch
256
+
257
+ def p_dot_p(t_a, gamma_a, t_b, gamma_b):
258
+ t_ratio = t_a / t_b
259
+ t_exp = torch.where(t_a < t_b , gamma_b , -gamma_a)
260
+ t_max = torch.maximum(t_a , t_b)
261
+ num = (gamma_a + 1) * (gamma_b + 1) * t_ratio ** t_exp
262
+ den = (gamma_a + gamma_b + 1) * t_max
263
+ return num / den
264
+
265
+ def solve_weights(t_i, gamma_i, t_r, gamma_r):
266
+ rv = lambda x: x.double().reshape(-1, 1)
267
+ cv = lambda x: x.double().reshape(1, -1)
268
+ A = p_dot_p(rv(t_i), rv(gamma_i), cv(t_i), cv(gamma_i))
269
+ b = p_dot_p(rv(t_i), rv(gamma_i), cv(t_r), cv(gamma_r))
270
+ return torch.linalg.solve(A, b)
271
+
272
+ class PostHocEMA(Module):
273
+
274
+ def __init__(
275
+ self,
276
+ model: Module,
277
+ sigma_rels: Tuple[float, ...] | None = None,
278
+ gammas: Tuple[float, ...] | None = None,
279
+ checkpoint_every_num_steps: int = 1000,
280
+ checkpoint_folder: str = './post-hoc-ema-checkpoints',
281
+ checkpoint_dtype: torch.dtype = torch.float16,
282
+ **kwargs
283
+ ):
284
+ super().__init__()
285
+ assert exists(sigma_rels) ^ exists(gammas)
286
+
287
+ if exists(sigma_rels):
288
+ gammas = tuple(map(sigma_rel_to_gamma, sigma_rels))
289
+
290
+ assert len(gammas) > 1, 'at least 2 ema models with different gammas in order to synthesize new ema models of a different gamma'
291
+ assert len(set(gammas)) == len(gammas), 'calculated gammas must be all unique'
292
+
293
+ self.gammas = gammas
294
+ self.num_ema_models = len(gammas)
295
+
296
+ self._model = [model]
297
+ self.ema_models = ModuleList([KarrasEMA(model, gamma = gamma, **kwargs) for gamma in gammas])
298
+
299
+ self.checkpoint_folder = Path(checkpoint_folder)
300
+ self.checkpoint_folder.mkdir(exist_ok = True, parents = True)
301
+ assert self.checkpoint_folder.is_dir()
302
+
303
+ self.checkpoint_every_num_steps = checkpoint_every_num_steps
304
+ self.checkpoint_dtype = checkpoint_dtype
305
+ self.ema_kwargs = kwargs
306
+
307
+ @property
308
+ def model(self):
309
+ return first(self._model)
310
+
311
+ @property
312
+ def step(self):
313
+ return first(self.ema_models).step
314
+
315
+ @property
316
+ def device(self):
317
+ return self.step.device
318
+
319
+ def copy_params_from_model_to_ema(self):
320
+ for ema_model in self.ema_models:
321
+ ema_model.copy_params_from_model_to_ema()
322
+
323
+ def copy_params_from_ema_to_model(self):
324
+ for ema_model in self.ema_models:
325
+ ema_model.copy_params_from_ema_to_model()
326
+
327
+ def update(self):
328
+ for ema_model in self.ema_models:
329
+ ema_model.update()
330
+
331
+ if not (self.step.item() % self.checkpoint_every_num_steps):
332
+ self.checkpoint()
333
+
334
+ def checkpoint(self):
335
+ step = self.step.item()
336
+
337
+ for ind, ema_model in enumerate(self.ema_models):
338
+ filename = f'{ind}.{step}.pt'
339
+ path = self.checkpoint_folder / filename
340
+
341
+ pkg = deepcopy(ema_model).to(self.checkpoint_dtype).state_dict()
342
+ torch.save(pkg, str(path))
343
+
344
+ def synthesize_ema_model(
345
+ self,
346
+ gamma: float | None = None,
347
+ sigma_rel: float | None = None,
348
+ step: int | None = None,
349
+ ) -> KarrasEMA:
350
+ assert exists(gamma) ^ exists(sigma_rel)
351
+ device = self.device
352
+
353
+ if exists(sigma_rel):
354
+ gamma = sigma_rel_to_gamma(sigma_rel)
355
+
356
+ synthesized_ema_model = KarrasEMA(
357
+ model = self.model,
358
+ gamma = gamma,
359
+ **self.ema_kwargs
360
+ )
361
+
362
+ synthesized_ema_model
363
+
364
+ # get all checkpoints
365
+
366
+ gammas = []
367
+ timesteps = []
368
+ checkpoints = [*self.checkpoint_folder.glob('*.pt')]
369
+
370
+ for file in checkpoints:
371
+ gamma_ind, timestep = map(int, file.stem.split('.'))
372
+ gammas.append(self.gammas[gamma_ind])
373
+ timesteps.append(timestep)
374
+
375
+ step = default(step, max(timesteps))
376
+ assert step <= max(timesteps), f'you can only synthesize for a timestep that is less than the max timestep {max(timesteps)}'
377
+
378
+ # line up with Algorithm 3
379
+
380
+ gamma_i = torch.tensor(gammas, device = device)
381
+ t_i = torch.tensor(timesteps, device = device)
382
+
383
+ gamma_r = torch.tensor([gamma], device = device)
384
+ t_r = torch.tensor([step], device = device)
385
+
386
+ # solve for weights for combining all checkpoints into synthesized, using least squares as in paper
387
+
388
+ weights = solve_weights(t_i, gamma_i, t_r, gamma_r)
389
+ weights = weights.squeeze(-1)
390
+
391
+ # now sum up all the checkpoints using the weights one by one
392
+
393
+ tmp_ema_model = KarrasEMA(
394
+ model = self.model,
395
+ gamma = gamma,
396
+ **self.ema_kwargs
397
+ )
398
+
399
+ for ind, (checkpoint, weight) in enumerate(zip(checkpoints, weights.tolist())):
400
+ is_first = ind == 0
401
+
402
+ # load checkpoint into a temporary ema model
403
+
404
+ ckpt_state_dict = torch.load(str(checkpoint), weights_only=True)
405
+ tmp_ema_model.load_state_dict(ckpt_state_dict)
406
+
407
+ # add weighted checkpoint to synthesized
408
+
409
+ for ckpt_tensor, synth_tensor in zip(tmp_ema_model.iter_all_ema_params_and_buffers(), synthesized_ema_model.iter_all_ema_params_and_buffers()):
410
+ if is_first:
411
+ synth_tensor.zero_()
412
+
413
+ synth_tensor.add_(ckpt_tensor * weight)
414
+
415
+ # return the synthesized model
416
+
417
+ return synthesized_ema_model
418
+
419
+ def __call__(self, *args, **kwargs):
420
+ return tuple(ema_model(*args, **kwargs) for ema_model in self.ema_models)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ema-pytorch
3
- Version: 0.3.1
3
+ Version: 0.6.3
4
4
  Summary: Easy way to keep track of exponential moving average version of your pytorch module
5
5
  Home-page: https://github.com/lucidrains/ema-pytorch
6
6
  Author: Phil Wang
@@ -14,5 +14,4 @@ Classifier: License :: OSI Approved :: MIT License
14
14
  Classifier: Programming Language :: Python :: 3.6
15
15
  Description-Content-Type: text/markdown
16
16
  License-File: LICENSE
17
- Requires-Dist: beartype
18
- Requires-Dist: torch>=1.6
17
+ Requires-Dist: torch>=2.0
@@ -3,6 +3,7 @@ README.md
3
3
  setup.py
4
4
  ema_pytorch/__init__.py
5
5
  ema_pytorch/ema_pytorch.py
6
+ ema_pytorch/post_hoc_ema.py
6
7
  ema_pytorch.egg-info/PKG-INFO
7
8
  ema_pytorch.egg-info/SOURCES.txt
8
9
  ema_pytorch.egg-info/dependency_links.txt
@@ -0,0 +1 @@
1
+ torch>=2.0
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
3
3
  setup(
4
4
  name = 'ema-pytorch',
5
5
  packages = find_packages(exclude=[]),
6
- version = '0.3.1',
6
+ version = '0.6.3',
7
7
  license='MIT',
8
8
  description = 'Easy way to keep track of exponential moving average version of your pytorch module',
9
9
  author = 'Phil Wang',
@@ -16,8 +16,7 @@ setup(
16
16
  'exponential moving average'
17
17
  ],
18
18
  install_requires=[
19
- 'beartype',
20
- 'torch>=1.6',
19
+ 'torch>=2.0',
21
20
  ],
22
21
  classifiers=[
23
22
  'Development Status :: 4 - Beta',
@@ -1,54 +0,0 @@
1
- ## EMA - Pytorch
2
-
3
- A simple way to keep track of an Exponential Moving Average (EMA) version of your pytorch model
4
-
5
- ## Install
6
-
7
- ```bash
8
- $ pip install ema-pytorch
9
- ```
10
-
11
- ## Usage
12
-
13
- ```python
14
- import torch
15
- from ema_pytorch import EMA
16
-
17
- # your neural network as a pytorch module
18
-
19
- net = torch.nn.Linear(512, 512)
20
-
21
- # wrap your neural network, specify the decay (beta)
22
-
23
- ema = EMA(
24
- net,
25
- beta = 0.9999, # exponential moving average factor
26
- update_after_step = 100, # only after this number of .update() calls will it start updating
27
- update_every = 10, # how often to actually update, to save on compute (updates every 10th .update() call)
28
- )
29
-
30
- # mutate your network, with SGD or otherwise
31
-
32
- with torch.no_grad():
33
- net.weight.copy_(torch.randn_like(net.weight))
34
- net.bias.copy_(torch.randn_like(net.bias))
35
-
36
- # you will call the update function on your moving average wrapper
37
-
38
- ema.update()
39
-
40
- # then, later on, you can invoke the EMA model the same way as your network
41
-
42
- data = torch.randn(1, 512)
43
-
44
- output = net(data)
45
- ema_output = ema(data)
46
-
47
- # if you want to save your ema model, it is recommended you save the entire wrapper
48
- # as it contains the number of steps taken (there is a warmup logic in there, recommended by @crowsonkb, validated for a number of projects now)
49
- # however, if you wish to access the copy of your model with EMA, then it will live at ema.ema_model
50
- ```
51
-
52
- ## Todo
53
-
54
- - [ ] address the issue of annealing EMA to 1 near the end of training for BYOL https://github.com/lucidrains/byol-pytorch/issues/82
@@ -1 +0,0 @@
1
- from ema_pytorch.ema_pytorch import EMA
@@ -1,2 +0,0 @@
1
- beartype
2
- torch>=1.6
File without changes
File without changes