easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
@@ -0,0 +1,542 @@
1
+ #
2
+ # Copyright (c) 2022, NVIDIA CORPORATION.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+
17
+ import json
18
+
19
+ import tensorflow as tf
20
+ from sparse_operation_kit.experiment import raw_ops as dynamic_variable_ops
21
+ from sparse_operation_kit.experiment.communication import num_gpus
22
+ from tensorflow.python.eager import context
23
+ from tensorflow.python.framework import ops
24
+ # from tensorflow.python.ops import array_ops
25
+ from tensorflow.python.ops import resource_variable_ops
26
+ from tensorflow.python.ops.resource_variable_ops import ResourceVariable
27
+ from tensorflow.python.ops.resource_variable_ops import variable_accessed
28
+
29
+ # from tensorflow.python.util import object_identity
30
+
31
+ dynamic_variable_count = 0
32
+
33
+ _resource_var_from_proto = ResourceVariable.from_proto
34
+
35
+
36
+ class DynamicVariable(ResourceVariable):
37
+ """Abbreviated as ``sok.experiment.DynamicVariable``.
38
+
39
+ A variable that allocates memory dynamically.
40
+
41
+ Parameters
42
+ ----------
43
+ dimension: int
44
+ The last dimension of this variable(that is, the embedding vector
45
+ size of embedding table).
46
+
47
+ initializer: string
48
+ a string to specify how to initialize this variable.
49
+ Currently, only support "random" or string of a float
50
+ value(meaning const initializer). Default value is "random".
51
+
52
+ var_type: string
53
+ a string to specify to use DET or HKV as the backend.
54
+ If use HKV as the backend, only support tf.int64 as key_type
55
+ If use HKV as the backend, please set init_capacity and max_capacity value equal to 2 powers.
56
+
57
+ key_type: dtype
58
+ specify the data type of indices. Unlike the static variable of
59
+ tensorflow, this variable is dynamically allocated and contains
60
+ a hash table inside it. So the data type of indices must be
61
+ specified to construct the hash table. Default value is tf.int64.
62
+
63
+ dtype: dtype
64
+ specify the data type of values. Default value is tf.float32.
65
+
66
+ Example
67
+ -------
68
+ .. code-block:: python
69
+
70
+ import numpy as np
71
+ import tensorflow as tf
72
+ import horovod.tensorflow as hvd
73
+ from sparse_operation_kit import experiment as sok
74
+
75
+ v = sok.DynamicVariable(dimension=3, initializer="13")
76
+ print("v.shape:", v.shape)
77
+ print("v.size:", v.size)
78
+
79
+ indices = tf.convert_to_tensor([0, 1, 2**40], dtype=tf.int64)
80
+
81
+ embedding = tf.nn.embedding_lookup(v, indices)
82
+ print("embedding:", embedding)
83
+ print("v.shape:", v.shape)
84
+ print("v.size:", v.size)
85
+ """
86
+
87
+ def __init__(self,
88
+ dimension,
89
+ initializer=None,
90
+ var_type=None,
91
+ name=None,
92
+ constraint=None,
93
+ trainable=True,
94
+ key_type=None,
95
+ dtype=None,
96
+ mode=None,
97
+ variable_def=None,
98
+ import_scope=None,
99
+ **kwargs):
100
+ self._indices = None
101
+ if variable_def is not None:
102
+ super(DynamicVariable, self)._init_from_proto(
103
+ variable_def, import_scope=import_scope, validate_shape=False)
104
+ g = ops.get_default_graph()
105
+ handle = g.as_graph_element(
106
+ ops.prepend_name_scope(
107
+ variable_def.variable_name, import_scope=import_scope),
108
+ allow_operation=False)
109
+ self._dimension = handle.op.get_attr('shape').dim[-1].size
110
+ self._key_type = handle.op.get_attr('key_type')
111
+ self._handle_type = handle.op.get_attr('dtype')
112
+ self._mode = None
113
+ self._config = {}
114
+ self._name = variable_def.variable_name.split(':')[0]
115
+ self._trainable = variable_def.trainable
116
+ self._dummy_handle = handle
117
+ self._handle = handle
118
+
119
+ # init op
120
+ init_op = g.as_graph_element(variable_def.initializer_name)
121
+ self._initializer_op = init_op
122
+
123
+ init_tf = init_op.control_inputs[0]
124
+ # init_dummy = init_op.control_inputs[1]
125
+
126
+ self._tf_handle = init_tf.inputs[0]
127
+ return
128
+
129
+ self._key_type = key_type if key_type is not None else tf.int64
130
+ self._handle_dtype = dtype if dtype is not None else tf.float32
131
+ self._dimension = dimension
132
+ self._mode = mode
133
+ self._config = json.dumps(kwargs)
134
+ self._config_dict = kwargs
135
+ if var_type == 'hybrid' and self._key_type != tf.int64:
136
+ raise NotImplementedError(
137
+ 'only key_type tf.int64 is supported in HKV backend')
138
+ if name is None:
139
+ global dynamic_variable_count
140
+ name = 'sok_dynamic_Variable_' + str(dynamic_variable_count)
141
+ dynamic_variable_count += 1
142
+ var_type = 'hbm' if var_type is None else var_type
143
+ self._var_type = var_type
144
+ self._base = super(DynamicVariable, self)
145
+ self._base.__init__(
146
+ initial_value=[[0.0] * dimension],
147
+ trainable=trainable,
148
+ name=name + '/proxy',
149
+ dtype=self._handle_dtype,
150
+ constraint=constraint,
151
+ distribute_strategy=None,
152
+ synchronization=None,
153
+ aggregation=None,
154
+ shape=[None, dimension],
155
+ )
156
+
157
+ with ops.init_scope():
158
+ # name = "DynamicVariable" if name is None else name
159
+ with ops.name_scope(name) as name_scope:
160
+ self._dummy_name = ops.name_from_scope_name(name_scope)
161
+ if context.executing_eagerly():
162
+ self._dummy_name = '%s_%d' % (name, ops.uid())
163
+ with ops.NullContextmanager():
164
+ shape = [None, dimension]
165
+ initializer = '' if initializer is None else initializer
166
+ self._initializer = initializer
167
+ handle = dynamic_variable_ops.dummy_var_handle(
168
+ container='DummyVariableContainer',
169
+ shared_name=self._dummy_name,
170
+ key_type=self._key_type,
171
+ dtype=self._handle_dtype,
172
+ shape=shape,
173
+ )
174
+ if type(initializer) is str:
175
+ init_op = dynamic_variable_ops.dummy_var_initialize(
176
+ handle,
177
+ initializer=initializer,
178
+ var_type=var_type,
179
+ unique_name=self._dummy_name,
180
+ key_type=self._key_type,
181
+ dtype=self._handle_dtype,
182
+ config=self._config,
183
+ )
184
+ else:
185
+ with tf.control_dependencies([initializer._initializer_op]):
186
+ initial_val = initializer.read_value()
187
+ init_op = dynamic_variable_ops.dummy_var_initialize(
188
+ handle,
189
+ initializer=initial_val,
190
+ var_type=var_type,
191
+ unique_name=self._dummy_name,
192
+ key_type=self._key_type,
193
+ dtype=self._handle_dtype,
194
+ config=self._config,
195
+ )
196
+ # TODO: Add is_initialized_op
197
+ # is_initialized_op = ops.convert_to_tensor(True)
198
+
199
+ self._tf_handle = self._handle
200
+ self._dummy_handle = handle
201
+ # Note that the default handle will be sok's handle
202
+ self._handle = self._dummy_handle
203
+ self._initializer_op = tf.group([self._initializer_op, init_op])
204
+ # self._is_initialized_op = tf.group([self._is_initialized_op, is_initialized_op])
205
+
206
+ handle_data = (
207
+ resource_variable_ops.cpp_shape_inference_pb2.CppShapeInferenceResult
208
+ .HandleData())
209
+ handle_data.is_set = True
210
+ handle_data.shape_and_type.append(
211
+ resource_variable_ops.cpp_shape_inference_pb2.CppShapeInferenceResult
212
+ .HandleShapeAndType(
213
+ shape=self.shape.as_proto(), dtype=self.dtype.as_datatype_enum))
214
+ resource_variable_ops._set_handle_shapes_and_types(
215
+ self._handle,
216
+ handle_data,
217
+ graph_mode=False if context.executing_eagerly() else True)
218
+
219
+ def is_static(self):
220
+ return self._handle is self._tf_handle
221
+
222
+ def to_static(self, indices, lookup_only=False):
223
+ if not self.is_static() and self._indices is None:
224
+ buffer = self.sparse_read(indices, lookup_only)
225
+ self._indices = indices
226
+ self._handle = self._tf_handle
227
+ return self.assign(buffer)
228
+ else:
229
+ raise RuntimeError('to_static() must be called in dynamic mode.')
230
+
231
+ def to_dynamic(self):
232
+ if self.is_static():
233
+ buffer = self.read_value()
234
+ sparse_delta = ops.IndexedSlices(buffer, self._indices, self.shape)
235
+ self._indices = None
236
+ self._handle = self._dummy_handle
237
+ return self.scatter_update(sparse_delta)
238
+ else:
239
+ raise RuntimeError('to_dynamic() must be called in static mode.')
240
+
241
+ @property
242
+ def name(self):
243
+ return self._dummy_handle.name
244
+
245
+ def __repr__(self):
246
+ if self.is_static():
247
+ return self._base.__repr__()
248
+ return "<sok.DynamicVariable '%s' shape=%s dtype=%s>" % (
249
+ self._dummy_name,
250
+ self.shape,
251
+ self.dtype.name,
252
+ )
253
+
254
+ @property
255
+ def size(self):
256
+ return dynamic_variable_ops.dummy_var_shape(
257
+ self._dummy_handle, key_type=self._key_type, dtype=self._handle_dtype)
258
+
259
+ @property
260
+ def indices(self):
261
+ return self._indices
262
+
263
+ @property
264
+ def dimension(self):
265
+ return self._dimension
266
+
267
+ def get_shape(self):
268
+ return [self._dimension]
269
+
270
+ @property
271
+ def key_type(self):
272
+ return self._key_type
273
+
274
+ @property
275
+ def handle_dtype(self):
276
+ return self._handle_dtype
277
+
278
+ @property
279
+ def backend_type(self):
280
+ return self._var_type
281
+
282
+ @property
283
+ def config_dict(self):
284
+ return self._config_dict
285
+
286
+ @property
287
+ def mode(self):
288
+ return self._mode
289
+
290
+ @property
291
+ def num_gpus(self):
292
+ return num_gpus()
293
+
294
+ @property
295
+ def initializer_str(self):
296
+ return self._initializer
297
+
298
+ def key_map(self, indices):
299
+ return indices
300
+
301
+ # -------------------------------------------------------------------------
302
+ # Methods supported both in static mode and dynamic mode
303
+ # -------------------------------------------------------------------------
304
+
305
+ def sparse_read(self, indices, name=None, lookup_only=False):
306
+ if self.is_static():
307
+ return self._base.sparse_read(indices, name)
308
+
309
+ variable_accessed(self)
310
+ if indices.dtype == tf.int32:
311
+ indices = tf.cast(indices, tf.int64)
312
+ return dynamic_variable_ops.dummy_var_sparse_read(
313
+ self._dummy_handle,
314
+ indices,
315
+ dtype=self._handle_dtype,
316
+ lookup_only=lookup_only)
317
+
318
+ def scatter_sub(self, sparse_delta, use_locking=False, name=None):
319
+ if self.is_static():
320
+ return self._base.scatter_sub(sparse_delta, use_locking, name)
321
+ if not isinstance(sparse_delta, ops.IndexedSlices):
322
+ raise TypeError('sparse_delta is not IndexedSlices: %s' % sparse_delta)
323
+ return dynamic_variable_ops.dummy_var_scatter_add(
324
+ self._dummy_handle,
325
+ sparse_delta.indices,
326
+ ops.convert_to_tensor(-sparse_delta.values, self.dtype),
327
+ )
328
+
329
+ def scatter_add(self, sparse_delta, use_locking=False, name=None):
330
+ if self.is_static():
331
+ return self._base.scatter_add(sparse_delta, use_locking, name)
332
+ if not isinstance(sparse_delta, ops.IndexedSlices):
333
+ raise TypeError('sparse_delta is not IndexedSlices: %s' % sparse_delta)
334
+ return dynamic_variable_ops.dummy_var_scatter_add(
335
+ self._dummy_handle,
336
+ sparse_delta.indices,
337
+ ops.convert_to_tensor(sparse_delta.values, self.dtype),
338
+ )
339
+
340
+ def scatter_update(self, sparse_delta, use_locking=False, name=None):
341
+ if self.is_static():
342
+ return self._base.scatter_update(sparse_delta, use_locking, name)
343
+ if not isinstance(sparse_delta, ops.IndexedSlices):
344
+ raise TypeError('sparse_delta is not IndexedSlices: %s' % sparse_delta)
345
+ return dynamic_variable_ops.dummy_var_scatter_update(
346
+ self._dummy_handle,
347
+ sparse_delta.indices,
348
+ ops.convert_to_tensor(sparse_delta.values, self.dtype),
349
+ )
350
+
351
+ # -------------------------------------------------------------------------
352
+ # Methods not supported both in static mode and dynamic mode
353
+ # -------------------------------------------------------------------------
354
+
355
+ def __deepcopy__(self, *args, **kwargs):
356
+ raise NotImplementedError('__deepcopy__() is not supported.')
357
+
358
+ def __reduce__(self, *args, **kwargs):
359
+ raise NotImplementedError('__reduce__() is not supported.')
360
+
361
+ def to_proto(self, *args, **kwargs):
362
+ return super(DynamicVariable, self).to_proto(*args, **kwargs)
363
+ # raise NotImplementedError("to_proto() is not supported.")
364
+
365
+ @staticmethod
366
+ def from_proto(variable_def, import_scope=None):
367
+ if '/DummyVarHandle' in variable_def.variable_name:
368
+ return DynamicVariable(
369
+ dimension=0, variable_def=variable_def, import_scope=import_scope)
370
+ else:
371
+ return _resource_var_from_proto(variable_def, import_scope)
372
+ # raise NotImplementedError("from_proto() is not supported.")
373
+
374
+ def set_shape(self, *args, **kwargs):
375
+ raise NotImplementedError('set_shape() is not supported.')
376
+
377
+ # -------------------------------------------------------------------------
378
+ # Methods only supported in static mode
379
+ # -------------------------------------------------------------------------
380
+
381
+ def is_initialized(self, name):
382
+ return True
383
+ if self.is_static():
384
+ return self._base.is_initialized(name)
385
+ raise NotImplementedError(
386
+ 'is_initialized() is not supported in dynamic mode.')
387
+
388
+ def _read_variable_op(self):
389
+ if self.is_static():
390
+ return self._base._read_variable_op()
391
+ raise NotImplementedError(
392
+ '_read_variable_op() is not supported in dynamic mode.')
393
+
394
+ def value(self):
395
+ if self.is_static():
396
+ return self._base.value()
397
+ raise NotImplementedError('value() is not supported in dynamic mode.')
398
+
399
+ def _dense_var_to_tensor(self, *args, **kwargs):
400
+ if self.is_static():
401
+ return self._base._dense_var_to_tensor(*args, **kwargs)
402
+ raise NotImplementedError(
403
+ '_dense_var_to_tensor() is not supported in dynamic mode.')
404
+
405
+ def _gather_saveables_for_checkpoint(self):
406
+ if self.is_static():
407
+ return self._base._gather_saveables_for_checkpoint()
408
+ raise NotImplementedError(
409
+ '_gather_saveables_for_checkpoint() is not supported in dynamic mode.')
410
+
411
+ def gather_nd(self, *args, **kwargs):
412
+ if self.is_static():
413
+ return self._base.gather_nd(*args, **kwargs)
414
+ raise NotImplementedError('gather_nd() is not supported in dynamic mode.')
415
+
416
+ def assign_add(self, *args, **kwargs):
417
+ if self.is_static():
418
+ return self._base.assign_add(*args, **kwargs)
419
+ raise NotImplementedError('assign_add() is not supported in dynamic mode.')
420
+
421
+ def assign(self, *args, **kwargs):
422
+ if self.is_static():
423
+ return self._base.assign(*args, **kwargs)
424
+ raise NotImplementedError('assign() is not supported in dynamic mode.')
425
+
426
+ def scatter_max(self, *args, **kwargs):
427
+ if self.is_static():
428
+ return self._base.scatter_max(*args, **kwargs)
429
+ raise NotImplementedError('scatter_max() is not supported in dynamic mode.')
430
+
431
+ def scatter_min(self, *args, **kwargs):
432
+ if self.is_static():
433
+ return self._base.scatter_min(*args, **kwargs)
434
+ raise NotImplementedError('scatter_min() is not supported in dynamic mode.')
435
+
436
+ def scatter_mul(self, *args, **kwargs):
437
+ if self.is_static():
438
+ return self._base.scatter_mul(*args, **kwargs)
439
+ raise NotImplementedError('scatter_mul() is not supported in dynamic mode.')
440
+
441
+ def scatter_dim(self, *args, **kwargs):
442
+ if self.is_static():
443
+ return self._base.scatter_dim(*args, **kwargs)
444
+ raise NotImplementedError('scatter_dim() is not supported in dynamic mode.')
445
+
446
+ def batch_scatter_update(self, *args, **kwargs):
447
+ if self.is_static():
448
+ return self._base.batch_scatter_update(*args, **kwargs)
449
+ raise NotImplementedError(
450
+ 'batch_scatter_update() is not supported in dynamic mode.')
451
+
452
+ def scatter_nd_sub(self, *args, **kwargs):
453
+ if self.is_static():
454
+ return self._base.scatter_nd_sub(*args, **kwargs)
455
+ raise NotImplementedError(
456
+ 'scatter_nd_sub() is not supported in dynamic mode.')
457
+
458
+ def scatter_nd_update(self, *args, **kwargs):
459
+ if self.is_static():
460
+ return self._base.scatter_nd_update(*args, **kwargs)
461
+ raise NotImplementedError(
462
+ 'scatter_nd_update() is not supported in dynamic mode.')
463
+
464
+ def _strided_slice_assign(self, *args, **kwargs):
465
+ if self.is_static():
466
+ return self._base._strided_slice_assign(*args, **kwargs)
467
+ raise NotImplementedError(
468
+ '_strided_slice_assign() is not supported in dynamic mode.')
469
+
470
+ def __int__(self, *args, **kwargs):
471
+ if self.is_static():
472
+ return self._base.__int__(*args, **kwargs)
473
+ raise NotImplementedError('__int__() is not supported in dynamic mode.')
474
+
475
+
476
+ ResourceVariable.from_proto = DynamicVariable.from_proto
477
+
478
+ # @tf.RegisterGradient("DummyVarSparseRead")
479
+ # def _SparseReadGrad(op, grad):
480
+ # """Gradient for sparse_read."""
481
+ # handle = op.inputs[0]
482
+ # indices = op.inputs[1]
483
+ # key_type = op.get_attr("key_type")
484
+ # dtype = op.get_attr("dtype")
485
+ # variable_shape = dynamic_variable_ops.dummy_var_shape(handle, key_type=key_type, dtype=dtype)
486
+ # size = array_ops.expand_dims(array_ops.size(indices), 0)
487
+ # values_shape = array_ops.concat([size, variable_shape[1:]], 0)
488
+ # grad = array_ops.reshape(grad, values_shape)
489
+ # indices = array_ops.reshape(indices, size)
490
+ # return (ops.IndexedSlices(grad, indices, variable_shape), None)
491
+
492
+
493
+ def export(var):
494
+ """Abbreviated as ``sok.experiment.export``.
495
+
496
+ Export the indices and value tensor from the given variable.
497
+
498
+ Parameters
499
+ ----------
500
+ var: sok.DynamicVariable
501
+ The variable to extract indices and values.
502
+
503
+ Returns
504
+ -------
505
+ indices: tf.Tensor
506
+ The indices of the given variable.
507
+
508
+ values: tf.Tensor
509
+ the values of the given variable.
510
+ """
511
+ if isinstance(var, DynamicVariable):
512
+ indices, values = dynamic_variable_ops.dummy_var_export(
513
+ var.handle, key_type=var.key_type, dtype=var.handle_dtype)
514
+ with tf.device('CPU'):
515
+ indices = tf.identity(indices)
516
+ values = tf.identity(values)
517
+ return indices, values
518
+
519
+
520
+ def assign(var, indices, values):
521
+ """Abbreviated as ``sok.experiment.assign``.
522
+
523
+ Assign the indices and value tensor to the target variable.
524
+
525
+ Parameters
526
+ ----------
527
+ var: sok.DynamicVariable
528
+ The target variable of assign.
529
+
530
+ indices: tf.Tensor
531
+ indices to be assigned to the variable.
532
+
533
+ values: tf.Tensor
534
+ values to be assigned to the variable
535
+
536
+ Returns
537
+ -------
538
+ variable: sok.DynamicVariable
539
+ """
540
+ if isinstance(var, DynamicVariable):
541
+ tf.cast(indices, var._key_type)
542
+ return dynamic_variable_ops.dummy_var_assign(var.handle, indices, values)