spark-nlp 2.6.3rc1__py2.py3-none-any.whl → 6.2.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- com/johnsnowlabs/ml/__init__.py +0 -0
- com/johnsnowlabs/ml/ai/__init__.py +10 -0
- com/johnsnowlabs/nlp/__init__.py +4 -2
- spark_nlp-6.2.1.dist-info/METADATA +362 -0
- spark_nlp-6.2.1.dist-info/RECORD +292 -0
- {spark_nlp-2.6.3rc1.dist-info → spark_nlp-6.2.1.dist-info}/WHEEL +1 -1
- sparknlp/__init__.py +281 -27
- sparknlp/annotation.py +137 -6
- sparknlp/annotation_audio.py +61 -0
- sparknlp/annotation_image.py +82 -0
- sparknlp/annotator/__init__.py +93 -0
- sparknlp/annotator/audio/__init__.py +16 -0
- sparknlp/annotator/audio/hubert_for_ctc.py +188 -0
- sparknlp/annotator/audio/wav2vec2_for_ctc.py +161 -0
- sparknlp/annotator/audio/whisper_for_ctc.py +251 -0
- sparknlp/annotator/chunk2_doc.py +85 -0
- sparknlp/annotator/chunker.py +137 -0
- sparknlp/annotator/classifier_dl/__init__.py +61 -0
- sparknlp/annotator/classifier_dl/albert_for_multiple_choice.py +161 -0
- sparknlp/annotator/classifier_dl/albert_for_question_answering.py +172 -0
- sparknlp/annotator/classifier_dl/albert_for_sequence_classification.py +201 -0
- sparknlp/annotator/classifier_dl/albert_for_token_classification.py +179 -0
- sparknlp/annotator/classifier_dl/albert_for_zero_shot_classification.py +211 -0
- sparknlp/annotator/classifier_dl/bart_for_zero_shot_classification.py +225 -0
- sparknlp/annotator/classifier_dl/bert_for_multiple_choice.py +161 -0
- sparknlp/annotator/classifier_dl/bert_for_question_answering.py +168 -0
- sparknlp/annotator/classifier_dl/bert_for_sequence_classification.py +202 -0
- sparknlp/annotator/classifier_dl/bert_for_token_classification.py +177 -0
- sparknlp/annotator/classifier_dl/bert_for_zero_shot_classification.py +212 -0
- sparknlp/annotator/classifier_dl/camembert_for_question_answering.py +168 -0
- sparknlp/annotator/classifier_dl/camembert_for_sequence_classification.py +205 -0
- sparknlp/annotator/classifier_dl/camembert_for_token_classification.py +173 -0
- sparknlp/annotator/classifier_dl/camembert_for_zero_shot_classification.py +202 -0
- sparknlp/annotator/classifier_dl/classifier_dl.py +320 -0
- sparknlp/annotator/classifier_dl/deberta_for_question_answering.py +168 -0
- sparknlp/annotator/classifier_dl/deberta_for_sequence_classification.py +198 -0
- sparknlp/annotator/classifier_dl/deberta_for_token_classification.py +175 -0
- sparknlp/annotator/classifier_dl/deberta_for_zero_shot_classification.py +193 -0
- sparknlp/annotator/classifier_dl/distil_bert_for_question_answering.py +168 -0
- sparknlp/annotator/classifier_dl/distil_bert_for_sequence_classification.py +201 -0
- sparknlp/annotator/classifier_dl/distil_bert_for_token_classification.py +175 -0
- sparknlp/annotator/classifier_dl/distil_bert_for_zero_shot_classification.py +211 -0
- sparknlp/annotator/classifier_dl/distilbert_for_multiple_choice.py +161 -0
- sparknlp/annotator/classifier_dl/longformer_for_question_answering.py +168 -0
- sparknlp/annotator/classifier_dl/longformer_for_sequence_classification.py +201 -0
- sparknlp/annotator/classifier_dl/longformer_for_token_classification.py +176 -0
- sparknlp/annotator/classifier_dl/mpnet_for_question_answering.py +148 -0
- sparknlp/annotator/classifier_dl/mpnet_for_sequence_classification.py +188 -0
- sparknlp/annotator/classifier_dl/mpnet_for_token_classification.py +173 -0
- sparknlp/annotator/classifier_dl/multi_classifier_dl.py +395 -0
- sparknlp/annotator/classifier_dl/roberta_for_multiple_choice.py +161 -0
- sparknlp/annotator/classifier_dl/roberta_for_question_answering.py +168 -0
- sparknlp/annotator/classifier_dl/roberta_for_sequence_classification.py +201 -0
- sparknlp/annotator/classifier_dl/roberta_for_token_classification.py +189 -0
- sparknlp/annotator/classifier_dl/roberta_for_zero_shot_classification.py +225 -0
- sparknlp/annotator/classifier_dl/sentiment_dl.py +378 -0
- sparknlp/annotator/classifier_dl/tapas_for_question_answering.py +170 -0
- sparknlp/annotator/classifier_dl/xlm_roberta_for_multiple_choice.py +149 -0
- sparknlp/annotator/classifier_dl/xlm_roberta_for_question_answering.py +168 -0
- sparknlp/annotator/classifier_dl/xlm_roberta_for_sequence_classification.py +201 -0
- sparknlp/annotator/classifier_dl/xlm_roberta_for_token_classification.py +173 -0
- sparknlp/annotator/classifier_dl/xlm_roberta_for_zero_shot_classification.py +225 -0
- sparknlp/annotator/classifier_dl/xlnet_for_sequence_classification.py +201 -0
- sparknlp/annotator/classifier_dl/xlnet_for_token_classification.py +176 -0
- sparknlp/annotator/cleaners/__init__.py +15 -0
- sparknlp/annotator/cleaners/cleaner.py +202 -0
- sparknlp/annotator/cleaners/extractor.py +191 -0
- sparknlp/annotator/coref/__init__.py +1 -0
- sparknlp/annotator/coref/spanbert_coref.py +221 -0
- sparknlp/annotator/cv/__init__.py +29 -0
- sparknlp/annotator/cv/blip_for_question_answering.py +172 -0
- sparknlp/annotator/cv/clip_for_zero_shot_classification.py +193 -0
- sparknlp/annotator/cv/convnext_for_image_classification.py +269 -0
- sparknlp/annotator/cv/florence2_transformer.py +180 -0
- sparknlp/annotator/cv/gemma3_for_multimodal.py +346 -0
- sparknlp/annotator/cv/internvl_for_multimodal.py +280 -0
- sparknlp/annotator/cv/janus_for_multimodal.py +351 -0
- sparknlp/annotator/cv/llava_for_multimodal.py +328 -0
- sparknlp/annotator/cv/mllama_for_multimodal.py +340 -0
- sparknlp/annotator/cv/paligemma_for_multimodal.py +308 -0
- sparknlp/annotator/cv/phi3_vision_for_multimodal.py +328 -0
- sparknlp/annotator/cv/qwen2vl_transformer.py +332 -0
- sparknlp/annotator/cv/smolvlm_transformer.py +426 -0
- sparknlp/annotator/cv/swin_for_image_classification.py +242 -0
- sparknlp/annotator/cv/vision_encoder_decoder_for_image_captioning.py +240 -0
- sparknlp/annotator/cv/vit_for_image_classification.py +217 -0
- sparknlp/annotator/dataframe_optimizer.py +216 -0
- sparknlp/annotator/date2_chunk.py +88 -0
- sparknlp/annotator/dependency/__init__.py +17 -0
- sparknlp/annotator/dependency/dependency_parser.py +294 -0
- sparknlp/annotator/dependency/typed_dependency_parser.py +318 -0
- sparknlp/annotator/document_character_text_splitter.py +228 -0
- sparknlp/annotator/document_normalizer.py +235 -0
- sparknlp/annotator/document_token_splitter.py +175 -0
- sparknlp/annotator/document_token_splitter_test.py +85 -0
- sparknlp/annotator/embeddings/__init__.py +45 -0
- sparknlp/annotator/embeddings/albert_embeddings.py +230 -0
- sparknlp/annotator/embeddings/auto_gguf_embeddings.py +539 -0
- sparknlp/annotator/embeddings/bert_embeddings.py +208 -0
- sparknlp/annotator/embeddings/bert_sentence_embeddings.py +224 -0
- sparknlp/annotator/embeddings/bge_embeddings.py +199 -0
- sparknlp/annotator/embeddings/camembert_embeddings.py +210 -0
- sparknlp/annotator/embeddings/chunk_embeddings.py +149 -0
- sparknlp/annotator/embeddings/deberta_embeddings.py +208 -0
- sparknlp/annotator/embeddings/distil_bert_embeddings.py +221 -0
- sparknlp/annotator/embeddings/doc2vec.py +352 -0
- sparknlp/annotator/embeddings/e5_embeddings.py +195 -0
- sparknlp/annotator/embeddings/e5v_embeddings.py +138 -0
- sparknlp/annotator/embeddings/elmo_embeddings.py +251 -0
- sparknlp/annotator/embeddings/instructor_embeddings.py +204 -0
- sparknlp/annotator/embeddings/longformer_embeddings.py +211 -0
- sparknlp/annotator/embeddings/minilm_embeddings.py +189 -0
- sparknlp/annotator/embeddings/mpnet_embeddings.py +192 -0
- sparknlp/annotator/embeddings/mxbai_embeddings.py +184 -0
- sparknlp/annotator/embeddings/nomic_embeddings.py +181 -0
- sparknlp/annotator/embeddings/roberta_embeddings.py +225 -0
- sparknlp/annotator/embeddings/roberta_sentence_embeddings.py +191 -0
- sparknlp/annotator/embeddings/sentence_embeddings.py +134 -0
- sparknlp/annotator/embeddings/snowflake_embeddings.py +202 -0
- sparknlp/annotator/embeddings/uae_embeddings.py +211 -0
- sparknlp/annotator/embeddings/universal_sentence_encoder.py +211 -0
- sparknlp/annotator/embeddings/word2vec.py +353 -0
- sparknlp/annotator/embeddings/word_embeddings.py +385 -0
- sparknlp/annotator/embeddings/xlm_roberta_embeddings.py +225 -0
- sparknlp/annotator/embeddings/xlm_roberta_sentence_embeddings.py +194 -0
- sparknlp/annotator/embeddings/xlnet_embeddings.py +227 -0
- sparknlp/annotator/er/__init__.py +16 -0
- sparknlp/annotator/er/entity_ruler.py +267 -0
- sparknlp/annotator/graph_extraction.py +368 -0
- sparknlp/annotator/keyword_extraction/__init__.py +16 -0
- sparknlp/annotator/keyword_extraction/yake_keyword_extraction.py +270 -0
- sparknlp/annotator/ld_dl/__init__.py +16 -0
- sparknlp/annotator/ld_dl/language_detector_dl.py +199 -0
- sparknlp/annotator/lemmatizer.py +250 -0
- sparknlp/annotator/matcher/__init__.py +20 -0
- sparknlp/annotator/matcher/big_text_matcher.py +272 -0
- sparknlp/annotator/matcher/date_matcher.py +303 -0
- sparknlp/annotator/matcher/multi_date_matcher.py +109 -0
- sparknlp/annotator/matcher/regex_matcher.py +221 -0
- sparknlp/annotator/matcher/text_matcher.py +290 -0
- sparknlp/annotator/n_gram_generator.py +141 -0
- sparknlp/annotator/ner/__init__.py +21 -0
- sparknlp/annotator/ner/ner_approach.py +94 -0
- sparknlp/annotator/ner/ner_converter.py +148 -0
- sparknlp/annotator/ner/ner_crf.py +397 -0
- sparknlp/annotator/ner/ner_dl.py +591 -0
- sparknlp/annotator/ner/ner_dl_graph_checker.py +293 -0
- sparknlp/annotator/ner/ner_overwriter.py +166 -0
- sparknlp/annotator/ner/zero_shot_ner_model.py +173 -0
- sparknlp/annotator/normalizer.py +230 -0
- sparknlp/annotator/openai/__init__.py +16 -0
- sparknlp/annotator/openai/openai_completion.py +349 -0
- sparknlp/annotator/openai/openai_embeddings.py +106 -0
- sparknlp/annotator/param/__init__.py +17 -0
- sparknlp/annotator/param/classifier_encoder.py +98 -0
- sparknlp/annotator/param/evaluation_dl_params.py +130 -0
- sparknlp/annotator/pos/__init__.py +16 -0
- sparknlp/annotator/pos/perceptron.py +263 -0
- sparknlp/annotator/sentence/__init__.py +17 -0
- sparknlp/annotator/sentence/sentence_detector.py +290 -0
- sparknlp/annotator/sentence/sentence_detector_dl.py +467 -0
- sparknlp/annotator/sentiment/__init__.py +17 -0
- sparknlp/annotator/sentiment/sentiment_detector.py +208 -0
- sparknlp/annotator/sentiment/vivekn_sentiment.py +242 -0
- sparknlp/annotator/seq2seq/__init__.py +35 -0
- sparknlp/annotator/seq2seq/auto_gguf_model.py +304 -0
- sparknlp/annotator/seq2seq/auto_gguf_reranker.py +334 -0
- sparknlp/annotator/seq2seq/auto_gguf_vision_model.py +336 -0
- sparknlp/annotator/seq2seq/bart_transformer.py +420 -0
- sparknlp/annotator/seq2seq/cohere_transformer.py +357 -0
- sparknlp/annotator/seq2seq/cpm_transformer.py +321 -0
- sparknlp/annotator/seq2seq/gpt2_transformer.py +363 -0
- sparknlp/annotator/seq2seq/llama2_transformer.py +343 -0
- sparknlp/annotator/seq2seq/llama3_transformer.py +381 -0
- sparknlp/annotator/seq2seq/m2m100_transformer.py +392 -0
- sparknlp/annotator/seq2seq/marian_transformer.py +374 -0
- sparknlp/annotator/seq2seq/mistral_transformer.py +348 -0
- sparknlp/annotator/seq2seq/nllb_transformer.py +420 -0
- sparknlp/annotator/seq2seq/olmo_transformer.py +326 -0
- sparknlp/annotator/seq2seq/phi2_transformer.py +326 -0
- sparknlp/annotator/seq2seq/phi3_transformer.py +330 -0
- sparknlp/annotator/seq2seq/phi4_transformer.py +387 -0
- sparknlp/annotator/seq2seq/qwen_transformer.py +340 -0
- sparknlp/annotator/seq2seq/starcoder_transformer.py +335 -0
- sparknlp/annotator/seq2seq/t5_transformer.py +425 -0
- sparknlp/annotator/similarity/__init__.py +0 -0
- sparknlp/annotator/similarity/document_similarity_ranker.py +379 -0
- sparknlp/annotator/spell_check/__init__.py +18 -0
- sparknlp/annotator/spell_check/context_spell_checker.py +911 -0
- sparknlp/annotator/spell_check/norvig_sweeting.py +358 -0
- sparknlp/annotator/spell_check/symmetric_delete.py +299 -0
- sparknlp/annotator/stemmer.py +79 -0
- sparknlp/annotator/stop_words_cleaner.py +190 -0
- sparknlp/annotator/tf_ner_dl_graph_builder.py +179 -0
- sparknlp/annotator/token/__init__.py +19 -0
- sparknlp/annotator/token/chunk_tokenizer.py +118 -0
- sparknlp/annotator/token/recursive_tokenizer.py +205 -0
- sparknlp/annotator/token/regex_tokenizer.py +208 -0
- sparknlp/annotator/token/tokenizer.py +561 -0
- sparknlp/annotator/token2_chunk.py +76 -0
- sparknlp/annotator/ws/__init__.py +16 -0
- sparknlp/annotator/ws/word_segmenter.py +429 -0
- sparknlp/base/__init__.py +30 -0
- sparknlp/base/audio_assembler.py +95 -0
- sparknlp/base/doc2_chunk.py +169 -0
- sparknlp/base/document_assembler.py +164 -0
- sparknlp/base/embeddings_finisher.py +201 -0
- sparknlp/base/finisher.py +217 -0
- sparknlp/base/gguf_ranking_finisher.py +234 -0
- sparknlp/base/graph_finisher.py +125 -0
- sparknlp/base/has_recursive_fit.py +24 -0
- sparknlp/base/has_recursive_transform.py +22 -0
- sparknlp/base/image_assembler.py +172 -0
- sparknlp/base/light_pipeline.py +429 -0
- sparknlp/base/multi_document_assembler.py +164 -0
- sparknlp/base/prompt_assembler.py +207 -0
- sparknlp/base/recursive_pipeline.py +107 -0
- sparknlp/base/table_assembler.py +145 -0
- sparknlp/base/token_assembler.py +124 -0
- sparknlp/common/__init__.py +26 -0
- sparknlp/common/annotator_approach.py +41 -0
- sparknlp/common/annotator_model.py +47 -0
- sparknlp/common/annotator_properties.py +114 -0
- sparknlp/common/annotator_type.py +38 -0
- sparknlp/common/completion_post_processing.py +37 -0
- sparknlp/common/coverage_result.py +22 -0
- sparknlp/common/match_strategy.py +33 -0
- sparknlp/common/properties.py +1298 -0
- sparknlp/common/read_as.py +33 -0
- sparknlp/common/recursive_annotator_approach.py +35 -0
- sparknlp/common/storage.py +149 -0
- sparknlp/common/utils.py +39 -0
- sparknlp/functions.py +315 -5
- sparknlp/internal/__init__.py +1199 -0
- sparknlp/internal/annotator_java_ml.py +32 -0
- sparknlp/internal/annotator_transformer.py +37 -0
- sparknlp/internal/extended_java_wrapper.py +63 -0
- sparknlp/internal/params_getters_setters.py +71 -0
- sparknlp/internal/recursive.py +70 -0
- sparknlp/logging/__init__.py +15 -0
- sparknlp/logging/comet.py +467 -0
- sparknlp/partition/__init__.py +16 -0
- sparknlp/partition/partition.py +244 -0
- sparknlp/partition/partition_properties.py +902 -0
- sparknlp/partition/partition_transformer.py +200 -0
- sparknlp/pretrained/__init__.py +17 -0
- sparknlp/pretrained/pretrained_pipeline.py +158 -0
- sparknlp/pretrained/resource_downloader.py +216 -0
- sparknlp/pretrained/utils.py +35 -0
- sparknlp/reader/__init__.py +15 -0
- sparknlp/reader/enums.py +19 -0
- sparknlp/reader/pdf_to_text.py +190 -0
- sparknlp/reader/reader2doc.py +124 -0
- sparknlp/reader/reader2image.py +136 -0
- sparknlp/reader/reader2table.py +44 -0
- sparknlp/reader/reader_assembler.py +159 -0
- sparknlp/reader/sparknlp_reader.py +461 -0
- sparknlp/training/__init__.py +20 -0
- sparknlp/training/_tf_graph_builders/__init__.py +0 -0
- sparknlp/training/_tf_graph_builders/graph_builders.py +299 -0
- sparknlp/training/_tf_graph_builders/ner_dl/__init__.py +0 -0
- sparknlp/training/_tf_graph_builders/ner_dl/create_graph.py +41 -0
- sparknlp/training/_tf_graph_builders/ner_dl/dataset_encoder.py +78 -0
- sparknlp/training/_tf_graph_builders/ner_dl/ner_model.py +521 -0
- sparknlp/training/_tf_graph_builders/ner_dl/ner_model_saver.py +62 -0
- sparknlp/training/_tf_graph_builders/ner_dl/sentence_grouper.py +28 -0
- sparknlp/training/_tf_graph_builders/tf2contrib/__init__.py +36 -0
- sparknlp/training/_tf_graph_builders/tf2contrib/core_rnn_cell.py +385 -0
- sparknlp/training/_tf_graph_builders/tf2contrib/fused_rnn_cell.py +183 -0
- sparknlp/training/_tf_graph_builders/tf2contrib/gru_ops.py +235 -0
- sparknlp/training/_tf_graph_builders/tf2contrib/lstm_ops.py +665 -0
- sparknlp/training/_tf_graph_builders/tf2contrib/rnn.py +245 -0
- sparknlp/training/_tf_graph_builders/tf2contrib/rnn_cell.py +4006 -0
- sparknlp/training/_tf_graph_builders_1x/__init__.py +0 -0
- sparknlp/training/_tf_graph_builders_1x/graph_builders.py +277 -0
- sparknlp/training/_tf_graph_builders_1x/ner_dl/__init__.py +0 -0
- sparknlp/training/_tf_graph_builders_1x/ner_dl/create_graph.py +34 -0
- sparknlp/training/_tf_graph_builders_1x/ner_dl/dataset_encoder.py +78 -0
- sparknlp/training/_tf_graph_builders_1x/ner_dl/ner_model.py +532 -0
- sparknlp/training/_tf_graph_builders_1x/ner_dl/ner_model_saver.py +62 -0
- sparknlp/training/_tf_graph_builders_1x/ner_dl/sentence_grouper.py +28 -0
- sparknlp/training/conll.py +150 -0
- sparknlp/training/conllu.py +103 -0
- sparknlp/training/pos.py +103 -0
- sparknlp/training/pub_tator.py +76 -0
- sparknlp/training/spacy_to_annotation.py +57 -0
- sparknlp/training/tfgraphs.py +5 -0
- sparknlp/upload_to_hub.py +149 -0
- sparknlp/util.py +51 -5
- com/__init__.pyc +0 -0
- com/__pycache__/__init__.cpython-36.pyc +0 -0
- com/johnsnowlabs/__init__.pyc +0 -0
- com/johnsnowlabs/__pycache__/__init__.cpython-36.pyc +0 -0
- com/johnsnowlabs/nlp/__init__.pyc +0 -0
- com/johnsnowlabs/nlp/__pycache__/__init__.cpython-36.pyc +0 -0
- spark_nlp-2.6.3rc1.dist-info/METADATA +0 -36
- spark_nlp-2.6.3rc1.dist-info/RECORD +0 -48
- sparknlp/__init__.pyc +0 -0
- sparknlp/__pycache__/__init__.cpython-36.pyc +0 -0
- sparknlp/__pycache__/annotation.cpython-36.pyc +0 -0
- sparknlp/__pycache__/annotator.cpython-36.pyc +0 -0
- sparknlp/__pycache__/base.cpython-36.pyc +0 -0
- sparknlp/__pycache__/common.cpython-36.pyc +0 -0
- sparknlp/__pycache__/embeddings.cpython-36.pyc +0 -0
- sparknlp/__pycache__/functions.cpython-36.pyc +0 -0
- sparknlp/__pycache__/internal.cpython-36.pyc +0 -0
- sparknlp/__pycache__/pretrained.cpython-36.pyc +0 -0
- sparknlp/__pycache__/storage.cpython-36.pyc +0 -0
- sparknlp/__pycache__/training.cpython-36.pyc +0 -0
- sparknlp/__pycache__/util.cpython-36.pyc +0 -0
- sparknlp/annotation.pyc +0 -0
- sparknlp/annotator.py +0 -3006
- sparknlp/annotator.pyc +0 -0
- sparknlp/base.py +0 -347
- sparknlp/base.pyc +0 -0
- sparknlp/common.py +0 -193
- sparknlp/common.pyc +0 -0
- sparknlp/embeddings.py +0 -40
- sparknlp/embeddings.pyc +0 -0
- sparknlp/internal.py +0 -288
- sparknlp/internal.pyc +0 -0
- sparknlp/pretrained.py +0 -123
- sparknlp/pretrained.pyc +0 -0
- sparknlp/storage.py +0 -32
- sparknlp/storage.pyc +0 -0
- sparknlp/training.py +0 -62
- sparknlp/training.pyc +0 -0
- sparknlp/util.pyc +0 -0
- {spark_nlp-2.6.3rc1.dist-info → spark_nlp-6.2.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,320 @@
|
|
|
1
|
+
# Copyright 2017-2022 John Snow Labs
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Contains classes for ClassifierDL."""
|
|
15
|
+
|
|
16
|
+
from sparknlp.annotator.param import EvaluationDLParams, ClassifierEncoder
|
|
17
|
+
from sparknlp.base import DocumentAssembler
|
|
18
|
+
from sparknlp.common import *
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ClassifierDLApproach(AnnotatorApproach, EvaluationDLParams, ClassifierEncoder):
|
|
22
|
+
"""Trains a ClassifierDL for generic Multi-class Text Classification.
|
|
23
|
+
|
|
24
|
+
ClassifierDL uses the state-of-the-art Universal Sentence Encoder as an
|
|
25
|
+
input for text classifications.
|
|
26
|
+
The ClassifierDL annotator uses a deep learning model (DNNs) we have built
|
|
27
|
+
inside TensorFlow and supports up to 100 classes.
|
|
28
|
+
|
|
29
|
+
For instantiated/pretrained models, see :class:`.ClassifierDLModel`.
|
|
30
|
+
|
|
31
|
+
Setting a test dataset to monitor model metrics can be done with
|
|
32
|
+
``.setTestDataset``. The method expects a path to a parquet file containing a
|
|
33
|
+
dataframe that has the same required columns as the training dataframe. The
|
|
34
|
+
pre-processing steps for the training dataframe should also be applied to the test
|
|
35
|
+
dataframe. The following example will show how to create the test dataset:
|
|
36
|
+
|
|
37
|
+
>>> documentAssembler = DocumentAssembler() \\
|
|
38
|
+
... .setInputCol("text") \\
|
|
39
|
+
... .setOutputCol("document")
|
|
40
|
+
>>> embeddings = UniversalSentenceEncoder.pretrained() \\
|
|
41
|
+
... .setInputCols(["document"]) \\
|
|
42
|
+
... .setOutputCol("sentence_embeddings")
|
|
43
|
+
>>> preProcessingPipeline = Pipeline().setStages([documentAssembler, embeddings])
|
|
44
|
+
>>> (train, test) = data.randomSplit([0.8, 0.2])
|
|
45
|
+
>>> preProcessingPipeline \\
|
|
46
|
+
... .fit(test) \\
|
|
47
|
+
... .transform(test)
|
|
48
|
+
... .write \\
|
|
49
|
+
... .mode("overwrite") \\
|
|
50
|
+
... .parquet("test_data")
|
|
51
|
+
>>> classifier = ClassifierDLApproach() \\
|
|
52
|
+
... .setInputCols(["sentence_embeddings"]) \\
|
|
53
|
+
... .setOutputCol("category") \\
|
|
54
|
+
... .setLabelColumn("label") \\
|
|
55
|
+
... .setTestDataset("test_data")
|
|
56
|
+
|
|
57
|
+
For extended examples of usage, see the Examples
|
|
58
|
+
`Examples <https://github.com/JohnSnowLabs/spark-nlp/blob/master/examples/python/training/english/classification/ClassifierDL_Train_multi_class_news_category_classifier.ipynb>`__.
|
|
59
|
+
|
|
60
|
+
======================= ======================
|
|
61
|
+
Input Annotation types Output Annotation type
|
|
62
|
+
======================= ======================
|
|
63
|
+
``SENTENCE_EMBEDDINGS`` ``CATEGORY``
|
|
64
|
+
======================= ======================
|
|
65
|
+
|
|
66
|
+
Parameters
|
|
67
|
+
----------
|
|
68
|
+
batchSize
|
|
69
|
+
Batch size, by default 64
|
|
70
|
+
configProtoBytes
|
|
71
|
+
ConfigProto from tensorflow, serialized into byte array.
|
|
72
|
+
dropout
|
|
73
|
+
Dropout coefficient, by default 0.5
|
|
74
|
+
enableOutputLogs
|
|
75
|
+
Whether to use stdout in addition to Spark logs, by default False
|
|
76
|
+
evaluationLogExtended
|
|
77
|
+
Whether logs for validation to be extended: it displays time and evaluation of
|
|
78
|
+
each label. Default is False.
|
|
79
|
+
labelColumn
|
|
80
|
+
Column with label per each token
|
|
81
|
+
lr
|
|
82
|
+
Learning Rate, by default 0.005
|
|
83
|
+
maxEpochs
|
|
84
|
+
Maximum number of epochs to train, by default 30
|
|
85
|
+
outputLogsPath
|
|
86
|
+
Folder path to save training logs
|
|
87
|
+
randomSeed
|
|
88
|
+
Random seed for shuffling
|
|
89
|
+
testDataset
|
|
90
|
+
Path to test dataset. If set used to calculate statistic on it during training.
|
|
91
|
+
validationSplit
|
|
92
|
+
Choose the proportion of training dataset to be validated against the
|
|
93
|
+
model on each Epoch. The value should be between 0.0 and 1.0 and by
|
|
94
|
+
default it is 0.0 and off.
|
|
95
|
+
verbose
|
|
96
|
+
Level of verbosity during training
|
|
97
|
+
|
|
98
|
+
Notes
|
|
99
|
+
-----
|
|
100
|
+
- This annotator accepts a label column of a single item in either type of
|
|
101
|
+
String, Int, Float, or Double.
|
|
102
|
+
- UniversalSentenceEncoder, Transformer based embeddings, or
|
|
103
|
+
SentenceEmbeddings can be used for the ``inputCol``.
|
|
104
|
+
|
|
105
|
+
Examples
|
|
106
|
+
--------
|
|
107
|
+
>>> import sparknlp
|
|
108
|
+
>>> from sparknlp.base import *
|
|
109
|
+
>>> from sparknlp.annotator import *
|
|
110
|
+
>>> from pyspark.ml import Pipeline
|
|
111
|
+
|
|
112
|
+
In this example, the training data ``"sentiment.csv"`` has the form of::
|
|
113
|
+
|
|
114
|
+
text,label
|
|
115
|
+
This movie is the best movie I have wached ever! In my opinion this movie can win an award.,0
|
|
116
|
+
This was a terrible movie! The acting was bad really bad!,1
|
|
117
|
+
...
|
|
118
|
+
|
|
119
|
+
Then traning can be done like so:
|
|
120
|
+
|
|
121
|
+
>>> smallCorpus = spark.read.option("header","True").csv("src/test/resources/classifier/sentiment.csv")
|
|
122
|
+
>>> documentAssembler = DocumentAssembler() \\
|
|
123
|
+
... .setInputCol("text") \\
|
|
124
|
+
... .setOutputCol("document")
|
|
125
|
+
>>> useEmbeddings = UniversalSentenceEncoder.pretrained() \\
|
|
126
|
+
... .setInputCols("document") \\
|
|
127
|
+
... .setOutputCol("sentence_embeddings")
|
|
128
|
+
>>> docClassifier = ClassifierDLApproach() \\
|
|
129
|
+
... .setInputCols("sentence_embeddings") \\
|
|
130
|
+
... .setOutputCol("category") \\
|
|
131
|
+
... .setLabelColumn("label") \\
|
|
132
|
+
... .setBatchSize(64) \\
|
|
133
|
+
... .setMaxEpochs(20) \\
|
|
134
|
+
... .setLr(5e-3) \\
|
|
135
|
+
... .setDropout(0.5)
|
|
136
|
+
>>> pipeline = Pipeline().setStages([
|
|
137
|
+
... documentAssembler,
|
|
138
|
+
... useEmbeddings,
|
|
139
|
+
... docClassifier
|
|
140
|
+
... ])
|
|
141
|
+
>>> pipelineModel = pipeline.fit(smallCorpus)
|
|
142
|
+
|
|
143
|
+
See Also
|
|
144
|
+
--------
|
|
145
|
+
MultiClassifierDLApproach : for multi-class classification
|
|
146
|
+
SentimentDLApproach : for sentiment analysis
|
|
147
|
+
"""
|
|
148
|
+
inputAnnotatorTypes = [AnnotatorType.SENTENCE_EMBEDDINGS]
|
|
149
|
+
|
|
150
|
+
outputAnnotatorType = AnnotatorType.CATEGORY
|
|
151
|
+
|
|
152
|
+
dropout = Param(Params._dummy(), "dropout", "Dropout coefficient", TypeConverters.toFloat)
|
|
153
|
+
|
|
154
|
+
def setDropout(self, v):
|
|
155
|
+
"""Sets dropout coefficient, by default 0.5
|
|
156
|
+
|
|
157
|
+
Parameters
|
|
158
|
+
----------
|
|
159
|
+
v : float
|
|
160
|
+
Dropout coefficient
|
|
161
|
+
"""
|
|
162
|
+
self._set(dropout=v)
|
|
163
|
+
return self
|
|
164
|
+
|
|
165
|
+
def _create_model(self, java_model):
|
|
166
|
+
return ClassifierDLModel(java_model=java_model)
|
|
167
|
+
|
|
168
|
+
@keyword_only
|
|
169
|
+
def __init__(self):
|
|
170
|
+
super(ClassifierDLApproach, self).__init__(
|
|
171
|
+
classname="com.johnsnowlabs.nlp.annotators.classifier.dl.ClassifierDLApproach")
|
|
172
|
+
self._setDefault(
|
|
173
|
+
maxEpochs=30,
|
|
174
|
+
lr=float(0.005),
|
|
175
|
+
batchSize=64,
|
|
176
|
+
dropout=float(0.5),
|
|
177
|
+
enableOutputLogs=False,
|
|
178
|
+
evaluationLogExtended=False
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
class ClassifierDLModel(AnnotatorModel, HasStorageRef, HasEngine):
|
|
183
|
+
"""ClassifierDL for generic Multi-class Text Classification.
|
|
184
|
+
|
|
185
|
+
ClassifierDL uses the state-of-the-art Universal Sentence Encoder as an
|
|
186
|
+
input for text classifications. The ClassifierDL annotator uses a deep
|
|
187
|
+
learning model (DNNs) we have built inside TensorFlow and supports up to
|
|
188
|
+
100 classes.
|
|
189
|
+
|
|
190
|
+
This is the instantiated model of the :class:`.ClassifierDLApproach`.
|
|
191
|
+
For training your own model, please see the documentation of that class.
|
|
192
|
+
|
|
193
|
+
Pretrained models can be loaded with :meth:`.pretrained` of the companion
|
|
194
|
+
object:
|
|
195
|
+
|
|
196
|
+
>>> classifierDL = ClassifierDLModel.pretrained() \\
|
|
197
|
+
... .setInputCols(["sentence_embeddings"]) \\
|
|
198
|
+
... .setOutputCol("classification")
|
|
199
|
+
|
|
200
|
+
The default model is ``"classifierdl_use_trec6"``, if no name is provided.
|
|
201
|
+
It uses embeddings from the UniversalSentenceEncoder and is trained on the
|
|
202
|
+
`TREC-6 <https://deepai.org/dataset/trec-6#:~:text=The%20TREC%20dataset%20is%20dataset,50%20has%20finer%2Dgrained%20labels>`__
|
|
203
|
+
dataset.
|
|
204
|
+
|
|
205
|
+
For available pretrained models please see the
|
|
206
|
+
`Models Hub <https://sparknlp.org/models?task=Text+Classification>`__.
|
|
207
|
+
|
|
208
|
+
For extended examples of usage, see the
|
|
209
|
+
`Examples <https://github.com/JohnSnowLabs/spark-nlp/blob/master/examples/python/training/english/classification/ClassifierDL_Train_multi_class_news_category_classifier.ipynb>`__.
|
|
210
|
+
|
|
211
|
+
======================= ======================
|
|
212
|
+
Input Annotation types Output Annotation type
|
|
213
|
+
======================= ======================
|
|
214
|
+
``SENTENCE_EMBEDDINGS`` ``CATEGORY``
|
|
215
|
+
======================= ======================
|
|
216
|
+
|
|
217
|
+
Parameters
|
|
218
|
+
----------
|
|
219
|
+
configProtoBytes
|
|
220
|
+
ConfigProto from tensorflow, serialized into byte array.
|
|
221
|
+
classes
|
|
222
|
+
Get the tags used to trained this ClassifierDLModel
|
|
223
|
+
|
|
224
|
+
Examples
|
|
225
|
+
--------
|
|
226
|
+
>>> import sparknlp
|
|
227
|
+
>>> from sparknlp.base import *
|
|
228
|
+
>>> from sparknlp.annotator import *
|
|
229
|
+
>>> from pyspark.ml import Pipeline
|
|
230
|
+
>>> documentAssembler = DocumentAssembler() \\
|
|
231
|
+
... .setInputCol("text") \\
|
|
232
|
+
... .setOutputCol("document")
|
|
233
|
+
>>> sentence = SentenceDetector() \\
|
|
234
|
+
... .setInputCols("document") \\
|
|
235
|
+
... .setOutputCol("sentence")
|
|
236
|
+
>>> useEmbeddings = UniversalSentenceEncoder.pretrained() \\
|
|
237
|
+
... .setInputCols("document") \\
|
|
238
|
+
... .setOutputCol("sentence_embeddings")
|
|
239
|
+
>>> sarcasmDL = ClassifierDLModel.pretrained("classifierdl_use_sarcasm") \\
|
|
240
|
+
... .setInputCols("sentence_embeddings") \\
|
|
241
|
+
... .setOutputCol("sarcasm")
|
|
242
|
+
>>> pipeline = Pipeline() \\
|
|
243
|
+
... .setStages([
|
|
244
|
+
... documentAssembler,
|
|
245
|
+
... sentence,
|
|
246
|
+
... useEmbeddings,
|
|
247
|
+
... sarcasmDL
|
|
248
|
+
... ])
|
|
249
|
+
>>> data = spark.createDataFrame([
|
|
250
|
+
... ["I'm ready!"],
|
|
251
|
+
... ["If I could put into words how much I love waking up at 6 am on Mondays I would."]
|
|
252
|
+
... ]).toDF("text")
|
|
253
|
+
>>> result = pipeline.fit(data).transform(data)
|
|
254
|
+
>>> result.selectExpr("explode(arrays_zip(sentence, sarcasm)) as out") \\
|
|
255
|
+
... .selectExpr("out.sentence.result as sentence", "out.sarcasm.result as sarcasm") \\
|
|
256
|
+
... .show(truncate=False)
|
|
257
|
+
+-------------------------------------------------------------------------------+-------+
|
|
258
|
+
|sentence |sarcasm|
|
|
259
|
+
+-------------------------------------------------------------------------------+-------+
|
|
260
|
+
|I'm ready! |normal |
|
|
261
|
+
|If I could put into words how much I love waking up at 6 am on Mondays I would.|sarcasm|
|
|
262
|
+
+-------------------------------------------------------------------------------+-------+
|
|
263
|
+
|
|
264
|
+
See Also
|
|
265
|
+
--------
|
|
266
|
+
MultiClassifierDLModel : for multi-class classification
|
|
267
|
+
SentimentDLModel : for sentiment analysis
|
|
268
|
+
"""
|
|
269
|
+
|
|
270
|
+
name = "ClassifierDLModel"
|
|
271
|
+
|
|
272
|
+
inputAnnotatorTypes = [AnnotatorType.SENTENCE_EMBEDDINGS]
|
|
273
|
+
|
|
274
|
+
outputAnnotatorType = AnnotatorType.CATEGORY
|
|
275
|
+
|
|
276
|
+
def __init__(self, classname="com.johnsnowlabs.nlp.annotators.classifier.dl.ClassifierDLModel", java_model=None):
|
|
277
|
+
super(ClassifierDLModel, self).__init__(
|
|
278
|
+
classname=classname,
|
|
279
|
+
java_model=java_model
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
configProtoBytes = Param(Params._dummy(), "configProtoBytes",
|
|
283
|
+
"ConfigProto from tensorflow, serialized into byte array. Get with config_proto.SerializeToString()",
|
|
284
|
+
TypeConverters.toListInt)
|
|
285
|
+
|
|
286
|
+
classes = Param(Params._dummy(), "classes",
|
|
287
|
+
"get the tags used to trained this ClassifierDLModel",
|
|
288
|
+
TypeConverters.toListString)
|
|
289
|
+
|
|
290
|
+
def setConfigProtoBytes(self, b):
|
|
291
|
+
"""Sets configProto from tensorflow, serialized into byte array.
|
|
292
|
+
|
|
293
|
+
Parameters
|
|
294
|
+
----------
|
|
295
|
+
b : List[int]
|
|
296
|
+
ConfigProto from tensorflow, serialized into byte array
|
|
297
|
+
"""
|
|
298
|
+
return self._set(configProtoBytes=b)
|
|
299
|
+
|
|
300
|
+
@staticmethod
|
|
301
|
+
def pretrained(name="classifierdl_use_trec6", lang="en", remote_loc=None):
|
|
302
|
+
"""Downloads and loads a pretrained model.
|
|
303
|
+
|
|
304
|
+
Parameters
|
|
305
|
+
----------
|
|
306
|
+
name : str, optional
|
|
307
|
+
Name of the pretrained model, by default "classifierdl_use_trec6"
|
|
308
|
+
lang : str, optional
|
|
309
|
+
Language of the pretrained model, by default "en"
|
|
310
|
+
remote_loc : str, optional
|
|
311
|
+
Optional remote address of the resource, by default None. Will use
|
|
312
|
+
Spark NLPs repositories otherwise.
|
|
313
|
+
|
|
314
|
+
Returns
|
|
315
|
+
-------
|
|
316
|
+
ClassifierDLModel
|
|
317
|
+
The restored model
|
|
318
|
+
"""
|
|
319
|
+
from sparknlp.pretrained import ResourceDownloader
|
|
320
|
+
return ResourceDownloader.downloadModel(ClassifierDLModel, name, lang, remote_loc)
|
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
# Copyright 2017-2022 John Snow Labs
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from sparknlp.common import *
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class DeBertaForQuestionAnswering(AnnotatorModel,
|
|
19
|
+
HasCaseSensitiveProperties,
|
|
20
|
+
HasBatchedAnnotate,
|
|
21
|
+
HasEngine,
|
|
22
|
+
HasMaxSentenceLengthLimit):
|
|
23
|
+
"""DeBertaForQuestionAnswering can load DeBERTa Models with a span classification head on top for extractive
|
|
24
|
+
question-answering tasks like SQuAD (a linear layer on top of the hidden-states output to compute span start
|
|
25
|
+
logits and span end logits).
|
|
26
|
+
|
|
27
|
+
Pretrained models can be loaded with :meth:`.pretrained` of the companion
|
|
28
|
+
object:
|
|
29
|
+
|
|
30
|
+
>>> spanClassifier = DeBertaForQuestionAnswering.pretrained() \\
|
|
31
|
+
... .setInputCols(["document_question", "document_context"]) \\
|
|
32
|
+
... .setOutputCol("answer")
|
|
33
|
+
|
|
34
|
+
The default model is ``"deberta_v3_xsmall_qa_squad2"``, if no name is
|
|
35
|
+
provided.
|
|
36
|
+
|
|
37
|
+
For available pretrained models please see the `Models Hub
|
|
38
|
+
<https://sparknlp.org/models?task=Question+Answering>`__.
|
|
39
|
+
|
|
40
|
+
To see which models are compatible and how to import them see
|
|
41
|
+
`Import Transformers into Spark NLP 🚀
|
|
42
|
+
<https://github.com/JohnSnowLabs/spark-nlp/discussions/5669>`_.
|
|
43
|
+
|
|
44
|
+
====================== ======================
|
|
45
|
+
Input Annotation types Output Annotation type
|
|
46
|
+
====================== ======================
|
|
47
|
+
``DOCUMENT, DOCUMENT`` ``CHUNK``
|
|
48
|
+
====================== ======================
|
|
49
|
+
|
|
50
|
+
Parameters
|
|
51
|
+
----------
|
|
52
|
+
batchSize
|
|
53
|
+
Batch size. Large values allows faster processing but requires more
|
|
54
|
+
memory, by default 8
|
|
55
|
+
caseSensitive
|
|
56
|
+
Whether to ignore case in tokens for embeddings matching, by default
|
|
57
|
+
False
|
|
58
|
+
configProtoBytes
|
|
59
|
+
ConfigProto from tensorflow, serialized into byte array.
|
|
60
|
+
maxSentenceLength
|
|
61
|
+
Max sentence length to process, by default 128
|
|
62
|
+
|
|
63
|
+
Examples
|
|
64
|
+
--------
|
|
65
|
+
>>> import sparknlp
|
|
66
|
+
>>> from sparknlp.base import *
|
|
67
|
+
>>> from sparknlp.annotator import *
|
|
68
|
+
>>> from pyspark.ml import Pipeline
|
|
69
|
+
>>> documentAssembler = MultiDocumentAssembler() \\
|
|
70
|
+
... .setInputCols(["question", "context"]) \\
|
|
71
|
+
... .setOutputCol(["document_question", "document_context"])
|
|
72
|
+
>>> spanClassifier = DeBertaForQuestionAnswering.pretrained() \\
|
|
73
|
+
... .setInputCols(["document_question", "document_context"]) \\
|
|
74
|
+
... .setOutputCol("answer") \\
|
|
75
|
+
... .setCaseSensitive(False)
|
|
76
|
+
>>> pipeline = Pipeline().setStages([
|
|
77
|
+
... documentAssembler,
|
|
78
|
+
... spanClassifier
|
|
79
|
+
... ])
|
|
80
|
+
>>> data = spark.createDataFrame([["What's my name?", "My name is Clara and I live in Berkeley."]]).toDF("question", "context")
|
|
81
|
+
>>> result = pipeline.fit(data).transform(data)
|
|
82
|
+
>>> result.select("answer.result").show(truncate=False)
|
|
83
|
+
+--------------------+
|
|
84
|
+
|result |
|
|
85
|
+
+--------------------+
|
|
86
|
+
|[Clara] |
|
|
87
|
+
+--------------------+
|
|
88
|
+
"""
|
|
89
|
+
name = "DeBertaForQuestionAnswering"
|
|
90
|
+
|
|
91
|
+
inputAnnotatorTypes = [AnnotatorType.DOCUMENT, AnnotatorType.DOCUMENT]
|
|
92
|
+
|
|
93
|
+
outputAnnotatorType = AnnotatorType.CHUNK
|
|
94
|
+
|
|
95
|
+
configProtoBytes = Param(Params._dummy(),
|
|
96
|
+
"configProtoBytes",
|
|
97
|
+
"ConfigProto from tensorflow, serialized into byte array. Get with config_proto.SerializeToString()",
|
|
98
|
+
TypeConverters.toListInt)
|
|
99
|
+
|
|
100
|
+
coalesceSentences = Param(Params._dummy(), "coalesceSentences",
|
|
101
|
+
"Instead of 1 class per sentence (if inputCols is '''sentence''') output 1 class per document by averaging probabilities in all sentences.",
|
|
102
|
+
TypeConverters.toBoolean)
|
|
103
|
+
|
|
104
|
+
def setConfigProtoBytes(self, b):
|
|
105
|
+
"""Sets configProto from tensorflow, serialized into byte array.
|
|
106
|
+
|
|
107
|
+
Parameters
|
|
108
|
+
----------
|
|
109
|
+
b : List[int]
|
|
110
|
+
ConfigProto from tensorflow, serialized into byte array
|
|
111
|
+
"""
|
|
112
|
+
return self._set(configProtoBytes=b)
|
|
113
|
+
|
|
114
|
+
@keyword_only
|
|
115
|
+
def __init__(self, classname="com.johnsnowlabs.nlp.annotators.classifier.dl.DeBertaForQuestionAnswering",
|
|
116
|
+
java_model=None):
|
|
117
|
+
super(DeBertaForQuestionAnswering, self).__init__(
|
|
118
|
+
classname=classname,
|
|
119
|
+
java_model=java_model
|
|
120
|
+
)
|
|
121
|
+
self._setDefault(
|
|
122
|
+
batchSize=8,
|
|
123
|
+
maxSentenceLength=128,
|
|
124
|
+
caseSensitive=False
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
@staticmethod
|
|
128
|
+
def loadSavedModel(folder, spark_session):
|
|
129
|
+
"""Loads a locally saved model.
|
|
130
|
+
|
|
131
|
+
Parameters
|
|
132
|
+
----------
|
|
133
|
+
folder : str
|
|
134
|
+
Folder of the saved model
|
|
135
|
+
spark_session : pyspark.sql.SparkSession
|
|
136
|
+
The current SparkSession
|
|
137
|
+
|
|
138
|
+
Returns
|
|
139
|
+
-------
|
|
140
|
+
DeBertaForQuestionAnswering
|
|
141
|
+
The restored model
|
|
142
|
+
"""
|
|
143
|
+
from sparknlp.internal import _DeBertaQuestionAnsweringLoader
|
|
144
|
+
jModel = _DeBertaQuestionAnsweringLoader(folder, spark_session._jsparkSession)._java_obj
|
|
145
|
+
return DeBertaForQuestionAnswering(java_model=jModel)
|
|
146
|
+
|
|
147
|
+
@staticmethod
|
|
148
|
+
def pretrained(name="deberta_v3_xsmall_qa_squad2", lang="en", remote_loc=None):
|
|
149
|
+
"""Downloads and loads a pretrained model.
|
|
150
|
+
|
|
151
|
+
Parameters
|
|
152
|
+
----------
|
|
153
|
+
name : str, optional
|
|
154
|
+
Name of the pretrained model, by default
|
|
155
|
+
"deberta_v3_xsmall_qa_squad2"
|
|
156
|
+
lang : str, optional
|
|
157
|
+
Language of the pretrained model, by default "en"
|
|
158
|
+
remote_loc : str, optional
|
|
159
|
+
Optional remote address of the resource, by default None. Will use
|
|
160
|
+
Spark NLPs repositories otherwise.
|
|
161
|
+
|
|
162
|
+
Returns
|
|
163
|
+
-------
|
|
164
|
+
DeBertaForQuestionAnswering
|
|
165
|
+
The restored model
|
|
166
|
+
"""
|
|
167
|
+
from sparknlp.pretrained import ResourceDownloader
|
|
168
|
+
return ResourceDownloader.downloadModel(DeBertaForQuestionAnswering, name, lang, remote_loc)
|
|
@@ -0,0 +1,198 @@
|
|
|
1
|
+
# Copyright 2017-2022 John Snow Labs
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Contains classes for DeBertaForSequenceClassification."""
|
|
15
|
+
from sparknlp.common import *
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class DeBertaForSequenceClassification(AnnotatorModel,
|
|
19
|
+
HasCaseSensitiveProperties,
|
|
20
|
+
HasBatchedAnnotate,
|
|
21
|
+
HasClassifierActivationProperties,
|
|
22
|
+
HasEngine,
|
|
23
|
+
HasMaxSentenceLengthLimit):
|
|
24
|
+
"""DeBertaForSequenceClassification can load DeBERTa v2 & v3 Models with sequence classification/regression head on
|
|
25
|
+
top (a linear layer on top of the pooled output) e.g. for multi-class document classification tasks.
|
|
26
|
+
|
|
27
|
+
Pretrained models can be loaded with :meth:`.pretrained` of the companion
|
|
28
|
+
object:
|
|
29
|
+
|
|
30
|
+
>>> sequenceClassifier = DeBertaForSequenceClassification.pretrained() \\
|
|
31
|
+
... .setInputCols(["token", "document"]) \\
|
|
32
|
+
... .setOutputCol("label")
|
|
33
|
+
|
|
34
|
+
The default model is ``"deberta_v3_xsmall_sequence_classifier_imdb"``, if no name is
|
|
35
|
+
provided.
|
|
36
|
+
|
|
37
|
+
For available pretrained models please see the `Models Hub
|
|
38
|
+
<https://sparknlp.org/models?task=Text+Classification>`__.
|
|
39
|
+
|
|
40
|
+
To see which models are compatible and how to import them see
|
|
41
|
+
`Import Transformers into Spark NLP 🚀
|
|
42
|
+
<https://github.com/JohnSnowLabs/spark-nlp/discussions/5669>`_.
|
|
43
|
+
|
|
44
|
+
====================== ======================
|
|
45
|
+
Input Annotation types Output Annotation type
|
|
46
|
+
====================== ======================
|
|
47
|
+
``DOCUMENT, TOKEN`` ``CATEGORY``
|
|
48
|
+
====================== ======================
|
|
49
|
+
|
|
50
|
+
Parameters
|
|
51
|
+
----------
|
|
52
|
+
batchSize
|
|
53
|
+
Batch size. Large values allows faster processing but requires more
|
|
54
|
+
memory, by default 8
|
|
55
|
+
caseSensitive
|
|
56
|
+
Whether to ignore case in tokens for embeddings matching, by default
|
|
57
|
+
True
|
|
58
|
+
configProtoBytes
|
|
59
|
+
ConfigProto from tensorflow, serialized into byte array.
|
|
60
|
+
maxSentenceLength
|
|
61
|
+
Max sentence length to process, by default 128
|
|
62
|
+
coalesceSentences
|
|
63
|
+
Instead of 1 class per sentence (if inputCols is `sentence`) output
|
|
64
|
+
1 class per document by averaging probabilities in all sentences, by
|
|
65
|
+
default False.
|
|
66
|
+
|
|
67
|
+
Examples
|
|
68
|
+
--------
|
|
69
|
+
>>> import sparknlp
|
|
70
|
+
>>> from sparknlp.base import *
|
|
71
|
+
>>> from sparknlp.annotator import *
|
|
72
|
+
>>> from pyspark.ml import Pipeline
|
|
73
|
+
>>> documentAssembler = DocumentAssembler() \\
|
|
74
|
+
... .setInputCol("text") \\
|
|
75
|
+
... .setOutputCol("document")
|
|
76
|
+
>>> tokenizer = Tokenizer() \\
|
|
77
|
+
... .setInputCols(["document"]) \\
|
|
78
|
+
... .setOutputCol("token")
|
|
79
|
+
>>> sequenceClassifier = DeBertaForSequenceClassification.pretrained() \\
|
|
80
|
+
... .setInputCols(["token", "document"]) \\
|
|
81
|
+
... .setOutputCol("label") \\
|
|
82
|
+
... .setCaseSensitive(True)
|
|
83
|
+
>>> pipeline = Pipeline().setStages([
|
|
84
|
+
... documentAssembler,
|
|
85
|
+
... tokenizer,
|
|
86
|
+
... sequenceClassifier
|
|
87
|
+
... ])
|
|
88
|
+
>>> data = spark.createDataFrame([["I loved this movie when I was a child.", "It was pretty boring."]]).toDF("text")
|
|
89
|
+
>>> result = pipeline.fit(data).transform(data)
|
|
90
|
+
>>> result.select("label.result").show(truncate=False)
|
|
91
|
+
+------+
|
|
92
|
+
|result|
|
|
93
|
+
+------+
|
|
94
|
+
|[pos] |
|
|
95
|
+
|[neg] |
|
|
96
|
+
+------+
|
|
97
|
+
"""
|
|
98
|
+
name = "DeBertaForSequenceClassification"
|
|
99
|
+
|
|
100
|
+
inputAnnotatorTypes = [AnnotatorType.DOCUMENT, AnnotatorType.TOKEN]
|
|
101
|
+
|
|
102
|
+
outputAnnotatorType = AnnotatorType.CATEGORY
|
|
103
|
+
|
|
104
|
+
configProtoBytes = Param(Params._dummy(),
|
|
105
|
+
"configProtoBytes",
|
|
106
|
+
"ConfigProto from tensorflow, serialized into byte array. Get with config_proto.SerializeToString()",
|
|
107
|
+
TypeConverters.toListInt)
|
|
108
|
+
|
|
109
|
+
coalesceSentences = Param(Params._dummy(), "coalesceSentences",
|
|
110
|
+
"Instead of 1 class per sentence (if inputCols is '''sentence''') output 1 class per document by averaging probabilities in all sentences.",
|
|
111
|
+
TypeConverters.toBoolean)
|
|
112
|
+
|
|
113
|
+
def getClasses(self):
|
|
114
|
+
"""
|
|
115
|
+
Returns labels used to train this model
|
|
116
|
+
"""
|
|
117
|
+
return self._call_java("getClasses")
|
|
118
|
+
|
|
119
|
+
def setConfigProtoBytes(self, b):
|
|
120
|
+
"""Sets configProto from tensorflow, serialized into byte array.
|
|
121
|
+
|
|
122
|
+
Parameters
|
|
123
|
+
----------
|
|
124
|
+
b : List[int]
|
|
125
|
+
ConfigProto from tensorflow, serialized into byte array
|
|
126
|
+
"""
|
|
127
|
+
return self._set(configProtoBytes=b)
|
|
128
|
+
|
|
129
|
+
def setCoalesceSentences(self, value):
|
|
130
|
+
"""Instead of 1 class per sentence (if inputCols is '''sentence''') output 1 class per document by averaging
|
|
131
|
+
probabilities in all sentences. Due to max sequence length limit in almost all transformer models such as
|
|
132
|
+
BERT (512 tokens), this parameter helps to feed all the sentences into the model and averaging all the
|
|
133
|
+
probabilities for the entire document instead of probabilities per sentence. (Default: true)
|
|
134
|
+
|
|
135
|
+
Parameters
|
|
136
|
+
----------
|
|
137
|
+
value : bool
|
|
138
|
+
If the output of all sentences will be averaged to one output
|
|
139
|
+
"""
|
|
140
|
+
return self._set(coalesceSentences=value)
|
|
141
|
+
|
|
142
|
+
@keyword_only
|
|
143
|
+
def __init__(self, classname="com.johnsnowlabs.nlp.annotators.classifier.dl.DeBertaForSequenceClassification",
|
|
144
|
+
java_model=None):
|
|
145
|
+
super(DeBertaForSequenceClassification, self).__init__(
|
|
146
|
+
classname=classname,
|
|
147
|
+
java_model=java_model
|
|
148
|
+
)
|
|
149
|
+
self._setDefault(
|
|
150
|
+
batchSize=8,
|
|
151
|
+
maxSentenceLength=128,
|
|
152
|
+
caseSensitive=True,
|
|
153
|
+
coalesceSentences=False,
|
|
154
|
+
activation="softmax"
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
@staticmethod
|
|
158
|
+
def loadSavedModel(folder, spark_session):
|
|
159
|
+
"""Loads a locally saved model.
|
|
160
|
+
|
|
161
|
+
Parameters
|
|
162
|
+
----------
|
|
163
|
+
folder : str
|
|
164
|
+
Folder of the saved model
|
|
165
|
+
spark_session : pyspark.sql.SparkSession
|
|
166
|
+
The current SparkSession
|
|
167
|
+
|
|
168
|
+
Returns
|
|
169
|
+
-------
|
|
170
|
+
AlbertForSequenceClassification
|
|
171
|
+
The restored model
|
|
172
|
+
"""
|
|
173
|
+
from sparknlp.internal import _DeBertaSequenceClassifierLoader
|
|
174
|
+
jModel = _DeBertaSequenceClassifierLoader(folder, spark_session._jsparkSession)._java_obj
|
|
175
|
+
return DeBertaForSequenceClassification(java_model=jModel)
|
|
176
|
+
|
|
177
|
+
@staticmethod
|
|
178
|
+
def pretrained(name="deberta_base_sequence_classifier_imdb", lang="en", remote_loc=None):
|
|
179
|
+
"""Downloads and loads a pretrained model.
|
|
180
|
+
|
|
181
|
+
Parameters
|
|
182
|
+
----------
|
|
183
|
+
name : str, optional
|
|
184
|
+
Name of the pretrained model, by default
|
|
185
|
+
"deberta_v3_xsmall_sequence_classifier_imdb"
|
|
186
|
+
lang : str, optional
|
|
187
|
+
Language of the pretrained model, by default "en"
|
|
188
|
+
remote_loc : str, optional
|
|
189
|
+
Optional remote address of the resource, by default None. Will use
|
|
190
|
+
Spark NLPs repositories otherwise.
|
|
191
|
+
|
|
192
|
+
Returns
|
|
193
|
+
-------
|
|
194
|
+
AlbertForSequenceClassification
|
|
195
|
+
The restored model
|
|
196
|
+
"""
|
|
197
|
+
from sparknlp.pretrained import ResourceDownloader
|
|
198
|
+
return ResourceDownloader.downloadModel(DeBertaForSequenceClassification, name, lang, remote_loc)
|