keras-hub-nightly 0.16.1.dev202410030339__py3-none-any.whl → 0.16.1.dev202410040340__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. keras_hub/api/layers/__init__.py +3 -0
  2. keras_hub/api/models/__init__.py +9 -0
  3. keras_hub/src/models/deeplab_v3/__init__.py +7 -0
  4. keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +196 -0
  5. keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py +10 -0
  6. keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py +16 -0
  7. keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +215 -0
  8. keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +4 -0
  9. keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +109 -0
  10. keras_hub/src/models/image_segmenter_preprocessor.py +29 -4
  11. keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +8 -1
  12. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +57 -93
  13. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +3 -3
  14. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +5 -3
  15. keras_hub/src/models/task.py +20 -15
  16. keras_hub/src/models/vae/__init__.py +1 -0
  17. keras_hub/src/models/vae/vae_backbone.py +172 -0
  18. keras_hub/src/models/vae/vae_layers.py +740 -0
  19. keras_hub/src/version_utils.py +1 -1
  20. {keras_hub_nightly-0.16.1.dev202410030339.dist-info → keras_hub_nightly-0.16.1.dev202410040340.dist-info}/METADATA +1 -1
  21. {keras_hub_nightly-0.16.1.dev202410030339.dist-info → keras_hub_nightly-0.16.1.dev202410040340.dist-info}/RECORD +23 -14
  22. keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +0 -320
  23. {keras_hub_nightly-0.16.1.dev202410030339.dist-info → keras_hub_nightly-0.16.1.dev202410040340.dist-info}/WHEEL +0 -0
  24. {keras_hub_nightly-0.16.1.dev202410030339.dist-info → keras_hub_nightly-0.16.1.dev202410040340.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,7 @@
1
1
  from keras_hub.src.api_export import keras_hub_export
2
2
 
3
3
  # Unique source of truth for the version number.
4
- __version__ = "0.16.1.dev202410030339"
4
+ __version__ = "0.16.1.dev202410040340"
5
5
 
6
6
 
7
7
  @keras_hub_export("keras_hub.version")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: keras-hub-nightly
3
- Version: 0.16.1.dev202410030339
3
+ Version: 0.16.1.dev202410040340
4
4
  Summary: Industry-strength Natural Language Processing extensions for Keras.
5
5
  Home-page: https://github.com/keras-team/keras-hub
6
6
  Author: Keras team
@@ -1,15 +1,15 @@
1
1
  keras_hub/__init__.py,sha256=QGdXyHgYt6cMUAP1ebxwc6oR86dE0dkMxNy2eOCQtFo,855
2
2
  keras_hub/api/__init__.py,sha256=spMxsgqzjpeuC8rY4WP-2kAZ2qwwKRSbFwddXgUjqQE,524
3
3
  keras_hub/api/bounding_box/__init__.py,sha256=T8R_X7BPm0et1xaZq8565uJmid7dylsSFSj4V-rGuFQ,1097
4
- keras_hub/api/layers/__init__.py,sha256=P1Zn4sjTx1OnmlRyX8-QRxSe-2gkvyQ-90BzCjqr3oU,2227
4
+ keras_hub/api/layers/__init__.py,sha256=XImD0tHdnDR1a7q3u-Pw-VRMASi9sDtrV6hr2beVYTw,2331
5
5
  keras_hub/api/metrics/__init__.py,sha256=So8Ec-lOcTzn_UUMmAdzDm8RKkPu2dbRUm2px8gpUEI,381
6
- keras_hub/api/models/__init__.py,sha256=dyancDilnzbHByiTYQNhqfm6JFeZH_DKHl4PZuvWoA0,13994
6
+ keras_hub/api/models/__init__.py,sha256=sMfVpa2N90cG7qjkwSEI_x3uCvZNwQqFbedn5wcUzbE,14311
7
7
  keras_hub/api/samplers/__init__.py,sha256=n-_SEXxr2LNUzK2FqVFN7alsrkx1P_HOVTeLZKeGCdE,730
8
8
  keras_hub/api/tokenizers/__init__.py,sha256=_f-r_cyUM2fjBB7iO84ThOdqqsAxHNIewJ2EBDlM0cA,2524
9
9
  keras_hub/api/utils/__init__.py,sha256=Gp1E6gG-RtKQS3PBEQEOz9PQvXkXaJ0ySGMqZ7myN7A,215
10
10
  keras_hub/src/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
11
  keras_hub/src/api_export.py,sha256=9pQZK27JObxWZ96QPLBp1OBsjWigh1iuV6RglPGMRk0,1499
12
- keras_hub/src/version_utils.py,sha256=Q7sWkBqN11QJLqnWmwU9B2XhXWRKLr1vv199Ud-cp4A,222
12
+ keras_hub/src/version_utils.py,sha256=9pXCOZsdoqgw8IovZaLqnG-5LTgH73BlKBZXdOvDpYc,222
13
13
  keras_hub/src/bounding_box/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
14
  keras_hub/src/bounding_box/converters.py,sha256=a5po8DBm87oz2EXfi-0uEZHCMlCJPIb4-MaZIdYx3Dg,17865
15
15
  keras_hub/src/bounding_box/formats.py,sha256=YmskOz2BOSat7NaE__J9VfpSNGPJJR0znSzA4lp8MMI,3868
@@ -56,13 +56,13 @@ keras_hub/src/models/feature_pyramid_backbone.py,sha256=clEW-TTQSVJ_5qFNdDF0iABk
56
56
  keras_hub/src/models/image_classifier.py,sha256=yt6cjhPfqs8A_eWXBsXdXFzn-aRgH2rVHUq7Zu7CyK8,7804
57
57
  keras_hub/src/models/image_classifier_preprocessor.py,sha256=YdewYfMPVHI7gdhbBI-zVcy4NSfg0bhiOHTmGEKoOYI,2668
58
58
  keras_hub/src/models/image_segmenter.py,sha256=C1bzIO59pG58iist5GLn_qnlotDpcAVxPV_8a68BkAc,2876
59
- keras_hub/src/models/image_segmenter_preprocessor.py,sha256=vJoZc1OebQWlqUP_ygCS7P1Pyq1KmmUc-0V_-maDzX4,2658
59
+ keras_hub/src/models/image_segmenter_preprocessor.py,sha256=IMmVJWBc0VZ1-5jLmFmmwQ3q_oQnhIfCE9A6nS1ss8Q,3743
60
60
  keras_hub/src/models/masked_lm.py,sha256=uXO_dE_hILlOC9jNr6oK6IHi9IGUqLyNGvr6nMt8Rk0,3576
61
61
  keras_hub/src/models/masked_lm_preprocessor.py,sha256=g8vrnyYwqdnSw5xppROM1Gzo_jmMWKYZoQCsKdfrFKk,5656
62
62
  keras_hub/src/models/preprocessor.py,sha256=pJodz7KRVncvsC3o4qoKDYWP2J0a8E9CD6oVGYgJzIM,7970
63
63
  keras_hub/src/models/seq_2_seq_lm.py,sha256=w0gX-5YZjatfvAJmFAgSHyqS_BLqc8FF8DPLGK8mrgI,1864
64
64
  keras_hub/src/models/seq_2_seq_lm_preprocessor.py,sha256=HUHRbWRG5SF1pPpotGzBhXlrMh4pLFxgAoFk05FIrB4,9687
65
- keras_hub/src/models/task.py,sha256=MfrzIoj3XFaRiNlUg-K6D8l-ylWfpzBjjmSy-guXtG8,13935
65
+ keras_hub/src/models/task.py,sha256=2iapEFHvzyl0ASlH6yzQA2OHSr1jV1V-pLtagHdBncQ,14416
66
66
  keras_hub/src/models/text_classifier.py,sha256=VBDvQUHTpJPqKp7A4VAtm35FOmJ3yMo0DW6GdX67xG0,4159
67
67
  keras_hub/src/models/text_classifier_preprocessor.py,sha256=EoWp-GHnaLnAKTdAzDmC-soAV92ATF3QozdubdV2WXI,4722
68
68
  keras_hub/src/models/text_to_image.py,sha256=N42l1W8YEUBHOdGiT4BQNqzTpgjB2O5dtLU5FbKpMy0,10792
@@ -115,6 +115,13 @@ keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py,sha256=zEMCLy9eCiBEpA_xM
115
115
  keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py,sha256=JbdQV1ZHFX4_PcJhemuHQ5YPkJC-XujgbTyjCrdL7nk,8556
116
116
  keras_hub/src/models/deberta_v3/disentangled_self_attention.py,sha256=3l7Hy7JfiZDDDFE6uTqSuFjg802kXD9acA7aHKRdzJk,13122
117
117
  keras_hub/src/models/deberta_v3/relative_embedding.py,sha256=3WIQ1nWcEhfWF0U9DcKyYz3AAhO3Pmg7ykpzrYe0Jgw,2886
118
+ keras_hub/src/models/deeplab_v3/__init__.py,sha256=FHAUPM4a1DJj4EsNTbYEd1riNq__uHU4eB3t3Z1zgj0,288
119
+ keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py,sha256=WyFhuLcjFPFVuNL09bvW_Jsja7WjEw3zazKCOwbFDTM,7709
120
+ keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py,sha256=mRkH3HdhpV0fCcQcVXEvIX7SNk-bAMb3SAHzgK-FD5c,371
121
+ keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py,sha256=hR9S6lNYamY0EBDBo3e1qTCiwtftmLXrN-UYuzfw5Io,581
122
+ keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py,sha256=qmEiolOOriLAojXB67xXW9IOo717kaCGeDVZJLaGY98,7834
123
+ keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py,sha256=JE-Uv_CXDVnwNkTcy5GtzrIQXvcXSP7fdVpfhmk8qhg,122
124
+ keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py,sha256=tiMDcCFANHMUx3IVtW3r1P_JTazgPPsbW4IktIytKEU,3650
118
125
  keras_hub/src/models/densenet/__init__.py,sha256=r7StyamnWeeZxOk9r4ZYNbS_YVhu9YGPyXhNxljvdPg,269
119
126
  keras_hub/src/models/densenet/densenet_backbone.py,sha256=dN9lUwKzO3E2HthNV2x54ozeBEQ0ilNs5uYHshFQpT0,6723
120
127
  keras_hub/src/models/densenet/densenet_image_classifier.py,sha256=ptuV6PwgoUpmrSPqX7-a85IpWsElwcCv_G5IVkP9E_Q,530
@@ -263,14 +270,13 @@ keras_hub/src/models/sam/sam_presets.py,sha256=AfGUKNOkz0G11OMYqVebXKgEBar1qpIkA
263
270
  keras_hub/src/models/sam/sam_prompt_encoder.py,sha256=2foB7900QbzQfZjBo335XYsdjmhOnVT8fKD1CubJNVE,11801
264
271
  keras_hub/src/models/sam/sam_transformer.py,sha256=L2bdxdc2RUF1juRZ0F0Z6r0gTva1sUwEdjItJmKKf6w,5730
265
272
  keras_hub/src/models/stable_diffusion_3/__init__.py,sha256=ZKYQuaRObyhKq8GVAHmoRvlXp6FpU8ChvutVCHyXKuc,343
266
- keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py,sha256=9SLbOpAv50q8yv8I6H4DHbsIgwNo8TJmwZfAH8Ew6Zw,2827
273
+ keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py,sha256=vtVhieAv277mAiZj7Kvvqg_Ba7klfQxZVk4PPxNNQ0s,3062
267
274
  keras_hub/src/models/stable_diffusion_3/mmdit.py,sha256=ntmxjDJtZbHDGVPPAnasVZyoOTp5bbMPhxM30SYmpoQ,25711
268
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py,sha256=9rWSG0C23_pwN1pymZbial3GX_UM4tmDLXtB4kTQ04w,22599
269
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py,sha256=gfF5ZOhJx03IQTPnb2Nf65i3pNz-fQlhdAJ3DjKHHZ8,658
270
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py,sha256=XH4osHG9EE1sJpfj7rf0bCqrIHpeXaswFoEojWnE0pw,4419
275
+ keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py,sha256=D-U5T6UYKzraHLAgMa-LLcd40ZmX_5rmlybawT4ooHY,21398
276
+ keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py,sha256=sBGVRFd-bYxcqydydOB70XpOtpTt6AVrTR3LV-LBFXY,662
277
+ keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py,sha256=8nP3ejDOd1hqjYXJzbri62PgtclxGydw-8bw-qHIPdc,4414
271
278
  keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py,sha256=TB0KESt5dnFYiS292PbzB0LdiH23AD6aTSTGmQEuzGM,2742
272
279
  keras_hub/src/models/stable_diffusion_3/t5_encoder.py,sha256=oV7P1uwCKdGiD93zXq7kmqX0elMZQU4UvBa8wg6P1hs,5113
273
- keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py,sha256=j4nrvfYhW-4njhhk4PFf-bWQF-EHzplvaT15Q7s5Pb4,10056
274
280
  keras_hub/src/models/t5/__init__.py,sha256=OWyoUeDY3v4DnO8Ry02DWV1bNSVGcC89PF9oCftyi1s,233
275
281
  keras_hub/src/models/t5/t5_backbone.py,sha256=AtE2VudEUkm7hE3p6JP_CfEAjt4pwgSKOBQ0B0BggQc,10258
276
282
  keras_hub/src/models/t5/t5_layer_norm.py,sha256=R8KPHFOq9N3SD013WjtloLWRzaEMNEyY0fbViNEFVXQ,630
@@ -279,6 +285,9 @@ keras_hub/src/models/t5/t5_preprocessor.py,sha256=UVOnCHUJF_MBcOyfR9G9oeRUEoN3Xo
279
285
  keras_hub/src/models/t5/t5_presets.py,sha256=95zU4cTNEZMH2yiCLptA9zhu2D4mE1Cay18K91nt7jM,3005
280
286
  keras_hub/src/models/t5/t5_tokenizer.py,sha256=pLTu15JeYSpVmy-2600vBc-Mxn_uHyTKts4PI2MxxBM,2517
281
287
  keras_hub/src/models/t5/t5_transformer_layer.py,sha256=uDeP84F1x7xJxki5iKe12Zn6eWD_4yVjoFXMuod-a3A,5347
288
+ keras_hub/src/models/vae/__init__.py,sha256=i3UaSW4IJf76O7lSPE1dyxOVjuHx8iAYKivqvUbDHOw,62
289
+ keras_hub/src/models/vae/vae_backbone.py,sha256=aYf1sGteFJ7FyR3X8Ek6QBjAT5GjRtQTK2jXhYVJeM4,6671
290
+ keras_hub/src/models/vae/vae_layers.py,sha256=N83CYM1zgbl1EIjAOs3cFCkJEdxvbXkgM9ghKyljFAg,27752
282
291
  keras_hub/src/models/vgg/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
283
292
  keras_hub/src/models/vgg/vgg_backbone.py,sha256=QnEDKn5n9bA9p3nvt5fBHnAssvnLxR0qv-oB372Ts0U,3702
284
293
  keras_hub/src/models/vgg/vgg_image_classifier.py,sha256=Dtq_HIJP6fHe8m7ZVLVn8IbHEsVMFWLvWMmn8TU1ntw,6600
@@ -350,7 +359,7 @@ keras_hub/src/utils/transformers/convert_mistral.py,sha256=kVhN9h1ZFVhwkNW8p3wnS
350
359
  keras_hub/src/utils/transformers/convert_pali_gemma.py,sha256=B1leeDw96Yvu81hYumf66hIid07k5NLqoeWAJgPnaLs,10649
351
360
  keras_hub/src/utils/transformers/preset_loader.py,sha256=GS44hZUuGQCtzsyn8z44ZpHdftd3DFemwV2hx2bQa-U,2738
352
361
  keras_hub/src/utils/transformers/safetensor_utils.py,sha256=rPK-Uw1CG0DX0d_UAD-r2cG9fw8GI8bvAlrcXfQ9g4c,3323
353
- keras_hub_nightly-0.16.1.dev202410030339.dist-info/METADATA,sha256=tLxESmpHL96pjwqK1gteBF1IdJ_CKtgBOvGEIG9gfyU,7458
354
- keras_hub_nightly-0.16.1.dev202410030339.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
355
- keras_hub_nightly-0.16.1.dev202410030339.dist-info/top_level.txt,sha256=N4J6piIWBKa38A4uV-CnIopnOEf8mHAbkNXafXm_CuA,10
356
- keras_hub_nightly-0.16.1.dev202410030339.dist-info/RECORD,,
362
+ keras_hub_nightly-0.16.1.dev202410040340.dist-info/METADATA,sha256=BuGQEiANSPxHuXnmW4K_EEYaGBE2ENMMSMARmCthK70,7458
363
+ keras_hub_nightly-0.16.1.dev202410040340.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
364
+ keras_hub_nightly-0.16.1.dev202410040340.dist-info/top_level.txt,sha256=N4J6piIWBKa38A4uV-CnIopnOEf8mHAbkNXafXm_CuA,10
365
+ keras_hub_nightly-0.16.1.dev202410040340.dist-info/RECORD,,
@@ -1,320 +0,0 @@
1
- import math
2
-
3
- from keras import layers
4
- from keras import ops
5
-
6
- from keras_hub.src.models.backbone import Backbone
7
- from keras_hub.src.utils.keras_utils import standardize_data_format
8
-
9
-
10
- class VAEAttention(layers.Layer):
11
- def __init__(self, filters, groups=32, data_format=None, **kwargs):
12
- super().__init__(**kwargs)
13
- self.filters = filters
14
- self.data_format = standardize_data_format(data_format)
15
- gn_axis = -1 if self.data_format == "channels_last" else 1
16
-
17
- self.group_norm = layers.GroupNormalization(
18
- groups=groups,
19
- axis=gn_axis,
20
- epsilon=1e-6,
21
- dtype="float32",
22
- name="group_norm",
23
- )
24
- self.query_conv2d = layers.Conv2D(
25
- filters,
26
- 1,
27
- 1,
28
- data_format=self.data_format,
29
- dtype=self.dtype_policy,
30
- name="query_conv2d",
31
- )
32
- self.key_conv2d = layers.Conv2D(
33
- filters,
34
- 1,
35
- 1,
36
- data_format=self.data_format,
37
- dtype=self.dtype_policy,
38
- name="key_conv2d",
39
- )
40
- self.value_conv2d = layers.Conv2D(
41
- filters,
42
- 1,
43
- 1,
44
- data_format=self.data_format,
45
- dtype=self.dtype_policy,
46
- name="value_conv2d",
47
- )
48
- self.softmax = layers.Softmax(dtype="float32")
49
- self.output_conv2d = layers.Conv2D(
50
- filters,
51
- 1,
52
- 1,
53
- data_format=self.data_format,
54
- dtype=self.dtype_policy,
55
- name="output_conv2d",
56
- )
57
-
58
- self.groups = groups
59
- self._inverse_sqrt_filters = 1.0 / math.sqrt(float(filters))
60
-
61
- def build(self, input_shape):
62
- self.group_norm.build(input_shape)
63
- self.query_conv2d.build(input_shape)
64
- self.key_conv2d.build(input_shape)
65
- self.value_conv2d.build(input_shape)
66
- self.output_conv2d.build(input_shape)
67
-
68
- def call(self, inputs, training=None):
69
- x = self.group_norm(inputs)
70
- query = self.query_conv2d(x)
71
- key = self.key_conv2d(x)
72
- value = self.value_conv2d(x)
73
-
74
- if self.data_format == "channels_first":
75
- query = ops.transpose(query, (0, 2, 3, 1))
76
- key = ops.transpose(key, (0, 2, 3, 1))
77
- value = ops.transpose(value, (0, 2, 3, 1))
78
- shape = ops.shape(inputs)
79
- b = shape[0]
80
- query = ops.reshape(query, (b, -1, self.filters))
81
- key = ops.reshape(key, (b, -1, self.filters))
82
- value = ops.reshape(value, (b, -1, self.filters))
83
-
84
- # Compute attention.
85
- query = ops.multiply(
86
- query, ops.cast(self._inverse_sqrt_filters, query.dtype)
87
- )
88
- # [B, H0 * W0, C], [B, H1 * W1, C] -> [B, H0 * W0, H1 * W1]
89
- attention_scores = ops.einsum("abc,adc->abd", query, key)
90
- attention_scores = ops.cast(
91
- self.softmax(attention_scores), self.compute_dtype
92
- )
93
- # [B, H2 * W2, C], [B, H0 * W0, H1 * W1] -> [B, H1 * W1 ,C]
94
- attention_output = ops.einsum("abc,adb->adc", value, attention_scores)
95
- x = ops.reshape(attention_output, shape)
96
-
97
- x = self.output_conv2d(x)
98
- if self.data_format == "channels_first":
99
- x = ops.transpose(x, (0, 3, 1, 2))
100
- x = ops.add(x, inputs)
101
- return x
102
-
103
- def get_config(self):
104
- config = super().get_config()
105
- config.update(
106
- {
107
- "filters": self.filters,
108
- "groups": self.groups,
109
- }
110
- )
111
- return config
112
-
113
- def compute_output_shape(self, input_shape):
114
- return input_shape
115
-
116
-
117
- def apply_resnet_block(x, filters, data_format=None, dtype=None, name=None):
118
- data_format = standardize_data_format(data_format)
119
- gn_axis = -1 if data_format == "channels_last" else 1
120
- input_filters = x.shape[gn_axis]
121
-
122
- residual = x
123
- x = layers.GroupNormalization(
124
- groups=32,
125
- axis=gn_axis,
126
- epsilon=1e-6,
127
- dtype="float32",
128
- name=f"{name}_norm1",
129
- )(x)
130
- x = layers.Activation("swish", dtype=dtype)(x)
131
- x = layers.Conv2D(
132
- filters,
133
- 3,
134
- 1,
135
- padding="same",
136
- data_format=data_format,
137
- dtype=dtype,
138
- name=f"{name}_conv1",
139
- )(x)
140
- x = layers.GroupNormalization(
141
- groups=32,
142
- axis=gn_axis,
143
- epsilon=1e-6,
144
- dtype="float32",
145
- name=f"{name}_norm2",
146
- )(x)
147
- x = layers.Activation("swish", dtype=dtype)(x)
148
- x = layers.Conv2D(
149
- filters,
150
- 3,
151
- 1,
152
- padding="same",
153
- data_format=data_format,
154
- dtype=dtype,
155
- name=f"{name}_conv2",
156
- )(x)
157
- if input_filters != filters:
158
- residual = layers.Conv2D(
159
- filters,
160
- 1,
161
- 1,
162
- data_format=data_format,
163
- dtype=dtype,
164
- name=f"{name}_residual_projection",
165
- )(residual)
166
- x = layers.Add(dtype=dtype)([residual, x])
167
- return x
168
-
169
-
170
- class VAEImageDecoder(Backbone):
171
- """Decoder for the VAE model used in Stable Diffusion 3.
172
-
173
- Args:
174
- stackwise_num_filters: list of ints. The number of filters for each
175
- stack.
176
- stackwise_num_blocks: list of ints. The number of blocks for each stack.
177
- output_channels: int. The number of channels in the output.
178
- latent_shape: tuple. The shape of the latent image.
179
- data_format: `None` or str. If specified, either `"channels_last"` or
180
- `"channels_first"`. The ordering of the dimensions in the
181
- inputs. `"channels_last"` corresponds to inputs with shape
182
- `(batch_size, height, width, channels)`
183
- while `"channels_first"` corresponds to inputs with shape
184
- `(batch_size, channels, height, width)`. It defaults to the
185
- `image_data_format` value found in your Keras config file at
186
- `~/.keras/keras.json`. If you never set it, then it will be
187
- `"channels_last"`.
188
- dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype
189
- to use for the model's computations and weights.
190
- """
191
-
192
- def __init__(
193
- self,
194
- stackwise_num_filters,
195
- stackwise_num_blocks,
196
- output_channels=3,
197
- latent_shape=(None, None, 16),
198
- data_format=None,
199
- dtype=None,
200
- **kwargs,
201
- ):
202
- data_format = standardize_data_format(data_format)
203
- gn_axis = -1 if data_format == "channels_last" else 1
204
-
205
- # === Functional Model ===
206
- latent_inputs = layers.Input(shape=latent_shape)
207
-
208
- x = layers.Conv2D(
209
- stackwise_num_filters[0],
210
- 3,
211
- 1,
212
- padding="same",
213
- data_format=data_format,
214
- dtype=dtype,
215
- name="input_projection",
216
- )(latent_inputs)
217
- x = apply_resnet_block(
218
- x,
219
- stackwise_num_filters[0],
220
- data_format=data_format,
221
- dtype=dtype,
222
- name="input_block0",
223
- )
224
- x = VAEAttention(
225
- stackwise_num_filters[0],
226
- data_format=data_format,
227
- dtype=dtype,
228
- name="input_attention",
229
- )(x)
230
- x = apply_resnet_block(
231
- x,
232
- stackwise_num_filters[0],
233
- data_format=data_format,
234
- dtype=dtype,
235
- name="input_block1",
236
- )
237
-
238
- # Stacks.
239
- for i, filters in enumerate(stackwise_num_filters):
240
- for j in range(stackwise_num_blocks[i]):
241
- x = apply_resnet_block(
242
- x,
243
- filters,
244
- data_format=data_format,
245
- dtype=dtype,
246
- name=f"block{i}_{j}",
247
- )
248
- if i != len(stackwise_num_filters) - 1:
249
- # No upsamling in the last blcok.
250
- x = layers.UpSampling2D(
251
- 2,
252
- data_format=data_format,
253
- dtype=dtype,
254
- name=f"upsample_{i}",
255
- )(x)
256
- x = layers.Conv2D(
257
- filters,
258
- 3,
259
- 1,
260
- padding="same",
261
- data_format=data_format,
262
- dtype=dtype,
263
- name=f"upsample_{i}_conv",
264
- )(x)
265
-
266
- # Ouput block.
267
- x = layers.GroupNormalization(
268
- groups=32,
269
- axis=gn_axis,
270
- epsilon=1e-6,
271
- dtype="float32",
272
- name="output_norm",
273
- )(x)
274
- x = layers.Activation("swish", dtype=dtype, name="output_activation")(x)
275
- image_outputs = layers.Conv2D(
276
- output_channels,
277
- 3,
278
- 1,
279
- padding="same",
280
- data_format=data_format,
281
- dtype=dtype,
282
- name="output_projection",
283
- )(x)
284
- super().__init__(inputs=latent_inputs, outputs=image_outputs, **kwargs)
285
-
286
- # === Config ===
287
- self.stackwise_num_filters = stackwise_num_filters
288
- self.stackwise_num_blocks = stackwise_num_blocks
289
- self.output_channels = output_channels
290
- self.latent_shape = latent_shape
291
-
292
- @property
293
- def scaling_factor(self):
294
- """The scaling factor for the latent space.
295
-
296
- This is used to scale the latent space to have unit variance when
297
- training the diffusion model.
298
- """
299
- return 1.5305
300
-
301
- @property
302
- def shift_factor(self):
303
- """The shift factor for the latent space.
304
-
305
- This is used to shift the latent space to have zero mean when
306
- training the diffusion model.
307
- """
308
- return 0.0609
309
-
310
- def get_config(self):
311
- config = super().get_config()
312
- config.update(
313
- {
314
- "stackwise_num_filters": self.stackwise_num_filters,
315
- "stackwise_num_blocks": self.stackwise_num_blocks,
316
- "output_channels": self.output_channels,
317
- "image_shape": self.latent_shape,
318
- }
319
- )
320
- return config