spark-nlp 2.6.3rc1__py2.py3-none-any.whl → 6.2.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- com/johnsnowlabs/ml/__init__.py +0 -0
- com/johnsnowlabs/ml/ai/__init__.py +10 -0
- com/johnsnowlabs/nlp/__init__.py +4 -2
- spark_nlp-6.2.1.dist-info/METADATA +362 -0
- spark_nlp-6.2.1.dist-info/RECORD +292 -0
- {spark_nlp-2.6.3rc1.dist-info → spark_nlp-6.2.1.dist-info}/WHEEL +1 -1
- sparknlp/__init__.py +281 -27
- sparknlp/annotation.py +137 -6
- sparknlp/annotation_audio.py +61 -0
- sparknlp/annotation_image.py +82 -0
- sparknlp/annotator/__init__.py +93 -0
- sparknlp/annotator/audio/__init__.py +16 -0
- sparknlp/annotator/audio/hubert_for_ctc.py +188 -0
- sparknlp/annotator/audio/wav2vec2_for_ctc.py +161 -0
- sparknlp/annotator/audio/whisper_for_ctc.py +251 -0
- sparknlp/annotator/chunk2_doc.py +85 -0
- sparknlp/annotator/chunker.py +137 -0
- sparknlp/annotator/classifier_dl/__init__.py +61 -0
- sparknlp/annotator/classifier_dl/albert_for_multiple_choice.py +161 -0
- sparknlp/annotator/classifier_dl/albert_for_question_answering.py +172 -0
- sparknlp/annotator/classifier_dl/albert_for_sequence_classification.py +201 -0
- sparknlp/annotator/classifier_dl/albert_for_token_classification.py +179 -0
- sparknlp/annotator/classifier_dl/albert_for_zero_shot_classification.py +211 -0
- sparknlp/annotator/classifier_dl/bart_for_zero_shot_classification.py +225 -0
- sparknlp/annotator/classifier_dl/bert_for_multiple_choice.py +161 -0
- sparknlp/annotator/classifier_dl/bert_for_question_answering.py +168 -0
- sparknlp/annotator/classifier_dl/bert_for_sequence_classification.py +202 -0
- sparknlp/annotator/classifier_dl/bert_for_token_classification.py +177 -0
- sparknlp/annotator/classifier_dl/bert_for_zero_shot_classification.py +212 -0
- sparknlp/annotator/classifier_dl/camembert_for_question_answering.py +168 -0
- sparknlp/annotator/classifier_dl/camembert_for_sequence_classification.py +205 -0
- sparknlp/annotator/classifier_dl/camembert_for_token_classification.py +173 -0
- sparknlp/annotator/classifier_dl/camembert_for_zero_shot_classification.py +202 -0
- sparknlp/annotator/classifier_dl/classifier_dl.py +320 -0
- sparknlp/annotator/classifier_dl/deberta_for_question_answering.py +168 -0
- sparknlp/annotator/classifier_dl/deberta_for_sequence_classification.py +198 -0
- sparknlp/annotator/classifier_dl/deberta_for_token_classification.py +175 -0
- sparknlp/annotator/classifier_dl/deberta_for_zero_shot_classification.py +193 -0
- sparknlp/annotator/classifier_dl/distil_bert_for_question_answering.py +168 -0
- sparknlp/annotator/classifier_dl/distil_bert_for_sequence_classification.py +201 -0
- sparknlp/annotator/classifier_dl/distil_bert_for_token_classification.py +175 -0
- sparknlp/annotator/classifier_dl/distil_bert_for_zero_shot_classification.py +211 -0
- sparknlp/annotator/classifier_dl/distilbert_for_multiple_choice.py +161 -0
- sparknlp/annotator/classifier_dl/longformer_for_question_answering.py +168 -0
- sparknlp/annotator/classifier_dl/longformer_for_sequence_classification.py +201 -0
- sparknlp/annotator/classifier_dl/longformer_for_token_classification.py +176 -0
- sparknlp/annotator/classifier_dl/mpnet_for_question_answering.py +148 -0
- sparknlp/annotator/classifier_dl/mpnet_for_sequence_classification.py +188 -0
- sparknlp/annotator/classifier_dl/mpnet_for_token_classification.py +173 -0
- sparknlp/annotator/classifier_dl/multi_classifier_dl.py +395 -0
- sparknlp/annotator/classifier_dl/roberta_for_multiple_choice.py +161 -0
- sparknlp/annotator/classifier_dl/roberta_for_question_answering.py +168 -0
- sparknlp/annotator/classifier_dl/roberta_for_sequence_classification.py +201 -0
- sparknlp/annotator/classifier_dl/roberta_for_token_classification.py +189 -0
- sparknlp/annotator/classifier_dl/roberta_for_zero_shot_classification.py +225 -0
- sparknlp/annotator/classifier_dl/sentiment_dl.py +378 -0
- sparknlp/annotator/classifier_dl/tapas_for_question_answering.py +170 -0
- sparknlp/annotator/classifier_dl/xlm_roberta_for_multiple_choice.py +149 -0
- sparknlp/annotator/classifier_dl/xlm_roberta_for_question_answering.py +168 -0
- sparknlp/annotator/classifier_dl/xlm_roberta_for_sequence_classification.py +201 -0
- sparknlp/annotator/classifier_dl/xlm_roberta_for_token_classification.py +173 -0
- sparknlp/annotator/classifier_dl/xlm_roberta_for_zero_shot_classification.py +225 -0
- sparknlp/annotator/classifier_dl/xlnet_for_sequence_classification.py +201 -0
- sparknlp/annotator/classifier_dl/xlnet_for_token_classification.py +176 -0
- sparknlp/annotator/cleaners/__init__.py +15 -0
- sparknlp/annotator/cleaners/cleaner.py +202 -0
- sparknlp/annotator/cleaners/extractor.py +191 -0
- sparknlp/annotator/coref/__init__.py +1 -0
- sparknlp/annotator/coref/spanbert_coref.py +221 -0
- sparknlp/annotator/cv/__init__.py +29 -0
- sparknlp/annotator/cv/blip_for_question_answering.py +172 -0
- sparknlp/annotator/cv/clip_for_zero_shot_classification.py +193 -0
- sparknlp/annotator/cv/convnext_for_image_classification.py +269 -0
- sparknlp/annotator/cv/florence2_transformer.py +180 -0
- sparknlp/annotator/cv/gemma3_for_multimodal.py +346 -0
- sparknlp/annotator/cv/internvl_for_multimodal.py +280 -0
- sparknlp/annotator/cv/janus_for_multimodal.py +351 -0
- sparknlp/annotator/cv/llava_for_multimodal.py +328 -0
- sparknlp/annotator/cv/mllama_for_multimodal.py +340 -0
- sparknlp/annotator/cv/paligemma_for_multimodal.py +308 -0
- sparknlp/annotator/cv/phi3_vision_for_multimodal.py +328 -0
- sparknlp/annotator/cv/qwen2vl_transformer.py +332 -0
- sparknlp/annotator/cv/smolvlm_transformer.py +426 -0
- sparknlp/annotator/cv/swin_for_image_classification.py +242 -0
- sparknlp/annotator/cv/vision_encoder_decoder_for_image_captioning.py +240 -0
- sparknlp/annotator/cv/vit_for_image_classification.py +217 -0
- sparknlp/annotator/dataframe_optimizer.py +216 -0
- sparknlp/annotator/date2_chunk.py +88 -0
- sparknlp/annotator/dependency/__init__.py +17 -0
- sparknlp/annotator/dependency/dependency_parser.py +294 -0
- sparknlp/annotator/dependency/typed_dependency_parser.py +318 -0
- sparknlp/annotator/document_character_text_splitter.py +228 -0
- sparknlp/annotator/document_normalizer.py +235 -0
- sparknlp/annotator/document_token_splitter.py +175 -0
- sparknlp/annotator/document_token_splitter_test.py +85 -0
- sparknlp/annotator/embeddings/__init__.py +45 -0
- sparknlp/annotator/embeddings/albert_embeddings.py +230 -0
- sparknlp/annotator/embeddings/auto_gguf_embeddings.py +539 -0
- sparknlp/annotator/embeddings/bert_embeddings.py +208 -0
- sparknlp/annotator/embeddings/bert_sentence_embeddings.py +224 -0
- sparknlp/annotator/embeddings/bge_embeddings.py +199 -0
- sparknlp/annotator/embeddings/camembert_embeddings.py +210 -0
- sparknlp/annotator/embeddings/chunk_embeddings.py +149 -0
- sparknlp/annotator/embeddings/deberta_embeddings.py +208 -0
- sparknlp/annotator/embeddings/distil_bert_embeddings.py +221 -0
- sparknlp/annotator/embeddings/doc2vec.py +352 -0
- sparknlp/annotator/embeddings/e5_embeddings.py +195 -0
- sparknlp/annotator/embeddings/e5v_embeddings.py +138 -0
- sparknlp/annotator/embeddings/elmo_embeddings.py +251 -0
- sparknlp/annotator/embeddings/instructor_embeddings.py +204 -0
- sparknlp/annotator/embeddings/longformer_embeddings.py +211 -0
- sparknlp/annotator/embeddings/minilm_embeddings.py +189 -0
- sparknlp/annotator/embeddings/mpnet_embeddings.py +192 -0
- sparknlp/annotator/embeddings/mxbai_embeddings.py +184 -0
- sparknlp/annotator/embeddings/nomic_embeddings.py +181 -0
- sparknlp/annotator/embeddings/roberta_embeddings.py +225 -0
- sparknlp/annotator/embeddings/roberta_sentence_embeddings.py +191 -0
- sparknlp/annotator/embeddings/sentence_embeddings.py +134 -0
- sparknlp/annotator/embeddings/snowflake_embeddings.py +202 -0
- sparknlp/annotator/embeddings/uae_embeddings.py +211 -0
- sparknlp/annotator/embeddings/universal_sentence_encoder.py +211 -0
- sparknlp/annotator/embeddings/word2vec.py +353 -0
- sparknlp/annotator/embeddings/word_embeddings.py +385 -0
- sparknlp/annotator/embeddings/xlm_roberta_embeddings.py +225 -0
- sparknlp/annotator/embeddings/xlm_roberta_sentence_embeddings.py +194 -0
- sparknlp/annotator/embeddings/xlnet_embeddings.py +227 -0
- sparknlp/annotator/er/__init__.py +16 -0
- sparknlp/annotator/er/entity_ruler.py +267 -0
- sparknlp/annotator/graph_extraction.py +368 -0
- sparknlp/annotator/keyword_extraction/__init__.py +16 -0
- sparknlp/annotator/keyword_extraction/yake_keyword_extraction.py +270 -0
- sparknlp/annotator/ld_dl/__init__.py +16 -0
- sparknlp/annotator/ld_dl/language_detector_dl.py +199 -0
- sparknlp/annotator/lemmatizer.py +250 -0
- sparknlp/annotator/matcher/__init__.py +20 -0
- sparknlp/annotator/matcher/big_text_matcher.py +272 -0
- sparknlp/annotator/matcher/date_matcher.py +303 -0
- sparknlp/annotator/matcher/multi_date_matcher.py +109 -0
- sparknlp/annotator/matcher/regex_matcher.py +221 -0
- sparknlp/annotator/matcher/text_matcher.py +290 -0
- sparknlp/annotator/n_gram_generator.py +141 -0
- sparknlp/annotator/ner/__init__.py +21 -0
- sparknlp/annotator/ner/ner_approach.py +94 -0
- sparknlp/annotator/ner/ner_converter.py +148 -0
- sparknlp/annotator/ner/ner_crf.py +397 -0
- sparknlp/annotator/ner/ner_dl.py +591 -0
- sparknlp/annotator/ner/ner_dl_graph_checker.py +293 -0
- sparknlp/annotator/ner/ner_overwriter.py +166 -0
- sparknlp/annotator/ner/zero_shot_ner_model.py +173 -0
- sparknlp/annotator/normalizer.py +230 -0
- sparknlp/annotator/openai/__init__.py +16 -0
- sparknlp/annotator/openai/openai_completion.py +349 -0
- sparknlp/annotator/openai/openai_embeddings.py +106 -0
- sparknlp/annotator/param/__init__.py +17 -0
- sparknlp/annotator/param/classifier_encoder.py +98 -0
- sparknlp/annotator/param/evaluation_dl_params.py +130 -0
- sparknlp/annotator/pos/__init__.py +16 -0
- sparknlp/annotator/pos/perceptron.py +263 -0
- sparknlp/annotator/sentence/__init__.py +17 -0
- sparknlp/annotator/sentence/sentence_detector.py +290 -0
- sparknlp/annotator/sentence/sentence_detector_dl.py +467 -0
- sparknlp/annotator/sentiment/__init__.py +17 -0
- sparknlp/annotator/sentiment/sentiment_detector.py +208 -0
- sparknlp/annotator/sentiment/vivekn_sentiment.py +242 -0
- sparknlp/annotator/seq2seq/__init__.py +35 -0
- sparknlp/annotator/seq2seq/auto_gguf_model.py +304 -0
- sparknlp/annotator/seq2seq/auto_gguf_reranker.py +334 -0
- sparknlp/annotator/seq2seq/auto_gguf_vision_model.py +336 -0
- sparknlp/annotator/seq2seq/bart_transformer.py +420 -0
- sparknlp/annotator/seq2seq/cohere_transformer.py +357 -0
- sparknlp/annotator/seq2seq/cpm_transformer.py +321 -0
- sparknlp/annotator/seq2seq/gpt2_transformer.py +363 -0
- sparknlp/annotator/seq2seq/llama2_transformer.py +343 -0
- sparknlp/annotator/seq2seq/llama3_transformer.py +381 -0
- sparknlp/annotator/seq2seq/m2m100_transformer.py +392 -0
- sparknlp/annotator/seq2seq/marian_transformer.py +374 -0
- sparknlp/annotator/seq2seq/mistral_transformer.py +348 -0
- sparknlp/annotator/seq2seq/nllb_transformer.py +420 -0
- sparknlp/annotator/seq2seq/olmo_transformer.py +326 -0
- sparknlp/annotator/seq2seq/phi2_transformer.py +326 -0
- sparknlp/annotator/seq2seq/phi3_transformer.py +330 -0
- sparknlp/annotator/seq2seq/phi4_transformer.py +387 -0
- sparknlp/annotator/seq2seq/qwen_transformer.py +340 -0
- sparknlp/annotator/seq2seq/starcoder_transformer.py +335 -0
- sparknlp/annotator/seq2seq/t5_transformer.py +425 -0
- sparknlp/annotator/similarity/__init__.py +0 -0
- sparknlp/annotator/similarity/document_similarity_ranker.py +379 -0
- sparknlp/annotator/spell_check/__init__.py +18 -0
- sparknlp/annotator/spell_check/context_spell_checker.py +911 -0
- sparknlp/annotator/spell_check/norvig_sweeting.py +358 -0
- sparknlp/annotator/spell_check/symmetric_delete.py +299 -0
- sparknlp/annotator/stemmer.py +79 -0
- sparknlp/annotator/stop_words_cleaner.py +190 -0
- sparknlp/annotator/tf_ner_dl_graph_builder.py +179 -0
- sparknlp/annotator/token/__init__.py +19 -0
- sparknlp/annotator/token/chunk_tokenizer.py +118 -0
- sparknlp/annotator/token/recursive_tokenizer.py +205 -0
- sparknlp/annotator/token/regex_tokenizer.py +208 -0
- sparknlp/annotator/token/tokenizer.py +561 -0
- sparknlp/annotator/token2_chunk.py +76 -0
- sparknlp/annotator/ws/__init__.py +16 -0
- sparknlp/annotator/ws/word_segmenter.py +429 -0
- sparknlp/base/__init__.py +30 -0
- sparknlp/base/audio_assembler.py +95 -0
- sparknlp/base/doc2_chunk.py +169 -0
- sparknlp/base/document_assembler.py +164 -0
- sparknlp/base/embeddings_finisher.py +201 -0
- sparknlp/base/finisher.py +217 -0
- sparknlp/base/gguf_ranking_finisher.py +234 -0
- sparknlp/base/graph_finisher.py +125 -0
- sparknlp/base/has_recursive_fit.py +24 -0
- sparknlp/base/has_recursive_transform.py +22 -0
- sparknlp/base/image_assembler.py +172 -0
- sparknlp/base/light_pipeline.py +429 -0
- sparknlp/base/multi_document_assembler.py +164 -0
- sparknlp/base/prompt_assembler.py +207 -0
- sparknlp/base/recursive_pipeline.py +107 -0
- sparknlp/base/table_assembler.py +145 -0
- sparknlp/base/token_assembler.py +124 -0
- sparknlp/common/__init__.py +26 -0
- sparknlp/common/annotator_approach.py +41 -0
- sparknlp/common/annotator_model.py +47 -0
- sparknlp/common/annotator_properties.py +114 -0
- sparknlp/common/annotator_type.py +38 -0
- sparknlp/common/completion_post_processing.py +37 -0
- sparknlp/common/coverage_result.py +22 -0
- sparknlp/common/match_strategy.py +33 -0
- sparknlp/common/properties.py +1298 -0
- sparknlp/common/read_as.py +33 -0
- sparknlp/common/recursive_annotator_approach.py +35 -0
- sparknlp/common/storage.py +149 -0
- sparknlp/common/utils.py +39 -0
- sparknlp/functions.py +315 -5
- sparknlp/internal/__init__.py +1199 -0
- sparknlp/internal/annotator_java_ml.py +32 -0
- sparknlp/internal/annotator_transformer.py +37 -0
- sparknlp/internal/extended_java_wrapper.py +63 -0
- sparknlp/internal/params_getters_setters.py +71 -0
- sparknlp/internal/recursive.py +70 -0
- sparknlp/logging/__init__.py +15 -0
- sparknlp/logging/comet.py +467 -0
- sparknlp/partition/__init__.py +16 -0
- sparknlp/partition/partition.py +244 -0
- sparknlp/partition/partition_properties.py +902 -0
- sparknlp/partition/partition_transformer.py +200 -0
- sparknlp/pretrained/__init__.py +17 -0
- sparknlp/pretrained/pretrained_pipeline.py +158 -0
- sparknlp/pretrained/resource_downloader.py +216 -0
- sparknlp/pretrained/utils.py +35 -0
- sparknlp/reader/__init__.py +15 -0
- sparknlp/reader/enums.py +19 -0
- sparknlp/reader/pdf_to_text.py +190 -0
- sparknlp/reader/reader2doc.py +124 -0
- sparknlp/reader/reader2image.py +136 -0
- sparknlp/reader/reader2table.py +44 -0
- sparknlp/reader/reader_assembler.py +159 -0
- sparknlp/reader/sparknlp_reader.py +461 -0
- sparknlp/training/__init__.py +20 -0
- sparknlp/training/_tf_graph_builders/__init__.py +0 -0
- sparknlp/training/_tf_graph_builders/graph_builders.py +299 -0
- sparknlp/training/_tf_graph_builders/ner_dl/__init__.py +0 -0
- sparknlp/training/_tf_graph_builders/ner_dl/create_graph.py +41 -0
- sparknlp/training/_tf_graph_builders/ner_dl/dataset_encoder.py +78 -0
- sparknlp/training/_tf_graph_builders/ner_dl/ner_model.py +521 -0
- sparknlp/training/_tf_graph_builders/ner_dl/ner_model_saver.py +62 -0
- sparknlp/training/_tf_graph_builders/ner_dl/sentence_grouper.py +28 -0
- sparknlp/training/_tf_graph_builders/tf2contrib/__init__.py +36 -0
- sparknlp/training/_tf_graph_builders/tf2contrib/core_rnn_cell.py +385 -0
- sparknlp/training/_tf_graph_builders/tf2contrib/fused_rnn_cell.py +183 -0
- sparknlp/training/_tf_graph_builders/tf2contrib/gru_ops.py +235 -0
- sparknlp/training/_tf_graph_builders/tf2contrib/lstm_ops.py +665 -0
- sparknlp/training/_tf_graph_builders/tf2contrib/rnn.py +245 -0
- sparknlp/training/_tf_graph_builders/tf2contrib/rnn_cell.py +4006 -0
- sparknlp/training/_tf_graph_builders_1x/__init__.py +0 -0
- sparknlp/training/_tf_graph_builders_1x/graph_builders.py +277 -0
- sparknlp/training/_tf_graph_builders_1x/ner_dl/__init__.py +0 -0
- sparknlp/training/_tf_graph_builders_1x/ner_dl/create_graph.py +34 -0
- sparknlp/training/_tf_graph_builders_1x/ner_dl/dataset_encoder.py +78 -0
- sparknlp/training/_tf_graph_builders_1x/ner_dl/ner_model.py +532 -0
- sparknlp/training/_tf_graph_builders_1x/ner_dl/ner_model_saver.py +62 -0
- sparknlp/training/_tf_graph_builders_1x/ner_dl/sentence_grouper.py +28 -0
- sparknlp/training/conll.py +150 -0
- sparknlp/training/conllu.py +103 -0
- sparknlp/training/pos.py +103 -0
- sparknlp/training/pub_tator.py +76 -0
- sparknlp/training/spacy_to_annotation.py +57 -0
- sparknlp/training/tfgraphs.py +5 -0
- sparknlp/upload_to_hub.py +149 -0
- sparknlp/util.py +51 -5
- com/__init__.pyc +0 -0
- com/__pycache__/__init__.cpython-36.pyc +0 -0
- com/johnsnowlabs/__init__.pyc +0 -0
- com/johnsnowlabs/__pycache__/__init__.cpython-36.pyc +0 -0
- com/johnsnowlabs/nlp/__init__.pyc +0 -0
- com/johnsnowlabs/nlp/__pycache__/__init__.cpython-36.pyc +0 -0
- spark_nlp-2.6.3rc1.dist-info/METADATA +0 -36
- spark_nlp-2.6.3rc1.dist-info/RECORD +0 -48
- sparknlp/__init__.pyc +0 -0
- sparknlp/__pycache__/__init__.cpython-36.pyc +0 -0
- sparknlp/__pycache__/annotation.cpython-36.pyc +0 -0
- sparknlp/__pycache__/annotator.cpython-36.pyc +0 -0
- sparknlp/__pycache__/base.cpython-36.pyc +0 -0
- sparknlp/__pycache__/common.cpython-36.pyc +0 -0
- sparknlp/__pycache__/embeddings.cpython-36.pyc +0 -0
- sparknlp/__pycache__/functions.cpython-36.pyc +0 -0
- sparknlp/__pycache__/internal.cpython-36.pyc +0 -0
- sparknlp/__pycache__/pretrained.cpython-36.pyc +0 -0
- sparknlp/__pycache__/storage.cpython-36.pyc +0 -0
- sparknlp/__pycache__/training.cpython-36.pyc +0 -0
- sparknlp/__pycache__/util.cpython-36.pyc +0 -0
- sparknlp/annotation.pyc +0 -0
- sparknlp/annotator.py +0 -3006
- sparknlp/annotator.pyc +0 -0
- sparknlp/base.py +0 -347
- sparknlp/base.pyc +0 -0
- sparknlp/common.py +0 -193
- sparknlp/common.pyc +0 -0
- sparknlp/embeddings.py +0 -40
- sparknlp/embeddings.pyc +0 -0
- sparknlp/internal.py +0 -288
- sparknlp/internal.pyc +0 -0
- sparknlp/pretrained.py +0 -123
- sparknlp/pretrained.pyc +0 -0
- sparknlp/storage.py +0 -32
- sparknlp/storage.pyc +0 -0
- sparknlp/training.py +0 -62
- sparknlp/training.pyc +0 -0
- sparknlp/util.pyc +0 -0
- {spark_nlp-2.6.3rc1.dist-info → spark_nlp-6.2.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,395 @@
|
|
|
1
|
+
# Copyright 2017-2022 John Snow Labs
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Contains classes for MultiClassifierDL."""
|
|
15
|
+
|
|
16
|
+
from sparknlp.annotator.param import EvaluationDLParams, ClassifierEncoder
|
|
17
|
+
from sparknlp.annotator.classifier_dl import ClassifierDLModel
|
|
18
|
+
from sparknlp.common import *
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class MultiClassifierDLApproach(AnnotatorApproach, EvaluationDLParams, ClassifierEncoder):
|
|
22
|
+
"""Trains a MultiClassifierDL for Multi-label Text Classification.
|
|
23
|
+
|
|
24
|
+
MultiClassifierDL uses a Bidirectional GRU with a convolutional model that
|
|
25
|
+
we have built inside TensorFlow and supports up to 100 classes.
|
|
26
|
+
|
|
27
|
+
In machine learning, multi-label classification and the strongly related
|
|
28
|
+
problem of multi-output classification are variants of the classification
|
|
29
|
+
problem where multiple labels may be assigned to each instance. Multi-label
|
|
30
|
+
classification is a generalization of multiclass classification, which is
|
|
31
|
+
the single-label problem of categorizing instances into precisely one of
|
|
32
|
+
more than two classes; in the multi-label problem there is no constraint on
|
|
33
|
+
how many of the classes the instance can be assigned to. Formally,
|
|
34
|
+
multi-label classification is the problem of finding a model that maps
|
|
35
|
+
inputs x to binary vectors y (assigning a value of 0 or 1 for each element
|
|
36
|
+
(label) in y).
|
|
37
|
+
|
|
38
|
+
For instantiated/pretrained models, see :class:`.MultiClassifierDLModel`.
|
|
39
|
+
|
|
40
|
+
The input to `MultiClassifierDL` are Sentence Embeddings such as the
|
|
41
|
+
state-of-the-art :class:`.UniversalSentenceEncoder`,
|
|
42
|
+
:class:`.BertSentenceEmbeddings`, :class:`.SentenceEmbeddings` or other
|
|
43
|
+
sentence embeddings.
|
|
44
|
+
|
|
45
|
+
Setting a test dataset to monitor model metrics can be done with
|
|
46
|
+
``.setTestDataset``. The method expects a path to a parquet file containing a
|
|
47
|
+
dataframe that has the same required columns as the training dataframe. The
|
|
48
|
+
pre-processing steps for the training dataframe should also be applied to the test
|
|
49
|
+
dataframe. The following example will show how to create the test dataset:
|
|
50
|
+
|
|
51
|
+
>>> documentAssembler = DocumentAssembler() \\
|
|
52
|
+
... .setInputCol("text") \\
|
|
53
|
+
... .setOutputCol("document")
|
|
54
|
+
>>> embeddings = UniversalSentenceEncoder.pretrained() \\
|
|
55
|
+
... .setInputCols(["document"]) \\
|
|
56
|
+
... .setOutputCol("sentence_embeddings")
|
|
57
|
+
>>> preProcessingPipeline = Pipeline().setStages([documentAssembler, embeddings])
|
|
58
|
+
>>> (train, test) = data.randomSplit([0.8, 0.2])
|
|
59
|
+
>>> preProcessingPipeline \\
|
|
60
|
+
... .fit(test) \\
|
|
61
|
+
... .transform(test)
|
|
62
|
+
... .write \\
|
|
63
|
+
... .mode("overwrite") \\
|
|
64
|
+
... .parquet("test_data")
|
|
65
|
+
>>> multiClassifier = MultiClassifierDLApproach() \\
|
|
66
|
+
... .setInputCols(["sentence_embeddings"]) \\
|
|
67
|
+
... .setOutputCol("category") \\
|
|
68
|
+
... .setLabelColumn("label") \\
|
|
69
|
+
... .setTestDataset("test_data")
|
|
70
|
+
|
|
71
|
+
For extended examples of usage, see the `Examples <https://github.com/JohnSnowLabs/spark-nlp/blob/master/examples/python/training/english/classification/MultiClassifierDL_train_multi_label_E2E_challenge_classifier.ipynb>`__.
|
|
72
|
+
|
|
73
|
+
======================= ======================
|
|
74
|
+
Input Annotation types Output Annotation type
|
|
75
|
+
======================= ======================
|
|
76
|
+
``SENTENCE_EMBEDDINGS`` ``CATEGORY``
|
|
77
|
+
======================= ======================
|
|
78
|
+
|
|
79
|
+
Parameters
|
|
80
|
+
----------
|
|
81
|
+
batchSize
|
|
82
|
+
Batch size, by default 64
|
|
83
|
+
configProtoBytes
|
|
84
|
+
ConfigProto from tensorflow, serialized into byte array.
|
|
85
|
+
enableOutputLogs
|
|
86
|
+
Whether to use stdout in addition to Spark logs, by default False
|
|
87
|
+
enableOutputLogs
|
|
88
|
+
Whether to use stdout in addition to Spark logs.
|
|
89
|
+
evaluationLogExtended
|
|
90
|
+
Whether logs for validation to be extended: it displays time and evaluation of
|
|
91
|
+
each label. Default is False.
|
|
92
|
+
labelColumn
|
|
93
|
+
Column with label per each token
|
|
94
|
+
lr
|
|
95
|
+
Learning Rate, by default 0.001
|
|
96
|
+
maxEpochs
|
|
97
|
+
Maximum number of epochs to train, by default 10
|
|
98
|
+
outputLogsPath
|
|
99
|
+
Folder path to save training logs
|
|
100
|
+
randomSeed
|
|
101
|
+
Random seed, by default 44
|
|
102
|
+
shufflePerEpoch
|
|
103
|
+
whether to shuffle the training data on each Epoch, by default False
|
|
104
|
+
testDataset
|
|
105
|
+
Path to test dataset. If set used to calculate statistic on it during training.
|
|
106
|
+
threshold
|
|
107
|
+
The minimum threshold for each label to be accepted, by default 0.5
|
|
108
|
+
validationSplit
|
|
109
|
+
Choose the proportion of training dataset to be validated against the
|
|
110
|
+
model on each Epoch. The value should be between 0.0 and 1.0 and by
|
|
111
|
+
default it is 0.0 and off, by default 0.0
|
|
112
|
+
verbose
|
|
113
|
+
Level of verbosity during training
|
|
114
|
+
|
|
115
|
+
Notes
|
|
116
|
+
-----
|
|
117
|
+
- This annotator requires an array of labels in type of String.
|
|
118
|
+
- UniversalSentenceEncoder, BertSentenceEmbeddings, SentenceEmbeddings or
|
|
119
|
+
other sentence embeddings can be used for the ``inputCol``.
|
|
120
|
+
|
|
121
|
+
Examples
|
|
122
|
+
--------
|
|
123
|
+
>>> import sparknlp
|
|
124
|
+
>>> from sparknlp.base import *
|
|
125
|
+
>>> from sparknlp.annotator import *
|
|
126
|
+
>>> from pyspark.ml import Pipeline
|
|
127
|
+
|
|
128
|
+
In this example, the training data has the form::
|
|
129
|
+
|
|
130
|
+
+----------------+--------------------+--------------------+
|
|
131
|
+
| id| text| labels|
|
|
132
|
+
+----------------+--------------------+--------------------+
|
|
133
|
+
|ed58abb40640f983|PN NewsYou mean ... | [toxic]|
|
|
134
|
+
|a1237f726b5f5d89|Dude. Place the ...| [obscene, insult]|
|
|
135
|
+
|24b0d6c8733c2abe|Thanks - thanks ...| [insult]|
|
|
136
|
+
|8c4478fb239bcfc0|" Gee, 5 minutes ...|[toxic, obscene, ...|
|
|
137
|
+
+----------------+--------------------+--------------------+
|
|
138
|
+
|
|
139
|
+
Process training data to create text with associated array of labels:
|
|
140
|
+
|
|
141
|
+
>>> trainDataset.printSchema()
|
|
142
|
+
root
|
|
143
|
+
|-- id: string (nullable = true)
|
|
144
|
+
|-- text: string (nullable = true)
|
|
145
|
+
|-- labels: array (nullable = true)
|
|
146
|
+
| |-- element: string (containsNull = true)
|
|
147
|
+
|
|
148
|
+
Then create pipeline for training:
|
|
149
|
+
|
|
150
|
+
>>> documentAssembler = DocumentAssembler() \\
|
|
151
|
+
... .setInputCol("text") \\
|
|
152
|
+
... .setOutputCol("document") \\
|
|
153
|
+
... .setCleanupMode("shrink")
|
|
154
|
+
>>> embeddings = UniversalSentenceEncoder.pretrained() \\
|
|
155
|
+
... .setInputCols("document") \\
|
|
156
|
+
... .setOutputCol("embeddings")
|
|
157
|
+
>>> docClassifier = MultiClassifierDLApproach() \\
|
|
158
|
+
... .setInputCols("embeddings") \\
|
|
159
|
+
... .setOutputCol("category") \\
|
|
160
|
+
... .setLabelColumn("labels") \\
|
|
161
|
+
... .setBatchSize(128) \\
|
|
162
|
+
... .setMaxEpochs(10) \\
|
|
163
|
+
... .setLr(1e-3) \\
|
|
164
|
+
... .setThreshold(0.5) \\
|
|
165
|
+
... .setValidationSplit(0.1)
|
|
166
|
+
>>> pipeline = Pipeline().setStages([
|
|
167
|
+
... documentAssembler,
|
|
168
|
+
... embeddings,
|
|
169
|
+
... docClassifier
|
|
170
|
+
... ])
|
|
171
|
+
>>> pipelineModel = pipeline.fit(trainDataset)
|
|
172
|
+
|
|
173
|
+
See Also
|
|
174
|
+
--------
|
|
175
|
+
ClassifierDLApproach : for single-class classification
|
|
176
|
+
SentimentDLApproach : for sentiment analysis
|
|
177
|
+
"""
|
|
178
|
+
inputAnnotatorTypes = [AnnotatorType.SENTENCE_EMBEDDINGS]
|
|
179
|
+
|
|
180
|
+
outputAnnotatorType = AnnotatorType.CATEGORY
|
|
181
|
+
|
|
182
|
+
shufflePerEpoch = Param(Params._dummy(), "shufflePerEpoch", "whether to shuffle the training data on each Epoch",
|
|
183
|
+
TypeConverters.toBoolean)
|
|
184
|
+
threshold = Param(Params._dummy(), "threshold",
|
|
185
|
+
"The minimum threshold for each label to be accepted. Default is 0.5", TypeConverters.toFloat)
|
|
186
|
+
|
|
187
|
+
def setVerbose(self, v):
|
|
188
|
+
"""Sets level of verbosity during training.
|
|
189
|
+
|
|
190
|
+
Parameters
|
|
191
|
+
----------
|
|
192
|
+
v : int
|
|
193
|
+
Level of verbosity
|
|
194
|
+
"""
|
|
195
|
+
return self._set(verbose=v)
|
|
196
|
+
|
|
197
|
+
def setShufflePerEpoch(self, v):
|
|
198
|
+
return self._set(shufflePerEpoch=v)
|
|
199
|
+
|
|
200
|
+
def setThreshold(self, v):
|
|
201
|
+
"""Sets minimum threshold for each label to be accepted, by default 0.5.
|
|
202
|
+
|
|
203
|
+
Parameters
|
|
204
|
+
----------
|
|
205
|
+
v : float
|
|
206
|
+
The minimum threshold for each label to be accepted, by default 0.5
|
|
207
|
+
"""
|
|
208
|
+
self._set(threshold=v)
|
|
209
|
+
return self
|
|
210
|
+
|
|
211
|
+
def _create_model(self, java_model):
|
|
212
|
+
return ClassifierDLModel(java_model=java_model)
|
|
213
|
+
|
|
214
|
+
@keyword_only
|
|
215
|
+
def __init__(self):
|
|
216
|
+
super(MultiClassifierDLApproach, self).__init__(
|
|
217
|
+
classname="com.johnsnowlabs.nlp.annotators.classifier.dl.MultiClassifierDLApproach")
|
|
218
|
+
self._setDefault(
|
|
219
|
+
maxEpochs=10,
|
|
220
|
+
lr=float(0.001),
|
|
221
|
+
batchSize=64,
|
|
222
|
+
validationSplit=float(0.0),
|
|
223
|
+
threshold=float(0.5),
|
|
224
|
+
randomSeed=44,
|
|
225
|
+
shufflePerEpoch=False,
|
|
226
|
+
enableOutputLogs=False
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
class MultiClassifierDLModel(AnnotatorModel, HasStorageRef, HasEngine):
|
|
231
|
+
"""MultiClassifierDL for Multi-label Text Classification.
|
|
232
|
+
|
|
233
|
+
MultiClassifierDL Bidirectional GRU with Convolution model we have built
|
|
234
|
+
inside TensorFlow and supports up to 100 classes.
|
|
235
|
+
|
|
236
|
+
In machine learning, multi-label classification and the strongly related
|
|
237
|
+
problem of multi-output classification are variants of the classification
|
|
238
|
+
problem where multiple labels may be assigned to each instance. Multi-label
|
|
239
|
+
classification is a generalization of multiclass classification, which is
|
|
240
|
+
the single-label problem of categorizing instances into precisely one of
|
|
241
|
+
more than two classes; in the multi-label problem there is no constraint on
|
|
242
|
+
how many of the classes the instance can be assigned to. Formally,
|
|
243
|
+
multi-label classification is the problem of finding a model that maps
|
|
244
|
+
inputs x to binary vectors y (assigning a value of 0 or 1 for each element
|
|
245
|
+
(label) in y).
|
|
246
|
+
|
|
247
|
+
The input to ``MultiClassifierDL`` are Sentence Embeddings such as the
|
|
248
|
+
state-of-the-art :class:`.UniversalSentenceEncoder`,
|
|
249
|
+
:class:`.BertSentenceEmbeddings`, :class:`.SentenceEmbeddings` or other
|
|
250
|
+
sentence embeddings.
|
|
251
|
+
|
|
252
|
+
This is the instantiated model of the :class:`.MultiClassifierDLApproach`.
|
|
253
|
+
For training your own model, please see the documentation of that class.
|
|
254
|
+
|
|
255
|
+
Pretrained models can be loaded with :meth:`.pretrained` of the companion
|
|
256
|
+
object:
|
|
257
|
+
|
|
258
|
+
>>> multiClassifier = MultiClassifierDLModel.pretrained() \\
|
|
259
|
+
>>> .setInputCols(["sentence_embeddings"]) \\
|
|
260
|
+
>>> .setOutputCol("categories")
|
|
261
|
+
|
|
262
|
+
The default model is ``"multiclassifierdl_use_toxic"``, if no name is
|
|
263
|
+
provided. It uses embeddings from the UniversalSentenceEncoder and
|
|
264
|
+
classifies toxic comments.
|
|
265
|
+
|
|
266
|
+
The data is based on the
|
|
267
|
+
`Jigsaw Toxic Comment Classification Challenge <https://www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge/overview>`__.
|
|
268
|
+
For available pretrained models please see the `Models Hub <https://sparknlp.org/models?task=Text+Classification>`__.
|
|
269
|
+
|
|
270
|
+
For extended examples of usage, see the `Examples <https://github.com/JohnSnowLabs/spark-nlp/blob/master/examples/python/training/english/classification/MultiClassifierDL_train_multi_label_E2E_challenge_classifier.ipynb>`__.
|
|
271
|
+
|
|
272
|
+
======================= ======================
|
|
273
|
+
Input Annotation types Output Annotation type
|
|
274
|
+
======================= ======================
|
|
275
|
+
``SENTENCE_EMBEDDINGS`` ``CATEGORY``
|
|
276
|
+
======================= ======================
|
|
277
|
+
|
|
278
|
+
Parameters
|
|
279
|
+
----------
|
|
280
|
+
configProtoBytes
|
|
281
|
+
ConfigProto from tensorflow, serialized into byte array.
|
|
282
|
+
threshold
|
|
283
|
+
The minimum threshold for each label to be accepted, by default 0.5
|
|
284
|
+
classes
|
|
285
|
+
Get the tags used to trained this MultiClassifierDLModel
|
|
286
|
+
|
|
287
|
+
Examples
|
|
288
|
+
--------
|
|
289
|
+
>>> import sparknlp
|
|
290
|
+
>>> from sparknlp.base import *
|
|
291
|
+
>>> from sparknlp.annotator import *
|
|
292
|
+
>>> from pyspark.ml import Pipeline
|
|
293
|
+
>>> documentAssembler = DocumentAssembler() \\
|
|
294
|
+
... .setInputCol("text") \\
|
|
295
|
+
... .setOutputCol("document")
|
|
296
|
+
>>> useEmbeddings = UniversalSentenceEncoder.pretrained() \\
|
|
297
|
+
... .setInputCols("document") \\
|
|
298
|
+
... .setOutputCol("sentence_embeddings")
|
|
299
|
+
>>> multiClassifierDl = MultiClassifierDLModel.pretrained() \\
|
|
300
|
+
... .setInputCols("sentence_embeddings") \\
|
|
301
|
+
... .setOutputCol("classifications")
|
|
302
|
+
>>> pipeline = Pipeline() \\
|
|
303
|
+
... .setStages([
|
|
304
|
+
... documentAssembler,
|
|
305
|
+
... useEmbeddings,
|
|
306
|
+
... multiClassifierDl
|
|
307
|
+
... ])
|
|
308
|
+
>>> data = spark.createDataFrame([
|
|
309
|
+
... ["This is pretty good stuff!"],
|
|
310
|
+
... ["Wtf kind of crap is this"]
|
|
311
|
+
... ]).toDF("text")
|
|
312
|
+
>>> result = pipeline.fit(data).transform(data)
|
|
313
|
+
>>> result.select("text", "classifications.result").show(truncate=False)
|
|
314
|
+
+--------------------------+----------------+
|
|
315
|
+
|text |result |
|
|
316
|
+
+--------------------------+----------------+
|
|
317
|
+
|This is pretty good stuff!|[] |
|
|
318
|
+
|Wtf kind of crap is this |[toxic, obscene]|
|
|
319
|
+
+--------------------------+----------------+
|
|
320
|
+
|
|
321
|
+
See Also
|
|
322
|
+
--------
|
|
323
|
+
ClassifierDLModel : for single-class classification
|
|
324
|
+
SentimentDLModel : for sentiment analysis
|
|
325
|
+
"""
|
|
326
|
+
name = "MultiClassifierDLModel"
|
|
327
|
+
|
|
328
|
+
inputAnnotatorTypes = [AnnotatorType.SENTENCE_EMBEDDINGS]
|
|
329
|
+
|
|
330
|
+
outputAnnotatorType = AnnotatorType.CATEGORY
|
|
331
|
+
|
|
332
|
+
def __init__(self, classname="com.johnsnowlabs.nlp.annotators.classifier.dl.MultiClassifierDLModel",
|
|
333
|
+
java_model=None):
|
|
334
|
+
super(MultiClassifierDLModel, self).__init__(
|
|
335
|
+
classname=classname,
|
|
336
|
+
java_model=java_model
|
|
337
|
+
)
|
|
338
|
+
self._setDefault(
|
|
339
|
+
threshold=float(0.5)
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
configProtoBytes = Param(Params._dummy(), "configProtoBytes",
|
|
343
|
+
"ConfigProto from tensorflow, serialized into byte array. Get with config_proto.SerializeToString()",
|
|
344
|
+
TypeConverters.toListInt)
|
|
345
|
+
|
|
346
|
+
threshold = Param(Params._dummy(), "threshold",
|
|
347
|
+
"The minimum threshold for each label to be accepted. Default is 0.5", TypeConverters.toFloat)
|
|
348
|
+
|
|
349
|
+
classes = Param(Params._dummy(), "classes",
|
|
350
|
+
"get the tags used to trained this MultiClassifierDLModel",
|
|
351
|
+
TypeConverters.toListString)
|
|
352
|
+
|
|
353
|
+
def setThreshold(self, v):
|
|
354
|
+
"""Sets minimum threshold for each label to be accepted, by default 0.5.
|
|
355
|
+
|
|
356
|
+
Parameters
|
|
357
|
+
----------
|
|
358
|
+
v : float
|
|
359
|
+
The minimum threshold for each label to be accepted, by default 0.5
|
|
360
|
+
"""
|
|
361
|
+
self._set(threshold=v)
|
|
362
|
+
return self
|
|
363
|
+
|
|
364
|
+
def setConfigProtoBytes(self, b):
|
|
365
|
+
"""Sets configProto from tensorflow, serialized into byte array.
|
|
366
|
+
|
|
367
|
+
Parameters
|
|
368
|
+
----------
|
|
369
|
+
b : List[int]
|
|
370
|
+
ConfigProto from tensorflow, serialized into byte array
|
|
371
|
+
"""
|
|
372
|
+
return self._set(configProtoBytes=b)
|
|
373
|
+
|
|
374
|
+
@staticmethod
|
|
375
|
+
def pretrained(name="multiclassifierdl_use_toxic", lang="en", remote_loc=None):
|
|
376
|
+
"""Downloads and loads a pretrained model.
|
|
377
|
+
|
|
378
|
+
Parameters
|
|
379
|
+
----------
|
|
380
|
+
name : str, optional
|
|
381
|
+
Name of the pretrained model, by default
|
|
382
|
+
"multiclassifierdl_use_toxic"
|
|
383
|
+
lang : str, optional
|
|
384
|
+
Language of the pretrained model, by default "en"
|
|
385
|
+
remote_loc : str, optional
|
|
386
|
+
Optional remote address of the resource, by default None. Will use
|
|
387
|
+
Spark NLPs repositories otherwise.
|
|
388
|
+
|
|
389
|
+
Returns
|
|
390
|
+
-------
|
|
391
|
+
MultiClassifierDLModel
|
|
392
|
+
The restored model
|
|
393
|
+
"""
|
|
394
|
+
from sparknlp.pretrained import ResourceDownloader
|
|
395
|
+
return ResourceDownloader.downloadModel(MultiClassifierDLModel, name, lang, remote_loc)
|
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
# Copyright 2017-2025 John Snow Labs
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from sparknlp.common import *
|
|
16
|
+
|
|
17
|
+
class RoBertaForMultipleChoice(AnnotatorModel,
|
|
18
|
+
HasCaseSensitiveProperties,
|
|
19
|
+
HasBatchedAnnotate,
|
|
20
|
+
HasEngine,
|
|
21
|
+
HasMaxSentenceLengthLimit):
|
|
22
|
+
"""RoBertaForMultipleChoice can load RoBERTa Models with a multiple choice classification head on top
|
|
23
|
+
(a linear layer on top of the pooled output and a softmax) e.g. for RocStories/SWAG tasks.
|
|
24
|
+
|
|
25
|
+
Pretrained models can be loaded with :meth:`.pretrained` of the companion
|
|
26
|
+
object:
|
|
27
|
+
|
|
28
|
+
>>> spanClassifier = RoBertaForMultipleChoice.pretrained() \\
|
|
29
|
+
... .setInputCols(["document_question", "document_context"]) \\
|
|
30
|
+
... .setOutputCol("answer")
|
|
31
|
+
|
|
32
|
+
The default model is ``"roberta_base_uncased_multiple_choice"``, if no name is
|
|
33
|
+
provided.
|
|
34
|
+
|
|
35
|
+
For available pretrained models please see the `Models Hub
|
|
36
|
+
<https://sparknlp.org/models?task=Multiple+Choice>`__.
|
|
37
|
+
|
|
38
|
+
To see which models are compatible and how to import them see
|
|
39
|
+
`Import Transformers into Spark NLP 🚀
|
|
40
|
+
<https://github.com/JohnSnowLabs/spark-nlp/discussions/5669>`_.
|
|
41
|
+
|
|
42
|
+
====================== ======================
|
|
43
|
+
Input Annotation types Output Annotation type
|
|
44
|
+
====================== ======================
|
|
45
|
+
``DOCUMENT, DOCUMENT`` ``CHUNK``
|
|
46
|
+
====================== ======================
|
|
47
|
+
|
|
48
|
+
Parameters
|
|
49
|
+
----------
|
|
50
|
+
batchSize
|
|
51
|
+
Batch size. Large values allows faster processing but requires more
|
|
52
|
+
memory, by default 8
|
|
53
|
+
caseSensitive
|
|
54
|
+
Whether to ignore case in tokens for embeddings matching, by default
|
|
55
|
+
False
|
|
56
|
+
maxSentenceLength
|
|
57
|
+
Max sentence length to process, by default 512
|
|
58
|
+
|
|
59
|
+
Examples
|
|
60
|
+
--------
|
|
61
|
+
>>> import sparknlp
|
|
62
|
+
>>> from sparknlp.base import *
|
|
63
|
+
>>> from sparknlp.annotator import *
|
|
64
|
+
>>> from pyspark.ml import Pipeline
|
|
65
|
+
>>> documentAssembler = MultiDocumentAssembler() \\
|
|
66
|
+
... .setInputCols(["question", "context"]) \\
|
|
67
|
+
... .setOutputCols(["document_question", "document_context"])
|
|
68
|
+
>>> questionAnswering = RoBertaForMultipleChoice.pretrained() \\
|
|
69
|
+
... .setInputCols(["document_question", "document_context"]) \\
|
|
70
|
+
... .setOutputCol("answer") \\
|
|
71
|
+
... .setCaseSensitive(False)
|
|
72
|
+
>>> pipeline = Pipeline().setStages([
|
|
73
|
+
... documentAssembler,
|
|
74
|
+
... questionAnswering
|
|
75
|
+
... ])
|
|
76
|
+
>>> data = spark.createDataFrame([["The Eiffel Tower is located in which country??", "Germany, France, Italy"]]).toDF("question", "context")
|
|
77
|
+
>>> result = pipeline.fit(data).transform(data)
|
|
78
|
+
>>> result.select("answer.result").show(truncate=False)
|
|
79
|
+
+--------------------+
|
|
80
|
+
|result |
|
|
81
|
+
+--------------------+
|
|
82
|
+
|[France] |
|
|
83
|
+
+--------------------+
|
|
84
|
+
"""
|
|
85
|
+
name = "RobertaForMultipleChoice"
|
|
86
|
+
|
|
87
|
+
inputAnnotatorTypes = [AnnotatorType.DOCUMENT, AnnotatorType.DOCUMENT]
|
|
88
|
+
|
|
89
|
+
outputAnnotatorType = AnnotatorType.CHUNK
|
|
90
|
+
|
|
91
|
+
choicesDelimiter = Param(Params._dummy(),
|
|
92
|
+
"choicesDelimiter",
|
|
93
|
+
"Delimiter character use to split the choices",
|
|
94
|
+
TypeConverters.toString)
|
|
95
|
+
|
|
96
|
+
def setChoicesDelimiter(self, value):
|
|
97
|
+
"""Sets delimiter character use to split the choices
|
|
98
|
+
|
|
99
|
+
Parameters
|
|
100
|
+
----------
|
|
101
|
+
value : string
|
|
102
|
+
Delimiter character use to split the choices
|
|
103
|
+
"""
|
|
104
|
+
return self._set(caseSensitive=value)
|
|
105
|
+
|
|
106
|
+
@keyword_only
|
|
107
|
+
def __init__(self, classname="com.johnsnowlabs.nlp.annotators.classifier.dl.RoBertaForMultipleChoice",
|
|
108
|
+
java_model=None):
|
|
109
|
+
super(RoBertaForMultipleChoice, self).__init__(
|
|
110
|
+
classname=classname,
|
|
111
|
+
java_model=java_model
|
|
112
|
+
)
|
|
113
|
+
self._setDefault(
|
|
114
|
+
batchSize=4,
|
|
115
|
+
maxSentenceLength=512,
|
|
116
|
+
caseSensitive=False,
|
|
117
|
+
choicesDelimiter = ","
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
@staticmethod
|
|
121
|
+
def loadSavedModel(folder, spark_session):
|
|
122
|
+
"""Loads a locally saved model.
|
|
123
|
+
|
|
124
|
+
Parameters
|
|
125
|
+
----------
|
|
126
|
+
folder : str
|
|
127
|
+
Folder of the saved model
|
|
128
|
+
spark_session : pyspark.sql.SparkSession
|
|
129
|
+
The current SparkSession
|
|
130
|
+
|
|
131
|
+
Returns
|
|
132
|
+
-------
|
|
133
|
+
RobertaForQuestionAnswering
|
|
134
|
+
The restored model
|
|
135
|
+
"""
|
|
136
|
+
from sparknlp.internal import _RoBertaMultipleChoiceLoader
|
|
137
|
+
jModel = _RoBertaMultipleChoiceLoader(folder, spark_session._jsparkSession)._java_obj
|
|
138
|
+
return RoBertaForMultipleChoice(java_model=jModel)
|
|
139
|
+
|
|
140
|
+
@staticmethod
|
|
141
|
+
def pretrained(name="Roberta_base_uncased_multiple_choice", lang="en", remote_loc=None):
|
|
142
|
+
"""Downloads and loads a pretrained model.
|
|
143
|
+
|
|
144
|
+
Parameters
|
|
145
|
+
----------
|
|
146
|
+
name : str, optional
|
|
147
|
+
Name of the pretrained model, by default
|
|
148
|
+
"Roberta_base_uncased_multiple_choice"
|
|
149
|
+
lang : str, optional
|
|
150
|
+
Language of the pretrained model, by default "en"
|
|
151
|
+
remote_loc : str, optional
|
|
152
|
+
Optional remote address of the resource, by default None. Will use
|
|
153
|
+
Spark NLPs repositories otherwise.
|
|
154
|
+
|
|
155
|
+
Returns
|
|
156
|
+
-------
|
|
157
|
+
RoBertaForMultipleChoice
|
|
158
|
+
The restored model
|
|
159
|
+
"""
|
|
160
|
+
from sparknlp.pretrained import ResourceDownloader
|
|
161
|
+
return ResourceDownloader.downloadModel(RoBertaForMultipleChoice, name, lang, remote_loc)
|