spark-nlp 2.6.3rc1__py2.py3-none-any.whl → 6.2.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- com/johnsnowlabs/ml/__init__.py +0 -0
- com/johnsnowlabs/ml/ai/__init__.py +10 -0
- com/johnsnowlabs/nlp/__init__.py +4 -2
- spark_nlp-6.2.1.dist-info/METADATA +362 -0
- spark_nlp-6.2.1.dist-info/RECORD +292 -0
- {spark_nlp-2.6.3rc1.dist-info → spark_nlp-6.2.1.dist-info}/WHEEL +1 -1
- sparknlp/__init__.py +281 -27
- sparknlp/annotation.py +137 -6
- sparknlp/annotation_audio.py +61 -0
- sparknlp/annotation_image.py +82 -0
- sparknlp/annotator/__init__.py +93 -0
- sparknlp/annotator/audio/__init__.py +16 -0
- sparknlp/annotator/audio/hubert_for_ctc.py +188 -0
- sparknlp/annotator/audio/wav2vec2_for_ctc.py +161 -0
- sparknlp/annotator/audio/whisper_for_ctc.py +251 -0
- sparknlp/annotator/chunk2_doc.py +85 -0
- sparknlp/annotator/chunker.py +137 -0
- sparknlp/annotator/classifier_dl/__init__.py +61 -0
- sparknlp/annotator/classifier_dl/albert_for_multiple_choice.py +161 -0
- sparknlp/annotator/classifier_dl/albert_for_question_answering.py +172 -0
- sparknlp/annotator/classifier_dl/albert_for_sequence_classification.py +201 -0
- sparknlp/annotator/classifier_dl/albert_for_token_classification.py +179 -0
- sparknlp/annotator/classifier_dl/albert_for_zero_shot_classification.py +211 -0
- sparknlp/annotator/classifier_dl/bart_for_zero_shot_classification.py +225 -0
- sparknlp/annotator/classifier_dl/bert_for_multiple_choice.py +161 -0
- sparknlp/annotator/classifier_dl/bert_for_question_answering.py +168 -0
- sparknlp/annotator/classifier_dl/bert_for_sequence_classification.py +202 -0
- sparknlp/annotator/classifier_dl/bert_for_token_classification.py +177 -0
- sparknlp/annotator/classifier_dl/bert_for_zero_shot_classification.py +212 -0
- sparknlp/annotator/classifier_dl/camembert_for_question_answering.py +168 -0
- sparknlp/annotator/classifier_dl/camembert_for_sequence_classification.py +205 -0
- sparknlp/annotator/classifier_dl/camembert_for_token_classification.py +173 -0
- sparknlp/annotator/classifier_dl/camembert_for_zero_shot_classification.py +202 -0
- sparknlp/annotator/classifier_dl/classifier_dl.py +320 -0
- sparknlp/annotator/classifier_dl/deberta_for_question_answering.py +168 -0
- sparknlp/annotator/classifier_dl/deberta_for_sequence_classification.py +198 -0
- sparknlp/annotator/classifier_dl/deberta_for_token_classification.py +175 -0
- sparknlp/annotator/classifier_dl/deberta_for_zero_shot_classification.py +193 -0
- sparknlp/annotator/classifier_dl/distil_bert_for_question_answering.py +168 -0
- sparknlp/annotator/classifier_dl/distil_bert_for_sequence_classification.py +201 -0
- sparknlp/annotator/classifier_dl/distil_bert_for_token_classification.py +175 -0
- sparknlp/annotator/classifier_dl/distil_bert_for_zero_shot_classification.py +211 -0
- sparknlp/annotator/classifier_dl/distilbert_for_multiple_choice.py +161 -0
- sparknlp/annotator/classifier_dl/longformer_for_question_answering.py +168 -0
- sparknlp/annotator/classifier_dl/longformer_for_sequence_classification.py +201 -0
- sparknlp/annotator/classifier_dl/longformer_for_token_classification.py +176 -0
- sparknlp/annotator/classifier_dl/mpnet_for_question_answering.py +148 -0
- sparknlp/annotator/classifier_dl/mpnet_for_sequence_classification.py +188 -0
- sparknlp/annotator/classifier_dl/mpnet_for_token_classification.py +173 -0
- sparknlp/annotator/classifier_dl/multi_classifier_dl.py +395 -0
- sparknlp/annotator/classifier_dl/roberta_for_multiple_choice.py +161 -0
- sparknlp/annotator/classifier_dl/roberta_for_question_answering.py +168 -0
- sparknlp/annotator/classifier_dl/roberta_for_sequence_classification.py +201 -0
- sparknlp/annotator/classifier_dl/roberta_for_token_classification.py +189 -0
- sparknlp/annotator/classifier_dl/roberta_for_zero_shot_classification.py +225 -0
- sparknlp/annotator/classifier_dl/sentiment_dl.py +378 -0
- sparknlp/annotator/classifier_dl/tapas_for_question_answering.py +170 -0
- sparknlp/annotator/classifier_dl/xlm_roberta_for_multiple_choice.py +149 -0
- sparknlp/annotator/classifier_dl/xlm_roberta_for_question_answering.py +168 -0
- sparknlp/annotator/classifier_dl/xlm_roberta_for_sequence_classification.py +201 -0
- sparknlp/annotator/classifier_dl/xlm_roberta_for_token_classification.py +173 -0
- sparknlp/annotator/classifier_dl/xlm_roberta_for_zero_shot_classification.py +225 -0
- sparknlp/annotator/classifier_dl/xlnet_for_sequence_classification.py +201 -0
- sparknlp/annotator/classifier_dl/xlnet_for_token_classification.py +176 -0
- sparknlp/annotator/cleaners/__init__.py +15 -0
- sparknlp/annotator/cleaners/cleaner.py +202 -0
- sparknlp/annotator/cleaners/extractor.py +191 -0
- sparknlp/annotator/coref/__init__.py +1 -0
- sparknlp/annotator/coref/spanbert_coref.py +221 -0
- sparknlp/annotator/cv/__init__.py +29 -0
- sparknlp/annotator/cv/blip_for_question_answering.py +172 -0
- sparknlp/annotator/cv/clip_for_zero_shot_classification.py +193 -0
- sparknlp/annotator/cv/convnext_for_image_classification.py +269 -0
- sparknlp/annotator/cv/florence2_transformer.py +180 -0
- sparknlp/annotator/cv/gemma3_for_multimodal.py +346 -0
- sparknlp/annotator/cv/internvl_for_multimodal.py +280 -0
- sparknlp/annotator/cv/janus_for_multimodal.py +351 -0
- sparknlp/annotator/cv/llava_for_multimodal.py +328 -0
- sparknlp/annotator/cv/mllama_for_multimodal.py +340 -0
- sparknlp/annotator/cv/paligemma_for_multimodal.py +308 -0
- sparknlp/annotator/cv/phi3_vision_for_multimodal.py +328 -0
- sparknlp/annotator/cv/qwen2vl_transformer.py +332 -0
- sparknlp/annotator/cv/smolvlm_transformer.py +426 -0
- sparknlp/annotator/cv/swin_for_image_classification.py +242 -0
- sparknlp/annotator/cv/vision_encoder_decoder_for_image_captioning.py +240 -0
- sparknlp/annotator/cv/vit_for_image_classification.py +217 -0
- sparknlp/annotator/dataframe_optimizer.py +216 -0
- sparknlp/annotator/date2_chunk.py +88 -0
- sparknlp/annotator/dependency/__init__.py +17 -0
- sparknlp/annotator/dependency/dependency_parser.py +294 -0
- sparknlp/annotator/dependency/typed_dependency_parser.py +318 -0
- sparknlp/annotator/document_character_text_splitter.py +228 -0
- sparknlp/annotator/document_normalizer.py +235 -0
- sparknlp/annotator/document_token_splitter.py +175 -0
- sparknlp/annotator/document_token_splitter_test.py +85 -0
- sparknlp/annotator/embeddings/__init__.py +45 -0
- sparknlp/annotator/embeddings/albert_embeddings.py +230 -0
- sparknlp/annotator/embeddings/auto_gguf_embeddings.py +539 -0
- sparknlp/annotator/embeddings/bert_embeddings.py +208 -0
- sparknlp/annotator/embeddings/bert_sentence_embeddings.py +224 -0
- sparknlp/annotator/embeddings/bge_embeddings.py +199 -0
- sparknlp/annotator/embeddings/camembert_embeddings.py +210 -0
- sparknlp/annotator/embeddings/chunk_embeddings.py +149 -0
- sparknlp/annotator/embeddings/deberta_embeddings.py +208 -0
- sparknlp/annotator/embeddings/distil_bert_embeddings.py +221 -0
- sparknlp/annotator/embeddings/doc2vec.py +352 -0
- sparknlp/annotator/embeddings/e5_embeddings.py +195 -0
- sparknlp/annotator/embeddings/e5v_embeddings.py +138 -0
- sparknlp/annotator/embeddings/elmo_embeddings.py +251 -0
- sparknlp/annotator/embeddings/instructor_embeddings.py +204 -0
- sparknlp/annotator/embeddings/longformer_embeddings.py +211 -0
- sparknlp/annotator/embeddings/minilm_embeddings.py +189 -0
- sparknlp/annotator/embeddings/mpnet_embeddings.py +192 -0
- sparknlp/annotator/embeddings/mxbai_embeddings.py +184 -0
- sparknlp/annotator/embeddings/nomic_embeddings.py +181 -0
- sparknlp/annotator/embeddings/roberta_embeddings.py +225 -0
- sparknlp/annotator/embeddings/roberta_sentence_embeddings.py +191 -0
- sparknlp/annotator/embeddings/sentence_embeddings.py +134 -0
- sparknlp/annotator/embeddings/snowflake_embeddings.py +202 -0
- sparknlp/annotator/embeddings/uae_embeddings.py +211 -0
- sparknlp/annotator/embeddings/universal_sentence_encoder.py +211 -0
- sparknlp/annotator/embeddings/word2vec.py +353 -0
- sparknlp/annotator/embeddings/word_embeddings.py +385 -0
- sparknlp/annotator/embeddings/xlm_roberta_embeddings.py +225 -0
- sparknlp/annotator/embeddings/xlm_roberta_sentence_embeddings.py +194 -0
- sparknlp/annotator/embeddings/xlnet_embeddings.py +227 -0
- sparknlp/annotator/er/__init__.py +16 -0
- sparknlp/annotator/er/entity_ruler.py +267 -0
- sparknlp/annotator/graph_extraction.py +368 -0
- sparknlp/annotator/keyword_extraction/__init__.py +16 -0
- sparknlp/annotator/keyword_extraction/yake_keyword_extraction.py +270 -0
- sparknlp/annotator/ld_dl/__init__.py +16 -0
- sparknlp/annotator/ld_dl/language_detector_dl.py +199 -0
- sparknlp/annotator/lemmatizer.py +250 -0
- sparknlp/annotator/matcher/__init__.py +20 -0
- sparknlp/annotator/matcher/big_text_matcher.py +272 -0
- sparknlp/annotator/matcher/date_matcher.py +303 -0
- sparknlp/annotator/matcher/multi_date_matcher.py +109 -0
- sparknlp/annotator/matcher/regex_matcher.py +221 -0
- sparknlp/annotator/matcher/text_matcher.py +290 -0
- sparknlp/annotator/n_gram_generator.py +141 -0
- sparknlp/annotator/ner/__init__.py +21 -0
- sparknlp/annotator/ner/ner_approach.py +94 -0
- sparknlp/annotator/ner/ner_converter.py +148 -0
- sparknlp/annotator/ner/ner_crf.py +397 -0
- sparknlp/annotator/ner/ner_dl.py +591 -0
- sparknlp/annotator/ner/ner_dl_graph_checker.py +293 -0
- sparknlp/annotator/ner/ner_overwriter.py +166 -0
- sparknlp/annotator/ner/zero_shot_ner_model.py +173 -0
- sparknlp/annotator/normalizer.py +230 -0
- sparknlp/annotator/openai/__init__.py +16 -0
- sparknlp/annotator/openai/openai_completion.py +349 -0
- sparknlp/annotator/openai/openai_embeddings.py +106 -0
- sparknlp/annotator/param/__init__.py +17 -0
- sparknlp/annotator/param/classifier_encoder.py +98 -0
- sparknlp/annotator/param/evaluation_dl_params.py +130 -0
- sparknlp/annotator/pos/__init__.py +16 -0
- sparknlp/annotator/pos/perceptron.py +263 -0
- sparknlp/annotator/sentence/__init__.py +17 -0
- sparknlp/annotator/sentence/sentence_detector.py +290 -0
- sparknlp/annotator/sentence/sentence_detector_dl.py +467 -0
- sparknlp/annotator/sentiment/__init__.py +17 -0
- sparknlp/annotator/sentiment/sentiment_detector.py +208 -0
- sparknlp/annotator/sentiment/vivekn_sentiment.py +242 -0
- sparknlp/annotator/seq2seq/__init__.py +35 -0
- sparknlp/annotator/seq2seq/auto_gguf_model.py +304 -0
- sparknlp/annotator/seq2seq/auto_gguf_reranker.py +334 -0
- sparknlp/annotator/seq2seq/auto_gguf_vision_model.py +336 -0
- sparknlp/annotator/seq2seq/bart_transformer.py +420 -0
- sparknlp/annotator/seq2seq/cohere_transformer.py +357 -0
- sparknlp/annotator/seq2seq/cpm_transformer.py +321 -0
- sparknlp/annotator/seq2seq/gpt2_transformer.py +363 -0
- sparknlp/annotator/seq2seq/llama2_transformer.py +343 -0
- sparknlp/annotator/seq2seq/llama3_transformer.py +381 -0
- sparknlp/annotator/seq2seq/m2m100_transformer.py +392 -0
- sparknlp/annotator/seq2seq/marian_transformer.py +374 -0
- sparknlp/annotator/seq2seq/mistral_transformer.py +348 -0
- sparknlp/annotator/seq2seq/nllb_transformer.py +420 -0
- sparknlp/annotator/seq2seq/olmo_transformer.py +326 -0
- sparknlp/annotator/seq2seq/phi2_transformer.py +326 -0
- sparknlp/annotator/seq2seq/phi3_transformer.py +330 -0
- sparknlp/annotator/seq2seq/phi4_transformer.py +387 -0
- sparknlp/annotator/seq2seq/qwen_transformer.py +340 -0
- sparknlp/annotator/seq2seq/starcoder_transformer.py +335 -0
- sparknlp/annotator/seq2seq/t5_transformer.py +425 -0
- sparknlp/annotator/similarity/__init__.py +0 -0
- sparknlp/annotator/similarity/document_similarity_ranker.py +379 -0
- sparknlp/annotator/spell_check/__init__.py +18 -0
- sparknlp/annotator/spell_check/context_spell_checker.py +911 -0
- sparknlp/annotator/spell_check/norvig_sweeting.py +358 -0
- sparknlp/annotator/spell_check/symmetric_delete.py +299 -0
- sparknlp/annotator/stemmer.py +79 -0
- sparknlp/annotator/stop_words_cleaner.py +190 -0
- sparknlp/annotator/tf_ner_dl_graph_builder.py +179 -0
- sparknlp/annotator/token/__init__.py +19 -0
- sparknlp/annotator/token/chunk_tokenizer.py +118 -0
- sparknlp/annotator/token/recursive_tokenizer.py +205 -0
- sparknlp/annotator/token/regex_tokenizer.py +208 -0
- sparknlp/annotator/token/tokenizer.py +561 -0
- sparknlp/annotator/token2_chunk.py +76 -0
- sparknlp/annotator/ws/__init__.py +16 -0
- sparknlp/annotator/ws/word_segmenter.py +429 -0
- sparknlp/base/__init__.py +30 -0
- sparknlp/base/audio_assembler.py +95 -0
- sparknlp/base/doc2_chunk.py +169 -0
- sparknlp/base/document_assembler.py +164 -0
- sparknlp/base/embeddings_finisher.py +201 -0
- sparknlp/base/finisher.py +217 -0
- sparknlp/base/gguf_ranking_finisher.py +234 -0
- sparknlp/base/graph_finisher.py +125 -0
- sparknlp/base/has_recursive_fit.py +24 -0
- sparknlp/base/has_recursive_transform.py +22 -0
- sparknlp/base/image_assembler.py +172 -0
- sparknlp/base/light_pipeline.py +429 -0
- sparknlp/base/multi_document_assembler.py +164 -0
- sparknlp/base/prompt_assembler.py +207 -0
- sparknlp/base/recursive_pipeline.py +107 -0
- sparknlp/base/table_assembler.py +145 -0
- sparknlp/base/token_assembler.py +124 -0
- sparknlp/common/__init__.py +26 -0
- sparknlp/common/annotator_approach.py +41 -0
- sparknlp/common/annotator_model.py +47 -0
- sparknlp/common/annotator_properties.py +114 -0
- sparknlp/common/annotator_type.py +38 -0
- sparknlp/common/completion_post_processing.py +37 -0
- sparknlp/common/coverage_result.py +22 -0
- sparknlp/common/match_strategy.py +33 -0
- sparknlp/common/properties.py +1298 -0
- sparknlp/common/read_as.py +33 -0
- sparknlp/common/recursive_annotator_approach.py +35 -0
- sparknlp/common/storage.py +149 -0
- sparknlp/common/utils.py +39 -0
- sparknlp/functions.py +315 -5
- sparknlp/internal/__init__.py +1199 -0
- sparknlp/internal/annotator_java_ml.py +32 -0
- sparknlp/internal/annotator_transformer.py +37 -0
- sparknlp/internal/extended_java_wrapper.py +63 -0
- sparknlp/internal/params_getters_setters.py +71 -0
- sparknlp/internal/recursive.py +70 -0
- sparknlp/logging/__init__.py +15 -0
- sparknlp/logging/comet.py +467 -0
- sparknlp/partition/__init__.py +16 -0
- sparknlp/partition/partition.py +244 -0
- sparknlp/partition/partition_properties.py +902 -0
- sparknlp/partition/partition_transformer.py +200 -0
- sparknlp/pretrained/__init__.py +17 -0
- sparknlp/pretrained/pretrained_pipeline.py +158 -0
- sparknlp/pretrained/resource_downloader.py +216 -0
- sparknlp/pretrained/utils.py +35 -0
- sparknlp/reader/__init__.py +15 -0
- sparknlp/reader/enums.py +19 -0
- sparknlp/reader/pdf_to_text.py +190 -0
- sparknlp/reader/reader2doc.py +124 -0
- sparknlp/reader/reader2image.py +136 -0
- sparknlp/reader/reader2table.py +44 -0
- sparknlp/reader/reader_assembler.py +159 -0
- sparknlp/reader/sparknlp_reader.py +461 -0
- sparknlp/training/__init__.py +20 -0
- sparknlp/training/_tf_graph_builders/__init__.py +0 -0
- sparknlp/training/_tf_graph_builders/graph_builders.py +299 -0
- sparknlp/training/_tf_graph_builders/ner_dl/__init__.py +0 -0
- sparknlp/training/_tf_graph_builders/ner_dl/create_graph.py +41 -0
- sparknlp/training/_tf_graph_builders/ner_dl/dataset_encoder.py +78 -0
- sparknlp/training/_tf_graph_builders/ner_dl/ner_model.py +521 -0
- sparknlp/training/_tf_graph_builders/ner_dl/ner_model_saver.py +62 -0
- sparknlp/training/_tf_graph_builders/ner_dl/sentence_grouper.py +28 -0
- sparknlp/training/_tf_graph_builders/tf2contrib/__init__.py +36 -0
- sparknlp/training/_tf_graph_builders/tf2contrib/core_rnn_cell.py +385 -0
- sparknlp/training/_tf_graph_builders/tf2contrib/fused_rnn_cell.py +183 -0
- sparknlp/training/_tf_graph_builders/tf2contrib/gru_ops.py +235 -0
- sparknlp/training/_tf_graph_builders/tf2contrib/lstm_ops.py +665 -0
- sparknlp/training/_tf_graph_builders/tf2contrib/rnn.py +245 -0
- sparknlp/training/_tf_graph_builders/tf2contrib/rnn_cell.py +4006 -0
- sparknlp/training/_tf_graph_builders_1x/__init__.py +0 -0
- sparknlp/training/_tf_graph_builders_1x/graph_builders.py +277 -0
- sparknlp/training/_tf_graph_builders_1x/ner_dl/__init__.py +0 -0
- sparknlp/training/_tf_graph_builders_1x/ner_dl/create_graph.py +34 -0
- sparknlp/training/_tf_graph_builders_1x/ner_dl/dataset_encoder.py +78 -0
- sparknlp/training/_tf_graph_builders_1x/ner_dl/ner_model.py +532 -0
- sparknlp/training/_tf_graph_builders_1x/ner_dl/ner_model_saver.py +62 -0
- sparknlp/training/_tf_graph_builders_1x/ner_dl/sentence_grouper.py +28 -0
- sparknlp/training/conll.py +150 -0
- sparknlp/training/conllu.py +103 -0
- sparknlp/training/pos.py +103 -0
- sparknlp/training/pub_tator.py +76 -0
- sparknlp/training/spacy_to_annotation.py +57 -0
- sparknlp/training/tfgraphs.py +5 -0
- sparknlp/upload_to_hub.py +149 -0
- sparknlp/util.py +51 -5
- com/__init__.pyc +0 -0
- com/__pycache__/__init__.cpython-36.pyc +0 -0
- com/johnsnowlabs/__init__.pyc +0 -0
- com/johnsnowlabs/__pycache__/__init__.cpython-36.pyc +0 -0
- com/johnsnowlabs/nlp/__init__.pyc +0 -0
- com/johnsnowlabs/nlp/__pycache__/__init__.cpython-36.pyc +0 -0
- spark_nlp-2.6.3rc1.dist-info/METADATA +0 -36
- spark_nlp-2.6.3rc1.dist-info/RECORD +0 -48
- sparknlp/__init__.pyc +0 -0
- sparknlp/__pycache__/__init__.cpython-36.pyc +0 -0
- sparknlp/__pycache__/annotation.cpython-36.pyc +0 -0
- sparknlp/__pycache__/annotator.cpython-36.pyc +0 -0
- sparknlp/__pycache__/base.cpython-36.pyc +0 -0
- sparknlp/__pycache__/common.cpython-36.pyc +0 -0
- sparknlp/__pycache__/embeddings.cpython-36.pyc +0 -0
- sparknlp/__pycache__/functions.cpython-36.pyc +0 -0
- sparknlp/__pycache__/internal.cpython-36.pyc +0 -0
- sparknlp/__pycache__/pretrained.cpython-36.pyc +0 -0
- sparknlp/__pycache__/storage.cpython-36.pyc +0 -0
- sparknlp/__pycache__/training.cpython-36.pyc +0 -0
- sparknlp/__pycache__/util.cpython-36.pyc +0 -0
- sparknlp/annotation.pyc +0 -0
- sparknlp/annotator.py +0 -3006
- sparknlp/annotator.pyc +0 -0
- sparknlp/base.py +0 -347
- sparknlp/base.pyc +0 -0
- sparknlp/common.py +0 -193
- sparknlp/common.pyc +0 -0
- sparknlp/embeddings.py +0 -40
- sparknlp/embeddings.pyc +0 -0
- sparknlp/internal.py +0 -288
- sparknlp/internal.pyc +0 -0
- sparknlp/pretrained.py +0 -123
- sparknlp/pretrained.pyc +0 -0
- sparknlp/storage.py +0 -32
- sparknlp/storage.pyc +0 -0
- sparknlp/training.py +0 -62
- sparknlp/training.pyc +0 -0
- sparknlp/util.pyc +0 -0
- {spark_nlp-2.6.3rc1.dist-info → spark_nlp-6.2.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,665 @@
|
|
|
1
|
+
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""LSTM Block Cell ops."""
|
|
16
|
+
from __future__ import absolute_import
|
|
17
|
+
from __future__ import division
|
|
18
|
+
from __future__ import print_function
|
|
19
|
+
|
|
20
|
+
import abc
|
|
21
|
+
|
|
22
|
+
import six
|
|
23
|
+
from tensorflow.python.framework import dtypes
|
|
24
|
+
from tensorflow.python.framework import ops
|
|
25
|
+
from tensorflow.python.keras.engine import input_spec
|
|
26
|
+
from tensorflow.python.layers import base as base_layer
|
|
27
|
+
from tensorflow.python.ops import array_ops
|
|
28
|
+
from tensorflow.python.ops import gen_rnn_ops
|
|
29
|
+
from tensorflow.python.ops import init_ops
|
|
30
|
+
from tensorflow.python.ops import math_ops
|
|
31
|
+
from tensorflow.python.ops import nn_ops
|
|
32
|
+
from tensorflow.python.ops import rnn_cell_impl
|
|
33
|
+
|
|
34
|
+
LayerRNNCell = rnn_cell_impl.LayerRNNCell # pylint: disable=invalid-name
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
# pylint: disable=invalid-name
|
|
38
|
+
def _lstm_block_cell(x,
|
|
39
|
+
cs_prev,
|
|
40
|
+
h_prev,
|
|
41
|
+
w,
|
|
42
|
+
b,
|
|
43
|
+
wci=None,
|
|
44
|
+
wcf=None,
|
|
45
|
+
wco=None,
|
|
46
|
+
forget_bias=None,
|
|
47
|
+
cell_clip=None,
|
|
48
|
+
use_peephole=None,
|
|
49
|
+
name=None):
|
|
50
|
+
r"""Computes the LSTM cell forward propagation for 1 time step.
|
|
51
|
+
|
|
52
|
+
This implementation uses 1 weight matrix and 1 bias vector, and there's an
|
|
53
|
+
optional peephole connection.
|
|
54
|
+
|
|
55
|
+
This kernel op implements the following mathematical equations:
|
|
56
|
+
|
|
57
|
+
```python
|
|
58
|
+
xh = [x, h_prev]
|
|
59
|
+
[i, ci, f, o] = xh * w + b
|
|
60
|
+
f = f + forget_bias
|
|
61
|
+
|
|
62
|
+
if not use_peephole:
|
|
63
|
+
wci = wcf = wco = 0
|
|
64
|
+
|
|
65
|
+
i = sigmoid(cs_prev * wci + i)
|
|
66
|
+
f = sigmoid(cs_prev * wcf + f)
|
|
67
|
+
ci = tanh(ci)
|
|
68
|
+
|
|
69
|
+
cs = ci .* i + cs_prev .* f
|
|
70
|
+
cs = clip(cs, cell_clip)
|
|
71
|
+
|
|
72
|
+
o = sigmoid(cs * wco + o)
|
|
73
|
+
co = tanh(cs)
|
|
74
|
+
h = co .* o
|
|
75
|
+
```
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
x: A `Tensor`. Must be one of the following types: `float32`.
|
|
79
|
+
The input to the LSTM cell, shape (batch_size, num_inputs).
|
|
80
|
+
cs_prev: A `Tensor`. Must have the same type as `x`.
|
|
81
|
+
Value of the cell state at previous time step.
|
|
82
|
+
h_prev: A `Tensor`. Must have the same type as `x`.
|
|
83
|
+
Output of the previous cell at previous time step.
|
|
84
|
+
w: A `Tensor`. Must have the same type as `x`. The weight matrix.
|
|
85
|
+
b: A `Tensor`. Must have the same type as `x`. The bias vector.
|
|
86
|
+
wci: A `Tensor`. Must have the same type as `x`.
|
|
87
|
+
The weight matrix for input gate peephole connection.
|
|
88
|
+
wcf: A `Tensor`. Must have the same type as `x`.
|
|
89
|
+
The weight matrix for forget gate peephole connection.
|
|
90
|
+
wco: A `Tensor`. Must have the same type as `x`.
|
|
91
|
+
The weight matrix for output gate peephole connection.
|
|
92
|
+
forget_bias: An optional `float`. Defaults to `1`. The forget gate bias.
|
|
93
|
+
cell_clip: An optional `float`. Defaults to `-1` (no clipping).
|
|
94
|
+
Value to clip the 'cs' value to. Disable by setting to negative value.
|
|
95
|
+
use_peephole: An optional `bool`. Defaults to `False`.
|
|
96
|
+
Whether to use peephole weights.
|
|
97
|
+
name: A name for the operation (optional).
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
A tuple of `Tensor` objects (i, cs, f, o, ci, co, h).
|
|
101
|
+
i: A `Tensor`. Has the same type as `x`. The input gate.
|
|
102
|
+
cs: A `Tensor`. Has the same type as `x`. The cell state before the tanh.
|
|
103
|
+
f: A `Tensor`. Has the same type as `x`. The forget gate.
|
|
104
|
+
o: A `Tensor`. Has the same type as `x`. The output gate.
|
|
105
|
+
ci: A `Tensor`. Has the same type as `x`. The cell input.
|
|
106
|
+
co: A `Tensor`. Has the same type as `x`. The cell after the tanh.
|
|
107
|
+
h: A `Tensor`. Has the same type as `x`. The output h vector.
|
|
108
|
+
|
|
109
|
+
Raises:
|
|
110
|
+
ValueError: If cell_size is None.
|
|
111
|
+
"""
|
|
112
|
+
if wci is None:
|
|
113
|
+
cell_size = cs_prev.get_shape().with_rank(2).dims[1].value
|
|
114
|
+
if cell_size is None:
|
|
115
|
+
raise ValueError("cell_size from `cs_prev` should not be None.")
|
|
116
|
+
wci = array_ops.constant(0, dtype=dtypes.float32, shape=[cell_size])
|
|
117
|
+
wcf = wci
|
|
118
|
+
wco = wci
|
|
119
|
+
|
|
120
|
+
# pylint: disable=protected-access
|
|
121
|
+
return gen_rnn_ops.lstm_block_cell(
|
|
122
|
+
x=x,
|
|
123
|
+
cs_prev=cs_prev,
|
|
124
|
+
h_prev=h_prev,
|
|
125
|
+
w=w,
|
|
126
|
+
wci=wci,
|
|
127
|
+
wcf=wcf,
|
|
128
|
+
wco=wco,
|
|
129
|
+
b=b,
|
|
130
|
+
forget_bias=forget_bias,
|
|
131
|
+
cell_clip=cell_clip if cell_clip is not None else -1,
|
|
132
|
+
use_peephole=use_peephole,
|
|
133
|
+
name=name)
|
|
134
|
+
# pylint: enable=protected-access
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def _block_lstm(seq_len_max,
|
|
138
|
+
x,
|
|
139
|
+
w,
|
|
140
|
+
b,
|
|
141
|
+
cs_prev=None,
|
|
142
|
+
h_prev=None,
|
|
143
|
+
wci=None,
|
|
144
|
+
wcf=None,
|
|
145
|
+
wco=None,
|
|
146
|
+
forget_bias=None,
|
|
147
|
+
cell_clip=None,
|
|
148
|
+
use_peephole=None,
|
|
149
|
+
name=None):
|
|
150
|
+
r"""TODO(williamchan): add doc.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
seq_len_max: A `Tensor` of type `int64`.
|
|
154
|
+
x: A list of at least 1 `Tensor` objects of the same type.
|
|
155
|
+
w: A `Tensor`. Must have the same type as `x`.
|
|
156
|
+
b: A `Tensor`. Must have the same type as `x`.
|
|
157
|
+
cs_prev: A `Tensor`. Must have the same type as `x`.
|
|
158
|
+
h_prev: A `Tensor`. Must have the same type as `x`.
|
|
159
|
+
wci: A `Tensor`. Must have the same type as `x`.
|
|
160
|
+
wcf: A `Tensor`. Must have the same type as `x`.
|
|
161
|
+
wco: A `Tensor`. Must have the same type as `x`.
|
|
162
|
+
forget_bias: An optional `float`. Defaults to `1`.
|
|
163
|
+
cell_clip: An optional `float`. Defaults to `-1` (no clipping).
|
|
164
|
+
use_peephole: An optional `bool`. Defaults to `False`.
|
|
165
|
+
name: A name for the operation (optional).
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
A tuple of `Tensor` objects (i, cs, f, o, ci, co, h).
|
|
169
|
+
i: A list with the same number of `Tensor` objects as `x` of `Tensor`
|
|
170
|
+
objects of the same type as x.
|
|
171
|
+
cs: A list with the same number of `Tensor` objects as `x` of `Tensor`
|
|
172
|
+
objects of the same type as x.
|
|
173
|
+
f: A list with the same number of `Tensor` objects as `x` of `Tensor`
|
|
174
|
+
objects of the same type as x.
|
|
175
|
+
o: A list with the same number of `Tensor` objects as `x` of `Tensor`
|
|
176
|
+
objects of the same type as x.
|
|
177
|
+
ci: A list with the same number of `Tensor` objects as `x` of `Tensor`
|
|
178
|
+
objects of the same type as x.
|
|
179
|
+
co: A list with the same number of `Tensor` objects as `x` of `Tensor`
|
|
180
|
+
objects of the same type as x.
|
|
181
|
+
h: A list with the same number of `Tensor` objects as `x` of `Tensor`
|
|
182
|
+
objects of the same type as x.
|
|
183
|
+
|
|
184
|
+
Raises:
|
|
185
|
+
ValueError: If `b` does not have a valid shape.
|
|
186
|
+
"""
|
|
187
|
+
dtype = x[0].dtype
|
|
188
|
+
batch_size = x[0].get_shape().with_rank(2).dims[0].value
|
|
189
|
+
cell_size4 = b.get_shape().with_rank(1).dims[0].value
|
|
190
|
+
if cell_size4 is None:
|
|
191
|
+
raise ValueError("`b` shape must not be None.")
|
|
192
|
+
cell_size = cell_size4 / 4
|
|
193
|
+
zero_state = None
|
|
194
|
+
if cs_prev is None or h_prev is None:
|
|
195
|
+
zero_state = array_ops.constant(
|
|
196
|
+
0, dtype=dtype, shape=[batch_size, cell_size])
|
|
197
|
+
if cs_prev is None:
|
|
198
|
+
cs_prev = zero_state
|
|
199
|
+
if h_prev is None:
|
|
200
|
+
h_prev = zero_state
|
|
201
|
+
if wci is None:
|
|
202
|
+
wci = array_ops.constant(0, dtype=dtype, shape=[cell_size])
|
|
203
|
+
wcf = wci
|
|
204
|
+
wco = wci
|
|
205
|
+
|
|
206
|
+
# pylint: disable=protected-access
|
|
207
|
+
i, cs, f, o, ci, co, h = gen_rnn_ops.block_lstm(
|
|
208
|
+
seq_len_max=seq_len_max,
|
|
209
|
+
x=array_ops.stack(x),
|
|
210
|
+
cs_prev=cs_prev,
|
|
211
|
+
h_prev=h_prev,
|
|
212
|
+
w=w,
|
|
213
|
+
wci=wci,
|
|
214
|
+
wcf=wcf,
|
|
215
|
+
wco=wco,
|
|
216
|
+
b=b,
|
|
217
|
+
forget_bias=forget_bias,
|
|
218
|
+
cell_clip=cell_clip if cell_clip is not None else -1,
|
|
219
|
+
name=name,
|
|
220
|
+
use_peephole=use_peephole)
|
|
221
|
+
|
|
222
|
+
return array_ops.unstack(i), array_ops.unstack(cs), array_ops.unstack(
|
|
223
|
+
f), array_ops.unstack(o), array_ops.unstack(ci), array_ops.unstack(
|
|
224
|
+
co), array_ops.unstack(h)
|
|
225
|
+
# pylint: enable=protected-access
|
|
226
|
+
# pylint: enable=invalid-name
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
@ops.RegisterGradient("LSTMBlockCell")
|
|
230
|
+
def _LSTMBlockCellGrad(op, *grad):
|
|
231
|
+
"""Gradient for LSTMBlockCell."""
|
|
232
|
+
(x, cs_prev, h_prev, w, wci, wcf, wco, b) = op.inputs
|
|
233
|
+
(i, cs, f, o, ci, co, _) = op.outputs
|
|
234
|
+
(_, cs_grad, _, _, _, _, h_grad) = grad
|
|
235
|
+
|
|
236
|
+
batch_size = x.get_shape().with_rank(2).dims[0].value
|
|
237
|
+
if batch_size is None:
|
|
238
|
+
batch_size = -1
|
|
239
|
+
input_size = x.get_shape().with_rank(2).dims[1].value
|
|
240
|
+
if input_size is None:
|
|
241
|
+
raise ValueError("input_size from `x` should not be None.")
|
|
242
|
+
cell_size = cs_prev.get_shape().with_rank(2).dims[1].value
|
|
243
|
+
if cell_size is None:
|
|
244
|
+
raise ValueError("cell_size from `cs_prev` should not be None.")
|
|
245
|
+
|
|
246
|
+
(cs_prev_grad, dgates, wci_grad, wcf_grad,
|
|
247
|
+
wco_grad) = gen_rnn_ops.lstm_block_cell_grad(
|
|
248
|
+
x=x,
|
|
249
|
+
cs_prev=cs_prev,
|
|
250
|
+
h_prev=h_prev,
|
|
251
|
+
w=w,
|
|
252
|
+
wci=wci,
|
|
253
|
+
wcf=wcf,
|
|
254
|
+
wco=wco,
|
|
255
|
+
b=b,
|
|
256
|
+
i=i,
|
|
257
|
+
cs=cs,
|
|
258
|
+
f=f,
|
|
259
|
+
o=o,
|
|
260
|
+
ci=ci,
|
|
261
|
+
co=co,
|
|
262
|
+
cs_grad=cs_grad,
|
|
263
|
+
h_grad=h_grad,
|
|
264
|
+
use_peephole=op.get_attr("use_peephole"))
|
|
265
|
+
|
|
266
|
+
# Backprop from dgates to xh.
|
|
267
|
+
xh_grad = math_ops.matmul(dgates, w, transpose_b=True)
|
|
268
|
+
|
|
269
|
+
x_grad = array_ops.slice(xh_grad, (0, 0), (batch_size, input_size))
|
|
270
|
+
x_grad.get_shape().merge_with(x.get_shape())
|
|
271
|
+
|
|
272
|
+
h_prev_grad = array_ops.slice(xh_grad, (0, input_size),
|
|
273
|
+
(batch_size, cell_size))
|
|
274
|
+
h_prev_grad.get_shape().merge_with(h_prev.get_shape())
|
|
275
|
+
|
|
276
|
+
# Backprop from dgates to w.
|
|
277
|
+
xh = array_ops.concat([x, h_prev], 1)
|
|
278
|
+
w_grad = math_ops.matmul(xh, dgates, transpose_a=True)
|
|
279
|
+
w_grad.get_shape().merge_with(w.get_shape())
|
|
280
|
+
|
|
281
|
+
# Backprop from dgates to b.
|
|
282
|
+
b_grad = nn_ops.bias_add_grad(dgates)
|
|
283
|
+
b_grad.get_shape().merge_with(b.get_shape())
|
|
284
|
+
|
|
285
|
+
return (x_grad, cs_prev_grad, h_prev_grad, w_grad, wci_grad, wcf_grad,
|
|
286
|
+
wco_grad, b_grad)
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
class LSTMBlockCell(LayerRNNCell):
|
|
290
|
+
"""Basic LSTM recurrent network cell.
|
|
291
|
+
|
|
292
|
+
The implementation is based on: http://arxiv.org/abs/1409.2329.
|
|
293
|
+
|
|
294
|
+
We add `forget_bias` (default: 1) to the biases of the forget gate in order to
|
|
295
|
+
reduce the scale of forgetting in the beginning of the training.
|
|
296
|
+
|
|
297
|
+
Unlike `rnn_cell_impl.LSTMCell`, this is a monolithic op and should be much
|
|
298
|
+
faster. The weight and bias matrices should be compatible as long as the
|
|
299
|
+
variable scope matches.
|
|
300
|
+
"""
|
|
301
|
+
|
|
302
|
+
def __init__(self,
|
|
303
|
+
num_units,
|
|
304
|
+
forget_bias=1.0,
|
|
305
|
+
cell_clip=None,
|
|
306
|
+
use_peephole=False,
|
|
307
|
+
dtype=None,
|
|
308
|
+
reuse=None,
|
|
309
|
+
name="lstm_cell"):
|
|
310
|
+
"""Initialize the basic LSTM cell.
|
|
311
|
+
|
|
312
|
+
Args:
|
|
313
|
+
num_units: int, The number of units in the LSTM cell.
|
|
314
|
+
forget_bias: float, The bias added to forget gates (see above).
|
|
315
|
+
cell_clip: An optional `float`. Defaults to `-1` (no clipping).
|
|
316
|
+
use_peephole: Whether to use peephole connections or not.
|
|
317
|
+
dtype: the variable dtype of this layer. Default to tf.float32.
|
|
318
|
+
reuse: (optional) boolean describing whether to reuse variables in an
|
|
319
|
+
existing scope. If not `True`, and the existing scope already has the
|
|
320
|
+
given variables, an error is raised.
|
|
321
|
+
name: String, the name of the layer. Layers with the same name will
|
|
322
|
+
share weights, but to avoid mistakes we require reuse=True in such
|
|
323
|
+
cases. By default this is "lstm_cell", for variable-name compatibility
|
|
324
|
+
with `tf.compat.v1.nn.rnn_cell.LSTMCell`.
|
|
325
|
+
|
|
326
|
+
When restoring from CudnnLSTM-trained checkpoints, must use
|
|
327
|
+
CudnnCompatibleLSTMBlockCell instead.
|
|
328
|
+
"""
|
|
329
|
+
super(LSTMBlockCell, self).__init__(_reuse=reuse, dtype=dtype, name=name)
|
|
330
|
+
self._num_units = num_units
|
|
331
|
+
self._forget_bias = forget_bias
|
|
332
|
+
self._use_peephole = use_peephole
|
|
333
|
+
self._cell_clip = cell_clip if cell_clip is not None else -1
|
|
334
|
+
self._names = {
|
|
335
|
+
"W": "kernel",
|
|
336
|
+
"b": "bias",
|
|
337
|
+
"wci": "w_i_diag",
|
|
338
|
+
"wcf": "w_f_diag",
|
|
339
|
+
"wco": "w_o_diag",
|
|
340
|
+
"scope": "lstm_cell"
|
|
341
|
+
}
|
|
342
|
+
# Inputs must be 2-dimensional.
|
|
343
|
+
self.input_spec = input_spec.InputSpec(ndim=2)
|
|
344
|
+
|
|
345
|
+
@property
|
|
346
|
+
def state_size(self):
|
|
347
|
+
return rnn_cell_impl.LSTMStateTuple(self._num_units, self._num_units)
|
|
348
|
+
|
|
349
|
+
@property
|
|
350
|
+
def output_size(self):
|
|
351
|
+
return self._num_units
|
|
352
|
+
|
|
353
|
+
def build(self, inputs_shape):
|
|
354
|
+
if not inputs_shape.dims[1].value:
|
|
355
|
+
raise ValueError(
|
|
356
|
+
"Expecting inputs_shape[1] to be set: %s" % str(inputs_shape))
|
|
357
|
+
input_size = inputs_shape.dims[1].value
|
|
358
|
+
self._kernel = self.add_variable(
|
|
359
|
+
self._names["W"], [input_size + self._num_units, self._num_units * 4])
|
|
360
|
+
self._bias = self.add_variable(
|
|
361
|
+
self._names["b"], [self._num_units * 4],
|
|
362
|
+
initializer=init_ops.constant_initializer(0.0))
|
|
363
|
+
if self._use_peephole:
|
|
364
|
+
self._w_i_diag = self.add_variable(self._names["wci"], [self._num_units])
|
|
365
|
+
self._w_f_diag = self.add_variable(self._names["wcf"], [self._num_units])
|
|
366
|
+
self._w_o_diag = self.add_variable(self._names["wco"], [self._num_units])
|
|
367
|
+
|
|
368
|
+
self.built = True
|
|
369
|
+
|
|
370
|
+
def call(self, inputs, state):
|
|
371
|
+
"""Long short-term memory cell (LSTM)."""
|
|
372
|
+
if len(state) != 2:
|
|
373
|
+
raise ValueError("Expecting state to be a tuple with length 2.")
|
|
374
|
+
|
|
375
|
+
if self._use_peephole:
|
|
376
|
+
wci = self._w_i_diag
|
|
377
|
+
wcf = self._w_f_diag
|
|
378
|
+
wco = self._w_o_diag
|
|
379
|
+
else:
|
|
380
|
+
wci = wcf = wco = array_ops.zeros([self._num_units], dtype=self.dtype)
|
|
381
|
+
|
|
382
|
+
(cs_prev, h_prev) = state
|
|
383
|
+
(_, cs, _, _, _, _, h) = _lstm_block_cell(
|
|
384
|
+
inputs,
|
|
385
|
+
cs_prev,
|
|
386
|
+
h_prev,
|
|
387
|
+
self._kernel,
|
|
388
|
+
self._bias,
|
|
389
|
+
wci=wci,
|
|
390
|
+
wcf=wcf,
|
|
391
|
+
wco=wco,
|
|
392
|
+
forget_bias=self._forget_bias,
|
|
393
|
+
cell_clip=self._cell_clip,
|
|
394
|
+
use_peephole=self._use_peephole)
|
|
395
|
+
|
|
396
|
+
new_state = rnn_cell_impl.LSTMStateTuple(cs, h)
|
|
397
|
+
return h, new_state
|
|
398
|
+
|
|
399
|
+
|
|
400
|
+
@six.add_metaclass(abc.ABCMeta)
|
|
401
|
+
class LSTMBlockWrapper(base_layer.Layer):
|
|
402
|
+
"""This is a helper class that provides housekeeping for LSTM cells.
|
|
403
|
+
|
|
404
|
+
This may be useful for alternative LSTM and similar type of cells.
|
|
405
|
+
The subclasses must implement `_call_cell` method and `num_units` property.
|
|
406
|
+
"""
|
|
407
|
+
|
|
408
|
+
@abc.abstractproperty
|
|
409
|
+
def num_units(self):
|
|
410
|
+
"""Number of units in this cell (output dimension)."""
|
|
411
|
+
|
|
412
|
+
@abc.abstractmethod
|
|
413
|
+
def _call_cell(self, inputs, initial_cell_state, initial_output, dtype,
|
|
414
|
+
sequence_length):
|
|
415
|
+
"""Run this LSTM on inputs, starting from the given state.
|
|
416
|
+
|
|
417
|
+
This method must be implemented by subclasses and does the actual work
|
|
418
|
+
of calling the cell.
|
|
419
|
+
|
|
420
|
+
Args:
|
|
421
|
+
inputs: `3-D` tensor with shape `[time_len, batch_size, input_size]`
|
|
422
|
+
initial_cell_state: initial value for cell state, shape `[batch_size,
|
|
423
|
+
self._num_units]`
|
|
424
|
+
initial_output: initial value of cell output, shape `[batch_size,
|
|
425
|
+
self._num_units]`
|
|
426
|
+
dtype: The data type for the initial state and expected output.
|
|
427
|
+
sequence_length: Specifies the length of each sequence in inputs. An int32
|
|
428
|
+
or int64 vector (tensor) size [batch_size], values in [0, time_len) or
|
|
429
|
+
None.
|
|
430
|
+
|
|
431
|
+
Returns:
|
|
432
|
+
A pair containing:
|
|
433
|
+
|
|
434
|
+
- State: A `3-D` tensor of shape `[time_len, batch_size, output_size]`
|
|
435
|
+
- Output: A `3-D` tensor of shape `[time_len, batch_size, output_size]`
|
|
436
|
+
"""
|
|
437
|
+
pass
|
|
438
|
+
|
|
439
|
+
def call(self, inputs, initial_state=None, dtype=None, sequence_length=None):
|
|
440
|
+
"""Run this LSTM on inputs, starting from the given state.
|
|
441
|
+
|
|
442
|
+
Args:
|
|
443
|
+
inputs: `3-D` tensor with shape `[time_len, batch_size, input_size]`.
|
|
444
|
+
initial_state: a tuple `(initial_cell_state, initial_output)` with tensors
|
|
445
|
+
of shape `[batch_size, self._num_units]`. If this is not provided, the
|
|
446
|
+
cell is expected to create a zero initial state of type `dtype`.
|
|
447
|
+
dtype: The data type for the initial state and expected output. Required
|
|
448
|
+
if `initial_state` is not provided or RNN state has a heterogeneous
|
|
449
|
+
dtype.
|
|
450
|
+
sequence_length: Specifies the length of each sequence in inputs. An
|
|
451
|
+
`int32` or `int64` vector (tensor) size `[batch_size]`, values in `[0,
|
|
452
|
+
time_len).`
|
|
453
|
+
Defaults to `time_len` for each element.
|
|
454
|
+
|
|
455
|
+
Returns:
|
|
456
|
+
A pair containing:
|
|
457
|
+
|
|
458
|
+
- Output: A `3-D` tensor of shape `[time_len, batch_size, output_size]`
|
|
459
|
+
or a list of time_len tensors of shape `[batch_size, output_size]`,
|
|
460
|
+
to match the type of the `inputs`.
|
|
461
|
+
- Final state: a tuple `(cell_state, output)` matching `initial_state`.
|
|
462
|
+
|
|
463
|
+
Raises:
|
|
464
|
+
ValueError: in case of shape mismatches
|
|
465
|
+
"""
|
|
466
|
+
is_list = isinstance(inputs, list)
|
|
467
|
+
if is_list:
|
|
468
|
+
inputs = array_ops.stack(inputs)
|
|
469
|
+
inputs_shape = inputs.get_shape().with_rank(3)
|
|
470
|
+
if not inputs_shape[2]:
|
|
471
|
+
raise ValueError("Expecting inputs_shape[2] to be set: %s" % inputs_shape)
|
|
472
|
+
batch_size = inputs_shape.dims[1].value
|
|
473
|
+
if batch_size is None:
|
|
474
|
+
batch_size = array_ops.shape(inputs)[1]
|
|
475
|
+
time_len = inputs_shape.dims[0].value
|
|
476
|
+
if time_len is None:
|
|
477
|
+
time_len = array_ops.shape(inputs)[0]
|
|
478
|
+
|
|
479
|
+
# Provide default values for initial_state and dtype
|
|
480
|
+
if initial_state is None:
|
|
481
|
+
if dtype is None:
|
|
482
|
+
raise ValueError("Either initial_state or dtype needs to be specified")
|
|
483
|
+
z = array_ops.zeros(
|
|
484
|
+
array_ops.stack([batch_size, self.num_units]), dtype=dtype)
|
|
485
|
+
initial_state = z, z
|
|
486
|
+
else:
|
|
487
|
+
if len(initial_state) != 2:
|
|
488
|
+
raise ValueError(
|
|
489
|
+
"Expecting initial_state to be a tuple with length 2 or None")
|
|
490
|
+
if dtype is None:
|
|
491
|
+
dtype = initial_state[0].dtype
|
|
492
|
+
|
|
493
|
+
# create the actual cell
|
|
494
|
+
if sequence_length is not None:
|
|
495
|
+
sequence_length = ops.convert_to_tensor(sequence_length)
|
|
496
|
+
initial_cell_state, initial_output = initial_state # pylint: disable=unpacking-non-sequence
|
|
497
|
+
cell_states, outputs = self._call_cell(
|
|
498
|
+
inputs, initial_cell_state, initial_output, dtype, sequence_length)
|
|
499
|
+
|
|
500
|
+
if sequence_length is not None:
|
|
501
|
+
# Mask out the part beyond sequence_length
|
|
502
|
+
mask = array_ops.transpose(
|
|
503
|
+
array_ops.sequence_mask(sequence_length, time_len, dtype=dtype),
|
|
504
|
+
[1, 0])
|
|
505
|
+
mask = array_ops.tile(
|
|
506
|
+
array_ops.expand_dims(mask, [-1]), [1, 1, self.num_units])
|
|
507
|
+
outputs *= mask
|
|
508
|
+
# Prepend initial states to cell_states and outputs for indexing to work
|
|
509
|
+
# correctly,since we want to access the last valid state at
|
|
510
|
+
# sequence_length - 1, which can even be -1, corresponding to the
|
|
511
|
+
# initial state.
|
|
512
|
+
mod_cell_states = array_ops.concat(
|
|
513
|
+
[array_ops.expand_dims(initial_cell_state, [0]), cell_states], 0)
|
|
514
|
+
mod_outputs = array_ops.concat(
|
|
515
|
+
[array_ops.expand_dims(initial_output, [0]), outputs], 0)
|
|
516
|
+
final_cell_state = self._gather_states(mod_cell_states, sequence_length,
|
|
517
|
+
batch_size)
|
|
518
|
+
final_output = self._gather_states(mod_outputs, sequence_length,
|
|
519
|
+
batch_size)
|
|
520
|
+
else:
|
|
521
|
+
# No sequence_lengths used: final state is the last state
|
|
522
|
+
final_cell_state = cell_states[-1]
|
|
523
|
+
final_output = outputs[-1]
|
|
524
|
+
|
|
525
|
+
if is_list:
|
|
526
|
+
# Input was a list, so return a list
|
|
527
|
+
outputs = array_ops.unstack(outputs)
|
|
528
|
+
|
|
529
|
+
final_state = rnn_cell_impl.LSTMStateTuple(final_cell_state, final_output)
|
|
530
|
+
return outputs, final_state
|
|
531
|
+
|
|
532
|
+
def _gather_states(self, data, indices, batch_size):
|
|
533
|
+
"""Produce `out`, s.t. out(i, j) = data(indices(i), i, j)."""
|
|
534
|
+
return array_ops.gather_nd(
|
|
535
|
+
data, array_ops.stack([indices, math_ops.range(batch_size)], axis=1))
|
|
536
|
+
|
|
537
|
+
|
|
538
|
+
class LSTMBlockFusedCell(LSTMBlockWrapper):
|
|
539
|
+
"""FusedRNNCell implementation of LSTM.
|
|
540
|
+
|
|
541
|
+
This is an extremely efficient LSTM implementation, that uses a single TF op
|
|
542
|
+
for the entire LSTM. It should be both faster and more memory-efficient than
|
|
543
|
+
LSTMBlockCell defined above.
|
|
544
|
+
|
|
545
|
+
The implementation is based on: http://arxiv.org/abs/1409.2329.
|
|
546
|
+
|
|
547
|
+
We add forget_bias (default: 1) to the biases of the forget gate in order to
|
|
548
|
+
reduce the scale of forgetting in the beginning of the training.
|
|
549
|
+
|
|
550
|
+
The variable naming is consistent with `rnn_cell_impl.LSTMCell`.
|
|
551
|
+
"""
|
|
552
|
+
|
|
553
|
+
def __init__(self,
|
|
554
|
+
num_units,
|
|
555
|
+
forget_bias=1.0,
|
|
556
|
+
cell_clip=None,
|
|
557
|
+
use_peephole=False,
|
|
558
|
+
reuse=None,
|
|
559
|
+
dtype=None,
|
|
560
|
+
name="lstm_fused_cell"):
|
|
561
|
+
"""Initialize the LSTM cell.
|
|
562
|
+
|
|
563
|
+
Args:
|
|
564
|
+
num_units: int, The number of units in the LSTM cell.
|
|
565
|
+
forget_bias: float, The bias added to forget gates (see above).
|
|
566
|
+
cell_clip: clip the cell to this value. Defaults is no cell clipping.
|
|
567
|
+
use_peephole: Whether to use peephole connections or not.
|
|
568
|
+
reuse: (optional) boolean describing whether to reuse variables in an
|
|
569
|
+
existing scope. If not `True`, and the existing scope already has the
|
|
570
|
+
given variables, an error is raised.
|
|
571
|
+
dtype: the dtype of variables of this layer.
|
|
572
|
+
name: String, the name of the layer. Layers with the same name will
|
|
573
|
+
share weights, but to avoid mistakes we require reuse=True in such
|
|
574
|
+
cases. By default this is "lstm_cell", for variable-name compatibility
|
|
575
|
+
with `tf.compat.v1.nn.rnn_cell.LSTMCell`.
|
|
576
|
+
"""
|
|
577
|
+
super(LSTMBlockFusedCell, self).__init__(
|
|
578
|
+
_reuse=reuse, name=name, dtype=dtype)
|
|
579
|
+
self._num_units = num_units
|
|
580
|
+
self._forget_bias = forget_bias
|
|
581
|
+
self._cell_clip = cell_clip if cell_clip is not None else -1
|
|
582
|
+
self._use_peephole = use_peephole
|
|
583
|
+
|
|
584
|
+
# Inputs must be 3-dimensional.
|
|
585
|
+
self.input_spec = input_spec.InputSpec(ndim=3)
|
|
586
|
+
|
|
587
|
+
@property
|
|
588
|
+
def num_units(self):
|
|
589
|
+
"""Number of units in this cell (output dimension)."""
|
|
590
|
+
return self._num_units
|
|
591
|
+
|
|
592
|
+
def build(self, input_shape):
|
|
593
|
+
input_size = input_shape.dims[2].value
|
|
594
|
+
self._kernel = self.add_variable(
|
|
595
|
+
"kernel", [input_size + self._num_units, self._num_units * 4])
|
|
596
|
+
self._bias = self.add_variable(
|
|
597
|
+
"bias", [self._num_units * 4],
|
|
598
|
+
initializer=init_ops.constant_initializer(0.0))
|
|
599
|
+
if self._use_peephole:
|
|
600
|
+
self._w_i_diag = self.add_variable("w_i_diag", [self._num_units])
|
|
601
|
+
self._w_f_diag = self.add_variable("w_f_diag", [self._num_units])
|
|
602
|
+
self._w_o_diag = self.add_variable("w_o_diag", [self._num_units])
|
|
603
|
+
|
|
604
|
+
self.built = True
|
|
605
|
+
|
|
606
|
+
def _call_cell(self,
|
|
607
|
+
inputs,
|
|
608
|
+
initial_cell_state=None,
|
|
609
|
+
initial_output=None,
|
|
610
|
+
dtype=None,
|
|
611
|
+
sequence_length=None):
|
|
612
|
+
"""Run this LSTM on inputs, starting from the given state.
|
|
613
|
+
|
|
614
|
+
Args:
|
|
615
|
+
inputs: `3-D` tensor with shape `[time_len, batch_size, input_size]`
|
|
616
|
+
initial_cell_state: initial value for cell state, shape `[batch_size,
|
|
617
|
+
self._num_units]`
|
|
618
|
+
initial_output: initial value of cell output, shape `[batch_size,
|
|
619
|
+
self._num_units]`
|
|
620
|
+
dtype: The data type for the initial state and expected output.
|
|
621
|
+
sequence_length: Specifies the length of each sequence in inputs. An
|
|
622
|
+
`int32` or `int64` vector (tensor) size `[batch_size]`, values in `[0,
|
|
623
|
+
time_len)` or None.
|
|
624
|
+
|
|
625
|
+
Returns:
|
|
626
|
+
A pair containing:
|
|
627
|
+
|
|
628
|
+
- Cell state (cs): A `3-D` tensor of shape `[time_len, batch_size,
|
|
629
|
+
output_size]`
|
|
630
|
+
- Output (h): A `3-D` tensor of shape `[time_len, batch_size,
|
|
631
|
+
output_size]`
|
|
632
|
+
"""
|
|
633
|
+
|
|
634
|
+
inputs_shape = inputs.get_shape().with_rank(3)
|
|
635
|
+
time_len = inputs_shape.dims[0].value
|
|
636
|
+
if time_len is None:
|
|
637
|
+
time_len = array_ops.shape(inputs)[0]
|
|
638
|
+
|
|
639
|
+
if self._use_peephole:
|
|
640
|
+
wci = self._w_i_diag
|
|
641
|
+
wco = self._w_o_diag
|
|
642
|
+
wcf = self._w_f_diag
|
|
643
|
+
else:
|
|
644
|
+
wci = wcf = wco = array_ops.zeros([self._num_units], dtype=dtype)
|
|
645
|
+
|
|
646
|
+
if sequence_length is None:
|
|
647
|
+
max_seq_len = math_ops.cast(time_len, dtypes.int64)
|
|
648
|
+
else:
|
|
649
|
+
max_seq_len = math_ops.cast(math_ops.reduce_max(sequence_length),
|
|
650
|
+
dtypes.int64)
|
|
651
|
+
|
|
652
|
+
_, cs, _, _, _, _, h = gen_rnn_ops.block_lstm(
|
|
653
|
+
seq_len_max=max_seq_len,
|
|
654
|
+
x=inputs,
|
|
655
|
+
cs_prev=initial_cell_state,
|
|
656
|
+
h_prev=initial_output,
|
|
657
|
+
w=self._kernel,
|
|
658
|
+
wci=wci,
|
|
659
|
+
wcf=wcf,
|
|
660
|
+
wco=wco,
|
|
661
|
+
b=self._bias,
|
|
662
|
+
forget_bias=self._forget_bias,
|
|
663
|
+
cell_clip=self._cell_clip,
|
|
664
|
+
use_peephole=self._use_peephole)
|
|
665
|
+
return cs, h
|