easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
@@ -0,0 +1,1036 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ from __future__ import absolute_import
4
+ from __future__ import division
5
+ from __future__ import print_function
6
+
7
+ import json
8
+ import logging
9
+ import os
10
+ import re
11
+ import sys
12
+ import time
13
+
14
+ import numpy as np
15
+ import six
16
+ import tensorflow as tf
17
+ from tensorflow.core.framework.summary_pb2 import Summary
18
+ from tensorflow.python.client import device_lib
19
+ from tensorflow.python.framework import errors_impl
20
+ from tensorflow.python.framework import meta_graph
21
+ from tensorflow.python.framework import ops
22
+ from tensorflow.python.ops import array_ops
23
+ from tensorflow.python.platform import gfile
24
+ from tensorflow.python.training import basic_session_run_hooks
25
+ from tensorflow.python.training import session_run_hook
26
+ from tensorflow.python.training.summary_io import SummaryWriterCache
27
+
28
+ from easy_rec.python.ops.incr_record import get_sparse_indices
29
+ from easy_rec.python.ops.incr_record import kv_resource_incr_gather
30
+ from easy_rec.python.utils import constant
31
+ from easy_rec.python.utils import embedding_utils
32
+ from easy_rec.python.utils import shape_utils
33
+
34
+ from tensorflow.python.training.basic_session_run_hooks import SecondOrStepTimer # NOQA
35
+
36
+ try:
37
+ import horovod.tensorflow as hvd
38
+ except Exception:
39
+ hvd = None
40
+
41
+ try:
42
+ from sparse_operation_kit import experiment as sok
43
+ except Exception:
44
+ sok = None
45
+
46
+ try:
47
+ from kafka import KafkaProducer, KafkaAdminClient
48
+ from kafka.admin import NewTopic
49
+ except ImportError as ex:
50
+ logging.warning('kafka-python is not installed: %s' % str(ex))
51
+
52
+ if tf.__version__ >= '2.0':
53
+ tf = tf.compat.v1
54
+ SessionRunHook = session_run_hook.SessionRunHook
55
+ CheckpointSaverHook = basic_session_run_hooks.CheckpointSaverHook
56
+
57
+
58
+ def tensor_log_format_func(tensor_dict):
59
+ prefix = ''
60
+ if 'step' in tensor_dict:
61
+ prefix = 'global step %s: ' % tensor_dict['step']
62
+ stats = []
63
+ for k in tensor_dict:
64
+ if k == 'step':
65
+ continue
66
+ tensor_value = tensor_dict[k]
67
+ stats.append('%s = %s' % (k, tensor_value))
68
+ return prefix + ', '.join(stats)
69
+
70
+
71
+ class ExitBarrierHook(SessionRunHook):
72
+ """ExitBarrier to make sure master and workers exit at the same time.
73
+
74
+ After training finish, master has to do evaluation and model export, so master exits a little late
75
+ than workers.
76
+ """
77
+
78
+ def __init__(self, num_worker, is_chief, model_dir):
79
+ self._num_worker = num_worker
80
+ self._is_chief = is_chief
81
+ self._queue = None
82
+ self._signal_que = None
83
+ self._que_size = None
84
+ self._queue = None
85
+ self._enque = None
86
+ self._deque = None
87
+ self._model_dir = model_dir
88
+ self._send = None
89
+ self._recv = None
90
+
91
+ def begin(self):
92
+ """Count the number of workers and masters, and setup barrier queue."""
93
+ tf.logging.info('number workers(including master) = %d' % self._num_worker)
94
+ with tf.device(
95
+ tf.DeviceSpec(job='ps', task=0, device_type='CPU', device_index=0)):
96
+ self._queue = tf.FIFOQueue(
97
+ capacity=self._num_worker,
98
+ dtypes=[tf.float32],
99
+ shapes=[()],
100
+ name='exit_counter',
101
+ shared_name='exit_counter')
102
+ self._signal_que = tf.FIFOQueue(
103
+ capacity=self._num_worker,
104
+ dtypes=[tf.string],
105
+ shapes=[()],
106
+ name='exit_counter_signal',
107
+ shared_name='exit_counter_signal')
108
+ self._enque = self._queue.enqueue(1.0)
109
+ self._que_size = self._queue.size()
110
+ self._deque = self._queue.dequeue()
111
+ if self._is_chief:
112
+ self._flag_file = os.path.join(self._model_dir,
113
+ 'atexit_sync_' + str(int(time.time())))
114
+ self._send = self._signal_que.enqueue([self._flag_file])
115
+ else:
116
+ self._recv = self._signal_que.dequeue()
117
+ self._flag_file = None
118
+
119
+ def after_create_session(self, session, coord):
120
+ """Clean up the queue after create session.
121
+
122
+ Sometimes ps is not exit, the last run enqueued elements will remain in the queue
123
+ """
124
+ if self._is_chief:
125
+ # clear the queue
126
+ que_size = session.run(self._que_size)
127
+ while que_size > 0:
128
+ session.run(self._deque)
129
+ que_size = session.run(self._que_size)
130
+ logging.info('exit counter cleared: %d' % que_size)
131
+
132
+ def end(self, session):
133
+ """Ensure when all workers and master enqueue an element, then exit."""
134
+ session.run(self._enque)
135
+ que_size = session.run(self._que_size)
136
+ while que_size < self._num_worker:
137
+ que_size = session.run(self._que_size)
138
+ time.sleep(5)
139
+ tf.logging.info(
140
+ 'waiting for other worker to exit, finished %d, total %d' %
141
+ (que_size, self._num_worker))
142
+ # prepare on_exit synchronize base on self._flag_file
143
+ if self._is_chief:
144
+ for i in range(self._num_worker - 1):
145
+ session.run(self._send)
146
+ else:
147
+ self._flag_file = session.run(self._recv)
148
+
149
+ def _check_flag_file(is_chief, flag_file):
150
+ logging.info('_check_flag_file: is_chief = %d flag_file=%s' %
151
+ (is_chief, flag_file))
152
+ if is_chief:
153
+ with gfile.GFile(flag_file, 'w') as fout:
154
+ fout.write('atexit time: %d' % int(time.time()))
155
+ else:
156
+ while not gfile.Exists(flag_file):
157
+ time.sleep(1)
158
+
159
+ from atexit import register
160
+ register(
161
+ _check_flag_file, is_chief=self._is_chief, flag_file=self._flag_file)
162
+ logging.info('ExitBarrier passed')
163
+
164
+
165
+ class EvaluateExitBarrierHook(SessionRunHook):
166
+ """ExitBarrier to make sure master and workers exit at the same time.
167
+
168
+ After training finish, master has to do evaluation and model export, so master exits a little late
169
+ than workers.
170
+ """
171
+
172
+ def __init__(self, num_worker, is_chief, model_dir, metric_ops=None):
173
+ self._num_worker = num_worker
174
+ self._is_chief = is_chief
175
+ self._queue = None
176
+ self._signal_que = None
177
+ self._que_size = None
178
+ self._queue = None
179
+ self._enque = None
180
+ self._deque = None
181
+ self._model_dir = model_dir
182
+ self._send = None
183
+ self._recv = None
184
+ self.metric_ops = metric_ops
185
+ self.eval_result = None
186
+
187
+ def begin(self):
188
+ """Count the number of workers and masters, and setup barrier queue."""
189
+ tf.logging.info('number workers(including master) = %d' % self._num_worker)
190
+ with tf.device(
191
+ tf.DeviceSpec(job='ps', task=0, device_type='CPU', device_index=0)):
192
+ self._queue = tf.FIFOQueue(
193
+ capacity=self._num_worker,
194
+ dtypes=[tf.float32],
195
+ shapes=[()],
196
+ name='exit_counter',
197
+ shared_name='exit_counter')
198
+ self._signal_que = tf.FIFOQueue(
199
+ capacity=self._num_worker,
200
+ dtypes=[tf.string],
201
+ shapes=[()],
202
+ name='exit_counter_signal',
203
+ shared_name='exit_counter_signal')
204
+ self._enque = self._queue.enqueue(1.0)
205
+ self._que_size = self._queue.size()
206
+ self._deque = self._queue.dequeue()
207
+ if self._is_chief:
208
+ self._flag_file = os.path.join(self._model_dir,
209
+ 'atexit_sync_' + str(int(time.time())))
210
+ self._send = self._signal_que.enqueue([self._flag_file])
211
+ else:
212
+ self._recv = self._signal_que.dequeue()
213
+ self._flag_file = None
214
+
215
+ def after_create_session(self, session, coord):
216
+ """Clean up the queue after create session.
217
+
218
+ Sometimes ps is not exit, the last run enqueued elements will remain in the queue
219
+ """
220
+ if self._is_chief:
221
+ # clear the queue
222
+ que_size = session.run(self._que_size)
223
+ while que_size > 0:
224
+ session.run(self._deque)
225
+ que_size = session.run(self._que_size)
226
+ logging.info('exit counter cleared: %d' % que_size)
227
+
228
+ def end(self, session):
229
+ """Ensure when all workers and master enqueue an element, then exit."""
230
+ session.run(self._enque)
231
+ que_size = session.run(self._que_size)
232
+ while que_size < self._num_worker:
233
+ que_size = session.run(self._que_size)
234
+ time.sleep(5)
235
+ tf.logging.info(
236
+ 'waiting for other worker to exit, finished %d, total %d' %
237
+ (que_size, self._num_worker))
238
+ # prepare on_exit synchronize base on self._flag_file
239
+ if self._is_chief:
240
+ self.eval_result = session.run(self.metric_ops)
241
+ for i in range(self._num_worker - 1):
242
+ session.run(self._send)
243
+ else:
244
+ self._flag_file = session.run(self._recv)
245
+
246
+ def _check_flag_file(is_chief, flag_file):
247
+ logging.info('_check_flag_file: is_chief = %d flag_file=%s' %
248
+ (is_chief, flag_file))
249
+ if is_chief:
250
+ with gfile.GFile(flag_file, 'w') as fout:
251
+ fout.write('atexit time: %d' % int(time.time()))
252
+ else:
253
+ while not gfile.Exists(flag_file):
254
+ time.sleep(1)
255
+
256
+ from atexit import register
257
+ register(
258
+ _check_flag_file, is_chief=self._is_chief, flag_file=self._flag_file)
259
+ session.run(self.metric_ops)
260
+
261
+ logging.info('ExitBarrier passed')
262
+
263
+
264
+ class ProgressHook(SessionRunHook):
265
+
266
+ def __init__(self, num_steps, filename, is_chief):
267
+ """Initializes a `ProgressHook`.
268
+
269
+ Args:
270
+ num_steps: total train steps
271
+ filename: progress file name
272
+ is_chief: is chief worker or not
273
+ """
274
+ self._num_steps = num_steps
275
+ self._is_chief = is_chief
276
+ if self._is_chief:
277
+ self._progress_file = gfile.GFile(filename, 'w')
278
+ self._progress_file.write('0.00\n')
279
+ self._progress_interval = 0.01 # 1%
280
+ self._last_progress_cnt = 0
281
+
282
+ def before_run(self, run_context):
283
+ if self._is_chief:
284
+ return tf.train.SessionRunArgs([tf.train.get_global_step()])
285
+
286
+ def after_run(
287
+ self,
288
+ run_context, # pylint: disable=unused-argument
289
+ run_values):
290
+ if self._is_chief:
291
+ global_step = run_values.results[0]
292
+ curr_progress = global_step / self._num_steps
293
+ curr_progress_cnt = int(curr_progress / self._progress_interval)
294
+ if curr_progress_cnt >= self._last_progress_cnt + 1:
295
+ self._progress_file.write('%.2f\n' % curr_progress)
296
+ self._progress_file.flush()
297
+ self._last_progress_cnt = curr_progress_cnt
298
+ logging.info('Training Progress: %.2f' % curr_progress)
299
+
300
+ def end(self, session):
301
+ if self._is_chief:
302
+ if self._last_progress_cnt < 1 / self._progress_interval:
303
+ self._progress_file.write('1.00\n')
304
+ self._progress_file.close()
305
+
306
+
307
+ class CheckpointSaverHook(CheckpointSaverHook):
308
+ """Saves checkpoints every N steps or seconds."""
309
+
310
+ def __init__(self,
311
+ checkpoint_dir,
312
+ save_secs=None,
313
+ save_steps=None,
314
+ saver=None,
315
+ checkpoint_basename='model.ckpt',
316
+ scaffold=None,
317
+ listeners=None,
318
+ write_graph=True,
319
+ data_offset_var=None,
320
+ increment_save_config=None):
321
+ """Initializes a `CheckpointSaverHook`.
322
+
323
+ Args:
324
+ checkpoint_dir: `str`, base directory for the checkpoint files.
325
+ save_secs: `int`, save every N secs.
326
+ save_steps: `int`, save every N steps.
327
+ saver: `Saver` object, used for saving.
328
+ checkpoint_basename: `str`, base name for the checkpoint files.
329
+ scaffold: `Scaffold`, use to get saver object.
330
+ listeners: List of `CheckpointSaverListener` subclass instances.
331
+ Used for callbacks that run immediately before or after this hook saves
332
+ the checkpoint.
333
+ write_graph: whether to save graph.pbtxt.
334
+ data_offset_var: data offset variable.
335
+ increment_save_config: parameters for saving increment checkpoints.
336
+
337
+ Raises:
338
+ ValueError: One of `save_steps` or `save_secs` should be set.
339
+ ValueError: At most one of saver or scaffold should be set.
340
+ """
341
+ super(CheckpointSaverHook, self).__init__(
342
+ checkpoint_dir,
343
+ save_secs=save_secs,
344
+ save_steps=save_steps,
345
+ saver=saver,
346
+ checkpoint_basename=checkpoint_basename,
347
+ scaffold=scaffold,
348
+ listeners=listeners)
349
+ self._cuda_profile_start = 0
350
+ self._cuda_profile_stop = 0
351
+ self._steps_per_run = 1
352
+ self._write_graph = write_graph
353
+ self._data_offset_var = data_offset_var
354
+
355
+ self._task_idx, self._task_num = get_task_index_and_num()
356
+
357
+ if increment_save_config is not None:
358
+ self._kafka_timeout_ms = os.environ.get('KAFKA_TIMEOUT', 600) * 1000
359
+ logging.info('KAFKA_TIMEOUT: %dms' % self._kafka_timeout_ms)
360
+ self._kafka_max_req_size = os.environ.get('KAFKA_MAX_REQ_SIZE',
361
+ 1024 * 1024 * 64)
362
+ logging.info('KAFKA_MAX_REQ_SIZE: %d' % self._kafka_max_req_size)
363
+ self._kafka_max_msg_size = os.environ.get('KAFKA_MAX_MSG_SIZE',
364
+ 1024 * 1024 * 1024)
365
+ logging.info('KAFKA_MAX_MSG_SIZE: %d' % self._kafka_max_msg_size)
366
+
367
+ self._dense_name_to_ids = embedding_utils.get_dense_name_to_ids()
368
+ self._sparse_name_to_ids = embedding_utils.get_sparse_name_to_ids()
369
+
370
+ with gfile.GFile(
371
+ os.path.join(checkpoint_dir, constant.DENSE_UPDATE_VARIABLES),
372
+ 'w') as fout:
373
+ json.dump(self._dense_name_to_ids, fout, indent=2)
374
+
375
+ save_secs = increment_save_config.dense_save_secs
376
+ save_steps = increment_save_config.dense_save_steps
377
+ self._dense_timer = SecondOrStepTimer(
378
+ every_secs=save_secs if save_secs > 0 else None,
379
+ every_steps=save_steps if save_steps > 0 else None)
380
+ save_secs = increment_save_config.sparse_save_secs
381
+ save_steps = increment_save_config.sparse_save_steps
382
+ self._sparse_timer = SecondOrStepTimer(
383
+ every_secs=save_secs if save_secs > 0 else None,
384
+ every_steps=save_steps if save_steps > 0 else None)
385
+
386
+ self._dense_timer.update_last_triggered_step(0)
387
+ self._sparse_timer.update_last_triggered_step(0)
388
+
389
+ self._sparse_indices = []
390
+ self._sparse_values = []
391
+ sparse_train_vars = ops.get_collection(constant.SPARSE_UPDATE_VARIABLES)
392
+ for sparse_var, indice_dtype in sparse_train_vars:
393
+ with ops.control_dependencies([tf.train.get_global_step()]):
394
+ with ops.colocate_with(sparse_var):
395
+ sparse_indice = get_sparse_indices(
396
+ var_name=sparse_var.op.name, ktype=indice_dtype)
397
+ # sparse_indice = sparse_indice.global_indices
398
+ self._sparse_indices.append(sparse_indice)
399
+ if 'EmbeddingVariable' in str(type(sparse_var)):
400
+ self._sparse_values.append(
401
+ kv_resource_incr_gather(
402
+ sparse_var._handle, sparse_indice,
403
+ np.zeros(sparse_var.shape.as_list(), dtype=np.float32)))
404
+ # sparse_var.sparse_read(sparse_indice))
405
+ else:
406
+ self._sparse_values.append(
407
+ array_ops.gather(sparse_var, sparse_indice))
408
+
409
+ self._kafka_producer = None
410
+ self._incr_save_dir = None
411
+ if increment_save_config.HasField('kafka'):
412
+ self._topic = increment_save_config.kafka.topic
413
+ logging.info('increment save topic: %s' % self._topic)
414
+
415
+ admin_clt = KafkaAdminClient(
416
+ bootstrap_servers=increment_save_config.kafka.server,
417
+ request_timeout_ms=self._kafka_timeout_ms,
418
+ api_version_auto_timeout_ms=self._kafka_timeout_ms)
419
+ if self._topic not in admin_clt.list_topics():
420
+ admin_clt.create_topics(
421
+ new_topics=[
422
+ NewTopic(
423
+ name=self._topic,
424
+ num_partitions=1,
425
+ replication_factor=1,
426
+ topic_configs={
427
+ 'max.message.bytes': self._kafka_max_msg_size
428
+ })
429
+ ],
430
+ validate_only=False)
431
+ logging.info('create increment save topic: %s' % self._topic)
432
+ admin_clt.close()
433
+
434
+ servers = increment_save_config.kafka.server.split(',')
435
+ self._kafka_producer = KafkaProducer(
436
+ bootstrap_servers=servers,
437
+ max_request_size=self._kafka_max_req_size,
438
+ api_version_auto_timeout_ms=self._kafka_timeout_ms,
439
+ request_timeout_ms=self._kafka_timeout_ms)
440
+ elif increment_save_config.HasField('fs'):
441
+ fs = increment_save_config.fs
442
+ if fs.relative:
443
+ self._incr_save_dir = os.path.join(checkpoint_dir, fs.incr_save_dir)
444
+ else:
445
+ self._incr_save_dir = fs.incr_save_dir
446
+ if not self._incr_save_dir.endswith('/'):
447
+ self._incr_save_dir += '/'
448
+ if not gfile.IsDirectory(self._incr_save_dir):
449
+ gfile.MakeDirs(self._incr_save_dir)
450
+ elif increment_save_config.HasField('datahub'):
451
+ raise NotImplementedError('datahub increment saving is in development.')
452
+ else:
453
+ raise ValueError(
454
+ 'incr_update not specified correctly, must be oneof: kafka,fs')
455
+
456
+ self._debug_save_update = increment_save_config.debug_save_update
457
+ else:
458
+ self._dense_timer = None
459
+ self._sparse_timer = None
460
+
461
+ def after_create_session(self, session, coord):
462
+ global_step = session.run(self._global_step_tensor)
463
+ if self._write_graph:
464
+ # We do write graph and saver_def at the first call of before_run.
465
+ # We cannot do this in begin, since we let other hooks to change graph and
466
+ # add variables at begin. Graph is finalized after all begin calls.
467
+ tf.train.write_graph(tf.get_default_graph().as_graph_def(add_shapes=True),
468
+ self._checkpoint_dir, 'graph.pbtxt')
469
+ saver_def = self._get_saver().saver_def if self._get_saver() else None
470
+ graph = tf.get_default_graph()
471
+ meta_graph_def = meta_graph.create_meta_graph_def(
472
+ graph_def=graph.as_graph_def(add_shapes=True), saver_def=saver_def)
473
+ self._summary_writer.add_graph(graph)
474
+ self._summary_writer.add_meta_graph(meta_graph_def)
475
+
476
+ # save for step 0
477
+ self._save(session, global_step)
478
+
479
+ self._timer.update_last_triggered_step(global_step)
480
+
481
+ def before_run(self, run_context): # pylint: disable=unused-argument
482
+ return tf.train.SessionRunArgs(self._global_step_tensor)
483
+
484
+ def _send_dense(self, global_step, session):
485
+ dense_train_vars = ops.get_collection(constant.DENSE_UPDATE_VARIABLES)
486
+ dense_train_vals = session.run(dense_train_vars)
487
+ logging.info('global_step=%d, increment save dense variables' % global_step)
488
+
489
+ # build msg header
490
+ msg_num = len(dense_train_vals)
491
+ msg_ids = [self._dense_name_to_ids[x.op.name] for x in dense_train_vars]
492
+ # 0 mean dense update message
493
+ msg_header = [0, msg_num, global_step]
494
+ for msg_id, x in zip(msg_ids, dense_train_vals):
495
+ msg_header.append(msg_id)
496
+ msg_header.append(x.size)
497
+
498
+ # build msg body
499
+ bytes_buf = np.array(msg_header, dtype=np.int32).tobytes()
500
+ for x in dense_train_vals:
501
+ bytes_buf += x.tobytes()
502
+
503
+ if self._kafka_producer is not None:
504
+ msg_key = 'dense_update_%d' % global_step
505
+ send_res = self._kafka_producer.send(
506
+ self._topic, bytes_buf, key=msg_key.encode('utf-8'))
507
+ logging.info('kafka send dense: %d exception: %s' %
508
+ (global_step, send_res.exception))
509
+
510
+ if self._incr_save_dir is not None:
511
+ save_path = os.path.join(self._incr_save_dir,
512
+ 'dense_update_%d' % global_step)
513
+ with gfile.GFile(save_path, 'wb') as fout:
514
+ fout.write(bytes_buf)
515
+ save_flag = save_path + '.done'
516
+ with gfile.GFile(save_flag, 'w') as fout:
517
+ fout.write('dense_update_%d' % global_step)
518
+
519
+ if self._debug_save_update and self._incr_save_dir is None:
520
+ base_dir, _ = os.path.split(self._save_path)
521
+ incr_save_dir = os.path.join(base_dir, 'incr_save/')
522
+ if not gfile.Exists(incr_save_dir):
523
+ gfile.MakeDirs(incr_save_dir)
524
+ save_path = os.path.join(incr_save_dir, 'dense_update_%d' % global_step)
525
+ with gfile.GFile(save_path, 'wb') as fout:
526
+ fout.write(bytes_buf)
527
+
528
+ logging.info(
529
+ 'global_step=%d, increment update dense variables, msg_num=%d' %
530
+ (global_step, msg_num))
531
+
532
+ def _send_sparse(self, global_step, session):
533
+ sparse_train_vars = ops.get_collection(constant.SPARSE_UPDATE_VARIABLES)
534
+ sparse_res = session.run(self._sparse_indices + self._sparse_values)
535
+ msg_num = int(len(sparse_res) / 2)
536
+
537
+ sel_ids = [i for i in range(msg_num) if len(sparse_res[i]) > 0]
538
+ sparse_key_res = [sparse_res[i] for i in sel_ids]
539
+ sparse_val_res = [sparse_res[i + msg_num] for i in sel_ids]
540
+ sparse_train_vars = [sparse_train_vars[i][0] for i in sel_ids]
541
+
542
+ sel_embed_ids = [
543
+ self._sparse_name_to_ids[x.name] for x in sparse_train_vars
544
+ ]
545
+
546
+ msg_num = len(sel_ids)
547
+
548
+ if msg_num == 0:
549
+ logging.warning('there are no sparse updates, will skip this send: %d' %
550
+ global_step)
551
+ return
552
+
553
+ # build msg header
554
+ # 1 means sparse update messages
555
+ msg_header = [1, msg_num, global_step]
556
+ for tmp_id, tmp_key in zip(sel_embed_ids, sparse_key_res):
557
+ msg_header.append(tmp_id)
558
+ msg_header.append(len(tmp_key))
559
+ bytes_buf = np.array(msg_header, dtype=np.int32).tobytes()
560
+
561
+ # build msg body
562
+ for tmp_id, tmp_key, tmp_val, tmp_var in zip(sel_embed_ids, sparse_key_res,
563
+ sparse_val_res,
564
+ sparse_train_vars):
565
+ # for non kv embedding variables, add partition offset to tmp_key
566
+ if 'EmbeddingVariable' not in str(type(tmp_var)):
567
+ if tmp_var._save_slice_info is not None:
568
+ tmp_key += tmp_var._save_slice_info.var_offset[0]
569
+ bytes_buf += tmp_key.tobytes()
570
+ bytes_buf += tmp_val.tobytes()
571
+ if self._kafka_producer is not None:
572
+ msg_key = 'sparse_update_%d' % global_step
573
+ send_res = self._kafka_producer.send(
574
+ self._topic, bytes_buf, key=msg_key.encode('utf-8'))
575
+ logging.info('kafka send sparse: %d %s' %
576
+ (global_step, send_res.exception))
577
+
578
+ if self._incr_save_dir is not None:
579
+ save_path = os.path.join(self._incr_save_dir,
580
+ 'sparse_update_%d' % global_step)
581
+ with gfile.GFile(save_path, 'wb') as fout:
582
+ fout.write(bytes_buf)
583
+ save_flag = save_path + '.done'
584
+ with gfile.GFile(save_flag, 'w') as fout:
585
+ fout.write('sparse_update_%d' % global_step)
586
+
587
+ if self._debug_save_update and self._incr_save_dir is None:
588
+ base_dir, _ = os.path.split(self._save_path)
589
+ incr_save_dir = os.path.join(base_dir, 'incr_save/')
590
+ if not gfile.Exists(incr_save_dir):
591
+ gfile.MakeDirs(incr_save_dir)
592
+ save_path = os.path.join(incr_save_dir, 'sparse_update_%d' % global_step)
593
+ with gfile.GFile(save_path, 'wb') as fout:
594
+ fout.write(bytes_buf)
595
+
596
+ logging.info(
597
+ 'global_step=%d, increment update sparse variables, msg_num=%d, msg_size=%d'
598
+ % (global_step, msg_num, len(bytes_buf)))
599
+
600
+ def after_run(self, run_context, run_values):
601
+ super(CheckpointSaverHook, self).after_run(run_context, run_values)
602
+ stale_global_step = run_values.results
603
+ global_step = -1
604
+ if self._dense_timer is not None and self._dense_timer.should_trigger_for_step(
605
+ stale_global_step + self._steps_per_run):
606
+ global_step = run_context.session.run(self._global_step_tensor)
607
+ self._dense_timer.update_last_triggered_step(global_step)
608
+ self._send_dense(global_step, run_context.session)
609
+
610
+ if self._sparse_timer is not None and self._sparse_timer.should_trigger_for_step(
611
+ stale_global_step + self._steps_per_run):
612
+ if global_step < 0:
613
+ global_step = run_context.session.run(self._global_step_tensor)
614
+
615
+ self._sparse_timer.update_last_triggered_step(global_step)
616
+ self._send_sparse(global_step, run_context.session)
617
+
618
+ def _save(self, session, step):
619
+ """Saves the latest checkpoint, returns should_stop."""
620
+ logging.info('Saving checkpoints for %d into %s.', step, self._save_path)
621
+
622
+ for l in self._listeners: # noqa: E741
623
+ l.before_save(session, step)
624
+
625
+ if self._data_offset_var is not None:
626
+ save_data_offset = session.run(self._data_offset_var)
627
+ data_offset_json = {}
628
+ for x in save_data_offset:
629
+ if x:
630
+ data_offset_json.update(json.loads(x))
631
+ save_offset_path = os.path.join(self._checkpoint_dir,
632
+ 'model.ckpt-%d.offset' % step)
633
+ with gfile.GFile(save_offset_path, 'w') as fout:
634
+ json.dump(data_offset_json, fout)
635
+
636
+ self._get_saver().save(
637
+ session,
638
+ self._save_path,
639
+ global_step=step,
640
+ write_meta_graph=self._write_graph)
641
+
642
+ self._summary_writer.add_session_log(
643
+ tf.SessionLog(
644
+ status=tf.SessionLog.CHECKPOINT, checkpoint_path=self._save_path),
645
+ step)
646
+
647
+ should_stop = False
648
+ for l in self._listeners: # noqa: E741
649
+ if l.after_save(session, step):
650
+ logging.info(
651
+ 'A CheckpointSaverListener requested that training be stopped. '
652
+ 'listener: {}'.format(l))
653
+ should_stop = True
654
+ return should_stop
655
+
656
+ def end(self, session):
657
+ global_step = session.run(self._global_step_tensor)
658
+ super(CheckpointSaverHook, self).end(session)
659
+ if self._dense_timer is not None and \
660
+ global_step != self._dense_timer.last_triggered_step():
661
+ self._dense_timer.update_last_triggered_step(global_step)
662
+ self._send_dense(global_step, session)
663
+ if self._sparse_timer is not None and \
664
+ global_step != self._sparse_timer.last_triggered_step():
665
+ self._sparse_timer.update_last_triggered_step(global_step)
666
+ self._send_sparse(global_step, session)
667
+
668
+
669
+ class NumpyCheckpointRestoreHook(SessionRunHook):
670
+ """Restore variable from numpy checkpoint."""
671
+
672
+ def __init__(self, ckpt_path, name2var_map):
673
+ """Initializes a `NumpyCheckpointRestoreHook`.
674
+
675
+ Args:
676
+ ckpt_path: numpy checkpoint path to restore from
677
+ name2var_map: var name in numpy ckpt to variable map
678
+ """
679
+ self._ckpt_path = ckpt_path
680
+ self._name2var_map = name2var_map
681
+ self._restore_op = None
682
+
683
+ def begin(self):
684
+ ckpt_data = np.load(self._ckpt_path)
685
+ vars_not_inited = {}
686
+
687
+ assign_ops = []
688
+ has_shape_unmatch = False
689
+ with tf.variable_scope('', reuse=True):
690
+ for var_name, var in six.iteritems(self._name2var_map):
691
+ var_shape = var.get_shape().as_list()
692
+ if var_name in ckpt_data.keys():
693
+ var_data = ckpt_data[var_name]
694
+ if list(var_data.shape) == var_shape:
695
+ assign_ops.append(var.assign(var_data))
696
+ else:
697
+ logging.error(
698
+ 'variable [%s] shape not match %r vs %r' %
699
+ (var.name.split(':')[0], var_shape, list(var_data.shape)))
700
+ has_shape_unmatch = True
701
+ elif 'Momentum' not in var_name and 'global_step' not in var_name:
702
+ logging.error('variable [%s] not found in ckpt' % var_name)
703
+ vars_not_inited[var_name] = ','.join([str(s) for s in var_shape])
704
+ self._restore_op = tf.group(assign_ops)
705
+
706
+ with gfile.GFile(self._ckpt_path[:-4] + '_not_inited.txt', 'w') as f:
707
+ for var_name in sorted(vars_not_inited.keys()):
708
+ f.write('%s:%s\n' % (var_name, vars_not_inited[var_name]))
709
+ assert not has_shape_unmatch, 'exist variable shape not match, restore failed'
710
+ assert len(vars_not_inited.keys()) == 0, \
711
+ 'exist variable shape not inited, restore failed'
712
+
713
+ def after_create_session(self, session, coord):
714
+ assert self._restore_op is not None
715
+ logging.info('running numpy checkpoint restore_op')
716
+ session.run(self._restore_op)
717
+
718
+
719
+ class IncompatibleShapeRestoreHook(SessionRunHook):
720
+ """Restore variable with incompatible shapes."""
721
+
722
+ def __init__(self, incompatible_shape_var_map):
723
+ """Initializes a `IncompatibleShapeRestoreHook`.
724
+
725
+ Args:
726
+ incompatible_shape_var_map: a variables mapping with incompatible shapes,
727
+ map from real variable to temp variable, real variable is the variable
728
+ used in model, temp variable is the variable restored from checkpoint.
729
+ """
730
+ self._incompatible_shape_var_map = incompatible_shape_var_map
731
+ self._restore_op = None
732
+
733
+ def begin(self):
734
+ assign_ops = []
735
+ for var, var_tmp in six.iteritems(self._incompatible_shape_var_map):
736
+ assign_ops.append(
737
+ var.assign(
738
+ shape_utils.pad_or_clip_nd(var_tmp,
739
+ var.get_shape().as_list())))
740
+ logging.info(
741
+ 'Assign variable[%s] from shape%s to shape%s' %
742
+ (var.name, var_tmp.get_shape().as_list(), var.get_shape().as_list()))
743
+ self._restore_op = tf.group(assign_ops)
744
+
745
+ def after_create_session(self, session, coord):
746
+ assert self._restore_op is not None
747
+ logging.info('running incompatible shape variable restore_op')
748
+ session.run(self._restore_op)
749
+
750
+
751
+ class MultipleCheckpointsRestoreHook(SessionRunHook):
752
+ """Restore variable from numpy checkpoint."""
753
+ SEP = ';'
754
+
755
+ def __init__(self, ckpt_paths):
756
+ """Initializes a `MultipleCheckpointsRestoreHook`.
757
+
758
+ Args:
759
+ ckpt_paths: multiple checkpoint path, seperated by ;
760
+ name2var_map: var name in numpy ckpt to variable map
761
+ """
762
+ self._ckpt_path_list = ckpt_paths.split(self.SEP)
763
+ self._saver_list = []
764
+
765
+ def begin(self):
766
+ global_variables = tf.global_variables()
767
+ var_names = [re.sub(':[0-9]$', '', var.name) for var in global_variables]
768
+ restore_status = {var_name: False for var_name in var_names}
769
+ for ckpt_path in self._ckpt_path_list:
770
+ logging.info('read variable from %s' % ckpt_path)
771
+ ckpt_reader = tf.train.NewCheckpointReader(ckpt_path)
772
+ ckpt_var2shape_map = ckpt_reader.get_variable_to_shape_map()
773
+ # ckpt_var2shape_map.pop(tf.GraphKeys.GLOBAL_STEP, None)
774
+ name2var = {}
775
+ for var in global_variables:
776
+ var_name = re.sub(':[0-9]$', '', var.name)
777
+ if var_name in ckpt_var2shape_map:
778
+ if restore_status[var_name]:
779
+ logging.warning(
780
+ 'variable %s find in more than one checkpoint, skipped %s' %
781
+ (var_name, ckpt_path))
782
+ continue
783
+ name2var[var_name] = var
784
+ restore_status[var_name] = True
785
+ saver = tf.train.Saver(name2var)
786
+ self._saver_list.append(saver)
787
+
788
+ restore_check = True
789
+ for var_name, stat in six.iteritems(restore_status):
790
+ if not stat:
791
+ logging.error('var %s not find in checkpoints' % var_name)
792
+ restore_check = False
793
+
794
+ assert restore_check, 'failed to find all variables in checkpoints provided'
795
+
796
+ def after_create_session(self, session, coord):
797
+ logging.info('running multiple checkpoint restore hook')
798
+ for saver, ckpt_path in zip(self._saver_list, self._ckpt_path_list):
799
+ logging.info('restore checkpoint from %s' % ckpt_path)
800
+ saver.restore(session, ckpt_path)
801
+
802
+
803
+ class OnlineEvaluationHook(SessionRunHook):
804
+
805
+ def __init__(self, metric_dict, output_dir):
806
+ self._metric_dict = metric_dict
807
+ self._output_dir = output_dir
808
+ self._summary_writer = SummaryWriterCache.get(self._output_dir)
809
+
810
+ def end(self, session):
811
+ metric_tensor_dict = {k: v[0] for k, v in self._metric_dict.items()}
812
+ metric_value_dict = session.run(metric_tensor_dict)
813
+ tf.logging.info('Eval metric: %s' % metric_value_dict)
814
+
815
+ global_step_tensor = tf.train.get_or_create_global_step()
816
+ global_step = session.run(global_step_tensor)
817
+
818
+ summary = Summary()
819
+ for k, v in metric_value_dict.items():
820
+ summary.value.add(tag=k, simple_value=v)
821
+ self._summary_writer.add_summary(summary, global_step=global_step)
822
+ self._summary_writer.flush()
823
+
824
+ eval_result_file = os.path.join(self._output_dir,
825
+ 'online_eval_result.txt-%s' % global_step)
826
+ logging.info('Saving online eval result to file %s' % eval_result_file)
827
+ with gfile.GFile(eval_result_file, 'w') as ofile:
828
+ result_to_write = {}
829
+ for key in sorted(metric_value_dict):
830
+ # convert numpy float to python float
831
+ result_to_write[key] = metric_value_dict[key].item()
832
+ ofile.write(json.dumps(result_to_write, indent=2))
833
+
834
+
835
+ def parse_tf_config():
836
+ tf_config_str = os.environ.get('TF_CONFIG', '')
837
+ if 'TF_CONFIG' in os.environ:
838
+ tf_config = json.loads(tf_config_str)
839
+ cluster = tf_config['cluster']
840
+ task = tf_config['task']
841
+ task_type = task['type']
842
+ task_index = task['index']
843
+ else:
844
+ cluster = {}
845
+ task_type = 'master'
846
+ task_index = 0
847
+ return cluster, task_type, task_index
848
+
849
+
850
+ def get_task_index_and_num():
851
+ if hvd is not None and 'HOROVOD_RANK' in os.environ:
852
+ return hvd.rank(), hvd.size()
853
+ cluster, task_type, task_index = parse_tf_config()
854
+ if 'worker' not in cluster:
855
+ return 0, 1
856
+ if task_type == 'evaluator':
857
+ return 0, 1
858
+
859
+ task_num = len(cluster['worker'])
860
+ if 'chief' in cluster or 'master' in cluster:
861
+ task_num += 1
862
+ if task_type not in ['chief', 'master']:
863
+ task_index += 1
864
+ return task_index, task_num
865
+
866
+
867
+ def get_ckpt_version(ckpt_path):
868
+ """Get checkpoint version from ckpt_path.
869
+
870
+ Args:
871
+ ckpt_path: such as xx/model.ckpt-2000 or xx/model.ckpt-2000.meta
872
+
873
+ Return:
874
+ ckpt_version: such as 2000
875
+ """
876
+ _, ckpt_name = os.path.split(ckpt_path)
877
+ ckpt_name, ext = os.path.splitext(ckpt_name)
878
+ if ext.startswith('.ckpt-'):
879
+ ckpt_name = ext
880
+ toks = ckpt_name.split('-')
881
+ return int(toks[-1])
882
+
883
+
884
+ def get_latest_checkpoint_from_checkpoint_path(checkpoint_path,
885
+ ignore_ckpt_error):
886
+ ckpt_path = None
887
+ if checkpoint_path.endswith('/') or gfile.IsDirectory(checkpoint_path + '/'):
888
+ checkpoint_dir = checkpoint_path
889
+ if not checkpoint_dir.endswith('/'):
890
+ checkpoint_dir = checkpoint_dir + '/'
891
+ if gfile.Exists(checkpoint_dir):
892
+ ckpt_path = latest_checkpoint(checkpoint_dir)
893
+ if ckpt_path:
894
+ logging.info(
895
+ 'fine_tune_checkpoint is directory, will use the latest checkpoint: %s'
896
+ % ckpt_path)
897
+ else:
898
+ assert ignore_ckpt_error, 'fine_tune_checkpoint(%s) is not exists.' % checkpoint_path
899
+ else:
900
+ assert ignore_ckpt_error, 'fine_tune_checkpoint(%s) is not exists.' % checkpoint_path
901
+ elif gfile.Exists(checkpoint_path + '.index'):
902
+ ckpt_path = checkpoint_path
903
+ logging.info('update fine_tune_checkpoint to %s' % checkpoint_path)
904
+ else:
905
+ assert ignore_ckpt_error, 'fine_tune_checkpoint(%s) is not exists.' % checkpoint_path
906
+ return ckpt_path
907
+
908
+
909
+ def latest_checkpoint(model_dir):
910
+ """Find lastest checkpoint under a directory.
911
+
912
+ Args:
913
+ model_dir: model directory
914
+
915
+ Return:
916
+ model_path: xx/model.ckpt-2000
917
+ """
918
+ try:
919
+ ckpt_metas = gfile.Glob(os.path.join(model_dir, 'model.ckpt-*.index'))
920
+
921
+ if len(ckpt_metas) == 0:
922
+ return None
923
+
924
+ if len(ckpt_metas) > 1:
925
+ ckpt_metas.sort(key=lambda x: get_ckpt_version(x))
926
+ ckpt_path = os.path.splitext(ckpt_metas[-1])[0]
927
+ return ckpt_path
928
+ except errors_impl.NotFoundError:
929
+ return None
930
+
931
+
932
+ def get_trained_steps(model_dir):
933
+ ckpt_path = latest_checkpoint(model_dir)
934
+ if ckpt_path is not None:
935
+ return int(ckpt_path.split('-')[-1])
936
+ else:
937
+ return 0
938
+
939
+
940
+ def master_to_chief():
941
+ if 'TF_CONFIG' in os.environ:
942
+ tf_config = json.loads(os.environ['TF_CONFIG'])
943
+ # change chief to master
944
+ if 'master' in tf_config['cluster']:
945
+ tf_config['cluster']['chief'] = tf_config['cluster']['master']
946
+ del tf_config['cluster']['chief']
947
+ if tf_config['task']['type'] == 'master':
948
+ tf_config['task']['type'] = 'chief'
949
+ os.environ['TF_CONFIG'] = json.dumps(tf_config)
950
+ return tf_config
951
+ else:
952
+ return None
953
+
954
+
955
+ def chief_to_master():
956
+ if 'TF_CONFIG' in os.environ:
957
+ tf_config = json.loads(os.environ['TF_CONFIG'])
958
+ # change chief to master
959
+ if 'chief' in tf_config['cluster']:
960
+ tf_config['cluster']['master'] = tf_config['cluster']['chief']
961
+ del tf_config['cluster']['chief']
962
+ if tf_config['task']['type'] == 'chief':
963
+ tf_config['task']['type'] = 'master'
964
+ os.environ['TF_CONFIG'] = json.dumps(tf_config)
965
+ return tf_config
966
+ else:
967
+ return None
968
+
969
+
970
+ def is_ps():
971
+ if 'TF_CONFIG' in os.environ:
972
+ tf_config = json.loads(os.environ['TF_CONFIG'])
973
+ if 'task' in tf_config:
974
+ return tf_config['task']['type'] == 'ps'
975
+ return False
976
+
977
+
978
+ def is_chief():
979
+ if has_hvd():
980
+ return hvd.rank() == 0
981
+
982
+ if 'TF_CONFIG' in os.environ:
983
+ tf_config = json.loads(os.environ['TF_CONFIG'])
984
+ if 'task' in tf_config:
985
+ return tf_config['task']['type'] in ['chief', 'master']
986
+ return True
987
+
988
+
989
+ def is_master():
990
+ if 'TF_CONFIG' in os.environ:
991
+ tf_config = json.loads(os.environ['TF_CONFIG'])
992
+ if 'task' in tf_config:
993
+ return tf_config['task']['type'] == 'master'
994
+ return True
995
+
996
+
997
+ def is_evaluator():
998
+ if 'TF_CONFIG' in os.environ:
999
+ tf_config = json.loads(os.environ['TF_CONFIG'])
1000
+ if 'task' in tf_config:
1001
+ return tf_config['task']['type'] == 'evaluator'
1002
+ return False
1003
+
1004
+
1005
+ def has_hvd():
1006
+ return hvd is not None and 'HOROVOD_RANK' in os.environ
1007
+
1008
+
1009
+ def has_sok():
1010
+ return sok is not None and 'ENABLE_SOK' in os.environ
1011
+
1012
+
1013
+ def init_hvd():
1014
+ if hvd is None:
1015
+ logging.error(
1016
+ 'horovod is not installed: HOROVOD_WITH_TENSORFLOW=1 pip install horovod'
1017
+ )
1018
+ sys.exit(1)
1019
+
1020
+ hvd.init()
1021
+ os.environ['HOROVOD_RANK'] = str(hvd.rank())
1022
+
1023
+
1024
+ def init_sok():
1025
+ try:
1026
+ sok.init()
1027
+ os.environ['ENABLE_SOK'] = '1'
1028
+ return True
1029
+ except Exception:
1030
+ logging.warning('sok is not installed')
1031
+ return False
1032
+
1033
+
1034
+ def get_available_gpus():
1035
+ local_device_protos = device_lib.list_local_devices()
1036
+ return [x.name for x in local_device_protos if x.device_type == 'GPU']