easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
@@ -0,0 +1,54 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import tensorflow as tf
4
+
5
+ from easy_rec.python.layers import dnn
6
+ from easy_rec.python.model.multi_task_model import MultiTaskModel
7
+
8
+ from easy_rec.python.protos.simple_multi_task_pb2 import SimpleMultiTask as SimpleMultiTaskConfig # NOQA
9
+
10
+ if tf.__version__ >= '2.0':
11
+ tf = tf.compat.v1
12
+
13
+
14
+ class SimpleMultiTask(MultiTaskModel):
15
+
16
+ def __init__(self,
17
+ model_config,
18
+ feature_configs,
19
+ features,
20
+ labels=None,
21
+ is_training=False):
22
+ super(SimpleMultiTask, self).__init__(model_config, feature_configs,
23
+ features, labels, is_training)
24
+
25
+ assert self._model_config.WhichOneof('model') == 'simple_multi_task', \
26
+ 'invalid model config: %s' % self._model_config.WhichOneof('model')
27
+ self._model_config = self._model_config.simple_multi_task
28
+ assert isinstance(self._model_config, SimpleMultiTaskConfig)
29
+
30
+ if self.has_backbone:
31
+ self._features = self.backbone
32
+ else:
33
+ self._features, _ = self._input_layer(self._feature_dict, 'all')
34
+ self._init_towers(self._model_config.task_towers)
35
+
36
+ def build_predict_graph(self):
37
+ tower_outputs = {}
38
+ for i, task_tower_cfg in enumerate(self._task_towers):
39
+ tower_name = task_tower_cfg.tower_name
40
+ task_dnn = dnn.DNN(
41
+ task_tower_cfg.dnn,
42
+ self._l2_reg,
43
+ name=tower_name,
44
+ is_training=self._is_training)
45
+ task_fea = task_dnn(self._features)
46
+ task_output = tf.layers.dense(
47
+ inputs=task_fea,
48
+ units=task_tower_cfg.num_class,
49
+ kernel_regularizer=self._l2_reg,
50
+ name='dnn_output_%d' % i)
51
+ tower_outputs[tower_name] = task_output
52
+
53
+ self._add_to_prediction_dict(tower_outputs)
54
+ return self._prediction_dict
@@ -0,0 +1,46 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import tensorflow as tf
4
+
5
+ from easy_rec.python.layers import dnn
6
+ from easy_rec.python.layers import uniter
7
+ from easy_rec.python.model.rank_model import RankModel
8
+
9
+ from easy_rec.python.protos.uniter_pb2 import Uniter as UNITERConfig # NOQA
10
+
11
+ if tf.__version__ >= '2.0':
12
+ tf = tf.compat.v1
13
+
14
+
15
+ class Uniter(RankModel):
16
+ """UNITER: UNiversal Image-TExt Representation Learning.
17
+
18
+ See the original paper:
19
+ https://arxiv.org/abs/1909.11740
20
+ """
21
+
22
+ def __init__(self,
23
+ model_config,
24
+ feature_configs,
25
+ features,
26
+ labels=None,
27
+ is_training=False):
28
+ super(Uniter, self).__init__(model_config, feature_configs, features,
29
+ labels, is_training)
30
+ assert self._model_config.WhichOneof('model') == 'uniter', (
31
+ 'invalid model config: %s' % self._model_config.WhichOneof('model'))
32
+
33
+ self._uniter_layer = uniter.Uniter(model_config, feature_configs, features,
34
+ self._model_config.uniter.config,
35
+ self._input_layer)
36
+ self._model_config = self._model_config.uniter
37
+
38
+ def build_predict_graph(self):
39
+ hidden = self._uniter_layer(self._is_training, l2_reg=self._l2_reg)
40
+ final_dnn_layer = dnn.DNN(self._model_config.final_dnn, self._l2_reg,
41
+ 'final_dnn', self._is_training)
42
+ all_fea = final_dnn_layer(hidden)
43
+
44
+ final = tf.layers.dense(all_fea, self._num_class, name='output')
45
+ self._add_to_prediction_dict(final)
46
+ return self._prediction_dict
@@ -0,0 +1,121 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import logging
4
+
5
+ import tensorflow as tf
6
+
7
+ from easy_rec.python.layers import dnn
8
+ from easy_rec.python.model.rank_model import RankModel
9
+
10
+ from easy_rec.python.protos.wide_and_deep_pb2 import WideAndDeep as WideAndDeepConfig # NOQA
11
+
12
+ if tf.__version__ >= '2.0':
13
+ tf = tf.compat.v1
14
+
15
+
16
+ class WideAndDeep(RankModel):
17
+
18
+ def __init__(self,
19
+ model_config,
20
+ feature_configs,
21
+ features,
22
+ labels=None,
23
+ is_training=False):
24
+ super(WideAndDeep, self).__init__(model_config, feature_configs, features,
25
+ labels, is_training)
26
+ assert model_config.WhichOneof('model') == 'wide_and_deep', \
27
+ 'invalid model config: %s' % model_config.WhichOneof('model')
28
+ self._model_config = model_config.wide_and_deep
29
+ assert isinstance(self._model_config, WideAndDeepConfig)
30
+ assert self._input_layer.has_group('wide')
31
+ _, self._wide_features = self._input_layer(self._feature_dict, 'wide')
32
+ assert self._input_layer.has_group('deep')
33
+ _, self._deep_features = self._input_layer(self._feature_dict, 'deep')
34
+
35
+ def build_input_layer(self, model_config, feature_configs):
36
+ # overwrite create input_layer to support wide_output_dim
37
+ has_final = len(model_config.wide_and_deep.final_dnn.hidden_units) > 0
38
+ self._wide_output_dim = model_config.wide_and_deep.wide_output_dim
39
+ if not has_final:
40
+ model_config.wide_and_deep.wide_output_dim = model_config.num_class
41
+ self._wide_output_dim = model_config.num_class
42
+ super(WideAndDeep, self).build_input_layer(model_config, feature_configs)
43
+
44
+ def build_predict_graph(self):
45
+ wide_fea = tf.add_n(self._wide_features)
46
+ logging.info('wide features dimension: %d' % wide_fea.get_shape()[-1])
47
+
48
+ self._deep_features = tf.concat(self._deep_features, axis=1)
49
+ logging.info('input deep features dimension: %d' %
50
+ self._deep_features.get_shape()[-1])
51
+
52
+ deep_layer = dnn.DNN(self._model_config.dnn, self._l2_reg, 'deep_feature',
53
+ self._is_training)
54
+ deep_fea = deep_layer(self._deep_features)
55
+ logging.info('output deep features dimension: %d' %
56
+ deep_fea.get_shape()[-1])
57
+
58
+ has_final = len(self._model_config.final_dnn.hidden_units) > 0
59
+ print('wide_deep has_final_dnn layers = %d' % has_final)
60
+ if has_final:
61
+ all_fea = tf.concat([wide_fea, deep_fea], axis=1)
62
+ final_layer = dnn.DNN(self._model_config.final_dnn, self._l2_reg,
63
+ 'final_dnn', self._is_training)
64
+ all_fea = final_layer(all_fea)
65
+ output = tf.layers.dense(
66
+ all_fea,
67
+ self._num_class,
68
+ kernel_regularizer=self._l2_reg,
69
+ name='output')
70
+ else:
71
+ deep_out = tf.layers.dense(
72
+ deep_fea,
73
+ self._num_class,
74
+ kernel_regularizer=self._l2_reg,
75
+ name='deep_out')
76
+ output = deep_out + wide_fea
77
+
78
+ self._add_to_prediction_dict(output)
79
+
80
+ return self._prediction_dict
81
+
82
+ def get_grouped_vars(self, opt_num):
83
+ """Group the vars into different optimization groups.
84
+
85
+ Each group will be optimized by a separate optimizer.
86
+
87
+ Args:
88
+ opt_num: number of optimizers from easyrec config.
89
+
90
+ Return:
91
+ list of list of variables.
92
+ """
93
+ assert opt_num <= 3, 'could only support 2 or 3 optimizers, ' + \
94
+ 'if opt_num = 2, one for the wide , and one for the others, ' + \
95
+ 'if opt_num = 3, one for the wide, second for the deep embeddings, ' + \
96
+ 'and third for the other layers.'
97
+
98
+ if opt_num == 2:
99
+ wide_vars = []
100
+ deep_vars = []
101
+ for tmp_var in tf.trainable_variables():
102
+ if tmp_var.name.startswith('input_layer') and \
103
+ (not tmp_var.name.startswith('input_layer_1')):
104
+ wide_vars.append(tmp_var)
105
+ else:
106
+ deep_vars.append(tmp_var)
107
+ return [wide_vars, deep_vars]
108
+ elif opt_num == 3:
109
+ wide_vars = []
110
+ embedding_vars = []
111
+ deep_vars = []
112
+ for tmp_var in tf.trainable_variables():
113
+ if tmp_var.name.startswith('input_layer') and \
114
+ (not tmp_var.name.startswith('input_layer_1')):
115
+ wide_vars.append(tmp_var)
116
+ elif tmp_var.name.startswith(
117
+ 'input_layer') or '/embedding_weights' in tmp_var.name:
118
+ embedding_vars.append(tmp_var)
119
+ else:
120
+ deep_vars.append(tmp_var)
121
+ return [wide_vars, embedding_vars, deep_vars]
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
File without changes
@@ -0,0 +1,193 @@
1
+ """Python wrappers around TensorFlow ops.
2
+
3
+ This file is MACHINE GENERATED! Do not edit.
4
+ Original C++ source file: kafka_ops_deprecated.cc
5
+ """
6
+
7
+ import logging
8
+ import os
9
+ import traceback
10
+
11
+ import six as _six
12
+ import tensorflow as tf
13
+ from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow
14
+ from tensorflow.python.eager import context as _context
15
+ from tensorflow.python.eager import core as _core
16
+ from tensorflow.python.eager import execute as _execute
17
+ # Needed to trigger the call to _set_call_cpp_shape_fn.
18
+ from tensorflow.python.framework import dtypes as _dtypes
19
+ from tensorflow.python.framework import ops as _ops
20
+ from tensorflow.python.util.tf_export import tf_export
21
+
22
+ import easy_rec
23
+
24
+ kafka_module = None
25
+ if easy_rec.ops_dir is not None:
26
+ kafka_ops_path = os.path.join(easy_rec.ops_dir, 'kafka.so')
27
+ if os.path.exists(kafka_ops_path):
28
+ try:
29
+ kafka_module = tf.load_op_library(kafka_ops_path)
30
+ except Exception:
31
+ logging.warning('load %s failed: %s' %
32
+ (kafka_ops_path, traceback.format_exc()))
33
+
34
+
35
+ @tf_export('io_kafka_dataset_v2')
36
+ def io_kafka_dataset_v2(topics,
37
+ servers,
38
+ group,
39
+ eof,
40
+ timeout,
41
+ config_global,
42
+ config_topic,
43
+ message_key,
44
+ message_offset,
45
+ name=None):
46
+ """Creates a dataset that emits the messages of one or more Kafka topics.
47
+
48
+ Args:
49
+ topics: A `Tensor` of type `string`.
50
+ A `tf.string` tensor containing one or more subscriptions,
51
+ in the format of [topic:partition:offset].
52
+ servers: A `Tensor` of type `string`. A list of bootstrap servers.
53
+ group: A `Tensor` of type `string`. The consumer group id.
54
+ eof: A `Tensor` of type `bool`.
55
+ If True, the kafka reader will stop on EOF.
56
+ timeout: A `Tensor` of type `int64`.
57
+ The timeout value for the Kafka Consumer to wait
58
+ (in millisecond).
59
+ config_global: A `Tensor` of type `string`.
60
+ A `tf.string` tensor containing global configuration
61
+ properties in [Key=Value] format,
62
+ eg. ["enable.auto.commit=false", "heartbeat.interval.ms=2000"],
63
+ please refer to 'Global configuration properties' in librdkafka doc.
64
+ config_topic: A `Tensor` of type `string`.
65
+ A `tf.string` tensor containing topic configuration
66
+ properties in [Key=Value] format, eg. ["auto.offset.reset=earliest"],
67
+ please refer to 'Topic configuration properties' in librdkafka doc.
68
+ message_key: A `Tensor` of type `bool`.
69
+ message_offset: A `Tensor` of type `bool`.
70
+ name: A name for the operation (optional).
71
+
72
+ Returns:
73
+ A `Tensor` of type `variant`.
74
+ """
75
+ return kafka_module.io_kafka_dataset_v2(
76
+ topics=topics,
77
+ servers=servers,
78
+ group=group,
79
+ eof=eof,
80
+ timeout=timeout,
81
+ config_global=config_global,
82
+ config_topic=config_topic,
83
+ message_key=message_key,
84
+ message_offset=message_offset,
85
+ name=name)
86
+
87
+
88
+ def io_kafka_dataset_eager_fallback(topics,
89
+ servers,
90
+ group,
91
+ eof,
92
+ timeout,
93
+ config_global,
94
+ config_topic,
95
+ message_key,
96
+ message_offset,
97
+ name=None,
98
+ ctx=None):
99
+ """This is the slowpath function for Eager mode.
100
+
101
+ This is for function io_kafka_dataset
102
+ """
103
+ _ctx = ctx if ctx else _context.context()
104
+ topics = _ops.convert_to_tensor(topics, _dtypes.string)
105
+ servers = _ops.convert_to_tensor(servers, _dtypes.string)
106
+ group = _ops.convert_to_tensor(group, _dtypes.string)
107
+ eof = _ops.convert_to_tensor(eof, _dtypes.bool)
108
+ timeout = _ops.convert_to_tensor(timeout, _dtypes.int64)
109
+ config_global = _ops.convert_to_tensor(config_global, _dtypes.string)
110
+ config_topic = _ops.convert_to_tensor(config_topic, _dtypes.string)
111
+ message_key = _ops.convert_to_tensor(message_key, _dtypes.bool)
112
+ message_offset = _ops.convert_to_tensor(message_offset, _dtypes.bool)
113
+ _inputs_flat = [
114
+ topics, servers, group, eof, timeout, config_global, config_topic,
115
+ message_key, message_offset
116
+ ]
117
+ _attrs = None
118
+ _result = _execute.execute(
119
+ b'IOKafkaDataset',
120
+ 1,
121
+ inputs=_inputs_flat,
122
+ attrs=_attrs,
123
+ ctx=_ctx,
124
+ name=name)
125
+ _execute.record_gradient('IOKafkaDataset', _inputs_flat, _attrs, _result,
126
+ name)
127
+ _result, = _result
128
+ return _result
129
+
130
+
131
+ @tf_export('io_write_kafka_v2')
132
+ def io_write_kafka_v2(message, topic, servers, name=None):
133
+ r"""TODO: add doc.
134
+
135
+ Args:
136
+ message: A `Tensor` of type `string`.
137
+ topic: A `Tensor` of type `string`.
138
+ servers: A `Tensor` of type `string`.
139
+ name: A name for the operation (optional).
140
+
141
+ Returns:
142
+ A `Tensor` of type `string`.
143
+ """
144
+ _ctx = _context._context
145
+ if _ctx is None or not _ctx._eager_context.is_eager:
146
+ _op = kafka_module.io_write_kafka_v2(
147
+ message=message, topic=topic, servers=servers, name=name)
148
+ _result = _op.outputs[:]
149
+ _inputs_flat = _op.inputs
150
+ _attrs = None
151
+ _execute.record_gradient('IOWriteKafka', _inputs_flat, _attrs, _result,
152
+ name)
153
+ _result, = _result
154
+ return _result
155
+
156
+ else:
157
+ try:
158
+ _result = _pywrap_tensorflow.TFE_Py_FastPathExecute(
159
+ _ctx._context_handle, _ctx._eager_context.device_name, 'IOWriteKafka',
160
+ name, _ctx._post_execution_callbacks, message, topic, servers)
161
+ return _result
162
+ except _core._FallbackException:
163
+ return io_write_kafka_eager_fallback(
164
+ message, topic, servers, name=name, ctx=_ctx)
165
+ except _core._NotOkStatusException as e:
166
+ if name is not None:
167
+ message = e.message + ' name: ' + name
168
+ else:
169
+ message = e.message
170
+ _six.raise_from(_core._status_to_exception(e.code, message), None)
171
+
172
+
173
+ def io_write_kafka_eager_fallback(message, topic, servers, name=None, ctx=None):
174
+ """This is the slowpath function for Eager mode.
175
+
176
+ This is for function io_write_kafka
177
+ """
178
+ _ctx = ctx if ctx else _context.context()
179
+ message = _ops.convert_to_tensor(message, _dtypes.string)
180
+ topic = _ops.convert_to_tensor(topic, _dtypes.string)
181
+ servers = _ops.convert_to_tensor(servers, _dtypes.string)
182
+ _inputs_flat = [message, topic, servers]
183
+ _attrs = None
184
+ _result = _execute.execute(
185
+ b'IOWriteKafka',
186
+ 1,
187
+ inputs=_inputs_flat,
188
+ attrs=_attrs,
189
+ ctx=_ctx,
190
+ name=name)
191
+ _execute.record_gradient('IOWriteKafka', _inputs_flat, _attrs, _result, name)
192
+ _result, = _result
193
+ return _result
@@ -0,0 +1,28 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import logging
4
+ import os
5
+
6
+ import tensorflow as tf
7
+ from tensorflow.python.ops import string_ops
8
+
9
+ import easy_rec
10
+ from easy_rec.python.utils import constant
11
+
12
+ try:
13
+ str_avx_op_path = os.path.join(easy_rec.ops_dir, 'libstr_avx_op.so')
14
+ str_avx_op = tf.load_op_library(str_avx_op_path)
15
+ logging.info('load avx string_split op from %s succeed' % str_avx_op_path)
16
+ except Exception as ex:
17
+ logging.warning('load avx string_split op failed: %s' % str(ex))
18
+ str_avx_op = None
19
+
20
+
21
+ def str_split_by_chr(input_str, sep, skip_empty):
22
+ if constant.has_avx_str_split() and str_avx_op is not None:
23
+ assert len(sep) == 1, \
24
+ 'invalid data_config.separator(%s) len(%d) != 1' % (
25
+ sep, len(sep))
26
+ return str_avx_op.avx512_string_split(input_str, sep, skip_empty=skip_empty)
27
+ else:
28
+ return string_ops.string_split(input_str, sep, skip_empty=skip_empty)
@@ -0,0 +1,30 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import logging
4
+ import os
5
+
6
+ import tensorflow as tf
7
+
8
+ import easy_rec
9
+
10
+ try:
11
+ op_path = os.path.join(easy_rec.ops_dir, 'incr_record.so')
12
+ op = tf.load_op_library(op_path)
13
+ get_sparse_indices = op.get_sparse_indices
14
+ set_sparse_indices = op.set_sparse_indices
15
+ if 'kv_resource_incr_gather' in dir(op):
16
+ kv_resource_incr_gather = getattr(op, 'kv_resource_incr_gather')
17
+ else:
18
+ kv_resource_incr_gather = None
19
+ except ImportError as ex:
20
+ get_sparse_indices = None
21
+ set_sparse_indices = None
22
+ kv_resource_incr_gather = None
23
+ logging.warning('failed to import gen_io_ops.collect_sparse_indices: %s' %
24
+ str(ex))
25
+ except Exception as ex:
26
+ get_sparse_indices = None
27
+ set_sparse_indices = None
28
+ kv_resource_incr_gather = None
29
+ logging.warning('failed to import gen_io_ops.collect_sparse_indices: %s' %
30
+ str(ex))