tf-models-nightly 2.17.0.dev20240617__py2.py3-none-any.whl → 2.20.0.dev20251205__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (1257) hide show
  1. official/__init__.py +1 -1
  2. official/common/__init__.py +1 -1
  3. official/common/dataset_fn.py +1 -1
  4. official/common/distribute_utils.py +27 -3
  5. official/common/distribute_utils_test.py +13 -12
  6. official/common/flags.py +24 -6
  7. official/common/registry_imports.py +1 -1
  8. official/common/streamz_counters.py +1 -1
  9. official/core/__init__.py +1 -1
  10. official/core/actions.py +1 -1
  11. official/core/actions_test.py +1 -1
  12. official/core/base_task.py +1 -1
  13. official/core/base_trainer.py +1 -1
  14. official/core/base_trainer_test.py +1 -1
  15. official/core/config_definitions.py +1 -1
  16. official/core/exp_factory.py +1 -1
  17. official/core/export_base.py +1 -1
  18. official/core/export_base_test.py +1 -1
  19. official/core/file_writers.py +1 -1
  20. official/core/file_writers_test.py +1 -1
  21. official/core/input_reader.py +1 -1
  22. official/core/registry.py +1 -1
  23. official/core/registry_test.py +1 -1
  24. official/core/savedmodel_checkpoint_manager.py +1 -1
  25. official/core/savedmodel_checkpoint_manager_test.py +1 -1
  26. official/core/task_factory.py +1 -1
  27. official/core/test_utils.py +1 -1
  28. official/core/tf_example_builder.py +1 -1
  29. official/core/tf_example_builder_test.py +1 -1
  30. official/core/tf_example_feature_key.py +1 -1
  31. official/core/tf_example_feature_key_test.py +1 -1
  32. official/core/train_lib.py +1 -3
  33. official/core/train_lib_test.py +1 -1
  34. official/core/train_utils.py +1 -1
  35. official/core/train_utils_test.py +1 -1
  36. official/legacy/__init__.py +1 -1
  37. official/legacy/albert/__init__.py +1 -1
  38. official/legacy/albert/configs.py +1 -1
  39. official/legacy/bert/__init__.py +1 -1
  40. official/legacy/bert/bert_models.py +1 -1
  41. official/legacy/bert/bert_models_test.py +1 -1
  42. official/legacy/bert/common_flags.py +1 -1
  43. official/legacy/bert/configs.py +1 -1
  44. official/legacy/bert/export_tfhub.py +1 -2
  45. official/legacy/bert/export_tfhub_test.py +1 -1
  46. official/legacy/bert/input_pipeline.py +1 -1
  47. official/legacy/bert/model_saving_utils.py +1 -1
  48. official/legacy/bert/model_training_utils.py +1 -1
  49. official/legacy/bert/model_training_utils_test.py +1 -1
  50. official/legacy/bert/run_classifier.py +1 -2
  51. official/legacy/bert/run_pretraining.py +1 -2
  52. official/legacy/bert/run_squad.py +1 -2
  53. official/legacy/bert/run_squad_helper.py +1 -1
  54. official/legacy/bert/serving.py +1 -1
  55. official/legacy/detection/__init__.py +1 -1
  56. official/legacy/detection/configs/__init__.py +1 -1
  57. official/legacy/detection/configs/base_config.py +1 -1
  58. official/legacy/detection/configs/factory.py +1 -1
  59. official/legacy/detection/configs/maskrcnn_config.py +1 -1
  60. official/legacy/detection/configs/olnmask_config.py +1 -1
  61. official/legacy/detection/configs/retinanet_config.py +1 -1
  62. official/legacy/detection/configs/shapemask_config.py +1 -1
  63. official/legacy/detection/dataloader/__init__.py +1 -1
  64. official/legacy/detection/dataloader/anchor.py +1 -1
  65. official/legacy/detection/dataloader/factory.py +1 -1
  66. official/legacy/detection/dataloader/input_reader.py +1 -1
  67. official/legacy/detection/dataloader/maskrcnn_parser.py +1 -1
  68. official/legacy/detection/dataloader/mode_keys.py +1 -1
  69. official/legacy/detection/dataloader/olnmask_parser.py +1 -1
  70. official/legacy/detection/dataloader/retinanet_parser.py +1 -1
  71. official/legacy/detection/dataloader/shapemask_parser.py +1 -1
  72. official/legacy/detection/dataloader/tf_example_decoder.py +1 -1
  73. official/legacy/detection/evaluation/__init__.py +1 -1
  74. official/legacy/detection/evaluation/coco_evaluator.py +1 -1
  75. official/legacy/detection/evaluation/coco_utils.py +1 -1
  76. official/legacy/detection/evaluation/factory.py +1 -1
  77. official/legacy/detection/executor/__init__.py +1 -1
  78. official/legacy/detection/executor/detection_executor.py +1 -1
  79. official/legacy/detection/executor/distributed_executor.py +1 -1
  80. official/legacy/detection/main.py +1 -1
  81. official/legacy/detection/modeling/__init__.py +1 -1
  82. official/legacy/detection/modeling/architecture/__init__.py +1 -1
  83. official/legacy/detection/modeling/architecture/factory.py +1 -1
  84. official/legacy/detection/modeling/architecture/fpn.py +1 -1
  85. official/legacy/detection/modeling/architecture/heads.py +1 -1
  86. official/legacy/detection/modeling/architecture/identity.py +1 -1
  87. official/legacy/detection/modeling/architecture/nn_blocks.py +1 -1
  88. official/legacy/detection/modeling/architecture/nn_ops.py +1 -1
  89. official/legacy/detection/modeling/architecture/resnet.py +1 -1
  90. official/legacy/detection/modeling/architecture/spinenet.py +1 -1
  91. official/legacy/detection/modeling/base_model.py +1 -1
  92. official/legacy/detection/modeling/checkpoint_utils.py +1 -1
  93. official/legacy/detection/modeling/factory.py +1 -1
  94. official/legacy/detection/modeling/learning_rates.py +1 -1
  95. official/legacy/detection/modeling/losses.py +1 -1
  96. official/legacy/detection/modeling/maskrcnn_model.py +1 -1
  97. official/legacy/detection/modeling/olnmask_model.py +1 -1
  98. official/legacy/detection/modeling/optimizers.py +1 -1
  99. official/legacy/detection/modeling/retinanet_model.py +1 -1
  100. official/legacy/detection/modeling/shapemask_model.py +1 -1
  101. official/legacy/detection/ops/__init__.py +1 -1
  102. official/legacy/detection/ops/nms.py +1 -1
  103. official/legacy/detection/ops/postprocess_ops.py +1 -1
  104. official/legacy/detection/ops/roi_ops.py +1 -1
  105. official/legacy/detection/ops/spatial_transform_ops.py +1 -1
  106. official/legacy/detection/ops/target_ops.py +1 -1
  107. official/legacy/detection/utils/__init__.py +1 -1
  108. official/legacy/detection/utils/box_utils.py +1 -1
  109. official/legacy/detection/utils/class_utils.py +1 -1
  110. official/legacy/detection/utils/dataloader_utils.py +1 -1
  111. official/legacy/detection/utils/input_utils.py +1 -1
  112. official/legacy/detection/utils/mask_utils.py +1 -1
  113. official/legacy/image_classification/__init__.py +1 -1
  114. official/legacy/image_classification/augment.py +1 -1
  115. official/legacy/image_classification/augment_test.py +1 -1
  116. official/legacy/image_classification/callbacks.py +1 -1
  117. official/legacy/image_classification/classifier_trainer.py +1 -1
  118. official/legacy/image_classification/classifier_trainer_test.py +1 -1
  119. official/legacy/image_classification/classifier_trainer_util_test.py +1 -1
  120. official/legacy/image_classification/configs/__init__.py +1 -1
  121. official/legacy/image_classification/configs/base_configs.py +1 -1
  122. official/legacy/image_classification/configs/configs.py +1 -1
  123. official/legacy/image_classification/dataset_factory.py +1 -1
  124. official/legacy/image_classification/efficientnet/__init__.py +1 -1
  125. official/legacy/image_classification/efficientnet/common_modules.py +1 -1
  126. official/legacy/image_classification/efficientnet/efficientnet_config.py +1 -1
  127. official/legacy/image_classification/efficientnet/efficientnet_model.py +1 -1
  128. official/legacy/image_classification/efficientnet/tfhub_export.py +1 -1
  129. official/legacy/image_classification/learning_rate.py +1 -1
  130. official/legacy/image_classification/learning_rate_test.py +1 -1
  131. official/legacy/image_classification/mnist_main.py +1 -2
  132. official/legacy/image_classification/mnist_test.py +1 -1
  133. official/legacy/image_classification/optimizer_factory.py +1 -1
  134. official/legacy/image_classification/optimizer_factory_test.py +1 -1
  135. official/legacy/image_classification/preprocessing.py +1 -1
  136. official/legacy/image_classification/resnet/__init__.py +1 -1
  137. official/legacy/image_classification/resnet/common.py +1 -1
  138. official/legacy/image_classification/resnet/imagenet_preprocessing.py +1 -1
  139. official/legacy/image_classification/resnet/resnet_config.py +1 -1
  140. official/legacy/image_classification/resnet/resnet_ctl_imagenet_main.py +1 -2
  141. official/legacy/image_classification/resnet/resnet_model.py +1 -1
  142. official/legacy/image_classification/resnet/resnet_runnable.py +1 -1
  143. official/legacy/image_classification/resnet/tfhub_export.py +1 -2
  144. official/legacy/image_classification/test_utils.py +1 -1
  145. official/legacy/image_classification/vgg/__init__.py +1 -1
  146. official/legacy/image_classification/vgg/vgg_config.py +1 -1
  147. official/legacy/image_classification/vgg/vgg_model.py +1 -1
  148. official/legacy/transformer/__init__.py +1 -1
  149. official/legacy/transformer/attention_layer.py +1 -1
  150. official/legacy/transformer/beam_search_v1.py +1 -1
  151. official/legacy/transformer/compute_bleu.py +1 -1
  152. official/legacy/transformer/compute_bleu_test.py +1 -1
  153. official/legacy/transformer/data_download.py +1 -1
  154. official/legacy/transformer/data_pipeline.py +1 -1
  155. official/legacy/transformer/embedding_layer.py +1 -1
  156. official/legacy/transformer/ffn_layer.py +1 -1
  157. official/legacy/transformer/metrics.py +1 -1
  158. official/legacy/transformer/misc.py +1 -1
  159. official/legacy/transformer/model_params.py +1 -1
  160. official/legacy/transformer/model_utils.py +1 -1
  161. official/legacy/transformer/model_utils_test.py +1 -1
  162. official/legacy/transformer/optimizer.py +1 -1
  163. official/legacy/transformer/transformer.py +1 -1
  164. official/legacy/transformer/transformer_forward_test.py +1 -1
  165. official/legacy/transformer/transformer_layers_test.py +1 -1
  166. official/legacy/transformer/transformer_main.py +1 -5
  167. official/legacy/transformer/transformer_main_test.py +1 -1
  168. official/legacy/transformer/transformer_test.py +1 -1
  169. official/legacy/transformer/translate.py +1 -2
  170. official/legacy/transformer/utils/__init__.py +1 -1
  171. official/legacy/transformer/utils/metrics.py +1 -1
  172. official/legacy/transformer/utils/tokenizer.py +1 -1
  173. official/legacy/transformer/utils/tokenizer_test.py +1 -1
  174. official/legacy/xlnet/__init__.py +1 -1
  175. official/legacy/xlnet/classifier_utils.py +1 -1
  176. official/legacy/xlnet/common_flags.py +1 -1
  177. official/legacy/xlnet/data_utils.py +1 -1
  178. official/legacy/xlnet/optimization.py +1 -1
  179. official/legacy/xlnet/preprocess_classification_data.py +1 -2
  180. official/legacy/xlnet/preprocess_pretrain_data.py +1 -2
  181. official/legacy/xlnet/preprocess_squad_data.py +1 -2
  182. official/legacy/xlnet/preprocess_utils.py +1 -1
  183. official/legacy/xlnet/run_classifier.py +1 -2
  184. official/legacy/xlnet/run_pretrain.py +1 -2
  185. official/legacy/xlnet/run_squad.py +1 -2
  186. official/legacy/xlnet/squad_utils.py +1 -1
  187. official/legacy/xlnet/training_utils.py +1 -1
  188. official/legacy/xlnet/xlnet_config.py +1 -1
  189. official/legacy/xlnet/xlnet_modeling.py +1 -1
  190. official/modeling/__init__.py +1 -1
  191. official/modeling/activations/__init__.py +1 -1
  192. official/modeling/activations/gelu.py +1 -1
  193. official/modeling/activations/gelu_test.py +1 -1
  194. official/modeling/activations/mish.py +1 -1
  195. official/modeling/activations/mish_test.py +1 -1
  196. official/modeling/activations/relu.py +1 -1
  197. official/modeling/activations/relu_test.py +1 -1
  198. official/modeling/activations/sigmoid.py +1 -1
  199. official/modeling/activations/sigmoid_test.py +1 -1
  200. official/modeling/activations/swish.py +1 -1
  201. official/modeling/activations/swish_test.py +1 -1
  202. official/modeling/grad_utils.py +1 -1
  203. official/modeling/grad_utils_test.py +1 -1
  204. official/modeling/hyperparams/__init__.py +1 -1
  205. official/modeling/hyperparams/base_config.py +27 -19
  206. official/modeling/hyperparams/base_config_test.py +32 -1
  207. official/modeling/hyperparams/oneof.py +1 -1
  208. official/modeling/hyperparams/oneof_test.py +1 -1
  209. official/modeling/hyperparams/params_dict.py +1 -1
  210. official/modeling/hyperparams/params_dict_test.py +1 -1
  211. official/modeling/multitask/__init__.py +1 -1
  212. official/modeling/multitask/base_model.py +1 -1
  213. official/modeling/multitask/base_trainer.py +1 -1
  214. official/modeling/multitask/base_trainer_test.py +1 -1
  215. official/modeling/multitask/configs.py +3 -3
  216. official/modeling/multitask/evaluator.py +1 -1
  217. official/modeling/multitask/evaluator_test.py +1 -1
  218. official/modeling/multitask/interleaving_trainer.py +1 -1
  219. official/modeling/multitask/interleaving_trainer_test.py +1 -1
  220. official/modeling/multitask/multitask.py +1 -1
  221. official/modeling/multitask/task_sampler.py +1 -1
  222. official/modeling/multitask/task_sampler_test.py +1 -1
  223. official/modeling/multitask/test_utils.py +1 -1
  224. official/modeling/multitask/train_lib.py +81 -14
  225. official/modeling/multitask/train_lib_test.py +1 -1
  226. official/modeling/optimization/__init__.py +1 -1
  227. official/modeling/optimization/adafactor_optimizer.py +1 -1
  228. official/modeling/optimization/configs/__init__.py +1 -1
  229. official/modeling/optimization/configs/learning_rate_config.py +1 -1
  230. official/modeling/optimization/configs/optimization_config.py +1 -1
  231. official/modeling/optimization/configs/optimization_config_test.py +1 -1
  232. official/modeling/optimization/configs/optimizer_config.py +1 -1
  233. official/modeling/optimization/ema_optimizer.py +1 -1
  234. official/modeling/optimization/lamb.py +1 -1
  235. official/modeling/optimization/lamb_test.py +1 -1
  236. official/modeling/optimization/lars.py +1 -1
  237. official/modeling/optimization/legacy_adamw.py +1 -1
  238. official/modeling/optimization/lr_schedule.py +1 -1
  239. official/modeling/optimization/lr_schedule_test.py +1 -1
  240. official/modeling/optimization/optimizer_factory.py +1 -1
  241. official/modeling/optimization/optimizer_factory_test.py +1 -1
  242. official/modeling/optimization/slide_optimizer.py +1 -1
  243. official/modeling/performance.py +1 -1
  244. official/modeling/privacy/__init__.py +1 -1
  245. official/modeling/privacy/configs.py +1 -1
  246. official/modeling/privacy/configs_test.py +1 -1
  247. official/modeling/privacy/ops.py +1 -1
  248. official/modeling/privacy/ops_test.py +1 -1
  249. official/modeling/tf_utils.py +1 -1
  250. official/modeling/tf_utils_test.py +1 -1
  251. official/nlp/__init__.py +1 -1
  252. official/nlp/configs/__init__.py +1 -1
  253. official/nlp/configs/bert.py +1 -1
  254. official/nlp/configs/electra.py +1 -1
  255. official/nlp/configs/encoders.py +1 -1
  256. official/nlp/configs/encoders_test.py +1 -1
  257. official/nlp/configs/experiment_configs.py +1 -1
  258. official/nlp/configs/finetuning_experiments.py +1 -1
  259. official/nlp/configs/pretraining_experiments.py +1 -1
  260. official/nlp/configs/wmt_transformer_experiments.py +1 -1
  261. official/nlp/continuous_finetune_lib.py +1 -1
  262. official/nlp/continuous_finetune_lib_test.py +1 -1
  263. official/nlp/data/__init__.py +1 -1
  264. official/nlp/data/classifier_data_lib.py +1 -1
  265. official/nlp/data/classifier_data_lib_test.py +1 -1
  266. official/nlp/data/create_finetuning_data.py +1 -2
  267. official/nlp/data/create_pretraining_data.py +1 -3
  268. official/nlp/data/create_pretraining_data_test.py +1 -1
  269. official/nlp/data/create_xlnet_pretraining_data.py +1 -3
  270. official/nlp/data/create_xlnet_pretraining_data_test.py +1 -1
  271. official/nlp/data/data_loader.py +1 -1
  272. official/nlp/data/data_loader_factory.py +1 -1
  273. official/nlp/data/data_loader_factory_test.py +1 -1
  274. official/nlp/data/dual_encoder_dataloader.py +1 -1
  275. official/nlp/data/dual_encoder_dataloader_test.py +1 -1
  276. official/nlp/data/pretrain_dataloader.py +1 -1
  277. official/nlp/data/pretrain_dataloader_test.py +1 -1
  278. official/nlp/data/pretrain_dynamic_dataloader.py +1 -1
  279. official/nlp/data/pretrain_dynamic_dataloader_test.py +1 -1
  280. official/nlp/data/pretrain_text_dataloader.py +1 -1
  281. official/nlp/data/question_answering_dataloader.py +1 -1
  282. official/nlp/data/question_answering_dataloader_test.py +1 -1
  283. official/nlp/data/sentence_prediction_dataloader.py +1 -1
  284. official/nlp/data/sentence_prediction_dataloader_test.py +1 -1
  285. official/nlp/data/sentence_retrieval_lib.py +1 -1
  286. official/nlp/data/squad_lib.py +1 -1
  287. official/nlp/data/squad_lib_sp.py +1 -1
  288. official/nlp/data/tagging_data_lib.py +1 -1
  289. official/nlp/data/tagging_data_lib_test.py +1 -1
  290. official/nlp/data/tagging_dataloader.py +1 -1
  291. official/nlp/data/tagging_dataloader_test.py +1 -1
  292. official/nlp/data/train_sentencepiece.py +1 -1
  293. official/nlp/data/wmt_dataloader.py +1 -1
  294. official/nlp/data/wmt_dataloader_test.py +1 -1
  295. official/nlp/metrics/__init__.py +1 -1
  296. official/nlp/metrics/bleu.py +1 -1
  297. official/nlp/metrics/bleu_test.py +1 -1
  298. official/nlp/modeling/__init__.py +1 -1
  299. official/nlp/modeling/layers/__init__.py +1 -1
  300. official/nlp/modeling/layers/attention.py +1 -1
  301. official/nlp/modeling/layers/attention_test.py +1 -1
  302. official/nlp/modeling/layers/bigbird_attention.py +1 -1
  303. official/nlp/modeling/layers/bigbird_attention_test.py +1 -1
  304. official/nlp/modeling/layers/block_diag_feedforward.py +1 -1
  305. official/nlp/modeling/layers/block_diag_feedforward_test.py +1 -1
  306. official/nlp/modeling/layers/block_sparse_attention.py +187 -44
  307. official/nlp/modeling/layers/block_sparse_attention_test.py +137 -7
  308. official/nlp/modeling/layers/cls_head.py +1 -1
  309. official/nlp/modeling/layers/cls_head_test.py +1 -1
  310. official/nlp/modeling/layers/factorized_embedding.py +1 -1
  311. official/nlp/modeling/layers/factorized_embedding_test.py +1 -1
  312. official/nlp/modeling/layers/gated_feedforward.py +2 -2
  313. official/nlp/modeling/layers/gated_feedforward_test.py +1 -1
  314. official/nlp/modeling/layers/gaussian_process.py +1 -1
  315. official/nlp/modeling/layers/gaussian_process_test.py +1 -1
  316. official/nlp/modeling/layers/kernel_attention.py +1 -1
  317. official/nlp/modeling/layers/kernel_attention_test.py +1 -1
  318. official/nlp/modeling/layers/masked_lm.py +1 -1
  319. official/nlp/modeling/layers/masked_lm_test.py +1 -1
  320. official/nlp/modeling/layers/masked_softmax.py +1 -1
  321. official/nlp/modeling/layers/masked_softmax_test.py +1 -1
  322. official/nlp/modeling/layers/mat_mul_with_margin.py +1 -2
  323. official/nlp/modeling/layers/mat_mul_with_margin_test.py +1 -1
  324. official/nlp/modeling/layers/mixing.py +1 -1
  325. official/nlp/modeling/layers/mixing_test.py +1 -1
  326. official/nlp/modeling/layers/mobile_bert_layers.py +1 -1
  327. official/nlp/modeling/layers/mobile_bert_layers_test.py +1 -1
  328. official/nlp/modeling/layers/moe.py +1 -1
  329. official/nlp/modeling/layers/moe_test.py +1 -1
  330. official/nlp/modeling/layers/multi_channel_attention.py +1 -1
  331. official/nlp/modeling/layers/multi_channel_attention_test.py +1 -1
  332. official/nlp/modeling/layers/multi_query_attention.py +222 -4
  333. official/nlp/modeling/layers/multi_query_attention_test.py +201 -1
  334. official/nlp/modeling/layers/on_device_embedding.py +1 -1
  335. official/nlp/modeling/layers/on_device_embedding_test.py +1 -1
  336. official/nlp/modeling/layers/pack_optimization.py +1 -1
  337. official/nlp/modeling/layers/pack_optimization_test.py +1 -1
  338. official/nlp/modeling/layers/per_dim_scale_attention.py +1 -1
  339. official/nlp/modeling/layers/per_dim_scale_attention_test.py +1 -1
  340. official/nlp/modeling/layers/position_embedding.py +1 -1
  341. official/nlp/modeling/layers/position_embedding_test.py +1 -1
  342. official/nlp/modeling/layers/relative_attention.py +1 -1
  343. official/nlp/modeling/layers/relative_attention_test.py +1 -1
  344. official/nlp/modeling/layers/reuse_attention.py +1 -1
  345. official/nlp/modeling/layers/reuse_attention_test.py +1 -1
  346. official/nlp/modeling/layers/reuse_transformer.py +1 -1
  347. official/nlp/modeling/layers/reuse_transformer_test.py +1 -1
  348. official/nlp/modeling/layers/rezero_transformer.py +20 -1
  349. official/nlp/modeling/layers/rezero_transformer_test.py +1 -1
  350. official/nlp/modeling/layers/routing.py +1 -1
  351. official/nlp/modeling/layers/routing_test.py +1 -1
  352. official/nlp/modeling/layers/self_attention_mask.py +1 -1
  353. official/nlp/modeling/layers/spectral_normalization.py +1 -1
  354. official/nlp/modeling/layers/spectral_normalization_test.py +1 -1
  355. official/nlp/modeling/layers/talking_heads_attention.py +1 -1
  356. official/nlp/modeling/layers/talking_heads_attention_test.py +1 -1
  357. official/nlp/modeling/layers/text_layers.py +1 -1
  358. official/nlp/modeling/layers/text_layers_test.py +1 -1
  359. official/nlp/modeling/layers/tn_expand_condense.py +1 -1
  360. official/nlp/modeling/layers/tn_expand_condense_test.py +1 -1
  361. official/nlp/modeling/layers/tn_transformer_expand_condense.py +1 -3
  362. official/nlp/modeling/layers/tn_transformer_test.py +1 -1
  363. official/nlp/modeling/layers/transformer.py +1 -1
  364. official/nlp/modeling/layers/transformer_encoder_block.py +273 -52
  365. official/nlp/modeling/layers/transformer_encoder_block_test.py +215 -11
  366. official/nlp/modeling/layers/transformer_scaffold.py +1 -1
  367. official/nlp/modeling/layers/transformer_scaffold_test.py +1 -1
  368. official/nlp/modeling/layers/transformer_test.py +1 -1
  369. official/nlp/modeling/layers/transformer_xl.py +1 -1
  370. official/nlp/modeling/layers/transformer_xl_test.py +1 -1
  371. official/nlp/modeling/layers/util.py +1 -1
  372. official/nlp/modeling/losses/__init__.py +1 -1
  373. official/nlp/modeling/losses/weighted_sparse_categorical_crossentropy.py +1 -1
  374. official/nlp/modeling/losses/weighted_sparse_categorical_crossentropy_test.py +1 -1
  375. official/nlp/modeling/models/__init__.py +1 -1
  376. official/nlp/modeling/models/bert_classifier.py +1 -1
  377. official/nlp/modeling/models/bert_classifier_test.py +1 -1
  378. official/nlp/modeling/models/bert_pretrainer.py +1 -1
  379. official/nlp/modeling/models/bert_pretrainer_test.py +1 -1
  380. official/nlp/modeling/models/bert_span_labeler.py +1 -1
  381. official/nlp/modeling/models/bert_span_labeler_test.py +1 -1
  382. official/nlp/modeling/models/bert_token_classifier.py +1 -1
  383. official/nlp/modeling/models/bert_token_classifier_test.py +1 -1
  384. official/nlp/modeling/models/dual_encoder.py +1 -1
  385. official/nlp/modeling/models/dual_encoder_test.py +1 -1
  386. official/nlp/modeling/models/electra_pretrainer.py +1 -1
  387. official/nlp/modeling/models/electra_pretrainer_test.py +1 -1
  388. official/nlp/modeling/models/seq2seq_transformer.py +1 -1
  389. official/nlp/modeling/models/seq2seq_transformer_test.py +1 -1
  390. official/nlp/modeling/models/t5.py +1 -1
  391. official/nlp/modeling/models/t5_test.py +1 -1
  392. official/nlp/modeling/models/xlnet.py +1 -1
  393. official/nlp/modeling/models/xlnet_test.py +1 -1
  394. official/nlp/modeling/networks/__init__.py +1 -1
  395. official/nlp/modeling/networks/albert_encoder.py +1 -1
  396. official/nlp/modeling/networks/albert_encoder_test.py +1 -1
  397. official/nlp/modeling/networks/bert_dense_encoder_test.py +1 -2
  398. official/nlp/modeling/networks/bert_encoder.py +1 -1
  399. official/nlp/modeling/networks/bert_encoder_test.py +1 -2
  400. official/nlp/modeling/networks/classification.py +1 -1
  401. official/nlp/modeling/networks/classification_test.py +1 -1
  402. official/nlp/modeling/networks/encoder_scaffold.py +1 -1
  403. official/nlp/modeling/networks/encoder_scaffold_test.py +1 -1
  404. official/nlp/modeling/networks/fnet.py +1 -1
  405. official/nlp/modeling/networks/fnet_test.py +1 -1
  406. official/nlp/modeling/networks/funnel_transformer.py +1 -1
  407. official/nlp/modeling/networks/funnel_transformer_test.py +1 -1
  408. official/nlp/modeling/networks/mobile_bert_encoder.py +6 -4
  409. official/nlp/modeling/networks/mobile_bert_encoder_test.py +1 -1
  410. official/nlp/modeling/networks/packed_sequence_embedding.py +1 -1
  411. official/nlp/modeling/networks/packed_sequence_embedding_test.py +1 -3
  412. official/nlp/modeling/networks/span_labeling.py +1 -1
  413. official/nlp/modeling/networks/span_labeling_test.py +1 -1
  414. official/nlp/modeling/networks/sparse_mixer.py +1 -1
  415. official/nlp/modeling/networks/sparse_mixer_test.py +1 -1
  416. official/nlp/modeling/networks/xlnet_base.py +1 -1
  417. official/nlp/modeling/networks/xlnet_base_test.py +1 -1
  418. official/nlp/modeling/ops/__init__.py +1 -1
  419. official/nlp/modeling/ops/beam_search.py +1 -1
  420. official/nlp/modeling/ops/beam_search_test.py +1 -1
  421. official/nlp/modeling/ops/decoding_module.py +1 -1
  422. official/nlp/modeling/ops/decoding_module_test.py +1 -1
  423. official/nlp/modeling/ops/sampling_module.py +3 -3
  424. official/nlp/modeling/ops/segment_extractor.py +1 -1
  425. official/nlp/modeling/ops/segment_extractor_test.py +1 -1
  426. official/nlp/optimization.py +1 -1
  427. official/nlp/serving/__init__.py +1 -1
  428. official/nlp/serving/export_savedmodel.py +1 -1
  429. official/nlp/serving/export_savedmodel_test.py +1 -1
  430. official/nlp/serving/export_savedmodel_util.py +1 -1
  431. official/nlp/serving/serving_modules.py +1 -1
  432. official/nlp/serving/serving_modules_test.py +1 -1
  433. official/nlp/tasks/__init__.py +1 -1
  434. official/nlp/tasks/dual_encoder.py +1 -2
  435. official/nlp/tasks/dual_encoder_test.py +1 -1
  436. official/nlp/tasks/electra_task.py +1 -1
  437. official/nlp/tasks/electra_task_test.py +1 -1
  438. official/nlp/tasks/masked_lm.py +1 -1
  439. official/nlp/tasks/masked_lm_determinism_test.py +1 -1
  440. official/nlp/tasks/masked_lm_test.py +1 -1
  441. official/nlp/tasks/question_answering.py +1 -1
  442. official/nlp/tasks/question_answering_test.py +1 -1
  443. official/nlp/tasks/sentence_prediction.py +1 -1
  444. official/nlp/tasks/sentence_prediction_test.py +1 -1
  445. official/nlp/tasks/tagging.py +1 -1
  446. official/nlp/tasks/tagging_test.py +1 -1
  447. official/nlp/tasks/translation.py +1 -1
  448. official/nlp/tasks/translation_test.py +1 -1
  449. official/nlp/tasks/utils.py +1 -1
  450. official/nlp/tools/__init__.py +1 -1
  451. official/nlp/tools/export_tfhub.py +1 -1
  452. official/nlp/tools/export_tfhub_lib.py +1 -2
  453. official/nlp/tools/export_tfhub_lib_test.py +1 -1
  454. official/nlp/tools/squad_evaluate_v1_1.py +1 -1
  455. official/nlp/tools/squad_evaluate_v2_0.py +1 -1
  456. official/nlp/tools/tf1_bert_checkpoint_converter_lib.py +1 -1
  457. official/nlp/tools/tf2_albert_encoder_checkpoint_converter.py +1 -1
  458. official/nlp/tools/tf2_bert_encoder_checkpoint_converter.py +1 -1
  459. official/nlp/tools/tokenization.py +1 -1
  460. official/nlp/tools/tokenization_test.py +1 -1
  461. official/nlp/train.py +1 -1
  462. official/projects/__init__.py +1 -1
  463. official/projects/bigbird/__init__.py +1 -1
  464. official/projects/bigbird/encoder.py +1 -1
  465. official/projects/bigbird/encoder_test.py +1 -1
  466. official/projects/bigbird/experiment_configs.py +1 -1
  467. official/projects/bigbird/recompute_grad.py +1 -1
  468. official/projects/bigbird/recomputing_dropout.py +1 -1
  469. official/projects/bigbird/stateless_dropout.py +1 -1
  470. official/projects/centernet/__init__.py +1 -1
  471. official/projects/centernet/common/__init__.py +1 -1
  472. official/projects/centernet/common/registry_imports.py +1 -1
  473. official/projects/centernet/configs/__init__.py +1 -1
  474. official/projects/centernet/configs/backbones.py +1 -1
  475. official/projects/centernet/configs/centernet.py +1 -1
  476. official/projects/centernet/configs/centernet_test.py +1 -1
  477. official/projects/centernet/dataloaders/__init__.py +1 -1
  478. official/projects/centernet/dataloaders/centernet_input.py +1 -1
  479. official/projects/centernet/losses/__init__.py +1 -1
  480. official/projects/centernet/losses/centernet_losses.py +1 -1
  481. official/projects/centernet/losses/centernet_losses_test.py +1 -1
  482. official/projects/centernet/modeling/__init__.py +1 -1
  483. official/projects/centernet/modeling/backbones/__init__.py +1 -1
  484. official/projects/centernet/modeling/backbones/hourglass.py +1 -1
  485. official/projects/centernet/modeling/backbones/hourglass_test.py +1 -1
  486. official/projects/centernet/modeling/centernet_model.py +2 -2
  487. official/projects/centernet/modeling/centernet_model_test.py +1 -1
  488. official/projects/centernet/modeling/heads/__init__.py +1 -1
  489. official/projects/centernet/modeling/heads/centernet_head.py +2 -2
  490. official/projects/centernet/modeling/heads/centernet_head_test.py +1 -1
  491. official/projects/centernet/modeling/layers/__init__.py +1 -1
  492. official/projects/centernet/modeling/layers/cn_nn_blocks.py +1 -1
  493. official/projects/centernet/modeling/layers/cn_nn_blocks_test.py +1 -1
  494. official/projects/centernet/modeling/layers/detection_generator.py +1 -1
  495. official/projects/centernet/modeling/layers/detection_generator_test.py +1 -1
  496. official/projects/centernet/ops/__init__.py +1 -1
  497. official/projects/centernet/ops/box_list.py +1 -1
  498. official/projects/centernet/ops/box_list_ops.py +1 -1
  499. official/projects/centernet/ops/loss_ops.py +1 -1
  500. official/projects/centernet/ops/nms_ops.py +1 -1
  501. official/projects/centernet/ops/preprocess_ops.py +1 -1
  502. official/projects/centernet/ops/target_assigner.py +1 -1
  503. official/projects/centernet/ops/target_assigner_test.py +1 -1
  504. official/projects/centernet/tasks/__init__.py +1 -1
  505. official/projects/centernet/tasks/centernet.py +1 -1
  506. official/projects/centernet/train.py +1 -1
  507. official/projects/centernet/utils/__init__.py +1 -1
  508. official/projects/centernet/utils/checkpoints/__init__.py +1 -1
  509. official/projects/centernet/utils/checkpoints/config_classes.py +1 -1
  510. official/projects/centernet/utils/checkpoints/config_data.py +1 -1
  511. official/projects/centernet/utils/checkpoints/load_weights.py +1 -1
  512. official/projects/centernet/utils/checkpoints/read_checkpoints.py +1 -1
  513. official/projects/centernet/utils/tf2_centernet_checkpoint_converter.py +1 -1
  514. official/projects/deepmac_maskrcnn/__init__.py +1 -1
  515. official/projects/deepmac_maskrcnn/common/__init__.py +1 -1
  516. official/projects/deepmac_maskrcnn/common/registry_imports.py +1 -1
  517. official/projects/deepmac_maskrcnn/configs/__init__.py +1 -1
  518. official/projects/deepmac_maskrcnn/configs/deep_mask_head_rcnn.py +1 -1
  519. official/projects/deepmac_maskrcnn/configs/deep_mask_head_rcnn_config_test.py +1 -1
  520. official/projects/deepmac_maskrcnn/modeling/__init__.py +1 -1
  521. official/projects/deepmac_maskrcnn/modeling/heads/__init__.py +1 -1
  522. official/projects/deepmac_maskrcnn/modeling/heads/hourglass_network.py +1 -1
  523. official/projects/deepmac_maskrcnn/modeling/heads/instance_heads.py +1 -3
  524. official/projects/deepmac_maskrcnn/modeling/heads/instance_heads_test.py +1 -2
  525. official/projects/deepmac_maskrcnn/modeling/maskrcnn_model.py +1 -3
  526. official/projects/deepmac_maskrcnn/modeling/maskrcnn_model_test.py +1 -3
  527. official/projects/deepmac_maskrcnn/serving/__init__.py +1 -1
  528. official/projects/deepmac_maskrcnn/serving/detection.py +1 -1
  529. official/projects/deepmac_maskrcnn/serving/detection_test.py +1 -1
  530. official/projects/deepmac_maskrcnn/serving/export_saved_model.py +1 -1
  531. official/projects/deepmac_maskrcnn/tasks/__init__.py +1 -1
  532. official/projects/deepmac_maskrcnn/tasks/deep_mask_head_rcnn.py +1 -1
  533. official/projects/deepmac_maskrcnn/train.py +1 -1
  534. official/projects/detr/__init__.py +14 -0
  535. official/projects/detr/configs/__init__.py +14 -0
  536. official/projects/detr/configs/detr.py +277 -0
  537. official/projects/detr/configs/detr_test.py +51 -0
  538. official/projects/detr/dataloaders/__init__.py +14 -0
  539. official/projects/detr/dataloaders/coco.py +157 -0
  540. official/projects/detr/dataloaders/coco_test.py +111 -0
  541. official/projects/detr/dataloaders/detr_input.py +175 -0
  542. official/projects/detr/experiments/__init__.py +14 -0
  543. official/projects/detr/modeling/__init__.py +14 -0
  544. official/projects/detr/modeling/detr.py +345 -0
  545. official/projects/detr/modeling/detr_test.py +70 -0
  546. official/projects/detr/modeling/transformer.py +849 -0
  547. official/projects/detr/modeling/transformer_test.py +263 -0
  548. official/projects/detr/ops/__init__.py +14 -0
  549. official/projects/detr/ops/matchers.py +489 -0
  550. official/projects/detr/ops/matchers_test.py +95 -0
  551. official/projects/detr/optimization.py +151 -0
  552. official/projects/detr/serving/__init__.py +14 -0
  553. official/projects/detr/serving/export_module.py +103 -0
  554. official/projects/detr/serving/export_module_test.py +98 -0
  555. official/projects/detr/serving/export_saved_model.py +109 -0
  556. official/projects/detr/tasks/__init__.py +14 -0
  557. official/projects/detr/tasks/detection.py +433 -0
  558. official/projects/detr/tasks/detection_test.py +203 -0
  559. official/projects/detr/train.py +70 -0
  560. official/projects/maskconver/__init__.py +14 -0
  561. official/projects/maskconver/configs/__init__.py +14 -0
  562. official/projects/maskconver/configs/backbones.py +43 -0
  563. official/projects/maskconver/configs/decoders.py +36 -0
  564. official/projects/maskconver/configs/maskconver.py +523 -0
  565. official/projects/maskconver/configs/multiscale_maskconver.py +215 -0
  566. official/projects/maskconver/tasks/__init__.py +14 -0
  567. official/projects/maskconver/tasks/maskconver.py +641 -0
  568. official/projects/maskconver/tasks/multiscale_maskconver.py +278 -0
  569. official/projects/maskconver/train.py +30 -0
  570. official/projects/maxvit/__init__.py +1 -1
  571. official/projects/maxvit/configs/__init__.py +1 -1
  572. official/projects/maxvit/configs/backbones.py +1 -1
  573. official/projects/maxvit/configs/image_classification.py +1 -1
  574. official/projects/maxvit/configs/image_classification_test.py +1 -1
  575. official/projects/maxvit/configs/rcnn.py +1 -1
  576. official/projects/maxvit/configs/rcnn_test.py +1 -1
  577. official/projects/maxvit/configs/retinanet.py +1 -1
  578. official/projects/maxvit/configs/retinanet_test.py +1 -1
  579. official/projects/maxvit/configs/semantic_segmentation.py +1 -1
  580. official/projects/maxvit/configs/semantic_segmentation_test.py +1 -1
  581. official/projects/maxvit/modeling/__init__.py +1 -1
  582. official/projects/maxvit/modeling/common_ops.py +14 -1
  583. official/projects/maxvit/modeling/layers.py +1 -1
  584. official/projects/maxvit/modeling/maxvit.py +2 -2
  585. official/projects/maxvit/modeling/maxvit_test.py +1 -1
  586. official/projects/maxvit/registry_imports.py +1 -1
  587. official/projects/maxvit/train.py +1 -1
  588. official/projects/maxvit/train_test.py +1 -1
  589. official/projects/mobilebert/__init__.py +1 -1
  590. official/projects/mobilebert/distillation.py +1 -1
  591. official/projects/mobilebert/distillation_test.py +1 -1
  592. official/projects/mobilebert/export_tfhub.py +1 -1
  593. official/projects/mobilebert/model_utils.py +1 -1
  594. official/projects/mobilebert/run_distillation.py +1 -1
  595. official/projects/mobilebert/tf2_model_checkpoint_converter.py +1 -1
  596. official/projects/mobilebert/utils.py +1 -1
  597. official/projects/movinet/__init__.py +1 -1
  598. official/projects/movinet/configs/__init__.py +1 -1
  599. official/projects/movinet/configs/movinet.py +1 -1
  600. official/projects/movinet/configs/movinet_test.py +1 -1
  601. official/projects/movinet/modeling/__init__.py +1 -1
  602. official/projects/movinet/modeling/movinet.py +1 -1
  603. official/projects/movinet/modeling/movinet_layers.py +1 -1
  604. official/projects/movinet/modeling/movinet_layers_test.py +1 -1
  605. official/projects/movinet/modeling/movinet_model.py +1 -1
  606. official/projects/movinet/modeling/movinet_model_test.py +1 -1
  607. official/projects/movinet/modeling/movinet_test.py +1 -1
  608. official/projects/movinet/tools/__init__.py +1 -1
  609. official/projects/movinet/tools/convert_3d_2plus1d.py +1 -1
  610. official/projects/movinet/tools/convert_3d_2plus1d_test.py +1 -1
  611. official/projects/movinet/tools/export_saved_model.py +1 -1
  612. official/projects/movinet/tools/export_saved_model_test.py +6 -3
  613. official/projects/movinet/tools/quantize_movinet.py +1 -1
  614. official/projects/movinet/train.py +1 -1
  615. official/projects/movinet/train_test.py +1 -1
  616. official/projects/nhnet/__init__.py +1 -1
  617. official/projects/nhnet/configs.py +1 -1
  618. official/projects/nhnet/configs_test.py +1 -1
  619. official/projects/nhnet/decoder.py +1 -1
  620. official/projects/nhnet/decoder_test.py +1 -1
  621. official/projects/nhnet/evaluation.py +1 -3
  622. official/projects/nhnet/input_pipeline.py +1 -1
  623. official/projects/nhnet/models.py +1 -1
  624. official/projects/nhnet/models_test.py +1 -1
  625. official/projects/nhnet/optimizer.py +1 -1
  626. official/projects/nhnet/raw_data_process.py +1 -1
  627. official/projects/nhnet/raw_data_processor.py +1 -1
  628. official/projects/nhnet/trainer.py +1 -6
  629. official/projects/nhnet/trainer_test.py +1 -1
  630. official/projects/nhnet/utils.py +1 -1
  631. official/projects/panoptic/__init__.py +1 -1
  632. official/projects/panoptic/configs/__init__.py +1 -1
  633. official/projects/panoptic/configs/panoptic_deeplab.py +5 -6
  634. official/projects/panoptic/configs/panoptic_maskrcnn.py +1 -1
  635. official/projects/panoptic/tasks/__init__.py +1 -1
  636. official/projects/panoptic/tasks/panoptic_deeplab.py +1 -1
  637. official/projects/panoptic/tasks/panoptic_maskrcnn.py +3 -1
  638. official/projects/panoptic/train.py +1 -1
  639. official/projects/qat/__init__.py +1 -1
  640. official/projects/qat/nlp/__init__.py +1 -1
  641. official/projects/qat/nlp/configs/__init__.py +1 -1
  642. official/projects/qat/nlp/configs/finetuning_experiments.py +1 -1
  643. official/projects/qat/nlp/modeling/__init__.py +1 -1
  644. official/projects/qat/nlp/modeling/layers/__init__.py +1 -1
  645. official/projects/qat/nlp/modeling/layers/mobile_bert_layers.py +1 -1
  646. official/projects/qat/nlp/modeling/layers/multi_head_attention.py +1 -1
  647. official/projects/qat/nlp/modeling/layers/transformer_encoder_block.py +1 -1
  648. official/projects/qat/nlp/modeling/layers/transformer_encoder_block_test.py +1 -1
  649. official/projects/qat/nlp/modeling/models/__init__.py +1 -1
  650. official/projects/qat/nlp/modeling/models/bert_span_labeler.py +1 -1
  651. official/projects/qat/nlp/modeling/networks/__init__.py +1 -1
  652. official/projects/qat/nlp/modeling/networks/span_labeling.py +1 -1
  653. official/projects/qat/nlp/pretrained_checkpoint_converter.py +1 -3
  654. official/projects/qat/nlp/quantization/__init__.py +1 -1
  655. official/projects/qat/nlp/quantization/configs.py +1 -1
  656. official/projects/qat/nlp/quantization/configs_test.py +1 -2
  657. official/projects/qat/nlp/quantization/helper.py +1 -1
  658. official/projects/qat/nlp/quantization/schemes.py +1 -3
  659. official/projects/qat/nlp/quantization/wrappers.py +1 -1
  660. official/projects/qat/nlp/registry_imports.py +1 -1
  661. official/projects/qat/nlp/tasks/__init__.py +1 -1
  662. official/projects/qat/nlp/tasks/question_answering.py +1 -1
  663. official/projects/qat/nlp/tasks/question_answering_test.py +1 -1
  664. official/projects/qat/nlp/train.py +1 -1
  665. official/projects/qat/vision/__init__.py +1 -1
  666. official/projects/qat/vision/configs/__init__.py +1 -1
  667. official/projects/qat/vision/configs/common.py +1 -1
  668. official/projects/qat/vision/configs/image_classification.py +1 -1
  669. official/projects/qat/vision/configs/image_classification_test.py +1 -1
  670. official/projects/qat/vision/configs/retinanet.py +1 -1
  671. official/projects/qat/vision/configs/retinanet_test.py +1 -1
  672. official/projects/qat/vision/configs/semantic_segmentation.py +1 -1
  673. official/projects/qat/vision/configs/semantic_segmentation_test.py +1 -1
  674. official/projects/qat/vision/modeling/__init__.py +1 -1
  675. official/projects/qat/vision/modeling/factory.py +1 -3
  676. official/projects/qat/vision/modeling/factory_test.py +1 -3
  677. official/projects/qat/vision/modeling/heads/__init__.py +1 -1
  678. official/projects/qat/vision/modeling/heads/dense_prediction_heads.py +1 -3
  679. official/projects/qat/vision/modeling/heads/dense_prediction_heads_test.py +1 -2
  680. official/projects/qat/vision/modeling/layers/__init__.py +1 -1
  681. official/projects/qat/vision/modeling/layers/nn_blocks.py +1 -3
  682. official/projects/qat/vision/modeling/layers/nn_blocks_test.py +1 -2
  683. official/projects/qat/vision/modeling/layers/nn_layers.py +2 -2
  684. official/projects/qat/vision/modeling/layers/nn_layers_test.py +1 -2
  685. official/projects/qat/vision/modeling/segmentation_model.py +1 -2
  686. official/projects/qat/vision/n_bit/__init__.py +1 -1
  687. official/projects/qat/vision/n_bit/configs.py +1 -1
  688. official/projects/qat/vision/n_bit/configs_test.py +1 -3
  689. official/projects/qat/vision/n_bit/nn_blocks.py +1 -3
  690. official/projects/qat/vision/n_bit/nn_blocks_test.py +1 -2
  691. official/projects/qat/vision/n_bit/nn_layers.py +1 -1
  692. official/projects/qat/vision/n_bit/schemes.py +1 -3
  693. official/projects/qat/vision/quantization/__init__.py +1 -1
  694. official/projects/qat/vision/quantization/configs.py +1 -1
  695. official/projects/qat/vision/quantization/configs_test.py +1 -3
  696. official/projects/qat/vision/quantization/helper.py +1 -1
  697. official/projects/qat/vision/quantization/helper_test.py +1 -1
  698. official/projects/qat/vision/quantization/layer_transforms.py +1 -1
  699. official/projects/qat/vision/quantization/schemes.py +1 -3
  700. official/projects/qat/vision/registry_imports.py +1 -1
  701. official/projects/qat/vision/serving/__init__.py +1 -1
  702. official/projects/qat/vision/serving/export_module.py +1 -1
  703. official/projects/qat/vision/serving/export_saved_model.py +1 -1
  704. official/projects/qat/vision/serving/export_tflite.py +1 -1
  705. official/projects/qat/vision/tasks/__init__.py +1 -1
  706. official/projects/qat/vision/tasks/image_classification.py +1 -1
  707. official/projects/qat/vision/tasks/image_classification_test.py +1 -1
  708. official/projects/qat/vision/tasks/retinanet.py +1 -1
  709. official/projects/qat/vision/tasks/retinanet_test.py +1 -1
  710. official/projects/qat/vision/tasks/semantic_segmentation.py +1 -1
  711. official/projects/qat/vision/train.py +1 -1
  712. official/projects/roformer/__init__.py +1 -1
  713. official/projects/roformer/roformer.py +1 -1
  714. official/projects/roformer/roformer_attention.py +1 -1
  715. official/projects/roformer/roformer_attention_test.py +1 -1
  716. official/projects/roformer/roformer_encoder.py +1 -1
  717. official/projects/roformer/roformer_encoder_block.py +1 -1
  718. official/projects/roformer/roformer_encoder_block_test.py +1 -1
  719. official/projects/roformer/roformer_encoder_test.py +1 -1
  720. official/projects/roformer/roformer_experiments.py +1 -1
  721. official/projects/roformer/train.py +1 -1
  722. official/projects/teams/__init__.py +1 -1
  723. official/projects/teams/teams.py +1 -1
  724. official/projects/teams/teams_experiments.py +1 -1
  725. official/projects/teams/teams_pretrainer.py +1 -1
  726. official/projects/teams/teams_pretrainer_test.py +1 -1
  727. official/projects/teams/teams_task.py +1 -1
  728. official/projects/teams/teams_task_test.py +1 -1
  729. official/projects/teams/train.py +1 -1
  730. official/projects/triviaqa/__init__.py +1 -1
  731. official/projects/triviaqa/dataset.py +1 -1
  732. official/projects/triviaqa/download_and_prepare.py +1 -1
  733. official/projects/triviaqa/evaluate.py +1 -1
  734. official/projects/triviaqa/evaluation.py +1 -1
  735. official/projects/triviaqa/inputs.py +1 -1
  736. official/projects/triviaqa/modeling.py +1 -1
  737. official/projects/triviaqa/predict.py +1 -1
  738. official/projects/triviaqa/prediction.py +1 -1
  739. official/projects/triviaqa/preprocess.py +1 -1
  740. official/projects/triviaqa/sentencepiece_pb2.py +1 -1
  741. official/projects/triviaqa/train.py +1 -1
  742. official/projects/video_ssl/__init__.py +1 -1
  743. official/projects/video_ssl/configs/__init__.py +1 -1
  744. official/projects/video_ssl/configs/video_ssl.py +1 -1
  745. official/projects/video_ssl/configs/video_ssl_test.py +1 -1
  746. official/projects/video_ssl/dataloaders/__init__.py +1 -1
  747. official/projects/video_ssl/dataloaders/video_ssl_input.py +1 -1
  748. official/projects/video_ssl/dataloaders/video_ssl_input_test.py +1 -2
  749. official/projects/video_ssl/losses/__init__.py +1 -1
  750. official/projects/video_ssl/losses/losses.py +1 -2
  751. official/projects/video_ssl/modeling/__init__.py +1 -1
  752. official/projects/video_ssl/modeling/video_ssl_model.py +1 -3
  753. official/projects/video_ssl/ops/__init__.py +1 -1
  754. official/projects/video_ssl/ops/video_ssl_preprocess_ops.py +1 -1
  755. official/projects/video_ssl/ops/video_ssl_preprocess_ops_test.py +1 -1
  756. official/projects/video_ssl/tasks/__init__.py +1 -1
  757. official/projects/video_ssl/tasks/linear_eval.py +1 -1
  758. official/projects/video_ssl/tasks/pretrain.py +1 -1
  759. official/projects/video_ssl/tasks/pretrain_test.py +1 -1
  760. official/projects/video_ssl/train.py +1 -1
  761. official/projects/volumetric_models/__init__.py +1 -1
  762. official/projects/volumetric_models/configs/__init__.py +1 -1
  763. official/projects/volumetric_models/configs/backbones.py +1 -1
  764. official/projects/volumetric_models/configs/decoders.py +1 -1
  765. official/projects/volumetric_models/configs/semantic_segmentation_3d.py +1 -1
  766. official/projects/volumetric_models/configs/semantic_segmentation_3d_test.py +1 -1
  767. official/projects/volumetric_models/dataloaders/__init__.py +1 -1
  768. official/projects/volumetric_models/dataloaders/segmentation_input_3d.py +1 -1
  769. official/projects/volumetric_models/dataloaders/segmentation_input_3d_test.py +1 -1
  770. official/projects/volumetric_models/evaluation/__init__.py +1 -1
  771. official/projects/volumetric_models/evaluation/segmentation_metrics.py +1 -1
  772. official/projects/volumetric_models/evaluation/segmentation_metrics_test.py +1 -1
  773. official/projects/volumetric_models/losses/__init__.py +1 -1
  774. official/projects/volumetric_models/losses/segmentation_losses.py +1 -1
  775. official/projects/volumetric_models/losses/segmentation_losses_test.py +1 -1
  776. official/projects/volumetric_models/modeling/__init__.py +1 -1
  777. official/projects/volumetric_models/modeling/backbones/__init__.py +1 -1
  778. official/projects/volumetric_models/modeling/backbones/unet_3d.py +1 -2
  779. official/projects/volumetric_models/modeling/backbones/unet_3d_test.py +1 -2
  780. official/projects/volumetric_models/modeling/decoders/__init__.py +1 -1
  781. official/projects/volumetric_models/modeling/decoders/factory.py +1 -3
  782. official/projects/volumetric_models/modeling/decoders/factory_test.py +1 -1
  783. official/projects/volumetric_models/modeling/decoders/unet_3d_decoder.py +1 -1
  784. official/projects/volumetric_models/modeling/decoders/unet_3d_decoder_test.py +1 -2
  785. official/projects/volumetric_models/modeling/factory.py +1 -3
  786. official/projects/volumetric_models/modeling/factory_test.py +1 -1
  787. official/projects/volumetric_models/modeling/heads/__init__.py +1 -1
  788. official/projects/volumetric_models/modeling/heads/segmentation_heads_3d.py +1 -1
  789. official/projects/volumetric_models/modeling/heads/segmentation_heads_3d_test.py +1 -1
  790. official/projects/volumetric_models/modeling/nn_blocks_3d.py +2 -3
  791. official/projects/volumetric_models/modeling/nn_blocks_3d_test.py +1 -2
  792. official/projects/volumetric_models/modeling/segmentation_model_test.py +1 -1
  793. official/projects/volumetric_models/registry_imports.py +1 -1
  794. official/projects/volumetric_models/serving/__init__.py +1 -1
  795. official/projects/volumetric_models/serving/export_saved_model.py +1 -1
  796. official/projects/volumetric_models/serving/semantic_segmentation_3d.py +1 -1
  797. official/projects/volumetric_models/serving/semantic_segmentation_3d_test.py +3 -3
  798. official/projects/volumetric_models/tasks/__init__.py +1 -1
  799. official/projects/volumetric_models/tasks/semantic_segmentation_3d.py +1 -1
  800. official/projects/volumetric_models/tasks/semantic_segmentation_3d_test.py +1 -1
  801. official/projects/volumetric_models/train.py +1 -1
  802. official/projects/volumetric_models/train_test.py +1 -1
  803. official/projects/waste_identification_ml/__init__.py +1 -1
  804. official/projects/waste_identification_ml/data_generation/__init__.py +1 -1
  805. official/projects/waste_identification_ml/data_generation/utils.py +1 -1
  806. official/projects/waste_identification_ml/data_generation/utils_test.py +1 -1
  807. official/projects/yolo/__init__.py +1 -1
  808. official/projects/yolo/common/__init__.py +1 -1
  809. official/projects/yolo/common/registry_imports.py +1 -1
  810. official/projects/yolo/configs/__init__.py +1 -1
  811. official/projects/yolo/configs/backbones.py +1 -1
  812. official/projects/yolo/configs/darknet_classification.py +1 -1
  813. official/projects/yolo/configs/decoders.py +1 -1
  814. official/projects/yolo/configs/yolo.py +1 -1
  815. official/projects/yolo/configs/yolov7.py +17 -1
  816. official/projects/yolo/dataloaders/__init__.py +1 -1
  817. official/projects/yolo/dataloaders/classification_input.py +1 -1
  818. official/projects/yolo/dataloaders/tf_example_decoder.py +1 -1
  819. official/projects/yolo/dataloaders/yolo_input.py +1 -1
  820. official/projects/yolo/losses/__init__.py +1 -1
  821. official/projects/yolo/losses/yolo_loss.py +1 -1
  822. official/projects/yolo/losses/yolo_loss_test.py +1 -1
  823. official/projects/yolo/losses/yolov7_loss.py +1 -1
  824. official/projects/yolo/losses/yolov7_loss_test.py +1 -1
  825. official/projects/yolo/modeling/__init__.py +1 -1
  826. official/projects/yolo/modeling/backbones/__init__.py +1 -1
  827. official/projects/yolo/modeling/backbones/darknet.py +1 -1
  828. official/projects/yolo/modeling/backbones/darknet_test.py +1 -1
  829. official/projects/yolo/modeling/backbones/yolov7.py +69 -1
  830. official/projects/yolo/modeling/backbones/yolov7_test.py +1 -1
  831. official/projects/yolo/modeling/decoders/__init__.py +1 -1
  832. official/projects/yolo/modeling/decoders/yolo_decoder.py +1 -1
  833. official/projects/yolo/modeling/decoders/yolo_decoder_test.py +1 -2
  834. official/projects/yolo/modeling/decoders/yolov7.py +90 -1
  835. official/projects/yolo/modeling/decoders/yolov7_test.py +1 -1
  836. official/projects/yolo/modeling/factory.py +1 -1
  837. official/projects/yolo/modeling/factory_test.py +1 -1
  838. official/projects/yolo/modeling/heads/__init__.py +1 -1
  839. official/projects/yolo/modeling/heads/yolo_head.py +1 -1
  840. official/projects/yolo/modeling/heads/yolo_head_test.py +1 -2
  841. official/projects/yolo/modeling/heads/yolov7_head.py +1 -1
  842. official/projects/yolo/modeling/heads/yolov7_head_test.py +1 -1
  843. official/projects/yolo/modeling/layers/__init__.py +1 -1
  844. official/projects/yolo/modeling/layers/detection_generator.py +1 -1
  845. official/projects/yolo/modeling/layers/detection_generator_test.py +1 -1
  846. official/projects/yolo/modeling/layers/nn_blocks.py +1 -1
  847. official/projects/yolo/modeling/layers/nn_blocks_test.py +1 -1
  848. official/projects/yolo/modeling/yolo_model.py +2 -2
  849. official/projects/yolo/modeling/yolov7_model.py +2 -2
  850. official/projects/yolo/ops/__init__.py +1 -1
  851. official/projects/yolo/ops/anchor.py +1 -1
  852. official/projects/yolo/ops/box_ops.py +1 -1
  853. official/projects/yolo/ops/box_ops_test.py +1 -1
  854. official/projects/yolo/ops/initializer_ops.py +1 -1
  855. official/projects/yolo/ops/kmeans_anchors.py +1 -1
  856. official/projects/yolo/ops/kmeans_anchors_test.py +1 -1
  857. official/projects/yolo/ops/loss_utils.py +1 -1
  858. official/projects/yolo/ops/math_ops.py +1 -1
  859. official/projects/yolo/ops/mosaic.py +1 -1
  860. official/projects/yolo/ops/preprocessing_ops.py +1 -1
  861. official/projects/yolo/ops/preprocessing_ops_test.py +1 -1
  862. official/projects/yolo/optimization/__init__.py +1 -1
  863. official/projects/yolo/optimization/configs/__init__.py +1 -1
  864. official/projects/yolo/optimization/configs/optimization_config.py +1 -1
  865. official/projects/yolo/optimization/configs/optimizer_config.py +1 -1
  866. official/projects/yolo/optimization/optimizer_factory.py +1 -1
  867. official/projects/yolo/optimization/sgd_torch.py +1 -1
  868. official/projects/yolo/serving/__init__.py +1 -1
  869. official/projects/yolo/serving/export_module_factory.py +1 -1
  870. official/projects/yolo/serving/export_saved_model.py +1 -1
  871. official/projects/yolo/serving/export_tflite.py +1 -1
  872. official/projects/yolo/serving/model_fn.py +1 -1
  873. official/projects/yolo/tasks/__init__.py +1 -1
  874. official/projects/yolo/tasks/image_classification.py +1 -1
  875. official/projects/yolo/tasks/task_utils.py +1 -1
  876. official/projects/yolo/tasks/yolo.py +1 -1
  877. official/projects/yolo/tasks/yolov7.py +1 -1
  878. official/projects/yolo/train.py +1 -1
  879. official/projects/yt8m/__init__.py +1 -1
  880. official/projects/yt8m/configs/__init__.py +1 -1
  881. official/projects/yt8m/configs/yt8m.py +1 -1
  882. official/projects/yt8m/configs/yt8m_test.py +1 -1
  883. official/projects/yt8m/modeling/__init__.py +1 -1
  884. official/projects/yt8m/modeling/backbones/__init__.py +1 -1
  885. official/projects/yt8m/modeling/backbones/dbof.py +1 -1
  886. official/projects/yt8m/modeling/backbones/dbof_test.py +1 -1
  887. official/projects/yt8m/modeling/heads/__init__.py +1 -1
  888. official/projects/yt8m/modeling/heads/logistic.py +1 -1
  889. official/projects/yt8m/modeling/heads/moe.py +1 -1
  890. official/projects/yt8m/modeling/nn_layers.py +1 -1
  891. official/projects/yt8m/modeling/nn_layers_test.py +1 -1
  892. official/projects/yt8m/modeling/yt8m_model.py +1 -1
  893. official/projects/yt8m/modeling/yt8m_model_test.py +1 -1
  894. official/projects/yt8m/modeling/yt8m_model_utils.py +1 -1
  895. official/projects/yt8m/modeling/yt8m_model_utils_test.py +1 -1
  896. official/projects/yt8m/tasks/__init__.py +1 -1
  897. official/projects/yt8m/tasks/yt8m_task.py +1 -1
  898. official/projects/yt8m/train.py +1 -1
  899. official/projects/yt8m/train_test.py +1 -1
  900. official/recommendation/__init__.py +1 -1
  901. official/recommendation/constants.py +1 -1
  902. official/recommendation/create_ncf_data.py +1 -2
  903. official/recommendation/data_pipeline.py +1 -1
  904. official/recommendation/data_preprocessing.py +1 -1
  905. official/recommendation/data_test.py +4 -4
  906. official/recommendation/movielens.py +1 -2
  907. official/recommendation/ncf_common.py +1 -1
  908. official/recommendation/ncf_input_pipeline.py +1 -1
  909. official/recommendation/ncf_keras_main.py +1 -1
  910. official/recommendation/ncf_test.py +1 -1
  911. official/recommendation/neumf_model.py +1 -1
  912. official/recommendation/popen_helper.py +1 -1
  913. official/recommendation/ranking/__init__.py +1 -1
  914. official/recommendation/ranking/common.py +1 -1
  915. official/recommendation/ranking/configs/__init__.py +1 -1
  916. official/recommendation/ranking/configs/config.py +14 -1
  917. official/recommendation/ranking/configs/config_test.py +1 -1
  918. official/recommendation/ranking/data/__init__.py +1 -1
  919. official/recommendation/ranking/data/data_pipeline.py +9 -2
  920. official/recommendation/ranking/data/data_pipeline_multi_hot.py +8 -2
  921. official/recommendation/ranking/data/data_pipeline_multi_hot_test.py +12 -6
  922. official/recommendation/ranking/data/data_pipeline_test.py +18 -8
  923. official/recommendation/ranking/task.py +102 -19
  924. official/recommendation/ranking/task_test.py +1 -1
  925. official/recommendation/ranking/train.py +1 -1
  926. official/recommendation/ranking/train_test.py +76 -31
  927. official/recommendation/stat_utils.py +1 -1
  928. official/recommendation/uplift/__init__.py +1 -1
  929. official/recommendation/uplift/keras_test_case.py +1 -1
  930. official/recommendation/uplift/keys.py +1 -1
  931. official/recommendation/uplift/layers/__init__.py +1 -1
  932. official/recommendation/uplift/layers/encoders/__init__.py +1 -1
  933. official/recommendation/uplift/layers/encoders/concat_features.py +1 -1
  934. official/recommendation/uplift/layers/encoders/concat_features_test.py +1 -1
  935. official/recommendation/uplift/layers/heads/__init__.py +1 -1
  936. official/recommendation/uplift/layers/heads/two_tower_logits_head.py +1 -1
  937. official/recommendation/uplift/layers/heads/two_tower_logits_head_test.py +1 -1
  938. official/recommendation/uplift/layers/uplift_networks/__init__.py +1 -1
  939. official/recommendation/uplift/layers/uplift_networks/base_uplift_networks.py +1 -1
  940. official/recommendation/uplift/layers/uplift_networks/two_tower_output_head.py +1 -1
  941. official/recommendation/uplift/layers/uplift_networks/two_tower_output_head_test.py +1 -1
  942. official/recommendation/uplift/layers/uplift_networks/two_tower_uplift_network.py +1 -1
  943. official/recommendation/uplift/layers/uplift_networks/two_tower_uplift_network_test.py +1 -1
  944. official/recommendation/uplift/losses/__init__.py +1 -1
  945. official/recommendation/uplift/losses/true_logits_loss.py +1 -1
  946. official/recommendation/uplift/losses/true_logits_loss_test.py +1 -1
  947. official/recommendation/uplift/metrics/__init__.py +1 -1
  948. official/recommendation/uplift/metrics/label_mean.py +1 -1
  949. official/recommendation/uplift/metrics/label_mean_test.py +1 -1
  950. official/recommendation/uplift/metrics/label_variance.py +1 -1
  951. official/recommendation/uplift/metrics/label_variance_test.py +1 -1
  952. official/recommendation/uplift/metrics/loss_metric.py +1 -1
  953. official/recommendation/uplift/metrics/loss_metric_test.py +1 -1
  954. official/recommendation/uplift/metrics/metric_configs.py +1 -1
  955. official/recommendation/uplift/metrics/poisson_metrics.py +1 -1
  956. official/recommendation/uplift/metrics/poisson_metrics_test.py +1 -1
  957. official/recommendation/uplift/metrics/sliced_metric.py +1 -1
  958. official/recommendation/uplift/metrics/sliced_metric_test.py +1 -1
  959. official/recommendation/uplift/metrics/treatment_fraction.py +1 -1
  960. official/recommendation/uplift/metrics/treatment_fraction_test.py +1 -1
  961. official/recommendation/uplift/metrics/treatment_sliced_metric.py +1 -1
  962. official/recommendation/uplift/metrics/treatment_sliced_metric_test.py +1 -1
  963. official/recommendation/uplift/metrics/uplift_mean.py +1 -1
  964. official/recommendation/uplift/metrics/uplift_mean_test.py +1 -1
  965. official/recommendation/uplift/metrics/variance.py +1 -1
  966. official/recommendation/uplift/metrics/variance_test.py +12 -10
  967. official/recommendation/uplift/models/__init__.py +1 -1
  968. official/recommendation/uplift/models/two_tower_uplift_model.py +1 -1
  969. official/recommendation/uplift/models/two_tower_uplift_model_test.py +1 -1
  970. official/recommendation/uplift/types.py +1 -1
  971. official/recommendation/uplift/utils.py +3 -3
  972. official/recommendation/uplift/utils_test.py +1 -1
  973. official/utils/__init__.py +1 -1
  974. official/utils/docs/__init__.py +1 -1
  975. official/utils/docs/build_orbit_api_docs.py +1 -1
  976. official/utils/docs/build_tfm_api_docs.py +1 -1
  977. official/utils/flags/__init__.py +1 -1
  978. official/utils/flags/_base.py +1 -1
  979. official/utils/flags/_benchmark.py +1 -1
  980. official/utils/flags/_conventions.py +1 -1
  981. official/utils/flags/_device.py +1 -1
  982. official/utils/flags/_distribution.py +1 -1
  983. official/utils/flags/_misc.py +1 -1
  984. official/utils/flags/_performance.py +1 -1
  985. official/utils/flags/core.py +1 -1
  986. official/utils/flags/flags_test.py +1 -1
  987. official/utils/hyperparams_flags.py +1 -1
  988. official/utils/misc/__init__.py +1 -1
  989. official/utils/misc/keras_utils.py +1 -1
  990. official/utils/misc/model_helpers.py +1 -1
  991. official/utils/misc/model_helpers_test.py +3 -3
  992. official/utils/testing/__init__.py +1 -1
  993. official/utils/testing/integration.py +1 -1
  994. official/utils/testing/mock_task.py +1 -1
  995. official/vision/__init__.py +1 -1
  996. official/vision/configs/__init__.py +1 -1
  997. official/vision/configs/backbones.py +3 -1
  998. official/vision/configs/backbones_3d.py +1 -2
  999. official/vision/configs/common.py +1 -3
  1000. official/vision/configs/decoders.py +1 -3
  1001. official/vision/configs/image_classification.py +1 -1
  1002. official/vision/configs/image_classification_test.py +1 -1
  1003. official/vision/configs/maskrcnn.py +1 -1
  1004. official/vision/configs/maskrcnn_test.py +1 -1
  1005. official/vision/configs/retinanet.py +2 -1
  1006. official/vision/configs/retinanet_test.py +1 -1
  1007. official/vision/configs/semantic_segmentation.py +7 -8
  1008. official/vision/configs/semantic_segmentation_test.py +1 -1
  1009. official/vision/configs/video_classification.py +1 -1
  1010. official/vision/configs/video_classification_test.py +1 -1
  1011. official/vision/data/__init__.py +1 -1
  1012. official/vision/data/create_coco_tf_record.py +1 -1
  1013. official/vision/data/fake_feature_generator.py +5 -2
  1014. official/vision/data/image_utils.py +1 -1
  1015. official/vision/data/image_utils_test.py +1 -1
  1016. official/vision/data/process_coco_few_shot_json_files.py +1 -1
  1017. official/vision/data/tf_example_builder.py +1 -1
  1018. official/vision/data/tf_example_builder_test.py +1 -1
  1019. official/vision/data/tf_example_feature_key.py +1 -1
  1020. official/vision/data/tfrecord_lib.py +1 -1
  1021. official/vision/data/tfrecord_lib_test.py +1 -1
  1022. official/vision/dataloaders/__init__.py +1 -1
  1023. official/vision/dataloaders/classification_input.py +1 -2
  1024. official/vision/dataloaders/decoder.py +1 -1
  1025. official/vision/dataloaders/input_reader.py +1 -1
  1026. official/vision/dataloaders/input_reader_factory.py +1 -1
  1027. official/vision/dataloaders/maskrcnn_input.py +1 -2
  1028. official/vision/dataloaders/parser.py +1 -1
  1029. official/vision/dataloaders/retinanet_input.py +1 -3
  1030. official/vision/dataloaders/segmentation_input.py +9 -4
  1031. official/vision/dataloaders/tf_example_decoder.py +1 -1
  1032. official/vision/dataloaders/tf_example_decoder_test.py +1 -2
  1033. official/vision/dataloaders/tf_example_label_map_decoder.py +1 -2
  1034. official/vision/dataloaders/tf_example_label_map_decoder_test.py +1 -2
  1035. official/vision/dataloaders/tfds_classification_decoders.py +1 -1
  1036. official/vision/dataloaders/tfds_detection_decoders.py +1 -1
  1037. official/vision/dataloaders/tfds_factory.py +1 -1
  1038. official/vision/dataloaders/tfds_factory_test.py +1 -1
  1039. official/vision/dataloaders/tfds_segmentation_decoders.py +1 -1
  1040. official/vision/dataloaders/tfexample_utils.py +1 -1
  1041. official/vision/dataloaders/utils.py +1 -2
  1042. official/vision/dataloaders/utils_test.py +1 -3
  1043. official/vision/dataloaders/video_input.py +1 -1
  1044. official/vision/dataloaders/video_input_test.py +1 -2
  1045. official/vision/evaluation/__init__.py +1 -1
  1046. official/vision/evaluation/coco_evaluator.py +1 -2
  1047. official/vision/evaluation/coco_utils.py +1 -3
  1048. official/vision/evaluation/coco_utils_test.py +1 -1
  1049. official/vision/evaluation/instance_metrics.py +1 -1
  1050. official/vision/evaluation/instance_metrics_test.py +1 -1
  1051. official/vision/evaluation/iou.py +1 -1
  1052. official/vision/evaluation/iou_test.py +1 -1
  1053. official/vision/evaluation/panoptic_quality.py +1 -1
  1054. official/vision/evaluation/panoptic_quality_evaluator.py +1 -1
  1055. official/vision/evaluation/panoptic_quality_evaluator_test.py +1 -1
  1056. official/vision/evaluation/panoptic_quality_test.py +1 -1
  1057. official/vision/evaluation/segmentation_metrics.py +1 -1
  1058. official/vision/evaluation/segmentation_metrics_test.py +1 -1
  1059. official/vision/evaluation/wod_detection_evaluator.py +1 -1
  1060. official/vision/losses/__init__.py +1 -1
  1061. official/vision/losses/focal_loss.py +1 -1
  1062. official/vision/losses/loss_utils.py +1 -1
  1063. official/vision/losses/maskrcnn_losses.py +1 -2
  1064. official/vision/losses/maskrcnn_losses_test.py +1 -1
  1065. official/vision/losses/retinanet_losses.py +1 -2
  1066. official/vision/losses/segmentation_losses.py +1 -1
  1067. official/vision/losses/segmentation_losses_test.py +1 -1
  1068. official/vision/modeling/__init__.py +1 -1
  1069. official/vision/modeling/backbones/__init__.py +1 -1
  1070. official/vision/modeling/backbones/efficientnet.py +1 -3
  1071. official/vision/modeling/backbones/efficientnet_test.py +1 -2
  1072. official/vision/modeling/backbones/factory.py +1 -3
  1073. official/vision/modeling/backbones/factory_test.py +1 -2
  1074. official/vision/modeling/backbones/mobiledet.py +1 -1
  1075. official/vision/modeling/backbones/mobiledet_test.py +1 -1
  1076. official/vision/modeling/backbones/mobilenet.py +73 -3
  1077. official/vision/modeling/backbones/mobilenet_test.py +12 -3
  1078. official/vision/modeling/backbones/resnet.py +1 -2
  1079. official/vision/modeling/backbones/resnet_3d.py +1 -2
  1080. official/vision/modeling/backbones/resnet_3d_test.py +1 -2
  1081. official/vision/modeling/backbones/resnet_deeplab.py +5 -4
  1082. official/vision/modeling/backbones/resnet_deeplab_test.py +21 -10
  1083. official/vision/modeling/backbones/resnet_test.py +1 -2
  1084. official/vision/modeling/backbones/resnet_unet.py +1 -2
  1085. official/vision/modeling/backbones/resnet_unet_test.py +1 -3
  1086. official/vision/modeling/backbones/revnet.py +1 -2
  1087. official/vision/modeling/backbones/revnet_test.py +1 -2
  1088. official/vision/modeling/backbones/spinenet.py +1 -3
  1089. official/vision/modeling/backbones/spinenet_mobile.py +1 -3
  1090. official/vision/modeling/backbones/spinenet_mobile_test.py +1 -2
  1091. official/vision/modeling/backbones/spinenet_test.py +1 -2
  1092. official/vision/modeling/backbones/vit.py +53 -27
  1093. official/vision/modeling/backbones/vit_specs.py +1 -1
  1094. official/vision/modeling/backbones/vit_test.py +12 -1
  1095. official/vision/modeling/classification_model.py +1 -2
  1096. official/vision/modeling/classification_model_test.py +1 -2
  1097. official/vision/modeling/decoders/__init__.py +1 -1
  1098. official/vision/modeling/decoders/aspp.py +1 -3
  1099. official/vision/modeling/decoders/aspp_test.py +1 -2
  1100. official/vision/modeling/decoders/factory.py +1 -3
  1101. official/vision/modeling/decoders/factory_test.py +1 -1
  1102. official/vision/modeling/decoders/fpn.py +1 -2
  1103. official/vision/modeling/decoders/fpn_test.py +1 -2
  1104. official/vision/modeling/decoders/nasfpn.py +1 -3
  1105. official/vision/modeling/decoders/nasfpn_test.py +1 -2
  1106. official/vision/modeling/factory.py +1 -1
  1107. official/vision/modeling/factory_3d.py +1 -2
  1108. official/vision/modeling/factory_test.py +1 -2
  1109. official/vision/modeling/heads/__init__.py +1 -1
  1110. official/vision/modeling/heads/dense_prediction_heads.py +1 -3
  1111. official/vision/modeling/heads/dense_prediction_heads_test.py +1 -3
  1112. official/vision/modeling/heads/instance_heads.py +3 -4
  1113. official/vision/modeling/heads/instance_heads_test.py +1 -2
  1114. official/vision/modeling/heads/segmentation_heads.py +2 -2
  1115. official/vision/modeling/heads/segmentation_heads_test.py +1 -2
  1116. official/vision/modeling/layers/__init__.py +1 -1
  1117. official/vision/modeling/layers/box_sampler.py +1 -2
  1118. official/vision/modeling/layers/deeplab.py +1 -1
  1119. official/vision/modeling/layers/deeplab_test.py +1 -1
  1120. official/vision/modeling/layers/detection_generator.py +1 -3
  1121. official/vision/modeling/layers/detection_generator_test.py +1 -3
  1122. official/vision/modeling/layers/edgetpu.py +1 -1
  1123. official/vision/modeling/layers/edgetpu_test.py +1 -1
  1124. official/vision/modeling/layers/mask_sampler.py +1 -2
  1125. official/vision/modeling/layers/nn_blocks.py +1 -2
  1126. official/vision/modeling/layers/nn_blocks_3d.py +1 -2
  1127. official/vision/modeling/layers/nn_blocks_3d_test.py +1 -2
  1128. official/vision/modeling/layers/nn_blocks_test.py +1 -3
  1129. official/vision/modeling/layers/nn_layers.py +1 -1
  1130. official/vision/modeling/layers/nn_layers_test.py +1 -2
  1131. official/vision/modeling/layers/roi_aligner.py +7 -5
  1132. official/vision/modeling/layers/roi_aligner_test.py +1 -2
  1133. official/vision/modeling/layers/roi_generator.py +1 -2
  1134. official/vision/modeling/layers/roi_sampler.py +1 -2
  1135. official/vision/modeling/maskrcnn_model.py +1 -1
  1136. official/vision/modeling/maskrcnn_model_test.py +1 -2
  1137. official/vision/modeling/models/__init__.py +1 -1
  1138. official/vision/modeling/retinanet_model.py +9 -8
  1139. official/vision/modeling/retinanet_model_test.py +1 -2
  1140. official/vision/modeling/segmentation_model.py +4 -4
  1141. official/vision/modeling/segmentation_model_test.py +1 -1
  1142. official/vision/modeling/video_classification_model.py +1 -1
  1143. official/vision/modeling/video_classification_model_test.py +1 -2
  1144. official/vision/ops/__init__.py +1 -1
  1145. official/vision/ops/anchor.py +1 -3
  1146. official/vision/ops/anchor_generator.py +1 -1
  1147. official/vision/ops/anchor_generator_test.py +1 -1
  1148. official/vision/ops/anchor_test.py +1 -2
  1149. official/vision/ops/augment.py +4 -16
  1150. official/vision/ops/augment_test.py +1 -1
  1151. official/vision/ops/box_matcher.py +1 -1
  1152. official/vision/ops/box_matcher_test.py +1 -1
  1153. official/vision/ops/box_ops.py +1 -2
  1154. official/vision/ops/iou_similarity.py +1 -1
  1155. official/vision/ops/iou_similarity_test.py +1 -1
  1156. official/vision/ops/mask_ops.py +1 -3
  1157. official/vision/ops/mask_ops_test.py +1 -2
  1158. official/vision/ops/nms.py +1 -2
  1159. official/vision/ops/preprocess_ops.py +40 -11
  1160. official/vision/ops/preprocess_ops_3d.py +6 -3
  1161. official/vision/ops/preprocess_ops_3d_test.py +1 -1
  1162. official/vision/ops/preprocess_ops_test.py +13 -7
  1163. official/vision/ops/sampling_ops.py +1 -2
  1164. official/vision/ops/spatial_transform_ops.py +1 -1
  1165. official/vision/ops/target_gather.py +1 -1
  1166. official/vision/ops/target_gather_test.py +1 -1
  1167. official/vision/registry_imports.py +1 -1
  1168. official/vision/serving/__init__.py +1 -1
  1169. official/vision/serving/detection.py +21 -1
  1170. official/vision/serving/detection_test.py +39 -1
  1171. official/vision/serving/export_base.py +1 -1
  1172. official/vision/serving/export_base_v2.py +1 -1
  1173. official/vision/serving/export_base_v2_test.py +1 -1
  1174. official/vision/serving/export_module_factory.py +1 -1
  1175. official/vision/serving/export_module_factory_test.py +1 -1
  1176. official/vision/serving/export_saved_model.py +1 -1
  1177. official/vision/serving/export_saved_model_lib.py +1 -1
  1178. official/vision/serving/export_saved_model_lib_test.py +1 -1
  1179. official/vision/serving/export_saved_model_lib_v2.py +1 -1
  1180. official/vision/serving/export_tfhub.py +1 -2
  1181. official/vision/serving/export_tfhub_lib.py +1 -3
  1182. official/vision/serving/export_tflite.py +1 -1
  1183. official/vision/serving/export_tflite_lib.py +1 -1
  1184. official/vision/serving/export_utils.py +1 -1
  1185. official/vision/serving/image_classification.py +1 -1
  1186. official/vision/serving/image_classification_test.py +1 -1
  1187. official/vision/serving/semantic_segmentation.py +6 -3
  1188. official/vision/serving/semantic_segmentation_test.py +71 -7
  1189. official/vision/serving/video_classification.py +1 -1
  1190. official/vision/serving/video_classification_test.py +1 -1
  1191. official/vision/tasks/__init__.py +1 -1
  1192. official/vision/tasks/image_classification.py +1 -1
  1193. official/vision/tasks/maskrcnn.py +1 -1
  1194. official/vision/tasks/retinanet.py +1 -1
  1195. official/vision/tasks/semantic_segmentation.py +1 -1
  1196. official/vision/tasks/video_classification.py +1 -1
  1197. official/vision/train.py +1 -1
  1198. official/vision/train_spatial_partitioning.py +1 -1
  1199. official/vision/utils/__init__.py +1 -1
  1200. official/vision/utils/object_detection/__init__.py +1 -1
  1201. official/vision/utils/object_detection/argmax_matcher.py +1 -1
  1202. official/vision/utils/object_detection/balanced_positive_negative_sampler.py +1 -1
  1203. official/vision/utils/object_detection/box_coder.py +1 -1
  1204. official/vision/utils/object_detection/box_list.py +1 -1
  1205. official/vision/utils/object_detection/box_list_ops.py +1 -1
  1206. official/vision/utils/object_detection/faster_rcnn_box_coder.py +1 -1
  1207. official/vision/utils/object_detection/matcher.py +1 -1
  1208. official/vision/utils/object_detection/minibatch_sampler.py +1 -1
  1209. official/vision/utils/object_detection/ops.py +1 -1
  1210. official/vision/utils/object_detection/preprocessor.py +1 -1
  1211. official/vision/utils/object_detection/region_similarity_calculator.py +1 -1
  1212. official/vision/utils/object_detection/shape_utils.py +1 -1
  1213. official/vision/utils/object_detection/target_assigner.py +1 -1
  1214. official/vision/utils/object_detection/visualization_utils.py +6 -1
  1215. official/vision/utils/ops_test.py +1 -1
  1216. official/vision/utils/summary_manager.py +1 -1
  1217. orbit/__init__.py +1 -1
  1218. orbit/actions/__init__.py +1 -1
  1219. orbit/actions/conditional_action.py +3 -2
  1220. orbit/actions/conditional_action_test.py +1 -1
  1221. orbit/actions/export_saved_model.py +1 -1
  1222. orbit/actions/export_saved_model_test.py +1 -1
  1223. orbit/actions/new_best_metric.py +2 -2
  1224. orbit/actions/new_best_metric_test.py +2 -2
  1225. orbit/actions/save_checkpoint_if_preempted.py +1 -1
  1226. orbit/controller.py +1 -1
  1227. orbit/controller_test.py +1 -1
  1228. orbit/examples/__init__.py +1 -1
  1229. orbit/examples/single_task/__init__.py +1 -1
  1230. orbit/examples/single_task/single_task_evaluator.py +1 -1
  1231. orbit/examples/single_task/single_task_evaluator_test.py +1 -1
  1232. orbit/examples/single_task/single_task_trainer.py +1 -1
  1233. orbit/examples/single_task/single_task_trainer_test.py +1 -1
  1234. orbit/runner.py +1 -1
  1235. orbit/standard_runner.py +1 -1
  1236. orbit/standard_runner_test.py +1 -1
  1237. orbit/utils/__init__.py +1 -1
  1238. orbit/utils/common.py +1 -1
  1239. orbit/utils/common_test.py +1 -1
  1240. orbit/utils/epoch_helper.py +1 -1
  1241. orbit/utils/loop_fns.py +7 -2
  1242. orbit/utils/summary_manager.py +1 -1
  1243. orbit/utils/summary_manager_interface.py +1 -1
  1244. orbit/utils/tpu_summaries.py +1 -1
  1245. orbit/utils/tpu_summaries_test.py +1 -1
  1246. tensorflow_models/__init__.py +1 -1
  1247. tensorflow_models/nlp/__init__.py +1 -1
  1248. tensorflow_models/tensorflow_models_test.py +1 -1
  1249. tensorflow_models/uplift/__init__.py +1 -1
  1250. tensorflow_models/vision/__init__.py +1 -1
  1251. {tf_models_nightly-2.17.0.dev20240617.dist-info → tf_models_nightly-2.20.0.dev20251205.dist-info}/METADATA +1 -1
  1252. tf_models_nightly-2.20.0.dev20251205.dist-info/RECORD +1256 -0
  1253. tf_models_nightly-2.17.0.dev20240617.dist-info/RECORD +0 -1220
  1254. {tf_models_nightly-2.17.0.dev20240617.dist-info → tf_models_nightly-2.20.0.dev20251205.dist-info}/AUTHORS +0 -0
  1255. {tf_models_nightly-2.17.0.dev20240617.dist-info → tf_models_nightly-2.20.0.dev20251205.dist-info}/LICENSE +0 -0
  1256. {tf_models_nightly-2.17.0.dev20240617.dist-info → tf_models_nightly-2.20.0.dev20251205.dist-info}/WHEEL +0 -0
  1257. {tf_models_nightly-2.17.0.dev20240617.dist-info → tf_models_nightly-2.20.0.dev20251205.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
1
+ # Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -14,6 +14,8 @@
14
14
 
15
15
  """Tests for block sparse attention layer."""
16
16
 
17
+ import math
18
+
17
19
  from absl.testing import parameterized
18
20
  import numpy as np
19
21
  import tensorflow as tf, tf_keras
@@ -25,9 +27,36 @@ class BlockSparseAttentionTest(tf.test.TestCase, parameterized.TestCase):
25
27
 
26
28
  @parameterized.named_parameters(
27
29
  ("key_value_same_proj", None, None, [40, 80]),
30
+ ("key_value_same_proj_mqa", None, None, [40, 80], False, 1),
31
+ ("key_value_same_proj_multi_query_blocks", None, None, [40, 80], True),
32
+ (
33
+ "key_value_same_proj_multi_query_blocks_mqa",
34
+ None,
35
+ None,
36
+ [40, 80],
37
+ True,
38
+ 1,
39
+ ),
28
40
  ("key_value_different_proj", 32, 60, [40, 60]),
41
+ ("key_value_different_proj_mqa", 32, 60, [40, 60], False, 1),
42
+ ("key_value_different_proj_multi_query_blocks", 32, 60, [40, 60], True),
43
+ (
44
+ "key_value_different_proj_multi_query_blocks_mqa",
45
+ 32,
46
+ 60,
47
+ [40, 60],
48
+ True,
49
+ 1,
50
+ ),
29
51
  )
30
- def test_non_masked_attention(self, value_dim, output_shape, output_dims):
52
+ def test_non_masked_attention(
53
+ self,
54
+ value_dim,
55
+ output_shape,
56
+ output_dims,
57
+ multi_query_blocks=False,
58
+ num_kv_heads=None,
59
+ ):
31
60
  """Test that the attention layer can be created without a mask tensor."""
32
61
  test_layer = block_sparse_attention.MultiHeadAttention(
33
62
  num_heads=12,
@@ -35,7 +64,8 @@ class BlockSparseAttentionTest(tf.test.TestCase, parameterized.TestCase):
35
64
  value_dim=value_dim,
36
65
  output_shape=output_shape,
37
66
  src_block_size=10,
38
- tgt_block_size=5,
67
+ tgt_block_size=20 if multi_query_blocks else 5,
68
+ num_kv_heads=num_kv_heads,
39
69
  )
40
70
  # Create a 3-dimensional input (the first dimension is implicit).
41
71
  query = tf_keras.Input(shape=(40, 80))
@@ -53,12 +83,41 @@ class BlockSparseAttentionTest(tf.test.TestCase, parameterized.TestCase):
53
83
  output = test_layer(query, query)
54
84
  self.assertEqual(output.shape.as_list(), [None, 40, 80])
55
85
 
56
- @parameterized.named_parameters(("with_bias", True), ("no_bias", False))
57
- def test_masked_attention(self, use_bias):
86
+ @parameterized.named_parameters(
87
+ ("with_bias", True),
88
+ ("with_bias_mqa", True, False, False, 1),
89
+ ("with_bias_multi_query_blocks", True, False, True),
90
+ ("with_bias_multi_query_blocks_mqa", True, False, True, 1),
91
+ ("no_bias", False),
92
+ ("no_bias_mqa", False, False, False, 1),
93
+ ("no_bias_multi_query_blocks", False, False, True),
94
+ ("no_bias_multi_query_blocks_mqa", False, False, True, 1),
95
+ ("with_sigmoid_attn", True, True),
96
+ ("with_sigmoid_attn_mqa", True, True, False, 1),
97
+ ("with_sigmoid_attn_multi_query_blocks", True, True, True),
98
+ ("with_sigmoid_attn_multi_query_blocks_mqa", True, True, True, 1),
99
+ )
100
+ def test_masked_attention(
101
+ self,
102
+ use_bias,
103
+ use_sigmoid_attn=False,
104
+ multi_query_blocks=False,
105
+ num_kv_heads=None,
106
+ ):
58
107
  """Test with a mask tensor."""
108
+ if use_sigmoid_attn:
109
+ sigmoid_attn_bias = -math.log(2)
110
+ else:
111
+ sigmoid_attn_bias = None
59
112
  test_layer = block_sparse_attention.MultiHeadAttention(
60
- num_heads=4, key_dim=2, use_bias=use_bias, src_block_size=2,
61
- tgt_block_size=1,
113
+ num_heads=4,
114
+ key_dim=2,
115
+ use_bias=use_bias,
116
+ src_block_size=2,
117
+ tgt_block_size=2 if multi_query_blocks else 1,
118
+ use_sigmoid_attn=use_sigmoid_attn,
119
+ sigmoid_attn_bias=sigmoid_attn_bias,
120
+ num_kv_heads=num_kv_heads,
62
121
  )
63
122
  # Create a 3-dimensional input (the first dimension is implicit).
64
123
  batch_size = 3
@@ -112,6 +171,77 @@ class BlockSparseAttentionTest(tf.test.TestCase, parameterized.TestCase):
112
171
  self.assertLen(test_layer._query_dense.trainable_variables, 1)
113
172
  self.assertLen(test_layer._output_dense.trainable_variables, 1)
114
173
 
174
+ @parameterized.named_parameters(
175
+ ("default_with_softmax", False),
176
+ ("default_with_sigmoid", True),
177
+ )
178
+ def test_default_masked_attention(
179
+ self,
180
+ use_sigmoid_attn=False,
181
+ ):
182
+ """Test with a mask tensor."""
183
+ seq_len = 8
184
+ if use_sigmoid_attn:
185
+ sigmoid_attn_bias = -math.log(seq_len)
186
+ else:
187
+ sigmoid_attn_bias = None
188
+ test_layer = block_sparse_attention.MultiHeadAttention(
189
+ num_heads=4,
190
+ key_dim=2,
191
+ use_bias=True,
192
+ src_block_size=seq_len,
193
+ tgt_block_size=seq_len,
194
+ use_sigmoid_attn=use_sigmoid_attn,
195
+ sigmoid_attn_bias=sigmoid_attn_bias,
196
+ )
197
+ # Create a 3-dimensional input (the first dimension is implicit).
198
+ batch_size = 3
199
+ query = tf_keras.Input(shape=(seq_len, 8))
200
+ value = tf_keras.Input(shape=(seq_len, 8))
201
+ mask_tensor = tf_keras.Input(shape=(seq_len, seq_len))
202
+ output = test_layer(query=query, value=value, attention_mask=mask_tensor)
203
+
204
+ # Create a model containing the test layer.
205
+ model = tf_keras.Model([query, value, mask_tensor], output)
206
+
207
+ # Generate data for the input (non-mask) tensors.
208
+ from_data = 10 * np.random.random_sample((batch_size, seq_len, 8))
209
+ to_data = 10 * np.random.random_sample((batch_size, seq_len, 8))
210
+
211
+ # Invoke the data with a random set of mask data. This should mask at
212
+ # least one element.
213
+ mask_data = np.random.randint(2, size=(batch_size, seq_len, seq_len))
214
+ masked_output_data = model.predict([from_data, to_data, mask_data])
215
+
216
+ # Invoke the same data, but with a null mask (where no elements are
217
+ # masked).
218
+ null_mask_data = np.ones((batch_size, seq_len, seq_len))
219
+ unmasked_output_data = model.predict([from_data, to_data, null_mask_data])
220
+
221
+ # Because one data is masked and one is not, the outputs should not be
222
+ # the same.
223
+ self.assertNotAllClose(masked_output_data, unmasked_output_data)
224
+
225
+ # Tests the layer with three inputs: Q, K, V.
226
+ key = tf_keras.Input(shape=(seq_len, 8))
227
+ output = test_layer(
228
+ query, value=value, key=key, attention_mask=mask_tensor
229
+ )
230
+ model = tf_keras.Model([query, value, key, mask_tensor], output)
231
+
232
+ masked_output_data = model.predict(
233
+ [from_data, to_data, to_data, mask_data]
234
+ )
235
+ unmasked_output_data = model.predict(
236
+ [from_data, to_data, to_data, null_mask_data]
237
+ )
238
+ # Because one data is masked and one is not, the outputs should not be
239
+ # the same.
240
+ self.assertNotAllClose(masked_output_data, unmasked_output_data)
241
+
242
+ self.assertLen(test_layer._query_dense.trainable_variables, 2)
243
+ self.assertLen(test_layer._output_dense.trainable_variables, 2)
244
+
115
245
  def test_masked_attention_with_scores(self):
116
246
  """Test with a mask tensor."""
117
247
  test_layer = block_sparse_attention.MultiHeadAttention(
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
1
+ # Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
1
+ # Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
1
+ # Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
1
+ # Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
1
+ # Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -37,7 +37,7 @@ class GatedFeedforward(tf_keras.layers.Layer):
37
37
  dropout: Dropout probability for the output dropout.
38
38
  use_gate: Whether to use gated linear units. If True, assuming `GELU` as the
39
39
  activation and omitting bias, will apply
40
- `GEGLU(x, W, V, W_2) = (GEGLU(xW) * xV)W2`; if False, will follow
40
+ `GEGLU(x, W, V, W_2) = (GELU(xW) * xV)W2`; if False, will follow
41
41
  "Attention Is All You Need" (https://arxiv.org/abs/1706.03762) paper and
42
42
  apply `FFN(x, W, W_2) = GELU(xW_1)W_2.`
43
43
  num_blocks: The number of feedforward blocks to stack. Each block contains a
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
1
+ # Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
1
+ # Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
1
+ # Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
1
+ # Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
1
+ # Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
1
+ # Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
1
+ # Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
1
+ # Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
1
+ # Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
1
+ # Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -16,7 +16,6 @@
16
16
  # pylint: disable=g-classes-have-attributes
17
17
 
18
18
  from typing import Tuple
19
- # Import libraries
20
19
  import tensorflow as tf, tf_keras
21
20
 
22
21
  from official.modeling import tf_utils
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
1
+ # Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
1
+ # Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
1
+ # Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
1
+ # Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
1
+ # Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
1
+ # Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
1
+ # Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
1
+ # Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
1
+ # Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
1
+ # Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -18,10 +18,13 @@ Based on https://arxiv.org/pdf/1911.02150.pdf and
18
18
  https://arxiv.org/pdf/2305.13245.pdf.
19
19
  """
20
20
 
21
+ import math
21
22
  import string
22
23
  from typing import Optional, Sequence, Union
23
24
 
25
+ import gin
24
26
  import tensorflow as tf, tf_keras
27
+ from official.modeling import tf_utils
25
28
 
26
29
  _CHR_IDX = string.ascii_lowercase
27
30
 
@@ -78,7 +81,9 @@ def _get_output_shape(
78
81
  class MultiHeadAttention(tf_keras.layers.MultiHeadAttention):
79
82
  """Multi-query attention layer."""
80
83
 
81
- def __init__(self, num_kv_heads=None, **kwargs):
84
+ def __init__(
85
+ self, num_kv_heads=None, enable_gqa_optimization=False, **kwargs
86
+ ):
82
87
  # num_kv_heads defines the number of key/value heads. A value of 1 means
83
88
  # that the key/value heads are shared across all query heads. Any other
84
89
  # value must be less than num_heads and must divide num_heads exactly. If
@@ -86,6 +91,16 @@ class MultiHeadAttention(tf_keras.layers.MultiHeadAttention):
86
91
  # num_kv_heads.
87
92
  super().__init__(**kwargs)
88
93
  self._num_kv_heads = num_kv_heads or self._num_heads
94
+ # TODO(akandoor): Remove this flag once the GQA optimization is rolled out.
95
+ # This flag is used to enable order of K,G in the einsum equations.
96
+ # This optimization is only used in GQA, and is disabled by default.
97
+ # If enabled, the einsum equations are:
98
+ # 1. Dot product: "...SKH,...TKGH->...KGTS"
99
+ # 2. Combine: "...KGTS,...SKH->...TKGH"
100
+ # If disabled, the einsum equations are:
101
+ # 1. Dot product: "...SKH,...TKnH->...nKTS"
102
+ # 2. Combine: "...nKTS,...SKH->...TnKH"
103
+ self._enable_gqa_optimization = enable_gqa_optimization
89
104
  assert (
90
105
  self._num_kv_heads < self._num_heads
91
106
  ), "num_kv_heads must be less than num_heads."
@@ -93,6 +108,11 @@ class MultiHeadAttention(tf_keras.layers.MultiHeadAttention):
93
108
  self._num_heads % self._num_kv_heads == 0
94
109
  ), "num_kv_heads needs to divide num_heads exactly."
95
110
 
111
+ def get_config(self):
112
+ config = super().get_config()
113
+ config.update({"num_kv_heads": self._num_kv_heads})
114
+ return config
115
+
96
116
  def _build_from_signature(
97
117
  self,
98
118
  query: Union[tf.Tensor, tf.TensorShape],
@@ -138,8 +158,12 @@ class MultiHeadAttention(tf_keras.layers.MultiHeadAttention):
138
158
  output_dims = 2
139
159
  key_last_dims = [self._num_kv_heads, self._key_dim]
140
160
  value_last_dims = [self._num_kv_heads, self._value_dim]
141
- self._dot_product_equation = "...SKH,...TKnH->...nKTS"
142
- self._combine_equation = "...nKTS,...SKH->...TnKH"
161
+ if self._enable_gqa_optimization:
162
+ self._dot_product_equation = "...SKH,...TKGH->...KGTS"
163
+ self._combine_equation = "...KGTS,...SKH->...TKGH"
164
+ else:
165
+ self._dot_product_equation = "...SKH,...TKnH->...nKTS"
166
+ self._combine_equation = "...nKTS,...SKH->...TnKH"
143
167
 
144
168
  einsum_equation, bias_axes, output_rank = _build_proj_equation(
145
169
  free_dims=self._key_shape.rank - 1,
@@ -165,6 +189,9 @@ class MultiHeadAttention(tf_keras.layers.MultiHeadAttention):
165
189
  name="value",
166
190
  **self._get_common_kwargs_for_sublayer(),
167
191
  )
192
+ self._qkv_rank = (
193
+ output_rank if self._num_kv_heads > 1 else output_rank + 1
194
+ )
168
195
 
169
196
  def _compute_attention(
170
197
  self, query, key, value, attention_mask=None, training=None
@@ -206,3 +233,194 @@ class MultiHeadAttention(tf_keras.layers.MultiHeadAttention):
206
233
  ],
207
234
  )
208
235
  return attention_output, attention_scores
236
+
237
+
238
+ @tf_keras.utils.register_keras_serializable(package="Text")
239
+ @gin.configurable
240
+ class TalkingHeadsMultiQueryAttention(MultiHeadAttention):
241
+ """Implements Talking-Heads Attention combined with Multi-Query Attention.
242
+
243
+ See https://arxiv.org/pdf/2003.02436 for more details.
244
+ TODO(akandoor): Make num talking heads configurable. Currently, num talking
245
+ heads is fixed to num query heads.
246
+
247
+ This class inherits from MultiQueryAttention to get the MQA-specific
248
+ logic for __init__, get_config.
249
+
250
+ It then overrides _build_from_signature to add the talking-heads weights
251
+ and overrides _compute_attention to merge the MQA wrapper (for
252
+ reshaping) with the THA computation (for pre/post-softmax projections).
253
+ """
254
+
255
+ def _build_from_signature(
256
+ self,
257
+ query: Union[tf.Tensor, tf.TensorShape],
258
+ value: Union[tf.Tensor, tf.TensorShape],
259
+ key: Optional[Union[tf.Tensor, tf.TensorShape]] = None,
260
+ ):
261
+ """Builds layers and variables."""
262
+ # Call the parent (MultiQueryAttention) _build_from_signature.
263
+ super()._build_from_signature(query=query, value=value, key=key)
264
+ # Now, *after* all MQA setup is done, we add the THA setup logic.
265
+ qkv_rank = self._qkv_rank
266
+ # TalkingHeadsAttention logic to the MQA build logic.
267
+ num_batch_dims = qkv_rank - len(self._attention_axes) - 2
268
+ attn_scores_rank = num_batch_dims + 1 + len(self._attention_axes) * 2
269
+ scores_notation = _CHR_IDX[:attn_scores_rank]
270
+ projection_notation = scores_notation[num_batch_dims] + (
271
+ _CHR_IDX[attn_scores_rank])
272
+ projected_scores_notation = scores_notation[:num_batch_dims] + (
273
+ _CHR_IDX[attn_scores_rank] + scores_notation[num_batch_dims + 1:])
274
+ self._talking_heads_equation = "%s,%s->%s" % (
275
+ scores_notation, projection_notation, projected_scores_notation)
276
+
277
+ with tf.init_scope():
278
+ self._pre_softmax_weight = self.add_weight(
279
+ "pre_softmax_weight",
280
+ shape=(self._num_heads, self._num_heads),
281
+ initializer=tf_utils.clone_initializer(self._kernel_initializer),
282
+ regularizer=self._kernel_regularizer,
283
+ constraint=self._kernel_constraint,
284
+ dtype=self.dtype,
285
+ trainable=True)
286
+ self._post_softmax_weight = self.add_weight(
287
+ "post_softmax_weight",
288
+ shape=(self._num_heads, self._num_heads),
289
+ initializer=tf_utils.clone_initializer(self._kernel_initializer),
290
+ regularizer=self._kernel_regularizer,
291
+ constraint=self._kernel_constraint,
292
+ dtype=self.dtype,
293
+ trainable=True)
294
+
295
+ def _compute_attention(
296
+ self, query, key, value, attention_mask=None, training=None
297
+ ):
298
+ """Applies Dot-product attention, merging MQA wrapper and THA computation.
299
+
300
+ Args:
301
+ query: Projected query `Tensor` of shape `[B, T, N, key_dim]`.
302
+ key: Projected key `Tensor` of shape `[B, T, N, key_dim]`.
303
+ value: Projected value `Tensor` of shape `[B, T, N, value_dim]`.
304
+ attention_mask: a boolean mask of shape `[B, T, S]`, that prevents
305
+ attention to certain positions.
306
+ training: Python boolean indicating whether the layer should behave in
307
+ training mode (adding dropout) or in inference mode (doing nothing).
308
+
309
+ Returns:
310
+ attention_output: Multi-headed outputs of attention computation.
311
+ attention_scores: Multi-headed attention weights.
312
+ """
313
+ # This is the MQA "wrapper" logic for grouped queries
314
+ query_shape = tf.shape(query)
315
+ if self._num_kv_heads > 1:
316
+ query = tf.reshape(
317
+ query,
318
+ [
319
+ query_shape[0],
320
+ query_shape[1],
321
+ self._num_kv_heads,
322
+ self._num_heads // self._num_kv_heads,
323
+ query_shape[-1],
324
+ ],
325
+ )
326
+
327
+ # This is the THA "computation" logic
328
+ # Note: Applying scalar multiply at the smaller end of einsum improves
329
+ # XLA performance, but may introduce slight numeric differences in
330
+ # the Transformer attention head.
331
+ query = tf.multiply(
332
+ query, 1.0 / math.sqrt(float(self._key_dim))
333
+ )
334
+
335
+ # Note: self._dot_product_equation was set by _build_from_signature
336
+ # (from MQA) to be MQA-compatible.
337
+ attention_scores = tf.einsum(self._dot_product_equation, key, query)
338
+
339
+ # --- Talking-Heads modification for MQA ---
340
+ # The THA _talking_heads_equation expects scores of shape [B, N, T, S].
341
+ # The MQA _dot_product_equation produces [B, K, G, T, S].
342
+ # We must reshape before and after applying TH logic.
343
+ scores_shape = tf.shape(attention_scores)
344
+ if self._num_kv_heads > 1:
345
+ # Reshape from [B, K, G, T, S] to [B, N, T, S]
346
+ attention_scores = tf.reshape(
347
+ attention_scores,
348
+ [
349
+ scores_shape[0], # Batch
350
+ self._num_heads, # N = K * G
351
+ scores_shape[-2], # T
352
+ scores_shape[-1] # S
353
+ ]
354
+ )
355
+
356
+ # Apply linear projection before softmax
357
+ attention_scores = tf.einsum(self._talking_heads_equation, attention_scores,
358
+ self._pre_softmax_weight)
359
+
360
+ # Normalize the attention scores to probabilities.
361
+ attention_scores = self._masked_softmax(attention_scores, attention_mask)
362
+
363
+ # Apply linear projection after softmax
364
+ attention_scores = tf.einsum(self._talking_heads_equation, attention_scores,
365
+ self._post_softmax_weight)
366
+
367
+ # Reshape back to MQA-compatible shape [B, K, G, T, S]
368
+ # before the final combine_equation
369
+ if self._num_kv_heads > 1:
370
+ if self._enable_gqa_optimization:
371
+ attention_scores = tf.reshape(
372
+ attention_scores,
373
+ [
374
+ scores_shape[0], # B
375
+ self._num_kv_heads, # K
376
+ self._num_heads // self._num_kv_heads, # G
377
+ scores_shape[-2], # T
378
+ scores_shape[-1] # S
379
+ ]
380
+ )
381
+ else:
382
+ attention_scores = tf.reshape(
383
+ attention_scores,
384
+ [
385
+ scores_shape[0], # B
386
+ self._num_heads // self._num_kv_heads, # G
387
+ self._num_kv_heads, # K
388
+ scores_shape[-2], # T
389
+ scores_shape[-1] # S
390
+ ]
391
+ )
392
+
393
+ # This is actually dropping out entire tokens to attend to.
394
+ attention_scores_dropout = self._dropout_layer(
395
+ attention_scores, training=training)
396
+
397
+ # Note: self._combine_equation was set by _build_from_signature
398
+ # (from MQA) to be MQA-compatible.
399
+ attention_output = tf.einsum(self._combine_equation,
400
+ attention_scores_dropout, value)
401
+
402
+ # This is the MQA "wrapper" logic for grouped queries
403
+ if self._num_kv_heads > 1:
404
+ attention_output = tf.reshape(
405
+ attention_output,
406
+ [
407
+ query_shape[0],
408
+ query_shape[1],
409
+ self._num_heads,
410
+ tf.shape(attention_output)[-1],
411
+ ],
412
+ )
413
+ # We also need to reshape the final scores back to [B, N, T, S]
414
+ # for the return value.
415
+ attention_scores = tf.reshape(
416
+ attention_scores,
417
+ [
418
+ query_shape[0],
419
+ self._num_heads,
420
+ tf.shape(attention_scores)[-2],
421
+ tf.shape(attention_scores)[-1],
422
+ ],
423
+ )
424
+
425
+ return attention_output, attention_scores
426
+