easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
@@ -0,0 +1,941 @@
1
+ # -*- encoding:utf-8 -*-
2
+ import logging
3
+ import os
4
+
5
+ import numpy as np
6
+ import tensorflow as tf
7
+ from google.protobuf import text_format
8
+ from tensorflow.python.framework import ops
9
+ from tensorflow.python.platform.gfile import GFile
10
+ # from tensorflow.python.saved_model import constants
11
+ from tensorflow.python.saved_model import signature_constants
12
+ from tensorflow.python.saved_model.loader_impl import SavedModelLoader
13
+
14
+ from easy_rec.python.utils import conditional
15
+ from easy_rec.python.utils import constant
16
+ from easy_rec.python.utils import embedding_utils
17
+ from easy_rec.python.utils import proto_util
18
+
19
+ EMBEDDING_INITIALIZERS = 'embedding_initializers'
20
+
21
+
22
+ class MetaGraphEditor:
23
+
24
+ def __init__(self,
25
+ lookup_lib_path,
26
+ saved_model_dir,
27
+ redis_url=None,
28
+ redis_passwd=None,
29
+ redis_timeout=0,
30
+ redis_cache_names=[],
31
+ oss_path=None,
32
+ oss_endpoint=None,
33
+ oss_ak=None,
34
+ oss_sk=None,
35
+ oss_timeout=0,
36
+ meta_graph_def=None,
37
+ norm_name_to_ids=None,
38
+ incr_update_params=None,
39
+ debug_dir=''):
40
+ self._lookup_op = tf.load_op_library(lookup_lib_path)
41
+ self._debug_dir = debug_dir
42
+ self._verbose = debug_dir != ''
43
+ if saved_model_dir:
44
+ tags = ['serve']
45
+ loader = SavedModelLoader(saved_model_dir)
46
+ saver, _ = loader.load_graph(tf.get_default_graph(), tags, None)
47
+ meta_graph_def = loader.get_meta_graph_def_from_tags(tags)
48
+ else:
49
+ assert meta_graph_def, 'either saved_model_dir or meta_graph_def must be set'
50
+ tf.reset_default_graph()
51
+ from tensorflow.python.framework import meta_graph
52
+ meta_graph.import_scoped_meta_graph_with_return_elements(
53
+ meta_graph_def, clear_devices=True)
54
+ # tf.train.import_meta_graph(meta_graph_def)
55
+ self._meta_graph_version = meta_graph_def.meta_info_def.meta_graph_version
56
+ self._signature_def = meta_graph_def.signature_def[
57
+ signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
58
+
59
+ if self._verbose:
60
+ debug_out_path = os.path.join(self._debug_dir, 'meta_graph_raw.txt')
61
+ with GFile(debug_out_path, 'w') as fout:
62
+ fout.write(text_format.MessageToString(meta_graph_def, as_utf8=True))
63
+ self._meta_graph_def = meta_graph_def
64
+ self._old_node_num = len(self._meta_graph_def.graph_def.node)
65
+ self._all_graph_nodes = None
66
+ self._all_graph_node_flags = None
67
+ self._restore_tensor_node = None
68
+ self._restore_shard_node = None
69
+ self._restore_all_node = []
70
+ self._lookup_outs = None
71
+ self._feature_names = None
72
+ self._embed_names = None
73
+ self._embed_name_to_ids = norm_name_to_ids
74
+ self._is_cache_from_redis = []
75
+ self._redis_cache_names = redis_cache_names
76
+ self._embed_ids = None
77
+ self._embed_dims = None
78
+ self._embed_sizes = None
79
+ self._embed_combiners = None
80
+ self._redis_url = redis_url
81
+ self._redis_passwd = redis_passwd
82
+ self._redis_timeout = redis_timeout
83
+ self._oss_path = oss_path
84
+ self._oss_endpoint = oss_endpoint
85
+ self._oss_ak = oss_ak
86
+ self._oss_sk = oss_sk
87
+ self._oss_timeout = oss_timeout
88
+
89
+ self._incr_update_params = incr_update_params
90
+
91
+ # increment update placeholders
92
+ self._embedding_update_inputs = {}
93
+ self._embedding_update_outputs = {}
94
+
95
+ self._dense_update_inputs = {}
96
+ self._dense_update_outputs = {}
97
+
98
+ @property
99
+ def sparse_update_inputs(self):
100
+ return self._embedding_update_inputs
101
+
102
+ @property
103
+ def sparse_update_outputs(self):
104
+ return self._embedding_update_outputs
105
+
106
+ @property
107
+ def dense_update_inputs(self):
108
+ return self._dense_update_inputs
109
+
110
+ @property
111
+ def dense_update_outputs(self):
112
+ return self._dense_update_outputs
113
+
114
+ @property
115
+ def graph_def(self):
116
+ return self._meta_graph_def.graph_def
117
+
118
+ @property
119
+ def signature_def(self):
120
+ return self._signature_def
121
+
122
+ @property
123
+ def meta_graph_version(self):
124
+ return self._meta_graph_version
125
+
126
+ def init_graph_node_clear_flags(self):
127
+ graph_def = self._meta_graph_def.graph_def
128
+ self._all_graph_nodes = [n for n in graph_def.node]
129
+ self._all_graph_node_flags = [True for n in graph_def.node]
130
+
131
+ def _get_share_embed_name(self, x, embed_names):
132
+ """Map share embedding tensor names to embed names.
133
+
134
+ Args:
135
+ x: string, embedding tensor names, such as:
136
+ input_layer_1/shared_embed_1/field16_shared_embedding
137
+ input_layer_1/shared_embed_2/field17_shared_embedding
138
+ input_layer/shared_embed_wide/field15_shared_embedding
139
+ input_layer/shared_embed_wide_1/field16_shared_embedding
140
+ embed_names: all the optional embedding_names
141
+ Return:
142
+ one element in embed_names, such as:
143
+ input_layer_1/shared_embed
144
+ input_layer_1/shared_embed
145
+ input_layer/shared_embed_wide
146
+ input_layer/shared_embed_wide
147
+ """
148
+ assert x.endswith('_shared_embedding')
149
+ name_toks = x.split('/')
150
+ name_toks = name_toks[:-1]
151
+ tmp = name_toks[-1]
152
+ tmp = tmp.split('_')
153
+ try:
154
+ int(tmp[-1])
155
+ name_toks[-1] = '_'.join(tmp[:-1])
156
+ except Exception:
157
+ pass
158
+ tmp_name = '/'.join(name_toks[1:])
159
+ sel_embed_name = ''
160
+ for embed_name in embed_names:
161
+ tmp_toks = embed_name.split('/')
162
+ tmp_toks = tmp_toks[1:]
163
+ embed_name_sub = '/'.join(tmp_toks)
164
+ if tmp_name == embed_name_sub:
165
+ assert not sel_embed_name, 'confusions encountered: %s %s' % (
166
+ x, ','.join(embed_names))
167
+ sel_embed_name = embed_name
168
+ assert sel_embed_name, '%s not find in shared_embeddings: %s' % (
169
+ tmp_name, ','.join(embed_names))
170
+ return sel_embed_name
171
+
172
+ def _find_embed_combiners(self, norm_embed_names):
173
+ """Find embedding lookup combiner methods.
174
+
175
+ Args:
176
+ norm_embed_names: normalized embedding names
177
+ Return:
178
+ list: combiner methods for each features: sum, mean, sqrtn
179
+ """
180
+ embed_combiners = {}
181
+ embed_combine_node_cts = {}
182
+ combiner_map = {
183
+ 'SparseSegmentSum': 'sum',
184
+ 'SparseSegmentMean': 'mean',
185
+ 'SparseSegmentSqrtN': 'sqrtn'
186
+ }
187
+ for node in self._meta_graph_def.graph_def.node:
188
+ if node.op in combiner_map:
189
+ norm_name, _ = proto_util.get_norm_embed_name(node.name)
190
+ embed_combiners[norm_name] = combiner_map[node.op]
191
+ embed_combine_node_cts[norm_name] = embed_combine_node_cts.get(
192
+ norm_name, 0) + 1
193
+ elif node.op == 'RealDiv' and len(node.input) == 2:
194
+ # for tag feature with weights, and combiner == mean
195
+ if 'SegmentSum' in node.input[0] and 'SegmentSum' in node.input[1]:
196
+ norm_name, _ = proto_util.get_norm_embed_name(node.name)
197
+ embed_combiners[norm_name] = 'mean'
198
+ embed_combine_node_cts[norm_name] = embed_combine_node_cts.get(
199
+ norm_name, 0) + 1
200
+ elif node.op == 'SegmentSum':
201
+ norm_name, _ = proto_util.get_norm_embed_name(node.name)
202
+ # avoid overwrite RealDiv results
203
+ if norm_name not in embed_combiners:
204
+ embed_combiners[norm_name] = 'sum'
205
+ embed_combine_node_cts[norm_name] = embed_combine_node_cts.get(
206
+ norm_name, 0) + 1
207
+ return [embed_combiners[x] for x in norm_embed_names]
208
+
209
+ def _find_lookup_indices_values_shapes(self):
210
+ # use the specific _embedding_weights/SparseReshape to find out
211
+ # lookup inputs: indices, values, dense_shape, weights
212
+ indices = {}
213
+ values = {}
214
+ shapes = {}
215
+
216
+ def _get_output_shape(graph_def, input_name):
217
+ out_id = 0
218
+ if ':' in input_name:
219
+ node_name, out_id = input_name.split(':')
220
+ out_id = int(out_id)
221
+ else:
222
+ node_name = input_name
223
+ for node in graph_def.node:
224
+ if node.name == node_name:
225
+ return node.attr['_output_shapes'].list.shape[out_id]
226
+ return None
227
+
228
+ for node in self._meta_graph_def.graph_def.node:
229
+ if '_embedding_weights/SparseReshape' in node.name:
230
+ if node.op == 'SparseReshape':
231
+ # embed_name, _ = proto_util.get_norm_embed_name(node.name, self._verbose)
232
+ fea_name, _ = proto_util.get_norm_embed_name(node.name, self._verbose)
233
+ for tmp_input in node.input:
234
+ tmp_shape = _get_output_shape(self._meta_graph_def.graph_def,
235
+ tmp_input)
236
+ if '_embedding_weights/Cast' in tmp_input:
237
+ continue
238
+ elif len(tmp_shape.dim) == 2:
239
+ indices[fea_name] = tmp_input
240
+ elif len(tmp_shape.dim) == 1:
241
+ shapes[fea_name] = tmp_input
242
+ elif node.op == 'Identity':
243
+ fea_name, _ = proto_util.get_norm_embed_name(node.name, self._verbose)
244
+ values[fea_name] = node.input[0]
245
+ return indices, values, shapes
246
+
247
+ def _find_lookup_weights(self):
248
+ weights = {}
249
+ for node in self._meta_graph_def.graph_def.node:
250
+ if '_weighted_by_' in node.name and 'GatherV2' in node.name:
251
+ has_sparse_reshape = False
252
+ for tmp_input in node.input:
253
+ if 'SparseReshape' in tmp_input:
254
+ has_sparse_reshape = True
255
+ if has_sparse_reshape:
256
+ continue
257
+ if len(node.input) != 3:
258
+ continue
259
+ # try to find nodes with weights
260
+ # input_layer/xxx_weighted_by_yyy_embedding/xxx_weighted_by_yyy_embedding_weights/GatherV2_[0-9]
261
+ # which has three inputs:
262
+ # input_layer/xxx_weighted_by_yyy_embedding/xxx_weighted_by_yyy_embedding_weights/Reshape_1
263
+ # DeserializeSparse_1 (this is the weight)
264
+ # input_layer/xxx_weighted_by_yyy_embedding/xxx_weighted_by_yyy_embedding_weights/GatherV2_4/axis
265
+ fea_name, _ = proto_util.get_norm_embed_name(node.name, self._verbose)
266
+ for tmp_input in node.input:
267
+ if '_weighted_by_' not in tmp_input:
268
+ weights[fea_name] = tmp_input
269
+ return weights
270
+
271
+ def _find_embed_names_and_dims(self, norm_embed_names):
272
+ # get embedding dimensions from Variables
273
+ embed_dims = {}
274
+ embed_sizes = {}
275
+ embed_is_kv = {}
276
+ for node in self._meta_graph_def.graph_def.node:
277
+ if 'embedding_weights' in node.name and node.op in [
278
+ 'VariableV2', 'KvVarHandleOp'
279
+ ]:
280
+ tmp = node.attr['shape'].shape.dim[-1].size
281
+ tmp2 = 1
282
+ for x in node.attr['shape'].shape.dim[:-1]:
283
+ tmp2 = tmp2 * x.size
284
+ embed_name, _ = proto_util.get_norm_embed_name(node.name, self._verbose)
285
+ assert embed_name is not None,\
286
+ 'fail to get_norm_embed_name(%s)' % node.name
287
+ embed_dims[embed_name] = tmp
288
+ embed_sizes[embed_name] = tmp2
289
+ embed_is_kv[embed_name] = 1 if node.op == 'KvVarHandleOp' else 0
290
+
291
+ # get all embedding dimensions, note that some embeddings
292
+ # are shared by multiple inputs, so the names should be
293
+ # transformed
294
+ all_embed_dims = []
295
+ all_embed_names = []
296
+ all_embed_sizes = []
297
+ all_embed_is_kv = []
298
+ for x in norm_embed_names:
299
+ if x in embed_dims:
300
+ all_embed_names.append(x)
301
+ all_embed_dims.append(embed_dims[x])
302
+ all_embed_sizes.append(embed_sizes[x])
303
+ all_embed_is_kv.append(embed_is_kv[x])
304
+ elif x.endswith('_shared_embedding'):
305
+ tmp_embed_name = self._get_share_embed_name(x, embed_dims.keys())
306
+ all_embed_names.append(tmp_embed_name)
307
+ all_embed_dims.append(embed_dims[tmp_embed_name])
308
+ all_embed_sizes.append(embed_sizes[tmp_embed_name])
309
+ all_embed_is_kv.append(embed_is_kv[tmp_embed_name])
310
+ return all_embed_names, all_embed_dims, all_embed_sizes, all_embed_is_kv
311
+
312
+ def find_lookup_inputs(self):
313
+ logging.info('Extract embedding_lookup inputs')
314
+
315
+ indices, values, shapes = self._find_lookup_indices_values_shapes()
316
+ weights = self._find_lookup_weights()
317
+
318
+ for fea in shapes.keys():
319
+ logging.info('Lookup Input[%s]: indices=%s values=%s shapes=%s' %
320
+ (fea, indices[fea], values[fea], shapes[fea]))
321
+
322
+ graph = tf.get_default_graph()
323
+
324
+ def _get_tensor_by_name(tensor_name):
325
+ if ':' not in tensor_name:
326
+ tensor_name = tensor_name + ':0'
327
+ return graph.get_tensor_by_name(tensor_name)
328
+
329
+ lookup_input_values = []
330
+ lookup_input_indices = []
331
+ lookup_input_shapes = []
332
+ lookup_input_weights = []
333
+ for key in values.keys():
334
+ tmp_val, tmp_ind, tmp_shape = values[key], indices[key], shapes[key]
335
+ lookup_input_values.append(_get_tensor_by_name(tmp_val))
336
+ lookup_input_indices.append(_get_tensor_by_name(tmp_ind))
337
+ lookup_input_shapes.append(_get_tensor_by_name(tmp_shape))
338
+ if key in weights:
339
+ tmp_w = weights[key]
340
+ lookup_input_weights.append(_get_tensor_by_name(tmp_w))
341
+ else:
342
+ lookup_input_weights.append([])
343
+
344
+ # get embedding combiners
345
+ self._embed_combiners = self._find_embed_combiners(values.keys())
346
+
347
+ # get embedding dimensions
348
+ self._embed_names, self._embed_dims, self._embed_sizes, self._embed_is_kv\
349
+ = self._find_embed_names_and_dims(values.keys())
350
+
351
+ if not self._embed_name_to_ids:
352
+ embed_name_uniq = list(set(self._embed_names))
353
+ self._embed_name_to_ids = {
354
+ t: tid for tid, t in enumerate(embed_name_uniq)
355
+ }
356
+ self._embed_ids = [
357
+ int(self._embed_name_to_ids[x]) for x in self._embed_names
358
+ ]
359
+
360
+ self._is_cache_from_redis = [
361
+ proto_util.is_cache_from_redis(x, self._redis_cache_names)
362
+ for x in self._embed_names
363
+ ]
364
+
365
+ # normalized feature names
366
+ self._feature_names = list(values.keys())
367
+
368
+ return lookup_input_indices, lookup_input_values, lookup_input_shapes,\
369
+ lookup_input_weights
370
+
371
+ def add_lookup_op(self, lookup_input_indices, lookup_input_values,
372
+ lookup_input_shapes, lookup_input_weights):
373
+ logging.info('add custom lookup operation to lookup embeddings from redis')
374
+ self._lookup_outs = [None for i in range(len(lookup_input_values))]
375
+ for i in range(len(lookup_input_values)):
376
+ if lookup_input_values[i].dtype == tf.int32:
377
+ lookup_input_values[i] = tf.to_int64(lookup_input_values[i])
378
+ for i in range(len(self._lookup_outs)):
379
+ i_1 = i + 1
380
+ self._lookup_outs[i] = self._lookup_op.kv_lookup(
381
+ lookup_input_indices[i:i_1],
382
+ lookup_input_values[i:i_1],
383
+ lookup_input_shapes[i:i_1],
384
+ lookup_input_weights[i:i_1],
385
+ url=self._redis_url,
386
+ password=self._redis_passwd,
387
+ timeout=self._redis_timeout,
388
+ combiners=self._embed_combiners[i:i_1],
389
+ embedding_dims=self._embed_dims[i:i_1],
390
+ embedding_names=self._embed_ids[i:i_1],
391
+ cache=self._is_cache_from_redis,
392
+ version=self._meta_graph_version)[0]
393
+
394
+ meta_graph_def = tf.train.export_meta_graph()
395
+
396
+ if self._verbose:
397
+ debug_path = os.path.join(self._debug_dir, 'graph_raw.txt')
398
+ with GFile(debug_path, 'w') as fout:
399
+ fout.write(
400
+ text_format.MessageToString(
401
+ self._meta_graph_def.graph_def, as_utf8=True))
402
+ return meta_graph_def
403
+
404
+ def add_oss_lookup_op(self, lookup_input_indices, lookup_input_values,
405
+ lookup_input_shapes, lookup_input_weights):
406
+ logging.info('add custom lookup operation to lookup embeddings from oss')
407
+ place_on_cpu = os.getenv('place_embedding_on_cpu')
408
+ place_on_cpu = eval(place_on_cpu) if place_on_cpu else False
409
+ with conditional(place_on_cpu, ops.device('/CPU:0')):
410
+ for i in range(len(lookup_input_values)):
411
+ if lookup_input_values[i].dtype == tf.int32:
412
+ lookup_input_values[i] = tf.to_int64(lookup_input_values[i])
413
+ # N = len(lookup_input_indices)
414
+ # self._lookup_outs = [ None for _ in range(N) ]
415
+ # for i in range(N):
416
+ # i_1 = i + 1
417
+ # self._lookup_outs[i] = self._lookup_op.oss_read_kv(
418
+ # lookup_input_indices[i:i_1],
419
+ # lookup_input_values[i:i_1],
420
+ # lookup_input_shapes[i:i_1],
421
+ # lookup_input_weights[i:i_1],
422
+ # osspath=self._oss_path,
423
+ # endpoint=self._oss_endpoint,
424
+ # ak=self._oss_ak,
425
+ # sk=self._oss_sk,
426
+ # timeout=self._oss_timeout,
427
+ # combiners=self._embed_combiners[i:i_1],
428
+ # embedding_dims=self._embed_dims[i:i_1],
429
+ # embedding_ids=self._embed_ids[i:i_1],
430
+ # embedding_is_kv=self._embed_is_kv[i:i_1],
431
+ # shared_name='embedding_lookup_res',
432
+ # name='embedding_lookup_fused/lookup')[0]
433
+ self._lookup_outs = self._lookup_op.oss_read_kv(
434
+ lookup_input_indices,
435
+ lookup_input_values,
436
+ lookup_input_shapes,
437
+ lookup_input_weights,
438
+ osspath=self._oss_path,
439
+ endpoint=self._oss_endpoint,
440
+ ak=self._oss_ak,
441
+ sk=self._oss_sk,
442
+ timeout=self._oss_timeout,
443
+ combiners=self._embed_combiners,
444
+ embedding_dims=self._embed_dims,
445
+ embedding_ids=self._embed_ids,
446
+ embedding_is_kv=self._embed_is_kv,
447
+ shared_name='embedding_lookup_res',
448
+ name='embedding_lookup_fused/lookup')
449
+
450
+ N = np.max([int(x) for x in self._embed_ids]) + 1
451
+ uniq_embed_ids = [x for x in range(N)]
452
+ uniq_embed_dims = [0 for x in range(N)]
453
+ uniq_embed_combiners = ['mean' for x in range(N)]
454
+ uniq_embed_is_kvs = [0 for x in range(N)]
455
+ for embed_id, embed_combiner, embed_is_kv, embed_dim in zip(
456
+ self._embed_ids, self._embed_combiners, self._embed_is_kv,
457
+ self._embed_dims):
458
+ uniq_embed_combiners[embed_id] = embed_combiner
459
+ uniq_embed_is_kvs[embed_id] = embed_is_kv
460
+ uniq_embed_dims[embed_id] = embed_dim
461
+
462
+ lookup_init_op = self._lookup_op.oss_init(
463
+ osspath=self._oss_path,
464
+ endpoint=self._oss_endpoint,
465
+ ak=self._oss_ak,
466
+ sk=self._oss_sk,
467
+ combiners=uniq_embed_combiners,
468
+ embedding_dims=uniq_embed_dims,
469
+ embedding_ids=uniq_embed_ids,
470
+ embedding_is_kv=uniq_embed_is_kvs,
471
+ N=N,
472
+ shared_name='embedding_lookup_res',
473
+ name='embedding_lookup_fused/init')
474
+
475
+ ops.add_to_collection(EMBEDDING_INITIALIZERS, lookup_init_op)
476
+
477
+ if self._incr_update_params is not None:
478
+ # all sparse variables are updated by a single custom operation
479
+ message_ph = tf.placeholder(tf.int8, [None], name='incr_update/message')
480
+ embedding_update = self._lookup_op.embedding_update(
481
+ message=message_ph,
482
+ shared_name='embedding_lookup_res',
483
+ name='embedding_lookup_fused/embedding_update')
484
+ self._embedding_update_inputs['incr_update/sparse/message'] = message_ph
485
+ self._embedding_update_outputs[
486
+ 'incr_update/sparse/embedding_update'] = embedding_update
487
+
488
+ # dense variables are updated one by one
489
+ dense_name_to_ids = embedding_utils.get_dense_name_to_ids()
490
+ for x in ops.get_collection(constant.DENSE_UPDATE_VARIABLES):
491
+ dense_var_id = dense_name_to_ids[x.op.name]
492
+ dense_input_name = 'incr_update/dense/%d/input' % dense_var_id
493
+ dense_output_name = 'incr_update/dense/%d/output' % dense_var_id
494
+ dense_update_input = tf.placeholder(
495
+ tf.float32, x.get_shape(), name=dense_input_name)
496
+ self._dense_update_inputs[dense_input_name] = dense_update_input
497
+ dense_assign_op = tf.assign(x, dense_update_input)
498
+ self._dense_update_outputs[dense_output_name] = dense_assign_op
499
+
500
+ meta_graph_def = tf.train.export_meta_graph()
501
+
502
+ if self._verbose:
503
+ debug_path = os.path.join(self._debug_dir, 'graph_raw.txt')
504
+ with GFile(debug_path, 'w') as fout:
505
+ fout.write(
506
+ text_format.MessageToString(
507
+ self._meta_graph_def.graph_def, as_utf8=True))
508
+ return meta_graph_def
509
+
510
+ def bytes2str(self, x):
511
+ if bytes == str:
512
+ return x
513
+ else:
514
+ try:
515
+ return x.decode('utf-8')
516
+ except Exception:
517
+ # in case of some special chars in protobuf
518
+ return str(x)
519
+
520
+ def clear_meta_graph_embeding(self, meta_graph_def):
521
+ logging.info('clear meta graph embedding_weights')
522
+
523
+ def _clear_embedding_in_meta_collect(meta_graph_def, collect_name):
524
+ tmp_vals = [
525
+ x
526
+ for x in meta_graph_def.collection_def[collect_name].bytes_list.value
527
+ if 'embedding_weights' not in self.bytes2str(x)
528
+ ]
529
+ meta_graph_def.collection_def[collect_name].bytes_list.ClearField('value')
530
+ for tmp_v in tmp_vals:
531
+ meta_graph_def.collection_def[collect_name].bytes_list.value.append(
532
+ tmp_v)
533
+
534
+ _clear_embedding_in_meta_collect(meta_graph_def, 'model_variables')
535
+ _clear_embedding_in_meta_collect(meta_graph_def, 'trainable_variables')
536
+ _clear_embedding_in_meta_collect(meta_graph_def, 'variables')
537
+
538
+ # clear Kv(pai embedding variable) ops in meta_info_def.stripped_op_list.op
539
+ kept_ops = [
540
+ x for x in meta_graph_def.meta_info_def.stripped_op_list.op
541
+ if x.name not in [
542
+ 'InitializeKvVariableOp', 'KvResourceGather', 'KvResourceImportV2',
543
+ 'KvVarHandleOp', 'KvVarIsInitializedOp', 'ReadKvVariableOp'
544
+ ]
545
+ ]
546
+ meta_graph_def.meta_info_def.stripped_op_list.ClearField('op')
547
+ meta_graph_def.meta_info_def.stripped_op_list.op.extend(kept_ops)
548
+ for tmp_op in meta_graph_def.meta_info_def.stripped_op_list.op:
549
+ if tmp_op.name == 'SaveV2':
550
+ for tmp_id, tmp_attr in enumerate(tmp_op.attr):
551
+ if tmp_attr.name == 'has_ev':
552
+ tmp_op.attr.remove(tmp_attr)
553
+ break
554
+
555
+ def clear_meta_collect(self, meta_graph_def):
556
+ drop_meta_collects = []
557
+ for key in meta_graph_def.collection_def:
558
+ val = meta_graph_def.collection_def[key]
559
+ if val.HasField('node_list'):
560
+ if 'embedding_weights' in val.node_list.value[
561
+ 0] and 'easy_rec' not in val.node_list.value[0]:
562
+ drop_meta_collects.append(key)
563
+ elif key == 'saved_model_assets':
564
+ drop_meta_collects.append(key)
565
+ for key in drop_meta_collects:
566
+ meta_graph_def.collection_def.pop(key)
567
+
568
+ def remove_embedding_weights_and_update_lookup_outputs(self):
569
+
570
+ def _should_drop(name):
571
+ if '_embedding_weights' in name:
572
+ if self._verbose:
573
+ logging.info('[SHOULD_DROP] %s' % name)
574
+ return True
575
+
576
+ logging.info('remove embedding_weights node in graph_def.node')
577
+ logging.info(
578
+ 'and replace the old embedding_lookup outputs with new lookup_op outputs'
579
+ )
580
+
581
+ for tid, node in enumerate(self._all_graph_nodes):
582
+ # drop the nodes
583
+ if _should_drop(node.name):
584
+ self._all_graph_node_flags[tid] = False
585
+ else:
586
+ for i in range(len(node.input)):
587
+ if _should_drop(node.input[i]):
588
+ input_name, _ = proto_util.get_norm_embed_name(
589
+ node.input[i], self._verbose)
590
+ print('REPLACE:' + node.input[i] + '=>' + input_name)
591
+ input_name = self._lookup_outs[self._feature_names.index(
592
+ input_name)].name
593
+ if input_name.endswith(':0'):
594
+ input_name = input_name.replace(':0', '')
595
+ node.input[i] = input_name
596
+
597
+ # drop by ids
598
+ def _drop_by_ids(self, tmp_obj, key, drop_ids):
599
+ keep_vals = [
600
+ x for i, x in enumerate(getattr(tmp_obj, key)) if i not in drop_ids
601
+ ]
602
+ tmp_obj.ClearField(key)
603
+ getattr(tmp_obj, key).extend(keep_vals)
604
+
605
+ def clear_save_restore(self):
606
+ """Clear save restore ops.
607
+
608
+ save/restore_all need save/restore_shard as input
609
+ save/restore_shard needs save/Assign_[0-N] as input
610
+ save/Assign_[0-N] needs save/RestoreV2 as input
611
+ save/RestoreV2 use save/RestoreV2/tensor_names and save/RestoreV2/shape_and_slices as input
612
+ edit [ save/RestoreV2/tensor_names save/RestoreV2/shape_and_slices save/RestoreV2 save/restore_shard ]
613
+ """
614
+ for tid, node in enumerate(self._all_graph_nodes):
615
+ if not self._all_graph_node_flags[tid]:
616
+ continue
617
+ if node.name == 'save/RestoreV2/tensor_names':
618
+ self._restore_tensor_node = node
619
+ break
620
+ # assert self._restore_tensor_node is not None, 'save/RestoreV2/tensor_names is not found'
621
+
622
+ if self._restore_tensor_node:
623
+ drop_ids = []
624
+ for tmp_id, tmp_name in enumerate(
625
+ self._restore_tensor_node.attr['value'].tensor.string_val):
626
+ if 'embedding_weights' in self.bytes2str(tmp_name):
627
+ drop_ids.append(tmp_id)
628
+
629
+ self._drop_by_ids(self._restore_tensor_node.attr['value'].tensor,
630
+ 'string_val', drop_ids)
631
+ keep_node_num = len(
632
+ self._restore_tensor_node.attr['value'].tensor.string_val)
633
+ logging.info(
634
+ 'update self._restore_tensor_node: string_val keep_num = %d drop_num = %d'
635
+ % (keep_node_num, len(drop_ids)))
636
+ self._restore_tensor_node.attr['value'].tensor.tensor_shape.dim[
637
+ 0].size = keep_node_num
638
+ self._restore_tensor_node.attr['_output_shapes'].list.shape[0].dim[
639
+ 0].size = keep_node_num
640
+
641
+ logging.info(
642
+ 'update save/RestoreV2, drop tensor_shapes, _output_shapes, related to embedding_weights'
643
+ )
644
+ self._restore_shard_node = None
645
+ for node_id, node in enumerate(self._all_graph_nodes):
646
+ if not self._all_graph_node_flags[tid]:
647
+ continue
648
+ if node.name == 'save/RestoreV2/shape_and_slices':
649
+ node.attr['value'].tensor.tensor_shape.dim[0].size = keep_node_num
650
+ node.attr['_output_shapes'].list.shape[0].dim[0].size = keep_node_num
651
+ self._drop_by_ids(node.attr['value'].tensor, 'string_val', drop_ids)
652
+ elif node.name == 'save/RestoreV2':
653
+ self._drop_by_ids(node.attr['_output_shapes'].list, 'shape', drop_ids)
654
+ self._drop_by_ids(node.attr['dtypes'].list, 'type', drop_ids)
655
+ elif node.name == 'save/restore_shard':
656
+ self._restore_shard_node = node
657
+ elif node.name.startswith('save/restore_all'):
658
+ self._restore_all_node.append(node)
659
+
660
+ def clear_save_assign(self):
661
+ logging.info(
662
+ 'update save/Assign, drop tensor_shapes, _output_shapes, related to embedding_weights'
663
+ )
664
+ # edit save/Assign
665
+ drop_save_assigns = []
666
+ all_kv_drop = []
667
+ for tid, node in enumerate(self._all_graph_nodes):
668
+ if not self._all_graph_node_flags[tid]:
669
+ continue
670
+ if node.op == 'Assign' and 'save/Assign' in node.name and \
671
+ 'embedding_weights' in node.input[0]:
672
+ drop_save_assigns.append('^' + node.name)
673
+ self._all_graph_node_flags[tid] = False
674
+ elif 'embedding_weights/ConcatPartitions/concat' in node.name:
675
+ self._all_graph_node_flags[tid] = False
676
+ elif node.name.endswith('/embedding_weights') and node.op == 'Identity':
677
+ self._all_graph_node_flags[tid] = False
678
+ elif 'save/KvResourceImportV2' in node.name and node.op == 'KvResourceImportV2':
679
+ drop_save_assigns.append('^' + node.name)
680
+ self._all_graph_node_flags[tid] = False
681
+ elif 'KvResourceImportV2' in node.name:
682
+ self._all_graph_node_flags[tid] = False
683
+ elif 'save/Const' in node.name and node.op == 'Const':
684
+ if '_class' in node.attr and len(node.attr['_class'].list.s) > 0:
685
+ const_name = node.attr['_class'].list.s[0]
686
+ if not isinstance(const_name, str):
687
+ const_name = const_name.decode('utf-8')
688
+ if 'embedding_weights' in const_name:
689
+ self._all_graph_node_flags[tid] = False
690
+ elif 'ReadKvVariableOp' in node.name and node.op == 'ReadKvVariableOp':
691
+ all_kv_drop.append(node.name)
692
+ self._all_graph_node_flags[tid] = False
693
+ elif node.op == 'Assign' and 'save/Assign' in node.name:
694
+ # update node(save/Assign_[0-N])'s input[1] by the position of
695
+ # node.input[0] in save/RestoreV2/tensor_names
696
+ # the outputs of save/RestoreV2 is connected to save/Assign
697
+ tmp_id = [
698
+ self.bytes2str(x)
699
+ for x in self._restore_tensor_node.attr['value'].tensor.string_val
700
+ ].index(node.input[0])
701
+ if tmp_id != 0:
702
+ tmp_input2 = 'save/RestoreV2:%d' % tmp_id
703
+ else:
704
+ tmp_input2 = 'save/RestoreV2'
705
+ if tmp_input2 != node.input[1]:
706
+ if self._verbose:
707
+ logging.info("update save/Assign[%s]'s input from %s to %s" %
708
+ (node.name, node.input[1], tmp_input2))
709
+ node.input[1] = tmp_input2
710
+
711
+ # save/restore_all need save/restore_shard as input
712
+ # save/restore_shard needs save/Assign_[0-N] as input
713
+ # save/Assign_[0-N] needs save/RestoreV2 as input
714
+ if self._restore_shard_node:
715
+ for tmp_input in drop_save_assigns:
716
+ self._restore_shard_node.input.remove(tmp_input)
717
+ if self._verbose:
718
+ logging.info('drop restore_shard input: %s' % tmp_input)
719
+ elif len(self._restore_all_node) > 0:
720
+ for tmp_input in drop_save_assigns:
721
+ for tmp_node in self._restore_all_node:
722
+ if tmp_input in tmp_node.input:
723
+ tmp_node.input.remove(tmp_input)
724
+ if self._verbose:
725
+ logging.info('drop %s input: %s' % (tmp_node.name, tmp_input))
726
+ break
727
+
728
+ def clear_save_v2(self):
729
+ """Clear SaveV2 ops.
730
+
731
+ save/Identity need [ save/MergeV2Checkpoints, save/control_dependency ]
732
+ as input. Save/MergeV2Checkpoints need [save/MergeV2Checkpoints/checkpoint_prefixes]
733
+ as input. Save/MergeV2Checkpoints/checkpoint_prefixes need [ save/ShardedFilename,
734
+ save/control_dependency ] as input. save/control_dependency need save/SaveV2 as input.
735
+ save/SaveV2 input: [ save/SaveV2/tensor_names, save/SaveV2/shape_and_slices ]
736
+ edit save/SaveV2 save/SaveV2/shape_and_slices save/SaveV2/tensor_names.
737
+ """
738
+ logging.info('update save/SaveV2 input shape, _output_shapes, tensor_shape')
739
+ save_drop_ids = []
740
+ for tid, node in enumerate(self._all_graph_nodes):
741
+ if not self._all_graph_node_flags[tid]:
742
+ continue
743
+ if node.name == 'save/SaveV2' and node.op == 'SaveV2':
744
+ for tmp_id, tmp_input in enumerate(node.input):
745
+ if '/embedding_weights' in tmp_input:
746
+ save_drop_ids.append(tmp_id)
747
+ diff_num = len(node.input) - len(node.attr['dtypes'].list.type)
748
+ self._drop_by_ids(node, 'input', save_drop_ids)
749
+ save_drop_ids = [x - diff_num for x in save_drop_ids]
750
+ self._drop_by_ids(node.attr['dtypes'].list, 'type', save_drop_ids)
751
+ if 'has_ev' in node.attr:
752
+ del node.attr['has_ev']
753
+ for node in self._all_graph_nodes:
754
+ if node.name == 'save/SaveV2/shape_and_slices' and node.op == 'Const':
755
+ # _output_shapes # size # string_val
756
+ node.attr['_output_shapes'].list.shape[0].dim[0].size -= len(
757
+ save_drop_ids)
758
+ node.attr['value'].tensor.tensor_shape.dim[0].size -= len(save_drop_ids)
759
+ self._drop_by_ids(node.attr['value'].tensor, 'string_val',
760
+ save_drop_ids)
761
+ elif node.name == 'save/SaveV2/tensor_names':
762
+ # tensor_names may not have the same order as save/SaveV2/shape_and_slices
763
+ tmp_drop_ids = [
764
+ tmp_id for tmp_id, tmp_val in enumerate(
765
+ node.attr['value'].tensor.string_val)
766
+ if 'embedding_weights' in self.bytes2str(tmp_val)
767
+ ]
768
+ # attr['value'].tensor.string_val # tensor_shape # size
769
+ assert len(save_drop_ids) == len(save_drop_ids)
770
+ node.attr['_output_shapes'].list.shape[0].dim[0].size -= len(
771
+ tmp_drop_ids)
772
+ node.attr['value'].tensor.tensor_shape.dim[0].size -= len(tmp_drop_ids)
773
+ self._drop_by_ids(node.attr['value'].tensor, 'string_val', tmp_drop_ids)
774
+
775
+ def clear_initialize(self):
776
+ """Clear initialization ops.
777
+
778
+ */read(Identity) depend on [*(VariableV2)]
779
+ */Assign depend on [*/Initializer/*, *(VariableV2)]
780
+ drop embedding_weights initialization nodes
781
+ */embedding_weights/part_x [,/Assign,/read]
782
+ */embedding_weights/part_1/Initializer/truncated_normal [,/shape,/mean,/stddev,/TruncatedNormal,/mul]
783
+ """
784
+ logging.info('Remove Initialization nodes for embedding_weights')
785
+ for tid, node in enumerate(self._all_graph_nodes):
786
+ if not self._all_graph_node_flags[tid]:
787
+ continue
788
+ if 'embedding_weights' in node.name and 'Initializer' in node.name:
789
+ self._all_graph_node_flags[tid] = False
790
+ elif 'embedding_weights' in node.name and 'Assign' in node.name:
791
+ self._all_graph_node_flags[tid] = False
792
+ elif 'embedding_weights' in node.name and node.op == 'VariableV2':
793
+ self._all_graph_node_flags[tid] = False
794
+ elif 'embedding_weights' in node.name and node.name.endswith(
795
+ '/read') and node.op == 'Identity':
796
+ self._all_graph_node_flags[tid] = False
797
+ elif 'embedding_weights' in node.name and node.op == 'Identity':
798
+ node_toks = node.name.split('/')
799
+ node_tok = node_toks[-1]
800
+ if 'embedding_weights_' in node_tok:
801
+ node_tok = node_tok[len('embedding_weights_'):]
802
+ try:
803
+ int(node_tok)
804
+ self._all_graph_node_flags[tid] = False
805
+ except Exception:
806
+ pass
807
+
808
+ def clear_embedding_variable(self):
809
+ # for pai embedding variable, we drop some special nodes
810
+ for tid, node in enumerate(self._all_graph_nodes):
811
+ if not self._all_graph_node_flags[tid]:
812
+ continue
813
+ if node.op in [
814
+ 'ReadKvVariableOp', 'KvVarIsInitializedOp', 'KvVarHandleOp'
815
+ ]:
816
+ self._all_graph_node_flags[tid] = False
817
+
818
+ # there maybe some nodes depend on the dropped nodes, they are dropped as well
819
+ def drop_dependent_nodes(self):
820
+ drop_names = [
821
+ tmp_node.name
822
+ for tid, tmp_node in enumerate(self._all_graph_nodes)
823
+ if not self._all_graph_node_flags[tid]
824
+ ]
825
+ while True:
826
+ more_drop_names = []
827
+ for tid, tmp_node in enumerate(self._all_graph_nodes):
828
+ if not self._all_graph_node_flags[tid]:
829
+ continue
830
+ if len(tmp_node.input) > 0 and tmp_node.input[0] in drop_names:
831
+ logging.info('drop dependent node: %s depend on %s' %
832
+ (tmp_node.name, tmp_node.input[0]))
833
+ self._all_graph_node_flags[tid] = False
834
+ more_drop_names.append(tmp_node.name)
835
+ drop_names = more_drop_names
836
+ if not drop_names:
837
+ break
838
+
839
+ def edit_graph(self):
840
+ # the main entrance
841
+ lookup_input_indices, lookup_input_values, lookup_input_shapes,\
842
+ lookup_input_weights = self.find_lookup_inputs()
843
+
844
+ # add lookup op to the graph
845
+ self._meta_graph_def = self.add_lookup_op(lookup_input_indices,
846
+ lookup_input_values,
847
+ lookup_input_shapes,
848
+ lookup_input_weights)
849
+
850
+ self.clear_meta_graph_embeding(self._meta_graph_def)
851
+
852
+ self.clear_meta_collect(self._meta_graph_def)
853
+
854
+ self.init_graph_node_clear_flags()
855
+
856
+ self.remove_embedding_weights_and_update_lookup_outputs()
857
+
858
+ # save/RestoreV2
859
+ self.clear_save_restore()
860
+
861
+ # save/Assign
862
+ self.clear_save_assign()
863
+
864
+ # save/SaveV2
865
+ self.clear_save_v2()
866
+
867
+ self.clear_initialize()
868
+
869
+ self.clear_embedding_variable()
870
+
871
+ self.drop_dependent_nodes()
872
+
873
+ self._meta_graph_def.graph_def.ClearField('node')
874
+ self._meta_graph_def.graph_def.node.extend([
875
+ x for tid, x in enumerate(self._all_graph_nodes)
876
+ if self._all_graph_node_flags[tid]
877
+ ])
878
+
879
+ logging.info('old node number = %d' % self._old_node_num)
880
+ logging.info('node number = %d' % len(self._meta_graph_def.graph_def.node))
881
+
882
+ if self._verbose:
883
+ debug_dump_path = os.path.join(self._debug_dir, 'graph.txt')
884
+ with GFile(debug_dump_path, 'w') as fout:
885
+ fout.write(text_format.MessageToString(self.graph_def, as_utf8=True))
886
+ debug_dump_path = os.path.join(self._debug_dir, 'meta_graph.txt')
887
+ with GFile(debug_dump_path, 'w') as fout:
888
+ fout.write(
889
+ text_format.MessageToString(self._meta_graph_def, as_utf8=True))
890
+
891
+ def edit_graph_for_oss(self):
892
+ # the main entrance
893
+ lookup_input_indices, lookup_input_values, lookup_input_shapes,\
894
+ lookup_input_weights = self.find_lookup_inputs()
895
+
896
+ # add lookup op to the graph
897
+ self._meta_graph_def = self.add_oss_lookup_op(lookup_input_indices,
898
+ lookup_input_values,
899
+ lookup_input_shapes,
900
+ lookup_input_weights)
901
+
902
+ self.clear_meta_graph_embeding(self._meta_graph_def)
903
+
904
+ self.clear_meta_collect(self._meta_graph_def)
905
+
906
+ self.init_graph_node_clear_flags()
907
+
908
+ self.remove_embedding_weights_and_update_lookup_outputs()
909
+
910
+ # save/RestoreV2
911
+ self.clear_save_restore()
912
+
913
+ # save/Assign
914
+ self.clear_save_assign()
915
+
916
+ # save/SaveV2
917
+ self.clear_save_v2()
918
+
919
+ self.clear_initialize()
920
+
921
+ self.clear_embedding_variable()
922
+
923
+ self.drop_dependent_nodes()
924
+
925
+ self._meta_graph_def.graph_def.ClearField('node')
926
+ self._meta_graph_def.graph_def.node.extend([
927
+ x for tid, x in enumerate(self._all_graph_nodes)
928
+ if self._all_graph_node_flags[tid]
929
+ ])
930
+
931
+ logging.info('old node number = %d' % self._old_node_num)
932
+ logging.info('node number = %d' % len(self._meta_graph_def.graph_def.node))
933
+
934
+ if self._verbose:
935
+ debug_dump_path = os.path.join(self._debug_dir, 'graph.txt')
936
+ with GFile(debug_dump_path, 'w') as fout:
937
+ fout.write(text_format.MessageToString(self.graph_def, as_utf8=True))
938
+ debug_dump_path = os.path.join(self._debug_dir, 'meta_graph.txt')
939
+ with GFile(debug_dump_path, 'w') as fout:
940
+ fout.write(
941
+ text_format.MessageToString(self._meta_graph_def, as_utf8=True))