hcpdiff 2.2.1__py3-none-any.whl → 2.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hcpdiff/ckpt_manager/__init__.py +1 -1
- hcpdiff/ckpt_manager/ckpt.py +21 -17
- hcpdiff/ckpt_manager/format/diffusers.py +4 -4
- hcpdiff/ckpt_manager/format/sd_single.py +3 -3
- hcpdiff/ckpt_manager/loader.py +11 -4
- hcpdiff/diffusion/noise/__init__.py +0 -1
- hcpdiff/diffusion/sampler/VP.py +27 -0
- hcpdiff/diffusion/sampler/__init__.py +2 -3
- hcpdiff/diffusion/sampler/base.py +106 -44
- hcpdiff/diffusion/sampler/diffusers.py +11 -17
- hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -1
- hcpdiff/diffusion/sampler/sigma_scheduler/base.py +77 -2
- hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +193 -49
- hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +110 -33
- hcpdiff/diffusion/sampler/sigma_scheduler/flow.py +74 -0
- hcpdiff/diffusion/sampler/sigma_scheduler/zero_terminal.py +22 -0
- hcpdiff/easy/cfg/sd15_train.py +33 -22
- hcpdiff/easy/cfg/sdxl_train.py +32 -23
- hcpdiff/evaluate/__init__.py +3 -1
- hcpdiff/evaluate/evaluator.py +76 -0
- hcpdiff/evaluate/metrics/__init__.py +1 -0
- hcpdiff/evaluate/metrics/clip_score.py +23 -0
- hcpdiff/evaluate/previewer.py +29 -12
- hcpdiff/loss/base.py +9 -26
- hcpdiff/loss/weighting.py +36 -18
- hcpdiff/models/lora_base_patch.py +26 -0
- hcpdiff/models/wrapper/sd.py +17 -19
- hcpdiff/trainer_ac.py +7 -5
- hcpdiff/trainer_ac_single.py +1 -6
- hcpdiff/utils/__init__.py +2 -1
- hcpdiff/utils/torch_utils.py +25 -0
- hcpdiff/workflow/__init__.py +1 -1
- hcpdiff/workflow/diffusion.py +27 -7
- hcpdiff/workflow/io.py +20 -3
- hcpdiff/workflow/text.py +6 -1
- {hcpdiff-2.2.1.dist-info → hcpdiff-2.3.1.dist-info}/METADATA +2 -2
- {hcpdiff-2.2.1.dist-info → hcpdiff-2.3.1.dist-info}/RECORD +41 -37
- {hcpdiff-2.2.1.dist-info → hcpdiff-2.3.1.dist-info}/WHEEL +1 -1
- hcpdiff/diffusion/noise/zero_terminal.py +0 -39
- hcpdiff/diffusion/sampler/ddpm.py +0 -20
- hcpdiff/diffusion/sampler/edm.py +0 -22
- {hcpdiff-2.2.1.dist-info → hcpdiff-2.3.1.dist-info}/entry_points.txt +0 -0
- {hcpdiff-2.2.1.dist-info → hcpdiff-2.3.1.dist-info}/licenses/LICENSE +0 -0
- {hcpdiff-2.2.1.dist-info → hcpdiff-2.3.1.dist-info}/top_level.txt +0 -0
@@ -1,17 +1,22 @@
|
|
1
|
-
import torch
|
2
1
|
import math
|
3
|
-
from typing import Union, Tuple
|
4
|
-
|
2
|
+
from typing import Union, Tuple, Callable
|
3
|
+
|
4
|
+
import torch
|
5
|
+
|
6
|
+
from hcpdiff.utils import invert_func
|
5
7
|
from .base import SigmaScheduler
|
6
8
|
|
7
9
|
class DDPMDiscreteSigmaScheduler(SigmaScheduler):
|
8
|
-
def __init__(self, beta_schedule: str = "scaled_linear", linear_start=0.00085, linear_end=0.0120, num_timesteps=1000):
|
10
|
+
def __init__(self, beta_schedule: str = "scaled_linear", linear_start=0.00085, linear_end=0.0120, num_timesteps=1000, pred_type='eps'):
|
9
11
|
super().__init__()
|
10
12
|
self.num_timesteps = num_timesteps
|
11
13
|
self.betas = self.make_betas(beta_schedule, linear_start, linear_end, num_timesteps)
|
12
14
|
alphas = 1.0-self.betas
|
13
15
|
self.alphas_cumprod = torch.cumprod(alphas, dim=0)
|
14
|
-
|
16
|
+
|
17
|
+
self.alphas = self.alphas_cumprod.sqrt()
|
18
|
+
self.sigmas = (1-self.alphas_cumprod).sqrt()
|
19
|
+
self.pred_type = pred_type
|
15
20
|
|
16
21
|
# for VLB calculation
|
17
22
|
self.alphas_cumprod_prev = torch.cat([alphas.new_tensor([1.0]), self.alphas_cumprod[:-1]])
|
@@ -22,37 +27,73 @@ class DDPMDiscreteSigmaScheduler(SigmaScheduler):
|
|
22
27
|
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
23
28
|
self.posterior_log_variance_clipped = torch.log(torch.cat([self.posterior_variance[1:2], self.posterior_variance[1:]]))
|
24
29
|
|
30
|
+
# def scale_t(self, t):
|
31
|
+
# return t*(self.num_timesteps-1)
|
25
32
|
|
26
33
|
@property
|
27
|
-
def
|
34
|
+
def sigma_start(self):
|
28
35
|
return self.sigmas[0]
|
29
36
|
|
30
37
|
@property
|
31
|
-
def
|
38
|
+
def sigma_end(self):
|
32
39
|
return self.sigmas[-1]
|
33
40
|
|
34
|
-
|
41
|
+
@property
|
42
|
+
def alpha_start(self):
|
43
|
+
return self.alphas[0]
|
44
|
+
|
45
|
+
@property
|
46
|
+
def alpha_end(self):
|
47
|
+
return self.alphas[-1]
|
48
|
+
|
49
|
+
def sigma(self, t: Union[float, torch.Tensor]):
|
35
50
|
if isinstance(t, float):
|
36
51
|
t = torch.tensor(t)
|
37
|
-
|
52
|
+
self.sigmas = self.sigmas.to(t.device)
|
53
|
+
return self.sigmas[((t*self.num_timesteps).round().long()).clip(min=0, max=self.num_timesteps-1)]
|
38
54
|
|
39
|
-
def
|
40
|
-
if isinstance(
|
41
|
-
|
42
|
-
|
43
|
-
|
55
|
+
def alpha(self, t: Union[float, torch.Tensor]):
|
56
|
+
if isinstance(t, float):
|
57
|
+
t = torch.tensor(t)
|
58
|
+
self.alphas = self.alphas.to(t.device)
|
59
|
+
return self.alphas[((t*self.num_timesteps).round().long()).clip(min=0, max=self.num_timesteps-1)]
|
60
|
+
|
61
|
+
def c_noise(self, t: Union[float, torch.Tensor]):
|
62
|
+
return (t*self.num_timesteps).round()
|
44
63
|
|
45
|
-
|
46
|
-
|
47
|
-
|
64
|
+
def velocity(self, t: Union[float, torch.Tensor], dt=1e-8, normlize=True) -> Tuple[torch.Tensor, torch.Tensor]:
|
65
|
+
'''
|
66
|
+
v(t) = dx(t)/dt = d\alpha(t)/dt * x(0) + d\sigma(t)/dt *eps
|
67
|
+
:param t: 0-1, rate of time step
|
68
|
+
:return: d\alpha(t)/dt, d\sigma(t)/dt
|
69
|
+
'''
|
70
|
+
d_alpha = -self.sigma(t)
|
71
|
+
d_sigma = self.alpha(t)
|
72
|
+
if normlize:
|
73
|
+
norm = torch.sqrt(d_alpha**2+d_sigma**2)
|
74
|
+
return d_alpha/norm, d_sigma/norm
|
75
|
+
else:
|
76
|
+
return d_alpha, d_sigma
|
48
77
|
|
49
78
|
def sigma_to_t(self, sigma: Union[float, torch.Tensor]):
|
50
|
-
|
51
|
-
|
79
|
+
ref_t = np.linspace(0, 1, len(self.sigmas))
|
80
|
+
t = torch.tensor(np.interp(sigma.cpu().clip(min=1e-8).log().numpy(), self.sigmas, ref_t))
|
81
|
+
return t
|
82
|
+
|
83
|
+
def alpha_to_t(self, alpha: Union[float, torch.Tensor]):
|
84
|
+
ref_t = np.linspace(0, 1, len(self.alphas))
|
85
|
+
t = torch.tensor(np.interp(alpha.cpu().clip(min=1e-8).log().numpy(), self.alphas, ref_t))
|
86
|
+
return t
|
87
|
+
|
88
|
+
def alpha_to_sigma(self, alpha):
|
89
|
+
return torch.sqrt(1 - alpha**2)
|
90
|
+
|
91
|
+
def sigma_to_alpha(self, sigma):
|
92
|
+
return torch.sqrt(1 - sigma**2)
|
52
93
|
|
53
94
|
def get_post_mean(self, t, x_0, x_t):
|
54
95
|
t = (t*len(self.sigmas)).long()
|
55
|
-
return self.posterior_mean_coef1[t].view(-1, 1, 1, 1).to(t.device)*x_0
|
96
|
+
return self.posterior_mean_coef1[t].view(-1, 1, 1, 1).to(t.device)*x_0+self.posterior_mean_coef2[t].view(-1, 1, 1, 1).to(t.device)*x_t
|
56
97
|
|
57
98
|
def get_post_log_var(self, t, x_t_var=None):
|
58
99
|
t = (t*len(self.sigmas)).long()
|
@@ -66,7 +107,6 @@ class DDPMDiscreteSigmaScheduler(SigmaScheduler):
|
|
66
107
|
model_log_variance = frac*max_log+(1-frac)*min_log
|
67
108
|
return model_log_variance
|
68
109
|
|
69
|
-
|
70
110
|
@staticmethod
|
71
111
|
def betas_for_alpha_bar(
|
72
112
|
num_diffusion_timesteps,
|
@@ -130,50 +170,154 @@ class DDPMDiscreteSigmaScheduler(SigmaScheduler):
|
|
130
170
|
else:
|
131
171
|
raise NotImplementedError(f"{beta_schedule} does is not implemented.")
|
132
172
|
|
133
|
-
class DDPMContinuousSigmaScheduler(
|
173
|
+
class DDPMContinuousSigmaScheduler(SigmaScheduler):
|
174
|
+
def __init__(self, beta_schedule: str = "scaled_linear", linear_start=0.00085, linear_end=0.0120, t_base=1000):
|
175
|
+
self.alpha_bar_fn = self.make_alpha_bar_fn(beta_schedule, linear_start, linear_end)
|
176
|
+
self.t_base = t_base # base time step for continuous product
|
177
|
+
|
178
|
+
def continuous_product(self, alpha_fn: Callable[[torch.Tensor], torch.Tensor], t: torch.Tensor):
|
179
|
+
'''
|
180
|
+
|
181
|
+
:param alpha_fn: alpha function
|
182
|
+
:param t: timesteps with shape [B]
|
183
|
+
:return: [B]
|
184
|
+
'''
|
185
|
+
bins = torch.linspace(0, 1, self.t_base, dtype=torch.float32).unsqueeze(0)
|
186
|
+
t_grid = bins*t.float().unsqueeze(1) # [B, num_bins]
|
187
|
+
alpha_vals = alpha_fn(t_grid)
|
188
|
+
|
189
|
+
if torch.any(alpha_vals<=0):
|
190
|
+
raise ValueError("alpha(t) must > 0 to avoid log(≤0).")
|
191
|
+
|
192
|
+
log_term = torch.log(alpha_vals) # [B, num_bins]
|
193
|
+
dt = t_grid[:, 1]-t_grid[:, 0] # [B]
|
194
|
+
integral = torch.cumsum((log_term[:, -1]+log_term[:, 1:])/2*dt.unsqueeze(1), dim=1) # [B]
|
195
|
+
x_vals = torch.exp(integral)
|
196
|
+
return x_vals
|
197
|
+
|
198
|
+
@staticmethod
|
199
|
+
def alpha_bar_linear(beta_s, beta_e, t, N=1000):
|
200
|
+
A = beta_e-beta_s
|
201
|
+
B = 1-beta_s
|
202
|
+
B_At = B-A*t
|
203
|
+
|
204
|
+
# 避免数值不稳定
|
205
|
+
eps = 1e-12
|
206
|
+
B = torch.clamp(B, min=eps)
|
207
|
+
B_At = torch.clamp(B_At, min=eps)
|
208
|
+
|
209
|
+
term = (B*torch.log(B)-B_At*torch.log(B_At)-A*t)
|
210
|
+
return torch.exp(N*term/A)
|
211
|
+
|
212
|
+
@staticmethod
|
213
|
+
def alpha_bar_scaled_linear(beta_s, beta_e, t, N=1000):
|
214
|
+
sqrt_bs = torch.sqrt(beta_s)
|
215
|
+
sqrt_be = torch.sqrt(beta_e)
|
216
|
+
a = sqrt_be-sqrt_bs
|
217
|
+
b = sqrt_bs
|
218
|
+
u0 = b
|
219
|
+
u1 = a*t+b
|
220
|
+
|
221
|
+
eps = 1e-12
|
222
|
+
|
223
|
+
def safe_log1m(u2):
|
224
|
+
return torch.log(torch.clamp(1-u2, min=eps))
|
225
|
+
|
226
|
+
def safe_log_frac(u):
|
227
|
+
return torch.log(torch.clamp(1+u, min=eps))-torch.log(torch.clamp(1-u, min=eps))
|
134
228
|
|
135
|
-
|
229
|
+
term1 = u1*safe_log1m(u1**2)
|
230
|
+
term2 = 0.5*safe_log_frac(u1)
|
231
|
+
term3 = u0*safe_log1m(u0**2)
|
232
|
+
term4 = 0.5*safe_log_frac(u0)
|
233
|
+
|
234
|
+
return torch.exp(N*(term1+term2-term3-term4)/a)
|
235
|
+
|
236
|
+
def make_alpha_bar_fn(self, beta_schedule, beta_start, beta_end, alpha_fn=None):
|
237
|
+
if alpha_fn is not None:
|
238
|
+
return lambda t, alpha_fn_=alpha_fn:self.continuous_product(alpha_fn_(t), t)
|
239
|
+
elif beta_schedule == "linear":
|
240
|
+
return lambda t:self.alpha_bar_linear(beta_start, beta_end, t)
|
241
|
+
elif beta_schedule == "scaled_linear":
|
242
|
+
# this schedule is very specific to the latent diffusion model.
|
243
|
+
return lambda t:self.alpha_bar_scaled_linear(beta_start, beta_end, t)
|
244
|
+
elif beta_schedule == "squaredcos_cap_v2":
|
245
|
+
return lambda t:torch.cos((t+0.008)/1.008*math.pi/2)**2
|
246
|
+
elif beta_schedule == "sigmoid":
|
247
|
+
# GeoDiff sigmoid schedule
|
248
|
+
alpha_fn = lambda t:1-torch.sigmoid(torch.lerp(torch.full_like(t, -6), torch.full_like(t, 6), t))*(beta_end-beta_start)+beta_start
|
249
|
+
return lambda t, alpha_fn_=alpha_fn:self.continuous_product(alpha_fn_(t), t)
|
250
|
+
else:
|
251
|
+
raise NotImplementedError(f"{beta_schedule} does is not implemented.")
|
252
|
+
|
253
|
+
def sigma(self, t: Union[float, torch.Tensor]):
|
136
254
|
if isinstance(t, float):
|
137
|
-
t = torch.tensor(t)
|
138
|
-
|
255
|
+
t = torch.tensor([t])
|
256
|
+
alpha_cumprod = self.alpha_bar_fn(t)
|
257
|
+
return torch.sqrt(1-alpha_cumprod)
|
139
258
|
|
140
|
-
def
|
141
|
-
if isinstance(
|
142
|
-
|
143
|
-
|
144
|
-
|
259
|
+
def alpha(self, t: Union[float, torch.Tensor]):
|
260
|
+
if isinstance(t, float):
|
261
|
+
t = torch.tensor([t])
|
262
|
+
alpha_cumprod = self.alpha_bar_fn(t)
|
263
|
+
return torch.sqrt(alpha_cumprod)
|
145
264
|
|
146
|
-
|
147
|
-
|
265
|
+
def c_noise(self, t: Union[float, torch.Tensor]):
|
266
|
+
return t*self.t_base
|
148
267
|
|
149
|
-
|
268
|
+
@property
|
269
|
+
def sigma_start(self):
|
270
|
+
return self.sigma(0)
|
150
271
|
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
272
|
+
@property
|
273
|
+
def sigma_end(self):
|
274
|
+
return self.sigma(1)
|
275
|
+
|
276
|
+
@property
|
277
|
+
def alpha_start(self):
|
278
|
+
return self.alpha(0)
|
279
|
+
|
280
|
+
@property
|
281
|
+
def alpha_end(self):
|
282
|
+
return self.alpha(1)
|
283
|
+
|
284
|
+
def alpha_to_t(self, alpha, t_min=0.0, t_max=1.0, tol=1e-5, max_iter=100):
|
285
|
+
"""
|
286
|
+
alpha: [B]
|
287
|
+
:return: t [B]
|
288
|
+
"""
|
289
|
+
return invert_func(self.alpha, alpha, t_min, t_max, tol, max_iter)
|
290
|
+
|
291
|
+
def sigma_to_t(self, sigma, t_min=0.0, t_max=1.0, tol=1e-5, max_iter=100):
|
292
|
+
"""
|
293
|
+
sigma: [B]
|
294
|
+
:return: t [B]
|
295
|
+
"""
|
296
|
+
return invert_func(self.sigma, sigma, t_min, t_max, tol, max_iter)
|
156
297
|
|
157
298
|
class TimeSigmaScheduler(SigmaScheduler):
|
158
299
|
def __init__(self, num_timesteps=1000):
|
159
300
|
super().__init__()
|
160
301
|
self.num_timesteps = num_timesteps
|
161
302
|
|
162
|
-
def
|
303
|
+
def sigma(self, t: Union[float, torch.Tensor]) -> torch.Tensor:
|
163
304
|
'''
|
164
305
|
:param t: 0-1, rate of time step
|
165
306
|
'''
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
if isinstance(min_rate, float):
|
170
|
-
min_rate = torch.full(shape, min_rate)
|
171
|
-
if isinstance(max_rate, float):
|
172
|
-
max_rate = torch.full(shape, max_rate)
|
307
|
+
if isinstance(t, float):
|
308
|
+
t = torch.tensor(t)
|
309
|
+
return ((t*self.num_timesteps).round().long()).clip(min=0, max=self.num_timesteps-1)
|
173
310
|
|
174
|
-
|
175
|
-
|
176
|
-
|
311
|
+
def alpha(self, t: Union[float, torch.Tensor]) -> torch.Tensor:
|
312
|
+
'''
|
313
|
+
:param t: 0-1, rate of time step
|
314
|
+
'''
|
315
|
+
if isinstance(t, float):
|
316
|
+
t = torch.tensor(t)
|
317
|
+
return ((t*self.num_timesteps).round().long()).clip(min=0, max=self.num_timesteps-1)
|
318
|
+
|
319
|
+
def c_noise(self, t: Union[float, torch.Tensor]):
|
320
|
+
return (t*self.num_timesteps).round()
|
177
321
|
|
178
322
|
if __name__ == '__main__':
|
179
323
|
from matplotlib import pyplot as plt
|
@@ -1,19 +1,18 @@
|
|
1
|
-
from typing import Union
|
1
|
+
from typing import Union, Tuple
|
2
2
|
|
3
|
-
import torch
|
4
3
|
import numpy as np
|
4
|
+
import torch
|
5
5
|
|
6
6
|
from .base import SigmaScheduler
|
7
7
|
|
8
8
|
class EDMSigmaScheduler(SigmaScheduler):
|
9
|
-
def __init__(self, sigma_min=0.002, sigma_max=80.0, rho=7.0
|
10
|
-
self.sigma_min =
|
11
|
-
self.sigma_max =
|
9
|
+
def __init__(self, sigma_min=0.002, sigma_max=80.0, sigma_data=0.5, rho=7.0):
|
10
|
+
self.sigma_min = sigma_min
|
11
|
+
self.sigma_max = sigma_max
|
12
|
+
self.sigma_data = sigma_data
|
12
13
|
self.rho = rho
|
13
14
|
|
14
|
-
|
15
|
-
|
16
|
-
def get_sigma(self, t: Union[float, torch.Tensor]):
|
15
|
+
def sigma_edm(self, t: Union[float, torch.Tensor]) -> torch.Tensor:
|
17
16
|
if isinstance(t, float):
|
18
17
|
t = torch.tensor(t)
|
19
18
|
|
@@ -21,28 +20,106 @@ class EDMSigmaScheduler(SigmaScheduler):
|
|
21
20
|
max_inv_rho = self.sigma_max**(1/self.rho)
|
22
21
|
return torch.lerp(min_inv_rho, max_inv_rho, t)**self.rho
|
23
22
|
|
24
|
-
def
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
def
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
23
|
+
def sigma(self, t: Union[float, torch.Tensor]) -> torch.Tensor:
|
24
|
+
'''
|
25
|
+
x_t = c_in(t) * (x(0) + \sigma(t)*eps), eps~N(0,I)
|
26
|
+
'''
|
27
|
+
if isinstance(t, float):
|
28
|
+
t = torch.tensor(t)
|
29
|
+
|
30
|
+
sigma_edm = self.sigma_edm(t)
|
31
|
+
return sigma_edm/torch.sqrt(sigma_edm**2+self.sigma_data**2)
|
32
|
+
|
33
|
+
def alpha(self, t: Union[float, torch.Tensor]) -> torch.Tensor:
|
34
|
+
'''
|
35
|
+
x_t = c_in(t) * (x(0) + \sigma(t)*eps), eps~N(0,I)
|
36
|
+
'''
|
37
|
+
if isinstance(t, float):
|
38
|
+
t = torch.tensor(t)
|
39
|
+
|
40
|
+
sigma_edm = self.sigma_edm(t)
|
41
|
+
return 1./torch.sqrt(sigma_edm**2+self.sigma_data**2)
|
42
|
+
|
43
|
+
def c_skip(self, t: Union[float, torch.Tensor]):
|
44
|
+
'''
|
45
|
+
\hat{x}(0) = c_skip(t)*(x(t)/c_in(t)) + c_out(t)*f(x(t))
|
46
|
+
:param t: 0-1, rate of time step
|
47
|
+
'''
|
48
|
+
sigma_edm = self.sigma_edm(t)
|
49
|
+
return self.sigma_data**2/torch.sqrt(sigma_edm**2+self.sigma_data**2)
|
50
|
+
|
51
|
+
def c_out(self, t: Union[float, torch.Tensor]):
|
52
|
+
'''
|
53
|
+
\hat{x}(0) = c_skip(t)*(x(t)/c_in(t)) + c_out(t)*f(x(t))
|
54
|
+
:param t: 0-1, rate of time step
|
55
|
+
'''
|
56
|
+
sigma_edm = self.sigma_edm(t)
|
57
|
+
return (self.sigma_data*sigma_edm)/torch.sqrt(sigma_edm**2+self.sigma_data**2)
|
58
|
+
|
59
|
+
def c_noise(self, t: Union[float, torch.Tensor]):
|
60
|
+
sigma_edm = self.sigma_edm(t)
|
61
|
+
return sigma_edm.log()/4
|
62
|
+
|
63
|
+
@property
|
64
|
+
def sigma_start(self):
|
65
|
+
return self.sigma(0)
|
66
|
+
|
67
|
+
@property
|
68
|
+
def sigma_end(self):
|
69
|
+
return self.sigma(1)
|
70
|
+
|
71
|
+
@property
|
72
|
+
def alpha_start(self):
|
73
|
+
return self.alpha(0)
|
74
|
+
|
75
|
+
@property
|
76
|
+
def alpha_end(self):
|
77
|
+
return self.alpha(1)
|
78
|
+
|
79
|
+
def alpha_to_sigma(self, alpha):
|
80
|
+
return torch.sqrt(1 - (alpha*self.sigma_data)**2)
|
81
|
+
|
82
|
+
def sigma_to_alpha(self, sigma):
|
83
|
+
return torch.sqrt(1 - sigma**2)/self.sigma_data
|
84
|
+
|
85
|
+
class EDMTimeRescaleScheduler(EDMSigmaScheduler):
|
86
|
+
def __init__(self, ref_scheduler: SigmaScheduler, sigma_min=0.002, sigma_max=80.0, rho=7.0):
|
87
|
+
super().__init__(sigma_min, sigma_max, rho)
|
88
|
+
self.ref_scheduler = ref_scheduler
|
89
|
+
|
90
|
+
def scale_t(self, t):
|
91
|
+
ref_t = torch.linspace(0, 1, 1000)
|
92
|
+
alphas = self.alpha(ref_t)
|
93
|
+
sigmas = self.sigma(ref_t)
|
94
|
+
sigmas_edm = sigmas/alphas
|
95
|
+
sigma_edm = self.sigma_edm(t)
|
96
|
+
t = np.interp(sigma_edm.cpu().clip(min=1e-8).log().numpy(), sigmas_edm, ref_t.numpy())
|
97
|
+
return torch.tensor(t)
|
98
|
+
|
99
|
+
def sigma(self, t: Union[float, torch.Tensor]) -> torch.Tensor:
|
100
|
+
return self.ref_scheduler.sigma(t)
|
101
|
+
|
102
|
+
def alpha(self, t: Union[float, torch.Tensor]) -> torch.Tensor:
|
103
|
+
return self.ref_scheduler.alpha(t)
|
104
|
+
|
105
|
+
def velocity(self, t: Union[float, torch.Tensor], dt=1e-8, normlize=True) -> Tuple[torch.Tensor, torch.Tensor]:
|
106
|
+
return self.ref_scheduler.velocity(t, dt=dt, normlize=normlize)
|
107
|
+
|
108
|
+
def c_skip(self, t: Union[float, torch.Tensor]):
|
109
|
+
return self.ref_scheduler.c_skip(t)
|
110
|
+
|
111
|
+
def c_out(self, t: Union[float, torch.Tensor]):
|
112
|
+
return self.ref_scheduler.c_out(t)
|
113
|
+
|
114
|
+
def c_noise(self, t: Union[float, torch.Tensor]):
|
115
|
+
return self.ref_scheduler.c_noise(t)
|
116
|
+
|
117
|
+
def sample(self, min_t=0.0, max_t=1.0, shape=(1,)):
|
118
|
+
if isinstance(min_t, float):
|
119
|
+
min_t = torch.full(shape, min_t)
|
120
|
+
if isinstance(max_t, float):
|
121
|
+
max_t = torch.full(shape, max_t)
|
122
|
+
|
123
|
+
t = torch.lerp(min_t, max_t, torch.rand_like(min_t))
|
124
|
+
t = self.scale_t(t)
|
125
|
+
return t
|
@@ -0,0 +1,74 @@
|
|
1
|
+
from typing import Union, Tuple
|
2
|
+
|
3
|
+
import torch
|
4
|
+
|
5
|
+
from .base import SigmaScheduler
|
6
|
+
|
7
|
+
class FlowSigmaScheduler(SigmaScheduler):
|
8
|
+
def __init__(self, t_start=0, t_end=1):
|
9
|
+
super().__init__()
|
10
|
+
self.t_start = t_start
|
11
|
+
self.t_end = t_end
|
12
|
+
|
13
|
+
def sigma(self, t: Union[float, torch.Tensor]):
|
14
|
+
if isinstance(t, float):
|
15
|
+
t = torch.tensor([t])
|
16
|
+
t = (self.t_end-self.t_start)*t+self.t_start
|
17
|
+
return t
|
18
|
+
|
19
|
+
def alpha(self, t: Union[float, torch.Tensor]):
|
20
|
+
if isinstance(t, float):
|
21
|
+
t = torch.tensor([t])
|
22
|
+
t = (self.t_end-self.t_start)*t+self.t_start
|
23
|
+
return 1-t
|
24
|
+
|
25
|
+
def velocity(self, t: Union[float, torch.Tensor], dt=1e-8, normlize=False) -> Tuple[torch.Tensor, torch.Tensor]:
|
26
|
+
'''
|
27
|
+
v(t) = dx(t)/dt = d\alpha(t)/dt * x(0) + d\sigma(t)/dt *eps
|
28
|
+
:param t: 0-1, rate of time step
|
29
|
+
:return: d\alpha(t)/dt, d\sigma(t)/dt
|
30
|
+
'''
|
31
|
+
if isinstance(t, float):
|
32
|
+
t = torch.tensor([t])
|
33
|
+
d_alpha = -torch.ones_like(t)
|
34
|
+
d_sigma = torch.ones_like(t)
|
35
|
+
if normlize:
|
36
|
+
norm = torch.sqrt(d_alpha**2+d_sigma**2)
|
37
|
+
return d_alpha/norm, d_sigma/norm
|
38
|
+
else:
|
39
|
+
return d_alpha, d_sigma
|
40
|
+
|
41
|
+
def alpha_to_t(self, alphas):
|
42
|
+
"""
|
43
|
+
alphas: [B]
|
44
|
+
:return: t [B]
|
45
|
+
"""
|
46
|
+
return alphas
|
47
|
+
|
48
|
+
def sigma_to_t(self, sigmas):
|
49
|
+
"""
|
50
|
+
sigmas: [B]
|
51
|
+
:return: t [B]
|
52
|
+
"""
|
53
|
+
return 1-sigmas
|
54
|
+
|
55
|
+
def alpha_to_sigma(self, alpha):
|
56
|
+
return 1-alpha
|
57
|
+
|
58
|
+
def sigma_to_alpha(self, sigma):
|
59
|
+
return 1-sigma
|
60
|
+
|
61
|
+
def c_skip(self, t: Union[float, torch.Tensor]):
|
62
|
+
'''
|
63
|
+
\hat{x}(0) = c_skip*x(t) + c_out*f(x(t))
|
64
|
+
:param t: 0-1, rate of time step
|
65
|
+
'''
|
66
|
+
return 1.
|
67
|
+
|
68
|
+
def c_out(self, t: Union[float, torch.Tensor]):
|
69
|
+
'''
|
70
|
+
\hat{x}(0) = c_skip*x(t) + c_out*f(x(t))
|
71
|
+
:param t: 0-1, rate of time step
|
72
|
+
'''
|
73
|
+
sigma = self.sigma(t)
|
74
|
+
return -sigma
|
@@ -0,0 +1,22 @@
|
|
1
|
+
from typing import Union
|
2
|
+
|
3
|
+
import torch
|
4
|
+
from .base import SigmaScheduler
|
5
|
+
|
6
|
+
class ZeroTerminalScheduler(SigmaScheduler):
|
7
|
+
def __init__(self, ref_scheduler: SigmaScheduler, eps=1e-4):
|
8
|
+
self.ref_scheduler = ref_scheduler
|
9
|
+
self.eps = eps
|
10
|
+
|
11
|
+
def alpha(self, t: Union[float, torch.Tensor]) -> torch.Tensor:
|
12
|
+
alpha_0 = self.ref_scheduler.alpha_start
|
13
|
+
alpha_T = self.ref_scheduler.alpha_end
|
14
|
+
alpha = self.ref_scheduler.alpha(t)
|
15
|
+
return (alpha - alpha_T)*(alpha_0-self.eps)/(alpha_0 - alpha_T) + self.eps
|
16
|
+
|
17
|
+
def sigma(self, t: Union[float, torch.Tensor]) -> torch.Tensor:
|
18
|
+
try:
|
19
|
+
alpha = self.alpha(t)
|
20
|
+
return self.ref_scheduler.alpha_to_sigma(alpha)
|
21
|
+
except NotImplementedError:
|
22
|
+
raise NotImplementedError(f'{type(self.ref_scheduler)} cannot be a "ZeroTerminalScheduler"!')
|
hcpdiff/easy/cfg/sd15_train.py
CHANGED
@@ -1,18 +1,17 @@
|
|
1
1
|
import torch
|
2
|
-
from rainbowneko.ckpt_manager import ckpt_saver, LAYERS_TRAINABLE, NekoPluginSaver, SafeTensorFormat
|
3
|
-
from rainbowneko.data import RatioBucket, FixedBucket
|
4
|
-
from rainbowneko.parser import CfgWDPluginParser, neko_cfg, CfgWDModelParser, disable_neko_cfg
|
5
|
-
from rainbowneko.utils import ConstantLR, Path_Like
|
6
|
-
|
7
2
|
from hcpdiff.ckpt_manager import LoraWebuiFormat
|
8
3
|
from hcpdiff.data import TextImagePairDataset, Text2ImageSource, StableDiffusionHandler
|
9
4
|
from hcpdiff.data import VaeCache
|
10
5
|
from hcpdiff.easy import SD15_auto_loader
|
11
6
|
from hcpdiff.models import SD15Wrapper, TEHookCFG
|
12
7
|
from hcpdiff.models.lora_layers_patch import LoraLayer
|
8
|
+
from rainbowneko.ckpt_manager import ckpt_saver, NekoOptimizerSaver, LAYERS_TRAINABLE, NekoPluginSaver, SafeTensorFormat
|
9
|
+
from rainbowneko.data import RatioBucket, FixedBucket
|
10
|
+
from rainbowneko.parser import CfgWDPluginParser, neko_cfg, CfgWDModelParser, disable_neko_cfg
|
11
|
+
from rainbowneko.utils import ConstantLR, Path_Like
|
13
12
|
|
14
13
|
@neko_cfg
|
15
|
-
def SD15_finetuning(base_model: str, train_steps: int, dataset, save_step: int = 500, lr: float = 1e-5, clip_skip: int = 0,
|
14
|
+
def SD15_finetuning(base_model: str, train_steps: int, dataset, save_step: int = 500, save_optimizer=False, lr: float = 1e-5, clip_skip: int = 0,
|
16
15
|
dtype: str = 'fp16', low_vram: bool = False, warmup_steps: int = 0, name: str = 'SD15'):
|
17
16
|
if low_vram:
|
18
17
|
from bitsandbytes.optim import AdamW8bit
|
@@ -20,6 +19,17 @@ def SD15_finetuning(base_model: str, train_steps: int, dataset, save_step: int =
|
|
20
19
|
else:
|
21
20
|
optimizer = torch.optim.AdamW(_partial_=True)
|
22
21
|
|
22
|
+
ckpt_saver_dict = dict(
|
23
|
+
SD15=ckpt_saver(
|
24
|
+
ckpt_type='safetensors',
|
25
|
+
target_module='denoiser',
|
26
|
+
layers=LAYERS_TRAINABLE,
|
27
|
+
)
|
28
|
+
)
|
29
|
+
|
30
|
+
if save_optimizer:
|
31
|
+
ckpt_saver_dict['optimizer'] = NekoOptimizerSaver()
|
32
|
+
|
23
33
|
from cfgs.train.py import train_base, tuning_base
|
24
34
|
|
25
35
|
return dict(
|
@@ -34,11 +44,7 @@ def SD15_finetuning(base_model: str, train_steps: int, dataset, save_step: int =
|
|
34
44
|
], weight_decay=1e-2),
|
35
45
|
|
36
46
|
ckpt_saver=dict(
|
37
|
-
SD15=
|
38
|
-
ckpt_type='safetensors',
|
39
|
-
target_module='denoiser',
|
40
|
-
layers=LAYERS_TRAINABLE,
|
41
|
-
)
|
47
|
+
SD15=ckpt_saver_dict
|
42
48
|
),
|
43
49
|
|
44
50
|
train=dict(
|
@@ -68,9 +74,9 @@ def SD15_finetuning(base_model: str, train_steps: int, dataset, save_step: int =
|
|
68
74
|
)
|
69
75
|
|
70
76
|
@neko_cfg
|
71
|
-
def SD15_lora_train(base_model: str, train_steps: int, dataset, save_step: int = 200, lr: float = 1e-4, rank: int = 4,
|
72
|
-
clip_skip: int = 0, with_conv: bool = False, dtype: str = 'fp16', low_vram: bool = False,
|
73
|
-
name: str = 'SD15', save_webui_format=False):
|
77
|
+
def SD15_lora_train(base_model: str, train_steps: int, dataset, save_step: int = 200, save_optimizer=False, lr: float = 1e-4, rank: int = 4,
|
78
|
+
alpha: float = None, clip_skip: int = 0, with_conv: bool = False, dtype: str = 'fp16', low_vram: bool = False,
|
79
|
+
warmup_steps: int = 0, name: str = 'SD15', save_webui_format=False):
|
74
80
|
with disable_neko_cfg:
|
75
81
|
if alpha is None:
|
76
82
|
alpha = rank
|
@@ -101,6 +107,17 @@ def SD15_lora_train(base_model: str, train_steps: int, dataset, save_step: int =
|
|
101
107
|
else:
|
102
108
|
lora_format = SafeTensorFormat()
|
103
109
|
|
110
|
+
ckpt_saver_dict = dict(
|
111
|
+
_replace_=True,
|
112
|
+
lora_unet=NekoPluginSaver(
|
113
|
+
format=lora_format,
|
114
|
+
target_plugin='lora1',
|
115
|
+
)
|
116
|
+
)
|
117
|
+
|
118
|
+
if save_optimizer:
|
119
|
+
ckpt_saver_dict['optimizer'] = NekoOptimizerSaver()
|
120
|
+
|
104
121
|
from cfgs.train.py.examples import SD_FT
|
105
122
|
|
106
123
|
return dict(
|
@@ -118,13 +135,7 @@ def SD15_lora_train(base_model: str, train_steps: int, dataset, save_step: int =
|
|
118
135
|
)
|
119
136
|
), weight_decay=0.1),
|
120
137
|
|
121
|
-
ckpt_saver=
|
122
|
-
_replace_ = True,
|
123
|
-
lora_unet=NekoPluginSaver(
|
124
|
-
format=lora_format,
|
125
|
-
target_plugin='lora1',
|
126
|
-
)
|
127
|
-
),
|
138
|
+
ckpt_saver=ckpt_saver_dict,
|
128
139
|
|
129
140
|
train=dict(
|
130
141
|
train_steps=train_steps,
|
@@ -181,7 +192,7 @@ def cfg_data_SD_ARB(img_root: Path_Like, batch_size: int = 4, trigger_word: str
|
|
181
192
|
)
|
182
193
|
|
183
194
|
@neko_cfg
|
184
|
-
def cfg_data_SD_resize_crop(img_root: Path_Like, batch_size: int = 4, trigger_word: str = '', target_size
|
195
|
+
def cfg_data_SD_resize_crop(img_root: Path_Like, batch_size: int = 4, trigger_word: str = '', target_size=(512, 512), word_names=None,
|
185
196
|
prompt_dropout: float = 0, prompt_template: Path_Like = 'prompt_template/caption.txt', loss_weight=1.0):
|
186
197
|
if word_names is None:
|
187
198
|
word_names = dict(pt1=trigger_word)
|