keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl → 0.16.0.dev2024092017__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (198) hide show
  1. keras_hub/__init__.py +0 -6
  2. keras_hub/api/__init__.py +2 -0
  3. keras_hub/api/bounding_box/__init__.py +36 -0
  4. keras_hub/api/layers/__init__.py +14 -0
  5. keras_hub/api/models/__init__.py +97 -48
  6. keras_hub/api/tokenizers/__init__.py +30 -0
  7. keras_hub/api/utils/__init__.py +22 -0
  8. keras_hub/src/api_export.py +15 -9
  9. keras_hub/src/bounding_box/__init__.py +13 -0
  10. keras_hub/src/bounding_box/converters.py +529 -0
  11. keras_hub/src/bounding_box/formats.py +162 -0
  12. keras_hub/src/bounding_box/iou.py +263 -0
  13. keras_hub/src/bounding_box/to_dense.py +95 -0
  14. keras_hub/src/bounding_box/to_ragged.py +99 -0
  15. keras_hub/src/bounding_box/utils.py +194 -0
  16. keras_hub/src/bounding_box/validate_format.py +99 -0
  17. keras_hub/src/layers/preprocessing/audio_converter.py +121 -0
  18. keras_hub/src/layers/preprocessing/image_converter.py +130 -0
  19. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +2 -0
  20. keras_hub/src/layers/preprocessing/multi_segment_packer.py +9 -8
  21. keras_hub/src/layers/preprocessing/preprocessing_layer.py +2 -29
  22. keras_hub/src/layers/preprocessing/random_deletion.py +33 -31
  23. keras_hub/src/layers/preprocessing/random_swap.py +33 -31
  24. keras_hub/src/layers/preprocessing/resizing_image_converter.py +101 -0
  25. keras_hub/src/layers/preprocessing/start_end_packer.py +3 -2
  26. keras_hub/src/models/albert/__init__.py +1 -2
  27. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +6 -86
  28. keras_hub/src/models/albert/{albert_classifier.py → albert_text_classifier.py} +34 -10
  29. keras_hub/src/models/albert/{albert_preprocessor.py → albert_text_classifier_preprocessor.py} +14 -70
  30. keras_hub/src/models/albert/albert_tokenizer.py +17 -36
  31. keras_hub/src/models/backbone.py +12 -34
  32. keras_hub/src/models/bart/__init__.py +1 -2
  33. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +21 -148
  34. keras_hub/src/models/bart/bart_tokenizer.py +12 -39
  35. keras_hub/src/models/bert/__init__.py +1 -5
  36. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +6 -87
  37. keras_hub/src/models/bert/bert_presets.py +1 -4
  38. keras_hub/src/models/bert/{bert_classifier.py → bert_text_classifier.py} +19 -12
  39. keras_hub/src/models/bert/{bert_preprocessor.py → bert_text_classifier_preprocessor.py} +14 -70
  40. keras_hub/src/models/bert/bert_tokenizer.py +17 -35
  41. keras_hub/src/models/bloom/__init__.py +1 -2
  42. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +6 -91
  43. keras_hub/src/models/bloom/bloom_tokenizer.py +12 -41
  44. keras_hub/src/models/causal_lm.py +10 -29
  45. keras_hub/src/models/causal_lm_preprocessor.py +195 -0
  46. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +54 -15
  47. keras_hub/src/models/deberta_v3/__init__.py +1 -4
  48. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +14 -77
  49. keras_hub/src/models/deberta_v3/{deberta_v3_classifier.py → deberta_v3_text_classifier.py} +16 -11
  50. keras_hub/src/models/deberta_v3/{deberta_v3_preprocessor.py → deberta_v3_text_classifier_preprocessor.py} +23 -64
  51. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +30 -25
  52. keras_hub/src/models/densenet/densenet_backbone.py +46 -22
  53. keras_hub/src/models/distil_bert/__init__.py +1 -4
  54. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +14 -76
  55. keras_hub/src/models/distil_bert/{distil_bert_classifier.py → distil_bert_text_classifier.py} +17 -12
  56. keras_hub/src/models/distil_bert/{distil_bert_preprocessor.py → distil_bert_text_classifier_preprocessor.py} +23 -63
  57. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +19 -35
  58. keras_hub/src/models/efficientnet/__init__.py +13 -0
  59. keras_hub/src/models/efficientnet/efficientnet_backbone.py +569 -0
  60. keras_hub/src/models/efficientnet/fusedmbconv.py +229 -0
  61. keras_hub/src/models/efficientnet/mbconv.py +238 -0
  62. keras_hub/src/models/electra/__init__.py +1 -2
  63. keras_hub/src/models/electra/electra_tokenizer.py +17 -32
  64. keras_hub/src/models/f_net/__init__.py +1 -2
  65. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +12 -78
  66. keras_hub/src/models/f_net/{f_net_classifier.py → f_net_text_classifier.py} +17 -10
  67. keras_hub/src/models/f_net/{f_net_preprocessor.py → f_net_text_classifier_preprocessor.py} +19 -63
  68. keras_hub/src/models/f_net/f_net_tokenizer.py +17 -35
  69. keras_hub/src/models/falcon/__init__.py +1 -2
  70. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +6 -89
  71. keras_hub/src/models/falcon/falcon_tokenizer.py +12 -35
  72. keras_hub/src/models/gemma/__init__.py +1 -2
  73. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +6 -90
  74. keras_hub/src/models/gemma/gemma_decoder_block.py +1 -1
  75. keras_hub/src/models/gemma/gemma_tokenizer.py +12 -23
  76. keras_hub/src/models/gpt2/__init__.py +1 -2
  77. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +6 -89
  78. keras_hub/src/models/gpt2/gpt2_preprocessor.py +12 -90
  79. keras_hub/src/models/gpt2/gpt2_tokenizer.py +12 -34
  80. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +6 -91
  81. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +12 -34
  82. keras_hub/src/models/image_classifier.py +0 -5
  83. keras_hub/src/models/image_classifier_preprocessor.py +83 -0
  84. keras_hub/src/models/llama/__init__.py +1 -2
  85. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +6 -85
  86. keras_hub/src/models/llama/llama_tokenizer.py +12 -25
  87. keras_hub/src/models/llama3/__init__.py +1 -2
  88. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +6 -89
  89. keras_hub/src/models/llama3/llama3_tokenizer.py +12 -33
  90. keras_hub/src/models/masked_lm.py +0 -2
  91. keras_hub/src/models/masked_lm_preprocessor.py +156 -0
  92. keras_hub/src/models/mistral/__init__.py +1 -2
  93. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +6 -91
  94. keras_hub/src/models/mistral/mistral_tokenizer.py +12 -23
  95. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +2 -2
  96. keras_hub/src/models/mobilenet/__init__.py +13 -0
  97. keras_hub/src/models/mobilenet/mobilenet_backbone.py +530 -0
  98. keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +114 -0
  99. keras_hub/src/models/opt/__init__.py +1 -2
  100. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +6 -93
  101. keras_hub/src/models/opt/opt_tokenizer.py +12 -41
  102. keras_hub/src/models/pali_gemma/__init__.py +1 -4
  103. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +28 -28
  104. keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +25 -0
  105. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +5 -5
  106. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +11 -3
  107. keras_hub/src/models/phi3/__init__.py +1 -2
  108. keras_hub/src/models/phi3/phi3_causal_lm.py +3 -9
  109. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +6 -89
  110. keras_hub/src/models/phi3/phi3_tokenizer.py +12 -36
  111. keras_hub/src/models/preprocessor.py +72 -83
  112. keras_hub/src/models/resnet/__init__.py +6 -0
  113. keras_hub/src/models/resnet/resnet_backbone.py +390 -42
  114. keras_hub/src/models/resnet/resnet_image_classifier.py +33 -6
  115. keras_hub/src/models/resnet/resnet_image_classifier_preprocessor.py +28 -0
  116. keras_hub/src/models/{llama3/llama3_preprocessor.py → resnet/resnet_image_converter.py} +7 -5
  117. keras_hub/src/models/resnet/resnet_presets.py +95 -0
  118. keras_hub/src/models/retinanet/__init__.py +13 -0
  119. keras_hub/src/models/retinanet/anchor_generator.py +175 -0
  120. keras_hub/src/models/retinanet/box_matcher.py +259 -0
  121. keras_hub/src/models/retinanet/non_max_supression.py +578 -0
  122. keras_hub/src/models/roberta/__init__.py +1 -2
  123. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +22 -74
  124. keras_hub/src/models/roberta/{roberta_classifier.py → roberta_text_classifier.py} +16 -11
  125. keras_hub/src/models/roberta/{roberta_preprocessor.py → roberta_text_classifier_preprocessor.py} +21 -53
  126. keras_hub/src/models/roberta/roberta_tokenizer.py +13 -52
  127. keras_hub/src/models/seq_2_seq_lm_preprocessor.py +269 -0
  128. keras_hub/src/models/stable_diffusion_v3/__init__.py +13 -0
  129. keras_hub/src/models/stable_diffusion_v3/clip_encoder_block.py +103 -0
  130. keras_hub/src/models/stable_diffusion_v3/clip_preprocessor.py +93 -0
  131. keras_hub/src/models/stable_diffusion_v3/clip_text_encoder.py +149 -0
  132. keras_hub/src/models/stable_diffusion_v3/clip_tokenizer.py +167 -0
  133. keras_hub/src/models/stable_diffusion_v3/mmdit.py +427 -0
  134. keras_hub/src/models/stable_diffusion_v3/mmdit_block.py +317 -0
  135. keras_hub/src/models/stable_diffusion_v3/t5_xxl_preprocessor.py +74 -0
  136. keras_hub/src/models/stable_diffusion_v3/t5_xxl_text_encoder.py +155 -0
  137. keras_hub/src/models/stable_diffusion_v3/vae_attention.py +126 -0
  138. keras_hub/src/models/stable_diffusion_v3/vae_image_decoder.py +186 -0
  139. keras_hub/src/models/t5/__init__.py +1 -2
  140. keras_hub/src/models/t5/t5_tokenizer.py +13 -23
  141. keras_hub/src/models/task.py +71 -116
  142. keras_hub/src/models/{classifier.py → text_classifier.py} +19 -13
  143. keras_hub/src/models/text_classifier_preprocessor.py +138 -0
  144. keras_hub/src/models/whisper/__init__.py +1 -2
  145. keras_hub/src/models/whisper/{whisper_audio_feature_extractor.py → whisper_audio_converter.py} +20 -18
  146. keras_hub/src/models/whisper/whisper_backbone.py +0 -3
  147. keras_hub/src/models/whisper/whisper_presets.py +10 -10
  148. keras_hub/src/models/whisper/whisper_tokenizer.py +20 -16
  149. keras_hub/src/models/xlm_roberta/__init__.py +1 -4
  150. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +26 -72
  151. keras_hub/src/models/xlm_roberta/{xlm_roberta_classifier.py → xlm_roberta_text_classifier.py} +16 -11
  152. keras_hub/src/models/xlm_roberta/{xlm_roberta_preprocessor.py → xlm_roberta_text_classifier_preprocessor.py} +26 -53
  153. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +25 -10
  154. keras_hub/src/tests/test_case.py +46 -0
  155. keras_hub/src/tokenizers/byte_pair_tokenizer.py +30 -17
  156. keras_hub/src/tokenizers/byte_tokenizer.py +14 -15
  157. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +20 -7
  158. keras_hub/src/tokenizers/tokenizer.py +67 -32
  159. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +14 -15
  160. keras_hub/src/tokenizers/word_piece_tokenizer.py +34 -47
  161. keras_hub/src/utils/imagenet/__init__.py +13 -0
  162. keras_hub/src/utils/imagenet/imagenet_utils.py +1067 -0
  163. keras_hub/src/utils/keras_utils.py +0 -50
  164. keras_hub/src/utils/preset_utils.py +230 -68
  165. keras_hub/src/utils/tensor_utils.py +187 -69
  166. keras_hub/src/utils/timm/convert_resnet.py +19 -16
  167. keras_hub/src/utils/timm/preset_loader.py +66 -0
  168. keras_hub/src/utils/transformers/convert_albert.py +193 -0
  169. keras_hub/src/utils/transformers/convert_bart.py +373 -0
  170. keras_hub/src/utils/transformers/convert_bert.py +7 -17
  171. keras_hub/src/utils/transformers/convert_distilbert.py +10 -20
  172. keras_hub/src/utils/transformers/convert_gemma.py +5 -19
  173. keras_hub/src/utils/transformers/convert_gpt2.py +5 -18
  174. keras_hub/src/utils/transformers/convert_llama3.py +7 -18
  175. keras_hub/src/utils/transformers/convert_mistral.py +129 -0
  176. keras_hub/src/utils/transformers/convert_pali_gemma.py +7 -29
  177. keras_hub/src/utils/transformers/preset_loader.py +77 -0
  178. keras_hub/src/utils/transformers/safetensor_utils.py +2 -2
  179. keras_hub/src/version_utils.py +1 -1
  180. keras_hub_nightly-0.16.0.dev2024092017.dist-info/METADATA +202 -0
  181. keras_hub_nightly-0.16.0.dev2024092017.dist-info/RECORD +334 -0
  182. {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/WHEEL +1 -1
  183. keras_hub/src/models/bart/bart_preprocessor.py +0 -276
  184. keras_hub/src/models/bloom/bloom_preprocessor.py +0 -185
  185. keras_hub/src/models/electra/electra_preprocessor.py +0 -154
  186. keras_hub/src/models/falcon/falcon_preprocessor.py +0 -187
  187. keras_hub/src/models/gemma/gemma_preprocessor.py +0 -191
  188. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +0 -145
  189. keras_hub/src/models/llama/llama_preprocessor.py +0 -189
  190. keras_hub/src/models/mistral/mistral_preprocessor.py +0 -190
  191. keras_hub/src/models/opt/opt_preprocessor.py +0 -188
  192. keras_hub/src/models/phi3/phi3_preprocessor.py +0 -190
  193. keras_hub/src/models/whisper/whisper_preprocessor.py +0 -326
  194. keras_hub/src/utils/timm/convert.py +0 -37
  195. keras_hub/src/utils/transformers/convert.py +0 -101
  196. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +0 -34
  197. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +0 -297
  198. {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,186 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import keras
15
+ from keras import layers
16
+
17
+ from keras_hub.src.models.stable_diffusion_v3.vae_attention import VAEAttention
18
+ from keras_hub.src.utils.keras_utils import standardize_data_format
19
+
20
+
21
+ class VAEImageDecoder(keras.Model):
22
+ def __init__(
23
+ self,
24
+ stackwise_num_filters,
25
+ stackwise_num_blocks,
26
+ output_channels=3,
27
+ latent_shape=(None, None, 16),
28
+ data_format=None,
29
+ dtype=None,
30
+ **kwargs,
31
+ ):
32
+ data_format = standardize_data_format(data_format)
33
+ gn_axis = -1 if data_format == "channels_last" else 1
34
+
35
+ # === Functional Model ===
36
+ latent_inputs = layers.Input(shape=latent_shape)
37
+
38
+ x = layers.Conv2D(
39
+ stackwise_num_filters[0],
40
+ 3,
41
+ 1,
42
+ padding="same",
43
+ data_format=data_format,
44
+ dtype=dtype,
45
+ name="input_projection",
46
+ )(latent_inputs)
47
+ x = apply_resnet_block(
48
+ x,
49
+ stackwise_num_filters[0],
50
+ data_format=data_format,
51
+ dtype=dtype,
52
+ name="input_block0",
53
+ )
54
+ x = VAEAttention(
55
+ stackwise_num_filters[0],
56
+ data_format=data_format,
57
+ dtype=dtype,
58
+ name="input_attention",
59
+ )(x)
60
+ x = apply_resnet_block(
61
+ x,
62
+ stackwise_num_filters[0],
63
+ data_format=data_format,
64
+ dtype=dtype,
65
+ name="input_block1",
66
+ )
67
+
68
+ # Stacks.
69
+ for i, filters in enumerate(stackwise_num_filters):
70
+ for j in range(stackwise_num_blocks[i]):
71
+ x = apply_resnet_block(
72
+ x,
73
+ filters,
74
+ data_format=data_format,
75
+ dtype=dtype,
76
+ name=f"block{i}_{j}",
77
+ )
78
+ if i != len(stackwise_num_filters) - 1:
79
+ # No upsamling in the last blcok.
80
+ x = layers.UpSampling2D(
81
+ 2,
82
+ data_format=data_format,
83
+ dtype=dtype,
84
+ name=f"upsample_{i}",
85
+ )(x)
86
+ x = layers.Conv2D(
87
+ filters,
88
+ 3,
89
+ 1,
90
+ padding="same",
91
+ data_format=data_format,
92
+ dtype=dtype,
93
+ name=f"upsample_{i}_conv",
94
+ )(x)
95
+
96
+ # Ouput block.
97
+ x = layers.GroupNormalization(
98
+ groups=32,
99
+ axis=gn_axis,
100
+ epsilon=1e-6,
101
+ dtype=dtype,
102
+ name="output_norm",
103
+ )(x)
104
+ x = layers.Activation("swish", dtype=dtype, name="output_activation")(x)
105
+ image_outputs = layers.Conv2D(
106
+ output_channels,
107
+ 3,
108
+ 1,
109
+ padding="same",
110
+ data_format=data_format,
111
+ dtype=dtype,
112
+ name="output_projection",
113
+ )(x)
114
+ super().__init__(inputs=latent_inputs, outputs=image_outputs, **kwargs)
115
+
116
+ # === Config ===
117
+ self.stackwise_num_filters = stackwise_num_filters
118
+ self.stackwise_num_blocks = stackwise_num_blocks
119
+ self.output_channels = output_channels
120
+ self.latent_shape = latent_shape
121
+
122
+ if dtype is not None:
123
+ try:
124
+ self.dtype_policy = keras.dtype_policies.get(dtype)
125
+ # Before Keras 3.2, there is no `keras.dtype_policies.get`.
126
+ except AttributeError:
127
+ if isinstance(dtype, keras.DTypePolicy):
128
+ dtype = dtype.name
129
+ self.dtype_policy = keras.DTypePolicy(dtype)
130
+
131
+ def get_config(self):
132
+ config = super().get_config()
133
+ config.update(
134
+ {
135
+ "stackwise_num_filters": self.stackwise_num_filters,
136
+ "stackwise_num_blocks": self.stackwise_num_blocks,
137
+ "output_channels": self.output_channels,
138
+ "image_shape": self.latent_shape,
139
+ }
140
+ )
141
+ return config
142
+
143
+
144
+ def apply_resnet_block(x, filters, data_format=None, dtype=None, name=None):
145
+ data_format = standardize_data_format(data_format)
146
+ gn_axis = -1 if data_format == "channels_last" else 1
147
+ input_filters = x.shape[gn_axis]
148
+
149
+ residual = x
150
+ x = layers.GroupNormalization(
151
+ groups=32, axis=gn_axis, epsilon=1e-6, dtype=dtype, name=f"{name}_norm1"
152
+ )(x)
153
+ x = layers.Activation("swish", dtype=dtype)(x)
154
+ x = layers.Conv2D(
155
+ filters,
156
+ 3,
157
+ 1,
158
+ padding="same",
159
+ data_format=data_format,
160
+ dtype=dtype,
161
+ name=f"{name}_conv1",
162
+ )(x)
163
+ x = layers.GroupNormalization(
164
+ groups=32, axis=gn_axis, epsilon=1e-6, dtype=dtype, name=f"{name}_norm2"
165
+ )(x)
166
+ x = layers.Activation("swish")(x)
167
+ x = layers.Conv2D(
168
+ filters,
169
+ 3,
170
+ 1,
171
+ padding="same",
172
+ data_format=data_format,
173
+ dtype=dtype,
174
+ name=f"{name}_conv2",
175
+ )(x)
176
+ if input_filters != filters:
177
+ residual = layers.Conv2D(
178
+ filters,
179
+ 1,
180
+ 1,
181
+ data_format=data_format,
182
+ dtype=dtype,
183
+ name=f"{name}_residual_projection",
184
+ )(residual)
185
+ x = layers.Add(dtype=dtype)([residual, x])
186
+ return x
@@ -14,7 +14,6 @@
14
14
 
15
15
  from keras_hub.src.models.t5.t5_backbone import T5Backbone
16
16
  from keras_hub.src.models.t5.t5_presets import backbone_presets
17
- from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer
18
17
  from keras_hub.src.utils.preset_utils import register_presets
19
18
 
20
- register_presets(backbone_presets, (T5Backbone, T5Tokenizer))
19
+ register_presets(backbone_presets, T5Backbone)
@@ -13,12 +13,18 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from keras_hub.src.api_export import keras_hub_export
16
+ from keras_hub.src.models.t5.t5_backbone import T5Backbone
16
17
  from keras_hub.src.tokenizers.sentence_piece_tokenizer import (
17
18
  SentencePieceTokenizer,
18
19
  )
19
20
 
20
21
 
21
- @keras_hub_export("keras_hub.models.T5Tokenizer")
22
+ @keras_hub_export(
23
+ [
24
+ "keras_hub.tokenizers.T5Tokenizer",
25
+ "keras_hub.models.T5Tokenizer",
26
+ ]
27
+ )
22
28
  class T5Tokenizer(SentencePieceTokenizer):
23
29
  """T5 tokenizer layer based on SentencePiece.
24
30
 
@@ -74,27 +80,11 @@ class T5Tokenizer(SentencePieceTokenizer):
74
80
  ```
75
81
  """
76
82
 
77
- def __init__(self, proto, **kwargs):
78
- self.end_token = "</s>"
79
- self.pad_token = "<pad>"
83
+ backbone_cls = T5Backbone
80
84
 
85
+ def __init__(self, proto, **kwargs):
86
+ # T5 uses the same start token as end token, i.e., "<\s>".
87
+ self._add_special_token("</s>", "end_token")
88
+ self._add_special_token("</s>", "start_token")
89
+ self._add_special_token("<pad>", "pad_token")
81
90
  super().__init__(proto=proto, **kwargs)
82
-
83
- def set_proto(self, proto):
84
- super().set_proto(proto)
85
- if proto is not None:
86
- for token in [self.end_token, self.pad_token]:
87
- if token not in self.get_vocabulary():
88
- raise ValueError(
89
- f"Cannot find token `'{token}'` in the provided "
90
- f"`vocabulary`. Please provide `'{token}'` in your "
91
- "`vocabulary` or use a pretrained `vocabulary` name."
92
- )
93
- self.end_token_id = self.token_to_id(self.end_token)
94
- self.pad_token_id = self.token_to_id(self.pad_token)
95
- # T5 uses the same start token as end token, i.e., "<\s>".
96
- self.start_token_id = self.end_token_id
97
- else:
98
- self.end_token_id = None
99
- self.pad_token_id = None
100
- self.start_token_id = None
@@ -22,18 +22,11 @@ from rich import table as rich_table
22
22
  from keras_hub.src.api_export import keras_hub_export
23
23
  from keras_hub.src.utils.keras_utils import print_msg
24
24
  from keras_hub.src.utils.pipeline_model import PipelineModel
25
- from keras_hub.src.utils.preset_utils import CONFIG_FILE
26
- from keras_hub.src.utils.preset_utils import MODEL_WEIGHTS_FILE
27
25
  from keras_hub.src.utils.preset_utils import TASK_CONFIG_FILE
28
26
  from keras_hub.src.utils.preset_utils import TASK_WEIGHTS_FILE
29
- from keras_hub.src.utils.preset_utils import check_config_class
30
- from keras_hub.src.utils.preset_utils import check_file_exists
31
- from keras_hub.src.utils.preset_utils import check_format
32
- from keras_hub.src.utils.preset_utils import get_file
33
- from keras_hub.src.utils.preset_utils import jax_memory_cleanup
34
- from keras_hub.src.utils.preset_utils import list_presets
35
- from keras_hub.src.utils.preset_utils import list_subclasses
36
- from keras_hub.src.utils.preset_utils import load_serialized_object
27
+ from keras_hub.src.utils.preset_utils import builtin_presets
28
+ from keras_hub.src.utils.preset_utils import find_subclass
29
+ from keras_hub.src.utils.preset_utils import get_preset_loader
37
30
  from keras_hub.src.utils.preset_utils import save_serialized_object
38
31
  from keras_hub.src.utils.python_utils import classproperty
39
32
 
@@ -56,12 +49,17 @@ class Task(PipelineModel):
56
49
  to load a pre-trained config and weights. Calling `from_preset()` on a task
57
50
  will automatically instantiate a `keras_hub.models.Backbone` and
58
51
  `keras_hub.models.Preprocessor`.
52
+
53
+ Args:
54
+ compile: boolean, defaults to `True`. If `True` will compile the model
55
+ with default parameters on construction. Model can still be
56
+ recompiled with a new loss, optimizer and metrics before training.
59
57
  """
60
58
 
61
59
  backbone_cls = None
62
60
  preprocessor_cls = None
63
61
 
64
- def __init__(self, *args, **kwargs):
62
+ def __init__(self, *args, compile=True, **kwargs):
65
63
  super().__init__(*args, **kwargs)
66
64
  self._functional_layer_ids = set(
67
65
  id(layer) for layer in self._flatten_layers()
@@ -69,6 +67,9 @@ class Task(PipelineModel):
69
67
  self._initialized = True
70
68
  if self.backbone is not None:
71
69
  self.dtype_policy = self._backbone.dtype_policy
70
+ if compile:
71
+ # Default compilation.
72
+ self.compile()
72
73
 
73
74
  def preprocess_samples(self, x, y=None, sample_weight=None):
74
75
  if self.preprocessor is not None:
@@ -131,13 +132,7 @@ class Task(PipelineModel):
131
132
  @classproperty
132
133
  def presets(cls):
133
134
  """List built-in presets for a `Task` subclass."""
134
- presets = list_presets(cls)
135
- # We can also load backbone presets.
136
- if cls.backbone_cls is not None:
137
- presets.update(cls.backbone_cls.presets)
138
- for subclass in list_subclasses(cls):
139
- presets.update(subclass.presets)
140
- return presets
135
+ return builtin_presets(cls)
141
136
 
142
137
  @classmethod
143
138
  def from_preset(
@@ -149,10 +144,10 @@ class Task(PipelineModel):
149
144
  """Instantiate a `keras_hub.models.Task` from a model preset.
150
145
 
151
146
  A preset is a directory of configs, weights and other file assets used
152
- to save and load a pre-trained model. The `preset` can be passed as a
147
+ to save and load a pre-trained model. The `preset` can be passed as
153
148
  one of:
154
149
 
155
- 1. a built in preset identifier like `'bert_base_en'`
150
+ 1. a built-in preset identifier like `'bert_base_en'`
156
151
  2. a Kaggle Models handle like `'kaggle://user/bert/keras/bert_base_en'`
157
152
  3. a Hugging Face handle like `'hf://user/bert_base_en'`
158
153
  4. a path to a local preset directory like `'./bert_base_en'`
@@ -162,16 +157,16 @@ class Task(PipelineModel):
162
157
 
163
158
  This constructor can be called in one of two ways. Either from a task
164
159
  specific base class like `keras_hub.models.CausalLM.from_preset()`, or
165
- from a model class like `keras_hub.models.BertClassifier.from_preset()`.
160
+ from a model class like `keras_hub.models.BertTextClassifier.from_preset()`.
166
161
  If calling from the a base class, the subclass of the returning object
167
162
  will be inferred from the config in the preset directory.
168
163
 
169
164
  Args:
170
- preset: string. A built in preset identifier, a Kaggle Models
165
+ preset: string. A built-in preset identifier, a Kaggle Models
171
166
  handle, a Hugging Face handle, or a path to a local directory.
172
- load_weights: bool. If `True`, the weights will be loaded into the
173
- model architecture. If `False`, the weights will be randomly
174
- initialized.
167
+ load_weights: bool. If `True`, saved weights will be loaded into
168
+ the model architecture. If `False`, all weights will be
169
+ randomly initialized.
175
170
 
176
171
  Examples:
177
172
  ```python
@@ -181,100 +176,37 @@ class Task(PipelineModel):
181
176
  )
182
177
 
183
178
  # Load a Bert classification task.
184
- model = keras_hub.models.Classifier.from_preset(
179
+ model = keras_hub.models.TextClassifier.from_preset(
185
180
  "bert_base_en",
186
181
  num_classes=2,
187
182
  )
188
183
  ```
189
184
  """
190
- format = check_format(preset)
191
-
192
- if format == "transformers":
193
- if cls.backbone_cls is None:
194
- raise ValueError("Backbone class is None")
195
- if cls.preprocessor_cls is None:
196
- raise ValueError("Preprocessor class is None")
197
-
198
- backbone = cls.backbone_cls.from_preset(preset)
199
- preprocessor = cls.preprocessor_cls.from_preset(preset)
200
- return cls(backbone=backbone, preprocessor=preprocessor, **kwargs)
201
-
202
185
  if cls == Task:
203
186
  raise ValueError(
204
187
  "Do not call `Task.from_preset()` directly. Instead call a "
205
188
  "particular task class, e.g. "
206
- "`keras_hub.models.Classifier.from_preset()` or "
207
- "`keras_hub.models.BertClassifier.from_preset()`."
208
- )
209
- if "backbone" in kwargs:
210
- raise ValueError(
211
- "You cannot pass a `backbone` argument to the `from_preset` "
212
- f"method. Instead, call the {cls.__name__} default "
213
- "constructor with a `backbone` argument. "
214
- f"Received: backbone={kwargs['backbone']}."
189
+ "`keras_hub.models.TextClassifier.from_preset()`."
215
190
  )
216
191
 
217
- # Check if we should load a `task.json` directly.
218
- load_task_config = False
219
- if check_file_exists(preset, TASK_CONFIG_FILE):
220
- task_preset_cls = check_config_class(preset, TASK_CONFIG_FILE)
221
- if issubclass(task_preset_cls, cls):
222
- load_task_config = True
223
- if load_task_config:
224
- # Task case.
225
- task_preset_cls = check_config_class(preset, TASK_CONFIG_FILE)
226
- task = load_serialized_object(preset, TASK_CONFIG_FILE)
227
- if load_weights:
228
- jax_memory_cleanup(task)
229
- if check_file_exists(preset, TASK_WEIGHTS_FILE):
230
- task.load_task_weights(get_file(preset, TASK_WEIGHTS_FILE))
231
- task.backbone.load_weights(get_file(preset, MODEL_WEIGHTS_FILE))
232
- task.preprocessor.tokenizer.load_preset_assets(preset)
233
- return task
234
-
235
- # Backbone case.
236
- # If `task.json` doesn't exist or the task preset class is different
237
- # from the calling class, create the task based on `config.json`.
238
- backbone_preset_cls = check_config_class(preset, CONFIG_FILE)
239
- if backbone_preset_cls is not cls.backbone_cls:
240
- subclasses = list_subclasses(cls)
241
- subclasses = tuple(
242
- filter(
243
- lambda x: x.backbone_cls == backbone_preset_cls,
244
- subclasses,
245
- )
246
- )
247
- if len(subclasses) == 0:
248
- raise ValueError(
249
- f"No registered subclass of `{cls.__name__}` can load "
250
- f"a `{backbone_preset_cls.__name__}`."
251
- )
252
- if len(subclasses) > 1:
253
- names = ", ".join(f"`{x.__name__}`" for x in subclasses)
254
- raise ValueError(
255
- f"Ambiguous call to `{cls.__name__}.from_preset()`. "
256
- f"Found multiple possible subclasses {names}. "
257
- "Please call `from_preset` on a subclass directly."
258
- )
259
- cls = subclasses[0]
260
- # Forward dtype to the backbone.
261
- backbone_kwargs = {}
262
- if "dtype" in kwargs:
263
- backbone_kwargs = {"dtype": kwargs.pop("dtype")}
264
- backbone = backbone_preset_cls.from_preset(
265
- preset, load_weights=load_weights, **backbone_kwargs
266
- )
267
- if "preprocessor" in kwargs:
268
- preprocessor = kwargs.pop("preprocessor")
269
- else:
270
- preprocessor = cls.preprocessor_cls.from_preset(preset)
271
- return cls(backbone=backbone, preprocessor=preprocessor, **kwargs)
192
+ loader = get_preset_loader(preset)
193
+ backbone_cls = loader.check_backbone_class()
194
+ # Detect the correct subclass if we need to.
195
+ if cls.backbone_cls != backbone_cls:
196
+ cls = find_subclass(preset, cls, backbone_cls)
197
+ # Specifically for classifiers, we never load task weights if
198
+ # num_classes is supplied. We handle this in the task base class because
199
+ # it is the same logic for classifiers regardless of modality (text,
200
+ # images, audio).
201
+ load_task_weights = "num_classes" not in kwargs
202
+ return loader.load_task(cls, load_weights, load_task_weights, **kwargs)
272
203
 
273
204
  def load_task_weights(self, filepath):
274
205
  """Load only the tasks specific weights not in the backbone."""
275
206
  if not str(filepath).endswith(".weights.h5"):
276
207
  raise ValueError(
277
- "The filename must end in `.weights.h5`. Received: filepath={filepath}"
208
+ "The filename must end in `.weights.h5`. "
209
+ f"Received: filepath={filepath}"
278
210
  )
279
211
  backbone_layer_ids = set(id(w) for w in self.backbone._flatten_layers())
280
212
  keras.saving.load_weights(
@@ -361,7 +293,9 @@ class Task(PipelineModel):
361
293
  print_fn = print_msg
362
294
 
363
295
  def highlight_number(x):
364
- return f"[color(45)]{x}[/]" if x is None else f"[color(34)]{x}[/]"
296
+ if x is None:
297
+ f"[color(45)]{x}[/]"
298
+ return f"[color(34)]{x:,}[/]" # Format number with commas.
365
299
 
366
300
  def highlight_symbol(x):
367
301
  return f"[color(33)]{x}[/]"
@@ -369,6 +303,10 @@ class Task(PipelineModel):
369
303
  def bold_text(x):
370
304
  return f"[bold]{x}[/]"
371
305
 
306
+ def highlight_shape(shape):
307
+ highlighted = [highlight_number(x) for x in shape]
308
+ return "(" + ", ".join(highlighted) + ")"
309
+
372
310
  if self.preprocessor:
373
311
  # Create a rich console for printing. Capture for non-interactive logging.
374
312
  if print_fn:
@@ -380,27 +318,44 @@ class Task(PipelineModel):
380
318
  console = rich_console.Console(highlight=False)
381
319
 
382
320
  column_1 = rich_table.Column(
383
- "Tokenizer (type)",
321
+ "Layer (type)",
384
322
  justify="left",
385
- width=int(0.5 * line_length),
323
+ width=int(0.6 * line_length),
386
324
  )
387
325
  column_2 = rich_table.Column(
388
- "Vocab #",
326
+ "Config",
389
327
  justify="right",
390
- width=int(0.5 * line_length),
328
+ width=int(0.4 * line_length),
391
329
  )
392
330
  table = rich_table.Table(
393
331
  column_1, column_2, width=line_length, show_lines=True
394
332
  )
333
+
334
+ def add_layer(layer, info):
335
+ layer_name = markup.escape(layer.name)
336
+ layer_class = highlight_symbol(
337
+ markup.escape(layer.__class__.__name__)
338
+ )
339
+ table.add_row(
340
+ f"{layer_name} ({layer_class})",
341
+ info,
342
+ )
343
+
395
344
  tokenizer = self.preprocessor.tokenizer
396
- tokenizer_name = markup.escape(tokenizer.name)
397
- tokenizer_class = highlight_symbol(
398
- markup.escape(tokenizer.__class__.__name__)
399
- )
400
- table.add_row(
401
- f"{tokenizer_name} ({tokenizer_class})",
402
- highlight_number(f"{tokenizer.vocabulary_size():,}"),
403
- )
345
+ if tokenizer:
346
+ info = "Vocab size: "
347
+ info += highlight_number(tokenizer.vocabulary_size())
348
+ add_layer(tokenizer, info)
349
+ image_converter = self.preprocessor.image_converter
350
+ if image_converter:
351
+ info = "Image size: "
352
+ info += highlight_shape(image_converter.image_size())
353
+ add_layer(image_converter, info)
354
+ audio_converter = self.preprocessor.audio_converter
355
+ if audio_converter:
356
+ info = "Audio shape: "
357
+ info += highlight_shape(audio_converter.audio_shape())
358
+ add_layer(audio_converter, info)
404
359
 
405
360
  # Print the to the console.
406
361
  preprocessor_name = markup.escape(self.preprocessor.name)
@@ -17,25 +17,36 @@ from keras_hub.src.api_export import keras_hub_export
17
17
  from keras_hub.src.models.task import Task
18
18
 
19
19
 
20
- @keras_hub_export("keras_hub.models.Classifier")
21
- class Classifier(Task):
20
+ @keras_hub_export(
21
+ [
22
+ "keras_hub.models.TextClassifier",
23
+ "keras_hub.models.Classifier",
24
+ ]
25
+ )
26
+ class TextClassifier(Task):
22
27
  """Base class for all classification tasks.
23
28
 
24
- `Classifier` tasks wrap a `keras_hub.models.Backbone` and
29
+ `TextClassifier` tasks wrap a `keras_hub.models.Backbone` and
25
30
  a `keras_hub.models.Preprocessor` to create a model that can be used for
26
- sequence classification. `Classifier` tasks take an additional
31
+ sequence classification. `TextClassifier` tasks take an additional
27
32
  `num_classes` argument, controlling the number of predicted output classes.
28
33
 
29
34
  To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)`
30
35
  labels where `x` is a string and `y` is a integer from `[0, num_classes)`.
31
36
 
32
- All `Classifier` tasks include a `from_preset()` constructor which can be
37
+ All `TextClassifier` tasks include a `from_preset()` constructor which can be
33
38
  used to load a pre-trained config and weights.
34
39
 
40
+ Some, but not all, classification presets include classification head
41
+ weights in a `task.weights.h5` file. For these presets, you can omit passing
42
+ `num_classes` to restore the saved classification head. For all presets, if
43
+ `num_classes` is passed as a kwarg to `from_preset()`, the classification
44
+ head will be randomly initialized.
45
+
35
46
  Example:
36
47
  ```python
37
48
  # Load a BERT classifier with pre-trained weights.
38
- classifier = keras_hub.models.Classifier.from_preset(
49
+ classifier = keras_hub.models.TextClassifier.from_preset(
39
50
  "bert_base_en",
40
51
  num_classes=2,
41
52
  )
@@ -52,11 +63,6 @@ class Classifier(Task):
52
63
  ```
53
64
  """
54
65
 
55
- def __init__(self, *args, **kwargs):
56
- super().__init__(*args, **kwargs)
57
- # Default compilation.
58
- self.compile()
59
-
60
66
  def compile(
61
67
  self,
62
68
  optimizer="auto",
@@ -65,9 +71,9 @@ class Classifier(Task):
65
71
  metrics="auto",
66
72
  **kwargs,
67
73
  ):
68
- """Configures the `Classifier` task for training.
74
+ """Configures the `TextClassifier` task for training.
69
75
 
70
- The `Classifier` task extends the default compilation signature of
76
+ The `TextClassifier` task extends the default compilation signature of
71
77
  `keras.Model.compile` with defaults for `optimizer`, `loss`, and
72
78
  `metrics`. To override these defaults, pass any value
73
79
  to these arguments during compilation.