keras-hub-nightly 0.16.1.dev202410020340__py3-none-any.whl → 0.16.1.dev202410040340__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. keras_hub/api/layers/__init__.py +3 -3
  2. keras_hub/api/models/__init__.py +10 -1
  3. keras_hub/src/layers/preprocessing/audio_converter.py +3 -7
  4. keras_hub/src/layers/preprocessing/image_converter.py +164 -34
  5. keras_hub/src/models/backbone.py +3 -9
  6. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +0 -109
  7. keras_hub/src/models/deeplab_v3/__init__.py +7 -0
  8. keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +196 -0
  9. keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py +10 -0
  10. keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py +16 -0
  11. keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +215 -0
  12. keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +4 -0
  13. keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +109 -0
  14. keras_hub/src/models/densenet/densenet_image_classifier.py +0 -128
  15. keras_hub/src/models/densenet/densenet_image_converter.py +2 -4
  16. keras_hub/src/models/feature_pyramid_backbone.py +1 -1
  17. keras_hub/src/models/image_classifier.py +147 -2
  18. keras_hub/src/models/image_classifier_preprocessor.py +3 -3
  19. keras_hub/src/models/image_segmenter.py +0 -5
  20. keras_hub/src/models/image_segmenter_preprocessor.py +29 -4
  21. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +0 -109
  22. keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +0 -92
  23. keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +2 -4
  24. keras_hub/src/models/preprocessor.py +3 -5
  25. keras_hub/src/models/resnet/resnet_backbone.py +1 -11
  26. keras_hub/src/models/resnet/resnet_image_classifier.py +0 -137
  27. keras_hub/src/models/resnet/resnet_image_converter.py +2 -4
  28. keras_hub/src/models/sam/__init__.py +5 -0
  29. keras_hub/src/models/sam/sam_image_converter.py +2 -4
  30. keras_hub/src/models/sam/sam_image_segmenter_preprocessor.py +11 -1
  31. keras_hub/src/models/sam/sam_presets.py +3 -3
  32. keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +8 -1
  33. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +57 -93
  34. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +3 -3
  35. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +5 -3
  36. keras_hub/src/models/task.py +39 -36
  37. keras_hub/src/models/vae/__init__.py +1 -0
  38. keras_hub/src/models/vae/vae_backbone.py +172 -0
  39. keras_hub/src/models/vae/vae_layers.py +740 -0
  40. keras_hub/src/models/vgg/vgg_backbone.py +1 -20
  41. keras_hub/src/models/vgg/vgg_image_classifier.py +108 -29
  42. keras_hub/src/tokenizers/tokenizer.py +3 -6
  43. keras_hub/src/utils/preset_utils.py +103 -61
  44. keras_hub/src/utils/timm/preset_loader.py +8 -9
  45. keras_hub/src/version_utils.py +1 -1
  46. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.16.1.dev202410040340.dist-info}/METADATA +1 -1
  47. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.16.1.dev202410040340.dist-info}/RECORD +49 -41
  48. keras_hub/src/layers/preprocessing/resizing_image_converter.py +0 -138
  49. keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +0 -320
  50. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.16.1.dev202410040340.dist-info}/WHEEL +0 -0
  51. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.16.1.dev202410040340.dist-info}/top_level.txt +0 -0
@@ -1,320 +0,0 @@
1
- import math
2
-
3
- from keras import layers
4
- from keras import ops
5
-
6
- from keras_hub.src.models.backbone import Backbone
7
- from keras_hub.src.utils.keras_utils import standardize_data_format
8
-
9
-
10
- class VAEAttention(layers.Layer):
11
- def __init__(self, filters, groups=32, data_format=None, **kwargs):
12
- super().__init__(**kwargs)
13
- self.filters = filters
14
- self.data_format = standardize_data_format(data_format)
15
- gn_axis = -1 if self.data_format == "channels_last" else 1
16
-
17
- self.group_norm = layers.GroupNormalization(
18
- groups=groups,
19
- axis=gn_axis,
20
- epsilon=1e-6,
21
- dtype="float32",
22
- name="group_norm",
23
- )
24
- self.query_conv2d = layers.Conv2D(
25
- filters,
26
- 1,
27
- 1,
28
- data_format=self.data_format,
29
- dtype=self.dtype_policy,
30
- name="query_conv2d",
31
- )
32
- self.key_conv2d = layers.Conv2D(
33
- filters,
34
- 1,
35
- 1,
36
- data_format=self.data_format,
37
- dtype=self.dtype_policy,
38
- name="key_conv2d",
39
- )
40
- self.value_conv2d = layers.Conv2D(
41
- filters,
42
- 1,
43
- 1,
44
- data_format=self.data_format,
45
- dtype=self.dtype_policy,
46
- name="value_conv2d",
47
- )
48
- self.softmax = layers.Softmax(dtype="float32")
49
- self.output_conv2d = layers.Conv2D(
50
- filters,
51
- 1,
52
- 1,
53
- data_format=self.data_format,
54
- dtype=self.dtype_policy,
55
- name="output_conv2d",
56
- )
57
-
58
- self.groups = groups
59
- self._inverse_sqrt_filters = 1.0 / math.sqrt(float(filters))
60
-
61
- def build(self, input_shape):
62
- self.group_norm.build(input_shape)
63
- self.query_conv2d.build(input_shape)
64
- self.key_conv2d.build(input_shape)
65
- self.value_conv2d.build(input_shape)
66
- self.output_conv2d.build(input_shape)
67
-
68
- def call(self, inputs, training=None):
69
- x = self.group_norm(inputs)
70
- query = self.query_conv2d(x)
71
- key = self.key_conv2d(x)
72
- value = self.value_conv2d(x)
73
-
74
- if self.data_format == "channels_first":
75
- query = ops.transpose(query, (0, 2, 3, 1))
76
- key = ops.transpose(key, (0, 2, 3, 1))
77
- value = ops.transpose(value, (0, 2, 3, 1))
78
- shape = ops.shape(inputs)
79
- b = shape[0]
80
- query = ops.reshape(query, (b, -1, self.filters))
81
- key = ops.reshape(key, (b, -1, self.filters))
82
- value = ops.reshape(value, (b, -1, self.filters))
83
-
84
- # Compute attention.
85
- query = ops.multiply(
86
- query, ops.cast(self._inverse_sqrt_filters, query.dtype)
87
- )
88
- # [B, H0 * W0, C], [B, H1 * W1, C] -> [B, H0 * W0, H1 * W1]
89
- attention_scores = ops.einsum("abc,adc->abd", query, key)
90
- attention_scores = ops.cast(
91
- self.softmax(attention_scores), self.compute_dtype
92
- )
93
- # [B, H2 * W2, C], [B, H0 * W0, H1 * W1] -> [B, H1 * W1 ,C]
94
- attention_output = ops.einsum("abc,adb->adc", value, attention_scores)
95
- x = ops.reshape(attention_output, shape)
96
-
97
- x = self.output_conv2d(x)
98
- if self.data_format == "channels_first":
99
- x = ops.transpose(x, (0, 3, 1, 2))
100
- x = ops.add(x, inputs)
101
- return x
102
-
103
- def get_config(self):
104
- config = super().get_config()
105
- config.update(
106
- {
107
- "filters": self.filters,
108
- "groups": self.groups,
109
- }
110
- )
111
- return config
112
-
113
- def compute_output_shape(self, input_shape):
114
- return input_shape
115
-
116
-
117
- def apply_resnet_block(x, filters, data_format=None, dtype=None, name=None):
118
- data_format = standardize_data_format(data_format)
119
- gn_axis = -1 if data_format == "channels_last" else 1
120
- input_filters = x.shape[gn_axis]
121
-
122
- residual = x
123
- x = layers.GroupNormalization(
124
- groups=32,
125
- axis=gn_axis,
126
- epsilon=1e-6,
127
- dtype="float32",
128
- name=f"{name}_norm1",
129
- )(x)
130
- x = layers.Activation("swish", dtype=dtype)(x)
131
- x = layers.Conv2D(
132
- filters,
133
- 3,
134
- 1,
135
- padding="same",
136
- data_format=data_format,
137
- dtype=dtype,
138
- name=f"{name}_conv1",
139
- )(x)
140
- x = layers.GroupNormalization(
141
- groups=32,
142
- axis=gn_axis,
143
- epsilon=1e-6,
144
- dtype="float32",
145
- name=f"{name}_norm2",
146
- )(x)
147
- x = layers.Activation("swish", dtype=dtype)(x)
148
- x = layers.Conv2D(
149
- filters,
150
- 3,
151
- 1,
152
- padding="same",
153
- data_format=data_format,
154
- dtype=dtype,
155
- name=f"{name}_conv2",
156
- )(x)
157
- if input_filters != filters:
158
- residual = layers.Conv2D(
159
- filters,
160
- 1,
161
- 1,
162
- data_format=data_format,
163
- dtype=dtype,
164
- name=f"{name}_residual_projection",
165
- )(residual)
166
- x = layers.Add(dtype=dtype)([residual, x])
167
- return x
168
-
169
-
170
- class VAEImageDecoder(Backbone):
171
- """Decoder for the VAE model used in Stable Diffusion 3.
172
-
173
- Args:
174
- stackwise_num_filters: list of ints. The number of filters for each
175
- stack.
176
- stackwise_num_blocks: list of ints. The number of blocks for each stack.
177
- output_channels: int. The number of channels in the output.
178
- latent_shape: tuple. The shape of the latent image.
179
- data_format: `None` or str. If specified, either `"channels_last"` or
180
- `"channels_first"`. The ordering of the dimensions in the
181
- inputs. `"channels_last"` corresponds to inputs with shape
182
- `(batch_size, height, width, channels)`
183
- while `"channels_first"` corresponds to inputs with shape
184
- `(batch_size, channels, height, width)`. It defaults to the
185
- `image_data_format` value found in your Keras config file at
186
- `~/.keras/keras.json`. If you never set it, then it will be
187
- `"channels_last"`.
188
- dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype
189
- to use for the model's computations and weights.
190
- """
191
-
192
- def __init__(
193
- self,
194
- stackwise_num_filters,
195
- stackwise_num_blocks,
196
- output_channels=3,
197
- latent_shape=(None, None, 16),
198
- data_format=None,
199
- dtype=None,
200
- **kwargs,
201
- ):
202
- data_format = standardize_data_format(data_format)
203
- gn_axis = -1 if data_format == "channels_last" else 1
204
-
205
- # === Functional Model ===
206
- latent_inputs = layers.Input(shape=latent_shape)
207
-
208
- x = layers.Conv2D(
209
- stackwise_num_filters[0],
210
- 3,
211
- 1,
212
- padding="same",
213
- data_format=data_format,
214
- dtype=dtype,
215
- name="input_projection",
216
- )(latent_inputs)
217
- x = apply_resnet_block(
218
- x,
219
- stackwise_num_filters[0],
220
- data_format=data_format,
221
- dtype=dtype,
222
- name="input_block0",
223
- )
224
- x = VAEAttention(
225
- stackwise_num_filters[0],
226
- data_format=data_format,
227
- dtype=dtype,
228
- name="input_attention",
229
- )(x)
230
- x = apply_resnet_block(
231
- x,
232
- stackwise_num_filters[0],
233
- data_format=data_format,
234
- dtype=dtype,
235
- name="input_block1",
236
- )
237
-
238
- # Stacks.
239
- for i, filters in enumerate(stackwise_num_filters):
240
- for j in range(stackwise_num_blocks[i]):
241
- x = apply_resnet_block(
242
- x,
243
- filters,
244
- data_format=data_format,
245
- dtype=dtype,
246
- name=f"block{i}_{j}",
247
- )
248
- if i != len(stackwise_num_filters) - 1:
249
- # No upsamling in the last blcok.
250
- x = layers.UpSampling2D(
251
- 2,
252
- data_format=data_format,
253
- dtype=dtype,
254
- name=f"upsample_{i}",
255
- )(x)
256
- x = layers.Conv2D(
257
- filters,
258
- 3,
259
- 1,
260
- padding="same",
261
- data_format=data_format,
262
- dtype=dtype,
263
- name=f"upsample_{i}_conv",
264
- )(x)
265
-
266
- # Ouput block.
267
- x = layers.GroupNormalization(
268
- groups=32,
269
- axis=gn_axis,
270
- epsilon=1e-6,
271
- dtype="float32",
272
- name="output_norm",
273
- )(x)
274
- x = layers.Activation("swish", dtype=dtype, name="output_activation")(x)
275
- image_outputs = layers.Conv2D(
276
- output_channels,
277
- 3,
278
- 1,
279
- padding="same",
280
- data_format=data_format,
281
- dtype=dtype,
282
- name="output_projection",
283
- )(x)
284
- super().__init__(inputs=latent_inputs, outputs=image_outputs, **kwargs)
285
-
286
- # === Config ===
287
- self.stackwise_num_filters = stackwise_num_filters
288
- self.stackwise_num_blocks = stackwise_num_blocks
289
- self.output_channels = output_channels
290
- self.latent_shape = latent_shape
291
-
292
- @property
293
- def scaling_factor(self):
294
- """The scaling factor for the latent space.
295
-
296
- This is used to scale the latent space to have unit variance when
297
- training the diffusion model.
298
- """
299
- return 1.5305
300
-
301
- @property
302
- def shift_factor(self):
303
- """The shift factor for the latent space.
304
-
305
- This is used to shift the latent space to have zero mean when
306
- training the diffusion model.
307
- """
308
- return 0.0609
309
-
310
- def get_config(self):
311
- config = super().get_config()
312
- config.update(
313
- {
314
- "stackwise_num_filters": self.stackwise_num_filters,
315
- "stackwise_num_blocks": self.stackwise_num_blocks,
316
- "output_channels": self.output_channels,
317
- "image_shape": self.latent_shape,
318
- }
319
- )
320
- return config