easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of easy-cs-rec-custommodel might be problematic. Click here for more details.
- easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
- easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
- easy_rec/__init__.py +114 -0
- easy_rec/python/__init__.py +0 -0
- easy_rec/python/builders/__init__.py +0 -0
- easy_rec/python/builders/hyperparams_builder.py +78 -0
- easy_rec/python/builders/loss_builder.py +333 -0
- easy_rec/python/builders/optimizer_builder.py +211 -0
- easy_rec/python/builders/strategy_builder.py +44 -0
- easy_rec/python/compat/__init__.py +0 -0
- easy_rec/python/compat/adam_s.py +245 -0
- easy_rec/python/compat/array_ops.py +229 -0
- easy_rec/python/compat/dynamic_variable.py +542 -0
- easy_rec/python/compat/early_stopping.py +653 -0
- easy_rec/python/compat/embedding_ops.py +162 -0
- easy_rec/python/compat/embedding_parallel_saver.py +316 -0
- easy_rec/python/compat/estimator_train.py +116 -0
- easy_rec/python/compat/exporter.py +473 -0
- easy_rec/python/compat/feature_column/__init__.py +0 -0
- easy_rec/python/compat/feature_column/feature_column.py +3675 -0
- easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
- easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
- easy_rec/python/compat/feature_column/utils.py +154 -0
- easy_rec/python/compat/layers.py +329 -0
- easy_rec/python/compat/ops.py +14 -0
- easy_rec/python/compat/optimizers.py +619 -0
- easy_rec/python/compat/queues.py +311 -0
- easy_rec/python/compat/regularizers.py +208 -0
- easy_rec/python/compat/sok_optimizer.py +440 -0
- easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
- easy_rec/python/compat/weight_decay_optimizers.py +475 -0
- easy_rec/python/core/__init__.py +0 -0
- easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
- easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
- easy_rec/python/core/learning_schedules.py +228 -0
- easy_rec/python/core/metrics.py +402 -0
- easy_rec/python/core/sampler.py +844 -0
- easy_rec/python/eval.py +102 -0
- easy_rec/python/export.py +150 -0
- easy_rec/python/feature_column/__init__.py +0 -0
- easy_rec/python/feature_column/feature_column.py +664 -0
- easy_rec/python/feature_column/feature_group.py +89 -0
- easy_rec/python/hpo/__init__.py +0 -0
- easy_rec/python/hpo/emr_hpo.py +140 -0
- easy_rec/python/hpo/generate_hpo_sql.py +71 -0
- easy_rec/python/hpo/pai_hpo.py +297 -0
- easy_rec/python/inference/__init__.py +0 -0
- easy_rec/python/inference/csv_predictor.py +189 -0
- easy_rec/python/inference/hive_parquet_predictor.py +200 -0
- easy_rec/python/inference/hive_predictor.py +166 -0
- easy_rec/python/inference/odps_predictor.py +70 -0
- easy_rec/python/inference/parquet_predictor.py +147 -0
- easy_rec/python/inference/parquet_predictor_v2.py +147 -0
- easy_rec/python/inference/predictor.py +621 -0
- easy_rec/python/inference/processor/__init__.py +0 -0
- easy_rec/python/inference/processor/test.py +170 -0
- easy_rec/python/inference/vector_retrieve.py +124 -0
- easy_rec/python/input/__init__.py +0 -0
- easy_rec/python/input/batch_tfrecord_input.py +117 -0
- easy_rec/python/input/criteo_binary_reader.py +259 -0
- easy_rec/python/input/criteo_input.py +107 -0
- easy_rec/python/input/csv_input.py +175 -0
- easy_rec/python/input/csv_input_ex.py +72 -0
- easy_rec/python/input/csv_input_v2.py +68 -0
- easy_rec/python/input/datahub_input.py +320 -0
- easy_rec/python/input/dummy_input.py +58 -0
- easy_rec/python/input/hive_input.py +123 -0
- easy_rec/python/input/hive_parquet_input.py +140 -0
- easy_rec/python/input/hive_rtp_input.py +174 -0
- easy_rec/python/input/input.py +1064 -0
- easy_rec/python/input/kafka_dataset.py +144 -0
- easy_rec/python/input/kafka_input.py +235 -0
- easy_rec/python/input/load_parquet.py +317 -0
- easy_rec/python/input/odps_input.py +101 -0
- easy_rec/python/input/odps_input_v2.py +110 -0
- easy_rec/python/input/odps_input_v3.py +132 -0
- easy_rec/python/input/odps_rtp_input.py +187 -0
- easy_rec/python/input/odps_rtp_input_v2.py +104 -0
- easy_rec/python/input/parquet_input.py +397 -0
- easy_rec/python/input/parquet_input_v2.py +180 -0
- easy_rec/python/input/parquet_input_v3.py +203 -0
- easy_rec/python/input/rtp_input.py +225 -0
- easy_rec/python/input/rtp_input_v2.py +145 -0
- easy_rec/python/input/tfrecord_input.py +100 -0
- easy_rec/python/layers/__init__.py +0 -0
- easy_rec/python/layers/backbone.py +571 -0
- easy_rec/python/layers/capsule_layer.py +176 -0
- easy_rec/python/layers/cmbf.py +390 -0
- easy_rec/python/layers/common_layers.py +192 -0
- easy_rec/python/layers/dnn.py +87 -0
- easy_rec/python/layers/embed_input_layer.py +25 -0
- easy_rec/python/layers/fm.py +26 -0
- easy_rec/python/layers/input_layer.py +396 -0
- easy_rec/python/layers/keras/__init__.py +34 -0
- easy_rec/python/layers/keras/activation.py +114 -0
- easy_rec/python/layers/keras/attention.py +267 -0
- easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
- easy_rec/python/layers/keras/blocks.py +262 -0
- easy_rec/python/layers/keras/bst.py +119 -0
- easy_rec/python/layers/keras/custom_ops.py +250 -0
- easy_rec/python/layers/keras/data_augment.py +133 -0
- easy_rec/python/layers/keras/din.py +67 -0
- easy_rec/python/layers/keras/einsum_dense.py +598 -0
- easy_rec/python/layers/keras/embedding.py +81 -0
- easy_rec/python/layers/keras/fibinet.py +251 -0
- easy_rec/python/layers/keras/interaction.py +416 -0
- easy_rec/python/layers/keras/layer_norm.py +364 -0
- easy_rec/python/layers/keras/mask_net.py +166 -0
- easy_rec/python/layers/keras/multi_head_attention.py +717 -0
- easy_rec/python/layers/keras/multi_task.py +125 -0
- easy_rec/python/layers/keras/numerical_embedding.py +376 -0
- easy_rec/python/layers/keras/ppnet.py +194 -0
- easy_rec/python/layers/keras/transformer.py +192 -0
- easy_rec/python/layers/layer_norm.py +51 -0
- easy_rec/python/layers/mmoe.py +83 -0
- easy_rec/python/layers/multihead_attention.py +162 -0
- easy_rec/python/layers/multihead_cross_attention.py +749 -0
- easy_rec/python/layers/senet.py +73 -0
- easy_rec/python/layers/seq_input_layer.py +134 -0
- easy_rec/python/layers/sequence_feature_layer.py +249 -0
- easy_rec/python/layers/uniter.py +301 -0
- easy_rec/python/layers/utils.py +248 -0
- easy_rec/python/layers/variational_dropout_layer.py +130 -0
- easy_rec/python/loss/__init__.py +0 -0
- easy_rec/python/loss/circle_loss.py +82 -0
- easy_rec/python/loss/contrastive_loss.py +79 -0
- easy_rec/python/loss/f1_reweight_loss.py +38 -0
- easy_rec/python/loss/focal_loss.py +93 -0
- easy_rec/python/loss/jrc_loss.py +128 -0
- easy_rec/python/loss/listwise_loss.py +161 -0
- easy_rec/python/loss/multi_similarity.py +68 -0
- easy_rec/python/loss/pairwise_loss.py +307 -0
- easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
- easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
- easy_rec/python/main.py +878 -0
- easy_rec/python/model/__init__.py +0 -0
- easy_rec/python/model/autoint.py +73 -0
- easy_rec/python/model/cmbf.py +47 -0
- easy_rec/python/model/collaborative_metric_learning.py +182 -0
- easy_rec/python/model/custom_model.py +323 -0
- easy_rec/python/model/dat.py +138 -0
- easy_rec/python/model/dbmtl.py +116 -0
- easy_rec/python/model/dcn.py +70 -0
- easy_rec/python/model/deepfm.py +106 -0
- easy_rec/python/model/dlrm.py +73 -0
- easy_rec/python/model/dropoutnet.py +207 -0
- easy_rec/python/model/dssm.py +154 -0
- easy_rec/python/model/dssm_senet.py +143 -0
- easy_rec/python/model/dummy_model.py +48 -0
- easy_rec/python/model/easy_rec_estimator.py +739 -0
- easy_rec/python/model/easy_rec_model.py +467 -0
- easy_rec/python/model/esmm.py +242 -0
- easy_rec/python/model/fm.py +63 -0
- easy_rec/python/model/match_model.py +357 -0
- easy_rec/python/model/mind.py +445 -0
- easy_rec/python/model/mmoe.py +70 -0
- easy_rec/python/model/multi_task_model.py +303 -0
- easy_rec/python/model/multi_tower.py +62 -0
- easy_rec/python/model/multi_tower_bst.py +190 -0
- easy_rec/python/model/multi_tower_din.py +130 -0
- easy_rec/python/model/multi_tower_recall.py +68 -0
- easy_rec/python/model/pdn.py +203 -0
- easy_rec/python/model/ple.py +120 -0
- easy_rec/python/model/rank_model.py +485 -0
- easy_rec/python/model/rocket_launching.py +203 -0
- easy_rec/python/model/simple_multi_task.py +54 -0
- easy_rec/python/model/uniter.py +46 -0
- easy_rec/python/model/wide_and_deep.py +121 -0
- easy_rec/python/ops/1.12/incr_record.so +0 -0
- easy_rec/python/ops/1.12/kafka.so +0 -0
- easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.12/libembed_op.so +0 -0
- easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
- easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
- easy_rec/python/ops/1.15/incr_record.so +0 -0
- easy_rec/python/ops/1.15/kafka.so +0 -0
- easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
- easy_rec/python/ops/1.15/libembed_op.so +0 -0
- easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so +0 -0
- easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
- easy_rec/python/ops/1.15/librdkafka.so +0 -0
- easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
- easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
- easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
- easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
- easy_rec/python/ops/2.12/libload_embed.so +0 -0
- easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
- easy_rec/python/ops/__init__.py +0 -0
- easy_rec/python/ops/gen_kafka_ops.py +193 -0
- easy_rec/python/ops/gen_str_avx_op.py +28 -0
- easy_rec/python/ops/incr_record.py +30 -0
- easy_rec/python/predict.py +170 -0
- easy_rec/python/protos/__init__.py +0 -0
- easy_rec/python/protos/autoint_pb2.py +122 -0
- easy_rec/python/protos/backbone_pb2.py +1416 -0
- easy_rec/python/protos/cmbf_pb2.py +435 -0
- easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
- easy_rec/python/protos/custom_model_pb2.py +57 -0
- easy_rec/python/protos/dat_pb2.py +262 -0
- easy_rec/python/protos/data_source_pb2.py +422 -0
- easy_rec/python/protos/dataset_pb2.py +1920 -0
- easy_rec/python/protos/dbmtl_pb2.py +191 -0
- easy_rec/python/protos/dcn_pb2.py +197 -0
- easy_rec/python/protos/deepfm_pb2.py +163 -0
- easy_rec/python/protos/dlrm_pb2.py +163 -0
- easy_rec/python/protos/dnn_pb2.py +329 -0
- easy_rec/python/protos/dropoutnet_pb2.py +239 -0
- easy_rec/python/protos/dssm_pb2.py +262 -0
- easy_rec/python/protos/dssm_senet_pb2.py +282 -0
- easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
- easy_rec/python/protos/esmm_pb2.py +133 -0
- easy_rec/python/protos/eval_pb2.py +930 -0
- easy_rec/python/protos/export_pb2.py +379 -0
- easy_rec/python/protos/feature_config_pb2.py +1359 -0
- easy_rec/python/protos/fm_pb2.py +90 -0
- easy_rec/python/protos/hive_config_pb2.py +138 -0
- easy_rec/python/protos/hyperparams_pb2.py +624 -0
- easy_rec/python/protos/keras_layer_pb2.py +692 -0
- easy_rec/python/protos/layer_pb2.py +1936 -0
- easy_rec/python/protos/loss_pb2.py +1713 -0
- easy_rec/python/protos/mind_pb2.py +497 -0
- easy_rec/python/protos/mmoe_pb2.py +215 -0
- easy_rec/python/protos/multi_tower_pb2.py +295 -0
- easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
- easy_rec/python/protos/optimizer_pb2.py +2017 -0
- easy_rec/python/protos/pdn_pb2.py +293 -0
- easy_rec/python/protos/pipeline_pb2.py +516 -0
- easy_rec/python/protos/ple_pb2.py +231 -0
- easy_rec/python/protos/predict_pb2.py +1140 -0
- easy_rec/python/protos/rocket_launching_pb2.py +169 -0
- easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
- easy_rec/python/protos/simi_pb2.py +54 -0
- easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
- easy_rec/python/protos/tf_predict_pb2.py +630 -0
- easy_rec/python/protos/tower_pb2.py +661 -0
- easy_rec/python/protos/train_pb2.py +1197 -0
- easy_rec/python/protos/uniter_pb2.py +307 -0
- easy_rec/python/protos/variational_dropout_pb2.py +91 -0
- easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
- easy_rec/python/test/__init__.py +0 -0
- easy_rec/python/test/csv_input_test.py +340 -0
- easy_rec/python/test/custom_early_stop_func.py +19 -0
- easy_rec/python/test/dh_local_run.py +104 -0
- easy_rec/python/test/embed_test.py +155 -0
- easy_rec/python/test/emr_run.py +119 -0
- easy_rec/python/test/eval_metric_test.py +107 -0
- easy_rec/python/test/excel_convert_test.py +64 -0
- easy_rec/python/test/export_test.py +513 -0
- easy_rec/python/test/fg_test.py +70 -0
- easy_rec/python/test/hive_input_test.py +311 -0
- easy_rec/python/test/hpo_test.py +235 -0
- easy_rec/python/test/kafka_test.py +373 -0
- easy_rec/python/test/local_incr_test.py +122 -0
- easy_rec/python/test/loss_test.py +110 -0
- easy_rec/python/test/odps_command.py +61 -0
- easy_rec/python/test/odps_local_run.py +86 -0
- easy_rec/python/test/odps_run.py +254 -0
- easy_rec/python/test/odps_test_cls.py +39 -0
- easy_rec/python/test/odps_test_prepare.py +198 -0
- easy_rec/python/test/odps_test_util.py +237 -0
- easy_rec/python/test/pre_check_test.py +54 -0
- easy_rec/python/test/predictor_test.py +394 -0
- easy_rec/python/test/rtp_convert_test.py +133 -0
- easy_rec/python/test/run.py +138 -0
- easy_rec/python/test/train_eval_test.py +1299 -0
- easy_rec/python/test/util_test.py +85 -0
- easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
- easy_rec/python/tools/__init__.py +0 -0
- easy_rec/python/tools/add_boundaries_to_config.py +67 -0
- easy_rec/python/tools/add_feature_info_to_config.py +145 -0
- easy_rec/python/tools/convert_config_format.py +48 -0
- easy_rec/python/tools/convert_rtp_data.py +79 -0
- easy_rec/python/tools/convert_rtp_fg.py +106 -0
- easy_rec/python/tools/create_config_from_excel.py +427 -0
- easy_rec/python/tools/criteo/__init__.py +0 -0
- easy_rec/python/tools/criteo/convert_data.py +157 -0
- easy_rec/python/tools/edit_lookup_graph.py +134 -0
- easy_rec/python/tools/faiss_index_pai.py +116 -0
- easy_rec/python/tools/feature_selection.py +316 -0
- easy_rec/python/tools/hit_rate_ds.py +223 -0
- easy_rec/python/tools/hit_rate_pai.py +138 -0
- easy_rec/python/tools/pre_check.py +120 -0
- easy_rec/python/tools/predict_and_chk.py +111 -0
- easy_rec/python/tools/read_kafka.py +55 -0
- easy_rec/python/tools/split_model_pai.py +286 -0
- easy_rec/python/tools/split_pdn_model_pai.py +272 -0
- easy_rec/python/tools/test_saved_model.py +80 -0
- easy_rec/python/tools/view_saved_model.py +39 -0
- easy_rec/python/tools/write_kafka.py +65 -0
- easy_rec/python/train_eval.py +325 -0
- easy_rec/python/utils/__init__.py +15 -0
- easy_rec/python/utils/activation.py +120 -0
- easy_rec/python/utils/check_utils.py +87 -0
- easy_rec/python/utils/compat.py +14 -0
- easy_rec/python/utils/config_util.py +652 -0
- easy_rec/python/utils/constant.py +43 -0
- easy_rec/python/utils/convert_rtp_fg.py +616 -0
- easy_rec/python/utils/dag.py +192 -0
- easy_rec/python/utils/distribution_utils.py +268 -0
- easy_rec/python/utils/ds_util.py +65 -0
- easy_rec/python/utils/embedding_utils.py +73 -0
- easy_rec/python/utils/estimator_utils.py +1036 -0
- easy_rec/python/utils/export_big_model.py +630 -0
- easy_rec/python/utils/expr_util.py +118 -0
- easy_rec/python/utils/fg_util.py +53 -0
- easy_rec/python/utils/hit_rate_utils.py +220 -0
- easy_rec/python/utils/hive_utils.py +183 -0
- easy_rec/python/utils/hpo_util.py +137 -0
- easy_rec/python/utils/hvd_utils.py +56 -0
- easy_rec/python/utils/input_utils.py +108 -0
- easy_rec/python/utils/io_util.py +282 -0
- easy_rec/python/utils/load_class.py +249 -0
- easy_rec/python/utils/meta_graph_editor.py +941 -0
- easy_rec/python/utils/multi_optimizer.py +62 -0
- easy_rec/python/utils/numpy_utils.py +18 -0
- easy_rec/python/utils/odps_util.py +79 -0
- easy_rec/python/utils/pai_util.py +86 -0
- easy_rec/python/utils/proto_util.py +90 -0
- easy_rec/python/utils/restore_filter.py +89 -0
- easy_rec/python/utils/shape_utils.py +432 -0
- easy_rec/python/utils/static_shape.py +71 -0
- easy_rec/python/utils/test_utils.py +866 -0
- easy_rec/python/utils/tf_utils.py +56 -0
- easy_rec/version.py +4 -0
- test/__init__.py +0 -0
|
@@ -0,0 +1,598 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
import re
|
|
4
|
+
import string
|
|
5
|
+
|
|
6
|
+
import tensorflow as tf
|
|
7
|
+
from tensorflow.python.keras import activations
|
|
8
|
+
from tensorflow.python.keras import constraints
|
|
9
|
+
from tensorflow.python.keras import initializers
|
|
10
|
+
from tensorflow.python.keras import regularizers
|
|
11
|
+
from tensorflow.python.keras.layers import Layer
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class EinsumDense(Layer):
|
|
15
|
+
"""A layer that uses `einsum` as the backing computation.
|
|
16
|
+
|
|
17
|
+
This layer can perform einsum calculations of arbitrary dimensionality.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
equation: An equation describing the einsum to perform.
|
|
21
|
+
This equation must be a valid einsum string of the form
|
|
22
|
+
`ab,bc->ac`, `...ab,bc->...ac`, or
|
|
23
|
+
`ab...,bc->ac...` where 'ab', 'bc', and 'ac' can be any valid einsum
|
|
24
|
+
axis expression sequence.
|
|
25
|
+
output_shape: The expected shape of the output tensor
|
|
26
|
+
(excluding the batch dimension and any dimensions
|
|
27
|
+
represented by ellipses). You can specify `None` for any dimension
|
|
28
|
+
that is unknown or can be inferred from the input shape.
|
|
29
|
+
activation: Activation function to use. If you don't specify anything,
|
|
30
|
+
no activation is applied
|
|
31
|
+
(that is, a "linear" activation: `a(x) = x`).
|
|
32
|
+
bias_axes: A string containing the output dimension(s)
|
|
33
|
+
to apply a bias to. Each character in the `bias_axes` string
|
|
34
|
+
should correspond to a character in the output portion
|
|
35
|
+
of the `equation` string.
|
|
36
|
+
kernel_initializer: Initializer for the `kernel` weights matrix.
|
|
37
|
+
bias_initializer: Initializer for the bias vector.
|
|
38
|
+
kernel_regularizer: Regularizer function applied to the `kernel` weights
|
|
39
|
+
matrix.
|
|
40
|
+
bias_regularizer: Regularizer function applied to the bias vector.
|
|
41
|
+
kernel_constraint: Constraint function applied to the `kernel` weights
|
|
42
|
+
matrix.
|
|
43
|
+
bias_constraint: Constraint function applied to the bias vector.
|
|
44
|
+
lora_rank: Optional integer. If set, the layer's forward pass
|
|
45
|
+
will implement LoRA (Low-Rank Adaptation)
|
|
46
|
+
with the provided rank. LoRA sets the layer's kernel
|
|
47
|
+
to non-trainable and replaces it with a delta over the
|
|
48
|
+
original kernel, obtained via multiplying two lower-rank
|
|
49
|
+
trainable matrices
|
|
50
|
+
(the factorization happens on the last dimension).
|
|
51
|
+
This can be useful to reduce the
|
|
52
|
+
computation cost of fine-tuning large dense layers.
|
|
53
|
+
You can also enable LoRA on an existing
|
|
54
|
+
`EinsumDense` layer by calling `layer.enable_lora(rank)`.
|
|
55
|
+
**kwargs: Base layer keyword arguments, such as `name` and `dtype`.
|
|
56
|
+
|
|
57
|
+
Examples:
|
|
58
|
+
**Biased dense layer with einsums**
|
|
59
|
+
|
|
60
|
+
This example shows how to instantiate a standard Keras dense layer using
|
|
61
|
+
einsum operations. This example is equivalent to
|
|
62
|
+
`keras.layers.Dense(64, use_bias=True)`.
|
|
63
|
+
|
|
64
|
+
>>> layer = tf.keras.layers.EinsumDense("ab,bc->ac",
|
|
65
|
+
... output_shape=64,
|
|
66
|
+
... bias_axes="c")
|
|
67
|
+
>>> input_tensor = tf.keras.Input(shape=[32])
|
|
68
|
+
>>> output_tensor = layer(input_tensor)
|
|
69
|
+
>>> output_tensor.shape
|
|
70
|
+
(None, 64)
|
|
71
|
+
|
|
72
|
+
**Applying a dense layer to a sequence**
|
|
73
|
+
|
|
74
|
+
This example shows how to instantiate a layer that applies the same dense
|
|
75
|
+
operation to every element in a sequence. Here, the `output_shape` has two
|
|
76
|
+
values (since there are two non-batch dimensions in the output); the first
|
|
77
|
+
dimension in the `output_shape` is `None`, because the sequence dimension
|
|
78
|
+
`b` has an unknown shape.
|
|
79
|
+
|
|
80
|
+
>>> layer = tf.keras.layers.EinsumDense("abc,cd->abd",
|
|
81
|
+
... output_shape=(None, 64),
|
|
82
|
+
... bias_axes="d")
|
|
83
|
+
>>> input_tensor = tf.keras.Input(shape=[32, 128])
|
|
84
|
+
>>> output_tensor = layer(input_tensor)
|
|
85
|
+
>>> output_tensor.shape
|
|
86
|
+
(None, 32, 64)
|
|
87
|
+
|
|
88
|
+
**Applying a dense layer to a sequence using ellipses**
|
|
89
|
+
|
|
90
|
+
This example shows how to instantiate a layer that applies the same dense
|
|
91
|
+
operation to every element in a sequence, but uses the ellipsis notation
|
|
92
|
+
instead of specifying the batch and sequence dimensions.
|
|
93
|
+
|
|
94
|
+
Because we are using ellipsis notation and have specified only one axis, the
|
|
95
|
+
`output_shape` arg is a single value. When instantiated in this way, the
|
|
96
|
+
layer can handle any number of sequence dimensions - including the case
|
|
97
|
+
where no sequence dimension exists.
|
|
98
|
+
|
|
99
|
+
>>> layer = tf.keras.layers.EinsumDense("...x,xy->...y",
|
|
100
|
+
... output_shape=64,
|
|
101
|
+
... bias_axes="y")
|
|
102
|
+
>>> input_tensor = tf.keras.Input(shape=[32, 128])
|
|
103
|
+
>>> output_tensor = layer(input_tensor)
|
|
104
|
+
>>> output_tensor.shape
|
|
105
|
+
(None, 32, 64)
|
|
106
|
+
"""
|
|
107
|
+
|
|
108
|
+
def __init__(self,
|
|
109
|
+
equation,
|
|
110
|
+
output_shape,
|
|
111
|
+
activation=None,
|
|
112
|
+
bias_axes=None,
|
|
113
|
+
kernel_initializer='glorot_uniform',
|
|
114
|
+
bias_initializer='zeros',
|
|
115
|
+
kernel_regularizer=None,
|
|
116
|
+
bias_regularizer=None,
|
|
117
|
+
kernel_constraint=None,
|
|
118
|
+
bias_constraint=None,
|
|
119
|
+
lora_rank=None,
|
|
120
|
+
**kwargs):
|
|
121
|
+
super(EinsumDense, self).__init__(**kwargs)
|
|
122
|
+
self.equation = equation
|
|
123
|
+
if isinstance(output_shape, int):
|
|
124
|
+
self.partial_output_shape = (output_shape,)
|
|
125
|
+
else:
|
|
126
|
+
self.partial_output_shape = tuple(output_shape)
|
|
127
|
+
self.bias_axes = bias_axes
|
|
128
|
+
self.activation = activations.get(activation)
|
|
129
|
+
self.kernel_initializer = initializers.get(kernel_initializer)
|
|
130
|
+
self.bias_initializer = initializers.get(bias_initializer)
|
|
131
|
+
self.kernel_regularizer = regularizers.get(kernel_regularizer)
|
|
132
|
+
self.bias_regularizer = regularizers.get(bias_regularizer)
|
|
133
|
+
self.kernel_constraint = constraints.get(kernel_constraint)
|
|
134
|
+
self.bias_constraint = constraints.get(bias_constraint)
|
|
135
|
+
self.lora_rank = lora_rank
|
|
136
|
+
self.lora_enabled = False
|
|
137
|
+
|
|
138
|
+
def build(self, input_shape):
|
|
139
|
+
shape_data = _analyze_einsum_string(
|
|
140
|
+
self.equation,
|
|
141
|
+
self.bias_axes,
|
|
142
|
+
input_shape,
|
|
143
|
+
self.partial_output_shape,
|
|
144
|
+
)
|
|
145
|
+
kernel_shape, bias_shape, full_output_shape = shape_data
|
|
146
|
+
for i in range(len(kernel_shape)):
|
|
147
|
+
dim = kernel_shape[i]
|
|
148
|
+
if isinstance(dim, tf.Dimension):
|
|
149
|
+
kernel_shape[i] = dim.value
|
|
150
|
+
for i in range(len(bias_shape)):
|
|
151
|
+
dim = bias_shape[i]
|
|
152
|
+
if isinstance(dim, tf.Dimension):
|
|
153
|
+
bias_shape[i] = dim.value
|
|
154
|
+
for i in range(len(full_output_shape)):
|
|
155
|
+
dim = full_output_shape[i]
|
|
156
|
+
if isinstance(dim, tf.Dimension):
|
|
157
|
+
full_output_shape[i] = dim.value
|
|
158
|
+
self.full_output_shape = tuple(full_output_shape)
|
|
159
|
+
self._kernel = self.add_weight(
|
|
160
|
+
name='kernel',
|
|
161
|
+
shape=tuple(kernel_shape),
|
|
162
|
+
initializer=self.kernel_initializer,
|
|
163
|
+
regularizer=self.kernel_regularizer,
|
|
164
|
+
constraint=self.kernel_constraint,
|
|
165
|
+
dtype=self.dtype,
|
|
166
|
+
trainable=True,
|
|
167
|
+
)
|
|
168
|
+
if bias_shape is not None:
|
|
169
|
+
self.bias = self.add_weight(
|
|
170
|
+
name='bias',
|
|
171
|
+
shape=tuple(bias_shape),
|
|
172
|
+
initializer=self.bias_initializer,
|
|
173
|
+
regularizer=self.bias_regularizer,
|
|
174
|
+
constraint=self.bias_constraint,
|
|
175
|
+
dtype=self.dtype,
|
|
176
|
+
trainable=True,
|
|
177
|
+
)
|
|
178
|
+
else:
|
|
179
|
+
self.bias = None
|
|
180
|
+
self.built = True
|
|
181
|
+
if self.lora_rank:
|
|
182
|
+
self.enable_lora(self.lora_rank)
|
|
183
|
+
|
|
184
|
+
@property
|
|
185
|
+
def kernel(self):
|
|
186
|
+
if not self.built:
|
|
187
|
+
raise AttributeError(
|
|
188
|
+
'You must build the layer before accessing `kernel`.')
|
|
189
|
+
if self.lora_enabled:
|
|
190
|
+
return self._kernel + tf.matmul(self.lora_kernel_a, self.lora_kernel_b)
|
|
191
|
+
return self._kernel
|
|
192
|
+
|
|
193
|
+
def compute_output_shape(self, _):
|
|
194
|
+
return self.full_output_shape
|
|
195
|
+
|
|
196
|
+
def call(self, inputs, training=None):
|
|
197
|
+
x = tf.einsum(self.equation, inputs, self.kernel)
|
|
198
|
+
if self.bias is not None:
|
|
199
|
+
x += self.bias
|
|
200
|
+
if self.activation is not None:
|
|
201
|
+
x = self.activation(x)
|
|
202
|
+
return x
|
|
203
|
+
|
|
204
|
+
def enable_lora(self,
|
|
205
|
+
rank,
|
|
206
|
+
a_initializer='he_uniform',
|
|
207
|
+
b_initializer='zeros'):
|
|
208
|
+
if self.kernel_constraint:
|
|
209
|
+
raise ValueError('Lora is incompatible with kernel constraints. '
|
|
210
|
+
'In order to enable lora on this layer, remove the '
|
|
211
|
+
'`kernel_constraint` argument.')
|
|
212
|
+
if not self.built:
|
|
213
|
+
raise ValueError("Cannot enable lora on a layer that isn't yet built.")
|
|
214
|
+
if self.lora_enabled:
|
|
215
|
+
raise ValueError('lora is already enabled. '
|
|
216
|
+
'This can only be done once per layer.')
|
|
217
|
+
self._tracker.unlock()
|
|
218
|
+
self.lora_kernel_a = self.add_weight(
|
|
219
|
+
name='lora_kernel_a',
|
|
220
|
+
shape=(self.kernel.shape[:-1] + (rank,)),
|
|
221
|
+
initializer=initializers.get(a_initializer),
|
|
222
|
+
regularizer=self.kernel_regularizer,
|
|
223
|
+
)
|
|
224
|
+
self.lora_kernel_b = self.add_weight(
|
|
225
|
+
name='lora_kernel_b',
|
|
226
|
+
shape=(rank, self.kernel.shape[-1]),
|
|
227
|
+
initializer=initializers.get(b_initializer),
|
|
228
|
+
regularizer=self.kernel_regularizer,
|
|
229
|
+
)
|
|
230
|
+
self._kernel.trainable = False
|
|
231
|
+
self._tracker.lock()
|
|
232
|
+
self.lora_enabled = True
|
|
233
|
+
self.lora_rank = rank
|
|
234
|
+
|
|
235
|
+
def save_own_variables(self, store):
|
|
236
|
+
# Do nothing if the layer isn't yet built
|
|
237
|
+
if not self.built:
|
|
238
|
+
return
|
|
239
|
+
# The keys of the `store` will be saved as determined because the
|
|
240
|
+
# default ordering will change after quantization
|
|
241
|
+
kernel_value, kernel_scale = self._get_kernel_with_merged_lora()
|
|
242
|
+
target_variables = [kernel_value]
|
|
243
|
+
if self.bias is not None:
|
|
244
|
+
target_variables.append(self.bias)
|
|
245
|
+
for i, variable in enumerate(target_variables):
|
|
246
|
+
store[str(i)] = variable
|
|
247
|
+
|
|
248
|
+
def load_own_variables(self, store):
|
|
249
|
+
if not self.lora_enabled:
|
|
250
|
+
self._check_load_own_variables(store)
|
|
251
|
+
# Do nothing if the layer isn't yet built
|
|
252
|
+
if not self.built:
|
|
253
|
+
return
|
|
254
|
+
# The keys of the `store` will be saved as determined because the
|
|
255
|
+
# default ordering will change after quantization
|
|
256
|
+
target_variables = [self._kernel]
|
|
257
|
+
if self.bias is not None:
|
|
258
|
+
target_variables.append(self.bias)
|
|
259
|
+
for i, variable in enumerate(target_variables):
|
|
260
|
+
variable.assign(store[str(i)])
|
|
261
|
+
if self.lora_enabled:
|
|
262
|
+
self.lora_kernel_a.assign(tf.zeros(self.lora_kernel_a.shape))
|
|
263
|
+
self.lora_kernel_b.assign(tf.zeros(self.lora_kernel_b.shape))
|
|
264
|
+
|
|
265
|
+
def get_config(self):
|
|
266
|
+
base_config = super(EinsumDense, self).get_config()
|
|
267
|
+
config = {
|
|
268
|
+
'output_shape':
|
|
269
|
+
self.partial_output_shape,
|
|
270
|
+
'equation':
|
|
271
|
+
self.equation,
|
|
272
|
+
'activation':
|
|
273
|
+
activations.serialize(self.activation),
|
|
274
|
+
'bias_axes':
|
|
275
|
+
self.bias_axes,
|
|
276
|
+
'kernel_initializer':
|
|
277
|
+
initializers.serialize(self.kernel_initializer),
|
|
278
|
+
'bias_initializer':
|
|
279
|
+
initializers.serialize(self.bias_initializer),
|
|
280
|
+
'kernel_regularizer':
|
|
281
|
+
regularizers.serialize(self.kernel_regularizer),
|
|
282
|
+
'bias_regularizer':
|
|
283
|
+
regularizers.serialize(self.bias_regularizer),
|
|
284
|
+
'activity_regularizer':
|
|
285
|
+
regularizers.serialize(self.activity_regularizer),
|
|
286
|
+
'kernel_constraint':
|
|
287
|
+
constraints.serialize(self.kernel_constraint),
|
|
288
|
+
'bias_constraint':
|
|
289
|
+
constraints.serialize(self.bias_constraint),
|
|
290
|
+
}
|
|
291
|
+
if self.lora_rank:
|
|
292
|
+
config['lora_rank'] = self.lora_rank
|
|
293
|
+
config.update(base_config)
|
|
294
|
+
return config
|
|
295
|
+
|
|
296
|
+
def _check_load_own_variables(self, store):
|
|
297
|
+
all_vars = self._trainable_variables + self._non_trainable_variables
|
|
298
|
+
if len(store.keys()) != len(all_vars):
|
|
299
|
+
if len(all_vars) == 0 and not self.built:
|
|
300
|
+
raise ValueError(
|
|
301
|
+
"Layer '{name}' was never built "
|
|
302
|
+
"and thus it doesn't have any variables. "
|
|
303
|
+
'However the weights file lists {num_keys} '
|
|
304
|
+
'variables for this layer.\n'
|
|
305
|
+
'In most cases, this error indicates that either:\n\n'
|
|
306
|
+
'1. The layer is owned by a parent layer that '
|
|
307
|
+
'implements a `build()` method, but calling the '
|
|
308
|
+
"parent's `build()` method did NOT create the state of "
|
|
309
|
+
"the child layer '{name}'. A `build()` method "
|
|
310
|
+
'must create ALL state for the layer, including '
|
|
311
|
+
'the state of any children layers.\n\n'
|
|
312
|
+
'2. You need to implement '
|
|
313
|
+
'the `def build_from_config(self, config)` method '
|
|
314
|
+
"on layer '{name}', to specify how to rebuild "
|
|
315
|
+
'it during loading. '
|
|
316
|
+
'In this case, you might also want to implement the '
|
|
317
|
+
'method that generates the build config at saving time, '
|
|
318
|
+
'`def get_build_config(self)`. '
|
|
319
|
+
'The method `build_from_config()` is meant '
|
|
320
|
+
'to create the state '
|
|
321
|
+
'of the layer (i.e. its variables) upon deserialization.'.format(
|
|
322
|
+
name=self.name, num_keys=len(store.keys())))
|
|
323
|
+
raise ValueError(
|
|
324
|
+
"Layer '{name}' expected {num_var} variables, but received "
|
|
325
|
+
'{num_key} variables during loading. '
|
|
326
|
+
'Expected: {names}'.format(
|
|
327
|
+
name=self.name,
|
|
328
|
+
num_var=len(store.keys()),
|
|
329
|
+
num_key=len(store.keys()),
|
|
330
|
+
names=[v.name for v in all_vars]))
|
|
331
|
+
|
|
332
|
+
def _get_kernel_with_merged_lora(self):
|
|
333
|
+
kernel_value = self.kernel
|
|
334
|
+
kernel_scale = None
|
|
335
|
+
return kernel_value, kernel_scale
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
def _analyze_einsum_string(equation, bias_axes, input_shape, output_shape):
|
|
339
|
+
"""Analyzes an einsum string to determine the required weight shape."""
|
|
340
|
+
dot_replaced_string = re.sub(r'\.\.\.', '0', equation)
|
|
341
|
+
|
|
342
|
+
# This is the case where no ellipses are present in the string.
|
|
343
|
+
split_string = re.match('([a-zA-Z]+),([a-zA-Z]+)->([a-zA-Z]+)',
|
|
344
|
+
dot_replaced_string)
|
|
345
|
+
if split_string:
|
|
346
|
+
return _analyze_split_string(split_string, bias_axes, input_shape,
|
|
347
|
+
output_shape)
|
|
348
|
+
|
|
349
|
+
# This is the case where ellipses are present on the left.
|
|
350
|
+
split_string = re.match('0([a-zA-Z]+),([a-zA-Z]+)->0([a-zA-Z]+)',
|
|
351
|
+
dot_replaced_string)
|
|
352
|
+
if split_string:
|
|
353
|
+
return _analyze_split_string(
|
|
354
|
+
split_string, bias_axes, input_shape, output_shape, left_elided=True)
|
|
355
|
+
|
|
356
|
+
# This is the case where ellipses are present on the right.
|
|
357
|
+
split_string = re.match('([a-zA-Z]{2,})0,([a-zA-Z]+)->([a-zA-Z]+)0',
|
|
358
|
+
dot_replaced_string)
|
|
359
|
+
if split_string:
|
|
360
|
+
return _analyze_split_string(split_string, bias_axes, input_shape,
|
|
361
|
+
output_shape)
|
|
362
|
+
|
|
363
|
+
raise ValueError(
|
|
364
|
+
"Invalid einsum equation '{equation}'. Equations must be in the form "
|
|
365
|
+
'[X],[Y]->[Z], ...[X],[Y]->...[Z], or [X]...,[Y]->[Z]....'.format(
|
|
366
|
+
equation=equation))
|
|
367
|
+
|
|
368
|
+
|
|
369
|
+
def _analyze_split_string(split_string,
|
|
370
|
+
bias_axes,
|
|
371
|
+
input_shape,
|
|
372
|
+
output_shape,
|
|
373
|
+
left_elided=False):
|
|
374
|
+
"""Analyze an pre-split einsum string to find the weight shape."""
|
|
375
|
+
input_spec = split_string.group(1)
|
|
376
|
+
weight_spec = split_string.group(2)
|
|
377
|
+
output_spec = split_string.group(3)
|
|
378
|
+
elided = len(input_shape) - len(input_spec)
|
|
379
|
+
if isinstance(output_shape, int):
|
|
380
|
+
output_shape = [output_shape]
|
|
381
|
+
else:
|
|
382
|
+
output_shape = list(output_shape)
|
|
383
|
+
|
|
384
|
+
output_shape.insert(0, input_shape[0])
|
|
385
|
+
|
|
386
|
+
if elided > 0 and left_elided:
|
|
387
|
+
for i in range(1, elided):
|
|
388
|
+
# We already inserted the 0th input dimension at dim 0, so we need
|
|
389
|
+
# to start at location 1 here.
|
|
390
|
+
output_shape.insert(1, input_shape[i])
|
|
391
|
+
elif elided > 0 and not left_elided:
|
|
392
|
+
for i in range(len(input_shape) - elided, len(input_shape)):
|
|
393
|
+
output_shape.append(input_shape[i])
|
|
394
|
+
|
|
395
|
+
if left_elided:
|
|
396
|
+
# If we have beginning dimensions elided, we need to use negative
|
|
397
|
+
# indexing to determine where in the input dimension our values are.
|
|
398
|
+
input_dim_map = {
|
|
399
|
+
dim: (i + elided) - len(input_shape) for i, dim in enumerate(input_spec)
|
|
400
|
+
}
|
|
401
|
+
# Because we've constructed the full output shape already, we don't need
|
|
402
|
+
# to do negative indexing.
|
|
403
|
+
output_dim_map = {dim: (i + elided) for i, dim in enumerate(output_spec)}
|
|
404
|
+
else:
|
|
405
|
+
input_dim_map = {dim: i for i, dim in enumerate(input_spec)}
|
|
406
|
+
output_dim_map = {dim: i for i, dim in enumerate(output_spec)}
|
|
407
|
+
|
|
408
|
+
for dim in input_spec:
|
|
409
|
+
input_shape_at_dim = input_shape[input_dim_map[dim]]
|
|
410
|
+
if dim in output_dim_map:
|
|
411
|
+
output_shape_at_dim = output_shape[output_dim_map[dim]]
|
|
412
|
+
if (output_shape_at_dim is not None and
|
|
413
|
+
output_shape_at_dim != input_shape_at_dim):
|
|
414
|
+
raise ValueError(
|
|
415
|
+
'Input shape and output shape do not match at shared '
|
|
416
|
+
"dimension '{dim}'. Input shape is {input_shape_at_dim}, "
|
|
417
|
+
'and output shape is {output_shape}.'.format(
|
|
418
|
+
dim=dim,
|
|
419
|
+
input_shape_at_dim=input_shape_at_dim,
|
|
420
|
+
output_shape=output_shape[output_dim_map[dim]]))
|
|
421
|
+
|
|
422
|
+
for dim in output_spec:
|
|
423
|
+
if dim not in input_spec and dim not in weight_spec:
|
|
424
|
+
raise ValueError(
|
|
425
|
+
"Dimension '{dim}' was specified in the output "
|
|
426
|
+
"'{output_spec}' but has no corresponding dim in the input "
|
|
427
|
+
"spec '{input_spec}' or weight spec '{output_spec}'".format(
|
|
428
|
+
dim=dim, output_spec=output_spec, input_spec=input_spec))
|
|
429
|
+
|
|
430
|
+
weight_shape = []
|
|
431
|
+
for dim in weight_spec:
|
|
432
|
+
if dim in input_dim_map:
|
|
433
|
+
weight_shape.append(input_shape[input_dim_map[dim]])
|
|
434
|
+
elif dim in output_dim_map:
|
|
435
|
+
weight_shape.append(output_shape[output_dim_map[dim]])
|
|
436
|
+
else:
|
|
437
|
+
raise ValueError(
|
|
438
|
+
"Weight dimension '{dim}' did not have a match in either "
|
|
439
|
+
"the input spec '{input_spec}' or the output "
|
|
440
|
+
"spec '{output_spec}'. For this layer, the weight must "
|
|
441
|
+
'be fully specified.'.format(
|
|
442
|
+
dim=dim, input_spec=input_spec, output_spec=output_spec))
|
|
443
|
+
|
|
444
|
+
if bias_axes is not None:
|
|
445
|
+
num_left_elided = elided if left_elided else 0
|
|
446
|
+
idx_map = {
|
|
447
|
+
char: output_shape[i + num_left_elided]
|
|
448
|
+
for i, char in enumerate(output_spec)
|
|
449
|
+
}
|
|
450
|
+
|
|
451
|
+
for char in bias_axes:
|
|
452
|
+
if char not in output_spec:
|
|
453
|
+
raise ValueError(
|
|
454
|
+
"Bias dimension '{char}' was requested, but is not part "
|
|
455
|
+
"of the output spec '{output_spec}'".format(
|
|
456
|
+
char=char, output_spec=output_spec))
|
|
457
|
+
|
|
458
|
+
first_bias_location = min([output_spec.find(char) for char in bias_axes])
|
|
459
|
+
bias_output_spec = output_spec[first_bias_location:]
|
|
460
|
+
|
|
461
|
+
bias_shape = [
|
|
462
|
+
idx_map[char] if char in bias_axes else 1 for char in bias_output_spec
|
|
463
|
+
]
|
|
464
|
+
|
|
465
|
+
if not left_elided:
|
|
466
|
+
for _ in range(elided):
|
|
467
|
+
bias_shape.append(1)
|
|
468
|
+
else:
|
|
469
|
+
bias_shape = None
|
|
470
|
+
|
|
471
|
+
return weight_shape, bias_shape, output_shape
|
|
472
|
+
|
|
473
|
+
|
|
474
|
+
def _analyze_quantization_info(equation, input_shape):
|
|
475
|
+
|
|
476
|
+
def get_specs(equation, input_shape):
|
|
477
|
+
possible_labels = string.ascii_letters
|
|
478
|
+
dot_replaced_string = re.sub(r'\.\.\.', '0', equation)
|
|
479
|
+
|
|
480
|
+
# This is the case where no ellipses are present in the string.
|
|
481
|
+
split_string = re.match('([a-zA-Z]+),([a-zA-Z]+)->([a-zA-Z]+)',
|
|
482
|
+
dot_replaced_string)
|
|
483
|
+
if split_string is not None:
|
|
484
|
+
input_spec = split_string.group(1)
|
|
485
|
+
weight_spec = split_string.group(2)
|
|
486
|
+
output_spec = split_string.group(3)
|
|
487
|
+
return input_spec, weight_spec, output_spec
|
|
488
|
+
|
|
489
|
+
# This is the case where ellipses are present on the left.
|
|
490
|
+
split_string = re.match('0([a-zA-Z]+),([a-zA-Z]+)->0([a-zA-Z]+)',
|
|
491
|
+
dot_replaced_string)
|
|
492
|
+
if split_string is not None:
|
|
493
|
+
input_spec = split_string.group(1)
|
|
494
|
+
weight_spec = split_string.group(2)
|
|
495
|
+
output_spec = split_string.group(3)
|
|
496
|
+
elided = len(input_shape) - len(input_spec)
|
|
497
|
+
possible_labels = sorted(
|
|
498
|
+
set(possible_labels) - set(input_spec) - set(weight_spec) -
|
|
499
|
+
set(output_spec))
|
|
500
|
+
# Pad labels on the left to `input_spec` and `output_spec`
|
|
501
|
+
for i in range(elided):
|
|
502
|
+
input_spec = possible_labels[i] + input_spec
|
|
503
|
+
output_spec = possible_labels[i] + output_spec
|
|
504
|
+
return input_spec, weight_spec, output_spec
|
|
505
|
+
|
|
506
|
+
# This is the case where ellipses are present on the right.
|
|
507
|
+
split_string = re.match('([a-zA-Z]{2,})0,([a-zA-Z]+)->([a-zA-Z]+)0',
|
|
508
|
+
dot_replaced_string)
|
|
509
|
+
if split_string is not None:
|
|
510
|
+
input_spec = split_string.group(1)
|
|
511
|
+
weight_spec = split_string.group(2)
|
|
512
|
+
output_spec = split_string.group(3)
|
|
513
|
+
elided = len(input_shape) - len(input_spec)
|
|
514
|
+
possible_labels = sorted(
|
|
515
|
+
set(possible_labels) - set(input_spec) - set(weight_spec) -
|
|
516
|
+
set(output_spec))
|
|
517
|
+
# Pad labels on the right to `input_spec` and `output_spec`
|
|
518
|
+
for i in range(elided):
|
|
519
|
+
input_spec = input_spec + possible_labels[i]
|
|
520
|
+
output_spec = output_spec + possible_labels[i]
|
|
521
|
+
return input_spec, weight_spec, output_spec
|
|
522
|
+
|
|
523
|
+
raise ValueError(
|
|
524
|
+
"Invalid einsum equation '{equation}'. Equations must be in the "
|
|
525
|
+
'form [X],[Y]->[Z], ...[X],[Y]->...[Z], or [X]...,[Y]->[Z]....'.format(
|
|
526
|
+
equation=equation))
|
|
527
|
+
|
|
528
|
+
input_spec, weight_spec, output_spec = get_specs(equation, input_shape)
|
|
529
|
+
|
|
530
|
+
# Determine the axes that should be reduced by the quantizer
|
|
531
|
+
input_reduced_axes = []
|
|
532
|
+
weight_reduced_axes = []
|
|
533
|
+
for i, label in enumerate(input_spec):
|
|
534
|
+
index = output_spec.find(label)
|
|
535
|
+
if index == -1:
|
|
536
|
+
input_reduced_axes.append(i)
|
|
537
|
+
for i, label in enumerate(weight_spec):
|
|
538
|
+
index = output_spec.find(label)
|
|
539
|
+
if index == -1:
|
|
540
|
+
weight_reduced_axes.append(i)
|
|
541
|
+
|
|
542
|
+
# Determine the axes of `ops.expand_dims`
|
|
543
|
+
input_expand_axes = []
|
|
544
|
+
weight_expand_axes = []
|
|
545
|
+
for i, label in enumerate(output_spec):
|
|
546
|
+
index_input = input_spec.find(label)
|
|
547
|
+
index_weight = weight_spec.find(label)
|
|
548
|
+
if index_input == -1:
|
|
549
|
+
input_expand_axes.append(i)
|
|
550
|
+
if index_weight == -1:
|
|
551
|
+
weight_expand_axes.append(i)
|
|
552
|
+
|
|
553
|
+
# Determine the axes of `ops.transpose`
|
|
554
|
+
input_transpose_axes = []
|
|
555
|
+
weight_transpose_axes = []
|
|
556
|
+
for i, label in enumerate(output_spec):
|
|
557
|
+
index_input = input_spec.find(label)
|
|
558
|
+
index_weight = weight_spec.find(label)
|
|
559
|
+
if index_input != -1:
|
|
560
|
+
input_transpose_axes.append(index_input)
|
|
561
|
+
if index_weight != -1:
|
|
562
|
+
weight_transpose_axes.append(index_weight)
|
|
563
|
+
# Postprocess the information:
|
|
564
|
+
# 1. Add dummy axes (1) to transpose_axes
|
|
565
|
+
# 2. Add axis to squeeze_axes if 1. failed
|
|
566
|
+
input_squeeze_axes = []
|
|
567
|
+
weight_squeeze_axes = []
|
|
568
|
+
for ori_index in input_reduced_axes:
|
|
569
|
+
try:
|
|
570
|
+
index = input_expand_axes.pop(0)
|
|
571
|
+
except IndexError:
|
|
572
|
+
input_squeeze_axes.append(ori_index)
|
|
573
|
+
input_transpose_axes.insert(index, ori_index)
|
|
574
|
+
for ori_index in weight_reduced_axes:
|
|
575
|
+
try:
|
|
576
|
+
index = weight_expand_axes.pop(0)
|
|
577
|
+
except IndexError:
|
|
578
|
+
weight_squeeze_axes.append(ori_index)
|
|
579
|
+
weight_transpose_axes.insert(index, ori_index)
|
|
580
|
+
# Prepare equation for `einsum_with_inputs_gradient`
|
|
581
|
+
custom_gradient_equation = '{output_spec},{weight_spec}->{input_spec}'.format(
|
|
582
|
+
output_spec=output_spec, input_spec=input_spec, weight_spec=weight_spec)
|
|
583
|
+
weight_reverse_transpose_axes = [
|
|
584
|
+
i for (_, i) in sorted((v, i)
|
|
585
|
+
for (i, v) in enumerate(weight_transpose_axes))
|
|
586
|
+
]
|
|
587
|
+
return (
|
|
588
|
+
input_reduced_axes,
|
|
589
|
+
weight_reduced_axes,
|
|
590
|
+
input_transpose_axes,
|
|
591
|
+
weight_transpose_axes,
|
|
592
|
+
input_expand_axes,
|
|
593
|
+
weight_expand_axes,
|
|
594
|
+
input_squeeze_axes,
|
|
595
|
+
weight_squeeze_axes,
|
|
596
|
+
custom_gradient_equation,
|
|
597
|
+
weight_reverse_transpose_axes,
|
|
598
|
+
)
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
# -*- encoding:utf-8 -*-
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
"""Fused embedding layer."""
|
|
4
|
+
import tensorflow as tf
|
|
5
|
+
from tensorflow.python.keras.layers import Embedding
|
|
6
|
+
from tensorflow.python.keras.layers import Layer
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def _combine(embeddings, weights, comb_fn):
|
|
10
|
+
# embeddings shape: [B, N, D]
|
|
11
|
+
if callable(comb_fn):
|
|
12
|
+
return comb_fn(embeddings, axis=1)
|
|
13
|
+
if weights is None:
|
|
14
|
+
return tf.reduce_mean(embeddings, axis=1)
|
|
15
|
+
if isinstance(weights, tf.SparseTensor):
|
|
16
|
+
if weights.dtype == tf.string:
|
|
17
|
+
weights = tf.sparse.to_dense(weights, default_value='0')
|
|
18
|
+
weights = tf.string_to_number(weights)
|
|
19
|
+
else:
|
|
20
|
+
weights = tf.sparse.to_dense(weights, default_value=0.0)
|
|
21
|
+
sum_weights = tf.reduce_sum(weights, axis=1, keepdims=True)
|
|
22
|
+
weights = tf.expand_dims(weights / sum_weights, axis=-1)
|
|
23
|
+
return tf.reduce_sum(embeddings * weights, axis=1)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class EmbeddingLayer(Layer):
|
|
27
|
+
|
|
28
|
+
def __init__(self, params, name='embedding_layer', reuse=None, **kwargs):
|
|
29
|
+
super(EmbeddingLayer, self).__init__(name=name, **kwargs)
|
|
30
|
+
params.check_required(['vocab_size', 'embedding_dim'])
|
|
31
|
+
vocab_size = int(params.vocab_size)
|
|
32
|
+
combiner = params.get_or_default('combiner', 'weight')
|
|
33
|
+
if combiner == 'mean':
|
|
34
|
+
self.combine_fn = tf.reduce_mean
|
|
35
|
+
elif combiner == 'sum':
|
|
36
|
+
self.combine_fn = tf.reduce_sum
|
|
37
|
+
elif combiner == 'max':
|
|
38
|
+
self.combine_fn = tf.reduce_max
|
|
39
|
+
elif combiner == 'min':
|
|
40
|
+
self.combine_fn = tf.reduce_min
|
|
41
|
+
elif combiner == 'weight':
|
|
42
|
+
self.combine_fn = 'weight'
|
|
43
|
+
else:
|
|
44
|
+
raise ValueError('unsupported embedding combiner: ' + combiner)
|
|
45
|
+
self.embed_dim = int(params.embedding_dim)
|
|
46
|
+
self.embedding = Embedding(vocab_size, self.embed_dim)
|
|
47
|
+
self.do_concat = params.get_or_default('concat', True)
|
|
48
|
+
|
|
49
|
+
def call(self, inputs, training=None, **kwargs):
|
|
50
|
+
inputs, weights = inputs
|
|
51
|
+
# 将多个特征的输入合并为一个索引 tensor
|
|
52
|
+
flat_inputs = [tf.reshape(input_field, [-1]) for input_field in inputs]
|
|
53
|
+
all_indices = tf.concat(flat_inputs, axis=0)
|
|
54
|
+
# 从共享的嵌入表中进行一次 embedding lookup
|
|
55
|
+
all_embeddings = self.embedding(all_indices)
|
|
56
|
+
is_multi = []
|
|
57
|
+
# 计算每个特征的嵌入
|
|
58
|
+
split_sizes = []
|
|
59
|
+
for input_field in inputs:
|
|
60
|
+
assert input_field.shape.ndims <= 2, 'dims of embedding layer input must be <= 2'
|
|
61
|
+
input_shape = tf.shape(input_field)
|
|
62
|
+
size = input_shape[0]
|
|
63
|
+
if input_field.shape.ndims > 1:
|
|
64
|
+
size *= input_shape[-1]
|
|
65
|
+
is_multi.append(True)
|
|
66
|
+
else:
|
|
67
|
+
is_multi.append(False)
|
|
68
|
+
split_sizes.append(size)
|
|
69
|
+
embeddings = tf.split(all_embeddings, split_sizes, axis=0)
|
|
70
|
+
for i in range(len(embeddings)):
|
|
71
|
+
if is_multi[i]:
|
|
72
|
+
batch_size = tf.shape(inputs[i])[0]
|
|
73
|
+
embeddings[i] = tf.cond(
|
|
74
|
+
tf.equal(tf.size(embeddings[i]), 0),
|
|
75
|
+
lambda: tf.zeros([batch_size, self.embed_dim]), lambda: _combine(
|
|
76
|
+
tf.reshape(embeddings[i], [batch_size, -1, self.embed_dim]),
|
|
77
|
+
weights[i], self.combine_fn))
|
|
78
|
+
if self.do_concat:
|
|
79
|
+
embeddings = tf.concat(embeddings, axis=-1)
|
|
80
|
+
print('Embedding layer:', self.name, embeddings)
|
|
81
|
+
return embeddings
|