easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
@@ -0,0 +1,102 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+
4
+ import logging
5
+ import os
6
+
7
+ import six
8
+ import tensorflow as tf
9
+ from tensorflow.python.lib.io import file_io
10
+
11
+ from easy_rec.python.main import distribute_evaluate
12
+ from easy_rec.python.main import evaluate
13
+ from easy_rec.python.protos.train_pb2 import DistributionStrategy
14
+ from easy_rec.python.utils import config_util
15
+ from easy_rec.python.utils import ds_util
16
+ from easy_rec.python.utils import estimator_utils
17
+
18
+ from easy_rec.python.utils.distribution_utils import set_tf_config_and_get_distribute_eval_worker_num_on_ds # NOQA
19
+ if tf.__version__ >= '2.0':
20
+ tf = tf.compat.v1
21
+
22
+ logging.basicConfig(
23
+ format='[%(levelname)s] %(asctime)s %(filename)s:%(lineno)d : %(message)s',
24
+ level=logging.INFO)
25
+
26
+ tf.app.flags.DEFINE_string('pipeline_config_path', None,
27
+ 'Path to pipeline config '
28
+ 'file.')
29
+ tf.app.flags.DEFINE_string(
30
+ 'checkpoint_path', None, 'checkpoint to be evaled '
31
+ ' if not specified, use the latest checkpoint in '
32
+ 'train_config.model_dir')
33
+ tf.app.flags.DEFINE_multi_string(
34
+ 'eval_input_path', None, 'eval data path, if specified will '
35
+ 'override pipeline_config.eval_input_path')
36
+ tf.app.flags.DEFINE_string('model_dir', None, help='will update the model_dir')
37
+ tf.app.flags.DEFINE_string('odps_config', None, help='odps config path')
38
+ tf.app.flags.DEFINE_string('eval_result_path', 'eval_result.txt',
39
+ 'eval result metric file')
40
+ tf.app.flags.DEFINE_bool('distribute_eval', False,
41
+ 'use distribute parameter server for train and eval.')
42
+ tf.app.flags.DEFINE_bool('is_on_ds', False, help='is on ds')
43
+ FLAGS = tf.app.flags.FLAGS
44
+
45
+
46
+ def main(argv):
47
+ if FLAGS.odps_config:
48
+ os.environ['ODPS_CONFIG_FILE_PATH'] = FLAGS.odps_config
49
+
50
+ if FLAGS.is_on_ds:
51
+ ds_util.set_on_ds()
52
+ if FLAGS.distribute_eval:
53
+ set_tf_config_and_get_distribute_eval_worker_num_on_ds()
54
+
55
+ assert FLAGS.model_dir or FLAGS.pipeline_config_path, 'At least one of model_dir and pipeline_config_path exists.'
56
+ if FLAGS.model_dir:
57
+ pipeline_config_path = os.path.join(FLAGS.model_dir, 'pipeline.config')
58
+ if file_io.file_exists(pipeline_config_path):
59
+ logging.info('update pipeline_config_path to %s' % pipeline_config_path)
60
+ else:
61
+ pipeline_config_path = FLAGS.pipeline_config_path
62
+ else:
63
+ pipeline_config_path = FLAGS.pipeline_config_path
64
+
65
+ pipeline_config = config_util.get_configs_from_pipeline_file(
66
+ pipeline_config_path)
67
+ if FLAGS.model_dir:
68
+ pipeline_config.model_dir = FLAGS.model_dir
69
+
70
+ if pipeline_config.train_config.train_distribute in [
71
+ DistributionStrategy.HorovodStrategy,
72
+ ]:
73
+ estimator_utils.init_hvd()
74
+ elif pipeline_config.train_config.train_distribute in [
75
+ DistributionStrategy.EmbeddingParallelStrategy,
76
+ DistributionStrategy.SokStrategy
77
+ ]:
78
+ estimator_utils.init_hvd()
79
+ estimator_utils.init_sok()
80
+
81
+ if FLAGS.distribute_eval:
82
+ os.environ['distribute_eval'] = 'True'
83
+ eval_result = distribute_evaluate(pipeline_config, FLAGS.checkpoint_path,
84
+ FLAGS.eval_input_path,
85
+ FLAGS.eval_result_path)
86
+ else:
87
+ os.environ['distribute_eval'] = 'False'
88
+ eval_result = evaluate(pipeline_config, FLAGS.checkpoint_path,
89
+ FLAGS.eval_input_path, FLAGS.eval_result_path)
90
+ if eval_result is not None:
91
+ # when distribute evaluate, only master has eval_result.
92
+ for key in sorted(eval_result):
93
+ # skip logging binary data
94
+ if isinstance(eval_result[key], six.binary_type):
95
+ continue
96
+ logging.info('%s: %s' % (key, str(eval_result[key])))
97
+ else:
98
+ logging.info('Eval result in master worker.')
99
+
100
+
101
+ if __name__ == '__main__':
102
+ tf.app.run()
@@ -0,0 +1,150 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import logging
4
+ import os
5
+
6
+ import tensorflow as tf
7
+ from tensorflow.python.lib.io import file_io
8
+
9
+ from easy_rec.python.main import export
10
+ from easy_rec.python.protos.train_pb2 import DistributionStrategy
11
+ from easy_rec.python.utils import config_util
12
+ from easy_rec.python.utils import estimator_utils
13
+
14
+ if tf.__version__.startswith('1.'):
15
+ from tensorflow.python.platform import gfile
16
+ else:
17
+ import tensorflow.io.gfile as gfile
18
+
19
+ if tf.__version__ >= '2.0':
20
+ tf = tf.compat.v1
21
+
22
+ logging.basicConfig(
23
+ format='[%(levelname)s] %(asctime)s %(filename)s:%(lineno)d : %(message)s',
24
+ level=logging.INFO)
25
+ tf.app.flags.DEFINE_string('pipeline_config_path', None,
26
+ 'Path to pipeline config '
27
+ 'file.')
28
+ tf.app.flags.DEFINE_string('checkpoint_path', '', 'checkpoint to be exported')
29
+ tf.app.flags.DEFINE_string('export_dir', None,
30
+ 'directory where model should be exported to')
31
+
32
+ tf.app.flags.DEFINE_string('redis_url', None, 'export to redis url, host:port')
33
+ tf.app.flags.DEFINE_string('redis_passwd', None, 'export to redis passwd')
34
+ tf.app.flags.DEFINE_integer('redis_threads', 0, 'export to redis threads')
35
+ tf.app.flags.DEFINE_integer('redis_batch_size', 256,
36
+ 'export to redis batch_size')
37
+ tf.app.flags.DEFINE_integer('redis_timeout', 600,
38
+ 'export to redis time_out in seconds')
39
+ tf.app.flags.DEFINE_integer('redis_expire', 24,
40
+ 'export to redis expire time in hour')
41
+ tf.app.flags.DEFINE_string('redis_embedding_version', '',
42
+ 'redis embedding version')
43
+ tf.app.flags.DEFINE_integer('redis_write_kv', 1,
44
+ 'whether to write embedding to redis')
45
+
46
+ tf.app.flags.DEFINE_string(
47
+ 'oss_path', None, 'write embed objects to oss folder, oss://bucket/folder')
48
+ tf.app.flags.DEFINE_string('oss_endpoint', None, 'oss endpoint')
49
+ tf.app.flags.DEFINE_string('oss_ak', None, 'oss ak')
50
+ tf.app.flags.DEFINE_string('oss_sk', None, 'oss sk')
51
+ tf.app.flags.DEFINE_integer('oss_threads', 10,
52
+ '# threads access oss at the same time')
53
+ tf.app.flags.DEFINE_integer('oss_timeout', 10,
54
+ 'connect to oss, time_out in seconds')
55
+ tf.app.flags.DEFINE_integer('oss_expire', 24, 'oss expire time in hours')
56
+ tf.app.flags.DEFINE_integer('oss_write_kv', 1,
57
+ 'whether to write embedding to oss')
58
+ tf.app.flags.DEFINE_string('oss_embedding_version', '', 'oss embedding version')
59
+
60
+ tf.app.flags.DEFINE_string('asset_files', '', 'more files to add to asset')
61
+ tf.app.flags.DEFINE_bool('verbose', False, 'print more debug information')
62
+
63
+ tf.app.flags.DEFINE_string('model_dir', None, help='will update the model_dir')
64
+ tf.app.flags.mark_flag_as_required('export_dir')
65
+
66
+ tf.app.flags.DEFINE_bool('clear_export', False, 'remove export_dir if exists')
67
+ tf.app.flags.DEFINE_string('export_done_file', '',
68
+ 'a flag file to signal that export model is done')
69
+ FLAGS = tf.app.flags.FLAGS
70
+
71
+
72
+ def main(argv):
73
+ extra_params = {}
74
+ if FLAGS.redis_url:
75
+ extra_params['redis_url'] = FLAGS.redis_url
76
+ if FLAGS.redis_passwd:
77
+ extra_params['redis_passwd'] = FLAGS.redis_passwd
78
+ if FLAGS.redis_threads > 0:
79
+ extra_params['redis_threads'] = FLAGS.redis_threads
80
+ if FLAGS.redis_batch_size > 0:
81
+ extra_params['redis_batch_size'] = FLAGS.redis_batch_size
82
+ if FLAGS.redis_expire > 0:
83
+ extra_params['redis_expire'] = FLAGS.redis_expire
84
+ if FLAGS.redis_embedding_version:
85
+ extra_params['redis_embedding_version'] = FLAGS.redis_embedding_version
86
+ if FLAGS.redis_write_kv == 0:
87
+ extra_params['redis_write_kv'] = False
88
+ else:
89
+ extra_params['redis_write_kv'] = True
90
+
91
+ assert FLAGS.model_dir or FLAGS.pipeline_config_path, 'At least one of model_dir and pipeline_config_path exists.'
92
+ if FLAGS.model_dir:
93
+ pipeline_config_path = os.path.join(FLAGS.model_dir, 'pipeline.config')
94
+ if file_io.file_exists(pipeline_config_path):
95
+ logging.info('update pipeline_config_path to %s' % pipeline_config_path)
96
+ else:
97
+ pipeline_config_path = FLAGS.pipeline_config_path
98
+ else:
99
+ pipeline_config_path = FLAGS.pipeline_config_path
100
+
101
+ if FLAGS.oss_path:
102
+ extra_params['oss_path'] = FLAGS.oss_path
103
+ if FLAGS.oss_endpoint:
104
+ extra_params['oss_endpoint'] = FLAGS.oss_endpoint
105
+ if FLAGS.oss_ak:
106
+ extra_params['oss_ak'] = FLAGS.oss_ak
107
+ if FLAGS.oss_sk:
108
+ extra_params['oss_sk'] = FLAGS.oss_sk
109
+ if FLAGS.oss_timeout > 0:
110
+ extra_params['oss_timeout'] = FLAGS.oss_timeout
111
+ if FLAGS.oss_expire > 0:
112
+ extra_params['oss_expire'] = FLAGS.oss_expire
113
+ if FLAGS.oss_threads > 0:
114
+ extra_params['oss_threads'] = FLAGS.oss_threads
115
+ if FLAGS.oss_write_kv:
116
+ extra_params['oss_write_kv'] = True if FLAGS.oss_write_kv == 1 else False
117
+ if FLAGS.oss_embedding_version:
118
+ extra_params['oss_embedding_version'] = FLAGS.oss_embedding_version
119
+
120
+ pipeline_config = config_util.get_configs_from_pipeline_file(
121
+ pipeline_config_path)
122
+ if pipeline_config.train_config.train_distribute in [
123
+ DistributionStrategy.HorovodStrategy,
124
+ ]:
125
+ estimator_utils.init_hvd()
126
+ elif pipeline_config.train_config.train_distribute in [
127
+ DistributionStrategy.EmbeddingParallelStrategy,
128
+ DistributionStrategy.SokStrategy
129
+ ]:
130
+ estimator_utils.init_hvd()
131
+ estimator_utils.init_sok()
132
+
133
+ if FLAGS.clear_export:
134
+ logging.info('will clear export_dir=%s' % FLAGS.export_dir)
135
+ if gfile.IsDirectory(FLAGS.export_dir):
136
+ gfile.DeleteRecursively(FLAGS.export_dir)
137
+
138
+ export_out_dir = export(FLAGS.export_dir, pipeline_config_path,
139
+ FLAGS.checkpoint_path, FLAGS.asset_files,
140
+ FLAGS.verbose, **extra_params)
141
+
142
+ if FLAGS.export_done_file:
143
+ flag_file = os.path.join(export_out_dir, FLAGS.export_done_file)
144
+ logging.info('create export done file: %s' % flag_file)
145
+ with gfile.GFile(flag_file, 'w') as fout:
146
+ fout.write('ExportDone')
147
+
148
+
149
+ if __name__ == '__main__':
150
+ tf.app.run()
File without changes