spark-nlp 5.4.1__py2.py3-none-any.whl → 5.5.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of spark-nlp might be problematic. Click here for more details.

Files changed (26) hide show
  1. spark_nlp-5.5.0.dist-info/METADATA +345 -0
  2. {spark_nlp-5.4.1.dist-info → spark_nlp-5.5.0.dist-info}/RECORD +25 -13
  3. sparknlp/__init__.py +2 -2
  4. sparknlp/annotator/classifier_dl/__init__.py +4 -1
  5. sparknlp/annotator/classifier_dl/albert_for_zero_shot_classification.py +211 -0
  6. sparknlp/annotator/classifier_dl/camembert_for_zero_shot_classification.py +202 -0
  7. sparknlp/annotator/classifier_dl/deberta_for_zero_shot_classification.py +2 -15
  8. sparknlp/annotator/embeddings/__init__.py +3 -0
  9. sparknlp/annotator/embeddings/mxbai_embeddings.py +184 -0
  10. sparknlp/annotator/embeddings/nomic_embeddings.py +181 -0
  11. sparknlp/annotator/embeddings/snowflake_embeddings.py +202 -0
  12. sparknlp/annotator/matcher/date_matcher.py +15 -0
  13. sparknlp/annotator/seq2seq/__init__.py +7 -0
  14. sparknlp/annotator/seq2seq/auto_gguf_model.py +804 -0
  15. sparknlp/annotator/seq2seq/cpm_transformer.py +321 -0
  16. sparknlp/annotator/seq2seq/llama3_transformer.py +381 -0
  17. sparknlp/annotator/seq2seq/nllb_transformer.py +420 -0
  18. sparknlp/annotator/seq2seq/phi3_transformer.py +330 -0
  19. sparknlp/annotator/seq2seq/qwen_transformer.py +339 -0
  20. sparknlp/annotator/seq2seq/starcoder_transformer.py +335 -0
  21. sparknlp/annotator/similarity/document_similarity_ranker.py +22 -0
  22. sparknlp/internal/__init__.py +89 -0
  23. spark_nlp-5.4.1.dist-info/METADATA +0 -1357
  24. {spark_nlp-5.4.1.dist-info → spark_nlp-5.5.0.dist-info}/.uuid +0 -0
  25. {spark_nlp-5.4.1.dist-info → spark_nlp-5.5.0.dist-info}/WHEEL +0 -0
  26. {spark_nlp-5.4.1.dist-info → spark_nlp-5.5.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,804 @@
1
+ # Copyright 2017-2023 John Snow Labs
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Contains classes for the AutoGGUFModel."""
15
+ from typing import List, Dict
16
+
17
+ from sparknlp.common import *
18
+
19
+
20
+ class AutoGGUFModel(AnnotatorModel, HasBatchedAnnotate):
21
+ """
22
+ Annotator that uses the llama.cpp library to generate text completions with large language
23
+ models.
24
+
25
+ For settable parameters, and their explanations, see the parameters of this class and refer to
26
+ the llama.cpp documentation of
27
+ `server.cpp <https://github.com/ggerganov/llama.cpp/tree/7d5e8777ae1d21af99d4f95be10db4870720da91/examples/server>`__
28
+ for more information.
29
+
30
+ If the parameters are not set, the annotator will default to use the parameters provided by
31
+ the model.
32
+
33
+ Pretrained models can be loaded with :meth:`.pretrained` of the companion
34
+ object:
35
+
36
+ >>> auto_gguf_model = AutoGGUFModel.pretrained() \\
37
+ ... .setInputCols(["document"]) \\
38
+ ... .setOutputCol("completions")
39
+
40
+ The default model is ``"phi3.5_mini_4k_instruct_q4_gguf"``, if no name is provided.
41
+
42
+ For extended examples of usage, see the
43
+ `AutoGGUFModelTest <https://github.com/JohnSnowLabs/spark-nlp/tree/master/src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/AutoGGUFModelTest.scala>`__
44
+ and the
45
+ `example notebook <https://github.com/JohnSnowLabs/spark-nlp/tree/master/examples/python/llama.cpp/llama.cpp_in_Spark_NLP_AutoGGUFModel.ipynb>`__.
46
+
47
+ For available pretrained models please see the `Models Hub <https://sparknlp.org/models>`__.
48
+
49
+ ====================== ======================
50
+ Input Annotation types Output Annotation type
51
+ ====================== ======================
52
+ ``DOCUMENT`` ``DOCUMENT``
53
+ ====================== ======================
54
+
55
+ Parameters
56
+ ----------
57
+ nThreads
58
+ Set the number of threads to use during generation
59
+ nThreadsDraft
60
+ Set the number of threads to use during draft generation
61
+ nThreadsBatch
62
+ Set the number of threads to use during batch and prompt processing
63
+ nThreadsBatchDraft
64
+ Set the number of threads to use during batch and prompt processing
65
+ nCtx
66
+ Set the size of the prompt context
67
+ nBatch
68
+ Set the logical batch size for prompt processing (must be >=32 to use BLAS)
69
+ nUbatch
70
+ Set the physical batch size for prompt processing (must be >=32 to use BLAS)
71
+ nDraft
72
+ Set the number of tokens to draft for speculative decoding
73
+ nChunks
74
+ Set the maximal number of chunks to process
75
+ nSequences
76
+ Set the number of sequences to decode
77
+ pSplit
78
+ Set the speculative decoding split probability
79
+ nGpuLayers
80
+ Set the number of layers to store in VRAM (-1 - use default)
81
+ nGpuLayersDraft
82
+ Set the number of layers to store in VRAM for the draft model (-1 - use default)
83
+ gpuSplitMode
84
+ Set how to split the model across GPUs
85
+ mainGpu
86
+ Set the main GPU that is used for scratch and small tensors.
87
+ tensorSplit
88
+ Set how split tensors should be distributed across GPUs
89
+ grpAttnN
90
+ Set the group-attention factor
91
+ grpAttnW
92
+ Set the group-attention width
93
+ ropeFreqBase
94
+ Set the RoPE base frequency, used by NTK-aware scaling
95
+ ropeFreqScale
96
+ Set the RoPE frequency scaling factor, expands context by a factor of 1/N
97
+ yarnExtFactor
98
+ Set the YaRN extrapolation mix factor
99
+ yarnAttnFactor
100
+ Set the YaRN scale sqrt(t) or attention magnitude
101
+ yarnBetaFast
102
+ Set the YaRN low correction dim or beta
103
+ yarnBetaSlow
104
+ Set the YaRN high correction dim or alpha
105
+ yarnOrigCtx
106
+ Set the YaRN original context size of model
107
+ defragmentationThreshold
108
+ Set the KV cache defragmentation threshold
109
+ numaStrategy
110
+ Set optimization strategies that help on some NUMA systems (if available)
111
+ ropeScalingType
112
+ Set the RoPE frequency scaling method, defaults to linear unless specified by the model
113
+ poolingType
114
+ Set the pooling type for embeddings, use model default if unspecified
115
+ modelDraft
116
+ Set the draft model for speculative decoding
117
+ modelAlias
118
+ Set a model alias
119
+ lookupCacheStaticFilePath
120
+ Set path to static lookup cache to use for lookup decoding (not updated by generation)
121
+ lookupCacheDynamicFilePath
122
+ Set path to dynamic lookup cache to use for lookup decoding (updated by generation)
123
+ embedding
124
+ Whether to load model with embedding support
125
+ flashAttention
126
+ Whether to enable Flash Attention
127
+ inputPrefixBos
128
+ Whether to add prefix BOS to user inputs, preceding the `--in-prefix` string
129
+ useMmap
130
+ Whether to use memory-map model (faster load but may increase pageouts if not using mlock)
131
+ useMlock
132
+ Whether to force the system to keep model in RAM rather than swapping or compressing
133
+ noKvOffload
134
+ Whether to disable KV offload
135
+ systemPrompt
136
+ Set a system prompt to use
137
+ chatTemplate
138
+ The chat template to use
139
+ inputPrefix
140
+ Set the prompt to start generation with
141
+ inputSuffix
142
+ Set a suffix for infilling
143
+ cachePrompt
144
+ Whether to remember the prompt to avoid reprocessing it
145
+ nPredict
146
+ Set the number of tokens to predict
147
+ topK
148
+ Set top-k sampling
149
+ topP
150
+ Set top-p sampling
151
+ minP
152
+ Set min-p sampling
153
+ tfsZ
154
+ Set tail free sampling, parameter z
155
+ typicalP
156
+ Set locally typical sampling, parameter p
157
+ temperature
158
+ Set the temperature
159
+ dynatempRange
160
+ Set the dynamic temperature range
161
+ dynatempExponent
162
+ Set the dynamic temperature exponent
163
+ repeatLastN
164
+ Set the last n tokens to consider for penalties
165
+ repeatPenalty
166
+ Set the penalty of repeated sequences of tokens
167
+ frequencyPenalty
168
+ Set the repetition alpha frequency penalty
169
+ presencePenalty
170
+ Set the repetition alpha presence penalty
171
+ miroStat
172
+ Set MiroStat sampling strategies.
173
+ mirostatTau
174
+ Set the MiroStat target entropy, parameter tau
175
+ mirostatEta
176
+ Set the MiroStat learning rate, parameter eta
177
+ penalizeNl
178
+ Whether to penalize newline tokens
179
+ nKeep
180
+ Set the number of tokens to keep from the initial prompt
181
+ seed
182
+ Set the RNG seed
183
+ nProbs
184
+ Set the amount top tokens probabilities to output if greater than 0.
185
+ minKeep
186
+ Set the amount of tokens the samplers should return at least (0 = disabled)
187
+ grammar
188
+ Set BNF-like grammar to constrain generations
189
+ penaltyPrompt
190
+ Override which part of the prompt is penalized for repetition.
191
+ ignoreEos
192
+ Set whether to ignore end of stream token and continue generating (implies --logit-bias 2-inf)
193
+ disableTokenIds
194
+ Set the token ids to disable in the completion
195
+ stopStrings
196
+ Set strings upon seeing which token generation is stopped
197
+ samplers
198
+ Set which samplers to use for token generation in the given order
199
+ useChatTemplate
200
+ Set whether or not generate should apply a chat template
201
+
202
+
203
+ Notes
204
+ -----
205
+ To use GPU inference with this annotator, make sure to use the Spark NLP GPU package and set
206
+ the number of GPU layers with the `setNGpuLayers` method.
207
+
208
+ When using larger models, we recommend adjusting GPU usage with `setNCtx` and `setNGpuLayers`
209
+ according to your hardware to avoid out-of-memory errors.
210
+
211
+ References
212
+ ----------
213
+ - `Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension
214
+ <https://arxiv.org/abs/1910.13461>`__
215
+ - https://github.com/pytorch/fairseq
216
+
217
+ **Paper Abstract:**
218
+ *We present BART, a denoising autoencoder for pretraining sequence-to-sequence models.
219
+ BART is trained by (1) corrupting text with an arbitrary noising function, and (2)
220
+ learning a model to reconstruct the original text. It uses a standard Tranformer-based
221
+ neural machine translation architecture which, despite its simplicity, can be seen as
222
+ generalizing BERT (due to the bidirectional encoder), GPT (with the left-to-right decoder),
223
+ and many other more recent pretraining schemes. We evaluate a number of noising approaches,
224
+ finding the best performance by both randomly shuffling the order of the original sentences
225
+ and using a novel in-filling scheme, where spans of text are replaced with a single mask token.
226
+ BART is particularly effective when fine tuned for text generation but also works well for
227
+ comprehension tasks. It matches the performance of RoBERTa with comparable training resources
228
+ on GLUE and SQuAD, achieves new state-of-the-art results on a range of abstractive dialogue,
229
+ question answering, and summarization tasks, with gains of up to 6 ROUGE. BART also provides
230
+ a 1.1 BLEU increase over a back-translation system for machine translation, with only target
231
+ language pretraining. We also report ablation experiments that replicate other pretraining
232
+ schemes within the BART framework, to better measure which factors most influence end-task performance.*
233
+
234
+ Examples
235
+ --------
236
+ >>> import sparknlp
237
+ >>> from sparknlp.base import *
238
+ >>> from sparknlp.annotator import *
239
+ >>> from pyspark.ml import Pipeline
240
+ >>> document = DocumentAssembler() \\
241
+ ... .setInputCol("text") \\
242
+ ... .setOutputCol("document")
243
+ >>> autoGGUFModel = AutoGGUFModel.pretrained() \\
244
+ ... .setInputCols(["document"]) \\
245
+ ... .setOutputCol("completions") \\
246
+ ... .setBatchSize(4) \\
247
+ ... .setNPredict(20) \\
248
+ ... .setNGpuLayers(99) \\
249
+ ... .setTemperature(0.4) \\
250
+ ... .setTopK(40) \\
251
+ ... .setTopP(0.9) \\
252
+ ... .setPenalizeNl(True)
253
+ >>> pipeline = Pipeline().setStages([document, autoGGUFModel])
254
+ >>> data = spark.createDataFrame([["Hello, I am a"]]).toDF("text")
255
+ >>> result = pipeline.fit(data).transform(data)
256
+ >>> result.select("completions").show(truncate = False)
257
+ +-----------------------------------------------------------------------------------------------------------------------------------+
258
+ |completions |
259
+ +-----------------------------------------------------------------------------------------------------------------------------------+
260
+ |[{document, 0, 78, new user. I am currently working on a project and I need to create a list of , {prompt -> Hello, I am a}, []}]|
261
+ +-----------------------------------------------------------------------------------------------------------------------------------+
262
+ """
263
+
264
+ name = "AutoGGUFModel"
265
+ inputAnnotatorTypes = [AnnotatorType.DOCUMENT]
266
+ outputAnnotatorType = AnnotatorType.DOCUMENT
267
+
268
+ # -------- MODEl PARAMETERS --------
269
+ nThreads = Param(Params._dummy(), "nThreads", "Set the number of threads to use during generation",
270
+ typeConverter=TypeConverters.toInt)
271
+ nThreadsDraft = Param(Params._dummy(), "nThreadsDraft", "Set the number of threads to use during draft generation",
272
+ typeConverter=TypeConverters.toInt)
273
+ nThreadsBatch = Param(Params._dummy(), "nThreadsBatch",
274
+ "Set the number of threads to use during batch and prompt processing",
275
+ typeConverter=TypeConverters.toInt)
276
+ nThreadsBatchDraft = Param(Params._dummy(), "nThreadsBatchDraft",
277
+ "Set the number of threads to use during batch and prompt processing",
278
+ typeConverter=TypeConverters.toInt)
279
+ nCtx = Param(Params._dummy(), "nCtx", "Set the size of the prompt context", typeConverter=TypeConverters.toInt)
280
+ nBatch = Param(Params._dummy(), "nBatch",
281
+ "Set the logical batch size for prompt processing (must be >=32 to use BLAS)",
282
+ typeConverter=TypeConverters.toInt)
283
+ nUbatch = Param(Params._dummy(), "nUbatch",
284
+ "Set the physical batch size for prompt processing (must be >=32 to use BLAS)",
285
+ typeConverter=TypeConverters.toInt)
286
+ nDraft = Param(Params._dummy(), "nDraft", "Set the number of tokens to draft for speculative decoding",
287
+ typeConverter=TypeConverters.toInt)
288
+ nChunks = Param(Params._dummy(), "nChunks", "Set the maximal number of chunks to process",
289
+ typeConverter=TypeConverters.toInt)
290
+ nSequences = Param(Params._dummy(), "nSequences", "Set the number of sequences to decode",
291
+ typeConverter=TypeConverters.toInt)
292
+ pSplit = Param(Params._dummy(), "pSplit", "Set the speculative decoding split probability",
293
+ typeConverter=TypeConverters.toFloat)
294
+ nGpuLayers = Param(Params._dummy(), "nGpuLayers", "Set the number of layers to store in VRAM (-1 - use default)",
295
+ typeConverter=TypeConverters.toInt)
296
+ nGpuLayersDraft = Param(Params._dummy(), "nGpuLayersDraft",
297
+ "Set the number of layers to store in VRAM for the draft model (-1 - use default)",
298
+ typeConverter=TypeConverters.toInt)
299
+ # Set how to split the model across GPUs
300
+ #
301
+ # - NONE: No GPU split
302
+ # - LAYER: Split the model across GPUs by layer
303
+ # - ROW: Split the model across GPUs by rows
304
+ gpuSplitMode = Param(Params._dummy(), "gpuSplitMode", "Set how to split the model across GPUs",
305
+ typeConverter=TypeConverters.toString)
306
+ mainGpu = Param(Params._dummy(), "mainGpu", "Set the main GPU that is used for scratch and small tensors.",
307
+ typeConverter=TypeConverters.toInt)
308
+ tensorSplit = Param(Params._dummy(), "tensorSplit", "Set how split tensors should be distributed across GPUs",
309
+ typeConverter=TypeConverters.toListFloat)
310
+ grpAttnN = Param(Params._dummy(), "grpAttnN", "Set the group-attention factor", typeConverter=TypeConverters.toInt)
311
+ grpAttnW = Param(Params._dummy(), "grpAttnW", "Set the group-attention width", typeConverter=TypeConverters.toInt)
312
+ ropeFreqBase = Param(Params._dummy(), "ropeFreqBase", "Set the RoPE base frequency, used by NTK-aware scaling",
313
+ typeConverter=TypeConverters.toFloat)
314
+ ropeFreqScale = Param(Params._dummy(), "ropeFreqScale",
315
+ "Set the RoPE frequency scaling factor, expands context by a factor of 1/N",
316
+ typeConverter=TypeConverters.toFloat)
317
+ yarnExtFactor = Param(Params._dummy(), "yarnExtFactor", "Set the YaRN extrapolation mix factor",
318
+ typeConverter=TypeConverters.toFloat)
319
+ yarnAttnFactor = Param(Params._dummy(), "yarnAttnFactor", "Set the YaRN scale sqrt(t) or attention magnitude",
320
+ typeConverter=TypeConverters.toFloat)
321
+ yarnBetaFast = Param(Params._dummy(), "yarnBetaFast", "Set the YaRN low correction dim or beta",
322
+ typeConverter=TypeConverters.toFloat)
323
+ yarnBetaSlow = Param(Params._dummy(), "yarnBetaSlow", "Set the YaRN high correction dim or alpha",
324
+ typeConverter=TypeConverters.toFloat)
325
+ yarnOrigCtx = Param(Params._dummy(), "yarnOrigCtx", "Set the YaRN original context size of model",
326
+ typeConverter=TypeConverters.toInt)
327
+ defragmentationThreshold = Param(Params._dummy(), "defragmentationThreshold",
328
+ "Set the KV cache defragmentation threshold", typeConverter=TypeConverters.toFloat)
329
+ # Set optimization strategies that help on some NUMA systems (if available)
330
+ #
331
+ # Available Strategies:
332
+ #
333
+ # - DISABLED: No NUMA optimizations
334
+ # - DISTRIBUTE: Spread execution evenly over all
335
+ # - ISOLATE: Only spawn threads on CPUs on the node that execution started on
336
+ # - NUMA_CTL: Use the CPU map provided by numactl
337
+ # - MIRROR: Mirrors the model across NUMA nodes
338
+ numaStrategy = Param(Params._dummy(), "numaStrategy",
339
+ "Set optimization strategies that help on some NUMA systems (if available)",
340
+ typeConverter=TypeConverters.toString)
341
+ # Set the RoPE frequency scaling method, defaults to linear unless specified by the model.
342
+ #
343
+ # - UNSPECIFIED: Don't use any scaling
344
+ # - LINEAR: Linear scaling
345
+ # - YARN: YaRN RoPE scaling
346
+ ropeScalingType = Param(Params._dummy(), "ropeScalingType",
347
+ "Set the RoPE frequency scaling method, defaults to linear unless specified by the model",
348
+ typeConverter=TypeConverters.toString)
349
+ # Set the pooling type for embeddings, use model default if unspecified
350
+ #
351
+ # - 0 UNSPECIFIED: Don't use any pooling
352
+ # - 1 MEAN: Mean Pooling
353
+ # - 2 CLS: CLS Pooling
354
+ poolingType = Param(Params._dummy(), "poolingType",
355
+ "Set the pooling type for embeddings, use model default if unspecified",
356
+ typeConverter=TypeConverters.toString)
357
+ modelDraft = Param(Params._dummy(), "modelDraft", "Set the draft model for speculative decoding",
358
+ typeConverter=TypeConverters.toString)
359
+ modelAlias = Param(Params._dummy(), "modelAlias", "Set a model alias", typeConverter=TypeConverters.toString)
360
+ lookupCacheStaticFilePath = Param(Params._dummy(), "lookupCacheStaticFilePath",
361
+ "Set path to static lookup cache to use for lookup decoding (not updated by generation)",
362
+ typeConverter=TypeConverters.toString)
363
+ lookupCacheDynamicFilePath = Param(Params._dummy(), "lookupCacheDynamicFilePath",
364
+ "Set path to dynamic lookup cache to use for lookup decoding (updated by generation)",
365
+ typeConverter=TypeConverters.toString)
366
+ # loraAdapters = new StructFeature[Map[String, Float]](this, "loraAdapters")
367
+ embedding = Param(Params._dummy(), "embedding", "Whether to load model with embedding support",
368
+ typeConverter=TypeConverters.toBoolean)
369
+ flashAttention = Param(Params._dummy(), "flashAttention", "Whether to enable Flash Attention",
370
+ typeConverter=TypeConverters.toBoolean)
371
+ inputPrefixBos = Param(Params._dummy(), "inputPrefixBos",
372
+ "Whether to add prefix BOS to user inputs, preceding the `--in-prefix` string",
373
+ typeConverter=TypeConverters.toBoolean)
374
+ useMmap = Param(Params._dummy(), "useMmap",
375
+ "Whether to use memory-map model (faster load but may increase pageouts if not using mlock)",
376
+ typeConverter=TypeConverters.toBoolean)
377
+ useMlock = Param(Params._dummy(), "useMlock",
378
+ "Whether to force the system to keep model in RAM rather than swapping or compressing",
379
+ typeConverter=TypeConverters.toBoolean)
380
+ noKvOffload = Param(Params._dummy(), "noKvOffload", "Whether to disable KV offload",
381
+ typeConverter=TypeConverters.toBoolean)
382
+ systemPrompt = Param(Params._dummy(), "systemPrompt", "Set a system prompt to use",
383
+ typeConverter=TypeConverters.toString)
384
+ chatTemplate = Param(Params._dummy(), "chatTemplate", "The chat template to use",
385
+ typeConverter=TypeConverters.toString)
386
+
387
+ # -------- INFERENCE PARAMETERS --------
388
+ inputPrefix = Param(Params._dummy(), "inputPrefix", "Set the prompt to start generation with",
389
+ typeConverter=TypeConverters.toString)
390
+ inputSuffix = Param(Params._dummy(), "inputSuffix", "Set a suffix for infilling",
391
+ typeConverter=TypeConverters.toString)
392
+ cachePrompt = Param(Params._dummy(), "cachePrompt", "Whether to remember the prompt to avoid reprocessing it",
393
+ typeConverter=TypeConverters.toBoolean)
394
+ nPredict = Param(Params._dummy(), "nPredict", "Set the number of tokens to predict",
395
+ typeConverter=TypeConverters.toInt)
396
+ topK = Param(Params._dummy(), "topK", "Set top-k sampling", typeConverter=TypeConverters.toInt)
397
+ topP = Param(Params._dummy(), "topP", "Set top-p sampling", typeConverter=TypeConverters.toFloat)
398
+ minP = Param(Params._dummy(), "minP", "Set min-p sampling", typeConverter=TypeConverters.toFloat)
399
+ tfsZ = Param(Params._dummy(), "tfsZ", "Set tail free sampling, parameter z", typeConverter=TypeConverters.toFloat)
400
+ typicalP = Param(Params._dummy(), "typicalP", "Set locally typical sampling, parameter p",
401
+ typeConverter=TypeConverters.toFloat)
402
+ temperature = Param(Params._dummy(), "temperature", "Set the temperature", typeConverter=TypeConverters.toFloat)
403
+ dynamicTemperatureRange = Param(Params._dummy(), "dynatempRange", "Set the dynamic temperature range",
404
+ typeConverter=TypeConverters.toFloat)
405
+ dynamicTemperatureExponent = Param(Params._dummy(), "dynatempExponent", "Set the dynamic temperature exponent",
406
+ typeConverter=TypeConverters.toFloat)
407
+ repeatLastN = Param(Params._dummy(), "repeatLastN", "Set the last n tokens to consider for penalties",
408
+ typeConverter=TypeConverters.toInt)
409
+ repeatPenalty = Param(Params._dummy(), "repeatPenalty", "Set the penalty of repeated sequences of tokens",
410
+ typeConverter=TypeConverters.toFloat)
411
+ frequencyPenalty = Param(Params._dummy(), "frequencyPenalty", "Set the repetition alpha frequency penalty",
412
+ typeConverter=TypeConverters.toFloat)
413
+ presencePenalty = Param(Params._dummy(), "presencePenalty", "Set the repetition alpha presence penalty",
414
+ typeConverter=TypeConverters.toFloat)
415
+ miroStat = Param(Params._dummy(), "miroStat", "Set MiroStat sampling strategies.",
416
+ typeConverter=TypeConverters.toString)
417
+ miroStatTau = Param(Params._dummy(), "mirostatTau", "Set the MiroStat target entropy, parameter tau",
418
+ typeConverter=TypeConverters.toFloat)
419
+ miroStatEta = Param(Params._dummy(), "mirostatEta", "Set the MiroStat learning rate, parameter eta",
420
+ typeConverter=TypeConverters.toFloat)
421
+ penalizeNl = Param(Params._dummy(), "penalizeNl", "Whether to penalize newline tokens",
422
+ typeConverter=TypeConverters.toBoolean)
423
+ nKeep = Param(Params._dummy(), "nKeep", "Set the number of tokens to keep from the initial prompt",
424
+ typeConverter=TypeConverters.toInt)
425
+ seed = Param(Params._dummy(), "seed", "Set the RNG seed", typeConverter=TypeConverters.toInt)
426
+ nProbs = Param(Params._dummy(), "nProbs", "Set the amount top tokens probabilities to output if greater than 0.",
427
+ typeConverter=TypeConverters.toInt)
428
+ minKeep = Param(Params._dummy(), "minKeep",
429
+ "Set the amount of tokens the samplers should return at least (0 = disabled)",
430
+ typeConverter=TypeConverters.toInt)
431
+ grammar = Param(Params._dummy(), "grammar", "Set BNF-like grammar to constrain generations",
432
+ typeConverter=TypeConverters.toString)
433
+ penaltyPrompt = Param(Params._dummy(), "penaltyPrompt",
434
+ "Override which part of the prompt is penalized for repetition.",
435
+ typeConverter=TypeConverters.toString)
436
+ ignoreEos = Param(Params._dummy(), "ignoreEos",
437
+ "Set whether to ignore end of stream token and continue generating (implies --logit-bias 2-inf)",
438
+ typeConverter=TypeConverters.toBoolean)
439
+ disableTokenIds = Param(Params._dummy(), "disableTokenIds", "Set the token ids to disable in the completion",
440
+ typeConverter=TypeConverters.toListInt)
441
+ stopStrings = Param(Params._dummy(), "stopStrings", "Set strings upon seeing which token generation is stopped",
442
+ typeConverter=TypeConverters.toListString)
443
+ samplers = Param(Params._dummy(), "samplers", "Set which samplers to use for token generation in the given order",
444
+ typeConverter=TypeConverters.toListString)
445
+ useChatTemplate = Param(Params._dummy(), "useChatTemplate",
446
+ "Set whether or not generate should apply a chat template",
447
+ typeConverter=TypeConverters.toBoolean)
448
+
449
+ # -------- MODEL SETTERS --------
450
+ def setNThreads(self, nThreads: int):
451
+ """Set the number of threads to use during generation"""
452
+ return self._set(nThreads=nThreads)
453
+
454
+ def setNThreadsDraft(self, nThreadsDraft: int):
455
+ """Set the number of threads to use during draft generation"""
456
+ return self._set(nThreadsDraft=nThreadsDraft)
457
+
458
+ def setNThreadsBatch(self, nThreadsBatch: int):
459
+ """Set the number of threads to use during batch and prompt processing"""
460
+ return self._set(nThreadsBatch=nThreadsBatch)
461
+
462
+ def setNThreadsBatchDraft(self, nThreadsBatchDraft: int):
463
+ """Set the number of threads to use during batch and prompt processing"""
464
+ return self._set(nThreadsBatchDraft=nThreadsBatchDraft)
465
+
466
+ def setNCtx(self, nCtx: int):
467
+ """Set the size of the prompt context"""
468
+ return self._set(nCtx=nCtx)
469
+
470
+ def setNBatch(self, nBatch: int):
471
+ """Set the logical batch size for prompt processing (must be >=32 to use BLAS)"""
472
+ return self._set(nBatch=nBatch)
473
+
474
+ def setNUbatch(self, nUbatch: int):
475
+ """Set the physical batch size for prompt processing (must be >=32 to use BLAS)"""
476
+ return self._set(nUbatch=nUbatch)
477
+
478
+ def setNDraft(self, nDraft: int):
479
+ """Set the number of tokens to draft for speculative decoding"""
480
+ return self._set(nDraft=nDraft)
481
+
482
+ def setNChunks(self, nChunks: int):
483
+ """Set the maximal number of chunks to process"""
484
+ return self._set(nChunks=nChunks)
485
+
486
+ def setNSequences(self, nSequences: int):
487
+ """Set the number of sequences to decode"""
488
+ return self._set(nSequences=nSequences)
489
+
490
+ def setPSplit(self, pSplit: float):
491
+ """Set the speculative decoding split probability"""
492
+ return self._set(pSplit=pSplit)
493
+
494
+ def setNGpuLayers(self, nGpuLayers: int):
495
+ """Set the number of layers to store in VRAM (-1 - use default)"""
496
+ return self._set(nGpuLayers=nGpuLayers)
497
+
498
+ def setNGpuLayersDraft(self, nGpuLayersDraft: int):
499
+ """Set the number of layers to store in VRAM for the draft model (-1 - use default)"""
500
+ return self._set(nGpuLayersDraft=nGpuLayersDraft)
501
+
502
+ def setGpuSplitMode(self, gpuSplitMode: str):
503
+ """Set how to split the model across GPUs"""
504
+ return self._set(gpuSplitMode=gpuSplitMode)
505
+
506
+ def setMainGpu(self, mainGpu: int):
507
+ """Set the main GPU that is used for scratch and small tensors."""
508
+ return self._set(mainGpu=mainGpu)
509
+
510
+ def setTensorSplit(self, tensorSplit: List[float]):
511
+ """Set how split tensors should be distributed across GPUs"""
512
+ return self._set(tensorSplit=tensorSplit)
513
+
514
+ def setGrpAttnN(self, grpAttnN: int):
515
+ """Set the group-attention factor"""
516
+ return self._set(grpAttnN=grpAttnN)
517
+
518
+ def setGrpAttnW(self, grpAttnW: int):
519
+ """Set the group-attention width"""
520
+ return self._set(grpAttnW=grpAttnW)
521
+
522
+ def setRopeFreqBase(self, ropeFreqBase: float):
523
+ """Set the RoPE base frequency, used by NTK-aware scaling"""
524
+ return self._set(ropeFreqBase=ropeFreqBase)
525
+
526
+ def setRopeFreqScale(self, ropeFreqScale: float):
527
+ """Set the RoPE frequency scaling factor, expands context by a factor of 1/N"""
528
+ return self._set(ropeFreqScale=ropeFreqScale)
529
+
530
+ def setYarnExtFactor(self, yarnExtFactor: float):
531
+ """Set the YaRN extrapolation mix factor"""
532
+ return self._set(yarnExtFactor=yarnExtFactor)
533
+
534
+ def setYarnAttnFactor(self, yarnAttnFactor: float):
535
+ """Set the YaRN scale sqrt(t) or attention magnitude"""
536
+ return self._set(yarnAttnFactor=yarnAttnFactor)
537
+
538
+ def setYarnBetaFast(self, yarnBetaFast: float):
539
+ """Set the YaRN low correction dim or beta"""
540
+ return self._set(yarnBetaFast=yarnBetaFast)
541
+
542
+ def setYarnBetaSlow(self, yarnBetaSlow: float):
543
+ """Set the YaRN high correction dim or alpha"""
544
+ return self._set(yarnBetaSlow=yarnBetaSlow)
545
+
546
+ def setYarnOrigCtx(self, yarnOrigCtx: int):
547
+ """Set the YaRN original context size of model"""
548
+ return self._set(yarnOrigCtx=yarnOrigCtx)
549
+
550
+ def setDefragmentationThreshold(self, defragmentationThreshold: float):
551
+ """Set the KV cache defragmentation threshold"""
552
+ return self._set(defragmentationThreshold=defragmentationThreshold)
553
+
554
+ def setNumaStrategy(self, numaStrategy: str):
555
+ """Set optimization strategies that help on some NUMA systems (if available)"""
556
+ return self._set(numaStrategy=numaStrategy)
557
+
558
+ def setRopeScalingType(self, ropeScalingType: str):
559
+ """Set the RoPE frequency scaling method, defaults to linear unless specified by the model"""
560
+ return self._set(ropeScalingType=ropeScalingType)
561
+
562
+ def setPoolingType(self, poolingType: bool):
563
+ """Set the pooling type for embeddings, use model default if unspecified"""
564
+ return self._set(poolingType=poolingType)
565
+
566
+ def setModelDraft(self, modelDraft: str):
567
+ """Set the draft model for speculative decoding"""
568
+ return self._set(modelDraft=modelDraft)
569
+
570
+ def setModelAlias(self, modelAlias: str):
571
+ """Set a model alias"""
572
+ return self._set(modelAlias=modelAlias)
573
+
574
+ def setLookupCacheStaticFilePath(self, lookupCacheStaticFilePath: str):
575
+ """Set path to static lookup cache to use for lookup decoding (not updated by generation)"""
576
+ return self._set(lookupCacheStaticFilePath=lookupCacheStaticFilePath)
577
+
578
+ def setLookupCacheDynamicFilePath(self, lookupCacheDynamicFilePath: str):
579
+ """Set path to dynamic lookup cache to use for lookup decoding (updated by generation)"""
580
+ return self._set(lookupCacheDynamicFilePath=lookupCacheDynamicFilePath)
581
+
582
+ def setEmbedding(self, embedding: bool):
583
+ """Whether to load model with embedding support"""
584
+ return self._set(embedding=embedding)
585
+
586
+ def setFlashAttention(self, flashAttention: bool):
587
+ """Whether to enable Flash Attention"""
588
+ return self._set(flashAttention=flashAttention)
589
+
590
+ def setInputPrefixBos(self, inputPrefixBos: bool):
591
+ """Whether to add prefix BOS to user inputs, preceding the `--in-prefix` bool"""
592
+ return self._set(inputPrefixBos=inputPrefixBos)
593
+
594
+ def setUseMmap(self, useMmap: bool):
595
+ """Whether to use memory-map model (faster load but may increase pageouts if not using mlock)"""
596
+ return self._set(useMmap=useMmap)
597
+
598
+ def setUseMlock(self, useMlock: bool):
599
+ """Whether to force the system to keep model in RAM rather than swapping or compressing"""
600
+ return self._set(useMlock=useMlock)
601
+
602
+ def setNoKvOffload(self, noKvOffload: bool):
603
+ """Whether to disable KV offload"""
604
+ return self._set(noKvOffload=noKvOffload)
605
+
606
+ def setSystemPrompt(self, systemPrompt: bool):
607
+ """Set a system prompt to use"""
608
+ return self._set(systemPrompt=systemPrompt)
609
+
610
+ def setChatTemplate(self, chatTemplate: str):
611
+ """The chat template to use"""
612
+ return self._set(chatTemplate=chatTemplate)
613
+
614
+ # -------- INFERENCE SETTERS --------
615
+ def setInputPrefix(self, inputPrefix: str):
616
+ """Set the prompt to start generation with"""
617
+ return self._set(inputPrefix=inputPrefix)
618
+
619
+ def setInputSuffix(self, inputSuffix: str):
620
+ """Set a suffix for infilling"""
621
+ return self._set(inputSuffix=inputSuffix)
622
+
623
+ def setCachePrompt(self, cachePrompt: bool):
624
+ """Whether to remember the prompt to avoid reprocessing it"""
625
+ return self._set(cachePrompt=cachePrompt)
626
+
627
+ def setNPredict(self, nPredict: int):
628
+ """Set the number of tokens to predict"""
629
+ return self._set(nPredict=nPredict)
630
+
631
+ def setTopK(self, topK: int):
632
+ """Set top-k sampling"""
633
+ return self._set(topK=topK)
634
+
635
+ def setTopP(self, topP: float):
636
+ """Set top-p sampling"""
637
+ return self._set(topP=topP)
638
+
639
+ def setMinP(self, minP: float):
640
+ """Set min-p sampling"""
641
+ return self._set(minP=minP)
642
+
643
+ def setTfsZ(self, tfsZ: float):
644
+ """Set tail free sampling, parameter z"""
645
+ return self._set(tfsZ=tfsZ)
646
+
647
+ def setTypicalP(self, typicalP: float):
648
+ """Set locally typical sampling, parameter p"""
649
+ return self._set(typicalP=typicalP)
650
+
651
+ def setTemperature(self, temperature: float):
652
+ """Set the temperature"""
653
+ return self._set(temperature=temperature)
654
+
655
+ def setDynamicTemperatureRange(self, dynamicTemperatureRange: float):
656
+ """Set the dynamic temperature range"""
657
+ return self._set(dynamicTemperatureRange=dynamicTemperatureRange)
658
+
659
+ def setDynamicTemperatureExponent(self, dynamicTemperatureExponent: float):
660
+ """Set the dynamic temperature exponent"""
661
+ return self._set(dynamicTemperatureExponent=dynamicTemperatureExponent)
662
+
663
+ def setRepeatLastN(self, repeatLastN: int):
664
+ """Set the last n tokens to consider for penalties"""
665
+ return self._set(repeatLastN=repeatLastN)
666
+
667
+ def setRepeatPenalty(self, repeatPenalty: float):
668
+ """Set the penalty of repeated sequences of tokens"""
669
+ return self._set(repeatPenalty=repeatPenalty)
670
+
671
+ def setFrequencyPenalty(self, frequencyPenalty: float):
672
+ """Set the repetition alpha frequency penalty"""
673
+ return self._set(frequencyPenalty=frequencyPenalty)
674
+
675
+ def setPresencePenalty(self, presencePenalty: float):
676
+ """Set the repetition alpha presence penalty"""
677
+ return self._set(presencePenalty=presencePenalty)
678
+
679
+ def setMiroStat(self, miroStat: str):
680
+ """Set MiroStat sampling strategies."""
681
+ return self._set(miroStat=miroStat)
682
+
683
+ def setMiroStatTau(self, miroStatTau: float):
684
+ """Set the MiroStat target entropy, parameter tau"""
685
+ return self._set(miroStatTau=miroStatTau)
686
+
687
+ def setMiroStatEta(self, miroStatEta: float):
688
+ """Set the MiroStat learning rate, parameter eta"""
689
+ return self._set(miroStatEta=miroStatEta)
690
+
691
+ def setPenalizeNl(self, penalizeNl: bool):
692
+ """Whether to penalize newline tokens"""
693
+ return self._set(penalizeNl=penalizeNl)
694
+
695
+ def setNKeep(self, nKeep: int):
696
+ """Set the number of tokens to keep from the initial prompt"""
697
+ return self._set(nKeep=nKeep)
698
+
699
+ def setSeed(self, seed: int):
700
+ """Set the RNG seed"""
701
+ return self._set(seed=seed)
702
+
703
+ def setNProbs(self, nProbs: int):
704
+ """Set the amount top tokens probabilities to output if greater than 0."""
705
+ return self._set(nProbs=nProbs)
706
+
707
+ def setMinKeep(self, minKeep: int):
708
+ """Set the amount of tokens the samplers should return at least (0 = disabled)"""
709
+ return self._set(minKeep=minKeep)
710
+
711
+ def setGrammar(self, grammar: bool):
712
+ """Set BNF-like grammar to constrain generations"""
713
+ return self._set(grammar=grammar)
714
+
715
+ def setPenaltyPrompt(self, penaltyPrompt: str):
716
+ """Override which part of the prompt is penalized for repetition."""
717
+ return self._set(penaltyPrompt=penaltyPrompt)
718
+
719
+ def setIgnoreEos(self, ignoreEos: bool):
720
+ """Set whether to ignore end of stream token and continue generating (implies --logit-bias 2-inf)"""
721
+ return self._set(ignoreEos=ignoreEos)
722
+
723
+ def setDisableTokenIds(self, disableTokenIds: List[int]):
724
+ """Set the token ids to disable in the completion"""
725
+ return self._set(disableTokenIds=disableTokenIds)
726
+
727
+ def setStopStrings(self, stopStrings: List[str]):
728
+ """Set strings upon seeing which token generation is stopped"""
729
+ return self._set(stopStrings=stopStrings)
730
+
731
+ def setSamplers(self, samplers: List[str]):
732
+ """Set which samplers to use for token generation in the given order"""
733
+ return self._set(samplers=samplers)
734
+
735
+ def setUseChatTemplate(self, useChatTemplate: bool):
736
+ """Set whether generate should apply a chat template"""
737
+ return self._set(useChatTemplate=useChatTemplate)
738
+
739
+ # -------- JAVA SETTERS --------
740
+ def setTokenIdBias(self, tokenIdBias: Dict[int, float]):
741
+ """Set token id bias"""
742
+ return self._call_java("setTokenIdBias", tokenIdBias)
743
+
744
+ def setTokenBias(self, tokenBias: Dict[str, float]):
745
+ """Set token id bias"""
746
+ return self._call_java("setTokenBias", tokenBias)
747
+
748
+ def setLoraAdapters(self, loraAdapters: Dict[str, float]):
749
+ """Set token id bias"""
750
+ return self._call_java("setLoraAdapters", loraAdapters)
751
+
752
+ def getMetadata(self):
753
+ """Gets the metadata of the model"""
754
+ return self._call_java("getMetadata")
755
+
756
+ @keyword_only
757
+ def __init__(self, classname="com.johnsnowlabs.nlp.annotators.seq2seq.AutoGGUFModel", java_model=None):
758
+ super(AutoGGUFModel, self).__init__(
759
+ classname=classname,
760
+ java_model=java_model
761
+ )
762
+ # self._setDefault()
763
+
764
+ @staticmethod
765
+ def loadSavedModel(folder, spark_session):
766
+ """Loads a locally saved model.
767
+
768
+ Parameters
769
+ ----------
770
+ folder : str
771
+ Folder of the saved model
772
+ spark_session : pyspark.sql.SparkSession
773
+ The current SparkSession
774
+
775
+ Returns
776
+ -------
777
+ AutoGGUFModel
778
+ The restored model
779
+ """
780
+ from sparknlp.internal import _AutoGGUFLoader
781
+ jModel = _AutoGGUFLoader(folder, spark_session._jsparkSession)._java_obj
782
+ return AutoGGUFModel(java_model=jModel)
783
+
784
+ @staticmethod
785
+ def pretrained(name="phi3.5_mini_4k_instruct_q4_gguf", lang="en", remote_loc=None):
786
+ """Downloads and loads a pretrained model.
787
+
788
+ Parameters
789
+ ----------
790
+ name : str, optional
791
+ Name of the pretrained model, by default "phi3.5_mini_4k_instruct_q4_gguf"
792
+ lang : str, optional
793
+ Language of the pretrained model, by default "en"
794
+ remote_loc : str, optional
795
+ Optional remote address of the resource, by default None. Will use
796
+ Spark NLPs repositories otherwise.
797
+
798
+ Returns
799
+ -------
800
+ AutoGGUFModel
801
+ The restored model
802
+ """
803
+ from sparknlp.pretrained import ResourceDownloader
804
+ return ResourceDownloader.downloadModel(AutoGGUFModel, name, lang, remote_loc)