easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
@@ -0,0 +1,571 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import logging
4
+
5
+ import six
6
+ import tensorflow as tf
7
+ from google.protobuf import struct_pb2
8
+
9
+ from easy_rec.python.layers.common_layers import EnhancedInputLayer
10
+ from easy_rec.python.layers.keras import MLP
11
+ from easy_rec.python.layers.keras import EmbeddingLayer
12
+ from easy_rec.python.layers.utils import Parameter
13
+ from easy_rec.python.protos import backbone_pb2
14
+ from easy_rec.python.utils.dag import DAG
15
+ from easy_rec.python.utils.load_class import load_keras_layer
16
+ from easy_rec.python.utils.tf_utils import add_elements_to_collection
17
+
18
+ if tf.__version__ >= '2.0':
19
+ tf = tf.compat.v1
20
+
21
+
22
+ class Package(object):
23
+ """A sub DAG of tf ops for reuse."""
24
+ __packages = {}
25
+
26
+ @staticmethod
27
+ def has_backbone_block(name):
28
+ if 'backbone' not in Package.__packages:
29
+ return False
30
+ backbone = Package.__packages['backbone']
31
+ return backbone.has_block(name)
32
+
33
+ @staticmethod
34
+ def backbone_block_outputs(name):
35
+ if 'backbone' not in Package.__packages:
36
+ return None
37
+ backbone = Package.__packages['backbone']
38
+ return backbone.block_outputs(name)
39
+
40
+ def __init__(self, config, features, input_layer, l2_reg=None):
41
+ self._config = config
42
+ self._features = features
43
+ self._input_layer = input_layer
44
+ self._l2_reg = l2_reg
45
+ self._dag = DAG()
46
+ self._name_to_blocks = {}
47
+ self._name_to_layer = {}
48
+ self.reset_input_config(None)
49
+ self._block_outputs = {}
50
+ self._package_input = None
51
+ self._feature_group_inputs = {}
52
+ reuse = None if config.name == 'backbone' else tf.AUTO_REUSE
53
+ input_feature_groups = self._feature_group_inputs
54
+
55
+ for block in config.blocks:
56
+ if len(block.inputs) == 0:
57
+ raise ValueError('block takes at least one input: %s' % block.name)
58
+ self._dag.add_node(block.name)
59
+ self._name_to_blocks[block.name] = block
60
+ layer = block.WhichOneof('layer')
61
+ if layer in {'input_layer', 'raw_input', 'embedding_layer'}:
62
+ if len(block.inputs) != 1:
63
+ raise ValueError('input layer `%s` takes only one input' % block.name)
64
+ one_input = block.inputs[0]
65
+ name = one_input.WhichOneof('name')
66
+ if name != 'feature_group_name':
67
+ raise KeyError(
68
+ '`feature_group_name` should be set for input layer: ' +
69
+ block.name)
70
+ group = one_input.feature_group_name
71
+ if not input_layer.has_group(group):
72
+ raise KeyError('invalid feature group name: ' + group)
73
+ if group in input_feature_groups:
74
+ if layer == input_layer:
75
+ logging.warning('input `%s` already exists in other block' % group)
76
+ elif layer == 'raw_input':
77
+ input_fn = input_feature_groups[group]
78
+ self._name_to_layer[block.name] = input_fn
79
+ elif layer == 'embedding_layer':
80
+ inputs, vocab, weights = input_feature_groups[group]
81
+ block.embedding_layer.vocab_size = vocab
82
+ params = Parameter.make_from_pb(block.embedding_layer)
83
+ input_fn = EmbeddingLayer(params, block.name)
84
+ self._name_to_layer[block.name] = input_fn
85
+ else:
86
+ if layer == 'input_layer':
87
+ input_fn = EnhancedInputLayer(self._input_layer, self._features,
88
+ group, reuse)
89
+ input_feature_groups[group] = input_fn
90
+ elif layer == 'raw_input':
91
+ input_fn = self._input_layer.get_raw_features(self._features, group)
92
+ input_feature_groups[group] = input_fn
93
+ else: # embedding_layer
94
+ inputs, vocab, weights = self._input_layer.get_bucketized_features(
95
+ self._features, group)
96
+ block.embedding_layer.vocab_size = vocab
97
+ params = Parameter.make_from_pb(block.embedding_layer)
98
+ input_fn = EmbeddingLayer(params, block.name)
99
+ input_feature_groups[group] = (inputs, vocab, weights)
100
+ logging.info('add an embedding layer %s with vocab size %d',
101
+ block.name, vocab)
102
+ self._name_to_layer[block.name] = input_fn
103
+ else:
104
+ self.define_layers(layer, block, block.name, reuse)
105
+
106
+ # sequential layers
107
+ for i, layer_cnf in enumerate(block.layers):
108
+ layer = layer_cnf.WhichOneof('layer')
109
+ name_i = '%s_l%d' % (block.name, i)
110
+ self.define_layers(layer, layer_cnf, name_i, reuse)
111
+
112
+ num_groups = len(input_feature_groups)
113
+ num_blocks = len(self._name_to_blocks) - num_groups
114
+ assert num_blocks > 0, 'there must be at least one block in backbone'
115
+
116
+ num_pkg_input = 0
117
+ for block in config.blocks:
118
+ layer = block.WhichOneof('layer')
119
+ if layer in {'input_layer', 'raw_input', 'embedding_layer'}:
120
+ continue
121
+ name = block.name
122
+ if name in input_feature_groups:
123
+ raise KeyError('block name can not be one of feature groups:' + name)
124
+ for input_node in block.inputs:
125
+ input_type = input_node.WhichOneof('name')
126
+ input_name = getattr(input_node, input_type)
127
+ if input_type == 'use_package_input':
128
+ assert input_name, 'use_package_input can not set false'
129
+ num_pkg_input += 1
130
+ continue
131
+ if input_type == 'package_name':
132
+ num_pkg_input += 1
133
+ self._dag.add_node_if_not_exists(input_name)
134
+ self._dag.add_edge(input_name, name)
135
+ if input_node.HasField('package_input'):
136
+ pkg_input_name = input_node.package_input
137
+ self._dag.add_node_if_not_exists(pkg_input_name)
138
+ self._dag.add_edge(pkg_input_name, input_name)
139
+ continue
140
+ iname = input_name
141
+ if iname in self._name_to_blocks:
142
+ assert iname != name, 'input name can not equal to block name:' + iname
143
+ self._dag.add_edge(iname, name)
144
+ else:
145
+ is_fea_group = input_type == 'feature_group_name'
146
+ if is_fea_group and input_layer.has_group(iname):
147
+ logging.info('adding an input_layer block: ' + iname)
148
+ new_block = backbone_pb2.Block()
149
+ new_block.name = iname
150
+ input_cfg = backbone_pb2.Input()
151
+ input_cfg.feature_group_name = iname
152
+ new_block.inputs.append(input_cfg)
153
+ new_block.input_layer.CopyFrom(backbone_pb2.InputLayer())
154
+ self._name_to_blocks[iname] = new_block
155
+ self._dag.add_node(iname)
156
+ self._dag.add_edge(iname, name)
157
+ if iname in input_feature_groups:
158
+ fn = input_feature_groups[iname]
159
+ else:
160
+ fn = EnhancedInputLayer(self._input_layer, self._features, iname)
161
+ input_feature_groups[iname] = fn
162
+ self._name_to_layer[iname] = fn
163
+ elif Package.has_backbone_block(iname):
164
+ backbone = Package.__packages['backbone']
165
+ backbone._dag.add_node_if_not_exists(self._config.name)
166
+ backbone._dag.add_edge(iname, self._config.name)
167
+ num_pkg_input += 1
168
+ else:
169
+ raise KeyError(
170
+ 'invalid input name `%s`, must be the name of either a feature group or an another block'
171
+ % iname)
172
+ num_groups = len(input_feature_groups)
173
+ assert num_pkg_input > 0 or num_groups > 0, 'there must be at least one input layer/feature group'
174
+
175
+ if len(config.concat_blocks) == 0 and len(config.output_blocks) == 0:
176
+ leaf = self._dag.all_leaves()
177
+ logging.warning(
178
+ '%s has no `concat_blocks` or `output_blocks`, try to concat all leaf blocks: %s'
179
+ % (config.name, ','.join(leaf)))
180
+ self._config.concat_blocks.extend(leaf)
181
+
182
+ Package.__packages[self._config.name] = self
183
+ logging.info('%s layers: %s' %
184
+ (config.name, ','.join(self._name_to_layer.keys())))
185
+
186
+ def define_layers(self, layer, layer_cnf, name, reuse):
187
+ if layer == 'keras_layer':
188
+ layer_obj = self.load_keras_layer(layer_cnf.keras_layer, name, reuse)
189
+ self._name_to_layer[name] = layer_obj
190
+ elif layer == 'recurrent':
191
+ keras_layer = layer_cnf.recurrent.keras_layer
192
+ for i in range(layer_cnf.recurrent.num_steps):
193
+ name_i = '%s_%d' % (name, i)
194
+ layer_obj = self.load_keras_layer(keras_layer, name_i, reuse)
195
+ self._name_to_layer[name_i] = layer_obj
196
+ elif layer == 'repeat':
197
+ keras_layer = layer_cnf.repeat.keras_layer
198
+ for i in range(layer_cnf.repeat.num_repeat):
199
+ name_i = '%s_%d' % (name, i)
200
+ layer_obj = self.load_keras_layer(keras_layer, name_i, reuse)
201
+ self._name_to_layer[name_i] = layer_obj
202
+
203
+ def reset_input_config(self, config):
204
+ self.input_config = config
205
+
206
+ def set_package_input(self, pkg_input):
207
+ self._package_input = pkg_input
208
+
209
+ def has_block(self, name):
210
+ return name in self._name_to_blocks
211
+
212
+ def block_outputs(self, name):
213
+ return self._block_outputs.get(name, None)
214
+
215
+ def block_input(self, config, block_outputs, training=None, **kwargs):
216
+ inputs = []
217
+ for input_node in config.inputs:
218
+ input_type = input_node.WhichOneof('name')
219
+ input_name = getattr(input_node, input_type)
220
+ if input_type == 'use_package_input':
221
+ input_feature = self._package_input
222
+ input_name = 'package_input'
223
+ elif input_type == 'package_name':
224
+ if input_name not in Package.__packages:
225
+ raise KeyError('package name `%s` does not exists' % input_name)
226
+ package = Package.__packages[input_name]
227
+ if input_node.HasField('reset_input'):
228
+ package.reset_input_config(input_node.reset_input)
229
+ if input_node.HasField('package_input'):
230
+ pkg_input_name = input_node.package_input
231
+ if pkg_input_name in block_outputs:
232
+ pkg_input = block_outputs[pkg_input_name]
233
+ else:
234
+ if pkg_input_name not in Package.__packages:
235
+ raise KeyError('package name `%s` does not exists' %
236
+ pkg_input_name)
237
+ inner_package = Package.__packages[pkg_input_name]
238
+ pkg_input = inner_package(training)
239
+ if input_node.HasField('package_input_fn'):
240
+ fn = eval(input_node.package_input_fn)
241
+ pkg_input = fn(pkg_input)
242
+ package.set_package_input(pkg_input)
243
+ input_feature = package(training, **kwargs)
244
+ elif input_name in block_outputs:
245
+ input_feature = block_outputs[input_name]
246
+ else:
247
+ input_feature = Package.backbone_block_outputs(input_name)
248
+
249
+ if input_feature is None:
250
+ raise KeyError('input name `%s` does not exists' % input_name)
251
+
252
+ if input_node.ignore_input:
253
+ continue
254
+ if input_node.HasField('input_slice'):
255
+ fn = eval('lambda x: x' + input_node.input_slice.strip())
256
+ input_feature = fn(input_feature)
257
+ if input_node.HasField('input_fn'):
258
+ with tf.name_scope(config.name):
259
+ fn = eval(input_node.input_fn)
260
+ input_feature = fn(input_feature)
261
+ inputs.append(input_feature)
262
+
263
+ if config.merge_inputs_into_list:
264
+ output = inputs
265
+ else:
266
+ try:
267
+ output = merge_inputs(inputs, config.input_concat_axis, config.name)
268
+ except ValueError as e:
269
+ msg = getattr(e, 'message', str(e))
270
+ logging.error('merge inputs of block %s failed: %s', config.name, msg)
271
+ raise e
272
+
273
+ if config.HasField('extra_input_fn'):
274
+ fn = eval(config.extra_input_fn)
275
+ output = fn(output)
276
+ return output
277
+
278
+ def __call__(self, is_training, **kwargs):
279
+ with tf.name_scope(self._config.name):
280
+ return self.call(is_training, **kwargs)
281
+
282
+ def call(self, is_training, **kwargs):
283
+ block_outputs = {}
284
+ self._block_outputs = block_outputs # reset
285
+ blocks = self._dag.topological_sort()
286
+ logging.info(self._config.name + ' topological order: ' + ','.join(blocks))
287
+ for block in blocks:
288
+ if block not in self._name_to_blocks:
289
+ assert block in Package.__packages, 'invalid block: ' + block
290
+ continue
291
+ config = self._name_to_blocks[block]
292
+ if config.layers: # sequential layers
293
+ logging.info('call sequential %d layers' % len(config.layers))
294
+ output = self.block_input(config, block_outputs, is_training, **kwargs)
295
+ for i, layer in enumerate(config.layers):
296
+ name_i = '%s_l%d' % (block, i)
297
+ output = self.call_layer(output, layer, name_i, is_training, **kwargs)
298
+ block_outputs[block] = output
299
+ continue
300
+ # just one of layer
301
+ layer = config.WhichOneof('layer')
302
+ if layer is None: # identity layer
303
+ output = self.block_input(config, block_outputs, is_training, **kwargs)
304
+ block_outputs[block] = output
305
+ elif layer == 'raw_input':
306
+ block_outputs[block] = self._name_to_layer[block]
307
+ elif layer == 'input_layer':
308
+ input_fn = self._name_to_layer[block]
309
+ input_config = config.input_layer
310
+ if self.input_config is not None:
311
+ input_config = self.input_config
312
+ input_fn.reset(input_config, is_training)
313
+ block_outputs[block] = input_fn(input_config, is_training)
314
+ elif layer == 'embedding_layer':
315
+ input_fn = self._name_to_layer[block]
316
+ feature_group = config.inputs[0].feature_group_name
317
+ inputs, _, weights = self._feature_group_inputs[feature_group]
318
+ block_outputs[block] = input_fn([inputs, weights], is_training)
319
+ else:
320
+ with tf.name_scope(block + '_input'):
321
+ inputs = self.block_input(config, block_outputs, is_training,
322
+ **kwargs)
323
+ output = self.call_layer(inputs, config, block, is_training, **kwargs)
324
+ block_outputs[block] = output
325
+
326
+ outputs = []
327
+ for output in self._config.output_blocks:
328
+ if output in block_outputs:
329
+ temp = block_outputs[output]
330
+ outputs.append(temp)
331
+ else:
332
+ raise ValueError('No output `%s` of backbone to be concat' % output)
333
+ if outputs:
334
+ return outputs
335
+
336
+ for output in self._config.concat_blocks:
337
+ if output in block_outputs:
338
+ temp = block_outputs[output]
339
+ outputs.append(temp)
340
+ else:
341
+ raise ValueError('No output `%s` of backbone to be concat' % output)
342
+ try:
343
+ output = merge_inputs(outputs, msg='backbone')
344
+ except ValueError as e:
345
+ msg = getattr(e, 'message', str(e))
346
+ logging.error("merge backbone's output failed: %s", msg)
347
+ raise e
348
+ return output
349
+
350
+ def load_keras_layer(self, layer_conf, name, reuse=None):
351
+ layer_cls, customize = load_keras_layer(layer_conf.class_name)
352
+ if layer_cls is None:
353
+ raise ValueError('Invalid keras layer class name: ' +
354
+ layer_conf.class_name)
355
+
356
+ param_type = layer_conf.WhichOneof('params')
357
+ if customize:
358
+ if param_type is None or param_type == 'st_params':
359
+ params = Parameter(layer_conf.st_params, True, l2_reg=self._l2_reg)
360
+ else:
361
+ pb_params = getattr(layer_conf, param_type)
362
+ params = Parameter(pb_params, False, l2_reg=self._l2_reg)
363
+
364
+ has_reuse = True
365
+ try:
366
+ from funcsigs import signature
367
+ sig = signature(layer_cls.__init__)
368
+ has_reuse = 'reuse' in sig.parameters.keys()
369
+ except ImportError:
370
+ try:
371
+ from sklearn.externals.funcsigs import signature
372
+ sig = signature(layer_cls.__init__)
373
+ has_reuse = 'reuse' in sig.parameters.keys()
374
+ except ImportError:
375
+ logging.warning('import funcsigs failed')
376
+
377
+ if has_reuse:
378
+ layer = layer_cls(params, name=name, reuse=reuse)
379
+ else:
380
+ layer = layer_cls(params, name=name)
381
+ return layer, customize
382
+ elif param_type is None: # internal keras layer
383
+ layer = layer_cls(name=name)
384
+ return layer, customize
385
+ else:
386
+ assert param_type == 'st_params', 'internal keras layer only support st_params'
387
+ try:
388
+ kwargs = convert_to_dict(layer_conf.st_params)
389
+ logging.info('call %s layer with params %r' %
390
+ (layer_conf.class_name, kwargs))
391
+ layer = layer_cls(name=name, **kwargs)
392
+ except TypeError as e:
393
+ logging.warning(e)
394
+ args = map(format_value, layer_conf.st_params.values())
395
+ logging.info('try to call %s layer with params %r' %
396
+ (layer_conf.class_name, args))
397
+ layer = layer_cls(*args, name=name)
398
+ return layer, customize
399
+
400
+ def call_keras_layer(self, inputs, name, training, **kwargs):
401
+ """Call predefined Keras Layer, which can be reused."""
402
+ layer, customize = self._name_to_layer[name]
403
+ cls = layer.__class__.__name__
404
+ if customize:
405
+ try:
406
+ output = layer(inputs, training=training, **kwargs)
407
+ except Exception as e:
408
+ msg = getattr(e, 'message', str(e))
409
+ logging.error('call keras layer %s (%s) failed: %s' % (name, cls, msg))
410
+ raise e
411
+ else:
412
+ try:
413
+ output = layer(inputs, training=training)
414
+ if cls == 'BatchNormalization':
415
+ add_elements_to_collection(layer.updates, tf.GraphKeys.UPDATE_OPS)
416
+ except TypeError:
417
+ output = layer(inputs)
418
+ return output
419
+
420
+ def call_layer(self, inputs, config, name, training, **kwargs):
421
+ layer_name = config.WhichOneof('layer')
422
+ if layer_name == 'keras_layer':
423
+ return self.call_keras_layer(inputs, name, training, **kwargs)
424
+ if layer_name == 'lambda':
425
+ conf = getattr(config, 'lambda')
426
+ fn = eval(conf.expression)
427
+ return fn(inputs)
428
+ if layer_name == 'repeat':
429
+ conf = config.repeat
430
+ n_loop = conf.num_repeat
431
+ outputs = []
432
+ for i in range(n_loop):
433
+ name_i = '%s_%d' % (name, i)
434
+ ly_inputs = inputs
435
+ if conf.HasField('input_slice'):
436
+ fn = eval('lambda x, i: x' + conf.input_slice.strip())
437
+ ly_inputs = fn(ly_inputs, i)
438
+ if conf.HasField('input_fn'):
439
+ with tf.name_scope(config.name):
440
+ fn = eval(conf.input_fn)
441
+ ly_inputs = fn(ly_inputs, i)
442
+ output = self.call_keras_layer(ly_inputs, name_i, training, **kwargs)
443
+ outputs.append(output)
444
+ if len(outputs) == 1:
445
+ return outputs[0]
446
+ if conf.HasField('output_concat_axis'):
447
+ return tf.concat(outputs, conf.output_concat_axis)
448
+ return outputs
449
+ if layer_name == 'recurrent':
450
+ conf = config.recurrent
451
+ fixed_input_index = -1
452
+ if conf.HasField('fixed_input_index'):
453
+ fixed_input_index = conf.fixed_input_index
454
+ if fixed_input_index >= 0:
455
+ assert type(inputs) in (tuple, list), '%s inputs must be a list'
456
+ output = inputs
457
+ for i in range(conf.num_steps):
458
+ name_i = '%s_%d' % (name, i)
459
+ output_i = self.call_keras_layer(output, name_i, training, **kwargs)
460
+ if fixed_input_index >= 0:
461
+ j = 0
462
+ for idx in range(len(output)):
463
+ if idx == fixed_input_index:
464
+ continue
465
+ if type(output_i) in (tuple, list):
466
+ output[idx] = output_i[j]
467
+ else:
468
+ output[idx] = output_i
469
+ j += 1
470
+ else:
471
+ output = output_i
472
+ if fixed_input_index >= 0:
473
+ del output[fixed_input_index]
474
+ if len(output) == 1:
475
+ return output[0]
476
+ return output
477
+ return output
478
+
479
+ raise NotImplementedError('Unsupported backbone layer:' + layer_name)
480
+
481
+
482
+ class Backbone(object):
483
+ """Configurable Backbone Network."""
484
+
485
+ def __init__(self, config, features, input_layer, l2_reg=None):
486
+ self._config = config
487
+ self._l2_reg = l2_reg
488
+ main_pkg = backbone_pb2.BlockPackage()
489
+ main_pkg.name = 'backbone'
490
+ main_pkg.blocks.MergeFrom(config.blocks)
491
+ if config.concat_blocks:
492
+ main_pkg.concat_blocks.extend(config.concat_blocks)
493
+ if config.output_blocks:
494
+ main_pkg.output_blocks.extend(config.output_blocks)
495
+ self._main_pkg = Package(main_pkg, features, input_layer, l2_reg)
496
+ for pkg in config.packages:
497
+ Package(pkg, features, input_layer, l2_reg)
498
+
499
+ def __call__(self, is_training, **kwargs):
500
+ output = self._main_pkg(is_training, **kwargs)
501
+
502
+ if self._config.HasField('top_mlp'):
503
+ params = Parameter.make_from_pb(self._config.top_mlp)
504
+ params.l2_regularizer = self._l2_reg
505
+ final_mlp = MLP(params, name='backbone_top_mlp')
506
+ if type(output) in (list, tuple):
507
+ output = tf.concat(output, axis=-1)
508
+ output = final_mlp(output, training=is_training, **kwargs)
509
+ return output
510
+
511
+ @classmethod
512
+ def wide_embed_dim(cls, config):
513
+ wide_embed_dim = None
514
+ for pkg in config.packages:
515
+ wide_embed_dim = get_wide_embed_dim(pkg.blocks, wide_embed_dim)
516
+ return get_wide_embed_dim(config.blocks, wide_embed_dim)
517
+
518
+
519
+ def get_wide_embed_dim(blocks, wide_embed_dim=None):
520
+ for block in blocks:
521
+ layer = block.WhichOneof('layer')
522
+ if layer == 'input_layer':
523
+ if block.input_layer.HasField('wide_output_dim'):
524
+ wide_dim = block.input_layer.wide_output_dim
525
+ if wide_embed_dim:
526
+ assert wide_embed_dim == wide_dim, 'wide_output_dim must be consistent'
527
+ else:
528
+ wide_embed_dim = wide_dim
529
+ return wide_embed_dim
530
+
531
+
532
+ def merge_inputs(inputs, axis=-1, msg=''):
533
+ if len(inputs) == 0:
534
+ raise ValueError('no inputs to be concat:' + msg)
535
+ if len(inputs) == 1:
536
+ return inputs[0]
537
+
538
+ from functools import reduce
539
+ if all(map(lambda x: type(x) == list, inputs)):
540
+ # merge multiple lists into a list
541
+ return reduce(lambda x, y: x + y, inputs)
542
+
543
+ if any(map(lambda x: type(x) == list, inputs)):
544
+ logging.warning('%s: try to merge inputs into list' % msg)
545
+ return reduce(lambda x, y: x + y,
546
+ [e if type(e) == list else [e] for e in inputs])
547
+
548
+ if axis != -1:
549
+ logging.info('concat inputs %s axis=%d' % (msg, axis))
550
+ return tf.concat(inputs, axis=axis)
551
+
552
+
553
+ def format_value(value):
554
+ value_type = type(value)
555
+ if value_type == six.text_type:
556
+ return str(value)
557
+ if value_type == float:
558
+ int_v = int(value)
559
+ return int_v if int_v == value else value
560
+ if value_type == struct_pb2.ListValue:
561
+ return map(format_value, value)
562
+ if value_type == struct_pb2.Struct:
563
+ return convert_to_dict(value)
564
+ return value
565
+
566
+
567
+ def convert_to_dict(struct):
568
+ kwargs = {}
569
+ for key, value in struct.items():
570
+ kwargs[str(key)] = format_value(value)
571
+ return kwargs