easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
@@ -0,0 +1,878 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+
4
+ from __future__ import absolute_import
5
+ from __future__ import division
6
+ from __future__ import print_function
7
+
8
+ import json
9
+ import logging
10
+ import math
11
+ import os
12
+ import time
13
+
14
+ import six
15
+ import tensorflow as tf
16
+ from tensorflow.core.protobuf import saved_model_pb2
17
+
18
+ import easy_rec
19
+ from easy_rec.python.builders import strategy_builder
20
+ from easy_rec.python.compat import estimator_train
21
+ from easy_rec.python.compat import exporter
22
+ from easy_rec.python.input.input import Input
23
+ from easy_rec.python.model.easy_rec_estimator import EasyRecEstimator
24
+ from easy_rec.python.model.easy_rec_model import EasyRecModel
25
+ from easy_rec.python.protos.train_pb2 import DistributionStrategy
26
+ from easy_rec.python.utils import config_util
27
+ from easy_rec.python.utils import constant
28
+ from easy_rec.python.utils import estimator_utils
29
+ from easy_rec.python.utils import fg_util
30
+ from easy_rec.python.utils import load_class
31
+ from easy_rec.python.utils.config_util import get_eval_input_path
32
+ from easy_rec.python.utils.config_util import get_model_dir_path
33
+ from easy_rec.python.utils.config_util import get_train_input_path
34
+ from easy_rec.python.utils.config_util import set_eval_input_path
35
+ from easy_rec.python.utils.export_big_model import export_big_model
36
+ from easy_rec.python.utils.export_big_model import export_big_model_to_oss
37
+
38
+ try:
39
+ import horovod.tensorflow as hvd
40
+ except Exception:
41
+ hvd = None
42
+
43
+ if tf.__version__ >= '2.0':
44
+ from tensorflow.core.protobuf import config_pb2
45
+
46
+ ConfigProto = config_pb2.ConfigProto
47
+ GPUOptions = config_pb2.GPUOptions
48
+
49
+ tf = tf.compat.v1
50
+ else:
51
+ GPUOptions = tf.GPUOptions
52
+ ConfigProto = tf.ConfigProto
53
+
54
+ load_class.auto_import()
55
+
56
+ # when version of tensorflow > 1.8 strip_default_attrs set true will cause
57
+ # saved_model inference core, such as:
58
+ # [libprotobuf FATAL external/protobuf_archive/src/google/protobuf/map.h:1058]
59
+ # CHECK failed: it != end(): key not found: new_axis_mask
60
+ # so temporarily modify strip_default_attrs of _SavedModelExporter in
61
+ # tf.estimator.exporter to false by default
62
+ FinalExporter = exporter.FinalExporter
63
+ LatestExporter = exporter.LatestExporter
64
+ BestExporter = exporter.BestExporter
65
+
66
+
67
+ def _get_input_fn(data_config,
68
+ feature_configs,
69
+ data_path=None,
70
+ export_config=None,
71
+ check_mode=False,
72
+ **kwargs):
73
+ """Build estimator input function.
74
+
75
+ Args:
76
+ data_config: dataset config
77
+ feature_configs: FeatureConfig
78
+ data_path: input_data_path
79
+ export_config: configuration for exporting models,
80
+ only used to build input_fn when exporting models
81
+
82
+ Returns:
83
+ subclass of Input
84
+ """
85
+ input_class_map = {y: x for x, y in data_config.InputType.items()}
86
+ input_cls_name = input_class_map[data_config.input_type]
87
+ input_class = Input.create_class(input_cls_name)
88
+
89
+ task_id, task_num = estimator_utils.get_task_index_and_num()
90
+ input_obj = input_class(
91
+ data_config,
92
+ feature_configs,
93
+ data_path,
94
+ task_index=task_id,
95
+ task_num=task_num,
96
+ check_mode=check_mode,
97
+ **kwargs)
98
+ input_fn = input_obj.create_input(export_config)
99
+ return input_fn
100
+
101
+
102
+ def _create_estimator(pipeline_config, distribution=None, params={}):
103
+ model_config = pipeline_config.model_config
104
+ train_config = pipeline_config.train_config
105
+ gpu_options = GPUOptions(allow_growth=True) # False)
106
+
107
+ logging.info(
108
+ 'train_config.train_distribute=%s[value=%d]' %
109
+ (DistributionStrategy.Name(pipeline_config.train_config.train_distribute),
110
+ pipeline_config.train_config.train_distribute))
111
+
112
+ # set gpu options only under hvd scenes
113
+ if hvd is not None and pipeline_config.train_config.train_distribute in [
114
+ DistributionStrategy.EmbeddingParallelStrategy,
115
+ DistributionStrategy.SokStrategy, DistributionStrategy.HorovodStrategy
116
+ ]:
117
+ local_rnk = hvd.local_rank()
118
+ gpus = tf.config.experimental.list_physical_devices('GPU')
119
+ logging.info('local_rnk=%d num_gpus=%d' % (local_rnk, len(gpus)))
120
+ if len(gpus) > 0:
121
+ tf.config.experimental.set_visible_devices(gpus[local_rnk], 'GPU')
122
+ gpu_options.visible_device_list = str(local_rnk)
123
+
124
+ session_config = ConfigProto(
125
+ gpu_options=gpu_options,
126
+ allow_soft_placement=True,
127
+ log_device_placement=params.get('log_device_placement', False),
128
+ inter_op_parallelism_threads=train_config.inter_op_parallelism_threads,
129
+ intra_op_parallelism_threads=train_config.intra_op_parallelism_threads)
130
+
131
+ if constant.NO_ARITHMETRIC_OPTI in os.environ:
132
+ logging.info('arithmetic_optimization is closed to improve performance')
133
+ session_config.graph_options.rewrite_options.arithmetic_optimization = \
134
+ session_config.graph_options.rewrite_options.OFF
135
+
136
+ session_config.device_filters.append('/job:ps')
137
+ model_cls = EasyRecModel.create_class(model_config.model_class)
138
+
139
+ save_checkpoints_steps = None
140
+ save_checkpoints_secs = None
141
+ if train_config.HasField('save_checkpoints_steps'):
142
+ save_checkpoints_steps = train_config.save_checkpoints_steps
143
+ if train_config.HasField('save_checkpoints_secs'):
144
+ save_checkpoints_secs = train_config.save_checkpoints_secs
145
+ # if both `save_checkpoints_steps` and `save_checkpoints_secs` are not set,
146
+ # use the default value of save_checkpoints_steps
147
+ if save_checkpoints_steps is None and save_checkpoints_secs is None:
148
+ save_checkpoints_steps = train_config.save_checkpoints_steps
149
+
150
+ run_config = tf.estimator.RunConfig(
151
+ model_dir=pipeline_config.model_dir,
152
+ log_step_count_steps=None, # train_config.log_step_count_steps,
153
+ save_summary_steps=train_config.save_summary_steps,
154
+ save_checkpoints_steps=save_checkpoints_steps,
155
+ save_checkpoints_secs=save_checkpoints_secs,
156
+ keep_checkpoint_max=train_config.keep_checkpoint_max,
157
+ train_distribute=distribution,
158
+ eval_distribute=distribution,
159
+ session_config=session_config)
160
+
161
+ estimator = EasyRecEstimator(
162
+ pipeline_config, model_cls, run_config=run_config, params=params)
163
+ return estimator, run_config
164
+
165
+
166
+ def _create_eval_export_spec(pipeline_config, eval_data, check_mode=False):
167
+ data_config = pipeline_config.data_config
168
+ # feature_configs = pipeline_config.feature_configs
169
+ feature_configs = config_util.get_compatible_feature_configs(pipeline_config)
170
+ eval_config = pipeline_config.eval_config
171
+ export_config = pipeline_config.export_config
172
+ if eval_config.num_examples > 0:
173
+ eval_steps = int(
174
+ math.ceil(float(eval_config.num_examples) / data_config.batch_size))
175
+ logging.info('eval_steps = %d' % eval_steps)
176
+ else:
177
+ eval_steps = None
178
+ input_fn_kwargs = {'pipeline_config': pipeline_config}
179
+ if data_config.input_type == data_config.InputType.OdpsRTPInputV2:
180
+ input_fn_kwargs['fg_json_path'] = pipeline_config.fg_json_path
181
+ # create eval input
182
+ export_input_fn = _get_input_fn(
183
+ data_config,
184
+ feature_configs,
185
+ None,
186
+ export_config,
187
+ check_mode=check_mode,
188
+ **input_fn_kwargs)
189
+ if export_config.exporter_type == 'final':
190
+ exporters = [
191
+ FinalExporter(name='final', serving_input_receiver_fn=export_input_fn)
192
+ ]
193
+ elif export_config.exporter_type == 'latest':
194
+ exporters = [
195
+ LatestExporter(
196
+ name='latest',
197
+ serving_input_receiver_fn=export_input_fn,
198
+ exports_to_keep=export_config.exports_to_keep)
199
+ ]
200
+ elif export_config.exporter_type == 'best':
201
+ logging.info(
202
+ 'will use BestExporter, metric is %s, the bigger the better: %d' %
203
+ (export_config.best_exporter_metric, export_config.metric_bigger))
204
+
205
+ def _metric_cmp_fn(best_eval_result, current_eval_result):
206
+ logging.info('metric: best = %s current = %s' %
207
+ (str(best_eval_result), str(current_eval_result)))
208
+ if export_config.metric_bigger:
209
+ return (best_eval_result[export_config.best_exporter_metric] <
210
+ current_eval_result[export_config.best_exporter_metric])
211
+ else:
212
+ return (best_eval_result[export_config.best_exporter_metric] >
213
+ current_eval_result[export_config.best_exporter_metric])
214
+
215
+ exporters = [
216
+ BestExporter(
217
+ name='best',
218
+ serving_input_receiver_fn=export_input_fn,
219
+ compare_fn=_metric_cmp_fn,
220
+ exports_to_keep=export_config.exports_to_keep)
221
+ ]
222
+ elif export_config.exporter_type == 'none':
223
+ exporters = []
224
+ else:
225
+ raise ValueError('Unknown exporter type %s' % export_config.exporter_type)
226
+
227
+ # set throttle_secs to a small number, so that we can control evaluation
228
+ # interval steps by checkpoint saving steps
229
+ eval_input_fn = _get_input_fn(data_config, feature_configs, eval_data,
230
+ **input_fn_kwargs)
231
+ eval_spec = tf.estimator.EvalSpec(
232
+ name='val',
233
+ input_fn=eval_input_fn,
234
+ steps=eval_steps,
235
+ throttle_secs=10,
236
+ exporters=exporters)
237
+ return eval_spec
238
+
239
+
240
+ def _check_model_dir(model_dir, continue_train):
241
+ if not continue_train:
242
+ if not tf.gfile.IsDirectory(model_dir):
243
+ tf.gfile.MakeDirs(model_dir)
244
+ else:
245
+ assert len(tf.gfile.Glob(model_dir + '/model.ckpt-*.meta')) == 0, \
246
+ 'model_dir[=%s] already exists and not empty(if you ' \
247
+ 'want to continue train on current model_dir please ' \
248
+ 'delete dir %s or specify --continue_train[internal use only])' % (
249
+ model_dir, model_dir)
250
+ else:
251
+ if not tf.gfile.IsDirectory(model_dir):
252
+ logging.info('%s does not exists, create it automatically' % model_dir)
253
+ tf.gfile.MakeDirs(model_dir)
254
+
255
+
256
+ def _get_ckpt_path(pipeline_config, checkpoint_path):
257
+ if checkpoint_path != '' and checkpoint_path is not None:
258
+ if tf.gfile.IsDirectory(checkpoint_path):
259
+ ckpt_path = estimator_utils.latest_checkpoint(checkpoint_path)
260
+ else:
261
+ ckpt_path = checkpoint_path
262
+ elif tf.gfile.IsDirectory(pipeline_config.model_dir):
263
+ ckpt_path = estimator_utils.latest_checkpoint(pipeline_config.model_dir)
264
+ logging.info('checkpoint_path is not specified, '
265
+ 'will use latest checkpoint %s from %s' %
266
+ (ckpt_path, pipeline_config.model_dir))
267
+ else:
268
+ assert False, 'pipeline_config.model_dir(%s) does not exist' \
269
+ % pipeline_config.model_dir
270
+ return ckpt_path
271
+
272
+
273
+ def train_and_evaluate(pipeline_config_path, continue_train=False):
274
+ """Train and evaluate a EasyRec model defined in pipeline_config_path.
275
+
276
+ Build an EasyRecEstimator, and then train and evaluate the estimator.
277
+
278
+ Args:
279
+ pipeline_config_path: a path to EasyRecConfig object, specifies
280
+ train_config: model_config, data_config and eval_config
281
+ continue_train: whether to restart train from an existing
282
+ checkpoint
283
+ Returns:
284
+ None, the model will be saved into pipeline_config.model_dir
285
+ """
286
+ assert tf.gfile.Exists(
287
+ pipeline_config_path), 'pipeline_config_path not exists'
288
+ pipeline_config = config_util.get_configs_from_pipeline_file(
289
+ pipeline_config_path)
290
+
291
+ _train_and_evaluate_impl(pipeline_config, continue_train)
292
+
293
+ return pipeline_config
294
+
295
+
296
+ def _train_and_evaluate_impl(pipeline_config,
297
+ continue_train=False,
298
+ check_mode=False,
299
+ fit_on_eval=False,
300
+ fit_on_eval_steps=None):
301
+ train_config = pipeline_config.train_config
302
+ data_config = pipeline_config.data_config
303
+ feature_configs = config_util.get_compatible_feature_configs(pipeline_config)
304
+
305
+ if train_config.train_distribute != DistributionStrategy.NoStrategy\
306
+ and train_config.sync_replicas:
307
+ logging.warning(
308
+ 'will set sync_replicas to False, because train_distribute[%s] != NoStrategy'
309
+ % pipeline_config.train_config.train_distribute)
310
+ pipeline_config.train_config.sync_replicas = False
311
+
312
+ train_data = get_train_input_path(pipeline_config)
313
+ eval_data = get_eval_input_path(pipeline_config)
314
+
315
+ distribution = strategy_builder.build(train_config)
316
+ params = {}
317
+ if train_config.is_profiling:
318
+ params['log_device_placement'] = True
319
+ estimator, run_config = _create_estimator(
320
+ pipeline_config, distribution=distribution, params=params)
321
+
322
+ version_file = os.path.join(pipeline_config.model_dir, 'version')
323
+ if estimator_utils.is_chief():
324
+ _check_model_dir(pipeline_config.model_dir, continue_train)
325
+ config_util.save_pipeline_config(pipeline_config, pipeline_config.model_dir)
326
+ with tf.gfile.GFile(version_file, 'w') as f:
327
+ f.write(easy_rec.__version__ + '\n')
328
+
329
+ train_steps = None
330
+ if train_config.HasField('num_steps') and train_config.num_steps > 0:
331
+ train_steps = train_config.num_steps
332
+ assert train_steps is not None or data_config.num_epochs > 0, (
333
+ 'either num_steps and num_epochs must be set to an integer > 0.')
334
+
335
+ if train_steps and data_config.num_epochs:
336
+ logging.info('Both num_steps and num_epochs are set.')
337
+ is_sync = train_config.sync_replicas
338
+ batch_size = data_config.batch_size
339
+ epoch_str = 'sample_num * %d / %d' % (data_config.num_epochs, batch_size)
340
+ if is_sync:
341
+ _, worker_num = estimator_utils.get_task_index_and_num()
342
+ epoch_str += ' / ' + str(worker_num)
343
+ logging.info('Will train min(%d, %s) steps...' % (train_steps, epoch_str))
344
+
345
+ input_fn_kwargs = {'pipeline_config': pipeline_config}
346
+ if data_config.input_type == data_config.InputType.OdpsRTPInputV2:
347
+ input_fn_kwargs['fg_json_path'] = pipeline_config.fg_json_path
348
+
349
+ # create train input
350
+ train_input_fn = _get_input_fn(
351
+ data_config,
352
+ feature_configs,
353
+ train_data,
354
+ check_mode=check_mode,
355
+ **input_fn_kwargs)
356
+ # Currently only a single Eval Spec is allowed.
357
+ train_spec = tf.estimator.TrainSpec(
358
+ input_fn=train_input_fn, max_steps=train_steps)
359
+
360
+ embedding_parallel = train_config.train_distribute in (
361
+ DistributionStrategy.SokStrategy,
362
+ DistributionStrategy.EmbeddingParallelStrategy)
363
+
364
+ if embedding_parallel:
365
+ estimator.train(
366
+ input_fn=train_input_fn,
367
+ max_steps=train_spec.max_steps,
368
+ hooks=list(train_spec.hooks),
369
+ saving_listeners=train_spec.saving_listeners)
370
+ train_input_fn.input_creator.stop()
371
+ else:
372
+ # create eval spec
373
+ eval_spec = _create_eval_export_spec(
374
+ pipeline_config, eval_data, check_mode=check_mode)
375
+ estimator_train.train_and_evaluate(estimator, train_spec, eval_spec)
376
+ logging.info('Train and evaluate finish')
377
+ if fit_on_eval and (not estimator_utils.is_evaluator()):
378
+ tf.reset_default_graph()
379
+ logging.info('Start continue training on eval data')
380
+ eval_input_fn = _get_input_fn(data_config, feature_configs, eval_data,
381
+ **input_fn_kwargs)
382
+ if fit_on_eval_steps is not None:
383
+ # wait estimator train done to get the correct train_steps
384
+ while not estimator_train.estimator_train_done(estimator):
385
+ time.sleep(1)
386
+ train_steps = estimator_utils.get_trained_steps(estimator.model_dir)
387
+ logging.info('\ttrain_steps=%d fit_on_eval_steps=%d' %
388
+ (train_steps, fit_on_eval_steps))
389
+ fit_on_eval_steps += train_steps
390
+ # Do not use estimator_train.train_and_evaluate as it starts tf.Server,
391
+ # which is redundant and reports port not available error.
392
+ estimator.train(
393
+ input_fn=eval_input_fn,
394
+ max_steps=fit_on_eval_steps,
395
+ hooks=list(train_spec.hooks),
396
+ saving_listeners=train_spec.saving_listeners if hasattr(
397
+ train_spec, 'saving_listeners') else None)
398
+ logging.info('Finished training on eval data')
399
+ # return estimator for custom training using estimator.train
400
+ return estimator
401
+
402
+
403
+ def evaluate(pipeline_config,
404
+ eval_checkpoint_path='',
405
+ eval_data_path=None,
406
+ eval_result_filename='eval_result.txt'):
407
+ """Evaluate a EasyRec model defined in pipeline_config_path.
408
+
409
+ Evaluate the model defined in pipeline_config_path on the eval data,
410
+ the metrics will be displayed on tensorboard and saved into eval_result.txt.
411
+
412
+ Args:
413
+ pipeline_config: either EasyRecConfig path or its instance
414
+ eval_checkpoint_path: if specified, will use this model instead of
415
+ model specified by model_dir in pipeline_config_path
416
+ eval_data_path: eval data path, default use eval data in pipeline_config
417
+ could be a path or a list of paths
418
+ eval_result_filename: evaluation result metrics save path.
419
+
420
+ Returns:
421
+ A dict of evaluation metrics: the metrics are specified in
422
+ pipeline_config_path
423
+ global_step: the global step for which this evaluation was performed.
424
+
425
+ Raises:
426
+ AssertionError, if:
427
+ * pipeline_config_path does not exist
428
+ """
429
+ pipeline_config = config_util.get_configs_from_pipeline_file(pipeline_config)
430
+ if pipeline_config.fg_json_path:
431
+ fg_util.load_fg_json_to_config(pipeline_config)
432
+ if eval_data_path is not None:
433
+ logging.info('Evaluating on data: %s' % eval_data_path)
434
+ set_eval_input_path(pipeline_config, eval_data_path)
435
+
436
+ train_config = pipeline_config.train_config
437
+ eval_data = get_eval_input_path(pipeline_config)
438
+
439
+ server_target = None
440
+ if 'TF_CONFIG' in os.environ:
441
+ tf_config = estimator_utils.chief_to_master()
442
+ from tensorflow.python.training import server_lib
443
+ if tf_config['task']['type'] == 'ps':
444
+ cluster = tf.train.ClusterSpec(tf_config['cluster'])
445
+ server = server_lib.Server(
446
+ cluster, job_name='ps', task_index=tf_config['task']['index'])
447
+ server.join()
448
+ elif tf_config['task']['type'] == 'master':
449
+ if 'ps' in tf_config['cluster']:
450
+ cluster = tf.train.ClusterSpec(tf_config['cluster'])
451
+ server = server_lib.Server(cluster, job_name='master', task_index=0)
452
+ server_target = server.target
453
+ print('server_target = %s' % server_target)
454
+
455
+ distribution = strategy_builder.build(train_config)
456
+ estimator, run_config = _create_estimator(pipeline_config, distribution)
457
+ eval_spec = _create_eval_export_spec(pipeline_config, eval_data)
458
+ ckpt_path = _get_ckpt_path(pipeline_config, eval_checkpoint_path)
459
+
460
+ if server_target:
461
+ # evaluate with parameter server
462
+ input_iter = eval_spec.input_fn(
463
+ mode=tf.estimator.ModeKeys.EVAL).make_one_shot_iterator()
464
+ input_feas, input_lbls = input_iter.get_next()
465
+ from tensorflow.python.training.device_setter import replica_device_setter
466
+ from tensorflow.python.framework.ops import device
467
+ from tensorflow.python.training.monitored_session import MonitoredSession
468
+ from tensorflow.python.training.monitored_session import ChiefSessionCreator
469
+ with device(
470
+ replica_device_setter(
471
+ worker_device='/job:master/task:0', cluster=cluster)):
472
+ estimator_spec = estimator._eval_model_fn(input_feas, input_lbls,
473
+ run_config)
474
+
475
+ session_config = ConfigProto(
476
+ allow_soft_placement=True, log_device_placement=True)
477
+ chief_sess_creator = ChiefSessionCreator(
478
+ master=server_target,
479
+ checkpoint_filename_with_path=ckpt_path,
480
+ config=session_config)
481
+ eval_metric_ops = estimator_spec.eval_metric_ops
482
+ update_ops = [eval_metric_ops[x][1] for x in eval_metric_ops.keys()]
483
+ metric_ops = {x: eval_metric_ops[x][0] for x in eval_metric_ops.keys()}
484
+ update_op = tf.group(update_ops)
485
+ with MonitoredSession(
486
+ session_creator=chief_sess_creator,
487
+ hooks=None,
488
+ stop_grace_period_secs=120) as sess:
489
+ while True:
490
+ try:
491
+ sess.run(update_op)
492
+ except tf.errors.OutOfRangeError:
493
+ break
494
+ eval_result = sess.run(metric_ops)
495
+ else:
496
+ # this way does not work, wait to be debugged
497
+ # the variables are not placed to parameter server
498
+ # with tf.device(
499
+ # replica_device_setter(
500
+ # worker_device='/job:master/task:0', cluster=cluster)):
501
+ eval_result = estimator.evaluate(
502
+ eval_spec.input_fn, eval_spec.steps, checkpoint_path=ckpt_path)
503
+ eval_spec.input_fn.input_creator.stop()
504
+ logging.info('Evaluate finish')
505
+
506
+ print('eval_result = ', eval_result)
507
+ logging.info('eval_result = {0}'.format(eval_result))
508
+ # write eval result to file
509
+ model_dir = pipeline_config.model_dir
510
+ eval_result_file = os.path.join(model_dir, eval_result_filename)
511
+ logging.info('save eval result to file %s' % eval_result_file)
512
+ with tf.gfile.GFile(eval_result_file, 'w') as ofile:
513
+ result_to_write = {}
514
+ for key in sorted(eval_result):
515
+ # skip logging binary data
516
+ if isinstance(eval_result[key], six.binary_type):
517
+ continue
518
+ # convert numpy float to python float
519
+ result_to_write[key] = eval_result[key].item()
520
+ ofile.write(json.dumps(result_to_write, indent=2))
521
+ return eval_result
522
+
523
+
524
+ def distribute_evaluate(pipeline_config,
525
+ eval_checkpoint_path='',
526
+ eval_data_path=None,
527
+ eval_result_filename='distribute_eval_result.txt'):
528
+ """Evaluate a EasyRec model defined in pipeline_config_path.
529
+
530
+ Evaluate the model defined in pipeline_config_path on the eval data,
531
+ the metrics will be displayed on tensorboard and saved into eval_result.txt.
532
+
533
+ Args:
534
+ pipeline_config: either EasyRecConfig path or its instance
535
+ eval_checkpoint_path: if specified, will use this model instead of
536
+ model specified by model_dir in pipeline_config_path
537
+ eval_data_path: eval data path, default use eval data in pipeline_config
538
+ could be a path or a list of paths
539
+ eval_result_filename: evaluation result metrics save path.
540
+
541
+ Returns:
542
+ A dict of evaluation metrics: the metrics are specified in
543
+ pipeline_config_path
544
+ global_step: the global step for which this evaluation was performed.
545
+
546
+ Raises:
547
+ AssertionError, if:
548
+ * pipeline_config_path does not exist
549
+ """
550
+ pipeline_config = config_util.get_configs_from_pipeline_file(pipeline_config)
551
+ if eval_data_path is not None:
552
+ logging.info('Evaluating on data: %s' % eval_data_path)
553
+ set_eval_input_path(pipeline_config, eval_data_path)
554
+ train_config = pipeline_config.train_config
555
+ eval_data = get_eval_input_path(pipeline_config)
556
+ data_config = pipeline_config.data_config
557
+ if data_config.HasField('sampler'):
558
+ logging.warning(
559
+ 'It is not accuracy to use eval with negative sampler, recommand to use hitrate.py!'
560
+ )
561
+ eval_result = {}
562
+ return eval_result
563
+ model_dir = get_model_dir_path(pipeline_config)
564
+ eval_tmp_results_dir = os.path.join(model_dir, 'distribute_eval_tmp_results')
565
+ if not tf.gfile.IsDirectory(eval_tmp_results_dir):
566
+ logging.info('create eval tmp results dir {}'.format(eval_tmp_results_dir))
567
+ tf.gfile.MakeDirs(eval_tmp_results_dir)
568
+ assert tf.gfile.IsDirectory(
569
+ eval_tmp_results_dir), 'tmp results dir not create success.'
570
+ os.environ['eval_tmp_results_dir'] = eval_tmp_results_dir
571
+
572
+ server_target = None
573
+ cur_job_name = None
574
+ if 'TF_CONFIG' in os.environ:
575
+ tf_config = estimator_utils.chief_to_master()
576
+
577
+ from tensorflow.python.training import server_lib
578
+ if tf_config['task']['type'] == 'ps':
579
+ cluster = tf.train.ClusterSpec(tf_config['cluster'])
580
+ server = server_lib.Server(
581
+ cluster, job_name='ps', task_index=tf_config['task']['index'])
582
+ server.join()
583
+ elif tf_config['task']['type'] == 'master':
584
+ if 'ps' in tf_config['cluster']:
585
+ cur_job_name = tf_config['task']['type']
586
+ cur_task_index = tf_config['task']['index']
587
+ cluster = tf.train.ClusterSpec(tf_config['cluster'])
588
+ server = server_lib.Server(
589
+ cluster, job_name=cur_job_name, task_index=cur_task_index)
590
+ server_target = server.target
591
+ print('server_target = %s' % server_target)
592
+ elif tf_config['task']['type'] == 'worker':
593
+ if 'ps' in tf_config['cluster']:
594
+ cur_job_name = tf_config['task']['type']
595
+ cur_task_index = tf_config['task']['index']
596
+ cluster = tf.train.ClusterSpec(tf_config['cluster'])
597
+ server = server_lib.Server(
598
+ cluster, job_name=cur_job_name, task_index=cur_task_index)
599
+ server_target = server.target
600
+ print('server_target = %s' % server_target)
601
+
602
+ if server_target:
603
+ from tensorflow.python.training.device_setter import replica_device_setter
604
+ from tensorflow.python.framework.ops import device
605
+ from tensorflow.python.training.monitored_session import MonitoredSession
606
+ from tensorflow.python.training.monitored_session import ChiefSessionCreator
607
+ from tensorflow.python.training.monitored_session import WorkerSessionCreator
608
+ from easy_rec.python.utils.estimator_utils import EvaluateExitBarrierHook
609
+ cur_work_device = '/job:' + cur_job_name + '/task:' + str(cur_task_index)
610
+ cur_ps_num = len(tf_config['cluster']['ps'])
611
+ with device(
612
+ replica_device_setter(
613
+ ps_tasks=cur_ps_num, worker_device=cur_work_device,
614
+ cluster=cluster)):
615
+ distribution = strategy_builder.build(train_config)
616
+ estimator, run_config = _create_estimator(pipeline_config, distribution)
617
+ eval_spec = _create_eval_export_spec(pipeline_config, eval_data)
618
+ ckpt_path = _get_ckpt_path(pipeline_config, eval_checkpoint_path)
619
+ ckpt_dir = os.path.dirname(ckpt_path)
620
+ input_iter = eval_spec.input_fn(
621
+ mode=tf.estimator.ModeKeys.EVAL).make_one_shot_iterator()
622
+ input_feas, input_lbls = input_iter.get_next()
623
+ estimator_spec = estimator._distribute_eval_model_fn(
624
+ input_feas, input_lbls, run_config)
625
+
626
+ session_config = ConfigProto(
627
+ allow_soft_placement=True,
628
+ log_device_placement=True,
629
+ device_filters=['/job:ps',
630
+ '/job:worker/task:%d' % cur_task_index])
631
+ if cur_job_name == 'master':
632
+ metric_variables = tf.get_collection(tf.GraphKeys.METRIC_VARIABLES)
633
+ model_ready_for_local_init_op = tf.variables_initializer(metric_variables)
634
+ global_variables = tf.global_variables()
635
+ remain_variables = list(
636
+ set(global_variables).difference(set(metric_variables)))
637
+ cur_saver = tf.train.Saver(var_list=remain_variables, sharded=True)
638
+ cur_scaffold = tf.train.Scaffold(
639
+ saver=cur_saver,
640
+ ready_for_local_init_op=model_ready_for_local_init_op)
641
+ cur_sess_creator = ChiefSessionCreator(
642
+ scaffold=cur_scaffold,
643
+ master=server_target,
644
+ checkpoint_filename_with_path=ckpt_path,
645
+ config=session_config)
646
+ else:
647
+ cur_sess_creator = WorkerSessionCreator(
648
+ master=server_target, config=session_config)
649
+ eval_metric_ops = estimator_spec.eval_metric_ops
650
+ update_ops = [eval_metric_ops[x][1] for x in eval_metric_ops.keys()]
651
+ metric_ops = {x: eval_metric_ops[x][0] for x in eval_metric_ops.keys()}
652
+ update_op = tf.group(update_ops)
653
+ cur_worker_num = len(tf_config['cluster']['worker']) + 1
654
+ if cur_job_name == 'master':
655
+ cur_stop_grace_period_sesc = 120
656
+ cur_hooks = EvaluateExitBarrierHook(cur_worker_num, True, ckpt_dir,
657
+ metric_ops)
658
+ else:
659
+ cur_stop_grace_period_sesc = 10
660
+ cur_hooks = EvaluateExitBarrierHook(cur_worker_num, False, ckpt_dir,
661
+ metric_ops)
662
+ with MonitoredSession(
663
+ session_creator=cur_sess_creator,
664
+ hooks=[cur_hooks],
665
+ stop_grace_period_secs=cur_stop_grace_period_sesc) as sess:
666
+ while True:
667
+ try:
668
+ sess.run(update_op)
669
+ except tf.errors.OutOfRangeError:
670
+ break
671
+ eval_result = cur_hooks.eval_result
672
+
673
+ logging.info('Evaluate finish')
674
+
675
+ # write eval result to file
676
+ model_dir = pipeline_config.model_dir
677
+ eval_result_file = os.path.join(model_dir, eval_result_filename)
678
+ logging.info('save eval result to file %s' % eval_result_file)
679
+ if cur_job_name == 'master':
680
+ print('eval_result = ', eval_result)
681
+ logging.info('eval_result = {0}'.format(eval_result))
682
+ with tf.gfile.GFile(eval_result_file, 'w') as ofile:
683
+ result_to_write = {'eval_method': 'distribute'}
684
+ for key in sorted(eval_result):
685
+ # skip logging binary data
686
+ if isinstance(eval_result[key], six.binary_type):
687
+ continue
688
+ # convert numpy float to python float
689
+ result_to_write[key] = eval_result[key].item()
690
+
691
+ ofile.write(json.dumps(result_to_write))
692
+ return eval_result
693
+
694
+
695
+ def predict(pipeline_config, checkpoint_path='', data_path=None):
696
+ """Predict a EasyRec model defined in pipeline_config_path.
697
+
698
+ Predict the model defined in pipeline_config_path on the eval data.
699
+
700
+ Args:
701
+ pipeline_config: either EasyRecConfig path or its instance
702
+ checkpoint_path: if specified, will use this model instead of
703
+ model specified by model_dir in pipeline_config_path
704
+ data_path: data path, default use eval data in pipeline_config
705
+ could be a path or a list of paths
706
+
707
+ Returns:
708
+ A list of dict of predict results
709
+
710
+ Raises:
711
+ AssertionError, if:
712
+ * pipeline_config_path does not exist
713
+ """
714
+ pipeline_config = config_util.get_configs_from_pipeline_file(pipeline_config)
715
+ if pipeline_config.fg_json_path:
716
+ fg_util.load_fg_json_to_config(pipeline_config)
717
+ if data_path is not None:
718
+ logging.info('Predict on data: %s' % data_path)
719
+ set_eval_input_path(pipeline_config, data_path)
720
+ train_config = pipeline_config.train_config
721
+ eval_data = get_eval_input_path(pipeline_config)
722
+
723
+ distribution = strategy_builder.build(train_config)
724
+ estimator, _ = _create_estimator(pipeline_config, distribution)
725
+ eval_spec = _create_eval_export_spec(pipeline_config, eval_data)
726
+
727
+ ckpt_path = _get_ckpt_path(pipeline_config, checkpoint_path)
728
+
729
+ pred_result = estimator.predict(eval_spec.input_fn, checkpoint_path=ckpt_path)
730
+ logging.info('Predict finish')
731
+ return pred_result
732
+
733
+
734
+ def export(export_dir,
735
+ pipeline_config,
736
+ checkpoint_path='',
737
+ asset_files=None,
738
+ verbose=False,
739
+ **extra_params):
740
+ """Export model defined in pipeline_config_path.
741
+
742
+ Args:
743
+ export_dir: base directory where the model should be exported
744
+ pipeline_config: proto.EasyRecConfig instance or file path
745
+ specify proto.EasyRecConfig
746
+ checkpoint_path: if specified, will use this model instead of
747
+ model in model_dir in pipeline_config_path
748
+ asset_files: extra files to add to assets, comma separated;
749
+ if asset file variable in graph need to be renamed,
750
+ specify by new_file_name:file_path
751
+ version: if version is defined, then will skip writing embedding to redis,
752
+ assume that embedding is already write into redis
753
+ verbose: dumps debug information
754
+ extra_params: keys related to write embedding to redis/oss
755
+ redis_url, redis_passwd, redis_threads, redis_batch_size,
756
+ redis_timeout, redis_expire if export embedding to redis;
757
+ redis_embedding_version: if specified, will kill export to redis
758
+ --
759
+ oss_path, oss_endpoint, oss_ak, oss_sk, oss_timeout,
760
+ oss_expire, oss_write_kv, oss_embedding_version
761
+
762
+ Returns:
763
+ the directory where model is exported
764
+
765
+ Raises:
766
+ AssertionError, if:
767
+ * pipeline_config_path does not exist
768
+ """
769
+ if not tf.gfile.Exists(export_dir):
770
+ tf.gfile.MakeDirs(export_dir)
771
+
772
+ pipeline_config = config_util.get_configs_from_pipeline_file(pipeline_config)
773
+ if pipeline_config.fg_json_path:
774
+ fg_util.load_fg_json_to_config(pipeline_config)
775
+ # feature_configs = pipeline_config.feature_configs
776
+ feature_configs = config_util.get_compatible_feature_configs(pipeline_config)
777
+ # create estimator
778
+ params = {'log_device_placement': verbose}
779
+ if asset_files:
780
+ logging.info('will add asset files: %s' % asset_files)
781
+ asset_file_dict = {}
782
+ for asset_file in asset_files.split(','):
783
+ asset_file = asset_file.strip()
784
+ if ':' not in asset_file or asset_file.startswith(
785
+ 'oss:') or asset_file.startswith('hdfs:'):
786
+ _, asset_name = os.path.split(asset_file)
787
+ else:
788
+ asset_name, asset_file = asset_file.split(':', 1)
789
+ asset_file_dict[asset_name] = asset_file
790
+ params['asset_files'] = asset_file_dict
791
+ estimator, _ = _create_estimator(pipeline_config, params=params)
792
+ # construct serving input fn
793
+ export_config = pipeline_config.export_config
794
+ data_config = pipeline_config.data_config
795
+ input_fn_kwargs = {'pipeline_config': pipeline_config}
796
+ if data_config.input_type == data_config.InputType.OdpsRTPInputV2:
797
+ input_fn_kwargs['fg_json_path'] = pipeline_config.fg_json_path
798
+ serving_input_fn = _get_input_fn(data_config, feature_configs, None,
799
+ export_config, **input_fn_kwargs)
800
+ ckpt_path = _get_ckpt_path(pipeline_config, checkpoint_path)
801
+ if 'oss_path' in extra_params:
802
+ if pipeline_config.train_config.HasField('incr_save_config'):
803
+ incr_save_config = pipeline_config.train_config.incr_save_config
804
+ extra_params['incr_update'] = {}
805
+ incr_save_type = incr_save_config.WhichOneof('incr_update')
806
+ logging.info('incr_save_type=%s' % incr_save_type)
807
+ if incr_save_type:
808
+ extra_params['incr_update'][incr_save_type] = getattr(
809
+ incr_save_config, incr_save_type)
810
+ return export_big_model_to_oss(export_dir, pipeline_config, extra_params,
811
+ serving_input_fn, estimator, ckpt_path,
812
+ verbose)
813
+
814
+ if 'redis_url' in extra_params:
815
+ return export_big_model(export_dir, pipeline_config, extra_params,
816
+ serving_input_fn, estimator, ckpt_path, verbose)
817
+
818
+ final_export_dir = estimator.export_savedmodel(
819
+ export_dir_base=export_dir,
820
+ serving_input_receiver_fn=serving_input_fn,
821
+ checkpoint_path=ckpt_path,
822
+ strip_default_attrs=True)
823
+
824
+ # add export ts as version info
825
+ saved_model = saved_model_pb2.SavedModel()
826
+ if type(final_export_dir) not in [type(''), type(u'')]:
827
+ final_export_dir = final_export_dir.decode('utf-8')
828
+ export_ts = [
829
+ x for x in final_export_dir.split('/') if x != '' and x is not None
830
+ ]
831
+ export_ts = export_ts[-1]
832
+ saved_pb_path = os.path.join(final_export_dir, 'saved_model.pb')
833
+ with tf.gfile.GFile(saved_pb_path, 'rb') as fin:
834
+ saved_model.ParseFromString(fin.read())
835
+ saved_model.meta_graphs[0].meta_info_def.meta_graph_version = export_ts
836
+ with tf.gfile.GFile(saved_pb_path, 'wb') as fout:
837
+ fout.write(saved_model.SerializeToString())
838
+
839
+ logging.info('model has been exported to %s successfully' % final_export_dir)
840
+ return final_export_dir
841
+
842
+
843
+ def export_checkpoint(pipeline_config=None,
844
+ export_path='',
845
+ checkpoint_path='',
846
+ asset_files=None,
847
+ verbose=False,
848
+ mode=tf.estimator.ModeKeys.PREDICT):
849
+ """Export the EasyRec model as checkpoint."""
850
+ pipeline_config = config_util.get_configs_from_pipeline_file(pipeline_config)
851
+ if pipeline_config.fg_json_path:
852
+ fg_util.load_fg_json_to_config(pipeline_config)
853
+ feature_configs = config_util.get_compatible_feature_configs(pipeline_config)
854
+ data_config = pipeline_config.data_config
855
+
856
+ input_fn_kwargs = {'pipeline_config': pipeline_config}
857
+ if data_config.input_type == data_config.InputType.OdpsRTPInputV2:
858
+ input_fn_kwargs['fg_json_path'] = pipeline_config.fg_json_path
859
+
860
+ # create estimator
861
+ params = {'log_device_placement': verbose}
862
+ if asset_files:
863
+ logging.info('will add asset files: %s' % asset_files)
864
+ params['asset_files'] = asset_files
865
+ estimator, _ = _create_estimator(pipeline_config, params=params)
866
+
867
+ # construct serving input fn
868
+ export_config = pipeline_config.export_config
869
+ serving_input_fn = _get_input_fn(data_config, feature_configs, None,
870
+ export_config, **input_fn_kwargs)
871
+ ckpt_path = _get_ckpt_path(pipeline_config, checkpoint_path)
872
+ estimator.export_checkpoint(
873
+ export_path=export_path,
874
+ serving_input_receiver_fn=serving_input_fn,
875
+ checkpoint_path=ckpt_path,
876
+ mode=mode)
877
+
878
+ logging.info('model checkpoint has been exported successfully')