keras-hub 0.21.1.dev0__py3-none-any.whl → 0.22.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/layers/__init__.py +9 -0
- keras_hub/models/__init__.py +47 -0
- keras_hub/src/layers/modeling/transformer_encoder.py +6 -3
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +17 -3
- keras_hub/src/layers/preprocessing/start_end_packer.py +24 -6
- keras_hub/src/models/backbone.py +13 -10
- keras_hub/src/models/clip/clip_backbone.py +3 -102
- keras_hub/src/models/clip/clip_layers.py +295 -0
- keras_hub/src/models/clip/clip_preprocessor.py +57 -48
- keras_hub/src/models/clip/clip_text_encoder.py +2 -2
- keras_hub/src/models/clip/clip_vision_encoder.py +3 -3
- keras_hub/src/models/deit/__init__.py +5 -0
- keras_hub/src/models/deit/deit_backbone.py +154 -0
- keras_hub/src/models/deit/deit_image_classifier.py +171 -0
- keras_hub/src/models/deit/deit_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/deit/deit_image_converter.py +8 -0
- keras_hub/src/models/deit/deit_layers.py +519 -0
- keras_hub/src/models/deit/deit_presets.py +49 -0
- keras_hub/src/models/dinov2/__init__.py +5 -0
- keras_hub/src/models/dinov2/dinov2_backbone.py +228 -0
- keras_hub/src/models/dinov2/dinov2_image_converter.py +8 -0
- keras_hub/src/models/dinov2/dinov2_layers.py +886 -0
- keras_hub/src/models/dinov2/dinov2_presets.py +89 -0
- keras_hub/src/models/esm/__init__.py +5 -0
- keras_hub/src/models/esm/esm_attention.py +95 -0
- keras_hub/src/models/esm/esm_backbone.py +229 -0
- keras_hub/src/models/esm/esm_classifier.py +184 -0
- keras_hub/src/models/esm/esm_classifier_preprocessor.py +135 -0
- keras_hub/src/models/esm/esm_encoder.py +134 -0
- keras_hub/src/models/esm/esm_masked_plm.py +117 -0
- keras_hub/src/models/esm/esm_masked_plm_preprocessor.py +143 -0
- keras_hub/src/models/esm/esm_presets.py +53 -0
- keras_hub/src/models/esm/esm_tokenizer.py +82 -0
- keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +6 -2
- keras_hub/src/models/gemma/gemma_attention.py +1 -1
- keras_hub/src/models/gemma3/gemma3_backbone.py +2 -2
- keras_hub/src/models/gemma3/gemma3_interleave_embeddings.py +1 -1
- keras_hub/src/models/gemma3/gemma3_presets.py +25 -0
- keras_hub/src/models/hgnetv2/__init__.py +5 -0
- keras_hub/src/models/hgnetv2/hgnetv2_backbone.py +193 -0
- keras_hub/src/models/hgnetv2/hgnetv2_encoder.py +148 -0
- keras_hub/src/models/hgnetv2/hgnetv2_image_classifier.py +216 -0
- keras_hub/src/models/hgnetv2/hgnetv2_image_classifier_preprocessor.py +14 -0
- keras_hub/src/models/hgnetv2/hgnetv2_image_converter.py +8 -0
- keras_hub/src/models/hgnetv2/hgnetv2_layers.py +918 -0
- keras_hub/src/models/hgnetv2/hgnetv2_presets.py +58 -0
- keras_hub/src/models/llama3/llama3_presets.py +3 -3
- keras_hub/src/models/mistral/mistral_presets.py +17 -1
- keras_hub/src/models/mixtral/mixtral_presets.py +2 -2
- keras_hub/src/models/mobilenet/mobilenet_presets.py +4 -4
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +2 -2
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +2 -2
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +17 -17
- keras_hub/src/models/qwen3/__init__.py +5 -0
- keras_hub/src/models/qwen3/qwen3_attention.py +369 -0
- keras_hub/src/models/qwen3/qwen3_backbone.py +191 -0
- keras_hub/src/models/qwen3/qwen3_causal_lm.py +390 -0
- keras_hub/src/models/qwen3/qwen3_causal_lm_preprocessor.py +10 -0
- keras_hub/src/models/qwen3/qwen3_decoder.py +309 -0
- keras_hub/src/models/qwen3/qwen3_layernorm.py +38 -0
- keras_hub/src/models/qwen3/qwen3_presets.py +73 -0
- keras_hub/src/models/qwen3/qwen3_tokenizer.py +48 -0
- keras_hub/src/models/qwen_moe/qwen_moe_attention.py +1 -0
- keras_hub/src/models/qwen_moe/qwen_moe_presets.py +2 -2
- keras_hub/src/models/roformer_v2/roformer_v2_attention.py +0 -2
- keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +16 -7
- keras_hub/src/models/stable_diffusion_3/mmdit.py +61 -4
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +31 -32
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +1 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +1 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +1 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +6 -2
- keras_hub/src/models/vit/vit_backbone.py +31 -11
- keras_hub/src/models/vit/vit_image_converter.py +0 -70
- keras_hub/src/models/vit/vit_layers.py +33 -18
- keras_hub/src/models/vit/vit_presets.py +11 -11
- keras_hub/src/utils/keras_utils.py +17 -0
- keras_hub/src/utils/preset_utils.py +19 -4
- keras_hub/src/utils/tensor_utils.py +14 -0
- keras_hub/src/utils/transformers/convert_deit.py +155 -0
- keras_hub/src/utils/transformers/convert_dinov2.py +180 -0
- keras_hub/src/utils/transformers/convert_esm.py +159 -0
- keras_hub/src/utils/transformers/convert_llama3.py +6 -0
- keras_hub/src/utils/transformers/convert_qwen3.py +145 -0
- keras_hub/src/utils/transformers/export/gemma.py +89 -0
- keras_hub/src/utils/transformers/export/hf_exporter.py +98 -0
- keras_hub/src/utils/transformers/preset_loader.py +14 -2
- keras_hub/src/version.py +1 -1
- keras_hub/tokenizers/__init__.py +1 -0
- {keras_hub-0.21.1.dev0.dist-info → keras_hub-0.22.0.dist-info}/METADATA +4 -4
- {keras_hub-0.21.1.dev0.dist-info → keras_hub-0.22.0.dist-info}/RECORD +93 -49
- keras_hub/src/models/clip/clip_encoder_block.py +0 -111
- keras_hub/src/models/clip/clip_vision_embedding.py +0 -101
- {keras_hub-0.21.1.dev0.dist-info → keras_hub-0.22.0.dist-info}/WHEEL +0 -0
- {keras_hub-0.21.1.dev0.dist-info → keras_hub-0.22.0.dist-info}/top_level.txt +0 -0
|
@@ -1,4 +1,6 @@
|
|
|
1
1
|
import keras
|
|
2
|
+
from keras import backend
|
|
3
|
+
from keras import distribution
|
|
2
4
|
from keras import layers
|
|
3
5
|
from keras import ops
|
|
4
6
|
|
|
@@ -96,26 +98,10 @@ class LatentRescaling(layers.Rescaling):
|
|
|
96
98
|
return (self.backend.cast(inputs, dtype) / scale) + offset
|
|
97
99
|
|
|
98
100
|
|
|
99
|
-
class
|
|
100
|
-
def call(
|
|
101
|
-
self,
|
|
102
|
-
latents,
|
|
103
|
-
positive_contexts,
|
|
104
|
-
negative_contexts,
|
|
105
|
-
positive_pooled_projections,
|
|
106
|
-
negative_pooled_projections,
|
|
107
|
-
timestep,
|
|
108
|
-
):
|
|
101
|
+
class TimestepBroadcastTo(layers.Layer):
|
|
102
|
+
def call(self, latents, timestep):
|
|
109
103
|
timestep = ops.broadcast_to(timestep, ops.shape(latents)[:1])
|
|
110
|
-
|
|
111
|
-
contexts = ops.concatenate(
|
|
112
|
-
[positive_contexts, negative_contexts], axis=0
|
|
113
|
-
)
|
|
114
|
-
pooled_projections = ops.concatenate(
|
|
115
|
-
[positive_pooled_projections, negative_pooled_projections], axis=0
|
|
116
|
-
)
|
|
117
|
-
timesteps = ops.concatenate([timestep, timestep], axis=0)
|
|
118
|
-
return latents, contexts, pooled_projections, timesteps
|
|
104
|
+
return timestep
|
|
119
105
|
|
|
120
106
|
|
|
121
107
|
class ClassifierFreeGuidance(layers.Layer):
|
|
@@ -330,8 +316,8 @@ class StableDiffusion3Backbone(Backbone):
|
|
|
330
316
|
name="diffuser",
|
|
331
317
|
)
|
|
332
318
|
self.vae = vae
|
|
333
|
-
self.
|
|
334
|
-
dtype=dtype, name="
|
|
319
|
+
self.timestep_broadcast_to = TimestepBroadcastTo(
|
|
320
|
+
dtype=dtype, name="timestep_broadcast_to"
|
|
335
321
|
)
|
|
336
322
|
self.cfg = ClassifierFreeGuidance(
|
|
337
323
|
dtype=dtype, name="classifier_free_guidance"
|
|
@@ -538,6 +524,9 @@ class StableDiffusion3Backbone(Backbone):
|
|
|
538
524
|
latents = self.vae.encode(images)
|
|
539
525
|
return self.image_rescaling(latents)
|
|
540
526
|
|
|
527
|
+
def configure_scheduler(self, num_steps):
|
|
528
|
+
self.scheduler.set_sigmas(num_steps)
|
|
529
|
+
|
|
541
530
|
def add_noise_step(self, latents, noises, step, num_steps):
|
|
542
531
|
return self.scheduler.add_noise(latents, noises, step, num_steps)
|
|
543
532
|
|
|
@@ -562,11 +551,15 @@ class StableDiffusion3Backbone(Backbone):
|
|
|
562
551
|
|
|
563
552
|
# Concatenation for classifier-free guidance.
|
|
564
553
|
if guidance_scale is not None:
|
|
565
|
-
|
|
566
|
-
|
|
554
|
+
timestep = self.timestep_broadcast_to(latents, timestep)
|
|
555
|
+
timesteps = ops.concatenate([timestep, timestep], axis=0)
|
|
556
|
+
concated_latents = ops.concatenate([latents, latents], axis=0)
|
|
557
|
+
contexts = ops.concatenate([embeddings[0], embeddings[1]], axis=0)
|
|
558
|
+
pooled_projs = ops.concatenate(
|
|
559
|
+
[embeddings[2], embeddings[3]], axis=0
|
|
567
560
|
)
|
|
568
561
|
else:
|
|
569
|
-
timesteps =
|
|
562
|
+
timesteps = self.timestep_broadcast_to(latents, timestep)
|
|
570
563
|
concated_latents = latents
|
|
571
564
|
contexts = embeddings[0]
|
|
572
565
|
pooled_projs = embeddings[2]
|
|
@@ -623,20 +616,26 @@ class StableDiffusion3Backbone(Backbone):
|
|
|
623
616
|
def from_config(cls, config, custom_objects=None):
|
|
624
617
|
config = config.copy()
|
|
625
618
|
|
|
626
|
-
# Propagate `dtype` to
|
|
619
|
+
# Propagate `dtype` to the VAE if needed.
|
|
627
620
|
if "dtype" in config and config["dtype"] is not None:
|
|
628
621
|
dtype_config = config["dtype"]
|
|
629
622
|
if "dtype" not in config["vae"]["config"]:
|
|
630
623
|
config["vae"]["config"]["dtype"] = dtype_config
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
624
|
+
|
|
625
|
+
# Text encoders default to float16 dtype if not specified.
|
|
626
|
+
# TODO: JAX CPU doesn't support float16 in `nn.dot_product_attention`.
|
|
627
|
+
is_jax_cpu = (
|
|
628
|
+
backend.backend() == "jax"
|
|
629
|
+
and "cpu" in distribution.list_devices()[0].lower()
|
|
630
|
+
)
|
|
631
|
+
for text_encoder in ("clip_l", "clip_g", "t5"):
|
|
635
632
|
if (
|
|
636
|
-
|
|
637
|
-
and
|
|
633
|
+
text_encoder in config
|
|
634
|
+
and config[text_encoder] is not None
|
|
635
|
+
and "dtype" not in config[text_encoder]["config"]
|
|
636
|
+
and not is_jax_cpu
|
|
638
637
|
):
|
|
639
|
-
config[
|
|
638
|
+
config[text_encoder]["config"]["dtype"] = "float16"
|
|
640
639
|
|
|
641
640
|
# We expect `vae`, `clip_l`, `clip_g` and/or `t5` to be instantiated.
|
|
642
641
|
config["vae"] = layers.deserialize(
|
|
@@ -50,8 +50,12 @@ class StableDiffusion3TextToImagePreprocessor(TextToImagePreprocessor):
|
|
|
50
50
|
|
|
51
51
|
def generate_preprocess(self, x):
|
|
52
52
|
token_ids = {}
|
|
53
|
-
token_ids["clip_l"] = self.clip_l_preprocessor(
|
|
54
|
-
|
|
53
|
+
token_ids["clip_l"] = self.clip_l_preprocessor(
|
|
54
|
+
{"prompts": x, "images": None}
|
|
55
|
+
)["token_ids"]
|
|
56
|
+
token_ids["clip_g"] = self.clip_g_preprocessor(
|
|
57
|
+
{"prompts": x, "images": None}
|
|
58
|
+
)["token_ids"]
|
|
55
59
|
if self.t5_preprocessor is not None:
|
|
56
60
|
token_ids["t5"] = self.t5_preprocessor(x)["token_ids"]
|
|
57
61
|
return token_ids
|
|
@@ -18,10 +18,10 @@ class ViTBackbone(Backbone):
|
|
|
18
18
|
|
|
19
19
|
Args:
|
|
20
20
|
image_shape: A tuple or list of 3 integers representing the shape of the
|
|
21
|
-
input image `(height, width, channels)
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
21
|
+
input image `(height, width, channels)`.
|
|
22
|
+
patch_size: int or (int, int). The size of each image patch, the input
|
|
23
|
+
image will be divided into patches of shape
|
|
24
|
+
`(patch_size_h, patch_size_w)`.
|
|
25
25
|
num_layers: int. The number of transformer encoder layers.
|
|
26
26
|
num_heads: int. specifying the number of attention heads in each
|
|
27
27
|
Transformer encoder layer.
|
|
@@ -37,6 +37,10 @@ class ViTBackbone(Backbone):
|
|
|
37
37
|
use_mha_bias: bool. Whether to use bias in the multi-head
|
|
38
38
|
attention layers.
|
|
39
39
|
use_mlp_bias: bool. Whether to use bias in the MLP layers.
|
|
40
|
+
use_class_token: bool. Whether to use class token to be part of
|
|
41
|
+
patch embedding. Defaults to `True`.
|
|
42
|
+
use_patch_bias: bool. Whether to use bias in Conv2d of patch embedding
|
|
43
|
+
layer. Defaults to `True`.
|
|
40
44
|
data_format: str. `"channels_last"` or `"channels_first"`, specifying
|
|
41
45
|
the data format for the input image. If `None`, defaults to
|
|
42
46
|
`"channels_last"`.
|
|
@@ -58,6 +62,8 @@ class ViTBackbone(Backbone):
|
|
|
58
62
|
layer_norm_epsilon=1e-6,
|
|
59
63
|
use_mha_bias=True,
|
|
60
64
|
use_mlp_bias=True,
|
|
65
|
+
use_class_token=True,
|
|
66
|
+
use_patch_bias=True,
|
|
61
67
|
data_format=None,
|
|
62
68
|
dtype=None,
|
|
63
69
|
**kwargs,
|
|
@@ -74,24 +80,34 @@ class ViTBackbone(Backbone):
|
|
|
74
80
|
f"at index {h_axis} (height) or {w_axis} (width). "
|
|
75
81
|
f"Image shape: {image_shape}"
|
|
76
82
|
)
|
|
77
|
-
|
|
83
|
+
|
|
84
|
+
if isinstance(patch_size, int):
|
|
85
|
+
patch_size = (patch_size, patch_size)
|
|
86
|
+
|
|
87
|
+
if image_shape[h_axis] % patch_size[0] != 0:
|
|
88
|
+
raise ValueError(
|
|
89
|
+
f"Input height {image_shape[h_axis]} should be divisible by "
|
|
90
|
+
f"patch size {patch_size[0]}."
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
if image_shape[w_axis] % patch_size[1] != 0:
|
|
78
94
|
raise ValueError(
|
|
79
|
-
f"
|
|
80
|
-
f"
|
|
81
|
-
f"indices {h_axis} and {w_axis} respectively. Image shape: "
|
|
82
|
-
f"{image_shape}"
|
|
95
|
+
f"Input width {image_shape[h_axis]} should be divisible by "
|
|
96
|
+
f"patch size {patch_size[1]}."
|
|
83
97
|
)
|
|
84
98
|
|
|
85
99
|
num_channels = image_shape[channels_axis]
|
|
86
100
|
|
|
87
101
|
# === Functional Model ===
|
|
88
|
-
inputs = keras.layers.Input(shape=image_shape)
|
|
102
|
+
inputs = keras.layers.Input(shape=image_shape, name="images")
|
|
89
103
|
|
|
90
104
|
x = ViTPatchingAndEmbedding(
|
|
91
|
-
image_size=image_shape[h_axis],
|
|
105
|
+
image_size=(image_shape[h_axis], image_shape[w_axis]),
|
|
92
106
|
patch_size=patch_size,
|
|
93
107
|
hidden_dim=hidden_dim,
|
|
94
108
|
num_channels=num_channels,
|
|
109
|
+
use_class_token=use_class_token,
|
|
110
|
+
use_patch_bias=use_patch_bias,
|
|
95
111
|
data_format=data_format,
|
|
96
112
|
dtype=dtype,
|
|
97
113
|
name="vit_patching_and_embedding",
|
|
@@ -130,6 +146,8 @@ class ViTBackbone(Backbone):
|
|
|
130
146
|
self.layer_norm_epsilon = layer_norm_epsilon
|
|
131
147
|
self.use_mha_bias = use_mha_bias
|
|
132
148
|
self.use_mlp_bias = use_mlp_bias
|
|
149
|
+
self.use_class_token = use_class_token
|
|
150
|
+
self.use_patch_bias = use_patch_bias
|
|
133
151
|
self.data_format = data_format
|
|
134
152
|
|
|
135
153
|
def get_config(self):
|
|
@@ -147,6 +165,8 @@ class ViTBackbone(Backbone):
|
|
|
147
165
|
"layer_norm_epsilon": self.layer_norm_epsilon,
|
|
148
166
|
"use_mha_bias": self.use_mha_bias,
|
|
149
167
|
"use_mlp_bias": self.use_mlp_bias,
|
|
168
|
+
"use_class_token": self.use_class_token,
|
|
169
|
+
"use_patch_bias": self.use_patch_bias,
|
|
150
170
|
}
|
|
151
171
|
)
|
|
152
172
|
return config
|
|
@@ -1,78 +1,8 @@
|
|
|
1
1
|
from keras_hub.src.api_export import keras_hub_export
|
|
2
2
|
from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
|
|
3
3
|
from keras_hub.src.models.vit.vit_backbone import ViTBackbone
|
|
4
|
-
from keras_hub.src.utils.tensor_utils import preprocessing_function
|
|
5
4
|
|
|
6
5
|
|
|
7
6
|
@keras_hub_export("keras_hub.layers.ViTImageConverter")
|
|
8
7
|
class ViTImageConverter(ImageConverter):
|
|
9
|
-
"""Converts images to the format expected by a ViT model.
|
|
10
|
-
|
|
11
|
-
This layer performs image normalization using mean and standard deviation
|
|
12
|
-
values. By default, it uses the same normalization as the
|
|
13
|
-
"google/vit-large-patch16-224" model on Hugging Face:
|
|
14
|
-
`norm_mean=[0.5, 0.5, 0.5]` and `norm_std=[0.5, 0.5, 0.5]`
|
|
15
|
-
([reference](https://huggingface.co/google/vit-large-patch16-224/blob/main/preprocessor_config.json)).
|
|
16
|
-
These defaults are suitable for models pretrained using this normalization.
|
|
17
|
-
|
|
18
|
-
Args:
|
|
19
|
-
norm_mean: list or tuple of floats. Mean values for image normalization.
|
|
20
|
-
Defaults to `[0.5, 0.5, 0.5]`.
|
|
21
|
-
norm_std: list or tuple of floats. Standard deviation values for
|
|
22
|
-
image normalization. Defaults to `[0.5, 0.5, 0.5]`.
|
|
23
|
-
**kwargs: Additional keyword arguments passed to
|
|
24
|
-
`keras_hub.layers.preprocessing.ImageConverter`.
|
|
25
|
-
|
|
26
|
-
Examples:
|
|
27
|
-
```python
|
|
28
|
-
import keras
|
|
29
|
-
import numpy as np
|
|
30
|
-
from keras_hub.src.layers import ViTImageConverter
|
|
31
|
-
|
|
32
|
-
# Example image (replace with your actual image data)
|
|
33
|
-
image = np.random.rand(1, 224, 224, 3) # Example: (B, H, W, C)
|
|
34
|
-
|
|
35
|
-
# Create a ViTImageConverter instance
|
|
36
|
-
converter = ViTImageConverter(
|
|
37
|
-
image_size=(28,28),
|
|
38
|
-
scale=1/255.
|
|
39
|
-
)
|
|
40
|
-
# Preprocess the image
|
|
41
|
-
preprocessed_image = converter(image)
|
|
42
|
-
```
|
|
43
|
-
"""
|
|
44
|
-
|
|
45
8
|
backbone_cls = ViTBackbone
|
|
46
|
-
|
|
47
|
-
def __init__(
|
|
48
|
-
self, norm_mean=[0.5, 0.5, 0.5], norm_std=[0.5, 0.5, 0.5], **kwargs
|
|
49
|
-
):
|
|
50
|
-
super().__init__(**kwargs)
|
|
51
|
-
self.norm_mean = norm_mean
|
|
52
|
-
self.norm_std = norm_std
|
|
53
|
-
|
|
54
|
-
@preprocessing_function
|
|
55
|
-
def call(self, inputs):
|
|
56
|
-
# TODO: Remove this whole function. Why can just use scale and offset
|
|
57
|
-
# in the base class.
|
|
58
|
-
x = super().call(inputs)
|
|
59
|
-
if self.norm_mean:
|
|
60
|
-
norm_mean = self._expand_non_channel_dims(self.norm_mean, x)
|
|
61
|
-
x, norm_mean = self._convert_types(x, norm_mean, self.compute_dtype)
|
|
62
|
-
x = x - norm_mean
|
|
63
|
-
if self.norm_std:
|
|
64
|
-
norm_std = self._expand_non_channel_dims(self.norm_std, x)
|
|
65
|
-
x, norm_std = self._convert_types(x, norm_std, x.dtype)
|
|
66
|
-
x = x / norm_std
|
|
67
|
-
|
|
68
|
-
return x
|
|
69
|
-
|
|
70
|
-
def get_config(self):
|
|
71
|
-
config = super().get_config()
|
|
72
|
-
config.update(
|
|
73
|
-
{
|
|
74
|
-
"norm_mean": self.norm_mean,
|
|
75
|
-
"norm_std": self.norm_std,
|
|
76
|
-
}
|
|
77
|
-
)
|
|
78
|
-
return config
|
|
@@ -75,12 +75,13 @@ class ViTPatchingAndEmbedding(keras.layers.Layer):
|
|
|
75
75
|
"""Patches the image and embeds the patches.
|
|
76
76
|
|
|
77
77
|
Args:
|
|
78
|
-
image_size: int. Size of the input image
|
|
79
|
-
|
|
80
|
-
patch_size: int. Size of each image patch.
|
|
78
|
+
image_size: (int, int). Size of the input image.
|
|
79
|
+
patch_size: (int, int). Size of each image patch.
|
|
81
80
|
hidden_dim: int. Dimensionality of the patch embeddings.
|
|
82
81
|
num_channels: int. Number of channels in the input image. Defaults to
|
|
83
82
|
`3`.
|
|
83
|
+
use_class_token: bool. Whether to use class token to be part of
|
|
84
|
+
patch embedding. Defaults to `True`.
|
|
84
85
|
data_format: str. `"channels_last"` or `"channels_first"`. Defaults to
|
|
85
86
|
`None` (which uses `"channels_last"`).
|
|
86
87
|
**kwargs: Additional keyword arguments passed to `keras.layers.Layer`
|
|
@@ -92,12 +93,15 @@ class ViTPatchingAndEmbedding(keras.layers.Layer):
|
|
|
92
93
|
patch_size,
|
|
93
94
|
hidden_dim,
|
|
94
95
|
num_channels=3,
|
|
96
|
+
use_class_token=True,
|
|
97
|
+
use_patch_bias=True,
|
|
95
98
|
data_format=None,
|
|
96
99
|
**kwargs,
|
|
97
100
|
):
|
|
98
101
|
super().__init__(**kwargs)
|
|
99
|
-
|
|
100
|
-
|
|
102
|
+
grid_size = tuple([s // p for s, p in zip(image_size, patch_size)])
|
|
103
|
+
num_patches = grid_size[0] * grid_size[1]
|
|
104
|
+
num_positions = num_patches + 1 if use_class_token else num_patches
|
|
101
105
|
|
|
102
106
|
# === Config ===
|
|
103
107
|
self.image_size = image_size
|
|
@@ -106,19 +110,22 @@ class ViTPatchingAndEmbedding(keras.layers.Layer):
|
|
|
106
110
|
self.num_channels = num_channels
|
|
107
111
|
self.num_patches = num_patches
|
|
108
112
|
self.num_positions = num_positions
|
|
113
|
+
self.use_class_token = use_class_token
|
|
114
|
+
self.use_patch_bias = use_patch_bias
|
|
109
115
|
self.data_format = standardize_data_format(data_format)
|
|
110
116
|
|
|
111
117
|
def build(self, input_shape):
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
118
|
+
if self.use_class_token:
|
|
119
|
+
self.class_token = self.add_weight(
|
|
120
|
+
shape=(
|
|
121
|
+
1,
|
|
122
|
+
1,
|
|
123
|
+
self.hidden_dim,
|
|
124
|
+
),
|
|
125
|
+
initializer="random_normal",
|
|
126
|
+
dtype=self.variable_dtype,
|
|
127
|
+
name="class_token",
|
|
128
|
+
)
|
|
122
129
|
self.patch_embedding = keras.layers.Conv2D(
|
|
123
130
|
filters=self.hidden_dim,
|
|
124
131
|
kernel_size=self.patch_size,
|
|
@@ -127,6 +134,7 @@ class ViTPatchingAndEmbedding(keras.layers.Layer):
|
|
|
127
134
|
activation=None,
|
|
128
135
|
dtype=self.dtype_policy,
|
|
129
136
|
data_format=self.data_format,
|
|
137
|
+
use_bias=self.use_patch_bias,
|
|
130
138
|
name="patch_embedding",
|
|
131
139
|
)
|
|
132
140
|
self.patch_embedding.build(input_shape)
|
|
@@ -153,10 +161,16 @@ class ViTPatchingAndEmbedding(keras.layers.Layer):
|
|
|
153
161
|
patch_embeddings = ops.reshape(
|
|
154
162
|
patch_embeddings, [embeddings_shape[0], -1, embeddings_shape[-1]]
|
|
155
163
|
)
|
|
156
|
-
class_token = ops.tile(self.class_token, (embeddings_shape[0], 1, 1))
|
|
157
164
|
position_embeddings = self.position_embedding(self.position_ids)
|
|
158
|
-
|
|
159
|
-
|
|
165
|
+
|
|
166
|
+
if self.use_class_token:
|
|
167
|
+
class_token = ops.tile(
|
|
168
|
+
self.class_token, (embeddings_shape[0], 1, 1)
|
|
169
|
+
)
|
|
170
|
+
patch_embeddings = ops.concatenate(
|
|
171
|
+
[class_token, patch_embeddings], axis=1
|
|
172
|
+
)
|
|
173
|
+
return ops.add(patch_embeddings, position_embeddings)
|
|
160
174
|
|
|
161
175
|
def compute_output_shape(self, input_shape):
|
|
162
176
|
return (
|
|
@@ -175,6 +189,7 @@ class ViTPatchingAndEmbedding(keras.layers.Layer):
|
|
|
175
189
|
"num_channels": self.num_channels,
|
|
176
190
|
"num_patches": self.num_patches,
|
|
177
191
|
"num_positions": self.num_positions,
|
|
192
|
+
"use_class_token": self.use_class_token,
|
|
178
193
|
}
|
|
179
194
|
)
|
|
180
195
|
return config
|
|
@@ -11,7 +11,7 @@ backbone_presets = {
|
|
|
11
11
|
"params": 85798656,
|
|
12
12
|
"path": "vit",
|
|
13
13
|
},
|
|
14
|
-
"kaggle_handle": "kaggle://keras/vit/keras/vit_base_patch16_224_imagenet/
|
|
14
|
+
"kaggle_handle": "kaggle://keras/vit/keras/vit_base_patch16_224_imagenet/3",
|
|
15
15
|
},
|
|
16
16
|
"vit_base_patch16_384_imagenet": {
|
|
17
17
|
"metadata": {
|
|
@@ -22,7 +22,7 @@ backbone_presets = {
|
|
|
22
22
|
"params": 86090496,
|
|
23
23
|
"path": "vit",
|
|
24
24
|
},
|
|
25
|
-
"kaggle_handle": "kaggle://keras/vit/keras/vit_base_patch16_384_imagenet/
|
|
25
|
+
"kaggle_handle": "kaggle://keras/vit/keras/vit_base_patch16_384_imagenet/3",
|
|
26
26
|
},
|
|
27
27
|
"vit_large_patch16_224_imagenet": {
|
|
28
28
|
"metadata": {
|
|
@@ -33,7 +33,7 @@ backbone_presets = {
|
|
|
33
33
|
"params": 303301632,
|
|
34
34
|
"path": "vit",
|
|
35
35
|
},
|
|
36
|
-
"kaggle_handle": "kaggle://keras/vit/keras/vit_large_patch16_224_imagenet/
|
|
36
|
+
"kaggle_handle": "kaggle://keras/vit/keras/vit_large_patch16_224_imagenet/3",
|
|
37
37
|
},
|
|
38
38
|
"vit_large_patch16_384_imagenet": {
|
|
39
39
|
"metadata": {
|
|
@@ -44,7 +44,7 @@ backbone_presets = {
|
|
|
44
44
|
"params": 303690752,
|
|
45
45
|
"path": "vit",
|
|
46
46
|
},
|
|
47
|
-
"kaggle_handle": "kaggle://keras/vit/keras/vit_large_patch16_384_imagenet/
|
|
47
|
+
"kaggle_handle": "kaggle://keras/vit/keras/vit_large_patch16_384_imagenet/3",
|
|
48
48
|
},
|
|
49
49
|
"vit_base_patch32_384_imagenet": {
|
|
50
50
|
"metadata": {
|
|
@@ -55,7 +55,7 @@ backbone_presets = {
|
|
|
55
55
|
"params": 87528192,
|
|
56
56
|
"path": "vit",
|
|
57
57
|
},
|
|
58
|
-
"kaggle_handle": "kaggle://keras/vit/keras/vit_base_patch32_384_imagenet/
|
|
58
|
+
"kaggle_handle": "kaggle://keras/vit/keras/vit_base_patch32_384_imagenet/2",
|
|
59
59
|
},
|
|
60
60
|
"vit_large_patch32_384_imagenet": {
|
|
61
61
|
"metadata": {
|
|
@@ -66,7 +66,7 @@ backbone_presets = {
|
|
|
66
66
|
"params": 305607680,
|
|
67
67
|
"path": "vit",
|
|
68
68
|
},
|
|
69
|
-
"kaggle_handle": "kaggle://keras/vit/keras/vit_large_patch32_384_imagenet/
|
|
69
|
+
"kaggle_handle": "kaggle://keras/vit/keras/vit_large_patch32_384_imagenet/2",
|
|
70
70
|
},
|
|
71
71
|
"vit_base_patch16_224_imagenet21k": {
|
|
72
72
|
"metadata": {
|
|
@@ -77,7 +77,7 @@ backbone_presets = {
|
|
|
77
77
|
"params": 85798656,
|
|
78
78
|
"path": "vit",
|
|
79
79
|
},
|
|
80
|
-
"kaggle_handle": "kaggle://keras/vit/keras/vit_base_patch16_224_imagenet21k/
|
|
80
|
+
"kaggle_handle": "kaggle://keras/vit/keras/vit_base_patch16_224_imagenet21k/2",
|
|
81
81
|
},
|
|
82
82
|
"vit_base_patch32_224_imagenet21k": {
|
|
83
83
|
"metadata": {
|
|
@@ -88,7 +88,7 @@ backbone_presets = {
|
|
|
88
88
|
"params": 87455232,
|
|
89
89
|
"path": "vit",
|
|
90
90
|
},
|
|
91
|
-
"kaggle_handle": "kaggle://keras/vit/keras/vit_base_patch32_224_imagenet21k/
|
|
91
|
+
"kaggle_handle": "kaggle://keras/vit/keras/vit_base_patch32_224_imagenet21k/2",
|
|
92
92
|
},
|
|
93
93
|
"vit_huge_patch14_224_imagenet21k": {
|
|
94
94
|
"metadata": {
|
|
@@ -99,7 +99,7 @@ backbone_presets = {
|
|
|
99
99
|
"params": 630764800,
|
|
100
100
|
"path": "vit",
|
|
101
101
|
},
|
|
102
|
-
"kaggle_handle": "kaggle://keras/vit/keras/vit_huge_patch14_224_imagenet21k/
|
|
102
|
+
"kaggle_handle": "kaggle://keras/vit/keras/vit_huge_patch14_224_imagenet21k/2",
|
|
103
103
|
},
|
|
104
104
|
"vit_large_patch16_224_imagenet21k": {
|
|
105
105
|
"metadata": {
|
|
@@ -110,7 +110,7 @@ backbone_presets = {
|
|
|
110
110
|
"params": 303301632,
|
|
111
111
|
"path": "vit",
|
|
112
112
|
},
|
|
113
|
-
"kaggle_handle": "kaggle://keras/vit/keras/vit_large_patch16_224_imagenet21k/
|
|
113
|
+
"kaggle_handle": "kaggle://keras/vit/keras/vit_large_patch16_224_imagenet21k/2",
|
|
114
114
|
},
|
|
115
115
|
"vit_large_patch32_224_imagenet21k": {
|
|
116
116
|
"metadata": {
|
|
@@ -121,6 +121,6 @@ backbone_presets = {
|
|
|
121
121
|
"params": 305510400,
|
|
122
122
|
"path": "vit",
|
|
123
123
|
},
|
|
124
|
-
"kaggle_handle": "kaggle://keras/vit/keras/vit_large_patch32_224_imagenet21k/
|
|
124
|
+
"kaggle_handle": "kaggle://keras/vit/keras/vit_large_patch32_224_imagenet21k/2",
|
|
125
125
|
},
|
|
126
126
|
}
|
|
@@ -71,6 +71,23 @@ def fused_attention_op_available():
|
|
|
71
71
|
)
|
|
72
72
|
return False
|
|
73
73
|
return True
|
|
74
|
+
elif (
|
|
75
|
+
hasattr(keras.config, "is_flash_attention_enabled")
|
|
76
|
+
and keras.config.backend() == "torch"
|
|
77
|
+
):
|
|
78
|
+
try:
|
|
79
|
+
from torch.backends.cuda import SDPAParams as SDPAParams
|
|
80
|
+
from torch.backends.cuda import (
|
|
81
|
+
can_use_flash_attention as can_use_flash_attention,
|
|
82
|
+
)
|
|
83
|
+
except ImportError:
|
|
84
|
+
logging.warning(
|
|
85
|
+
"Flash attention is not supported in your current PyTorch "
|
|
86
|
+
"version. Please update it by following the official guide: "
|
|
87
|
+
"https://pytorch.org/get-started/locally/"
|
|
88
|
+
)
|
|
89
|
+
return False
|
|
90
|
+
return True
|
|
74
91
|
else:
|
|
75
92
|
return False
|
|
76
93
|
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import collections
|
|
2
2
|
import datetime
|
|
3
|
+
import glob
|
|
3
4
|
import inspect
|
|
4
5
|
import json
|
|
5
6
|
import os
|
|
@@ -317,7 +318,8 @@ def _validate_backbone(preset):
|
|
|
317
318
|
)
|
|
318
319
|
|
|
319
320
|
weights_path = os.path.join(preset, MODEL_WEIGHTS_FILE)
|
|
320
|
-
|
|
321
|
+
sharded_weights_path = os.path.join(preset, "model_*.weights.h5")
|
|
322
|
+
if not os.path.exists(weights_path) and not glob.glob(sharded_weights_path):
|
|
321
323
|
raise FileNotFoundError(
|
|
322
324
|
f"The weights file is missing from the preset directory `{preset}`."
|
|
323
325
|
)
|
|
@@ -647,7 +649,10 @@ class KerasPresetLoader(PresetLoader):
|
|
|
647
649
|
return check_config_class(self.config)
|
|
648
650
|
|
|
649
651
|
def load_backbone(self, cls, load_weights, **kwargs):
|
|
650
|
-
|
|
652
|
+
config = self.config.copy()
|
|
653
|
+
backbone_kwargs, kwargs = self.get_backbone_kwargs(**kwargs)
|
|
654
|
+
config["config"] = {**config["config"], **backbone_kwargs}
|
|
655
|
+
backbone = self._load_serialized_object(config, **kwargs)
|
|
651
656
|
if load_weights:
|
|
652
657
|
jax_memory_cleanup(backbone)
|
|
653
658
|
self._load_backbone_weights(backbone)
|
|
@@ -732,7 +737,13 @@ class KerasPresetLoader(PresetLoader):
|
|
|
732
737
|
with open(config_path, encoding="utf-8") as config_file:
|
|
733
738
|
config = json.load(config_file)
|
|
734
739
|
weight_map = config["weight_map"]
|
|
735
|
-
|
|
740
|
+
filenames = set()
|
|
741
|
+
for v in weight_map.values():
|
|
742
|
+
if isinstance(v, list):
|
|
743
|
+
filenames.update(v)
|
|
744
|
+
else:
|
|
745
|
+
filenames.add(v)
|
|
746
|
+
return sorted(filenames)
|
|
736
747
|
|
|
737
748
|
def _load_backbone_weights(self, backbone):
|
|
738
749
|
# Detect if the backbone is sharded or not.
|
|
@@ -772,7 +783,11 @@ class KerasPresetSaver:
|
|
|
772
783
|
backbone_size_in_gb = backbone_size_in_bytes / (1024**3)
|
|
773
784
|
# If the size of the backbone is larger than `max_shard_size`, save
|
|
774
785
|
# sharded weights.
|
|
775
|
-
if
|
|
786
|
+
if (
|
|
787
|
+
sharded_weights_available()
|
|
788
|
+
and max_shard_size is not None
|
|
789
|
+
and backbone_size_in_gb > max_shard_size
|
|
790
|
+
):
|
|
776
791
|
backbone_sharded_weights_config_path = os.path.join(
|
|
777
792
|
self.preset_dir, SHARDED_MODEL_WEIGHTS_CONFIG_FILE
|
|
778
793
|
)
|
|
@@ -21,6 +21,20 @@ except ImportError:
|
|
|
21
21
|
NO_CONVERT_COUNTER = threading.local()
|
|
22
22
|
|
|
23
23
|
|
|
24
|
+
def pad(x, shape, padding_side, pad_value):
|
|
25
|
+
if padding_side == "left":
|
|
26
|
+
x = x[..., ::-1]
|
|
27
|
+
|
|
28
|
+
outputs = x.to_tensor(
|
|
29
|
+
default_value=pad_value,
|
|
30
|
+
shape=shape,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
if padding_side == "left":
|
|
34
|
+
outputs = outputs[..., ::-1]
|
|
35
|
+
return outputs
|
|
36
|
+
|
|
37
|
+
|
|
24
38
|
@contextlib.contextmanager
|
|
25
39
|
def no_convert_scope():
|
|
26
40
|
try:
|