easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
@@ -0,0 +1,301 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import tensorflow as tf
4
+
5
+ from easy_rec.python.layers import dnn
6
+ from easy_rec.python.layers import multihead_cross_attention
7
+ from easy_rec.python.utils.activation import get_activation
8
+ from easy_rec.python.utils.shape_utils import get_shape_list
9
+
10
+ if tf.__version__ >= '2.0':
11
+ tf = tf.compat.v1
12
+
13
+
14
+ class Uniter(object):
15
+ """UNITER: UNiversal Image-TExt Representation Learning.
16
+
17
+ See the original paper:
18
+ https://arxiv.org/abs/1909.11740
19
+ """
20
+
21
+ def __init__(self, model_config, feature_configs, features, uniter_config,
22
+ input_layer):
23
+ self._model_config = uniter_config
24
+ tower_num = 0
25
+ self._img_features = None
26
+ if input_layer.has_group('image'):
27
+ self._img_features, _ = input_layer(features, 'image')
28
+ tower_num += 1
29
+ self._general_features = None
30
+ if input_layer.has_group('general'):
31
+ self._general_features, _ = input_layer(features, 'general')
32
+ tower_num += 1
33
+ self._txt_seq_features = None
34
+ if input_layer.has_group('text'):
35
+ self._txt_seq_features, _, _ = input_layer(
36
+ features, 'text', is_combine=False)
37
+ tower_num += 1
38
+ self._use_token_type = True if tower_num > 1 else False
39
+ self._other_features = None
40
+ if input_layer.has_group('other'): # e.g. statistical feature
41
+ self._other_features, _ = input_layer(features, 'other')
42
+ tower_num += 1
43
+ assert tower_num > 0, 'there must be one of the feature groups: [image, text, general, other]'
44
+
45
+ self._general_feature_num = 0
46
+ self._txt_feature_num, self._img_feature_num = 0, 0
47
+ general_feature_names = set()
48
+ img_feature_names, txt_feature_names = set(), set()
49
+ for fea_group in model_config.feature_groups:
50
+ if fea_group.group_name == 'general':
51
+ self._general_feature_num = len(fea_group.feature_names)
52
+ general_feature_names = set(fea_group.feature_names)
53
+ assert self._general_feature_num == len(general_feature_names), (
54
+ 'there are duplicate features in `general` feature group')
55
+ elif fea_group.group_name == 'image':
56
+ self._img_feature_num = len(fea_group.feature_names)
57
+ img_feature_names = set(fea_group.feature_names)
58
+ assert self._img_feature_num == len(img_feature_names), (
59
+ 'there are duplicate features in `image` feature group')
60
+ elif fea_group.group_name == 'text':
61
+ self._txt_feature_num = len(fea_group.feature_names)
62
+ txt_feature_names = set(fea_group.feature_names)
63
+ assert self._txt_feature_num == len(txt_feature_names), (
64
+ 'there are duplicate features in `text` feature group')
65
+
66
+ if self._txt_feature_num > 1 or self._img_feature_num > 1:
67
+ self._use_token_type = True
68
+ self._token_type_vocab_size = self._txt_feature_num
69
+ if self._img_feature_num > 0:
70
+ self._token_type_vocab_size += 1
71
+ if self._general_feature_num > 0:
72
+ self._token_type_vocab_size += 1
73
+
74
+ max_seq_len = 0
75
+ txt_fea_emb_dim_list = []
76
+ general_emb_dim_list = []
77
+ img_fea_emb_dim_list = []
78
+ for feature_config in feature_configs:
79
+ fea_name = feature_config.input_names[0]
80
+ if feature_config.HasField('feature_name'):
81
+ fea_name = feature_config.feature_name
82
+ if fea_name in img_feature_names:
83
+ img_fea_emb_dim_list.append(feature_config.raw_input_dim)
84
+ if fea_name in general_feature_names:
85
+ general_emb_dim_list.append(feature_config.embedding_dim)
86
+ if fea_name in txt_feature_names:
87
+ txt_fea_emb_dim_list.append(feature_config.embedding_dim)
88
+ if feature_config.HasField('max_seq_len'):
89
+ assert feature_config.max_seq_len > 0, (
90
+ 'feature config `max_seq_len` must be greater than 0 for feature: '
91
+ + fea_name)
92
+ if feature_config.max_seq_len > max_seq_len:
93
+ max_seq_len = feature_config.max_seq_len
94
+
95
+ unique_dim_num = len(set(txt_fea_emb_dim_list))
96
+ assert unique_dim_num <= 1 and len(
97
+ txt_fea_emb_dim_list
98
+ ) == self._txt_feature_num, (
99
+ 'Uniter requires that all `text` feature dimensions must be consistent.'
100
+ )
101
+ unique_dim_num = len(set(img_fea_emb_dim_list))
102
+ assert unique_dim_num <= 1 and len(
103
+ img_fea_emb_dim_list
104
+ ) == self._img_feature_num, (
105
+ 'Uniter requires that all `image` feature dimensions must be consistent.'
106
+ )
107
+ unique_dim_num = len(set(general_emb_dim_list))
108
+ assert unique_dim_num <= 1 and len(
109
+ general_emb_dim_list
110
+ ) == self._general_feature_num, (
111
+ 'Uniter requires that all `general` feature dimensions must be consistent.'
112
+ )
113
+
114
+ if self._txt_feature_num > 0 and uniter_config.use_position_embeddings:
115
+ assert uniter_config.max_position_embeddings > 0, (
116
+ 'model config `max_position_embeddings` must be greater than 0. ')
117
+ assert uniter_config.max_position_embeddings >= max_seq_len, (
118
+ 'model config `max_position_embeddings` must be greater than or equal to the maximum of all feature config '
119
+ '`max_seq_len`, which is %d' % max_seq_len)
120
+
121
+ self._img_emb_size = img_fea_emb_dim_list[0] if img_fea_emb_dim_list else 0
122
+ self._txt_emb_size = txt_fea_emb_dim_list[0] if txt_fea_emb_dim_list else 0
123
+ self._general_emb_size = general_emb_dim_list[
124
+ 0] if general_emb_dim_list else 0
125
+ if self._img_features is not None:
126
+ assert self._img_emb_size > 0, '`image` feature dimensions must be greater than 0, set by `raw_input_dim`'
127
+
128
+ def text_embeddings(self, token_type_id):
129
+ all_txt_features = []
130
+ input_masks = []
131
+ hidden_size = self._model_config.hidden_size
132
+ if self._general_features is not None:
133
+ general_features = self._general_features
134
+ if self._general_emb_size != hidden_size:
135
+ # Run a linear projection of `hidden_size`
136
+ general_features = tf.reshape(
137
+ general_features, shape=[-1, self._general_emb_size])
138
+ general_features = tf.layers.dense(
139
+ general_features, hidden_size, name='txt_projection')
140
+ general_features = tf.reshape(
141
+ general_features, shape=[-1, self._general_feature_num, hidden_size])
142
+
143
+ batch_size = tf.shape(general_features)[0]
144
+ general_features = multihead_cross_attention.embedding_postprocessor(
145
+ general_features,
146
+ use_token_type=self._use_token_type,
147
+ token_type_ids=tf.ones(
148
+ shape=tf.stack([batch_size, self._general_feature_num]),
149
+ dtype=tf.int32) * token_type_id,
150
+ token_type_vocab_size=self._token_type_vocab_size,
151
+ reuse_token_type=tf.AUTO_REUSE,
152
+ use_position_embeddings=False,
153
+ dropout_prob=self._model_config.hidden_dropout_prob)
154
+
155
+ all_txt_features.append(general_features)
156
+ mask = tf.ones(
157
+ shape=tf.stack([batch_size, self._general_feature_num]),
158
+ dtype=tf.int32)
159
+ input_masks.append(mask)
160
+
161
+ if self._txt_seq_features is not None:
162
+
163
+ def dynamic_mask(x, max_len):
164
+ ones = tf.ones(shape=tf.stack([x]), dtype=tf.int32)
165
+ zeros = tf.zeros(shape=tf.stack([max_len - x]), dtype=tf.int32)
166
+ return tf.concat([ones, zeros], axis=0)
167
+
168
+ token_type_id += len(all_txt_features)
169
+ for i, (seq_fea, seq_len) in enumerate(self._txt_seq_features):
170
+ batch_size, max_seq_len, emb_size = get_shape_list(seq_fea, 3)
171
+ if emb_size != hidden_size:
172
+ seq_fea = tf.reshape(seq_fea, shape=[-1, emb_size])
173
+ seq_fea = tf.layers.dense(
174
+ seq_fea, hidden_size, name='txt_seq_projection_%d' % i)
175
+ seq_fea = tf.reshape(seq_fea, shape=[-1, max_seq_len, hidden_size])
176
+
177
+ seq_fea = multihead_cross_attention.embedding_postprocessor(
178
+ seq_fea,
179
+ use_token_type=self._use_token_type,
180
+ token_type_ids=tf.ones(
181
+ shape=tf.stack([batch_size, max_seq_len]), dtype=tf.int32) *
182
+ (i + token_type_id),
183
+ token_type_vocab_size=self._token_type_vocab_size,
184
+ reuse_token_type=tf.AUTO_REUSE,
185
+ use_position_embeddings=self._model_config.use_position_embeddings,
186
+ max_position_embeddings=self._model_config.max_position_embeddings,
187
+ position_embedding_name='txt_position_embeddings_%d' % i,
188
+ dropout_prob=self._model_config.hidden_dropout_prob)
189
+ all_txt_features.append(seq_fea)
190
+
191
+ input_mask = tf.map_fn(
192
+ fn=lambda t: dynamic_mask(t, max_seq_len),
193
+ elems=tf.to_int32(seq_len))
194
+ input_masks.append(input_mask)
195
+
196
+ return all_txt_features, input_masks
197
+
198
+ def image_embeddings(self):
199
+ if self._img_features is None:
200
+ return None
201
+ hidden_size = self._model_config.hidden_size
202
+ image_features = self._img_features
203
+ if self._img_emb_size != hidden_size:
204
+ # Run a linear projection of `hidden_size`
205
+ image_features = tf.reshape(
206
+ image_features, shape=[-1, self._img_emb_size])
207
+ image_features = tf.layers.dense(
208
+ image_features, hidden_size, name='img_projection')
209
+ image_features = tf.reshape(
210
+ image_features, shape=[-1, self._img_feature_num, hidden_size])
211
+
212
+ batch_size = tf.shape(image_features)[0]
213
+ img_fea = multihead_cross_attention.embedding_postprocessor(
214
+ image_features,
215
+ use_token_type=self._use_token_type,
216
+ token_type_ids=tf.zeros(
217
+ shape=tf.stack([batch_size, self._img_feature_num]),
218
+ dtype=tf.int32),
219
+ token_type_vocab_size=self._token_type_vocab_size,
220
+ reuse_token_type=tf.AUTO_REUSE,
221
+ use_position_embeddings=self._model_config.use_position_embeddings,
222
+ max_position_embeddings=self._model_config.max_position_embeddings,
223
+ position_embedding_name='img_position_embeddings',
224
+ dropout_prob=self._model_config.hidden_dropout_prob)
225
+ return img_fea
226
+
227
+ def __call__(self, is_training, *args, **kwargs):
228
+ if not is_training:
229
+ self._model_config.hidden_dropout_prob = 0.0
230
+ self._model_config.attention_probs_dropout_prob = 0.0
231
+
232
+ sub_modules = []
233
+
234
+ img_fea = self.image_embeddings()
235
+ start_token_id = 1 if self._img_feature_num > 0 else 0
236
+ txt_features, txt_masks = self.text_embeddings(start_token_id)
237
+
238
+ if img_fea is not None:
239
+ batch_size = tf.shape(img_fea)[0]
240
+ elif txt_features:
241
+ batch_size = tf.shape(txt_features[0])[0]
242
+ else:
243
+ batch_size = None
244
+
245
+ hidden_size = self._model_config.hidden_size
246
+ if batch_size is not None:
247
+ all_features = []
248
+ masks = []
249
+ cls_emb = tf.get_variable(name='cls_emb', shape=[1, 1, hidden_size])
250
+ cls_emb = tf.tile(cls_emb, [batch_size, 1, 1])
251
+ all_features.append(cls_emb)
252
+
253
+ mask = tf.ones(shape=tf.stack([batch_size, 1]), dtype=tf.int32)
254
+ masks.append(mask)
255
+
256
+ if img_fea is not None:
257
+ all_features.append(img_fea)
258
+ mask = tf.ones(
259
+ shape=tf.stack([batch_size, self._img_feature_num]), dtype=tf.int32)
260
+ masks.append(mask)
261
+
262
+ if txt_features:
263
+ all_features.extend(txt_features)
264
+ masks.extend(txt_masks)
265
+
266
+ all_fea = tf.concat(all_features, axis=1)
267
+ input_mask = tf.concat(masks, axis=1)
268
+ attention_mask = multihead_cross_attention.create_attention_mask_from_input_mask(
269
+ from_tensor=all_fea, to_mask=input_mask)
270
+ hidden_act = get_activation(self._model_config.hidden_act)
271
+ attention_fea = multihead_cross_attention.transformer_encoder(
272
+ all_fea,
273
+ hidden_size=hidden_size,
274
+ num_hidden_layers=self._model_config.num_hidden_layers,
275
+ num_attention_heads=self._model_config.num_attention_heads,
276
+ attention_mask=attention_mask,
277
+ intermediate_size=self._model_config.intermediate_size,
278
+ intermediate_act_fn=hidden_act,
279
+ hidden_dropout_prob=self._model_config.hidden_dropout_prob,
280
+ attention_probs_dropout_prob=self._model_config
281
+ .attention_probs_dropout_prob,
282
+ initializer_range=self._model_config.initializer_range,
283
+ name='uniter') # shape: [batch_size, seq_length, hidden_size]
284
+ print('attention_fea:', attention_fea.shape)
285
+ mm_fea = attention_fea[:, 0, :] # [CLS] feature
286
+ sub_modules.append(mm_fea)
287
+
288
+ if self._other_features is not None:
289
+ if self._model_config.HasField('other_feature_dnn'):
290
+ l2_reg = kwargs['l2_reg'] if 'l2_reg' in kwargs else 0
291
+ other_dnn_layer = dnn.DNN(self._model_config.other_feature_dnn, l2_reg,
292
+ 'other_dnn', is_training)
293
+ other_fea = other_dnn_layer(self._other_features)
294
+ else:
295
+ other_fea = self._other_features
296
+ sub_modules.append(other_fea)
297
+
298
+ if len(sub_modules) == 1:
299
+ return sub_modules[0]
300
+ output = tf.concat(sub_modules, axis=-1)
301
+ return output
@@ -0,0 +1,248 @@
1
+ # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Common util functions used by layers."""
16
+ from __future__ import absolute_import
17
+ from __future__ import division
18
+ from __future__ import print_function
19
+
20
+ import json
21
+
22
+ from google.protobuf import struct_pb2
23
+ from google.protobuf.descriptor import FieldDescriptor
24
+ from tensorflow.python.framework import ops
25
+ from tensorflow.python.framework import sparse_tensor
26
+ from tensorflow.python.ops import variables
27
+
28
+ try:
29
+ from tensorflow.python.ops import kv_variable_ops
30
+ except ImportError:
31
+ kv_variable_ops = None
32
+
33
+ ColumnNameInCollection = {}
34
+
35
+
36
+ def _tensor_to_map(tensor):
37
+ return {
38
+ 'node_path': tensor.name,
39
+ 'shape': tensor.shape.as_list() if tensor.shape else None,
40
+ 'dtype': tensor.dtype.name
41
+ }
42
+
43
+
44
+ def _tensor_to_tensorinfo(tensor):
45
+ tensor_info = {}
46
+ if isinstance(tensor, sparse_tensor.SparseTensor):
47
+ tensor_info['is_dense'] = False
48
+ tensor_info['values'] = _tensor_to_map(tensor.values)
49
+ tensor_info['indices'] = _tensor_to_map(tensor.indices)
50
+ tensor_info['dense_shape'] = _tensor_to_map(tensor.dense_shape)
51
+ else:
52
+ tensor_info['is_dense'] = True
53
+ tensor_info.update(_tensor_to_map(tensor))
54
+ return tensor_info
55
+
56
+
57
+ def add_tensor_to_collection(collection_name, name, tensor):
58
+ tensor_info = _tensor_to_tensorinfo(tensor)
59
+ tensor_info['name'] = name
60
+ update_attr_to_collection(collection_name, tensor_info)
61
+
62
+
63
+ def append_tensor_to_collection(collection_name, name, key, tensor):
64
+ tensor_info = _tensor_to_tensorinfo(tensor)
65
+ append_attr_to_collection(collection_name, name, key, tensor_info)
66
+
67
+
68
+ def _collection_item_key(col, name):
69
+ return '%d#%s' % (id(col), name)
70
+
71
+
72
+ def _process_item(collection_name, name, func):
73
+ col = ops.get_collection_ref(collection_name)
74
+ item_found = {}
75
+ idx_found = -1
76
+
77
+ # add id(col) because col may re-new sometimes
78
+ key = _collection_item_key(col, name)
79
+ if key in ColumnNameInCollection:
80
+ idx_found = ColumnNameInCollection[key]
81
+ if idx_found >= len(col):
82
+ raise Exception(
83
+ 'Find column name in collection failed: index out of range')
84
+
85
+ item_found = json.loads(col[idx_found])
86
+ if item_found['name'] != name:
87
+ raise Exception(
88
+ 'Find column name in collection failed: item name not match')
89
+ func(item_found)
90
+ col[idx_found] = json.dumps(item_found)
91
+ else:
92
+ func(item_found)
93
+ col.append(json.dumps(item_found))
94
+ ColumnNameInCollection[key] = len(col) - 1
95
+
96
+
97
+ def append_attr_to_collection(collection_name, name, key, value):
98
+
99
+ def append(item_found):
100
+ if key not in item_found:
101
+ item_found[key] = []
102
+ item_found[key].append(value)
103
+
104
+ _process_item(collection_name, name, append)
105
+
106
+
107
+ def update_attr_to_collection(collection_name, attrs):
108
+
109
+ def update(item_found):
110
+ item_found.update(attrs)
111
+
112
+ _process_item(collection_name, attrs['name'], update)
113
+
114
+
115
+ def unique_name_in_collection(collection_name, name):
116
+ col = ops.get_collection_ref(collection_name)
117
+ unique_name = name
118
+ index = 0
119
+ while True:
120
+ key = _collection_item_key(col, unique_name)
121
+ if key not in ColumnNameInCollection:
122
+ break
123
+ index += 1
124
+ unique_name = '%s_%d' % (name, index)
125
+ return unique_name
126
+
127
+
128
+ def gen_embedding_attrs(column=None,
129
+ variable=None,
130
+ bucket_size=None,
131
+ combiner=None,
132
+ is_embedding_var=None):
133
+ attrs = dict()
134
+ attrs['name'] = column.name
135
+ attrs['bucket_size'] = bucket_size
136
+ attrs['combiner'] = combiner
137
+ attrs['is_embedding_var'] = is_embedding_var
138
+ attrs['weights_op_path'] = variable.name
139
+ if kv_variable_ops:
140
+ if isinstance(variable, kv_variable_ops.EmbeddingVariable):
141
+ attrs['is_embedding_var'] = True
142
+ attrs['embedding_var_keys'] = variable._shared_name + '-keys'
143
+ attrs['embedding_var_values'] = variable._shared_name + '-values'
144
+ elif (isinstance(variable, variables.PartitionedVariable)) and \
145
+ (isinstance(variable._get_variable_list()[0], kv_variable_ops.EmbeddingVariable)):
146
+ attrs['embedding_var_keys'] = [v._shared_name + '-keys' for v in variable]
147
+ attrs['embedding_var_values'] = [
148
+ v._shared_name + '-values' for v in variable
149
+ ]
150
+ else:
151
+ attrs['is_embedding_var'] = False
152
+ else:
153
+ attrs['is_embedding_var'] = False
154
+ return attrs
155
+
156
+
157
+ def mark_input_src(name, src_desc):
158
+ ops.add_to_collection(ops.GraphKeys.RANK_SERVICE_INPUT_SRC,
159
+ json.dumps({
160
+ 'name': name,
161
+ 'src': src_desc
162
+ }))
163
+
164
+
165
+ def is_proto_message(pb_obj, field):
166
+ if not hasattr(pb_obj, 'DESCRIPTOR'):
167
+ return False
168
+ if field not in pb_obj.DESCRIPTOR.fields_by_name:
169
+ return False
170
+ field_type = pb_obj.DESCRIPTOR.fields_by_name[field].type
171
+ return field_type == FieldDescriptor.TYPE_MESSAGE
172
+
173
+
174
+ class Parameter(object):
175
+
176
+ def __init__(self, params, is_struct, l2_reg=None):
177
+ self.params = params
178
+ self.is_struct = is_struct
179
+ self._l2_reg = l2_reg
180
+
181
+ @staticmethod
182
+ def make_from_pb(config):
183
+ return Parameter(config, False)
184
+
185
+ def get_pb_config(self):
186
+ assert not self.is_struct, 'Struct parameter can not convert to pb config'
187
+ return self.params
188
+
189
+ @property
190
+ def l2_regularizer(self):
191
+ return self._l2_reg
192
+
193
+ @l2_regularizer.setter
194
+ def l2_regularizer(self, value):
195
+ self._l2_reg = value
196
+
197
+ def __getattr__(self, key):
198
+ if self.is_struct:
199
+ if key not in self.params:
200
+ return None
201
+ value = self.params[key]
202
+ if type(value) == struct_pb2.Struct:
203
+ return Parameter(value, True, self._l2_reg)
204
+ else:
205
+ return value
206
+ value = getattr(self.params, key)
207
+ if is_proto_message(self.params, key):
208
+ return Parameter(value, False, self._l2_reg)
209
+ return value
210
+
211
+ def __getitem__(self, key):
212
+ return self.__getattr__(key)
213
+
214
+ def get_or_default(self, key, def_val):
215
+ if self.is_struct:
216
+ if key in self.params:
217
+ if def_val is None:
218
+ return self.params[key]
219
+ value = self.params[key]
220
+ if type(value) == float:
221
+ return type(def_val)(value)
222
+ return value
223
+ return def_val
224
+ else: # pb message
225
+ value = getattr(self.params, key, def_val)
226
+ if hasattr(value, '__len__'): # repeated
227
+ return value if len(value) > 0 else def_val
228
+ try:
229
+ if self.params.HasField(key):
230
+ return value
231
+ except ValueError:
232
+ pass
233
+ return def_val # maybe not equal to the default value of msg field
234
+
235
+ def check_required(self, keys):
236
+ if not self.is_struct:
237
+ return
238
+ if not isinstance(keys, (list, tuple)):
239
+ keys = [keys]
240
+ for key in keys:
241
+ if key not in self.params:
242
+ raise KeyError('%s must be set in params' % key)
243
+
244
+ def has_field(self, key):
245
+ if self.is_struct:
246
+ return key in self.params
247
+ else:
248
+ return self.params.HasField(key)
@@ -0,0 +1,130 @@
1
+ # -*- encoding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import json
4
+
5
+ import numpy as np
6
+ import tensorflow as tf
7
+
8
+ from easy_rec.python.compat.feature_column.feature_column import _SharedEmbeddingColumn # NOQA
9
+ from easy_rec.python.compat.feature_column.feature_column_v2 import EmbeddingColumn # NOQA
10
+
11
+ if tf.__version__ >= '2.0':
12
+ tf = tf.compat.v1
13
+
14
+
15
+ class VariationalDropoutLayer(object):
16
+ """Rank features by variational dropout.
17
+
18
+ Use the Dropout concept on the input feature layer and optimize the corresponding feature-wise dropout rate
19
+ paper: Dropout Feature Ranking for Deep Learning Models
20
+ arXiv: 1712.08645
21
+ """
22
+
23
+ def __init__(self,
24
+ variational_dropout_config,
25
+ features_dimension,
26
+ is_training=False,
27
+ name=''):
28
+ self._config = variational_dropout_config
29
+ self.features_dimension = features_dimension
30
+ self.features_total_dimension = sum(self.features_dimension.values())
31
+
32
+ if self.variational_dropout_wise():
33
+ self._dropout_param_size = self.features_total_dimension
34
+ self.drop_param_shape = [self._dropout_param_size]
35
+ else:
36
+ self._dropout_param_size = len(self.features_dimension)
37
+ self.drop_param_shape = [self._dropout_param_size]
38
+ self.evaluate = not is_training
39
+
40
+ logit_p_name = 'logit_p' if name == 'all' else 'logit_p_%s' % name
41
+ self.logit_p = tf.get_variable(
42
+ name=logit_p_name,
43
+ shape=self.drop_param_shape,
44
+ dtype=tf.float32,
45
+ initializer=None)
46
+ tf.add_to_collection(
47
+ 'variational_dropout',
48
+ json.dumps([name, list(self.features_dimension.items())]))
49
+
50
+ def get_lambda(self):
51
+ return self._config.regularization_lambda
52
+
53
+ def variational_dropout_wise(self):
54
+ return self._config.embedding_wise_variational_dropout
55
+
56
+ def build_expand_index(self, batch_size):
57
+ # Build index_list--->[[0,0],[0,0],[0,0],[0,0],[0,1]......]
58
+ expanded_index = []
59
+ for i, index_loop_count in enumerate(self.features_dimension.values()):
60
+ for m in range(index_loop_count):
61
+ expanded_index.append([i])
62
+ expanded_index = tf.tile(expanded_index, [batch_size, 1])
63
+ batch_size_range = tf.range(batch_size)
64
+ expand_range_axis = tf.expand_dims(batch_size_range, 1)
65
+ batch_size_range_expand_dim_len = tf.tile(
66
+ expand_range_axis, [1, self.features_total_dimension])
67
+ index_i = tf.reshape(batch_size_range_expand_dim_len, [-1, 1])
68
+ expanded_index = tf.concat([index_i, expanded_index], 1)
69
+ return expanded_index
70
+
71
+ def sample_noisy_input(self, input):
72
+ batch_size = tf.shape(input)[0]
73
+ if self.evaluate:
74
+ expanded_dims_logit_p = tf.expand_dims(self.logit_p, 0)
75
+ expanded_logit_p = tf.tile(expanded_dims_logit_p, [batch_size, 1])
76
+ p = tf.sigmoid(expanded_logit_p)
77
+ if self.variational_dropout_wise():
78
+ scaled_input = input * (1 - p)
79
+ else:
80
+ # expand dropout layer
81
+ expanded_index = self.build_expand_index(batch_size)
82
+ expanded_p = tf.gather_nd(p, expanded_index)
83
+ expanded_p = tf.reshape(expanded_p, [-1, self.features_total_dimension])
84
+ scaled_input = input * (1 - expanded_p)
85
+
86
+ return scaled_input
87
+ else:
88
+ bern_val = self.sampled_from_logit_p(batch_size)
89
+ bern_val = tf.reshape(bern_val, [-1, self.features_total_dimension])
90
+ noisy_input = input * bern_val
91
+ return noisy_input
92
+
93
+ def sampled_from_logit_p(self, num_samples):
94
+ expand_dims_logit_p = tf.expand_dims(self.logit_p, 0)
95
+ expand_logit_p = tf.tile(expand_dims_logit_p, [num_samples, 1])
96
+ dropout_p = tf.sigmoid(expand_logit_p)
97
+ bern_val = self.concrete_dropout_neuron(dropout_p)
98
+
99
+ if self.variational_dropout_wise():
100
+ return bern_val
101
+ else:
102
+ # from feature_num to embedding_dim_num
103
+ expanded_index = self.build_expand_index(num_samples)
104
+ bern_val_gather_nd = tf.gather_nd(bern_val, expanded_index)
105
+ return bern_val_gather_nd
106
+
107
+ def concrete_dropout_neuron(self, dropout_p, temp=1.0 / 10.0):
108
+ EPSILON = np.finfo(float).eps
109
+ unif_noise = tf.random_uniform(
110
+ tf.shape(dropout_p), dtype=tf.float32, seed=None, name='unif_noise')
111
+
112
+ approx = (
113
+ tf.log(dropout_p + EPSILON) - tf.log(1. - dropout_p + EPSILON) +
114
+ tf.log(unif_noise + EPSILON) - tf.log(1. - unif_noise + EPSILON))
115
+
116
+ approx_output = tf.sigmoid(approx / temp)
117
+ return 1 - approx_output
118
+
119
+ def __call__(self, output_features):
120
+ batch_size = tf.shape(output_features)[0]
121
+ noisy_input = self.sample_noisy_input(output_features)
122
+ dropout_p = tf.sigmoid(self.logit_p)
123
+ variational_dropout_penalty = 1. - dropout_p
124
+ variational_dropout_penalty_lambda = self.get_lambda() / tf.cast(
125
+ batch_size, dtype=tf.float32)
126
+ variational_dropout_loss_sum = variational_dropout_penalty_lambda * tf.reduce_sum(
127
+ variational_dropout_penalty, axis=0)
128
+ tf.add_to_collection('variational_dropout_loss',
129
+ variational_dropout_loss_sum)
130
+ return noisy_input
File without changes