spark-nlp 4.2.6__py2.py3-none-any.whl → 6.2.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- com/johnsnowlabs/ml/__init__.py +0 -0
- com/johnsnowlabs/ml/ai/__init__.py +10 -0
- spark_nlp-6.2.1.dist-info/METADATA +362 -0
- spark_nlp-6.2.1.dist-info/RECORD +292 -0
- {spark_nlp-4.2.6.dist-info → spark_nlp-6.2.1.dist-info}/WHEEL +1 -1
- sparknlp/__init__.py +81 -28
- sparknlp/annotation.py +3 -2
- sparknlp/annotator/__init__.py +6 -0
- sparknlp/annotator/audio/__init__.py +2 -0
- sparknlp/annotator/audio/hubert_for_ctc.py +188 -0
- sparknlp/annotator/audio/wav2vec2_for_ctc.py +14 -14
- sparknlp/annotator/audio/whisper_for_ctc.py +251 -0
- sparknlp/{base → annotator}/chunk2_doc.py +4 -7
- sparknlp/annotator/chunker.py +1 -2
- sparknlp/annotator/classifier_dl/__init__.py +17 -0
- sparknlp/annotator/classifier_dl/albert_for_multiple_choice.py +161 -0
- sparknlp/annotator/classifier_dl/albert_for_question_answering.py +3 -15
- sparknlp/annotator/classifier_dl/albert_for_sequence_classification.py +4 -18
- sparknlp/annotator/classifier_dl/albert_for_token_classification.py +3 -17
- sparknlp/annotator/classifier_dl/albert_for_zero_shot_classification.py +211 -0
- sparknlp/annotator/classifier_dl/bart_for_zero_shot_classification.py +225 -0
- sparknlp/annotator/classifier_dl/bert_for_multiple_choice.py +161 -0
- sparknlp/annotator/classifier_dl/bert_for_question_answering.py +6 -20
- sparknlp/annotator/classifier_dl/bert_for_sequence_classification.py +3 -17
- sparknlp/annotator/classifier_dl/bert_for_token_classification.py +3 -17
- sparknlp/annotator/classifier_dl/bert_for_zero_shot_classification.py +212 -0
- sparknlp/annotator/classifier_dl/camembert_for_question_answering.py +168 -0
- sparknlp/annotator/classifier_dl/camembert_for_sequence_classification.py +5 -19
- sparknlp/annotator/classifier_dl/camembert_for_token_classification.py +5 -19
- sparknlp/annotator/classifier_dl/camembert_for_zero_shot_classification.py +202 -0
- sparknlp/annotator/classifier_dl/classifier_dl.py +4 -4
- sparknlp/annotator/classifier_dl/deberta_for_question_answering.py +3 -17
- sparknlp/annotator/classifier_dl/deberta_for_sequence_classification.py +4 -19
- sparknlp/annotator/classifier_dl/deberta_for_token_classification.py +5 -21
- sparknlp/annotator/classifier_dl/deberta_for_zero_shot_classification.py +193 -0
- sparknlp/annotator/classifier_dl/distil_bert_for_question_answering.py +3 -17
- sparknlp/annotator/classifier_dl/distil_bert_for_sequence_classification.py +4 -18
- sparknlp/annotator/classifier_dl/distil_bert_for_token_classification.py +3 -17
- sparknlp/annotator/classifier_dl/distil_bert_for_zero_shot_classification.py +211 -0
- sparknlp/annotator/classifier_dl/distilbert_for_multiple_choice.py +161 -0
- sparknlp/annotator/classifier_dl/longformer_for_question_answering.py +3 -17
- sparknlp/annotator/classifier_dl/longformer_for_sequence_classification.py +4 -18
- sparknlp/annotator/classifier_dl/longformer_for_token_classification.py +3 -17
- sparknlp/annotator/classifier_dl/mpnet_for_question_answering.py +148 -0
- sparknlp/annotator/classifier_dl/mpnet_for_sequence_classification.py +188 -0
- sparknlp/annotator/classifier_dl/mpnet_for_token_classification.py +173 -0
- sparknlp/annotator/classifier_dl/multi_classifier_dl.py +3 -3
- sparknlp/annotator/classifier_dl/roberta_for_multiple_choice.py +161 -0
- sparknlp/annotator/classifier_dl/roberta_for_question_answering.py +3 -17
- sparknlp/annotator/classifier_dl/roberta_for_sequence_classification.py +4 -18
- sparknlp/annotator/classifier_dl/roberta_for_token_classification.py +1 -1
- sparknlp/annotator/classifier_dl/roberta_for_zero_shot_classification.py +225 -0
- sparknlp/annotator/classifier_dl/sentiment_dl.py +4 -4
- sparknlp/annotator/classifier_dl/tapas_for_question_answering.py +2 -2
- sparknlp/annotator/classifier_dl/xlm_roberta_for_multiple_choice.py +149 -0
- sparknlp/annotator/classifier_dl/xlm_roberta_for_question_answering.py +3 -17
- sparknlp/annotator/classifier_dl/xlm_roberta_for_sequence_classification.py +4 -18
- sparknlp/annotator/classifier_dl/xlm_roberta_for_token_classification.py +6 -20
- sparknlp/annotator/classifier_dl/xlm_roberta_for_zero_shot_classification.py +225 -0
- sparknlp/annotator/classifier_dl/xlnet_for_sequence_classification.py +4 -18
- sparknlp/annotator/classifier_dl/xlnet_for_token_classification.py +3 -17
- sparknlp/annotator/cleaners/__init__.py +15 -0
- sparknlp/annotator/cleaners/cleaner.py +202 -0
- sparknlp/annotator/cleaners/extractor.py +191 -0
- sparknlp/annotator/coref/spanbert_coref.py +4 -18
- sparknlp/annotator/cv/__init__.py +15 -0
- sparknlp/annotator/cv/blip_for_question_answering.py +172 -0
- sparknlp/annotator/cv/clip_for_zero_shot_classification.py +193 -0
- sparknlp/annotator/cv/convnext_for_image_classification.py +269 -0
- sparknlp/annotator/cv/florence2_transformer.py +180 -0
- sparknlp/annotator/cv/gemma3_for_multimodal.py +346 -0
- sparknlp/annotator/cv/internvl_for_multimodal.py +280 -0
- sparknlp/annotator/cv/janus_for_multimodal.py +351 -0
- sparknlp/annotator/cv/llava_for_multimodal.py +328 -0
- sparknlp/annotator/cv/mllama_for_multimodal.py +340 -0
- sparknlp/annotator/cv/paligemma_for_multimodal.py +308 -0
- sparknlp/annotator/cv/phi3_vision_for_multimodal.py +328 -0
- sparknlp/annotator/cv/qwen2vl_transformer.py +332 -0
- sparknlp/annotator/cv/smolvlm_transformer.py +426 -0
- sparknlp/annotator/cv/swin_for_image_classification.py +242 -0
- sparknlp/annotator/cv/vision_encoder_decoder_for_image_captioning.py +240 -0
- sparknlp/annotator/cv/vit_for_image_classification.py +36 -4
- sparknlp/annotator/dataframe_optimizer.py +216 -0
- sparknlp/annotator/date2_chunk.py +88 -0
- sparknlp/annotator/dependency/dependency_parser.py +2 -3
- sparknlp/annotator/dependency/typed_dependency_parser.py +3 -4
- sparknlp/annotator/document_character_text_splitter.py +228 -0
- sparknlp/annotator/document_normalizer.py +37 -1
- sparknlp/annotator/document_token_splitter.py +175 -0
- sparknlp/annotator/document_token_splitter_test.py +85 -0
- sparknlp/annotator/embeddings/__init__.py +11 -0
- sparknlp/annotator/embeddings/albert_embeddings.py +4 -18
- sparknlp/annotator/embeddings/auto_gguf_embeddings.py +539 -0
- sparknlp/annotator/embeddings/bert_embeddings.py +9 -22
- sparknlp/annotator/embeddings/bert_sentence_embeddings.py +12 -24
- sparknlp/annotator/embeddings/bge_embeddings.py +199 -0
- sparknlp/annotator/embeddings/camembert_embeddings.py +4 -20
- sparknlp/annotator/embeddings/chunk_embeddings.py +1 -2
- sparknlp/annotator/embeddings/deberta_embeddings.py +2 -16
- sparknlp/annotator/embeddings/distil_bert_embeddings.py +5 -19
- sparknlp/annotator/embeddings/doc2vec.py +7 -1
- sparknlp/annotator/embeddings/e5_embeddings.py +195 -0
- sparknlp/annotator/embeddings/e5v_embeddings.py +138 -0
- sparknlp/annotator/embeddings/elmo_embeddings.py +2 -2
- sparknlp/annotator/embeddings/instructor_embeddings.py +204 -0
- sparknlp/annotator/embeddings/longformer_embeddings.py +3 -17
- sparknlp/annotator/embeddings/minilm_embeddings.py +189 -0
- sparknlp/annotator/embeddings/mpnet_embeddings.py +192 -0
- sparknlp/annotator/embeddings/mxbai_embeddings.py +184 -0
- sparknlp/annotator/embeddings/nomic_embeddings.py +181 -0
- sparknlp/annotator/embeddings/roberta_embeddings.py +9 -21
- sparknlp/annotator/embeddings/roberta_sentence_embeddings.py +7 -21
- sparknlp/annotator/embeddings/sentence_embeddings.py +2 -3
- sparknlp/annotator/embeddings/snowflake_embeddings.py +202 -0
- sparknlp/annotator/embeddings/uae_embeddings.py +211 -0
- sparknlp/annotator/embeddings/universal_sentence_encoder.py +3 -3
- sparknlp/annotator/embeddings/word2vec.py +7 -1
- sparknlp/annotator/embeddings/word_embeddings.py +4 -5
- sparknlp/annotator/embeddings/xlm_roberta_embeddings.py +9 -21
- sparknlp/annotator/embeddings/xlm_roberta_sentence_embeddings.py +7 -21
- sparknlp/annotator/embeddings/xlnet_embeddings.py +4 -18
- sparknlp/annotator/er/entity_ruler.py +37 -23
- sparknlp/annotator/keyword_extraction/yake_keyword_extraction.py +2 -3
- sparknlp/annotator/ld_dl/language_detector_dl.py +2 -2
- sparknlp/annotator/lemmatizer.py +3 -4
- sparknlp/annotator/matcher/date_matcher.py +35 -3
- sparknlp/annotator/matcher/multi_date_matcher.py +1 -2
- sparknlp/annotator/matcher/regex_matcher.py +3 -3
- sparknlp/annotator/matcher/text_matcher.py +2 -3
- sparknlp/annotator/n_gram_generator.py +1 -2
- sparknlp/annotator/ner/__init__.py +3 -1
- sparknlp/annotator/ner/ner_converter.py +18 -0
- sparknlp/annotator/ner/ner_crf.py +4 -5
- sparknlp/annotator/ner/ner_dl.py +10 -5
- sparknlp/annotator/ner/ner_dl_graph_checker.py +293 -0
- sparknlp/annotator/ner/ner_overwriter.py +2 -2
- sparknlp/annotator/ner/zero_shot_ner_model.py +173 -0
- sparknlp/annotator/normalizer.py +2 -2
- sparknlp/annotator/openai/__init__.py +16 -0
- sparknlp/annotator/openai/openai_completion.py +349 -0
- sparknlp/annotator/openai/openai_embeddings.py +106 -0
- sparknlp/annotator/pos/perceptron.py +6 -7
- sparknlp/annotator/sentence/sentence_detector.py +2 -2
- sparknlp/annotator/sentence/sentence_detector_dl.py +3 -3
- sparknlp/annotator/sentiment/sentiment_detector.py +4 -5
- sparknlp/annotator/sentiment/vivekn_sentiment.py +4 -5
- sparknlp/annotator/seq2seq/__init__.py +17 -0
- sparknlp/annotator/seq2seq/auto_gguf_model.py +304 -0
- sparknlp/annotator/seq2seq/auto_gguf_reranker.py +334 -0
- sparknlp/annotator/seq2seq/auto_gguf_vision_model.py +336 -0
- sparknlp/annotator/seq2seq/bart_transformer.py +420 -0
- sparknlp/annotator/seq2seq/cohere_transformer.py +357 -0
- sparknlp/annotator/seq2seq/cpm_transformer.py +321 -0
- sparknlp/annotator/seq2seq/gpt2_transformer.py +1 -1
- sparknlp/annotator/seq2seq/llama2_transformer.py +343 -0
- sparknlp/annotator/seq2seq/llama3_transformer.py +381 -0
- sparknlp/annotator/seq2seq/m2m100_transformer.py +392 -0
- sparknlp/annotator/seq2seq/marian_transformer.py +124 -3
- sparknlp/annotator/seq2seq/mistral_transformer.py +348 -0
- sparknlp/annotator/seq2seq/nllb_transformer.py +420 -0
- sparknlp/annotator/seq2seq/olmo_transformer.py +326 -0
- sparknlp/annotator/seq2seq/phi2_transformer.py +326 -0
- sparknlp/annotator/seq2seq/phi3_transformer.py +330 -0
- sparknlp/annotator/seq2seq/phi4_transformer.py +387 -0
- sparknlp/annotator/seq2seq/qwen_transformer.py +340 -0
- sparknlp/annotator/seq2seq/starcoder_transformer.py +335 -0
- sparknlp/annotator/seq2seq/t5_transformer.py +54 -4
- sparknlp/annotator/similarity/__init__.py +0 -0
- sparknlp/annotator/similarity/document_similarity_ranker.py +379 -0
- sparknlp/annotator/spell_check/context_spell_checker.py +116 -17
- sparknlp/annotator/spell_check/norvig_sweeting.py +3 -6
- sparknlp/annotator/spell_check/symmetric_delete.py +1 -1
- sparknlp/annotator/stemmer.py +2 -3
- sparknlp/annotator/stop_words_cleaner.py +3 -4
- sparknlp/annotator/tf_ner_dl_graph_builder.py +1 -1
- sparknlp/annotator/token/__init__.py +0 -1
- sparknlp/annotator/token/recursive_tokenizer.py +2 -3
- sparknlp/annotator/token/tokenizer.py +2 -3
- sparknlp/annotator/ws/word_segmenter.py +35 -10
- sparknlp/base/__init__.py +2 -3
- sparknlp/base/doc2_chunk.py +0 -3
- sparknlp/base/document_assembler.py +5 -5
- sparknlp/base/embeddings_finisher.py +14 -2
- sparknlp/base/finisher.py +15 -4
- sparknlp/base/gguf_ranking_finisher.py +234 -0
- sparknlp/base/image_assembler.py +69 -0
- sparknlp/base/light_pipeline.py +53 -21
- sparknlp/base/multi_document_assembler.py +9 -13
- sparknlp/base/prompt_assembler.py +207 -0
- sparknlp/base/token_assembler.py +1 -2
- sparknlp/common/__init__.py +2 -0
- sparknlp/common/annotator_type.py +1 -0
- sparknlp/common/completion_post_processing.py +37 -0
- sparknlp/common/match_strategy.py +33 -0
- sparknlp/common/properties.py +914 -9
- sparknlp/internal/__init__.py +841 -116
- sparknlp/internal/annotator_java_ml.py +1 -1
- sparknlp/internal/annotator_transformer.py +3 -0
- sparknlp/logging/comet.py +2 -2
- sparknlp/partition/__init__.py +16 -0
- sparknlp/partition/partition.py +244 -0
- sparknlp/partition/partition_properties.py +902 -0
- sparknlp/partition/partition_transformer.py +200 -0
- sparknlp/pretrained/pretrained_pipeline.py +1 -1
- sparknlp/pretrained/resource_downloader.py +126 -2
- sparknlp/reader/__init__.py +15 -0
- sparknlp/reader/enums.py +19 -0
- sparknlp/reader/pdf_to_text.py +190 -0
- sparknlp/reader/reader2doc.py +124 -0
- sparknlp/reader/reader2image.py +136 -0
- sparknlp/reader/reader2table.py +44 -0
- sparknlp/reader/reader_assembler.py +159 -0
- sparknlp/reader/sparknlp_reader.py +461 -0
- sparknlp/training/__init__.py +1 -0
- sparknlp/training/conll.py +8 -2
- sparknlp/training/spacy_to_annotation.py +57 -0
- sparknlp/util.py +26 -0
- spark_nlp-4.2.6.dist-info/METADATA +0 -1256
- spark_nlp-4.2.6.dist-info/RECORD +0 -196
- {spark_nlp-4.2.6.dist-info → spark_nlp-6.2.1.dist-info}/top_level.txt +0 -0
- /sparknlp/annotator/{token/token2_chunk.py → token2_chunk.py} +0 -0
|
@@ -17,11 +17,12 @@ from sparknlp.common import *
|
|
|
17
17
|
|
|
18
18
|
|
|
19
19
|
class BertSentenceEmbeddings(AnnotatorModel,
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
20
|
+
HasEmbeddingsProperties,
|
|
21
|
+
HasCaseSensitiveProperties,
|
|
22
|
+
HasStorageRef,
|
|
23
|
+
HasBatchedAnnotate,
|
|
24
|
+
HasEngine,
|
|
25
|
+
HasMaxSentenceLengthLimit):
|
|
25
26
|
"""Sentence-level embeddings using BERT. BERT (Bidirectional Encoder
|
|
26
27
|
Representations from Transformers) provides dense vector representations for
|
|
27
28
|
natural language by using a deep, pre-trained neural network with the
|
|
@@ -38,10 +39,10 @@ class BertSentenceEmbeddings(AnnotatorModel,
|
|
|
38
39
|
The default model is ``"sent_small_bert_L2_768"``, if no name is provided.
|
|
39
40
|
|
|
40
41
|
For available pretrained models please see the
|
|
41
|
-
`Models Hub <https://
|
|
42
|
+
`Models Hub <https://sparknlp.org/models?task=Embeddings>`__.
|
|
42
43
|
|
|
43
44
|
For extended examples of usage, see the
|
|
44
|
-
`
|
|
45
|
+
`Examples <https://github.com/JohnSnowLabs/spark-nlp/blob/master/examples/python/transformers/HuggingFace%20in%20Spark%20NLP%20-%20BERT%20Sentence.ipynb>`__.
|
|
45
46
|
|
|
46
47
|
====================== =======================
|
|
47
48
|
Input Annotation types Output Annotation type
|
|
@@ -133,11 +134,6 @@ class BertSentenceEmbeddings(AnnotatorModel,
|
|
|
133
134
|
|
|
134
135
|
outputAnnotatorType = AnnotatorType.SENTENCE_EMBEDDINGS
|
|
135
136
|
|
|
136
|
-
maxSentenceLength = Param(Params._dummy(),
|
|
137
|
-
"maxSentenceLength",
|
|
138
|
-
"Max sentence length to process",
|
|
139
|
-
typeConverter=TypeConverters.toInt)
|
|
140
|
-
|
|
141
137
|
isLong = Param(Params._dummy(),
|
|
142
138
|
"isLong",
|
|
143
139
|
"Use Long type instead of Int type for inputs buffer - Some Bert models require Long instead of Int.",
|
|
@@ -158,16 +154,6 @@ class BertSentenceEmbeddings(AnnotatorModel,
|
|
|
158
154
|
"""
|
|
159
155
|
return self._set(configProtoBytes=b)
|
|
160
156
|
|
|
161
|
-
def setMaxSentenceLength(self, value):
|
|
162
|
-
"""Sets max sentence length to process.
|
|
163
|
-
|
|
164
|
-
Parameters
|
|
165
|
-
----------
|
|
166
|
-
value : int
|
|
167
|
-
Max sentence length to process
|
|
168
|
-
"""
|
|
169
|
-
return self._set(maxSentenceLength=value)
|
|
170
|
-
|
|
171
157
|
def setIsLong(self, value):
|
|
172
158
|
"""Sets whether to use Long type instead of Int type for inputs buffer.
|
|
173
159
|
|
|
@@ -194,7 +180,7 @@ class BertSentenceEmbeddings(AnnotatorModel,
|
|
|
194
180
|
)
|
|
195
181
|
|
|
196
182
|
@staticmethod
|
|
197
|
-
def loadSavedModel(folder, spark_session):
|
|
183
|
+
def loadSavedModel(folder, spark_session, use_openvino=False):
|
|
198
184
|
"""Loads a locally saved model.
|
|
199
185
|
|
|
200
186
|
Parameters
|
|
@@ -203,6 +189,8 @@ class BertSentenceEmbeddings(AnnotatorModel,
|
|
|
203
189
|
Folder of the saved model
|
|
204
190
|
spark_session : pyspark.sql.SparkSession
|
|
205
191
|
The current SparkSession
|
|
192
|
+
use_openvino: bool
|
|
193
|
+
Use OpenVINO backend
|
|
206
194
|
|
|
207
195
|
Returns
|
|
208
196
|
-------
|
|
@@ -210,7 +198,7 @@ class BertSentenceEmbeddings(AnnotatorModel,
|
|
|
210
198
|
The restored model
|
|
211
199
|
"""
|
|
212
200
|
from sparknlp.internal import _BertSentenceLoader
|
|
213
|
-
jModel = _BertSentenceLoader(folder, spark_session._jsparkSession)._java_obj
|
|
201
|
+
jModel = _BertSentenceLoader(folder, spark_session._jsparkSession, use_openvino)._java_obj
|
|
214
202
|
return BertSentenceEmbeddings(java_model=jModel)
|
|
215
203
|
|
|
216
204
|
@staticmethod
|
|
@@ -0,0 +1,199 @@
|
|
|
1
|
+
# Copyright 2017-2022 John Snow Labs
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Contains classes for BGEEmbeddings."""
|
|
15
|
+
|
|
16
|
+
from sparknlp.common import *
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class BGEEmbeddings(AnnotatorModel,
|
|
20
|
+
HasEmbeddingsProperties,
|
|
21
|
+
HasCaseSensitiveProperties,
|
|
22
|
+
HasStorageRef,
|
|
23
|
+
HasBatchedAnnotate,
|
|
24
|
+
HasMaxSentenceLengthLimit,
|
|
25
|
+
HasClsTokenProperties):
|
|
26
|
+
"""Sentence embeddings using BGE.
|
|
27
|
+
|
|
28
|
+
BGE, or BAAI General Embeddings, a model that can map any text to a low-dimensional dense
|
|
29
|
+
vector which can be used for tasks like retrieval, classification, clustering, or semantic search.
|
|
30
|
+
|
|
31
|
+
Note that this annotator is only supported for Spark Versions 3.4 and up.
|
|
32
|
+
|
|
33
|
+
Pretrained models can be loaded with `pretrained` of the companion object:
|
|
34
|
+
|
|
35
|
+
>>> embeddings = BGEEmbeddings.pretrained() \\
|
|
36
|
+
... .setInputCols(["document"]) \\
|
|
37
|
+
... .setOutputCol("bge_embeddings")
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
The default model is ``"bge_base"``, if no name is provided.
|
|
41
|
+
|
|
42
|
+
For available pretrained models please see the
|
|
43
|
+
`Models Hub <https://sparknlp.org/models?q=BGE>`__.
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
====================== ======================
|
|
47
|
+
Input Annotation types Output Annotation type
|
|
48
|
+
====================== ======================
|
|
49
|
+
``DOCUMENT`` ``SENTENCE_EMBEDDINGS``
|
|
50
|
+
====================== ======================
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
**References**
|
|
54
|
+
|
|
55
|
+
`C-Pack: Packaged Resources To Advance General Chinese Embedding <https://arxiv.org/pdf/2309.07597>`__
|
|
56
|
+
`BGE Github Repository <https://github.com/FlagOpen/FlagEmbedding>`__
|
|
57
|
+
|
|
58
|
+
**Paper abstract**
|
|
59
|
+
|
|
60
|
+
*We introduce C-Pack, a package of resources that significantly advance the field of general
|
|
61
|
+
Chinese embeddings. C-Pack includes three critical resources.
|
|
62
|
+
1) C-MTEB is a comprehensive benchmark for Chinese text embeddings covering 6 tasks and 35 datasets.
|
|
63
|
+
2) C-MTP is a massive text embedding dataset curated from labeled and unlabeled Chinese corpora
|
|
64
|
+
for training embedding models.
|
|
65
|
+
3) C-TEM is a family of embedding models covering multiple sizes.
|
|
66
|
+
Our models outperform all prior Chinese text embeddings on C-MTEB by up to +10% upon the
|
|
67
|
+
time of the release. We also integrate and optimize the entire suite of training methods for
|
|
68
|
+
C-TEM. Along with our resources on general Chinese embedding, we release our data and models for
|
|
69
|
+
English text embeddings. The English models achieve stateof-the-art performance on the MTEB
|
|
70
|
+
benchmark; meanwhile, our released English data is 2 times larger than the Chinese data. All
|
|
71
|
+
these resources are made publicly available at https://github.com/FlagOpen/FlagEmbedding.*
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
Parameters
|
|
75
|
+
----------
|
|
76
|
+
batchSize
|
|
77
|
+
Size of every batch , by default 8
|
|
78
|
+
dimension
|
|
79
|
+
Number of embedding dimensions, by default 768
|
|
80
|
+
caseSensitive
|
|
81
|
+
Whether to ignore case in tokens for embeddings matching, by default False
|
|
82
|
+
maxSentenceLength
|
|
83
|
+
Max sentence length to process, by default 512
|
|
84
|
+
configProtoBytes
|
|
85
|
+
ConfigProto from tensorflow, serialized into byte array.
|
|
86
|
+
useCLSToken
|
|
87
|
+
Whether to use the CLS token for sentence embeddings, by default True
|
|
88
|
+
|
|
89
|
+
Examples
|
|
90
|
+
--------
|
|
91
|
+
>>> import sparknlp
|
|
92
|
+
>>> from sparknlp.base import *
|
|
93
|
+
>>> from sparknlp.annotator import *
|
|
94
|
+
>>> from pyspark.ml import Pipeline
|
|
95
|
+
>>> documentAssembler = DocumentAssembler() \\
|
|
96
|
+
... .setInputCol("text") \\
|
|
97
|
+
... .setOutputCol("document")
|
|
98
|
+
>>> embeddings = BGEEmbeddings.pretrained() \\
|
|
99
|
+
... .setInputCols(["document"]) \\
|
|
100
|
+
... .setOutputCol("bge_embeddings")
|
|
101
|
+
>>> embeddingsFinisher = EmbeddingsFinisher() \\
|
|
102
|
+
... .setInputCols(["bge_embeddings"]) \\
|
|
103
|
+
... .setOutputCols("finished_embeddings") \\
|
|
104
|
+
... .setOutputAsVector(True)
|
|
105
|
+
>>> pipeline = Pipeline().setStages([
|
|
106
|
+
... documentAssembler,
|
|
107
|
+
... embeddings,
|
|
108
|
+
... embeddingsFinisher
|
|
109
|
+
... ])
|
|
110
|
+
>>> data = spark.createDataFrame([["query: how much protein should a female eat",
|
|
111
|
+
... "passage: As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day." + \\
|
|
112
|
+
... "But, as you can see from this chart, you'll need to increase that if you're expecting or training for a" + \\
|
|
113
|
+
... "marathon. Check out the chart below to see how much protein you should be eating each day.",
|
|
114
|
+
... ]]).toDF("text")
|
|
115
|
+
>>> result = pipeline.fit(data).transform(data)
|
|
116
|
+
>>> result.selectExpr("explode(finished_embeddings) as result").show(5, 80)
|
|
117
|
+
+--------------------------------------------------------------------------------+
|
|
118
|
+
| result|
|
|
119
|
+
+--------------------------------------------------------------------------------+
|
|
120
|
+
|[[8.0190285E-4, -0.005974853, -0.072875895, 0.007944068, 0.026059335, -0.0080...|
|
|
121
|
+
|[[0.050514214, 0.010061974, -0.04340176, -0.020937217, 0.05170225, 0.01157857...|
|
|
122
|
+
+--------------------------------------------------------------------------------+
|
|
123
|
+
"""
|
|
124
|
+
|
|
125
|
+
name = "BGEEmbeddings"
|
|
126
|
+
|
|
127
|
+
inputAnnotatorTypes = [AnnotatorType.DOCUMENT]
|
|
128
|
+
|
|
129
|
+
outputAnnotatorType = AnnotatorType.SENTENCE_EMBEDDINGS
|
|
130
|
+
configProtoBytes = Param(Params._dummy(),
|
|
131
|
+
"configProtoBytes",
|
|
132
|
+
"ConfigProto from tensorflow, serialized into byte array. Get with config_proto.SerializeToString()",
|
|
133
|
+
TypeConverters.toListInt)
|
|
134
|
+
|
|
135
|
+
def setConfigProtoBytes(self, b):
|
|
136
|
+
"""Sets configProto from tensorflow, serialized into byte array.
|
|
137
|
+
|
|
138
|
+
Parameters
|
|
139
|
+
----------
|
|
140
|
+
b : List[int]
|
|
141
|
+
ConfigProto from tensorflow, serialized into byte array
|
|
142
|
+
"""
|
|
143
|
+
return self._set(configProtoBytes=b)
|
|
144
|
+
|
|
145
|
+
@keyword_only
|
|
146
|
+
def __init__(self, classname="com.johnsnowlabs.nlp.embeddings.BGEEmbeddings", java_model=None):
|
|
147
|
+
super(BGEEmbeddings, self).__init__(
|
|
148
|
+
classname=classname,
|
|
149
|
+
java_model=java_model
|
|
150
|
+
)
|
|
151
|
+
self._setDefault(
|
|
152
|
+
dimension=768,
|
|
153
|
+
batchSize=8,
|
|
154
|
+
maxSentenceLength=512,
|
|
155
|
+
caseSensitive=False,
|
|
156
|
+
useCLSToken=True
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
@staticmethod
|
|
160
|
+
def loadSavedModel(folder, spark_session):
|
|
161
|
+
"""Loads a locally saved model.
|
|
162
|
+
|
|
163
|
+
Parameters
|
|
164
|
+
----------
|
|
165
|
+
folder : str
|
|
166
|
+
Folder of the saved model
|
|
167
|
+
spark_session : pyspark.sql.SparkSession
|
|
168
|
+
The current SparkSession
|
|
169
|
+
|
|
170
|
+
Returns
|
|
171
|
+
-------
|
|
172
|
+
BGEEmbeddings
|
|
173
|
+
The restored model
|
|
174
|
+
"""
|
|
175
|
+
from sparknlp.internal import _BGELoader
|
|
176
|
+
jModel = _BGELoader(folder, spark_session._jsparkSession)._java_obj
|
|
177
|
+
return BGEEmbeddings(java_model=jModel)
|
|
178
|
+
|
|
179
|
+
@staticmethod
|
|
180
|
+
def pretrained(name="bge_small_en_v1.5", lang="en", remote_loc=None):
|
|
181
|
+
"""Downloads and loads a pretrained model.
|
|
182
|
+
|
|
183
|
+
Parameters
|
|
184
|
+
----------
|
|
185
|
+
name : str, optional
|
|
186
|
+
Name of the pretrained model, by default "bge_small_en_v1.5"
|
|
187
|
+
lang : str, optional
|
|
188
|
+
Language of the pretrained model, by default "en"
|
|
189
|
+
remote_loc : str, optional
|
|
190
|
+
Optional remote address of the resource, by default None. Will use
|
|
191
|
+
Spark NLPs repositories otherwise.
|
|
192
|
+
|
|
193
|
+
Returns
|
|
194
|
+
-------
|
|
195
|
+
BGEEmbeddings
|
|
196
|
+
The restored model
|
|
197
|
+
"""
|
|
198
|
+
from sparknlp.pretrained import ResourceDownloader
|
|
199
|
+
return ResourceDownloader.downloadModel(BGEEmbeddings, name, lang, remote_loc)
|
|
@@ -21,7 +21,8 @@ class CamemBertEmbeddings(AnnotatorModel,
|
|
|
21
21
|
HasCaseSensitiveProperties,
|
|
22
22
|
HasStorageRef,
|
|
23
23
|
HasBatchedAnnotate,
|
|
24
|
-
HasEngine
|
|
24
|
+
HasEngine,
|
|
25
|
+
HasMaxSentenceLengthLimit):
|
|
25
26
|
"""The CamemBERT model was proposed in CamemBERT: a Tasty French Language Model by
|
|
26
27
|
Louis Martin, Benjamin Muller, Pedro Javier Ortiz Suárez, Yoann Dupont, Laurent
|
|
27
28
|
Romary, Éric Villemonte de la Clergerie, Djamé Seddah, and Benoît Sagot.
|
|
@@ -39,10 +40,10 @@ class CamemBertEmbeddings(AnnotatorModel,
|
|
|
39
40
|
The default model is ``"camembert_base"``, if no name is provided.
|
|
40
41
|
|
|
41
42
|
For available pretrained models please see the
|
|
42
|
-
`Models Hub <https://
|
|
43
|
+
`Models Hub <https://sparknlp.org/models?task=Embeddings>`__.
|
|
43
44
|
|
|
44
45
|
For extended examples of usage, see the
|
|
45
|
-
`
|
|
46
|
+
`Examples <https://github.com/JohnSnowLabs/spark-nlp/blob/master/examples/python/training/english/dl-ner/ner_bert.ipynb>`__
|
|
46
47
|
and the
|
|
47
48
|
`CamemBertEmbeddingsTestSpec <https://github.com/JohnSnowLabs/spark-nlp/blob/master/src/test/scala/com/johnsnowlabs/nlp/embeddings/CamemBertEmbeddingsTestSpec.scala>`__.
|
|
48
49
|
|
|
@@ -143,13 +144,6 @@ class CamemBertEmbeddings(AnnotatorModel,
|
|
|
143
144
|
TypeConverters.toListInt,
|
|
144
145
|
)
|
|
145
146
|
|
|
146
|
-
maxSentenceLength = Param(
|
|
147
|
-
Params._dummy(),
|
|
148
|
-
"maxSentenceLength",
|
|
149
|
-
"Max sentence length to process",
|
|
150
|
-
typeConverter=TypeConverters.toInt,
|
|
151
|
-
)
|
|
152
|
-
|
|
153
147
|
def setConfigProtoBytes(self, b):
|
|
154
148
|
"""Sets configProto from tensorflow, serialized into byte array.
|
|
155
149
|
|
|
@@ -160,16 +154,6 @@ class CamemBertEmbeddings(AnnotatorModel,
|
|
|
160
154
|
"""
|
|
161
155
|
return self._set(configProtoBytes=b)
|
|
162
156
|
|
|
163
|
-
def setMaxSentenceLength(self, value):
|
|
164
|
-
"""Sets max sentence length to process.
|
|
165
|
-
|
|
166
|
-
Parameters
|
|
167
|
-
----------
|
|
168
|
-
value : int
|
|
169
|
-
Max sentence length to process
|
|
170
|
-
"""
|
|
171
|
-
return self._set(maxSentenceLength=value)
|
|
172
|
-
|
|
173
157
|
@keyword_only
|
|
174
158
|
def __init__(self, classname="com.johnsnowlabs.nlp.embeddings.CamemBertEmbeddings", java_model=None):
|
|
175
159
|
super(CamemBertEmbeddings, self).__init__(
|
|
@@ -21,7 +21,7 @@ class ChunkEmbeddings(AnnotatorModel):
|
|
|
21
21
|
chunk embeddings from either Chunker, NGramGenerator, or NerConverter
|
|
22
22
|
outputs.
|
|
23
23
|
|
|
24
|
-
For extended examples of usage, see the `
|
|
24
|
+
For extended examples of usage, see the `Examples <https://github.com/JohnSnowLabs/spark-nlp/blob/master/examples/python/annotation/text/english/embeddings/ChunkEmbeddings.ipynb>`__.
|
|
25
25
|
|
|
26
26
|
========================== ======================
|
|
27
27
|
Input Annotation types Output Annotation type
|
|
@@ -147,4 +147,3 @@ class ChunkEmbeddings(AnnotatorModel):
|
|
|
147
147
|
aggregation/pooling.
|
|
148
148
|
"""
|
|
149
149
|
return self._set(skipOOV=value)
|
|
150
|
-
|
|
@@ -20,7 +20,8 @@ class DeBertaEmbeddings(AnnotatorModel,
|
|
|
20
20
|
HasCaseSensitiveProperties,
|
|
21
21
|
HasStorageRef,
|
|
22
22
|
HasBatchedAnnotate,
|
|
23
|
-
HasEngine
|
|
23
|
+
HasEngine,
|
|
24
|
+
HasMaxSentenceLengthLimit):
|
|
24
25
|
"""The DeBERTa model was proposed in DeBERTa: Decoding-enhanced BERT with
|
|
25
26
|
Disentangled Attention by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu
|
|
26
27
|
Chen It is based on Google’s BERT model released in 2018 and Facebook’s
|
|
@@ -141,11 +142,6 @@ class DeBertaEmbeddings(AnnotatorModel,
|
|
|
141
142
|
"ConfigProto from tensorflow, serialized into byte array. Get with config_proto.SerializeToString()",
|
|
142
143
|
TypeConverters.toListInt)
|
|
143
144
|
|
|
144
|
-
maxSentenceLength = Param(Params._dummy(),
|
|
145
|
-
"maxSentenceLength",
|
|
146
|
-
"Max sentence length to process",
|
|
147
|
-
typeConverter=TypeConverters.toInt)
|
|
148
|
-
|
|
149
145
|
def setConfigProtoBytes(self, b):
|
|
150
146
|
"""Sets configProto from tensorflow, serialized into byte array.
|
|
151
147
|
|
|
@@ -156,16 +152,6 @@ class DeBertaEmbeddings(AnnotatorModel,
|
|
|
156
152
|
"""
|
|
157
153
|
return self._set(configProtoBytes=b)
|
|
158
154
|
|
|
159
|
-
def setMaxSentenceLength(self, value):
|
|
160
|
-
"""Sets max sentence length to process.
|
|
161
|
-
|
|
162
|
-
Parameters
|
|
163
|
-
----------
|
|
164
|
-
value : int
|
|
165
|
-
Max sentence length to process
|
|
166
|
-
"""
|
|
167
|
-
return self._set(maxSentenceLength=value)
|
|
168
|
-
|
|
169
155
|
@keyword_only
|
|
170
156
|
def __init__(self, classname="com.johnsnowlabs.nlp.embeddings.DeBertaEmbeddings", java_model=None):
|
|
171
157
|
super(DeBertaEmbeddings, self).__init__(
|
|
@@ -21,7 +21,8 @@ class DistilBertEmbeddings(AnnotatorModel,
|
|
|
21
21
|
HasCaseSensitiveProperties,
|
|
22
22
|
HasStorageRef,
|
|
23
23
|
HasBatchedAnnotate,
|
|
24
|
-
HasEngine
|
|
24
|
+
HasEngine,
|
|
25
|
+
HasMaxSentenceLengthLimit):
|
|
25
26
|
"""DistilBERT is a small, fast, cheap and light Transformer model trained by
|
|
26
27
|
distilling BERT base. It has 40% less parameters than ``bert-base-uncased``,
|
|
27
28
|
runs 60% faster while preserving over 95% of BERT's performances as measured
|
|
@@ -37,10 +38,10 @@ class DistilBertEmbeddings(AnnotatorModel,
|
|
|
37
38
|
|
|
38
39
|
The default model is ``"distilbert_base_cased"``, if no name is provided.
|
|
39
40
|
For available pretrained models please see the
|
|
40
|
-
`Models Hub <https://
|
|
41
|
+
`Models Hub <https://sparknlp.org/models?task=Embeddings>`__.
|
|
41
42
|
|
|
42
|
-
For extended examples of usage, see the `
|
|
43
|
-
<https://github.com/JohnSnowLabs/spark-nlp
|
|
43
|
+
For extended examples of usage, see the `Examples
|
|
44
|
+
<https://github.com/JohnSnowLabs/spark-nlp/blob/master/examples/python/transformers/HuggingFace%20in%20Spark%20NLP%20-%20DistilBERT.ipynb>`__.
|
|
44
45
|
To see which models are compatible and how to import them see
|
|
45
46
|
`Import Transformers into Spark NLP 🚀
|
|
46
47
|
<https://github.com/JohnSnowLabs/spark-nlp/discussions/5669>`_.
|
|
@@ -149,11 +150,6 @@ class DistilBertEmbeddings(AnnotatorModel,
|
|
|
149
150
|
|
|
150
151
|
outputAnnotatorType = AnnotatorType.WORD_EMBEDDINGS
|
|
151
152
|
|
|
152
|
-
maxSentenceLength = Param(Params._dummy(),
|
|
153
|
-
"maxSentenceLength",
|
|
154
|
-
"Max sentence length to process",
|
|
155
|
-
typeConverter=TypeConverters.toInt)
|
|
156
|
-
|
|
157
153
|
configProtoBytes = Param(Params._dummy(),
|
|
158
154
|
"configProtoBytes",
|
|
159
155
|
"ConfigProto from tensorflow, serialized into byte array. Get with config_proto.SerializeToString()",
|
|
@@ -169,16 +165,6 @@ class DistilBertEmbeddings(AnnotatorModel,
|
|
|
169
165
|
"""
|
|
170
166
|
return self._set(configProtoBytes=b)
|
|
171
167
|
|
|
172
|
-
def setMaxSentenceLength(self, value):
|
|
173
|
-
"""Sets max sentence length to process.
|
|
174
|
-
|
|
175
|
-
Parameters
|
|
176
|
-
----------
|
|
177
|
-
value : int
|
|
178
|
-
Max sentence length to process
|
|
179
|
-
"""
|
|
180
|
-
return self._set(maxSentenceLength=value)
|
|
181
|
-
|
|
182
168
|
@keyword_only
|
|
183
169
|
def __init__(self, classname="com.johnsnowlabs.nlp.embeddings.DistilBertEmbeddings", java_model=None):
|
|
184
170
|
super(DistilBertEmbeddings, self).__init__(
|
|
@@ -31,7 +31,7 @@ class Doc2VecApproach(AnnotatorApproach, HasStorageRef, HasEnableCachingProperti
|
|
|
31
31
|
|
|
32
32
|
For instantiated/pretrained models, see :class:`.Doc2VecModel`.
|
|
33
33
|
|
|
34
|
-
For available pretrained models please see the `Models Hub <https://
|
|
34
|
+
For available pretrained models please see the `Models Hub <https://sparknlp.org/models>`__.
|
|
35
35
|
|
|
36
36
|
====================== =======================
|
|
37
37
|
Input Annotation types Output Annotation type
|
|
@@ -344,3 +344,9 @@ class Doc2VecModel(AnnotatorModel, HasStorageRef, HasEmbeddingsProperties):
|
|
|
344
344
|
from sparknlp.pretrained import ResourceDownloader
|
|
345
345
|
return ResourceDownloader.downloadModel(Doc2VecModel, name, lang, remote_loc)
|
|
346
346
|
|
|
347
|
+
def getVectors(self):
|
|
348
|
+
"""
|
|
349
|
+
Returns the vector representation of the words as a dataframe
|
|
350
|
+
with two fields, word and vector.
|
|
351
|
+
"""
|
|
352
|
+
return self._call_java("getVectors")
|
|
@@ -0,0 +1,195 @@
|
|
|
1
|
+
# Copyright 2017-2022 John Snow Labs
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Contains classes for E5Embeddings."""
|
|
15
|
+
|
|
16
|
+
from sparknlp.common import *
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class E5Embeddings(AnnotatorModel,
|
|
20
|
+
HasEmbeddingsProperties,
|
|
21
|
+
HasCaseSensitiveProperties,
|
|
22
|
+
HasStorageRef,
|
|
23
|
+
HasBatchedAnnotate,
|
|
24
|
+
HasMaxSentenceLengthLimit):
|
|
25
|
+
"""Sentence embeddings using E5.
|
|
26
|
+
|
|
27
|
+
E5, a weakly supervised text embedding model that can generate text embeddings tailored to any task (e.g., classification, retrieval, clustering, text evaluation, etc.)
|
|
28
|
+
Note that this annotator is only supported for Spark Versions 3.4 and up.
|
|
29
|
+
|
|
30
|
+
Pretrained models can be loaded with :meth:`.pretrained` of the companion
|
|
31
|
+
object:
|
|
32
|
+
|
|
33
|
+
>>> embeddings = E5Embeddings.pretrained() \\
|
|
34
|
+
... .setInputCols(["document"]) \\
|
|
35
|
+
... .setOutputCol("e5_embeddings")
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
The default model is ``"e5_small"``, if no name is provided.
|
|
39
|
+
|
|
40
|
+
For available pretrained models please see the
|
|
41
|
+
`Models Hub <https://sparknlp.org/models?q=E5>`__.
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
====================== ======================
|
|
45
|
+
Input Annotation types Output Annotation type
|
|
46
|
+
====================== ======================
|
|
47
|
+
``DOCUMENT`` ``SENTENCE_EMBEDDINGS``
|
|
48
|
+
====================== ======================
|
|
49
|
+
|
|
50
|
+
Parameters
|
|
51
|
+
----------
|
|
52
|
+
batchSize
|
|
53
|
+
Size of every batch , by default 8
|
|
54
|
+
dimension
|
|
55
|
+
Number of embedding dimensions, by default 768
|
|
56
|
+
caseSensitive
|
|
57
|
+
Whether to ignore case in tokens for embeddings matching, by default False
|
|
58
|
+
maxSentenceLength
|
|
59
|
+
Max sentence length to process, by default 512
|
|
60
|
+
configProtoBytes
|
|
61
|
+
ConfigProto from tensorflow, serialized into byte array.
|
|
62
|
+
|
|
63
|
+
References
|
|
64
|
+
----------
|
|
65
|
+
`Text Embeddings by Weakly-Supervised Contrastive Pre-training <https://arxiv.org/pdf/2212.03533>`__
|
|
66
|
+
|
|
67
|
+
https://github.com/microsoft/unilm/tree/master/e5
|
|
68
|
+
|
|
69
|
+
**Paper abstract**
|
|
70
|
+
|
|
71
|
+
*This paper presents E5, a family of state-of-the-art text embeddings that transfer
|
|
72
|
+
well to a wide range of tasks. The model is trained in a contrastive manner with
|
|
73
|
+
weak supervision signals from our curated large-scale text pair dataset (called
|
|
74
|
+
CCPairs). E5 can be readily used as a general-purpose embedding model for any
|
|
75
|
+
tasks requiring a single-vector representation of texts such as retrieval, clustering,
|
|
76
|
+
and classification, achieving strong performance in both zero-shot and fine-tuned
|
|
77
|
+
settings. We conduct extensive evaluations on 56 datasets from the BEIR and
|
|
78
|
+
MTEB benchmarks. For zero-shot settings, E5 is the first model that outperforms
|
|
79
|
+
the strong BM25 baseline on the BEIR retrieval benchmark without using any
|
|
80
|
+
labeled data. When fine-tuned, E5 obtains the best results on the MTEB benchmark,
|
|
81
|
+
beating existing embedding models with 40× more parameters.*
|
|
82
|
+
|
|
83
|
+
Examples
|
|
84
|
+
--------
|
|
85
|
+
>>> import sparknlp
|
|
86
|
+
>>> from sparknlp.base import *
|
|
87
|
+
>>> from sparknlp.annotator import *
|
|
88
|
+
>>> from pyspark.ml import Pipeline
|
|
89
|
+
>>> documentAssembler = DocumentAssembler() \\
|
|
90
|
+
... .setInputCol("text") \\
|
|
91
|
+
... .setOutputCol("document")
|
|
92
|
+
>>> embeddings = E5Embeddings.pretrained() \\
|
|
93
|
+
... .setInputCols(["document"]) \\
|
|
94
|
+
... .setOutputCol("e5_embeddings")
|
|
95
|
+
>>> embeddingsFinisher = EmbeddingsFinisher() \\
|
|
96
|
+
... .setInputCols(["e5_embeddings"]) \\
|
|
97
|
+
... .setOutputCols("finished_embeddings") \\
|
|
98
|
+
... .setOutputAsVector(True)
|
|
99
|
+
>>> pipeline = Pipeline().setStages([
|
|
100
|
+
... documentAssembler,
|
|
101
|
+
... embeddings,
|
|
102
|
+
... embeddingsFinisher
|
|
103
|
+
... ])
|
|
104
|
+
>>> data = spark.createDataFrame([["query: how much protein should a female eat",
|
|
105
|
+
... "passage: As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day." + \
|
|
106
|
+
... "But, as you can see from this chart, you'll need to increase that if you're expecting or training for a" + \
|
|
107
|
+
... "marathon. Check out the chart below to see how much protein you should be eating each day.",
|
|
108
|
+
... ]]).toDF("text")
|
|
109
|
+
>>> result = pipeline.fit(data).transform(data)
|
|
110
|
+
>>> result.selectExpr("explode(finished_embeddings) as result").show(5, 80)
|
|
111
|
+
+--------------------------------------------------------------------------------+
|
|
112
|
+
| result|
|
|
113
|
+
+--------------------------------------------------------------------------------+
|
|
114
|
+
|[[8.0190285E-4, -0.005974853, -0.072875895, 0.007944068, 0.026059335, -0.0080...|
|
|
115
|
+
|[[0.050514214, 0.010061974, -0.04340176, -0.020937217, 0.05170225, 0.01157857...|
|
|
116
|
+
+--------------------------------------------------------------------------------+
|
|
117
|
+
"""
|
|
118
|
+
|
|
119
|
+
name = "E5Embeddings"
|
|
120
|
+
|
|
121
|
+
inputAnnotatorTypes = [AnnotatorType.DOCUMENT]
|
|
122
|
+
|
|
123
|
+
outputAnnotatorType = AnnotatorType.SENTENCE_EMBEDDINGS
|
|
124
|
+
configProtoBytes = Param(Params._dummy(),
|
|
125
|
+
"configProtoBytes",
|
|
126
|
+
"ConfigProto from tensorflow, serialized into byte array. Get with config_proto.SerializeToString()",
|
|
127
|
+
TypeConverters.toListInt)
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def setConfigProtoBytes(self, b):
|
|
131
|
+
"""Sets configProto from tensorflow, serialized into byte array.
|
|
132
|
+
|
|
133
|
+
Parameters
|
|
134
|
+
----------
|
|
135
|
+
b : List[int]
|
|
136
|
+
ConfigProto from tensorflow, serialized into byte array
|
|
137
|
+
"""
|
|
138
|
+
return self._set(configProtoBytes=b)
|
|
139
|
+
|
|
140
|
+
@keyword_only
|
|
141
|
+
def __init__(self, classname="com.johnsnowlabs.nlp.embeddings.E5Embeddings", java_model=None):
|
|
142
|
+
super(E5Embeddings, self).__init__(
|
|
143
|
+
classname=classname,
|
|
144
|
+
java_model=java_model
|
|
145
|
+
)
|
|
146
|
+
self._setDefault(
|
|
147
|
+
dimension=768,
|
|
148
|
+
batchSize=8,
|
|
149
|
+
maxSentenceLength=512,
|
|
150
|
+
caseSensitive=False,
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
@staticmethod
|
|
154
|
+
def loadSavedModel(folder, spark_session, use_openvino=False):
|
|
155
|
+
"""Loads a locally saved model.
|
|
156
|
+
|
|
157
|
+
Parameters
|
|
158
|
+
----------
|
|
159
|
+
folder : str
|
|
160
|
+
Folder of the saved model
|
|
161
|
+
spark_session : pyspark.sql.SparkSession
|
|
162
|
+
The current SparkSession
|
|
163
|
+
use_openvino : bool
|
|
164
|
+
Use OpenVINO backend
|
|
165
|
+
|
|
166
|
+
Returns
|
|
167
|
+
-------
|
|
168
|
+
E5Embeddings
|
|
169
|
+
The restored model
|
|
170
|
+
"""
|
|
171
|
+
from sparknlp.internal import _E5Loader
|
|
172
|
+
jModel = _E5Loader(folder, spark_session._jsparkSession, use_openvino)._java_obj
|
|
173
|
+
return E5Embeddings(java_model=jModel)
|
|
174
|
+
|
|
175
|
+
@staticmethod
|
|
176
|
+
def pretrained(name="e5_small", lang="en", remote_loc=None):
|
|
177
|
+
"""Downloads and loads a pretrained model.
|
|
178
|
+
|
|
179
|
+
Parameters
|
|
180
|
+
----------
|
|
181
|
+
name : str, optional
|
|
182
|
+
Name of the pretrained model, by default "e5_small"
|
|
183
|
+
lang : str, optional
|
|
184
|
+
Language of the pretrained model, by default "en"
|
|
185
|
+
remote_loc : str, optional
|
|
186
|
+
Optional remote address of the resource, by default None. Will use
|
|
187
|
+
Spark NLPs repositories otherwise.
|
|
188
|
+
|
|
189
|
+
Returns
|
|
190
|
+
-------
|
|
191
|
+
E5Embeddings
|
|
192
|
+
The restored model
|
|
193
|
+
"""
|
|
194
|
+
from sparknlp.pretrained import ResourceDownloader
|
|
195
|
+
return ResourceDownloader.downloadModel(E5Embeddings, name, lang, remote_loc)
|