keras-hub-nightly 0.16.1.dev202409250340__py3-none-any.whl → 0.16.1.dev202409260340__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/layers/__init__.py +3 -0
- keras_hub/api/models/__init__.py +16 -0
- keras_hub/api/tokenizers/__init__.py +1 -0
- keras_hub/src/models/{stable_diffusion_v3 → clip}/clip_encoder_block.py +8 -2
- keras_hub/src/models/clip/clip_preprocessor.py +147 -0
- keras_hub/src/models/{stable_diffusion_v3 → clip}/clip_text_encoder.py +60 -57
- keras_hub/src/models/{stable_diffusion_v3 → clip}/clip_tokenizer.py +69 -30
- keras_hub/src/models/densenet/__init__.py +6 -0
- keras_hub/src/models/densenet/densenet_backbone.py +11 -8
- keras_hub/src/models/densenet/densenet_image_classifier.py +27 -4
- keras_hub/src/models/densenet/densenet_image_classifier_preprocessor.py +27 -0
- keras_hub/src/models/densenet/densenet_image_converter.py +23 -0
- keras_hub/src/models/densenet/densenet_presets.py +56 -0
- keras_hub/src/models/stable_diffusion_3/__init__.py +13 -0
- keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +93 -0
- keras_hub/src/models/{stable_diffusion_v3 → stable_diffusion_3}/mmdit.py +351 -26
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +630 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +151 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +77 -0
- keras_hub/src/models/{stable_diffusion_v3/t5_xxl_text_encoder.py → stable_diffusion_3/t5_encoder.py} +7 -7
- keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +333 -0
- keras_hub/src/models/{stable_diffusion_v3/t5_xxl_preprocessor.py → t5/t5_preprocessor.py} +12 -3
- keras_hub/src/models/text_to_image.py +295 -0
- keras_hub/src/utils/timm/convert_densenet.py +107 -0
- keras_hub/src/utils/timm/preset_loader.py +3 -0
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.16.1.dev202409250340.dist-info → keras_hub_nightly-0.16.1.dev202409260340.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.16.1.dev202409250340.dist-info → keras_hub_nightly-0.16.1.dev202409260340.dist-info}/RECORD +31 -23
- keras_hub/src/models/stable_diffusion_v3/clip_preprocessor.py +0 -93
- keras_hub/src/models/stable_diffusion_v3/mmdit_block.py +0 -317
- keras_hub/src/models/stable_diffusion_v3/vae_attention.py +0 -126
- keras_hub/src/models/stable_diffusion_v3/vae_image_decoder.py +0 -186
- /keras_hub/src/models/{stable_diffusion_v3 → clip}/__init__.py +0 -0
- {keras_hub_nightly-0.16.1.dev202409250340.dist-info → keras_hub_nightly-0.16.1.dev202409260340.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.16.1.dev202409250340.dist-info → keras_hub_nightly-0.16.1.dev202409260340.dist-info}/top_level.txt +0 -0
@@ -19,7 +19,8 @@ from keras import models
|
|
19
19
|
from keras import ops
|
20
20
|
|
21
21
|
from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding
|
22
|
-
from keras_hub.src.models.
|
22
|
+
from keras_hub.src.models.backbone import Backbone
|
23
|
+
from keras_hub.src.utils.keras_utils import gelu_approximate
|
23
24
|
from keras_hub.src.utils.keras_utils import standardize_data_format
|
24
25
|
|
25
26
|
|
@@ -79,8 +80,8 @@ class AdjustablePositionEmbedding(PositionEmbedding):
|
|
79
80
|
width = width or self.width
|
80
81
|
shape = ops.shape(inputs)
|
81
82
|
feature_length = shape[-1]
|
82
|
-
top = ops.floor_divide(self.height - height, 2)
|
83
|
-
left = ops.floor_divide(self.width - width, 2)
|
83
|
+
top = ops.cast(ops.floor_divide(self.height - height, 2), "int32")
|
84
|
+
left = ops.cast(ops.floor_divide(self.width - width, 2), "int32")
|
84
85
|
position_embedding = ops.convert_to_tensor(self.position_embeddings)
|
85
86
|
position_embedding = ops.reshape(
|
86
87
|
position_embedding, (self.height, self.width, feature_length)
|
@@ -166,6 +167,305 @@ class TimestepEmbedding(layers.Layer):
|
|
166
167
|
return output_shape
|
167
168
|
|
168
169
|
|
170
|
+
class DismantledBlock(layers.Layer):
|
171
|
+
def __init__(
|
172
|
+
self,
|
173
|
+
num_heads,
|
174
|
+
hidden_dim,
|
175
|
+
mlp_ratio=4.0,
|
176
|
+
use_projection=True,
|
177
|
+
**kwargs,
|
178
|
+
):
|
179
|
+
super().__init__(**kwargs)
|
180
|
+
self.num_heads = num_heads
|
181
|
+
self.hidden_dim = hidden_dim
|
182
|
+
self.mlp_ratio = mlp_ratio
|
183
|
+
self.use_projection = use_projection
|
184
|
+
|
185
|
+
head_dim = hidden_dim // num_heads
|
186
|
+
self.head_dim = head_dim
|
187
|
+
mlp_hidden_dim = int(hidden_dim * mlp_ratio)
|
188
|
+
self.mlp_hidden_dim = mlp_hidden_dim
|
189
|
+
num_modulations = 6 if use_projection else 2
|
190
|
+
self.num_modulations = num_modulations
|
191
|
+
|
192
|
+
self.adaptive_norm_modulation = models.Sequential(
|
193
|
+
[
|
194
|
+
layers.Activation("silu", dtype=self.dtype_policy),
|
195
|
+
layers.Dense(
|
196
|
+
num_modulations * hidden_dim, dtype=self.dtype_policy
|
197
|
+
),
|
198
|
+
],
|
199
|
+
name="adaptive_norm_modulation",
|
200
|
+
)
|
201
|
+
self.norm1 = layers.LayerNormalization(
|
202
|
+
epsilon=1e-6,
|
203
|
+
center=False,
|
204
|
+
scale=False,
|
205
|
+
dtype="float32",
|
206
|
+
name="norm1",
|
207
|
+
)
|
208
|
+
self.attention_qkv = layers.Dense(
|
209
|
+
hidden_dim * 3, dtype=self.dtype_policy, name="attention_qkv"
|
210
|
+
)
|
211
|
+
if use_projection:
|
212
|
+
self.attention_proj = layers.Dense(
|
213
|
+
hidden_dim, dtype=self.dtype_policy, name="attention_proj"
|
214
|
+
)
|
215
|
+
self.norm2 = layers.LayerNormalization(
|
216
|
+
epsilon=1e-6,
|
217
|
+
center=False,
|
218
|
+
scale=False,
|
219
|
+
dtype="float32",
|
220
|
+
name="norm2",
|
221
|
+
)
|
222
|
+
self.mlp = models.Sequential(
|
223
|
+
[
|
224
|
+
layers.Dense(
|
225
|
+
mlp_hidden_dim,
|
226
|
+
activation=gelu_approximate,
|
227
|
+
dtype=self.dtype_policy,
|
228
|
+
),
|
229
|
+
layers.Dense(
|
230
|
+
hidden_dim,
|
231
|
+
dtype=self.dtype_policy,
|
232
|
+
),
|
233
|
+
],
|
234
|
+
name="mlp",
|
235
|
+
)
|
236
|
+
|
237
|
+
def build(self, inputs_shape, timestep_embedding):
|
238
|
+
self.adaptive_norm_modulation.build(timestep_embedding)
|
239
|
+
self.attention_qkv.build(inputs_shape)
|
240
|
+
self.norm1.build(inputs_shape)
|
241
|
+
if self.use_projection:
|
242
|
+
self.attention_proj.build(inputs_shape)
|
243
|
+
self.norm2.build(inputs_shape)
|
244
|
+
self.mlp.build(inputs_shape)
|
245
|
+
|
246
|
+
def _modulate(self, inputs, shift, scale):
|
247
|
+
shift = ops.expand_dims(shift, axis=1)
|
248
|
+
scale = ops.expand_dims(scale, axis=1)
|
249
|
+
return ops.add(ops.multiply(inputs, ops.add(scale, 1.0)), shift)
|
250
|
+
|
251
|
+
def _compute_pre_attention(self, inputs, timestep_embedding, training=None):
|
252
|
+
batch_size = ops.shape(inputs)[0]
|
253
|
+
if self.use_projection:
|
254
|
+
modulation = self.adaptive_norm_modulation(
|
255
|
+
timestep_embedding, training=training
|
256
|
+
)
|
257
|
+
modulation = ops.reshape(
|
258
|
+
modulation, (batch_size, 6, self.hidden_dim)
|
259
|
+
)
|
260
|
+
(
|
261
|
+
shift_msa,
|
262
|
+
scale_msa,
|
263
|
+
gate_msa,
|
264
|
+
shift_mlp,
|
265
|
+
scale_mlp,
|
266
|
+
gate_mlp,
|
267
|
+
) = ops.unstack(modulation, 6, axis=1)
|
268
|
+
qkv = self.attention_qkv(
|
269
|
+
self._modulate(self.norm1(inputs), shift_msa, scale_msa),
|
270
|
+
training=training,
|
271
|
+
)
|
272
|
+
qkv = ops.reshape(
|
273
|
+
qkv, (batch_size, -1, 3, self.num_heads, self.head_dim)
|
274
|
+
)
|
275
|
+
q, k, v = ops.unstack(qkv, 3, axis=2)
|
276
|
+
return (q, k, v), (inputs, gate_msa, shift_mlp, scale_mlp, gate_mlp)
|
277
|
+
else:
|
278
|
+
modulation = self.adaptive_norm_modulation(
|
279
|
+
timestep_embedding, training=training
|
280
|
+
)
|
281
|
+
modulation = ops.reshape(
|
282
|
+
modulation, (batch_size, 2, self.hidden_dim)
|
283
|
+
)
|
284
|
+
shift_msa, scale_msa = ops.unstack(modulation, 2, axis=1)
|
285
|
+
qkv = self.attention_qkv(
|
286
|
+
self._modulate(self.norm1(inputs), shift_msa, scale_msa),
|
287
|
+
training=training,
|
288
|
+
)
|
289
|
+
qkv = ops.reshape(
|
290
|
+
qkv, (batch_size, -1, 3, self.num_heads, self.head_dim)
|
291
|
+
)
|
292
|
+
q, k, v = ops.unstack(qkv, 3, axis=2)
|
293
|
+
return (q, k, v)
|
294
|
+
|
295
|
+
def _compute_post_attention(
|
296
|
+
self, inputs, inputs_intermediates, training=None
|
297
|
+
):
|
298
|
+
x, gate_msa, shift_mlp, scale_mlp, gate_mlp = inputs_intermediates
|
299
|
+
attn = self.attention_proj(inputs, training=training)
|
300
|
+
x = ops.add(x, ops.multiply(ops.expand_dims(gate_msa, axis=1), attn))
|
301
|
+
x = ops.add(
|
302
|
+
x,
|
303
|
+
ops.multiply(
|
304
|
+
ops.expand_dims(gate_mlp, axis=1),
|
305
|
+
self.mlp(
|
306
|
+
self._modulate(self.norm2(x), shift_mlp, scale_mlp),
|
307
|
+
training=training,
|
308
|
+
),
|
309
|
+
),
|
310
|
+
)
|
311
|
+
return x
|
312
|
+
|
313
|
+
def call(
|
314
|
+
self,
|
315
|
+
inputs,
|
316
|
+
timestep_embedding=None,
|
317
|
+
inputs_intermediates=None,
|
318
|
+
pre_attention=True,
|
319
|
+
training=None,
|
320
|
+
):
|
321
|
+
if pre_attention:
|
322
|
+
return self._compute_pre_attention(
|
323
|
+
inputs, timestep_embedding, training=training
|
324
|
+
)
|
325
|
+
else:
|
326
|
+
return self._compute_post_attention(
|
327
|
+
inputs, inputs_intermediates, training=training
|
328
|
+
)
|
329
|
+
|
330
|
+
def get_config(self):
|
331
|
+
config = super().get_config()
|
332
|
+
config.update(
|
333
|
+
{
|
334
|
+
"num_heads": self.num_heads,
|
335
|
+
"hidden_dim": self.hidden_dim,
|
336
|
+
"mlp_ratio": self.mlp_ratio,
|
337
|
+
"use_projection": self.use_projection,
|
338
|
+
}
|
339
|
+
)
|
340
|
+
return config
|
341
|
+
|
342
|
+
|
343
|
+
class MMDiTBlock(layers.Layer):
|
344
|
+
def __init__(
|
345
|
+
self,
|
346
|
+
num_heads,
|
347
|
+
hidden_dim,
|
348
|
+
mlp_ratio=4.0,
|
349
|
+
use_context_projection=True,
|
350
|
+
**kwargs,
|
351
|
+
):
|
352
|
+
super().__init__(**kwargs)
|
353
|
+
self.num_heads = num_heads
|
354
|
+
self.hidden_dim = hidden_dim
|
355
|
+
self.mlp_ratio = mlp_ratio
|
356
|
+
self.use_context_projection = use_context_projection
|
357
|
+
|
358
|
+
head_dim = hidden_dim // num_heads
|
359
|
+
self.head_dim = head_dim
|
360
|
+
self._inverse_sqrt_key_dim = 1.0 / math.sqrt(head_dim)
|
361
|
+
self._dot_product_equation = "aecd,abcd->acbe"
|
362
|
+
self._combine_equation = "acbe,aecd->abcd"
|
363
|
+
|
364
|
+
self.x_block = DismantledBlock(
|
365
|
+
num_heads=num_heads,
|
366
|
+
hidden_dim=hidden_dim,
|
367
|
+
mlp_ratio=mlp_ratio,
|
368
|
+
use_projection=True,
|
369
|
+
dtype=self.dtype_policy,
|
370
|
+
name="x_block",
|
371
|
+
)
|
372
|
+
self.context_block = DismantledBlock(
|
373
|
+
num_heads=num_heads,
|
374
|
+
hidden_dim=hidden_dim,
|
375
|
+
mlp_ratio=mlp_ratio,
|
376
|
+
use_projection=use_context_projection,
|
377
|
+
dtype=self.dtype_policy,
|
378
|
+
name="context_block",
|
379
|
+
)
|
380
|
+
self.softmax = layers.Softmax(dtype="float32")
|
381
|
+
|
382
|
+
def build(self, inputs_shape, context_shape, timestep_embedding_shape):
|
383
|
+
self.x_block.build(inputs_shape, timestep_embedding_shape)
|
384
|
+
self.context_block.build(context_shape, timestep_embedding_shape)
|
385
|
+
|
386
|
+
def _compute_attention(self, query, key, value):
|
387
|
+
query = ops.multiply(
|
388
|
+
query, ops.cast(self._inverse_sqrt_key_dim, query.dtype)
|
389
|
+
)
|
390
|
+
attention_scores = ops.einsum(self._dot_product_equation, key, query)
|
391
|
+
attention_scores = self.softmax(attention_scores)
|
392
|
+
attention_scores = ops.cast(attention_scores, self.compute_dtype)
|
393
|
+
attention_output = ops.einsum(
|
394
|
+
self._combine_equation, attention_scores, value
|
395
|
+
)
|
396
|
+
batch_size = ops.shape(attention_output)[0]
|
397
|
+
attention_output = ops.reshape(
|
398
|
+
attention_output, (batch_size, -1, self.num_heads * self.head_dim)
|
399
|
+
)
|
400
|
+
return attention_output
|
401
|
+
|
402
|
+
def call(self, inputs, context, timestep_embedding, training=None):
|
403
|
+
# Compute pre-attention.
|
404
|
+
x = inputs
|
405
|
+
if self.use_context_projection:
|
406
|
+
context_qkv, context_intermediates = self.context_block(
|
407
|
+
context,
|
408
|
+
timestep_embedding=timestep_embedding,
|
409
|
+
training=training,
|
410
|
+
)
|
411
|
+
else:
|
412
|
+
context_qkv = self.context_block(
|
413
|
+
context,
|
414
|
+
timestep_embedding=timestep_embedding,
|
415
|
+
training=training,
|
416
|
+
)
|
417
|
+
context_len = ops.shape(context_qkv[0])[1]
|
418
|
+
x_qkv, x_intermediates = self.x_block(
|
419
|
+
x, timestep_embedding=timestep_embedding, training=training
|
420
|
+
)
|
421
|
+
q = ops.concatenate([context_qkv[0], x_qkv[0]], axis=1)
|
422
|
+
k = ops.concatenate([context_qkv[1], x_qkv[1]], axis=1)
|
423
|
+
v = ops.concatenate([context_qkv[2], x_qkv[2]], axis=1)
|
424
|
+
|
425
|
+
# Compute attention.
|
426
|
+
attention = self._compute_attention(q, k, v)
|
427
|
+
context_attention = attention[:, :context_len]
|
428
|
+
x_attention = attention[:, context_len:]
|
429
|
+
|
430
|
+
# Compute post-attention.
|
431
|
+
x = self.x_block(
|
432
|
+
x_attention,
|
433
|
+
inputs_intermediates=x_intermediates,
|
434
|
+
pre_attention=False,
|
435
|
+
training=training,
|
436
|
+
)
|
437
|
+
if self.use_context_projection:
|
438
|
+
context = self.context_block(
|
439
|
+
context_attention,
|
440
|
+
inputs_intermediates=context_intermediates,
|
441
|
+
pre_attention=False,
|
442
|
+
training=training,
|
443
|
+
)
|
444
|
+
return x, context
|
445
|
+
else:
|
446
|
+
return x
|
447
|
+
|
448
|
+
def get_config(self):
|
449
|
+
config = super().get_config()
|
450
|
+
config.update(
|
451
|
+
{
|
452
|
+
"num_heads": self.num_heads,
|
453
|
+
"hidden_dim": self.hidden_dim,
|
454
|
+
"mlp_ratio": self.mlp_ratio,
|
455
|
+
"use_context_projection": self.use_context_projection,
|
456
|
+
}
|
457
|
+
)
|
458
|
+
return config
|
459
|
+
|
460
|
+
def compute_output_shape(
|
461
|
+
self, inputs_shape, context_shape, timestep_embedding_shape
|
462
|
+
):
|
463
|
+
if self.use_context_projection:
|
464
|
+
return inputs_shape, context_shape
|
465
|
+
else:
|
466
|
+
return inputs_shape
|
467
|
+
|
468
|
+
|
169
469
|
class OutputLayer(layers.Layer):
|
170
470
|
def __init__(self, hidden_dim, output_dim, **kwargs):
|
171
471
|
super().__init__(**kwargs)
|
@@ -186,11 +486,11 @@ class OutputLayer(layers.Layer):
|
|
186
486
|
epsilon=1e-6,
|
187
487
|
center=False,
|
188
488
|
scale=False,
|
189
|
-
dtype=
|
489
|
+
dtype="float32",
|
190
490
|
name="norm",
|
191
491
|
)
|
192
492
|
self.output_dense = layers.Dense(
|
193
|
-
output_dim,
|
493
|
+
output_dim,
|
194
494
|
use_bias=True,
|
195
495
|
dtype=self.dtype_policy,
|
196
496
|
name="output_dense",
|
@@ -227,6 +527,11 @@ class OutputLayer(layers.Layer):
|
|
227
527
|
)
|
228
528
|
return config
|
229
529
|
|
530
|
+
def compute_output_shape(self, inputs_shape):
|
531
|
+
outputs_shape = list(inputs_shape)
|
532
|
+
outputs_shape[-1] = self.output_dim
|
533
|
+
return outputs_shape
|
534
|
+
|
230
535
|
|
231
536
|
class Unpatch(layers.Layer):
|
232
537
|
def __init__(self, patch_size, output_dim, **kwargs):
|
@@ -263,18 +568,48 @@ class Unpatch(layers.Layer):
|
|
263
568
|
return [inputs_shape[0], None, None, self.output_dim]
|
264
569
|
|
265
570
|
|
266
|
-
class MMDiT(
|
571
|
+
class MMDiT(Backbone):
|
572
|
+
"""Multimodal Diffusion Transformer (MMDiT) model for Stable Diffusion 3.
|
573
|
+
|
574
|
+
MMDiT is introduced in [
|
575
|
+
Scaling Rectified Flow Transformers for High-Resolution Image Synthesis](
|
576
|
+
https://arxiv.org/abs/2403.03206).
|
577
|
+
|
578
|
+
Args:
|
579
|
+
patch_size: int. The size of each square patch in the input image.
|
580
|
+
hidden_dim: int. The size of the transformer hidden state at the end
|
581
|
+
of each transformer layer.
|
582
|
+
num_layers: int. The number of transformer layers.
|
583
|
+
num_heads: int. The number of attention heads for each transformer.
|
584
|
+
position_size: int. The size of the height and width for the position
|
585
|
+
embedding.
|
586
|
+
mlp_ratio: float. The ratio of the mlp hidden dim to the transformer
|
587
|
+
latent_shape: tuple. The shape of the latent image.
|
588
|
+
context_shape: tuple. The shape of the context.
|
589
|
+
pooled_projection_shape: tuple. The shape of the pooled projection.
|
590
|
+
data_format: `None` or str. If specified, either `"channels_last"` or
|
591
|
+
`"channels_first"`. The ordering of the dimensions in the
|
592
|
+
inputs. `"channels_last"` corresponds to inputs with shape
|
593
|
+
`(batch_size, height, width, channels)`
|
594
|
+
while `"channels_first"` corresponds to inputs with shape
|
595
|
+
`(batch_size, channels, height, width)`. It defaults to the
|
596
|
+
`image_data_format` value found in your Keras config file at
|
597
|
+
`~/.keras/keras.json`. If you never set it, then it will be
|
598
|
+
`"channels_last"`.
|
599
|
+
dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype
|
600
|
+
to use for the model's computations and weights.
|
601
|
+
"""
|
602
|
+
|
267
603
|
def __init__(
|
268
604
|
self,
|
269
605
|
patch_size,
|
270
|
-
num_heads,
|
271
606
|
hidden_dim,
|
272
|
-
|
607
|
+
num_layers,
|
608
|
+
num_heads,
|
273
609
|
position_size,
|
274
|
-
output_dim,
|
275
610
|
mlp_ratio=4.0,
|
276
611
|
latent_shape=(64, 64, 16),
|
277
|
-
context_shape=(
|
612
|
+
context_shape=(None, 4096),
|
278
613
|
pooled_projection_shape=(2048,),
|
279
614
|
data_format=None,
|
280
615
|
dtype=None,
|
@@ -287,6 +622,7 @@ class MMDiT(keras.Model):
|
|
287
622
|
)
|
288
623
|
image_height = latent_shape[0] // patch_size
|
289
624
|
image_width = latent_shape[1] // patch_size
|
625
|
+
output_dim = latent_shape[-1]
|
290
626
|
output_dim_in_final = patch_size**2 * output_dim
|
291
627
|
data_format = standardize_data_format(data_format)
|
292
628
|
if data_format != "channels_last":
|
@@ -331,11 +667,11 @@ class MMDiT(keras.Model):
|
|
331
667
|
num_heads,
|
332
668
|
hidden_dim,
|
333
669
|
mlp_ratio,
|
334
|
-
use_context_projection=not (i ==
|
670
|
+
use_context_projection=not (i == num_layers - 1),
|
335
671
|
dtype=dtype,
|
336
672
|
name=f"joint_block_{i}",
|
337
673
|
)
|
338
|
-
for i in range(
|
674
|
+
for i in range(num_layers)
|
339
675
|
]
|
340
676
|
self.output_layer = OutputLayer(
|
341
677
|
hidden_dim, output_dim_in_final, dtype=dtype, name="output_layer"
|
@@ -391,33 +727,22 @@ class MMDiT(keras.Model):
|
|
391
727
|
self.patch_size = patch_size
|
392
728
|
self.num_heads = num_heads
|
393
729
|
self.hidden_dim = hidden_dim
|
394
|
-
self.
|
730
|
+
self.num_layers = num_layers
|
395
731
|
self.position_size = position_size
|
396
|
-
self.output_dim = output_dim
|
397
732
|
self.mlp_ratio = mlp_ratio
|
398
733
|
self.latent_shape = latent_shape
|
399
734
|
self.context_shape = context_shape
|
400
735
|
self.pooled_projection_shape = pooled_projection_shape
|
401
736
|
|
402
|
-
if dtype is not None:
|
403
|
-
try:
|
404
|
-
self.dtype_policy = keras.dtype_policies.get(dtype)
|
405
|
-
# Before Keras 3.2, there is no `keras.dtype_policies.get`.
|
406
|
-
except AttributeError:
|
407
|
-
if isinstance(dtype, keras.DTypePolicy):
|
408
|
-
dtype = dtype.name
|
409
|
-
self.dtype_policy = keras.DTypePolicy(dtype)
|
410
|
-
|
411
737
|
def get_config(self):
|
412
738
|
config = super().get_config()
|
413
739
|
config.update(
|
414
740
|
{
|
415
741
|
"patch_size": self.patch_size,
|
416
|
-
"num_heads": self.num_heads,
|
417
742
|
"hidden_dim": self.hidden_dim,
|
418
|
-
"
|
743
|
+
"num_layers": self.num_layers,
|
744
|
+
"num_heads": self.num_heads,
|
419
745
|
"position_size": self.position_size,
|
420
|
-
"output_dim": self.output_dim,
|
421
746
|
"mlp_ratio": self.mlp_ratio,
|
422
747
|
"latent_shape": self.latent_shape,
|
423
748
|
"context_shape": self.context_shape,
|