flaxdiff 0.1.36.5__tar.gz → 0.1.37.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/PKG-INFO +1 -1
  2. flaxdiff-0.1.37.1/flaxdiff/samplers/__init__.py +7 -0
  3. {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/samplers/common.py +20 -13
  4. {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/samplers/ddim.py +1 -1
  5. {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/samplers/ddpm.py +2 -2
  6. {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/samplers/euler.py +3 -3
  7. {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/samplers/heun_sampler.py +2 -2
  8. {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/samplers/multistep_dpm.py +1 -1
  9. {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/samplers/rk4_sampler.py +7 -7
  10. flaxdiff-0.1.37.1/flaxdiff/schedulers/__init__.py +6 -0
  11. {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/trainer/diffusion_trainer.py +16 -14
  12. {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff.egg-info/PKG-INFO +1 -1
  13. {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/pyproject.toml +1 -1
  14. flaxdiff-0.1.36.5/flaxdiff/samplers/__init__.py +0 -7
  15. flaxdiff-0.1.36.5/flaxdiff/schedulers/__init__.py +0 -6
  16. {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/README.md +0 -0
  17. {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/__init__.py +0 -0
  18. {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/data/__init__.py +0 -0
  19. {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/data/dataset_map.py +0 -0
  20. {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/data/datasets.py +0 -0
  21. {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/data/online_loader.py +0 -0
  22. {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/data/sources/gcs.py +0 -0
  23. {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/data/sources/tfds.py +0 -0
  24. {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/metrics/inception.py +0 -0
  25. {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/metrics/utils.py +0 -0
  26. {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/models/__init__.py +0 -0
  27. {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/models/attention.py +0 -0
  28. {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/models/autoencoder/__init__.py +0 -0
  29. {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/models/autoencoder/autoencoder.py +0 -0
  30. {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/models/autoencoder/diffusers.py +0 -0
  31. {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/models/autoencoder/simple_autoenc.py +0 -0
  32. {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/models/common.py +0 -0
  33. {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/models/favor_fastattn.py +0 -0
  34. {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/models/simple_unet.py +0 -0
  35. {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/models/simple_vit.py +0 -0
  36. {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/predictors/__init__.py +0 -0
  37. {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/schedulers/common.py +0 -0
  38. {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/schedulers/continuous.py +0 -0
  39. {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/schedulers/cosine.py +0 -0
  40. {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/schedulers/discrete.py +0 -0
  41. {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/schedulers/exp.py +0 -0
  42. {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/schedulers/karras.py +0 -0
  43. {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/schedulers/linear.py +0 -0
  44. {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/schedulers/sqrt.py +0 -0
  45. {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/trainer/__init__.py +0 -0
  46. {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/trainer/autoencoder_trainer.py +0 -0
  47. {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/trainer/simple_trainer.py +0 -0
  48. {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/trainer/video_diffusion_trainer.py +0 -0
  49. {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff/utils.py +0 -0
  50. {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff.egg-info/SOURCES.txt +0 -0
  51. {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff.egg-info/dependency_links.txt +0 -0
  52. {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff.egg-info/requires.txt +0 -0
  53. {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/flaxdiff.egg-info/top_level.txt +0 -0
  54. {flaxdiff-0.1.36.5 → flaxdiff-0.1.37.1}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flaxdiff
3
- Version: 0.1.36.5
3
+ Version: 0.1.37.1
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
6
6
  License-Expression: MIT
@@ -0,0 +1,7 @@
1
+ from .common import *
2
+ from .ddim import *
3
+ from .ddpm import *
4
+ from .euler import *
5
+ from .heun_sampler import *
6
+ from .rk4_sampler import *
7
+ from .multistep_dpm import *
@@ -37,7 +37,7 @@ class DiffusionSampler():
37
37
  # Classifier free guidance
38
38
  assert null_labels_seq is not None, "Null labels sequence is required for classifier-free guidance"
39
39
  print("Using classifier-free guidance")
40
- def sample_model(x_t, t, *additional_inputs):
40
+ def sample_model(params, x_t, t, *additional_inputs):
41
41
  # Concatenate unconditional and conditional inputs
42
42
  x_t_cat = jnp.concatenate([x_t] * 2, axis=0)
43
43
  t_cat = jnp.concatenate([t] * 2, axis=0)
@@ -46,7 +46,7 @@ class DiffusionSampler():
46
46
 
47
47
  text_labels_seq, = additional_inputs
48
48
  text_labels_seq = jnp.concatenate([text_labels_seq, jnp.broadcast_to(null_labels_seq, text_labels_seq.shape)], axis=0)
49
- model_output = self.model.apply(self.params, *self.noise_schedule.transform_inputs(x_t_cat * c_in_cat, t_cat), text_labels_seq)
49
+ model_output = self.model.apply(params, *self.noise_schedule.transform_inputs(x_t_cat * c_in_cat, t_cat), text_labels_seq)
50
50
  # Split model output into unconditional and conditional parts
51
51
  model_output_cond, model_output_uncond = jnp.split(model_output, 2, axis=0)
52
52
  model_output = model_output_uncond + guidance_scale * (model_output_cond - model_output_uncond)
@@ -55,10 +55,10 @@ class DiffusionSampler():
55
55
  return x_0, eps, model_output
56
56
  else:
57
57
  # Unconditional sampling
58
- def sample_model(x_t, t, *additional_inputs):
58
+ def sample_model(params, x_t, t, *additional_inputs):
59
59
  rates = self.noise_schedule.get_rates(t)
60
60
  c_in = self.model_output_transform.get_input_scale(rates)
61
- model_output = self.model.apply(self.params, *self.noise_schedule.transform_inputs(x_t * c_in, t), *additional_inputs)
61
+ model_output = self.model.apply(params, *self.noise_schedule.transform_inputs(x_t * c_in, t), *additional_inputs)
62
62
  x_0, eps = self.model_output_transform(x_t, model_output, t, self.noise_schedule)
63
63
  return x_0, eps, model_output
64
64
 
@@ -70,22 +70,23 @@ class DiffusionSampler():
70
70
  self.sample_model = sample_model
71
71
 
72
72
  # Used to sample from the diffusion model
73
- def sample_step(self, current_samples:jnp.ndarray, current_step, model_conditioning_inputs, next_step=None, state:MarkovState=None) -> tuple[jnp.ndarray, MarkovState]:
73
+ def sample_step(self, sample_model_fn, current_samples:jnp.ndarray, current_step, model_conditioning_inputs, next_step=None, state:MarkovState=None) -> tuple[jnp.ndarray, MarkovState]:
74
74
  # First clip the noisy images
75
75
  step_ones = jnp.ones((current_samples.shape[0], ), dtype=jnp.int32)
76
76
  current_step = step_ones * current_step
77
77
  next_step = step_ones * next_step
78
- pred_images, pred_noise, _ = self.sample_model(current_samples, current_step, *model_conditioning_inputs)
78
+ pred_images, pred_noise, _ = sample_model_fn(current_samples, current_step, *model_conditioning_inputs)
79
79
  # plotImages(pred_images)
80
80
  # pred_images = clip_images(pred_images)
81
81
  new_samples, state = self.take_next_step(current_samples=current_samples, reconstructed_samples=pred_images,
82
- pred_noise=pred_noise, current_step=current_step, next_step=next_step, state=state,
83
- model_conditioning_inputs=model_conditioning_inputs
82
+ pred_noise=pred_noise, current_step=current_step, next_step=next_step, state=state,
83
+ model_conditioning_inputs=model_conditioning_inputs,
84
+ sample_model_fn=sample_model_fn,
84
85
  )
85
86
  return new_samples, state
86
87
 
87
88
  def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
88
- pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
89
+ pred_noise, current_step, state:RandomMarkovState, sample_model_fn, next_step=1,) -> tuple[jnp.ndarray, RandomMarkovState]:
89
90
  # estimate the q(x_{t-1} | x_t, x_0).
90
91
  # pred_images is x_0, noisy_images is x_t, steps is t
91
92
  return NotImplementedError
@@ -114,6 +115,7 @@ class DiffusionSampler():
114
115
  return jax.random.normal(rngs, (num_images, image_size, image_size, image_channels)) * variance
115
116
 
116
117
  def generate_images(self,
118
+ params:dict=None,
117
119
  num_images=16,
118
120
  diffusion_steps=1000,
119
121
  start_step:int = None,
@@ -131,10 +133,15 @@ class DiffusionSampler():
131
133
  if self.autoencoder is not None:
132
134
  priors = self.autoencoder.encode(priors)
133
135
  samples = priors
136
+
137
+ params = params if params is not None else self.params
138
+
139
+ def sample_model_fn(x_t, t, *additional_inputs):
140
+ return self.sample_model(params, x_t, t, *additional_inputs)
134
141
 
135
142
  # @jax.jit
136
- def sample_step(state:RandomMarkovState, samples, current_step, next_step):
137
- samples, state = self.sample_step(current_samples=samples,
143
+ def sample_step(sample_model_fn, state:RandomMarkovState, samples, current_step, next_step):
144
+ samples, state = self.sample_step(sample_model_fn=sample_model_fn, current_samples=samples,
138
145
  current_step=current_step,
139
146
  model_conditioning_inputs=model_conditioning_inputs,
140
147
  state=state, next_step=next_step)
@@ -154,11 +161,11 @@ class DiffusionSampler():
154
161
  next_step = self.scale_steps(steps[i+1] if i+1 < len(steps) else 0)
155
162
  if i != len(steps) - 1:
156
163
  # print("normal step")
157
- samples, rngstate = sample_step(rngstate, samples, current_step, next_step)
164
+ samples, rngstate = sample_step(sample_model_fn, rngstate, samples, current_step, next_step)
158
165
  else:
159
166
  # print("last step")
160
167
  step_ones = jnp.ones((num_images, ), dtype=jnp.int32)
161
- samples, _, _ = self.sample_model(samples, current_step * step_ones, *model_conditioning_inputs)
168
+ samples, _, _ = sample_model_fn(samples, current_step * step_ones, *model_conditioning_inputs)
162
169
  if self.autoencoder is not None:
163
170
  samples = self.autoencoder.decode(samples)
164
171
  samples = clip_images(samples)
@@ -4,7 +4,7 @@ from ..utils import MarkovState, RandomMarkovState
4
4
 
5
5
  class DDIMSampler(DiffusionSampler):
6
6
  def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
7
- pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
7
+ pred_noise, current_step, state:RandomMarkovState, sample_model_fn, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
8
8
  next_signal_rate, next_noise_rate = self.noise_schedule.get_rates(next_step)
9
9
  return reconstructed_samples * next_signal_rate + pred_noise * next_noise_rate, state
10
10
 
@@ -4,7 +4,7 @@ from .common import DiffusionSampler
4
4
  from ..utils import MarkovState, RandomMarkovState
5
5
  class DDPMSampler(DiffusionSampler):
6
6
  def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
7
- pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
7
+ pred_noise, current_step, state:RandomMarkovState, sample_model_fn, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
8
8
  mean = self.noise_schedule.get_posterior_mean(reconstructed_samples, current_samples, current_step)
9
9
  variance = self.noise_schedule.get_posterior_variance(steps=current_step)
10
10
 
@@ -19,7 +19,7 @@ class DDPMSampler(DiffusionSampler):
19
19
 
20
20
  class SimpleDDPMSampler(DiffusionSampler):
21
21
  def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
22
- pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
22
+ pred_noise, current_step, state:RandomMarkovState, sample_model_fn, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
23
23
  state, rng = state.get_random_key()
24
24
  noise = jax.random.normal(rng, reconstructed_samples.shape, dtype=jnp.float32)
25
25
 
@@ -6,7 +6,7 @@ from ..utils import RandomMarkovState
6
6
  class EulerSampler(DiffusionSampler):
7
7
  # Basically a DDIM Sampler but parameterized as an ODE
8
8
  def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
9
- pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
9
+ pred_noise, current_step, state:RandomMarkovState, sample_model_fn, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
10
10
  current_alpha, current_sigma = self.noise_schedule.get_rates(current_step)
11
11
  next_alpha, next_sigma = self.noise_schedule.get_rates(next_step)
12
12
 
@@ -22,7 +22,7 @@ class SimplifiedEulerSampler(DiffusionSampler):
22
22
  This is for networks with forward diffusion of the form x_{t+1} = x_t + sigma_t * epsilon_t
23
23
  """
24
24
  def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
25
- pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
25
+ pred_noise, current_step, state:RandomMarkovState, sample_model_fn, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
26
26
  _, current_sigma = self.noise_schedule.get_rates(current_step)
27
27
  _, next_sigma = self.noise_schedule.get_rates(next_step)
28
28
 
@@ -37,7 +37,7 @@ class EulerAncestralSampler(DiffusionSampler):
37
37
  Similar to EulerSampler but with ancestral sampling
38
38
  """
39
39
  def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
40
- pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
40
+ pred_noise, current_step, state:RandomMarkovState, sample_model_fn, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
41
41
  current_alpha, current_sigma = self.noise_schedule.get_rates(current_step)
42
42
  next_alpha, next_sigma = self.noise_schedule.get_rates(next_step)
43
43
 
@@ -5,7 +5,7 @@ from ..utils import RandomMarkovState
5
5
 
6
6
  class HeunSampler(DiffusionSampler):
7
7
  def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
8
- pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
8
+ pred_noise, current_step, state:RandomMarkovState, sample_model_fn, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
9
9
  # Get the noise and signal rates for the current and next steps
10
10
  current_alpha, current_sigma = self.noise_schedule.get_rates(current_step)
11
11
  next_alpha, next_sigma = self.noise_schedule.get_rates(next_step)
@@ -17,7 +17,7 @@ class HeunSampler(DiffusionSampler):
17
17
  next_samples_0 = current_samples + dx_0 * dt
18
18
 
19
19
  # Recompute x_0 and eps at the first estimate to refine the derivative
20
- estimated_x_0, _, _ = self.sample_model(next_samples_0, next_step, *model_conditioning_inputs)
20
+ estimated_x_0, _, _ = sample_model_fn(next_samples_0, next_step, *model_conditioning_inputs)
21
21
 
22
22
  # Estimate the refined derivative using the midpoint (Heun's method)
23
23
  dx_1 = (next_samples_0 - x_0_coeff * estimated_x_0) / next_sigma
@@ -9,7 +9,7 @@ class MultiStepDPM(DiffusionSampler):
9
9
  self.history = []
10
10
 
11
11
  def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
12
- pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
12
+ pred_noise, current_step, state:RandomMarkovState, sample_model_fn, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
13
13
  # Get the noise and signal rates for the current and next steps
14
14
  current_alpha, current_sigma = self.noise_schedule.get_rates(current_step)
15
15
  next_alpha, next_sigma = self.noise_schedule.get_rates(next_step)
@@ -9,14 +9,14 @@ class RK4Sampler(DiffusionSampler):
9
9
  super().__init__(*args, **kwargs)
10
10
  assert issubclass(type(self.noise_schedule), GeneralizedNoiseScheduler), "Noise schedule must be a GeneralizedNoiseScheduler"
11
11
  @jax.jit
12
- def get_derivative(x_t, sigma, state:RandomMarkovState, model_conditioning_inputs) -> tuple[jnp.ndarray, RandomMarkovState]:
12
+ def get_derivative(sample_model_fn, x_t, sigma, state:RandomMarkovState, model_conditioning_inputs) -> tuple[jnp.ndarray, RandomMarkovState]:
13
13
  t = self.noise_schedule.get_timesteps(sigma)
14
- x_0, eps, _ = self.sample_model(x_t, t, *model_conditioning_inputs)
14
+ x_0, eps, _ = sample_model_fn(x_t, t, *model_conditioning_inputs)
15
15
  return eps, state
16
16
 
17
17
  self.get_derivative = get_derivative
18
18
 
19
- def sample_step(self, current_samples:jnp.ndarray, current_step, model_conditioning_inputs, next_step=None, state:MarkovState=None) -> tuple[jnp.ndarray, MarkovState]:
19
+ def sample_step(self, sample_model_fn, current_samples:jnp.ndarray, current_step, model_conditioning_inputs, next_step=None, state:MarkovState=None) -> tuple[jnp.ndarray, MarkovState]:
20
20
  step_ones = jnp.ones((current_samples.shape[0], ), dtype=jnp.int32)
21
21
  current_step = step_ones * current_step
22
22
  next_step = step_ones * next_step
@@ -25,10 +25,10 @@ class RK4Sampler(DiffusionSampler):
25
25
 
26
26
  dt = next_sigma - current_sigma
27
27
 
28
- k1, state = self.get_derivative(current_samples, current_sigma, state, model_conditioning_inputs)
29
- k2, state = self.get_derivative(current_samples + 0.5 * k1 * dt, current_sigma + 0.5 * dt, state, model_conditioning_inputs)
30
- k3, state = self.get_derivative(current_samples + 0.5 * k2 * dt, current_sigma + 0.5 * dt, state, model_conditioning_inputs)
31
- k4, state = self.get_derivative(current_samples + k3 * dt, current_sigma + dt, state, model_conditioning_inputs)
28
+ k1, state = self.get_derivative(sample_model_fn, current_samples, current_sigma, state, model_conditioning_inputs)
29
+ k2, state = self.get_derivative(sample_model_fn, current_samples + 0.5 * k1 * dt, current_sigma + 0.5 * dt, state, model_conditioning_inputs)
30
+ k3, state = self.get_derivative(sample_model_fn, current_samples + 0.5 * k2 * dt, current_sigma + 0.5 * dt, state, model_conditioning_inputs)
31
+ k4, state = self.get_derivative(sample_model_fn, current_samples + k3 * dt, current_sigma + dt, state, model_conditioning_inputs)
32
32
 
33
33
  next_samples = current_samples + (((k1 + 2 * k2 + 2 * k3 + k4) * dt) / 6)
34
34
  return next_samples, state
@@ -0,0 +1,6 @@
1
+ from .discrete import *
2
+ from .common import *
3
+ from .cosine import *
4
+ from .linear import *
5
+ from .sqrt import *
6
+ from .karras import *
@@ -235,19 +235,19 @@ class DiffusionTrainer(SimpleTrainer):
235
235
  null_labels_full = null_labels_full.astype(jnp.float16)
236
236
  # null_labels_seq = jnp.array(null_labels_full[0], dtype=jnp.float16)
237
237
 
238
- def generate_sampler(state: TrainState):
239
- sampler = sampler_class(
240
- model=model,
241
- params=state.ema_params,
242
- noise_schedule=self.noise_schedule if sampling_noise_schedule is None else sampling_noise_schedule,
243
- model_output_transform=self.model_output_transform,
244
- image_size=self.input_shapes['x'][0],
245
- null_labels_seq=null_labels_full,
246
- autoencoder=autoencoder,
247
- )
248
- return sampler
238
+ sampler = sampler_class(
239
+ model=model,
240
+ params=None,
241
+ noise_schedule=self.noise_schedule if sampling_noise_schedule is None else sampling_noise_schedule,
242
+ model_output_transform=self.model_output_transform,
243
+ image_size=self.input_shapes['x'][0],
244
+ null_labels_seq=null_labels_full,
245
+ autoencoder=autoencoder,
246
+ guidance_scale=3.0,
247
+ )
249
248
 
250
249
  def generate_samples(
250
+ val_state: TrainState,
251
251
  batch,
252
252
  sampler: DiffusionSampler,
253
253
  diffusion_steps: int,
@@ -255,6 +255,7 @@ class DiffusionTrainer(SimpleTrainer):
255
255
  labels_seq = encoder.encode_from_tokens(batch)
256
256
  labels_seq = jnp.array(labels_seq, dtype=jnp.float16)
257
257
  samples = sampler.generate_images(
258
+ params=val_state.ema_params,
258
259
  num_images=len(labels_seq),
259
260
  diffusion_steps=diffusion_steps,
260
261
  start_step=1000,
@@ -264,7 +265,7 @@ class DiffusionTrainer(SimpleTrainer):
264
265
  )
265
266
  return samples
266
267
 
267
- return generate_sampler, generate_samples
268
+ return sampler, generate_samples
268
269
 
269
270
  def validation_loop(
270
271
  self,
@@ -275,14 +276,15 @@ class DiffusionTrainer(SimpleTrainer):
275
276
  current_step,
276
277
  diffusion_steps=200,
277
278
  ):
278
- generate_sampler, generate_samples = val_step_fn
279
+ sampler, generate_samples = val_step_fn
279
280
 
280
- sampler = generate_sampler(val_state)
281
+ # sampler = generate_sampler(val_state)
281
282
 
282
283
  val_ds = iter(val_ds()) if val_ds else None
283
284
  # Evaluation step
284
285
  try:
285
286
  samples = generate_samples(
287
+ val_state,
286
288
  next(val_ds),
287
289
  sampler,
288
290
  diffusion_steps,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flaxdiff
3
- Version: 0.1.36.5
3
+ Version: 0.1.37.1
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
6
6
  License-Expression: MIT
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "flaxdiff"
7
- version = "0.1.36.5"
7
+ version = "0.1.37.1"
8
8
  description = "A versatile and easy to understand Diffusion library"
9
9
  readme = "README.md"
10
10
  authors = [
@@ -1,7 +0,0 @@
1
- from .common import DiffusionSampler
2
- from .ddim import DDIMSampler
3
- from .ddpm import DDPMSampler, SimpleDDPMSampler
4
- from .euler import EulerSampler, SimplifiedEulerSampler
5
- from .heun_sampler import HeunSampler
6
- from .rk4_sampler import RK4Sampler
7
- from .multistep_dpm import MultiStepDPM
@@ -1,6 +0,0 @@
1
- from .discrete import DiscreteNoiseScheduler
2
- from .common import NoiseScheduler, GeneralizedNoiseScheduler
3
- from .cosine import CosineNoiseScheduler, ContinuousNoiseScheduler, CosineGeneralNoiseScheduler
4
- from .linear import LinearNoiseSchedule
5
- from .sqrt import SqrtContinuousNoiseScheduler
6
- from .karras import KarrasVENoiseScheduler, SimpleExpNoiseScheduler, EDMNoiseScheduler
File without changes
File without changes