easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
@@ -0,0 +1,286 @@
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ import copy
3
+ import logging
4
+ import os
5
+ import sys
6
+
7
+ import tensorflow as tf
8
+ from tensorflow.core.framework import graph_pb2
9
+ from tensorflow.python.framework import importer
10
+ from tensorflow.python.framework import ops
11
+ from tensorflow.python.framework.dtypes import _TYPE_TO_STRING
12
+ from tensorflow.python.ops.resource_variable_ops import _from_proto_fn
13
+ from tensorflow.python.saved_model import signature_constants
14
+ from tensorflow.python.tools import saved_model_utils
15
+ from tensorflow.python.training import saver as tf_saver
16
+
17
+ from easy_rec.python.utils import io_util
18
+
19
+ if tf.__version__ >= '2.0':
20
+ tf = tf.compat.v1
21
+ from tensorflow.python.saved_model.path_helpers import get_variables_path
22
+ else:
23
+ from tensorflow.python.saved_model.utils_impl import get_variables_path
24
+
25
+ FLAGS = tf.app.flags.FLAGS
26
+ tf.app.flags.DEFINE_string('model_dir', '', '')
27
+ tf.app.flags.DEFINE_string('user_model_dir', '', '')
28
+ tf.app.flags.DEFINE_string('item_model_dir', '', '')
29
+ tf.app.flags.DEFINE_string('user_fg_json_path', '', '')
30
+ tf.app.flags.DEFINE_string('item_fg_json_path', '', '')
31
+
32
+ logging.basicConfig(
33
+ level=logging.INFO, format='[%(asctime)s][%(levelname)s] %(message)s')
34
+
35
+
36
+ def search_pb(directory):
37
+ dir_list = []
38
+ for root, dirs, files in tf.gfile.Walk(directory):
39
+ for f in files:
40
+ _, ext = os.path.splitext(f)
41
+ if ext == '.pb':
42
+ dir_list.append(root)
43
+ if len(dir_list) == 0:
44
+ raise ValueError('savedmodel is not found in directory %s' % directory)
45
+ elif len(dir_list) > 1:
46
+ raise ValueError('multiple saved model found in directory %s' % directory)
47
+
48
+ return dir_list[0]
49
+
50
+
51
+ def _node_name(name):
52
+ if name.startswith('^'):
53
+ return name[1:]
54
+ else:
55
+ return name.split(':')[0]
56
+
57
+
58
+ def extract_sub_graph(graph_def, dest_nodes, variable_protos):
59
+ """Extract the subgraph that can reach any of the nodes in 'dest_nodes'.
60
+
61
+ Args:
62
+ graph_def: graph_pb2.GraphDef
63
+ dest_nodes: a list includes output node names
64
+
65
+ Returns:
66
+ out: the GraphDef of the sub-graph.
67
+ variables_to_keep: variables to be kept for saver.
68
+ """
69
+ if not isinstance(graph_def, graph_pb2.GraphDef):
70
+ raise TypeError('graph_def must be a graph_pb2.GraphDef proto.')
71
+
72
+ edges = {}
73
+ name_to_node_map = {}
74
+ node_seq = {}
75
+ seq = 0
76
+ nodes_to_keep = set()
77
+ variables_to_keep = set()
78
+
79
+ for node in graph_def.node:
80
+ n = _node_name(node.name)
81
+ name_to_node_map[n] = node
82
+ edges[n] = [_node_name(item) for item in node.input]
83
+ node_seq[n] = seq
84
+ seq += 1
85
+ for d in dest_nodes:
86
+ assert d in name_to_node_map, "'%s' is not in graph" % d
87
+
88
+ next_to_visit = dest_nodes[:]
89
+ while next_to_visit:
90
+ n = next_to_visit[0]
91
+
92
+ if n in variable_protos:
93
+ proto = variable_protos[n]
94
+ next_to_visit.append(_node_name(proto.initial_value_name))
95
+ next_to_visit.append(_node_name(proto.initializer_name))
96
+ next_to_visit.append(_node_name(proto.snapshot_name))
97
+ variables_to_keep.add(proto.variable_name)
98
+
99
+ del next_to_visit[0]
100
+ if n in nodes_to_keep:
101
+ continue
102
+ # make sure n is in edges
103
+ if n in edges:
104
+ nodes_to_keep.add(n)
105
+ next_to_visit += edges[n]
106
+ nodes_to_keep_list = sorted(list(nodes_to_keep), key=lambda n: node_seq[n])
107
+
108
+ out = graph_pb2.GraphDef()
109
+ for n in nodes_to_keep_list:
110
+ out.node.extend([copy.deepcopy(name_to_node_map[n])])
111
+ out.library.CopyFrom(graph_def.library)
112
+ out.versions.CopyFrom(graph_def.versions)
113
+
114
+ return out, variables_to_keep
115
+
116
+
117
+ def load_meta_graph_def(model_dir):
118
+ """Load meta graph def in saved model.
119
+
120
+ Args:
121
+ model_dir: saved model directory.
122
+
123
+ Returns:
124
+ meta_graph_def: a MetaGraphDef.
125
+ variable_protos: a dict of VariableDef.
126
+ input_tensor_names: signature inputs in saved model.
127
+ output_tensor_names: signature outputs in saved model.
128
+ """
129
+ input_tensor_names = {}
130
+ output_tensor_names = {}
131
+ variable_protos = {}
132
+
133
+ meta_graph_def = saved_model_utils.get_meta_graph_def(
134
+ model_dir, tf.saved_model.tag_constants.SERVING)
135
+ signatures = meta_graph_def.signature_def
136
+ collections = meta_graph_def.collection_def
137
+
138
+ # parse collection_def in SavedModel
139
+ for key, col_def in collections.items():
140
+ if key in ops.GraphKeys._VARIABLE_COLLECTIONS:
141
+ tf.logging.info('[Collection] %s:' % key)
142
+ for value in col_def.bytes_list.value:
143
+ proto_type = ops.get_collection_proto_type(key)
144
+ proto = proto_type()
145
+ proto.ParseFromString(value)
146
+ tf.logging.info('%s' % proto.variable_name)
147
+ variable_node_name = _node_name(proto.variable_name)
148
+ if variable_node_name not in variable_protos:
149
+ variable_protos[variable_node_name] = proto
150
+
151
+ # parse signature info for SavedModel
152
+ for sig_name in signatures:
153
+ if signatures[
154
+ sig_name].method_name == tf.saved_model.signature_constants.PREDICT_METHOD_NAME:
155
+ tf.logging.info('[Signature] inputs:')
156
+ for input_name in signatures[sig_name].inputs:
157
+ input_tensor_shape = []
158
+ input_tensor = signatures[sig_name].inputs[input_name]
159
+ for dim in input_tensor.tensor_shape.dim:
160
+ input_tensor_shape.append(int(dim.size))
161
+ tf.logging.info('"%s": %s; %s' %
162
+ (input_name, _TYPE_TO_STRING[input_tensor.dtype],
163
+ input_tensor_shape))
164
+ input_tensor_names[input_name] = input_tensor.name
165
+ tf.logging.info('[Signature] outputs:')
166
+ for output_name in signatures[sig_name].outputs:
167
+ output_tensor_shape = []
168
+ output_tensor = signatures[sig_name].outputs[output_name]
169
+ for dim in output_tensor.tensor_shape.dim:
170
+ output_tensor_shape.append(int(dim.size))
171
+ tf.logging.info('"%s": %s; %s' %
172
+ (output_name, _TYPE_TO_STRING[output_tensor.dtype],
173
+ output_tensor_shape))
174
+ output_tensor_names[output_name] = output_tensor.name
175
+
176
+ return meta_graph_def, variable_protos, input_tensor_names, output_tensor_names
177
+
178
+
179
+ def export(model_dir, meta_graph_def, variable_protos, input_tensor_names,
180
+ output_tensor_names, part_name, part_dir):
181
+ """Export subpart saved model.
182
+
183
+ Args:
184
+ model_dir: saved model directory.
185
+ meta_graph_def: a MetaGraphDef.
186
+ variable_protos: a dict of VariableDef.
187
+ input_tensor_names: signature inputs in saved model.
188
+ output_tensor_names: signature outputs in saved model.
189
+ part_name: subpart model name, user or item.
190
+ part_dir: subpart model export directory.
191
+ """
192
+ output_tensor_names = {
193
+ x: output_tensor_names[x]
194
+ for x in output_tensor_names.keys()
195
+ if part_name in x
196
+ }
197
+ output_node_names = [
198
+ _node_name(output_tensor_names[x]) for x in output_tensor_names.keys()
199
+ ]
200
+
201
+ inference_graph, variables_to_keep = extract_sub_graph(
202
+ meta_graph_def.graph_def, output_node_names, variable_protos)
203
+
204
+ tf.reset_default_graph()
205
+ with tf.Session() as sess:
206
+ with sess.graph.as_default():
207
+ graph = ops.get_default_graph()
208
+ importer.import_graph_def(inference_graph, name='')
209
+ for name in variables_to_keep:
210
+ variable = _from_proto_fn(variable_protos[name.split(':')[0]])
211
+ graph.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, variable)
212
+ saver = tf_saver.Saver()
213
+ saver.restore(sess, get_variables_path(model_dir))
214
+
215
+ builder = tf.saved_model.builder.SavedModelBuilder(part_dir)
216
+ signature_inputs = {}
217
+ for input_name in input_tensor_names:
218
+ try:
219
+ tensor_info = tf.saved_model.utils.build_tensor_info(
220
+ graph.get_tensor_by_name(input_tensor_names[input_name]))
221
+ signature_inputs[input_name] = tensor_info
222
+ except Exception:
223
+ print('ignore input: %s' % input_name)
224
+
225
+ signature_outputs = {}
226
+ for output_name in output_tensor_names:
227
+ tensor_info = tf.saved_model.utils.build_tensor_info(
228
+ graph.get_tensor_by_name(output_tensor_names[output_name]))
229
+ signature_outputs[output_name] = tensor_info
230
+
231
+ prediction_signature = (
232
+ tf.saved_model.signature_def_utils.build_signature_def(
233
+ inputs=signature_inputs,
234
+ outputs=signature_outputs,
235
+ method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME
236
+ ))
237
+
238
+ builder.add_meta_graph_and_variables(
239
+ sess, [tf.saved_model.tag_constants.SERVING],
240
+ signature_def_map={
241
+ signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
242
+ prediction_signature,
243
+ })
244
+ builder.save()
245
+ config_path = os.path.join(model_dir, 'assets/pipeline.config')
246
+ assert tf.gfile.Exists(config_path)
247
+ dst_path = os.path.join(part_dir, 'assets')
248
+ dst_config_path = os.path.join(dst_path, 'pipeline.config')
249
+ tf.gfile.MkDir(dst_path)
250
+ tf.gfile.Copy(config_path, dst_config_path)
251
+ if part_name == 'user' and FLAGS.user_fg_json_path:
252
+ dst_fg_path = os.path.join(dst_path, 'fg.json')
253
+ tf.gfile.Copy(FLAGS.user_fg_json_path, dst_fg_path)
254
+ if part_name == 'item' and FLAGS.item_fg_json_path:
255
+ dst_fg_path = os.path.join(dst_path, 'fg.json')
256
+ tf.gfile.Copy(FLAGS.item_fg_json_path, dst_fg_path)
257
+
258
+
259
+ def main(argv):
260
+ model_dir = search_pb(FLAGS.model_dir)
261
+ tf.logging.info('Loading meta graph...')
262
+ meta_graph_def, variable_protos, input_tensor_names, output_tensor_names = load_meta_graph_def(
263
+ model_dir)
264
+ tf.logging.info('Exporting user part model...')
265
+ export(
266
+ model_dir,
267
+ meta_graph_def,
268
+ variable_protos,
269
+ input_tensor_names,
270
+ output_tensor_names,
271
+ part_name='user',
272
+ part_dir=FLAGS.user_model_dir)
273
+ tf.logging.info('Exporting item part model...')
274
+ export(
275
+ model_dir,
276
+ meta_graph_def,
277
+ variable_protos,
278
+ input_tensor_names,
279
+ output_tensor_names,
280
+ part_name='item',
281
+ part_dir=FLAGS.item_model_dir)
282
+
283
+
284
+ if __name__ == '__main__':
285
+ sys.argv = io_util.filter_unknown_args(FLAGS, sys.argv)
286
+ tf.app.run()
@@ -0,0 +1,272 @@
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ import copy
3
+ import logging
4
+ import os
5
+ import sys
6
+
7
+ import tensorflow as tf
8
+ from tensorflow.core.framework import graph_pb2
9
+ from tensorflow.python.framework import importer
10
+ from tensorflow.python.framework import ops
11
+ from tensorflow.python.framework.dtypes import _TYPE_TO_STRING
12
+ from tensorflow.python.saved_model import signature_constants
13
+ from tensorflow.python.saved_model.utils_impl import get_variables_path
14
+ from tensorflow.python.tools import saved_model_utils
15
+ from tensorflow.python.training import saver as tf_saver
16
+
17
+ from easy_rec.python.utils import io_util
18
+
19
+ FLAGS = tf.app.flags.FLAGS
20
+ tf.app.flags.DEFINE_string('model_dir', '', '')
21
+ tf.app.flags.DEFINE_string('trigger_model_dir', '', '')
22
+ tf.app.flags.DEFINE_string('sim_model_dir', '', '')
23
+
24
+ logging.basicConfig(
25
+ level=logging.INFO, format='[%(asctime)s][%(levelname)s] %(message)s')
26
+
27
+
28
+ def search_pb(directory):
29
+ dir_list = []
30
+ for root, dirs, files in tf.gfile.Walk(directory):
31
+ for f in files:
32
+ _, ext = os.path.splitext(f)
33
+ if ext == '.pb':
34
+ dir_list.append(root)
35
+ if len(dir_list) == 0:
36
+ raise ValueError('savedmodel is not found in directory %s' % directory)
37
+ elif len(dir_list) > 1:
38
+ raise ValueError('multiple saved model found in directory %s' % directory)
39
+
40
+ return dir_list[0]
41
+
42
+
43
+ def _node_name(name):
44
+ if name.startswith('^'):
45
+ return name[1:]
46
+ else:
47
+ return name.split(':')[0]
48
+
49
+
50
+ def extract_sub_graph(graph_def, dest_nodes, variable_protos):
51
+ """Extract the subgraph that can reach any of the nodes in 'dest_nodes'.
52
+
53
+ Args:
54
+ graph_def: graph_pb2.GraphDef
55
+ dest_nodes: a list includes output node names
56
+
57
+ Returns:
58
+ out: the GraphDef of the sub-graph.
59
+ variables_to_keep: variables to be kept for saver.
60
+ """
61
+ if not isinstance(graph_def, graph_pb2.GraphDef):
62
+ raise TypeError('graph_def must be a graph_pb2.GraphDef proto.')
63
+
64
+ edges = {}
65
+ name_to_node_map = {}
66
+ node_seq = {}
67
+ seq = 0
68
+ nodes_to_keep = set()
69
+ variables_to_keep = set()
70
+
71
+ for node in graph_def.node:
72
+ n = _node_name(node.name)
73
+ name_to_node_map[n] = node
74
+ edges[n] = [_node_name(item) for item in node.input]
75
+ node_seq[n] = seq
76
+ seq += 1
77
+ for d in dest_nodes:
78
+ assert d in name_to_node_map, "'%s' is not in graph" % d
79
+
80
+ next_to_visit = dest_nodes[:]
81
+ while next_to_visit:
82
+ n = next_to_visit[0]
83
+
84
+ if n in variable_protos:
85
+ proto = variable_protos[n]
86
+ next_to_visit.append(_node_name(proto.initial_value_name))
87
+ next_to_visit.append(_node_name(proto.initializer_name))
88
+ next_to_visit.append(_node_name(proto.snapshot_name))
89
+ variables_to_keep.add(proto.variable_name)
90
+
91
+ del next_to_visit[0]
92
+ if n in nodes_to_keep:
93
+ continue
94
+ # make sure n is in edges
95
+ if n in edges:
96
+ nodes_to_keep.add(n)
97
+ next_to_visit += edges[n]
98
+ nodes_to_keep_list = sorted(list(nodes_to_keep), key=lambda n: node_seq[n])
99
+
100
+ out = graph_pb2.GraphDef()
101
+ for n in nodes_to_keep_list:
102
+ out.node.extend([copy.deepcopy(name_to_node_map[n])])
103
+ out.library.CopyFrom(graph_def.library)
104
+ out.versions.CopyFrom(graph_def.versions)
105
+
106
+ return out, variables_to_keep
107
+
108
+
109
+ def load_meta_graph_def(model_dir):
110
+ """Load meta graph def in saved model.
111
+
112
+ Args:
113
+ model_dir: saved model directory.
114
+
115
+ Returns:
116
+ meta_graph_def: a MetaGraphDef.
117
+ variable_protos: a dict of VariableDef.
118
+ input_tensor_names: signature inputs in saved model.
119
+ output_tensor_names: signature outputs in saved model.
120
+ """
121
+ input_tensor_names = {}
122
+ output_tensor_names = {}
123
+ variable_protos = {}
124
+
125
+ meta_graph_def = saved_model_utils.get_meta_graph_def(
126
+ model_dir, tf.saved_model.tag_constants.SERVING)
127
+ signatures = meta_graph_def.signature_def
128
+ collections = meta_graph_def.collection_def
129
+
130
+ # parse collection_def in SavedModel
131
+ for key, col_def in collections.items():
132
+ if key in ops.GraphKeys._VARIABLE_COLLECTIONS:
133
+ tf.logging.info('[Collection] %s:' % key)
134
+ for value in col_def.bytes_list.value:
135
+ proto_type = ops.get_collection_proto_type(key)
136
+ proto = proto_type()
137
+ proto.ParseFromString(value)
138
+ tf.logging.info('%s' % proto.variable_name)
139
+ variable_node_name = _node_name(proto.variable_name)
140
+ if variable_node_name not in variable_protos:
141
+ variable_protos[variable_node_name] = proto
142
+
143
+ # parse signature info for SavedModel
144
+ for sig_name in signatures:
145
+ if signatures[
146
+ sig_name].method_name == tf.saved_model.signature_constants.PREDICT_METHOD_NAME:
147
+ tf.logging.info('[Signature] inputs:')
148
+ for input_name in signatures[sig_name].inputs:
149
+ input_tensor_shape = []
150
+ input_tensor = signatures[sig_name].inputs[input_name]
151
+ for dim in input_tensor.tensor_shape.dim:
152
+ input_tensor_shape.append(int(dim.size))
153
+ tf.logging.info('"%s": %s; %s' %
154
+ (input_name, _TYPE_TO_STRING[input_tensor.dtype],
155
+ input_tensor_shape))
156
+ input_tensor_names[input_name] = input_tensor.name
157
+ tf.logging.info('[Signature] outputs:')
158
+ for output_name in signatures[sig_name].outputs:
159
+ output_tensor_shape = []
160
+ output_tensor = signatures[sig_name].outputs[output_name]
161
+ for dim in output_tensor.tensor_shape.dim:
162
+ output_tensor_shape.append(int(dim.size))
163
+ tf.logging.info('"%s": %s; %s' %
164
+ (output_name, _TYPE_TO_STRING[output_tensor.dtype],
165
+ output_tensor_shape))
166
+ output_tensor_names[output_name] = output_tensor.name
167
+
168
+ return meta_graph_def, variable_protos, input_tensor_names, output_tensor_names
169
+
170
+
171
+ def export(model_dir, meta_graph_def, variable_protos, input_tensor_names,
172
+ output_tensor_names, part_name, part_dir):
173
+ """Export subpart saved model.
174
+
175
+ Args:
176
+ model_dir: saved model directory.
177
+ meta_graph_def: a MetaGraphDef.
178
+ variable_protos: a dict of VariableDef.
179
+ input_tensor_names: signature inputs in saved model.
180
+ output_tensor_names: signature outputs in saved model.
181
+ part_name: subpart model name, user or item.
182
+ part_dir: subpart model export directory.
183
+ """
184
+ output_tensor_names = {
185
+ x: output_tensor_names[x]
186
+ for x in output_tensor_names.keys()
187
+ if part_name in x
188
+ }
189
+ output_node_names = [
190
+ _node_name(output_tensor_names[x]) for x in output_tensor_names.keys()
191
+ ]
192
+
193
+ inference_graph, variables_to_keep = extract_sub_graph(
194
+ meta_graph_def.graph_def, output_node_names, variable_protos)
195
+
196
+ tf.reset_default_graph()
197
+ with tf.Session() as sess:
198
+ with sess.graph.as_default():
199
+ graph = ops.get_default_graph()
200
+ importer.import_graph_def(inference_graph, name='')
201
+ for name in variables_to_keep:
202
+ variable = graph.get_tensor_by_name(name)
203
+ graph.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, variable)
204
+ saver = tf_saver.Saver()
205
+ saver.restore(sess, get_variables_path(model_dir))
206
+
207
+ builder = tf.saved_model.builder.SavedModelBuilder(part_dir)
208
+ signature_inputs = {}
209
+ for input_name in input_tensor_names:
210
+ try:
211
+ tensor_info = tf.saved_model.utils.build_tensor_info(
212
+ graph.get_tensor_by_name(input_tensor_names[input_name]))
213
+ signature_inputs[input_name] = tensor_info
214
+ except Exception:
215
+ print('ignore input: %s' % input_name)
216
+
217
+ signature_outputs = {}
218
+ for output_name in output_tensor_names:
219
+ tensor_info = tf.saved_model.utils.build_tensor_info(
220
+ graph.get_tensor_by_name(output_tensor_names[output_name]))
221
+ signature_outputs[output_name] = tensor_info
222
+
223
+ prediction_signature = (
224
+ tf.saved_model.signature_def_utils.build_signature_def(
225
+ inputs=signature_inputs,
226
+ outputs=signature_outputs,
227
+ method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME
228
+ ))
229
+
230
+ builder.add_meta_graph_and_variables(
231
+ sess, [tf.saved_model.tag_constants.SERVING],
232
+ signature_def_map={
233
+ signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
234
+ prediction_signature,
235
+ })
236
+ builder.save()
237
+ config_path = os.path.join(model_dir, 'assets/pipeline.config')
238
+ assert tf.gfile.Exists(config_path)
239
+ dst_path = os.path.join(part_dir, 'assets')
240
+ dst_config_path = os.path.join(dst_path, 'pipeline.config')
241
+ tf.gfile.MkDir(dst_path)
242
+ tf.gfile.Copy(config_path, dst_config_path)
243
+
244
+
245
+ def main(argv):
246
+ model_dir = search_pb(FLAGS.model_dir)
247
+ tf.logging.info('Loading meta graph...')
248
+ meta_graph_def, variable_protos, input_tensor_names, output_tensor_names = load_meta_graph_def(
249
+ model_dir)
250
+ tf.logging.info('Exporting trigger part model...')
251
+ export(
252
+ model_dir,
253
+ meta_graph_def,
254
+ variable_protos,
255
+ input_tensor_names,
256
+ output_tensor_names,
257
+ part_name='trigger_out',
258
+ part_dir=FLAGS.trigger_model_dir)
259
+ tf.logging.info('Exporting sim part model...')
260
+ export(
261
+ model_dir,
262
+ meta_graph_def,
263
+ variable_protos,
264
+ input_tensor_names,
265
+ output_tensor_names,
266
+ part_name='sim_out',
267
+ part_dir=FLAGS.sim_model_dir)
268
+
269
+
270
+ if __name__ == '__main__':
271
+ sys.argv = io_util.filter_unknown_args(FLAGS, sys.argv)
272
+ tf.app.run()
@@ -0,0 +1,80 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import argparse
4
+ import json
5
+ import logging
6
+ import os
7
+
8
+ import numpy as np
9
+ import tensorflow as tf
10
+
11
+ import easy_rec
12
+ from easy_rec.python.inference.predictor import Predictor
13
+
14
+ logging.basicConfig(
15
+ format='[%(levelname)s] %(asctime)s %(filename)s:%(lineno)d : %(message)s',
16
+ level=logging.INFO)
17
+
18
+ lookup_op_path = os.path.join(easy_rec.ops_dir, 'libkv_lookup.so')
19
+ lookup_op = tf.load_op_library(lookup_op_path)
20
+
21
+ if __name__ == '__main__':
22
+ """Test saved model, an example:
23
+
24
+ python -m easy_rec.python.tools.test_saved_model
25
+ --saved_model_dir after_edit_save
26
+ --input_path data/test/rtp/xys_cxr_fg_sample_test2_with_lbl.txt
27
+ --with_lbl
28
+ """
29
+
30
+ parser = argparse.ArgumentParser()
31
+ parser.add_argument(
32
+ '--saved_model_dir', type=str, default=None, help='saved model dir')
33
+ parser.add_argument('--input_path', type=str, default=None, help='output dir')
34
+ parser.add_argument('--save_path', type=str, default=None, help='save path')
35
+ parser.add_argument('--separator', type=str, default=',', help='separator')
36
+ parser.add_argument(
37
+ '--cmp_res_path', type=str, default=None, help='compare result path')
38
+ parser.add_argument(
39
+ '--cmp_key', type=str, default='probs', help='compare key')
40
+ parser.add_argument('--tol', type=float, default=1e-5, help='tolerance')
41
+ parser.add_argument(
42
+ '--with_lbl',
43
+ action='store_true',
44
+ default=False,
45
+ help='whether the test data has label field')
46
+ args = parser.parse_args()
47
+
48
+ logging.info('saved_model_dir: %s' % args.saved_model_dir)
49
+ logging.info('test_data_path: %s' % args.input_path)
50
+ logging.info('test_data has lbl: %s' % args.with_lbl)
51
+
52
+ predictor = Predictor(args.saved_model_dir)
53
+ with open(args.input_path, 'r') as fin:
54
+ feature_vals = []
55
+ for line_str in fin:
56
+ line_str = line_str.strip()
57
+ line_toks = line_str.split(args.separator)
58
+ if args.with_lbl:
59
+ line_toks = line_toks[1:]
60
+ feature_vals.append(args.separator.join(line_toks))
61
+ output = predictor.predict(feature_vals, batch_size=4096)
62
+
63
+ if args.save_path:
64
+ with open(args.save_path, 'w') as fout:
65
+ for one in output:
66
+ fout.write(str(one) + '\n')
67
+
68
+ if args.cmp_res_path:
69
+ logging.info('compare result path: ' + args.cmp_res_path)
70
+ logging.info('compare key: ' + args.cmp_key)
71
+ logging.info('tolerance: ' + str(args.tol))
72
+ with open(args.cmp_res_path, 'r') as fin:
73
+ for line_id, line_str in enumerate(fin):
74
+ line_str = line_str.strip()
75
+ line_pred = json.loads(line_str)
76
+ assert np.abs(
77
+ line_pred[args.cmp_key] -
78
+ output[line_id][args.cmp_key]) < args.tol, 'line[%d]: %.8f' % (
79
+ line_id,
80
+ np.abs(line_pred[args.cmp_key] - output[line_id][args.cmp_key]))
@@ -0,0 +1,39 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import argparse
4
+ import logging
5
+
6
+ from google.protobuf import text_format
7
+ from tensorflow.core.protobuf import saved_model_pb2
8
+ from tensorflow.python.platform.gfile import GFile
9
+
10
+ logging.basicConfig(
11
+ format='[%(levelname)s] %(asctime)s %(filename)s:%(lineno)d : %(message)s',
12
+ level=logging.INFO)
13
+
14
+ if __name__ == '__main__':
15
+ parser = argparse.ArgumentParser()
16
+ parser.add_argument(
17
+ '--input', type=str, default=None, help='saved model path')
18
+ parser.add_argument(
19
+ '--output', type=str, default=None, help='saved model save path')
20
+ args = parser.parse_args()
21
+
22
+ assert args.input is not None and args.output is not None
23
+
24
+ logging.info('saved_model_path: %s' % args.input)
25
+
26
+ saved_model = saved_model_pb2.SavedModel()
27
+ if args.input.endswith('.pb'):
28
+ with GFile(args.input, 'rb') as fin:
29
+ saved_model.ParseFromString(fin.read())
30
+ else:
31
+ with GFile(args.input, 'r') as fin:
32
+ text_format.Merge(fin.read(), saved_model)
33
+
34
+ if args.output.endswith('.pbtxt'):
35
+ with GFile(args.output, 'w') as fout:
36
+ fout.write(text_format.MessageToString(saved_model, as_utf8=True))
37
+ else:
38
+ with GFile(args.output, 'wb') as fout:
39
+ fout.write(saved_model.SerializeToString())