spark-nlp 4.2.6__py2.py3-none-any.whl → 6.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (221) hide show
  1. com/johnsnowlabs/ml/__init__.py +0 -0
  2. com/johnsnowlabs/ml/ai/__init__.py +10 -0
  3. spark_nlp-6.2.1.dist-info/METADATA +362 -0
  4. spark_nlp-6.2.1.dist-info/RECORD +292 -0
  5. {spark_nlp-4.2.6.dist-info → spark_nlp-6.2.1.dist-info}/WHEEL +1 -1
  6. sparknlp/__init__.py +81 -28
  7. sparknlp/annotation.py +3 -2
  8. sparknlp/annotator/__init__.py +6 -0
  9. sparknlp/annotator/audio/__init__.py +2 -0
  10. sparknlp/annotator/audio/hubert_for_ctc.py +188 -0
  11. sparknlp/annotator/audio/wav2vec2_for_ctc.py +14 -14
  12. sparknlp/annotator/audio/whisper_for_ctc.py +251 -0
  13. sparknlp/{base → annotator}/chunk2_doc.py +4 -7
  14. sparknlp/annotator/chunker.py +1 -2
  15. sparknlp/annotator/classifier_dl/__init__.py +17 -0
  16. sparknlp/annotator/classifier_dl/albert_for_multiple_choice.py +161 -0
  17. sparknlp/annotator/classifier_dl/albert_for_question_answering.py +3 -15
  18. sparknlp/annotator/classifier_dl/albert_for_sequence_classification.py +4 -18
  19. sparknlp/annotator/classifier_dl/albert_for_token_classification.py +3 -17
  20. sparknlp/annotator/classifier_dl/albert_for_zero_shot_classification.py +211 -0
  21. sparknlp/annotator/classifier_dl/bart_for_zero_shot_classification.py +225 -0
  22. sparknlp/annotator/classifier_dl/bert_for_multiple_choice.py +161 -0
  23. sparknlp/annotator/classifier_dl/bert_for_question_answering.py +6 -20
  24. sparknlp/annotator/classifier_dl/bert_for_sequence_classification.py +3 -17
  25. sparknlp/annotator/classifier_dl/bert_for_token_classification.py +3 -17
  26. sparknlp/annotator/classifier_dl/bert_for_zero_shot_classification.py +212 -0
  27. sparknlp/annotator/classifier_dl/camembert_for_question_answering.py +168 -0
  28. sparknlp/annotator/classifier_dl/camembert_for_sequence_classification.py +5 -19
  29. sparknlp/annotator/classifier_dl/camembert_for_token_classification.py +5 -19
  30. sparknlp/annotator/classifier_dl/camembert_for_zero_shot_classification.py +202 -0
  31. sparknlp/annotator/classifier_dl/classifier_dl.py +4 -4
  32. sparknlp/annotator/classifier_dl/deberta_for_question_answering.py +3 -17
  33. sparknlp/annotator/classifier_dl/deberta_for_sequence_classification.py +4 -19
  34. sparknlp/annotator/classifier_dl/deberta_for_token_classification.py +5 -21
  35. sparknlp/annotator/classifier_dl/deberta_for_zero_shot_classification.py +193 -0
  36. sparknlp/annotator/classifier_dl/distil_bert_for_question_answering.py +3 -17
  37. sparknlp/annotator/classifier_dl/distil_bert_for_sequence_classification.py +4 -18
  38. sparknlp/annotator/classifier_dl/distil_bert_for_token_classification.py +3 -17
  39. sparknlp/annotator/classifier_dl/distil_bert_for_zero_shot_classification.py +211 -0
  40. sparknlp/annotator/classifier_dl/distilbert_for_multiple_choice.py +161 -0
  41. sparknlp/annotator/classifier_dl/longformer_for_question_answering.py +3 -17
  42. sparknlp/annotator/classifier_dl/longformer_for_sequence_classification.py +4 -18
  43. sparknlp/annotator/classifier_dl/longformer_for_token_classification.py +3 -17
  44. sparknlp/annotator/classifier_dl/mpnet_for_question_answering.py +148 -0
  45. sparknlp/annotator/classifier_dl/mpnet_for_sequence_classification.py +188 -0
  46. sparknlp/annotator/classifier_dl/mpnet_for_token_classification.py +173 -0
  47. sparknlp/annotator/classifier_dl/multi_classifier_dl.py +3 -3
  48. sparknlp/annotator/classifier_dl/roberta_for_multiple_choice.py +161 -0
  49. sparknlp/annotator/classifier_dl/roberta_for_question_answering.py +3 -17
  50. sparknlp/annotator/classifier_dl/roberta_for_sequence_classification.py +4 -18
  51. sparknlp/annotator/classifier_dl/roberta_for_token_classification.py +1 -1
  52. sparknlp/annotator/classifier_dl/roberta_for_zero_shot_classification.py +225 -0
  53. sparknlp/annotator/classifier_dl/sentiment_dl.py +4 -4
  54. sparknlp/annotator/classifier_dl/tapas_for_question_answering.py +2 -2
  55. sparknlp/annotator/classifier_dl/xlm_roberta_for_multiple_choice.py +149 -0
  56. sparknlp/annotator/classifier_dl/xlm_roberta_for_question_answering.py +3 -17
  57. sparknlp/annotator/classifier_dl/xlm_roberta_for_sequence_classification.py +4 -18
  58. sparknlp/annotator/classifier_dl/xlm_roberta_for_token_classification.py +6 -20
  59. sparknlp/annotator/classifier_dl/xlm_roberta_for_zero_shot_classification.py +225 -0
  60. sparknlp/annotator/classifier_dl/xlnet_for_sequence_classification.py +4 -18
  61. sparknlp/annotator/classifier_dl/xlnet_for_token_classification.py +3 -17
  62. sparknlp/annotator/cleaners/__init__.py +15 -0
  63. sparknlp/annotator/cleaners/cleaner.py +202 -0
  64. sparknlp/annotator/cleaners/extractor.py +191 -0
  65. sparknlp/annotator/coref/spanbert_coref.py +4 -18
  66. sparknlp/annotator/cv/__init__.py +15 -0
  67. sparknlp/annotator/cv/blip_for_question_answering.py +172 -0
  68. sparknlp/annotator/cv/clip_for_zero_shot_classification.py +193 -0
  69. sparknlp/annotator/cv/convnext_for_image_classification.py +269 -0
  70. sparknlp/annotator/cv/florence2_transformer.py +180 -0
  71. sparknlp/annotator/cv/gemma3_for_multimodal.py +346 -0
  72. sparknlp/annotator/cv/internvl_for_multimodal.py +280 -0
  73. sparknlp/annotator/cv/janus_for_multimodal.py +351 -0
  74. sparknlp/annotator/cv/llava_for_multimodal.py +328 -0
  75. sparknlp/annotator/cv/mllama_for_multimodal.py +340 -0
  76. sparknlp/annotator/cv/paligemma_for_multimodal.py +308 -0
  77. sparknlp/annotator/cv/phi3_vision_for_multimodal.py +328 -0
  78. sparknlp/annotator/cv/qwen2vl_transformer.py +332 -0
  79. sparknlp/annotator/cv/smolvlm_transformer.py +426 -0
  80. sparknlp/annotator/cv/swin_for_image_classification.py +242 -0
  81. sparknlp/annotator/cv/vision_encoder_decoder_for_image_captioning.py +240 -0
  82. sparknlp/annotator/cv/vit_for_image_classification.py +36 -4
  83. sparknlp/annotator/dataframe_optimizer.py +216 -0
  84. sparknlp/annotator/date2_chunk.py +88 -0
  85. sparknlp/annotator/dependency/dependency_parser.py +2 -3
  86. sparknlp/annotator/dependency/typed_dependency_parser.py +3 -4
  87. sparknlp/annotator/document_character_text_splitter.py +228 -0
  88. sparknlp/annotator/document_normalizer.py +37 -1
  89. sparknlp/annotator/document_token_splitter.py +175 -0
  90. sparknlp/annotator/document_token_splitter_test.py +85 -0
  91. sparknlp/annotator/embeddings/__init__.py +11 -0
  92. sparknlp/annotator/embeddings/albert_embeddings.py +4 -18
  93. sparknlp/annotator/embeddings/auto_gguf_embeddings.py +539 -0
  94. sparknlp/annotator/embeddings/bert_embeddings.py +9 -22
  95. sparknlp/annotator/embeddings/bert_sentence_embeddings.py +12 -24
  96. sparknlp/annotator/embeddings/bge_embeddings.py +199 -0
  97. sparknlp/annotator/embeddings/camembert_embeddings.py +4 -20
  98. sparknlp/annotator/embeddings/chunk_embeddings.py +1 -2
  99. sparknlp/annotator/embeddings/deberta_embeddings.py +2 -16
  100. sparknlp/annotator/embeddings/distil_bert_embeddings.py +5 -19
  101. sparknlp/annotator/embeddings/doc2vec.py +7 -1
  102. sparknlp/annotator/embeddings/e5_embeddings.py +195 -0
  103. sparknlp/annotator/embeddings/e5v_embeddings.py +138 -0
  104. sparknlp/annotator/embeddings/elmo_embeddings.py +2 -2
  105. sparknlp/annotator/embeddings/instructor_embeddings.py +204 -0
  106. sparknlp/annotator/embeddings/longformer_embeddings.py +3 -17
  107. sparknlp/annotator/embeddings/minilm_embeddings.py +189 -0
  108. sparknlp/annotator/embeddings/mpnet_embeddings.py +192 -0
  109. sparknlp/annotator/embeddings/mxbai_embeddings.py +184 -0
  110. sparknlp/annotator/embeddings/nomic_embeddings.py +181 -0
  111. sparknlp/annotator/embeddings/roberta_embeddings.py +9 -21
  112. sparknlp/annotator/embeddings/roberta_sentence_embeddings.py +7 -21
  113. sparknlp/annotator/embeddings/sentence_embeddings.py +2 -3
  114. sparknlp/annotator/embeddings/snowflake_embeddings.py +202 -0
  115. sparknlp/annotator/embeddings/uae_embeddings.py +211 -0
  116. sparknlp/annotator/embeddings/universal_sentence_encoder.py +3 -3
  117. sparknlp/annotator/embeddings/word2vec.py +7 -1
  118. sparknlp/annotator/embeddings/word_embeddings.py +4 -5
  119. sparknlp/annotator/embeddings/xlm_roberta_embeddings.py +9 -21
  120. sparknlp/annotator/embeddings/xlm_roberta_sentence_embeddings.py +7 -21
  121. sparknlp/annotator/embeddings/xlnet_embeddings.py +4 -18
  122. sparknlp/annotator/er/entity_ruler.py +37 -23
  123. sparknlp/annotator/keyword_extraction/yake_keyword_extraction.py +2 -3
  124. sparknlp/annotator/ld_dl/language_detector_dl.py +2 -2
  125. sparknlp/annotator/lemmatizer.py +3 -4
  126. sparknlp/annotator/matcher/date_matcher.py +35 -3
  127. sparknlp/annotator/matcher/multi_date_matcher.py +1 -2
  128. sparknlp/annotator/matcher/regex_matcher.py +3 -3
  129. sparknlp/annotator/matcher/text_matcher.py +2 -3
  130. sparknlp/annotator/n_gram_generator.py +1 -2
  131. sparknlp/annotator/ner/__init__.py +3 -1
  132. sparknlp/annotator/ner/ner_converter.py +18 -0
  133. sparknlp/annotator/ner/ner_crf.py +4 -5
  134. sparknlp/annotator/ner/ner_dl.py +10 -5
  135. sparknlp/annotator/ner/ner_dl_graph_checker.py +293 -0
  136. sparknlp/annotator/ner/ner_overwriter.py +2 -2
  137. sparknlp/annotator/ner/zero_shot_ner_model.py +173 -0
  138. sparknlp/annotator/normalizer.py +2 -2
  139. sparknlp/annotator/openai/__init__.py +16 -0
  140. sparknlp/annotator/openai/openai_completion.py +349 -0
  141. sparknlp/annotator/openai/openai_embeddings.py +106 -0
  142. sparknlp/annotator/pos/perceptron.py +6 -7
  143. sparknlp/annotator/sentence/sentence_detector.py +2 -2
  144. sparknlp/annotator/sentence/sentence_detector_dl.py +3 -3
  145. sparknlp/annotator/sentiment/sentiment_detector.py +4 -5
  146. sparknlp/annotator/sentiment/vivekn_sentiment.py +4 -5
  147. sparknlp/annotator/seq2seq/__init__.py +17 -0
  148. sparknlp/annotator/seq2seq/auto_gguf_model.py +304 -0
  149. sparknlp/annotator/seq2seq/auto_gguf_reranker.py +334 -0
  150. sparknlp/annotator/seq2seq/auto_gguf_vision_model.py +336 -0
  151. sparknlp/annotator/seq2seq/bart_transformer.py +420 -0
  152. sparknlp/annotator/seq2seq/cohere_transformer.py +357 -0
  153. sparknlp/annotator/seq2seq/cpm_transformer.py +321 -0
  154. sparknlp/annotator/seq2seq/gpt2_transformer.py +1 -1
  155. sparknlp/annotator/seq2seq/llama2_transformer.py +343 -0
  156. sparknlp/annotator/seq2seq/llama3_transformer.py +381 -0
  157. sparknlp/annotator/seq2seq/m2m100_transformer.py +392 -0
  158. sparknlp/annotator/seq2seq/marian_transformer.py +124 -3
  159. sparknlp/annotator/seq2seq/mistral_transformer.py +348 -0
  160. sparknlp/annotator/seq2seq/nllb_transformer.py +420 -0
  161. sparknlp/annotator/seq2seq/olmo_transformer.py +326 -0
  162. sparknlp/annotator/seq2seq/phi2_transformer.py +326 -0
  163. sparknlp/annotator/seq2seq/phi3_transformer.py +330 -0
  164. sparknlp/annotator/seq2seq/phi4_transformer.py +387 -0
  165. sparknlp/annotator/seq2seq/qwen_transformer.py +340 -0
  166. sparknlp/annotator/seq2seq/starcoder_transformer.py +335 -0
  167. sparknlp/annotator/seq2seq/t5_transformer.py +54 -4
  168. sparknlp/annotator/similarity/__init__.py +0 -0
  169. sparknlp/annotator/similarity/document_similarity_ranker.py +379 -0
  170. sparknlp/annotator/spell_check/context_spell_checker.py +116 -17
  171. sparknlp/annotator/spell_check/norvig_sweeting.py +3 -6
  172. sparknlp/annotator/spell_check/symmetric_delete.py +1 -1
  173. sparknlp/annotator/stemmer.py +2 -3
  174. sparknlp/annotator/stop_words_cleaner.py +3 -4
  175. sparknlp/annotator/tf_ner_dl_graph_builder.py +1 -1
  176. sparknlp/annotator/token/__init__.py +0 -1
  177. sparknlp/annotator/token/recursive_tokenizer.py +2 -3
  178. sparknlp/annotator/token/tokenizer.py +2 -3
  179. sparknlp/annotator/ws/word_segmenter.py +35 -10
  180. sparknlp/base/__init__.py +2 -3
  181. sparknlp/base/doc2_chunk.py +0 -3
  182. sparknlp/base/document_assembler.py +5 -5
  183. sparknlp/base/embeddings_finisher.py +14 -2
  184. sparknlp/base/finisher.py +15 -4
  185. sparknlp/base/gguf_ranking_finisher.py +234 -0
  186. sparknlp/base/image_assembler.py +69 -0
  187. sparknlp/base/light_pipeline.py +53 -21
  188. sparknlp/base/multi_document_assembler.py +9 -13
  189. sparknlp/base/prompt_assembler.py +207 -0
  190. sparknlp/base/token_assembler.py +1 -2
  191. sparknlp/common/__init__.py +2 -0
  192. sparknlp/common/annotator_type.py +1 -0
  193. sparknlp/common/completion_post_processing.py +37 -0
  194. sparknlp/common/match_strategy.py +33 -0
  195. sparknlp/common/properties.py +914 -9
  196. sparknlp/internal/__init__.py +841 -116
  197. sparknlp/internal/annotator_java_ml.py +1 -1
  198. sparknlp/internal/annotator_transformer.py +3 -0
  199. sparknlp/logging/comet.py +2 -2
  200. sparknlp/partition/__init__.py +16 -0
  201. sparknlp/partition/partition.py +244 -0
  202. sparknlp/partition/partition_properties.py +902 -0
  203. sparknlp/partition/partition_transformer.py +200 -0
  204. sparknlp/pretrained/pretrained_pipeline.py +1 -1
  205. sparknlp/pretrained/resource_downloader.py +126 -2
  206. sparknlp/reader/__init__.py +15 -0
  207. sparknlp/reader/enums.py +19 -0
  208. sparknlp/reader/pdf_to_text.py +190 -0
  209. sparknlp/reader/reader2doc.py +124 -0
  210. sparknlp/reader/reader2image.py +136 -0
  211. sparknlp/reader/reader2table.py +44 -0
  212. sparknlp/reader/reader_assembler.py +159 -0
  213. sparknlp/reader/sparknlp_reader.py +461 -0
  214. sparknlp/training/__init__.py +1 -0
  215. sparknlp/training/conll.py +8 -2
  216. sparknlp/training/spacy_to_annotation.py +57 -0
  217. sparknlp/util.py +26 -0
  218. spark_nlp-4.2.6.dist-info/METADATA +0 -1256
  219. spark_nlp-4.2.6.dist-info/RECORD +0 -196
  220. {spark_nlp-4.2.6.dist-info → spark_nlp-6.2.1.dist-info}/top_level.txt +0 -0
  221. /sparknlp/annotator/{token/token2_chunk.py → token2_chunk.py} +0 -0
@@ -0,0 +1,343 @@
1
+ # Copyright 2017-2022 John Snow Labs
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Contains classes for the LLAMA2Transformer."""
15
+
16
+ from sparknlp.common import *
17
+
18
+
19
+ class LLAMA2Transformer(AnnotatorModel, HasBatchedAnnotate, HasEngine):
20
+ """Llama 2: Open Foundation and Fine-Tuned Chat Models
21
+
22
+ The Llama 2 release introduces a family of pretrained and fine-tuned LLMs, ranging in scale
23
+ from 7B to 70B parameters (7B, 13B, 70B). The pretrained models come with significant
24
+ improvements over the Llama 1 models, including being trained on 40% more tokens, having a
25
+ much longer context length (4k tokens 🤯), and using grouped-query attention for fast
26
+ inference of the 70B model🔥!
27
+
28
+ However, the most exciting part of this release is the fine-tuned models (Llama 2-Chat), which
29
+ have been optimized for dialogue applications using Reinforcement Learning from Human Feedback
30
+ (RLHF). Across a wide range of helpfulness and safety benchmarks, the Llama 2-Chat models
31
+ perform better than most open models and achieve comparable performance to ChatGPT according
32
+ to human evaluations.
33
+
34
+ Pretrained models can be loaded with :meth:`.pretrained` of the companion
35
+ object:
36
+
37
+ >>> llama2 = LLAMA2Transformer.pretrained() \\
38
+ ... .setInputCols(["document"]) \\
39
+ ... .setOutputCol("generation")
40
+
41
+
42
+ The default model is ``"llam2-7b"``, if no name is provided. For available
43
+ pretrained models please see the `Models Hub
44
+ <https://sparknlp.org/models?q=llama2>`__.
45
+
46
+ ====================== ======================
47
+ Input Annotation types Output Annotation type
48
+ ====================== ======================
49
+ ``DOCUMENT`` ``DOCUMENT``
50
+ ====================== ======================
51
+
52
+ Parameters
53
+ ----------
54
+ configProtoBytes
55
+ ConfigProto from tensorflow, serialized into byte array.
56
+ minOutputLength
57
+ Minimum length of the sequence to be generated, by default 0
58
+ maxOutputLength
59
+ Maximum length of output text, by default 20
60
+ doSample
61
+ Whether or not to use sampling; use greedy decoding otherwise, by default False
62
+ temperature
63
+ The value used to module the next token probabilities, by default 1.0
64
+ topK
65
+ The number of highest probability vocabulary tokens to keep for
66
+ top-k-filtering, by default 50
67
+ topP
68
+ Top cumulative probability for vocabulary tokens, by default 1.0
69
+
70
+ If set to float < 1, only the most probable tokens with probabilities
71
+ that add up to ``topP`` or higher are kept for generation.
72
+ repetitionPenalty
73
+ The parameter for repetition penalty, 1.0 means no penalty. , by default
74
+ 1.0
75
+ noRepeatNgramSize
76
+ If set to int > 0, all ngrams of that size can only occur once, by
77
+ default 0
78
+ ignoreTokenIds
79
+ A list of token ids which are ignored in the decoder's output, by
80
+ default []
81
+
82
+ Notes
83
+ -----
84
+ This is a very computationally expensive module especially on larger
85
+ sequence. The use of an accelerator such as GPU is recommended.
86
+
87
+ References
88
+ ----------
89
+ - `Llama 2: Open Foundation and Fine-Tuned Chat Models
90
+ <https://ai.meta.com/research/publications/llama-2-open-foundation-and-fine-tuned-chat-models/>`__
91
+ - https://github.com/facebookresearch/llama
92
+
93
+ **Paper Abstract:**
94
+
95
+ *In this work, we develop and release Llama 2, a collection of pretrained and fine-tuned
96
+ large language models (LLMs) ranging in scale from 7 billion to 70 billion parameters. Our
97
+ fine-tuned LLMs, called Llama 2-Chat, are optimized for dialogue use cases. Our models
98
+ outperform open-source chat models on most benchmarks we tested, and based on our human
99
+ evaluations for helpfulness and safety, may be a suitable substitute for closed-source models.
100
+ We provide a detailed description of our approach to fine-tuning and safety improvements of
101
+ Llama 2-Chat in order to enable the community to build on our work and contribute to the
102
+ responsible development of LLMs.*
103
+
104
+ Examples
105
+ --------
106
+ >>> import sparknlp
107
+ >>> from sparknlp.base import *
108
+ >>> from sparknlp.annotator import *
109
+ >>> from pyspark.ml import Pipeline
110
+ >>> documentAssembler = DocumentAssembler() \\
111
+ ... .setInputCol("text") \\
112
+ ... .setOutputCol("documents")
113
+ >>> llama2 = LLAMA2Transformer.pretrained("llama_2_7b_chat_hf_int4") \\
114
+ ... .setInputCols(["documents"]) \\
115
+ ... .setMaxOutputLength(50) \\
116
+ ... .setOutputCol("generation")
117
+ >>> pipeline = Pipeline().setStages([documentAssembler, llama2])
118
+ >>> data = spark.createDataFrame([["My name is Leonardo."]]).toDF("text")
119
+ >>> result = pipeline.fit(data).transform(data)
120
+ >>> result.select("summaries.generation").show(truncate=False)
121
+ +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
122
+ |result |
123
+ +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
124
+ |[My name is Leonardo. I am a man of letters. I have been a man for many years. I was born in the year 1776. I came to the United States in 1776, and I have lived in the United Kingdom since 1776.]|
125
+ -----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
126
+ """
127
+
128
+ name = "LLAMA2Transformer"
129
+
130
+ inputAnnotatorTypes = [AnnotatorType.DOCUMENT]
131
+
132
+ outputAnnotatorType = AnnotatorType.DOCUMENT
133
+
134
+
135
+ configProtoBytes = Param(Params._dummy(),
136
+ "configProtoBytes",
137
+ "ConfigProto from tensorflow, serialized into byte array. Get with config_proto.SerializeToString()",
138
+ TypeConverters.toListInt)
139
+
140
+ minOutputLength = Param(Params._dummy(), "minOutputLength", "Minimum length of the sequence to be generated",
141
+ typeConverter=TypeConverters.toInt)
142
+
143
+ maxOutputLength = Param(Params._dummy(), "maxOutputLength", "Maximum length of output text",
144
+ typeConverter=TypeConverters.toInt)
145
+
146
+ doSample = Param(Params._dummy(), "doSample", "Whether or not to use sampling; use greedy decoding otherwise",
147
+ typeConverter=TypeConverters.toBoolean)
148
+
149
+ temperature = Param(Params._dummy(), "temperature", "The value used to module the next token probabilities",
150
+ typeConverter=TypeConverters.toFloat)
151
+
152
+ topK = Param(Params._dummy(), "topK",
153
+ "The number of highest probability vocabulary tokens to keep for top-k-filtering",
154
+ typeConverter=TypeConverters.toInt)
155
+
156
+ topP = Param(Params._dummy(), "topP",
157
+ "If set to float < 1, only the most probable tokens with probabilities that add up to ``top_p`` or higher are kept for generation",
158
+ typeConverter=TypeConverters.toFloat)
159
+
160
+ repetitionPenalty = Param(Params._dummy(), "repetitionPenalty",
161
+ "The parameter for repetition penalty. 1.0 means no penalty. See `this paper <https://arxiv.org/pdf/1909.05858.pdf>`__ for more details",
162
+ typeConverter=TypeConverters.toFloat)
163
+
164
+ noRepeatNgramSize = Param(Params._dummy(), "noRepeatNgramSize",
165
+ "If set to int > 0, all ngrams of that size can only occur once",
166
+ typeConverter=TypeConverters.toInt)
167
+
168
+ ignoreTokenIds = Param(Params._dummy(), "ignoreTokenIds",
169
+ "A list of token ids which are ignored in the decoder's output",
170
+ typeConverter=TypeConverters.toListInt)
171
+
172
+
173
+ def setIgnoreTokenIds(self, value):
174
+ """A list of token ids which are ignored in the decoder's output.
175
+
176
+ Parameters
177
+ ----------
178
+ value : List[int]
179
+ The words to be filtered out
180
+ """
181
+ return self._set(ignoreTokenIds=value)
182
+
183
+ def setConfigProtoBytes(self, b):
184
+ """Sets configProto from tensorflow, serialized into byte array.
185
+
186
+ Parameters
187
+ ----------
188
+ b : List[int]
189
+ ConfigProto from tensorflow, serialized into byte array
190
+ """
191
+ return self._set(configProtoBytes=b)
192
+
193
+ def setMinOutputLength(self, value):
194
+ """Sets minimum length of the sequence to be generated.
195
+
196
+ Parameters
197
+ ----------
198
+ value : int
199
+ Minimum length of the sequence to be generated
200
+ """
201
+ return self._set(minOutputLength=value)
202
+
203
+ def setMaxOutputLength(self, value):
204
+ """Sets maximum length of output text.
205
+
206
+ Parameters
207
+ ----------
208
+ value : int
209
+ Maximum length of output text
210
+ """
211
+ return self._set(maxOutputLength=value)
212
+
213
+ def setDoSample(self, value):
214
+ """Sets whether or not to use sampling, use greedy decoding otherwise.
215
+
216
+ Parameters
217
+ ----------
218
+ value : bool
219
+ Whether or not to use sampling; use greedy decoding otherwise
220
+ """
221
+ return self._set(doSample=value)
222
+
223
+ def setTemperature(self, value):
224
+ """Sets the value used to module the next token probabilities.
225
+
226
+ Parameters
227
+ ----------
228
+ value : float
229
+ The value used to module the next token probabilities
230
+ """
231
+ return self._set(temperature=value)
232
+
233
+ def setTopK(self, value):
234
+ """Sets the number of highest probability vocabulary tokens to keep for
235
+ top-k-filtering.
236
+
237
+ Parameters
238
+ ----------
239
+ value : int
240
+ Number of highest probability vocabulary tokens to keep
241
+ """
242
+ return self._set(topK=value)
243
+
244
+ def setTopP(self, value):
245
+ """Sets the top cumulative probability for vocabulary tokens.
246
+
247
+ If set to float < 1, only the most probable tokens with probabilities
248
+ that add up to ``topP`` or higher are kept for generation.
249
+
250
+ Parameters
251
+ ----------
252
+ value : float
253
+ Cumulative probability for vocabulary tokens
254
+ """
255
+ return self._set(topP=value)
256
+
257
+ def setRepetitionPenalty(self, value):
258
+ """Sets the parameter for repetition penalty. 1.0 means no penalty.
259
+
260
+ Parameters
261
+ ----------
262
+ value : float
263
+ The repetition penalty
264
+
265
+ References
266
+ ----------
267
+ See `Ctrl: A Conditional Transformer Language Model For Controllable
268
+ Generation <https://arxiv.org/pdf/1909.05858.pdf>`__ for more details.
269
+ """
270
+ return self._set(repetitionPenalty=value)
271
+
272
+ def setNoRepeatNgramSize(self, value):
273
+ """Sets size of n-grams that can only occur once.
274
+
275
+ If set to int > 0, all ngrams of that size can only occur once.
276
+
277
+ Parameters
278
+ ----------
279
+ value : int
280
+ N-gram size can only occur once
281
+ """
282
+ return self._set(noRepeatNgramSize=value)
283
+
284
+ @keyword_only
285
+ def __init__(self, classname="com.johnsnowlabs.nlp.annotators.seq2seq.LLAMA2Transformer", java_model=None):
286
+ super(LLAMA2Transformer, self).__init__(
287
+ classname=classname,
288
+ java_model=java_model
289
+ )
290
+ self._setDefault(
291
+ minOutputLength=0,
292
+ maxOutputLength=20,
293
+ doSample=False,
294
+ temperature=0.6,
295
+ topK=50,
296
+ topP=0.9,
297
+ repetitionPenalty=1.0,
298
+ noRepeatNgramSize=0,
299
+ ignoreTokenIds=[],
300
+ batchSize=1
301
+ )
302
+
303
+ @staticmethod
304
+ def loadSavedModel(folder, spark_session, use_openvino = False):
305
+ """Loads a locally saved model.
306
+
307
+ Parameters
308
+ ----------
309
+ folder : str
310
+ Folder of the saved model
311
+ spark_session : pyspark.sql.SparkSession
312
+ The current SparkSession
313
+
314
+ Returns
315
+ -------
316
+ LLAMA2Transformer
317
+ The restored model
318
+ """
319
+ from sparknlp.internal import _LLAMA2Loader
320
+ jModel = _LLAMA2Loader(folder, spark_session._jsparkSession, use_openvino)._java_obj
321
+ return LLAMA2Transformer(java_model=jModel)
322
+
323
+ @staticmethod
324
+ def pretrained(name="llama_2_7b_chat_hf_int4", lang="en", remote_loc=None):
325
+ """Downloads and loads a pretrained model.
326
+
327
+ Parameters
328
+ ----------
329
+ name : str, optional
330
+ Name of the pretrained model, by default "llama_2_7b_chat_hf_int4"
331
+ lang : str, optional
332
+ Language of the pretrained model, by default "en"
333
+ remote_loc : str, optional
334
+ Optional remote address of the resource, by default None. Will use
335
+ Spark NLPs repositories otherwise.
336
+
337
+ Returns
338
+ -------
339
+ LLAMA2Transformer
340
+ The restored model
341
+ """
342
+ from sparknlp.pretrained import ResourceDownloader
343
+ return ResourceDownloader.downloadModel(LLAMA2Transformer, name, lang, remote_loc)