easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
@@ -0,0 +1,65 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import argparse
4
+ import logging
5
+ import sys
6
+
7
+ # from kafka import KafkaConsumer
8
+ from kafka import KafkaAdminClient
9
+ from kafka import KafkaProducer
10
+ from kafka.admin import NewTopic
11
+
12
+ # from kafka.structs import TopicPartition
13
+
14
+ logging.basicConfig(
15
+ level=logging.INFO, format='[%(asctime)s][%(levelname)s] %(message)s')
16
+
17
+ if __name__ == '__main__':
18
+ parser = argparse.ArgumentParser()
19
+ parser.add_argument('--servers', type=str, default='localhost:9092')
20
+ parser.add_argument('--topic', type=str, default=None)
21
+ parser.add_argument('--group', type=str, default='consumer')
22
+ parser.add_argument('--partitions', type=str, default=None)
23
+ parser.add_argument('--timeout', type=float, default=float('inf'))
24
+ # file to send
25
+ parser.add_argument('--input_path', type=str, default=None)
26
+ args = parser.parse_args()
27
+
28
+ if args.input_path is None:
29
+ logging.error('input_path is not set')
30
+ sys.exit(1)
31
+
32
+ if args.topic is None:
33
+ logging.error('topic is not set')
34
+ sys.exit(1)
35
+
36
+ servers = args.servers.split(',')
37
+
38
+ admin_clt = KafkaAdminClient(bootstrap_servers=servers)
39
+ if args.topic not in admin_clt.list_topics():
40
+ admin_clt.create_topics(
41
+ new_topics=[
42
+ NewTopic(
43
+ name=args.topic,
44
+ num_partitions=1,
45
+ replication_factor=1,
46
+ topic_configs={'max.message.bytes': 1024 * 1024 * 1024})
47
+ ],
48
+ validate_only=False)
49
+ logging.info('create increment save topic: %s' % args.topic)
50
+ admin_clt.close()
51
+
52
+ producer = KafkaProducer(
53
+ bootstrap_servers=servers,
54
+ request_timeout_ms=args.timeout * 1000,
55
+ api_version=(0, 10, 1))
56
+
57
+ i = 1
58
+ with open(args.input_path, 'r') as fin:
59
+ for line_str in fin:
60
+ producer.send(args.topic, line_str.encode('utf-8'))
61
+ i += 1
62
+ break
63
+ if i % 100 == 0:
64
+ logging.info('progress: %d' % i)
65
+ producer.close()
@@ -0,0 +1,325 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import argparse
4
+ import json
5
+ import logging
6
+ import os
7
+ py_root_dir_path = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
8
+ print(f"py_root_dir_path:{py_root_dir_path}")
9
+ import sys
10
+ sys.path.append(py_root_dir_path)
11
+ import warnings
12
+
13
+ import tensorflow as tf
14
+
15
+
16
+ from easy_rec.python.main import _train_and_evaluate_impl
17
+ from easy_rec.python.protos.pipeline_pb2 import EasyRecConfig
18
+ from easy_rec.python.protos.train_pb2 import DistributionStrategy
19
+ from easy_rec.python.utils import config_util
20
+ from easy_rec.python.utils import ds_util
21
+ from easy_rec.python.utils import estimator_utils
22
+ from easy_rec.python.utils import fg_util
23
+ from easy_rec.python.utils import hpo_util
24
+ from easy_rec.python.utils.config_util import process_neg_sampler_data_path
25
+ from easy_rec.python.utils.config_util import set_eval_input_path
26
+ from easy_rec.python.utils.config_util import set_train_input_path
27
+
28
+ logging.basicConfig(level=logging.INFO)
29
+ warnings.filterwarnings('ignore')
30
+
31
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
32
+
33
+ if tf.__version__.startswith('1.'):
34
+ from tensorflow.python.platform import gfile
35
+ else:
36
+ import tensorflow.io.gfile as gfile
37
+
38
+ from easy_rec.python.utils.distribution_utils import set_tf_config_and_get_train_worker_num_on_ds # NOQA
39
+
40
+ if tf.__version__ >= '2.0':
41
+ tf = tf.compat.v1
42
+
43
+ tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
44
+
45
+ logging.basicConfig(
46
+ format='[%(levelname)s] %(asctime)s %(filename)s:%(lineno)d : %(message)s',
47
+ level=logging.INFO)
48
+
49
+
50
+ def _get_file_path(root_path, file_list):
51
+ # 获取该目录下所有的文件名称和目录名称
52
+ dir_or_files = os.listdir(root_path)
53
+ for dir_file in dir_or_files:
54
+ # 获取目录或者文件的路径
55
+ dir_file_path = os.path.join(root_path, dir_file)
56
+ # 判断该路径为文件还是路径
57
+ if os.path.isdir(dir_file_path):
58
+ # 递归获取所有文件和目录的路径
59
+ _get_file_path(dir_file_path, file_list)
60
+ else:
61
+ if not str(dir_file_path).__contains__('_SUCCESS'):
62
+ file_list.append(dir_file_path)
63
+
64
+
65
+ def get_vocab_list(vocab_path):
66
+ with gfile.GFile(vocab_path, 'r') as fin:
67
+ vocabulary_list = [str(line).strip() for line in fin]
68
+ return vocabulary_list
69
+
70
+
71
+ def get_file_path_list(root_path):
72
+ file_list = []
73
+ _get_file_path(root_path, file_list)
74
+ return file_list
75
+
76
+
77
+ def change_pipeline_config(pipeline_config: EasyRecConfig):
78
+ for data in pipeline_config.feature_config.features:
79
+ # print("****"*10)
80
+ vocab_file = data.vocab_list
81
+ if vocab_file:
82
+ vocab_file_new = get_file_path_list(f"{data_root_path}/{vocab_file.pop()}")[0]
83
+ # print(vocab_file_new)
84
+ vocab_list = get_vocab_list(vocab_file_new)
85
+ for vocab in vocab_list:
86
+ data.vocab_list.append(vocab)
87
+
88
+ model_dir = pipeline_config.model_dir
89
+ pipeline_config.model_dir = f"{data_root_path}/{model_dir}"
90
+
91
+ train_input_path = f"{data_root_path}/{pipeline_config.train_input_path}"
92
+ train_input_path_new = get_file_path_list(train_input_path)
93
+ pipeline_config.train_input_path = ','.join(train_input_path_new)
94
+
95
+ eval_input_path = f"{data_root_path}/{pipeline_config.eval_input_path}"
96
+ eval_input_path_new = get_file_path_list(eval_input_path)
97
+ pipeline_config.eval_input_path = ','.join(eval_input_path_new)
98
+
99
+ pipeline_config.data_config.batch_size = batch_size
100
+ pipeline_config.data_config.num_epochs = num_epochs
101
+
102
+ pipeline_config.train_config.log_step_count_steps = int(train_sample_cnt /
103
+ batch_size)
104
+ pipeline_config.train_config.save_checkpoints_steps = int(train_sample_cnt /
105
+ batch_size)
106
+
107
+ pipeline_config.train_config.optimizer_config[
108
+ 0].adam_optimizer.learning_rate.exponential_decay_learning_rate.initial_learning_rate = initial_learning_rate
109
+
110
+
111
+ if __name__ == '__main__':
112
+ parser = argparse.ArgumentParser()
113
+ parser.add_argument(
114
+ '--pipeline_config_path',
115
+ type=str,
116
+ # default="/Users/chensheng/PycharmProjects/EasyRec/samples/model_config/deepfm_on_criteo_tfrecord.config",
117
+ default='/Users/chensheng/PycharmProjects/EasyRec/samples/model_config/custom_model.config',
118
+ help='Path to pipeline config file.')
119
+
120
+ parser.add_argument(
121
+ '--data_root_path',
122
+ type = str,
123
+ default= '/Users/chensheng/PycharmProjects/EasyRec/data/test/cs_data'
124
+ )
125
+
126
+ parser.add_argument(
127
+ '--train_sample_cnt',
128
+ type=int,
129
+ default=27000,
130
+ help='训练集合的样本数,该数与save_checkpoints_steps 数值相等')
131
+
132
+ parser.add_argument(
133
+ '--batch_size',
134
+ type=int,
135
+ default=3000,
136
+ )
137
+
138
+ parser.add_argument(
139
+ '--num_epochs',
140
+ type=int,
141
+ default=10,
142
+ )
143
+ parser.add_argument(
144
+ '--initial_learning_rate',
145
+ type=float,
146
+ default=0.001,
147
+ )
148
+
149
+ parser.add_argument(
150
+ '--continue_train',
151
+ action='store_true',
152
+ default=False,
153
+ help='continue train using existing model_dir')
154
+ parser.add_argument(
155
+ '--hpo_param_path',
156
+ type=str,
157
+ default=None,
158
+ help='hyperparam tuning param path')
159
+ parser.add_argument(
160
+ '--hpo_metric_save_path',
161
+ type=str,
162
+ default=None,
163
+ help='hyperparameter save metric path')
164
+ parser.add_argument(
165
+ '--model_dir',
166
+ type=str,
167
+ default=None,
168
+ help='will update the model_dir in pipeline_config')
169
+ parser.add_argument(
170
+ '--train_input_path',
171
+ type=str,
172
+ nargs='*',
173
+ default=None,
174
+ help='train data input path')
175
+ parser.add_argument(
176
+ '--eval_input_path',
177
+ type=str,
178
+ nargs='*',
179
+ default=None,
180
+ help='eval data input path')
181
+ parser.add_argument(
182
+ '--fit_on_eval',
183
+ action='store_true',
184
+ default=False,
185
+ help='Fit evaluation data after fitting and evaluating train data')
186
+ parser.add_argument(
187
+ '--fit_on_eval_steps',
188
+ type=int,
189
+ default=None,
190
+ help='Fit evaluation data steps')
191
+ parser.add_argument(
192
+ '--fine_tune_checkpoint',
193
+ type=str,
194
+ default=None,
195
+ help='will update the train_config.fine_tune_checkpoint in pipeline_config'
196
+ )
197
+ parser.add_argument(
198
+ '--edit_config_json',
199
+ type=str,
200
+ default=None,
201
+ help='edit pipeline config str, example: {"model_dir":"experiments/",'
202
+ '"feature_config.feature[0].boundaries":[4,5,6,7]}')
203
+ parser.add_argument(
204
+ '--ignore_finetune_ckpt_error',
205
+ action='store_true',
206
+ default=False,
207
+ help='During incremental training, ignore the problem of missing fine_tune_checkpoint files'
208
+ )
209
+ parser.add_argument(
210
+ '--odps_config', type=str, default=None, help='odps config path')
211
+ parser.add_argument(
212
+ '--is_on_ds', action='store_true', default=False, help='is on ds')
213
+ parser.add_argument(
214
+ '--check_mode',
215
+ action='store_true',
216
+ default=False,
217
+ help='is use check mode')
218
+ parser.add_argument(
219
+ '--selected_cols', type=str, default=None, help='select input columns')
220
+ parser.add_argument('--gpu', type=str, default=None, help='gpu id')
221
+ args, extra_args = parser.parse_known_args()
222
+
223
+ data_root_path = args.data_root_path
224
+ train_sample_cnt = args.train_sample_cnt
225
+ batch_size = args.batch_size
226
+ num_epochs = args.num_epochs
227
+ initial_learning_rate = args.initial_learning_rate
228
+
229
+ if args.gpu is not None:
230
+ os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
231
+
232
+ edit_config_json = {}
233
+ if args.edit_config_json:
234
+ edit_config_json = json.loads(args.edit_config_json)
235
+
236
+ if extra_args is not None and len(extra_args) > 0:
237
+ config_util.parse_extra_config_param(extra_args, edit_config_json)
238
+
239
+ if args.pipeline_config_path is not None:
240
+ pipeline_config = config_util.get_configs_from_pipeline_file(
241
+ args.pipeline_config_path, False)
242
+ if args.selected_cols:
243
+ pipeline_config.data_config.selected_cols = args.selected_cols
244
+ if args.model_dir:
245
+ pipeline_config.model_dir = args.model_dir
246
+ logging.info('update model_dir to %s' % pipeline_config.model_dir)
247
+ if args.train_input_path:
248
+ set_train_input_path(pipeline_config, args.train_input_path)
249
+ if args.eval_input_path:
250
+ set_eval_input_path(pipeline_config, args.eval_input_path)
251
+
252
+ if args.fine_tune_checkpoint:
253
+ ckpt_path = estimator_utils.get_latest_checkpoint_from_checkpoint_path(
254
+ args.fine_tune_checkpoint, args.ignore_finetune_ckpt_error)
255
+
256
+ if ckpt_path:
257
+ pipeline_config.train_config.fine_tune_checkpoint = ckpt_path
258
+
259
+ if pipeline_config.fg_json_path:
260
+ fg_util.load_fg_json_to_config(pipeline_config)
261
+
262
+ if args.odps_config:
263
+ os.environ['ODPS_CONFIG_FILE_PATH'] = args.odps_config
264
+
265
+ if len(edit_config_json) > 0:
266
+ fine_tune_checkpoint = edit_config_json.get('train_config', {}).get(
267
+ 'fine_tune_checkpoint', None)
268
+ if fine_tune_checkpoint:
269
+ ckpt_path = estimator_utils.get_latest_checkpoint_from_checkpoint_path(
270
+ args.fine_tune_checkpoint, args.ignore_finetune_ckpt_error)
271
+ edit_config_json['train_config']['fine_tune_checkpoint'] = ckpt_path
272
+ config_util.edit_config(pipeline_config, edit_config_json)
273
+
274
+ process_neg_sampler_data_path(pipeline_config)
275
+
276
+ if args.is_on_ds:
277
+ ds_util.set_on_ds()
278
+ set_tf_config_and_get_train_worker_num_on_ds()
279
+ if pipeline_config.train_config.fine_tune_checkpoint:
280
+ ds_util.cache_ckpt(pipeline_config)
281
+
282
+ if pipeline_config.train_config.train_distribute in [
283
+ DistributionStrategy.HorovodStrategy,
284
+ ]:
285
+ estimator_utils.init_hvd()
286
+ elif pipeline_config.train_config.train_distribute in [
287
+ DistributionStrategy.EmbeddingParallelStrategy,
288
+ DistributionStrategy.SokStrategy
289
+ ]:
290
+ estimator_utils.init_hvd()
291
+ estimator_utils.init_sok()
292
+
293
+ if args.hpo_param_path:
294
+ with gfile.GFile(args.hpo_param_path, 'r') as fin:
295
+ hpo_config = json.load(fin)
296
+ hpo_params = hpo_config['param']
297
+ config_util.edit_config(pipeline_config, hpo_params)
298
+ config_util.auto_expand_share_feature_configs(pipeline_config)
299
+ _train_and_evaluate_impl(pipeline_config, args.continue_train,
300
+ args.check_mode)
301
+ hpo_util.save_eval_metrics(
302
+ pipeline_config.model_dir,
303
+ metric_save_path=args.hpo_metric_save_path,
304
+ has_evaluator=False)
305
+ else:
306
+
307
+ change_pipeline_config(pipeline_config)
308
+
309
+ if args.continue_train:
310
+ pass
311
+ else:
312
+ model_dir = pipeline_config.model_dir
313
+ print(f'model_dir:{model_dir}')
314
+ os.system(f'rm -rf {model_dir}')
315
+
316
+
317
+ config_util.auto_expand_share_feature_configs(pipeline_config)
318
+ _train_and_evaluate_impl(
319
+ pipeline_config,
320
+ args.continue_train,
321
+ args.check_mode,
322
+ fit_on_eval=args.fit_on_eval,
323
+ fit_on_eval_steps=args.fit_on_eval_steps)
324
+ else:
325
+ raise ValueError('pipeline_config_path should not be empty when training!')
@@ -0,0 +1,15 @@
1
+ class conditional(object):
2
+ """Wrap another context manager and enter it only if condition is true."""
3
+
4
+ def __init__(self, condition, contextmanager):
5
+ self.condition = condition
6
+ self.contextmanager = contextmanager
7
+
8
+ def __enter__(self):
9
+ """Conditionally enter a context manager."""
10
+ if self.condition:
11
+ return self.contextmanager.__enter__()
12
+
13
+ def __exit__(self, *args):
14
+ if self.condition:
15
+ return self.contextmanager.__exit__(*args)
@@ -0,0 +1,120 @@
1
+ # -*- encoding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+
4
+ import numpy as np
5
+ import six
6
+ import tensorflow as tf
7
+
8
+ from easy_rec.python.utils.load_class import load_by_path
9
+
10
+ if tf.__version__ >= '2.0':
11
+ tf = tf.compat.v1
12
+
13
+
14
+ def dice(_x, axis=-1, epsilon=1e-9, name='dice', training=True):
15
+ """The Data Adaptive Activation Function in DIN.
16
+
17
+ Which can be viewed as a generalization of PReLu,
18
+ and can adaptively adjust the rectified point according to distribution of input data.
19
+
20
+ Arguments
21
+ - **axis** : Integer, the axis that should be used to compute data distribution (typically the features axis).
22
+ - **epsilon** : Small float added to variance to avoid dividing by zero.
23
+
24
+ References
25
+ - [Zhou G, Zhu X, Song C, et al. Deep interest network for click-through rate prediction[C]
26
+ Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining.
27
+ ACM, 2018: 1059-1068.] (https://arxiv.org/pdf/1706.06978.pdf)
28
+ """
29
+ alphas = tf.get_variable(
30
+ 'alpha_' + name,
31
+ _x.get_shape()[-1],
32
+ initializer=tf.constant_initializer(0.0),
33
+ dtype=tf.float32)
34
+ inputs_normed = tf.layers.batch_normalization(
35
+ inputs=_x,
36
+ axis=axis,
37
+ epsilon=epsilon,
38
+ center=False,
39
+ scale=False,
40
+ training=training)
41
+ x_p = tf.sigmoid(inputs_normed)
42
+ return alphas * (1.0 - x_p) * _x + x_p * _x
43
+
44
+
45
+ def gelu(x, name='gelu'):
46
+ """Gaussian Error Linear Unit.
47
+
48
+ This is a smoother version of the RELU.
49
+ Original paper: https://arxiv.org/abs/1606.08415
50
+
51
+ Args:
52
+ x: float Tensor to perform activation.
53
+ name: name for this activation
54
+
55
+ Returns:
56
+ `x` with the GELU activation applied.
57
+ """
58
+ with tf.name_scope(name):
59
+ cdf = 0.5 * (1.0 + tf.tanh(
60
+ (np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
61
+ return x * cdf
62
+
63
+
64
+ def swish(x, name='swish'):
65
+ with tf.name_scope(name):
66
+ return x * tf.sigmoid(x)
67
+
68
+
69
+ def get_activation(activation_string, **kwargs):
70
+ """Maps a string to a Python function, e.g., "relu" => `tf.nn.relu`.
71
+
72
+ Args:
73
+ activation_string: String name of the activation function.
74
+
75
+ Returns:
76
+ A Python function corresponding to the activation function. If
77
+ `activation_string` is None, empty, or "linear", this will return None.
78
+ If `activation_string` is not a string, it will return `activation_string`.
79
+
80
+ Raises:
81
+ ValueError: The `activation_string` does not correspond to a known
82
+ activation.
83
+ """
84
+ # We assume that anything that's not a string is already an activation
85
+ # function, so we just return it.
86
+ if not isinstance(activation_string, six.string_types):
87
+ return activation_string
88
+
89
+ if not activation_string:
90
+ return None
91
+
92
+ act = activation_string.lower()
93
+ if act == 'linear':
94
+ return None
95
+ elif act == 'relu':
96
+ return tf.nn.relu
97
+ elif act == 'gelu':
98
+ return gelu
99
+ elif act == 'leaky_relu':
100
+ return tf.nn.leaky_relu
101
+ elif act == 'prelu':
102
+ if len(kwargs) == 0:
103
+ return tf.nn.leaky_relu
104
+ return tf.keras.layers.PReLU(**kwargs)
105
+ elif act == 'dice':
106
+ return lambda x, name='dice': dice(x, name=name, **kwargs)
107
+ elif act == 'elu':
108
+ return tf.nn.elu
109
+ elif act == 'selu':
110
+ return tf.nn.selu
111
+ elif act == 'tanh':
112
+ return tf.tanh
113
+ elif act == 'swish':
114
+ if tf.__version__ < '1.13.0':
115
+ return swish
116
+ return tf.nn.swish
117
+ elif act == 'sigmoid':
118
+ return tf.nn.sigmoid
119
+ else:
120
+ return load_by_path(activation_string)
@@ -0,0 +1,87 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+
4
+ import tensorflow as tf
5
+
6
+ from easy_rec.python.protos.dataset_pb2 import DatasetConfig
7
+
8
+ if tf.__version__ >= '2.0':
9
+ tf = tf.compat.v1
10
+
11
+
12
+ def check_split(line, sep, requried_field_num, field_name=''):
13
+ assert sep, 'must have separator.' + (' field: %s.' %
14
+ field_name) if field_name else ''
15
+
16
+ for one_line in line:
17
+ field_num = len(one_line.split(sep))
18
+ if field_name:
19
+ assert_info = 'sep[%s] maybe invalid. field_num=%d, required_num=%d, field: %s, value: %s, ' \
20
+ 'please check separator and data.' % \
21
+ (sep, field_num, requried_field_num, field_name, one_line)
22
+ else:
23
+ assert_info = 'sep[%s] maybe invalid. field_num=%d, required_num=%d, current line is: %s, ' \
24
+ 'please check separator and data.' % \
25
+ (sep, field_num, requried_field_num, one_line)
26
+ assert field_num == requried_field_num, assert_info
27
+ return True
28
+
29
+
30
+ def check_string_to_number(field_vals, field_name):
31
+ for val in field_vals:
32
+ try:
33
+ float(val)
34
+ except: # noqa: E722
35
+ assert False, 'StringToNumber ERROR: cannot convert string_to_number, field: %s, value: %s. ' \
36
+ 'please check data.' % (field_name, val)
37
+ return True
38
+
39
+
40
+ def check_sequence(pipeline_config_path, features):
41
+ seq_att_groups = pipeline_config_path.model_config.seq_att_groups
42
+ if not seq_att_groups:
43
+ return
44
+ for seq_att_group in seq_att_groups:
45
+ seq_att_maps = seq_att_group.seq_att_map
46
+ if not seq_att_maps:
47
+ return
48
+ for seq_att_map in seq_att_maps:
49
+ assert len(seq_att_map.key) == len(seq_att_map.hist_seq), \
50
+ 'The size of hist_seq must equal to the size of key in one seq_att_map.'
51
+ size_list = []
52
+ for hist_seq in seq_att_map.hist_seq:
53
+ cur_seq_size = len(features[hist_seq].values)
54
+ size_list.append(cur_seq_size)
55
+ hist_seqs = ' '.join(seq_att_map.hist_seq)
56
+ assert len(set(size_list)) == 1, \
57
+ 'SequenceFeature Error: The size in [%s] should be consistent. Please check input: [%s].' % \
58
+ (hist_seqs, hist_seqs)
59
+
60
+
61
+ def check_env_and_input_path(pipeline_config, input_path):
62
+ input_type = pipeline_config.data_config.input_type
63
+ input_type_name = DatasetConfig.InputType.Name(input_type)
64
+ ignore_input_list = [
65
+ DatasetConfig.InputType.TFRecordInput,
66
+ DatasetConfig.InputType.BatchTFRecordInput,
67
+ DatasetConfig.InputType.KafkaInput,
68
+ DatasetConfig.InputType.DataHubInput,
69
+ DatasetConfig.InputType.HiveInput,
70
+ DatasetConfig.InputType.DummyInput,
71
+ ]
72
+ if input_type in ignore_input_list:
73
+ return True
74
+ assert_info = 'Current InputType is %s, InputPath is %s. Please check InputType and InputPath.' % \
75
+ (input_type_name, input_path)
76
+ if input_type_name.startswith('Odps'):
77
+ # is on pai
78
+ for path in input_path.split(','):
79
+ if not path.startswith('odps://'):
80
+ assert False, assert_info
81
+ return True
82
+ else:
83
+ # local or ds
84
+ for path in input_path.split(','):
85
+ if path.startswith('odps://'):
86
+ assert False, assert_info
87
+ return True
@@ -0,0 +1,14 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ # Date: 2019-10-12
4
+ # util to hanlde python2 python3 compatibility
5
+
6
+ import sys
7
+
8
+
9
+ def in_python2():
10
+ return sys.version_info[0] == 2
11
+
12
+
13
+ def in_python3():
14
+ return sys.version_info[0] == 3