keras-hub-nightly 0.19.0.dev202501080345__py3-none-any.whl → 0.19.0.dev202501150344__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (25) hide show
  1. keras_hub/src/metrics/bleu.py +3 -2
  2. keras_hub/src/models/basnet/basnet_backbone.py +1 -1
  3. keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +3 -3
  4. keras_hub/src/models/densenet/densenet_backbone.py +3 -3
  5. keras_hub/src/models/flux/flux_text_to_image.py +1 -1
  6. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +2 -2
  7. keras_hub/src/models/resnet/resnet_backbone.py +1 -1
  8. keras_hub/src/models/retinanet/feature_pyramid.py +5 -5
  9. keras_hub/src/models/stable_diffusion_3/mmdit.py +254 -58
  10. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +15 -2
  11. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +1 -2
  12. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +12 -0
  13. keras_hub/src/models/vit/vit_layers.py +1 -1
  14. keras_hub/src/tokenizers/byte_tokenizer.py +1 -2
  15. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +3 -0
  16. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +1 -2
  17. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +3 -0
  18. keras_hub/src/utils/timm/convert_densenet.py +6 -4
  19. keras_hub/src/utils/timm/convert_efficientnet.py +1 -1
  20. keras_hub/src/utils/timm/convert_resnet.py +1 -1
  21. keras_hub/src/version_utils.py +1 -1
  22. {keras_hub_nightly-0.19.0.dev202501080345.dist-info → keras_hub_nightly-0.19.0.dev202501150344.dist-info}/METADATA +13 -2
  23. {keras_hub_nightly-0.19.0.dev202501080345.dist-info → keras_hub_nightly-0.19.0.dev202501150344.dist-info}/RECORD +25 -25
  24. {keras_hub_nightly-0.19.0.dev202501080345.dist-info → keras_hub_nightly-0.19.0.dev202501150344.dist-info}/WHEEL +1 -1
  25. {keras_hub_nightly-0.19.0.dev202501080345.dist-info → keras_hub_nightly-0.19.0.dev202501150344.dist-info}/top_level.txt +0 -0
@@ -329,8 +329,9 @@ class Bleu(keras.metrics.Metric):
329
329
  return tf.squeeze(inputs, axis=-1)
330
330
  else:
331
331
  raise ValueError(
332
- f"{tensor_name} must be of rank {base_rank}, {base_rank+1} "
333
- f"or {base_rank+2}. Found rank: {inputs.shape.rank}"
332
+ f"{tensor_name} must be of rank {base_rank}, "
333
+ f"{base_rank + 1}, or {base_rank + 2}. "
334
+ f"Found rank: {inputs.shape.rank}"
334
335
  )
335
336
 
336
337
  y_true = validate_and_fix_rank(y_true, "y_true", 1)
@@ -219,7 +219,7 @@ def get_resnet_block(_resnet, block_num):
219
219
  else:
220
220
  x = _resnet.pyramid_outputs[extractor_levels[block_num - 1]]
221
221
  y = _resnet.get_layer(
222
- f"stack{block_num}_block{num_blocks[block_num]-1}_add"
222
+ f"stack{block_num}_block{num_blocks[block_num] - 1}_add"
223
223
  ).output
224
224
  return keras.models.Model(
225
225
  inputs=x,
@@ -88,13 +88,13 @@ class SpatialPyramidPooling(keras.layers.Layer):
88
88
  dilation_rate=dilation_rate,
89
89
  use_bias=False,
90
90
  data_format=self.data_format,
91
- name=f"aspp_conv_{i+2}",
91
+ name=f"aspp_conv_{i + 2}",
92
92
  ),
93
93
  keras.layers.BatchNormalization(
94
- axis=self.channel_axis, name=f"aspp_bn_{i+2}"
94
+ axis=self.channel_axis, name=f"aspp_bn_{i + 2}"
95
95
  ),
96
96
  keras.layers.Activation(
97
- self.activation, name=f"aspp_activation_{i+2}"
97
+ self.activation, name=f"aspp_activation_{i + 2}"
98
98
  ),
99
99
  ]
100
100
  )
@@ -81,14 +81,14 @@ class DenseNetBackbone(FeaturePyramidBackbone):
81
81
  channel_axis,
82
82
  stackwise_num_repeats[stack_index],
83
83
  growth_rate,
84
- name=f"stack{stack_index+1}",
84
+ name=f"stack{stack_index + 1}",
85
85
  )
86
86
  pyramid_outputs[f"P{index}"] = x
87
87
  x = apply_transition_block(
88
88
  x,
89
89
  channel_axis,
90
90
  compression_ratio,
91
- name=f"transition{stack_index+1}",
91
+ name=f"transition{stack_index + 1}",
92
92
  )
93
93
 
94
94
  x = apply_dense_block(
@@ -140,7 +140,7 @@ def apply_dense_block(x, channel_axis, num_repeats, growth_rate, name=None):
140
140
 
141
141
  for i in range(num_repeats):
142
142
  x = apply_conv_block(
143
- x, channel_axis, growth_rate, name=f"{name}_block{i+1}"
143
+ x, channel_axis, growth_rate, name=f"{name}_block{i + 1}"
144
144
  )
145
145
  return x
146
146
 
@@ -81,7 +81,7 @@ class FluxTextToImage(TextToImage):
81
81
 
82
82
  def fit(self, *args, **kwargs):
83
83
  raise NotImplementedError(
84
- "Currently, `fit` is not supported for " "`FluxTextToImage`."
84
+ "Currently, `fit` is not supported for `FluxTextToImage`."
85
85
  )
86
86
 
87
87
  def generate_step(
@@ -5,7 +5,7 @@ backbone_presets = {
5
5
  "pali_gemma_3b_mix_224": {
6
6
  "metadata": {
7
7
  "description": (
8
- "image size 224, mix fine tuned, text sequence " "length is 256"
8
+ "image size 224, mix fine tuned, text sequence length is 256"
9
9
  ),
10
10
  "params": 2923335408,
11
11
  "path": "pali_gemma",
@@ -45,7 +45,7 @@ backbone_presets = {
45
45
  "pali_gemma_3b_896": {
46
46
  "metadata": {
47
47
  "description": (
48
- "image size 896, pre trained, text sequence length " "is 512"
48
+ "image size 896, pre trained, text sequence length is 512"
49
49
  ),
50
50
  "params": 2927759088,
51
51
  "path": "pali_gemma",
@@ -177,7 +177,7 @@ class ResNetBackbone(FeaturePyramidBackbone):
177
177
  use_bias=False,
178
178
  padding="same",
179
179
  dtype=dtype,
180
- name=f"conv{conv_index+1}_conv",
180
+ name=f"conv{conv_index + 1}_conv",
181
181
  )(x)
182
182
 
183
183
  if not use_pre_activation:
@@ -209,9 +209,9 @@ class FeaturePyramid(keras.layers.Layer):
209
209
  )
210
210
  if i == backbone_max_level + 1 and self.use_p5:
211
211
  self.output_conv_layers[level].build(
212
- (None, None, None, input_shapes[f"P{i-1}"][-1])
212
+ (None, None, None, input_shapes[f"P{i - 1}"][-1])
213
213
  if self.data_format == "channels_last"
214
- else (None, input_shapes[f"P{i-1}"][1], None, None)
214
+ else (None, input_shapes[f"P{i - 1}"][1], None, None)
215
215
  )
216
216
  else:
217
217
  self.output_conv_layers[level].build(
@@ -277,7 +277,7 @@ class FeaturePyramid(keras.layers.Layer):
277
277
  if i < backbone_max_level:
278
278
  # for the top most output, it doesn't need to merge with any
279
279
  # upper stream outputs
280
- upstream_output = self.top_down_op(output_features[f"P{i+1}"])
280
+ upstream_output = self.top_down_op(output_features[f"P{i + 1}"])
281
281
  output = self.merge_op([output, upstream_output])
282
282
  output_features[level] = (
283
283
  self.lateral_batch_norm_layers[level](output)
@@ -296,9 +296,9 @@ class FeaturePyramid(keras.layers.Layer):
296
296
  for i in range(backbone_max_level + 1, self.max_level + 1):
297
297
  level = f"P{i}"
298
298
  feats_in = (
299
- inputs[f"P{i-1}"]
299
+ inputs[f"P{i - 1}"]
300
300
  if i == backbone_max_level + 1 and self.use_p5
301
- else output_features[f"P{i-1}"]
301
+ else output_features[f"P{i - 1}"]
302
302
  )
303
303
  if i > backbone_max_level + 1:
304
304
  feats_in = self.activation(feats_in)
@@ -15,9 +15,8 @@ class AdaptiveLayerNormalization(layers.Layer):
15
15
 
16
16
  Args:
17
17
  embedding_dim: int. The size of each embedding vector.
18
- residual_modulation: bool. Whether to output the modulation parameters
19
- of the residual connection within the block of the diffusion
20
- transformers. Defaults to `False`.
18
+ num_modulations: int. The number of the modulation parameters. The
19
+ available values are `2`, `6` and `9`. Defaults to `2`.
21
20
  **kwargs: other keyword arguments passed to `keras.layers.Layer`,
22
21
  including `name`, `dtype` etc.
23
22
 
@@ -28,11 +27,17 @@ class AdaptiveLayerNormalization(layers.Layer):
28
27
  https://arxiv.org/abs/2212.09748).
29
28
  """
30
29
 
31
- def __init__(self, hidden_dim, residual_modulation=False, **kwargs):
30
+ def __init__(self, hidden_dim, num_modulations=2, **kwargs):
32
31
  super().__init__(**kwargs)
33
- self.hidden_dim = int(hidden_dim)
34
- self.residual_modulation = bool(residual_modulation)
35
- num_modulations = 6 if self.residual_modulation else 2
32
+ hidden_dim = int(hidden_dim)
33
+ num_modulations = int(num_modulations)
34
+ if num_modulations not in (2, 6, 9):
35
+ raise ValueError(
36
+ "`num_modulations` must be `2`, `6` or `9`. "
37
+ f"Received: num_modulations={num_modulations}"
38
+ )
39
+ self.hidden_dim = hidden_dim
40
+ self.num_modulations = num_modulations
36
41
 
37
42
  self.silu = layers.Activation("silu", dtype=self.dtype_policy)
38
43
  self.dense = layers.Dense(
@@ -52,40 +57,84 @@ class AdaptiveLayerNormalization(layers.Layer):
52
57
  self.norm.build(inputs_shape)
53
58
 
54
59
  def call(self, inputs, embeddings, training=None):
55
- x = inputs
60
+ hidden_states = inputs
56
61
  emb = self.dense(self.silu(embeddings), training=training)
57
- if self.residual_modulation:
58
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
59
- ops.split(emb, 6, axis=1)
60
- )
62
+ if self.num_modulations == 9:
63
+ (
64
+ shift_msa,
65
+ scale_msa,
66
+ gate_msa,
67
+ shift_mlp,
68
+ scale_mlp,
69
+ gate_mlp,
70
+ shift_msa2,
71
+ scale_msa2,
72
+ gate_msa2,
73
+ ) = ops.split(emb, self.num_modulations, axis=1)
74
+ elif self.num_modulations == 6:
75
+ (
76
+ shift_msa,
77
+ scale_msa,
78
+ gate_msa,
79
+ shift_mlp,
80
+ scale_mlp,
81
+ gate_mlp,
82
+ ) = ops.split(emb, self.num_modulations, axis=1)
61
83
  else:
62
- shift_msa, scale_msa = ops.split(emb, 2, axis=1)
84
+ shift_msa, scale_msa = ops.split(emb, self.num_modulations, axis=1)
85
+
63
86
  scale_msa = ops.expand_dims(scale_msa, axis=1)
64
87
  shift_msa = ops.expand_dims(shift_msa, axis=1)
65
- x = ops.add(
66
- ops.multiply(
67
- self.norm(x, training=training),
68
- ops.add(1.0, scale_msa),
69
- ),
70
- shift_msa,
88
+ norm_hidden_states = ops.cast(
89
+ self.norm(hidden_states, training=training), scale_msa.dtype
90
+ )
91
+ hidden_states = ops.add(
92
+ ops.multiply(norm_hidden_states, ops.add(1.0, scale_msa)), shift_msa
71
93
  )
72
- if self.residual_modulation:
73
- return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
94
+
95
+ if self.num_modulations == 9:
96
+ scale_msa2 = ops.expand_dims(scale_msa2, axis=1)
97
+ shift_msa2 = ops.expand_dims(shift_msa2, axis=1)
98
+ hidden_states2 = ops.add(
99
+ ops.multiply(norm_hidden_states, ops.add(1.0, scale_msa2)),
100
+ shift_msa2,
101
+ )
102
+ return (
103
+ hidden_states,
104
+ gate_msa,
105
+ shift_mlp,
106
+ scale_mlp,
107
+ gate_mlp,
108
+ hidden_states2,
109
+ gate_msa2,
110
+ )
111
+ elif self.num_modulations == 6:
112
+ return hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp
74
113
  else:
75
- return x
114
+ return hidden_states
76
115
 
77
116
  def get_config(self):
78
117
  config = super().get_config()
79
118
  config.update(
80
119
  {
81
120
  "hidden_dim": self.hidden_dim,
82
- "residual_modulation": self.residual_modulation,
121
+ "num_modulations": self.num_modulations,
83
122
  }
84
123
  )
85
124
  return config
86
125
 
87
126
  def compute_output_shape(self, inputs_shape, embeddings_shape):
88
- if self.residual_modulation:
127
+ if self.num_modulations == 9:
128
+ return (
129
+ inputs_shape,
130
+ embeddings_shape,
131
+ embeddings_shape,
132
+ embeddings_shape,
133
+ embeddings_shape,
134
+ inputs_shape,
135
+ embeddings_shape,
136
+ )
137
+ elif self.num_modulations == 6:
89
138
  return (
90
139
  inputs_shape,
91
140
  embeddings_shape,
@@ -345,6 +394,27 @@ class TimestepEmbedding(layers.Layer):
345
394
  return output_shape
346
395
 
347
396
 
397
+ def get_qk_norm(qk_norm=None, q_norm_name="q_norm", k_norm_name="k_norm"):
398
+ """Helper function to instantiate `LayerNormalization` layers."""
399
+ q_norm = None
400
+ k_norm = None
401
+ if qk_norm is None:
402
+ pass
403
+ elif qk_norm == "rms_norm":
404
+ q_norm = layers.LayerNormalization(
405
+ epsilon=1e-6, rms_scaling=True, dtype="float32", name=q_norm_name
406
+ )
407
+ k_norm = layers.LayerNormalization(
408
+ epsilon=1e-6, rms_scaling=True, dtype="float32", name=k_norm_name
409
+ )
410
+ else:
411
+ raise NotImplementedError(
412
+ "Supported `qk_norm` are `'rms_norm'` and `None`. "
413
+ f"Received: qk_norm={qk_norm}."
414
+ )
415
+ return q_norm, k_norm
416
+
417
+
348
418
  class DismantledBlock(layers.Layer):
349
419
  """A dismantled block used to compute pre- and post-attention.
350
420
 
@@ -356,6 +426,8 @@ class DismantledBlock(layers.Layer):
356
426
  the end of the block.
357
427
  qk_norm: Optional str. Whether to normalize the query and key tensors.
358
428
  Available options are `None` and `"rms_norm"`. Defaults to `None`.
429
+ use_dual_attention: bool. Whether to use a dual attention in the
430
+ block. Defaults to `False`.
359
431
  **kwargs: other keyword arguments passed to `keras.layers.Layer`,
360
432
  including `name`, `dtype` etc.
361
433
  """
@@ -367,6 +439,7 @@ class DismantledBlock(layers.Layer):
367
439
  mlp_ratio=4.0,
368
440
  use_projection=True,
369
441
  qk_norm=None,
442
+ use_dual_attention=False,
370
443
  **kwargs,
371
444
  ):
372
445
  super().__init__(**kwargs)
@@ -375,6 +448,7 @@ class DismantledBlock(layers.Layer):
375
448
  self.mlp_ratio = mlp_ratio
376
449
  self.use_projection = use_projection
377
450
  self.qk_norm = qk_norm
451
+ self.use_dual_attention = use_dual_attention
378
452
 
379
453
  head_dim = hidden_dim // num_heads
380
454
  self.head_dim = head_dim
@@ -384,7 +458,7 @@ class DismantledBlock(layers.Layer):
384
458
  if use_projection:
385
459
  self.ada_layer_norm = AdaptiveLayerNormalization(
386
460
  hidden_dim,
387
- residual_modulation=True,
461
+ num_modulations=9 if use_dual_attention else 6,
388
462
  dtype=self.dtype_policy,
389
463
  name="ada_layer_norm",
390
464
  )
@@ -395,18 +469,10 @@ class DismantledBlock(layers.Layer):
395
469
  self.attention_qkv = layers.Dense(
396
470
  hidden_dim * 3, dtype=self.dtype_policy, name="attention_qkv"
397
471
  )
398
- if qk_norm is not None and qk_norm == "rms_norm":
399
- self.q_norm = layers.LayerNormalization(
400
- epsilon=1e-6, rms_scaling=True, dtype="float32", name="q_norm"
401
- )
402
- self.k_norm = layers.LayerNormalization(
403
- epsilon=1e-6, rms_scaling=True, dtype="float32", name="q_norm"
404
- )
405
- elif qk_norm is not None:
406
- raise NotImplementedError(
407
- "Supported `qk_norm` are `'rms_norm'` and `None`. "
408
- f"Received: qk_norm={qk_norm}."
409
- )
472
+ q_norm, k_norm = get_qk_norm(qk_norm)
473
+ if q_norm is not None:
474
+ self.q_norm = q_norm
475
+ self.k_norm = k_norm
410
476
  if use_projection:
411
477
  self.attention_proj = layers.Dense(
412
478
  hidden_dim, dtype=self.dtype_policy, name="attention_proj"
@@ -426,6 +492,19 @@ class DismantledBlock(layers.Layer):
426
492
  name="mlp",
427
493
  )
428
494
 
495
+ if use_dual_attention:
496
+ self.attention_qkv2 = layers.Dense(
497
+ hidden_dim * 3, dtype=self.dtype_policy, name="attention_qkv2"
498
+ )
499
+ q_norm2, k_norm2 = get_qk_norm(qk_norm, "q_norm2", "k_norm2")
500
+ if q_norm is not None:
501
+ self.q_norm2 = q_norm2
502
+ self.k_norm2 = k_norm2
503
+ if use_projection:
504
+ self.attention_proj2 = layers.Dense(
505
+ hidden_dim, dtype=self.dtype_policy, name="attention_proj2"
506
+ )
507
+
429
508
  def build(self, inputs_shape, timestep_embedding):
430
509
  self.ada_layer_norm.build(inputs_shape, timestep_embedding)
431
510
  self.attention_qkv.build(inputs_shape)
@@ -437,6 +516,13 @@ class DismantledBlock(layers.Layer):
437
516
  self.attention_proj.build(inputs_shape)
438
517
  self.norm2.build(inputs_shape)
439
518
  self.mlp.build(inputs_shape)
519
+ if self.use_dual_attention:
520
+ self.attention_qkv2.build(inputs_shape)
521
+ if self.qk_norm is not None:
522
+ self.q_norm2.build([None, None, self.num_heads, self.head_dim])
523
+ self.k_norm2.build([None, None, self.num_heads, self.head_dim])
524
+ if self.use_projection:
525
+ self.attention_proj2.build(inputs_shape)
440
526
 
441
527
  def _modulate(self, inputs, shift, scale):
442
528
  inputs = ops.cast(inputs, self.compute_dtype)
@@ -456,8 +542,12 @@ class DismantledBlock(layers.Layer):
456
542
  )
457
543
  q, k, v = ops.unstack(qkv, 3, axis=2)
458
544
  if self.qk_norm is not None:
459
- q = self.q_norm(q, training=training)
460
- k = self.k_norm(k, training=training)
545
+ q = ops.cast(
546
+ self.q_norm(q, training=training), self.compute_dtype
547
+ )
548
+ k = ops.cast(
549
+ self.k_norm(k, training=training), self.compute_dtype
550
+ )
461
551
  return (q, k, v), (inputs, gate_msa, shift_mlp, scale_mlp, gate_mlp)
462
552
  else:
463
553
  x = self.ada_layer_norm(
@@ -469,8 +559,12 @@ class DismantledBlock(layers.Layer):
469
559
  )
470
560
  q, k, v = ops.unstack(qkv, 3, axis=2)
471
561
  if self.qk_norm is not None:
472
- q = self.q_norm(q, training=training)
473
- k = self.k_norm(k, training=training)
562
+ q = ops.cast(
563
+ self.q_norm(q, training=training), self.compute_dtype
564
+ )
565
+ k = ops.cast(
566
+ self.k_norm(k, training=training), self.compute_dtype
567
+ )
474
568
  return (q, k, v)
475
569
 
476
570
  def _compute_post_attention(
@@ -495,22 +589,95 @@ class DismantledBlock(layers.Layer):
495
589
  )
496
590
  return x
497
591
 
592
+ def _compute_pre_attention_with_dual_attention(
593
+ self, inputs, timestep_embedding, training=None
594
+ ):
595
+ batch_size = ops.shape(inputs)[0]
596
+ x, gate_msa, shift_mlp, scale_mlp, gate_mlp, x2, gate_msa2 = (
597
+ self.ada_layer_norm(inputs, timestep_embedding, training=training)
598
+ )
599
+ # Compute the main attention
600
+ qkv = self.attention_qkv(x, training=training)
601
+ qkv = ops.reshape(
602
+ qkv, (batch_size, -1, 3, self.num_heads, self.head_dim)
603
+ )
604
+ q, k, v = ops.unstack(qkv, 3, axis=2)
605
+ if self.qk_norm is not None:
606
+ q = ops.cast(self.q_norm(q, training=training), self.compute_dtype)
607
+ k = ops.cast(self.k_norm(k, training=training), self.compute_dtype)
608
+ # Compute the dual attention
609
+ qkv2 = self.attention_qkv2(x2, training=training)
610
+ qkv2 = ops.reshape(
611
+ qkv2, (batch_size, -1, 3, self.num_heads, self.head_dim)
612
+ )
613
+ q2, k2, v2 = ops.unstack(qkv2, 3, axis=2)
614
+ if self.qk_norm is not None:
615
+ q2 = ops.cast(
616
+ self.q_norm2(q2, training=training), self.compute_dtype
617
+ )
618
+ k2 = ops.cast(
619
+ self.k_norm2(k2, training=training), self.compute_dtype
620
+ )
621
+ return (
622
+ (q, k, v),
623
+ (q2, k2, v2),
624
+ (inputs, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2),
625
+ )
626
+
627
+ def _compute_post_attention_with_dual_attention(
628
+ self, inputs, inputs2, inputs_intermediates, training=None
629
+ ):
630
+ x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2 = (
631
+ inputs_intermediates
632
+ )
633
+ gate_msa = ops.expand_dims(gate_msa, axis=1)
634
+ shift_mlp = ops.expand_dims(shift_mlp, axis=1)
635
+ scale_mlp = ops.expand_dims(scale_mlp, axis=1)
636
+ gate_mlp = ops.expand_dims(gate_mlp, axis=1)
637
+ gate_msa2 = ops.expand_dims(gate_msa2, axis=1)
638
+ attn = self.attention_proj(inputs, training=training)
639
+ x = ops.add(x, ops.multiply(gate_msa, attn))
640
+ attn2 = self.attention_proj2(inputs2, training=training)
641
+ x = ops.add(x, ops.multiply(gate_msa2, attn2))
642
+ x = ops.add(
643
+ x,
644
+ ops.multiply(
645
+ gate_mlp,
646
+ self.mlp(
647
+ self._modulate(self.norm2(x), shift_mlp, scale_mlp),
648
+ training=training,
649
+ ),
650
+ ),
651
+ )
652
+ return x
653
+
498
654
  def call(
499
655
  self,
500
656
  inputs,
501
657
  timestep_embedding=None,
502
658
  inputs_intermediates=None,
659
+ inputs2=None, # For the dual attention.
503
660
  pre_attention=True,
504
661
  training=None,
505
662
  ):
506
663
  if pre_attention:
507
- return self._compute_pre_attention(
508
- inputs, timestep_embedding, training=training
509
- )
664
+ if self.use_dual_attention:
665
+ return self._compute_pre_attention_with_dual_attention(
666
+ inputs, timestep_embedding, training=training
667
+ )
668
+ else:
669
+ return self._compute_pre_attention(
670
+ inputs, timestep_embedding, training=training
671
+ )
510
672
  else:
511
- return self._compute_post_attention(
512
- inputs, inputs_intermediates, training=training
513
- )
673
+ if self.use_dual_attention:
674
+ return self._compute_post_attention_with_dual_attention(
675
+ inputs, inputs2, inputs_intermediates, training=training
676
+ )
677
+ else:
678
+ return self._compute_post_attention(
679
+ inputs, inputs_intermediates, training=training
680
+ )
514
681
 
515
682
  def get_config(self):
516
683
  config = super().get_config()
@@ -521,6 +688,7 @@ class DismantledBlock(layers.Layer):
521
688
  "mlp_ratio": self.mlp_ratio,
522
689
  "use_projection": self.use_projection,
523
690
  "qk_norm": self.qk_norm,
691
+ "use_dual_attention": self.use_dual_attention,
524
692
  }
525
693
  )
526
694
  return config
@@ -542,6 +710,8 @@ class MMDiTBlock(layers.Layer):
542
710
  layer at the end of the context block.
543
711
  qk_norm: Optional str. Whether to normalize the query and key tensors.
544
712
  Available options are `None` and `"rms_norm"`. Defaults to `None`.
713
+ use_dual_attention: bool. Whether to use a dual attention in the
714
+ block. Defaults to `False`.
545
715
  **kwargs: other keyword arguments passed to `keras.layers.Layer`,
546
716
  including `name`, `dtype` etc.
547
717
 
@@ -557,6 +727,7 @@ class MMDiTBlock(layers.Layer):
557
727
  mlp_ratio=4.0,
558
728
  use_context_projection=True,
559
729
  qk_norm=None,
730
+ use_dual_attention=False,
560
731
  **kwargs,
561
732
  ):
562
733
  super().__init__(**kwargs)
@@ -565,6 +736,7 @@ class MMDiTBlock(layers.Layer):
565
736
  self.mlp_ratio = mlp_ratio
566
737
  self.use_context_projection = use_context_projection
567
738
  self.qk_norm = qk_norm
739
+ self.use_dual_attention = use_dual_attention
568
740
 
569
741
  head_dim = hidden_dim // num_heads
570
742
  self.head_dim = head_dim
@@ -576,6 +748,7 @@ class MMDiTBlock(layers.Layer):
576
748
  mlp_ratio=mlp_ratio,
577
749
  use_projection=True,
578
750
  qk_norm=qk_norm,
751
+ use_dual_attention=use_dual_attention,
579
752
  dtype=self.dtype_policy,
580
753
  name="x_block",
581
754
  )
@@ -602,8 +775,6 @@ class MMDiTBlock(layers.Layer):
602
775
  if hasattr(ops, "dot_product_attention") and hasattr(
603
776
  keras.config, "is_flash_attention_enabled"
604
777
  ):
605
- # `ops.dot_product_attention` is slower than the vanilla
606
- # implementation in the tensorflow backend.
607
778
  encoded = ops.dot_product_attention(
608
779
  query,
609
780
  key,
@@ -643,9 +814,14 @@ class MMDiTBlock(layers.Layer):
643
814
  training=training,
644
815
  )
645
816
  context_len = ops.shape(context_qkv[0])[1]
646
- x_qkv, x_intermediates = self.x_block(
647
- x, timestep_embedding=timestep_embedding, training=training
648
- )
817
+ if self.x_block.use_dual_attention:
818
+ x_qkv, x_qkv2, x_intermediates = self.x_block(
819
+ x, timestep_embedding=timestep_embedding, training=training
820
+ )
821
+ else:
822
+ x_qkv, x_intermediates = self.x_block(
823
+ x, timestep_embedding=timestep_embedding, training=training
824
+ )
649
825
  q = ops.concatenate([context_qkv[0], x_qkv[0]], axis=1)
650
826
  k = ops.concatenate([context_qkv[1], x_qkv[1]], axis=1)
651
827
  v = ops.concatenate([context_qkv[2], x_qkv[2]], axis=1)
@@ -656,12 +832,23 @@ class MMDiTBlock(layers.Layer):
656
832
  x_attention = attention[:, context_len:]
657
833
 
658
834
  # Compute post-attention.
659
- x = self.x_block(
660
- x_attention,
661
- inputs_intermediates=x_intermediates,
662
- pre_attention=False,
663
- training=training,
664
- )
835
+ if self.x_block.use_dual_attention:
836
+ q2, k2, v2 = x_qkv2
837
+ x_attention2 = self._compute_attention(q2, k2, v2)
838
+ x = self.x_block(
839
+ x_attention,
840
+ inputs_intermediates=x_intermediates,
841
+ inputs2=x_attention2,
842
+ pre_attention=False,
843
+ training=training,
844
+ )
845
+ else:
846
+ x = self.x_block(
847
+ x_attention,
848
+ inputs_intermediates=x_intermediates,
849
+ pre_attention=False,
850
+ training=training,
851
+ )
665
852
  if self.use_context_projection:
666
853
  context = self.context_block(
667
854
  context_attention,
@@ -682,6 +869,7 @@ class MMDiTBlock(layers.Layer):
682
869
  "mlp_ratio": self.mlp_ratio,
683
870
  "use_context_projection": self.use_context_projection,
684
871
  "qk_norm": self.qk_norm,
872
+ "use_dual_attention": self.use_dual_attention,
685
873
  }
686
874
  )
687
875
  return config
@@ -761,6 +949,9 @@ class MMDiT(Backbone):
761
949
  qk_norm: Optional str. Whether to normalize the query and key tensors in
762
950
  the intermediate blocks. Available options are `None` and
763
951
  `"rms_norm"`. Defaults to `None`.
952
+ dual_attention_indices: Optional tuple. Specifies the indices of
953
+ the blocks that serve as dual attention blocks. Typically, this is
954
+ for 3.5 version. Defaults to `None`.
764
955
  data_format: `None` or str. If specified, either `"channels_last"` or
765
956
  `"channels_first"`. The ordering of the dimensions in the
766
957
  inputs. `"channels_last"` corresponds to inputs with shape
@@ -786,6 +977,7 @@ class MMDiT(Backbone):
786
977
  context_shape=(None, 4096),
787
978
  pooled_projection_shape=(2048,),
788
979
  qk_norm=None,
980
+ dual_attention_indices=None,
789
981
  data_format=None,
790
982
  dtype=None,
791
983
  **kwargs,
@@ -799,6 +991,7 @@ class MMDiT(Backbone):
799
991
  image_width = latent_shape[1] // patch_size
800
992
  output_dim = latent_shape[-1]
801
993
  output_dim_in_final = patch_size**2 * output_dim
994
+ dual_attention_indices = dual_attention_indices or ()
802
995
  data_format = standardize_data_format(data_format)
803
996
  if data_format != "channels_last":
804
997
  raise NotImplementedError(
@@ -840,6 +1033,7 @@ class MMDiT(Backbone):
840
1033
  mlp_ratio,
841
1034
  use_context_projection=not (i == num_layers - 1),
842
1035
  qk_norm=qk_norm,
1036
+ use_dual_attention=i in dual_attention_indices,
843
1037
  dtype=dtype,
844
1038
  name=f"joint_block_{i}",
845
1039
  )
@@ -910,6 +1104,7 @@ class MMDiT(Backbone):
910
1104
  self.context_shape = context_shape
911
1105
  self.pooled_projection_shape = pooled_projection_shape
912
1106
  self.qk_norm = qk_norm
1107
+ self.dual_attention_indices = dual_attention_indices
913
1108
 
914
1109
  def get_config(self):
915
1110
  config = super().get_config()
@@ -925,6 +1120,7 @@ class MMDiT(Backbone):
925
1120
  "context_shape": self.context_shape,
926
1121
  "pooled_projection_shape": self.pooled_projection_shape,
927
1122
  "qk_norm": self.qk_norm,
1123
+ "dual_attention_indices": self.dual_attention_indices,
928
1124
  }
929
1125
  )
930
1126
  return config
@@ -205,7 +205,10 @@ class StableDiffusion3Backbone(Backbone):
205
205
  mmdit_qk_norm: Optional str. Whether to normalize the query and key
206
206
  tensors for each transformer in MMDiT. Available options are `None`
207
207
  and `"rms_norm"`. Typically, this is set to `None` for 3.0 version
208
- and to `"rms_norm" for 3.5 version.
208
+ and to `"rms_norm"` for 3.5 version.
209
+ mmdit_dual_attention_indices: Optional tuple. Specifies the indices of
210
+ the blocks that serve as dual attention blocks. Typically, this is
211
+ for 3.5 version. Defaults to `None`.
209
212
  vae: The VAE used for transformations between pixel space and latent
210
213
  space.
211
214
  clip_l: The CLIP text encoder for encoding the inputs.
@@ -253,6 +256,7 @@ class StableDiffusion3Backbone(Backbone):
253
256
  mmdit_depth=4,
254
257
  mmdit_position_size=192,
255
258
  mmdit_qk_norm=None,
259
+ mmdit_dual_attention_indices=None,
256
260
  vae=vae,
257
261
  clip_l=clip_l,
258
262
  clip_g=clip_g,
@@ -268,6 +272,7 @@ class StableDiffusion3Backbone(Backbone):
268
272
  mmdit_num_heads,
269
273
  mmdit_position_size,
270
274
  mmdit_qk_norm,
275
+ mmdit_dual_attention_indices,
271
276
  vae,
272
277
  clip_l,
273
278
  clip_g,
@@ -319,6 +324,7 @@ class StableDiffusion3Backbone(Backbone):
319
324
  context_shape=context_shape,
320
325
  pooled_projection_shape=pooled_projection_shape,
321
326
  qk_norm=mmdit_qk_norm,
327
+ dual_attention_indices=mmdit_dual_attention_indices,
322
328
  data_format=data_format,
323
329
  dtype=dtype,
324
330
  name="diffuser",
@@ -454,6 +460,7 @@ class StableDiffusion3Backbone(Backbone):
454
460
  self.mmdit_num_heads = mmdit_num_heads
455
461
  self.mmdit_position_size = mmdit_position_size
456
462
  self.mmdit_qk_norm = mmdit_qk_norm
463
+ self.mmdit_dual_attention_indices = mmdit_dual_attention_indices
457
464
  self.latent_channels = latent_channels
458
465
  self.output_channels = output_channels
459
466
  self.num_train_timesteps = num_train_timesteps
@@ -590,6 +597,9 @@ class StableDiffusion3Backbone(Backbone):
590
597
  "mmdit_num_heads": self.mmdit_num_heads,
591
598
  "mmdit_position_size": self.mmdit_position_size,
592
599
  "mmdit_qk_norm": self.mmdit_qk_norm,
600
+ "mmdit_dual_attention_indices": (
601
+ self.mmdit_dual_attention_indices
602
+ ),
593
603
  "vae": layers.serialize(self.vae),
594
604
  "clip_l": layers.serialize(self.clip_l),
595
605
  "clip_g": layers.serialize(self.clip_g),
@@ -638,7 +648,10 @@ class StableDiffusion3Backbone(Backbone):
638
648
  )
639
649
 
640
650
  # To maintain backward compatibility, we need to ensure that
641
- # `mmdit_qk_norm` is included in the config.
651
+ # `mmdit_qk_norm` and `mmdit_dual_attention_indices` is included in the
652
+ # config.
642
653
  if "mmdit_qk_norm" not in config:
643
654
  config["mmdit_qk_norm"] = None
655
+ if "mmdit_dual_attention_indices" not in config:
656
+ config["mmdit_dual_attention_indices"] = None
644
657
  return cls(**config)
@@ -82,8 +82,7 @@ class StableDiffusion3Inpaint(Inpaint):
82
82
 
83
83
  def fit(self, *args, **kwargs):
84
84
  raise NotImplementedError(
85
- "Currently, `fit` is not supported for "
86
- "`StableDiffusion3Inpaint`."
85
+ "Currently, `fit` is not supported for `StableDiffusion3Inpaint`."
87
86
  )
88
87
 
89
88
  def generate_step(
@@ -13,6 +13,18 @@ backbone_presets = {
13
13
  },
14
14
  "kaggle_handle": "kaggle://keras/stablediffusion3/keras/stable_diffusion_3_medium/4",
15
15
  },
16
+ "stable_diffusion_3.5_medium": {
17
+ "metadata": {
18
+ "description": (
19
+ "3 billion parameter, including CLIP L and CLIP G text "
20
+ "encoders, MMDiT-X generative model, and VAE autoencoder. "
21
+ "Developed by Stability AI."
22
+ ),
23
+ "params": 3371793763,
24
+ "path": "stable_diffusion_3",
25
+ },
26
+ "kaggle_handle": "kaggle://keras/stablediffusion3/keras/stable_diffusion_3.5_medium/1",
27
+ },
16
28
  "stable_diffusion_3.5_large": {
17
29
  "metadata": {
18
30
  "description": (
@@ -351,7 +351,7 @@ class ViTEncoder(keras.layers.Layer):
351
351
  attention_dropout=self.attention_dropout,
352
352
  layer_norm_epsilon=self.layer_norm_epsilon,
353
353
  dtype=self.dtype_policy,
354
- name=f"tranformer_block_{i+1}",
354
+ name=f"tranformer_block_{i + 1}",
355
355
  )
356
356
  encoder_block.build((None, None, self.hidden_dim))
357
357
  self.encoder_layers.append(encoder_block)
@@ -150,8 +150,7 @@ class ByteTokenizer(tokenizer.Tokenizer):
150
150
  ):
151
151
  if not is_int_dtype(dtype):
152
152
  raise ValueError(
153
- "Output dtype must be an integer type. "
154
- f"Received: dtype={dtype}"
153
+ f"Output dtype must be an integer type. Received: dtype={dtype}"
155
154
  )
156
155
 
157
156
  # Check normalization_form.
@@ -1,5 +1,7 @@
1
1
  import io
2
2
 
3
+ from keras_hub.src.utils.tensor_utils import assert_tf_libs_installed
4
+
3
5
  try:
4
6
  import sentencepiece as spm
5
7
  import tensorflow as tf
@@ -77,6 +79,7 @@ def compute_sentence_piece_proto(
77
79
  tf.Tensor([ 4 8 12 5 9 14 5 6 13 4 7 10 11 6 13],
78
80
  shape=(15,), dtype=int32)
79
81
  """
82
+ assert_tf_libs_installed("compute_sentence_piece_proto")
80
83
 
81
84
  if spm is None:
82
85
  raise ImportError(
@@ -203,8 +203,7 @@ class UnicodeCodepointTokenizer(tokenizer.Tokenizer):
203
203
  ) -> None:
204
204
  if not is_int_dtype(dtype):
205
205
  raise ValueError(
206
- "Output dtype must be an integer type. "
207
- f"Received: dtype={dtype}"
206
+ f"Output dtype must be an integer type. Received: dtype={dtype}"
208
207
  )
209
208
 
210
209
  # Check normalization_form.
@@ -1,5 +1,6 @@
1
1
  from keras_hub.src.api_export import keras_hub_export
2
2
  from keras_hub.src.tokenizers.word_piece_tokenizer import pretokenize
3
+ from keras_hub.src.utils.tensor_utils import assert_tf_libs_installed
3
4
 
4
5
  try:
5
6
  import tensorflow as tf
@@ -117,6 +118,8 @@ def compute_word_piece_vocabulary(
117
118
  inputs.map(tokenizer.tokenize)
118
119
  ```
119
120
  """ # noqa: E501
121
+ assert_tf_libs_installed("compute_word_piece_vocabulary")
122
+
120
123
  # Read data files.
121
124
  if not isinstance(data, (list, tf.data.Dataset)):
122
125
  raise ValueError(
@@ -59,9 +59,11 @@ def convert_weights(backbone, loader, timm_config):
59
59
  num_stacks = len(backbone.stackwise_num_repeats)
60
60
  for stack_index in range(num_stacks):
61
61
  for block_idx in range(backbone.stackwise_num_repeats[stack_index]):
62
- keras_name = f"stack{stack_index+1}_block{block_idx+1}"
62
+ keras_name = f"stack{stack_index + 1}_block{block_idx + 1}"
63
63
  hf_name = (
64
- f"features.denseblock{stack_index+1}.denselayer{block_idx+1}"
64
+ "features."
65
+ f"denseblock{stack_index + 1}"
66
+ f".denselayer{block_idx + 1}"
65
67
  )
66
68
  port_batch_normalization(f"{keras_name}_1_bn", f"{hf_name}.norm1")
67
69
  port_conv2d(f"{keras_name}_1_conv", f"{hf_name}.conv1")
@@ -69,8 +71,8 @@ def convert_weights(backbone, loader, timm_config):
69
71
  port_conv2d(f"{keras_name}_2_conv", f"{hf_name}.conv2")
70
72
 
71
73
  for stack_index in range(num_stacks - 1):
72
- keras_transition_name = f"transition{stack_index+1}"
73
- hf_transition_name = f"features.transition{stack_index+1}"
74
+ keras_transition_name = f"transition{stack_index + 1}"
75
+ hf_transition_name = f"features.transition{stack_index + 1}"
74
76
  port_batch_normalization(
75
77
  f"{keras_transition_name}_bn", f"{hf_transition_name}.norm"
76
78
  )
@@ -268,7 +268,7 @@ def convert_weights(backbone, loader, timm_config):
268
268
  # 97 is the start of the lowercase alphabet.
269
269
  letter_identifier = chr(block_idx + 97)
270
270
 
271
- keras_block_prefix = f"block{stack_index+1}{letter_identifier}_"
271
+ keras_block_prefix = f"block{stack_index + 1}{letter_identifier}_"
272
272
  hf_block_prefix = f"blocks.{stack_index}.{block_idx}."
273
273
 
274
274
  if block_type == "v1":
@@ -89,7 +89,7 @@ def convert_weights(backbone, loader, timm_config):
89
89
  for block_idx in range(backbone.stackwise_num_blocks[stack_index]):
90
90
  if version == "v1":
91
91
  keras_name = f"stack{stack_index}_block{block_idx}"
92
- hf_name = f"layer{stack_index+1}.{block_idx}"
92
+ hf_name = f"layer{stack_index + 1}.{block_idx}"
93
93
  else:
94
94
  keras_name = f"stack{stack_index}_block{block_idx}"
95
95
  hf_name = f"stages.{stack_index}.blocks.{block_idx}"
@@ -1,7 +1,7 @@
1
1
  from keras_hub.src.api_export import keras_hub_export
2
2
 
3
3
  # Unique source of truth for the version number.
4
- __version__ = "0.19.0.dev202501080345"
4
+ __version__ = "0.19.0.dev202501150344"
5
5
 
6
6
 
7
7
  @keras_hub_export("keras_hub.version")
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: keras-hub-nightly
3
- Version: 0.19.0.dev202501080345
3
+ Version: 0.19.0.dev202501150344
4
4
  Summary: Industry-strength Natural Language Processing extensions for Keras.
5
5
  Home-page: https://github.com/keras-team/keras-hub
6
6
  Author: Keras team
@@ -31,6 +31,17 @@ Requires-Dist: tensorflow-text
31
31
  Provides-Extra: extras
32
32
  Requires-Dist: rouge-score; extra == "extras"
33
33
  Requires-Dist: sentencepiece; extra == "extras"
34
+ Dynamic: author
35
+ Dynamic: author-email
36
+ Dynamic: classifier
37
+ Dynamic: description
38
+ Dynamic: description-content-type
39
+ Dynamic: home-page
40
+ Dynamic: license
41
+ Dynamic: provides-extra
42
+ Dynamic: requires-dist
43
+ Dynamic: requires-python
44
+ Dynamic: summary
34
45
 
35
46
  # KerasHub: Multi-framework Pretrained Models
36
47
  [![](https://github.com/keras-team/keras-hub/workflows/Tests/badge.svg?branch=master)](https://github.com/keras-team/keras-hub/actions?query=workflow%3ATests+branch%3Amaster)
@@ -9,7 +9,7 @@ keras_hub/api/tokenizers/__init__.py,sha256=mtJgQy1spfQnPAkeLoeinsT_W9iCWHlJXwzc
9
9
  keras_hub/api/utils/__init__.py,sha256=Gp1E6gG-RtKQS3PBEQEOz9PQvXkXaJ0ySGMqZ7myN7A,215
10
10
  keras_hub/src/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
11
  keras_hub/src/api_export.py,sha256=9pQZK27JObxWZ96QPLBp1OBsjWigh1iuV6RglPGMRk0,1499
12
- keras_hub/src/version_utils.py,sha256=PULLssHhM5UCqNcvlGVVyAQo9rDMkhLThc6DLU9Kz2g,222
12
+ keras_hub/src/version_utils.py,sha256=XXUJ1oMuODMzez6Sqr-8PGIem4zG0YD-78PlkMxNEXI,222
13
13
  keras_hub/src/bounding_box/__init__.py,sha256=7i6KnGupN4AVivR_dFjQyuuTbI0GkHy8d-aMXeqZdU8,95
14
14
  keras_hub/src/bounding_box/converters.py,sha256=UUp1hwegpDZyIo8sh9TLNy1v6JjwmvwzL6wmHFMAtbk,21916
15
15
  keras_hub/src/bounding_box/formats.py,sha256=YmskOz2BOSat7NaE__J9VfpSNGPJJR0znSzA4lp8MMI,3868
@@ -43,7 +43,7 @@ keras_hub/src/layers/preprocessing/random_deletion.py,sha256=x23nRo0ir2J4Ps42i9X
43
43
  keras_hub/src/layers/preprocessing/random_swap.py,sha256=w2z7yNQsII5g4sEFi4GXfgxIc1S6UUt3a8YWZew_f4U,9504
44
44
  keras_hub/src/layers/preprocessing/start_end_packer.py,sha256=lY2K937z6JucxNe7VknynhhjrcUfFigU6mqIdv2gS-Y,7973
45
45
  keras_hub/src/metrics/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
46
- keras_hub/src/metrics/bleu.py,sha256=N4hnlCIFk558nZAHxGlzLYBx6gwpWS3Wvw1iFM69xiA,13665
46
+ keras_hub/src/metrics/bleu.py,sha256=pnid5azpAxO6vKEfUtAby3nH29OGbwYKgVGOGeoaA3I,13694
47
47
  keras_hub/src/metrics/edit_distance.py,sha256=kjhe8uNjvv8aN49RyrKAbNi7a8_OlB8fMza0J_CfNQg,6353
48
48
  keras_hub/src/metrics/perplexity.py,sha256=dDUQcfE5JbAruG3spEkgue6IjHcynqgmGpJAqWg22Tw,6139
49
49
  keras_hub/src/metrics/rouge_base.py,sha256=Pt2DUznhTTeR-fX1nQ_wSbPtmuTgxQTvrGpu8LRVapE,6264
@@ -87,7 +87,7 @@ keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py,sha256=3_e-ULIcm_3DK
87
87
  keras_hub/src/models/bart/bart_tokenizer.py,sha256=Q7IXmIwXzhPSN427oQRyF9ufoExQGS184Yo_4boaOZo,2811
88
88
  keras_hub/src/models/basnet/__init__.py,sha256=4N6XvIUYYJl5xtoaL3_9fawUX_qP3WmTYNEEU7tn8Gw,253
89
89
  keras_hub/src/models/basnet/basnet.py,sha256=JA58Q9lmygdSOm5MUaPAlaL6B8XnmqCcRaGrk9c8P3Q,4287
90
- keras_hub/src/models/basnet/basnet_backbone.py,sha256=t_52WW6jetONS7AnPf9YsiMLDqOjVwjNuayQEv6ZAk4,13503
90
+ keras_hub/src/models/basnet/basnet_backbone.py,sha256=P-jogkYIu9j7_28fl2RFQRMl87BXz1wcY_LtIrxBy1E,13505
91
91
  keras_hub/src/models/basnet/basnet_image_converter.py,sha256=DwzAwtZeggYw_qyRQ-Abnnm885Wobv3wClxRzOTscI0,342
92
92
  keras_hub/src/models/basnet/basnet_preprocessor.py,sha256=uM504utaXODSqR5zpKnopRuaV_l84zCg06RkNoNSKIs,510
93
93
  keras_hub/src/models/basnet/basnet_presets.py,sha256=z6tR2q_EvYnUmGfsWIWYfmR_8gvWYPH3QmtpAu_T8f8,63
@@ -135,11 +135,11 @@ keras_hub/src/models/deeplab_v3/__init__.py,sha256=FHAUPM4a1DJj4EsNTbYEd1riNq__u
135
135
  keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py,sha256=dH7HHu_NAnE-HP6ivOL7fFLQZHt_MWmehlMccLljhPc,7764
136
136
  keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py,sha256=mRkH3HdhpV0fCcQcVXEvIX7SNk-bAMb3SAHzgK-FD5c,371
137
137
  keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py,sha256=hR9S6lNYamY0EBDBo3e1qTCiwtftmLXrN-UYuzfw5Io,581
138
- keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py,sha256=qmEiolOOriLAojXB67xXW9IOo717kaCGeDVZJLaGY98,7834
138
+ keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py,sha256=mz9nG55gdXSTDE96AXgeTCwUFB95DIpTuqrvWIt5Lco,7840
139
139
  keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py,sha256=ZKYY8A7mV2QvwXwjDUd9xAbVHo58-Hgj_IqNUbuyCIU,625
140
140
  keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py,sha256=pubi30sPJKLOpz9fRQff2FZt_53KBvwf2uyaJ5YL7J8,3726
141
141
  keras_hub/src/models/densenet/__init__.py,sha256=r7StyamnWeeZxOk9r4ZYNbS_YVhu9YGPyXhNxljvdPg,269
142
- keras_hub/src/models/densenet/densenet_backbone.py,sha256=5QawyB4EhyaXpmm8l_QUYveU7kEet3jRD3s94XAz8Tw,6738
142
+ keras_hub/src/models/densenet/densenet_backbone.py,sha256=f2nfsXyXQert2aYHq-L-JZtp8inq1fs1K47rzZQ9nTI,6744
143
143
  keras_hub/src/models/densenet/densenet_image_classifier.py,sha256=ye-Ix3oU42pfsDoh-h1PG4di1kzldO0ZO7Nj304p_X4,544
144
144
  keras_hub/src/models/densenet/densenet_image_classifier_preprocessor.py,sha256=xDZbTw_h6pjLDzf8QmbDyMnMsFzgh-dPX1ldg9kddhg,563
145
145
  keras_hub/src/models/densenet/densenet_image_converter.py,sha256=DoxYlJVZ9uaabFhVjWOmzvhONoc8KNcQj2vQ6Z1AUpU,354
@@ -186,7 +186,7 @@ keras_hub/src/models/flux/flux_layers.py,sha256=wevcAEbayBD8bVm-21FBi2LQ13pZzB99
186
186
  keras_hub/src/models/flux/flux_maths.py,sha256=2pnHW8HW7V2JZ8HIrUwE-UU4klpFQaOkoAvG5nWVfyY,7502
187
187
  keras_hub/src/models/flux/flux_model.py,sha256=K92PyeFHIp8SwXuxhv__XCEaQ2wqSW1jOb97I4S24Rw,8991
188
188
  keras_hub/src/models/flux/flux_presets.py,sha256=z7C_FbI1_F5YETXuWpc7Yh_0w-5N0eBQy6Oks_X9W88,54
189
- keras_hub/src/models/flux/flux_text_to_image.py,sha256=mI_QxOzjXl3b5s7Q1LZemceCdeboqPD5ilEPEEyer40,4169
189
+ keras_hub/src/models/flux/flux_text_to_image.py,sha256=Rf5dD2EhG0bE8Gyg9sqaA8YEexS1kdraofIkxiZDjvc,4166
190
190
  keras_hub/src/models/flux/flux_text_to_image_preprocessor.py,sha256=Fs9jr97QtmRUbRRz1kITpkuhDM2GoV3n0XSFC-qQA14,2252
191
191
  keras_hub/src/models/gemma/__init__.py,sha256=rVzOJMJ39bgVlT8UdC0t8PlN2c237GKTBmfHIsbPuOQ,251
192
192
  keras_hub/src/models/gemma/gemma_attention.py,sha256=1CVN5z9GKoU8TuNMih2_MweDkpd98xSqdic9F8xIBE8,8317
@@ -257,7 +257,7 @@ keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py,sha256=AViEs6YltUqWnIVo7
257
257
  keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py,sha256=F57y0fZ0wYYxfGIjfrJc1W9uQpViYFx5bvFjj5CqUbI,4814
258
258
  keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py,sha256=24ABQ1vGlppV-KfWh0YqJjzM_Lu2GIwvyJ4X2XXie_A,5616
259
259
  keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py,sha256=5yM_jUtrFsWIieiwfFBoP7mtPmQAwywkeLKbd7fhmzk,371
260
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py,sha256=O648iwzs0wooiQCfDQ-n0wOtzIOEDGXRSwSb_Brx2Ck,8985
260
+ keras_hub/src/models/pali_gemma/pali_gemma_presets.py,sha256=Ka1ChUUSKw-yY2th3QtmNtkeXt0krYfwhkHrScioMls,8979
261
261
  keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py,sha256=ljTiADHo0Ok88q-jVzwJIle2C8xcxnudLTsBLzIySaM,2415
262
262
  keras_hub/src/models/pali_gemma/pali_gemma_vit.py,sha256=ViPKfGksbxBGJ3iS3M_KWxRc8Ie4LF7rWWUKDiqECJE,18285
263
263
  keras_hub/src/models/phi3/__init__.py,sha256=zIbf1MU-ks91mEkjTRJAsk51N3BBnXDF2JM1vO-13PQ,245
@@ -271,7 +271,7 @@ keras_hub/src/models/phi3/phi3_presets.py,sha256=sb2ce7Gq1OikFEf2KIYG69rFKHYKj8q
271
271
  keras_hub/src/models/phi3/phi3_rotary_embedding.py,sha256=wqiRn8nETNcLc5Vsm_d_8s11Ro6ibWZbWvODdLqIOo4,5013
272
272
  keras_hub/src/models/phi3/phi3_tokenizer.py,sha256=bOPH14wTVVHJHq8mgzXLjsgvKMNhfO8eayevAPpjYVA,1992
273
273
  keras_hub/src/models/resnet/__init__.py,sha256=C5UqlQ6apm8WSp1bnrxB6Bi3BGaknxRQs-r3b2wpaGA,257
274
- keras_hub/src/models/resnet/resnet_backbone.py,sha256=3acTjdWbnos8l_TPxYLgoV3Y4V_vJ_o1AqGhiQu459k,31274
274
+ keras_hub/src/models/resnet/resnet_backbone.py,sha256=Q7nlqcTXZzjqd0e-DsjHC4ok58yOX7qxseotym3uZpM,31276
275
275
  keras_hub/src/models/resnet/resnet_image_classifier.py,sha256=nf35EKDzvBkfhHsK-s6Ks0nbhvKO7HEOYZm94YckyWE,510
276
276
  keras_hub/src/models/resnet/resnet_image_classifier_preprocessor.py,sha256=fM7gyQ0qB-RRuI4USJkRD6q9-HVfuC71e-BLTo-UhHQ,543
277
277
  keras_hub/src/models/resnet/resnet_image_converter.py,sha256=fgTxihJznGFss-y3Z-jp0JE3X1gaaB2y-f2KMwrT8Pk,342
@@ -279,7 +279,7 @@ keras_hub/src/models/resnet/resnet_presets.py,sha256=cryfXlC_FSEN_jrexKIh5aVbzp8
279
279
  keras_hub/src/models/retinanet/__init__.py,sha256=veWIFvMN6151M69l7FvTcI-IIEe_8dLmNO5NLOszQ1c,275
280
280
  keras_hub/src/models/retinanet/anchor_generator.py,sha256=0OgKSW3OKmbc0cOPHF6FYTAzn7fcHklg665PGSwAaDM,6504
281
281
  keras_hub/src/models/retinanet/box_matcher.py,sha256=l820r1R-ByqiyVgmZ0YFjjz0njchDda-wItzLn1X84o,10834
282
- keras_hub/src/models/retinanet/feature_pyramid.py,sha256=VxLcOEjJSXIDu30oMcZEYdVlpHaOP3IutZNwh0N3uHQ,17604
282
+ keras_hub/src/models/retinanet/feature_pyramid.py,sha256=hbdrj6X-D2SlwOp2h1WcBlTdSAlLmFK43X7OrkJRoMA,17614
283
283
  keras_hub/src/models/retinanet/non_max_supression.py,sha256=PMOLlRw-EnyEmhlUhJjEbHf1xXiplN95pUxQbiJQbN4,20996
284
284
  keras_hub/src/models/retinanet/prediction_head.py,sha256=xWHt21-SS2t7vCmTONlR1lSbJXhml5jx68V8MGbGybg,7863
285
285
  keras_hub/src/models/retinanet/retinanet_backbone.py,sha256=BJBPJLxpOCOU0Br7b4JsgCZBHQHLAhxLqo9BHNIsl1g,5659
@@ -314,11 +314,11 @@ keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py,sha256=
314
314
  keras_hub/src/models/segformer/segformer_presets.py,sha256=ET39ospixkTaCsjoMLdJrr3wlGvTAQu5prleVC5lMZI,4793
315
315
  keras_hub/src/models/stable_diffusion_3/__init__.py,sha256=ZKYQuaRObyhKq8GVAHmoRvlXp6FpU8ChvutVCHyXKuc,343
316
316
  keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py,sha256=vtVhieAv277mAiZj7Kvvqg_Ba7klfQxZVk4PPxNNQ0s,3062
317
- keras_hub/src/models/stable_diffusion_3/mmdit.py,sha256=poJlz-xt06hgOtn_Bw5YQDxZtDBc9L4Vo0ahhGwPly4,33340
318
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py,sha256=u0Wwtbl5b-1z_vn07TRw4jpkVYrReZeHbWqQIrZjyCA,23368
317
+ keras_hub/src/models/stable_diffusion_3/mmdit.py,sha256=0gq2tcIqcbiGKKDDj3vrRsF67U3qE9g706XPs2BfCOY,40979
318
+ keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py,sha256=w8lsMampk34M9xQi96mEnXmkaKQqFQtoFTW8zP7ilEA,24078
319
319
  keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py,sha256=oQcVCWOwrdUTrr_JNekoMqdSlKYMGz5tG6v8uD25lTc,5479
320
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py,sha256=aZMIC-GYjLhdU_yM7fJEznApCo1zwRAgwQbW0tCW0xY,6399
321
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py,sha256=z6wrfv8rCqLBzn7_edRcKCIDQRTNUgLqyr-LLp55-IE,1680
320
+ keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py,sha256=t4uw920Jn1k80air3WRGimKf71aMVu6q73oWFH348vk,6384
321
+ keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py,sha256=x7Ez4L955MJE4ABtBy-63YpU9XpR0Ro8QWPzYYJs1yE,2167
322
322
  keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py,sha256=Yt-UIatVKANjjKFCFEj1rIHhOrt8hqefKKQJIAWcTLc,4567
323
323
  keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py,sha256=m5PdVSgTcYuqd7jOQ8wD4PAnMa7wY2WdhwpK3hdydhM,2756
324
324
  keras_hub/src/models/stable_diffusion_3/t5_encoder.py,sha256=oV7P1uwCKdGiD93zXq7kmqX0elMZQU4UvBa8wg6P1hs,5113
@@ -344,7 +344,7 @@ keras_hub/src/models/vit/vit_backbone.py,sha256=kGmRZO4u-1q4PBcbhJbiWVIEVYAcp2H4
344
344
  keras_hub/src/models/vit/vit_image_classifier.py,sha256=lMVxiD1_6drx7XQ7P7YzlqnFP7kT1zlMe84f-T3SDQI,6332
345
345
  keras_hub/src/models/vit/vit_image_classifier_preprocessor.py,sha256=wu6YcBlXMWB9sKCPvmNdGBZKTLQt_HyHWS6P9nyDwsk,504
346
346
  keras_hub/src/models/vit/vit_image_converter.py,sha256=5xVF04BzMcdTDc6aErAYj3_BuGmVd3zoJMcH1ho4T0g,2561
347
- keras_hub/src/models/vit/vit_layers.py,sha256=s4j3n3qnJnv6W9AdUkNsO3Vsi_BhxEGECYkaLVCU6XY,13238
347
+ keras_hub/src/models/vit/vit_layers.py,sha256=Zsz-ARPY49S1nXLUtpFwtPfw31D-vCtKesEo_2JIKPA,13240
348
348
  keras_hub/src/models/vit/vit_presets.py,sha256=zZhxUleOom1ie3gn0Mi-_xhhdFEEsnqSQyKADV2L38k,4479
349
349
  keras_hub/src/models/vit_det/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
350
350
  keras_hub/src/models/vit_det/vit_det_backbone.py,sha256=DOZ5J7c1t5PAZ6y0pMmBoQTMOUup7UoUrYVfCs69ltY,7697
@@ -383,13 +383,13 @@ keras_hub/src/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSu
383
383
  keras_hub/src/tests/test_case.py,sha256=oGWoUhlKgjVMNIjvUVnQR-k5iKvodztHsFMOs669Trw,27402
384
384
  keras_hub/src/tokenizers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
385
385
  keras_hub/src/tokenizers/byte_pair_tokenizer.py,sha256=WeUlHMAf5y_MUjFIfVhEcFoOZu-z4kkSj-Dq-pegM9w,24052
386
- keras_hub/src/tokenizers/byte_tokenizer.py,sha256=c1a41eVuLzGmBtscQ0RxPIqFi41m_604KJ9fdpPR7Sc,10437
386
+ keras_hub/src/tokenizers/byte_tokenizer.py,sha256=GPIKaddXugbfckfhodADsBpaYb72DgFMs_xfXHnK4qU,10418
387
387
  keras_hub/src/tokenizers/sentence_piece_tokenizer.py,sha256=nOqkpa2nHitITpdowPHdwxiN87e8huLW8Dt2gozVnhI,9350
388
- keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py,sha256=LhUxwcaDKt5V58DBzK9Sh4D-hOL80SHGpL4LavWbq74,4642
388
+ keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py,sha256=caqgV9N4lH97zBviFPdpwo_O95AaJBEJLQv6Icq3Hs8,4774
389
389
  keras_hub/src/tokenizers/tokenizer.py,sha256=v0Ka5ayrBwpsGBlkIadXK-b4RsMTbhV6BZrvKullbxY,9722
390
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py,sha256=KxuVsUx3ntGsuqaQ-gnFWFfoVLsl5Hag7rBk6xfq-fQ,13572
390
+ keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py,sha256=hRv_XxoPIPDpHfO0ZttSOv_M89sMaFpvmllojvKz_ac,13553
391
391
  keras_hub/src/tokenizers/word_piece_tokenizer.py,sha256=vP6AZgbzsRiuPCt3W_n94nsF7XiERnagWcH_rqJHtVU,19943
392
- keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py,sha256=Zz1SGgArykxBVWnS5YV-ViqyMOrw3j3i_i_jto96zCg,6610
392
+ keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py,sha256=cylrs02ZrYQ1TuZr9oyS3NrVbDwGctA3VXbIh1pFJMQ,6743
393
393
  keras_hub/src/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
394
394
  keras_hub/src/utils/keras_utils.py,sha256=0yKIfFuO_IqAH8vHbG3ncRmCVKg__xRGfQtLYWZ8YuA,1695
395
395
  keras_hub/src/utils/pipeline_model.py,sha256=jgzB6NQPSl0KOu08N-TazfOnXnUJbZjH2EXXhx25Ftg,9084
@@ -399,9 +399,9 @@ keras_hub/src/utils/tensor_utils.py,sha256=YVJesN91bk-OzJXY1mOKBppuY8noBU7zhPQNX
399
399
  keras_hub/src/utils/imagenet/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
400
400
  keras_hub/src/utils/imagenet/imagenet_utils.py,sha256=MvIvv1WJo51ZXBxy4S7t_DsN3ZMtJWlC4cmRvKM2kIA,39304
401
401
  keras_hub/src/utils/timm/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
402
- keras_hub/src/utils/timm/convert_densenet.py,sha256=V-GRjWuDnlh3b1EMxqahwZ3GMwSgOa3v0HOfb2ZZ-d0,3342
403
- keras_hub/src/utils/timm/convert_efficientnet.py,sha256=wkOKTLS_N_VKy1CQQGjSlD_TPSOOmCMMXQvbjravN6g,17098
404
- keras_hub/src/utils/timm/convert_resnet.py,sha256=ee8eTml0ffJKE8avzGoLFcpjPF63DsvoIUArAGa8Ngg,5832
402
+ keras_hub/src/utils/timm/convert_densenet.py,sha256=fu8HBIQis5o3ib2tyI2qnmYScVrVIQySok8vTfa1qJ8,3393
403
+ keras_hub/src/utils/timm/convert_efficientnet.py,sha256=SgEIlyyinS04qoQpEgh3WazHq544zNUCCpfmWh3EjSs,17100
404
+ keras_hub/src/utils/timm/convert_resnet.py,sha256=8JFkVtdpy5z9h83LJ97rD-a8FRejXPZvMNksNuStqjM,5834
405
405
  keras_hub/src/utils/timm/convert_vgg.py,sha256=MT5jGnLrzenPpe66Af_Lp1IdR9KGtsSrcmn6_UPqHvQ,2419
406
406
  keras_hub/src/utils/timm/preset_loader.py,sha256=cdZDjthZdTD2myMOenQar4ACyi7VTuIzNRg24LuqS-4,3374
407
407
  keras_hub/src/utils/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -417,7 +417,7 @@ keras_hub/src/utils/transformers/convert_pali_gemma.py,sha256=B1leeDw96Yvu81hYum
417
417
  keras_hub/src/utils/transformers/convert_vit.py,sha256=9SUZ9utNJhW_5cj3acMn9cRy47u2eIcDsrhmzj77o9k,5187
418
418
  keras_hub/src/utils/transformers/preset_loader.py,sha256=DgGJXbTSB9Na8FIR-YWWVqQPOFxHwWrGm41EwcS_EFs,3797
419
419
  keras_hub/src/utils/transformers/safetensor_utils.py,sha256=CYUHyA4y-B61r7NDnCsFb4t_UmSwZ1k9L-8gzEd6KRg,3339
420
- keras_hub_nightly-0.19.0.dev202501080345.dist-info/METADATA,sha256=WWsYYpkd-P_ryoA3jId3bNDKaMQOJfy7eBeYQ7N_D6w,7260
421
- keras_hub_nightly-0.19.0.dev202501080345.dist-info/WHEEL,sha256=A3WOREP4zgxI0fKrHUG8DC8013e3dK3n7a6HDbcEIwE,91
422
- keras_hub_nightly-0.19.0.dev202501080345.dist-info/top_level.txt,sha256=N4J6piIWBKa38A4uV-CnIopnOEf8mHAbkNXafXm_CuA,10
423
- keras_hub_nightly-0.19.0.dev202501080345.dist-info/RECORD,,
420
+ keras_hub_nightly-0.19.0.dev202501150344.dist-info/METADATA,sha256=FhbHeGMBpOfmdE1bEoJdl34xtvV3n85LqFH_5STUyUo,7498
421
+ keras_hub_nightly-0.19.0.dev202501150344.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
422
+ keras_hub_nightly-0.19.0.dev202501150344.dist-info/top_level.txt,sha256=N4J6piIWBKa38A4uV-CnIopnOEf8mHAbkNXafXm_CuA,10
423
+ keras_hub_nightly-0.19.0.dev202501150344.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.7.0)
2
+ Generator: setuptools (75.8.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5