spark-nlp 2.6.3rc1__py2.py3-none-any.whl → 6.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (329) hide show
  1. com/johnsnowlabs/ml/__init__.py +0 -0
  2. com/johnsnowlabs/ml/ai/__init__.py +10 -0
  3. com/johnsnowlabs/nlp/__init__.py +4 -2
  4. spark_nlp-6.2.1.dist-info/METADATA +362 -0
  5. spark_nlp-6.2.1.dist-info/RECORD +292 -0
  6. {spark_nlp-2.6.3rc1.dist-info → spark_nlp-6.2.1.dist-info}/WHEEL +1 -1
  7. sparknlp/__init__.py +281 -27
  8. sparknlp/annotation.py +137 -6
  9. sparknlp/annotation_audio.py +61 -0
  10. sparknlp/annotation_image.py +82 -0
  11. sparknlp/annotator/__init__.py +93 -0
  12. sparknlp/annotator/audio/__init__.py +16 -0
  13. sparknlp/annotator/audio/hubert_for_ctc.py +188 -0
  14. sparknlp/annotator/audio/wav2vec2_for_ctc.py +161 -0
  15. sparknlp/annotator/audio/whisper_for_ctc.py +251 -0
  16. sparknlp/annotator/chunk2_doc.py +85 -0
  17. sparknlp/annotator/chunker.py +137 -0
  18. sparknlp/annotator/classifier_dl/__init__.py +61 -0
  19. sparknlp/annotator/classifier_dl/albert_for_multiple_choice.py +161 -0
  20. sparknlp/annotator/classifier_dl/albert_for_question_answering.py +172 -0
  21. sparknlp/annotator/classifier_dl/albert_for_sequence_classification.py +201 -0
  22. sparknlp/annotator/classifier_dl/albert_for_token_classification.py +179 -0
  23. sparknlp/annotator/classifier_dl/albert_for_zero_shot_classification.py +211 -0
  24. sparknlp/annotator/classifier_dl/bart_for_zero_shot_classification.py +225 -0
  25. sparknlp/annotator/classifier_dl/bert_for_multiple_choice.py +161 -0
  26. sparknlp/annotator/classifier_dl/bert_for_question_answering.py +168 -0
  27. sparknlp/annotator/classifier_dl/bert_for_sequence_classification.py +202 -0
  28. sparknlp/annotator/classifier_dl/bert_for_token_classification.py +177 -0
  29. sparknlp/annotator/classifier_dl/bert_for_zero_shot_classification.py +212 -0
  30. sparknlp/annotator/classifier_dl/camembert_for_question_answering.py +168 -0
  31. sparknlp/annotator/classifier_dl/camembert_for_sequence_classification.py +205 -0
  32. sparknlp/annotator/classifier_dl/camembert_for_token_classification.py +173 -0
  33. sparknlp/annotator/classifier_dl/camembert_for_zero_shot_classification.py +202 -0
  34. sparknlp/annotator/classifier_dl/classifier_dl.py +320 -0
  35. sparknlp/annotator/classifier_dl/deberta_for_question_answering.py +168 -0
  36. sparknlp/annotator/classifier_dl/deberta_for_sequence_classification.py +198 -0
  37. sparknlp/annotator/classifier_dl/deberta_for_token_classification.py +175 -0
  38. sparknlp/annotator/classifier_dl/deberta_for_zero_shot_classification.py +193 -0
  39. sparknlp/annotator/classifier_dl/distil_bert_for_question_answering.py +168 -0
  40. sparknlp/annotator/classifier_dl/distil_bert_for_sequence_classification.py +201 -0
  41. sparknlp/annotator/classifier_dl/distil_bert_for_token_classification.py +175 -0
  42. sparknlp/annotator/classifier_dl/distil_bert_for_zero_shot_classification.py +211 -0
  43. sparknlp/annotator/classifier_dl/distilbert_for_multiple_choice.py +161 -0
  44. sparknlp/annotator/classifier_dl/longformer_for_question_answering.py +168 -0
  45. sparknlp/annotator/classifier_dl/longformer_for_sequence_classification.py +201 -0
  46. sparknlp/annotator/classifier_dl/longformer_for_token_classification.py +176 -0
  47. sparknlp/annotator/classifier_dl/mpnet_for_question_answering.py +148 -0
  48. sparknlp/annotator/classifier_dl/mpnet_for_sequence_classification.py +188 -0
  49. sparknlp/annotator/classifier_dl/mpnet_for_token_classification.py +173 -0
  50. sparknlp/annotator/classifier_dl/multi_classifier_dl.py +395 -0
  51. sparknlp/annotator/classifier_dl/roberta_for_multiple_choice.py +161 -0
  52. sparknlp/annotator/classifier_dl/roberta_for_question_answering.py +168 -0
  53. sparknlp/annotator/classifier_dl/roberta_for_sequence_classification.py +201 -0
  54. sparknlp/annotator/classifier_dl/roberta_for_token_classification.py +189 -0
  55. sparknlp/annotator/classifier_dl/roberta_for_zero_shot_classification.py +225 -0
  56. sparknlp/annotator/classifier_dl/sentiment_dl.py +378 -0
  57. sparknlp/annotator/classifier_dl/tapas_for_question_answering.py +170 -0
  58. sparknlp/annotator/classifier_dl/xlm_roberta_for_multiple_choice.py +149 -0
  59. sparknlp/annotator/classifier_dl/xlm_roberta_for_question_answering.py +168 -0
  60. sparknlp/annotator/classifier_dl/xlm_roberta_for_sequence_classification.py +201 -0
  61. sparknlp/annotator/classifier_dl/xlm_roberta_for_token_classification.py +173 -0
  62. sparknlp/annotator/classifier_dl/xlm_roberta_for_zero_shot_classification.py +225 -0
  63. sparknlp/annotator/classifier_dl/xlnet_for_sequence_classification.py +201 -0
  64. sparknlp/annotator/classifier_dl/xlnet_for_token_classification.py +176 -0
  65. sparknlp/annotator/cleaners/__init__.py +15 -0
  66. sparknlp/annotator/cleaners/cleaner.py +202 -0
  67. sparknlp/annotator/cleaners/extractor.py +191 -0
  68. sparknlp/annotator/coref/__init__.py +1 -0
  69. sparknlp/annotator/coref/spanbert_coref.py +221 -0
  70. sparknlp/annotator/cv/__init__.py +29 -0
  71. sparknlp/annotator/cv/blip_for_question_answering.py +172 -0
  72. sparknlp/annotator/cv/clip_for_zero_shot_classification.py +193 -0
  73. sparknlp/annotator/cv/convnext_for_image_classification.py +269 -0
  74. sparknlp/annotator/cv/florence2_transformer.py +180 -0
  75. sparknlp/annotator/cv/gemma3_for_multimodal.py +346 -0
  76. sparknlp/annotator/cv/internvl_for_multimodal.py +280 -0
  77. sparknlp/annotator/cv/janus_for_multimodal.py +351 -0
  78. sparknlp/annotator/cv/llava_for_multimodal.py +328 -0
  79. sparknlp/annotator/cv/mllama_for_multimodal.py +340 -0
  80. sparknlp/annotator/cv/paligemma_for_multimodal.py +308 -0
  81. sparknlp/annotator/cv/phi3_vision_for_multimodal.py +328 -0
  82. sparknlp/annotator/cv/qwen2vl_transformer.py +332 -0
  83. sparknlp/annotator/cv/smolvlm_transformer.py +426 -0
  84. sparknlp/annotator/cv/swin_for_image_classification.py +242 -0
  85. sparknlp/annotator/cv/vision_encoder_decoder_for_image_captioning.py +240 -0
  86. sparknlp/annotator/cv/vit_for_image_classification.py +217 -0
  87. sparknlp/annotator/dataframe_optimizer.py +216 -0
  88. sparknlp/annotator/date2_chunk.py +88 -0
  89. sparknlp/annotator/dependency/__init__.py +17 -0
  90. sparknlp/annotator/dependency/dependency_parser.py +294 -0
  91. sparknlp/annotator/dependency/typed_dependency_parser.py +318 -0
  92. sparknlp/annotator/document_character_text_splitter.py +228 -0
  93. sparknlp/annotator/document_normalizer.py +235 -0
  94. sparknlp/annotator/document_token_splitter.py +175 -0
  95. sparknlp/annotator/document_token_splitter_test.py +85 -0
  96. sparknlp/annotator/embeddings/__init__.py +45 -0
  97. sparknlp/annotator/embeddings/albert_embeddings.py +230 -0
  98. sparknlp/annotator/embeddings/auto_gguf_embeddings.py +539 -0
  99. sparknlp/annotator/embeddings/bert_embeddings.py +208 -0
  100. sparknlp/annotator/embeddings/bert_sentence_embeddings.py +224 -0
  101. sparknlp/annotator/embeddings/bge_embeddings.py +199 -0
  102. sparknlp/annotator/embeddings/camembert_embeddings.py +210 -0
  103. sparknlp/annotator/embeddings/chunk_embeddings.py +149 -0
  104. sparknlp/annotator/embeddings/deberta_embeddings.py +208 -0
  105. sparknlp/annotator/embeddings/distil_bert_embeddings.py +221 -0
  106. sparknlp/annotator/embeddings/doc2vec.py +352 -0
  107. sparknlp/annotator/embeddings/e5_embeddings.py +195 -0
  108. sparknlp/annotator/embeddings/e5v_embeddings.py +138 -0
  109. sparknlp/annotator/embeddings/elmo_embeddings.py +251 -0
  110. sparknlp/annotator/embeddings/instructor_embeddings.py +204 -0
  111. sparknlp/annotator/embeddings/longformer_embeddings.py +211 -0
  112. sparknlp/annotator/embeddings/minilm_embeddings.py +189 -0
  113. sparknlp/annotator/embeddings/mpnet_embeddings.py +192 -0
  114. sparknlp/annotator/embeddings/mxbai_embeddings.py +184 -0
  115. sparknlp/annotator/embeddings/nomic_embeddings.py +181 -0
  116. sparknlp/annotator/embeddings/roberta_embeddings.py +225 -0
  117. sparknlp/annotator/embeddings/roberta_sentence_embeddings.py +191 -0
  118. sparknlp/annotator/embeddings/sentence_embeddings.py +134 -0
  119. sparknlp/annotator/embeddings/snowflake_embeddings.py +202 -0
  120. sparknlp/annotator/embeddings/uae_embeddings.py +211 -0
  121. sparknlp/annotator/embeddings/universal_sentence_encoder.py +211 -0
  122. sparknlp/annotator/embeddings/word2vec.py +353 -0
  123. sparknlp/annotator/embeddings/word_embeddings.py +385 -0
  124. sparknlp/annotator/embeddings/xlm_roberta_embeddings.py +225 -0
  125. sparknlp/annotator/embeddings/xlm_roberta_sentence_embeddings.py +194 -0
  126. sparknlp/annotator/embeddings/xlnet_embeddings.py +227 -0
  127. sparknlp/annotator/er/__init__.py +16 -0
  128. sparknlp/annotator/er/entity_ruler.py +267 -0
  129. sparknlp/annotator/graph_extraction.py +368 -0
  130. sparknlp/annotator/keyword_extraction/__init__.py +16 -0
  131. sparknlp/annotator/keyword_extraction/yake_keyword_extraction.py +270 -0
  132. sparknlp/annotator/ld_dl/__init__.py +16 -0
  133. sparknlp/annotator/ld_dl/language_detector_dl.py +199 -0
  134. sparknlp/annotator/lemmatizer.py +250 -0
  135. sparknlp/annotator/matcher/__init__.py +20 -0
  136. sparknlp/annotator/matcher/big_text_matcher.py +272 -0
  137. sparknlp/annotator/matcher/date_matcher.py +303 -0
  138. sparknlp/annotator/matcher/multi_date_matcher.py +109 -0
  139. sparknlp/annotator/matcher/regex_matcher.py +221 -0
  140. sparknlp/annotator/matcher/text_matcher.py +290 -0
  141. sparknlp/annotator/n_gram_generator.py +141 -0
  142. sparknlp/annotator/ner/__init__.py +21 -0
  143. sparknlp/annotator/ner/ner_approach.py +94 -0
  144. sparknlp/annotator/ner/ner_converter.py +148 -0
  145. sparknlp/annotator/ner/ner_crf.py +397 -0
  146. sparknlp/annotator/ner/ner_dl.py +591 -0
  147. sparknlp/annotator/ner/ner_dl_graph_checker.py +293 -0
  148. sparknlp/annotator/ner/ner_overwriter.py +166 -0
  149. sparknlp/annotator/ner/zero_shot_ner_model.py +173 -0
  150. sparknlp/annotator/normalizer.py +230 -0
  151. sparknlp/annotator/openai/__init__.py +16 -0
  152. sparknlp/annotator/openai/openai_completion.py +349 -0
  153. sparknlp/annotator/openai/openai_embeddings.py +106 -0
  154. sparknlp/annotator/param/__init__.py +17 -0
  155. sparknlp/annotator/param/classifier_encoder.py +98 -0
  156. sparknlp/annotator/param/evaluation_dl_params.py +130 -0
  157. sparknlp/annotator/pos/__init__.py +16 -0
  158. sparknlp/annotator/pos/perceptron.py +263 -0
  159. sparknlp/annotator/sentence/__init__.py +17 -0
  160. sparknlp/annotator/sentence/sentence_detector.py +290 -0
  161. sparknlp/annotator/sentence/sentence_detector_dl.py +467 -0
  162. sparknlp/annotator/sentiment/__init__.py +17 -0
  163. sparknlp/annotator/sentiment/sentiment_detector.py +208 -0
  164. sparknlp/annotator/sentiment/vivekn_sentiment.py +242 -0
  165. sparknlp/annotator/seq2seq/__init__.py +35 -0
  166. sparknlp/annotator/seq2seq/auto_gguf_model.py +304 -0
  167. sparknlp/annotator/seq2seq/auto_gguf_reranker.py +334 -0
  168. sparknlp/annotator/seq2seq/auto_gguf_vision_model.py +336 -0
  169. sparknlp/annotator/seq2seq/bart_transformer.py +420 -0
  170. sparknlp/annotator/seq2seq/cohere_transformer.py +357 -0
  171. sparknlp/annotator/seq2seq/cpm_transformer.py +321 -0
  172. sparknlp/annotator/seq2seq/gpt2_transformer.py +363 -0
  173. sparknlp/annotator/seq2seq/llama2_transformer.py +343 -0
  174. sparknlp/annotator/seq2seq/llama3_transformer.py +381 -0
  175. sparknlp/annotator/seq2seq/m2m100_transformer.py +392 -0
  176. sparknlp/annotator/seq2seq/marian_transformer.py +374 -0
  177. sparknlp/annotator/seq2seq/mistral_transformer.py +348 -0
  178. sparknlp/annotator/seq2seq/nllb_transformer.py +420 -0
  179. sparknlp/annotator/seq2seq/olmo_transformer.py +326 -0
  180. sparknlp/annotator/seq2seq/phi2_transformer.py +326 -0
  181. sparknlp/annotator/seq2seq/phi3_transformer.py +330 -0
  182. sparknlp/annotator/seq2seq/phi4_transformer.py +387 -0
  183. sparknlp/annotator/seq2seq/qwen_transformer.py +340 -0
  184. sparknlp/annotator/seq2seq/starcoder_transformer.py +335 -0
  185. sparknlp/annotator/seq2seq/t5_transformer.py +425 -0
  186. sparknlp/annotator/similarity/__init__.py +0 -0
  187. sparknlp/annotator/similarity/document_similarity_ranker.py +379 -0
  188. sparknlp/annotator/spell_check/__init__.py +18 -0
  189. sparknlp/annotator/spell_check/context_spell_checker.py +911 -0
  190. sparknlp/annotator/spell_check/norvig_sweeting.py +358 -0
  191. sparknlp/annotator/spell_check/symmetric_delete.py +299 -0
  192. sparknlp/annotator/stemmer.py +79 -0
  193. sparknlp/annotator/stop_words_cleaner.py +190 -0
  194. sparknlp/annotator/tf_ner_dl_graph_builder.py +179 -0
  195. sparknlp/annotator/token/__init__.py +19 -0
  196. sparknlp/annotator/token/chunk_tokenizer.py +118 -0
  197. sparknlp/annotator/token/recursive_tokenizer.py +205 -0
  198. sparknlp/annotator/token/regex_tokenizer.py +208 -0
  199. sparknlp/annotator/token/tokenizer.py +561 -0
  200. sparknlp/annotator/token2_chunk.py +76 -0
  201. sparknlp/annotator/ws/__init__.py +16 -0
  202. sparknlp/annotator/ws/word_segmenter.py +429 -0
  203. sparknlp/base/__init__.py +30 -0
  204. sparknlp/base/audio_assembler.py +95 -0
  205. sparknlp/base/doc2_chunk.py +169 -0
  206. sparknlp/base/document_assembler.py +164 -0
  207. sparknlp/base/embeddings_finisher.py +201 -0
  208. sparknlp/base/finisher.py +217 -0
  209. sparknlp/base/gguf_ranking_finisher.py +234 -0
  210. sparknlp/base/graph_finisher.py +125 -0
  211. sparknlp/base/has_recursive_fit.py +24 -0
  212. sparknlp/base/has_recursive_transform.py +22 -0
  213. sparknlp/base/image_assembler.py +172 -0
  214. sparknlp/base/light_pipeline.py +429 -0
  215. sparknlp/base/multi_document_assembler.py +164 -0
  216. sparknlp/base/prompt_assembler.py +207 -0
  217. sparknlp/base/recursive_pipeline.py +107 -0
  218. sparknlp/base/table_assembler.py +145 -0
  219. sparknlp/base/token_assembler.py +124 -0
  220. sparknlp/common/__init__.py +26 -0
  221. sparknlp/common/annotator_approach.py +41 -0
  222. sparknlp/common/annotator_model.py +47 -0
  223. sparknlp/common/annotator_properties.py +114 -0
  224. sparknlp/common/annotator_type.py +38 -0
  225. sparknlp/common/completion_post_processing.py +37 -0
  226. sparknlp/common/coverage_result.py +22 -0
  227. sparknlp/common/match_strategy.py +33 -0
  228. sparknlp/common/properties.py +1298 -0
  229. sparknlp/common/read_as.py +33 -0
  230. sparknlp/common/recursive_annotator_approach.py +35 -0
  231. sparknlp/common/storage.py +149 -0
  232. sparknlp/common/utils.py +39 -0
  233. sparknlp/functions.py +315 -5
  234. sparknlp/internal/__init__.py +1199 -0
  235. sparknlp/internal/annotator_java_ml.py +32 -0
  236. sparknlp/internal/annotator_transformer.py +37 -0
  237. sparknlp/internal/extended_java_wrapper.py +63 -0
  238. sparknlp/internal/params_getters_setters.py +71 -0
  239. sparknlp/internal/recursive.py +70 -0
  240. sparknlp/logging/__init__.py +15 -0
  241. sparknlp/logging/comet.py +467 -0
  242. sparknlp/partition/__init__.py +16 -0
  243. sparknlp/partition/partition.py +244 -0
  244. sparknlp/partition/partition_properties.py +902 -0
  245. sparknlp/partition/partition_transformer.py +200 -0
  246. sparknlp/pretrained/__init__.py +17 -0
  247. sparknlp/pretrained/pretrained_pipeline.py +158 -0
  248. sparknlp/pretrained/resource_downloader.py +216 -0
  249. sparknlp/pretrained/utils.py +35 -0
  250. sparknlp/reader/__init__.py +15 -0
  251. sparknlp/reader/enums.py +19 -0
  252. sparknlp/reader/pdf_to_text.py +190 -0
  253. sparknlp/reader/reader2doc.py +124 -0
  254. sparknlp/reader/reader2image.py +136 -0
  255. sparknlp/reader/reader2table.py +44 -0
  256. sparknlp/reader/reader_assembler.py +159 -0
  257. sparknlp/reader/sparknlp_reader.py +461 -0
  258. sparknlp/training/__init__.py +20 -0
  259. sparknlp/training/_tf_graph_builders/__init__.py +0 -0
  260. sparknlp/training/_tf_graph_builders/graph_builders.py +299 -0
  261. sparknlp/training/_tf_graph_builders/ner_dl/__init__.py +0 -0
  262. sparknlp/training/_tf_graph_builders/ner_dl/create_graph.py +41 -0
  263. sparknlp/training/_tf_graph_builders/ner_dl/dataset_encoder.py +78 -0
  264. sparknlp/training/_tf_graph_builders/ner_dl/ner_model.py +521 -0
  265. sparknlp/training/_tf_graph_builders/ner_dl/ner_model_saver.py +62 -0
  266. sparknlp/training/_tf_graph_builders/ner_dl/sentence_grouper.py +28 -0
  267. sparknlp/training/_tf_graph_builders/tf2contrib/__init__.py +36 -0
  268. sparknlp/training/_tf_graph_builders/tf2contrib/core_rnn_cell.py +385 -0
  269. sparknlp/training/_tf_graph_builders/tf2contrib/fused_rnn_cell.py +183 -0
  270. sparknlp/training/_tf_graph_builders/tf2contrib/gru_ops.py +235 -0
  271. sparknlp/training/_tf_graph_builders/tf2contrib/lstm_ops.py +665 -0
  272. sparknlp/training/_tf_graph_builders/tf2contrib/rnn.py +245 -0
  273. sparknlp/training/_tf_graph_builders/tf2contrib/rnn_cell.py +4006 -0
  274. sparknlp/training/_tf_graph_builders_1x/__init__.py +0 -0
  275. sparknlp/training/_tf_graph_builders_1x/graph_builders.py +277 -0
  276. sparknlp/training/_tf_graph_builders_1x/ner_dl/__init__.py +0 -0
  277. sparknlp/training/_tf_graph_builders_1x/ner_dl/create_graph.py +34 -0
  278. sparknlp/training/_tf_graph_builders_1x/ner_dl/dataset_encoder.py +78 -0
  279. sparknlp/training/_tf_graph_builders_1x/ner_dl/ner_model.py +532 -0
  280. sparknlp/training/_tf_graph_builders_1x/ner_dl/ner_model_saver.py +62 -0
  281. sparknlp/training/_tf_graph_builders_1x/ner_dl/sentence_grouper.py +28 -0
  282. sparknlp/training/conll.py +150 -0
  283. sparknlp/training/conllu.py +103 -0
  284. sparknlp/training/pos.py +103 -0
  285. sparknlp/training/pub_tator.py +76 -0
  286. sparknlp/training/spacy_to_annotation.py +57 -0
  287. sparknlp/training/tfgraphs.py +5 -0
  288. sparknlp/upload_to_hub.py +149 -0
  289. sparknlp/util.py +51 -5
  290. com/__init__.pyc +0 -0
  291. com/__pycache__/__init__.cpython-36.pyc +0 -0
  292. com/johnsnowlabs/__init__.pyc +0 -0
  293. com/johnsnowlabs/__pycache__/__init__.cpython-36.pyc +0 -0
  294. com/johnsnowlabs/nlp/__init__.pyc +0 -0
  295. com/johnsnowlabs/nlp/__pycache__/__init__.cpython-36.pyc +0 -0
  296. spark_nlp-2.6.3rc1.dist-info/METADATA +0 -36
  297. spark_nlp-2.6.3rc1.dist-info/RECORD +0 -48
  298. sparknlp/__init__.pyc +0 -0
  299. sparknlp/__pycache__/__init__.cpython-36.pyc +0 -0
  300. sparknlp/__pycache__/annotation.cpython-36.pyc +0 -0
  301. sparknlp/__pycache__/annotator.cpython-36.pyc +0 -0
  302. sparknlp/__pycache__/base.cpython-36.pyc +0 -0
  303. sparknlp/__pycache__/common.cpython-36.pyc +0 -0
  304. sparknlp/__pycache__/embeddings.cpython-36.pyc +0 -0
  305. sparknlp/__pycache__/functions.cpython-36.pyc +0 -0
  306. sparknlp/__pycache__/internal.cpython-36.pyc +0 -0
  307. sparknlp/__pycache__/pretrained.cpython-36.pyc +0 -0
  308. sparknlp/__pycache__/storage.cpython-36.pyc +0 -0
  309. sparknlp/__pycache__/training.cpython-36.pyc +0 -0
  310. sparknlp/__pycache__/util.cpython-36.pyc +0 -0
  311. sparknlp/annotation.pyc +0 -0
  312. sparknlp/annotator.py +0 -3006
  313. sparknlp/annotator.pyc +0 -0
  314. sparknlp/base.py +0 -347
  315. sparknlp/base.pyc +0 -0
  316. sparknlp/common.py +0 -193
  317. sparknlp/common.pyc +0 -0
  318. sparknlp/embeddings.py +0 -40
  319. sparknlp/embeddings.pyc +0 -0
  320. sparknlp/internal.py +0 -288
  321. sparknlp/internal.pyc +0 -0
  322. sparknlp/pretrained.py +0 -123
  323. sparknlp/pretrained.pyc +0 -0
  324. sparknlp/storage.py +0 -32
  325. sparknlp/storage.pyc +0 -0
  326. sparknlp/training.py +0 -62
  327. sparknlp/training.pyc +0 -0
  328. sparknlp/util.pyc +0 -0
  329. {spark_nlp-2.6.3rc1.dist-info → spark_nlp-6.2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,532 @@
1
+ import math
2
+ import random
3
+ import sys
4
+
5
+ import numpy as np
6
+ import tensorflow as tf
7
+
8
+ from .sentence_grouper import SentenceGrouper
9
+
10
+
11
+ class NerModel:
12
+ # If session is not defined than default session will be used
13
+ def __init__(self, session=None, dummy_tags=None, use_contrib=True, use_gpu_device=0):
14
+ tf.disable_v2_behavior()
15
+
16
+ self.word_repr = None
17
+ self.word_embeddings = None
18
+ self.session = session
19
+ self.session_created = False
20
+ self.dummy_tags = dummy_tags or []
21
+ self.use_contrib = use_contrib
22
+ self.use_gpu_device = use_gpu_device
23
+
24
+ if self.session is None:
25
+ self.session_created = True
26
+ self.session = tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(
27
+ allow_soft_placement=True,
28
+ log_device_placement=False))
29
+ with tf.compat.v1.device('/gpu:{}'.format(self.use_gpu_device)):
30
+ with tf.compat.v1.variable_scope("char_repr"):
31
+ # shape = (batch size, sentence, word)
32
+ self.char_ids = tf.compat.v1.placeholder(tf.int32, shape=[None, None, None], name="char_ids")
33
+
34
+ # shape = (batch_size, sentence)
35
+ self.word_lengths = tf.compat.v1.placeholder(tf.int32, shape=[None, None], name="word_lengths")
36
+
37
+ with tf.compat.v1.variable_scope("word_repr"):
38
+ # shape = (batch size)
39
+ self.sentence_lengths = tf.compat.v1.placeholder(tf.int32, shape=[None], name="sentence_lengths")
40
+
41
+ with tf.compat.v1.variable_scope("training", reuse=None) as scope:
42
+ # shape = (batch, sentence)
43
+ self.labels = tf.compat.v1.placeholder(tf.int32, shape=[None, None], name="labels")
44
+
45
+ self.lr = tf.compat.v1.placeholder_with_default(0.005, shape=(), name="lr")
46
+ self.dropout = tf.compat.v1.placeholder(tf.float32, shape=(), name="dropout")
47
+
48
+ self._char_bilstm_added = False
49
+ self._char_cnn_added = False
50
+ self._word_embeddings_added = False
51
+ self._context_added = False
52
+ self._encode_added = False
53
+
54
+ def add_bilstm_char_repr(self, nchars=101, dim=25, hidden=25):
55
+ self._char_bilstm_added = True
56
+
57
+ with tf.compat.v1.device('/gpu:{}'.format(self.use_gpu_device)):
58
+
59
+ with tf.compat.v1.variable_scope("char_repr_lstm"):
60
+ # 1. Lookup for character embeddings
61
+ char_range = math.sqrt(3 / dim)
62
+ embeddings = tf.compat.v1.get_variable(name="char_embeddings",
63
+ dtype=tf.float32,
64
+ shape=[nchars, dim],
65
+ initializer=tf.compat.v1.random_uniform_initializer(
66
+ -char_range,
67
+ char_range
68
+ ),
69
+ use_resource=False)
70
+
71
+ # shape = (batch, sentence, word, char embeddings dim)
72
+ char_embeddings = tf.nn.embedding_lookup(params=embeddings, ids=self.char_ids)
73
+ # char_embeddings = tf.nn.dropout(char_embeddings, self.dropout)
74
+ s = tf.shape(input=char_embeddings)
75
+
76
+ # shape = (batch x sentence, word, char embeddings dim)
77
+ char_embeddings_seq = tf.reshape(char_embeddings, shape=[-1, s[-2], dim])
78
+
79
+ # shape = (batch x sentence)
80
+ word_lengths_seq = tf.reshape(self.word_lengths, shape=[-1])
81
+
82
+ # 2. Add Bidirectional LSTM
83
+ model = tf.keras.Sequential([
84
+ tf.keras.layers.Bidirectional(
85
+ layer=tf.keras.layers.LSTM(hidden, return_sequences=False),
86
+ merge_mode="concat"
87
+ )
88
+ ])
89
+
90
+ inputs = char_embeddings_seq
91
+ mask = tf.expand_dims(tf.sequence_mask(word_lengths_seq, dtype=tf.float32), axis=-1)
92
+
93
+ # shape = (batch x sentence, 2 x hidden)
94
+ output = model(inputs, mask=mask)
95
+
96
+ # shape = (batch, sentence, 2 x hidden)
97
+ char_repr = tf.reshape(output, shape=[-1, s[1], 2 * hidden])
98
+
99
+ if self.word_repr is not None:
100
+ self.word_repr = tf.concat([self.word_repr, char_repr], axis=-1)
101
+ else:
102
+ self.word_repr = char_repr
103
+
104
+ def add_cnn_char_repr(self, nchars=101, dim=25, nfilters=25, pad=2):
105
+ self._char_cnn_added = True
106
+
107
+ with tf.compat.v1.device('/gpu:{}'.format(self.use_gpu_device)):
108
+
109
+ with tf.compat.v1.variable_scope("char_repr_cnn") as scope:
110
+ # 1. Lookup for character embeddings
111
+ char_range = math.sqrt(3 / dim)
112
+ embeddings = tf.compat.v1.get_variable(name="char_embeddings", dtype=tf.float32,
113
+ shape=[nchars, dim],
114
+ initializer=tf.compat.v1.random_uniform_initializer(-char_range,
115
+ char_range),
116
+ use_resource=False)
117
+
118
+ # shape = (batch, sentence, word_len, embeddings dim)
119
+ char_embeddings = tf.nn.embedding_lookup(params=embeddings, ids=self.char_ids)
120
+ # char_embeddings = tf.nn.dropout(char_embeddings, self.dropout)
121
+ s = tf.shape(input=char_embeddings)
122
+
123
+ # shape = (batch x sentence, word_len, embeddings dim)
124
+ char_embeddings = tf.reshape(char_embeddings, shape=[-1, s[-2], dim])
125
+
126
+ # batch x sentence, word_len, nfilters
127
+ conv1d = tf.keras.layers.Conv1D(
128
+ filters=nfilters,
129
+ kernel_size=[3],
130
+ padding='same',
131
+ activation=tf.nn.relu
132
+ )(char_embeddings)
133
+
134
+ # Max across each filter, shape = (batch x sentence, nfilters)
135
+ char_repr = tf.reduce_max(input_tensor=conv1d, axis=1, keepdims=True)
136
+ char_repr = tf.squeeze(char_repr, axis=[1])
137
+
138
+ # (batch, sentence, nfilters)
139
+ char_repr = tf.reshape(char_repr, shape=[s[0], s[1], nfilters])
140
+
141
+ if self.word_repr is not None:
142
+ self.word_repr = tf.concat([self.word_repr, char_repr], axis=-1)
143
+ else:
144
+ self.word_repr = char_repr
145
+
146
+ def add_pretrained_word_embeddings(self, dim=100):
147
+ self._word_embeddings_added = True
148
+
149
+ with tf.compat.v1.device('/gpu:{}'.format(self.use_gpu_device)):
150
+ with tf.compat.v1.variable_scope("word_repr") as scope:
151
+ # shape = (batch size, sentence, dim)
152
+ self.word_embeddings = tf.compat.v1.placeholder(tf.float32, shape=[None, None, dim],
153
+ name="word_embeddings")
154
+
155
+ if self.word_repr is not None:
156
+ self.word_repr = tf.concat([self.word_repr, self.word_embeddings], axis=-1)
157
+ else:
158
+ self.word_repr = self.word_embeddings
159
+
160
+ def _create_lstm_layer(self, inputs, hidden_size, lengths):
161
+
162
+ with tf.compat.v1.device('/gpu:{}'.format(self.use_gpu_device)):
163
+ if not self.use_contrib:
164
+ model = tf.keras.Sequential([
165
+ tf.keras.layers.Bidirectional(
166
+ layer=tf.keras.layers.LSTM(hidden_size, return_sequences=False),
167
+ merge_mode="concat"
168
+ )
169
+ ])
170
+
171
+ mask = tf.expand_dims(tf.sequence_mask(lengths, dtype=tf.float32), axis=-1)
172
+ # shape = (batch x sentence, 2 x hidden)
173
+ output = model(inputs, mask=mask)
174
+ # inputs shape = (batch, sentence, inp)
175
+ batch = tf.shape(input=lengths)[0]
176
+
177
+ return tf.reshape(output, shape=[batch, -1, 2 * hidden_size])
178
+
179
+ time_based = tf.transpose(a=inputs, perm=[1, 0, 2])
180
+
181
+ cell_fw = tf.contrib.rnn.LSTMBlockFusedCell(hidden_size, use_peephole=True)
182
+ cell_bw = tf.contrib.rnn.LSTMBlockFusedCell(hidden_size, use_peephole=True)
183
+ cell_bw = tf.contrib.rnn.TimeReversedFusedRNN(cell_bw)
184
+
185
+ output_fw, _ = cell_fw(time_based, dtype=tf.float32, sequence_length=lengths)
186
+ output_bw, _ = cell_bw(time_based, dtype=tf.float32, sequence_length=lengths)
187
+
188
+ result = tf.concat([output_fw, output_bw], axis=-1)
189
+ return tf.transpose(a=result, perm=[1, 0, 2])
190
+
191
+ def _multiply_layer(self, source, result_size, activation=tf.nn.relu):
192
+
193
+ with tf.compat.v1.device('/gpu:{}'.format(self.use_gpu_device)):
194
+ ntime_steps = tf.shape(input=source)[1]
195
+ source_size = source.shape[2]
196
+
197
+ W = tf.compat.v1.get_variable("W", shape=[source_size, result_size],
198
+ dtype=tf.float32,
199
+ initializer=tf.compat.v1.keras.initializers.VarianceScaling(scale=1.0,
200
+ mode="fan_avg",
201
+ distribution="uniform"),
202
+ use_resource=False)
203
+
204
+ b = tf.compat.v1.get_variable("b", shape=[result_size], dtype=tf.float32, use_resource=False)
205
+
206
+ # batch x time, source_size
207
+ source = tf.reshape(source, [-1, source_size])
208
+ # batch x time, result_size
209
+ result = tf.matmul(source, W) + b
210
+
211
+ result = tf.reshape(result, [-1, ntime_steps, result_size])
212
+ if activation:
213
+ result = activation(result)
214
+
215
+ return result
216
+
217
+ # Adds Bi LSTM with size of each cell hidden_size
218
+ def add_context_repr(self, ntags, hidden_size=100, height=1, residual=True):
219
+ assert (self._word_embeddings_added or self._char_cnn_added or self._char_bilstm_added,
220
+ "Add word embeddings by method add_word_embeddings " +
221
+ "or add char representation by method add_bilstm_char_repr " +
222
+ "or add_bilstm_char_repr before adding context layer")
223
+
224
+ self._context_added = True
225
+ self.ntags = ntags
226
+
227
+ with tf.compat.v1.device('/gpu:{}'.format(self.use_gpu_device)):
228
+ context_repr = self._multiply_layer(self.word_repr, 2 * hidden_size)
229
+ # Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`
230
+ context_repr = tf.nn.dropout(x=context_repr, rate=1 - self.dropout)
231
+
232
+ with tf.compat.v1.variable_scope("context_repr"):
233
+ for i in range(height):
234
+ with tf.compat.v1.variable_scope('lstm-{}'.format(i)):
235
+ new_repr = self._create_lstm_layer(context_repr, hidden_size,
236
+ lengths=self.sentence_lengths)
237
+
238
+ context_repr = new_repr + context_repr if residual else new_repr
239
+
240
+ context_repr = tf.nn.dropout(x=context_repr, rate=1 - self.dropout)
241
+
242
+ # batch, sentence, ntags
243
+ self.scores = self._multiply_layer(context_repr, ntags, activation=None)
244
+
245
+ tf.identity(self.scores, "scores")
246
+
247
+ self.predicted_labels = tf.argmax(input=self.scores, axis=-1)
248
+ tf.identity(self.predicted_labels, "predicted_labels")
249
+
250
+ def add_inference_layer(self, crf=False, predictions_op_name=None):
251
+ assert (self._context_added,
252
+ "Add context representation layer by method add_context_repr before adding inference layer")
253
+ self._inference_added = True
254
+
255
+ with tf.device('/gpu:{}'.format(self.use_gpu_device)):
256
+
257
+ with tf.compat.v1.variable_scope("inference", reuse=None) as scope:
258
+
259
+ self.crf = tf.constant(crf, dtype=tf.bool, name="crf")
260
+
261
+ if crf:
262
+ transition_params = tf.compat.v1.get_variable("transition_params",
263
+ shape=[self.ntags, self.ntags],
264
+ initializer=tf.compat.v1.keras.initializers.VarianceScaling(
265
+ scale=1.0, mode="fan_avg",
266
+ distribution="uniform"),
267
+ use_resource=False)
268
+
269
+ # CRF shape = (batch, sentence)
270
+ log_likelihood, self.transition_params = tf.contrib.crf.crf_log_likelihood(
271
+ self.scores,
272
+ self.labels,
273
+ self.sentence_lengths,
274
+ transition_params
275
+ )
276
+
277
+ tf.identity(log_likelihood, "log_likelihood")
278
+ tf.identity(self.transition_params, "transition_params")
279
+
280
+ self.loss = tf.reduce_mean(input_tensor=-log_likelihood)
281
+ if predictions_op_name:
282
+ with tf.compat.v1.variable_scope("inference_tmp", reuse=None):
283
+ tmp_prediction, _ = tf.contrib.crf.crf_decode(self.scores, self.transition_params,
284
+ self.sentence_lengths)
285
+
286
+ self.prediction = tf.identity(tmp_prediction, name=predictions_op_name)
287
+ else:
288
+ self.prediction, _ = tf.contrib.crf.crf_decode(self.scores, self.transition_params,
289
+ self.sentence_lengths)
290
+
291
+ else:
292
+ # Softmax
293
+ losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.scores, labels=self.labels)
294
+ # shape = (batch, sentence, ntags)
295
+ mask = tf.sequence_mask(self.sentence_lengths)
296
+ # apply mask
297
+ losses = tf.boolean_mask(tensor=losses, mask=mask)
298
+
299
+ self.loss = tf.reduce_mean(input_tensor=losses)
300
+
301
+ self.prediction = tf.math.argmax(input=self.scores, axis=-1, name=predictions_op_name)
302
+
303
+ tf.identity(self.loss, "loss")
304
+
305
+ # clip_gradient < 0 - no gradient clipping
306
+ def add_training_op(self, clip_gradient=2.0, train_op_name=None):
307
+ assert (self._inference_added,
308
+ "Add inference layer by method add_inference_layer before adding training layer")
309
+ self._training_added = True
310
+
311
+ with tf.compat.v1.device('/gpu:{}'.format(self.use_gpu_device)):
312
+
313
+ with tf.compat.v1.variable_scope("training", reuse=None):
314
+ if train_op_name:
315
+ optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=self.lr, name=train_op_name)
316
+ else:
317
+ optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=self.lr)
318
+ if clip_gradient > 0:
319
+ gvs = optimizer.compute_gradients(self.loss)
320
+ capped_gvs = [(tf.clip_by_value(grad, -clip_gradient, clip_gradient), var) for grad, var in gvs if
321
+ grad is not None]
322
+ self.train_op = optimizer.apply_gradients(capped_gvs)
323
+ else:
324
+ self.train_op = optimizer.minimize(self.loss)
325
+
326
+ self.init_op = tf.compat.v1.variables_initializer(tf.compat.v1.global_variables(), name="init")
327
+
328
+ @staticmethod
329
+ def num_trues(array):
330
+ result = 0
331
+ for item in array:
332
+ if item == True:
333
+ result += 1
334
+
335
+ return result
336
+
337
+ @staticmethod
338
+ def fill(array, l, val):
339
+ result = array[:]
340
+ for i in range(l - len(array)):
341
+ result.append(val)
342
+ return result
343
+
344
+ @staticmethod
345
+ def get_sentence_lengths(batch, idx="word_embeddings"):
346
+ return [len(row[idx]) for row in batch]
347
+
348
+ @staticmethod
349
+ def get_sentence_token_lengths(batch, idx="tag_ids"):
350
+ return [len(row[idx]) for row in batch]
351
+
352
+ @staticmethod
353
+ def get_word_lengths(batch, idx="char_ids"):
354
+ max_words = max([len(row[idx]) for row in batch])
355
+ return [NerModel.fill([len(chars) for chars in row[idx]], max_words, 0)
356
+ for row in batch]
357
+
358
+ @staticmethod
359
+ def get_char_ids(batch, idx="char_ids"):
360
+ max_chars = max([max([len(char_ids) for char_ids in sentence[idx]]) for sentence in batch])
361
+ max_words = max([len(sentence[idx]) for sentence in batch])
362
+
363
+ return [
364
+ NerModel.fill(
365
+ [NerModel.fill(char_ids, max_chars, 0) for char_ids in sentence[idx]],
366
+ max_words, [0] * max_chars
367
+ )
368
+ for sentence in batch]
369
+
370
+ @staticmethod
371
+ def get_from_batch(batch, idx):
372
+ k = max([len(row[idx]) for row in batch])
373
+ return list([NerModel.fill(row[idx], k, 0) for row in batch])
374
+
375
+ @staticmethod
376
+ def get_tag_ids(batch, idx="tag_ids"):
377
+ return NerModel.get_from_batch(batch, idx)
378
+
379
+ @staticmethod
380
+ def get_word_embeddings(batch, idx="word_embeddings"):
381
+ embeddings_dim = len(batch[0][idx][0])
382
+ max_words = max([len(sentence[idx]) for sentence in batch])
383
+ return [
384
+ NerModel.fill([word_embedding for word_embedding in sentence[idx]],
385
+ max_words, [0] * embeddings_dim
386
+ )
387
+ for sentence in batch]
388
+
389
+ @staticmethod
390
+ def slice(dataset, batch_size=10):
391
+ grouper = SentenceGrouper([5, 10, 20, 50])
392
+ return grouper.slice(dataset, batch_size)
393
+
394
+ def init_variables(self):
395
+ self.session.run(self.init_op)
396
+
397
+ def train(self, train,
398
+ epoch_start=0,
399
+ epoch_end=100,
400
+ batch_size=32,
401
+ lr=0.01,
402
+ po=0,
403
+ dropout=0.65,
404
+ init_variables=False
405
+ ):
406
+
407
+ assert (self._training_added, "Add training layer by method add_training_op before running training")
408
+
409
+ if init_variables:
410
+ with tf.compat.v1.device('/gpu:{}'.format(self.use_gpu_device)):
411
+ self.session.run(tf.compat.v1.global_variables_initializer())
412
+
413
+ print('trainig started')
414
+ for epoch in range(epoch_start, epoch_end):
415
+ random.shuffle(train)
416
+ sum_loss = 0
417
+ for batch in NerModel.slice(train, batch_size):
418
+ feed_dict = {
419
+ self.sentence_lengths: NerModel.get_sentence_lengths(batch),
420
+ self.word_embeddings: NerModel.get_word_embeddings(batch),
421
+
422
+ self.word_lengths: NerModel.get_word_lengths(batch),
423
+ self.char_ids: NerModel.get_char_ids(batch),
424
+ self.labels: NerModel.get_tag_ids(batch),
425
+
426
+ self.dropout: dropout,
427
+ self.lr: lr / (1 + po * epoch)
428
+ }
429
+ mean_loss, _ = self.session.run([self.loss, self.train_op], feed_dict=feed_dict)
430
+ sum_loss += mean_loss
431
+
432
+ print("epoch {}".format(epoch))
433
+ print("mean loss: {}".format(sum_loss))
434
+ print()
435
+ sys.stdout.flush()
436
+
437
+ def measure(self, dataset, batch_size=20, dropout=1.0):
438
+ predicted = {}
439
+ correct = {}
440
+ correct_predicted = {}
441
+
442
+ for batch in NerModel.slice(dataset, batch_size):
443
+ tags_ids = NerModel.get_tag_ids(batch)
444
+ sentence_lengths = NerModel.get_sentence_lengths(batch)
445
+
446
+ feed_dict = {
447
+ self.sentence_lengths: sentence_lengths,
448
+ self.word_embeddings: NerModel.get_word_embeddings(batch),
449
+
450
+ self.word_lengths: NerModel.get_word_lengths(batch),
451
+ self.char_ids: NerModel.get_char_ids(batch),
452
+ self.labels: tags_ids,
453
+
454
+ self.dropout: dropout
455
+ }
456
+
457
+ prediction = self.session.run(self.prediction, feed_dict=feed_dict)
458
+ batch_prediction = np.reshape(prediction, (len(batch), -1))
459
+
460
+ for i in range(len(batch)):
461
+ is_word_start = batch[i]['is_word_start']
462
+
463
+ for word in range(sentence_lengths[i]):
464
+ if not is_word_start[word]:
465
+ continue
466
+
467
+ p = batch_prediction[i][word]
468
+ c = tags_ids[i][word]
469
+
470
+ if c in self.dummy_tags:
471
+ continue
472
+
473
+ predicted[p] = predicted.get(p, 0) + 1
474
+ correct[c] = correct.get(c, 0) + 1
475
+ if p == c:
476
+ correct_predicted[p] = correct_predicted.get(p, 0) + 1
477
+
478
+ num_correct_predicted = sum([correct_predicted.get(i, 0) for i in range(1, self.ntags)])
479
+ num_predicted = sum([predicted.get(i, 0) for i in range(1, self.ntags)])
480
+ num_correct = sum([correct.get(i, 0) for i in range(1, self.ntags)])
481
+
482
+ prec = num_correct_predicted / (num_predicted or 1.)
483
+ rec = num_correct_predicted / (num_correct or 1.)
484
+
485
+ f1 = 2 * prec * rec / (rec + prec)
486
+
487
+ return prec, rec, f1
488
+
489
+ @staticmethod
490
+ def get_softmax(scores, threshold=None):
491
+ exp_scores = np.exp(scores)
492
+
493
+ for _ in exp_scores:
494
+ for sentence in exp_scores:
495
+ for i in range(len(sentence)):
496
+ probabilities = sentence[i] / np.sum(sentence[i])
497
+ sentence[i] = [p if threshold is None or p >= threshold else 0 for p in probabilities]
498
+
499
+ return exp_scores
500
+
501
+ def predict(self, sentences, batch_size=20, threshold=None):
502
+ result = []
503
+
504
+ for batch in NerModel.slice(sentences, batch_size):
505
+ sentence_lengths = NerModel.get_sentence_lengths(batch)
506
+
507
+ feed_dict = {
508
+ self.sentence_lengths: sentence_lengths,
509
+ self.word_embeddings: NerModel.get_word_embeddings(batch),
510
+
511
+ self.word_lengths: NerModel.get_word_lengths(batch),
512
+ self.char_ids: NerModel.get_char_ids(batch),
513
+
514
+ self.dropout: 1.1
515
+ }
516
+
517
+ prediction = self.session.run(self.prediction, feed_dict=feed_dict)
518
+ batch_prediction = np.reshape(prediction, (len(batch), -1))
519
+
520
+ for i in range(len(batch)):
521
+ sentence = []
522
+ for word in range(sentence_lengths[i]):
523
+ tag = batch_prediction[i][word]
524
+ sentence.append(tag)
525
+
526
+ result.append(sentence)
527
+
528
+ return result
529
+
530
+ def close(self):
531
+ if self.session_created:
532
+ self.session.close()
@@ -0,0 +1,62 @@
1
+ import os
2
+
3
+ import tensorflow as tf
4
+
5
+
6
+ class NerModelSaver:
7
+ def __init__(self, ner, encoder, embeddings_file=None):
8
+ self.ner = ner
9
+ self.encoder = encoder
10
+ self.embeddings_file = embeddings_file
11
+
12
+ @staticmethod
13
+ def restore_tensorflow_state(session, export_dir):
14
+ with tf.device('/gpu:0'):
15
+ saveNodes = list([n.name for n in tf.get_default_graph().as_graph_def().node if n.name.startswith('save/')])
16
+ if len(saveNodes) == 0:
17
+ saver = tf.train.Saver()
18
+
19
+ variables_file = os.path.join(export_dir, 'variables')
20
+ session.run("save/restore_all", feed_dict={'save/Const:0': variables_file})
21
+
22
+ def save_models(self, folder):
23
+ with tf.device('/gpu:0'):
24
+ saveNodes = list([n.name for n in tf.get_default_graph().as_graph_def().node if n.name.startswith('save/')])
25
+ if len(saveNodes) == 0:
26
+ saver = tf.train.Saver()
27
+
28
+ variables_file = os.path.join(folder, 'variables')
29
+ self.ner.session.run('save/control_dependency', feed_dict={'save/Const:0': variables_file})
30
+ tf.train.write_graph(self.ner.session.graph, folder, 'saved_model.pb', False)
31
+
32
+ def save(self, export_dir):
33
+ def save_tags(file):
34
+ id2tag = {id: tag for (tag, id) in self.encoder.tag2id.items()}
35
+
36
+ with open(file, 'w') as f:
37
+ for i in range(len(id2tag)):
38
+ tag = id2tag[i]
39
+ f.write(tag)
40
+ f.write('\n')
41
+
42
+ def save_embeddings(src, dst):
43
+ from shutil import copyfile
44
+ copyfile(src, dst)
45
+ with open(dst + '.meta', 'w') as f:
46
+ embeddings = self.encoder.embeddings
47
+ dim = len(embeddings[0]) if embeddings else 0
48
+ f.write(str(dim))
49
+
50
+ def save_chars(file):
51
+ id2char = {id: char for (char, id) in self.encoder.char2id.items()}
52
+ with open(file, 'w') as f:
53
+ for i in range(1, len(id2char) + 1):
54
+ f.write(id2char[i])
55
+
56
+ save_models(export_dir)
57
+ save_tags(os.path.join(export_dir, 'tags.csv'))
58
+
59
+ if self.embeddings_file:
60
+ save_embeddings(self.embeddings_file, os.path.join(export_dir, 'embeddings'))
61
+
62
+ save_chars(os.path.join(export_dir, 'chars.csv'))
@@ -0,0 +1,28 @@
1
+ class SentenceGrouper:
2
+ def __init__(self, bucket_lengths):
3
+ self.bucket_lengths = bucket_lengths
4
+
5
+ def get_bucket_id(self, length):
6
+ for i, bucket_len in enumerate(self.bucket_lengths):
7
+ if length <= bucket_len:
8
+ return i
9
+
10
+ return len(self.bucket_lengths)
11
+
12
+ def slice(self, dataset, batch_size=32):
13
+ buckets = [[] for item in self.bucket_lengths]
14
+ buckets.append([])
15
+
16
+ for entry in dataset:
17
+ length = len(entry['words'])
18
+ bucket_id = self.get_bucket_id(length)
19
+ buckets[bucket_id].append(entry)
20
+
21
+ if len(buckets[bucket_id]) >= batch_size:
22
+ result = buckets[bucket_id][:]
23
+ yield result
24
+ buckets[bucket_id] = []
25
+
26
+ for bucket in buckets:
27
+ if len(bucket) > 0:
28
+ yield bucket