ema-pytorch 0.3.1__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ema-pytorch might be problematic. Click here for more details.

@@ -43,7 +43,7 @@ class EMA(Module):
43
43
 
44
44
  Args:
45
45
  inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
46
- power (float): Exponential factor of EMA warmup. Default: 1.
46
+ power (float): Exponential factor of EMA warmup. Default: 2/3.
47
47
  min_value (float): The minimum EMA decay rate. Default: 0.
48
48
  """
49
49
 
@@ -53,6 +53,7 @@ class EMA(Module):
53
53
  model: Module,
54
54
  ema_model: Optional[Module] = None, # if your model has lazylinears or other types of non-deepcopyable modules, you can pass in your own ema model
55
55
  beta = 0.9999,
56
+ karras_beta = False, # if True, uses the karras time dependent beta
56
57
  update_after_step = 100,
57
58
  update_every = 10,
58
59
  inv_gamma = 1.0,
@@ -65,7 +66,10 @@ class EMA(Module):
65
66
  allow_different_devices = False # if the EMA model is on a different device (say CPU), automatically move the tensor
66
67
  ):
67
68
  super().__init__()
68
- self.beta = beta
69
+ self._beta = beta
70
+ self.karras_beta = karras_beta
71
+
72
+ self.is_frozen = beta == 1.
69
73
 
70
74
  # whether to include the online model within the module tree, so that state_dict also saves it
71
75
 
@@ -127,6 +131,13 @@ class EMA(Module):
127
131
  @property
128
132
  def model(self):
129
133
  return self.online_model if self.include_online_model else self.online_model[0]
134
+
135
+ @property
136
+ def beta(self):
137
+ if self.karras_beta:
138
+ return (1 - 1 / (self.step + 1)) ** (1 + self.power)
139
+
140
+ return self._beta
130
141
 
131
142
  def eval(self):
132
143
  return self.ema_model.eval()
@@ -193,6 +204,9 @@ class EMA(Module):
193
204
 
194
205
  @torch.no_grad()
195
206
  def update_moving_average(self, ma_model, current_model):
207
+ if self.is_frozen:
208
+ return
209
+
196
210
  copy, lerp = self.inplace_copy, self.inplace_lerp
197
211
  current_decay = self.get_current_decay()
198
212
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ema-pytorch
3
- Version: 0.3.1
3
+ Version: 0.3.3
4
4
  Summary: Easy way to keep track of exponential moving average version of your pytorch module
5
5
  Home-page: https://github.com/lucidrains/ema-pytorch
6
6
  Author: Phil Wang
@@ -0,0 +1,7 @@
1
+ ema_pytorch/__init__.py,sha256=jCgUV6FfA65ct6lMEClKKwl-gjZMDv7aOeaN3RshE0U,40
2
+ ema_pytorch/ema_pytorch.py,sha256=eOycxsyO3aCBzToiBGEqOJIGxklX3N3Yitq3wldwHkg,8573
3
+ ema_pytorch-0.3.3.dist-info/LICENSE,sha256=xZDkKtpHE2TPCAeqKe1fjdpKernl1YW-d01j_1ltkAU,1066
4
+ ema_pytorch-0.3.3.dist-info/METADATA,sha256=i2-cAyNXNi4l-xUVsrvFP5wpJHvzyglB1BBBFhobwqg,715
5
+ ema_pytorch-0.3.3.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
6
+ ema_pytorch-0.3.3.dist-info/top_level.txt,sha256=XXFJmHviark_32Hfm5X9niezVmnRTUIhfdifCrJgXmE,12
7
+ ema_pytorch-0.3.3.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: bdist_wheel (0.41.3)
2
+ Generator: bdist_wheel (0.42.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,7 +0,0 @@
1
- ema_pytorch/__init__.py,sha256=jCgUV6FfA65ct6lMEClKKwl-gjZMDv7aOeaN3RshE0U,40
2
- ema_pytorch/ema_pytorch.py,sha256=ALWUc6STWA7ZqACn4OuXAiWqb57tk3Vj5B-RowHKNYU,8186
3
- ema_pytorch-0.3.1.dist-info/LICENSE,sha256=xZDkKtpHE2TPCAeqKe1fjdpKernl1YW-d01j_1ltkAU,1066
4
- ema_pytorch-0.3.1.dist-info/METADATA,sha256=dqsiNPoAH7aG_Na--NPg4P_DsgbpJjcc3KsA1lu6e-I,715
5
- ema_pytorch-0.3.1.dist-info/WHEEL,sha256=Xo9-1PvkuimrydujYJAjF7pCkriuXBpUPEjma1nZyJ0,92
6
- ema_pytorch-0.3.1.dist-info/top_level.txt,sha256=XXFJmHviark_32Hfm5X9niezVmnRTUIhfdifCrJgXmE,12
7
- ema_pytorch-0.3.1.dist-info/RECORD,,