hcpdiff 2.2.1__py3-none-any.whl → 2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. hcpdiff/ckpt_manager/__init__.py +1 -1
  2. hcpdiff/ckpt_manager/ckpt.py +21 -17
  3. hcpdiff/ckpt_manager/format/diffusers.py +4 -4
  4. hcpdiff/ckpt_manager/format/sd_single.py +3 -3
  5. hcpdiff/ckpt_manager/loader.py +11 -4
  6. hcpdiff/diffusion/noise/__init__.py +0 -1
  7. hcpdiff/diffusion/sampler/VP.py +27 -0
  8. hcpdiff/diffusion/sampler/__init__.py +2 -3
  9. hcpdiff/diffusion/sampler/base.py +106 -44
  10. hcpdiff/diffusion/sampler/diffusers.py +11 -17
  11. hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -1
  12. hcpdiff/diffusion/sampler/sigma_scheduler/base.py +77 -2
  13. hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +193 -49
  14. hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +110 -33
  15. hcpdiff/diffusion/sampler/sigma_scheduler/flow.py +74 -0
  16. hcpdiff/diffusion/sampler/sigma_scheduler/zero_terminal.py +22 -0
  17. hcpdiff/easy/cfg/sd15_train.py +33 -22
  18. hcpdiff/easy/cfg/sdxl_train.py +32 -23
  19. hcpdiff/evaluate/__init__.py +3 -1
  20. hcpdiff/evaluate/evaluator.py +76 -0
  21. hcpdiff/evaluate/metrics/__init__.py +1 -0
  22. hcpdiff/evaluate/metrics/clip_score.py +23 -0
  23. hcpdiff/evaluate/previewer.py +29 -12
  24. hcpdiff/loss/base.py +9 -26
  25. hcpdiff/loss/weighting.py +36 -18
  26. hcpdiff/models/lora_base_patch.py +26 -0
  27. hcpdiff/models/wrapper/sd.py +17 -19
  28. hcpdiff/trainer_ac.py +7 -5
  29. hcpdiff/trainer_ac_single.py +1 -6
  30. hcpdiff/utils/__init__.py +2 -1
  31. hcpdiff/utils/torch_utils.py +25 -0
  32. hcpdiff/workflow/__init__.py +1 -1
  33. hcpdiff/workflow/diffusion.py +27 -7
  34. hcpdiff/workflow/io.py +20 -3
  35. hcpdiff/workflow/text.py +6 -1
  36. {hcpdiff-2.2.1.dist-info → hcpdiff-2.3.dist-info}/METADATA +2 -2
  37. {hcpdiff-2.2.1.dist-info → hcpdiff-2.3.dist-info}/RECORD +41 -37
  38. {hcpdiff-2.2.1.dist-info → hcpdiff-2.3.dist-info}/WHEEL +1 -1
  39. hcpdiff/diffusion/noise/zero_terminal.py +0 -39
  40. hcpdiff/diffusion/sampler/ddpm.py +0 -20
  41. hcpdiff/diffusion/sampler/edm.py +0 -22
  42. {hcpdiff-2.2.1.dist-info → hcpdiff-2.3.dist-info}/entry_points.txt +0 -0
  43. {hcpdiff-2.2.1.dist-info → hcpdiff-2.3.dist-info}/licenses/LICENSE +0 -0
  44. {hcpdiff-2.2.1.dist-info → hcpdiff-2.3.dist-info}/top_level.txt +0 -0
@@ -1,17 +1,22 @@
1
- import torch
2
1
  import math
3
- from typing import Union, Tuple
4
- from hcpdiff.utils import linear_interp
2
+ from typing import Union, Tuple, Callable
3
+
4
+ import torch
5
+
6
+ from hcpdiff.utils import invert_func
5
7
  from .base import SigmaScheduler
6
8
 
7
9
  class DDPMDiscreteSigmaScheduler(SigmaScheduler):
8
- def __init__(self, beta_schedule: str = "scaled_linear", linear_start=0.00085, linear_end=0.0120, num_timesteps=1000):
10
+ def __init__(self, beta_schedule: str = "scaled_linear", linear_start=0.00085, linear_end=0.0120, num_timesteps=1000, pred_type='eps'):
9
11
  super().__init__()
10
12
  self.num_timesteps = num_timesteps
11
13
  self.betas = self.make_betas(beta_schedule, linear_start, linear_end, num_timesteps)
12
14
  alphas = 1.0-self.betas
13
15
  self.alphas_cumprod = torch.cumprod(alphas, dim=0)
14
- self.sigmas = ((1-self.alphas_cumprod)/self.alphas_cumprod).sqrt()
16
+
17
+ self.alphas = self.alphas_cumprod.sqrt()
18
+ self.sigmas = (1-self.alphas_cumprod).sqrt()
19
+ self.pred_type = pred_type
15
20
 
16
21
  # for VLB calculation
17
22
  self.alphas_cumprod_prev = torch.cat([alphas.new_tensor([1.0]), self.alphas_cumprod[:-1]])
@@ -22,37 +27,73 @@ class DDPMDiscreteSigmaScheduler(SigmaScheduler):
22
27
  # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
23
28
  self.posterior_log_variance_clipped = torch.log(torch.cat([self.posterior_variance[1:2], self.posterior_variance[1:]]))
24
29
 
30
+ # def scale_t(self, t):
31
+ # return t*(self.num_timesteps-1)
25
32
 
26
33
  @property
27
- def sigma_min(self):
34
+ def sigma_start(self):
28
35
  return self.sigmas[0]
29
36
 
30
37
  @property
31
- def sigma_max(self):
38
+ def sigma_end(self):
32
39
  return self.sigmas[-1]
33
40
 
34
- def get_sigma(self, t: Union[float, torch.Tensor]):
41
+ @property
42
+ def alpha_start(self):
43
+ return self.alphas[0]
44
+
45
+ @property
46
+ def alpha_end(self):
47
+ return self.alphas[-1]
48
+
49
+ def sigma(self, t: Union[float, torch.Tensor]):
35
50
  if isinstance(t, float):
36
51
  t = torch.tensor(t)
37
- return self.sigmas[(t*len(self.sigmas)).long()]
52
+ self.sigmas = self.sigmas.to(t.device)
53
+ return self.sigmas[((t*self.num_timesteps).round().long()).clip(min=0, max=self.num_timesteps-1)]
38
54
 
39
- def sample_sigma(self, min_rate=0.0, max_rate=1.0, shape=(1,)):
40
- if isinstance(min_rate, float):
41
- min_rate = torch.full(shape, min_rate)
42
- if isinstance(max_rate, float):
43
- max_rate = torch.full(shape, max_rate)
55
+ def alpha(self, t: Union[float, torch.Tensor]):
56
+ if isinstance(t, float):
57
+ t = torch.tensor(t)
58
+ self.alphas = self.alphas.to(t.device)
59
+ return self.alphas[((t*self.num_timesteps).round().long()).clip(min=0, max=self.num_timesteps-1)]
60
+
61
+ def c_noise(self, t: Union[float, torch.Tensor]):
62
+ return (t*self.num_timesteps).round()
44
63
 
45
- t = torch.lerp(min_rate, max_rate, torch.rand_like(min_rate))
46
- t_scale = (t*(self.num_timesteps-1e-5)).long() # [0, num_timesteps-1)
47
- return self.sigmas[t_scale], t
64
+ def velocity(self, t: Union[float, torch.Tensor], dt=1e-8, normlize=True) -> Tuple[torch.Tensor, torch.Tensor]:
65
+ '''
66
+ v(t) = dx(t)/dt = d\alpha(t)/dt * x(0) + d\sigma(t)/dt *eps
67
+ :param t: 0-1, rate of time step
68
+ :return: d\alpha(t)/dt, d\sigma(t)/dt
69
+ '''
70
+ d_alpha = -self.sigma(t)
71
+ d_sigma = self.alpha(t)
72
+ if normlize:
73
+ norm = torch.sqrt(d_alpha**2+d_sigma**2)
74
+ return d_alpha/norm, d_sigma/norm
75
+ else:
76
+ return d_alpha, d_sigma
48
77
 
49
78
  def sigma_to_t(self, sigma: Union[float, torch.Tensor]):
50
- t = (self.sigmas-sigma).abs().argmin()
51
- return t/self.num_timesteps
79
+ ref_t = np.linspace(0, 1, len(self.sigmas))
80
+ t = torch.tensor(np.interp(sigma.cpu().clip(min=1e-8).log().numpy(), self.sigmas, ref_t))
81
+ return t
82
+
83
+ def alpha_to_t(self, alpha: Union[float, torch.Tensor]):
84
+ ref_t = np.linspace(0, 1, len(self.alphas))
85
+ t = torch.tensor(np.interp(alpha.cpu().clip(min=1e-8).log().numpy(), self.alphas, ref_t))
86
+ return t
87
+
88
+ def alpha_to_sigma(self, alpha):
89
+ return torch.sqrt(1 - alpha**2)
90
+
91
+ def sigma_to_alpha(self, sigma):
92
+ return torch.sqrt(1 - sigma**2)
52
93
 
53
94
  def get_post_mean(self, t, x_0, x_t):
54
95
  t = (t*len(self.sigmas)).long()
55
- return self.posterior_mean_coef1[t].view(-1, 1, 1, 1).to(t.device)*x_0 + self.posterior_mean_coef2[t].view(-1, 1, 1, 1).to(t.device)*x_t
96
+ return self.posterior_mean_coef1[t].view(-1, 1, 1, 1).to(t.device)*x_0+self.posterior_mean_coef2[t].view(-1, 1, 1, 1).to(t.device)*x_t
56
97
 
57
98
  def get_post_log_var(self, t, x_t_var=None):
58
99
  t = (t*len(self.sigmas)).long()
@@ -66,7 +107,6 @@ class DDPMDiscreteSigmaScheduler(SigmaScheduler):
66
107
  model_log_variance = frac*max_log+(1-frac)*min_log
67
108
  return model_log_variance
68
109
 
69
-
70
110
  @staticmethod
71
111
  def betas_for_alpha_bar(
72
112
  num_diffusion_timesteps,
@@ -130,50 +170,154 @@ class DDPMDiscreteSigmaScheduler(SigmaScheduler):
130
170
  else:
131
171
  raise NotImplementedError(f"{beta_schedule} does is not implemented.")
132
172
 
133
- class DDPMContinuousSigmaScheduler(DDPMDiscreteSigmaScheduler):
173
+ class DDPMContinuousSigmaScheduler(SigmaScheduler):
174
+ def __init__(self, beta_schedule: str = "scaled_linear", linear_start=0.00085, linear_end=0.0120, t_base=1000):
175
+ self.alpha_bar_fn = self.make_alpha_bar_fn(beta_schedule, linear_start, linear_end)
176
+ self.t_base = t_base # base time step for continuous product
177
+
178
+ def continuous_product(self, alpha_fn: Callable[[torch.Tensor], torch.Tensor], t: torch.Tensor):
179
+ '''
180
+
181
+ :param alpha_fn: alpha function
182
+ :param t: timesteps with shape [B]
183
+ :return: [B]
184
+ '''
185
+ bins = torch.linspace(0, 1, self.t_base, dtype=torch.float32).unsqueeze(0)
186
+ t_grid = bins*t.float().unsqueeze(1) # [B, num_bins]
187
+ alpha_vals = alpha_fn(t_grid)
188
+
189
+ if torch.any(alpha_vals<=0):
190
+ raise ValueError("alpha(t) must > 0 to avoid log(≤0).")
191
+
192
+ log_term = torch.log(alpha_vals) # [B, num_bins]
193
+ dt = t_grid[:, 1]-t_grid[:, 0] # [B]
194
+ integral = torch.cumsum((log_term[:, -1]+log_term[:, 1:])/2*dt.unsqueeze(1), dim=1) # [B]
195
+ x_vals = torch.exp(integral)
196
+ return x_vals
197
+
198
+ @staticmethod
199
+ def alpha_bar_linear(beta_s, beta_e, t, N=1000):
200
+ A = beta_e-beta_s
201
+ B = 1-beta_s
202
+ B_At = B-A*t
203
+
204
+ # 避免数值不稳定
205
+ eps = 1e-12
206
+ B = torch.clamp(B, min=eps)
207
+ B_At = torch.clamp(B_At, min=eps)
208
+
209
+ term = (B*torch.log(B)-B_At*torch.log(B_At)-A*t)
210
+ return torch.exp(N*term/A)
211
+
212
+ @staticmethod
213
+ def alpha_bar_scaled_linear(beta_s, beta_e, t, N=1000):
214
+ sqrt_bs = torch.sqrt(beta_s)
215
+ sqrt_be = torch.sqrt(beta_e)
216
+ a = sqrt_be-sqrt_bs
217
+ b = sqrt_bs
218
+ u0 = b
219
+ u1 = a*t+b
220
+
221
+ eps = 1e-12
222
+
223
+ def safe_log1m(u2):
224
+ return torch.log(torch.clamp(1-u2, min=eps))
225
+
226
+ def safe_log_frac(u):
227
+ return torch.log(torch.clamp(1+u, min=eps))-torch.log(torch.clamp(1-u, min=eps))
134
228
 
135
- def get_sigma(self, t: Union[float, torch.Tensor]):
229
+ term1 = u1*safe_log1m(u1**2)
230
+ term2 = 0.5*safe_log_frac(u1)
231
+ term3 = u0*safe_log1m(u0**2)
232
+ term4 = 0.5*safe_log_frac(u0)
233
+
234
+ return torch.exp(N*(term1+term2-term3-term4)/a)
235
+
236
+ def make_alpha_bar_fn(self, beta_schedule, beta_start, beta_end, alpha_fn=None):
237
+ if alpha_fn is not None:
238
+ return lambda t, alpha_fn_=alpha_fn:self.continuous_product(alpha_fn_(t), t)
239
+ elif beta_schedule == "linear":
240
+ return lambda t:self.alpha_bar_linear(beta_start, beta_end, t)
241
+ elif beta_schedule == "scaled_linear":
242
+ # this schedule is very specific to the latent diffusion model.
243
+ return lambda t:self.alpha_bar_scaled_linear(beta_start, beta_end, t)
244
+ elif beta_schedule == "squaredcos_cap_v2":
245
+ return lambda t:torch.cos((t+0.008)/1.008*math.pi/2)**2
246
+ elif beta_schedule == "sigmoid":
247
+ # GeoDiff sigmoid schedule
248
+ alpha_fn = lambda t:1-torch.sigmoid(torch.lerp(torch.full_like(t, -6), torch.full_like(t, 6), t))*(beta_end-beta_start)+beta_start
249
+ return lambda t, alpha_fn_=alpha_fn:self.continuous_product(alpha_fn_(t), t)
250
+ else:
251
+ raise NotImplementedError(f"{beta_schedule} does is not implemented.")
252
+
253
+ def sigma(self, t: Union[float, torch.Tensor]):
136
254
  if isinstance(t, float):
137
- t = torch.tensor(t)
138
- return linear_interp(self.sigmas, t)
255
+ t = torch.tensor([t])
256
+ alpha_cumprod = self.alpha_bar_fn(t)
257
+ return torch.sqrt(1-alpha_cumprod)
139
258
 
140
- def sample_sigma(self, min_rate=0.0, max_rate=1.0, shape=(1,)):
141
- if isinstance(min_rate, float):
142
- min_rate = torch.full(shape, min_rate)
143
- if isinstance(max_rate, float):
144
- max_rate = torch.full(shape, max_rate)
259
+ def alpha(self, t: Union[float, torch.Tensor]):
260
+ if isinstance(t, float):
261
+ t = torch.tensor([t])
262
+ alpha_cumprod = self.alpha_bar_fn(t)
263
+ return torch.sqrt(alpha_cumprod)
145
264
 
146
- t = torch.lerp(min_rate, max_rate, torch.rand_like(min_rate))
147
- t_scale = (t*(self.num_timesteps-1-1e-5)) # [0, num_timesteps-1)
265
+ def c_noise(self, t: Union[float, torch.Tensor]):
266
+ return t*self.t_base
148
267
 
149
- return linear_interp(self.sigmas, t_scale), t
268
+ @property
269
+ def sigma_start(self):
270
+ return self.sigma(0)
150
271
 
151
- def sigma_to_t(self, sigma: Union[float, torch.Tensor]):
152
- diff = self.sigmas-sigma
153
- diff[diff<0] = float('inf')
154
- t0 = diff.argmin().clamp(0, self.num_timesteps-2)
155
- return t0 + diff.min()/(self.sigmas[t0+1]-self.sigmas[t0])
272
+ @property
273
+ def sigma_end(self):
274
+ return self.sigma(1)
275
+
276
+ @property
277
+ def alpha_start(self):
278
+ return self.alpha(0)
279
+
280
+ @property
281
+ def alpha_end(self):
282
+ return self.alpha(1)
283
+
284
+ def alpha_to_t(self, alpha, t_min=0.0, t_max=1.0, tol=1e-5, max_iter=100):
285
+ """
286
+ alpha: [B]
287
+ :return: t [B]
288
+ """
289
+ return invert_func(self.alpha, alpha, t_min, t_max, tol, max_iter)
290
+
291
+ def sigma_to_t(self, sigma, t_min=0.0, t_max=1.0, tol=1e-5, max_iter=100):
292
+ """
293
+ sigma: [B]
294
+ :return: t [B]
295
+ """
296
+ return invert_func(self.sigma, sigma, t_min, t_max, tol, max_iter)
156
297
 
157
298
  class TimeSigmaScheduler(SigmaScheduler):
158
299
  def __init__(self, num_timesteps=1000):
159
300
  super().__init__()
160
301
  self.num_timesteps = num_timesteps
161
302
 
162
- def get_sigma(self, t: Union[float, torch.Tensor]) -> torch.Tensor:
303
+ def sigma(self, t: Union[float, torch.Tensor]) -> torch.Tensor:
163
304
  '''
164
305
  :param t: 0-1, rate of time step
165
306
  '''
166
- return t
167
-
168
- def sample_sigma(self, min_rate=0.0, max_rate=1.0, shape=(1,)) -> Tuple[torch.Tensor, torch.Tensor]:
169
- if isinstance(min_rate, float):
170
- min_rate = torch.full(shape, min_rate)
171
- if isinstance(max_rate, float):
172
- max_rate = torch.full(shape, max_rate)
307
+ if isinstance(t, float):
308
+ t = torch.tensor(t)
309
+ return ((t*self.num_timesteps).round().long()).clip(min=0, max=self.num_timesteps-1)
173
310
 
174
- t = torch.lerp(min_rate, max_rate, torch.rand_like(min_rate))
175
- t_scale = (t*(self.num_timesteps-1e-5)).long() # [0, num_timesteps-1)
176
- return t_scale, t
311
+ def alpha(self, t: Union[float, torch.Tensor]) -> torch.Tensor:
312
+ '''
313
+ :param t: 0-1, rate of time step
314
+ '''
315
+ if isinstance(t, float):
316
+ t = torch.tensor(t)
317
+ return ((t*self.num_timesteps).round().long()).clip(min=0, max=self.num_timesteps-1)
318
+
319
+ def c_noise(self, t: Union[float, torch.Tensor]):
320
+ return (t*self.num_timesteps).round()
177
321
 
178
322
  if __name__ == '__main__':
179
323
  from matplotlib import pyplot as plt
@@ -1,19 +1,18 @@
1
- from typing import Union
1
+ from typing import Union, Tuple
2
2
 
3
- import torch
4
3
  import numpy as np
4
+ import torch
5
5
 
6
6
  from .base import SigmaScheduler
7
7
 
8
8
  class EDMSigmaScheduler(SigmaScheduler):
9
- def __init__(self, sigma_min=0.002, sigma_max=80.0, rho=7.0, num_timesteps=1000):
10
- self.sigma_min = torch.tensor(sigma_min)
11
- self.sigma_max = torch.tensor(sigma_max)
9
+ def __init__(self, sigma_min=0.002, sigma_max=80.0, sigma_data=0.5, rho=7.0):
10
+ self.sigma_min = sigma_min
11
+ self.sigma_max = sigma_max
12
+ self.sigma_data = sigma_data
12
13
  self.rho = rho
13
14
 
14
- self.num_timesteps=num_timesteps
15
-
16
- def get_sigma(self, t: Union[float, torch.Tensor]):
15
+ def sigma_edm(self, t: Union[float, torch.Tensor]) -> torch.Tensor:
17
16
  if isinstance(t, float):
18
17
  t = torch.tensor(t)
19
18
 
@@ -21,28 +20,106 @@ class EDMSigmaScheduler(SigmaScheduler):
21
20
  max_inv_rho = self.sigma_max**(1/self.rho)
22
21
  return torch.lerp(min_inv_rho, max_inv_rho, t)**self.rho
23
22
 
24
- def sample_sigma(self, min_rate=0.0, max_rate=1.0, shape=(1,)):
25
- if isinstance(min_rate, float):
26
- min_rate = torch.full(shape, min_rate)
27
- if isinstance(max_rate, float):
28
- max_rate = torch.full(shape, max_rate)
29
-
30
- t = torch.lerp(min_rate, max_rate, torch.rand_like(min_rate))
31
- return self.get_sigma(t), t
32
-
33
- class EDMRefSigmaScheduler(EDMSigmaScheduler):
34
- def __init__(self, ref_scheduler, sigma_min=0.002, sigma_max=80.0, rho=7.0, num_timesteps=1000):
35
- super().__init__(sigma_min, sigma_max, rho, num_timesteps=num_timesteps)
36
- self.ref_sigmas = ref_scheduler.sigmas.cpu().clip(min=1e-8).log().numpy()
37
- self.ref_t = np.linspace(0, 1, len(self.ref_sigmas))
38
-
39
- def sample_sigma(self, min_rate=0.0, max_rate=1.0, shape=(1,)):
40
- if isinstance(min_rate, float):
41
- min_rate = torch.full(shape, min_rate)
42
- if isinstance(max_rate, float):
43
- max_rate = torch.full(shape, max_rate)
44
-
45
- t = torch.lerp(min_rate, max_rate, torch.rand_like(min_rate))
46
- sigma = self.get_sigma(t)
47
- t_rect = torch.tensor(np.interp(sigma.cpu().clip(min=1e-8).log().numpy(), self.ref_sigmas, self.ref_t))
48
- return sigma, t_rect
23
+ def sigma(self, t: Union[float, torch.Tensor]) -> torch.Tensor:
24
+ '''
25
+ x_t = c_in(t) * (x(0) + \sigma(t)*eps), eps~N(0,I)
26
+ '''
27
+ if isinstance(t, float):
28
+ t = torch.tensor(t)
29
+
30
+ sigma_edm = self.sigma_edm(t)
31
+ return sigma_edm/torch.sqrt(sigma_edm**2+self.sigma_data**2)
32
+
33
+ def alpha(self, t: Union[float, torch.Tensor]) -> torch.Tensor:
34
+ '''
35
+ x_t = c_in(t) * (x(0) + \sigma(t)*eps), eps~N(0,I)
36
+ '''
37
+ if isinstance(t, float):
38
+ t = torch.tensor(t)
39
+
40
+ sigma_edm = self.sigma_edm(t)
41
+ return 1./torch.sqrt(sigma_edm**2+self.sigma_data**2)
42
+
43
+ def c_skip(self, t: Union[float, torch.Tensor]):
44
+ '''
45
+ \hat{x}(0) = c_skip(t)*(x(t)/c_in(t)) + c_out(t)*f(x(t))
46
+ :param t: 0-1, rate of time step
47
+ '''
48
+ sigma_edm = self.sigma_edm(t)
49
+ return self.sigma_data**2/torch.sqrt(sigma_edm**2+self.sigma_data**2)
50
+
51
+ def c_out(self, t: Union[float, torch.Tensor]):
52
+ '''
53
+ \hat{x}(0) = c_skip(t)*(x(t)/c_in(t)) + c_out(t)*f(x(t))
54
+ :param t: 0-1, rate of time step
55
+ '''
56
+ sigma_edm = self.sigma_edm(t)
57
+ return (self.sigma_data*sigma_edm)/torch.sqrt(sigma_edm**2+self.sigma_data**2)
58
+
59
+ def c_noise(self, t: Union[float, torch.Tensor]):
60
+ sigma_edm = self.sigma_edm(t)
61
+ return sigma_edm.log()/4
62
+
63
+ @property
64
+ def sigma_start(self):
65
+ return self.sigma(0)
66
+
67
+ @property
68
+ def sigma_end(self):
69
+ return self.sigma(1)
70
+
71
+ @property
72
+ def alpha_start(self):
73
+ return self.alpha(0)
74
+
75
+ @property
76
+ def alpha_end(self):
77
+ return self.alpha(1)
78
+
79
+ def alpha_to_sigma(self, alpha):
80
+ return torch.sqrt(1 - (alpha*self.sigma_data)**2)
81
+
82
+ def sigma_to_alpha(self, sigma):
83
+ return torch.sqrt(1 - sigma**2)/self.sigma_data
84
+
85
+ class EDMTimeRescaleScheduler(EDMSigmaScheduler):
86
+ def __init__(self, ref_scheduler: SigmaScheduler, sigma_min=0.002, sigma_max=80.0, rho=7.0):
87
+ super().__init__(sigma_min, sigma_max, rho)
88
+ self.ref_scheduler = ref_scheduler
89
+
90
+ def scale_t(self, t):
91
+ ref_t = torch.linspace(0, 1, 1000)
92
+ alphas = self.alpha(ref_t)
93
+ sigmas = self.sigma(ref_t)
94
+ sigmas_edm = sigmas/alphas
95
+ sigma_edm = self.sigma_edm(t)
96
+ t = np.interp(sigma_edm.cpu().clip(min=1e-8).log().numpy(), sigmas_edm, ref_t.numpy())
97
+ return torch.tensor(t)
98
+
99
+ def sigma(self, t: Union[float, torch.Tensor]) -> torch.Tensor:
100
+ return self.ref_scheduler.sigma(t)
101
+
102
+ def alpha(self, t: Union[float, torch.Tensor]) -> torch.Tensor:
103
+ return self.ref_scheduler.alpha(t)
104
+
105
+ def velocity(self, t: Union[float, torch.Tensor], dt=1e-8, normlize=True) -> Tuple[torch.Tensor, torch.Tensor]:
106
+ return self.ref_scheduler.velocity(t, dt=dt, normlize=normlize)
107
+
108
+ def c_skip(self, t: Union[float, torch.Tensor]):
109
+ return self.ref_scheduler.c_skip(t)
110
+
111
+ def c_out(self, t: Union[float, torch.Tensor]):
112
+ return self.ref_scheduler.c_out(t)
113
+
114
+ def c_noise(self, t: Union[float, torch.Tensor]):
115
+ return self.ref_scheduler.c_noise(t)
116
+
117
+ def sample(self, min_t=0.0, max_t=1.0, shape=(1,)):
118
+ if isinstance(min_t, float):
119
+ min_t = torch.full(shape, min_t)
120
+ if isinstance(max_t, float):
121
+ max_t = torch.full(shape, max_t)
122
+
123
+ t = torch.lerp(min_t, max_t, torch.rand_like(min_t))
124
+ t = self.scale_t(t)
125
+ return t
@@ -0,0 +1,74 @@
1
+ from typing import Union, Tuple
2
+
3
+ import torch
4
+
5
+ from .base import SigmaScheduler
6
+
7
+ class FlowSigmaScheduler(SigmaScheduler):
8
+ def __init__(self, t_start=0, t_end=1):
9
+ super().__init__()
10
+ self.t_start = t_start
11
+ self.t_end = t_end
12
+
13
+ def sigma(self, t: Union[float, torch.Tensor]):
14
+ if isinstance(t, float):
15
+ t = torch.tensor([t])
16
+ t = (self.t_end-self.t_start)*t+self.t_start
17
+ return t
18
+
19
+ def alpha(self, t: Union[float, torch.Tensor]):
20
+ if isinstance(t, float):
21
+ t = torch.tensor([t])
22
+ t = (self.t_end-self.t_start)*t+self.t_start
23
+ return 1-t
24
+
25
+ def velocity(self, t: Union[float, torch.Tensor], dt=1e-8, normlize=False) -> Tuple[torch.Tensor, torch.Tensor]:
26
+ '''
27
+ v(t) = dx(t)/dt = d\alpha(t)/dt * x(0) + d\sigma(t)/dt *eps
28
+ :param t: 0-1, rate of time step
29
+ :return: d\alpha(t)/dt, d\sigma(t)/dt
30
+ '''
31
+ if isinstance(t, float):
32
+ t = torch.tensor([t])
33
+ d_alpha = -torch.ones_like(t)
34
+ d_sigma = torch.ones_like(t)
35
+ if normlize:
36
+ norm = torch.sqrt(d_alpha**2+d_sigma**2)
37
+ return d_alpha/norm, d_sigma/norm
38
+ else:
39
+ return d_alpha, d_sigma
40
+
41
+ def alpha_to_t(self, alphas):
42
+ """
43
+ alphas: [B]
44
+ :return: t [B]
45
+ """
46
+ return alphas
47
+
48
+ def sigma_to_t(self, sigmas):
49
+ """
50
+ sigmas: [B]
51
+ :return: t [B]
52
+ """
53
+ return 1-sigmas
54
+
55
+ def alpha_to_sigma(self, alpha):
56
+ return 1-alpha
57
+
58
+ def sigma_to_alpha(self, sigma):
59
+ return 1-sigma
60
+
61
+ def c_skip(self, t: Union[float, torch.Tensor]):
62
+ '''
63
+ \hat{x}(0) = c_skip*x(t) + c_out*f(x(t))
64
+ :param t: 0-1, rate of time step
65
+ '''
66
+ return 1.
67
+
68
+ def c_out(self, t: Union[float, torch.Tensor]):
69
+ '''
70
+ \hat{x}(0) = c_skip*x(t) + c_out*f(x(t))
71
+ :param t: 0-1, rate of time step
72
+ '''
73
+ sigma = self.sigma(t)
74
+ return -sigma
@@ -0,0 +1,22 @@
1
+ from typing import Union
2
+
3
+ import torch
4
+ from .base import SigmaScheduler
5
+
6
+ class ZeroTerminalScheduler(SigmaScheduler):
7
+ def __init__(self, ref_scheduler: SigmaScheduler, eps=1e-4):
8
+ self.ref_scheduler = ref_scheduler
9
+ self.eps = eps
10
+
11
+ def alpha(self, t: Union[float, torch.Tensor]) -> torch.Tensor:
12
+ alpha_0 = self.ref_scheduler.alpha_start
13
+ alpha_T = self.ref_scheduler.alpha_end
14
+ alpha = self.ref_scheduler.alpha(t)
15
+ return (alpha - alpha_T)*(alpha_0-self.eps)/(alpha_0 - alpha_T) + self.eps
16
+
17
+ def sigma(self, t: Union[float, torch.Tensor]) -> torch.Tensor:
18
+ try:
19
+ alpha = self.alpha(t)
20
+ return self.ref_scheduler.alpha_to_sigma(alpha)
21
+ except NotImplementedError:
22
+ raise NotImplementedError(f'{type(self.ref_scheduler)} cannot be a "ZeroTerminalScheduler"!')
@@ -1,18 +1,17 @@
1
1
  import torch
2
- from rainbowneko.ckpt_manager import ckpt_saver, LAYERS_TRAINABLE, NekoPluginSaver, SafeTensorFormat
3
- from rainbowneko.data import RatioBucket, FixedBucket
4
- from rainbowneko.parser import CfgWDPluginParser, neko_cfg, CfgWDModelParser, disable_neko_cfg
5
- from rainbowneko.utils import ConstantLR, Path_Like
6
-
7
2
  from hcpdiff.ckpt_manager import LoraWebuiFormat
8
3
  from hcpdiff.data import TextImagePairDataset, Text2ImageSource, StableDiffusionHandler
9
4
  from hcpdiff.data import VaeCache
10
5
  from hcpdiff.easy import SD15_auto_loader
11
6
  from hcpdiff.models import SD15Wrapper, TEHookCFG
12
7
  from hcpdiff.models.lora_layers_patch import LoraLayer
8
+ from rainbowneko.ckpt_manager import ckpt_saver, NekoOptimizerSaver, LAYERS_TRAINABLE, NekoPluginSaver, SafeTensorFormat
9
+ from rainbowneko.data import RatioBucket, FixedBucket
10
+ from rainbowneko.parser import CfgWDPluginParser, neko_cfg, CfgWDModelParser, disable_neko_cfg
11
+ from rainbowneko.utils import ConstantLR, Path_Like
13
12
 
14
13
  @neko_cfg
15
- def SD15_finetuning(base_model: str, train_steps: int, dataset, save_step: int = 500, lr: float = 1e-5, clip_skip: int = 0,
14
+ def SD15_finetuning(base_model: str, train_steps: int, dataset, save_step: int = 500, save_optimizer=False, lr: float = 1e-5, clip_skip: int = 0,
16
15
  dtype: str = 'fp16', low_vram: bool = False, warmup_steps: int = 0, name: str = 'SD15'):
17
16
  if low_vram:
18
17
  from bitsandbytes.optim import AdamW8bit
@@ -20,6 +19,17 @@ def SD15_finetuning(base_model: str, train_steps: int, dataset, save_step: int =
20
19
  else:
21
20
  optimizer = torch.optim.AdamW(_partial_=True)
22
21
 
22
+ ckpt_saver_dict = dict(
23
+ SD15=ckpt_saver(
24
+ ckpt_type='safetensors',
25
+ target_module='denoiser',
26
+ layers=LAYERS_TRAINABLE,
27
+ )
28
+ )
29
+
30
+ if save_optimizer:
31
+ ckpt_saver_dict['optimizer'] = NekoOptimizerSaver()
32
+
23
33
  from cfgs.train.py import train_base, tuning_base
24
34
 
25
35
  return dict(
@@ -34,11 +44,7 @@ def SD15_finetuning(base_model: str, train_steps: int, dataset, save_step: int =
34
44
  ], weight_decay=1e-2),
35
45
 
36
46
  ckpt_saver=dict(
37
- SD15=ckpt_saver(
38
- ckpt_type='safetensors',
39
- target_module='denoiser',
40
- layers=LAYERS_TRAINABLE,
41
- )
47
+ SD15=ckpt_saver_dict
42
48
  ),
43
49
 
44
50
  train=dict(
@@ -68,9 +74,9 @@ def SD15_finetuning(base_model: str, train_steps: int, dataset, save_step: int =
68
74
  )
69
75
 
70
76
  @neko_cfg
71
- def SD15_lora_train(base_model: str, train_steps: int, dataset, save_step: int = 200, lr: float = 1e-4, rank: int = 4, alpha: float = None,
72
- clip_skip: int = 0, with_conv: bool = False, dtype: str = 'fp16', low_vram: bool = False, warmup_steps: int = 0,
73
- name: str = 'SD15', save_webui_format=False):
77
+ def SD15_lora_train(base_model: str, train_steps: int, dataset, save_step: int = 200, save_optimizer=False, lr: float = 1e-4, rank: int = 4,
78
+ alpha: float = None, clip_skip: int = 0, with_conv: bool = False, dtype: str = 'fp16', low_vram: bool = False,
79
+ warmup_steps: int = 0, name: str = 'SD15', save_webui_format=False):
74
80
  with disable_neko_cfg:
75
81
  if alpha is None:
76
82
  alpha = rank
@@ -101,6 +107,17 @@ def SD15_lora_train(base_model: str, train_steps: int, dataset, save_step: int =
101
107
  else:
102
108
  lora_format = SafeTensorFormat()
103
109
 
110
+ ckpt_saver_dict = dict(
111
+ _replace_=True,
112
+ lora_unet=NekoPluginSaver(
113
+ format=lora_format,
114
+ target_plugin='lora1',
115
+ )
116
+ )
117
+
118
+ if save_optimizer:
119
+ ckpt_saver_dict['optimizer'] = NekoOptimizerSaver()
120
+
104
121
  from cfgs.train.py.examples import SD_FT
105
122
 
106
123
  return dict(
@@ -118,13 +135,7 @@ def SD15_lora_train(base_model: str, train_steps: int, dataset, save_step: int =
118
135
  )
119
136
  ), weight_decay=0.1),
120
137
 
121
- ckpt_saver=dict(
122
- _replace_ = True,
123
- lora_unet=NekoPluginSaver(
124
- format=lora_format,
125
- target_plugin='lora1',
126
- )
127
- ),
138
+ ckpt_saver=ckpt_saver_dict,
128
139
 
129
140
  train=dict(
130
141
  train_steps=train_steps,
@@ -181,7 +192,7 @@ def cfg_data_SD_ARB(img_root: Path_Like, batch_size: int = 4, trigger_word: str
181
192
  )
182
193
 
183
194
  @neko_cfg
184
- def cfg_data_SD_resize_crop(img_root: Path_Like, batch_size: int = 4, trigger_word: str = '', target_size = (512, 512), word_names=None,
195
+ def cfg_data_SD_resize_crop(img_root: Path_Like, batch_size: int = 4, trigger_word: str = '', target_size=(512, 512), word_names=None,
185
196
  prompt_dropout: float = 0, prompt_template: Path_Like = 'prompt_template/caption.txt', loss_weight=1.0):
186
197
  if word_names is None:
187
198
  word_names = dict(pt1=trigger_word)