easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
@@ -0,0 +1,598 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import re
4
+ import string
5
+
6
+ import tensorflow as tf
7
+ from tensorflow.python.keras import activations
8
+ from tensorflow.python.keras import constraints
9
+ from tensorflow.python.keras import initializers
10
+ from tensorflow.python.keras import regularizers
11
+ from tensorflow.python.keras.layers import Layer
12
+
13
+
14
+ class EinsumDense(Layer):
15
+ """A layer that uses `einsum` as the backing computation.
16
+
17
+ This layer can perform einsum calculations of arbitrary dimensionality.
18
+
19
+ Args:
20
+ equation: An equation describing the einsum to perform.
21
+ This equation must be a valid einsum string of the form
22
+ `ab,bc->ac`, `...ab,bc->...ac`, or
23
+ `ab...,bc->ac...` where 'ab', 'bc', and 'ac' can be any valid einsum
24
+ axis expression sequence.
25
+ output_shape: The expected shape of the output tensor
26
+ (excluding the batch dimension and any dimensions
27
+ represented by ellipses). You can specify `None` for any dimension
28
+ that is unknown or can be inferred from the input shape.
29
+ activation: Activation function to use. If you don't specify anything,
30
+ no activation is applied
31
+ (that is, a "linear" activation: `a(x) = x`).
32
+ bias_axes: A string containing the output dimension(s)
33
+ to apply a bias to. Each character in the `bias_axes` string
34
+ should correspond to a character in the output portion
35
+ of the `equation` string.
36
+ kernel_initializer: Initializer for the `kernel` weights matrix.
37
+ bias_initializer: Initializer for the bias vector.
38
+ kernel_regularizer: Regularizer function applied to the `kernel` weights
39
+ matrix.
40
+ bias_regularizer: Regularizer function applied to the bias vector.
41
+ kernel_constraint: Constraint function applied to the `kernel` weights
42
+ matrix.
43
+ bias_constraint: Constraint function applied to the bias vector.
44
+ lora_rank: Optional integer. If set, the layer's forward pass
45
+ will implement LoRA (Low-Rank Adaptation)
46
+ with the provided rank. LoRA sets the layer's kernel
47
+ to non-trainable and replaces it with a delta over the
48
+ original kernel, obtained via multiplying two lower-rank
49
+ trainable matrices
50
+ (the factorization happens on the last dimension).
51
+ This can be useful to reduce the
52
+ computation cost of fine-tuning large dense layers.
53
+ You can also enable LoRA on an existing
54
+ `EinsumDense` layer by calling `layer.enable_lora(rank)`.
55
+ **kwargs: Base layer keyword arguments, such as `name` and `dtype`.
56
+
57
+ Examples:
58
+ **Biased dense layer with einsums**
59
+
60
+ This example shows how to instantiate a standard Keras dense layer using
61
+ einsum operations. This example is equivalent to
62
+ `keras.layers.Dense(64, use_bias=True)`.
63
+
64
+ >>> layer = tf.keras.layers.EinsumDense("ab,bc->ac",
65
+ ... output_shape=64,
66
+ ... bias_axes="c")
67
+ >>> input_tensor = tf.keras.Input(shape=[32])
68
+ >>> output_tensor = layer(input_tensor)
69
+ >>> output_tensor.shape
70
+ (None, 64)
71
+
72
+ **Applying a dense layer to a sequence**
73
+
74
+ This example shows how to instantiate a layer that applies the same dense
75
+ operation to every element in a sequence. Here, the `output_shape` has two
76
+ values (since there are two non-batch dimensions in the output); the first
77
+ dimension in the `output_shape` is `None`, because the sequence dimension
78
+ `b` has an unknown shape.
79
+
80
+ >>> layer = tf.keras.layers.EinsumDense("abc,cd->abd",
81
+ ... output_shape=(None, 64),
82
+ ... bias_axes="d")
83
+ >>> input_tensor = tf.keras.Input(shape=[32, 128])
84
+ >>> output_tensor = layer(input_tensor)
85
+ >>> output_tensor.shape
86
+ (None, 32, 64)
87
+
88
+ **Applying a dense layer to a sequence using ellipses**
89
+
90
+ This example shows how to instantiate a layer that applies the same dense
91
+ operation to every element in a sequence, but uses the ellipsis notation
92
+ instead of specifying the batch and sequence dimensions.
93
+
94
+ Because we are using ellipsis notation and have specified only one axis, the
95
+ `output_shape` arg is a single value. When instantiated in this way, the
96
+ layer can handle any number of sequence dimensions - including the case
97
+ where no sequence dimension exists.
98
+
99
+ >>> layer = tf.keras.layers.EinsumDense("...x,xy->...y",
100
+ ... output_shape=64,
101
+ ... bias_axes="y")
102
+ >>> input_tensor = tf.keras.Input(shape=[32, 128])
103
+ >>> output_tensor = layer(input_tensor)
104
+ >>> output_tensor.shape
105
+ (None, 32, 64)
106
+ """
107
+
108
+ def __init__(self,
109
+ equation,
110
+ output_shape,
111
+ activation=None,
112
+ bias_axes=None,
113
+ kernel_initializer='glorot_uniform',
114
+ bias_initializer='zeros',
115
+ kernel_regularizer=None,
116
+ bias_regularizer=None,
117
+ kernel_constraint=None,
118
+ bias_constraint=None,
119
+ lora_rank=None,
120
+ **kwargs):
121
+ super(EinsumDense, self).__init__(**kwargs)
122
+ self.equation = equation
123
+ if isinstance(output_shape, int):
124
+ self.partial_output_shape = (output_shape,)
125
+ else:
126
+ self.partial_output_shape = tuple(output_shape)
127
+ self.bias_axes = bias_axes
128
+ self.activation = activations.get(activation)
129
+ self.kernel_initializer = initializers.get(kernel_initializer)
130
+ self.bias_initializer = initializers.get(bias_initializer)
131
+ self.kernel_regularizer = regularizers.get(kernel_regularizer)
132
+ self.bias_regularizer = regularizers.get(bias_regularizer)
133
+ self.kernel_constraint = constraints.get(kernel_constraint)
134
+ self.bias_constraint = constraints.get(bias_constraint)
135
+ self.lora_rank = lora_rank
136
+ self.lora_enabled = False
137
+
138
+ def build(self, input_shape):
139
+ shape_data = _analyze_einsum_string(
140
+ self.equation,
141
+ self.bias_axes,
142
+ input_shape,
143
+ self.partial_output_shape,
144
+ )
145
+ kernel_shape, bias_shape, full_output_shape = shape_data
146
+ for i in range(len(kernel_shape)):
147
+ dim = kernel_shape[i]
148
+ if isinstance(dim, tf.Dimension):
149
+ kernel_shape[i] = dim.value
150
+ for i in range(len(bias_shape)):
151
+ dim = bias_shape[i]
152
+ if isinstance(dim, tf.Dimension):
153
+ bias_shape[i] = dim.value
154
+ for i in range(len(full_output_shape)):
155
+ dim = full_output_shape[i]
156
+ if isinstance(dim, tf.Dimension):
157
+ full_output_shape[i] = dim.value
158
+ self.full_output_shape = tuple(full_output_shape)
159
+ self._kernel = self.add_weight(
160
+ name='kernel',
161
+ shape=tuple(kernel_shape),
162
+ initializer=self.kernel_initializer,
163
+ regularizer=self.kernel_regularizer,
164
+ constraint=self.kernel_constraint,
165
+ dtype=self.dtype,
166
+ trainable=True,
167
+ )
168
+ if bias_shape is not None:
169
+ self.bias = self.add_weight(
170
+ name='bias',
171
+ shape=tuple(bias_shape),
172
+ initializer=self.bias_initializer,
173
+ regularizer=self.bias_regularizer,
174
+ constraint=self.bias_constraint,
175
+ dtype=self.dtype,
176
+ trainable=True,
177
+ )
178
+ else:
179
+ self.bias = None
180
+ self.built = True
181
+ if self.lora_rank:
182
+ self.enable_lora(self.lora_rank)
183
+
184
+ @property
185
+ def kernel(self):
186
+ if not self.built:
187
+ raise AttributeError(
188
+ 'You must build the layer before accessing `kernel`.')
189
+ if self.lora_enabled:
190
+ return self._kernel + tf.matmul(self.lora_kernel_a, self.lora_kernel_b)
191
+ return self._kernel
192
+
193
+ def compute_output_shape(self, _):
194
+ return self.full_output_shape
195
+
196
+ def call(self, inputs, training=None):
197
+ x = tf.einsum(self.equation, inputs, self.kernel)
198
+ if self.bias is not None:
199
+ x += self.bias
200
+ if self.activation is not None:
201
+ x = self.activation(x)
202
+ return x
203
+
204
+ def enable_lora(self,
205
+ rank,
206
+ a_initializer='he_uniform',
207
+ b_initializer='zeros'):
208
+ if self.kernel_constraint:
209
+ raise ValueError('Lora is incompatible with kernel constraints. '
210
+ 'In order to enable lora on this layer, remove the '
211
+ '`kernel_constraint` argument.')
212
+ if not self.built:
213
+ raise ValueError("Cannot enable lora on a layer that isn't yet built.")
214
+ if self.lora_enabled:
215
+ raise ValueError('lora is already enabled. '
216
+ 'This can only be done once per layer.')
217
+ self._tracker.unlock()
218
+ self.lora_kernel_a = self.add_weight(
219
+ name='lora_kernel_a',
220
+ shape=(self.kernel.shape[:-1] + (rank,)),
221
+ initializer=initializers.get(a_initializer),
222
+ regularizer=self.kernel_regularizer,
223
+ )
224
+ self.lora_kernel_b = self.add_weight(
225
+ name='lora_kernel_b',
226
+ shape=(rank, self.kernel.shape[-1]),
227
+ initializer=initializers.get(b_initializer),
228
+ regularizer=self.kernel_regularizer,
229
+ )
230
+ self._kernel.trainable = False
231
+ self._tracker.lock()
232
+ self.lora_enabled = True
233
+ self.lora_rank = rank
234
+
235
+ def save_own_variables(self, store):
236
+ # Do nothing if the layer isn't yet built
237
+ if not self.built:
238
+ return
239
+ # The keys of the `store` will be saved as determined because the
240
+ # default ordering will change after quantization
241
+ kernel_value, kernel_scale = self._get_kernel_with_merged_lora()
242
+ target_variables = [kernel_value]
243
+ if self.bias is not None:
244
+ target_variables.append(self.bias)
245
+ for i, variable in enumerate(target_variables):
246
+ store[str(i)] = variable
247
+
248
+ def load_own_variables(self, store):
249
+ if not self.lora_enabled:
250
+ self._check_load_own_variables(store)
251
+ # Do nothing if the layer isn't yet built
252
+ if not self.built:
253
+ return
254
+ # The keys of the `store` will be saved as determined because the
255
+ # default ordering will change after quantization
256
+ target_variables = [self._kernel]
257
+ if self.bias is not None:
258
+ target_variables.append(self.bias)
259
+ for i, variable in enumerate(target_variables):
260
+ variable.assign(store[str(i)])
261
+ if self.lora_enabled:
262
+ self.lora_kernel_a.assign(tf.zeros(self.lora_kernel_a.shape))
263
+ self.lora_kernel_b.assign(tf.zeros(self.lora_kernel_b.shape))
264
+
265
+ def get_config(self):
266
+ base_config = super(EinsumDense, self).get_config()
267
+ config = {
268
+ 'output_shape':
269
+ self.partial_output_shape,
270
+ 'equation':
271
+ self.equation,
272
+ 'activation':
273
+ activations.serialize(self.activation),
274
+ 'bias_axes':
275
+ self.bias_axes,
276
+ 'kernel_initializer':
277
+ initializers.serialize(self.kernel_initializer),
278
+ 'bias_initializer':
279
+ initializers.serialize(self.bias_initializer),
280
+ 'kernel_regularizer':
281
+ regularizers.serialize(self.kernel_regularizer),
282
+ 'bias_regularizer':
283
+ regularizers.serialize(self.bias_regularizer),
284
+ 'activity_regularizer':
285
+ regularizers.serialize(self.activity_regularizer),
286
+ 'kernel_constraint':
287
+ constraints.serialize(self.kernel_constraint),
288
+ 'bias_constraint':
289
+ constraints.serialize(self.bias_constraint),
290
+ }
291
+ if self.lora_rank:
292
+ config['lora_rank'] = self.lora_rank
293
+ config.update(base_config)
294
+ return config
295
+
296
+ def _check_load_own_variables(self, store):
297
+ all_vars = self._trainable_variables + self._non_trainable_variables
298
+ if len(store.keys()) != len(all_vars):
299
+ if len(all_vars) == 0 and not self.built:
300
+ raise ValueError(
301
+ "Layer '{name}' was never built "
302
+ "and thus it doesn't have any variables. "
303
+ 'However the weights file lists {num_keys} '
304
+ 'variables for this layer.\n'
305
+ 'In most cases, this error indicates that either:\n\n'
306
+ '1. The layer is owned by a parent layer that '
307
+ 'implements a `build()` method, but calling the '
308
+ "parent's `build()` method did NOT create the state of "
309
+ "the child layer '{name}'. A `build()` method "
310
+ 'must create ALL state for the layer, including '
311
+ 'the state of any children layers.\n\n'
312
+ '2. You need to implement '
313
+ 'the `def build_from_config(self, config)` method '
314
+ "on layer '{name}', to specify how to rebuild "
315
+ 'it during loading. '
316
+ 'In this case, you might also want to implement the '
317
+ 'method that generates the build config at saving time, '
318
+ '`def get_build_config(self)`. '
319
+ 'The method `build_from_config()` is meant '
320
+ 'to create the state '
321
+ 'of the layer (i.e. its variables) upon deserialization.'.format(
322
+ name=self.name, num_keys=len(store.keys())))
323
+ raise ValueError(
324
+ "Layer '{name}' expected {num_var} variables, but received "
325
+ '{num_key} variables during loading. '
326
+ 'Expected: {names}'.format(
327
+ name=self.name,
328
+ num_var=len(store.keys()),
329
+ num_key=len(store.keys()),
330
+ names=[v.name for v in all_vars]))
331
+
332
+ def _get_kernel_with_merged_lora(self):
333
+ kernel_value = self.kernel
334
+ kernel_scale = None
335
+ return kernel_value, kernel_scale
336
+
337
+
338
+ def _analyze_einsum_string(equation, bias_axes, input_shape, output_shape):
339
+ """Analyzes an einsum string to determine the required weight shape."""
340
+ dot_replaced_string = re.sub(r'\.\.\.', '0', equation)
341
+
342
+ # This is the case where no ellipses are present in the string.
343
+ split_string = re.match('([a-zA-Z]+),([a-zA-Z]+)->([a-zA-Z]+)',
344
+ dot_replaced_string)
345
+ if split_string:
346
+ return _analyze_split_string(split_string, bias_axes, input_shape,
347
+ output_shape)
348
+
349
+ # This is the case where ellipses are present on the left.
350
+ split_string = re.match('0([a-zA-Z]+),([a-zA-Z]+)->0([a-zA-Z]+)',
351
+ dot_replaced_string)
352
+ if split_string:
353
+ return _analyze_split_string(
354
+ split_string, bias_axes, input_shape, output_shape, left_elided=True)
355
+
356
+ # This is the case where ellipses are present on the right.
357
+ split_string = re.match('([a-zA-Z]{2,})0,([a-zA-Z]+)->([a-zA-Z]+)0',
358
+ dot_replaced_string)
359
+ if split_string:
360
+ return _analyze_split_string(split_string, bias_axes, input_shape,
361
+ output_shape)
362
+
363
+ raise ValueError(
364
+ "Invalid einsum equation '{equation}'. Equations must be in the form "
365
+ '[X],[Y]->[Z], ...[X],[Y]->...[Z], or [X]...,[Y]->[Z]....'.format(
366
+ equation=equation))
367
+
368
+
369
+ def _analyze_split_string(split_string,
370
+ bias_axes,
371
+ input_shape,
372
+ output_shape,
373
+ left_elided=False):
374
+ """Analyze an pre-split einsum string to find the weight shape."""
375
+ input_spec = split_string.group(1)
376
+ weight_spec = split_string.group(2)
377
+ output_spec = split_string.group(3)
378
+ elided = len(input_shape) - len(input_spec)
379
+ if isinstance(output_shape, int):
380
+ output_shape = [output_shape]
381
+ else:
382
+ output_shape = list(output_shape)
383
+
384
+ output_shape.insert(0, input_shape[0])
385
+
386
+ if elided > 0 and left_elided:
387
+ for i in range(1, elided):
388
+ # We already inserted the 0th input dimension at dim 0, so we need
389
+ # to start at location 1 here.
390
+ output_shape.insert(1, input_shape[i])
391
+ elif elided > 0 and not left_elided:
392
+ for i in range(len(input_shape) - elided, len(input_shape)):
393
+ output_shape.append(input_shape[i])
394
+
395
+ if left_elided:
396
+ # If we have beginning dimensions elided, we need to use negative
397
+ # indexing to determine where in the input dimension our values are.
398
+ input_dim_map = {
399
+ dim: (i + elided) - len(input_shape) for i, dim in enumerate(input_spec)
400
+ }
401
+ # Because we've constructed the full output shape already, we don't need
402
+ # to do negative indexing.
403
+ output_dim_map = {dim: (i + elided) for i, dim in enumerate(output_spec)}
404
+ else:
405
+ input_dim_map = {dim: i for i, dim in enumerate(input_spec)}
406
+ output_dim_map = {dim: i for i, dim in enumerate(output_spec)}
407
+
408
+ for dim in input_spec:
409
+ input_shape_at_dim = input_shape[input_dim_map[dim]]
410
+ if dim in output_dim_map:
411
+ output_shape_at_dim = output_shape[output_dim_map[dim]]
412
+ if (output_shape_at_dim is not None and
413
+ output_shape_at_dim != input_shape_at_dim):
414
+ raise ValueError(
415
+ 'Input shape and output shape do not match at shared '
416
+ "dimension '{dim}'. Input shape is {input_shape_at_dim}, "
417
+ 'and output shape is {output_shape}.'.format(
418
+ dim=dim,
419
+ input_shape_at_dim=input_shape_at_dim,
420
+ output_shape=output_shape[output_dim_map[dim]]))
421
+
422
+ for dim in output_spec:
423
+ if dim not in input_spec and dim not in weight_spec:
424
+ raise ValueError(
425
+ "Dimension '{dim}' was specified in the output "
426
+ "'{output_spec}' but has no corresponding dim in the input "
427
+ "spec '{input_spec}' or weight spec '{output_spec}'".format(
428
+ dim=dim, output_spec=output_spec, input_spec=input_spec))
429
+
430
+ weight_shape = []
431
+ for dim in weight_spec:
432
+ if dim in input_dim_map:
433
+ weight_shape.append(input_shape[input_dim_map[dim]])
434
+ elif dim in output_dim_map:
435
+ weight_shape.append(output_shape[output_dim_map[dim]])
436
+ else:
437
+ raise ValueError(
438
+ "Weight dimension '{dim}' did not have a match in either "
439
+ "the input spec '{input_spec}' or the output "
440
+ "spec '{output_spec}'. For this layer, the weight must "
441
+ 'be fully specified.'.format(
442
+ dim=dim, input_spec=input_spec, output_spec=output_spec))
443
+
444
+ if bias_axes is not None:
445
+ num_left_elided = elided if left_elided else 0
446
+ idx_map = {
447
+ char: output_shape[i + num_left_elided]
448
+ for i, char in enumerate(output_spec)
449
+ }
450
+
451
+ for char in bias_axes:
452
+ if char not in output_spec:
453
+ raise ValueError(
454
+ "Bias dimension '{char}' was requested, but is not part "
455
+ "of the output spec '{output_spec}'".format(
456
+ char=char, output_spec=output_spec))
457
+
458
+ first_bias_location = min([output_spec.find(char) for char in bias_axes])
459
+ bias_output_spec = output_spec[first_bias_location:]
460
+
461
+ bias_shape = [
462
+ idx_map[char] if char in bias_axes else 1 for char in bias_output_spec
463
+ ]
464
+
465
+ if not left_elided:
466
+ for _ in range(elided):
467
+ bias_shape.append(1)
468
+ else:
469
+ bias_shape = None
470
+
471
+ return weight_shape, bias_shape, output_shape
472
+
473
+
474
+ def _analyze_quantization_info(equation, input_shape):
475
+
476
+ def get_specs(equation, input_shape):
477
+ possible_labels = string.ascii_letters
478
+ dot_replaced_string = re.sub(r'\.\.\.', '0', equation)
479
+
480
+ # This is the case where no ellipses are present in the string.
481
+ split_string = re.match('([a-zA-Z]+),([a-zA-Z]+)->([a-zA-Z]+)',
482
+ dot_replaced_string)
483
+ if split_string is not None:
484
+ input_spec = split_string.group(1)
485
+ weight_spec = split_string.group(2)
486
+ output_spec = split_string.group(3)
487
+ return input_spec, weight_spec, output_spec
488
+
489
+ # This is the case where ellipses are present on the left.
490
+ split_string = re.match('0([a-zA-Z]+),([a-zA-Z]+)->0([a-zA-Z]+)',
491
+ dot_replaced_string)
492
+ if split_string is not None:
493
+ input_spec = split_string.group(1)
494
+ weight_spec = split_string.group(2)
495
+ output_spec = split_string.group(3)
496
+ elided = len(input_shape) - len(input_spec)
497
+ possible_labels = sorted(
498
+ set(possible_labels) - set(input_spec) - set(weight_spec) -
499
+ set(output_spec))
500
+ # Pad labels on the left to `input_spec` and `output_spec`
501
+ for i in range(elided):
502
+ input_spec = possible_labels[i] + input_spec
503
+ output_spec = possible_labels[i] + output_spec
504
+ return input_spec, weight_spec, output_spec
505
+
506
+ # This is the case where ellipses are present on the right.
507
+ split_string = re.match('([a-zA-Z]{2,})0,([a-zA-Z]+)->([a-zA-Z]+)0',
508
+ dot_replaced_string)
509
+ if split_string is not None:
510
+ input_spec = split_string.group(1)
511
+ weight_spec = split_string.group(2)
512
+ output_spec = split_string.group(3)
513
+ elided = len(input_shape) - len(input_spec)
514
+ possible_labels = sorted(
515
+ set(possible_labels) - set(input_spec) - set(weight_spec) -
516
+ set(output_spec))
517
+ # Pad labels on the right to `input_spec` and `output_spec`
518
+ for i in range(elided):
519
+ input_spec = input_spec + possible_labels[i]
520
+ output_spec = output_spec + possible_labels[i]
521
+ return input_spec, weight_spec, output_spec
522
+
523
+ raise ValueError(
524
+ "Invalid einsum equation '{equation}'. Equations must be in the "
525
+ 'form [X],[Y]->[Z], ...[X],[Y]->...[Z], or [X]...,[Y]->[Z]....'.format(
526
+ equation=equation))
527
+
528
+ input_spec, weight_spec, output_spec = get_specs(equation, input_shape)
529
+
530
+ # Determine the axes that should be reduced by the quantizer
531
+ input_reduced_axes = []
532
+ weight_reduced_axes = []
533
+ for i, label in enumerate(input_spec):
534
+ index = output_spec.find(label)
535
+ if index == -1:
536
+ input_reduced_axes.append(i)
537
+ for i, label in enumerate(weight_spec):
538
+ index = output_spec.find(label)
539
+ if index == -1:
540
+ weight_reduced_axes.append(i)
541
+
542
+ # Determine the axes of `ops.expand_dims`
543
+ input_expand_axes = []
544
+ weight_expand_axes = []
545
+ for i, label in enumerate(output_spec):
546
+ index_input = input_spec.find(label)
547
+ index_weight = weight_spec.find(label)
548
+ if index_input == -1:
549
+ input_expand_axes.append(i)
550
+ if index_weight == -1:
551
+ weight_expand_axes.append(i)
552
+
553
+ # Determine the axes of `ops.transpose`
554
+ input_transpose_axes = []
555
+ weight_transpose_axes = []
556
+ for i, label in enumerate(output_spec):
557
+ index_input = input_spec.find(label)
558
+ index_weight = weight_spec.find(label)
559
+ if index_input != -1:
560
+ input_transpose_axes.append(index_input)
561
+ if index_weight != -1:
562
+ weight_transpose_axes.append(index_weight)
563
+ # Postprocess the information:
564
+ # 1. Add dummy axes (1) to transpose_axes
565
+ # 2. Add axis to squeeze_axes if 1. failed
566
+ input_squeeze_axes = []
567
+ weight_squeeze_axes = []
568
+ for ori_index in input_reduced_axes:
569
+ try:
570
+ index = input_expand_axes.pop(0)
571
+ except IndexError:
572
+ input_squeeze_axes.append(ori_index)
573
+ input_transpose_axes.insert(index, ori_index)
574
+ for ori_index in weight_reduced_axes:
575
+ try:
576
+ index = weight_expand_axes.pop(0)
577
+ except IndexError:
578
+ weight_squeeze_axes.append(ori_index)
579
+ weight_transpose_axes.insert(index, ori_index)
580
+ # Prepare equation for `einsum_with_inputs_gradient`
581
+ custom_gradient_equation = '{output_spec},{weight_spec}->{input_spec}'.format(
582
+ output_spec=output_spec, input_spec=input_spec, weight_spec=weight_spec)
583
+ weight_reverse_transpose_axes = [
584
+ i for (_, i) in sorted((v, i)
585
+ for (i, v) in enumerate(weight_transpose_axes))
586
+ ]
587
+ return (
588
+ input_reduced_axes,
589
+ weight_reduced_axes,
590
+ input_transpose_axes,
591
+ weight_transpose_axes,
592
+ input_expand_axes,
593
+ weight_expand_axes,
594
+ input_squeeze_axes,
595
+ weight_squeeze_axes,
596
+ custom_gradient_equation,
597
+ weight_reverse_transpose_axes,
598
+ )
@@ -0,0 +1,81 @@
1
+ # -*- encoding:utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ """Fused embedding layer."""
4
+ import tensorflow as tf
5
+ from tensorflow.python.keras.layers import Embedding
6
+ from tensorflow.python.keras.layers import Layer
7
+
8
+
9
+ def _combine(embeddings, weights, comb_fn):
10
+ # embeddings shape: [B, N, D]
11
+ if callable(comb_fn):
12
+ return comb_fn(embeddings, axis=1)
13
+ if weights is None:
14
+ return tf.reduce_mean(embeddings, axis=1)
15
+ if isinstance(weights, tf.SparseTensor):
16
+ if weights.dtype == tf.string:
17
+ weights = tf.sparse.to_dense(weights, default_value='0')
18
+ weights = tf.string_to_number(weights)
19
+ else:
20
+ weights = tf.sparse.to_dense(weights, default_value=0.0)
21
+ sum_weights = tf.reduce_sum(weights, axis=1, keepdims=True)
22
+ weights = tf.expand_dims(weights / sum_weights, axis=-1)
23
+ return tf.reduce_sum(embeddings * weights, axis=1)
24
+
25
+
26
+ class EmbeddingLayer(Layer):
27
+
28
+ def __init__(self, params, name='embedding_layer', reuse=None, **kwargs):
29
+ super(EmbeddingLayer, self).__init__(name=name, **kwargs)
30
+ params.check_required(['vocab_size', 'embedding_dim'])
31
+ vocab_size = int(params.vocab_size)
32
+ combiner = params.get_or_default('combiner', 'weight')
33
+ if combiner == 'mean':
34
+ self.combine_fn = tf.reduce_mean
35
+ elif combiner == 'sum':
36
+ self.combine_fn = tf.reduce_sum
37
+ elif combiner == 'max':
38
+ self.combine_fn = tf.reduce_max
39
+ elif combiner == 'min':
40
+ self.combine_fn = tf.reduce_min
41
+ elif combiner == 'weight':
42
+ self.combine_fn = 'weight'
43
+ else:
44
+ raise ValueError('unsupported embedding combiner: ' + combiner)
45
+ self.embed_dim = int(params.embedding_dim)
46
+ self.embedding = Embedding(vocab_size, self.embed_dim)
47
+ self.do_concat = params.get_or_default('concat', True)
48
+
49
+ def call(self, inputs, training=None, **kwargs):
50
+ inputs, weights = inputs
51
+ # 将多个特征的输入合并为一个索引 tensor
52
+ flat_inputs = [tf.reshape(input_field, [-1]) for input_field in inputs]
53
+ all_indices = tf.concat(flat_inputs, axis=0)
54
+ # 从共享的嵌入表中进行一次 embedding lookup
55
+ all_embeddings = self.embedding(all_indices)
56
+ is_multi = []
57
+ # 计算每个特征的嵌入
58
+ split_sizes = []
59
+ for input_field in inputs:
60
+ assert input_field.shape.ndims <= 2, 'dims of embedding layer input must be <= 2'
61
+ input_shape = tf.shape(input_field)
62
+ size = input_shape[0]
63
+ if input_field.shape.ndims > 1:
64
+ size *= input_shape[-1]
65
+ is_multi.append(True)
66
+ else:
67
+ is_multi.append(False)
68
+ split_sizes.append(size)
69
+ embeddings = tf.split(all_embeddings, split_sizes, axis=0)
70
+ for i in range(len(embeddings)):
71
+ if is_multi[i]:
72
+ batch_size = tf.shape(inputs[i])[0]
73
+ embeddings[i] = tf.cond(
74
+ tf.equal(tf.size(embeddings[i]), 0),
75
+ lambda: tf.zeros([batch_size, self.embed_dim]), lambda: _combine(
76
+ tf.reshape(embeddings[i], [batch_size, -1, self.embed_dim]),
77
+ weights[i], self.combine_fn))
78
+ if self.do_concat:
79
+ embeddings = tf.concat(embeddings, axis=-1)
80
+ print('Embedding layer:', self.name, embeddings)
81
+ return embeddings