easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
@@ -0,0 +1,134 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import argparse
4
+ import logging
5
+ import os
6
+ import sys
7
+
8
+ import tensorflow as tf
9
+ from tensorflow.core.protobuf import saved_model_pb2
10
+ from tensorflow.python.lib.io.file_io import file_exists
11
+ from tensorflow.python.lib.io.file_io import recursive_create_dir
12
+ from tensorflow.python.platform.gfile import GFile
13
+
14
+ import easy_rec
15
+ from easy_rec.python.utils.meta_graph_editor import MetaGraphEditor
16
+
17
+ logging.basicConfig(
18
+ format='[%(levelname)s] %(asctime)s %(filename)s:%(lineno)d : %(message)s',
19
+ level=logging.INFO)
20
+
21
+ if __name__ == '__main__':
22
+ """Replace the default embedding_lookup ops with self defined embedding lookup ops.
23
+
24
+ The data are now stored in redis, for lookup, it is to retrieve the
25
+ embedding vectors by {version}_{embed_name}_{embed_id}.
26
+ Example:
27
+ python -m easy_rec.python.tools.edit_lookup_graph
28
+ --saved_model_dir rtp_large_embedding_export/1604304644
29
+ --output_dir ./after_edit_save
30
+ --test_data_path data/test/rtp/xys_cxr_fg_sample_test2_with_lbl.txt
31
+ """
32
+
33
+ parser = argparse.ArgumentParser()
34
+ parser.add_argument(
35
+ '--saved_model_dir', type=str, default=None, help='saved model dir')
36
+ parser.add_argument('--output_dir', type=str, default=None, help='output dir')
37
+ parser.add_argument(
38
+ '--redis_url', type=str, default='127.0.0.1:6379', help='redis url')
39
+ parser.add_argument(
40
+ '--redis_passwd', type=str, default='', help='redis password')
41
+ parser.add_argument('--time_out', type=int, default=1500, help='timeout')
42
+ parser.add_argument(
43
+ '--test_data_path', type=str, default='', help='test data path')
44
+ parser.add_argument('--verbose', action='store_true', default=False)
45
+
46
+ args = parser.parse_args()
47
+ logging.info('saved_model_dir: %s' % args.saved_model_dir)
48
+
49
+ if not os.path.exists(os.path.join(args.saved_model_dir, 'saved_model.pb')):
50
+ logging.error('saved_model.pb does not exist in %s' % args.saved_model_dir)
51
+ sys.exit(1)
52
+
53
+ logging.info('output_dir: %s' % args.output_dir)
54
+ logging.info('redis_url: %s' % args.redis_url)
55
+ lookup_lib_path = os.path.join(easy_rec.ops_dir, 'libkv_lookup.so')
56
+ logging.info('lookup_lib_path: %s' % lookup_lib_path)
57
+
58
+ if not file_exists(args.output_dir):
59
+ recursive_create_dir(args.output_dir)
60
+
61
+ meta_graph_editor = MetaGraphEditor(
62
+ lookup_lib_path,
63
+ args.saved_model_dir,
64
+ args.redis_url,
65
+ args.redis_passwd,
66
+ args.time_out,
67
+ meta_graph_def=None,
68
+ debug_dir=args.output_dir if args.verbose else '')
69
+ meta_graph_editor.edit_graph()
70
+
71
+ meta_graph_version = meta_graph_editor.meta_graph_version
72
+ if meta_graph_version == '':
73
+ export_ts = [
74
+ x for x in args.saved_model_dir.split('/') if x != '' and x is not None
75
+ ]
76
+ meta_graph_version = export_ts[-1]
77
+
78
+ # import edit graph
79
+ tf.reset_default_graph()
80
+ saver = tf.train.import_meta_graph(meta_graph_editor._meta_graph_def)
81
+
82
+ embed_name_to_id_file = os.path.join(args.output_dir, 'embed_name_to_ids.txt')
83
+ with GFile(embed_name_to_id_file, 'w') as fout:
84
+ for tmp_norm_name in meta_graph_editor._embed_name_to_ids:
85
+ fout.write(
86
+ '%s\t%s\n' %
87
+ (tmp_norm_name, meta_graph_editor._embed_name_to_ids[tmp_norm_name]))
88
+ tf.add_to_collection(
89
+ tf.GraphKeys.ASSET_FILEPATHS,
90
+ tf.constant(
91
+ embed_name_to_id_file, dtype=tf.string, name='embed_name_to_ids.txt'))
92
+
93
+ graph = tf.get_default_graph()
94
+ inputs = meta_graph_editor.signature_def.inputs
95
+ inputs_map = {}
96
+ for name, tensor in inputs.items():
97
+ logging.info('model inputs: %s => %s' % (name, tensor.name))
98
+ inputs_map[name] = graph.get_tensor_by_name(tensor.name)
99
+
100
+ outputs = meta_graph_editor.signature_def.outputs
101
+ outputs_map = {}
102
+ for name, tensor in outputs.items():
103
+ logging.info('model outputs: %s => %s' % (name, tensor.name))
104
+ outputs_map[name] = graph.get_tensor_by_name(tensor.name)
105
+ with tf.Session() as sess:
106
+ saver.restore(sess, args.saved_model_dir + '/variables/variables')
107
+ output_dir = os.path.join(args.output_dir, meta_graph_version)
108
+ tf.saved_model.simple_save(
109
+ sess, output_dir, inputs=inputs_map, outputs=outputs_map)
110
+ # the meta_graph_version could not be passed via existing interfaces
111
+ # so we could only write it by the raw methods
112
+ saved_model = saved_model_pb2.SavedModel()
113
+ with GFile(os.path.join(output_dir, 'saved_model.pb'), 'rb') as fin:
114
+ saved_model.ParseFromString(fin.read())
115
+
116
+ saved_model.meta_graphs[
117
+ 0].meta_info_def.meta_graph_version = meta_graph_editor.meta_graph_version
118
+ with GFile(os.path.join(output_dir, 'saved_model.pb'), 'wb') as fout:
119
+ fout.write(saved_model.SerializeToString())
120
+
121
+ logging.info('save output to %s' % output_dir)
122
+ if args.test_data_path:
123
+ with GFile(args.test_data_path, 'r') as fin:
124
+ feature_vals = []
125
+ for line_str in fin:
126
+ line_str = line_str.strip()
127
+ line_toks = line_str.split('')
128
+ line_toks = line_toks[1:]
129
+ feature_vals.append(''.join(line_toks))
130
+ if len(feature_vals) >= 32:
131
+ break
132
+ out_vals = sess.run(
133
+ outputs_map, feed_dict={inputs_map['features']: feature_vals})
134
+ logging.info('test_data probs:' + str(out_vals))
@@ -0,0 +1,116 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ from __future__ import print_function
4
+
5
+ import logging
6
+ import os
7
+ import sys
8
+
9
+ import faiss
10
+ import numpy as np
11
+ import tensorflow as tf
12
+
13
+ from easy_rec.python.utils import io_util
14
+
15
+ logging.basicConfig(
16
+ level=logging.INFO, format='[%(asctime)s][%(levelname)s] %(message)s')
17
+
18
+ tf.app.flags.DEFINE_string('tables', '', 'tables passed by pai command')
19
+ tf.app.flags.DEFINE_integer('batch_size', 1024, 'batch size')
20
+ tf.app.flags.DEFINE_integer('embedding_dim', 32, 'embedding dimension')
21
+ tf.app.flags.DEFINE_string('index_output_dir', '', 'index output directory')
22
+ tf.app.flags.DEFINE_string('index_type', 'IVFFlat', 'index type')
23
+ tf.app.flags.DEFINE_integer('ivf_nlist', 1000, 'nlist')
24
+ tf.app.flags.DEFINE_integer('hnsw_M', 32, 'hnsw M')
25
+ tf.app.flags.DEFINE_integer('hnsw_efConstruction', 200, 'hnsw efConstruction')
26
+ tf.app.flags.DEFINE_integer('debug', 0, 'debug index')
27
+
28
+ FLAGS = tf.app.flags.FLAGS
29
+
30
+
31
+ def main(argv):
32
+ reader = tf.python_io.TableReader(
33
+ FLAGS.tables, slice_id=0, slice_count=1, capacity=FLAGS.batch_size * 2)
34
+ i = 0
35
+ id_map_f = tf.gfile.GFile(
36
+ os.path.join(FLAGS.index_output_dir, 'id_mapping'), 'w')
37
+ embeddings = []
38
+ while True:
39
+ try:
40
+ records = reader.read(FLAGS.batch_size)
41
+ for j, record in enumerate(records):
42
+ if isinstance(record[0], bytes):
43
+ eid = record[0].decode('utf-8')
44
+ id_map_f.write('%s\n' % eid)
45
+
46
+ embeddings.extend(
47
+ [list(map(float, record[1].split(b','))) for record in records])
48
+ i += 1
49
+ if i % 100 == 0:
50
+ logging.info('read %d embeddings.' % (i * FLAGS.batch_size))
51
+ except tf.python_io.OutOfRangeException:
52
+ break
53
+ reader.close()
54
+ id_map_f.close()
55
+
56
+ logging.info('Building faiss index..')
57
+ if FLAGS.index_type == 'IVFFlat':
58
+ quantizer = faiss.IndexFlatIP(FLAGS.embedding_dim)
59
+ index = faiss.IndexIVFFlat(quantizer, FLAGS.embedding_dim, FLAGS.ivf_nlist,
60
+ faiss.METRIC_INNER_PRODUCT)
61
+ elif FLAGS.index_type == 'HNSWFlat':
62
+ index = faiss.IndexHNSWFlat(FLAGS.embedding_dim, FLAGS.hnsw_M,
63
+ faiss.METRIC_INNER_PRODUCT)
64
+ index.hnsw.efConstruction = FLAGS.hnsw_efConstruction
65
+ else:
66
+ raise NotImplementedError
67
+
68
+ embeddings = np.array(embeddings)
69
+ if FLAGS.index_type == 'IVFFlat':
70
+ logging.info('train embeddings...')
71
+ index.train(embeddings)
72
+
73
+ logging.info('build embeddings...')
74
+ index.add(embeddings)
75
+ faiss.write_index(index, 'faiss_index')
76
+
77
+ with tf.gfile.GFile(
78
+ os.path.join(FLAGS.index_output_dir, 'faiss_index'), 'wb') as f_out:
79
+ with open('faiss_index', 'rb') as f_in:
80
+ f_out.write(f_in.read())
81
+
82
+ if FLAGS.debug != 0:
83
+ # IVFFlat
84
+ for ivf_nlist in [100, 500, 1000, 2000]:
85
+ quantizer = faiss.IndexFlatIP(FLAGS.embedding_dim)
86
+ index = faiss.IndexIVFFlat(quantizer, FLAGS.embedding_dim, ivf_nlist,
87
+ faiss.METRIC_INNER_PRODUCT)
88
+ index.train(embeddings)
89
+ index.add(embeddings)
90
+ index_name = 'faiss_index_ivfflat_nlist%d' % ivf_nlist
91
+ faiss.write_index(index, index_name)
92
+ with tf.gfile.GFile(
93
+ os.path.join(FLAGS.index_output_dir, index_name), 'wb') as f_out:
94
+ with open(index_name, 'rb') as f_in:
95
+ f_out.write(f_in.read())
96
+
97
+ # HNSWFlat
98
+ for hnsw_M in [16, 32, 64, 128]:
99
+ for hnsw_efConstruction in [64, 128, 256, 512, 1024, 2048, 4096, 8196]:
100
+ if hnsw_efConstruction < hnsw_M * 2:
101
+ continue
102
+ index = faiss.IndexHNSWFlat(FLAGS.embedding_dim, hnsw_M,
103
+ faiss.METRIC_INNER_PRODUCT)
104
+ index.hnsw.efConstruction = hnsw_efConstruction
105
+ index.add(embeddings)
106
+ index_name = 'faiss_index_hnsw_M%d_ef%d' % (hnsw_M, hnsw_efConstruction)
107
+ faiss.write_index(index, index_name)
108
+ with tf.gfile.GFile(
109
+ os.path.join(FLAGS.index_output_dir, index_name), 'wb') as f_out:
110
+ with open(index_name, 'rb') as f_in:
111
+ f_out.write(f_in.read())
112
+
113
+
114
+ if __name__ == '__main__':
115
+ sys.argv = io_util.filter_unknown_args(FLAGS, sys.argv)
116
+ tf.app.run()
@@ -0,0 +1,316 @@
1
+ from __future__ import division
2
+ from __future__ import print_function
3
+
4
+ import json
5
+ import os
6
+ import sys
7
+ from collections import OrderedDict
8
+
9
+ import numpy as np
10
+ import pandas as pd
11
+ import tensorflow as tf
12
+ from tensorflow.python.framework.meta_graph import read_meta_graph_file
13
+
14
+ from easy_rec.python.utils import config_util
15
+ from easy_rec.python.utils import io_util
16
+
17
+ if tf.__version__ >= '2.0':
18
+ tf = tf.compat.v1
19
+
20
+ import matplotlib # NOQA
21
+ matplotlib.use('Agg') # NOQA
22
+ import matplotlib.pyplot as plt # NOQA
23
+
24
+ tf.app.flags.DEFINE_string('model_type', 'variational_dropout',
25
+ 'feature selection model type')
26
+ tf.app.flags.DEFINE_string('config_path', '',
27
+ 'feature selection model config path')
28
+ tf.app.flags.DEFINE_string('checkpoint_path', None,
29
+ 'feature selection model checkpoint path')
30
+ tf.app.flags.DEFINE_string('output_dir', '',
31
+ 'feature selection result directory')
32
+ tf.app.flags.DEFINE_integer(
33
+ 'topk', 100, 'select topk importance features for each feature group')
34
+ tf.app.flags.DEFINE_string('fg_path', '', 'fg config path')
35
+ tf.app.flags.DEFINE_bool('visualize', False,
36
+ 'visualization feature selection result or not')
37
+ FLAGS = tf.app.flags.FLAGS
38
+
39
+
40
+ class VariationalDropoutFS:
41
+
42
+ def __init__(self,
43
+ config_path,
44
+ output_dir,
45
+ topk,
46
+ checkpoint_path=None,
47
+ fg_path=None,
48
+ visualize=False):
49
+ self._config_path = config_path
50
+ self._output_dir = output_dir
51
+ self._topk = topk
52
+ if not tf.gfile.Exists(self._output_dir):
53
+ tf.gfile.MakeDirs(self._output_dir)
54
+ self._checkpoint_path = checkpoint_path
55
+ self._fg_path = fg_path
56
+ self._visualize = visualize
57
+
58
+ def process(self):
59
+ tf.logging.info('Loading logit_p of VariationalDropout layer ...')
60
+ feature_dim_dropout_p_map, embedding_wise_variational_dropout = self._feature_dim_dropout_ratio(
61
+ )
62
+
63
+ feature_importance_map = {}
64
+ for group_name, feature_dim_dropout_p in feature_dim_dropout_p_map.items():
65
+ tf.logging.info('Calculating %s feature importance ...' % group_name)
66
+ feature_importance = self._get_feature_importance(
67
+ feature_dim_dropout_p, embedding_wise_variational_dropout)
68
+ feature_importance_map[group_name] = feature_importance
69
+
70
+ tf.logging.info('Dump %s feature importance to csv ...' % group_name)
71
+ self._dump_to_csv(feature_importance, group_name)
72
+
73
+ if self._visualize:
74
+ tf.logging.info('Visualizing %s feature importance ...' % group_name)
75
+ if embedding_wise_variational_dropout:
76
+ self._visualize_embedding_dim_importance(feature_dim_dropout_p)
77
+ self._visualize_feature_importance(feature_importance, group_name)
78
+
79
+ tf.logging.info('Processing model config ...')
80
+ self._process_config(feature_importance_map)
81
+
82
+ def _feature_dim_dropout_ratio(self):
83
+ """Get dropout ratio of embedding-wise or feature-wise."""
84
+ config = config_util.get_configs_from_pipeline_file(self._config_path)
85
+ assert config.model_config.HasField(
86
+ 'variational_dropout'), 'variational_dropout must be in model_config'
87
+
88
+ embedding_wise_variational_dropout = config.model_config.variational_dropout.embedding_wise_variational_dropout
89
+
90
+ if self._checkpoint_path is None or len(self._checkpoint_path) == 0:
91
+ checkpoint_path = tf.train.latest_checkpoint(config.model_dir)
92
+ else:
93
+ checkpoint_path = self._checkpoint_path
94
+
95
+ meta_graph_def = read_meta_graph_file(checkpoint_path + '.meta')
96
+ features_dimension_map = dict()
97
+ for col_def in meta_graph_def.collection_def[
98
+ 'variational_dropout'].bytes_list.value:
99
+ name, features_dimension = json.loads(col_def)
100
+ name = 'all' if name == '' else name
101
+ features_dimension_map[name] = OrderedDict(features_dimension)
102
+
103
+ tf.logging.info('Reading checkpoint from %s ...' % checkpoint_path)
104
+ reader = tf.train.NewCheckpointReader(checkpoint_path)
105
+
106
+ feature_dim_dropout_p_map = {}
107
+ for feature_group in config.model_config.feature_groups:
108
+ group_name = feature_group.group_name
109
+
110
+ logit_p_name = 'logit_p' if group_name == 'all' else 'logit_p_%s' % group_name
111
+ try:
112
+ logit_p = reader.get_tensor(logit_p_name)
113
+ except Exception:
114
+ print('get `logit_p` failed, try to get `backbone/logit_p`')
115
+ logit_p = reader.get_tensor('backbone/' + logit_p_name)
116
+ feature_dims_importance = tf.sigmoid(logit_p)
117
+ with tf.Session() as sess:
118
+ feature_dims_importance = feature_dims_importance.eval(session=sess)
119
+
120
+ feature_dim_dropout_p = {}
121
+ if embedding_wise_variational_dropout:
122
+ index_end = 0
123
+ for feature_name, feature_dim in features_dimension_map[
124
+ group_name].items():
125
+ index_start = index_end
126
+ index_end = index_start + feature_dim
127
+ feature_dim_dropout_p[feature_name] = feature_dims_importance[
128
+ index_start:index_end]
129
+ else:
130
+ index = 0
131
+ for feature_name in features_dimension_map[group_name].keys():
132
+ feature_dim_dropout_p[feature_name] = feature_dims_importance[index]
133
+ index += 1
134
+
135
+ feature_dim_dropout_p_map[group_name] = feature_dim_dropout_p
136
+ return feature_dim_dropout_p_map, embedding_wise_variational_dropout
137
+
138
+ def _get_feature_importance(self, feature_dim_dropout_p,
139
+ embedding_wise_variational_dropout):
140
+ """Calculate feature importance."""
141
+ if embedding_wise_variational_dropout:
142
+ feature_importance = {}
143
+ for item in feature_dim_dropout_p.items():
144
+ dropout_rate_mean = np.mean(item[1])
145
+ feature_importance[item[0]] = dropout_rate_mean
146
+ feature_importance = OrderedDict(
147
+ sorted(feature_importance.items(), key=lambda e: e[1]))
148
+ else:
149
+ feature_importance = OrderedDict(
150
+ sorted(feature_dim_dropout_p.items(), key=lambda e: e[1]))
151
+ return feature_importance
152
+
153
+ def _process_config(self, feature_importance_map):
154
+ """Process model config and fg config with feature selection."""
155
+ excluded_features = set()
156
+ for group_name, feature_importance in feature_importance_map.items():
157
+ for i, (feature_name, _) in enumerate(feature_importance.items()):
158
+ if i >= self._topk:
159
+ excluded_features.add(feature_name)
160
+
161
+ config = config_util.get_configs_from_pipeline_file(self._config_path)
162
+ # keep sequence features and side-infos
163
+ sequence_features = set()
164
+ for feature_group in config.model_config.feature_groups:
165
+ for sequence_feature in feature_group.sequence_features:
166
+ for seq_att_map in sequence_feature.seq_att_map:
167
+ for key in seq_att_map.key:
168
+ sequence_features.add(key)
169
+ for hist_seq in seq_att_map.hist_seq:
170
+ sequence_features.add(hist_seq)
171
+ # compat with din
172
+ for sequence_feature in config.model_config.seq_att_groups:
173
+ for seq_att_map in sequence_feature.seq_att_map:
174
+ for key in seq_att_map.key:
175
+ sequence_features.add(key)
176
+ for hist_seq in seq_att_map.hist_seq:
177
+ sequence_features.add(hist_seq)
178
+ excluded_features = excluded_features - sequence_features
179
+
180
+ feature_configs = []
181
+ for feature_config in config_util.get_compatible_feature_configs(config):
182
+ feature_name = feature_config.feature_name if feature_config.HasField('feature_name') \
183
+ else feature_config.input_names[0]
184
+ if feature_name not in excluded_features:
185
+ feature_configs.append(feature_config)
186
+
187
+ if config.feature_configs:
188
+ config.ClearField('feature_configs')
189
+ config.feature_configs.extend(feature_configs)
190
+ else:
191
+ config.feature_config.ClearField('features')
192
+ config.feature_config.features.extend(feature_configs)
193
+
194
+ for feature_group in config.model_config.feature_groups:
195
+ feature_names = []
196
+ for feature_name in feature_group.feature_names:
197
+ if feature_name not in excluded_features:
198
+ feature_names.append(feature_name)
199
+ feature_group.ClearField('feature_names')
200
+ feature_group.feature_names.extend(feature_names)
201
+ config_util.save_message(
202
+ config,
203
+ os.path.join(self._output_dir, os.path.basename(self._config_path)))
204
+
205
+ if self._fg_path is not None and len(self._fg_path) > 0:
206
+ with tf.gfile.Open(self._fg_path) as f:
207
+ fg_json = json.load(f, object_pairs_hook=OrderedDict)
208
+ features = []
209
+ for feature in fg_json['features']:
210
+ if 'feature_name' in feature:
211
+ if feature['feature_name'] not in excluded_features:
212
+ features.append(feature)
213
+ else:
214
+ features.append(feature)
215
+ fg_json['features'] = features
216
+ with tf.gfile.Open(
217
+ os.path.join(self._output_dir, os.path.basename(self._fg_path)),
218
+ 'w') as f:
219
+ json.dump(fg_json, f, indent=4)
220
+
221
+ def _dump_to_csv(self, feature_importance, group_name):
222
+ """Dump feature importance data to a csv file."""
223
+ with tf.gfile.Open(
224
+ os.path.join(self._output_dir,
225
+ 'feature_dropout_ratio_%s.csv' % group_name), 'w') as f:
226
+ df = pd.DataFrame(
227
+ columns=['feature_name', 'mean_drop_p'],
228
+ data=[list(kv) for kv in feature_importance.items()])
229
+ df.to_csv(f, encoding='gbk')
230
+
231
+ def _visualize_embedding_dim_importance(self, feature_dim_dropout_p):
232
+ """Visualize embedding-wise importance visualization for every feature."""
233
+ output_dir = os.path.join(self._output_dir, 'feature_dims_importance_pics')
234
+ if not tf.gfile.Exists(output_dir):
235
+ tf.gfile.MakeDirs(output_dir)
236
+
237
+ plt.rcdefaults()
238
+ for feature_name, feature_dropout_p in feature_dim_dropout_p.items():
239
+ embedding_len = len(feature_dropout_p)
240
+ embedding_dims = []
241
+ for i in range(embedding_len):
242
+ embedding_dims.append('dim_' + str(i + 1))
243
+ y_pos = np.arange(len(embedding_dims))
244
+ performance_list = []
245
+ for i in range(0, embedding_len):
246
+ performance_list.append(feature_dropout_p[i])
247
+ fig, ax = plt.subplots()
248
+ b = ax.barh(
249
+ y_pos,
250
+ performance_list,
251
+ align='center',
252
+ alpha=0.4,
253
+ label='dropout_rate',
254
+ lw=1)
255
+ for rect in b:
256
+ w = rect.get_width()
257
+ ax.text(
258
+ w,
259
+ rect.get_y() + rect.get_height() / 2,
260
+ '%.4f' % w,
261
+ ha='left',
262
+ va='center')
263
+ plt.yticks(y_pos, embedding_dims)
264
+ plt.xlabel(feature_name)
265
+ plt.title('Dropout ratio')
266
+ img_path = os.path.join(output_dir, feature_name + '.png')
267
+ with tf.gfile.GFile(img_path, 'wb') as f:
268
+ plt.savefig(f, format='png')
269
+
270
+ def _visualize_feature_importance(self, feature_importance, group_name):
271
+ """Draw feature importance histogram."""
272
+ df = pd.DataFrame(
273
+ columns=['feature_name', 'mean_drop_p'],
274
+ data=[list(kv) for kv in feature_importance.items()])
275
+ df['color'] = ['red' if x < 0.5 else 'green' for x in df['mean_drop_p']]
276
+ df.sort_values('mean_drop_p', inplace=True, ascending=False)
277
+ df.reset_index(inplace=True)
278
+ # Draw plot
279
+ plt.figure(figsize=(90, 200), dpi=100)
280
+ plt.hlines(y=df.index, xmin=0, xmax=df.mean_drop_p)
281
+ for x, y, tex in zip(df.mean_drop_p, df.index, df.mean_drop_p):
282
+ plt.text(
283
+ x,
284
+ y,
285
+ round(tex, 2),
286
+ horizontalalignment='right' if x < 0 else 'left',
287
+ verticalalignment='center',
288
+ fontdict={
289
+ 'color': 'red' if x < 0 else 'green',
290
+ 'size': 14
291
+ })
292
+ # Decorations
293
+ plt.yticks(df.index, df.feature_name, fontsize=20)
294
+ plt.title('Dropout Ratio', fontdict={'size': 30})
295
+ plt.grid(linestyle='--', alpha=0.5)
296
+ plt.xlim(0, 1)
297
+ with tf.gfile.GFile(
298
+ os.path.join(self._output_dir,
299
+ 'feature_dropout_pic_%s.png' % group_name), 'wb') as f:
300
+ plt.savefig(f, format='png')
301
+
302
+
303
+ if __name__ == '__main__':
304
+ sys.argv = io_util.filter_unknown_args(FLAGS, sys.argv)
305
+ if FLAGS.model_type == 'variational_dropout':
306
+ fs = VariationalDropoutFS(
307
+ FLAGS.config_path,
308
+ FLAGS.output_dir,
309
+ FLAGS.topk,
310
+ checkpoint_path=FLAGS.checkpoint_path,
311
+ fg_path=FLAGS.fg_path,
312
+ visualize=FLAGS.visualize)
313
+ fs.process()
314
+ else:
315
+ raise ValueError('Unknown feature selection model type %s' %
316
+ FLAGS.model_type)