flaxdiff 0.1.35.5__py3-none-any.whl → 0.1.36__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,71 @@
1
+ from .sources.tfds import data_source_tfds, tfds_augmenters
2
+ from .sources.gcs import data_source_gcs, data_source_combined_gcs, gcs_augmenters
3
+
4
+ # Configure the following for your datasets
5
+ datasetMap = {
6
+ "oxford_flowers102": {
7
+ "source": data_source_tfds("oxford_flowers102", use_tf=False),
8
+ "augmenter": tfds_augmenters,
9
+ },
10
+ "cc12m": {
11
+ "source": data_source_gcs('arrayrecord2/cc12m'),
12
+ "augmenter": gcs_augmenters,
13
+ },
14
+ "laiona_coco": {
15
+ "source": data_source_gcs('arrayrecord2/laion-aesthetics-12m+mscoco-2017'),
16
+ "augmenter": gcs_augmenters,
17
+ },
18
+ "aesthetic_coyo": {
19
+ "source": data_source_gcs('arrayrecords/aestheticCoyo_0.25clip_6aesthetic'),
20
+ "augmenter": gcs_augmenters,
21
+ },
22
+ "combined_aesthetic": {
23
+ "source": data_source_combined_gcs([
24
+ 'arrayrecord2/laion-aesthetics-12m+mscoco-2017',
25
+ 'arrayrecords/aestheticCoyo_0.25clip_6aesthetic',
26
+ 'arrayrecord2/cc12m',
27
+ 'arrayrecords/aestheticCoyo_0.25clip_6aesthetic',
28
+ ]),
29
+ "augmenter": gcs_augmenters,
30
+ },
31
+ "laiona_coco_coyo": {
32
+ "source": data_source_combined_gcs([
33
+ 'arrayrecords/aestheticCoyo_0.25clip_6aesthetic',
34
+ 'arrayrecord2/laion-aesthetics-12m+mscoco-2017',
35
+ 'arrayrecords/aestheticCoyo_0.25clip_6aesthetic',
36
+ ]),
37
+ "augmenter": gcs_augmenters,
38
+ },
39
+ "combined_30m": {
40
+ "source": data_source_combined_gcs([
41
+ 'arrayrecord2/laion-aesthetics-12m+mscoco-2017',
42
+ 'arrayrecord2/cc12m',
43
+ 'arrayrecord2/aestheticCoyo_0.26_clip_5.5aesthetic_256plus',
44
+ "arrayrecord2/playground+leonardo_x4+cc3m.parquet",
45
+ ]),
46
+ "augmenter": gcs_augmenters,
47
+ }
48
+ }
49
+
50
+ onlineDatasetMap = {
51
+ "combined_online": {
52
+ "source": [
53
+ # "gs://flaxdiff-datasets-regional/datasets/laion-aesthetics-12m+mscoco-2017.parquet"
54
+ # "ChristophSchuhmann/MS_COCO_2017_URL_TEXT",
55
+ # "dclure/laion-aesthetics-12m-umap",
56
+ "gs://flaxdiff-datasets-regional/datasets/laion-aesthetics-12m+mscoco-2017",
57
+ "gs://flaxdiff-datasets-regional/datasets/coyo700m-aesthetic-5.4_25M",
58
+ "gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m",
59
+ "gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m",
60
+ "gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m",
61
+ "gs://flaxdiff-datasets-regional/datasets/cc12m",
62
+ "gs://flaxdiff-datasets-regional/datasets/playground-liked",
63
+ "gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m",
64
+ "gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m",
65
+ "gs://flaxdiff-datasets-regional/datasets/cc3m",
66
+ "gs://flaxdiff-datasets-regional/datasets/cc3m",
67
+ "gs://flaxdiff-datasets-regional/datasets/laion2B-en-aesthetic-4.2_37M",
68
+ # "gs://flaxdiff-datasets-regional/datasets/laiion400m-185M"
69
+ ]
70
+ }
71
+ }
@@ -0,0 +1,169 @@
1
+ import jax.numpy as jnp
2
+ import grain.python as pygrain
3
+ from typing import Dict
4
+ import numpy as np
5
+ import jax
6
+ from flaxdiff.utils import convert_to_global_tree, AutoTextTokenizer
7
+ from .dataset_map import datasetMap, onlineDatasetMap
8
+ import traceback
9
+ from .online_loader import OnlineStreamingDataLoader
10
+ import queue
11
+ from jax.sharding import Mesh
12
+ import threading
13
+
14
+ def batch_mesh_map(mesh):
15
+ class augmenters(pygrain.MapTransform):
16
+ def __init__(self, *args, **kwargs):
17
+ super().__init__(*args, **kwargs)
18
+
19
+ def map(self, batch) -> Dict[str, jnp.array]:
20
+ return convert_to_global_tree(mesh, batch)
21
+ return augmenters
22
+
23
+ def get_dataset_grain(
24
+ data_name="cc12m",
25
+ batch_size=64,
26
+ image_scale=256,
27
+ count=None,
28
+ num_epochs=None,
29
+ method=jax.image.ResizeMethod.LANCZOS3,
30
+ worker_count=32,
31
+ read_thread_count=64,
32
+ read_buffer_size=50,
33
+ worker_buffer_size=20,
34
+ seed=0,
35
+ dataset_source="/mnt/gcs_mount/arrayrecord2/cc12m/",
36
+ ):
37
+ dataset = datasetMap[data_name]
38
+ data_source = dataset["source"](dataset_source)
39
+ augmenter = dataset["augmenter"](image_scale, method)
40
+
41
+ local_batch_size = batch_size // jax.process_count()
42
+
43
+ sampler = pygrain.IndexSampler(
44
+ num_records=len(data_source) if count is None else count,
45
+ shuffle=True,
46
+ seed=seed,
47
+ num_epochs=num_epochs,
48
+ shard_options=pygrain.ShardByJaxProcess(),
49
+ )
50
+
51
+ def get_trainset():
52
+ transformations = [
53
+ augmenter(),
54
+ pygrain.Batch(local_batch_size, drop_remainder=True),
55
+ ]
56
+
57
+ # if mesh != None:
58
+ # transformations += [batch_mesh_map(mesh)]
59
+
60
+ loader = pygrain.DataLoader(
61
+ data_source=data_source,
62
+ sampler=sampler,
63
+ operations=transformations,
64
+ worker_count=worker_count,
65
+ read_options=pygrain.ReadOptions(
66
+ read_thread_count, read_buffer_size
67
+ ),
68
+ worker_buffer_size=worker_buffer_size,
69
+ )
70
+ return loader
71
+
72
+
73
+ return {
74
+ "train": get_trainset,
75
+ "train_len": len(data_source),
76
+ "local_batch_size": local_batch_size,
77
+ "global_batch_size": batch_size,
78
+ # "null_labels": null_labels,
79
+ # "null_labels_full": null_labels_full,
80
+ # "model": model,
81
+ # "tokenizer": tokenizer,
82
+ }
83
+
84
+ def generate_collate_fn():
85
+ auto_tokenize = AutoTextTokenizer(tensor_type="np")
86
+ def default_collate(batch):
87
+ try:
88
+ # urls = [sample["url"] for sample in batch]
89
+ captions = [sample["caption"] for sample in batch]
90
+ results = auto_tokenize(captions)
91
+ images = np.stack([sample["image"] for sample in batch], axis=0)
92
+ return {
93
+ "image": images,
94
+ "input_ids": results['input_ids'],
95
+ "attention_mask": results['attention_mask'],
96
+ }
97
+ except Exception as e:
98
+ print("Error in collate function", e, [sample["image"].shape for sample in batch])
99
+ traceback.print_exc()
100
+
101
+ return default_collate
102
+
103
+ def get_dataset_online(
104
+ data_name="combined_online",
105
+ batch_size=64,
106
+ image_scale=256,
107
+ count=None,
108
+ num_epochs=None,
109
+ method=jax.image.ResizeMethod.LANCZOS3,
110
+ worker_count=32,
111
+ read_thread_count=64,
112
+ read_buffer_size=50,
113
+ worker_buffer_size=20,
114
+ seed=0,
115
+ dataset_source="/mnt/gcs_mount/arrayrecord2/cc12m/",
116
+ ):
117
+ local_batch_size = batch_size // jax.process_count()
118
+
119
+ sources = onlineDatasetMap[data_name]["source"]
120
+ dataloader = OnlineStreamingDataLoader(
121
+ sources,
122
+ batch_size=local_batch_size,
123
+ num_workers=worker_count,
124
+ num_threads=read_thread_count,
125
+ image_shape=(image_scale, image_scale),
126
+ global_process_count=jax.process_count(),
127
+ global_process_index=jax.process_index(),
128
+ prefetch=worker_buffer_size,
129
+ collate_fn=generate_collate_fn(),
130
+ default_split="train",
131
+ )
132
+
133
+ def get_trainset(mesh: Mesh = None):
134
+ if mesh != None:
135
+ class dataLoaderWithMesh:
136
+ def __init__(self, dataloader, mesh):
137
+ self.dataloader = dataloader
138
+ self.mesh = mesh
139
+ self.tmp_queue = queue.Queue(worker_buffer_size)
140
+ def batch_loader():
141
+ for batch in self.dataloader:
142
+ try:
143
+ self.tmp_queue.put(convert_to_global_tree(mesh, batch))
144
+ except Exception as e:
145
+ print("Error processing batch", e)
146
+ self.loader_thread = threading.Thread(target=batch_loader)
147
+ self.loader_thread.start()
148
+
149
+ def __iter__(self):
150
+ return self
151
+
152
+ def __next__(self):
153
+ return self.tmp_queue.get()
154
+
155
+ dataloader_with_mesh = dataLoaderWithMesh(dataloader, mesh)
156
+
157
+ return dataloader_with_mesh
158
+ return dataloader
159
+
160
+ return {
161
+ "train": get_trainset,
162
+ "train_len": len(dataloader) * jax.process_count(),
163
+ "local_batch_size": local_batch_size,
164
+ "global_batch_size": batch_size,
165
+ # "null_labels": null_labels,
166
+ # "null_labels_full": null_labels_full,
167
+ # "model": model,
168
+ # "tokenizer": tokenizer,
169
+ }
@@ -45,36 +45,43 @@ def fetch_single_image(image_url, timeout=None, retries=0):
45
45
 
46
46
 
47
47
  def default_image_processor(
48
- image, image_shape,
48
+ image, image_shape,
49
49
  min_image_shape=(128, 128),
50
50
  upscale_interpolation=cv2.INTER_CUBIC,
51
51
  downscale_interpolation=cv2.INTER_AREA,
52
52
  ):
53
- image = np.array(image)
54
- original_height, original_width = image.shape[:2]
55
- # check if the image is too small
56
- if min(original_height, original_width) < min(min_image_shape):
57
- return None, original_height, original_width
58
- # check if wrong aspect ratio
59
- if max(original_height, original_width) / min(original_height, original_width) > 2.4:
60
- return None, original_height, original_width
61
- # check if the variance is too low
62
- if np.std(image) < 1e-5:
63
- return None, original_height, original_width
64
- # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
65
- downscale = max(original_width, original_height) > max(image_shape)
66
- interpolation = downscale_interpolation if downscale else upscale_interpolation
67
-
68
- image = A.longest_max_size(image, max(
69
- image_shape), interpolation=interpolation)
70
- image = A.pad(
71
- image,
72
- min_height=image_shape[0],
73
- min_width=image_shape[1],
74
- border_mode=cv2.BORDER_CONSTANT,
75
- value=[255, 255, 255],
76
- )
77
- return image, original_height, original_width
53
+ try:
54
+ image = np.array(image)
55
+ if len(image.shape) != 3 or image.shape[2] != 3:
56
+ return None, 0, 0
57
+ original_height, original_width = image.shape[:2]
58
+ # check if the image is too small
59
+ if min(original_height, original_width) < min(min_image_shape):
60
+ return None, original_height, original_width
61
+ # check if wrong aspect ratio
62
+ if max(original_height, original_width) / min(original_height, original_width) > 2.4:
63
+ return None, original_height, original_width
64
+ # check if the variance is too low
65
+ if np.std(image) < 1e-5:
66
+ return None, original_height, original_width
67
+ # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
68
+ downscale = max(original_width, original_height) > max(image_shape)
69
+ interpolation = downscale_interpolation if downscale else upscale_interpolation
70
+
71
+ image = A.longest_max_size(image, max(
72
+ image_shape), interpolation=interpolation)
73
+ image = A.pad(
74
+ image,
75
+ min_height=image_shape[0],
76
+ min_width=image_shape[1],
77
+ border_mode=cv2.BORDER_CONSTANT,
78
+ value=[255, 255, 255],
79
+ )
80
+ return image, original_height, original_width
81
+ except Exception as e:
82
+ # print("Error processing image", e, image_shape, interpolation)
83
+ # traceback.print_exc()
84
+ return None, 0, 0
78
85
 
79
86
 
80
87
  def map_sample(
@@ -120,14 +127,36 @@ def map_sample(
120
127
  # })
121
128
  pass
122
129
 
123
-
124
130
  def default_feature_extractor(sample):
131
+ url = None
132
+ if "url" in sample:
133
+ url = sample["url"]
134
+ elif "URL" in sample:
135
+ url = sample["URL"]
136
+ elif "image_url" in sample:
137
+ url = sample["image_url"]
138
+ else:
139
+ print("No url found in sample, skipping", sample.keys())
140
+
141
+ caption = None
142
+ if "caption" in sample:
143
+ caption = sample["caption"]
144
+ elif "CAPTION" in sample:
145
+ caption = sample["CAPTION"]
146
+ elif "txt" in sample:
147
+ caption = sample["txt"]
148
+ elif "TEXT" in sample:
149
+ caption = sample["TEXT"]
150
+ elif "text" in sample:
151
+ caption = sample["text"]
152
+ else:
153
+ print("No caption found in sample, skipping", sample.keys())
154
+
125
155
  return {
126
- "url": sample["url"],
127
- "caption": sample["caption"],
156
+ "url": url,
157
+ "caption": caption,
128
158
  }
129
-
130
-
159
+
131
160
  def map_batch(
132
161
  batch, num_threads=256, image_shape=(256, 256),
133
162
  min_image_shape=(128, 128),
@@ -301,15 +330,13 @@ class OnlineStreamingDataLoader():
301
330
  self.dataset = dataset.shard(
302
331
  num_shards=global_process_count, index=global_process_index)
303
332
  print(f"Dataset length: {len(dataset)}")
304
- self.iterator = ImageBatchIterator(
305
- self.dataset, image_shape=image_shape,
306
- min_image_shape=min_image_shape,
307
- num_workers=num_workers, batch_size=batch_size, num_threads=num_threads,
308
- timeout=timeout, retries=retries, image_processor=image_processor,
309
- upscale_interpolation=upscale_interpolation,
310
- downscale_interpolation=downscale_interpolation,
311
- feature_extractor=feature_extractor
312
- )
333
+ self.iterator = ImageBatchIterator(self.dataset, image_shape=image_shape,
334
+ min_image_shape=min_image_shape,
335
+ num_workers=num_workers, batch_size=batch_size, num_threads=num_threads,
336
+ timeout=timeout, retries=retries, image_processor=image_processor,
337
+ upscale_interpolation=upscale_interpolation,
338
+ downscale_interpolation=downscale_interpolation,
339
+ feature_extractor=feature_extractor)
313
340
  self.batch_size = batch_size
314
341
 
315
342
  # Launch a thread to load batches in the background
@@ -320,7 +347,7 @@ class OnlineStreamingDataLoader():
320
347
  try:
321
348
  self.batch_queue.put(collate_fn(batch))
322
349
  except Exception as e:
323
- print("Error processing batch", e)
350
+ print("Error collating batch", e)
324
351
 
325
352
  self.loader_thread = threading.Thread(target=batch_loader)
326
353
  self.loader_thread.start()
@@ -333,4 +360,4 @@ class OnlineStreamingDataLoader():
333
360
  # return self.collate_fn(next(self.iterator))
334
361
 
335
362
  def __len__(self):
336
- return len(self.dataset)
363
+ return len(self.dataset)
@@ -11,6 +11,7 @@ import einops
11
11
  import functools
12
12
  import math
13
13
  from .common import kernel_init
14
+ import jax.experimental.pallas.ops.tpu.flash_attention
14
15
 
15
16
  class EfficientAttention(nn.Module):
16
17
  """
@@ -50,7 +50,7 @@ class Unet(nn.Module):
50
50
  features=self.feature_depths[0],
51
51
  kernel_size=(3, 3),
52
52
  strides=(1, 1),
53
- kernel_init=self.kernel_init(1.0),
53
+ kernel_init=self.kernel_init(scale=1.0),
54
54
  dtype=self.dtype,
55
55
  precision=self.precision
56
56
  )(x)
@@ -65,7 +65,7 @@ class Unet(nn.Module):
65
65
  down_conv_type,
66
66
  name=f"down_{i}_residual_{j}",
67
67
  features=dim_in,
68
- kernel_init=self.kernel_init(1.0),
68
+ kernel_init=self.kernel_init(scale=1.0),
69
69
  kernel_size=(3, 3),
70
70
  strides=(1, 1),
71
71
  activation=self.activation,
@@ -85,7 +85,7 @@ class Unet(nn.Module):
85
85
  force_fp32_for_softmax=attention_config.get("force_fp32_for_softmax", False),
86
86
  norm_inputs=attention_config.get("norm_inputs", True),
87
87
  explicitly_add_residual=attention_config.get("explicitly_add_residual", True),
88
- kernel_init=self.kernel_init(1.0),
88
+ kernel_init=self.kernel_init(scale=1.0),
89
89
  name=f"down_{i}_attention_{j}")(x, textcontext)
90
90
  # print("down residual for feature level", i, "is of shape", x.shape, "features", dim_in)
91
91
  downs.append(x)
@@ -108,7 +108,7 @@ class Unet(nn.Module):
108
108
  middle_conv_type,
109
109
  name=f"middle_res1_{j}",
110
110
  features=middle_dim_out,
111
- kernel_init=self.kernel_init(1.0),
111
+ kernel_init=self.kernel_init(scale=1.0),
112
112
  kernel_size=(3, 3),
113
113
  strides=(1, 1),
114
114
  activation=self.activation,
@@ -129,13 +129,13 @@ class Unet(nn.Module):
129
129
  force_fp32_for_softmax=middle_attention.get("force_fp32_for_softmax", False),
130
130
  norm_inputs=middle_attention.get("norm_inputs", True),
131
131
  explicitly_add_residual=middle_attention.get("explicitly_add_residual", True),
132
- kernel_init=self.kernel_init(1.0),
132
+ kernel_init=self.kernel_init(scale=1.0),
133
133
  name=f"middle_attention_{j}")(x, textcontext)
134
134
  x = ResidualBlock(
135
135
  middle_conv_type,
136
136
  name=f"middle_res2_{j}",
137
137
  features=middle_dim_out,
138
- kernel_init=self.kernel_init(1.0),
138
+ kernel_init=self.kernel_init(scale=1.0),
139
139
  kernel_size=(3, 3),
140
140
  strides=(1, 1),
141
141
  activation=self.activation,
@@ -157,7 +157,7 @@ class Unet(nn.Module):
157
157
  up_conv_type,# if j == 0 else "separable",
158
158
  name=f"up_{i}_residual_{j}",
159
159
  features=dim_out,
160
- kernel_init=self.kernel_init(1.0),
160
+ kernel_init=self.kernel_init(scale=1.0),
161
161
  kernel_size=kernel_size,
162
162
  strides=(1, 1),
163
163
  activation=self.activation,
@@ -177,7 +177,7 @@ class Unet(nn.Module):
177
177
  force_fp32_for_softmax=middle_attention.get("force_fp32_for_softmax", False),
178
178
  norm_inputs=attention_config.get("norm_inputs", True),
179
179
  explicitly_add_residual=attention_config.get("explicitly_add_residual", True),
180
- kernel_init=self.kernel_init(1.0),
180
+ kernel_init=self.kernel_init(scale=1.0),
181
181
  name=f"up_{i}_attention_{j}")(x, textcontext)
182
182
  # print("Upscaling ", i, x.shape)
183
183
  if i != len(feature_depths) - 1:
@@ -196,7 +196,7 @@ class Unet(nn.Module):
196
196
  features=self.feature_depths[0],
197
197
  kernel_size=(3, 3),
198
198
  strides=(1, 1),
199
- kernel_init=self.kernel_init(1.0),
199
+ kernel_init=self.kernel_init(scale=1.0),
200
200
  dtype=self.dtype,
201
201
  precision=self.precision
202
202
  )(x)
@@ -207,7 +207,7 @@ class Unet(nn.Module):
207
207
  conv_type,
208
208
  name="final_residual",
209
209
  features=self.feature_depths[0],
210
- kernel_init=self.kernel_init(1.0),
210
+ kernel_init=self.kernel_init(scale=1.0),
211
211
  kernel_size=(3,3),
212
212
  strides=(1, 1),
213
213
  activation=self.activation,
@@ -226,7 +226,7 @@ class Unet(nn.Module):
226
226
  kernel_size=(3, 3),
227
227
  strides=(1, 1),
228
228
  # activation=jax.nn.mish
229
- kernel_init=self.kernel_init(0.0),
229
+ kernel_init=self.kernel_init(scale=0.0),
230
230
  dtype=self.dtype,
231
231
  precision=self.precision
232
232
  )(x)
@@ -70,7 +70,7 @@ class UViT(nn.Module):
70
70
  kernel_init: Callable = partial(kernel_init, scale=1.0)
71
71
  add_residualblock_output: bool = False
72
72
  norm_inputs: bool = False
73
- explicitly_add_residual: bool = False
73
+ explicitly_add_residual: bool = True
74
74
 
75
75
  def setup(self):
76
76
  if self.norm_groups > 0:
@@ -15,36 +15,76 @@ class DiffusionSampler():
15
15
 
16
16
  def __init__(self, model:nn.Module, params:dict,
17
17
  noise_schedule:NoiseScheduler,
18
- model_output_transform:DiffusionPredictionTransform=EpsilonPredictionTransform()):
18
+ model_output_transform:DiffusionPredictionTransform=EpsilonPredictionTransform(),
19
+ guidance_scale:float = 0.0,
20
+ null_labels_seq:jax.Array=None,
21
+ autoencoder=None,
22
+ image_size=256,
23
+ autoenc_scale_reduction=8,
24
+ autoenc_latent_channels=4,
25
+ ):
19
26
  self.model = model
20
27
  self.noise_schedule = noise_schedule
21
28
  self.params = params
22
29
  self.model_output_transform = model_output_transform
23
-
24
- @jax.jit
25
- def sample_model(x_t, t):
26
- rates = self.noise_schedule.get_rates(t)
27
- c_in = self.model_output_transform.get_input_scale(rates)
28
- model_output = self.model.apply(self.params, *self.noise_schedule.transform_inputs(x_t * c_in, t))
29
- x_0, eps = self.model_output_transform(x_t, model_output, t, self.noise_schedule)
30
- return x_0, eps, model_output
30
+ self.guidance_scale = guidance_scale
31
+ self.image_size = image_size
32
+ self.autoenc_scale_reduction = autoenc_scale_reduction
33
+ self.autoencoder = autoencoder
34
+ self.autoenc_latent_channels = autoenc_latent_channels
31
35
 
36
+ if self.guidance_scale > 0:
37
+ # Classifier free guidance
38
+ assert null_labels_seq is not None, "Null labels sequence is required for classifier-free guidance"
39
+ print("Using classifier-free guidance")
40
+ def sample_model(x_t, t, *additional_inputs):
41
+ # Concatenate unconditional and conditional inputs
42
+ x_t_cat = jnp.concatenate([x_t] * 2, axis=0)
43
+ t_cat = jnp.concatenate([t] * 2, axis=0)
44
+ rates_cat = self.noise_schedule.get_rates(t_cat)
45
+ c_in_cat = self.model_output_transform.get_input_scale(rates_cat)
46
+
47
+ text_labels_seq, = additional_inputs
48
+ text_labels_seq = jnp.concatenate([text_labels_seq, jnp.broadcast_to(null_labels_seq, text_labels_seq.shape)], axis=0)
49
+ model_output = self.model.apply(self.params, *self.noise_schedule.transform_inputs(x_t_cat * c_in_cat, t_cat), text_labels_seq)
50
+ # Split model output into unconditional and conditional parts
51
+ model_output_cond, model_output_uncond = jnp.split(model_output, 2, axis=0)
52
+ model_output = model_output_uncond + guidance_scale * (model_output_cond - model_output_uncond)
53
+
54
+ x_0, eps = self.model_output_transform(x_t, model_output, t, self.noise_schedule)
55
+ return x_0, eps, model_output
56
+ else:
57
+ # Unconditional sampling
58
+ def sample_model(x_t, t, *additional_inputs):
59
+ rates = self.noise_schedule.get_rates(t)
60
+ c_in = self.model_output_transform.get_input_scale(rates)
61
+ model_output = self.model.apply(self.params, *self.noise_schedule.transform_inputs(x_t * c_in, t), *additional_inputs)
62
+ x_0, eps = self.model_output_transform(x_t, model_output, t, self.noise_schedule)
63
+ return x_0, eps, model_output
64
+
65
+ # if jax.device_count() > 1:
66
+ # mesh = jax.sharding.Mesh(jax.devices(), 'data')
67
+ # sample_model = shard_map(sample_model, mesh=mesh, in_specs=(P('data'), P('data'), P('data')),
68
+ # out_specs=(P('data'), P('data'), P('data')))
69
+ sample_model = jax.jit(sample_model)
32
70
  self.sample_model = sample_model
33
71
 
34
72
  # Used to sample from the diffusion model
35
- def sample_step(self, current_samples:jnp.ndarray, current_step, next_step=None, state:MarkovState=None) -> tuple[jnp.ndarray, MarkovState]:
73
+ def sample_step(self, current_samples:jnp.ndarray, current_step, model_conditioning_inputs, next_step=None, state:MarkovState=None) -> tuple[jnp.ndarray, MarkovState]:
36
74
  # First clip the noisy images
37
- # pred_images = clip_images(pred_images)
38
75
  step_ones = jnp.ones((current_samples.shape[0], ), dtype=jnp.int32)
39
76
  current_step = step_ones * current_step
40
77
  next_step = step_ones * next_step
41
- pred_images, pred_noise, _ = self.sample_model(current_samples, current_step)
78
+ pred_images, pred_noise, _ = self.sample_model(current_samples, current_step, *model_conditioning_inputs)
42
79
  # plotImages(pred_images)
80
+ # pred_images = clip_images(pred_images)
43
81
  new_samples, state = self.take_next_step(current_samples=current_samples, reconstructed_samples=pred_images,
44
- pred_noise=pred_noise, current_step=current_step, next_step=next_step, state=state)
82
+ pred_noise=pred_noise, current_step=current_step, next_step=next_step, state=state,
83
+ model_conditioning_inputs=model_conditioning_inputs
84
+ )
45
85
  return new_samples, state
46
86
 
47
- def take_next_step(self, current_samples, reconstructed_samples,
87
+ def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
48
88
  pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
49
89
  # estimate the q(x_{t-1} | x_t, x_0).
50
90
  # pred_images is x_0, noisy_images is x_t, steps is t
@@ -62,11 +102,16 @@ class DiffusionSampler():
62
102
  steps = jnp.linspace(end_step, start_step, diffusion_steps, dtype=jnp.int16)[::-1]
63
103
  return steps
64
104
 
65
- def get_initial_samples(self, num_images, rngs:jax.random.PRNGKey, start_step, image_size=64):
105
+ def get_initial_samples(self, num_images, rngs:jax.random.PRNGKey, start_step):
66
106
  start_step = self.scale_steps(start_step)
67
107
  alpha_n, sigma_n = self.noise_schedule.get_rates(start_step)
68
108
  variance = jnp.sqrt(alpha_n ** 2 + sigma_n ** 2)
69
- return jax.random.normal(rngs, (num_images, image_size, image_size, 3)) * variance
109
+ image_size = self.image_size
110
+ image_channels = 3
111
+ if self.autoencoder is not None:
112
+ image_size = image_size // self.autoenc_scale_reduction
113
+ image_channels = self.autoenc_latent_channels
114
+ return jax.random.normal(rngs, (num_images, image_size, image_size, image_channels)) * variance
70
115
 
71
116
  def generate_images(self,
72
117
  num_images=16,
@@ -75,18 +120,23 @@ class DiffusionSampler():
75
120
  end_step:int = 0,
76
121
  steps_override=None,
77
122
  priors=None,
78
- rngstate:RandomMarkovState=RandomMarkovState(jax.random.PRNGKey(42))) -> jnp.ndarray:
123
+ rngstate:RandomMarkovState=RandomMarkovState(jax.random.PRNGKey(42)),
124
+ model_conditioning_inputs:tuple=()
125
+ ) -> jnp.ndarray:
79
126
  if priors is None:
80
127
  rngstate, newrngs = rngstate.get_random_key()
81
128
  samples = self.get_initial_samples(num_images, newrngs, start_step)
82
129
  else:
83
130
  print("Using priors")
131
+ if self.autoencoder is not None:
132
+ priors = self.autoencoder.encode(priors)
84
133
  samples = priors
85
134
 
86
- @jax.jit
135
+ # @jax.jit
87
136
  def sample_step(state:RandomMarkovState, samples, current_step, next_step):
88
137
  samples, state = self.sample_step(current_samples=samples,
89
138
  current_step=current_step,
139
+ model_conditioning_inputs=model_conditioning_inputs,
90
140
  state=state, next_step=next_step)
91
141
  return samples, state
92
142
 
@@ -108,6 +158,8 @@ class DiffusionSampler():
108
158
  else:
109
159
  # print("last step")
110
160
  step_ones = jnp.ones((num_images, ), dtype=jnp.int32)
111
- samples, _, _ = self.sample_model(samples, current_step * step_ones)
161
+ samples, _, _ = self.sample_model(samples, current_step * step_ones, *model_conditioning_inputs)
162
+ if self.autoencoder is not None:
163
+ samples = self.autoencoder.decode(samples)
112
164
  samples = clip_images(samples)
113
- return samples
165
+ return samples
flaxdiff/samplers/ddim.py CHANGED
@@ -1,10 +1,10 @@
1
1
  import jax.numpy as jnp
2
2
  from .common import DiffusionSampler
3
- from ..utils import MarkovState
3
+ from ..utils import MarkovState, RandomMarkovState
4
4
 
5
5
  class DDIMSampler(DiffusionSampler):
6
- def take_next_step(self,
7
- current_samples, reconstructed_samples,
8
- pred_noise, current_step, state:MarkovState, next_step=None) -> tuple[jnp.ndarray, MarkovState]:
6
+ def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
7
+ pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
9
8
  next_signal_rate, next_noise_rate = self.noise_schedule.get_rates(next_step)
10
- return reconstructed_samples * next_signal_rate + pred_noise * next_noise_rate, state
9
+ return reconstructed_samples * next_signal_rate + pred_noise * next_noise_rate, state
10
+