keras-hub 0.21.1.dev0__py3-none-any.whl → 0.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. keras_hub/layers/__init__.py +9 -0
  2. keras_hub/models/__init__.py +47 -0
  3. keras_hub/src/layers/modeling/transformer_encoder.py +6 -3
  4. keras_hub/src/layers/preprocessing/multi_segment_packer.py +17 -3
  5. keras_hub/src/layers/preprocessing/start_end_packer.py +24 -6
  6. keras_hub/src/models/backbone.py +13 -10
  7. keras_hub/src/models/clip/clip_backbone.py +3 -102
  8. keras_hub/src/models/clip/clip_layers.py +295 -0
  9. keras_hub/src/models/clip/clip_preprocessor.py +57 -48
  10. keras_hub/src/models/clip/clip_text_encoder.py +2 -2
  11. keras_hub/src/models/clip/clip_vision_encoder.py +3 -3
  12. keras_hub/src/models/deit/__init__.py +5 -0
  13. keras_hub/src/models/deit/deit_backbone.py +154 -0
  14. keras_hub/src/models/deit/deit_image_classifier.py +171 -0
  15. keras_hub/src/models/deit/deit_image_classifier_preprocessor.py +12 -0
  16. keras_hub/src/models/deit/deit_image_converter.py +8 -0
  17. keras_hub/src/models/deit/deit_layers.py +519 -0
  18. keras_hub/src/models/deit/deit_presets.py +49 -0
  19. keras_hub/src/models/dinov2/__init__.py +5 -0
  20. keras_hub/src/models/dinov2/dinov2_backbone.py +228 -0
  21. keras_hub/src/models/dinov2/dinov2_image_converter.py +8 -0
  22. keras_hub/src/models/dinov2/dinov2_layers.py +886 -0
  23. keras_hub/src/models/dinov2/dinov2_presets.py +89 -0
  24. keras_hub/src/models/esm/__init__.py +5 -0
  25. keras_hub/src/models/esm/esm_attention.py +95 -0
  26. keras_hub/src/models/esm/esm_backbone.py +229 -0
  27. keras_hub/src/models/esm/esm_classifier.py +184 -0
  28. keras_hub/src/models/esm/esm_classifier_preprocessor.py +135 -0
  29. keras_hub/src/models/esm/esm_encoder.py +134 -0
  30. keras_hub/src/models/esm/esm_masked_plm.py +117 -0
  31. keras_hub/src/models/esm/esm_masked_plm_preprocessor.py +143 -0
  32. keras_hub/src/models/esm/esm_presets.py +53 -0
  33. keras_hub/src/models/esm/esm_tokenizer.py +82 -0
  34. keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +6 -2
  35. keras_hub/src/models/gemma/gemma_attention.py +1 -1
  36. keras_hub/src/models/gemma3/gemma3_backbone.py +2 -2
  37. keras_hub/src/models/gemma3/gemma3_interleave_embeddings.py +1 -1
  38. keras_hub/src/models/gemma3/gemma3_presets.py +25 -0
  39. keras_hub/src/models/hgnetv2/__init__.py +5 -0
  40. keras_hub/src/models/hgnetv2/hgnetv2_backbone.py +193 -0
  41. keras_hub/src/models/hgnetv2/hgnetv2_encoder.py +148 -0
  42. keras_hub/src/models/hgnetv2/hgnetv2_image_classifier.py +216 -0
  43. keras_hub/src/models/hgnetv2/hgnetv2_image_classifier_preprocessor.py +14 -0
  44. keras_hub/src/models/hgnetv2/hgnetv2_image_converter.py +8 -0
  45. keras_hub/src/models/hgnetv2/hgnetv2_layers.py +918 -0
  46. keras_hub/src/models/hgnetv2/hgnetv2_presets.py +58 -0
  47. keras_hub/src/models/llama3/llama3_presets.py +3 -3
  48. keras_hub/src/models/mistral/mistral_presets.py +17 -1
  49. keras_hub/src/models/mixtral/mixtral_presets.py +2 -2
  50. keras_hub/src/models/mobilenet/mobilenet_presets.py +4 -4
  51. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +2 -2
  52. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +2 -2
  53. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +17 -17
  54. keras_hub/src/models/qwen3/__init__.py +5 -0
  55. keras_hub/src/models/qwen3/qwen3_attention.py +369 -0
  56. keras_hub/src/models/qwen3/qwen3_backbone.py +191 -0
  57. keras_hub/src/models/qwen3/qwen3_causal_lm.py +390 -0
  58. keras_hub/src/models/qwen3/qwen3_causal_lm_preprocessor.py +10 -0
  59. keras_hub/src/models/qwen3/qwen3_decoder.py +309 -0
  60. keras_hub/src/models/qwen3/qwen3_layernorm.py +38 -0
  61. keras_hub/src/models/qwen3/qwen3_presets.py +73 -0
  62. keras_hub/src/models/qwen3/qwen3_tokenizer.py +48 -0
  63. keras_hub/src/models/qwen_moe/qwen_moe_attention.py +1 -0
  64. keras_hub/src/models/qwen_moe/qwen_moe_presets.py +2 -2
  65. keras_hub/src/models/roformer_v2/roformer_v2_attention.py +0 -2
  66. keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +16 -7
  67. keras_hub/src/models/stable_diffusion_3/mmdit.py +61 -4
  68. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +31 -32
  69. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +1 -0
  70. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +1 -0
  71. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +1 -0
  72. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +6 -2
  73. keras_hub/src/models/vit/vit_backbone.py +31 -11
  74. keras_hub/src/models/vit/vit_image_converter.py +0 -70
  75. keras_hub/src/models/vit/vit_layers.py +33 -18
  76. keras_hub/src/models/vit/vit_presets.py +11 -11
  77. keras_hub/src/utils/keras_utils.py +17 -0
  78. keras_hub/src/utils/preset_utils.py +19 -4
  79. keras_hub/src/utils/tensor_utils.py +14 -0
  80. keras_hub/src/utils/transformers/convert_deit.py +155 -0
  81. keras_hub/src/utils/transformers/convert_dinov2.py +180 -0
  82. keras_hub/src/utils/transformers/convert_esm.py +159 -0
  83. keras_hub/src/utils/transformers/convert_llama3.py +6 -0
  84. keras_hub/src/utils/transformers/convert_qwen3.py +145 -0
  85. keras_hub/src/utils/transformers/export/gemma.py +89 -0
  86. keras_hub/src/utils/transformers/export/hf_exporter.py +98 -0
  87. keras_hub/src/utils/transformers/preset_loader.py +14 -2
  88. keras_hub/src/version.py +1 -1
  89. keras_hub/tokenizers/__init__.py +1 -0
  90. {keras_hub-0.21.1.dev0.dist-info → keras_hub-0.22.0.dist-info}/METADATA +4 -4
  91. {keras_hub-0.21.1.dev0.dist-info → keras_hub-0.22.0.dist-info}/RECORD +93 -49
  92. keras_hub/src/models/clip/clip_encoder_block.py +0 -111
  93. keras_hub/src/models/clip/clip_vision_embedding.py +0 -101
  94. {keras_hub-0.21.1.dev0.dist-info → keras_hub-0.22.0.dist-info}/WHEEL +0 -0
  95. {keras_hub-0.21.1.dev0.dist-info → keras_hub-0.22.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,6 @@
1
1
  import keras
2
+ from keras import backend
3
+ from keras import distribution
2
4
  from keras import layers
3
5
  from keras import ops
4
6
 
@@ -96,26 +98,10 @@ class LatentRescaling(layers.Rescaling):
96
98
  return (self.backend.cast(inputs, dtype) / scale) + offset
97
99
 
98
100
 
99
- class ClassifierFreeGuidanceConcatenate(layers.Layer):
100
- def call(
101
- self,
102
- latents,
103
- positive_contexts,
104
- negative_contexts,
105
- positive_pooled_projections,
106
- negative_pooled_projections,
107
- timestep,
108
- ):
101
+ class TimestepBroadcastTo(layers.Layer):
102
+ def call(self, latents, timestep):
109
103
  timestep = ops.broadcast_to(timestep, ops.shape(latents)[:1])
110
- latents = ops.concatenate([latents, latents], axis=0)
111
- contexts = ops.concatenate(
112
- [positive_contexts, negative_contexts], axis=0
113
- )
114
- pooled_projections = ops.concatenate(
115
- [positive_pooled_projections, negative_pooled_projections], axis=0
116
- )
117
- timesteps = ops.concatenate([timestep, timestep], axis=0)
118
- return latents, contexts, pooled_projections, timesteps
104
+ return timestep
119
105
 
120
106
 
121
107
  class ClassifierFreeGuidance(layers.Layer):
@@ -330,8 +316,8 @@ class StableDiffusion3Backbone(Backbone):
330
316
  name="diffuser",
331
317
  )
332
318
  self.vae = vae
333
- self.cfg_concat = ClassifierFreeGuidanceConcatenate(
334
- dtype=dtype, name="classifier_free_guidance_concat"
319
+ self.timestep_broadcast_to = TimestepBroadcastTo(
320
+ dtype=dtype, name="timestep_broadcast_to"
335
321
  )
336
322
  self.cfg = ClassifierFreeGuidance(
337
323
  dtype=dtype, name="classifier_free_guidance"
@@ -538,6 +524,9 @@ class StableDiffusion3Backbone(Backbone):
538
524
  latents = self.vae.encode(images)
539
525
  return self.image_rescaling(latents)
540
526
 
527
+ def configure_scheduler(self, num_steps):
528
+ self.scheduler.set_sigmas(num_steps)
529
+
541
530
  def add_noise_step(self, latents, noises, step, num_steps):
542
531
  return self.scheduler.add_noise(latents, noises, step, num_steps)
543
532
 
@@ -562,11 +551,15 @@ class StableDiffusion3Backbone(Backbone):
562
551
 
563
552
  # Concatenation for classifier-free guidance.
564
553
  if guidance_scale is not None:
565
- concated_latents, contexts, pooled_projs, timesteps = (
566
- self.cfg_concat(latents, *embeddings, timestep)
554
+ timestep = self.timestep_broadcast_to(latents, timestep)
555
+ timesteps = ops.concatenate([timestep, timestep], axis=0)
556
+ concated_latents = ops.concatenate([latents, latents], axis=0)
557
+ contexts = ops.concatenate([embeddings[0], embeddings[1]], axis=0)
558
+ pooled_projs = ops.concatenate(
559
+ [embeddings[2], embeddings[3]], axis=0
567
560
  )
568
561
  else:
569
- timesteps = ops.broadcast_to(timestep, ops.shape(latents)[:1])
562
+ timesteps = self.timestep_broadcast_to(latents, timestep)
570
563
  concated_latents = latents
571
564
  contexts = embeddings[0]
572
565
  pooled_projs = embeddings[2]
@@ -623,20 +616,26 @@ class StableDiffusion3Backbone(Backbone):
623
616
  def from_config(cls, config, custom_objects=None):
624
617
  config = config.copy()
625
618
 
626
- # Propagate `dtype` to text encoders if needed.
619
+ # Propagate `dtype` to the VAE if needed.
627
620
  if "dtype" in config and config["dtype"] is not None:
628
621
  dtype_config = config["dtype"]
629
622
  if "dtype" not in config["vae"]["config"]:
630
623
  config["vae"]["config"]["dtype"] = dtype_config
631
- if "dtype" not in config["clip_l"]["config"]:
632
- config["clip_l"]["config"]["dtype"] = dtype_config
633
- if "dtype" not in config["clip_g"]["config"]:
634
- config["clip_g"]["config"]["dtype"] = dtype_config
624
+
625
+ # Text encoders default to float16 dtype if not specified.
626
+ # TODO: JAX CPU doesn't support float16 in `nn.dot_product_attention`.
627
+ is_jax_cpu = (
628
+ backend.backend() == "jax"
629
+ and "cpu" in distribution.list_devices()[0].lower()
630
+ )
631
+ for text_encoder in ("clip_l", "clip_g", "t5"):
635
632
  if (
636
- config["t5"] is not None
637
- and "dtype" not in config["t5"]["config"]
633
+ text_encoder in config
634
+ and config[text_encoder] is not None
635
+ and "dtype" not in config[text_encoder]["config"]
636
+ and not is_jax_cpu
638
637
  ):
639
- config["t5"]["config"]["dtype"] = dtype_config
638
+ config[text_encoder]["config"]["dtype"] = "float16"
640
639
 
641
640
  # We expect `vae`, `clip_l`, `clip_g` and/or `t5` to be instantiated.
642
641
  config["vae"] = layers.deserialize(
@@ -169,6 +169,7 @@ class StableDiffusion3ImageToImage(ImageToImage):
169
169
  guidance_scale=7.0,
170
170
  seed=None,
171
171
  ):
172
+ self.backbone.configure_scheduler(num_steps)
172
173
  return super().generate(
173
174
  inputs,
174
175
  num_steps=num_steps,
@@ -184,6 +184,7 @@ class StableDiffusion3Inpaint(Inpaint):
184
184
  guidance_scale=7.0,
185
185
  seed=None,
186
186
  ):
187
+ self.backbone.configure_scheduler(num_steps)
187
188
  return super().generate(
188
189
  inputs,
189
190
  num_steps=num_steps,
@@ -141,6 +141,7 @@ class StableDiffusion3TextToImage(TextToImage):
141
141
  guidance_scale=7.0,
142
142
  seed=None,
143
143
  ):
144
+ self.backbone.configure_scheduler(num_steps)
144
145
  return super().generate(
145
146
  inputs,
146
147
  num_steps=num_steps,
@@ -50,8 +50,12 @@ class StableDiffusion3TextToImagePreprocessor(TextToImagePreprocessor):
50
50
 
51
51
  def generate_preprocess(self, x):
52
52
  token_ids = {}
53
- token_ids["clip_l"] = self.clip_l_preprocessor(x)["token_ids"]
54
- token_ids["clip_g"] = self.clip_g_preprocessor(x)["token_ids"]
53
+ token_ids["clip_l"] = self.clip_l_preprocessor(
54
+ {"prompts": x, "images": None}
55
+ )["token_ids"]
56
+ token_ids["clip_g"] = self.clip_g_preprocessor(
57
+ {"prompts": x, "images": None}
58
+ )["token_ids"]
55
59
  if self.t5_preprocessor is not None:
56
60
  token_ids["t5"] = self.t5_preprocessor(x)["token_ids"]
57
61
  return token_ids
@@ -18,10 +18,10 @@ class ViTBackbone(Backbone):
18
18
 
19
19
  Args:
20
20
  image_shape: A tuple or list of 3 integers representing the shape of the
21
- input image `(height, width, channels)`, `height` and `width` must
22
- be equal.
23
- patch_size: int. The size of each image patch, the input image will be
24
- divided into patches of shape `(patch_size, patch_size)`.
21
+ input image `(height, width, channels)`.
22
+ patch_size: int or (int, int). The size of each image patch, the input
23
+ image will be divided into patches of shape
24
+ `(patch_size_h, patch_size_w)`.
25
25
  num_layers: int. The number of transformer encoder layers.
26
26
  num_heads: int. specifying the number of attention heads in each
27
27
  Transformer encoder layer.
@@ -37,6 +37,10 @@ class ViTBackbone(Backbone):
37
37
  use_mha_bias: bool. Whether to use bias in the multi-head
38
38
  attention layers.
39
39
  use_mlp_bias: bool. Whether to use bias in the MLP layers.
40
+ use_class_token: bool. Whether to use class token to be part of
41
+ patch embedding. Defaults to `True`.
42
+ use_patch_bias: bool. Whether to use bias in Conv2d of patch embedding
43
+ layer. Defaults to `True`.
40
44
  data_format: str. `"channels_last"` or `"channels_first"`, specifying
41
45
  the data format for the input image. If `None`, defaults to
42
46
  `"channels_last"`.
@@ -58,6 +62,8 @@ class ViTBackbone(Backbone):
58
62
  layer_norm_epsilon=1e-6,
59
63
  use_mha_bias=True,
60
64
  use_mlp_bias=True,
65
+ use_class_token=True,
66
+ use_patch_bias=True,
61
67
  data_format=None,
62
68
  dtype=None,
63
69
  **kwargs,
@@ -74,24 +80,34 @@ class ViTBackbone(Backbone):
74
80
  f"at index {h_axis} (height) or {w_axis} (width). "
75
81
  f"Image shape: {image_shape}"
76
82
  )
77
- if image_shape[h_axis] != image_shape[w_axis]:
83
+
84
+ if isinstance(patch_size, int):
85
+ patch_size = (patch_size, patch_size)
86
+
87
+ if image_shape[h_axis] % patch_size[0] != 0:
88
+ raise ValueError(
89
+ f"Input height {image_shape[h_axis]} should be divisible by "
90
+ f"patch size {patch_size[0]}."
91
+ )
92
+
93
+ if image_shape[w_axis] % patch_size[1] != 0:
78
94
  raise ValueError(
79
- f"Image height and width must be equal. Found height: "
80
- f"{image_shape[h_axis]}, width: {image_shape[w_axis]} at "
81
- f"indices {h_axis} and {w_axis} respectively. Image shape: "
82
- f"{image_shape}"
95
+ f"Input width {image_shape[h_axis]} should be divisible by "
96
+ f"patch size {patch_size[1]}."
83
97
  )
84
98
 
85
99
  num_channels = image_shape[channels_axis]
86
100
 
87
101
  # === Functional Model ===
88
- inputs = keras.layers.Input(shape=image_shape)
102
+ inputs = keras.layers.Input(shape=image_shape, name="images")
89
103
 
90
104
  x = ViTPatchingAndEmbedding(
91
- image_size=image_shape[h_axis],
105
+ image_size=(image_shape[h_axis], image_shape[w_axis]),
92
106
  patch_size=patch_size,
93
107
  hidden_dim=hidden_dim,
94
108
  num_channels=num_channels,
109
+ use_class_token=use_class_token,
110
+ use_patch_bias=use_patch_bias,
95
111
  data_format=data_format,
96
112
  dtype=dtype,
97
113
  name="vit_patching_and_embedding",
@@ -130,6 +146,8 @@ class ViTBackbone(Backbone):
130
146
  self.layer_norm_epsilon = layer_norm_epsilon
131
147
  self.use_mha_bias = use_mha_bias
132
148
  self.use_mlp_bias = use_mlp_bias
149
+ self.use_class_token = use_class_token
150
+ self.use_patch_bias = use_patch_bias
133
151
  self.data_format = data_format
134
152
 
135
153
  def get_config(self):
@@ -147,6 +165,8 @@ class ViTBackbone(Backbone):
147
165
  "layer_norm_epsilon": self.layer_norm_epsilon,
148
166
  "use_mha_bias": self.use_mha_bias,
149
167
  "use_mlp_bias": self.use_mlp_bias,
168
+ "use_class_token": self.use_class_token,
169
+ "use_patch_bias": self.use_patch_bias,
150
170
  }
151
171
  )
152
172
  return config
@@ -1,78 +1,8 @@
1
1
  from keras_hub.src.api_export import keras_hub_export
2
2
  from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
3
3
  from keras_hub.src.models.vit.vit_backbone import ViTBackbone
4
- from keras_hub.src.utils.tensor_utils import preprocessing_function
5
4
 
6
5
 
7
6
  @keras_hub_export("keras_hub.layers.ViTImageConverter")
8
7
  class ViTImageConverter(ImageConverter):
9
- """Converts images to the format expected by a ViT model.
10
-
11
- This layer performs image normalization using mean and standard deviation
12
- values. By default, it uses the same normalization as the
13
- "google/vit-large-patch16-224" model on Hugging Face:
14
- `norm_mean=[0.5, 0.5, 0.5]` and `norm_std=[0.5, 0.5, 0.5]`
15
- ([reference](https://huggingface.co/google/vit-large-patch16-224/blob/main/preprocessor_config.json)).
16
- These defaults are suitable for models pretrained using this normalization.
17
-
18
- Args:
19
- norm_mean: list or tuple of floats. Mean values for image normalization.
20
- Defaults to `[0.5, 0.5, 0.5]`.
21
- norm_std: list or tuple of floats. Standard deviation values for
22
- image normalization. Defaults to `[0.5, 0.5, 0.5]`.
23
- **kwargs: Additional keyword arguments passed to
24
- `keras_hub.layers.preprocessing.ImageConverter`.
25
-
26
- Examples:
27
- ```python
28
- import keras
29
- import numpy as np
30
- from keras_hub.src.layers import ViTImageConverter
31
-
32
- # Example image (replace with your actual image data)
33
- image = np.random.rand(1, 224, 224, 3) # Example: (B, H, W, C)
34
-
35
- # Create a ViTImageConverter instance
36
- converter = ViTImageConverter(
37
- image_size=(28,28),
38
- scale=1/255.
39
- )
40
- # Preprocess the image
41
- preprocessed_image = converter(image)
42
- ```
43
- """
44
-
45
8
  backbone_cls = ViTBackbone
46
-
47
- def __init__(
48
- self, norm_mean=[0.5, 0.5, 0.5], norm_std=[0.5, 0.5, 0.5], **kwargs
49
- ):
50
- super().__init__(**kwargs)
51
- self.norm_mean = norm_mean
52
- self.norm_std = norm_std
53
-
54
- @preprocessing_function
55
- def call(self, inputs):
56
- # TODO: Remove this whole function. Why can just use scale and offset
57
- # in the base class.
58
- x = super().call(inputs)
59
- if self.norm_mean:
60
- norm_mean = self._expand_non_channel_dims(self.norm_mean, x)
61
- x, norm_mean = self._convert_types(x, norm_mean, self.compute_dtype)
62
- x = x - norm_mean
63
- if self.norm_std:
64
- norm_std = self._expand_non_channel_dims(self.norm_std, x)
65
- x, norm_std = self._convert_types(x, norm_std, x.dtype)
66
- x = x / norm_std
67
-
68
- return x
69
-
70
- def get_config(self):
71
- config = super().get_config()
72
- config.update(
73
- {
74
- "norm_mean": self.norm_mean,
75
- "norm_std": self.norm_std,
76
- }
77
- )
78
- return config
@@ -75,12 +75,13 @@ class ViTPatchingAndEmbedding(keras.layers.Layer):
75
75
  """Patches the image and embeds the patches.
76
76
 
77
77
  Args:
78
- image_size: int. Size of the input image (height or width).
79
- Assumed to be square.
80
- patch_size: int. Size of each image patch.
78
+ image_size: (int, int). Size of the input image.
79
+ patch_size: (int, int). Size of each image patch.
81
80
  hidden_dim: int. Dimensionality of the patch embeddings.
82
81
  num_channels: int. Number of channels in the input image. Defaults to
83
82
  `3`.
83
+ use_class_token: bool. Whether to use class token to be part of
84
+ patch embedding. Defaults to `True`.
84
85
  data_format: str. `"channels_last"` or `"channels_first"`. Defaults to
85
86
  `None` (which uses `"channels_last"`).
86
87
  **kwargs: Additional keyword arguments passed to `keras.layers.Layer`
@@ -92,12 +93,15 @@ class ViTPatchingAndEmbedding(keras.layers.Layer):
92
93
  patch_size,
93
94
  hidden_dim,
94
95
  num_channels=3,
96
+ use_class_token=True,
97
+ use_patch_bias=True,
95
98
  data_format=None,
96
99
  **kwargs,
97
100
  ):
98
101
  super().__init__(**kwargs)
99
- num_patches = (image_size // patch_size) ** 2
100
- num_positions = num_patches + 1
102
+ grid_size = tuple([s // p for s, p in zip(image_size, patch_size)])
103
+ num_patches = grid_size[0] * grid_size[1]
104
+ num_positions = num_patches + 1 if use_class_token else num_patches
101
105
 
102
106
  # === Config ===
103
107
  self.image_size = image_size
@@ -106,19 +110,22 @@ class ViTPatchingAndEmbedding(keras.layers.Layer):
106
110
  self.num_channels = num_channels
107
111
  self.num_patches = num_patches
108
112
  self.num_positions = num_positions
113
+ self.use_class_token = use_class_token
114
+ self.use_patch_bias = use_patch_bias
109
115
  self.data_format = standardize_data_format(data_format)
110
116
 
111
117
  def build(self, input_shape):
112
- self.class_token = self.add_weight(
113
- shape=(
114
- 1,
115
- 1,
116
- self.hidden_dim,
117
- ),
118
- initializer="random_normal",
119
- dtype=self.variable_dtype,
120
- name="class_token",
121
- )
118
+ if self.use_class_token:
119
+ self.class_token = self.add_weight(
120
+ shape=(
121
+ 1,
122
+ 1,
123
+ self.hidden_dim,
124
+ ),
125
+ initializer="random_normal",
126
+ dtype=self.variable_dtype,
127
+ name="class_token",
128
+ )
122
129
  self.patch_embedding = keras.layers.Conv2D(
123
130
  filters=self.hidden_dim,
124
131
  kernel_size=self.patch_size,
@@ -127,6 +134,7 @@ class ViTPatchingAndEmbedding(keras.layers.Layer):
127
134
  activation=None,
128
135
  dtype=self.dtype_policy,
129
136
  data_format=self.data_format,
137
+ use_bias=self.use_patch_bias,
130
138
  name="patch_embedding",
131
139
  )
132
140
  self.patch_embedding.build(input_shape)
@@ -153,10 +161,16 @@ class ViTPatchingAndEmbedding(keras.layers.Layer):
153
161
  patch_embeddings = ops.reshape(
154
162
  patch_embeddings, [embeddings_shape[0], -1, embeddings_shape[-1]]
155
163
  )
156
- class_token = ops.tile(self.class_token, (embeddings_shape[0], 1, 1))
157
164
  position_embeddings = self.position_embedding(self.position_ids)
158
- embeddings = ops.concatenate([class_token, patch_embeddings], axis=1)
159
- return ops.add(embeddings, position_embeddings)
165
+
166
+ if self.use_class_token:
167
+ class_token = ops.tile(
168
+ self.class_token, (embeddings_shape[0], 1, 1)
169
+ )
170
+ patch_embeddings = ops.concatenate(
171
+ [class_token, patch_embeddings], axis=1
172
+ )
173
+ return ops.add(patch_embeddings, position_embeddings)
160
174
 
161
175
  def compute_output_shape(self, input_shape):
162
176
  return (
@@ -175,6 +189,7 @@ class ViTPatchingAndEmbedding(keras.layers.Layer):
175
189
  "num_channels": self.num_channels,
176
190
  "num_patches": self.num_patches,
177
191
  "num_positions": self.num_positions,
192
+ "use_class_token": self.use_class_token,
178
193
  }
179
194
  )
180
195
  return config
@@ -11,7 +11,7 @@ backbone_presets = {
11
11
  "params": 85798656,
12
12
  "path": "vit",
13
13
  },
14
- "kaggle_handle": "kaggle://keras/vit/keras/vit_base_patch16_224_imagenet/2",
14
+ "kaggle_handle": "kaggle://keras/vit/keras/vit_base_patch16_224_imagenet/3",
15
15
  },
16
16
  "vit_base_patch16_384_imagenet": {
17
17
  "metadata": {
@@ -22,7 +22,7 @@ backbone_presets = {
22
22
  "params": 86090496,
23
23
  "path": "vit",
24
24
  },
25
- "kaggle_handle": "kaggle://keras/vit/keras/vit_base_patch16_384_imagenet/2",
25
+ "kaggle_handle": "kaggle://keras/vit/keras/vit_base_patch16_384_imagenet/3",
26
26
  },
27
27
  "vit_large_patch16_224_imagenet": {
28
28
  "metadata": {
@@ -33,7 +33,7 @@ backbone_presets = {
33
33
  "params": 303301632,
34
34
  "path": "vit",
35
35
  },
36
- "kaggle_handle": "kaggle://keras/vit/keras/vit_large_patch16_224_imagenet/2",
36
+ "kaggle_handle": "kaggle://keras/vit/keras/vit_large_patch16_224_imagenet/3",
37
37
  },
38
38
  "vit_large_patch16_384_imagenet": {
39
39
  "metadata": {
@@ -44,7 +44,7 @@ backbone_presets = {
44
44
  "params": 303690752,
45
45
  "path": "vit",
46
46
  },
47
- "kaggle_handle": "kaggle://keras/vit/keras/vit_large_patch16_384_imagenet/2",
47
+ "kaggle_handle": "kaggle://keras/vit/keras/vit_large_patch16_384_imagenet/3",
48
48
  },
49
49
  "vit_base_patch32_384_imagenet": {
50
50
  "metadata": {
@@ -55,7 +55,7 @@ backbone_presets = {
55
55
  "params": 87528192,
56
56
  "path": "vit",
57
57
  },
58
- "kaggle_handle": "kaggle://keras/vit/keras/vit_base_patch32_384_imagenet/1",
58
+ "kaggle_handle": "kaggle://keras/vit/keras/vit_base_patch32_384_imagenet/2",
59
59
  },
60
60
  "vit_large_patch32_384_imagenet": {
61
61
  "metadata": {
@@ -66,7 +66,7 @@ backbone_presets = {
66
66
  "params": 305607680,
67
67
  "path": "vit",
68
68
  },
69
- "kaggle_handle": "kaggle://keras/vit/keras/vit_large_patch32_384_imagenet/1",
69
+ "kaggle_handle": "kaggle://keras/vit/keras/vit_large_patch32_384_imagenet/2",
70
70
  },
71
71
  "vit_base_patch16_224_imagenet21k": {
72
72
  "metadata": {
@@ -77,7 +77,7 @@ backbone_presets = {
77
77
  "params": 85798656,
78
78
  "path": "vit",
79
79
  },
80
- "kaggle_handle": "kaggle://keras/vit/keras/vit_base_patch16_224_imagenet21k/1",
80
+ "kaggle_handle": "kaggle://keras/vit/keras/vit_base_patch16_224_imagenet21k/2",
81
81
  },
82
82
  "vit_base_patch32_224_imagenet21k": {
83
83
  "metadata": {
@@ -88,7 +88,7 @@ backbone_presets = {
88
88
  "params": 87455232,
89
89
  "path": "vit",
90
90
  },
91
- "kaggle_handle": "kaggle://keras/vit/keras/vit_base_patch32_224_imagenet21k/1",
91
+ "kaggle_handle": "kaggle://keras/vit/keras/vit_base_patch32_224_imagenet21k/2",
92
92
  },
93
93
  "vit_huge_patch14_224_imagenet21k": {
94
94
  "metadata": {
@@ -99,7 +99,7 @@ backbone_presets = {
99
99
  "params": 630764800,
100
100
  "path": "vit",
101
101
  },
102
- "kaggle_handle": "kaggle://keras/vit/keras/vit_huge_patch14_224_imagenet21k/1",
102
+ "kaggle_handle": "kaggle://keras/vit/keras/vit_huge_patch14_224_imagenet21k/2",
103
103
  },
104
104
  "vit_large_patch16_224_imagenet21k": {
105
105
  "metadata": {
@@ -110,7 +110,7 @@ backbone_presets = {
110
110
  "params": 303301632,
111
111
  "path": "vit",
112
112
  },
113
- "kaggle_handle": "kaggle://keras/vit/keras/vit_large_patch16_224_imagenet21k/1",
113
+ "kaggle_handle": "kaggle://keras/vit/keras/vit_large_patch16_224_imagenet21k/2",
114
114
  },
115
115
  "vit_large_patch32_224_imagenet21k": {
116
116
  "metadata": {
@@ -121,6 +121,6 @@ backbone_presets = {
121
121
  "params": 305510400,
122
122
  "path": "vit",
123
123
  },
124
- "kaggle_handle": "kaggle://keras/vit/keras/vit_large_patch32_224_imagenet21k/1",
124
+ "kaggle_handle": "kaggle://keras/vit/keras/vit_large_patch32_224_imagenet21k/2",
125
125
  },
126
126
  }
@@ -71,6 +71,23 @@ def fused_attention_op_available():
71
71
  )
72
72
  return False
73
73
  return True
74
+ elif (
75
+ hasattr(keras.config, "is_flash_attention_enabled")
76
+ and keras.config.backend() == "torch"
77
+ ):
78
+ try:
79
+ from torch.backends.cuda import SDPAParams as SDPAParams
80
+ from torch.backends.cuda import (
81
+ can_use_flash_attention as can_use_flash_attention,
82
+ )
83
+ except ImportError:
84
+ logging.warning(
85
+ "Flash attention is not supported in your current PyTorch "
86
+ "version. Please update it by following the official guide: "
87
+ "https://pytorch.org/get-started/locally/"
88
+ )
89
+ return False
90
+ return True
74
91
  else:
75
92
  return False
76
93
 
@@ -1,5 +1,6 @@
1
1
  import collections
2
2
  import datetime
3
+ import glob
3
4
  import inspect
4
5
  import json
5
6
  import os
@@ -317,7 +318,8 @@ def _validate_backbone(preset):
317
318
  )
318
319
 
319
320
  weights_path = os.path.join(preset, MODEL_WEIGHTS_FILE)
320
- if not os.path.exists(weights_path):
321
+ sharded_weights_path = os.path.join(preset, "model_*.weights.h5")
322
+ if not os.path.exists(weights_path) and not glob.glob(sharded_weights_path):
321
323
  raise FileNotFoundError(
322
324
  f"The weights file is missing from the preset directory `{preset}`."
323
325
  )
@@ -647,7 +649,10 @@ class KerasPresetLoader(PresetLoader):
647
649
  return check_config_class(self.config)
648
650
 
649
651
  def load_backbone(self, cls, load_weights, **kwargs):
650
- backbone = self._load_serialized_object(self.config, **kwargs)
652
+ config = self.config.copy()
653
+ backbone_kwargs, kwargs = self.get_backbone_kwargs(**kwargs)
654
+ config["config"] = {**config["config"], **backbone_kwargs}
655
+ backbone = self._load_serialized_object(config, **kwargs)
651
656
  if load_weights:
652
657
  jax_memory_cleanup(backbone)
653
658
  self._load_backbone_weights(backbone)
@@ -732,7 +737,13 @@ class KerasPresetLoader(PresetLoader):
732
737
  with open(config_path, encoding="utf-8") as config_file:
733
738
  config = json.load(config_file)
734
739
  weight_map = config["weight_map"]
735
- return sorted(set(weight_map.values()))
740
+ filenames = set()
741
+ for v in weight_map.values():
742
+ if isinstance(v, list):
743
+ filenames.update(v)
744
+ else:
745
+ filenames.add(v)
746
+ return sorted(filenames)
736
747
 
737
748
  def _load_backbone_weights(self, backbone):
738
749
  # Detect if the backbone is sharded or not.
@@ -772,7 +783,11 @@ class KerasPresetSaver:
772
783
  backbone_size_in_gb = backbone_size_in_bytes / (1024**3)
773
784
  # If the size of the backbone is larger than `max_shard_size`, save
774
785
  # sharded weights.
775
- if sharded_weights_available() and backbone_size_in_gb > max_shard_size:
786
+ if (
787
+ sharded_weights_available()
788
+ and max_shard_size is not None
789
+ and backbone_size_in_gb > max_shard_size
790
+ ):
776
791
  backbone_sharded_weights_config_path = os.path.join(
777
792
  self.preset_dir, SHARDED_MODEL_WEIGHTS_CONFIG_FILE
778
793
  )
@@ -21,6 +21,20 @@ except ImportError:
21
21
  NO_CONVERT_COUNTER = threading.local()
22
22
 
23
23
 
24
+ def pad(x, shape, padding_side, pad_value):
25
+ if padding_side == "left":
26
+ x = x[..., ::-1]
27
+
28
+ outputs = x.to_tensor(
29
+ default_value=pad_value,
30
+ shape=shape,
31
+ )
32
+
33
+ if padding_side == "left":
34
+ outputs = outputs[..., ::-1]
35
+ return outputs
36
+
37
+
24
38
  @contextlib.contextmanager
25
39
  def no_convert_scope():
26
40
  try: