spark-nlp 2.6.3rc1__py2.py3-none-any.whl → 6.2.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- com/johnsnowlabs/ml/__init__.py +0 -0
- com/johnsnowlabs/ml/ai/__init__.py +10 -0
- com/johnsnowlabs/nlp/__init__.py +4 -2
- spark_nlp-6.2.1.dist-info/METADATA +362 -0
- spark_nlp-6.2.1.dist-info/RECORD +292 -0
- {spark_nlp-2.6.3rc1.dist-info → spark_nlp-6.2.1.dist-info}/WHEEL +1 -1
- sparknlp/__init__.py +281 -27
- sparknlp/annotation.py +137 -6
- sparknlp/annotation_audio.py +61 -0
- sparknlp/annotation_image.py +82 -0
- sparknlp/annotator/__init__.py +93 -0
- sparknlp/annotator/audio/__init__.py +16 -0
- sparknlp/annotator/audio/hubert_for_ctc.py +188 -0
- sparknlp/annotator/audio/wav2vec2_for_ctc.py +161 -0
- sparknlp/annotator/audio/whisper_for_ctc.py +251 -0
- sparknlp/annotator/chunk2_doc.py +85 -0
- sparknlp/annotator/chunker.py +137 -0
- sparknlp/annotator/classifier_dl/__init__.py +61 -0
- sparknlp/annotator/classifier_dl/albert_for_multiple_choice.py +161 -0
- sparknlp/annotator/classifier_dl/albert_for_question_answering.py +172 -0
- sparknlp/annotator/classifier_dl/albert_for_sequence_classification.py +201 -0
- sparknlp/annotator/classifier_dl/albert_for_token_classification.py +179 -0
- sparknlp/annotator/classifier_dl/albert_for_zero_shot_classification.py +211 -0
- sparknlp/annotator/classifier_dl/bart_for_zero_shot_classification.py +225 -0
- sparknlp/annotator/classifier_dl/bert_for_multiple_choice.py +161 -0
- sparknlp/annotator/classifier_dl/bert_for_question_answering.py +168 -0
- sparknlp/annotator/classifier_dl/bert_for_sequence_classification.py +202 -0
- sparknlp/annotator/classifier_dl/bert_for_token_classification.py +177 -0
- sparknlp/annotator/classifier_dl/bert_for_zero_shot_classification.py +212 -0
- sparknlp/annotator/classifier_dl/camembert_for_question_answering.py +168 -0
- sparknlp/annotator/classifier_dl/camembert_for_sequence_classification.py +205 -0
- sparknlp/annotator/classifier_dl/camembert_for_token_classification.py +173 -0
- sparknlp/annotator/classifier_dl/camembert_for_zero_shot_classification.py +202 -0
- sparknlp/annotator/classifier_dl/classifier_dl.py +320 -0
- sparknlp/annotator/classifier_dl/deberta_for_question_answering.py +168 -0
- sparknlp/annotator/classifier_dl/deberta_for_sequence_classification.py +198 -0
- sparknlp/annotator/classifier_dl/deberta_for_token_classification.py +175 -0
- sparknlp/annotator/classifier_dl/deberta_for_zero_shot_classification.py +193 -0
- sparknlp/annotator/classifier_dl/distil_bert_for_question_answering.py +168 -0
- sparknlp/annotator/classifier_dl/distil_bert_for_sequence_classification.py +201 -0
- sparknlp/annotator/classifier_dl/distil_bert_for_token_classification.py +175 -0
- sparknlp/annotator/classifier_dl/distil_bert_for_zero_shot_classification.py +211 -0
- sparknlp/annotator/classifier_dl/distilbert_for_multiple_choice.py +161 -0
- sparknlp/annotator/classifier_dl/longformer_for_question_answering.py +168 -0
- sparknlp/annotator/classifier_dl/longformer_for_sequence_classification.py +201 -0
- sparknlp/annotator/classifier_dl/longformer_for_token_classification.py +176 -0
- sparknlp/annotator/classifier_dl/mpnet_for_question_answering.py +148 -0
- sparknlp/annotator/classifier_dl/mpnet_for_sequence_classification.py +188 -0
- sparknlp/annotator/classifier_dl/mpnet_for_token_classification.py +173 -0
- sparknlp/annotator/classifier_dl/multi_classifier_dl.py +395 -0
- sparknlp/annotator/classifier_dl/roberta_for_multiple_choice.py +161 -0
- sparknlp/annotator/classifier_dl/roberta_for_question_answering.py +168 -0
- sparknlp/annotator/classifier_dl/roberta_for_sequence_classification.py +201 -0
- sparknlp/annotator/classifier_dl/roberta_for_token_classification.py +189 -0
- sparknlp/annotator/classifier_dl/roberta_for_zero_shot_classification.py +225 -0
- sparknlp/annotator/classifier_dl/sentiment_dl.py +378 -0
- sparknlp/annotator/classifier_dl/tapas_for_question_answering.py +170 -0
- sparknlp/annotator/classifier_dl/xlm_roberta_for_multiple_choice.py +149 -0
- sparknlp/annotator/classifier_dl/xlm_roberta_for_question_answering.py +168 -0
- sparknlp/annotator/classifier_dl/xlm_roberta_for_sequence_classification.py +201 -0
- sparknlp/annotator/classifier_dl/xlm_roberta_for_token_classification.py +173 -0
- sparknlp/annotator/classifier_dl/xlm_roberta_for_zero_shot_classification.py +225 -0
- sparknlp/annotator/classifier_dl/xlnet_for_sequence_classification.py +201 -0
- sparknlp/annotator/classifier_dl/xlnet_for_token_classification.py +176 -0
- sparknlp/annotator/cleaners/__init__.py +15 -0
- sparknlp/annotator/cleaners/cleaner.py +202 -0
- sparknlp/annotator/cleaners/extractor.py +191 -0
- sparknlp/annotator/coref/__init__.py +1 -0
- sparknlp/annotator/coref/spanbert_coref.py +221 -0
- sparknlp/annotator/cv/__init__.py +29 -0
- sparknlp/annotator/cv/blip_for_question_answering.py +172 -0
- sparknlp/annotator/cv/clip_for_zero_shot_classification.py +193 -0
- sparknlp/annotator/cv/convnext_for_image_classification.py +269 -0
- sparknlp/annotator/cv/florence2_transformer.py +180 -0
- sparknlp/annotator/cv/gemma3_for_multimodal.py +346 -0
- sparknlp/annotator/cv/internvl_for_multimodal.py +280 -0
- sparknlp/annotator/cv/janus_for_multimodal.py +351 -0
- sparknlp/annotator/cv/llava_for_multimodal.py +328 -0
- sparknlp/annotator/cv/mllama_for_multimodal.py +340 -0
- sparknlp/annotator/cv/paligemma_for_multimodal.py +308 -0
- sparknlp/annotator/cv/phi3_vision_for_multimodal.py +328 -0
- sparknlp/annotator/cv/qwen2vl_transformer.py +332 -0
- sparknlp/annotator/cv/smolvlm_transformer.py +426 -0
- sparknlp/annotator/cv/swin_for_image_classification.py +242 -0
- sparknlp/annotator/cv/vision_encoder_decoder_for_image_captioning.py +240 -0
- sparknlp/annotator/cv/vit_for_image_classification.py +217 -0
- sparknlp/annotator/dataframe_optimizer.py +216 -0
- sparknlp/annotator/date2_chunk.py +88 -0
- sparknlp/annotator/dependency/__init__.py +17 -0
- sparknlp/annotator/dependency/dependency_parser.py +294 -0
- sparknlp/annotator/dependency/typed_dependency_parser.py +318 -0
- sparknlp/annotator/document_character_text_splitter.py +228 -0
- sparknlp/annotator/document_normalizer.py +235 -0
- sparknlp/annotator/document_token_splitter.py +175 -0
- sparknlp/annotator/document_token_splitter_test.py +85 -0
- sparknlp/annotator/embeddings/__init__.py +45 -0
- sparknlp/annotator/embeddings/albert_embeddings.py +230 -0
- sparknlp/annotator/embeddings/auto_gguf_embeddings.py +539 -0
- sparknlp/annotator/embeddings/bert_embeddings.py +208 -0
- sparknlp/annotator/embeddings/bert_sentence_embeddings.py +224 -0
- sparknlp/annotator/embeddings/bge_embeddings.py +199 -0
- sparknlp/annotator/embeddings/camembert_embeddings.py +210 -0
- sparknlp/annotator/embeddings/chunk_embeddings.py +149 -0
- sparknlp/annotator/embeddings/deberta_embeddings.py +208 -0
- sparknlp/annotator/embeddings/distil_bert_embeddings.py +221 -0
- sparknlp/annotator/embeddings/doc2vec.py +352 -0
- sparknlp/annotator/embeddings/e5_embeddings.py +195 -0
- sparknlp/annotator/embeddings/e5v_embeddings.py +138 -0
- sparknlp/annotator/embeddings/elmo_embeddings.py +251 -0
- sparknlp/annotator/embeddings/instructor_embeddings.py +204 -0
- sparknlp/annotator/embeddings/longformer_embeddings.py +211 -0
- sparknlp/annotator/embeddings/minilm_embeddings.py +189 -0
- sparknlp/annotator/embeddings/mpnet_embeddings.py +192 -0
- sparknlp/annotator/embeddings/mxbai_embeddings.py +184 -0
- sparknlp/annotator/embeddings/nomic_embeddings.py +181 -0
- sparknlp/annotator/embeddings/roberta_embeddings.py +225 -0
- sparknlp/annotator/embeddings/roberta_sentence_embeddings.py +191 -0
- sparknlp/annotator/embeddings/sentence_embeddings.py +134 -0
- sparknlp/annotator/embeddings/snowflake_embeddings.py +202 -0
- sparknlp/annotator/embeddings/uae_embeddings.py +211 -0
- sparknlp/annotator/embeddings/universal_sentence_encoder.py +211 -0
- sparknlp/annotator/embeddings/word2vec.py +353 -0
- sparknlp/annotator/embeddings/word_embeddings.py +385 -0
- sparknlp/annotator/embeddings/xlm_roberta_embeddings.py +225 -0
- sparknlp/annotator/embeddings/xlm_roberta_sentence_embeddings.py +194 -0
- sparknlp/annotator/embeddings/xlnet_embeddings.py +227 -0
- sparknlp/annotator/er/__init__.py +16 -0
- sparknlp/annotator/er/entity_ruler.py +267 -0
- sparknlp/annotator/graph_extraction.py +368 -0
- sparknlp/annotator/keyword_extraction/__init__.py +16 -0
- sparknlp/annotator/keyword_extraction/yake_keyword_extraction.py +270 -0
- sparknlp/annotator/ld_dl/__init__.py +16 -0
- sparknlp/annotator/ld_dl/language_detector_dl.py +199 -0
- sparknlp/annotator/lemmatizer.py +250 -0
- sparknlp/annotator/matcher/__init__.py +20 -0
- sparknlp/annotator/matcher/big_text_matcher.py +272 -0
- sparknlp/annotator/matcher/date_matcher.py +303 -0
- sparknlp/annotator/matcher/multi_date_matcher.py +109 -0
- sparknlp/annotator/matcher/regex_matcher.py +221 -0
- sparknlp/annotator/matcher/text_matcher.py +290 -0
- sparknlp/annotator/n_gram_generator.py +141 -0
- sparknlp/annotator/ner/__init__.py +21 -0
- sparknlp/annotator/ner/ner_approach.py +94 -0
- sparknlp/annotator/ner/ner_converter.py +148 -0
- sparknlp/annotator/ner/ner_crf.py +397 -0
- sparknlp/annotator/ner/ner_dl.py +591 -0
- sparknlp/annotator/ner/ner_dl_graph_checker.py +293 -0
- sparknlp/annotator/ner/ner_overwriter.py +166 -0
- sparknlp/annotator/ner/zero_shot_ner_model.py +173 -0
- sparknlp/annotator/normalizer.py +230 -0
- sparknlp/annotator/openai/__init__.py +16 -0
- sparknlp/annotator/openai/openai_completion.py +349 -0
- sparknlp/annotator/openai/openai_embeddings.py +106 -0
- sparknlp/annotator/param/__init__.py +17 -0
- sparknlp/annotator/param/classifier_encoder.py +98 -0
- sparknlp/annotator/param/evaluation_dl_params.py +130 -0
- sparknlp/annotator/pos/__init__.py +16 -0
- sparknlp/annotator/pos/perceptron.py +263 -0
- sparknlp/annotator/sentence/__init__.py +17 -0
- sparknlp/annotator/sentence/sentence_detector.py +290 -0
- sparknlp/annotator/sentence/sentence_detector_dl.py +467 -0
- sparknlp/annotator/sentiment/__init__.py +17 -0
- sparknlp/annotator/sentiment/sentiment_detector.py +208 -0
- sparknlp/annotator/sentiment/vivekn_sentiment.py +242 -0
- sparknlp/annotator/seq2seq/__init__.py +35 -0
- sparknlp/annotator/seq2seq/auto_gguf_model.py +304 -0
- sparknlp/annotator/seq2seq/auto_gguf_reranker.py +334 -0
- sparknlp/annotator/seq2seq/auto_gguf_vision_model.py +336 -0
- sparknlp/annotator/seq2seq/bart_transformer.py +420 -0
- sparknlp/annotator/seq2seq/cohere_transformer.py +357 -0
- sparknlp/annotator/seq2seq/cpm_transformer.py +321 -0
- sparknlp/annotator/seq2seq/gpt2_transformer.py +363 -0
- sparknlp/annotator/seq2seq/llama2_transformer.py +343 -0
- sparknlp/annotator/seq2seq/llama3_transformer.py +381 -0
- sparknlp/annotator/seq2seq/m2m100_transformer.py +392 -0
- sparknlp/annotator/seq2seq/marian_transformer.py +374 -0
- sparknlp/annotator/seq2seq/mistral_transformer.py +348 -0
- sparknlp/annotator/seq2seq/nllb_transformer.py +420 -0
- sparknlp/annotator/seq2seq/olmo_transformer.py +326 -0
- sparknlp/annotator/seq2seq/phi2_transformer.py +326 -0
- sparknlp/annotator/seq2seq/phi3_transformer.py +330 -0
- sparknlp/annotator/seq2seq/phi4_transformer.py +387 -0
- sparknlp/annotator/seq2seq/qwen_transformer.py +340 -0
- sparknlp/annotator/seq2seq/starcoder_transformer.py +335 -0
- sparknlp/annotator/seq2seq/t5_transformer.py +425 -0
- sparknlp/annotator/similarity/__init__.py +0 -0
- sparknlp/annotator/similarity/document_similarity_ranker.py +379 -0
- sparknlp/annotator/spell_check/__init__.py +18 -0
- sparknlp/annotator/spell_check/context_spell_checker.py +911 -0
- sparknlp/annotator/spell_check/norvig_sweeting.py +358 -0
- sparknlp/annotator/spell_check/symmetric_delete.py +299 -0
- sparknlp/annotator/stemmer.py +79 -0
- sparknlp/annotator/stop_words_cleaner.py +190 -0
- sparknlp/annotator/tf_ner_dl_graph_builder.py +179 -0
- sparknlp/annotator/token/__init__.py +19 -0
- sparknlp/annotator/token/chunk_tokenizer.py +118 -0
- sparknlp/annotator/token/recursive_tokenizer.py +205 -0
- sparknlp/annotator/token/regex_tokenizer.py +208 -0
- sparknlp/annotator/token/tokenizer.py +561 -0
- sparknlp/annotator/token2_chunk.py +76 -0
- sparknlp/annotator/ws/__init__.py +16 -0
- sparknlp/annotator/ws/word_segmenter.py +429 -0
- sparknlp/base/__init__.py +30 -0
- sparknlp/base/audio_assembler.py +95 -0
- sparknlp/base/doc2_chunk.py +169 -0
- sparknlp/base/document_assembler.py +164 -0
- sparknlp/base/embeddings_finisher.py +201 -0
- sparknlp/base/finisher.py +217 -0
- sparknlp/base/gguf_ranking_finisher.py +234 -0
- sparknlp/base/graph_finisher.py +125 -0
- sparknlp/base/has_recursive_fit.py +24 -0
- sparknlp/base/has_recursive_transform.py +22 -0
- sparknlp/base/image_assembler.py +172 -0
- sparknlp/base/light_pipeline.py +429 -0
- sparknlp/base/multi_document_assembler.py +164 -0
- sparknlp/base/prompt_assembler.py +207 -0
- sparknlp/base/recursive_pipeline.py +107 -0
- sparknlp/base/table_assembler.py +145 -0
- sparknlp/base/token_assembler.py +124 -0
- sparknlp/common/__init__.py +26 -0
- sparknlp/common/annotator_approach.py +41 -0
- sparknlp/common/annotator_model.py +47 -0
- sparknlp/common/annotator_properties.py +114 -0
- sparknlp/common/annotator_type.py +38 -0
- sparknlp/common/completion_post_processing.py +37 -0
- sparknlp/common/coverage_result.py +22 -0
- sparknlp/common/match_strategy.py +33 -0
- sparknlp/common/properties.py +1298 -0
- sparknlp/common/read_as.py +33 -0
- sparknlp/common/recursive_annotator_approach.py +35 -0
- sparknlp/common/storage.py +149 -0
- sparknlp/common/utils.py +39 -0
- sparknlp/functions.py +315 -5
- sparknlp/internal/__init__.py +1199 -0
- sparknlp/internal/annotator_java_ml.py +32 -0
- sparknlp/internal/annotator_transformer.py +37 -0
- sparknlp/internal/extended_java_wrapper.py +63 -0
- sparknlp/internal/params_getters_setters.py +71 -0
- sparknlp/internal/recursive.py +70 -0
- sparknlp/logging/__init__.py +15 -0
- sparknlp/logging/comet.py +467 -0
- sparknlp/partition/__init__.py +16 -0
- sparknlp/partition/partition.py +244 -0
- sparknlp/partition/partition_properties.py +902 -0
- sparknlp/partition/partition_transformer.py +200 -0
- sparknlp/pretrained/__init__.py +17 -0
- sparknlp/pretrained/pretrained_pipeline.py +158 -0
- sparknlp/pretrained/resource_downloader.py +216 -0
- sparknlp/pretrained/utils.py +35 -0
- sparknlp/reader/__init__.py +15 -0
- sparknlp/reader/enums.py +19 -0
- sparknlp/reader/pdf_to_text.py +190 -0
- sparknlp/reader/reader2doc.py +124 -0
- sparknlp/reader/reader2image.py +136 -0
- sparknlp/reader/reader2table.py +44 -0
- sparknlp/reader/reader_assembler.py +159 -0
- sparknlp/reader/sparknlp_reader.py +461 -0
- sparknlp/training/__init__.py +20 -0
- sparknlp/training/_tf_graph_builders/__init__.py +0 -0
- sparknlp/training/_tf_graph_builders/graph_builders.py +299 -0
- sparknlp/training/_tf_graph_builders/ner_dl/__init__.py +0 -0
- sparknlp/training/_tf_graph_builders/ner_dl/create_graph.py +41 -0
- sparknlp/training/_tf_graph_builders/ner_dl/dataset_encoder.py +78 -0
- sparknlp/training/_tf_graph_builders/ner_dl/ner_model.py +521 -0
- sparknlp/training/_tf_graph_builders/ner_dl/ner_model_saver.py +62 -0
- sparknlp/training/_tf_graph_builders/ner_dl/sentence_grouper.py +28 -0
- sparknlp/training/_tf_graph_builders/tf2contrib/__init__.py +36 -0
- sparknlp/training/_tf_graph_builders/tf2contrib/core_rnn_cell.py +385 -0
- sparknlp/training/_tf_graph_builders/tf2contrib/fused_rnn_cell.py +183 -0
- sparknlp/training/_tf_graph_builders/tf2contrib/gru_ops.py +235 -0
- sparknlp/training/_tf_graph_builders/tf2contrib/lstm_ops.py +665 -0
- sparknlp/training/_tf_graph_builders/tf2contrib/rnn.py +245 -0
- sparknlp/training/_tf_graph_builders/tf2contrib/rnn_cell.py +4006 -0
- sparknlp/training/_tf_graph_builders_1x/__init__.py +0 -0
- sparknlp/training/_tf_graph_builders_1x/graph_builders.py +277 -0
- sparknlp/training/_tf_graph_builders_1x/ner_dl/__init__.py +0 -0
- sparknlp/training/_tf_graph_builders_1x/ner_dl/create_graph.py +34 -0
- sparknlp/training/_tf_graph_builders_1x/ner_dl/dataset_encoder.py +78 -0
- sparknlp/training/_tf_graph_builders_1x/ner_dl/ner_model.py +532 -0
- sparknlp/training/_tf_graph_builders_1x/ner_dl/ner_model_saver.py +62 -0
- sparknlp/training/_tf_graph_builders_1x/ner_dl/sentence_grouper.py +28 -0
- sparknlp/training/conll.py +150 -0
- sparknlp/training/conllu.py +103 -0
- sparknlp/training/pos.py +103 -0
- sparknlp/training/pub_tator.py +76 -0
- sparknlp/training/spacy_to_annotation.py +57 -0
- sparknlp/training/tfgraphs.py +5 -0
- sparknlp/upload_to_hub.py +149 -0
- sparknlp/util.py +51 -5
- com/__init__.pyc +0 -0
- com/__pycache__/__init__.cpython-36.pyc +0 -0
- com/johnsnowlabs/__init__.pyc +0 -0
- com/johnsnowlabs/__pycache__/__init__.cpython-36.pyc +0 -0
- com/johnsnowlabs/nlp/__init__.pyc +0 -0
- com/johnsnowlabs/nlp/__pycache__/__init__.cpython-36.pyc +0 -0
- spark_nlp-2.6.3rc1.dist-info/METADATA +0 -36
- spark_nlp-2.6.3rc1.dist-info/RECORD +0 -48
- sparknlp/__init__.pyc +0 -0
- sparknlp/__pycache__/__init__.cpython-36.pyc +0 -0
- sparknlp/__pycache__/annotation.cpython-36.pyc +0 -0
- sparknlp/__pycache__/annotator.cpython-36.pyc +0 -0
- sparknlp/__pycache__/base.cpython-36.pyc +0 -0
- sparknlp/__pycache__/common.cpython-36.pyc +0 -0
- sparknlp/__pycache__/embeddings.cpython-36.pyc +0 -0
- sparknlp/__pycache__/functions.cpython-36.pyc +0 -0
- sparknlp/__pycache__/internal.cpython-36.pyc +0 -0
- sparknlp/__pycache__/pretrained.cpython-36.pyc +0 -0
- sparknlp/__pycache__/storage.cpython-36.pyc +0 -0
- sparknlp/__pycache__/training.cpython-36.pyc +0 -0
- sparknlp/__pycache__/util.cpython-36.pyc +0 -0
- sparknlp/annotation.pyc +0 -0
- sparknlp/annotator.py +0 -3006
- sparknlp/annotator.pyc +0 -0
- sparknlp/base.py +0 -347
- sparknlp/base.pyc +0 -0
- sparknlp/common.py +0 -193
- sparknlp/common.pyc +0 -0
- sparknlp/embeddings.py +0 -40
- sparknlp/embeddings.pyc +0 -0
- sparknlp/internal.py +0 -288
- sparknlp/internal.pyc +0 -0
- sparknlp/pretrained.py +0 -123
- sparknlp/pretrained.pyc +0 -0
- sparknlp/storage.py +0 -32
- sparknlp/storage.pyc +0 -0
- sparknlp/training.py +0 -62
- sparknlp/training.pyc +0 -0
- sparknlp/util.pyc +0 -0
- {spark_nlp-2.6.3rc1.dist-info → spark_nlp-6.2.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,235 @@
|
|
|
1
|
+
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Python wrapper for the Block GRU Op."""
|
|
16
|
+
from __future__ import absolute_import
|
|
17
|
+
from __future__ import division
|
|
18
|
+
from __future__ import print_function
|
|
19
|
+
|
|
20
|
+
from tensorflow.python.framework import ops
|
|
21
|
+
from tensorflow.python.framework import tensor_shape
|
|
22
|
+
from tensorflow.python.keras.engine import input_spec
|
|
23
|
+
from tensorflow.python.ops import array_ops
|
|
24
|
+
from tensorflow.python.ops import gen_rnn_ops
|
|
25
|
+
from tensorflow.python.ops import init_ops
|
|
26
|
+
from tensorflow.python.ops import math_ops
|
|
27
|
+
from tensorflow.python.ops import nn_ops
|
|
28
|
+
from tensorflow.python.ops import rnn_cell_impl
|
|
29
|
+
from tensorflow.python.util.deprecation import deprecated_args
|
|
30
|
+
|
|
31
|
+
LayerRNNCell = rnn_cell_impl.LayerRNNCell # pylint: disable=invalid-name
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@ops.RegisterGradient("GRUBlockCell")
|
|
35
|
+
def _GRUBlockCellGrad(op, *grad):
|
|
36
|
+
r"""Gradient for GRUBlockCell.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
op: Op for which the gradient is defined.
|
|
40
|
+
*grad: Gradients of the optimization function wrt output
|
|
41
|
+
for the Op.
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
d_x: Gradients wrt to x
|
|
45
|
+
d_h: Gradients wrt to h
|
|
46
|
+
d_w_ru: Gradients wrt to w_ru
|
|
47
|
+
d_w_c: Gradients wrt to w_c
|
|
48
|
+
d_b_ru: Gradients wrt to b_ru
|
|
49
|
+
d_b_c: Gradients wrt to b_c
|
|
50
|
+
|
|
51
|
+
Mathematics behind the Gradients below:
|
|
52
|
+
```
|
|
53
|
+
d_c_bar = d_h \circ (1-u) \circ (1-c \circ c)
|
|
54
|
+
d_u_bar = d_h \circ (h-c) \circ u \circ (1-u)
|
|
55
|
+
|
|
56
|
+
d_r_bar_u_bar = [d_r_bar d_u_bar]
|
|
57
|
+
|
|
58
|
+
[d_x_component_1 d_h_prev_component_1] = d_r_bar_u_bar * w_ru^T
|
|
59
|
+
|
|
60
|
+
[d_x_component_2 d_h_prevr] = d_c_bar * w_c^T
|
|
61
|
+
|
|
62
|
+
d_x = d_x_component_1 + d_x_component_2
|
|
63
|
+
|
|
64
|
+
d_h_prev = d_h_prev_component_1 + d_h_prevr \circ r + u
|
|
65
|
+
```
|
|
66
|
+
Below calculation is performed in the python wrapper for the Gradients
|
|
67
|
+
(not in the gradient kernel.)
|
|
68
|
+
```
|
|
69
|
+
d_w_ru = x_h_prevr^T * d_c_bar
|
|
70
|
+
|
|
71
|
+
d_w_c = x_h_prev^T * d_r_bar_u_bar
|
|
72
|
+
|
|
73
|
+
d_b_ru = sum of d_r_bar_u_bar along axis = 0
|
|
74
|
+
|
|
75
|
+
d_b_c = sum of d_c_bar along axis = 0
|
|
76
|
+
```
|
|
77
|
+
"""
|
|
78
|
+
x, h_prev, w_ru, w_c, b_ru, b_c = op.inputs
|
|
79
|
+
r, u, c, _ = op.outputs
|
|
80
|
+
_, _, _, d_h = grad
|
|
81
|
+
|
|
82
|
+
d_x, d_h_prev, d_c_bar, d_r_bar_u_bar = gen_rnn_ops.gru_block_cell_grad(
|
|
83
|
+
x, h_prev, w_ru, w_c, b_ru, b_c, r, u, c, d_h)
|
|
84
|
+
|
|
85
|
+
x_h_prev = array_ops.concat([x, h_prev], 1)
|
|
86
|
+
d_w_ru = math_ops.matmul(x_h_prev, d_r_bar_u_bar, transpose_a=True)
|
|
87
|
+
d_b_ru = nn_ops.bias_add_grad(d_r_bar_u_bar)
|
|
88
|
+
|
|
89
|
+
x_h_prevr = array_ops.concat([x, h_prev * r], 1)
|
|
90
|
+
d_w_c = math_ops.matmul(x_h_prevr, d_c_bar, transpose_a=True)
|
|
91
|
+
d_b_c = nn_ops.bias_add_grad(d_c_bar)
|
|
92
|
+
|
|
93
|
+
return d_x, d_h_prev, d_w_ru, d_w_c, d_b_ru, d_b_c
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class GRUBlockCell(LayerRNNCell):
|
|
97
|
+
r"""Block GRU cell implementation.
|
|
98
|
+
|
|
99
|
+
Deprecated: use GRUBlockCellV2 instead.
|
|
100
|
+
|
|
101
|
+
The implementation is based on: http://arxiv.org/abs/1406.1078
|
|
102
|
+
Computes the GRU cell forward propagation for 1 time step.
|
|
103
|
+
|
|
104
|
+
This kernel op implements the following mathematical equations:
|
|
105
|
+
|
|
106
|
+
Biases are initialized with:
|
|
107
|
+
|
|
108
|
+
* `b_ru` - constant_initializer(1.0)
|
|
109
|
+
* `b_c` - constant_initializer(0.0)
|
|
110
|
+
|
|
111
|
+
```
|
|
112
|
+
x_h_prev = [x, h_prev]
|
|
113
|
+
|
|
114
|
+
[r_bar u_bar] = x_h_prev * w_ru + b_ru
|
|
115
|
+
|
|
116
|
+
r = sigmoid(r_bar)
|
|
117
|
+
u = sigmoid(u_bar)
|
|
118
|
+
|
|
119
|
+
h_prevr = h_prev \circ r
|
|
120
|
+
|
|
121
|
+
x_h_prevr = [x h_prevr]
|
|
122
|
+
|
|
123
|
+
c_bar = x_h_prevr * w_c + b_c
|
|
124
|
+
c = tanh(c_bar)
|
|
125
|
+
|
|
126
|
+
h = (1-u) \circ c + u \circ h_prev
|
|
127
|
+
```
|
|
128
|
+
|
|
129
|
+
"""
|
|
130
|
+
|
|
131
|
+
@deprecated_args(None, "cell_size is deprecated, use num_units instead",
|
|
132
|
+
"cell_size")
|
|
133
|
+
def __init__(self,
|
|
134
|
+
num_units=None,
|
|
135
|
+
cell_size=None,
|
|
136
|
+
reuse=None,
|
|
137
|
+
name="gru_cell"):
|
|
138
|
+
"""Initialize the Block GRU cell.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
num_units: int, The number of units in the GRU cell.
|
|
142
|
+
cell_size: int, The old (deprecated) name for `num_units`.
|
|
143
|
+
reuse: (optional) boolean describing whether to reuse variables in an
|
|
144
|
+
existing scope. If not `True`, and the existing scope already has the
|
|
145
|
+
given variables, an error is raised.
|
|
146
|
+
name: String, the name of the layer. Layers with the same name will
|
|
147
|
+
share weights, but to avoid mistakes we require reuse=True in such
|
|
148
|
+
cases. By default this is "lstm_cell", for variable-name compatibility
|
|
149
|
+
with `tf.compat.v1.nn.rnn_cell.GRUCell`.
|
|
150
|
+
|
|
151
|
+
Raises:
|
|
152
|
+
ValueError: if both cell_size and num_units are not None;
|
|
153
|
+
or both are None.
|
|
154
|
+
"""
|
|
155
|
+
super(GRUBlockCell, self).__init__(_reuse=reuse, name=name)
|
|
156
|
+
if (cell_size is None) == (num_units is None):
|
|
157
|
+
raise ValueError(
|
|
158
|
+
"Exactly one of num_units or cell_size must be provided.")
|
|
159
|
+
if num_units is None:
|
|
160
|
+
num_units = cell_size
|
|
161
|
+
self._cell_size = num_units
|
|
162
|
+
# Inputs must be 2-dimensional.
|
|
163
|
+
self.input_spec = input_spec.InputSpec(ndim=2)
|
|
164
|
+
|
|
165
|
+
@property
|
|
166
|
+
def state_size(self):
|
|
167
|
+
return self._cell_size
|
|
168
|
+
|
|
169
|
+
@property
|
|
170
|
+
def output_size(self):
|
|
171
|
+
return self._cell_size
|
|
172
|
+
|
|
173
|
+
def build(self, input_shape):
|
|
174
|
+
# Check if the input size exist.
|
|
175
|
+
input_size = tensor_shape.dimension_value(input_shape[1])
|
|
176
|
+
if input_size is None:
|
|
177
|
+
raise ValueError("Expecting input_size to be set.")
|
|
178
|
+
|
|
179
|
+
self._gate_kernel = self.add_variable(
|
|
180
|
+
"w_ru", [input_size + self._cell_size, self._cell_size * 2])
|
|
181
|
+
self._gate_bias = self.add_variable(
|
|
182
|
+
"b_ru", [self._cell_size * 2],
|
|
183
|
+
initializer=init_ops.constant_initializer(1.0))
|
|
184
|
+
self._candidate_kernel = self.add_variable(
|
|
185
|
+
"w_c", [input_size + self._cell_size, self._cell_size])
|
|
186
|
+
self._candidate_bias = self.add_variable(
|
|
187
|
+
"b_c", [self._cell_size],
|
|
188
|
+
initializer=init_ops.constant_initializer(0.0))
|
|
189
|
+
|
|
190
|
+
self.built = True
|
|
191
|
+
|
|
192
|
+
def call(self, inputs, h_prev):
|
|
193
|
+
"""GRU cell."""
|
|
194
|
+
# Check cell_size == state_size from h_prev.
|
|
195
|
+
cell_size = h_prev.get_shape().with_rank(2)[1]
|
|
196
|
+
if cell_size != self._cell_size:
|
|
197
|
+
raise ValueError("Shape of h_prev[1] incorrect: cell_size %i vs %s" %
|
|
198
|
+
(self._cell_size, cell_size))
|
|
199
|
+
|
|
200
|
+
_gru_block_cell = gen_rnn_ops.gru_block_cell
|
|
201
|
+
_, _, _, new_h = _gru_block_cell(
|
|
202
|
+
x=inputs,
|
|
203
|
+
h_prev=h_prev,
|
|
204
|
+
w_ru=self._gate_kernel,
|
|
205
|
+
w_c=self._candidate_kernel,
|
|
206
|
+
b_ru=self._gate_bias,
|
|
207
|
+
b_c=self._candidate_bias)
|
|
208
|
+
|
|
209
|
+
return new_h, new_h
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
class GRUBlockCellV2(GRUBlockCell):
|
|
213
|
+
"""Temporary GRUBlockCell impl with a different variable naming scheme.
|
|
214
|
+
|
|
215
|
+
Only differs from GRUBlockCell by variable names.
|
|
216
|
+
"""
|
|
217
|
+
|
|
218
|
+
def build(self, input_shape):
|
|
219
|
+
"""GRU cell."""
|
|
220
|
+
input_size = tensor_shape.dimension_value(input_shape[1])
|
|
221
|
+
if input_size is None:
|
|
222
|
+
raise ValueError("Expecting input_size to be set.")
|
|
223
|
+
|
|
224
|
+
self._gate_kernel = self.add_variable(
|
|
225
|
+
"gates/kernel", [input_size + self._cell_size, self._cell_size * 2])
|
|
226
|
+
self._gate_bias = self.add_variable(
|
|
227
|
+
"gates/bias", [self._cell_size * 2],
|
|
228
|
+
initializer=init_ops.constant_initializer(1.0))
|
|
229
|
+
self._candidate_kernel = self.add_variable(
|
|
230
|
+
"candidate/kernel", [input_size + self._cell_size, self._cell_size])
|
|
231
|
+
self._candidate_bias = self.add_variable(
|
|
232
|
+
"candidate/bias", [self._cell_size],
|
|
233
|
+
initializer=init_ops.constant_initializer(0.0))
|
|
234
|
+
|
|
235
|
+
self.built = True
|