spark-nlp 2.6.3rc1__py2.py3-none-any.whl → 6.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (329) hide show
  1. com/johnsnowlabs/ml/__init__.py +0 -0
  2. com/johnsnowlabs/ml/ai/__init__.py +10 -0
  3. com/johnsnowlabs/nlp/__init__.py +4 -2
  4. spark_nlp-6.2.1.dist-info/METADATA +362 -0
  5. spark_nlp-6.2.1.dist-info/RECORD +292 -0
  6. {spark_nlp-2.6.3rc1.dist-info → spark_nlp-6.2.1.dist-info}/WHEEL +1 -1
  7. sparknlp/__init__.py +281 -27
  8. sparknlp/annotation.py +137 -6
  9. sparknlp/annotation_audio.py +61 -0
  10. sparknlp/annotation_image.py +82 -0
  11. sparknlp/annotator/__init__.py +93 -0
  12. sparknlp/annotator/audio/__init__.py +16 -0
  13. sparknlp/annotator/audio/hubert_for_ctc.py +188 -0
  14. sparknlp/annotator/audio/wav2vec2_for_ctc.py +161 -0
  15. sparknlp/annotator/audio/whisper_for_ctc.py +251 -0
  16. sparknlp/annotator/chunk2_doc.py +85 -0
  17. sparknlp/annotator/chunker.py +137 -0
  18. sparknlp/annotator/classifier_dl/__init__.py +61 -0
  19. sparknlp/annotator/classifier_dl/albert_for_multiple_choice.py +161 -0
  20. sparknlp/annotator/classifier_dl/albert_for_question_answering.py +172 -0
  21. sparknlp/annotator/classifier_dl/albert_for_sequence_classification.py +201 -0
  22. sparknlp/annotator/classifier_dl/albert_for_token_classification.py +179 -0
  23. sparknlp/annotator/classifier_dl/albert_for_zero_shot_classification.py +211 -0
  24. sparknlp/annotator/classifier_dl/bart_for_zero_shot_classification.py +225 -0
  25. sparknlp/annotator/classifier_dl/bert_for_multiple_choice.py +161 -0
  26. sparknlp/annotator/classifier_dl/bert_for_question_answering.py +168 -0
  27. sparknlp/annotator/classifier_dl/bert_for_sequence_classification.py +202 -0
  28. sparknlp/annotator/classifier_dl/bert_for_token_classification.py +177 -0
  29. sparknlp/annotator/classifier_dl/bert_for_zero_shot_classification.py +212 -0
  30. sparknlp/annotator/classifier_dl/camembert_for_question_answering.py +168 -0
  31. sparknlp/annotator/classifier_dl/camembert_for_sequence_classification.py +205 -0
  32. sparknlp/annotator/classifier_dl/camembert_for_token_classification.py +173 -0
  33. sparknlp/annotator/classifier_dl/camembert_for_zero_shot_classification.py +202 -0
  34. sparknlp/annotator/classifier_dl/classifier_dl.py +320 -0
  35. sparknlp/annotator/classifier_dl/deberta_for_question_answering.py +168 -0
  36. sparknlp/annotator/classifier_dl/deberta_for_sequence_classification.py +198 -0
  37. sparknlp/annotator/classifier_dl/deberta_for_token_classification.py +175 -0
  38. sparknlp/annotator/classifier_dl/deberta_for_zero_shot_classification.py +193 -0
  39. sparknlp/annotator/classifier_dl/distil_bert_for_question_answering.py +168 -0
  40. sparknlp/annotator/classifier_dl/distil_bert_for_sequence_classification.py +201 -0
  41. sparknlp/annotator/classifier_dl/distil_bert_for_token_classification.py +175 -0
  42. sparknlp/annotator/classifier_dl/distil_bert_for_zero_shot_classification.py +211 -0
  43. sparknlp/annotator/classifier_dl/distilbert_for_multiple_choice.py +161 -0
  44. sparknlp/annotator/classifier_dl/longformer_for_question_answering.py +168 -0
  45. sparknlp/annotator/classifier_dl/longformer_for_sequence_classification.py +201 -0
  46. sparknlp/annotator/classifier_dl/longformer_for_token_classification.py +176 -0
  47. sparknlp/annotator/classifier_dl/mpnet_for_question_answering.py +148 -0
  48. sparknlp/annotator/classifier_dl/mpnet_for_sequence_classification.py +188 -0
  49. sparknlp/annotator/classifier_dl/mpnet_for_token_classification.py +173 -0
  50. sparknlp/annotator/classifier_dl/multi_classifier_dl.py +395 -0
  51. sparknlp/annotator/classifier_dl/roberta_for_multiple_choice.py +161 -0
  52. sparknlp/annotator/classifier_dl/roberta_for_question_answering.py +168 -0
  53. sparknlp/annotator/classifier_dl/roberta_for_sequence_classification.py +201 -0
  54. sparknlp/annotator/classifier_dl/roberta_for_token_classification.py +189 -0
  55. sparknlp/annotator/classifier_dl/roberta_for_zero_shot_classification.py +225 -0
  56. sparknlp/annotator/classifier_dl/sentiment_dl.py +378 -0
  57. sparknlp/annotator/classifier_dl/tapas_for_question_answering.py +170 -0
  58. sparknlp/annotator/classifier_dl/xlm_roberta_for_multiple_choice.py +149 -0
  59. sparknlp/annotator/classifier_dl/xlm_roberta_for_question_answering.py +168 -0
  60. sparknlp/annotator/classifier_dl/xlm_roberta_for_sequence_classification.py +201 -0
  61. sparknlp/annotator/classifier_dl/xlm_roberta_for_token_classification.py +173 -0
  62. sparknlp/annotator/classifier_dl/xlm_roberta_for_zero_shot_classification.py +225 -0
  63. sparknlp/annotator/classifier_dl/xlnet_for_sequence_classification.py +201 -0
  64. sparknlp/annotator/classifier_dl/xlnet_for_token_classification.py +176 -0
  65. sparknlp/annotator/cleaners/__init__.py +15 -0
  66. sparknlp/annotator/cleaners/cleaner.py +202 -0
  67. sparknlp/annotator/cleaners/extractor.py +191 -0
  68. sparknlp/annotator/coref/__init__.py +1 -0
  69. sparknlp/annotator/coref/spanbert_coref.py +221 -0
  70. sparknlp/annotator/cv/__init__.py +29 -0
  71. sparknlp/annotator/cv/blip_for_question_answering.py +172 -0
  72. sparknlp/annotator/cv/clip_for_zero_shot_classification.py +193 -0
  73. sparknlp/annotator/cv/convnext_for_image_classification.py +269 -0
  74. sparknlp/annotator/cv/florence2_transformer.py +180 -0
  75. sparknlp/annotator/cv/gemma3_for_multimodal.py +346 -0
  76. sparknlp/annotator/cv/internvl_for_multimodal.py +280 -0
  77. sparknlp/annotator/cv/janus_for_multimodal.py +351 -0
  78. sparknlp/annotator/cv/llava_for_multimodal.py +328 -0
  79. sparknlp/annotator/cv/mllama_for_multimodal.py +340 -0
  80. sparknlp/annotator/cv/paligemma_for_multimodal.py +308 -0
  81. sparknlp/annotator/cv/phi3_vision_for_multimodal.py +328 -0
  82. sparknlp/annotator/cv/qwen2vl_transformer.py +332 -0
  83. sparknlp/annotator/cv/smolvlm_transformer.py +426 -0
  84. sparknlp/annotator/cv/swin_for_image_classification.py +242 -0
  85. sparknlp/annotator/cv/vision_encoder_decoder_for_image_captioning.py +240 -0
  86. sparknlp/annotator/cv/vit_for_image_classification.py +217 -0
  87. sparknlp/annotator/dataframe_optimizer.py +216 -0
  88. sparknlp/annotator/date2_chunk.py +88 -0
  89. sparknlp/annotator/dependency/__init__.py +17 -0
  90. sparknlp/annotator/dependency/dependency_parser.py +294 -0
  91. sparknlp/annotator/dependency/typed_dependency_parser.py +318 -0
  92. sparknlp/annotator/document_character_text_splitter.py +228 -0
  93. sparknlp/annotator/document_normalizer.py +235 -0
  94. sparknlp/annotator/document_token_splitter.py +175 -0
  95. sparknlp/annotator/document_token_splitter_test.py +85 -0
  96. sparknlp/annotator/embeddings/__init__.py +45 -0
  97. sparknlp/annotator/embeddings/albert_embeddings.py +230 -0
  98. sparknlp/annotator/embeddings/auto_gguf_embeddings.py +539 -0
  99. sparknlp/annotator/embeddings/bert_embeddings.py +208 -0
  100. sparknlp/annotator/embeddings/bert_sentence_embeddings.py +224 -0
  101. sparknlp/annotator/embeddings/bge_embeddings.py +199 -0
  102. sparknlp/annotator/embeddings/camembert_embeddings.py +210 -0
  103. sparknlp/annotator/embeddings/chunk_embeddings.py +149 -0
  104. sparknlp/annotator/embeddings/deberta_embeddings.py +208 -0
  105. sparknlp/annotator/embeddings/distil_bert_embeddings.py +221 -0
  106. sparknlp/annotator/embeddings/doc2vec.py +352 -0
  107. sparknlp/annotator/embeddings/e5_embeddings.py +195 -0
  108. sparknlp/annotator/embeddings/e5v_embeddings.py +138 -0
  109. sparknlp/annotator/embeddings/elmo_embeddings.py +251 -0
  110. sparknlp/annotator/embeddings/instructor_embeddings.py +204 -0
  111. sparknlp/annotator/embeddings/longformer_embeddings.py +211 -0
  112. sparknlp/annotator/embeddings/minilm_embeddings.py +189 -0
  113. sparknlp/annotator/embeddings/mpnet_embeddings.py +192 -0
  114. sparknlp/annotator/embeddings/mxbai_embeddings.py +184 -0
  115. sparknlp/annotator/embeddings/nomic_embeddings.py +181 -0
  116. sparknlp/annotator/embeddings/roberta_embeddings.py +225 -0
  117. sparknlp/annotator/embeddings/roberta_sentence_embeddings.py +191 -0
  118. sparknlp/annotator/embeddings/sentence_embeddings.py +134 -0
  119. sparknlp/annotator/embeddings/snowflake_embeddings.py +202 -0
  120. sparknlp/annotator/embeddings/uae_embeddings.py +211 -0
  121. sparknlp/annotator/embeddings/universal_sentence_encoder.py +211 -0
  122. sparknlp/annotator/embeddings/word2vec.py +353 -0
  123. sparknlp/annotator/embeddings/word_embeddings.py +385 -0
  124. sparknlp/annotator/embeddings/xlm_roberta_embeddings.py +225 -0
  125. sparknlp/annotator/embeddings/xlm_roberta_sentence_embeddings.py +194 -0
  126. sparknlp/annotator/embeddings/xlnet_embeddings.py +227 -0
  127. sparknlp/annotator/er/__init__.py +16 -0
  128. sparknlp/annotator/er/entity_ruler.py +267 -0
  129. sparknlp/annotator/graph_extraction.py +368 -0
  130. sparknlp/annotator/keyword_extraction/__init__.py +16 -0
  131. sparknlp/annotator/keyword_extraction/yake_keyword_extraction.py +270 -0
  132. sparknlp/annotator/ld_dl/__init__.py +16 -0
  133. sparknlp/annotator/ld_dl/language_detector_dl.py +199 -0
  134. sparknlp/annotator/lemmatizer.py +250 -0
  135. sparknlp/annotator/matcher/__init__.py +20 -0
  136. sparknlp/annotator/matcher/big_text_matcher.py +272 -0
  137. sparknlp/annotator/matcher/date_matcher.py +303 -0
  138. sparknlp/annotator/matcher/multi_date_matcher.py +109 -0
  139. sparknlp/annotator/matcher/regex_matcher.py +221 -0
  140. sparknlp/annotator/matcher/text_matcher.py +290 -0
  141. sparknlp/annotator/n_gram_generator.py +141 -0
  142. sparknlp/annotator/ner/__init__.py +21 -0
  143. sparknlp/annotator/ner/ner_approach.py +94 -0
  144. sparknlp/annotator/ner/ner_converter.py +148 -0
  145. sparknlp/annotator/ner/ner_crf.py +397 -0
  146. sparknlp/annotator/ner/ner_dl.py +591 -0
  147. sparknlp/annotator/ner/ner_dl_graph_checker.py +293 -0
  148. sparknlp/annotator/ner/ner_overwriter.py +166 -0
  149. sparknlp/annotator/ner/zero_shot_ner_model.py +173 -0
  150. sparknlp/annotator/normalizer.py +230 -0
  151. sparknlp/annotator/openai/__init__.py +16 -0
  152. sparknlp/annotator/openai/openai_completion.py +349 -0
  153. sparknlp/annotator/openai/openai_embeddings.py +106 -0
  154. sparknlp/annotator/param/__init__.py +17 -0
  155. sparknlp/annotator/param/classifier_encoder.py +98 -0
  156. sparknlp/annotator/param/evaluation_dl_params.py +130 -0
  157. sparknlp/annotator/pos/__init__.py +16 -0
  158. sparknlp/annotator/pos/perceptron.py +263 -0
  159. sparknlp/annotator/sentence/__init__.py +17 -0
  160. sparknlp/annotator/sentence/sentence_detector.py +290 -0
  161. sparknlp/annotator/sentence/sentence_detector_dl.py +467 -0
  162. sparknlp/annotator/sentiment/__init__.py +17 -0
  163. sparknlp/annotator/sentiment/sentiment_detector.py +208 -0
  164. sparknlp/annotator/sentiment/vivekn_sentiment.py +242 -0
  165. sparknlp/annotator/seq2seq/__init__.py +35 -0
  166. sparknlp/annotator/seq2seq/auto_gguf_model.py +304 -0
  167. sparknlp/annotator/seq2seq/auto_gguf_reranker.py +334 -0
  168. sparknlp/annotator/seq2seq/auto_gguf_vision_model.py +336 -0
  169. sparknlp/annotator/seq2seq/bart_transformer.py +420 -0
  170. sparknlp/annotator/seq2seq/cohere_transformer.py +357 -0
  171. sparknlp/annotator/seq2seq/cpm_transformer.py +321 -0
  172. sparknlp/annotator/seq2seq/gpt2_transformer.py +363 -0
  173. sparknlp/annotator/seq2seq/llama2_transformer.py +343 -0
  174. sparknlp/annotator/seq2seq/llama3_transformer.py +381 -0
  175. sparknlp/annotator/seq2seq/m2m100_transformer.py +392 -0
  176. sparknlp/annotator/seq2seq/marian_transformer.py +374 -0
  177. sparknlp/annotator/seq2seq/mistral_transformer.py +348 -0
  178. sparknlp/annotator/seq2seq/nllb_transformer.py +420 -0
  179. sparknlp/annotator/seq2seq/olmo_transformer.py +326 -0
  180. sparknlp/annotator/seq2seq/phi2_transformer.py +326 -0
  181. sparknlp/annotator/seq2seq/phi3_transformer.py +330 -0
  182. sparknlp/annotator/seq2seq/phi4_transformer.py +387 -0
  183. sparknlp/annotator/seq2seq/qwen_transformer.py +340 -0
  184. sparknlp/annotator/seq2seq/starcoder_transformer.py +335 -0
  185. sparknlp/annotator/seq2seq/t5_transformer.py +425 -0
  186. sparknlp/annotator/similarity/__init__.py +0 -0
  187. sparknlp/annotator/similarity/document_similarity_ranker.py +379 -0
  188. sparknlp/annotator/spell_check/__init__.py +18 -0
  189. sparknlp/annotator/spell_check/context_spell_checker.py +911 -0
  190. sparknlp/annotator/spell_check/norvig_sweeting.py +358 -0
  191. sparknlp/annotator/spell_check/symmetric_delete.py +299 -0
  192. sparknlp/annotator/stemmer.py +79 -0
  193. sparknlp/annotator/stop_words_cleaner.py +190 -0
  194. sparknlp/annotator/tf_ner_dl_graph_builder.py +179 -0
  195. sparknlp/annotator/token/__init__.py +19 -0
  196. sparknlp/annotator/token/chunk_tokenizer.py +118 -0
  197. sparknlp/annotator/token/recursive_tokenizer.py +205 -0
  198. sparknlp/annotator/token/regex_tokenizer.py +208 -0
  199. sparknlp/annotator/token/tokenizer.py +561 -0
  200. sparknlp/annotator/token2_chunk.py +76 -0
  201. sparknlp/annotator/ws/__init__.py +16 -0
  202. sparknlp/annotator/ws/word_segmenter.py +429 -0
  203. sparknlp/base/__init__.py +30 -0
  204. sparknlp/base/audio_assembler.py +95 -0
  205. sparknlp/base/doc2_chunk.py +169 -0
  206. sparknlp/base/document_assembler.py +164 -0
  207. sparknlp/base/embeddings_finisher.py +201 -0
  208. sparknlp/base/finisher.py +217 -0
  209. sparknlp/base/gguf_ranking_finisher.py +234 -0
  210. sparknlp/base/graph_finisher.py +125 -0
  211. sparknlp/base/has_recursive_fit.py +24 -0
  212. sparknlp/base/has_recursive_transform.py +22 -0
  213. sparknlp/base/image_assembler.py +172 -0
  214. sparknlp/base/light_pipeline.py +429 -0
  215. sparknlp/base/multi_document_assembler.py +164 -0
  216. sparknlp/base/prompt_assembler.py +207 -0
  217. sparknlp/base/recursive_pipeline.py +107 -0
  218. sparknlp/base/table_assembler.py +145 -0
  219. sparknlp/base/token_assembler.py +124 -0
  220. sparknlp/common/__init__.py +26 -0
  221. sparknlp/common/annotator_approach.py +41 -0
  222. sparknlp/common/annotator_model.py +47 -0
  223. sparknlp/common/annotator_properties.py +114 -0
  224. sparknlp/common/annotator_type.py +38 -0
  225. sparknlp/common/completion_post_processing.py +37 -0
  226. sparknlp/common/coverage_result.py +22 -0
  227. sparknlp/common/match_strategy.py +33 -0
  228. sparknlp/common/properties.py +1298 -0
  229. sparknlp/common/read_as.py +33 -0
  230. sparknlp/common/recursive_annotator_approach.py +35 -0
  231. sparknlp/common/storage.py +149 -0
  232. sparknlp/common/utils.py +39 -0
  233. sparknlp/functions.py +315 -5
  234. sparknlp/internal/__init__.py +1199 -0
  235. sparknlp/internal/annotator_java_ml.py +32 -0
  236. sparknlp/internal/annotator_transformer.py +37 -0
  237. sparknlp/internal/extended_java_wrapper.py +63 -0
  238. sparknlp/internal/params_getters_setters.py +71 -0
  239. sparknlp/internal/recursive.py +70 -0
  240. sparknlp/logging/__init__.py +15 -0
  241. sparknlp/logging/comet.py +467 -0
  242. sparknlp/partition/__init__.py +16 -0
  243. sparknlp/partition/partition.py +244 -0
  244. sparknlp/partition/partition_properties.py +902 -0
  245. sparknlp/partition/partition_transformer.py +200 -0
  246. sparknlp/pretrained/__init__.py +17 -0
  247. sparknlp/pretrained/pretrained_pipeline.py +158 -0
  248. sparknlp/pretrained/resource_downloader.py +216 -0
  249. sparknlp/pretrained/utils.py +35 -0
  250. sparknlp/reader/__init__.py +15 -0
  251. sparknlp/reader/enums.py +19 -0
  252. sparknlp/reader/pdf_to_text.py +190 -0
  253. sparknlp/reader/reader2doc.py +124 -0
  254. sparknlp/reader/reader2image.py +136 -0
  255. sparknlp/reader/reader2table.py +44 -0
  256. sparknlp/reader/reader_assembler.py +159 -0
  257. sparknlp/reader/sparknlp_reader.py +461 -0
  258. sparknlp/training/__init__.py +20 -0
  259. sparknlp/training/_tf_graph_builders/__init__.py +0 -0
  260. sparknlp/training/_tf_graph_builders/graph_builders.py +299 -0
  261. sparknlp/training/_tf_graph_builders/ner_dl/__init__.py +0 -0
  262. sparknlp/training/_tf_graph_builders/ner_dl/create_graph.py +41 -0
  263. sparknlp/training/_tf_graph_builders/ner_dl/dataset_encoder.py +78 -0
  264. sparknlp/training/_tf_graph_builders/ner_dl/ner_model.py +521 -0
  265. sparknlp/training/_tf_graph_builders/ner_dl/ner_model_saver.py +62 -0
  266. sparknlp/training/_tf_graph_builders/ner_dl/sentence_grouper.py +28 -0
  267. sparknlp/training/_tf_graph_builders/tf2contrib/__init__.py +36 -0
  268. sparknlp/training/_tf_graph_builders/tf2contrib/core_rnn_cell.py +385 -0
  269. sparknlp/training/_tf_graph_builders/tf2contrib/fused_rnn_cell.py +183 -0
  270. sparknlp/training/_tf_graph_builders/tf2contrib/gru_ops.py +235 -0
  271. sparknlp/training/_tf_graph_builders/tf2contrib/lstm_ops.py +665 -0
  272. sparknlp/training/_tf_graph_builders/tf2contrib/rnn.py +245 -0
  273. sparknlp/training/_tf_graph_builders/tf2contrib/rnn_cell.py +4006 -0
  274. sparknlp/training/_tf_graph_builders_1x/__init__.py +0 -0
  275. sparknlp/training/_tf_graph_builders_1x/graph_builders.py +277 -0
  276. sparknlp/training/_tf_graph_builders_1x/ner_dl/__init__.py +0 -0
  277. sparknlp/training/_tf_graph_builders_1x/ner_dl/create_graph.py +34 -0
  278. sparknlp/training/_tf_graph_builders_1x/ner_dl/dataset_encoder.py +78 -0
  279. sparknlp/training/_tf_graph_builders_1x/ner_dl/ner_model.py +532 -0
  280. sparknlp/training/_tf_graph_builders_1x/ner_dl/ner_model_saver.py +62 -0
  281. sparknlp/training/_tf_graph_builders_1x/ner_dl/sentence_grouper.py +28 -0
  282. sparknlp/training/conll.py +150 -0
  283. sparknlp/training/conllu.py +103 -0
  284. sparknlp/training/pos.py +103 -0
  285. sparknlp/training/pub_tator.py +76 -0
  286. sparknlp/training/spacy_to_annotation.py +57 -0
  287. sparknlp/training/tfgraphs.py +5 -0
  288. sparknlp/upload_to_hub.py +149 -0
  289. sparknlp/util.py +51 -5
  290. com/__init__.pyc +0 -0
  291. com/__pycache__/__init__.cpython-36.pyc +0 -0
  292. com/johnsnowlabs/__init__.pyc +0 -0
  293. com/johnsnowlabs/__pycache__/__init__.cpython-36.pyc +0 -0
  294. com/johnsnowlabs/nlp/__init__.pyc +0 -0
  295. com/johnsnowlabs/nlp/__pycache__/__init__.cpython-36.pyc +0 -0
  296. spark_nlp-2.6.3rc1.dist-info/METADATA +0 -36
  297. spark_nlp-2.6.3rc1.dist-info/RECORD +0 -48
  298. sparknlp/__init__.pyc +0 -0
  299. sparknlp/__pycache__/__init__.cpython-36.pyc +0 -0
  300. sparknlp/__pycache__/annotation.cpython-36.pyc +0 -0
  301. sparknlp/__pycache__/annotator.cpython-36.pyc +0 -0
  302. sparknlp/__pycache__/base.cpython-36.pyc +0 -0
  303. sparknlp/__pycache__/common.cpython-36.pyc +0 -0
  304. sparknlp/__pycache__/embeddings.cpython-36.pyc +0 -0
  305. sparknlp/__pycache__/functions.cpython-36.pyc +0 -0
  306. sparknlp/__pycache__/internal.cpython-36.pyc +0 -0
  307. sparknlp/__pycache__/pretrained.cpython-36.pyc +0 -0
  308. sparknlp/__pycache__/storage.cpython-36.pyc +0 -0
  309. sparknlp/__pycache__/training.cpython-36.pyc +0 -0
  310. sparknlp/__pycache__/util.cpython-36.pyc +0 -0
  311. sparknlp/annotation.pyc +0 -0
  312. sparknlp/annotator.py +0 -3006
  313. sparknlp/annotator.pyc +0 -0
  314. sparknlp/base.py +0 -347
  315. sparknlp/base.pyc +0 -0
  316. sparknlp/common.py +0 -193
  317. sparknlp/common.pyc +0 -0
  318. sparknlp/embeddings.py +0 -40
  319. sparknlp/embeddings.pyc +0 -0
  320. sparknlp/internal.py +0 -288
  321. sparknlp/internal.pyc +0 -0
  322. sparknlp/pretrained.py +0 -123
  323. sparknlp/pretrained.pyc +0 -0
  324. sparknlp/storage.py +0 -32
  325. sparknlp/storage.pyc +0 -0
  326. sparknlp/training.py +0 -62
  327. sparknlp/training.pyc +0 -0
  328. sparknlp/util.pyc +0 -0
  329. {spark_nlp-2.6.3rc1.dist-info → spark_nlp-6.2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1298 @@
1
+ # Copyright 2017-2022 John Snow Labs
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Contains classes for Annotator properties."""
15
+ from typing import List, Dict
16
+
17
+ from pyspark.ml.param import Param, Params, TypeConverters
18
+
19
+
20
+ class HasBatchedAnnotate:
21
+ batchSize = Param(Params._dummy(), "batchSize", "Size of every batch", TypeConverters.toInt)
22
+
23
+ def setBatchSize(self, v):
24
+ """Sets batch size.
25
+
26
+ Parameters
27
+ ----------
28
+ v : int
29
+ Batch size
30
+ """
31
+ return self._set(batchSize=v)
32
+
33
+ def getBatchSize(self):
34
+ """Gets current batch size.
35
+
36
+ Returns
37
+ -------
38
+ int
39
+ Current batch size
40
+ """
41
+ return self.getOrDefault(self.batchSize)
42
+
43
+
44
+ class HasCaseSensitiveProperties:
45
+ caseSensitive = Param(Params._dummy(),
46
+ "caseSensitive",
47
+ "whether to ignore case in tokens for embeddings matching",
48
+ typeConverter=TypeConverters.toBoolean)
49
+
50
+ def setCaseSensitive(self, value):
51
+ """Sets whether to ignore case in tokens for embeddings matching.
52
+
53
+ Parameters
54
+ ----------
55
+ value : bool
56
+ Whether to ignore case in tokens for embeddings matching
57
+ """
58
+ return self._set(caseSensitive=value)
59
+
60
+ def getCaseSensitive(self):
61
+ """Gets whether to ignore case in tokens for embeddings matching.
62
+
63
+ Returns
64
+ -------
65
+ bool
66
+ Whether to ignore case in tokens for embeddings matching
67
+ """
68
+ return self.getOrDefault(self.caseSensitive)
69
+
70
+
71
+ class HasClsTokenProperties:
72
+ useCLSToken = Param(Params._dummy(),
73
+ "useCLSToken",
74
+ "Whether to use CLS token for pooling (true) or attention-based average pooling (false)",
75
+ typeConverter=TypeConverters.toBoolean)
76
+
77
+ def setUseCLSToken(self, value):
78
+ """Sets whether to ignore case in tokens for embeddings matching.
79
+
80
+ Parameters
81
+ ----------
82
+ value : bool
83
+ Whether to use CLS token for pooling (true) or attention-based average pooling (false)
84
+ """
85
+ return self._set(useCLSToken=value)
86
+
87
+ def getUseCLSToken(self):
88
+ """Gets whether to use CLS token for pooling (true) or attention-based average pooling (false)
89
+
90
+ Returns
91
+ -------
92
+ bool
93
+ Whether to use CLS token for pooling (true) or attention-based average pooling (false)
94
+ """
95
+ return self.getOrDefault(self.useCLSToken)
96
+
97
+
98
+ class HasClassifierActivationProperties:
99
+ activation = Param(Params._dummy(),
100
+ "activation",
101
+ "Whether to calculate logits via Softmax or Sigmoid. Default is Softmax",
102
+ typeConverter=TypeConverters.toString)
103
+
104
+ multilabel = Param(Params._dummy(),
105
+ "multilabel",
106
+ "Whether to calculate logits via Multiclass(softmax) or Multilabel(sigmoid). Default is False i.e. Multiclass",
107
+ typeConverter=TypeConverters.toBoolean)
108
+
109
+ threshold = Param(Params._dummy(),
110
+ "threshold",
111
+ "Choose the threshold to determine which logits are considered to be positive or negative",
112
+ typeConverter=TypeConverters.toFloat)
113
+
114
+ def setActivation(self, value):
115
+ """Sets whether to calculate logits via Softmax or Sigmoid. Default is Softmax
116
+
117
+ Parameters
118
+ ----------
119
+ value : str
120
+ Whether to calculate logits via Softmax or Sigmoid. Default is Softmax
121
+ """
122
+ return self._set(activation=value)
123
+
124
+ def getActivation(self):
125
+ """Gets whether to calculate logits via Softmax or Sigmoid. Default is Softmax
126
+
127
+ Returns
128
+ -------
129
+ str
130
+ Whether to calculate logits via Softmax or Sigmoid. Default is Softmax
131
+ """
132
+ return self.getOrDefault(self.activation)
133
+
134
+ def setMultilabel(self, value):
135
+ """Set whether or not the result should be multi-class (the sum of all probabilities is 1.0) or
136
+ multi-label (each label has a probability between 0.0 to 1.0).
137
+ Default is False i.e. multi-class
138
+
139
+ Parameters
140
+ ----------
141
+ value : bool
142
+ Whether or not the result should be multi-class (the sum of all probabilities is 1.0) or
143
+ multi-label (each label has a probability between 0.0 to 1.0).
144
+ Default is False i.e. multi-class
145
+ """
146
+ return self._set(multilabel=value)
147
+
148
+ def getMultilabel(self):
149
+ """Gets whether or not the result should be multi-class (the sum of all probabilities is 1.0) or
150
+ multi-label (each label has a probability between 0.0 to 1.0).
151
+ Default is False i.e. multi-class
152
+
153
+ Parameters
154
+ ----------
155
+ value : bool
156
+ Whether or not the result should be multi-class (the sum of all probabilities is 1.0) or
157
+ multi-label (each label has a probability between 0.0 to 1.0).
158
+ Default is False i.e. multi-class
159
+ """
160
+ return self.getOrDefault(self.multilabel)
161
+
162
+ def setThreshold(self, value):
163
+ """Set the threshold to determine which logits are considered to be positive or negative.
164
+ (Default: `0.5`). The value should be between 0.0 and 1.0. Changing the threshold value
165
+ will affect the resulting labels and can be used to adjust the balance between precision and
166
+ recall in the classification process.
167
+
168
+ Parameters
169
+ ----------
170
+ value : float
171
+ The threshold to determine which logits are considered to be positive or negative.
172
+ (Default: `0.5`). The value should be between 0.0 and 1.0. Changing the threshold value
173
+ will affect the resulting labels and can be used to adjust the balance between precision and
174
+ recall in the classification process.
175
+ """
176
+ return self._set(threshold=value)
177
+
178
+
179
+ class HasEmbeddingsProperties(Params):
180
+ dimension = Param(Params._dummy(),
181
+ "dimension",
182
+ "Number of embedding dimensions",
183
+ typeConverter=TypeConverters.toInt)
184
+
185
+ def setDimension(self, value):
186
+ """Sets embeddings dimension.
187
+
188
+ Parameters
189
+ ----------
190
+ value : int
191
+ Embeddings dimension
192
+ """
193
+ return self._set(dimension=value)
194
+
195
+ def getDimension(self):
196
+ """Gets embeddings dimension."""
197
+ return self.getOrDefault(self.dimension)
198
+
199
+
200
+ class HasEnableCachingProperties:
201
+ enableCaching = Param(Params._dummy(),
202
+ "enableCaching",
203
+ "Whether to enable caching DataFrames or RDDs during the training",
204
+ typeConverter=TypeConverters.toBoolean)
205
+
206
+ def setEnableCaching(self, value):
207
+ """Sets whether to enable caching DataFrames or RDDs during the training
208
+
209
+ Parameters
210
+ ----------
211
+ value : bool
212
+ Whether to enable caching DataFrames or RDDs during the training
213
+ """
214
+ return self._set(enableCaching=value)
215
+
216
+ def getEnableCaching(self):
217
+ """Gets whether to enable caching DataFrames or RDDs during the training
218
+
219
+ Returns
220
+ -------
221
+ bool
222
+ Whether to enable caching DataFrames or RDDs during the training
223
+ """
224
+ return self.getOrDefault(self.enableCaching)
225
+
226
+
227
+ class HasBatchedAnnotateImage:
228
+ batchSize = Param(Params._dummy(), "batchSize", "Size of every batch", TypeConverters.toInt)
229
+
230
+ def setBatchSize(self, v):
231
+ """Sets batch size.
232
+
233
+ Parameters
234
+ ----------
235
+ v : int
236
+ Batch size
237
+ """
238
+ return self._set(batchSize=v)
239
+
240
+ def getBatchSize(self):
241
+ """Gets current batch size.
242
+
243
+ Returns
244
+ -------
245
+ int
246
+ Current batch size
247
+ """
248
+ return self.getOrDefault(self.batchSize)
249
+
250
+
251
+ class HasImageFeatureProperties:
252
+ doResize = Param(Params._dummy(), "doResize", "Whether to resize the input to a certain size",
253
+ TypeConverters.toBoolean)
254
+
255
+ doNormalize = Param(Params._dummy(), "doNormalize",
256
+ "Whether to normalize the input with mean and standard deviation",
257
+ TypeConverters.toBoolean)
258
+
259
+ featureExtractorType = Param(Params._dummy(), "featureExtractorType",
260
+ "Name of model's architecture for feature extraction",
261
+ TypeConverters.toString)
262
+
263
+ imageMean = Param(Params._dummy(), "imageMean",
264
+ "The sequence of means for each channel, to be used when normalizing images",
265
+ TypeConverters.toListFloat)
266
+
267
+ imageStd = Param(Params._dummy(), "imageStd",
268
+ "The sequence of standard deviations for each channel, to be used when normalizing images",
269
+ TypeConverters.toListFloat)
270
+
271
+ resample = Param(Params._dummy(), "resample",
272
+ "An optional resampling filter. This can be one of PIL.Image.NEAREST, PIL.Image.BILINEAR or "
273
+ "PIL.Image.BICUBIC. Only has an effect if do_resize is set to True.",
274
+ TypeConverters.toInt)
275
+
276
+ size = Param(Params._dummy(), "size",
277
+ "Resize the input to the given size. If a tuple is provided, it should be (width, height). If only "
278
+ "an integer is provided, then the input will be resized to (size, size). Only has an effect if "
279
+ "do_resize is set to True.",
280
+ TypeConverters.toInt)
281
+
282
+ def setDoResize(self, value):
283
+ """
284
+
285
+ Parameters
286
+ ----------
287
+ value : Boolean
288
+ Whether to resize the input to a certain size
289
+ """
290
+ return self._set(doResize=value)
291
+
292
+ def setDoNormalize(self, value):
293
+ """
294
+
295
+ Parameters
296
+ ----------
297
+ value : Boolean
298
+ Whether to normalize the input with mean and standard deviation
299
+ """
300
+ return self._set(doNormalize=value)
301
+
302
+ def setFeatureExtractorType(self, value):
303
+ """
304
+
305
+ Parameters
306
+ ----------
307
+ value : str
308
+ Name of model's architecture for feature extraction
309
+ """
310
+ return self._set(featureExtractorType=value)
311
+
312
+ def setImageStd(self, value):
313
+ """
314
+
315
+ Parameters
316
+ ----------
317
+ value : List[float]
318
+ The sequence of standard deviations for each channel, to be used when normalizing images
319
+ """
320
+ return self._set(imageStd=value)
321
+
322
+ def setImageMean(self, value):
323
+ """
324
+
325
+ Parameters
326
+ ----------
327
+ value : List[float]
328
+ The sequence of means for each channel, to be used when normalizing images
329
+ """
330
+ return self._set(imageMean=value)
331
+
332
+ def setResample(self, value):
333
+ """
334
+
335
+ Parameters
336
+ ----------
337
+ value : int
338
+ Resampling filter for resizing. This can be one of `PIL.Image.NEAREST`, `PIL.Image.BILINEAR` or
339
+ `PIL.Image.BICUBIC`. Only has an effect if `do_resize` is set to `True`.
340
+ """
341
+ return self._set(resample=value)
342
+
343
+ def setSize(self, value):
344
+ """
345
+
346
+ Parameters
347
+ ----------
348
+ value : int
349
+ Resize the input to the given size. If a tuple is provided, it should be (width, height).
350
+ """
351
+ return self._set(size=value)
352
+
353
+
354
+ class HasRescaleFactor:
355
+ doRescale = Param(Params._dummy(), "doRescale",
356
+ "Whether to rescale the image values by rescaleFactor.",
357
+ TypeConverters.toBoolean)
358
+
359
+ rescaleFactor = Param(Params._dummy(), "rescaleFactor",
360
+ "Factor to scale the image values",
361
+ TypeConverters.toFloat)
362
+
363
+ def setDoRescale(self, value):
364
+ """Sets Whether to rescale the image values by rescaleFactor, by default `True`.
365
+
366
+ Parameters
367
+ ----------
368
+ value : Boolean
369
+ Whether to rescale the image values by rescaleFactor.
370
+ """
371
+ return self._set(doRescale=value)
372
+
373
+ def setRescaleFactor(self, value):
374
+ """Sets Factor to scale the image values, by default `1/255.0`.
375
+
376
+ Parameters
377
+ ----------
378
+ value : Boolean
379
+ Whether to rescale the image values by rescaleFactor.
380
+ """
381
+ return self._set(rescaleFactor=value)
382
+
383
+
384
+ class HasBatchedAnnotateAudio:
385
+ batchSize = Param(Params._dummy(), "batchSize", "Size of every batch", TypeConverters.toInt)
386
+
387
+ def setBatchSize(self, v):
388
+ """Sets batch size.
389
+
390
+ Parameters
391
+ ----------
392
+ v : int
393
+ Batch size
394
+ """
395
+ return self._set(batchSize=v)
396
+
397
+ def getBatchSize(self):
398
+ """Gets current batch size.
399
+
400
+ Returns
401
+ -------
402
+ int
403
+ Current batch size
404
+ """
405
+ return self.getOrDefault(self.batchSize)
406
+
407
+
408
+ class HasAudioFeatureProperties:
409
+ doNormalize = Param(Params._dummy(), "doNormalize",
410
+ "Whether to normalize the input",
411
+ TypeConverters.toBoolean)
412
+
413
+ returnAttentionMask = Param(Params._dummy(), "returnAttentionMask", "",
414
+ TypeConverters.toBoolean)
415
+
416
+ paddingSide = Param(Params._dummy(), "paddingSide",
417
+ "",
418
+ TypeConverters.toString)
419
+
420
+ featureSize = Param(Params._dummy(), "featureSize",
421
+ "",
422
+ TypeConverters.toInt)
423
+
424
+ samplingRate = Param(Params._dummy(), "samplingRate",
425
+ "",
426
+ TypeConverters.toInt)
427
+
428
+ paddingValue = Param(Params._dummy(), "paddingValue",
429
+ "",
430
+ TypeConverters.toFloat)
431
+
432
+ def setDoNormalize(self, value):
433
+ """
434
+
435
+ Parameters
436
+ ----------
437
+ value : Boolean
438
+ Whether to normalize the input with mean and standard deviation
439
+ """
440
+ return self._set(doNormalize=value)
441
+
442
+ def setReturnAttentionMask(self, value):
443
+ """
444
+
445
+ Parameters
446
+ ----------
447
+ value : boolean
448
+ """
449
+ return self._set(returnAttentionMask=value)
450
+
451
+ def setPaddingSide(self, value):
452
+ """
453
+
454
+ Parameters
455
+ ----------
456
+ value : str
457
+
458
+ """
459
+ return self._set(paddingSide=value)
460
+
461
+ def setFeatureSize(self, value):
462
+ """
463
+
464
+ Parameters
465
+ ----------
466
+ value : int
467
+
468
+ """
469
+ return self._set(featureSize=value)
470
+
471
+ def setSamplingRate(self, value):
472
+ """
473
+
474
+ Parameters
475
+ ----------
476
+ value : Int
477
+ """
478
+ return self._set(samplingRate=value)
479
+
480
+ def setPaddingValue(self, value):
481
+ """
482
+
483
+ Parameters
484
+ ----------
485
+ value : float
486
+ """
487
+ return self._set(paddingValue=value)
488
+
489
+
490
+ class HasEngine:
491
+ engine = Param(Params._dummy(), "engine",
492
+ "Deep Learning engine used for this model",
493
+ typeConverter=TypeConverters.toString)
494
+
495
+ def getEngine(self):
496
+ """
497
+
498
+ Returns
499
+ -------
500
+ str
501
+ Deep Learning engine used for this model"
502
+ """
503
+ return self.getOrDefault(self.engine)
504
+
505
+
506
+ class HasCandidateLabelsProperties:
507
+ candidateLabels = Param(Params._dummy(), "candidateLabels",
508
+ "Deep Learning engine used for this model",
509
+ typeConverter=TypeConverters.toListString)
510
+
511
+ contradictionIdParam = Param(Params._dummy(), "contradictionIdParam",
512
+ "contradictionIdParam",
513
+ typeConverter=TypeConverters.toInt)
514
+
515
+ entailmentIdParam = Param(Params._dummy(), "entailmentIdParam",
516
+ "contradictionIdParam",
517
+ typeConverter=TypeConverters.toInt)
518
+
519
+ def setCandidateLabels(self, v):
520
+ """Sets candidateLabels.
521
+
522
+ Parameters
523
+ ----------
524
+ v : list[string]
525
+ candidateLabels
526
+ """
527
+ return self._set(candidateLabels=v)
528
+
529
+ def setContradictionIdParam(self, v):
530
+ """Sets contradictionIdParam.
531
+
532
+ Parameters
533
+ ----------
534
+ v : int
535
+ contradictionIdParam
536
+ """
537
+ return self._set(contradictionIdParam=v)
538
+
539
+ def setEntailmentIdParam(self, v):
540
+ """Sets entailmentIdParam.
541
+
542
+ Parameters
543
+ ----------
544
+ v : int
545
+ entailmentIdParam
546
+ """
547
+ return self._set(entailmentIdParam=v)
548
+
549
+
550
+ class HasMaxSentenceLengthLimit:
551
+ # Default Value, can be overridden
552
+ max_length_limit = 512
553
+
554
+ maxSentenceLength = Param(Params._dummy(),
555
+ "maxSentenceLength",
556
+ "Max sentence length to process",
557
+ typeConverter=TypeConverters.toInt)
558
+
559
+ def setMaxSentenceLength(self, value):
560
+ """Sets max sentence length to process.
561
+
562
+ Note that a maximum limit exists depending on the model. If you are working with long single
563
+ sequences, consider splitting up the input first with another annotator e.g. SentenceDetector.
564
+
565
+ Parameters
566
+ ----------
567
+ value : int
568
+ Max sentence length to process
569
+ """
570
+ if value > self.max_length_limit:
571
+ raise ValueError(
572
+ f"{self.__class__.__name__} models do not support token sequences longer than {self.max_length_limit}.\n"
573
+ f"Consider splitting up the input first with another annotator e.g. SentenceDetector.")
574
+ return self._set(maxSentenceLength=value)
575
+
576
+ def getMaxSentenceLength(self):
577
+ """Gets max sentence of the model.
578
+
579
+ Returns
580
+ -------
581
+ int
582
+ Max sentence length to process
583
+ """
584
+ return self.getOrDefault("maxSentenceLength")
585
+
586
+
587
+ class HasLongMaxSentenceLengthLimit(HasMaxSentenceLengthLimit):
588
+ max_length_limit = 4096
589
+
590
+
591
+ class HasGeneratorProperties:
592
+ task = Param(Params._dummy(), "task", "Transformer's task, e.g. summarize>", typeConverter=TypeConverters.toString)
593
+
594
+ minOutputLength = Param(Params._dummy(), "minOutputLength", "Minimum length of the sequence to be generated",
595
+ typeConverter=TypeConverters.toInt)
596
+
597
+ maxOutputLength = Param(Params._dummy(), "maxOutputLength", "Maximum length of output text",
598
+ typeConverter=TypeConverters.toInt)
599
+
600
+ doSample = Param(Params._dummy(), "doSample", "Whether or not to use sampling; use greedy decoding otherwise",
601
+ typeConverter=TypeConverters.toBoolean)
602
+
603
+ temperature = Param(Params._dummy(), "temperature", "The value used to module the next token probabilities",
604
+ typeConverter=TypeConverters.toFloat)
605
+
606
+ topK = Param(Params._dummy(), "topK",
607
+ "The number of highest probability vocabulary tokens to keep for top-k-filtering",
608
+ typeConverter=TypeConverters.toInt)
609
+
610
+ topP = Param(Params._dummy(), "topP",
611
+ "If set to float < 1, only the most probable tokens with probabilities that add up to ``top_p`` or higher are kept for generation",
612
+ typeConverter=TypeConverters.toFloat)
613
+
614
+ repetitionPenalty = Param(Params._dummy(), "repetitionPenalty",
615
+ "The parameter for repetition penalty. 1.0 means no penalty. See `this paper <https://arxiv.org/pdf/1909.05858.pdf>`__ for more details",
616
+ typeConverter=TypeConverters.toFloat)
617
+
618
+ noRepeatNgramSize = Param(Params._dummy(), "noRepeatNgramSize",
619
+ "If set to int > 0, all ngrams of that size can only occur once",
620
+ typeConverter=TypeConverters.toInt)
621
+
622
+ beamSize = Param(Params._dummy(), "beamSize",
623
+ "The Number of beams for beam search.",
624
+ typeConverter=TypeConverters.toInt)
625
+
626
+ nReturnSequences = Param(Params._dummy(),
627
+ "nReturnSequences",
628
+ "The number of sequences to return from the beam search.",
629
+ typeConverter=TypeConverters.toInt)
630
+
631
+ def setTask(self, value):
632
+ """Sets the transformer's task, e.g. ``summarize:``.
633
+
634
+ Parameters
635
+ ----------
636
+ value : str
637
+ The transformer's task
638
+ """
639
+ return self._set(task=value)
640
+
641
+ def setMinOutputLength(self, value):
642
+ """Sets minimum length of the sequence to be generated.
643
+
644
+ Parameters
645
+ ----------
646
+ value : int
647
+ Minimum length of the sequence to be generated
648
+ """
649
+ return self._set(minOutputLength=value)
650
+
651
+ def setMaxOutputLength(self, value):
652
+ """Sets maximum length of output text.
653
+
654
+ Parameters
655
+ ----------
656
+ value : int
657
+ Maximum length of output text
658
+ """
659
+ return self._set(maxOutputLength=value)
660
+
661
+ def setDoSample(self, value):
662
+ """Sets whether or not to use sampling, use greedy decoding otherwise.
663
+
664
+ Parameters
665
+ ----------
666
+ value : bool
667
+ Whether or not to use sampling; use greedy decoding otherwise
668
+ """
669
+ return self._set(doSample=value)
670
+
671
+ def setTemperature(self, value):
672
+ """Sets the value used to module the next token probabilities.
673
+
674
+ Parameters
675
+ ----------
676
+ value : float
677
+ The value used to module the next token probabilities
678
+ """
679
+ return self._set(temperature=value)
680
+
681
+ def setTopK(self, value):
682
+ """Sets the number of highest probability vocabulary tokens to keep for
683
+ top-k-filtering.
684
+
685
+ Parameters
686
+ ----------
687
+ value : int
688
+ Number of highest probability vocabulary tokens to keep
689
+ """
690
+ return self._set(topK=value)
691
+
692
+ def setTopP(self, value):
693
+ """Sets the top cumulative probability for vocabulary tokens.
694
+
695
+ If set to float < 1, only the most probable tokens with probabilities
696
+ that add up to ``topP`` or higher are kept for generation.
697
+
698
+ Parameters
699
+ ----------
700
+ value : float
701
+ Cumulative probability for vocabulary tokens
702
+ """
703
+ return self._set(topP=value)
704
+
705
+ def setRepetitionPenalty(self, value):
706
+ """Sets the parameter for repetition penalty. 1.0 means no penalty.
707
+
708
+ Parameters
709
+ ----------
710
+ value : float
711
+ The repetition penalty
712
+
713
+ References
714
+ ----------
715
+ See `Ctrl: A Conditional Transformer Language Model For Controllable
716
+ Generation <https://arxiv.org/pdf/1909.05858.pdf>`__ for more details.
717
+ """
718
+ return self._set(repetitionPenalty=value)
719
+
720
+ def setNoRepeatNgramSize(self, value):
721
+ """Sets size of n-grams that can only occur once.
722
+
723
+ If set to int > 0, all ngrams of that size can only occur once.
724
+
725
+ Parameters
726
+ ----------
727
+ value : int
728
+ N-gram size can only occur once
729
+ """
730
+ return self._set(noRepeatNgramSize=value)
731
+
732
+ def setBeamSize(self, value):
733
+ """Sets the number of beam size for beam search.
734
+
735
+ Parameters
736
+ ----------
737
+ value : int
738
+ Number of beam size for beam search
739
+ """
740
+ return self._set(beamSize=value)
741
+
742
+ def setNReturnSequences(self, value):
743
+ """Sets the number of sequences to return from the beam search.
744
+
745
+ Parameters
746
+ ----------
747
+ value : int
748
+ Number of sequences to return
749
+ """
750
+ return self._set(nReturnSequences=value)
751
+
752
+
753
+ class HasLlamaCppProperties:
754
+ # -------- MODEl PARAMETERS --------
755
+ nThreads = Param(Params._dummy(), "nThreads", "Set the number of threads to use during generation",
756
+ typeConverter=TypeConverters.toInt)
757
+ # nThreadsDraft = Param(Params._dummy(), "nThreadsDraft", "Set the number of threads to use during draft generation",
758
+ # typeConverter=TypeConverters.toInt)
759
+ nThreadsBatch = Param(Params._dummy(), "nThreadsBatch",
760
+ "Set the number of threads to use during batch and prompt processing",
761
+ typeConverter=TypeConverters.toInt)
762
+ # nThreadsBatchDraft = Param(Params._dummy(), "nThreadsBatchDraft",
763
+ # "Set the number of threads to use during batch and prompt processing",
764
+ # typeConverter=TypeConverters.toInt)
765
+ nCtx = Param(Params._dummy(), "nCtx", "Set the size of the prompt context", typeConverter=TypeConverters.toInt)
766
+ nBatch = Param(Params._dummy(), "nBatch",
767
+ "Set the logical batch size for prompt processing (must be >=32 to use BLAS)",
768
+ typeConverter=TypeConverters.toInt)
769
+ nUbatch = Param(Params._dummy(), "nUbatch",
770
+ "Set the physical batch size for prompt processing (must be >=32 to use BLAS)",
771
+ typeConverter=TypeConverters.toInt)
772
+ nDraft = Param(Params._dummy(), "nDraft", "Set the number of tokens to draft for speculative decoding",
773
+ typeConverter=TypeConverters.toInt)
774
+ # nChunks = Param(Params._dummy(), "nChunks", "Set the maximal number of chunks to process",
775
+ # typeConverter=TypeConverters.toInt)
776
+ # nSequences = Param(Params._dummy(), "nSequences", "Set the number of sequences to decode",
777
+ # typeConverter=TypeConverters.toInt)
778
+ # pSplit = Param(Params._dummy(), "pSplit", "Set the speculative decoding split probability",
779
+ # typeConverter=TypeConverters.toFloat)
780
+ nGpuLayers = Param(Params._dummy(), "nGpuLayers", "Set the number of layers to store in VRAM (-1 - use default)",
781
+ typeConverter=TypeConverters.toInt)
782
+ nGpuLayersDraft = Param(Params._dummy(), "nGpuLayersDraft",
783
+ "Set the number of layers to store in VRAM for the draft model (-1 - use default)",
784
+ typeConverter=TypeConverters.toInt)
785
+ # Set how to split the model across GPUs
786
+ #
787
+ # - NONE: No GPU split
788
+ # - LAYER: Split the model across GPUs by layer
789
+ # - ROW: Split the model across GPUs by rows
790
+ gpuSplitMode = Param(Params._dummy(), "gpuSplitMode", "Set how to split the model across GPUs",
791
+ typeConverter=TypeConverters.toString)
792
+ mainGpu = Param(Params._dummy(), "mainGpu", "Set the main GPU that is used for scratch and small tensors.",
793
+ typeConverter=TypeConverters.toInt)
794
+ # tensorSplit = Param(Params._dummy(), "tensorSplit", "Set how split tensors should be distributed across GPUs",
795
+ # typeConverter=TypeConverters.toListFloat)
796
+ # grpAttnN = Param(Params._dummy(), "grpAttnN", "Set the group-attention factor", typeConverter=TypeConverters.toInt)
797
+ # grpAttnW = Param(Params._dummy(), "grpAttnW", "Set the group-attention width", typeConverter=TypeConverters.toInt)
798
+ ropeFreqBase = Param(Params._dummy(), "ropeFreqBase", "Set the RoPE base frequency, used by NTK-aware scaling",
799
+ typeConverter=TypeConverters.toFloat)
800
+ ropeFreqScale = Param(Params._dummy(), "ropeFreqScale",
801
+ "Set the RoPE frequency scaling factor, expands context by a factor of 1/N",
802
+ typeConverter=TypeConverters.toFloat)
803
+ yarnExtFactor = Param(Params._dummy(), "yarnExtFactor", "Set the YaRN extrapolation mix factor",
804
+ typeConverter=TypeConverters.toFloat)
805
+ yarnAttnFactor = Param(Params._dummy(), "yarnAttnFactor", "Set the YaRN scale sqrt(t) or attention magnitude",
806
+ typeConverter=TypeConverters.toFloat)
807
+ yarnBetaFast = Param(Params._dummy(), "yarnBetaFast", "Set the YaRN low correction dim or beta",
808
+ typeConverter=TypeConverters.toFloat)
809
+ yarnBetaSlow = Param(Params._dummy(), "yarnBetaSlow", "Set the YaRN high correction dim or alpha",
810
+ typeConverter=TypeConverters.toFloat)
811
+ yarnOrigCtx = Param(Params._dummy(), "yarnOrigCtx", "Set the YaRN original context size of model",
812
+ typeConverter=TypeConverters.toInt)
813
+ defragmentationThreshold = Param(Params._dummy(), "defragmentationThreshold",
814
+ "Set the KV cache defragmentation threshold", typeConverter=TypeConverters.toFloat)
815
+ # Set optimization strategies that help on some NUMA systems (if available)
816
+ #
817
+ # Available Strategies:
818
+ #
819
+ # - DISABLED: No NUMA optimizations
820
+ # - DISTRIBUTE: Spread execution evenly over all
821
+ # - ISOLATE: Only spawn threads on CPUs on the node that execution started on
822
+ # - NUMA_CTL: Use the CPU map provided by numactl
823
+ # - MIRROR: Mirrors the model across NUMA nodes
824
+ numaStrategy = Param(Params._dummy(), "numaStrategy",
825
+ "Set optimization strategies that help on some NUMA systems (if available)",
826
+ typeConverter=TypeConverters.toString)
827
+ # Set the RoPE frequency scaling method, defaults to linear unless specified by the model.
828
+ #
829
+ # - NONE: Don't use any scaling
830
+ # - LINEAR: Linear scaling
831
+ # - YARN: YaRN RoPE scaling
832
+ ropeScalingType = Param(Params._dummy(), "ropeScalingType",
833
+ "Set the RoPE frequency scaling method, defaults to linear unless specified by the model",
834
+ typeConverter=TypeConverters.toString)
835
+ # Set the pooling type for embeddings, use model default if unspecified
836
+ #
837
+ # - MEAN: Mean Pooling
838
+ # - CLS: CLS Pooling
839
+ # - LAST: Last token pooling
840
+ # - RANK: For reranked models
841
+ poolingType = Param(Params._dummy(), "poolingType",
842
+ "Set the pooling type for embeddings, use model default if unspecified",
843
+ typeConverter=TypeConverters.toString)
844
+ modelDraft = Param(Params._dummy(), "modelDraft", "Set the draft model for speculative decoding",
845
+ typeConverter=TypeConverters.toString)
846
+ modelAlias = Param(Params._dummy(), "modelAlias", "Set a model alias", typeConverter=TypeConverters.toString)
847
+ # lookupCacheStaticFilePath = Param(Params._dummy(), "lookupCacheStaticFilePath",
848
+ # "Set path to static lookup cache to use for lookup decoding (not updated by generation)",
849
+ # typeConverter=TypeConverters.toString)
850
+ # lookupCacheDynamicFilePath = Param(Params._dummy(), "lookupCacheDynamicFilePath",
851
+ # "Set path to dynamic lookup cache to use for lookup decoding (updated by generation)",
852
+ # typeConverter=TypeConverters.toString)
853
+ # loraAdapters = new StructFeature[Map[String, Float]](this, "loraAdapters")
854
+ embedding = Param(Params._dummy(), "embedding", "Whether to load model with embedding support",
855
+ typeConverter=TypeConverters.toBoolean)
856
+ flashAttention = Param(Params._dummy(), "flashAttention", "Whether to enable Flash Attention",
857
+ typeConverter=TypeConverters.toBoolean)
858
+ # inputPrefixBos = Param(Params._dummy(), "inputPrefixBos",
859
+ # "Whether to add prefix BOS to user inputs, preceding the `--in-prefix` string",
860
+ # typeConverter=TypeConverters.toBoolean)
861
+ useMmap = Param(Params._dummy(), "useMmap",
862
+ "Whether to use memory-map model (faster load but may increase pageouts if not using mlock)",
863
+ typeConverter=TypeConverters.toBoolean)
864
+ useMlock = Param(Params._dummy(), "useMlock",
865
+ "Whether to force the system to keep model in RAM rather than swapping or compressing",
866
+ typeConverter=TypeConverters.toBoolean)
867
+ noKvOffload = Param(Params._dummy(), "noKvOffload", "Whether to disable KV offload",
868
+ typeConverter=TypeConverters.toBoolean)
869
+ systemPrompt = Param(Params._dummy(), "systemPrompt", "Set a system prompt to use",
870
+ typeConverter=TypeConverters.toString)
871
+ chatTemplate = Param(Params._dummy(), "chatTemplate", "The chat template to use",
872
+ typeConverter=TypeConverters.toString)
873
+ logVerbosity = Param(Params._dummy(), "logVerbosity", "Set the log verbosity level",
874
+ typeConverter=TypeConverters.toInt)
875
+ disableLog = Param(Params._dummy(), "disableLog", "Whether to disable logging",
876
+ typeConverter=TypeConverters.toBoolean)
877
+
878
+ # -------- INFERENCE PARAMETERS --------
879
+ inputPrefix = Param(Params._dummy(), "inputPrefix", "Set the prompt to start generation with",
880
+ typeConverter=TypeConverters.toString)
881
+ inputSuffix = Param(Params._dummy(), "inputSuffix", "Set a suffix for infilling",
882
+ typeConverter=TypeConverters.toString)
883
+ cachePrompt = Param(Params._dummy(), "cachePrompt", "Whether to remember the prompt to avoid reprocessing it",
884
+ typeConverter=TypeConverters.toBoolean)
885
+ nPredict = Param(Params._dummy(), "nPredict", "Set the number of tokens to predict",
886
+ typeConverter=TypeConverters.toInt)
887
+ topK = Param(Params._dummy(), "topK", "Set top-k sampling", typeConverter=TypeConverters.toInt)
888
+ topP = Param(Params._dummy(), "topP", "Set top-p sampling", typeConverter=TypeConverters.toFloat)
889
+ minP = Param(Params._dummy(), "minP", "Set min-p sampling", typeConverter=TypeConverters.toFloat)
890
+ tfsZ = Param(Params._dummy(), "tfsZ", "Set tail free sampling, parameter z", typeConverter=TypeConverters.toFloat)
891
+ typicalP = Param(Params._dummy(), "typicalP", "Set locally typical sampling, parameter p",
892
+ typeConverter=TypeConverters.toFloat)
893
+ temperature = Param(Params._dummy(), "temperature", "Set the temperature", typeConverter=TypeConverters.toFloat)
894
+ dynamicTemperatureRange = Param(Params._dummy(), "dynatempRange", "Set the dynamic temperature range",
895
+ typeConverter=TypeConverters.toFloat)
896
+ dynamicTemperatureExponent = Param(Params._dummy(), "dynatempExponent", "Set the dynamic temperature exponent",
897
+ typeConverter=TypeConverters.toFloat)
898
+ repeatLastN = Param(Params._dummy(), "repeatLastN", "Set the last n tokens to consider for penalties",
899
+ typeConverter=TypeConverters.toInt)
900
+ repeatPenalty = Param(Params._dummy(), "repeatPenalty", "Set the penalty of repeated sequences of tokens",
901
+ typeConverter=TypeConverters.toFloat)
902
+ frequencyPenalty = Param(Params._dummy(), "frequencyPenalty", "Set the repetition alpha frequency penalty",
903
+ typeConverter=TypeConverters.toFloat)
904
+ presencePenalty = Param(Params._dummy(), "presencePenalty", "Set the repetition alpha presence penalty",
905
+ typeConverter=TypeConverters.toFloat)
906
+ miroStat = Param(Params._dummy(), "miroStat", "Set MiroStat sampling strategies.",
907
+ typeConverter=TypeConverters.toString)
908
+ miroStatTau = Param(Params._dummy(), "mirostatTau", "Set the MiroStat target entropy, parameter tau",
909
+ typeConverter=TypeConverters.toFloat)
910
+ miroStatEta = Param(Params._dummy(), "mirostatEta", "Set the MiroStat learning rate, parameter eta",
911
+ typeConverter=TypeConverters.toFloat)
912
+ penalizeNl = Param(Params._dummy(), "penalizeNl", "Whether to penalize newline tokens",
913
+ typeConverter=TypeConverters.toBoolean)
914
+ nKeep = Param(Params._dummy(), "nKeep", "Set the number of tokens to keep from the initial prompt",
915
+ typeConverter=TypeConverters.toInt)
916
+ seed = Param(Params._dummy(), "seed", "Set the RNG seed", typeConverter=TypeConverters.toInt)
917
+ nProbs = Param(Params._dummy(), "nProbs", "Set the amount top tokens probabilities to output if greater than 0.",
918
+ typeConverter=TypeConverters.toInt)
919
+ minKeep = Param(Params._dummy(), "minKeep",
920
+ "Set the amount of tokens the samplers should return at least (0 = disabled)",
921
+ typeConverter=TypeConverters.toInt)
922
+ grammar = Param(Params._dummy(), "grammar", "Set BNF-like grammar to constrain generations",
923
+ typeConverter=TypeConverters.toString)
924
+ penaltyPrompt = Param(Params._dummy(), "penaltyPrompt",
925
+ "Override which part of the prompt is penalized for repetition.",
926
+ typeConverter=TypeConverters.toString)
927
+ ignoreEos = Param(Params._dummy(), "ignoreEos",
928
+ "Set whether to ignore end of stream token and continue generating (implies --logit-bias 2-inf)",
929
+ typeConverter=TypeConverters.toBoolean)
930
+ disableTokenIds = Param(Params._dummy(), "disableTokenIds", "Set the token ids to disable in the completion",
931
+ typeConverter=TypeConverters.toListInt)
932
+ stopStrings = Param(Params._dummy(), "stopStrings", "Set strings upon seeing which token generation is stopped",
933
+ typeConverter=TypeConverters.toListString)
934
+ samplers = Param(Params._dummy(), "samplers", "Set which samplers to use for token generation in the given order",
935
+ typeConverter=TypeConverters.toListString)
936
+ useChatTemplate = Param(Params._dummy(), "useChatTemplate",
937
+ "Set whether or not generate should apply a chat template",
938
+ typeConverter=TypeConverters.toBoolean)
939
+
940
+ # -------- MODEL SETTERS --------
941
+ def setNThreads(self, nThreads: int):
942
+ """Set the number of threads to use during generation"""
943
+ return self._set(nThreads=nThreads)
944
+
945
+ # def setNThreadsDraft(self, nThreadsDraft: int):
946
+ # """Set the number of threads to use during draft generation"""
947
+ # return self._set(nThreadsDraft=nThreadsDraft)
948
+
949
+ def setNThreadsBatch(self, nThreadsBatch: int):
950
+ """Set the number of threads to use during batch and prompt processing"""
951
+ return self._set(nThreadsBatch=nThreadsBatch)
952
+
953
+ # def setNThreadsBatchDraft(self, nThreadsBatchDraft: int):
954
+ # """Set the number of threads to use during batch and prompt processing"""
955
+ # return self._set(nThreadsBatchDraft=nThreadsBatchDraft)
956
+
957
+ def setNCtx(self, nCtx: int):
958
+ """Set the size of the prompt context"""
959
+ return self._set(nCtx=nCtx)
960
+
961
+ def setNBatch(self, nBatch: int):
962
+ """Set the logical batch size for prompt processing (must be >=32 to use BLAS)"""
963
+ return self._set(nBatch=nBatch)
964
+
965
+ def setNUbatch(self, nUbatch: int):
966
+ """Set the physical batch size for prompt processing (must be >=32 to use BLAS)"""
967
+ return self._set(nUbatch=nUbatch)
968
+
969
+ def setNDraft(self, nDraft: int):
970
+ """Set the number of tokens to draft for speculative decoding"""
971
+ return self._set(nDraft=nDraft)
972
+
973
+ # def setNChunks(self, nChunks: int):
974
+ # """Set the maximal number of chunks to process"""
975
+ # return self._set(nChunks=nChunks)
976
+
977
+ # def setNSequences(self, nSequences: int):
978
+ # """Set the number of sequences to decode"""
979
+ # return self._set(nSequences=nSequences)
980
+
981
+ # def setPSplit(self, pSplit: float):
982
+ # """Set the speculative decoding split probability"""
983
+ # return self._set(pSplit=pSplit)
984
+
985
+ def setNGpuLayers(self, nGpuLayers: int):
986
+ """Set the number of layers to store in VRAM (-1 - use default)"""
987
+ return self._set(nGpuLayers=nGpuLayers)
988
+
989
+ def setNGpuLayersDraft(self, nGpuLayersDraft: int):
990
+ """Set the number of layers to store in VRAM for the draft model (-1 - use default)"""
991
+ return self._set(nGpuLayersDraft=nGpuLayersDraft)
992
+
993
+ def setGpuSplitMode(self, gpuSplitMode: str):
994
+ """Set how to split the model across GPUs"""
995
+ return self._set(gpuSplitMode=gpuSplitMode)
996
+
997
+ def setMainGpu(self, mainGpu: int):
998
+ """Set the main GPU that is used for scratch and small tensors."""
999
+ return self._set(mainGpu=mainGpu)
1000
+
1001
+ # def setTensorSplit(self, tensorSplit: List[float]):
1002
+ # """Set how split tensors should be distributed across GPUs"""
1003
+ # return self._set(tensorSplit=tensorSplit)
1004
+
1005
+ # def setGrpAttnN(self, grpAttnN: int):
1006
+ # """Set the group-attention factor"""
1007
+ # return self._set(grpAttnN=grpAttnN)
1008
+
1009
+ # def setGrpAttnW(self, grpAttnW: int):
1010
+ # """Set the group-attention width"""
1011
+ # return self._set(grpAttnW=grpAttnW)
1012
+
1013
+ def setRopeFreqBase(self, ropeFreqBase: float):
1014
+ """Set the RoPE base frequency, used by NTK-aware scaling"""
1015
+ return self._set(ropeFreqBase=ropeFreqBase)
1016
+
1017
+ def setRopeFreqScale(self, ropeFreqScale: float):
1018
+ """Set the RoPE frequency scaling factor, expands context by a factor of 1/N"""
1019
+ return self._set(ropeFreqScale=ropeFreqScale)
1020
+
1021
+ def setYarnExtFactor(self, yarnExtFactor: float):
1022
+ """Set the YaRN extrapolation mix factor"""
1023
+ return self._set(yarnExtFactor=yarnExtFactor)
1024
+
1025
+ def setYarnAttnFactor(self, yarnAttnFactor: float):
1026
+ """Set the YaRN scale sqrt(t) or attention magnitude"""
1027
+ return self._set(yarnAttnFactor=yarnAttnFactor)
1028
+
1029
+ def setYarnBetaFast(self, yarnBetaFast: float):
1030
+ """Set the YaRN low correction dim or beta"""
1031
+ return self._set(yarnBetaFast=yarnBetaFast)
1032
+
1033
+ def setYarnBetaSlow(self, yarnBetaSlow: float):
1034
+ """Set the YaRN high correction dim or alpha"""
1035
+ return self._set(yarnBetaSlow=yarnBetaSlow)
1036
+
1037
+ def setYarnOrigCtx(self, yarnOrigCtx: int):
1038
+ """Set the YaRN original context size of model"""
1039
+ return self._set(yarnOrigCtx=yarnOrigCtx)
1040
+
1041
+ def setDefragmentationThreshold(self, defragmentationThreshold: float):
1042
+ """Set the KV cache defragmentation threshold"""
1043
+ return self._set(defragmentationThreshold=defragmentationThreshold)
1044
+
1045
+ def setNumaStrategy(self, numaStrategy: str):
1046
+ """Set optimization strategies that help on some NUMA systems (if available)
1047
+
1048
+ Possible values:
1049
+
1050
+ - DISABLED: No NUMA optimizations
1051
+ - DISTRIBUTE: spread execution evenly over all
1052
+ - ISOLATE: only spawn threads on CPUs on the node that execution started on
1053
+ - NUMA_CTL: use the CPU map provided by numactl
1054
+ - MIRROR: Mirrors the model across NUMA nodes
1055
+ """
1056
+ numaUpper = numaStrategy.upper()
1057
+ numaStrategies = ["DISABLED", "DISTRIBUTE", "ISOLATE", "NUMA_CTL", "MIRROR"]
1058
+ if numaUpper not in numaStrategies:
1059
+ raise ValueError(
1060
+ f"Invalid NUMA strategy: {numaUpper}. "
1061
+ + f"Valid values are: {numaStrategies}"
1062
+ )
1063
+ return self._set(numaStrategy=numaStrategy)
1064
+
1065
+ def setRopeScalingType(self, ropeScalingType: str):
1066
+ """Set the RoPE frequency scaling method, defaults to linear unless specified by the model.
1067
+
1068
+ Possible values:
1069
+
1070
+ - NONE: Don't use any scaling
1071
+ - LINEAR: Linear scaling
1072
+ - YARN: YaRN RoPE scaling
1073
+ """
1074
+ ropeScalingTypeUpper = ropeScalingType.upper()
1075
+ ropeScalingTypes = ["NONE", "LINEAR", "YARN"]
1076
+ if ropeScalingTypeUpper not in ropeScalingTypes:
1077
+ raise ValueError(
1078
+ f"Invalid RoPE scaling type: {ropeScalingType}. "
1079
+ + f"Valid values are: {ropeScalingTypes}"
1080
+ )
1081
+ return self._set(ropeScalingType=ropeScalingTypeUpper)
1082
+
1083
+ def setPoolingType(self, poolingType: str):
1084
+ """Set the pooling type for embeddings, use model default if unspecified
1085
+
1086
+ Possible values:
1087
+
1088
+ - MEAN: Mean Pooling
1089
+ - CLS: CLS Pooling
1090
+ - LAST: Last token pooling
1091
+ - RANK: For reranked models
1092
+ """
1093
+ poolingTypeUpper = poolingType.upper()
1094
+ poolingTypes = ["NONE", "MEAN", "CLS", "LAST", "RANK"]
1095
+ if poolingTypeUpper not in poolingTypes:
1096
+ raise ValueError(
1097
+ f"Invalid pooling type: {poolingType}. "
1098
+ + f"Valid values are: {poolingTypes}"
1099
+ )
1100
+ return self._set(poolingType=poolingType)
1101
+
1102
+ def setModelDraft(self, modelDraft: str):
1103
+ """Set the draft model for speculative decoding"""
1104
+ return self._set(modelDraft=modelDraft)
1105
+
1106
+ def setModelAlias(self, modelAlias: str):
1107
+ """Set a model alias"""
1108
+ return self._set(modelAlias=modelAlias)
1109
+
1110
+ # def setLookupCacheStaticFilePath(self, lookupCacheStaticFilePath: str):
1111
+ # """Set path to static lookup cache to use for lookup decoding (not updated by generation)"""
1112
+ # return self._set(lookupCacheStaticFilePath=lookupCacheStaticFilePath)
1113
+
1114
+ # def setLookupCacheDynamicFilePath(self, lookupCacheDynamicFilePath: str):
1115
+ # """Set path to dynamic lookup cache to use for lookup decoding (updated by generation)"""
1116
+ # return self._set(lookupCacheDynamicFilePath=lookupCacheDynamicFilePath)
1117
+
1118
+ def setFlashAttention(self, flashAttention: bool):
1119
+ """Whether to enable Flash Attention"""
1120
+ return self._set(flashAttention=flashAttention)
1121
+
1122
+ # def setInputPrefixBos(self, inputPrefixBos: bool):
1123
+ # """Whether to add prefix BOS to user inputs, preceding the `--in-prefix` string"""
1124
+ # return self._set(inputPrefixBos=inputPrefixBos)
1125
+
1126
+ def setUseMmap(self, useMmap: bool):
1127
+ """Whether to use memory-map model (faster load but may increase pageouts if not using mlock)"""
1128
+ return self._set(useMmap=useMmap)
1129
+
1130
+ def setUseMlock(self, useMlock: bool):
1131
+ """Whether to force the system to keep model in RAM rather than swapping or compressing"""
1132
+ return self._set(useMlock=useMlock)
1133
+
1134
+ def setNoKvOffload(self, noKvOffload: bool):
1135
+ """Whether to disable KV offload"""
1136
+ return self._set(noKvOffload=noKvOffload)
1137
+
1138
+ def setSystemPrompt(self, systemPrompt: str):
1139
+ """Set a system prompt to use"""
1140
+ return self._set(systemPrompt=systemPrompt)
1141
+
1142
+ def setChatTemplate(self, chatTemplate: str):
1143
+ """The chat template to use"""
1144
+ return self._set(chatTemplate=chatTemplate)
1145
+
1146
+ # -------- INFERENCE SETTERS --------
1147
+ def setInputPrefix(self, inputPrefix: str):
1148
+ """Set the prompt to start generation with"""
1149
+ return self._set(inputPrefix=inputPrefix)
1150
+
1151
+ def setInputSuffix(self, inputSuffix: str):
1152
+ """Set a suffix for infilling"""
1153
+ return self._set(inputSuffix=inputSuffix)
1154
+
1155
+ def setCachePrompt(self, cachePrompt: bool):
1156
+ """Whether to remember the prompt to avoid reprocessing it"""
1157
+ return self._set(cachePrompt=cachePrompt)
1158
+
1159
+ def setNPredict(self, nPredict: int):
1160
+ """Set the number of tokens to predict"""
1161
+ return self._set(nPredict=nPredict)
1162
+
1163
+ def setTopK(self, topK: int):
1164
+ """Set top-k sampling"""
1165
+ return self._set(topK=topK)
1166
+
1167
+ def setTopP(self, topP: float):
1168
+ """Set top-p sampling"""
1169
+ return self._set(topP=topP)
1170
+
1171
+ def setMinP(self, minP: float):
1172
+ """Set min-p sampling"""
1173
+ return self._set(minP=minP)
1174
+
1175
+ def setTfsZ(self, tfsZ: float):
1176
+ """Set tail free sampling, parameter z"""
1177
+ return self._set(tfsZ=tfsZ)
1178
+
1179
+ def setTypicalP(self, typicalP: float):
1180
+ """Set locally typical sampling, parameter p"""
1181
+ return self._set(typicalP=typicalP)
1182
+
1183
+ def setTemperature(self, temperature: float):
1184
+ """Set the temperature"""
1185
+ return self._set(temperature=temperature)
1186
+
1187
+ def setDynamicTemperatureRange(self, dynamicTemperatureRange: float):
1188
+ """Set the dynamic temperature range"""
1189
+ return self._set(dynamicTemperatureRange=dynamicTemperatureRange)
1190
+
1191
+ def setDynamicTemperatureExponent(self, dynamicTemperatureExponent: float):
1192
+ """Set the dynamic temperature exponent"""
1193
+ return self._set(dynamicTemperatureExponent=dynamicTemperatureExponent)
1194
+
1195
+ def setRepeatLastN(self, repeatLastN: int):
1196
+ """Set the last n tokens to consider for penalties"""
1197
+ return self._set(repeatLastN=repeatLastN)
1198
+
1199
+ def setRepeatPenalty(self, repeatPenalty: float):
1200
+ """Set the penalty of repeated sequences of tokens"""
1201
+ return self._set(repeatPenalty=repeatPenalty)
1202
+
1203
+ def setFrequencyPenalty(self, frequencyPenalty: float):
1204
+ """Set the repetition alpha frequency penalty"""
1205
+ return self._set(frequencyPenalty=frequencyPenalty)
1206
+
1207
+ def setPresencePenalty(self, presencePenalty: float):
1208
+ """Set the repetition alpha presence penalty"""
1209
+ return self._set(presencePenalty=presencePenalty)
1210
+
1211
+ def setMiroStat(self, miroStat: str):
1212
+ """Set MiroStat sampling strategies."""
1213
+ return self._set(miroStat=miroStat)
1214
+
1215
+ def setMiroStatTau(self, miroStatTau: float):
1216
+ """Set the MiroStat target entropy, parameter tau"""
1217
+ return self._set(miroStatTau=miroStatTau)
1218
+
1219
+ def setMiroStatEta(self, miroStatEta: float):
1220
+ """Set the MiroStat learning rate, parameter eta"""
1221
+ return self._set(miroStatEta=miroStatEta)
1222
+
1223
+ def setPenalizeNl(self, penalizeNl: bool):
1224
+ """Whether to penalize newline tokens"""
1225
+ return self._set(penalizeNl=penalizeNl)
1226
+
1227
+ def setNKeep(self, nKeep: int):
1228
+ """Set the number of tokens to keep from the initial prompt"""
1229
+ return self._set(nKeep=nKeep)
1230
+
1231
+ def setSeed(self, seed: int):
1232
+ """Set the RNG seed"""
1233
+ return self._set(seed=seed)
1234
+
1235
+ def setNProbs(self, nProbs: int):
1236
+ """Set the amount top tokens probabilities to output if greater than 0."""
1237
+ return self._set(nProbs=nProbs)
1238
+
1239
+ def setMinKeep(self, minKeep: int):
1240
+ """Set the amount of tokens the samplers should return at least (0 = disabled)"""
1241
+ return self._set(minKeep=minKeep)
1242
+
1243
+ def setGrammar(self, grammar: str):
1244
+ """Set BNF-like grammar to constrain generations"""
1245
+ return self._set(grammar=grammar)
1246
+
1247
+ def setPenaltyPrompt(self, penaltyPrompt: str):
1248
+ """Override which part of the prompt is penalized for repetition."""
1249
+ return self._set(penaltyPrompt=penaltyPrompt)
1250
+
1251
+ def setIgnoreEos(self, ignoreEos: bool):
1252
+ """Set whether to ignore end of stream token and continue generating (implies --logit-bias 2-inf)"""
1253
+ return self._set(ignoreEos=ignoreEos)
1254
+
1255
+ def setDisableTokenIds(self, disableTokenIds: List[int]):
1256
+ """Set the token ids to disable in the completion"""
1257
+ return self._set(disableTokenIds=disableTokenIds)
1258
+
1259
+ def setStopStrings(self, stopStrings: List[str]):
1260
+ """Set strings upon seeing which token generation is stopped"""
1261
+ return self._set(stopStrings=stopStrings)
1262
+
1263
+ def setSamplers(self, samplers: List[str]):
1264
+ """Set which samplers to use for token generation in the given order"""
1265
+ return self._set(samplers=samplers)
1266
+
1267
+ def setUseChatTemplate(self, useChatTemplate: bool):
1268
+ """Set whether generate should apply a chat template"""
1269
+ return self._set(useChatTemplate=useChatTemplate)
1270
+
1271
+ def setNParallel(self, nParallel: int):
1272
+ """Sets the number of parallel processes for decoding. This is an alias for `setBatchSize`."""
1273
+ return self.setBatchSize(nParallel)
1274
+
1275
+ def setLogVerbosity(self, logVerbosity: int):
1276
+ """Set the log verbosity level"""
1277
+ return self._set(logVerbosity=logVerbosity)
1278
+
1279
+ def setDisableLog(self, disableLog: bool):
1280
+ """Whether to disable logging"""
1281
+ return self._set(disableLog=disableLog)
1282
+
1283
+ # -------- JAVA SETTERS --------
1284
+ def setTokenIdBias(self, tokenIdBias: Dict[int, float]):
1285
+ """Set token id bias"""
1286
+ return self._call_java("setTokenIdBias", tokenIdBias)
1287
+
1288
+ def setTokenBias(self, tokenBias: Dict[str, float]):
1289
+ """Set token id bias"""
1290
+ return self._call_java("setTokenBias", tokenBias)
1291
+
1292
+ # def setLoraAdapters(self, loraAdapters: Dict[str, float]):
1293
+ # """Set LoRA adapters with their scaling factors"""
1294
+ # return self._call_java("setLoraAdapters", loraAdapters)
1295
+
1296
+ def getMetadata(self):
1297
+ """Gets the metadata of the model"""
1298
+ return self._call_java("getMetadata")