keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl → 0.16.0.dev2024092017__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (198) hide show
  1. keras_hub/__init__.py +0 -6
  2. keras_hub/api/__init__.py +2 -0
  3. keras_hub/api/bounding_box/__init__.py +36 -0
  4. keras_hub/api/layers/__init__.py +14 -0
  5. keras_hub/api/models/__init__.py +97 -48
  6. keras_hub/api/tokenizers/__init__.py +30 -0
  7. keras_hub/api/utils/__init__.py +22 -0
  8. keras_hub/src/api_export.py +15 -9
  9. keras_hub/src/bounding_box/__init__.py +13 -0
  10. keras_hub/src/bounding_box/converters.py +529 -0
  11. keras_hub/src/bounding_box/formats.py +162 -0
  12. keras_hub/src/bounding_box/iou.py +263 -0
  13. keras_hub/src/bounding_box/to_dense.py +95 -0
  14. keras_hub/src/bounding_box/to_ragged.py +99 -0
  15. keras_hub/src/bounding_box/utils.py +194 -0
  16. keras_hub/src/bounding_box/validate_format.py +99 -0
  17. keras_hub/src/layers/preprocessing/audio_converter.py +121 -0
  18. keras_hub/src/layers/preprocessing/image_converter.py +130 -0
  19. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +2 -0
  20. keras_hub/src/layers/preprocessing/multi_segment_packer.py +9 -8
  21. keras_hub/src/layers/preprocessing/preprocessing_layer.py +2 -29
  22. keras_hub/src/layers/preprocessing/random_deletion.py +33 -31
  23. keras_hub/src/layers/preprocessing/random_swap.py +33 -31
  24. keras_hub/src/layers/preprocessing/resizing_image_converter.py +101 -0
  25. keras_hub/src/layers/preprocessing/start_end_packer.py +3 -2
  26. keras_hub/src/models/albert/__init__.py +1 -2
  27. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +6 -86
  28. keras_hub/src/models/albert/{albert_classifier.py → albert_text_classifier.py} +34 -10
  29. keras_hub/src/models/albert/{albert_preprocessor.py → albert_text_classifier_preprocessor.py} +14 -70
  30. keras_hub/src/models/albert/albert_tokenizer.py +17 -36
  31. keras_hub/src/models/backbone.py +12 -34
  32. keras_hub/src/models/bart/__init__.py +1 -2
  33. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +21 -148
  34. keras_hub/src/models/bart/bart_tokenizer.py +12 -39
  35. keras_hub/src/models/bert/__init__.py +1 -5
  36. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +6 -87
  37. keras_hub/src/models/bert/bert_presets.py +1 -4
  38. keras_hub/src/models/bert/{bert_classifier.py → bert_text_classifier.py} +19 -12
  39. keras_hub/src/models/bert/{bert_preprocessor.py → bert_text_classifier_preprocessor.py} +14 -70
  40. keras_hub/src/models/bert/bert_tokenizer.py +17 -35
  41. keras_hub/src/models/bloom/__init__.py +1 -2
  42. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +6 -91
  43. keras_hub/src/models/bloom/bloom_tokenizer.py +12 -41
  44. keras_hub/src/models/causal_lm.py +10 -29
  45. keras_hub/src/models/causal_lm_preprocessor.py +195 -0
  46. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +54 -15
  47. keras_hub/src/models/deberta_v3/__init__.py +1 -4
  48. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +14 -77
  49. keras_hub/src/models/deberta_v3/{deberta_v3_classifier.py → deberta_v3_text_classifier.py} +16 -11
  50. keras_hub/src/models/deberta_v3/{deberta_v3_preprocessor.py → deberta_v3_text_classifier_preprocessor.py} +23 -64
  51. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +30 -25
  52. keras_hub/src/models/densenet/densenet_backbone.py +46 -22
  53. keras_hub/src/models/distil_bert/__init__.py +1 -4
  54. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +14 -76
  55. keras_hub/src/models/distil_bert/{distil_bert_classifier.py → distil_bert_text_classifier.py} +17 -12
  56. keras_hub/src/models/distil_bert/{distil_bert_preprocessor.py → distil_bert_text_classifier_preprocessor.py} +23 -63
  57. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +19 -35
  58. keras_hub/src/models/efficientnet/__init__.py +13 -0
  59. keras_hub/src/models/efficientnet/efficientnet_backbone.py +569 -0
  60. keras_hub/src/models/efficientnet/fusedmbconv.py +229 -0
  61. keras_hub/src/models/efficientnet/mbconv.py +238 -0
  62. keras_hub/src/models/electra/__init__.py +1 -2
  63. keras_hub/src/models/electra/electra_tokenizer.py +17 -32
  64. keras_hub/src/models/f_net/__init__.py +1 -2
  65. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +12 -78
  66. keras_hub/src/models/f_net/{f_net_classifier.py → f_net_text_classifier.py} +17 -10
  67. keras_hub/src/models/f_net/{f_net_preprocessor.py → f_net_text_classifier_preprocessor.py} +19 -63
  68. keras_hub/src/models/f_net/f_net_tokenizer.py +17 -35
  69. keras_hub/src/models/falcon/__init__.py +1 -2
  70. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +6 -89
  71. keras_hub/src/models/falcon/falcon_tokenizer.py +12 -35
  72. keras_hub/src/models/gemma/__init__.py +1 -2
  73. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +6 -90
  74. keras_hub/src/models/gemma/gemma_decoder_block.py +1 -1
  75. keras_hub/src/models/gemma/gemma_tokenizer.py +12 -23
  76. keras_hub/src/models/gpt2/__init__.py +1 -2
  77. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +6 -89
  78. keras_hub/src/models/gpt2/gpt2_preprocessor.py +12 -90
  79. keras_hub/src/models/gpt2/gpt2_tokenizer.py +12 -34
  80. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +6 -91
  81. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +12 -34
  82. keras_hub/src/models/image_classifier.py +0 -5
  83. keras_hub/src/models/image_classifier_preprocessor.py +83 -0
  84. keras_hub/src/models/llama/__init__.py +1 -2
  85. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +6 -85
  86. keras_hub/src/models/llama/llama_tokenizer.py +12 -25
  87. keras_hub/src/models/llama3/__init__.py +1 -2
  88. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +6 -89
  89. keras_hub/src/models/llama3/llama3_tokenizer.py +12 -33
  90. keras_hub/src/models/masked_lm.py +0 -2
  91. keras_hub/src/models/masked_lm_preprocessor.py +156 -0
  92. keras_hub/src/models/mistral/__init__.py +1 -2
  93. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +6 -91
  94. keras_hub/src/models/mistral/mistral_tokenizer.py +12 -23
  95. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +2 -2
  96. keras_hub/src/models/mobilenet/__init__.py +13 -0
  97. keras_hub/src/models/mobilenet/mobilenet_backbone.py +530 -0
  98. keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +114 -0
  99. keras_hub/src/models/opt/__init__.py +1 -2
  100. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +6 -93
  101. keras_hub/src/models/opt/opt_tokenizer.py +12 -41
  102. keras_hub/src/models/pali_gemma/__init__.py +1 -4
  103. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +28 -28
  104. keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +25 -0
  105. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +5 -5
  106. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +11 -3
  107. keras_hub/src/models/phi3/__init__.py +1 -2
  108. keras_hub/src/models/phi3/phi3_causal_lm.py +3 -9
  109. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +6 -89
  110. keras_hub/src/models/phi3/phi3_tokenizer.py +12 -36
  111. keras_hub/src/models/preprocessor.py +72 -83
  112. keras_hub/src/models/resnet/__init__.py +6 -0
  113. keras_hub/src/models/resnet/resnet_backbone.py +390 -42
  114. keras_hub/src/models/resnet/resnet_image_classifier.py +33 -6
  115. keras_hub/src/models/resnet/resnet_image_classifier_preprocessor.py +28 -0
  116. keras_hub/src/models/{llama3/llama3_preprocessor.py → resnet/resnet_image_converter.py} +7 -5
  117. keras_hub/src/models/resnet/resnet_presets.py +95 -0
  118. keras_hub/src/models/retinanet/__init__.py +13 -0
  119. keras_hub/src/models/retinanet/anchor_generator.py +175 -0
  120. keras_hub/src/models/retinanet/box_matcher.py +259 -0
  121. keras_hub/src/models/retinanet/non_max_supression.py +578 -0
  122. keras_hub/src/models/roberta/__init__.py +1 -2
  123. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +22 -74
  124. keras_hub/src/models/roberta/{roberta_classifier.py → roberta_text_classifier.py} +16 -11
  125. keras_hub/src/models/roberta/{roberta_preprocessor.py → roberta_text_classifier_preprocessor.py} +21 -53
  126. keras_hub/src/models/roberta/roberta_tokenizer.py +13 -52
  127. keras_hub/src/models/seq_2_seq_lm_preprocessor.py +269 -0
  128. keras_hub/src/models/stable_diffusion_v3/__init__.py +13 -0
  129. keras_hub/src/models/stable_diffusion_v3/clip_encoder_block.py +103 -0
  130. keras_hub/src/models/stable_diffusion_v3/clip_preprocessor.py +93 -0
  131. keras_hub/src/models/stable_diffusion_v3/clip_text_encoder.py +149 -0
  132. keras_hub/src/models/stable_diffusion_v3/clip_tokenizer.py +167 -0
  133. keras_hub/src/models/stable_diffusion_v3/mmdit.py +427 -0
  134. keras_hub/src/models/stable_diffusion_v3/mmdit_block.py +317 -0
  135. keras_hub/src/models/stable_diffusion_v3/t5_xxl_preprocessor.py +74 -0
  136. keras_hub/src/models/stable_diffusion_v3/t5_xxl_text_encoder.py +155 -0
  137. keras_hub/src/models/stable_diffusion_v3/vae_attention.py +126 -0
  138. keras_hub/src/models/stable_diffusion_v3/vae_image_decoder.py +186 -0
  139. keras_hub/src/models/t5/__init__.py +1 -2
  140. keras_hub/src/models/t5/t5_tokenizer.py +13 -23
  141. keras_hub/src/models/task.py +71 -116
  142. keras_hub/src/models/{classifier.py → text_classifier.py} +19 -13
  143. keras_hub/src/models/text_classifier_preprocessor.py +138 -0
  144. keras_hub/src/models/whisper/__init__.py +1 -2
  145. keras_hub/src/models/whisper/{whisper_audio_feature_extractor.py → whisper_audio_converter.py} +20 -18
  146. keras_hub/src/models/whisper/whisper_backbone.py +0 -3
  147. keras_hub/src/models/whisper/whisper_presets.py +10 -10
  148. keras_hub/src/models/whisper/whisper_tokenizer.py +20 -16
  149. keras_hub/src/models/xlm_roberta/__init__.py +1 -4
  150. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +26 -72
  151. keras_hub/src/models/xlm_roberta/{xlm_roberta_classifier.py → xlm_roberta_text_classifier.py} +16 -11
  152. keras_hub/src/models/xlm_roberta/{xlm_roberta_preprocessor.py → xlm_roberta_text_classifier_preprocessor.py} +26 -53
  153. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +25 -10
  154. keras_hub/src/tests/test_case.py +46 -0
  155. keras_hub/src/tokenizers/byte_pair_tokenizer.py +30 -17
  156. keras_hub/src/tokenizers/byte_tokenizer.py +14 -15
  157. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +20 -7
  158. keras_hub/src/tokenizers/tokenizer.py +67 -32
  159. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +14 -15
  160. keras_hub/src/tokenizers/word_piece_tokenizer.py +34 -47
  161. keras_hub/src/utils/imagenet/__init__.py +13 -0
  162. keras_hub/src/utils/imagenet/imagenet_utils.py +1067 -0
  163. keras_hub/src/utils/keras_utils.py +0 -50
  164. keras_hub/src/utils/preset_utils.py +230 -68
  165. keras_hub/src/utils/tensor_utils.py +187 -69
  166. keras_hub/src/utils/timm/convert_resnet.py +19 -16
  167. keras_hub/src/utils/timm/preset_loader.py +66 -0
  168. keras_hub/src/utils/transformers/convert_albert.py +193 -0
  169. keras_hub/src/utils/transformers/convert_bart.py +373 -0
  170. keras_hub/src/utils/transformers/convert_bert.py +7 -17
  171. keras_hub/src/utils/transformers/convert_distilbert.py +10 -20
  172. keras_hub/src/utils/transformers/convert_gemma.py +5 -19
  173. keras_hub/src/utils/transformers/convert_gpt2.py +5 -18
  174. keras_hub/src/utils/transformers/convert_llama3.py +7 -18
  175. keras_hub/src/utils/transformers/convert_mistral.py +129 -0
  176. keras_hub/src/utils/transformers/convert_pali_gemma.py +7 -29
  177. keras_hub/src/utils/transformers/preset_loader.py +77 -0
  178. keras_hub/src/utils/transformers/safetensor_utils.py +2 -2
  179. keras_hub/src/version_utils.py +1 -1
  180. keras_hub_nightly-0.16.0.dev2024092017.dist-info/METADATA +202 -0
  181. keras_hub_nightly-0.16.0.dev2024092017.dist-info/RECORD +334 -0
  182. {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/WHEEL +1 -1
  183. keras_hub/src/models/bart/bart_preprocessor.py +0 -276
  184. keras_hub/src/models/bloom/bloom_preprocessor.py +0 -185
  185. keras_hub/src/models/electra/electra_preprocessor.py +0 -154
  186. keras_hub/src/models/falcon/falcon_preprocessor.py +0 -187
  187. keras_hub/src/models/gemma/gemma_preprocessor.py +0 -191
  188. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +0 -145
  189. keras_hub/src/models/llama/llama_preprocessor.py +0 -189
  190. keras_hub/src/models/mistral/mistral_preprocessor.py +0 -190
  191. keras_hub/src/models/opt/opt_preprocessor.py +0 -188
  192. keras_hub/src/models/phi3/phi3_preprocessor.py +0 -190
  193. keras_hub/src/models/whisper/whisper_preprocessor.py +0 -326
  194. keras_hub/src/utils/timm/convert.py +0 -37
  195. keras_hub/src/utils/transformers/convert.py +0 -101
  196. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +0 -34
  197. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +0 -297
  198. {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/top_level.txt +0 -0
@@ -14,14 +14,13 @@
14
14
  import keras
15
15
 
16
16
  from keras_hub.src.api_export import keras_hub_export
17
- from keras_hub.src.models.backbone import Backbone
17
+ from keras_hub.src.models.feature_pyramid_backbone import FeaturePyramidBackbone
18
18
 
19
- BN_AXIS = 3
20
19
  BN_EPSILON = 1.001e-5
21
20
 
22
21
 
23
22
  @keras_hub_export("keras_hub.models.DenseNetBackbone")
24
- class DenseNetBackbone(Backbone):
23
+ class DenseNetBackbone(FeaturePyramidBackbone):
25
24
  """Instantiates the DenseNet architecture.
26
25
 
27
26
  This class implements a DenseNet backbone as described in
@@ -35,7 +34,7 @@ class DenseNetBackbone(Backbone):
35
34
  include_rescaling: bool, whether to rescale the inputs. If set
36
35
  to `True`, inputs will be passed through a `Rescaling(1/255.0)`
37
36
  layer. Defaults to `True`.
38
- image_shape: optional shape tuple, defaults to (224, 224, 3).
37
+ image_shape: optional shape tuple, defaults to (None, None, 3).
39
38
  compression_ratio: float, compression rate at transition layers,
40
39
  defaults to 0.5.
41
40
  growth_rate: int, number of filters added by each dense block,
@@ -62,12 +61,14 @@ class DenseNetBackbone(Backbone):
62
61
  self,
63
62
  stackwise_num_repeats,
64
63
  include_rescaling=True,
65
- image_shape=(224, 224, 3),
64
+ image_shape=(None, None, 3),
66
65
  compression_ratio=0.5,
67
66
  growth_rate=32,
68
67
  **kwargs,
69
68
  ):
70
69
  # === Functional Model ===
70
+ data_format = keras.config.image_data_format()
71
+ channel_axis = -1 if data_format == "channels_last" else 1
71
72
  image_input = keras.layers.Input(shape=image_shape)
72
73
 
73
74
  x = image_input
@@ -75,37 +76,47 @@ class DenseNetBackbone(Backbone):
75
76
  x = keras.layers.Rescaling(1 / 255.0)(x)
76
77
 
77
78
  x = keras.layers.Conv2D(
78
- 64, 7, strides=2, use_bias=False, padding="same", name="conv1_conv"
79
+ 64,
80
+ 7,
81
+ strides=2,
82
+ use_bias=False,
83
+ padding="same",
84
+ data_format=data_format,
85
+ name="conv1_conv",
79
86
  )(x)
80
87
  x = keras.layers.BatchNormalization(
81
- axis=BN_AXIS, epsilon=BN_EPSILON, name="conv1_bn"
88
+ axis=channel_axis, epsilon=BN_EPSILON, name="conv1_bn"
82
89
  )(x)
83
90
  x = keras.layers.Activation("relu", name="conv1_relu")(x)
84
91
  x = keras.layers.MaxPooling2D(
85
- 3, strides=2, padding="same", name="pool1"
92
+ 3, strides=2, padding="same", data_format=data_format, name="pool1"
86
93
  )(x)
87
94
 
95
+ pyramid_outputs = {}
88
96
  for stack_index in range(len(stackwise_num_repeats) - 1):
89
97
  index = stack_index + 2
90
98
  x = apply_dense_block(
91
99
  x,
100
+ channel_axis,
92
101
  stackwise_num_repeats[stack_index],
93
102
  growth_rate,
94
103
  name=f"conv{index}",
95
104
  )
105
+ pyramid_outputs[f"P{index}"] = x
96
106
  x = apply_transition_block(
97
- x, compression_ratio, name=f"pool{index}"
107
+ x, channel_axis, compression_ratio, name=f"pool{index}"
98
108
  )
99
109
 
100
110
  x = apply_dense_block(
101
111
  x,
112
+ channel_axis,
102
113
  stackwise_num_repeats[-1],
103
114
  growth_rate,
104
115
  name=f"conv{len(stackwise_num_repeats) + 1}",
105
116
  )
106
-
117
+ pyramid_outputs[f"P{len(stackwise_num_repeats) + 1}"] = x
107
118
  x = keras.layers.BatchNormalization(
108
- axis=BN_AXIS, epsilon=BN_EPSILON, name="bn"
119
+ axis=channel_axis, epsilon=BN_EPSILON, name="bn"
109
120
  )(x)
110
121
  x = keras.layers.Activation("relu", name="relu")(x)
111
122
 
@@ -117,6 +128,7 @@ class DenseNetBackbone(Backbone):
117
128
  self.compression_ratio = compression_ratio
118
129
  self.growth_rate = growth_rate
119
130
  self.image_shape = image_shape
131
+ self.pyramid_outputs = pyramid_outputs
120
132
 
121
133
  def get_config(self):
122
134
  config = super().get_config()
@@ -132,7 +144,7 @@ class DenseNetBackbone(Backbone):
132
144
  return config
133
145
 
134
146
 
135
- def apply_dense_block(x, num_repeats, growth_rate, name=None):
147
+ def apply_dense_block(x, channel_axis, num_repeats, growth_rate, name=None):
136
148
  """A dense block.
137
149
 
138
150
  Args:
@@ -145,11 +157,13 @@ def apply_dense_block(x, num_repeats, growth_rate, name=None):
145
157
  name = f"dense_block_{keras.backend.get_uid('dense_block')}"
146
158
 
147
159
  for i in range(num_repeats):
148
- x = apply_conv_block(x, growth_rate, name=f"{name}_block_{i}")
160
+ x = apply_conv_block(
161
+ x, channel_axis, growth_rate, name=f"{name}_block_{i}"
162
+ )
149
163
  return x
150
164
 
151
165
 
152
- def apply_transition_block(x, compression_ratio, name=None):
166
+ def apply_transition_block(x, channel_axis, compression_ratio, name=None):
153
167
  """A transition block.
154
168
 
155
169
  Args:
@@ -157,24 +171,28 @@ def apply_transition_block(x, compression_ratio, name=None):
157
171
  compression_ratio: float, compression rate at transition layers.
158
172
  name: string, block label.
159
173
  """
174
+ data_format = keras.config.image_data_format()
160
175
  if name is None:
161
176
  name = f"transition_block_{keras.backend.get_uid('transition_block')}"
162
177
 
163
178
  x = keras.layers.BatchNormalization(
164
- axis=BN_AXIS, epsilon=BN_EPSILON, name=f"{name}_bn"
179
+ axis=channel_axis, epsilon=BN_EPSILON, name=f"{name}_bn"
165
180
  )(x)
166
181
  x = keras.layers.Activation("relu", name=f"{name}_relu")(x)
167
182
  x = keras.layers.Conv2D(
168
- int(x.shape[BN_AXIS] * compression_ratio),
183
+ int(x.shape[channel_axis] * compression_ratio),
169
184
  1,
170
185
  use_bias=False,
186
+ data_format=data_format,
171
187
  name=f"{name}_conv",
172
188
  )(x)
173
- x = keras.layers.AveragePooling2D(2, strides=2, name=f"{name}_pool")(x)
189
+ x = keras.layers.AveragePooling2D(
190
+ 2, strides=2, data_format=data_format, name=f"{name}_pool"
191
+ )(x)
174
192
  return x
175
193
 
176
194
 
177
- def apply_conv_block(x, growth_rate, name=None):
195
+ def apply_conv_block(x, channel_axis, growth_rate, name=None):
178
196
  """A building block for a dense block.
179
197
 
180
198
  Args:
@@ -182,19 +200,24 @@ def apply_conv_block(x, growth_rate, name=None):
182
200
  growth_rate: int, number of filters added by each dense block.
183
201
  name: string, block label.
184
202
  """
203
+ data_format = keras.config.image_data_format()
185
204
  if name is None:
186
205
  name = f"conv_block_{keras.backend.get_uid('conv_block')}"
187
206
 
188
207
  shortcut = x
189
208
  x = keras.layers.BatchNormalization(
190
- axis=BN_AXIS, epsilon=BN_EPSILON, name=f"{name}_0_bn"
209
+ axis=channel_axis, epsilon=BN_EPSILON, name=f"{name}_0_bn"
191
210
  )(x)
192
211
  x = keras.layers.Activation("relu", name=f"{name}_0_relu")(x)
193
212
  x = keras.layers.Conv2D(
194
- 4 * growth_rate, 1, use_bias=False, name=f"{name}_1_conv"
213
+ 4 * growth_rate,
214
+ 1,
215
+ use_bias=False,
216
+ data_format=data_format,
217
+ name=f"{name}_1_conv",
195
218
  )(x)
196
219
  x = keras.layers.BatchNormalization(
197
- axis=BN_AXIS, epsilon=BN_EPSILON, name=f"{name}_1_bn"
220
+ axis=channel_axis, epsilon=BN_EPSILON, name=f"{name}_1_bn"
198
221
  )(x)
199
222
  x = keras.layers.Activation("relu", name=f"{name}_1_relu")(x)
200
223
  x = keras.layers.Conv2D(
@@ -202,9 +225,10 @@ def apply_conv_block(x, growth_rate, name=None):
202
225
  3,
203
226
  padding="same",
204
227
  use_bias=False,
228
+ data_format=data_format,
205
229
  name=f"{name}_2_conv",
206
230
  )(x)
207
- x = keras.layers.Concatenate(axis=BN_AXIS, name=f"{name}_concat")(
231
+ x = keras.layers.Concatenate(axis=channel_axis, name=f"{name}_concat")(
208
232
  [shortcut, x]
209
233
  )
210
234
  return x
@@ -18,9 +18,6 @@ from keras_hub.src.models.distil_bert.distil_bert_backbone import (
18
18
  from keras_hub.src.models.distil_bert.distil_bert_presets import (
19
19
  backbone_presets,
20
20
  )
21
- from keras_hub.src.models.distil_bert.distil_bert_tokenizer import (
22
- DistilBertTokenizer,
23
- )
24
21
  from keras_hub.src.utils.preset_utils import register_presets
25
22
 
26
- register_presets(backbone_presets, (DistilBertBackbone, DistilBertTokenizer))
23
+ register_presets(backbone_presets, DistilBertBackbone)
@@ -13,19 +13,20 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import keras
16
- from absl import logging
17
16
 
18
17
  from keras_hub.src.api_export import keras_hub_export
19
- from keras_hub.src.layers.preprocessing.masked_lm_mask_generator import (
20
- MaskedLMMaskGenerator,
18
+ from keras_hub.src.models.distil_bert.distil_bert_backbone import (
19
+ DistilBertBackbone,
21
20
  )
22
- from keras_hub.src.models.distil_bert.distil_bert_preprocessor import (
23
- DistilBertPreprocessor,
21
+ from keras_hub.src.models.distil_bert.distil_bert_tokenizer import (
22
+ DistilBertTokenizer,
24
23
  )
24
+ from keras_hub.src.models.masked_lm_preprocessor import MaskedLMPreprocessor
25
+ from keras_hub.src.utils.tensor_utils import preprocessing_function
25
26
 
26
27
 
27
28
  @keras_hub_export("keras_hub.models.DistilBertMaskedLMPreprocessor")
28
- class DistilBertMaskedLMPreprocessor(DistilBertPreprocessor):
29
+ class DistilBertMaskedLMPreprocessor(MaskedLMPreprocessor):
29
30
  """DistilBERT preprocessing for the masked language modeling task.
30
31
 
31
32
  This preprocessing layer will prepare inputs for a masked language modeling
@@ -119,76 +120,13 @@ class DistilBertMaskedLMPreprocessor(DistilBertPreprocessor):
119
120
  ```
120
121
  """
121
122
 
122
- def __init__(
123
- self,
124
- tokenizer,
125
- sequence_length=512,
126
- truncate="round_robin",
127
- mask_selection_rate=0.15,
128
- mask_selection_length=96,
129
- mask_token_rate=0.8,
130
- random_token_rate=0.1,
131
- **kwargs,
132
- ):
133
- super().__init__(
134
- tokenizer,
135
- sequence_length=sequence_length,
136
- truncate=truncate,
137
- **kwargs,
138
- )
139
- self.mask_selection_rate = mask_selection_rate
140
- self.mask_selection_length = mask_selection_length
141
- self.mask_token_rate = mask_token_rate
142
- self.random_token_rate = random_token_rate
143
- self.masker = None
144
-
145
- def build(self, input_shape):
146
- super().build(input_shape)
147
- # Defer masker creation to `build()` so that we can be sure tokenizer
148
- # assets have loaded when restoring a saved model.
149
- self.masker = MaskedLMMaskGenerator(
150
- mask_selection_rate=self.mask_selection_rate,
151
- mask_selection_length=self.mask_selection_length,
152
- mask_token_rate=self.mask_token_rate,
153
- random_token_rate=self.random_token_rate,
154
- vocabulary_size=self.tokenizer.vocabulary_size(),
155
- mask_token_id=self.tokenizer.mask_token_id,
156
- unselectable_token_ids=[
157
- self.tokenizer.cls_token_id,
158
- self.tokenizer.sep_token_id,
159
- self.tokenizer.pad_token_id,
160
- ],
161
- )
123
+ backbone_cls = DistilBertBackbone
124
+ tokenizer_cls = DistilBertTokenizer
162
125
 
126
+ @preprocessing_function
163
127
  def call(self, x, y=None, sample_weight=None):
164
- if y is not None or sample_weight is not None:
165
- logging.warning(
166
- f"{self.__class__.__name__} generates `y` and `sample_weight` "
167
- "based on your input data, but your data already contains `y` "
168
- "or `sample_weight`. Your `y` and `sample_weight` will be "
169
- "ignored."
170
- )
171
-
172
- x = super().call(x)
173
- token_ids, padding_mask = x["token_ids"], x["padding_mask"]
174
- masker_outputs = self.masker(token_ids)
175
- x = {
176
- "token_ids": masker_outputs["token_ids"],
177
- "padding_mask": padding_mask,
178
- "mask_positions": masker_outputs["mask_positions"],
179
- }
180
- y = masker_outputs["mask_ids"]
181
- sample_weight = masker_outputs["mask_weights"]
128
+ output = super().call(x, y=y, sample_weight=sample_weight)
129
+ x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(output)
130
+ # Backbone has no segment ID input.
131
+ del x["segment_ids"]
182
132
  return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
183
-
184
- def get_config(self):
185
- config = super().get_config()
186
- config.update(
187
- {
188
- "mask_selection_rate": self.mask_selection_rate,
189
- "mask_selection_length": self.mask_selection_length,
190
- "mask_token_rate": self.mask_token_rate,
191
- "random_token_rate": self.random_token_rate,
192
- }
193
- )
194
- return config
@@ -16,20 +16,25 @@
16
16
  import keras
17
17
 
18
18
  from keras_hub.src.api_export import keras_hub_export
19
- from keras_hub.src.models.classifier import Classifier
20
19
  from keras_hub.src.models.distil_bert.distil_bert_backbone import (
21
20
  DistilBertBackbone,
22
21
  )
23
22
  from keras_hub.src.models.distil_bert.distil_bert_backbone import (
24
23
  distilbert_kernel_initializer,
25
24
  )
26
- from keras_hub.src.models.distil_bert.distil_bert_preprocessor import (
27
- DistilBertPreprocessor,
25
+ from keras_hub.src.models.distil_bert.distil_bert_text_classifier_preprocessor import (
26
+ DistilBertTextClassifierPreprocessor,
28
27
  )
28
+ from keras_hub.src.models.text_classifier import TextClassifier
29
29
 
30
30
 
31
- @keras_hub_export("keras_hub.models.DistilBertClassifier")
32
- class DistilBertClassifier(Classifier):
31
+ @keras_hub_export(
32
+ [
33
+ "keras_hub.models.DistilBertTextClassifier",
34
+ "keras_hub.models.DistilBertClassifier",
35
+ ]
36
+ )
37
+ class DistilBertTextClassifier(TextClassifier):
33
38
  """An end-to-end DistilBERT model for classification tasks.
34
39
 
35
40
  This model attaches a classification head to a
@@ -50,7 +55,7 @@ class DistilBertClassifier(Classifier):
50
55
  Args:
51
56
  backbone: A `keras_hub.models.DistilBert` instance.
52
57
  num_classes: int. Number of classes to predict.
53
- preprocessor: A `keras_hub.models.DistilBertPreprocessor` or `None`. If
58
+ preprocessor: A `keras_hub.models.DistilBertTextClassifierPreprocessor` or `None`. If
54
59
  `None`, this model will not apply preprocessing, and inputs should
55
60
  be preprocessed before calling the model.
56
61
  activation: Optional `str` or callable. The
@@ -69,12 +74,12 @@ class DistilBertClassifier(Classifier):
69
74
  labels = [0, 3]
70
75
 
71
76
  # Use a shorter sequence length.
72
- preprocessor = keras_hub.models.DistilBertPreprocessor.from_preset(
77
+ preprocessor = keras_hub.models.DistilBertTextClassifierPreprocessor.from_preset(
73
78
  "distil_bert_base_en_uncased",
74
79
  sequence_length=128,
75
80
  )
76
81
  # Pretrained classifier.
77
- classifier = keras_hub.models.DistilBertClassifier.from_preset(
82
+ classifier = keras_hub.models.DistilBertTextClassifier.from_preset(
78
83
  "distil_bert_base_en_uncased",
79
84
  num_classes=4,
80
85
  preprocessor=preprocessor,
@@ -102,7 +107,7 @@ class DistilBertClassifier(Classifier):
102
107
  labels = [0, 3]
103
108
 
104
109
  # Pretrained classifier without preprocessing.
105
- classifier = keras_hub.models.DistilBertClassifier.from_preset(
110
+ classifier = keras_hub.models.DistilBertTextClassifier.from_preset(
106
111
  "distil_bert_base_en_uncased",
107
112
  num_classes=4,
108
113
  preprocessor=None,
@@ -119,7 +124,7 @@ class DistilBertClassifier(Classifier):
119
124
  tokenizer = keras_hub.models.DistilBertTokenizer(
120
125
  vocabulary=vocab,
121
126
  )
122
- preprocessor = keras_hub.models.DistilBertPreprocessor(
127
+ preprocessor = keras_hub.models.DistilBertTextClassifierPreprocessor(
123
128
  tokenizer=tokenizer,
124
129
  sequence_length=128,
125
130
  )
@@ -131,7 +136,7 @@ class DistilBertClassifier(Classifier):
131
136
  intermediate_dim=512,
132
137
  max_sequence_length=128,
133
138
  )
134
- classifier = keras_hub.models.DistilBertClassifier(
139
+ classifier = keras_hub.models.DistilBertTextClassifier(
135
140
  backbone=backbone,
136
141
  preprocessor=preprocessor,
137
142
  num_classes=4,
@@ -141,7 +146,7 @@ class DistilBertClassifier(Classifier):
141
146
  """
142
147
 
143
148
  backbone_cls = DistilBertBackbone
144
- preprocessor_cls = DistilBertPreprocessor
149
+ preprocessor_cls = DistilBertTextClassifierPreprocessor
145
150
 
146
151
  def __init__(
147
152
  self,
@@ -16,20 +16,25 @@
16
16
  import keras
17
17
 
18
18
  from keras_hub.src.api_export import keras_hub_export
19
- from keras_hub.src.layers.preprocessing.multi_segment_packer import (
20
- MultiSegmentPacker,
19
+ from keras_hub.src.models.distil_bert.distil_bert_backbone import (
20
+ DistilBertBackbone,
21
21
  )
22
22
  from keras_hub.src.models.distil_bert.distil_bert_tokenizer import (
23
23
  DistilBertTokenizer,
24
24
  )
25
- from keras_hub.src.models.preprocessor import Preprocessor
26
- from keras_hub.src.utils.keras_utils import (
27
- convert_inputs_to_list_of_tensor_segments,
25
+ from keras_hub.src.models.text_classifier_preprocessor import (
26
+ TextClassifierPreprocessor,
28
27
  )
28
+ from keras_hub.src.utils.tensor_utils import preprocessing_function
29
29
 
30
30
 
31
- @keras_hub_export("keras_hub.models.DistilBertPreprocessor")
32
- class DistilBertPreprocessor(Preprocessor):
31
+ @keras_hub_export(
32
+ [
33
+ "keras_hub.models.DistilBertTextClassifierPreprocessor",
34
+ "keras_hub.models.DistilBertPreprocessor",
35
+ ]
36
+ )
37
+ class DistilBertTextClassifierPreprocessor(TextClassifierPreprocessor):
33
38
  """A DistilBERT preprocessing layer which tokenizes and packs inputs.
34
39
 
35
40
  This preprocessing layer will do three things:
@@ -70,7 +75,7 @@ class DistilBertPreprocessor(Preprocessor):
70
75
 
71
76
  Directly calling the layer on data.
72
77
  ```python
73
- preprocessor = keras_hub.models.DistilBertPreprocessor.from_preset(
78
+ preprocessor = keras_hub.models.TextClassifierPreprocessor.from_preset(
74
79
  "distil_bert_base_en_uncased"
75
80
  )
76
81
  preprocessor(["The quick brown fox jumped.", "Call me Ishmael."])
@@ -79,13 +84,15 @@ class DistilBertPreprocessor(Preprocessor):
79
84
  vocab = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
80
85
  vocab += ["The", "quick", "brown", "fox", "jumped", "."]
81
86
  tokenizer = keras_hub.models.DistilBertTokenizer(vocabulary=vocab)
82
- preprocessor = keras_hub.models.DistilBertPreprocessor(tokenizer)
87
+ preprocessor = keras_hub.models.DistilBertTextClassifierPreprocessor(
88
+ tokenizer
89
+ )
83
90
  preprocessor("The quick brown fox jumped.")
84
91
  ```
85
92
 
86
93
  Mapping with `tf.data.Dataset`.
87
94
  ```python
88
- preprocessor = keras_hub.models.DistilBertPreprocessor.from_preset(
95
+ preprocessor = keras_hub.models.TextClassifierPreprocessor.from_preset(
89
96
  "distil_bert_base_en_uncased"
90
97
  )
91
98
 
@@ -116,60 +123,13 @@ class DistilBertPreprocessor(Preprocessor):
116
123
  ```
117
124
  """
118
125
 
126
+ backbone_cls = DistilBertBackbone
119
127
  tokenizer_cls = DistilBertTokenizer
120
128
 
121
- def __init__(
122
- self,
123
- tokenizer,
124
- sequence_length=512,
125
- truncate="round_robin",
126
- **kwargs,
127
- ):
128
- super().__init__(**kwargs)
129
- self.tokenizer = tokenizer
130
- self.packer = None
131
- self.sequence_length = sequence_length
132
- self.truncate = truncate
133
-
134
- def build(self, input_shape):
135
- super().build(input_shape)
136
- # Defer masker creation to `build()` so that we can be sure tokenizer
137
- # assets have loaded when restoring a saved model.
138
- self.packer = MultiSegmentPacker(
139
- start_value=self.tokenizer.cls_token_id,
140
- end_value=self.tokenizer.sep_token_id,
141
- pad_value=self.tokenizer.pad_token_id,
142
- truncate=self.truncate,
143
- sequence_length=self.sequence_length,
144
- )
145
-
129
+ @preprocessing_function
146
130
  def call(self, x, y=None, sample_weight=None):
147
- x = convert_inputs_to_list_of_tensor_segments(x)
148
- x = [self.tokenizer(segment) for segment in x]
149
- token_ids, _ = self.packer(x)
150
- x = {
151
- "token_ids": token_ids,
152
- "padding_mask": token_ids != self.tokenizer.pad_token_id,
153
- }
131
+ output = super().call(x, y=y, sample_weight=sample_weight)
132
+ x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(output)
133
+ # Backbone has no segment ID input.
134
+ del x["segment_ids"]
154
135
  return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
155
-
156
- def get_config(self):
157
- config = super().get_config()
158
- config.update(
159
- {
160
- "sequence_length": self.sequence_length,
161
- "truncate": self.truncate,
162
- }
163
- )
164
- return config
165
-
166
- @property
167
- def sequence_length(self):
168
- """The padded length of model input sequences."""
169
- return self._sequence_length
170
-
171
- @sequence_length.setter
172
- def sequence_length(self, value):
173
- self._sequence_length = value
174
- if self.packer is not None:
175
- self.packer.sequence_length = value
@@ -14,10 +14,18 @@
14
14
 
15
15
 
16
16
  from keras_hub.src.api_export import keras_hub_export
17
+ from keras_hub.src.models.distil_bert.distil_bert_backbone import (
18
+ DistilBertBackbone,
19
+ )
17
20
  from keras_hub.src.tokenizers.word_piece_tokenizer import WordPieceTokenizer
18
21
 
19
22
 
20
- @keras_hub_export("keras_hub.models.DistilBertTokenizer")
23
+ @keras_hub_export(
24
+ [
25
+ "keras_hub.tokenizers.DistilBertTokenizer",
26
+ "keras_hub.models.DistilBertTokenizer",
27
+ ]
28
+ )
21
29
  class DistilBertTokenizer(WordPieceTokenizer):
22
30
  """A DistilBERT tokenizer using WordPiece subword segmentation.
23
31
 
@@ -27,9 +35,6 @@ class DistilBertTokenizer(WordPieceTokenizer):
27
35
  models and provides a `from_preset()` method to automatically download
28
36
  a matching vocabulary for a DistilBERT preset.
29
37
 
30
- This tokenizer does not provide truncation or padding of inputs. It can be
31
- combined with a `keras_hub.models.DistilBertPreprocessor` layer for input packing.
32
-
33
38
  If input is a batch of strings (rank > 0), the layer will output a
34
39
  `tf.RaggedTensor` where the last dimension of the output is ragged.
35
40
 
@@ -70,45 +75,24 @@ class DistilBertTokenizer(WordPieceTokenizer):
70
75
  ```
71
76
  """
72
77
 
78
+ backbone_cls = DistilBertBackbone
79
+
73
80
  def __init__(
74
81
  self,
75
82
  vocabulary,
76
83
  lowercase=False,
77
- special_tokens_in_strings=False,
78
84
  **kwargs,
79
85
  ):
80
- self.cls_token = "[CLS]"
81
- self.sep_token = "[SEP]"
82
- self.pad_token = "[PAD]"
83
- self.mask_token = "[MASK]"
86
+ self._add_special_token("[CLS]", "cls_token")
87
+ self._add_special_token("[SEP]", "sep_token")
88
+ self._add_special_token("[PAD]", "pad_token")
89
+ self._add_special_token("[MASK]", "mask_token")
90
+ # Also add `tokenizer.start_token` and `tokenizer.end_token` for
91
+ # compatibility with other tokenizers.
92
+ self._add_special_token("[CLS]", "start_token")
93
+ self._add_special_token("[SEP]", "end_token")
84
94
  super().__init__(
85
95
  vocabulary=vocabulary,
86
96
  lowercase=lowercase,
87
- special_tokens=[
88
- self.cls_token,
89
- self.sep_token,
90
- self.pad_token,
91
- self.mask_token,
92
- ],
93
- special_tokens_in_strings=special_tokens_in_strings,
94
97
  **kwargs,
95
98
  )
96
-
97
- def set_vocabulary(self, vocabulary):
98
- super().set_vocabulary(vocabulary)
99
-
100
- if vocabulary is not None:
101
- self.cls_token_id = self.token_to_id(self.cls_token)
102
- self.sep_token_id = self.token_to_id(self.sep_token)
103
- self.pad_token_id = self.token_to_id(self.pad_token)
104
- self.mask_token_id = self.token_to_id(self.mask_token)
105
- else:
106
- self.cls_token_id = None
107
- self.sep_token_id = None
108
- self.pad_token_id = None
109
- self.mask_token_id = None
110
-
111
- def get_config(self):
112
- config = super().get_config()
113
- del config["special_tokens"] # Not configurable; set in __init__.
114
- return config
@@ -0,0 +1,13 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.