keras-hub-nightly 0.16.1.dev202410030339__py3-none-any.whl → 0.16.1.dev202410040340__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/layers/__init__.py +3 -0
- keras_hub/api/models/__init__.py +9 -0
- keras_hub/src/models/deeplab_v3/__init__.py +7 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +196 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py +10 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py +16 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +215 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +4 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +109 -0
- keras_hub/src/models/image_segmenter_preprocessor.py +29 -4
- keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +8 -1
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +57 -93
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +3 -3
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +5 -3
- keras_hub/src/models/task.py +20 -15
- keras_hub/src/models/vae/__init__.py +1 -0
- keras_hub/src/models/vae/vae_backbone.py +172 -0
- keras_hub/src/models/vae/vae_layers.py +740 -0
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.16.1.dev202410030339.dist-info → keras_hub_nightly-0.16.1.dev202410040340.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.16.1.dev202410030339.dist-info → keras_hub_nightly-0.16.1.dev202410040340.dist-info}/RECORD +23 -14
- keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +0 -320
- {keras_hub_nightly-0.16.1.dev202410030339.dist-info → keras_hub_nightly-0.16.1.dev202410040340.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.16.1.dev202410030339.dist-info → keras_hub_nightly-0.16.1.dev202410040340.dist-info}/top_level.txt +0 -0
keras_hub/src/version_utils.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: keras-hub-nightly
|
3
|
-
Version: 0.16.1.
|
3
|
+
Version: 0.16.1.dev202410040340
|
4
4
|
Summary: Industry-strength Natural Language Processing extensions for Keras.
|
5
5
|
Home-page: https://github.com/keras-team/keras-hub
|
6
6
|
Author: Keras team
|
@@ -1,15 +1,15 @@
|
|
1
1
|
keras_hub/__init__.py,sha256=QGdXyHgYt6cMUAP1ebxwc6oR86dE0dkMxNy2eOCQtFo,855
|
2
2
|
keras_hub/api/__init__.py,sha256=spMxsgqzjpeuC8rY4WP-2kAZ2qwwKRSbFwddXgUjqQE,524
|
3
3
|
keras_hub/api/bounding_box/__init__.py,sha256=T8R_X7BPm0et1xaZq8565uJmid7dylsSFSj4V-rGuFQ,1097
|
4
|
-
keras_hub/api/layers/__init__.py,sha256=
|
4
|
+
keras_hub/api/layers/__init__.py,sha256=XImD0tHdnDR1a7q3u-Pw-VRMASi9sDtrV6hr2beVYTw,2331
|
5
5
|
keras_hub/api/metrics/__init__.py,sha256=So8Ec-lOcTzn_UUMmAdzDm8RKkPu2dbRUm2px8gpUEI,381
|
6
|
-
keras_hub/api/models/__init__.py,sha256=
|
6
|
+
keras_hub/api/models/__init__.py,sha256=sMfVpa2N90cG7qjkwSEI_x3uCvZNwQqFbedn5wcUzbE,14311
|
7
7
|
keras_hub/api/samplers/__init__.py,sha256=n-_SEXxr2LNUzK2FqVFN7alsrkx1P_HOVTeLZKeGCdE,730
|
8
8
|
keras_hub/api/tokenizers/__init__.py,sha256=_f-r_cyUM2fjBB7iO84ThOdqqsAxHNIewJ2EBDlM0cA,2524
|
9
9
|
keras_hub/api/utils/__init__.py,sha256=Gp1E6gG-RtKQS3PBEQEOz9PQvXkXaJ0ySGMqZ7myN7A,215
|
10
10
|
keras_hub/src/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
11
11
|
keras_hub/src/api_export.py,sha256=9pQZK27JObxWZ96QPLBp1OBsjWigh1iuV6RglPGMRk0,1499
|
12
|
-
keras_hub/src/version_utils.py,sha256=
|
12
|
+
keras_hub/src/version_utils.py,sha256=9pXCOZsdoqgw8IovZaLqnG-5LTgH73BlKBZXdOvDpYc,222
|
13
13
|
keras_hub/src/bounding_box/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
14
14
|
keras_hub/src/bounding_box/converters.py,sha256=a5po8DBm87oz2EXfi-0uEZHCMlCJPIb4-MaZIdYx3Dg,17865
|
15
15
|
keras_hub/src/bounding_box/formats.py,sha256=YmskOz2BOSat7NaE__J9VfpSNGPJJR0znSzA4lp8MMI,3868
|
@@ -56,13 +56,13 @@ keras_hub/src/models/feature_pyramid_backbone.py,sha256=clEW-TTQSVJ_5qFNdDF0iABk
|
|
56
56
|
keras_hub/src/models/image_classifier.py,sha256=yt6cjhPfqs8A_eWXBsXdXFzn-aRgH2rVHUq7Zu7CyK8,7804
|
57
57
|
keras_hub/src/models/image_classifier_preprocessor.py,sha256=YdewYfMPVHI7gdhbBI-zVcy4NSfg0bhiOHTmGEKoOYI,2668
|
58
58
|
keras_hub/src/models/image_segmenter.py,sha256=C1bzIO59pG58iist5GLn_qnlotDpcAVxPV_8a68BkAc,2876
|
59
|
-
keras_hub/src/models/image_segmenter_preprocessor.py,sha256=
|
59
|
+
keras_hub/src/models/image_segmenter_preprocessor.py,sha256=IMmVJWBc0VZ1-5jLmFmmwQ3q_oQnhIfCE9A6nS1ss8Q,3743
|
60
60
|
keras_hub/src/models/masked_lm.py,sha256=uXO_dE_hILlOC9jNr6oK6IHi9IGUqLyNGvr6nMt8Rk0,3576
|
61
61
|
keras_hub/src/models/masked_lm_preprocessor.py,sha256=g8vrnyYwqdnSw5xppROM1Gzo_jmMWKYZoQCsKdfrFKk,5656
|
62
62
|
keras_hub/src/models/preprocessor.py,sha256=pJodz7KRVncvsC3o4qoKDYWP2J0a8E9CD6oVGYgJzIM,7970
|
63
63
|
keras_hub/src/models/seq_2_seq_lm.py,sha256=w0gX-5YZjatfvAJmFAgSHyqS_BLqc8FF8DPLGK8mrgI,1864
|
64
64
|
keras_hub/src/models/seq_2_seq_lm_preprocessor.py,sha256=HUHRbWRG5SF1pPpotGzBhXlrMh4pLFxgAoFk05FIrB4,9687
|
65
|
-
keras_hub/src/models/task.py,sha256=
|
65
|
+
keras_hub/src/models/task.py,sha256=2iapEFHvzyl0ASlH6yzQA2OHSr1jV1V-pLtagHdBncQ,14416
|
66
66
|
keras_hub/src/models/text_classifier.py,sha256=VBDvQUHTpJPqKp7A4VAtm35FOmJ3yMo0DW6GdX67xG0,4159
|
67
67
|
keras_hub/src/models/text_classifier_preprocessor.py,sha256=EoWp-GHnaLnAKTdAzDmC-soAV92ATF3QozdubdV2WXI,4722
|
68
68
|
keras_hub/src/models/text_to_image.py,sha256=N42l1W8YEUBHOdGiT4BQNqzTpgjB2O5dtLU5FbKpMy0,10792
|
@@ -115,6 +115,13 @@ keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py,sha256=zEMCLy9eCiBEpA_xM
|
|
115
115
|
keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py,sha256=JbdQV1ZHFX4_PcJhemuHQ5YPkJC-XujgbTyjCrdL7nk,8556
|
116
116
|
keras_hub/src/models/deberta_v3/disentangled_self_attention.py,sha256=3l7Hy7JfiZDDDFE6uTqSuFjg802kXD9acA7aHKRdzJk,13122
|
117
117
|
keras_hub/src/models/deberta_v3/relative_embedding.py,sha256=3WIQ1nWcEhfWF0U9DcKyYz3AAhO3Pmg7ykpzrYe0Jgw,2886
|
118
|
+
keras_hub/src/models/deeplab_v3/__init__.py,sha256=FHAUPM4a1DJj4EsNTbYEd1riNq__uHU4eB3t3Z1zgj0,288
|
119
|
+
keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py,sha256=WyFhuLcjFPFVuNL09bvW_Jsja7WjEw3zazKCOwbFDTM,7709
|
120
|
+
keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py,sha256=mRkH3HdhpV0fCcQcVXEvIX7SNk-bAMb3SAHzgK-FD5c,371
|
121
|
+
keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py,sha256=hR9S6lNYamY0EBDBo3e1qTCiwtftmLXrN-UYuzfw5Io,581
|
122
|
+
keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py,sha256=qmEiolOOriLAojXB67xXW9IOo717kaCGeDVZJLaGY98,7834
|
123
|
+
keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py,sha256=JE-Uv_CXDVnwNkTcy5GtzrIQXvcXSP7fdVpfhmk8qhg,122
|
124
|
+
keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py,sha256=tiMDcCFANHMUx3IVtW3r1P_JTazgPPsbW4IktIytKEU,3650
|
118
125
|
keras_hub/src/models/densenet/__init__.py,sha256=r7StyamnWeeZxOk9r4ZYNbS_YVhu9YGPyXhNxljvdPg,269
|
119
126
|
keras_hub/src/models/densenet/densenet_backbone.py,sha256=dN9lUwKzO3E2HthNV2x54ozeBEQ0ilNs5uYHshFQpT0,6723
|
120
127
|
keras_hub/src/models/densenet/densenet_image_classifier.py,sha256=ptuV6PwgoUpmrSPqX7-a85IpWsElwcCv_G5IVkP9E_Q,530
|
@@ -263,14 +270,13 @@ keras_hub/src/models/sam/sam_presets.py,sha256=AfGUKNOkz0G11OMYqVebXKgEBar1qpIkA
|
|
263
270
|
keras_hub/src/models/sam/sam_prompt_encoder.py,sha256=2foB7900QbzQfZjBo335XYsdjmhOnVT8fKD1CubJNVE,11801
|
264
271
|
keras_hub/src/models/sam/sam_transformer.py,sha256=L2bdxdc2RUF1juRZ0F0Z6r0gTva1sUwEdjItJmKKf6w,5730
|
265
272
|
keras_hub/src/models/stable_diffusion_3/__init__.py,sha256=ZKYQuaRObyhKq8GVAHmoRvlXp6FpU8ChvutVCHyXKuc,343
|
266
|
-
keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py,sha256=
|
273
|
+
keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py,sha256=vtVhieAv277mAiZj7Kvvqg_Ba7klfQxZVk4PPxNNQ0s,3062
|
267
274
|
keras_hub/src/models/stable_diffusion_3/mmdit.py,sha256=ntmxjDJtZbHDGVPPAnasVZyoOTp5bbMPhxM30SYmpoQ,25711
|
268
|
-
keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py,sha256=
|
269
|
-
keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py,sha256=
|
270
|
-
keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py,sha256=
|
275
|
+
keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py,sha256=D-U5T6UYKzraHLAgMa-LLcd40ZmX_5rmlybawT4ooHY,21398
|
276
|
+
keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py,sha256=sBGVRFd-bYxcqydydOB70XpOtpTt6AVrTR3LV-LBFXY,662
|
277
|
+
keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py,sha256=8nP3ejDOd1hqjYXJzbri62PgtclxGydw-8bw-qHIPdc,4414
|
271
278
|
keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py,sha256=TB0KESt5dnFYiS292PbzB0LdiH23AD6aTSTGmQEuzGM,2742
|
272
279
|
keras_hub/src/models/stable_diffusion_3/t5_encoder.py,sha256=oV7P1uwCKdGiD93zXq7kmqX0elMZQU4UvBa8wg6P1hs,5113
|
273
|
-
keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py,sha256=j4nrvfYhW-4njhhk4PFf-bWQF-EHzplvaT15Q7s5Pb4,10056
|
274
280
|
keras_hub/src/models/t5/__init__.py,sha256=OWyoUeDY3v4DnO8Ry02DWV1bNSVGcC89PF9oCftyi1s,233
|
275
281
|
keras_hub/src/models/t5/t5_backbone.py,sha256=AtE2VudEUkm7hE3p6JP_CfEAjt4pwgSKOBQ0B0BggQc,10258
|
276
282
|
keras_hub/src/models/t5/t5_layer_norm.py,sha256=R8KPHFOq9N3SD013WjtloLWRzaEMNEyY0fbViNEFVXQ,630
|
@@ -279,6 +285,9 @@ keras_hub/src/models/t5/t5_preprocessor.py,sha256=UVOnCHUJF_MBcOyfR9G9oeRUEoN3Xo
|
|
279
285
|
keras_hub/src/models/t5/t5_presets.py,sha256=95zU4cTNEZMH2yiCLptA9zhu2D4mE1Cay18K91nt7jM,3005
|
280
286
|
keras_hub/src/models/t5/t5_tokenizer.py,sha256=pLTu15JeYSpVmy-2600vBc-Mxn_uHyTKts4PI2MxxBM,2517
|
281
287
|
keras_hub/src/models/t5/t5_transformer_layer.py,sha256=uDeP84F1x7xJxki5iKe12Zn6eWD_4yVjoFXMuod-a3A,5347
|
288
|
+
keras_hub/src/models/vae/__init__.py,sha256=i3UaSW4IJf76O7lSPE1dyxOVjuHx8iAYKivqvUbDHOw,62
|
289
|
+
keras_hub/src/models/vae/vae_backbone.py,sha256=aYf1sGteFJ7FyR3X8Ek6QBjAT5GjRtQTK2jXhYVJeM4,6671
|
290
|
+
keras_hub/src/models/vae/vae_layers.py,sha256=N83CYM1zgbl1EIjAOs3cFCkJEdxvbXkgM9ghKyljFAg,27752
|
282
291
|
keras_hub/src/models/vgg/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
283
292
|
keras_hub/src/models/vgg/vgg_backbone.py,sha256=QnEDKn5n9bA9p3nvt5fBHnAssvnLxR0qv-oB372Ts0U,3702
|
284
293
|
keras_hub/src/models/vgg/vgg_image_classifier.py,sha256=Dtq_HIJP6fHe8m7ZVLVn8IbHEsVMFWLvWMmn8TU1ntw,6600
|
@@ -350,7 +359,7 @@ keras_hub/src/utils/transformers/convert_mistral.py,sha256=kVhN9h1ZFVhwkNW8p3wnS
|
|
350
359
|
keras_hub/src/utils/transformers/convert_pali_gemma.py,sha256=B1leeDw96Yvu81hYumf66hIid07k5NLqoeWAJgPnaLs,10649
|
351
360
|
keras_hub/src/utils/transformers/preset_loader.py,sha256=GS44hZUuGQCtzsyn8z44ZpHdftd3DFemwV2hx2bQa-U,2738
|
352
361
|
keras_hub/src/utils/transformers/safetensor_utils.py,sha256=rPK-Uw1CG0DX0d_UAD-r2cG9fw8GI8bvAlrcXfQ9g4c,3323
|
353
|
-
keras_hub_nightly-0.16.1.
|
354
|
-
keras_hub_nightly-0.16.1.
|
355
|
-
keras_hub_nightly-0.16.1.
|
356
|
-
keras_hub_nightly-0.16.1.
|
362
|
+
keras_hub_nightly-0.16.1.dev202410040340.dist-info/METADATA,sha256=BuGQEiANSPxHuXnmW4K_EEYaGBE2ENMMSMARmCthK70,7458
|
363
|
+
keras_hub_nightly-0.16.1.dev202410040340.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
|
364
|
+
keras_hub_nightly-0.16.1.dev202410040340.dist-info/top_level.txt,sha256=N4J6piIWBKa38A4uV-CnIopnOEf8mHAbkNXafXm_CuA,10
|
365
|
+
keras_hub_nightly-0.16.1.dev202410040340.dist-info/RECORD,,
|
@@ -1,320 +0,0 @@
|
|
1
|
-
import math
|
2
|
-
|
3
|
-
from keras import layers
|
4
|
-
from keras import ops
|
5
|
-
|
6
|
-
from keras_hub.src.models.backbone import Backbone
|
7
|
-
from keras_hub.src.utils.keras_utils import standardize_data_format
|
8
|
-
|
9
|
-
|
10
|
-
class VAEAttention(layers.Layer):
|
11
|
-
def __init__(self, filters, groups=32, data_format=None, **kwargs):
|
12
|
-
super().__init__(**kwargs)
|
13
|
-
self.filters = filters
|
14
|
-
self.data_format = standardize_data_format(data_format)
|
15
|
-
gn_axis = -1 if self.data_format == "channels_last" else 1
|
16
|
-
|
17
|
-
self.group_norm = layers.GroupNormalization(
|
18
|
-
groups=groups,
|
19
|
-
axis=gn_axis,
|
20
|
-
epsilon=1e-6,
|
21
|
-
dtype="float32",
|
22
|
-
name="group_norm",
|
23
|
-
)
|
24
|
-
self.query_conv2d = layers.Conv2D(
|
25
|
-
filters,
|
26
|
-
1,
|
27
|
-
1,
|
28
|
-
data_format=self.data_format,
|
29
|
-
dtype=self.dtype_policy,
|
30
|
-
name="query_conv2d",
|
31
|
-
)
|
32
|
-
self.key_conv2d = layers.Conv2D(
|
33
|
-
filters,
|
34
|
-
1,
|
35
|
-
1,
|
36
|
-
data_format=self.data_format,
|
37
|
-
dtype=self.dtype_policy,
|
38
|
-
name="key_conv2d",
|
39
|
-
)
|
40
|
-
self.value_conv2d = layers.Conv2D(
|
41
|
-
filters,
|
42
|
-
1,
|
43
|
-
1,
|
44
|
-
data_format=self.data_format,
|
45
|
-
dtype=self.dtype_policy,
|
46
|
-
name="value_conv2d",
|
47
|
-
)
|
48
|
-
self.softmax = layers.Softmax(dtype="float32")
|
49
|
-
self.output_conv2d = layers.Conv2D(
|
50
|
-
filters,
|
51
|
-
1,
|
52
|
-
1,
|
53
|
-
data_format=self.data_format,
|
54
|
-
dtype=self.dtype_policy,
|
55
|
-
name="output_conv2d",
|
56
|
-
)
|
57
|
-
|
58
|
-
self.groups = groups
|
59
|
-
self._inverse_sqrt_filters = 1.0 / math.sqrt(float(filters))
|
60
|
-
|
61
|
-
def build(self, input_shape):
|
62
|
-
self.group_norm.build(input_shape)
|
63
|
-
self.query_conv2d.build(input_shape)
|
64
|
-
self.key_conv2d.build(input_shape)
|
65
|
-
self.value_conv2d.build(input_shape)
|
66
|
-
self.output_conv2d.build(input_shape)
|
67
|
-
|
68
|
-
def call(self, inputs, training=None):
|
69
|
-
x = self.group_norm(inputs)
|
70
|
-
query = self.query_conv2d(x)
|
71
|
-
key = self.key_conv2d(x)
|
72
|
-
value = self.value_conv2d(x)
|
73
|
-
|
74
|
-
if self.data_format == "channels_first":
|
75
|
-
query = ops.transpose(query, (0, 2, 3, 1))
|
76
|
-
key = ops.transpose(key, (0, 2, 3, 1))
|
77
|
-
value = ops.transpose(value, (0, 2, 3, 1))
|
78
|
-
shape = ops.shape(inputs)
|
79
|
-
b = shape[0]
|
80
|
-
query = ops.reshape(query, (b, -1, self.filters))
|
81
|
-
key = ops.reshape(key, (b, -1, self.filters))
|
82
|
-
value = ops.reshape(value, (b, -1, self.filters))
|
83
|
-
|
84
|
-
# Compute attention.
|
85
|
-
query = ops.multiply(
|
86
|
-
query, ops.cast(self._inverse_sqrt_filters, query.dtype)
|
87
|
-
)
|
88
|
-
# [B, H0 * W0, C], [B, H1 * W1, C] -> [B, H0 * W0, H1 * W1]
|
89
|
-
attention_scores = ops.einsum("abc,adc->abd", query, key)
|
90
|
-
attention_scores = ops.cast(
|
91
|
-
self.softmax(attention_scores), self.compute_dtype
|
92
|
-
)
|
93
|
-
# [B, H2 * W2, C], [B, H0 * W0, H1 * W1] -> [B, H1 * W1 ,C]
|
94
|
-
attention_output = ops.einsum("abc,adb->adc", value, attention_scores)
|
95
|
-
x = ops.reshape(attention_output, shape)
|
96
|
-
|
97
|
-
x = self.output_conv2d(x)
|
98
|
-
if self.data_format == "channels_first":
|
99
|
-
x = ops.transpose(x, (0, 3, 1, 2))
|
100
|
-
x = ops.add(x, inputs)
|
101
|
-
return x
|
102
|
-
|
103
|
-
def get_config(self):
|
104
|
-
config = super().get_config()
|
105
|
-
config.update(
|
106
|
-
{
|
107
|
-
"filters": self.filters,
|
108
|
-
"groups": self.groups,
|
109
|
-
}
|
110
|
-
)
|
111
|
-
return config
|
112
|
-
|
113
|
-
def compute_output_shape(self, input_shape):
|
114
|
-
return input_shape
|
115
|
-
|
116
|
-
|
117
|
-
def apply_resnet_block(x, filters, data_format=None, dtype=None, name=None):
|
118
|
-
data_format = standardize_data_format(data_format)
|
119
|
-
gn_axis = -1 if data_format == "channels_last" else 1
|
120
|
-
input_filters = x.shape[gn_axis]
|
121
|
-
|
122
|
-
residual = x
|
123
|
-
x = layers.GroupNormalization(
|
124
|
-
groups=32,
|
125
|
-
axis=gn_axis,
|
126
|
-
epsilon=1e-6,
|
127
|
-
dtype="float32",
|
128
|
-
name=f"{name}_norm1",
|
129
|
-
)(x)
|
130
|
-
x = layers.Activation("swish", dtype=dtype)(x)
|
131
|
-
x = layers.Conv2D(
|
132
|
-
filters,
|
133
|
-
3,
|
134
|
-
1,
|
135
|
-
padding="same",
|
136
|
-
data_format=data_format,
|
137
|
-
dtype=dtype,
|
138
|
-
name=f"{name}_conv1",
|
139
|
-
)(x)
|
140
|
-
x = layers.GroupNormalization(
|
141
|
-
groups=32,
|
142
|
-
axis=gn_axis,
|
143
|
-
epsilon=1e-6,
|
144
|
-
dtype="float32",
|
145
|
-
name=f"{name}_norm2",
|
146
|
-
)(x)
|
147
|
-
x = layers.Activation("swish", dtype=dtype)(x)
|
148
|
-
x = layers.Conv2D(
|
149
|
-
filters,
|
150
|
-
3,
|
151
|
-
1,
|
152
|
-
padding="same",
|
153
|
-
data_format=data_format,
|
154
|
-
dtype=dtype,
|
155
|
-
name=f"{name}_conv2",
|
156
|
-
)(x)
|
157
|
-
if input_filters != filters:
|
158
|
-
residual = layers.Conv2D(
|
159
|
-
filters,
|
160
|
-
1,
|
161
|
-
1,
|
162
|
-
data_format=data_format,
|
163
|
-
dtype=dtype,
|
164
|
-
name=f"{name}_residual_projection",
|
165
|
-
)(residual)
|
166
|
-
x = layers.Add(dtype=dtype)([residual, x])
|
167
|
-
return x
|
168
|
-
|
169
|
-
|
170
|
-
class VAEImageDecoder(Backbone):
|
171
|
-
"""Decoder for the VAE model used in Stable Diffusion 3.
|
172
|
-
|
173
|
-
Args:
|
174
|
-
stackwise_num_filters: list of ints. The number of filters for each
|
175
|
-
stack.
|
176
|
-
stackwise_num_blocks: list of ints. The number of blocks for each stack.
|
177
|
-
output_channels: int. The number of channels in the output.
|
178
|
-
latent_shape: tuple. The shape of the latent image.
|
179
|
-
data_format: `None` or str. If specified, either `"channels_last"` or
|
180
|
-
`"channels_first"`. The ordering of the dimensions in the
|
181
|
-
inputs. `"channels_last"` corresponds to inputs with shape
|
182
|
-
`(batch_size, height, width, channels)`
|
183
|
-
while `"channels_first"` corresponds to inputs with shape
|
184
|
-
`(batch_size, channels, height, width)`. It defaults to the
|
185
|
-
`image_data_format` value found in your Keras config file at
|
186
|
-
`~/.keras/keras.json`. If you never set it, then it will be
|
187
|
-
`"channels_last"`.
|
188
|
-
dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype
|
189
|
-
to use for the model's computations and weights.
|
190
|
-
"""
|
191
|
-
|
192
|
-
def __init__(
|
193
|
-
self,
|
194
|
-
stackwise_num_filters,
|
195
|
-
stackwise_num_blocks,
|
196
|
-
output_channels=3,
|
197
|
-
latent_shape=(None, None, 16),
|
198
|
-
data_format=None,
|
199
|
-
dtype=None,
|
200
|
-
**kwargs,
|
201
|
-
):
|
202
|
-
data_format = standardize_data_format(data_format)
|
203
|
-
gn_axis = -1 if data_format == "channels_last" else 1
|
204
|
-
|
205
|
-
# === Functional Model ===
|
206
|
-
latent_inputs = layers.Input(shape=latent_shape)
|
207
|
-
|
208
|
-
x = layers.Conv2D(
|
209
|
-
stackwise_num_filters[0],
|
210
|
-
3,
|
211
|
-
1,
|
212
|
-
padding="same",
|
213
|
-
data_format=data_format,
|
214
|
-
dtype=dtype,
|
215
|
-
name="input_projection",
|
216
|
-
)(latent_inputs)
|
217
|
-
x = apply_resnet_block(
|
218
|
-
x,
|
219
|
-
stackwise_num_filters[0],
|
220
|
-
data_format=data_format,
|
221
|
-
dtype=dtype,
|
222
|
-
name="input_block0",
|
223
|
-
)
|
224
|
-
x = VAEAttention(
|
225
|
-
stackwise_num_filters[0],
|
226
|
-
data_format=data_format,
|
227
|
-
dtype=dtype,
|
228
|
-
name="input_attention",
|
229
|
-
)(x)
|
230
|
-
x = apply_resnet_block(
|
231
|
-
x,
|
232
|
-
stackwise_num_filters[0],
|
233
|
-
data_format=data_format,
|
234
|
-
dtype=dtype,
|
235
|
-
name="input_block1",
|
236
|
-
)
|
237
|
-
|
238
|
-
# Stacks.
|
239
|
-
for i, filters in enumerate(stackwise_num_filters):
|
240
|
-
for j in range(stackwise_num_blocks[i]):
|
241
|
-
x = apply_resnet_block(
|
242
|
-
x,
|
243
|
-
filters,
|
244
|
-
data_format=data_format,
|
245
|
-
dtype=dtype,
|
246
|
-
name=f"block{i}_{j}",
|
247
|
-
)
|
248
|
-
if i != len(stackwise_num_filters) - 1:
|
249
|
-
# No upsamling in the last blcok.
|
250
|
-
x = layers.UpSampling2D(
|
251
|
-
2,
|
252
|
-
data_format=data_format,
|
253
|
-
dtype=dtype,
|
254
|
-
name=f"upsample_{i}",
|
255
|
-
)(x)
|
256
|
-
x = layers.Conv2D(
|
257
|
-
filters,
|
258
|
-
3,
|
259
|
-
1,
|
260
|
-
padding="same",
|
261
|
-
data_format=data_format,
|
262
|
-
dtype=dtype,
|
263
|
-
name=f"upsample_{i}_conv",
|
264
|
-
)(x)
|
265
|
-
|
266
|
-
# Ouput block.
|
267
|
-
x = layers.GroupNormalization(
|
268
|
-
groups=32,
|
269
|
-
axis=gn_axis,
|
270
|
-
epsilon=1e-6,
|
271
|
-
dtype="float32",
|
272
|
-
name="output_norm",
|
273
|
-
)(x)
|
274
|
-
x = layers.Activation("swish", dtype=dtype, name="output_activation")(x)
|
275
|
-
image_outputs = layers.Conv2D(
|
276
|
-
output_channels,
|
277
|
-
3,
|
278
|
-
1,
|
279
|
-
padding="same",
|
280
|
-
data_format=data_format,
|
281
|
-
dtype=dtype,
|
282
|
-
name="output_projection",
|
283
|
-
)(x)
|
284
|
-
super().__init__(inputs=latent_inputs, outputs=image_outputs, **kwargs)
|
285
|
-
|
286
|
-
# === Config ===
|
287
|
-
self.stackwise_num_filters = stackwise_num_filters
|
288
|
-
self.stackwise_num_blocks = stackwise_num_blocks
|
289
|
-
self.output_channels = output_channels
|
290
|
-
self.latent_shape = latent_shape
|
291
|
-
|
292
|
-
@property
|
293
|
-
def scaling_factor(self):
|
294
|
-
"""The scaling factor for the latent space.
|
295
|
-
|
296
|
-
This is used to scale the latent space to have unit variance when
|
297
|
-
training the diffusion model.
|
298
|
-
"""
|
299
|
-
return 1.5305
|
300
|
-
|
301
|
-
@property
|
302
|
-
def shift_factor(self):
|
303
|
-
"""The shift factor for the latent space.
|
304
|
-
|
305
|
-
This is used to shift the latent space to have zero mean when
|
306
|
-
training the diffusion model.
|
307
|
-
"""
|
308
|
-
return 0.0609
|
309
|
-
|
310
|
-
def get_config(self):
|
311
|
-
config = super().get_config()
|
312
|
-
config.update(
|
313
|
-
{
|
314
|
-
"stackwise_num_filters": self.stackwise_num_filters,
|
315
|
-
"stackwise_num_blocks": self.stackwise_num_blocks,
|
316
|
-
"output_channels": self.output_channels,
|
317
|
-
"image_shape": self.latent_shape,
|
318
|
-
}
|
319
|
-
)
|
320
|
-
return config
|
File without changes
|