keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (297) hide show
  1. keras_hub/__init__.py +52 -0
  2. keras_hub/api/__init__.py +27 -0
  3. keras_hub/api/layers/__init__.py +47 -0
  4. keras_hub/api/metrics/__init__.py +24 -0
  5. keras_hub/api/models/__init__.py +249 -0
  6. keras_hub/api/samplers/__init__.py +29 -0
  7. keras_hub/api/tokenizers/__init__.py +35 -0
  8. keras_hub/src/__init__.py +13 -0
  9. keras_hub/src/api_export.py +53 -0
  10. keras_hub/src/layers/__init__.py +13 -0
  11. keras_hub/src/layers/modeling/__init__.py +13 -0
  12. keras_hub/src/layers/modeling/alibi_bias.py +143 -0
  13. keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
  14. keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
  15. keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
  16. keras_hub/src/layers/modeling/position_embedding.py +123 -0
  17. keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
  18. keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
  19. keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
  20. keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
  21. keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
  22. keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
  23. keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
  24. keras_hub/src/layers/preprocessing/__init__.py +13 -0
  25. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
  26. keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
  27. keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
  28. keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
  29. keras_hub/src/layers/preprocessing/random_swap.py +267 -0
  30. keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
  31. keras_hub/src/metrics/__init__.py +13 -0
  32. keras_hub/src/metrics/bleu.py +394 -0
  33. keras_hub/src/metrics/edit_distance.py +197 -0
  34. keras_hub/src/metrics/perplexity.py +181 -0
  35. keras_hub/src/metrics/rouge_base.py +204 -0
  36. keras_hub/src/metrics/rouge_l.py +97 -0
  37. keras_hub/src/metrics/rouge_n.py +125 -0
  38. keras_hub/src/models/__init__.py +13 -0
  39. keras_hub/src/models/albert/__init__.py +20 -0
  40. keras_hub/src/models/albert/albert_backbone.py +267 -0
  41. keras_hub/src/models/albert/albert_classifier.py +202 -0
  42. keras_hub/src/models/albert/albert_masked_lm.py +129 -0
  43. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
  44. keras_hub/src/models/albert/albert_preprocessor.py +206 -0
  45. keras_hub/src/models/albert/albert_presets.py +70 -0
  46. keras_hub/src/models/albert/albert_tokenizer.py +119 -0
  47. keras_hub/src/models/backbone.py +311 -0
  48. keras_hub/src/models/bart/__init__.py +20 -0
  49. keras_hub/src/models/bart/bart_backbone.py +261 -0
  50. keras_hub/src/models/bart/bart_preprocessor.py +276 -0
  51. keras_hub/src/models/bart/bart_presets.py +74 -0
  52. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
  53. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
  54. keras_hub/src/models/bart/bart_tokenizer.py +124 -0
  55. keras_hub/src/models/bert/__init__.py +23 -0
  56. keras_hub/src/models/bert/bert_backbone.py +227 -0
  57. keras_hub/src/models/bert/bert_classifier.py +183 -0
  58. keras_hub/src/models/bert/bert_masked_lm.py +131 -0
  59. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
  60. keras_hub/src/models/bert/bert_preprocessor.py +184 -0
  61. keras_hub/src/models/bert/bert_presets.py +147 -0
  62. keras_hub/src/models/bert/bert_tokenizer.py +112 -0
  63. keras_hub/src/models/bloom/__init__.py +20 -0
  64. keras_hub/src/models/bloom/bloom_attention.py +186 -0
  65. keras_hub/src/models/bloom/bloom_backbone.py +173 -0
  66. keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
  67. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
  68. keras_hub/src/models/bloom/bloom_decoder.py +206 -0
  69. keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
  70. keras_hub/src/models/bloom/bloom_presets.py +121 -0
  71. keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
  72. keras_hub/src/models/causal_lm.py +383 -0
  73. keras_hub/src/models/classifier.py +109 -0
  74. keras_hub/src/models/csp_darknet/__init__.py +13 -0
  75. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
  76. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
  77. keras_hub/src/models/deberta_v3/__init__.py +24 -0
  78. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
  79. keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
  80. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
  81. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
  82. keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
  83. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
  84. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
  85. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
  86. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
  87. keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
  88. keras_hub/src/models/densenet/__init__.py +13 -0
  89. keras_hub/src/models/densenet/densenet_backbone.py +210 -0
  90. keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
  91. keras_hub/src/models/distil_bert/__init__.py +26 -0
  92. keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
  93. keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
  94. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
  95. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
  96. keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
  97. keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
  98. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
  99. keras_hub/src/models/electra/__init__.py +20 -0
  100. keras_hub/src/models/electra/electra_backbone.py +247 -0
  101. keras_hub/src/models/electra/electra_preprocessor.py +154 -0
  102. keras_hub/src/models/electra/electra_presets.py +95 -0
  103. keras_hub/src/models/electra/electra_tokenizer.py +104 -0
  104. keras_hub/src/models/f_net/__init__.py +20 -0
  105. keras_hub/src/models/f_net/f_net_backbone.py +236 -0
  106. keras_hub/src/models/f_net/f_net_classifier.py +154 -0
  107. keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
  108. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
  109. keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
  110. keras_hub/src/models/f_net/f_net_presets.py +43 -0
  111. keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
  112. keras_hub/src/models/falcon/__init__.py +20 -0
  113. keras_hub/src/models/falcon/falcon_attention.py +156 -0
  114. keras_hub/src/models/falcon/falcon_backbone.py +164 -0
  115. keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
  116. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
  117. keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
  118. keras_hub/src/models/falcon/falcon_presets.py +30 -0
  119. keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
  120. keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
  121. keras_hub/src/models/feature_pyramid_backbone.py +73 -0
  122. keras_hub/src/models/gemma/__init__.py +20 -0
  123. keras_hub/src/models/gemma/gemma_attention.py +250 -0
  124. keras_hub/src/models/gemma/gemma_backbone.py +316 -0
  125. keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
  126. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
  127. keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
  128. keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
  129. keras_hub/src/models/gemma/gemma_presets.py +248 -0
  130. keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
  131. keras_hub/src/models/gemma/rms_normalization.py +40 -0
  132. keras_hub/src/models/gpt2/__init__.py +20 -0
  133. keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
  134. keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
  135. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
  136. keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
  137. keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
  138. keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
  139. keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
  140. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
  141. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
  142. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
  143. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
  144. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
  145. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
  146. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
  147. keras_hub/src/models/image_classifier.py +90 -0
  148. keras_hub/src/models/llama/__init__.py +20 -0
  149. keras_hub/src/models/llama/llama_attention.py +225 -0
  150. keras_hub/src/models/llama/llama_backbone.py +188 -0
  151. keras_hub/src/models/llama/llama_causal_lm.py +327 -0
  152. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
  153. keras_hub/src/models/llama/llama_decoder.py +246 -0
  154. keras_hub/src/models/llama/llama_layernorm.py +48 -0
  155. keras_hub/src/models/llama/llama_preprocessor.py +189 -0
  156. keras_hub/src/models/llama/llama_presets.py +80 -0
  157. keras_hub/src/models/llama/llama_tokenizer.py +84 -0
  158. keras_hub/src/models/llama3/__init__.py +20 -0
  159. keras_hub/src/models/llama3/llama3_backbone.py +84 -0
  160. keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
  161. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
  162. keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
  163. keras_hub/src/models/llama3/llama3_presets.py +69 -0
  164. keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
  165. keras_hub/src/models/masked_lm.py +101 -0
  166. keras_hub/src/models/mistral/__init__.py +20 -0
  167. keras_hub/src/models/mistral/mistral_attention.py +238 -0
  168. keras_hub/src/models/mistral/mistral_backbone.py +203 -0
  169. keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
  170. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
  171. keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
  172. keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
  173. keras_hub/src/models/mistral/mistral_presets.py +48 -0
  174. keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
  175. keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
  176. keras_hub/src/models/mix_transformer/__init__.py +13 -0
  177. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
  178. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
  179. keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
  180. keras_hub/src/models/opt/__init__.py +20 -0
  181. keras_hub/src/models/opt/opt_backbone.py +173 -0
  182. keras_hub/src/models/opt/opt_causal_lm.py +301 -0
  183. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
  184. keras_hub/src/models/opt/opt_preprocessor.py +188 -0
  185. keras_hub/src/models/opt/opt_presets.py +72 -0
  186. keras_hub/src/models/opt/opt_tokenizer.py +116 -0
  187. keras_hub/src/models/pali_gemma/__init__.py +23 -0
  188. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
  189. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
  190. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
  191. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
  192. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
  193. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
  194. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
  195. keras_hub/src/models/phi3/__init__.py +20 -0
  196. keras_hub/src/models/phi3/phi3_attention.py +260 -0
  197. keras_hub/src/models/phi3/phi3_backbone.py +224 -0
  198. keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
  199. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
  200. keras_hub/src/models/phi3/phi3_decoder.py +260 -0
  201. keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
  202. keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
  203. keras_hub/src/models/phi3/phi3_presets.py +50 -0
  204. keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
  205. keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
  206. keras_hub/src/models/preprocessor.py +207 -0
  207. keras_hub/src/models/resnet/__init__.py +13 -0
  208. keras_hub/src/models/resnet/resnet_backbone.py +612 -0
  209. keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
  210. keras_hub/src/models/roberta/__init__.py +20 -0
  211. keras_hub/src/models/roberta/roberta_backbone.py +184 -0
  212. keras_hub/src/models/roberta/roberta_classifier.py +209 -0
  213. keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
  214. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
  215. keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
  216. keras_hub/src/models/roberta/roberta_presets.py +43 -0
  217. keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
  218. keras_hub/src/models/seq_2_seq_lm.py +54 -0
  219. keras_hub/src/models/t5/__init__.py +20 -0
  220. keras_hub/src/models/t5/t5_backbone.py +261 -0
  221. keras_hub/src/models/t5/t5_layer_norm.py +35 -0
  222. keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
  223. keras_hub/src/models/t5/t5_presets.py +95 -0
  224. keras_hub/src/models/t5/t5_tokenizer.py +100 -0
  225. keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
  226. keras_hub/src/models/task.py +419 -0
  227. keras_hub/src/models/vgg/__init__.py +13 -0
  228. keras_hub/src/models/vgg/vgg_backbone.py +158 -0
  229. keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
  230. keras_hub/src/models/vit_det/__init__.py +13 -0
  231. keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
  232. keras_hub/src/models/vit_det/vit_layers.py +565 -0
  233. keras_hub/src/models/whisper/__init__.py +20 -0
  234. keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
  235. keras_hub/src/models/whisper/whisper_backbone.py +305 -0
  236. keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
  237. keras_hub/src/models/whisper/whisper_decoder.py +141 -0
  238. keras_hub/src/models/whisper/whisper_encoder.py +106 -0
  239. keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
  240. keras_hub/src/models/whisper/whisper_presets.py +148 -0
  241. keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
  242. keras_hub/src/models/xlm_roberta/__init__.py +26 -0
  243. keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
  244. keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
  245. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
  246. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
  247. keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
  248. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
  249. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
  250. keras_hub/src/models/xlnet/__init__.py +13 -0
  251. keras_hub/src/models/xlnet/relative_attention.py +459 -0
  252. keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
  253. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
  254. keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
  255. keras_hub/src/samplers/__init__.py +13 -0
  256. keras_hub/src/samplers/beam_sampler.py +207 -0
  257. keras_hub/src/samplers/contrastive_sampler.py +231 -0
  258. keras_hub/src/samplers/greedy_sampler.py +50 -0
  259. keras_hub/src/samplers/random_sampler.py +77 -0
  260. keras_hub/src/samplers/sampler.py +237 -0
  261. keras_hub/src/samplers/serialization.py +97 -0
  262. keras_hub/src/samplers/top_k_sampler.py +92 -0
  263. keras_hub/src/samplers/top_p_sampler.py +113 -0
  264. keras_hub/src/tests/__init__.py +13 -0
  265. keras_hub/src/tests/test_case.py +608 -0
  266. keras_hub/src/tokenizers/__init__.py +13 -0
  267. keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
  268. keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
  269. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
  270. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
  271. keras_hub/src/tokenizers/tokenizer.py +235 -0
  272. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
  273. keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
  274. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
  275. keras_hub/src/utils/__init__.py +13 -0
  276. keras_hub/src/utils/keras_utils.py +130 -0
  277. keras_hub/src/utils/pipeline_model.py +293 -0
  278. keras_hub/src/utils/preset_utils.py +621 -0
  279. keras_hub/src/utils/python_utils.py +21 -0
  280. keras_hub/src/utils/tensor_utils.py +206 -0
  281. keras_hub/src/utils/timm/__init__.py +13 -0
  282. keras_hub/src/utils/timm/convert.py +37 -0
  283. keras_hub/src/utils/timm/convert_resnet.py +171 -0
  284. keras_hub/src/utils/transformers/__init__.py +13 -0
  285. keras_hub/src/utils/transformers/convert.py +101 -0
  286. keras_hub/src/utils/transformers/convert_bert.py +173 -0
  287. keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
  288. keras_hub/src/utils/transformers/convert_gemma.py +187 -0
  289. keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
  290. keras_hub/src/utils/transformers/convert_llama3.py +136 -0
  291. keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
  292. keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
  293. keras_hub/src/version_utils.py +23 -0
  294. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
  295. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
  296. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
  297. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,124 @@
1
+ # Copyright 2023 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import keras
15
+
16
+ from keras_hub.src.api_export import keras_hub_export
17
+ from keras_hub.src.models.image_classifier import ImageClassifier
18
+ from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone
19
+
20
+
21
+ @keras_hub_export("keras_hub.models.VGGImageClassifier")
22
+ class VGGImageClassifier(ImageClassifier):
23
+ """VGG16 image classifier task model.
24
+
25
+ Args:
26
+ backbone: A `keras_hub.models.VGGBackbone` instance.
27
+ num_classes: int, number of classes to predict.
28
+ pooling: str, type of pooling layer. Must be one of "avg", "max".
29
+ activation: Optional `str` or callable, defaults to "softmax". The
30
+ activation function to use on the Dense layer. Set `activation=None`
31
+ to return the output logits.
32
+
33
+ To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)`
34
+ labels where `x` is a string and `y` is a integer from `[0, num_classes)`.
35
+ All `ImageClassifier` tasks include a `from_preset()` constructor which can be
36
+ used to load a pre-trained config and weights.
37
+
38
+ Examples:
39
+ Train from preset
40
+ ```python
41
+ # Load preset and train
42
+ images = np.ones((2, 224, 224, 3), dtype="float32")
43
+ labels = [0, 3]
44
+ classifier = keras_hub.models.VGGImageClassifier.from_preset(
45
+ 'vgg_16_image_classifier')
46
+ classifier.fit(x=images, y=labels, batch_size=2)
47
+
48
+ # Re-compile (e.g., with a new learning rate).
49
+ classifier.compile(
50
+ loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
51
+ optimizer=keras.optimizers.Adam(5e-5),
52
+ jit_compile=True,
53
+ )
54
+
55
+ # Access backbone programmatically (e.g., to change `trainable`).
56
+ classifier.backbone.trainable = False
57
+ # Fit again.
58
+ classifier.fit(x=images, y=labels, batch_size=2)
59
+ ```
60
+ Custom backbone
61
+ ```python
62
+ images = np.ones((2, 224, 224, 3), dtype="float32")
63
+ labels = [0, 3]
64
+
65
+ backbone = keras_hub.models.VGGBackbone(
66
+ stackwise_num_repeats = [2, 2, 3, 3, 3],
67
+ stackwise_num_filters = [64, 128, 256, 512, 512],
68
+ image_shape = (224, 224, 3),
69
+ include_rescaling = False,
70
+ pooling = "avg",
71
+ )
72
+ classifier = keras_hub.models.VGGImageClassifier(
73
+ backbone=backbone,
74
+ num_classes=4,
75
+ )
76
+ classifier.fit(x=images, y=labels, batch_size=2)
77
+ ```
78
+ """
79
+
80
+ backbone_cls = VGGBackbone
81
+
82
+ def __init__(
83
+ self,
84
+ backbone,
85
+ num_classes,
86
+ activation="softmax",
87
+ preprocessor=None, # adding this dummy arg for saved model test
88
+ # TODO: once preprocessor flow is figured out, this needs to be updated
89
+ **kwargs,
90
+ ):
91
+ # === Layers ===
92
+ self.backbone = backbone
93
+ self.output_dense = keras.layers.Dense(
94
+ num_classes,
95
+ activation=activation,
96
+ name="predictions",
97
+ )
98
+
99
+ # === Functional Model ===
100
+ inputs = self.backbone.input
101
+ x = self.backbone(inputs)
102
+ outputs = self.output_dense(x)
103
+
104
+ # Instantiate using Functional API Model constructor
105
+ super().__init__(
106
+ inputs=inputs,
107
+ outputs=outputs,
108
+ **kwargs,
109
+ )
110
+
111
+ # === Config ===
112
+ self.num_classes = num_classes
113
+ self.activation = activation
114
+
115
+ def get_config(self):
116
+ # Backbone serialized in `super`
117
+ config = super().get_config()
118
+ config.update(
119
+ {
120
+ "num_classes": self.num_classes,
121
+ "activation": self.activation,
122
+ }
123
+ )
124
+ return config
@@ -0,0 +1,13 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,204 @@
1
+ # Copyright 2024 The KerasCV Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import keras
16
+ from keras import ops
17
+
18
+ from keras_hub.src.api_export import keras_hub_export
19
+ from keras_hub.src.models.backbone import Backbone
20
+ from keras_hub.src.models.vit_det.vit_layers import AddPositionalEmbedding
21
+ from keras_hub.src.models.vit_det.vit_layers import ViTDetPatchingAndEmbedding
22
+ from keras_hub.src.models.vit_det.vit_layers import WindowedTransformerEncoder
23
+
24
+
25
+ @keras_hub_export("keras_hub.models.ViTDetBackbone")
26
+ class ViTDetBackbone(Backbone):
27
+ """An implementation of ViT image encoder.
28
+
29
+ The ViTDetBackbone uses a windowed transformer encoder and relative
30
+ positional encodings. The code has been adapted from [Segment Anything
31
+ paper](https://arxiv.org/abs/2304.02643), [Segment Anything GitHub](
32
+ https://github.com/facebookresearch/segment-anything) and [Detectron2](
33
+ https://github.com/facebookresearch/detectron2).
34
+
35
+ Args:
36
+ hidden_size (int): The latent dimensionality to be projected
37
+ into in the output of each stacked windowed transformer encoder.
38
+ num_layers (int): The number of transformer encoder layers to
39
+ stack in the Vision Transformer.
40
+ intermediate_dim (int): The dimensionality of the hidden Dense
41
+ layer in the transformer MLP head.
42
+ num_heads (int): the number of heads to use in the
43
+ `MultiHeadAttentionWithRelativePE` layer of each transformer
44
+ encoder.
45
+ global_attention_layer_indices (list): Indexes for blocks using
46
+ global attention.
47
+ image_shape (tuple[int], optional): The size of the input image in
48
+ `(H, W, C)` format. Defaults to `(1024, 1024, 3)`.
49
+ include_rescaling (bool, optional): Whether to rescale the inputs. If
50
+ set to `True`, inputs will be passed through a
51
+ `Rescaling(1/255.0)` layer. Defaults to `False`.
52
+ patch_size (int, optional): the patch size to be supplied to the
53
+ Patching layer to turn input images into a flattened sequence of
54
+ patches. Defaults to `16`.
55
+ num_output_channels (int, optional): The number of channels (features)
56
+ in the output (image encodings). Defaults to `256`.
57
+ use_bias (bool, optional): Whether to use bias to project the keys,
58
+ queries, and values in the attention layer. Defaults to `True`.
59
+ use_abs_pos (bool, optional): Whether to add absolute positional
60
+ embeddings to the output patches. Defaults to `True`.
61
+ use_rel_pos (bool, optional): Whether to use relative positional
62
+ emcodings in the attention layer. Defaults to `True`.
63
+ window_size (int, optional): The size of the window for windowed
64
+ attention in the transformer encoder blocks. Defaults to `14`.
65
+ layer_norm_epsilon (int, optional): The epsilon to use in the layer
66
+ normalization blocks in transformer encoder. Defaults to `1e-6`.
67
+
68
+ Examples:
69
+ ```python
70
+ input_data = np.ones((2, 224, 224, 3), dtype="float32")
71
+
72
+ # Pretrained ViTDetBackbone backbone.
73
+ model = keras_hub.models.ViTDetBackbone.from_preset("vit_det")
74
+ model(input_data)
75
+
76
+ # Randomly initialized ViTDetBackbone backbone with a custom config.
77
+ model = keras_hub.models.ViTDetBackbone(
78
+ image_shape = (16, 16, 3),
79
+ patch_size = 2,
80
+ hidden_size = 4,
81
+ num_layers = 2,
82
+ global_attention_layer_indices = [2, 5, 8, 11],
83
+ intermediate_dim = 4 * 4,
84
+ num_heads = 2,
85
+ num_output_channels = 2,
86
+ window_size = 2,
87
+ )
88
+ model(input_data)
89
+ ```
90
+ """
91
+
92
+ def __init__(
93
+ self,
94
+ hidden_size,
95
+ num_layers,
96
+ intermediate_dim,
97
+ num_heads,
98
+ global_attention_layer_indices,
99
+ include_rescaling=True,
100
+ image_shape=(1024, 1024, 3),
101
+ patch_size=16,
102
+ num_output_channels=256,
103
+ use_bias=True,
104
+ use_abs_pos=True,
105
+ use_rel_pos=True,
106
+ window_size=14,
107
+ layer_norm_epsilon=1e-6,
108
+ **kwargs
109
+ ):
110
+ # === Functional model ===
111
+ img_input = keras.layers.Input(shape=image_shape)
112
+ # Check that the input image is well specified.
113
+ if img_input.shape[-3] is None or img_input.shape[-2] is None:
114
+ raise ValueError(
115
+ "Height and width of the image must be specified"
116
+ " in `image_shape`."
117
+ )
118
+ if img_input.shape[-3] != img_input.shape[-2]:
119
+ raise ValueError(
120
+ "Input image must be square i.e. the height must"
121
+ " be equal to the width in the `image_shape`"
122
+ " tuple/tensor."
123
+ )
124
+ img_size = img_input.shape[-3]
125
+ x = img_input
126
+ if include_rescaling:
127
+ # Use common rescaling strategy across keras_cv
128
+ x = keras.layers.Rescaling(1.0 / 255.0)(x)
129
+ # VITDet scales inputs based on the standard ImageNet mean/stddev.
130
+ x = (x - ops.array([0.485, 0.456, 0.406], dtype=x.dtype)) / (
131
+ ops.array([0.229, 0.224, 0.225], dtype=x.dtype)
132
+ )
133
+ x = ViTDetPatchingAndEmbedding(
134
+ kernel_size=(patch_size, patch_size),
135
+ strides=(patch_size, patch_size),
136
+ embed_dim=hidden_size,
137
+ )(x)
138
+ if use_abs_pos:
139
+ x = AddPositionalEmbedding(img_size, patch_size, hidden_size)(x)
140
+ for i in range(num_layers):
141
+ x = WindowedTransformerEncoder(
142
+ project_dim=hidden_size,
143
+ intermediate_dim=intermediate_dim,
144
+ num_heads=num_heads,
145
+ use_bias=use_bias,
146
+ use_rel_pos=use_rel_pos,
147
+ window_size=(
148
+ window_size
149
+ if i not in global_attention_layer_indices
150
+ else 0
151
+ ),
152
+ input_size=(img_size // patch_size, img_size // patch_size),
153
+ )(x)
154
+ x = keras.layers.Conv2D(
155
+ filters=num_output_channels, kernel_size=1, use_bias=False
156
+ )(x)
157
+ x = keras.layers.LayerNormalization(epsilon=1e-6)(x)
158
+ x = keras.layers.Conv2D(
159
+ filters=num_output_channels,
160
+ kernel_size=3,
161
+ padding="same",
162
+ use_bias=False,
163
+ )(x)
164
+ x = keras.layers.LayerNormalization(epsilon=1e-6)(x)
165
+
166
+ super().__init__(inputs=img_input, outputs=x, **kwargs)
167
+
168
+ # === Config ===
169
+ self.patch_size = patch_size
170
+ self.image_shape = image_shape
171
+ self.hidden_size = hidden_size
172
+ self.num_layers = num_layers
173
+ self.intermediate_dim = intermediate_dim
174
+ self.num_heads = num_heads
175
+ self.num_output_channels = num_output_channels
176
+ self.use_bias = use_bias
177
+ self.use_rel_pos = use_rel_pos
178
+ self.use_abs_pos = use_abs_pos
179
+ self.window_size = window_size
180
+ self.global_attention_layer_indices = global_attention_layer_indices
181
+ self.layer_norm_epsilon = layer_norm_epsilon
182
+ self.include_rescaling = include_rescaling
183
+
184
+ def get_config(self):
185
+ config = super().get_config()
186
+ config.update(
187
+ {
188
+ "image_shape": self.image_shape,
189
+ "include_rescaling": self.include_rescaling,
190
+ "patch_size": self.patch_size,
191
+ "hidden_size": self.hidden_size,
192
+ "num_layers": self.num_layers,
193
+ "intermediate_dim": self.intermediate_dim,
194
+ "num_heads": self.num_heads,
195
+ "num_output_channels": self.num_output_channels,
196
+ "use_bias": self.use_bias,
197
+ "use_abs_pos": self.use_abs_pos,
198
+ "use_rel_pos": self.use_rel_pos,
199
+ "window_size": self.window_size,
200
+ "global_attention_layer_indices": self.global_attention_layer_indices,
201
+ "layer_norm_epsilon": self.layer_norm_epsilon,
202
+ }
203
+ )
204
+ return config