ema-pytorch 0.3.1__py3-none-any.whl → 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ema-pytorch might be problematic. Click here for more details.
- ema_pytorch/ema_pytorch.py +16 -2
- {ema_pytorch-0.3.1.dist-info → ema_pytorch-0.3.3.dist-info}/METADATA +1 -1
- ema_pytorch-0.3.3.dist-info/RECORD +7 -0
- {ema_pytorch-0.3.1.dist-info → ema_pytorch-0.3.3.dist-info}/WHEEL +1 -1
- ema_pytorch-0.3.1.dist-info/RECORD +0 -7
- {ema_pytorch-0.3.1.dist-info → ema_pytorch-0.3.3.dist-info}/LICENSE +0 -0
- {ema_pytorch-0.3.1.dist-info → ema_pytorch-0.3.3.dist-info}/top_level.txt +0 -0
ema_pytorch/ema_pytorch.py
CHANGED
|
@@ -43,7 +43,7 @@ class EMA(Module):
|
|
|
43
43
|
|
|
44
44
|
Args:
|
|
45
45
|
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
|
|
46
|
-
power (float): Exponential factor of EMA warmup. Default:
|
|
46
|
+
power (float): Exponential factor of EMA warmup. Default: 2/3.
|
|
47
47
|
min_value (float): The minimum EMA decay rate. Default: 0.
|
|
48
48
|
"""
|
|
49
49
|
|
|
@@ -53,6 +53,7 @@ class EMA(Module):
|
|
|
53
53
|
model: Module,
|
|
54
54
|
ema_model: Optional[Module] = None, # if your model has lazylinears or other types of non-deepcopyable modules, you can pass in your own ema model
|
|
55
55
|
beta = 0.9999,
|
|
56
|
+
karras_beta = False, # if True, uses the karras time dependent beta
|
|
56
57
|
update_after_step = 100,
|
|
57
58
|
update_every = 10,
|
|
58
59
|
inv_gamma = 1.0,
|
|
@@ -65,7 +66,10 @@ class EMA(Module):
|
|
|
65
66
|
allow_different_devices = False # if the EMA model is on a different device (say CPU), automatically move the tensor
|
|
66
67
|
):
|
|
67
68
|
super().__init__()
|
|
68
|
-
self.
|
|
69
|
+
self._beta = beta
|
|
70
|
+
self.karras_beta = karras_beta
|
|
71
|
+
|
|
72
|
+
self.is_frozen = beta == 1.
|
|
69
73
|
|
|
70
74
|
# whether to include the online model within the module tree, so that state_dict also saves it
|
|
71
75
|
|
|
@@ -127,6 +131,13 @@ class EMA(Module):
|
|
|
127
131
|
@property
|
|
128
132
|
def model(self):
|
|
129
133
|
return self.online_model if self.include_online_model else self.online_model[0]
|
|
134
|
+
|
|
135
|
+
@property
|
|
136
|
+
def beta(self):
|
|
137
|
+
if self.karras_beta:
|
|
138
|
+
return (1 - 1 / (self.step + 1)) ** (1 + self.power)
|
|
139
|
+
|
|
140
|
+
return self._beta
|
|
130
141
|
|
|
131
142
|
def eval(self):
|
|
132
143
|
return self.ema_model.eval()
|
|
@@ -193,6 +204,9 @@ class EMA(Module):
|
|
|
193
204
|
|
|
194
205
|
@torch.no_grad()
|
|
195
206
|
def update_moving_average(self, ma_model, current_model):
|
|
207
|
+
if self.is_frozen:
|
|
208
|
+
return
|
|
209
|
+
|
|
196
210
|
copy, lerp = self.inplace_copy, self.inplace_lerp
|
|
197
211
|
current_decay = self.get_current_decay()
|
|
198
212
|
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
ema_pytorch/__init__.py,sha256=jCgUV6FfA65ct6lMEClKKwl-gjZMDv7aOeaN3RshE0U,40
|
|
2
|
+
ema_pytorch/ema_pytorch.py,sha256=eOycxsyO3aCBzToiBGEqOJIGxklX3N3Yitq3wldwHkg,8573
|
|
3
|
+
ema_pytorch-0.3.3.dist-info/LICENSE,sha256=xZDkKtpHE2TPCAeqKe1fjdpKernl1YW-d01j_1ltkAU,1066
|
|
4
|
+
ema_pytorch-0.3.3.dist-info/METADATA,sha256=i2-cAyNXNi4l-xUVsrvFP5wpJHvzyglB1BBBFhobwqg,715
|
|
5
|
+
ema_pytorch-0.3.3.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
|
|
6
|
+
ema_pytorch-0.3.3.dist-info/top_level.txt,sha256=XXFJmHviark_32Hfm5X9niezVmnRTUIhfdifCrJgXmE,12
|
|
7
|
+
ema_pytorch-0.3.3.dist-info/RECORD,,
|
|
@@ -1,7 +0,0 @@
|
|
|
1
|
-
ema_pytorch/__init__.py,sha256=jCgUV6FfA65ct6lMEClKKwl-gjZMDv7aOeaN3RshE0U,40
|
|
2
|
-
ema_pytorch/ema_pytorch.py,sha256=ALWUc6STWA7ZqACn4OuXAiWqb57tk3Vj5B-RowHKNYU,8186
|
|
3
|
-
ema_pytorch-0.3.1.dist-info/LICENSE,sha256=xZDkKtpHE2TPCAeqKe1fjdpKernl1YW-d01j_1ltkAU,1066
|
|
4
|
-
ema_pytorch-0.3.1.dist-info/METADATA,sha256=dqsiNPoAH7aG_Na--NPg4P_DsgbpJjcc3KsA1lu6e-I,715
|
|
5
|
-
ema_pytorch-0.3.1.dist-info/WHEEL,sha256=Xo9-1PvkuimrydujYJAjF7pCkriuXBpUPEjma1nZyJ0,92
|
|
6
|
-
ema_pytorch-0.3.1.dist-info/top_level.txt,sha256=XXFJmHviark_32Hfm5X9niezVmnRTUIhfdifCrJgXmE,12
|
|
7
|
-
ema_pytorch-0.3.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|