spark-nlp 5.5.0__py2.py3-none-any.whl → 5.5.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of spark-nlp might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: spark-nlp
3
- Version: 5.5.0
3
+ Version: 5.5.1
4
4
  Summary: John Snow Labs Spark NLP is a natural language processing library built on top of Apache Spark ML. It provides simple, performant & accurate NLP annotations for machine learning pipelines, that scale easily in a distributed environment.
5
5
  Home-page: https://github.com/JohnSnowLabs/spark-nlp
6
6
  Author: John Snow Labs
@@ -95,7 +95,7 @@ $ java -version
95
95
  $ conda create -n sparknlp python=3.7 -y
96
96
  $ conda activate sparknlp
97
97
  # spark-nlp by default is based on pyspark 3.x
98
- $ pip install spark-nlp==5.5.0 pyspark==3.3.1
98
+ $ pip install spark-nlp==5.5.1 pyspark==3.3.1
99
99
  ```
100
100
 
101
101
  In Python console or Jupyter `Python3` kernel:
@@ -161,7 +161,7 @@ For a quick example of using pipelines and models take a look at our official [d
161
161
 
162
162
  ### Apache Spark Support
163
163
 
164
- Spark NLP *5.5.0* has been built on top of Apache Spark 3.4 while fully supports Apache Spark 3.0.x, 3.1.x, 3.2.x, 3.3.x, 3.4.x, and 3.5.x
164
+ Spark NLP *5.5.1* has been built on top of Apache Spark 3.4 while fully supports Apache Spark 3.0.x, 3.1.x, 3.2.x, 3.3.x, 3.4.x, and 3.5.x
165
165
 
166
166
  | Spark NLP | Apache Spark 3.5.x | Apache Spark 3.4.x | Apache Spark 3.3.x | Apache Spark 3.2.x | Apache Spark 3.1.x | Apache Spark 3.0.x | Apache Spark 2.4.x | Apache Spark 2.3.x |
167
167
  |-----------|--------------------|--------------------|--------------------|--------------------|--------------------|--------------------|--------------------|--------------------|
@@ -189,7 +189,7 @@ Find out more about 4.x `SparkNLP` versions in our official [documentation](http
189
189
 
190
190
  ### Databricks Support
191
191
 
192
- Spark NLP 5.5.0 has been tested and is compatible with the following runtimes:
192
+ Spark NLP 5.5.1 has been tested and is compatible with the following runtimes:
193
193
 
194
194
  | **CPU** | **GPU** |
195
195
  |--------------------|--------------------|
@@ -206,7 +206,7 @@ We are compatible with older runtimes. For a full list check databricks support
206
206
 
207
207
  ### EMR Support
208
208
 
209
- Spark NLP 5.5.0 has been tested and is compatible with the following EMR releases:
209
+ Spark NLP 5.5.1 has been tested and is compatible with the following EMR releases:
210
210
 
211
211
  | **EMR Release** |
212
212
  |--------------------|
@@ -237,7 +237,7 @@ deployed to Maven central. To add any of our packages as a dependency in your ap
237
237
  from our official documentation.
238
238
 
239
239
  If you are interested, there is a simple SBT project for Spark NLP to guide you on how to use it in your
240
- projects [Spark NLP SBT S5.5.0r](https://github.com/maziyarpanahi/spark-nlp-starter)
240
+ projects [Spark NLP SBT S5.5.1r](https://github.com/maziyarpanahi/spark-nlp-starter)
241
241
 
242
242
  ### Python
243
243
 
@@ -282,7 +282,7 @@ In Spark NLP we can define S3 locations to:
282
282
 
283
283
  Please check [these instructions](https://sparknlp.org/docs/en/install#s3-integration) from our official documentation.
284
284
 
285
- ## Document5.5.0
285
+ ## Document5.5.1
286
286
 
287
287
  ### Examples
288
288
 
@@ -315,7 +315,7 @@ the Spark NLP library:
315
315
  keywords = {Spark, Natural language processing, Deep learning, Tensorflow, Cluster},
316
316
  abstract = {Spark NLP is a Natural Language Processing (NLP) library built on top of Apache Spark ML. It provides simple, performant & accurate NLP annotations for machine learning pipelines that can scale easily in a distributed environment. Spark NLP comes with 1100+ pretrained pipelines and models in more than 192+ languages. It supports nearly all the NLP tasks and modules that can be used seamlessly in a cluster. Downloaded more than 2.7 million times and experiencing 9x growth since January 2020, Spark NLP is used by 54% of healthcare organizations as the world’s most widely used NLP library in the enterprise.}
317
317
  }
318
- }5.5.0
318
+ }5.5.1
319
319
  ```
320
320
 
321
321
  ## Community support
@@ -3,7 +3,7 @@ com/johnsnowlabs/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,
3
3
  com/johnsnowlabs/ml/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
4
  com/johnsnowlabs/ml/ai/__init__.py,sha256=YQiK2M7U4d8y5irPy_HB8ae0mSpqS9583MH44pnKJXc,295
5
5
  com/johnsnowlabs/nlp/__init__.py,sha256=DPIVXtONO5xXyOk-HB0-sNiHAcco17NN13zPS_6Uw8c,294
6
- sparknlp/__init__.py,sha256=AKjfuQ0s3OwAuC0XQj4FjGRISLl5RY8bO3WdLpZpQaA,13638
6
+ sparknlp/__init__.py,sha256=26U34YGDYCBbYHr3rpxvy71snUYuoFhl5XfUvxkPv7M,13638
7
7
  sparknlp/annotation.py,sha256=I5zOxG5vV2RfPZfqN9enT1i4mo6oBcn3Lrzs37QiOiA,5635
8
8
  sparknlp/annotation_audio.py,sha256=iRV_InSVhgvAwSRe9NTbUH9v6OGvTM-FPCpSAKVu0mE,1917
9
9
  sparknlp/annotation_image.py,sha256=xhCe8Ko-77XqWVuuYHFrjKqF6zPd8Z-RY_rmZXNwCXU,2547
@@ -30,12 +30,13 @@ sparknlp/annotator/audio/__init__.py,sha256=dXjtvi5c0aTZFq1Q_JciUd1uFTBVSJoUdcq0
30
30
  sparknlp/annotator/audio/hubert_for_ctc.py,sha256=76PfwPZZvOHU5kfDqLueCFbmqa4W8pMNRGoCvOqjsEA,7859
31
31
  sparknlp/annotator/audio/wav2vec2_for_ctc.py,sha256=K78P1U6vA4O1UufsLYzy0H7arsKNmwPcIV7kzDFsA5Q,6210
32
32
  sparknlp/annotator/audio/whisper_for_ctc.py,sha256=uII51umuohqwnAW0Q7VdxEFyr_j5LMnfpcRlf8TbetA,9800
33
- sparknlp/annotator/classifier_dl/__init__.py,sha256=74WL0W2zBfx6v0tJpx1DcRfZENs86n9JxizDDBEE41A,3934
33
+ sparknlp/annotator/classifier_dl/__init__.py,sha256=4v2_3kSWQFFBc_KzaJ0gEC6ANVJpy5tsHa6CJGc4nCw,4005
34
34
  sparknlp/annotator/classifier_dl/albert_for_question_answering.py,sha256=LG2dL6Fky1T35yXTUZBfIihIIGnkRFQ7ECQ3HRXXEG8,6517
35
35
  sparknlp/annotator/classifier_dl/albert_for_sequence_classification.py,sha256=kWx7f9pcKE2qw319gn8FN0Md5dX38gbmfeoY9gWCLNk,7842
36
36
  sparknlp/annotator/classifier_dl/albert_for_token_classification.py,sha256=5rdsjWnsAVmtP-idU7ATKJ8lkH2rtlKZLnpi4Mq27eI,6839
37
37
  sparknlp/annotator/classifier_dl/albert_for_zero_shot_classification.py,sha256=_TgV6EiIOiD_djA3fxfoz-o37mzMeKbn6iL2kZ6GzO0,8366
38
38
  sparknlp/annotator/classifier_dl/bart_for_zero_shot_classification.py,sha256=yqQeDdpLbNOKuSZejZjSAjT8ydYyxsTVf2aFDgSSDfc,8767
39
+ sparknlp/annotator/classifier_dl/bert_for_multiple_choice.py,sha256=Ew_NGBj7F5ApgK3SyQh2HIfjD7ZTqTs0LZEQxjwoyto,5936
39
40
  sparknlp/annotator/classifier_dl/bert_for_question_answering.py,sha256=2euY_RAdMPA4IHJXZAd5MkQojFOtFNhB_hSc1iVQ5DQ,6433
40
41
  sparknlp/annotator/classifier_dl/bert_for_sequence_classification.py,sha256=AzD3RQcRuQc0DDTbL6vGiacTtHlZnbAqksNvRQq7EQE,7800
41
42
  sparknlp/annotator/classifier_dl/bert_for_token_classification.py,sha256=uJXoDLPfPWiRmKqtw_3lLBvneIirj87S2JWwfd33zq8,6668
@@ -173,7 +174,7 @@ sparknlp/annotator/token/regex_tokenizer.py,sha256=FG2HvFwMb1G_4grfyIQaeBpaAgKv_
173
174
  sparknlp/annotator/token/tokenizer.py,sha256=Me3P3wogUKUJ7O7_2wLdPzF00vKpp_sHuiztpGWRVpU,19939
174
175
  sparknlp/annotator/ws/__init__.py,sha256=-l8bnl8Z6lGXWOBdRIBZ6958fzTHt4o87QhhLHIFF8A,693
175
176
  sparknlp/annotator/ws/word_segmenter.py,sha256=rrbshwn5wzXIHpCCDji6ZcsmiARpuA82_p_6TgNHfRc,16365
176
- sparknlp/base/__init__.py,sha256=iC4b4NzTDsqEu3eE_f5QL8JD-uoig4Pn2h1ZMyPHR6Q,1266
177
+ sparknlp/base/__init__.py,sha256=fCL-kReIavZceUa1OC99pSRH7MsXzqGB8BXgzVS_f7s,1311
177
178
  sparknlp/base/audio_assembler.py,sha256=HKa9mXvmuMUrjTihUZkppGj-WJjcUrm2BGapNuPifyI,3320
178
179
  sparknlp/base/doc2_chunk.py,sha256=TyvbdJNkVo9favHlOEoH5JwKbjpk5ZVJ75p8Cilp9jM,6551
179
180
  sparknlp/base/document_assembler.py,sha256=zl-SXWMTR3B0EZ8z6SWYchCwEo-61FhU6u7dHUKDIOg,6697
@@ -185,6 +186,7 @@ sparknlp/base/has_recursive_transform.py,sha256=UkGNgo4LMsjQC-Coeefg4bJcg7FoPcPi
185
186
  sparknlp/base/image_assembler.py,sha256=HytRoYJTLMqGtvScHoFnp6CasG9IVNYAHYiT2_rrmeE,3719
186
187
  sparknlp/base/light_pipeline.py,sha256=Jk2DLpT4PLHCANlOo_WetTdPba_5lYs3ywiyY3lM-PE,16577
187
188
  sparknlp/base/multi_document_assembler.py,sha256=4htET1fRAeOB6zhsNXsBq5rKZvn-LGD4vrFRjPZeqow,7070
189
+ sparknlp/base/prompt_assembler.py,sha256=ysU4Vbmnuv2UBHK0JBkYrxiZiJ7_GTcVMip1-QRmheI,11570
188
190
  sparknlp/base/recursive_pipeline.py,sha256=V9rTnu8KMwgjoceykN9pF1mKGtOkkuiC_n9v8dE3LDk,4279
189
191
  sparknlp/base/table_assembler.py,sha256=Kxu3R2fY6JgCxEc07ibsMsjip6dgcPDHLiWAZ8gC_d8,5102
190
192
  sparknlp/base/token_assembler.py,sha256=qiHry07L7mVCqeHSH6hHxLygv1AsfZIE4jy1L75L3Do,5075
@@ -200,7 +202,7 @@ sparknlp/common/read_as.py,sha256=imxPGwV7jr4Li_acbo0OAHHRGCBbYv-akzEGaBWEfcY,12
200
202
  sparknlp/common/recursive_annotator_approach.py,sha256=vqugBw22cE3Ff7PIpRlnYFuOlchgL0nM26D8j-NdpqU,1449
201
203
  sparknlp/common/storage.py,sha256=D91H3p8EIjNspjqAYu6ephRpCUtdcAir4_PrAbkIQWE,4842
202
204
  sparknlp/common/utils.py,sha256=Yne6yYcwKxhOZC-U4qfYoDhWUP_6BIaAjI5X_P_df1E,1306
203
- sparknlp/internal/__init__.py,sha256=nK-9lncAVRXmyP8ATbiMwRnLJVe4IEd_r5Z3gEqDK3g,33672
205
+ sparknlp/internal/__init__.py,sha256=ljEf4IUraCdKU7gKFxNwFxlX-FHcnkG6sqs1MxEhLSQ,33967
204
206
  sparknlp/internal/annotator_java_ml.py,sha256=UGPoThG0rGXUOXGSQnDzEDW81Mu1s5RPF29v7DFyE3c,1187
205
207
  sparknlp/internal/annotator_transformer.py,sha256=fXmc2IWXGybqZpbEU9obmbdBYPc798y42zvSB4tqV9U,1448
206
208
  sparknlp/internal/extended_java_wrapper.py,sha256=hwP0133-hDiDf5sBF-P3MtUsuuDj1PpQbtGZQIRwzfk,2240
@@ -242,8 +244,8 @@ sparknlp/training/_tf_graph_builders_1x/ner_dl/dataset_encoder.py,sha256=R4yHFN3
242
244
  sparknlp/training/_tf_graph_builders_1x/ner_dl/ner_model.py,sha256=EoCSdcIjqQ3wv13MAuuWrKV8wyVBP0SbOEW41omHlR0,23189
243
245
  sparknlp/training/_tf_graph_builders_1x/ner_dl/ner_model_saver.py,sha256=k5CQ7gKV6HZbZMB8cKLUJuZxoZWlP_DFWdZ--aIDwsc,2356
244
246
  sparknlp/training/_tf_graph_builders_1x/ner_dl/sentence_grouper.py,sha256=pAxjWhjazSX8Vg0MFqJiuRVw1IbnQNSs-8Xp26L4nko,870
245
- spark_nlp-5.5.0.dist-info/.uuid,sha256=1f6hF51aIuv9yCvh31NU9lOpS34NE-h3a0Et7R9yR6A,36
246
- spark_nlp-5.5.0.dist-info/METADATA,sha256=FccpjBJS2ERU0kJM5kPn_bUo4VyX4l8tHLImemv6czo,19156
247
- spark_nlp-5.5.0.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
248
- spark_nlp-5.5.0.dist-info/top_level.txt,sha256=uuytur4pyMRw2H_txNY2ZkaucZHUs22QF8-R03ch_-E,13
249
- spark_nlp-5.5.0.dist-info/RECORD,,
247
+ spark_nlp-5.5.1.dist-info/.uuid,sha256=1f6hF51aIuv9yCvh31NU9lOpS34NE-h3a0Et7R9yR6A,36
248
+ spark_nlp-5.5.1.dist-info/METADATA,sha256=Y7Y0nf18tO2RfHzagHWWZpn4QRrF50d5wP3hXG1eFyw,19156
249
+ spark_nlp-5.5.1.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
250
+ spark_nlp-5.5.1.dist-info/top_level.txt,sha256=uuytur4pyMRw2H_txNY2ZkaucZHUs22QF8-R03ch_-E,13
251
+ spark_nlp-5.5.1.dist-info/RECORD,,
sparknlp/__init__.py CHANGED
@@ -129,7 +129,7 @@ def start(gpu=False,
129
129
  The initiated Spark session.
130
130
 
131
131
  """
132
- current_version = "5.5.0"
132
+ current_version = "5.5.1"
133
133
 
134
134
  if params is None:
135
135
  params = {}
@@ -310,4 +310,4 @@ def version():
310
310
  str
311
311
  The current Spark NLP version.
312
312
  """
313
- return '5.5.0'
313
+ return '5.5.1'
@@ -54,4 +54,4 @@ from sparknlp.annotator.classifier_dl.mpnet_for_question_answering import *
54
54
  from sparknlp.annotator.classifier_dl.mpnet_for_token_classification import *
55
55
  from sparknlp.annotator.classifier_dl.albert_for_zero_shot_classification import *
56
56
  from sparknlp.annotator.classifier_dl.camembert_for_zero_shot_classification import *
57
-
57
+ from sparknlp.annotator.classifier_dl.bert_for_multiple_choice import *
@@ -0,0 +1,161 @@
1
+ # Copyright 2017-2024 John Snow Labs
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from sparknlp.common import *
16
+
17
+ class BertForMultipleChoice(AnnotatorModel,
18
+ HasCaseSensitiveProperties,
19
+ HasBatchedAnnotate,
20
+ HasEngine,
21
+ HasMaxSentenceLengthLimit):
22
+ """BertForMultipleChoice can load BERT Models with a multiple choice classification head on top
23
+ (a linear layer on top of the pooled output and a softmax) e.g. for RocStories/SWAG tasks.
24
+
25
+ Pretrained models can be loaded with :meth:`.pretrained` of the companion
26
+ object:
27
+
28
+ >>> spanClassifier = BertForMultipleChoice.pretrained() \\
29
+ ... .setInputCols(["document_question", "document_context"]) \\
30
+ ... .setOutputCol("answer")
31
+
32
+ The default model is ``"bert_base_uncased_multiple_choice"``, if no name is
33
+ provided.
34
+
35
+ For available pretrained models please see the `Models Hub
36
+ <https://sparknlp.org/models?task=Multiple+Choice>`__.
37
+
38
+ To see which models are compatible and how to import them see
39
+ `Import Transformers into Spark NLP 🚀
40
+ <https://github.com/JohnSnowLabs/spark-nlp/discussions/5669>`_.
41
+
42
+ ====================== ======================
43
+ Input Annotation types Output Annotation type
44
+ ====================== ======================
45
+ ``DOCUMENT, DOCUMENT`` ``CHUNK``
46
+ ====================== ======================
47
+
48
+ Parameters
49
+ ----------
50
+ batchSize
51
+ Batch size. Large values allows faster processing but requires more
52
+ memory, by default 8
53
+ caseSensitive
54
+ Whether to ignore case in tokens for embeddings matching, by default
55
+ False
56
+ maxSentenceLength
57
+ Max sentence length to process, by default 512
58
+
59
+ Examples
60
+ --------
61
+ >>> import sparknlp
62
+ >>> from sparknlp.base import *
63
+ >>> from sparknlp.annotator import *
64
+ >>> from pyspark.ml import Pipeline
65
+ >>> documentAssembler = MultiDocumentAssembler() \\
66
+ ... .setInputCols(["question", "context"]) \\
67
+ ... .setOutputCols(["document_question", "document_context"])
68
+ >>> questionAnswering = BertForMultipleChoice.pretrained() \\
69
+ ... .setInputCols(["document_question", "document_context"]) \\
70
+ ... .setOutputCol("answer") \\
71
+ ... .setCaseSensitive(False)
72
+ >>> pipeline = Pipeline().setStages([
73
+ ... documentAssembler,
74
+ ... questionAnswering
75
+ ... ])
76
+ >>> data = spark.createDataFrame([["The Eiffel Tower is located in which country??", "Germany, France, Italy"]]).toDF("question", "context")
77
+ >>> result = pipeline.fit(data).transform(data)
78
+ >>> result.select("answer.result").show(truncate=False)
79
+ +--------------------+
80
+ |result |
81
+ +--------------------+
82
+ |[France] |
83
+ +--------------------+
84
+ """
85
+ name = "BertForMultipleChoice"
86
+
87
+ inputAnnotatorTypes = [AnnotatorType.DOCUMENT, AnnotatorType.DOCUMENT]
88
+
89
+ outputAnnotatorType = AnnotatorType.CHUNK
90
+
91
+ choicesDelimiter = Param(Params._dummy(),
92
+ "choicesDelimiter",
93
+ "Delimiter character use to split the choices",
94
+ TypeConverters.toString)
95
+
96
+ def setChoicesDelimiter(self, value):
97
+ """Sets delimiter character use to split the choices
98
+
99
+ Parameters
100
+ ----------
101
+ value : string
102
+ Delimiter character use to split the choices
103
+ """
104
+ return self._set(caseSensitive=value)
105
+
106
+ @keyword_only
107
+ def __init__(self, classname="com.johnsnowlabs.nlp.annotators.classifier.dl.BertForMultipleChoice",
108
+ java_model=None):
109
+ super(BertForMultipleChoice, self).__init__(
110
+ classname=classname,
111
+ java_model=java_model
112
+ )
113
+ self._setDefault(
114
+ batchSize=4,
115
+ maxSentenceLength=512,
116
+ caseSensitive=False,
117
+ choicesDelimiter = ","
118
+ )
119
+
120
+ @staticmethod
121
+ def loadSavedModel(folder, spark_session):
122
+ """Loads a locally saved model.
123
+
124
+ Parameters
125
+ ----------
126
+ folder : str
127
+ Folder of the saved model
128
+ spark_session : pyspark.sql.SparkSession
129
+ The current SparkSession
130
+
131
+ Returns
132
+ -------
133
+ BertForQuestionAnswering
134
+ The restored model
135
+ """
136
+ from sparknlp.internal import _BertMultipleChoiceLoader
137
+ jModel = _BertMultipleChoiceLoader(folder, spark_session._jsparkSession)._java_obj
138
+ return BertForMultipleChoice(java_model=jModel)
139
+
140
+ @staticmethod
141
+ def pretrained(name="bert_base_uncased_multiple_choice", lang="en", remote_loc=None):
142
+ """Downloads and loads a pretrained model.
143
+
144
+ Parameters
145
+ ----------
146
+ name : str, optional
147
+ Name of the pretrained model, by default
148
+ "bert_base_uncased_multiple_choice"
149
+ lang : str, optional
150
+ Language of the pretrained model, by default "en"
151
+ remote_loc : str, optional
152
+ Optional remote address of the resource, by default None. Will use
153
+ Spark NLPs repositories otherwise.
154
+
155
+ Returns
156
+ -------
157
+ BertForQuestionAnswering
158
+ The restored model
159
+ """
160
+ from sparknlp.pretrained import ResourceDownloader
161
+ return ResourceDownloader.downloadModel(BertForMultipleChoice, name, lang, remote_loc)
sparknlp/base/__init__.py CHANGED
@@ -26,3 +26,4 @@ from sparknlp.base.token_assembler import *
26
26
  from sparknlp.base.image_assembler import *
27
27
  from sparknlp.base.audio_assembler import *
28
28
  from sparknlp.base.table_assembler import *
29
+ from sparknlp.base.prompt_assembler import *
@@ -0,0 +1,207 @@
1
+ # Copyright 2017-2024 John Snow Labs
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Contains classes for the PromptAssembler."""
15
+
16
+ from pyspark import keyword_only
17
+ from pyspark.ml.param import TypeConverters, Params, Param
18
+
19
+ from sparknlp.common import AnnotatorType
20
+ from sparknlp.internal import AnnotatorTransformer
21
+
22
+
23
+ class PromptAssembler(AnnotatorTransformer):
24
+ """Assembles a sequence of messages into a single string using a template. These strings can then
25
+ be used as prompts for large language models.
26
+
27
+ This annotator expects an array of two-tuples as the type of the input column (one array of
28
+ tuples per row). The first element of the tuples should be the role and the second element is
29
+ the text of the message. Possible roles are "system", "user" and "assistant".
30
+
31
+ An assistant header can be added to the end of the generated string by using
32
+ ``setAddAssistant(True)``.
33
+
34
+ At the moment, this annotator uses llama.cpp as a backend to parse and apply the templates.
35
+ llama.cpp uses basic pattern matching to determine the type of the template, then applies a
36
+ basic version of the template to the messages. This means that more advanced templates are not
37
+ supported.
38
+
39
+ For an extended example see the
40
+ `example notebook <https://github.com/JohnSnowLabs/spark-nlp/blob/master/examples/python/llama.cpp/PromptAssember_with_AutoGGUFModel.ipynb>`__.
41
+
42
+ ====================== ======================
43
+ Input Annotation types Output Annotation type
44
+ ====================== ======================
45
+ ``NONE`` ``DOCUMENT``
46
+ ====================== ======================
47
+
48
+ Parameters
49
+ ----------
50
+ inputCol
51
+ Input column name
52
+ outputCol
53
+ Output column name
54
+ chatTemplate
55
+ Template used for the chat
56
+ addAssistant
57
+ Whether to add an assistant header to the end of the generated string
58
+
59
+ Examples
60
+ --------
61
+ >>> from sparknlp.base import *
62
+ >>> messages = [
63
+ ... [
64
+ ... ("system", "You are a helpful assistant."),
65
+ ... ("assistant", "Hello there, how can I help you?"),
66
+ ... ("user", "I need help with organizing my room."),
67
+ ... ]
68
+ ... ]
69
+ >>> df = spark.createDataFrame([messages]).toDF("messages")
70
+ >>> template = (
71
+ ... "{{- bos_token }} {%- if custom_tools is defined %} {%- set tools = custom_tools %} {%- "
72
+ ... "endif %} {%- if not tools_in_user_message is defined %} {%- set tools_in_user_message = true %} {%- "
73
+ ... 'endif %} {%- if not date_string is defined %} {%- set date_string = "26 Jul 2024" %} {%- endif %} '
74
+ ... "{%- if not tools is defined %} {%- set tools = none %} {%- endif %} {#- This block extracts the "
75
+ ... "system message, so we can slot it into the right place. #} {%- if messages[0]['role'] == 'system' %}"
76
+ ... " {%- set system_message = messages[0]['content']|trim %} {%- set messages = messages[1:] %} {%- else"
77
+ ... ' %} {%- set system_message = "" %} {%- endif %} {#- System message + builtin tools #} {{- '
78
+ ... '"<|start_header_id|>system<|end_header_id|>\\n\\n" }} {%- if builtin_tools is defined or tools is '
79
+ ... 'not none %} {{- "Environment: ipython\\n" }} {%- endif %} {%- if builtin_tools is defined %} {{- '
80
+ ... '"Tools: " + builtin_tools | reject(\\'equalto\', \\'code_interpreter\\') | join(", ") + "\\n\\n"}} '
81
+ ... '{%- endif %} {{- "Cutting Knowledge Date: December 2023\\n" }} {{- "Today Date: " + date_string '
82
+ ... '+ "\\n\\n" }} {%- if tools is not none and not tools_in_user_message %} {{- "You have access to '
83
+ ... 'the following functions. To call a function, please respond with JSON for a function call." }} {{- '
84
+ ... '\\'Respond in the format {"name": function name, "parameters": dictionary of argument name and its'
85
+ ... ' value}.\\' }} {{- "Do not use variables.\\n\\n" }} {%- for t in tools %} {{- t | tojson(indent=4) '
86
+ ... '}} {{- "\\n\\n" }} {%- endfor %} {%- endif %} {{- system_message }} {{- "<|eot_id|>" }} {#- '
87
+ ... "Custom tools are passed in a user message with some extra guidance #} {%- if tools_in_user_message "
88
+ ... "and not tools is none %} {#- Extract the first user message so we can plug it in here #} {%- if "
89
+ ... "messages | length != 0 %} {%- set first_user_message = messages[0]['content']|trim %} {%- set "
90
+ ... 'messages = messages[1:] %} {%- else %} {{- raise_exception("Cannot put tools in the first user '
91
+ ... "message when there's no first user message!\\") }} {%- endif %} {{- "
92
+ ... "'<|start_header_id|>user<|end_header_id|>\\n\\n' -}} {{- \\"Given the following functions, please "
93
+ ... 'respond with a JSON for a function call " }} {{- "with its proper arguments that best answers the '
94
+ ... 'given prompt.\\n\\n" }} {{- \\'Respond in the format {"name": function name, "parameters": '
95
+ ... 'dictionary of argument name and its value}.\\' }} {{- "Do not use variables.\\n\\n" }} {%- for t in '
96
+ ... 'tools %} {{- t | tojson(indent=4) }} {{- "\\n\\n" }} {%- endfor %} {{- first_user_message + '
97
+ ... "\\"<|eot_id|>\\"}} {%- endif %} {%- for message in messages %} {%- if not (message.role == 'ipython' "
98
+ ... "or message.role == 'tool' or 'tool_calls' in message) %} {{- '<|start_header_id|>' + message['role']"
99
+ ... " + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }} {%- elif 'tool_calls' in "
100
+ ... 'message %} {%- if not message.tool_calls|length == 1 %} {{- raise_exception("This model only '
101
+ ... 'supports single tool-calls at once!") }} {%- endif %} {%- set tool_call = message.tool_calls[0]'
102
+ ... ".function %} {%- if builtin_tools is defined and tool_call.name in builtin_tools %} {{- "
103
+ ... "'<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}} {{- \\"<|python_tag|>\\" + tool_call.name + "
104
+ ... '".call(" }} {%- for arg_name, arg_val in tool_call.arguments | items %} {{- arg_name + \\'="\\' + '
105
+ ... 'arg_val + \\'"\\' }} {%- if not loop.last %} {{- ", " }} {%- endif %} {%- endfor %} {{- ")" }} {%- '
106
+ ... "else %} {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}} {{- '{\\"name\": \\"' + "
107
+ ... 'tool_call.name + \\'", \\' }} {{- \\'"parameters": \\' }} {{- tool_call.arguments | tojson }} {{- "}" '
108
+ ... "}} {%- endif %} {%- if builtin_tools is defined %} {#- This means we're in ipython mode #} {{- "
109
+ ... '"<|eom_id|>" }} {%- else %} {{- "<|eot_id|>" }} {%- endif %} {%- elif message.role == "tool" '
110
+ ... 'or message.role == "ipython" %} {{- "<|start_header_id|>ipython<|end_header_id|>\\n\\n" }} {%- '
111
+ ... "if message.content is mapping or message.content is iterable %} {{- message.content | tojson }} {%- "
112
+ ... 'else %} {{- message.content }} {%- endif %} {{- "<|eot_id|>" }} {%- endif %} {%- endfor %} {%- if '
113
+ ... "add_generation_prompt %} {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }} {%- endif %} "
114
+ ... )
115
+ >>> prompt_assembler = (
116
+ ... PromptAssembler()
117
+ ... .setInputCol("messages")
118
+ ... .setOutputCol("prompt")
119
+ ... .setChatTemplate(template)
120
+ ... )
121
+ >>> prompt_assembler.transform(df).select("prompt.result").show(truncate=False)
122
+ +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
123
+ |result |
124
+ +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
125
+ |[<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nHello there, how can I help you?<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nI need help with organizing my room.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n]|
126
+ +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
127
+ """
128
+
129
+ outputAnnotatorType = AnnotatorType.DOCUMENT
130
+
131
+ inputCol = Param(
132
+ Params._dummy(),
133
+ "inputCol",
134
+ "input column name",
135
+ typeConverter=TypeConverters.toString,
136
+ )
137
+ outputCol = Param(
138
+ Params._dummy(),
139
+ "outputCol",
140
+ "output column name",
141
+ typeConverter=TypeConverters.toString,
142
+ )
143
+ chatTemplate = Param(
144
+ Params._dummy(),
145
+ "chatTemplate",
146
+ "Template used for the chat",
147
+ typeConverter=TypeConverters.toString,
148
+ )
149
+ addAssistant = Param(
150
+ Params._dummy(),
151
+ "addAssistant",
152
+ "Whether to add an assistant header to the end of the generated string",
153
+ typeConverter=TypeConverters.toBoolean,
154
+ )
155
+ name = "PromptAssembler"
156
+
157
+ @keyword_only
158
+ def __init__(self):
159
+ super(PromptAssembler, self).__init__(
160
+ classname="com.johnsnowlabs.nlp.PromptAssembler"
161
+ )
162
+ self._setDefault(outputCol="prompt", addAssistant=True)
163
+
164
+ @keyword_only
165
+ def setParams(self):
166
+ kwargs = self._input_kwargs
167
+ return self._set(**kwargs)
168
+
169
+ def setInputCol(self, value):
170
+ """Sets input column name.
171
+
172
+ Parameters
173
+ ----------
174
+ value : str
175
+ Name of the input column
176
+ """
177
+ return self._set(inputCol=value)
178
+
179
+ def setOutputCol(self, value):
180
+ """Sets output column name.
181
+
182
+ Parameters
183
+ ----------
184
+ value : str
185
+ Name of the Output Column
186
+ """
187
+ return self._set(outputCol=value)
188
+
189
+ def setChatTemplate(self, value):
190
+ """Sets the chat template.
191
+
192
+ Parameters
193
+ ----------
194
+ value : str
195
+ Template used for the chat
196
+ """
197
+ return self._set(chatTemplate=value)
198
+
199
+ def setAddAssistant(self, value):
200
+ """Sets whether to add an assistant header to the end of the generated string.
201
+
202
+ Parameters
203
+ ----------
204
+ value : bool
205
+ Whether to add an assistant header to the end of the generated string
206
+ """
207
+ return self._set(addAssistant=value)
@@ -113,6 +113,13 @@ class _BertQuestionAnsweringLoader(ExtendedJavaWrapper):
113
113
  jspark,
114
114
  )
115
115
 
116
+ class _BertMultipleChoiceLoader(ExtendedJavaWrapper):
117
+ def __init__(self, path, jspark):
118
+ super(_BertMultipleChoiceLoader, self).__init__(
119
+ "com.johnsnowlabs.nlp.annotators.classifier.dl.BertForMultipleChoice.loadSavedModel",
120
+ path,
121
+ jspark,
122
+ )
116
123
 
117
124
  class _DeBERTaLoader(ExtendedJavaWrapper):
118
125
  def __init__(self, path, jspark):