spark-nlp 2.6.3rc1__py2.py3-none-any.whl → 6.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (329) hide show
  1. com/johnsnowlabs/ml/__init__.py +0 -0
  2. com/johnsnowlabs/ml/ai/__init__.py +10 -0
  3. com/johnsnowlabs/nlp/__init__.py +4 -2
  4. spark_nlp-6.2.1.dist-info/METADATA +362 -0
  5. spark_nlp-6.2.1.dist-info/RECORD +292 -0
  6. {spark_nlp-2.6.3rc1.dist-info → spark_nlp-6.2.1.dist-info}/WHEEL +1 -1
  7. sparknlp/__init__.py +281 -27
  8. sparknlp/annotation.py +137 -6
  9. sparknlp/annotation_audio.py +61 -0
  10. sparknlp/annotation_image.py +82 -0
  11. sparknlp/annotator/__init__.py +93 -0
  12. sparknlp/annotator/audio/__init__.py +16 -0
  13. sparknlp/annotator/audio/hubert_for_ctc.py +188 -0
  14. sparknlp/annotator/audio/wav2vec2_for_ctc.py +161 -0
  15. sparknlp/annotator/audio/whisper_for_ctc.py +251 -0
  16. sparknlp/annotator/chunk2_doc.py +85 -0
  17. sparknlp/annotator/chunker.py +137 -0
  18. sparknlp/annotator/classifier_dl/__init__.py +61 -0
  19. sparknlp/annotator/classifier_dl/albert_for_multiple_choice.py +161 -0
  20. sparknlp/annotator/classifier_dl/albert_for_question_answering.py +172 -0
  21. sparknlp/annotator/classifier_dl/albert_for_sequence_classification.py +201 -0
  22. sparknlp/annotator/classifier_dl/albert_for_token_classification.py +179 -0
  23. sparknlp/annotator/classifier_dl/albert_for_zero_shot_classification.py +211 -0
  24. sparknlp/annotator/classifier_dl/bart_for_zero_shot_classification.py +225 -0
  25. sparknlp/annotator/classifier_dl/bert_for_multiple_choice.py +161 -0
  26. sparknlp/annotator/classifier_dl/bert_for_question_answering.py +168 -0
  27. sparknlp/annotator/classifier_dl/bert_for_sequence_classification.py +202 -0
  28. sparknlp/annotator/classifier_dl/bert_for_token_classification.py +177 -0
  29. sparknlp/annotator/classifier_dl/bert_for_zero_shot_classification.py +212 -0
  30. sparknlp/annotator/classifier_dl/camembert_for_question_answering.py +168 -0
  31. sparknlp/annotator/classifier_dl/camembert_for_sequence_classification.py +205 -0
  32. sparknlp/annotator/classifier_dl/camembert_for_token_classification.py +173 -0
  33. sparknlp/annotator/classifier_dl/camembert_for_zero_shot_classification.py +202 -0
  34. sparknlp/annotator/classifier_dl/classifier_dl.py +320 -0
  35. sparknlp/annotator/classifier_dl/deberta_for_question_answering.py +168 -0
  36. sparknlp/annotator/classifier_dl/deberta_for_sequence_classification.py +198 -0
  37. sparknlp/annotator/classifier_dl/deberta_for_token_classification.py +175 -0
  38. sparknlp/annotator/classifier_dl/deberta_for_zero_shot_classification.py +193 -0
  39. sparknlp/annotator/classifier_dl/distil_bert_for_question_answering.py +168 -0
  40. sparknlp/annotator/classifier_dl/distil_bert_for_sequence_classification.py +201 -0
  41. sparknlp/annotator/classifier_dl/distil_bert_for_token_classification.py +175 -0
  42. sparknlp/annotator/classifier_dl/distil_bert_for_zero_shot_classification.py +211 -0
  43. sparknlp/annotator/classifier_dl/distilbert_for_multiple_choice.py +161 -0
  44. sparknlp/annotator/classifier_dl/longformer_for_question_answering.py +168 -0
  45. sparknlp/annotator/classifier_dl/longformer_for_sequence_classification.py +201 -0
  46. sparknlp/annotator/classifier_dl/longformer_for_token_classification.py +176 -0
  47. sparknlp/annotator/classifier_dl/mpnet_for_question_answering.py +148 -0
  48. sparknlp/annotator/classifier_dl/mpnet_for_sequence_classification.py +188 -0
  49. sparknlp/annotator/classifier_dl/mpnet_for_token_classification.py +173 -0
  50. sparknlp/annotator/classifier_dl/multi_classifier_dl.py +395 -0
  51. sparknlp/annotator/classifier_dl/roberta_for_multiple_choice.py +161 -0
  52. sparknlp/annotator/classifier_dl/roberta_for_question_answering.py +168 -0
  53. sparknlp/annotator/classifier_dl/roberta_for_sequence_classification.py +201 -0
  54. sparknlp/annotator/classifier_dl/roberta_for_token_classification.py +189 -0
  55. sparknlp/annotator/classifier_dl/roberta_for_zero_shot_classification.py +225 -0
  56. sparknlp/annotator/classifier_dl/sentiment_dl.py +378 -0
  57. sparknlp/annotator/classifier_dl/tapas_for_question_answering.py +170 -0
  58. sparknlp/annotator/classifier_dl/xlm_roberta_for_multiple_choice.py +149 -0
  59. sparknlp/annotator/classifier_dl/xlm_roberta_for_question_answering.py +168 -0
  60. sparknlp/annotator/classifier_dl/xlm_roberta_for_sequence_classification.py +201 -0
  61. sparknlp/annotator/classifier_dl/xlm_roberta_for_token_classification.py +173 -0
  62. sparknlp/annotator/classifier_dl/xlm_roberta_for_zero_shot_classification.py +225 -0
  63. sparknlp/annotator/classifier_dl/xlnet_for_sequence_classification.py +201 -0
  64. sparknlp/annotator/classifier_dl/xlnet_for_token_classification.py +176 -0
  65. sparknlp/annotator/cleaners/__init__.py +15 -0
  66. sparknlp/annotator/cleaners/cleaner.py +202 -0
  67. sparknlp/annotator/cleaners/extractor.py +191 -0
  68. sparknlp/annotator/coref/__init__.py +1 -0
  69. sparknlp/annotator/coref/spanbert_coref.py +221 -0
  70. sparknlp/annotator/cv/__init__.py +29 -0
  71. sparknlp/annotator/cv/blip_for_question_answering.py +172 -0
  72. sparknlp/annotator/cv/clip_for_zero_shot_classification.py +193 -0
  73. sparknlp/annotator/cv/convnext_for_image_classification.py +269 -0
  74. sparknlp/annotator/cv/florence2_transformer.py +180 -0
  75. sparknlp/annotator/cv/gemma3_for_multimodal.py +346 -0
  76. sparknlp/annotator/cv/internvl_for_multimodal.py +280 -0
  77. sparknlp/annotator/cv/janus_for_multimodal.py +351 -0
  78. sparknlp/annotator/cv/llava_for_multimodal.py +328 -0
  79. sparknlp/annotator/cv/mllama_for_multimodal.py +340 -0
  80. sparknlp/annotator/cv/paligemma_for_multimodal.py +308 -0
  81. sparknlp/annotator/cv/phi3_vision_for_multimodal.py +328 -0
  82. sparknlp/annotator/cv/qwen2vl_transformer.py +332 -0
  83. sparknlp/annotator/cv/smolvlm_transformer.py +426 -0
  84. sparknlp/annotator/cv/swin_for_image_classification.py +242 -0
  85. sparknlp/annotator/cv/vision_encoder_decoder_for_image_captioning.py +240 -0
  86. sparknlp/annotator/cv/vit_for_image_classification.py +217 -0
  87. sparknlp/annotator/dataframe_optimizer.py +216 -0
  88. sparknlp/annotator/date2_chunk.py +88 -0
  89. sparknlp/annotator/dependency/__init__.py +17 -0
  90. sparknlp/annotator/dependency/dependency_parser.py +294 -0
  91. sparknlp/annotator/dependency/typed_dependency_parser.py +318 -0
  92. sparknlp/annotator/document_character_text_splitter.py +228 -0
  93. sparknlp/annotator/document_normalizer.py +235 -0
  94. sparknlp/annotator/document_token_splitter.py +175 -0
  95. sparknlp/annotator/document_token_splitter_test.py +85 -0
  96. sparknlp/annotator/embeddings/__init__.py +45 -0
  97. sparknlp/annotator/embeddings/albert_embeddings.py +230 -0
  98. sparknlp/annotator/embeddings/auto_gguf_embeddings.py +539 -0
  99. sparknlp/annotator/embeddings/bert_embeddings.py +208 -0
  100. sparknlp/annotator/embeddings/bert_sentence_embeddings.py +224 -0
  101. sparknlp/annotator/embeddings/bge_embeddings.py +199 -0
  102. sparknlp/annotator/embeddings/camembert_embeddings.py +210 -0
  103. sparknlp/annotator/embeddings/chunk_embeddings.py +149 -0
  104. sparknlp/annotator/embeddings/deberta_embeddings.py +208 -0
  105. sparknlp/annotator/embeddings/distil_bert_embeddings.py +221 -0
  106. sparknlp/annotator/embeddings/doc2vec.py +352 -0
  107. sparknlp/annotator/embeddings/e5_embeddings.py +195 -0
  108. sparknlp/annotator/embeddings/e5v_embeddings.py +138 -0
  109. sparknlp/annotator/embeddings/elmo_embeddings.py +251 -0
  110. sparknlp/annotator/embeddings/instructor_embeddings.py +204 -0
  111. sparknlp/annotator/embeddings/longformer_embeddings.py +211 -0
  112. sparknlp/annotator/embeddings/minilm_embeddings.py +189 -0
  113. sparknlp/annotator/embeddings/mpnet_embeddings.py +192 -0
  114. sparknlp/annotator/embeddings/mxbai_embeddings.py +184 -0
  115. sparknlp/annotator/embeddings/nomic_embeddings.py +181 -0
  116. sparknlp/annotator/embeddings/roberta_embeddings.py +225 -0
  117. sparknlp/annotator/embeddings/roberta_sentence_embeddings.py +191 -0
  118. sparknlp/annotator/embeddings/sentence_embeddings.py +134 -0
  119. sparknlp/annotator/embeddings/snowflake_embeddings.py +202 -0
  120. sparknlp/annotator/embeddings/uae_embeddings.py +211 -0
  121. sparknlp/annotator/embeddings/universal_sentence_encoder.py +211 -0
  122. sparknlp/annotator/embeddings/word2vec.py +353 -0
  123. sparknlp/annotator/embeddings/word_embeddings.py +385 -0
  124. sparknlp/annotator/embeddings/xlm_roberta_embeddings.py +225 -0
  125. sparknlp/annotator/embeddings/xlm_roberta_sentence_embeddings.py +194 -0
  126. sparknlp/annotator/embeddings/xlnet_embeddings.py +227 -0
  127. sparknlp/annotator/er/__init__.py +16 -0
  128. sparknlp/annotator/er/entity_ruler.py +267 -0
  129. sparknlp/annotator/graph_extraction.py +368 -0
  130. sparknlp/annotator/keyword_extraction/__init__.py +16 -0
  131. sparknlp/annotator/keyword_extraction/yake_keyword_extraction.py +270 -0
  132. sparknlp/annotator/ld_dl/__init__.py +16 -0
  133. sparknlp/annotator/ld_dl/language_detector_dl.py +199 -0
  134. sparknlp/annotator/lemmatizer.py +250 -0
  135. sparknlp/annotator/matcher/__init__.py +20 -0
  136. sparknlp/annotator/matcher/big_text_matcher.py +272 -0
  137. sparknlp/annotator/matcher/date_matcher.py +303 -0
  138. sparknlp/annotator/matcher/multi_date_matcher.py +109 -0
  139. sparknlp/annotator/matcher/regex_matcher.py +221 -0
  140. sparknlp/annotator/matcher/text_matcher.py +290 -0
  141. sparknlp/annotator/n_gram_generator.py +141 -0
  142. sparknlp/annotator/ner/__init__.py +21 -0
  143. sparknlp/annotator/ner/ner_approach.py +94 -0
  144. sparknlp/annotator/ner/ner_converter.py +148 -0
  145. sparknlp/annotator/ner/ner_crf.py +397 -0
  146. sparknlp/annotator/ner/ner_dl.py +591 -0
  147. sparknlp/annotator/ner/ner_dl_graph_checker.py +293 -0
  148. sparknlp/annotator/ner/ner_overwriter.py +166 -0
  149. sparknlp/annotator/ner/zero_shot_ner_model.py +173 -0
  150. sparknlp/annotator/normalizer.py +230 -0
  151. sparknlp/annotator/openai/__init__.py +16 -0
  152. sparknlp/annotator/openai/openai_completion.py +349 -0
  153. sparknlp/annotator/openai/openai_embeddings.py +106 -0
  154. sparknlp/annotator/param/__init__.py +17 -0
  155. sparknlp/annotator/param/classifier_encoder.py +98 -0
  156. sparknlp/annotator/param/evaluation_dl_params.py +130 -0
  157. sparknlp/annotator/pos/__init__.py +16 -0
  158. sparknlp/annotator/pos/perceptron.py +263 -0
  159. sparknlp/annotator/sentence/__init__.py +17 -0
  160. sparknlp/annotator/sentence/sentence_detector.py +290 -0
  161. sparknlp/annotator/sentence/sentence_detector_dl.py +467 -0
  162. sparknlp/annotator/sentiment/__init__.py +17 -0
  163. sparknlp/annotator/sentiment/sentiment_detector.py +208 -0
  164. sparknlp/annotator/sentiment/vivekn_sentiment.py +242 -0
  165. sparknlp/annotator/seq2seq/__init__.py +35 -0
  166. sparknlp/annotator/seq2seq/auto_gguf_model.py +304 -0
  167. sparknlp/annotator/seq2seq/auto_gguf_reranker.py +334 -0
  168. sparknlp/annotator/seq2seq/auto_gguf_vision_model.py +336 -0
  169. sparknlp/annotator/seq2seq/bart_transformer.py +420 -0
  170. sparknlp/annotator/seq2seq/cohere_transformer.py +357 -0
  171. sparknlp/annotator/seq2seq/cpm_transformer.py +321 -0
  172. sparknlp/annotator/seq2seq/gpt2_transformer.py +363 -0
  173. sparknlp/annotator/seq2seq/llama2_transformer.py +343 -0
  174. sparknlp/annotator/seq2seq/llama3_transformer.py +381 -0
  175. sparknlp/annotator/seq2seq/m2m100_transformer.py +392 -0
  176. sparknlp/annotator/seq2seq/marian_transformer.py +374 -0
  177. sparknlp/annotator/seq2seq/mistral_transformer.py +348 -0
  178. sparknlp/annotator/seq2seq/nllb_transformer.py +420 -0
  179. sparknlp/annotator/seq2seq/olmo_transformer.py +326 -0
  180. sparknlp/annotator/seq2seq/phi2_transformer.py +326 -0
  181. sparknlp/annotator/seq2seq/phi3_transformer.py +330 -0
  182. sparknlp/annotator/seq2seq/phi4_transformer.py +387 -0
  183. sparknlp/annotator/seq2seq/qwen_transformer.py +340 -0
  184. sparknlp/annotator/seq2seq/starcoder_transformer.py +335 -0
  185. sparknlp/annotator/seq2seq/t5_transformer.py +425 -0
  186. sparknlp/annotator/similarity/__init__.py +0 -0
  187. sparknlp/annotator/similarity/document_similarity_ranker.py +379 -0
  188. sparknlp/annotator/spell_check/__init__.py +18 -0
  189. sparknlp/annotator/spell_check/context_spell_checker.py +911 -0
  190. sparknlp/annotator/spell_check/norvig_sweeting.py +358 -0
  191. sparknlp/annotator/spell_check/symmetric_delete.py +299 -0
  192. sparknlp/annotator/stemmer.py +79 -0
  193. sparknlp/annotator/stop_words_cleaner.py +190 -0
  194. sparknlp/annotator/tf_ner_dl_graph_builder.py +179 -0
  195. sparknlp/annotator/token/__init__.py +19 -0
  196. sparknlp/annotator/token/chunk_tokenizer.py +118 -0
  197. sparknlp/annotator/token/recursive_tokenizer.py +205 -0
  198. sparknlp/annotator/token/regex_tokenizer.py +208 -0
  199. sparknlp/annotator/token/tokenizer.py +561 -0
  200. sparknlp/annotator/token2_chunk.py +76 -0
  201. sparknlp/annotator/ws/__init__.py +16 -0
  202. sparknlp/annotator/ws/word_segmenter.py +429 -0
  203. sparknlp/base/__init__.py +30 -0
  204. sparknlp/base/audio_assembler.py +95 -0
  205. sparknlp/base/doc2_chunk.py +169 -0
  206. sparknlp/base/document_assembler.py +164 -0
  207. sparknlp/base/embeddings_finisher.py +201 -0
  208. sparknlp/base/finisher.py +217 -0
  209. sparknlp/base/gguf_ranking_finisher.py +234 -0
  210. sparknlp/base/graph_finisher.py +125 -0
  211. sparknlp/base/has_recursive_fit.py +24 -0
  212. sparknlp/base/has_recursive_transform.py +22 -0
  213. sparknlp/base/image_assembler.py +172 -0
  214. sparknlp/base/light_pipeline.py +429 -0
  215. sparknlp/base/multi_document_assembler.py +164 -0
  216. sparknlp/base/prompt_assembler.py +207 -0
  217. sparknlp/base/recursive_pipeline.py +107 -0
  218. sparknlp/base/table_assembler.py +145 -0
  219. sparknlp/base/token_assembler.py +124 -0
  220. sparknlp/common/__init__.py +26 -0
  221. sparknlp/common/annotator_approach.py +41 -0
  222. sparknlp/common/annotator_model.py +47 -0
  223. sparknlp/common/annotator_properties.py +114 -0
  224. sparknlp/common/annotator_type.py +38 -0
  225. sparknlp/common/completion_post_processing.py +37 -0
  226. sparknlp/common/coverage_result.py +22 -0
  227. sparknlp/common/match_strategy.py +33 -0
  228. sparknlp/common/properties.py +1298 -0
  229. sparknlp/common/read_as.py +33 -0
  230. sparknlp/common/recursive_annotator_approach.py +35 -0
  231. sparknlp/common/storage.py +149 -0
  232. sparknlp/common/utils.py +39 -0
  233. sparknlp/functions.py +315 -5
  234. sparknlp/internal/__init__.py +1199 -0
  235. sparknlp/internal/annotator_java_ml.py +32 -0
  236. sparknlp/internal/annotator_transformer.py +37 -0
  237. sparknlp/internal/extended_java_wrapper.py +63 -0
  238. sparknlp/internal/params_getters_setters.py +71 -0
  239. sparknlp/internal/recursive.py +70 -0
  240. sparknlp/logging/__init__.py +15 -0
  241. sparknlp/logging/comet.py +467 -0
  242. sparknlp/partition/__init__.py +16 -0
  243. sparknlp/partition/partition.py +244 -0
  244. sparknlp/partition/partition_properties.py +902 -0
  245. sparknlp/partition/partition_transformer.py +200 -0
  246. sparknlp/pretrained/__init__.py +17 -0
  247. sparknlp/pretrained/pretrained_pipeline.py +158 -0
  248. sparknlp/pretrained/resource_downloader.py +216 -0
  249. sparknlp/pretrained/utils.py +35 -0
  250. sparknlp/reader/__init__.py +15 -0
  251. sparknlp/reader/enums.py +19 -0
  252. sparknlp/reader/pdf_to_text.py +190 -0
  253. sparknlp/reader/reader2doc.py +124 -0
  254. sparknlp/reader/reader2image.py +136 -0
  255. sparknlp/reader/reader2table.py +44 -0
  256. sparknlp/reader/reader_assembler.py +159 -0
  257. sparknlp/reader/sparknlp_reader.py +461 -0
  258. sparknlp/training/__init__.py +20 -0
  259. sparknlp/training/_tf_graph_builders/__init__.py +0 -0
  260. sparknlp/training/_tf_graph_builders/graph_builders.py +299 -0
  261. sparknlp/training/_tf_graph_builders/ner_dl/__init__.py +0 -0
  262. sparknlp/training/_tf_graph_builders/ner_dl/create_graph.py +41 -0
  263. sparknlp/training/_tf_graph_builders/ner_dl/dataset_encoder.py +78 -0
  264. sparknlp/training/_tf_graph_builders/ner_dl/ner_model.py +521 -0
  265. sparknlp/training/_tf_graph_builders/ner_dl/ner_model_saver.py +62 -0
  266. sparknlp/training/_tf_graph_builders/ner_dl/sentence_grouper.py +28 -0
  267. sparknlp/training/_tf_graph_builders/tf2contrib/__init__.py +36 -0
  268. sparknlp/training/_tf_graph_builders/tf2contrib/core_rnn_cell.py +385 -0
  269. sparknlp/training/_tf_graph_builders/tf2contrib/fused_rnn_cell.py +183 -0
  270. sparknlp/training/_tf_graph_builders/tf2contrib/gru_ops.py +235 -0
  271. sparknlp/training/_tf_graph_builders/tf2contrib/lstm_ops.py +665 -0
  272. sparknlp/training/_tf_graph_builders/tf2contrib/rnn.py +245 -0
  273. sparknlp/training/_tf_graph_builders/tf2contrib/rnn_cell.py +4006 -0
  274. sparknlp/training/_tf_graph_builders_1x/__init__.py +0 -0
  275. sparknlp/training/_tf_graph_builders_1x/graph_builders.py +277 -0
  276. sparknlp/training/_tf_graph_builders_1x/ner_dl/__init__.py +0 -0
  277. sparknlp/training/_tf_graph_builders_1x/ner_dl/create_graph.py +34 -0
  278. sparknlp/training/_tf_graph_builders_1x/ner_dl/dataset_encoder.py +78 -0
  279. sparknlp/training/_tf_graph_builders_1x/ner_dl/ner_model.py +532 -0
  280. sparknlp/training/_tf_graph_builders_1x/ner_dl/ner_model_saver.py +62 -0
  281. sparknlp/training/_tf_graph_builders_1x/ner_dl/sentence_grouper.py +28 -0
  282. sparknlp/training/conll.py +150 -0
  283. sparknlp/training/conllu.py +103 -0
  284. sparknlp/training/pos.py +103 -0
  285. sparknlp/training/pub_tator.py +76 -0
  286. sparknlp/training/spacy_to_annotation.py +57 -0
  287. sparknlp/training/tfgraphs.py +5 -0
  288. sparknlp/upload_to_hub.py +149 -0
  289. sparknlp/util.py +51 -5
  290. com/__init__.pyc +0 -0
  291. com/__pycache__/__init__.cpython-36.pyc +0 -0
  292. com/johnsnowlabs/__init__.pyc +0 -0
  293. com/johnsnowlabs/__pycache__/__init__.cpython-36.pyc +0 -0
  294. com/johnsnowlabs/nlp/__init__.pyc +0 -0
  295. com/johnsnowlabs/nlp/__pycache__/__init__.cpython-36.pyc +0 -0
  296. spark_nlp-2.6.3rc1.dist-info/METADATA +0 -36
  297. spark_nlp-2.6.3rc1.dist-info/RECORD +0 -48
  298. sparknlp/__init__.pyc +0 -0
  299. sparknlp/__pycache__/__init__.cpython-36.pyc +0 -0
  300. sparknlp/__pycache__/annotation.cpython-36.pyc +0 -0
  301. sparknlp/__pycache__/annotator.cpython-36.pyc +0 -0
  302. sparknlp/__pycache__/base.cpython-36.pyc +0 -0
  303. sparknlp/__pycache__/common.cpython-36.pyc +0 -0
  304. sparknlp/__pycache__/embeddings.cpython-36.pyc +0 -0
  305. sparknlp/__pycache__/functions.cpython-36.pyc +0 -0
  306. sparknlp/__pycache__/internal.cpython-36.pyc +0 -0
  307. sparknlp/__pycache__/pretrained.cpython-36.pyc +0 -0
  308. sparknlp/__pycache__/storage.cpython-36.pyc +0 -0
  309. sparknlp/__pycache__/training.cpython-36.pyc +0 -0
  310. sparknlp/__pycache__/util.cpython-36.pyc +0 -0
  311. sparknlp/annotation.pyc +0 -0
  312. sparknlp/annotator.py +0 -3006
  313. sparknlp/annotator.pyc +0 -0
  314. sparknlp/base.py +0 -347
  315. sparknlp/base.pyc +0 -0
  316. sparknlp/common.py +0 -193
  317. sparknlp/common.pyc +0 -0
  318. sparknlp/embeddings.py +0 -40
  319. sparknlp/embeddings.pyc +0 -0
  320. sparknlp/internal.py +0 -288
  321. sparknlp/internal.pyc +0 -0
  322. sparknlp/pretrained.py +0 -123
  323. sparknlp/pretrained.pyc +0 -0
  324. sparknlp/storage.py +0 -32
  325. sparknlp/storage.pyc +0 -0
  326. sparknlp/training.py +0 -62
  327. sparknlp/training.pyc +0 -0
  328. sparknlp/util.pyc +0 -0
  329. {spark_nlp-2.6.3rc1.dist-info → spark_nlp-6.2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,240 @@
1
+ # Copyright 2017-2022 John Snow Labs
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Contains classes concerning VisionEncoderDecoderForImageCaptioning."""
16
+
17
+ from sparknlp.common import *
18
+
19
+
20
+ class VisionEncoderDecoderForImageCaptioning(AnnotatorModel,
21
+ HasBatchedAnnotateImage,
22
+ HasImageFeatureProperties,
23
+ HasGeneratorProperties,
24
+ HasRescaleFactor,
25
+ HasEngine):
26
+ """VisionEncoderDecoder model that converts images into text captions. It allows for the use of
27
+ pretrained vision auto-encoding models, such as ViT, BEiT, or DeiT as the encoder, in
28
+ combination with pretrained language models, like RoBERTa, GPT2, or BERT as the decoder.
29
+
30
+ Pretrained models can be loaded with ``pretrained`` of the companion object:
31
+
32
+ .. code-block:: python
33
+
34
+ imageClassifier = VisionEncoderDecoderForImageCaptioning.pretrained() \\
35
+ .setInputCols(["image_assembler"]) \\
36
+ .setOutputCol("caption")
37
+
38
+
39
+ The default model is ``"image_captioning_vit_gpt2"``, if no name is provided.
40
+
41
+ For available pretrained models please see the
42
+ `Models Hub <https://sparknlp.org/models?task=Image+Captioning>`__.
43
+
44
+ Models from the HuggingFace 🤗 Transformers library are also compatible with Spark NLP 🚀. To
45
+ see which models are compatible and how to import them see
46
+ https://github.com/JohnSnowLabs/spark-nlp/discussions/5669 and to see more extended
47
+ examples, see
48
+ `VisionEncoderDecoderTestSpec <https://github.com/JohnSnowLabs/spark-nlp/blob/master/src/test/scala/com/johnsnowlabs/nlp/annotators/cv/VisionEncoderDecoderForImageCaptioningTestSpec.scala>`__.
49
+
50
+ Notes
51
+ -----
52
+ This is a very computationally expensive module especially on larger
53
+ batch sizes. The use of an accelerator such as GPU is recommended.
54
+
55
+
56
+ ====================== ======================
57
+ Input Annotation types Output Annotation type
58
+ ====================== ======================
59
+ ``IMAGE`` ``DOCUMENT``
60
+ ====================== ======================
61
+
62
+ Parameters
63
+ ----------
64
+ configProtoBytes
65
+ ConfigProto from tensorflow, serialized into byte array.
66
+ doResize
67
+ Whether to resize the input to a certain size
68
+ doNormalize
69
+ Whether to normalize the input with mean and standard deviation
70
+ featureExtractorType
71
+ Name of model's architecture for feature extraction
72
+ imageMean
73
+ The sequence of means for each channel, to be used when normalizing images
74
+ imageStd
75
+ The sequence of standard deviations for each channel, to be used when normalizing images
76
+ resample
77
+ An optional resampling filter. This can be one of `PIL.Image.NEAREST`, `PIL.Image.BILINEAR` or
78
+ `PIL.Image.BICUBIC`. Only has an effect if do_resize is set to True.
79
+ size
80
+ Resize the input to the given size. If a tuple is provided, it should be (width, height). If only an integer is
81
+ provided, then the input will be resized to (size, size). Only has an effect if do_resize is set to True.
82
+ doRescale
83
+ Whether to rescale the image values by rescaleFactor
84
+ rescaleFactor
85
+ Factor to scale the image values
86
+ minOutputLength
87
+ Minimum length of the sequence to be generated
88
+ maxOutputLength
89
+ Maximum length of output text
90
+ doSample
91
+ Whether or not to use sampling; use greedy decoding otherwise
92
+ temperature
93
+ The value used to module the next token probabilities
94
+ topK
95
+ The number of highest probability vocabulary tokens to keep for top-k-filtering
96
+ topP
97
+ If set to float < 1, only the most probable tokens with probabilities that add up to ``top_p`` or higher are
98
+ kept for generation
99
+ repetitionPenalty
100
+ The parameter for repetition penalty. 1.0 means no penalty.
101
+ See `this paper <https://arxiv.org/pdf/1909.05858.pdf>`__ for more details
102
+ noRepeatNgramSize
103
+ If set to int > 0, all ngrams of that size can only occur once
104
+ beamSize
105
+ The Number of beams for beam search
106
+ nReturnSequences
107
+ The number of sequences to return from the beam search
108
+
109
+ Examples
110
+ --------
111
+ >>> import sparknlp
112
+ >>> from sparknlp.base import *
113
+ >>> from sparknlp.annotator import *
114
+ >>> from pyspark.ml import Pipeline
115
+ >>> imageDF = spark.read \\
116
+ ... .format("image") \\
117
+ ... .option("dropInvalid", value = True) \\
118
+ ... .load("src/test/resources/image/")
119
+ >>> imageAssembler = ImageAssembler() \\
120
+ ... .setInputCol("image") \\
121
+ ... .setOutputCol("image_assembler")
122
+ >>> imageCaptioning = VisionEncoderDecoderForImageCaptioning \\
123
+ ... .pretrained() \\
124
+ ... .setBeamSize(2) \\
125
+ ... .setDoSample(False) \\
126
+ ... .setInputCols(["image_assembler"]) \\
127
+ ... .setOutputCol("caption")
128
+ >>> pipeline = Pipeline().setStages([imageAssembler, imageCaptioning])
129
+ >>> pipelineDF = pipeline.fit(imageDF).transform(imageDF)
130
+ >>> pipelineDF \\
131
+ ... .selectExpr("reverse(split(image.origin, '/'))[0] as image_name", "caption.result") \\
132
+ ... .show(truncate = False)
133
+ +-----------------+---------------------------------------------------------+
134
+ |image_name |result |
135
+ +-----------------+---------------------------------------------------------+
136
+ |palace.JPEG |[a large room filled with furniture and a large window] |
137
+ |egyptian_cat.jpeg|[a cat laying on a couch next to another cat] |
138
+ |hippopotamus.JPEG|[a brown bear in a body of water] |
139
+ |hen.JPEG |[a flock of chickens standing next to each other] |
140
+ |ostrich.JPEG |[a large bird standing on top of a lush green field] |
141
+ |junco.JPEG |[a small bird standing on a wet ground] |
142
+ |bluetick.jpg |[a small dog standing on a wooden floor] |
143
+ |chihuahua.jpg |[a small brown dog wearing a blue sweater] |
144
+ |tractor.JPEG |[a man is standing in a field with a tractor] |
145
+ |ox.JPEG |[a large brown cow standing on top of a lush green field]|
146
+ +-----------------+---------------------------------------------------------+
147
+
148
+ """
149
+ name = "VisionEncoderDecoderForImageCaptioning"
150
+
151
+ inputAnnotatorTypes = [AnnotatorType.IMAGE]
152
+
153
+ outputAnnotatorType = AnnotatorType.DOCUMENT
154
+
155
+ configProtoBytes = Param(Params._dummy(),
156
+ "configProtoBytes",
157
+ "ConfigProto from tensorflow, serialized into byte array. Get with "
158
+ "config_proto.SerializeToString()",
159
+ TypeConverters.toListInt)
160
+
161
+ def setConfigProtoBytes(self, b):
162
+ """Sets configProto from tensorflow, serialized into byte array.
163
+
164
+ Parameters
165
+ ----------
166
+ b : List[int]
167
+ ConfigProto from tensorflow, serialized into byte array
168
+ """
169
+ return self._set(configProtoBytes=b)
170
+
171
+ @keyword_only
172
+ def __init__(self, classname="com.johnsnowlabs.nlp.annotators.cv.VisionEncoderDecoderForImageCaptioning",
173
+ java_model=None):
174
+ super(VisionEncoderDecoderForImageCaptioning, self).__init__(
175
+ classname=classname,
176
+ java_model=java_model
177
+ )
178
+ self._setDefault(
179
+ batchSize=2,
180
+ beamSize=1,
181
+ doNormalize=True,
182
+ doRescale=True,
183
+ doResize=True,
184
+ doSample=True,
185
+ imageMean=[0.5, 0.5, 0.5],
186
+ imageStd=[0.5, 0.5, 0.5],
187
+ maxOutputLength=50,
188
+ minOutputLength=0,
189
+ nReturnSequences=1,
190
+ noRepeatNgramSize=0,
191
+ repetitionPenalty=1.0,
192
+ resample=2,
193
+ rescaleFactor=1 / 255.0,
194
+ size=224,
195
+ temperature=1.0,
196
+ topK=50,
197
+ topP=1.0)
198
+
199
+ @staticmethod
200
+ def loadSavedModel(folder, spark_session):
201
+ """Loads a locally saved model.
202
+
203
+ Parameters
204
+ ----------
205
+ folder : str
206
+ Folder of the saved model
207
+ spark_session : pyspark.sql.SparkSession
208
+ The current SparkSession
209
+
210
+ Returns
211
+ -------
212
+ VisionEncoderDecoderForImageCaptioning
213
+ The restored model
214
+ """
215
+ from sparknlp.internal import _VisionEncoderDecoderForImageCaptioning
216
+ jModel = _VisionEncoderDecoderForImageCaptioning(folder, spark_session._jsparkSession)._java_obj
217
+ return VisionEncoderDecoderForImageCaptioning(java_model=jModel)
218
+
219
+ @staticmethod
220
+ def pretrained(name="image_captioning_vit_gpt2", lang="en", remote_loc=None):
221
+ """Downloads and loads a pretrained model.
222
+
223
+ Parameters
224
+ ----------
225
+ name : str, optional
226
+ Name of the pretrained model, by default
227
+ "image_captioning_vit_gpt2"
228
+ lang : str, optional
229
+ Language of the pretrained model, by default "en"
230
+ remote_loc : str, optional
231
+ Optional remote address of the resource, by default None. Will use
232
+ Spark NLPs repositories otherwise.
233
+
234
+ Returns
235
+ -------
236
+ VisionEncoderDecoderForImageCaptioning
237
+ The restored model
238
+ """
239
+ from sparknlp.pretrained import ResourceDownloader
240
+ return ResourceDownloader.downloadModel(VisionEncoderDecoderForImageCaptioning, name, lang, remote_loc)
@@ -0,0 +1,217 @@
1
+ # Copyright 2017-2022 John Snow Labs
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Contains classes concerning ViTForImageClassification."""
16
+
17
+ from sparknlp.common import *
18
+
19
+
20
+ class ViTForImageClassification(AnnotatorModel,
21
+ HasBatchedAnnotateImage,
22
+ HasImageFeatureProperties,
23
+ HasEngine):
24
+ """Vision Transformer (ViT) for image classification.
25
+
26
+ ViT is a transformer based alternative to the convolutional neural networks usually
27
+ used for image recognition tasks.
28
+
29
+ Pretrained models can be loaded with ``pretrained`` of the companion object:
30
+
31
+ .. code-block:: python
32
+
33
+ imageClassifier = ViTForImageClassification.pretrained() \\
34
+ .setInputCols(["image_assembler"]) \\
35
+ .setOutputCol("class")
36
+
37
+
38
+ The default model is ``"image_classifier_vit_base_patch16_224"``, if no name is
39
+ provided.
40
+
41
+ For available pretrained models please see the
42
+ `Models Hub <https://sparknlp.org/models?task=Image+Classification>`__.
43
+
44
+ Models from the HuggingFace 🤗 Transformers library are also compatible with Spark
45
+ NLP 🚀. To see which models are compatible and how to import them see
46
+ https://github.com/JohnSnowLabs/spark-nlp/discussions/5669 and to see more extended
47
+ examples, see
48
+ `ViTImageClassificationTestSpec <https://github.com/JohnSnowLabs/spark-nlp/blob/master/src/test/scala/com/johnsnowlabs/nlp/annotators/cv/ViTImageClassificationTestSpec.scala>`__.
49
+
50
+ **Paper Abstract:**
51
+
52
+ *While the Transformer architecture has become the de-facto standard for natural
53
+ language processing tasks, its applications to computer vision remain limited. In
54
+ vision, attention is either applied in conjunction with convolutional networks, or
55
+ used to replace certain components of convolutional networks while keeping their
56
+ overall structure in place. We show that this reliance on CNNs is not necessary and
57
+ a pure transformer applied directly to sequences of image patches can perform very
58
+ well on image classification tasks. When pre-trained on large amounts of data and
59
+ transferred to multiple mid-sized or small image recognition benchmarks (ImageNet,
60
+ CIFAR-100, VTAB, etc.), Vision Transformer (ViT) attains excellent results compared
61
+ to state-of-the-art convolutional networks while requiring substantially fewer
62
+ computational resources to train.*
63
+
64
+
65
+ ====================== ======================
66
+ Input Annotation types Output Annotation type
67
+ ====================== ======================
68
+ ``IMAGE`` ``CATEGORY``
69
+ ====================== ======================
70
+
71
+ References
72
+ ----------
73
+
74
+ `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
75
+ <https://arxiv.org/abs/2010.11929>`__
76
+
77
+
78
+ Parameters
79
+ ----------
80
+ doResize
81
+ Whether to resize the input to a certain size
82
+ doNormalize
83
+ Whether to normalize the input with mean and standard deviation
84
+ featureExtractorType
85
+ Name of model's architecture for feature extraction
86
+ imageMean
87
+ The sequence of means for each channel, to be used when normalizing images
88
+ imageStd
89
+ The sequence of standard deviations for each channel, to be used when normalizing images
90
+ resample
91
+ An optional resampling filter. This can be one of `PIL.Image.NEAREST`, `PIL.Image.BILINEAR` or
92
+ `PIL.Image.BICUBIC`. Only has an effect if do_resize is set to True.
93
+ size
94
+ Resize the input to the given size. If a tuple is provided, it should be (width, height). If only an integer is
95
+ provided, then the input will be resized to (size, size). Only has an effect if do_resize is set to True.
96
+ configProtoBytes
97
+ ConfigProto from tensorflow, serialized into byte array.
98
+
99
+ Examples
100
+ --------
101
+ >>> import sparknlp
102
+ >>> from sparknlp.base import *
103
+ >>> from sparknlp.annotator import *
104
+ >>> from pyspark.ml import Pipeline
105
+ >>> imageDF = spark.read \\
106
+ ... .format("image") \\
107
+ ... .option("dropInvalid", value = True) \\
108
+ ... .load("src/test/resources/image/")
109
+ >>> imageAssembler = ImageAssembler() \\
110
+ ... .setInputCol("image") \\
111
+ ... .setOutputCol("image_assembler")
112
+ >>> imageClassifier = ViTForImageClassification \\
113
+ ... .pretrained() \\
114
+ ... .setInputCols(["image_assembler"]) \\
115
+ ... .setOutputCol("class")
116
+ >>> pipeline = Pipeline().setStages([imageAssembler, imageClassifier])
117
+ >>> pipelineDF = pipeline.fit(imageDF).transform(imageDF)
118
+ >>> pipelineDF \\
119
+ ... .selectExpr("reverse(split(image.origin, '/'))[0] as image_name", "class.result") \\
120
+ ... .show(truncate=False)
121
+ +-----------------+----------------------------------------------------------+
122
+ |image_name |result |
123
+ +-----------------+----------------------------------------------------------+
124
+ |palace.JPEG |[palace] |
125
+ |egyptian_cat.jpeg|[Egyptian cat] |
126
+ |hippopotamus.JPEG|[hippopotamus, hippo, river horse, Hippopotamus amphibius]|
127
+ |hen.JPEG |[hen] |
128
+ |ostrich.JPEG |[ostrich, Struthio camelus] |
129
+ |junco.JPEG |[junco, snowbird] |
130
+ |bluetick.jpg |[bluetick] |
131
+ |chihuahua.jpg |[Chihuahua] |
132
+ |tractor.JPEG |[tractor] |
133
+ |ox.JPEG |[ox] |
134
+ +-----------------+----------------------------------------------------------+
135
+
136
+ """
137
+ name = "ViTForImageClassification"
138
+
139
+ inputAnnotatorTypes = [AnnotatorType.IMAGE]
140
+
141
+ outputAnnotatorType = AnnotatorType.CATEGORY
142
+
143
+ configProtoBytes = Param(Params._dummy(),
144
+ "configProtoBytes",
145
+ "ConfigProto from tensorflow, serialized into byte array. Get with "
146
+ "config_proto.SerializeToString()",
147
+ TypeConverters.toListInt)
148
+
149
+ def getClasses(self):
150
+ """
151
+ Returns labels used to train this model
152
+ """
153
+ return self._call_java("getClasses")
154
+
155
+ def setConfigProtoBytes(self, b):
156
+ """Sets configProto from tensorflow, serialized into byte array.
157
+
158
+ Parameters
159
+ ----------
160
+ b : List[int]
161
+ ConfigProto from tensorflow, serialized into byte array
162
+ """
163
+ return self._set(configProtoBytes=b)
164
+
165
+ @keyword_only
166
+ def __init__(self, classname="com.johnsnowlabs.nlp.annotators.cv.ViTForImageClassification",
167
+ java_model=None):
168
+ super(ViTForImageClassification, self).__init__(
169
+ classname=classname,
170
+ java_model=java_model
171
+ )
172
+ self._setDefault(
173
+ batchSize=2
174
+ )
175
+
176
+ @staticmethod
177
+ def loadSavedModel(folder, spark_session):
178
+ """Loads a locally saved model.
179
+
180
+ Parameters
181
+ ----------
182
+ folder : str
183
+ Folder of the saved model
184
+ spark_session : pyspark.sql.SparkSession
185
+ The current SparkSession
186
+
187
+ Returns
188
+ -------
189
+ ViTForImageClassification
190
+ The restored model
191
+ """
192
+ from sparknlp.internal import _ViTForImageClassification
193
+ jModel = _ViTForImageClassification(folder, spark_session._jsparkSession)._java_obj
194
+ return ViTForImageClassification(java_model=jModel)
195
+
196
+ @staticmethod
197
+ def pretrained(name="image_classifier_vit_base_patch16_224", lang="en", remote_loc=None):
198
+ """Downloads and loads a pretrained model.
199
+
200
+ Parameters
201
+ ----------
202
+ name : str, optional
203
+ Name of the pretrained model, by default
204
+ "image_classifier_vit_base_patch16_224"
205
+ lang : str, optional
206
+ Language of the pretrained model, by default "en"
207
+ remote_loc : str, optional
208
+ Optional remote address of the resource, by default None. Will use
209
+ Spark NLPs repositories otherwise.
210
+
211
+ Returns
212
+ -------
213
+ ViTForImageClassification
214
+ The restored model
215
+ """
216
+ from sparknlp.pretrained import ResourceDownloader
217
+ return ResourceDownloader.downloadModel(ViTForImageClassification, name, lang, remote_loc)
@@ -0,0 +1,216 @@
1
+ # Copyright 2017-2025 John Snow Labs
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from pyspark.ml import Transformer
15
+ from pyspark.ml.param.shared import *
16
+ from pyspark.sql import DataFrame
17
+ from typing import Any
18
+
19
+ # Custom converter for string-to-string dictionaries
20
+ def toStringDict(value):
21
+ if not isinstance(value, dict):
22
+ raise TypeError("Expected a dictionary of strings.")
23
+ return {str(k): str(v) for k, v in value.items()}
24
+
25
+ class DataFrameOptimizer(Transformer):
26
+ """
27
+ Optimizes a Spark DataFrame by repartitioning, optionally caching, and persisting it to disk.
28
+
29
+ This transformer is intended to improve performance for Spark NLP pipelines or when preparing
30
+ data for export. It allows partition tuning via `numPartitions` directly, or indirectly using
31
+ `executorCores` and `numWorkers`. The DataFrame can also be persisted in a specified format
32
+ (`csv`, `json`, or `parquet`) with additional writer options.
33
+
34
+ Parameters
35
+ ----------
36
+ executorCores : int, optional
37
+ Number of cores per Spark executor (used to compute number of partitions if `numPartitions` is not set).
38
+
39
+ numWorkers : int, optional
40
+ Number of executor nodes (used to compute number of partitions if `numPartitions` is not set).
41
+
42
+ numPartitions : int, optional
43
+ Target number of partitions for the DataFrame (overrides calculation via cores × workers).
44
+
45
+ doCache : bool, default False
46
+ Whether to cache the DataFrame after repartitioning.
47
+
48
+ persistPath : str, optional
49
+ Path to save the DataFrame output (if persistence is enabled).
50
+
51
+ persistFormat : str, optional
52
+ Format to persist the DataFrame in: one of `'csv'`, `'json'`, or `'parquet'`.
53
+
54
+ outputOptions : dict, optional
55
+ Dictionary of options for the DataFrameWriter (e.g., `{"compression": "snappy"}` for parquet).
56
+
57
+ Examples
58
+ --------
59
+ >>> optimizer = DataFrameOptimizer() \\
60
+ ... .setExecutorCores(4) \\
61
+ ... .setNumWorkers(5) \\
62
+ ... .setDoCache(True) \\
63
+ ... .setPersistPath("/tmp/out") \\
64
+ ... .setPersistFormat("parquet") \\
65
+ ... .setOutputOptions({"compression": "snappy"})
66
+
67
+ >>> optimized_df = optimizer.transform(input_df)
68
+
69
+ Notes
70
+ -----
71
+ - You must specify either `numPartitions`, or both `executorCores` and `numWorkers`.
72
+ - Schema is preserved; no columns are modified or removed.
73
+ """
74
+
75
+ executorCores = Param(
76
+ Params._dummy(),
77
+ "executorCores",
78
+ "Number of cores per executor",
79
+ typeConverter = TypeConverters.toInt
80
+ )
81
+ numWorkers = Param(
82
+ Params._dummy(),
83
+ "numWorkers",
84
+ "Number of Spark workers",
85
+ typeConverter = TypeConverters.toInt
86
+ )
87
+ numPartitions = Param(
88
+ Params._dummy(),
89
+ "numPartitions",
90
+ "Total number of partitions (overrides executorCores * numWorkers)",
91
+ typeConverter = TypeConverters.toInt
92
+ )
93
+ doCache = Param(
94
+ Params._dummy(),
95
+ "doCache",
96
+ "Whether to cache the DataFrame",
97
+ typeConverter = TypeConverters.toBoolean
98
+ )
99
+
100
+ persistPath = Param(
101
+ Params._dummy(),
102
+ "persistPath",
103
+ "Optional path to persist the DataFrame",
104
+ typeConverter = TypeConverters.toString
105
+ )
106
+ persistFormat = Param(
107
+ Params._dummy(),
108
+ "persistFormat",
109
+ "Format to persist: parquet, json, csv",
110
+ typeConverter = TypeConverters.toString
111
+ )
112
+
113
+ outputOptions = Param(
114
+ Params._dummy(),
115
+ "outputOptions",
116
+ "Additional writer options",
117
+ typeConverter=toStringDict
118
+ )
119
+
120
+ def __init__(self):
121
+ super().__init__()
122
+ self._setDefault(
123
+ doCache=False,
124
+ persistFormat="none",
125
+ numPartitions=1,
126
+ executorCores=1,
127
+ numWorkers=1
128
+ )
129
+
130
+ # Parameter setters
131
+ def setExecutorCores(self, value: int):
132
+ """Set the number of executor cores."""
133
+ return self._set(executorCores=value)
134
+
135
+ def setNumWorkers(self, value: int):
136
+ """Set the number of Spark workers."""
137
+ return self._set(numWorkers=value)
138
+
139
+ def setNumPartitions(self, value: int):
140
+ """Set the total number of partitions (overrides cores * workers)."""
141
+ return self._set(numPartitions=value)
142
+
143
+ def setDoCache(self, value: bool):
144
+ """Set whether to cache the DataFrame."""
145
+ return self._set(doCache=value)
146
+
147
+ def setPersistPath(self, value: str):
148
+ """Set the path where the DataFrame should be persisted."""
149
+ return self._set(persistPath=value)
150
+
151
+ def setPersistFormat(self, value: str):
152
+ """Set the format to persist the DataFrame (parquet, json, csv)."""
153
+ return self._set(persistFormat=value)
154
+
155
+ def setOutputOptions(self, value: dict):
156
+ """Set additional writer options (e.g. for csv headers)."""
157
+ return self._set(outputOptions=value)
158
+
159
+ # Optional bulk setter
160
+ def setParams(self, **kwargs: Any):
161
+ for param, value in kwargs.items():
162
+ self._set(**{param: value})
163
+ return self
164
+
165
+ def _transform(self, dataset: DataFrame) -> DataFrame:
166
+ self._validate_params()
167
+ part_count = self.getOrDefault(self.numPartitions)
168
+ cores = self.getOrDefault(self.executorCores)
169
+ workers = self.getOrDefault(self.numWorkers)
170
+ if cores is None or workers is None:
171
+ raise ValueError("Provide either numPartitions or both executorCores and numWorkers")
172
+ if part_count == 1:
173
+ part_count = cores * workers
174
+
175
+ optimized_df = dataset.repartition(part_count)
176
+
177
+ if self.getOrDefault(self.doCache):
178
+ optimized_df = optimized_df.cache()
179
+
180
+ format = self.getOrDefault(self.persistFormat).lower()
181
+ if format != "none":
182
+ path = self.getOrDefault(self.persistPath)
183
+ if not path:
184
+ raise ValueError("persistPath must be set when persistFormat is not 'none'")
185
+ writer = optimized_df.write.mode("overwrite")
186
+ if self.isDefined(self.outputOptions):
187
+ writer = writer.options(**self.getOrDefault(self.outputOptions))
188
+ if format == "parquet":
189
+ writer.parquet(path)
190
+ elif format == "json":
191
+ writer.json(path)
192
+ elif format == "csv":
193
+ writer.csv(path)
194
+ else:
195
+ raise ValueError(f"Unsupported format: {format}")
196
+
197
+ return optimized_df
198
+
199
+ def _validate_params(self):
200
+ if self.isDefined(self.executorCores):
201
+ val = self.getOrDefault(self.executorCores)
202
+ if val <= 0:
203
+ raise ValueError("executorCores must be > 0")
204
+
205
+ if self.isDefined(self.numWorkers):
206
+ val = self.getOrDefault(self.numWorkers)
207
+ if val <= 0:
208
+ raise ValueError("numWorkers must be > 0")
209
+
210
+ if self.isDefined(self.numPartitions):
211
+ val = self.getOrDefault(self.numPartitions)
212
+ if val <= 0:
213
+ raise ValueError("numPartitions must be > 0")
214
+
215
+ if self.isDefined(self.persistPath) and not self.isDefined(self.persistFormat):
216
+ raise ValueError("persistFormat must be defined when persistPath is set")