keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl → 0.16.0.dev2024092017__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (198) hide show
  1. keras_hub/__init__.py +0 -6
  2. keras_hub/api/__init__.py +2 -0
  3. keras_hub/api/bounding_box/__init__.py +36 -0
  4. keras_hub/api/layers/__init__.py +14 -0
  5. keras_hub/api/models/__init__.py +97 -48
  6. keras_hub/api/tokenizers/__init__.py +30 -0
  7. keras_hub/api/utils/__init__.py +22 -0
  8. keras_hub/src/api_export.py +15 -9
  9. keras_hub/src/bounding_box/__init__.py +13 -0
  10. keras_hub/src/bounding_box/converters.py +529 -0
  11. keras_hub/src/bounding_box/formats.py +162 -0
  12. keras_hub/src/bounding_box/iou.py +263 -0
  13. keras_hub/src/bounding_box/to_dense.py +95 -0
  14. keras_hub/src/bounding_box/to_ragged.py +99 -0
  15. keras_hub/src/bounding_box/utils.py +194 -0
  16. keras_hub/src/bounding_box/validate_format.py +99 -0
  17. keras_hub/src/layers/preprocessing/audio_converter.py +121 -0
  18. keras_hub/src/layers/preprocessing/image_converter.py +130 -0
  19. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +2 -0
  20. keras_hub/src/layers/preprocessing/multi_segment_packer.py +9 -8
  21. keras_hub/src/layers/preprocessing/preprocessing_layer.py +2 -29
  22. keras_hub/src/layers/preprocessing/random_deletion.py +33 -31
  23. keras_hub/src/layers/preprocessing/random_swap.py +33 -31
  24. keras_hub/src/layers/preprocessing/resizing_image_converter.py +101 -0
  25. keras_hub/src/layers/preprocessing/start_end_packer.py +3 -2
  26. keras_hub/src/models/albert/__init__.py +1 -2
  27. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +6 -86
  28. keras_hub/src/models/albert/{albert_classifier.py → albert_text_classifier.py} +34 -10
  29. keras_hub/src/models/albert/{albert_preprocessor.py → albert_text_classifier_preprocessor.py} +14 -70
  30. keras_hub/src/models/albert/albert_tokenizer.py +17 -36
  31. keras_hub/src/models/backbone.py +12 -34
  32. keras_hub/src/models/bart/__init__.py +1 -2
  33. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +21 -148
  34. keras_hub/src/models/bart/bart_tokenizer.py +12 -39
  35. keras_hub/src/models/bert/__init__.py +1 -5
  36. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +6 -87
  37. keras_hub/src/models/bert/bert_presets.py +1 -4
  38. keras_hub/src/models/bert/{bert_classifier.py → bert_text_classifier.py} +19 -12
  39. keras_hub/src/models/bert/{bert_preprocessor.py → bert_text_classifier_preprocessor.py} +14 -70
  40. keras_hub/src/models/bert/bert_tokenizer.py +17 -35
  41. keras_hub/src/models/bloom/__init__.py +1 -2
  42. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +6 -91
  43. keras_hub/src/models/bloom/bloom_tokenizer.py +12 -41
  44. keras_hub/src/models/causal_lm.py +10 -29
  45. keras_hub/src/models/causal_lm_preprocessor.py +195 -0
  46. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +54 -15
  47. keras_hub/src/models/deberta_v3/__init__.py +1 -4
  48. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +14 -77
  49. keras_hub/src/models/deberta_v3/{deberta_v3_classifier.py → deberta_v3_text_classifier.py} +16 -11
  50. keras_hub/src/models/deberta_v3/{deberta_v3_preprocessor.py → deberta_v3_text_classifier_preprocessor.py} +23 -64
  51. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +30 -25
  52. keras_hub/src/models/densenet/densenet_backbone.py +46 -22
  53. keras_hub/src/models/distil_bert/__init__.py +1 -4
  54. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +14 -76
  55. keras_hub/src/models/distil_bert/{distil_bert_classifier.py → distil_bert_text_classifier.py} +17 -12
  56. keras_hub/src/models/distil_bert/{distil_bert_preprocessor.py → distil_bert_text_classifier_preprocessor.py} +23 -63
  57. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +19 -35
  58. keras_hub/src/models/efficientnet/__init__.py +13 -0
  59. keras_hub/src/models/efficientnet/efficientnet_backbone.py +569 -0
  60. keras_hub/src/models/efficientnet/fusedmbconv.py +229 -0
  61. keras_hub/src/models/efficientnet/mbconv.py +238 -0
  62. keras_hub/src/models/electra/__init__.py +1 -2
  63. keras_hub/src/models/electra/electra_tokenizer.py +17 -32
  64. keras_hub/src/models/f_net/__init__.py +1 -2
  65. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +12 -78
  66. keras_hub/src/models/f_net/{f_net_classifier.py → f_net_text_classifier.py} +17 -10
  67. keras_hub/src/models/f_net/{f_net_preprocessor.py → f_net_text_classifier_preprocessor.py} +19 -63
  68. keras_hub/src/models/f_net/f_net_tokenizer.py +17 -35
  69. keras_hub/src/models/falcon/__init__.py +1 -2
  70. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +6 -89
  71. keras_hub/src/models/falcon/falcon_tokenizer.py +12 -35
  72. keras_hub/src/models/gemma/__init__.py +1 -2
  73. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +6 -90
  74. keras_hub/src/models/gemma/gemma_decoder_block.py +1 -1
  75. keras_hub/src/models/gemma/gemma_tokenizer.py +12 -23
  76. keras_hub/src/models/gpt2/__init__.py +1 -2
  77. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +6 -89
  78. keras_hub/src/models/gpt2/gpt2_preprocessor.py +12 -90
  79. keras_hub/src/models/gpt2/gpt2_tokenizer.py +12 -34
  80. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +6 -91
  81. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +12 -34
  82. keras_hub/src/models/image_classifier.py +0 -5
  83. keras_hub/src/models/image_classifier_preprocessor.py +83 -0
  84. keras_hub/src/models/llama/__init__.py +1 -2
  85. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +6 -85
  86. keras_hub/src/models/llama/llama_tokenizer.py +12 -25
  87. keras_hub/src/models/llama3/__init__.py +1 -2
  88. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +6 -89
  89. keras_hub/src/models/llama3/llama3_tokenizer.py +12 -33
  90. keras_hub/src/models/masked_lm.py +0 -2
  91. keras_hub/src/models/masked_lm_preprocessor.py +156 -0
  92. keras_hub/src/models/mistral/__init__.py +1 -2
  93. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +6 -91
  94. keras_hub/src/models/mistral/mistral_tokenizer.py +12 -23
  95. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +2 -2
  96. keras_hub/src/models/mobilenet/__init__.py +13 -0
  97. keras_hub/src/models/mobilenet/mobilenet_backbone.py +530 -0
  98. keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +114 -0
  99. keras_hub/src/models/opt/__init__.py +1 -2
  100. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +6 -93
  101. keras_hub/src/models/opt/opt_tokenizer.py +12 -41
  102. keras_hub/src/models/pali_gemma/__init__.py +1 -4
  103. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +28 -28
  104. keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +25 -0
  105. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +5 -5
  106. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +11 -3
  107. keras_hub/src/models/phi3/__init__.py +1 -2
  108. keras_hub/src/models/phi3/phi3_causal_lm.py +3 -9
  109. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +6 -89
  110. keras_hub/src/models/phi3/phi3_tokenizer.py +12 -36
  111. keras_hub/src/models/preprocessor.py +72 -83
  112. keras_hub/src/models/resnet/__init__.py +6 -0
  113. keras_hub/src/models/resnet/resnet_backbone.py +390 -42
  114. keras_hub/src/models/resnet/resnet_image_classifier.py +33 -6
  115. keras_hub/src/models/resnet/resnet_image_classifier_preprocessor.py +28 -0
  116. keras_hub/src/models/{llama3/llama3_preprocessor.py → resnet/resnet_image_converter.py} +7 -5
  117. keras_hub/src/models/resnet/resnet_presets.py +95 -0
  118. keras_hub/src/models/retinanet/__init__.py +13 -0
  119. keras_hub/src/models/retinanet/anchor_generator.py +175 -0
  120. keras_hub/src/models/retinanet/box_matcher.py +259 -0
  121. keras_hub/src/models/retinanet/non_max_supression.py +578 -0
  122. keras_hub/src/models/roberta/__init__.py +1 -2
  123. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +22 -74
  124. keras_hub/src/models/roberta/{roberta_classifier.py → roberta_text_classifier.py} +16 -11
  125. keras_hub/src/models/roberta/{roberta_preprocessor.py → roberta_text_classifier_preprocessor.py} +21 -53
  126. keras_hub/src/models/roberta/roberta_tokenizer.py +13 -52
  127. keras_hub/src/models/seq_2_seq_lm_preprocessor.py +269 -0
  128. keras_hub/src/models/stable_diffusion_v3/__init__.py +13 -0
  129. keras_hub/src/models/stable_diffusion_v3/clip_encoder_block.py +103 -0
  130. keras_hub/src/models/stable_diffusion_v3/clip_preprocessor.py +93 -0
  131. keras_hub/src/models/stable_diffusion_v3/clip_text_encoder.py +149 -0
  132. keras_hub/src/models/stable_diffusion_v3/clip_tokenizer.py +167 -0
  133. keras_hub/src/models/stable_diffusion_v3/mmdit.py +427 -0
  134. keras_hub/src/models/stable_diffusion_v3/mmdit_block.py +317 -0
  135. keras_hub/src/models/stable_diffusion_v3/t5_xxl_preprocessor.py +74 -0
  136. keras_hub/src/models/stable_diffusion_v3/t5_xxl_text_encoder.py +155 -0
  137. keras_hub/src/models/stable_diffusion_v3/vae_attention.py +126 -0
  138. keras_hub/src/models/stable_diffusion_v3/vae_image_decoder.py +186 -0
  139. keras_hub/src/models/t5/__init__.py +1 -2
  140. keras_hub/src/models/t5/t5_tokenizer.py +13 -23
  141. keras_hub/src/models/task.py +71 -116
  142. keras_hub/src/models/{classifier.py → text_classifier.py} +19 -13
  143. keras_hub/src/models/text_classifier_preprocessor.py +138 -0
  144. keras_hub/src/models/whisper/__init__.py +1 -2
  145. keras_hub/src/models/whisper/{whisper_audio_feature_extractor.py → whisper_audio_converter.py} +20 -18
  146. keras_hub/src/models/whisper/whisper_backbone.py +0 -3
  147. keras_hub/src/models/whisper/whisper_presets.py +10 -10
  148. keras_hub/src/models/whisper/whisper_tokenizer.py +20 -16
  149. keras_hub/src/models/xlm_roberta/__init__.py +1 -4
  150. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +26 -72
  151. keras_hub/src/models/xlm_roberta/{xlm_roberta_classifier.py → xlm_roberta_text_classifier.py} +16 -11
  152. keras_hub/src/models/xlm_roberta/{xlm_roberta_preprocessor.py → xlm_roberta_text_classifier_preprocessor.py} +26 -53
  153. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +25 -10
  154. keras_hub/src/tests/test_case.py +46 -0
  155. keras_hub/src/tokenizers/byte_pair_tokenizer.py +30 -17
  156. keras_hub/src/tokenizers/byte_tokenizer.py +14 -15
  157. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +20 -7
  158. keras_hub/src/tokenizers/tokenizer.py +67 -32
  159. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +14 -15
  160. keras_hub/src/tokenizers/word_piece_tokenizer.py +34 -47
  161. keras_hub/src/utils/imagenet/__init__.py +13 -0
  162. keras_hub/src/utils/imagenet/imagenet_utils.py +1067 -0
  163. keras_hub/src/utils/keras_utils.py +0 -50
  164. keras_hub/src/utils/preset_utils.py +230 -68
  165. keras_hub/src/utils/tensor_utils.py +187 -69
  166. keras_hub/src/utils/timm/convert_resnet.py +19 -16
  167. keras_hub/src/utils/timm/preset_loader.py +66 -0
  168. keras_hub/src/utils/transformers/convert_albert.py +193 -0
  169. keras_hub/src/utils/transformers/convert_bart.py +373 -0
  170. keras_hub/src/utils/transformers/convert_bert.py +7 -17
  171. keras_hub/src/utils/transformers/convert_distilbert.py +10 -20
  172. keras_hub/src/utils/transformers/convert_gemma.py +5 -19
  173. keras_hub/src/utils/transformers/convert_gpt2.py +5 -18
  174. keras_hub/src/utils/transformers/convert_llama3.py +7 -18
  175. keras_hub/src/utils/transformers/convert_mistral.py +129 -0
  176. keras_hub/src/utils/transformers/convert_pali_gemma.py +7 -29
  177. keras_hub/src/utils/transformers/preset_loader.py +77 -0
  178. keras_hub/src/utils/transformers/safetensor_utils.py +2 -2
  179. keras_hub/src/version_utils.py +1 -1
  180. keras_hub_nightly-0.16.0.dev2024092017.dist-info/METADATA +202 -0
  181. keras_hub_nightly-0.16.0.dev2024092017.dist-info/RECORD +334 -0
  182. {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/WHEEL +1 -1
  183. keras_hub/src/models/bart/bart_preprocessor.py +0 -276
  184. keras_hub/src/models/bloom/bloom_preprocessor.py +0 -185
  185. keras_hub/src/models/electra/electra_preprocessor.py +0 -154
  186. keras_hub/src/models/falcon/falcon_preprocessor.py +0 -187
  187. keras_hub/src/models/gemma/gemma_preprocessor.py +0 -191
  188. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +0 -145
  189. keras_hub/src/models/llama/llama_preprocessor.py +0 -189
  190. keras_hub/src/models/mistral/mistral_preprocessor.py +0 -190
  191. keras_hub/src/models/opt/opt_preprocessor.py +0 -188
  192. keras_hub/src/models/phi3/phi3_preprocessor.py +0 -190
  193. keras_hub/src/models/whisper/whisper_preprocessor.py +0 -326
  194. keras_hub/src/utils/timm/convert.py +0 -37
  195. keras_hub/src/utils/transformers/convert.py +0 -101
  196. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +0 -34
  197. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +0 -297
  198. {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,9 @@ import keras
16
16
  from keras_hub.src.api_export import keras_hub_export
17
17
  from keras_hub.src.models.image_classifier import ImageClassifier
18
18
  from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone
19
+ from keras_hub.src.models.resnet.resnet_image_classifier_preprocessor import (
20
+ ResNetImageClassifierPreprocessor,
21
+ )
19
22
 
20
23
 
21
24
  @keras_hub_export("keras_hub.models.ResNetImageClassifier")
@@ -42,7 +45,9 @@ class ResNetImageClassifier(ImageClassifier):
42
45
  ```python
43
46
  # Load preset and train
44
47
  images = np.ones((2, 224, 224, 3), dtype="float32")
45
- classifier = keras_hub.models.ResNetImageClassifier.from_preset("resnet50")
48
+ classifier = keras_hub.models.ResNetImageClassifier.from_preset(
49
+ "resnet_50_imagenet"
50
+ )
46
51
  classifier.predict(images)
47
52
  ```
48
53
 
@@ -51,13 +56,17 @@ class ResNetImageClassifier(ImageClassifier):
51
56
  # Load preset and train
52
57
  images = np.ones((2, 224, 224, 3), dtype="float32")
53
58
  labels = [0, 3]
54
- classifier = keras_hub.models.ResNetImageClassifier.from_preset("resnet50")
59
+ classifier = keras_hub.models.ResNetImageClassifier.from_preset(
60
+ "resnet_50_imagenet"
61
+ )
55
62
  classifier.fit(x=images, y=labels, batch_size=2)
56
63
  ```
57
64
 
58
65
  Call `fit()` with custom loss, optimizer and backbone.
59
66
  ```python
60
- classifier = keras_hub.models.ResNetImageClassifier.from_preset("resnet50")
67
+ classifier = keras_hub.models.ResNetImageClassifier.from_preset(
68
+ "resnet_50_imagenet"
69
+ )
61
70
  classifier.compile(
62
71
  loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
63
72
  optimizer=keras.optimizers.Adam(5e-5),
@@ -88,21 +97,36 @@ class ResNetImageClassifier(ImageClassifier):
88
97
  """
89
98
 
90
99
  backbone_cls = ResNetBackbone
100
+ preprocessor_cls = ResNetImageClassifierPreprocessor
91
101
 
92
102
  def __init__(
93
103
  self,
94
104
  backbone,
95
105
  num_classes,
96
- activation="softmax",
106
+ preprocessor=None,
107
+ pooling="avg",
108
+ activation=None,
97
109
  head_dtype=None,
98
- preprocessor=None, # adding this dummy arg for saved model test
99
- # TODO: once preprocessor flow is figured out, this needs to be updated
100
110
  **kwargs,
101
111
  ):
102
112
  head_dtype = head_dtype or backbone.dtype_policy
103
113
 
104
114
  # === Layers ===
105
115
  self.backbone = backbone
116
+ self.preprocessor = preprocessor
117
+ if pooling == "avg":
118
+ self.pooler = keras.layers.GlobalAveragePooling2D(
119
+ data_format=backbone.data_format, dtype=head_dtype
120
+ )
121
+ elif pooling == "max":
122
+ self.pooler = keras.layers.GlobalAveragePooling2D(
123
+ data_format=backbone.data_format, dtype=head_dtype
124
+ )
125
+ else:
126
+ raise ValueError(
127
+ "Unknown `pooling` type. Polling should be either `'avg'` or "
128
+ f"`'max'`. Received: pooling={pooling}."
129
+ )
106
130
  self.output_dense = keras.layers.Dense(
107
131
  num_classes,
108
132
  activation=activation,
@@ -113,6 +137,7 @@ class ResNetImageClassifier(ImageClassifier):
113
137
  # === Functional Model ===
114
138
  inputs = self.backbone.input
115
139
  x = self.backbone(inputs)
140
+ x = self.pooler(x)
116
141
  outputs = self.output_dense(x)
117
142
  super().__init__(
118
143
  inputs=inputs,
@@ -123,6 +148,7 @@ class ResNetImageClassifier(ImageClassifier):
123
148
  # === Config ===
124
149
  self.num_classes = num_classes
125
150
  self.activation = activation
151
+ self.pooling = pooling
126
152
 
127
153
  def get_config(self):
128
154
  # Backbone serialized in `super`
@@ -130,6 +156,7 @@ class ResNetImageClassifier(ImageClassifier):
130
156
  config.update(
131
157
  {
132
158
  "num_classes": self.num_classes,
159
+ "pooling": self.pooling,
133
160
  "activation": self.activation,
134
161
  }
135
162
  )
@@ -0,0 +1,28 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from keras_hub.src.api_export import keras_hub_export
16
+ from keras_hub.src.models.image_classifier_preprocessor import (
17
+ ImageClassifierPreprocessor,
18
+ )
19
+ from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone
20
+ from keras_hub.src.models.resnet.resnet_image_converter import (
21
+ ResNetImageConverter,
22
+ )
23
+
24
+
25
+ @keras_hub_export("keras_hub.models.ResNetImageClassifierPreprocessor")
26
+ class ResNetImageClassifierPreprocessor(ImageClassifierPreprocessor):
27
+ backbone_cls = ResNetBackbone
28
+ image_converter_cls = ResNetImageConverter
@@ -12,10 +12,12 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  from keras_hub.src.api_export import keras_hub_export
15
- from keras_hub.src.models.llama3.llama3_tokenizer import Llama3Tokenizer
16
- from keras_hub.src.models.llama.llama_preprocessor import LlamaPreprocessor
15
+ from keras_hub.src.layers.preprocessing.resizing_image_converter import (
16
+ ResizingImageConverter,
17
+ )
18
+ from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone
17
19
 
18
20
 
19
- @keras_hub_export("keras_hub.models.Llama3Preprocessor")
20
- class Llama3Preprocessor(LlamaPreprocessor):
21
- tokenizer_cls = Llama3Tokenizer
21
+ @keras_hub_export("keras_hub.layers.ResNetImageConverter")
22
+ class ResNetImageConverter(ResizingImageConverter):
23
+ backbone_cls = ResNetBackbone
@@ -0,0 +1,95 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """ResNet preset configurations."""
15
+
16
+ backbone_presets = {
17
+ "resnet_18_imagenet": {
18
+ "metadata": {
19
+ "description": (
20
+ "18-layer ResNet model pre-trained on the ImageNet 1k dataset "
21
+ "at a 224x224 resolution."
22
+ ),
23
+ "params": 11186112,
24
+ "official_name": "ResNet",
25
+ "path": "resnet",
26
+ "model_card": "https://arxiv.org/abs/2110.00476",
27
+ },
28
+ "kaggle_handle": "kaggle://kerashub/resnetv1/keras/resnet_18_imagenet/2",
29
+ },
30
+ "resnet_50_imagenet": {
31
+ "metadata": {
32
+ "description": (
33
+ "50-layer ResNet model pre-trained on the ImageNet 1k dataset "
34
+ "at a 224x224 resolution."
35
+ ),
36
+ "params": 23561152,
37
+ "official_name": "ResNet",
38
+ "path": "resnet",
39
+ "model_card": "https://arxiv.org/abs/2110.00476",
40
+ },
41
+ "kaggle_handle": "kaggle://kerashub/resnetv1/keras/resnet_50_imagenet/2",
42
+ },
43
+ "resnet_101_imagenet": {
44
+ "metadata": {
45
+ "description": (
46
+ "101-layer ResNet model pre-trained on the ImageNet 1k dataset "
47
+ "at a 224x224 resolution."
48
+ ),
49
+ "params": 42605504,
50
+ "official_name": "ResNet",
51
+ "path": "resnet",
52
+ "model_card": "https://arxiv.org/abs/2110.00476",
53
+ },
54
+ "kaggle_handle": "kaggle://kerashub/resnetv1/keras/resnet_101_imagenet/2",
55
+ },
56
+ "resnet_152_imagenet": {
57
+ "metadata": {
58
+ "description": (
59
+ "152-layer ResNet model pre-trained on the ImageNet 1k dataset "
60
+ "at a 224x224 resolution."
61
+ ),
62
+ "params": 58295232,
63
+ "official_name": "ResNet",
64
+ "path": "resnet",
65
+ "model_card": "https://arxiv.org/abs/2110.00476",
66
+ },
67
+ "kaggle_handle": "kaggle://kerashub/resnetv1/keras/resnet_152_imagenet/2",
68
+ },
69
+ "resnet_v2_50_imagenet": {
70
+ "metadata": {
71
+ "description": (
72
+ "50-layer ResNetV2 model pre-trained on the ImageNet 1k "
73
+ "dataset at a 224x224 resolution."
74
+ ),
75
+ "params": 23561152,
76
+ "official_name": "ResNet",
77
+ "path": "resnet",
78
+ "model_card": "https://arxiv.org/abs/2110.00476",
79
+ },
80
+ "kaggle_handle": "kaggle://kerashub/resnetv2/keras/resnet_v2_50_imagenet/2",
81
+ },
82
+ "resnet_v2_101_imagenet": {
83
+ "metadata": {
84
+ "description": (
85
+ "101-layer ResNetV2 model pre-trained on the ImageNet 1k "
86
+ "dataset at a 224x224 resolution."
87
+ ),
88
+ "params": 42605504,
89
+ "official_name": "ResNet",
90
+ "path": "resnet",
91
+ "model_card": "https://arxiv.org/abs/2110.00476",
92
+ },
93
+ "kaggle_handle": "kaggle://kerashub/resnetv2/keras/resnet_v2_101_imagenet/2",
94
+ },
95
+ }
@@ -0,0 +1,13 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,175 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+
17
+ import keras
18
+ from keras import ops
19
+
20
+ from keras_hub.src.bounding_box.converters import convert_format
21
+
22
+
23
+ class AnchorGenerator(keras.layers.Layer):
24
+ """Generates anchor boxes for object detection tasks.
25
+
26
+ This layer creates a set of anchor boxes (also known as default boxes or
27
+ priors) for use in object detection models, particularly those utilizing
28
+ Feature Pyramid Networks (FPN). It generates anchors across multiple
29
+ pyramid levels, with various scales and aspect ratios.
30
+
31
+ Feature Pyramid Levels:
32
+ - Levels typically range from 2 to 6 (P2 to P7), corresponding to different
33
+ resolutions of the input image.
34
+ - Each level l has a stride of 2^l pixels relative to the input image.
35
+ - Lower levels (e.g., P2) have higher resolution and are used for
36
+ detecting smaller objects.
37
+ - Higher levels (e.g., P7) have lower resolution and are used
38
+ for larger objects.
39
+
40
+ Args:
41
+ bounding_box_format (str): The format of the bounding boxes
42
+ to be generated. Expected to be a string like 'xyxy', 'xywh', etc.
43
+ min_level (int): Minimum level of the output feature pyramid.
44
+ max_level (int): Maximum level of the output feature pyramid.
45
+ num_scales (int): Number of intermediate scales added on each level.
46
+ For example, num_scales=2 adds one additional intermediate anchor
47
+ scale [2^0, 2^0.5] on each level.
48
+ aspect_ratios (list of float): Aspect ratios of anchors added on
49
+ each level. Each number indicates the ratio of width to height.
50
+ anchor_size (float): Scale of size of the base anchor relative to the
51
+ feature stride 2^level.
52
+
53
+ Call arguments:
54
+ images (Optional[Tensor]): An image tensor with shape `[B, H, W, C]` or
55
+ `[H, W, C]`. If provided, its shape will be used to determine anchor
56
+ sizes.
57
+
58
+ Returns:
59
+ Dict: A dictionary mapping feature levels
60
+ (e.g., 'P3', 'P4', etc.) to anchor boxes. Each entry contains a tensor
61
+ of shape `(H/stride * W/stride * num_anchors_per_location, 4)`,
62
+ where H and W are the height and width of the image, stride is 2^level,
63
+ and num_anchors_per_location is `num_scales * len(aspect_ratios)`.
64
+
65
+ Example:
66
+ ```python
67
+ anchor_generator = AnchorGenerator(
68
+ bounding_box_format='xyxy',
69
+ min_level=3,
70
+ max_level=7,
71
+ num_scales=3,
72
+ aspect_ratios=[0.5, 1.0, 2.0],
73
+ anchor_size=4.0,
74
+ )
75
+ anchors = anchor_generator(images=keas.ops.ones(shape=(2, 640, 480, 3)))
76
+ ```
77
+ """
78
+
79
+ def __init__(
80
+ self,
81
+ bounding_box_format,
82
+ min_level,
83
+ max_level,
84
+ num_scales,
85
+ aspect_ratios,
86
+ anchor_size,
87
+ **kwargs,
88
+ ):
89
+ super().__init__(**kwargs)
90
+ self.bounding_box_format = bounding_box_format
91
+ self.min_level = min_level
92
+ self.max_level = max_level
93
+ self.num_scales = num_scales
94
+ self.aspect_ratios = aspect_ratios
95
+ self.anchor_size = anchor_size
96
+ self.built = True
97
+
98
+ def call(self, images):
99
+ images_shape = ops.shape(images)
100
+ if len(images_shape) == 4:
101
+ image_shape = images_shape[1:-1]
102
+ else:
103
+ image_shape = images_shape[:-1]
104
+
105
+ image_shape = tuple(image_shape)
106
+
107
+ multilevel_boxes = {}
108
+ for level in range(self.min_level, self.max_level + 1):
109
+ boxes_l = []
110
+ # Calculate the feature map size for this level
111
+ feat_size_y = math.ceil(image_shape[0] / 2**level)
112
+ feat_size_x = math.ceil(image_shape[1] / 2**level)
113
+
114
+ # Calculate the stride (step size) for this level
115
+ stride_y = ops.cast(image_shape[0] / feat_size_y, "float32")
116
+ stride_x = ops.cast(image_shape[1] / feat_size_x, "float32")
117
+
118
+ # Generate anchor center points
119
+ # Start from stride/2 to center anchors on pixels
120
+ cx = ops.arange(stride_x / 2, image_shape[1], stride_x)
121
+ cy = ops.arange(stride_y / 2, image_shape[0], stride_y)
122
+
123
+ # Create a grid of anchor centers
124
+ cx_grid, cy_grid = ops.meshgrid(cx, cy)
125
+
126
+ for scale in range(self.num_scales):
127
+ for aspect_ratio in self.aspect_ratios:
128
+ # Calculate the intermediate scale factor
129
+ intermidate_scale = 2 ** (scale / self.num_scales)
130
+ # Calculate the base anchor size for this level and scale
131
+ base_anchor_size = (
132
+ self.anchor_size * 2**level * intermidate_scale
133
+ )
134
+ # Adjust anchor dimensions based on aspect ratio
135
+ aspect_x = aspect_ratio**0.5
136
+ aspect_y = aspect_ratio**-0.5
137
+ half_anchor_size_x = base_anchor_size * aspect_x / 2.0
138
+ half_anchor_size_y = base_anchor_size * aspect_y / 2.0
139
+
140
+ # Generate anchor boxes (y1, x1, y2, x2 format)
141
+ boxes = ops.stack(
142
+ [
143
+ cy_grid - half_anchor_size_y,
144
+ cx_grid - half_anchor_size_x,
145
+ cy_grid + half_anchor_size_y,
146
+ cx_grid + half_anchor_size_x,
147
+ ],
148
+ axis=-1,
149
+ )
150
+ boxes_l.append(boxes)
151
+ # Concat anchors on the same level to tensor shape HxWx(Ax4)
152
+ boxes_l = ops.concatenate(boxes_l, axis=-1)
153
+ boxes_l = ops.reshape(boxes_l, (-1, 4))
154
+ # Convert to user defined
155
+ multilevel_boxes[f"P{level}"] = convert_format(
156
+ boxes_l,
157
+ source="yxyx",
158
+ target=self.bounding_box_format,
159
+ )
160
+ return multilevel_boxes
161
+
162
+ def compute_output_shape(self, input_shape):
163
+ multilevel_boxes_shape = {}
164
+ for level in range(self.min_level, self.max_level + 1):
165
+ multilevel_boxes_shape[f"P{level}"] = (None, None, 4)
166
+ return multilevel_boxes_shape
167
+
168
+ @property
169
+ def anchors_per_location(self):
170
+ """
171
+ The `anchors_per_location` property returns the number of anchors
172
+ generated per pixel location, which is equal to
173
+ `num_scales * len(aspect_ratios)`.
174
+ """
175
+ return self.num_scales * len(self.aspect_ratios)