easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
@@ -0,0 +1,440 @@
1
+ #
2
+ # Copyright (c) 2022, NVIDIA CORPORATION.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+
17
+ import tensorflow as tf
18
+ from tensorflow.python.eager import context
19
+ # from tensorflow.python.framework import dtypes
20
+ from tensorflow.python.framework import ops
21
+ # from tensorflow.python.ops import control_flow_ops
22
+ from tensorflow.python.ops import array_ops
23
+ from tensorflow.python.ops import gradients
24
+ from tensorflow.python.ops import resource_variable_ops
25
+ from tensorflow.python.ops import state_ops
26
+
27
+ from easy_rec.python.compat.dynamic_variable import DynamicVariable
28
+
29
+
30
+ def OptimizerWrapper(optimizer):
31
+ """Abbreviated as ``sok.experiment.OptimizerWrapper``.
32
+
33
+ This is a wrapper for tensorflow optimizer so that it can update
34
+ dynamic_variable.DynamicVariable.
35
+
36
+ Parameters
37
+ ----------
38
+ optimizer: tensorflow optimizer
39
+ The original tensorflow optimizer.
40
+
41
+ Example
42
+ -------
43
+ .. code-block:: python
44
+
45
+ import numpy as np
46
+ import tensorflow as tf
47
+ import horovod.tensorflow as hvd
48
+ from sparse_operation_kit import experiment as sok
49
+
50
+ v = dynamic_variable.DynamicVariable(dimension=3, initializer="13")
51
+
52
+ indices = tf.convert_to_tensor([0, 1, 2**40], dtype=tf.int64)
53
+
54
+ with tf.GradientTape() as tape:
55
+ embedding = tf.nn.embedding_lookup(v, indices)
56
+ print("embedding:", embedding)
57
+ loss = tf.reduce_sum(embedding)
58
+
59
+ grads = tape.gradient(loss, [v])
60
+
61
+ optimizer = tf.keras.optimizers.SGD(learning_rate=1.0)
62
+ optimizer = sok.OptimizerWrapper(optimizer)
63
+ optimizer.apply_gradients(zip(grads, [v]))
64
+
65
+ embedding = tf.nn.embedding_lookup(v, indices)
66
+ print("embedding:", embedding)
67
+ """
68
+ # a specific code path for dl framework tf2.11.0
69
+ try:
70
+ if isinstance(optimizer, tf.keras.optimizers.legacy.Optimizer):
71
+ return OptimizerWrapperV2(optimizer)
72
+ except Exception:
73
+ pass
74
+
75
+ if isinstance(optimizer, tf.keras.optimizers.Optimizer):
76
+ return OptimizerWrapperV2(optimizer)
77
+ else:
78
+ return OptimizerWrapperV1(optimizer)
79
+
80
+
81
+ class OptimizerWrapperV1(object):
82
+
83
+ def __init__(self, optimizer):
84
+ self._optimizer = optimizer
85
+ # slots
86
+ unused = tf.Variable([0.0],
87
+ dtype=tf.float32,
88
+ name='unused',
89
+ trainable=False)
90
+ self._optimizer._create_slots([unused])
91
+ names, slots = [], []
92
+ for name in self._optimizer.get_slot_names():
93
+ names.append(name)
94
+ slots.append(self._optimizer.get_slot(unused, name))
95
+ unused_key = self._var_key(unused)
96
+ for name in names:
97
+ assert unused_key in self._optimizer._slots[name]
98
+ self._optimizer._slots[name].pop(unused_key)
99
+ self._initial_vals = {}
100
+ for i, name in enumerate(names):
101
+ self._initial_vals[name] = slots[i]
102
+ # self._optimizer._prepare()
103
+
104
+ def compute_gradients(self,
105
+ loss,
106
+ var_list=None,
107
+ aggregation_method=None,
108
+ colocate_gradients_with_ops=False,
109
+ grad_loss=None):
110
+ self._loss = loss
111
+ tmp_grads = gradients.gradients(loss, var_list)
112
+ return list(zip(tmp_grads, var_list))
113
+ # TODO: the following routine does not work with DynamicVariable
114
+ # return self._optimizer.compute_gradients(loss=loss, var_list=var_list,
115
+ # # gate_gradients=gate_gradients,
116
+ # aggregation_method=aggregation_method,
117
+ # colocate_gradients_with_ops=colocate_gradients_with_ops,
118
+ # grad_loss=grad_loss)
119
+
120
+ def _var_key(self, var):
121
+ if isinstance(var, DynamicVariable):
122
+ return (var._tf_handle.op.graph, var._tf_handle.op.name)
123
+ else:
124
+ return (var.op.graph, var.op.name)
125
+
126
+ def _create_slots(self, vars):
127
+ for var in vars:
128
+ if isinstance(var, DynamicVariable):
129
+ self._create_slots_dynamic(var)
130
+ else:
131
+ self._optimizer._create_slots(var)
132
+
133
+ def _create_slots_dynamic(self, var):
134
+ key = self._var_key(var)
135
+ for slot_name in self._initial_vals:
136
+ if key not in self._optimizer._slots[slot_name]:
137
+ if var.backend_type == 'hbm':
138
+ with ops.colocate_with(var):
139
+ slot = DynamicVariable(
140
+ dimension=var.dimension,
141
+ initializer=self._initial_vals[slot_name],
142
+ name='DynamicSlot',
143
+ trainable=False)
144
+ else:
145
+ tmp_config = var.config_dict
146
+ # tmp_initializer = var.initializer_str
147
+ with ops.colocate_with(var):
148
+ slot = DynamicVariable(
149
+ dimension=var.dimension,
150
+ initializer=self._initial_vals[slot_name],
151
+ var_type=var.backend_type,
152
+ name='DynamicSlot',
153
+ trainable=False,
154
+ **tmp_config)
155
+
156
+ self._optimizer._slots[slot_name][key] = slot
157
+
158
+ def get_slot_names(self):
159
+ return self._optimizer.get_slot_names()
160
+
161
+ def get_slot(self, var, slot_name):
162
+ key = self._var_key(var)
163
+ return self._optimizer._slots[slot_name][key]
164
+
165
+ @property
166
+ def _slots(self):
167
+ return self._optimizer._slots
168
+
169
+ def apply_gradients(self, grads_and_vars, global_step=None, name=None):
170
+ gradients = grads_and_vars
171
+ sparse_vars = [x for x in gradients if 'DynamicVariable' in str(type(x[1]))]
172
+ dense_vars = [
173
+ x for x in gradients if 'DynamicVariable' not in str(type(x[1]))
174
+ ]
175
+
176
+ def _dummy_finish(update_ops, name_scope):
177
+ return update_ops
178
+
179
+ finish_func = self._optimizer._finish
180
+ self._optimizer._finish = _dummy_finish
181
+ with ops.control_dependencies([array_ops.identity(self._loss)]):
182
+ sparse_grad_updates = self.apply_sparse_gradients(sparse_vars, name=name)
183
+
184
+ dense_grad_updates = self._optimizer.apply_gradients(
185
+ dense_vars, global_step=None, name=name)
186
+ if sparse_grad_updates is not None and dense_grad_updates is not None:
187
+ grad_updates = sparse_grad_updates + dense_grad_updates
188
+ elif sparse_grad_updates is not None:
189
+ grad_updates = sparse_grad_updates
190
+ elif dense_grad_updates is not None:
191
+ grad_updates = dense_grad_updates
192
+
193
+ assert global_step is not None
194
+ with ops.control_dependencies([finish_func(grad_updates, 'update')]):
195
+ with ops.colocate_with(global_step):
196
+ if isinstance(global_step, resource_variable_ops.BaseResourceVariable):
197
+ # TODO(apassos): the implicit read in assign_add is slow; consider
198
+ # making it less so.
199
+ apply_updates = resource_variable_ops.assign_add_variable_op(
200
+ global_step.handle,
201
+ ops.convert_to_tensor(1, dtype=global_step.dtype),
202
+ name=name)
203
+ else:
204
+ apply_updates = state_ops.assign_add(global_step, 1, name=name)
205
+
206
+ if not context.executing_eagerly():
207
+ if isinstance(apply_updates, ops.Tensor):
208
+ apply_updates = apply_updates.op
209
+ train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
210
+ if apply_updates not in train_op:
211
+ train_op.append(apply_updates)
212
+
213
+ return apply_updates
214
+
215
+ def apply_sparse_gradients(self, grads_and_vars, global_step=None, name=None):
216
+ # 1. Create slots and do sparse_read
217
+ to_static_ops = []
218
+ grad_list, var_list = [], []
219
+ for g, v in grads_and_vars:
220
+ if g is not None:
221
+ unique, indices = tf.unique(g.indices)
222
+ grad_list.append(ops.IndexedSlices(g.values, indices, g.dense_shape))
223
+ # TODO: Check multi-thread safety of DET
224
+ with tf.control_dependencies([g.values]):
225
+ to_static_ops.append(v.to_static(unique, False))
226
+ var_list.append(v)
227
+ key = self._var_key(v)
228
+ for slot_name in self._initial_vals:
229
+ if key not in self._optimizer._slots[slot_name]:
230
+ tmp_slot_var_name = v._dummy_handle.op.name + '/' + self._optimizer._name
231
+ if v.backend_type == 'hbm':
232
+ with ops.colocate_with(v):
233
+ slot = DynamicVariable(
234
+ dimension=v.dimension,
235
+ initializer=self._initial_vals[slot_name],
236
+ name=tmp_slot_var_name,
237
+ trainable=False,
238
+ )
239
+ else:
240
+ tmp_config = v.config_dict
241
+ # tmp_initializer = v.initializer_str
242
+ with ops.colocate_with(v):
243
+ slot = DynamicVariable(
244
+ dimension=v.dimension,
245
+ initializer=self._initial_vals[slot_name],
246
+ var_type=v.backend_type,
247
+ name=tmp_slot_var_name,
248
+ trainable=False,
249
+ **tmp_config)
250
+
251
+ self._optimizer._slots[slot_name][key] = slot
252
+ else:
253
+ slot = self._optimizer._slots[slot_name][key]
254
+ to_static_ops.append(slot.to_static(unique))
255
+
256
+ if len(grad_list) == 0:
257
+ return
258
+
259
+ # 3. Call tf-optimizer
260
+ with ops.control_dependencies(to_static_ops):
261
+ train_op = self._optimizer.apply_gradients(
262
+ zip(grad_list, var_list), global_step=global_step, name=name)
263
+
264
+ # 5. Write buffer back to dynamic variables
265
+ to_dynamic_ops = []
266
+ if not isinstance(train_op, list):
267
+ train_op = [train_op]
268
+ with ops.control_dependencies(train_op):
269
+ for v in var_list:
270
+ key = self._var_key(v)
271
+ to_dynamic_ops.append(v.to_dynamic())
272
+ for name in self._initial_vals:
273
+ slot = self._optimizer._slots[name][key]
274
+ to_dynamic_ops.append(slot.to_dynamic())
275
+
276
+ return to_dynamic_ops
277
+
278
+
279
+ class OptimizerWrapperV2(object):
280
+
281
+ def __init__(self, optimizer):
282
+ self._optimizer = optimizer
283
+ # slots
284
+ if tf.__version__[0] == '1':
285
+ unused = tf.Variable([0.0],
286
+ name='unused',
287
+ trainable=False,
288
+ use_resource=True)
289
+ else:
290
+ unused = tf.Variable([0.0], name='unused', trainable=False)
291
+ self._optimizer._create_slots([unused])
292
+ names, slots = [], []
293
+ for name in self._optimizer.get_slot_names():
294
+ names.append(name)
295
+ slots.append(self._optimizer.get_slot(unused, name))
296
+ unused_key = self._var_key(unused)
297
+ if unused_key in self._optimizer._slots:
298
+ self._optimizer._slots.pop(unused_key)
299
+ self._initial_vals = {}
300
+ for i, name in enumerate(names):
301
+ self._initial_vals[name] = slots[i]
302
+ self._iterations = tf.Variable(0)
303
+
304
+ @property
305
+ def lr(self):
306
+ return self._optimizer.lr
307
+
308
+ def _create_slots(self, vars):
309
+ for tmp_var in vars:
310
+ if isinstance(tmp_var, DynamicVariable):
311
+ self._create_slots_dynamic(tmp_var)
312
+ else:
313
+ self._optimizer._create_slots(tmp_var)
314
+
315
+ def _create_slots_dynamic(self, var):
316
+ key = self._var_key(var)
317
+ if key not in self._optimizer._slots:
318
+ self._optimizer._slots[key] = {}
319
+ for slot_name in self._initial_vals:
320
+ if slot_name not in self._optimizer._slots[key]:
321
+ if var.backend_type == 'hbm':
322
+ slot = DynamicVariable(
323
+ dimension=var.dimension,
324
+ initializer=self._initial_vals[slot_name],
325
+ name='DynamicSlot',
326
+ trainable=False,
327
+ )
328
+ else:
329
+ tmp_config = var.config_dict
330
+ # tmp_initializer = var.initializer_str
331
+ slot = DynamicVariable(
332
+ dimension=var.dimension,
333
+ initializer=self._initial_vals[slot_name],
334
+ var_type=var.backend_type,
335
+ name='DynamicSlot',
336
+ trainable=False,
337
+ **tmp_config)
338
+ self._optimizer._slots[key][slot_name] = slot
339
+
340
+ def _var_key(self, var):
341
+ if hasattr(var, '_distributed_container'):
342
+ var = var._distributed_container()
343
+ if var._in_graph_mode:
344
+ return var._shared_name
345
+ return var._unique_id
346
+
347
+ def get_slot_names(self):
348
+ return self._optimizer.get_slot_names()
349
+
350
+ def get_slot(self, var, name):
351
+ return self._optimizer.get_slot(var, name)
352
+
353
+ @property
354
+ def _slots(self):
355
+ return self._optimizer._slots
356
+
357
+ def apply_gradients(self, grads_and_vars, global_step=None, name=None):
358
+ # 1. Create slots and do sparse_read
359
+ to_static_ops = []
360
+ grad_list, var_list = [], []
361
+ for g, v in grads_and_vars:
362
+ if g is not None:
363
+ unique, indices = tf.unique(g.indices)
364
+ grad_list.append(ops.IndexedSlices(g.values, indices, g.dense_shape))
365
+ # TODO: Check multi-thread safety of DET
366
+ # with tf.control_dependencies([g.values]):
367
+ to_static_ops.append(v.to_static(unique))
368
+ var_list.append(v)
369
+ key = self._var_key(v)
370
+ if key not in self._optimizer._slots:
371
+ self._optimizer._slots[key] = {}
372
+ for slot_name in self._initial_vals:
373
+ if slot_name not in self._optimizer._slots[key]:
374
+ if v.backend_type == 'hbm':
375
+ slot = DynamicVariable(
376
+ dimension=v.dimension,
377
+ initializer=self._initial_vals[slot_name],
378
+ name='DynamicSlot',
379
+ trainable=False,
380
+ )
381
+ else:
382
+ tmp_config = v.config_dict
383
+ # tmp_initializer = v.initializer_str
384
+ slot = DynamicVariable(
385
+ dimension=v.dimension,
386
+ initializer=self._initial_vals[slot_name],
387
+ var_type=v.backend_type,
388
+ name='DynamicSlot',
389
+ trainable=False,
390
+ **tmp_config)
391
+
392
+ self._optimizer._slots[key][slot_name] = slot
393
+ else:
394
+ slot = self._optimizer._slots[key][slot_name]
395
+ to_static_ops.append(slot.to_static(unique))
396
+
397
+ if len(grad_list) == 0:
398
+ return
399
+
400
+ # 2. Switch iterations
401
+ iterations = self._optimizer._iterations
402
+ self._optimizer._iterations = self._iterations
403
+
404
+ # 3. Call tf-optimizer
405
+ with tf.control_dependencies(to_static_ops):
406
+ train_op = self._optimizer.apply_gradients(
407
+ zip(grad_list, var_list), name=name)
408
+
409
+ # 4. Switch iterations
410
+ self._optimizer._iterations = iterations
411
+
412
+ # 5. Write buffer back to dynamic variables
413
+ to_dynamic_ops = []
414
+ with tf.control_dependencies([train_op]):
415
+ for v in var_list:
416
+ key = self._var_key(v)
417
+ to_dynamic_ops.append(v.to_dynamic())
418
+ for name in self._initial_vals:
419
+ slot = self._optimizer._slots[key][name]
420
+ to_dynamic_ops.append(slot.to_dynamic())
421
+ return tf.group(to_dynamic_ops)
422
+
423
+
424
+ class SGD(object):
425
+
426
+ def __init__(self, lr):
427
+ self._lr = tf.Variable(lr)
428
+
429
+ @property
430
+ def lr(self):
431
+ return self._lr
432
+
433
+ def apply_gradients(self, grads_and_vars, global_step=None, name=None):
434
+ train_ops = []
435
+ for g, v in grads_and_vars:
436
+ if g is not None:
437
+ scaled_g = ops.IndexedSlices(g.values * self._lr, g.indices,
438
+ g.dense_shape)
439
+ train_ops.append(v.scatter_sub(scaled_g))
440
+ return tf.group(train_ops)