flaxdiff 0.1.35.5__tar.gz → 0.1.36__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/PKG-INFO +8 -2
  2. flaxdiff-0.1.36/flaxdiff/data/dataset_map.py +71 -0
  3. flaxdiff-0.1.36/flaxdiff/data/datasets.py +169 -0
  4. {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/data/online_loader.py +69 -42
  5. {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/models/attention.py +1 -0
  6. {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/models/simple_unet.py +11 -11
  7. {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/models/simple_vit.py +1 -1
  8. {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/samplers/common.py +72 -20
  9. flaxdiff-0.1.36/flaxdiff/samplers/ddim.py +10 -0
  10. {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/samplers/ddpm.py +5 -11
  11. {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/samplers/euler.py +7 -10
  12. {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/samplers/heun_sampler.py +3 -4
  13. {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/samplers/multistep_dpm.py +2 -3
  14. {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/samplers/rk4_sampler.py +9 -9
  15. {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/trainer/autoencoder_trainer.py +1 -1
  16. {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/trainer/diffusion_trainer.py +124 -32
  17. {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/trainer/simple_trainer.py +187 -91
  18. flaxdiff-0.1.36/flaxdiff/trainer/video_diffusion_trainer.py +62 -0
  19. flaxdiff-0.1.36/flaxdiff/utils.py +192 -0
  20. {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff.egg-info/PKG-INFO +8 -2
  21. {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff.egg-info/SOURCES.txt +4 -1
  22. {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/setup.py +1 -1
  23. flaxdiff-0.1.35.5/flaxdiff/samplers/ddim.py +0 -10
  24. flaxdiff-0.1.35.5/flaxdiff/utils.py +0 -89
  25. {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/README.md +0 -0
  26. {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/__init__.py +0 -0
  27. {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/data/__init__.py +0 -0
  28. {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/models/__init__.py +0 -0
  29. {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/models/autoencoder/__init__.py +0 -0
  30. {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/models/autoencoder/autoencoder.py +0 -0
  31. {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/models/autoencoder/diffusers.py +0 -0
  32. {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/models/autoencoder/simple_autoenc.py +0 -0
  33. {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/models/common.py +0 -0
  34. {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/models/favor_fastattn.py +0 -0
  35. {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/predictors/__init__.py +0 -0
  36. {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/samplers/__init__.py +0 -0
  37. {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/schedulers/__init__.py +0 -0
  38. {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/schedulers/common.py +0 -0
  39. {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/schedulers/continuous.py +0 -0
  40. {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/schedulers/cosine.py +0 -0
  41. {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/schedulers/discrete.py +0 -0
  42. {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/schedulers/exp.py +0 -0
  43. {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/schedulers/karras.py +0 -0
  44. {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/schedulers/linear.py +0 -0
  45. {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/schedulers/sqrt.py +0 -0
  46. {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff/trainer/__init__.py +0 -0
  47. {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff.egg-info/dependency_links.txt +0 -0
  48. {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff.egg-info/requires.txt +0 -0
  49. {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/flaxdiff.egg-info/top_level.txt +0 -0
  50. {flaxdiff-0.1.35.5 → flaxdiff-0.1.36}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.4
2
2
  Name: flaxdiff
3
- Version: 0.1.35.5
3
+ Version: 0.1.36
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author: Ashish Kumar Singh
6
6
  Author-email: ashishkmr472@gmail.com
@@ -10,6 +10,12 @@ Requires-Dist: optax>=0.2.2
10
10
  Requires-Dist: jax>=0.4.28
11
11
  Requires-Dist: orbax
12
12
  Requires-Dist: clu
13
+ Dynamic: author
14
+ Dynamic: author-email
15
+ Dynamic: description
16
+ Dynamic: description-content-type
17
+ Dynamic: requires-dist
18
+ Dynamic: summary
13
19
 
14
20
  # ![](images/logo.jpeg "FlaxDiff")
15
21
 
@@ -0,0 +1,71 @@
1
+ from .sources.tfds import data_source_tfds, tfds_augmenters
2
+ from .sources.gcs import data_source_gcs, data_source_combined_gcs, gcs_augmenters
3
+
4
+ # Configure the following for your datasets
5
+ datasetMap = {
6
+ "oxford_flowers102": {
7
+ "source": data_source_tfds("oxford_flowers102", use_tf=False),
8
+ "augmenter": tfds_augmenters,
9
+ },
10
+ "cc12m": {
11
+ "source": data_source_gcs('arrayrecord2/cc12m'),
12
+ "augmenter": gcs_augmenters,
13
+ },
14
+ "laiona_coco": {
15
+ "source": data_source_gcs('arrayrecord2/laion-aesthetics-12m+mscoco-2017'),
16
+ "augmenter": gcs_augmenters,
17
+ },
18
+ "aesthetic_coyo": {
19
+ "source": data_source_gcs('arrayrecords/aestheticCoyo_0.25clip_6aesthetic'),
20
+ "augmenter": gcs_augmenters,
21
+ },
22
+ "combined_aesthetic": {
23
+ "source": data_source_combined_gcs([
24
+ 'arrayrecord2/laion-aesthetics-12m+mscoco-2017',
25
+ 'arrayrecords/aestheticCoyo_0.25clip_6aesthetic',
26
+ 'arrayrecord2/cc12m',
27
+ 'arrayrecords/aestheticCoyo_0.25clip_6aesthetic',
28
+ ]),
29
+ "augmenter": gcs_augmenters,
30
+ },
31
+ "laiona_coco_coyo": {
32
+ "source": data_source_combined_gcs([
33
+ 'arrayrecords/aestheticCoyo_0.25clip_6aesthetic',
34
+ 'arrayrecord2/laion-aesthetics-12m+mscoco-2017',
35
+ 'arrayrecords/aestheticCoyo_0.25clip_6aesthetic',
36
+ ]),
37
+ "augmenter": gcs_augmenters,
38
+ },
39
+ "combined_30m": {
40
+ "source": data_source_combined_gcs([
41
+ 'arrayrecord2/laion-aesthetics-12m+mscoco-2017',
42
+ 'arrayrecord2/cc12m',
43
+ 'arrayrecord2/aestheticCoyo_0.26_clip_5.5aesthetic_256plus',
44
+ "arrayrecord2/playground+leonardo_x4+cc3m.parquet",
45
+ ]),
46
+ "augmenter": gcs_augmenters,
47
+ }
48
+ }
49
+
50
+ onlineDatasetMap = {
51
+ "combined_online": {
52
+ "source": [
53
+ # "gs://flaxdiff-datasets-regional/datasets/laion-aesthetics-12m+mscoco-2017.parquet"
54
+ # "ChristophSchuhmann/MS_COCO_2017_URL_TEXT",
55
+ # "dclure/laion-aesthetics-12m-umap",
56
+ "gs://flaxdiff-datasets-regional/datasets/laion-aesthetics-12m+mscoco-2017",
57
+ "gs://flaxdiff-datasets-regional/datasets/coyo700m-aesthetic-5.4_25M",
58
+ "gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m",
59
+ "gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m",
60
+ "gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m",
61
+ "gs://flaxdiff-datasets-regional/datasets/cc12m",
62
+ "gs://flaxdiff-datasets-regional/datasets/playground-liked",
63
+ "gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m",
64
+ "gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m",
65
+ "gs://flaxdiff-datasets-regional/datasets/cc3m",
66
+ "gs://flaxdiff-datasets-regional/datasets/cc3m",
67
+ "gs://flaxdiff-datasets-regional/datasets/laion2B-en-aesthetic-4.2_37M",
68
+ # "gs://flaxdiff-datasets-regional/datasets/laiion400m-185M"
69
+ ]
70
+ }
71
+ }
@@ -0,0 +1,169 @@
1
+ import jax.numpy as jnp
2
+ import grain.python as pygrain
3
+ from typing import Dict
4
+ import numpy as np
5
+ import jax
6
+ from flaxdiff.utils import convert_to_global_tree, AutoTextTokenizer
7
+ from .dataset_map import datasetMap, onlineDatasetMap
8
+ import traceback
9
+ from .online_loader import OnlineStreamingDataLoader
10
+ import queue
11
+ from jax.sharding import Mesh
12
+ import threading
13
+
14
+ def batch_mesh_map(mesh):
15
+ class augmenters(pygrain.MapTransform):
16
+ def __init__(self, *args, **kwargs):
17
+ super().__init__(*args, **kwargs)
18
+
19
+ def map(self, batch) -> Dict[str, jnp.array]:
20
+ return convert_to_global_tree(mesh, batch)
21
+ return augmenters
22
+
23
+ def get_dataset_grain(
24
+ data_name="cc12m",
25
+ batch_size=64,
26
+ image_scale=256,
27
+ count=None,
28
+ num_epochs=None,
29
+ method=jax.image.ResizeMethod.LANCZOS3,
30
+ worker_count=32,
31
+ read_thread_count=64,
32
+ read_buffer_size=50,
33
+ worker_buffer_size=20,
34
+ seed=0,
35
+ dataset_source="/mnt/gcs_mount/arrayrecord2/cc12m/",
36
+ ):
37
+ dataset = datasetMap[data_name]
38
+ data_source = dataset["source"](dataset_source)
39
+ augmenter = dataset["augmenter"](image_scale, method)
40
+
41
+ local_batch_size = batch_size // jax.process_count()
42
+
43
+ sampler = pygrain.IndexSampler(
44
+ num_records=len(data_source) if count is None else count,
45
+ shuffle=True,
46
+ seed=seed,
47
+ num_epochs=num_epochs,
48
+ shard_options=pygrain.ShardByJaxProcess(),
49
+ )
50
+
51
+ def get_trainset():
52
+ transformations = [
53
+ augmenter(),
54
+ pygrain.Batch(local_batch_size, drop_remainder=True),
55
+ ]
56
+
57
+ # if mesh != None:
58
+ # transformations += [batch_mesh_map(mesh)]
59
+
60
+ loader = pygrain.DataLoader(
61
+ data_source=data_source,
62
+ sampler=sampler,
63
+ operations=transformations,
64
+ worker_count=worker_count,
65
+ read_options=pygrain.ReadOptions(
66
+ read_thread_count, read_buffer_size
67
+ ),
68
+ worker_buffer_size=worker_buffer_size,
69
+ )
70
+ return loader
71
+
72
+
73
+ return {
74
+ "train": get_trainset,
75
+ "train_len": len(data_source),
76
+ "local_batch_size": local_batch_size,
77
+ "global_batch_size": batch_size,
78
+ # "null_labels": null_labels,
79
+ # "null_labels_full": null_labels_full,
80
+ # "model": model,
81
+ # "tokenizer": tokenizer,
82
+ }
83
+
84
+ def generate_collate_fn():
85
+ auto_tokenize = AutoTextTokenizer(tensor_type="np")
86
+ def default_collate(batch):
87
+ try:
88
+ # urls = [sample["url"] for sample in batch]
89
+ captions = [sample["caption"] for sample in batch]
90
+ results = auto_tokenize(captions)
91
+ images = np.stack([sample["image"] for sample in batch], axis=0)
92
+ return {
93
+ "image": images,
94
+ "input_ids": results['input_ids'],
95
+ "attention_mask": results['attention_mask'],
96
+ }
97
+ except Exception as e:
98
+ print("Error in collate function", e, [sample["image"].shape for sample in batch])
99
+ traceback.print_exc()
100
+
101
+ return default_collate
102
+
103
+ def get_dataset_online(
104
+ data_name="combined_online",
105
+ batch_size=64,
106
+ image_scale=256,
107
+ count=None,
108
+ num_epochs=None,
109
+ method=jax.image.ResizeMethod.LANCZOS3,
110
+ worker_count=32,
111
+ read_thread_count=64,
112
+ read_buffer_size=50,
113
+ worker_buffer_size=20,
114
+ seed=0,
115
+ dataset_source="/mnt/gcs_mount/arrayrecord2/cc12m/",
116
+ ):
117
+ local_batch_size = batch_size // jax.process_count()
118
+
119
+ sources = onlineDatasetMap[data_name]["source"]
120
+ dataloader = OnlineStreamingDataLoader(
121
+ sources,
122
+ batch_size=local_batch_size,
123
+ num_workers=worker_count,
124
+ num_threads=read_thread_count,
125
+ image_shape=(image_scale, image_scale),
126
+ global_process_count=jax.process_count(),
127
+ global_process_index=jax.process_index(),
128
+ prefetch=worker_buffer_size,
129
+ collate_fn=generate_collate_fn(),
130
+ default_split="train",
131
+ )
132
+
133
+ def get_trainset(mesh: Mesh = None):
134
+ if mesh != None:
135
+ class dataLoaderWithMesh:
136
+ def __init__(self, dataloader, mesh):
137
+ self.dataloader = dataloader
138
+ self.mesh = mesh
139
+ self.tmp_queue = queue.Queue(worker_buffer_size)
140
+ def batch_loader():
141
+ for batch in self.dataloader:
142
+ try:
143
+ self.tmp_queue.put(convert_to_global_tree(mesh, batch))
144
+ except Exception as e:
145
+ print("Error processing batch", e)
146
+ self.loader_thread = threading.Thread(target=batch_loader)
147
+ self.loader_thread.start()
148
+
149
+ def __iter__(self):
150
+ return self
151
+
152
+ def __next__(self):
153
+ return self.tmp_queue.get()
154
+
155
+ dataloader_with_mesh = dataLoaderWithMesh(dataloader, mesh)
156
+
157
+ return dataloader_with_mesh
158
+ return dataloader
159
+
160
+ return {
161
+ "train": get_trainset,
162
+ "train_len": len(dataloader) * jax.process_count(),
163
+ "local_batch_size": local_batch_size,
164
+ "global_batch_size": batch_size,
165
+ # "null_labels": null_labels,
166
+ # "null_labels_full": null_labels_full,
167
+ # "model": model,
168
+ # "tokenizer": tokenizer,
169
+ }
@@ -45,36 +45,43 @@ def fetch_single_image(image_url, timeout=None, retries=0):
45
45
 
46
46
 
47
47
  def default_image_processor(
48
- image, image_shape,
48
+ image, image_shape,
49
49
  min_image_shape=(128, 128),
50
50
  upscale_interpolation=cv2.INTER_CUBIC,
51
51
  downscale_interpolation=cv2.INTER_AREA,
52
52
  ):
53
- image = np.array(image)
54
- original_height, original_width = image.shape[:2]
55
- # check if the image is too small
56
- if min(original_height, original_width) < min(min_image_shape):
57
- return None, original_height, original_width
58
- # check if wrong aspect ratio
59
- if max(original_height, original_width) / min(original_height, original_width) > 2.4:
60
- return None, original_height, original_width
61
- # check if the variance is too low
62
- if np.std(image) < 1e-5:
63
- return None, original_height, original_width
64
- # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
65
- downscale = max(original_width, original_height) > max(image_shape)
66
- interpolation = downscale_interpolation if downscale else upscale_interpolation
67
-
68
- image = A.longest_max_size(image, max(
69
- image_shape), interpolation=interpolation)
70
- image = A.pad(
71
- image,
72
- min_height=image_shape[0],
73
- min_width=image_shape[1],
74
- border_mode=cv2.BORDER_CONSTANT,
75
- value=[255, 255, 255],
76
- )
77
- return image, original_height, original_width
53
+ try:
54
+ image = np.array(image)
55
+ if len(image.shape) != 3 or image.shape[2] != 3:
56
+ return None, 0, 0
57
+ original_height, original_width = image.shape[:2]
58
+ # check if the image is too small
59
+ if min(original_height, original_width) < min(min_image_shape):
60
+ return None, original_height, original_width
61
+ # check if wrong aspect ratio
62
+ if max(original_height, original_width) / min(original_height, original_width) > 2.4:
63
+ return None, original_height, original_width
64
+ # check if the variance is too low
65
+ if np.std(image) < 1e-5:
66
+ return None, original_height, original_width
67
+ # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
68
+ downscale = max(original_width, original_height) > max(image_shape)
69
+ interpolation = downscale_interpolation if downscale else upscale_interpolation
70
+
71
+ image = A.longest_max_size(image, max(
72
+ image_shape), interpolation=interpolation)
73
+ image = A.pad(
74
+ image,
75
+ min_height=image_shape[0],
76
+ min_width=image_shape[1],
77
+ border_mode=cv2.BORDER_CONSTANT,
78
+ value=[255, 255, 255],
79
+ )
80
+ return image, original_height, original_width
81
+ except Exception as e:
82
+ # print("Error processing image", e, image_shape, interpolation)
83
+ # traceback.print_exc()
84
+ return None, 0, 0
78
85
 
79
86
 
80
87
  def map_sample(
@@ -120,14 +127,36 @@ def map_sample(
120
127
  # })
121
128
  pass
122
129
 
123
-
124
130
  def default_feature_extractor(sample):
131
+ url = None
132
+ if "url" in sample:
133
+ url = sample["url"]
134
+ elif "URL" in sample:
135
+ url = sample["URL"]
136
+ elif "image_url" in sample:
137
+ url = sample["image_url"]
138
+ else:
139
+ print("No url found in sample, skipping", sample.keys())
140
+
141
+ caption = None
142
+ if "caption" in sample:
143
+ caption = sample["caption"]
144
+ elif "CAPTION" in sample:
145
+ caption = sample["CAPTION"]
146
+ elif "txt" in sample:
147
+ caption = sample["txt"]
148
+ elif "TEXT" in sample:
149
+ caption = sample["TEXT"]
150
+ elif "text" in sample:
151
+ caption = sample["text"]
152
+ else:
153
+ print("No caption found in sample, skipping", sample.keys())
154
+
125
155
  return {
126
- "url": sample["url"],
127
- "caption": sample["caption"],
156
+ "url": url,
157
+ "caption": caption,
128
158
  }
129
-
130
-
159
+
131
160
  def map_batch(
132
161
  batch, num_threads=256, image_shape=(256, 256),
133
162
  min_image_shape=(128, 128),
@@ -301,15 +330,13 @@ class OnlineStreamingDataLoader():
301
330
  self.dataset = dataset.shard(
302
331
  num_shards=global_process_count, index=global_process_index)
303
332
  print(f"Dataset length: {len(dataset)}")
304
- self.iterator = ImageBatchIterator(
305
- self.dataset, image_shape=image_shape,
306
- min_image_shape=min_image_shape,
307
- num_workers=num_workers, batch_size=batch_size, num_threads=num_threads,
308
- timeout=timeout, retries=retries, image_processor=image_processor,
309
- upscale_interpolation=upscale_interpolation,
310
- downscale_interpolation=downscale_interpolation,
311
- feature_extractor=feature_extractor
312
- )
333
+ self.iterator = ImageBatchIterator(self.dataset, image_shape=image_shape,
334
+ min_image_shape=min_image_shape,
335
+ num_workers=num_workers, batch_size=batch_size, num_threads=num_threads,
336
+ timeout=timeout, retries=retries, image_processor=image_processor,
337
+ upscale_interpolation=upscale_interpolation,
338
+ downscale_interpolation=downscale_interpolation,
339
+ feature_extractor=feature_extractor)
313
340
  self.batch_size = batch_size
314
341
 
315
342
  # Launch a thread to load batches in the background
@@ -320,7 +347,7 @@ class OnlineStreamingDataLoader():
320
347
  try:
321
348
  self.batch_queue.put(collate_fn(batch))
322
349
  except Exception as e:
323
- print("Error processing batch", e)
350
+ print("Error collating batch", e)
324
351
 
325
352
  self.loader_thread = threading.Thread(target=batch_loader)
326
353
  self.loader_thread.start()
@@ -333,4 +360,4 @@ class OnlineStreamingDataLoader():
333
360
  # return self.collate_fn(next(self.iterator))
334
361
 
335
362
  def __len__(self):
336
- return len(self.dataset)
363
+ return len(self.dataset)
@@ -11,6 +11,7 @@ import einops
11
11
  import functools
12
12
  import math
13
13
  from .common import kernel_init
14
+ import jax.experimental.pallas.ops.tpu.flash_attention
14
15
 
15
16
  class EfficientAttention(nn.Module):
16
17
  """
@@ -50,7 +50,7 @@ class Unet(nn.Module):
50
50
  features=self.feature_depths[0],
51
51
  kernel_size=(3, 3),
52
52
  strides=(1, 1),
53
- kernel_init=self.kernel_init(1.0),
53
+ kernel_init=self.kernel_init(scale=1.0),
54
54
  dtype=self.dtype,
55
55
  precision=self.precision
56
56
  )(x)
@@ -65,7 +65,7 @@ class Unet(nn.Module):
65
65
  down_conv_type,
66
66
  name=f"down_{i}_residual_{j}",
67
67
  features=dim_in,
68
- kernel_init=self.kernel_init(1.0),
68
+ kernel_init=self.kernel_init(scale=1.0),
69
69
  kernel_size=(3, 3),
70
70
  strides=(1, 1),
71
71
  activation=self.activation,
@@ -85,7 +85,7 @@ class Unet(nn.Module):
85
85
  force_fp32_for_softmax=attention_config.get("force_fp32_for_softmax", False),
86
86
  norm_inputs=attention_config.get("norm_inputs", True),
87
87
  explicitly_add_residual=attention_config.get("explicitly_add_residual", True),
88
- kernel_init=self.kernel_init(1.0),
88
+ kernel_init=self.kernel_init(scale=1.0),
89
89
  name=f"down_{i}_attention_{j}")(x, textcontext)
90
90
  # print("down residual for feature level", i, "is of shape", x.shape, "features", dim_in)
91
91
  downs.append(x)
@@ -108,7 +108,7 @@ class Unet(nn.Module):
108
108
  middle_conv_type,
109
109
  name=f"middle_res1_{j}",
110
110
  features=middle_dim_out,
111
- kernel_init=self.kernel_init(1.0),
111
+ kernel_init=self.kernel_init(scale=1.0),
112
112
  kernel_size=(3, 3),
113
113
  strides=(1, 1),
114
114
  activation=self.activation,
@@ -129,13 +129,13 @@ class Unet(nn.Module):
129
129
  force_fp32_for_softmax=middle_attention.get("force_fp32_for_softmax", False),
130
130
  norm_inputs=middle_attention.get("norm_inputs", True),
131
131
  explicitly_add_residual=middle_attention.get("explicitly_add_residual", True),
132
- kernel_init=self.kernel_init(1.0),
132
+ kernel_init=self.kernel_init(scale=1.0),
133
133
  name=f"middle_attention_{j}")(x, textcontext)
134
134
  x = ResidualBlock(
135
135
  middle_conv_type,
136
136
  name=f"middle_res2_{j}",
137
137
  features=middle_dim_out,
138
- kernel_init=self.kernel_init(1.0),
138
+ kernel_init=self.kernel_init(scale=1.0),
139
139
  kernel_size=(3, 3),
140
140
  strides=(1, 1),
141
141
  activation=self.activation,
@@ -157,7 +157,7 @@ class Unet(nn.Module):
157
157
  up_conv_type,# if j == 0 else "separable",
158
158
  name=f"up_{i}_residual_{j}",
159
159
  features=dim_out,
160
- kernel_init=self.kernel_init(1.0),
160
+ kernel_init=self.kernel_init(scale=1.0),
161
161
  kernel_size=kernel_size,
162
162
  strides=(1, 1),
163
163
  activation=self.activation,
@@ -177,7 +177,7 @@ class Unet(nn.Module):
177
177
  force_fp32_for_softmax=middle_attention.get("force_fp32_for_softmax", False),
178
178
  norm_inputs=attention_config.get("norm_inputs", True),
179
179
  explicitly_add_residual=attention_config.get("explicitly_add_residual", True),
180
- kernel_init=self.kernel_init(1.0),
180
+ kernel_init=self.kernel_init(scale=1.0),
181
181
  name=f"up_{i}_attention_{j}")(x, textcontext)
182
182
  # print("Upscaling ", i, x.shape)
183
183
  if i != len(feature_depths) - 1:
@@ -196,7 +196,7 @@ class Unet(nn.Module):
196
196
  features=self.feature_depths[0],
197
197
  kernel_size=(3, 3),
198
198
  strides=(1, 1),
199
- kernel_init=self.kernel_init(1.0),
199
+ kernel_init=self.kernel_init(scale=1.0),
200
200
  dtype=self.dtype,
201
201
  precision=self.precision
202
202
  )(x)
@@ -207,7 +207,7 @@ class Unet(nn.Module):
207
207
  conv_type,
208
208
  name="final_residual",
209
209
  features=self.feature_depths[0],
210
- kernel_init=self.kernel_init(1.0),
210
+ kernel_init=self.kernel_init(scale=1.0),
211
211
  kernel_size=(3,3),
212
212
  strides=(1, 1),
213
213
  activation=self.activation,
@@ -226,7 +226,7 @@ class Unet(nn.Module):
226
226
  kernel_size=(3, 3),
227
227
  strides=(1, 1),
228
228
  # activation=jax.nn.mish
229
- kernel_init=self.kernel_init(0.0),
229
+ kernel_init=self.kernel_init(scale=0.0),
230
230
  dtype=self.dtype,
231
231
  precision=self.precision
232
232
  )(x)
@@ -70,7 +70,7 @@ class UViT(nn.Module):
70
70
  kernel_init: Callable = partial(kernel_init, scale=1.0)
71
71
  add_residualblock_output: bool = False
72
72
  norm_inputs: bool = False
73
- explicitly_add_residual: bool = False
73
+ explicitly_add_residual: bool = True
74
74
 
75
75
  def setup(self):
76
76
  if self.norm_groups > 0: