easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
@@ -0,0 +1,373 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+
4
+ import json
5
+ import logging
6
+ import os
7
+ import threading
8
+ import time
9
+ import traceback
10
+ import unittest
11
+
12
+ import numpy as np
13
+ import six
14
+ import tensorflow as tf
15
+ from tensorflow.python.data.ops import iterator_ops
16
+ from tensorflow.python.platform import gfile
17
+
18
+ from easy_rec.python.inference.predictor import Predictor
19
+ from easy_rec.python.input.kafka_dataset import KafkaDataset
20
+ from easy_rec.python.utils import numpy_utils
21
+ from easy_rec.python.utils import test_utils
22
+
23
+ try:
24
+ import kafka
25
+ from kafka import KafkaProducer, KafkaAdminClient
26
+ from kafka.admin import NewTopic
27
+ except ImportError:
28
+ logging.warning('kafka-python is not installed: %s' % traceback.format_exc())
29
+
30
+
31
+ class KafkaTest(tf.test.TestCase):
32
+
33
+ def setUp(self):
34
+ self._success = True
35
+ self._test_dir = test_utils.get_tmp_dir()
36
+ if self._testMethodName == 'test_session':
37
+ self._kafka_server_proc = None
38
+ self._zookeeper_proc = None
39
+ return
40
+
41
+ logging.info('Testing %s.%s, test_dir=%s' %
42
+ (type(self).__name__, self._testMethodName, self._test_dir))
43
+ self._log_dir = os.path.join(self._test_dir, 'logs')
44
+ if not gfile.IsDirectory(self._log_dir):
45
+ gfile.MakeDirs(self._log_dir)
46
+
47
+ self._kafka_servers = ['127.0.0.1:9092']
48
+ self._test_topic = 'kafka_op_test_topic'
49
+
50
+ if 'kafka_install_dir' in os.environ:
51
+ kafka_install_dir = os.environ.get('kafka_install_dir', None)
52
+
53
+ zookeeper_config_raw = '%s/config/zookeeper.properties' % kafka_install_dir
54
+ zookeeper_config = os.path.join(self._test_dir, 'zookeeper.properties')
55
+ with open(zookeeper_config, 'w') as fout:
56
+ with open(zookeeper_config_raw, 'r') as fin:
57
+ for line_str in fin:
58
+ if line_str.startswith('dataDir='):
59
+ fout.write('dataDir=%s/zookeeper\n' % self._test_dir)
60
+ else:
61
+ fout.write(line_str)
62
+ cmd = 'bash %s/bin/zookeeper-server-start.sh %s' % (kafka_install_dir,
63
+ zookeeper_config)
64
+ log_file = os.path.join(self._log_dir, 'zookeeper.log')
65
+ self._zookeeper_proc = test_utils.run_cmd(cmd, log_file)
66
+
67
+ kafka_config_raw = '%s/config/server.properties' % kafka_install_dir
68
+ kafka_config = os.path.join(self._test_dir, 'server.properties')
69
+ with open(kafka_config, 'w') as fout:
70
+ with open(kafka_config_raw, 'r') as fin:
71
+ for line_str in fin:
72
+ if line_str.startswith('log.dirs='):
73
+ fout.write('log.dirs=%s/kafka\n' % self._test_dir)
74
+ else:
75
+ fout.write(line_str)
76
+ cmd = 'bash %s/bin/kafka-server-start.sh %s' % (kafka_install_dir,
77
+ kafka_config)
78
+ log_file = os.path.join(self._log_dir, 'kafka_server.log')
79
+ self._kafka_server_proc = test_utils.run_cmd(cmd, log_file)
80
+
81
+ started = False
82
+ while not started:
83
+ if self._kafka_server_proc.poll(
84
+ ) and self._kafka_server_proc.returncode:
85
+ logging.warning('start kafka server failed, will retry.')
86
+ os.system('cat %s' % log_file)
87
+ self._kafka_server_proc = test_utils.run_cmd(cmd, log_file)
88
+ time.sleep(5)
89
+ else:
90
+ try:
91
+ admin_clt = KafkaAdminClient(bootstrap_servers=self._kafka_servers)
92
+ logging.info('old topics: %s' % (','.join(admin_clt.list_topics())))
93
+ admin_clt.close()
94
+ started = True
95
+ except kafka.errors.NoBrokersAvailable:
96
+ time.sleep(2)
97
+ self._create_topic()
98
+ else:
99
+ self._zookeeper_proc = None
100
+ self._kafka_server_proc = None
101
+ self._should_stop = False
102
+ self._producer = None
103
+
104
+ def _create_topic(self, num_partitions=2):
105
+ admin_clt = KafkaAdminClient(bootstrap_servers=self._kafka_servers)
106
+
107
+ logging.info('create topic: %s' % self._test_topic)
108
+ topic_list = [
109
+ NewTopic(
110
+ name=self._test_topic,
111
+ num_partitions=num_partitions,
112
+ replication_factor=1)
113
+ ]
114
+
115
+ admin_clt.create_topics(new_topics=topic_list, validate_only=False)
116
+ logging.info('all topics: %s' % (','.join(admin_clt.list_topics())))
117
+ admin_clt.close()
118
+
119
+ def _create_producer(self, generate_func):
120
+ # start produce thread
121
+
122
+ prod = threading.Thread(target=generate_func)
123
+ prod.start()
124
+ return prod
125
+
126
+ def _stop_producer(self):
127
+ if self._producer is not None:
128
+ self._should_stop = True
129
+ self._producer.join()
130
+
131
+ def tearDown(self):
132
+ try:
133
+ self._stop_producer()
134
+ if self._kafka_server_proc is not None:
135
+ self._kafka_server_proc.terminate()
136
+ except Exception as ex:
137
+ logging.warning('exception terminate kafka proc: %s' % str(ex))
138
+
139
+ try:
140
+ if self._zookeeper_proc is not None:
141
+ self._zookeeper_proc.terminate()
142
+ except Exception as ex:
143
+ logging.warning('exception terminate zookeeper proc: %s' % str(ex))
144
+
145
+ test_utils.set_gpu_id(None)
146
+ if self._success:
147
+ test_utils.clean_up(self._test_dir)
148
+
149
+ @unittest.skipIf('kafka_install_dir' not in os.environ,
150
+ 'Only execute when kafka is available')
151
+ def test_kafka_ops(self):
152
+ try:
153
+ test_utils.set_gpu_id(None)
154
+
155
+ def _generate():
156
+ producer = KafkaProducer(
157
+ bootstrap_servers=self._kafka_servers, api_version=(0, 10, 1))
158
+ i = 0
159
+ while not self._should_stop:
160
+ msg = 'user_id_%d' % i
161
+ producer.send(self._test_topic, msg)
162
+ producer.close()
163
+
164
+ self._producer = self._create_producer(_generate)
165
+
166
+ group = 'dataset_consumer'
167
+ k = KafkaDataset(
168
+ servers=self._kafka_servers[0],
169
+ topics=[self._test_topic + ':0', self._test_topic + ':1'],
170
+ group=group,
171
+ eof=True,
172
+ # control the maximal read of each partition
173
+ config_global=['max.partition.fetch.bytes=1048576'],
174
+ message_key=True,
175
+ message_offset=True)
176
+
177
+ batch_dataset = k.batch(5)
178
+
179
+ iterator = iterator_ops.Iterator.from_structure(
180
+ batch_dataset.output_types)
181
+ init_batch_op = iterator.make_initializer(batch_dataset)
182
+ get_next = iterator.get_next()
183
+
184
+ sess = tf.Session()
185
+ sess.run(init_batch_op)
186
+
187
+ p = sess.run(get_next)
188
+
189
+ self.assertEquals(len(p), 3)
190
+ offset = p[2]
191
+ self.assertEquals(offset[0], '0:0')
192
+ self.assertEquals(offset[1], '0:1')
193
+
194
+ p = sess.run(get_next)
195
+ offset = p[2]
196
+ self.assertEquals(offset[0], '0:5')
197
+ self.assertEquals(offset[1], '0:6')
198
+
199
+ max_iter = 300
200
+ while max_iter > 0:
201
+ sess.run(get_next)
202
+ max_iter -= 1
203
+ except tf.errors.OutOfRangeError:
204
+ pass
205
+ except Exception as ex:
206
+ self._success = False
207
+ raise ex
208
+
209
+ @unittest.skipIf('kafka_install_dir' not in os.environ,
210
+ 'Only execute when kafka is available')
211
+ def test_kafka_train(self):
212
+ try:
213
+ # start produce thread
214
+ self._producer = self._create_producer(self._generate)
215
+
216
+ test_utils.set_gpu_id(None)
217
+
218
+ self._success = test_utils.test_single_train_eval(
219
+ 'samples/model_config/deepfm_combo_avazu_kafka.config',
220
+ self._test_dir)
221
+ self.assertTrue(self._success)
222
+ except Exception as ex:
223
+ self._success = False
224
+ raise ex
225
+
226
+ def _generate(self):
227
+ producer = KafkaProducer(
228
+ bootstrap_servers=self._kafka_servers, api_version=(0, 10, 1))
229
+ while not self._should_stop:
230
+ with open('data/test/dwd_avazu_ctr_deepmodel_10w.csv', 'r') as fin:
231
+ for line_str in fin:
232
+ line_str = line_str.strip()
233
+ if self._should_stop:
234
+ break
235
+ if six.PY3:
236
+ line_str = line_str.encode('utf-8')
237
+ producer.send(self._test_topic, line_str)
238
+ producer.close()
239
+ logging.info('data generation thread done.')
240
+
241
+ @unittest.skipIf('kafka_install_dir' not in os.environ,
242
+ 'Only execute when kafka is available')
243
+ def test_kafka_train_chief_redundant(self):
244
+ try:
245
+ # start produce thread
246
+ self._producer = self._create_producer(self._generate)
247
+
248
+ test_utils.set_gpu_id(None)
249
+
250
+ self._success = test_utils.test_distributed_train_eval(
251
+ 'samples/model_config/deepfm_combo_avazu_kafka_chief_redundant.config',
252
+ self._test_dir,
253
+ num_evaluator=1)
254
+ self.assertTrue(self._success)
255
+ except Exception as ex:
256
+ self._success = False
257
+ raise ex
258
+
259
+ @unittest.skipIf('kafka_install_dir' not in os.environ,
260
+ 'Only execute when kafka is available')
261
+ def test_kafka_train_v2(self):
262
+ try:
263
+ # start produce thread
264
+ self._producer = self._create_producer(self._generate)
265
+
266
+ test_utils.set_gpu_id(None)
267
+
268
+ self._success = test_utils.test_single_train_eval(
269
+ 'samples/model_config/deepfm_combo_avazu_kafka_time_offset.config',
270
+ self._test_dir)
271
+
272
+ self.assertTrue(self._success)
273
+ except Exception as ex:
274
+ self._success = False
275
+ raise ex
276
+
277
+ @unittest.skipIf(
278
+ 'kafka_install_dir' not in os.environ or 'oss_path' not in os.environ or
279
+ 'oss_endpoint' not in os.environ and 'oss_ak' not in os.environ or
280
+ 'oss_sk' not in os.environ, 'Only execute when kafka is available')
281
+ def test_kafka_processor(self):
282
+ self._test_kafka_processor(
283
+ 'samples/model_config/taobao_fg_incr_save.config')
284
+
285
+ @unittest.skipIf(
286
+ 'kafka_install_dir' not in os.environ or 'oss_path' not in os.environ or
287
+ 'oss_endpoint' not in os.environ and 'oss_ak' not in os.environ or
288
+ 'oss_sk' not in os.environ, 'Only execute when kafka is available')
289
+ def test_kafka_processor_ev(self):
290
+ self._test_kafka_processor(
291
+ 'samples/model_config/taobao_fg_incr_save_ev.config')
292
+
293
+ def _test_kafka_processor(self, config_path):
294
+ self._success = False
295
+ success = test_utils.test_distributed_train_eval(
296
+ config_path, self._test_dir, total_steps=500)
297
+ self.assertTrue(success)
298
+ export_cmd = """
299
+ python -m easy_rec.python.export --pipeline_config_path %s/pipeline.config
300
+ --export_dir %s/export/sep/ --oss_path=%s --oss_ak=%s --oss_sk=%s --oss_endpoint=%s
301
+ --asset_files ./samples/rtp_fg/fg.json
302
+ --checkpoint_path %s/train/model.ckpt-0
303
+ """ % (self._test_dir, self._test_dir, os.environ['oss_path'],
304
+ os.environ['oss_ak'], os.environ['oss_sk'],
305
+ os.environ['oss_endpoint'], self._test_dir)
306
+ proc = test_utils.run_cmd(export_cmd,
307
+ '%s/log_export_sep.txt' % self._test_dir)
308
+ proc.wait()
309
+ self.assertTrue(proc.returncode == 0)
310
+ files = gfile.Glob(os.path.join(self._test_dir, 'export/sep/[1-9][0-9]*'))
311
+ export_sep_dir = files[0]
312
+
313
+ predict_cmd = """
314
+ python -m easy_rec.python.inference.processor.test --saved_model_dir %s
315
+ --input_path data/test/rtp/taobao_test_feature.txt
316
+ --output_path %s/processor.out --test_dir %s
317
+ """ % (export_sep_dir, self._test_dir, self._test_dir)
318
+ envs = dict(os.environ)
319
+ envs['PROCESSOR_TEST'] = '1'
320
+ proc = test_utils.run_cmd(
321
+ predict_cmd, '%s/log_processor.txt' % self._test_dir, env=envs)
322
+ proc.wait()
323
+ self.assertTrue(proc.returncode == 0)
324
+
325
+ with open('%s/processor.out' % self._test_dir, 'r') as fin:
326
+ processor_out = []
327
+ for line_str in fin:
328
+ line_str = line_str.strip()
329
+ processor_out.append(json.loads(line_str))
330
+
331
+ predictor = Predictor(os.path.join(self._test_dir, 'train/export/final/'))
332
+ with open('data/test/rtp/taobao_test_feature.txt', 'r') as fin:
333
+ inputs = []
334
+ for line_str in fin:
335
+ line_str = line_str.strip()
336
+ line_tok = line_str.split(';')[-1]
337
+ line_tok = line_tok.split(chr(2))
338
+ inputs.append(line_tok)
339
+ output_res = predictor.predict(inputs, batch_size=1024)
340
+
341
+ with open('%s/predictor.out' % self._test_dir, 'w') as fout:
342
+ for i in range(len(output_res)):
343
+ fout.write(
344
+ json.dumps(output_res[i], cls=numpy_utils.NumpyEncoder) + '\n')
345
+
346
+ for i in range(len(output_res)):
347
+ val0 = output_res[i]['probs']
348
+ val1 = processor_out[i]['probs']
349
+ diff = np.abs(val0 - val1)
350
+ assert diff < 1e-4, 'too much difference[%.6f] >= 1e-4' % diff
351
+ self._success = True
352
+
353
+ @unittest.skipIf('kafka_install_dir' not in os.environ,
354
+ 'Only execute when kafka is available')
355
+ def test_kafka_train_v3(self):
356
+ try:
357
+ # start produce thread
358
+ self._producer = self._create_producer(self._generate)
359
+
360
+ test_utils.set_gpu_id(None)
361
+
362
+ self._success = test_utils.test_single_train_eval(
363
+ 'samples/model_config/deepfm_combo_avazu_kafka_time_offset2.config',
364
+ self._test_dir)
365
+
366
+ self.assertTrue(self._success)
367
+ except Exception as ex:
368
+ self._success = False
369
+ raise ex
370
+
371
+
372
+ if __name__ == '__main__':
373
+ tf.test.main()
@@ -0,0 +1,122 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+
4
+ import json
5
+ import logging
6
+ import os
7
+ import unittest
8
+
9
+ import numpy as np
10
+ import tensorflow as tf
11
+ from tensorflow.python.platform import gfile
12
+
13
+ from easy_rec.python.inference.predictor import Predictor
14
+ from easy_rec.python.utils import numpy_utils
15
+ from easy_rec.python.utils import test_utils
16
+
17
+
18
+ class LocalIncrTest(tf.test.TestCase):
19
+
20
+ def setUp(self):
21
+ self._success = True
22
+ self._test_dir = test_utils.get_tmp_dir()
23
+
24
+ logging.info('Testing %s.%s, test_dir=%s' %
25
+ (type(self).__name__, self._testMethodName, self._test_dir))
26
+ self._log_dir = os.path.join(self._test_dir, 'logs')
27
+ if not gfile.IsDirectory(self._log_dir):
28
+ gfile.MakeDirs(self._log_dir)
29
+
30
+ @unittest.skipIf(
31
+ 'oss_path' not in os.environ or
32
+ 'oss_endpoint' not in os.environ and 'oss_ak' not in os.environ or
33
+ 'oss_sk' not in os.environ, 'Only execute when kafka is available')
34
+ def test_incr_save(self):
35
+ self._test_incr_save(
36
+ 'samples/model_config/taobao_fg_incr_save_local.config')
37
+
38
+ @unittest.skipIf(
39
+ 'oss_path' not in os.environ or
40
+ 'oss_endpoint' not in os.environ and 'oss_ak' not in os.environ or
41
+ 'oss_sk' not in os.environ, 'Only execute when kafka is available')
42
+ def test_incr_save_ev(self):
43
+ self._test_incr_save(
44
+ 'samples/model_config/taobao_fg_incr_save_ev_local.config')
45
+
46
+ @unittest.skipIf(
47
+ 'oss_path' not in os.environ or
48
+ 'oss_endpoint' not in os.environ and 'oss_ak' not in os.environ or
49
+ 'oss_sk' not in os.environ, 'Only execute when kafka is available')
50
+ def test_incr_save_share_ev(self):
51
+ self._test_incr_save(
52
+ 'samples/model_config/taobao_fg_incr_save_share_ev_local.config')
53
+
54
+ def _test_incr_save(self, config_path):
55
+ self._success = False
56
+ success = test_utils.test_distributed_train_eval(
57
+ config_path,
58
+ self._test_dir,
59
+ total_steps=100,
60
+ edit_config_json={
61
+ 'train_config.incr_save_config.fs.mount_path':
62
+ os.path.join(self._test_dir, 'train/incr_save/')
63
+ })
64
+ self.assertTrue(success)
65
+ export_cmd = """
66
+ python -m easy_rec.python.export --pipeline_config_path %s/pipeline.config
67
+ --export_dir %s/export/sep/ --oss_path=%s --oss_ak=%s --oss_sk=%s --oss_endpoint=%s
68
+ --asset_files ./samples/rtp_fg/fg.json
69
+ --checkpoint_path %s/train/model.ckpt-0
70
+ """ % (self._test_dir, self._test_dir, os.environ['oss_path'],
71
+ os.environ['oss_ak'], os.environ['oss_sk'],
72
+ os.environ['oss_endpoint'], self._test_dir)
73
+ proc = test_utils.run_cmd(export_cmd,
74
+ '%s/log_export_sep.txt' % self._test_dir)
75
+ proc.wait()
76
+ self.assertTrue(proc.returncode == 0)
77
+ files = gfile.Glob(os.path.join(self._test_dir, 'export/sep/[1-9][0-9]*'))
78
+ export_sep_dir = files[0]
79
+
80
+ predict_cmd = """
81
+ python -m easy_rec.python.inference.processor.test --saved_model_dir %s
82
+ --input_path data/test/rtp/taobao_test_feature.txt
83
+ --output_path %s/processor.out --test_dir %s
84
+ """ % (export_sep_dir, self._test_dir, self._test_dir)
85
+ envs = dict(os.environ)
86
+ envs['PROCESSOR_TEST'] = '1'
87
+ proc = test_utils.run_cmd(
88
+ predict_cmd, '%s/log_processor.txt' % self._test_dir, env=envs)
89
+ proc.wait()
90
+ self.assertTrue(proc.returncode == 0)
91
+
92
+ with open('%s/processor.out' % self._test_dir, 'r') as fin:
93
+ processor_out = []
94
+ for line_str in fin:
95
+ line_str = line_str.strip()
96
+ processor_out.append(json.loads(line_str))
97
+
98
+ predictor = Predictor(os.path.join(self._test_dir, 'train/export/final/'))
99
+ with open('data/test/rtp/taobao_test_feature.txt', 'r') as fin:
100
+ inputs = []
101
+ for line_str in fin:
102
+ line_str = line_str.strip()
103
+ line_tok = line_str.split(';')[-1]
104
+ line_tok = line_tok.split(chr(2))
105
+ inputs.append(line_tok)
106
+ output_res = predictor.predict(inputs, batch_size=1024)
107
+
108
+ with open('%s/predictor.out' % self._test_dir, 'w') as fout:
109
+ for i in range(len(output_res)):
110
+ fout.write(
111
+ json.dumps(output_res[i], cls=numpy_utils.NumpyEncoder) + '\n')
112
+
113
+ for i in range(len(output_res)):
114
+ val0 = output_res[i]['probs']
115
+ val1 = processor_out[i]['probs']
116
+ diff = np.abs(val0 - val1)
117
+ assert diff < 1e-4, 'too much difference[%.6f] >= 1e-4' % diff
118
+ self._success = True
119
+
120
+
121
+ if __name__ == '__main__':
122
+ tf.test.main()
@@ -0,0 +1,110 @@
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ import tensorflow as tf
3
+
4
+ from easy_rec.python.loss.circle_loss import circle_loss
5
+ from easy_rec.python.loss.circle_loss import get_anchor_positive_triplet_mask
6
+
7
+ from easy_rec.python.loss.f1_reweight_loss import f1_reweight_sigmoid_cross_entropy # NOQA
8
+
9
+ from easy_rec.python.loss.softmax_loss_with_negative_mining import softmax_loss_with_negative_mining # NOQA
10
+
11
+ if tf.__version__ >= '2.0':
12
+ tf = tf.compat.v1
13
+
14
+
15
+ class LossTest(tf.test.TestCase):
16
+
17
+ def test_f1_reweighted_loss(self):
18
+ print('test_f1_reweighted_loss')
19
+ logits = tf.constant([0.1, 0.5, 0.3, 0.8, -0.1, 0.3])
20
+ labels = tf.constant([1, 1, 0, 0, 1, 1])
21
+ loss = f1_reweight_sigmoid_cross_entropy(
22
+ labels=labels, logits=logits, beta_square=4)
23
+ with self.test_session() as sess:
24
+ loss_val = sess.run(loss)
25
+ self.assertAlmostEqual(loss_val, 0.47844395, delta=1e-5)
26
+
27
+ def test_softmax_loss_with_negative_mining(self):
28
+ print('test_softmax_loss_with_negative_mining')
29
+ user_emb = tf.constant([[0.1, 0.5, 0.3], [0.8, -0.1, 0.3], [0.28, 0.3, 0.9],
30
+ [0.37, 0.45, 0.93], [-0.7, 0.15, 0.03],
31
+ [0.18, 0.9, -0.3]])
32
+ item_emb = tf.constant([[0.1, -0.5, 0.3], [0.8, -0.31, 0.3],
33
+ [0.7, -0.45, 0.15], [0.08, -0.31, -0.9],
34
+ [-0.7, 0.85, 0.03], [0.18, 0.89, -0.3]])
35
+
36
+ label = tf.constant([1, 1, 0, 0, 1, 1])
37
+ loss = softmax_loss_with_negative_mining(
38
+ user_emb, item_emb, label, num_negative_samples=2, seed=1)
39
+ with self.test_session() as sess:
40
+ loss_val = sess.run(loss)
41
+ self.assertAlmostEqual(loss_val, 0.48577175, delta=1e-5)
42
+
43
+ def test_circle_loss(self):
44
+ print('test_circle_loss')
45
+ emb = tf.constant([[0.1, 0.2, 0.15, 0.1], [0.3, 0.6, 0.45, 0.3],
46
+ [0.13, 0.6, 0.45, 0.3], [0.3, 0.26, 0.45, 0.3],
47
+ [0.3, 0.6, 0.5, 0.13], [0.08, 0.43, 0.21, 0.6]],
48
+ dtype=tf.float32)
49
+ label = tf.constant([1, 1, 2, 2, 3, 3])
50
+ loss = circle_loss(emb, label, label, margin=0.25, gamma=64)
51
+ with self.test_session() as sess:
52
+ loss_val = sess.run(loss)
53
+ self.assertAlmostEqual(loss_val, 52.75707, delta=1e-5)
54
+
55
+ def test_triplet_mask(self):
56
+ print('test_triplet_mask')
57
+ label = tf.constant([1, 1, 2, 2, 3, 3, 4, 5])
58
+ positive_mask = tf.constant(
59
+ [[0., 1., 0., 0., 0., 0., 0., 0.], [1., 0., 0., 0., 0., 0., 0., 0.],
60
+ [0., 0., 0., 1., 0., 0., 0., 0.], [0., 0., 1., 0., 0., 0., 0., 0.],
61
+ [0., 0., 0., 0., 0., 1., 0., 0.], [0., 0., 0., 0., 1., 0., 0., 0.],
62
+ [0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0.]],
63
+ dtype=tf.float32)
64
+ negative_mask = tf.constant(
65
+ [[0., 0., 1., 1., 1., 1., 1., 1.], [0., 0., 1., 1., 1., 1., 1., 1.],
66
+ [1., 1., 0., 0., 1., 1., 1., 1.], [1., 1., 0., 0., 1., 1., 1., 1.],
67
+ [1., 1., 1., 1., 0., 0., 1., 1.], [1., 1., 1., 1., 0., 0., 1., 1.],
68
+ [1., 1., 1., 1., 1., 1., 0., 1.], [1., 1., 1., 1., 1., 1., 1., 0.]],
69
+ dtype=tf.float32)
70
+ with self.test_session():
71
+ pos_mask = get_anchor_positive_triplet_mask(label, label)
72
+ self.assertAllEqual(positive_mask, pos_mask)
73
+
74
+ neg_mask = _get_anchor_negative_triplet_mask(label, label)
75
+ self.assertAllEqual(negative_mask, neg_mask)
76
+
77
+ batch_size = label.shape.as_list()[0]
78
+ neg_mask2 = 1 - pos_mask - tf.eye(batch_size)
79
+ self.assertAllEqual(neg_mask, neg_mask2)
80
+
81
+
82
+ def _get_anchor_negative_triplet_mask(labels, sessions):
83
+ """Return a 2D mask where mask[a, n] is 1.0 iff a and n have distinct session or label.
84
+
85
+ Args:
86
+ sessions: a `Tensor` with shape [batch_size]
87
+ labels: a `Tensor` with shape [batch_size]
88
+
89
+ Returns:
90
+ mask: tf.bool `Tensor` with shape [batch_size, batch_size]
91
+ """
92
+ # Check if sessions[i] != sessions[k]
93
+ # Uses broadcasting where the 1st argument has shape (1, batch_size) and the 2nd (batch_size, 1)
94
+ session_not_equal = tf.not_equal(
95
+ tf.expand_dims(sessions, 0), tf.expand_dims(sessions, 1))
96
+
97
+ if labels is sessions:
98
+ return tf.cast(session_not_equal, tf.float32)
99
+
100
+ # Check if labels[i] != labels[k]
101
+ # Uses broadcasting where the 1st argument has shape (1, batch_size) and the 2nd (batch_size, 1)
102
+ label_not_equal = tf.not_equal(
103
+ tf.expand_dims(labels, 0), tf.expand_dims(labels, 1))
104
+
105
+ mask = tf.logical_or(session_not_equal, label_not_equal)
106
+ return tf.cast(mask, tf.float32)
107
+
108
+
109
+ if __name__ == '__main__':
110
+ tf.test.main()
@@ -0,0 +1,61 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+
4
+ import logging
5
+ import os
6
+ import subprocess
7
+
8
+ from easy_rec.python.test.odps_test_util import get_oss_bucket
9
+
10
+
11
+ class OdpsCommand:
12
+
13
+ def __init__(self, odps_oss_config):
14
+ """Wrapper for running odps command.
15
+
16
+ Args:
17
+ odps_oss_config: instance of easy_rec.python.utils.odps_test_util.OdpsOSSConfig
18
+ """
19
+ self.bucket = get_oss_bucket(odps_oss_config.oss_key,
20
+ odps_oss_config.oss_secret,
21
+ odps_oss_config.endpoint,
22
+ odps_oss_config.bucket_name)
23
+ self.bucket_name = odps_oss_config.bucket_name
24
+ self.temp_dir = odps_oss_config.temp_dir
25
+ self.log_path = odps_oss_config.log_dir
26
+ self.odpscmd = odps_oss_config.odpscmd_path
27
+ self.odps_config_path = odps_oss_config.odps_config_path
28
+ self.algo_project = odps_oss_config.algo_project
29
+ self.algo_res_project = odps_oss_config.algo_res_project
30
+ self.algo_version = odps_oss_config.algo_version
31
+
32
+ def run_odps_cmd(self, script_file):
33
+ """Run sql use odpscmd.
34
+
35
+ Args:
36
+ script_file: xxx.sql file, to be runned by odpscmd
37
+ Raise:
38
+ ValueError if failed
39
+ """
40
+ exec_file_path = os.path.join(self.temp_dir, script_file)
41
+ file_name = os.path.split(script_file)[1]
42
+ log_file = os.path.join(self.log_path, file_name)
43
+
44
+ if self.odps_config_path is None:
45
+ cmd = 'nohup %s -f %s > %s.log 2>&1' % (self.odpscmd, exec_file_path,
46
+ log_file)
47
+ else:
48
+ cmd = 'nohup %s --config=%s -f %s > %s.log 2>&1' % (
49
+ self.odpscmd, self.odps_config_path, exec_file_path, log_file)
50
+ logging.info('will run cmd: %s' % (cmd))
51
+ proc = subprocess.Popen(cmd, shell=True)
52
+ proc.wait()
53
+ if (proc.returncode == 0):
54
+ logging.info('%s run succeed' % script_file)
55
+ else:
56
+ raise ValueError('%s run FAILED: please check log file:%s.log' %
57
+ (exec_file_path, log_file))
58
+
59
+ def run_list(self, files):
60
+ for f in files:
61
+ self.run_odps_cmd(f)