easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
@@ -0,0 +1,207 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import tensorflow as tf
4
+
5
+ from easy_rec.python.layers import dnn
6
+ from easy_rec.python.loss.pairwise_loss import pairwise_loss
7
+ from easy_rec.python.model.easy_rec_model import EasyRecModel
8
+ from easy_rec.python.protos.loss_pb2 import LossType
9
+ from easy_rec.python.utils.proto_util import copy_obj
10
+
11
+ from easy_rec.python.protos.dropoutnet_pb2 import DropoutNet as DropoutNetConfig # NOQA
12
+ from easy_rec.python.loss.softmax_loss_with_negative_mining import softmax_loss_with_negative_mining # NOQA
13
+ from easy_rec.python.protos.dropoutnet_pb2 import DropoutNet as DropoutNetConfig # NOQA
14
+ if tf.__version__ >= '2.0':
15
+ tf = tf.compat.v1
16
+ losses = tf.losses
17
+
18
+
19
+ def cosine_similarity(user_emb, item_emb):
20
+ user_item_sim = tf.reduce_sum(
21
+ tf.multiply(user_emb, item_emb), axis=1, name='cosine')
22
+ return user_item_sim
23
+
24
+
25
+ def bernoulli_dropout(x, rate, training=False):
26
+ if rate == 0.0 or not training:
27
+ return x
28
+ keep_rate = 1.0 - rate
29
+ dist = tf.distributions.Bernoulli(probs=keep_rate, dtype=x.dtype)
30
+ mask = dist.sample(sample_shape=tf.stack([tf.shape(x)[0], 1]))
31
+ return x * mask / keep_rate
32
+
33
+
34
+ class DropoutNet(EasyRecModel):
35
+
36
+ def __init__(self,
37
+ model_config,
38
+ feature_configs,
39
+ features,
40
+ labels=None,
41
+ is_training=False):
42
+ super(DropoutNet, self).__init__(model_config, feature_configs, features,
43
+ labels, is_training)
44
+ self._losses = self._model_config.losses
45
+ assert self._model_config.WhichOneof(
46
+ 'model'
47
+ ) == 'dropoutnet', 'invalid model config: %s' % self._model_config.WhichOneof(
48
+ 'model')
49
+ self._model_config = self._model_config.dropoutnet
50
+ assert isinstance(self._model_config, DropoutNetConfig)
51
+
52
+ # copy_obj so that any modification will not affect original config
53
+ self.user_content_layers = copy_obj(self._model_config.user_content)
54
+ self.user_preference_layers = copy_obj(self._model_config.user_preference)
55
+ self.user_tower_layers = copy_obj(self._model_config.user_tower)
56
+ self.user_content_feature, self.user_preference_feature = None, None
57
+ if self._input_layer.has_group('user_content'):
58
+ self.user_content_feature, _ = self._input_layer(self._feature_dict,
59
+ 'user_content')
60
+ if self._input_layer.has_group('user_preference'):
61
+ self.user_preference_feature, _ = self._input_layer(
62
+ self._feature_dict, 'user_preference')
63
+ assert self.user_content_feature is not None or self.user_preference_feature is not None, 'no user feature'
64
+
65
+ # copy_obj so that any modification will not affect original config
66
+ self.item_content_layers = copy_obj(self._model_config.item_content)
67
+ self.item_preference_layers = copy_obj(self._model_config.item_preference)
68
+ self.item_tower_layers = copy_obj(self._model_config.item_tower)
69
+ self.item_content_feature, self.item_preference_feature = None, None
70
+ if self._input_layer.has_group('item_content'):
71
+ self.item_content_feature, _ = self._input_layer(self._feature_dict,
72
+ 'item_content')
73
+ if self._input_layer.has_group('item_preference'):
74
+ self.item_preference_feature, _ = self._input_layer(
75
+ self._feature_dict, 'item_preference')
76
+ assert self.item_content_feature is not None or self.item_preference_feature is not None, 'no item feature'
77
+
78
+ def build_predict_graph(self):
79
+ num_user_dnn_layer = len(self.user_tower_layers.hidden_units)
80
+ last_user_hidden = self.user_tower_layers.hidden_units.pop()
81
+ num_item_dnn_layer = len(self.item_tower_layers.hidden_units)
82
+ last_item_hidden = self.item_tower_layers.hidden_units.pop()
83
+ assert last_item_hidden == last_user_hidden, 'the last hidden layer size of user tower and item tower must be equal'
84
+
85
+ # --------------------------build user tower-----------------------------------
86
+ with tf.name_scope('user_tower'):
87
+ user_features = []
88
+ if self.user_content_feature is not None:
89
+ user_content_dnn = dnn.DNN(self.user_content_layers, self._l2_reg,
90
+ 'user_content', self._is_training)
91
+ content_feature = user_content_dnn(self.user_content_feature)
92
+ user_features.append(content_feature)
93
+ if self.user_preference_feature is not None:
94
+ user_prefer_feature = bernoulli_dropout(
95
+ self.user_preference_feature, self._model_config.user_dropout_rate,
96
+ self._is_training)
97
+ user_prefer_dnn = dnn.DNN(self.user_preference_layers, self._l2_reg,
98
+ 'user_preference', self._is_training)
99
+ prefer_feature = user_prefer_dnn(user_prefer_feature)
100
+ user_features.append(prefer_feature)
101
+
102
+ user_tower_feature = tf.concat(user_features, axis=-1)
103
+
104
+ user_dnn = dnn.DNN(self.user_tower_layers, self._l2_reg, 'user_dnn',
105
+ self._is_training)
106
+ user_hidden = user_dnn(user_tower_feature)
107
+ user_tower_emb = tf.layers.dense(
108
+ inputs=user_hidden,
109
+ units=last_user_hidden,
110
+ kernel_regularizer=self._l2_reg,
111
+ name='user_dnn/dnn_%d' % (num_user_dnn_layer - 1))
112
+
113
+ # --------------------------build item tower-----------------------------------
114
+ with tf.name_scope('item_tower'):
115
+ item_features = []
116
+ if self.item_content_feature is not None:
117
+ item_content_dnn = dnn.DNN(self.item_content_layers, self._l2_reg,
118
+ 'item_content', self._is_training)
119
+ content_feature = item_content_dnn(self.item_content_feature)
120
+ item_features.append(content_feature)
121
+ if self.item_preference_feature is not None:
122
+ item_prefer_feature = bernoulli_dropout(
123
+ self.item_preference_feature, self._model_config.item_dropout_rate,
124
+ self._is_training)
125
+ item_prefer_dnn = dnn.DNN(self.item_preference_layers, self._l2_reg,
126
+ 'item_preference', self._is_training)
127
+ prefer_feature = item_prefer_dnn(item_prefer_feature)
128
+ item_features.append(prefer_feature)
129
+
130
+ item_tower_feature = tf.concat(item_features, axis=-1)
131
+
132
+ item_dnn = dnn.DNN(self.item_tower_layers, self._l2_reg, 'item_dnn',
133
+ self._is_training)
134
+ item_hidden = item_dnn(item_tower_feature)
135
+ item_tower_emb = tf.layers.dense(
136
+ inputs=item_hidden,
137
+ units=last_item_hidden,
138
+ kernel_regularizer=self._l2_reg,
139
+ name='item_dnn/dnn_%d' % (num_item_dnn_layer - 1))
140
+
141
+ user_emb = tf.nn.l2_normalize(user_tower_emb, axis=-1)
142
+ item_emb = tf.nn.l2_normalize(item_tower_emb, axis=-1)
143
+ cosine = cosine_similarity(user_emb, item_emb)
144
+ self._prediction_dict['similarity'] = cosine
145
+ self._prediction_dict['float_user_emb'] = user_emb
146
+ self._prediction_dict['float_item_emb'] = item_emb
147
+ self._prediction_dict['user_emb'] = tf.reduce_join(
148
+ tf.as_string(user_emb), axis=-1, separator=',')
149
+ self._prediction_dict['item_emb'] = tf.reduce_join(
150
+ tf.as_string(item_emb), axis=-1, separator=',')
151
+ return self._prediction_dict
152
+
153
+ def build_loss_graph(self):
154
+ labels = list(self._labels.values())[0]
155
+ logits = self._prediction_dict['similarity']
156
+ for loss in self._losses:
157
+ if loss.loss_type == LossType.SOFTMAX_CROSS_ENTROPY_WITH_NEGATIVE_MINING:
158
+ assert self._model_config.HasField(
159
+ 'softmax_loss'), '`softmax_loss` must be configured'
160
+ user_emb = self._prediction_dict['float_user_emb']
161
+ item_emb = self._prediction_dict['float_item_emb']
162
+ loss_value = softmax_loss_with_negative_mining(
163
+ user_emb,
164
+ item_emb,
165
+ labels,
166
+ self._model_config.softmax_loss.num_negative_samples,
167
+ embed_normed=True,
168
+ weights=self._sample_weight,
169
+ margin=self._model_config.softmax_loss.margin,
170
+ gamma=self._model_config.softmax_loss.gamma,
171
+ t=self._model_config.softmax_loss.coefficient_of_support_vector)
172
+ self._loss_dict['softmax_loss'] = loss_value * loss.weight
173
+ elif loss.loss_type == LossType.PAIR_WISE_LOSS:
174
+ loss_value = pairwise_loss(labels, logits)
175
+ self._loss_dict['pairwise_loss'] = loss_value * loss.weight
176
+ elif loss.loss_type == LossType.CLASSIFICATION:
177
+ loss_value = tf.losses.sigmoid_cross_entropy(labels, logits,
178
+ self._sample_weight)
179
+ self._loss_dict['sigmoid_loss'] = loss_value * loss.weight
180
+ return self._loss_dict
181
+
182
+ def build_metric_graph(self, eval_config):
183
+ from easy_rec.python.core.easyrec_metrics import metrics_tf as metrics
184
+ metric_dict = {}
185
+ labels = list(self._labels.values())[0]
186
+ sim_score = self._prediction_dict['similarity']
187
+ prob = tf.nn.sigmoid(sim_score)
188
+ predict = tf.greater(prob, 0.5)
189
+ for metric in eval_config.metrics_set:
190
+ if metric.WhichOneof('metric') == 'auc':
191
+ metric_dict['auc'] = metrics.auc(
192
+ labels, prob, weights=self._sample_weight)
193
+ elif metric.WhichOneof('metric') == 'accuracy':
194
+ metric_dict['accuracy'] = metrics.accuracy(
195
+ tf.cast(labels, tf.bool), predict, weights=self._sample_weight)
196
+ elif metric.WhichOneof('metric') == 'precision':
197
+ metric_dict['precision'] = metrics.precision(
198
+ labels, predict, weights=self._sample_weight)
199
+ elif metric.WhichOneof('metric') == 'recall':
200
+ metric_dict['recall'] = metrics.recall(
201
+ labels, predict, weights=self._sample_weight)
202
+ else:
203
+ ValueError('invalid metric type: %s' % str(metric))
204
+ return metric_dict
205
+
206
+ def get_outputs(self):
207
+ return ['similarity', 'user_emb', 'item_emb']
@@ -0,0 +1,154 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import tensorflow as tf
4
+
5
+ from easy_rec.python.layers import dnn
6
+ from easy_rec.python.model.match_model import MatchModel
7
+ from easy_rec.python.protos.dssm_pb2 import DSSM as DSSMConfig
8
+ from easy_rec.python.protos.loss_pb2 import LossType
9
+ from easy_rec.python.protos.simi_pb2 import Similarity
10
+ from easy_rec.python.utils.proto_util import copy_obj
11
+
12
+ if tf.__version__ >= '2.0':
13
+ tf = tf.compat.v1
14
+ losses = tf.losses
15
+
16
+
17
+ class DSSM(MatchModel):
18
+
19
+ def __init__(self,
20
+ model_config,
21
+ feature_configs,
22
+ features,
23
+ labels=None,
24
+ is_training=False):
25
+ super(DSSM, self).__init__(model_config, feature_configs, features, labels,
26
+ is_training)
27
+ assert self._model_config.WhichOneof('model') == 'dssm', \
28
+ 'invalid model config: %s' % self._model_config.WhichOneof('model')
29
+ self._model_config = self._model_config.dssm
30
+ assert isinstance(self._model_config, DSSMConfig)
31
+
32
+ # copy_obj so that any modification will not affect original config
33
+ self.user_tower = copy_obj(self._model_config.user_tower)
34
+ self.user_tower_feature, _ = self._input_layer(self._feature_dict, 'user')
35
+ # copy_obj so that any modification will not affect original config
36
+ self.item_tower = copy_obj(self._model_config.item_tower)
37
+ self.item_tower_feature, _ = self._input_layer(self._feature_dict, 'item')
38
+ self._user_tower_emb = None
39
+ self._item_tower_emb = None
40
+
41
+ def build_predict_graph(self):
42
+ num_user_dnn_layer = len(self.user_tower.dnn.hidden_units)
43
+ last_user_hidden = self.user_tower.dnn.hidden_units.pop()
44
+ user_dnn = dnn.DNN(self.user_tower.dnn, self._l2_reg, 'user_dnn',
45
+ self._is_training)
46
+ user_tower_emb = user_dnn(self.user_tower_feature)
47
+ user_tower_emb = tf.layers.dense(
48
+ inputs=user_tower_emb,
49
+ units=last_user_hidden,
50
+ kernel_regularizer=self._l2_reg,
51
+ name='user_dnn/dnn_%d' % (num_user_dnn_layer - 1))
52
+
53
+ num_item_dnn_layer = len(self.item_tower.dnn.hidden_units)
54
+ last_item_hidden = self.item_tower.dnn.hidden_units.pop()
55
+ item_dnn = dnn.DNN(self.item_tower.dnn, self._l2_reg, 'item_dnn',
56
+ self._is_training)
57
+ item_tower_emb = item_dnn(self.item_tower_feature)
58
+ item_tower_emb = tf.layers.dense(
59
+ inputs=item_tower_emb,
60
+ units=last_item_hidden,
61
+ kernel_regularizer=self._l2_reg,
62
+ name='item_dnn/dnn_%d' % (num_item_dnn_layer - 1))
63
+
64
+ if self._model_config.simi_func == Similarity.COSINE:
65
+ user_tower_emb = self.norm(user_tower_emb)
66
+ item_tower_emb = self.norm(item_tower_emb)
67
+ temperature = self._model_config.temperature
68
+ else:
69
+ temperature = 1.0
70
+
71
+ user_item_sim = self.sim(user_tower_emb, item_tower_emb) / temperature
72
+ if self._model_config.scale_simi:
73
+ sim_w = tf.get_variable(
74
+ 'sim_w',
75
+ dtype=tf.float32,
76
+ shape=(1),
77
+ initializer=tf.ones_initializer())
78
+ sim_b = tf.get_variable(
79
+ 'sim_b',
80
+ dtype=tf.float32,
81
+ shape=(1),
82
+ initializer=tf.zeros_initializer())
83
+ y_pred = user_item_sim * tf.abs(sim_w) + sim_b
84
+ else:
85
+ y_pred = user_item_sim
86
+
87
+ if self._is_point_wise:
88
+ y_pred = tf.reshape(y_pred, [-1])
89
+
90
+ if self._loss_type == LossType.CLASSIFICATION:
91
+ self._prediction_dict['logits'] = y_pred
92
+ self._prediction_dict['probs'] = tf.nn.sigmoid(y_pred)
93
+ elif self._loss_type == LossType.SOFTMAX_CROSS_ENTROPY:
94
+ y_pred = self._mask_in_batch(y_pred)
95
+ self._prediction_dict['logits'] = y_pred
96
+ self._prediction_dict['probs'] = tf.nn.softmax(y_pred)
97
+ else:
98
+ self._prediction_dict['y'] = y_pred
99
+
100
+ self._prediction_dict['user_tower_emb'] = user_tower_emb
101
+ self._prediction_dict['item_tower_emb'] = item_tower_emb
102
+ self._prediction_dict['user_emb'] = tf.reduce_join(
103
+ tf.as_string(user_tower_emb), axis=-1, separator=',')
104
+ self._prediction_dict['item_emb'] = tf.reduce_join(
105
+ tf.as_string(item_tower_emb), axis=-1, separator=',')
106
+ return self._prediction_dict
107
+
108
+ def get_outputs(self):
109
+ if self._loss_type == LossType.CLASSIFICATION:
110
+ return [
111
+ 'logits', 'probs', 'user_emb', 'item_emb', 'user_tower_emb',
112
+ 'item_tower_emb'
113
+ ]
114
+ elif self._loss_type == LossType.SOFTMAX_CROSS_ENTROPY:
115
+ self._prediction_dict['logits'] = tf.squeeze(
116
+ self._prediction_dict['logits'], axis=-1)
117
+ self._prediction_dict['probs'] = tf.nn.sigmoid(
118
+ self._prediction_dict['logits'])
119
+ return [
120
+ 'logits', 'probs', 'user_emb', 'item_emb', 'user_tower_emb',
121
+ 'item_tower_emb'
122
+ ]
123
+ elif self._loss_type == LossType.L2_LOSS:
124
+ return ['y', 'user_emb', 'item_emb', 'user_tower_emb', 'item_tower_emb']
125
+ else:
126
+ raise ValueError('invalid loss type: %s' % str(self._loss_type))
127
+
128
+ def build_output_dict(self):
129
+ output_dict = super(DSSM, self).build_output_dict()
130
+ output_dict['user_tower_feature'] = tf.reduce_join(
131
+ tf.as_string(self.user_tower_feature), axis=-1, separator=',')
132
+ output_dict['item_tower_feature'] = tf.reduce_join(
133
+ tf.as_string(self.item_tower_feature), axis=-1, separator=',')
134
+ return output_dict
135
+
136
+ def build_rtp_output_dict(self):
137
+ output_dict = super(DSSM, self).build_rtp_output_dict()
138
+ if 'user_tower_emb' not in self._prediction_dict:
139
+ raise ValueError(
140
+ 'User tower embedding does not exist. Please checking predict graph.')
141
+ output_dict['user_embedding_output'] = tf.identity(
142
+ self._prediction_dict['user_tower_emb'], name='user_embedding_output')
143
+ if 'item_tower_emb' not in self._prediction_dict:
144
+ raise ValueError(
145
+ 'Item tower embedding does not exist. Please checking predict graph.')
146
+ output_dict['item_embedding_output'] = tf.identity(
147
+ self._prediction_dict['item_tower_emb'], name='item_embedding_output')
148
+ if self._loss_type == LossType.CLASSIFICATION:
149
+ if 'probs' not in self._prediction_dict:
150
+ raise ValueError(
151
+ 'Probs output does not exist. Please checking predict graph.')
152
+ output_dict['rank_predict'] = tf.identity(
153
+ self._prediction_dict['probs'], name='rank_predict')
154
+ return output_dict
@@ -0,0 +1,143 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import tensorflow as tf
4
+
5
+ from easy_rec.python.layers import dnn
6
+ from easy_rec.python.layers import senet
7
+ from easy_rec.python.model.dssm import DSSM
8
+ from easy_rec.python.model.match_model import MatchModel
9
+ from easy_rec.python.protos.loss_pb2 import LossType
10
+ from easy_rec.python.protos.simi_pb2 import Similarity
11
+ from easy_rec.python.utils.proto_util import copy_obj
12
+
13
+ from easy_rec.python.protos.dssm_senet_pb2 import DSSM_SENet as DSSM_SENet_Config # NOQA
14
+
15
+ if tf.__version__ >= '2.0':
16
+ tf = tf.compat.v1
17
+ losses = tf.losses
18
+
19
+
20
+ class DSSM_SENet(DSSM):
21
+
22
+ def __init__(self,
23
+ model_config,
24
+ feature_configs,
25
+ features,
26
+ labels=None,
27
+ is_training=False):
28
+
29
+ MatchModel.__init__(self, model_config, feature_configs, features, labels,
30
+ is_training)
31
+
32
+ assert self._model_config.WhichOneof('model') == 'dssm_senet', \
33
+ 'invalid model config: %s' % self._model_config.WhichOneof('model')
34
+ self._model_config = self._model_config.dssm_senet
35
+ assert isinstance(self._model_config, DSSM_SENet_Config)
36
+
37
+ # copy_obj so that any modification will not affect original config
38
+ self.user_tower = copy_obj(self._model_config.user_tower)
39
+
40
+ self.user_seq_features, self.user_plain_features, self.user_feature_list = self._input_layer(
41
+ self._feature_dict, 'user', is_combine=False)
42
+ self.user_num_fields = len(self.user_feature_list)
43
+
44
+ # copy_obj so that any modification will not affect original config
45
+ self.item_tower = copy_obj(self._model_config.item_tower)
46
+
47
+ self.item_seq_features, self.item_plain_features, self.item_feature_list = self._input_layer(
48
+ self._feature_dict, 'item', is_combine=False)
49
+ self.item_num_fields = len(self.item_feature_list)
50
+
51
+ self._user_tower_emb = None
52
+ self._item_tower_emb = None
53
+
54
+ def build_predict_graph(self):
55
+ user_senet = senet.SENet(
56
+ num_fields=self.user_num_fields,
57
+ num_squeeze_group=self.user_tower.senet.num_squeeze_group,
58
+ reduction_ratio=self.user_tower.senet.reduction_ratio,
59
+ l2_reg=self._l2_reg,
60
+ name='user_senet')
61
+ user_senet_output_list = user_senet(self.user_feature_list)
62
+ user_senet_output = tf.concat(user_senet_output_list, axis=-1)
63
+
64
+ num_user_dnn_layer = len(self.user_tower.dnn.hidden_units)
65
+ last_user_hidden = self.user_tower.dnn.hidden_units.pop()
66
+ user_dnn = dnn.DNN(self.user_tower.dnn, self._l2_reg, 'user_dnn',
67
+ self._is_training)
68
+ user_tower_emb = user_dnn(user_senet_output)
69
+ user_tower_emb = tf.layers.dense(
70
+ inputs=user_tower_emb,
71
+ units=last_user_hidden,
72
+ kernel_regularizer=self._l2_reg,
73
+ name='user_dnn/dnn_%d' % (num_user_dnn_layer - 1))
74
+
75
+ item_senet = senet.SENet(
76
+ num_fields=self.item_num_fields,
77
+ num_squeeze_group=self.item_tower.senet.num_squeeze_group,
78
+ reduction_ratio=self.item_tower.senet.reduction_ratio,
79
+ l2_reg=self._l2_reg,
80
+ name='item_senet')
81
+
82
+ item_senet_output_list = item_senet(self.item_feature_list)
83
+ item_senet_output = tf.concat(item_senet_output_list, axis=-1)
84
+
85
+ num_item_dnn_layer = len(self.item_tower.dnn.hidden_units)
86
+ last_item_hidden = self.item_tower.dnn.hidden_units.pop()
87
+ item_dnn = dnn.DNN(self.item_tower.dnn, self._l2_reg, 'item_dnn',
88
+ self._is_training)
89
+ item_tower_emb = item_dnn(item_senet_output)
90
+ item_tower_emb = tf.layers.dense(
91
+ inputs=item_tower_emb,
92
+ units=last_item_hidden,
93
+ kernel_regularizer=self._l2_reg,
94
+ name='item_dnn/dnn_%d' % (num_item_dnn_layer - 1))
95
+
96
+ if self._model_config.simi_func == Similarity.COSINE:
97
+ user_tower_emb = self.norm(user_tower_emb)
98
+ item_tower_emb = self.norm(item_tower_emb)
99
+ temperature = self._model_config.temperature
100
+ else:
101
+ temperature = 1.0
102
+
103
+ user_item_sim = self.sim(user_tower_emb, item_tower_emb) / temperature
104
+ if self._model_config.scale_simi:
105
+ sim_w = tf.get_variable(
106
+ 'sim_w',
107
+ dtype=tf.float32,
108
+ shape=(1),
109
+ initializer=tf.ones_initializer())
110
+ sim_b = tf.get_variable(
111
+ 'sim_b',
112
+ dtype=tf.float32,
113
+ shape=(1),
114
+ initializer=tf.zeros_initializer())
115
+ y_pred = user_item_sim * tf.abs(sim_w) + sim_b
116
+ else:
117
+ y_pred = user_item_sim
118
+
119
+ if self._is_point_wise:
120
+ y_pred = tf.reshape(y_pred, [-1])
121
+
122
+ if self._loss_type == LossType.CLASSIFICATION:
123
+ self._prediction_dict['logits'] = y_pred
124
+ self._prediction_dict['probs'] = tf.nn.sigmoid(y_pred)
125
+ elif self._loss_type == LossType.SOFTMAX_CROSS_ENTROPY:
126
+ y_pred = self._mask_in_batch(y_pred)
127
+ self._prediction_dict['logits'] = y_pred
128
+ self._prediction_dict['probs'] = tf.nn.softmax(y_pred)
129
+ else:
130
+ self._prediction_dict['y'] = y_pred
131
+
132
+ self._prediction_dict['user_tower_emb'] = user_tower_emb
133
+ self._prediction_dict['item_tower_emb'] = item_tower_emb
134
+ self._prediction_dict['user_emb'] = tf.reduce_join(
135
+ tf.as_string(user_tower_emb), axis=-1, separator=',')
136
+ self._prediction_dict['item_emb'] = tf.reduce_join(
137
+ tf.as_string(item_tower_emb), axis=-1, separator=',')
138
+ return self._prediction_dict
139
+
140
+ def build_output_dict(self):
141
+ output_dict = MatchModel.build_output_dict(self)
142
+
143
+ return output_dict
@@ -0,0 +1,48 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+
4
+ import tensorflow as tf
5
+
6
+ from easy_rec.python.model.easy_rec_model import EasyRecModel
7
+
8
+
9
+ class DummyModel(EasyRecModel):
10
+
11
+ def __init__(self,
12
+ model_config,
13
+ feature_configs,
14
+ features,
15
+ labels=None,
16
+ is_training=False):
17
+ super(DummyModel, self).__init__(model_config, feature_configs, features,
18
+ labels, is_training)
19
+
20
+ if self._labels is not None:
21
+ self._labels = list(self._labels.values())
22
+ if self._labels[0].dtype != tf.float32:
23
+ self._labels[0] = tf.ones_like(self._labels[0], tf.float32)
24
+
25
+ def build_predict_graph(self):
26
+ input_data = tf.random_uniform(tf.shape(self._labels[0]), dtype=tf.float32)
27
+ input_data = tf.reshape(input_data, [-1, 1])
28
+ output = tf.layers.dense(inputs=input_data, units=1, name='layer_0')
29
+ self._prediction_dict['output'] = output
30
+ for key in self._feature_dict:
31
+ val = self._feature_dict[key]
32
+ if isinstance(val, tf.sparse.SparseTensor):
33
+ val = val.values
34
+ self._prediction_dict[key] = val
35
+ return self._prediction_dict
36
+
37
+ def build_loss_graph(self):
38
+ return {
39
+ 'cross_ent':
40
+ tf.reduce_sum(
41
+ tf.square(self._prediction_dict['output'] - self._labels[0]))
42
+ }
43
+
44
+ def get_outputs(self):
45
+ return ['output']
46
+
47
+ def build_metric_graph(self):
48
+ return {}