easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
@@ -0,0 +1,62 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ # Date: 2018-09-13
4
+ import tensorflow as tf
5
+ from tensorflow.python.training import optimizer
6
+
7
+
8
+ class MultiOptimizer(optimizer.Optimizer):
9
+
10
+ def __init__(self, opts, grouped_vars, use_locking=False):
11
+ """Combine multiple optimizers for optimization, such as WideAndDeep.
12
+
13
+ Args:
14
+ opts: list of optimizer instance.
15
+ grouped_vars: list of list of vars, each list of vars are
16
+ optimized by each of the optimizers.
17
+ use_locking: be compatible, currently not used.
18
+ """
19
+ super(MultiOptimizer, self).__init__(use_locking, 'MultiOptimizer')
20
+ self._opts = opts
21
+ self._grouped_vars = grouped_vars
22
+
23
+ def compute_gradients(self, loss, variables, **kwargs):
24
+ grad_and_vars = []
25
+ for gid, opt in enumerate(self._opts):
26
+ grad_and_vars.extend(
27
+ opt.compute_gradients(loss, self._grouped_vars[gid], **kwargs))
28
+ return grad_and_vars
29
+
30
+ def apply_gradients(self, grads_and_vars, global_step=None, name=None):
31
+ update_ops = []
32
+ grads_and_vars = [x for x in grads_and_vars]
33
+ for gid, opt in enumerate(self._opts):
34
+ tmp = [x for x in grads_and_vars if x[1] in self._grouped_vars[gid]]
35
+ if gid == 0:
36
+ update_ops.append(opt.apply_gradients(tmp, global_step))
37
+ else:
38
+ update_ops.append(opt.apply_gradients(tmp, None))
39
+ return tf.group(update_ops)
40
+
41
+ def open_auto_record(self, flag=True):
42
+ super(MultiOptimizer, self).open_auto_record(flag)
43
+
44
+ def get_slot(self, var, name):
45
+ raise NotImplementedError('not implemented')
46
+ # for opt in self._opts:
47
+ # tmp = opt.get_slot(var, name)
48
+ # if tmp is not None:
49
+ # return tmp
50
+ # return None
51
+
52
+ def variables(self):
53
+ all_vars = []
54
+ for opt in self._opts:
55
+ all_vars.extend(opt.variables())
56
+ return all_vars
57
+
58
+ def get_slot_names(self):
59
+ slot_names = []
60
+ for opt in self._opts:
61
+ slot_names.extend(opt.get_slot_names())
62
+ return slot_names
@@ -0,0 +1,18 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import json
4
+
5
+ import numpy as np
6
+
7
+
8
+ class NumpyEncoder(json.JSONEncoder):
9
+ """For encode numpy arrays."""
10
+
11
+ def default(self, obj):
12
+ if isinstance(obj, np.integer):
13
+ return int(obj)
14
+ elif isinstance(obj, np.floating):
15
+ return float(obj)
16
+ elif isinstance(obj, np.ndarray):
17
+ return obj.tolist()
18
+ return json.JSONEncoder.default(self, obj)
@@ -0,0 +1,79 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ """Common functions used for odps input."""
4
+ from tensorflow.python.framework import dtypes
5
+
6
+ from easy_rec.python.protos.dataset_pb2 import DatasetConfig
7
+
8
+
9
+ def is_type_compatiable(odps_type, input_type):
10
+ """Check that odps_type are compatiable with input_type."""
11
+ type_map = {
12
+ 'bigint': DatasetConfig.INT64,
13
+ 'string': DatasetConfig.STRING,
14
+ 'double': DatasetConfig.DOUBLE
15
+ }
16
+ tmp_type = type_map[odps_type]
17
+ if tmp_type == input_type:
18
+ return True
19
+ else:
20
+ float_types = [DatasetConfig.FLOAT, DatasetConfig.DOUBLE]
21
+ int_types = [DatasetConfig.INT32, DatasetConfig.INT64]
22
+ if tmp_type in float_types and input_type in float_types:
23
+ return True
24
+ elif tmp_type in int_types and input_type in int_types:
25
+ return True
26
+ else:
27
+ return False
28
+
29
+
30
+ def odps_type_to_input_type(odps_type):
31
+ """Check that odps_type are compatiable with input_type."""
32
+ odps_type_map = {
33
+ 'bigint': DatasetConfig.INT64,
34
+ 'string': DatasetConfig.STRING,
35
+ 'double': DatasetConfig.DOUBLE
36
+ }
37
+ assert odps_type in odps_type_map, 'only support [bigint, string, double]'
38
+ input_type = odps_type_map[odps_type]
39
+ return input_type
40
+
41
+
42
+ def check_input_field_and_types(data_config):
43
+ """Check compatibility of input in data_config.
44
+
45
+ check that data_config.input_fields are compatible with
46
+ data_config.selected_cols and data_config.selected_types.
47
+
48
+ Args:
49
+ data_config: instance of DatasetConfig
50
+ """
51
+ input_fields = [x.input_name for x in data_config.input_fields]
52
+ input_field_types = [x.input_type for x in data_config.input_fields]
53
+ selected_cols = data_config.selected_cols if data_config.selected_cols else None
54
+ selected_col_types = data_config.selected_col_types if data_config.selected_col_types else None
55
+ if not selected_cols:
56
+ return
57
+
58
+ selected_cols = selected_cols.split(',')
59
+ for x in input_fields:
60
+ assert x in selected_cols, 'column %s is not in table' % x
61
+ if selected_col_types:
62
+ selected_types = selected_col_types.split(',')
63
+ type_map = {x: y for x, y in zip(selected_cols, selected_types)}
64
+ for x, y in zip(input_fields, input_field_types):
65
+ tmp_type = type_map[x]
66
+ assert is_type_compatiable(tmp_type, y), \
67
+ 'feature[%s] type error: odps %s is not compatible with input_type %s' % (
68
+ x, tmp_type, DatasetConfig.FieldType.Name(y))
69
+
70
+
71
+ def odps_type_2_tf_type(odps_type):
72
+ if odps_type == 'string':
73
+ return dtypes.string
74
+ elif odps_type == 'bigint':
75
+ return dtypes.int64
76
+ elif odps_type in ['double', 'float']:
77
+ return dtypes.float32
78
+ else:
79
+ return dtypes.string
@@ -0,0 +1,86 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+
4
+ import logging
5
+ import os
6
+ import sys
7
+ import traceback
8
+
9
+ import tensorflow as tf
10
+
11
+ if sys.version_info.major == 2:
12
+ from urllib2 import urlopen, Request, HTTPError
13
+ else:
14
+ from urllib.request import urlopen, Request
15
+ from urllib.error import HTTPError
16
+
17
+
18
+ def is_on_pai():
19
+ # IS_ON_PAI is set in pai_jobs/run.py
20
+ # which is the entry on pai platform
21
+ return 'IS_ON_PAI' in os.environ
22
+
23
+
24
+ def set_on_pai():
25
+ logging.info('set on pai environment variable: IS_ON_PAI')
26
+ os.environ['IS_ON_PAI'] = '1'
27
+
28
+
29
+ def download(url):
30
+ _, fname = os.path.split(url)
31
+ request = Request(url=url)
32
+ try:
33
+ response = urlopen(request, timeout=10)
34
+ with open(fname, 'w') as ofile:
35
+ ofile.write(response.read())
36
+ return fname
37
+ except HTTPError as e:
38
+ tf.logging.error('http error: ', e.code)
39
+ tf.logging.error('body:', e.read())
40
+ return None
41
+ except Exception as e:
42
+ tf.logging.error(e)
43
+ tf.logging.error(traceback.format_exc())
44
+ return None
45
+
46
+
47
+ def process_config(configs, task_index=0, worker_num=1):
48
+ """Download config and select config for the worker.
49
+
50
+ Args:
51
+ configs: config paths, separated by ','
52
+ task_index: worker index
53
+ worker_num: total number of workers
54
+ """
55
+ configs = configs.split(',')
56
+ if len(configs) > 1:
57
+ assert len(configs) == worker_num, \
58
+ 'number of configs must be equal to number of workers,' + \
59
+ ' when number of configs > 1'
60
+ config = configs[task_index]
61
+ else:
62
+ config = configs[0]
63
+
64
+ if config[:4] == 'http':
65
+ return download(config)
66
+ elif config[:3] == 'oss':
67
+ if '/##/' in config:
68
+ config = config.replace('/##/', '\x02')
69
+ if '/#/' in config:
70
+ config = config.replace('/#/', '\x01')
71
+ return config
72
+ else:
73
+ # allow to use this entry file to run experiments from local env
74
+ # to avoid uploading sample file
75
+ return config
76
+
77
+
78
+ def test():
79
+ f = download(
80
+ 'https://easy-rec.oss-cn-hangzhou.aliyuncs.com/config/MultiTower/dwd_avazu_ctr_deepmodel.config'
81
+ )
82
+ assert f == 'dwd_avazu_ctr_deepmodel.config'
83
+
84
+
85
+ if __name__ == '__main__':
86
+ test()
@@ -0,0 +1,90 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import logging
4
+
5
+
6
+ def copy_obj(proto_obj):
7
+ """Make a copy of proto_obj so that later modifications of tmp_obj will have no impact on proto_obj.
8
+
9
+ Args:
10
+ proto_obj: a protobuf message
11
+ Return:
12
+ a copy of proto_obj
13
+ """
14
+ tmp_obj = type(proto_obj)()
15
+ tmp_obj.CopyFrom(proto_obj)
16
+ return tmp_obj
17
+
18
+
19
+ def get_norm_embed_name(name, verbose=False):
20
+ """For embedding export to redis.
21
+
22
+ Args:
23
+ name: variable name
24
+ verbose: whether to dump the embed_names
25
+ Return:
26
+ embedding_name: normalized embedding_name
27
+ embedding_part_id: normalized embedding part_id
28
+ if embedding_weights not in name, return None, None
29
+ """
30
+ name_toks = name.split('/')
31
+ for i in range(0, len(name_toks) - 1):
32
+ if name_toks[i + 1].startswith('embedding_weights:'):
33
+ var_id = name_toks[i + 1].replace('embedding_weights:', '')
34
+ tmp_name = '/'.join(name_toks[:i + 1])
35
+ if var_id != '0':
36
+ tmp_name = tmp_name + '_' + var_id
37
+ if verbose:
38
+ logging.info('norm %s to %s' % (name, tmp_name))
39
+ return tmp_name, 0
40
+ if i > 1 and name_toks[i + 1].startswith('part_') and \
41
+ name_toks[i] == 'embedding_weights':
42
+ tmp_name = '/'.join(name_toks[:i])
43
+ part_id = name_toks[i + 1].replace('part_', '')
44
+ part_toks = part_id.split(':')
45
+ if len(part_toks) >= 2 and part_toks[1] != '0':
46
+ tmp_name = tmp_name + '_' + part_toks[1]
47
+ if verbose:
48
+ logging.info('norm %s to %s' % (name, tmp_name))
49
+ return tmp_name, int(part_toks[0])
50
+
51
+ # input_layer/app_category_embedding/app_category_embedding_weights/SparseReshape
52
+ # => input_layer/app_category_embedding
53
+ for i in range(0, len(name_toks) - 1):
54
+ if name_toks[i + 1].endswith('_embedding_weights') or \
55
+ '_embedding_weights_' in name_toks[i + 1]:
56
+ tmp_name = '/'.join(name_toks[:i + 1])
57
+ if verbose:
58
+ logging.info('norm %s to %s' % (name, tmp_name))
59
+ return tmp_name, 0
60
+ # input_layer/app_category_embedding/embedding_weights
61
+ # => input_layer/app_category_embedding
62
+ for i in range(0, len(name_toks) - 1):
63
+ if name_toks[i + 1] == 'embedding_weights':
64
+ tmp_name = '/'.join(name_toks[:i + 1])
65
+ if verbose:
66
+ logging.info('norm %s to %s' % (name, tmp_name))
67
+ return tmp_name, 0
68
+ logging.warning('Failed to norm: %s' % name)
69
+ return None, None
70
+
71
+
72
+ def is_cache_from_redis(name, redis_cache_names):
73
+ """Check whether name should be cached.
74
+
75
+ Args:
76
+ name: string, the variable name to be checked
77
+ redis_cache_names: list of string, names which should be cached.
78
+
79
+ Return:
80
+ True if need to be cached
81
+ """
82
+ tok = name.split('/')
83
+ if tok[0].startswith('input_layer'):
84
+ tok = tok[1:]
85
+ for y in redis_cache_names:
86
+ for k in tok:
87
+ if k.startswith(y):
88
+ logging.info('embedding %s will be cached[specified by %s]' % (name, y))
89
+ return True
90
+ return False
@@ -0,0 +1,89 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ """Define filters for restore."""
4
+
5
+ from abc import ABCMeta
6
+ from abc import abstractmethod
7
+ from enum import Enum
8
+
9
+
10
+ class Logical(Enum):
11
+ AND = 1
12
+ OR = 2
13
+
14
+
15
+ class Filter:
16
+ __metaclass__ = ABCMeta
17
+
18
+ def __init__(self):
19
+ pass
20
+
21
+ @abstractmethod
22
+ def keep(self, var_name):
23
+ """Keep the var or not.
24
+
25
+ Args:
26
+ var_name: input name of the var
27
+
28
+ Returns:
29
+ True if the var will be kept, else False
30
+ """
31
+ return True
32
+
33
+
34
+ class KeywordFilter(Filter):
35
+
36
+ def __init__(self, pattern, exclusive=False):
37
+ """Init KeywordFilter.
38
+
39
+ Args:
40
+ pattern: keyword to be matched
41
+ exclusive: if True, var_name should include the pattern
42
+ else, var_name should not include the pattern
43
+ """
44
+ self._pattern = pattern
45
+ self._exclusive = exclusive
46
+
47
+ def keep(self, var_name):
48
+ if not self._exclusive:
49
+ return self._pattern in var_name
50
+ else:
51
+ return self._pattern not in var_name
52
+
53
+
54
+ class CombineFilter(Filter):
55
+
56
+ def __init__(self, filters, logical=Logical.AND):
57
+ """Init CombineFilter.
58
+
59
+ Args:
60
+ filters: a set of filters to be combined
61
+ logical: logical and/or combination of the filters
62
+ """
63
+ self._filters = filters
64
+ self._logical = logical
65
+
66
+ def keep(self, var_name):
67
+ if self._logical == Logical.AND:
68
+ for one_filter in self._filters:
69
+ if not one_filter.keep(var_name):
70
+ return False
71
+ return True
72
+ elif self._logical == Logical.OR:
73
+ for one_filter in self._filters:
74
+ if one_filter.keep(var_name):
75
+ return True
76
+ return False
77
+
78
+
79
+ class ScopeDrop:
80
+ """For drop out scope prefix when restore variables from checkpoint."""
81
+
82
+ def __init__(self, scope_name):
83
+ self._scope_name = scope_name
84
+ if len(self._scope_name) >= 0:
85
+ if self._scope_name[-1] != '/':
86
+ self._scope_name += '/'
87
+
88
+ def update(self, var_name):
89
+ return var_name.replace(self._scope_name, '')