easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
@@ -0,0 +1,485 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import logging
4
+
5
+ import tensorflow as tf
6
+ from tensorflow.python.ops import math_ops
7
+
8
+ from easy_rec.python.builders import loss_builder
9
+ from easy_rec.python.model.easy_rec_model import EasyRecModel
10
+ from easy_rec.python.protos.loss_pb2 import LossType
11
+
12
+ from easy_rec.python.loss.zero_inflated_lognormal import zero_inflated_lognormal_pred # NOQA
13
+
14
+ if tf.__version__ >= '2.0':
15
+ tf = tf.compat.v1
16
+
17
+
18
+ class RankModel(EasyRecModel):
19
+
20
+ def __init__(self,
21
+ model_config,
22
+ feature_configs,
23
+ features,
24
+ labels=None,
25
+ is_training=False):
26
+ super(RankModel, self).__init__(model_config, feature_configs, features,
27
+ labels, is_training)
28
+ self._loss_type = self._model_config.loss_type
29
+ self._num_class = self._model_config.num_class
30
+ self._losses = self._model_config.losses
31
+ if self._labels is not None:
32
+ if model_config.HasField('label_name'):
33
+ self._label_name = model_config.label_name
34
+ else:
35
+ self._label_name = list(self._labels.keys())[0]
36
+ self._outputs = []
37
+
38
+ def build_predict_graph(self):
39
+ if not self.has_backbone:
40
+ raise NotImplementedError(
41
+ 'method `build_predict_graph` must be implemented when backbone network do not exits'
42
+ )
43
+ model = self._model_config.WhichOneof('model')
44
+ assert model == 'model_params', '`model_params` must be configured'
45
+ config = self._model_config.model_params
46
+ for out in config.outputs:
47
+ self._outputs.append(out)
48
+
49
+ output = self.backbone
50
+ if int(output.shape[-1]) != self._num_class:
51
+ logging.info('add head logits layer for rank model')
52
+ output = tf.layers.dense(output, self._num_class, name='output')
53
+ self._add_to_prediction_dict(output)
54
+ return self._prediction_dict
55
+
56
+ def _output_to_prediction_impl(self,
57
+ output,
58
+ loss_type,
59
+ num_class=1,
60
+ suffix=''):
61
+ prediction_dict = {}
62
+ binary_loss_type = {
63
+ LossType.F1_REWEIGHTED_LOSS, LossType.PAIR_WISE_LOSS,
64
+ LossType.BINARY_FOCAL_LOSS, LossType.PAIRWISE_FOCAL_LOSS,
65
+ LossType.LISTWISE_RANK_LOSS, LossType.PAIRWISE_HINGE_LOSS,
66
+ LossType.PAIRWISE_LOGISTIC_LOSS, LossType.BINARY_CROSS_ENTROPY_LOSS,
67
+ LossType.LISTWISE_DISTILL_LOSS
68
+ }
69
+ if loss_type in binary_loss_type:
70
+ assert num_class == 1, 'num_class must be 1 when loss type is %s' % loss_type.name
71
+ output = tf.squeeze(output, axis=1)
72
+ probs = tf.sigmoid(output)
73
+ tf.summary.scalar('prediction/probs', tf.reduce_mean(probs))
74
+ prediction_dict['logits' + suffix] = output
75
+ prediction_dict['probs' + suffix] = probs
76
+ elif loss_type == LossType.JRC_LOSS:
77
+ assert num_class == 2, 'num_class must be 2 when loss type is JRC_LOSS'
78
+ probs = tf.nn.softmax(output, axis=1)
79
+ tf.summary.scalar('prediction/probs', tf.reduce_mean(probs[:, 1]))
80
+ prediction_dict['logits' + suffix] = output
81
+ prediction_dict['pos_logits' + suffix] = output[:, 1]
82
+ prediction_dict['probs' + suffix] = probs[:, 1]
83
+ elif loss_type == LossType.ZILN_LOSS:
84
+ assert num_class == 3, 'num_class must be 3 when loss type is ZILN_LOSS'
85
+ probs, preds = zero_inflated_lognormal_pred(output)
86
+ tf.summary.scalar('prediction/probs', tf.reduce_mean(probs))
87
+ tf.summary.scalar('prediction/y', tf.reduce_mean(preds))
88
+ prediction_dict['logits' + suffix] = output
89
+ prediction_dict['probs' + suffix] = probs
90
+ prediction_dict['y' + suffix] = preds
91
+ elif loss_type == LossType.CLASSIFICATION:
92
+ if num_class == 1:
93
+ output = tf.squeeze(output, axis=1)
94
+ probs = tf.sigmoid(output)
95
+ tf.summary.scalar('prediction/probs', tf.reduce_mean(probs))
96
+ prediction_dict['logits' + suffix] = output
97
+ prediction_dict['probs' + suffix] = probs
98
+ else:
99
+ probs = tf.nn.softmax(output, axis=1)
100
+ prediction_dict['logits' + suffix] = output
101
+ prediction_dict['logits' + suffix + '_1'] = output[:, 1]
102
+ prediction_dict['probs' + suffix] = probs
103
+ prediction_dict['probs' + suffix + '_1'] = probs[:, 1]
104
+ prediction_dict['logits' + suffix + '_y'] = math_ops.reduce_max(
105
+ output, axis=1)
106
+ prediction_dict['probs' + suffix + '_y'] = math_ops.reduce_max(
107
+ probs, axis=1)
108
+ prediction_dict['y' + suffix] = tf.argmax(output, axis=1)
109
+ elif loss_type == LossType.L2_LOSS:
110
+ output = tf.squeeze(output, axis=1)
111
+ prediction_dict['y' + suffix] = output
112
+ elif loss_type == LossType.SIGMOID_L2_LOSS:
113
+ output = tf.squeeze(output, axis=1)
114
+ prediction_dict['y' + suffix] = tf.sigmoid(output)
115
+ return prediction_dict
116
+
117
+ def _add_to_prediction_dict(self, output):
118
+ if len(self._losses) == 0:
119
+ prediction_dict = self._output_to_prediction_impl(
120
+ output, loss_type=self._loss_type, num_class=self._num_class)
121
+ self._prediction_dict.update(prediction_dict)
122
+ else:
123
+ for loss in self._losses:
124
+ prediction_dict = self._output_to_prediction_impl(
125
+ output, loss_type=loss.loss_type, num_class=self._num_class)
126
+ self._prediction_dict.update(prediction_dict)
127
+
128
+ def build_rtp_output_dict(self):
129
+ """Forward tensor as `rank_predict`, which is a special node for RTP."""
130
+ outputs = {}
131
+ outputs.update(super(RankModel, self).build_rtp_output_dict())
132
+ rank_predict = None
133
+ try:
134
+ op = tf.get_default_graph().get_operation_by_name('rank_predict')
135
+ if len(op.outputs) != 1:
136
+ raise ValueError(
137
+ ('failed to build RTP rank_predict output: op {}[{}] has output ' +
138
+ 'size {}, however 1 is expected.').format(op.name, op.type,
139
+ len(op.outputs)))
140
+ rank_predict = op.outputs[0]
141
+ except KeyError:
142
+ forwarded = None
143
+ loss_types = {self._loss_type}
144
+ if len(self._losses) > 0:
145
+ loss_types = {loss.loss_type for loss in self._losses}
146
+ binary_loss_set = {
147
+ LossType.CLASSIFICATION, LossType.F1_REWEIGHTED_LOSS,
148
+ LossType.PAIR_WISE_LOSS, LossType.BINARY_FOCAL_LOSS,
149
+ LossType.PAIRWISE_FOCAL_LOSS, LossType.PAIRWISE_LOGISTIC_LOSS,
150
+ LossType.JRC_LOSS, LossType.LISTWISE_DISTILL_LOSS,
151
+ LossType.LISTWISE_RANK_LOSS
152
+ }
153
+ if loss_types & binary_loss_set:
154
+ if 'probs' in self._prediction_dict:
155
+ forwarded = self._prediction_dict['probs']
156
+ else:
157
+ raise ValueError(
158
+ 'failed to build RTP rank_predict output: classification model ' +
159
+ "expect 'probs' prediction, which is not found. Please check if" +
160
+ ' build_predict_graph() is called.')
161
+ elif loss_types & {
162
+ LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS, LossType.ZILN_LOSS
163
+ }:
164
+ if 'y' in self._prediction_dict:
165
+ forwarded = self._prediction_dict['y']
166
+ else:
167
+ raise ValueError(
168
+ 'failed to build RTP rank_predict output: regression model expect'
169
+ +
170
+ "'y' prediction, which is not found. Please check if build_predic"
171
+ + 't_graph() is called.')
172
+ else:
173
+ logging.warning(
174
+ 'failed to build RTP rank_predict: unsupported loss type {}'.format(
175
+ loss_types))
176
+ if forwarded is not None:
177
+ rank_predict = tf.identity(forwarded, name='rank_predict')
178
+ if rank_predict is not None:
179
+ outputs['rank_predict'] = rank_predict
180
+ return outputs
181
+
182
+ def _build_loss_impl(self,
183
+ loss_type,
184
+ label_name,
185
+ loss_weight=1.0,
186
+ num_class=1,
187
+ suffix='',
188
+ loss_name='',
189
+ loss_param=None):
190
+ loss_dict = {}
191
+ binary_loss_type = {
192
+ LossType.F1_REWEIGHTED_LOSS, LossType.PAIR_WISE_LOSS,
193
+ LossType.BINARY_FOCAL_LOSS, LossType.PAIRWISE_FOCAL_LOSS,
194
+ LossType.LISTWISE_RANK_LOSS, LossType.PAIRWISE_HINGE_LOSS,
195
+ LossType.PAIRWISE_LOGISTIC_LOSS, LossType.JRC_LOSS,
196
+ LossType.LISTWISE_DISTILL_LOSS, LossType.ZILN_LOSS
197
+ }
198
+ if loss_type in {
199
+ LossType.CLASSIFICATION, LossType.BINARY_CROSS_ENTROPY_LOSS
200
+ }:
201
+ loss_name = loss_name if loss_name else 'cross_entropy_loss' + suffix
202
+ pred = self._prediction_dict['logits' + suffix]
203
+ elif loss_type in binary_loss_type:
204
+ if not loss_name:
205
+ loss_name = LossType.Name(loss_type).lower() + suffix
206
+ else:
207
+ loss_name = loss_name + suffix
208
+ pred = self._prediction_dict['logits' + suffix]
209
+ elif loss_type in [LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS]:
210
+ loss_name = loss_name if loss_name else 'l2_loss' + suffix
211
+ pred = self._prediction_dict['y' + suffix]
212
+ else:
213
+ raise ValueError('invalid loss type: %s' % LossType.Name(loss_type))
214
+
215
+ tf.summary.scalar('labels/%s' % label_name,
216
+ tf.reduce_mean(tf.to_float(self._labels[label_name])))
217
+ kwargs = {'loss_name': loss_name}
218
+ if loss_param is not None:
219
+ if hasattr(loss_param, 'session_name'):
220
+ kwargs['session_ids'] = self._feature_dict[loss_param.session_name]
221
+ loss_dict[loss_name] = loss_builder.build(
222
+ loss_type,
223
+ self._labels[label_name],
224
+ pred,
225
+ loss_weight,
226
+ num_class,
227
+ loss_param=loss_param,
228
+ **kwargs)
229
+ return loss_dict
230
+
231
+ def build_loss_graph(self):
232
+ loss_dict = {}
233
+ with tf.name_scope('loss'):
234
+ if len(self._losses) == 0:
235
+ loss_dict = self._build_loss_impl(
236
+ self._loss_type,
237
+ label_name=self._label_name,
238
+ loss_weight=self._sample_weight,
239
+ num_class=self._num_class)
240
+ else:
241
+ strategy = self._base_model_config.loss_weight_strategy
242
+ loss_weight = [1.0]
243
+ if strategy == self._base_model_config.Random and len(self._losses) > 1:
244
+ weights = tf.random_normal([len(self._losses)])
245
+ loss_weight = tf.nn.softmax(weights)
246
+ for i, loss in enumerate(self._losses):
247
+ loss_param = loss.WhichOneof('loss_param')
248
+ if loss_param is not None:
249
+ loss_param = getattr(loss, loss_param)
250
+ loss_ops = self._build_loss_impl(
251
+ loss.loss_type,
252
+ label_name=self._label_name,
253
+ loss_weight=self._sample_weight,
254
+ num_class=self._num_class,
255
+ loss_name=loss.loss_name,
256
+ loss_param=loss_param)
257
+ for loss_name, loss_value in loss_ops.items():
258
+ if strategy == self._base_model_config.Fixed:
259
+ loss_dict[loss_name] = loss_value * loss.weight
260
+ elif strategy == self._base_model_config.Uncertainty:
261
+ if loss.learn_loss_weight:
262
+ uncertainty = tf.Variable(
263
+ 0, name='%s_loss_weight' % loss_name, dtype=tf.float32)
264
+ tf.summary.scalar('%s_uncertainty' % loss_name, uncertainty)
265
+ if loss.loss_type in {
266
+ LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS
267
+ }:
268
+ loss_dict[loss_name] = 0.5 * tf.exp(
269
+ -uncertainty) * loss_value + 0.5 * uncertainty
270
+ else:
271
+ loss_dict[loss_name] = tf.exp(
272
+ -uncertainty) * loss_value + 0.5 * uncertainty
273
+ else:
274
+ loss_dict[loss_name] = loss_value * loss.weight
275
+ elif strategy == self._base_model_config.Random:
276
+ loss_dict[loss_name] = loss_value * loss_weight[i]
277
+ else:
278
+ raise ValueError('Unsupported loss weight strategy: ' +
279
+ strategy.Name)
280
+ self._loss_dict.update(loss_dict)
281
+ # build kd loss
282
+ kd_loss_dict = loss_builder.build_kd_loss(self.kd, self._prediction_dict,
283
+ self._labels,
284
+ self._feature_dict)
285
+ self._loss_dict.update(kd_loss_dict)
286
+ return self._loss_dict
287
+
288
+ def _build_metric_impl(self,
289
+ metric,
290
+ loss_type,
291
+ label_name,
292
+ num_class=1,
293
+ suffix=''):
294
+ if not isinstance(loss_type, set):
295
+ loss_type = {loss_type}
296
+ from easy_rec.python.core.easyrec_metrics import metrics_tf
297
+ from easy_rec.python.core import metrics as metrics_lib
298
+ binary_loss_set = {
299
+ LossType.CLASSIFICATION, LossType.F1_REWEIGHTED_LOSS,
300
+ LossType.PAIR_WISE_LOSS, LossType.BINARY_FOCAL_LOSS,
301
+ LossType.PAIRWISE_FOCAL_LOSS, LossType.PAIRWISE_LOGISTIC_LOSS,
302
+ LossType.JRC_LOSS, LossType.LISTWISE_DISTILL_LOSS,
303
+ LossType.LISTWISE_RANK_LOSS, LossType.ZILN_LOSS
304
+ }
305
+ metric_dict = {}
306
+ if metric.WhichOneof('metric') == 'auc':
307
+ assert loss_type & binary_loss_set
308
+ if num_class == 1 or loss_type & {LossType.JRC_LOSS, LossType.ZILN_LOSS}:
309
+ label = tf.to_int64(self._labels[label_name])
310
+ metric_dict['auc' + suffix] = metrics_tf.auc(
311
+ label,
312
+ self._prediction_dict['probs' + suffix],
313
+ num_thresholds=metric.auc.num_thresholds)
314
+ elif num_class == 2:
315
+ label = tf.to_int64(self._labels[label_name])
316
+ metric_dict['auc' + suffix] = metrics_tf.auc(
317
+ label,
318
+ self._prediction_dict['probs' + suffix][:, 1],
319
+ num_thresholds=metric.auc.num_thresholds)
320
+ else:
321
+ raise ValueError('Wrong class number')
322
+ elif metric.WhichOneof('metric') == 'gauc':
323
+ assert loss_type & binary_loss_set
324
+ if num_class == 1 or loss_type & {LossType.JRC_LOSS, LossType.ZILN_LOSS}:
325
+ label = tf.to_int64(self._labels[label_name])
326
+ uids = self._feature_dict[metric.gauc.uid_field]
327
+ if isinstance(uids, tf.sparse.SparseTensor):
328
+ uids = tf.sparse_to_dense(
329
+ uids.indices, uids.dense_shape, uids.values, default_value='')
330
+ uids = tf.reshape(uids, [-1])
331
+ metric_dict['gauc' + suffix] = metrics_lib.gauc(
332
+ label,
333
+ self._prediction_dict['probs' + suffix],
334
+ uids=uids,
335
+ reduction=metric.gauc.reduction)
336
+ elif num_class == 2:
337
+ label = tf.to_int64(self._labels[label_name])
338
+ metric_dict['gauc' + suffix] = metrics_lib.gauc(
339
+ label,
340
+ self._prediction_dict['probs' + suffix][:, 1],
341
+ uids=self._feature_dict[metric.gauc.uid_field],
342
+ reduction=metric.gauc.reduction)
343
+ else:
344
+ raise ValueError('Wrong class number')
345
+ elif metric.WhichOneof('metric') == 'session_auc':
346
+ assert loss_type & binary_loss_set
347
+ if num_class == 1 or loss_type & {LossType.JRC_LOSS, LossType.ZILN_LOSS}:
348
+ label = tf.to_int64(self._labels[label_name])
349
+ metric_dict['session_auc' + suffix] = metrics_lib.session_auc(
350
+ label,
351
+ self._prediction_dict['probs' + suffix],
352
+ session_ids=self._feature_dict[metric.session_auc.session_id_field],
353
+ reduction=metric.session_auc.reduction)
354
+ elif num_class == 2:
355
+ label = tf.to_int64(self._labels[label_name])
356
+ metric_dict['session_auc' + suffix] = metrics_lib.session_auc(
357
+ label,
358
+ self._prediction_dict['probs' + suffix][:, 1],
359
+ session_ids=self._feature_dict[metric.session_auc.session_id_field],
360
+ reduction=metric.session_auc.reduction)
361
+ else:
362
+ raise ValueError('Wrong class number')
363
+ elif metric.WhichOneof('metric') == 'max_f1':
364
+ assert loss_type & binary_loss_set
365
+ if num_class == 1 or loss_type & {LossType.JRC_LOSS, LossType.ZILN_LOSS}:
366
+ label = tf.to_int64(self._labels[label_name])
367
+ metric_dict['max_f1' + suffix] = metrics_lib.max_f1(
368
+ label, self._prediction_dict['logits' + suffix])
369
+ elif num_class == 2:
370
+ label = tf.to_int64(self._labels[label_name])
371
+ metric_dict['max_f1' + suffix] = metrics_lib.max_f1(
372
+ label, self._prediction_dict['logits' + suffix][:, 1])
373
+ else:
374
+ raise ValueError('Wrong class number')
375
+ elif metric.WhichOneof('metric') == 'recall_at_topk':
376
+ assert loss_type & binary_loss_set
377
+ assert num_class > 1
378
+ label = tf.to_int64(self._labels[label_name])
379
+ metric_dict['recall_at_topk' + suffix] = metrics_tf.recall_at_k(
380
+ label, self._prediction_dict['logits' + suffix],
381
+ metric.recall_at_topk.topk)
382
+ elif metric.WhichOneof('metric') == 'mean_absolute_error':
383
+ label = tf.to_float(self._labels[label_name])
384
+ if loss_type & {
385
+ LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS, LossType.ZILN_LOSS
386
+ }:
387
+ metric_dict['mean_absolute_error' +
388
+ suffix] = metrics_tf.mean_absolute_error(
389
+ label, self._prediction_dict['y' + suffix])
390
+ elif loss_type & {LossType.CLASSIFICATION} and num_class == 1:
391
+ metric_dict['mean_absolute_error' +
392
+ suffix] = metrics_tf.mean_absolute_error(
393
+ label, self._prediction_dict['probs' + suffix])
394
+ else:
395
+ assert False, 'mean_absolute_error is not supported for this model'
396
+ elif metric.WhichOneof('metric') == 'mean_squared_error':
397
+ label = tf.to_float(self._labels[label_name])
398
+ if loss_type & {
399
+ LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS, LossType.ZILN_LOSS
400
+ }:
401
+ metric_dict['mean_squared_error' +
402
+ suffix] = metrics_tf.mean_squared_error(
403
+ label, self._prediction_dict['y' + suffix])
404
+ elif num_class == 1 and loss_type & binary_loss_set:
405
+ metric_dict['mean_squared_error' +
406
+ suffix] = metrics_tf.mean_squared_error(
407
+ label, self._prediction_dict['probs' + suffix])
408
+ else:
409
+ assert False, 'mean_squared_error is not supported for this model'
410
+ elif metric.WhichOneof('metric') == 'root_mean_squared_error':
411
+ label = tf.to_float(self._labels[label_name])
412
+ if loss_type & {
413
+ LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS, LossType.ZILN_LOSS
414
+ }:
415
+ metric_dict['root_mean_squared_error' +
416
+ suffix] = metrics_tf.root_mean_squared_error(
417
+ label, self._prediction_dict['y' + suffix])
418
+ elif loss_type & {LossType.CLASSIFICATION} and num_class == 1:
419
+ metric_dict['root_mean_squared_error' +
420
+ suffix] = metrics_tf.root_mean_squared_error(
421
+ label, self._prediction_dict['probs' + suffix])
422
+ else:
423
+ assert False, 'root_mean_squared_error is not supported for this model'
424
+ elif metric.WhichOneof('metric') == 'accuracy':
425
+ assert loss_type & {LossType.CLASSIFICATION}
426
+ assert num_class > 1
427
+ label = tf.to_int64(self._labels[label_name])
428
+ metric_dict['accuracy' + suffix] = metrics_tf.accuracy(
429
+ label, self._prediction_dict['y' + suffix])
430
+ return metric_dict
431
+
432
+ def build_metric_graph(self, eval_config):
433
+ loss_types = {self._loss_type}
434
+ if len(self._losses) > 0:
435
+ loss_types = {loss.loss_type for loss in self._losses}
436
+ for metric in eval_config.metrics_set:
437
+ self._metric_dict.update(
438
+ self._build_metric_impl(
439
+ metric,
440
+ loss_type=loss_types,
441
+ label_name=self._label_name,
442
+ num_class=self._num_class))
443
+ return self._metric_dict
444
+
445
+ def _get_outputs_impl(self, loss_type, num_class=1, suffix=''):
446
+ binary_loss_set = {
447
+ LossType.F1_REWEIGHTED_LOSS, LossType.PAIR_WISE_LOSS,
448
+ LossType.BINARY_FOCAL_LOSS, LossType.PAIRWISE_FOCAL_LOSS,
449
+ LossType.LISTWISE_RANK_LOSS, LossType.PAIRWISE_HINGE_LOSS,
450
+ LossType.PAIRWISE_LOGISTIC_LOSS, LossType.LISTWISE_DISTILL_LOSS
451
+ }
452
+ if loss_type in binary_loss_set:
453
+ return ['probs' + suffix, 'logits' + suffix]
454
+ if loss_type == LossType.JRC_LOSS:
455
+ return ['probs' + suffix, 'pos_logits' + suffix]
456
+ if loss_type == LossType.ZILN_LOSS:
457
+ return ['probs' + suffix, 'y' + suffix, 'logits' + suffix]
458
+ if loss_type == LossType.CLASSIFICATION:
459
+ if num_class == 1:
460
+ return ['probs' + suffix, 'logits' + suffix]
461
+ else:
462
+ return [
463
+ 'y' + suffix, 'probs' + suffix, 'logits' + suffix,
464
+ 'probs' + suffix + '_y', 'logits' + suffix + '_y',
465
+ 'probs' + suffix + '_1', 'logits' + suffix + '_1'
466
+ ]
467
+ elif loss_type in [LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS]:
468
+ return ['y' + suffix]
469
+ else:
470
+ raise ValueError('invalid loss type: %s' % LossType.Name(loss_type))
471
+
472
+ def get_outputs(self):
473
+ if len(self._losses) == 0:
474
+ outputs = self._get_outputs_impl(self._loss_type, self._num_class)
475
+ if self._outputs:
476
+ outputs.extend(self._outputs)
477
+ return list(set(outputs))
478
+
479
+ all_outputs = []
480
+ if self._outputs:
481
+ all_outputs.extend(self._outputs)
482
+ for loss in self._losses:
483
+ outputs = self._get_outputs_impl(loss.loss_type, self._num_class)
484
+ all_outputs.extend(outputs)
485
+ return list(set(all_outputs))
@@ -0,0 +1,203 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import tensorflow as tf
4
+
5
+ from easy_rec.python.builders import loss_builder
6
+ from easy_rec.python.layers import dnn
7
+ from easy_rec.python.model.rank_model import RankModel
8
+ from easy_rec.python.protos.loss_pb2 import LossType
9
+ from easy_rec.python.protos.simi_pb2 import Similarity
10
+
11
+ from easy_rec.python.protos.rocket_launching_pb2 import RocketLaunching as RocketLaunchingConfig # NOQA
12
+
13
+ if tf.__version__ >= '2.0':
14
+ tf = tf.compat.v1
15
+
16
+
17
+ class RocketLaunching(RankModel):
18
+
19
+ def __init__(self,
20
+ model_config,
21
+ feature_configs,
22
+ features,
23
+ labels=None,
24
+ is_training=False):
25
+ super(RocketLaunching, self).__init__(model_config, feature_configs,
26
+ features, labels, is_training)
27
+ assert self._model_config.WhichOneof('model') == 'rocket_launching', \
28
+ 'invalid model config: %s' % self._model_config.WhichOneof('model')
29
+ self._model_config = self._model_config.rocket_launching
30
+ assert isinstance(self._model_config, RocketLaunchingConfig)
31
+ if self._labels is not None:
32
+ self._label_name = list(self._labels.keys())[0]
33
+
34
+ self._features, _ = self._input_layer(self._feature_dict, 'all')
35
+
36
+ def sim(self, feature_emb1, feature_emb2):
37
+ emb1_emb2_sim = tf.reduce_sum(
38
+ tf.multiply(feature_emb1, feature_emb2), axis=1, keepdims=True)
39
+ return emb1_emb2_sim
40
+
41
+ def norm(self, fea):
42
+ fea_norm = tf.nn.l2_normalize(fea, axis=1)
43
+ return fea_norm
44
+
45
+ def feature_based_sim(self, feature_based_distillation, i, j):
46
+ booster_feature_no_gradient = tf.stop_gradient(
47
+ self.booster_feature['hidden_layer' + str(j)])
48
+ if feature_based_distillation == Similarity.COSINE:
49
+ booster_feature_no_gradient_norm = self.norm(booster_feature_no_gradient)
50
+ light_feature_norm = self.norm(self.light_feature['hidden_layer' +
51
+ str(i)])
52
+ sim_middle_layer = tf.reduce_mean(
53
+ self.sim(booster_feature_no_gradient_norm, light_feature_norm))
54
+ return sim_middle_layer
55
+ else:
56
+ return tf.sqrt(
57
+ tf.reduce_sum(
58
+ tf.square(booster_feature_no_gradient -
59
+ self.light_feature['hidden_layer' + str(i)])))
60
+
61
+ def build_predict_graph(self):
62
+ self.hidden_layer_feature_output = self._model_config.feature_based_distillation
63
+ if self._model_config.HasField('share_dnn'):
64
+ share_dnn_layer = dnn.DNN(self._model_config.share_dnn, self._l2_reg,
65
+ 'share_dnn', self._is_training)
66
+ share_feature = share_dnn_layer(self._features)
67
+ booster_dnn_layer = dnn.DNN(self._model_config.booster_dnn, self._l2_reg,
68
+ 'booster_dnn', self._is_training)
69
+ light_dnn_layer = dnn.DNN(self._model_config.light_dnn, self._l2_reg,
70
+ 'light_dnn', self._is_training)
71
+ if self._model_config.HasField('share_dnn'):
72
+ self.booster_feature = booster_dnn_layer(share_feature,
73
+ self.hidden_layer_feature_output)
74
+ input_embedding_stop_gradient = tf.stop_gradient(share_feature)
75
+ self.light_feature = light_dnn_layer(input_embedding_stop_gradient,
76
+ self.hidden_layer_feature_output)
77
+ else:
78
+ self.booster_feature = booster_dnn_layer(self._features,
79
+ self.hidden_layer_feature_output)
80
+ input_embedding_stop_gradient = tf.stop_gradient(self._features)
81
+ self.light_feature = light_dnn_layer(input_embedding_stop_gradient,
82
+ self.hidden_layer_feature_output)
83
+
84
+ if self._model_config.feature_based_distillation:
85
+ booster_out = tf.layers.dense(
86
+ self.booster_feature['hidden_layer_end'],
87
+ self._num_class,
88
+ kernel_regularizer=self._l2_reg,
89
+ name='booster_output')
90
+
91
+ light_out = tf.layers.dense(
92
+ self.light_feature['hidden_layer_end'],
93
+ self._num_class,
94
+ kernel_regularizer=self._l2_reg,
95
+ name='light_output')
96
+ else:
97
+ booster_out = tf.layers.dense(
98
+ self.booster_feature,
99
+ self._num_class,
100
+ kernel_regularizer=self._l2_reg,
101
+ name='booster_output')
102
+
103
+ light_out = tf.layers.dense(
104
+ self.light_feature,
105
+ self._num_class,
106
+ kernel_regularizer=self._l2_reg,
107
+ name='light_output')
108
+
109
+ self._prediction_dict.update(
110
+ self._output_to_prediction_impl(
111
+ booster_out,
112
+ self._loss_type,
113
+ num_class=self._num_class,
114
+ suffix='_booster'))
115
+ self._prediction_dict.update(
116
+ self._output_to_prediction_impl(
117
+ light_out,
118
+ self._loss_type,
119
+ num_class=self._num_class,
120
+ suffix='_light'))
121
+
122
+ return self._prediction_dict
123
+
124
+ def build_loss_graph(self):
125
+ logits_booster = self._prediction_dict['logits_booster']
126
+ logits_light = self._prediction_dict['logits_light']
127
+ self.feature_distillation_function = self._model_config.feature_distillation_function
128
+
129
+ # feature_based_distillation loss
130
+ if self._model_config.feature_based_distillation:
131
+ booster_hidden_units = self._model_config.booster_dnn.hidden_units
132
+ light_hidden_units = self._model_config.light_dnn.hidden_units
133
+ count = 0
134
+
135
+ for i, unit_i in enumerate(light_hidden_units):
136
+ for j, unit_j in enumerate(booster_hidden_units):
137
+ if light_hidden_units[i] == booster_hidden_units[j]:
138
+ self._prediction_dict[
139
+ 'similarity_' + str(count)] = self.feature_based_sim(
140
+ self._model_config.feature_based_distillation, i, j)
141
+ count += 1
142
+ break
143
+
144
+ self._loss_dict.update(
145
+ self._build_loss_impl(
146
+ LossType.CLASSIFICATION,
147
+ label_name=self._label_name,
148
+ loss_weight=self._sample_weight,
149
+ num_class=self._num_class,
150
+ suffix='_booster'))
151
+
152
+ self._loss_dict.update(
153
+ self._build_loss_impl(
154
+ LossType.CLASSIFICATION,
155
+ label_name=self._label_name,
156
+ loss_weight=self._sample_weight,
157
+ num_class=self._num_class,
158
+ suffix='_light'))
159
+
160
+ booster_logits_no_grad = tf.stop_gradient(logits_booster)
161
+
162
+ self._loss_dict['hint_loss'] = loss_builder.build(
163
+ LossType.L2_LOSS,
164
+ label=booster_logits_no_grad,
165
+ pred=logits_light,
166
+ loss_weight=self._sample_weight)
167
+
168
+ if self._model_config.feature_based_distillation:
169
+ for key, value in self._prediction_dict.items():
170
+ if key.startswith('similarity_'):
171
+ self._loss_dict[key] = -0.1 * value
172
+ return self._loss_dict
173
+ else:
174
+ return self._loss_dict
175
+
176
+ def build_metric_graph(self, eval_config):
177
+ metric_dict = {}
178
+ for metric in eval_config.metrics_set:
179
+ metric_dict.update(
180
+ self._build_metric_impl(
181
+ metric,
182
+ loss_type=LossType.CLASSIFICATION,
183
+ label_name=self._label_name,
184
+ num_class=self._num_class,
185
+ suffix='_light'))
186
+ metric_dict.update(
187
+ self._build_metric_impl(
188
+ metric,
189
+ loss_type=LossType.CLASSIFICATION,
190
+ label_name=self._label_name,
191
+ num_class=self._num_class,
192
+ suffix='_booster'))
193
+ return metric_dict
194
+
195
+ def get_outputs(self):
196
+ outputs = []
197
+ outputs.extend(
198
+ self._get_outputs_impl(
199
+ self._loss_type, self._num_class, suffix='_light'))
200
+ outputs.extend(
201
+ self._get_outputs_impl(
202
+ self._loss_type, self._num_class, suffix='_booster'))
203
+ return outputs