easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
@@ -0,0 +1,154 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright 2019 The TensorFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+ """Defines functions common to multiple feature column files."""
17
+
18
+ from __future__ import absolute_import
19
+ from __future__ import division
20
+ from __future__ import print_function
21
+
22
+ import six
23
+ from tensorflow.python.framework import dtypes
24
+ from tensorflow.python.framework import ops
25
+ from tensorflow.python.ops import array_ops
26
+ from tensorflow.python.ops import math_ops
27
+ from tensorflow.python.util import nest
28
+
29
+
30
+ def sequence_length_from_sparse_tensor(sp_tensor, num_elements=1):
31
+ """Returns a [batch_size] Tensor with per-example sequence length."""
32
+ with ops.name_scope(None, 'sequence_length') as name_scope:
33
+ row_ids = sp_tensor.indices[:, 0]
34
+ column_ids = sp_tensor.indices[:, 1]
35
+ # Add one to convert column indices to element length
36
+ column_ids += array_ops.ones_like(column_ids)
37
+ # Get the number of elements we will have per example/row
38
+ seq_length = math_ops.segment_max(column_ids, segment_ids=row_ids)
39
+
40
+ # The raw values are grouped according to num_elements;
41
+ # how many entities will we have after grouping?
42
+ # Example: orig tensor [[1, 2], [3]], col_ids = (0, 1, 1),
43
+ # row_ids = (0, 0, 1), seq_length = [2, 1]. If num_elements = 2,
44
+ # these will get grouped, and the final seq_length is [1, 1]
45
+ seq_length = math_ops.cast(
46
+ math_ops.ceil(seq_length / num_elements), dtypes.int64)
47
+
48
+ # If the last n rows do not have ids, seq_length will have shape
49
+ # [batch_size - n]. Pad the remaining values with zeros.
50
+ n_pad = array_ops.shape(sp_tensor)[:1] - array_ops.shape(seq_length)[:1]
51
+ padding = array_ops.zeros(n_pad, dtype=seq_length.dtype)
52
+ return array_ops.concat([seq_length, padding], axis=0, name=name_scope)
53
+
54
+
55
+ def assert_string_or_int(dtype, prefix):
56
+ if (dtype != dtypes.string) and (not dtype.is_integer):
57
+ raise ValueError('{} dtype must be string or integer. dtype: {}.'.format(
58
+ prefix, dtype))
59
+
60
+
61
+ def assert_key_is_string(key):
62
+ if not isinstance(key, six.string_types):
63
+ raise ValueError(
64
+ 'key must be a string. Got: type {}. Given key: {}.'.format(
65
+ type(key), key))
66
+
67
+
68
+ def check_default_value(shape, default_value, dtype, key):
69
+ """Returns default value as tuple if it's valid, otherwise raises errors.
70
+
71
+ This function verifies that `default_value` is compatible with both `shape`
72
+ and `dtype`. If it is not compatible, it raises an error. If it is compatible,
73
+ it casts default_value to a tuple and returns it. `key` is used only
74
+ for error message.
75
+
76
+ Args:
77
+ shape: An iterable of integers specifies the shape of the `Tensor`.
78
+ default_value: If a single value is provided, the same value will be applied
79
+ as the default value for every item. If an iterable of values is
80
+ provided, the shape of the `default_value` should be equal to the given
81
+ `shape`.
82
+ dtype: defines the type of values. Default value is `tf.float32`. Must be a
83
+ non-quantized, real integer or floating point type.
84
+ key: Column name, used only for error messages.
85
+
86
+ Returns:
87
+ A tuple which will be used as default value.
88
+
89
+ Raises:
90
+ TypeError: if `default_value` is an iterable but not compatible with `shape`
91
+ TypeError: if `default_value` is not compatible with `dtype`.
92
+ ValueError: if `dtype` is not convertible to `tf.float32`.
93
+ """
94
+ if default_value is None:
95
+ return None
96
+
97
+ if isinstance(default_value, int):
98
+ return _create_tuple(shape, default_value)
99
+
100
+ if isinstance(default_value, float) and dtype.is_floating:
101
+ return _create_tuple(shape, default_value)
102
+
103
+ if callable(getattr(default_value, 'tolist', None)): # Handles numpy arrays
104
+ default_value = default_value.tolist()
105
+
106
+ if nest.is_sequence(default_value):
107
+ if not _is_shape_and_default_value_compatible(default_value, shape):
108
+ raise ValueError(
109
+ 'The shape of default_value must be equal to given shape. '
110
+ 'default_value: {}, shape: {}, key: {}'.format(
111
+ default_value, shape, key))
112
+ # Check if the values in the list are all integers or are convertible to
113
+ # floats.
114
+ is_list_all_int = all(
115
+ isinstance(v, int) for v in nest.flatten(default_value))
116
+ is_list_has_float = any(
117
+ isinstance(v, float) for v in nest.flatten(default_value))
118
+ if is_list_all_int:
119
+ return _as_tuple(default_value)
120
+ if is_list_has_float and dtype.is_floating:
121
+ return _as_tuple(default_value)
122
+ raise TypeError('default_value must be compatible with dtype. '
123
+ 'default_value: {}, dtype: {}, key: {}'.format(
124
+ default_value, dtype, key))
125
+
126
+
127
+ def _create_tuple(shape, value):
128
+ """Returns a tuple with given shape and filled with value."""
129
+ if shape:
130
+ return tuple([_create_tuple(shape[1:], value) for _ in range(shape[0])])
131
+ return value
132
+
133
+
134
+ def _as_tuple(value):
135
+ if not nest.is_sequence(value):
136
+ return value
137
+ return tuple([_as_tuple(v) for v in value])
138
+
139
+
140
+ def _is_shape_and_default_value_compatible(default_value, shape):
141
+ """Verifies compatibility of shape and default_value."""
142
+ # Invalid condition:
143
+ # * if default_value is not a scalar and shape is empty
144
+ # * or if default_value is an iterable and shape is not empty
145
+ if nest.is_sequence(default_value) != bool(shape):
146
+ return False
147
+ if not shape:
148
+ return True
149
+ if len(default_value) != shape[0]:
150
+ return False
151
+ for i in range(shape[0]):
152
+ if not _is_shape_and_default_value_compatible(default_value[i], shape[1:]):
153
+ return False
154
+ return True
@@ -0,0 +1,329 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+ """Higher level ops for building layers."""
17
+
18
+ from __future__ import absolute_import
19
+ from __future__ import division
20
+ from __future__ import print_function
21
+
22
+ import functools
23
+
24
+ from tensorflow.python.framework import dtypes
25
+ from tensorflow.python.framework import ops
26
+ from tensorflow.python.ops import init_ops
27
+ from tensorflow.python.ops import nn
28
+ from tensorflow.python.ops import variable_scope
29
+
30
+
31
+ def layer_norm(inputs,
32
+ center=True,
33
+ scale=True,
34
+ activation_fn=None,
35
+ reuse=None,
36
+ variables_collections=None,
37
+ outputs_collections=None,
38
+ trainable=True,
39
+ begin_norm_axis=1,
40
+ begin_params_axis=-1,
41
+ scope=None):
42
+ """Adds a Layer Normalization layer.
43
+
44
+ Based on the paper:
45
+
46
+ "Layer Normalization"
47
+ Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton
48
+ https://arxiv.org/abs/1607.06450.
49
+
50
+ Can be used as a normalizer function for conv2d and fully_connected.
51
+
52
+ Given a tensor `inputs` of rank `R`, moments are calculated and normalization
53
+ is performed over axes `begin_norm_axis ... R - 1`. Scaling and centering,
54
+ if requested, is performed over axes `begin_params_axis .. R - 1`.
55
+
56
+ By default, `begin_norm_axis = 1` and `begin_params_axis = -1`,
57
+ meaning that normalization is performed over all but the first axis
58
+ (the `HWC` if `inputs` is `NHWC`), while the `beta` and `gamma` trainable
59
+ parameters are calculated for the rightmost axis (the `C` if `inputs` is
60
+ `NHWC`). Scaling and recentering is performed via broadcast of the
61
+ `beta` and `gamma` parameters with the normalized tensor.
62
+
63
+ The shapes of `beta` and `gamma` are `inputs.shape[begin_params_axis:]`,
64
+ and this part of the inputs' shape must be fully defined.
65
+
66
+ Args:
67
+ inputs: A tensor having rank `R`. The normalization is performed over
68
+ axes `begin_norm_axis ... R - 1` and centering and scaling parameters
69
+ are calculated over `begin_params_axis ... R - 1`.
70
+ center: If True, add offset of `beta` to normalized tensor. If False, `beta`
71
+ is ignored.
72
+ scale: If True, multiply by `gamma`. If False, `gamma` is
73
+ not used. When the next layer is linear (also e.g. `nn.relu`), this can be
74
+ disabled since the scaling can be done by the next layer.
75
+ activation_fn: Activation function, default set to None to skip it and
76
+ maintain a linear activation.
77
+ reuse: Whether or not the layer and its variables should be reused. To be
78
+ able to reuse the layer scope must be given.
79
+ variables_collections: Optional collections for the variables.
80
+ outputs_collections: Collections to add the outputs.
81
+ trainable: If `True` also add variables to the graph collection
82
+ `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
83
+ begin_norm_axis: The first normalization dimension: normalization will be
84
+ performed along dimensions `begin_norm_axis : rank(inputs)`
85
+ begin_params_axis: The first parameter (beta, gamma) dimension: scale
86
+ and centering parameters will have dimensions
87
+ `begin_params_axis : rank(inputs)` and will be broadcast with the
88
+ normalized inputs accordingly.
89
+ scope: Optional scope for `variable_scope`.
90
+
91
+ Returns:
92
+ A `Tensor` representing the output of the operation, having the same
93
+ shape and dtype as `inputs`.
94
+
95
+ Raises:
96
+ ValueError: If the rank of `inputs` is not known at graph build time,
97
+ or if `inputs.shape[begin_params_axis:]` is not fully defined at
98
+ graph build time.
99
+ """
100
+ with variable_scope.variable_scope(
101
+ scope, 'LayerNorm', [inputs], reuse=reuse) as sc:
102
+ inputs = ops.convert_to_tensor(inputs)
103
+ inputs_shape = inputs.shape
104
+ inputs_rank = inputs_shape.ndims
105
+ if inputs_rank is None:
106
+ raise ValueError('Inputs %s has undefined rank.' % inputs.name)
107
+ dtype = inputs.dtype.base_dtype
108
+ if begin_norm_axis < 0:
109
+ begin_norm_axis = inputs_rank + begin_norm_axis
110
+ if begin_params_axis >= inputs_rank or begin_norm_axis >= inputs_rank:
111
+ raise ValueError('begin_params_axis (%d) and begin_norm_axis (%d) '
112
+ 'must be < rank(inputs) (%d)' %
113
+ (begin_params_axis, begin_norm_axis, inputs_rank))
114
+ params_shape = inputs_shape[begin_params_axis:]
115
+ if not params_shape.is_fully_defined():
116
+ raise ValueError(
117
+ 'Inputs %s: shape(inputs)[%s:] is not fully defined: %s' %
118
+ (inputs.name, begin_params_axis, inputs_shape))
119
+ # Allocate parameters for the beta and gamma of the normalization.
120
+ beta, gamma = None, None
121
+ if center:
122
+ beta_collections = get_variable_collections(variables_collections, 'beta')
123
+ beta = model_variable(
124
+ 'beta',
125
+ shape=params_shape,
126
+ dtype=dtype,
127
+ initializer=init_ops.zeros_initializer(),
128
+ collections=beta_collections,
129
+ trainable=trainable)
130
+ if scale:
131
+ gamma_collections = get_variable_collections(variables_collections,
132
+ 'gamma')
133
+ gamma = model_variable(
134
+ 'gamma',
135
+ shape=params_shape,
136
+ dtype=dtype,
137
+ initializer=init_ops.ones_initializer(),
138
+ collections=gamma_collections,
139
+ trainable=trainable)
140
+ # Calculate the moments on the last axis (layer activations).
141
+ norm_axes = list(range(begin_norm_axis, inputs_rank))
142
+ mean, variance = nn.moments(inputs, norm_axes, keep_dims=True)
143
+ # Compute layer normalization using the batch_normalization function.
144
+ variance_epsilon = 1e-12
145
+ outputs = nn.batch_normalization(
146
+ inputs,
147
+ mean,
148
+ variance,
149
+ offset=beta,
150
+ scale=gamma,
151
+ variance_epsilon=variance_epsilon)
152
+ outputs.set_shape(inputs_shape)
153
+ if activation_fn is not None:
154
+ outputs = activation_fn(outputs)
155
+ return collect_named_outputs(outputs_collections, sc.name, outputs)
156
+
157
+
158
+ def get_variable_collections(variables_collections, name):
159
+ if isinstance(variables_collections, dict):
160
+ variable_collections = variables_collections.get(name, None)
161
+ else:
162
+ variable_collections = variables_collections
163
+ return variable_collections
164
+
165
+
166
+ def collect_named_outputs(collections, alias, outputs):
167
+ """Add `Tensor` outputs tagged with alias to collections.
168
+
169
+ It is useful to collect end-points or tags for summaries. Example of usage:
170
+ logits = collect_named_outputs('end_points', 'inception_v3/logits', logits)
171
+ assert 'inception_v3/logits' in logits.aliases
172
+
173
+ Args:
174
+ collections: A collection or list of collections. If None skip collection.
175
+ alias: String to append to the list of aliases of outputs, for example,
176
+ 'inception_v3/conv1'.
177
+ outputs: Tensor, an output tensor to collect
178
+
179
+ Returns:
180
+ The outputs Tensor to allow inline call.
181
+ """
182
+ if collections:
183
+ append_tensor_alias(outputs, alias)
184
+ ops.add_to_collections(collections, outputs)
185
+ return outputs
186
+
187
+
188
+ def append_tensor_alias(tensor, alias):
189
+ """Append an alias to the list of aliases of the tensor.
190
+
191
+ Args:
192
+ tensor: A `Tensor`.
193
+ alias: String, to add to the list of aliases of the tensor.
194
+
195
+ Returns:
196
+ The tensor with a new alias appended to its list of aliases.
197
+ """
198
+ # Remove ending '/' if present.
199
+ if alias[-1] == '/':
200
+ alias = alias[:-1]
201
+ if hasattr(tensor, 'aliases'):
202
+ tensor.aliases.append(alias)
203
+ else:
204
+ tensor.aliases = [alias]
205
+ return tensor
206
+
207
+
208
+ def variable(name,
209
+ shape=None,
210
+ dtype=None,
211
+ initializer=None,
212
+ regularizer=None,
213
+ trainable=True,
214
+ collections=None,
215
+ caching_device=None,
216
+ device=None,
217
+ partitioner=None,
218
+ custom_getter=None,
219
+ use_resource=None):
220
+ """Gets an existing variable with these parameters or creates a new one.
221
+
222
+ Args:
223
+ name: the name of the new or existing variable.
224
+ shape: shape of the new or existing variable.
225
+ dtype: type of the new or existing variable (defaults to `DT_FLOAT`).
226
+ initializer: initializer for the variable if one is created.
227
+ regularizer: a (Tensor -> Tensor or None) function; the result of
228
+ applying it on a newly created variable will be added to the collection
229
+ GraphKeys.REGULARIZATION_LOSSES and can be used for regularization.
230
+ trainable: If `True` also add the variable to the graph collection
231
+ `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
232
+ collections: A list of collection names to which the Variable will be added.
233
+ If None it would default to `tf.GraphKeys.GLOBAL_VARIABLES`.
234
+ caching_device: Optional device string or function describing where the
235
+ Variable should be cached for reading. Defaults to the Variable's
236
+ device.
237
+ device: Optional device to place the variable. It can be an string or a
238
+ function that is called to get the device for the variable.
239
+ partitioner: Optional callable that accepts a fully defined `TensorShape`
240
+ and dtype of the `Variable` to be created, and returns a list of
241
+ partitions for each axis (currently only one axis can be partitioned).
242
+ custom_getter: Callable that allows overwriting the internal
243
+ get_variable method and has to have the same signature.
244
+ use_resource: If `True` use a ResourceVariable instead of a Variable.
245
+
246
+ Returns:
247
+ The created or existing variable.
248
+ """
249
+ collections = list(collections if collections is not None else
250
+ [ops.GraphKeys.GLOBAL_VARIABLES])
251
+
252
+ # Remove duplicates
253
+ collections = list(set(collections))
254
+ getter = variable_scope.get_variable
255
+ if custom_getter is not None:
256
+ getter = functools.partial(
257
+ custom_getter, reuse=variable_scope.get_variable_scope().reuse)
258
+ with ops.device(device or ''):
259
+ return getter(
260
+ name,
261
+ shape=shape,
262
+ dtype=dtype,
263
+ initializer=initializer,
264
+ regularizer=regularizer,
265
+ trainable=trainable,
266
+ collections=collections,
267
+ caching_device=caching_device,
268
+ partitioner=partitioner,
269
+ use_resource=use_resource)
270
+
271
+
272
+ def model_variable(name,
273
+ shape=None,
274
+ dtype=dtypes.float32,
275
+ initializer=None,
276
+ regularizer=None,
277
+ trainable=True,
278
+ collections=None,
279
+ caching_device=None,
280
+ device=None,
281
+ partitioner=None,
282
+ custom_getter=None,
283
+ use_resource=None):
284
+ """Gets an existing model variable with these parameters or creates a new one.
285
+
286
+ Args:
287
+ name: the name of the new or existing variable.
288
+ shape: shape of the new or existing variable.
289
+ dtype: type of the new or existing variable (defaults to `DT_FLOAT`).
290
+ initializer: initializer for the variable if one is created.
291
+ regularizer: a (Tensor -> Tensor or None) function; the result of
292
+ applying it on a newly created variable will be added to the collection
293
+ GraphKeys.REGULARIZATION_LOSSES and can be used for regularization.
294
+ trainable: If `True` also add the variable to the graph collection
295
+ `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
296
+ collections: A list of collection names to which the Variable will be added.
297
+ Note that the variable is always also added to the
298
+ `GraphKeys.GLOBAL_VARIABLES` and `GraphKeys.MODEL_VARIABLES` collections.
299
+ caching_device: Optional device string or function describing where the
300
+ Variable should be cached for reading. Defaults to the Variable's
301
+ device.
302
+ device: Optional device to place the variable. It can be an string or a
303
+ function that is called to get the device for the variable.
304
+ partitioner: Optional callable that accepts a fully defined `TensorShape`
305
+ and dtype of the `Variable` to be created, and returns a list of
306
+ partitions for each axis (currently only one axis can be partitioned).
307
+ custom_getter: Callable that allows overwriting the internal
308
+ get_variable method and has to have the same signature.
309
+ use_resource: If `True` use a ResourceVariable instead of a Variable.
310
+
311
+ Returns:
312
+ The created or existing variable.
313
+ """
314
+ collections = list(collections or [])
315
+ collections += [ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.MODEL_VARIABLES]
316
+ var = variable(
317
+ name,
318
+ shape=shape,
319
+ dtype=dtype,
320
+ initializer=initializer,
321
+ regularizer=regularizer,
322
+ trainable=trainable,
323
+ collections=collections,
324
+ caching_device=caching_device,
325
+ device=device,
326
+ partitioner=partitioner,
327
+ custom_getter=custom_getter,
328
+ use_resource=use_resource)
329
+ return var
@@ -0,0 +1,14 @@
1
+ from tensorflow.python.framework import ops
2
+
3
+
4
+ class GraphKeys(ops.GraphKeys):
5
+ # For rank service
6
+ RANK_SERVICE_FG_CONF = '__rank_service_fg_conf'
7
+ RANK_SERVICE_INPUT = '__rank_service_input'
8
+ RANK_SERVICE_OUTPUT = '__rank_service_output'
9
+ RANK_SERVICE_EMBEDDING = '__rank_service_embedding'
10
+ RANK_SERVICE_INPUT_SRC = '__rank_service_input_src'
11
+ RANK_SERVICE_REPLACE_OP = '__rank_service_replace'
12
+ RANK_SERVICE_SHAPE_OPT_FLAG = '__rank_service_shape_opt_flag'
13
+ # For compatition between RTP and EasyRec
14
+ RANK_SERVICE_FEATURE_NODE = '__rank_service_feature_node'