flaxdiff 0.1.35.6__tar.gz → 0.1.36__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. {flaxdiff-0.1.35.6 → flaxdiff-0.1.36}/PKG-INFO +8 -2
  2. flaxdiff-0.1.36/flaxdiff/data/dataset_map.py +71 -0
  3. flaxdiff-0.1.36/flaxdiff/data/datasets.py +169 -0
  4. {flaxdiff-0.1.35.6 → flaxdiff-0.1.36}/flaxdiff/data/online_loader.py +69 -42
  5. {flaxdiff-0.1.35.6 → flaxdiff-0.1.36}/flaxdiff/samplers/common.py +72 -20
  6. flaxdiff-0.1.36/flaxdiff/samplers/ddim.py +10 -0
  7. {flaxdiff-0.1.35.6 → flaxdiff-0.1.36}/flaxdiff/samplers/ddpm.py +5 -11
  8. {flaxdiff-0.1.35.6 → flaxdiff-0.1.36}/flaxdiff/samplers/euler.py +7 -10
  9. {flaxdiff-0.1.35.6 → flaxdiff-0.1.36}/flaxdiff/samplers/heun_sampler.py +3 -4
  10. {flaxdiff-0.1.35.6 → flaxdiff-0.1.36}/flaxdiff/samplers/multistep_dpm.py +2 -3
  11. {flaxdiff-0.1.35.6 → flaxdiff-0.1.36}/flaxdiff/samplers/rk4_sampler.py +9 -9
  12. {flaxdiff-0.1.35.6 → flaxdiff-0.1.36}/flaxdiff/trainer/autoencoder_trainer.py +1 -1
  13. {flaxdiff-0.1.35.6 → flaxdiff-0.1.36}/flaxdiff/trainer/diffusion_trainer.py +124 -32
  14. {flaxdiff-0.1.35.6 → flaxdiff-0.1.36}/flaxdiff/trainer/simple_trainer.py +187 -91
  15. flaxdiff-0.1.36/flaxdiff/trainer/video_diffusion_trainer.py +62 -0
  16. flaxdiff-0.1.36/flaxdiff/utils.py +192 -0
  17. {flaxdiff-0.1.35.6 → flaxdiff-0.1.36}/flaxdiff.egg-info/PKG-INFO +8 -2
  18. {flaxdiff-0.1.35.6 → flaxdiff-0.1.36}/flaxdiff.egg-info/SOURCES.txt +4 -1
  19. {flaxdiff-0.1.35.6 → flaxdiff-0.1.36}/setup.py +1 -1
  20. flaxdiff-0.1.35.6/flaxdiff/samplers/ddim.py +0 -10
  21. flaxdiff-0.1.35.6/flaxdiff/utils.py +0 -89
  22. {flaxdiff-0.1.35.6 → flaxdiff-0.1.36}/README.md +0 -0
  23. {flaxdiff-0.1.35.6 → flaxdiff-0.1.36}/flaxdiff/__init__.py +0 -0
  24. {flaxdiff-0.1.35.6 → flaxdiff-0.1.36}/flaxdiff/data/__init__.py +0 -0
  25. {flaxdiff-0.1.35.6 → flaxdiff-0.1.36}/flaxdiff/models/__init__.py +0 -0
  26. {flaxdiff-0.1.35.6 → flaxdiff-0.1.36}/flaxdiff/models/attention.py +0 -0
  27. {flaxdiff-0.1.35.6 → flaxdiff-0.1.36}/flaxdiff/models/autoencoder/__init__.py +0 -0
  28. {flaxdiff-0.1.35.6 → flaxdiff-0.1.36}/flaxdiff/models/autoencoder/autoencoder.py +0 -0
  29. {flaxdiff-0.1.35.6 → flaxdiff-0.1.36}/flaxdiff/models/autoencoder/diffusers.py +0 -0
  30. {flaxdiff-0.1.35.6 → flaxdiff-0.1.36}/flaxdiff/models/autoencoder/simple_autoenc.py +0 -0
  31. {flaxdiff-0.1.35.6 → flaxdiff-0.1.36}/flaxdiff/models/common.py +0 -0
  32. {flaxdiff-0.1.35.6 → flaxdiff-0.1.36}/flaxdiff/models/favor_fastattn.py +0 -0
  33. {flaxdiff-0.1.35.6 → flaxdiff-0.1.36}/flaxdiff/models/simple_unet.py +0 -0
  34. {flaxdiff-0.1.35.6 → flaxdiff-0.1.36}/flaxdiff/models/simple_vit.py +0 -0
  35. {flaxdiff-0.1.35.6 → flaxdiff-0.1.36}/flaxdiff/predictors/__init__.py +0 -0
  36. {flaxdiff-0.1.35.6 → flaxdiff-0.1.36}/flaxdiff/samplers/__init__.py +0 -0
  37. {flaxdiff-0.1.35.6 → flaxdiff-0.1.36}/flaxdiff/schedulers/__init__.py +0 -0
  38. {flaxdiff-0.1.35.6 → flaxdiff-0.1.36}/flaxdiff/schedulers/common.py +0 -0
  39. {flaxdiff-0.1.35.6 → flaxdiff-0.1.36}/flaxdiff/schedulers/continuous.py +0 -0
  40. {flaxdiff-0.1.35.6 → flaxdiff-0.1.36}/flaxdiff/schedulers/cosine.py +0 -0
  41. {flaxdiff-0.1.35.6 → flaxdiff-0.1.36}/flaxdiff/schedulers/discrete.py +0 -0
  42. {flaxdiff-0.1.35.6 → flaxdiff-0.1.36}/flaxdiff/schedulers/exp.py +0 -0
  43. {flaxdiff-0.1.35.6 → flaxdiff-0.1.36}/flaxdiff/schedulers/karras.py +0 -0
  44. {flaxdiff-0.1.35.6 → flaxdiff-0.1.36}/flaxdiff/schedulers/linear.py +0 -0
  45. {flaxdiff-0.1.35.6 → flaxdiff-0.1.36}/flaxdiff/schedulers/sqrt.py +0 -0
  46. {flaxdiff-0.1.35.6 → flaxdiff-0.1.36}/flaxdiff/trainer/__init__.py +0 -0
  47. {flaxdiff-0.1.35.6 → flaxdiff-0.1.36}/flaxdiff.egg-info/dependency_links.txt +0 -0
  48. {flaxdiff-0.1.35.6 → flaxdiff-0.1.36}/flaxdiff.egg-info/requires.txt +0 -0
  49. {flaxdiff-0.1.35.6 → flaxdiff-0.1.36}/flaxdiff.egg-info/top_level.txt +0 -0
  50. {flaxdiff-0.1.35.6 → flaxdiff-0.1.36}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.4
2
2
  Name: flaxdiff
3
- Version: 0.1.35.6
3
+ Version: 0.1.36
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author: Ashish Kumar Singh
6
6
  Author-email: ashishkmr472@gmail.com
@@ -10,6 +10,12 @@ Requires-Dist: optax>=0.2.2
10
10
  Requires-Dist: jax>=0.4.28
11
11
  Requires-Dist: orbax
12
12
  Requires-Dist: clu
13
+ Dynamic: author
14
+ Dynamic: author-email
15
+ Dynamic: description
16
+ Dynamic: description-content-type
17
+ Dynamic: requires-dist
18
+ Dynamic: summary
13
19
 
14
20
  # ![](images/logo.jpeg "FlaxDiff")
15
21
 
@@ -0,0 +1,71 @@
1
+ from .sources.tfds import data_source_tfds, tfds_augmenters
2
+ from .sources.gcs import data_source_gcs, data_source_combined_gcs, gcs_augmenters
3
+
4
+ # Configure the following for your datasets
5
+ datasetMap = {
6
+ "oxford_flowers102": {
7
+ "source": data_source_tfds("oxford_flowers102", use_tf=False),
8
+ "augmenter": tfds_augmenters,
9
+ },
10
+ "cc12m": {
11
+ "source": data_source_gcs('arrayrecord2/cc12m'),
12
+ "augmenter": gcs_augmenters,
13
+ },
14
+ "laiona_coco": {
15
+ "source": data_source_gcs('arrayrecord2/laion-aesthetics-12m+mscoco-2017'),
16
+ "augmenter": gcs_augmenters,
17
+ },
18
+ "aesthetic_coyo": {
19
+ "source": data_source_gcs('arrayrecords/aestheticCoyo_0.25clip_6aesthetic'),
20
+ "augmenter": gcs_augmenters,
21
+ },
22
+ "combined_aesthetic": {
23
+ "source": data_source_combined_gcs([
24
+ 'arrayrecord2/laion-aesthetics-12m+mscoco-2017',
25
+ 'arrayrecords/aestheticCoyo_0.25clip_6aesthetic',
26
+ 'arrayrecord2/cc12m',
27
+ 'arrayrecords/aestheticCoyo_0.25clip_6aesthetic',
28
+ ]),
29
+ "augmenter": gcs_augmenters,
30
+ },
31
+ "laiona_coco_coyo": {
32
+ "source": data_source_combined_gcs([
33
+ 'arrayrecords/aestheticCoyo_0.25clip_6aesthetic',
34
+ 'arrayrecord2/laion-aesthetics-12m+mscoco-2017',
35
+ 'arrayrecords/aestheticCoyo_0.25clip_6aesthetic',
36
+ ]),
37
+ "augmenter": gcs_augmenters,
38
+ },
39
+ "combined_30m": {
40
+ "source": data_source_combined_gcs([
41
+ 'arrayrecord2/laion-aesthetics-12m+mscoco-2017',
42
+ 'arrayrecord2/cc12m',
43
+ 'arrayrecord2/aestheticCoyo_0.26_clip_5.5aesthetic_256plus',
44
+ "arrayrecord2/playground+leonardo_x4+cc3m.parquet",
45
+ ]),
46
+ "augmenter": gcs_augmenters,
47
+ }
48
+ }
49
+
50
+ onlineDatasetMap = {
51
+ "combined_online": {
52
+ "source": [
53
+ # "gs://flaxdiff-datasets-regional/datasets/laion-aesthetics-12m+mscoco-2017.parquet"
54
+ # "ChristophSchuhmann/MS_COCO_2017_URL_TEXT",
55
+ # "dclure/laion-aesthetics-12m-umap",
56
+ "gs://flaxdiff-datasets-regional/datasets/laion-aesthetics-12m+mscoco-2017",
57
+ "gs://flaxdiff-datasets-regional/datasets/coyo700m-aesthetic-5.4_25M",
58
+ "gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m",
59
+ "gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m",
60
+ "gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m",
61
+ "gs://flaxdiff-datasets-regional/datasets/cc12m",
62
+ "gs://flaxdiff-datasets-regional/datasets/playground-liked",
63
+ "gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m",
64
+ "gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m",
65
+ "gs://flaxdiff-datasets-regional/datasets/cc3m",
66
+ "gs://flaxdiff-datasets-regional/datasets/cc3m",
67
+ "gs://flaxdiff-datasets-regional/datasets/laion2B-en-aesthetic-4.2_37M",
68
+ # "gs://flaxdiff-datasets-regional/datasets/laiion400m-185M"
69
+ ]
70
+ }
71
+ }
@@ -0,0 +1,169 @@
1
+ import jax.numpy as jnp
2
+ import grain.python as pygrain
3
+ from typing import Dict
4
+ import numpy as np
5
+ import jax
6
+ from flaxdiff.utils import convert_to_global_tree, AutoTextTokenizer
7
+ from .dataset_map import datasetMap, onlineDatasetMap
8
+ import traceback
9
+ from .online_loader import OnlineStreamingDataLoader
10
+ import queue
11
+ from jax.sharding import Mesh
12
+ import threading
13
+
14
+ def batch_mesh_map(mesh):
15
+ class augmenters(pygrain.MapTransform):
16
+ def __init__(self, *args, **kwargs):
17
+ super().__init__(*args, **kwargs)
18
+
19
+ def map(self, batch) -> Dict[str, jnp.array]:
20
+ return convert_to_global_tree(mesh, batch)
21
+ return augmenters
22
+
23
+ def get_dataset_grain(
24
+ data_name="cc12m",
25
+ batch_size=64,
26
+ image_scale=256,
27
+ count=None,
28
+ num_epochs=None,
29
+ method=jax.image.ResizeMethod.LANCZOS3,
30
+ worker_count=32,
31
+ read_thread_count=64,
32
+ read_buffer_size=50,
33
+ worker_buffer_size=20,
34
+ seed=0,
35
+ dataset_source="/mnt/gcs_mount/arrayrecord2/cc12m/",
36
+ ):
37
+ dataset = datasetMap[data_name]
38
+ data_source = dataset["source"](dataset_source)
39
+ augmenter = dataset["augmenter"](image_scale, method)
40
+
41
+ local_batch_size = batch_size // jax.process_count()
42
+
43
+ sampler = pygrain.IndexSampler(
44
+ num_records=len(data_source) if count is None else count,
45
+ shuffle=True,
46
+ seed=seed,
47
+ num_epochs=num_epochs,
48
+ shard_options=pygrain.ShardByJaxProcess(),
49
+ )
50
+
51
+ def get_trainset():
52
+ transformations = [
53
+ augmenter(),
54
+ pygrain.Batch(local_batch_size, drop_remainder=True),
55
+ ]
56
+
57
+ # if mesh != None:
58
+ # transformations += [batch_mesh_map(mesh)]
59
+
60
+ loader = pygrain.DataLoader(
61
+ data_source=data_source,
62
+ sampler=sampler,
63
+ operations=transformations,
64
+ worker_count=worker_count,
65
+ read_options=pygrain.ReadOptions(
66
+ read_thread_count, read_buffer_size
67
+ ),
68
+ worker_buffer_size=worker_buffer_size,
69
+ )
70
+ return loader
71
+
72
+
73
+ return {
74
+ "train": get_trainset,
75
+ "train_len": len(data_source),
76
+ "local_batch_size": local_batch_size,
77
+ "global_batch_size": batch_size,
78
+ # "null_labels": null_labels,
79
+ # "null_labels_full": null_labels_full,
80
+ # "model": model,
81
+ # "tokenizer": tokenizer,
82
+ }
83
+
84
+ def generate_collate_fn():
85
+ auto_tokenize = AutoTextTokenizer(tensor_type="np")
86
+ def default_collate(batch):
87
+ try:
88
+ # urls = [sample["url"] for sample in batch]
89
+ captions = [sample["caption"] for sample in batch]
90
+ results = auto_tokenize(captions)
91
+ images = np.stack([sample["image"] for sample in batch], axis=0)
92
+ return {
93
+ "image": images,
94
+ "input_ids": results['input_ids'],
95
+ "attention_mask": results['attention_mask'],
96
+ }
97
+ except Exception as e:
98
+ print("Error in collate function", e, [sample["image"].shape for sample in batch])
99
+ traceback.print_exc()
100
+
101
+ return default_collate
102
+
103
+ def get_dataset_online(
104
+ data_name="combined_online",
105
+ batch_size=64,
106
+ image_scale=256,
107
+ count=None,
108
+ num_epochs=None,
109
+ method=jax.image.ResizeMethod.LANCZOS3,
110
+ worker_count=32,
111
+ read_thread_count=64,
112
+ read_buffer_size=50,
113
+ worker_buffer_size=20,
114
+ seed=0,
115
+ dataset_source="/mnt/gcs_mount/arrayrecord2/cc12m/",
116
+ ):
117
+ local_batch_size = batch_size // jax.process_count()
118
+
119
+ sources = onlineDatasetMap[data_name]["source"]
120
+ dataloader = OnlineStreamingDataLoader(
121
+ sources,
122
+ batch_size=local_batch_size,
123
+ num_workers=worker_count,
124
+ num_threads=read_thread_count,
125
+ image_shape=(image_scale, image_scale),
126
+ global_process_count=jax.process_count(),
127
+ global_process_index=jax.process_index(),
128
+ prefetch=worker_buffer_size,
129
+ collate_fn=generate_collate_fn(),
130
+ default_split="train",
131
+ )
132
+
133
+ def get_trainset(mesh: Mesh = None):
134
+ if mesh != None:
135
+ class dataLoaderWithMesh:
136
+ def __init__(self, dataloader, mesh):
137
+ self.dataloader = dataloader
138
+ self.mesh = mesh
139
+ self.tmp_queue = queue.Queue(worker_buffer_size)
140
+ def batch_loader():
141
+ for batch in self.dataloader:
142
+ try:
143
+ self.tmp_queue.put(convert_to_global_tree(mesh, batch))
144
+ except Exception as e:
145
+ print("Error processing batch", e)
146
+ self.loader_thread = threading.Thread(target=batch_loader)
147
+ self.loader_thread.start()
148
+
149
+ def __iter__(self):
150
+ return self
151
+
152
+ def __next__(self):
153
+ return self.tmp_queue.get()
154
+
155
+ dataloader_with_mesh = dataLoaderWithMesh(dataloader, mesh)
156
+
157
+ return dataloader_with_mesh
158
+ return dataloader
159
+
160
+ return {
161
+ "train": get_trainset,
162
+ "train_len": len(dataloader) * jax.process_count(),
163
+ "local_batch_size": local_batch_size,
164
+ "global_batch_size": batch_size,
165
+ # "null_labels": null_labels,
166
+ # "null_labels_full": null_labels_full,
167
+ # "model": model,
168
+ # "tokenizer": tokenizer,
169
+ }
@@ -45,36 +45,43 @@ def fetch_single_image(image_url, timeout=None, retries=0):
45
45
 
46
46
 
47
47
  def default_image_processor(
48
- image, image_shape,
48
+ image, image_shape,
49
49
  min_image_shape=(128, 128),
50
50
  upscale_interpolation=cv2.INTER_CUBIC,
51
51
  downscale_interpolation=cv2.INTER_AREA,
52
52
  ):
53
- image = np.array(image)
54
- original_height, original_width = image.shape[:2]
55
- # check if the image is too small
56
- if min(original_height, original_width) < min(min_image_shape):
57
- return None, original_height, original_width
58
- # check if wrong aspect ratio
59
- if max(original_height, original_width) / min(original_height, original_width) > 2.4:
60
- return None, original_height, original_width
61
- # check if the variance is too low
62
- if np.std(image) < 1e-5:
63
- return None, original_height, original_width
64
- # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
65
- downscale = max(original_width, original_height) > max(image_shape)
66
- interpolation = downscale_interpolation if downscale else upscale_interpolation
67
-
68
- image = A.longest_max_size(image, max(
69
- image_shape), interpolation=interpolation)
70
- image = A.pad(
71
- image,
72
- min_height=image_shape[0],
73
- min_width=image_shape[1],
74
- border_mode=cv2.BORDER_CONSTANT,
75
- value=[255, 255, 255],
76
- )
77
- return image, original_height, original_width
53
+ try:
54
+ image = np.array(image)
55
+ if len(image.shape) != 3 or image.shape[2] != 3:
56
+ return None, 0, 0
57
+ original_height, original_width = image.shape[:2]
58
+ # check if the image is too small
59
+ if min(original_height, original_width) < min(min_image_shape):
60
+ return None, original_height, original_width
61
+ # check if wrong aspect ratio
62
+ if max(original_height, original_width) / min(original_height, original_width) > 2.4:
63
+ return None, original_height, original_width
64
+ # check if the variance is too low
65
+ if np.std(image) < 1e-5:
66
+ return None, original_height, original_width
67
+ # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
68
+ downscale = max(original_width, original_height) > max(image_shape)
69
+ interpolation = downscale_interpolation if downscale else upscale_interpolation
70
+
71
+ image = A.longest_max_size(image, max(
72
+ image_shape), interpolation=interpolation)
73
+ image = A.pad(
74
+ image,
75
+ min_height=image_shape[0],
76
+ min_width=image_shape[1],
77
+ border_mode=cv2.BORDER_CONSTANT,
78
+ value=[255, 255, 255],
79
+ )
80
+ return image, original_height, original_width
81
+ except Exception as e:
82
+ # print("Error processing image", e, image_shape, interpolation)
83
+ # traceback.print_exc()
84
+ return None, 0, 0
78
85
 
79
86
 
80
87
  def map_sample(
@@ -120,14 +127,36 @@ def map_sample(
120
127
  # })
121
128
  pass
122
129
 
123
-
124
130
  def default_feature_extractor(sample):
131
+ url = None
132
+ if "url" in sample:
133
+ url = sample["url"]
134
+ elif "URL" in sample:
135
+ url = sample["URL"]
136
+ elif "image_url" in sample:
137
+ url = sample["image_url"]
138
+ else:
139
+ print("No url found in sample, skipping", sample.keys())
140
+
141
+ caption = None
142
+ if "caption" in sample:
143
+ caption = sample["caption"]
144
+ elif "CAPTION" in sample:
145
+ caption = sample["CAPTION"]
146
+ elif "txt" in sample:
147
+ caption = sample["txt"]
148
+ elif "TEXT" in sample:
149
+ caption = sample["TEXT"]
150
+ elif "text" in sample:
151
+ caption = sample["text"]
152
+ else:
153
+ print("No caption found in sample, skipping", sample.keys())
154
+
125
155
  return {
126
- "url": sample["url"],
127
- "caption": sample["caption"],
156
+ "url": url,
157
+ "caption": caption,
128
158
  }
129
-
130
-
159
+
131
160
  def map_batch(
132
161
  batch, num_threads=256, image_shape=(256, 256),
133
162
  min_image_shape=(128, 128),
@@ -301,15 +330,13 @@ class OnlineStreamingDataLoader():
301
330
  self.dataset = dataset.shard(
302
331
  num_shards=global_process_count, index=global_process_index)
303
332
  print(f"Dataset length: {len(dataset)}")
304
- self.iterator = ImageBatchIterator(
305
- self.dataset, image_shape=image_shape,
306
- min_image_shape=min_image_shape,
307
- num_workers=num_workers, batch_size=batch_size, num_threads=num_threads,
308
- timeout=timeout, retries=retries, image_processor=image_processor,
309
- upscale_interpolation=upscale_interpolation,
310
- downscale_interpolation=downscale_interpolation,
311
- feature_extractor=feature_extractor
312
- )
333
+ self.iterator = ImageBatchIterator(self.dataset, image_shape=image_shape,
334
+ min_image_shape=min_image_shape,
335
+ num_workers=num_workers, batch_size=batch_size, num_threads=num_threads,
336
+ timeout=timeout, retries=retries, image_processor=image_processor,
337
+ upscale_interpolation=upscale_interpolation,
338
+ downscale_interpolation=downscale_interpolation,
339
+ feature_extractor=feature_extractor)
313
340
  self.batch_size = batch_size
314
341
 
315
342
  # Launch a thread to load batches in the background
@@ -320,7 +347,7 @@ class OnlineStreamingDataLoader():
320
347
  try:
321
348
  self.batch_queue.put(collate_fn(batch))
322
349
  except Exception as e:
323
- print("Error processing batch", e)
350
+ print("Error collating batch", e)
324
351
 
325
352
  self.loader_thread = threading.Thread(target=batch_loader)
326
353
  self.loader_thread.start()
@@ -333,4 +360,4 @@ class OnlineStreamingDataLoader():
333
360
  # return self.collate_fn(next(self.iterator))
334
361
 
335
362
  def __len__(self):
336
- return len(self.dataset)
363
+ return len(self.dataset)
@@ -15,36 +15,76 @@ class DiffusionSampler():
15
15
 
16
16
  def __init__(self, model:nn.Module, params:dict,
17
17
  noise_schedule:NoiseScheduler,
18
- model_output_transform:DiffusionPredictionTransform=EpsilonPredictionTransform()):
18
+ model_output_transform:DiffusionPredictionTransform=EpsilonPredictionTransform(),
19
+ guidance_scale:float = 0.0,
20
+ null_labels_seq:jax.Array=None,
21
+ autoencoder=None,
22
+ image_size=256,
23
+ autoenc_scale_reduction=8,
24
+ autoenc_latent_channels=4,
25
+ ):
19
26
  self.model = model
20
27
  self.noise_schedule = noise_schedule
21
28
  self.params = params
22
29
  self.model_output_transform = model_output_transform
23
-
24
- @jax.jit
25
- def sample_model(x_t, t):
26
- rates = self.noise_schedule.get_rates(t)
27
- c_in = self.model_output_transform.get_input_scale(rates)
28
- model_output = self.model.apply(self.params, *self.noise_schedule.transform_inputs(x_t * c_in, t))
29
- x_0, eps = self.model_output_transform(x_t, model_output, t, self.noise_schedule)
30
- return x_0, eps, model_output
30
+ self.guidance_scale = guidance_scale
31
+ self.image_size = image_size
32
+ self.autoenc_scale_reduction = autoenc_scale_reduction
33
+ self.autoencoder = autoencoder
34
+ self.autoenc_latent_channels = autoenc_latent_channels
31
35
 
36
+ if self.guidance_scale > 0:
37
+ # Classifier free guidance
38
+ assert null_labels_seq is not None, "Null labels sequence is required for classifier-free guidance"
39
+ print("Using classifier-free guidance")
40
+ def sample_model(x_t, t, *additional_inputs):
41
+ # Concatenate unconditional and conditional inputs
42
+ x_t_cat = jnp.concatenate([x_t] * 2, axis=0)
43
+ t_cat = jnp.concatenate([t] * 2, axis=0)
44
+ rates_cat = self.noise_schedule.get_rates(t_cat)
45
+ c_in_cat = self.model_output_transform.get_input_scale(rates_cat)
46
+
47
+ text_labels_seq, = additional_inputs
48
+ text_labels_seq = jnp.concatenate([text_labels_seq, jnp.broadcast_to(null_labels_seq, text_labels_seq.shape)], axis=0)
49
+ model_output = self.model.apply(self.params, *self.noise_schedule.transform_inputs(x_t_cat * c_in_cat, t_cat), text_labels_seq)
50
+ # Split model output into unconditional and conditional parts
51
+ model_output_cond, model_output_uncond = jnp.split(model_output, 2, axis=0)
52
+ model_output = model_output_uncond + guidance_scale * (model_output_cond - model_output_uncond)
53
+
54
+ x_0, eps = self.model_output_transform(x_t, model_output, t, self.noise_schedule)
55
+ return x_0, eps, model_output
56
+ else:
57
+ # Unconditional sampling
58
+ def sample_model(x_t, t, *additional_inputs):
59
+ rates = self.noise_schedule.get_rates(t)
60
+ c_in = self.model_output_transform.get_input_scale(rates)
61
+ model_output = self.model.apply(self.params, *self.noise_schedule.transform_inputs(x_t * c_in, t), *additional_inputs)
62
+ x_0, eps = self.model_output_transform(x_t, model_output, t, self.noise_schedule)
63
+ return x_0, eps, model_output
64
+
65
+ # if jax.device_count() > 1:
66
+ # mesh = jax.sharding.Mesh(jax.devices(), 'data')
67
+ # sample_model = shard_map(sample_model, mesh=mesh, in_specs=(P('data'), P('data'), P('data')),
68
+ # out_specs=(P('data'), P('data'), P('data')))
69
+ sample_model = jax.jit(sample_model)
32
70
  self.sample_model = sample_model
33
71
 
34
72
  # Used to sample from the diffusion model
35
- def sample_step(self, current_samples:jnp.ndarray, current_step, next_step=None, state:MarkovState=None) -> tuple[jnp.ndarray, MarkovState]:
73
+ def sample_step(self, current_samples:jnp.ndarray, current_step, model_conditioning_inputs, next_step=None, state:MarkovState=None) -> tuple[jnp.ndarray, MarkovState]:
36
74
  # First clip the noisy images
37
- # pred_images = clip_images(pred_images)
38
75
  step_ones = jnp.ones((current_samples.shape[0], ), dtype=jnp.int32)
39
76
  current_step = step_ones * current_step
40
77
  next_step = step_ones * next_step
41
- pred_images, pred_noise, _ = self.sample_model(current_samples, current_step)
78
+ pred_images, pred_noise, _ = self.sample_model(current_samples, current_step, *model_conditioning_inputs)
42
79
  # plotImages(pred_images)
80
+ # pred_images = clip_images(pred_images)
43
81
  new_samples, state = self.take_next_step(current_samples=current_samples, reconstructed_samples=pred_images,
44
- pred_noise=pred_noise, current_step=current_step, next_step=next_step, state=state)
82
+ pred_noise=pred_noise, current_step=current_step, next_step=next_step, state=state,
83
+ model_conditioning_inputs=model_conditioning_inputs
84
+ )
45
85
  return new_samples, state
46
86
 
47
- def take_next_step(self, current_samples, reconstructed_samples,
87
+ def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
48
88
  pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
49
89
  # estimate the q(x_{t-1} | x_t, x_0).
50
90
  # pred_images is x_0, noisy_images is x_t, steps is t
@@ -62,11 +102,16 @@ class DiffusionSampler():
62
102
  steps = jnp.linspace(end_step, start_step, diffusion_steps, dtype=jnp.int16)[::-1]
63
103
  return steps
64
104
 
65
- def get_initial_samples(self, num_images, rngs:jax.random.PRNGKey, start_step, image_size=64):
105
+ def get_initial_samples(self, num_images, rngs:jax.random.PRNGKey, start_step):
66
106
  start_step = self.scale_steps(start_step)
67
107
  alpha_n, sigma_n = self.noise_schedule.get_rates(start_step)
68
108
  variance = jnp.sqrt(alpha_n ** 2 + sigma_n ** 2)
69
- return jax.random.normal(rngs, (num_images, image_size, image_size, 3)) * variance
109
+ image_size = self.image_size
110
+ image_channels = 3
111
+ if self.autoencoder is not None:
112
+ image_size = image_size // self.autoenc_scale_reduction
113
+ image_channels = self.autoenc_latent_channels
114
+ return jax.random.normal(rngs, (num_images, image_size, image_size, image_channels)) * variance
70
115
 
71
116
  def generate_images(self,
72
117
  num_images=16,
@@ -75,18 +120,23 @@ class DiffusionSampler():
75
120
  end_step:int = 0,
76
121
  steps_override=None,
77
122
  priors=None,
78
- rngstate:RandomMarkovState=RandomMarkovState(jax.random.PRNGKey(42))) -> jnp.ndarray:
123
+ rngstate:RandomMarkovState=RandomMarkovState(jax.random.PRNGKey(42)),
124
+ model_conditioning_inputs:tuple=()
125
+ ) -> jnp.ndarray:
79
126
  if priors is None:
80
127
  rngstate, newrngs = rngstate.get_random_key()
81
128
  samples = self.get_initial_samples(num_images, newrngs, start_step)
82
129
  else:
83
130
  print("Using priors")
131
+ if self.autoencoder is not None:
132
+ priors = self.autoencoder.encode(priors)
84
133
  samples = priors
85
134
 
86
- @jax.jit
135
+ # @jax.jit
87
136
  def sample_step(state:RandomMarkovState, samples, current_step, next_step):
88
137
  samples, state = self.sample_step(current_samples=samples,
89
138
  current_step=current_step,
139
+ model_conditioning_inputs=model_conditioning_inputs,
90
140
  state=state, next_step=next_step)
91
141
  return samples, state
92
142
 
@@ -108,6 +158,8 @@ class DiffusionSampler():
108
158
  else:
109
159
  # print("last step")
110
160
  step_ones = jnp.ones((num_images, ), dtype=jnp.int32)
111
- samples, _, _ = self.sample_model(samples, current_step * step_ones)
161
+ samples, _, _ = self.sample_model(samples, current_step * step_ones, *model_conditioning_inputs)
162
+ if self.autoencoder is not None:
163
+ samples = self.autoencoder.decode(samples)
112
164
  samples = clip_images(samples)
113
- return samples
165
+ return samples
@@ -0,0 +1,10 @@
1
+ import jax.numpy as jnp
2
+ from .common import DiffusionSampler
3
+ from ..utils import MarkovState, RandomMarkovState
4
+
5
+ class DDIMSampler(DiffusionSampler):
6
+ def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
7
+ pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
8
+ next_signal_rate, next_noise_rate = self.noise_schedule.get_rates(next_step)
9
+ return reconstructed_samples * next_signal_rate + pred_noise * next_noise_rate, state
10
+
@@ -3,9 +3,8 @@ import jax.numpy as jnp
3
3
  from .common import DiffusionSampler
4
4
  from ..utils import MarkovState, RandomMarkovState
5
5
  class DDPMSampler(DiffusionSampler):
6
- def take_next_step(self,
7
- current_samples, reconstructed_samples,
8
- pred_noise, current_step, state:RandomMarkovState, next_step=None) -> tuple[jnp.ndarray, RandomMarkovState]:
6
+ def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
7
+ pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
9
8
  mean = self.noise_schedule.get_posterior_mean(reconstructed_samples, current_samples, current_step)
10
9
  variance = self.noise_schedule.get_posterior_variance(steps=current_step)
11
10
 
@@ -19,9 +18,8 @@ class DDPMSampler(DiffusionSampler):
19
18
  return super().generate_images(num_images=num_images, diffusion_steps=diffusion_steps, start_step=start_step, *args, **kwargs)
20
19
 
21
20
  class SimpleDDPMSampler(DiffusionSampler):
22
- def take_next_step(self,
23
- current_samples, reconstructed_samples,
24
- pred_noise, current_step, state:RandomMarkovState, next_step=None) -> tuple[jnp.ndarray, RandomMarkovState]:
21
+ def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
22
+ pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
25
23
  state, rng = state.get_random_key()
26
24
  noise = jax.random.normal(rng, reconstructed_samples.shape, dtype=jnp.float32)
27
25
 
@@ -33,11 +31,7 @@ class SimpleDDPMSampler(DiffusionSampler):
33
31
 
34
32
  noise_ratio_squared = (next_noise_rate ** 2) / (current_noise_rate ** 2)
35
33
  signal_ratio_squared = (current_signal_rate ** 2) / (next_signal_rate ** 2)
36
- betas = (1 - signal_ratio_squared)
37
- gamma = jnp.sqrt(noise_ratio_squared * betas)
34
+ gamma = jnp.sqrt(noise_ratio_squared * (1 - signal_ratio_squared))
38
35
 
39
36
  next_samples = next_signal_rate * reconstructed_samples + pred_noise_coeff * pred_noise + noise * gamma
40
- # pred_noise_coeff = ((next_noise_rate ** 2) * current_signal_rate) / (current_noise_rate * next_signal_rate)
41
- # next_samples = (2 - jnp.sqrt(1 - betas)) * current_samples - betas * (pred_noise / current_noise_rate) + noise * gamma#jnp.sqrt(betas)
42
- # next_samples = (1 / (jnp.sqrt(1 - betas) + 1.e-24)) * (current_samples - betas * (pred_noise / current_noise_rate)) + noise * gamma
43
37
  return next_samples, state
@@ -5,9 +5,8 @@ from ..utils import RandomMarkovState
5
5
 
6
6
  class EulerSampler(DiffusionSampler):
7
7
  # Basically a DDIM Sampler but parameterized as an ODE
8
- def take_next_step(self,
9
- current_samples, reconstructed_samples,
10
- pred_noise, current_step, state:RandomMarkovState, next_step=None) -> tuple[jnp.ndarray, RandomMarkovState]:
8
+ def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
9
+ pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
11
10
  current_alpha, current_sigma = self.noise_schedule.get_rates(current_step)
12
11
  next_alpha, next_sigma = self.noise_schedule.get_rates(next_step)
13
12
 
@@ -22,9 +21,8 @@ class SimplifiedEulerSampler(DiffusionSampler):
22
21
  """
23
22
  This is for networks with forward diffusion of the form x_{t+1} = x_t + sigma_t * epsilon_t
24
23
  """
25
- def take_next_step(self,
26
- current_samples, reconstructed_samples,
27
- pred_noise, current_step, state:RandomMarkovState, next_step=None) -> tuple[jnp.ndarray, RandomMarkovState]:
24
+ def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
25
+ pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
28
26
  _, current_sigma = self.noise_schedule.get_rates(current_step)
29
27
  _, next_sigma = self.noise_schedule.get_rates(next_step)
30
28
 
@@ -38,9 +36,8 @@ class EulerAncestralSampler(DiffusionSampler):
38
36
  """
39
37
  Similar to EulerSampler but with ancestral sampling
40
38
  """
41
- def take_next_step(self,
42
- current_samples, reconstructed_samples,
43
- pred_noise, current_step, state:RandomMarkovState, next_step=None) -> tuple[jnp.ndarray, RandomMarkovState]:
39
+ def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
40
+ pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
44
41
  current_alpha, current_sigma = self.noise_schedule.get_rates(current_step)
45
42
  next_alpha, next_sigma = self.noise_schedule.get_rates(next_step)
46
43
 
@@ -56,4 +53,4 @@ class EulerAncestralSampler(DiffusionSampler):
56
53
  dW = jax.random.normal(subkey, current_samples.shape) * sigma_up
57
54
 
58
55
  next_samples = current_samples + dx * dt + dW
59
- return next_samples, state
56
+ return next_samples, state