easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
@@ -0,0 +1,162 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ """Add embedding column for EmbeddingVariable which is only available on pai."""
4
+
5
+ from tensorflow.python.framework import dtypes
6
+ from tensorflow.python.framework import ops
7
+ from tensorflow.python.framework import sparse_tensor
8
+ from tensorflow.python.framework import tensor_shape
9
+ from tensorflow.python.ops import array_ops
10
+ from tensorflow.python.ops import embedding_ops
11
+ from tensorflow.python.ops import math_ops
12
+ from tensorflow.python.ops import sparse_ops
13
+
14
+
15
+ def _prune_invalid_ids(sparse_ids, sparse_weights):
16
+ """Prune invalid IDs (< 0) from the input ids and weights."""
17
+ is_id_valid = math_ops.greater_equal(sparse_ids.values, 0)
18
+ if sparse_weights is not None:
19
+ is_id_valid = math_ops.logical_and(
20
+ is_id_valid,
21
+ array_ops.ones_like(sparse_weights.values, dtype=dtypes.bool))
22
+ sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_id_valid)
23
+ if sparse_weights is not None:
24
+ sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_id_valid)
25
+ return sparse_ids, sparse_weights
26
+
27
+
28
+ def _prune_invalid_weights(sparse_ids, sparse_weights):
29
+ """Prune invalid weights (< 0) from the input ids and weights."""
30
+ if sparse_weights is not None:
31
+ is_weights_valid = math_ops.greater(sparse_weights.values, 0)
32
+ sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_weights_valid)
33
+ sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_weights_valid)
34
+ return sparse_ids, sparse_weights
35
+
36
+
37
+ def safe_embedding_lookup_sparse(embedding_weights,
38
+ sparse_ids,
39
+ sparse_weights=None,
40
+ combiner='mean',
41
+ default_id=None,
42
+ name=None,
43
+ partition_strategy='div',
44
+ max_norm=None):
45
+ """Lookup embedding results, accounting for invalid IDs and empty features.
46
+
47
+ Fixed so that could be used with Pai EmbeddingVariables.
48
+
49
+ The partitioned embedding in `embedding_weights` must all be the same shape
50
+ except for the first dimension. The first dimension is allowed to vary as the
51
+ vocabulary size is not necessarily a multiple of `P`. `embedding_weights`
52
+ may be a `PartitionedVariable` as returned by using `tf.get_variable()` with a
53
+ partitioner.
54
+
55
+ Invalid IDs (< 0) are pruned from input IDs and weights, as well as any IDs
56
+ with non-positive weight. For an entry with no features, the embedding vector
57
+ for `default_id` is returned, or the 0-vector if `default_id` is not supplied.
58
+
59
+ The ids and weights may be multi-dimensional. Embeddings are always aggregated
60
+ along the last dimension.
61
+
62
+ Args:
63
+ embedding_weights: A list of `P` float `Tensor`s or values representing
64
+ partitioned embedding `Tensor`s. Alternatively, a `PartitionedVariable`
65
+ created by partitioning along dimension 0. The total unpartitioned
66
+ shape should be `[e_0, e_1, ..., e_m]`, where `e_0` represents the
67
+ vocab size and `e_1, ..., e_m` are the embedding dimensions.
68
+ sparse_ids: `SparseTensor` of shape `[d_0, d_1, ..., d_n]` containing the
69
+ ids. `d_0` is typically batch size.
70
+ sparse_weights: `SparseTensor` of same shape as `sparse_ids`, containing
71
+ float weights corresponding to `sparse_ids`, or `None` if all weights
72
+ are be assumed to be 1.0.
73
+ combiner: A string specifying how to combine embedding results for each
74
+ entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean"
75
+ the default.
76
+ default_id: The id to use for an entry with no features.
77
+ name: A name for this operation (optional).
78
+ partition_strategy: A string specifying the partitioning strategy.
79
+ Currently `"div"` and `"mod"` are supported. Default is `"div"`.
80
+ max_norm: If not `None`, all embeddings are l2-normalized to max_norm before
81
+ combining.
82
+
83
+
84
+ Returns:
85
+ Dense `Tensor` of shape `[d_0, d_1, ..., d_{n-1}, e_1, ..., e_m]`.
86
+
87
+ Raises:
88
+ ValueError: if `embedding_weights` is empty.
89
+ """
90
+ if embedding_weights is None:
91
+ raise ValueError('Missing embedding_weights %s.' % embedding_weights)
92
+
93
+ embed_tensors = [ops.convert_to_tensor(embedding_weights)]
94
+ with ops.name_scope(name, 'embedding_lookup',
95
+ embed_tensors + [sparse_ids, sparse_weights]) as scope:
96
+ # Reshape higher-rank sparse ids and weights to linear segment ids.
97
+ original_shape = sparse_ids.dense_shape
98
+ original_rank_dim = sparse_ids.dense_shape.get_shape()[0]
99
+ original_rank = (
100
+ array_ops.size(original_shape)
101
+ if original_rank_dim.value is None else original_rank_dim.value)
102
+ sparse_ids = sparse_ops.sparse_reshape(sparse_ids, [
103
+ math_ops.reduce_prod(
104
+ array_ops.slice(original_shape, [0], [original_rank - 1])),
105
+ array_ops.gather(original_shape, original_rank - 1)
106
+ ])
107
+ if sparse_weights is not None:
108
+ sparse_weights = sparse_tensor.SparseTensor(sparse_ids.indices,
109
+ sparse_weights.values,
110
+ sparse_ids.dense_shape)
111
+
112
+ # Prune invalid ids and weights.
113
+ sparse_ids, sparse_weights = _prune_invalid_ids(sparse_ids, sparse_weights)
114
+ if combiner != 'sum':
115
+ sparse_ids, sparse_weights = _prune_invalid_weights(
116
+ sparse_ids, sparse_weights)
117
+
118
+ # Fill in dummy values for empty features, if necessary.
119
+ sparse_ids, is_row_empty = sparse_ops.sparse_fill_empty_rows(
120
+ sparse_ids, default_id or 0)
121
+ if sparse_weights is not None:
122
+ sparse_weights, _ = sparse_ops.sparse_fill_empty_rows(sparse_weights, 1.0)
123
+
124
+ indices = sparse_ids.indices
125
+ values = sparse_ids.values
126
+ if values.dtype != dtypes.int64:
127
+ values = math_ops.to_int64(values)
128
+ sparse_ids = sparse_tensor.SparseTensor(
129
+ indices=indices, values=values, dense_shape=sparse_ids.dense_shape)
130
+
131
+ result = embedding_ops.embedding_lookup_sparse(
132
+ embedding_weights,
133
+ sparse_ids,
134
+ sparse_weights,
135
+ combiner=combiner,
136
+ partition_strategy=partition_strategy,
137
+ name=None if default_id is None else scope,
138
+ max_norm=max_norm)
139
+
140
+ if default_id is None:
141
+ # Broadcast is_row_empty to the same shape as embedding_lookup_result,
142
+ # for use in Select.
143
+ is_row_empty = array_ops.tile(
144
+ array_ops.reshape(is_row_empty, [-1, 1]),
145
+ array_ops.stack([1, array_ops.shape(result)[1]]))
146
+
147
+ result = array_ops.where(
148
+ is_row_empty, array_ops.zeros_like(result), result, name=scope)
149
+
150
+ # Reshape back from linear ids back into higher-dimensional dense result.
151
+ final_result = array_ops.reshape(
152
+ result,
153
+ array_ops.concat([
154
+ array_ops.slice(
155
+ math_ops.cast(original_shape, dtypes.int32), [0],
156
+ [original_rank - 1]),
157
+ array_ops.slice(array_ops.shape(result), [1], [-1])
158
+ ], 0))
159
+ final_result.set_shape(
160
+ tensor_shape.unknown_shape(
161
+ (original_rank_dim - 1).value).concatenate(result.get_shape()[1:]))
162
+ return final_result
@@ -0,0 +1,316 @@
1
+ # -*- encoding:utf-8 -*-
2
+
3
+ import logging
4
+ import os
5
+
6
+ import numpy as np
7
+ from tensorflow.core.protobuf import saver_pb2
8
+ from tensorflow.python.framework import dtypes
9
+ from tensorflow.python.framework import ops
10
+ # from tensorflow.python.ops import math_ops
11
+ # from tensorflow.python.ops import logging_ops
12
+ from tensorflow.python.ops import array_ops
13
+ from tensorflow.python.ops import control_flow_ops
14
+ from tensorflow.python.ops import script_ops
15
+ from tensorflow.python.ops import state_ops
16
+ from tensorflow.python.platform import gfile
17
+ from tensorflow.python.training import saver
18
+
19
+ from easy_rec.python.utils import constant
20
+
21
+ try:
22
+ import horovod.tensorflow as hvd
23
+ from sparse_operation_kit.experiment import raw_ops as dynamic_variable_ops
24
+ from easy_rec.python.compat import dynamic_variable
25
+ except Exception:
26
+ dynamic_variable_ops = None
27
+ dynamic_variable = None
28
+
29
+ try:
30
+ from tensorflow.python.framework.load_library import load_op_library
31
+ import easy_rec
32
+ load_embed_lib_path = os.path.join(easy_rec.ops_dir, 'libload_embed.so')
33
+ load_embed_lib = load_op_library(load_embed_lib_path)
34
+ except Exception as ex:
35
+ logging.warning('load libload_embed.so failed: %s' % str(ex))
36
+ load_embed_lib = None
37
+
38
+
39
+ def _get_embed_part_id(embed_file):
40
+ embed_file = embed_file.split('/')[-1]
41
+ embed_file = embed_file.split('.')[0]
42
+ embed_id = embed_file.split('-')[-1]
43
+ return int(embed_id)
44
+
45
+
46
+ class EmbeddingParallelSaver(saver.Saver):
47
+
48
+ def __init__(self,
49
+ var_list=None,
50
+ reshape=False,
51
+ sharded=False,
52
+ max_to_keep=5,
53
+ keep_checkpoint_every_n_hours=10000.0,
54
+ name=None,
55
+ restore_sequentially=False,
56
+ saver_def=None,
57
+ builder=None,
58
+ defer_build=False,
59
+ allow_empty=False,
60
+ write_version=saver_pb2.SaverDef.V2,
61
+ pad_step_number=False,
62
+ save_relative_paths=False,
63
+ filename=None):
64
+ self._kv_vars = []
65
+ self._embed_vars = []
66
+ tf_vars = []
67
+ embed_para_vars = ops.get_collection(constant.EmbeddingParallel)
68
+ for var in var_list:
69
+ if dynamic_variable is not None and isinstance(
70
+ var, dynamic_variable.DynamicVariable):
71
+ self._kv_vars.append(var)
72
+ elif var.name in embed_para_vars:
73
+ logging.info('save shard embedding %s part_id=%d part_shape=%s' %
74
+ (var.name, hvd.rank(), var.get_shape()))
75
+ self._embed_vars.append(var)
76
+ else:
77
+ tf_vars.append(var)
78
+ super(EmbeddingParallelSaver, self).__init__(
79
+ tf_vars,
80
+ reshape=reshape,
81
+ sharded=sharded,
82
+ max_to_keep=max_to_keep,
83
+ keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
84
+ name=name,
85
+ restore_sequentially=restore_sequentially,
86
+ saver_def=saver_def,
87
+ builder=builder,
88
+ defer_build=defer_build,
89
+ allow_empty=allow_empty,
90
+ write_version=write_version,
91
+ pad_step_number=pad_step_number,
92
+ save_relative_paths=save_relative_paths,
93
+ filename=filename)
94
+ self._is_build = False
95
+
96
+ def _has_embed_vars(self):
97
+ return (len(self._kv_vars) + len(self._embed_vars)) > 0
98
+
99
+ def _save_dense_embedding(self, embed_var):
100
+ logging.info('task[%d] save_dense_embed: %s' % (hvd.rank(), embed_var.name))
101
+
102
+ def _save_embed(embed, filename, var_name):
103
+ task_id = hvd.rank()
104
+ filename = filename.decode('utf-8')
105
+ var_name = var_name.decode('utf-8').replace('/', '__')
106
+ embed_dir = filename + '-embedding/'
107
+ logging.info('task[%d] save_dense_embed: %s to %s' %
108
+ (task_id, var_name, embed_dir))
109
+ if not gfile.Exists(embed_dir):
110
+ gfile.MakeDirs(embed_dir)
111
+ embed_file = filename + '-embedding/embed-' + var_name + '-part-%d.bin' % task_id
112
+ with gfile.GFile(embed_file, 'wb') as fout:
113
+ fout.write(embed.tobytes())
114
+
115
+ if task_id == 0:
116
+ # clear old embedding tables
117
+ embed_pattern = filename + '-embedding/embed-' + var_name + '-part-*.bin'
118
+ embed_files = gfile.Glob(embed_pattern)
119
+ for embed_file in embed_files:
120
+ embed_id = _get_embed_part_id(embed_file)
121
+ if embed_id >= hvd.size():
122
+ gfile.DeleteRecursively(embed_file)
123
+ return np.asarray([embed_file], order='C', dtype=np.object)
124
+
125
+ file_name = ops.get_default_graph().get_tensor_by_name(
126
+ self.saver_def.filename_tensor_name)
127
+ save_paths = script_ops.py_func(_save_embed,
128
+ [embed_var, file_name, embed_var.name],
129
+ dtypes.string)
130
+ return save_paths
131
+
132
+ def _load_dense_embedding(self, embed_var):
133
+ file_name = ops.get_default_graph().get_tensor_by_name(
134
+ self.saver_def.filename_tensor_name)
135
+ embed_dim = embed_var.get_shape()[-1]
136
+ embed_part_size = embed_var.get_shape()[0]
137
+
138
+ def _load_embed(embed, embed_dim, embed_part_size, part_id, part_num,
139
+ filename, var_name):
140
+ filename = filename.decode('utf-8')
141
+ var_name = var_name.decode('utf-8').replace('/', '__')
142
+ embed_pattern = filename + '-embedding/embed-' + var_name + '-part-*.bin'
143
+ embed_files = gfile.Glob(embed_pattern)
144
+
145
+ embed_files.sort(key=_get_embed_part_id)
146
+
147
+ logging.info('task[%d] embed_files=%s embed_dim=%d embed_part_size=%d' %
148
+ (part_id, ','.join(embed_files), embed_dim, embed_part_size))
149
+
150
+ part_embed_vals = np.zeros([embed_part_size, embed_dim], dtype=np.float32)
151
+ part_update_cnt = 0
152
+ for embed_file in embed_files:
153
+ part_id_o = _get_embed_part_id(embed_file)
154
+ with gfile.GFile(embed_file, 'rb') as fin:
155
+ embed_val = np.frombuffer(fin.read(), np.float32)
156
+ embed_val = embed_val.reshape([-1, embed_dim])
157
+ embed_ids_o = np.arange(len(embed_val))
158
+ embed_ids_o = part_id_o + embed_ids_o * len(embed_files)
159
+ sel_ids = np.where(
160
+ np.logical_and((embed_ids_o % part_num) == part_id,
161
+ embed_ids_o < embed_part_size * part_num))[0]
162
+ part_update_cnt += len(sel_ids)
163
+ embed_ids = embed_ids_o[sel_ids]
164
+ embed_ids_n = np.array(embed_ids / part_num, dtype=np.int64)
165
+ part_embed_vals[embed_ids_n] = embed_val[sel_ids]
166
+ logging.info('task[%d] load_part_cnt=%d' % (part_id, part_update_cnt))
167
+ return part_embed_vals
168
+
169
+ with ops.control_dependencies([embed_var._initializer_op]):
170
+ if load_embed_lib is not None:
171
+ embed_val = load_embed_lib.load_embed(
172
+ task_index=hvd.rank(),
173
+ task_num=hvd.size(),
174
+ embed_dim=embed_dim,
175
+ embed_part_size=embed_part_size,
176
+ var_name='embed-' + embed_var.name.replace('/', '__'),
177
+ ckpt_path=file_name)
178
+ else:
179
+ embed_val = script_ops.py_func(_load_embed, [
180
+ embed_var, embed_dim, embed_part_size,
181
+ hvd.rank(),
182
+ hvd.size(), file_name, embed_var.name
183
+ ], dtypes.float32)
184
+ embed_val.set_shape(embed_var.get_shape())
185
+ return state_ops.assign(embed_var, embed_val)
186
+
187
+ def _save_kv_embedding(self, sok_var):
188
+ indices, values = dynamic_variable_ops.dummy_var_export(
189
+ sok_var.handle, key_type=sok_var.key_type, dtype=sok_var.handle_dtype)
190
+ file_name = ops.get_default_graph().get_tensor_by_name(
191
+ self.saver_def.filename_tensor_name)
192
+
193
+ def _save_key_vals(indices, values, filename, var_name):
194
+ var_name = var_name.decode('utf-8').replace('/', '__')
195
+ filename = filename.decode('utf-8')
196
+ sok_dir = filename + '-embedding/'
197
+ if not gfile.Exists(sok_dir):
198
+ gfile.MakeDirs(sok_dir)
199
+ task_id = hvd.rank()
200
+ key_file = filename + '-embedding/embed-' + var_name + '-part-%d.key' % task_id
201
+ with gfile.GFile(key_file, 'wb') as fout:
202
+ fout.write(indices.tobytes())
203
+ val_file = filename + '-embedding/embed-' + var_name + '-part-%d.val' % task_id
204
+ with gfile.GFile(val_file, 'wb') as fout:
205
+ fout.write(values.tobytes())
206
+
207
+ if task_id == 0:
208
+ key_file_pattern = filename + '-embedding/embed-' + var_name + '-part-*.key'
209
+ key_files = gfile.Glob(key_file_pattern)
210
+ for key_file in key_files:
211
+ embed_id = _get_embed_part_id(key_file)
212
+ if embed_id >= hvd.size():
213
+ gfile.DeleteRecursively(key_file)
214
+ val_file = key_file[:-4] + '.val'
215
+ if gfile.Exists(val_file):
216
+ gfile.DeleteRecursively(val_file)
217
+
218
+ return np.asarray([key_file, val_file], order='C', dtype=np.object)
219
+
220
+ save_paths = script_ops.py_func(_save_key_vals,
221
+ [indices, values, file_name, sok_var.name],
222
+ dtypes.string)
223
+ return save_paths
224
+
225
+ def _load_kv_embedding(self, sok_var):
226
+
227
+ def _load_key_vals(filename, var_name):
228
+ var_name = var_name.decode('utf-8').replace('/', '__')
229
+ filename = filename.decode('utf-8')
230
+ key_file_pattern = filename + '-embedding/embed-' + var_name + '-part-*.key'
231
+ logging.info('key_file_pattern=%s filename=%s var_name=%s var=%s' %
232
+ (key_file_pattern, filename, var_name, str(sok_var)))
233
+ key_files = gfile.Glob(key_file_pattern)
234
+ logging.info('key_file_pattern=%s file_num=%d' %
235
+ (key_file_pattern, len(key_files)))
236
+ all_keys = []
237
+ all_vals = []
238
+ for key_file in key_files:
239
+ with gfile.GFile(key_file, 'rb') as fin:
240
+ tmp_keys = np.frombuffer(fin.read(), dtype=np.int64)
241
+ tmp_ids = tmp_keys % hvd.size()
242
+ tmp_ids = np.where(tmp_ids == hvd.rank())[0]
243
+ if len(tmp_ids) == 0:
244
+ break
245
+ all_keys.append(tmp_keys.take(tmp_ids, axis=0))
246
+ logging.info('part_keys.shape=%s %s %s' % (str(
247
+ tmp_keys.shape), str(tmp_ids.shape), str(all_keys[-1].shape)))
248
+
249
+ val_file = key_file[:-4] + 'vals'
250
+ with gfile.GFile(val_file, 'rb') as fin:
251
+ tmp_vals = np.frombuffer(
252
+ fin.read(), dtype=np.float32).reshape([-1, sok_var._dimension])
253
+ all_vals.append(tmp_vals.take(tmp_ids, axis=0))
254
+ logging.info('part_vals.shape=%s %s %s' % (str(
255
+ tmp_vals.shape), str(tmp_ids.shape), str(all_vals[-1].shape)))
256
+
257
+ all_keys = np.concatenate(all_keys, axis=0)
258
+ all_vals = np.concatenate(all_vals, axis=0)
259
+
260
+ shuffle_ids = np.array(range(len(all_keys)))
261
+ np.random.shuffle(shuffle_ids)
262
+ all_keys = all_keys.take(shuffle_ids, axis=0)
263
+ all_vals = all_vals.take(shuffle_ids, axis=0)
264
+ return all_keys, all_vals
265
+
266
+ file_name = ops.get_default_graph().get_tensor_by_name(
267
+ self.saver_def.filename_tensor_name)
268
+ if load_embed_lib is not None:
269
+ keys, vals = load_embed_lib.load_kv_embed(
270
+ task_index=hvd.rank(),
271
+ task_num=hvd.size(),
272
+ embed_dim=sok_var._dimension,
273
+ var_name='embed-' + sok_var.name.replace('/', '__'),
274
+ ckpt_path=file_name)
275
+ else:
276
+ logging.warning('libload_embed.so not loaded, will use python script_ops')
277
+ keys, vals = script_ops.py_func(_load_key_vals, [file_name, sok_var.name],
278
+ (dtypes.int64, dtypes.float32))
279
+ with ops.control_dependencies([sok_var._initializer_op]):
280
+ return dynamic_variable_ops.dummy_var_assign(sok_var.handle, keys, vals)
281
+
282
+ def build(self):
283
+ if self._is_built:
284
+ return
285
+ super(EmbeddingParallelSaver, self).build()
286
+ if self.saver_def.restore_op_name and self._has_embed_vars():
287
+ # load data from the model
288
+ restore_ops = []
289
+ for sok_var in self._kv_vars:
290
+ restore_ops.append(self._load_kv_embedding(sok_var))
291
+ for embed_var in self._embed_vars:
292
+ restore_ops.append(self._load_dense_embedding(embed_var))
293
+ old_restore_op = ops.get_default_graph().get_operation_by_name(
294
+ self.saver_def.restore_op_name)
295
+ restore_ops.append(old_restore_op)
296
+ restore_op_n = control_flow_ops.group(restore_ops)
297
+ self.saver_def.restore_op_name = restore_op_n.name
298
+
299
+ if self.saver_def.save_tensor_name and self._has_embed_vars():
300
+ file_name = ops.get_default_graph().get_tensor_by_name(
301
+ self.saver_def.filename_tensor_name)
302
+ save_part_ops = []
303
+ for sok_var in self._kv_vars:
304
+ save_part_op = self._save_kv_embedding(sok_var)
305
+ save_part_ops.append(save_part_op)
306
+ for embed_var in self._embed_vars:
307
+ save_part_op = self._save_dense_embedding(embed_var)
308
+ save_part_ops.append(save_part_op)
309
+ old_save_op = ops.get_default_graph().get_tensor_by_name(
310
+ self.saver_def.save_tensor_name)
311
+ # only the first worker needs to save non embedding variables
312
+ if hvd.rank() == 0:
313
+ save_part_ops.append(old_save_op)
314
+ with ops.control_dependencies(save_part_ops):
315
+ save_op_n = array_ops.identity(file_name)
316
+ self.saver_def.save_tensor_name = save_op_n.name
@@ -0,0 +1,116 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import logging
4
+ import os
5
+
6
+ import tensorflow as tf
7
+ from tensorflow.python.estimator import run_config as run_config_lib
8
+ from tensorflow.python.util import compat
9
+ from tensorflow_estimator.python.estimator.training import _assert_eval_spec
10
+ from tensorflow_estimator.python.estimator.training import _TrainingExecutor
11
+
12
+ from easy_rec.python.compat.exporter import FinalExporter
13
+ from easy_rec.python.utils import estimator_utils
14
+
15
+ from tensorflow_estimator.python.estimator.training import _ContinuousEvalListener # NOQA
16
+
17
+ from tensorflow.python.distribute import estimator_training as distribute_coordinator_training # NOQA
18
+
19
+ if tf.__version__ >= '2.0':
20
+ tf = tf.compat.v1
21
+ gfile = tf.gfile
22
+
23
+
24
+ class TrainDoneListener(_ContinuousEvalListener):
25
+ """Interface for listeners that take action before or after evaluation."""
26
+
27
+ def __init__(self, estimator):
28
+ self._model_dir = estimator.model_dir
29
+ self._train_done_file = os.path.join(self._model_dir,
30
+ 'ESTIMATOR_TRAIN_DONE')
31
+
32
+ @property
33
+ def train_done_file(self):
34
+ return self._train_done_file
35
+
36
+ def after_eval(self, eval_result):
37
+ """Called after the evaluation is executed.
38
+
39
+ Args:
40
+ eval_result: An `_EvalResult` instance.
41
+
42
+ Returns:
43
+ False if you want to early stop continuous evaluation; `True` otherwise.
44
+ """
45
+ last_ckpt_path = eval_result.checkpoint_path
46
+ if last_ckpt_path is not None:
47
+ model_dir = os.path.dirname(last_ckpt_path).rstrip('/') + '/'
48
+ latest_ckpt_path = estimator_utils.latest_checkpoint(model_dir)
49
+ if latest_ckpt_path != last_ckpt_path:
50
+ logging.info(
51
+ 'TrainDoneListener: latest_ckpt_path[%s] != last_ckpt_path[%s]' %
52
+ (latest_ckpt_path, last_ckpt_path))
53
+ # there are more checkpoints wait to be evaluated
54
+ return True
55
+ return not gfile.Exists(self._train_done_file)
56
+
57
+
58
+ def train_and_evaluate(estimator, train_spec, eval_spec):
59
+ _assert_eval_spec(eval_spec) # fail fast if eval_spec is invalid.
60
+
61
+ train_done_listener = TrainDoneListener(estimator)
62
+ executor = _TrainingExecutor(
63
+ estimator=estimator,
64
+ train_spec=train_spec,
65
+ eval_spec=eval_spec,
66
+ continuous_eval_listener=train_done_listener)
67
+ config = estimator.config
68
+
69
+ # If `distribute_coordinator_mode` is set and running in distributed
70
+ # environment, we run `train_and_evaluate` via distribute coordinator.
71
+ if distribute_coordinator_training.should_run_distribute_coordinator(config):
72
+ logging.info('Running `train_and_evaluate` with Distribute Coordinator.')
73
+ distribute_coordinator_training.train_and_evaluate(estimator, train_spec,
74
+ eval_spec,
75
+ _TrainingExecutor)
76
+ return
77
+
78
+ if (config.task_type == run_config_lib.TaskType.EVALUATOR and
79
+ config.task_id > 0):
80
+ raise ValueError(
81
+ 'For distributed training, there can only be one `evaluator` task '
82
+ '(with task id 0). Given task id {}'.format(config.task_id))
83
+
84
+ result = executor.run()
85
+
86
+ # fix for the bug evaluator fails to export in case num_epoch is reached
87
+ # before num_steps is reached or num_steps is set to infinite
88
+ if estimator_utils.is_evaluator():
89
+ export_dir_base = os.path.join(
90
+ compat.as_str_any(estimator.model_dir), compat.as_str_any('export'))
91
+ for exporter in eval_spec.exporters:
92
+ if isinstance(exporter, FinalExporter):
93
+ export_path = os.path.join(
94
+ compat.as_str_any(export_dir_base),
95
+ compat.as_str_any(exporter.name))
96
+ # avoid duplicate export
97
+ if gfile.IsDirectory(export_path + '/'):
98
+ continue
99
+ exporter.export(
100
+ estimator=estimator,
101
+ export_path=export_path,
102
+ checkpoint_path=estimator_utils.latest_checkpoint(
103
+ estimator.model_dir),
104
+ eval_result=None,
105
+ is_the_final_export=True)
106
+
107
+ if estimator_utils.is_chief():
108
+ with gfile.GFile(train_done_listener.train_done_file, 'w') as fout:
109
+ fout.write('Train Done.')
110
+
111
+ return result
112
+
113
+
114
+ def estimator_train_done(estimator):
115
+ train_done_file = os.path.join(estimator.model_dir, 'ESTIMATOR_TRAIN_DONE')
116
+ return gfile.Exists(train_done_file)