keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (297) hide show
  1. keras_hub/__init__.py +52 -0
  2. keras_hub/api/__init__.py +27 -0
  3. keras_hub/api/layers/__init__.py +47 -0
  4. keras_hub/api/metrics/__init__.py +24 -0
  5. keras_hub/api/models/__init__.py +249 -0
  6. keras_hub/api/samplers/__init__.py +29 -0
  7. keras_hub/api/tokenizers/__init__.py +35 -0
  8. keras_hub/src/__init__.py +13 -0
  9. keras_hub/src/api_export.py +53 -0
  10. keras_hub/src/layers/__init__.py +13 -0
  11. keras_hub/src/layers/modeling/__init__.py +13 -0
  12. keras_hub/src/layers/modeling/alibi_bias.py +143 -0
  13. keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
  14. keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
  15. keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
  16. keras_hub/src/layers/modeling/position_embedding.py +123 -0
  17. keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
  18. keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
  19. keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
  20. keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
  21. keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
  22. keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
  23. keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
  24. keras_hub/src/layers/preprocessing/__init__.py +13 -0
  25. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
  26. keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
  27. keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
  28. keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
  29. keras_hub/src/layers/preprocessing/random_swap.py +267 -0
  30. keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
  31. keras_hub/src/metrics/__init__.py +13 -0
  32. keras_hub/src/metrics/bleu.py +394 -0
  33. keras_hub/src/metrics/edit_distance.py +197 -0
  34. keras_hub/src/metrics/perplexity.py +181 -0
  35. keras_hub/src/metrics/rouge_base.py +204 -0
  36. keras_hub/src/metrics/rouge_l.py +97 -0
  37. keras_hub/src/metrics/rouge_n.py +125 -0
  38. keras_hub/src/models/__init__.py +13 -0
  39. keras_hub/src/models/albert/__init__.py +20 -0
  40. keras_hub/src/models/albert/albert_backbone.py +267 -0
  41. keras_hub/src/models/albert/albert_classifier.py +202 -0
  42. keras_hub/src/models/albert/albert_masked_lm.py +129 -0
  43. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
  44. keras_hub/src/models/albert/albert_preprocessor.py +206 -0
  45. keras_hub/src/models/albert/albert_presets.py +70 -0
  46. keras_hub/src/models/albert/albert_tokenizer.py +119 -0
  47. keras_hub/src/models/backbone.py +311 -0
  48. keras_hub/src/models/bart/__init__.py +20 -0
  49. keras_hub/src/models/bart/bart_backbone.py +261 -0
  50. keras_hub/src/models/bart/bart_preprocessor.py +276 -0
  51. keras_hub/src/models/bart/bart_presets.py +74 -0
  52. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
  53. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
  54. keras_hub/src/models/bart/bart_tokenizer.py +124 -0
  55. keras_hub/src/models/bert/__init__.py +23 -0
  56. keras_hub/src/models/bert/bert_backbone.py +227 -0
  57. keras_hub/src/models/bert/bert_classifier.py +183 -0
  58. keras_hub/src/models/bert/bert_masked_lm.py +131 -0
  59. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
  60. keras_hub/src/models/bert/bert_preprocessor.py +184 -0
  61. keras_hub/src/models/bert/bert_presets.py +147 -0
  62. keras_hub/src/models/bert/bert_tokenizer.py +112 -0
  63. keras_hub/src/models/bloom/__init__.py +20 -0
  64. keras_hub/src/models/bloom/bloom_attention.py +186 -0
  65. keras_hub/src/models/bloom/bloom_backbone.py +173 -0
  66. keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
  67. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
  68. keras_hub/src/models/bloom/bloom_decoder.py +206 -0
  69. keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
  70. keras_hub/src/models/bloom/bloom_presets.py +121 -0
  71. keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
  72. keras_hub/src/models/causal_lm.py +383 -0
  73. keras_hub/src/models/classifier.py +109 -0
  74. keras_hub/src/models/csp_darknet/__init__.py +13 -0
  75. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
  76. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
  77. keras_hub/src/models/deberta_v3/__init__.py +24 -0
  78. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
  79. keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
  80. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
  81. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
  82. keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
  83. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
  84. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
  85. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
  86. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
  87. keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
  88. keras_hub/src/models/densenet/__init__.py +13 -0
  89. keras_hub/src/models/densenet/densenet_backbone.py +210 -0
  90. keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
  91. keras_hub/src/models/distil_bert/__init__.py +26 -0
  92. keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
  93. keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
  94. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
  95. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
  96. keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
  97. keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
  98. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
  99. keras_hub/src/models/electra/__init__.py +20 -0
  100. keras_hub/src/models/electra/electra_backbone.py +247 -0
  101. keras_hub/src/models/electra/electra_preprocessor.py +154 -0
  102. keras_hub/src/models/electra/electra_presets.py +95 -0
  103. keras_hub/src/models/electra/electra_tokenizer.py +104 -0
  104. keras_hub/src/models/f_net/__init__.py +20 -0
  105. keras_hub/src/models/f_net/f_net_backbone.py +236 -0
  106. keras_hub/src/models/f_net/f_net_classifier.py +154 -0
  107. keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
  108. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
  109. keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
  110. keras_hub/src/models/f_net/f_net_presets.py +43 -0
  111. keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
  112. keras_hub/src/models/falcon/__init__.py +20 -0
  113. keras_hub/src/models/falcon/falcon_attention.py +156 -0
  114. keras_hub/src/models/falcon/falcon_backbone.py +164 -0
  115. keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
  116. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
  117. keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
  118. keras_hub/src/models/falcon/falcon_presets.py +30 -0
  119. keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
  120. keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
  121. keras_hub/src/models/feature_pyramid_backbone.py +73 -0
  122. keras_hub/src/models/gemma/__init__.py +20 -0
  123. keras_hub/src/models/gemma/gemma_attention.py +250 -0
  124. keras_hub/src/models/gemma/gemma_backbone.py +316 -0
  125. keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
  126. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
  127. keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
  128. keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
  129. keras_hub/src/models/gemma/gemma_presets.py +248 -0
  130. keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
  131. keras_hub/src/models/gemma/rms_normalization.py +40 -0
  132. keras_hub/src/models/gpt2/__init__.py +20 -0
  133. keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
  134. keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
  135. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
  136. keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
  137. keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
  138. keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
  139. keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
  140. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
  141. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
  142. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
  143. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
  144. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
  145. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
  146. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
  147. keras_hub/src/models/image_classifier.py +90 -0
  148. keras_hub/src/models/llama/__init__.py +20 -0
  149. keras_hub/src/models/llama/llama_attention.py +225 -0
  150. keras_hub/src/models/llama/llama_backbone.py +188 -0
  151. keras_hub/src/models/llama/llama_causal_lm.py +327 -0
  152. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
  153. keras_hub/src/models/llama/llama_decoder.py +246 -0
  154. keras_hub/src/models/llama/llama_layernorm.py +48 -0
  155. keras_hub/src/models/llama/llama_preprocessor.py +189 -0
  156. keras_hub/src/models/llama/llama_presets.py +80 -0
  157. keras_hub/src/models/llama/llama_tokenizer.py +84 -0
  158. keras_hub/src/models/llama3/__init__.py +20 -0
  159. keras_hub/src/models/llama3/llama3_backbone.py +84 -0
  160. keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
  161. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
  162. keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
  163. keras_hub/src/models/llama3/llama3_presets.py +69 -0
  164. keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
  165. keras_hub/src/models/masked_lm.py +101 -0
  166. keras_hub/src/models/mistral/__init__.py +20 -0
  167. keras_hub/src/models/mistral/mistral_attention.py +238 -0
  168. keras_hub/src/models/mistral/mistral_backbone.py +203 -0
  169. keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
  170. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
  171. keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
  172. keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
  173. keras_hub/src/models/mistral/mistral_presets.py +48 -0
  174. keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
  175. keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
  176. keras_hub/src/models/mix_transformer/__init__.py +13 -0
  177. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
  178. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
  179. keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
  180. keras_hub/src/models/opt/__init__.py +20 -0
  181. keras_hub/src/models/opt/opt_backbone.py +173 -0
  182. keras_hub/src/models/opt/opt_causal_lm.py +301 -0
  183. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
  184. keras_hub/src/models/opt/opt_preprocessor.py +188 -0
  185. keras_hub/src/models/opt/opt_presets.py +72 -0
  186. keras_hub/src/models/opt/opt_tokenizer.py +116 -0
  187. keras_hub/src/models/pali_gemma/__init__.py +23 -0
  188. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
  189. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
  190. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
  191. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
  192. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
  193. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
  194. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
  195. keras_hub/src/models/phi3/__init__.py +20 -0
  196. keras_hub/src/models/phi3/phi3_attention.py +260 -0
  197. keras_hub/src/models/phi3/phi3_backbone.py +224 -0
  198. keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
  199. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
  200. keras_hub/src/models/phi3/phi3_decoder.py +260 -0
  201. keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
  202. keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
  203. keras_hub/src/models/phi3/phi3_presets.py +50 -0
  204. keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
  205. keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
  206. keras_hub/src/models/preprocessor.py +207 -0
  207. keras_hub/src/models/resnet/__init__.py +13 -0
  208. keras_hub/src/models/resnet/resnet_backbone.py +612 -0
  209. keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
  210. keras_hub/src/models/roberta/__init__.py +20 -0
  211. keras_hub/src/models/roberta/roberta_backbone.py +184 -0
  212. keras_hub/src/models/roberta/roberta_classifier.py +209 -0
  213. keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
  214. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
  215. keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
  216. keras_hub/src/models/roberta/roberta_presets.py +43 -0
  217. keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
  218. keras_hub/src/models/seq_2_seq_lm.py +54 -0
  219. keras_hub/src/models/t5/__init__.py +20 -0
  220. keras_hub/src/models/t5/t5_backbone.py +261 -0
  221. keras_hub/src/models/t5/t5_layer_norm.py +35 -0
  222. keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
  223. keras_hub/src/models/t5/t5_presets.py +95 -0
  224. keras_hub/src/models/t5/t5_tokenizer.py +100 -0
  225. keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
  226. keras_hub/src/models/task.py +419 -0
  227. keras_hub/src/models/vgg/__init__.py +13 -0
  228. keras_hub/src/models/vgg/vgg_backbone.py +158 -0
  229. keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
  230. keras_hub/src/models/vit_det/__init__.py +13 -0
  231. keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
  232. keras_hub/src/models/vit_det/vit_layers.py +565 -0
  233. keras_hub/src/models/whisper/__init__.py +20 -0
  234. keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
  235. keras_hub/src/models/whisper/whisper_backbone.py +305 -0
  236. keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
  237. keras_hub/src/models/whisper/whisper_decoder.py +141 -0
  238. keras_hub/src/models/whisper/whisper_encoder.py +106 -0
  239. keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
  240. keras_hub/src/models/whisper/whisper_presets.py +148 -0
  241. keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
  242. keras_hub/src/models/xlm_roberta/__init__.py +26 -0
  243. keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
  244. keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
  245. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
  246. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
  247. keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
  248. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
  249. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
  250. keras_hub/src/models/xlnet/__init__.py +13 -0
  251. keras_hub/src/models/xlnet/relative_attention.py +459 -0
  252. keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
  253. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
  254. keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
  255. keras_hub/src/samplers/__init__.py +13 -0
  256. keras_hub/src/samplers/beam_sampler.py +207 -0
  257. keras_hub/src/samplers/contrastive_sampler.py +231 -0
  258. keras_hub/src/samplers/greedy_sampler.py +50 -0
  259. keras_hub/src/samplers/random_sampler.py +77 -0
  260. keras_hub/src/samplers/sampler.py +237 -0
  261. keras_hub/src/samplers/serialization.py +97 -0
  262. keras_hub/src/samplers/top_k_sampler.py +92 -0
  263. keras_hub/src/samplers/top_p_sampler.py +113 -0
  264. keras_hub/src/tests/__init__.py +13 -0
  265. keras_hub/src/tests/test_case.py +608 -0
  266. keras_hub/src/tokenizers/__init__.py +13 -0
  267. keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
  268. keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
  269. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
  270. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
  271. keras_hub/src/tokenizers/tokenizer.py +235 -0
  272. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
  273. keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
  274. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
  275. keras_hub/src/utils/__init__.py +13 -0
  276. keras_hub/src/utils/keras_utils.py +130 -0
  277. keras_hub/src/utils/pipeline_model.py +293 -0
  278. keras_hub/src/utils/preset_utils.py +621 -0
  279. keras_hub/src/utils/python_utils.py +21 -0
  280. keras_hub/src/utils/tensor_utils.py +206 -0
  281. keras_hub/src/utils/timm/__init__.py +13 -0
  282. keras_hub/src/utils/timm/convert.py +37 -0
  283. keras_hub/src/utils/timm/convert_resnet.py +171 -0
  284. keras_hub/src/utils/transformers/__init__.py +13 -0
  285. keras_hub/src/utils/transformers/convert.py +101 -0
  286. keras_hub/src/utils/transformers/convert_bert.py +173 -0
  287. keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
  288. keras_hub/src/utils/transformers/convert_gemma.py +187 -0
  289. keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
  290. keras_hub/src/utils/transformers/convert_llama3.py +136 -0
  291. keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
  292. keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
  293. keras_hub/src/version_utils.py +23 -0
  294. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
  295. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
  296. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
  297. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,136 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import keras
15
+
16
+ from keras_hub.src.api_export import keras_hub_export
17
+ from keras_hub.src.models.image_classifier import ImageClassifier
18
+ from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone
19
+
20
+
21
+ @keras_hub_export("keras_hub.models.ResNetImageClassifier")
22
+ class ResNetImageClassifier(ImageClassifier):
23
+ """ResNet image classifier task model.
24
+
25
+ Args:
26
+ backbone: A `keras_hub.models.ResNetBackbone` instance.
27
+ num_classes: int. The number of classes to predict.
28
+ activation: `None`, str or callable. The activation function to use on
29
+ the `Dense` layer. Set `activation=None` to return the output
30
+ logits. Defaults to `"softmax"`.
31
+ head_dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The
32
+ dtype to use for the classification head's computations and weights.
33
+
34
+ To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)`
35
+ where `x` is a tensor and `y` is a integer from `[0, num_classes)`.
36
+ All `ImageClassifier` tasks include a `from_preset()` constructor which can
37
+ be used to load a pre-trained config and weights.
38
+
39
+ Examples:
40
+
41
+ Call `predict()` to run inference.
42
+ ```python
43
+ # Load preset and train
44
+ images = np.ones((2, 224, 224, 3), dtype="float32")
45
+ classifier = keras_hub.models.ResNetImageClassifier.from_preset("resnet50")
46
+ classifier.predict(images)
47
+ ```
48
+
49
+ Call `fit()` on a single batch.
50
+ ```python
51
+ # Load preset and train
52
+ images = np.ones((2, 224, 224, 3), dtype="float32")
53
+ labels = [0, 3]
54
+ classifier = keras_hub.models.ResNetImageClassifier.from_preset("resnet50")
55
+ classifier.fit(x=images, y=labels, batch_size=2)
56
+ ```
57
+
58
+ Call `fit()` with custom loss, optimizer and backbone.
59
+ ```python
60
+ classifier = keras_hub.models.ResNetImageClassifier.from_preset("resnet50")
61
+ classifier.compile(
62
+ loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
63
+ optimizer=keras.optimizers.Adam(5e-5),
64
+ )
65
+ classifier.backbone.trainable = False
66
+ classifier.fit(x=images, y=labels, batch_size=2)
67
+ ```
68
+
69
+ Custom backbone.
70
+ ```python
71
+ images = np.ones((2, 224, 224, 3), dtype="float32")
72
+ labels = [0, 3]
73
+ backbone = keras_hub.models.ResNetBackbone(
74
+ stackwise_num_filters=[64, 64, 64],
75
+ stackwise_num_blocks=[2, 2, 2],
76
+ stackwise_num_strides=[1, 2, 2],
77
+ block_type="basic_block",
78
+ use_pre_activation=True,
79
+ include_rescaling=False,
80
+ pooling="avg",
81
+ )
82
+ classifier = keras_hub.models.ResNetImageClassifier(
83
+ backbone=backbone,
84
+ num_classes=4,
85
+ )
86
+ classifier.fit(x=images, y=labels, batch_size=2)
87
+ ```
88
+ """
89
+
90
+ backbone_cls = ResNetBackbone
91
+
92
+ def __init__(
93
+ self,
94
+ backbone,
95
+ num_classes,
96
+ activation="softmax",
97
+ head_dtype=None,
98
+ preprocessor=None, # adding this dummy arg for saved model test
99
+ # TODO: once preprocessor flow is figured out, this needs to be updated
100
+ **kwargs,
101
+ ):
102
+ head_dtype = head_dtype or backbone.dtype_policy
103
+
104
+ # === Layers ===
105
+ self.backbone = backbone
106
+ self.output_dense = keras.layers.Dense(
107
+ num_classes,
108
+ activation=activation,
109
+ dtype=head_dtype,
110
+ name="predictions",
111
+ )
112
+
113
+ # === Functional Model ===
114
+ inputs = self.backbone.input
115
+ x = self.backbone(inputs)
116
+ outputs = self.output_dense(x)
117
+ super().__init__(
118
+ inputs=inputs,
119
+ outputs=outputs,
120
+ **kwargs,
121
+ )
122
+
123
+ # === Config ===
124
+ self.num_classes = num_classes
125
+ self.activation = activation
126
+
127
+ def get_config(self):
128
+ # Backbone serialized in `super`
129
+ config = super().get_config()
130
+ config.update(
131
+ {
132
+ "num_classes": self.num_classes,
133
+ "activation": self.activation,
134
+ }
135
+ )
136
+ return config
@@ -0,0 +1,20 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from keras_hub.src.models.roberta.roberta_backbone import RobertaBackbone
16
+ from keras_hub.src.models.roberta.roberta_presets import backbone_presets
17
+ from keras_hub.src.models.roberta.roberta_tokenizer import RobertaTokenizer
18
+ from keras_hub.src.utils.preset_utils import register_presets
19
+
20
+ register_presets(backbone_presets, (RobertaBackbone, RobertaTokenizer))
@@ -0,0 +1,184 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import keras
17
+
18
+ from keras_hub.src.api_export import keras_hub_export
19
+ from keras_hub.src.layers.modeling.token_and_position_embedding import (
20
+ TokenAndPositionEmbedding,
21
+ )
22
+ from keras_hub.src.layers.modeling.transformer_encoder import TransformerEncoder
23
+ from keras_hub.src.models.backbone import Backbone
24
+
25
+
26
+ def roberta_kernel_initializer(stddev=0.02):
27
+ return keras.initializers.TruncatedNormal(stddev=stddev)
28
+
29
+
30
+ @keras_hub_export("keras_hub.models.RobertaBackbone")
31
+ class RobertaBackbone(Backbone):
32
+ """A RoBERTa encoder network.
33
+
34
+ This network implements a bi-directional Transformer-based encoder as
35
+ described in ["RoBERTa: A Robustly Optimized BERT Pretraining Approach"](https://arxiv.org/abs/1907.11692).
36
+ It includes the embedding lookups and transformer layers, but does not
37
+ include the masked language model head used during pretraining.
38
+
39
+ The default constructor gives a fully customizable, randomly initialized
40
+ RoBERTa encoder with any number of layers, heads, and embedding
41
+ dimensions. To load preset architectures and weights, use the `from_preset()`
42
+ constructor.
43
+
44
+ Disclaimer: Pre-trained models are provided on an "as is" basis, without
45
+ warranties or conditions of any kind. The underlying model is provided by a
46
+ third party and subject to a separate license, available
47
+ [here](https://github.com/facebookresearch/fairseq).
48
+
49
+ Args:
50
+ vocabulary_size: int. The size of the token vocabulary.
51
+ num_layers: int. The number of transformer layers.
52
+ num_heads: int. The number of attention heads for each transformer.
53
+ The hidden size must be divisible by the number of attention heads.
54
+ hidden_dim: int. The size of the transformer encoding layer.
55
+ intermediate_dim: int. The output dimension of the first Dense layer in
56
+ a two-layer feedforward network for each transformer.
57
+ dropout: float. Dropout probability for the Transformer encoder.
58
+ max_sequence_length: int. The maximum sequence length this encoder can
59
+ consume. The sequence length of the input must be less than
60
+ `max_sequence_length` default value. This determines the variable
61
+ shape for positional embeddings.
62
+ dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
63
+ for model computations and weights. Note that some computations,
64
+ such as softmax and layer normalization, will always be done at
65
+ float32 precision regardless of dtype.
66
+
67
+ Examples:
68
+ ```python
69
+ input_data = {
70
+ "token_ids": np.ones(shape=(1, 12), dtype="int32"),
71
+ "padding_mask": np.array(
72
+ [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], shape=(1, 12)),
73
+ }
74
+
75
+ # Pretrained RoBERTa encoder
76
+ model = keras_hub.models.RobertaBackbone.from_preset("roberta_base_en")
77
+ model(input_data)
78
+
79
+ # Randomly initialized RoBERTa model with custom config
80
+ model = keras_hub.models.RobertaBackbone(
81
+ vocabulary_size=50265,
82
+ num_layers=4,
83
+ num_heads=4,
84
+ hidden_dim=256,
85
+ intermediate_dim=512,
86
+ max_sequence_length=128,
87
+ )
88
+ model(input_data)
89
+ ```
90
+ """
91
+
92
+ def __init__(
93
+ self,
94
+ vocabulary_size,
95
+ num_layers,
96
+ num_heads,
97
+ hidden_dim,
98
+ intermediate_dim,
99
+ dropout=0.1,
100
+ max_sequence_length=512,
101
+ dtype=None,
102
+ **kwargs,
103
+ ):
104
+ # === Layers ===
105
+ self.embeddings = TokenAndPositionEmbedding(
106
+ vocabulary_size=vocabulary_size,
107
+ sequence_length=max_sequence_length,
108
+ embedding_dim=hidden_dim,
109
+ embeddings_initializer=roberta_kernel_initializer(),
110
+ dtype=dtype,
111
+ name="embeddings",
112
+ )
113
+ self.token_embedding = self.embeddings.token_embedding
114
+ self.embeddings_layer_norm = keras.layers.LayerNormalization(
115
+ axis=-1,
116
+ epsilon=1e-5, # Original paper uses this epsilon value
117
+ dtype=dtype,
118
+ name="embeddings_layer_norm",
119
+ )
120
+ self.embeddings_dropout = keras.layers.Dropout(
121
+ dropout,
122
+ dtype=dtype,
123
+ name="embeddings_dropout",
124
+ )
125
+ self.transformer_layers = []
126
+ for i in range(num_layers):
127
+ layer = TransformerEncoder(
128
+ num_heads=num_heads,
129
+ intermediate_dim=intermediate_dim,
130
+ activation="gelu",
131
+ dropout=dropout,
132
+ layer_norm_epsilon=1e-5,
133
+ kernel_initializer=roberta_kernel_initializer(),
134
+ dtype=dtype,
135
+ name=f"transformer_layer_{i}",
136
+ )
137
+ self.transformer_layers.append(layer)
138
+
139
+ # === Functional Model ===
140
+ token_id_input = keras.Input(
141
+ shape=(None,), dtype="int32", name="token_ids"
142
+ )
143
+ padding_mask_input = keras.Input(
144
+ shape=(None,), dtype="int32", name="padding_mask"
145
+ )
146
+ x = self.embeddings(token_id_input)
147
+ x = self.embeddings_layer_norm(x)
148
+ x = self.embeddings_dropout(x)
149
+ for transformer_layer in self.transformer_layers:
150
+ x = transformer_layer(x, padding_mask=padding_mask_input)
151
+ super().__init__(
152
+ inputs={
153
+ "token_ids": token_id_input,
154
+ "padding_mask": padding_mask_input,
155
+ },
156
+ outputs=x,
157
+ dtype=dtype,
158
+ **kwargs,
159
+ )
160
+
161
+ # === Config ===
162
+ self.vocabulary_size = vocabulary_size
163
+ self.num_layers = num_layers
164
+ self.num_heads = num_heads
165
+ self.hidden_dim = hidden_dim
166
+ self.intermediate_dim = intermediate_dim
167
+ self.dropout = dropout
168
+ self.max_sequence_length = max_sequence_length
169
+ self.start_token_index = 0
170
+
171
+ def get_config(self):
172
+ config = super().get_config()
173
+ config.update(
174
+ {
175
+ "vocabulary_size": self.vocabulary_size,
176
+ "num_layers": self.num_layers,
177
+ "num_heads": self.num_heads,
178
+ "hidden_dim": self.hidden_dim,
179
+ "intermediate_dim": self.intermediate_dim,
180
+ "dropout": self.dropout,
181
+ "max_sequence_length": self.max_sequence_length,
182
+ }
183
+ )
184
+ return config
@@ -0,0 +1,209 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import keras
17
+
18
+ from keras_hub.src.api_export import keras_hub_export
19
+ from keras_hub.src.models.classifier import Classifier
20
+ from keras_hub.src.models.roberta.roberta_backbone import RobertaBackbone
21
+ from keras_hub.src.models.roberta.roberta_backbone import (
22
+ roberta_kernel_initializer,
23
+ )
24
+ from keras_hub.src.models.roberta.roberta_preprocessor import (
25
+ RobertaPreprocessor,
26
+ )
27
+
28
+
29
+ @keras_hub_export("keras_hub.models.RobertaClassifier")
30
+ class RobertaClassifier(Classifier):
31
+ """An end-to-end RoBERTa model for classification tasks.
32
+
33
+ This model attaches a classification head to a
34
+ `keras_hub.model.RobertaBackbone` instance, mapping from the backbone
35
+ outputs to logits suitable for a classification task. For usage of this
36
+ model with pre-trained weights, see the `from_preset()` constructor.
37
+
38
+ This model can optionally be configured with a `preprocessor` layer, in
39
+ which case it will automatically apply preprocessing to raw inputs during
40
+ `fit()`, `predict()`, and `evaluate()`. This is done by default when
41
+ creating the model with `from_preset()`.
42
+
43
+ Disclaimer: Pre-trained models are provided on an "as is" basis, without
44
+ warranties or conditions of any kind. The underlying model is provided by a
45
+ third party and subject to a separate license, available
46
+ [here](https://github.com/facebookresearch/fairseq).
47
+
48
+ Args:
49
+ backbone: A `keras_hub.models.RobertaBackbone` instance.
50
+ num_classes: int. Number of classes to predict.
51
+ preprocessor: A `keras_hub.models.RobertaPreprocessor` or `None`. If
52
+ `None`, this model will not apply preprocessing, and inputs should
53
+ be preprocessed before calling the model.
54
+ activation: Optional `str` or callable. The activation function to use
55
+ on the model outputs. Set `activation="softmax"` to return output
56
+ probabilities. Defaults to `None`.
57
+ hidden_dim: int. The size of the pooler layer.
58
+ dropout: float. The dropout probability value, applied to the pooled
59
+ output, and after the first dense layer.
60
+
61
+ Examples:
62
+
63
+ Raw string data.
64
+ ```python
65
+ features = ["The quick brown fox jumped.", "I forgot my homework."]
66
+ labels = [0, 3]
67
+
68
+ # Pretrained classifier.
69
+ classifier = keras_hub.models.RobertaClassifier.from_preset(
70
+ "roberta_base_en",
71
+ num_classes=4,
72
+ )
73
+ classifier.fit(x=features, y=labels, batch_size=2)
74
+ classifier.predict(x=features, batch_size=2)
75
+
76
+ # Re-compile (e.g., with a new learning rate).
77
+ classifier.compile(
78
+ loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
79
+ optimizer=keras.optimizers.Adam(5e-5),
80
+ jit_compile=True,
81
+ )
82
+ # Access backbone programmatically (e.g., to change `trainable`).
83
+ classifier.backbone.trainable = False
84
+ # Fit again.
85
+ classifier.fit(x=features, y=labels, batch_size=2)
86
+ ```
87
+
88
+ Preprocessed integer data.
89
+ ```python
90
+ features = {
91
+ "token_ids": np.ones(shape=(2, 12), dtype="int32"),
92
+ "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]] * 2),
93
+ }
94
+ labels = [0, 3]
95
+
96
+ # Pretrained classifier without preprocessing.
97
+ classifier = keras_hub.models.RobertaClassifier.from_preset(
98
+ "roberta_base_en",
99
+ num_classes=4,
100
+ preprocessor=None,
101
+ )
102
+ classifier.fit(x=features, y=labels, batch_size=2)
103
+ ```
104
+
105
+ Custom backbone and vocabulary.
106
+ ```python
107
+ features = ["a quick fox", "a fox quick"]
108
+ labels = [0, 3]
109
+
110
+ vocab = {"<s>": 0, "<pad>": 1, "</s>": 2, "<mask>": 3}
111
+ vocab = {**vocab, "a": 4, "Ġquick": 5, "Ġfox": 6}
112
+ merges = ["Ġ q", "u i", "c k", "ui ck", "Ġq uick"]
113
+ merges += ["Ġ f", "o x", "Ġf ox"]
114
+ tokenizer = keras_hub.models.RobertaTokenizer(
115
+ vocabulary=vocab,
116
+ merges=merges
117
+ )
118
+ preprocessor = keras_hub.models.RobertaPreprocessor(
119
+ tokenizer=tokenizer,
120
+ sequence_length=128,
121
+ )
122
+ backbone = keras_hub.models.RobertaBackbone(
123
+ vocabulary_size=20,
124
+ num_layers=4,
125
+ num_heads=4,
126
+ hidden_dim=256,
127
+ intermediate_dim=512,
128
+ max_sequence_length=128
129
+ )
130
+ classifier = keras_hub.models.RobertaClassifier(
131
+ backbone=backbone,
132
+ preprocessor=preprocessor,
133
+ num_classes=4,
134
+ )
135
+ classifier.fit(x=features, y=labels, batch_size=2)
136
+ ```
137
+ """
138
+
139
+ backbone_cls = RobertaBackbone
140
+ preprocessor_cls = RobertaPreprocessor
141
+
142
+ def __init__(
143
+ self,
144
+ backbone,
145
+ num_classes,
146
+ preprocessor=None,
147
+ activation=None,
148
+ hidden_dim=None,
149
+ dropout=0.0,
150
+ **kwargs,
151
+ ):
152
+ # === Layers ===
153
+ self.backbone = backbone
154
+ self.preprocessor = preprocessor
155
+ self.pooled_dropout = keras.layers.Dropout(
156
+ dropout,
157
+ dtype=backbone.dtype_policy,
158
+ name="pooled_dropout",
159
+ )
160
+ hidden_dim = hidden_dim or backbone.hidden_dim
161
+ self.pooled_dense = keras.layers.Dense(
162
+ hidden_dim,
163
+ activation="tanh",
164
+ dtype=backbone.dtype_policy,
165
+ name="pooled_dense",
166
+ )
167
+ self.output_dropout = keras.layers.Dropout(
168
+ dropout,
169
+ dtype=backbone.dtype_policy,
170
+ name="output_dropout",
171
+ )
172
+ self.output_dense = keras.layers.Dense(
173
+ num_classes,
174
+ kernel_initializer=roberta_kernel_initializer(),
175
+ activation=activation,
176
+ dtype=backbone.dtype_policy,
177
+ name="logits",
178
+ )
179
+
180
+ # === Functional Model ===
181
+ inputs = backbone.input
182
+ x = backbone(inputs)[:, backbone.start_token_index, :]
183
+ x = self.pooled_dropout(x)
184
+ x = self.pooled_dense(x)
185
+ x = self.output_dropout(x)
186
+ outputs = self.output_dense(x)
187
+ super().__init__(
188
+ inputs=inputs,
189
+ outputs=outputs,
190
+ **kwargs,
191
+ )
192
+
193
+ # === Config ===
194
+ self.num_classes = num_classes
195
+ self.activation = keras.activations.get(activation)
196
+ self.hidden_dim = hidden_dim
197
+ self.dropout = dropout
198
+
199
+ def get_config(self):
200
+ config = super().get_config()
201
+ config.update(
202
+ {
203
+ "num_classes": self.num_classes,
204
+ "activation": keras.activations.serialize(self.activation),
205
+ "hidden_dim": self.hidden_dim,
206
+ "dropout": self.dropout,
207
+ }
208
+ )
209
+ return config
@@ -0,0 +1,136 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import keras
17
+
18
+ from keras_hub.src.api_export import keras_hub_export
19
+ from keras_hub.src.layers.modeling.masked_lm_head import MaskedLMHead
20
+ from keras_hub.src.models.masked_lm import MaskedLM
21
+ from keras_hub.src.models.roberta.roberta_backbone import RobertaBackbone
22
+ from keras_hub.src.models.roberta.roberta_backbone import (
23
+ roberta_kernel_initializer,
24
+ )
25
+ from keras_hub.src.models.roberta.roberta_masked_lm_preprocessor import (
26
+ RobertaMaskedLMPreprocessor,
27
+ )
28
+
29
+
30
+ @keras_hub_export("keras_hub.models.RobertaMaskedLM")
31
+ class RobertaMaskedLM(MaskedLM):
32
+ """An end-to-end RoBERTa model for the masked language modeling task.
33
+
34
+ This model will train RoBERTa on a masked language modeling task.
35
+ The model will predict labels for a number of masked tokens in the
36
+ input data. For usage of this model with pre-trained weights, see the
37
+ `from_preset()` method.
38
+
39
+ This model can optionally be configured with a `preprocessor` layer, in
40
+ which case inputs can be raw string features during `fit()`, `predict()`,
41
+ and `evaluate()`. Inputs will be tokenized and dynamically masked during
42
+ training and evaluation. This is done by default when creating the model
43
+ with `from_preset()`.
44
+
45
+ Disclaimer: Pre-trained models are provided on an "as is" basis, without
46
+ warranties or conditions of any kind. The underlying model is provided by a
47
+ third party and subject to a separate license, available
48
+ [here](https://github.com/facebookresearch/fairseq).
49
+
50
+ Args:
51
+ backbone: A `keras_hub.models.RobertaBackbone` instance.
52
+ preprocessor: A `keras_hub.models.RobertaMaskedLMPreprocessor` or
53
+ `None`. If `None`, this model will not apply preprocessing, and
54
+ inputs should be preprocessed before calling the model.
55
+
56
+ Examples:
57
+
58
+ Raw string data.
59
+ ```python
60
+ features = ["The quick brown fox jumped.", "I forgot my homework."]
61
+
62
+ # Pretrained language model.
63
+ masked_lm = keras_hub.models.RobertaMaskedLM.from_preset(
64
+ "roberta_base_en",
65
+ )
66
+ masked_lm.fit(x=features, batch_size=2)
67
+
68
+ # Re-compile (e.g., with a new learning rate).
69
+ masked_lm.compile(
70
+ loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
71
+ optimizer=keras.optimizers.Adam(5e-5),
72
+ jit_compile=True,
73
+ )
74
+ # Access backbone programmatically (e.g., to change `trainable`).
75
+ masked_lm.backbone.trainable = False
76
+ # Fit again.
77
+ masked_lm.fit(x=features, batch_size=2)
78
+ ```
79
+
80
+ Preprocessed integer data.
81
+ ```python
82
+ # Create a preprocessed dataset where 0 is the mask token.
83
+ features = {
84
+ "token_ids": np.array([[1, 2, 0, 4, 0, 6, 7, 8]] * 2),
85
+ "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1]] * 2),
86
+ "mask_positions": np.array([[2, 4]] * 2)
87
+ }
88
+ # Labels are the original masked values.
89
+ labels = [[3, 5]] * 2
90
+
91
+ masked_lm = keras_hub.models.RobertaMaskedLM.from_preset(
92
+ "roberta_base_en",
93
+ preprocessor=None,
94
+ )
95
+
96
+ masked_lm.fit(x=features, y=labels, batch_size=2)
97
+ ```
98
+ """
99
+
100
+ backbone_cls = RobertaBackbone
101
+ preprocessor_cls = RobertaMaskedLMPreprocessor
102
+
103
+ def __init__(
104
+ self,
105
+ backbone,
106
+ preprocessor=None,
107
+ **kwargs,
108
+ ):
109
+ # === Layers ===
110
+ self.backbone = backbone
111
+ self.preprocessor = preprocessor
112
+ self.masked_lm_head = MaskedLMHead(
113
+ vocabulary_size=backbone.vocabulary_size,
114
+ token_embedding=backbone.token_embedding,
115
+ intermediate_activation="gelu",
116
+ kernel_initializer=roberta_kernel_initializer(),
117
+ dtype=backbone.dtype_policy,
118
+ name="mlm_head",
119
+ )
120
+
121
+ # === Functional Model ===
122
+ inputs = {
123
+ **backbone.input,
124
+ "mask_positions": keras.Input(
125
+ shape=(None,), dtype="int32", name="mask_positions"
126
+ ),
127
+ }
128
+ backbone_outputs = backbone(backbone.input)
129
+ outputs = self.masked_lm_head(
130
+ backbone_outputs, inputs["mask_positions"]
131
+ )
132
+ super().__init__(
133
+ inputs=inputs,
134
+ outputs=outputs,
135
+ **kwargs,
136
+ )