spark-nlp 4.2.6__py2.py3-none-any.whl → 6.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (221) hide show
  1. com/johnsnowlabs/ml/__init__.py +0 -0
  2. com/johnsnowlabs/ml/ai/__init__.py +10 -0
  3. spark_nlp-6.2.1.dist-info/METADATA +362 -0
  4. spark_nlp-6.2.1.dist-info/RECORD +292 -0
  5. {spark_nlp-4.2.6.dist-info → spark_nlp-6.2.1.dist-info}/WHEEL +1 -1
  6. sparknlp/__init__.py +81 -28
  7. sparknlp/annotation.py +3 -2
  8. sparknlp/annotator/__init__.py +6 -0
  9. sparknlp/annotator/audio/__init__.py +2 -0
  10. sparknlp/annotator/audio/hubert_for_ctc.py +188 -0
  11. sparknlp/annotator/audio/wav2vec2_for_ctc.py +14 -14
  12. sparknlp/annotator/audio/whisper_for_ctc.py +251 -0
  13. sparknlp/{base → annotator}/chunk2_doc.py +4 -7
  14. sparknlp/annotator/chunker.py +1 -2
  15. sparknlp/annotator/classifier_dl/__init__.py +17 -0
  16. sparknlp/annotator/classifier_dl/albert_for_multiple_choice.py +161 -0
  17. sparknlp/annotator/classifier_dl/albert_for_question_answering.py +3 -15
  18. sparknlp/annotator/classifier_dl/albert_for_sequence_classification.py +4 -18
  19. sparknlp/annotator/classifier_dl/albert_for_token_classification.py +3 -17
  20. sparknlp/annotator/classifier_dl/albert_for_zero_shot_classification.py +211 -0
  21. sparknlp/annotator/classifier_dl/bart_for_zero_shot_classification.py +225 -0
  22. sparknlp/annotator/classifier_dl/bert_for_multiple_choice.py +161 -0
  23. sparknlp/annotator/classifier_dl/bert_for_question_answering.py +6 -20
  24. sparknlp/annotator/classifier_dl/bert_for_sequence_classification.py +3 -17
  25. sparknlp/annotator/classifier_dl/bert_for_token_classification.py +3 -17
  26. sparknlp/annotator/classifier_dl/bert_for_zero_shot_classification.py +212 -0
  27. sparknlp/annotator/classifier_dl/camembert_for_question_answering.py +168 -0
  28. sparknlp/annotator/classifier_dl/camembert_for_sequence_classification.py +5 -19
  29. sparknlp/annotator/classifier_dl/camembert_for_token_classification.py +5 -19
  30. sparknlp/annotator/classifier_dl/camembert_for_zero_shot_classification.py +202 -0
  31. sparknlp/annotator/classifier_dl/classifier_dl.py +4 -4
  32. sparknlp/annotator/classifier_dl/deberta_for_question_answering.py +3 -17
  33. sparknlp/annotator/classifier_dl/deberta_for_sequence_classification.py +4 -19
  34. sparknlp/annotator/classifier_dl/deberta_for_token_classification.py +5 -21
  35. sparknlp/annotator/classifier_dl/deberta_for_zero_shot_classification.py +193 -0
  36. sparknlp/annotator/classifier_dl/distil_bert_for_question_answering.py +3 -17
  37. sparknlp/annotator/classifier_dl/distil_bert_for_sequence_classification.py +4 -18
  38. sparknlp/annotator/classifier_dl/distil_bert_for_token_classification.py +3 -17
  39. sparknlp/annotator/classifier_dl/distil_bert_for_zero_shot_classification.py +211 -0
  40. sparknlp/annotator/classifier_dl/distilbert_for_multiple_choice.py +161 -0
  41. sparknlp/annotator/classifier_dl/longformer_for_question_answering.py +3 -17
  42. sparknlp/annotator/classifier_dl/longformer_for_sequence_classification.py +4 -18
  43. sparknlp/annotator/classifier_dl/longformer_for_token_classification.py +3 -17
  44. sparknlp/annotator/classifier_dl/mpnet_for_question_answering.py +148 -0
  45. sparknlp/annotator/classifier_dl/mpnet_for_sequence_classification.py +188 -0
  46. sparknlp/annotator/classifier_dl/mpnet_for_token_classification.py +173 -0
  47. sparknlp/annotator/classifier_dl/multi_classifier_dl.py +3 -3
  48. sparknlp/annotator/classifier_dl/roberta_for_multiple_choice.py +161 -0
  49. sparknlp/annotator/classifier_dl/roberta_for_question_answering.py +3 -17
  50. sparknlp/annotator/classifier_dl/roberta_for_sequence_classification.py +4 -18
  51. sparknlp/annotator/classifier_dl/roberta_for_token_classification.py +1 -1
  52. sparknlp/annotator/classifier_dl/roberta_for_zero_shot_classification.py +225 -0
  53. sparknlp/annotator/classifier_dl/sentiment_dl.py +4 -4
  54. sparknlp/annotator/classifier_dl/tapas_for_question_answering.py +2 -2
  55. sparknlp/annotator/classifier_dl/xlm_roberta_for_multiple_choice.py +149 -0
  56. sparknlp/annotator/classifier_dl/xlm_roberta_for_question_answering.py +3 -17
  57. sparknlp/annotator/classifier_dl/xlm_roberta_for_sequence_classification.py +4 -18
  58. sparknlp/annotator/classifier_dl/xlm_roberta_for_token_classification.py +6 -20
  59. sparknlp/annotator/classifier_dl/xlm_roberta_for_zero_shot_classification.py +225 -0
  60. sparknlp/annotator/classifier_dl/xlnet_for_sequence_classification.py +4 -18
  61. sparknlp/annotator/classifier_dl/xlnet_for_token_classification.py +3 -17
  62. sparknlp/annotator/cleaners/__init__.py +15 -0
  63. sparknlp/annotator/cleaners/cleaner.py +202 -0
  64. sparknlp/annotator/cleaners/extractor.py +191 -0
  65. sparknlp/annotator/coref/spanbert_coref.py +4 -18
  66. sparknlp/annotator/cv/__init__.py +15 -0
  67. sparknlp/annotator/cv/blip_for_question_answering.py +172 -0
  68. sparknlp/annotator/cv/clip_for_zero_shot_classification.py +193 -0
  69. sparknlp/annotator/cv/convnext_for_image_classification.py +269 -0
  70. sparknlp/annotator/cv/florence2_transformer.py +180 -0
  71. sparknlp/annotator/cv/gemma3_for_multimodal.py +346 -0
  72. sparknlp/annotator/cv/internvl_for_multimodal.py +280 -0
  73. sparknlp/annotator/cv/janus_for_multimodal.py +351 -0
  74. sparknlp/annotator/cv/llava_for_multimodal.py +328 -0
  75. sparknlp/annotator/cv/mllama_for_multimodal.py +340 -0
  76. sparknlp/annotator/cv/paligemma_for_multimodal.py +308 -0
  77. sparknlp/annotator/cv/phi3_vision_for_multimodal.py +328 -0
  78. sparknlp/annotator/cv/qwen2vl_transformer.py +332 -0
  79. sparknlp/annotator/cv/smolvlm_transformer.py +426 -0
  80. sparknlp/annotator/cv/swin_for_image_classification.py +242 -0
  81. sparknlp/annotator/cv/vision_encoder_decoder_for_image_captioning.py +240 -0
  82. sparknlp/annotator/cv/vit_for_image_classification.py +36 -4
  83. sparknlp/annotator/dataframe_optimizer.py +216 -0
  84. sparknlp/annotator/date2_chunk.py +88 -0
  85. sparknlp/annotator/dependency/dependency_parser.py +2 -3
  86. sparknlp/annotator/dependency/typed_dependency_parser.py +3 -4
  87. sparknlp/annotator/document_character_text_splitter.py +228 -0
  88. sparknlp/annotator/document_normalizer.py +37 -1
  89. sparknlp/annotator/document_token_splitter.py +175 -0
  90. sparknlp/annotator/document_token_splitter_test.py +85 -0
  91. sparknlp/annotator/embeddings/__init__.py +11 -0
  92. sparknlp/annotator/embeddings/albert_embeddings.py +4 -18
  93. sparknlp/annotator/embeddings/auto_gguf_embeddings.py +539 -0
  94. sparknlp/annotator/embeddings/bert_embeddings.py +9 -22
  95. sparknlp/annotator/embeddings/bert_sentence_embeddings.py +12 -24
  96. sparknlp/annotator/embeddings/bge_embeddings.py +199 -0
  97. sparknlp/annotator/embeddings/camembert_embeddings.py +4 -20
  98. sparknlp/annotator/embeddings/chunk_embeddings.py +1 -2
  99. sparknlp/annotator/embeddings/deberta_embeddings.py +2 -16
  100. sparknlp/annotator/embeddings/distil_bert_embeddings.py +5 -19
  101. sparknlp/annotator/embeddings/doc2vec.py +7 -1
  102. sparknlp/annotator/embeddings/e5_embeddings.py +195 -0
  103. sparknlp/annotator/embeddings/e5v_embeddings.py +138 -0
  104. sparknlp/annotator/embeddings/elmo_embeddings.py +2 -2
  105. sparknlp/annotator/embeddings/instructor_embeddings.py +204 -0
  106. sparknlp/annotator/embeddings/longformer_embeddings.py +3 -17
  107. sparknlp/annotator/embeddings/minilm_embeddings.py +189 -0
  108. sparknlp/annotator/embeddings/mpnet_embeddings.py +192 -0
  109. sparknlp/annotator/embeddings/mxbai_embeddings.py +184 -0
  110. sparknlp/annotator/embeddings/nomic_embeddings.py +181 -0
  111. sparknlp/annotator/embeddings/roberta_embeddings.py +9 -21
  112. sparknlp/annotator/embeddings/roberta_sentence_embeddings.py +7 -21
  113. sparknlp/annotator/embeddings/sentence_embeddings.py +2 -3
  114. sparknlp/annotator/embeddings/snowflake_embeddings.py +202 -0
  115. sparknlp/annotator/embeddings/uae_embeddings.py +211 -0
  116. sparknlp/annotator/embeddings/universal_sentence_encoder.py +3 -3
  117. sparknlp/annotator/embeddings/word2vec.py +7 -1
  118. sparknlp/annotator/embeddings/word_embeddings.py +4 -5
  119. sparknlp/annotator/embeddings/xlm_roberta_embeddings.py +9 -21
  120. sparknlp/annotator/embeddings/xlm_roberta_sentence_embeddings.py +7 -21
  121. sparknlp/annotator/embeddings/xlnet_embeddings.py +4 -18
  122. sparknlp/annotator/er/entity_ruler.py +37 -23
  123. sparknlp/annotator/keyword_extraction/yake_keyword_extraction.py +2 -3
  124. sparknlp/annotator/ld_dl/language_detector_dl.py +2 -2
  125. sparknlp/annotator/lemmatizer.py +3 -4
  126. sparknlp/annotator/matcher/date_matcher.py +35 -3
  127. sparknlp/annotator/matcher/multi_date_matcher.py +1 -2
  128. sparknlp/annotator/matcher/regex_matcher.py +3 -3
  129. sparknlp/annotator/matcher/text_matcher.py +2 -3
  130. sparknlp/annotator/n_gram_generator.py +1 -2
  131. sparknlp/annotator/ner/__init__.py +3 -1
  132. sparknlp/annotator/ner/ner_converter.py +18 -0
  133. sparknlp/annotator/ner/ner_crf.py +4 -5
  134. sparknlp/annotator/ner/ner_dl.py +10 -5
  135. sparknlp/annotator/ner/ner_dl_graph_checker.py +293 -0
  136. sparknlp/annotator/ner/ner_overwriter.py +2 -2
  137. sparknlp/annotator/ner/zero_shot_ner_model.py +173 -0
  138. sparknlp/annotator/normalizer.py +2 -2
  139. sparknlp/annotator/openai/__init__.py +16 -0
  140. sparknlp/annotator/openai/openai_completion.py +349 -0
  141. sparknlp/annotator/openai/openai_embeddings.py +106 -0
  142. sparknlp/annotator/pos/perceptron.py +6 -7
  143. sparknlp/annotator/sentence/sentence_detector.py +2 -2
  144. sparknlp/annotator/sentence/sentence_detector_dl.py +3 -3
  145. sparknlp/annotator/sentiment/sentiment_detector.py +4 -5
  146. sparknlp/annotator/sentiment/vivekn_sentiment.py +4 -5
  147. sparknlp/annotator/seq2seq/__init__.py +17 -0
  148. sparknlp/annotator/seq2seq/auto_gguf_model.py +304 -0
  149. sparknlp/annotator/seq2seq/auto_gguf_reranker.py +334 -0
  150. sparknlp/annotator/seq2seq/auto_gguf_vision_model.py +336 -0
  151. sparknlp/annotator/seq2seq/bart_transformer.py +420 -0
  152. sparknlp/annotator/seq2seq/cohere_transformer.py +357 -0
  153. sparknlp/annotator/seq2seq/cpm_transformer.py +321 -0
  154. sparknlp/annotator/seq2seq/gpt2_transformer.py +1 -1
  155. sparknlp/annotator/seq2seq/llama2_transformer.py +343 -0
  156. sparknlp/annotator/seq2seq/llama3_transformer.py +381 -0
  157. sparknlp/annotator/seq2seq/m2m100_transformer.py +392 -0
  158. sparknlp/annotator/seq2seq/marian_transformer.py +124 -3
  159. sparknlp/annotator/seq2seq/mistral_transformer.py +348 -0
  160. sparknlp/annotator/seq2seq/nllb_transformer.py +420 -0
  161. sparknlp/annotator/seq2seq/olmo_transformer.py +326 -0
  162. sparknlp/annotator/seq2seq/phi2_transformer.py +326 -0
  163. sparknlp/annotator/seq2seq/phi3_transformer.py +330 -0
  164. sparknlp/annotator/seq2seq/phi4_transformer.py +387 -0
  165. sparknlp/annotator/seq2seq/qwen_transformer.py +340 -0
  166. sparknlp/annotator/seq2seq/starcoder_transformer.py +335 -0
  167. sparknlp/annotator/seq2seq/t5_transformer.py +54 -4
  168. sparknlp/annotator/similarity/__init__.py +0 -0
  169. sparknlp/annotator/similarity/document_similarity_ranker.py +379 -0
  170. sparknlp/annotator/spell_check/context_spell_checker.py +116 -17
  171. sparknlp/annotator/spell_check/norvig_sweeting.py +3 -6
  172. sparknlp/annotator/spell_check/symmetric_delete.py +1 -1
  173. sparknlp/annotator/stemmer.py +2 -3
  174. sparknlp/annotator/stop_words_cleaner.py +3 -4
  175. sparknlp/annotator/tf_ner_dl_graph_builder.py +1 -1
  176. sparknlp/annotator/token/__init__.py +0 -1
  177. sparknlp/annotator/token/recursive_tokenizer.py +2 -3
  178. sparknlp/annotator/token/tokenizer.py +2 -3
  179. sparknlp/annotator/ws/word_segmenter.py +35 -10
  180. sparknlp/base/__init__.py +2 -3
  181. sparknlp/base/doc2_chunk.py +0 -3
  182. sparknlp/base/document_assembler.py +5 -5
  183. sparknlp/base/embeddings_finisher.py +14 -2
  184. sparknlp/base/finisher.py +15 -4
  185. sparknlp/base/gguf_ranking_finisher.py +234 -0
  186. sparknlp/base/image_assembler.py +69 -0
  187. sparknlp/base/light_pipeline.py +53 -21
  188. sparknlp/base/multi_document_assembler.py +9 -13
  189. sparknlp/base/prompt_assembler.py +207 -0
  190. sparknlp/base/token_assembler.py +1 -2
  191. sparknlp/common/__init__.py +2 -0
  192. sparknlp/common/annotator_type.py +1 -0
  193. sparknlp/common/completion_post_processing.py +37 -0
  194. sparknlp/common/match_strategy.py +33 -0
  195. sparknlp/common/properties.py +914 -9
  196. sparknlp/internal/__init__.py +841 -116
  197. sparknlp/internal/annotator_java_ml.py +1 -1
  198. sparknlp/internal/annotator_transformer.py +3 -0
  199. sparknlp/logging/comet.py +2 -2
  200. sparknlp/partition/__init__.py +16 -0
  201. sparknlp/partition/partition.py +244 -0
  202. sparknlp/partition/partition_properties.py +902 -0
  203. sparknlp/partition/partition_transformer.py +200 -0
  204. sparknlp/pretrained/pretrained_pipeline.py +1 -1
  205. sparknlp/pretrained/resource_downloader.py +126 -2
  206. sparknlp/reader/__init__.py +15 -0
  207. sparknlp/reader/enums.py +19 -0
  208. sparknlp/reader/pdf_to_text.py +190 -0
  209. sparknlp/reader/reader2doc.py +124 -0
  210. sparknlp/reader/reader2image.py +136 -0
  211. sparknlp/reader/reader2table.py +44 -0
  212. sparknlp/reader/reader_assembler.py +159 -0
  213. sparknlp/reader/sparknlp_reader.py +461 -0
  214. sparknlp/training/__init__.py +1 -0
  215. sparknlp/training/conll.py +8 -2
  216. sparknlp/training/spacy_to_annotation.py +57 -0
  217. sparknlp/util.py +26 -0
  218. spark_nlp-4.2.6.dist-info/METADATA +0 -1256
  219. spark_nlp-4.2.6.dist-info/RECORD +0 -196
  220. {spark_nlp-4.2.6.dist-info → spark_nlp-6.2.1.dist-info}/top_level.txt +0 -0
  221. /sparknlp/annotator/{token/token2_chunk.py → token2_chunk.py} +0 -0
@@ -20,7 +20,8 @@ class DistilBertForSequenceClassification(AnnotatorModel,
20
20
  HasCaseSensitiveProperties,
21
21
  HasBatchedAnnotate,
22
22
  HasClassifierActivationProperties,
23
- HasEngine):
23
+ HasEngine,
24
+ HasMaxSentenceLengthLimit):
24
25
  """DistilBertForSequenceClassification can load DistilBERT Models with sequence classification/regression head on
25
26
  top (a linear layer on top of the pooled output) e.g. for multi-class document classification tasks.
26
27
 
@@ -35,7 +36,7 @@ class DistilBertForSequenceClassification(AnnotatorModel,
35
36
  provided.
36
37
 
37
38
  For available pretrained models please see the `Models Hub
38
- <https://nlp.johnsnowlabs.com/models?task=Text+Classification>`__.
39
+ <https://sparknlp.org/models?task=Text+Classification>`__.
39
40
 
40
41
  To see which models are compatible and how to import them see
41
42
  `Import Transformers into Spark NLP 🚀
@@ -61,7 +62,7 @@ class DistilBertForSequenceClassification(AnnotatorModel,
61
62
  Max sentence length to process, by default 128
62
63
  coalesceSentences
63
64
  Instead of 1 class per sentence (if inputCols is `sentence`) output
64
- 1 class per document by averaging probabilities in all sentences, by
65
+ 1 class per document by averaging probabilities in all sentences, by
65
66
  default False.
66
67
  activation
67
68
  Whether to calculate logits via Softmax or Sigmoid, by default
@@ -104,11 +105,6 @@ class DistilBertForSequenceClassification(AnnotatorModel,
104
105
 
105
106
  outputAnnotatorType = AnnotatorType.CATEGORY
106
107
 
107
- maxSentenceLength = Param(Params._dummy(),
108
- "maxSentenceLength",
109
- "Max sentence length to process",
110
- typeConverter=TypeConverters.toInt)
111
-
112
108
  configProtoBytes = Param(Params._dummy(),
113
109
  "configProtoBytes",
114
110
  "ConfigProto from tensorflow, serialized into byte array. Get with config_proto.SerializeToString()",
@@ -134,16 +130,6 @@ class DistilBertForSequenceClassification(AnnotatorModel,
134
130
  """
135
131
  return self._set(configProtoBytes=b)
136
132
 
137
- def setMaxSentenceLength(self, value):
138
- """Sets max sentence length to process, by default 128.
139
-
140
- Parameters
141
- ----------
142
- value : int
143
- Max sentence length to process
144
- """
145
- return self._set(maxSentenceLength=value)
146
-
147
133
  def setCoalesceSentences(self, value):
148
134
  """Instead of 1 class per sentence (if inputCols is '''sentence''') output 1 class per document by averaging probabilities in all sentences.
149
135
  Due to max sequence length limit in almost all transformer models such as BERT (512 tokens), this parameter helps feeding all the sentences
@@ -19,7 +19,8 @@ from sparknlp.common import *
19
19
  class DistilBertForTokenClassification(AnnotatorModel,
20
20
  HasCaseSensitiveProperties,
21
21
  HasBatchedAnnotate,
22
- HasEngine):
22
+ HasEngine,
23
+ HasMaxSentenceLengthLimit):
23
24
  """DistilBertForTokenClassification can load Bert Models with a token
24
25
  classification head on top (a linear layer on top of the hidden-states
25
26
  output) e.g. for Named-Entity-Recognition (NER) tasks.
@@ -35,7 +36,7 @@ class DistilBertForTokenClassification(AnnotatorModel,
35
36
  name is provided.
36
37
 
37
38
  For available pretrained models please see the `Models Hub
38
- <https://nlp.johnsnowlabs.com/models?task=Named+Entity+Recognition>`__.
39
+ <https://sparknlp.org/models?task=Named+Entity+Recognition>`__.
39
40
 
40
41
  To see which models are compatible and how to import them see
41
42
  `Import Transformers into Spark NLP 🚀
@@ -96,11 +97,6 @@ class DistilBertForTokenClassification(AnnotatorModel,
96
97
 
97
98
  outputAnnotatorType = AnnotatorType.NAMED_ENTITY
98
99
 
99
- maxSentenceLength = Param(Params._dummy(),
100
- "maxSentenceLength",
101
- "Max sentence length to process",
102
- typeConverter=TypeConverters.toInt)
103
-
104
100
  configProtoBytes = Param(Params._dummy(),
105
101
  "configProtoBytes",
106
102
  "ConfigProto from tensorflow, serialized into byte array. Get with config_proto.SerializeToString()",
@@ -122,16 +118,6 @@ class DistilBertForTokenClassification(AnnotatorModel,
122
118
  """
123
119
  return self._set(configProtoBytes=b)
124
120
 
125
- def setMaxSentenceLength(self, value):
126
- """Sets max sentence length to process, by default 128.
127
-
128
- Parameters
129
- ----------
130
- value : int
131
- Max sentence length to process
132
- """
133
- return self._set(maxSentenceLength=value)
134
-
135
121
  @keyword_only
136
122
  def __init__(self, classname="com.johnsnowlabs.nlp.annotators.classifier.dl.DistilBertForTokenClassification",
137
123
  java_model=None):
@@ -0,0 +1,211 @@
1
+ # Copyright 2017-2023 John Snow Labs
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Contains classes for DistilBertForZeroShotClassification."""
15
+
16
+ from sparknlp.common import *
17
+
18
+
19
+ class DistilBertForZeroShotClassification(AnnotatorModel,
20
+ HasCaseSensitiveProperties,
21
+ HasBatchedAnnotate,
22
+ HasClassifierActivationProperties,
23
+ HasCandidateLabelsProperties,
24
+ HasEngine,
25
+ HasMaxSentenceLengthLimit):
26
+ """DistilBertForZeroShotClassification using a `ModelForSequenceClassification` trained on NLI (natural language
27
+ inference) tasks. Equivalent of `DistilBertForSequenceClassification` models, but these models don't require a hardcoded
28
+ number of potential classes, they can be chosen at runtime. It usually means it's slower but it is much more
29
+ flexible.
30
+
31
+ Note that the model will loop through all provided labels. So the more labels you have, the
32
+ longer this process will take.
33
+
34
+ Any combination of sequences and labels can be passed and each combination will be posed as a premise/hypothesis
35
+ pair and passed to the pretrained model.
36
+
37
+ Pretrained models can be loaded with :meth:`.pretrained` of the companion
38
+ object:
39
+
40
+ >>> sequenceClassifier = DistilBertForZeroShotClassification.pretrained() \\
41
+ ... .setInputCols(["token", "document"]) \\
42
+ ... .setOutputCol("label")
43
+
44
+ The default model is ``"distilbert_base_zero_shot_classifier_uncased_mnli"``, if no name is
45
+ provided.
46
+
47
+ For available pretrained models please see the `Models Hub
48
+ <https://sparknlp.orgtask=Text+Classification>`__.
49
+
50
+ To see which models are compatible and how to import them see
51
+ `Import Transformers into Spark NLP 🚀
52
+ <https://github.com/JohnSnowLabs/spark-nlp/discussions/5669>`_.
53
+
54
+ ====================== ======================
55
+ Input Annotation types Output Annotation type
56
+ ====================== ======================
57
+ ``DOCUMENT, TOKEN`` ``CATEGORY``
58
+ ====================== ======================
59
+
60
+ Parameters
61
+ ----------
62
+ batchSize
63
+ Batch size. Large values allows faster processing but requires more
64
+ memory, by default 8
65
+ caseSensitive
66
+ Whether to ignore case in tokens for embeddings matching, by default
67
+ True
68
+ configProtoBytes
69
+ ConfigProto from tensorflow, serialized into byte array.
70
+ maxSentenceLength
71
+ Max sentence length to process, by default 128
72
+ coalesceSentences
73
+ Instead of 1 class per sentence (if inputCols is `sentence`) output 1
74
+ class per document by averaging probabilities in all sentences, by
75
+ default False
76
+ activation
77
+ Whether to calculate logits via Softmax or Sigmoid, by default
78
+ `"softmax"`.
79
+
80
+ Examples
81
+ --------
82
+ >>> import sparknlp
83
+ >>> from sparknlp.base import *
84
+ >>> from sparknlp.annotator import *
85
+ >>> from pyspark.ml import Pipeline
86
+ >>> documentAssembler = DocumentAssembler() \\
87
+ ... .setInputCol("text") \\
88
+ ... .setOutputCol("document")
89
+ >>> tokenizer = Tokenizer() \\
90
+ ... .setInputCols(["document"]) \\
91
+ ... .setOutputCol("token")
92
+ >>> sequenceClassifier = DistilBertForZeroShotClassification.pretrained() \\
93
+ ... .setInputCols(["token", "document"]) \\
94
+ ... .setOutputCol("label") \\
95
+ ... .setCaseSensitive(True)
96
+ >>> pipeline = Pipeline().setStages([
97
+ ... documentAssembler,
98
+ ... tokenizer,
99
+ ... sequenceClassifier
100
+ ... ])
101
+ >>> data = spark.createDataFrame([["I loved this movie when I was a child.", "It was pretty boring."]]).toDF("text")
102
+ >>> result = pipeline.fit(data).transform(data)
103
+ >>> result.select("label.result").show(truncate=False)
104
+ +------+
105
+ |result|
106
+ +------+
107
+ |[pos] |
108
+ |[neg] |
109
+ +------+
110
+ """
111
+ name = "DistilBertForZeroShotClassification"
112
+
113
+ inputAnnotatorTypes = [AnnotatorType.DOCUMENT, AnnotatorType.TOKEN]
114
+
115
+ outputAnnotatorType = AnnotatorType.CATEGORY
116
+
117
+ configProtoBytes = Param(Params._dummy(),
118
+ "configProtoBytes",
119
+ "ConfigProto from tensorflow, serialized into byte array. Get with config_proto.SerializeToString()",
120
+ TypeConverters.toListInt)
121
+
122
+ coalesceSentences = Param(Params._dummy(), "coalesceSentences",
123
+ "Instead of 1 class per sentence (if inputCols is '''sentence''') output 1 class per document by averaging probabilities in all sentences.",
124
+ TypeConverters.toBoolean)
125
+
126
+ def getClasses(self):
127
+ """
128
+ Returns labels used to train this model
129
+ """
130
+ return self._call_java("getClasses")
131
+
132
+ def setConfigProtoBytes(self, b):
133
+ """Sets configProto from tensorflow, serialized into byte array.
134
+
135
+ Parameters
136
+ ----------
137
+ b : List[int]
138
+ ConfigProto from tensorflow, serialized into byte array
139
+ """
140
+ return self._set(configProtoBytes=b)
141
+
142
+ def setCoalesceSentences(self, value):
143
+ """Instead of 1 class per sentence (if inputCols is '''sentence''') output 1 class per document by averaging
144
+ probabilities in all sentences. Due to max sequence length limit in almost all transformer models such as DistilBERT
145
+ (512 tokens), this parameter helps to feed all the sentences into the model and averaging all the probabilities
146
+ for the entire document instead of probabilities per sentence. (Default: true)
147
+
148
+ Parameters
149
+ ----------
150
+ value : bool
151
+ If the output of all sentences will be averaged to one output
152
+ """
153
+ return self._set(coalesceSentences=value)
154
+
155
+ @keyword_only
156
+ def __init__(self, classname="com.johnsnowlabs.nlp.annotators.classifier.dl.DistilBertForZeroShotClassification",
157
+ java_model=None):
158
+ super(DistilBertForZeroShotClassification, self).__init__(
159
+ classname=classname,
160
+ java_model=java_model
161
+ )
162
+ self._setDefault(
163
+ batchSize=8,
164
+ maxSentenceLength=128,
165
+ caseSensitive=True,
166
+ coalesceSentences=False,
167
+ activation="softmax"
168
+ )
169
+
170
+ @staticmethod
171
+ def loadSavedModel(folder, spark_session):
172
+ """Loads a locally saved model.
173
+
174
+ Parameters
175
+ ----------
176
+ folder : str
177
+ Folder of the saved model
178
+ spark_session : pyspark.sql.SparkSession
179
+ The current SparkSession
180
+
181
+ Returns
182
+ -------
183
+ DistilBertForZeroShotClassification
184
+ The restored model
185
+ """
186
+ from sparknlp.internal import _DistilBertForZeroShotClassification
187
+ jModel = _DistilBertForZeroShotClassification(folder, spark_session._jsparkSession)._java_obj
188
+ return DistilBertForZeroShotClassification(java_model=jModel)
189
+
190
+ @staticmethod
191
+ def pretrained(name="distilbert_base_zero_shot_classifier_uncased_mnli", lang="en", remote_loc=None):
192
+ """Downloads and loads a pretrained model.
193
+
194
+ Parameters
195
+ ----------
196
+ name : str, optional
197
+ Name of the pretrained model, by default
198
+ "distilbert_base_zero_shot_classifier_uncased_mnli"
199
+ lang : str, optional
200
+ Language of the pretrained model, by default "en"
201
+ remote_loc : str, optional
202
+ Optional remote address of the resource, by default None. Will use
203
+ Spark NLPs repositories otherwise.
204
+
205
+ Returns
206
+ -------
207
+ DistilBertForZeroShotClassification
208
+ The restored model
209
+ """
210
+ from sparknlp.pretrained import ResourceDownloader
211
+ return ResourceDownloader.downloadModel(DistilBertForZeroShotClassification, name, lang, remote_loc)
@@ -0,0 +1,161 @@
1
+ # Copyright 2017-2024 John Snow Labs
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from sparknlp.common import *
16
+
17
+ class DistilBertForMultipleChoice(AnnotatorModel,
18
+ HasCaseSensitiveProperties,
19
+ HasBatchedAnnotate,
20
+ HasEngine,
21
+ HasMaxSentenceLengthLimit):
22
+ """DistilBertForMultipleChoice can load DistilBert Models with a multiple choice classification head on top
23
+ (a linear layer on top of the pooled output and a softmax) e.g. for RocStories/SWAG tasks.
24
+
25
+ Pretrained models can be loaded with :meth:`.pretrained` of the companion
26
+ object:
27
+
28
+ >>> spanClassifier = DistilBertForMultipleChoice.pretrained() \\
29
+ ... .setInputCols(["document_question", "document_context"]) \\
30
+ ... .setOutputCol("answer")
31
+
32
+ The default model is ``"bert_base_uncased_multiple_choice"``, if no name is
33
+ provided.
34
+
35
+ For available pretrained models please see the `Models Hub
36
+ <https://sparknlp.org/models?task=Multiple+Choice>`__.
37
+
38
+ To see which models are compatible and how to import them see
39
+ `Import Transformers into Spark NLP 🚀
40
+ <https://github.com/JohnSnowLabs/spark-nlp/discussions/5669>`_.
41
+
42
+ ====================== ======================
43
+ Input Annotation types Output Annotation type
44
+ ====================== ======================
45
+ ``DOCUMENT, DOCUMENT`` ``CHUNK``
46
+ ====================== ======================
47
+
48
+ Parameters
49
+ ----------
50
+ batchSize
51
+ Batch size. Large values allows faster processing but requires more
52
+ memory, by default 8
53
+ caseSensitive
54
+ Whether to ignore case in tokens for embeddings matching, by default
55
+ False
56
+ maxSentenceLength
57
+ Max sentence length to process, by default 512
58
+
59
+ Examples
60
+ --------
61
+ >>> import sparknlp
62
+ >>> from sparknlp.base import *
63
+ >>> from sparknlp.annotator import *
64
+ >>> from pyspark.ml import Pipeline
65
+ >>> documentAssembler = MultiDocumentAssembler() \\
66
+ ... .setInputCols(["question", "context"]) \\
67
+ ... .setOutputCols(["document_question", "document_context"])
68
+ >>> questionAnswering = DistilBertForMultipleChoice.pretrained() \\
69
+ ... .setInputCols(["document_question", "document_context"]) \\
70
+ ... .setOutputCol("answer") \\
71
+ ... .setCaseSensitive(False)
72
+ >>> pipeline = Pipeline().setStages([
73
+ ... documentAssembler,
74
+ ... questionAnswering
75
+ ... ])
76
+ >>> data = spark.createDataFrame([["The Eiffel Tower is located in which country??", "Germany, France, Italy"]]).toDF("question", "context")
77
+ >>> result = pipeline.fit(data).transform(data)
78
+ >>> result.select("answer.result").show(truncate=False)
79
+ +--------------------+
80
+ |result |
81
+ +--------------------+
82
+ |[France] |
83
+ +--------------------+
84
+ """
85
+ name = "DistilBertForMultipleChoice"
86
+
87
+ inputAnnotatorTypes = [AnnotatorType.DOCUMENT, AnnotatorType.DOCUMENT]
88
+
89
+ outputAnnotatorType = AnnotatorType.CHUNK
90
+
91
+ choicesDelimiter = Param(Params._dummy(),
92
+ "choicesDelimiter",
93
+ "Delimiter character use to split the choices",
94
+ TypeConverters.toString)
95
+
96
+ def setChoicesDelimiter(self, value):
97
+ """Sets delimiter character use to split the choices
98
+
99
+ Parameters
100
+ ----------
101
+ value : string
102
+ Delimiter character use to split the choices
103
+ """
104
+ return self._set(caseSensitive=value)
105
+
106
+ @keyword_only
107
+ def __init__(self, classname="com.johnsnowlabs.nlp.annotators.classifier.dl.DistilBertForMultipleChoice",
108
+ java_model=None):
109
+ super(DistilBertForMultipleChoice, self).__init__(
110
+ classname=classname,
111
+ java_model=java_model
112
+ )
113
+ self._setDefault(
114
+ batchSize=4,
115
+ maxSentenceLength=512,
116
+ caseSensitive=False,
117
+ choicesDelimiter = ","
118
+ )
119
+
120
+ @staticmethod
121
+ def loadSavedModel(folder, spark_session):
122
+ """Loads a locally saved model.
123
+
124
+ Parameters
125
+ ----------
126
+ folder : str
127
+ Folder of the saved model
128
+ spark_session : pyspark.sql.SparkSession
129
+ The current SparkSession
130
+
131
+ Returns
132
+ -------
133
+ DistilBertForMultipleChoice
134
+ The restored model
135
+ """
136
+ from sparknlp.internal import _DistilBertMultipleChoiceLoader
137
+ jModel = _DistilBertMultipleChoiceLoader(folder, spark_session._jsparkSession)._java_obj
138
+ return DistilBertForMultipleChoice(java_model=jModel)
139
+
140
+ @staticmethod
141
+ def pretrained(name="distilbert_base_uncased_multiple_choice", lang="en", remote_loc=None):
142
+ """Downloads and loads a pretrained model.
143
+
144
+ Parameters
145
+ ----------
146
+ name : str, optional
147
+ Name of the pretrained model, by default
148
+ "bert_base_uncased_multiple_choice"
149
+ lang : str, optional
150
+ Language of the pretrained model, by default "en"
151
+ remote_loc : str, optional
152
+ Optional remote address of the resource, by default None. Will use
153
+ Spark NLPs repositories otherwise.
154
+
155
+ Returns
156
+ -------
157
+ DistilBertForMultipleChoice
158
+ The restored model
159
+ """
160
+ from sparknlp.pretrained import ResourceDownloader
161
+ return ResourceDownloader.downloadModel(DistilBertForMultipleChoice, name, lang, remote_loc)
@@ -18,7 +18,8 @@ from sparknlp.common import *
18
18
  class LongformerForQuestionAnswering(AnnotatorModel,
19
19
  HasCaseSensitiveProperties,
20
20
  HasBatchedAnnotate,
21
- HasEngine):
21
+ HasEngine,
22
+ HasLongMaxSentenceLengthLimit):
22
23
  """LongformerForQuestionAnswering can load Longformer Models with a span classification head on top for extractive
23
24
  question-answering tasks like SQuAD (a linear layer on top of the hidden-states output to compute span start
24
25
  logits and span end logits).
@@ -34,7 +35,7 @@ class LongformerForQuestionAnswering(AnnotatorModel,
34
35
  provided.
35
36
 
36
37
  For available pretrained models please see the `Models Hub
37
- <https://nlp.johnsnowlabs.com/models?task=Question+Answering>`__.
38
+ <https://sparknlp.org/models?task=Question+Answering>`__.
38
39
 
39
40
  To see which models are compatible and how to import them see
40
41
  `Import Transformers into Spark NLP 🚀
@@ -91,11 +92,6 @@ class LongformerForQuestionAnswering(AnnotatorModel,
91
92
 
92
93
  outputAnnotatorType = AnnotatorType.CHUNK
93
94
 
94
- maxSentenceLength = Param(Params._dummy(),
95
- "maxSentenceLength",
96
- "Max sentence length to process",
97
- typeConverter=TypeConverters.toInt)
98
-
99
95
  configProtoBytes = Param(Params._dummy(),
100
96
  "configProtoBytes",
101
97
  "ConfigProto from tensorflow, serialized into byte array. Get with config_proto.SerializeToString()",
@@ -115,16 +111,6 @@ class LongformerForQuestionAnswering(AnnotatorModel,
115
111
  """
116
112
  return self._set(configProtoBytes=b)
117
113
 
118
- def setMaxSentenceLength(self, value):
119
- """Sets max sentence length to process, by default 128.
120
-
121
- Parameters
122
- ----------
123
- value : int
124
- Max sentence length to process
125
- """
126
- return self._set(maxSentenceLength=value)
127
-
128
114
  @keyword_only
129
115
  def __init__(self, classname="com.johnsnowlabs.nlp.annotators.classifier.dl.LongformerForQuestionAnswering",
130
116
  java_model=None):
@@ -20,7 +20,8 @@ class LongformerForSequenceClassification(AnnotatorModel,
20
20
  HasCaseSensitiveProperties,
21
21
  HasBatchedAnnotate,
22
22
  HasClassifierActivationProperties,
23
- HasEngine):
23
+ HasEngine,
24
+ HasLongMaxSentenceLengthLimit):
24
25
  """LongformerForSequenceClassification can load Longformer Models with sequence classification/regression head on
25
26
  top (a linear layer on top of the pooled output) e.g. for multi-class document classification tasks.
26
27
 
@@ -35,7 +36,7 @@ class LongformerForSequenceClassification(AnnotatorModel,
35
36
  provided.
36
37
 
37
38
  For available pretrained models please see the `Models Hub
38
- <https://nlp.johnsnowlabs.com/models?task=Text+Classification>`__.
39
+ <https://sparknlp.org/models?task=Text+Classification>`__.
39
40
 
40
41
  To see which models are compatible and how to import them see
41
42
  `Import Transformers into Spark NLP 🚀
@@ -61,7 +62,7 @@ class LongformerForSequenceClassification(AnnotatorModel,
61
62
  Max sentence length to process, by default 4096
62
63
  coalesceSentences
63
64
  Instead of 1 class per sentence (if inputCols is `sentence`) output
64
- 1 class per document by averaging probabilities in all sentences, by
65
+ 1 class per document by averaging probabilities in all sentences, by
65
66
  default False.
66
67
  activation
67
68
  Whether to calculate logits via Softmax or Sigmoid, by default
@@ -104,11 +105,6 @@ class LongformerForSequenceClassification(AnnotatorModel,
104
105
 
105
106
  outputAnnotatorType = AnnotatorType.CATEGORY
106
107
 
107
- maxSentenceLength = Param(Params._dummy(),
108
- "maxSentenceLength",
109
- "Max sentence length to process",
110
- typeConverter=TypeConverters.toInt)
111
-
112
108
  configProtoBytes = Param(Params._dummy(),
113
109
  "configProtoBytes",
114
110
  "ConfigProto from tensorflow, serialized into byte array. Get with config_proto.SerializeToString()",
@@ -134,16 +130,6 @@ class LongformerForSequenceClassification(AnnotatorModel,
134
130
  """
135
131
  return self._set(configProtoBytes=b)
136
132
 
137
- def setMaxSentenceLength(self, value):
138
- """Sets max sentence length to process, by default 128.
139
-
140
- Parameters
141
- ----------
142
- value : int
143
- Max sentence length to process
144
- """
145
- return self._set(maxSentenceLength=value)
146
-
147
133
  def setCoalesceSentences(self, value):
148
134
  """Instead of 1 class per sentence (if inputCols is '''sentence''') output 1 class per document by averaging probabilities in all sentences.
149
135
  Due to max sequence length limit in almost all transformer models such as BERT (512 tokens), this parameter helps feeding all the sentences
@@ -19,7 +19,8 @@ from sparknlp.common import *
19
19
  class LongformerForTokenClassification(AnnotatorModel,
20
20
  HasCaseSensitiveProperties,
21
21
  HasBatchedAnnotate,
22
- HasEngine):
22
+ HasEngine,
23
+ HasLongMaxSentenceLengthLimit):
23
24
  """LongformerForTokenClassification can load Longformer Models with a token
24
25
  classification head on top (a linear layer on top of the hidden-states
25
26
  output) e.g. for Named-Entity-Recognition (NER) tasks.
@@ -35,7 +36,7 @@ class LongformerForTokenClassification(AnnotatorModel,
35
36
  provided.
36
37
 
37
38
  For available pretrained models please see the `Models Hub
38
- <https://nlp.johnsnowlabs.com/models?task=Named+Entity+Recognition>`__.
39
+ <https://sparknlp.org/models?task=Named+Entity+Recognition>`__.
39
40
 
40
41
  To see which models are compatible and how to import them see
41
42
  `Import Transformers into Spark NLP 🚀
@@ -97,11 +98,6 @@ class LongformerForTokenClassification(AnnotatorModel,
97
98
 
98
99
  outputAnnotatorType = AnnotatorType.NAMED_ENTITY
99
100
 
100
- maxSentenceLength = Param(Params._dummy(),
101
- "maxSentenceLength",
102
- "Max sentence length to process",
103
- typeConverter=TypeConverters.toInt)
104
-
105
101
  configProtoBytes = Param(Params._dummy(),
106
102
  "configProtoBytes",
107
103
  "ConfigProto from tensorflow, serialized into byte array. Get with config_proto.SerializeToString()",
@@ -123,16 +119,6 @@ class LongformerForTokenClassification(AnnotatorModel,
123
119
  """
124
120
  return self._set(configProtoBytes=b)
125
121
 
126
- def setMaxSentenceLength(self, value):
127
- """Sets max sentence length to process, by default 128.
128
-
129
- Parameters
130
- ----------
131
- value : int
132
- Max sentence length to process
133
- """
134
- return self._set(maxSentenceLength=value)
135
-
136
122
  @keyword_only
137
123
  def __init__(self, classname="com.johnsnowlabs.nlp.annotators.classifier.dl.LongformerForTokenClassification",
138
124
  java_model=None):