easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
easy_rec/__init__.py ADDED
@@ -0,0 +1,114 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+
4
+ import logging
5
+ import os
6
+ import platform
7
+ import sys
8
+
9
+ from easy_rec.version import __version__
10
+
11
+ curr_dir, _ = os.path.split(__file__)
12
+ parent_dir = os.path.dirname(curr_dir)
13
+ sys.path.insert(0, parent_dir)
14
+
15
+ logging.basicConfig(
16
+ level=logging.INFO, format='[%(asctime)s][%(levelname)s] %(message)s')
17
+
18
+ # Avoid import tensorflow which conflicts with the version used in EasyRecProcessor
19
+ if 'PROCESSOR_TEST' not in os.environ:
20
+ from tensorflow.python.platform import tf_logging
21
+ # In DeepRec, logger.propagate of tf_logging is False, should be True
22
+ tf_logging._logger.propagate = True
23
+
24
+ def get_ops_dir():
25
+ import tensorflow as tf
26
+ if platform.system() == 'Linux':
27
+ ops_dir = os.path.join(curr_dir, 'python/ops')
28
+ if 'PAI' in tf.__version__:
29
+ ops_dir = os.path.join(ops_dir, '1.12_pai')
30
+ elif tf.__version__.startswith('1.12'):
31
+ ops_dir = os.path.join(ops_dir, '1.12')
32
+ elif tf.__version__.startswith('1.15'):
33
+ if 'IS_ON_PAI' in os.environ:
34
+ ops_dir = os.path.join(ops_dir, 'DeepRec')
35
+ else:
36
+ ops_dir = os.path.join(ops_dir, '1.15')
37
+ else:
38
+ tmp_version = tf.__version__.split('.')
39
+ tmp_version = '.'.join(tmp_version[:2])
40
+ ops_dir = os.path.join(ops_dir, tmp_version)
41
+ return ops_dir
42
+ else:
43
+ return None
44
+
45
+ ops_dir = get_ops_dir()
46
+ if ops_dir is not None and not os.path.exists(ops_dir):
47
+ logging.warning('ops_dir[%s] does not exist' % ops_dir)
48
+ ops_dir = None
49
+
50
+ from easy_rec.python.inference.predictor import Predictor # isort:skip # noqa: E402
51
+ from easy_rec.python.main import evaluate # isort:skip # noqa: E402
52
+ from easy_rec.python.main import distribute_evaluate # isort:skip # noqa: E402
53
+ from easy_rec.python.main import export # isort:skip # noqa: E402
54
+ from easy_rec.python.main import train_and_evaluate # isort:skip # noqa: E402
55
+ from easy_rec.python.main import export_checkpoint # isort:skip # noqa: E402
56
+
57
+ try:
58
+ import tensorflow_io.oss
59
+ except Exception:
60
+ pass
61
+
62
+ print('easy_rec version: %s' % __version__)
63
+ print('Usage: easy_rec.help()')
64
+
65
+ _global_config = {}
66
+
67
+
68
+ def help():
69
+ print("""
70
+ 1 Train
71
+ 1.1 Train 1gpu
72
+ CUDA_VISIBLE_DEVICES=0 python -m easy_rec.python.train_eval
73
+ --pipeline_config_path deepfm_combo_on_avazu_ctr.config
74
+ 1.2 Train 2gpu
75
+ sh scripts/train_2gpu.sh deepfm_combo_on_avazu_ctr.config
76
+ 2 Eval
77
+ CUDA_VISIBLE_DEVICES=0 python -m easy_rec.python.eval
78
+ --pipeline_config_path deepfm_combo_on_avazu_ctr.config
79
+ 3 Export
80
+ CUDA_VISIBLE_DEVICES=""
81
+ python -m easy_rec.python.export
82
+ --pipeline_config_path deepfm_combo_on_avazu_ctr.config
83
+ --export_dir models/export
84
+ 4 Create config from excel
85
+ python -m easy_rec.python.tools.create_config_from_excel
86
+ --excel_path dwd_avazu_ctr_multi_tower.xls
87
+ --output_path dwd_avazu_ctr_multi_tower.config
88
+ 5. Inference:
89
+ # use list input
90
+ import csv
91
+ from easy_rec.python.inference.predictor import Predictor
92
+ predictor = Predictor(SAVED_MODEL_DIR)
93
+ with open(INPUT_CSV, 'r') as fin:
94
+ reader = csv.reader(fin)
95
+ inputs = []
96
+ for row in reader:
97
+ inputs.append(row[1:])
98
+ output_res = self._predictor.predict(inputs, batch_size=32)
99
+
100
+ # use dict input
101
+ import csv
102
+ from easy_rec.python.inference.predictor import Predictor
103
+ predictor = Predictor(SAVED_MODEL_DIR)
104
+ field_keys = [ "field1", "field2", "field3", "field4", "field5",
105
+ "field6", "field7", "field8", "field9", "field10",
106
+ "field11", "field12", "field13", "field14", "field15",
107
+ "field16", "field17", "field18", "field19", "field20" ]
108
+ with open(INPUT_CSV, 'r') as fin:
109
+ reader = csv.reader(fin)
110
+ inputs = []
111
+ for row in reader:
112
+ inputs.append({ f : row[fid+1] for fid, f in enumerate(field_keys) })
113
+ output_res = self._predictor.predict(inputs, batch_size=32)
114
+ """)
File without changes
File without changes
@@ -0,0 +1,78 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+ """Builder function to construct tf-slim arg_scope for convolution, fc ops."""
17
+ import tensorflow as tf
18
+
19
+ from easy_rec.python.compat import regularizers
20
+
21
+ if tf.__version__ >= '2.0':
22
+ tf = tf.compat.v1
23
+
24
+
25
+ def build_regularizer(regularizer):
26
+ """Builds a tensorflow regularizer from config.
27
+
28
+ Args:
29
+ regularizer: hyperparams_pb2.Hyperparams.regularizer proto.
30
+
31
+ Returns:
32
+ tensorflow regularizer.
33
+
34
+ Raises:
35
+ ValueError: On unknown regularizer.
36
+ """
37
+ regularizer_oneof = regularizer.WhichOneof('regularizer_oneof')
38
+ if regularizer_oneof == 'l1_regularizer':
39
+ return regularizers.l1_regularizer(
40
+ scale=float(regularizer.l1_regularizer.scale))
41
+ if regularizer_oneof == 'l2_regularizer':
42
+ return regularizers.l2_regularizer(
43
+ scale=float(regularizer.l2_regularizer.scale))
44
+ if regularizer_oneof == 'l1_l2_regularizer':
45
+ return regularizers.l1_l2_regularizer(
46
+ scale_l1=float(regularizer.l1_l2_regularizer.scale_l1),
47
+ scale_l2=float(regularizer.l1_l2_regularizer.scale_l2))
48
+
49
+ raise ValueError('Unknown regularizer function: {}'.format(regularizer_oneof))
50
+
51
+
52
+ def build_initializer(initializer):
53
+ """Build a tf initializer from config.
54
+
55
+ Args:
56
+ initializer: hyperparams_pb2.Hyperparams.regularizer proto.
57
+
58
+ Returns:
59
+ tf initializer.
60
+
61
+ Raises:
62
+ ValueError: On unknown initializer.
63
+ """
64
+ initializer_oneof = initializer.WhichOneof('initializer_oneof')
65
+ if initializer_oneof == 'truncated_normal_initializer':
66
+ return tf.truncated_normal_initializer(
67
+ mean=initializer.truncated_normal_initializer.mean,
68
+ stddev=initializer.truncated_normal_initializer.stddev)
69
+ if initializer_oneof == 'random_normal_initializer':
70
+ return tf.random_normal_initializer(
71
+ mean=initializer.random_normal_initializer.mean,
72
+ stddev=initializer.random_normal_initializer.stddev)
73
+ if initializer_oneof == 'glorot_normal_initializer':
74
+ return tf.glorot_normal_initializer()
75
+ if initializer_oneof == 'constant_initializer':
76
+ return tf.constant_initializer(
77
+ [x for x in initializer.constant_initializer.consts])
78
+ raise ValueError('Unknown initializer function: {}'.format(initializer_oneof))
@@ -0,0 +1,333 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import logging
4
+
5
+ import numpy as np
6
+ import tensorflow as tf
7
+
8
+ from easy_rec.python.loss.focal_loss import sigmoid_focal_loss_with_logits
9
+ from easy_rec.python.loss.jrc_loss import jrc_loss
10
+ from easy_rec.python.loss.listwise_loss import listwise_distill_loss
11
+ from easy_rec.python.loss.listwise_loss import listwise_rank_loss
12
+ from easy_rec.python.loss.pairwise_loss import pairwise_focal_loss
13
+ from easy_rec.python.loss.pairwise_loss import pairwise_hinge_loss
14
+ from easy_rec.python.loss.pairwise_loss import pairwise_logistic_loss
15
+ from easy_rec.python.loss.pairwise_loss import pairwise_loss
16
+ from easy_rec.python.protos.loss_pb2 import LossType
17
+
18
+ from easy_rec.python.loss.zero_inflated_lognormal import zero_inflated_lognormal_loss # NOQA
19
+
20
+ from easy_rec.python.loss.f1_reweight_loss import f1_reweight_sigmoid_cross_entropy # NOQA
21
+
22
+ if tf.__version__ >= '2.0':
23
+ tf = tf.compat.v1
24
+
25
+
26
+ def build(loss_type,
27
+ label,
28
+ pred,
29
+ loss_weight=1.0,
30
+ num_class=1,
31
+ loss_param=None,
32
+ **kwargs):
33
+ loss_name = kwargs.pop('loss_name') if 'loss_name' in kwargs else 'unknown'
34
+ if loss_type == LossType.CLASSIFICATION:
35
+ if num_class == 1:
36
+ return tf.losses.sigmoid_cross_entropy(
37
+ label, logits=pred, weights=loss_weight, **kwargs)
38
+ else:
39
+ assert label.dtype in [tf.int32, tf.int64], \
40
+ 'label.dtype must in [tf.int32, tf.int64] when use sparse_softmax_cross_entropy.'
41
+ return tf.losses.sparse_softmax_cross_entropy(
42
+ labels=label, logits=pred, weights=loss_weight, **kwargs)
43
+ elif loss_type == LossType.CROSS_ENTROPY_LOSS:
44
+ return tf.losses.log_loss(label, pred, weights=loss_weight, **kwargs)
45
+ elif loss_type == LossType.BINARY_CROSS_ENTROPY_LOSS:
46
+ losses = tf.keras.backend.binary_crossentropy(label, pred, from_logits=True)
47
+ return tf.reduce_mean(losses)
48
+ elif loss_type in [LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS]:
49
+ logging.info('%s is used' % LossType.Name(loss_type))
50
+ return tf.losses.mean_squared_error(
51
+ labels=label, predictions=pred, weights=loss_weight, **kwargs)
52
+ elif loss_type == LossType.ZILN_LOSS:
53
+ loss = zero_inflated_lognormal_loss(label, pred)
54
+ if np.isscalar(loss_weight) and loss_weight != 1.0:
55
+ return loss * loss_weight
56
+ return loss
57
+ elif loss_type == LossType.JRC_LOSS:
58
+ session = kwargs.get('session_ids', None)
59
+ if loss_param is None:
60
+ return jrc_loss(label, pred, session, name=loss_name)
61
+ return jrc_loss(
62
+ label,
63
+ pred,
64
+ session,
65
+ loss_param.alpha,
66
+ loss_weight_strategy=loss_param.loss_weight_strategy,
67
+ sample_weights=loss_weight,
68
+ same_label_loss=loss_param.same_label_loss,
69
+ name=loss_name)
70
+ elif loss_type == LossType.PAIR_WISE_LOSS:
71
+ session = kwargs.get('session_ids', None)
72
+ margin = 0 if loss_param is None else loss_param.margin
73
+ temp = 1.0 if loss_param is None else loss_param.temperature
74
+ return pairwise_loss(
75
+ label,
76
+ pred,
77
+ session_ids=session,
78
+ margin=margin,
79
+ temperature=temp,
80
+ weights=loss_weight,
81
+ name=loss_name)
82
+ elif loss_type == LossType.PAIRWISE_LOGISTIC_LOSS:
83
+ session = kwargs.get('session_ids', None)
84
+ temp = 1.0 if loss_param is None else loss_param.temperature
85
+ ohem_ratio = 1.0 if loss_param is None else loss_param.ohem_ratio
86
+ hinge_margin = None
87
+ if loss_param is not None and loss_param.HasField('hinge_margin'):
88
+ hinge_margin = loss_param.hinge_margin
89
+ lbl_margin = False if loss_param is None else loss_param.use_label_margin
90
+ return pairwise_logistic_loss(
91
+ label,
92
+ pred,
93
+ session_ids=session,
94
+ temperature=temp,
95
+ hinge_margin=hinge_margin,
96
+ ohem_ratio=ohem_ratio,
97
+ weights=loss_weight,
98
+ use_label_margin=lbl_margin,
99
+ name=loss_name)
100
+ elif loss_type == LossType.PAIRWISE_HINGE_LOSS:
101
+ session = kwargs.get('session_ids', None)
102
+ temp, ohem_ratio, margin = 1.0, 1.0, 1.0
103
+ label_is_logits, use_label_margin, use_exponent = True, True, False
104
+ if loss_param is not None:
105
+ temp = loss_param.temperature
106
+ ohem_ratio = loss_param.ohem_ratio
107
+ margin = loss_param.margin
108
+ label_is_logits = loss_param.label_is_logits
109
+ use_label_margin = loss_param.use_label_margin
110
+ use_exponent = loss_param.use_exponent
111
+ return pairwise_hinge_loss(
112
+ label,
113
+ pred,
114
+ session_ids=session,
115
+ temperature=temp,
116
+ margin=margin,
117
+ ohem_ratio=ohem_ratio,
118
+ weights=loss_weight,
119
+ label_is_logits=label_is_logits,
120
+ use_label_margin=use_label_margin,
121
+ use_exponent=use_exponent,
122
+ name=loss_name)
123
+ elif loss_type == LossType.PAIRWISE_FOCAL_LOSS:
124
+ session = kwargs.get('session_ids', None)
125
+ if loss_param is None:
126
+ return pairwise_focal_loss(
127
+ label, pred, session_ids=session, weights=loss_weight, name=loss_name)
128
+ hinge_margin = None
129
+ if loss_param.HasField('hinge_margin'):
130
+ hinge_margin = loss_param.hinge_margin
131
+ return pairwise_focal_loss(
132
+ label,
133
+ pred,
134
+ session_ids=session,
135
+ gamma=loss_param.gamma,
136
+ alpha=loss_param.alpha if loss_param.HasField('alpha') else None,
137
+ hinge_margin=hinge_margin,
138
+ ohem_ratio=loss_param.ohem_ratio,
139
+ temperature=loss_param.temperature,
140
+ weights=loss_weight,
141
+ name=loss_name)
142
+ elif loss_type == LossType.LISTWISE_RANK_LOSS:
143
+ session = kwargs.get('session_ids', None)
144
+ trans_fn, temp, label_is_logits, scale = None, 1.0, False, False
145
+ if loss_param is not None:
146
+ temp = loss_param.temperature
147
+ label_is_logits = loss_param.label_is_logits
148
+ scale = loss_param.scale_logits
149
+ if loss_param.HasField('transform_fn'):
150
+ trans_fn = loss_param.transform_fn
151
+ return listwise_rank_loss(
152
+ label,
153
+ pred,
154
+ session,
155
+ temperature=temp,
156
+ label_is_logits=label_is_logits,
157
+ transform_fn=trans_fn,
158
+ scale_logits=scale,
159
+ weights=loss_weight)
160
+ elif loss_type == LossType.LISTWISE_DISTILL_LOSS:
161
+ session = kwargs.get('session_ids', None)
162
+ trans_fn, temp, label_clip_max_value, scale = None, 1.0, 512.0, False
163
+ if loss_param is not None:
164
+ temp = loss_param.temperature
165
+ label_clip_max_value = loss_param.label_clip_max_value
166
+ scale = loss_param.scale_logits
167
+ if loss_param.HasField('transform_fn'):
168
+ trans_fn = loss_param.transform_fn
169
+ return listwise_distill_loss(
170
+ label,
171
+ pred,
172
+ session,
173
+ temperature=temp,
174
+ label_clip_max_value=label_clip_max_value,
175
+ transform_fn=trans_fn,
176
+ scale_logits=scale,
177
+ weights=loss_weight)
178
+ elif loss_type == LossType.F1_REWEIGHTED_LOSS:
179
+ f1_beta_square = 1.0 if loss_param is None else loss_param.f1_beta_square
180
+ label_smoothing = 0 if loss_param is None else loss_param.label_smoothing
181
+ return f1_reweight_sigmoid_cross_entropy(
182
+ label,
183
+ pred,
184
+ f1_beta_square,
185
+ weights=loss_weight,
186
+ label_smoothing=label_smoothing)
187
+ elif loss_type == LossType.BINARY_FOCAL_LOSS:
188
+ if loss_param is None:
189
+ return sigmoid_focal_loss_with_logits(
190
+ label, pred, sample_weights=loss_weight, name=loss_name)
191
+ gamma = loss_param.gamma
192
+ alpha = None
193
+ if loss_param.HasField('alpha'):
194
+ alpha = loss_param.alpha
195
+ return sigmoid_focal_loss_with_logits(
196
+ label,
197
+ pred,
198
+ gamma=gamma,
199
+ alpha=alpha,
200
+ ohem_ratio=loss_param.ohem_ratio,
201
+ sample_weights=loss_weight,
202
+ label_smoothing=loss_param.label_smoothing,
203
+ name=loss_name)
204
+ else:
205
+ raise ValueError('unsupported loss type: %s' % LossType.Name(loss_type))
206
+
207
+
208
+ def build_kd_loss(kds, prediction_dict, label_dict, feature_dict):
209
+ """Build knowledge distillation loss.
210
+
211
+ Args:
212
+ kds: list of knowledge distillation object of type KD.
213
+ prediction_dict: dict of predict_name to predict tensors.
214
+ label_dict: ordered dict of label_name to label tensors.
215
+ feature_dict: dict of feature name to feature value
216
+
217
+ Return:
218
+ knowledge distillation loss will be add to loss_dict with key: kd_loss.
219
+ """
220
+ loss_dict = {}
221
+ for kd in kds:
222
+ assert kd.pred_name in prediction_dict, \
223
+ 'invalid predict_name: %s available ones: %s' % (
224
+ kd.pred_name, ','.join(prediction_dict.keys()))
225
+
226
+ loss_name = kd.loss_name
227
+ if not loss_name:
228
+ loss_name = 'kd_loss_' + kd.pred_name.replace('/', '_')
229
+ loss_name += '_' + kd.soft_label_name.replace('/', '_')
230
+
231
+ loss_weight = kd.loss_weight
232
+ if kd.HasField('task_space_indicator_name') and kd.HasField(
233
+ 'task_space_indicator_value'):
234
+ in_task_space = tf.to_float(
235
+ tf.equal(feature_dict[kd.task_space_indicator_name],
236
+ kd.task_space_indicator_value))
237
+ loss_weight = loss_weight * (
238
+ kd.in_task_space_weight * in_task_space + kd.out_task_space_weight *
239
+ (1 - in_task_space))
240
+
241
+ label = label_dict[kd.soft_label_name]
242
+ pred = prediction_dict[kd.pred_name]
243
+ epsilon = tf.keras.backend.epsilon()
244
+ num_class = 1 if len(pred.get_shape()) < 2 else pred.get_shape()[-1]
245
+
246
+ if kd.loss_type == LossType.BINARY_CROSS_ENTROPY_LOSS:
247
+ if not kd.label_is_logits: # label is prob
248
+ label = tf.clip_by_value(label, epsilon, 1 - epsilon)
249
+ label = tf.log(label / (1 - label))
250
+ if not kd.pred_is_logits:
251
+ pred = tf.clip_by_value(pred, epsilon, 1 - epsilon)
252
+ pred = tf.log(pred / (1 - pred))
253
+ if kd.temperature > 0:
254
+ label = label / kd.temperature
255
+ pred = pred / kd.temperature
256
+ label = tf.nn.sigmoid(label) # convert to prob
257
+ elif kd.loss_type == LossType.KL_DIVERGENCE_LOSS:
258
+ if not kd.label_is_logits: # label is prob
259
+ if num_class == 1: # for binary classification
260
+ label = tf.clip_by_value(label, epsilon, 1 - epsilon)
261
+ label = tf.log(label / (1 - label))
262
+ else:
263
+ label = tf.math.log(label + epsilon)
264
+ label -= tf.reduce_max(label)
265
+ if not kd.pred_is_logits:
266
+ if num_class == 1: # for binary classification
267
+ pred = tf.clip_by_value(pred, epsilon, 1 - epsilon)
268
+ pred = tf.log(pred / (1 - pred))
269
+ else:
270
+ pred = tf.math.log(pred + epsilon)
271
+ pred -= tf.reduce_max(pred)
272
+ if kd.temperature > 0:
273
+ label = label / kd.temperature
274
+ pred = pred / kd.temperature
275
+ if num_class > 1:
276
+ label = tf.nn.softmax(label)
277
+ pred = tf.nn.softmax(pred)
278
+ else:
279
+ label = tf.nn.sigmoid(label) # convert to prob
280
+ pred = tf.nn.sigmoid(pred) # convert to prob
281
+ elif kd.loss_type == LossType.CROSS_ENTROPY_LOSS:
282
+ if not kd.label_is_logits:
283
+ label = tf.math.log(label + epsilon)
284
+ if not kd.pred_is_logits:
285
+ pred = tf.math.log(pred + epsilon)
286
+ if kd.temperature > 0:
287
+ label = label / kd.temperature
288
+ pred = pred / kd.temperature
289
+ if num_class > 1:
290
+ label = tf.nn.softmax(label)
291
+ pred = tf.nn.softmax(pred)
292
+ elif num_class == 1:
293
+ label = tf.nn.sigmoid(label)
294
+ pred = tf.nn.sigmoid(pred)
295
+
296
+ if kd.loss_type == LossType.KL_DIVERGENCE_LOSS:
297
+ if num_class == 1:
298
+ label = tf.expand_dims(label, 1) # [B, 1]
299
+ labels = tf.concat([1 - label, label], axis=1) # [B, 2]
300
+ pred = tf.expand_dims(pred, 1) # [B, 1]
301
+ preds = tf.concat([1 - pred, pred], axis=1) # [B, 2]
302
+ else:
303
+ labels = label
304
+ preds = pred
305
+ losses = tf.keras.losses.KLD(labels, preds)
306
+ loss_dict[loss_name] = tf.reduce_mean(
307
+ losses, name=loss_name) * loss_weight
308
+ elif kd.loss_type == LossType.BINARY_CROSS_ENTROPY_LOSS:
309
+ losses = tf.keras.backend.binary_crossentropy(
310
+ label, pred, from_logits=True)
311
+ loss_dict[loss_name] = tf.reduce_mean(
312
+ losses, name=loss_name) * loss_weight
313
+ elif kd.loss_type == LossType.CROSS_ENTROPY_LOSS:
314
+ loss_dict[loss_name] = tf.losses.log_loss(
315
+ label, pred, weights=loss_weight)
316
+ elif kd.loss_type == LossType.L2_LOSS:
317
+ loss_dict[loss_name] = tf.losses.mean_squared_error(
318
+ labels=label, predictions=pred, weights=loss_weight)
319
+ else:
320
+ loss_param = kd.WhichOneof('loss_param')
321
+ kwargs = {}
322
+ if loss_param is not None:
323
+ loss_param = getattr(kd, loss_param)
324
+ if hasattr(loss_param, 'session_name'):
325
+ kwargs['session_ids'] = feature_dict[loss_param.session_name]
326
+ loss_dict[loss_name] = build(
327
+ kd.loss_type,
328
+ label,
329
+ pred,
330
+ loss_weight=loss_weight,
331
+ loss_param=loss_param,
332
+ **kwargs)
333
+ return loss_dict