easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
@@ -0,0 +1,528 @@
1
+ # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Synchronize replicas for training."""
16
+ from __future__ import absolute_import
17
+ from __future__ import division
18
+ from __future__ import print_function
19
+
20
+ from tensorflow.core.framework import types_pb2
21
+ from tensorflow.python.framework import errors_impl
22
+ from tensorflow.python.framework import ops
23
+ from tensorflow.python.ops import array_ops
24
+ from tensorflow.python.ops import control_flow_ops
25
+ from tensorflow.python.ops import data_flow_ops
26
+ from tensorflow.python.ops import state_ops
27
+ from tensorflow.python.ops import variable_scope
28
+ from tensorflow.python.ops import variables
29
+ from tensorflow.python.platform import tf_logging as logging
30
+ from tensorflow.python.training import optimizer
31
+ from tensorflow.python.training import queue_runner
32
+ from tensorflow.python.training import session_manager
33
+ from tensorflow.python.training import session_run_hook
34
+ from tensorflow.python.util.tf_export import tf_export
35
+
36
+
37
+ # Please note that the gradients from replicas are averaged instead of summed
38
+ # (as in the old sync_replicas_optimizer) so you need to increase the learning
39
+ # rate according to the number of replicas. This change is introduced to be
40
+ # consistent with how gradients are aggregated (averaged) within a batch in a
41
+ # replica.
42
+ @tf_export('train.SyncReplicasOptimizer')
43
+ class SyncReplicasOptimizer(optimizer.Optimizer):
44
+ """Class to synchronize, aggregate gradients and pass them to the optimizer.
45
+
46
+ In a typical asynchronous training environment, it's common to have some
47
+ stale gradients. For example, with a N-replica asynchronous training,
48
+ gradients will be applied to the variables N times independently. Depending
49
+ on each replica's training speed, some gradients might be calculated from
50
+ copies of the variable from several steps back (N-1 steps on average). This
51
+ optimizer avoids stale gradients by collecting gradients from all replicas,
52
+ averaging them, then applying them to the variables in one shot, after
53
+ which replicas can fetch the new variables and continue.
54
+
55
+ The following accumulators/queue are created:
56
+
57
+ * N `gradient accumulators`, one per variable to train. Gradients are pushed
58
+ to them and the chief worker will wait until enough gradients are collected
59
+ and then average them before applying to variables. The accumulator will
60
+ drop all stale gradients (more details in the accumulator op).
61
+ * 1 `token` queue where the optimizer pushes the new global_step value after
62
+ all variables are updated.
63
+
64
+ The following local variable is created:
65
+ * `sync_rep_local_step`, one per replica. Compared against the global_step in
66
+ each accumulator to check for staleness of the gradients.
67
+
68
+ The optimizer adds nodes to the graph to collect gradients and pause the
69
+ trainers until variables are updated.
70
+ For the Parameter Server job:
71
+
72
+ 1. An accumulator is created for each variable, and each replica pushes the
73
+ gradients into the accumulators instead of directly applying them to the
74
+ variables.
75
+ 2. Each accumulator averages once enough gradients (replicas_to_aggregate)
76
+ have been accumulated.
77
+ 3. Apply the averaged gradients to the variables.
78
+ 4. Only after all variables have been updated, increment the global step.
79
+ 5. Only after step 4, pushes `global_step` in the `token_queue`, once for
80
+ each worker replica. The workers can now fetch the global step, use it to
81
+ update its local_step variable and start the next batch.
82
+
83
+ For the replicas:
84
+
85
+ 1. Start a step: fetch variables and compute gradients.
86
+ 2. Once the gradients have been computed, push them into gradient
87
+ accumulators. Each accumulator will check the staleness and drop the stale.
88
+ 3. After pushing all the gradients, dequeue an updated value of global_step
89
+ from the token queue and record that step to its local_step variable. Note
90
+ that this is effectively a barrier.
91
+ 4. Start the next batch.
92
+
93
+ ### Usage
94
+
95
+ ```python
96
+ # Create any optimizer to update the variables, say a simple SGD:
97
+ opt = GradientDescentOptimizer(learning_rate=0.1)
98
+
99
+ # Wrap the optimizer with sync_replicas_optimizer with 50 replicas: at each
100
+ # step the optimizer collects 50 gradients before applying to variables.
101
+ # Note that if you want to have 2 backup replicas, you can change
102
+ # total_num_replicas=52 and make sure this number matches how many physical
103
+ # replicas you started in your job.
104
+ opt = tf.train.SyncReplicasOptimizer(opt, replicas_to_aggregate=50,
105
+ total_num_replicas=50)
106
+
107
+ # Some models have startup_delays to help stabilize the model but when using
108
+ # sync_replicas training, set it to 0.
109
+
110
+ # Now you can call `minimize()` or `compute_gradients()` and
111
+ # `apply_gradients()` normally
112
+ training_op = opt.minimize(total_loss, global_step=self.global_step)
113
+
114
+
115
+ # You can create the hook which handles initialization and queues.
116
+ sync_replicas_hook = opt.make_session_run_hook(is_chief)
117
+ ```
118
+
119
+ In the training program, every worker will run the train_op as if not
120
+ synchronized.
121
+
122
+ ```python
123
+ with training.MonitoredTrainingSession(
124
+ master=workers[worker_id].target, is_chief=is_chief,
125
+ hooks=[sync_replicas_hook]) as mon_sess:
126
+ while not mon_sess.should_stop():
127
+ mon_sess.run(training_op)
128
+ ```
129
+
130
+ To use SyncReplicasOptimizer with an `Estimator`, you need to send
131
+ sync_replicas_hook while calling the fit.
132
+ ```python
133
+ my_estimator = DNNClassifier(..., optimizer=opt)
134
+ my_estimator.fit(..., hooks=[sync_replicas_hook])
135
+ ```
136
+ """
137
+
138
+ sync_que_id = -1
139
+
140
+ def __init__(self,
141
+ opt,
142
+ replicas_to_aggregate,
143
+ total_num_replicas=None,
144
+ variable_averages=None,
145
+ variables_to_average=None,
146
+ use_locking=False,
147
+ name='sync_replicas',
148
+ **extra_args):
149
+ """Construct a sync_replicas optimizer.
150
+
151
+ Args:
152
+ opt: The actual optimizer that will be used to compute and apply the
153
+ gradients. Must be one of the Optimizer classes.
154
+ replicas_to_aggregate: number of replicas to aggregate for each variable
155
+ update.
156
+ total_num_replicas: Total number of tasks/workers/replicas, could be
157
+ different from replicas_to_aggregate.
158
+ If total_num_replicas > replicas_to_aggregate: it is backup_replicas +
159
+ replicas_to_aggregate.
160
+ If total_num_replicas < replicas_to_aggregate: Replicas compute
161
+ multiple batches per update to variables.
162
+ variable_averages: Optional `ExponentialMovingAverage` object, used to
163
+ maintain moving averages for the variables passed in
164
+ `variables_to_average`.
165
+ variables_to_average: a list of variables that need to be averaged. Only
166
+ needed if variable_averages is passed in.
167
+ use_locking: If True use locks for update operation.
168
+ name: string. Optional name of the returned operation.
169
+ """
170
+ if total_num_replicas is None:
171
+ total_num_replicas = replicas_to_aggregate
172
+
173
+ super(SyncReplicasOptimizer, self).__init__(use_locking, name)
174
+ logging.info(
175
+ 'SyncReplicasV2: replicas_to_aggregate=%s; total_num_replicas=%s',
176
+ replicas_to_aggregate, total_num_replicas)
177
+ self._opt = opt
178
+ self._replicas_to_aggregate = replicas_to_aggregate
179
+ self._gradients_applied = False
180
+ self._variable_averages = variable_averages
181
+ self._variables_to_average = variables_to_average
182
+ self._total_num_replicas = total_num_replicas
183
+ self._tokens_per_step = max(total_num_replicas, replicas_to_aggregate)
184
+ self._global_step = None
185
+ self._sync_token_queue = None
186
+ self._is_sync_que_closed = None
187
+ self._close_sync_que = None
188
+
189
+ # The synchronization op will be executed in a queue runner which should
190
+ # only be executed by one of the replicas (usually the chief).
191
+ self._chief_queue_runner = None
192
+
193
+ # Remember which accumulator is on which device to set the initial step in
194
+ # the accumulator to be global step. This list contains list of the
195
+ # following format: (accumulator, device).
196
+ self._accumulator_list = []
197
+
198
+ def compute_gradients(self, *args, **kwargs):
199
+ """Compute gradients of "loss" for the variables in "var_list".
200
+
201
+ This simply wraps the compute_gradients() from the real optimizer. The
202
+ gradients will be aggregated in the apply_gradients() so that user can
203
+ modify the gradients like clipping with per replica global norm if needed.
204
+ The global norm with aggregated gradients can be bad as one replica's huge
205
+ gradients can hurt the gradients from other replicas.
206
+
207
+ Args:
208
+ *args: Arguments for compute_gradients().
209
+ **kwargs: Keyword arguments for compute_gradients().
210
+
211
+ Returns:
212
+ A list of (gradient, variable) pairs.
213
+ """
214
+ return self._opt.compute_gradients(*args, **kwargs)
215
+
216
+ def apply_gradients(self, grads_and_vars, global_step=None, name=None):
217
+ """Apply gradients to variables.
218
+
219
+ This contains most of the synchronization implementation and also wraps the
220
+ apply_gradients() from the real optimizer.
221
+
222
+ Args:
223
+ grads_and_vars: List of (gradient, variable) pairs as returned by
224
+ compute_gradients().
225
+ global_step: Optional Variable to increment by one after the
226
+ variables have been updated.
227
+ name: Optional name for the returned operation. Default to the
228
+ name passed to the Optimizer constructor.
229
+
230
+ Returns:
231
+ train_op: The op to dequeue a token so the replicas can exit this batch
232
+ and start the next one. This is executed by each replica.
233
+
234
+ Raises:
235
+ ValueError: If the grads_and_vars is empty.
236
+ ValueError: If global step is not provided, the staleness cannot be
237
+ checked.
238
+ """
239
+ if not grads_and_vars:
240
+ raise ValueError('Must supply at least one variable')
241
+
242
+ if global_step is None:
243
+ raise ValueError('Global step is required to check staleness')
244
+
245
+ self._global_step = global_step
246
+ train_ops = []
247
+ aggregated_grad = []
248
+ var_list = []
249
+
250
+ # local_anchor op will be placed on this worker task by default.
251
+ local_anchor = control_flow_ops.no_op()
252
+ # Colocating local_step variable prevents it being placed on the PS.
253
+ with ops.colocate_with(local_anchor):
254
+ self._local_step = variable_scope.variable(
255
+ initial_value=0,
256
+ trainable=False,
257
+ collections=[ops.GraphKeys.LOCAL_VARIABLES],
258
+ dtype=global_step.dtype.base_dtype,
259
+ name='sync_rep_local_step')
260
+
261
+ self.local_step_init_op = state_ops.assign(self._local_step, global_step)
262
+ chief_init_ops = [self.local_step_init_op]
263
+ self.ready_for_local_init_op = variables.report_uninitialized_variables(
264
+ variables.global_variables())
265
+
266
+ with ops.name_scope(None, self._name):
267
+ for grad, var in grads_and_vars:
268
+ var_list.append(var)
269
+ with ops.device(var.device):
270
+ # Dense gradients.
271
+ if grad is None:
272
+ aggregated_grad.append(None) # pass-through.
273
+ continue
274
+ elif isinstance(grad, ops.Tensor):
275
+ grad_accum = data_flow_ops.ConditionalAccumulator(
276
+ grad.dtype,
277
+ shape=var.get_shape(),
278
+ shared_name=var.name + '/grad_accum')
279
+ train_ops.append(
280
+ grad_accum.apply_grad(grad, local_step=self._local_step))
281
+ aggregated_grad.append(
282
+ grad_accum.take_grad(self._replicas_to_aggregate))
283
+ else:
284
+ if not isinstance(grad, ops.IndexedSlices):
285
+ raise ValueError('Unknown grad type!')
286
+ grad_accum = data_flow_ops.SparseConditionalAccumulator(
287
+ grad.dtype, shape=(), shared_name=var.name + '/grad_accum')
288
+ train_ops.append(
289
+ grad_accum.apply_indexed_slices_grad(
290
+ grad, local_step=self._local_step))
291
+ aggregated_grad.append(
292
+ grad_accum.take_indexed_slices_grad(
293
+ self._replicas_to_aggregate))
294
+
295
+ self._accumulator_list.append((grad_accum, var.device))
296
+
297
+ aggregated_grads_and_vars = zip(aggregated_grad, var_list)
298
+
299
+ # sync_op will be assigned to the same device as the global step.
300
+ with ops.device(global_step.device), ops.name_scope(''):
301
+ update_op = self._opt.apply_gradients(aggregated_grads_and_vars,
302
+ global_step)
303
+
304
+ def _get_token_qname():
305
+ SyncReplicasOptimizer.sync_que_id += 1
306
+ if SyncReplicasOptimizer.sync_que_id == 0:
307
+ return 'sync_token_q'
308
+ else:
309
+ return 'sync_token_q_' + str(SyncReplicasOptimizer.sync_que_id)
310
+
311
+ # Create token queue.
312
+ token_qname = _get_token_qname()
313
+ logging.info('create sync_token_queue[%s]' % token_qname)
314
+ with ops.device(global_step.device), ops.name_scope(''):
315
+ sync_token_queue = (
316
+ data_flow_ops.FIFOQueue(
317
+ -1,
318
+ global_step.dtype.base_dtype,
319
+ shapes=(),
320
+ name=token_qname,
321
+ shared_name=token_qname))
322
+ self._sync_token_queue = sync_token_queue
323
+ self._is_sync_que_closed = sync_token_queue.is_closed()
324
+ self._close_sync_que = sync_token_queue.close(
325
+ cancel_pending_enqueues=True, name='close_sync_token_queue')
326
+
327
+ # dummy_queue is passed to the queue runner. Don't use the real queues
328
+ # because the queue runner doesn't automatically reopen it once it
329
+ # closed queues in PS devices.
330
+ dummy_queue = (
331
+ data_flow_ops.FIFOQueue(
332
+ 1,
333
+ types_pb2.DT_INT32,
334
+ shapes=(),
335
+ name='dummy_queue',
336
+ shared_name='dummy_queue'))
337
+
338
+ with ops.device(global_step.device), ops.name_scope(''):
339
+ # Replicas have to wait until they can get a token from the token queue.
340
+ with ops.control_dependencies(train_ops):
341
+ token = sync_token_queue.dequeue()
342
+ train_op = state_ops.assign(self._local_step, token)
343
+
344
+ with ops.control_dependencies([update_op]):
345
+ # Sync_op needs to insert tokens to the token queue at the end of the
346
+ # step so the replicas can fetch them to start the next step.
347
+ tokens = array_ops.fill([self._tokens_per_step], global_step)
348
+ sync_op = sync_token_queue.enqueue_many((tokens,))
349
+
350
+ if self._variable_averages is not None:
351
+ with ops.control_dependencies([sync_op]), ops.name_scope(''):
352
+ sync_op = self._variable_averages.apply(self._variables_to_average)
353
+
354
+ self._chief_queue_runner = queue_runner.QueueRunner(
355
+ dummy_queue, [sync_op])
356
+ ops.add_to_collection(ops.GraphKeys.QUEUE_RUNNERS,
357
+ self._chief_queue_runner)
358
+ for accum, dev in self._accumulator_list:
359
+ with ops.device(dev):
360
+ chief_init_ops.append(
361
+ accum.set_global_step(global_step, name='SetGlobalStep'))
362
+ self.chief_init_op = control_flow_ops.group(*(chief_init_ops))
363
+ self._gradients_applied = True
364
+ return train_op
365
+
366
+ def get_chief_queue_runner(self):
367
+ """Returns the QueueRunner for the chief to execute.
368
+
369
+ This includes the operations to synchronize replicas: aggregate gradients,
370
+ apply to variables, increment global step, insert tokens to token queue.
371
+
372
+ Note that this can only be called after calling apply_gradients() which
373
+ actually generates this queuerunner.
374
+
375
+ Returns:
376
+ A `QueueRunner` for chief to execute.
377
+
378
+ Raises:
379
+ ValueError: If this is called before apply_gradients().
380
+ """
381
+ if self._gradients_applied is False:
382
+ raise ValueError('Should be called after apply_gradients().')
383
+
384
+ return self._chief_queue_runner
385
+
386
+ def get_slot(self, *args, **kwargs):
387
+ """Return a slot named "name" created for "var" by the Optimizer.
388
+
389
+ This simply wraps the get_slot() from the actual optimizer.
390
+
391
+ Args:
392
+ *args: Arguments for get_slot().
393
+ **kwargs: Keyword arguments for get_slot().
394
+
395
+ Returns:
396
+ The `Variable` for the slot if it was created, `None` otherwise.
397
+ """
398
+ return self._opt.get_slot(*args, **kwargs)
399
+
400
+ def variables(self):
401
+ """Fetches a list of optimizer variables in the default graph.
402
+
403
+ This wraps `variables()` from the actual optimizer. It does not include
404
+ the `SyncReplicasOptimizer`'s local step.
405
+
406
+ Returns:
407
+ A list of variables.
408
+ """
409
+ return self._opt.variables()
410
+
411
+ def get_slot_names(self, *args, **kwargs):
412
+ """Return a list of the names of slots created by the `Optimizer`.
413
+
414
+ This simply wraps the get_slot_names() from the actual optimizer.
415
+
416
+ Args:
417
+ *args: Arguments for get_slot().
418
+ **kwargs: Keyword arguments for get_slot().
419
+
420
+ Returns:
421
+ A list of strings.
422
+ """
423
+ return self._opt.get_slot_names(*args, **kwargs)
424
+
425
+ def get_init_tokens_op(self, num_tokens=-1):
426
+ """Returns the op to fill the sync_token_queue with the tokens.
427
+
428
+ This is supposed to be executed in the beginning of the chief/sync thread
429
+ so that even if the total_num_replicas is less than replicas_to_aggregate,
430
+ the model can still proceed as the replicas can compute multiple steps per
431
+ variable update. Make sure:
432
+ `num_tokens >= replicas_to_aggregate - total_num_replicas`.
433
+
434
+ Args:
435
+ num_tokens: Number of tokens to add to the queue.
436
+
437
+ Returns:
438
+ An op for the chief/sync replica to fill the token queue.
439
+
440
+ Raises:
441
+ ValueError: If this is called before apply_gradients().
442
+ ValueError: If num_tokens are smaller than replicas_to_aggregate -
443
+ total_num_replicas.
444
+ """
445
+ if self._gradients_applied is False:
446
+ raise ValueError(
447
+ 'get_init_tokens_op() should be called after apply_gradients().')
448
+
449
+ tokens_needed = self._replicas_to_aggregate - self._total_num_replicas
450
+ if num_tokens == -1:
451
+ num_tokens = self._replicas_to_aggregate
452
+ elif num_tokens < tokens_needed:
453
+ raise ValueError(
454
+ 'Too few tokens to finish the first step: %d (given) vs %d (needed)' %
455
+ (num_tokens, tokens_needed))
456
+
457
+ if num_tokens > 0:
458
+ with ops.device(self._global_step.device), ops.name_scope(''):
459
+ tokens = array_ops.fill([num_tokens], self._global_step)
460
+ init_tokens = self._sync_token_queue.enqueue_many((tokens,))
461
+ else:
462
+ init_tokens = control_flow_ops.no_op(name='no_init_tokens')
463
+
464
+ return init_tokens
465
+
466
+ def make_session_run_hook(self, is_chief, num_tokens=-1):
467
+ """Creates a hook to handle SyncReplicasHook ops such as initialization."""
468
+ return _SyncReplicasOptimizerHook(self, is_chief, num_tokens)
469
+
470
+
471
+ class _SyncReplicasOptimizerHook(session_run_hook.SessionRunHook):
472
+ """A SessionRunHook handles ops related to SyncReplicasOptimizer."""
473
+
474
+ def __init__(self, sync_optimizer, is_chief, num_tokens):
475
+ """Creates hook to handle SyncReplicasOptimizer initialization ops.
476
+
477
+ Args:
478
+ sync_optimizer: `SyncReplicasOptimizer` which this hook will initialize.
479
+ is_chief: `Bool`, whether is this a chief replica or not.
480
+ num_tokens: Number of tokens to add to the queue.
481
+ """
482
+ self._sync_optimizer = sync_optimizer
483
+ self._is_chief = is_chief
484
+ self._num_tokens = num_tokens
485
+
486
+ def begin(self):
487
+ if self._sync_optimizer._gradients_applied is False: # pylint: disable=protected-access
488
+ raise ValueError(
489
+ 'SyncReplicasOptimizer.apply_gradient should be called before using '
490
+ 'the hook.')
491
+ if self._is_chief:
492
+ self._local_init_op = self._sync_optimizer.chief_init_op
493
+ self._ready_for_local_init_op = (
494
+ self._sync_optimizer.ready_for_local_init_op)
495
+ self._init_tokens_op = self._sync_optimizer.get_init_tokens_op(
496
+ self._num_tokens)
497
+ else:
498
+ self._local_init_op = self._sync_optimizer.local_step_init_op
499
+ self._ready_for_local_init_op = (
500
+ self._sync_optimizer.ready_for_local_init_op)
501
+ self._init_tokens_op = None
502
+
503
+ def after_create_session(self, session, coord):
504
+ """Runs SyncReplicasOptimizer initialization ops."""
505
+ local_init_success, msg = session_manager._ready( # pylint: disable=protected-access
506
+ self._ready_for_local_init_op, session,
507
+ 'Model is not ready for SyncReplicasOptimizer local init.')
508
+ if not local_init_success:
509
+ raise RuntimeError(
510
+ 'Init operations did not make model ready for SyncReplicasOptimizer '
511
+ 'local_init. Init op: %s, error: %s' %
512
+ (self._local_init_op.name, msg))
513
+ session.run(self._local_init_op)
514
+ is_closed = session.run(self._sync_optimizer._is_sync_que_closed)
515
+ assert not is_closed, 'sync_que is closed'
516
+ if self._init_tokens_op is not None:
517
+ session.run(self._init_tokens_op)
518
+
519
+ def end(self, session):
520
+ try:
521
+ is_closed = session.run(self._sync_optimizer._is_sync_que_closed)
522
+ if not is_closed:
523
+ logging.info('will close sync token que')
524
+ session.run(self._sync_optimizer._close_sync_que)
525
+ else:
526
+ logging.info('sync token que is closed')
527
+ except errors_impl.CancelledError:
528
+ logging.info('sync token que is closed')