keras-hub-nightly 0.16.1.dev202409250340__py3-none-any.whl → 0.16.1.dev202409270338__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (357) hide show
  1. keras_hub/__init__.py +0 -13
  2. keras_hub/api/__init__.py +0 -13
  3. keras_hub/api/bounding_box/__init__.py +0 -13
  4. keras_hub/api/layers/__init__.py +3 -13
  5. keras_hub/api/metrics/__init__.py +0 -13
  6. keras_hub/api/models/__init__.py +16 -13
  7. keras_hub/api/samplers/__init__.py +0 -13
  8. keras_hub/api/tokenizers/__init__.py +1 -13
  9. keras_hub/api/utils/__init__.py +0 -13
  10. keras_hub/src/__init__.py +0 -13
  11. keras_hub/src/api_export.py +0 -14
  12. keras_hub/src/bounding_box/__init__.py +0 -13
  13. keras_hub/src/bounding_box/converters.py +0 -13
  14. keras_hub/src/bounding_box/formats.py +0 -13
  15. keras_hub/src/bounding_box/iou.py +1 -13
  16. keras_hub/src/bounding_box/to_dense.py +0 -14
  17. keras_hub/src/bounding_box/to_ragged.py +0 -13
  18. keras_hub/src/bounding_box/utils.py +0 -13
  19. keras_hub/src/bounding_box/validate_format.py +0 -14
  20. keras_hub/src/layers/__init__.py +0 -13
  21. keras_hub/src/layers/modeling/__init__.py +0 -13
  22. keras_hub/src/layers/modeling/alibi_bias.py +0 -13
  23. keras_hub/src/layers/modeling/cached_multi_head_attention.py +0 -14
  24. keras_hub/src/layers/modeling/f_net_encoder.py +0 -14
  25. keras_hub/src/layers/modeling/masked_lm_head.py +0 -14
  26. keras_hub/src/layers/modeling/position_embedding.py +0 -14
  27. keras_hub/src/layers/modeling/reversible_embedding.py +0 -14
  28. keras_hub/src/layers/modeling/rotary_embedding.py +0 -14
  29. keras_hub/src/layers/modeling/sine_position_encoding.py +0 -14
  30. keras_hub/src/layers/modeling/token_and_position_embedding.py +0 -14
  31. keras_hub/src/layers/modeling/transformer_decoder.py +0 -14
  32. keras_hub/src/layers/modeling/transformer_encoder.py +0 -14
  33. keras_hub/src/layers/modeling/transformer_layer_utils.py +0 -14
  34. keras_hub/src/layers/preprocessing/__init__.py +0 -13
  35. keras_hub/src/layers/preprocessing/audio_converter.py +0 -13
  36. keras_hub/src/layers/preprocessing/image_converter.py +0 -13
  37. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +0 -15
  38. keras_hub/src/layers/preprocessing/multi_segment_packer.py +0 -14
  39. keras_hub/src/layers/preprocessing/preprocessing_layer.py +0 -14
  40. keras_hub/src/layers/preprocessing/random_deletion.py +0 -14
  41. keras_hub/src/layers/preprocessing/random_swap.py +0 -14
  42. keras_hub/src/layers/preprocessing/resizing_image_converter.py +0 -13
  43. keras_hub/src/layers/preprocessing/start_end_packer.py +0 -15
  44. keras_hub/src/metrics/__init__.py +0 -13
  45. keras_hub/src/metrics/bleu.py +0 -14
  46. keras_hub/src/metrics/edit_distance.py +0 -14
  47. keras_hub/src/metrics/perplexity.py +0 -14
  48. keras_hub/src/metrics/rouge_base.py +0 -14
  49. keras_hub/src/metrics/rouge_l.py +0 -14
  50. keras_hub/src/metrics/rouge_n.py +0 -14
  51. keras_hub/src/models/__init__.py +0 -13
  52. keras_hub/src/models/albert/__init__.py +0 -14
  53. keras_hub/src/models/albert/albert_backbone.py +0 -14
  54. keras_hub/src/models/albert/albert_masked_lm.py +0 -14
  55. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +0 -14
  56. keras_hub/src/models/albert/albert_presets.py +0 -14
  57. keras_hub/src/models/albert/albert_text_classifier.py +0 -14
  58. keras_hub/src/models/albert/albert_text_classifier_preprocessor.py +0 -14
  59. keras_hub/src/models/albert/albert_tokenizer.py +0 -14
  60. keras_hub/src/models/backbone.py +0 -14
  61. keras_hub/src/models/bart/__init__.py +0 -14
  62. keras_hub/src/models/bart/bart_backbone.py +0 -14
  63. keras_hub/src/models/bart/bart_presets.py +0 -13
  64. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +0 -15
  65. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +0 -15
  66. keras_hub/src/models/bart/bart_tokenizer.py +0 -15
  67. keras_hub/src/models/bert/__init__.py +0 -14
  68. keras_hub/src/models/bert/bert_backbone.py +0 -14
  69. keras_hub/src/models/bert/bert_masked_lm.py +0 -14
  70. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +0 -14
  71. keras_hub/src/models/bert/bert_presets.py +0 -13
  72. keras_hub/src/models/bert/bert_text_classifier.py +0 -14
  73. keras_hub/src/models/bert/bert_text_classifier_preprocessor.py +0 -14
  74. keras_hub/src/models/bert/bert_tokenizer.py +0 -14
  75. keras_hub/src/models/bloom/__init__.py +0 -14
  76. keras_hub/src/models/bloom/bloom_attention.py +0 -13
  77. keras_hub/src/models/bloom/bloom_backbone.py +0 -14
  78. keras_hub/src/models/bloom/bloom_causal_lm.py +0 -15
  79. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +0 -15
  80. keras_hub/src/models/bloom/bloom_decoder.py +0 -13
  81. keras_hub/src/models/bloom/bloom_presets.py +0 -13
  82. keras_hub/src/models/bloom/bloom_tokenizer.py +0 -15
  83. keras_hub/src/models/causal_lm.py +0 -14
  84. keras_hub/src/models/causal_lm_preprocessor.py +0 -13
  85. keras_hub/src/models/clip/__init__.py +0 -0
  86. keras_hub/src/models/{stable_diffusion_v3 → clip}/clip_encoder_block.py +8 -15
  87. keras_hub/src/models/clip/clip_preprocessor.py +134 -0
  88. keras_hub/src/models/clip/clip_text_encoder.py +139 -0
  89. keras_hub/src/models/{stable_diffusion_v3 → clip}/clip_tokenizer.py +65 -41
  90. keras_hub/src/models/csp_darknet/__init__.py +0 -13
  91. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +0 -13
  92. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +0 -13
  93. keras_hub/src/models/deberta_v3/__init__.py +0 -14
  94. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +0 -15
  95. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +0 -15
  96. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +0 -14
  97. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +0 -13
  98. keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +0 -15
  99. keras_hub/src/models/deberta_v3/deberta_v3_text_classifier_preprocessor.py +0 -14
  100. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +0 -15
  101. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +0 -14
  102. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +0 -14
  103. keras_hub/src/models/deberta_v3/relative_embedding.py +0 -14
  104. keras_hub/src/models/densenet/__init__.py +5 -13
  105. keras_hub/src/models/densenet/densenet_backbone.py +11 -21
  106. keras_hub/src/models/densenet/densenet_image_classifier.py +27 -17
  107. keras_hub/src/models/densenet/densenet_image_classifier_preprocessor.py +27 -0
  108. keras_hub/src/models/{stable_diffusion_v3/__init__.py → densenet/densenet_image_converter.py} +10 -0
  109. keras_hub/src/models/densenet/densenet_presets.py +56 -0
  110. keras_hub/src/models/distil_bert/__init__.py +0 -14
  111. keras_hub/src/models/distil_bert/distil_bert_backbone.py +0 -15
  112. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +0 -15
  113. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +0 -14
  114. keras_hub/src/models/distil_bert/distil_bert_presets.py +0 -13
  115. keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +0 -15
  116. keras_hub/src/models/distil_bert/distil_bert_text_classifier_preprocessor.py +0 -15
  117. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +0 -15
  118. keras_hub/src/models/efficientnet/__init__.py +0 -13
  119. keras_hub/src/models/efficientnet/efficientnet_backbone.py +0 -13
  120. keras_hub/src/models/efficientnet/fusedmbconv.py +0 -14
  121. keras_hub/src/models/efficientnet/mbconv.py +0 -14
  122. keras_hub/src/models/electra/__init__.py +0 -14
  123. keras_hub/src/models/electra/electra_backbone.py +0 -14
  124. keras_hub/src/models/electra/electra_presets.py +0 -13
  125. keras_hub/src/models/electra/electra_tokenizer.py +0 -14
  126. keras_hub/src/models/f_net/__init__.py +0 -14
  127. keras_hub/src/models/f_net/f_net_backbone.py +0 -15
  128. keras_hub/src/models/f_net/f_net_masked_lm.py +0 -15
  129. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +0 -14
  130. keras_hub/src/models/f_net/f_net_presets.py +0 -13
  131. keras_hub/src/models/f_net/f_net_text_classifier.py +0 -15
  132. keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +0 -15
  133. keras_hub/src/models/f_net/f_net_tokenizer.py +0 -15
  134. keras_hub/src/models/falcon/__init__.py +0 -14
  135. keras_hub/src/models/falcon/falcon_attention.py +0 -13
  136. keras_hub/src/models/falcon/falcon_backbone.py +0 -13
  137. keras_hub/src/models/falcon/falcon_causal_lm.py +0 -14
  138. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +0 -14
  139. keras_hub/src/models/falcon/falcon_presets.py +0 -13
  140. keras_hub/src/models/falcon/falcon_tokenizer.py +0 -15
  141. keras_hub/src/models/falcon/falcon_transformer_decoder.py +0 -13
  142. keras_hub/src/models/feature_pyramid_backbone.py +0 -13
  143. keras_hub/src/models/gemma/__init__.py +0 -14
  144. keras_hub/src/models/gemma/gemma_attention.py +0 -13
  145. keras_hub/src/models/gemma/gemma_backbone.py +0 -15
  146. keras_hub/src/models/gemma/gemma_causal_lm.py +0 -15
  147. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +0 -14
  148. keras_hub/src/models/gemma/gemma_decoder_block.py +0 -13
  149. keras_hub/src/models/gemma/gemma_presets.py +0 -13
  150. keras_hub/src/models/gemma/gemma_tokenizer.py +0 -14
  151. keras_hub/src/models/gemma/rms_normalization.py +0 -14
  152. keras_hub/src/models/gpt2/__init__.py +0 -14
  153. keras_hub/src/models/gpt2/gpt2_backbone.py +0 -15
  154. keras_hub/src/models/gpt2/gpt2_causal_lm.py +0 -15
  155. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +0 -14
  156. keras_hub/src/models/gpt2/gpt2_preprocessor.py +0 -15
  157. keras_hub/src/models/gpt2/gpt2_presets.py +0 -13
  158. keras_hub/src/models/gpt2/gpt2_tokenizer.py +0 -15
  159. keras_hub/src/models/gpt_neo_x/__init__.py +0 -13
  160. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +0 -14
  161. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +0 -14
  162. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +0 -14
  163. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +0 -14
  164. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +0 -14
  165. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +0 -14
  166. keras_hub/src/models/image_classifier.py +0 -13
  167. keras_hub/src/models/image_classifier_preprocessor.py +0 -13
  168. keras_hub/src/models/image_segmenter.py +0 -13
  169. keras_hub/src/models/llama/__init__.py +0 -14
  170. keras_hub/src/models/llama/llama_attention.py +0 -13
  171. keras_hub/src/models/llama/llama_backbone.py +0 -13
  172. keras_hub/src/models/llama/llama_causal_lm.py +0 -13
  173. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +0 -15
  174. keras_hub/src/models/llama/llama_decoder.py +0 -13
  175. keras_hub/src/models/llama/llama_layernorm.py +0 -13
  176. keras_hub/src/models/llama/llama_presets.py +0 -13
  177. keras_hub/src/models/llama/llama_tokenizer.py +0 -14
  178. keras_hub/src/models/llama3/__init__.py +0 -14
  179. keras_hub/src/models/llama3/llama3_backbone.py +0 -14
  180. keras_hub/src/models/llama3/llama3_causal_lm.py +0 -13
  181. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +0 -14
  182. keras_hub/src/models/llama3/llama3_presets.py +0 -13
  183. keras_hub/src/models/llama3/llama3_tokenizer.py +0 -14
  184. keras_hub/src/models/masked_lm.py +0 -13
  185. keras_hub/src/models/masked_lm_preprocessor.py +0 -13
  186. keras_hub/src/models/mistral/__init__.py +0 -14
  187. keras_hub/src/models/mistral/mistral_attention.py +0 -13
  188. keras_hub/src/models/mistral/mistral_backbone.py +0 -14
  189. keras_hub/src/models/mistral/mistral_causal_lm.py +0 -14
  190. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +0 -14
  191. keras_hub/src/models/mistral/mistral_layer_norm.py +0 -13
  192. keras_hub/src/models/mistral/mistral_presets.py +0 -13
  193. keras_hub/src/models/mistral/mistral_tokenizer.py +0 -14
  194. keras_hub/src/models/mistral/mistral_transformer_decoder.py +0 -13
  195. keras_hub/src/models/mix_transformer/__init__.py +0 -13
  196. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +0 -13
  197. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +0 -13
  198. keras_hub/src/models/mix_transformer/mix_transformer_layers.py +0 -13
  199. keras_hub/src/models/mobilenet/__init__.py +0 -13
  200. keras_hub/src/models/mobilenet/mobilenet_backbone.py +0 -13
  201. keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +0 -13
  202. keras_hub/src/models/opt/__init__.py +0 -14
  203. keras_hub/src/models/opt/opt_backbone.py +0 -15
  204. keras_hub/src/models/opt/opt_causal_lm.py +0 -15
  205. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +0 -13
  206. keras_hub/src/models/opt/opt_presets.py +0 -13
  207. keras_hub/src/models/opt/opt_tokenizer.py +0 -15
  208. keras_hub/src/models/pali_gemma/__init__.py +0 -13
  209. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +0 -13
  210. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +0 -13
  211. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +0 -13
  212. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +0 -14
  213. keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +0 -13
  214. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +0 -13
  215. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +0 -13
  216. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +0 -13
  217. keras_hub/src/models/phi3/__init__.py +0 -14
  218. keras_hub/src/models/phi3/phi3_attention.py +0 -13
  219. keras_hub/src/models/phi3/phi3_backbone.py +0 -13
  220. keras_hub/src/models/phi3/phi3_causal_lm.py +0 -13
  221. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +0 -14
  222. keras_hub/src/models/phi3/phi3_decoder.py +0 -13
  223. keras_hub/src/models/phi3/phi3_layernorm.py +0 -13
  224. keras_hub/src/models/phi3/phi3_presets.py +0 -13
  225. keras_hub/src/models/phi3/phi3_rotary_embedding.py +0 -13
  226. keras_hub/src/models/phi3/phi3_tokenizer.py +0 -13
  227. keras_hub/src/models/preprocessor.py +51 -32
  228. keras_hub/src/models/resnet/__init__.py +0 -14
  229. keras_hub/src/models/resnet/resnet_backbone.py +0 -13
  230. keras_hub/src/models/resnet/resnet_image_classifier.py +0 -13
  231. keras_hub/src/models/resnet/resnet_image_classifier_preprocessor.py +0 -14
  232. keras_hub/src/models/resnet/resnet_image_converter.py +0 -13
  233. keras_hub/src/models/resnet/resnet_presets.py +0 -13
  234. keras_hub/src/models/retinanet/__init__.py +0 -13
  235. keras_hub/src/models/retinanet/anchor_generator.py +0 -14
  236. keras_hub/src/models/retinanet/box_matcher.py +0 -14
  237. keras_hub/src/models/retinanet/non_max_supression.py +0 -14
  238. keras_hub/src/models/roberta/__init__.py +0 -14
  239. keras_hub/src/models/roberta/roberta_backbone.py +0 -15
  240. keras_hub/src/models/roberta/roberta_masked_lm.py +0 -15
  241. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +0 -14
  242. keras_hub/src/models/roberta/roberta_presets.py +0 -13
  243. keras_hub/src/models/roberta/roberta_text_classifier.py +0 -15
  244. keras_hub/src/models/roberta/roberta_text_classifier_preprocessor.py +0 -14
  245. keras_hub/src/models/roberta/roberta_tokenizer.py +0 -15
  246. keras_hub/src/models/sam/__init__.py +0 -13
  247. keras_hub/src/models/sam/sam_backbone.py +0 -14
  248. keras_hub/src/models/sam/sam_image_segmenter.py +0 -14
  249. keras_hub/src/models/sam/sam_layers.py +0 -14
  250. keras_hub/src/models/sam/sam_mask_decoder.py +0 -14
  251. keras_hub/src/models/sam/sam_prompt_encoder.py +0 -14
  252. keras_hub/src/models/sam/sam_transformer.py +0 -14
  253. keras_hub/src/models/seq_2_seq_lm.py +0 -13
  254. keras_hub/src/models/seq_2_seq_lm_preprocessor.py +0 -13
  255. keras_hub/src/models/stable_diffusion_3/__init__.py +9 -0
  256. keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +80 -0
  257. keras_hub/src/models/{stable_diffusion_v3 → stable_diffusion_3}/mmdit.py +351 -39
  258. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +631 -0
  259. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +31 -0
  260. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +138 -0
  261. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +83 -0
  262. keras_hub/src/models/{stable_diffusion_v3/t5_xxl_text_encoder.py → stable_diffusion_3/t5_encoder.py} +7 -20
  263. keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +320 -0
  264. keras_hub/src/models/t5/__init__.py +0 -14
  265. keras_hub/src/models/t5/t5_backbone.py +0 -14
  266. keras_hub/src/models/t5/t5_layer_norm.py +0 -14
  267. keras_hub/src/models/t5/t5_multi_head_attention.py +0 -14
  268. keras_hub/src/models/{stable_diffusion_v3/t5_xxl_preprocessor.py → t5/t5_preprocessor.py} +12 -16
  269. keras_hub/src/models/t5/t5_presets.py +0 -13
  270. keras_hub/src/models/t5/t5_tokenizer.py +0 -14
  271. keras_hub/src/models/t5/t5_transformer_layer.py +0 -14
  272. keras_hub/src/models/task.py +0 -14
  273. keras_hub/src/models/text_classifier.py +0 -13
  274. keras_hub/src/models/text_classifier_preprocessor.py +0 -13
  275. keras_hub/src/models/text_to_image.py +282 -0
  276. keras_hub/src/models/vgg/__init__.py +0 -13
  277. keras_hub/src/models/vgg/vgg_backbone.py +0 -13
  278. keras_hub/src/models/vgg/vgg_image_classifier.py +0 -13
  279. keras_hub/src/models/vit_det/__init__.py +0 -13
  280. keras_hub/src/models/vit_det/vit_det_backbone.py +0 -14
  281. keras_hub/src/models/vit_det/vit_layers.py +0 -15
  282. keras_hub/src/models/whisper/__init__.py +0 -14
  283. keras_hub/src/models/whisper/whisper_audio_converter.py +0 -15
  284. keras_hub/src/models/whisper/whisper_backbone.py +0 -15
  285. keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +0 -13
  286. keras_hub/src/models/whisper/whisper_decoder.py +0 -14
  287. keras_hub/src/models/whisper/whisper_encoder.py +0 -14
  288. keras_hub/src/models/whisper/whisper_presets.py +0 -14
  289. keras_hub/src/models/whisper/whisper_tokenizer.py +0 -14
  290. keras_hub/src/models/xlm_roberta/__init__.py +0 -14
  291. keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +0 -15
  292. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +0 -15
  293. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +0 -14
  294. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +0 -13
  295. keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +0 -15
  296. keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier_preprocessor.py +0 -15
  297. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +0 -15
  298. keras_hub/src/models/xlnet/__init__.py +0 -13
  299. keras_hub/src/models/xlnet/relative_attention.py +0 -14
  300. keras_hub/src/models/xlnet/xlnet_backbone.py +0 -14
  301. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +0 -14
  302. keras_hub/src/models/xlnet/xlnet_encoder.py +0 -14
  303. keras_hub/src/samplers/__init__.py +0 -13
  304. keras_hub/src/samplers/beam_sampler.py +0 -14
  305. keras_hub/src/samplers/contrastive_sampler.py +0 -14
  306. keras_hub/src/samplers/greedy_sampler.py +0 -14
  307. keras_hub/src/samplers/random_sampler.py +0 -14
  308. keras_hub/src/samplers/sampler.py +0 -14
  309. keras_hub/src/samplers/serialization.py +0 -14
  310. keras_hub/src/samplers/top_k_sampler.py +0 -14
  311. keras_hub/src/samplers/top_p_sampler.py +0 -14
  312. keras_hub/src/tests/__init__.py +0 -13
  313. keras_hub/src/tests/test_case.py +0 -14
  314. keras_hub/src/tokenizers/__init__.py +0 -13
  315. keras_hub/src/tokenizers/byte_pair_tokenizer.py +0 -14
  316. keras_hub/src/tokenizers/byte_tokenizer.py +0 -14
  317. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +0 -14
  318. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +0 -14
  319. keras_hub/src/tokenizers/tokenizer.py +23 -27
  320. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +0 -15
  321. keras_hub/src/tokenizers/word_piece_tokenizer.py +0 -14
  322. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +0 -15
  323. keras_hub/src/utils/__init__.py +0 -13
  324. keras_hub/src/utils/imagenet/__init__.py +0 -13
  325. keras_hub/src/utils/imagenet/imagenet_utils.py +0 -13
  326. keras_hub/src/utils/keras_utils.py +0 -14
  327. keras_hub/src/utils/pipeline_model.py +0 -14
  328. keras_hub/src/utils/preset_utils.py +32 -76
  329. keras_hub/src/utils/python_utils.py +0 -13
  330. keras_hub/src/utils/tensor_utils.py +0 -14
  331. keras_hub/src/utils/timm/__init__.py +0 -13
  332. keras_hub/src/utils/timm/convert_densenet.py +107 -0
  333. keras_hub/src/utils/timm/convert_resnet.py +0 -13
  334. keras_hub/src/utils/timm/preset_loader.py +3 -13
  335. keras_hub/src/utils/transformers/__init__.py +0 -13
  336. keras_hub/src/utils/transformers/convert_albert.py +0 -13
  337. keras_hub/src/utils/transformers/convert_bart.py +0 -13
  338. keras_hub/src/utils/transformers/convert_bert.py +0 -13
  339. keras_hub/src/utils/transformers/convert_distilbert.py +0 -13
  340. keras_hub/src/utils/transformers/convert_gemma.py +0 -13
  341. keras_hub/src/utils/transformers/convert_gpt2.py +0 -13
  342. keras_hub/src/utils/transformers/convert_llama3.py +0 -13
  343. keras_hub/src/utils/transformers/convert_mistral.py +0 -13
  344. keras_hub/src/utils/transformers/convert_pali_gemma.py +0 -13
  345. keras_hub/src/utils/transformers/preset_loader.py +1 -15
  346. keras_hub/src/utils/transformers/safetensor_utils.py +9 -15
  347. keras_hub/src/version_utils.py +1 -15
  348. {keras_hub_nightly-0.16.1.dev202409250340.dist-info → keras_hub_nightly-0.16.1.dev202409270338.dist-info}/METADATA +30 -27
  349. keras_hub_nightly-0.16.1.dev202409270338.dist-info/RECORD +351 -0
  350. keras_hub/src/models/stable_diffusion_v3/clip_preprocessor.py +0 -93
  351. keras_hub/src/models/stable_diffusion_v3/clip_text_encoder.py +0 -149
  352. keras_hub/src/models/stable_diffusion_v3/mmdit_block.py +0 -317
  353. keras_hub/src/models/stable_diffusion_v3/vae_attention.py +0 -126
  354. keras_hub/src/models/stable_diffusion_v3/vae_image_decoder.py +0 -186
  355. keras_hub_nightly-0.16.1.dev202409250340.dist-info/RECORD +0 -342
  356. {keras_hub_nightly-0.16.1.dev202409250340.dist-info → keras_hub_nightly-0.16.1.dev202409270338.dist-info}/WHEEL +0 -0
  357. {keras_hub_nightly-0.16.1.dev202409250340.dist-info → keras_hub_nightly-0.16.1.dev202409270338.dist-info}/top_level.txt +0 -0
@@ -1,317 +0,0 @@
1
- # Copyright 2024 The KerasHub Authors
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # https://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- import math
15
-
16
- from keras import layers
17
- from keras import models
18
- from keras import ops
19
-
20
- from keras_hub.src.utils.keras_utils import gelu_approximate
21
-
22
-
23
- class DismantledBlock(layers.Layer):
24
- def __init__(
25
- self,
26
- num_heads,
27
- hidden_dim,
28
- mlp_ratio=4.0,
29
- use_projection=True,
30
- **kwargs,
31
- ):
32
- super().__init__(**kwargs)
33
- self.num_heads = num_heads
34
- self.hidden_dim = hidden_dim
35
- self.mlp_ratio = mlp_ratio
36
- self.use_projection = use_projection
37
-
38
- head_dim = hidden_dim // num_heads
39
- self.head_dim = head_dim
40
- mlp_hidden_dim = int(hidden_dim * mlp_ratio)
41
- self.mlp_hidden_dim = mlp_hidden_dim
42
- num_modulations = 6 if use_projection else 2
43
- self.num_modulations = num_modulations
44
-
45
- self.adaptive_norm_modulation = models.Sequential(
46
- [
47
- layers.Activation("silu", dtype=self.dtype_policy),
48
- layers.Dense(
49
- num_modulations * hidden_dim, dtype=self.dtype_policy
50
- ),
51
- ],
52
- name="adaptive_norm_modulation",
53
- )
54
- self.norm1 = layers.LayerNormalization(
55
- epsilon=1e-6,
56
- center=False,
57
- scale=False,
58
- dtype=self.dtype_policy,
59
- name="norm1",
60
- )
61
- self.attention_qkv = layers.Dense(
62
- hidden_dim * 3, dtype=self.dtype_policy, name="attention_qkv"
63
- )
64
- if use_projection:
65
- self.attention_proj = layers.Dense(
66
- hidden_dim, dtype=self.dtype_policy, name="attention_proj"
67
- )
68
- self.norm2 = layers.LayerNormalization(
69
- epsilon=1e-6,
70
- center=False,
71
- scale=False,
72
- dtype=self.dtype_policy,
73
- name="norm2",
74
- )
75
- self.mlp = models.Sequential(
76
- [
77
- layers.Dense(
78
- mlp_hidden_dim,
79
- activation=gelu_approximate,
80
- dtype=self.dtype_policy,
81
- ),
82
- layers.Dense(
83
- hidden_dim,
84
- dtype=self.dtype_policy,
85
- ),
86
- ],
87
- name="mlp",
88
- )
89
-
90
- def build(self, inputs_shape, timestep_embedding):
91
- self.adaptive_norm_modulation.build(timestep_embedding)
92
- self.attention_qkv.build(inputs_shape)
93
- self.norm1.build(inputs_shape)
94
- if self.use_projection:
95
- self.attention_proj.build(inputs_shape)
96
- self.norm2.build(inputs_shape)
97
- self.mlp.build(inputs_shape)
98
-
99
- def _modulate(self, inputs, shift, scale):
100
- shift = ops.expand_dims(shift, axis=1)
101
- scale = ops.expand_dims(scale, axis=1)
102
- return ops.add(ops.multiply(inputs, ops.add(scale, 1.0)), shift)
103
-
104
- def _compute_pre_attention(self, inputs, timestep_embedding, training=None):
105
- batch_size = ops.shape(inputs)[0]
106
- if self.use_projection:
107
- modulation = self.adaptive_norm_modulation(
108
- timestep_embedding, training=training
109
- )
110
- modulation = ops.reshape(
111
- modulation, (batch_size, 6, self.hidden_dim)
112
- )
113
- (
114
- shift_msa,
115
- scale_msa,
116
- gate_msa,
117
- shift_mlp,
118
- scale_mlp,
119
- gate_mlp,
120
- ) = ops.unstack(modulation, 6, axis=1)
121
- qkv = self.attention_qkv(
122
- self._modulate(self.norm1(inputs), shift_msa, scale_msa),
123
- training=training,
124
- )
125
- qkv = ops.reshape(
126
- qkv, (batch_size, -1, 3, self.num_heads, self.head_dim)
127
- )
128
- q, k, v = ops.unstack(qkv, 3, axis=2)
129
- return (q, k, v), (inputs, gate_msa, shift_mlp, scale_mlp, gate_mlp)
130
- else:
131
- modulation = self.adaptive_norm_modulation(
132
- timestep_embedding, training=training
133
- )
134
- modulation = ops.reshape(
135
- modulation, (batch_size, 2, self.hidden_dim)
136
- )
137
- shift_msa, scale_msa = ops.unstack(modulation, 2, axis=1)
138
- qkv = self.attention_qkv(
139
- self._modulate(self.norm1(inputs), shift_msa, scale_msa),
140
- training=training,
141
- )
142
- qkv = ops.reshape(
143
- qkv, (batch_size, -1, 3, self.num_heads, self.head_dim)
144
- )
145
- q, k, v = ops.unstack(qkv, 3, axis=2)
146
- return (q, k, v)
147
-
148
- def _compute_post_attention(
149
- self, inputs, inputs_intermediates, training=None
150
- ):
151
- x, gate_msa, shift_mlp, scale_mlp, gate_mlp = inputs_intermediates
152
- attn = self.attention_proj(inputs, training=training)
153
- x = ops.add(x, ops.multiply(ops.expand_dims(gate_msa, axis=1), attn))
154
- x = ops.add(
155
- x,
156
- ops.multiply(
157
- ops.expand_dims(gate_mlp, axis=1),
158
- self.mlp(
159
- self._modulate(self.norm2(x), shift_mlp, scale_mlp),
160
- training=training,
161
- ),
162
- ),
163
- )
164
- return x
165
-
166
- def call(
167
- self,
168
- inputs,
169
- timestep_embedding=None,
170
- inputs_intermediates=None,
171
- pre_attention=True,
172
- training=None,
173
- ):
174
- if pre_attention:
175
- return self._compute_pre_attention(
176
- inputs, timestep_embedding, training=training
177
- )
178
- else:
179
- return self._compute_post_attention(
180
- inputs, inputs_intermediates, training=training
181
- )
182
-
183
- def get_config(self):
184
- config = super().get_config()
185
- config.update(
186
- {
187
- "num_heads": self.num_heads,
188
- "hidden_dim": self.hidden_dim,
189
- "mlp_ratio": self.mlp_ratio,
190
- "use_projection": self.use_projection,
191
- }
192
- )
193
- return config
194
-
195
-
196
- class MMDiTBlock(layers.Layer):
197
- def __init__(
198
- self,
199
- num_heads,
200
- hidden_dim,
201
- mlp_ratio=4.0,
202
- use_context_projection=True,
203
- **kwargs,
204
- ):
205
- super().__init__(**kwargs)
206
- self.num_heads = num_heads
207
- self.hidden_dim = hidden_dim
208
- self.mlp_ratio = mlp_ratio
209
- self.use_context_projection = use_context_projection
210
-
211
- head_dim = hidden_dim // num_heads
212
- self.head_dim = head_dim
213
- self._inverse_sqrt_key_dim = 1.0 / math.sqrt(head_dim)
214
- self._dot_product_equation = "aecd,abcd->acbe"
215
- self._combine_equation = "acbe,aecd->abcd"
216
-
217
- self.x_block = DismantledBlock(
218
- num_heads=num_heads,
219
- hidden_dim=hidden_dim,
220
- mlp_ratio=mlp_ratio,
221
- use_projection=True,
222
- dtype=self.dtype_policy,
223
- name="x_block",
224
- )
225
- self.context_block = DismantledBlock(
226
- num_heads=num_heads,
227
- hidden_dim=hidden_dim,
228
- mlp_ratio=mlp_ratio,
229
- use_projection=use_context_projection,
230
- dtype=self.dtype_policy,
231
- name="context_block",
232
- )
233
-
234
- def build(self, inputs_shape, context_shape, timestep_embedding_shape):
235
- self.x_block.build(inputs_shape, timestep_embedding_shape)
236
- self.context_block.build(context_shape, timestep_embedding_shape)
237
-
238
- def _compute_attention(self, query, key, value):
239
- query = ops.multiply(
240
- query, ops.cast(self._inverse_sqrt_key_dim, query.dtype)
241
- )
242
- attention_scores = ops.einsum(self._dot_product_equation, key, query)
243
- attention_scores = ops.nn.softmax(attention_scores, axis=-1)
244
- attention_output = ops.einsum(
245
- self._combine_equation, attention_scores, value
246
- )
247
- batch_size = ops.shape(attention_output)[0]
248
- attention_output = ops.reshape(
249
- attention_output, (batch_size, -1, self.num_heads * self.head_dim)
250
- )
251
- return attention_output
252
-
253
- def call(self, inputs, context, timestep_embedding, training=None):
254
- # Compute pre-attention.
255
- x = inputs
256
- if self.use_context_projection:
257
- context_qkv, context_intermediates = self.context_block(
258
- context,
259
- timestep_embedding=timestep_embedding,
260
- training=training,
261
- )
262
- else:
263
- context_qkv = self.context_block(
264
- context,
265
- timestep_embedding=timestep_embedding,
266
- training=training,
267
- )
268
- context_len = ops.shape(context_qkv[0])[1]
269
- x_qkv, x_intermediates = self.x_block(
270
- x, timestep_embedding=timestep_embedding, training=training
271
- )
272
- q = ops.concatenate([context_qkv[0], x_qkv[0]], axis=1)
273
- k = ops.concatenate([context_qkv[1], x_qkv[1]], axis=1)
274
- v = ops.concatenate([context_qkv[2], x_qkv[2]], axis=1)
275
-
276
- # Compute attention.
277
- attention = self._compute_attention(q, k, v)
278
- context_attention = attention[:, :context_len]
279
- x_attention = attention[:, context_len:]
280
-
281
- # Compute post-attention.
282
- x = self.x_block(
283
- x_attention,
284
- inputs_intermediates=x_intermediates,
285
- pre_attention=False,
286
- training=training,
287
- )
288
- if self.use_context_projection:
289
- context = self.context_block(
290
- context_attention,
291
- inputs_intermediates=context_intermediates,
292
- pre_attention=False,
293
- training=training,
294
- )
295
- return x, context
296
- else:
297
- return x
298
-
299
- def get_config(self):
300
- config = super().get_config()
301
- config.update(
302
- {
303
- "num_heads": self.num_heads,
304
- "hidden_dim": self.hidden_dim,
305
- "mlp_ratio": self.mlp_ratio,
306
- "use_context_projection": self.use_context_projection,
307
- }
308
- )
309
- return config
310
-
311
- def compute_output_shape(
312
- self, inputs_shape, context_shape, timestep_embedding_shape
313
- ):
314
- if self.use_context_projection:
315
- return inputs_shape, context_shape
316
- else:
317
- return inputs_shape
@@ -1,126 +0,0 @@
1
- # Copyright 2024 The KerasHub Authors
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # https://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- import math
15
-
16
- from keras import layers
17
- from keras import ops
18
-
19
- from keras_hub.src.utils.keras_utils import standardize_data_format
20
-
21
-
22
- class VAEAttention(layers.Layer):
23
- def __init__(self, filters, groups=32, data_format=None, **kwargs):
24
- super().__init__(**kwargs)
25
- self.filters = filters
26
- self.data_format = standardize_data_format(data_format)
27
- gn_axis = -1 if self.data_format == "channels_last" else 1
28
-
29
- self.group_norm = layers.GroupNormalization(
30
- groups=groups,
31
- axis=gn_axis,
32
- epsilon=1e-6,
33
- dtype=self.dtype_policy,
34
- name="group_norm",
35
- )
36
- self.query_conv2d = layers.Conv2D(
37
- filters,
38
- 1,
39
- 1,
40
- data_format=self.data_format,
41
- dtype=self.dtype_policy,
42
- name="query_conv2d",
43
- )
44
- self.key_conv2d = layers.Conv2D(
45
- filters,
46
- 1,
47
- 1,
48
- data_format=self.data_format,
49
- dtype=self.dtype_policy,
50
- name="key_conv2d",
51
- )
52
- self.value_conv2d = layers.Conv2D(
53
- filters,
54
- 1,
55
- 1,
56
- data_format=self.data_format,
57
- dtype=self.dtype_policy,
58
- name="value_conv2d",
59
- )
60
- self.softmax = layers.Softmax(dtype="float32")
61
- self.output_conv2d = layers.Conv2D(
62
- filters,
63
- 1,
64
- 1,
65
- data_format=self.data_format,
66
- dtype=self.dtype_policy,
67
- name="output_conv2d",
68
- )
69
-
70
- self.groups = groups
71
- self._inverse_sqrt_filters = 1.0 / math.sqrt(float(filters))
72
-
73
- def build(self, input_shape):
74
- self.group_norm.build(input_shape)
75
- self.query_conv2d.build(input_shape)
76
- self.key_conv2d.build(input_shape)
77
- self.value_conv2d.build(input_shape)
78
- self.output_conv2d.build(input_shape)
79
-
80
- def call(self, inputs, training=None):
81
- x = self.group_norm(inputs)
82
- query = self.query_conv2d(x)
83
- key = self.key_conv2d(x)
84
- value = self.value_conv2d(x)
85
-
86
- if self.data_format == "channels_first":
87
- query = ops.transpose(query, (0, 2, 3, 1))
88
- key = ops.transpose(key, (0, 2, 3, 1))
89
- value = ops.transpose(value, (0, 2, 3, 1))
90
- shape = ops.shape(inputs)
91
- b = shape[0]
92
- query = ops.reshape(query, (b, -1, self.filters))
93
- key = ops.reshape(key, (b, -1, self.filters))
94
- value = ops.reshape(value, (b, -1, self.filters))
95
-
96
- # Compute attention.
97
- query = ops.multiply(
98
- query, ops.cast(self._inverse_sqrt_filters, query.dtype)
99
- )
100
- # [B, H0 * W0, C], [B, H1 * W1, C] -> [B, H0 * W0, H1 * W1]
101
- attention_scores = ops.einsum("abc,adc->abd", query, key)
102
- attention_scores = ops.cast(
103
- self.softmax(attention_scores), self.compute_dtype
104
- )
105
- # [B, H2 * W2, C], [B, H0 * W0, H1 * W1] -> [B, H1 * W1 ,C]
106
- attention_output = ops.einsum("abc,adb->adc", value, attention_scores)
107
- x = ops.reshape(attention_output, shape)
108
-
109
- x = self.output_conv2d(x)
110
- if self.data_format == "channels_first":
111
- x = ops.transpose(x, (0, 3, 1, 2))
112
- x = ops.add(x, inputs)
113
- return x
114
-
115
- def get_config(self):
116
- config = super().get_config()
117
- config.update(
118
- {
119
- "filters": self.filters,
120
- "groups": self.groups,
121
- }
122
- )
123
- return config
124
-
125
- def compute_output_shape(self, input_shape):
126
- return input_shape
@@ -1,186 +0,0 @@
1
- # Copyright 2024 The KerasHub Authors
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # https://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- import keras
15
- from keras import layers
16
-
17
- from keras_hub.src.models.stable_diffusion_v3.vae_attention import VAEAttention
18
- from keras_hub.src.utils.keras_utils import standardize_data_format
19
-
20
-
21
- class VAEImageDecoder(keras.Model):
22
- def __init__(
23
- self,
24
- stackwise_num_filters,
25
- stackwise_num_blocks,
26
- output_channels=3,
27
- latent_shape=(None, None, 16),
28
- data_format=None,
29
- dtype=None,
30
- **kwargs,
31
- ):
32
- data_format = standardize_data_format(data_format)
33
- gn_axis = -1 if data_format == "channels_last" else 1
34
-
35
- # === Functional Model ===
36
- latent_inputs = layers.Input(shape=latent_shape)
37
-
38
- x = layers.Conv2D(
39
- stackwise_num_filters[0],
40
- 3,
41
- 1,
42
- padding="same",
43
- data_format=data_format,
44
- dtype=dtype,
45
- name="input_projection",
46
- )(latent_inputs)
47
- x = apply_resnet_block(
48
- x,
49
- stackwise_num_filters[0],
50
- data_format=data_format,
51
- dtype=dtype,
52
- name="input_block0",
53
- )
54
- x = VAEAttention(
55
- stackwise_num_filters[0],
56
- data_format=data_format,
57
- dtype=dtype,
58
- name="input_attention",
59
- )(x)
60
- x = apply_resnet_block(
61
- x,
62
- stackwise_num_filters[0],
63
- data_format=data_format,
64
- dtype=dtype,
65
- name="input_block1",
66
- )
67
-
68
- # Stacks.
69
- for i, filters in enumerate(stackwise_num_filters):
70
- for j in range(stackwise_num_blocks[i]):
71
- x = apply_resnet_block(
72
- x,
73
- filters,
74
- data_format=data_format,
75
- dtype=dtype,
76
- name=f"block{i}_{j}",
77
- )
78
- if i != len(stackwise_num_filters) - 1:
79
- # No upsamling in the last blcok.
80
- x = layers.UpSampling2D(
81
- 2,
82
- data_format=data_format,
83
- dtype=dtype,
84
- name=f"upsample_{i}",
85
- )(x)
86
- x = layers.Conv2D(
87
- filters,
88
- 3,
89
- 1,
90
- padding="same",
91
- data_format=data_format,
92
- dtype=dtype,
93
- name=f"upsample_{i}_conv",
94
- )(x)
95
-
96
- # Ouput block.
97
- x = layers.GroupNormalization(
98
- groups=32,
99
- axis=gn_axis,
100
- epsilon=1e-6,
101
- dtype=dtype,
102
- name="output_norm",
103
- )(x)
104
- x = layers.Activation("swish", dtype=dtype, name="output_activation")(x)
105
- image_outputs = layers.Conv2D(
106
- output_channels,
107
- 3,
108
- 1,
109
- padding="same",
110
- data_format=data_format,
111
- dtype=dtype,
112
- name="output_projection",
113
- )(x)
114
- super().__init__(inputs=latent_inputs, outputs=image_outputs, **kwargs)
115
-
116
- # === Config ===
117
- self.stackwise_num_filters = stackwise_num_filters
118
- self.stackwise_num_blocks = stackwise_num_blocks
119
- self.output_channels = output_channels
120
- self.latent_shape = latent_shape
121
-
122
- if dtype is not None:
123
- try:
124
- self.dtype_policy = keras.dtype_policies.get(dtype)
125
- # Before Keras 3.2, there is no `keras.dtype_policies.get`.
126
- except AttributeError:
127
- if isinstance(dtype, keras.DTypePolicy):
128
- dtype = dtype.name
129
- self.dtype_policy = keras.DTypePolicy(dtype)
130
-
131
- def get_config(self):
132
- config = super().get_config()
133
- config.update(
134
- {
135
- "stackwise_num_filters": self.stackwise_num_filters,
136
- "stackwise_num_blocks": self.stackwise_num_blocks,
137
- "output_channels": self.output_channels,
138
- "image_shape": self.latent_shape,
139
- }
140
- )
141
- return config
142
-
143
-
144
- def apply_resnet_block(x, filters, data_format=None, dtype=None, name=None):
145
- data_format = standardize_data_format(data_format)
146
- gn_axis = -1 if data_format == "channels_last" else 1
147
- input_filters = x.shape[gn_axis]
148
-
149
- residual = x
150
- x = layers.GroupNormalization(
151
- groups=32, axis=gn_axis, epsilon=1e-6, dtype=dtype, name=f"{name}_norm1"
152
- )(x)
153
- x = layers.Activation("swish", dtype=dtype)(x)
154
- x = layers.Conv2D(
155
- filters,
156
- 3,
157
- 1,
158
- padding="same",
159
- data_format=data_format,
160
- dtype=dtype,
161
- name=f"{name}_conv1",
162
- )(x)
163
- x = layers.GroupNormalization(
164
- groups=32, axis=gn_axis, epsilon=1e-6, dtype=dtype, name=f"{name}_norm2"
165
- )(x)
166
- x = layers.Activation("swish")(x)
167
- x = layers.Conv2D(
168
- filters,
169
- 3,
170
- 1,
171
- padding="same",
172
- data_format=data_format,
173
- dtype=dtype,
174
- name=f"{name}_conv2",
175
- )(x)
176
- if input_filters != filters:
177
- residual = layers.Conv2D(
178
- filters,
179
- 1,
180
- 1,
181
- data_format=data_format,
182
- dtype=dtype,
183
- name=f"{name}_residual_projection",
184
- )(residual)
185
- x = layers.Add(dtype=dtype)([residual, x])
186
- return x