spark-nlp 2.6.3rc1__py2.py3-none-any.whl → 6.2.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- com/johnsnowlabs/ml/__init__.py +0 -0
- com/johnsnowlabs/ml/ai/__init__.py +10 -0
- com/johnsnowlabs/nlp/__init__.py +4 -2
- spark_nlp-6.2.1.dist-info/METADATA +362 -0
- spark_nlp-6.2.1.dist-info/RECORD +292 -0
- {spark_nlp-2.6.3rc1.dist-info → spark_nlp-6.2.1.dist-info}/WHEEL +1 -1
- sparknlp/__init__.py +281 -27
- sparknlp/annotation.py +137 -6
- sparknlp/annotation_audio.py +61 -0
- sparknlp/annotation_image.py +82 -0
- sparknlp/annotator/__init__.py +93 -0
- sparknlp/annotator/audio/__init__.py +16 -0
- sparknlp/annotator/audio/hubert_for_ctc.py +188 -0
- sparknlp/annotator/audio/wav2vec2_for_ctc.py +161 -0
- sparknlp/annotator/audio/whisper_for_ctc.py +251 -0
- sparknlp/annotator/chunk2_doc.py +85 -0
- sparknlp/annotator/chunker.py +137 -0
- sparknlp/annotator/classifier_dl/__init__.py +61 -0
- sparknlp/annotator/classifier_dl/albert_for_multiple_choice.py +161 -0
- sparknlp/annotator/classifier_dl/albert_for_question_answering.py +172 -0
- sparknlp/annotator/classifier_dl/albert_for_sequence_classification.py +201 -0
- sparknlp/annotator/classifier_dl/albert_for_token_classification.py +179 -0
- sparknlp/annotator/classifier_dl/albert_for_zero_shot_classification.py +211 -0
- sparknlp/annotator/classifier_dl/bart_for_zero_shot_classification.py +225 -0
- sparknlp/annotator/classifier_dl/bert_for_multiple_choice.py +161 -0
- sparknlp/annotator/classifier_dl/bert_for_question_answering.py +168 -0
- sparknlp/annotator/classifier_dl/bert_for_sequence_classification.py +202 -0
- sparknlp/annotator/classifier_dl/bert_for_token_classification.py +177 -0
- sparknlp/annotator/classifier_dl/bert_for_zero_shot_classification.py +212 -0
- sparknlp/annotator/classifier_dl/camembert_for_question_answering.py +168 -0
- sparknlp/annotator/classifier_dl/camembert_for_sequence_classification.py +205 -0
- sparknlp/annotator/classifier_dl/camembert_for_token_classification.py +173 -0
- sparknlp/annotator/classifier_dl/camembert_for_zero_shot_classification.py +202 -0
- sparknlp/annotator/classifier_dl/classifier_dl.py +320 -0
- sparknlp/annotator/classifier_dl/deberta_for_question_answering.py +168 -0
- sparknlp/annotator/classifier_dl/deberta_for_sequence_classification.py +198 -0
- sparknlp/annotator/classifier_dl/deberta_for_token_classification.py +175 -0
- sparknlp/annotator/classifier_dl/deberta_for_zero_shot_classification.py +193 -0
- sparknlp/annotator/classifier_dl/distil_bert_for_question_answering.py +168 -0
- sparknlp/annotator/classifier_dl/distil_bert_for_sequence_classification.py +201 -0
- sparknlp/annotator/classifier_dl/distil_bert_for_token_classification.py +175 -0
- sparknlp/annotator/classifier_dl/distil_bert_for_zero_shot_classification.py +211 -0
- sparknlp/annotator/classifier_dl/distilbert_for_multiple_choice.py +161 -0
- sparknlp/annotator/classifier_dl/longformer_for_question_answering.py +168 -0
- sparknlp/annotator/classifier_dl/longformer_for_sequence_classification.py +201 -0
- sparknlp/annotator/classifier_dl/longformer_for_token_classification.py +176 -0
- sparknlp/annotator/classifier_dl/mpnet_for_question_answering.py +148 -0
- sparknlp/annotator/classifier_dl/mpnet_for_sequence_classification.py +188 -0
- sparknlp/annotator/classifier_dl/mpnet_for_token_classification.py +173 -0
- sparknlp/annotator/classifier_dl/multi_classifier_dl.py +395 -0
- sparknlp/annotator/classifier_dl/roberta_for_multiple_choice.py +161 -0
- sparknlp/annotator/classifier_dl/roberta_for_question_answering.py +168 -0
- sparknlp/annotator/classifier_dl/roberta_for_sequence_classification.py +201 -0
- sparknlp/annotator/classifier_dl/roberta_for_token_classification.py +189 -0
- sparknlp/annotator/classifier_dl/roberta_for_zero_shot_classification.py +225 -0
- sparknlp/annotator/classifier_dl/sentiment_dl.py +378 -0
- sparknlp/annotator/classifier_dl/tapas_for_question_answering.py +170 -0
- sparknlp/annotator/classifier_dl/xlm_roberta_for_multiple_choice.py +149 -0
- sparknlp/annotator/classifier_dl/xlm_roberta_for_question_answering.py +168 -0
- sparknlp/annotator/classifier_dl/xlm_roberta_for_sequence_classification.py +201 -0
- sparknlp/annotator/classifier_dl/xlm_roberta_for_token_classification.py +173 -0
- sparknlp/annotator/classifier_dl/xlm_roberta_for_zero_shot_classification.py +225 -0
- sparknlp/annotator/classifier_dl/xlnet_for_sequence_classification.py +201 -0
- sparknlp/annotator/classifier_dl/xlnet_for_token_classification.py +176 -0
- sparknlp/annotator/cleaners/__init__.py +15 -0
- sparknlp/annotator/cleaners/cleaner.py +202 -0
- sparknlp/annotator/cleaners/extractor.py +191 -0
- sparknlp/annotator/coref/__init__.py +1 -0
- sparknlp/annotator/coref/spanbert_coref.py +221 -0
- sparknlp/annotator/cv/__init__.py +29 -0
- sparknlp/annotator/cv/blip_for_question_answering.py +172 -0
- sparknlp/annotator/cv/clip_for_zero_shot_classification.py +193 -0
- sparknlp/annotator/cv/convnext_for_image_classification.py +269 -0
- sparknlp/annotator/cv/florence2_transformer.py +180 -0
- sparknlp/annotator/cv/gemma3_for_multimodal.py +346 -0
- sparknlp/annotator/cv/internvl_for_multimodal.py +280 -0
- sparknlp/annotator/cv/janus_for_multimodal.py +351 -0
- sparknlp/annotator/cv/llava_for_multimodal.py +328 -0
- sparknlp/annotator/cv/mllama_for_multimodal.py +340 -0
- sparknlp/annotator/cv/paligemma_for_multimodal.py +308 -0
- sparknlp/annotator/cv/phi3_vision_for_multimodal.py +328 -0
- sparknlp/annotator/cv/qwen2vl_transformer.py +332 -0
- sparknlp/annotator/cv/smolvlm_transformer.py +426 -0
- sparknlp/annotator/cv/swin_for_image_classification.py +242 -0
- sparknlp/annotator/cv/vision_encoder_decoder_for_image_captioning.py +240 -0
- sparknlp/annotator/cv/vit_for_image_classification.py +217 -0
- sparknlp/annotator/dataframe_optimizer.py +216 -0
- sparknlp/annotator/date2_chunk.py +88 -0
- sparknlp/annotator/dependency/__init__.py +17 -0
- sparknlp/annotator/dependency/dependency_parser.py +294 -0
- sparknlp/annotator/dependency/typed_dependency_parser.py +318 -0
- sparknlp/annotator/document_character_text_splitter.py +228 -0
- sparknlp/annotator/document_normalizer.py +235 -0
- sparknlp/annotator/document_token_splitter.py +175 -0
- sparknlp/annotator/document_token_splitter_test.py +85 -0
- sparknlp/annotator/embeddings/__init__.py +45 -0
- sparknlp/annotator/embeddings/albert_embeddings.py +230 -0
- sparknlp/annotator/embeddings/auto_gguf_embeddings.py +539 -0
- sparknlp/annotator/embeddings/bert_embeddings.py +208 -0
- sparknlp/annotator/embeddings/bert_sentence_embeddings.py +224 -0
- sparknlp/annotator/embeddings/bge_embeddings.py +199 -0
- sparknlp/annotator/embeddings/camembert_embeddings.py +210 -0
- sparknlp/annotator/embeddings/chunk_embeddings.py +149 -0
- sparknlp/annotator/embeddings/deberta_embeddings.py +208 -0
- sparknlp/annotator/embeddings/distil_bert_embeddings.py +221 -0
- sparknlp/annotator/embeddings/doc2vec.py +352 -0
- sparknlp/annotator/embeddings/e5_embeddings.py +195 -0
- sparknlp/annotator/embeddings/e5v_embeddings.py +138 -0
- sparknlp/annotator/embeddings/elmo_embeddings.py +251 -0
- sparknlp/annotator/embeddings/instructor_embeddings.py +204 -0
- sparknlp/annotator/embeddings/longformer_embeddings.py +211 -0
- sparknlp/annotator/embeddings/minilm_embeddings.py +189 -0
- sparknlp/annotator/embeddings/mpnet_embeddings.py +192 -0
- sparknlp/annotator/embeddings/mxbai_embeddings.py +184 -0
- sparknlp/annotator/embeddings/nomic_embeddings.py +181 -0
- sparknlp/annotator/embeddings/roberta_embeddings.py +225 -0
- sparknlp/annotator/embeddings/roberta_sentence_embeddings.py +191 -0
- sparknlp/annotator/embeddings/sentence_embeddings.py +134 -0
- sparknlp/annotator/embeddings/snowflake_embeddings.py +202 -0
- sparknlp/annotator/embeddings/uae_embeddings.py +211 -0
- sparknlp/annotator/embeddings/universal_sentence_encoder.py +211 -0
- sparknlp/annotator/embeddings/word2vec.py +353 -0
- sparknlp/annotator/embeddings/word_embeddings.py +385 -0
- sparknlp/annotator/embeddings/xlm_roberta_embeddings.py +225 -0
- sparknlp/annotator/embeddings/xlm_roberta_sentence_embeddings.py +194 -0
- sparknlp/annotator/embeddings/xlnet_embeddings.py +227 -0
- sparknlp/annotator/er/__init__.py +16 -0
- sparknlp/annotator/er/entity_ruler.py +267 -0
- sparknlp/annotator/graph_extraction.py +368 -0
- sparknlp/annotator/keyword_extraction/__init__.py +16 -0
- sparknlp/annotator/keyword_extraction/yake_keyword_extraction.py +270 -0
- sparknlp/annotator/ld_dl/__init__.py +16 -0
- sparknlp/annotator/ld_dl/language_detector_dl.py +199 -0
- sparknlp/annotator/lemmatizer.py +250 -0
- sparknlp/annotator/matcher/__init__.py +20 -0
- sparknlp/annotator/matcher/big_text_matcher.py +272 -0
- sparknlp/annotator/matcher/date_matcher.py +303 -0
- sparknlp/annotator/matcher/multi_date_matcher.py +109 -0
- sparknlp/annotator/matcher/regex_matcher.py +221 -0
- sparknlp/annotator/matcher/text_matcher.py +290 -0
- sparknlp/annotator/n_gram_generator.py +141 -0
- sparknlp/annotator/ner/__init__.py +21 -0
- sparknlp/annotator/ner/ner_approach.py +94 -0
- sparknlp/annotator/ner/ner_converter.py +148 -0
- sparknlp/annotator/ner/ner_crf.py +397 -0
- sparknlp/annotator/ner/ner_dl.py +591 -0
- sparknlp/annotator/ner/ner_dl_graph_checker.py +293 -0
- sparknlp/annotator/ner/ner_overwriter.py +166 -0
- sparknlp/annotator/ner/zero_shot_ner_model.py +173 -0
- sparknlp/annotator/normalizer.py +230 -0
- sparknlp/annotator/openai/__init__.py +16 -0
- sparknlp/annotator/openai/openai_completion.py +349 -0
- sparknlp/annotator/openai/openai_embeddings.py +106 -0
- sparknlp/annotator/param/__init__.py +17 -0
- sparknlp/annotator/param/classifier_encoder.py +98 -0
- sparknlp/annotator/param/evaluation_dl_params.py +130 -0
- sparknlp/annotator/pos/__init__.py +16 -0
- sparknlp/annotator/pos/perceptron.py +263 -0
- sparknlp/annotator/sentence/__init__.py +17 -0
- sparknlp/annotator/sentence/sentence_detector.py +290 -0
- sparknlp/annotator/sentence/sentence_detector_dl.py +467 -0
- sparknlp/annotator/sentiment/__init__.py +17 -0
- sparknlp/annotator/sentiment/sentiment_detector.py +208 -0
- sparknlp/annotator/sentiment/vivekn_sentiment.py +242 -0
- sparknlp/annotator/seq2seq/__init__.py +35 -0
- sparknlp/annotator/seq2seq/auto_gguf_model.py +304 -0
- sparknlp/annotator/seq2seq/auto_gguf_reranker.py +334 -0
- sparknlp/annotator/seq2seq/auto_gguf_vision_model.py +336 -0
- sparknlp/annotator/seq2seq/bart_transformer.py +420 -0
- sparknlp/annotator/seq2seq/cohere_transformer.py +357 -0
- sparknlp/annotator/seq2seq/cpm_transformer.py +321 -0
- sparknlp/annotator/seq2seq/gpt2_transformer.py +363 -0
- sparknlp/annotator/seq2seq/llama2_transformer.py +343 -0
- sparknlp/annotator/seq2seq/llama3_transformer.py +381 -0
- sparknlp/annotator/seq2seq/m2m100_transformer.py +392 -0
- sparknlp/annotator/seq2seq/marian_transformer.py +374 -0
- sparknlp/annotator/seq2seq/mistral_transformer.py +348 -0
- sparknlp/annotator/seq2seq/nllb_transformer.py +420 -0
- sparknlp/annotator/seq2seq/olmo_transformer.py +326 -0
- sparknlp/annotator/seq2seq/phi2_transformer.py +326 -0
- sparknlp/annotator/seq2seq/phi3_transformer.py +330 -0
- sparknlp/annotator/seq2seq/phi4_transformer.py +387 -0
- sparknlp/annotator/seq2seq/qwen_transformer.py +340 -0
- sparknlp/annotator/seq2seq/starcoder_transformer.py +335 -0
- sparknlp/annotator/seq2seq/t5_transformer.py +425 -0
- sparknlp/annotator/similarity/__init__.py +0 -0
- sparknlp/annotator/similarity/document_similarity_ranker.py +379 -0
- sparknlp/annotator/spell_check/__init__.py +18 -0
- sparknlp/annotator/spell_check/context_spell_checker.py +911 -0
- sparknlp/annotator/spell_check/norvig_sweeting.py +358 -0
- sparknlp/annotator/spell_check/symmetric_delete.py +299 -0
- sparknlp/annotator/stemmer.py +79 -0
- sparknlp/annotator/stop_words_cleaner.py +190 -0
- sparknlp/annotator/tf_ner_dl_graph_builder.py +179 -0
- sparknlp/annotator/token/__init__.py +19 -0
- sparknlp/annotator/token/chunk_tokenizer.py +118 -0
- sparknlp/annotator/token/recursive_tokenizer.py +205 -0
- sparknlp/annotator/token/regex_tokenizer.py +208 -0
- sparknlp/annotator/token/tokenizer.py +561 -0
- sparknlp/annotator/token2_chunk.py +76 -0
- sparknlp/annotator/ws/__init__.py +16 -0
- sparknlp/annotator/ws/word_segmenter.py +429 -0
- sparknlp/base/__init__.py +30 -0
- sparknlp/base/audio_assembler.py +95 -0
- sparknlp/base/doc2_chunk.py +169 -0
- sparknlp/base/document_assembler.py +164 -0
- sparknlp/base/embeddings_finisher.py +201 -0
- sparknlp/base/finisher.py +217 -0
- sparknlp/base/gguf_ranking_finisher.py +234 -0
- sparknlp/base/graph_finisher.py +125 -0
- sparknlp/base/has_recursive_fit.py +24 -0
- sparknlp/base/has_recursive_transform.py +22 -0
- sparknlp/base/image_assembler.py +172 -0
- sparknlp/base/light_pipeline.py +429 -0
- sparknlp/base/multi_document_assembler.py +164 -0
- sparknlp/base/prompt_assembler.py +207 -0
- sparknlp/base/recursive_pipeline.py +107 -0
- sparknlp/base/table_assembler.py +145 -0
- sparknlp/base/token_assembler.py +124 -0
- sparknlp/common/__init__.py +26 -0
- sparknlp/common/annotator_approach.py +41 -0
- sparknlp/common/annotator_model.py +47 -0
- sparknlp/common/annotator_properties.py +114 -0
- sparknlp/common/annotator_type.py +38 -0
- sparknlp/common/completion_post_processing.py +37 -0
- sparknlp/common/coverage_result.py +22 -0
- sparknlp/common/match_strategy.py +33 -0
- sparknlp/common/properties.py +1298 -0
- sparknlp/common/read_as.py +33 -0
- sparknlp/common/recursive_annotator_approach.py +35 -0
- sparknlp/common/storage.py +149 -0
- sparknlp/common/utils.py +39 -0
- sparknlp/functions.py +315 -5
- sparknlp/internal/__init__.py +1199 -0
- sparknlp/internal/annotator_java_ml.py +32 -0
- sparknlp/internal/annotator_transformer.py +37 -0
- sparknlp/internal/extended_java_wrapper.py +63 -0
- sparknlp/internal/params_getters_setters.py +71 -0
- sparknlp/internal/recursive.py +70 -0
- sparknlp/logging/__init__.py +15 -0
- sparknlp/logging/comet.py +467 -0
- sparknlp/partition/__init__.py +16 -0
- sparknlp/partition/partition.py +244 -0
- sparknlp/partition/partition_properties.py +902 -0
- sparknlp/partition/partition_transformer.py +200 -0
- sparknlp/pretrained/__init__.py +17 -0
- sparknlp/pretrained/pretrained_pipeline.py +158 -0
- sparknlp/pretrained/resource_downloader.py +216 -0
- sparknlp/pretrained/utils.py +35 -0
- sparknlp/reader/__init__.py +15 -0
- sparknlp/reader/enums.py +19 -0
- sparknlp/reader/pdf_to_text.py +190 -0
- sparknlp/reader/reader2doc.py +124 -0
- sparknlp/reader/reader2image.py +136 -0
- sparknlp/reader/reader2table.py +44 -0
- sparknlp/reader/reader_assembler.py +159 -0
- sparknlp/reader/sparknlp_reader.py +461 -0
- sparknlp/training/__init__.py +20 -0
- sparknlp/training/_tf_graph_builders/__init__.py +0 -0
- sparknlp/training/_tf_graph_builders/graph_builders.py +299 -0
- sparknlp/training/_tf_graph_builders/ner_dl/__init__.py +0 -0
- sparknlp/training/_tf_graph_builders/ner_dl/create_graph.py +41 -0
- sparknlp/training/_tf_graph_builders/ner_dl/dataset_encoder.py +78 -0
- sparknlp/training/_tf_graph_builders/ner_dl/ner_model.py +521 -0
- sparknlp/training/_tf_graph_builders/ner_dl/ner_model_saver.py +62 -0
- sparknlp/training/_tf_graph_builders/ner_dl/sentence_grouper.py +28 -0
- sparknlp/training/_tf_graph_builders/tf2contrib/__init__.py +36 -0
- sparknlp/training/_tf_graph_builders/tf2contrib/core_rnn_cell.py +385 -0
- sparknlp/training/_tf_graph_builders/tf2contrib/fused_rnn_cell.py +183 -0
- sparknlp/training/_tf_graph_builders/tf2contrib/gru_ops.py +235 -0
- sparknlp/training/_tf_graph_builders/tf2contrib/lstm_ops.py +665 -0
- sparknlp/training/_tf_graph_builders/tf2contrib/rnn.py +245 -0
- sparknlp/training/_tf_graph_builders/tf2contrib/rnn_cell.py +4006 -0
- sparknlp/training/_tf_graph_builders_1x/__init__.py +0 -0
- sparknlp/training/_tf_graph_builders_1x/graph_builders.py +277 -0
- sparknlp/training/_tf_graph_builders_1x/ner_dl/__init__.py +0 -0
- sparknlp/training/_tf_graph_builders_1x/ner_dl/create_graph.py +34 -0
- sparknlp/training/_tf_graph_builders_1x/ner_dl/dataset_encoder.py +78 -0
- sparknlp/training/_tf_graph_builders_1x/ner_dl/ner_model.py +532 -0
- sparknlp/training/_tf_graph_builders_1x/ner_dl/ner_model_saver.py +62 -0
- sparknlp/training/_tf_graph_builders_1x/ner_dl/sentence_grouper.py +28 -0
- sparknlp/training/conll.py +150 -0
- sparknlp/training/conllu.py +103 -0
- sparknlp/training/pos.py +103 -0
- sparknlp/training/pub_tator.py +76 -0
- sparknlp/training/spacy_to_annotation.py +57 -0
- sparknlp/training/tfgraphs.py +5 -0
- sparknlp/upload_to_hub.py +149 -0
- sparknlp/util.py +51 -5
- com/__init__.pyc +0 -0
- com/__pycache__/__init__.cpython-36.pyc +0 -0
- com/johnsnowlabs/__init__.pyc +0 -0
- com/johnsnowlabs/__pycache__/__init__.cpython-36.pyc +0 -0
- com/johnsnowlabs/nlp/__init__.pyc +0 -0
- com/johnsnowlabs/nlp/__pycache__/__init__.cpython-36.pyc +0 -0
- spark_nlp-2.6.3rc1.dist-info/METADATA +0 -36
- spark_nlp-2.6.3rc1.dist-info/RECORD +0 -48
- sparknlp/__init__.pyc +0 -0
- sparknlp/__pycache__/__init__.cpython-36.pyc +0 -0
- sparknlp/__pycache__/annotation.cpython-36.pyc +0 -0
- sparknlp/__pycache__/annotator.cpython-36.pyc +0 -0
- sparknlp/__pycache__/base.cpython-36.pyc +0 -0
- sparknlp/__pycache__/common.cpython-36.pyc +0 -0
- sparknlp/__pycache__/embeddings.cpython-36.pyc +0 -0
- sparknlp/__pycache__/functions.cpython-36.pyc +0 -0
- sparknlp/__pycache__/internal.cpython-36.pyc +0 -0
- sparknlp/__pycache__/pretrained.cpython-36.pyc +0 -0
- sparknlp/__pycache__/storage.cpython-36.pyc +0 -0
- sparknlp/__pycache__/training.cpython-36.pyc +0 -0
- sparknlp/__pycache__/util.cpython-36.pyc +0 -0
- sparknlp/annotation.pyc +0 -0
- sparknlp/annotator.py +0 -3006
- sparknlp/annotator.pyc +0 -0
- sparknlp/base.py +0 -347
- sparknlp/base.pyc +0 -0
- sparknlp/common.py +0 -193
- sparknlp/common.pyc +0 -0
- sparknlp/embeddings.py +0 -40
- sparknlp/embeddings.pyc +0 -0
- sparknlp/internal.py +0 -288
- sparknlp/internal.pyc +0 -0
- sparknlp/pretrained.py +0 -123
- sparknlp/pretrained.pyc +0 -0
- sparknlp/storage.py +0 -32
- sparknlp/storage.pyc +0 -0
- sparknlp/training.py +0 -62
- sparknlp/training.pyc +0 -0
- sparknlp/util.pyc +0 -0
- {spark_nlp-2.6.3rc1.dist-info → spark_nlp-6.2.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,4006 @@
|
|
|
1
|
+
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Module for constructing RNN Cells."""
|
|
16
|
+
from __future__ import absolute_import
|
|
17
|
+
from __future__ import division
|
|
18
|
+
from __future__ import print_function
|
|
19
|
+
|
|
20
|
+
import collections
|
|
21
|
+
import math
|
|
22
|
+
|
|
23
|
+
from tensorflow.contrib.compiler import jit
|
|
24
|
+
from tensorflow.contrib.layers.python.layers import layers
|
|
25
|
+
from tensorflow.contrib.rnn.python.ops import core_rnn_cell
|
|
26
|
+
from tensorflow.python.framework import constant_op
|
|
27
|
+
from tensorflow.python.framework import dtypes
|
|
28
|
+
from tensorflow.python.framework import op_def_registry
|
|
29
|
+
from tensorflow.python.framework import ops
|
|
30
|
+
from tensorflow.python.framework import tensor_shape
|
|
31
|
+
from tensorflow.python.keras import activations
|
|
32
|
+
from tensorflow.python.keras import initializers
|
|
33
|
+
from tensorflow.python.keras.engine import input_spec
|
|
34
|
+
from tensorflow.python.ops import array_ops
|
|
35
|
+
from tensorflow.python.ops import clip_ops
|
|
36
|
+
from tensorflow.python.ops import control_flow_ops
|
|
37
|
+
from tensorflow.python.ops import gen_array_ops
|
|
38
|
+
from tensorflow.python.ops import init_ops
|
|
39
|
+
from tensorflow.python.ops import math_ops
|
|
40
|
+
from tensorflow.python.ops import nn_impl # pylint: disable=unused-import
|
|
41
|
+
from tensorflow.python.ops import nn_ops
|
|
42
|
+
from tensorflow.python.ops import random_ops
|
|
43
|
+
from tensorflow.python.ops import rnn_cell_impl
|
|
44
|
+
from tensorflow.python.ops import variable_scope as vs
|
|
45
|
+
from tensorflow.python.platform import tf_logging as logging
|
|
46
|
+
from tensorflow.python.util import nest
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _get_concat_variable(name, shape, dtype, num_shards):
|
|
50
|
+
"""Get a sharded variable concatenated into one tensor."""
|
|
51
|
+
sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards)
|
|
52
|
+
if len(sharded_variable) == 1:
|
|
53
|
+
return sharded_variable[0]
|
|
54
|
+
|
|
55
|
+
concat_name = name + "/concat"
|
|
56
|
+
concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0"
|
|
57
|
+
for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES):
|
|
58
|
+
if value.name == concat_full_name:
|
|
59
|
+
return value
|
|
60
|
+
|
|
61
|
+
concat_variable = array_ops.concat(sharded_variable, 0, name=concat_name)
|
|
62
|
+
ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES, concat_variable)
|
|
63
|
+
return concat_variable
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _get_sharded_variable(name, shape, dtype, num_shards):
|
|
67
|
+
"""Get a list of sharded variables with the given dtype."""
|
|
68
|
+
if num_shards > shape[0]:
|
|
69
|
+
raise ValueError("Too many shards: shape=%s, num_shards=%d" % (shape,
|
|
70
|
+
num_shards))
|
|
71
|
+
unit_shard_size = int(math.floor(shape[0] / num_shards))
|
|
72
|
+
remaining_rows = shape[0] - unit_shard_size * num_shards
|
|
73
|
+
|
|
74
|
+
shards = []
|
|
75
|
+
for i in range(num_shards):
|
|
76
|
+
current_size = unit_shard_size
|
|
77
|
+
if i < remaining_rows:
|
|
78
|
+
current_size += 1
|
|
79
|
+
shards.append(
|
|
80
|
+
vs.get_variable(
|
|
81
|
+
name + "_%d" % i, [current_size] + shape[1:], dtype=dtype))
|
|
82
|
+
return shards
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def _norm(g, b, inp, scope):
|
|
86
|
+
shape = inp.get_shape()[-1:]
|
|
87
|
+
gamma_init = init_ops.constant_initializer(g)
|
|
88
|
+
beta_init = init_ops.constant_initializer(b)
|
|
89
|
+
with vs.variable_scope(scope):
|
|
90
|
+
# Initialize beta and gamma for use by layer_norm.
|
|
91
|
+
vs.get_variable("gamma", shape=shape, initializer=gamma_init)
|
|
92
|
+
vs.get_variable("beta", shape=shape, initializer=beta_init)
|
|
93
|
+
normalized = layers.layer_norm(inp, reuse=True, scope=scope)
|
|
94
|
+
return normalized
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell):
|
|
98
|
+
"""Long short-term memory unit (LSTM) recurrent network cell.
|
|
99
|
+
|
|
100
|
+
The default non-peephole implementation is based on:
|
|
101
|
+
|
|
102
|
+
https://pdfs.semanticscholar.org/1154/0131eae85b2e11d53df7f1360eeb6476e7f4.pdf
|
|
103
|
+
|
|
104
|
+
Felix Gers, Jurgen Schmidhuber, and Fred Cummins.
|
|
105
|
+
"Learning to forget: Continual prediction with LSTM." IET, 850-855, 1999.
|
|
106
|
+
|
|
107
|
+
The peephole implementation is based on:
|
|
108
|
+
|
|
109
|
+
https://research.google.com/pubs/archive/43905.pdf
|
|
110
|
+
|
|
111
|
+
Hasim Sak, Andrew Senior, and Francoise Beaufays.
|
|
112
|
+
"Long short-term memory recurrent neural network architectures for
|
|
113
|
+
large scale acoustic modeling." INTERSPEECH, 2014.
|
|
114
|
+
|
|
115
|
+
The coupling of input and forget gate is based on:
|
|
116
|
+
|
|
117
|
+
http://arxiv.org/pdf/1503.04069.pdf
|
|
118
|
+
|
|
119
|
+
Greff et al. "LSTM: A Search Space Odyssey"
|
|
120
|
+
|
|
121
|
+
The class uses optional peep-hole connections, and an optional projection
|
|
122
|
+
layer.
|
|
123
|
+
Layer normalization implementation is based on:
|
|
124
|
+
|
|
125
|
+
https://arxiv.org/abs/1607.06450.
|
|
126
|
+
|
|
127
|
+
"Layer Normalization"
|
|
128
|
+
Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton
|
|
129
|
+
|
|
130
|
+
and is applied before the internal nonlinearities.
|
|
131
|
+
|
|
132
|
+
"""
|
|
133
|
+
|
|
134
|
+
def __init__(self,
|
|
135
|
+
num_units,
|
|
136
|
+
use_peepholes=False,
|
|
137
|
+
initializer=None,
|
|
138
|
+
num_proj=None,
|
|
139
|
+
proj_clip=None,
|
|
140
|
+
num_unit_shards=1,
|
|
141
|
+
num_proj_shards=1,
|
|
142
|
+
forget_bias=1.0,
|
|
143
|
+
state_is_tuple=True,
|
|
144
|
+
activation=math_ops.tanh,
|
|
145
|
+
reuse=None,
|
|
146
|
+
layer_norm=False,
|
|
147
|
+
norm_gain=1.0,
|
|
148
|
+
norm_shift=0.0):
|
|
149
|
+
"""Initialize the parameters for an LSTM cell.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
num_units: int, The number of units in the LSTM cell
|
|
153
|
+
use_peepholes: bool, set True to enable diagonal/peephole connections.
|
|
154
|
+
initializer: (optional) The initializer to use for the weight and
|
|
155
|
+
projection matrices.
|
|
156
|
+
num_proj: (optional) int, The output dimensionality for the projection
|
|
157
|
+
matrices. If None, no projection is performed.
|
|
158
|
+
proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is
|
|
159
|
+
provided, then the projected values are clipped elementwise to within
|
|
160
|
+
`[-proj_clip, proj_clip]`.
|
|
161
|
+
num_unit_shards: How to split the weight matrix. If >1, the weight
|
|
162
|
+
matrix is stored across num_unit_shards.
|
|
163
|
+
num_proj_shards: How to split the projection matrix. If >1, the
|
|
164
|
+
projection matrix is stored across num_proj_shards.
|
|
165
|
+
forget_bias: Biases of the forget gate are initialized by default to 1
|
|
166
|
+
in order to reduce the scale of forgetting at the beginning of
|
|
167
|
+
the training.
|
|
168
|
+
state_is_tuple: If True, accepted and returned states are 2-tuples of
|
|
169
|
+
the `c_state` and `m_state`. By default (False), they are concatenated
|
|
170
|
+
along the column axis. This default behavior will soon be deprecated.
|
|
171
|
+
activation: Activation function of the inner states.
|
|
172
|
+
reuse: (optional) Python boolean describing whether to reuse variables
|
|
173
|
+
in an existing scope. If not `True`, and the existing scope already has
|
|
174
|
+
the given variables, an error is raised.
|
|
175
|
+
layer_norm: If `True`, layer normalization will be applied.
|
|
176
|
+
norm_gain: float, The layer normalization gain initial value. If
|
|
177
|
+
`layer_norm` has been set to `False`, this argument will be ignored.
|
|
178
|
+
norm_shift: float, The layer normalization shift initial value. If
|
|
179
|
+
`layer_norm` has been set to `False`, this argument will be ignored.
|
|
180
|
+
"""
|
|
181
|
+
super(CoupledInputForgetGateLSTMCell, self).__init__(_reuse=reuse)
|
|
182
|
+
if not state_is_tuple:
|
|
183
|
+
logging.warn("%s: Using a concatenated state is slower and will soon be "
|
|
184
|
+
"deprecated. Use state_is_tuple=True.", self)
|
|
185
|
+
self._num_units = num_units
|
|
186
|
+
self._use_peepholes = use_peepholes
|
|
187
|
+
self._initializer = initializer
|
|
188
|
+
self._num_proj = num_proj
|
|
189
|
+
self._proj_clip = proj_clip
|
|
190
|
+
self._num_unit_shards = num_unit_shards
|
|
191
|
+
self._num_proj_shards = num_proj_shards
|
|
192
|
+
self._forget_bias = forget_bias
|
|
193
|
+
self._state_is_tuple = state_is_tuple
|
|
194
|
+
self._activation = activation
|
|
195
|
+
self._reuse = reuse
|
|
196
|
+
self._layer_norm = layer_norm
|
|
197
|
+
self._norm_gain = norm_gain
|
|
198
|
+
self._norm_shift = norm_shift
|
|
199
|
+
|
|
200
|
+
if num_proj:
|
|
201
|
+
self._state_size = (
|
|
202
|
+
rnn_cell_impl.LSTMStateTuple(num_units, num_proj)
|
|
203
|
+
if state_is_tuple else num_units + num_proj)
|
|
204
|
+
self._output_size = num_proj
|
|
205
|
+
else:
|
|
206
|
+
self._state_size = (
|
|
207
|
+
rnn_cell_impl.LSTMStateTuple(num_units, num_units)
|
|
208
|
+
if state_is_tuple else 2 * num_units)
|
|
209
|
+
self._output_size = num_units
|
|
210
|
+
|
|
211
|
+
@property
|
|
212
|
+
def state_size(self):
|
|
213
|
+
return self._state_size
|
|
214
|
+
|
|
215
|
+
@property
|
|
216
|
+
def output_size(self):
|
|
217
|
+
return self._output_size
|
|
218
|
+
|
|
219
|
+
def call(self, inputs, state):
|
|
220
|
+
"""Run one step of LSTM.
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
inputs: input Tensor, 2D, batch x num_units.
|
|
224
|
+
state: if `state_is_tuple` is False, this must be a state Tensor,
|
|
225
|
+
`2-D, batch x state_size`. If `state_is_tuple` is True, this must be a
|
|
226
|
+
tuple of state Tensors, both `2-D`, with column sizes `c_state` and
|
|
227
|
+
`m_state`.
|
|
228
|
+
|
|
229
|
+
Returns:
|
|
230
|
+
A tuple containing:
|
|
231
|
+
- A `2-D, [batch x output_dim]`, Tensor representing the output of the
|
|
232
|
+
LSTM after reading `inputs` when previous state was `state`.
|
|
233
|
+
Here output_dim is:
|
|
234
|
+
num_proj if num_proj was set,
|
|
235
|
+
num_units otherwise.
|
|
236
|
+
- Tensor(s) representing the new state of LSTM after reading `inputs` when
|
|
237
|
+
the previous state was `state`. Same type and shape(s) as `state`.
|
|
238
|
+
|
|
239
|
+
Raises:
|
|
240
|
+
ValueError: If input size cannot be inferred from inputs via
|
|
241
|
+
static shape inference.
|
|
242
|
+
"""
|
|
243
|
+
sigmoid = math_ops.sigmoid
|
|
244
|
+
|
|
245
|
+
num_proj = self._num_units if self._num_proj is None else self._num_proj
|
|
246
|
+
|
|
247
|
+
if self._state_is_tuple:
|
|
248
|
+
(c_prev, m_prev) = state
|
|
249
|
+
else:
|
|
250
|
+
c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units])
|
|
251
|
+
m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj])
|
|
252
|
+
|
|
253
|
+
dtype = inputs.dtype
|
|
254
|
+
input_size = inputs.get_shape().with_rank(2).dims[1]
|
|
255
|
+
if input_size.value is None:
|
|
256
|
+
raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
|
|
257
|
+
concat_w = _get_concat_variable(
|
|
258
|
+
"W",
|
|
259
|
+
[input_size.value + num_proj, 3 * self._num_units],
|
|
260
|
+
dtype,
|
|
261
|
+
self._num_unit_shards)
|
|
262
|
+
|
|
263
|
+
b = vs.get_variable(
|
|
264
|
+
"B",
|
|
265
|
+
shape=[3 * self._num_units],
|
|
266
|
+
initializer=init_ops.zeros_initializer(),
|
|
267
|
+
dtype=dtype)
|
|
268
|
+
|
|
269
|
+
# j = new_input, f = forget_gate, o = output_gate
|
|
270
|
+
cell_inputs = array_ops.concat([inputs, m_prev], 1)
|
|
271
|
+
lstm_matrix = math_ops.matmul(cell_inputs, concat_w)
|
|
272
|
+
|
|
273
|
+
# If layer nomalization is applied, do not add bias
|
|
274
|
+
if not self._layer_norm:
|
|
275
|
+
lstm_matrix = nn_ops.bias_add(lstm_matrix, b)
|
|
276
|
+
|
|
277
|
+
j, f, o = array_ops.split(value=lstm_matrix, num_or_size_splits=3, axis=1)
|
|
278
|
+
|
|
279
|
+
# Apply layer normalization
|
|
280
|
+
if self._layer_norm:
|
|
281
|
+
j = _norm(self._norm_gain, self._norm_shift, j, "transform")
|
|
282
|
+
f = _norm(self._norm_gain, self._norm_shift, f, "forget")
|
|
283
|
+
o = _norm(self._norm_gain, self._norm_shift, o, "output")
|
|
284
|
+
|
|
285
|
+
# Diagonal connections
|
|
286
|
+
if self._use_peepholes:
|
|
287
|
+
w_f_diag = vs.get_variable(
|
|
288
|
+
"W_F_diag", shape=[self._num_units], dtype=dtype)
|
|
289
|
+
w_o_diag = vs.get_variable(
|
|
290
|
+
"W_O_diag", shape=[self._num_units], dtype=dtype)
|
|
291
|
+
|
|
292
|
+
if self._use_peepholes:
|
|
293
|
+
f_act = sigmoid(f + self._forget_bias + w_f_diag * c_prev)
|
|
294
|
+
else:
|
|
295
|
+
f_act = sigmoid(f + self._forget_bias)
|
|
296
|
+
c = (f_act * c_prev + (1 - f_act) * self._activation(j))
|
|
297
|
+
|
|
298
|
+
# Apply layer normalization
|
|
299
|
+
if self._layer_norm:
|
|
300
|
+
c = _norm(self._norm_gain, self._norm_shift, c, "state")
|
|
301
|
+
|
|
302
|
+
if self._use_peepholes:
|
|
303
|
+
m = sigmoid(o + w_o_diag * c) * self._activation(c)
|
|
304
|
+
else:
|
|
305
|
+
m = sigmoid(o) * self._activation(c)
|
|
306
|
+
|
|
307
|
+
if self._num_proj is not None:
|
|
308
|
+
concat_w_proj = _get_concat_variable("W_P",
|
|
309
|
+
[self._num_units, self._num_proj],
|
|
310
|
+
dtype, self._num_proj_shards)
|
|
311
|
+
|
|
312
|
+
m = math_ops.matmul(m, concat_w_proj)
|
|
313
|
+
if self._proj_clip is not None:
|
|
314
|
+
# pylint: disable=invalid-unary-operand-type
|
|
315
|
+
m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)
|
|
316
|
+
# pylint: enable=invalid-unary-operand-type
|
|
317
|
+
|
|
318
|
+
new_state = (
|
|
319
|
+
rnn_cell_impl.LSTMStateTuple(c, m)
|
|
320
|
+
if self._state_is_tuple else array_ops.concat([c, m], 1))
|
|
321
|
+
return m, new_state
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
class TimeFreqLSTMCell(rnn_cell_impl.RNNCell):
|
|
325
|
+
"""Time-Frequency Long short-term memory unit (LSTM) recurrent network cell.
|
|
326
|
+
|
|
327
|
+
This implementation is based on:
|
|
328
|
+
|
|
329
|
+
Tara N. Sainath and Bo Li
|
|
330
|
+
"Modeling Time-Frequency Patterns with LSTM vs. Convolutional Architectures
|
|
331
|
+
for LVCSR Tasks." submitted to INTERSPEECH, 2016.
|
|
332
|
+
|
|
333
|
+
It uses peep-hole connections and optional cell clipping.
|
|
334
|
+
"""
|
|
335
|
+
|
|
336
|
+
def __init__(self,
|
|
337
|
+
num_units,
|
|
338
|
+
use_peepholes=False,
|
|
339
|
+
cell_clip=None,
|
|
340
|
+
initializer=None,
|
|
341
|
+
num_unit_shards=1,
|
|
342
|
+
forget_bias=1.0,
|
|
343
|
+
feature_size=None,
|
|
344
|
+
frequency_skip=1,
|
|
345
|
+
reuse=None):
|
|
346
|
+
"""Initialize the parameters for an LSTM cell.
|
|
347
|
+
|
|
348
|
+
Args:
|
|
349
|
+
num_units: int, The number of units in the LSTM cell
|
|
350
|
+
use_peepholes: bool, set True to enable diagonal/peephole connections.
|
|
351
|
+
cell_clip: (optional) A float value, if provided the cell state is clipped
|
|
352
|
+
by this value prior to the cell output activation.
|
|
353
|
+
initializer: (optional) The initializer to use for the weight and
|
|
354
|
+
projection matrices.
|
|
355
|
+
num_unit_shards: int, How to split the weight matrix. If >1, the weight
|
|
356
|
+
matrix is stored across num_unit_shards.
|
|
357
|
+
forget_bias: float, Biases of the forget gate are initialized by default
|
|
358
|
+
to 1 in order to reduce the scale of forgetting at the beginning
|
|
359
|
+
of the training.
|
|
360
|
+
feature_size: int, The size of the input feature the LSTM spans over.
|
|
361
|
+
frequency_skip: int, The amount the LSTM filter is shifted by in
|
|
362
|
+
frequency.
|
|
363
|
+
reuse: (optional) Python boolean describing whether to reuse variables
|
|
364
|
+
in an existing scope. If not `True`, and the existing scope already has
|
|
365
|
+
the given variables, an error is raised.
|
|
366
|
+
"""
|
|
367
|
+
super(TimeFreqLSTMCell, self).__init__(_reuse=reuse)
|
|
368
|
+
self._num_units = num_units
|
|
369
|
+
self._use_peepholes = use_peepholes
|
|
370
|
+
self._cell_clip = cell_clip
|
|
371
|
+
self._initializer = initializer
|
|
372
|
+
self._num_unit_shards = num_unit_shards
|
|
373
|
+
self._forget_bias = forget_bias
|
|
374
|
+
self._feature_size = feature_size
|
|
375
|
+
self._frequency_skip = frequency_skip
|
|
376
|
+
self._state_size = 2 * num_units
|
|
377
|
+
self._output_size = num_units
|
|
378
|
+
self._reuse = reuse
|
|
379
|
+
|
|
380
|
+
@property
|
|
381
|
+
def output_size(self):
|
|
382
|
+
return self._output_size
|
|
383
|
+
|
|
384
|
+
@property
|
|
385
|
+
def state_size(self):
|
|
386
|
+
return self._state_size
|
|
387
|
+
|
|
388
|
+
def call(self, inputs, state):
|
|
389
|
+
"""Run one step of LSTM.
|
|
390
|
+
|
|
391
|
+
Args:
|
|
392
|
+
inputs: input Tensor, 2D, batch x num_units.
|
|
393
|
+
state: state Tensor, 2D, batch x state_size.
|
|
394
|
+
|
|
395
|
+
Returns:
|
|
396
|
+
A tuple containing:
|
|
397
|
+
- A 2D, batch x output_dim, Tensor representing the output of the LSTM
|
|
398
|
+
after reading "inputs" when previous state was "state".
|
|
399
|
+
Here output_dim is num_units.
|
|
400
|
+
- A 2D, batch x state_size, Tensor representing the new state of LSTM
|
|
401
|
+
after reading "inputs" when previous state was "state".
|
|
402
|
+
Raises:
|
|
403
|
+
ValueError: if an input_size was specified and the provided inputs have
|
|
404
|
+
a different dimension.
|
|
405
|
+
"""
|
|
406
|
+
sigmoid = math_ops.sigmoid
|
|
407
|
+
tanh = math_ops.tanh
|
|
408
|
+
|
|
409
|
+
freq_inputs = self._make_tf_features(inputs)
|
|
410
|
+
dtype = inputs.dtype
|
|
411
|
+
actual_input_size = freq_inputs[0].get_shape().as_list()[1]
|
|
412
|
+
|
|
413
|
+
concat_w = _get_concat_variable(
|
|
414
|
+
"W", [actual_input_size + 2 * self._num_units, 4 * self._num_units],
|
|
415
|
+
dtype, self._num_unit_shards)
|
|
416
|
+
|
|
417
|
+
b = vs.get_variable(
|
|
418
|
+
"B",
|
|
419
|
+
shape=[4 * self._num_units],
|
|
420
|
+
initializer=init_ops.zeros_initializer(),
|
|
421
|
+
dtype=dtype)
|
|
422
|
+
|
|
423
|
+
# Diagonal connections
|
|
424
|
+
if self._use_peepholes:
|
|
425
|
+
w_f_diag = vs.get_variable(
|
|
426
|
+
"W_F_diag", shape=[self._num_units], dtype=dtype)
|
|
427
|
+
w_i_diag = vs.get_variable(
|
|
428
|
+
"W_I_diag", shape=[self._num_units], dtype=dtype)
|
|
429
|
+
w_o_diag = vs.get_variable(
|
|
430
|
+
"W_O_diag", shape=[self._num_units], dtype=dtype)
|
|
431
|
+
|
|
432
|
+
# initialize the first freq state to be zero
|
|
433
|
+
m_prev_freq = array_ops.zeros(
|
|
434
|
+
[inputs.shape.dims[0].value or inputs.get_shape()[0], self._num_units],
|
|
435
|
+
dtype)
|
|
436
|
+
for fq in range(len(freq_inputs)):
|
|
437
|
+
c_prev = array_ops.slice(state, [0, 2 * fq * self._num_units],
|
|
438
|
+
[-1, self._num_units])
|
|
439
|
+
m_prev = array_ops.slice(state, [0, (2 * fq + 1) * self._num_units],
|
|
440
|
+
[-1, self._num_units])
|
|
441
|
+
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
|
|
442
|
+
cell_inputs = array_ops.concat([freq_inputs[fq], m_prev, m_prev_freq], 1)
|
|
443
|
+
lstm_matrix = nn_ops.bias_add(math_ops.matmul(cell_inputs, concat_w), b)
|
|
444
|
+
i, j, f, o = array_ops.split(
|
|
445
|
+
value=lstm_matrix, num_or_size_splits=4, axis=1)
|
|
446
|
+
|
|
447
|
+
if self._use_peepholes:
|
|
448
|
+
c = (
|
|
449
|
+
sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev +
|
|
450
|
+
sigmoid(i + w_i_diag * c_prev) * tanh(j))
|
|
451
|
+
else:
|
|
452
|
+
c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) * tanh(j))
|
|
453
|
+
|
|
454
|
+
if self._cell_clip is not None:
|
|
455
|
+
# pylint: disable=invalid-unary-operand-type
|
|
456
|
+
c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip)
|
|
457
|
+
# pylint: enable=invalid-unary-operand-type
|
|
458
|
+
|
|
459
|
+
if self._use_peepholes:
|
|
460
|
+
m = sigmoid(o + w_o_diag * c) * tanh(c)
|
|
461
|
+
else:
|
|
462
|
+
m = sigmoid(o) * tanh(c)
|
|
463
|
+
m_prev_freq = m
|
|
464
|
+
if fq == 0:
|
|
465
|
+
state_out = array_ops.concat([c, m], 1)
|
|
466
|
+
m_out = m
|
|
467
|
+
else:
|
|
468
|
+
state_out = array_ops.concat([state_out, c, m], 1)
|
|
469
|
+
m_out = array_ops.concat([m_out, m], 1)
|
|
470
|
+
return m_out, state_out
|
|
471
|
+
|
|
472
|
+
def _make_tf_features(self, input_feat):
|
|
473
|
+
"""Make the frequency features.
|
|
474
|
+
|
|
475
|
+
Args:
|
|
476
|
+
input_feat: input Tensor, 2D, batch x num_units.
|
|
477
|
+
|
|
478
|
+
Returns:
|
|
479
|
+
A list of frequency features, with each element containing:
|
|
480
|
+
- A 2D, batch x output_dim, Tensor representing the time-frequency feature
|
|
481
|
+
for that frequency index. Here output_dim is feature_size.
|
|
482
|
+
Raises:
|
|
483
|
+
ValueError: if input_size cannot be inferred from static shape inference.
|
|
484
|
+
"""
|
|
485
|
+
input_size = input_feat.get_shape().with_rank(2).dims[-1].value
|
|
486
|
+
if input_size is None:
|
|
487
|
+
raise ValueError("Cannot infer input_size from static shape inference.")
|
|
488
|
+
num_feats = int(
|
|
489
|
+
(input_size - self._feature_size) / (self._frequency_skip)) + 1
|
|
490
|
+
freq_inputs = []
|
|
491
|
+
for f in range(num_feats):
|
|
492
|
+
cur_input = array_ops.slice(input_feat, [0, f * self._frequency_skip],
|
|
493
|
+
[-1, self._feature_size])
|
|
494
|
+
freq_inputs.append(cur_input)
|
|
495
|
+
return freq_inputs
|
|
496
|
+
|
|
497
|
+
|
|
498
|
+
class GridLSTMCell(rnn_cell_impl.RNNCell):
|
|
499
|
+
"""Grid Long short-term memory unit (LSTM) recurrent network cell.
|
|
500
|
+
|
|
501
|
+
The default is based on:
|
|
502
|
+
Nal Kalchbrenner, Ivo Danihelka and Alex Graves
|
|
503
|
+
"Grid Long Short-Term Memory," Proc. ICLR 2016.
|
|
504
|
+
http://arxiv.org/abs/1507.01526
|
|
505
|
+
|
|
506
|
+
When peephole connections are used, the implementation is based on:
|
|
507
|
+
Tara N. Sainath and Bo Li
|
|
508
|
+
"Modeling Time-Frequency Patterns with LSTM vs. Convolutional Architectures
|
|
509
|
+
for LVCSR Tasks." submitted to INTERSPEECH, 2016.
|
|
510
|
+
|
|
511
|
+
The code uses optional peephole connections, shared_weights and cell clipping.
|
|
512
|
+
"""
|
|
513
|
+
|
|
514
|
+
def __init__(self,
|
|
515
|
+
num_units,
|
|
516
|
+
use_peepholes=False,
|
|
517
|
+
share_time_frequency_weights=False,
|
|
518
|
+
cell_clip=None,
|
|
519
|
+
initializer=None,
|
|
520
|
+
num_unit_shards=1,
|
|
521
|
+
forget_bias=1.0,
|
|
522
|
+
feature_size=None,
|
|
523
|
+
frequency_skip=None,
|
|
524
|
+
num_frequency_blocks=None,
|
|
525
|
+
start_freqindex_list=None,
|
|
526
|
+
end_freqindex_list=None,
|
|
527
|
+
couple_input_forget_gates=False,
|
|
528
|
+
state_is_tuple=True,
|
|
529
|
+
reuse=None):
|
|
530
|
+
"""Initialize the parameters for an LSTM cell.
|
|
531
|
+
|
|
532
|
+
Args:
|
|
533
|
+
num_units: int, The number of units in the LSTM cell
|
|
534
|
+
use_peepholes: (optional) bool, default False. Set True to enable
|
|
535
|
+
diagonal/peephole connections.
|
|
536
|
+
share_time_frequency_weights: (optional) bool, default False. Set True to
|
|
537
|
+
enable shared cell weights between time and frequency LSTMs.
|
|
538
|
+
cell_clip: (optional) A float value, default None, if provided the cell
|
|
539
|
+
state is clipped by this value prior to the cell output activation.
|
|
540
|
+
initializer: (optional) The initializer to use for the weight and
|
|
541
|
+
projection matrices, default None.
|
|
542
|
+
num_unit_shards: (optional) int, default 1, How to split the weight
|
|
543
|
+
matrix. If > 1, the weight matrix is stored across num_unit_shards.
|
|
544
|
+
forget_bias: (optional) float, default 1.0, The initial bias of the
|
|
545
|
+
forget gates, used to reduce the scale of forgetting at the beginning
|
|
546
|
+
of the training.
|
|
547
|
+
feature_size: (optional) int, default None, The size of the input feature
|
|
548
|
+
the LSTM spans over.
|
|
549
|
+
frequency_skip: (optional) int, default None, The amount the LSTM filter
|
|
550
|
+
is shifted by in frequency.
|
|
551
|
+
num_frequency_blocks: [required] A list of frequency blocks needed to
|
|
552
|
+
cover the whole input feature splitting defined by start_freqindex_list
|
|
553
|
+
and end_freqindex_list.
|
|
554
|
+
start_freqindex_list: [optional], list of ints, default None, The
|
|
555
|
+
starting frequency index for each frequency block.
|
|
556
|
+
end_freqindex_list: [optional], list of ints, default None. The ending
|
|
557
|
+
frequency index for each frequency block.
|
|
558
|
+
couple_input_forget_gates: (optional) bool, default False, Whether to
|
|
559
|
+
couple the input and forget gates, i.e. f_gate = 1.0 - i_gate, to reduce
|
|
560
|
+
model parameters and computation cost.
|
|
561
|
+
state_is_tuple: If True, accepted and returned states are 2-tuples of
|
|
562
|
+
the `c_state` and `m_state`. By default (False), they are concatenated
|
|
563
|
+
along the column axis. This default behavior will soon be deprecated.
|
|
564
|
+
reuse: (optional) Python boolean describing whether to reuse variables
|
|
565
|
+
in an existing scope. If not `True`, and the existing scope already has
|
|
566
|
+
the given variables, an error is raised.
|
|
567
|
+
Raises:
|
|
568
|
+
ValueError: if the num_frequency_blocks list is not specified
|
|
569
|
+
"""
|
|
570
|
+
super(GridLSTMCell, self).__init__(_reuse=reuse)
|
|
571
|
+
if not state_is_tuple:
|
|
572
|
+
logging.warn("%s: Using a concatenated state is slower and will soon be "
|
|
573
|
+
"deprecated. Use state_is_tuple=True.", self)
|
|
574
|
+
self._num_units = num_units
|
|
575
|
+
self._use_peepholes = use_peepholes
|
|
576
|
+
self._share_time_frequency_weights = share_time_frequency_weights
|
|
577
|
+
self._couple_input_forget_gates = couple_input_forget_gates
|
|
578
|
+
self._state_is_tuple = state_is_tuple
|
|
579
|
+
self._cell_clip = cell_clip
|
|
580
|
+
self._initializer = initializer
|
|
581
|
+
self._num_unit_shards = num_unit_shards
|
|
582
|
+
self._forget_bias = forget_bias
|
|
583
|
+
self._feature_size = feature_size
|
|
584
|
+
self._frequency_skip = frequency_skip
|
|
585
|
+
self._start_freqindex_list = start_freqindex_list
|
|
586
|
+
self._end_freqindex_list = end_freqindex_list
|
|
587
|
+
self._num_frequency_blocks = num_frequency_blocks
|
|
588
|
+
self._total_blocks = 0
|
|
589
|
+
self._reuse = reuse
|
|
590
|
+
if self._num_frequency_blocks is None:
|
|
591
|
+
raise ValueError("Must specify num_frequency_blocks")
|
|
592
|
+
|
|
593
|
+
for block_index in range(len(self._num_frequency_blocks)):
|
|
594
|
+
self._total_blocks += int(self._num_frequency_blocks[block_index])
|
|
595
|
+
if state_is_tuple:
|
|
596
|
+
state_names = ""
|
|
597
|
+
for block_index in range(len(self._num_frequency_blocks)):
|
|
598
|
+
for freq_index in range(self._num_frequency_blocks[block_index]):
|
|
599
|
+
name_prefix = "state_f%02d_b%02d" % (freq_index, block_index)
|
|
600
|
+
state_names += ("%s_c, %s_m," % (name_prefix, name_prefix))
|
|
601
|
+
self._state_tuple_type = collections.namedtuple("GridLSTMStateTuple",
|
|
602
|
+
state_names.strip(","))
|
|
603
|
+
self._state_size = self._state_tuple_type(*(
|
|
604
|
+
[num_units, num_units] * self._total_blocks))
|
|
605
|
+
else:
|
|
606
|
+
self._state_tuple_type = None
|
|
607
|
+
self._state_size = num_units * self._total_blocks * 2
|
|
608
|
+
self._output_size = num_units * self._total_blocks * 2
|
|
609
|
+
|
|
610
|
+
@property
|
|
611
|
+
def output_size(self):
|
|
612
|
+
return self._output_size
|
|
613
|
+
|
|
614
|
+
@property
|
|
615
|
+
def state_size(self):
|
|
616
|
+
return self._state_size
|
|
617
|
+
|
|
618
|
+
@property
|
|
619
|
+
def state_tuple_type(self):
|
|
620
|
+
return self._state_tuple_type
|
|
621
|
+
|
|
622
|
+
def call(self, inputs, state):
|
|
623
|
+
"""Run one step of LSTM.
|
|
624
|
+
|
|
625
|
+
Args:
|
|
626
|
+
inputs: input Tensor, 2D, [batch, feature_size].
|
|
627
|
+
state: Tensor or tuple of Tensors, 2D, [batch, state_size], depends on the
|
|
628
|
+
flag self._state_is_tuple.
|
|
629
|
+
|
|
630
|
+
Returns:
|
|
631
|
+
A tuple containing:
|
|
632
|
+
- A 2D, [batch, output_dim], Tensor representing the output of the LSTM
|
|
633
|
+
after reading "inputs" when previous state was "state".
|
|
634
|
+
Here output_dim is num_units.
|
|
635
|
+
- A 2D, [batch, state_size], Tensor representing the new state of LSTM
|
|
636
|
+
after reading "inputs" when previous state was "state".
|
|
637
|
+
Raises:
|
|
638
|
+
ValueError: if an input_size was specified and the provided inputs have
|
|
639
|
+
a different dimension.
|
|
640
|
+
"""
|
|
641
|
+
batch_size = tensor_shape.dimension_value(
|
|
642
|
+
inputs.shape[0]) or array_ops.shape(inputs)[0]
|
|
643
|
+
freq_inputs = self._make_tf_features(inputs)
|
|
644
|
+
m_out_lst = []
|
|
645
|
+
state_out_lst = []
|
|
646
|
+
for block in range(len(freq_inputs)):
|
|
647
|
+
m_out_lst_current, state_out_lst_current = self._compute(
|
|
648
|
+
freq_inputs[block],
|
|
649
|
+
block,
|
|
650
|
+
state,
|
|
651
|
+
batch_size,
|
|
652
|
+
state_is_tuple=self._state_is_tuple)
|
|
653
|
+
m_out_lst.extend(m_out_lst_current)
|
|
654
|
+
state_out_lst.extend(state_out_lst_current)
|
|
655
|
+
if self._state_is_tuple:
|
|
656
|
+
state_out = self._state_tuple_type(*state_out_lst)
|
|
657
|
+
else:
|
|
658
|
+
state_out = array_ops.concat(state_out_lst, 1)
|
|
659
|
+
m_out = array_ops.concat(m_out_lst, 1)
|
|
660
|
+
return m_out, state_out
|
|
661
|
+
|
|
662
|
+
def _compute(self,
|
|
663
|
+
freq_inputs,
|
|
664
|
+
block,
|
|
665
|
+
state,
|
|
666
|
+
batch_size,
|
|
667
|
+
state_prefix="state",
|
|
668
|
+
state_is_tuple=True):
|
|
669
|
+
"""Run the actual computation of one step LSTM.
|
|
670
|
+
|
|
671
|
+
Args:
|
|
672
|
+
freq_inputs: list of Tensors, 2D, [batch, feature_size].
|
|
673
|
+
block: int, current frequency block index to process.
|
|
674
|
+
state: Tensor or tuple of Tensors, 2D, [batch, state_size], it depends on
|
|
675
|
+
the flag state_is_tuple.
|
|
676
|
+
batch_size: int32, batch size.
|
|
677
|
+
state_prefix: (optional) string, name prefix for states, defaults to
|
|
678
|
+
"state".
|
|
679
|
+
state_is_tuple: boolean, indicates whether the state is a tuple or Tensor.
|
|
680
|
+
|
|
681
|
+
Returns:
|
|
682
|
+
A tuple, containing:
|
|
683
|
+
- A list of [batch, output_dim] Tensors, representing the output of the
|
|
684
|
+
LSTM given the inputs and state.
|
|
685
|
+
- A list of [batch, state_size] Tensors, representing the LSTM state
|
|
686
|
+
values given the inputs and previous state.
|
|
687
|
+
"""
|
|
688
|
+
sigmoid = math_ops.sigmoid
|
|
689
|
+
tanh = math_ops.tanh
|
|
690
|
+
num_gates = 3 if self._couple_input_forget_gates else 4
|
|
691
|
+
dtype = freq_inputs[0].dtype
|
|
692
|
+
actual_input_size = freq_inputs[0].get_shape().as_list()[1]
|
|
693
|
+
|
|
694
|
+
concat_w_f = _get_concat_variable(
|
|
695
|
+
"W_f_%d" % block,
|
|
696
|
+
[actual_input_size + 2 * self._num_units, num_gates * self._num_units],
|
|
697
|
+
dtype, self._num_unit_shards)
|
|
698
|
+
b_f = vs.get_variable(
|
|
699
|
+
"B_f_%d" % block,
|
|
700
|
+
shape=[num_gates * self._num_units],
|
|
701
|
+
initializer=init_ops.zeros_initializer(),
|
|
702
|
+
dtype=dtype)
|
|
703
|
+
if not self._share_time_frequency_weights:
|
|
704
|
+
concat_w_t = _get_concat_variable("W_t_%d" % block, [
|
|
705
|
+
actual_input_size + 2 * self._num_units, num_gates * self._num_units
|
|
706
|
+
], dtype, self._num_unit_shards)
|
|
707
|
+
b_t = vs.get_variable(
|
|
708
|
+
"B_t_%d" % block,
|
|
709
|
+
shape=[num_gates * self._num_units],
|
|
710
|
+
initializer=init_ops.zeros_initializer(),
|
|
711
|
+
dtype=dtype)
|
|
712
|
+
|
|
713
|
+
if self._use_peepholes:
|
|
714
|
+
# Diagonal connections
|
|
715
|
+
if not self._couple_input_forget_gates:
|
|
716
|
+
w_f_diag_freqf = vs.get_variable(
|
|
717
|
+
"W_F_diag_freqf_%d" % block, shape=[self._num_units], dtype=dtype)
|
|
718
|
+
w_f_diag_freqt = vs.get_variable(
|
|
719
|
+
"W_F_diag_freqt_%d" % block, shape=[self._num_units], dtype=dtype)
|
|
720
|
+
w_i_diag_freqf = vs.get_variable(
|
|
721
|
+
"W_I_diag_freqf_%d" % block, shape=[self._num_units], dtype=dtype)
|
|
722
|
+
w_i_diag_freqt = vs.get_variable(
|
|
723
|
+
"W_I_diag_freqt_%d" % block, shape=[self._num_units], dtype=dtype)
|
|
724
|
+
w_o_diag_freqf = vs.get_variable(
|
|
725
|
+
"W_O_diag_freqf_%d" % block, shape=[self._num_units], dtype=dtype)
|
|
726
|
+
w_o_diag_freqt = vs.get_variable(
|
|
727
|
+
"W_O_diag_freqt_%d" % block, shape=[self._num_units], dtype=dtype)
|
|
728
|
+
if not self._share_time_frequency_weights:
|
|
729
|
+
if not self._couple_input_forget_gates:
|
|
730
|
+
w_f_diag_timef = vs.get_variable(
|
|
731
|
+
"W_F_diag_timef_%d" % block, shape=[self._num_units], dtype=dtype)
|
|
732
|
+
w_f_diag_timet = vs.get_variable(
|
|
733
|
+
"W_F_diag_timet_%d" % block, shape=[self._num_units], dtype=dtype)
|
|
734
|
+
w_i_diag_timef = vs.get_variable(
|
|
735
|
+
"W_I_diag_timef_%d" % block, shape=[self._num_units], dtype=dtype)
|
|
736
|
+
w_i_diag_timet = vs.get_variable(
|
|
737
|
+
"W_I_diag_timet_%d" % block, shape=[self._num_units], dtype=dtype)
|
|
738
|
+
w_o_diag_timef = vs.get_variable(
|
|
739
|
+
"W_O_diag_timef_%d" % block, shape=[self._num_units], dtype=dtype)
|
|
740
|
+
w_o_diag_timet = vs.get_variable(
|
|
741
|
+
"W_O_diag_timet_%d" % block, shape=[self._num_units], dtype=dtype)
|
|
742
|
+
|
|
743
|
+
# initialize the first freq state to be zero
|
|
744
|
+
m_prev_freq = array_ops.zeros([batch_size, self._num_units], dtype)
|
|
745
|
+
c_prev_freq = array_ops.zeros([batch_size, self._num_units], dtype)
|
|
746
|
+
for freq_index in range(len(freq_inputs)):
|
|
747
|
+
if state_is_tuple:
|
|
748
|
+
name_prefix = "%s_f%02d_b%02d" % (state_prefix, freq_index, block)
|
|
749
|
+
c_prev_time = getattr(state, name_prefix + "_c")
|
|
750
|
+
m_prev_time = getattr(state, name_prefix + "_m")
|
|
751
|
+
else:
|
|
752
|
+
c_prev_time = array_ops.slice(
|
|
753
|
+
state, [0, 2 * freq_index * self._num_units], [-1, self._num_units])
|
|
754
|
+
m_prev_time = array_ops.slice(
|
|
755
|
+
state, [0, (2 * freq_index + 1) * self._num_units],
|
|
756
|
+
[-1, self._num_units])
|
|
757
|
+
|
|
758
|
+
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
|
|
759
|
+
cell_inputs = array_ops.concat(
|
|
760
|
+
[freq_inputs[freq_index], m_prev_time, m_prev_freq], 1)
|
|
761
|
+
|
|
762
|
+
# F-LSTM
|
|
763
|
+
lstm_matrix_freq = nn_ops.bias_add(
|
|
764
|
+
math_ops.matmul(cell_inputs, concat_w_f), b_f)
|
|
765
|
+
if self._couple_input_forget_gates:
|
|
766
|
+
i_freq, j_freq, o_freq = array_ops.split(
|
|
767
|
+
value=lstm_matrix_freq, num_or_size_splits=num_gates, axis=1)
|
|
768
|
+
f_freq = None
|
|
769
|
+
else:
|
|
770
|
+
i_freq, j_freq, f_freq, o_freq = array_ops.split(
|
|
771
|
+
value=lstm_matrix_freq, num_or_size_splits=num_gates, axis=1)
|
|
772
|
+
# T-LSTM
|
|
773
|
+
if self._share_time_frequency_weights:
|
|
774
|
+
i_time = i_freq
|
|
775
|
+
j_time = j_freq
|
|
776
|
+
f_time = f_freq
|
|
777
|
+
o_time = o_freq
|
|
778
|
+
else:
|
|
779
|
+
lstm_matrix_time = nn_ops.bias_add(
|
|
780
|
+
math_ops.matmul(cell_inputs, concat_w_t), b_t)
|
|
781
|
+
if self._couple_input_forget_gates:
|
|
782
|
+
i_time, j_time, o_time = array_ops.split(
|
|
783
|
+
value=lstm_matrix_time, num_or_size_splits=num_gates, axis=1)
|
|
784
|
+
f_time = None
|
|
785
|
+
else:
|
|
786
|
+
i_time, j_time, f_time, o_time = array_ops.split(
|
|
787
|
+
value=lstm_matrix_time, num_or_size_splits=num_gates, axis=1)
|
|
788
|
+
|
|
789
|
+
# F-LSTM c_freq
|
|
790
|
+
# input gate activations
|
|
791
|
+
if self._use_peepholes:
|
|
792
|
+
i_freq_g = sigmoid(i_freq + w_i_diag_freqf * c_prev_freq +
|
|
793
|
+
w_i_diag_freqt * c_prev_time)
|
|
794
|
+
else:
|
|
795
|
+
i_freq_g = sigmoid(i_freq)
|
|
796
|
+
# forget gate activations
|
|
797
|
+
if self._couple_input_forget_gates:
|
|
798
|
+
f_freq_g = 1.0 - i_freq_g
|
|
799
|
+
else:
|
|
800
|
+
if self._use_peepholes:
|
|
801
|
+
f_freq_g = sigmoid(f_freq + self._forget_bias + w_f_diag_freqf *
|
|
802
|
+
c_prev_freq + w_f_diag_freqt * c_prev_time)
|
|
803
|
+
else:
|
|
804
|
+
f_freq_g = sigmoid(f_freq + self._forget_bias)
|
|
805
|
+
# cell state
|
|
806
|
+
c_freq = f_freq_g * c_prev_freq + i_freq_g * tanh(j_freq)
|
|
807
|
+
if self._cell_clip is not None:
|
|
808
|
+
# pylint: disable=invalid-unary-operand-type
|
|
809
|
+
c_freq = clip_ops.clip_by_value(c_freq, -self._cell_clip,
|
|
810
|
+
self._cell_clip)
|
|
811
|
+
# pylint: enable=invalid-unary-operand-type
|
|
812
|
+
|
|
813
|
+
# T-LSTM c_freq
|
|
814
|
+
# input gate activations
|
|
815
|
+
if self._use_peepholes:
|
|
816
|
+
if self._share_time_frequency_weights:
|
|
817
|
+
i_time_g = sigmoid(i_time + w_i_diag_freqf * c_prev_freq +
|
|
818
|
+
w_i_diag_freqt * c_prev_time)
|
|
819
|
+
else:
|
|
820
|
+
i_time_g = sigmoid(i_time + w_i_diag_timef * c_prev_freq +
|
|
821
|
+
w_i_diag_timet * c_prev_time)
|
|
822
|
+
else:
|
|
823
|
+
i_time_g = sigmoid(i_time)
|
|
824
|
+
# forget gate activations
|
|
825
|
+
if self._couple_input_forget_gates:
|
|
826
|
+
f_time_g = 1.0 - i_time_g
|
|
827
|
+
else:
|
|
828
|
+
if self._use_peepholes:
|
|
829
|
+
if self._share_time_frequency_weights:
|
|
830
|
+
f_time_g = sigmoid(f_time + self._forget_bias + w_f_diag_freqf *
|
|
831
|
+
c_prev_freq + w_f_diag_freqt * c_prev_time)
|
|
832
|
+
else:
|
|
833
|
+
f_time_g = sigmoid(f_time + self._forget_bias + w_f_diag_timef *
|
|
834
|
+
c_prev_freq + w_f_diag_timet * c_prev_time)
|
|
835
|
+
else:
|
|
836
|
+
f_time_g = sigmoid(f_time + self._forget_bias)
|
|
837
|
+
# cell state
|
|
838
|
+
c_time = f_time_g * c_prev_time + i_time_g * tanh(j_time)
|
|
839
|
+
if self._cell_clip is not None:
|
|
840
|
+
# pylint: disable=invalid-unary-operand-type
|
|
841
|
+
c_time = clip_ops.clip_by_value(c_time, -self._cell_clip,
|
|
842
|
+
self._cell_clip)
|
|
843
|
+
# pylint: enable=invalid-unary-operand-type
|
|
844
|
+
|
|
845
|
+
# F-LSTM m_freq
|
|
846
|
+
if self._use_peepholes:
|
|
847
|
+
m_freq = sigmoid(o_freq + w_o_diag_freqf * c_freq +
|
|
848
|
+
w_o_diag_freqt * c_time) * tanh(c_freq)
|
|
849
|
+
else:
|
|
850
|
+
m_freq = sigmoid(o_freq) * tanh(c_freq)
|
|
851
|
+
|
|
852
|
+
# T-LSTM m_time
|
|
853
|
+
if self._use_peepholes:
|
|
854
|
+
if self._share_time_frequency_weights:
|
|
855
|
+
m_time = sigmoid(o_time + w_o_diag_freqf * c_freq +
|
|
856
|
+
w_o_diag_freqt * c_time) * tanh(c_time)
|
|
857
|
+
else:
|
|
858
|
+
m_time = sigmoid(o_time + w_o_diag_timef * c_freq +
|
|
859
|
+
w_o_diag_timet * c_time) * tanh(c_time)
|
|
860
|
+
else:
|
|
861
|
+
m_time = sigmoid(o_time) * tanh(c_time)
|
|
862
|
+
|
|
863
|
+
m_prev_freq = m_freq
|
|
864
|
+
c_prev_freq = c_freq
|
|
865
|
+
# Concatenate the outputs for T-LSTM and F-LSTM for each shift
|
|
866
|
+
if freq_index == 0:
|
|
867
|
+
state_out_lst = [c_time, m_time]
|
|
868
|
+
m_out_lst = [m_time, m_freq]
|
|
869
|
+
else:
|
|
870
|
+
state_out_lst.extend([c_time, m_time])
|
|
871
|
+
m_out_lst.extend([m_time, m_freq])
|
|
872
|
+
|
|
873
|
+
return m_out_lst, state_out_lst
|
|
874
|
+
|
|
875
|
+
def _make_tf_features(self, input_feat, slice_offset=0):
|
|
876
|
+
"""Make the frequency features.
|
|
877
|
+
|
|
878
|
+
Args:
|
|
879
|
+
input_feat: input Tensor, 2D, [batch, num_units].
|
|
880
|
+
slice_offset: (optional) Python int, default 0, the slicing offset is only
|
|
881
|
+
used for the backward processing in the BidirectionalGridLSTMCell. It
|
|
882
|
+
specifies a different starting point instead of always 0 to enable the
|
|
883
|
+
forward and backward processing look at different frequency blocks.
|
|
884
|
+
|
|
885
|
+
Returns:
|
|
886
|
+
A list of frequency features, with each element containing:
|
|
887
|
+
- A 2D, [batch, output_dim], Tensor representing the time-frequency
|
|
888
|
+
feature for that frequency index. Here output_dim is feature_size.
|
|
889
|
+
Raises:
|
|
890
|
+
ValueError: if input_size cannot be inferred from static shape inference.
|
|
891
|
+
"""
|
|
892
|
+
input_size = input_feat.get_shape().with_rank(2).dims[-1].value
|
|
893
|
+
if input_size is None:
|
|
894
|
+
raise ValueError("Cannot infer input_size from static shape inference.")
|
|
895
|
+
if slice_offset > 0:
|
|
896
|
+
# Padding to the end
|
|
897
|
+
inputs = array_ops.pad(input_feat,
|
|
898
|
+
array_ops.constant(
|
|
899
|
+
[0, 0, 0, slice_offset],
|
|
900
|
+
shape=[2, 2],
|
|
901
|
+
dtype=dtypes.int32), "CONSTANT")
|
|
902
|
+
elif slice_offset < 0:
|
|
903
|
+
# Padding to the front
|
|
904
|
+
inputs = array_ops.pad(input_feat,
|
|
905
|
+
array_ops.constant(
|
|
906
|
+
[0, 0, -slice_offset, 0],
|
|
907
|
+
shape=[2, 2],
|
|
908
|
+
dtype=dtypes.int32), "CONSTANT")
|
|
909
|
+
slice_offset = 0
|
|
910
|
+
else:
|
|
911
|
+
inputs = input_feat
|
|
912
|
+
freq_inputs = []
|
|
913
|
+
if not self._start_freqindex_list:
|
|
914
|
+
if len(self._num_frequency_blocks) != 1:
|
|
915
|
+
raise ValueError("Length of num_frequency_blocks"
|
|
916
|
+
" is not 1, but instead is %d" %
|
|
917
|
+
len(self._num_frequency_blocks))
|
|
918
|
+
num_feats = int(
|
|
919
|
+
(input_size - self._feature_size) / (self._frequency_skip)) + 1
|
|
920
|
+
if num_feats != self._num_frequency_blocks[0]:
|
|
921
|
+
raise ValueError(
|
|
922
|
+
"Invalid num_frequency_blocks, requires %d but gets %d, please"
|
|
923
|
+
" check the input size and filter config are correct." %
|
|
924
|
+
(self._num_frequency_blocks[0], num_feats))
|
|
925
|
+
block_inputs = []
|
|
926
|
+
for f in range(num_feats):
|
|
927
|
+
cur_input = array_ops.slice(
|
|
928
|
+
inputs, [0, slice_offset + f * self._frequency_skip],
|
|
929
|
+
[-1, self._feature_size])
|
|
930
|
+
block_inputs.append(cur_input)
|
|
931
|
+
freq_inputs.append(block_inputs)
|
|
932
|
+
else:
|
|
933
|
+
if len(self._start_freqindex_list) != len(self._end_freqindex_list):
|
|
934
|
+
raise ValueError("Length of start and end freqindex_list"
|
|
935
|
+
" does not match %d %d",
|
|
936
|
+
len(self._start_freqindex_list),
|
|
937
|
+
len(self._end_freqindex_list))
|
|
938
|
+
if len(self._num_frequency_blocks) != len(self._start_freqindex_list):
|
|
939
|
+
raise ValueError("Length of num_frequency_blocks"
|
|
940
|
+
" is not equal to start_freqindex_list %d %d",
|
|
941
|
+
len(self._num_frequency_blocks),
|
|
942
|
+
len(self._start_freqindex_list))
|
|
943
|
+
for b in range(len(self._start_freqindex_list)):
|
|
944
|
+
start_index = self._start_freqindex_list[b]
|
|
945
|
+
end_index = self._end_freqindex_list[b]
|
|
946
|
+
cur_size = end_index - start_index
|
|
947
|
+
block_feats = int(
|
|
948
|
+
(cur_size - self._feature_size) / (self._frequency_skip)) + 1
|
|
949
|
+
if block_feats != self._num_frequency_blocks[b]:
|
|
950
|
+
raise ValueError(
|
|
951
|
+
"Invalid num_frequency_blocks, requires %d but gets %d, please"
|
|
952
|
+
" check the input size and filter config are correct." %
|
|
953
|
+
(self._num_frequency_blocks[b], block_feats))
|
|
954
|
+
block_inputs = []
|
|
955
|
+
for f in range(block_feats):
|
|
956
|
+
cur_input = array_ops.slice(
|
|
957
|
+
inputs,
|
|
958
|
+
[0, start_index + slice_offset + f * self._frequency_skip],
|
|
959
|
+
[-1, self._feature_size])
|
|
960
|
+
block_inputs.append(cur_input)
|
|
961
|
+
freq_inputs.append(block_inputs)
|
|
962
|
+
return freq_inputs
|
|
963
|
+
|
|
964
|
+
|
|
965
|
+
class BidirectionalGridLSTMCell(GridLSTMCell):
|
|
966
|
+
"""Bidirectional GridLstm cell.
|
|
967
|
+
|
|
968
|
+
The bidirection connection is only used in the frequency direction, which
|
|
969
|
+
hence doesn't affect the time direction's real-time processing that is
|
|
970
|
+
required for online recognition systems.
|
|
971
|
+
The current implementation uses different weights for the two directions.
|
|
972
|
+
"""
|
|
973
|
+
|
|
974
|
+
def __init__(self,
|
|
975
|
+
num_units,
|
|
976
|
+
use_peepholes=False,
|
|
977
|
+
share_time_frequency_weights=False,
|
|
978
|
+
cell_clip=None,
|
|
979
|
+
initializer=None,
|
|
980
|
+
num_unit_shards=1,
|
|
981
|
+
forget_bias=1.0,
|
|
982
|
+
feature_size=None,
|
|
983
|
+
frequency_skip=None,
|
|
984
|
+
num_frequency_blocks=None,
|
|
985
|
+
start_freqindex_list=None,
|
|
986
|
+
end_freqindex_list=None,
|
|
987
|
+
couple_input_forget_gates=False,
|
|
988
|
+
backward_slice_offset=0,
|
|
989
|
+
reuse=None):
|
|
990
|
+
"""Initialize the parameters for an LSTM cell.
|
|
991
|
+
|
|
992
|
+
Args:
|
|
993
|
+
num_units: int, The number of units in the LSTM cell
|
|
994
|
+
use_peepholes: (optional) bool, default False. Set True to enable
|
|
995
|
+
diagonal/peephole connections.
|
|
996
|
+
share_time_frequency_weights: (optional) bool, default False. Set True to
|
|
997
|
+
enable shared cell weights between time and frequency LSTMs.
|
|
998
|
+
cell_clip: (optional) A float value, default None, if provided the cell
|
|
999
|
+
state is clipped by this value prior to the cell output activation.
|
|
1000
|
+
initializer: (optional) The initializer to use for the weight and
|
|
1001
|
+
projection matrices, default None.
|
|
1002
|
+
num_unit_shards: (optional) int, default 1, How to split the weight
|
|
1003
|
+
matrix. If > 1, the weight matrix is stored across num_unit_shards.
|
|
1004
|
+
forget_bias: (optional) float, default 1.0, The initial bias of the
|
|
1005
|
+
forget gates, used to reduce the scale of forgetting at the beginning
|
|
1006
|
+
of the training.
|
|
1007
|
+
feature_size: (optional) int, default None, The size of the input feature
|
|
1008
|
+
the LSTM spans over.
|
|
1009
|
+
frequency_skip: (optional) int, default None, The amount the LSTM filter
|
|
1010
|
+
is shifted by in frequency.
|
|
1011
|
+
num_frequency_blocks: [required] A list of frequency blocks needed to
|
|
1012
|
+
cover the whole input feature splitting defined by start_freqindex_list
|
|
1013
|
+
and end_freqindex_list.
|
|
1014
|
+
start_freqindex_list: [optional], list of ints, default None, The
|
|
1015
|
+
starting frequency index for each frequency block.
|
|
1016
|
+
end_freqindex_list: [optional], list of ints, default None. The ending
|
|
1017
|
+
frequency index for each frequency block.
|
|
1018
|
+
couple_input_forget_gates: (optional) bool, default False, Whether to
|
|
1019
|
+
couple the input and forget gates, i.e. f_gate = 1.0 - i_gate, to reduce
|
|
1020
|
+
model parameters and computation cost.
|
|
1021
|
+
backward_slice_offset: (optional) int32, default 0, the starting offset to
|
|
1022
|
+
slice the feature for backward processing.
|
|
1023
|
+
reuse: (optional) Python boolean describing whether to reuse variables
|
|
1024
|
+
in an existing scope. If not `True`, and the existing scope already has
|
|
1025
|
+
the given variables, an error is raised.
|
|
1026
|
+
"""
|
|
1027
|
+
super(BidirectionalGridLSTMCell, self).__init__(
|
|
1028
|
+
num_units, use_peepholes, share_time_frequency_weights, cell_clip,
|
|
1029
|
+
initializer, num_unit_shards, forget_bias, feature_size, frequency_skip,
|
|
1030
|
+
num_frequency_blocks, start_freqindex_list, end_freqindex_list,
|
|
1031
|
+
couple_input_forget_gates, True, reuse)
|
|
1032
|
+
self._backward_slice_offset = int(backward_slice_offset)
|
|
1033
|
+
state_names = ""
|
|
1034
|
+
for direction in ["fwd", "bwd"]:
|
|
1035
|
+
for block_index in range(len(self._num_frequency_blocks)):
|
|
1036
|
+
for freq_index in range(self._num_frequency_blocks[block_index]):
|
|
1037
|
+
name_prefix = "%s_state_f%02d_b%02d" % (direction, freq_index,
|
|
1038
|
+
block_index)
|
|
1039
|
+
state_names += ("%s_c, %s_m," % (name_prefix, name_prefix))
|
|
1040
|
+
self._state_tuple_type = collections.namedtuple(
|
|
1041
|
+
"BidirectionalGridLSTMStateTuple", state_names.strip(","))
|
|
1042
|
+
self._state_size = self._state_tuple_type(*(
|
|
1043
|
+
[num_units, num_units] * self._total_blocks * 2))
|
|
1044
|
+
self._output_size = 2 * num_units * self._total_blocks * 2
|
|
1045
|
+
|
|
1046
|
+
def call(self, inputs, state):
|
|
1047
|
+
"""Run one step of LSTM.
|
|
1048
|
+
|
|
1049
|
+
Args:
|
|
1050
|
+
inputs: input Tensor, 2D, [batch, num_units].
|
|
1051
|
+
state: tuple of Tensors, 2D, [batch, state_size].
|
|
1052
|
+
|
|
1053
|
+
Returns:
|
|
1054
|
+
A tuple containing:
|
|
1055
|
+
- A 2D, [batch, output_dim], Tensor representing the output of the LSTM
|
|
1056
|
+
after reading "inputs" when previous state was "state".
|
|
1057
|
+
Here output_dim is num_units.
|
|
1058
|
+
- A 2D, [batch, state_size], Tensor representing the new state of LSTM
|
|
1059
|
+
after reading "inputs" when previous state was "state".
|
|
1060
|
+
Raises:
|
|
1061
|
+
ValueError: if an input_size was specified and the provided inputs have
|
|
1062
|
+
a different dimension.
|
|
1063
|
+
"""
|
|
1064
|
+
batch_size = tensor_shape.dimension_value(
|
|
1065
|
+
inputs.shape[0]) or array_ops.shape(inputs)[0]
|
|
1066
|
+
fwd_inputs = self._make_tf_features(inputs)
|
|
1067
|
+
if self._backward_slice_offset:
|
|
1068
|
+
bwd_inputs = self._make_tf_features(inputs, self._backward_slice_offset)
|
|
1069
|
+
else:
|
|
1070
|
+
bwd_inputs = fwd_inputs
|
|
1071
|
+
|
|
1072
|
+
# Forward processing
|
|
1073
|
+
with vs.variable_scope("fwd"):
|
|
1074
|
+
fwd_m_out_lst = []
|
|
1075
|
+
fwd_state_out_lst = []
|
|
1076
|
+
for block in range(len(fwd_inputs)):
|
|
1077
|
+
fwd_m_out_lst_current, fwd_state_out_lst_current = self._compute(
|
|
1078
|
+
fwd_inputs[block],
|
|
1079
|
+
block,
|
|
1080
|
+
state,
|
|
1081
|
+
batch_size,
|
|
1082
|
+
state_prefix="fwd_state",
|
|
1083
|
+
state_is_tuple=True)
|
|
1084
|
+
fwd_m_out_lst.extend(fwd_m_out_lst_current)
|
|
1085
|
+
fwd_state_out_lst.extend(fwd_state_out_lst_current)
|
|
1086
|
+
# Backward processing
|
|
1087
|
+
bwd_m_out_lst = []
|
|
1088
|
+
bwd_state_out_lst = []
|
|
1089
|
+
with vs.variable_scope("bwd"):
|
|
1090
|
+
for block in range(len(bwd_inputs)):
|
|
1091
|
+
# Reverse the blocks
|
|
1092
|
+
bwd_inputs_reverse = bwd_inputs[block][::-1]
|
|
1093
|
+
bwd_m_out_lst_current, bwd_state_out_lst_current = self._compute(
|
|
1094
|
+
bwd_inputs_reverse,
|
|
1095
|
+
block,
|
|
1096
|
+
state,
|
|
1097
|
+
batch_size,
|
|
1098
|
+
state_prefix="bwd_state",
|
|
1099
|
+
state_is_tuple=True)
|
|
1100
|
+
bwd_m_out_lst.extend(bwd_m_out_lst_current)
|
|
1101
|
+
bwd_state_out_lst.extend(bwd_state_out_lst_current)
|
|
1102
|
+
state_out = self._state_tuple_type(*(fwd_state_out_lst + bwd_state_out_lst))
|
|
1103
|
+
# Outputs are always concated as it is never used separately.
|
|
1104
|
+
m_out = array_ops.concat(fwd_m_out_lst + bwd_m_out_lst, 1)
|
|
1105
|
+
return m_out, state_out
|
|
1106
|
+
|
|
1107
|
+
|
|
1108
|
+
# pylint: disable=protected-access
|
|
1109
|
+
_Linear = core_rnn_cell._Linear # pylint: disable=invalid-name
|
|
1110
|
+
|
|
1111
|
+
|
|
1112
|
+
# pylint: enable=protected-access
|
|
1113
|
+
|
|
1114
|
+
|
|
1115
|
+
class AttentionCellWrapper(rnn_cell_impl.RNNCell):
|
|
1116
|
+
"""Basic attention cell wrapper.
|
|
1117
|
+
|
|
1118
|
+
Implementation based on https://arxiv.org/abs/1601.06733.
|
|
1119
|
+
"""
|
|
1120
|
+
|
|
1121
|
+
def __init__(self,
|
|
1122
|
+
cell,
|
|
1123
|
+
attn_length,
|
|
1124
|
+
attn_size=None,
|
|
1125
|
+
attn_vec_size=None,
|
|
1126
|
+
input_size=None,
|
|
1127
|
+
state_is_tuple=True,
|
|
1128
|
+
reuse=None):
|
|
1129
|
+
"""Create a cell with attention.
|
|
1130
|
+
|
|
1131
|
+
Args:
|
|
1132
|
+
cell: an RNNCell, an attention is added to it.
|
|
1133
|
+
attn_length: integer, the size of an attention window.
|
|
1134
|
+
attn_size: integer, the size of an attention vector. Equal to
|
|
1135
|
+
cell.output_size by default.
|
|
1136
|
+
attn_vec_size: integer, the number of convolutional features calculated
|
|
1137
|
+
on attention state and a size of the hidden layer built from
|
|
1138
|
+
base cell state. Equal attn_size to by default.
|
|
1139
|
+
input_size: integer, the size of a hidden linear layer,
|
|
1140
|
+
built from inputs and attention. Derived from the input tensor
|
|
1141
|
+
by default.
|
|
1142
|
+
state_is_tuple: If True, accepted and returned states are n-tuples, where
|
|
1143
|
+
`n = len(cells)`. By default (False), the states are all
|
|
1144
|
+
concatenated along the column axis.
|
|
1145
|
+
reuse: (optional) Python boolean describing whether to reuse variables
|
|
1146
|
+
in an existing scope. If not `True`, and the existing scope already has
|
|
1147
|
+
the given variables, an error is raised.
|
|
1148
|
+
|
|
1149
|
+
Raises:
|
|
1150
|
+
TypeError: if cell is not an RNNCell.
|
|
1151
|
+
ValueError: if cell returns a state tuple but the flag
|
|
1152
|
+
`state_is_tuple` is `False` or if attn_length is zero or less.
|
|
1153
|
+
"""
|
|
1154
|
+
super(AttentionCellWrapper, self).__init__(_reuse=reuse)
|
|
1155
|
+
rnn_cell_impl.assert_like_rnncell("cell", cell)
|
|
1156
|
+
if nest.is_sequence(cell.state_size) and not state_is_tuple:
|
|
1157
|
+
raise ValueError(
|
|
1158
|
+
"Cell returns tuple of states, but the flag "
|
|
1159
|
+
"state_is_tuple is not set. State size is: %s" % str(cell.state_size))
|
|
1160
|
+
if attn_length <= 0:
|
|
1161
|
+
raise ValueError(
|
|
1162
|
+
"attn_length should be greater than zero, got %s" % str(attn_length))
|
|
1163
|
+
if not state_is_tuple:
|
|
1164
|
+
logging.warn("%s: Using a concatenated state is slower and will soon be "
|
|
1165
|
+
"deprecated. Use state_is_tuple=True.", self)
|
|
1166
|
+
if attn_size is None:
|
|
1167
|
+
attn_size = cell.output_size
|
|
1168
|
+
if attn_vec_size is None:
|
|
1169
|
+
attn_vec_size = attn_size
|
|
1170
|
+
self._state_is_tuple = state_is_tuple
|
|
1171
|
+
self._cell = cell
|
|
1172
|
+
self._attn_vec_size = attn_vec_size
|
|
1173
|
+
self._input_size = input_size
|
|
1174
|
+
self._attn_size = attn_size
|
|
1175
|
+
self._attn_length = attn_length
|
|
1176
|
+
self._reuse = reuse
|
|
1177
|
+
self._linear1 = None
|
|
1178
|
+
self._linear2 = None
|
|
1179
|
+
self._linear3 = None
|
|
1180
|
+
|
|
1181
|
+
@property
|
|
1182
|
+
def state_size(self):
|
|
1183
|
+
size = (self._cell.state_size, self._attn_size,
|
|
1184
|
+
self._attn_size * self._attn_length)
|
|
1185
|
+
if self._state_is_tuple:
|
|
1186
|
+
return size
|
|
1187
|
+
else:
|
|
1188
|
+
return sum(list(size))
|
|
1189
|
+
|
|
1190
|
+
@property
|
|
1191
|
+
def output_size(self):
|
|
1192
|
+
return self._attn_size
|
|
1193
|
+
|
|
1194
|
+
def call(self, inputs, state):
|
|
1195
|
+
"""Long short-term memory cell with attention (LSTMA)."""
|
|
1196
|
+
if self._state_is_tuple:
|
|
1197
|
+
state, attns, attn_states = state
|
|
1198
|
+
else:
|
|
1199
|
+
states = state
|
|
1200
|
+
state = array_ops.slice(states, [0, 0], [-1, self._cell.state_size])
|
|
1201
|
+
attns = array_ops.slice(states, [0, self._cell.state_size],
|
|
1202
|
+
[-1, self._attn_size])
|
|
1203
|
+
attn_states = array_ops.slice(
|
|
1204
|
+
states, [0, self._cell.state_size + self._attn_size],
|
|
1205
|
+
[-1, self._attn_size * self._attn_length])
|
|
1206
|
+
attn_states = array_ops.reshape(attn_states,
|
|
1207
|
+
[-1, self._attn_length, self._attn_size])
|
|
1208
|
+
input_size = self._input_size
|
|
1209
|
+
if input_size is None:
|
|
1210
|
+
input_size = inputs.get_shape().as_list()[1]
|
|
1211
|
+
if self._linear1 is None:
|
|
1212
|
+
self._linear1 = _Linear([inputs, attns], input_size, True)
|
|
1213
|
+
inputs = self._linear1([inputs, attns])
|
|
1214
|
+
cell_output, new_state = self._cell(inputs, state)
|
|
1215
|
+
if self._state_is_tuple:
|
|
1216
|
+
new_state_cat = array_ops.concat(nest.flatten(new_state), 1)
|
|
1217
|
+
else:
|
|
1218
|
+
new_state_cat = new_state
|
|
1219
|
+
new_attns, new_attn_states = self._attention(new_state_cat, attn_states)
|
|
1220
|
+
with vs.variable_scope("attn_output_projection"):
|
|
1221
|
+
if self._linear2 is None:
|
|
1222
|
+
self._linear2 = _Linear([cell_output, new_attns], self._attn_size, True)
|
|
1223
|
+
output = self._linear2([cell_output, new_attns])
|
|
1224
|
+
new_attn_states = array_ops.concat(
|
|
1225
|
+
[new_attn_states, array_ops.expand_dims(output, 1)], 1)
|
|
1226
|
+
new_attn_states = array_ops.reshape(
|
|
1227
|
+
new_attn_states, [-1, self._attn_length * self._attn_size])
|
|
1228
|
+
new_state = (new_state, new_attns, new_attn_states)
|
|
1229
|
+
if not self._state_is_tuple:
|
|
1230
|
+
new_state = array_ops.concat(list(new_state), 1)
|
|
1231
|
+
return output, new_state
|
|
1232
|
+
|
|
1233
|
+
def _attention(self, query, attn_states):
|
|
1234
|
+
conv2d = nn_ops.conv2d
|
|
1235
|
+
reduce_sum = math_ops.reduce_sum
|
|
1236
|
+
softmax = nn_ops.softmax
|
|
1237
|
+
tanh = math_ops.tanh
|
|
1238
|
+
|
|
1239
|
+
with vs.variable_scope("attention"):
|
|
1240
|
+
k = vs.get_variable("attn_w",
|
|
1241
|
+
[1, 1, self._attn_size, self._attn_vec_size])
|
|
1242
|
+
v = vs.get_variable("attn_v", [self._attn_vec_size])
|
|
1243
|
+
hidden = array_ops.reshape(attn_states,
|
|
1244
|
+
[-1, self._attn_length, 1, self._attn_size])
|
|
1245
|
+
hidden_features = conv2d(hidden, k, [1, 1, 1, 1], "SAME")
|
|
1246
|
+
if self._linear3 is None:
|
|
1247
|
+
self._linear3 = _Linear(query, self._attn_vec_size, True)
|
|
1248
|
+
y = self._linear3(query)
|
|
1249
|
+
y = array_ops.reshape(y, [-1, 1, 1, self._attn_vec_size])
|
|
1250
|
+
s = reduce_sum(v * tanh(hidden_features + y), [2, 3])
|
|
1251
|
+
a = softmax(s)
|
|
1252
|
+
d = reduce_sum(
|
|
1253
|
+
array_ops.reshape(a, [-1, self._attn_length, 1, 1]) * hidden, [1, 2])
|
|
1254
|
+
new_attns = array_ops.reshape(d, [-1, self._attn_size])
|
|
1255
|
+
new_attn_states = array_ops.slice(attn_states, [0, 1, 0], [-1, -1, -1])
|
|
1256
|
+
return new_attns, new_attn_states
|
|
1257
|
+
|
|
1258
|
+
|
|
1259
|
+
class HighwayWrapper(rnn_cell_impl.RNNCell):
|
|
1260
|
+
"""RNNCell wrapper that adds highway connection on cell input and output.
|
|
1261
|
+
|
|
1262
|
+
Based on:
|
|
1263
|
+
R. K. Srivastava, K. Greff, and J. Schmidhuber, "Highway networks",
|
|
1264
|
+
arXiv preprint arXiv:1505.00387, 2015.
|
|
1265
|
+
https://arxiv.org/abs/1505.00387
|
|
1266
|
+
"""
|
|
1267
|
+
|
|
1268
|
+
def __init__(self,
|
|
1269
|
+
cell,
|
|
1270
|
+
couple_carry_transform_gates=True,
|
|
1271
|
+
carry_bias_init=1.0):
|
|
1272
|
+
"""Constructs a `HighwayWrapper` for `cell`.
|
|
1273
|
+
|
|
1274
|
+
Args:
|
|
1275
|
+
cell: An instance of `RNNCell`.
|
|
1276
|
+
couple_carry_transform_gates: boolean, should the Carry and Transform gate
|
|
1277
|
+
be coupled.
|
|
1278
|
+
carry_bias_init: float, carry gates bias initialization.
|
|
1279
|
+
"""
|
|
1280
|
+
self._cell = cell
|
|
1281
|
+
self._couple_carry_transform_gates = couple_carry_transform_gates
|
|
1282
|
+
self._carry_bias_init = carry_bias_init
|
|
1283
|
+
|
|
1284
|
+
@property
|
|
1285
|
+
def state_size(self):
|
|
1286
|
+
return self._cell.state_size
|
|
1287
|
+
|
|
1288
|
+
@property
|
|
1289
|
+
def output_size(self):
|
|
1290
|
+
return self._cell.output_size
|
|
1291
|
+
|
|
1292
|
+
def zero_state(self, batch_size, dtype):
|
|
1293
|
+
with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
|
|
1294
|
+
return self._cell.zero_state(batch_size, dtype)
|
|
1295
|
+
|
|
1296
|
+
def _highway(self, inp, out):
|
|
1297
|
+
input_size = inp.get_shape().with_rank(2).dims[1].value
|
|
1298
|
+
carry_weight = vs.get_variable("carry_w", [input_size, input_size])
|
|
1299
|
+
carry_bias = vs.get_variable(
|
|
1300
|
+
"carry_b", [input_size],
|
|
1301
|
+
initializer=init_ops.constant_initializer(self._carry_bias_init))
|
|
1302
|
+
carry = math_ops.sigmoid(nn_ops.xw_plus_b(inp, carry_weight, carry_bias))
|
|
1303
|
+
if self._couple_carry_transform_gates:
|
|
1304
|
+
transform = 1 - carry
|
|
1305
|
+
else:
|
|
1306
|
+
transform_weight = vs.get_variable("transform_w",
|
|
1307
|
+
[input_size, input_size])
|
|
1308
|
+
transform_bias = vs.get_variable(
|
|
1309
|
+
"transform_b", [input_size],
|
|
1310
|
+
initializer=init_ops.constant_initializer(-self._carry_bias_init))
|
|
1311
|
+
transform = math_ops.sigmoid(
|
|
1312
|
+
nn_ops.xw_plus_b(inp, transform_weight, transform_bias))
|
|
1313
|
+
return inp * carry + out * transform
|
|
1314
|
+
|
|
1315
|
+
def __call__(self, inputs, state, scope=None):
|
|
1316
|
+
"""Run the cell and add its inputs to its outputs.
|
|
1317
|
+
|
|
1318
|
+
Args:
|
|
1319
|
+
inputs: cell inputs.
|
|
1320
|
+
state: cell state.
|
|
1321
|
+
scope: optional cell scope.
|
|
1322
|
+
|
|
1323
|
+
Returns:
|
|
1324
|
+
Tuple of cell outputs and new state.
|
|
1325
|
+
|
|
1326
|
+
Raises:
|
|
1327
|
+
TypeError: If cell inputs and outputs have different structure (type).
|
|
1328
|
+
ValueError: If cell inputs and outputs have different structure (value).
|
|
1329
|
+
"""
|
|
1330
|
+
outputs, new_state = self._cell(inputs, state, scope=scope)
|
|
1331
|
+
nest.assert_same_structure(inputs, outputs)
|
|
1332
|
+
|
|
1333
|
+
# Ensure shapes match
|
|
1334
|
+
def assert_shape_match(inp, out):
|
|
1335
|
+
inp.get_shape().assert_is_compatible_with(out.get_shape())
|
|
1336
|
+
|
|
1337
|
+
nest.map_structure(assert_shape_match, inputs, outputs)
|
|
1338
|
+
res_outputs = nest.map_structure(self._highway, inputs, outputs)
|
|
1339
|
+
return (res_outputs, new_state)
|
|
1340
|
+
|
|
1341
|
+
|
|
1342
|
+
class LayerNormBasicLSTMCell(rnn_cell_impl.RNNCell):
|
|
1343
|
+
"""LSTM unit with layer normalization and recurrent dropout.
|
|
1344
|
+
|
|
1345
|
+
This class adds layer normalization and recurrent dropout to a
|
|
1346
|
+
basic LSTM unit. Layer normalization implementation is based on:
|
|
1347
|
+
|
|
1348
|
+
https://arxiv.org/abs/1607.06450.
|
|
1349
|
+
|
|
1350
|
+
"Layer Normalization"
|
|
1351
|
+
Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton
|
|
1352
|
+
|
|
1353
|
+
and is applied before the internal nonlinearities.
|
|
1354
|
+
Recurrent dropout is base on:
|
|
1355
|
+
|
|
1356
|
+
https://arxiv.org/abs/1603.05118
|
|
1357
|
+
|
|
1358
|
+
"Recurrent Dropout without Memory Loss"
|
|
1359
|
+
Stanislau Semeniuta, Aliaksei Severyn, Erhardt Barth.
|
|
1360
|
+
"""
|
|
1361
|
+
|
|
1362
|
+
def __init__(self,
|
|
1363
|
+
num_units,
|
|
1364
|
+
forget_bias=1.0,
|
|
1365
|
+
input_size=None,
|
|
1366
|
+
activation=math_ops.tanh,
|
|
1367
|
+
layer_norm=True,
|
|
1368
|
+
norm_gain=1.0,
|
|
1369
|
+
norm_shift=0.0,
|
|
1370
|
+
dropout_keep_prob=1.0,
|
|
1371
|
+
dropout_prob_seed=None,
|
|
1372
|
+
reuse=None):
|
|
1373
|
+
"""Initializes the basic LSTM cell.
|
|
1374
|
+
|
|
1375
|
+
Args:
|
|
1376
|
+
num_units: int, The number of units in the LSTM cell.
|
|
1377
|
+
forget_bias: float, The bias added to forget gates (see above).
|
|
1378
|
+
input_size: Deprecated and unused.
|
|
1379
|
+
activation: Activation function of the inner states.
|
|
1380
|
+
layer_norm: If `True`, layer normalization will be applied.
|
|
1381
|
+
norm_gain: float, The layer normalization gain initial value. If
|
|
1382
|
+
`layer_norm` has been set to `False`, this argument will be ignored.
|
|
1383
|
+
norm_shift: float, The layer normalization shift initial value. If
|
|
1384
|
+
`layer_norm` has been set to `False`, this argument will be ignored.
|
|
1385
|
+
dropout_keep_prob: unit Tensor or float between 0 and 1 representing the
|
|
1386
|
+
recurrent dropout probability value. If float and 1.0, no dropout will
|
|
1387
|
+
be applied.
|
|
1388
|
+
dropout_prob_seed: (optional) integer, the randomness seed.
|
|
1389
|
+
reuse: (optional) Python boolean describing whether to reuse variables
|
|
1390
|
+
in an existing scope. If not `True`, and the existing scope already has
|
|
1391
|
+
the given variables, an error is raised.
|
|
1392
|
+
"""
|
|
1393
|
+
super(LayerNormBasicLSTMCell, self).__init__(_reuse=reuse)
|
|
1394
|
+
|
|
1395
|
+
if input_size is not None:
|
|
1396
|
+
logging.warn("%s: The input_size parameter is deprecated.", self)
|
|
1397
|
+
|
|
1398
|
+
self._num_units = num_units
|
|
1399
|
+
self._activation = activation
|
|
1400
|
+
self._forget_bias = forget_bias
|
|
1401
|
+
self._keep_prob = dropout_keep_prob
|
|
1402
|
+
self._seed = dropout_prob_seed
|
|
1403
|
+
self._layer_norm = layer_norm
|
|
1404
|
+
self._norm_gain = norm_gain
|
|
1405
|
+
self._norm_shift = norm_shift
|
|
1406
|
+
self._reuse = reuse
|
|
1407
|
+
|
|
1408
|
+
@property
|
|
1409
|
+
def state_size(self):
|
|
1410
|
+
return rnn_cell_impl.LSTMStateTuple(self._num_units, self._num_units)
|
|
1411
|
+
|
|
1412
|
+
@property
|
|
1413
|
+
def output_size(self):
|
|
1414
|
+
return self._num_units
|
|
1415
|
+
|
|
1416
|
+
def _norm(self, inp, scope, dtype=dtypes.float32):
|
|
1417
|
+
shape = inp.get_shape()[-1:]
|
|
1418
|
+
gamma_init = init_ops.constant_initializer(self._norm_gain)
|
|
1419
|
+
beta_init = init_ops.constant_initializer(self._norm_shift)
|
|
1420
|
+
with vs.variable_scope(scope):
|
|
1421
|
+
# Initialize beta and gamma for use by layer_norm.
|
|
1422
|
+
vs.get_variable("gamma", shape=shape, initializer=gamma_init, dtype=dtype)
|
|
1423
|
+
vs.get_variable("beta", shape=shape, initializer=beta_init, dtype=dtype)
|
|
1424
|
+
normalized = layers.layer_norm(inp, reuse=True, scope=scope)
|
|
1425
|
+
return normalized
|
|
1426
|
+
|
|
1427
|
+
def _linear(self, args):
|
|
1428
|
+
out_size = 4 * self._num_units
|
|
1429
|
+
proj_size = args.get_shape()[-1]
|
|
1430
|
+
dtype = args.dtype
|
|
1431
|
+
weights = vs.get_variable("kernel", [proj_size, out_size], dtype=dtype)
|
|
1432
|
+
out = math_ops.matmul(args, weights)
|
|
1433
|
+
if not self._layer_norm:
|
|
1434
|
+
bias = vs.get_variable("bias", [out_size], dtype=dtype)
|
|
1435
|
+
out = nn_ops.bias_add(out, bias)
|
|
1436
|
+
return out
|
|
1437
|
+
|
|
1438
|
+
def call(self, inputs, state):
|
|
1439
|
+
"""LSTM cell with layer normalization and recurrent dropout."""
|
|
1440
|
+
c, h = state
|
|
1441
|
+
args = array_ops.concat([inputs, h], 1)
|
|
1442
|
+
concat = self._linear(args)
|
|
1443
|
+
dtype = args.dtype
|
|
1444
|
+
|
|
1445
|
+
i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)
|
|
1446
|
+
if self._layer_norm:
|
|
1447
|
+
i = self._norm(i, "input", dtype=dtype)
|
|
1448
|
+
j = self._norm(j, "transform", dtype=dtype)
|
|
1449
|
+
f = self._norm(f, "forget", dtype=dtype)
|
|
1450
|
+
o = self._norm(o, "output", dtype=dtype)
|
|
1451
|
+
|
|
1452
|
+
g = self._activation(j)
|
|
1453
|
+
if (not isinstance(self._keep_prob, float)) or self._keep_prob < 1:
|
|
1454
|
+
g = nn_ops.dropout(g, self._keep_prob, seed=self._seed)
|
|
1455
|
+
|
|
1456
|
+
new_c = (
|
|
1457
|
+
c * math_ops.sigmoid(f + self._forget_bias) + math_ops.sigmoid(i) * g)
|
|
1458
|
+
if self._layer_norm:
|
|
1459
|
+
new_c = self._norm(new_c, "state", dtype=dtype)
|
|
1460
|
+
new_h = self._activation(new_c) * math_ops.sigmoid(o)
|
|
1461
|
+
|
|
1462
|
+
new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h)
|
|
1463
|
+
return new_h, new_state
|
|
1464
|
+
|
|
1465
|
+
|
|
1466
|
+
class NASCell(rnn_cell_impl.LayerRNNCell):
|
|
1467
|
+
"""Neural Architecture Search (NAS) recurrent network cell.
|
|
1468
|
+
|
|
1469
|
+
This implements the recurrent cell from the paper:
|
|
1470
|
+
|
|
1471
|
+
https://arxiv.org/abs/1611.01578
|
|
1472
|
+
|
|
1473
|
+
Barret Zoph and Quoc V. Le.
|
|
1474
|
+
"Neural Architecture Search with Reinforcement Learning" Proc. ICLR 2017.
|
|
1475
|
+
|
|
1476
|
+
The class uses an optional projection layer.
|
|
1477
|
+
"""
|
|
1478
|
+
|
|
1479
|
+
# NAS cell's architecture base.
|
|
1480
|
+
_NAS_BASE = 8
|
|
1481
|
+
|
|
1482
|
+
def __init__(self, num_units, num_proj=None, use_bias=False, reuse=None,
|
|
1483
|
+
**kwargs):
|
|
1484
|
+
"""Initialize the parameters for a NAS cell.
|
|
1485
|
+
|
|
1486
|
+
Args:
|
|
1487
|
+
num_units: int, The number of units in the NAS cell.
|
|
1488
|
+
num_proj: (optional) int, The output dimensionality for the projection
|
|
1489
|
+
matrices. If None, no projection is performed.
|
|
1490
|
+
use_bias: (optional) bool, If True then use biases within the cell. This
|
|
1491
|
+
is False by default.
|
|
1492
|
+
reuse: (optional) Python boolean describing whether to reuse variables
|
|
1493
|
+
in an existing scope. If not `True`, and the existing scope already has
|
|
1494
|
+
the given variables, an error is raised.
|
|
1495
|
+
**kwargs: Additional keyword arguments.
|
|
1496
|
+
"""
|
|
1497
|
+
super(NASCell, self).__init__(_reuse=reuse, **kwargs)
|
|
1498
|
+
self._num_units = num_units
|
|
1499
|
+
self._num_proj = num_proj
|
|
1500
|
+
self._use_bias = use_bias
|
|
1501
|
+
self._reuse = reuse
|
|
1502
|
+
|
|
1503
|
+
if num_proj is not None:
|
|
1504
|
+
self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_proj)
|
|
1505
|
+
self._output_size = num_proj
|
|
1506
|
+
else:
|
|
1507
|
+
self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_units)
|
|
1508
|
+
self._output_size = num_units
|
|
1509
|
+
|
|
1510
|
+
@property
|
|
1511
|
+
def state_size(self):
|
|
1512
|
+
return self._state_size
|
|
1513
|
+
|
|
1514
|
+
@property
|
|
1515
|
+
def output_size(self):
|
|
1516
|
+
return self._output_size
|
|
1517
|
+
|
|
1518
|
+
def build(self, inputs_shape):
|
|
1519
|
+
input_size = tensor_shape.dimension_value(
|
|
1520
|
+
tensor_shape.TensorShape(inputs_shape).with_rank(2)[1])
|
|
1521
|
+
if input_size is None:
|
|
1522
|
+
raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
|
|
1523
|
+
|
|
1524
|
+
num_proj = self._num_units if self._num_proj is None else self._num_proj
|
|
1525
|
+
|
|
1526
|
+
# Variables for the NAS cell. `recurrent_kernel` is all matrices multiplying
|
|
1527
|
+
# the hiddenstate and `kernel` is all matrices multiplying the inputs.
|
|
1528
|
+
self.recurrent_kernel = self.add_variable(
|
|
1529
|
+
"recurrent_kernel", [num_proj, self._NAS_BASE * self._num_units])
|
|
1530
|
+
self.kernel = self.add_variable(
|
|
1531
|
+
"kernel", [input_size, self._NAS_BASE * self._num_units])
|
|
1532
|
+
|
|
1533
|
+
if self._use_bias:
|
|
1534
|
+
self.bias = self.add_variable("bias",
|
|
1535
|
+
shape=[self._NAS_BASE * self._num_units],
|
|
1536
|
+
initializer=init_ops.zeros_initializer)
|
|
1537
|
+
|
|
1538
|
+
# Projection layer if specified
|
|
1539
|
+
if self._num_proj is not None:
|
|
1540
|
+
self.projection_weights = self.add_variable(
|
|
1541
|
+
"projection_weights", [self._num_units, self._num_proj])
|
|
1542
|
+
|
|
1543
|
+
self.built = True
|
|
1544
|
+
|
|
1545
|
+
def call(self, inputs, state):
|
|
1546
|
+
"""Run one step of NAS Cell.
|
|
1547
|
+
|
|
1548
|
+
Args:
|
|
1549
|
+
inputs: input Tensor, 2D, batch x num_units.
|
|
1550
|
+
state: This must be a tuple of state Tensors, both `2-D`, with column
|
|
1551
|
+
sizes `c_state` and `m_state`.
|
|
1552
|
+
|
|
1553
|
+
Returns:
|
|
1554
|
+
A tuple containing:
|
|
1555
|
+
- A `2-D, [batch x output_dim]`, Tensor representing the output of the
|
|
1556
|
+
NAS Cell after reading `inputs` when previous state was `state`.
|
|
1557
|
+
Here output_dim is:
|
|
1558
|
+
num_proj if num_proj was set,
|
|
1559
|
+
num_units otherwise.
|
|
1560
|
+
- Tensor(s) representing the new state of NAS Cell after reading `inputs`
|
|
1561
|
+
when the previous state was `state`. Same type and shape(s) as `state`.
|
|
1562
|
+
|
|
1563
|
+
Raises:
|
|
1564
|
+
ValueError: If input size cannot be inferred from inputs via
|
|
1565
|
+
static shape inference.
|
|
1566
|
+
"""
|
|
1567
|
+
sigmoid = math_ops.sigmoid
|
|
1568
|
+
tanh = math_ops.tanh
|
|
1569
|
+
relu = nn_ops.relu
|
|
1570
|
+
|
|
1571
|
+
(c_prev, m_prev) = state
|
|
1572
|
+
|
|
1573
|
+
m_matrix = math_ops.matmul(m_prev, self.recurrent_kernel)
|
|
1574
|
+
inputs_matrix = math_ops.matmul(inputs, self.kernel)
|
|
1575
|
+
|
|
1576
|
+
if self._use_bias:
|
|
1577
|
+
m_matrix = nn_ops.bias_add(m_matrix, self.bias)
|
|
1578
|
+
|
|
1579
|
+
# The NAS cell branches into 8 different splits for both the hiddenstate
|
|
1580
|
+
# and the input
|
|
1581
|
+
m_matrix_splits = array_ops.split(
|
|
1582
|
+
axis=1, num_or_size_splits=self._NAS_BASE, value=m_matrix)
|
|
1583
|
+
inputs_matrix_splits = array_ops.split(
|
|
1584
|
+
axis=1, num_or_size_splits=self._NAS_BASE, value=inputs_matrix)
|
|
1585
|
+
|
|
1586
|
+
# First layer
|
|
1587
|
+
layer1_0 = sigmoid(inputs_matrix_splits[0] + m_matrix_splits[0])
|
|
1588
|
+
layer1_1 = relu(inputs_matrix_splits[1] + m_matrix_splits[1])
|
|
1589
|
+
layer1_2 = sigmoid(inputs_matrix_splits[2] + m_matrix_splits[2])
|
|
1590
|
+
layer1_3 = relu(inputs_matrix_splits[3] * m_matrix_splits[3])
|
|
1591
|
+
layer1_4 = tanh(inputs_matrix_splits[4] + m_matrix_splits[4])
|
|
1592
|
+
layer1_5 = sigmoid(inputs_matrix_splits[5] + m_matrix_splits[5])
|
|
1593
|
+
layer1_6 = tanh(inputs_matrix_splits[6] + m_matrix_splits[6])
|
|
1594
|
+
layer1_7 = sigmoid(inputs_matrix_splits[7] + m_matrix_splits[7])
|
|
1595
|
+
|
|
1596
|
+
# Second layer
|
|
1597
|
+
l2_0 = tanh(layer1_0 * layer1_1)
|
|
1598
|
+
l2_1 = tanh(layer1_2 + layer1_3)
|
|
1599
|
+
l2_2 = tanh(layer1_4 * layer1_5)
|
|
1600
|
+
l2_3 = sigmoid(layer1_6 + layer1_7)
|
|
1601
|
+
|
|
1602
|
+
# Inject the cell
|
|
1603
|
+
l2_0 = tanh(l2_0 + c_prev)
|
|
1604
|
+
|
|
1605
|
+
# Third layer
|
|
1606
|
+
l3_0_pre = l2_0 * l2_1
|
|
1607
|
+
new_c = l3_0_pre # create new cell
|
|
1608
|
+
l3_0 = l3_0_pre
|
|
1609
|
+
l3_1 = tanh(l2_2 + l2_3)
|
|
1610
|
+
|
|
1611
|
+
# Final layer
|
|
1612
|
+
new_m = tanh(l3_0 * l3_1)
|
|
1613
|
+
|
|
1614
|
+
# Projection layer if specified
|
|
1615
|
+
if self._num_proj is not None:
|
|
1616
|
+
new_m = math_ops.matmul(new_m, self.projection_weights)
|
|
1617
|
+
|
|
1618
|
+
new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_m)
|
|
1619
|
+
return new_m, new_state
|
|
1620
|
+
|
|
1621
|
+
|
|
1622
|
+
class UGRNNCell(rnn_cell_impl.RNNCell):
|
|
1623
|
+
"""Update Gate Recurrent Neural Network (UGRNN) cell.
|
|
1624
|
+
|
|
1625
|
+
Compromise between a LSTM/GRU and a vanilla RNN. There is only one
|
|
1626
|
+
gate, and that is to determine whether the unit should be
|
|
1627
|
+
integrating or computing instantaneously. This is the recurrent
|
|
1628
|
+
idea of the feedforward Highway Network.
|
|
1629
|
+
|
|
1630
|
+
This implements the recurrent cell from the paper:
|
|
1631
|
+
|
|
1632
|
+
https://arxiv.org/abs/1611.09913
|
|
1633
|
+
|
|
1634
|
+
Jasmine Collins, Jascha Sohl-Dickstein, and David Sussillo.
|
|
1635
|
+
"Capacity and Trainability in Recurrent Neural Networks" Proc. ICLR 2017.
|
|
1636
|
+
"""
|
|
1637
|
+
|
|
1638
|
+
def __init__(self,
|
|
1639
|
+
num_units,
|
|
1640
|
+
initializer=None,
|
|
1641
|
+
forget_bias=1.0,
|
|
1642
|
+
activation=math_ops.tanh,
|
|
1643
|
+
reuse=None):
|
|
1644
|
+
"""Initialize the parameters for an UGRNN cell.
|
|
1645
|
+
|
|
1646
|
+
Args:
|
|
1647
|
+
num_units: int, The number of units in the UGRNN cell
|
|
1648
|
+
initializer: (optional) The initializer to use for the weight matrices.
|
|
1649
|
+
forget_bias: (optional) float, default 1.0, The initial bias of the
|
|
1650
|
+
forget gate, used to reduce the scale of forgetting at the beginning
|
|
1651
|
+
of the training.
|
|
1652
|
+
activation: (optional) Activation function of the inner states.
|
|
1653
|
+
Default is `tf.tanh`.
|
|
1654
|
+
reuse: (optional) Python boolean describing whether to reuse variables
|
|
1655
|
+
in an existing scope. If not `True`, and the existing scope already has
|
|
1656
|
+
the given variables, an error is raised.
|
|
1657
|
+
"""
|
|
1658
|
+
super(UGRNNCell, self).__init__(_reuse=reuse)
|
|
1659
|
+
self._num_units = num_units
|
|
1660
|
+
self._initializer = initializer
|
|
1661
|
+
self._forget_bias = forget_bias
|
|
1662
|
+
self._activation = activation
|
|
1663
|
+
self._reuse = reuse
|
|
1664
|
+
self._linear = None
|
|
1665
|
+
|
|
1666
|
+
@property
|
|
1667
|
+
def state_size(self):
|
|
1668
|
+
return self._num_units
|
|
1669
|
+
|
|
1670
|
+
@property
|
|
1671
|
+
def output_size(self):
|
|
1672
|
+
return self._num_units
|
|
1673
|
+
|
|
1674
|
+
def call(self, inputs, state):
|
|
1675
|
+
"""Run one step of UGRNN.
|
|
1676
|
+
|
|
1677
|
+
Args:
|
|
1678
|
+
inputs: input Tensor, 2D, batch x input size.
|
|
1679
|
+
state: state Tensor, 2D, batch x num units.
|
|
1680
|
+
|
|
1681
|
+
Returns:
|
|
1682
|
+
new_output: batch x num units, Tensor representing the output of the UGRNN
|
|
1683
|
+
after reading `inputs` when previous state was `state`. Identical to
|
|
1684
|
+
`new_state`.
|
|
1685
|
+
new_state: batch x num units, Tensor representing the state of the UGRNN
|
|
1686
|
+
after reading `inputs` when previous state was `state`.
|
|
1687
|
+
|
|
1688
|
+
Raises:
|
|
1689
|
+
ValueError: If input size cannot be inferred from inputs via
|
|
1690
|
+
static shape inference.
|
|
1691
|
+
"""
|
|
1692
|
+
sigmoid = math_ops.sigmoid
|
|
1693
|
+
|
|
1694
|
+
input_size = inputs.get_shape().with_rank(2).dims[1]
|
|
1695
|
+
if input_size.value is None:
|
|
1696
|
+
raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
|
|
1697
|
+
|
|
1698
|
+
with vs.variable_scope(
|
|
1699
|
+
vs.get_variable_scope(), initializer=self._initializer):
|
|
1700
|
+
cell_inputs = array_ops.concat([inputs, state], 1)
|
|
1701
|
+
if self._linear is None:
|
|
1702
|
+
self._linear = _Linear(cell_inputs, 2 * self._num_units, True)
|
|
1703
|
+
rnn_matrix = self._linear(cell_inputs)
|
|
1704
|
+
|
|
1705
|
+
[g_act, c_act] = array_ops.split(
|
|
1706
|
+
axis=1, num_or_size_splits=2, value=rnn_matrix)
|
|
1707
|
+
|
|
1708
|
+
c = self._activation(c_act)
|
|
1709
|
+
g = sigmoid(g_act + self._forget_bias)
|
|
1710
|
+
new_state = g * state + (1.0 - g) * c
|
|
1711
|
+
new_output = new_state
|
|
1712
|
+
|
|
1713
|
+
return new_output, new_state
|
|
1714
|
+
|
|
1715
|
+
|
|
1716
|
+
class IntersectionRNNCell(rnn_cell_impl.RNNCell):
|
|
1717
|
+
"""Intersection Recurrent Neural Network (+RNN) cell.
|
|
1718
|
+
|
|
1719
|
+
Architecture with coupled recurrent gate as well as coupled depth
|
|
1720
|
+
gate, designed to improve information flow through stacked RNNs. As the
|
|
1721
|
+
architecture uses depth gating, the dimensionality of the depth
|
|
1722
|
+
output (y) also should not change through depth (input size == output size).
|
|
1723
|
+
To achieve this, the first layer of a stacked Intersection RNN projects
|
|
1724
|
+
the inputs to N (num units) dimensions. Therefore when initializing an
|
|
1725
|
+
IntersectionRNNCell, one should set `num_in_proj = N` for the first layer
|
|
1726
|
+
and use default settings for subsequent layers.
|
|
1727
|
+
|
|
1728
|
+
This implements the recurrent cell from the paper:
|
|
1729
|
+
|
|
1730
|
+
https://arxiv.org/abs/1611.09913
|
|
1731
|
+
|
|
1732
|
+
Jasmine Collins, Jascha Sohl-Dickstein, and David Sussillo.
|
|
1733
|
+
"Capacity and Trainability in Recurrent Neural Networks" Proc. ICLR 2017.
|
|
1734
|
+
|
|
1735
|
+
The Intersection RNN is built for use in deeply stacked
|
|
1736
|
+
RNNs so it may not achieve best performance with depth 1.
|
|
1737
|
+
"""
|
|
1738
|
+
|
|
1739
|
+
def __init__(self,
|
|
1740
|
+
num_units,
|
|
1741
|
+
num_in_proj=None,
|
|
1742
|
+
initializer=None,
|
|
1743
|
+
forget_bias=1.0,
|
|
1744
|
+
y_activation=nn_ops.relu,
|
|
1745
|
+
reuse=None):
|
|
1746
|
+
"""Initialize the parameters for an +RNN cell.
|
|
1747
|
+
|
|
1748
|
+
Args:
|
|
1749
|
+
num_units: int, The number of units in the +RNN cell
|
|
1750
|
+
num_in_proj: (optional) int, The input dimensionality for the RNN.
|
|
1751
|
+
If creating the first layer of an +RNN, this should be set to
|
|
1752
|
+
`num_units`. Otherwise, this should be set to `None` (default).
|
|
1753
|
+
If `None`, dimensionality of `inputs` should be equal to `num_units`,
|
|
1754
|
+
otherwise ValueError is thrown.
|
|
1755
|
+
initializer: (optional) The initializer to use for the weight matrices.
|
|
1756
|
+
forget_bias: (optional) float, default 1.0, The initial bias of the
|
|
1757
|
+
forget gates, used to reduce the scale of forgetting at the beginning
|
|
1758
|
+
of the training.
|
|
1759
|
+
y_activation: (optional) Activation function of the states passed
|
|
1760
|
+
through depth. Default is 'tf.nn.relu`.
|
|
1761
|
+
reuse: (optional) Python boolean describing whether to reuse variables
|
|
1762
|
+
in an existing scope. If not `True`, and the existing scope already has
|
|
1763
|
+
the given variables, an error is raised.
|
|
1764
|
+
"""
|
|
1765
|
+
super(IntersectionRNNCell, self).__init__(_reuse=reuse)
|
|
1766
|
+
self._num_units = num_units
|
|
1767
|
+
self._initializer = initializer
|
|
1768
|
+
self._forget_bias = forget_bias
|
|
1769
|
+
self._num_input_proj = num_in_proj
|
|
1770
|
+
self._y_activation = y_activation
|
|
1771
|
+
self._reuse = reuse
|
|
1772
|
+
self._linear1 = None
|
|
1773
|
+
self._linear2 = None
|
|
1774
|
+
|
|
1775
|
+
@property
|
|
1776
|
+
def state_size(self):
|
|
1777
|
+
return self._num_units
|
|
1778
|
+
|
|
1779
|
+
@property
|
|
1780
|
+
def output_size(self):
|
|
1781
|
+
return self._num_units
|
|
1782
|
+
|
|
1783
|
+
def call(self, inputs, state):
|
|
1784
|
+
"""Run one step of the Intersection RNN.
|
|
1785
|
+
|
|
1786
|
+
Args:
|
|
1787
|
+
inputs: input Tensor, 2D, batch x input size.
|
|
1788
|
+
state: state Tensor, 2D, batch x num units.
|
|
1789
|
+
|
|
1790
|
+
Returns:
|
|
1791
|
+
new_y: batch x num units, Tensor representing the output of the +RNN
|
|
1792
|
+
after reading `inputs` when previous state was `state`.
|
|
1793
|
+
new_state: batch x num units, Tensor representing the state of the +RNN
|
|
1794
|
+
after reading `inputs` when previous state was `state`.
|
|
1795
|
+
|
|
1796
|
+
Raises:
|
|
1797
|
+
ValueError: If input size cannot be inferred from `inputs` via
|
|
1798
|
+
static shape inference.
|
|
1799
|
+
ValueError: If input size != output size (these must be equal when
|
|
1800
|
+
using the Intersection RNN).
|
|
1801
|
+
"""
|
|
1802
|
+
sigmoid = math_ops.sigmoid
|
|
1803
|
+
tanh = math_ops.tanh
|
|
1804
|
+
|
|
1805
|
+
input_size = inputs.get_shape().with_rank(2).dims[1]
|
|
1806
|
+
if input_size.value is None:
|
|
1807
|
+
raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
|
|
1808
|
+
|
|
1809
|
+
with vs.variable_scope(
|
|
1810
|
+
vs.get_variable_scope(), initializer=self._initializer):
|
|
1811
|
+
# read-in projections (should be used for first layer in deep +RNN
|
|
1812
|
+
# to transform size of inputs from I --> N)
|
|
1813
|
+
if input_size.value != self._num_units:
|
|
1814
|
+
if self._num_input_proj:
|
|
1815
|
+
with vs.variable_scope("in_projection"):
|
|
1816
|
+
if self._linear1 is None:
|
|
1817
|
+
self._linear1 = _Linear(inputs, self._num_units, True)
|
|
1818
|
+
inputs = self._linear1(inputs)
|
|
1819
|
+
else:
|
|
1820
|
+
raise ValueError("Must have input size == output size for "
|
|
1821
|
+
"Intersection RNN. To fix, num_in_proj should "
|
|
1822
|
+
"be set to num_units at cell init.")
|
|
1823
|
+
|
|
1824
|
+
n_dim = i_dim = self._num_units
|
|
1825
|
+
cell_inputs = array_ops.concat([inputs, state], 1)
|
|
1826
|
+
if self._linear2 is None:
|
|
1827
|
+
self._linear2 = _Linear(cell_inputs, 2 * n_dim + 2 * i_dim, True)
|
|
1828
|
+
rnn_matrix = self._linear2(cell_inputs)
|
|
1829
|
+
|
|
1830
|
+
gh_act = rnn_matrix[:, :n_dim] # b x n
|
|
1831
|
+
h_act = rnn_matrix[:, n_dim:2 * n_dim] # b x n
|
|
1832
|
+
gy_act = rnn_matrix[:, 2 * n_dim:2 * n_dim + i_dim] # b x i
|
|
1833
|
+
y_act = rnn_matrix[:, 2 * n_dim + i_dim:2 * n_dim + 2 * i_dim] # b x i
|
|
1834
|
+
|
|
1835
|
+
h = tanh(h_act)
|
|
1836
|
+
y = self._y_activation(y_act)
|
|
1837
|
+
gh = sigmoid(gh_act + self._forget_bias)
|
|
1838
|
+
gy = sigmoid(gy_act + self._forget_bias)
|
|
1839
|
+
|
|
1840
|
+
new_state = gh * state + (1.0 - gh) * h # passed thru time
|
|
1841
|
+
new_y = gy * inputs + (1.0 - gy) * y # passed thru depth
|
|
1842
|
+
|
|
1843
|
+
return new_y, new_state
|
|
1844
|
+
|
|
1845
|
+
|
|
1846
|
+
_REGISTERED_OPS = None
|
|
1847
|
+
|
|
1848
|
+
|
|
1849
|
+
class CompiledWrapper(rnn_cell_impl.RNNCell):
|
|
1850
|
+
"""Wraps step execution in an XLA JIT scope."""
|
|
1851
|
+
|
|
1852
|
+
def __init__(self, cell, compile_stateful=False):
|
|
1853
|
+
"""Create CompiledWrapper cell.
|
|
1854
|
+
|
|
1855
|
+
Args:
|
|
1856
|
+
cell: Instance of `RNNCell`.
|
|
1857
|
+
compile_stateful: Whether to compile stateful ops like initializers
|
|
1858
|
+
and random number generators (default: False).
|
|
1859
|
+
"""
|
|
1860
|
+
self._cell = cell
|
|
1861
|
+
self._compile_stateful = compile_stateful
|
|
1862
|
+
|
|
1863
|
+
@property
|
|
1864
|
+
def state_size(self):
|
|
1865
|
+
return self._cell.state_size
|
|
1866
|
+
|
|
1867
|
+
@property
|
|
1868
|
+
def output_size(self):
|
|
1869
|
+
return self._cell.output_size
|
|
1870
|
+
|
|
1871
|
+
def zero_state(self, batch_size, dtype):
|
|
1872
|
+
with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
|
|
1873
|
+
return self._cell.zero_state(batch_size, dtype)
|
|
1874
|
+
|
|
1875
|
+
def __call__(self, inputs, state, scope=None):
|
|
1876
|
+
if self._compile_stateful:
|
|
1877
|
+
compile_ops = True
|
|
1878
|
+
else:
|
|
1879
|
+
|
|
1880
|
+
def compile_ops(node_def):
|
|
1881
|
+
global _REGISTERED_OPS
|
|
1882
|
+
if _REGISTERED_OPS is None:
|
|
1883
|
+
_REGISTERED_OPS = op_def_registry.get_registered_ops()
|
|
1884
|
+
return not _REGISTERED_OPS[node_def.op].is_stateful
|
|
1885
|
+
|
|
1886
|
+
with jit.experimental_jit_scope(compile_ops=compile_ops):
|
|
1887
|
+
return self._cell(inputs, state, scope=scope)
|
|
1888
|
+
|
|
1889
|
+
|
|
1890
|
+
def _random_exp_initializer(minval, maxval, seed=None, dtype=dtypes.float32):
|
|
1891
|
+
"""Returns an exponential distribution initializer.
|
|
1892
|
+
|
|
1893
|
+
Args:
|
|
1894
|
+
minval: float or a scalar float Tensor. With value > 0. Lower bound of the
|
|
1895
|
+
range of random values to generate.
|
|
1896
|
+
maxval: float or a scalar float Tensor. With value > minval. Upper bound of
|
|
1897
|
+
the range of random values to generate.
|
|
1898
|
+
seed: An integer. Used to create random seeds.
|
|
1899
|
+
dtype: The data type.
|
|
1900
|
+
|
|
1901
|
+
Returns:
|
|
1902
|
+
An initializer that generates tensors with an exponential distribution.
|
|
1903
|
+
"""
|
|
1904
|
+
|
|
1905
|
+
def _initializer(shape, dtype=dtype, partition_info=None):
|
|
1906
|
+
del partition_info # Unused.
|
|
1907
|
+
return math_ops.exp(
|
|
1908
|
+
random_ops.random_uniform(
|
|
1909
|
+
shape, math_ops.log(minval), math_ops.log(maxval), dtype,
|
|
1910
|
+
seed=seed))
|
|
1911
|
+
|
|
1912
|
+
return _initializer
|
|
1913
|
+
|
|
1914
|
+
|
|
1915
|
+
class PhasedLSTMCell(rnn_cell_impl.RNNCell):
|
|
1916
|
+
"""Phased LSTM recurrent network cell.
|
|
1917
|
+
|
|
1918
|
+
https://arxiv.org/pdf/1610.09513v1.pdf
|
|
1919
|
+
"""
|
|
1920
|
+
|
|
1921
|
+
def __init__(self,
|
|
1922
|
+
num_units,
|
|
1923
|
+
use_peepholes=False,
|
|
1924
|
+
leak=0.001,
|
|
1925
|
+
ratio_on=0.1,
|
|
1926
|
+
trainable_ratio_on=True,
|
|
1927
|
+
period_init_min=1.0,
|
|
1928
|
+
period_init_max=1000.0,
|
|
1929
|
+
reuse=None):
|
|
1930
|
+
"""Initialize the Phased LSTM cell.
|
|
1931
|
+
|
|
1932
|
+
Args:
|
|
1933
|
+
num_units: int, The number of units in the Phased LSTM cell.
|
|
1934
|
+
use_peepholes: bool, set True to enable peephole connections.
|
|
1935
|
+
leak: float or scalar float Tensor with value in [0, 1]. Leak applied
|
|
1936
|
+
during training.
|
|
1937
|
+
ratio_on: float or scalar float Tensor with value in [0, 1]. Ratio of the
|
|
1938
|
+
period during which the gates are open.
|
|
1939
|
+
trainable_ratio_on: bool, weather ratio_on is trainable.
|
|
1940
|
+
period_init_min: float or scalar float Tensor. With value > 0.
|
|
1941
|
+
Minimum value of the initialized period.
|
|
1942
|
+
The period values are initialized by drawing from the distribution:
|
|
1943
|
+
e^U(log(period_init_min), log(period_init_max))
|
|
1944
|
+
Where U(.,.) is the uniform distribution.
|
|
1945
|
+
period_init_max: float or scalar float Tensor.
|
|
1946
|
+
With value > period_init_min. Maximum value of the initialized period.
|
|
1947
|
+
reuse: (optional) Python boolean describing whether to reuse variables
|
|
1948
|
+
in an existing scope. If not `True`, and the existing scope already has
|
|
1949
|
+
the given variables, an error is raised.
|
|
1950
|
+
"""
|
|
1951
|
+
# We pass autocast=False because this layer can accept inputs of different
|
|
1952
|
+
# dtypes, so we do not want to automatically cast them to the same dtype.
|
|
1953
|
+
super(PhasedLSTMCell, self).__init__(_reuse=reuse, autocast=False)
|
|
1954
|
+
self._num_units = num_units
|
|
1955
|
+
self._use_peepholes = use_peepholes
|
|
1956
|
+
self._leak = leak
|
|
1957
|
+
self._ratio_on = ratio_on
|
|
1958
|
+
self._trainable_ratio_on = trainable_ratio_on
|
|
1959
|
+
self._period_init_min = period_init_min
|
|
1960
|
+
self._period_init_max = period_init_max
|
|
1961
|
+
self._reuse = reuse
|
|
1962
|
+
self._linear1 = None
|
|
1963
|
+
self._linear2 = None
|
|
1964
|
+
self._linear3 = None
|
|
1965
|
+
|
|
1966
|
+
@property
|
|
1967
|
+
def state_size(self):
|
|
1968
|
+
return rnn_cell_impl.LSTMStateTuple(self._num_units, self._num_units)
|
|
1969
|
+
|
|
1970
|
+
@property
|
|
1971
|
+
def output_size(self):
|
|
1972
|
+
return self._num_units
|
|
1973
|
+
|
|
1974
|
+
def _mod(self, x, y):
|
|
1975
|
+
"""Modulo function that propagates x gradients."""
|
|
1976
|
+
return array_ops.stop_gradient(math_ops.mod(x, y) - x) + x
|
|
1977
|
+
|
|
1978
|
+
def _get_cycle_ratio(self, time, phase, period):
|
|
1979
|
+
"""Compute the cycle ratio in the dtype of the time."""
|
|
1980
|
+
phase_casted = math_ops.cast(phase, dtype=time.dtype)
|
|
1981
|
+
period_casted = math_ops.cast(period, dtype=time.dtype)
|
|
1982
|
+
shifted_time = time - phase_casted
|
|
1983
|
+
cycle_ratio = self._mod(shifted_time, period_casted) / period_casted
|
|
1984
|
+
return math_ops.cast(cycle_ratio, dtype=dtypes.float32)
|
|
1985
|
+
|
|
1986
|
+
def call(self, inputs, state):
|
|
1987
|
+
"""Phased LSTM Cell.
|
|
1988
|
+
|
|
1989
|
+
Args:
|
|
1990
|
+
inputs: A tuple of 2 Tensor.
|
|
1991
|
+
The first Tensor has shape [batch, 1], and type float32 or float64.
|
|
1992
|
+
It stores the time.
|
|
1993
|
+
The second Tensor has shape [batch, features_size], and type float32.
|
|
1994
|
+
It stores the features.
|
|
1995
|
+
state: rnn_cell_impl.LSTMStateTuple, state from previous timestep.
|
|
1996
|
+
|
|
1997
|
+
Returns:
|
|
1998
|
+
A tuple containing:
|
|
1999
|
+
- A Tensor of float32, and shape [batch_size, num_units], representing the
|
|
2000
|
+
output of the cell.
|
|
2001
|
+
- A rnn_cell_impl.LSTMStateTuple, containing 2 Tensors of float32, shape
|
|
2002
|
+
[batch_size, num_units], representing the new state and the output.
|
|
2003
|
+
"""
|
|
2004
|
+
(c_prev, h_prev) = state
|
|
2005
|
+
(time, x) = inputs
|
|
2006
|
+
|
|
2007
|
+
in_mask_gates = [x, h_prev]
|
|
2008
|
+
if self._use_peepholes:
|
|
2009
|
+
in_mask_gates.append(c_prev)
|
|
2010
|
+
|
|
2011
|
+
with vs.variable_scope("mask_gates"):
|
|
2012
|
+
if self._linear1 is None:
|
|
2013
|
+
self._linear1 = _Linear(in_mask_gates, 2 * self._num_units, True)
|
|
2014
|
+
|
|
2015
|
+
mask_gates = math_ops.sigmoid(self._linear1(in_mask_gates))
|
|
2016
|
+
[input_gate, forget_gate] = array_ops.split(
|
|
2017
|
+
axis=1, num_or_size_splits=2, value=mask_gates)
|
|
2018
|
+
|
|
2019
|
+
with vs.variable_scope("new_input"):
|
|
2020
|
+
if self._linear2 is None:
|
|
2021
|
+
self._linear2 = _Linear([x, h_prev], self._num_units, True)
|
|
2022
|
+
new_input = math_ops.tanh(self._linear2([x, h_prev]))
|
|
2023
|
+
|
|
2024
|
+
new_c = (c_prev * forget_gate + input_gate * new_input)
|
|
2025
|
+
|
|
2026
|
+
in_out_gate = [x, h_prev]
|
|
2027
|
+
if self._use_peepholes:
|
|
2028
|
+
in_out_gate.append(new_c)
|
|
2029
|
+
|
|
2030
|
+
with vs.variable_scope("output_gate"):
|
|
2031
|
+
if self._linear3 is None:
|
|
2032
|
+
self._linear3 = _Linear(in_out_gate, self._num_units, True)
|
|
2033
|
+
output_gate = math_ops.sigmoid(self._linear3(in_out_gate))
|
|
2034
|
+
|
|
2035
|
+
new_h = math_ops.tanh(new_c) * output_gate
|
|
2036
|
+
|
|
2037
|
+
period = vs.get_variable(
|
|
2038
|
+
"period", [self._num_units],
|
|
2039
|
+
initializer=_random_exp_initializer(self._period_init_min,
|
|
2040
|
+
self._period_init_max))
|
|
2041
|
+
phase = vs.get_variable(
|
|
2042
|
+
"phase", [self._num_units],
|
|
2043
|
+
initializer=init_ops.random_uniform_initializer(0.,
|
|
2044
|
+
period.initial_value))
|
|
2045
|
+
ratio_on = vs.get_variable(
|
|
2046
|
+
"ratio_on", [self._num_units],
|
|
2047
|
+
initializer=init_ops.constant_initializer(self._ratio_on),
|
|
2048
|
+
trainable=self._trainable_ratio_on)
|
|
2049
|
+
|
|
2050
|
+
cycle_ratio = self._get_cycle_ratio(time, phase, period)
|
|
2051
|
+
|
|
2052
|
+
k_up = 2 * cycle_ratio / ratio_on
|
|
2053
|
+
k_down = 2 - k_up
|
|
2054
|
+
k_closed = self._leak * cycle_ratio
|
|
2055
|
+
|
|
2056
|
+
k = array_ops.where(cycle_ratio < ratio_on, k_down, k_closed)
|
|
2057
|
+
k = array_ops.where(cycle_ratio < 0.5 * ratio_on, k_up, k)
|
|
2058
|
+
|
|
2059
|
+
new_c = k * new_c + (1 - k) * c_prev
|
|
2060
|
+
new_h = k * new_h + (1 - k) * h_prev
|
|
2061
|
+
|
|
2062
|
+
new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h)
|
|
2063
|
+
|
|
2064
|
+
return new_h, new_state
|
|
2065
|
+
|
|
2066
|
+
|
|
2067
|
+
class ConvLSTMCell(rnn_cell_impl.RNNCell):
|
|
2068
|
+
"""Convolutional LSTM recurrent network cell.
|
|
2069
|
+
|
|
2070
|
+
https://arxiv.org/pdf/1506.04214v1.pdf
|
|
2071
|
+
"""
|
|
2072
|
+
|
|
2073
|
+
def __init__(self,
|
|
2074
|
+
conv_ndims,
|
|
2075
|
+
input_shape,
|
|
2076
|
+
output_channels,
|
|
2077
|
+
kernel_shape,
|
|
2078
|
+
use_bias=True,
|
|
2079
|
+
skip_connection=False,
|
|
2080
|
+
forget_bias=1.0,
|
|
2081
|
+
initializers=None,
|
|
2082
|
+
name="conv_lstm_cell"):
|
|
2083
|
+
"""Construct ConvLSTMCell.
|
|
2084
|
+
|
|
2085
|
+
Args:
|
|
2086
|
+
conv_ndims: Convolution dimensionality (1, 2 or 3).
|
|
2087
|
+
input_shape: Shape of the input as int tuple, excluding the batch size.
|
|
2088
|
+
output_channels: int, number of output channels of the conv LSTM.
|
|
2089
|
+
kernel_shape: Shape of kernel as an int tuple (of size 1, 2 or 3).
|
|
2090
|
+
use_bias: (bool) Use bias in convolutions.
|
|
2091
|
+
skip_connection: If set to `True`, concatenate the input to the
|
|
2092
|
+
output of the conv LSTM. Default: `False`.
|
|
2093
|
+
forget_bias: Forget bias.
|
|
2094
|
+
initializers: Unused.
|
|
2095
|
+
name: Name of the module.
|
|
2096
|
+
|
|
2097
|
+
Raises:
|
|
2098
|
+
ValueError: If `skip_connection` is `True` and stride is different from 1
|
|
2099
|
+
or if `input_shape` is incompatible with `conv_ndims`.
|
|
2100
|
+
"""
|
|
2101
|
+
super(ConvLSTMCell, self).__init__(name=name)
|
|
2102
|
+
|
|
2103
|
+
if conv_ndims != len(input_shape) - 1:
|
|
2104
|
+
raise ValueError("Invalid input_shape {} for conv_ndims={}.".format(
|
|
2105
|
+
input_shape, conv_ndims))
|
|
2106
|
+
|
|
2107
|
+
self._conv_ndims = conv_ndims
|
|
2108
|
+
self._input_shape = input_shape
|
|
2109
|
+
self._output_channels = output_channels
|
|
2110
|
+
self._kernel_shape = list(kernel_shape)
|
|
2111
|
+
self._use_bias = use_bias
|
|
2112
|
+
self._forget_bias = forget_bias
|
|
2113
|
+
self._skip_connection = skip_connection
|
|
2114
|
+
|
|
2115
|
+
self._total_output_channels = output_channels
|
|
2116
|
+
if self._skip_connection:
|
|
2117
|
+
self._total_output_channels += self._input_shape[-1]
|
|
2118
|
+
|
|
2119
|
+
state_size = tensor_shape.TensorShape(
|
|
2120
|
+
self._input_shape[:-1] + [self._output_channels])
|
|
2121
|
+
self._state_size = rnn_cell_impl.LSTMStateTuple(state_size, state_size)
|
|
2122
|
+
self._output_size = tensor_shape.TensorShape(
|
|
2123
|
+
self._input_shape[:-1] + [self._total_output_channels])
|
|
2124
|
+
|
|
2125
|
+
@property
|
|
2126
|
+
def output_size(self):
|
|
2127
|
+
return self._output_size
|
|
2128
|
+
|
|
2129
|
+
@property
|
|
2130
|
+
def state_size(self):
|
|
2131
|
+
return self._state_size
|
|
2132
|
+
|
|
2133
|
+
def call(self, inputs, state, scope=None):
|
|
2134
|
+
cell, hidden = state
|
|
2135
|
+
new_hidden = _conv([inputs, hidden], self._kernel_shape,
|
|
2136
|
+
4 * self._output_channels, self._use_bias)
|
|
2137
|
+
gates = array_ops.split(
|
|
2138
|
+
value=new_hidden, num_or_size_splits=4, axis=self._conv_ndims + 1)
|
|
2139
|
+
|
|
2140
|
+
input_gate, new_input, forget_gate, output_gate = gates
|
|
2141
|
+
new_cell = math_ops.sigmoid(forget_gate + self._forget_bias) * cell
|
|
2142
|
+
new_cell += math_ops.sigmoid(input_gate) * math_ops.tanh(new_input)
|
|
2143
|
+
output = math_ops.tanh(new_cell) * math_ops.sigmoid(output_gate)
|
|
2144
|
+
|
|
2145
|
+
if self._skip_connection:
|
|
2146
|
+
output = array_ops.concat([output, inputs], axis=-1)
|
|
2147
|
+
new_state = rnn_cell_impl.LSTMStateTuple(new_cell, output)
|
|
2148
|
+
return output, new_state
|
|
2149
|
+
|
|
2150
|
+
|
|
2151
|
+
class Conv1DLSTMCell(ConvLSTMCell):
|
|
2152
|
+
"""1D Convolutional LSTM recurrent network cell.
|
|
2153
|
+
|
|
2154
|
+
https://arxiv.org/pdf/1506.04214v1.pdf
|
|
2155
|
+
"""
|
|
2156
|
+
|
|
2157
|
+
def __init__(self, name="conv_1d_lstm_cell", **kwargs):
|
|
2158
|
+
"""Construct Conv1DLSTM. See `ConvLSTMCell` for more details."""
|
|
2159
|
+
super(Conv1DLSTMCell, self).__init__(conv_ndims=1, name=name, **kwargs)
|
|
2160
|
+
|
|
2161
|
+
|
|
2162
|
+
class Conv2DLSTMCell(ConvLSTMCell):
|
|
2163
|
+
"""2D Convolutional LSTM recurrent network cell.
|
|
2164
|
+
|
|
2165
|
+
https://arxiv.org/pdf/1506.04214v1.pdf
|
|
2166
|
+
"""
|
|
2167
|
+
|
|
2168
|
+
def __init__(self, name="conv_2d_lstm_cell", **kwargs):
|
|
2169
|
+
"""Construct Conv2DLSTM. See `ConvLSTMCell` for more details."""
|
|
2170
|
+
super(Conv2DLSTMCell, self).__init__(conv_ndims=2, name=name, **kwargs)
|
|
2171
|
+
|
|
2172
|
+
|
|
2173
|
+
class Conv3DLSTMCell(ConvLSTMCell):
|
|
2174
|
+
"""3D Convolutional LSTM recurrent network cell.
|
|
2175
|
+
|
|
2176
|
+
https://arxiv.org/pdf/1506.04214v1.pdf
|
|
2177
|
+
"""
|
|
2178
|
+
|
|
2179
|
+
def __init__(self, name="conv_3d_lstm_cell", **kwargs):
|
|
2180
|
+
"""Construct Conv3DLSTM. See `ConvLSTMCell` for more details."""
|
|
2181
|
+
super(Conv3DLSTMCell, self).__init__(conv_ndims=3, name=name, **kwargs)
|
|
2182
|
+
|
|
2183
|
+
|
|
2184
|
+
def _conv(args, filter_size, num_features, bias, bias_start=0.0):
|
|
2185
|
+
"""Convolution.
|
|
2186
|
+
|
|
2187
|
+
Args:
|
|
2188
|
+
args: a Tensor or a list of Tensors of dimension 3D, 4D or 5D,
|
|
2189
|
+
batch x n, Tensors.
|
|
2190
|
+
filter_size: int tuple of filter shape (of size 1, 2 or 3).
|
|
2191
|
+
num_features: int, number of features.
|
|
2192
|
+
bias: Whether to use biases in the convolution layer.
|
|
2193
|
+
bias_start: starting value to initialize the bias; 0 by default.
|
|
2194
|
+
|
|
2195
|
+
Returns:
|
|
2196
|
+
A 3D, 4D, or 5D Tensor with shape [batch ... num_features]
|
|
2197
|
+
|
|
2198
|
+
Raises:
|
|
2199
|
+
ValueError: if some of the arguments has unspecified or wrong shape.
|
|
2200
|
+
"""
|
|
2201
|
+
|
|
2202
|
+
# Calculate the total size of arguments on dimension 1.
|
|
2203
|
+
total_arg_size_depth = 0
|
|
2204
|
+
shapes = [a.get_shape().as_list() for a in args]
|
|
2205
|
+
shape_length = len(shapes[0])
|
|
2206
|
+
for shape in shapes:
|
|
2207
|
+
if len(shape) not in [3, 4, 5]:
|
|
2208
|
+
raise ValueError("Conv Linear expects 3D, 4D "
|
|
2209
|
+
"or 5D arguments: %s" % str(shapes))
|
|
2210
|
+
if len(shape) != len(shapes[0]):
|
|
2211
|
+
raise ValueError("Conv Linear expects all args "
|
|
2212
|
+
"to be of same Dimension: %s" % str(shapes))
|
|
2213
|
+
else:
|
|
2214
|
+
total_arg_size_depth += shape[-1]
|
|
2215
|
+
dtype = [a.dtype for a in args][0]
|
|
2216
|
+
|
|
2217
|
+
# determine correct conv operation
|
|
2218
|
+
if shape_length == 3:
|
|
2219
|
+
conv_op = nn_ops.conv1d
|
|
2220
|
+
strides = 1
|
|
2221
|
+
elif shape_length == 4:
|
|
2222
|
+
conv_op = nn_ops.conv2d
|
|
2223
|
+
strides = shape_length * [1]
|
|
2224
|
+
elif shape_length == 5:
|
|
2225
|
+
conv_op = nn_ops.conv3d
|
|
2226
|
+
strides = shape_length * [1]
|
|
2227
|
+
|
|
2228
|
+
# Now the computation.
|
|
2229
|
+
kernel = vs.get_variable(
|
|
2230
|
+
"kernel", filter_size + [total_arg_size_depth, num_features], dtype=dtype)
|
|
2231
|
+
if len(args) == 1:
|
|
2232
|
+
res = conv_op(args[0], kernel, strides, padding="SAME")
|
|
2233
|
+
else:
|
|
2234
|
+
res = conv_op(
|
|
2235
|
+
array_ops.concat(axis=shape_length - 1, values=args),
|
|
2236
|
+
kernel,
|
|
2237
|
+
strides,
|
|
2238
|
+
padding="SAME")
|
|
2239
|
+
if not bias:
|
|
2240
|
+
return res
|
|
2241
|
+
bias_term = vs.get_variable(
|
|
2242
|
+
"biases", [num_features],
|
|
2243
|
+
dtype=dtype,
|
|
2244
|
+
initializer=init_ops.constant_initializer(bias_start, dtype=dtype))
|
|
2245
|
+
return res + bias_term
|
|
2246
|
+
|
|
2247
|
+
|
|
2248
|
+
class GLSTMCell(rnn_cell_impl.RNNCell):
|
|
2249
|
+
"""Group LSTM cell (G-LSTM).
|
|
2250
|
+
|
|
2251
|
+
The implementation is based on:
|
|
2252
|
+
|
|
2253
|
+
https://arxiv.org/abs/1703.10722
|
|
2254
|
+
|
|
2255
|
+
O. Kuchaiev and B. Ginsburg
|
|
2256
|
+
"Factorization Tricks for LSTM Networks", ICLR 2017 workshop.
|
|
2257
|
+
|
|
2258
|
+
In brief, a G-LSTM cell consists of one LSTM sub-cell per group, where each
|
|
2259
|
+
sub-cell operates on an evenly-sized sub-vector of the input and produces an
|
|
2260
|
+
evenly-sized sub-vector of the output. For example, a G-LSTM cell with 128
|
|
2261
|
+
units and 4 groups consists of 4 LSTMs sub-cells with 32 units each. If that
|
|
2262
|
+
G-LSTM cell is fed a 200-dim input, then each sub-cell receives a 50-dim part
|
|
2263
|
+
of the input and produces a 32-dim part of the output.
|
|
2264
|
+
"""
|
|
2265
|
+
|
|
2266
|
+
def __init__(self,
|
|
2267
|
+
num_units,
|
|
2268
|
+
initializer=None,
|
|
2269
|
+
num_proj=None,
|
|
2270
|
+
number_of_groups=1,
|
|
2271
|
+
forget_bias=1.0,
|
|
2272
|
+
activation=math_ops.tanh,
|
|
2273
|
+
reuse=None):
|
|
2274
|
+
"""Initialize the parameters of G-LSTM cell.
|
|
2275
|
+
|
|
2276
|
+
Args:
|
|
2277
|
+
num_units: int, The number of units in the G-LSTM cell
|
|
2278
|
+
initializer: (optional) The initializer to use for the weight and
|
|
2279
|
+
projection matrices.
|
|
2280
|
+
num_proj: (optional) int, The output dimensionality for the projection
|
|
2281
|
+
matrices. If None, no projection is performed.
|
|
2282
|
+
number_of_groups: (optional) int, number of groups to use.
|
|
2283
|
+
If `number_of_groups` is 1, then it should be equivalent to LSTM cell
|
|
2284
|
+
forget_bias: Biases of the forget gate are initialized by default to 1
|
|
2285
|
+
in order to reduce the scale of forgetting at the beginning of
|
|
2286
|
+
the training.
|
|
2287
|
+
activation: Activation function of the inner states.
|
|
2288
|
+
reuse: (optional) Python boolean describing whether to reuse variables
|
|
2289
|
+
in an existing scope. If not `True`, and the existing scope already
|
|
2290
|
+
has the given variables, an error is raised.
|
|
2291
|
+
|
|
2292
|
+
Raises:
|
|
2293
|
+
ValueError: If `num_units` or `num_proj` is not divisible by
|
|
2294
|
+
`number_of_groups`.
|
|
2295
|
+
"""
|
|
2296
|
+
super(GLSTMCell, self).__init__(_reuse=reuse)
|
|
2297
|
+
self._num_units = num_units
|
|
2298
|
+
self._initializer = initializer
|
|
2299
|
+
self._num_proj = num_proj
|
|
2300
|
+
self._forget_bias = forget_bias
|
|
2301
|
+
self._activation = activation
|
|
2302
|
+
self._number_of_groups = number_of_groups
|
|
2303
|
+
|
|
2304
|
+
if self._num_units % self._number_of_groups != 0:
|
|
2305
|
+
raise ValueError("num_units must be divisible by number_of_groups")
|
|
2306
|
+
if self._num_proj:
|
|
2307
|
+
if self._num_proj % self._number_of_groups != 0:
|
|
2308
|
+
raise ValueError("num_proj must be divisible by number_of_groups")
|
|
2309
|
+
self._group_shape = [
|
|
2310
|
+
int(self._num_proj / self._number_of_groups),
|
|
2311
|
+
int(self._num_units / self._number_of_groups)
|
|
2312
|
+
]
|
|
2313
|
+
else:
|
|
2314
|
+
self._group_shape = [
|
|
2315
|
+
int(self._num_units / self._number_of_groups),
|
|
2316
|
+
int(self._num_units / self._number_of_groups)
|
|
2317
|
+
]
|
|
2318
|
+
|
|
2319
|
+
if num_proj:
|
|
2320
|
+
self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_proj)
|
|
2321
|
+
self._output_size = num_proj
|
|
2322
|
+
else:
|
|
2323
|
+
self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_units)
|
|
2324
|
+
self._output_size = num_units
|
|
2325
|
+
self._linear1 = [None] * number_of_groups
|
|
2326
|
+
self._linear2 = None
|
|
2327
|
+
|
|
2328
|
+
@property
|
|
2329
|
+
def state_size(self):
|
|
2330
|
+
return self._state_size
|
|
2331
|
+
|
|
2332
|
+
@property
|
|
2333
|
+
def output_size(self):
|
|
2334
|
+
return self._output_size
|
|
2335
|
+
|
|
2336
|
+
def _get_input_for_group(self, inputs, group_id, group_size):
|
|
2337
|
+
"""Slices inputs into groups to prepare for processing by cell's groups.
|
|
2338
|
+
|
|
2339
|
+
Args:
|
|
2340
|
+
inputs: cell input or it's previous state,
|
|
2341
|
+
a Tensor, 2D, [batch x num_units]
|
|
2342
|
+
group_id: group id, a Scalar, for which to prepare input
|
|
2343
|
+
group_size: size of the group
|
|
2344
|
+
|
|
2345
|
+
Returns:
|
|
2346
|
+
subset of inputs corresponding to group "group_id",
|
|
2347
|
+
a Tensor, 2D, [batch x num_units/number_of_groups]
|
|
2348
|
+
"""
|
|
2349
|
+
return array_ops.slice(
|
|
2350
|
+
input_=inputs,
|
|
2351
|
+
begin=[0, group_id * group_size],
|
|
2352
|
+
size=[self._batch_size, group_size],
|
|
2353
|
+
name=("GLSTM_group%d_input_generation" % group_id))
|
|
2354
|
+
|
|
2355
|
+
def call(self, inputs, state):
|
|
2356
|
+
"""Run one step of G-LSTM.
|
|
2357
|
+
|
|
2358
|
+
Args:
|
|
2359
|
+
inputs: input Tensor, 2D, [batch x num_inputs]. num_inputs must be
|
|
2360
|
+
statically-known and evenly divisible into groups. The innermost
|
|
2361
|
+
vectors of the inputs are split into evenly-sized sub-vectors and fed
|
|
2362
|
+
into the per-group LSTM sub-cells.
|
|
2363
|
+
state: this must be a tuple of state Tensors, both `2-D`, with column
|
|
2364
|
+
sizes `c_state` and `m_state`.
|
|
2365
|
+
|
|
2366
|
+
Returns:
|
|
2367
|
+
A tuple containing:
|
|
2368
|
+
|
|
2369
|
+
- A `2-D, [batch x output_dim]`, Tensor representing the output of the
|
|
2370
|
+
G-LSTM after reading `inputs` when previous state was `state`.
|
|
2371
|
+
Here output_dim is:
|
|
2372
|
+
num_proj if num_proj was set,
|
|
2373
|
+
num_units otherwise.
|
|
2374
|
+
- LSTMStateTuple representing the new state of G-LSTM cell
|
|
2375
|
+
after reading `inputs` when the previous state was `state`.
|
|
2376
|
+
|
|
2377
|
+
Raises:
|
|
2378
|
+
ValueError: If input size cannot be inferred from inputs via
|
|
2379
|
+
static shape inference, or if the input shape is incompatible
|
|
2380
|
+
with the number of groups.
|
|
2381
|
+
"""
|
|
2382
|
+
(c_prev, m_prev) = state
|
|
2383
|
+
|
|
2384
|
+
self._batch_size = tensor_shape.dimension_value(
|
|
2385
|
+
inputs.shape[0]) or array_ops.shape(inputs)[0]
|
|
2386
|
+
|
|
2387
|
+
# If the input size is statically-known, calculate and validate its group
|
|
2388
|
+
# size. Otherwise, use the output group size.
|
|
2389
|
+
input_size = tensor_shape.dimension_value(inputs.shape[1])
|
|
2390
|
+
if input_size is None:
|
|
2391
|
+
raise ValueError("input size must be statically known")
|
|
2392
|
+
if input_size % self._number_of_groups != 0:
|
|
2393
|
+
raise ValueError(
|
|
2394
|
+
"input size (%d) must be divisible by number_of_groups (%d)" %
|
|
2395
|
+
(input_size, self._number_of_groups))
|
|
2396
|
+
input_group_size = int(input_size / self._number_of_groups)
|
|
2397
|
+
|
|
2398
|
+
dtype = inputs.dtype
|
|
2399
|
+
scope = vs.get_variable_scope()
|
|
2400
|
+
with vs.variable_scope(scope, initializer=self._initializer):
|
|
2401
|
+
i_parts = []
|
|
2402
|
+
j_parts = []
|
|
2403
|
+
f_parts = []
|
|
2404
|
+
o_parts = []
|
|
2405
|
+
|
|
2406
|
+
for group_id in range(self._number_of_groups):
|
|
2407
|
+
with vs.variable_scope("group%d" % group_id):
|
|
2408
|
+
x_g_id = array_ops.concat(
|
|
2409
|
+
[
|
|
2410
|
+
self._get_input_for_group(inputs, group_id, input_group_size),
|
|
2411
|
+
self._get_input_for_group(m_prev, group_id,
|
|
2412
|
+
self._group_shape[0])
|
|
2413
|
+
],
|
|
2414
|
+
axis=1)
|
|
2415
|
+
linear = self._linear1[group_id]
|
|
2416
|
+
if linear is None:
|
|
2417
|
+
linear = _Linear(x_g_id, 4 * self._group_shape[1], False)
|
|
2418
|
+
self._linear1[group_id] = linear
|
|
2419
|
+
R_k = linear(x_g_id) # pylint: disable=invalid-name
|
|
2420
|
+
i_k, j_k, f_k, o_k = array_ops.split(R_k, 4, 1)
|
|
2421
|
+
|
|
2422
|
+
i_parts.append(i_k)
|
|
2423
|
+
j_parts.append(j_k)
|
|
2424
|
+
f_parts.append(f_k)
|
|
2425
|
+
o_parts.append(o_k)
|
|
2426
|
+
|
|
2427
|
+
bi = vs.get_variable(
|
|
2428
|
+
name="bias_i",
|
|
2429
|
+
shape=[self._num_units],
|
|
2430
|
+
dtype=dtype,
|
|
2431
|
+
initializer=init_ops.constant_initializer(0.0, dtype=dtype))
|
|
2432
|
+
bj = vs.get_variable(
|
|
2433
|
+
name="bias_j",
|
|
2434
|
+
shape=[self._num_units],
|
|
2435
|
+
dtype=dtype,
|
|
2436
|
+
initializer=init_ops.constant_initializer(0.0, dtype=dtype))
|
|
2437
|
+
bf = vs.get_variable(
|
|
2438
|
+
name="bias_f",
|
|
2439
|
+
shape=[self._num_units],
|
|
2440
|
+
dtype=dtype,
|
|
2441
|
+
initializer=init_ops.constant_initializer(0.0, dtype=dtype))
|
|
2442
|
+
bo = vs.get_variable(
|
|
2443
|
+
name="bias_o",
|
|
2444
|
+
shape=[self._num_units],
|
|
2445
|
+
dtype=dtype,
|
|
2446
|
+
initializer=init_ops.constant_initializer(0.0, dtype=dtype))
|
|
2447
|
+
|
|
2448
|
+
i = nn_ops.bias_add(array_ops.concat(i_parts, axis=1), bi)
|
|
2449
|
+
j = nn_ops.bias_add(array_ops.concat(j_parts, axis=1), bj)
|
|
2450
|
+
f = nn_ops.bias_add(array_ops.concat(f_parts, axis=1), bf)
|
|
2451
|
+
o = nn_ops.bias_add(array_ops.concat(o_parts, axis=1), bo)
|
|
2452
|
+
|
|
2453
|
+
c = (
|
|
2454
|
+
math_ops.sigmoid(f + self._forget_bias) * c_prev +
|
|
2455
|
+
math_ops.sigmoid(i) * math_ops.tanh(j))
|
|
2456
|
+
m = math_ops.sigmoid(o) * self._activation(c)
|
|
2457
|
+
|
|
2458
|
+
if self._num_proj is not None:
|
|
2459
|
+
with vs.variable_scope("projection"):
|
|
2460
|
+
if self._linear2 is None:
|
|
2461
|
+
self._linear2 = _Linear(m, self._num_proj, False)
|
|
2462
|
+
m = self._linear2(m)
|
|
2463
|
+
|
|
2464
|
+
new_state = rnn_cell_impl.LSTMStateTuple(c, m)
|
|
2465
|
+
return m, new_state
|
|
2466
|
+
|
|
2467
|
+
|
|
2468
|
+
class LayerNormLSTMCell(rnn_cell_impl.RNNCell):
|
|
2469
|
+
"""Long short-term memory unit (LSTM) recurrent network cell.
|
|
2470
|
+
|
|
2471
|
+
The default non-peephole implementation is based on:
|
|
2472
|
+
|
|
2473
|
+
https://pdfs.semanticscholar.org/1154/0131eae85b2e11d53df7f1360eeb6476e7f4.pdf
|
|
2474
|
+
|
|
2475
|
+
Felix Gers, Jurgen Schmidhuber, and Fred Cummins.
|
|
2476
|
+
"Learning to forget: Continual prediction with LSTM." IET, 850-855, 1999.
|
|
2477
|
+
|
|
2478
|
+
The peephole implementation is based on:
|
|
2479
|
+
|
|
2480
|
+
https://research.google.com/pubs/archive/43905.pdf
|
|
2481
|
+
|
|
2482
|
+
Hasim Sak, Andrew Senior, and Francoise Beaufays.
|
|
2483
|
+
"Long short-term memory recurrent neural network architectures for
|
|
2484
|
+
large scale acoustic modeling." INTERSPEECH, 2014.
|
|
2485
|
+
|
|
2486
|
+
The class uses optional peep-hole connections, optional cell clipping, and
|
|
2487
|
+
an optional projection layer.
|
|
2488
|
+
|
|
2489
|
+
Layer normalization implementation is based on:
|
|
2490
|
+
|
|
2491
|
+
https://arxiv.org/abs/1607.06450.
|
|
2492
|
+
|
|
2493
|
+
"Layer Normalization"
|
|
2494
|
+
Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton
|
|
2495
|
+
|
|
2496
|
+
and is applied before the internal nonlinearities.
|
|
2497
|
+
|
|
2498
|
+
"""
|
|
2499
|
+
|
|
2500
|
+
def __init__(self,
|
|
2501
|
+
num_units,
|
|
2502
|
+
use_peepholes=False,
|
|
2503
|
+
cell_clip=None,
|
|
2504
|
+
initializer=None,
|
|
2505
|
+
num_proj=None,
|
|
2506
|
+
proj_clip=None,
|
|
2507
|
+
forget_bias=1.0,
|
|
2508
|
+
activation=None,
|
|
2509
|
+
layer_norm=False,
|
|
2510
|
+
norm_gain=1.0,
|
|
2511
|
+
norm_shift=0.0,
|
|
2512
|
+
reuse=None):
|
|
2513
|
+
"""Initialize the parameters for an LSTM cell.
|
|
2514
|
+
|
|
2515
|
+
Args:
|
|
2516
|
+
num_units: int, The number of units in the LSTM cell
|
|
2517
|
+
use_peepholes: bool, set True to enable diagonal/peephole connections.
|
|
2518
|
+
cell_clip: (optional) A float value, if provided the cell state is clipped
|
|
2519
|
+
by this value prior to the cell output activation.
|
|
2520
|
+
initializer: (optional) The initializer to use for the weight and
|
|
2521
|
+
projection matrices.
|
|
2522
|
+
num_proj: (optional) int, The output dimensionality for the projection
|
|
2523
|
+
matrices. If None, no projection is performed.
|
|
2524
|
+
proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is
|
|
2525
|
+
provided, then the projected values are clipped elementwise to within
|
|
2526
|
+
`[-proj_clip, proj_clip]`.
|
|
2527
|
+
forget_bias: Biases of the forget gate are initialized by default to 1
|
|
2528
|
+
in order to reduce the scale of forgetting at the beginning of
|
|
2529
|
+
the training. Must set it manually to `0.0` when restoring from
|
|
2530
|
+
CudnnLSTM trained checkpoints.
|
|
2531
|
+
activation: Activation function of the inner states. Default: `tanh`.
|
|
2532
|
+
layer_norm: If `True`, layer normalization will be applied.
|
|
2533
|
+
norm_gain: float, The layer normalization gain initial value. If
|
|
2534
|
+
`layer_norm` has been set to `False`, this argument will be ignored.
|
|
2535
|
+
norm_shift: float, The layer normalization shift initial value. If
|
|
2536
|
+
`layer_norm` has been set to `False`, this argument will be ignored.
|
|
2537
|
+
reuse: (optional) Python boolean describing whether to reuse variables
|
|
2538
|
+
in an existing scope. If not `True`, and the existing scope already has
|
|
2539
|
+
the given variables, an error is raised.
|
|
2540
|
+
|
|
2541
|
+
When restoring from CudnnLSTM-trained checkpoints, must use
|
|
2542
|
+
CudnnCompatibleLSTMCell instead.
|
|
2543
|
+
"""
|
|
2544
|
+
super(LayerNormLSTMCell, self).__init__(_reuse=reuse)
|
|
2545
|
+
|
|
2546
|
+
self._num_units = num_units
|
|
2547
|
+
self._use_peepholes = use_peepholes
|
|
2548
|
+
self._cell_clip = cell_clip
|
|
2549
|
+
self._initializer = initializer
|
|
2550
|
+
self._num_proj = num_proj
|
|
2551
|
+
self._proj_clip = proj_clip
|
|
2552
|
+
self._forget_bias = forget_bias
|
|
2553
|
+
self._activation = activation or math_ops.tanh
|
|
2554
|
+
self._layer_norm = layer_norm
|
|
2555
|
+
self._norm_gain = norm_gain
|
|
2556
|
+
self._norm_shift = norm_shift
|
|
2557
|
+
|
|
2558
|
+
if num_proj:
|
|
2559
|
+
self._state_size = (rnn_cell_impl.LSTMStateTuple(num_units, num_proj))
|
|
2560
|
+
self._output_size = num_proj
|
|
2561
|
+
else:
|
|
2562
|
+
self._state_size = (rnn_cell_impl.LSTMStateTuple(num_units, num_units))
|
|
2563
|
+
self._output_size = num_units
|
|
2564
|
+
|
|
2565
|
+
@property
|
|
2566
|
+
def state_size(self):
|
|
2567
|
+
return self._state_size
|
|
2568
|
+
|
|
2569
|
+
@property
|
|
2570
|
+
def output_size(self):
|
|
2571
|
+
return self._output_size
|
|
2572
|
+
|
|
2573
|
+
def _linear(self,
|
|
2574
|
+
args,
|
|
2575
|
+
output_size,
|
|
2576
|
+
bias,
|
|
2577
|
+
bias_initializer=None,
|
|
2578
|
+
kernel_initializer=None,
|
|
2579
|
+
layer_norm=False):
|
|
2580
|
+
"""Linear map: sum_i(args[i] * W[i]), where W[i] is a Variable.
|
|
2581
|
+
|
|
2582
|
+
Args:
|
|
2583
|
+
args: a 2D Tensor or a list of 2D, batch x n, Tensors.
|
|
2584
|
+
output_size: int, second dimension of W[i].
|
|
2585
|
+
bias: boolean, whether to add a bias term or not.
|
|
2586
|
+
bias_initializer: starting value to initialize the bias
|
|
2587
|
+
(default is all zeros).
|
|
2588
|
+
kernel_initializer: starting value to initialize the weight.
|
|
2589
|
+
layer_norm: boolean, whether to apply layer normalization.
|
|
2590
|
+
|
|
2591
|
+
|
|
2592
|
+
Returns:
|
|
2593
|
+
A 2D Tensor with shape [batch x output_size] taking value
|
|
2594
|
+
sum_i(args[i] * W[i]), where each W[i] is a newly created Variable.
|
|
2595
|
+
|
|
2596
|
+
Raises:
|
|
2597
|
+
ValueError: if some of the arguments has unspecified or wrong shape.
|
|
2598
|
+
"""
|
|
2599
|
+
if args is None or (nest.is_sequence(args) and not args):
|
|
2600
|
+
raise ValueError("`args` must be specified")
|
|
2601
|
+
if not nest.is_sequence(args):
|
|
2602
|
+
args = [args]
|
|
2603
|
+
|
|
2604
|
+
# Calculate the total size of arguments on dimension 1.
|
|
2605
|
+
total_arg_size = 0
|
|
2606
|
+
shapes = [a.get_shape() for a in args]
|
|
2607
|
+
for shape in shapes:
|
|
2608
|
+
if shape.ndims != 2:
|
|
2609
|
+
raise ValueError("linear is expecting 2D arguments: %s" % shapes)
|
|
2610
|
+
if tensor_shape.dimension_value(shape[1]) is None:
|
|
2611
|
+
raise ValueError("linear expects shape[1] to be provided for shape %s, "
|
|
2612
|
+
"but saw %s" % (shape, shape[1]))
|
|
2613
|
+
else:
|
|
2614
|
+
total_arg_size += tensor_shape.dimension_value(shape[1])
|
|
2615
|
+
|
|
2616
|
+
dtype = [a.dtype for a in args][0]
|
|
2617
|
+
|
|
2618
|
+
# Now the computation.
|
|
2619
|
+
scope = vs.get_variable_scope()
|
|
2620
|
+
with vs.variable_scope(scope) as outer_scope:
|
|
2621
|
+
weights = vs.get_variable(
|
|
2622
|
+
"kernel", [total_arg_size, output_size],
|
|
2623
|
+
dtype=dtype,
|
|
2624
|
+
initializer=kernel_initializer)
|
|
2625
|
+
if len(args) == 1:
|
|
2626
|
+
res = math_ops.matmul(args[0], weights)
|
|
2627
|
+
else:
|
|
2628
|
+
res = math_ops.matmul(array_ops.concat(args, 1), weights)
|
|
2629
|
+
if not bias:
|
|
2630
|
+
return res
|
|
2631
|
+
with vs.variable_scope(outer_scope) as inner_scope:
|
|
2632
|
+
inner_scope.set_partitioner(None)
|
|
2633
|
+
if bias_initializer is None:
|
|
2634
|
+
bias_initializer = init_ops.constant_initializer(0.0, dtype=dtype)
|
|
2635
|
+
biases = vs.get_variable(
|
|
2636
|
+
"bias", [output_size], dtype=dtype, initializer=bias_initializer)
|
|
2637
|
+
|
|
2638
|
+
if not layer_norm:
|
|
2639
|
+
res = nn_ops.bias_add(res, biases)
|
|
2640
|
+
|
|
2641
|
+
return res
|
|
2642
|
+
|
|
2643
|
+
def call(self, inputs, state):
|
|
2644
|
+
"""Run one step of LSTM.
|
|
2645
|
+
|
|
2646
|
+
Args:
|
|
2647
|
+
inputs: input Tensor, 2D, batch x num_units.
|
|
2648
|
+
state: this must be a tuple of state Tensors,
|
|
2649
|
+
both `2-D`, with column sizes `c_state` and
|
|
2650
|
+
`m_state`.
|
|
2651
|
+
|
|
2652
|
+
Returns:
|
|
2653
|
+
A tuple containing:
|
|
2654
|
+
|
|
2655
|
+
- A `2-D, [batch x output_dim]`, Tensor representing the output of the
|
|
2656
|
+
LSTM after reading `inputs` when previous state was `state`.
|
|
2657
|
+
Here output_dim is:
|
|
2658
|
+
num_proj if num_proj was set,
|
|
2659
|
+
num_units otherwise.
|
|
2660
|
+
- Tensor(s) representing the new state of LSTM after reading `inputs` when
|
|
2661
|
+
the previous state was `state`. Same type and shape(s) as `state`.
|
|
2662
|
+
|
|
2663
|
+
Raises:
|
|
2664
|
+
ValueError: If input size cannot be inferred from inputs via
|
|
2665
|
+
static shape inference.
|
|
2666
|
+
"""
|
|
2667
|
+
sigmoid = math_ops.sigmoid
|
|
2668
|
+
|
|
2669
|
+
(c_prev, m_prev) = state
|
|
2670
|
+
|
|
2671
|
+
dtype = inputs.dtype
|
|
2672
|
+
input_size = inputs.get_shape().with_rank(2).dims[1]
|
|
2673
|
+
if input_size.value is None:
|
|
2674
|
+
raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
|
|
2675
|
+
scope = vs.get_variable_scope()
|
|
2676
|
+
with vs.variable_scope(scope, initializer=self._initializer) as unit_scope:
|
|
2677
|
+
|
|
2678
|
+
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
|
|
2679
|
+
lstm_matrix = self._linear(
|
|
2680
|
+
[inputs, m_prev],
|
|
2681
|
+
4 * self._num_units,
|
|
2682
|
+
bias=True,
|
|
2683
|
+
bias_initializer=None,
|
|
2684
|
+
layer_norm=self._layer_norm)
|
|
2685
|
+
i, j, f, o = array_ops.split(
|
|
2686
|
+
value=lstm_matrix, num_or_size_splits=4, axis=1)
|
|
2687
|
+
|
|
2688
|
+
if self._layer_norm:
|
|
2689
|
+
i = _norm(self._norm_gain, self._norm_shift, i, "input")
|
|
2690
|
+
j = _norm(self._norm_gain, self._norm_shift, j, "transform")
|
|
2691
|
+
f = _norm(self._norm_gain, self._norm_shift, f, "forget")
|
|
2692
|
+
o = _norm(self._norm_gain, self._norm_shift, o, "output")
|
|
2693
|
+
|
|
2694
|
+
# Diagonal connections
|
|
2695
|
+
if self._use_peepholes:
|
|
2696
|
+
with vs.variable_scope(unit_scope):
|
|
2697
|
+
w_f_diag = vs.get_variable(
|
|
2698
|
+
"w_f_diag", shape=[self._num_units], dtype=dtype)
|
|
2699
|
+
w_i_diag = vs.get_variable(
|
|
2700
|
+
"w_i_diag", shape=[self._num_units], dtype=dtype)
|
|
2701
|
+
w_o_diag = vs.get_variable(
|
|
2702
|
+
"w_o_diag", shape=[self._num_units], dtype=dtype)
|
|
2703
|
+
|
|
2704
|
+
if self._use_peepholes:
|
|
2705
|
+
c = (
|
|
2706
|
+
sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev +
|
|
2707
|
+
sigmoid(i + w_i_diag * c_prev) * self._activation(j))
|
|
2708
|
+
else:
|
|
2709
|
+
c = (
|
|
2710
|
+
sigmoid(f + self._forget_bias) * c_prev +
|
|
2711
|
+
sigmoid(i) * self._activation(j))
|
|
2712
|
+
|
|
2713
|
+
if self._layer_norm:
|
|
2714
|
+
c = _norm(self._norm_gain, self._norm_shift, c, "state")
|
|
2715
|
+
|
|
2716
|
+
if self._cell_clip is not None:
|
|
2717
|
+
# pylint: disable=invalid-unary-operand-type
|
|
2718
|
+
c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip)
|
|
2719
|
+
# pylint: enable=invalid-unary-operand-type
|
|
2720
|
+
if self._use_peepholes:
|
|
2721
|
+
m = sigmoid(o + w_o_diag * c) * self._activation(c)
|
|
2722
|
+
else:
|
|
2723
|
+
m = sigmoid(o) * self._activation(c)
|
|
2724
|
+
|
|
2725
|
+
if self._num_proj is not None:
|
|
2726
|
+
with vs.variable_scope("projection"):
|
|
2727
|
+
m = self._linear(m, self._num_proj, bias=False)
|
|
2728
|
+
|
|
2729
|
+
if self._proj_clip is not None:
|
|
2730
|
+
# pylint: disable=invalid-unary-operand-type
|
|
2731
|
+
m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)
|
|
2732
|
+
# pylint: enable=invalid-unary-operand-type
|
|
2733
|
+
|
|
2734
|
+
new_state = (rnn_cell_impl.LSTMStateTuple(c, m))
|
|
2735
|
+
return m, new_state
|
|
2736
|
+
|
|
2737
|
+
|
|
2738
|
+
class SRUCell(rnn_cell_impl.LayerRNNCell):
|
|
2739
|
+
"""SRU, Simple Recurrent Unit.
|
|
2740
|
+
|
|
2741
|
+
Implementation based on
|
|
2742
|
+
Training RNNs as Fast as CNNs (cf. https://arxiv.org/abs/1709.02755).
|
|
2743
|
+
|
|
2744
|
+
This variation of RNN cell is characterized by the simplified data
|
|
2745
|
+
dependence
|
|
2746
|
+
between hidden states of two consecutive time steps. Traditionally, hidden
|
|
2747
|
+
states from a cell at time step t-1 needs to be multiplied with a matrix
|
|
2748
|
+
W_hh before being fed into the ensuing cell at time step t.
|
|
2749
|
+
This flavor of RNN replaces the matrix multiplication between h_{t-1}
|
|
2750
|
+
and W_hh with a pointwise multiplication, resulting in performance
|
|
2751
|
+
gain.
|
|
2752
|
+
|
|
2753
|
+
Args:
|
|
2754
|
+
num_units: int, The number of units in the SRU cell.
|
|
2755
|
+
activation: Nonlinearity to use. Default: `tanh`.
|
|
2756
|
+
reuse: (optional) Python boolean describing whether to reuse variables
|
|
2757
|
+
in an existing scope. If not `True`, and the existing scope already has
|
|
2758
|
+
the given variables, an error is raised.
|
|
2759
|
+
name: (optional) String, the name of the layer. Layers with the same name
|
|
2760
|
+
will share weights, but to avoid mistakes we require reuse=True in such
|
|
2761
|
+
cases.
|
|
2762
|
+
**kwargs: Additional keyword arguments.
|
|
2763
|
+
"""
|
|
2764
|
+
|
|
2765
|
+
def __init__(self, num_units, activation=None, reuse=None, name=None,
|
|
2766
|
+
**kwargs):
|
|
2767
|
+
super(SRUCell, self).__init__(_reuse=reuse, name=name, **kwargs)
|
|
2768
|
+
self._num_units = num_units
|
|
2769
|
+
self._activation = activation or math_ops.tanh
|
|
2770
|
+
|
|
2771
|
+
# Restrict inputs to be 2-dimensional matrices
|
|
2772
|
+
self.input_spec = input_spec.InputSpec(ndim=2)
|
|
2773
|
+
|
|
2774
|
+
@property
|
|
2775
|
+
def state_size(self):
|
|
2776
|
+
return self._num_units
|
|
2777
|
+
|
|
2778
|
+
@property
|
|
2779
|
+
def output_size(self):
|
|
2780
|
+
return self._num_units
|
|
2781
|
+
|
|
2782
|
+
def build(self, inputs_shape):
|
|
2783
|
+
if tensor_shape.dimension_value(inputs_shape[1]) is None:
|
|
2784
|
+
raise ValueError(
|
|
2785
|
+
"Expected inputs.shape[-1] to be known, saw shape: %s" % inputs_shape)
|
|
2786
|
+
|
|
2787
|
+
input_depth = tensor_shape.dimension_value(inputs_shape[1])
|
|
2788
|
+
|
|
2789
|
+
# pylint: disable=protected-access
|
|
2790
|
+
self._kernel = self.add_variable(
|
|
2791
|
+
rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
|
|
2792
|
+
shape=[input_depth, 4 * self._num_units])
|
|
2793
|
+
# pylint: enable=protected-access
|
|
2794
|
+
self._bias = self.add_variable(
|
|
2795
|
+
rnn_cell_impl._BIAS_VARIABLE_NAME, # pylint: disable=protected-access
|
|
2796
|
+
shape=[2 * self._num_units],
|
|
2797
|
+
initializer=init_ops.zeros_initializer)
|
|
2798
|
+
|
|
2799
|
+
self._built = True
|
|
2800
|
+
|
|
2801
|
+
def call(self, inputs, state):
|
|
2802
|
+
"""Simple recurrent unit (SRU) with num_units cells."""
|
|
2803
|
+
|
|
2804
|
+
U = math_ops.matmul(inputs, self._kernel) # pylint: disable=invalid-name
|
|
2805
|
+
x_bar, f_intermediate, r_intermediate, x_tx = array_ops.split(
|
|
2806
|
+
value=U, num_or_size_splits=4, axis=1)
|
|
2807
|
+
|
|
2808
|
+
f_r = math_ops.sigmoid(
|
|
2809
|
+
nn_ops.bias_add(
|
|
2810
|
+
array_ops.concat([f_intermediate, r_intermediate], 1), self._bias))
|
|
2811
|
+
f, r = array_ops.split(value=f_r, num_or_size_splits=2, axis=1)
|
|
2812
|
+
|
|
2813
|
+
c = f * state + (1.0 - f) * x_bar
|
|
2814
|
+
h = r * self._activation(c) + (1.0 - r) * x_tx
|
|
2815
|
+
|
|
2816
|
+
return h, c
|
|
2817
|
+
|
|
2818
|
+
|
|
2819
|
+
class WeightNormLSTMCell(rnn_cell_impl.RNNCell):
|
|
2820
|
+
"""Weight normalized LSTM Cell. Adapted from `rnn_cell_impl.LSTMCell`.
|
|
2821
|
+
|
|
2822
|
+
The weight-norm implementation is based on:
|
|
2823
|
+
https://arxiv.org/abs/1602.07868
|
|
2824
|
+
Tim Salimans, Diederik P. Kingma.
|
|
2825
|
+
Weight Normalization: A Simple Reparameterization to Accelerate
|
|
2826
|
+
Training of Deep Neural Networks
|
|
2827
|
+
|
|
2828
|
+
The default LSTM implementation based on:
|
|
2829
|
+
|
|
2830
|
+
https://pdfs.semanticscholar.org/1154/0131eae85b2e11d53df7f1360eeb6476e7f4.pdf
|
|
2831
|
+
|
|
2832
|
+
Felix Gers, Jurgen Schmidhuber, and Fred Cummins.
|
|
2833
|
+
"Learning to forget: Continual prediction with LSTM." IET, 850-855, 1999.
|
|
2834
|
+
|
|
2835
|
+
The class uses optional peephole connections, optional cell clipping
|
|
2836
|
+
and an optional projection layer.
|
|
2837
|
+
|
|
2838
|
+
The optional peephole implementation is based on:
|
|
2839
|
+
https://research.google.com/pubs/archive/43905.pdf
|
|
2840
|
+
Hasim Sak, Andrew Senior, and Francoise Beaufays.
|
|
2841
|
+
"Long short-term memory recurrent neural network architectures for
|
|
2842
|
+
large scale acoustic modeling." INTERSPEECH, 2014.
|
|
2843
|
+
"""
|
|
2844
|
+
|
|
2845
|
+
def __init__(self,
|
|
2846
|
+
num_units,
|
|
2847
|
+
norm=True,
|
|
2848
|
+
use_peepholes=False,
|
|
2849
|
+
cell_clip=None,
|
|
2850
|
+
initializer=None,
|
|
2851
|
+
num_proj=None,
|
|
2852
|
+
proj_clip=None,
|
|
2853
|
+
forget_bias=1,
|
|
2854
|
+
activation=None,
|
|
2855
|
+
reuse=None):
|
|
2856
|
+
"""Initialize the parameters of a weight-normalized LSTM cell.
|
|
2857
|
+
|
|
2858
|
+
Args:
|
|
2859
|
+
num_units: int, The number of units in the LSTM cell
|
|
2860
|
+
norm: If `True`, apply normalization to the weight matrices. If False,
|
|
2861
|
+
the result is identical to that obtained from `rnn_cell_impl.LSTMCell`
|
|
2862
|
+
use_peepholes: bool, set `True` to enable diagonal/peephole connections.
|
|
2863
|
+
cell_clip: (optional) A float value, if provided the cell state is clipped
|
|
2864
|
+
by this value prior to the cell output activation.
|
|
2865
|
+
initializer: (optional) The initializer to use for the weight matrices.
|
|
2866
|
+
num_proj: (optional) int, The output dimensionality for the projection
|
|
2867
|
+
matrices. If None, no projection is performed.
|
|
2868
|
+
proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is
|
|
2869
|
+
provided, then the projected values are clipped elementwise to within
|
|
2870
|
+
`[-proj_clip, proj_clip]`.
|
|
2871
|
+
forget_bias: Biases of the forget gate are initialized by default to 1
|
|
2872
|
+
in order to reduce the scale of forgetting at the beginning of
|
|
2873
|
+
the training.
|
|
2874
|
+
activation: Activation function of the inner states. Default: `tanh`.
|
|
2875
|
+
reuse: (optional) Python boolean describing whether to reuse variables
|
|
2876
|
+
in an existing scope. If not `True`, and the existing scope already has
|
|
2877
|
+
the given variables, an error is raised.
|
|
2878
|
+
"""
|
|
2879
|
+
super(WeightNormLSTMCell, self).__init__(_reuse=reuse)
|
|
2880
|
+
|
|
2881
|
+
self._scope = "wn_lstm_cell"
|
|
2882
|
+
self._num_units = num_units
|
|
2883
|
+
self._norm = norm
|
|
2884
|
+
self._initializer = initializer
|
|
2885
|
+
self._use_peepholes = use_peepholes
|
|
2886
|
+
self._cell_clip = cell_clip
|
|
2887
|
+
self._num_proj = num_proj
|
|
2888
|
+
self._proj_clip = proj_clip
|
|
2889
|
+
self._activation = activation or math_ops.tanh
|
|
2890
|
+
self._forget_bias = forget_bias
|
|
2891
|
+
|
|
2892
|
+
self._weights_variable_name = "kernel"
|
|
2893
|
+
self._bias_variable_name = "bias"
|
|
2894
|
+
|
|
2895
|
+
if num_proj:
|
|
2896
|
+
self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_proj)
|
|
2897
|
+
self._output_size = num_proj
|
|
2898
|
+
else:
|
|
2899
|
+
self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_units)
|
|
2900
|
+
self._output_size = num_units
|
|
2901
|
+
|
|
2902
|
+
@property
|
|
2903
|
+
def state_size(self):
|
|
2904
|
+
return self._state_size
|
|
2905
|
+
|
|
2906
|
+
@property
|
|
2907
|
+
def output_size(self):
|
|
2908
|
+
return self._output_size
|
|
2909
|
+
|
|
2910
|
+
def _normalize(self, weight, name):
|
|
2911
|
+
"""Apply weight normalization.
|
|
2912
|
+
|
|
2913
|
+
Args:
|
|
2914
|
+
weight: a 2D tensor with known number of columns.
|
|
2915
|
+
name: string, variable name for the normalizer.
|
|
2916
|
+
Returns:
|
|
2917
|
+
A tensor with the same shape as `weight`.
|
|
2918
|
+
"""
|
|
2919
|
+
|
|
2920
|
+
output_size = weight.get_shape().as_list()[1]
|
|
2921
|
+
g = vs.get_variable(name, [output_size], dtype=weight.dtype)
|
|
2922
|
+
return nn_impl.l2_normalize(weight, axis=0) * g
|
|
2923
|
+
|
|
2924
|
+
def _linear(self,
|
|
2925
|
+
args,
|
|
2926
|
+
output_size,
|
|
2927
|
+
norm,
|
|
2928
|
+
bias,
|
|
2929
|
+
bias_initializer=None,
|
|
2930
|
+
kernel_initializer=None):
|
|
2931
|
+
"""Linear map: sum_i(args[i] * W[i]), where W[i] is a variable.
|
|
2932
|
+
|
|
2933
|
+
Args:
|
|
2934
|
+
args: a 2D Tensor or a list of 2D, batch x n, Tensors.
|
|
2935
|
+
output_size: int, second dimension of W[i].
|
|
2936
|
+
norm: bool, whether to normalize the weights.
|
|
2937
|
+
bias: boolean, whether to add a bias term or not.
|
|
2938
|
+
bias_initializer: starting value to initialize the bias
|
|
2939
|
+
(default is all zeros).
|
|
2940
|
+
kernel_initializer: starting value to initialize the weight.
|
|
2941
|
+
|
|
2942
|
+
Returns:
|
|
2943
|
+
A 2D Tensor with shape [batch x output_size] equal to
|
|
2944
|
+
sum_i(args[i] * W[i]), where W[i]s are newly created matrices.
|
|
2945
|
+
|
|
2946
|
+
Raises:
|
|
2947
|
+
ValueError: if some of the arguments has unspecified or wrong shape.
|
|
2948
|
+
"""
|
|
2949
|
+
if args is None or (nest.is_sequence(args) and not args):
|
|
2950
|
+
raise ValueError("`args` must be specified")
|
|
2951
|
+
if not nest.is_sequence(args):
|
|
2952
|
+
args = [args]
|
|
2953
|
+
|
|
2954
|
+
# Calculate the total size of arguments on dimension 1.
|
|
2955
|
+
total_arg_size = 0
|
|
2956
|
+
shapes = [a.get_shape() for a in args]
|
|
2957
|
+
for shape in shapes:
|
|
2958
|
+
if shape.ndims != 2:
|
|
2959
|
+
raise ValueError("linear is expecting 2D arguments: %s" % shapes)
|
|
2960
|
+
if tensor_shape.dimension_value(shape[1]) is None:
|
|
2961
|
+
raise ValueError("linear expects shape[1] to be provided for shape %s, "
|
|
2962
|
+
"but saw %s" % (shape, shape[1]))
|
|
2963
|
+
else:
|
|
2964
|
+
total_arg_size += tensor_shape.dimension_value(shape[1])
|
|
2965
|
+
|
|
2966
|
+
dtype = [a.dtype for a in args][0]
|
|
2967
|
+
|
|
2968
|
+
# Now the computation.
|
|
2969
|
+
scope = vs.get_variable_scope()
|
|
2970
|
+
with vs.variable_scope(scope) as outer_scope:
|
|
2971
|
+
weights = vs.get_variable(
|
|
2972
|
+
self._weights_variable_name, [total_arg_size, output_size],
|
|
2973
|
+
dtype=dtype,
|
|
2974
|
+
initializer=kernel_initializer)
|
|
2975
|
+
if norm:
|
|
2976
|
+
wn = []
|
|
2977
|
+
st = 0
|
|
2978
|
+
with ops.control_dependencies(None):
|
|
2979
|
+
for i in range(len(args)):
|
|
2980
|
+
en = st + tensor_shape.dimension_value(shapes[i][1])
|
|
2981
|
+
wn.append(
|
|
2982
|
+
self._normalize(weights[st:en, :], name="norm_{}".format(i)))
|
|
2983
|
+
st = en
|
|
2984
|
+
|
|
2985
|
+
weights = array_ops.concat(wn, axis=0)
|
|
2986
|
+
|
|
2987
|
+
if len(args) == 1:
|
|
2988
|
+
res = math_ops.matmul(args[0], weights)
|
|
2989
|
+
else:
|
|
2990
|
+
res = math_ops.matmul(array_ops.concat(args, 1), weights)
|
|
2991
|
+
if not bias:
|
|
2992
|
+
return res
|
|
2993
|
+
|
|
2994
|
+
with vs.variable_scope(outer_scope) as inner_scope:
|
|
2995
|
+
inner_scope.set_partitioner(None)
|
|
2996
|
+
if bias_initializer is None:
|
|
2997
|
+
bias_initializer = init_ops.constant_initializer(0.0, dtype=dtype)
|
|
2998
|
+
|
|
2999
|
+
biases = vs.get_variable(
|
|
3000
|
+
self._bias_variable_name, [output_size],
|
|
3001
|
+
dtype=dtype,
|
|
3002
|
+
initializer=bias_initializer)
|
|
3003
|
+
|
|
3004
|
+
return nn_ops.bias_add(res, biases)
|
|
3005
|
+
|
|
3006
|
+
def call(self, inputs, state):
|
|
3007
|
+
"""Run one step of LSTM.
|
|
3008
|
+
|
|
3009
|
+
Args:
|
|
3010
|
+
inputs: input Tensor, 2D, batch x num_units.
|
|
3011
|
+
state: A tuple of state Tensors, both `2-D`, with column sizes
|
|
3012
|
+
`c_state` and `m_state`.
|
|
3013
|
+
|
|
3014
|
+
Returns:
|
|
3015
|
+
A tuple containing:
|
|
3016
|
+
|
|
3017
|
+
- A `2-D, [batch x output_dim]`, Tensor representing the output of the
|
|
3018
|
+
LSTM after reading `inputs` when previous state was `state`.
|
|
3019
|
+
Here output_dim is:
|
|
3020
|
+
num_proj if num_proj was set,
|
|
3021
|
+
num_units otherwise.
|
|
3022
|
+
- Tensor(s) representing the new state of LSTM after reading `inputs` when
|
|
3023
|
+
the previous state was `state`. Same type and shape(s) as `state`.
|
|
3024
|
+
|
|
3025
|
+
Raises:
|
|
3026
|
+
ValueError: If input size cannot be inferred from inputs via
|
|
3027
|
+
static shape inference.
|
|
3028
|
+
"""
|
|
3029
|
+
dtype = inputs.dtype
|
|
3030
|
+
num_units = self._num_units
|
|
3031
|
+
sigmoid = math_ops.sigmoid
|
|
3032
|
+
c, h = state
|
|
3033
|
+
|
|
3034
|
+
input_size = inputs.get_shape().with_rank(2).dims[1]
|
|
3035
|
+
if input_size.value is None:
|
|
3036
|
+
raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
|
|
3037
|
+
|
|
3038
|
+
with vs.variable_scope(self._scope, initializer=self._initializer):
|
|
3039
|
+
|
|
3040
|
+
concat = self._linear(
|
|
3041
|
+
[inputs, h], 4 * num_units, norm=self._norm, bias=True)
|
|
3042
|
+
|
|
3043
|
+
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
|
|
3044
|
+
i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)
|
|
3045
|
+
|
|
3046
|
+
if self._use_peepholes:
|
|
3047
|
+
w_f_diag = vs.get_variable("w_f_diag", shape=[num_units], dtype=dtype)
|
|
3048
|
+
w_i_diag = vs.get_variable("w_i_diag", shape=[num_units], dtype=dtype)
|
|
3049
|
+
w_o_diag = vs.get_variable("w_o_diag", shape=[num_units], dtype=dtype)
|
|
3050
|
+
|
|
3051
|
+
new_c = (
|
|
3052
|
+
c * sigmoid(f + self._forget_bias + w_f_diag * c) +
|
|
3053
|
+
sigmoid(i + w_i_diag * c) * self._activation(j))
|
|
3054
|
+
else:
|
|
3055
|
+
new_c = (
|
|
3056
|
+
c * sigmoid(f + self._forget_bias) +
|
|
3057
|
+
sigmoid(i) * self._activation(j))
|
|
3058
|
+
|
|
3059
|
+
if self._cell_clip is not None:
|
|
3060
|
+
# pylint: disable=invalid-unary-operand-type
|
|
3061
|
+
new_c = clip_ops.clip_by_value(new_c, -self._cell_clip, self._cell_clip)
|
|
3062
|
+
# pylint: enable=invalid-unary-operand-type
|
|
3063
|
+
if self._use_peepholes:
|
|
3064
|
+
new_h = sigmoid(o + w_o_diag * new_c) * self._activation(new_c)
|
|
3065
|
+
else:
|
|
3066
|
+
new_h = sigmoid(o) * self._activation(new_c)
|
|
3067
|
+
|
|
3068
|
+
if self._num_proj is not None:
|
|
3069
|
+
with vs.variable_scope("projection"):
|
|
3070
|
+
new_h = self._linear(
|
|
3071
|
+
new_h, self._num_proj, norm=self._norm, bias=False)
|
|
3072
|
+
|
|
3073
|
+
if self._proj_clip is not None:
|
|
3074
|
+
# pylint: disable=invalid-unary-operand-type
|
|
3075
|
+
new_h = clip_ops.clip_by_value(new_h, -self._proj_clip,
|
|
3076
|
+
self._proj_clip)
|
|
3077
|
+
# pylint: enable=invalid-unary-operand-type
|
|
3078
|
+
|
|
3079
|
+
new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h)
|
|
3080
|
+
return new_h, new_state
|
|
3081
|
+
|
|
3082
|
+
|
|
3083
|
+
class IndRNNCell(rnn_cell_impl.LayerRNNCell):
|
|
3084
|
+
"""Independently Recurrent Neural Network (IndRNN) cell
|
|
3085
|
+
(cf. https://arxiv.org/abs/1803.04831).
|
|
3086
|
+
|
|
3087
|
+
Args:
|
|
3088
|
+
num_units: int, The number of units in the RNN cell.
|
|
3089
|
+
activation: Nonlinearity to use. Default: `tanh`.
|
|
3090
|
+
reuse: (optional) Python boolean describing whether to reuse variables
|
|
3091
|
+
in an existing scope. If not `True`, and the existing scope already has
|
|
3092
|
+
the given variables, an error is raised.
|
|
3093
|
+
name: String, the name of the layer. Layers with the same name will
|
|
3094
|
+
share weights, but to avoid mistakes we require reuse=True in such
|
|
3095
|
+
cases.
|
|
3096
|
+
dtype: Default dtype of the layer (default of `None` means use the type
|
|
3097
|
+
of the first input). Required when `build` is called before `call`.
|
|
3098
|
+
"""
|
|
3099
|
+
|
|
3100
|
+
def __init__(self,
|
|
3101
|
+
num_units,
|
|
3102
|
+
activation=None,
|
|
3103
|
+
reuse=None,
|
|
3104
|
+
name=None,
|
|
3105
|
+
dtype=None):
|
|
3106
|
+
super(IndRNNCell, self).__init__(_reuse=reuse, name=name, dtype=dtype)
|
|
3107
|
+
|
|
3108
|
+
# Inputs must be 2-dimensional.
|
|
3109
|
+
self.input_spec = input_spec.InputSpec(ndim=2)
|
|
3110
|
+
|
|
3111
|
+
self._num_units = num_units
|
|
3112
|
+
self._activation = activation or math_ops.tanh
|
|
3113
|
+
|
|
3114
|
+
@property
|
|
3115
|
+
def state_size(self):
|
|
3116
|
+
return self._num_units
|
|
3117
|
+
|
|
3118
|
+
@property
|
|
3119
|
+
def output_size(self):
|
|
3120
|
+
return self._num_units
|
|
3121
|
+
|
|
3122
|
+
def build(self, inputs_shape):
|
|
3123
|
+
if tensor_shape.dimension_value(inputs_shape[1]) is None:
|
|
3124
|
+
raise ValueError(
|
|
3125
|
+
"Expected inputs.shape[-1] to be known, saw shape: %s" % inputs_shape)
|
|
3126
|
+
|
|
3127
|
+
input_depth = tensor_shape.dimension_value(inputs_shape[1])
|
|
3128
|
+
# pylint: disable=protected-access
|
|
3129
|
+
self._kernel_w = self.add_variable(
|
|
3130
|
+
"%s_w" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
|
|
3131
|
+
shape=[input_depth, self._num_units])
|
|
3132
|
+
self._kernel_u = self.add_variable(
|
|
3133
|
+
"%s_u" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
|
|
3134
|
+
shape=[1, self._num_units],
|
|
3135
|
+
initializer=init_ops.random_uniform_initializer(
|
|
3136
|
+
minval=-1, maxval=1, dtype=self.dtype))
|
|
3137
|
+
self._bias = self.add_variable(
|
|
3138
|
+
rnn_cell_impl._BIAS_VARIABLE_NAME,
|
|
3139
|
+
shape=[self._num_units],
|
|
3140
|
+
initializer=init_ops.zeros_initializer(dtype=self.dtype))
|
|
3141
|
+
# pylint: enable=protected-access
|
|
3142
|
+
|
|
3143
|
+
self.built = True
|
|
3144
|
+
|
|
3145
|
+
def call(self, inputs, state):
|
|
3146
|
+
"""IndRNN: output = new_state = act(W * input + u * state + B)."""
|
|
3147
|
+
|
|
3148
|
+
gate_inputs = math_ops.matmul(inputs, self._kernel_w) + (
|
|
3149
|
+
state * self._kernel_u)
|
|
3150
|
+
gate_inputs = nn_ops.bias_add(gate_inputs, self._bias)
|
|
3151
|
+
output = self._activation(gate_inputs)
|
|
3152
|
+
return output, output
|
|
3153
|
+
|
|
3154
|
+
|
|
3155
|
+
class IndyGRUCell(rnn_cell_impl.LayerRNNCell):
|
|
3156
|
+
r"""Independently Gated Recurrent Unit cell.
|
|
3157
|
+
|
|
3158
|
+
Based on IndRNNs (https://arxiv.org/abs/1803.04831) and similar to GRUCell,
|
|
3159
|
+
yet with the \\(U_r\\), \\(U_z\\), and \\(U\\) matrices in equations 5, 6, and
|
|
3160
|
+
8 of http://arxiv.org/abs/1406.1078 respectively replaced by diagonal
|
|
3161
|
+
matrices, i.e. a Hadamard product with a single vector:
|
|
3162
|
+
|
|
3163
|
+
$$r_j = \sigma\left([\mathbf W_r\mathbf x]_j +
|
|
3164
|
+
[\mathbf u_r\circ \mathbf h_{(t-1)}]_j\right)$$
|
|
3165
|
+
$$z_j = \sigma\left([\mathbf W_z\mathbf x]_j +
|
|
3166
|
+
[\mathbf u_z\circ \mathbf h_{(t-1)}]_j\right)$$
|
|
3167
|
+
$$\tilde{h}^{(t)}_j = \phi\left([\mathbf W \mathbf x]_j +
|
|
3168
|
+
[\mathbf u \circ \mathbf r \circ \mathbf h_{(t-1)}]_j\right)$$
|
|
3169
|
+
|
|
3170
|
+
where \\(\circ\\) denotes the Hadamard operator. This means that each IndyGRU
|
|
3171
|
+
node sees only its own state, as opposed to seeing all states in the same
|
|
3172
|
+
layer.
|
|
3173
|
+
|
|
3174
|
+
Args:
|
|
3175
|
+
num_units: int, The number of units in the GRU cell.
|
|
3176
|
+
activation: Nonlinearity to use. Default: `tanh`.
|
|
3177
|
+
reuse: (optional) Python boolean describing whether to reuse variables
|
|
3178
|
+
in an existing scope. If not `True`, and the existing scope already has
|
|
3179
|
+
the given variables, an error is raised.
|
|
3180
|
+
kernel_initializer: (optional) The initializer to use for the weight
|
|
3181
|
+
matrices applied to the input.
|
|
3182
|
+
bias_initializer: (optional) The initializer to use for the bias.
|
|
3183
|
+
name: String, the name of the layer. Layers with the same name will
|
|
3184
|
+
share weights, but to avoid mistakes we require reuse=True in such
|
|
3185
|
+
cases.
|
|
3186
|
+
dtype: Default dtype of the layer (default of `None` means use the type
|
|
3187
|
+
of the first input). Required when `build` is called before `call`.
|
|
3188
|
+
"""
|
|
3189
|
+
|
|
3190
|
+
def __init__(self,
|
|
3191
|
+
num_units,
|
|
3192
|
+
activation=None,
|
|
3193
|
+
reuse=None,
|
|
3194
|
+
kernel_initializer=None,
|
|
3195
|
+
bias_initializer=None,
|
|
3196
|
+
name=None,
|
|
3197
|
+
dtype=None):
|
|
3198
|
+
super(IndyGRUCell, self).__init__(_reuse=reuse, name=name, dtype=dtype)
|
|
3199
|
+
|
|
3200
|
+
# Inputs must be 2-dimensional.
|
|
3201
|
+
self.input_spec = input_spec.InputSpec(ndim=2)
|
|
3202
|
+
|
|
3203
|
+
self._num_units = num_units
|
|
3204
|
+
self._activation = activation or math_ops.tanh
|
|
3205
|
+
self._kernel_initializer = kernel_initializer
|
|
3206
|
+
self._bias_initializer = bias_initializer
|
|
3207
|
+
|
|
3208
|
+
@property
|
|
3209
|
+
def state_size(self):
|
|
3210
|
+
return self._num_units
|
|
3211
|
+
|
|
3212
|
+
@property
|
|
3213
|
+
def output_size(self):
|
|
3214
|
+
return self._num_units
|
|
3215
|
+
|
|
3216
|
+
def build(self, inputs_shape):
|
|
3217
|
+
if tensor_shape.dimension_value(inputs_shape[1]) is None:
|
|
3218
|
+
raise ValueError(
|
|
3219
|
+
"Expected inputs.shape[-1] to be known, saw shape: %s" % inputs_shape)
|
|
3220
|
+
|
|
3221
|
+
input_depth = tensor_shape.dimension_value(inputs_shape[1])
|
|
3222
|
+
# pylint: disable=protected-access
|
|
3223
|
+
self._gate_kernel_w = self.add_variable(
|
|
3224
|
+
"gates/%s_w" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
|
|
3225
|
+
shape=[input_depth, 2 * self._num_units],
|
|
3226
|
+
initializer=self._kernel_initializer)
|
|
3227
|
+
self._gate_kernel_u = self.add_variable(
|
|
3228
|
+
"gates/%s_u" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
|
|
3229
|
+
shape=[1, 2 * self._num_units],
|
|
3230
|
+
initializer=init_ops.random_uniform_initializer(
|
|
3231
|
+
minval=-1, maxval=1, dtype=self.dtype))
|
|
3232
|
+
self._gate_bias = self.add_variable(
|
|
3233
|
+
"gates/%s" % rnn_cell_impl._BIAS_VARIABLE_NAME,
|
|
3234
|
+
shape=[2 * self._num_units],
|
|
3235
|
+
initializer=(self._bias_initializer
|
|
3236
|
+
if self._bias_initializer is not None else
|
|
3237
|
+
init_ops.constant_initializer(1.0, dtype=self.dtype)))
|
|
3238
|
+
self._candidate_kernel_w = self.add_variable(
|
|
3239
|
+
"candidate/%s" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
|
|
3240
|
+
shape=[input_depth, self._num_units],
|
|
3241
|
+
initializer=self._kernel_initializer)
|
|
3242
|
+
self._candidate_kernel_u = self.add_variable(
|
|
3243
|
+
"candidate/%s_u" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
|
|
3244
|
+
shape=[1, self._num_units],
|
|
3245
|
+
initializer=init_ops.random_uniform_initializer(
|
|
3246
|
+
minval=-1, maxval=1, dtype=self.dtype))
|
|
3247
|
+
self._candidate_bias = self.add_variable(
|
|
3248
|
+
"candidate/%s" % rnn_cell_impl._BIAS_VARIABLE_NAME,
|
|
3249
|
+
shape=[self._num_units],
|
|
3250
|
+
initializer=(self._bias_initializer
|
|
3251
|
+
if self._bias_initializer is not None else
|
|
3252
|
+
init_ops.zeros_initializer(dtype=self.dtype)))
|
|
3253
|
+
# pylint: enable=protected-access
|
|
3254
|
+
|
|
3255
|
+
self.built = True
|
|
3256
|
+
|
|
3257
|
+
def call(self, inputs, state):
|
|
3258
|
+
"""Recurrently independent Gated Recurrent Unit (GRU) with nunits cells."""
|
|
3259
|
+
|
|
3260
|
+
gate_inputs = math_ops.matmul(inputs, self._gate_kernel_w) + (
|
|
3261
|
+
gen_array_ops.tile(state, [1, 2]) * self._gate_kernel_u)
|
|
3262
|
+
gate_inputs = nn_ops.bias_add(gate_inputs, self._gate_bias)
|
|
3263
|
+
|
|
3264
|
+
value = math_ops.sigmoid(gate_inputs)
|
|
3265
|
+
r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1)
|
|
3266
|
+
|
|
3267
|
+
r_state = r * state
|
|
3268
|
+
|
|
3269
|
+
candidate = math_ops.matmul(inputs, self._candidate_kernel_w) + (
|
|
3270
|
+
r_state * self._candidate_kernel_u)
|
|
3271
|
+
candidate = nn_ops.bias_add(candidate, self._candidate_bias)
|
|
3272
|
+
|
|
3273
|
+
c = self._activation(candidate)
|
|
3274
|
+
new_h = u * state + (1 - u) * c
|
|
3275
|
+
return new_h, new_h
|
|
3276
|
+
|
|
3277
|
+
|
|
3278
|
+
class IndyLSTMCell(rnn_cell_impl.LayerRNNCell):
|
|
3279
|
+
r"""Basic IndyLSTM recurrent network cell.
|
|
3280
|
+
|
|
3281
|
+
Based on IndRNNs (https://arxiv.org/abs/1803.04831) and similar to
|
|
3282
|
+
BasicLSTMCell, yet with the \\(U_f\\), \\(U_i\\), \\(U_o\\) and \\(U_c\\)
|
|
3283
|
+
matrices in the regular LSTM equations replaced by diagonal matrices, i.e. a
|
|
3284
|
+
Hadamard product with a single vector:
|
|
3285
|
+
|
|
3286
|
+
$$f_t = \sigma_g\left(W_f x_t + u_f \circ h_{t-1} + b_f\right)$$
|
|
3287
|
+
$$i_t = \sigma_g\left(W_i x_t + u_i \circ h_{t-1} + b_i\right)$$
|
|
3288
|
+
$$o_t = \sigma_g\left(W_o x_t + u_o \circ h_{t-1} + b_o\right)$$
|
|
3289
|
+
$$c_t = f_t \circ c_{t-1} +
|
|
3290
|
+
i_t \circ \sigma_c\left(W_c x_t + u_c \circ h_{t-1} + b_c\right)$$
|
|
3291
|
+
|
|
3292
|
+
where \\(\circ\\) denotes the Hadamard operator. This means that each IndyLSTM
|
|
3293
|
+
node sees only its own state \\(h\\) and \\(c\\), as opposed to seeing all
|
|
3294
|
+
states in the same layer.
|
|
3295
|
+
|
|
3296
|
+
We add forget_bias (default: 1) to the biases of the forget gate in order to
|
|
3297
|
+
reduce the scale of forgetting in the beginning of the training.
|
|
3298
|
+
|
|
3299
|
+
It does not allow cell clipping, a projection layer, and does not
|
|
3300
|
+
use peep-hole connections: it is the basic baseline.
|
|
3301
|
+
|
|
3302
|
+
For a detailed analysis of IndyLSTMs, see https://arxiv.org/abs/1903.08023.
|
|
3303
|
+
"""
|
|
3304
|
+
|
|
3305
|
+
def __init__(self,
|
|
3306
|
+
num_units,
|
|
3307
|
+
forget_bias=1.0,
|
|
3308
|
+
activation=None,
|
|
3309
|
+
reuse=None,
|
|
3310
|
+
kernel_initializer=None,
|
|
3311
|
+
bias_initializer=None,
|
|
3312
|
+
name=None,
|
|
3313
|
+
dtype=None):
|
|
3314
|
+
"""Initialize the IndyLSTM cell.
|
|
3315
|
+
|
|
3316
|
+
Args:
|
|
3317
|
+
num_units: int, The number of units in the LSTM cell.
|
|
3318
|
+
forget_bias: float, The bias added to forget gates (see above).
|
|
3319
|
+
Must set to `0.0` manually when restoring from CudnnLSTM-trained
|
|
3320
|
+
checkpoints.
|
|
3321
|
+
activation: Activation function of the inner states. Default: `tanh`.
|
|
3322
|
+
reuse: (optional) Python boolean describing whether to reuse variables
|
|
3323
|
+
in an existing scope. If not `True`, and the existing scope already has
|
|
3324
|
+
the given variables, an error is raised.
|
|
3325
|
+
kernel_initializer: (optional) The initializer to use for the weight
|
|
3326
|
+
matrix applied to the inputs.
|
|
3327
|
+
bias_initializer: (optional) The initializer to use for the bias.
|
|
3328
|
+
name: String, the name of the layer. Layers with the same name will
|
|
3329
|
+
share weights, but to avoid mistakes we require reuse=True in such
|
|
3330
|
+
cases.
|
|
3331
|
+
dtype: Default dtype of the layer (default of `None` means use the type
|
|
3332
|
+
of the first input). Required when `build` is called before `call`.
|
|
3333
|
+
"""
|
|
3334
|
+
super(IndyLSTMCell, self).__init__(_reuse=reuse, name=name, dtype=dtype)
|
|
3335
|
+
|
|
3336
|
+
# Inputs must be 2-dimensional.
|
|
3337
|
+
self.input_spec = input_spec.InputSpec(ndim=2)
|
|
3338
|
+
|
|
3339
|
+
self._num_units = num_units
|
|
3340
|
+
self._forget_bias = forget_bias
|
|
3341
|
+
self._activation = activation or math_ops.tanh
|
|
3342
|
+
self._kernel_initializer = kernel_initializer
|
|
3343
|
+
self._bias_initializer = bias_initializer
|
|
3344
|
+
|
|
3345
|
+
@property
|
|
3346
|
+
def state_size(self):
|
|
3347
|
+
return rnn_cell_impl.LSTMStateTuple(self._num_units, self._num_units)
|
|
3348
|
+
|
|
3349
|
+
@property
|
|
3350
|
+
def output_size(self):
|
|
3351
|
+
return self._num_units
|
|
3352
|
+
|
|
3353
|
+
def build(self, inputs_shape):
|
|
3354
|
+
if tensor_shape.dimension_value(inputs_shape[1]) is None:
|
|
3355
|
+
raise ValueError(
|
|
3356
|
+
"Expected inputs.shape[-1] to be known, saw shape: %s" % inputs_shape)
|
|
3357
|
+
|
|
3358
|
+
input_depth = tensor_shape.dimension_value(inputs_shape[1])
|
|
3359
|
+
# pylint: disable=protected-access
|
|
3360
|
+
self._kernel_w = self.add_variable(
|
|
3361
|
+
"%s_w" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
|
|
3362
|
+
shape=[input_depth, 4 * self._num_units],
|
|
3363
|
+
initializer=self._kernel_initializer)
|
|
3364
|
+
self._kernel_u = self.add_variable(
|
|
3365
|
+
"%s_u" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
|
|
3366
|
+
shape=[1, 4 * self._num_units],
|
|
3367
|
+
initializer=init_ops.random_uniform_initializer(
|
|
3368
|
+
minval=-1, maxval=1, dtype=self.dtype))
|
|
3369
|
+
self._bias = self.add_variable(
|
|
3370
|
+
rnn_cell_impl._BIAS_VARIABLE_NAME,
|
|
3371
|
+
shape=[4 * self._num_units],
|
|
3372
|
+
initializer=(self._bias_initializer
|
|
3373
|
+
if self._bias_initializer is not None else
|
|
3374
|
+
init_ops.zeros_initializer(dtype=self.dtype)))
|
|
3375
|
+
# pylint: enable=protected-access
|
|
3376
|
+
|
|
3377
|
+
self.built = True
|
|
3378
|
+
|
|
3379
|
+
def call(self, inputs, state):
|
|
3380
|
+
"""Independent Long short-term memory cell (IndyLSTM).
|
|
3381
|
+
|
|
3382
|
+
Args:
|
|
3383
|
+
inputs: `2-D` tensor with shape `[batch_size, input_size]`.
|
|
3384
|
+
state: An `LSTMStateTuple` of state tensors, each shaped
|
|
3385
|
+
`[batch_size, num_units]`.
|
|
3386
|
+
|
|
3387
|
+
Returns:
|
|
3388
|
+
A pair containing the new hidden state, and the new state (a
|
|
3389
|
+
`LSTMStateTuple`).
|
|
3390
|
+
"""
|
|
3391
|
+
sigmoid = math_ops.sigmoid
|
|
3392
|
+
one = constant_op.constant(1, dtype=dtypes.int32)
|
|
3393
|
+
c, h = state
|
|
3394
|
+
|
|
3395
|
+
gate_inputs = math_ops.matmul(inputs, self._kernel_w)
|
|
3396
|
+
gate_inputs += gen_array_ops.tile(h, [1, 4]) * self._kernel_u
|
|
3397
|
+
gate_inputs = nn_ops.bias_add(gate_inputs, self._bias)
|
|
3398
|
+
|
|
3399
|
+
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
|
|
3400
|
+
i, j, f, o = array_ops.split(
|
|
3401
|
+
value=gate_inputs, num_or_size_splits=4, axis=one)
|
|
3402
|
+
|
|
3403
|
+
forget_bias_tensor = constant_op.constant(self._forget_bias, dtype=f.dtype)
|
|
3404
|
+
# Note that using `add` and `multiply` instead of `+` and `*` gives a
|
|
3405
|
+
# performance improvement. So using those at the cost of readability.
|
|
3406
|
+
add = math_ops.add
|
|
3407
|
+
multiply = math_ops.multiply
|
|
3408
|
+
new_c = add(
|
|
3409
|
+
multiply(c, sigmoid(add(f, forget_bias_tensor))),
|
|
3410
|
+
multiply(sigmoid(i), self._activation(j)))
|
|
3411
|
+
new_h = multiply(self._activation(new_c), sigmoid(o))
|
|
3412
|
+
|
|
3413
|
+
new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h)
|
|
3414
|
+
return new_h, new_state
|
|
3415
|
+
|
|
3416
|
+
|
|
3417
|
+
NTMControllerState = collections.namedtuple(
|
|
3418
|
+
"NTMControllerState",
|
|
3419
|
+
("controller_state", "read_vector_list", "w_list", "M", "time"))
|
|
3420
|
+
|
|
3421
|
+
|
|
3422
|
+
class NTMCell(rnn_cell_impl.LayerRNNCell):
|
|
3423
|
+
"""Neural Turing Machine Cell with RNN controller.
|
|
3424
|
+
|
|
3425
|
+
Implementation based on:
|
|
3426
|
+
https://arxiv.org/abs/1807.08518
|
|
3427
|
+
Mark Collier, Joeran Beel
|
|
3428
|
+
|
|
3429
|
+
which is in turn based on the source code of:
|
|
3430
|
+
https://github.com/snowkylin/ntm
|
|
3431
|
+
|
|
3432
|
+
and of course the original NTM paper:
|
|
3433
|
+
Neural Turing Machines
|
|
3434
|
+
https://arxiv.org/abs/1410.5401
|
|
3435
|
+
A Graves, G Wayne, I Danihelka
|
|
3436
|
+
"""
|
|
3437
|
+
|
|
3438
|
+
def __init__(self,
|
|
3439
|
+
controller,
|
|
3440
|
+
memory_size,
|
|
3441
|
+
memory_vector_dim,
|
|
3442
|
+
read_head_num,
|
|
3443
|
+
write_head_num,
|
|
3444
|
+
shift_range=1,
|
|
3445
|
+
output_dim=None,
|
|
3446
|
+
clip_value=20,
|
|
3447
|
+
dtype=dtypes.float32,
|
|
3448
|
+
name=None):
|
|
3449
|
+
"""Initialize the NTM Cell.
|
|
3450
|
+
|
|
3451
|
+
Args:
|
|
3452
|
+
controller: an RNNCell, the RNN controller.
|
|
3453
|
+
memory_size: int, The number of memory locations in the NTM memory
|
|
3454
|
+
matrix
|
|
3455
|
+
memory_vector_dim: int, The dimensionality of each location in the NTM
|
|
3456
|
+
memory matrix
|
|
3457
|
+
read_head_num: int, The number of read heads from the controller into
|
|
3458
|
+
memory
|
|
3459
|
+
write_head_num: int, The number of write heads from the controller into
|
|
3460
|
+
memory
|
|
3461
|
+
shift_range: int, The number of places to the left/right it is possible
|
|
3462
|
+
to iterate the previous address to in a single step
|
|
3463
|
+
output_dim: int, The number of dimensions to make a linear projection of
|
|
3464
|
+
the NTM controller outputs to. If None, no linear projection is
|
|
3465
|
+
applied
|
|
3466
|
+
clip_value: float, The maximum absolute value the controller parameters
|
|
3467
|
+
are clipped to
|
|
3468
|
+
dtype: Default dtype of the layer (default of `None` means use the type
|
|
3469
|
+
of the first input). Required when `build` is called before `call`.
|
|
3470
|
+
name: String, the name of the layer. Layers with the same name will
|
|
3471
|
+
share weights, but to avoid mistakes we require reuse=True in such
|
|
3472
|
+
cases.
|
|
3473
|
+
"""
|
|
3474
|
+
super(NTMCell, self).__init__(dtype=dtype, name=name)
|
|
3475
|
+
|
|
3476
|
+
rnn_cell_impl.assert_like_rnncell("NTM RNN controller cell", controller)
|
|
3477
|
+
|
|
3478
|
+
self.controller = controller
|
|
3479
|
+
self.memory_size = memory_size
|
|
3480
|
+
self.memory_vector_dim = memory_vector_dim
|
|
3481
|
+
self.read_head_num = read_head_num
|
|
3482
|
+
self.write_head_num = write_head_num
|
|
3483
|
+
self.clip_value = clip_value
|
|
3484
|
+
|
|
3485
|
+
self.output_dim = output_dim
|
|
3486
|
+
self.shift_range = shift_range
|
|
3487
|
+
|
|
3488
|
+
self.num_parameters_per_head = (
|
|
3489
|
+
self.memory_vector_dim + 2 * self.shift_range + 4)
|
|
3490
|
+
self.num_heads = self.read_head_num + self.write_head_num
|
|
3491
|
+
self.total_parameter_num = (
|
|
3492
|
+
self.num_parameters_per_head * self.num_heads +
|
|
3493
|
+
self.memory_vector_dim * 2 * self.write_head_num)
|
|
3494
|
+
|
|
3495
|
+
@property
|
|
3496
|
+
def state_size(self):
|
|
3497
|
+
return NTMControllerState(
|
|
3498
|
+
controller_state=self.controller.state_size,
|
|
3499
|
+
read_vector_list=[
|
|
3500
|
+
self.memory_vector_dim for _ in range(self.read_head_num)
|
|
3501
|
+
],
|
|
3502
|
+
w_list=[
|
|
3503
|
+
self.memory_size
|
|
3504
|
+
for _ in range(self.read_head_num + self.write_head_num)
|
|
3505
|
+
],
|
|
3506
|
+
M=tensor_shape.TensorShape([self.memory_size * self.memory_vector_dim]),
|
|
3507
|
+
time=tensor_shape.TensorShape([]))
|
|
3508
|
+
|
|
3509
|
+
@property
|
|
3510
|
+
def output_size(self):
|
|
3511
|
+
return self.output_dim
|
|
3512
|
+
|
|
3513
|
+
def build(self, inputs_shape):
|
|
3514
|
+
if self.output_dim is None:
|
|
3515
|
+
if inputs_shape[1].value is None:
|
|
3516
|
+
raise ValueError(
|
|
3517
|
+
"Expected inputs.shape[-1] to be known, saw shape: %s" %
|
|
3518
|
+
inputs_shape)
|
|
3519
|
+
else:
|
|
3520
|
+
self.output_dim = inputs_shape[1].value
|
|
3521
|
+
|
|
3522
|
+
def _create_linear_initializer(input_size, dtype=dtypes.float32):
|
|
3523
|
+
stddev = 1.0 / math.sqrt(input_size)
|
|
3524
|
+
return init_ops.truncated_normal_initializer(stddev=stddev, dtype=dtype)
|
|
3525
|
+
|
|
3526
|
+
self._params_kernel = self.add_variable(
|
|
3527
|
+
"parameters_kernel",
|
|
3528
|
+
shape=[self.controller.output_size, self.total_parameter_num],
|
|
3529
|
+
initializer=_create_linear_initializer(self.controller.output_size))
|
|
3530
|
+
|
|
3531
|
+
self._params_bias = self.add_variable(
|
|
3532
|
+
"parameters_bias",
|
|
3533
|
+
shape=[self.total_parameter_num],
|
|
3534
|
+
initializer=init_ops.constant_initializer(0.0, dtype=self.dtype))
|
|
3535
|
+
|
|
3536
|
+
self._output_kernel = self.add_variable(
|
|
3537
|
+
"output_kernel",
|
|
3538
|
+
shape=[
|
|
3539
|
+
self.controller.output_size +
|
|
3540
|
+
self.memory_vector_dim * self.read_head_num, self.output_dim
|
|
3541
|
+
],
|
|
3542
|
+
initializer=_create_linear_initializer(self.controller.output_size +
|
|
3543
|
+
self.memory_vector_dim *
|
|
3544
|
+
self.read_head_num))
|
|
3545
|
+
|
|
3546
|
+
self._output_bias = self.add_variable(
|
|
3547
|
+
"output_bias",
|
|
3548
|
+
shape=[self.output_dim],
|
|
3549
|
+
initializer=init_ops.constant_initializer(0.0, dtype=self.dtype))
|
|
3550
|
+
|
|
3551
|
+
self._init_read_vectors = [
|
|
3552
|
+
self.add_variable(
|
|
3553
|
+
"initial_read_vector_%d" % i,
|
|
3554
|
+
shape=[1, self.memory_vector_dim],
|
|
3555
|
+
initializer=initializers.glorot_uniform())
|
|
3556
|
+
for i in range(self.read_head_num)
|
|
3557
|
+
]
|
|
3558
|
+
|
|
3559
|
+
self._init_address_weights = [
|
|
3560
|
+
self.add_variable(
|
|
3561
|
+
"initial_address_weights_%d" % i,
|
|
3562
|
+
shape=[1, self.memory_size],
|
|
3563
|
+
initializer=initializers.glorot_uniform())
|
|
3564
|
+
for i in range(self.read_head_num + self.write_head_num)
|
|
3565
|
+
]
|
|
3566
|
+
|
|
3567
|
+
self._M = self.add_variable(
|
|
3568
|
+
"memory",
|
|
3569
|
+
shape=[self.memory_size, self.memory_vector_dim],
|
|
3570
|
+
initializer=init_ops.constant_initializer(1e-6, dtype=self.dtype))
|
|
3571
|
+
|
|
3572
|
+
self.built = True
|
|
3573
|
+
|
|
3574
|
+
def call(self, x, prev_state):
|
|
3575
|
+
# Addressing Mechanisms (Sec 3.3)
|
|
3576
|
+
|
|
3577
|
+
def _prev_read_vector_list_initial_value():
|
|
3578
|
+
return [
|
|
3579
|
+
self._expand(
|
|
3580
|
+
math_ops.tanh(
|
|
3581
|
+
array_ops.squeeze(
|
|
3582
|
+
math_ops.matmul(
|
|
3583
|
+
array_ops.ones([1, 1]), self._init_read_vectors[i]))),
|
|
3584
|
+
dim=0,
|
|
3585
|
+
N=x.shape[0].value or array_ops.shape(x)[0])
|
|
3586
|
+
for i in range(self.read_head_num)
|
|
3587
|
+
]
|
|
3588
|
+
|
|
3589
|
+
prev_read_vector_list = control_flow_ops.cond(
|
|
3590
|
+
math_ops.equal(prev_state.time,
|
|
3591
|
+
0), _prev_read_vector_list_initial_value, lambda:
|
|
3592
|
+
prev_state.read_vector_list)
|
|
3593
|
+
if self.read_head_num == 1:
|
|
3594
|
+
prev_read_vector_list = [prev_read_vector_list]
|
|
3595
|
+
|
|
3596
|
+
controller_input = array_ops.concat([x] + prev_read_vector_list, axis=1)
|
|
3597
|
+
controller_output, controller_state = self.controller(
|
|
3598
|
+
controller_input, prev_state.controller_state)
|
|
3599
|
+
|
|
3600
|
+
parameters = math_ops.matmul(controller_output, self._params_kernel)
|
|
3601
|
+
parameters = nn_ops.bias_add(parameters, self._params_bias)
|
|
3602
|
+
parameters = clip_ops.clip_by_value(parameters, -self.clip_value,
|
|
3603
|
+
self.clip_value)
|
|
3604
|
+
head_parameter_list = array_ops.split(
|
|
3605
|
+
parameters[:, :self.num_parameters_per_head * self.num_heads],
|
|
3606
|
+
self.num_heads,
|
|
3607
|
+
axis=1)
|
|
3608
|
+
erase_add_list = array_ops.split(
|
|
3609
|
+
parameters[:, self.num_parameters_per_head * self.num_heads:],
|
|
3610
|
+
2 * self.write_head_num,
|
|
3611
|
+
axis=1)
|
|
3612
|
+
|
|
3613
|
+
def _prev_w_list_initial_value():
|
|
3614
|
+
return [
|
|
3615
|
+
self._expand(
|
|
3616
|
+
nn_ops.softmax(
|
|
3617
|
+
array_ops.squeeze(
|
|
3618
|
+
math_ops.matmul(
|
|
3619
|
+
array_ops.ones([1, 1]),
|
|
3620
|
+
self._init_address_weights[i]))),
|
|
3621
|
+
dim=0,
|
|
3622
|
+
N=x.shape[0].value or array_ops.shape(x)[0])
|
|
3623
|
+
for i in range(self.read_head_num + self.write_head_num)
|
|
3624
|
+
]
|
|
3625
|
+
|
|
3626
|
+
prev_w_list = control_flow_ops.cond(
|
|
3627
|
+
math_ops.equal(prev_state.time, 0),
|
|
3628
|
+
_prev_w_list_initial_value, lambda: prev_state.w_list)
|
|
3629
|
+
if (self.read_head_num + self.write_head_num) == 1:
|
|
3630
|
+
prev_w_list = [prev_w_list]
|
|
3631
|
+
|
|
3632
|
+
prev_M = control_flow_ops.cond(
|
|
3633
|
+
math_ops.equal(prev_state.time, 0), lambda: self._expand(
|
|
3634
|
+
self._M, dim=0, N=x.shape[0].value or array_ops.shape(x)[0]),
|
|
3635
|
+
lambda: prev_state.M)
|
|
3636
|
+
|
|
3637
|
+
w_list = []
|
|
3638
|
+
for i, head_parameter in enumerate(head_parameter_list):
|
|
3639
|
+
k = math_ops.tanh(head_parameter[:, 0:self.memory_vector_dim])
|
|
3640
|
+
beta = nn_ops.softplus(head_parameter[:, self.memory_vector_dim])
|
|
3641
|
+
g = math_ops.sigmoid(head_parameter[:, self.memory_vector_dim + 1])
|
|
3642
|
+
s = nn_ops.softmax(head_parameter[:, self.memory_vector_dim +
|
|
3643
|
+
2:(self.memory_vector_dim + 2 +
|
|
3644
|
+
(self.shift_range * 2 + 1))])
|
|
3645
|
+
gamma = nn_ops.softplus(head_parameter[:, -1]) + 1
|
|
3646
|
+
w = self._addressing(k, beta, g, s, gamma, prev_M, prev_w_list[i])
|
|
3647
|
+
w_list.append(w)
|
|
3648
|
+
|
|
3649
|
+
# Reading (Sec 3.1)
|
|
3650
|
+
|
|
3651
|
+
read_w_list = w_list[:self.read_head_num]
|
|
3652
|
+
read_vector_list = []
|
|
3653
|
+
for i in range(self.read_head_num):
|
|
3654
|
+
read_vector = math_ops.reduce_sum(
|
|
3655
|
+
array_ops.expand_dims(read_w_list[i], dim=2) * prev_M, axis=1)
|
|
3656
|
+
read_vector_list.append(read_vector)
|
|
3657
|
+
|
|
3658
|
+
# Writing (Sec 3.2)
|
|
3659
|
+
|
|
3660
|
+
write_w_list = w_list[self.read_head_num:]
|
|
3661
|
+
M = prev_M
|
|
3662
|
+
for i in range(self.write_head_num):
|
|
3663
|
+
w = array_ops.expand_dims(write_w_list[i], axis=2)
|
|
3664
|
+
erase_vector = array_ops.expand_dims(
|
|
3665
|
+
math_ops.sigmoid(erase_add_list[i * 2]), axis=1)
|
|
3666
|
+
add_vector = array_ops.expand_dims(
|
|
3667
|
+
math_ops.tanh(erase_add_list[i * 2 + 1]), axis=1)
|
|
3668
|
+
erase_M = array_ops.ones_like(M) - math_ops.matmul(w, erase_vector)
|
|
3669
|
+
M = M * erase_M + math_ops.matmul(w, add_vector)
|
|
3670
|
+
|
|
3671
|
+
output = math_ops.matmul(
|
|
3672
|
+
array_ops.concat([controller_output] + read_vector_list, axis=1),
|
|
3673
|
+
self._output_kernel)
|
|
3674
|
+
output = nn_ops.bias_add(output, self._output_bias)
|
|
3675
|
+
output = clip_ops.clip_by_value(output, -self.clip_value, self.clip_value)
|
|
3676
|
+
|
|
3677
|
+
return output, NTMControllerState(
|
|
3678
|
+
controller_state=controller_state,
|
|
3679
|
+
read_vector_list=read_vector_list,
|
|
3680
|
+
w_list=w_list,
|
|
3681
|
+
M=M,
|
|
3682
|
+
time=prev_state.time + 1)
|
|
3683
|
+
|
|
3684
|
+
def _expand(self, x, dim, N):
|
|
3685
|
+
return array_ops.concat([array_ops.expand_dims(x, dim) for _ in range(N)],
|
|
3686
|
+
axis=dim)
|
|
3687
|
+
|
|
3688
|
+
def _addressing(self, k, beta, g, s, gamma, prev_M, prev_w):
|
|
3689
|
+
# Sec 3.3.1 Focusing by Content
|
|
3690
|
+
|
|
3691
|
+
k = array_ops.expand_dims(k, axis=2)
|
|
3692
|
+
inner_product = math_ops.matmul(prev_M, k)
|
|
3693
|
+
k_norm = math_ops.sqrt(
|
|
3694
|
+
math_ops.reduce_sum(math_ops.square(k), axis=1, keepdims=True))
|
|
3695
|
+
M_norm = math_ops.sqrt(
|
|
3696
|
+
math_ops.reduce_sum(math_ops.square(prev_M), axis=2, keepdims=True))
|
|
3697
|
+
norm_product = M_norm * k_norm
|
|
3698
|
+
|
|
3699
|
+
# eq (6)
|
|
3700
|
+
K = array_ops.squeeze(inner_product / (norm_product + 1e-8))
|
|
3701
|
+
|
|
3702
|
+
K_amplified = math_ops.exp(array_ops.expand_dims(beta, axis=1) * K)
|
|
3703
|
+
|
|
3704
|
+
# eq (5)
|
|
3705
|
+
w_c = K_amplified / math_ops.reduce_sum(K_amplified, axis=1, keepdims=True)
|
|
3706
|
+
|
|
3707
|
+
# Sec 3.3.2 Focusing by Location
|
|
3708
|
+
|
|
3709
|
+
g = array_ops.expand_dims(g, axis=1)
|
|
3710
|
+
|
|
3711
|
+
# eq (7)
|
|
3712
|
+
w_g = g * w_c + (1 - g) * prev_w
|
|
3713
|
+
|
|
3714
|
+
s = array_ops.concat([
|
|
3715
|
+
s[:, :self.shift_range + 1],
|
|
3716
|
+
array_ops.zeros([
|
|
3717
|
+
s.shape[0].value or array_ops.shape(s)[0], self.memory_size -
|
|
3718
|
+
(self.shift_range * 2 + 1)
|
|
3719
|
+
]), s[:, -self.shift_range:]
|
|
3720
|
+
],
|
|
3721
|
+
axis=1)
|
|
3722
|
+
t = array_ops.concat(
|
|
3723
|
+
[array_ops.reverse(s, axis=[1]),
|
|
3724
|
+
array_ops.reverse(s, axis=[1])],
|
|
3725
|
+
axis=1)
|
|
3726
|
+
s_matrix = array_ops.stack([
|
|
3727
|
+
t[:, self.memory_size - i - 1:self.memory_size * 2 - i - 1]
|
|
3728
|
+
for i in range(self.memory_size)
|
|
3729
|
+
],
|
|
3730
|
+
axis=1)
|
|
3731
|
+
|
|
3732
|
+
# eq (8)
|
|
3733
|
+
w_ = math_ops.reduce_sum(
|
|
3734
|
+
array_ops.expand_dims(w_g, axis=1) * s_matrix, axis=2)
|
|
3735
|
+
w_sharpen = math_ops.pow(w_, array_ops.expand_dims(gamma, axis=1))
|
|
3736
|
+
|
|
3737
|
+
# eq (9)
|
|
3738
|
+
w = w_sharpen / math_ops.reduce_sum(w_sharpen, axis=1, keepdims=True)
|
|
3739
|
+
|
|
3740
|
+
return w
|
|
3741
|
+
|
|
3742
|
+
def zero_state(self, batch_size, dtype):
|
|
3743
|
+
read_vector_list = [
|
|
3744
|
+
array_ops.zeros([batch_size, self.memory_vector_dim])
|
|
3745
|
+
for _ in range(self.read_head_num)
|
|
3746
|
+
]
|
|
3747
|
+
|
|
3748
|
+
w_list = [
|
|
3749
|
+
array_ops.zeros([batch_size, self.memory_size])
|
|
3750
|
+
for _ in range(self.read_head_num + self.write_head_num)
|
|
3751
|
+
]
|
|
3752
|
+
|
|
3753
|
+
controller_init_state = self.controller.zero_state(batch_size, dtype)
|
|
3754
|
+
|
|
3755
|
+
M = array_ops.zeros([batch_size, self.memory_size, self.memory_vector_dim])
|
|
3756
|
+
|
|
3757
|
+
return NTMControllerState(
|
|
3758
|
+
controller_state=controller_init_state,
|
|
3759
|
+
read_vector_list=read_vector_list,
|
|
3760
|
+
w_list=w_list,
|
|
3761
|
+
M=M,
|
|
3762
|
+
time=0)
|
|
3763
|
+
|
|
3764
|
+
|
|
3765
|
+
class MinimalRNNCell(rnn_cell_impl.LayerRNNCell):
|
|
3766
|
+
"""MinimalRNN cell.
|
|
3767
|
+
|
|
3768
|
+
The implementation is based on:
|
|
3769
|
+
|
|
3770
|
+
https://arxiv.org/pdf/1806.05394v2.pdf
|
|
3771
|
+
|
|
3772
|
+
Minmin Chen, Jeffrey Pennington, Samuel S. Schoenholz.
|
|
3773
|
+
"Dynamical Isometry and a Mean Field Theory of RNNs: Gating Enables Signal
|
|
3774
|
+
Propagation in Recurrent Neural Networks." ICML, 2018.
|
|
3775
|
+
|
|
3776
|
+
A MinimalRNN cell first projects the input to the hidden space. The new
|
|
3777
|
+
hidden state is then calculated as a weighted sum of the projected input and
|
|
3778
|
+
the previous hidden state, using a single update gate.
|
|
3779
|
+
"""
|
|
3780
|
+
|
|
3781
|
+
def __init__(self,
|
|
3782
|
+
units,
|
|
3783
|
+
activation="tanh",
|
|
3784
|
+
kernel_initializer="glorot_uniform",
|
|
3785
|
+
bias_initializer="ones",
|
|
3786
|
+
name=None,
|
|
3787
|
+
dtype=None,
|
|
3788
|
+
**kwargs):
|
|
3789
|
+
"""Initialize the parameters for a MinimalRNN cell.
|
|
3790
|
+
|
|
3791
|
+
Args:
|
|
3792
|
+
units: int, The number of units in the MinimalRNN cell.
|
|
3793
|
+
activation: Nonlinearity to use in the feedforward network. Default:
|
|
3794
|
+
`tanh`.
|
|
3795
|
+
kernel_initializer: The initializer to use for the weight in the update
|
|
3796
|
+
gate and feedforward network. Default: `glorot_uniform`.
|
|
3797
|
+
bias_initializer: The initializer to use for the bias in the update
|
|
3798
|
+
gate. Default: `ones`.
|
|
3799
|
+
name: String, the name of the cell.
|
|
3800
|
+
dtype: Default dtype of the cell.
|
|
3801
|
+
**kwargs: Dict, keyword named properties for common cell attributes.
|
|
3802
|
+
"""
|
|
3803
|
+
super(MinimalRNNCell, self).__init__(name=name, dtype=dtype, **kwargs)
|
|
3804
|
+
|
|
3805
|
+
# Inputs must be 2-dimensional.
|
|
3806
|
+
self.input_spec = input_spec.InputSpec(ndim=2)
|
|
3807
|
+
|
|
3808
|
+
self.units = units
|
|
3809
|
+
self.activation = activations.get(activation)
|
|
3810
|
+
self.kernel_initializer = initializers.get(kernel_initializer)
|
|
3811
|
+
self.bias_initializer = initializers.get(bias_initializer)
|
|
3812
|
+
|
|
3813
|
+
@property
|
|
3814
|
+
def state_size(self):
|
|
3815
|
+
return self.units
|
|
3816
|
+
|
|
3817
|
+
@property
|
|
3818
|
+
def output_size(self):
|
|
3819
|
+
return self.units
|
|
3820
|
+
|
|
3821
|
+
def build(self, inputs_shape):
|
|
3822
|
+
if inputs_shape[-1] is None:
|
|
3823
|
+
raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
|
|
3824
|
+
% str(inputs_shape))
|
|
3825
|
+
|
|
3826
|
+
input_size = inputs_shape[-1]
|
|
3827
|
+
# pylint: disable=protected-access
|
|
3828
|
+
# self._kernel contains W_x, W, V
|
|
3829
|
+
self.kernel = self.add_weight(
|
|
3830
|
+
name=rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
|
|
3831
|
+
shape=[input_size + 2 * self.units, self.units],
|
|
3832
|
+
initializer=self.kernel_initializer)
|
|
3833
|
+
self.bias = self.add_weight(
|
|
3834
|
+
name=rnn_cell_impl._BIAS_VARIABLE_NAME,
|
|
3835
|
+
shape=[self.units],
|
|
3836
|
+
initializer=self.bias_initializer)
|
|
3837
|
+
# pylint: enable=protected-access
|
|
3838
|
+
|
|
3839
|
+
self.built = True
|
|
3840
|
+
|
|
3841
|
+
def call(self, inputs, state):
|
|
3842
|
+
"""Run one step of MinimalRNN.
|
|
3843
|
+
|
|
3844
|
+
Args:
|
|
3845
|
+
inputs: input Tensor, must be 2-D, `[batch, input_size]`.
|
|
3846
|
+
state: state Tensor, must be 2-D, `[batch, state_size]`.
|
|
3847
|
+
|
|
3848
|
+
Returns:
|
|
3849
|
+
A tuple containing:
|
|
3850
|
+
|
|
3851
|
+
- Output: A `2-D` tensor with shape `[batch_size, state_size]`.
|
|
3852
|
+
- New state: A `2-D` tensor with shape `[batch_size, state_size]`.
|
|
3853
|
+
|
|
3854
|
+
Raises:
|
|
3855
|
+
ValueError: If input size cannot be inferred from inputs via
|
|
3856
|
+
static shape inference.
|
|
3857
|
+
"""
|
|
3858
|
+
input_size = inputs.get_shape()[1]
|
|
3859
|
+
if tensor_shape.dimension_value(input_size) is None:
|
|
3860
|
+
raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
|
|
3861
|
+
|
|
3862
|
+
feedforward_weight, gate_weight = array_ops.split(
|
|
3863
|
+
value=self.kernel,
|
|
3864
|
+
num_or_size_splits=[tensor_shape.dimension_value(input_size),
|
|
3865
|
+
2 * self.units],
|
|
3866
|
+
axis=0)
|
|
3867
|
+
|
|
3868
|
+
feedforward = math_ops.matmul(inputs, feedforward_weight)
|
|
3869
|
+
feedforward = self.activation(feedforward)
|
|
3870
|
+
|
|
3871
|
+
gate_inputs = math_ops.matmul(
|
|
3872
|
+
array_ops.concat([feedforward, state], 1), gate_weight)
|
|
3873
|
+
gate_inputs = nn_ops.bias_add(gate_inputs, self.bias)
|
|
3874
|
+
u = math_ops.sigmoid(gate_inputs)
|
|
3875
|
+
|
|
3876
|
+
new_h = u * state + (1 - u) * feedforward
|
|
3877
|
+
return new_h, new_h
|
|
3878
|
+
|
|
3879
|
+
|
|
3880
|
+
class CFNCell(rnn_cell_impl.LayerRNNCell):
|
|
3881
|
+
"""Chaos Free Network cell.
|
|
3882
|
+
|
|
3883
|
+
The implementation is based on:
|
|
3884
|
+
|
|
3885
|
+
https://openreview.net/pdf?id=S1dIzvclg
|
|
3886
|
+
|
|
3887
|
+
Thomas Laurent, James von Brecht.
|
|
3888
|
+
"A recurrent neural network without chaos." ICLR, 2017.
|
|
3889
|
+
|
|
3890
|
+
A CFN cell first projects the input to the hidden space. The hidden state
|
|
3891
|
+
goes through a contractive mapping. The new hidden state is then calculated
|
|
3892
|
+
as a linear combination of the projected input and the contracted previous
|
|
3893
|
+
hidden state, using decoupled input and forget gates.
|
|
3894
|
+
"""
|
|
3895
|
+
|
|
3896
|
+
def __init__(self,
|
|
3897
|
+
units,
|
|
3898
|
+
activation="tanh",
|
|
3899
|
+
kernel_initializer="glorot_uniform",
|
|
3900
|
+
bias_initializer="ones",
|
|
3901
|
+
name=None,
|
|
3902
|
+
dtype=None,
|
|
3903
|
+
**kwargs):
|
|
3904
|
+
"""Initialize the parameters for a CFN cell.
|
|
3905
|
+
|
|
3906
|
+
Args:
|
|
3907
|
+
units: int, The number of units in the CFN cell.
|
|
3908
|
+
activation: Nonlinearity to use. Default: `tanh`.
|
|
3909
|
+
kernel_initializer: Initializer for the `kernel` weights
|
|
3910
|
+
matrix. Default: `glorot_uniform`.
|
|
3911
|
+
bias_initializer: The initializer to use for the bias in the
|
|
3912
|
+
gates. Default: `ones`.
|
|
3913
|
+
name: String, the name of the cell.
|
|
3914
|
+
dtype: Default dtype of the cell.
|
|
3915
|
+
**kwargs: Dict, keyword named properties for common cell attributes.
|
|
3916
|
+
"""
|
|
3917
|
+
super(CFNCell, self).__init__(name=name, dtype=dtype, **kwargs)
|
|
3918
|
+
|
|
3919
|
+
# Inputs must be 2-dimensional.
|
|
3920
|
+
self.input_spec = input_spec.InputSpec(ndim=2)
|
|
3921
|
+
|
|
3922
|
+
self.units = units
|
|
3923
|
+
self.activation = activations.get(activation)
|
|
3924
|
+
self.kernel_initializer = initializers.get(kernel_initializer)
|
|
3925
|
+
self.bias_initializer = initializers.get(bias_initializer)
|
|
3926
|
+
|
|
3927
|
+
@property
|
|
3928
|
+
def state_size(self):
|
|
3929
|
+
return self.units
|
|
3930
|
+
|
|
3931
|
+
@property
|
|
3932
|
+
def output_size(self):
|
|
3933
|
+
return self.units
|
|
3934
|
+
|
|
3935
|
+
def build(self, inputs_shape):
|
|
3936
|
+
if inputs_shape[-1] is None:
|
|
3937
|
+
raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
|
|
3938
|
+
% str(inputs_shape))
|
|
3939
|
+
|
|
3940
|
+
input_size = inputs_shape[-1]
|
|
3941
|
+
# pylint: disable=protected-access
|
|
3942
|
+
# `self.kernel` contains V_{\theta}, V_{\eta}, W.
|
|
3943
|
+
# `self.recurrent_kernel` contains U_{\theta}, U_{\eta}.
|
|
3944
|
+
# `self.bias` contains b_{\theta}, b_{\eta}.
|
|
3945
|
+
self.kernel = self.add_weight(
|
|
3946
|
+
shape=[input_size, 3 * self.units],
|
|
3947
|
+
name=rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
|
|
3948
|
+
initializer=self.kernel_initializer)
|
|
3949
|
+
self.recurrent_kernel = self.add_weight(
|
|
3950
|
+
shape=[self.units, 2 * self.units],
|
|
3951
|
+
name="recurrent_%s" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
|
|
3952
|
+
initializer=self.kernel_initializer)
|
|
3953
|
+
self.bias = self.add_weight(
|
|
3954
|
+
shape=[2 * self.units],
|
|
3955
|
+
name=rnn_cell_impl._BIAS_VARIABLE_NAME,
|
|
3956
|
+
initializer=self.bias_initializer)
|
|
3957
|
+
# pylint: enable=protected-access
|
|
3958
|
+
|
|
3959
|
+
self.built = True
|
|
3960
|
+
|
|
3961
|
+
def call(self, inputs, state):
|
|
3962
|
+
"""Run one step of CFN.
|
|
3963
|
+
|
|
3964
|
+
Args:
|
|
3965
|
+
inputs: input Tensor, must be 2-D, `[batch, input_size]`.
|
|
3966
|
+
state: state Tensor, must be 2-D, `[batch, state_size]`.
|
|
3967
|
+
|
|
3968
|
+
Returns:
|
|
3969
|
+
A tuple containing:
|
|
3970
|
+
|
|
3971
|
+
- Output: A `2-D` tensor with shape `[batch_size, state_size]`.
|
|
3972
|
+
- New state: A `2-D` tensor with shape `[batch_size, state_size]`.
|
|
3973
|
+
|
|
3974
|
+
Raises:
|
|
3975
|
+
ValueError: If input size cannot be inferred from inputs via
|
|
3976
|
+
static shape inference.
|
|
3977
|
+
"""
|
|
3978
|
+
input_size = inputs.get_shape()[-1]
|
|
3979
|
+
if tensor_shape.dimension_value(input_size) is None:
|
|
3980
|
+
raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
|
|
3981
|
+
|
|
3982
|
+
# The variable names u, v, w, b are consistent with the notations in the
|
|
3983
|
+
# original paper.
|
|
3984
|
+
v, w = array_ops.split(
|
|
3985
|
+
value=self.kernel,
|
|
3986
|
+
num_or_size_splits=[2 * self.units, self.units],
|
|
3987
|
+
axis=1)
|
|
3988
|
+
u = self.recurrent_kernel
|
|
3989
|
+
b = self.bias
|
|
3990
|
+
|
|
3991
|
+
gates = math_ops.matmul(state, u) + math_ops.matmul(inputs, v)
|
|
3992
|
+
gates = nn_ops.bias_add(gates, b)
|
|
3993
|
+
gates = math_ops.sigmoid(gates)
|
|
3994
|
+
theta, eta = array_ops.split(value=gates,
|
|
3995
|
+
num_or_size_splits=2,
|
|
3996
|
+
axis=1)
|
|
3997
|
+
|
|
3998
|
+
proj_input = math_ops.matmul(inputs, w)
|
|
3999
|
+
|
|
4000
|
+
# The input gate is (1 - eta), which is different from the original paper.
|
|
4001
|
+
# This is for the propose of initialization. With the default
|
|
4002
|
+
# bias_initializer `ones`, the input gate is initialized to a small number.
|
|
4003
|
+
new_h = theta * self.activation(state) + (1 - eta) * self.activation(
|
|
4004
|
+
proj_input)
|
|
4005
|
+
|
|
4006
|
+
return new_h, new_h
|