spark-nlp 5.5.3__py2.py3-none-any.whl → 6.0.1rc1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of spark-nlp might be problematic. Click here for more details.

Files changed (37) hide show
  1. {spark_nlp-5.5.3.dist-info → spark_nlp-6.0.1rc1.dist-info}/METADATA +20 -11
  2. {spark_nlp-5.5.3.dist-info → spark_nlp-6.0.1rc1.dist-info}/RECORD +36 -17
  3. {spark_nlp-5.5.3.dist-info → spark_nlp-6.0.1rc1.dist-info}/WHEEL +1 -1
  4. sparknlp/__init__.py +2 -2
  5. sparknlp/annotator/classifier_dl/__init__.py +4 -0
  6. sparknlp/annotator/classifier_dl/albert_for_multiple_choice.py +161 -0
  7. sparknlp/annotator/classifier_dl/bert_for_multiple_choice.py +2 -2
  8. sparknlp/annotator/classifier_dl/distilbert_for_multiple_choice.py +161 -0
  9. sparknlp/annotator/classifier_dl/roberta_for_multiple_choice.py +161 -0
  10. sparknlp/annotator/classifier_dl/xlm_roberta_for_multiple_choice.py +149 -0
  11. sparknlp/annotator/cleaners/__init__.py +15 -0
  12. sparknlp/annotator/cleaners/cleaner.py +202 -0
  13. sparknlp/annotator/cleaners/extractor.py +191 -0
  14. sparknlp/annotator/cv/__init__.py +9 -1
  15. sparknlp/annotator/cv/gemma3_for_multimodal.py +351 -0
  16. sparknlp/annotator/cv/janus_for_multimodal.py +356 -0
  17. sparknlp/annotator/cv/llava_for_multimodal.py +328 -0
  18. sparknlp/annotator/cv/mllama_for_multimodal.py +340 -0
  19. sparknlp/annotator/cv/paligemma_for_multimodal.py +308 -0
  20. sparknlp/annotator/cv/phi3_vision_for_multimodal.py +328 -0
  21. sparknlp/annotator/cv/qwen2vl_transformer.py +332 -0
  22. sparknlp/annotator/cv/smolvlm_transformer.py +432 -0
  23. sparknlp/annotator/embeddings/auto_gguf_embeddings.py +10 -6
  24. sparknlp/annotator/seq2seq/__init__.py +3 -0
  25. sparknlp/annotator/seq2seq/auto_gguf_model.py +8 -503
  26. sparknlp/annotator/seq2seq/auto_gguf_vision_model.py +333 -0
  27. sparknlp/annotator/seq2seq/cohere_transformer.py +357 -0
  28. sparknlp/annotator/seq2seq/llama3_transformer.py +4 -4
  29. sparknlp/annotator/seq2seq/olmo_transformer.py +326 -0
  30. sparknlp/base/image_assembler.py +58 -0
  31. sparknlp/common/properties.py +605 -96
  32. sparknlp/internal/__init__.py +127 -2
  33. sparknlp/reader/enums.py +19 -0
  34. sparknlp/reader/pdf_to_text.py +111 -0
  35. sparknlp/reader/sparknlp_reader.py +222 -14
  36. spark_nlp-5.5.3.dist-info/.uuid +0 -1
  37. {spark_nlp-5.5.3.dist-info → spark_nlp-6.0.1rc1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,333 @@
1
+ # Copyright 2017-2025 John Snow Labs
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Contains classes for the AutoGGUFVisionModel."""
15
+ from sparknlp.common import *
16
+
17
+
18
+ class AutoGGUFVisionModel(AnnotatorModel, HasBatchedAnnotate, HasLlamaCppProperties):
19
+ """Multimodal annotator that uses the llama.cpp library to generate text completions with large
20
+ language models. It supports ingesting images for captioning.
21
+
22
+ At the moment only CLIP based models are supported.
23
+
24
+ For settable parameters, and their explanations, see HasLlamaCppInferenceProperties,
25
+ HasLlamaCppModelProperties and refer to the llama.cpp documentation of
26
+ `server.cpp <https://github.com/ggerganov/llama.cpp/tree/7d5e8777ae1d21af99d4f95be10db4870720da91/examples/server>`__
27
+ for more information.
28
+
29
+ If the parameters are not set, the annotator will default to use the parameters provided by
30
+ the model.
31
+
32
+ This annotator expects a column of annotator type AnnotationImage for the image and
33
+ Annotation for the caption. Note that the image bytes in the image annotation need to be
34
+ raw image bytes without preprocessing. We provide the helper function
35
+ ImageAssembler.loadImagesAsBytes to load the image bytes from a directory.
36
+
37
+ Pretrained models can be loaded with ``pretrained`` of the companion object:
38
+
39
+ .. code-block:: python
40
+
41
+ autoGGUFVisionModel = AutoGGUFVisionModel.pretrained() \\
42
+ .setInputCols(["image", "document"]) \\
43
+ .setOutputCol("completions")
44
+
45
+
46
+ The default model is ``"llava_v1.5_7b_Q4_0_gguf"``, if no name is provided.
47
+
48
+ For available pretrained models please see the `Models Hub <https://sparknlp.org/models>`__.
49
+
50
+ For extended examples of usage, see the
51
+ `AutoGGUFVisionModelTest <https://github.com/JohnSnowLabs/spark-nlp/tree/master/src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/AutoGGUFVisionModelTest.scala>`__
52
+ and the
53
+ `example notebook <https://github.com/JohnSnowLabs/spark-nlp/tree/master/examples/python/llama.cpp/llama.cpp_in_Spark_NLP_AutoGGUFVisionModel.ipynb>`__.
54
+
55
+ ====================== ======================
56
+ Input Annotation types Output Annotation type
57
+ ====================== ======================
58
+ ``IMAGE, DOCUMENT`` ``DOCUMENT``
59
+ ====================== ======================
60
+
61
+ Parameters
62
+ ----------
63
+ nThreads
64
+ Set the number of threads to use during generation
65
+ nThreadsDraft
66
+ Set the number of threads to use during draft generation
67
+ nThreadsBatch
68
+ Set the number of threads to use during batch and prompt processing
69
+ nThreadsBatchDraft
70
+ Set the number of threads to use during batch and prompt processing
71
+ nCtx
72
+ Set the size of the prompt context
73
+ nBatch
74
+ Set the logical batch size for prompt processing (must be >=32 to use BLAS)
75
+ nUbatch
76
+ Set the physical batch size for prompt processing (must be >=32 to use BLAS)
77
+ nDraft
78
+ Set the number of tokens to draft for speculative decoding
79
+ nChunks
80
+ Set the maximal number of chunks to process
81
+ nSequences
82
+ Set the number of sequences to decode
83
+ pSplit
84
+ Set the speculative decoding split probability
85
+ nGpuLayers
86
+ Set the number of layers to store in VRAM (-1 - use default)
87
+ nGpuLayersDraft
88
+ Set the number of layers to store in VRAM for the draft model (-1 - use default)
89
+ gpuSplitMode
90
+ Set how to split the model across GPUs
91
+ mainGpu
92
+ Set the main GPU that is used for scratch and small tensors.
93
+ tensorSplit
94
+ Set how split tensors should be distributed across GPUs
95
+ grpAttnN
96
+ Set the group-attention factor
97
+ grpAttnW
98
+ Set the group-attention width
99
+ ropeFreqBase
100
+ Set the RoPE base frequency, used by NTK-aware scaling
101
+ ropeFreqScale
102
+ Set the RoPE frequency scaling factor, expands context by a factor of 1/N
103
+ yarnExtFactor
104
+ Set the YaRN extrapolation mix factor
105
+ yarnAttnFactor
106
+ Set the YaRN scale sqrt(t) or attention magnitude
107
+ yarnBetaFast
108
+ Set the YaRN low correction dim or beta
109
+ yarnBetaSlow
110
+ Set the YaRN high correction dim or alpha
111
+ yarnOrigCtx
112
+ Set the YaRN original context size of model
113
+ defragmentationThreshold
114
+ Set the KV cache defragmentation threshold
115
+ numaStrategy
116
+ Set optimization strategies that help on some NUMA systems (if available)
117
+ ropeScalingType
118
+ Set the RoPE frequency scaling method, defaults to linear unless specified by the model
119
+ poolingType
120
+ Set the pooling type for embeddings, use model default if unspecified
121
+ modelDraft
122
+ Set the draft model for speculative decoding
123
+ modelAlias
124
+ Set a model alias
125
+ lookupCacheStaticFilePath
126
+ Set path to static lookup cache to use for lookup decoding (not updated by generation)
127
+ lookupCacheDynamicFilePath
128
+ Set path to dynamic lookup cache to use for lookup decoding (updated by generation)
129
+ embedding
130
+ Whether to load model with embedding support
131
+ flashAttention
132
+ Whether to enable Flash Attention
133
+ inputPrefixBos
134
+ Whether to add prefix BOS to user inputs, preceding the `--in-prefix` string
135
+ useMmap
136
+ Whether to use memory-map model (faster load but may increase pageouts if not using mlock)
137
+ useMlock
138
+ Whether to force the system to keep model in RAM rather than swapping or compressing
139
+ noKvOffload
140
+ Whether to disable KV offload
141
+ systemPrompt
142
+ Set a system prompt to use
143
+ chatTemplate
144
+ The chat template to use
145
+ inputPrefix
146
+ Set the prompt to start generation with
147
+ inputSuffix
148
+ Set a suffix for infilling
149
+ cachePrompt
150
+ Whether to remember the prompt to avoid reprocessing it
151
+ nPredict
152
+ Set the number of tokens to predict
153
+ topK
154
+ Set top-k sampling
155
+ topP
156
+ Set top-p sampling
157
+ minP
158
+ Set min-p sampling
159
+ tfsZ
160
+ Set tail free sampling, parameter z
161
+ typicalP
162
+ Set locally typical sampling, parameter p
163
+ temperature
164
+ Set the temperature
165
+ dynatempRange
166
+ Set the dynamic temperature range
167
+ dynatempExponent
168
+ Set the dynamic temperature exponent
169
+ repeatLastN
170
+ Set the last n tokens to consider for penalties
171
+ repeatPenalty
172
+ Set the penalty of repeated sequences of tokens
173
+ frequencyPenalty
174
+ Set the repetition alpha frequency penalty
175
+ presencePenalty
176
+ Set the repetition alpha presence penalty
177
+ miroStat
178
+ Set MiroStat sampling strategies.
179
+ mirostatTau
180
+ Set the MiroStat target entropy, parameter tau
181
+ mirostatEta
182
+ Set the MiroStat learning rate, parameter eta
183
+ penalizeNl
184
+ Whether to penalize newline tokens
185
+ nKeep
186
+ Set the number of tokens to keep from the initial prompt
187
+ seed
188
+ Set the RNG seed
189
+ nProbs
190
+ Set the amount top tokens probabilities to output if greater than 0.
191
+ minKeep
192
+ Set the amount of tokens the samplers should return at least (0 = disabled)
193
+ grammar
194
+ Set BNF-like grammar to constrain generations
195
+ penaltyPrompt
196
+ Override which part of the prompt is penalized for repetition.
197
+ ignoreEos
198
+ Set whether to ignore end of stream token and continue generating (implies --logit-bias 2-inf)
199
+ disableTokenIds
200
+ Set the token ids to disable in the completion
201
+ stopStrings
202
+ Set strings upon seeing which token generation is stopped
203
+ samplers
204
+ Set which samplers to use for token generation in the given order
205
+ useChatTemplate
206
+ Set whether or not generate should apply a chat template
207
+
208
+ Notes
209
+ -----
210
+ To use GPU inference with this annotator, make sure to use the Spark NLP GPU package and set
211
+ the number of GPU layers with the `setNGpuLayers` method.
212
+
213
+ When using larger models, we recommend adjusting GPU usage with `setNCtx` and `setNGpuLayers`
214
+ according to your hardware to avoid out-of-memory errors.
215
+
216
+ Examples
217
+ >>> import sparknlp
218
+ >>> from sparknlp.base import *
219
+ >>> from sparknlp.annotator import *
220
+ >>> from pyspark.ml import Pipeline
221
+ >>> from pyspark.sql.functions import lit
222
+ >>> documentAssembler = DocumentAssembler() \\
223
+ ... .setInputCol("caption") \\
224
+ ... .setOutputCol("caption_document")
225
+ >>> imageAssembler = ImageAssembler() \\
226
+ ... .setInputCol("image") \\
227
+ ... .setOutputCol("image_assembler")
228
+ >>> imagesPath = "src/test/resources/image/"
229
+ >>> data = ImageAssembler \\
230
+ ... .loadImagesAsBytes(spark, imagesPath) \\
231
+ ... .withColumn("caption", lit("Caption this image.")) # Add a caption to each image.
232
+ >>> nPredict = 40
233
+ >>> model = AutoGGUFVisionModel.pretrained() \\
234
+ ... .setInputCols(["caption_document", "image_assembler"]) \\
235
+ ... .setOutputCol("completions") \\
236
+ ... .setBatchSize(4) \\
237
+ ... .setNGpuLayers(99) \\
238
+ ... .setNCtx(4096) \\
239
+ ... .setMinKeep(0) \\
240
+ ... .setMinP(0.05) \\
241
+ ... .setNPredict(nPredict) \\
242
+ ... .setNProbs(0) \\
243
+ ... .setPenalizeNl(False) \\
244
+ ... .setRepeatLastN(256) \\
245
+ ... .setRepeatPenalty(1.18) \\
246
+ ... .setStopStrings(["</s>", "Llama:", "User:"]) \\
247
+ ... .setTemperature(0.05) \\
248
+ ... .setTfsZ(1) \\
249
+ ... .setTypicalP(1) \\
250
+ ... .setTopK(40) \\
251
+ ... .setTopP(0.95)
252
+ >>> pipeline = Pipeline().setStages([documentAssembler, imageAssembler, model])
253
+ >>> pipeline.fit(data).transform(data) \\
254
+ ... .selectExpr("reverse(split(image.origin, '/'))[0] as image_name", "completions.result") \\
255
+ ... .show(truncate = False)
256
+ +-----------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
257
+ |image_name |result |
258
+ +-----------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
259
+ |palace.JPEG |[ The image depicts a large, ornate room with high ceilings and beautifully decorated walls. There are several chairs placed throughout the space, some of which have cushions] |
260
+ |egyptian_cat.jpeg|[ The image features two cats lying on a pink surface, possibly a bed or sofa. One cat is positioned towards the left side of the scene and appears to be sleeping while holding] |
261
+ |hippopotamus.JPEG|[ A large brown hippo is swimming in a body of water, possibly an aquarium. The hippo appears to be enjoying its time in the water and seems relaxed as it floats] |
262
+ |hen.JPEG |[ The image features a large chicken standing next to several baby chickens. In total, there are five birds in the scene: one adult and four young ones. They appear to be gathered together] |
263
+ |ostrich.JPEG |[ The image features a large, long-necked bird standing in the grass. It appears to be an ostrich or similar species with its head held high and looking around. In addition to] |
264
+ |junco.JPEG |[ A small bird with a black head and white chest is standing on the snow. It appears to be looking at something, possibly food or another animal in its vicinity. The scene takes place out] |
265
+ |bluetick.jpg |[ A dog with a red collar is sitting on the floor, looking at something. The dog appears to be staring into the distance or focusing its attention on an object in front of it.] |
266
+ |chihuahua.jpg |[ A small brown dog wearing a sweater is sitting on the floor. The dog appears to be looking at something, possibly its owner or another animal in the room. It seems comfortable and relaxed]|
267
+ |tractor.JPEG |[ A man is sitting in the driver's seat of a green tractor, which has yellow wheels and tires. The tractor appears to be parked on top of an empty field with] |
268
+ |ox.JPEG |[ A large bull with horns is standing in a grassy field.] |
269
+ +-----------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+-------
270
+ """
271
+
272
+ name = "AutoGGUFVisionModel"
273
+ inputAnnotatorTypes = [AnnotatorType.IMAGE, AnnotatorType.DOCUMENT]
274
+ outputAnnotatorType = AnnotatorType.DOCUMENT
275
+
276
+ @keyword_only
277
+ def __init__(self, classname="com.johnsnowlabs.nlp.annotators.seq2seq.AutoGGUFVisionModel", java_model=None):
278
+ super(AutoGGUFVisionModel, self).__init__(
279
+ classname=classname,
280
+ java_model=java_model
281
+ )
282
+
283
+ self._setDefault(
284
+ useChatTemplate=True,
285
+ nCtx=4096,
286
+ nBatch=512,
287
+ embedding=False,
288
+ nPredict=100
289
+ )
290
+
291
+ @staticmethod
292
+ def loadSavedModel(modelPath, mmprojPath, spark_session):
293
+ """Loads a locally saved modelPath.
294
+
295
+ Parameters
296
+ ----------
297
+ modelPath : str
298
+ Path to the modelPath file
299
+ mmprojPath : str
300
+ Path to the mmprojPath file
301
+ spark_session : pyspark.sql.SparkSession
302
+ The current SparkSession
303
+
304
+ Returns
305
+ -------
306
+ AutoGGUFVisionModel
307
+ The restored modelPath
308
+ """
309
+ from sparknlp.internal import _AutoGGUFVisionLoader
310
+ jModel = _AutoGGUFVisionLoader(modelPath, mmprojPath, spark_session._jsparkSession)._java_obj
311
+ return AutoGGUFVisionModel(java_model=jModel)
312
+
313
+ @staticmethod
314
+ def pretrained(name="llava_v1.5_7b_Q4_0_gguf", lang="en", remote_loc=None):
315
+ """Downloads and loads a pretrained model.
316
+
317
+ Parameters
318
+ ----------
319
+ name : str, optional
320
+ Name of the pretrained model, by default "llava_v1.5_7b_Q4_0_gguf"
321
+ lang : str, optional
322
+ Language of the pretrained model, by default "en"
323
+ remote_loc : str, optional
324
+ Optional remote address of the resource, by default None. Will use
325
+ Spark NLPs repositories otherwise.
326
+
327
+ Returns
328
+ -------
329
+ AutoGGUFVisionModel
330
+ The restored model
331
+ """
332
+ from sparknlp.pretrained import ResourceDownloader
333
+ return ResourceDownloader.downloadModel(AutoGGUFVisionModel, name, lang, remote_loc)
@@ -0,0 +1,357 @@
1
+ # Copyright 2017-2022 John Snow Labs
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Contains classes for the CoHereTransformer."""
15
+
16
+ from sparknlp.common import *
17
+
18
+
19
+ class CoHereTransformer(AnnotatorModel, HasBatchedAnnotate, HasEngine):
20
+ """Cohere: Command-R Transformer
21
+
22
+ C4AI Command-R is a research release of a 35 billion parameter highly performant generative model.
23
+ Command-R is a large language model with open weights optimized for a variety of use cases including reasoning,
24
+ summarization, and question answering. Command-R has the capability for multilingual generation evaluated
25
+ in 10 languages and highly performant RAG capabilities.
26
+
27
+ Pretrained models can be loaded with :meth:`.pretrained` of the companion
28
+ object:
29
+
30
+ >>> CoHere = CoHereTransformer.pretrained() \\
31
+ ... .setInputCols(["document"]) \\
32
+ ... .setOutputCol("generation")
33
+
34
+
35
+ The default model is ``"c4ai_command_r_v01_int4"``, if no name is provided. For available
36
+ pretrained models please see the `Models Hub
37
+ <https://sparknlp.org/models?q=CoHere>`__.
38
+
39
+ ====================== ======================
40
+ Input Annotation types Output Annotation type
41
+ ====================== ======================
42
+ ``DOCUMENT`` ``DOCUMENT``
43
+ ====================== ======================
44
+
45
+ Parameters
46
+ ----------
47
+ configProtoBytes
48
+ ConfigProto from tensorflow, serialized into byte array.
49
+ minOutputLength
50
+ Minimum length of the sequence to be generated, by default 0
51
+ maxOutputLength
52
+ Maximum length of output text, by default 60
53
+ doSample
54
+ Whether or not to use sampling; use greedy decoding otherwise, by default False
55
+ temperature
56
+ The value used to modulate the next token probabilities, by default 1.0
57
+ topK
58
+ The number of highest probability vocabulary tokens to keep for
59
+ top-k-filtering, by default 40
60
+ topP
61
+ Top cumulative probability for vocabulary tokens, by default 1.0
62
+
63
+ If set to float < 1, only the most probable tokens with probabilities
64
+ that add up to ``topP`` or higher are kept for generation.
65
+ repetitionPenalty
66
+ The parameter for repetition penalty, 1.0 means no penalty. , by default
67
+ 1.0
68
+ noRepeatNgramSize
69
+ If set to int > 0, all ngrams of that size can only occur once, by
70
+ default 0
71
+ ignoreTokenIds
72
+ A list of token ids which are ignored in the decoder's output, by
73
+ default []
74
+
75
+ Notes
76
+ -----
77
+ This is a very computationally expensive module, especially on larger
78
+ sequences. The use of an accelerator such as GPU is recommended.
79
+
80
+ References
81
+ ----------
82
+ - `Cohere <https://cohere.for.ai/>`__
83
+
84
+
85
+ Examples
86
+ --------
87
+ >>> import sparknlp
88
+ >>> from sparknlp.base import *
89
+ >>> from sparknlp.annotator import *
90
+ >>> from pyspark.ml import Pipeline
91
+ >>> documentAssembler = DocumentAssembler() \\
92
+ ... .setInputCol("text") \\
93
+ ... .setOutputCol("documents")
94
+ >>> CoHere = CoHereTransformer.pretrained("c4ai_command_r_v01_int4","en") \\
95
+ ... .setInputCols(["documents"]) \\
96
+ ... .setMaxOutputLength(60) \\
97
+ ... .setOutputCol("generation")
98
+ >>> pipeline = Pipeline().setStages([documentAssembler, CoHere])
99
+ >>> data = spark.createDataFrame([
100
+ ... (
101
+ ... 1,
102
+ ... "<BOS_TOKEN><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hello, how are you?<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"
103
+ ... )
104
+ ... ]).toDF("id", "text")
105
+ >>> result = pipeline.fit(data).transform(data)
106
+ >>> result.select("generation.result").show(truncate=False)
107
+ +------------------------------------------------+
108
+ |result |
109
+ +------------------------------------------------+
110
+ |[Hello! I'm doing well, thank you for asking! I'm excited to help you with whatever questions you have today. How can I assist you?]|
111
+ +------------------------------------------------+
112
+ """
113
+
114
+ name = "CoHereTransformer"
115
+
116
+ inputAnnotatorTypes = [AnnotatorType.DOCUMENT]
117
+
118
+ outputAnnotatorType = AnnotatorType.DOCUMENT
119
+
120
+ configProtoBytes = Param(Params._dummy(),
121
+ "configProtoBytes",
122
+ "ConfigProto from tensorflow, serialized into byte array. Get with config_proto.SerializeToString()",
123
+ TypeConverters.toListInt)
124
+
125
+ minOutputLength = Param(Params._dummy(), "minOutputLength", "Minimum length of the sequence to be generated",
126
+ typeConverter=TypeConverters.toInt)
127
+
128
+ maxOutputLength = Param(Params._dummy(), "maxOutputLength", "Maximum length of output text",
129
+ typeConverter=TypeConverters.toInt)
130
+
131
+ doSample = Param(Params._dummy(), "doSample", "Whether or not to use sampling; use greedy decoding otherwise",
132
+ typeConverter=TypeConverters.toBoolean)
133
+
134
+ temperature = Param(Params._dummy(), "temperature", "The value used to module the next token probabilities",
135
+ typeConverter=TypeConverters.toFloat)
136
+
137
+ topK = Param(Params._dummy(), "topK",
138
+ "The number of highest probability vocabulary tokens to keep for top-k-filtering",
139
+ typeConverter=TypeConverters.toInt)
140
+
141
+ topP = Param(Params._dummy(), "topP",
142
+ "If set to float < 1, only the most probable tokens with probabilities that add up to ``top_p`` or higher are kept for generation",
143
+ typeConverter=TypeConverters.toFloat)
144
+
145
+ repetitionPenalty = Param(Params._dummy(), "repetitionPenalty",
146
+ "The parameter for repetition penalty. 1.0 means no penalty. See `this paper <https://arxiv.org/pdf/1909.05858.pdf>`__ for more details",
147
+ typeConverter=TypeConverters.toFloat)
148
+
149
+ noRepeatNgramSize = Param(Params._dummy(), "noRepeatNgramSize",
150
+ "If set to int > 0, all ngrams of that size can only occur once",
151
+ typeConverter=TypeConverters.toInt)
152
+
153
+ ignoreTokenIds = Param(Params._dummy(), "ignoreTokenIds",
154
+ "A list of token ids which are ignored in the decoder's output",
155
+ typeConverter=TypeConverters.toListInt)
156
+
157
+ beamSize = Param(Params._dummy(), "beamSize",
158
+ "The number of beams to use for beam search",
159
+ typeConverter=TypeConverters.toInt)
160
+
161
+ stopTokenIds = Param(Params._dummy(), "stopTokenIds",
162
+ "A list of token ids which are considered as stop tokens in the decoder's output",
163
+ typeConverter=TypeConverters.toListInt)
164
+
165
+ def setIgnoreTokenIds(self, value):
166
+ """A list of token ids which are ignored in the decoder's output.
167
+
168
+ Parameters
169
+ ----------
170
+ value : List[int]
171
+ The words to be filtered out
172
+ """
173
+ return self._set(ignoreTokenIds=value)
174
+
175
+ def setConfigProtoBytes(self, b):
176
+ """Sets configProto from tensorflow, serialized into byte array.
177
+
178
+ Parameters
179
+ ----------
180
+ b : List[int]
181
+ ConfigProto from tensorflow, serialized into byte array
182
+ """
183
+ return self._set(configProtoBytes=b)
184
+
185
+ def setMinOutputLength(self, value):
186
+ """Sets minimum length of the sequence to be generated.
187
+
188
+ Parameters
189
+ ----------
190
+ value : int
191
+ Minimum length of the sequence to be generated
192
+ """
193
+ return self._set(minOutputLength=value)
194
+
195
+ def setMaxOutputLength(self, value):
196
+ """Sets maximum length of output text.
197
+
198
+ Parameters
199
+ ----------
200
+ value : int
201
+ Maximum length of output text
202
+ """
203
+ return self._set(maxOutputLength=value)
204
+
205
+ def setDoSample(self, value):
206
+ """Sets whether or not to use sampling, use greedy decoding otherwise.
207
+
208
+ Parameters
209
+ ----------
210
+ value : bool
211
+ Whether or not to use sampling; use greedy decoding otherwise
212
+ """
213
+ return self._set(doSample=value)
214
+
215
+ def setTemperature(self, value):
216
+ """Sets the value used to module the next token probabilities.
217
+
218
+ Parameters
219
+ ----------
220
+ value : float
221
+ The value used to module the next token probabilities
222
+ """
223
+ return self._set(temperature=value)
224
+
225
+ def setTopK(self, value):
226
+ """Sets the number of highest probability vocabulary tokens to keep for
227
+ top-k-filtering.
228
+
229
+ Parameters
230
+ ----------
231
+ value : int
232
+ Number of highest probability vocabulary tokens to keep
233
+ """
234
+ return self._set(topK=value)
235
+
236
+ def setTopP(self, value):
237
+ """Sets the top cumulative probability for vocabulary tokens.
238
+
239
+ If set to float < 1, only the most probable tokens with probabilities
240
+ that add up to ``topP`` or higher are kept for generation.
241
+
242
+ Parameters
243
+ ----------
244
+ value : float
245
+ Cumulative probability for vocabulary tokens
246
+ """
247
+ return self._set(topP=value)
248
+
249
+ def setRepetitionPenalty(self, value):
250
+ """Sets the parameter for repetition penalty. 1.0 means no penalty.
251
+
252
+ Parameters
253
+ ----------
254
+ value : float
255
+ The repetition penalty
256
+
257
+ References
258
+ ----------
259
+ See `Ctrl: A Conditional Transformer Language Model For Controllable
260
+ Generation <https://arxiv.org/pdf/1909.05858.pdf>`__ for more details.
261
+ """
262
+ return self._set(repetitionPenalty=value)
263
+
264
+ def setNoRepeatNgramSize(self, value):
265
+ """Sets size of n-grams that can only occur once.
266
+
267
+ If set to int > 0, all ngrams of that size can only occur once.
268
+
269
+ Parameters
270
+ ----------
271
+ value : int
272
+ N-gram size can only occur once
273
+ """
274
+ return self._set(noRepeatNgramSize=value)
275
+
276
+ def setBeamSize(self, value):
277
+ """Sets the number of beams to use for beam search.
278
+
279
+ Parameters
280
+ ----------
281
+ value : int
282
+ The number of beams to use for beam search
283
+ """
284
+ return self._set(beamSize=value)
285
+
286
+ def setStopTokenIds(self, value):
287
+ """Sets a list of token ids which are considered as stop tokens in the decoder's output.
288
+
289
+ Parameters
290
+ ----------
291
+ value : List[int]
292
+ The words to be considered as stop tokens
293
+ """
294
+ return self._set(stopTokenIds=value)
295
+
296
+ @keyword_only
297
+ def __init__(self, classname="com.johnsnowlabs.nlp.annotators.seq2seq.CoHereTransformer", java_model=None):
298
+ super(CoHereTransformer, self).__init__(
299
+ classname=classname,
300
+ java_model=java_model
301
+ )
302
+ self._setDefault(
303
+ minOutputLength=0,
304
+ maxOutputLength=20,
305
+ doSample=False,
306
+ temperature=0.6,
307
+ topK=-1,
308
+ topP=0.9,
309
+ repetitionPenalty=1.0,
310
+ noRepeatNgramSize=3,
311
+ ignoreTokenIds=[],
312
+ batchSize=1,
313
+ beamSize=1,
314
+ stopTokenIds=[128001, ]
315
+ )
316
+
317
+ @staticmethod
318
+ def loadSavedModel(folder, spark_session, use_openvino=False):
319
+ """Loads a locally saved model.
320
+
321
+ Parameters
322
+ ----------
323
+ folder : str
324
+ Folder of the saved model
325
+ spark_session : pyspark.sql.SparkSession
326
+ The current SparkSession
327
+
328
+ Returns
329
+ -------
330
+ CoHereTransformer
331
+ The restored model
332
+ """
333
+ from sparknlp.internal import _CoHereLoader
334
+ jModel = _CoHereLoader(folder, spark_session._jsparkSession, use_openvino)._java_obj
335
+ return CoHereTransformer(java_model=jModel)
336
+
337
+ @staticmethod
338
+ def pretrained(name="c4ai_command_r_v01_int4", lang="en", remote_loc=None):
339
+ """Downloads and loads a pretrained model.
340
+
341
+ Parameters
342
+ ----------
343
+ name : str, optional
344
+ Name of the pretrained model, by default "c4ai_command_r_v01_int4"
345
+ lang : str, optional
346
+ Language of the pretrained model, by default "en"
347
+ remote_loc : str, optional
348
+ Optional remote address of the resource, by default None. Will use
349
+ Spark NLPs repositories otherwise.
350
+
351
+ Returns
352
+ -------
353
+ CoHereTransformer
354
+ The restored model
355
+ """
356
+ from sparknlp.pretrained import ResourceDownloader
357
+ return ResourceDownloader.downloadModel(CoHereTransformer, name, lang, remote_loc)