keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl → 0.16.0.dev2024092017__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (198) hide show
  1. keras_hub/__init__.py +0 -6
  2. keras_hub/api/__init__.py +2 -0
  3. keras_hub/api/bounding_box/__init__.py +36 -0
  4. keras_hub/api/layers/__init__.py +14 -0
  5. keras_hub/api/models/__init__.py +97 -48
  6. keras_hub/api/tokenizers/__init__.py +30 -0
  7. keras_hub/api/utils/__init__.py +22 -0
  8. keras_hub/src/api_export.py +15 -9
  9. keras_hub/src/bounding_box/__init__.py +13 -0
  10. keras_hub/src/bounding_box/converters.py +529 -0
  11. keras_hub/src/bounding_box/formats.py +162 -0
  12. keras_hub/src/bounding_box/iou.py +263 -0
  13. keras_hub/src/bounding_box/to_dense.py +95 -0
  14. keras_hub/src/bounding_box/to_ragged.py +99 -0
  15. keras_hub/src/bounding_box/utils.py +194 -0
  16. keras_hub/src/bounding_box/validate_format.py +99 -0
  17. keras_hub/src/layers/preprocessing/audio_converter.py +121 -0
  18. keras_hub/src/layers/preprocessing/image_converter.py +130 -0
  19. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +2 -0
  20. keras_hub/src/layers/preprocessing/multi_segment_packer.py +9 -8
  21. keras_hub/src/layers/preprocessing/preprocessing_layer.py +2 -29
  22. keras_hub/src/layers/preprocessing/random_deletion.py +33 -31
  23. keras_hub/src/layers/preprocessing/random_swap.py +33 -31
  24. keras_hub/src/layers/preprocessing/resizing_image_converter.py +101 -0
  25. keras_hub/src/layers/preprocessing/start_end_packer.py +3 -2
  26. keras_hub/src/models/albert/__init__.py +1 -2
  27. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +6 -86
  28. keras_hub/src/models/albert/{albert_classifier.py → albert_text_classifier.py} +34 -10
  29. keras_hub/src/models/albert/{albert_preprocessor.py → albert_text_classifier_preprocessor.py} +14 -70
  30. keras_hub/src/models/albert/albert_tokenizer.py +17 -36
  31. keras_hub/src/models/backbone.py +12 -34
  32. keras_hub/src/models/bart/__init__.py +1 -2
  33. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +21 -148
  34. keras_hub/src/models/bart/bart_tokenizer.py +12 -39
  35. keras_hub/src/models/bert/__init__.py +1 -5
  36. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +6 -87
  37. keras_hub/src/models/bert/bert_presets.py +1 -4
  38. keras_hub/src/models/bert/{bert_classifier.py → bert_text_classifier.py} +19 -12
  39. keras_hub/src/models/bert/{bert_preprocessor.py → bert_text_classifier_preprocessor.py} +14 -70
  40. keras_hub/src/models/bert/bert_tokenizer.py +17 -35
  41. keras_hub/src/models/bloom/__init__.py +1 -2
  42. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +6 -91
  43. keras_hub/src/models/bloom/bloom_tokenizer.py +12 -41
  44. keras_hub/src/models/causal_lm.py +10 -29
  45. keras_hub/src/models/causal_lm_preprocessor.py +195 -0
  46. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +54 -15
  47. keras_hub/src/models/deberta_v3/__init__.py +1 -4
  48. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +14 -77
  49. keras_hub/src/models/deberta_v3/{deberta_v3_classifier.py → deberta_v3_text_classifier.py} +16 -11
  50. keras_hub/src/models/deberta_v3/{deberta_v3_preprocessor.py → deberta_v3_text_classifier_preprocessor.py} +23 -64
  51. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +30 -25
  52. keras_hub/src/models/densenet/densenet_backbone.py +46 -22
  53. keras_hub/src/models/distil_bert/__init__.py +1 -4
  54. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +14 -76
  55. keras_hub/src/models/distil_bert/{distil_bert_classifier.py → distil_bert_text_classifier.py} +17 -12
  56. keras_hub/src/models/distil_bert/{distil_bert_preprocessor.py → distil_bert_text_classifier_preprocessor.py} +23 -63
  57. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +19 -35
  58. keras_hub/src/models/efficientnet/__init__.py +13 -0
  59. keras_hub/src/models/efficientnet/efficientnet_backbone.py +569 -0
  60. keras_hub/src/models/efficientnet/fusedmbconv.py +229 -0
  61. keras_hub/src/models/efficientnet/mbconv.py +238 -0
  62. keras_hub/src/models/electra/__init__.py +1 -2
  63. keras_hub/src/models/electra/electra_tokenizer.py +17 -32
  64. keras_hub/src/models/f_net/__init__.py +1 -2
  65. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +12 -78
  66. keras_hub/src/models/f_net/{f_net_classifier.py → f_net_text_classifier.py} +17 -10
  67. keras_hub/src/models/f_net/{f_net_preprocessor.py → f_net_text_classifier_preprocessor.py} +19 -63
  68. keras_hub/src/models/f_net/f_net_tokenizer.py +17 -35
  69. keras_hub/src/models/falcon/__init__.py +1 -2
  70. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +6 -89
  71. keras_hub/src/models/falcon/falcon_tokenizer.py +12 -35
  72. keras_hub/src/models/gemma/__init__.py +1 -2
  73. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +6 -90
  74. keras_hub/src/models/gemma/gemma_decoder_block.py +1 -1
  75. keras_hub/src/models/gemma/gemma_tokenizer.py +12 -23
  76. keras_hub/src/models/gpt2/__init__.py +1 -2
  77. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +6 -89
  78. keras_hub/src/models/gpt2/gpt2_preprocessor.py +12 -90
  79. keras_hub/src/models/gpt2/gpt2_tokenizer.py +12 -34
  80. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +6 -91
  81. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +12 -34
  82. keras_hub/src/models/image_classifier.py +0 -5
  83. keras_hub/src/models/image_classifier_preprocessor.py +83 -0
  84. keras_hub/src/models/llama/__init__.py +1 -2
  85. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +6 -85
  86. keras_hub/src/models/llama/llama_tokenizer.py +12 -25
  87. keras_hub/src/models/llama3/__init__.py +1 -2
  88. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +6 -89
  89. keras_hub/src/models/llama3/llama3_tokenizer.py +12 -33
  90. keras_hub/src/models/masked_lm.py +0 -2
  91. keras_hub/src/models/masked_lm_preprocessor.py +156 -0
  92. keras_hub/src/models/mistral/__init__.py +1 -2
  93. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +6 -91
  94. keras_hub/src/models/mistral/mistral_tokenizer.py +12 -23
  95. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +2 -2
  96. keras_hub/src/models/mobilenet/__init__.py +13 -0
  97. keras_hub/src/models/mobilenet/mobilenet_backbone.py +530 -0
  98. keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +114 -0
  99. keras_hub/src/models/opt/__init__.py +1 -2
  100. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +6 -93
  101. keras_hub/src/models/opt/opt_tokenizer.py +12 -41
  102. keras_hub/src/models/pali_gemma/__init__.py +1 -4
  103. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +28 -28
  104. keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +25 -0
  105. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +5 -5
  106. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +11 -3
  107. keras_hub/src/models/phi3/__init__.py +1 -2
  108. keras_hub/src/models/phi3/phi3_causal_lm.py +3 -9
  109. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +6 -89
  110. keras_hub/src/models/phi3/phi3_tokenizer.py +12 -36
  111. keras_hub/src/models/preprocessor.py +72 -83
  112. keras_hub/src/models/resnet/__init__.py +6 -0
  113. keras_hub/src/models/resnet/resnet_backbone.py +390 -42
  114. keras_hub/src/models/resnet/resnet_image_classifier.py +33 -6
  115. keras_hub/src/models/resnet/resnet_image_classifier_preprocessor.py +28 -0
  116. keras_hub/src/models/{llama3/llama3_preprocessor.py → resnet/resnet_image_converter.py} +7 -5
  117. keras_hub/src/models/resnet/resnet_presets.py +95 -0
  118. keras_hub/src/models/retinanet/__init__.py +13 -0
  119. keras_hub/src/models/retinanet/anchor_generator.py +175 -0
  120. keras_hub/src/models/retinanet/box_matcher.py +259 -0
  121. keras_hub/src/models/retinanet/non_max_supression.py +578 -0
  122. keras_hub/src/models/roberta/__init__.py +1 -2
  123. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +22 -74
  124. keras_hub/src/models/roberta/{roberta_classifier.py → roberta_text_classifier.py} +16 -11
  125. keras_hub/src/models/roberta/{roberta_preprocessor.py → roberta_text_classifier_preprocessor.py} +21 -53
  126. keras_hub/src/models/roberta/roberta_tokenizer.py +13 -52
  127. keras_hub/src/models/seq_2_seq_lm_preprocessor.py +269 -0
  128. keras_hub/src/models/stable_diffusion_v3/__init__.py +13 -0
  129. keras_hub/src/models/stable_diffusion_v3/clip_encoder_block.py +103 -0
  130. keras_hub/src/models/stable_diffusion_v3/clip_preprocessor.py +93 -0
  131. keras_hub/src/models/stable_diffusion_v3/clip_text_encoder.py +149 -0
  132. keras_hub/src/models/stable_diffusion_v3/clip_tokenizer.py +167 -0
  133. keras_hub/src/models/stable_diffusion_v3/mmdit.py +427 -0
  134. keras_hub/src/models/stable_diffusion_v3/mmdit_block.py +317 -0
  135. keras_hub/src/models/stable_diffusion_v3/t5_xxl_preprocessor.py +74 -0
  136. keras_hub/src/models/stable_diffusion_v3/t5_xxl_text_encoder.py +155 -0
  137. keras_hub/src/models/stable_diffusion_v3/vae_attention.py +126 -0
  138. keras_hub/src/models/stable_diffusion_v3/vae_image_decoder.py +186 -0
  139. keras_hub/src/models/t5/__init__.py +1 -2
  140. keras_hub/src/models/t5/t5_tokenizer.py +13 -23
  141. keras_hub/src/models/task.py +71 -116
  142. keras_hub/src/models/{classifier.py → text_classifier.py} +19 -13
  143. keras_hub/src/models/text_classifier_preprocessor.py +138 -0
  144. keras_hub/src/models/whisper/__init__.py +1 -2
  145. keras_hub/src/models/whisper/{whisper_audio_feature_extractor.py → whisper_audio_converter.py} +20 -18
  146. keras_hub/src/models/whisper/whisper_backbone.py +0 -3
  147. keras_hub/src/models/whisper/whisper_presets.py +10 -10
  148. keras_hub/src/models/whisper/whisper_tokenizer.py +20 -16
  149. keras_hub/src/models/xlm_roberta/__init__.py +1 -4
  150. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +26 -72
  151. keras_hub/src/models/xlm_roberta/{xlm_roberta_classifier.py → xlm_roberta_text_classifier.py} +16 -11
  152. keras_hub/src/models/xlm_roberta/{xlm_roberta_preprocessor.py → xlm_roberta_text_classifier_preprocessor.py} +26 -53
  153. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +25 -10
  154. keras_hub/src/tests/test_case.py +46 -0
  155. keras_hub/src/tokenizers/byte_pair_tokenizer.py +30 -17
  156. keras_hub/src/tokenizers/byte_tokenizer.py +14 -15
  157. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +20 -7
  158. keras_hub/src/tokenizers/tokenizer.py +67 -32
  159. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +14 -15
  160. keras_hub/src/tokenizers/word_piece_tokenizer.py +34 -47
  161. keras_hub/src/utils/imagenet/__init__.py +13 -0
  162. keras_hub/src/utils/imagenet/imagenet_utils.py +1067 -0
  163. keras_hub/src/utils/keras_utils.py +0 -50
  164. keras_hub/src/utils/preset_utils.py +230 -68
  165. keras_hub/src/utils/tensor_utils.py +187 -69
  166. keras_hub/src/utils/timm/convert_resnet.py +19 -16
  167. keras_hub/src/utils/timm/preset_loader.py +66 -0
  168. keras_hub/src/utils/transformers/convert_albert.py +193 -0
  169. keras_hub/src/utils/transformers/convert_bart.py +373 -0
  170. keras_hub/src/utils/transformers/convert_bert.py +7 -17
  171. keras_hub/src/utils/transformers/convert_distilbert.py +10 -20
  172. keras_hub/src/utils/transformers/convert_gemma.py +5 -19
  173. keras_hub/src/utils/transformers/convert_gpt2.py +5 -18
  174. keras_hub/src/utils/transformers/convert_llama3.py +7 -18
  175. keras_hub/src/utils/transformers/convert_mistral.py +129 -0
  176. keras_hub/src/utils/transformers/convert_pali_gemma.py +7 -29
  177. keras_hub/src/utils/transformers/preset_loader.py +77 -0
  178. keras_hub/src/utils/transformers/safetensor_utils.py +2 -2
  179. keras_hub/src/version_utils.py +1 -1
  180. keras_hub_nightly-0.16.0.dev2024092017.dist-info/METADATA +202 -0
  181. keras_hub_nightly-0.16.0.dev2024092017.dist-info/RECORD +334 -0
  182. {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/WHEEL +1 -1
  183. keras_hub/src/models/bart/bart_preprocessor.py +0 -276
  184. keras_hub/src/models/bloom/bloom_preprocessor.py +0 -185
  185. keras_hub/src/models/electra/electra_preprocessor.py +0 -154
  186. keras_hub/src/models/falcon/falcon_preprocessor.py +0 -187
  187. keras_hub/src/models/gemma/gemma_preprocessor.py +0 -191
  188. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +0 -145
  189. keras_hub/src/models/llama/llama_preprocessor.py +0 -189
  190. keras_hub/src/models/mistral/mistral_preprocessor.py +0 -190
  191. keras_hub/src/models/opt/opt_preprocessor.py +0 -188
  192. keras_hub/src/models/phi3/phi3_preprocessor.py +0 -190
  193. keras_hub/src/models/whisper/whisper_preprocessor.py +0 -326
  194. keras_hub/src/utils/timm/convert.py +0 -37
  195. keras_hub/src/utils/transformers/convert.py +0 -101
  196. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +0 -34
  197. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +0 -297
  198. {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,13 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,103 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from keras import layers
15
+ from keras import ops
16
+
17
+
18
+ def quick_gelu(x):
19
+ return x * ops.sigmoid(1.702 * x)
20
+
21
+
22
+ class CLIPEncoderBlock(layers.Layer):
23
+ def __init__(
24
+ self,
25
+ hidden_dim,
26
+ num_heads,
27
+ intermediate_dim,
28
+ intermediate_activation="quick_gelu",
29
+ **kwargs,
30
+ ):
31
+ super().__init__(**kwargs)
32
+ if hidden_dim % num_heads != 0:
33
+ raise ValueError(
34
+ "`hidden_dim` must be divisible by `num_heads`. "
35
+ f"Received: hidden_dim={hidden_dim}, num_heads={num_heads}"
36
+ )
37
+ self.hidden_dim = hidden_dim
38
+ self.num_heads = num_heads
39
+ self.intermediate_dim = intermediate_dim
40
+ self.intermediate_activation = intermediate_activation
41
+
42
+ if intermediate_activation == "quick_gelu":
43
+ intermediate_activation = quick_gelu
44
+
45
+ self.layer_norm_1 = layers.LayerNormalization(
46
+ epsilon=0.00001, dtype=self.dtype_policy, name="layer_norm_1"
47
+ )
48
+ self.attention = layers.MultiHeadAttention(
49
+ num_heads,
50
+ hidden_dim // num_heads,
51
+ dtype=self.dtype_policy,
52
+ name="attention",
53
+ )
54
+ self.layer_norm_2 = layers.LayerNormalization(
55
+ epsilon=0.00001, dtype=self.dtype_policy, name="layer_norm_2"
56
+ )
57
+ self.dense_1 = layers.Dense(
58
+ self.intermediate_dim, dtype=self.dtype_policy, name="dense_1"
59
+ )
60
+ self.activation = layers.Activation(
61
+ intermediate_activation, dtype=self.dtype_policy, name="activation"
62
+ )
63
+ self.dense_2 = layers.Dense(
64
+ self.hidden_dim, dtype=self.dtype_policy, name="dense_2"
65
+ )
66
+
67
+ def build(self, input_shape):
68
+ self.layer_norm_1.build(input_shape)
69
+ self.attention.build(input_shape, input_shape, input_shape)
70
+ self.layer_norm_2.build(input_shape)
71
+ self.dense_1.build(input_shape)
72
+ input_shape = self.dense_1.compute_output_shape(input_shape)
73
+ self.dense_2.build(input_shape)
74
+
75
+ def compute_output_shape(self, inputs_shape):
76
+ outputs_shape = list(inputs_shape)
77
+ outputs_shape[-1] = self.hidden_dim
78
+ return outputs_shape
79
+
80
+ def call(self, x, training=None):
81
+ residual = x
82
+ x = self.layer_norm_1(x)
83
+ x = self.attention(x, x, x, training=training, use_causal_mask=True)
84
+ x = ops.add(residual, x)
85
+
86
+ residual = x
87
+ x = self.dense_1(self.layer_norm_2(residual))
88
+ x = self.activation(x)
89
+ x = self.dense_2(x)
90
+ x = ops.add(residual, x)
91
+ return x
92
+
93
+ def get_config(self):
94
+ config = super().get_config()
95
+ config.update(
96
+ {
97
+ "hidden_dim": self.hidden_dim,
98
+ "num_heads": self.num_heads,
99
+ "intermediate_dim": self.intermediate_dim,
100
+ "intermediate_activation": self.intermediate_activation,
101
+ }
102
+ )
103
+ return config
@@ -0,0 +1,93 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import keras
15
+
16
+ from keras_hub.src.layers.preprocessing.start_end_packer import StartEndPacker
17
+ from keras_hub.src.models.preprocessor import Preprocessor
18
+ from keras_hub.src.models.stable_diffusion_v3.clip_tokenizer import (
19
+ CLIPTokenizer,
20
+ )
21
+ from keras_hub.src.utils.tensor_utils import preprocessing_function
22
+
23
+ try:
24
+ import tensorflow as tf
25
+ except ImportError:
26
+ tf = None
27
+
28
+
29
+ class CLIPPreprocessor(Preprocessor):
30
+ tokenizer_cls = CLIPTokenizer
31
+
32
+ def __init__(
33
+ self,
34
+ tokenizer,
35
+ sequence_length=77,
36
+ add_start_token=True,
37
+ add_end_token=False,
38
+ to_lower=True,
39
+ pad_with_end_token=True,
40
+ **kwargs,
41
+ ):
42
+ super().__init__(**kwargs)
43
+ self.tokenizer = tokenizer
44
+ self.sequence_length = sequence_length
45
+ self.add_start_token = add_start_token
46
+ self.add_end_token = add_end_token
47
+ self.to_lower = to_lower
48
+ self.pad_with_end_token = pad_with_end_token
49
+
50
+ def build(self, input_shape):
51
+ # Defer packer creation to `build()` so that we can be sure tokenizer
52
+ # assets have loaded when restoring a saved model.
53
+ pad_value = self.tokenizer.pad_token_id
54
+ if self.pad_with_end_token:
55
+ pad_value = self.tokenizer.end_token_id
56
+
57
+ self.packer = StartEndPacker(
58
+ start_value=self.tokenizer.start_token_id,
59
+ end_value=self.tokenizer.end_token_id,
60
+ pad_value=pad_value,
61
+ sequence_length=self.sequence_length,
62
+ return_padding_mask=True,
63
+ )
64
+ self.built = True
65
+
66
+ @preprocessing_function
67
+ def call(self, x, y=None, sample_weight=None, sequence_length=None):
68
+ if self.to_lower:
69
+ x = tf.strings.lower(x)
70
+ token_ids, padding_mask = self.packer(
71
+ self.tokenizer(x),
72
+ sequence_length=sequence_length or self.sequence_length,
73
+ add_start_value=self.add_start_token,
74
+ add_end_value=self.add_end_token,
75
+ )
76
+ x = {
77
+ "token_ids": token_ids,
78
+ "padding_mask": padding_mask,
79
+ }
80
+ return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
81
+
82
+ def get_config(self):
83
+ config = super().get_config()
84
+ config.update(
85
+ {
86
+ "sequence_length": self.sequence_length,
87
+ "add_start_token": self.add_start_token,
88
+ "add_end_token": self.add_end_token,
89
+ "to_lower": self.to_lower,
90
+ "pad_with_end_token": self.pad_with_end_token,
91
+ }
92
+ )
93
+ return config
@@ -0,0 +1,149 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import keras
15
+ from keras import layers
16
+ from keras import ops
17
+
18
+ from keras_hub.src.layers.modeling.token_and_position_embedding import (
19
+ TokenAndPositionEmbedding,
20
+ )
21
+ from keras_hub.src.models.stable_diffusion_v3.clip_encoder_block import (
22
+ CLIPEncoderBlock,
23
+ )
24
+
25
+
26
+ class CLIPTextEncoder(keras.Model):
27
+ def __init__(
28
+ self,
29
+ embedding_dim,
30
+ hidden_dim,
31
+ num_layers,
32
+ num_heads,
33
+ intermediate_dim,
34
+ intermediate_activation="quick_gelu",
35
+ intermediate_output_index=None,
36
+ vocabulary_size=49408,
37
+ sequence_length=77,
38
+ dtype=None,
39
+ **kwargs,
40
+ ):
41
+ if (
42
+ intermediate_output_index is not None
43
+ and intermediate_output_index < 0
44
+ ):
45
+ intermediate_output_index += num_layers
46
+
47
+ # === Layers ===
48
+ self.embedding = TokenAndPositionEmbedding(
49
+ vocabulary_size=vocabulary_size,
50
+ sequence_length=sequence_length,
51
+ embedding_dim=embedding_dim,
52
+ dtype=dtype,
53
+ name="embedding",
54
+ )
55
+ self.encoder_layers = [
56
+ CLIPEncoderBlock(
57
+ hidden_dim,
58
+ num_heads,
59
+ intermediate_dim,
60
+ intermediate_activation,
61
+ dtype=dtype,
62
+ )
63
+ for _ in range(num_layers)
64
+ ]
65
+ self.layer_norm = layers.LayerNormalization(
66
+ epsilon=0.00001, dtype=dtype, name="layer_norm"
67
+ )
68
+ self.text_projection = layers.Dense(
69
+ hidden_dim,
70
+ use_bias=False,
71
+ dtype=dtype,
72
+ name="text_projection",
73
+ )
74
+
75
+ # === Functional Model ===
76
+ encoder_token_ids = layers.Input(
77
+ shape=(sequence_length,), dtype="int32", name="encoder_token_ids"
78
+ )
79
+ x = self.embedding(encoder_token_ids)
80
+ encoder_intermediate_output = None
81
+ # Encoder.
82
+ for i, block in enumerate(self.encoder_layers):
83
+ x = block(x)
84
+ if i == intermediate_output_index:
85
+ encoder_intermediate_output = x
86
+ x = self.layer_norm(x)
87
+ encoder_output = x
88
+ if encoder_intermediate_output is not None:
89
+ encoder_intermediate_output = self.layer_norm(
90
+ encoder_intermediate_output
91
+ )
92
+ # Projection.
93
+ indices = ops.expand_dims(
94
+ ops.cast(ops.argmax(encoder_token_ids, axis=-1), "int32"), axis=-1
95
+ )
96
+ pooled_output = ops.take_along_axis(x, indices[:, :, None], axis=1)
97
+ pooled_output = ops.squeeze(pooled_output, axis=1)
98
+ projection_output = self.text_projection(pooled_output)
99
+
100
+ outputs = {
101
+ "encoder_sequence_output": encoder_output,
102
+ "encoder_pooled_output": pooled_output,
103
+ "encoder_projection_output": projection_output,
104
+ }
105
+ if intermediate_output_index is not None:
106
+ outputs["encoder_intermediate_output"] = encoder_intermediate_output
107
+
108
+ super().__init__(
109
+ inputs={"encoder_token_ids": encoder_token_ids},
110
+ outputs=outputs,
111
+ **kwargs,
112
+ )
113
+
114
+ # === Config ===
115
+ self.embedding_dim = embedding_dim
116
+ self.hidden_dim = hidden_dim
117
+ self.num_layers = num_layers
118
+ self.num_heads = num_heads
119
+ self.intermediate_dim = intermediate_dim
120
+ self.intermediate_activation = intermediate_activation
121
+ self.intermediate_output_index = intermediate_output_index
122
+ self.vocabulary_size = vocabulary_size
123
+ self.sequence_length = sequence_length
124
+
125
+ if dtype is not None:
126
+ try:
127
+ self.dtype_policy = keras.dtype_policies.get(dtype)
128
+ # Before Keras 3.2, there is no `keras.dtype_policies.get`.
129
+ except AttributeError:
130
+ if isinstance(dtype, keras.DTypePolicy):
131
+ dtype = dtype.name
132
+ self.dtype_policy = keras.DTypePolicy(dtype)
133
+
134
+ def get_config(self):
135
+ config = super().get_config()
136
+ config.update(
137
+ {
138
+ "embedding_dim": self.embedding_dim,
139
+ "hidden_dim": self.hidden_dim,
140
+ "num_layers": self.num_layers,
141
+ "num_heads": self.num_heads,
142
+ "intermediate_dim": self.intermediate_dim,
143
+ "intermediate_activation": self.intermediate_activation,
144
+ "intermediate_output_index": self.intermediate_output_index,
145
+ "vocabulary_size": self.vocabulary_size,
146
+ "sequence_length": self.sequence_length,
147
+ }
148
+ )
149
+ return config
@@ -0,0 +1,167 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from keras_hub.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer
15
+ from keras_hub.src.tokenizers.byte_pair_tokenizer import convert_to_ragged_batch
16
+ from keras_hub.src.tokenizers.byte_pair_tokenizer import split_strings_for_bpe
17
+
18
+ try:
19
+ import tensorflow as tf
20
+ except ImportError:
21
+ tf = None
22
+
23
+
24
+ class CLIPTokenizer(BytePairTokenizer):
25
+ def __init__(self, vocabulary=None, merges=None, **kwargs):
26
+ self.start_token = "<|startoftext|>"
27
+ self.end_token = "<|endoftext|>"
28
+
29
+ super().__init__(
30
+ vocabulary=vocabulary,
31
+ merges=merges,
32
+ unsplittable_tokens=[self.start_token, self.end_token],
33
+ **kwargs,
34
+ )
35
+
36
+ def set_vocabulary_and_merges(self, vocabulary, merges):
37
+ super().set_vocabulary_and_merges(vocabulary, merges)
38
+
39
+ if vocabulary is not None:
40
+ # Check for necessary special tokens.
41
+ if self.end_token not in self.get_vocabulary():
42
+ raise ValueError(
43
+ f"Cannot find token `'{self.end_token}'` in the provided "
44
+ f"`vocabulary`. Please provide `'{self.end_token}'` in "
45
+ "your `vocabulary` or use a pretrained `vocabulary` name."
46
+ )
47
+
48
+ self.start_token_id = self.token_to_id(self.start_token)
49
+ self.end_token_id = self.token_to_id(self.end_token)
50
+ self.pad_token_id = 0
51
+ else:
52
+ self.end_token_id = None
53
+ self.start_token_id = None
54
+ self.pad_token_id = None
55
+
56
+ def _bpe_merge_and_update_cache(self, tokens):
57
+ """Process unseen tokens and add to cache."""
58
+ words = self._transform_bytes(tokens)
59
+
60
+ # In StableDiffusionV3, we need to add `</w>` to the last word.
61
+ words = tf.strings.reduce_join(words, axis=1, separator=" ")
62
+ words = tf.strings.join([words, "</w>"])
63
+ words = tf.strings.split(words, sep=" ")
64
+
65
+ tokenized_words = self._bpe_merge(words)
66
+
67
+ # For each word, join all its token by a whitespace,
68
+ # e.g., ["dragon", "fly"] => "dragon fly" for hash purpose.
69
+ tokenized_words = tf.strings.reduce_join(
70
+ tokenized_words, axis=1, separator=" "
71
+ )
72
+ self.cache.insert(tokens, tokenized_words)
73
+
74
+ def tokenize(self, inputs):
75
+ self._check_vocabulary()
76
+ if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)):
77
+ inputs = tf.convert_to_tensor(inputs)
78
+
79
+ if self.add_prefix_space:
80
+ inputs = tf.strings.join([" ", inputs])
81
+
82
+ scalar_input = inputs.shape.rank == 0
83
+ if scalar_input:
84
+ inputs = tf.expand_dims(inputs, 0)
85
+
86
+ raw_tokens = split_strings_for_bpe(inputs, self.unsplittable_tokens)
87
+
88
+ # Strip and remove empty tokens.
89
+ raw_tokens = tf.strings.strip(raw_tokens)
90
+ raw_tokens = tf.ragged.boolean_mask(raw_tokens, raw_tokens != "")
91
+
92
+ token_row_splits = raw_tokens.row_splits
93
+ flat_tokens = raw_tokens.flat_values
94
+
95
+ # Check cache.
96
+ cache_lookup = self.cache.lookup(flat_tokens)
97
+ cache_mask = cache_lookup == ""
98
+
99
+ has_unseen_words = tf.math.reduce_any(
100
+ (cache_lookup == "") & (flat_tokens != "")
101
+ )
102
+
103
+ def process_unseen_tokens():
104
+ unseen_tokens = tf.boolean_mask(flat_tokens, cache_mask)
105
+ self._bpe_merge_and_update_cache(unseen_tokens)
106
+ return self.cache.lookup(flat_tokens)
107
+
108
+ # If `has_unseen_words == True`, it means not all tokens are in cache,
109
+ # we will process the unseen tokens. Otherwise return the cache lookup.
110
+ tokenized_words = tf.cond(
111
+ has_unseen_words,
112
+ process_unseen_tokens,
113
+ lambda: cache_lookup,
114
+ )
115
+
116
+ tokens = tf.strings.split(tokenized_words, sep=" ")
117
+ if self.compute_dtype != tf.string:
118
+ # Encode merged tokens.
119
+ tokens = self.token_to_id_map.lookup(tokens)
120
+
121
+ # Unflatten to match input.
122
+ tokens = tf.RaggedTensor.from_row_splits(
123
+ tokens.flat_values,
124
+ tf.gather(tokens.row_splits, token_row_splits),
125
+ )
126
+
127
+ # Convert to a dense output if `sequence_length` is set.
128
+ if self.sequence_length:
129
+ output_shape = tokens.shape.as_list()
130
+ output_shape[-1] = self.sequence_length
131
+ tokens = tokens.to_tensor(shape=output_shape)
132
+
133
+ # Convert to a dense output if input in scalar
134
+ if scalar_input:
135
+ tokens = tf.squeeze(tokens, 0)
136
+ tf.ensure_shape(tokens, shape=[self.sequence_length])
137
+
138
+ return tokens
139
+
140
+ def detokenize(self, inputs):
141
+ self._check_vocabulary()
142
+ inputs, unbatched, _ = convert_to_ragged_batch(inputs)
143
+ inputs = tf.cast(inputs, self.dtype)
144
+ unicode_text = tf.strings.reduce_join(
145
+ self.id_to_token_map.lookup(inputs), axis=-1
146
+ )
147
+
148
+ # When detokenizing, we need to remove </w> and extra whitespace.
149
+ unicode_text = tf.strings.regex_replace(unicode_text, r"</w>", " ")
150
+ unicode_text = tf.strings.strip(unicode_text)
151
+
152
+ split_unicode_text = tf.strings.unicode_split(unicode_text, "UTF-8")
153
+ outputs = tf.strings.reduce_join(
154
+ self.unicode2byte.lookup(split_unicode_text), axis=-1
155
+ )
156
+
157
+ if unbatched:
158
+ outputs = tf.squeeze(outputs, 0)
159
+ return outputs
160
+
161
+ def get_config(self):
162
+ config = super().get_config()
163
+ # In the constructor, we pass the list of special tokens to the
164
+ # `unsplittable_tokens` arg of the superclass' constructor. Hence, we
165
+ # delete it from the config here.
166
+ del config["unsplittable_tokens"]
167
+ return config