spark-nlp 6.0.0__py2.py3-none-any.whl → 6.0.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of spark-nlp might be problematic. Click here for more details.

@@ -0,0 +1,308 @@
1
+ # Copyright 2017-2024 John Snow Labs
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from sparknlp.common import *
16
+
17
+ class PaliGemmaForMultiModal(AnnotatorModel,
18
+ HasBatchedAnnotateImage,
19
+ HasImageFeatureProperties,
20
+ HasEngine,
21
+ HasCandidateLabelsProperties,
22
+ HasRescaleFactor):
23
+ """PaliGemmaForMultiModal can load PaliGemma models for visual question answering.
24
+ The model consists of a vision encoder, a text encoder, a text decoder and a model merger.
25
+ The vision encoder will encode the input image, the text encoder will encode the input text,
26
+ the model merger will merge the image and text embeddings, and the text decoder will output the answer.
27
+
28
+ Pretrained models can be loaded with :meth:`.pretrained` of the companion
29
+ object:
30
+
31
+ >>> visualQAClassifier = PaliGemmaForMultiModal.pretrained() \
32
+ ... .setInputCols(["image_assembler"]) \
33
+ ... .setOutputCol("answer")
34
+
35
+ The default model is ``"paligemma_3b_pt_224_int4"``, if no name is
36
+ provided.
37
+
38
+ For available pretrained models please see the `Models Hub
39
+ <https://sparknlp.org/models?task=Question+Answering>`__.
40
+
41
+ ====================== ======================
42
+ Input Annotation types Output Annotation type
43
+ ====================== ======================
44
+ ``IMAGE`` ``DOCUMENT``
45
+ ====================== ======================
46
+
47
+ Parameters
48
+ ----------
49
+ batchSize
50
+ Batch size. Large values allows faster processing but requires more
51
+ memory, by default 2
52
+ maxSentenceLength
53
+ Max sentence length to process, by default 50
54
+
55
+ Examples
56
+ --------
57
+ >>> import sparknlp
58
+ >>> from sparknlp.base import *
59
+ >>> from sparknlp.annotator import *
60
+ >>> from pyspark.ml import Pipeline
61
+ >>> image_df = SparkSessionForTest.spark.read.format("image").load(path=images_path)
62
+ >>> test_df = image_df.withColumn("text", lit("USER: \n <image> \nDescribe this image. \nASSISTANT:\n"))
63
+ >>> imageAssembler = ImageAssembler() \
64
+ ... .setInputCol("image") \
65
+ ... .setOutputCol("image_assembler")
66
+ >>> visualQAClassifier = PaliGemmaForMultiModal.pretrained() \
67
+ ... .setInputCols("image_assembler") \
68
+ ... .setOutputCol("answer")
69
+ >>> pipeline = Pipeline().setStages([
70
+ ... imageAssembler,
71
+ ... visualQAClassifier
72
+ ... ])
73
+ >>> result = pipeline.fit(test_df).transform(test_df)
74
+ >>> result.select("image_assembler.origin", "answer.result").show(false)
75
+ +--------------------------------------+------+
76
+ |origin |result|
77
+ +--------------------------------------+------+
78
+ |[file:///content/images/bluetick.jpg] |[A dog is standing on a grassy field.]|
79
+ +--------------------------------------+------+
80
+ """
81
+
82
+ name = "PaliGemmaForMultiModal"
83
+
84
+ inputAnnotatorTypes = [AnnotatorType.IMAGE]
85
+
86
+ outputAnnotatorType = AnnotatorType.DOCUMENT
87
+
88
+ minOutputLength = Param(Params._dummy(), "minOutputLength", "Minimum length of the sequence to be generated",
89
+ typeConverter=TypeConverters.toInt)
90
+
91
+ maxOutputLength = Param(Params._dummy(), "maxOutputLength", "Maximum length of output text",
92
+ typeConverter=TypeConverters.toInt)
93
+
94
+ doSample = Param(Params._dummy(), "doSample", "Whether or not to use sampling; use greedy decoding otherwise",
95
+ typeConverter=TypeConverters.toBoolean)
96
+
97
+ temperature = Param(Params._dummy(), "temperature", "The value used to module the next token probabilities",
98
+ typeConverter=TypeConverters.toFloat)
99
+
100
+ topK = Param(Params._dummy(), "topK",
101
+ "The number of highest probability vocabulary tokens to keep for top-k-filtering",
102
+ typeConverter=TypeConverters.toInt)
103
+
104
+ topP = Param(Params._dummy(), "topP",
105
+ "If set to float < 1, only the most probable tokens with probabilities that add up to ``top_p`` or higher are kept for generation",
106
+ typeConverter=TypeConverters.toFloat)
107
+
108
+ repetitionPenalty = Param(Params._dummy(), "repetitionPenalty",
109
+ "The parameter for repetition penalty. 1.0 means no penalty. See `this paper <https://arxiv.org/pdf/1909.05858.pdf>`__ for more details",
110
+ typeConverter=TypeConverters.toFloat)
111
+
112
+ noRepeatNgramSize = Param(Params._dummy(), "noRepeatNgramSize",
113
+ "If set to int > 0, all ngrams of that size can only occur once",
114
+ typeConverter=TypeConverters.toInt)
115
+
116
+ ignoreTokenIds = Param(Params._dummy(), "ignoreTokenIds",
117
+ "A list of token ids which are ignored in the decoder's output",
118
+ typeConverter=TypeConverters.toListInt)
119
+ beamSize = Param(Params._dummy(), "beamSize",
120
+ "The Number of beams for beam search.",
121
+ typeConverter=TypeConverters.toInt)
122
+
123
+ def setMaxSentenceSize(self, value):
124
+ """Sets Maximum sentence length that the annotator will process, by
125
+ default 50.
126
+
127
+ Parameters
128
+ ----------
129
+ value : int
130
+ Maximum sentence length that the annotator will process
131
+ """
132
+ return self._set(maxSentenceLength=value)
133
+
134
+ def setIgnoreTokenIds(self, value):
135
+ """A list of token ids which are ignored in the decoder's output.
136
+
137
+ Parameters
138
+ ----------
139
+ value : List[int]
140
+ The words to be filtered out
141
+ """
142
+ return self._set(ignoreTokenIds=value)
143
+
144
+ def setMinOutputLength(self, value):
145
+ """Sets minimum length of the sequence to be generated.
146
+
147
+ Parameters
148
+ ----------
149
+ value : int
150
+ Minimum length of the sequence to be generated
151
+ """
152
+ return self._set(minOutputLength=value)
153
+
154
+ def setMaxOutputLength(self, value):
155
+ """Sets maximum length of output text.
156
+
157
+ Parameters
158
+ ----------
159
+ value : int
160
+ Maximum length of output text
161
+ """
162
+ return self._set(maxOutputLength=value)
163
+
164
+ def setDoSample(self, value):
165
+ """Sets whether or not to use sampling, use greedy decoding otherwise.
166
+
167
+ Parameters
168
+ ----------
169
+ value : bool
170
+ Whether or not to use sampling; use greedy decoding otherwise
171
+ """
172
+ return self._set(doSample=value)
173
+
174
+ def setTemperature(self, value):
175
+ """Sets the value used to module the next token probabilities.
176
+
177
+ Parameters
178
+ ----------
179
+ value : float
180
+ The value used to module the next token probabilities
181
+ """
182
+ return self._set(temperature=value)
183
+
184
+ def setTopK(self, value):
185
+ """Sets the number of highest probability vocabulary tokens to keep for
186
+ top-k-filtering.
187
+
188
+ Parameters
189
+ ----------
190
+ value : int
191
+ Number of highest probability vocabulary tokens to keep
192
+ """
193
+ return self._set(topK=value)
194
+
195
+ def setTopP(self, value):
196
+ """Sets the top cumulative probability for vocabulary tokens.
197
+
198
+ If set to float < 1, only the most probable tokens with probabilities
199
+ that add up to ``topP`` or higher are kept for generation.
200
+
201
+ Parameters
202
+ ----------
203
+ value : float
204
+ Cumulative probability for vocabulary tokens
205
+ """
206
+ return self._set(topP=value)
207
+
208
+ def setRepetitionPenalty(self, value):
209
+ """Sets the parameter for repetition penalty. 1.0 means no penalty.
210
+
211
+ Parameters
212
+ ----------
213
+ value : float
214
+ The repetition penalty
215
+
216
+ References
217
+ ----------
218
+ See `Ctrl: A Conditional Transformer Language Model For Controllable
219
+ Generation <https://arxiv.org/pdf/1909.05858.pdf>`__ for more details.
220
+ """
221
+ return self._set(repetitionPenalty=value)
222
+
223
+ def setNoRepeatNgramSize(self, value):
224
+ """Sets size of n-grams that can only occur once.
225
+
226
+ If set to int > 0, all ngrams of that size can only occur once.
227
+
228
+ Parameters
229
+ ----------
230
+ value : int
231
+ N-gram size can only occur once
232
+ """
233
+ return self._set(noRepeatNgramSize=value)
234
+
235
+ def setBeamSize(self, value):
236
+ """Sets the number of beam size for beam search, by default `4`.
237
+
238
+ Parameters
239
+ ----------
240
+ value : int
241
+ Number of beam size for beam search
242
+ """
243
+ return self._set(beamSize=value)
244
+
245
+ @keyword_only
246
+ def __init__(self, classname="com.johnsnowlabs.nlp.annotators.cv.PaliGemmaForMultiModal",
247
+ java_model=None):
248
+ super(PaliGemmaForMultiModal, self).__init__(
249
+ classname=classname,
250
+ java_model=java_model
251
+ )
252
+ self._setDefault(
253
+ batchSize=2,
254
+ minOutputLength=0,
255
+ maxOutputLength=200,
256
+ doSample=False,
257
+ temperature=1,
258
+ topK=50,
259
+ topP=1,
260
+ repetitionPenalty=1.0,
261
+ noRepeatNgramSize=0,
262
+ ignoreTokenIds=[],
263
+ beamSize=1,
264
+ )
265
+
266
+ @staticmethod
267
+ def loadSavedModel(folder, spark_session, use_openvino=False):
268
+ """Loads a locally saved model.
269
+
270
+ Parameters
271
+ ----------
272
+ folder : str
273
+ Folder of the saved model
274
+ spark_session : pyspark.sql.SparkSession
275
+ The current SparkSession
276
+
277
+ Returns
278
+ -------
279
+ PaliGemmaForMultiModal
280
+ The restored model
281
+ """
282
+ from sparknlp.internal import _PaliGemmaForMultiModalLoader
283
+ jModel = _PaliGemmaForMultiModalLoader(folder, spark_session._jsparkSession, use_openvino)._java_obj
284
+ return PaliGemmaForMultiModal(java_model=jModel)
285
+
286
+ @staticmethod
287
+ def pretrained(name="paligemma_3b_pt_224_int4", lang="en", remote_loc=None):
288
+ """Downloads and loads a pretrained model.
289
+
290
+ Parameters
291
+ ----------
292
+ name : str, optional
293
+ Name of the pretrained model, by default
294
+ "paligemma_3b_pt_224_int4"
295
+ lang : str, optional
296
+ Language of the pretrained model, by default "en"
297
+ remote_loc : str, optional
298
+ Optional remote address of the resource, by default None. Will use
299
+ Spark NLPs repositories otherwise.
300
+
301
+ Returns
302
+ -------
303
+ PaliGemmaForMultiModal
304
+ The restored model
305
+ """
306
+ from sparknlp.pretrained import ResourceDownloader
307
+ return ResourceDownloader.downloadModel(PaliGemmaForMultiModal, name, lang, remote_loc)
308
+