spark-nlp 5.5.3__py2.py3-none-any.whl → 6.0.1rc1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of spark-nlp might be problematic. Click here for more details.
- {spark_nlp-5.5.3.dist-info → spark_nlp-6.0.1rc1.dist-info}/METADATA +20 -11
- {spark_nlp-5.5.3.dist-info → spark_nlp-6.0.1rc1.dist-info}/RECORD +36 -17
- {spark_nlp-5.5.3.dist-info → spark_nlp-6.0.1rc1.dist-info}/WHEEL +1 -1
- sparknlp/__init__.py +2 -2
- sparknlp/annotator/classifier_dl/__init__.py +4 -0
- sparknlp/annotator/classifier_dl/albert_for_multiple_choice.py +161 -0
- sparknlp/annotator/classifier_dl/bert_for_multiple_choice.py +2 -2
- sparknlp/annotator/classifier_dl/distilbert_for_multiple_choice.py +161 -0
- sparknlp/annotator/classifier_dl/roberta_for_multiple_choice.py +161 -0
- sparknlp/annotator/classifier_dl/xlm_roberta_for_multiple_choice.py +149 -0
- sparknlp/annotator/cleaners/__init__.py +15 -0
- sparknlp/annotator/cleaners/cleaner.py +202 -0
- sparknlp/annotator/cleaners/extractor.py +191 -0
- sparknlp/annotator/cv/__init__.py +9 -1
- sparknlp/annotator/cv/gemma3_for_multimodal.py +351 -0
- sparknlp/annotator/cv/janus_for_multimodal.py +356 -0
- sparknlp/annotator/cv/llava_for_multimodal.py +328 -0
- sparknlp/annotator/cv/mllama_for_multimodal.py +340 -0
- sparknlp/annotator/cv/paligemma_for_multimodal.py +308 -0
- sparknlp/annotator/cv/phi3_vision_for_multimodal.py +328 -0
- sparknlp/annotator/cv/qwen2vl_transformer.py +332 -0
- sparknlp/annotator/cv/smolvlm_transformer.py +432 -0
- sparknlp/annotator/embeddings/auto_gguf_embeddings.py +10 -6
- sparknlp/annotator/seq2seq/__init__.py +3 -0
- sparknlp/annotator/seq2seq/auto_gguf_model.py +8 -503
- sparknlp/annotator/seq2seq/auto_gguf_vision_model.py +333 -0
- sparknlp/annotator/seq2seq/cohere_transformer.py +357 -0
- sparknlp/annotator/seq2seq/llama3_transformer.py +4 -4
- sparknlp/annotator/seq2seq/olmo_transformer.py +326 -0
- sparknlp/base/image_assembler.py +58 -0
- sparknlp/common/properties.py +605 -96
- sparknlp/internal/__init__.py +127 -2
- sparknlp/reader/enums.py +19 -0
- sparknlp/reader/pdf_to_text.py +111 -0
- sparknlp/reader/sparknlp_reader.py +222 -14
- spark_nlp-5.5.3.dist-info/.uuid +0 -1
- {spark_nlp-5.5.3.dist-info → spark_nlp-6.0.1rc1.dist-info}/top_level.txt +0 -0
|
@@ -38,7 +38,7 @@ class LLAMA3Transformer(AnnotatorModel, HasBatchedAnnotate, HasEngine):
|
|
|
38
38
|
... .setOutputCol("generation")
|
|
39
39
|
|
|
40
40
|
|
|
41
|
-
The default model is ``"
|
|
41
|
+
The default model is ``"llama_3_7b_instruct_hf_int4"``, if no name is provided. For available
|
|
42
42
|
pretrained models please see the `Models Hub
|
|
43
43
|
<https://sparknlp.org/models?q=llama3>`__.
|
|
44
44
|
|
|
@@ -108,7 +108,7 @@ class LLAMA3Transformer(AnnotatorModel, HasBatchedAnnotate, HasEngine):
|
|
|
108
108
|
>>> documentAssembler = DocumentAssembler() \\
|
|
109
109
|
... .setInputCol("text") \\
|
|
110
110
|
... .setOutputCol("documents")
|
|
111
|
-
>>> llama3 = LLAMA3Transformer.pretrained("
|
|
111
|
+
>>> llama3 = LLAMA3Transformer.pretrained("llama_3_7b_instruct_hf_int4") \\
|
|
112
112
|
... .setInputCols(["documents"]) \\
|
|
113
113
|
... .setMaxOutputLength(60) \\
|
|
114
114
|
... .setOutputCol("generation")
|
|
@@ -359,13 +359,13 @@ class LLAMA3Transformer(AnnotatorModel, HasBatchedAnnotate, HasEngine):
|
|
|
359
359
|
return LLAMA3Transformer(java_model=jModel)
|
|
360
360
|
|
|
361
361
|
@staticmethod
|
|
362
|
-
def pretrained(name="
|
|
362
|
+
def pretrained(name="llama_3_7b_instruct_hf_int4", lang="en", remote_loc=None):
|
|
363
363
|
"""Downloads and loads a pretrained model.
|
|
364
364
|
|
|
365
365
|
Parameters
|
|
366
366
|
----------
|
|
367
367
|
name : str, optional
|
|
368
|
-
Name of the pretrained model, by default "
|
|
368
|
+
Name of the pretrained model, by default "llama_3_7b_instruct_hf_int4"
|
|
369
369
|
lang : str, optional
|
|
370
370
|
Language of the pretrained model, by default "en"
|
|
371
371
|
remote_loc : str, optional
|
|
@@ -0,0 +1,326 @@
|
|
|
1
|
+
# Copyright 2017-2022 John Snow Labs
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Contains classes for the OLMoTransformer."""
|
|
15
|
+
|
|
16
|
+
from sparknlp.common import *
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class OLMoTransformer(AnnotatorModel, HasBatchedAnnotate, HasEngine):
|
|
20
|
+
"""OLMo: Open Language Models
|
|
21
|
+
|
|
22
|
+
OLMo is a series of Open Language Models designed to enable the science of language models.
|
|
23
|
+
The OLMo models are trained on the Dolma dataset. We release all code, checkpoints, logs
|
|
24
|
+
(coming soon), and details involved in training these models.
|
|
25
|
+
|
|
26
|
+
Pretrained models can be loaded with :meth:`.pretrained` of the companion
|
|
27
|
+
object:
|
|
28
|
+
|
|
29
|
+
>>> olmo = OLMoTransformer.pretrained() \\
|
|
30
|
+
... .setInputCols(["document"]) \\
|
|
31
|
+
... .setOutputCol("generation")
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
The default model is ``"olmo_1b_int4"``, if no name is provided. For available
|
|
35
|
+
pretrained models please see the `Models Hub
|
|
36
|
+
<https://sparknlp.org/models?q=olmo>`__.
|
|
37
|
+
|
|
38
|
+
====================== ======================
|
|
39
|
+
Input Annotation types Output Annotation type
|
|
40
|
+
====================== ======================
|
|
41
|
+
``DOCUMENT`` ``DOCUMENT``
|
|
42
|
+
====================== ======================
|
|
43
|
+
|
|
44
|
+
Parameters
|
|
45
|
+
----------
|
|
46
|
+
configProtoBytes
|
|
47
|
+
ConfigProto from tensorflow, serialized into byte array.
|
|
48
|
+
minOutputLength
|
|
49
|
+
Minimum length of the sequence to be generated, by default 0
|
|
50
|
+
maxOutputLength
|
|
51
|
+
Maximum length of output text, by default 20
|
|
52
|
+
doSample
|
|
53
|
+
Whether or not to use sampling; use greedy decoding otherwise, by default False
|
|
54
|
+
temperature
|
|
55
|
+
The value used to module the next token probabilities, by default 1.0
|
|
56
|
+
topK
|
|
57
|
+
The number of highest probability vocabulary tokens to keep for
|
|
58
|
+
top-k-filtering, by default 50
|
|
59
|
+
topP
|
|
60
|
+
Top cumulative probability for vocabulary tokens, by default 1.0
|
|
61
|
+
|
|
62
|
+
If set to float < 1, only the most probable tokens with probabilities
|
|
63
|
+
that add up to ``topP`` or higher are kept for generation.
|
|
64
|
+
repetitionPenalty
|
|
65
|
+
The parameter for repetition penalty, 1.0 means no penalty. , by default
|
|
66
|
+
1.0
|
|
67
|
+
noRepeatNgramSize
|
|
68
|
+
If set to int > 0, all ngrams of that size can only occur once, by
|
|
69
|
+
default 0
|
|
70
|
+
ignoreTokenIds
|
|
71
|
+
A list of token ids which are ignored in the decoder's output, by
|
|
72
|
+
default []
|
|
73
|
+
|
|
74
|
+
Notes
|
|
75
|
+
-----
|
|
76
|
+
This is a very computationally expensive module especially on larger
|
|
77
|
+
sequence. The use of an accelerator such as GPU is recommended.
|
|
78
|
+
|
|
79
|
+
References
|
|
80
|
+
----------
|
|
81
|
+
- `OLMo Project Page.
|
|
82
|
+
<https://allenai.org/olmo>`__
|
|
83
|
+
- `OLMO GitHub Repository.
|
|
84
|
+
<https://github.com/allenai/OLMo>`__
|
|
85
|
+
- `OLMo: Accelerating the Science of Language Models
|
|
86
|
+
<https://arxiv.org/pdf/2402.00838.pdf>`__
|
|
87
|
+
|
|
88
|
+
**Paper Abstract:**
|
|
89
|
+
|
|
90
|
+
*Language models (LMs) have become ubiquitous in both NLP research and in commercial product offerings.
|
|
91
|
+
As their commercial importance has surged, the most powerful models have become closed off, gated behind
|
|
92
|
+
proprietary interfaces, with important details of their training data, architectures, and development
|
|
93
|
+
undisclosed. Given the importance of these details in scientifically studying these models, including
|
|
94
|
+
their biases and potential risks, we believe it is essential for the research community to have access
|
|
95
|
+
to powerful, truly open LMs. To this end, this technical report details the first release of OLMo,
|
|
96
|
+
a state-of-the-art, truly Open Language Model and its framework to build and study the science of
|
|
97
|
+
language modeling. Unlike most prior efforts that have only released model weights and inference code,
|
|
98
|
+
we release OLMo and the whole framework, including training data and training and evaluation code.
|
|
99
|
+
We hope this release will empower and strengthen the open research community and inspire a new wave
|
|
100
|
+
of innovation.*
|
|
101
|
+
|
|
102
|
+
Examples
|
|
103
|
+
--------
|
|
104
|
+
>>> import sparknlp
|
|
105
|
+
>>> from sparknlp.base import *
|
|
106
|
+
>>> from sparknlp.annotator import *
|
|
107
|
+
>>> from pyspark.ml import Pipeline
|
|
108
|
+
>>> documentAssembler = DocumentAssembler() \\
|
|
109
|
+
... .setInputCol("text") \\
|
|
110
|
+
... .setOutputCol("documents")
|
|
111
|
+
>>> olmo = OLMoTransformer.pretrained("olmo-7b") \\
|
|
112
|
+
... .setInputCols(["documents"]) \\
|
|
113
|
+
... .setMaxOutputLength(50) \\
|
|
114
|
+
... .setOutputCol("generation")
|
|
115
|
+
>>> pipeline = Pipeline().setStages([documentAssembler, olmo])
|
|
116
|
+
>>> data = spark.createDataFrame([["My name is Leonardo."]]).toDF("text")
|
|
117
|
+
>>> result = pipeline.fit(data).transform(data)
|
|
118
|
+
>>> result.select("summaries.generation").show(truncate=False)
|
|
119
|
+
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|
|
120
|
+
|result |
|
|
121
|
+
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|
|
122
|
+
|[My name is Leonardo . I am a student of the University of California, Berkeley. I am interested in the field of Artificial Intelligence and its applications in the real world. I have a strong |
|
|
123
|
+
| passion for learning and am always looking for ways to improve my knowledge and skills] |
|
|
124
|
+
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|
|
125
|
+
"""
|
|
126
|
+
|
|
127
|
+
name = "OLMoTransformer"
|
|
128
|
+
|
|
129
|
+
inputAnnotatorTypes = [AnnotatorType.DOCUMENT]
|
|
130
|
+
|
|
131
|
+
outputAnnotatorType = AnnotatorType.DOCUMENT
|
|
132
|
+
|
|
133
|
+
configProtoBytes = Param(Params._dummy(), "configProtoBytes",
|
|
134
|
+
"ConfigProto from tensorflow, serialized into byte array. Get with config_proto.SerializeToString()",
|
|
135
|
+
TypeConverters.toListInt)
|
|
136
|
+
|
|
137
|
+
minOutputLength = Param(Params._dummy(), "minOutputLength", "Minimum length of the sequence to be generated",
|
|
138
|
+
typeConverter=TypeConverters.toInt)
|
|
139
|
+
|
|
140
|
+
maxOutputLength = Param(Params._dummy(), "maxOutputLength", "Maximum length of output text",
|
|
141
|
+
typeConverter=TypeConverters.toInt)
|
|
142
|
+
|
|
143
|
+
doSample = Param(Params._dummy(), "doSample", "Whether or not to use sampling; use greedy decoding otherwise",
|
|
144
|
+
typeConverter=TypeConverters.toBoolean)
|
|
145
|
+
|
|
146
|
+
temperature = Param(Params._dummy(), "temperature", "The value used to module the next token probabilities",
|
|
147
|
+
typeConverter=TypeConverters.toFloat)
|
|
148
|
+
|
|
149
|
+
topK = Param(Params._dummy(), "topK",
|
|
150
|
+
"The number of highest probability vocabulary tokens to keep for top-k-filtering",
|
|
151
|
+
typeConverter=TypeConverters.toInt)
|
|
152
|
+
|
|
153
|
+
topP = Param(Params._dummy(), "topP",
|
|
154
|
+
"If set to float < 1, only the most probable tokens with probabilities that add up to ``top_p`` or higher are kept for generation",
|
|
155
|
+
typeConverter=TypeConverters.toFloat)
|
|
156
|
+
|
|
157
|
+
repetitionPenalty = Param(Params._dummy(), "repetitionPenalty",
|
|
158
|
+
"The parameter for repetition penalty. 1.0 means no penalty. See `this paper <https://arxiv.org/pdf/1909.05858.pdf>`__ for more details",
|
|
159
|
+
typeConverter=TypeConverters.toFloat)
|
|
160
|
+
|
|
161
|
+
noRepeatNgramSize = Param(Params._dummy(), "noRepeatNgramSize",
|
|
162
|
+
"If set to int > 0, all ngrams of that size can only occur once",
|
|
163
|
+
typeConverter=TypeConverters.toInt)
|
|
164
|
+
|
|
165
|
+
ignoreTokenIds = Param(Params._dummy(), "ignoreTokenIds",
|
|
166
|
+
"A list of token ids which are ignored in the decoder's output",
|
|
167
|
+
typeConverter=TypeConverters.toListInt)
|
|
168
|
+
|
|
169
|
+
def setIgnoreTokenIds(self, value):
|
|
170
|
+
"""A list of token ids which are ignored in the decoder's output.
|
|
171
|
+
|
|
172
|
+
Parameters
|
|
173
|
+
----------
|
|
174
|
+
value : List[int]
|
|
175
|
+
The words to be filtered out
|
|
176
|
+
"""
|
|
177
|
+
return self._set(ignoreTokenIds=value)
|
|
178
|
+
|
|
179
|
+
def setConfigProtoBytes(self, b):
|
|
180
|
+
"""Sets configProto from tensorflow, serialized into byte array.
|
|
181
|
+
|
|
182
|
+
Parameters
|
|
183
|
+
----------
|
|
184
|
+
b : List[int]
|
|
185
|
+
ConfigProto from tensorflow, serialized into byte array
|
|
186
|
+
"""
|
|
187
|
+
return self._set(configProtoBytes=b)
|
|
188
|
+
|
|
189
|
+
def setMinOutputLength(self, value):
|
|
190
|
+
"""Sets minimum length of the sequence to be generated.
|
|
191
|
+
|
|
192
|
+
Parameters
|
|
193
|
+
----------
|
|
194
|
+
value : int
|
|
195
|
+
Minimum length of the sequence to be generated
|
|
196
|
+
"""
|
|
197
|
+
return self._set(minOutputLength=value)
|
|
198
|
+
|
|
199
|
+
def setMaxOutputLength(self, value):
|
|
200
|
+
"""Sets maximum length of output text.
|
|
201
|
+
|
|
202
|
+
Parameters
|
|
203
|
+
----------
|
|
204
|
+
value : int
|
|
205
|
+
Maximum length of output text
|
|
206
|
+
"""
|
|
207
|
+
return self._set(maxOutputLength=value)
|
|
208
|
+
|
|
209
|
+
def setDoSample(self, value):
|
|
210
|
+
"""Sets whether or not to use sampling, use greedy decoding otherwise.
|
|
211
|
+
|
|
212
|
+
Parameters
|
|
213
|
+
----------
|
|
214
|
+
value : bool
|
|
215
|
+
Whether or not to use sampling; use greedy decoding otherwise
|
|
216
|
+
"""
|
|
217
|
+
return self._set(doSample=value)
|
|
218
|
+
|
|
219
|
+
def setTemperature(self, value):
|
|
220
|
+
"""Sets the value used to module the next token probabilities.
|
|
221
|
+
|
|
222
|
+
Parameters
|
|
223
|
+
----------
|
|
224
|
+
value : float
|
|
225
|
+
The value used to module the next token probabilities
|
|
226
|
+
"""
|
|
227
|
+
return self._set(temperature=value)
|
|
228
|
+
|
|
229
|
+
def setTopK(self, value):
|
|
230
|
+
"""Sets the number of highest probability vocabulary tokens to keep for
|
|
231
|
+
top-k-filtering.
|
|
232
|
+
|
|
233
|
+
Parameters
|
|
234
|
+
----------
|
|
235
|
+
value : int
|
|
236
|
+
Number of highest probability vocabulary tokens to keep
|
|
237
|
+
"""
|
|
238
|
+
return self._set(topK=value)
|
|
239
|
+
|
|
240
|
+
def setTopP(self, value):
|
|
241
|
+
"""Sets the top cumulative probability for vocabulary tokens.
|
|
242
|
+
|
|
243
|
+
If set to float < 1, only the most probable tokens with probabilities
|
|
244
|
+
that add up to ``topP`` or higher are kept for generation.
|
|
245
|
+
|
|
246
|
+
Parameters
|
|
247
|
+
----------
|
|
248
|
+
value : float
|
|
249
|
+
Cumulative probability for vocabulary tokens
|
|
250
|
+
"""
|
|
251
|
+
return self._set(topP=value)
|
|
252
|
+
|
|
253
|
+
def setRepetitionPenalty(self, value):
|
|
254
|
+
"""Sets the parameter for repetition penalty. 1.0 means no penalty.
|
|
255
|
+
|
|
256
|
+
Parameters
|
|
257
|
+
----------
|
|
258
|
+
value : float
|
|
259
|
+
The repetition penalty
|
|
260
|
+
|
|
261
|
+
References
|
|
262
|
+
----------
|
|
263
|
+
See `Ctrl: A Conditional Transformer Language Model For Controllable
|
|
264
|
+
Generation <https://arxiv.org/pdf/1909.05858.pdf>`__ for more details.
|
|
265
|
+
"""
|
|
266
|
+
return self._set(repetitionPenalty=value)
|
|
267
|
+
|
|
268
|
+
def setNoRepeatNgramSize(self, value):
|
|
269
|
+
"""Sets size of n-grams that can only occur once.
|
|
270
|
+
|
|
271
|
+
If set to int > 0, all ngrams of that size can only occur once.
|
|
272
|
+
|
|
273
|
+
Parameters
|
|
274
|
+
----------
|
|
275
|
+
value : int
|
|
276
|
+
N-gram size can only occur once
|
|
277
|
+
"""
|
|
278
|
+
return self._set(noRepeatNgramSize=value)
|
|
279
|
+
|
|
280
|
+
@keyword_only
|
|
281
|
+
def __init__(self, classname="com.johnsnowlabs.nlp.annotators.seq2seq.OLMoTransformer", java_model=None):
|
|
282
|
+
super(OLMoTransformer, self).__init__(classname=classname, java_model=java_model)
|
|
283
|
+
self._setDefault(minOutputLength=0, maxOutputLength=20, doSample=False, temperature=0.6, topK=50, topP=0.9,
|
|
284
|
+
repetitionPenalty=1.0, noRepeatNgramSize=0, ignoreTokenIds=[], batchSize=1)
|
|
285
|
+
|
|
286
|
+
@staticmethod
|
|
287
|
+
def loadSavedModel(folder, spark_session):
|
|
288
|
+
"""Loads a locally saved model.
|
|
289
|
+
|
|
290
|
+
Parameters
|
|
291
|
+
----------
|
|
292
|
+
folder : str
|
|
293
|
+
Folder of the saved model
|
|
294
|
+
spark_session : pyspark.sql.SparkSession
|
|
295
|
+
The current SparkSession
|
|
296
|
+
|
|
297
|
+
Returns
|
|
298
|
+
-------
|
|
299
|
+
OLMoTransformer
|
|
300
|
+
The restored model
|
|
301
|
+
"""
|
|
302
|
+
from sparknlp.internal import _OLMoLoader
|
|
303
|
+
jModel = _OLMoLoader(folder, spark_session._jsparkSession)._java_obj
|
|
304
|
+
return OLMoTransformer(java_model=jModel)
|
|
305
|
+
|
|
306
|
+
@staticmethod
|
|
307
|
+
def pretrained(name="olmo_1b_int4", lang="en", remote_loc=None):
|
|
308
|
+
"""Downloads and loads a pretrained model.
|
|
309
|
+
|
|
310
|
+
Parameters
|
|
311
|
+
----------
|
|
312
|
+
name : str, optional
|
|
313
|
+
Name of the pretrained model, by default "olmo-7b"
|
|
314
|
+
lang : str, optional
|
|
315
|
+
Language of the pretrained model, by default "en"
|
|
316
|
+
remote_loc : str, optional
|
|
317
|
+
Optional remote address of the resource, by default None. Will use
|
|
318
|
+
Spark NLPs repositories otherwise.
|
|
319
|
+
|
|
320
|
+
Returns
|
|
321
|
+
-------
|
|
322
|
+
OLMoTransformer
|
|
323
|
+
The restored model
|
|
324
|
+
"""
|
|
325
|
+
from sparknlp.pretrained import ResourceDownloader
|
|
326
|
+
return ResourceDownloader.downloadModel(OLMoTransformer, name, lang, remote_loc)
|
sparknlp/base/image_assembler.py
CHANGED
|
@@ -15,6 +15,8 @@
|
|
|
15
15
|
|
|
16
16
|
from pyspark import keyword_only
|
|
17
17
|
from pyspark.ml.param import TypeConverters, Params, Param
|
|
18
|
+
from pyspark.sql import SparkSession, DataFrame
|
|
19
|
+
from pyspark.sql.functions import regexp_replace, col
|
|
18
20
|
|
|
19
21
|
from sparknlp.common import AnnotatorType
|
|
20
22
|
from sparknlp.internal import AnnotatorTransformer
|
|
@@ -112,3 +114,59 @@ class ImageAssembler(AnnotatorTransformer):
|
|
|
112
114
|
Name of an optional input text column
|
|
113
115
|
"""
|
|
114
116
|
return self._set(inputCol=value)
|
|
117
|
+
|
|
118
|
+
@classmethod
|
|
119
|
+
def loadImagesAsBytes(cls, spark: SparkSession, path: str):
|
|
120
|
+
"""
|
|
121
|
+
Loads images from a given path and returns them as raw bytes, instead of the default
|
|
122
|
+
OpenCV-compatible format. Supported image types include JPEG, PNG, GIF, and BMP.
|
|
123
|
+
|
|
124
|
+
Multimodal inference with llama.cpp requires raw bytes as input.
|
|
125
|
+
|
|
126
|
+
Parameters
|
|
127
|
+
----------
|
|
128
|
+
spark : SparkSession
|
|
129
|
+
The active SparkSession.
|
|
130
|
+
path : str
|
|
131
|
+
The path to the images. Supported image types are JPEG, PNG, GIF, and BMP.
|
|
132
|
+
|
|
133
|
+
Returns
|
|
134
|
+
-------
|
|
135
|
+
DataFrame
|
|
136
|
+
A DataFrame containing the images as raw bytes along with their metadata.
|
|
137
|
+
"""
|
|
138
|
+
|
|
139
|
+
# Replace the path separator in the `origin` field and `path` column, so that they match
|
|
140
|
+
def replace_path(column_name: str):
|
|
141
|
+
return regexp_replace(col(column_name), ":///", ":/")
|
|
142
|
+
|
|
143
|
+
# Load the images as metadata with the default Spark image format
|
|
144
|
+
data = (
|
|
145
|
+
spark.read.format("image")
|
|
146
|
+
.option("dropInvalid", True)
|
|
147
|
+
.load(path)
|
|
148
|
+
.withColumn(
|
|
149
|
+
"image", col("image").withField("origin", replace_path("image.origin"))
|
|
150
|
+
)
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
# Load the images as raw binary files
|
|
154
|
+
image_bytes = (
|
|
155
|
+
spark.read.format("binaryFile")
|
|
156
|
+
.option("pathGlobFilter", "*.{jpeg,jpg,png,gif,bmp,JPEG,JPG,PNG,GIF,BMP}")
|
|
157
|
+
.option("dropInvalid", True)
|
|
158
|
+
.load(path)
|
|
159
|
+
.withColumn("path", replace_path("path"))
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
# Join the two datasets on the file path
|
|
163
|
+
df_joined = data.join(
|
|
164
|
+
image_bytes, data["image.origin"] == image_bytes["path"], "inner"
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
# Replace the `data` field of the `image` column with raw bytes
|
|
168
|
+
df_image_replaced = df_joined.withColumn(
|
|
169
|
+
"image", df_joined["image"].withField("data", df_joined["content"])
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
return df_image_replaced
|