spark-nlp 2.6.3rc1__py2.py3-none-any.whl → 6.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (329) hide show
  1. com/johnsnowlabs/ml/__init__.py +0 -0
  2. com/johnsnowlabs/ml/ai/__init__.py +10 -0
  3. com/johnsnowlabs/nlp/__init__.py +4 -2
  4. spark_nlp-6.2.1.dist-info/METADATA +362 -0
  5. spark_nlp-6.2.1.dist-info/RECORD +292 -0
  6. {spark_nlp-2.6.3rc1.dist-info → spark_nlp-6.2.1.dist-info}/WHEEL +1 -1
  7. sparknlp/__init__.py +281 -27
  8. sparknlp/annotation.py +137 -6
  9. sparknlp/annotation_audio.py +61 -0
  10. sparknlp/annotation_image.py +82 -0
  11. sparknlp/annotator/__init__.py +93 -0
  12. sparknlp/annotator/audio/__init__.py +16 -0
  13. sparknlp/annotator/audio/hubert_for_ctc.py +188 -0
  14. sparknlp/annotator/audio/wav2vec2_for_ctc.py +161 -0
  15. sparknlp/annotator/audio/whisper_for_ctc.py +251 -0
  16. sparknlp/annotator/chunk2_doc.py +85 -0
  17. sparknlp/annotator/chunker.py +137 -0
  18. sparknlp/annotator/classifier_dl/__init__.py +61 -0
  19. sparknlp/annotator/classifier_dl/albert_for_multiple_choice.py +161 -0
  20. sparknlp/annotator/classifier_dl/albert_for_question_answering.py +172 -0
  21. sparknlp/annotator/classifier_dl/albert_for_sequence_classification.py +201 -0
  22. sparknlp/annotator/classifier_dl/albert_for_token_classification.py +179 -0
  23. sparknlp/annotator/classifier_dl/albert_for_zero_shot_classification.py +211 -0
  24. sparknlp/annotator/classifier_dl/bart_for_zero_shot_classification.py +225 -0
  25. sparknlp/annotator/classifier_dl/bert_for_multiple_choice.py +161 -0
  26. sparknlp/annotator/classifier_dl/bert_for_question_answering.py +168 -0
  27. sparknlp/annotator/classifier_dl/bert_for_sequence_classification.py +202 -0
  28. sparknlp/annotator/classifier_dl/bert_for_token_classification.py +177 -0
  29. sparknlp/annotator/classifier_dl/bert_for_zero_shot_classification.py +212 -0
  30. sparknlp/annotator/classifier_dl/camembert_for_question_answering.py +168 -0
  31. sparknlp/annotator/classifier_dl/camembert_for_sequence_classification.py +205 -0
  32. sparknlp/annotator/classifier_dl/camembert_for_token_classification.py +173 -0
  33. sparknlp/annotator/classifier_dl/camembert_for_zero_shot_classification.py +202 -0
  34. sparknlp/annotator/classifier_dl/classifier_dl.py +320 -0
  35. sparknlp/annotator/classifier_dl/deberta_for_question_answering.py +168 -0
  36. sparknlp/annotator/classifier_dl/deberta_for_sequence_classification.py +198 -0
  37. sparknlp/annotator/classifier_dl/deberta_for_token_classification.py +175 -0
  38. sparknlp/annotator/classifier_dl/deberta_for_zero_shot_classification.py +193 -0
  39. sparknlp/annotator/classifier_dl/distil_bert_for_question_answering.py +168 -0
  40. sparknlp/annotator/classifier_dl/distil_bert_for_sequence_classification.py +201 -0
  41. sparknlp/annotator/classifier_dl/distil_bert_for_token_classification.py +175 -0
  42. sparknlp/annotator/classifier_dl/distil_bert_for_zero_shot_classification.py +211 -0
  43. sparknlp/annotator/classifier_dl/distilbert_for_multiple_choice.py +161 -0
  44. sparknlp/annotator/classifier_dl/longformer_for_question_answering.py +168 -0
  45. sparknlp/annotator/classifier_dl/longformer_for_sequence_classification.py +201 -0
  46. sparknlp/annotator/classifier_dl/longformer_for_token_classification.py +176 -0
  47. sparknlp/annotator/classifier_dl/mpnet_for_question_answering.py +148 -0
  48. sparknlp/annotator/classifier_dl/mpnet_for_sequence_classification.py +188 -0
  49. sparknlp/annotator/classifier_dl/mpnet_for_token_classification.py +173 -0
  50. sparknlp/annotator/classifier_dl/multi_classifier_dl.py +395 -0
  51. sparknlp/annotator/classifier_dl/roberta_for_multiple_choice.py +161 -0
  52. sparknlp/annotator/classifier_dl/roberta_for_question_answering.py +168 -0
  53. sparknlp/annotator/classifier_dl/roberta_for_sequence_classification.py +201 -0
  54. sparknlp/annotator/classifier_dl/roberta_for_token_classification.py +189 -0
  55. sparknlp/annotator/classifier_dl/roberta_for_zero_shot_classification.py +225 -0
  56. sparknlp/annotator/classifier_dl/sentiment_dl.py +378 -0
  57. sparknlp/annotator/classifier_dl/tapas_for_question_answering.py +170 -0
  58. sparknlp/annotator/classifier_dl/xlm_roberta_for_multiple_choice.py +149 -0
  59. sparknlp/annotator/classifier_dl/xlm_roberta_for_question_answering.py +168 -0
  60. sparknlp/annotator/classifier_dl/xlm_roberta_for_sequence_classification.py +201 -0
  61. sparknlp/annotator/classifier_dl/xlm_roberta_for_token_classification.py +173 -0
  62. sparknlp/annotator/classifier_dl/xlm_roberta_for_zero_shot_classification.py +225 -0
  63. sparknlp/annotator/classifier_dl/xlnet_for_sequence_classification.py +201 -0
  64. sparknlp/annotator/classifier_dl/xlnet_for_token_classification.py +176 -0
  65. sparknlp/annotator/cleaners/__init__.py +15 -0
  66. sparknlp/annotator/cleaners/cleaner.py +202 -0
  67. sparknlp/annotator/cleaners/extractor.py +191 -0
  68. sparknlp/annotator/coref/__init__.py +1 -0
  69. sparknlp/annotator/coref/spanbert_coref.py +221 -0
  70. sparknlp/annotator/cv/__init__.py +29 -0
  71. sparknlp/annotator/cv/blip_for_question_answering.py +172 -0
  72. sparknlp/annotator/cv/clip_for_zero_shot_classification.py +193 -0
  73. sparknlp/annotator/cv/convnext_for_image_classification.py +269 -0
  74. sparknlp/annotator/cv/florence2_transformer.py +180 -0
  75. sparknlp/annotator/cv/gemma3_for_multimodal.py +346 -0
  76. sparknlp/annotator/cv/internvl_for_multimodal.py +280 -0
  77. sparknlp/annotator/cv/janus_for_multimodal.py +351 -0
  78. sparknlp/annotator/cv/llava_for_multimodal.py +328 -0
  79. sparknlp/annotator/cv/mllama_for_multimodal.py +340 -0
  80. sparknlp/annotator/cv/paligemma_for_multimodal.py +308 -0
  81. sparknlp/annotator/cv/phi3_vision_for_multimodal.py +328 -0
  82. sparknlp/annotator/cv/qwen2vl_transformer.py +332 -0
  83. sparknlp/annotator/cv/smolvlm_transformer.py +426 -0
  84. sparknlp/annotator/cv/swin_for_image_classification.py +242 -0
  85. sparknlp/annotator/cv/vision_encoder_decoder_for_image_captioning.py +240 -0
  86. sparknlp/annotator/cv/vit_for_image_classification.py +217 -0
  87. sparknlp/annotator/dataframe_optimizer.py +216 -0
  88. sparknlp/annotator/date2_chunk.py +88 -0
  89. sparknlp/annotator/dependency/__init__.py +17 -0
  90. sparknlp/annotator/dependency/dependency_parser.py +294 -0
  91. sparknlp/annotator/dependency/typed_dependency_parser.py +318 -0
  92. sparknlp/annotator/document_character_text_splitter.py +228 -0
  93. sparknlp/annotator/document_normalizer.py +235 -0
  94. sparknlp/annotator/document_token_splitter.py +175 -0
  95. sparknlp/annotator/document_token_splitter_test.py +85 -0
  96. sparknlp/annotator/embeddings/__init__.py +45 -0
  97. sparknlp/annotator/embeddings/albert_embeddings.py +230 -0
  98. sparknlp/annotator/embeddings/auto_gguf_embeddings.py +539 -0
  99. sparknlp/annotator/embeddings/bert_embeddings.py +208 -0
  100. sparknlp/annotator/embeddings/bert_sentence_embeddings.py +224 -0
  101. sparknlp/annotator/embeddings/bge_embeddings.py +199 -0
  102. sparknlp/annotator/embeddings/camembert_embeddings.py +210 -0
  103. sparknlp/annotator/embeddings/chunk_embeddings.py +149 -0
  104. sparknlp/annotator/embeddings/deberta_embeddings.py +208 -0
  105. sparknlp/annotator/embeddings/distil_bert_embeddings.py +221 -0
  106. sparknlp/annotator/embeddings/doc2vec.py +352 -0
  107. sparknlp/annotator/embeddings/e5_embeddings.py +195 -0
  108. sparknlp/annotator/embeddings/e5v_embeddings.py +138 -0
  109. sparknlp/annotator/embeddings/elmo_embeddings.py +251 -0
  110. sparknlp/annotator/embeddings/instructor_embeddings.py +204 -0
  111. sparknlp/annotator/embeddings/longformer_embeddings.py +211 -0
  112. sparknlp/annotator/embeddings/minilm_embeddings.py +189 -0
  113. sparknlp/annotator/embeddings/mpnet_embeddings.py +192 -0
  114. sparknlp/annotator/embeddings/mxbai_embeddings.py +184 -0
  115. sparknlp/annotator/embeddings/nomic_embeddings.py +181 -0
  116. sparknlp/annotator/embeddings/roberta_embeddings.py +225 -0
  117. sparknlp/annotator/embeddings/roberta_sentence_embeddings.py +191 -0
  118. sparknlp/annotator/embeddings/sentence_embeddings.py +134 -0
  119. sparknlp/annotator/embeddings/snowflake_embeddings.py +202 -0
  120. sparknlp/annotator/embeddings/uae_embeddings.py +211 -0
  121. sparknlp/annotator/embeddings/universal_sentence_encoder.py +211 -0
  122. sparknlp/annotator/embeddings/word2vec.py +353 -0
  123. sparknlp/annotator/embeddings/word_embeddings.py +385 -0
  124. sparknlp/annotator/embeddings/xlm_roberta_embeddings.py +225 -0
  125. sparknlp/annotator/embeddings/xlm_roberta_sentence_embeddings.py +194 -0
  126. sparknlp/annotator/embeddings/xlnet_embeddings.py +227 -0
  127. sparknlp/annotator/er/__init__.py +16 -0
  128. sparknlp/annotator/er/entity_ruler.py +267 -0
  129. sparknlp/annotator/graph_extraction.py +368 -0
  130. sparknlp/annotator/keyword_extraction/__init__.py +16 -0
  131. sparknlp/annotator/keyword_extraction/yake_keyword_extraction.py +270 -0
  132. sparknlp/annotator/ld_dl/__init__.py +16 -0
  133. sparknlp/annotator/ld_dl/language_detector_dl.py +199 -0
  134. sparknlp/annotator/lemmatizer.py +250 -0
  135. sparknlp/annotator/matcher/__init__.py +20 -0
  136. sparknlp/annotator/matcher/big_text_matcher.py +272 -0
  137. sparknlp/annotator/matcher/date_matcher.py +303 -0
  138. sparknlp/annotator/matcher/multi_date_matcher.py +109 -0
  139. sparknlp/annotator/matcher/regex_matcher.py +221 -0
  140. sparknlp/annotator/matcher/text_matcher.py +290 -0
  141. sparknlp/annotator/n_gram_generator.py +141 -0
  142. sparknlp/annotator/ner/__init__.py +21 -0
  143. sparknlp/annotator/ner/ner_approach.py +94 -0
  144. sparknlp/annotator/ner/ner_converter.py +148 -0
  145. sparknlp/annotator/ner/ner_crf.py +397 -0
  146. sparknlp/annotator/ner/ner_dl.py +591 -0
  147. sparknlp/annotator/ner/ner_dl_graph_checker.py +293 -0
  148. sparknlp/annotator/ner/ner_overwriter.py +166 -0
  149. sparknlp/annotator/ner/zero_shot_ner_model.py +173 -0
  150. sparknlp/annotator/normalizer.py +230 -0
  151. sparknlp/annotator/openai/__init__.py +16 -0
  152. sparknlp/annotator/openai/openai_completion.py +349 -0
  153. sparknlp/annotator/openai/openai_embeddings.py +106 -0
  154. sparknlp/annotator/param/__init__.py +17 -0
  155. sparknlp/annotator/param/classifier_encoder.py +98 -0
  156. sparknlp/annotator/param/evaluation_dl_params.py +130 -0
  157. sparknlp/annotator/pos/__init__.py +16 -0
  158. sparknlp/annotator/pos/perceptron.py +263 -0
  159. sparknlp/annotator/sentence/__init__.py +17 -0
  160. sparknlp/annotator/sentence/sentence_detector.py +290 -0
  161. sparknlp/annotator/sentence/sentence_detector_dl.py +467 -0
  162. sparknlp/annotator/sentiment/__init__.py +17 -0
  163. sparknlp/annotator/sentiment/sentiment_detector.py +208 -0
  164. sparknlp/annotator/sentiment/vivekn_sentiment.py +242 -0
  165. sparknlp/annotator/seq2seq/__init__.py +35 -0
  166. sparknlp/annotator/seq2seq/auto_gguf_model.py +304 -0
  167. sparknlp/annotator/seq2seq/auto_gguf_reranker.py +334 -0
  168. sparknlp/annotator/seq2seq/auto_gguf_vision_model.py +336 -0
  169. sparknlp/annotator/seq2seq/bart_transformer.py +420 -0
  170. sparknlp/annotator/seq2seq/cohere_transformer.py +357 -0
  171. sparknlp/annotator/seq2seq/cpm_transformer.py +321 -0
  172. sparknlp/annotator/seq2seq/gpt2_transformer.py +363 -0
  173. sparknlp/annotator/seq2seq/llama2_transformer.py +343 -0
  174. sparknlp/annotator/seq2seq/llama3_transformer.py +381 -0
  175. sparknlp/annotator/seq2seq/m2m100_transformer.py +392 -0
  176. sparknlp/annotator/seq2seq/marian_transformer.py +374 -0
  177. sparknlp/annotator/seq2seq/mistral_transformer.py +348 -0
  178. sparknlp/annotator/seq2seq/nllb_transformer.py +420 -0
  179. sparknlp/annotator/seq2seq/olmo_transformer.py +326 -0
  180. sparknlp/annotator/seq2seq/phi2_transformer.py +326 -0
  181. sparknlp/annotator/seq2seq/phi3_transformer.py +330 -0
  182. sparknlp/annotator/seq2seq/phi4_transformer.py +387 -0
  183. sparknlp/annotator/seq2seq/qwen_transformer.py +340 -0
  184. sparknlp/annotator/seq2seq/starcoder_transformer.py +335 -0
  185. sparknlp/annotator/seq2seq/t5_transformer.py +425 -0
  186. sparknlp/annotator/similarity/__init__.py +0 -0
  187. sparknlp/annotator/similarity/document_similarity_ranker.py +379 -0
  188. sparknlp/annotator/spell_check/__init__.py +18 -0
  189. sparknlp/annotator/spell_check/context_spell_checker.py +911 -0
  190. sparknlp/annotator/spell_check/norvig_sweeting.py +358 -0
  191. sparknlp/annotator/spell_check/symmetric_delete.py +299 -0
  192. sparknlp/annotator/stemmer.py +79 -0
  193. sparknlp/annotator/stop_words_cleaner.py +190 -0
  194. sparknlp/annotator/tf_ner_dl_graph_builder.py +179 -0
  195. sparknlp/annotator/token/__init__.py +19 -0
  196. sparknlp/annotator/token/chunk_tokenizer.py +118 -0
  197. sparknlp/annotator/token/recursive_tokenizer.py +205 -0
  198. sparknlp/annotator/token/regex_tokenizer.py +208 -0
  199. sparknlp/annotator/token/tokenizer.py +561 -0
  200. sparknlp/annotator/token2_chunk.py +76 -0
  201. sparknlp/annotator/ws/__init__.py +16 -0
  202. sparknlp/annotator/ws/word_segmenter.py +429 -0
  203. sparknlp/base/__init__.py +30 -0
  204. sparknlp/base/audio_assembler.py +95 -0
  205. sparknlp/base/doc2_chunk.py +169 -0
  206. sparknlp/base/document_assembler.py +164 -0
  207. sparknlp/base/embeddings_finisher.py +201 -0
  208. sparknlp/base/finisher.py +217 -0
  209. sparknlp/base/gguf_ranking_finisher.py +234 -0
  210. sparknlp/base/graph_finisher.py +125 -0
  211. sparknlp/base/has_recursive_fit.py +24 -0
  212. sparknlp/base/has_recursive_transform.py +22 -0
  213. sparknlp/base/image_assembler.py +172 -0
  214. sparknlp/base/light_pipeline.py +429 -0
  215. sparknlp/base/multi_document_assembler.py +164 -0
  216. sparknlp/base/prompt_assembler.py +207 -0
  217. sparknlp/base/recursive_pipeline.py +107 -0
  218. sparknlp/base/table_assembler.py +145 -0
  219. sparknlp/base/token_assembler.py +124 -0
  220. sparknlp/common/__init__.py +26 -0
  221. sparknlp/common/annotator_approach.py +41 -0
  222. sparknlp/common/annotator_model.py +47 -0
  223. sparknlp/common/annotator_properties.py +114 -0
  224. sparknlp/common/annotator_type.py +38 -0
  225. sparknlp/common/completion_post_processing.py +37 -0
  226. sparknlp/common/coverage_result.py +22 -0
  227. sparknlp/common/match_strategy.py +33 -0
  228. sparknlp/common/properties.py +1298 -0
  229. sparknlp/common/read_as.py +33 -0
  230. sparknlp/common/recursive_annotator_approach.py +35 -0
  231. sparknlp/common/storage.py +149 -0
  232. sparknlp/common/utils.py +39 -0
  233. sparknlp/functions.py +315 -5
  234. sparknlp/internal/__init__.py +1199 -0
  235. sparknlp/internal/annotator_java_ml.py +32 -0
  236. sparknlp/internal/annotator_transformer.py +37 -0
  237. sparknlp/internal/extended_java_wrapper.py +63 -0
  238. sparknlp/internal/params_getters_setters.py +71 -0
  239. sparknlp/internal/recursive.py +70 -0
  240. sparknlp/logging/__init__.py +15 -0
  241. sparknlp/logging/comet.py +467 -0
  242. sparknlp/partition/__init__.py +16 -0
  243. sparknlp/partition/partition.py +244 -0
  244. sparknlp/partition/partition_properties.py +902 -0
  245. sparknlp/partition/partition_transformer.py +200 -0
  246. sparknlp/pretrained/__init__.py +17 -0
  247. sparknlp/pretrained/pretrained_pipeline.py +158 -0
  248. sparknlp/pretrained/resource_downloader.py +216 -0
  249. sparknlp/pretrained/utils.py +35 -0
  250. sparknlp/reader/__init__.py +15 -0
  251. sparknlp/reader/enums.py +19 -0
  252. sparknlp/reader/pdf_to_text.py +190 -0
  253. sparknlp/reader/reader2doc.py +124 -0
  254. sparknlp/reader/reader2image.py +136 -0
  255. sparknlp/reader/reader2table.py +44 -0
  256. sparknlp/reader/reader_assembler.py +159 -0
  257. sparknlp/reader/sparknlp_reader.py +461 -0
  258. sparknlp/training/__init__.py +20 -0
  259. sparknlp/training/_tf_graph_builders/__init__.py +0 -0
  260. sparknlp/training/_tf_graph_builders/graph_builders.py +299 -0
  261. sparknlp/training/_tf_graph_builders/ner_dl/__init__.py +0 -0
  262. sparknlp/training/_tf_graph_builders/ner_dl/create_graph.py +41 -0
  263. sparknlp/training/_tf_graph_builders/ner_dl/dataset_encoder.py +78 -0
  264. sparknlp/training/_tf_graph_builders/ner_dl/ner_model.py +521 -0
  265. sparknlp/training/_tf_graph_builders/ner_dl/ner_model_saver.py +62 -0
  266. sparknlp/training/_tf_graph_builders/ner_dl/sentence_grouper.py +28 -0
  267. sparknlp/training/_tf_graph_builders/tf2contrib/__init__.py +36 -0
  268. sparknlp/training/_tf_graph_builders/tf2contrib/core_rnn_cell.py +385 -0
  269. sparknlp/training/_tf_graph_builders/tf2contrib/fused_rnn_cell.py +183 -0
  270. sparknlp/training/_tf_graph_builders/tf2contrib/gru_ops.py +235 -0
  271. sparknlp/training/_tf_graph_builders/tf2contrib/lstm_ops.py +665 -0
  272. sparknlp/training/_tf_graph_builders/tf2contrib/rnn.py +245 -0
  273. sparknlp/training/_tf_graph_builders/tf2contrib/rnn_cell.py +4006 -0
  274. sparknlp/training/_tf_graph_builders_1x/__init__.py +0 -0
  275. sparknlp/training/_tf_graph_builders_1x/graph_builders.py +277 -0
  276. sparknlp/training/_tf_graph_builders_1x/ner_dl/__init__.py +0 -0
  277. sparknlp/training/_tf_graph_builders_1x/ner_dl/create_graph.py +34 -0
  278. sparknlp/training/_tf_graph_builders_1x/ner_dl/dataset_encoder.py +78 -0
  279. sparknlp/training/_tf_graph_builders_1x/ner_dl/ner_model.py +532 -0
  280. sparknlp/training/_tf_graph_builders_1x/ner_dl/ner_model_saver.py +62 -0
  281. sparknlp/training/_tf_graph_builders_1x/ner_dl/sentence_grouper.py +28 -0
  282. sparknlp/training/conll.py +150 -0
  283. sparknlp/training/conllu.py +103 -0
  284. sparknlp/training/pos.py +103 -0
  285. sparknlp/training/pub_tator.py +76 -0
  286. sparknlp/training/spacy_to_annotation.py +57 -0
  287. sparknlp/training/tfgraphs.py +5 -0
  288. sparknlp/upload_to_hub.py +149 -0
  289. sparknlp/util.py +51 -5
  290. com/__init__.pyc +0 -0
  291. com/__pycache__/__init__.cpython-36.pyc +0 -0
  292. com/johnsnowlabs/__init__.pyc +0 -0
  293. com/johnsnowlabs/__pycache__/__init__.cpython-36.pyc +0 -0
  294. com/johnsnowlabs/nlp/__init__.pyc +0 -0
  295. com/johnsnowlabs/nlp/__pycache__/__init__.cpython-36.pyc +0 -0
  296. spark_nlp-2.6.3rc1.dist-info/METADATA +0 -36
  297. spark_nlp-2.6.3rc1.dist-info/RECORD +0 -48
  298. sparknlp/__init__.pyc +0 -0
  299. sparknlp/__pycache__/__init__.cpython-36.pyc +0 -0
  300. sparknlp/__pycache__/annotation.cpython-36.pyc +0 -0
  301. sparknlp/__pycache__/annotator.cpython-36.pyc +0 -0
  302. sparknlp/__pycache__/base.cpython-36.pyc +0 -0
  303. sparknlp/__pycache__/common.cpython-36.pyc +0 -0
  304. sparknlp/__pycache__/embeddings.cpython-36.pyc +0 -0
  305. sparknlp/__pycache__/functions.cpython-36.pyc +0 -0
  306. sparknlp/__pycache__/internal.cpython-36.pyc +0 -0
  307. sparknlp/__pycache__/pretrained.cpython-36.pyc +0 -0
  308. sparknlp/__pycache__/storage.cpython-36.pyc +0 -0
  309. sparknlp/__pycache__/training.cpython-36.pyc +0 -0
  310. sparknlp/__pycache__/util.cpython-36.pyc +0 -0
  311. sparknlp/annotation.pyc +0 -0
  312. sparknlp/annotator.py +0 -3006
  313. sparknlp/annotator.pyc +0 -0
  314. sparknlp/base.py +0 -347
  315. sparknlp/base.pyc +0 -0
  316. sparknlp/common.py +0 -193
  317. sparknlp/common.pyc +0 -0
  318. sparknlp/embeddings.py +0 -40
  319. sparknlp/embeddings.pyc +0 -0
  320. sparknlp/internal.py +0 -288
  321. sparknlp/internal.pyc +0 -0
  322. sparknlp/pretrained.py +0 -123
  323. sparknlp/pretrained.pyc +0 -0
  324. sparknlp/storage.py +0 -32
  325. sparknlp/storage.pyc +0 -0
  326. sparknlp/training.py +0 -62
  327. sparknlp/training.pyc +0 -0
  328. sparknlp/util.pyc +0 -0
  329. {spark_nlp-2.6.3rc1.dist-info → spark_nlp-6.2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,665 @@
1
+ # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """LSTM Block Cell ops."""
16
+ from __future__ import absolute_import
17
+ from __future__ import division
18
+ from __future__ import print_function
19
+
20
+ import abc
21
+
22
+ import six
23
+ from tensorflow.python.framework import dtypes
24
+ from tensorflow.python.framework import ops
25
+ from tensorflow.python.keras.engine import input_spec
26
+ from tensorflow.python.layers import base as base_layer
27
+ from tensorflow.python.ops import array_ops
28
+ from tensorflow.python.ops import gen_rnn_ops
29
+ from tensorflow.python.ops import init_ops
30
+ from tensorflow.python.ops import math_ops
31
+ from tensorflow.python.ops import nn_ops
32
+ from tensorflow.python.ops import rnn_cell_impl
33
+
34
+ LayerRNNCell = rnn_cell_impl.LayerRNNCell # pylint: disable=invalid-name
35
+
36
+
37
+ # pylint: disable=invalid-name
38
+ def _lstm_block_cell(x,
39
+ cs_prev,
40
+ h_prev,
41
+ w,
42
+ b,
43
+ wci=None,
44
+ wcf=None,
45
+ wco=None,
46
+ forget_bias=None,
47
+ cell_clip=None,
48
+ use_peephole=None,
49
+ name=None):
50
+ r"""Computes the LSTM cell forward propagation for 1 time step.
51
+
52
+ This implementation uses 1 weight matrix and 1 bias vector, and there's an
53
+ optional peephole connection.
54
+
55
+ This kernel op implements the following mathematical equations:
56
+
57
+ ```python
58
+ xh = [x, h_prev]
59
+ [i, ci, f, o] = xh * w + b
60
+ f = f + forget_bias
61
+
62
+ if not use_peephole:
63
+ wci = wcf = wco = 0
64
+
65
+ i = sigmoid(cs_prev * wci + i)
66
+ f = sigmoid(cs_prev * wcf + f)
67
+ ci = tanh(ci)
68
+
69
+ cs = ci .* i + cs_prev .* f
70
+ cs = clip(cs, cell_clip)
71
+
72
+ o = sigmoid(cs * wco + o)
73
+ co = tanh(cs)
74
+ h = co .* o
75
+ ```
76
+
77
+ Args:
78
+ x: A `Tensor`. Must be one of the following types: `float32`.
79
+ The input to the LSTM cell, shape (batch_size, num_inputs).
80
+ cs_prev: A `Tensor`. Must have the same type as `x`.
81
+ Value of the cell state at previous time step.
82
+ h_prev: A `Tensor`. Must have the same type as `x`.
83
+ Output of the previous cell at previous time step.
84
+ w: A `Tensor`. Must have the same type as `x`. The weight matrix.
85
+ b: A `Tensor`. Must have the same type as `x`. The bias vector.
86
+ wci: A `Tensor`. Must have the same type as `x`.
87
+ The weight matrix for input gate peephole connection.
88
+ wcf: A `Tensor`. Must have the same type as `x`.
89
+ The weight matrix for forget gate peephole connection.
90
+ wco: A `Tensor`. Must have the same type as `x`.
91
+ The weight matrix for output gate peephole connection.
92
+ forget_bias: An optional `float`. Defaults to `1`. The forget gate bias.
93
+ cell_clip: An optional `float`. Defaults to `-1` (no clipping).
94
+ Value to clip the 'cs' value to. Disable by setting to negative value.
95
+ use_peephole: An optional `bool`. Defaults to `False`.
96
+ Whether to use peephole weights.
97
+ name: A name for the operation (optional).
98
+
99
+ Returns:
100
+ A tuple of `Tensor` objects (i, cs, f, o, ci, co, h).
101
+ i: A `Tensor`. Has the same type as `x`. The input gate.
102
+ cs: A `Tensor`. Has the same type as `x`. The cell state before the tanh.
103
+ f: A `Tensor`. Has the same type as `x`. The forget gate.
104
+ o: A `Tensor`. Has the same type as `x`. The output gate.
105
+ ci: A `Tensor`. Has the same type as `x`. The cell input.
106
+ co: A `Tensor`. Has the same type as `x`. The cell after the tanh.
107
+ h: A `Tensor`. Has the same type as `x`. The output h vector.
108
+
109
+ Raises:
110
+ ValueError: If cell_size is None.
111
+ """
112
+ if wci is None:
113
+ cell_size = cs_prev.get_shape().with_rank(2).dims[1].value
114
+ if cell_size is None:
115
+ raise ValueError("cell_size from `cs_prev` should not be None.")
116
+ wci = array_ops.constant(0, dtype=dtypes.float32, shape=[cell_size])
117
+ wcf = wci
118
+ wco = wci
119
+
120
+ # pylint: disable=protected-access
121
+ return gen_rnn_ops.lstm_block_cell(
122
+ x=x,
123
+ cs_prev=cs_prev,
124
+ h_prev=h_prev,
125
+ w=w,
126
+ wci=wci,
127
+ wcf=wcf,
128
+ wco=wco,
129
+ b=b,
130
+ forget_bias=forget_bias,
131
+ cell_clip=cell_clip if cell_clip is not None else -1,
132
+ use_peephole=use_peephole,
133
+ name=name)
134
+ # pylint: enable=protected-access
135
+
136
+
137
+ def _block_lstm(seq_len_max,
138
+ x,
139
+ w,
140
+ b,
141
+ cs_prev=None,
142
+ h_prev=None,
143
+ wci=None,
144
+ wcf=None,
145
+ wco=None,
146
+ forget_bias=None,
147
+ cell_clip=None,
148
+ use_peephole=None,
149
+ name=None):
150
+ r"""TODO(williamchan): add doc.
151
+
152
+ Args:
153
+ seq_len_max: A `Tensor` of type `int64`.
154
+ x: A list of at least 1 `Tensor` objects of the same type.
155
+ w: A `Tensor`. Must have the same type as `x`.
156
+ b: A `Tensor`. Must have the same type as `x`.
157
+ cs_prev: A `Tensor`. Must have the same type as `x`.
158
+ h_prev: A `Tensor`. Must have the same type as `x`.
159
+ wci: A `Tensor`. Must have the same type as `x`.
160
+ wcf: A `Tensor`. Must have the same type as `x`.
161
+ wco: A `Tensor`. Must have the same type as `x`.
162
+ forget_bias: An optional `float`. Defaults to `1`.
163
+ cell_clip: An optional `float`. Defaults to `-1` (no clipping).
164
+ use_peephole: An optional `bool`. Defaults to `False`.
165
+ name: A name for the operation (optional).
166
+
167
+ Returns:
168
+ A tuple of `Tensor` objects (i, cs, f, o, ci, co, h).
169
+ i: A list with the same number of `Tensor` objects as `x` of `Tensor`
170
+ objects of the same type as x.
171
+ cs: A list with the same number of `Tensor` objects as `x` of `Tensor`
172
+ objects of the same type as x.
173
+ f: A list with the same number of `Tensor` objects as `x` of `Tensor`
174
+ objects of the same type as x.
175
+ o: A list with the same number of `Tensor` objects as `x` of `Tensor`
176
+ objects of the same type as x.
177
+ ci: A list with the same number of `Tensor` objects as `x` of `Tensor`
178
+ objects of the same type as x.
179
+ co: A list with the same number of `Tensor` objects as `x` of `Tensor`
180
+ objects of the same type as x.
181
+ h: A list with the same number of `Tensor` objects as `x` of `Tensor`
182
+ objects of the same type as x.
183
+
184
+ Raises:
185
+ ValueError: If `b` does not have a valid shape.
186
+ """
187
+ dtype = x[0].dtype
188
+ batch_size = x[0].get_shape().with_rank(2).dims[0].value
189
+ cell_size4 = b.get_shape().with_rank(1).dims[0].value
190
+ if cell_size4 is None:
191
+ raise ValueError("`b` shape must not be None.")
192
+ cell_size = cell_size4 / 4
193
+ zero_state = None
194
+ if cs_prev is None or h_prev is None:
195
+ zero_state = array_ops.constant(
196
+ 0, dtype=dtype, shape=[batch_size, cell_size])
197
+ if cs_prev is None:
198
+ cs_prev = zero_state
199
+ if h_prev is None:
200
+ h_prev = zero_state
201
+ if wci is None:
202
+ wci = array_ops.constant(0, dtype=dtype, shape=[cell_size])
203
+ wcf = wci
204
+ wco = wci
205
+
206
+ # pylint: disable=protected-access
207
+ i, cs, f, o, ci, co, h = gen_rnn_ops.block_lstm(
208
+ seq_len_max=seq_len_max,
209
+ x=array_ops.stack(x),
210
+ cs_prev=cs_prev,
211
+ h_prev=h_prev,
212
+ w=w,
213
+ wci=wci,
214
+ wcf=wcf,
215
+ wco=wco,
216
+ b=b,
217
+ forget_bias=forget_bias,
218
+ cell_clip=cell_clip if cell_clip is not None else -1,
219
+ name=name,
220
+ use_peephole=use_peephole)
221
+
222
+ return array_ops.unstack(i), array_ops.unstack(cs), array_ops.unstack(
223
+ f), array_ops.unstack(o), array_ops.unstack(ci), array_ops.unstack(
224
+ co), array_ops.unstack(h)
225
+ # pylint: enable=protected-access
226
+ # pylint: enable=invalid-name
227
+
228
+
229
+ @ops.RegisterGradient("LSTMBlockCell")
230
+ def _LSTMBlockCellGrad(op, *grad):
231
+ """Gradient for LSTMBlockCell."""
232
+ (x, cs_prev, h_prev, w, wci, wcf, wco, b) = op.inputs
233
+ (i, cs, f, o, ci, co, _) = op.outputs
234
+ (_, cs_grad, _, _, _, _, h_grad) = grad
235
+
236
+ batch_size = x.get_shape().with_rank(2).dims[0].value
237
+ if batch_size is None:
238
+ batch_size = -1
239
+ input_size = x.get_shape().with_rank(2).dims[1].value
240
+ if input_size is None:
241
+ raise ValueError("input_size from `x` should not be None.")
242
+ cell_size = cs_prev.get_shape().with_rank(2).dims[1].value
243
+ if cell_size is None:
244
+ raise ValueError("cell_size from `cs_prev` should not be None.")
245
+
246
+ (cs_prev_grad, dgates, wci_grad, wcf_grad,
247
+ wco_grad) = gen_rnn_ops.lstm_block_cell_grad(
248
+ x=x,
249
+ cs_prev=cs_prev,
250
+ h_prev=h_prev,
251
+ w=w,
252
+ wci=wci,
253
+ wcf=wcf,
254
+ wco=wco,
255
+ b=b,
256
+ i=i,
257
+ cs=cs,
258
+ f=f,
259
+ o=o,
260
+ ci=ci,
261
+ co=co,
262
+ cs_grad=cs_grad,
263
+ h_grad=h_grad,
264
+ use_peephole=op.get_attr("use_peephole"))
265
+
266
+ # Backprop from dgates to xh.
267
+ xh_grad = math_ops.matmul(dgates, w, transpose_b=True)
268
+
269
+ x_grad = array_ops.slice(xh_grad, (0, 0), (batch_size, input_size))
270
+ x_grad.get_shape().merge_with(x.get_shape())
271
+
272
+ h_prev_grad = array_ops.slice(xh_grad, (0, input_size),
273
+ (batch_size, cell_size))
274
+ h_prev_grad.get_shape().merge_with(h_prev.get_shape())
275
+
276
+ # Backprop from dgates to w.
277
+ xh = array_ops.concat([x, h_prev], 1)
278
+ w_grad = math_ops.matmul(xh, dgates, transpose_a=True)
279
+ w_grad.get_shape().merge_with(w.get_shape())
280
+
281
+ # Backprop from dgates to b.
282
+ b_grad = nn_ops.bias_add_grad(dgates)
283
+ b_grad.get_shape().merge_with(b.get_shape())
284
+
285
+ return (x_grad, cs_prev_grad, h_prev_grad, w_grad, wci_grad, wcf_grad,
286
+ wco_grad, b_grad)
287
+
288
+
289
+ class LSTMBlockCell(LayerRNNCell):
290
+ """Basic LSTM recurrent network cell.
291
+
292
+ The implementation is based on: http://arxiv.org/abs/1409.2329.
293
+
294
+ We add `forget_bias` (default: 1) to the biases of the forget gate in order to
295
+ reduce the scale of forgetting in the beginning of the training.
296
+
297
+ Unlike `rnn_cell_impl.LSTMCell`, this is a monolithic op and should be much
298
+ faster. The weight and bias matrices should be compatible as long as the
299
+ variable scope matches.
300
+ """
301
+
302
+ def __init__(self,
303
+ num_units,
304
+ forget_bias=1.0,
305
+ cell_clip=None,
306
+ use_peephole=False,
307
+ dtype=None,
308
+ reuse=None,
309
+ name="lstm_cell"):
310
+ """Initialize the basic LSTM cell.
311
+
312
+ Args:
313
+ num_units: int, The number of units in the LSTM cell.
314
+ forget_bias: float, The bias added to forget gates (see above).
315
+ cell_clip: An optional `float`. Defaults to `-1` (no clipping).
316
+ use_peephole: Whether to use peephole connections or not.
317
+ dtype: the variable dtype of this layer. Default to tf.float32.
318
+ reuse: (optional) boolean describing whether to reuse variables in an
319
+ existing scope. If not `True`, and the existing scope already has the
320
+ given variables, an error is raised.
321
+ name: String, the name of the layer. Layers with the same name will
322
+ share weights, but to avoid mistakes we require reuse=True in such
323
+ cases. By default this is "lstm_cell", for variable-name compatibility
324
+ with `tf.compat.v1.nn.rnn_cell.LSTMCell`.
325
+
326
+ When restoring from CudnnLSTM-trained checkpoints, must use
327
+ CudnnCompatibleLSTMBlockCell instead.
328
+ """
329
+ super(LSTMBlockCell, self).__init__(_reuse=reuse, dtype=dtype, name=name)
330
+ self._num_units = num_units
331
+ self._forget_bias = forget_bias
332
+ self._use_peephole = use_peephole
333
+ self._cell_clip = cell_clip if cell_clip is not None else -1
334
+ self._names = {
335
+ "W": "kernel",
336
+ "b": "bias",
337
+ "wci": "w_i_diag",
338
+ "wcf": "w_f_diag",
339
+ "wco": "w_o_diag",
340
+ "scope": "lstm_cell"
341
+ }
342
+ # Inputs must be 2-dimensional.
343
+ self.input_spec = input_spec.InputSpec(ndim=2)
344
+
345
+ @property
346
+ def state_size(self):
347
+ return rnn_cell_impl.LSTMStateTuple(self._num_units, self._num_units)
348
+
349
+ @property
350
+ def output_size(self):
351
+ return self._num_units
352
+
353
+ def build(self, inputs_shape):
354
+ if not inputs_shape.dims[1].value:
355
+ raise ValueError(
356
+ "Expecting inputs_shape[1] to be set: %s" % str(inputs_shape))
357
+ input_size = inputs_shape.dims[1].value
358
+ self._kernel = self.add_variable(
359
+ self._names["W"], [input_size + self._num_units, self._num_units * 4])
360
+ self._bias = self.add_variable(
361
+ self._names["b"], [self._num_units * 4],
362
+ initializer=init_ops.constant_initializer(0.0))
363
+ if self._use_peephole:
364
+ self._w_i_diag = self.add_variable(self._names["wci"], [self._num_units])
365
+ self._w_f_diag = self.add_variable(self._names["wcf"], [self._num_units])
366
+ self._w_o_diag = self.add_variable(self._names["wco"], [self._num_units])
367
+
368
+ self.built = True
369
+
370
+ def call(self, inputs, state):
371
+ """Long short-term memory cell (LSTM)."""
372
+ if len(state) != 2:
373
+ raise ValueError("Expecting state to be a tuple with length 2.")
374
+
375
+ if self._use_peephole:
376
+ wci = self._w_i_diag
377
+ wcf = self._w_f_diag
378
+ wco = self._w_o_diag
379
+ else:
380
+ wci = wcf = wco = array_ops.zeros([self._num_units], dtype=self.dtype)
381
+
382
+ (cs_prev, h_prev) = state
383
+ (_, cs, _, _, _, _, h) = _lstm_block_cell(
384
+ inputs,
385
+ cs_prev,
386
+ h_prev,
387
+ self._kernel,
388
+ self._bias,
389
+ wci=wci,
390
+ wcf=wcf,
391
+ wco=wco,
392
+ forget_bias=self._forget_bias,
393
+ cell_clip=self._cell_clip,
394
+ use_peephole=self._use_peephole)
395
+
396
+ new_state = rnn_cell_impl.LSTMStateTuple(cs, h)
397
+ return h, new_state
398
+
399
+
400
+ @six.add_metaclass(abc.ABCMeta)
401
+ class LSTMBlockWrapper(base_layer.Layer):
402
+ """This is a helper class that provides housekeeping for LSTM cells.
403
+
404
+ This may be useful for alternative LSTM and similar type of cells.
405
+ The subclasses must implement `_call_cell` method and `num_units` property.
406
+ """
407
+
408
+ @abc.abstractproperty
409
+ def num_units(self):
410
+ """Number of units in this cell (output dimension)."""
411
+
412
+ @abc.abstractmethod
413
+ def _call_cell(self, inputs, initial_cell_state, initial_output, dtype,
414
+ sequence_length):
415
+ """Run this LSTM on inputs, starting from the given state.
416
+
417
+ This method must be implemented by subclasses and does the actual work
418
+ of calling the cell.
419
+
420
+ Args:
421
+ inputs: `3-D` tensor with shape `[time_len, batch_size, input_size]`
422
+ initial_cell_state: initial value for cell state, shape `[batch_size,
423
+ self._num_units]`
424
+ initial_output: initial value of cell output, shape `[batch_size,
425
+ self._num_units]`
426
+ dtype: The data type for the initial state and expected output.
427
+ sequence_length: Specifies the length of each sequence in inputs. An int32
428
+ or int64 vector (tensor) size [batch_size], values in [0, time_len) or
429
+ None.
430
+
431
+ Returns:
432
+ A pair containing:
433
+
434
+ - State: A `3-D` tensor of shape `[time_len, batch_size, output_size]`
435
+ - Output: A `3-D` tensor of shape `[time_len, batch_size, output_size]`
436
+ """
437
+ pass
438
+
439
+ def call(self, inputs, initial_state=None, dtype=None, sequence_length=None):
440
+ """Run this LSTM on inputs, starting from the given state.
441
+
442
+ Args:
443
+ inputs: `3-D` tensor with shape `[time_len, batch_size, input_size]`.
444
+ initial_state: a tuple `(initial_cell_state, initial_output)` with tensors
445
+ of shape `[batch_size, self._num_units]`. If this is not provided, the
446
+ cell is expected to create a zero initial state of type `dtype`.
447
+ dtype: The data type for the initial state and expected output. Required
448
+ if `initial_state` is not provided or RNN state has a heterogeneous
449
+ dtype.
450
+ sequence_length: Specifies the length of each sequence in inputs. An
451
+ `int32` or `int64` vector (tensor) size `[batch_size]`, values in `[0,
452
+ time_len).`
453
+ Defaults to `time_len` for each element.
454
+
455
+ Returns:
456
+ A pair containing:
457
+
458
+ - Output: A `3-D` tensor of shape `[time_len, batch_size, output_size]`
459
+ or a list of time_len tensors of shape `[batch_size, output_size]`,
460
+ to match the type of the `inputs`.
461
+ - Final state: a tuple `(cell_state, output)` matching `initial_state`.
462
+
463
+ Raises:
464
+ ValueError: in case of shape mismatches
465
+ """
466
+ is_list = isinstance(inputs, list)
467
+ if is_list:
468
+ inputs = array_ops.stack(inputs)
469
+ inputs_shape = inputs.get_shape().with_rank(3)
470
+ if not inputs_shape[2]:
471
+ raise ValueError("Expecting inputs_shape[2] to be set: %s" % inputs_shape)
472
+ batch_size = inputs_shape.dims[1].value
473
+ if batch_size is None:
474
+ batch_size = array_ops.shape(inputs)[1]
475
+ time_len = inputs_shape.dims[0].value
476
+ if time_len is None:
477
+ time_len = array_ops.shape(inputs)[0]
478
+
479
+ # Provide default values for initial_state and dtype
480
+ if initial_state is None:
481
+ if dtype is None:
482
+ raise ValueError("Either initial_state or dtype needs to be specified")
483
+ z = array_ops.zeros(
484
+ array_ops.stack([batch_size, self.num_units]), dtype=dtype)
485
+ initial_state = z, z
486
+ else:
487
+ if len(initial_state) != 2:
488
+ raise ValueError(
489
+ "Expecting initial_state to be a tuple with length 2 or None")
490
+ if dtype is None:
491
+ dtype = initial_state[0].dtype
492
+
493
+ # create the actual cell
494
+ if sequence_length is not None:
495
+ sequence_length = ops.convert_to_tensor(sequence_length)
496
+ initial_cell_state, initial_output = initial_state # pylint: disable=unpacking-non-sequence
497
+ cell_states, outputs = self._call_cell(
498
+ inputs, initial_cell_state, initial_output, dtype, sequence_length)
499
+
500
+ if sequence_length is not None:
501
+ # Mask out the part beyond sequence_length
502
+ mask = array_ops.transpose(
503
+ array_ops.sequence_mask(sequence_length, time_len, dtype=dtype),
504
+ [1, 0])
505
+ mask = array_ops.tile(
506
+ array_ops.expand_dims(mask, [-1]), [1, 1, self.num_units])
507
+ outputs *= mask
508
+ # Prepend initial states to cell_states and outputs for indexing to work
509
+ # correctly,since we want to access the last valid state at
510
+ # sequence_length - 1, which can even be -1, corresponding to the
511
+ # initial state.
512
+ mod_cell_states = array_ops.concat(
513
+ [array_ops.expand_dims(initial_cell_state, [0]), cell_states], 0)
514
+ mod_outputs = array_ops.concat(
515
+ [array_ops.expand_dims(initial_output, [0]), outputs], 0)
516
+ final_cell_state = self._gather_states(mod_cell_states, sequence_length,
517
+ batch_size)
518
+ final_output = self._gather_states(mod_outputs, sequence_length,
519
+ batch_size)
520
+ else:
521
+ # No sequence_lengths used: final state is the last state
522
+ final_cell_state = cell_states[-1]
523
+ final_output = outputs[-1]
524
+
525
+ if is_list:
526
+ # Input was a list, so return a list
527
+ outputs = array_ops.unstack(outputs)
528
+
529
+ final_state = rnn_cell_impl.LSTMStateTuple(final_cell_state, final_output)
530
+ return outputs, final_state
531
+
532
+ def _gather_states(self, data, indices, batch_size):
533
+ """Produce `out`, s.t. out(i, j) = data(indices(i), i, j)."""
534
+ return array_ops.gather_nd(
535
+ data, array_ops.stack([indices, math_ops.range(batch_size)], axis=1))
536
+
537
+
538
+ class LSTMBlockFusedCell(LSTMBlockWrapper):
539
+ """FusedRNNCell implementation of LSTM.
540
+
541
+ This is an extremely efficient LSTM implementation, that uses a single TF op
542
+ for the entire LSTM. It should be both faster and more memory-efficient than
543
+ LSTMBlockCell defined above.
544
+
545
+ The implementation is based on: http://arxiv.org/abs/1409.2329.
546
+
547
+ We add forget_bias (default: 1) to the biases of the forget gate in order to
548
+ reduce the scale of forgetting in the beginning of the training.
549
+
550
+ The variable naming is consistent with `rnn_cell_impl.LSTMCell`.
551
+ """
552
+
553
+ def __init__(self,
554
+ num_units,
555
+ forget_bias=1.0,
556
+ cell_clip=None,
557
+ use_peephole=False,
558
+ reuse=None,
559
+ dtype=None,
560
+ name="lstm_fused_cell"):
561
+ """Initialize the LSTM cell.
562
+
563
+ Args:
564
+ num_units: int, The number of units in the LSTM cell.
565
+ forget_bias: float, The bias added to forget gates (see above).
566
+ cell_clip: clip the cell to this value. Defaults is no cell clipping.
567
+ use_peephole: Whether to use peephole connections or not.
568
+ reuse: (optional) boolean describing whether to reuse variables in an
569
+ existing scope. If not `True`, and the existing scope already has the
570
+ given variables, an error is raised.
571
+ dtype: the dtype of variables of this layer.
572
+ name: String, the name of the layer. Layers with the same name will
573
+ share weights, but to avoid mistakes we require reuse=True in such
574
+ cases. By default this is "lstm_cell", for variable-name compatibility
575
+ with `tf.compat.v1.nn.rnn_cell.LSTMCell`.
576
+ """
577
+ super(LSTMBlockFusedCell, self).__init__(
578
+ _reuse=reuse, name=name, dtype=dtype)
579
+ self._num_units = num_units
580
+ self._forget_bias = forget_bias
581
+ self._cell_clip = cell_clip if cell_clip is not None else -1
582
+ self._use_peephole = use_peephole
583
+
584
+ # Inputs must be 3-dimensional.
585
+ self.input_spec = input_spec.InputSpec(ndim=3)
586
+
587
+ @property
588
+ def num_units(self):
589
+ """Number of units in this cell (output dimension)."""
590
+ return self._num_units
591
+
592
+ def build(self, input_shape):
593
+ input_size = input_shape.dims[2].value
594
+ self._kernel = self.add_variable(
595
+ "kernel", [input_size + self._num_units, self._num_units * 4])
596
+ self._bias = self.add_variable(
597
+ "bias", [self._num_units * 4],
598
+ initializer=init_ops.constant_initializer(0.0))
599
+ if self._use_peephole:
600
+ self._w_i_diag = self.add_variable("w_i_diag", [self._num_units])
601
+ self._w_f_diag = self.add_variable("w_f_diag", [self._num_units])
602
+ self._w_o_diag = self.add_variable("w_o_diag", [self._num_units])
603
+
604
+ self.built = True
605
+
606
+ def _call_cell(self,
607
+ inputs,
608
+ initial_cell_state=None,
609
+ initial_output=None,
610
+ dtype=None,
611
+ sequence_length=None):
612
+ """Run this LSTM on inputs, starting from the given state.
613
+
614
+ Args:
615
+ inputs: `3-D` tensor with shape `[time_len, batch_size, input_size]`
616
+ initial_cell_state: initial value for cell state, shape `[batch_size,
617
+ self._num_units]`
618
+ initial_output: initial value of cell output, shape `[batch_size,
619
+ self._num_units]`
620
+ dtype: The data type for the initial state and expected output.
621
+ sequence_length: Specifies the length of each sequence in inputs. An
622
+ `int32` or `int64` vector (tensor) size `[batch_size]`, values in `[0,
623
+ time_len)` or None.
624
+
625
+ Returns:
626
+ A pair containing:
627
+
628
+ - Cell state (cs): A `3-D` tensor of shape `[time_len, batch_size,
629
+ output_size]`
630
+ - Output (h): A `3-D` tensor of shape `[time_len, batch_size,
631
+ output_size]`
632
+ """
633
+
634
+ inputs_shape = inputs.get_shape().with_rank(3)
635
+ time_len = inputs_shape.dims[0].value
636
+ if time_len is None:
637
+ time_len = array_ops.shape(inputs)[0]
638
+
639
+ if self._use_peephole:
640
+ wci = self._w_i_diag
641
+ wco = self._w_o_diag
642
+ wcf = self._w_f_diag
643
+ else:
644
+ wci = wcf = wco = array_ops.zeros([self._num_units], dtype=dtype)
645
+
646
+ if sequence_length is None:
647
+ max_seq_len = math_ops.cast(time_len, dtypes.int64)
648
+ else:
649
+ max_seq_len = math_ops.cast(math_ops.reduce_max(sequence_length),
650
+ dtypes.int64)
651
+
652
+ _, cs, _, _, _, _, h = gen_rnn_ops.block_lstm(
653
+ seq_len_max=max_seq_len,
654
+ x=inputs,
655
+ cs_prev=initial_cell_state,
656
+ h_prev=initial_output,
657
+ w=self._kernel,
658
+ wci=wci,
659
+ wcf=wcf,
660
+ wco=wco,
661
+ b=self._bias,
662
+ forget_bias=self._forget_bias,
663
+ cell_clip=self._cell_clip,
664
+ use_peephole=self._use_peephole)
665
+ return cs, h