easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
@@ -0,0 +1,749 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ from __future__ import absolute_import
4
+ from __future__ import division
5
+ from __future__ import print_function
6
+
7
+ import math
8
+
9
+ import tensorflow as tf
10
+
11
+ from easy_rec.python.compat.layers import layer_norm as tf_layer_norm
12
+ from easy_rec.python.utils.activation import gelu
13
+ from easy_rec.python.utils.shape_utils import get_shape_list
14
+
15
+ if tf.__version__ >= '2.0':
16
+ tf = tf.compat.v1
17
+
18
+
19
+ def create_initializer(initializer_range=0.02):
20
+ """Creates a `truncated_normal_initializer` with the given range."""
21
+ return tf.truncated_normal_initializer(stddev=initializer_range)
22
+
23
+
24
+ def dropout(input_tensor, dropout_prob):
25
+ """Perform dropout.
26
+
27
+ Args:
28
+ input_tensor: float Tensor.
29
+ dropout_prob: Python float. The probability of dropping out a value (NOT of
30
+ *keeping* a dimension as in `tf.nn.dropout`).
31
+
32
+ Returns:
33
+ A version of `input_tensor` with dropout applied.
34
+ """
35
+ if dropout_prob is None or dropout_prob == 0.0:
36
+ return input_tensor
37
+
38
+ output = tf.nn.dropout(input_tensor, 1.0 - dropout_prob)
39
+ return output
40
+
41
+
42
+ def attention_layer(from_tensor,
43
+ to_tensor,
44
+ size_per_head,
45
+ num_attention_heads=1,
46
+ attention_mask=None,
47
+ query_act=None,
48
+ key_act=None,
49
+ value_act=None,
50
+ attention_probs_dropout_prob=0.0,
51
+ initializer_range=0.02,
52
+ do_return_2d_tensor=False,
53
+ batch_size=None,
54
+ from_seq_length=None,
55
+ to_seq_length=None,
56
+ reuse=None):
57
+ """Performs multi-headed attention from `from_tensor` to `to_tensor`.
58
+
59
+ This is an implementation of multi-headed attention based on "Attention is all you Need".
60
+ If `from_tensor` and `to_tensor` are the same, then this is self-attention.
61
+ Each timestep in `from_tensor` attends to the corresponding sequence in `to_tensor`,
62
+ and returns a fixed-width vector.
63
+ This function first projects `from_tensor` into a "query" tensor and `to_tensor` into "key" and "value" tensors.
64
+ These are (effectively) a list of tensors of length `num_attention_heads`, where each tensor is of shape:
65
+ [batch_size, seq_length, size_per_head].
66
+ Then, the query and key tensors are dot-producted and scaled. These are
67
+ softmaxed to obtain attention probabilities. The value tensors are then
68
+ interpolated by these probabilities, then concatenated back to a single
69
+ tensor and returned.
70
+ In practice, the multi-headed attention are done with transposes and
71
+ reshapes rather than actual separate tensors.
72
+
73
+ Args:
74
+ from_tensor: float Tensor of shape [batch_size, from_seq_length,
75
+ from_width].
76
+ to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width].
77
+ size_per_head: int. Size of each attention head.
78
+ num_attention_heads: int. Number of attention heads.
79
+ attention_mask: (optional) int32 Tensor of shape [batch_size,
80
+ from_seq_length, to_seq_length]. The values should be 1 or 0. The
81
+ attention scores will effectively be set to -infinity for any positions in
82
+ the mask that are 0, and will be unchanged for positions that are 1.
83
+ query_act: (optional) Activation function for the query transform.
84
+ key_act: (optional) Activation function for the key transform.
85
+ value_act: (optional) Activation function for the value transform.
86
+ attention_probs_dropout_prob: (optional) float. Dropout probability of the
87
+ attention probabilities.
88
+ initializer_range: float. Range of the weight initializer.
89
+ do_return_2d_tensor: bool. If True, the output will be of shape [batch_size
90
+ * from_seq_length, num_attention_heads * size_per_head]. If False, the
91
+ output will be of shape [batch_size, from_seq_length, num_attention_heads
92
+ * size_per_head].
93
+ batch_size: (Optional) int. If the input is 2D, this might be the batch size
94
+ of the 3D version of the `from_tensor` and `to_tensor`.
95
+ from_seq_length: (Optional) If the input is 2D, this might be the seq length
96
+ of the 3D version of the `from_tensor`.
97
+ to_seq_length: (Optional) If the input is 2D, this might be the seq length
98
+ of the 3D version of the `to_tensor`.
99
+ reuse: whether to reuse this layer
100
+
101
+ Returns:
102
+ float Tensor of shape [batch_size, from_seq_length,
103
+ num_attention_heads * size_per_head]. (If `do_return_2d_tensor` is
104
+ true, this will be of shape [batch_size * from_seq_length,
105
+ num_attention_heads * size_per_head]).
106
+
107
+ Raises:
108
+ ValueError: Any of the arguments or tensor shapes are invalid.
109
+ """
110
+
111
+ def transpose_for_scores(input_tensor, batch_size, num_attention_heads,
112
+ seq_length, width):
113
+ output_tensor = tf.reshape(
114
+ input_tensor, [batch_size, seq_length, num_attention_heads, width])
115
+
116
+ output_tensor = tf.transpose(output_tensor, [0, 2, 1, 3])
117
+ return output_tensor
118
+
119
+ from_shape = get_shape_list(from_tensor, expected_rank=[2, 3])
120
+ to_shape = get_shape_list(to_tensor, expected_rank=[2, 3])
121
+
122
+ if len(from_shape) != len(to_shape):
123
+ raise ValueError(
124
+ 'The rank of `from_tensor` must match the rank of `to_tensor`.')
125
+
126
+ if len(from_shape) == 3:
127
+ batch_size = from_shape[0]
128
+ from_seq_length = from_shape[1]
129
+ to_seq_length = to_shape[1]
130
+ elif len(from_shape) == 2:
131
+ if (batch_size is None or from_seq_length is None or to_seq_length is None):
132
+ raise ValueError(
133
+ 'When passing in rank 2 tensors to attention_layer, the values '
134
+ 'for `batch_size`, `from_seq_length`, and `to_seq_length` '
135
+ 'must all be specified.')
136
+
137
+ # Scalar dimensions referenced here:
138
+ # B = batch size (number of sequences)
139
+ # F = `from_tensor` sequence length
140
+ # T = `to_tensor` sequence length
141
+ # N = `num_attention_heads`
142
+ # H = `size_per_head`
143
+
144
+ from_tensor_2d = reshape_to_matrix(from_tensor)
145
+ to_tensor_2d = reshape_to_matrix(to_tensor)
146
+
147
+ # `query_layer` = [B*F, N*H]
148
+ query_layer = tf.layers.dense(
149
+ from_tensor_2d,
150
+ num_attention_heads * size_per_head,
151
+ activation=query_act,
152
+ name='query',
153
+ kernel_initializer=create_initializer(initializer_range),
154
+ reuse=reuse)
155
+
156
+ # `key_layer` = [B*T, N*H]
157
+ key_layer = tf.layers.dense(
158
+ to_tensor_2d,
159
+ num_attention_heads * size_per_head,
160
+ activation=key_act,
161
+ name='key',
162
+ kernel_initializer=create_initializer(initializer_range),
163
+ reuse=reuse)
164
+
165
+ # `value_layer` = [B*T, N*H]
166
+ value_layer = tf.layers.dense(
167
+ to_tensor_2d,
168
+ num_attention_heads * size_per_head,
169
+ activation=value_act,
170
+ name='value',
171
+ kernel_initializer=create_initializer(initializer_range),
172
+ reuse=reuse)
173
+
174
+ # `query_layer` = [B, N, F, H]
175
+ query_layer = transpose_for_scores(query_layer, batch_size,
176
+ num_attention_heads, from_seq_length,
177
+ size_per_head)
178
+
179
+ # `key_layer` = [B, N, T, H]
180
+ key_layer = transpose_for_scores(key_layer, batch_size, num_attention_heads,
181
+ to_seq_length, size_per_head)
182
+
183
+ # Take the dot product between "query" and "key" to get the raw
184
+ # attention scores.
185
+ # `attention_scores` = [B, N, F, T]
186
+ attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
187
+ attention_scores = tf.multiply(attention_scores,
188
+ 1.0 / math.sqrt(float(size_per_head)))
189
+
190
+ if attention_mask is not None:
191
+ # `attention_mask` = [B, 1, F, T]
192
+ attention_mask = tf.expand_dims(attention_mask, axis=[1])
193
+
194
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
195
+ # masked positions, this operation will create a tensor which is 0.0 for
196
+ # positions we want to attend and -10000.0 for masked positions.
197
+ adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0
198
+
199
+ # Since we are adding it to the raw scores before the softmax, this is
200
+ # effectively the same as removing these entirely.
201
+ attention_scores += adder
202
+
203
+ # Normalize the attention scores to probabilities.
204
+ # `attention_probs` = [B, N, F, T]
205
+ attention_probs = tf.nn.softmax(attention_scores)
206
+
207
+ # This is actually dropping out entire tokens to attend to, which might
208
+ # seem a bit unusual, but is taken from the original Transformer paper.
209
+ attention_probs = dropout(attention_probs, attention_probs_dropout_prob)
210
+
211
+ # `value_layer` = [B, T, N, H]
212
+ value_layer = tf.reshape(
213
+ value_layer,
214
+ [batch_size, to_seq_length, num_attention_heads, size_per_head])
215
+
216
+ # `value_layer` = [B, N, T, H]
217
+ value_layer = tf.transpose(value_layer, [0, 2, 1, 3])
218
+
219
+ # `context_layer` = [B, N, F, H]
220
+ context_layer = tf.matmul(attention_probs, value_layer)
221
+
222
+ # `context_layer` = [B, F, N, H]
223
+ context_layer = tf.transpose(context_layer, [0, 2, 1, 3])
224
+
225
+ if do_return_2d_tensor:
226
+ # `context_layer` = [B*F, N*H]
227
+ context_layer = tf.reshape(
228
+ context_layer,
229
+ [batch_size * from_seq_length, num_attention_heads * size_per_head])
230
+ else:
231
+ # `context_layer` = [B, F, N*H]
232
+ context_layer = tf.reshape(
233
+ context_layer,
234
+ [batch_size, from_seq_length, num_attention_heads * size_per_head])
235
+
236
+ return context_layer
237
+
238
+
239
+ def transformer_encoder(input_tensor,
240
+ attention_mask=None,
241
+ hidden_size=768,
242
+ num_hidden_layers=12,
243
+ num_attention_heads=12,
244
+ intermediate_size=3072,
245
+ intermediate_act_fn=gelu,
246
+ hidden_dropout_prob=0.1,
247
+ attention_probs_dropout_prob=0.1,
248
+ initializer_range=0.02,
249
+ reuse=None,
250
+ name='transformer'):
251
+ """Multi-headed, multi-layer Transformer from "Attention is All You Need".
252
+
253
+ This is almost an exact implementation of the original Transformer encoder.
254
+ See the original paper:
255
+ https://arxiv.org/abs/1706.03762
256
+ Args:
257
+ input_tensor: float Tensor of shape [batch_size, seq_length, hidden_size].
258
+ attention_mask: (optional) int32 Tensor of shape [batch_size, seq_length,
259
+ seq_length], with 1 for positions that can be attended to and 0 in
260
+ positions that should not be.
261
+ hidden_size: int. Hidden size of the Transformer.
262
+ num_hidden_layers: int. Number of layers (blocks) in the Transformer.
263
+ num_attention_heads: int. Number of attention heads in the Transformer.
264
+ intermediate_size: int. The size of the "intermediate" (a.k.a., feed
265
+ forward) layer.
266
+ intermediate_act_fn: function. The non-linear activation function to apply
267
+ to the output of the intermediate/feed-forward layer.
268
+ hidden_dropout_prob: float. Dropout probability for the hidden layers.
269
+ attention_probs_dropout_prob: float. Dropout probability of the attention
270
+ probabilities.
271
+ initializer_range: float. Range of the initializer (stddev of truncated
272
+ normal).
273
+ reuse: whether to reuse this encoder
274
+ name: scope name prefix
275
+
276
+ Returns:
277
+ float Tensor of shape [batch_size, seq_length, hidden_size], the final
278
+ hidden layer of the Transformer.
279
+
280
+ Raises:
281
+ ValueError: A Tensor shape or parameter is invalid.
282
+ """
283
+ if hidden_size % num_attention_heads != 0:
284
+ raise ValueError(
285
+ 'The hidden size (%d) is not a multiple of the number of attention '
286
+ 'heads (%d)' % (hidden_size, num_attention_heads))
287
+
288
+ attention_head_size = int(hidden_size / num_attention_heads)
289
+ input_shape = get_shape_list(input_tensor, expected_rank=3)
290
+ batch_size = input_shape[0]
291
+ seq_length = input_shape[1]
292
+ input_width = input_shape[2]
293
+
294
+ # The Transformer performs sum residuals on all layers so the input needs
295
+ # to be the same as the hidden size.
296
+ if input_width != hidden_size:
297
+ raise ValueError('The width of the input tensor (%d) != hidden size (%d)' %
298
+ (input_width, hidden_size))
299
+
300
+ # We keep the representation as a 2D tensor to avoid re-shaping it back and
301
+ # forth from a 3D tensor to a 2D tensor. Re-shapes are normally free on
302
+ # the GPU/CPU but may not be free on the TPU, so we want to minimize them to
303
+ # help the optimizer.
304
+ prev_output = reshape_to_matrix(input_tensor)
305
+
306
+ for layer_idx in range(num_hidden_layers):
307
+ with tf.variable_scope('%s_layer_%d' % (name, layer_idx)):
308
+ layer_input = prev_output
309
+
310
+ with tf.variable_scope('attention'):
311
+ with tf.variable_scope('self'):
312
+ # [batch_size * from_seq_length, num_attention_heads * size_per_head]
313
+ attention_output = attention_layer(
314
+ from_tensor=layer_input,
315
+ to_tensor=layer_input,
316
+ size_per_head=attention_head_size,
317
+ num_attention_heads=num_attention_heads,
318
+ attention_mask=attention_mask,
319
+ attention_probs_dropout_prob=attention_probs_dropout_prob,
320
+ initializer_range=initializer_range,
321
+ do_return_2d_tensor=True,
322
+ batch_size=batch_size,
323
+ from_seq_length=seq_length,
324
+ to_seq_length=seq_length,
325
+ reuse=reuse)
326
+
327
+ # Run a linear projection of `hidden_size` then add a residual
328
+ # with `layer_input`.
329
+ with tf.variable_scope('output', reuse=reuse):
330
+ attention_output = tf.layers.dense(
331
+ attention_output,
332
+ hidden_size,
333
+ kernel_initializer=create_initializer(initializer_range))
334
+ attention_output = dropout(attention_output, hidden_dropout_prob)
335
+ attention_output = layer_norm(attention_output + layer_input)
336
+
337
+ # The activation is only applied to the "intermediate" hidden layer.
338
+ with tf.variable_scope('intermediate', reuse=reuse):
339
+ intermediate_output = tf.layers.dense(
340
+ attention_output,
341
+ intermediate_size,
342
+ activation=intermediate_act_fn,
343
+ kernel_initializer=create_initializer(initializer_range))
344
+
345
+ # Down-project back to `hidden_size` then add the residual.
346
+ with tf.variable_scope('output', reuse=reuse):
347
+ layer_output = tf.layers.dense(
348
+ intermediate_output,
349
+ hidden_size,
350
+ kernel_initializer=create_initializer(initializer_range))
351
+ layer_output = dropout(layer_output, hidden_dropout_prob)
352
+ layer_output = layer_norm(layer_output + attention_output)
353
+ prev_output = layer_output
354
+
355
+ final_output = reshape_from_matrix(prev_output, input_shape)
356
+ return final_output
357
+
358
+
359
+ def cross_attention_block(from_tensor,
360
+ to_tensor,
361
+ layer_idx,
362
+ size_per_head,
363
+ cross_attention_mask=None,
364
+ self_attention_mask=None,
365
+ num_attention_heads=1,
366
+ intermediate_size=512,
367
+ hidden_dropout_prob=0.1,
368
+ attention_probs_dropout_prob=0.1,
369
+ initializer_range=0.02,
370
+ name=''):
371
+ """Multi-headed cross attention block.
372
+
373
+ Args:
374
+ from_tensor: float Tensor of shape [batch_size, from_seq_length,
375
+ from_width].
376
+ to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width].
377
+ layer_idx: int. layer id in the Transformer.
378
+ size_per_head: int. Size of each attention head.
379
+ cross_attention_mask: (optional) int32 Tensor of shape [batch_size, from_seq_length,
380
+ to_seq_length], with 1 for positions that can be attended to and 0 in
381
+ positions that should not be.
382
+ self_attention_mask: (optional) int32 Tensor of shape [batch_size, from_seq_length,
383
+ from_seq_length], with 1 for positions that can be attended to and 0 in
384
+ positions that should not be.
385
+ num_attention_heads: int. Number of attention heads in the Transformer.
386
+ intermediate_size: int. The size of the "intermediate" (a.k.a., feed
387
+ forward) layer.
388
+ hidden_dropout_prob: float. Dropout probability for the hidden layers.
389
+ attention_probs_dropout_prob: float. Dropout probability of the attention
390
+ probabilities.
391
+ initializer_range: float. Range of the initializer (stddev of truncated
392
+ normal).
393
+ name: scope name prefix
394
+
395
+ Returns:
396
+ float Tensor of shape [batch_size, seq_length, hidden_size], the final
397
+ hidden layer of the Transformer.
398
+
399
+ Raises:
400
+ ValueError: A Tensor shape or parameter is invalid.
401
+ """
402
+ input_shape = get_shape_list(from_tensor, expected_rank=3)
403
+ batch_size = input_shape[0]
404
+ from_seq_length = input_shape[1]
405
+
406
+ input_shape = get_shape_list(to_tensor, expected_rank=3)
407
+ to_seq_length = input_shape[1]
408
+
409
+ with tf.variable_scope('%scross_layer_%d' % (name, layer_idx)):
410
+ with tf.variable_scope('attention'):
411
+ with tf.variable_scope('cross'):
412
+ # [batch_size * from_seq_length, num_attention_heads * size_per_head]
413
+ cross_attention_output = attention_layer(
414
+ from_tensor=from_tensor,
415
+ to_tensor=to_tensor,
416
+ size_per_head=size_per_head,
417
+ num_attention_heads=num_attention_heads,
418
+ attention_mask=cross_attention_mask,
419
+ attention_probs_dropout_prob=attention_probs_dropout_prob,
420
+ initializer_range=initializer_range,
421
+ do_return_2d_tensor=True,
422
+ batch_size=batch_size,
423
+ from_seq_length=from_seq_length,
424
+ to_seq_length=to_seq_length)
425
+
426
+ with tf.variable_scope('self'):
427
+ # [batch_size * from_seq_length, num_attention_heads * size_per_head]
428
+ self_attention_output = attention_layer(
429
+ from_tensor=cross_attention_output,
430
+ to_tensor=cross_attention_output,
431
+ size_per_head=size_per_head,
432
+ num_attention_heads=num_attention_heads,
433
+ attention_mask=self_attention_mask,
434
+ attention_probs_dropout_prob=attention_probs_dropout_prob,
435
+ initializer_range=initializer_range,
436
+ do_return_2d_tensor=True,
437
+ batch_size=batch_size,
438
+ from_seq_length=from_seq_length,
439
+ to_seq_length=from_seq_length)
440
+
441
+ with tf.variable_scope('output'):
442
+ attention_output = dropout(self_attention_output, hidden_dropout_prob)
443
+ attention_output = layer_norm(attention_output + cross_attention_output)
444
+
445
+ # The activation is only applied to the "intermediate" hidden layer.
446
+ with tf.variable_scope('intermediate'):
447
+ intermediate_output = tf.layers.dense(
448
+ attention_output,
449
+ intermediate_size,
450
+ activation=tf.nn.relu,
451
+ kernel_initializer=create_initializer(initializer_range))
452
+
453
+ # Down-project back to `hidden_size` then add the residual.
454
+ with tf.variable_scope('output'):
455
+ layer_output = tf.layers.dense(
456
+ intermediate_output,
457
+ num_attention_heads * size_per_head,
458
+ kernel_initializer=create_initializer(initializer_range))
459
+ layer_output = dropout(layer_output, hidden_dropout_prob)
460
+ # [batch_size * from_seq_length, num_attention_heads * size_per_head]
461
+ layer_output = layer_norm(layer_output + attention_output)
462
+
463
+ final_output = reshape_from_matrix(
464
+ layer_output,
465
+ [batch_size, from_seq_length, num_attention_heads * size_per_head])
466
+ return final_output # [batch_size, from_seq_length, num_attention_heads * size_per_head]
467
+
468
+
469
+ def cross_attention_tower(left_tensor,
470
+ right_tensor,
471
+ num_hidden_layers=1,
472
+ num_attention_heads=12,
473
+ left_size_per_head=64,
474
+ right_size_per_head=64,
475
+ left_intermediate_size=0,
476
+ right_intermediate_size=0,
477
+ left_input_mask=None,
478
+ right_input_mask=None,
479
+ hidden_dropout_prob=0.1,
480
+ attention_probs_dropout_prob=0.1,
481
+ initializer_range=0.02,
482
+ name=''):
483
+ """Multi-headed, multi layer cross attention block.
484
+
485
+ Args:
486
+ left_tensor: float Tensor of shape [batch_size, left_seq_length,
487
+ from_width].
488
+ right_tensor: float Tensor of shape [batch_size, right_seq_length, to_width].
489
+ num_hidden_layers: int. Number of layers (blocks) in the Transformer.
490
+ num_attention_heads: int. Number of attention heads in the Transformer.
491
+ left_size_per_head: int. Size of each attention head of left tower.
492
+ right_size_per_head: int. Size of each attention head of right tower.
493
+ left intermediate_size: int. The size of the "intermediate" (a.k.a., feed
494
+ forward) layer of left tower. Less or equal to 0 means `num_attention_heads
495
+ * left_size_per_head`
496
+ right intermediate_size: int. The size of the "intermediate" (a.k.a., feed
497
+ forward) layer of right tower. Less or equal to 0 means `num_attention_heads
498
+ * right_size_per_head`
499
+ left_input_mask: the mask for `left_tensor`
500
+ right_input_mask: the mask for `right_tensor`
501
+ hidden_dropout_prob: float. Dropout probability for the hidden layers.
502
+ attention_probs_dropout_prob: float. Dropout probability of the attention
503
+ probabilities.
504
+ initializer_range: float. Range of the initializer (stddev of truncated
505
+ normal).
506
+ name: scope name prefix
507
+
508
+ Returns:
509
+ tuple of float Tensors of shape ([batch_size, left_seq_length, hidden_size],
510
+ [batch_size, right_seq_length, hidden_size]),
511
+ where hidden_size = num_attention_heads * size_per_head
512
+
513
+ Raises:
514
+ ValueError: A Tensor shape or parameter is invalid.
515
+ """
516
+ if left_intermediate_size <= 0:
517
+ left_intermediate_size = num_attention_heads * left_size_per_head
518
+ if right_intermediate_size <= 0:
519
+ right_intermediate_size = num_attention_heads * right_size_per_head
520
+
521
+ left_attention_mask = None
522
+ if left_input_mask is not None:
523
+ left_attention_mask = create_attention_mask_from_input_mask(
524
+ left_tensor, left_attention_mask)
525
+
526
+ left_2_right_attention_mask = None
527
+ if right_input_mask is not None:
528
+ left_2_right_attention_mask = create_attention_mask_from_input_mask(
529
+ left_tensor, right_input_mask)
530
+
531
+ right_attention_mask = None
532
+ if right_input_mask is not None:
533
+ right_attention_mask = create_attention_mask_from_input_mask(
534
+ right_tensor, right_input_mask)
535
+
536
+ right_2_left_attention_mask = None
537
+ if left_input_mask is not None:
538
+ right_2_left_attention_mask = create_attention_mask_from_input_mask(
539
+ right_tensor, left_input_mask)
540
+
541
+ prev_left_output = left_tensor
542
+ prev_right_output = right_tensor
543
+ for layer_idx in range(num_hidden_layers):
544
+ left_output = cross_attention_block(
545
+ prev_left_output,
546
+ prev_right_output,
547
+ layer_idx,
548
+ num_attention_heads=num_attention_heads,
549
+ size_per_head=left_size_per_head,
550
+ intermediate_size=left_intermediate_size,
551
+ hidden_dropout_prob=hidden_dropout_prob,
552
+ cross_attention_mask=left_2_right_attention_mask,
553
+ self_attention_mask=left_attention_mask,
554
+ attention_probs_dropout_prob=attention_probs_dropout_prob,
555
+ initializer_range=initializer_range,
556
+ name='%sleft_to_right_' % name)
557
+ right_output = cross_attention_block(
558
+ prev_right_output,
559
+ prev_left_output,
560
+ layer_idx,
561
+ num_attention_heads=num_attention_heads,
562
+ size_per_head=right_size_per_head,
563
+ intermediate_size=right_intermediate_size,
564
+ hidden_dropout_prob=hidden_dropout_prob,
565
+ cross_attention_mask=right_2_left_attention_mask,
566
+ self_attention_mask=right_attention_mask,
567
+ attention_probs_dropout_prob=attention_probs_dropout_prob,
568
+ initializer_range=initializer_range,
569
+ name='%sright_to_left_' % name)
570
+ prev_left_output = left_output
571
+ prev_right_output = right_output
572
+ return prev_left_output, prev_right_output
573
+
574
+
575
+ def layer_norm(input_tensor, name=None):
576
+ """Run layer normalization on the last dimension of the tensor."""
577
+ return tf_layer_norm(
578
+ inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name)
579
+
580
+
581
+ def reshape_to_matrix(input_tensor):
582
+ """Reshapes a >= rank 2 tensor to a rank 2 tensor (i.e., a matrix)."""
583
+ ndims = input_tensor.shape.ndims
584
+ if ndims < 2:
585
+ raise ValueError('Input tensor must have at least rank 2. Shape = %s' %
586
+ (input_tensor.shape))
587
+ if ndims == 2:
588
+ return input_tensor
589
+
590
+ width = input_tensor.shape[-1]
591
+ output_tensor = tf.reshape(input_tensor, [-1, width])
592
+ return output_tensor
593
+
594
+
595
+ def reshape_from_matrix(output_tensor, orig_shape_list):
596
+ """Reshapes a rank 2 tensor back to its original rank >= 2 tensor."""
597
+ if len(orig_shape_list) == 2:
598
+ return output_tensor
599
+
600
+ output_shape = get_shape_list(output_tensor)
601
+
602
+ orig_dims = orig_shape_list[0:-1]
603
+ width = output_shape[-1]
604
+
605
+ return tf.reshape(output_tensor, orig_dims + [width])
606
+
607
+
608
+ def create_attention_mask_from_input_mask(from_tensor, to_mask):
609
+ """Create 3D attention mask from a 2D tensor mask.
610
+
611
+ Args:
612
+ from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...].
613
+ to_mask: int32 Tensor of shape [batch_size, to_seq_length].
614
+
615
+ Returns:
616
+ float Tensor of shape [batch_size, from_seq_length, to_seq_length].
617
+ """
618
+ from_shape = get_shape_list(from_tensor, expected_rank=[2, 3])
619
+ batch_size = from_shape[0]
620
+ from_seq_length = from_shape[1]
621
+
622
+ to_shape = get_shape_list(to_mask, expected_rank=2)
623
+ to_seq_length = to_shape[1]
624
+
625
+ to_mask = tf.cast(
626
+ tf.reshape(to_mask, [batch_size, 1, to_seq_length]), tf.float32)
627
+
628
+ # We don't assume that `from_tensor` is a mask (although it could be). We
629
+ # don't actually care if we attend *from* padding tokens (only *to* padding)
630
+ # tokens so we create a tensor of all ones.
631
+ #
632
+ # `broadcast_ones` = [batch_size, from_seq_length, 1]
633
+ broadcast_ones = tf.ones(
634
+ shape=tf.stack([batch_size, from_seq_length, 1]), dtype=tf.float32)
635
+
636
+ # Here we broadcast along two dimensions to create the mask.
637
+ mask = broadcast_ones * to_mask
638
+
639
+ return mask
640
+
641
+
642
+ def embedding_postprocessor(input_tensor,
643
+ use_token_type=False,
644
+ token_type_ids=None,
645
+ token_type_vocab_size=16,
646
+ token_type_embedding_name='token_type_embeddings',
647
+ reuse_token_type=None,
648
+ use_position_embeddings=True,
649
+ position_embedding_name='position_embeddings',
650
+ reuse_position_embedding=None,
651
+ initializer_range=0.02,
652
+ max_position_embeddings=512,
653
+ dropout_prob=0.1):
654
+ """Performs various post-processing on a word embedding tensor.
655
+
656
+ Args:
657
+ input_tensor: float Tensor of shape [batch_size, seq_length,
658
+ embedding_size].
659
+ use_token_type: bool. Whether to add embeddings for `token_type_ids`.
660
+ token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length].
661
+ Must be specified if `use_token_type` is True.
662
+ token_type_vocab_size: int. The vocabulary size of `token_type_ids`.
663
+ token_type_embedding_name: string. The name of the embedding table variable
664
+ for token type ids.
665
+ reuse_token_type: bool. Whether to reuse token type embedding variable.
666
+ use_position_embeddings: bool. Whether to add position embeddings for the
667
+ position of each token in the sequence.
668
+ position_embedding_name: string. The name of the embedding table variable
669
+ for positional embeddings.
670
+ reuse_position_embedding: bool. Whether to reuse position embedding variable.
671
+ initializer_range: float. Range of the weight initialization.
672
+ max_position_embeddings: int. Maximum sequence length that might ever be
673
+ used with this model. This can be longer than the sequence length of
674
+ input_tensor, but cannot be shorter.
675
+ dropout_prob: float. Dropout probability applied to the final output tensor.
676
+
677
+ Returns:
678
+ float tensor with same shape as `input_tensor`.
679
+
680
+ Raises:
681
+ ValueError: One of the tensor shapes or input values is invalid.
682
+ """
683
+ input_shape = get_shape_list(input_tensor, expected_rank=3)
684
+ batch_size = input_shape[0]
685
+ seq_length = input_shape[1]
686
+ width = input_shape[2]
687
+
688
+ output = input_tensor
689
+
690
+ if use_token_type:
691
+ if token_type_ids is None:
692
+ raise ValueError('`token_type_ids` must be specified if'
693
+ '`use_token_type` is True.')
694
+ with tf.variable_scope('token_type', reuse=reuse_token_type):
695
+ token_type_table = tf.get_variable(
696
+ name=token_type_embedding_name,
697
+ shape=[token_type_vocab_size, width],
698
+ initializer=create_initializer(initializer_range))
699
+ # This vocab will be small so we always do one-hot here, since it is always
700
+ # faster for a small vocabulary.
701
+ flat_token_type_ids = tf.reshape(token_type_ids, [-1])
702
+ one_hot_ids = tf.one_hot(flat_token_type_ids, depth=token_type_vocab_size)
703
+ token_type_embeddings = tf.matmul(one_hot_ids, token_type_table)
704
+ token_type_embeddings = tf.reshape(token_type_embeddings,
705
+ [batch_size, seq_length, width])
706
+ output += token_type_embeddings
707
+
708
+ if use_position_embeddings:
709
+ assert_op = tf.assert_less_equal(seq_length, max_position_embeddings)
710
+ with tf.control_dependencies([assert_op]):
711
+ with tf.variable_scope(
712
+ 'position_embedding', reuse=reuse_position_embedding):
713
+ full_position_embeddings = tf.get_variable(
714
+ name=position_embedding_name,
715
+ shape=[max_position_embeddings, width],
716
+ initializer=create_initializer(initializer_range))
717
+ # Since the position embedding table is a learned variable, we create it
718
+ # using a (long) sequence length `max_position_embeddings`. The actual
719
+ # sequence length might be shorter than this, for faster training of
720
+ # tasks that do not have long sequences.
721
+ #
722
+ # So `full_position_embeddings` is effectively an embedding table
723
+ # for position [0, 1, 2, ..., max_position_embeddings-1], and the current
724
+ # sequence has positions [0, 1, 2, ... seq_length-1], so we can just
725
+ # perform a slice.
726
+ position_embeddings = tf.slice(full_position_embeddings, [0, 0],
727
+ [seq_length, -1])
728
+ num_dims = len(output.shape.as_list())
729
+
730
+ # Only the last two dimensions are relevant (`seq_length` and `width`), so
731
+ # we broadcast among the first dimensions, which is typically just
732
+ # the batch size.
733
+ position_broadcast_shape = []
734
+ for _ in range(num_dims - 2):
735
+ position_broadcast_shape.append(1)
736
+ position_broadcast_shape.extend([seq_length, width])
737
+ position_embeddings = tf.reshape(position_embeddings,
738
+ position_broadcast_shape)
739
+ output += position_embeddings
740
+
741
+ output = layer_norm_and_dropout(output, dropout_prob)
742
+ return output
743
+
744
+
745
+ def layer_norm_and_dropout(input_tensor, dropout_prob, name=None):
746
+ """Runs layer normalization followed by dropout."""
747
+ output_tensor = layer_norm(input_tensor, name)
748
+ output_tensor = dropout(output_tensor, dropout_prob)
749
+ return output_tensor