easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
@@ -0,0 +1,176 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import logging
4
+
5
+ import numpy as np
6
+ import tensorflow as tf
7
+
8
+ if tf.__version__ >= '2.0':
9
+ tf = tf.compat.v1
10
+
11
+
12
+ class CapsuleLayer:
13
+
14
+ def __init__(self, capsule_config, is_training):
15
+ # max_seq_len: max behaviour sequence length(history length)
16
+ self._max_seq_len = capsule_config.max_seq_len
17
+ # max_k: max high capsule number
18
+ self._max_k = capsule_config.max_k
19
+ # high_dim: high capsule vector dimension
20
+ self._high_dim = capsule_config.high_dim
21
+ # number of Expectation-Maximization iterations
22
+ self._num_iters = capsule_config.num_iters
23
+ # routing_logits_scale
24
+ self._routing_logits_scale = capsule_config.routing_logits_scale
25
+ # routing_logits_stddev
26
+ self._routing_logits_stddev = capsule_config.routing_logits_stddev
27
+ # squash power
28
+ self._squash_pow = capsule_config.squash_pow
29
+ # scale ratio
30
+ self._scale_ratio = capsule_config.scale_ratio
31
+ self._const_caps_num = capsule_config.const_caps_num
32
+ self._is_training = is_training
33
+
34
+ def squash(self, inputs):
35
+ """Squash inputs over the last dimension."""
36
+ input_norm = tf.reduce_sum(tf.square(inputs), keep_dims=True, axis=-1)
37
+ input_norm_eps = tf.maximum(input_norm, 1e-8)
38
+ scale_factor = tf.pow(input_norm_eps / (1 + input_norm_eps), self._squash_pow) * \
39
+ self._scale_ratio / tf.sqrt(input_norm_eps)
40
+ tf.summary.histogram('capsule/squash_scale_factor', scale_factor)
41
+ return scale_factor * inputs
42
+
43
+ def _build_capsule_simi(self, high_capsules, capsule_num):
44
+ high_capsule_mask = tf.sequence_mask(capsule_num,
45
+ tf.shape(high_capsules)[1])
46
+ high_capsules = high_capsules * tf.to_float(high_capsule_mask[:, :, None])
47
+ high_capsules = tf.nn.l2_normalize(high_capsules, axis=-1)
48
+ sum_sqr = tf.square(tf.reduce_sum(high_capsules, axis=1))
49
+ sqr_sum = tf.reduce_sum(tf.square(high_capsules), axis=1)
50
+ simi = sum_sqr - sqr_sum
51
+
52
+ div = tf.maximum(tf.to_float(capsule_num * (capsule_num - 1)), 1.0)
53
+ simi = tf.reduce_sum(simi, axis=1) / div
54
+
55
+ is_multi = tf.to_float(capsule_num > 1)
56
+ avg_simi = tf.reduce_sum((simi + 1) * is_multi) / \
57
+ (2.0 * tf.reduce_sum(is_multi))
58
+ return avg_simi
59
+
60
+ def __call__(self, seq_feas, seq_lens):
61
+ """Capsule layer implementation.
62
+
63
+ Args:
64
+ seq_feas: tensor of shape batch_size x self._max_seq_len x low_fea_dim(bsd)
65
+ seq_lens: tensor of shape batch_size
66
+
67
+ Return:
68
+ high_capsules: tensor of shape batch_size x max_k x high_dim
69
+ """
70
+ # pad or clip to max_seq_len
71
+ seq_feas = tf.cond(
72
+ tf.greater(tf.shape(seq_feas)[1], self._max_seq_len),
73
+ lambda: seq_feas[:, :self._max_seq_len, :], lambda: tf.cond(
74
+ tf.less(tf.shape(seq_feas)[1], self._max_seq_len), lambda: tf.pad(
75
+ seq_feas, [[0, 0], [
76
+ 0, self._max_seq_len - tf.shape(seq_feas)[1]
77
+ ], [0, 0]]), lambda: seq_feas))
78
+ seq_lens = tf.minimum(seq_lens, self._max_seq_len)
79
+
80
+ batch_size = tf.shape(seq_lens)[0]
81
+ # max_seq_len x max_num_high_capsule(sh)
82
+ if self._is_training:
83
+ routing_logits = tf.truncated_normal(
84
+ [batch_size, self._max_seq_len, self._max_k],
85
+ stddev=self._routing_logits_stddev)
86
+ else:
87
+ np.random.seed(28)
88
+ routing_logits = tf.constant(
89
+ np.random.uniform(
90
+ high=self._routing_logits_stddev,
91
+ size=[self._max_seq_len, self._max_k]),
92
+ dtype=tf.float32)
93
+ routing_logits = tf.tile(routing_logits[None, :, :], [batch_size, 1, 1])
94
+ routing_logits = tf.stop_gradient(routing_logits)
95
+ # batch_size x max_seq_len x max_k(bsh)
96
+ low_fea_dim = seq_feas.get_shape()[-1]
97
+ # map low capsule features to high capsule features:
98
+ # low_fea_dim x high_dim(de)
99
+ bilinear_matrix = tf.get_variable(
100
+ dtype=tf.float32, shape=[low_fea_dim, self._high_dim], name='capsule/S')
101
+ # map sequence feature to high dimensional space
102
+ seq_feas_high = tf.tensordot(seq_feas, bilinear_matrix, axes=1)
103
+ seq_feas_high_stop = tf.stop_gradient(seq_feas_high)
104
+ seq_feas_high_norm = tf.nn.l2_normalize(seq_feas_high_stop, -1)
105
+
106
+ if self._const_caps_num:
107
+ logging.info('will use constant number of capsules: %d' % self._max_k)
108
+ num_high_capsules = tf.zeros_like(seq_lens, dtype=tf.int32) + self._max_k
109
+ else:
110
+ logging.info(
111
+ 'will use log(seq_len) number of capsules, max_capsules: %d' %
112
+ self._max_k)
113
+ num_high_capsules = tf.maximum(
114
+ 1, tf.minimum(self._max_k,
115
+ tf.to_int32(tf.log(tf.to_float(seq_lens)))))
116
+
117
+ # batch_size x max_seq_len(bs)
118
+ mask = tf.sequence_mask(seq_lens, self._max_seq_len)
119
+ mask = tf.cast(mask, tf.float32)
120
+ # batch_size x max_k(bh)
121
+ mask_cap = tf.sequence_mask(num_high_capsules, self._max_k)
122
+ mask_cap = tf.cast(mask_cap, tf.float32)
123
+ # batch_size x max_seq_len x 1(bs1)
124
+ # max_seq_thresh = (mask[:, :, None] * 2 - 1) * 1e32
125
+ # batch_size x 1 x h (b1h)
126
+ max_cap_thresh = (tf.cast(mask_cap[:, None, :], tf.float32) * 2 - 1) * 1e32
127
+ for iter_id in range(self._num_iters):
128
+ # batch_size x max_seq_len x max_k(bsh)
129
+ routing_logits = tf.minimum(routing_logits, max_cap_thresh)
130
+ routing_logits = tf.nn.softmax(routing_logits, axis=2)
131
+
132
+ routing_logits = routing_logits * mask[:, :, None]
133
+
134
+ logits_simi = self._build_capsule_simi(routing_logits, seq_lens)
135
+ tf.summary.scalar('capsule/rlogits_simi_%d' % iter_id, logits_simi)
136
+
137
+ seq_fea_simi = self._build_capsule_simi(seq_feas_high_stop, seq_lens)
138
+ tf.summary.scalar('capsule/seq_fea_simi_%d' % iter_id, seq_fea_simi)
139
+
140
+ # batch_size x max_k x high_dim(bse,bsh->bhe)
141
+ high_capsules = tf.einsum(
142
+ 'bse, bsh->bhe', seq_feas_high_stop
143
+ if iter_id + 1 < self._num_iters else seq_feas_high, routing_logits)
144
+ if iter_id + 1 == self._num_iters:
145
+ capsule_simi = self._build_capsule_simi(high_capsules,
146
+ num_high_capsules)
147
+ tf.summary.scalar('caspule/simi_%d' % iter_id, capsule_simi)
148
+ tf.summary.scalar('capsule/before_squash',
149
+ tf.reduce_mean(tf.norm(high_capsules, axis=-1)))
150
+ high_capsules = self.squash(high_capsules)
151
+ tf.summary.scalar('capsule/after_squash',
152
+ tf.reduce_mean(tf.norm(high_capsules, axis=-1)))
153
+ capsule_simi_final = self._build_capsule_simi(high_capsules,
154
+ num_high_capsules)
155
+ tf.summary.scalar('caspule/simi_final', capsule_simi_final)
156
+ break
157
+
158
+ # batch_size x max_k x high_dim(bhe)
159
+ high_capsules = tf.nn.l2_normalize(high_capsules, -1)
160
+ capsule_simi = self._build_capsule_simi(high_capsules, num_high_capsules)
161
+ tf.summary.scalar('caspule/simi_%d' % iter_id, capsule_simi)
162
+ # batch_size x max_seq_len x max_k(bse, bhe->bsh)
163
+ if self._routing_logits_scale > 0:
164
+ if iter_id == 0:
165
+ logging.info('routing_logits_scale = %.2f' %
166
+ self._routing_logits_scale)
167
+ routing_logits = tf.einsum('bse, bhe->bsh', seq_feas_high_norm,
168
+ high_capsules) * self._routing_logits_scale
169
+ else:
170
+ routing_logits = tf.einsum('bse, bhe->bsh', seq_feas_high_stop,
171
+ high_capsules)
172
+
173
+ # zero paddings
174
+ high_capsule_mask = tf.sequence_mask(num_high_capsules, self._max_k)
175
+ high_capsules = high_capsules * tf.to_float(high_capsule_mask[:, :, None])
176
+ return high_capsules, num_high_capsules
@@ -0,0 +1,390 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import tensorflow as tf
4
+
5
+ from easy_rec.python.layers import dnn
6
+ from easy_rec.python.layers import multihead_cross_attention
7
+ from easy_rec.python.utils.shape_utils import get_shape_list
8
+
9
+ if tf.__version__ >= '2.0':
10
+ tf = tf.compat.v1
11
+
12
+
13
+ class CMBF(object):
14
+ """CMBF: Cross-Modal-Based Fusion Recommendation Algorithm.
15
+
16
+ This is almost an exact implementation of the original CMBF model.
17
+ See the original paper:
18
+ https://www.mdpi.com/1424-8220/21/16/5275
19
+ """
20
+
21
+ def __init__(self, model_config, feature_configs, features, cmbf_config,
22
+ input_layer):
23
+ self._model_config = cmbf_config
24
+
25
+ has_feature = False
26
+ self._img_features = None
27
+ if input_layer.has_group('image'):
28
+ self._img_features, _ = input_layer(features, 'image')
29
+ has_feature = True
30
+ self._general_features = None
31
+ if input_layer.has_group('general'):
32
+ self._general_features, _ = input_layer(features, 'general')
33
+ has_feature = True
34
+ self._txt_seq_features = None
35
+ if input_layer.has_group('text'):
36
+ self._txt_seq_features, _, _ = input_layer(
37
+ features, 'text', is_combine=False)
38
+ has_feature = True
39
+ self._other_features = None
40
+ if input_layer.has_group('other'): # e.g. statistical feature
41
+ self._other_features, _ = input_layer(features, 'other')
42
+ has_feature = True
43
+ assert has_feature, 'there must be one of the feature groups: [image, text, general, other]'
44
+
45
+ self._general_feature_num, self._img_feature_num = 0, 0
46
+ self._txt_feature_num = 0
47
+ general_feature_names, txt_seq_feature_names = set(), set()
48
+ img_feature_names = set()
49
+ for fea_group in model_config.feature_groups:
50
+ if fea_group.group_name == 'general':
51
+ self._general_feature_num = len(fea_group.feature_names)
52
+ general_feature_names = set(fea_group.feature_names)
53
+ assert self._general_feature_num == len(general_feature_names), (
54
+ 'there are duplicate features in `general` feature group')
55
+ elif fea_group.group_name == 'image':
56
+ self._img_feature_num = len(fea_group.feature_names)
57
+ img_feature_names = set(fea_group.feature_names)
58
+ assert self._img_feature_num == len(img_feature_names), (
59
+ 'there are duplicate features in `image` feature group')
60
+ elif fea_group.group_name == 'text':
61
+ txt_seq_feature_names = set(fea_group.feature_names)
62
+ self._txt_feature_num = len(fea_group.feature_names)
63
+ assert self._txt_feature_num == len(txt_seq_feature_names), (
64
+ 'there are duplicate features in `text` feature group')
65
+
66
+ max_seq_len = 0
67
+ txt_fea_emb_dim_list = []
68
+ general_emb_dim_list = []
69
+ img_fea_emb_dim_list = []
70
+ for feature_config in feature_configs:
71
+ fea_name = feature_config.input_names[0]
72
+ if feature_config.HasField('feature_name'):
73
+ fea_name = feature_config.feature_name
74
+ if fea_name in img_feature_names:
75
+ img_fea_emb_dim_list.append(feature_config.raw_input_dim)
76
+ if fea_name in general_feature_names:
77
+ general_emb_dim_list.append(feature_config.embedding_dim)
78
+ if fea_name in txt_seq_feature_names:
79
+ txt_fea_emb_dim_list.append(feature_config.embedding_dim)
80
+ if feature_config.HasField('max_seq_len'):
81
+ assert feature_config.max_seq_len > 0, (
82
+ 'feature config `max_seq_len` must be greater than 0 for feature: '
83
+ + fea_name)
84
+ if feature_config.max_seq_len > max_seq_len:
85
+ max_seq_len = feature_config.max_seq_len
86
+
87
+ unique_dim_num = len(set(txt_fea_emb_dim_list))
88
+ assert unique_dim_num <= 1 and len(
89
+ txt_fea_emb_dim_list
90
+ ) == self._txt_feature_num, (
91
+ 'CMBF requires that all `text` feature dimensions must be consistent.')
92
+ unique_dim_num = len(set(general_emb_dim_list))
93
+ assert unique_dim_num <= 1 and len(
94
+ general_emb_dim_list
95
+ ) == self._general_feature_num, (
96
+ 'CMBF requires that all `general` feature dimensions must be consistent.'
97
+ )
98
+ unique_dim_num = len(set(img_fea_emb_dim_list))
99
+ assert unique_dim_num <= 1 and len(
100
+ img_fea_emb_dim_list
101
+ ) == self._img_feature_num, (
102
+ 'CMBF requires that all `image` feature dimensions must be consistent.')
103
+
104
+ if cmbf_config.use_position_embeddings:
105
+ assert cmbf_config.max_position_embeddings > 0, (
106
+ 'model config `max_position_embeddings` must be greater than 0. '
107
+ 'It must be set when `use_position_embeddings` is true (default)')
108
+ assert cmbf_config.max_position_embeddings >= max_seq_len, (
109
+ 'model config `max_position_embeddings` must be greater than or equal to the maximum of all feature config '
110
+ '`max_seq_len`, which is %d' % max_seq_len)
111
+
112
+ self._img_emb_size = img_fea_emb_dim_list[0] if img_fea_emb_dim_list else 0
113
+ self._txt_emb_size = txt_fea_emb_dim_list[0] if txt_fea_emb_dim_list else 0
114
+ self._general_emb_size = general_emb_dim_list[
115
+ 0] if general_emb_dim_list else 0
116
+ self._head_num = cmbf_config.multi_head_num
117
+ self._img_head_num = cmbf_config.image_multi_head_num
118
+ self._txt_head_num = cmbf_config.text_multi_head_num
119
+ self._txt_head_size = cmbf_config.text_head_size
120
+ self._img_head_size = cmbf_config.image_head_size
121
+ self._img_patch_num = cmbf_config.image_feature_patch_num
122
+ self._img_self_attention_layer_num = cmbf_config.image_self_attention_layer_num
123
+ self._txt_self_attention_layer_num = cmbf_config.text_self_attention_layer_num
124
+ self._cross_modal_layer_num = cmbf_config.cross_modal_layer_num
125
+ print('txt_feature_num: {0}, img_feature_num: {1}, txt_seq_feature_num: {2}'
126
+ .format(self._general_feature_num, self._img_feature_num,
127
+ len(self._txt_seq_features) if self._txt_seq_features else 0))
128
+ print('txt_embedding_size: {0}, img_embedding_size: {1}'.format(
129
+ self._txt_emb_size, self._img_emb_size))
130
+ if self._img_features is not None:
131
+ assert self._img_emb_size > 0, '`image` feature dimensions must be greater than 0, set by `raw_input_dim`'
132
+
133
+ def image_self_attention_tower(self):
134
+ """The input of image self attention tower can be one of.
135
+
136
+ 1. multiple image embeddings, each corresponding to a patch, or a ROI(region of interest), or a frame of video
137
+ 2. one big image embedding composed by stacking multiple image embeddings
138
+ 3. one conventional image embedding extracted by an image model
139
+
140
+ If image embedding size is not equal to configured `image_feature_dim` argument,
141
+ do dimension reduce to this size before single modal learning module
142
+ """
143
+ if self._img_features is None:
144
+ return None
145
+ image_features = self._img_features
146
+ img_fea_num = self._img_feature_num
147
+ if self._img_self_attention_layer_num <= 0:
148
+ hidden_size = self._model_config.multi_head_num * self._model_config.image_cross_head_size
149
+ if self._img_emb_size != hidden_size:
150
+ # Run a linear projection of `hidden_size`
151
+ image_features = tf.reshape(
152
+ self._img_features, shape=[-1, self._img_emb_size])
153
+ image_features = tf.layers.dense(
154
+ image_features, hidden_size, name='img_projection')
155
+ image_features = tf.reshape(
156
+ image_features, shape=[-1, img_fea_num, hidden_size])
157
+ return image_features
158
+
159
+ hidden_size = self._img_head_size * self._img_head_num
160
+ if img_fea_num > 1: # in case of video frames or ROIs (Region Of Interest)
161
+ if self._img_emb_size != hidden_size:
162
+ # Run a linear projection of `hidden_size`
163
+ image_features = tf.reshape(
164
+ self._img_features, shape=[-1, self._img_emb_size])
165
+ image_features = tf.layers.dense(
166
+ image_features, hidden_size, name='img_projection')
167
+ image_features = tf.reshape(
168
+ image_features, shape=[-1, img_fea_num, hidden_size])
169
+ elif img_fea_num == 1:
170
+ if self._img_patch_num > 1: # image feature dimension: patch_num * emb_size
171
+ img_fea_num = self._img_patch_num
172
+ img_emb_size = self._img_emb_size // self._img_patch_num
173
+ assert img_emb_size * self._img_patch_num == self._img_emb_size, (
174
+ 'image feature dimension must equal to `image_feature_slice_num * embedding_size_per_region`'
175
+ )
176
+ self._img_emb_size = img_emb_size
177
+ if self._img_emb_size != hidden_size:
178
+ # Run a linear projection of `hidden_size`
179
+ image_features = tf.reshape(
180
+ self._img_features, shape=[-1, self._img_emb_size])
181
+ image_features = tf.layers.dense(
182
+ image_features, hidden_size, name='img_projection')
183
+ image_features = tf.reshape(
184
+ image_features, shape=[-1, img_fea_num, hidden_size])
185
+ else:
186
+ img_fea_num = self._model_config.image_feature_dim
187
+ if img_fea_num != self._img_emb_size:
188
+ image_features = tf.layers.dense(
189
+ image_features, img_fea_num, name='img_projection')
190
+ # convert each element of image feature to a feature vector
191
+ img_mapping_matrix = tf.get_variable(
192
+ 'img_map_matrix', [1, img_fea_num, hidden_size], dtype=tf.float32)
193
+ image_features = tf.expand_dims(image_features, -1) * img_mapping_matrix
194
+
195
+ img_attention_fea = multihead_cross_attention.transformer_encoder(
196
+ image_features,
197
+ hidden_size=hidden_size, # head_num * size_per_head
198
+ num_hidden_layers=self._img_self_attention_layer_num,
199
+ num_attention_heads=self._head_num,
200
+ intermediate_size=hidden_size * 4,
201
+ hidden_dropout_prob=self._model_config.hidden_dropout_prob,
202
+ attention_probs_dropout_prob=self._model_config
203
+ .attention_probs_dropout_prob,
204
+ name='image_self_attention'
205
+ ) # shape: [batch_size, image_seq_num/image_feature_dim, hidden_size]
206
+ # print('img_attention_fea:', img_attention_fea.shape)
207
+ return img_attention_fea
208
+
209
+ def text_self_attention_tower(self):
210
+ hidden_size = self._txt_head_size * self._txt_head_num
211
+ txt_features = None
212
+ all_txt_features = []
213
+ input_masks = []
214
+
215
+ if self._general_features is not None:
216
+ general_features = self._general_features
217
+ if self._general_emb_size != hidden_size:
218
+ # Run a linear projection of `hidden_size`
219
+ general_features = tf.reshape(
220
+ general_features, shape=[-1, self._general_emb_size])
221
+ general_features = tf.layers.dense(
222
+ general_features, hidden_size, name='txt_projection')
223
+ txt_features = tf.reshape(
224
+ general_features, shape=[-1, self._general_feature_num, hidden_size])
225
+
226
+ all_txt_features.append(txt_features)
227
+ batch_size = tf.shape(txt_features)[0]
228
+ mask = tf.ones(
229
+ shape=tf.stack([batch_size, self._general_feature_num]),
230
+ dtype=tf.int32)
231
+ input_masks.append(mask)
232
+
233
+ input_mask = None
234
+ attention_mask = None
235
+ if self._txt_seq_features is not None:
236
+
237
+ def dynamic_mask(x, max_len):
238
+ ones = tf.ones(shape=tf.stack([x]), dtype=tf.int32)
239
+ zeros = tf.zeros(shape=tf.stack([max_len - x]), dtype=tf.int32)
240
+ return tf.concat([ones, zeros], axis=0)
241
+
242
+ token_type_vocab_size = len(self._txt_seq_features)
243
+ for i, (seq_fea, seq_len) in enumerate(self._txt_seq_features):
244
+ batch_size, max_seq_len, emb_size = get_shape_list(seq_fea, 3)
245
+ if emb_size != hidden_size:
246
+ seq_fea = tf.reshape(seq_fea, shape=[-1, emb_size])
247
+ seq_fea = tf.layers.dense(
248
+ seq_fea, hidden_size, name='txt_seq_projection_%d' % i)
249
+ seq_fea = tf.reshape(seq_fea, shape=[-1, max_seq_len, hidden_size])
250
+
251
+ seq_fea = multihead_cross_attention.embedding_postprocessor(
252
+ seq_fea,
253
+ use_token_type=self._model_config.use_token_type,
254
+ token_type_ids=tf.ones(
255
+ shape=tf.stack([batch_size, max_seq_len]), dtype=tf.int32) * i,
256
+ token_type_vocab_size=token_type_vocab_size,
257
+ reuse_token_type=tf.AUTO_REUSE,
258
+ use_position_embeddings=self._model_config.use_position_embeddings,
259
+ max_position_embeddings=self._model_config.max_position_embeddings,
260
+ position_embedding_name='position_embeddings_%d' % i,
261
+ dropout_prob=self._model_config.text_seq_emb_dropout_prob)
262
+ all_txt_features.append(seq_fea)
263
+
264
+ input_mask = tf.map_fn(
265
+ fn=lambda t: dynamic_mask(t, max_seq_len),
266
+ elems=tf.to_int32(seq_len))
267
+ input_masks.append(input_mask)
268
+
269
+ txt_features = tf.concat(all_txt_features, axis=1)
270
+ input_mask = tf.concat(input_masks, axis=1)
271
+ attention_mask = multihead_cross_attention.create_attention_mask_from_input_mask(
272
+ from_tensor=txt_features, to_mask=input_mask)
273
+
274
+ if txt_features is None:
275
+ return None, None, None
276
+
277
+ txt_attention_fea = multihead_cross_attention.transformer_encoder(
278
+ txt_features,
279
+ hidden_size=hidden_size,
280
+ num_hidden_layers=self._txt_self_attention_layer_num,
281
+ num_attention_heads=self._head_num,
282
+ attention_mask=attention_mask,
283
+ intermediate_size=hidden_size * 4,
284
+ hidden_dropout_prob=self._model_config.hidden_dropout_prob,
285
+ attention_probs_dropout_prob=self._model_config
286
+ .attention_probs_dropout_prob,
287
+ name='text_self_attention'
288
+ ) # shape: [batch_size, txt_seq_length, hidden_size]
289
+ print('txt_attention_fea:', txt_attention_fea.shape)
290
+ return txt_attention_fea, input_mask, input_masks
291
+
292
+ def merge_text_embedding(self, txt_embeddings, input_masks):
293
+ shape = get_shape_list(txt_embeddings)
294
+ if self._txt_seq_features is None:
295
+ return tf.reshape(txt_embeddings, shape=[-1, shape[1] * shape[2]])
296
+
297
+ text_seq_emb = []
298
+ if self._general_feature_num > 0:
299
+ text_emb = tf.slice(txt_embeddings, [0, 0, 0],
300
+ [shape[0], self._general_feature_num, shape[2]])
301
+ text_seq_emb.append(text_emb)
302
+
303
+ begin = self._general_feature_num
304
+ for i in range(len(text_seq_emb), len(input_masks)):
305
+ size = tf.shape(input_masks[i])[1]
306
+ temp_emb = tf.slice(txt_embeddings, [0, begin, 0],
307
+ [shape[0], size, shape[2]])
308
+ mask = tf.expand_dims(tf.to_float(input_masks[i]), -1)
309
+ temp_emb = temp_emb * mask
310
+ # avg pooling
311
+ emb_sum = tf.reduce_sum(
312
+ temp_emb, axis=1,
313
+ keepdims=True) # shape: [batch_size, 1, hidden_size]
314
+ count = tf.reduce_sum(
315
+ mask, axis=1, keepdims=True) # shape: [batch_size, 1, 1]
316
+ seq_emb = emb_sum / count # shape: [batch_size, 1, hidden_size]
317
+
318
+ text_seq_emb.append(seq_emb)
319
+ begin = begin + size
320
+
321
+ txt_emb = tf.concat(text_seq_emb, axis=1)
322
+ seq_num = len(text_seq_emb)
323
+ if self._general_feature_num > 0:
324
+ seq_num += self._general_feature_num - 1
325
+ txt_embeddings = tf.reshape(txt_emb, shape=[-1, seq_num * shape[2]])
326
+ return txt_embeddings
327
+
328
+ def __call__(self, is_training, *args, **kwargs):
329
+ if not is_training:
330
+ self._model_config.hidden_dropout_prob = 0.0
331
+ self._model_config.attention_probs_dropout_prob = 0.0
332
+
333
+ # shape: [batch_size, image_num/image_dim, hidden_size]
334
+ img_attention_fea = self.image_self_attention_tower()
335
+
336
+ # shape: [batch_size, txt_seq_length, hidden_size]
337
+ txt_attention_fea, input_mask, input_masks = self.text_self_attention_tower(
338
+ )
339
+
340
+ all_fea = []
341
+ if None not in [img_attention_fea, txt_attention_fea]:
342
+ img_embeddings, txt_embeddings = multihead_cross_attention.cross_attention_tower(
343
+ img_attention_fea,
344
+ txt_attention_fea,
345
+ num_hidden_layers=self._cross_modal_layer_num,
346
+ num_attention_heads=self._head_num,
347
+ right_input_mask=input_mask,
348
+ left_size_per_head=self._model_config.image_cross_head_size,
349
+ left_intermediate_size=4 * self._model_config.image_cross_head_size *
350
+ self._head_num,
351
+ right_size_per_head=self._model_config.text_cross_head_size,
352
+ right_intermediate_size=4 * self._model_config.text_cross_head_size *
353
+ self._head_num,
354
+ hidden_dropout_prob=self._model_config.hidden_dropout_prob,
355
+ attention_probs_dropout_prob=self._model_config
356
+ .attention_probs_dropout_prob)
357
+ # img_embeddings shape: [batch_size, image_(region_)num/image_feature_dim, multi_head_num * image_cross_head_size]
358
+ print('img_embeddings:', img_embeddings.shape)
359
+ # txt_embeddings shape: [batch_size, general_feature_num + max_txt_seq_len, multi_head_num * text_cross_head_size]
360
+ print('txt_embeddings:', txt_embeddings.shape)
361
+
362
+ # shape: [batch_size, multi_head_num * image_cross_head_size]
363
+ img_embeddings = tf.reduce_mean(img_embeddings, axis=1)
364
+
365
+ # shape: [batch_size, (general_feature_num + txt_seq_num) * multi_head_num * text_cross_head_size]
366
+ txt_embeddings = self.merge_text_embedding(txt_embeddings, input_masks)
367
+ all_fea = [img_embeddings, txt_embeddings]
368
+
369
+ elif img_attention_fea is not None: # only has image tower
370
+ # avg pooling, shape: [batch_size, multi_head_num * image_head_size]
371
+ img_embeddings = tf.reduce_mean(img_attention_fea, axis=1)
372
+ all_fea = [img_embeddings]
373
+
374
+ elif txt_attention_fea is not None: # only has text tower
375
+ # shape: [batch_size, (general_feature_num + txt_seq_num) * multi_head_num * text_head_size]
376
+ txt_embeddings = self.merge_text_embedding(txt_attention_fea, input_masks)
377
+ all_fea = [txt_embeddings]
378
+
379
+ if self._other_features is not None:
380
+ if self._model_config.HasField('other_feature_dnn'):
381
+ l2_reg = kwargs['l2_reg'] if 'l2_reg' in kwargs else 0
382
+ other_dnn_layer = dnn.DNN(self._model_config.other_feature_dnn, l2_reg,
383
+ 'other_dnn', is_training)
384
+ other_fea = other_dnn_layer(self._other_features)
385
+ all_fea.append(other_fea) # e.g. statistical features
386
+ else:
387
+ all_fea.append(self._other_features) # e.g. statistical features
388
+
389
+ output = tf.concat(all_fea, axis=-1)
390
+ return output