easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
@@ -0,0 +1,192 @@
1
+ import logging
2
+ from collections import OrderedDict
3
+ from collections import defaultdict
4
+ from copy import copy
5
+ from copy import deepcopy
6
+
7
+
8
+ class DAG(object):
9
+ """Directed acyclic graph implementation."""
10
+
11
+ def __init__(self):
12
+ """Construct a new DAG with no nodes or edges."""
13
+ self.reset_graph()
14
+
15
+ def add_node(self, node_name, graph=None):
16
+ """Add a node if it does not exist yet, or error out."""
17
+ if not graph:
18
+ graph = self.graph
19
+ if node_name in graph:
20
+ raise KeyError('node %s already exists' % node_name)
21
+ graph[node_name] = set()
22
+
23
+ def add_node_if_not_exists(self, node_name, graph=None):
24
+ try:
25
+ self.add_node(node_name, graph=graph)
26
+ except KeyError:
27
+ logging.info('node %s already exist' % node_name)
28
+
29
+ def delete_node(self, node_name, graph=None):
30
+ """Deletes this node and all edges referencing it."""
31
+ if not graph:
32
+ graph = self.graph
33
+ if node_name not in graph:
34
+ raise KeyError('node %s does not exist' % node_name)
35
+ graph.pop(node_name)
36
+
37
+ for node, edges in graph.items():
38
+ if node_name in edges:
39
+ edges.remove(node_name)
40
+
41
+ def delete_node_if_exists(self, node_name, graph=None):
42
+ try:
43
+ self.delete_node(node_name, graph=graph)
44
+ except KeyError:
45
+ logging.info('node %s does not exist' % node_name)
46
+
47
+ def add_edge(self, ind_node, dep_node, graph=None):
48
+ """Add an edge (dependency) between the specified nodes."""
49
+ if not graph:
50
+ graph = self.graph
51
+ if ind_node not in graph or dep_node not in graph:
52
+ raise KeyError('one or more nodes do not exist in graph')
53
+ test_graph = deepcopy(graph)
54
+ test_graph[ind_node].add(dep_node)
55
+ is_valid, message = self.validate(test_graph)
56
+ if is_valid:
57
+ graph[ind_node].add(dep_node)
58
+ else:
59
+ raise Exception('invalid DAG')
60
+
61
+ def delete_edge(self, ind_node, dep_node, graph=None):
62
+ """Delete an edge from the graph."""
63
+ if not graph:
64
+ graph = self.graph
65
+ if dep_node not in graph.get(ind_node, []):
66
+ raise KeyError('this edge does not exist in graph')
67
+ graph[ind_node].remove(dep_node)
68
+
69
+ def rename_edges(self, old_task_name, new_task_name, graph=None):
70
+ """Change references to a task in existing edges."""
71
+ if not graph:
72
+ graph = self.graph
73
+ for node, edges in graph.items():
74
+
75
+ if node == old_task_name:
76
+ graph[new_task_name] = copy(edges)
77
+ del graph[old_task_name]
78
+
79
+ else:
80
+ if old_task_name in edges:
81
+ edges.remove(old_task_name)
82
+ edges.add(new_task_name)
83
+
84
+ def predecessors(self, node, graph=None):
85
+ """Returns a list of all predecessors of the given node."""
86
+ if graph is None:
87
+ graph = self.graph
88
+ return [key for key in graph if node in graph[key]]
89
+
90
+ def downstream(self, node, graph=None):
91
+ """Returns a list of all nodes this node has edges towards."""
92
+ if graph is None:
93
+ graph = self.graph
94
+ if node not in graph:
95
+ raise KeyError('node %s is not in graph' % node)
96
+ return list(graph[node])
97
+
98
+ def all_downstreams(self, node, graph=None):
99
+ """Returns a list of all nodes ultimately downstream of the given node in the dependency graph.
100
+
101
+ in topological order.
102
+ """
103
+ if graph is None:
104
+ graph = self.graph
105
+ nodes = [node]
106
+ nodes_seen = set()
107
+ i = 0
108
+ while i < len(nodes):
109
+ downstreams = self.downstream(nodes[i], graph)
110
+ for downstream_node in downstreams:
111
+ if downstream_node not in nodes_seen:
112
+ nodes_seen.add(downstream_node)
113
+ nodes.append(downstream_node)
114
+ i += 1
115
+ return list(
116
+ filter(lambda node: node in nodes_seen,
117
+ self.topological_sort(graph=graph)))
118
+
119
+ def all_leaves(self, graph=None):
120
+ """Return a list of all leaves (nodes with no downstreams)."""
121
+ if graph is None:
122
+ graph = self.graph
123
+ return [key for key in graph if not graph[key]]
124
+
125
+ def from_dict(self, graph_dict):
126
+ """Reset the graph and build it from the passed dictionary.
127
+
128
+ The dictionary takes the form of {node_name: [directed edges]}
129
+ """
130
+ self.reset_graph()
131
+ for new_node in graph_dict.keys():
132
+ self.add_node(new_node)
133
+ for ind_node, dep_nodes in graph_dict.items():
134
+ if not isinstance(dep_nodes, list):
135
+ raise TypeError('dict values must be lists')
136
+ for dep_node in dep_nodes:
137
+ self.add_edge(ind_node, dep_node)
138
+
139
+ def reset_graph(self):
140
+ """Restore the graph to an empty state."""
141
+ self.graph = OrderedDict()
142
+
143
+ def independent_nodes(self, graph=None):
144
+ """Returns a list of all nodes in the graph with no dependencies."""
145
+ if graph is None:
146
+ graph = self.graph
147
+
148
+ dependent_nodes = set(
149
+ node for dependents in graph.values() for node in dependents)
150
+ return [node for node in graph.keys() if node not in dependent_nodes]
151
+
152
+ def validate(self, graph=None):
153
+ """Returns (Boolean, message) of whether DAG is valid."""
154
+ graph = graph if graph is not None else self.graph
155
+ if len(self.independent_nodes(graph)) == 0:
156
+ return False, 'no independent nodes detected'
157
+ try:
158
+ self.topological_sort(graph)
159
+ except ValueError:
160
+ return False, 'failed topological sort'
161
+ return True, 'valid'
162
+
163
+ def topological_sort(self, graph=None):
164
+ """Returns a topological ordering of the DAG.
165
+
166
+ Raises an error if this is not possible (graph is not valid).
167
+ """
168
+ if graph is None:
169
+ graph = self.graph
170
+ result = []
171
+ in_degree = defaultdict(lambda: 0)
172
+
173
+ for u in graph:
174
+ for v in graph[u]:
175
+ in_degree[v] += 1
176
+ ready = [node for node in graph if not in_degree[node]]
177
+
178
+ while ready:
179
+ u = ready.pop()
180
+ result.append(u)
181
+ for v in graph[u]:
182
+ in_degree[v] -= 1
183
+ if in_degree[v] == 0:
184
+ ready.append(v)
185
+
186
+ if len(result) == len(graph):
187
+ return result
188
+ else:
189
+ raise ValueError('graph is not acyclic')
190
+
191
+ def size(self):
192
+ return len(self.graph)
@@ -0,0 +1,268 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ from __future__ import print_function
4
+
5
+ import json
6
+ import logging
7
+ import os
8
+
9
+ import tensorflow as tf
10
+
11
+ from easy_rec.python.protos.train_pb2 import DistributionStrategy
12
+ from easy_rec.python.utils import estimator_utils
13
+ from easy_rec.python.utils.estimator_utils import chief_to_master
14
+ from easy_rec.python.utils.estimator_utils import master_to_chief
15
+
16
+ DistributionStrategyMap = {
17
+ '': DistributionStrategy.NoStrategy,
18
+ 'ps': DistributionStrategy.PSStrategy,
19
+ 'ess': DistributionStrategy.ExascaleStrategy,
20
+ 'mirrored': DistributionStrategy.MirroredStrategy,
21
+ 'collective': DistributionStrategy.CollectiveAllReduceStrategy
22
+ }
23
+
24
+
25
+ def set_distribution_config(pipeline_config, num_worker, num_gpus_per_worker,
26
+ distribute_strategy):
27
+ if distribute_strategy in [
28
+ DistributionStrategy.PSStrategy, DistributionStrategy.MirroredStrategy,
29
+ DistributionStrategy.CollectiveAllReduceStrategy,
30
+ DistributionStrategy.ExascaleStrategy
31
+ ]:
32
+ pipeline_config.train_config.sync_replicas = False
33
+ pipeline_config.train_config.train_distribute = distribute_strategy
34
+ pipeline_config.train_config.num_gpus_per_worker = num_gpus_per_worker
35
+ print('Dump pipeline_config.train_config:')
36
+ print(pipeline_config.train_config)
37
+
38
+
39
+ def set_tf_config_and_get_train_worker_num(
40
+ ps_hosts,
41
+ worker_hosts,
42
+ task_index,
43
+ job_name,
44
+ distribute_strategy=DistributionStrategy.NoStrategy,
45
+ eval_method='none'):
46
+ logging.info(
47
+ 'set_tf_config_and_get_train_worker_num: distribute_strategy = %d' %
48
+ distribute_strategy)
49
+ worker_hosts = worker_hosts.split(',')
50
+ ps_hosts = ps_hosts.split(',') if ps_hosts else []
51
+
52
+ total_worker_num = len(worker_hosts)
53
+ train_worker_num = total_worker_num
54
+
55
+ print('Original TF_CONFIG=%s' % os.environ.get('TF_CONFIG', ''))
56
+ print('worker_hosts=%s ps_hosts=%s task_index=%d job_name=%s' %
57
+ (','.join(worker_hosts), ','.join(ps_hosts), task_index, job_name))
58
+ print('eval_method=%s' % eval_method)
59
+
60
+ if distribute_strategy == DistributionStrategy.MirroredStrategy:
61
+ assert total_worker_num == 1, 'mirrored distribute strategy only need 1 worker'
62
+ elif distribute_strategy in [
63
+ DistributionStrategy.NoStrategy, DistributionStrategy.PSStrategy,
64
+ DistributionStrategy.CollectiveAllReduceStrategy,
65
+ DistributionStrategy.ExascaleStrategy
66
+ ]:
67
+ cluster, task_type, task_index_ = estimator_utils.parse_tf_config()
68
+ train_worker_num = 0
69
+ if eval_method == 'separate':
70
+ if 'evaluator' in cluster:
71
+ # 'evaluator' in cluster indicates user use new-style cluster content
72
+ if 'chief' in cluster:
73
+ train_worker_num += len(cluster['chief'])
74
+ elif 'master' in cluster:
75
+ train_worker_num += len(cluster['master'])
76
+ if 'worker' in cluster:
77
+ train_worker_num += len(cluster['worker'])
78
+ # drop evaluator to avoid hang
79
+ if distribute_strategy == DistributionStrategy.NoStrategy:
80
+ del cluster['evaluator']
81
+ tf_config = {
82
+ 'cluster': cluster,
83
+ 'task': {
84
+ 'type': task_type,
85
+ 'index': task_index_
86
+ }
87
+ }
88
+ os.environ['TF_CONFIG'] = json.dumps(tf_config)
89
+ else:
90
+ # backward compatibility, if user does not assign one evaluator in
91
+ # -Dcluster, we use first worker for chief, second for evaluation
92
+ train_worker_num = total_worker_num - 1
93
+ assert train_worker_num > 0, 'in distribution mode worker num must be greater than 1, ' \
94
+ 'the second worker will be used as evaluator'
95
+ if len(worker_hosts) > 1:
96
+ cluster = {'chief': [worker_hosts[0]], 'worker': worker_hosts[2:]}
97
+ if distribute_strategy != DistributionStrategy.NoStrategy:
98
+ cluster['evaluator'] = [worker_hosts[1]]
99
+ if len(ps_hosts) > 0:
100
+ cluster['ps'] = ps_hosts
101
+ if job_name == 'ps':
102
+ os.environ['TF_CONFIG'] = json.dumps({
103
+ 'cluster': cluster,
104
+ 'task': {
105
+ 'type': job_name,
106
+ 'index': task_index
107
+ }
108
+ })
109
+ elif job_name == 'worker':
110
+ if task_index == 0:
111
+ os.environ['TF_CONFIG'] = json.dumps({
112
+ 'cluster': cluster,
113
+ 'task': {
114
+ 'type': 'chief',
115
+ 'index': 0
116
+ }
117
+ })
118
+ elif task_index == 1:
119
+ os.environ['TF_CONFIG'] = json.dumps({
120
+ 'cluster': cluster,
121
+ 'task': {
122
+ 'type': 'evaluator',
123
+ 'index': 0
124
+ }
125
+ })
126
+ else:
127
+ os.environ['TF_CONFIG'] = json.dumps({
128
+ 'cluster': cluster,
129
+ 'task': {
130
+ 'type': job_name,
131
+ 'index': task_index - 2
132
+ }
133
+ })
134
+ else:
135
+ if 'evaluator' in cluster:
136
+ evaluator = cluster['evaluator']
137
+ del cluster['evaluator']
138
+ # 'evaluator' in cluster indicates user use new-style cluster content
139
+ train_worker_num += 1
140
+ if 'chief' in cluster:
141
+ train_worker_num += len(cluster['chief'])
142
+ elif 'master' in cluster:
143
+ train_worker_num += len(cluster['master'])
144
+ if 'worker' in cluster:
145
+ train_worker_num += len(cluster['worker'])
146
+ cluster['worker'].append(evaluator[0])
147
+ else:
148
+ cluster['worker'] = [evaluator[0]]
149
+ if task_type == 'evaluator':
150
+ tf_config = {
151
+ 'cluster': cluster,
152
+ 'task': {
153
+ 'type': 'worker',
154
+ 'index': train_worker_num - 2
155
+ }
156
+ }
157
+ else:
158
+ tf_config = {
159
+ 'cluster': cluster,
160
+ 'task': {
161
+ 'type': task_type,
162
+ 'index': task_index_
163
+ }
164
+ }
165
+ os.environ['TF_CONFIG'] = json.dumps(tf_config)
166
+ else:
167
+ cluster = {'chief': [worker_hosts[0]], 'worker': worker_hosts[1:]}
168
+ train_worker_num = len(worker_hosts)
169
+ if len(ps_hosts) > 0:
170
+ cluster['ps'] = ps_hosts
171
+ if job_name == 'ps':
172
+ os.environ['TF_CONFIG'] = json.dumps({
173
+ 'cluster': cluster,
174
+ 'task': {
175
+ 'type': job_name,
176
+ 'index': task_index
177
+ }
178
+ })
179
+ else:
180
+ if task_index == 0:
181
+ os.environ['TF_CONFIG'] = json.dumps({
182
+ 'cluster': cluster,
183
+ 'task': {
184
+ 'type': 'chief',
185
+ 'index': 0
186
+ }
187
+ })
188
+ else:
189
+ os.environ['TF_CONFIG'] = json.dumps({
190
+ 'cluster': cluster,
191
+ 'task': {
192
+ 'type': 'worker',
193
+ 'index': task_index - 1
194
+ }
195
+ })
196
+ if eval_method == 'none':
197
+ # change master to chief, will not evaluate
198
+ master_to_chief()
199
+ elif eval_method == 'master':
200
+ # change chief to master, will evaluate on master
201
+ chief_to_master()
202
+ else:
203
+ assert distribute_strategy == '', 'invalid distribute_strategy %s'\
204
+ % distribute_strategy
205
+ cluster, task_type, task_index = estimator_utils.parse_tf_config()
206
+ print('Final TF_CONFIG = %s' % os.environ.get('TF_CONFIG', ''))
207
+ tf.logging.info('TF_CONFIG %s' % os.environ.get('TF_CONFIG', ''))
208
+ tf.logging.info('distribute_stategy %s, train_worker_num: %d' %
209
+ (distribute_strategy, train_worker_num))
210
+
211
+ # remove pai chief-worker waiting strategy
212
+ # which is conflicted with worker waiting strategy in easyrec
213
+ if 'TF_WRITE_WORKER_STATUS_FILE' in os.environ:
214
+ del os.environ['TF_WRITE_WORKER_STATUS_FILE']
215
+ return train_worker_num
216
+
217
+
218
+ def set_tf_config_and_get_train_worker_num_on_ds():
219
+ if 'TF_CONFIG' not in os.environ:
220
+ return
221
+ tf_config = json.loads(os.environ['TF_CONFIG'])
222
+ if 'cluster' in tf_config and 'ps' in tf_config['cluster'] and (
223
+ 'evaluator' not in tf_config['cluster']):
224
+ easyrec_tf_config = dict()
225
+ easyrec_tf_config['cluster'] = {}
226
+ easyrec_tf_config['task'] = {}
227
+ easyrec_tf_config['cluster']['ps'] = tf_config['cluster']['ps']
228
+ easyrec_tf_config['cluster']['chief'] = [tf_config['cluster']['worker'][0]]
229
+ easyrec_tf_config['cluster']['worker'] = tf_config['cluster']['worker'][2:]
230
+
231
+ if tf_config['task']['type'] == 'worker' and tf_config['task']['index'] == 0:
232
+ easyrec_tf_config['task']['type'] = 'chief'
233
+ easyrec_tf_config['task']['index'] = 0
234
+ elif tf_config['task']['type'] == 'worker' and tf_config['task'][
235
+ 'index'] == 1:
236
+ easyrec_tf_config['task']['type'] = 'evaluator'
237
+ easyrec_tf_config['task']['index'] = 0
238
+ elif tf_config['task']['type'] == 'worker':
239
+ easyrec_tf_config['task']['type'] = tf_config['task']['type']
240
+ easyrec_tf_config['task']['index'] = tf_config['task']['index'] - 2
241
+ else:
242
+ easyrec_tf_config['task']['type'] = tf_config['task']['type']
243
+ easyrec_tf_config['task']['index'] = tf_config['task']['index']
244
+ os.environ['TF_CONFIG'] = json.dumps(easyrec_tf_config)
245
+
246
+
247
+ def set_tf_config_and_get_distribute_eval_worker_num_on_ds():
248
+ assert 'TF_CONFIG' in os.environ, "'TF_CONFIG' must in os.environ"
249
+ tf_config = json.loads(os.environ['TF_CONFIG'])
250
+ if 'cluster' in tf_config and 'ps' in tf_config['cluster'] and (
251
+ 'evaluator' not in tf_config['cluster']):
252
+ easyrec_tf_config = dict()
253
+ easyrec_tf_config['cluster'] = {}
254
+ easyrec_tf_config['task'] = {}
255
+ easyrec_tf_config['cluster']['ps'] = tf_config['cluster']['ps']
256
+ easyrec_tf_config['cluster']['chief'] = [tf_config['cluster']['worker'][0]]
257
+ easyrec_tf_config['cluster']['worker'] = tf_config['cluster']['worker'][1:]
258
+
259
+ if tf_config['task']['type'] == 'worker' and tf_config['task']['index'] == 0:
260
+ easyrec_tf_config['task']['type'] = 'chief'
261
+ easyrec_tf_config['task']['index'] = 0
262
+ elif tf_config['task']['type'] == 'worker':
263
+ easyrec_tf_config['task']['type'] = tf_config['task']['type']
264
+ easyrec_tf_config['task']['index'] = tf_config['task']['index'] - 1
265
+ else:
266
+ easyrec_tf_config['task']['type'] = tf_config['task']['type']
267
+ easyrec_tf_config['task']['index'] = tf_config['task']['index']
268
+ os.environ['TF_CONFIG'] = json.dumps(easyrec_tf_config)
@@ -0,0 +1,65 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import logging
4
+ import os
5
+ import subprocess
6
+ import traceback
7
+
8
+ from tensorflow.python.platform import gfile
9
+
10
+ from easy_rec.python.utils import estimator_utils
11
+
12
+
13
+ def is_on_ds():
14
+ # IS_ON_PAI is set in train_eval
15
+ # which is the entry on DataScience platform
16
+ return 'IS_ON_DS' in os.environ
17
+
18
+
19
+ def set_on_ds():
20
+ logging.info('set on ds environment variable: IS_ON_DS')
21
+ os.environ['IS_ON_DS'] = '1'
22
+
23
+
24
+ def cache_ckpt(pipeline_config):
25
+ fine_tune_ckpt_path = pipeline_config.train_config.fine_tune_checkpoint
26
+ if not fine_tune_ckpt_path.startswith('hdfs://'):
27
+ # there is no need to cache if remote directories are mounted
28
+ return
29
+
30
+ if estimator_utils.is_ps() or estimator_utils.is_chief(
31
+ ) or estimator_utils.is_master():
32
+ tmpdir = os.path.dirname(fine_tune_ckpt_path.replace('hdfs://', ''))
33
+ tmpdir = os.path.join('/tmp/experiments', tmpdir)
34
+ logging.info('will cache fine_tune_ckpt to local dir: %s' % tmpdir)
35
+ if gfile.IsDirectory(tmpdir):
36
+ gfile.DeleteRecursively(tmpdir)
37
+ gfile.MakeDirs(tmpdir)
38
+ src_files = gfile.Glob(fine_tune_ckpt_path + '*')
39
+ src_files.sort()
40
+ data_files = [x for x in src_files if '.data-' in x]
41
+ meta_files = [x for x in src_files if '.data-' not in x]
42
+ if estimator_utils.is_ps():
43
+ _, _, ps_id = estimator_utils.parse_tf_config()
44
+ ps_id = (ps_id % len(data_files))
45
+ data_files = data_files[ps_id:] + data_files[:ps_id]
46
+ src_files = meta_files + data_files
47
+ else:
48
+ src_files = meta_files
49
+ for src_path in src_files:
50
+ _, file_name = os.path.split(src_path)
51
+ dst_path = os.path.join(tmpdir, os.path.basename(src_path))
52
+ logging.info('will copy %s to local path %s' % (src_path, dst_path))
53
+ try:
54
+ output = subprocess.check_output(
55
+ 'hadoop fs -get %s %s' % (src_path, dst_path), shell=True)
56
+ logging.info('copy succeed: %s' % output)
57
+ except Exception:
58
+ logging.warning('exception: %s' % traceback.format_exc())
59
+ ckpt_filename = os.path.basename(fine_tune_ckpt_path)
60
+ fine_tune_ckpt_path = os.path.join(tmpdir, ckpt_filename)
61
+ pipeline_config.train_config.fine_tune_checkpoint = fine_tune_ckpt_path
62
+ logging.info('will restore from %s' % fine_tune_ckpt_path)
63
+ else:
64
+ # workers do not have to create the restore graph
65
+ pipeline_config.train_config.ClearField('fine_tune_checkpoint')
@@ -0,0 +1,73 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import os
4
+
5
+ import tensorflow as tf
6
+ from tensorflow.python.framework import ops
7
+
8
+ from easy_rec.python.utils import constant
9
+ from easy_rec.python.utils import proto_util
10
+
11
+ if tf.__version__ >= '2.0':
12
+ tf = tf.compat.v1
13
+
14
+
15
+ def get_norm_name_to_ids():
16
+ """Get normalize embedding name(including kv variables) to ids.
17
+
18
+ Return:
19
+ normalized names to ids mapping.
20
+ """
21
+ norm_name_to_ids = {}
22
+ for x in ops.get_collection(constant.SPARSE_UPDATE_VARIABLES):
23
+ norm_name, part_id = proto_util.get_norm_embed_name(x[0].name)
24
+ norm_name_to_ids[norm_name] = 1
25
+
26
+ for tid, t in enumerate(norm_name_to_ids.keys()):
27
+ norm_name_to_ids[t] = str(tid)
28
+ return norm_name_to_ids
29
+
30
+
31
+ def get_sparse_name_to_ids():
32
+ """Get embedding variable(including kv variables) name to ids mapping.
33
+
34
+ Return:
35
+ variable names to ids mappping.
36
+ """
37
+ norm_name_to_ids = get_norm_name_to_ids()
38
+ name_to_ids = {}
39
+ for x in ops.get_collection(constant.SPARSE_UPDATE_VARIABLES):
40
+ norm_name, _ = proto_util.get_norm_embed_name(x[0].name)
41
+ name_to_ids[x[0].name] = norm_name_to_ids[norm_name]
42
+ return name_to_ids
43
+
44
+
45
+ def get_dense_name_to_ids():
46
+ dense_train_vars = ops.get_collection(constant.DENSE_UPDATE_VARIABLES)
47
+ norm_name_to_ids = {}
48
+ for tid, x in enumerate(dense_train_vars):
49
+ norm_name_to_ids[x.op.name] = tid
50
+ return norm_name_to_ids
51
+
52
+
53
+ embedding_parallel = False
54
+
55
+
56
+ def set_embedding_parallel():
57
+ global embedding_parallel
58
+ embedding_parallel = True
59
+
60
+
61
+ def is_embedding_parallel():
62
+ global embedding_parallel
63
+ return embedding_parallel
64
+
65
+
66
+ def sort_col_by_name():
67
+ return constant.SORT_COL_BY_NAME in os.environ
68
+
69
+
70
+ def embedding_on_cpu():
71
+ place_on_cpu = os.getenv(constant.EmbeddingOnCPU)
72
+ place_on_cpu = eval(place_on_cpu) if place_on_cpu else False
73
+ return place_on_cpu