keras-hub 0.21.1__py3-none-any.whl → 0.22.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. keras_hub/layers/__init__.py +9 -0
  2. keras_hub/models/__init__.py +47 -0
  3. keras_hub/src/layers/modeling/transformer_encoder.py +6 -3
  4. keras_hub/src/layers/preprocessing/multi_segment_packer.py +17 -3
  5. keras_hub/src/layers/preprocessing/start_end_packer.py +24 -6
  6. keras_hub/src/models/backbone.py +13 -10
  7. keras_hub/src/models/clip/clip_backbone.py +3 -102
  8. keras_hub/src/models/clip/clip_layers.py +295 -0
  9. keras_hub/src/models/clip/clip_preprocessor.py +57 -48
  10. keras_hub/src/models/clip/clip_text_encoder.py +2 -2
  11. keras_hub/src/models/clip/clip_vision_encoder.py +3 -3
  12. keras_hub/src/models/deit/__init__.py +5 -0
  13. keras_hub/src/models/deit/deit_backbone.py +154 -0
  14. keras_hub/src/models/deit/deit_image_classifier.py +171 -0
  15. keras_hub/src/models/deit/deit_image_classifier_preprocessor.py +12 -0
  16. keras_hub/src/models/deit/deit_image_converter.py +8 -0
  17. keras_hub/src/models/deit/deit_layers.py +519 -0
  18. keras_hub/src/models/deit/deit_presets.py +49 -0
  19. keras_hub/src/models/dinov2/__init__.py +5 -0
  20. keras_hub/src/models/dinov2/dinov2_backbone.py +228 -0
  21. keras_hub/src/models/dinov2/dinov2_image_converter.py +8 -0
  22. keras_hub/src/models/dinov2/dinov2_layers.py +886 -0
  23. keras_hub/src/models/dinov2/dinov2_presets.py +89 -0
  24. keras_hub/src/models/esm/__init__.py +5 -0
  25. keras_hub/src/models/esm/esm_attention.py +95 -0
  26. keras_hub/src/models/esm/esm_backbone.py +229 -0
  27. keras_hub/src/models/esm/esm_classifier.py +184 -0
  28. keras_hub/src/models/esm/esm_classifier_preprocessor.py +135 -0
  29. keras_hub/src/models/esm/esm_encoder.py +134 -0
  30. keras_hub/src/models/esm/esm_masked_plm.py +117 -0
  31. keras_hub/src/models/esm/esm_masked_plm_preprocessor.py +143 -0
  32. keras_hub/src/models/esm/esm_presets.py +53 -0
  33. keras_hub/src/models/esm/esm_tokenizer.py +82 -0
  34. keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +6 -2
  35. keras_hub/src/models/gemma/gemma_attention.py +1 -1
  36. keras_hub/src/models/gemma3/gemma3_backbone.py +2 -2
  37. keras_hub/src/models/gemma3/gemma3_interleave_embeddings.py +1 -1
  38. keras_hub/src/models/hgnetv2/__init__.py +5 -0
  39. keras_hub/src/models/hgnetv2/hgnetv2_backbone.py +193 -0
  40. keras_hub/src/models/hgnetv2/hgnetv2_encoder.py +148 -0
  41. keras_hub/src/models/hgnetv2/hgnetv2_image_classifier.py +216 -0
  42. keras_hub/src/models/hgnetv2/hgnetv2_image_classifier_preprocessor.py +14 -0
  43. keras_hub/src/models/hgnetv2/hgnetv2_image_converter.py +8 -0
  44. keras_hub/src/models/hgnetv2/hgnetv2_layers.py +918 -0
  45. keras_hub/src/models/hgnetv2/hgnetv2_presets.py +58 -0
  46. keras_hub/src/models/llama3/llama3_presets.py +3 -3
  47. keras_hub/src/models/mistral/mistral_presets.py +17 -1
  48. keras_hub/src/models/mixtral/mixtral_presets.py +2 -2
  49. keras_hub/src/models/mobilenet/mobilenet_presets.py +4 -4
  50. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +2 -2
  51. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +2 -2
  52. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +17 -17
  53. keras_hub/src/models/qwen3/__init__.py +5 -0
  54. keras_hub/src/models/qwen3/qwen3_attention.py +369 -0
  55. keras_hub/src/models/qwen3/qwen3_backbone.py +191 -0
  56. keras_hub/src/models/qwen3/qwen3_causal_lm.py +390 -0
  57. keras_hub/src/models/qwen3/qwen3_causal_lm_preprocessor.py +10 -0
  58. keras_hub/src/models/qwen3/qwen3_decoder.py +309 -0
  59. keras_hub/src/models/qwen3/qwen3_layernorm.py +38 -0
  60. keras_hub/src/models/qwen3/qwen3_presets.py +73 -0
  61. keras_hub/src/models/qwen3/qwen3_tokenizer.py +48 -0
  62. keras_hub/src/models/qwen_moe/qwen_moe_attention.py +1 -0
  63. keras_hub/src/models/qwen_moe/qwen_moe_presets.py +2 -2
  64. keras_hub/src/models/roformer_v2/roformer_v2_attention.py +0 -2
  65. keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +16 -7
  66. keras_hub/src/models/stable_diffusion_3/mmdit.py +61 -4
  67. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +31 -32
  68. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +1 -0
  69. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +1 -0
  70. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +1 -0
  71. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +6 -2
  72. keras_hub/src/models/vit/vit_backbone.py +31 -11
  73. keras_hub/src/models/vit/vit_image_converter.py +0 -70
  74. keras_hub/src/models/vit/vit_layers.py +33 -18
  75. keras_hub/src/models/vit/vit_presets.py +11 -11
  76. keras_hub/src/utils/keras_utils.py +17 -0
  77. keras_hub/src/utils/preset_utils.py +19 -4
  78. keras_hub/src/utils/tensor_utils.py +14 -0
  79. keras_hub/src/utils/transformers/convert_deit.py +155 -0
  80. keras_hub/src/utils/transformers/convert_dinov2.py +180 -0
  81. keras_hub/src/utils/transformers/convert_esm.py +159 -0
  82. keras_hub/src/utils/transformers/convert_llama3.py +6 -0
  83. keras_hub/src/utils/transformers/convert_qwen3.py +145 -0
  84. keras_hub/src/utils/transformers/export/gemma.py +89 -0
  85. keras_hub/src/utils/transformers/export/hf_exporter.py +98 -0
  86. keras_hub/src/utils/transformers/preset_loader.py +14 -2
  87. keras_hub/src/version.py +1 -1
  88. keras_hub/tokenizers/__init__.py +1 -0
  89. {keras_hub-0.21.1.dist-info → keras_hub-0.22.0.dev0.dist-info}/METADATA +4 -4
  90. {keras_hub-0.21.1.dist-info → keras_hub-0.22.0.dev0.dist-info}/RECORD +92 -48
  91. keras_hub/src/models/clip/clip_encoder_block.py +0 -111
  92. keras_hub/src/models/clip/clip_vision_embedding.py +0 -101
  93. {keras_hub-0.21.1.dist-info → keras_hub-0.22.0.dev0.dist-info}/WHEEL +0 -0
  94. {keras_hub-0.21.1.dist-info → keras_hub-0.22.0.dev0.dist-info}/top_level.txt +0 -0
@@ -1,111 +0,0 @@
1
- from keras import dtype_policies
2
- from keras import layers
3
- from keras import ops
4
-
5
-
6
- def quick_gelu(x):
7
- return x * ops.sigmoid(1.702 * x)
8
-
9
-
10
- # TODO: Deprecate this in favor of `keras.layers.MultiHeadAttention` once the
11
- # dtype compatibility issue is resolved.
12
- class CLIPMultiHeadAttention(layers.MultiHeadAttention):
13
- def _masked_softmax(self, attention_scores, attention_mask=None):
14
- attention_scores = super()._masked_softmax(
15
- attention_scores, attention_mask
16
- )
17
- return ops.cast(attention_scores, self._value_dense.compute_dtype)
18
-
19
-
20
- class CLIPEncoderBlock(layers.Layer):
21
- def __init__(
22
- self,
23
- hidden_dim,
24
- num_heads,
25
- intermediate_dim,
26
- intermediate_activation="quick_gelu",
27
- use_causal_mask=True,
28
- **kwargs,
29
- ):
30
- super().__init__(**kwargs)
31
- if hidden_dim % num_heads != 0:
32
- raise ValueError(
33
- "`hidden_dim` must be divisible by `num_heads`. "
34
- f"Received: hidden_dim={hidden_dim}, num_heads={num_heads}"
35
- )
36
- self.hidden_dim = hidden_dim
37
- self.num_heads = num_heads
38
- self.intermediate_dim = intermediate_dim
39
- self.intermediate_activation = intermediate_activation
40
- self.use_causal_mask = use_causal_mask
41
-
42
- if intermediate_activation == "quick_gelu":
43
- intermediate_activation = quick_gelu
44
-
45
- self.layer_norm_1 = layers.LayerNormalization(
46
- epsilon=1e-5, dtype=self.dtype_policy, name="layer_norm_1"
47
- )
48
- self.attention = CLIPMultiHeadAttention(
49
- num_heads,
50
- hidden_dim // num_heads,
51
- dtype=self.dtype_policy,
52
- name="attention",
53
- )
54
- self.layer_norm_2 = layers.LayerNormalization(
55
- epsilon=1e-5, dtype=self.dtype_policy, name="layer_norm_2"
56
- )
57
- self.dense_1 = layers.Dense(
58
- self.intermediate_dim, dtype=self.dtype_policy, name="dense_1"
59
- )
60
- self.activation = layers.Activation(
61
- intermediate_activation, dtype=self.dtype_policy, name="activation"
62
- )
63
- self.dense_2 = layers.Dense(
64
- self.hidden_dim, dtype=self.dtype_policy, name="dense_2"
65
- )
66
-
67
- def build(self, input_shape):
68
- self.layer_norm_1.build(input_shape)
69
- self.attention.build(input_shape, input_shape, input_shape)
70
- # Before Keras 3.2, there was no setter for `dtype_policy`. Directly
71
- # assign a `DTypePolicy` instead.
72
- self.attention._softmax.dtype_policy = dtype_policies.DTypePolicy(
73
- "float32"
74
- )
75
- self.layer_norm_2.build(input_shape)
76
- self.dense_1.build(input_shape)
77
- input_shape = self.dense_1.compute_output_shape(input_shape)
78
- self.dense_2.build(input_shape)
79
-
80
- def compute_output_shape(self, inputs_shape):
81
- outputs_shape = list(inputs_shape)
82
- outputs_shape[-1] = self.hidden_dim
83
- return outputs_shape
84
-
85
- def call(self, x, training=None):
86
- residual = x
87
- x = self.layer_norm_1(x)
88
- x = self.attention(
89
- x, x, x, training=training, use_causal_mask=self.use_causal_mask
90
- )
91
- x = ops.add(residual, x)
92
-
93
- residual = x
94
- x = self.dense_1(self.layer_norm_2(residual))
95
- x = self.activation(x)
96
- x = self.dense_2(x)
97
- x = ops.add(residual, x)
98
- return x
99
-
100
- def get_config(self):
101
- config = super().get_config()
102
- config.update(
103
- {
104
- "hidden_dim": self.hidden_dim,
105
- "num_heads": self.num_heads,
106
- "intermediate_dim": self.intermediate_dim,
107
- "intermediate_activation": self.intermediate_activation,
108
- "use_causal_mask": self.use_causal_mask,
109
- }
110
- )
111
- return config
@@ -1,101 +0,0 @@
1
- from keras import layers
2
- from keras import ops
3
-
4
- from keras_hub.src.utils.keras_utils import standardize_data_format
5
-
6
-
7
- class CLIPVisionEmbedding(layers.Layer):
8
- def __init__(
9
- self,
10
- hidden_dim,
11
- patch_size,
12
- image_size,
13
- data_format=None,
14
- dtype=None,
15
- **kwargs,
16
- ):
17
- super().__init__(dtype=dtype, **kwargs)
18
- self.hidden_dim = int(hidden_dim)
19
- self.patch_size = int(patch_size)
20
- self.image_size = int(image_size)
21
- data_format = standardize_data_format(data_format)
22
- self.data_format = data_format
23
- num_patches = (image_size // patch_size) ** 2
24
- self.num_positions = num_patches + 1
25
-
26
- self.patch_embedding = layers.Conv2D(
27
- hidden_dim,
28
- kernel_size=patch_size,
29
- strides=patch_size,
30
- data_format=data_format,
31
- use_bias=False,
32
- dtype=dtype,
33
- name="patch_embedding",
34
- )
35
- self.position_embedding = layers.Embedding(
36
- num_patches + 1, hidden_dim, dtype=dtype, name="position_embedding"
37
- )
38
-
39
- def build(self, input_shape):
40
- self.class_embedding = self.add_weight(
41
- shape=(self.hidden_dim,),
42
- initializer="random_normal",
43
- dtype=self.variable_dtype,
44
- name="class_embedding",
45
- )
46
- self.position_ids = self.add_weight(
47
- shape=(1, self.num_positions),
48
- initializer="zeros",
49
- # Let the backend determine the int dtype. For example, tf
50
- # requires int64 for correct device placement, whereas jax and torch
51
- # don't.
52
- dtype=int,
53
- trainable=False,
54
- name="position_ids",
55
- )
56
- self.patch_embedding.build(input_shape)
57
- self.position_embedding.build(self.position_ids.shape)
58
-
59
- def call(self, inputs, training=None):
60
- x = inputs
61
- batch_size = ops.shape(x)[0]
62
- patch_embeddings = self.patch_embedding(x, training=training)
63
- if self.data_format == "channels_last":
64
- patch_embeddings = ops.reshape(
65
- patch_embeddings, (batch_size, -1, self.hidden_dim)
66
- )
67
- else:
68
- patch_embeddings = ops.reshape(
69
- patch_embeddings, (batch_size, self.hidden_dim, -1)
70
- )
71
- patch_embeddings = ops.transpose(patch_embeddings, (0, 2, 1))
72
- class_embeddings = ops.expand_dims(self.class_embedding, axis=(0, 1))
73
- class_embeddings = ops.tile(class_embeddings, (batch_size, 1, 1))
74
- position_embeddings = self.position_embedding(self.position_ids)
75
- embeddings = ops.concatenate(
76
- [class_embeddings, patch_embeddings], axis=1
77
- )
78
- return ops.add(embeddings, position_embeddings)
79
-
80
- def get_config(self):
81
- config = super().get_config()
82
- config.update(
83
- {
84
- "hidden_dim": self.hidden_dim,
85
- "patch_size": self.patch_size,
86
- "image_size": self.image_size,
87
- }
88
- )
89
- return config
90
-
91
- def compute_output_shape(self, input_shape):
92
- output_shape = [input_shape[0], None, self.hidden_dim]
93
- if self.data_format == "channels_last":
94
- if input_shape[1] is not None and input_shape[2] is not None:
95
- patch_num = input_shape[1] // self.patch_size
96
- output_shape[1] = patch_num**2 + 1
97
- else:
98
- if input_shape[2] is not None and input_shape[3] is not None:
99
- patch_num = input_shape[2] // self.patch_size
100
- output_shape[1] = patch_num**2 + 1
101
- return output_shape