easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
@@ -0,0 +1,144 @@
1
+ # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Kafka Dataset."""
16
+
17
+ import logging
18
+ import traceback
19
+
20
+ from tensorflow.python.data.ops import dataset_ops
21
+ from tensorflow.python.framework import dtypes
22
+ from tensorflow.python.framework import ops
23
+ from tensorflow.python.framework import tensor_shape
24
+
25
+ try:
26
+ from easy_rec.python.ops import gen_kafka_ops
27
+ except ImportError:
28
+ logging.warning('failed to import gen_kafka_ops: %s' % traceback.format_exc())
29
+
30
+
31
+ class KafkaDataset(dataset_ops.Dataset):
32
+ """A Kafka Dataset that consumes the message."""
33
+
34
+ def __init__(self,
35
+ topics,
36
+ servers='localhost',
37
+ group='',
38
+ eof=False,
39
+ timeout=1000,
40
+ config_global=None,
41
+ config_topic=None,
42
+ message_key=False,
43
+ message_offset=False):
44
+ """Create a KafkaReader.
45
+
46
+ Args:
47
+ topics: A `tf.string` tensor containing one or more subscriptions,
48
+ in the format of [topic:partition:offset:length],
49
+ by default length is -1 for unlimited.
50
+ servers: A list of bootstrap servers.
51
+ group: The consumer group id.
52
+ eof: If True, the kafka reader will stop on EOF.
53
+ timeout: The timeout value for the Kafka Consumer to wait
54
+ (in millisecond).
55
+ config_global: A `tf.string` tensor containing global configuration
56
+ properties in [Key=Value] format,
57
+ eg. ["enable.auto.commit=false",
58
+ "heartbeat.interval.ms=2000"],
59
+ please refer to 'Global configuration properties'
60
+ in librdkafka doc.
61
+ config_topic: A `tf.string` tensor containing topic configuration
62
+ properties in [Key=Value] format,
63
+ eg. ["auto.offset.reset=earliest"],
64
+ please refer to 'Topic configuration properties'
65
+ in librdkafka doc.
66
+ message_key: If True, the kafka will output both message value and key.
67
+ message_offset: If True, the kafka will output both message value and offset.
68
+ """
69
+ self._topics = ops.convert_to_tensor(
70
+ topics, dtype=dtypes.string, name='topics')
71
+ self._servers = ops.convert_to_tensor(
72
+ servers, dtype=dtypes.string, name='servers')
73
+ self._group = ops.convert_to_tensor(
74
+ group, dtype=dtypes.string, name='group')
75
+ self._eof = ops.convert_to_tensor(eof, dtype=dtypes.bool, name='eof')
76
+ self._timeout = ops.convert_to_tensor(
77
+ timeout, dtype=dtypes.int64, name='timeout')
78
+ config_global = config_global if config_global else []
79
+ self._config_global = ops.convert_to_tensor(
80
+ config_global, dtype=dtypes.string, name='config_global')
81
+ config_topic = config_topic if config_topic else []
82
+ self._config_topic = ops.convert_to_tensor(
83
+ config_topic, dtype=dtypes.string, name='config_topic')
84
+ self._message_key = message_key
85
+ self._message_offset = message_offset
86
+ super(KafkaDataset, self).__init__()
87
+
88
+ def _inputs(self):
89
+ return []
90
+
91
+ def _as_variant_tensor(self):
92
+ return gen_kafka_ops.io_kafka_dataset_v2(
93
+ self._topics,
94
+ self._servers,
95
+ self._group,
96
+ self._eof,
97
+ self._timeout,
98
+ self._config_global,
99
+ self._config_topic,
100
+ self._message_key,
101
+ self._message_offset,
102
+ )
103
+
104
+ @property
105
+ def output_classes(self):
106
+ if self._message_key ^ self._message_offset:
107
+ return (ops.Tensor, ops.Tensor)
108
+ elif self._message_key and self._message_offset:
109
+ return (ops.Tensor, ops.Tensor, ops.Tensor)
110
+ return (ops.Tensor)
111
+
112
+ @property
113
+ def output_shapes(self):
114
+ if self._message_key ^ self._message_offset:
115
+ return ((tensor_shape.TensorShape([]), tensor_shape.TensorShape([])))
116
+ elif self._message_key and self._message_offset:
117
+ return ((tensor_shape.TensorShape([]), tensor_shape.TensorShape([]),
118
+ tensor_shape.TensorShape([])))
119
+ return ((tensor_shape.TensorShape([])))
120
+
121
+ @property
122
+ def output_types(self):
123
+ if self._message_key ^ self._message_offset:
124
+ return ((dtypes.string, dtypes.string))
125
+ elif self._message_key and self._message_offset:
126
+ return ((dtypes.string, dtypes.string, dtypes.string))
127
+ return ((dtypes.string))
128
+
129
+
130
+ def write_kafka_v2(message, topic, servers='localhost', name=None):
131
+ """Write kafka.
132
+
133
+ Args:
134
+ message: A `Tensor` of type `string`. 0-D.
135
+ topic: A `tf.string` tensor containing one subscription,
136
+ in the format of topic:partition.
137
+ servers: A list of bootstrap servers.
138
+ name: A name for the operation (optional).
139
+
140
+ Returns:
141
+ A `Tensor` of type `string`. 0-D.
142
+ """
143
+ return gen_kafka_ops.io_write_kafka_v2(
144
+ message=message, topic=topic, servers=servers, name=name)
@@ -0,0 +1,235 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import json
4
+ import logging
5
+ import traceback
6
+
7
+ import six
8
+ import tensorflow as tf
9
+
10
+ from easy_rec.python.input.input import Input
11
+ from easy_rec.python.input.kafka_dataset import KafkaDataset
12
+ from easy_rec.python.utils.config_util import parse_time
13
+
14
+ if tf.__version__.startswith('1.'):
15
+ from tensorflow.python.platform import gfile
16
+ else:
17
+ import tensorflow.io.gfile as gfile
18
+
19
+ try:
20
+ from kafka import KafkaConsumer, TopicPartition
21
+ except ImportError:
22
+ logging.warning(
23
+ 'kafka-python is not installed[%s]. You can install it by: pip install kafka-python'
24
+ % traceback.format_exc())
25
+
26
+ if tf.__version__ >= '2.0':
27
+ ignore_errors = tf.data.experimental.ignore_errors()
28
+ tf = tf.compat.v1
29
+ else:
30
+ ignore_errors = tf.contrib.data.ignore_errors()
31
+
32
+
33
+ class KafkaInput(Input):
34
+
35
+ DATA_OFFSET = 'DATA_OFFSET'
36
+
37
+ def __init__(self,
38
+ data_config,
39
+ feature_config,
40
+ kafka_config,
41
+ task_index=0,
42
+ task_num=1,
43
+ check_mode=False,
44
+ pipeline_config=None):
45
+ super(KafkaInput,
46
+ self).__init__(data_config, feature_config, '', task_index, task_num,
47
+ check_mode, pipeline_config)
48
+ self._kafka = kafka_config
49
+ self._offset_dict = {}
50
+ if self._kafka is not None:
51
+ consumer = KafkaConsumer(
52
+ group_id='kafka_dataset_consumer',
53
+ bootstrap_servers=[self._kafka.server],
54
+ api_version_auto_timeout_ms=60000) # in miliseconds
55
+ partitions = consumer.partitions_for_topic(self._kafka.topic)
56
+ self._num_partition = len(partitions)
57
+ logging.info('all partitions[%d]: %s' % (self._num_partition, partitions))
58
+
59
+ # determine kafka offsets for each partition
60
+ offset_type = self._kafka.WhichOneof('offset')
61
+ if offset_type is not None:
62
+ if offset_type == 'offset_time':
63
+ ts = parse_time(self._kafka.offset_time)
64
+ input_map = {
65
+ TopicPartition(partition=part_id, topic=self._kafka.topic):
66
+ ts * 1000 for part_id in partitions
67
+ }
68
+ part_offsets = consumer.offsets_for_times(input_map)
69
+ # part_offsets is a dictionary:
70
+ # {
71
+ # TopicPartition(topic=u'kafka_data_20220408', partition=0):
72
+ # OffsetAndTimestamp(offset=2, timestamp=1650611437895)
73
+ # }
74
+ for part in part_offsets:
75
+ self._offset_dict[part.partition] = part_offsets[part].offset
76
+ logging.info(
77
+ 'Find offset by time, topic[%s], partition[%d], timestamp[%ss], offset[%d], offset_timestamp[%dms]'
78
+ % (self._kafka.topic, part.partition, ts,
79
+ part_offsets[part].offset, part_offsets[part].timestamp))
80
+ elif offset_type == 'offset_info':
81
+ offset_dict = json.loads(self._kafka.offset_info)
82
+ for part in offset_dict:
83
+ part_id = int(part)
84
+ self._offset_dict[part_id] = offset_dict[part]
85
+ else:
86
+ assert 'invalid offset_type: %s' % offset_type
87
+ self._task_offset_dict = {}
88
+
89
+ def _preprocess(self, field_dict):
90
+ output_dict = super(KafkaInput, self)._preprocess(field_dict)
91
+
92
+ # append offset fields
93
+ if Input.DATA_OFFSET in field_dict:
94
+ output_dict[Input.DATA_OFFSET] = field_dict[Input.DATA_OFFSET]
95
+
96
+ # for _get_features to include DATA_OFFSET
97
+ if Input.DATA_OFFSET not in self._appended_fields:
98
+ self._appended_fields.append(Input.DATA_OFFSET)
99
+
100
+ return output_dict
101
+
102
+ def _parse_csv(self, line, message_key, message_offset):
103
+ record_defaults = [
104
+ self.get_type_defaults(t, v)
105
+ for t, v in zip(self._input_field_types, self._input_field_defaults)
106
+ ]
107
+
108
+ fields = tf.decode_csv(
109
+ line,
110
+ use_quote_delim=False,
111
+ field_delim=self._data_config.separator,
112
+ record_defaults=record_defaults,
113
+ name='decode_csv')
114
+
115
+ inputs = {self._input_fields[x]: fields[x] for x in self._effective_fids}
116
+
117
+ for x in self._label_fids:
118
+ inputs[self._input_fields[x]] = fields[x]
119
+
120
+ # record current offset
121
+ def _parse_offset(message_offset):
122
+ for kv in message_offset:
123
+ if six.PY3:
124
+ kv = kv.decode('utf-8')
125
+ k, v = kv.split(':')
126
+ k = int(k)
127
+ v = int(v)
128
+ if k not in self._task_offset_dict or v > self._task_offset_dict[k]:
129
+ self._task_offset_dict[k] = v
130
+ return json.dumps(self._task_offset_dict)
131
+
132
+ inputs[Input.DATA_OFFSET] = tf.py_func(_parse_offset, [message_offset],
133
+ tf.string)
134
+ return inputs
135
+
136
+ def restore(self, checkpoint_path):
137
+ if checkpoint_path is None:
138
+ return
139
+
140
+ offset_path = checkpoint_path + '.offset'
141
+ if not gfile.Exists(offset_path):
142
+ return
143
+
144
+ logging.info('will restore kafka offset from %s' % offset_path)
145
+ with gfile.GFile(offset_path, 'r') as fin:
146
+ offset_dict = json.load(fin)
147
+ self._offset_dict = {}
148
+ for k in offset_dict:
149
+ v = offset_dict[k]
150
+ k = int(k)
151
+ if k not in self._offset_dict or v > self._offset_dict[k]:
152
+ self._offset_dict[k] = v
153
+
154
+ def _get_topics(self):
155
+ task_num = self._task_num
156
+ task_index = self._task_index
157
+ if self._data_config.chief_redundant and self._mode == tf.estimator.ModeKeys.TRAIN:
158
+ task_index = max(task_index - 1, 0)
159
+ task_num = max(task_num - 1, 1)
160
+
161
+ topics = []
162
+ self._task_offset_dict = {}
163
+ for part_id in range(self._num_partition):
164
+ if (part_id % task_num) == task_index:
165
+ offset = self._offset_dict.get(part_id, 0)
166
+ topics.append('%s:%d:%d' % (self._kafka.topic, part_id, offset))
167
+ self._task_offset_dict[part_id] = offset
168
+ logging.info('assigned topic partitions: %s' % (','.join(topics)))
169
+ assert len(
170
+ topics) > 0, 'no partitions are assigned for this task(%d/%d)' % (
171
+ self._task_index, self._task_num)
172
+ return topics
173
+
174
+ def _build(self, mode, params):
175
+ num_parallel_calls = self._data_config.num_parallel_calls
176
+ task_topics = self._get_topics()
177
+ if mode == tf.estimator.ModeKeys.TRAIN:
178
+ assert self._kafka is not None, 'kafka_train_input is not set.'
179
+ train_kafka = self._kafka
180
+ logging.info(
181
+ 'train kafka server: %s topic: %s task_num: %d task_index: %d topics: %s'
182
+ % (train_kafka.server, train_kafka.topic, self._task_num,
183
+ self._task_index, task_topics))
184
+
185
+ dataset = KafkaDataset(
186
+ task_topics,
187
+ servers=train_kafka.server,
188
+ group=train_kafka.group,
189
+ eof=False,
190
+ config_global=list(self._kafka.config_global),
191
+ config_topic=list(self._kafka.config_topic),
192
+ message_key=True,
193
+ message_offset=True)
194
+
195
+ if self._data_config.shuffle:
196
+ dataset = dataset.shuffle(
197
+ self._data_config.shuffle_buffer_size,
198
+ seed=2020,
199
+ reshuffle_each_iteration=True)
200
+ else:
201
+ eval_kafka = self._kafka
202
+ assert self._kafka is not None, 'kafka_eval_input is not set.'
203
+
204
+ logging.info(
205
+ 'eval kafka server: %s topic: %s task_num: %d task_index: %d topics: %s'
206
+ % (eval_kafka.server, eval_kafka.topic, self._task_num,
207
+ self._task_index, task_topics))
208
+
209
+ dataset = KafkaDataset(
210
+ task_topics,
211
+ servers=self._kafka.server,
212
+ group=eval_kafka.group,
213
+ eof=False,
214
+ config_global=list(self._kafka.config_global),
215
+ config_topic=list(self._kafka.config_topic),
216
+ message_key=True,
217
+ message_offset=True)
218
+
219
+ dataset = dataset.batch(self._data_config.batch_size)
220
+ dataset = dataset.map(
221
+ self._parse_csv, num_parallel_calls=num_parallel_calls)
222
+ if self._data_config.ignore_error:
223
+ dataset = dataset.apply(ignore_errors)
224
+ dataset = dataset.prefetch(buffer_size=self._prefetch_size)
225
+ dataset = dataset.map(
226
+ map_func=self._preprocess, num_parallel_calls=num_parallel_calls)
227
+
228
+ dataset = dataset.prefetch(buffer_size=self._prefetch_size)
229
+
230
+ if mode != tf.estimator.ModeKeys.PREDICT:
231
+ dataset = dataset.map(lambda x:
232
+ (self._get_features(x), self._get_labels(x)))
233
+ else:
234
+ dataset = dataset.map(lambda x: (self._get_features(x)))
235
+ return dataset
@@ -0,0 +1,317 @@
1
+ import logging
2
+ import multiprocessing
3
+ import queue
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+
8
+
9
+ def start_data_proc(task_index,
10
+ task_num,
11
+ num_proc,
12
+ file_que,
13
+ data_que,
14
+ proc_start_que,
15
+ proc_stop_que,
16
+ batch_size,
17
+ label_fields,
18
+ sparse_fea_names,
19
+ dense_fea_names,
20
+ dense_fea_cfgs,
21
+ reserve_fields,
22
+ drop_remainder,
23
+ need_pack=True):
24
+ mp_ctxt = multiprocessing.get_context('spawn')
25
+ proc_arr = []
26
+ for proc_id in range(num_proc):
27
+ proc = mp_ctxt.Process(
28
+ target=load_data_proc,
29
+ args=(proc_id, file_que, data_que, proc_start_que, proc_stop_que,
30
+ batch_size, label_fields, sparse_fea_names, dense_fea_names,
31
+ dense_fea_cfgs, reserve_fields, drop_remainder, task_index,
32
+ task_num, need_pack),
33
+ name='task_%d_data_proc_%d' % (task_index, proc_id))
34
+ proc.daemon = True
35
+ proc.start()
36
+ proc_arr.append(proc)
37
+ return proc_arr
38
+
39
+
40
+ def _should_stop(proc_stop_que):
41
+ try:
42
+ proc_stop_que.get(block=False)
43
+ logging.info('data_proc stop signal received')
44
+ proc_stop_que.close()
45
+ return True
46
+ except queue.Empty:
47
+ return False
48
+ except ValueError:
49
+ return True
50
+ except AssertionError:
51
+ return True
52
+
53
+
54
+ def _add_to_que(data_dict, data_que, proc_stop_que):
55
+ while True:
56
+ try:
57
+ data_que.put(data_dict, timeout=5)
58
+ return True
59
+ except queue.Full:
60
+ logging.warning('data_que is full')
61
+ if _should_stop(proc_stop_que):
62
+ return False
63
+ except ValueError:
64
+ logging.warning('data_que is closed')
65
+ return False
66
+ except AssertionError:
67
+ logging.warning('data_que is closed')
68
+ return False
69
+
70
+
71
+ def _get_one_file(file_que, proc_stop_que):
72
+ while True:
73
+ try:
74
+ input_file = file_que.get(timeout=1)
75
+ return input_file
76
+ except queue.Empty:
77
+ pass
78
+ return None
79
+
80
+
81
+ def _pack_sparse_feas(data_dict, sparse_fea_names):
82
+ fea_val_arr = []
83
+ fea_len_arr = []
84
+ for fea_name in sparse_fea_names:
85
+ fea_len_arr.append(data_dict[fea_name][0])
86
+ fea_val_arr.append(data_dict[fea_name][1])
87
+ del data_dict[fea_name]
88
+ fea_lens = np.concatenate(fea_len_arr, axis=0)
89
+ fea_vals = np.concatenate(fea_val_arr, axis=0)
90
+ data_dict['sparse_fea'] = (fea_lens, fea_vals)
91
+
92
+
93
+ def _pack_dense_feas(data_dict, dense_fea_names, dense_fea_cfgs):
94
+ fea_val_arr = []
95
+ for fea_name, fea_cfg in zip(dense_fea_names, dense_fea_cfgs):
96
+ fea_val_arr.append(data_dict[fea_name].reshape([-1, fea_cfg.raw_input_dim]))
97
+ del data_dict[fea_name]
98
+ fea_vals = np.concatenate(fea_val_arr, axis=1)
99
+ data_dict['dense_fea'] = fea_vals
100
+
101
+
102
+ def _reshape_dense_feas(data_dict, dense_fea_names, dense_fea_cfgs):
103
+ for fea_name, fea_cfg in zip(dense_fea_names, dense_fea_cfgs):
104
+ data_dict[fea_name] = data_dict[fea_name].reshape(
105
+ [-1, fea_cfg.raw_input_dim])
106
+
107
+
108
+ def _load_dense(input_data, field_names, sid, eid, dense_dict):
109
+ for k in field_names:
110
+ if isinstance(input_data[k][0], np.ndarray):
111
+ np_dtype = type(input_data[k][sid][0])
112
+ dense_dict[k] = np.array([x[0] for x in input_data[k][sid:eid]],
113
+ dtype=np_dtype)
114
+ else:
115
+ dense_dict[k] = input_data[k][sid:eid].to_numpy()
116
+
117
+
118
+ def _load_and_pad_dense(input_data, field_names, sid, dense_dict,
119
+ part_dense_dict, part_dense_dict_n, batch_size):
120
+ for k in field_names:
121
+ if isinstance(input_data[k][0], np.ndarray):
122
+ np_dtype = type(input_data[k][sid][0])
123
+ tmp_lbls = np.array([x[0] for x in input_data[k][sid:]], dtype=np_dtype)
124
+ else:
125
+ tmp_lbls = input_data[k][sid:].to_numpy()
126
+ if part_dense_dict is not None and k in part_dense_dict:
127
+ tmp_lbls = np.concatenate([part_dense_dict[k], tmp_lbls], axis=0)
128
+ if len(tmp_lbls) > batch_size:
129
+ dense_dict[k] = tmp_lbls[:batch_size]
130
+ part_dense_dict_n[k] = tmp_lbls[batch_size:]
131
+ elif len(tmp_lbls) == batch_size:
132
+ dense_dict[k] = tmp_lbls
133
+ else:
134
+ part_dense_dict_n[k] = tmp_lbls
135
+ else:
136
+ part_dense_dict_n[k] = tmp_lbls
137
+
138
+
139
+ def load_data_proc(proc_id, file_que, data_que, proc_start_que, proc_stop_que,
140
+ batch_size, label_fields, sparse_fea_names, dense_fea_names,
141
+ dense_fea_cfgs, reserve_fields, drop_remainder, task_index,
142
+ task_num, need_pack):
143
+ logging.info('data proc %d start, proc_start_que=%s' %
144
+ (proc_id, proc_start_que.qsize()))
145
+ proc_start_que.get()
146
+ effective_fields = sparse_fea_names + dense_fea_names
147
+ all_fields = effective_fields
148
+ if label_fields is not None:
149
+ all_fields = all_fields + label_fields
150
+ if reserve_fields is not None:
151
+ for tmp in reserve_fields:
152
+ if tmp not in all_fields:
153
+ all_fields.append(tmp)
154
+ logging.info('data proc %d start, file_que.qsize=%d' %
155
+ (proc_id, file_que.qsize()))
156
+ num_files = 0
157
+ part_data_dict = {}
158
+
159
+ is_good = True
160
+ total_batch_cnt = 0
161
+ total_sample_cnt = 0
162
+ while is_good:
163
+ if _should_stop(proc_stop_que):
164
+ is_good = False
165
+ break
166
+ input_file = _get_one_file(file_que, proc_stop_que)
167
+ if input_file is None:
168
+ break
169
+ num_files += 1
170
+ input_data = pd.read_parquet(input_file, columns=all_fields)
171
+ data_len = len(input_data[all_fields[0]])
172
+ total_sample_cnt += data_len
173
+ batch_num = int(data_len / batch_size)
174
+ res_num = data_len % batch_size
175
+
176
+ sid = 0
177
+ for batch_id in range(batch_num):
178
+ eid = sid + batch_size
179
+ data_dict = {}
180
+
181
+ if label_fields is not None and len(label_fields) > 0:
182
+ _load_dense(input_data, label_fields, sid, eid, data_dict)
183
+
184
+ if reserve_fields is not None and len(reserve_fields) > 0:
185
+ data_dict['reserve'] = {}
186
+ _load_dense(input_data, reserve_fields, sid, eid, data_dict['reserve'])
187
+
188
+ if len(sparse_fea_names) > 0:
189
+ for k in sparse_fea_names:
190
+ val = input_data[k][sid:eid]
191
+ if isinstance(input_data[k][sid], np.ndarray):
192
+ all_lens = np.array([len(x) for x in val], dtype=np.int32)
193
+ all_vals = np.concatenate(val.to_numpy())
194
+ else:
195
+ all_lens = np.ones([len(val)], dtype=np.int32)
196
+ all_vals = val.to_numpy()
197
+ assert np.sum(all_lens) == len(
198
+ all_vals), 'len(all_vals)=%d np.sum(all_lens)=%d' % (
199
+ len(all_vals), np.sum(all_lens))
200
+ data_dict[k] = (all_lens, all_vals)
201
+
202
+ if len(dense_fea_names) > 0:
203
+ _load_dense(input_data, dense_fea_names, sid, eid, data_dict)
204
+
205
+ if need_pack:
206
+ if len(sparse_fea_names) > 0:
207
+ _pack_sparse_feas(data_dict, sparse_fea_names)
208
+ if len(dense_fea_names) > 0:
209
+ _pack_dense_feas(data_dict, dense_fea_names, dense_fea_cfgs)
210
+ else:
211
+ if len(dense_fea_names) > 0:
212
+ _reshape_dense_feas(data_dict, dense_fea_names, dense_fea_cfgs)
213
+ # logging.info('task_index=%d sid=%d eid=%d total_len=%d' % (task_index, sid, eid,
214
+ # len(data_dict['sparse_fea'][1])))
215
+ if not _add_to_que(data_dict, data_que, proc_stop_que):
216
+ logging.info('add to que failed')
217
+ is_good = False
218
+ break
219
+ total_batch_cnt += 1
220
+ sid += batch_size
221
+
222
+ if res_num > 0 and is_good:
223
+ data_dict = {}
224
+ part_data_dict_n = {}
225
+
226
+ if label_fields is not None and len(label_fields) > 0:
227
+ _load_and_pad_dense(input_data, label_fields, sid, data_dict,
228
+ part_data_dict, part_data_dict_n, batch_size)
229
+
230
+ if reserve_fields is not None and len(reserve_fields) > 0:
231
+ data_dict['reserve'] = {}
232
+ part_data_dict_n['reserve'] = {}
233
+ _load_and_pad_dense(input_data, label_fields, sid, data_dict['reserve'],
234
+ part_data_dict['reserve'],
235
+ part_data_dict_n['reserve'], batch_size)
236
+
237
+ if len(dense_fea_names) > 0:
238
+ _load_and_pad_dense(input_data, dense_fea_names, sid, data_dict,
239
+ part_data_dict, part_data_dict_n, batch_size)
240
+
241
+ if len(sparse_fea_names) > 0:
242
+ for k in sparse_fea_names:
243
+ val = input_data[k][sid:]
244
+
245
+ if isinstance(input_data[k][sid], np.ndarray):
246
+ all_lens = np.array([len(x) for x in val], dtype=np.int32)
247
+ all_vals = np.concatenate(val.to_numpy())
248
+ else:
249
+ all_lens = np.ones([len(val)], dtype=np.int32)
250
+ all_vals = val.to_numpy()
251
+
252
+ if part_data_dict is not None and k in part_data_dict:
253
+ tmp_lens = np.concatenate([part_data_dict[k][0], all_lens], axis=0)
254
+ tmp_vals = np.concatenate([part_data_dict[k][1], all_vals], axis=0)
255
+ if len(tmp_lens) > batch_size:
256
+ tmp_res_lens = tmp_lens[batch_size:]
257
+ tmp_lens = tmp_lens[:batch_size]
258
+ tmp_num_elems = np.sum(tmp_lens)
259
+ tmp_res_vals = tmp_vals[tmp_num_elems:]
260
+ tmp_vals = tmp_vals[:tmp_num_elems]
261
+ part_data_dict_n[k] = (tmp_res_lens, tmp_res_vals)
262
+ data_dict[k] = (tmp_lens, tmp_vals)
263
+ elif len(tmp_lens) == batch_size:
264
+ data_dict[k] = (tmp_lens, tmp_vals)
265
+ else:
266
+ part_data_dict_n[k] = (tmp_lens, tmp_vals)
267
+ else:
268
+ part_data_dict_n[k] = (all_lens, all_vals)
269
+
270
+ if effective_fields[0] in data_dict:
271
+ if need_pack:
272
+ if len(sparse_fea_names) > 0:
273
+ _pack_sparse_feas(data_dict, sparse_fea_names)
274
+ if len(dense_fea_names) > 0:
275
+ _pack_dense_feas(data_dict, dense_fea_names, dense_fea_cfgs)
276
+ else:
277
+ if len(dense_fea_names) > 0:
278
+ _reshape_dense_feas(data_dict, dense_fea_names, dense_fea_cfgs)
279
+ if not _add_to_que(data_dict, data_que, proc_stop_que):
280
+ logging.info('add to que failed')
281
+ is_good = False
282
+ break
283
+ total_batch_cnt += 1
284
+ part_data_dict = part_data_dict_n
285
+ if len(part_data_dict) > 0 and is_good:
286
+ batch_len = len(part_data_dict[effective_fields[0]][0])
287
+ if not drop_remainder:
288
+ if need_pack:
289
+ if len(sparse_fea_names) > 0:
290
+ _pack_sparse_feas(part_data_dict, sparse_fea_names)
291
+ if len(dense_fea_names) > 0:
292
+ _pack_dense_feas(part_data_dict, dense_fea_names, dense_fea_cfgs)
293
+ else:
294
+ if len(dense_fea_names) > 0:
295
+ _reshape_dense_feas(part_data_dict, dense_fea_names, dense_fea_cfgs)
296
+ logging.info('remainder batch: %s sample_num=%d' %
297
+ (','.join(part_data_dict.keys()), batch_len))
298
+ _add_to_que(part_data_dict, data_que, proc_stop_que)
299
+ total_batch_cnt += 1
300
+ else:
301
+ logging.warning('drop remain %d samples as drop_remainder is set' %
302
+ batch_len)
303
+ if is_good:
304
+ is_good = _add_to_que(None, data_que, proc_stop_que)
305
+ logging.info(
306
+ 'data_proc_id[%d]: is_good = %s, total_batch_cnt=%d, total_sample_cnt=%d'
307
+ % (proc_id, is_good, total_batch_cnt, total_sample_cnt))
308
+ data_que.close(wait_send_finish=is_good)
309
+
310
+ while not is_good:
311
+ try:
312
+ if file_que.get(timeout=1) is None:
313
+ break
314
+ except queue.Empty:
315
+ pass
316
+ file_que.close()
317
+ logging.info('data proc %d done, file_num=%d' % (proc_id, num_files))