easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
@@ -0,0 +1,739 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ from __future__ import print_function
4
+
5
+ import json
6
+ import logging
7
+ import os
8
+ import re
9
+ import time
10
+ from collections import OrderedDict
11
+
12
+ import tensorflow as tf
13
+ from tensorflow.python.client import session as tf_session
14
+ from tensorflow.python.eager import context
15
+ from tensorflow.python.framework import ops
16
+ from tensorflow.python.ops import variables
17
+ from tensorflow.python.platform import gfile
18
+ from tensorflow.python.saved_model import signature_constants
19
+ from tensorflow.python.training import basic_session_run_hooks
20
+ from tensorflow.python.training import saver
21
+
22
+ from easy_rec.python.builders import optimizer_builder
23
+ from easy_rec.python.compat import optimizers
24
+ from easy_rec.python.compat import sync_replicas_optimizer
25
+ from easy_rec.python.compat.early_stopping import custom_early_stop_hook
26
+ from easy_rec.python.compat.early_stopping import deadline_stop_hook
27
+ from easy_rec.python.compat.early_stopping import find_early_stop_var
28
+ from easy_rec.python.compat.early_stopping import oss_stop_hook
29
+ from easy_rec.python.compat.early_stopping import stop_if_no_decrease_hook
30
+ from easy_rec.python.compat.early_stopping import stop_if_no_increase_hook
31
+ from easy_rec.python.compat.ops import GraphKeys
32
+ from easy_rec.python.input.input import Input
33
+ from easy_rec.python.layers.utils import _tensor_to_tensorinfo
34
+ from easy_rec.python.protos.pipeline_pb2 import EasyRecConfig
35
+ from easy_rec.python.protos.train_pb2 import DistributionStrategy
36
+ from easy_rec.python.utils import constant
37
+ from easy_rec.python.utils import embedding_utils
38
+ from easy_rec.python.utils import estimator_utils
39
+ from easy_rec.python.utils import hvd_utils
40
+ from easy_rec.python.utils import pai_util
41
+ from easy_rec.python.utils.multi_optimizer import MultiOptimizer
42
+
43
+ from easy_rec.python.compat.embedding_parallel_saver import EmbeddingParallelSaver # NOQA
44
+
45
+ try:
46
+ import horovod.tensorflow as hvd
47
+ except Exception:
48
+ hvd = None
49
+
50
+ try:
51
+ from sparse_operation_kit import experiment as sok
52
+ from easy_rec.python.compat import sok_optimizer
53
+ except Exception:
54
+ sok = None
55
+
56
+ if tf.__version__ >= '2.0':
57
+ tf = tf.compat.v1
58
+
59
+ tf.estimator.Estimator._assert_members_are_not_overridden = lambda x: x
60
+
61
+
62
+ class EasyRecEstimator(tf.estimator.Estimator):
63
+
64
+ def __init__(self, pipeline_config, model_cls, run_config, params):
65
+ self._pipeline_config = pipeline_config
66
+ self._model_cls = model_cls
67
+ assert isinstance(self._pipeline_config, EasyRecConfig)
68
+
69
+ super(EasyRecEstimator, self).__init__(
70
+ model_fn=self._model_fn,
71
+ model_dir=pipeline_config.model_dir,
72
+ config=run_config,
73
+ params=params)
74
+
75
+ def evaluate(self,
76
+ input_fn,
77
+ steps=None,
78
+ hooks=None,
79
+ checkpoint_path=None,
80
+ name=None):
81
+ # support for datahub/kafka offset restore
82
+ input_fn.input_creator.restore(checkpoint_path)
83
+ return super(EasyRecEstimator, self).evaluate(input_fn, steps, hooks,
84
+ checkpoint_path, name)
85
+
86
+ def train(self,
87
+ input_fn,
88
+ hooks=None,
89
+ steps=None,
90
+ max_steps=None,
91
+ saving_listeners=None):
92
+ # support for datahub/kafka offset restore
93
+ checkpoint_path = estimator_utils.latest_checkpoint(self.model_dir)
94
+ if checkpoint_path is not None:
95
+ input_fn.input_creator.restore(checkpoint_path)
96
+ elif self.train_config.HasField('fine_tune_checkpoint'):
97
+ fine_tune_ckpt = self.train_config.fine_tune_checkpoint
98
+ if fine_tune_ckpt.endswith('/') or gfile.IsDirectory(fine_tune_ckpt +
99
+ '/'):
100
+ fine_tune_ckpt = estimator_utils.latest_checkpoint(fine_tune_ckpt)
101
+ print(
102
+ 'fine_tune_checkpoint[%s] is directory, will use the latest checkpoint: %s'
103
+ % (self.train_config.fine_tune_checkpoint, fine_tune_ckpt))
104
+ self.train_config.fine_tune_checkpoint = fine_tune_ckpt
105
+ input_fn.input_creator.restore(fine_tune_ckpt)
106
+ return super(EasyRecEstimator, self).train(input_fn, hooks, steps,
107
+ max_steps, saving_listeners)
108
+
109
+ @property
110
+ def feature_configs(self):
111
+ if len(self._pipeline_config.feature_configs) > 0:
112
+ return self._pipeline_config.feature_configs
113
+ elif self._pipeline_config.feature_config and len(
114
+ self._pipeline_config.feature_config.features) > 0:
115
+ return self._pipeline_config.feature_config.features
116
+ else:
117
+ assert False, 'One of feature_configs and feature_config.features must be configured.'
118
+
119
+ @property
120
+ def model_config(self):
121
+ return self._pipeline_config.model_config
122
+
123
+ @property
124
+ def eval_config(self):
125
+ return self._pipeline_config.eval_config
126
+
127
+ @property
128
+ def train_config(self):
129
+ return self._pipeline_config.train_config
130
+
131
+ @property
132
+ def incr_save_config(self):
133
+ return self.train_config.incr_save_config if self.train_config.HasField(
134
+ 'incr_save_config') else None
135
+
136
+ @property
137
+ def export_config(self):
138
+ return self._pipeline_config.export_config
139
+
140
+ @property
141
+ def embedding_parallel(self):
142
+ return self.train_config.train_distribute in (
143
+ DistributionStrategy.SokStrategy,
144
+ DistributionStrategy.EmbeddingParallelStrategy)
145
+
146
+ @property
147
+ def saver_cls(self):
148
+ # when embedding parallel is used, will use the extended
149
+ # saver class (EmbeddingParallelSaver) to save sharded embedding
150
+ tmp_saver_cls = saver.Saver
151
+ if self.embedding_parallel:
152
+ tmp_saver_cls = EmbeddingParallelSaver
153
+ return tmp_saver_cls
154
+
155
+ def _train_model_fn(self, features, labels, run_config):
156
+ tf.keras.backend.set_learning_phase(1)
157
+ model = self._model_cls(
158
+ self.model_config,
159
+ self.feature_configs,
160
+ features,
161
+ labels,
162
+ is_training=True)
163
+ predict_dict = model.build_predict_graph()
164
+ loss_dict = model.build_loss_graph()
165
+
166
+ regularization_losses = tf.get_collection(
167
+ tf.GraphKeys.REGULARIZATION_LOSSES)
168
+ if regularization_losses:
169
+ regularization_losses = [
170
+ reg_loss.get() if hasattr(reg_loss, 'get') else reg_loss
171
+ for reg_loss in regularization_losses
172
+ ]
173
+ regularization_losses = tf.add_n(
174
+ regularization_losses, name='regularization_loss')
175
+ loss_dict['regularization_loss'] = regularization_losses
176
+
177
+ variational_dropout_loss = tf.get_collection('variational_dropout_loss')
178
+ if variational_dropout_loss:
179
+ variational_dropout_loss = tf.add_n(
180
+ variational_dropout_loss, name='variational_dropout_loss')
181
+ loss_dict['variational_dropout_loss'] = variational_dropout_loss
182
+
183
+ loss = tf.add_n(list(loss_dict.values()))
184
+ loss_dict['total_loss'] = loss
185
+ for key in loss_dict:
186
+ tf.summary.scalar(key, loss_dict[key], family='loss')
187
+
188
+ if Input.DATA_OFFSET in features:
189
+ task_index, task_num = estimator_utils.get_task_index_and_num()
190
+ data_offset_var = tf.get_variable(
191
+ name=Input.DATA_OFFSET,
192
+ dtype=tf.string,
193
+ shape=[task_num],
194
+ collections=[tf.GraphKeys.GLOBAL_VARIABLES, Input.DATA_OFFSET],
195
+ trainable=False)
196
+ update_offset = tf.assign(data_offset_var[task_index],
197
+ features[Input.DATA_OFFSET])
198
+ ops.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_offset)
199
+ else:
200
+ data_offset_var = None
201
+
202
+ # update op, usually used for batch-norm
203
+ update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
204
+ if update_ops:
205
+ # register for increment update, such as batchnorm moving_mean and moving_variance
206
+ global_vars = {x.name: x for x in tf.global_variables()}
207
+ for x in update_ops:
208
+ if isinstance(x, ops.Operation) and x.inputs[0].name in global_vars:
209
+ ops.add_to_collection(constant.DENSE_UPDATE_VARIABLES,
210
+ global_vars[x.inputs[0].name])
211
+ update_op = tf.group(*update_ops, name='update_barrier')
212
+ with tf.control_dependencies([update_op]):
213
+ loss = tf.identity(loss, name='total_loss')
214
+
215
+ # build optimizer
216
+ if len(self.train_config.optimizer_config) == 1:
217
+ optimizer_config = self.train_config.optimizer_config[0]
218
+ optimizer, learning_rate = optimizer_builder.build(optimizer_config)
219
+ tf.summary.scalar('learning_rate', learning_rate[0])
220
+ else:
221
+ optimizer_config = self.train_config.optimizer_config
222
+ all_opts = []
223
+ for opti_id, tmp_config in enumerate(optimizer_config):
224
+ with tf.name_scope('optimizer_%d' % opti_id):
225
+ opt, learning_rate = optimizer_builder.build(tmp_config)
226
+ tf.summary.scalar('learning_rate', learning_rate[0])
227
+ all_opts.append(opt)
228
+ grouped_vars = model.get_grouped_vars(len(all_opts))
229
+ assert len(grouped_vars) == len(optimizer_config), \
230
+ 'the number of var group(%d) != the number of optimizers(%d)' \
231
+ % (len(grouped_vars), len(optimizer_config))
232
+ optimizer = MultiOptimizer(all_opts, grouped_vars)
233
+
234
+ if self.train_config.train_distribute == DistributionStrategy.SokStrategy:
235
+ optimizer = sok_optimizer.OptimizerWrapper(optimizer)
236
+
237
+ hooks = []
238
+ if estimator_utils.has_hvd():
239
+ assert not self.train_config.sync_replicas, \
240
+ 'sync_replicas should not be set when using horovod'
241
+ bcast_hook = hvd_utils.BroadcastGlobalVariablesHook(0)
242
+ hooks.append(bcast_hook)
243
+
244
+ # for distributed and synced training
245
+ if self.train_config.sync_replicas and run_config.num_worker_replicas > 1:
246
+ logging.info('sync_replicas: num_worker_replias = %d' %
247
+ run_config.num_worker_replicas)
248
+ if pai_util.is_on_pai():
249
+ optimizer = tf.train.SyncReplicasOptimizer(
250
+ optimizer,
251
+ replicas_to_aggregate=run_config.num_worker_replicas,
252
+ total_num_replicas=run_config.num_worker_replicas,
253
+ sparse_accumulator_type=self.train_config.sparse_accumulator_type)
254
+ else:
255
+ optimizer = sync_replicas_optimizer.SyncReplicasOptimizer(
256
+ optimizer,
257
+ replicas_to_aggregate=run_config.num_worker_replicas,
258
+ total_num_replicas=run_config.num_worker_replicas)
259
+ hooks.append(
260
+ optimizer.make_session_run_hook(run_config.is_chief, num_tokens=0))
261
+
262
+ # add barrier for no strategy case
263
+ if run_config.num_worker_replicas > 1 and \
264
+ self.train_config.train_distribute == DistributionStrategy.NoStrategy:
265
+ hooks.append(
266
+ estimator_utils.ExitBarrierHook(run_config.num_worker_replicas,
267
+ run_config.is_chief, self.model_dir))
268
+
269
+ if self.export_config.enable_early_stop:
270
+ eval_dir = os.path.join(self._model_dir, 'eval_val')
271
+ logging.info('will use early stop, eval_events_dir=%s' % eval_dir)
272
+ if self.export_config.HasField('early_stop_func'):
273
+ hooks.append(
274
+ custom_early_stop_hook(
275
+ self,
276
+ eval_dir=eval_dir,
277
+ custom_stop_func=self.export_config.early_stop_func,
278
+ custom_stop_func_params=self.export_config.early_stop_params))
279
+ elif self.export_config.metric_bigger:
280
+ hooks.append(
281
+ stop_if_no_increase_hook(
282
+ self,
283
+ self.export_config.best_exporter_metric,
284
+ self.export_config.max_check_steps,
285
+ eval_dir=eval_dir))
286
+ else:
287
+ hooks.append(
288
+ stop_if_no_decrease_hook(
289
+ self,
290
+ self.export_config.best_exporter_metric,
291
+ self.export_config.max_check_steps,
292
+ eval_dir=eval_dir))
293
+
294
+ if self.train_config.enable_oss_stop_signal:
295
+ hooks.append(oss_stop_hook(self))
296
+
297
+ if self.train_config.HasField('dead_line'):
298
+ hooks.append(deadline_stop_hook(self, self.train_config.dead_line))
299
+
300
+ summaries = ['global_gradient_norm']
301
+ if self.train_config.summary_model_vars:
302
+ summaries.extend(['gradient_norm', 'gradients'])
303
+
304
+ gradient_clipping_by_norm = self.train_config.gradient_clipping_by_norm
305
+ if gradient_clipping_by_norm <= 0:
306
+ gradient_clipping_by_norm = None
307
+
308
+ gradient_multipliers = None
309
+ if self.train_config.optimizer_config[0].HasField(
310
+ 'embedding_learning_rate_multiplier'):
311
+ gradient_multipliers = {
312
+ var: self.train_config.optimizer_config[0]
313
+ .embedding_learning_rate_multiplier
314
+ for var in tf.trainable_variables()
315
+ if 'embedding_weights:' in var.name or
316
+ '/embedding_weights/part_' in var.name
317
+ }
318
+
319
+ # optimize loss
320
+ # colocate_gradients_with_ops=True means to compute gradients
321
+ # on the same device on which op is processes in forward process
322
+ all_train_vars = []
323
+ if len(self.train_config.freeze_gradient) > 0:
324
+ for one_var in tf.trainable_variables():
325
+ is_freeze = False
326
+ for x in self.train_config.freeze_gradient:
327
+ if re.search(x, one_var.name) is not None:
328
+ logging.info('will freeze gradients of %s' % one_var.name)
329
+ is_freeze = True
330
+ break
331
+ if not is_freeze:
332
+ all_train_vars.append(one_var)
333
+ else:
334
+ all_train_vars = tf.trainable_variables()
335
+
336
+ if self.embedding_parallel:
337
+ logging.info('embedding_parallel is enabled')
338
+
339
+ train_op = optimizers.optimize_loss(
340
+ loss=loss,
341
+ global_step=tf.train.get_global_step(),
342
+ learning_rate=None,
343
+ clip_gradients=gradient_clipping_by_norm,
344
+ optimizer=optimizer,
345
+ gradient_multipliers=gradient_multipliers,
346
+ variables=all_train_vars,
347
+ summaries=summaries,
348
+ colocate_gradients_with_ops=True,
349
+ not_apply_grad_after_first_step=run_config.is_chief and
350
+ self._pipeline_config.data_config.chief_redundant,
351
+ name='', # Preventing scope prefix on all variables.
352
+ incr_save=(self.incr_save_config is not None),
353
+ embedding_parallel=self.embedding_parallel)
354
+
355
+ # online evaluation
356
+ metric_update_op_dict = None
357
+ if self.eval_config.eval_online:
358
+ metric_update_op_dict = {}
359
+ metric_dict = model.build_metric_graph(self.eval_config)
360
+ for k, v in metric_dict.items():
361
+ metric_update_op_dict['%s/batch' % k] = v[1]
362
+ if isinstance(v[1], tf.Tensor):
363
+ tf.summary.scalar('%s/batch' % k, v[1])
364
+ train_op = tf.group([train_op] + list(metric_update_op_dict.values()))
365
+ if estimator_utils.is_chief():
366
+ hooks.append(
367
+ estimator_utils.OnlineEvaluationHook(
368
+ metric_dict=metric_dict, output_dir=self.model_dir))
369
+
370
+ if self.train_config.HasField('fine_tune_checkpoint'):
371
+ fine_tune_ckpt = self.train_config.fine_tune_checkpoint
372
+ logging.warning('will restore from %s' % fine_tune_ckpt)
373
+ fine_tune_ckpt_var_map = self.train_config.fine_tune_ckpt_var_map
374
+ force_restore = self.train_config.force_restore_shape_compatible
375
+ restore_hook = model.restore(
376
+ fine_tune_ckpt,
377
+ include_global_step=False,
378
+ ckpt_var_map_path=fine_tune_ckpt_var_map,
379
+ force_restore_shape_compatible=force_restore)
380
+ if restore_hook is not None:
381
+ hooks.append(restore_hook)
382
+
383
+ # logging
384
+ logging_dict = OrderedDict()
385
+ logging_dict['step'] = tf.train.get_global_step()
386
+ logging_dict['lr'] = learning_rate[0]
387
+ logging_dict.update(loss_dict)
388
+ if metric_update_op_dict is not None:
389
+ logging_dict.update(metric_update_op_dict)
390
+
391
+ log_step_count_steps = self.train_config.log_step_count_steps
392
+ logging_hook = basic_session_run_hooks.LoggingTensorHook(
393
+ logging_dict,
394
+ every_n_iter=log_step_count_steps,
395
+ formatter=estimator_utils.tensor_log_format_func)
396
+ hooks.append(logging_hook)
397
+
398
+ if self.train_config.train_distribute in [
399
+ DistributionStrategy.CollectiveAllReduceStrategy,
400
+ DistributionStrategy.MirroredStrategy,
401
+ DistributionStrategy.MultiWorkerMirroredStrategy
402
+ ]:
403
+ # for multi worker strategy, we could not replace the
404
+ # inner CheckpointSaverHook, so just use it.
405
+ scaffold = tf.train.Scaffold()
406
+ else:
407
+ var_list = (
408
+ tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) +
409
+ tf.get_collection(tf.GraphKeys.SAVEABLE_OBJECTS))
410
+
411
+ # exclude data_offset_var
412
+ var_list = [x for x in var_list if x != data_offset_var]
413
+ # early_stop flag will not be saved in checkpoint
414
+ # and could not be restored from checkpoint
415
+ early_stop_var = find_early_stop_var(var_list)
416
+ var_list = [x for x in var_list if x != early_stop_var]
417
+
418
+ initialize_var_list = [
419
+ x for x in var_list if 'WorkQueue' not in str(type(x))
420
+ ]
421
+
422
+ # incompatiable shape restore will not be saved in checkpoint
423
+ # but must be able to restore from checkpoint
424
+ incompatiable_shape_restore = tf.get_collection('T_E_M_P_RESTROE')
425
+
426
+ local_init_ops = [tf.train.Scaffold.default_local_init_op()]
427
+ if data_offset_var is not None and estimator_utils.is_chief():
428
+ local_init_ops.append(tf.initializers.variables([data_offset_var]))
429
+ if early_stop_var is not None and estimator_utils.is_chief():
430
+ local_init_ops.append(tf.initializers.variables([early_stop_var]))
431
+ if len(incompatiable_shape_restore) > 0:
432
+ local_init_ops.append(
433
+ tf.initializers.variables(incompatiable_shape_restore))
434
+
435
+ scaffold = tf.train.Scaffold(
436
+ saver=self.saver_cls(
437
+ var_list=var_list,
438
+ sharded=True,
439
+ max_to_keep=self.train_config.keep_checkpoint_max,
440
+ save_relative_paths=True),
441
+ local_init_op=tf.group(local_init_ops),
442
+ ready_for_local_init_op=tf.report_uninitialized_variables(
443
+ var_list=initialize_var_list))
444
+ # saver hook
445
+ saver_hook = estimator_utils.CheckpointSaverHook(
446
+ checkpoint_dir=self.model_dir,
447
+ save_secs=self._config.save_checkpoints_secs,
448
+ save_steps=self._config.save_checkpoints_steps,
449
+ scaffold=scaffold,
450
+ write_graph=self.train_config.write_graph,
451
+ data_offset_var=data_offset_var,
452
+ increment_save_config=self.incr_save_config)
453
+ if estimator_utils.is_chief() or self.embedding_parallel:
454
+ hooks.append(saver_hook)
455
+ if estimator_utils.is_chief():
456
+ hooks.append(
457
+ basic_session_run_hooks.StepCounterHook(
458
+ every_n_steps=log_step_count_steps, output_dir=self.model_dir))
459
+
460
+ # profiling hook
461
+ if self.train_config.is_profiling and estimator_utils.is_chief():
462
+ profile_hook = tf.train.ProfilerHook(
463
+ save_steps=log_step_count_steps, output_dir=self.model_dir)
464
+ hooks.append(profile_hook)
465
+
466
+ return tf.estimator.EstimatorSpec(
467
+ mode=tf.estimator.ModeKeys.TRAIN,
468
+ loss=loss,
469
+ predictions=predict_dict,
470
+ train_op=train_op,
471
+ scaffold=scaffold,
472
+ training_hooks=hooks)
473
+
474
+ def _eval_model_fn(self, features, labels, run_config):
475
+ tf.keras.backend.set_learning_phase(0)
476
+ start = time.time()
477
+ model = self._model_cls(
478
+ self.model_config,
479
+ self.feature_configs,
480
+ features,
481
+ labels,
482
+ is_training=False)
483
+ predict_dict = model.build_predict_graph()
484
+ loss_dict = model.build_loss_graph()
485
+ loss = tf.add_n(list(loss_dict.values()))
486
+ loss_dict['total_loss'] = loss
487
+
488
+ metric_dict = model.build_metric_graph(self.eval_config)
489
+ for loss_key in loss_dict.keys():
490
+ loss_tensor = loss_dict[loss_key]
491
+ # add key-prefix to make loss metric key in the same family of train loss
492
+ metric_dict['loss/loss/' + loss_key] = tf.metrics.mean(loss_tensor)
493
+ tf.logging.info('metric_dict keys: %s' % metric_dict.keys())
494
+
495
+ var_list = (
496
+ ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) +
497
+ ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS))
498
+
499
+ metric_variables = ops.get_collection(ops.GraphKeys.METRIC_VARIABLES)
500
+ model_ready_for_local_init_op = tf.variables_initializer(metric_variables)
501
+ scaffold = tf.train.Scaffold(
502
+ saver=self.saver_cls(
503
+ var_list=var_list, sharded=True, save_relative_paths=True),
504
+ ready_for_local_init_op=model_ready_for_local_init_op)
505
+ end = time.time()
506
+ tf.logging.info('eval graph construct finished. Time %.3fs' % (end - start))
507
+ return tf.estimator.EstimatorSpec(
508
+ mode=tf.estimator.ModeKeys.EVAL,
509
+ loss=loss,
510
+ scaffold=scaffold,
511
+ predictions=predict_dict,
512
+ eval_metric_ops=metric_dict)
513
+
514
+ def _distribute_eval_model_fn(self, features, labels, run_config):
515
+ tf.keras.backend.set_learning_phase(0)
516
+ start = time.time()
517
+ model = self._model_cls(
518
+ self.model_config,
519
+ self.feature_configs,
520
+ features,
521
+ labels,
522
+ is_training=False)
523
+ predict_dict = model.build_predict_graph()
524
+ loss_dict = model.build_loss_graph()
525
+ loss = tf.add_n(list(loss_dict.values()))
526
+ loss_dict['total_loss'] = loss
527
+ metric_dict = model.build_metric_graph(self.eval_config)
528
+ for loss_key in loss_dict.keys():
529
+ loss_tensor = loss_dict[loss_key]
530
+ # add key-prefix to make loss metric key in the same family of train loss
531
+ metric_dict['loss/loss/' + loss_key] = tf.metrics.mean(loss_tensor)
532
+ tf.logging.info('metric_dict keys: %s' % metric_dict.keys())
533
+
534
+ end = time.time()
535
+ tf.logging.info('eval graph construct finished. Time %.3fs' % (end - start))
536
+ metric_name_list = []
537
+ for metric_i in self.eval_config.metrics_set:
538
+ metric_name_list.append(metric_i.WhichOneof('metric'))
539
+ all_var_list = []
540
+ metric_var_list = []
541
+ for var in variables._all_saveable_objects():
542
+ var_name = var.name
543
+ flag = True
544
+ for metric_i in metric_name_list:
545
+ if metric_i in var_name:
546
+ flag = False
547
+ break
548
+ if flag:
549
+ all_var_list.append(var)
550
+ else:
551
+ metric_var_list.append(var)
552
+ global_variables = tf.global_variables()
553
+ metric_variables = tf.get_collection(tf.GraphKeys.METRIC_VARIABLES)
554
+ model_ready_for_local_init_op = tf.variables_initializer(metric_variables)
555
+ remain_variables = list(
556
+ set(global_variables).difference(set(metric_variables)))
557
+ cur_saver = tf.train.Saver(var_list=remain_variables, sharded=True)
558
+ scaffold = tf.train.Scaffold(
559
+ saver=cur_saver, ready_for_local_init_op=model_ready_for_local_init_op)
560
+ return tf.estimator.EstimatorSpec(
561
+ mode=tf.estimator.ModeKeys.EVAL,
562
+ loss=loss,
563
+ predictions=predict_dict,
564
+ eval_metric_ops=metric_dict,
565
+ scaffold=scaffold)
566
+
567
+ def _export_model_fn(self, features, labels, run_config, params):
568
+ tf.keras.backend.set_learning_phase(0)
569
+ model = self._model_cls(
570
+ self.model_config,
571
+ self.feature_configs,
572
+ features,
573
+ labels=None,
574
+ is_training=False)
575
+ model.build_predict_graph()
576
+
577
+ export_config = self._pipeline_config.export_config
578
+ outputs = {}
579
+ logging.info('building default outputs')
580
+ outputs.update(model.build_output_dict())
581
+ if export_config.export_features:
582
+ logging.info('building output features')
583
+ outputs.update(model.build_feature_output_dict())
584
+ if export_config.export_rtp_outputs:
585
+ logging.info('building RTP outputs')
586
+ outputs.update(model.build_rtp_output_dict())
587
+
588
+ for out in outputs:
589
+ tf.logging.info(
590
+ 'output %s shape: %s type: %s' %
591
+ (out, outputs[out].get_shape().as_list(), outputs[out].dtype))
592
+ export_outputs = {
593
+ signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
594
+ tf.estimator.export.PredictOutput(outputs)
595
+ }
596
+
597
+ # save train pipeline.config for debug purpose
598
+ pipeline_path = os.path.join(self._model_dir, 'pipeline.config')
599
+ if gfile.Exists(pipeline_path):
600
+ ops.add_to_collection(
601
+ tf.GraphKeys.ASSET_FILEPATHS,
602
+ tf.constant(pipeline_path, dtype=tf.string, name='pipeline.config'))
603
+ else:
604
+ print('train pipeline_path(%s) does not exist' % pipeline_path)
605
+
606
+ # restore DENSE_UPDATE_VARIABLES collection
607
+ dense_train_var_path = os.path.join(self.model_dir,
608
+ constant.DENSE_UPDATE_VARIABLES)
609
+ if gfile.Exists(dense_train_var_path):
610
+ with gfile.GFile(dense_train_var_path, 'r') as fin:
611
+ var_name_to_id_map = json.load(fin)
612
+ var_name_id_lst = [
613
+ (x, var_name_to_id_map[x]) for x in var_name_to_id_map
614
+ ]
615
+ var_name_id_lst.sort(key=lambda x: x[1])
616
+ all_vars = {x.op.name: x for x in tf.global_variables()}
617
+ for var_name, var_id in var_name_id_lst:
618
+ assert var_name in all_vars, 'dense_train_var[%s] is not found' % var_name
619
+ ops.add_to_collection(constant.DENSE_UPDATE_VARIABLES,
620
+ all_vars[var_name])
621
+
622
+ # add more asset files
623
+ if len(export_config.asset_files) > 0:
624
+ for asset_file in export_config.asset_files:
625
+ if asset_file.startswith('!'):
626
+ asset_file = asset_file[1:]
627
+ _, asset_name = os.path.split(asset_file)
628
+ ops.add_to_collection(
629
+ ops.GraphKeys.ASSET_FILEPATHS,
630
+ tf.constant(asset_file, dtype=tf.string, name=asset_name))
631
+ elif 'asset_files' in params:
632
+ for asset_name in params['asset_files']:
633
+ asset_file = params['asset_files'][asset_name]
634
+ ops.add_to_collection(
635
+ tf.GraphKeys.ASSET_FILEPATHS,
636
+ tf.constant(asset_file, dtype=tf.string, name=asset_name))
637
+
638
+ if self._pipeline_config.HasField('fg_json_path'):
639
+ fg_path = self._pipeline_config.fg_json_path
640
+ if fg_path[0] == '!':
641
+ fg_path = fg_path[1:]
642
+ ops.add_to_collection(
643
+ tf.GraphKeys.ASSET_FILEPATHS,
644
+ tf.constant(fg_path, dtype=tf.string, name='fg.json'))
645
+
646
+ var_list = (
647
+ ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) +
648
+ ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS))
649
+
650
+ scaffold = tf.train.Scaffold(
651
+ saver=self.saver_cls(
652
+ var_list=var_list, sharded=True, save_relative_paths=True))
653
+
654
+ return tf.estimator.EstimatorSpec(
655
+ mode=tf.estimator.ModeKeys.PREDICT,
656
+ loss=None,
657
+ scaffold=scaffold,
658
+ predictions=outputs,
659
+ export_outputs=export_outputs)
660
+
661
+ def _model_fn(self, features, labels, mode, config, params):
662
+ os.environ['tf.estimator.mode'] = mode
663
+ os.environ['tf.estimator.ModeKeys.TRAIN'] = tf.estimator.ModeKeys.TRAIN
664
+ if self._pipeline_config.feature_config.embedding_on_cpu:
665
+ os.environ['place_embedding_on_cpu'] = 'True'
666
+ if self._pipeline_config.fg_json_path:
667
+ EasyRecEstimator._write_rtp_fg_config_to_col(
668
+ fg_config_path=self._pipeline_config.fg_json_path)
669
+ EasyRecEstimator._write_rtp_inputs_to_col(features)
670
+
671
+ if self.embedding_parallel:
672
+ embedding_utils.set_embedding_parallel()
673
+
674
+ if mode == tf.estimator.ModeKeys.TRAIN:
675
+ return self._train_model_fn(features, labels, config)
676
+ elif mode == tf.estimator.ModeKeys.EVAL:
677
+ return self._eval_model_fn(features, labels, config)
678
+ elif mode == tf.estimator.ModeKeys.PREDICT:
679
+ return self._export_model_fn(features, labels, config, params)
680
+
681
+ @staticmethod
682
+ def _write_rtp_fg_config_to_col(fg_config=None, fg_config_path=None):
683
+ """Write RTP config to RTP-specified graph collections.
684
+
685
+ Args:
686
+ fg_config: JSON-dict RTP config. If set, fg_config_path will be ignored.
687
+ fg_config_path: path to the RTP config file.
688
+ """
689
+ if fg_config is None:
690
+ if fg_config_path.startswith('!'):
691
+ fg_config_path = fg_config_path[1:]
692
+ with gfile.GFile(fg_config_path, 'r') as f:
693
+ fg_config = json.load(f)
694
+ col = ops.get_collection_ref(GraphKeys.RANK_SERVICE_FG_CONF)
695
+ if len(col) == 0:
696
+ col.append(json.dumps(fg_config))
697
+ else:
698
+ col[0] = json.dumps(fg_config)
699
+
700
+ @staticmethod
701
+ def _write_rtp_inputs_to_col(features):
702
+ """Write input nodes information to RTP-specified graph collections.
703
+
704
+ Args:
705
+ features: the feature dictionary used as model input.
706
+ """
707
+ feature_info_map = dict()
708
+ for feature_name, feature_value in features.items():
709
+ feature_info = _tensor_to_tensorinfo(feature_value)
710
+ feature_info_map[feature_name] = feature_info
711
+ col = ops.get_collection_ref(GraphKeys.RANK_SERVICE_FEATURE_NODE)
712
+ if len(col) == 0:
713
+ col.append(json.dumps(feature_info_map))
714
+ else:
715
+ col[0] = json.dumps(feature_info_map)
716
+
717
+ def export_checkpoint(self,
718
+ export_path=None,
719
+ serving_input_receiver_fn=None,
720
+ checkpoint_path=None,
721
+ mode=tf.estimator.ModeKeys.PREDICT):
722
+ with context.graph_mode():
723
+ if not checkpoint_path:
724
+ # Locate the latest checkpoint
725
+ checkpoint_path = estimator_utils.latest_checkpoint(self._model_dir)
726
+ if not checkpoint_path:
727
+ raise ValueError("Couldn't find trained model at %s." % self._model_dir)
728
+ with ops.Graph().as_default():
729
+ input_receiver = serving_input_receiver_fn()
730
+ estimator_spec = self._call_model_fn(
731
+ features=input_receiver.features,
732
+ labels=getattr(input_receiver, 'labels', None),
733
+ mode=mode,
734
+ config=self.config)
735
+ with tf_session.Session(config=self._session_config) as session:
736
+ graph_saver = estimator_spec.scaffold.saver or saver.Saver(
737
+ sharded=True)
738
+ graph_saver.restore(session, checkpoint_path)
739
+ graph_saver.save(session, export_path)