easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
@@ -0,0 +1,137 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import json
4
+ import logging
5
+ import os
6
+
7
+ import psutil
8
+ import tensorflow as tf
9
+ from tensorflow.python.summary import summary_iterator
10
+
11
+ if tf.__version__ >= '2.0':
12
+ gfile = tf.compat.v1.gfile
13
+ else:
14
+ gfile = tf.gfile
15
+
16
+
17
+ def get_all_eval_result(event_file_pattern):
18
+ """Get the best eval result from event files.
19
+
20
+ Args:
21
+ event_files: Absolute pattern of event files.
22
+
23
+ Returns:
24
+ The best eval result.
25
+ """
26
+ all_eval_result = []
27
+ for event_file in gfile.Glob(os.path.join(event_file_pattern)):
28
+ for event in summary_iterator.summary_iterator(event_file):
29
+ if event.HasField('summary'):
30
+ event_eval_result = {}
31
+ event_eval_result['global_step'] = event.step
32
+ for value in event.summary.value:
33
+ if value.HasField('simple_value'):
34
+ event_eval_result[value.tag] = value.simple_value
35
+ if len(event_eval_result) >= 2:
36
+ all_eval_result.append(event_eval_result)
37
+ return all_eval_result
38
+
39
+
40
+ def save_eval_metrics(model_dir, metric_save_path, has_evaluator=True):
41
+ """Save evaluation metrics.
42
+
43
+ Args:
44
+ model_dir: train model directory
45
+ metric_save_path: metric saving path
46
+ has_evaluator: evaluation is done on a separate evaluator, not on master.
47
+ """
48
+
49
+ def _get_eval_event_file_pattern():
50
+ eval_dir = os.path.join(model_dir, 'eval_val/')
51
+ if not gfile.Exists(eval_dir):
52
+ eval_dir = os.path.join(model_dir, 'eval/')
53
+ assert gfile.Exists(eval_dir), 'eval_val or eval does exists'
54
+ event_file_pattern = os.path.join(eval_dir, '*.tfevents.*')
55
+ logging.info('event_file_pattern: %s' % event_file_pattern)
56
+ return event_file_pattern
57
+
58
+ all_eval_res = []
59
+ if 'TF_CONFIG' in os.environ:
60
+ # check whether evaluator exists
61
+ tf_config = json.loads(os.environ['TF_CONFIG'])
62
+ logging.info('tf_config = %s' % json.dumps(tf_config))
63
+ logging.info('model_dir = %s' % model_dir)
64
+ if has_evaluator:
65
+ if tf_config['task']['type'] == 'evaluator':
66
+ event_file_pattern = _get_eval_event_file_pattern()
67
+ all_eval_res = get_all_eval_result(event_file_pattern)
68
+ elif 'master' in tf_config['cluster'] or 'chief' in tf_config['cluster']:
69
+ if tf_config['task']['type'] in ['master', 'chief']:
70
+ event_file_pattern = _get_eval_event_file_pattern()
71
+ all_eval_res = get_all_eval_result(event_file_pattern)
72
+ else:
73
+ assert False, 'invalid cluster config, could not find master or chief or evaluator'
74
+ else:
75
+ # standalone mode
76
+ event_file_pattern = _get_eval_event_file_pattern()
77
+ all_eval_res = get_all_eval_result(event_file_pattern)
78
+
79
+ logging.info('all_eval_res num = %d' % len(all_eval_res))
80
+ if len(all_eval_res) > 0:
81
+ with gfile.GFile(metric_save_path, 'w') as fout:
82
+ for eval_res in all_eval_res:
83
+ fout.write(json.dumps(eval_res) + '\n')
84
+ logging.info('save all evaluation result to %s' % metric_save_path)
85
+
86
+
87
+ def kill_old_proc(tmp_dir, platform='pai'):
88
+ curr_pid = os.getpid()
89
+ if platform == 'pai':
90
+ for p in psutil.process_iter():
91
+ try:
92
+ cmd = ' '.join(p.cmdline())
93
+ if 'easy_rec.python.hpo.pai_hpo' in cmd and 'python' in cmd:
94
+ if p.pid != curr_pid:
95
+ logging.info('will kill: [%d] %s' % (p.pid, cmd))
96
+ p.terminate()
97
+ if 'client/experiment_main.py' in cmd and 'python' in cmd:
98
+ if p.pid != curr_pid:
99
+ logging.info('will kill: [%d] %s' % (p.pid, cmd))
100
+ p.terminate()
101
+ except Exception:
102
+ pass
103
+ else:
104
+ for p in psutil.process_iter():
105
+ try:
106
+ cmd = ' '.join(p.cmdline())
107
+ if 'easy_rec.python.hpo.emr_hpo' in cmd and 'python' in cmd:
108
+ if p.pid != curr_pid:
109
+ logging.info('will kill: [%d] %s' % (p.pid, cmd))
110
+ p.terminate()
111
+ if 'client/experiment_main.py' in cmd and 'python' in cmd:
112
+ if p.pid != curr_pid:
113
+ logging.info('will kill: [%d] %s' % (p.pid, cmd))
114
+ p.terminate()
115
+ if 'el_submit' in cmd and 'easy_rec_hpo' in cmd:
116
+ if p.pid != curr_pid:
117
+ logging.info('will kill: [%d] %s' % (p.pid, cmd))
118
+ p.terminate()
119
+ except Exception:
120
+ pass
121
+
122
+ if platform == 'emr':
123
+ # clear easy_rec_hpo yarn jobs
124
+ yarn_job_file = os.path.join(tmp_dir, 'yarn_job.txt')
125
+ os.system(
126
+ "yarn application -list | awk '{ if ($2 == \"easy_rec_hpo\") print $1 }' > %s"
127
+ % yarn_job_file)
128
+ yarn_job_arr = []
129
+ with open(yarn_job_file, 'r') as fin:
130
+ for line_str in fin:
131
+ line_str = line_str.strip()
132
+ yarn_job_arr.append(line_str)
133
+ yarn_job_arr = list(set(yarn_job_arr))
134
+ if len(yarn_job_arr) > 0:
135
+ logging.info('will kill the easy_rec_hpo yarn jobs: %s' %
136
+ ','.join(yarn_job_arr))
137
+ os.system('yarn application -kill %s' % ' '.join(yarn_job_arr))
@@ -0,0 +1,56 @@
1
+ # -*- encoding: utf-8 -*-
2
+ import logging
3
+
4
+ import tensorflow as tf
5
+ from tensorflow.python.framework import ops
6
+ from tensorflow.python.training import session_run_hook
7
+
8
+ from easy_rec.python.utils import constant
9
+
10
+ # from horovod.tensorflow.compression import Compression
11
+ try:
12
+ from horovod.tensorflow.functions import broadcast_variables
13
+ except Exception:
14
+ pass
15
+
16
+ if tf.__version__ >= '2.0':
17
+ tf = tf.compat.v1
18
+
19
+
20
+ class BroadcastGlobalVariablesHook(session_run_hook.SessionRunHook):
21
+ """SessionRunHook that will broadcast all global variables from root rank to all other processes during initialization.
22
+
23
+ This is necessary to ensure consistent initialization of all workers when
24
+ training is started with random weights or restored from a checkpoint.
25
+ """ # noqa: E501
26
+
27
+ def __init__(self, root_rank, device=''):
28
+ """Construct a new BroadcastGlobalVariablesHook that will broadcast all global variables from root rank to all other processes during initialization.
29
+
30
+ Args:
31
+ root_rank:
32
+ Rank that will send data, other ranks will receive data.
33
+ device:
34
+ Device to be used for broadcasting. Uses GPU by default
35
+ if Horovod was built with HOROVOD_GPU_OPERATIONS.
36
+ """ # noqa: E501
37
+ super(BroadcastGlobalVariablesHook, self).__init__()
38
+ self.root_rank = root_rank
39
+ self.bcast_op = None
40
+ self.device = device
41
+
42
+ def begin(self):
43
+ bcast_vars = []
44
+ embed_para_vars = ops.get_collection(constant.EmbeddingParallel)
45
+ for x in tf.global_variables():
46
+ # if '/embedding' not in x.name and 'DynamicVariable' not in str(type(x)):
47
+ if x.name not in embed_para_vars:
48
+ bcast_vars.append(x)
49
+ logging.info('will broadcast variable: name=%s shape=%s' %
50
+ (x.name, x.get_shape()))
51
+ if not self.bcast_op or self.bcast_op.graph != tf.get_default_graph():
52
+ with tf.device(self.device):
53
+ self.bcast_op = broadcast_variables(bcast_vars, self.root_rank)
54
+
55
+ def after_create_session(self, session, coord):
56
+ session.run(self.bcast_op)
@@ -0,0 +1,108 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import numpy as np
4
+ import pandas as pd
5
+ import tensorflow as tf
6
+
7
+ from easy_rec.python.protos.dataset_pb2 import DatasetConfig
8
+
9
+ if tf.__version__ >= '2.0':
10
+ tf = tf.compat.v1
11
+
12
+
13
+ def get_type_defaults(field_type, default_val=''):
14
+ type_defaults = {
15
+ DatasetConfig.INT32: 0,
16
+ DatasetConfig.INT64: 0,
17
+ DatasetConfig.STRING: '',
18
+ DatasetConfig.BOOL: False,
19
+ DatasetConfig.FLOAT: 0.0,
20
+ DatasetConfig.DOUBLE: 0.0
21
+ }
22
+ assert field_type in type_defaults, 'invalid type: %s' % field_type
23
+ if default_val == '':
24
+ default_val = type_defaults[field_type]
25
+ if field_type == DatasetConfig.INT32:
26
+ return int(default_val)
27
+ elif field_type == DatasetConfig.INT64:
28
+ return np.int64(default_val)
29
+ elif field_type == DatasetConfig.STRING:
30
+ return default_val
31
+ elif field_type == DatasetConfig.BOOL:
32
+ return default_val.lower() == 'true'
33
+ elif field_type in [DatasetConfig.FLOAT]:
34
+ return float(default_val)
35
+ elif field_type in [DatasetConfig.DOUBLE]:
36
+ return np.float64(default_val)
37
+
38
+ return type_defaults[field_type]
39
+
40
+
41
+ def string_to_number(field, ftype, default_value, name=''):
42
+ """Type conversion for parsing rtp fg input format.
43
+
44
+ Args:
45
+ field: field to be converted.
46
+ ftype: field dtype set in DatasetConfig.
47
+ default_value: default value for this field
48
+ name: field name for
49
+ Returns: A name for the operation (optional).
50
+ """
51
+ default_vals = tf.tile(tf.constant([str(default_value)]), tf.shape(field))
52
+ field = tf.where(tf.greater(tf.strings.length(field), 0), field, default_vals)
53
+
54
+ if ftype in [DatasetConfig.INT32, DatasetConfig.INT64]:
55
+ # Int type is not supported in fg.
56
+ # If you specify INT32, INT64 in DatasetConfig, you need to perform a cast at here.
57
+ tmp_field = tf.string_to_number(
58
+ field, tf.double, name='field_as_flt_%s' % name)
59
+ if ftype in [DatasetConfig.INT64]:
60
+ tmp_field = tf.cast(tmp_field, tf.int64)
61
+ else:
62
+ tmp_field = tf.cast(tmp_field, tf.int32)
63
+ elif ftype in [DatasetConfig.FLOAT]:
64
+ tmp_field = tf.string_to_number(
65
+ field, tf.float32, name='field_as_flt_%s' % name)
66
+ elif ftype in [DatasetConfig.DOUBLE]:
67
+ tmp_field = tf.string_to_number(
68
+ field, tf.float64, name='field_as_flt_%s' % name)
69
+ elif ftype in [DatasetConfig.BOOL]:
70
+ tmp_field = tf.logical_or(tf.equal(field, 'True'), tf.equal(field, 'true'))
71
+ elif ftype in [DatasetConfig.STRING]:
72
+ tmp_field = field
73
+ else:
74
+ assert False, 'invalid types: %s' % str(ftype)
75
+ return tmp_field
76
+
77
+
78
+ def np_to_tf_type(np_type):
79
+ _types_map = {
80
+ int: tf.int32,
81
+ np.int32: tf.int32,
82
+ np.int64: tf.int64,
83
+ str: tf.string,
84
+ np.float: tf.float32,
85
+ np.float32: tf.float32,
86
+ float: tf.float32,
87
+ np.double: tf.float64
88
+ }
89
+ if np_type in _types_map:
90
+ return _types_map[np_type]
91
+ else:
92
+ return tf.string
93
+
94
+
95
+ def get_tf_type_from_parquet_file(cols, parquet_file):
96
+ # gfile not supported, read_parquet requires random access
97
+ input_data = pd.read_parquet(parquet_file, columns=cols)
98
+ tf_types = []
99
+ for col in cols:
100
+ obj = input_data[col][0]
101
+ if isinstance(obj, list):
102
+ data_type = type(obj[0])
103
+ elif isinstance(obj, np.ndarray):
104
+ data_type = type(obj[0])
105
+ else:
106
+ data_type = type(obj)
107
+ tf_types.append(np_to_tf_type(data_type))
108
+ return tf_types
@@ -0,0 +1,282 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ """IO utils.
4
+
5
+ isort:skip_file
6
+ """
7
+ import logging
8
+ from future import standard_library
9
+ standard_library.install_aliases()
10
+
11
+ import os
12
+ import traceback
13
+ from subprocess import getstatusoutput
14
+
15
+ import six
16
+ import tensorflow as tf
17
+ from six.moves import http_client
18
+ from six.moves import urllib
19
+ import json
20
+ if six.PY2:
21
+ from urllib import quote
22
+ else:
23
+ from urllib.parse import quote
24
+
25
+ if tf.__version__ >= '2.0':
26
+ tf = tf.compat.v1
27
+
28
+ EASY_REC_RES_DIR = 'easy_rec_user_resources'
29
+ HTTP_MAX_NUM_RETRY = 5
30
+ HTTP_MAX_TIMEOUT = 600
31
+
32
+
33
+ def http_read(url, timeout=HTTP_MAX_TIMEOUT, max_retry=HTTP_MAX_NUM_RETRY):
34
+ """Read data from url with maximum retry.
35
+
36
+ Args:
37
+ url: http url to be read
38
+ timeout: specifies a timeout in seconds for blocking operations.
39
+ max_retry: http max retry times.
40
+ """
41
+ num_read_try = 0
42
+ data = None
43
+ while num_read_try < max_retry:
44
+ try:
45
+ if six.PY2:
46
+ url = url.encode('utf-8')
47
+ url = quote(url, safe='%/:?=&')
48
+ data = urllib.request.urlopen(url, timeout=timeout).read()
49
+ break
50
+ except http_client.IncompleteRead:
51
+ tf.logging.warning('incomplete read exception, will retry: %s' % url)
52
+ num_read_try += 1
53
+ except Exception:
54
+ tf.logging.error(traceback.format_exc())
55
+ break
56
+
57
+ if data is None:
58
+ tf.logging.error('http read %s failed' % url)
59
+
60
+ return data
61
+
62
+
63
+ def download(oss_or_url, dst_dir=''):
64
+ """Download file.
65
+
66
+ Args:
67
+ oss_or_url: http or oss path
68
+ dst_dir: destination directory
69
+ Return:
70
+ dst_file: local path for the downloaded file
71
+ """
72
+ _, basename = os.path.split(oss_or_url)
73
+ if oss_or_url[:3] == 'oss':
74
+ with tf.gfile.GFile(oss_or_url, 'rb') as infile:
75
+ file_content = infile.read()
76
+ elif oss_or_url[:4] == 'http':
77
+ try:
78
+ response = urllib.request.urlopen(oss_or_url, timeout=HTTP_MAX_TIMEOUT)
79
+ file_content = response.read()
80
+ except Exception as e:
81
+ raise RuntimeError('Download %s failed: %s\n %s' %
82
+ (oss_or_url, str(e), traceback.format_exc()))
83
+ else:
84
+ tf.logging.warning('skip downloading %s, seems to be a local file' %
85
+ oss_or_url)
86
+ return oss_or_url
87
+
88
+ if dst_dir != '' and not os.path.exists(dst_dir):
89
+ os.makedirs(dst_dir)
90
+ dst_file = os.path.join(dst_dir, basename)
91
+ with tf.gfile.GFile(dst_file, 'wb') as ofile:
92
+ ofile.write(file_content)
93
+
94
+ return dst_file
95
+
96
+
97
+ def create_module_dir(dst_dir):
98
+ if not os.path.exists(dst_dir):
99
+ os.makedirs(dst_dir)
100
+ with open(os.path.join(dst_dir, '__init__.py'), 'w') as ofile:
101
+ ofile.write('\n')
102
+
103
+
104
+ def download_resource(resource_path, dst_dir=EASY_REC_RES_DIR):
105
+ """Download user resource.
106
+
107
+ Args:
108
+ resource_path: http or oss path
109
+ dst_dir: destination directory
110
+ """
111
+ create_module_dir(dst_dir)
112
+ _, basename = os.path.split(resource_path)
113
+ if not basename.endswith('.py'):
114
+ raise ValueError('resource %s should be python file' % resource_path)
115
+
116
+ target = download(resource_path, dst_dir)
117
+
118
+ return target
119
+
120
+
121
+ def download_and_uncompress_resource(resource_path, dst_dir=EASY_REC_RES_DIR):
122
+ """Download user resource and uncompress it if necessary.
123
+
124
+ Args:
125
+ resource_path: http or oss path
126
+ dst_dir: download destination directory
127
+ """
128
+ create_module_dir(dst_dir)
129
+
130
+ _, basename = os.path.split(resource_path)
131
+ if not basename.endswith('.tar.gz') and not basename.endswith('.zip') and \
132
+ not basename.endswith('.py'):
133
+ raise ValueError('resource %s should be tar.gz or zip or py' %
134
+ resource_path)
135
+
136
+ download(resource_path, dst_dir)
137
+
138
+ stat = 0
139
+ if basename.endswith('tar.gz'):
140
+ stat, output = getstatusoutput('cd %s && tar -zxf %s' % (dst_dir, basename))
141
+ elif basename.endswith('zip'):
142
+ stat, output = getstatusoutput('cd %s && unzip %s' % (dst_dir, basename))
143
+
144
+ if stat != 0:
145
+ raise ValueError('uncompress resoruce %s failed: %s' % resource_path,
146
+ output)
147
+
148
+ return dst_dir
149
+
150
+
151
+ def oss_has_t_mode(target_file):
152
+ """Test if current enviroment support t-mode written to oss."""
153
+ if 'PAI' not in tf.__version__:
154
+ return False
155
+ # test if running on cluster
156
+ test_file = target_file + '.tmp'
157
+ try:
158
+ with tf.gfile.GFile(test_file, 't') as ofile:
159
+ ofile.write('a')
160
+ pass
161
+ tf.gfile.Remove(test_file)
162
+ return True
163
+ except: # noqa: E722
164
+ return False
165
+
166
+
167
+ def fix_oss_dir(path):
168
+ """Make sure that oss dir endswith /."""
169
+ if path.startswith('oss://') and not path.endswith('/'):
170
+ return path + '/'
171
+ return path
172
+
173
+
174
+ def save_data_to_json_path(json_path, data):
175
+ with tf.gfile.GFile(json_path, 'w') as fout:
176
+ fout.write(json.dumps(data))
177
+ assert tf.gfile.Exists(json_path), 'in_save_data_to_json_path, save_failed'
178
+
179
+
180
+ def read_data_from_json_path(json_path):
181
+ if json_path and tf.gfile.Exists(json_path):
182
+ with tf.gfile.GFile(json_path, 'r') as fin:
183
+ data = json.loads(fin.read())
184
+ return data
185
+ else:
186
+ logging.info('json_path not exists, return None')
187
+ return None
188
+
189
+
190
+ def convert_tf_flags_to_argparse(flags):
191
+ """Convert tf.app.flags.FLAGS to argparse.ArgumentParser.
192
+
193
+ Args:
194
+ flags: tf.app.flags.FLAGS
195
+ Returns:
196
+ argparse.ArgumentParser: configurate ArgumentParser object
197
+ """
198
+ import argparse
199
+ import ast
200
+ parser = argparse.ArgumentParser()
201
+
202
+ args = {}
203
+ for flag in flags._flags().values():
204
+ flag_name = flag.name
205
+ if flag_name in args:
206
+ args[flag_name][0] = True
207
+ continue
208
+ default = flag.value
209
+ flag_type = type(default)
210
+ help_str = flag.help or ''
211
+ args[flag_name] = [
212
+ False, flag_type, default, help_str,
213
+ flag.choices if hasattr(flag, 'choices') else None
214
+ ]
215
+
216
+ def str2bool(v):
217
+ if isinstance(v, bool):
218
+ return v
219
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
220
+ return True
221
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
222
+ return False
223
+ else:
224
+ raise argparse.ArgumentTypeError('Boolean value expected.')
225
+
226
+ for flag_name, (multi, flag_type, default, help_str, choices) in args.items():
227
+ if flag_type == bool:
228
+ parser.add_argument(
229
+ '--' + flag_name,
230
+ type=str2bool,
231
+ nargs='?',
232
+ const=True,
233
+ default=False,
234
+ help=help_str)
235
+ elif flag_type == str:
236
+ if choices:
237
+ parser.add_argument(
238
+ '--' + flag_name,
239
+ type=str,
240
+ choices=choices,
241
+ default=default,
242
+ help=help_str)
243
+ elif multi:
244
+ parser.add_argument(
245
+ '--' + flag_name,
246
+ type=str,
247
+ action='append',
248
+ default=default,
249
+ help=help_str)
250
+ else:
251
+ parser.add_argument(
252
+ '--' + flag_name, type=str, default=default, help=help_str)
253
+ elif flag_type in (list, dict):
254
+ parser.add_argument(
255
+ '--' + flag_name,
256
+ type=lambda s: ast.literal_eval(s),
257
+ default=default,
258
+ help=help_str)
259
+ elif flag_type in (int, float):
260
+ parser.add_argument(
261
+ '--' + flag_name, type=flag_type, default=default, help=help_str)
262
+ else:
263
+ parser.add_argument(
264
+ '--' + flag_name, type=str, default=default, help=help_str)
265
+ return parser
266
+
267
+
268
+ def filter_unknown_args(flags, args):
269
+ """Filter unknown args."""
270
+ known_args = [args[0]]
271
+ parser = convert_tf_flags_to_argparse(flags)
272
+ args, unknown = parser.parse_known_args(args)
273
+ if len(unknown) > 1:
274
+ logging.info('undefined arguments: %s', ', '.join(unknown[1:]))
275
+ for key, value in vars(args).items():
276
+ if value is None:
277
+ continue
278
+ if type(value) in (list, dict) and not value:
279
+ continue
280
+ known_args.append('--' + key + '=' + str(value))
281
+ logging.info('defined arguments: %s', ', '.join(known_args[1:]))
282
+ return known_args