easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
@@ -0,0 +1,652 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ """Functions for reading and updating configuration files.
4
+
5
+ Such as Hyper parameter tuning or automatic feature expanding.
6
+ """
7
+
8
+ import datetime
9
+ import json
10
+ import logging
11
+ import os
12
+ import re
13
+ import sys
14
+
15
+ import numpy as np
16
+ import six
17
+ import tensorflow as tf
18
+ from google.protobuf import json_format
19
+ from google.protobuf import text_format
20
+ from tensorflow.python.lib.io import file_io
21
+
22
+ from easy_rec.python.protos import pipeline_pb2
23
+ from easy_rec.python.protos.feature_config_pb2 import FeatureConfig
24
+ from easy_rec.python.utils import pai_util
25
+ from easy_rec.python.utils.hive_utils import HiveUtils
26
+
27
+ if tf.__version__ >= '2.0':
28
+ tf = tf.compat.v1
29
+
30
+
31
+ def search_pipeline_config(directory):
32
+ dir_list = []
33
+ for root, dirs, files in tf.gfile.Walk(directory):
34
+ for f in files:
35
+ _, ext = os.path.splitext(f)
36
+ if ext == '.config':
37
+ dir_list.append(os.path.join(root, f))
38
+ if len(dir_list) == 0:
39
+ raise ValueError('config is not found in directory %s' % directory)
40
+ elif len(dir_list) > 1:
41
+ raise ValueError('config saved model found in directory %s' % directory)
42
+ logging.info('use pipeline config: %s' % dir_list[0])
43
+ return dir_list[0]
44
+
45
+
46
+ def get_configs_from_pipeline_file(pipeline_config_path, auto_expand=True):
47
+ """Reads config from a file containing pipeline_pb2.EasyRecConfig.
48
+
49
+ Args:
50
+ pipeline_config_path: Path to pipeline_pb2.EasyRecConfig text
51
+ proto.
52
+
53
+ Returns:
54
+ Dictionary of configuration objects. Keys are `model`, `train_config`,
55
+ `train_input_config`, `eval_config`, `eval_input_config`. Value are the
56
+ corresponding config objects.
57
+ """
58
+ if isinstance(pipeline_config_path, pipeline_pb2.EasyRecConfig):
59
+ return pipeline_config_path
60
+
61
+ assert tf.gfile.Exists(
62
+ pipeline_config_path
63
+ ), 'pipeline_config_path [%s] not exists' % pipeline_config_path
64
+
65
+ pipeline_config = pipeline_pb2.EasyRecConfig()
66
+ with tf.gfile.GFile(pipeline_config_path, 'r') as f:
67
+ config_str = f.read()
68
+ if pipeline_config_path.endswith('.config'):
69
+ text_format.Merge(config_str, pipeline_config)
70
+ elif pipeline_config_path.endswith('.json'):
71
+ json_format.Parse(config_str, pipeline_config)
72
+ else:
73
+ assert False, 'invalid file format(%s), currently support formats: .config(prototxt) .json' % pipeline_config_path
74
+
75
+ if auto_expand:
76
+ return auto_expand_share_feature_configs(pipeline_config)
77
+ else:
78
+ return pipeline_config
79
+
80
+
81
+ def auto_expand_share_feature_configs(pipeline_config):
82
+ feature_configs = get_compatible_feature_configs(pipeline_config)
83
+ for share_config in feature_configs:
84
+ if len(share_config.shared_names) == 0:
85
+ continue
86
+
87
+ # auto expand all shared_names
88
+ input_names = []
89
+ for input_name in share_config.shared_names:
90
+ if pipeline_config.data_config.auto_expand_input_fields:
91
+ input_names.extend(auto_expand_names(input_name))
92
+ else:
93
+ input_names.append(input_name)
94
+
95
+ # make a clean copy
96
+ while len(share_config.shared_names) > 0:
97
+ share_config.shared_names.pop()
98
+
99
+ fea_config = FeatureConfig()
100
+ fea_config.CopyFrom(share_config)
101
+ while len(fea_config.input_names) > 0:
102
+ fea_config.input_names.pop()
103
+
104
+ # generate for each item in input_name
105
+ for tmp_name in input_names:
106
+ tmp_config = FeatureConfig()
107
+ tmp_config.CopyFrom(fea_config)
108
+ tmp_config.input_names.append(tmp_name)
109
+ if pipeline_config.feature_configs:
110
+ pipeline_config.feature_configs.append(tmp_config)
111
+ else:
112
+ pipeline_config.feature_config.features.append(tmp_config)
113
+ return pipeline_config
114
+
115
+
116
+ # def auto_expand_names(input_name):
117
+ # """Auto expand field[1-3] to field1, field2, field3.
118
+ #
119
+ # Args:
120
+ # input_name: a string pattern like field[1-3]
121
+ #
122
+ # Returns:
123
+ # a string list of the expanded names
124
+ # """
125
+ # match_obj = re.match(r'([a-zA-Z_]+)\[([0-9]+)-([0-9]+)\]', input_name)
126
+ # if match_obj:
127
+ # prefix = match_obj.group(1)
128
+ # sid = int(match_obj.group(2))
129
+ # eid = int(match_obj.group(3)) + 1
130
+ # input_name = ['%s%d' % (prefix, tid) for tid in range(sid, eid)]
131
+ # else:
132
+ # input_name = [input_name]
133
+ # return input_name
134
+
135
+
136
+ def auto_expand_names(input_name):
137
+ """Auto expand field[1-3] to field1, field2, field3.
138
+
139
+ Args:
140
+ input_name: a string pattern like field[1-3]
141
+
142
+ Returns:
143
+ a string list of the expanded names
144
+ Todo:
145
+ could be extended to support more complicated patterns
146
+ """
147
+ flag = 1
148
+ if input_name.endswith(']'):
149
+ match_obj = re.match(r'([a-zA-Z_]+)\[([0-9]+)-([0-9]+)\]', input_name)
150
+ else:
151
+ flag = 2
152
+ match_obj = re.match(r'([a-zA-Z_]+)\[([0-9]+)-([0-9]+)\]([a-zA-Z_]+)',
153
+ input_name)
154
+ if match_obj:
155
+ prefix = match_obj.group(1)
156
+ sid = int(match_obj.group(2))
157
+ eid = int(match_obj.group(3)) + 1
158
+ if flag == 2:
159
+ endfix = match_obj.group(4)
160
+ input_name = ['%s%d%s' % (prefix, tid, endfix) for tid in range(sid, eid)]
161
+ else:
162
+ input_name = ['%s%d' % (prefix, tid) for tid in range(sid, eid)]
163
+
164
+ else:
165
+ input_name = [input_name]
166
+ return input_name
167
+
168
+
169
+ def create_pipeline_proto_from_configs(configs):
170
+ """Creates a pipeline_pb2.EasyRecConfig from configs dictionary.
171
+
172
+ This function performs the inverse operation of
173
+ create_configs_from_pipeline_proto().
174
+
175
+ Args:
176
+ configs: Dictionary of configs. See get_configs_from_pipeline_file().
177
+
178
+ Returns:
179
+ A fully populated pipeline_pb2.EasyRecConfig.
180
+ """
181
+ pipeline_config = pipeline_pb2.EasyRecConfig()
182
+ pipeline_config.model.CopyFrom(configs['model'])
183
+ pipeline_config.train_config.CopyFrom(configs['train_config'])
184
+ pipeline_config.train_input_reader.CopyFrom(configs['train_input_config'])
185
+ pipeline_config.eval_config.CopyFrom(configs['eval_config'])
186
+ pipeline_config.eval_input_reader.CopyFrom(configs['eval_input_config'])
187
+ if 'graph_rewriter_config' in configs:
188
+ pipeline_config.graph_rewriter.CopyFrom(configs['graph_rewriter_config'])
189
+ return pipeline_config
190
+
191
+
192
+ def save_pipeline_config(pipeline_config,
193
+ directory,
194
+ filename='pipeline.config'):
195
+ """Saves a pipeline config text file to disk.
196
+
197
+ Args:
198
+ pipeline_config: A pipeline_pb2.TrainEvalPipelineConfig.
199
+ directory: The model directory into which the pipeline config file will be
200
+ saved.
201
+ filename: pipelineconfig filename
202
+ """
203
+ if not file_io.file_exists(directory):
204
+ file_io.recursive_create_dir(directory)
205
+ pipeline_config_path = os.path.join(directory, filename)
206
+ # as_utf8=True to make sure pbtxt is human readable when string contains chinese
207
+ save_message(pipeline_config, pipeline_config_path)
208
+
209
+
210
+ def _get_basic_types():
211
+ dtypes = [
212
+ bool, int, str, float,
213
+ type(u''), np.float16, np.float32, np.float64, np.char, np.byte, np.uint8,
214
+ np.int8, np.int16, np.uint16, np.uint32, np.int32, np.uint64, np.int64,
215
+ bool, str
216
+ ]
217
+ if six.PY2:
218
+ dtypes.append(long) # noqa: F821
219
+
220
+ return dtypes
221
+
222
+
223
+ def edit_config(pipeline_config, edit_config_json):
224
+ """Update params specified by automl.
225
+
226
+ Args:
227
+ pipeline_config: EasyRecConfig
228
+ edit_config_json: edit config json
229
+ """
230
+
231
+ def _type_convert(proto, val, parent=None):
232
+ if type(val) != type(proto):
233
+ try:
234
+ if isinstance(proto, bool):
235
+ assert val in ['True', 'true', 'False', 'false']
236
+ val = val in ['True', 'true']
237
+ else:
238
+ val = type(proto)(val)
239
+ except ValueError as ex:
240
+ if parent is None:
241
+ raise ex
242
+ assert isinstance(proto, int)
243
+ val = getattr(parent, val)
244
+ assert isinstance(val, int)
245
+ return val
246
+
247
+ def _get_attr(obj, attr, only_last=False):
248
+ # only_last means we only return the last element in paths array
249
+ attr_toks = [x.strip() for x in attr.split('.') if x != '']
250
+ paths = []
251
+ objs = [obj]
252
+ nobjs = []
253
+ for key in attr_toks:
254
+ # clear old paths to clear new paths
255
+ paths = []
256
+ for obj in objs:
257
+ if '[' in key:
258
+ pos = key.find('[')
259
+ name, cond = key[:pos], key[pos + 1:]
260
+ cond = cond[:-1]
261
+ update_objs = getattr(obj, name)
262
+ # select all update_objs
263
+ if cond == ':':
264
+ for tid, update_obj in enumerate(update_objs):
265
+ paths.append((obj, update_obj, None, tid))
266
+ nobjs.append(update_obj)
267
+ continue
268
+
269
+ # select by range update_objs[1:10]
270
+ if ':' in cond:
271
+ colon_pos = cond.find(':')
272
+ sid = cond[:colon_pos]
273
+ if len(sid) == 0:
274
+ sid = 0
275
+ else:
276
+ sid = int(sid)
277
+ eid = cond[(colon_pos + 1):]
278
+ if len(eid) == 0:
279
+ eid = len(update_objs)
280
+ else:
281
+ eid = int(eid)
282
+ for tid, update_obj in enumerate(update_objs[sid:eid]):
283
+ paths.append((obj, update_obj, None, tid + sid))
284
+ nobjs.append(update_obj)
285
+ continue
286
+
287
+ # for simple index update_objs[0]
288
+ try:
289
+ obj_id = int(cond)
290
+ obj = update_objs[obj_id]
291
+ paths.append((obj, update_objs, None, obj_id))
292
+ nobjs.append(obj)
293
+ continue
294
+ except ValueError:
295
+ pass
296
+
297
+ # for complex conditions a[optimizer.lr=20]
298
+ op_func_map = {
299
+ '>=': lambda x, y: x >= y,
300
+ '<=': lambda x, y: x <= y,
301
+ '<': lambda x, y: x < y,
302
+ '>': lambda x, y: x > y,
303
+ '=': lambda x, y: x == y
304
+ }
305
+ cond_key = None
306
+ cond_val = None
307
+ op_func = None
308
+ for op in ['>=', '<=', '>', '<', '=']:
309
+ tmp_pos = cond.rfind(op)
310
+ if tmp_pos != -1:
311
+ cond_key = cond[:tmp_pos]
312
+ cond_val = cond[(tmp_pos + len(op)):]
313
+ op_func = op_func_map[op]
314
+ break
315
+
316
+ assert cond_key is not None, 'invalid cond: %s' % cond
317
+ assert cond_val is not None, 'invalid cond: %s' % cond
318
+
319
+ for tid, update_obj in enumerate(update_objs):
320
+ tmp, tmp_parent, _, _ = _get_attr(
321
+ update_obj, cond_key, only_last=True)
322
+
323
+ cond_val = _type_convert(tmp, cond_val, tmp_parent)
324
+
325
+ if op_func(tmp, cond_val):
326
+ obj_id = tid
327
+ paths.append((update_obj, update_objs, None, obj_id))
328
+ nobjs.append(update_obj)
329
+ else:
330
+ sub_obj = getattr(obj, key)
331
+ paths.append((sub_obj, obj, key, -1))
332
+ nobjs.append(sub_obj)
333
+ # exchange to prepare for parsing next token
334
+ objs = nobjs
335
+ nobjs = []
336
+ if only_last:
337
+ return paths[-1]
338
+ else:
339
+ return paths
340
+
341
+ for param_keys in edit_config_json:
342
+ # multiple keys/vals combination
343
+ param_vals = edit_config_json[param_keys]
344
+ param_vals = [x.strip() for x in str(param_vals).split(';')]
345
+ param_keys = [x.strip() for x in str(param_keys).split(';')]
346
+ for param_key, param_val in zip(param_keys, param_vals):
347
+ update_obj = pipeline_config
348
+ tmp_paths = _get_attr(update_obj, param_key)
349
+ # update a set of objs
350
+ for tmp_val, tmp_obj, tmp_name, tmp_id in tmp_paths:
351
+ # list and dict are not basic types, must be handle separately
352
+ basic_types = _get_basic_types()
353
+ if type(tmp_val) in basic_types:
354
+ # simple type cast
355
+ tmp_val = _type_convert(tmp_val, param_val, tmp_obj)
356
+ if tmp_name is None:
357
+ tmp_obj[tmp_id] = tmp_val
358
+ else:
359
+ setattr(tmp_obj, tmp_name, tmp_val)
360
+ elif 'Scalar' in str(type(tmp_val)) and 'ClearField' in dir(tmp_obj):
361
+ tmp_obj.ClearField(tmp_name)
362
+ text_format.Parse('%s:%s' % (tmp_name, param_val), tmp_obj)
363
+ else:
364
+ tmp_val.Clear()
365
+ param_val = param_val.strip()
366
+ if param_val.startswith('{') and param_val.endswith('}'):
367
+ param_val = param_val[1:-1]
368
+ text_format.Parse(param_val, tmp_val)
369
+
370
+ return pipeline_config
371
+
372
+
373
+ def save_message(protobuf_message, filename):
374
+ """Saves a pipeline config text file to disk.
375
+
376
+ Args:
377
+ protobuf_message: A pipeline_pb2.TrainEvalPipelineConfig.
378
+ filename: pipeline config filename
379
+ """
380
+ directory, _ = os.path.split(filename)
381
+ if not file_io.file_exists(directory):
382
+ file_io.recursive_create_dir(directory)
383
+ # as_utf8=True to make sure pbtxt is human readable when string contains chinese
384
+ config_text = text_format.MessageToString(protobuf_message, as_utf8=True)
385
+ with tf.gfile.Open(filename, 'wb') as f:
386
+ logging.info('Writing protobuf message file to %s', filename)
387
+ f.write(config_text)
388
+
389
+
390
+ def add_boundaries_to_config(pipeline_config, tables):
391
+ import common_io
392
+
393
+ feature_boundaries_info = {}
394
+ reader = common_io.table.TableReader(tables, selected_cols='feature,json')
395
+ while True:
396
+ try:
397
+ record = reader.read()
398
+ raw_info = json.loads(record[0][1])
399
+ bin_info = []
400
+ for info in raw_info['bin']['norm'][:-1]:
401
+ split_point = float(info['value'].split(',')[1][:-1])
402
+ bin_info.append(split_point)
403
+ feature_boundaries_info[record[0][0]] = bin_info
404
+ except common_io.exception.OutOfRangeException:
405
+ reader.close()
406
+ break
407
+
408
+ logging.info('feature boundaries: %s' % feature_boundaries_info)
409
+ feature_configs = get_compatible_feature_configs(pipeline_config)
410
+ for feature_config in feature_configs:
411
+ feature_name = feature_config.input_names[0]
412
+ if feature_name in feature_boundaries_info:
413
+ if feature_config.feature_type != feature_config.SequenceFeature:
414
+ logging.info(
415
+ 'feature = {0}, type = {1}, will turn to RawFeature.'.format(
416
+ feature_name, feature_config.feature_type))
417
+ feature_config.feature_type = feature_config.RawFeature
418
+ feature_config.hash_bucket_size = 0
419
+ feature_config.ClearField('boundaries')
420
+ feature_config.boundaries.extend(feature_boundaries_info[feature_name])
421
+ logging.info('edited %s' % feature_name)
422
+
423
+
424
+ def get_compatible_feature_configs(pipeline_config):
425
+ if pipeline_config.feature_configs:
426
+ feature_configs = pipeline_config.feature_configs
427
+ else:
428
+ feature_configs = pipeline_config.feature_config.features
429
+ return feature_configs
430
+
431
+
432
+ def parse_time(time_data):
433
+ """Parse time string to timestamp.
434
+
435
+ Args:
436
+ time_data: could be two formats: '%Y%m%d %H:%M:%S' or '%s'
437
+ Return:
438
+ timestamp: int
439
+ """
440
+ if isinstance(time_data, str) or isinstance(time_data, type(u'')):
441
+ if len(time_data) == 17:
442
+ return int(
443
+ datetime.datetime.strptime(time_data,
444
+ '%Y%m%d %H:%M:%S').strftime('%s'))
445
+ elif len(time_data) == 10:
446
+ return int(time_data)
447
+ else:
448
+ assert 'invalid time string: %s' % time_data
449
+ else:
450
+ return int(time_data)
451
+
452
+
453
+ def search_fg_json(directory):
454
+ dir_list = []
455
+ for root, dirs, files in tf.gfile.Walk(directory):
456
+ for f in files:
457
+ _, ext = os.path.splitext(f)
458
+ if ext == '.json':
459
+ dir_list.append(os.path.join(root, f))
460
+ if len(dir_list) == 0:
461
+ return None
462
+ elif len(dir_list) > 1:
463
+ raise ValueError('fg.json found in directory %s' % directory)
464
+ logging.info('use fg.json: %s' % dir_list[0])
465
+ return dir_list[0]
466
+
467
+
468
+ def get_input_name_from_fg_json(fg_json):
469
+ if not fg_json:
470
+ return []
471
+ input_names = []
472
+ for fea in fg_json['features']:
473
+ if 'feature_name' in fea:
474
+ if 'stub_type' in fea and fea['stub_type']:
475
+ continue
476
+ input_names.append(fea['feature_name'])
477
+ elif 'sequence_name' in fea:
478
+ sequence_name = fea['sequence_name']
479
+ for seq_fea in fea['features']:
480
+ assert 'feature_name' in seq_fea
481
+ if 'stub_type' in seq_fea and seq_fea['stub_type']:
482
+ continue
483
+ feature_name = seq_fea['feature_name']
484
+ input_names.append(sequence_name + '__' + feature_name)
485
+ return input_names
486
+
487
+
488
+ def get_train_input_path(pipeline_config):
489
+ input_name = pipeline_config.WhichOneof('train_path')
490
+ return getattr(pipeline_config, input_name)
491
+
492
+
493
+ def get_eval_input_path(pipeline_config):
494
+ input_name = pipeline_config.WhichOneof('eval_path')
495
+ return getattr(pipeline_config, input_name)
496
+
497
+
498
+ def get_model_dir_path(pipeline_config):
499
+ model_dir = pipeline_config.model_dir
500
+ return model_dir
501
+
502
+
503
+ def set_train_input_path(pipeline_config, train_input_path):
504
+ if pipeline_config.WhichOneof('train_path') == 'hive_train_input':
505
+ if isinstance(train_input_path, list):
506
+ assert len(
507
+ train_input_path
508
+ ) <= 1, 'only support one hive_train_input.table_name when hive input'
509
+ pipeline_config.hive_train_input.table_name = train_input_path[0]
510
+ else:
511
+ assert len(
512
+ train_input_path.split(',')
513
+ ) <= 1, 'only support one hive_train_input.table_name when hive input'
514
+ pipeline_config.hive_train_input.table_name = train_input_path
515
+ logging.info('update hive_train_input.table_name to %s' %
516
+ pipeline_config.hive_train_input.table_name)
517
+
518
+ elif pipeline_config.WhichOneof('train_path') == 'kafka_train_input':
519
+ if isinstance(train_input_path, list):
520
+ pipeline_config.kafka_train_input = ','.join(train_input_path)
521
+ else:
522
+ pipeline_config.kafka_train_input = train_input_path
523
+ elif pipeline_config.WhichOneof('train_path') == 'parquet_train_input':
524
+ if isinstance(train_input_path, list):
525
+ pipeline_config.parquet_train_input = ','.join(train_input_path)
526
+ else:
527
+ pipeline_config.parquet_train_input = train_input_path
528
+ else:
529
+ if isinstance(train_input_path, list):
530
+ pipeline_config.train_input_path = ','.join(train_input_path)
531
+ else:
532
+ pipeline_config.train_input_path = train_input_path
533
+ logging.info('update train_input_path to %s' %
534
+ pipeline_config.train_input_path)
535
+ return pipeline_config
536
+
537
+
538
+ def set_eval_input_path(pipeline_config, eval_input_path):
539
+ if pipeline_config.WhichOneof('eval_path') == 'hive_eval_input':
540
+ if isinstance(eval_input_path, list):
541
+ assert len(
542
+ eval_input_path
543
+ ) <= 1, 'only support one hive_eval_input.table_name when hive input'
544
+ pipeline_config.hive_eval_input.table_name = eval_input_path[0]
545
+ else:
546
+ assert len(
547
+ eval_input_path.split(',')
548
+ ) <= 1, 'only support one hive_eval_input.table_name when hive input'
549
+ pipeline_config.hive_eval_input.table_name = eval_input_path
550
+ logging.info('update hive_eval_input.table_name to %s' %
551
+ pipeline_config.hive_eval_input.table_name)
552
+ elif pipeline_config.WhichOneof('eval_path') == 'parquet_eval_input':
553
+ if isinstance(eval_input_path, list):
554
+ pipeline_config.parquet_eval_input = ','.join(eval_input_path)
555
+ else:
556
+ pipeline_config.parquet_eval_input = eval_input_path
557
+ elif pipeline_config.WhichOneof('eval_path') == 'kafka_eval_input':
558
+ if isinstance(eval_input_path, list):
559
+ pipeline_config.kafka_eval_input = ','.join(eval_input_path)
560
+ else:
561
+ pipeline_config.kafka_eval_input = eval_input_path
562
+ else:
563
+ if isinstance(eval_input_path, list):
564
+ pipeline_config.eval_input_path = ','.join(eval_input_path)
565
+ else:
566
+ pipeline_config.eval_input_path = eval_input_path
567
+ logging.info('update eval_input_path to %s' %
568
+ pipeline_config.eval_input_path)
569
+ return pipeline_config
570
+
571
+
572
+ def process_data_path(data_path, hive_util):
573
+ if data_path.startswith('hdfs://'):
574
+ return data_path
575
+ if re.match(r'(.*)\.(.*)', data_path):
576
+ hdfs_path = hive_util.get_table_location(data_path)
577
+ assert hdfs_path, "Can't find hdfs path of %s" % data_path
578
+ logging.info('update %s to %s' % (data_path, hdfs_path))
579
+ return hdfs_path
580
+ return data_path
581
+
582
+
583
+ def process_neg_sampler_data_path(pipeline_config):
584
+ # replace neg_sampler hive table => hdfs path
585
+ if pai_util.is_on_pai():
586
+ return
587
+ if not pipeline_config.data_config.HasField('sampler'):
588
+ return
589
+ # not using hive, so not need to process it
590
+ if pipeline_config.WhichOneof('train_path') != 'hive_train_input':
591
+ return
592
+ hive_util = HiveUtils(
593
+ data_config=pipeline_config.data_config,
594
+ hive_config=pipeline_config.hive_train_input)
595
+ sampler_type = pipeline_config.data_config.WhichOneof('sampler')
596
+ sampler_config = getattr(pipeline_config.data_config, sampler_type)
597
+ if hasattr(sampler_config, 'input_path'):
598
+ sampler_config.input_path = process_data_path(sampler_config.input_path,
599
+ hive_util)
600
+ if hasattr(sampler_config, 'user_input_path'):
601
+ sampler_config.user_input_path = process_data_path(
602
+ sampler_config.user_input_path, hive_util)
603
+ if hasattr(sampler_config, 'item_input_path'):
604
+ sampler_config.item_input_path = process_data_path(
605
+ sampler_config.item_input_path, hive_util)
606
+ if hasattr(sampler_config, 'pos_edge_input_path'):
607
+ sampler_config.pos_edge_input_path = process_data_path(
608
+ sampler_config.pos_edge_input_path, hive_util)
609
+ if hasattr(sampler_config, 'hard_neg_edge_input_path'):
610
+ sampler_config.hard_neg_edge_input_path = process_data_path(
611
+ sampler_config.hard_neg_edge_input_path, hive_util)
612
+
613
+
614
+ def parse_extra_config_param(extra_args, edit_config_json):
615
+ arg_num = len(extra_args)
616
+ arg_id = 0
617
+ while arg_id < arg_num:
618
+ if extra_args[arg_id].startswith('--data_config.') or \
619
+ extra_args[arg_id].startswith('--train_config.') or \
620
+ extra_args[arg_id].startswith('--feature_config.') or \
621
+ extra_args[arg_id].startswith('--model_config.') or \
622
+ extra_args[arg_id].startswith('--export_config.') or \
623
+ extra_args[arg_id].startswith('--eval_config.'):
624
+ tmp_arg = extra_args[arg_id][2:]
625
+ if '=' in tmp_arg:
626
+ sep_pos = tmp_arg.find('=')
627
+ k = tmp_arg[:sep_pos]
628
+ v = tmp_arg[(sep_pos + 1):]
629
+ v = v.strip(' "\'')
630
+ edit_config_json[k] = v
631
+ arg_id += 1
632
+ elif arg_id + 1 < len(extra_args):
633
+ edit_config_json[tmp_arg] = extra_args[arg_id + 1].strip(' "\'')
634
+ arg_id += 2
635
+ else:
636
+ logging.error('missing value for arg: %s' % extra_args[arg_id])
637
+ sys.exit(1)
638
+ else:
639
+ logging.error('unknown args: %s' % extra_args[arg_id])
640
+ sys.exit(1)
641
+
642
+
643
+ def process_multi_file_input_path(sampler_config_input_path):
644
+
645
+ if '*' in sampler_config_input_path:
646
+ input_path = ','.join(
647
+ file_path
648
+ for file_path in tf.gfile.Glob(sampler_config_input_path.split(',')))
649
+ else:
650
+ input_path = sampler_config_input_path
651
+
652
+ return input_path
@@ -0,0 +1,43 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+
4
+ import os
5
+
6
+ SAMPLE_WEIGHT = 'SAMPLE_WEIGHT'
7
+
8
+ DENSE_UPDATE_VARIABLES = 'DENSE_UPDATE_VARIABLES'
9
+
10
+ SPARSE_UPDATE_VARIABLES = 'SPARSE_UPDATE_VARIABLES'
11
+ ENABLE_AVX_STR_SPLIT = 'ENABLE_AVX_STR_SPLIT'
12
+
13
+ # Environment variables to control whether to sort
14
+ # feature columns by name, by default sort is not
15
+ # enabled. The flag is set for backward compatibility.
16
+ SORT_COL_BY_NAME = 'SORT_COL_BY_NAME'
17
+
18
+ # arithmetic_optimization causes significant slow training
19
+ # of a test case:
20
+ # train_eval_test.TrainEvalTest.test_train_parquet
21
+ NO_ARITHMETRIC_OPTI = 'NO_ARITHMETRIC_OPTI'
22
+
23
+ # shard embedding var_name collection
24
+ EmbeddingParallel = 'EmbeddingParallel'
25
+
26
+ # environ variable to force embedding placement on cpu
27
+ EmbeddingOnCPU = 'place_embedding_on_cpu'
28
+
29
+ # clear ps counter queue at start to reuse existing ps
30
+ ClearPsCounterQueue = 'CLEAR_PS_COUNTER_QUEUE'
31
+
32
+
33
+ def enable_avx_str_split():
34
+ os.environ[ENABLE_AVX_STR_SPLIT] = '1'
35
+
36
+
37
+ def has_avx_str_split():
38
+ return ENABLE_AVX_STR_SPLIT in os.environ and os.environ[
39
+ ENABLE_AVX_STR_SPLIT] == '1'
40
+
41
+
42
+ def disable_avx_str_split():
43
+ del os.environ[ENABLE_AVX_STR_SPLIT]