spark-nlp 2.6.3rc1__py2.py3-none-any.whl → 6.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (329) hide show
  1. com/johnsnowlabs/ml/__init__.py +0 -0
  2. com/johnsnowlabs/ml/ai/__init__.py +10 -0
  3. com/johnsnowlabs/nlp/__init__.py +4 -2
  4. spark_nlp-6.2.1.dist-info/METADATA +362 -0
  5. spark_nlp-6.2.1.dist-info/RECORD +292 -0
  6. {spark_nlp-2.6.3rc1.dist-info → spark_nlp-6.2.1.dist-info}/WHEEL +1 -1
  7. sparknlp/__init__.py +281 -27
  8. sparknlp/annotation.py +137 -6
  9. sparknlp/annotation_audio.py +61 -0
  10. sparknlp/annotation_image.py +82 -0
  11. sparknlp/annotator/__init__.py +93 -0
  12. sparknlp/annotator/audio/__init__.py +16 -0
  13. sparknlp/annotator/audio/hubert_for_ctc.py +188 -0
  14. sparknlp/annotator/audio/wav2vec2_for_ctc.py +161 -0
  15. sparknlp/annotator/audio/whisper_for_ctc.py +251 -0
  16. sparknlp/annotator/chunk2_doc.py +85 -0
  17. sparknlp/annotator/chunker.py +137 -0
  18. sparknlp/annotator/classifier_dl/__init__.py +61 -0
  19. sparknlp/annotator/classifier_dl/albert_for_multiple_choice.py +161 -0
  20. sparknlp/annotator/classifier_dl/albert_for_question_answering.py +172 -0
  21. sparknlp/annotator/classifier_dl/albert_for_sequence_classification.py +201 -0
  22. sparknlp/annotator/classifier_dl/albert_for_token_classification.py +179 -0
  23. sparknlp/annotator/classifier_dl/albert_for_zero_shot_classification.py +211 -0
  24. sparknlp/annotator/classifier_dl/bart_for_zero_shot_classification.py +225 -0
  25. sparknlp/annotator/classifier_dl/bert_for_multiple_choice.py +161 -0
  26. sparknlp/annotator/classifier_dl/bert_for_question_answering.py +168 -0
  27. sparknlp/annotator/classifier_dl/bert_for_sequence_classification.py +202 -0
  28. sparknlp/annotator/classifier_dl/bert_for_token_classification.py +177 -0
  29. sparknlp/annotator/classifier_dl/bert_for_zero_shot_classification.py +212 -0
  30. sparknlp/annotator/classifier_dl/camembert_for_question_answering.py +168 -0
  31. sparknlp/annotator/classifier_dl/camembert_for_sequence_classification.py +205 -0
  32. sparknlp/annotator/classifier_dl/camembert_for_token_classification.py +173 -0
  33. sparknlp/annotator/classifier_dl/camembert_for_zero_shot_classification.py +202 -0
  34. sparknlp/annotator/classifier_dl/classifier_dl.py +320 -0
  35. sparknlp/annotator/classifier_dl/deberta_for_question_answering.py +168 -0
  36. sparknlp/annotator/classifier_dl/deberta_for_sequence_classification.py +198 -0
  37. sparknlp/annotator/classifier_dl/deberta_for_token_classification.py +175 -0
  38. sparknlp/annotator/classifier_dl/deberta_for_zero_shot_classification.py +193 -0
  39. sparknlp/annotator/classifier_dl/distil_bert_for_question_answering.py +168 -0
  40. sparknlp/annotator/classifier_dl/distil_bert_for_sequence_classification.py +201 -0
  41. sparknlp/annotator/classifier_dl/distil_bert_for_token_classification.py +175 -0
  42. sparknlp/annotator/classifier_dl/distil_bert_for_zero_shot_classification.py +211 -0
  43. sparknlp/annotator/classifier_dl/distilbert_for_multiple_choice.py +161 -0
  44. sparknlp/annotator/classifier_dl/longformer_for_question_answering.py +168 -0
  45. sparknlp/annotator/classifier_dl/longformer_for_sequence_classification.py +201 -0
  46. sparknlp/annotator/classifier_dl/longformer_for_token_classification.py +176 -0
  47. sparknlp/annotator/classifier_dl/mpnet_for_question_answering.py +148 -0
  48. sparknlp/annotator/classifier_dl/mpnet_for_sequence_classification.py +188 -0
  49. sparknlp/annotator/classifier_dl/mpnet_for_token_classification.py +173 -0
  50. sparknlp/annotator/classifier_dl/multi_classifier_dl.py +395 -0
  51. sparknlp/annotator/classifier_dl/roberta_for_multiple_choice.py +161 -0
  52. sparknlp/annotator/classifier_dl/roberta_for_question_answering.py +168 -0
  53. sparknlp/annotator/classifier_dl/roberta_for_sequence_classification.py +201 -0
  54. sparknlp/annotator/classifier_dl/roberta_for_token_classification.py +189 -0
  55. sparknlp/annotator/classifier_dl/roberta_for_zero_shot_classification.py +225 -0
  56. sparknlp/annotator/classifier_dl/sentiment_dl.py +378 -0
  57. sparknlp/annotator/classifier_dl/tapas_for_question_answering.py +170 -0
  58. sparknlp/annotator/classifier_dl/xlm_roberta_for_multiple_choice.py +149 -0
  59. sparknlp/annotator/classifier_dl/xlm_roberta_for_question_answering.py +168 -0
  60. sparknlp/annotator/classifier_dl/xlm_roberta_for_sequence_classification.py +201 -0
  61. sparknlp/annotator/classifier_dl/xlm_roberta_for_token_classification.py +173 -0
  62. sparknlp/annotator/classifier_dl/xlm_roberta_for_zero_shot_classification.py +225 -0
  63. sparknlp/annotator/classifier_dl/xlnet_for_sequence_classification.py +201 -0
  64. sparknlp/annotator/classifier_dl/xlnet_for_token_classification.py +176 -0
  65. sparknlp/annotator/cleaners/__init__.py +15 -0
  66. sparknlp/annotator/cleaners/cleaner.py +202 -0
  67. sparknlp/annotator/cleaners/extractor.py +191 -0
  68. sparknlp/annotator/coref/__init__.py +1 -0
  69. sparknlp/annotator/coref/spanbert_coref.py +221 -0
  70. sparknlp/annotator/cv/__init__.py +29 -0
  71. sparknlp/annotator/cv/blip_for_question_answering.py +172 -0
  72. sparknlp/annotator/cv/clip_for_zero_shot_classification.py +193 -0
  73. sparknlp/annotator/cv/convnext_for_image_classification.py +269 -0
  74. sparknlp/annotator/cv/florence2_transformer.py +180 -0
  75. sparknlp/annotator/cv/gemma3_for_multimodal.py +346 -0
  76. sparknlp/annotator/cv/internvl_for_multimodal.py +280 -0
  77. sparknlp/annotator/cv/janus_for_multimodal.py +351 -0
  78. sparknlp/annotator/cv/llava_for_multimodal.py +328 -0
  79. sparknlp/annotator/cv/mllama_for_multimodal.py +340 -0
  80. sparknlp/annotator/cv/paligemma_for_multimodal.py +308 -0
  81. sparknlp/annotator/cv/phi3_vision_for_multimodal.py +328 -0
  82. sparknlp/annotator/cv/qwen2vl_transformer.py +332 -0
  83. sparknlp/annotator/cv/smolvlm_transformer.py +426 -0
  84. sparknlp/annotator/cv/swin_for_image_classification.py +242 -0
  85. sparknlp/annotator/cv/vision_encoder_decoder_for_image_captioning.py +240 -0
  86. sparknlp/annotator/cv/vit_for_image_classification.py +217 -0
  87. sparknlp/annotator/dataframe_optimizer.py +216 -0
  88. sparknlp/annotator/date2_chunk.py +88 -0
  89. sparknlp/annotator/dependency/__init__.py +17 -0
  90. sparknlp/annotator/dependency/dependency_parser.py +294 -0
  91. sparknlp/annotator/dependency/typed_dependency_parser.py +318 -0
  92. sparknlp/annotator/document_character_text_splitter.py +228 -0
  93. sparknlp/annotator/document_normalizer.py +235 -0
  94. sparknlp/annotator/document_token_splitter.py +175 -0
  95. sparknlp/annotator/document_token_splitter_test.py +85 -0
  96. sparknlp/annotator/embeddings/__init__.py +45 -0
  97. sparknlp/annotator/embeddings/albert_embeddings.py +230 -0
  98. sparknlp/annotator/embeddings/auto_gguf_embeddings.py +539 -0
  99. sparknlp/annotator/embeddings/bert_embeddings.py +208 -0
  100. sparknlp/annotator/embeddings/bert_sentence_embeddings.py +224 -0
  101. sparknlp/annotator/embeddings/bge_embeddings.py +199 -0
  102. sparknlp/annotator/embeddings/camembert_embeddings.py +210 -0
  103. sparknlp/annotator/embeddings/chunk_embeddings.py +149 -0
  104. sparknlp/annotator/embeddings/deberta_embeddings.py +208 -0
  105. sparknlp/annotator/embeddings/distil_bert_embeddings.py +221 -0
  106. sparknlp/annotator/embeddings/doc2vec.py +352 -0
  107. sparknlp/annotator/embeddings/e5_embeddings.py +195 -0
  108. sparknlp/annotator/embeddings/e5v_embeddings.py +138 -0
  109. sparknlp/annotator/embeddings/elmo_embeddings.py +251 -0
  110. sparknlp/annotator/embeddings/instructor_embeddings.py +204 -0
  111. sparknlp/annotator/embeddings/longformer_embeddings.py +211 -0
  112. sparknlp/annotator/embeddings/minilm_embeddings.py +189 -0
  113. sparknlp/annotator/embeddings/mpnet_embeddings.py +192 -0
  114. sparknlp/annotator/embeddings/mxbai_embeddings.py +184 -0
  115. sparknlp/annotator/embeddings/nomic_embeddings.py +181 -0
  116. sparknlp/annotator/embeddings/roberta_embeddings.py +225 -0
  117. sparknlp/annotator/embeddings/roberta_sentence_embeddings.py +191 -0
  118. sparknlp/annotator/embeddings/sentence_embeddings.py +134 -0
  119. sparknlp/annotator/embeddings/snowflake_embeddings.py +202 -0
  120. sparknlp/annotator/embeddings/uae_embeddings.py +211 -0
  121. sparknlp/annotator/embeddings/universal_sentence_encoder.py +211 -0
  122. sparknlp/annotator/embeddings/word2vec.py +353 -0
  123. sparknlp/annotator/embeddings/word_embeddings.py +385 -0
  124. sparknlp/annotator/embeddings/xlm_roberta_embeddings.py +225 -0
  125. sparknlp/annotator/embeddings/xlm_roberta_sentence_embeddings.py +194 -0
  126. sparknlp/annotator/embeddings/xlnet_embeddings.py +227 -0
  127. sparknlp/annotator/er/__init__.py +16 -0
  128. sparknlp/annotator/er/entity_ruler.py +267 -0
  129. sparknlp/annotator/graph_extraction.py +368 -0
  130. sparknlp/annotator/keyword_extraction/__init__.py +16 -0
  131. sparknlp/annotator/keyword_extraction/yake_keyword_extraction.py +270 -0
  132. sparknlp/annotator/ld_dl/__init__.py +16 -0
  133. sparknlp/annotator/ld_dl/language_detector_dl.py +199 -0
  134. sparknlp/annotator/lemmatizer.py +250 -0
  135. sparknlp/annotator/matcher/__init__.py +20 -0
  136. sparknlp/annotator/matcher/big_text_matcher.py +272 -0
  137. sparknlp/annotator/matcher/date_matcher.py +303 -0
  138. sparknlp/annotator/matcher/multi_date_matcher.py +109 -0
  139. sparknlp/annotator/matcher/regex_matcher.py +221 -0
  140. sparknlp/annotator/matcher/text_matcher.py +290 -0
  141. sparknlp/annotator/n_gram_generator.py +141 -0
  142. sparknlp/annotator/ner/__init__.py +21 -0
  143. sparknlp/annotator/ner/ner_approach.py +94 -0
  144. sparknlp/annotator/ner/ner_converter.py +148 -0
  145. sparknlp/annotator/ner/ner_crf.py +397 -0
  146. sparknlp/annotator/ner/ner_dl.py +591 -0
  147. sparknlp/annotator/ner/ner_dl_graph_checker.py +293 -0
  148. sparknlp/annotator/ner/ner_overwriter.py +166 -0
  149. sparknlp/annotator/ner/zero_shot_ner_model.py +173 -0
  150. sparknlp/annotator/normalizer.py +230 -0
  151. sparknlp/annotator/openai/__init__.py +16 -0
  152. sparknlp/annotator/openai/openai_completion.py +349 -0
  153. sparknlp/annotator/openai/openai_embeddings.py +106 -0
  154. sparknlp/annotator/param/__init__.py +17 -0
  155. sparknlp/annotator/param/classifier_encoder.py +98 -0
  156. sparknlp/annotator/param/evaluation_dl_params.py +130 -0
  157. sparknlp/annotator/pos/__init__.py +16 -0
  158. sparknlp/annotator/pos/perceptron.py +263 -0
  159. sparknlp/annotator/sentence/__init__.py +17 -0
  160. sparknlp/annotator/sentence/sentence_detector.py +290 -0
  161. sparknlp/annotator/sentence/sentence_detector_dl.py +467 -0
  162. sparknlp/annotator/sentiment/__init__.py +17 -0
  163. sparknlp/annotator/sentiment/sentiment_detector.py +208 -0
  164. sparknlp/annotator/sentiment/vivekn_sentiment.py +242 -0
  165. sparknlp/annotator/seq2seq/__init__.py +35 -0
  166. sparknlp/annotator/seq2seq/auto_gguf_model.py +304 -0
  167. sparknlp/annotator/seq2seq/auto_gguf_reranker.py +334 -0
  168. sparknlp/annotator/seq2seq/auto_gguf_vision_model.py +336 -0
  169. sparknlp/annotator/seq2seq/bart_transformer.py +420 -0
  170. sparknlp/annotator/seq2seq/cohere_transformer.py +357 -0
  171. sparknlp/annotator/seq2seq/cpm_transformer.py +321 -0
  172. sparknlp/annotator/seq2seq/gpt2_transformer.py +363 -0
  173. sparknlp/annotator/seq2seq/llama2_transformer.py +343 -0
  174. sparknlp/annotator/seq2seq/llama3_transformer.py +381 -0
  175. sparknlp/annotator/seq2seq/m2m100_transformer.py +392 -0
  176. sparknlp/annotator/seq2seq/marian_transformer.py +374 -0
  177. sparknlp/annotator/seq2seq/mistral_transformer.py +348 -0
  178. sparknlp/annotator/seq2seq/nllb_transformer.py +420 -0
  179. sparknlp/annotator/seq2seq/olmo_transformer.py +326 -0
  180. sparknlp/annotator/seq2seq/phi2_transformer.py +326 -0
  181. sparknlp/annotator/seq2seq/phi3_transformer.py +330 -0
  182. sparknlp/annotator/seq2seq/phi4_transformer.py +387 -0
  183. sparknlp/annotator/seq2seq/qwen_transformer.py +340 -0
  184. sparknlp/annotator/seq2seq/starcoder_transformer.py +335 -0
  185. sparknlp/annotator/seq2seq/t5_transformer.py +425 -0
  186. sparknlp/annotator/similarity/__init__.py +0 -0
  187. sparknlp/annotator/similarity/document_similarity_ranker.py +379 -0
  188. sparknlp/annotator/spell_check/__init__.py +18 -0
  189. sparknlp/annotator/spell_check/context_spell_checker.py +911 -0
  190. sparknlp/annotator/spell_check/norvig_sweeting.py +358 -0
  191. sparknlp/annotator/spell_check/symmetric_delete.py +299 -0
  192. sparknlp/annotator/stemmer.py +79 -0
  193. sparknlp/annotator/stop_words_cleaner.py +190 -0
  194. sparknlp/annotator/tf_ner_dl_graph_builder.py +179 -0
  195. sparknlp/annotator/token/__init__.py +19 -0
  196. sparknlp/annotator/token/chunk_tokenizer.py +118 -0
  197. sparknlp/annotator/token/recursive_tokenizer.py +205 -0
  198. sparknlp/annotator/token/regex_tokenizer.py +208 -0
  199. sparknlp/annotator/token/tokenizer.py +561 -0
  200. sparknlp/annotator/token2_chunk.py +76 -0
  201. sparknlp/annotator/ws/__init__.py +16 -0
  202. sparknlp/annotator/ws/word_segmenter.py +429 -0
  203. sparknlp/base/__init__.py +30 -0
  204. sparknlp/base/audio_assembler.py +95 -0
  205. sparknlp/base/doc2_chunk.py +169 -0
  206. sparknlp/base/document_assembler.py +164 -0
  207. sparknlp/base/embeddings_finisher.py +201 -0
  208. sparknlp/base/finisher.py +217 -0
  209. sparknlp/base/gguf_ranking_finisher.py +234 -0
  210. sparknlp/base/graph_finisher.py +125 -0
  211. sparknlp/base/has_recursive_fit.py +24 -0
  212. sparknlp/base/has_recursive_transform.py +22 -0
  213. sparknlp/base/image_assembler.py +172 -0
  214. sparknlp/base/light_pipeline.py +429 -0
  215. sparknlp/base/multi_document_assembler.py +164 -0
  216. sparknlp/base/prompt_assembler.py +207 -0
  217. sparknlp/base/recursive_pipeline.py +107 -0
  218. sparknlp/base/table_assembler.py +145 -0
  219. sparknlp/base/token_assembler.py +124 -0
  220. sparknlp/common/__init__.py +26 -0
  221. sparknlp/common/annotator_approach.py +41 -0
  222. sparknlp/common/annotator_model.py +47 -0
  223. sparknlp/common/annotator_properties.py +114 -0
  224. sparknlp/common/annotator_type.py +38 -0
  225. sparknlp/common/completion_post_processing.py +37 -0
  226. sparknlp/common/coverage_result.py +22 -0
  227. sparknlp/common/match_strategy.py +33 -0
  228. sparknlp/common/properties.py +1298 -0
  229. sparknlp/common/read_as.py +33 -0
  230. sparknlp/common/recursive_annotator_approach.py +35 -0
  231. sparknlp/common/storage.py +149 -0
  232. sparknlp/common/utils.py +39 -0
  233. sparknlp/functions.py +315 -5
  234. sparknlp/internal/__init__.py +1199 -0
  235. sparknlp/internal/annotator_java_ml.py +32 -0
  236. sparknlp/internal/annotator_transformer.py +37 -0
  237. sparknlp/internal/extended_java_wrapper.py +63 -0
  238. sparknlp/internal/params_getters_setters.py +71 -0
  239. sparknlp/internal/recursive.py +70 -0
  240. sparknlp/logging/__init__.py +15 -0
  241. sparknlp/logging/comet.py +467 -0
  242. sparknlp/partition/__init__.py +16 -0
  243. sparknlp/partition/partition.py +244 -0
  244. sparknlp/partition/partition_properties.py +902 -0
  245. sparknlp/partition/partition_transformer.py +200 -0
  246. sparknlp/pretrained/__init__.py +17 -0
  247. sparknlp/pretrained/pretrained_pipeline.py +158 -0
  248. sparknlp/pretrained/resource_downloader.py +216 -0
  249. sparknlp/pretrained/utils.py +35 -0
  250. sparknlp/reader/__init__.py +15 -0
  251. sparknlp/reader/enums.py +19 -0
  252. sparknlp/reader/pdf_to_text.py +190 -0
  253. sparknlp/reader/reader2doc.py +124 -0
  254. sparknlp/reader/reader2image.py +136 -0
  255. sparknlp/reader/reader2table.py +44 -0
  256. sparknlp/reader/reader_assembler.py +159 -0
  257. sparknlp/reader/sparknlp_reader.py +461 -0
  258. sparknlp/training/__init__.py +20 -0
  259. sparknlp/training/_tf_graph_builders/__init__.py +0 -0
  260. sparknlp/training/_tf_graph_builders/graph_builders.py +299 -0
  261. sparknlp/training/_tf_graph_builders/ner_dl/__init__.py +0 -0
  262. sparknlp/training/_tf_graph_builders/ner_dl/create_graph.py +41 -0
  263. sparknlp/training/_tf_graph_builders/ner_dl/dataset_encoder.py +78 -0
  264. sparknlp/training/_tf_graph_builders/ner_dl/ner_model.py +521 -0
  265. sparknlp/training/_tf_graph_builders/ner_dl/ner_model_saver.py +62 -0
  266. sparknlp/training/_tf_graph_builders/ner_dl/sentence_grouper.py +28 -0
  267. sparknlp/training/_tf_graph_builders/tf2contrib/__init__.py +36 -0
  268. sparknlp/training/_tf_graph_builders/tf2contrib/core_rnn_cell.py +385 -0
  269. sparknlp/training/_tf_graph_builders/tf2contrib/fused_rnn_cell.py +183 -0
  270. sparknlp/training/_tf_graph_builders/tf2contrib/gru_ops.py +235 -0
  271. sparknlp/training/_tf_graph_builders/tf2contrib/lstm_ops.py +665 -0
  272. sparknlp/training/_tf_graph_builders/tf2contrib/rnn.py +245 -0
  273. sparknlp/training/_tf_graph_builders/tf2contrib/rnn_cell.py +4006 -0
  274. sparknlp/training/_tf_graph_builders_1x/__init__.py +0 -0
  275. sparknlp/training/_tf_graph_builders_1x/graph_builders.py +277 -0
  276. sparknlp/training/_tf_graph_builders_1x/ner_dl/__init__.py +0 -0
  277. sparknlp/training/_tf_graph_builders_1x/ner_dl/create_graph.py +34 -0
  278. sparknlp/training/_tf_graph_builders_1x/ner_dl/dataset_encoder.py +78 -0
  279. sparknlp/training/_tf_graph_builders_1x/ner_dl/ner_model.py +532 -0
  280. sparknlp/training/_tf_graph_builders_1x/ner_dl/ner_model_saver.py +62 -0
  281. sparknlp/training/_tf_graph_builders_1x/ner_dl/sentence_grouper.py +28 -0
  282. sparknlp/training/conll.py +150 -0
  283. sparknlp/training/conllu.py +103 -0
  284. sparknlp/training/pos.py +103 -0
  285. sparknlp/training/pub_tator.py +76 -0
  286. sparknlp/training/spacy_to_annotation.py +57 -0
  287. sparknlp/training/tfgraphs.py +5 -0
  288. sparknlp/upload_to_hub.py +149 -0
  289. sparknlp/util.py +51 -5
  290. com/__init__.pyc +0 -0
  291. com/__pycache__/__init__.cpython-36.pyc +0 -0
  292. com/johnsnowlabs/__init__.pyc +0 -0
  293. com/johnsnowlabs/__pycache__/__init__.cpython-36.pyc +0 -0
  294. com/johnsnowlabs/nlp/__init__.pyc +0 -0
  295. com/johnsnowlabs/nlp/__pycache__/__init__.cpython-36.pyc +0 -0
  296. spark_nlp-2.6.3rc1.dist-info/METADATA +0 -36
  297. spark_nlp-2.6.3rc1.dist-info/RECORD +0 -48
  298. sparknlp/__init__.pyc +0 -0
  299. sparknlp/__pycache__/__init__.cpython-36.pyc +0 -0
  300. sparknlp/__pycache__/annotation.cpython-36.pyc +0 -0
  301. sparknlp/__pycache__/annotator.cpython-36.pyc +0 -0
  302. sparknlp/__pycache__/base.cpython-36.pyc +0 -0
  303. sparknlp/__pycache__/common.cpython-36.pyc +0 -0
  304. sparknlp/__pycache__/embeddings.cpython-36.pyc +0 -0
  305. sparknlp/__pycache__/functions.cpython-36.pyc +0 -0
  306. sparknlp/__pycache__/internal.cpython-36.pyc +0 -0
  307. sparknlp/__pycache__/pretrained.cpython-36.pyc +0 -0
  308. sparknlp/__pycache__/storage.cpython-36.pyc +0 -0
  309. sparknlp/__pycache__/training.cpython-36.pyc +0 -0
  310. sparknlp/__pycache__/util.cpython-36.pyc +0 -0
  311. sparknlp/annotation.pyc +0 -0
  312. sparknlp/annotator.py +0 -3006
  313. sparknlp/annotator.pyc +0 -0
  314. sparknlp/base.py +0 -347
  315. sparknlp/base.pyc +0 -0
  316. sparknlp/common.py +0 -193
  317. sparknlp/common.pyc +0 -0
  318. sparknlp/embeddings.py +0 -40
  319. sparknlp/embeddings.pyc +0 -0
  320. sparknlp/internal.py +0 -288
  321. sparknlp/internal.pyc +0 -0
  322. sparknlp/pretrained.py +0 -123
  323. sparknlp/pretrained.pyc +0 -0
  324. sparknlp/storage.py +0 -32
  325. sparknlp/storage.pyc +0 -0
  326. sparknlp/training.py +0 -62
  327. sparknlp/training.pyc +0 -0
  328. sparknlp/util.pyc +0 -0
  329. {spark_nlp-2.6.3rc1.dist-info → spark_nlp-6.2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,521 @@
1
+ import math
2
+ import random
3
+ import sys
4
+
5
+ import numpy as np
6
+ import tensorflow as tf
7
+
8
+ from .sentence_grouper import SentenceGrouper
9
+ from ..tf2contrib import *
10
+
11
+
12
+ class NerModel:
13
+ # If session is not defined than default session will be used
14
+ def __init__(self, session=None, dummy_tags=None, use_contrib=True, use_gpu_device=0):
15
+
16
+ tf.disable_v2_behavior()
17
+ tf.enable_v2_tensorshape()
18
+
19
+ self.word_repr = None
20
+ self.word_embeddings = None
21
+ self.session = session
22
+ self.session_created = False
23
+ self.dummy_tags = dummy_tags or []
24
+ self.use_contrib = use_contrib
25
+ self.use_gpu_device = use_gpu_device
26
+
27
+ if self.session is None:
28
+ self.session_created = True
29
+ self.session = tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(
30
+ allow_soft_placement=True,
31
+ log_device_placement=False))
32
+ with tf.compat.v1.device('/gpu:{}'.format(self.use_gpu_device)):
33
+ with tf.compat.v1.variable_scope("char_repr"):
34
+ # shape = (batch size, sentence, word)
35
+ self.char_ids = tf.compat.v1.placeholder(tf.int32, shape=[None, None, None], name="char_ids")
36
+
37
+ # shape = (batch_size, sentence)
38
+ self.word_lengths = tf.compat.v1.placeholder(tf.int32, shape=[None, None], name="word_lengths")
39
+
40
+ with tf.compat.v1.variable_scope("word_repr"):
41
+ # shape = (batch size)
42
+ self.sentence_lengths = tf.compat.v1.placeholder(tf.int32, shape=[None], name="sentence_lengths")
43
+
44
+ with tf.compat.v1.variable_scope("training", reuse=None):
45
+ # shape = (batch, sentence)
46
+ self.labels = tf.compat.v1.placeholder(tf.int32, shape=[None, None], name="labels")
47
+
48
+ self.lr = tf.compat.v1.placeholder_with_default(0.005, shape=(), name="lr")
49
+ self.dropout = tf.compat.v1.placeholder(tf.float32, shape=(), name="dropout")
50
+
51
+ self._char_bilstm_added = False
52
+ self._char_cnn_added = False
53
+ self._word_embeddings_added = False
54
+ self._context_added = False
55
+ self._encode_added = False
56
+
57
+ def add_bilstm_char_repr(self, nchars=101, dim=25, hidden=25):
58
+ self._char_bilstm_added = True
59
+
60
+ with tf.compat.v1.device('/gpu:{}'.format(self.use_gpu_device)):
61
+
62
+ with tf.compat.v1.variable_scope("char_repr_lstm"):
63
+ # 1. Lookup for character embeddings
64
+ char_range = math.sqrt(3 / dim)
65
+ embeddings = tf.compat.v1.get_variable(name="char_embeddings",
66
+ dtype=tf.float32,
67
+ shape=[nchars, dim],
68
+ initializer=tf.compat.v1.random_uniform_initializer(
69
+ -char_range,
70
+ char_range
71
+ ),
72
+ use_resource=False)
73
+
74
+ # shape = (batch, sentence, word, char embeddings dim)
75
+ char_embeddings = tf.nn.embedding_lookup(params=embeddings, ids=self.char_ids)
76
+ # char_embeddings = tf.nn.dropout(char_embeddings, self.dropout)
77
+ s = tf.shape(input=char_embeddings)
78
+
79
+ # shape = (batch x sentence, word, char embeddings dim)
80
+ char_embeddings_seq = tf.reshape(char_embeddings, shape=[-1, s[-2], dim])
81
+
82
+ # shape = (batch x sentence)
83
+ word_lengths_seq = tf.reshape(self.word_lengths, shape=[-1])
84
+
85
+ # 2. Add Bidirectional LSTM
86
+ model = tf.keras.Sequential([
87
+ tf.keras.layers.Bidirectional(
88
+ layer=tf.keras.layers.LSTM(hidden, return_sequences=False),
89
+ merge_mode="concat"
90
+ )
91
+ ])
92
+
93
+ inputs = char_embeddings_seq
94
+ mask = tf.expand_dims(tf.sequence_mask(word_lengths_seq, dtype=tf.float32), axis=-1)
95
+
96
+ # shape = (batch x sentence, 2 x hidden)
97
+ output = model(inputs, mask=mask)
98
+
99
+ # shape = (batch, sentence, 2 x hidden)
100
+ char_repr = tf.reshape(output, shape=[-1, s[1], 2 * hidden])
101
+
102
+ if self.word_repr is not None:
103
+ self.word_repr = tf.concat([self.word_repr, char_repr], axis=-1)
104
+ else:
105
+ self.word_repr = char_repr
106
+
107
+ def add_cnn_char_repr(self, nchars=101, dim=25, nfilters=25, pad=2):
108
+ self._char_cnn_added = True
109
+
110
+ with tf.compat.v1.device('/gpu:{}'.format(self.use_gpu_device)):
111
+
112
+ with tf.compat.v1.variable_scope("char_repr_cnn"):
113
+ # 1. Lookup for character embeddings
114
+ char_range = math.sqrt(3 / dim)
115
+ embeddings = tf.compat.v1.get_variable(name="char_embeddings", dtype=tf.float32,
116
+ shape=[nchars, dim],
117
+ initializer=tf.compat.v1.random_uniform_initializer(-char_range,
118
+ char_range),
119
+ use_resource=False)
120
+
121
+ # shape = (batch, sentence, word_len, embeddings dim)
122
+ char_embeddings = tf.nn.embedding_lookup(params=embeddings, ids=self.char_ids)
123
+ # char_embeddings = tf.nn.dropout(char_embeddings, self.dropout)
124
+ s = tf.shape(input=char_embeddings)
125
+
126
+ # shape = (batch x sentence, word_len, embeddings dim)
127
+ char_embeddings = tf.reshape(char_embeddings, shape=[-1, s[-2], dim])
128
+
129
+ # batch x sentence, word_len, nfilters
130
+ conv1d = tf.keras.layers.Conv1D(
131
+ filters=nfilters,
132
+ kernel_size=[3],
133
+ padding='same',
134
+ activation=tf.nn.relu
135
+ )(char_embeddings)
136
+
137
+ # Max across each filter, shape = (batch x sentence, nfilters)
138
+ char_repr = tf.reduce_max(input_tensor=conv1d, axis=1, keepdims=True)
139
+ char_repr = tf.squeeze(char_repr, axis=[1])
140
+
141
+ # (batch, sentence, nfilters)
142
+ char_repr = tf.reshape(char_repr, shape=[s[0], s[1], nfilters])
143
+
144
+ if self.word_repr is not None:
145
+ self.word_repr = tf.concat([self.word_repr, char_repr], axis=-1)
146
+ else:
147
+ self.word_repr = char_repr
148
+
149
+ def add_pretrained_word_embeddings(self, dim=100):
150
+ self._word_embeddings_added = True
151
+
152
+ with tf.compat.v1.device('/gpu:{}'.format(self.use_gpu_device)):
153
+ with tf.compat.v1.variable_scope("word_repr"):
154
+ # shape = (batch size, sentence, dim)
155
+ self.word_embeddings = tf.compat.v1.placeholder(tf.float32, shape=[None, None, dim],
156
+ name="word_embeddings")
157
+
158
+ if self.word_repr is not None:
159
+ self.word_repr = tf.concat([self.word_repr, self.word_embeddings], axis=-1)
160
+ else:
161
+ self.word_repr = self.word_embeddings
162
+
163
+ def _create_lstm_layer(self, inputs, hidden_size, lengths):
164
+
165
+ with tf.compat.v1.device('/gpu:{}'.format(self.use_gpu_device)):
166
+ if not self.use_contrib:
167
+ raise ValueError("NER Tensorflow graphs can no longer be built without tf.contrib. Set use_contrib=True.")
168
+
169
+ time_based = tf.transpose(a=inputs, perm=[1, 0, 2])
170
+
171
+ cell_fw = LSTMBlockFusedCell(hidden_size, use_peephole=True)
172
+ cell_bw = LSTMBlockFusedCell(hidden_size, use_peephole=True)
173
+ cell_bw = TimeReversedFusedRNN(cell_bw)
174
+
175
+ output_fw, _ = cell_fw(time_based, dtype=tf.float32, sequence_length=lengths)
176
+ output_bw, _ = cell_bw(time_based, dtype=tf.float32, sequence_length=lengths)
177
+
178
+ result = tf.concat([output_fw, output_bw], axis=-1)
179
+
180
+ return tf.transpose(a=result, perm=[1, 0, 2])
181
+
182
+ def _multiply_layer(self, source, result_size, activation=tf.nn.relu):
183
+
184
+ with tf.compat.v1.device('/gpu:{}'.format(self.use_gpu_device)):
185
+ ntime_steps = tf.shape(input=source)[1]
186
+ source_size = source.shape[2]
187
+
188
+ W = tf.compat.v1.get_variable("W", shape=[source_size, result_size],
189
+ dtype=tf.float32,
190
+ initializer=tf.compat.v1.keras.initializers.VarianceScaling(scale=1.0,
191
+ mode="fan_avg",
192
+ distribution="uniform"),
193
+ use_resource=False)
194
+
195
+ b = tf.compat.v1.get_variable("b", shape=[result_size], dtype=tf.float32, use_resource=False)
196
+
197
+ # batch x time, source_size
198
+ source = tf.reshape(source, [-1, source_size])
199
+ # batch x time, result_size
200
+ result = tf.matmul(source, W) + b
201
+
202
+ result = tf.reshape(result, [-1, ntime_steps, result_size])
203
+ if activation:
204
+ result = activation(result)
205
+
206
+ return result
207
+
208
+ # Adds Bi LSTM with size of each cell hidden_size
209
+ def add_context_repr(self, ntags, hidden_size=100, height=1, residual=True):
210
+ assert (self._word_embeddings_added or self._char_cnn_added or self._char_bilstm_added,
211
+ "Add word embeddings by method add_word_embeddings " +
212
+ "or add char representation by method add_bilstm_char_repr " +
213
+ "or add_bilstm_char_repr before adding context layer")
214
+
215
+ self._context_added = True
216
+ self.ntags = ntags
217
+
218
+ with tf.compat.v1.device('/gpu:{}'.format(self.use_gpu_device)):
219
+ context_repr = self._multiply_layer(self.word_repr, 2 * hidden_size)
220
+ # Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`
221
+ context_repr = tf.nn.dropout(x=context_repr, rate=1 - self.dropout)
222
+
223
+ with tf.compat.v1.variable_scope("context_repr"):
224
+ for i in range(height):
225
+ with tf.compat.v1.variable_scope('lstm-{}'.format(i)):
226
+ new_repr = self._create_lstm_layer(context_repr, hidden_size,
227
+ lengths=self.sentence_lengths)
228
+
229
+ context_repr = new_repr + context_repr if residual else new_repr
230
+
231
+ context_repr = tf.nn.dropout(x=context_repr, rate=1 - self.dropout)
232
+
233
+ # batch, sentence, ntags
234
+ self.scores = self._multiply_layer(context_repr, ntags, activation=None)
235
+
236
+ tf.identity(self.scores, "scores")
237
+
238
+ self.predicted_labels = tf.argmax(input=self.scores, axis=-1)
239
+ tf.identity(self.predicted_labels, "predicted_labels")
240
+
241
+ def add_inference_layer(self, crf=False, predictions_op_name=None):
242
+ assert (self._context_added,
243
+ "Add context representation layer by method add_context_repr before adding inference layer")
244
+ self._inference_added = True
245
+
246
+ with tf.device('/gpu:{}'.format(self.use_gpu_device)):
247
+
248
+ with tf.compat.v1.variable_scope("inference", reuse=None):
249
+
250
+ self.crf = tf.constant(crf, dtype=tf.bool, name="crf")
251
+
252
+ if crf:
253
+ transition_params = tf.compat.v1.get_variable("transition_params",
254
+ shape=[self.ntags, self.ntags],
255
+ initializer=tf.compat.v1.keras.initializers.VarianceScaling(
256
+ scale=1.0, mode="fan_avg",
257
+ distribution="uniform"),
258
+ use_resource=False)
259
+
260
+ # CRF shape = (batch, sentence)
261
+ log_likelihood, self.transition_params = crf_log_likelihood(
262
+ self.scores,
263
+ self.labels,
264
+ self.sentence_lengths,
265
+ transition_params
266
+ )
267
+
268
+ tf.identity(log_likelihood, "log_likelihood")
269
+ tf.identity(self.transition_params, "transition_params")
270
+
271
+ self.loss = tf.reduce_mean(input_tensor=-log_likelihood)
272
+ if predictions_op_name:
273
+ with tf.compat.v1.variable_scope("inference_tmp", reuse=None):
274
+ tmp_prediction, _ = crf_decode(self.scores, self.transition_params, self.sentence_lengths)
275
+
276
+ self.prediction = tf.identity(tmp_prediction, name=predictions_op_name)
277
+ else:
278
+ self.prediction, _ = crf_decode(self.scores, self.transition_params, self.sentence_lengths)
279
+ print(self.prediction)
280
+ else:
281
+ # Softmax
282
+ losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.scores, labels=self.labels)
283
+ # shape = (batch, sentence, ntags)
284
+ mask = tf.sequence_mask(self.sentence_lengths)
285
+ # apply mask
286
+ losses = tf.boolean_mask(tensor=losses, mask=mask)
287
+
288
+ self.loss = tf.reduce_mean(input_tensor=losses)
289
+
290
+ self.prediction = tf.math.argmax(input=self.scores, axis=-1, name=predictions_op_name)
291
+
292
+ tf.identity(self.loss, "loss")
293
+
294
+ # clip_gradient < 0 - no gradient clipping
295
+ def add_training_op(self, clip_gradient=2.0, train_op_name=None):
296
+ assert (self._inference_added,
297
+ "Add inference layer by method add_inference_layer before adding training layer")
298
+ self._training_added = True
299
+
300
+ with tf.compat.v1.device('/gpu:{}'.format(self.use_gpu_device)):
301
+
302
+ with tf.compat.v1.variable_scope("training", reuse=None):
303
+ if train_op_name:
304
+ optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=self.lr, name=train_op_name)
305
+ else:
306
+ optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=self.lr)
307
+ if clip_gradient > 0:
308
+ gvs = optimizer.compute_gradients(self.loss)
309
+ capped_gvs = [(tf.clip_by_value(grad, -clip_gradient, clip_gradient), var) for grad, var in gvs if
310
+ grad is not None]
311
+ self.train_op = optimizer.apply_gradients(capped_gvs)
312
+ else:
313
+ self.train_op = optimizer.minimize(self.loss)
314
+
315
+ self.init_op = tf.compat.v1.variables_initializer(tf.compat.v1.global_variables(), name="init")
316
+
317
+ @staticmethod
318
+ def num_trues(array):
319
+ result = 0
320
+ for item in array:
321
+ if item == True:
322
+ result += 1
323
+
324
+ return result
325
+
326
+ @staticmethod
327
+ def fill(array, l, val):
328
+ result = array[:]
329
+ for i in range(l - len(array)):
330
+ result.append(val)
331
+ return result
332
+
333
+ @staticmethod
334
+ def get_sentence_lengths(batch, idx="word_embeddings"):
335
+ return [len(row[idx]) for row in batch]
336
+
337
+ @staticmethod
338
+ def get_sentence_token_lengths(batch, idx="tag_ids"):
339
+ return [len(row[idx]) for row in batch]
340
+
341
+ @staticmethod
342
+ def get_word_lengths(batch, idx="char_ids"):
343
+ max_words = max([len(row[idx]) for row in batch])
344
+ return [NerModel.fill([len(chars) for chars in row[idx]], max_words, 0)
345
+ for row in batch]
346
+
347
+ @staticmethod
348
+ def get_char_ids(batch, idx="char_ids"):
349
+ max_chars = max([max([len(char_ids) for char_ids in sentence[idx]]) for sentence in batch])
350
+ max_words = max([len(sentence[idx]) for sentence in batch])
351
+
352
+ return [
353
+ NerModel.fill(
354
+ [NerModel.fill(char_ids, max_chars, 0) for char_ids in sentence[idx]],
355
+ max_words, [0] * max_chars
356
+ )
357
+ for sentence in batch]
358
+
359
+ @staticmethod
360
+ def get_from_batch(batch, idx):
361
+ k = max([len(row[idx]) for row in batch])
362
+ return list([NerModel.fill(row[idx], k, 0) for row in batch])
363
+
364
+ @staticmethod
365
+ def get_tag_ids(batch, idx="tag_ids"):
366
+ return NerModel.get_from_batch(batch, idx)
367
+
368
+ @staticmethod
369
+ def get_word_embeddings(batch, idx="word_embeddings"):
370
+ embeddings_dim = len(batch[0][idx][0])
371
+ max_words = max([len(sentence[idx]) for sentence in batch])
372
+ return [
373
+ NerModel.fill([word_embedding for word_embedding in sentence[idx]],
374
+ max_words, [0] * embeddings_dim
375
+ )
376
+ for sentence in batch]
377
+
378
+ @staticmethod
379
+ def slice(dataset, batch_size=10):
380
+ grouper = SentenceGrouper([5, 10, 20, 50])
381
+ return grouper.slice(dataset, batch_size)
382
+
383
+ def init_variables(self):
384
+ self.session.run(self.init_op)
385
+
386
+ def train(self, train,
387
+ epoch_start=0,
388
+ epoch_end=100,
389
+ batch_size=32,
390
+ lr=0.01,
391
+ po=0,
392
+ dropout=0.65,
393
+ init_variables=False
394
+ ):
395
+
396
+ assert (self._training_added, "Add training layer by method add_training_op before running training")
397
+
398
+ if init_variables:
399
+ with tf.compat.v1.device('/gpu:{}'.format(self.use_gpu_device)):
400
+ self.session.run(tf.compat.v1.global_variables_initializer())
401
+
402
+ print('trainig started')
403
+ for epoch in range(epoch_start, epoch_end):
404
+ random.shuffle(train)
405
+ sum_loss = 0
406
+ for batch in NerModel.slice(train, batch_size):
407
+ feed_dict = {
408
+ self.sentence_lengths: NerModel.get_sentence_lengths(batch),
409
+ self.word_embeddings: NerModel.get_word_embeddings(batch),
410
+
411
+ self.word_lengths: NerModel.get_word_lengths(batch),
412
+ self.char_ids: NerModel.get_char_ids(batch),
413
+ self.labels: NerModel.get_tag_ids(batch),
414
+
415
+ self.dropout: dropout,
416
+ self.lr: lr / (1 + po * epoch)
417
+ }
418
+ mean_loss, _ = self.session.run([self.loss, self.train_op], feed_dict=feed_dict)
419
+ sum_loss += mean_loss
420
+
421
+ print("epoch {}".format(epoch))
422
+ print("mean loss: {}".format(sum_loss))
423
+ print()
424
+ sys.stdout.flush()
425
+
426
+ def measure(self, dataset, batch_size=20, dropout=1.0):
427
+ predicted = {}
428
+ correct = {}
429
+ correct_predicted = {}
430
+
431
+ for batch in NerModel.slice(dataset, batch_size):
432
+ tags_ids = NerModel.get_tag_ids(batch)
433
+ sentence_lengths = NerModel.get_sentence_lengths(batch)
434
+
435
+ feed_dict = {
436
+ self.sentence_lengths: sentence_lengths,
437
+ self.word_embeddings: NerModel.get_word_embeddings(batch),
438
+
439
+ self.word_lengths: NerModel.get_word_lengths(batch),
440
+ self.char_ids: NerModel.get_char_ids(batch),
441
+ self.labels: tags_ids,
442
+
443
+ self.dropout: dropout
444
+ }
445
+
446
+ prediction = self.session.run(self.prediction, feed_dict=feed_dict)
447
+ batch_prediction = np.reshape(prediction, (len(batch), -1))
448
+
449
+ for i in range(len(batch)):
450
+ is_word_start = batch[i]['is_word_start']
451
+
452
+ for word in range(sentence_lengths[i]):
453
+ if not is_word_start[word]:
454
+ continue
455
+
456
+ p = batch_prediction[i][word]
457
+ c = tags_ids[i][word]
458
+
459
+ if c in self.dummy_tags:
460
+ continue
461
+
462
+ predicted[p] = predicted.get(p, 0) + 1
463
+ correct[c] = correct.get(c, 0) + 1
464
+ if p == c:
465
+ correct_predicted[p] = correct_predicted.get(p, 0) + 1
466
+
467
+ num_correct_predicted = sum([correct_predicted.get(i, 0) for i in range(1, self.ntags)])
468
+ num_predicted = sum([predicted.get(i, 0) for i in range(1, self.ntags)])
469
+ num_correct = sum([correct.get(i, 0) for i in range(1, self.ntags)])
470
+
471
+ prec = num_correct_predicted / (num_predicted or 1.)
472
+ rec = num_correct_predicted / (num_correct or 1.)
473
+
474
+ f1 = 2 * prec * rec / (rec + prec)
475
+
476
+ return prec, rec, f1
477
+
478
+ @staticmethod
479
+ def get_softmax(scores, threshold=None):
480
+ exp_scores = np.exp(scores)
481
+
482
+ for _ in exp_scores:
483
+ for sentence in exp_scores:
484
+ for i in range(len(sentence)):
485
+ probabilities = sentence[i] / np.sum(sentence[i])
486
+ sentence[i] = [p if threshold is None or p >= threshold else 0 for p in probabilities]
487
+
488
+ return exp_scores
489
+
490
+ def predict(self, sentences, batch_size=20, threshold=None):
491
+ result = []
492
+
493
+ for batch in NerModel.slice(sentences, batch_size):
494
+ sentence_lengths = NerModel.get_sentence_lengths(batch)
495
+
496
+ feed_dict = {
497
+ self.sentence_lengths: sentence_lengths,
498
+ self.word_embeddings: NerModel.get_word_embeddings(batch),
499
+
500
+ self.word_lengths: NerModel.get_word_lengths(batch),
501
+ self.char_ids: NerModel.get_char_ids(batch),
502
+
503
+ self.dropout: 1.1
504
+ }
505
+
506
+ prediction = self.session.run(self.prediction, feed_dict=feed_dict)
507
+ batch_prediction = np.reshape(prediction, (len(batch), -1))
508
+
509
+ for i in range(len(batch)):
510
+ sentence = []
511
+ for word in range(sentence_lengths[i]):
512
+ tag = batch_prediction[i][word]
513
+ sentence.append(tag)
514
+
515
+ result.append(sentence)
516
+
517
+ return result
518
+
519
+ def close(self):
520
+ if self.session_created:
521
+ self.session.close()
@@ -0,0 +1,62 @@
1
+ import os
2
+
3
+ import tensorflow as tf
4
+
5
+
6
+ class NerModelSaver:
7
+ def __init__(self, ner, encoder, embeddings_file=None):
8
+ self.ner = ner
9
+ self.encoder = encoder
10
+ self.embeddings_file = embeddings_file
11
+
12
+ @staticmethod
13
+ def restore_tensorflow_state(session, export_dir):
14
+ with tf.device('/gpu:0'):
15
+ saveNodes = list([n.name for n in tf.get_default_graph().as_graph_def().node if n.name.startswith('save/')])
16
+ if len(saveNodes) == 0:
17
+ saver = tf.train.Saver()
18
+
19
+ variables_file = os.path.join(export_dir, 'variables')
20
+ session.run("save/restore_all", feed_dict={'save/Const:0': variables_file})
21
+
22
+ def save_models(self, folder):
23
+ with tf.device('/gpu:0'):
24
+ saveNodes = list([n.name for n in tf.get_default_graph().as_graph_def().node if n.name.startswith('save/')])
25
+ if len(saveNodes) == 0:
26
+ saver = tf.train.Saver()
27
+
28
+ variables_file = os.path.join(folder, 'variables')
29
+ self.ner.session.run('save/control_dependency', feed_dict={'save/Const:0': variables_file})
30
+ tf.train.write_graph(self.ner.session.graph, folder, 'saved_model.pb', False)
31
+
32
+ def save(self, export_dir):
33
+ def save_tags(file):
34
+ id2tag = {id: tag for (tag, id) in self.encoder.tag2id.items()}
35
+
36
+ with open(file, 'w') as f:
37
+ for i in range(len(id2tag)):
38
+ tag = id2tag[i]
39
+ f.write(tag)
40
+ f.write('\n')
41
+
42
+ def save_embeddings(src, dst):
43
+ from shutil import copyfile
44
+ copyfile(src, dst)
45
+ with open(dst + '.meta', 'w') as f:
46
+ embeddings = self.encoder.embeddings
47
+ dim = len(embeddings[0]) if embeddings else 0
48
+ f.write(str(dim))
49
+
50
+ def save_chars(file):
51
+ id2char = {id: char for (char, id) in self.encoder.char2id.items()}
52
+ with open(file, 'w') as f:
53
+ for i in range(1, len(id2char) + 1):
54
+ f.write(id2char[i])
55
+
56
+ save_models(export_dir)
57
+ save_tags(os.path.join(export_dir, 'tags.csv'))
58
+
59
+ if self.embeddings_file:
60
+ save_embeddings(self.embeddings_file, os.path.join(export_dir, 'embeddings'))
61
+
62
+ save_chars(os.path.join(export_dir, 'chars.csv'))
@@ -0,0 +1,28 @@
1
+ class SentenceGrouper:
2
+ def __init__(self, bucket_lengths):
3
+ self.bucket_lengths = bucket_lengths
4
+
5
+ def get_bucket_id(self, length):
6
+ for i, bucket_len in enumerate(self.bucket_lengths):
7
+ if length <= bucket_len:
8
+ return i
9
+
10
+ return len(self.bucket_lengths)
11
+
12
+ def slice(self, dataset, batch_size=32):
13
+ buckets = [[] for item in self.bucket_lengths]
14
+ buckets.append([])
15
+
16
+ for entry in dataset:
17
+ length = len(entry['words'])
18
+ bucket_id = self.get_bucket_id(length)
19
+ buckets[bucket_id].append(entry)
20
+
21
+ if len(buckets[bucket_id]) >= batch_size:
22
+ result = buckets[bucket_id][:]
23
+ yield result
24
+ buckets[bucket_id] = []
25
+
26
+ for bucket in buckets:
27
+ if len(bucket) > 0:
28
+ yield bucket
@@ -0,0 +1,36 @@
1
+ """
2
+ This is a distribution of the tf.contrib module python files available in TensorFlow 1.x:
3
+ https://github.com/tensorflow/tensorflow/blob/r1.15/tensorflow/contrib/rnn/python/ops
4
+ The original source code files are not modified, the only change is in this file.
5
+ This distribution includes just the python ops of tf.contrib and therefore not all not all functionality
6
+ of tf.contrib is enabled.
7
+ """
8
+ import tensorflow as tf
9
+
10
+ if tf.__version__[0] == '2':
11
+ # TensorFlow 2.x, so use tensorflow_addons and the custom distribution of tf.contrib
12
+ import tensorflow_addons
13
+
14
+ tf = tf.compat.v1
15
+
16
+ crf_decode = tensorflow_addons.text.crf_decode
17
+ crf_log_likelihood = tensorflow_addons.text.crf_log_likelihood
18
+ USE_TF2 = True
19
+
20
+ from .lstm_ops import *
21
+ from .fused_rnn_cell import *
22
+ from .rnn import *
23
+ from tensorflow.compat.v1.nn.rnn_cell import *
24
+
25
+ elif tf.__version__.startswith("1.15"):
26
+ # Tensorflow 1.15, use original tf.contrib
27
+
28
+ crf_decode = tf.contrib.crf.crf_decode
29
+ crf_log_likelihood = tf.contrib.crf.crf_log_likelihood
30
+ USE_TF2 = False
31
+
32
+ from tensorflow.contrib.rnn import *
33
+
34
+ else:
35
+ # Nothing can be done, exit
36
+ raise ValueError("This version of TensorFlow is not supported!")