spark-nlp 2.6.3rc1__py2.py3-none-any.whl → 6.2.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- com/johnsnowlabs/ml/__init__.py +0 -0
- com/johnsnowlabs/ml/ai/__init__.py +10 -0
- com/johnsnowlabs/nlp/__init__.py +4 -2
- spark_nlp-6.2.1.dist-info/METADATA +362 -0
- spark_nlp-6.2.1.dist-info/RECORD +292 -0
- {spark_nlp-2.6.3rc1.dist-info → spark_nlp-6.2.1.dist-info}/WHEEL +1 -1
- sparknlp/__init__.py +281 -27
- sparknlp/annotation.py +137 -6
- sparknlp/annotation_audio.py +61 -0
- sparknlp/annotation_image.py +82 -0
- sparknlp/annotator/__init__.py +93 -0
- sparknlp/annotator/audio/__init__.py +16 -0
- sparknlp/annotator/audio/hubert_for_ctc.py +188 -0
- sparknlp/annotator/audio/wav2vec2_for_ctc.py +161 -0
- sparknlp/annotator/audio/whisper_for_ctc.py +251 -0
- sparknlp/annotator/chunk2_doc.py +85 -0
- sparknlp/annotator/chunker.py +137 -0
- sparknlp/annotator/classifier_dl/__init__.py +61 -0
- sparknlp/annotator/classifier_dl/albert_for_multiple_choice.py +161 -0
- sparknlp/annotator/classifier_dl/albert_for_question_answering.py +172 -0
- sparknlp/annotator/classifier_dl/albert_for_sequence_classification.py +201 -0
- sparknlp/annotator/classifier_dl/albert_for_token_classification.py +179 -0
- sparknlp/annotator/classifier_dl/albert_for_zero_shot_classification.py +211 -0
- sparknlp/annotator/classifier_dl/bart_for_zero_shot_classification.py +225 -0
- sparknlp/annotator/classifier_dl/bert_for_multiple_choice.py +161 -0
- sparknlp/annotator/classifier_dl/bert_for_question_answering.py +168 -0
- sparknlp/annotator/classifier_dl/bert_for_sequence_classification.py +202 -0
- sparknlp/annotator/classifier_dl/bert_for_token_classification.py +177 -0
- sparknlp/annotator/classifier_dl/bert_for_zero_shot_classification.py +212 -0
- sparknlp/annotator/classifier_dl/camembert_for_question_answering.py +168 -0
- sparknlp/annotator/classifier_dl/camembert_for_sequence_classification.py +205 -0
- sparknlp/annotator/classifier_dl/camembert_for_token_classification.py +173 -0
- sparknlp/annotator/classifier_dl/camembert_for_zero_shot_classification.py +202 -0
- sparknlp/annotator/classifier_dl/classifier_dl.py +320 -0
- sparknlp/annotator/classifier_dl/deberta_for_question_answering.py +168 -0
- sparknlp/annotator/classifier_dl/deberta_for_sequence_classification.py +198 -0
- sparknlp/annotator/classifier_dl/deberta_for_token_classification.py +175 -0
- sparknlp/annotator/classifier_dl/deberta_for_zero_shot_classification.py +193 -0
- sparknlp/annotator/classifier_dl/distil_bert_for_question_answering.py +168 -0
- sparknlp/annotator/classifier_dl/distil_bert_for_sequence_classification.py +201 -0
- sparknlp/annotator/classifier_dl/distil_bert_for_token_classification.py +175 -0
- sparknlp/annotator/classifier_dl/distil_bert_for_zero_shot_classification.py +211 -0
- sparknlp/annotator/classifier_dl/distilbert_for_multiple_choice.py +161 -0
- sparknlp/annotator/classifier_dl/longformer_for_question_answering.py +168 -0
- sparknlp/annotator/classifier_dl/longformer_for_sequence_classification.py +201 -0
- sparknlp/annotator/classifier_dl/longformer_for_token_classification.py +176 -0
- sparknlp/annotator/classifier_dl/mpnet_for_question_answering.py +148 -0
- sparknlp/annotator/classifier_dl/mpnet_for_sequence_classification.py +188 -0
- sparknlp/annotator/classifier_dl/mpnet_for_token_classification.py +173 -0
- sparknlp/annotator/classifier_dl/multi_classifier_dl.py +395 -0
- sparknlp/annotator/classifier_dl/roberta_for_multiple_choice.py +161 -0
- sparknlp/annotator/classifier_dl/roberta_for_question_answering.py +168 -0
- sparknlp/annotator/classifier_dl/roberta_for_sequence_classification.py +201 -0
- sparknlp/annotator/classifier_dl/roberta_for_token_classification.py +189 -0
- sparknlp/annotator/classifier_dl/roberta_for_zero_shot_classification.py +225 -0
- sparknlp/annotator/classifier_dl/sentiment_dl.py +378 -0
- sparknlp/annotator/classifier_dl/tapas_for_question_answering.py +170 -0
- sparknlp/annotator/classifier_dl/xlm_roberta_for_multiple_choice.py +149 -0
- sparknlp/annotator/classifier_dl/xlm_roberta_for_question_answering.py +168 -0
- sparknlp/annotator/classifier_dl/xlm_roberta_for_sequence_classification.py +201 -0
- sparknlp/annotator/classifier_dl/xlm_roberta_for_token_classification.py +173 -0
- sparknlp/annotator/classifier_dl/xlm_roberta_for_zero_shot_classification.py +225 -0
- sparknlp/annotator/classifier_dl/xlnet_for_sequence_classification.py +201 -0
- sparknlp/annotator/classifier_dl/xlnet_for_token_classification.py +176 -0
- sparknlp/annotator/cleaners/__init__.py +15 -0
- sparknlp/annotator/cleaners/cleaner.py +202 -0
- sparknlp/annotator/cleaners/extractor.py +191 -0
- sparknlp/annotator/coref/__init__.py +1 -0
- sparknlp/annotator/coref/spanbert_coref.py +221 -0
- sparknlp/annotator/cv/__init__.py +29 -0
- sparknlp/annotator/cv/blip_for_question_answering.py +172 -0
- sparknlp/annotator/cv/clip_for_zero_shot_classification.py +193 -0
- sparknlp/annotator/cv/convnext_for_image_classification.py +269 -0
- sparknlp/annotator/cv/florence2_transformer.py +180 -0
- sparknlp/annotator/cv/gemma3_for_multimodal.py +346 -0
- sparknlp/annotator/cv/internvl_for_multimodal.py +280 -0
- sparknlp/annotator/cv/janus_for_multimodal.py +351 -0
- sparknlp/annotator/cv/llava_for_multimodal.py +328 -0
- sparknlp/annotator/cv/mllama_for_multimodal.py +340 -0
- sparknlp/annotator/cv/paligemma_for_multimodal.py +308 -0
- sparknlp/annotator/cv/phi3_vision_for_multimodal.py +328 -0
- sparknlp/annotator/cv/qwen2vl_transformer.py +332 -0
- sparknlp/annotator/cv/smolvlm_transformer.py +426 -0
- sparknlp/annotator/cv/swin_for_image_classification.py +242 -0
- sparknlp/annotator/cv/vision_encoder_decoder_for_image_captioning.py +240 -0
- sparknlp/annotator/cv/vit_for_image_classification.py +217 -0
- sparknlp/annotator/dataframe_optimizer.py +216 -0
- sparknlp/annotator/date2_chunk.py +88 -0
- sparknlp/annotator/dependency/__init__.py +17 -0
- sparknlp/annotator/dependency/dependency_parser.py +294 -0
- sparknlp/annotator/dependency/typed_dependency_parser.py +318 -0
- sparknlp/annotator/document_character_text_splitter.py +228 -0
- sparknlp/annotator/document_normalizer.py +235 -0
- sparknlp/annotator/document_token_splitter.py +175 -0
- sparknlp/annotator/document_token_splitter_test.py +85 -0
- sparknlp/annotator/embeddings/__init__.py +45 -0
- sparknlp/annotator/embeddings/albert_embeddings.py +230 -0
- sparknlp/annotator/embeddings/auto_gguf_embeddings.py +539 -0
- sparknlp/annotator/embeddings/bert_embeddings.py +208 -0
- sparknlp/annotator/embeddings/bert_sentence_embeddings.py +224 -0
- sparknlp/annotator/embeddings/bge_embeddings.py +199 -0
- sparknlp/annotator/embeddings/camembert_embeddings.py +210 -0
- sparknlp/annotator/embeddings/chunk_embeddings.py +149 -0
- sparknlp/annotator/embeddings/deberta_embeddings.py +208 -0
- sparknlp/annotator/embeddings/distil_bert_embeddings.py +221 -0
- sparknlp/annotator/embeddings/doc2vec.py +352 -0
- sparknlp/annotator/embeddings/e5_embeddings.py +195 -0
- sparknlp/annotator/embeddings/e5v_embeddings.py +138 -0
- sparknlp/annotator/embeddings/elmo_embeddings.py +251 -0
- sparknlp/annotator/embeddings/instructor_embeddings.py +204 -0
- sparknlp/annotator/embeddings/longformer_embeddings.py +211 -0
- sparknlp/annotator/embeddings/minilm_embeddings.py +189 -0
- sparknlp/annotator/embeddings/mpnet_embeddings.py +192 -0
- sparknlp/annotator/embeddings/mxbai_embeddings.py +184 -0
- sparknlp/annotator/embeddings/nomic_embeddings.py +181 -0
- sparknlp/annotator/embeddings/roberta_embeddings.py +225 -0
- sparknlp/annotator/embeddings/roberta_sentence_embeddings.py +191 -0
- sparknlp/annotator/embeddings/sentence_embeddings.py +134 -0
- sparknlp/annotator/embeddings/snowflake_embeddings.py +202 -0
- sparknlp/annotator/embeddings/uae_embeddings.py +211 -0
- sparknlp/annotator/embeddings/universal_sentence_encoder.py +211 -0
- sparknlp/annotator/embeddings/word2vec.py +353 -0
- sparknlp/annotator/embeddings/word_embeddings.py +385 -0
- sparknlp/annotator/embeddings/xlm_roberta_embeddings.py +225 -0
- sparknlp/annotator/embeddings/xlm_roberta_sentence_embeddings.py +194 -0
- sparknlp/annotator/embeddings/xlnet_embeddings.py +227 -0
- sparknlp/annotator/er/__init__.py +16 -0
- sparknlp/annotator/er/entity_ruler.py +267 -0
- sparknlp/annotator/graph_extraction.py +368 -0
- sparknlp/annotator/keyword_extraction/__init__.py +16 -0
- sparknlp/annotator/keyword_extraction/yake_keyword_extraction.py +270 -0
- sparknlp/annotator/ld_dl/__init__.py +16 -0
- sparknlp/annotator/ld_dl/language_detector_dl.py +199 -0
- sparknlp/annotator/lemmatizer.py +250 -0
- sparknlp/annotator/matcher/__init__.py +20 -0
- sparknlp/annotator/matcher/big_text_matcher.py +272 -0
- sparknlp/annotator/matcher/date_matcher.py +303 -0
- sparknlp/annotator/matcher/multi_date_matcher.py +109 -0
- sparknlp/annotator/matcher/regex_matcher.py +221 -0
- sparknlp/annotator/matcher/text_matcher.py +290 -0
- sparknlp/annotator/n_gram_generator.py +141 -0
- sparknlp/annotator/ner/__init__.py +21 -0
- sparknlp/annotator/ner/ner_approach.py +94 -0
- sparknlp/annotator/ner/ner_converter.py +148 -0
- sparknlp/annotator/ner/ner_crf.py +397 -0
- sparknlp/annotator/ner/ner_dl.py +591 -0
- sparknlp/annotator/ner/ner_dl_graph_checker.py +293 -0
- sparknlp/annotator/ner/ner_overwriter.py +166 -0
- sparknlp/annotator/ner/zero_shot_ner_model.py +173 -0
- sparknlp/annotator/normalizer.py +230 -0
- sparknlp/annotator/openai/__init__.py +16 -0
- sparknlp/annotator/openai/openai_completion.py +349 -0
- sparknlp/annotator/openai/openai_embeddings.py +106 -0
- sparknlp/annotator/param/__init__.py +17 -0
- sparknlp/annotator/param/classifier_encoder.py +98 -0
- sparknlp/annotator/param/evaluation_dl_params.py +130 -0
- sparknlp/annotator/pos/__init__.py +16 -0
- sparknlp/annotator/pos/perceptron.py +263 -0
- sparknlp/annotator/sentence/__init__.py +17 -0
- sparknlp/annotator/sentence/sentence_detector.py +290 -0
- sparknlp/annotator/sentence/sentence_detector_dl.py +467 -0
- sparknlp/annotator/sentiment/__init__.py +17 -0
- sparknlp/annotator/sentiment/sentiment_detector.py +208 -0
- sparknlp/annotator/sentiment/vivekn_sentiment.py +242 -0
- sparknlp/annotator/seq2seq/__init__.py +35 -0
- sparknlp/annotator/seq2seq/auto_gguf_model.py +304 -0
- sparknlp/annotator/seq2seq/auto_gguf_reranker.py +334 -0
- sparknlp/annotator/seq2seq/auto_gguf_vision_model.py +336 -0
- sparknlp/annotator/seq2seq/bart_transformer.py +420 -0
- sparknlp/annotator/seq2seq/cohere_transformer.py +357 -0
- sparknlp/annotator/seq2seq/cpm_transformer.py +321 -0
- sparknlp/annotator/seq2seq/gpt2_transformer.py +363 -0
- sparknlp/annotator/seq2seq/llama2_transformer.py +343 -0
- sparknlp/annotator/seq2seq/llama3_transformer.py +381 -0
- sparknlp/annotator/seq2seq/m2m100_transformer.py +392 -0
- sparknlp/annotator/seq2seq/marian_transformer.py +374 -0
- sparknlp/annotator/seq2seq/mistral_transformer.py +348 -0
- sparknlp/annotator/seq2seq/nllb_transformer.py +420 -0
- sparknlp/annotator/seq2seq/olmo_transformer.py +326 -0
- sparknlp/annotator/seq2seq/phi2_transformer.py +326 -0
- sparknlp/annotator/seq2seq/phi3_transformer.py +330 -0
- sparknlp/annotator/seq2seq/phi4_transformer.py +387 -0
- sparknlp/annotator/seq2seq/qwen_transformer.py +340 -0
- sparknlp/annotator/seq2seq/starcoder_transformer.py +335 -0
- sparknlp/annotator/seq2seq/t5_transformer.py +425 -0
- sparknlp/annotator/similarity/__init__.py +0 -0
- sparknlp/annotator/similarity/document_similarity_ranker.py +379 -0
- sparknlp/annotator/spell_check/__init__.py +18 -0
- sparknlp/annotator/spell_check/context_spell_checker.py +911 -0
- sparknlp/annotator/spell_check/norvig_sweeting.py +358 -0
- sparknlp/annotator/spell_check/symmetric_delete.py +299 -0
- sparknlp/annotator/stemmer.py +79 -0
- sparknlp/annotator/stop_words_cleaner.py +190 -0
- sparknlp/annotator/tf_ner_dl_graph_builder.py +179 -0
- sparknlp/annotator/token/__init__.py +19 -0
- sparknlp/annotator/token/chunk_tokenizer.py +118 -0
- sparknlp/annotator/token/recursive_tokenizer.py +205 -0
- sparknlp/annotator/token/regex_tokenizer.py +208 -0
- sparknlp/annotator/token/tokenizer.py +561 -0
- sparknlp/annotator/token2_chunk.py +76 -0
- sparknlp/annotator/ws/__init__.py +16 -0
- sparknlp/annotator/ws/word_segmenter.py +429 -0
- sparknlp/base/__init__.py +30 -0
- sparknlp/base/audio_assembler.py +95 -0
- sparknlp/base/doc2_chunk.py +169 -0
- sparknlp/base/document_assembler.py +164 -0
- sparknlp/base/embeddings_finisher.py +201 -0
- sparknlp/base/finisher.py +217 -0
- sparknlp/base/gguf_ranking_finisher.py +234 -0
- sparknlp/base/graph_finisher.py +125 -0
- sparknlp/base/has_recursive_fit.py +24 -0
- sparknlp/base/has_recursive_transform.py +22 -0
- sparknlp/base/image_assembler.py +172 -0
- sparknlp/base/light_pipeline.py +429 -0
- sparknlp/base/multi_document_assembler.py +164 -0
- sparknlp/base/prompt_assembler.py +207 -0
- sparknlp/base/recursive_pipeline.py +107 -0
- sparknlp/base/table_assembler.py +145 -0
- sparknlp/base/token_assembler.py +124 -0
- sparknlp/common/__init__.py +26 -0
- sparknlp/common/annotator_approach.py +41 -0
- sparknlp/common/annotator_model.py +47 -0
- sparknlp/common/annotator_properties.py +114 -0
- sparknlp/common/annotator_type.py +38 -0
- sparknlp/common/completion_post_processing.py +37 -0
- sparknlp/common/coverage_result.py +22 -0
- sparknlp/common/match_strategy.py +33 -0
- sparknlp/common/properties.py +1298 -0
- sparknlp/common/read_as.py +33 -0
- sparknlp/common/recursive_annotator_approach.py +35 -0
- sparknlp/common/storage.py +149 -0
- sparknlp/common/utils.py +39 -0
- sparknlp/functions.py +315 -5
- sparknlp/internal/__init__.py +1199 -0
- sparknlp/internal/annotator_java_ml.py +32 -0
- sparknlp/internal/annotator_transformer.py +37 -0
- sparknlp/internal/extended_java_wrapper.py +63 -0
- sparknlp/internal/params_getters_setters.py +71 -0
- sparknlp/internal/recursive.py +70 -0
- sparknlp/logging/__init__.py +15 -0
- sparknlp/logging/comet.py +467 -0
- sparknlp/partition/__init__.py +16 -0
- sparknlp/partition/partition.py +244 -0
- sparknlp/partition/partition_properties.py +902 -0
- sparknlp/partition/partition_transformer.py +200 -0
- sparknlp/pretrained/__init__.py +17 -0
- sparknlp/pretrained/pretrained_pipeline.py +158 -0
- sparknlp/pretrained/resource_downloader.py +216 -0
- sparknlp/pretrained/utils.py +35 -0
- sparknlp/reader/__init__.py +15 -0
- sparknlp/reader/enums.py +19 -0
- sparknlp/reader/pdf_to_text.py +190 -0
- sparknlp/reader/reader2doc.py +124 -0
- sparknlp/reader/reader2image.py +136 -0
- sparknlp/reader/reader2table.py +44 -0
- sparknlp/reader/reader_assembler.py +159 -0
- sparknlp/reader/sparknlp_reader.py +461 -0
- sparknlp/training/__init__.py +20 -0
- sparknlp/training/_tf_graph_builders/__init__.py +0 -0
- sparknlp/training/_tf_graph_builders/graph_builders.py +299 -0
- sparknlp/training/_tf_graph_builders/ner_dl/__init__.py +0 -0
- sparknlp/training/_tf_graph_builders/ner_dl/create_graph.py +41 -0
- sparknlp/training/_tf_graph_builders/ner_dl/dataset_encoder.py +78 -0
- sparknlp/training/_tf_graph_builders/ner_dl/ner_model.py +521 -0
- sparknlp/training/_tf_graph_builders/ner_dl/ner_model_saver.py +62 -0
- sparknlp/training/_tf_graph_builders/ner_dl/sentence_grouper.py +28 -0
- sparknlp/training/_tf_graph_builders/tf2contrib/__init__.py +36 -0
- sparknlp/training/_tf_graph_builders/tf2contrib/core_rnn_cell.py +385 -0
- sparknlp/training/_tf_graph_builders/tf2contrib/fused_rnn_cell.py +183 -0
- sparknlp/training/_tf_graph_builders/tf2contrib/gru_ops.py +235 -0
- sparknlp/training/_tf_graph_builders/tf2contrib/lstm_ops.py +665 -0
- sparknlp/training/_tf_graph_builders/tf2contrib/rnn.py +245 -0
- sparknlp/training/_tf_graph_builders/tf2contrib/rnn_cell.py +4006 -0
- sparknlp/training/_tf_graph_builders_1x/__init__.py +0 -0
- sparknlp/training/_tf_graph_builders_1x/graph_builders.py +277 -0
- sparknlp/training/_tf_graph_builders_1x/ner_dl/__init__.py +0 -0
- sparknlp/training/_tf_graph_builders_1x/ner_dl/create_graph.py +34 -0
- sparknlp/training/_tf_graph_builders_1x/ner_dl/dataset_encoder.py +78 -0
- sparknlp/training/_tf_graph_builders_1x/ner_dl/ner_model.py +532 -0
- sparknlp/training/_tf_graph_builders_1x/ner_dl/ner_model_saver.py +62 -0
- sparknlp/training/_tf_graph_builders_1x/ner_dl/sentence_grouper.py +28 -0
- sparknlp/training/conll.py +150 -0
- sparknlp/training/conllu.py +103 -0
- sparknlp/training/pos.py +103 -0
- sparknlp/training/pub_tator.py +76 -0
- sparknlp/training/spacy_to_annotation.py +57 -0
- sparknlp/training/tfgraphs.py +5 -0
- sparknlp/upload_to_hub.py +149 -0
- sparknlp/util.py +51 -5
- com/__init__.pyc +0 -0
- com/__pycache__/__init__.cpython-36.pyc +0 -0
- com/johnsnowlabs/__init__.pyc +0 -0
- com/johnsnowlabs/__pycache__/__init__.cpython-36.pyc +0 -0
- com/johnsnowlabs/nlp/__init__.pyc +0 -0
- com/johnsnowlabs/nlp/__pycache__/__init__.cpython-36.pyc +0 -0
- spark_nlp-2.6.3rc1.dist-info/METADATA +0 -36
- spark_nlp-2.6.3rc1.dist-info/RECORD +0 -48
- sparknlp/__init__.pyc +0 -0
- sparknlp/__pycache__/__init__.cpython-36.pyc +0 -0
- sparknlp/__pycache__/annotation.cpython-36.pyc +0 -0
- sparknlp/__pycache__/annotator.cpython-36.pyc +0 -0
- sparknlp/__pycache__/base.cpython-36.pyc +0 -0
- sparknlp/__pycache__/common.cpython-36.pyc +0 -0
- sparknlp/__pycache__/embeddings.cpython-36.pyc +0 -0
- sparknlp/__pycache__/functions.cpython-36.pyc +0 -0
- sparknlp/__pycache__/internal.cpython-36.pyc +0 -0
- sparknlp/__pycache__/pretrained.cpython-36.pyc +0 -0
- sparknlp/__pycache__/storage.cpython-36.pyc +0 -0
- sparknlp/__pycache__/training.cpython-36.pyc +0 -0
- sparknlp/__pycache__/util.cpython-36.pyc +0 -0
- sparknlp/annotation.pyc +0 -0
- sparknlp/annotator.py +0 -3006
- sparknlp/annotator.pyc +0 -0
- sparknlp/base.py +0 -347
- sparknlp/base.pyc +0 -0
- sparknlp/common.py +0 -193
- sparknlp/common.pyc +0 -0
- sparknlp/embeddings.py +0 -40
- sparknlp/embeddings.pyc +0 -0
- sparknlp/internal.py +0 -288
- sparknlp/internal.pyc +0 -0
- sparknlp/pretrained.py +0 -123
- sparknlp/pretrained.pyc +0 -0
- sparknlp/storage.py +0 -32
- sparknlp/storage.pyc +0 -0
- sparknlp/training.py +0 -62
- sparknlp/training.pyc +0 -0
- sparknlp/util.pyc +0 -0
- {spark_nlp-2.6.3rc1.dist-info → spark_nlp-6.2.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,532 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import random
|
|
3
|
+
import sys
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import tensorflow as tf
|
|
7
|
+
|
|
8
|
+
from .sentence_grouper import SentenceGrouper
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class NerModel:
|
|
12
|
+
# If session is not defined than default session will be used
|
|
13
|
+
def __init__(self, session=None, dummy_tags=None, use_contrib=True, use_gpu_device=0):
|
|
14
|
+
tf.disable_v2_behavior()
|
|
15
|
+
|
|
16
|
+
self.word_repr = None
|
|
17
|
+
self.word_embeddings = None
|
|
18
|
+
self.session = session
|
|
19
|
+
self.session_created = False
|
|
20
|
+
self.dummy_tags = dummy_tags or []
|
|
21
|
+
self.use_contrib = use_contrib
|
|
22
|
+
self.use_gpu_device = use_gpu_device
|
|
23
|
+
|
|
24
|
+
if self.session is None:
|
|
25
|
+
self.session_created = True
|
|
26
|
+
self.session = tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(
|
|
27
|
+
allow_soft_placement=True,
|
|
28
|
+
log_device_placement=False))
|
|
29
|
+
with tf.compat.v1.device('/gpu:{}'.format(self.use_gpu_device)):
|
|
30
|
+
with tf.compat.v1.variable_scope("char_repr"):
|
|
31
|
+
# shape = (batch size, sentence, word)
|
|
32
|
+
self.char_ids = tf.compat.v1.placeholder(tf.int32, shape=[None, None, None], name="char_ids")
|
|
33
|
+
|
|
34
|
+
# shape = (batch_size, sentence)
|
|
35
|
+
self.word_lengths = tf.compat.v1.placeholder(tf.int32, shape=[None, None], name="word_lengths")
|
|
36
|
+
|
|
37
|
+
with tf.compat.v1.variable_scope("word_repr"):
|
|
38
|
+
# shape = (batch size)
|
|
39
|
+
self.sentence_lengths = tf.compat.v1.placeholder(tf.int32, shape=[None], name="sentence_lengths")
|
|
40
|
+
|
|
41
|
+
with tf.compat.v1.variable_scope("training", reuse=None) as scope:
|
|
42
|
+
# shape = (batch, sentence)
|
|
43
|
+
self.labels = tf.compat.v1.placeholder(tf.int32, shape=[None, None], name="labels")
|
|
44
|
+
|
|
45
|
+
self.lr = tf.compat.v1.placeholder_with_default(0.005, shape=(), name="lr")
|
|
46
|
+
self.dropout = tf.compat.v1.placeholder(tf.float32, shape=(), name="dropout")
|
|
47
|
+
|
|
48
|
+
self._char_bilstm_added = False
|
|
49
|
+
self._char_cnn_added = False
|
|
50
|
+
self._word_embeddings_added = False
|
|
51
|
+
self._context_added = False
|
|
52
|
+
self._encode_added = False
|
|
53
|
+
|
|
54
|
+
def add_bilstm_char_repr(self, nchars=101, dim=25, hidden=25):
|
|
55
|
+
self._char_bilstm_added = True
|
|
56
|
+
|
|
57
|
+
with tf.compat.v1.device('/gpu:{}'.format(self.use_gpu_device)):
|
|
58
|
+
|
|
59
|
+
with tf.compat.v1.variable_scope("char_repr_lstm"):
|
|
60
|
+
# 1. Lookup for character embeddings
|
|
61
|
+
char_range = math.sqrt(3 / dim)
|
|
62
|
+
embeddings = tf.compat.v1.get_variable(name="char_embeddings",
|
|
63
|
+
dtype=tf.float32,
|
|
64
|
+
shape=[nchars, dim],
|
|
65
|
+
initializer=tf.compat.v1.random_uniform_initializer(
|
|
66
|
+
-char_range,
|
|
67
|
+
char_range
|
|
68
|
+
),
|
|
69
|
+
use_resource=False)
|
|
70
|
+
|
|
71
|
+
# shape = (batch, sentence, word, char embeddings dim)
|
|
72
|
+
char_embeddings = tf.nn.embedding_lookup(params=embeddings, ids=self.char_ids)
|
|
73
|
+
# char_embeddings = tf.nn.dropout(char_embeddings, self.dropout)
|
|
74
|
+
s = tf.shape(input=char_embeddings)
|
|
75
|
+
|
|
76
|
+
# shape = (batch x sentence, word, char embeddings dim)
|
|
77
|
+
char_embeddings_seq = tf.reshape(char_embeddings, shape=[-1, s[-2], dim])
|
|
78
|
+
|
|
79
|
+
# shape = (batch x sentence)
|
|
80
|
+
word_lengths_seq = tf.reshape(self.word_lengths, shape=[-1])
|
|
81
|
+
|
|
82
|
+
# 2. Add Bidirectional LSTM
|
|
83
|
+
model = tf.keras.Sequential([
|
|
84
|
+
tf.keras.layers.Bidirectional(
|
|
85
|
+
layer=tf.keras.layers.LSTM(hidden, return_sequences=False),
|
|
86
|
+
merge_mode="concat"
|
|
87
|
+
)
|
|
88
|
+
])
|
|
89
|
+
|
|
90
|
+
inputs = char_embeddings_seq
|
|
91
|
+
mask = tf.expand_dims(tf.sequence_mask(word_lengths_seq, dtype=tf.float32), axis=-1)
|
|
92
|
+
|
|
93
|
+
# shape = (batch x sentence, 2 x hidden)
|
|
94
|
+
output = model(inputs, mask=mask)
|
|
95
|
+
|
|
96
|
+
# shape = (batch, sentence, 2 x hidden)
|
|
97
|
+
char_repr = tf.reshape(output, shape=[-1, s[1], 2 * hidden])
|
|
98
|
+
|
|
99
|
+
if self.word_repr is not None:
|
|
100
|
+
self.word_repr = tf.concat([self.word_repr, char_repr], axis=-1)
|
|
101
|
+
else:
|
|
102
|
+
self.word_repr = char_repr
|
|
103
|
+
|
|
104
|
+
def add_cnn_char_repr(self, nchars=101, dim=25, nfilters=25, pad=2):
|
|
105
|
+
self._char_cnn_added = True
|
|
106
|
+
|
|
107
|
+
with tf.compat.v1.device('/gpu:{}'.format(self.use_gpu_device)):
|
|
108
|
+
|
|
109
|
+
with tf.compat.v1.variable_scope("char_repr_cnn") as scope:
|
|
110
|
+
# 1. Lookup for character embeddings
|
|
111
|
+
char_range = math.sqrt(3 / dim)
|
|
112
|
+
embeddings = tf.compat.v1.get_variable(name="char_embeddings", dtype=tf.float32,
|
|
113
|
+
shape=[nchars, dim],
|
|
114
|
+
initializer=tf.compat.v1.random_uniform_initializer(-char_range,
|
|
115
|
+
char_range),
|
|
116
|
+
use_resource=False)
|
|
117
|
+
|
|
118
|
+
# shape = (batch, sentence, word_len, embeddings dim)
|
|
119
|
+
char_embeddings = tf.nn.embedding_lookup(params=embeddings, ids=self.char_ids)
|
|
120
|
+
# char_embeddings = tf.nn.dropout(char_embeddings, self.dropout)
|
|
121
|
+
s = tf.shape(input=char_embeddings)
|
|
122
|
+
|
|
123
|
+
# shape = (batch x sentence, word_len, embeddings dim)
|
|
124
|
+
char_embeddings = tf.reshape(char_embeddings, shape=[-1, s[-2], dim])
|
|
125
|
+
|
|
126
|
+
# batch x sentence, word_len, nfilters
|
|
127
|
+
conv1d = tf.keras.layers.Conv1D(
|
|
128
|
+
filters=nfilters,
|
|
129
|
+
kernel_size=[3],
|
|
130
|
+
padding='same',
|
|
131
|
+
activation=tf.nn.relu
|
|
132
|
+
)(char_embeddings)
|
|
133
|
+
|
|
134
|
+
# Max across each filter, shape = (batch x sentence, nfilters)
|
|
135
|
+
char_repr = tf.reduce_max(input_tensor=conv1d, axis=1, keepdims=True)
|
|
136
|
+
char_repr = tf.squeeze(char_repr, axis=[1])
|
|
137
|
+
|
|
138
|
+
# (batch, sentence, nfilters)
|
|
139
|
+
char_repr = tf.reshape(char_repr, shape=[s[0], s[1], nfilters])
|
|
140
|
+
|
|
141
|
+
if self.word_repr is not None:
|
|
142
|
+
self.word_repr = tf.concat([self.word_repr, char_repr], axis=-1)
|
|
143
|
+
else:
|
|
144
|
+
self.word_repr = char_repr
|
|
145
|
+
|
|
146
|
+
def add_pretrained_word_embeddings(self, dim=100):
|
|
147
|
+
self._word_embeddings_added = True
|
|
148
|
+
|
|
149
|
+
with tf.compat.v1.device('/gpu:{}'.format(self.use_gpu_device)):
|
|
150
|
+
with tf.compat.v1.variable_scope("word_repr") as scope:
|
|
151
|
+
# shape = (batch size, sentence, dim)
|
|
152
|
+
self.word_embeddings = tf.compat.v1.placeholder(tf.float32, shape=[None, None, dim],
|
|
153
|
+
name="word_embeddings")
|
|
154
|
+
|
|
155
|
+
if self.word_repr is not None:
|
|
156
|
+
self.word_repr = tf.concat([self.word_repr, self.word_embeddings], axis=-1)
|
|
157
|
+
else:
|
|
158
|
+
self.word_repr = self.word_embeddings
|
|
159
|
+
|
|
160
|
+
def _create_lstm_layer(self, inputs, hidden_size, lengths):
|
|
161
|
+
|
|
162
|
+
with tf.compat.v1.device('/gpu:{}'.format(self.use_gpu_device)):
|
|
163
|
+
if not self.use_contrib:
|
|
164
|
+
model = tf.keras.Sequential([
|
|
165
|
+
tf.keras.layers.Bidirectional(
|
|
166
|
+
layer=tf.keras.layers.LSTM(hidden_size, return_sequences=False),
|
|
167
|
+
merge_mode="concat"
|
|
168
|
+
)
|
|
169
|
+
])
|
|
170
|
+
|
|
171
|
+
mask = tf.expand_dims(tf.sequence_mask(lengths, dtype=tf.float32), axis=-1)
|
|
172
|
+
# shape = (batch x sentence, 2 x hidden)
|
|
173
|
+
output = model(inputs, mask=mask)
|
|
174
|
+
# inputs shape = (batch, sentence, inp)
|
|
175
|
+
batch = tf.shape(input=lengths)[0]
|
|
176
|
+
|
|
177
|
+
return tf.reshape(output, shape=[batch, -1, 2 * hidden_size])
|
|
178
|
+
|
|
179
|
+
time_based = tf.transpose(a=inputs, perm=[1, 0, 2])
|
|
180
|
+
|
|
181
|
+
cell_fw = tf.contrib.rnn.LSTMBlockFusedCell(hidden_size, use_peephole=True)
|
|
182
|
+
cell_bw = tf.contrib.rnn.LSTMBlockFusedCell(hidden_size, use_peephole=True)
|
|
183
|
+
cell_bw = tf.contrib.rnn.TimeReversedFusedRNN(cell_bw)
|
|
184
|
+
|
|
185
|
+
output_fw, _ = cell_fw(time_based, dtype=tf.float32, sequence_length=lengths)
|
|
186
|
+
output_bw, _ = cell_bw(time_based, dtype=tf.float32, sequence_length=lengths)
|
|
187
|
+
|
|
188
|
+
result = tf.concat([output_fw, output_bw], axis=-1)
|
|
189
|
+
return tf.transpose(a=result, perm=[1, 0, 2])
|
|
190
|
+
|
|
191
|
+
def _multiply_layer(self, source, result_size, activation=tf.nn.relu):
|
|
192
|
+
|
|
193
|
+
with tf.compat.v1.device('/gpu:{}'.format(self.use_gpu_device)):
|
|
194
|
+
ntime_steps = tf.shape(input=source)[1]
|
|
195
|
+
source_size = source.shape[2]
|
|
196
|
+
|
|
197
|
+
W = tf.compat.v1.get_variable("W", shape=[source_size, result_size],
|
|
198
|
+
dtype=tf.float32,
|
|
199
|
+
initializer=tf.compat.v1.keras.initializers.VarianceScaling(scale=1.0,
|
|
200
|
+
mode="fan_avg",
|
|
201
|
+
distribution="uniform"),
|
|
202
|
+
use_resource=False)
|
|
203
|
+
|
|
204
|
+
b = tf.compat.v1.get_variable("b", shape=[result_size], dtype=tf.float32, use_resource=False)
|
|
205
|
+
|
|
206
|
+
# batch x time, source_size
|
|
207
|
+
source = tf.reshape(source, [-1, source_size])
|
|
208
|
+
# batch x time, result_size
|
|
209
|
+
result = tf.matmul(source, W) + b
|
|
210
|
+
|
|
211
|
+
result = tf.reshape(result, [-1, ntime_steps, result_size])
|
|
212
|
+
if activation:
|
|
213
|
+
result = activation(result)
|
|
214
|
+
|
|
215
|
+
return result
|
|
216
|
+
|
|
217
|
+
# Adds Bi LSTM with size of each cell hidden_size
|
|
218
|
+
def add_context_repr(self, ntags, hidden_size=100, height=1, residual=True):
|
|
219
|
+
assert (self._word_embeddings_added or self._char_cnn_added or self._char_bilstm_added,
|
|
220
|
+
"Add word embeddings by method add_word_embeddings " +
|
|
221
|
+
"or add char representation by method add_bilstm_char_repr " +
|
|
222
|
+
"or add_bilstm_char_repr before adding context layer")
|
|
223
|
+
|
|
224
|
+
self._context_added = True
|
|
225
|
+
self.ntags = ntags
|
|
226
|
+
|
|
227
|
+
with tf.compat.v1.device('/gpu:{}'.format(self.use_gpu_device)):
|
|
228
|
+
context_repr = self._multiply_layer(self.word_repr, 2 * hidden_size)
|
|
229
|
+
# Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`
|
|
230
|
+
context_repr = tf.nn.dropout(x=context_repr, rate=1 - self.dropout)
|
|
231
|
+
|
|
232
|
+
with tf.compat.v1.variable_scope("context_repr"):
|
|
233
|
+
for i in range(height):
|
|
234
|
+
with tf.compat.v1.variable_scope('lstm-{}'.format(i)):
|
|
235
|
+
new_repr = self._create_lstm_layer(context_repr, hidden_size,
|
|
236
|
+
lengths=self.sentence_lengths)
|
|
237
|
+
|
|
238
|
+
context_repr = new_repr + context_repr if residual else new_repr
|
|
239
|
+
|
|
240
|
+
context_repr = tf.nn.dropout(x=context_repr, rate=1 - self.dropout)
|
|
241
|
+
|
|
242
|
+
# batch, sentence, ntags
|
|
243
|
+
self.scores = self._multiply_layer(context_repr, ntags, activation=None)
|
|
244
|
+
|
|
245
|
+
tf.identity(self.scores, "scores")
|
|
246
|
+
|
|
247
|
+
self.predicted_labels = tf.argmax(input=self.scores, axis=-1)
|
|
248
|
+
tf.identity(self.predicted_labels, "predicted_labels")
|
|
249
|
+
|
|
250
|
+
def add_inference_layer(self, crf=False, predictions_op_name=None):
|
|
251
|
+
assert (self._context_added,
|
|
252
|
+
"Add context representation layer by method add_context_repr before adding inference layer")
|
|
253
|
+
self._inference_added = True
|
|
254
|
+
|
|
255
|
+
with tf.device('/gpu:{}'.format(self.use_gpu_device)):
|
|
256
|
+
|
|
257
|
+
with tf.compat.v1.variable_scope("inference", reuse=None) as scope:
|
|
258
|
+
|
|
259
|
+
self.crf = tf.constant(crf, dtype=tf.bool, name="crf")
|
|
260
|
+
|
|
261
|
+
if crf:
|
|
262
|
+
transition_params = tf.compat.v1.get_variable("transition_params",
|
|
263
|
+
shape=[self.ntags, self.ntags],
|
|
264
|
+
initializer=tf.compat.v1.keras.initializers.VarianceScaling(
|
|
265
|
+
scale=1.0, mode="fan_avg",
|
|
266
|
+
distribution="uniform"),
|
|
267
|
+
use_resource=False)
|
|
268
|
+
|
|
269
|
+
# CRF shape = (batch, sentence)
|
|
270
|
+
log_likelihood, self.transition_params = tf.contrib.crf.crf_log_likelihood(
|
|
271
|
+
self.scores,
|
|
272
|
+
self.labels,
|
|
273
|
+
self.sentence_lengths,
|
|
274
|
+
transition_params
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
tf.identity(log_likelihood, "log_likelihood")
|
|
278
|
+
tf.identity(self.transition_params, "transition_params")
|
|
279
|
+
|
|
280
|
+
self.loss = tf.reduce_mean(input_tensor=-log_likelihood)
|
|
281
|
+
if predictions_op_name:
|
|
282
|
+
with tf.compat.v1.variable_scope("inference_tmp", reuse=None):
|
|
283
|
+
tmp_prediction, _ = tf.contrib.crf.crf_decode(self.scores, self.transition_params,
|
|
284
|
+
self.sentence_lengths)
|
|
285
|
+
|
|
286
|
+
self.prediction = tf.identity(tmp_prediction, name=predictions_op_name)
|
|
287
|
+
else:
|
|
288
|
+
self.prediction, _ = tf.contrib.crf.crf_decode(self.scores, self.transition_params,
|
|
289
|
+
self.sentence_lengths)
|
|
290
|
+
|
|
291
|
+
else:
|
|
292
|
+
# Softmax
|
|
293
|
+
losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.scores, labels=self.labels)
|
|
294
|
+
# shape = (batch, sentence, ntags)
|
|
295
|
+
mask = tf.sequence_mask(self.sentence_lengths)
|
|
296
|
+
# apply mask
|
|
297
|
+
losses = tf.boolean_mask(tensor=losses, mask=mask)
|
|
298
|
+
|
|
299
|
+
self.loss = tf.reduce_mean(input_tensor=losses)
|
|
300
|
+
|
|
301
|
+
self.prediction = tf.math.argmax(input=self.scores, axis=-1, name=predictions_op_name)
|
|
302
|
+
|
|
303
|
+
tf.identity(self.loss, "loss")
|
|
304
|
+
|
|
305
|
+
# clip_gradient < 0 - no gradient clipping
|
|
306
|
+
def add_training_op(self, clip_gradient=2.0, train_op_name=None):
|
|
307
|
+
assert (self._inference_added,
|
|
308
|
+
"Add inference layer by method add_inference_layer before adding training layer")
|
|
309
|
+
self._training_added = True
|
|
310
|
+
|
|
311
|
+
with tf.compat.v1.device('/gpu:{}'.format(self.use_gpu_device)):
|
|
312
|
+
|
|
313
|
+
with tf.compat.v1.variable_scope("training", reuse=None):
|
|
314
|
+
if train_op_name:
|
|
315
|
+
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=self.lr, name=train_op_name)
|
|
316
|
+
else:
|
|
317
|
+
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=self.lr)
|
|
318
|
+
if clip_gradient > 0:
|
|
319
|
+
gvs = optimizer.compute_gradients(self.loss)
|
|
320
|
+
capped_gvs = [(tf.clip_by_value(grad, -clip_gradient, clip_gradient), var) for grad, var in gvs if
|
|
321
|
+
grad is not None]
|
|
322
|
+
self.train_op = optimizer.apply_gradients(capped_gvs)
|
|
323
|
+
else:
|
|
324
|
+
self.train_op = optimizer.minimize(self.loss)
|
|
325
|
+
|
|
326
|
+
self.init_op = tf.compat.v1.variables_initializer(tf.compat.v1.global_variables(), name="init")
|
|
327
|
+
|
|
328
|
+
@staticmethod
|
|
329
|
+
def num_trues(array):
|
|
330
|
+
result = 0
|
|
331
|
+
for item in array:
|
|
332
|
+
if item == True:
|
|
333
|
+
result += 1
|
|
334
|
+
|
|
335
|
+
return result
|
|
336
|
+
|
|
337
|
+
@staticmethod
|
|
338
|
+
def fill(array, l, val):
|
|
339
|
+
result = array[:]
|
|
340
|
+
for i in range(l - len(array)):
|
|
341
|
+
result.append(val)
|
|
342
|
+
return result
|
|
343
|
+
|
|
344
|
+
@staticmethod
|
|
345
|
+
def get_sentence_lengths(batch, idx="word_embeddings"):
|
|
346
|
+
return [len(row[idx]) for row in batch]
|
|
347
|
+
|
|
348
|
+
@staticmethod
|
|
349
|
+
def get_sentence_token_lengths(batch, idx="tag_ids"):
|
|
350
|
+
return [len(row[idx]) for row in batch]
|
|
351
|
+
|
|
352
|
+
@staticmethod
|
|
353
|
+
def get_word_lengths(batch, idx="char_ids"):
|
|
354
|
+
max_words = max([len(row[idx]) for row in batch])
|
|
355
|
+
return [NerModel.fill([len(chars) for chars in row[idx]], max_words, 0)
|
|
356
|
+
for row in batch]
|
|
357
|
+
|
|
358
|
+
@staticmethod
|
|
359
|
+
def get_char_ids(batch, idx="char_ids"):
|
|
360
|
+
max_chars = max([max([len(char_ids) for char_ids in sentence[idx]]) for sentence in batch])
|
|
361
|
+
max_words = max([len(sentence[idx]) for sentence in batch])
|
|
362
|
+
|
|
363
|
+
return [
|
|
364
|
+
NerModel.fill(
|
|
365
|
+
[NerModel.fill(char_ids, max_chars, 0) for char_ids in sentence[idx]],
|
|
366
|
+
max_words, [0] * max_chars
|
|
367
|
+
)
|
|
368
|
+
for sentence in batch]
|
|
369
|
+
|
|
370
|
+
@staticmethod
|
|
371
|
+
def get_from_batch(batch, idx):
|
|
372
|
+
k = max([len(row[idx]) for row in batch])
|
|
373
|
+
return list([NerModel.fill(row[idx], k, 0) for row in batch])
|
|
374
|
+
|
|
375
|
+
@staticmethod
|
|
376
|
+
def get_tag_ids(batch, idx="tag_ids"):
|
|
377
|
+
return NerModel.get_from_batch(batch, idx)
|
|
378
|
+
|
|
379
|
+
@staticmethod
|
|
380
|
+
def get_word_embeddings(batch, idx="word_embeddings"):
|
|
381
|
+
embeddings_dim = len(batch[0][idx][0])
|
|
382
|
+
max_words = max([len(sentence[idx]) for sentence in batch])
|
|
383
|
+
return [
|
|
384
|
+
NerModel.fill([word_embedding for word_embedding in sentence[idx]],
|
|
385
|
+
max_words, [0] * embeddings_dim
|
|
386
|
+
)
|
|
387
|
+
for sentence in batch]
|
|
388
|
+
|
|
389
|
+
@staticmethod
|
|
390
|
+
def slice(dataset, batch_size=10):
|
|
391
|
+
grouper = SentenceGrouper([5, 10, 20, 50])
|
|
392
|
+
return grouper.slice(dataset, batch_size)
|
|
393
|
+
|
|
394
|
+
def init_variables(self):
|
|
395
|
+
self.session.run(self.init_op)
|
|
396
|
+
|
|
397
|
+
def train(self, train,
|
|
398
|
+
epoch_start=0,
|
|
399
|
+
epoch_end=100,
|
|
400
|
+
batch_size=32,
|
|
401
|
+
lr=0.01,
|
|
402
|
+
po=0,
|
|
403
|
+
dropout=0.65,
|
|
404
|
+
init_variables=False
|
|
405
|
+
):
|
|
406
|
+
|
|
407
|
+
assert (self._training_added, "Add training layer by method add_training_op before running training")
|
|
408
|
+
|
|
409
|
+
if init_variables:
|
|
410
|
+
with tf.compat.v1.device('/gpu:{}'.format(self.use_gpu_device)):
|
|
411
|
+
self.session.run(tf.compat.v1.global_variables_initializer())
|
|
412
|
+
|
|
413
|
+
print('trainig started')
|
|
414
|
+
for epoch in range(epoch_start, epoch_end):
|
|
415
|
+
random.shuffle(train)
|
|
416
|
+
sum_loss = 0
|
|
417
|
+
for batch in NerModel.slice(train, batch_size):
|
|
418
|
+
feed_dict = {
|
|
419
|
+
self.sentence_lengths: NerModel.get_sentence_lengths(batch),
|
|
420
|
+
self.word_embeddings: NerModel.get_word_embeddings(batch),
|
|
421
|
+
|
|
422
|
+
self.word_lengths: NerModel.get_word_lengths(batch),
|
|
423
|
+
self.char_ids: NerModel.get_char_ids(batch),
|
|
424
|
+
self.labels: NerModel.get_tag_ids(batch),
|
|
425
|
+
|
|
426
|
+
self.dropout: dropout,
|
|
427
|
+
self.lr: lr / (1 + po * epoch)
|
|
428
|
+
}
|
|
429
|
+
mean_loss, _ = self.session.run([self.loss, self.train_op], feed_dict=feed_dict)
|
|
430
|
+
sum_loss += mean_loss
|
|
431
|
+
|
|
432
|
+
print("epoch {}".format(epoch))
|
|
433
|
+
print("mean loss: {}".format(sum_loss))
|
|
434
|
+
print()
|
|
435
|
+
sys.stdout.flush()
|
|
436
|
+
|
|
437
|
+
def measure(self, dataset, batch_size=20, dropout=1.0):
|
|
438
|
+
predicted = {}
|
|
439
|
+
correct = {}
|
|
440
|
+
correct_predicted = {}
|
|
441
|
+
|
|
442
|
+
for batch in NerModel.slice(dataset, batch_size):
|
|
443
|
+
tags_ids = NerModel.get_tag_ids(batch)
|
|
444
|
+
sentence_lengths = NerModel.get_sentence_lengths(batch)
|
|
445
|
+
|
|
446
|
+
feed_dict = {
|
|
447
|
+
self.sentence_lengths: sentence_lengths,
|
|
448
|
+
self.word_embeddings: NerModel.get_word_embeddings(batch),
|
|
449
|
+
|
|
450
|
+
self.word_lengths: NerModel.get_word_lengths(batch),
|
|
451
|
+
self.char_ids: NerModel.get_char_ids(batch),
|
|
452
|
+
self.labels: tags_ids,
|
|
453
|
+
|
|
454
|
+
self.dropout: dropout
|
|
455
|
+
}
|
|
456
|
+
|
|
457
|
+
prediction = self.session.run(self.prediction, feed_dict=feed_dict)
|
|
458
|
+
batch_prediction = np.reshape(prediction, (len(batch), -1))
|
|
459
|
+
|
|
460
|
+
for i in range(len(batch)):
|
|
461
|
+
is_word_start = batch[i]['is_word_start']
|
|
462
|
+
|
|
463
|
+
for word in range(sentence_lengths[i]):
|
|
464
|
+
if not is_word_start[word]:
|
|
465
|
+
continue
|
|
466
|
+
|
|
467
|
+
p = batch_prediction[i][word]
|
|
468
|
+
c = tags_ids[i][word]
|
|
469
|
+
|
|
470
|
+
if c in self.dummy_tags:
|
|
471
|
+
continue
|
|
472
|
+
|
|
473
|
+
predicted[p] = predicted.get(p, 0) + 1
|
|
474
|
+
correct[c] = correct.get(c, 0) + 1
|
|
475
|
+
if p == c:
|
|
476
|
+
correct_predicted[p] = correct_predicted.get(p, 0) + 1
|
|
477
|
+
|
|
478
|
+
num_correct_predicted = sum([correct_predicted.get(i, 0) for i in range(1, self.ntags)])
|
|
479
|
+
num_predicted = sum([predicted.get(i, 0) for i in range(1, self.ntags)])
|
|
480
|
+
num_correct = sum([correct.get(i, 0) for i in range(1, self.ntags)])
|
|
481
|
+
|
|
482
|
+
prec = num_correct_predicted / (num_predicted or 1.)
|
|
483
|
+
rec = num_correct_predicted / (num_correct or 1.)
|
|
484
|
+
|
|
485
|
+
f1 = 2 * prec * rec / (rec + prec)
|
|
486
|
+
|
|
487
|
+
return prec, rec, f1
|
|
488
|
+
|
|
489
|
+
@staticmethod
|
|
490
|
+
def get_softmax(scores, threshold=None):
|
|
491
|
+
exp_scores = np.exp(scores)
|
|
492
|
+
|
|
493
|
+
for _ in exp_scores:
|
|
494
|
+
for sentence in exp_scores:
|
|
495
|
+
for i in range(len(sentence)):
|
|
496
|
+
probabilities = sentence[i] / np.sum(sentence[i])
|
|
497
|
+
sentence[i] = [p if threshold is None or p >= threshold else 0 for p in probabilities]
|
|
498
|
+
|
|
499
|
+
return exp_scores
|
|
500
|
+
|
|
501
|
+
def predict(self, sentences, batch_size=20, threshold=None):
|
|
502
|
+
result = []
|
|
503
|
+
|
|
504
|
+
for batch in NerModel.slice(sentences, batch_size):
|
|
505
|
+
sentence_lengths = NerModel.get_sentence_lengths(batch)
|
|
506
|
+
|
|
507
|
+
feed_dict = {
|
|
508
|
+
self.sentence_lengths: sentence_lengths,
|
|
509
|
+
self.word_embeddings: NerModel.get_word_embeddings(batch),
|
|
510
|
+
|
|
511
|
+
self.word_lengths: NerModel.get_word_lengths(batch),
|
|
512
|
+
self.char_ids: NerModel.get_char_ids(batch),
|
|
513
|
+
|
|
514
|
+
self.dropout: 1.1
|
|
515
|
+
}
|
|
516
|
+
|
|
517
|
+
prediction = self.session.run(self.prediction, feed_dict=feed_dict)
|
|
518
|
+
batch_prediction = np.reshape(prediction, (len(batch), -1))
|
|
519
|
+
|
|
520
|
+
for i in range(len(batch)):
|
|
521
|
+
sentence = []
|
|
522
|
+
for word in range(sentence_lengths[i]):
|
|
523
|
+
tag = batch_prediction[i][word]
|
|
524
|
+
sentence.append(tag)
|
|
525
|
+
|
|
526
|
+
result.append(sentence)
|
|
527
|
+
|
|
528
|
+
return result
|
|
529
|
+
|
|
530
|
+
def close(self):
|
|
531
|
+
if self.session_created:
|
|
532
|
+
self.session.close()
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
import tensorflow as tf
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class NerModelSaver:
|
|
7
|
+
def __init__(self, ner, encoder, embeddings_file=None):
|
|
8
|
+
self.ner = ner
|
|
9
|
+
self.encoder = encoder
|
|
10
|
+
self.embeddings_file = embeddings_file
|
|
11
|
+
|
|
12
|
+
@staticmethod
|
|
13
|
+
def restore_tensorflow_state(session, export_dir):
|
|
14
|
+
with tf.device('/gpu:0'):
|
|
15
|
+
saveNodes = list([n.name for n in tf.get_default_graph().as_graph_def().node if n.name.startswith('save/')])
|
|
16
|
+
if len(saveNodes) == 0:
|
|
17
|
+
saver = tf.train.Saver()
|
|
18
|
+
|
|
19
|
+
variables_file = os.path.join(export_dir, 'variables')
|
|
20
|
+
session.run("save/restore_all", feed_dict={'save/Const:0': variables_file})
|
|
21
|
+
|
|
22
|
+
def save_models(self, folder):
|
|
23
|
+
with tf.device('/gpu:0'):
|
|
24
|
+
saveNodes = list([n.name for n in tf.get_default_graph().as_graph_def().node if n.name.startswith('save/')])
|
|
25
|
+
if len(saveNodes) == 0:
|
|
26
|
+
saver = tf.train.Saver()
|
|
27
|
+
|
|
28
|
+
variables_file = os.path.join(folder, 'variables')
|
|
29
|
+
self.ner.session.run('save/control_dependency', feed_dict={'save/Const:0': variables_file})
|
|
30
|
+
tf.train.write_graph(self.ner.session.graph, folder, 'saved_model.pb', False)
|
|
31
|
+
|
|
32
|
+
def save(self, export_dir):
|
|
33
|
+
def save_tags(file):
|
|
34
|
+
id2tag = {id: tag for (tag, id) in self.encoder.tag2id.items()}
|
|
35
|
+
|
|
36
|
+
with open(file, 'w') as f:
|
|
37
|
+
for i in range(len(id2tag)):
|
|
38
|
+
tag = id2tag[i]
|
|
39
|
+
f.write(tag)
|
|
40
|
+
f.write('\n')
|
|
41
|
+
|
|
42
|
+
def save_embeddings(src, dst):
|
|
43
|
+
from shutil import copyfile
|
|
44
|
+
copyfile(src, dst)
|
|
45
|
+
with open(dst + '.meta', 'w') as f:
|
|
46
|
+
embeddings = self.encoder.embeddings
|
|
47
|
+
dim = len(embeddings[0]) if embeddings else 0
|
|
48
|
+
f.write(str(dim))
|
|
49
|
+
|
|
50
|
+
def save_chars(file):
|
|
51
|
+
id2char = {id: char for (char, id) in self.encoder.char2id.items()}
|
|
52
|
+
with open(file, 'w') as f:
|
|
53
|
+
for i in range(1, len(id2char) + 1):
|
|
54
|
+
f.write(id2char[i])
|
|
55
|
+
|
|
56
|
+
save_models(export_dir)
|
|
57
|
+
save_tags(os.path.join(export_dir, 'tags.csv'))
|
|
58
|
+
|
|
59
|
+
if self.embeddings_file:
|
|
60
|
+
save_embeddings(self.embeddings_file, os.path.join(export_dir, 'embeddings'))
|
|
61
|
+
|
|
62
|
+
save_chars(os.path.join(export_dir, 'chars.csv'))
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
class SentenceGrouper:
|
|
2
|
+
def __init__(self, bucket_lengths):
|
|
3
|
+
self.bucket_lengths = bucket_lengths
|
|
4
|
+
|
|
5
|
+
def get_bucket_id(self, length):
|
|
6
|
+
for i, bucket_len in enumerate(self.bucket_lengths):
|
|
7
|
+
if length <= bucket_len:
|
|
8
|
+
return i
|
|
9
|
+
|
|
10
|
+
return len(self.bucket_lengths)
|
|
11
|
+
|
|
12
|
+
def slice(self, dataset, batch_size=32):
|
|
13
|
+
buckets = [[] for item in self.bucket_lengths]
|
|
14
|
+
buckets.append([])
|
|
15
|
+
|
|
16
|
+
for entry in dataset:
|
|
17
|
+
length = len(entry['words'])
|
|
18
|
+
bucket_id = self.get_bucket_id(length)
|
|
19
|
+
buckets[bucket_id].append(entry)
|
|
20
|
+
|
|
21
|
+
if len(buckets[bucket_id]) >= batch_size:
|
|
22
|
+
result = buckets[bucket_id][:]
|
|
23
|
+
yield result
|
|
24
|
+
buckets[bucket_id] = []
|
|
25
|
+
|
|
26
|
+
for bucket in buckets:
|
|
27
|
+
if len(bucket) > 0:
|
|
28
|
+
yield bucket
|