spark-nlp 4.2.6__py2.py3-none-any.whl → 6.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (221) hide show
  1. com/johnsnowlabs/ml/__init__.py +0 -0
  2. com/johnsnowlabs/ml/ai/__init__.py +10 -0
  3. spark_nlp-6.2.1.dist-info/METADATA +362 -0
  4. spark_nlp-6.2.1.dist-info/RECORD +292 -0
  5. {spark_nlp-4.2.6.dist-info → spark_nlp-6.2.1.dist-info}/WHEEL +1 -1
  6. sparknlp/__init__.py +81 -28
  7. sparknlp/annotation.py +3 -2
  8. sparknlp/annotator/__init__.py +6 -0
  9. sparknlp/annotator/audio/__init__.py +2 -0
  10. sparknlp/annotator/audio/hubert_for_ctc.py +188 -0
  11. sparknlp/annotator/audio/wav2vec2_for_ctc.py +14 -14
  12. sparknlp/annotator/audio/whisper_for_ctc.py +251 -0
  13. sparknlp/{base → annotator}/chunk2_doc.py +4 -7
  14. sparknlp/annotator/chunker.py +1 -2
  15. sparknlp/annotator/classifier_dl/__init__.py +17 -0
  16. sparknlp/annotator/classifier_dl/albert_for_multiple_choice.py +161 -0
  17. sparknlp/annotator/classifier_dl/albert_for_question_answering.py +3 -15
  18. sparknlp/annotator/classifier_dl/albert_for_sequence_classification.py +4 -18
  19. sparknlp/annotator/classifier_dl/albert_for_token_classification.py +3 -17
  20. sparknlp/annotator/classifier_dl/albert_for_zero_shot_classification.py +211 -0
  21. sparknlp/annotator/classifier_dl/bart_for_zero_shot_classification.py +225 -0
  22. sparknlp/annotator/classifier_dl/bert_for_multiple_choice.py +161 -0
  23. sparknlp/annotator/classifier_dl/bert_for_question_answering.py +6 -20
  24. sparknlp/annotator/classifier_dl/bert_for_sequence_classification.py +3 -17
  25. sparknlp/annotator/classifier_dl/bert_for_token_classification.py +3 -17
  26. sparknlp/annotator/classifier_dl/bert_for_zero_shot_classification.py +212 -0
  27. sparknlp/annotator/classifier_dl/camembert_for_question_answering.py +168 -0
  28. sparknlp/annotator/classifier_dl/camembert_for_sequence_classification.py +5 -19
  29. sparknlp/annotator/classifier_dl/camembert_for_token_classification.py +5 -19
  30. sparknlp/annotator/classifier_dl/camembert_for_zero_shot_classification.py +202 -0
  31. sparknlp/annotator/classifier_dl/classifier_dl.py +4 -4
  32. sparknlp/annotator/classifier_dl/deberta_for_question_answering.py +3 -17
  33. sparknlp/annotator/classifier_dl/deberta_for_sequence_classification.py +4 -19
  34. sparknlp/annotator/classifier_dl/deberta_for_token_classification.py +5 -21
  35. sparknlp/annotator/classifier_dl/deberta_for_zero_shot_classification.py +193 -0
  36. sparknlp/annotator/classifier_dl/distil_bert_for_question_answering.py +3 -17
  37. sparknlp/annotator/classifier_dl/distil_bert_for_sequence_classification.py +4 -18
  38. sparknlp/annotator/classifier_dl/distil_bert_for_token_classification.py +3 -17
  39. sparknlp/annotator/classifier_dl/distil_bert_for_zero_shot_classification.py +211 -0
  40. sparknlp/annotator/classifier_dl/distilbert_for_multiple_choice.py +161 -0
  41. sparknlp/annotator/classifier_dl/longformer_for_question_answering.py +3 -17
  42. sparknlp/annotator/classifier_dl/longformer_for_sequence_classification.py +4 -18
  43. sparknlp/annotator/classifier_dl/longformer_for_token_classification.py +3 -17
  44. sparknlp/annotator/classifier_dl/mpnet_for_question_answering.py +148 -0
  45. sparknlp/annotator/classifier_dl/mpnet_for_sequence_classification.py +188 -0
  46. sparknlp/annotator/classifier_dl/mpnet_for_token_classification.py +173 -0
  47. sparknlp/annotator/classifier_dl/multi_classifier_dl.py +3 -3
  48. sparknlp/annotator/classifier_dl/roberta_for_multiple_choice.py +161 -0
  49. sparknlp/annotator/classifier_dl/roberta_for_question_answering.py +3 -17
  50. sparknlp/annotator/classifier_dl/roberta_for_sequence_classification.py +4 -18
  51. sparknlp/annotator/classifier_dl/roberta_for_token_classification.py +1 -1
  52. sparknlp/annotator/classifier_dl/roberta_for_zero_shot_classification.py +225 -0
  53. sparknlp/annotator/classifier_dl/sentiment_dl.py +4 -4
  54. sparknlp/annotator/classifier_dl/tapas_for_question_answering.py +2 -2
  55. sparknlp/annotator/classifier_dl/xlm_roberta_for_multiple_choice.py +149 -0
  56. sparknlp/annotator/classifier_dl/xlm_roberta_for_question_answering.py +3 -17
  57. sparknlp/annotator/classifier_dl/xlm_roberta_for_sequence_classification.py +4 -18
  58. sparknlp/annotator/classifier_dl/xlm_roberta_for_token_classification.py +6 -20
  59. sparknlp/annotator/classifier_dl/xlm_roberta_for_zero_shot_classification.py +225 -0
  60. sparknlp/annotator/classifier_dl/xlnet_for_sequence_classification.py +4 -18
  61. sparknlp/annotator/classifier_dl/xlnet_for_token_classification.py +3 -17
  62. sparknlp/annotator/cleaners/__init__.py +15 -0
  63. sparknlp/annotator/cleaners/cleaner.py +202 -0
  64. sparknlp/annotator/cleaners/extractor.py +191 -0
  65. sparknlp/annotator/coref/spanbert_coref.py +4 -18
  66. sparknlp/annotator/cv/__init__.py +15 -0
  67. sparknlp/annotator/cv/blip_for_question_answering.py +172 -0
  68. sparknlp/annotator/cv/clip_for_zero_shot_classification.py +193 -0
  69. sparknlp/annotator/cv/convnext_for_image_classification.py +269 -0
  70. sparknlp/annotator/cv/florence2_transformer.py +180 -0
  71. sparknlp/annotator/cv/gemma3_for_multimodal.py +346 -0
  72. sparknlp/annotator/cv/internvl_for_multimodal.py +280 -0
  73. sparknlp/annotator/cv/janus_for_multimodal.py +351 -0
  74. sparknlp/annotator/cv/llava_for_multimodal.py +328 -0
  75. sparknlp/annotator/cv/mllama_for_multimodal.py +340 -0
  76. sparknlp/annotator/cv/paligemma_for_multimodal.py +308 -0
  77. sparknlp/annotator/cv/phi3_vision_for_multimodal.py +328 -0
  78. sparknlp/annotator/cv/qwen2vl_transformer.py +332 -0
  79. sparknlp/annotator/cv/smolvlm_transformer.py +426 -0
  80. sparknlp/annotator/cv/swin_for_image_classification.py +242 -0
  81. sparknlp/annotator/cv/vision_encoder_decoder_for_image_captioning.py +240 -0
  82. sparknlp/annotator/cv/vit_for_image_classification.py +36 -4
  83. sparknlp/annotator/dataframe_optimizer.py +216 -0
  84. sparknlp/annotator/date2_chunk.py +88 -0
  85. sparknlp/annotator/dependency/dependency_parser.py +2 -3
  86. sparknlp/annotator/dependency/typed_dependency_parser.py +3 -4
  87. sparknlp/annotator/document_character_text_splitter.py +228 -0
  88. sparknlp/annotator/document_normalizer.py +37 -1
  89. sparknlp/annotator/document_token_splitter.py +175 -0
  90. sparknlp/annotator/document_token_splitter_test.py +85 -0
  91. sparknlp/annotator/embeddings/__init__.py +11 -0
  92. sparknlp/annotator/embeddings/albert_embeddings.py +4 -18
  93. sparknlp/annotator/embeddings/auto_gguf_embeddings.py +539 -0
  94. sparknlp/annotator/embeddings/bert_embeddings.py +9 -22
  95. sparknlp/annotator/embeddings/bert_sentence_embeddings.py +12 -24
  96. sparknlp/annotator/embeddings/bge_embeddings.py +199 -0
  97. sparknlp/annotator/embeddings/camembert_embeddings.py +4 -20
  98. sparknlp/annotator/embeddings/chunk_embeddings.py +1 -2
  99. sparknlp/annotator/embeddings/deberta_embeddings.py +2 -16
  100. sparknlp/annotator/embeddings/distil_bert_embeddings.py +5 -19
  101. sparknlp/annotator/embeddings/doc2vec.py +7 -1
  102. sparknlp/annotator/embeddings/e5_embeddings.py +195 -0
  103. sparknlp/annotator/embeddings/e5v_embeddings.py +138 -0
  104. sparknlp/annotator/embeddings/elmo_embeddings.py +2 -2
  105. sparknlp/annotator/embeddings/instructor_embeddings.py +204 -0
  106. sparknlp/annotator/embeddings/longformer_embeddings.py +3 -17
  107. sparknlp/annotator/embeddings/minilm_embeddings.py +189 -0
  108. sparknlp/annotator/embeddings/mpnet_embeddings.py +192 -0
  109. sparknlp/annotator/embeddings/mxbai_embeddings.py +184 -0
  110. sparknlp/annotator/embeddings/nomic_embeddings.py +181 -0
  111. sparknlp/annotator/embeddings/roberta_embeddings.py +9 -21
  112. sparknlp/annotator/embeddings/roberta_sentence_embeddings.py +7 -21
  113. sparknlp/annotator/embeddings/sentence_embeddings.py +2 -3
  114. sparknlp/annotator/embeddings/snowflake_embeddings.py +202 -0
  115. sparknlp/annotator/embeddings/uae_embeddings.py +211 -0
  116. sparknlp/annotator/embeddings/universal_sentence_encoder.py +3 -3
  117. sparknlp/annotator/embeddings/word2vec.py +7 -1
  118. sparknlp/annotator/embeddings/word_embeddings.py +4 -5
  119. sparknlp/annotator/embeddings/xlm_roberta_embeddings.py +9 -21
  120. sparknlp/annotator/embeddings/xlm_roberta_sentence_embeddings.py +7 -21
  121. sparknlp/annotator/embeddings/xlnet_embeddings.py +4 -18
  122. sparknlp/annotator/er/entity_ruler.py +37 -23
  123. sparknlp/annotator/keyword_extraction/yake_keyword_extraction.py +2 -3
  124. sparknlp/annotator/ld_dl/language_detector_dl.py +2 -2
  125. sparknlp/annotator/lemmatizer.py +3 -4
  126. sparknlp/annotator/matcher/date_matcher.py +35 -3
  127. sparknlp/annotator/matcher/multi_date_matcher.py +1 -2
  128. sparknlp/annotator/matcher/regex_matcher.py +3 -3
  129. sparknlp/annotator/matcher/text_matcher.py +2 -3
  130. sparknlp/annotator/n_gram_generator.py +1 -2
  131. sparknlp/annotator/ner/__init__.py +3 -1
  132. sparknlp/annotator/ner/ner_converter.py +18 -0
  133. sparknlp/annotator/ner/ner_crf.py +4 -5
  134. sparknlp/annotator/ner/ner_dl.py +10 -5
  135. sparknlp/annotator/ner/ner_dl_graph_checker.py +293 -0
  136. sparknlp/annotator/ner/ner_overwriter.py +2 -2
  137. sparknlp/annotator/ner/zero_shot_ner_model.py +173 -0
  138. sparknlp/annotator/normalizer.py +2 -2
  139. sparknlp/annotator/openai/__init__.py +16 -0
  140. sparknlp/annotator/openai/openai_completion.py +349 -0
  141. sparknlp/annotator/openai/openai_embeddings.py +106 -0
  142. sparknlp/annotator/pos/perceptron.py +6 -7
  143. sparknlp/annotator/sentence/sentence_detector.py +2 -2
  144. sparknlp/annotator/sentence/sentence_detector_dl.py +3 -3
  145. sparknlp/annotator/sentiment/sentiment_detector.py +4 -5
  146. sparknlp/annotator/sentiment/vivekn_sentiment.py +4 -5
  147. sparknlp/annotator/seq2seq/__init__.py +17 -0
  148. sparknlp/annotator/seq2seq/auto_gguf_model.py +304 -0
  149. sparknlp/annotator/seq2seq/auto_gguf_reranker.py +334 -0
  150. sparknlp/annotator/seq2seq/auto_gguf_vision_model.py +336 -0
  151. sparknlp/annotator/seq2seq/bart_transformer.py +420 -0
  152. sparknlp/annotator/seq2seq/cohere_transformer.py +357 -0
  153. sparknlp/annotator/seq2seq/cpm_transformer.py +321 -0
  154. sparknlp/annotator/seq2seq/gpt2_transformer.py +1 -1
  155. sparknlp/annotator/seq2seq/llama2_transformer.py +343 -0
  156. sparknlp/annotator/seq2seq/llama3_transformer.py +381 -0
  157. sparknlp/annotator/seq2seq/m2m100_transformer.py +392 -0
  158. sparknlp/annotator/seq2seq/marian_transformer.py +124 -3
  159. sparknlp/annotator/seq2seq/mistral_transformer.py +348 -0
  160. sparknlp/annotator/seq2seq/nllb_transformer.py +420 -0
  161. sparknlp/annotator/seq2seq/olmo_transformer.py +326 -0
  162. sparknlp/annotator/seq2seq/phi2_transformer.py +326 -0
  163. sparknlp/annotator/seq2seq/phi3_transformer.py +330 -0
  164. sparknlp/annotator/seq2seq/phi4_transformer.py +387 -0
  165. sparknlp/annotator/seq2seq/qwen_transformer.py +340 -0
  166. sparknlp/annotator/seq2seq/starcoder_transformer.py +335 -0
  167. sparknlp/annotator/seq2seq/t5_transformer.py +54 -4
  168. sparknlp/annotator/similarity/__init__.py +0 -0
  169. sparknlp/annotator/similarity/document_similarity_ranker.py +379 -0
  170. sparknlp/annotator/spell_check/context_spell_checker.py +116 -17
  171. sparknlp/annotator/spell_check/norvig_sweeting.py +3 -6
  172. sparknlp/annotator/spell_check/symmetric_delete.py +1 -1
  173. sparknlp/annotator/stemmer.py +2 -3
  174. sparknlp/annotator/stop_words_cleaner.py +3 -4
  175. sparknlp/annotator/tf_ner_dl_graph_builder.py +1 -1
  176. sparknlp/annotator/token/__init__.py +0 -1
  177. sparknlp/annotator/token/recursive_tokenizer.py +2 -3
  178. sparknlp/annotator/token/tokenizer.py +2 -3
  179. sparknlp/annotator/ws/word_segmenter.py +35 -10
  180. sparknlp/base/__init__.py +2 -3
  181. sparknlp/base/doc2_chunk.py +0 -3
  182. sparknlp/base/document_assembler.py +5 -5
  183. sparknlp/base/embeddings_finisher.py +14 -2
  184. sparknlp/base/finisher.py +15 -4
  185. sparknlp/base/gguf_ranking_finisher.py +234 -0
  186. sparknlp/base/image_assembler.py +69 -0
  187. sparknlp/base/light_pipeline.py +53 -21
  188. sparknlp/base/multi_document_assembler.py +9 -13
  189. sparknlp/base/prompt_assembler.py +207 -0
  190. sparknlp/base/token_assembler.py +1 -2
  191. sparknlp/common/__init__.py +2 -0
  192. sparknlp/common/annotator_type.py +1 -0
  193. sparknlp/common/completion_post_processing.py +37 -0
  194. sparknlp/common/match_strategy.py +33 -0
  195. sparknlp/common/properties.py +914 -9
  196. sparknlp/internal/__init__.py +841 -116
  197. sparknlp/internal/annotator_java_ml.py +1 -1
  198. sparknlp/internal/annotator_transformer.py +3 -0
  199. sparknlp/logging/comet.py +2 -2
  200. sparknlp/partition/__init__.py +16 -0
  201. sparknlp/partition/partition.py +244 -0
  202. sparknlp/partition/partition_properties.py +902 -0
  203. sparknlp/partition/partition_transformer.py +200 -0
  204. sparknlp/pretrained/pretrained_pipeline.py +1 -1
  205. sparknlp/pretrained/resource_downloader.py +126 -2
  206. sparknlp/reader/__init__.py +15 -0
  207. sparknlp/reader/enums.py +19 -0
  208. sparknlp/reader/pdf_to_text.py +190 -0
  209. sparknlp/reader/reader2doc.py +124 -0
  210. sparknlp/reader/reader2image.py +136 -0
  211. sparknlp/reader/reader2table.py +44 -0
  212. sparknlp/reader/reader_assembler.py +159 -0
  213. sparknlp/reader/sparknlp_reader.py +461 -0
  214. sparknlp/training/__init__.py +1 -0
  215. sparknlp/training/conll.py +8 -2
  216. sparknlp/training/spacy_to_annotation.py +57 -0
  217. sparknlp/util.py +26 -0
  218. spark_nlp-4.2.6.dist-info/METADATA +0 -1256
  219. spark_nlp-4.2.6.dist-info/RECORD +0 -196
  220. {spark_nlp-4.2.6.dist-info → spark_nlp-6.2.1.dist-info}/top_level.txt +0 -0
  221. /sparknlp/annotator/{token/token2_chunk.py → token2_chunk.py} +0 -0
@@ -39,10 +39,10 @@ class T5Transformer(AnnotatorModel, HasBatchedAnnotate, HasEngine):
39
39
 
40
40
  The default model is ``"t5_small"``, if no name is provided. For available
41
41
  pretrained models please see the `Models Hub
42
- <https://nlp.johnsnowlabs.com/models?q=t5>`__.
42
+ <https://sparknlp.org/models?q=t5>`__.
43
43
 
44
- For extended examples of usage, see the `Spark NLP Workshop
45
- <https://github.com/JohnSnowLabs/spark-nlp-workshop/blob/master/tutorials/Certification_Trainings/Public/10.Question_Answering_and_Summarization_with_T5.ipynb>`__.
44
+ For extended examples of usage, see the `Examples
45
+ <https://github.com/JohnSnowLabs/spark-nlp/blob/master/examples/python/annotation/text/english/question-answering/Question_Answering_and_Summarization_with_T5.ipynb>`__.
46
46
 
47
47
  ====================== ======================
48
48
  Input Annotation types Output Annotation type
@@ -191,6 +191,23 @@ class T5Transformer(AnnotatorModel, HasBatchedAnnotate, HasEngine):
191
191
  "A list of token ids which are ignored in the decoder's output",
192
192
  typeConverter=TypeConverters.toListInt)
193
193
 
194
+ useCache = Param(Params._dummy(), "useCache", "Cache internal state of the model to improve performance",
195
+ typeConverter=TypeConverters.toBoolean)
196
+
197
+ stopAtEos = Param(
198
+ Params._dummy(),
199
+ "stopAtEos",
200
+ "Stop text generation when the end-of-sentence token is encountered.",
201
+ typeConverter=TypeConverters.toBoolean
202
+ )
203
+
204
+ maxNewTokens = Param(
205
+ Params._dummy(),
206
+ "maxNewTokens",
207
+ "Maximum number of new tokens to be generated",
208
+ typeConverter=TypeConverters.toInt
209
+ )
210
+
194
211
  def setIgnoreTokenIds(self, value):
195
212
  """A list of token ids which are ignored in the decoder's output.
196
213
 
@@ -241,6 +258,26 @@ class T5Transformer(AnnotatorModel, HasBatchedAnnotate, HasEngine):
241
258
  """
242
259
  return self._set(maxOutputLength=value)
243
260
 
261
+ def setStopAtEos(self, b):
262
+ """Stop text generation when the end-of-sentence token is encountered.
263
+
264
+ Parameters
265
+ ----------
266
+ b : bool
267
+ whether to stop at end-of-sentence token or not
268
+ """
269
+ return self._set(stopAtEos=b)
270
+
271
+ def setMaxNewTokens(self, value):
272
+ """Sets the maximum number of new tokens to be generated
273
+
274
+ Parameters
275
+ ----------
276
+ value : int
277
+ the maximum number of new tokens to be generated
278
+ """
279
+ return self._set(maxNewTokens=value)
280
+
244
281
  def setDoSample(self, value):
245
282
  """Sets whether or not to use sampling, use greedy decoding otherwise.
246
283
 
@@ -312,6 +349,16 @@ class T5Transformer(AnnotatorModel, HasBatchedAnnotate, HasEngine):
312
349
  """
313
350
  return self._set(noRepeatNgramSize=value)
314
351
 
352
+ def setUseCache(self, value):
353
+ """Cache internal state of the model to improve performance
354
+
355
+ Parameters
356
+ ----------
357
+ value : bool
358
+ Whether or not to use cache
359
+ """
360
+ return self._set(useCache=value)
361
+
315
362
  @keyword_only
316
363
  def __init__(self, classname="com.johnsnowlabs.nlp.annotators.seq2seq.T5Transformer", java_model=None):
317
364
  super(T5Transformer, self).__init__(
@@ -329,7 +376,10 @@ class T5Transformer(AnnotatorModel, HasBatchedAnnotate, HasEngine):
329
376
  repetitionPenalty=1.0,
330
377
  noRepeatNgramSize=0,
331
378
  ignoreTokenIds=[],
332
- batchSize=1
379
+ batchSize=1,
380
+ stopAtEos=True,
381
+ maxNewTokens=512,
382
+ useCache=False
333
383
  )
334
384
 
335
385
  @staticmethod
File without changes
@@ -0,0 +1,379 @@
1
+ # Copyright 2017-2023 John Snow Labs
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Contains classes for DocumentSimilarityRanker."""
15
+
16
+ from sparknlp.common import *
17
+ from pyspark import keyword_only
18
+ from pyspark.ml.param import TypeConverters, Params, Param
19
+ from sparknlp.internal import AnnotatorTransformer
20
+
21
+
22
+ class DocumentSimilarityRankerApproach(AnnotatorApproach, HasEnableCachingProperties):
23
+ """Annotator that uses LSH techniques present in Spark ML lib to execute
24
+ approximate nearest neighbors search on top of sentence embeddings.
25
+
26
+ It aims to capture the semantic meaning of a document in a dense,
27
+ continuous vector space and return it to the ranker search.
28
+
29
+ For instantiated/pretrained models, see DocumentSimilarityRankerModel.
30
+
31
+ For extended examples of usage, see the jupyter notebook
32
+ `Document Similarity Ranker for Spark NLP <https://github.com/JohnSnowLabs/spark-nlp/blob/master/examples/python/annotation/text/english/text-similarity/doc-sim-ranker/test_doc_sim_ranker.ipynb>`__.
33
+
34
+ ======================= ===========================
35
+ Input Annotation types Output Annotation type
36
+ ======================= ===========================
37
+ ``SENTENCE_EMBEDDINGS`` ``DOC_SIMILARITY_RANKINGS``
38
+ ======================= ===========================
39
+
40
+ Parameters
41
+ ----------
42
+ enableCaching
43
+ Whether to enable caching DataFrames or RDDs during the training
44
+ similarityMethod
45
+ The similarity method used to calculate the neighbours.
46
+ (Default: 'brp',Bucketed Random Projection for Euclidean Distance)
47
+ numberOfNeighbours
48
+ The number of neighbours the model will return (Default:`10`)
49
+ bucketLength
50
+ Controls the average size of hash buckets. A larger bucket
51
+ length (i.e., fewer buckets) increases the probability of features
52
+ being hashed to the same bucket (increasing the numbers of true and false positives)
53
+ numHashTables
54
+ Number of hash tables, where increasing number of hash tables lowers the
55
+ false negative rate, and decreasing it improves the running performance.
56
+ visibleDistances
57
+ "Whether to set visibleDistances in ranking output (Default: `false`).
58
+ identityRanking
59
+ Whether to include identity in ranking result set. Useful for debug. (Default: `false`).
60
+
61
+ Examples
62
+ --------
63
+ >>> import sparknlp
64
+ >>> from sparknlp.base import *
65
+ >>> from sparknlp.annotator import *
66
+ >>> from pyspark.ml import Pipeline
67
+ >>> from sparknlp.annotator.similarity.document_similarity_ranker import *
68
+ >>> document_assembler = DocumentAssembler() \
69
+ ... .setInputCol("text") \
70
+ ... .setOutputCol("document")
71
+ >>> sentence_embeddings = E5Embeddings.pretrained() \
72
+ ... .setInputCols(["document"]) \
73
+ ... .setOutputCol("sentence_embeddings")
74
+ >>> document_similarity_ranker = DocumentSimilarityRankerApproach() \
75
+ ... .setInputCols("sentence_embeddings") \
76
+ ... .setOutputCol("doc_similarity_rankings") \
77
+ ... .setSimilarityMethod("brp") \
78
+ ... .setNumberOfNeighbours(1) \
79
+ ... .setBucketLength(2.0) \
80
+ ... .setNumHashTables(3) \
81
+ ... .setVisibleDistances(True) \
82
+ ... .setIdentityRanking(False)
83
+ >>> document_similarity_ranker_finisher = DocumentSimilarityRankerFinisher() \
84
+ ... .setInputCols("doc_similarity_rankings") \
85
+ ... .setOutputCols(
86
+ ... "finished_doc_similarity_rankings_id",
87
+ ... "finished_doc_similarity_rankings_neighbors") \
88
+ ... .setExtractNearestNeighbor(True)
89
+ >>> pipeline = Pipeline(stages=[
90
+ ... document_assembler,
91
+ ... sentence_embeddings,
92
+ ... document_similarity_ranker,
93
+ ... document_similarity_ranker_finisher
94
+ ... ])
95
+ >>> docSimRankerPipeline = pipeline.fit(data).transform(data)
96
+ >>> (
97
+ ... docSimRankerPipeline
98
+ ... .select(
99
+ ... "finished_doc_similarity_rankings_id",
100
+ ... "finished_doc_similarity_rankings_neighbors"
101
+ ... ).show(10, False)
102
+ ... )
103
+ +-----------------------------------+------------------------------------------+
104
+ |finished_doc_similarity_rankings_id|finished_doc_similarity_rankings_neighbors|
105
+ +-----------------------------------+------------------------------------------+
106
+ |1510101612 |[(1634839239,0.12448559591306324)] |
107
+ |1634839239 |[(1510101612,0.12448559591306324)] |
108
+ |-612640902 |[(1274183715,0.1220122862046063)] |
109
+ |1274183715 |[(-612640902,0.1220122862046063)] |
110
+ |-1320876223 |[(1293373212,0.17848855164122393)] |
111
+ |1293373212 |[(-1320876223,0.17848855164122393)] |
112
+ |-1548374770 |[(-1719102856,0.23297156732534166)] |
113
+ |-1719102856 |[(-1548374770,0.23297156732534166)] |
114
+ +-----------------------------------+------------------------------------------+
115
+ """
116
+
117
+ inputAnnotatorTypes = [AnnotatorType.SENTENCE_EMBEDDINGS]
118
+
119
+ outputAnnotatorType = AnnotatorType.DOC_SIMILARITY_RANKINGS
120
+
121
+ similarityMethod = Param(Params._dummy(),
122
+ "similarityMethod",
123
+ "The similarity method used to calculate the neighbours. (Default: 'brp', "
124
+ "Bucketed Random Projection for Euclidean Distance)",
125
+ typeConverter=TypeConverters.toString)
126
+
127
+ numberOfNeighbours = Param(Params._dummy(),
128
+ "numberOfNeighbours",
129
+ "The number of neighbours the model will return (Default:`10`)",
130
+ typeConverter=TypeConverters.toInt)
131
+
132
+ bucketLength = Param(Params._dummy(),
133
+ "bucketLength",
134
+ "The bucket length that controls the average size of hash buckets. "
135
+ "A larger bucket length (i.e., fewer buckets) increases the probability of features "
136
+ "being hashed to the same bucket (increasing the numbers of true and false positives).",
137
+ typeConverter=TypeConverters.toFloat)
138
+
139
+ numHashTables = Param(Params._dummy(),
140
+ "numHashTables",
141
+ "number of hash tables, where increasing number of hash tables lowers the "
142
+ "false negative rate,and decreasing it improves the running performance.",
143
+ typeConverter=TypeConverters.toInt)
144
+
145
+ visibleDistances = Param(Params._dummy(),
146
+ "visibleDistances",
147
+ "Whether to set visibleDistances in ranking output (Default: `false`).",
148
+ typeConverter=TypeConverters.toBoolean)
149
+
150
+ identityRanking = Param(Params._dummy(),
151
+ "identityRanking",
152
+ "Whether to include identity in ranking result set. Useful for debug. (Default: `false`).",
153
+ typeConverter=TypeConverters.toBoolean)
154
+
155
+ asRetrieverQuery = Param(Params._dummy(),
156
+ "asRetrieverQuery",
157
+ "Whether to set the model as retriever RAG with a specific query string."
158
+ "(Default: `empty`)",
159
+ typeConverter=TypeConverters.toString)
160
+
161
+ aggregationMethod = Param(Params._dummy(),
162
+ "aggregationMethod",
163
+ "Specifies the method used to aggregate multiple sentence embeddings into a single vector representation.",
164
+ typeConverter=TypeConverters.toString)
165
+
166
+
167
+ def setSimilarityMethod(self, value):
168
+ """Sets the similarity method used to calculate the neighbours.
169
+ (Default: `"brp"`, Bucketed Random Projection for Euclidean Distance)
170
+
171
+ Parameters
172
+ ----------
173
+ value : str
174
+ the similarity method to calculate the neighbours.
175
+ """
176
+ return self._set(similarityMethod=value)
177
+
178
+ def setNumberOfNeighbours(self, value):
179
+ """Sets The number of neighbours the model will return for each document(Default:`"10"`).
180
+
181
+ Parameters
182
+ ----------
183
+ value : str
184
+ the number of neighbours the model will return for each document.
185
+ """
186
+ return self._set(numberOfNeighbours=value)
187
+
188
+ def setBucketLength(self, value):
189
+ """Sets the bucket length that controls the average size of hash buckets (Default:`"2.0"`).
190
+
191
+ Parameters
192
+ ----------
193
+ value : float
194
+ Sets the bucket length that controls the average size of hash buckets.
195
+ """
196
+ return self._set(bucketLength=value)
197
+
198
+ def setNumHashTables(self, value):
199
+ """Sets the number of hash tables.
200
+
201
+ Parameters
202
+ ----------
203
+ value : int
204
+ Sets the number of hash tables.
205
+ """
206
+ return self._set(numHashTables=value)
207
+
208
+ def setVisibleDistances(self, value):
209
+ """Sets the document distances visible in the result set.
210
+
211
+ Parameters
212
+ ----------
213
+ value : bool
214
+ Sets the document distances visible in the result set.
215
+ Default('False')
216
+ """
217
+ return self._set(visibleDistances=value)
218
+
219
+ def setIdentityRanking(self, value):
220
+ """Sets the document identity ranking inclusive in the result set.
221
+
222
+ Parameters
223
+ ----------
224
+ value : bool
225
+ Sets the document identity ranking inclusive in the result set.
226
+ Useful for debugging.
227
+ Default('False').
228
+ """
229
+ return self._set(identityRanking=value)
230
+
231
+ def asRetriever(self, value):
232
+ """Sets the query to use the document similarity ranker as a retriever in a RAG fashion.
233
+ (Default: `""`, empty if this annotator is not used as retriever)
234
+
235
+ Parameters
236
+ ----------
237
+ value : str
238
+ the query to use to select nearest neighbors in the retrieval process.
239
+ """
240
+ return self._set(asRetrieverQuery=value)
241
+
242
+ def setAggregationMethod(self, value):
243
+ """Set the method used to aggregate multiple sentence embeddings into a single vector
244
+ representation.
245
+
246
+ Parameters
247
+ ----------
248
+ value : str
249
+ Options include
250
+ 'AVERAGE' (compute the mean of all embeddings),
251
+ 'FIRST' (use the first embedding only),
252
+ 'MAX' (compute the element-wise maximum across embeddings)
253
+
254
+ Default ('AVERAGE')
255
+ """
256
+ return self._set(aggregationMethod=value)
257
+
258
+ @keyword_only
259
+ def __init__(self):
260
+ super(DocumentSimilarityRankerApproach, self)\
261
+ .__init__(classname="com.johnsnowlabs.nlp.annotators.similarity.DocumentSimilarityRankerApproach")
262
+ self._setDefault(
263
+ similarityMethod="brp",
264
+ numberOfNeighbours=10,
265
+ bucketLength=2.0,
266
+ numHashTables=3,
267
+ visibleDistances=False,
268
+ identityRanking=False,
269
+ asRetrieverQuery=""
270
+ )
271
+
272
+ def _create_model(self, java_model):
273
+ return DocumentSimilarityRankerModel(java_model=java_model)
274
+
275
+
276
+ class DocumentSimilarityRankerModel(AnnotatorModel, HasEmbeddingsProperties):
277
+
278
+ name = "DocumentSimilarityRankerModel"
279
+ inputAnnotatorTypes = [AnnotatorType.SENTENCE_EMBEDDINGS]
280
+ outputAnnotatorType = AnnotatorType.DOC_SIMILARITY_RANKINGS
281
+
282
+ def __init__(self, classname="com.johnsnowlabs.nlp.annotators.similarity.DocumentSimilarityRankerModel",
283
+ java_model=None):
284
+ super(DocumentSimilarityRankerModel, self).__init__(
285
+ classname=classname,
286
+ java_model=java_model
287
+ )
288
+
289
+
290
+ class DocumentSimilarityRankerFinisher(AnnotatorTransformer):
291
+ """Instantiated model of the DocumentSimilarityRankerApproach. For usage and examples see the
292
+ documentation of the main class.
293
+
294
+ ======================= ===========================
295
+ Input Annotation types Output Annotation type
296
+ ======================= ===========================
297
+ ``SENTENCE_EMBEDDINGS`` ``DOC_SIMILARITY_RANKINGS``
298
+ ======================= ===========================
299
+
300
+ Parameters
301
+ ----------
302
+ extractNearestNeighbor
303
+ Whether to extract the nearest neighbor document
304
+ """
305
+ inputCols = Param(Params._dummy(),
306
+ "inputCols",
307
+ "name of input annotation cols containing document similarity ranker results",
308
+ typeConverter=TypeConverters.toListString)
309
+ outputCols = Param(Params._dummy(),
310
+ "outputCols",
311
+ "output DocumentSimilarityRankerFinisher output cols",
312
+ typeConverter=TypeConverters.toListString)
313
+ extractNearestNeighbor = Param(Params._dummy(), "extractNearestNeighbor",
314
+ "whether to extract the nearest neighbor document",
315
+ typeConverter=TypeConverters.toBoolean)
316
+
317
+ name = "DocumentSimilarityRankerFinisher"
318
+
319
+ @keyword_only
320
+ def __init__(self):
321
+ super(DocumentSimilarityRankerFinisher, self).__init__(classname="com.johnsnowlabs.nlp.finisher.DocumentSimilarityRankerFinisher")
322
+ self._setDefault(
323
+ extractNearestNeighbor=False
324
+ )
325
+
326
+ @keyword_only
327
+ def setParams(self):
328
+ kwargs = self._input_kwargs
329
+ return self._set(**kwargs)
330
+
331
+ def setInputCols(self, *value):
332
+ """Sets name of input annotation columns containing embeddings.
333
+
334
+ Parameters
335
+ ----------
336
+ *value : str
337
+ Input columns for the annotator
338
+ """
339
+
340
+ if len(value) == 1 and type(value[0]) == list:
341
+ return self._set(inputCols=value[0])
342
+ else:
343
+ return self._set(inputCols=list(value))
344
+
345
+ def setOutputCols(self, *value):
346
+ """Sets names of finished output columns.
347
+
348
+ Parameters
349
+ ----------
350
+ *value : List[str]
351
+ Input columns for the annotator
352
+ """
353
+
354
+ if len(value) == 1 and type(value[0]) == list:
355
+ return self._set(outputCols=value[0])
356
+ else:
357
+ return self._set(outputCols=list(value))
358
+
359
+ def setExtractNearestNeighbor(self, value):
360
+ """Sets whether to extract the nearest neighbor document, by default False.
361
+
362
+ Parameters
363
+ ----------
364
+ value : bool
365
+ Whether to extract the nearest neighbor document
366
+ """
367
+
368
+ return self._set(extractNearestNeighbor=value)
369
+
370
+ def getInputCols(self):
371
+ """Gets input columns name of annotations."""
372
+ return self.getOrDefault(self.inputCols)
373
+
374
+ def getOutputCols(self):
375
+ """Gets output columns name of annotations."""
376
+ if len(self.getOrDefault(self.outputCols)) == 0:
377
+ return ["finished_" + input_col for input_col in self.getInputCols()]
378
+ else:
379
+ return self.getOrDefault(self.outputCols)
@@ -37,7 +37,7 @@ class ContextSpellCheckerApproach(AnnotatorApproach):
37
37
 
38
38
  For extended examples of usage, see the article
39
39
  `Training a Contextual Spell Checker for Italian Language <https://towardsdatascience.com/training-a-contextual-spell-checker-for-italian-language-66dda528e4bf>`__,
40
- the `Spark NLP Workshop <https://github.com/JohnSnowLabs/spark-nlp-workshop/blob/master/tutorials/blogposts/5.TrainingContextSpellChecker.ipynb>`__.
40
+ the `Examples <https://github.com/JohnSnowLabs/spark-nlp/blob/master/examples/python/training/italian/Training_Context_Spell_Checker_Italian.ipynb>`__.
41
41
 
42
42
  ====================== ======================
43
43
  Input Annotation types Output Annotation type
@@ -92,6 +92,10 @@ class ContextSpellCheckerApproach(AnnotatorApproach):
92
92
  correction.
93
93
  configProtoBytes
94
94
  ConfigProto from tensorflow, serialized into byte array.
95
+ maxSentLen
96
+ Maximum length for a sentence - internal use during training.
97
+ graphFolder
98
+ Folder path that contain external graph files.
95
99
 
96
100
  References
97
101
  ----------
@@ -226,6 +230,16 @@ class ContextSpellCheckerApproach(AnnotatorApproach):
226
230
  "ConfigProto from tensorflow, serialized into byte array. Get with config_proto.SerializeToString()",
227
231
  TypeConverters.toListInt)
228
232
 
233
+ maxSentLen = Param(Params._dummy(),
234
+ "maxSentLen",
235
+ "Maximum length of a sentence to be considered for training.",
236
+ typeConverter=TypeConverters.toInt)
237
+
238
+ graphFolder = Param(Params._dummy(),
239
+ "graphFolder",
240
+ "Folder path that contain external graph files.",
241
+ typeConverter=TypeConverters.toString)
242
+
229
243
  def setLanguageModelClasses(self, count):
230
244
  """Sets number of classes to use during factorization of the softmax
231
245
  output in the Language Model.
@@ -393,18 +407,6 @@ class ContextSpellCheckerApproach(AnnotatorApproach):
393
407
  """
394
408
  return self._set(weightedDistPath=path)
395
409
 
396
- def setWeightedDistPath(self, path):
397
- """Sets the path to the file containing the weights for the levenshtein
398
- distance.
399
-
400
- Parameters
401
- ----------
402
- path : str
403
- Path to the file containing the weights for the levenshtein
404
- distance.
405
- """
406
- return self._set(weightedDistPath=path)
407
-
408
410
  def setMaxWindowLen(self, length):
409
411
  """Sets the maximum size for the window used to remember history prior
410
412
  to every correction.
@@ -427,6 +429,26 @@ class ContextSpellCheckerApproach(AnnotatorApproach):
427
429
  """
428
430
  return self._set(configProtoBytes=b)
429
431
 
432
+ def setGraphFolder(self, path):
433
+ """Sets folder path that contain external graph files.
434
+
435
+ Parameters
436
+ ----------
437
+ path : str
438
+ Folder path that contain external graph files.
439
+ """
440
+ return self._set(graphFolder=path)
441
+
442
+ def setMaxSentLen(self, sentlen):
443
+ """Sets the maximum length of a sentence.
444
+
445
+ Parameters
446
+ ----------
447
+ sentlen : int
448
+ Maximum length of a sentence
449
+ """
450
+ return self._set(maxSentLen=sentlen)
451
+
430
452
  def addVocabClass(self, label, vocab, userdist=3):
431
453
  """Adds a new class of words to correct, based on a vocabulary.
432
454
 
@@ -494,9 +516,9 @@ class ContextSpellCheckerModel(AnnotatorModel, HasEngine):
494
516
 
495
517
 
496
518
  The default model is ``"spellcheck_dl"``, if no name is provided.
497
- For available pretrained models please see the `Models Hub <https://nlp.johnsnowlabs.com/models?task=Spell+Check>`__.
519
+ For available pretrained models please see the `Models Hub <https://sparknlp.org/models?task=Spell+Check>`__.
498
520
 
499
- For extended examples of usage, see the `Spark NLP Workshop <https://github.com/JohnSnowLabs/spark-nlp-workshop/blob/master/tutorials/streamlit_notebooks/SPELL_CHECKER_EN.ipynb>`__.
521
+ For extended examples of usage, see the `Examples <https://github.com/JohnSnowLabs/spark-nlp/blob/master/examples/python/training/italian/Training_Context_Spell_Checker_Italian.ipynb>`__.
500
522
 
501
523
  ====================== ======================
502
524
  Input Annotation types Output Annotation type
@@ -525,13 +547,25 @@ class ContextSpellCheckerModel(AnnotatorModel, HasEngine):
525
547
  correctSymbols
526
548
  Whether to correct special symbols or skip spell checking for them
527
549
  compareLowcase
528
- If true will compare tokens in low case with vocabulary
550
+ If true will compare tokens in low case with vocabulary.
529
551
  configProtoBytes
530
552
  ConfigProto from tensorflow, serialized into byte array.
553
+ vocabFreq
554
+ Frequency words from the vocabulary.
555
+ idsVocab
556
+ Mapping of ids to vocabulary.
557
+ vocabIds
558
+ Mapping of vocabulary to ids.
559
+ classes
560
+ Classes the spell checker recognizes.
561
+ weights
562
+ Levenshtein weights.
563
+ useNewLines
564
+ When set to true new lines will be treated as any other character. When set to false correction is applied on paragraphs as defined by newline characters.
531
565
 
532
566
 
533
567
  References
534
- -------------
568
+ ----------
535
569
  For an in-depth explanation of the module see the article `Applying Context
536
570
  Aware Spell Checking in Spark NLP
537
571
  <https://medium.com/spark-nlp/applying-context-aware-spell-checking-in-spark-nlp-3c29c46963bc>`__.
@@ -624,6 +658,31 @@ class ContextSpellCheckerModel(AnnotatorModel, HasEngine):
624
658
  "ConfigProto from tensorflow, serialized into byte array. Get with config_proto.SerializeToString()",
625
659
  TypeConverters.toListInt)
626
660
 
661
+ vocabFreq = Param(
662
+ Params._dummy(),
663
+ "vocabFreq",
664
+ "Frequency words from the vocabulary.",
665
+ TypeConverters.identity,
666
+ )
667
+ idsVocab = Param(
668
+ Params._dummy(),
669
+ "idsVocab",
670
+ "Mapping of ids to vocabulary.",
671
+ TypeConverters.identity,
672
+ )
673
+ vocabIds = Param(
674
+ Params._dummy(),
675
+ "vocabIds",
676
+ "Mapping of vocabulary to ids.",
677
+ TypeConverters.identity,
678
+ )
679
+ classes = Param(
680
+ Params._dummy(),
681
+ "classes",
682
+ "Classes the spell checker recognizes.",
683
+ TypeConverters.identity,
684
+ )
685
+
627
686
  def setWordMaxDistance(self, dist):
628
687
  """Sets maximum distance for the generated candidates for every word.
629
688
 
@@ -718,6 +777,46 @@ class ContextSpellCheckerModel(AnnotatorModel, HasEngine):
718
777
  """
719
778
  return self._set(configProtoBytes=b)
720
779
 
780
+ def setVocabFreq(self, value: dict):
781
+ """Sets frequency words from the vocabulary.
782
+
783
+ Parameters
784
+ ----------
785
+ value : dict
786
+ Frequency words from the vocabulary.
787
+ """
788
+ return self._set(vocabFreq=value)
789
+
790
+ def setIdsVocab(self, idsVocab: dict):
791
+ """Sets mapping of ids to vocabulary.
792
+
793
+ Parameters
794
+ ----------
795
+ idsVocab : dict
796
+ Mapping of ids to vocabulary.
797
+ """
798
+ return self._set(idsVocab=idsVocab)
799
+
800
+ def setVocabIds(self, vocabIds: dict):
801
+ """Sets mapping of vocabulary to ids.
802
+
803
+ Parameters
804
+ ----------
805
+ vocabIds : dict
806
+ Mapping of vocabulary to ids.
807
+ """
808
+ return self._set(vocabIds=vocabIds)
809
+
810
+ def setClasses(self, value):
811
+ """Sets classes the spell checker recognizes.
812
+
813
+ Parameters
814
+ ----------
815
+ value : list
816
+ Classes the spell checker recognizes.
817
+ """
818
+ return self._set(classes=value)
819
+
721
820
  def getWordClasses(self):
722
821
  """Gets the classes of words to be corrected.
723
822