easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
@@ -0,0 +1,267 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ """Attention layers that can be used in sequence DNN/CNN models.
4
+
5
+ This file follows the terminology of https://arxiv.org/abs/1706.03762 Figure 2.
6
+ Attention is formed by three tensors: Query, Key and Value.
7
+ """
8
+ import tensorflow as tf
9
+ from tensorflow.python.keras.layers import Layer
10
+
11
+
12
+ class Attention(Layer):
13
+ """Dot-product attention layer, a.k.a. Luong-style attention.
14
+
15
+ Inputs are a list with 2 or 3 elements:
16
+ 1. A `query` tensor of shape `(batch_size, Tq, dim)`.
17
+ 2. A `value` tensor of shape `(batch_size, Tv, dim)`.
18
+ 3. A optional `key` tensor of shape `(batch_size, Tv, dim)`. If none
19
+ supplied, `value` will be used as a `key`.
20
+
21
+ The calculation follows the steps:
22
+ 1. Calculate attention scores using `query` and `key` with shape
23
+ `(batch_size, Tq, Tv)`.
24
+ 2. Use scores to calculate a softmax distribution with shape
25
+ `(batch_size, Tq, Tv)`.
26
+ 3. Use the softmax distribution to create a linear combination of `value`
27
+ with shape `(batch_size, Tq, dim)`.
28
+
29
+ Args:
30
+ use_scale: If `True`, will create a scalar variable to scale the
31
+ attention scores.
32
+ dropout: Float between 0 and 1. Fraction of the units to drop for the
33
+ attention scores. Defaults to `0.0`.
34
+ seed: A Python integer to use as random seed in case of `dropout`.
35
+ score_mode: Function to use to compute attention scores, one of
36
+ `{"dot", "concat"}`. `"dot"` refers to the dot product between the
37
+ query and key vectors. `"concat"` refers to the hyperbolic tangent
38
+ of the concatenation of the `query` and `key` vectors.
39
+
40
+ Call Args:
41
+ inputs: List of the following tensors:
42
+ - `query`: Query tensor of shape `(batch_size, Tq, dim)`.
43
+ - `value`: Value tensor of shape `(batch_size, Tv, dim)`.
44
+ - `key`: Optional key tensor of shape `(batch_size, Tv, dim)`. If
45
+ not given, will use `value` for both `key` and `value`, which is
46
+ the most common case.
47
+ mask: List of the following tensors:
48
+ - `query_mask`: A boolean mask tensor of shape `(batch_size, Tq)`.
49
+ If given, the output will be zero at the positions where
50
+ `mask==False`.
51
+ - `value_mask`: A boolean mask tensor of shape `(batch_size, Tv)`.
52
+ If given, will apply the mask such that values at positions
53
+ where `mask==False` do not contribute to the result.
54
+ return_attention_scores: bool, it `True`, returns the attention scores
55
+ (after masking and softmax) as an additional output argument.
56
+ training: Python boolean indicating whether the layer should behave in
57
+ training mode (adding dropout) or in inference mode (no dropout).
58
+ use_causal_mask: Boolean. Set to `True` for decoder self-attention. Adds
59
+ a mask such that position `i` cannot attend to positions `j > i`.
60
+ This prevents the flow of information from the future towards the
61
+ past. Defaults to `False`.
62
+
63
+ Output:
64
+ Attention outputs of shape `(batch_size, Tq, dim)`.
65
+ (Optional) Attention scores after masking and softmax with shape
66
+ `(batch_size, Tq, Tv)`.
67
+ """
68
+
69
+ def __init__(self, params, name='attention', reuse=None, **kwargs):
70
+ super(Attention, self).__init__(name=name, **kwargs)
71
+ self.use_scale = params.get_or_default('use_scale', False)
72
+ self.scale_by_dim = params.get_or_default('scale_by_dim', False)
73
+ self.score_mode = params.get_or_default('score_mode', 'dot')
74
+ if self.score_mode not in ['dot', 'concat']:
75
+ raise ValueError('Invalid value for argument score_mode. '
76
+ "Expected one of {'dot', 'concat'}. "
77
+ 'Received: score_mode=%s' % self.score_mode)
78
+ self.dropout = params.get_or_default('dropout', 0.0)
79
+ self.seed = params.get_or_default('seed', None)
80
+ self.scale = None
81
+ self.concat_score_weight = None
82
+ self._return_attention_scores = params.get_or_default(
83
+ 'return_attention_scores', False)
84
+ self.use_causal_mask = params.get_or_default('use_causal_mask', False)
85
+
86
+ @property
87
+ def return_attention_scores(self):
88
+ return self._return_attention_scores
89
+
90
+ def build(self, input_shape):
91
+ self._validate_inputs(input_shape)
92
+ if self.use_scale:
93
+ self.scale = self.add_weight(
94
+ name='scale',
95
+ shape=(),
96
+ initializer='ones',
97
+ dtype=self.dtype,
98
+ trainable=True,
99
+ )
100
+ if self.score_mode == 'concat':
101
+ self.concat_score_weight = self.add_weight(
102
+ name='concat_score_weight',
103
+ shape=(),
104
+ initializer='ones',
105
+ dtype=self.dtype,
106
+ trainable=True,
107
+ )
108
+ super(Attention, self).build(input_shape) # Be sure to call this somewhere!
109
+
110
+ def _calculate_scores(self, query, key):
111
+ """Calculates attention scores as a query-key dot product.
112
+
113
+ Args:
114
+ query: Query tensor of shape `(batch_size, Tq, dim)`.
115
+ key: Key tensor of shape `(batch_size, Tv, dim)`.
116
+
117
+ Returns:
118
+ Tensor of shape `(batch_size, Tq, Tv)`.
119
+ """
120
+ if self.score_mode == 'dot':
121
+ scores = tf.matmul(query, tf.transpose(key, [0, 2, 1]))
122
+ if self.scale is not None:
123
+ scores *= self.scale
124
+ elif self.scale_by_dim:
125
+ dk = tf.cast(tf.shape(key)[-1], tf.float32)
126
+ scores /= tf.math.sqrt(dk)
127
+ elif self.score_mode == 'concat':
128
+ # Reshape tensors to enable broadcasting.
129
+ # Reshape into [batch_size, Tq, 1, dim].
130
+ q_reshaped = tf.expand_dims(query, axis=-2)
131
+ # Reshape into [batch_size, 1, Tv, dim].
132
+ k_reshaped = tf.expand_dims(key, axis=-3)
133
+ if self.scale is not None:
134
+ scores = self.concat_score_weight * tf.reduce_sum(
135
+ tf.tanh(self.scale * (q_reshaped + k_reshaped)), axis=-1)
136
+ else:
137
+ scores = self.concat_score_weight * tf.reduce_sum(
138
+ tf.tanh(q_reshaped + k_reshaped), axis=-1)
139
+ return scores
140
+
141
+ def _apply_scores(self, scores, value, scores_mask=None, training=False):
142
+ """Applies attention scores to the given value tensor.
143
+
144
+ To use this method in your attention layer, follow the steps:
145
+
146
+ * Use `query` tensor of shape `(batch_size, Tq)` and `key` tensor of
147
+ shape `(batch_size, Tv)` to calculate the attention `scores`.
148
+ * Pass `scores` and `value` tensors to this method. The method applies
149
+ `scores_mask`, calculates
150
+ `attention_distribution = softmax(scores)`, then returns
151
+ `matmul(attention_distribution, value).
152
+ * Apply `query_mask` and return the result.
153
+
154
+ Args:
155
+ scores: Scores float tensor of shape `(batch_size, Tq, Tv)`.
156
+ value: Value tensor of shape `(batch_size, Tv, dim)`.
157
+ scores_mask: A boolean mask tensor of shape `(batch_size, 1, Tv)`
158
+ or `(batch_size, Tq, Tv)`. If given, scores at positions where
159
+ `scores_mask==False` do not contribute to the result. It must
160
+ contain at least one `True` value in each line along the last
161
+ dimension.
162
+ training: Python boolean indicating whether the layer should behave
163
+ in training mode (adding dropout) or in inference mode
164
+ (no dropout).
165
+
166
+ Returns:
167
+ Tensor of shape `(batch_size, Tq, dim)`.
168
+ Attention scores after masking and softmax with shape
169
+ `(batch_size, Tq, Tv)`.
170
+ """
171
+ if scores_mask is not None:
172
+ padding_mask = tf.logical_not(scores_mask)
173
+ # Bias so padding positions do not contribute to attention
174
+ # distribution. Note 65504. is the max float16 value.
175
+ max_value = 65504.0 if scores.dtype == 'float16' else 1.0e9
176
+ scores -= max_value * tf.cast(padding_mask, dtype=scores.dtype)
177
+
178
+ weights = tf.nn.softmax(scores, axis=-1)
179
+ if training and self.dropout > 0:
180
+ weights = tf.nn.dropout(weights, 1.0 - self.dropout, seed=self.seed)
181
+ return tf.matmul(weights, value), weights
182
+
183
+ def _calculate_score_mask(self, scores, v_mask, use_causal_mask):
184
+ if use_causal_mask:
185
+ # Creates a lower triangular mask, so position i cannot attend to
186
+ # positions j > i. This prevents the flow of information from the
187
+ # future into the past.
188
+ score_shape = tf.shape(scores)
189
+ # causal_mask_shape = [1, Tq, Tv].
190
+ mask_shape = (1, score_shape[-2], score_shape[-1])
191
+ ones_mask = tf.ones(shape=mask_shape, dtype='int32')
192
+ row_index = tf.cumsum(ones_mask, axis=-2)
193
+ col_index = tf.cumsum(ones_mask, axis=-1)
194
+ causal_mask = tf.greater_equal(row_index, col_index)
195
+
196
+ if v_mask is not None:
197
+ # Mask of shape [batch_size, 1, Tv].
198
+ v_mask = tf.expand_dims(v_mask, axis=-2)
199
+ return tf.logical_and(v_mask, causal_mask)
200
+ return causal_mask
201
+ else:
202
+ # If not using causal mask, return the value mask as is,
203
+ # or None if the value mask is not provided.
204
+ return v_mask
205
+
206
+ def call(self, inputs, mask=None, training=False, **kwargs):
207
+ self._validate_inputs(inputs=inputs, mask=mask)
208
+ q = inputs[0]
209
+ v = inputs[1]
210
+ k = inputs[2] if len(inputs) > 2 else v
211
+ q_mask = mask[0] if mask else None
212
+ v_mask = mask[1] if mask else None
213
+ scores = self._calculate_scores(query=q, key=k)
214
+ scores_mask = self._calculate_score_mask(scores, v_mask,
215
+ self.use_causal_mask)
216
+ result, attention_scores = self._apply_scores(
217
+ scores=scores, value=v, scores_mask=scores_mask, training=training)
218
+ if q_mask is not None:
219
+ # Mask of shape [batch_size, Tq, 1].
220
+ q_mask = tf.expand_dims(q_mask, axis=-1)
221
+ result *= tf.cast(q_mask, dtype=result.dtype)
222
+ if self._return_attention_scores:
223
+ return result, attention_scores
224
+ return result
225
+
226
+ def compute_mask(self, inputs, mask=None):
227
+ self._validate_inputs(inputs=inputs, mask=mask)
228
+ if mask is None or mask[0] is None:
229
+ return None
230
+ return tf.convert_to_tensor(mask[0])
231
+
232
+ def compute_output_shape(self, input_shape):
233
+ """Returns shape of value tensor dim, but for query tensor length."""
234
+ return list(input_shape[0][:-1]), input_shape[1][-1]
235
+
236
+ def _validate_inputs(self, inputs, mask=None):
237
+ """Validates arguments of the call method."""
238
+ class_name = self.__class__.__name__
239
+ if not isinstance(inputs, list):
240
+ raise ValueError('{class_name} layer must be called on a list of inputs, '
241
+ 'namely [query, value] or [query, value, key]. '
242
+ 'Received: inputs={inputs}.'.format(
243
+ class_name=class_name, inputs=inputs))
244
+ if len(inputs) < 2 or len(inputs) > 3:
245
+ raise ValueError('%s layer accepts inputs list of length 2 or 3, '
246
+ 'namely [query, value] or [query, value, key]. '
247
+ 'Received length: %d.' % (class_name, len(inputs)))
248
+ if mask is not None:
249
+ if not isinstance(mask, list):
250
+ raise ValueError(
251
+ '{class_name} layer mask must be a list, '
252
+ 'namely [query_mask, value_mask]. Received: mask={mask}.'.format(
253
+ class_name=class_name, mask=mask))
254
+ if len(mask) < 2 or len(mask) > 3:
255
+ raise ValueError(
256
+ '{class_name} layer accepts mask list of length 2 or 3. '
257
+ 'Received: inputs={inputs}, mask={mask}.'.format(
258
+ class_name=class_name, inputs=inputs, mask=mask))
259
+
260
+ def get_config(self):
261
+ base_config = super(Attention, self).get_config()
262
+ config = {
263
+ 'use_scale': self.use_scale,
264
+ 'score_mode': self.score_mode,
265
+ 'dropout': self.dropout,
266
+ }
267
+ return dict(list(base_config.items()) + list(config.items()))
@@ -0,0 +1,47 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import logging
4
+
5
+ import tensorflow as tf
6
+
7
+ from easy_rec.python.loss import contrastive_loss
8
+
9
+
10
+ class AuxiliaryLoss(tf.keras.layers.Layer):
11
+ """Compute auxiliary loss, usually use for contrastive learning."""
12
+
13
+ def __init__(self, params, name='auxiliary_loss', reuse=None, **kwargs):
14
+ super(AuxiliaryLoss, self).__init__(name=name, **kwargs)
15
+ params.check_required('loss_type')
16
+ self.loss_type = params.get_or_default('loss_type', None)
17
+ self.loss_weight = params.get_or_default('loss_weight', 1.0)
18
+ logging.info('init layer `%s` with loss type: %s and weight: %f' %
19
+ (self.name, self.loss_type, self.loss_weight))
20
+ self.temperature = params.get_or_default('temperature', 0.1)
21
+
22
+ def call(self, inputs, training=None, **kwargs):
23
+ if self.loss_type is None:
24
+ logging.warning('loss_type is None in auxiliary loss layer')
25
+ return 0
26
+
27
+ loss_dict = kwargs['loss_dict']
28
+ loss_value = 0
29
+
30
+ if self.loss_type == 'l2_loss':
31
+ x1, x2 = inputs
32
+ loss = contrastive_loss.l2_loss(x1, x2)
33
+ loss_value = loss if self.loss_weight == 1.0 else loss * self.loss_weight
34
+ loss_dict['%s_l2_loss' % self.name] = loss_value
35
+ elif self.loss_type == 'info_nce':
36
+ query, positive = inputs
37
+ loss = contrastive_loss.info_nce_loss(
38
+ query, positive, temperature=self.temperature)
39
+ loss_value = loss if self.loss_weight == 1.0 else loss * self.loss_weight
40
+ loss_dict['%s_info_nce_loss' % self.name] = loss_value
41
+ elif self.loss_type == 'nce_loss':
42
+ x1, x2 = inputs
43
+ loss = contrastive_loss.nce_loss(x1, x2, temperature=self.temperature)
44
+ loss_value = loss if self.loss_weight == 1.0 else loss * self.loss_weight
45
+ loss_dict['%s_nce_loss' % self.name] = loss_value
46
+
47
+ return loss_value
@@ -0,0 +1,262 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ """Convenience blocks for building models."""
4
+ import logging
5
+
6
+ import tensorflow as tf
7
+ from tensorflow.python.keras.initializers import Constant
8
+ from tensorflow.python.keras.layers import Dense
9
+ from tensorflow.python.keras.layers import Dropout
10
+ from tensorflow.python.keras.layers import Lambda
11
+ from tensorflow.python.keras.layers import Layer
12
+
13
+ from easy_rec.python.layers.keras.activation import activation_layer
14
+ from easy_rec.python.layers.utils import Parameter
15
+ from easy_rec.python.utils.shape_utils import pad_or_truncate_sequence
16
+ from easy_rec.python.utils.tf_utils import add_elements_to_collection
17
+
18
+ if tf.__version__ >= '2.0':
19
+ tf = tf.compat.v1
20
+
21
+
22
+ class MLP(Layer):
23
+ """Sequential multi-layer perceptron (MLP) block.
24
+
25
+ Attributes:
26
+ units: Sequential list of layer sizes.
27
+ use_bias: Whether to include a bias term.
28
+ activation: Type of activation to use on all except the last layer.
29
+ final_activation: Type of activation to use on last layer.
30
+ **kwargs: Extra args passed to the Keras Layer base class.
31
+ """
32
+
33
+ def __init__(self, params, name='mlp', reuse=None, **kwargs):
34
+ super(MLP, self).__init__(name=name, **kwargs)
35
+ self.layer_name = name # for add to output
36
+ params.check_required('hidden_units')
37
+ use_bn = params.get_or_default('use_bn', True)
38
+ use_final_bn = params.get_or_default('use_final_bn', True)
39
+ use_bias = params.get_or_default('use_bias', False)
40
+ use_final_bias = params.get_or_default('use_final_bias', False)
41
+ dropout_rate = list(params.get_or_default('dropout_ratio', []))
42
+ activation = params.get_or_default('activation', 'relu')
43
+ initializer = params.get_or_default('initializer', 'he_uniform')
44
+ final_activation = params.get_or_default('final_activation', None)
45
+ use_bn_after_act = params.get_or_default('use_bn_after_activation', False)
46
+ units = list(params.hidden_units)
47
+ logging.info(
48
+ 'MLP(%s) units: %s, dropout: %r, activate=%s, use_bn=%r, final_bn=%r,'
49
+ ' final_activate=%s, bias=%r, initializer=%s, bn_after_activation=%r' %
50
+ (name, units, dropout_rate, activation, use_bn, use_final_bn,
51
+ final_activation, use_bias, initializer, use_bn_after_act))
52
+ assert len(units) > 0, 'MLP(%s) takes at least one hidden units' % name
53
+ self.reuse = reuse
54
+ self.add_to_outputs = params.get_or_default('add_to_outputs', False)
55
+
56
+ num_dropout = len(dropout_rate)
57
+ self._sub_layers = []
58
+ for i, num_units in enumerate(units[:-1]):
59
+ name = 'layer_%d' % i
60
+ drop_rate = dropout_rate[i] if i < num_dropout else 0.0
61
+ self.add_rich_layer(num_units, use_bn, drop_rate, activation, initializer,
62
+ use_bias, use_bn_after_act, name,
63
+ params.l2_regularizer)
64
+
65
+ n = len(units) - 1
66
+ drop_rate = dropout_rate[n] if num_dropout > n else 0.0
67
+ name = 'layer_%d' % n
68
+ self.add_rich_layer(units[-1], use_final_bn, drop_rate, final_activation,
69
+ initializer, use_final_bias, use_bn_after_act, name,
70
+ params.l2_regularizer)
71
+
72
+ def add_rich_layer(self,
73
+ num_units,
74
+ use_bn,
75
+ dropout_rate,
76
+ activation,
77
+ initializer,
78
+ use_bias,
79
+ use_bn_after_activation,
80
+ name,
81
+ l2_reg=None):
82
+ act_layer = activation_layer(activation, name='%s/act' % name)
83
+ if use_bn and not use_bn_after_activation:
84
+ dense = Dense(
85
+ units=num_units,
86
+ use_bias=use_bias,
87
+ kernel_initializer=initializer,
88
+ kernel_regularizer=l2_reg,
89
+ name='%s/dense' % name)
90
+ self._sub_layers.append(dense)
91
+ bn = tf.keras.layers.BatchNormalization(
92
+ name='%s/bn' % name, trainable=True)
93
+ self._sub_layers.append(bn)
94
+ self._sub_layers.append(act_layer)
95
+ else:
96
+ dense = Dense(
97
+ num_units,
98
+ use_bias=use_bias,
99
+ kernel_initializer=initializer,
100
+ kernel_regularizer=l2_reg,
101
+ name='%s/dense' % name)
102
+ self._sub_layers.append(dense)
103
+ self._sub_layers.append(act_layer)
104
+ if use_bn and use_bn_after_activation:
105
+ bn = tf.keras.layers.BatchNormalization(name='%s/bn' % name)
106
+ self._sub_layers.append(bn)
107
+
108
+ if 0.0 < dropout_rate < 1.0:
109
+ dropout = Dropout(dropout_rate, name='%s/dropout' % name)
110
+ self._sub_layers.append(dropout)
111
+ elif dropout_rate >= 1.0:
112
+ raise ValueError('invalid dropout_ratio: %.3f' % dropout_rate)
113
+
114
+ def call(self, x, training=None, **kwargs):
115
+ """Performs the forward computation of the block."""
116
+ for layer in self._sub_layers:
117
+ cls = layer.__class__.__name__
118
+ if cls in ('Dropout', 'BatchNormalization', 'Dice'):
119
+ x = layer(x, training=training)
120
+ if cls in ('BatchNormalization', 'Dice') and training:
121
+ add_elements_to_collection(layer.updates, tf.GraphKeys.UPDATE_OPS)
122
+ else:
123
+ x = layer(x)
124
+ if self.add_to_outputs and 'prediction_dict' in kwargs:
125
+ outputs = kwargs['prediction_dict']
126
+ outputs[self.layer_name] = tf.squeeze(x, axis=1)
127
+ logging.info('add `%s` to model outputs' % self.layer_name)
128
+ return x
129
+
130
+
131
+ class Highway(Layer):
132
+
133
+ def __init__(self, params, name='highway', reuse=None, **kwargs):
134
+ super(Highway, self).__init__(name=name, **kwargs)
135
+ self.emb_size = params.get_or_default('emb_size', None)
136
+ self.num_layers = params.get_or_default('num_layers', 1)
137
+ self.activation = params.get_or_default('activation', 'relu')
138
+ self.dropout_rate = params.get_or_default('dropout_rate', 0.0)
139
+ self.init_gate_bias = params.get_or_default('init_gate_bias', -3.0)
140
+ self.act_layer = activation_layer(self.activation)
141
+ self.dropout_layer = Dropout(
142
+ self.dropout_rate) if self.dropout_rate > 0.0 else None
143
+ self.project_layer = None
144
+ self.gate_bias_initializer = Constant(self.init_gate_bias)
145
+ self.gates = [] # T
146
+ self.transforms = [] # H
147
+ self.multiply_layer = tf.keras.layers.Multiply()
148
+ self.add_layer = tf.keras.layers.Add()
149
+
150
+ def build(self, input_shape):
151
+ dim = input_shape[-1]
152
+ if self.emb_size is not None and dim != self.emb_size:
153
+ self.project_layer = Dense(self.emb_size, name='input_projection')
154
+ dim = self.emb_size
155
+ self.carry_gate = Lambda(lambda x: 1.0 - x, output_shape=(dim,))
156
+ for i in range(self.num_layers):
157
+ gate = Dense(
158
+ units=dim,
159
+ bias_initializer=self.gate_bias_initializer,
160
+ activation='sigmoid',
161
+ name='gate_%d' % i)
162
+ self.gates.append(gate)
163
+ self.transforms.append(Dense(units=dim))
164
+
165
+ def call(self, inputs, training=None, **kwargs):
166
+ value = inputs
167
+ if self.project_layer is not None:
168
+ value = self.project_layer(inputs)
169
+ for i in range(self.num_layers):
170
+ gate = self.gates[i](value)
171
+ transformed = self.act_layer(self.transforms[i](value))
172
+ if self.dropout_layer is not None:
173
+ transformed = self.dropout_layer(transformed, training=training)
174
+ transformed_gated = self.multiply_layer([gate, transformed])
175
+ identity_gated = self.multiply_layer([self.carry_gate(gate), value])
176
+ value = self.add_layer([transformed_gated, identity_gated])
177
+ return value
178
+
179
+
180
+ class Gate(Layer):
181
+ """Weighted sum gate."""
182
+
183
+ def __init__(self, params, name='gate', reuse=None, **kwargs):
184
+ super(Gate, self).__init__(name=name, **kwargs)
185
+ self.weight_index = params.get_or_default('weight_index', 0)
186
+ if params.has_field('mlp'):
187
+ mlp_cfg = Parameter.make_from_pb(params.mlp)
188
+ mlp_cfg.l2_regularizer = params.l2_regularizer
189
+ self.top_mlp = MLP(mlp_cfg, name='top_mlp')
190
+ else:
191
+ self.top_mlp = None
192
+
193
+ def call(self, inputs, training=None, **kwargs):
194
+ assert len(
195
+ inputs
196
+ ) > 1, 'input of Gate layer must be a list containing at least 2 elements'
197
+ weights = inputs[self.weight_index]
198
+ j = 0
199
+ for i, x in enumerate(inputs):
200
+ if i == self.weight_index:
201
+ continue
202
+ if j == 0:
203
+ output = weights[:, j, None] * x
204
+ else:
205
+ output += weights[:, j, None] * x
206
+ j += 1
207
+ if self.top_mlp is not None:
208
+ output = self.top_mlp(output, training=training)
209
+ return output
210
+
211
+
212
+ class TextCNN(Layer):
213
+ """Text CNN Model.
214
+
215
+ References
216
+ - [Convolutional Neural Networks for Sentence Classification](https://arxiv.org/abs/1408.5882)
217
+ """
218
+
219
+ def __init__(self, params, name='text_cnn', reuse=None, **kwargs):
220
+ super(TextCNN, self).__init__(name=name, **kwargs)
221
+ self.config = params.get_pb_config()
222
+ self.pad_seq_length = self.config.pad_sequence_length
223
+ if self.pad_seq_length <= 0:
224
+ logging.warning(
225
+ 'run text cnn with pad_sequence_length <= 0, the predict of model may be unstable'
226
+ )
227
+ self.conv_layers = []
228
+ self.pool_layer = tf.keras.layers.GlobalMaxPool1D()
229
+ self.concat_layer = tf.keras.layers.Concatenate(axis=-1)
230
+ for size, filters in zip(self.config.filter_sizes, self.config.num_filters):
231
+ conv = tf.keras.layers.Conv1D(
232
+ filters=int(filters),
233
+ kernel_size=int(size),
234
+ activation=self.config.activation)
235
+ self.conv_layers.append(conv)
236
+ if self.config.HasField('mlp'):
237
+ p = Parameter.make_from_pb(self.config.mlp)
238
+ p.l2_regularizer = params.l2_regularizer
239
+ self.mlp = MLP(p, name='mlp', reuse=reuse)
240
+ else:
241
+ self.mlp = None
242
+
243
+ def call(self, inputs, training=None, **kwargs):
244
+ """Input shape: 3D tensor with shape: `(batch_size, steps, input_dim)."""
245
+ assert isinstance(inputs, (list, tuple))
246
+ assert len(inputs) >= 2
247
+ seq_emb, seq_len = inputs[:2]
248
+
249
+ if self.pad_seq_length > 0:
250
+ seq_emb, seq_len = pad_or_truncate_sequence(seq_emb, seq_len,
251
+ self.pad_seq_length)
252
+ pooled_outputs = []
253
+ for layer in self.conv_layers:
254
+ conv = layer(seq_emb)
255
+ pooled = self.pool_layer(conv)
256
+ pooled_outputs.append(pooled)
257
+ net = self.concat_layer(pooled_outputs)
258
+ if self.mlp is not None:
259
+ output = self.mlp(net, training=training)
260
+ else:
261
+ output = net
262
+ return output