easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
@@ -0,0 +1,427 @@
1
+ # -*-encoding:utf-8-*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import logging
4
+ import math
5
+ import sys
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+
10
+ from easy_rec.python.utils import config_util
11
+
12
+ logging.basicConfig(
13
+ level=logging.INFO, format='[%(asctime)s][%(levelname)s] %(message)s')
14
+
15
+
16
+ class ModelConfigConverter:
17
+
18
+ def __init__(self, excel_path, output_path, model_type, column_separator,
19
+ incol_separator, train_input_path, eval_input_path, model_dir):
20
+ self._excel_path = excel_path
21
+ self._output_path = output_path
22
+ self._model_type = model_type
23
+ self._column_separator = column_separator
24
+ self._incol_separator = incol_separator
25
+ self._dict_global = self._parse_global()
26
+ self._tower_dicts = {}
27
+ self._feature_names = []
28
+ self._feature_details = {}
29
+ self._label = ''
30
+ self._train_input_path = train_input_path
31
+ self._eval_input_path = eval_input_path
32
+ self._model_dir = model_dir
33
+ if not self._model_dir:
34
+ self._model_dir = 'experiments/demo'
35
+ logging.warning('model_dir is not specified, set to %s' % self._model_dir)
36
+
37
+ def _get_type_name(self, input_name):
38
+ type_dict = {
39
+ 'bigint': 'INT64',
40
+ 'double': 'DOUBLE',
41
+ 'float': 'FLOAT',
42
+ 'string': 'STRING',
43
+ 'bool': 'BOOL'
44
+ }
45
+ return type_dict[input_name]
46
+
47
+ def _get_type_default(self, input_name):
48
+ type_dict = {
49
+ 'bigint': '0',
50
+ 'double': '0.0',
51
+ 'float': '0.0',
52
+ 'string': '',
53
+ 'bool': 'false'
54
+ }
55
+ return type_dict[input_name]
56
+
57
+ def _parse_global(self):
58
+ df = pd.read_excel(self._excel_path, sheet_name='global')
59
+ dict_global = {}
60
+ for i, row in df.iterrows():
61
+ field = {}
62
+ name = field['name'] = row['name'].strip()
63
+ field['type_name'] = row['type']
64
+ field['hash_bucket_size'] = row['hash_bucket_size']
65
+ field['embedding_dim'] = row['embedding_dim']
66
+ field['default_value'] = row['default_value']
67
+ dict_global[name] = field
68
+ return dict_global
69
+
70
+ def _add_to_tower(self, tower_name, field):
71
+ if tower_name.lower() == 'nan':
72
+ return
73
+ if tower_name != 'label':
74
+ if self._model_type == 'deepfm':
75
+ if tower_name == 'deep':
76
+ tower_names = ['deep']
77
+ elif tower_name == 'wide':
78
+ tower_names = ['wide']
79
+ elif tower_name == 'wide_and_deep':
80
+ tower_names = ['wide', 'deep']
81
+ else:
82
+ raise ValueError(
83
+ 'invalid tower_name[%s] for deepfm model, '
84
+ 'only[label, deep, wide, wide_and_deep are supported]' %
85
+ tower_name)
86
+ for tower_name in tower_names:
87
+ if tower_name in self._tower_dicts:
88
+ self._tower_dicts[tower_name].append(field)
89
+ else:
90
+ self._tower_dicts[tower_name] = [field]
91
+ else:
92
+ if tower_name in self._tower_dicts:
93
+ self._tower_dicts[tower_name].append(field)
94
+ else:
95
+ self._tower_dicts[tower_name] = [field]
96
+
97
+ def _is_str(self, v):
98
+ if isinstance(v, str):
99
+ return True
100
+ try:
101
+ if isinstance(v, unicode): # noqa: F821
102
+ return True
103
+ except NameError:
104
+ return False
105
+ return False
106
+
107
+ def _parse_features(self):
108
+ df = pd.read_excel(self._excel_path, sheet_name='features')
109
+ for i, row in df.iterrows():
110
+ field = {}
111
+ name = field['name'] = row['name'].strip()
112
+ self._feature_names.append(name)
113
+ field['data_type'] = row['data_type'].strip()
114
+ field['type'] = row['type'].strip()
115
+ g = str(row['global']).strip()
116
+
117
+ if g and g != 'nan':
118
+ field['global'] = g
119
+
120
+ field['field_name'] = name
121
+
122
+ if row['type'].strip() == 'label':
123
+ self._label = name
124
+
125
+ if 'global' in field and field['global'] in self._dict_global:
126
+ # 如果是global 有值,就跳过
127
+ def _is_good(v):
128
+ return str(v) not in ['nan', '']
129
+
130
+ if _is_good(self._dict_global[field['global']]['default_value']):
131
+ field['default_value'] = self._dict_global[
132
+ field['global']]['default_value']
133
+ if _is_good(self._dict_global[field['global']]['hash_bucket_size']):
134
+ field['hash_bucket_size'] = self._dict_global[
135
+ field['global']]['hash_bucket_size']
136
+ if _is_good(self._dict_global[field['global']]['embedding_dim']):
137
+ field['embedding_dim'] = self._dict_global[
138
+ field['global']]['embedding_dim']
139
+ field['embedding_name'] = field['global']
140
+
141
+ for t in [
142
+ 'type', 'global', 'hash_bucket_size', 'embedding_dim',
143
+ 'default_value', 'weights', 'boundaries'
144
+ ]:
145
+ if t not in row:
146
+ continue
147
+ v = row[t]
148
+ if v not in ['', ' ', 'NaN', np.NaN, np.NAN, 'nan']:
149
+ if self._is_str(v):
150
+ field[t] = v.strip()
151
+ elif not math.isnan(v):
152
+ field[t] = int(v)
153
+
154
+ if t == 'default_value' and t not in field:
155
+ field[t] = ''
156
+ if field['type'] == 'dense':
157
+ field[t] = 0.0
158
+
159
+ if field['type'] == 'weights':
160
+ field['default_value'] = '1'
161
+
162
+ tower_name = row['group']
163
+ if name in self._dict_global:
164
+ field['type'] = 'category'
165
+ field['hash_bucket_size'] = self._dict_global[name]['hash_bucket_size']
166
+ field['embedding_dim'] = self._dict_global[name]['embedding_dim']
167
+ field['default_value'] = self._dict_global[name]['default_value']
168
+
169
+ if field['data_type'] == 'bigint':
170
+ field['default_value'] = 0
171
+ elif field['data_type'] == 'double':
172
+ field['default_value'] = 0.0
173
+
174
+ if field['type'] not in ['notneed', 'not_need', 'not_needed']:
175
+ tower_name = str(tower_name).strip()
176
+ self._add_to_tower(tower_name, field)
177
+ self._feature_details[name] = field
178
+
179
+ # check that tag features weights are one of the fields
180
+ for name, config in self._feature_details.items():
181
+ if config['type'] == 'tags':
182
+ if 'weights' in config and config[
183
+ 'weights'] not in self._feature_details:
184
+ raise ValueError(config['weights'] + ' not in field names')
185
+
186
+ def _write_train_eval_config(self, fout):
187
+ fout.write('train_input_path: "%s"\n' % self._train_input_path)
188
+ fout.write('eval_input_path: "%s"\n' % self._eval_input_path)
189
+ fout.write("""
190
+ model_dir: "%s"
191
+
192
+ train_config {
193
+ log_step_count_steps: 200
194
+ # fine_tune_checkpoint: ""
195
+ optimizer_config: {
196
+ adam_optimizer: {
197
+ learning_rate: {
198
+ exponential_decay_learning_rate {
199
+ initial_learning_rate: 0.0001
200
+ decay_steps: 10000
201
+ decay_factor: 0.5
202
+ min_learning_rate: 0.0000001
203
+ }
204
+ }
205
+ }
206
+ }
207
+ num_steps: 2000
208
+ sync_replicas: true
209
+ }
210
+
211
+ eval_config {
212
+ metrics_set: {
213
+ auc {}
214
+ }
215
+ }""" % self._model_dir)
216
+
217
+ def _write_deepfm_config(self, fout):
218
+ # write model_config
219
+ fout.write('model_config:{\n')
220
+ fout.write(' model_class: "DeepFM"\n')
221
+
222
+ # write feature group configs
223
+ tower_names = list(self._tower_dicts.keys())
224
+ tower_names.sort()
225
+ for tower_name in tower_names:
226
+ fout.write(' feature_groups: {\n')
227
+ fout.write(' group_name: "%s"\n' % tower_name)
228
+ curr_feas = self._tower_dicts[tower_name]
229
+ for fea in curr_feas:
230
+ if fea['type'] == 'weights':
231
+ continue
232
+ fout.write(' feature_names: "%s"\n' % fea['name'])
233
+ fout.write(' wide_deep:%s\n' % tower_name.upper())
234
+ fout.write(' }\n')
235
+
236
+ # write deepfm configs
237
+ fout.write("""
238
+ deepfm {
239
+ dnn {
240
+ hidden_units: [128, 64, 32]
241
+ }
242
+ final_dnn {
243
+ hidden_units: [128, 64]
244
+ }
245
+ wide_output_dim: 16
246
+ l2_regularization: 1e-5
247
+ }
248
+ embedding_regularization: 1e-5
249
+ }
250
+ """)
251
+
252
+ def _write_multi_tower_config(self, fout):
253
+ # write model_config
254
+ fout.write('model_config:{\n')
255
+ fout.write(' model_class: "MultiTower"\n')
256
+
257
+ # write each tower features
258
+ tower_names = list(self._tower_dicts.keys())
259
+ tower_names.sort()
260
+ for tower_name in tower_names:
261
+ fout.write(' feature_groups: {\n')
262
+ fout.write(' group_name: "%s"\n' % tower_name)
263
+ curr_feas = self._tower_dicts[tower_name]
264
+ for fea in curr_feas:
265
+ if fea['type'] == 'weights':
266
+ continue
267
+ fout.write(' feature_names: "%s"\n' % fea['name'])
268
+ fout.write(' wide_deep:DEEP\n')
269
+ fout.write(' }\n')
270
+
271
+ # write each tower dnn configs
272
+ fout.write('multi_tower { \n')
273
+
274
+ for tower_name in tower_names:
275
+ fout.write("""
276
+ towers {
277
+ input: "%s"
278
+ dnn {
279
+ hidden_units: [256, 192, 128]
280
+ }
281
+ }""" % tower_name)
282
+
283
+ fout.write("""
284
+ final_dnn {
285
+ hidden_units: [192, 128, 64]
286
+ }
287
+ l2_regularization: 1e-5
288
+ }
289
+ embedding_regularization: 1e-5
290
+ }""")
291
+
292
+ def _write_data_config(self, fout):
293
+ fout.write('data_config {\n')
294
+ fout.write(' separator: "%s"\n' % self._column_separator)
295
+ for name in self._feature_names:
296
+ fout.write(' input_fields: {\n')
297
+ fout.write(' input_name: "%s"\n' % name)
298
+ fout.write(' input_type: %s\n' %
299
+ self._get_type_name(self._feature_details[name]['data_type']))
300
+ if 'default_value' in self._feature_details[name]:
301
+ fout.write(' default_val:"%s"\n' %
302
+ self._feature_details[name]['default_value'])
303
+ fout.write(' }\n')
304
+
305
+ fout.write(' label_fields: "%s"\n' % self._label)
306
+ fout.write("""
307
+ batch_size: 1024
308
+ prefetch_size: 32
309
+ input_type: CSVInput
310
+ }""")
311
+
312
+ def _write_feature_config(self, fout):
313
+ for name in self._feature_names:
314
+ feature = self._feature_details[name]
315
+ if feature['type'] in ['weights', 'notneed', 'label']:
316
+ continue
317
+ if name == self._label:
318
+ continue
319
+ fout.write('feature_configs: {\n')
320
+ fout.write(' input_names: "%s"\n' % name)
321
+ if feature['type'] == 'category':
322
+ fout.write(' feature_type: IdFeature\n')
323
+ fout.write(' embedding_dim: %d\n' % feature['embedding_dim'])
324
+ fout.write(' hash_bucket_size: %d\n' % feature['hash_bucket_size'])
325
+ if 'embedding_name' in feature:
326
+ fout.write(' embedding_name: "%s"\n' % feature['embedding_name'])
327
+ elif feature['type'] == 'dense':
328
+ fout.write(' feature_type: RawFeature\n')
329
+ if self._model_type == 'deepfm':
330
+ assert feature[
331
+ 'boundaries'] != '', 'raw features must be discretized by specifying boundaries'
332
+ if 'boundaries' in feature and feature['boundaries'] != '':
333
+ fout.write(' boundaries: [%s]\n' %
334
+ str(feature['boundaries']).strip())
335
+ fout.write(' embedding_dim: %d\n' % int(feature['embedding_dim']))
336
+ elif feature['type'] == 'tags':
337
+ if 'weights' in feature:
338
+ fout.write(' input_names: "%s"\n' % feature['weights'])
339
+ fout.write(' feature_type: TagFeature\n')
340
+ fout.write(' hash_bucket_size: %d\n' % feature['hash_bucket_size'])
341
+ fout.write(' embedding_dim: %d\n' % feature['embedding_dim'])
342
+ if 'embedding_name' in feature:
343
+ fout.write(' embedding_name: "%s"\n' % feature['embedding_name'])
344
+ fout.write(' separator: "%s"\n' % self._incol_separator)
345
+ elif feature['type'] == 'indexes':
346
+ fout.write(' feature_type: TagFeature\n')
347
+ assert 'hash_bucket_size' in feature
348
+ fout.write(' num_buckets: %d\n' % feature['hash_bucket_size'])
349
+ if 'embedding_dim' in feature:
350
+ fout.write(' embedding_dim: %d\n' % feature['embedding_dim'])
351
+ if 'embedding_name' in feature:
352
+ fout.write(' embedding_name: "%s"\n' % feature['embedding_name'])
353
+ fout.write(' separator: "%s"\n' % self._incol_separator)
354
+ else:
355
+ assert False, 'invalid feature types: %s' % feature['type']
356
+ fout.write('}\n')
357
+
358
+ def convert(self):
359
+ self._parse_features()
360
+ logging.info(
361
+ 'TOWERS[%d]: %s' %
362
+ (len(self._tower_dicts), ','.join(list(self._tower_dicts.keys()))))
363
+ with open(self._output_path, 'w') as fout:
364
+ self._write_train_eval_config(fout)
365
+ self._write_data_config(fout)
366
+ self._write_feature_config(fout)
367
+ if self._model_type == 'deepfm':
368
+ self._write_deepfm_config(fout)
369
+ elif self._model_type == 'multi_tower':
370
+ self._write_multi_tower_config(fout)
371
+ else:
372
+ logging.warning(
373
+ 'the model_config could not be generated automatically, you have to write the model_config manually.'
374
+ )
375
+ # reformat the config
376
+ pipeline_config = config_util.get_configs_from_pipeline_file(
377
+ self._output_path)
378
+ config_util.save_message(pipeline_config, self._output_path)
379
+
380
+
381
+ model_types = ['deepfm', 'multi_tower']
382
+
383
+ if __name__ == '__main__':
384
+ import argparse
385
+
386
+ parser = argparse.ArgumentParser()
387
+ parser.add_argument(
388
+ '--model_type',
389
+ type=str,
390
+ choices=model_types,
391
+ help='model type, currently support: %s' % ','.join(model_types))
392
+ parser.add_argument('--excel_path', type=str, help='excel config path')
393
+ parser.add_argument('--output_path', type=str, help='generated config path')
394
+ parser.add_argument(
395
+ '--column_separator',
396
+ type=str,
397
+ default=',',
398
+ help='column separator, separator betwen features')
399
+ parser.add_argument(
400
+ '--incol_separator',
401
+ type=str,
402
+ default='|',
403
+ help='separator within features, such as tag features')
404
+ parser.add_argument(
405
+ '--train_input_path', type=str, default='', help='train input path')
406
+ parser.add_argument(
407
+ '--eval_input_path', type=str, default='', help='eval input path')
408
+ parser.add_argument('--model_dir', type=str, default='', help='model dir')
409
+ args = parser.parse_args()
410
+
411
+ if not args.excel_path or not args.output_path:
412
+ parser.print_usage()
413
+ sys.exit(1)
414
+
415
+ logging.info('column_separator = %s in_column_separator = %s' %
416
+ (args.column_separator, args.incol_separator))
417
+
418
+ converter = ModelConfigConverter(args.excel_path, args.output_path,
419
+ args.model_type, args.column_separator,
420
+ args.incol_separator, args.train_input_path,
421
+ args.eval_input_path, args.model_dir)
422
+ converter.convert()
423
+ logging.info('Conversion done')
424
+ logging.info('Tips:')
425
+ if args.train_input_path == '' or args.eval_input_path == '':
426
+ logging.info('*.you have to update train_input_path, eval_input_path')
427
+ logging.info('*.you may need to adjust dnn config or final_dnn config')
File without changes
@@ -0,0 +1,157 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import argparse
4
+ import gzip
5
+ import logging
6
+ import multiprocessing
7
+ import os
8
+ import traceback
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+ import six
13
+ from tensorflow.python.platform import gfile
14
+
15
+ logging.basicConfig(
16
+ level=logging.INFO, format='[%(asctime)s][%(levelname)s] %(message)s')
17
+
18
+
19
+ def save_np_bin(labels, dense_arr, cate_arr, prefix):
20
+ with gfile.GFile(prefix + '_label.bin', 'wb') as fout:
21
+ fout.write(np.array(labels, dtype=np.int32).tobytes())
22
+ with gfile.GFile(prefix + '_dense.bin', 'wb') as fout:
23
+ fout.write(np.array(dense_arr, dtype=np.float32).tobytes())
24
+ with gfile.GFile(prefix + '_category.bin', 'wb') as fout:
25
+ fout.write(np.array(cate_arr, dtype=np.float32).tobytes())
26
+
27
+
28
+ def save_parquet(labels, dense_arr, cate_arr, prefix):
29
+ df = {'is_click': labels}
30
+ for i in range(1, 14):
31
+ df['f' + str(i)] = dense_arr[:, i - 1]
32
+ for i in range(1, 27):
33
+ df['c' + str(i)] = cate_arr[:, i - 1]
34
+ df = pd.DataFrame(df)
35
+ save_path = prefix + '.parquet'
36
+ logging.info('save to %s' % save_path)
37
+ df.to_parquet(save_path)
38
+
39
+
40
+ def convert(input_path, prefix, part_record_num, save_format):
41
+ logging.info('start to convert %s, part_record_num=%d, save_format=%s' %
42
+ (input_path, part_record_num, save_format))
43
+ save_func = save_np_bin
44
+ if save_format == 'parquet':
45
+ save_func = save_parquet
46
+ batch_size = part_record_num
47
+ labels = np.zeros([batch_size], dtype=np.int32)
48
+ dense_arr = np.zeros([batch_size, 13], dtype=np.float32)
49
+ cate_arr = np.zeros([batch_size, 26], dtype=np.uint32)
50
+ part_id = 0
51
+ total_line = 0
52
+ try:
53
+ sid = 0
54
+ with gfile.GFile(input_path, 'rb') as gz_fin:
55
+ for line_str in gzip.GzipFile(fileobj=gz_fin, mode='rb'):
56
+ if six.PY3:
57
+ line_str = str(line_str, 'utf-8')
58
+ line_str = line_str.strip()
59
+ line_toks = line_str.split('\t')
60
+ labels[sid] = int(line_toks[0])
61
+
62
+ for j in range(1, 14):
63
+ x = line_toks[j]
64
+ dense_arr[sid, j - 1] = float(x) if x != '' else 0.0
65
+
66
+ for j in range(14, 40):
67
+ x = line_toks[j]
68
+ cate_arr[sid, j - 14] = int(x, 16) if x != '' else 0
69
+
70
+ sid += 1
71
+ if sid == batch_size:
72
+ save_func(labels, dense_arr, cate_arr, prefix + '_' + str(part_id))
73
+ logging.info('\t%s write part: %d' % (input_path, part_id))
74
+ part_id += 1
75
+ total_line += sid
76
+ sid = 0
77
+ if sid > 0:
78
+ save_func(labels[:sid], dense_arr[:sid], cate_arr[:sid],
79
+ prefix + '_' + str(part_id))
80
+ logging.info('\t%s write final part: %d' % (input_path, part_id))
81
+ part_id += 1
82
+ total_line += sid
83
+ except Exception as ex:
84
+ logging.error('convert %s failed: %s' % (input_path, str(ex)))
85
+ logging.error(traceback.format_exc())
86
+ return
87
+ logging.info('done convert %s, total_line=%d, part_num=%d' %
88
+ (input_path, total_line, part_id))
89
+
90
+
91
+ if __name__ == '__main__':
92
+ """Convert criteo 1T data to binary format.
93
+
94
+ The outputs are stored in multiple parts, each with at most part_record_num samples.
95
+ Each part consists of 3 files:
96
+ xxx_yyy_label.bin,
97
+ xxx_yyy_dense.bin,
98
+ xxx_yyy_category.bin,
99
+ xxx is in range [0-23], range of yyy is determined by part_record_num,
100
+
101
+ If part_record_num is set to the default value 8M, there will be 535 parts. We convert
102
+ the data on machine with 64GB memory, if you memory is limited, you can convert the .gz
103
+ files one by one, or you can set a small part_record_num.
104
+ """
105
+
106
+ parser = argparse.ArgumentParser()
107
+ parser.add_argument(
108
+ '--input_dir', type=str, default=None, help='criteo 1t data dir')
109
+ parser.add_argument(
110
+ '--save_dir',
111
+ type=str,
112
+ default=None,
113
+ help='criteo binary data output dir ')
114
+ parser.add_argument(
115
+ '--save_format',
116
+ type=str,
117
+ default='npy',
118
+ help='save format, choices: npy|parquet')
119
+ parser.add_argument(
120
+ '--part_record_num',
121
+ type=int,
122
+ default=1024 * 1024 * 8,
123
+ help='the maximal number of samples in each binary file')
124
+ parser.add_argument(
125
+ '--dt',
126
+ nargs='*',
127
+ type=int,
128
+ help='select days to convert, default to select all: 0-23')
129
+
130
+ args = parser.parse_args()
131
+
132
+ assert args.input_dir, 'input_dir is not set'
133
+ assert args.save_dir, 'save_dir is not set'
134
+
135
+ save_dir = args.save_dir
136
+ if not save_dir.endswith('/'):
137
+ save_dir = save_dir + '/'
138
+ if not gfile.IsDirectory(save_dir):
139
+ gfile.MakeDirs(save_dir)
140
+
141
+ if args.dt is None or len(args.dt) == 0:
142
+ days = list(range(0, 24))
143
+ else:
144
+ days = list(args.dt)
145
+
146
+ proc_arr = []
147
+ for d in days:
148
+ input_path = os.path.join(args.input_dir, 'day_%d.gz' % d)
149
+ prefix = os.path.join(args.save_dir, str(d))
150
+ proc = multiprocessing.Process(
151
+ target=convert,
152
+ args=(input_path, prefix, args.part_record_num, args.save_format))
153
+ convert(input_path, prefix, args.part_record_num, args.save_format)
154
+ proc.start()
155
+ proc_arr.append(proc)
156
+ for proc in proc_arr:
157
+ proc.join()