keras-hub-nightly 0.16.1.dev202410020340__py3-none-any.whl → 0.16.1.dev202410040340__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/layers/__init__.py +3 -3
- keras_hub/api/models/__init__.py +10 -1
- keras_hub/src/layers/preprocessing/audio_converter.py +3 -7
- keras_hub/src/layers/preprocessing/image_converter.py +164 -34
- keras_hub/src/models/backbone.py +3 -9
- keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +0 -109
- keras_hub/src/models/deeplab_v3/__init__.py +7 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +196 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py +10 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py +16 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +215 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +4 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +109 -0
- keras_hub/src/models/densenet/densenet_image_classifier.py +0 -128
- keras_hub/src/models/densenet/densenet_image_converter.py +2 -4
- keras_hub/src/models/feature_pyramid_backbone.py +1 -1
- keras_hub/src/models/image_classifier.py +147 -2
- keras_hub/src/models/image_classifier_preprocessor.py +3 -3
- keras_hub/src/models/image_segmenter.py +0 -5
- keras_hub/src/models/image_segmenter_preprocessor.py +29 -4
- keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +0 -109
- keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +0 -92
- keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +2 -4
- keras_hub/src/models/preprocessor.py +3 -5
- keras_hub/src/models/resnet/resnet_backbone.py +1 -11
- keras_hub/src/models/resnet/resnet_image_classifier.py +0 -137
- keras_hub/src/models/resnet/resnet_image_converter.py +2 -4
- keras_hub/src/models/sam/__init__.py +5 -0
- keras_hub/src/models/sam/sam_image_converter.py +2 -4
- keras_hub/src/models/sam/sam_image_segmenter_preprocessor.py +11 -1
- keras_hub/src/models/sam/sam_presets.py +3 -3
- keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +8 -1
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +57 -93
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +3 -3
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +5 -3
- keras_hub/src/models/task.py +39 -36
- keras_hub/src/models/vae/__init__.py +1 -0
- keras_hub/src/models/vae/vae_backbone.py +172 -0
- keras_hub/src/models/vae/vae_layers.py +740 -0
- keras_hub/src/models/vgg/vgg_backbone.py +1 -20
- keras_hub/src/models/vgg/vgg_image_classifier.py +108 -29
- keras_hub/src/tokenizers/tokenizer.py +3 -6
- keras_hub/src/utils/preset_utils.py +103 -61
- keras_hub/src/utils/timm/preset_loader.py +8 -9
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.16.1.dev202410040340.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.16.1.dev202410040340.dist-info}/RECORD +49 -41
- keras_hub/src/layers/preprocessing/resizing_image_converter.py +0 -138
- keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +0 -320
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.16.1.dev202410040340.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.16.1.dev202410040340.dist-info}/top_level.txt +0 -0
@@ -1,320 +0,0 @@
|
|
1
|
-
import math
|
2
|
-
|
3
|
-
from keras import layers
|
4
|
-
from keras import ops
|
5
|
-
|
6
|
-
from keras_hub.src.models.backbone import Backbone
|
7
|
-
from keras_hub.src.utils.keras_utils import standardize_data_format
|
8
|
-
|
9
|
-
|
10
|
-
class VAEAttention(layers.Layer):
|
11
|
-
def __init__(self, filters, groups=32, data_format=None, **kwargs):
|
12
|
-
super().__init__(**kwargs)
|
13
|
-
self.filters = filters
|
14
|
-
self.data_format = standardize_data_format(data_format)
|
15
|
-
gn_axis = -1 if self.data_format == "channels_last" else 1
|
16
|
-
|
17
|
-
self.group_norm = layers.GroupNormalization(
|
18
|
-
groups=groups,
|
19
|
-
axis=gn_axis,
|
20
|
-
epsilon=1e-6,
|
21
|
-
dtype="float32",
|
22
|
-
name="group_norm",
|
23
|
-
)
|
24
|
-
self.query_conv2d = layers.Conv2D(
|
25
|
-
filters,
|
26
|
-
1,
|
27
|
-
1,
|
28
|
-
data_format=self.data_format,
|
29
|
-
dtype=self.dtype_policy,
|
30
|
-
name="query_conv2d",
|
31
|
-
)
|
32
|
-
self.key_conv2d = layers.Conv2D(
|
33
|
-
filters,
|
34
|
-
1,
|
35
|
-
1,
|
36
|
-
data_format=self.data_format,
|
37
|
-
dtype=self.dtype_policy,
|
38
|
-
name="key_conv2d",
|
39
|
-
)
|
40
|
-
self.value_conv2d = layers.Conv2D(
|
41
|
-
filters,
|
42
|
-
1,
|
43
|
-
1,
|
44
|
-
data_format=self.data_format,
|
45
|
-
dtype=self.dtype_policy,
|
46
|
-
name="value_conv2d",
|
47
|
-
)
|
48
|
-
self.softmax = layers.Softmax(dtype="float32")
|
49
|
-
self.output_conv2d = layers.Conv2D(
|
50
|
-
filters,
|
51
|
-
1,
|
52
|
-
1,
|
53
|
-
data_format=self.data_format,
|
54
|
-
dtype=self.dtype_policy,
|
55
|
-
name="output_conv2d",
|
56
|
-
)
|
57
|
-
|
58
|
-
self.groups = groups
|
59
|
-
self._inverse_sqrt_filters = 1.0 / math.sqrt(float(filters))
|
60
|
-
|
61
|
-
def build(self, input_shape):
|
62
|
-
self.group_norm.build(input_shape)
|
63
|
-
self.query_conv2d.build(input_shape)
|
64
|
-
self.key_conv2d.build(input_shape)
|
65
|
-
self.value_conv2d.build(input_shape)
|
66
|
-
self.output_conv2d.build(input_shape)
|
67
|
-
|
68
|
-
def call(self, inputs, training=None):
|
69
|
-
x = self.group_norm(inputs)
|
70
|
-
query = self.query_conv2d(x)
|
71
|
-
key = self.key_conv2d(x)
|
72
|
-
value = self.value_conv2d(x)
|
73
|
-
|
74
|
-
if self.data_format == "channels_first":
|
75
|
-
query = ops.transpose(query, (0, 2, 3, 1))
|
76
|
-
key = ops.transpose(key, (0, 2, 3, 1))
|
77
|
-
value = ops.transpose(value, (0, 2, 3, 1))
|
78
|
-
shape = ops.shape(inputs)
|
79
|
-
b = shape[0]
|
80
|
-
query = ops.reshape(query, (b, -1, self.filters))
|
81
|
-
key = ops.reshape(key, (b, -1, self.filters))
|
82
|
-
value = ops.reshape(value, (b, -1, self.filters))
|
83
|
-
|
84
|
-
# Compute attention.
|
85
|
-
query = ops.multiply(
|
86
|
-
query, ops.cast(self._inverse_sqrt_filters, query.dtype)
|
87
|
-
)
|
88
|
-
# [B, H0 * W0, C], [B, H1 * W1, C] -> [B, H0 * W0, H1 * W1]
|
89
|
-
attention_scores = ops.einsum("abc,adc->abd", query, key)
|
90
|
-
attention_scores = ops.cast(
|
91
|
-
self.softmax(attention_scores), self.compute_dtype
|
92
|
-
)
|
93
|
-
# [B, H2 * W2, C], [B, H0 * W0, H1 * W1] -> [B, H1 * W1 ,C]
|
94
|
-
attention_output = ops.einsum("abc,adb->adc", value, attention_scores)
|
95
|
-
x = ops.reshape(attention_output, shape)
|
96
|
-
|
97
|
-
x = self.output_conv2d(x)
|
98
|
-
if self.data_format == "channels_first":
|
99
|
-
x = ops.transpose(x, (0, 3, 1, 2))
|
100
|
-
x = ops.add(x, inputs)
|
101
|
-
return x
|
102
|
-
|
103
|
-
def get_config(self):
|
104
|
-
config = super().get_config()
|
105
|
-
config.update(
|
106
|
-
{
|
107
|
-
"filters": self.filters,
|
108
|
-
"groups": self.groups,
|
109
|
-
}
|
110
|
-
)
|
111
|
-
return config
|
112
|
-
|
113
|
-
def compute_output_shape(self, input_shape):
|
114
|
-
return input_shape
|
115
|
-
|
116
|
-
|
117
|
-
def apply_resnet_block(x, filters, data_format=None, dtype=None, name=None):
|
118
|
-
data_format = standardize_data_format(data_format)
|
119
|
-
gn_axis = -1 if data_format == "channels_last" else 1
|
120
|
-
input_filters = x.shape[gn_axis]
|
121
|
-
|
122
|
-
residual = x
|
123
|
-
x = layers.GroupNormalization(
|
124
|
-
groups=32,
|
125
|
-
axis=gn_axis,
|
126
|
-
epsilon=1e-6,
|
127
|
-
dtype="float32",
|
128
|
-
name=f"{name}_norm1",
|
129
|
-
)(x)
|
130
|
-
x = layers.Activation("swish", dtype=dtype)(x)
|
131
|
-
x = layers.Conv2D(
|
132
|
-
filters,
|
133
|
-
3,
|
134
|
-
1,
|
135
|
-
padding="same",
|
136
|
-
data_format=data_format,
|
137
|
-
dtype=dtype,
|
138
|
-
name=f"{name}_conv1",
|
139
|
-
)(x)
|
140
|
-
x = layers.GroupNormalization(
|
141
|
-
groups=32,
|
142
|
-
axis=gn_axis,
|
143
|
-
epsilon=1e-6,
|
144
|
-
dtype="float32",
|
145
|
-
name=f"{name}_norm2",
|
146
|
-
)(x)
|
147
|
-
x = layers.Activation("swish", dtype=dtype)(x)
|
148
|
-
x = layers.Conv2D(
|
149
|
-
filters,
|
150
|
-
3,
|
151
|
-
1,
|
152
|
-
padding="same",
|
153
|
-
data_format=data_format,
|
154
|
-
dtype=dtype,
|
155
|
-
name=f"{name}_conv2",
|
156
|
-
)(x)
|
157
|
-
if input_filters != filters:
|
158
|
-
residual = layers.Conv2D(
|
159
|
-
filters,
|
160
|
-
1,
|
161
|
-
1,
|
162
|
-
data_format=data_format,
|
163
|
-
dtype=dtype,
|
164
|
-
name=f"{name}_residual_projection",
|
165
|
-
)(residual)
|
166
|
-
x = layers.Add(dtype=dtype)([residual, x])
|
167
|
-
return x
|
168
|
-
|
169
|
-
|
170
|
-
class VAEImageDecoder(Backbone):
|
171
|
-
"""Decoder for the VAE model used in Stable Diffusion 3.
|
172
|
-
|
173
|
-
Args:
|
174
|
-
stackwise_num_filters: list of ints. The number of filters for each
|
175
|
-
stack.
|
176
|
-
stackwise_num_blocks: list of ints. The number of blocks for each stack.
|
177
|
-
output_channels: int. The number of channels in the output.
|
178
|
-
latent_shape: tuple. The shape of the latent image.
|
179
|
-
data_format: `None` or str. If specified, either `"channels_last"` or
|
180
|
-
`"channels_first"`. The ordering of the dimensions in the
|
181
|
-
inputs. `"channels_last"` corresponds to inputs with shape
|
182
|
-
`(batch_size, height, width, channels)`
|
183
|
-
while `"channels_first"` corresponds to inputs with shape
|
184
|
-
`(batch_size, channels, height, width)`. It defaults to the
|
185
|
-
`image_data_format` value found in your Keras config file at
|
186
|
-
`~/.keras/keras.json`. If you never set it, then it will be
|
187
|
-
`"channels_last"`.
|
188
|
-
dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype
|
189
|
-
to use for the model's computations and weights.
|
190
|
-
"""
|
191
|
-
|
192
|
-
def __init__(
|
193
|
-
self,
|
194
|
-
stackwise_num_filters,
|
195
|
-
stackwise_num_blocks,
|
196
|
-
output_channels=3,
|
197
|
-
latent_shape=(None, None, 16),
|
198
|
-
data_format=None,
|
199
|
-
dtype=None,
|
200
|
-
**kwargs,
|
201
|
-
):
|
202
|
-
data_format = standardize_data_format(data_format)
|
203
|
-
gn_axis = -1 if data_format == "channels_last" else 1
|
204
|
-
|
205
|
-
# === Functional Model ===
|
206
|
-
latent_inputs = layers.Input(shape=latent_shape)
|
207
|
-
|
208
|
-
x = layers.Conv2D(
|
209
|
-
stackwise_num_filters[0],
|
210
|
-
3,
|
211
|
-
1,
|
212
|
-
padding="same",
|
213
|
-
data_format=data_format,
|
214
|
-
dtype=dtype,
|
215
|
-
name="input_projection",
|
216
|
-
)(latent_inputs)
|
217
|
-
x = apply_resnet_block(
|
218
|
-
x,
|
219
|
-
stackwise_num_filters[0],
|
220
|
-
data_format=data_format,
|
221
|
-
dtype=dtype,
|
222
|
-
name="input_block0",
|
223
|
-
)
|
224
|
-
x = VAEAttention(
|
225
|
-
stackwise_num_filters[0],
|
226
|
-
data_format=data_format,
|
227
|
-
dtype=dtype,
|
228
|
-
name="input_attention",
|
229
|
-
)(x)
|
230
|
-
x = apply_resnet_block(
|
231
|
-
x,
|
232
|
-
stackwise_num_filters[0],
|
233
|
-
data_format=data_format,
|
234
|
-
dtype=dtype,
|
235
|
-
name="input_block1",
|
236
|
-
)
|
237
|
-
|
238
|
-
# Stacks.
|
239
|
-
for i, filters in enumerate(stackwise_num_filters):
|
240
|
-
for j in range(stackwise_num_blocks[i]):
|
241
|
-
x = apply_resnet_block(
|
242
|
-
x,
|
243
|
-
filters,
|
244
|
-
data_format=data_format,
|
245
|
-
dtype=dtype,
|
246
|
-
name=f"block{i}_{j}",
|
247
|
-
)
|
248
|
-
if i != len(stackwise_num_filters) - 1:
|
249
|
-
# No upsamling in the last blcok.
|
250
|
-
x = layers.UpSampling2D(
|
251
|
-
2,
|
252
|
-
data_format=data_format,
|
253
|
-
dtype=dtype,
|
254
|
-
name=f"upsample_{i}",
|
255
|
-
)(x)
|
256
|
-
x = layers.Conv2D(
|
257
|
-
filters,
|
258
|
-
3,
|
259
|
-
1,
|
260
|
-
padding="same",
|
261
|
-
data_format=data_format,
|
262
|
-
dtype=dtype,
|
263
|
-
name=f"upsample_{i}_conv",
|
264
|
-
)(x)
|
265
|
-
|
266
|
-
# Ouput block.
|
267
|
-
x = layers.GroupNormalization(
|
268
|
-
groups=32,
|
269
|
-
axis=gn_axis,
|
270
|
-
epsilon=1e-6,
|
271
|
-
dtype="float32",
|
272
|
-
name="output_norm",
|
273
|
-
)(x)
|
274
|
-
x = layers.Activation("swish", dtype=dtype, name="output_activation")(x)
|
275
|
-
image_outputs = layers.Conv2D(
|
276
|
-
output_channels,
|
277
|
-
3,
|
278
|
-
1,
|
279
|
-
padding="same",
|
280
|
-
data_format=data_format,
|
281
|
-
dtype=dtype,
|
282
|
-
name="output_projection",
|
283
|
-
)(x)
|
284
|
-
super().__init__(inputs=latent_inputs, outputs=image_outputs, **kwargs)
|
285
|
-
|
286
|
-
# === Config ===
|
287
|
-
self.stackwise_num_filters = stackwise_num_filters
|
288
|
-
self.stackwise_num_blocks = stackwise_num_blocks
|
289
|
-
self.output_channels = output_channels
|
290
|
-
self.latent_shape = latent_shape
|
291
|
-
|
292
|
-
@property
|
293
|
-
def scaling_factor(self):
|
294
|
-
"""The scaling factor for the latent space.
|
295
|
-
|
296
|
-
This is used to scale the latent space to have unit variance when
|
297
|
-
training the diffusion model.
|
298
|
-
"""
|
299
|
-
return 1.5305
|
300
|
-
|
301
|
-
@property
|
302
|
-
def shift_factor(self):
|
303
|
-
"""The shift factor for the latent space.
|
304
|
-
|
305
|
-
This is used to shift the latent space to have zero mean when
|
306
|
-
training the diffusion model.
|
307
|
-
"""
|
308
|
-
return 0.0609
|
309
|
-
|
310
|
-
def get_config(self):
|
311
|
-
config = super().get_config()
|
312
|
-
config.update(
|
313
|
-
{
|
314
|
-
"stackwise_num_filters": self.stackwise_num_filters,
|
315
|
-
"stackwise_num_blocks": self.stackwise_num_blocks,
|
316
|
-
"output_channels": self.output_channels,
|
317
|
-
"image_shape": self.latent_shape,
|
318
|
-
}
|
319
|
-
)
|
320
|
-
return config
|
File without changes
|