spark-nlp 2.6.3rc1__py2.py3-none-any.whl → 6.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (329) hide show
  1. com/johnsnowlabs/ml/__init__.py +0 -0
  2. com/johnsnowlabs/ml/ai/__init__.py +10 -0
  3. com/johnsnowlabs/nlp/__init__.py +4 -2
  4. spark_nlp-6.2.1.dist-info/METADATA +362 -0
  5. spark_nlp-6.2.1.dist-info/RECORD +292 -0
  6. {spark_nlp-2.6.3rc1.dist-info → spark_nlp-6.2.1.dist-info}/WHEEL +1 -1
  7. sparknlp/__init__.py +281 -27
  8. sparknlp/annotation.py +137 -6
  9. sparknlp/annotation_audio.py +61 -0
  10. sparknlp/annotation_image.py +82 -0
  11. sparknlp/annotator/__init__.py +93 -0
  12. sparknlp/annotator/audio/__init__.py +16 -0
  13. sparknlp/annotator/audio/hubert_for_ctc.py +188 -0
  14. sparknlp/annotator/audio/wav2vec2_for_ctc.py +161 -0
  15. sparknlp/annotator/audio/whisper_for_ctc.py +251 -0
  16. sparknlp/annotator/chunk2_doc.py +85 -0
  17. sparknlp/annotator/chunker.py +137 -0
  18. sparknlp/annotator/classifier_dl/__init__.py +61 -0
  19. sparknlp/annotator/classifier_dl/albert_for_multiple_choice.py +161 -0
  20. sparknlp/annotator/classifier_dl/albert_for_question_answering.py +172 -0
  21. sparknlp/annotator/classifier_dl/albert_for_sequence_classification.py +201 -0
  22. sparknlp/annotator/classifier_dl/albert_for_token_classification.py +179 -0
  23. sparknlp/annotator/classifier_dl/albert_for_zero_shot_classification.py +211 -0
  24. sparknlp/annotator/classifier_dl/bart_for_zero_shot_classification.py +225 -0
  25. sparknlp/annotator/classifier_dl/bert_for_multiple_choice.py +161 -0
  26. sparknlp/annotator/classifier_dl/bert_for_question_answering.py +168 -0
  27. sparknlp/annotator/classifier_dl/bert_for_sequence_classification.py +202 -0
  28. sparknlp/annotator/classifier_dl/bert_for_token_classification.py +177 -0
  29. sparknlp/annotator/classifier_dl/bert_for_zero_shot_classification.py +212 -0
  30. sparknlp/annotator/classifier_dl/camembert_for_question_answering.py +168 -0
  31. sparknlp/annotator/classifier_dl/camembert_for_sequence_classification.py +205 -0
  32. sparknlp/annotator/classifier_dl/camembert_for_token_classification.py +173 -0
  33. sparknlp/annotator/classifier_dl/camembert_for_zero_shot_classification.py +202 -0
  34. sparknlp/annotator/classifier_dl/classifier_dl.py +320 -0
  35. sparknlp/annotator/classifier_dl/deberta_for_question_answering.py +168 -0
  36. sparknlp/annotator/classifier_dl/deberta_for_sequence_classification.py +198 -0
  37. sparknlp/annotator/classifier_dl/deberta_for_token_classification.py +175 -0
  38. sparknlp/annotator/classifier_dl/deberta_for_zero_shot_classification.py +193 -0
  39. sparknlp/annotator/classifier_dl/distil_bert_for_question_answering.py +168 -0
  40. sparknlp/annotator/classifier_dl/distil_bert_for_sequence_classification.py +201 -0
  41. sparknlp/annotator/classifier_dl/distil_bert_for_token_classification.py +175 -0
  42. sparknlp/annotator/classifier_dl/distil_bert_for_zero_shot_classification.py +211 -0
  43. sparknlp/annotator/classifier_dl/distilbert_for_multiple_choice.py +161 -0
  44. sparknlp/annotator/classifier_dl/longformer_for_question_answering.py +168 -0
  45. sparknlp/annotator/classifier_dl/longformer_for_sequence_classification.py +201 -0
  46. sparknlp/annotator/classifier_dl/longformer_for_token_classification.py +176 -0
  47. sparknlp/annotator/classifier_dl/mpnet_for_question_answering.py +148 -0
  48. sparknlp/annotator/classifier_dl/mpnet_for_sequence_classification.py +188 -0
  49. sparknlp/annotator/classifier_dl/mpnet_for_token_classification.py +173 -0
  50. sparknlp/annotator/classifier_dl/multi_classifier_dl.py +395 -0
  51. sparknlp/annotator/classifier_dl/roberta_for_multiple_choice.py +161 -0
  52. sparknlp/annotator/classifier_dl/roberta_for_question_answering.py +168 -0
  53. sparknlp/annotator/classifier_dl/roberta_for_sequence_classification.py +201 -0
  54. sparknlp/annotator/classifier_dl/roberta_for_token_classification.py +189 -0
  55. sparknlp/annotator/classifier_dl/roberta_for_zero_shot_classification.py +225 -0
  56. sparknlp/annotator/classifier_dl/sentiment_dl.py +378 -0
  57. sparknlp/annotator/classifier_dl/tapas_for_question_answering.py +170 -0
  58. sparknlp/annotator/classifier_dl/xlm_roberta_for_multiple_choice.py +149 -0
  59. sparknlp/annotator/classifier_dl/xlm_roberta_for_question_answering.py +168 -0
  60. sparknlp/annotator/classifier_dl/xlm_roberta_for_sequence_classification.py +201 -0
  61. sparknlp/annotator/classifier_dl/xlm_roberta_for_token_classification.py +173 -0
  62. sparknlp/annotator/classifier_dl/xlm_roberta_for_zero_shot_classification.py +225 -0
  63. sparknlp/annotator/classifier_dl/xlnet_for_sequence_classification.py +201 -0
  64. sparknlp/annotator/classifier_dl/xlnet_for_token_classification.py +176 -0
  65. sparknlp/annotator/cleaners/__init__.py +15 -0
  66. sparknlp/annotator/cleaners/cleaner.py +202 -0
  67. sparknlp/annotator/cleaners/extractor.py +191 -0
  68. sparknlp/annotator/coref/__init__.py +1 -0
  69. sparknlp/annotator/coref/spanbert_coref.py +221 -0
  70. sparknlp/annotator/cv/__init__.py +29 -0
  71. sparknlp/annotator/cv/blip_for_question_answering.py +172 -0
  72. sparknlp/annotator/cv/clip_for_zero_shot_classification.py +193 -0
  73. sparknlp/annotator/cv/convnext_for_image_classification.py +269 -0
  74. sparknlp/annotator/cv/florence2_transformer.py +180 -0
  75. sparknlp/annotator/cv/gemma3_for_multimodal.py +346 -0
  76. sparknlp/annotator/cv/internvl_for_multimodal.py +280 -0
  77. sparknlp/annotator/cv/janus_for_multimodal.py +351 -0
  78. sparknlp/annotator/cv/llava_for_multimodal.py +328 -0
  79. sparknlp/annotator/cv/mllama_for_multimodal.py +340 -0
  80. sparknlp/annotator/cv/paligemma_for_multimodal.py +308 -0
  81. sparknlp/annotator/cv/phi3_vision_for_multimodal.py +328 -0
  82. sparknlp/annotator/cv/qwen2vl_transformer.py +332 -0
  83. sparknlp/annotator/cv/smolvlm_transformer.py +426 -0
  84. sparknlp/annotator/cv/swin_for_image_classification.py +242 -0
  85. sparknlp/annotator/cv/vision_encoder_decoder_for_image_captioning.py +240 -0
  86. sparknlp/annotator/cv/vit_for_image_classification.py +217 -0
  87. sparknlp/annotator/dataframe_optimizer.py +216 -0
  88. sparknlp/annotator/date2_chunk.py +88 -0
  89. sparknlp/annotator/dependency/__init__.py +17 -0
  90. sparknlp/annotator/dependency/dependency_parser.py +294 -0
  91. sparknlp/annotator/dependency/typed_dependency_parser.py +318 -0
  92. sparknlp/annotator/document_character_text_splitter.py +228 -0
  93. sparknlp/annotator/document_normalizer.py +235 -0
  94. sparknlp/annotator/document_token_splitter.py +175 -0
  95. sparknlp/annotator/document_token_splitter_test.py +85 -0
  96. sparknlp/annotator/embeddings/__init__.py +45 -0
  97. sparknlp/annotator/embeddings/albert_embeddings.py +230 -0
  98. sparknlp/annotator/embeddings/auto_gguf_embeddings.py +539 -0
  99. sparknlp/annotator/embeddings/bert_embeddings.py +208 -0
  100. sparknlp/annotator/embeddings/bert_sentence_embeddings.py +224 -0
  101. sparknlp/annotator/embeddings/bge_embeddings.py +199 -0
  102. sparknlp/annotator/embeddings/camembert_embeddings.py +210 -0
  103. sparknlp/annotator/embeddings/chunk_embeddings.py +149 -0
  104. sparknlp/annotator/embeddings/deberta_embeddings.py +208 -0
  105. sparknlp/annotator/embeddings/distil_bert_embeddings.py +221 -0
  106. sparknlp/annotator/embeddings/doc2vec.py +352 -0
  107. sparknlp/annotator/embeddings/e5_embeddings.py +195 -0
  108. sparknlp/annotator/embeddings/e5v_embeddings.py +138 -0
  109. sparknlp/annotator/embeddings/elmo_embeddings.py +251 -0
  110. sparknlp/annotator/embeddings/instructor_embeddings.py +204 -0
  111. sparknlp/annotator/embeddings/longformer_embeddings.py +211 -0
  112. sparknlp/annotator/embeddings/minilm_embeddings.py +189 -0
  113. sparknlp/annotator/embeddings/mpnet_embeddings.py +192 -0
  114. sparknlp/annotator/embeddings/mxbai_embeddings.py +184 -0
  115. sparknlp/annotator/embeddings/nomic_embeddings.py +181 -0
  116. sparknlp/annotator/embeddings/roberta_embeddings.py +225 -0
  117. sparknlp/annotator/embeddings/roberta_sentence_embeddings.py +191 -0
  118. sparknlp/annotator/embeddings/sentence_embeddings.py +134 -0
  119. sparknlp/annotator/embeddings/snowflake_embeddings.py +202 -0
  120. sparknlp/annotator/embeddings/uae_embeddings.py +211 -0
  121. sparknlp/annotator/embeddings/universal_sentence_encoder.py +211 -0
  122. sparknlp/annotator/embeddings/word2vec.py +353 -0
  123. sparknlp/annotator/embeddings/word_embeddings.py +385 -0
  124. sparknlp/annotator/embeddings/xlm_roberta_embeddings.py +225 -0
  125. sparknlp/annotator/embeddings/xlm_roberta_sentence_embeddings.py +194 -0
  126. sparknlp/annotator/embeddings/xlnet_embeddings.py +227 -0
  127. sparknlp/annotator/er/__init__.py +16 -0
  128. sparknlp/annotator/er/entity_ruler.py +267 -0
  129. sparknlp/annotator/graph_extraction.py +368 -0
  130. sparknlp/annotator/keyword_extraction/__init__.py +16 -0
  131. sparknlp/annotator/keyword_extraction/yake_keyword_extraction.py +270 -0
  132. sparknlp/annotator/ld_dl/__init__.py +16 -0
  133. sparknlp/annotator/ld_dl/language_detector_dl.py +199 -0
  134. sparknlp/annotator/lemmatizer.py +250 -0
  135. sparknlp/annotator/matcher/__init__.py +20 -0
  136. sparknlp/annotator/matcher/big_text_matcher.py +272 -0
  137. sparknlp/annotator/matcher/date_matcher.py +303 -0
  138. sparknlp/annotator/matcher/multi_date_matcher.py +109 -0
  139. sparknlp/annotator/matcher/regex_matcher.py +221 -0
  140. sparknlp/annotator/matcher/text_matcher.py +290 -0
  141. sparknlp/annotator/n_gram_generator.py +141 -0
  142. sparknlp/annotator/ner/__init__.py +21 -0
  143. sparknlp/annotator/ner/ner_approach.py +94 -0
  144. sparknlp/annotator/ner/ner_converter.py +148 -0
  145. sparknlp/annotator/ner/ner_crf.py +397 -0
  146. sparknlp/annotator/ner/ner_dl.py +591 -0
  147. sparknlp/annotator/ner/ner_dl_graph_checker.py +293 -0
  148. sparknlp/annotator/ner/ner_overwriter.py +166 -0
  149. sparknlp/annotator/ner/zero_shot_ner_model.py +173 -0
  150. sparknlp/annotator/normalizer.py +230 -0
  151. sparknlp/annotator/openai/__init__.py +16 -0
  152. sparknlp/annotator/openai/openai_completion.py +349 -0
  153. sparknlp/annotator/openai/openai_embeddings.py +106 -0
  154. sparknlp/annotator/param/__init__.py +17 -0
  155. sparknlp/annotator/param/classifier_encoder.py +98 -0
  156. sparknlp/annotator/param/evaluation_dl_params.py +130 -0
  157. sparknlp/annotator/pos/__init__.py +16 -0
  158. sparknlp/annotator/pos/perceptron.py +263 -0
  159. sparknlp/annotator/sentence/__init__.py +17 -0
  160. sparknlp/annotator/sentence/sentence_detector.py +290 -0
  161. sparknlp/annotator/sentence/sentence_detector_dl.py +467 -0
  162. sparknlp/annotator/sentiment/__init__.py +17 -0
  163. sparknlp/annotator/sentiment/sentiment_detector.py +208 -0
  164. sparknlp/annotator/sentiment/vivekn_sentiment.py +242 -0
  165. sparknlp/annotator/seq2seq/__init__.py +35 -0
  166. sparknlp/annotator/seq2seq/auto_gguf_model.py +304 -0
  167. sparknlp/annotator/seq2seq/auto_gguf_reranker.py +334 -0
  168. sparknlp/annotator/seq2seq/auto_gguf_vision_model.py +336 -0
  169. sparknlp/annotator/seq2seq/bart_transformer.py +420 -0
  170. sparknlp/annotator/seq2seq/cohere_transformer.py +357 -0
  171. sparknlp/annotator/seq2seq/cpm_transformer.py +321 -0
  172. sparknlp/annotator/seq2seq/gpt2_transformer.py +363 -0
  173. sparknlp/annotator/seq2seq/llama2_transformer.py +343 -0
  174. sparknlp/annotator/seq2seq/llama3_transformer.py +381 -0
  175. sparknlp/annotator/seq2seq/m2m100_transformer.py +392 -0
  176. sparknlp/annotator/seq2seq/marian_transformer.py +374 -0
  177. sparknlp/annotator/seq2seq/mistral_transformer.py +348 -0
  178. sparknlp/annotator/seq2seq/nllb_transformer.py +420 -0
  179. sparknlp/annotator/seq2seq/olmo_transformer.py +326 -0
  180. sparknlp/annotator/seq2seq/phi2_transformer.py +326 -0
  181. sparknlp/annotator/seq2seq/phi3_transformer.py +330 -0
  182. sparknlp/annotator/seq2seq/phi4_transformer.py +387 -0
  183. sparknlp/annotator/seq2seq/qwen_transformer.py +340 -0
  184. sparknlp/annotator/seq2seq/starcoder_transformer.py +335 -0
  185. sparknlp/annotator/seq2seq/t5_transformer.py +425 -0
  186. sparknlp/annotator/similarity/__init__.py +0 -0
  187. sparknlp/annotator/similarity/document_similarity_ranker.py +379 -0
  188. sparknlp/annotator/spell_check/__init__.py +18 -0
  189. sparknlp/annotator/spell_check/context_spell_checker.py +911 -0
  190. sparknlp/annotator/spell_check/norvig_sweeting.py +358 -0
  191. sparknlp/annotator/spell_check/symmetric_delete.py +299 -0
  192. sparknlp/annotator/stemmer.py +79 -0
  193. sparknlp/annotator/stop_words_cleaner.py +190 -0
  194. sparknlp/annotator/tf_ner_dl_graph_builder.py +179 -0
  195. sparknlp/annotator/token/__init__.py +19 -0
  196. sparknlp/annotator/token/chunk_tokenizer.py +118 -0
  197. sparknlp/annotator/token/recursive_tokenizer.py +205 -0
  198. sparknlp/annotator/token/regex_tokenizer.py +208 -0
  199. sparknlp/annotator/token/tokenizer.py +561 -0
  200. sparknlp/annotator/token2_chunk.py +76 -0
  201. sparknlp/annotator/ws/__init__.py +16 -0
  202. sparknlp/annotator/ws/word_segmenter.py +429 -0
  203. sparknlp/base/__init__.py +30 -0
  204. sparknlp/base/audio_assembler.py +95 -0
  205. sparknlp/base/doc2_chunk.py +169 -0
  206. sparknlp/base/document_assembler.py +164 -0
  207. sparknlp/base/embeddings_finisher.py +201 -0
  208. sparknlp/base/finisher.py +217 -0
  209. sparknlp/base/gguf_ranking_finisher.py +234 -0
  210. sparknlp/base/graph_finisher.py +125 -0
  211. sparknlp/base/has_recursive_fit.py +24 -0
  212. sparknlp/base/has_recursive_transform.py +22 -0
  213. sparknlp/base/image_assembler.py +172 -0
  214. sparknlp/base/light_pipeline.py +429 -0
  215. sparknlp/base/multi_document_assembler.py +164 -0
  216. sparknlp/base/prompt_assembler.py +207 -0
  217. sparknlp/base/recursive_pipeline.py +107 -0
  218. sparknlp/base/table_assembler.py +145 -0
  219. sparknlp/base/token_assembler.py +124 -0
  220. sparknlp/common/__init__.py +26 -0
  221. sparknlp/common/annotator_approach.py +41 -0
  222. sparknlp/common/annotator_model.py +47 -0
  223. sparknlp/common/annotator_properties.py +114 -0
  224. sparknlp/common/annotator_type.py +38 -0
  225. sparknlp/common/completion_post_processing.py +37 -0
  226. sparknlp/common/coverage_result.py +22 -0
  227. sparknlp/common/match_strategy.py +33 -0
  228. sparknlp/common/properties.py +1298 -0
  229. sparknlp/common/read_as.py +33 -0
  230. sparknlp/common/recursive_annotator_approach.py +35 -0
  231. sparknlp/common/storage.py +149 -0
  232. sparknlp/common/utils.py +39 -0
  233. sparknlp/functions.py +315 -5
  234. sparknlp/internal/__init__.py +1199 -0
  235. sparknlp/internal/annotator_java_ml.py +32 -0
  236. sparknlp/internal/annotator_transformer.py +37 -0
  237. sparknlp/internal/extended_java_wrapper.py +63 -0
  238. sparknlp/internal/params_getters_setters.py +71 -0
  239. sparknlp/internal/recursive.py +70 -0
  240. sparknlp/logging/__init__.py +15 -0
  241. sparknlp/logging/comet.py +467 -0
  242. sparknlp/partition/__init__.py +16 -0
  243. sparknlp/partition/partition.py +244 -0
  244. sparknlp/partition/partition_properties.py +902 -0
  245. sparknlp/partition/partition_transformer.py +200 -0
  246. sparknlp/pretrained/__init__.py +17 -0
  247. sparknlp/pretrained/pretrained_pipeline.py +158 -0
  248. sparknlp/pretrained/resource_downloader.py +216 -0
  249. sparknlp/pretrained/utils.py +35 -0
  250. sparknlp/reader/__init__.py +15 -0
  251. sparknlp/reader/enums.py +19 -0
  252. sparknlp/reader/pdf_to_text.py +190 -0
  253. sparknlp/reader/reader2doc.py +124 -0
  254. sparknlp/reader/reader2image.py +136 -0
  255. sparknlp/reader/reader2table.py +44 -0
  256. sparknlp/reader/reader_assembler.py +159 -0
  257. sparknlp/reader/sparknlp_reader.py +461 -0
  258. sparknlp/training/__init__.py +20 -0
  259. sparknlp/training/_tf_graph_builders/__init__.py +0 -0
  260. sparknlp/training/_tf_graph_builders/graph_builders.py +299 -0
  261. sparknlp/training/_tf_graph_builders/ner_dl/__init__.py +0 -0
  262. sparknlp/training/_tf_graph_builders/ner_dl/create_graph.py +41 -0
  263. sparknlp/training/_tf_graph_builders/ner_dl/dataset_encoder.py +78 -0
  264. sparknlp/training/_tf_graph_builders/ner_dl/ner_model.py +521 -0
  265. sparknlp/training/_tf_graph_builders/ner_dl/ner_model_saver.py +62 -0
  266. sparknlp/training/_tf_graph_builders/ner_dl/sentence_grouper.py +28 -0
  267. sparknlp/training/_tf_graph_builders/tf2contrib/__init__.py +36 -0
  268. sparknlp/training/_tf_graph_builders/tf2contrib/core_rnn_cell.py +385 -0
  269. sparknlp/training/_tf_graph_builders/tf2contrib/fused_rnn_cell.py +183 -0
  270. sparknlp/training/_tf_graph_builders/tf2contrib/gru_ops.py +235 -0
  271. sparknlp/training/_tf_graph_builders/tf2contrib/lstm_ops.py +665 -0
  272. sparknlp/training/_tf_graph_builders/tf2contrib/rnn.py +245 -0
  273. sparknlp/training/_tf_graph_builders/tf2contrib/rnn_cell.py +4006 -0
  274. sparknlp/training/_tf_graph_builders_1x/__init__.py +0 -0
  275. sparknlp/training/_tf_graph_builders_1x/graph_builders.py +277 -0
  276. sparknlp/training/_tf_graph_builders_1x/ner_dl/__init__.py +0 -0
  277. sparknlp/training/_tf_graph_builders_1x/ner_dl/create_graph.py +34 -0
  278. sparknlp/training/_tf_graph_builders_1x/ner_dl/dataset_encoder.py +78 -0
  279. sparknlp/training/_tf_graph_builders_1x/ner_dl/ner_model.py +532 -0
  280. sparknlp/training/_tf_graph_builders_1x/ner_dl/ner_model_saver.py +62 -0
  281. sparknlp/training/_tf_graph_builders_1x/ner_dl/sentence_grouper.py +28 -0
  282. sparknlp/training/conll.py +150 -0
  283. sparknlp/training/conllu.py +103 -0
  284. sparknlp/training/pos.py +103 -0
  285. sparknlp/training/pub_tator.py +76 -0
  286. sparknlp/training/spacy_to_annotation.py +57 -0
  287. sparknlp/training/tfgraphs.py +5 -0
  288. sparknlp/upload_to_hub.py +149 -0
  289. sparknlp/util.py +51 -5
  290. com/__init__.pyc +0 -0
  291. com/__pycache__/__init__.cpython-36.pyc +0 -0
  292. com/johnsnowlabs/__init__.pyc +0 -0
  293. com/johnsnowlabs/__pycache__/__init__.cpython-36.pyc +0 -0
  294. com/johnsnowlabs/nlp/__init__.pyc +0 -0
  295. com/johnsnowlabs/nlp/__pycache__/__init__.cpython-36.pyc +0 -0
  296. spark_nlp-2.6.3rc1.dist-info/METADATA +0 -36
  297. spark_nlp-2.6.3rc1.dist-info/RECORD +0 -48
  298. sparknlp/__init__.pyc +0 -0
  299. sparknlp/__pycache__/__init__.cpython-36.pyc +0 -0
  300. sparknlp/__pycache__/annotation.cpython-36.pyc +0 -0
  301. sparknlp/__pycache__/annotator.cpython-36.pyc +0 -0
  302. sparknlp/__pycache__/base.cpython-36.pyc +0 -0
  303. sparknlp/__pycache__/common.cpython-36.pyc +0 -0
  304. sparknlp/__pycache__/embeddings.cpython-36.pyc +0 -0
  305. sparknlp/__pycache__/functions.cpython-36.pyc +0 -0
  306. sparknlp/__pycache__/internal.cpython-36.pyc +0 -0
  307. sparknlp/__pycache__/pretrained.cpython-36.pyc +0 -0
  308. sparknlp/__pycache__/storage.cpython-36.pyc +0 -0
  309. sparknlp/__pycache__/training.cpython-36.pyc +0 -0
  310. sparknlp/__pycache__/util.cpython-36.pyc +0 -0
  311. sparknlp/annotation.pyc +0 -0
  312. sparknlp/annotator.py +0 -3006
  313. sparknlp/annotator.pyc +0 -0
  314. sparknlp/base.py +0 -347
  315. sparknlp/base.pyc +0 -0
  316. sparknlp/common.py +0 -193
  317. sparknlp/common.pyc +0 -0
  318. sparknlp/embeddings.py +0 -40
  319. sparknlp/embeddings.pyc +0 -0
  320. sparknlp/internal.py +0 -288
  321. sparknlp/internal.pyc +0 -0
  322. sparknlp/pretrained.py +0 -123
  323. sparknlp/pretrained.pyc +0 -0
  324. sparknlp/storage.py +0 -32
  325. sparknlp/storage.pyc +0 -0
  326. sparknlp/training.py +0 -62
  327. sparknlp/training.pyc +0 -0
  328. sparknlp/util.pyc +0 -0
  329. {spark_nlp-2.6.3rc1.dist-info → spark_nlp-6.2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,4006 @@
1
+ # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Module for constructing RNN Cells."""
16
+ from __future__ import absolute_import
17
+ from __future__ import division
18
+ from __future__ import print_function
19
+
20
+ import collections
21
+ import math
22
+
23
+ from tensorflow.contrib.compiler import jit
24
+ from tensorflow.contrib.layers.python.layers import layers
25
+ from tensorflow.contrib.rnn.python.ops import core_rnn_cell
26
+ from tensorflow.python.framework import constant_op
27
+ from tensorflow.python.framework import dtypes
28
+ from tensorflow.python.framework import op_def_registry
29
+ from tensorflow.python.framework import ops
30
+ from tensorflow.python.framework import tensor_shape
31
+ from tensorflow.python.keras import activations
32
+ from tensorflow.python.keras import initializers
33
+ from tensorflow.python.keras.engine import input_spec
34
+ from tensorflow.python.ops import array_ops
35
+ from tensorflow.python.ops import clip_ops
36
+ from tensorflow.python.ops import control_flow_ops
37
+ from tensorflow.python.ops import gen_array_ops
38
+ from tensorflow.python.ops import init_ops
39
+ from tensorflow.python.ops import math_ops
40
+ from tensorflow.python.ops import nn_impl # pylint: disable=unused-import
41
+ from tensorflow.python.ops import nn_ops
42
+ from tensorflow.python.ops import random_ops
43
+ from tensorflow.python.ops import rnn_cell_impl
44
+ from tensorflow.python.ops import variable_scope as vs
45
+ from tensorflow.python.platform import tf_logging as logging
46
+ from tensorflow.python.util import nest
47
+
48
+
49
+ def _get_concat_variable(name, shape, dtype, num_shards):
50
+ """Get a sharded variable concatenated into one tensor."""
51
+ sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards)
52
+ if len(sharded_variable) == 1:
53
+ return sharded_variable[0]
54
+
55
+ concat_name = name + "/concat"
56
+ concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0"
57
+ for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES):
58
+ if value.name == concat_full_name:
59
+ return value
60
+
61
+ concat_variable = array_ops.concat(sharded_variable, 0, name=concat_name)
62
+ ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES, concat_variable)
63
+ return concat_variable
64
+
65
+
66
+ def _get_sharded_variable(name, shape, dtype, num_shards):
67
+ """Get a list of sharded variables with the given dtype."""
68
+ if num_shards > shape[0]:
69
+ raise ValueError("Too many shards: shape=%s, num_shards=%d" % (shape,
70
+ num_shards))
71
+ unit_shard_size = int(math.floor(shape[0] / num_shards))
72
+ remaining_rows = shape[0] - unit_shard_size * num_shards
73
+
74
+ shards = []
75
+ for i in range(num_shards):
76
+ current_size = unit_shard_size
77
+ if i < remaining_rows:
78
+ current_size += 1
79
+ shards.append(
80
+ vs.get_variable(
81
+ name + "_%d" % i, [current_size] + shape[1:], dtype=dtype))
82
+ return shards
83
+
84
+
85
+ def _norm(g, b, inp, scope):
86
+ shape = inp.get_shape()[-1:]
87
+ gamma_init = init_ops.constant_initializer(g)
88
+ beta_init = init_ops.constant_initializer(b)
89
+ with vs.variable_scope(scope):
90
+ # Initialize beta and gamma for use by layer_norm.
91
+ vs.get_variable("gamma", shape=shape, initializer=gamma_init)
92
+ vs.get_variable("beta", shape=shape, initializer=beta_init)
93
+ normalized = layers.layer_norm(inp, reuse=True, scope=scope)
94
+ return normalized
95
+
96
+
97
+ class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell):
98
+ """Long short-term memory unit (LSTM) recurrent network cell.
99
+
100
+ The default non-peephole implementation is based on:
101
+
102
+ https://pdfs.semanticscholar.org/1154/0131eae85b2e11d53df7f1360eeb6476e7f4.pdf
103
+
104
+ Felix Gers, Jurgen Schmidhuber, and Fred Cummins.
105
+ "Learning to forget: Continual prediction with LSTM." IET, 850-855, 1999.
106
+
107
+ The peephole implementation is based on:
108
+
109
+ https://research.google.com/pubs/archive/43905.pdf
110
+
111
+ Hasim Sak, Andrew Senior, and Francoise Beaufays.
112
+ "Long short-term memory recurrent neural network architectures for
113
+ large scale acoustic modeling." INTERSPEECH, 2014.
114
+
115
+ The coupling of input and forget gate is based on:
116
+
117
+ http://arxiv.org/pdf/1503.04069.pdf
118
+
119
+ Greff et al. "LSTM: A Search Space Odyssey"
120
+
121
+ The class uses optional peep-hole connections, and an optional projection
122
+ layer.
123
+ Layer normalization implementation is based on:
124
+
125
+ https://arxiv.org/abs/1607.06450.
126
+
127
+ "Layer Normalization"
128
+ Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton
129
+
130
+ and is applied before the internal nonlinearities.
131
+
132
+ """
133
+
134
+ def __init__(self,
135
+ num_units,
136
+ use_peepholes=False,
137
+ initializer=None,
138
+ num_proj=None,
139
+ proj_clip=None,
140
+ num_unit_shards=1,
141
+ num_proj_shards=1,
142
+ forget_bias=1.0,
143
+ state_is_tuple=True,
144
+ activation=math_ops.tanh,
145
+ reuse=None,
146
+ layer_norm=False,
147
+ norm_gain=1.0,
148
+ norm_shift=0.0):
149
+ """Initialize the parameters for an LSTM cell.
150
+
151
+ Args:
152
+ num_units: int, The number of units in the LSTM cell
153
+ use_peepholes: bool, set True to enable diagonal/peephole connections.
154
+ initializer: (optional) The initializer to use for the weight and
155
+ projection matrices.
156
+ num_proj: (optional) int, The output dimensionality for the projection
157
+ matrices. If None, no projection is performed.
158
+ proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is
159
+ provided, then the projected values are clipped elementwise to within
160
+ `[-proj_clip, proj_clip]`.
161
+ num_unit_shards: How to split the weight matrix. If >1, the weight
162
+ matrix is stored across num_unit_shards.
163
+ num_proj_shards: How to split the projection matrix. If >1, the
164
+ projection matrix is stored across num_proj_shards.
165
+ forget_bias: Biases of the forget gate are initialized by default to 1
166
+ in order to reduce the scale of forgetting at the beginning of
167
+ the training.
168
+ state_is_tuple: If True, accepted and returned states are 2-tuples of
169
+ the `c_state` and `m_state`. By default (False), they are concatenated
170
+ along the column axis. This default behavior will soon be deprecated.
171
+ activation: Activation function of the inner states.
172
+ reuse: (optional) Python boolean describing whether to reuse variables
173
+ in an existing scope. If not `True`, and the existing scope already has
174
+ the given variables, an error is raised.
175
+ layer_norm: If `True`, layer normalization will be applied.
176
+ norm_gain: float, The layer normalization gain initial value. If
177
+ `layer_norm` has been set to `False`, this argument will be ignored.
178
+ norm_shift: float, The layer normalization shift initial value. If
179
+ `layer_norm` has been set to `False`, this argument will be ignored.
180
+ """
181
+ super(CoupledInputForgetGateLSTMCell, self).__init__(_reuse=reuse)
182
+ if not state_is_tuple:
183
+ logging.warn("%s: Using a concatenated state is slower and will soon be "
184
+ "deprecated. Use state_is_tuple=True.", self)
185
+ self._num_units = num_units
186
+ self._use_peepholes = use_peepholes
187
+ self._initializer = initializer
188
+ self._num_proj = num_proj
189
+ self._proj_clip = proj_clip
190
+ self._num_unit_shards = num_unit_shards
191
+ self._num_proj_shards = num_proj_shards
192
+ self._forget_bias = forget_bias
193
+ self._state_is_tuple = state_is_tuple
194
+ self._activation = activation
195
+ self._reuse = reuse
196
+ self._layer_norm = layer_norm
197
+ self._norm_gain = norm_gain
198
+ self._norm_shift = norm_shift
199
+
200
+ if num_proj:
201
+ self._state_size = (
202
+ rnn_cell_impl.LSTMStateTuple(num_units, num_proj)
203
+ if state_is_tuple else num_units + num_proj)
204
+ self._output_size = num_proj
205
+ else:
206
+ self._state_size = (
207
+ rnn_cell_impl.LSTMStateTuple(num_units, num_units)
208
+ if state_is_tuple else 2 * num_units)
209
+ self._output_size = num_units
210
+
211
+ @property
212
+ def state_size(self):
213
+ return self._state_size
214
+
215
+ @property
216
+ def output_size(self):
217
+ return self._output_size
218
+
219
+ def call(self, inputs, state):
220
+ """Run one step of LSTM.
221
+
222
+ Args:
223
+ inputs: input Tensor, 2D, batch x num_units.
224
+ state: if `state_is_tuple` is False, this must be a state Tensor,
225
+ `2-D, batch x state_size`. If `state_is_tuple` is True, this must be a
226
+ tuple of state Tensors, both `2-D`, with column sizes `c_state` and
227
+ `m_state`.
228
+
229
+ Returns:
230
+ A tuple containing:
231
+ - A `2-D, [batch x output_dim]`, Tensor representing the output of the
232
+ LSTM after reading `inputs` when previous state was `state`.
233
+ Here output_dim is:
234
+ num_proj if num_proj was set,
235
+ num_units otherwise.
236
+ - Tensor(s) representing the new state of LSTM after reading `inputs` when
237
+ the previous state was `state`. Same type and shape(s) as `state`.
238
+
239
+ Raises:
240
+ ValueError: If input size cannot be inferred from inputs via
241
+ static shape inference.
242
+ """
243
+ sigmoid = math_ops.sigmoid
244
+
245
+ num_proj = self._num_units if self._num_proj is None else self._num_proj
246
+
247
+ if self._state_is_tuple:
248
+ (c_prev, m_prev) = state
249
+ else:
250
+ c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units])
251
+ m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj])
252
+
253
+ dtype = inputs.dtype
254
+ input_size = inputs.get_shape().with_rank(2).dims[1]
255
+ if input_size.value is None:
256
+ raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
257
+ concat_w = _get_concat_variable(
258
+ "W",
259
+ [input_size.value + num_proj, 3 * self._num_units],
260
+ dtype,
261
+ self._num_unit_shards)
262
+
263
+ b = vs.get_variable(
264
+ "B",
265
+ shape=[3 * self._num_units],
266
+ initializer=init_ops.zeros_initializer(),
267
+ dtype=dtype)
268
+
269
+ # j = new_input, f = forget_gate, o = output_gate
270
+ cell_inputs = array_ops.concat([inputs, m_prev], 1)
271
+ lstm_matrix = math_ops.matmul(cell_inputs, concat_w)
272
+
273
+ # If layer nomalization is applied, do not add bias
274
+ if not self._layer_norm:
275
+ lstm_matrix = nn_ops.bias_add(lstm_matrix, b)
276
+
277
+ j, f, o = array_ops.split(value=lstm_matrix, num_or_size_splits=3, axis=1)
278
+
279
+ # Apply layer normalization
280
+ if self._layer_norm:
281
+ j = _norm(self._norm_gain, self._norm_shift, j, "transform")
282
+ f = _norm(self._norm_gain, self._norm_shift, f, "forget")
283
+ o = _norm(self._norm_gain, self._norm_shift, o, "output")
284
+
285
+ # Diagonal connections
286
+ if self._use_peepholes:
287
+ w_f_diag = vs.get_variable(
288
+ "W_F_diag", shape=[self._num_units], dtype=dtype)
289
+ w_o_diag = vs.get_variable(
290
+ "W_O_diag", shape=[self._num_units], dtype=dtype)
291
+
292
+ if self._use_peepholes:
293
+ f_act = sigmoid(f + self._forget_bias + w_f_diag * c_prev)
294
+ else:
295
+ f_act = sigmoid(f + self._forget_bias)
296
+ c = (f_act * c_prev + (1 - f_act) * self._activation(j))
297
+
298
+ # Apply layer normalization
299
+ if self._layer_norm:
300
+ c = _norm(self._norm_gain, self._norm_shift, c, "state")
301
+
302
+ if self._use_peepholes:
303
+ m = sigmoid(o + w_o_diag * c) * self._activation(c)
304
+ else:
305
+ m = sigmoid(o) * self._activation(c)
306
+
307
+ if self._num_proj is not None:
308
+ concat_w_proj = _get_concat_variable("W_P",
309
+ [self._num_units, self._num_proj],
310
+ dtype, self._num_proj_shards)
311
+
312
+ m = math_ops.matmul(m, concat_w_proj)
313
+ if self._proj_clip is not None:
314
+ # pylint: disable=invalid-unary-operand-type
315
+ m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)
316
+ # pylint: enable=invalid-unary-operand-type
317
+
318
+ new_state = (
319
+ rnn_cell_impl.LSTMStateTuple(c, m)
320
+ if self._state_is_tuple else array_ops.concat([c, m], 1))
321
+ return m, new_state
322
+
323
+
324
+ class TimeFreqLSTMCell(rnn_cell_impl.RNNCell):
325
+ """Time-Frequency Long short-term memory unit (LSTM) recurrent network cell.
326
+
327
+ This implementation is based on:
328
+
329
+ Tara N. Sainath and Bo Li
330
+ "Modeling Time-Frequency Patterns with LSTM vs. Convolutional Architectures
331
+ for LVCSR Tasks." submitted to INTERSPEECH, 2016.
332
+
333
+ It uses peep-hole connections and optional cell clipping.
334
+ """
335
+
336
+ def __init__(self,
337
+ num_units,
338
+ use_peepholes=False,
339
+ cell_clip=None,
340
+ initializer=None,
341
+ num_unit_shards=1,
342
+ forget_bias=1.0,
343
+ feature_size=None,
344
+ frequency_skip=1,
345
+ reuse=None):
346
+ """Initialize the parameters for an LSTM cell.
347
+
348
+ Args:
349
+ num_units: int, The number of units in the LSTM cell
350
+ use_peepholes: bool, set True to enable diagonal/peephole connections.
351
+ cell_clip: (optional) A float value, if provided the cell state is clipped
352
+ by this value prior to the cell output activation.
353
+ initializer: (optional) The initializer to use for the weight and
354
+ projection matrices.
355
+ num_unit_shards: int, How to split the weight matrix. If >1, the weight
356
+ matrix is stored across num_unit_shards.
357
+ forget_bias: float, Biases of the forget gate are initialized by default
358
+ to 1 in order to reduce the scale of forgetting at the beginning
359
+ of the training.
360
+ feature_size: int, The size of the input feature the LSTM spans over.
361
+ frequency_skip: int, The amount the LSTM filter is shifted by in
362
+ frequency.
363
+ reuse: (optional) Python boolean describing whether to reuse variables
364
+ in an existing scope. If not `True`, and the existing scope already has
365
+ the given variables, an error is raised.
366
+ """
367
+ super(TimeFreqLSTMCell, self).__init__(_reuse=reuse)
368
+ self._num_units = num_units
369
+ self._use_peepholes = use_peepholes
370
+ self._cell_clip = cell_clip
371
+ self._initializer = initializer
372
+ self._num_unit_shards = num_unit_shards
373
+ self._forget_bias = forget_bias
374
+ self._feature_size = feature_size
375
+ self._frequency_skip = frequency_skip
376
+ self._state_size = 2 * num_units
377
+ self._output_size = num_units
378
+ self._reuse = reuse
379
+
380
+ @property
381
+ def output_size(self):
382
+ return self._output_size
383
+
384
+ @property
385
+ def state_size(self):
386
+ return self._state_size
387
+
388
+ def call(self, inputs, state):
389
+ """Run one step of LSTM.
390
+
391
+ Args:
392
+ inputs: input Tensor, 2D, batch x num_units.
393
+ state: state Tensor, 2D, batch x state_size.
394
+
395
+ Returns:
396
+ A tuple containing:
397
+ - A 2D, batch x output_dim, Tensor representing the output of the LSTM
398
+ after reading "inputs" when previous state was "state".
399
+ Here output_dim is num_units.
400
+ - A 2D, batch x state_size, Tensor representing the new state of LSTM
401
+ after reading "inputs" when previous state was "state".
402
+ Raises:
403
+ ValueError: if an input_size was specified and the provided inputs have
404
+ a different dimension.
405
+ """
406
+ sigmoid = math_ops.sigmoid
407
+ tanh = math_ops.tanh
408
+
409
+ freq_inputs = self._make_tf_features(inputs)
410
+ dtype = inputs.dtype
411
+ actual_input_size = freq_inputs[0].get_shape().as_list()[1]
412
+
413
+ concat_w = _get_concat_variable(
414
+ "W", [actual_input_size + 2 * self._num_units, 4 * self._num_units],
415
+ dtype, self._num_unit_shards)
416
+
417
+ b = vs.get_variable(
418
+ "B",
419
+ shape=[4 * self._num_units],
420
+ initializer=init_ops.zeros_initializer(),
421
+ dtype=dtype)
422
+
423
+ # Diagonal connections
424
+ if self._use_peepholes:
425
+ w_f_diag = vs.get_variable(
426
+ "W_F_diag", shape=[self._num_units], dtype=dtype)
427
+ w_i_diag = vs.get_variable(
428
+ "W_I_diag", shape=[self._num_units], dtype=dtype)
429
+ w_o_diag = vs.get_variable(
430
+ "W_O_diag", shape=[self._num_units], dtype=dtype)
431
+
432
+ # initialize the first freq state to be zero
433
+ m_prev_freq = array_ops.zeros(
434
+ [inputs.shape.dims[0].value or inputs.get_shape()[0], self._num_units],
435
+ dtype)
436
+ for fq in range(len(freq_inputs)):
437
+ c_prev = array_ops.slice(state, [0, 2 * fq * self._num_units],
438
+ [-1, self._num_units])
439
+ m_prev = array_ops.slice(state, [0, (2 * fq + 1) * self._num_units],
440
+ [-1, self._num_units])
441
+ # i = input_gate, j = new_input, f = forget_gate, o = output_gate
442
+ cell_inputs = array_ops.concat([freq_inputs[fq], m_prev, m_prev_freq], 1)
443
+ lstm_matrix = nn_ops.bias_add(math_ops.matmul(cell_inputs, concat_w), b)
444
+ i, j, f, o = array_ops.split(
445
+ value=lstm_matrix, num_or_size_splits=4, axis=1)
446
+
447
+ if self._use_peepholes:
448
+ c = (
449
+ sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev +
450
+ sigmoid(i + w_i_diag * c_prev) * tanh(j))
451
+ else:
452
+ c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) * tanh(j))
453
+
454
+ if self._cell_clip is not None:
455
+ # pylint: disable=invalid-unary-operand-type
456
+ c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip)
457
+ # pylint: enable=invalid-unary-operand-type
458
+
459
+ if self._use_peepholes:
460
+ m = sigmoid(o + w_o_diag * c) * tanh(c)
461
+ else:
462
+ m = sigmoid(o) * tanh(c)
463
+ m_prev_freq = m
464
+ if fq == 0:
465
+ state_out = array_ops.concat([c, m], 1)
466
+ m_out = m
467
+ else:
468
+ state_out = array_ops.concat([state_out, c, m], 1)
469
+ m_out = array_ops.concat([m_out, m], 1)
470
+ return m_out, state_out
471
+
472
+ def _make_tf_features(self, input_feat):
473
+ """Make the frequency features.
474
+
475
+ Args:
476
+ input_feat: input Tensor, 2D, batch x num_units.
477
+
478
+ Returns:
479
+ A list of frequency features, with each element containing:
480
+ - A 2D, batch x output_dim, Tensor representing the time-frequency feature
481
+ for that frequency index. Here output_dim is feature_size.
482
+ Raises:
483
+ ValueError: if input_size cannot be inferred from static shape inference.
484
+ """
485
+ input_size = input_feat.get_shape().with_rank(2).dims[-1].value
486
+ if input_size is None:
487
+ raise ValueError("Cannot infer input_size from static shape inference.")
488
+ num_feats = int(
489
+ (input_size - self._feature_size) / (self._frequency_skip)) + 1
490
+ freq_inputs = []
491
+ for f in range(num_feats):
492
+ cur_input = array_ops.slice(input_feat, [0, f * self._frequency_skip],
493
+ [-1, self._feature_size])
494
+ freq_inputs.append(cur_input)
495
+ return freq_inputs
496
+
497
+
498
+ class GridLSTMCell(rnn_cell_impl.RNNCell):
499
+ """Grid Long short-term memory unit (LSTM) recurrent network cell.
500
+
501
+ The default is based on:
502
+ Nal Kalchbrenner, Ivo Danihelka and Alex Graves
503
+ "Grid Long Short-Term Memory," Proc. ICLR 2016.
504
+ http://arxiv.org/abs/1507.01526
505
+
506
+ When peephole connections are used, the implementation is based on:
507
+ Tara N. Sainath and Bo Li
508
+ "Modeling Time-Frequency Patterns with LSTM vs. Convolutional Architectures
509
+ for LVCSR Tasks." submitted to INTERSPEECH, 2016.
510
+
511
+ The code uses optional peephole connections, shared_weights and cell clipping.
512
+ """
513
+
514
+ def __init__(self,
515
+ num_units,
516
+ use_peepholes=False,
517
+ share_time_frequency_weights=False,
518
+ cell_clip=None,
519
+ initializer=None,
520
+ num_unit_shards=1,
521
+ forget_bias=1.0,
522
+ feature_size=None,
523
+ frequency_skip=None,
524
+ num_frequency_blocks=None,
525
+ start_freqindex_list=None,
526
+ end_freqindex_list=None,
527
+ couple_input_forget_gates=False,
528
+ state_is_tuple=True,
529
+ reuse=None):
530
+ """Initialize the parameters for an LSTM cell.
531
+
532
+ Args:
533
+ num_units: int, The number of units in the LSTM cell
534
+ use_peepholes: (optional) bool, default False. Set True to enable
535
+ diagonal/peephole connections.
536
+ share_time_frequency_weights: (optional) bool, default False. Set True to
537
+ enable shared cell weights between time and frequency LSTMs.
538
+ cell_clip: (optional) A float value, default None, if provided the cell
539
+ state is clipped by this value prior to the cell output activation.
540
+ initializer: (optional) The initializer to use for the weight and
541
+ projection matrices, default None.
542
+ num_unit_shards: (optional) int, default 1, How to split the weight
543
+ matrix. If > 1, the weight matrix is stored across num_unit_shards.
544
+ forget_bias: (optional) float, default 1.0, The initial bias of the
545
+ forget gates, used to reduce the scale of forgetting at the beginning
546
+ of the training.
547
+ feature_size: (optional) int, default None, The size of the input feature
548
+ the LSTM spans over.
549
+ frequency_skip: (optional) int, default None, The amount the LSTM filter
550
+ is shifted by in frequency.
551
+ num_frequency_blocks: [required] A list of frequency blocks needed to
552
+ cover the whole input feature splitting defined by start_freqindex_list
553
+ and end_freqindex_list.
554
+ start_freqindex_list: [optional], list of ints, default None, The
555
+ starting frequency index for each frequency block.
556
+ end_freqindex_list: [optional], list of ints, default None. The ending
557
+ frequency index for each frequency block.
558
+ couple_input_forget_gates: (optional) bool, default False, Whether to
559
+ couple the input and forget gates, i.e. f_gate = 1.0 - i_gate, to reduce
560
+ model parameters and computation cost.
561
+ state_is_tuple: If True, accepted and returned states are 2-tuples of
562
+ the `c_state` and `m_state`. By default (False), they are concatenated
563
+ along the column axis. This default behavior will soon be deprecated.
564
+ reuse: (optional) Python boolean describing whether to reuse variables
565
+ in an existing scope. If not `True`, and the existing scope already has
566
+ the given variables, an error is raised.
567
+ Raises:
568
+ ValueError: if the num_frequency_blocks list is not specified
569
+ """
570
+ super(GridLSTMCell, self).__init__(_reuse=reuse)
571
+ if not state_is_tuple:
572
+ logging.warn("%s: Using a concatenated state is slower and will soon be "
573
+ "deprecated. Use state_is_tuple=True.", self)
574
+ self._num_units = num_units
575
+ self._use_peepholes = use_peepholes
576
+ self._share_time_frequency_weights = share_time_frequency_weights
577
+ self._couple_input_forget_gates = couple_input_forget_gates
578
+ self._state_is_tuple = state_is_tuple
579
+ self._cell_clip = cell_clip
580
+ self._initializer = initializer
581
+ self._num_unit_shards = num_unit_shards
582
+ self._forget_bias = forget_bias
583
+ self._feature_size = feature_size
584
+ self._frequency_skip = frequency_skip
585
+ self._start_freqindex_list = start_freqindex_list
586
+ self._end_freqindex_list = end_freqindex_list
587
+ self._num_frequency_blocks = num_frequency_blocks
588
+ self._total_blocks = 0
589
+ self._reuse = reuse
590
+ if self._num_frequency_blocks is None:
591
+ raise ValueError("Must specify num_frequency_blocks")
592
+
593
+ for block_index in range(len(self._num_frequency_blocks)):
594
+ self._total_blocks += int(self._num_frequency_blocks[block_index])
595
+ if state_is_tuple:
596
+ state_names = ""
597
+ for block_index in range(len(self._num_frequency_blocks)):
598
+ for freq_index in range(self._num_frequency_blocks[block_index]):
599
+ name_prefix = "state_f%02d_b%02d" % (freq_index, block_index)
600
+ state_names += ("%s_c, %s_m," % (name_prefix, name_prefix))
601
+ self._state_tuple_type = collections.namedtuple("GridLSTMStateTuple",
602
+ state_names.strip(","))
603
+ self._state_size = self._state_tuple_type(*(
604
+ [num_units, num_units] * self._total_blocks))
605
+ else:
606
+ self._state_tuple_type = None
607
+ self._state_size = num_units * self._total_blocks * 2
608
+ self._output_size = num_units * self._total_blocks * 2
609
+
610
+ @property
611
+ def output_size(self):
612
+ return self._output_size
613
+
614
+ @property
615
+ def state_size(self):
616
+ return self._state_size
617
+
618
+ @property
619
+ def state_tuple_type(self):
620
+ return self._state_tuple_type
621
+
622
+ def call(self, inputs, state):
623
+ """Run one step of LSTM.
624
+
625
+ Args:
626
+ inputs: input Tensor, 2D, [batch, feature_size].
627
+ state: Tensor or tuple of Tensors, 2D, [batch, state_size], depends on the
628
+ flag self._state_is_tuple.
629
+
630
+ Returns:
631
+ A tuple containing:
632
+ - A 2D, [batch, output_dim], Tensor representing the output of the LSTM
633
+ after reading "inputs" when previous state was "state".
634
+ Here output_dim is num_units.
635
+ - A 2D, [batch, state_size], Tensor representing the new state of LSTM
636
+ after reading "inputs" when previous state was "state".
637
+ Raises:
638
+ ValueError: if an input_size was specified and the provided inputs have
639
+ a different dimension.
640
+ """
641
+ batch_size = tensor_shape.dimension_value(
642
+ inputs.shape[0]) or array_ops.shape(inputs)[0]
643
+ freq_inputs = self._make_tf_features(inputs)
644
+ m_out_lst = []
645
+ state_out_lst = []
646
+ for block in range(len(freq_inputs)):
647
+ m_out_lst_current, state_out_lst_current = self._compute(
648
+ freq_inputs[block],
649
+ block,
650
+ state,
651
+ batch_size,
652
+ state_is_tuple=self._state_is_tuple)
653
+ m_out_lst.extend(m_out_lst_current)
654
+ state_out_lst.extend(state_out_lst_current)
655
+ if self._state_is_tuple:
656
+ state_out = self._state_tuple_type(*state_out_lst)
657
+ else:
658
+ state_out = array_ops.concat(state_out_lst, 1)
659
+ m_out = array_ops.concat(m_out_lst, 1)
660
+ return m_out, state_out
661
+
662
+ def _compute(self,
663
+ freq_inputs,
664
+ block,
665
+ state,
666
+ batch_size,
667
+ state_prefix="state",
668
+ state_is_tuple=True):
669
+ """Run the actual computation of one step LSTM.
670
+
671
+ Args:
672
+ freq_inputs: list of Tensors, 2D, [batch, feature_size].
673
+ block: int, current frequency block index to process.
674
+ state: Tensor or tuple of Tensors, 2D, [batch, state_size], it depends on
675
+ the flag state_is_tuple.
676
+ batch_size: int32, batch size.
677
+ state_prefix: (optional) string, name prefix for states, defaults to
678
+ "state".
679
+ state_is_tuple: boolean, indicates whether the state is a tuple or Tensor.
680
+
681
+ Returns:
682
+ A tuple, containing:
683
+ - A list of [batch, output_dim] Tensors, representing the output of the
684
+ LSTM given the inputs and state.
685
+ - A list of [batch, state_size] Tensors, representing the LSTM state
686
+ values given the inputs and previous state.
687
+ """
688
+ sigmoid = math_ops.sigmoid
689
+ tanh = math_ops.tanh
690
+ num_gates = 3 if self._couple_input_forget_gates else 4
691
+ dtype = freq_inputs[0].dtype
692
+ actual_input_size = freq_inputs[0].get_shape().as_list()[1]
693
+
694
+ concat_w_f = _get_concat_variable(
695
+ "W_f_%d" % block,
696
+ [actual_input_size + 2 * self._num_units, num_gates * self._num_units],
697
+ dtype, self._num_unit_shards)
698
+ b_f = vs.get_variable(
699
+ "B_f_%d" % block,
700
+ shape=[num_gates * self._num_units],
701
+ initializer=init_ops.zeros_initializer(),
702
+ dtype=dtype)
703
+ if not self._share_time_frequency_weights:
704
+ concat_w_t = _get_concat_variable("W_t_%d" % block, [
705
+ actual_input_size + 2 * self._num_units, num_gates * self._num_units
706
+ ], dtype, self._num_unit_shards)
707
+ b_t = vs.get_variable(
708
+ "B_t_%d" % block,
709
+ shape=[num_gates * self._num_units],
710
+ initializer=init_ops.zeros_initializer(),
711
+ dtype=dtype)
712
+
713
+ if self._use_peepholes:
714
+ # Diagonal connections
715
+ if not self._couple_input_forget_gates:
716
+ w_f_diag_freqf = vs.get_variable(
717
+ "W_F_diag_freqf_%d" % block, shape=[self._num_units], dtype=dtype)
718
+ w_f_diag_freqt = vs.get_variable(
719
+ "W_F_diag_freqt_%d" % block, shape=[self._num_units], dtype=dtype)
720
+ w_i_diag_freqf = vs.get_variable(
721
+ "W_I_diag_freqf_%d" % block, shape=[self._num_units], dtype=dtype)
722
+ w_i_diag_freqt = vs.get_variable(
723
+ "W_I_diag_freqt_%d" % block, shape=[self._num_units], dtype=dtype)
724
+ w_o_diag_freqf = vs.get_variable(
725
+ "W_O_diag_freqf_%d" % block, shape=[self._num_units], dtype=dtype)
726
+ w_o_diag_freqt = vs.get_variable(
727
+ "W_O_diag_freqt_%d" % block, shape=[self._num_units], dtype=dtype)
728
+ if not self._share_time_frequency_weights:
729
+ if not self._couple_input_forget_gates:
730
+ w_f_diag_timef = vs.get_variable(
731
+ "W_F_diag_timef_%d" % block, shape=[self._num_units], dtype=dtype)
732
+ w_f_diag_timet = vs.get_variable(
733
+ "W_F_diag_timet_%d" % block, shape=[self._num_units], dtype=dtype)
734
+ w_i_diag_timef = vs.get_variable(
735
+ "W_I_diag_timef_%d" % block, shape=[self._num_units], dtype=dtype)
736
+ w_i_diag_timet = vs.get_variable(
737
+ "W_I_diag_timet_%d" % block, shape=[self._num_units], dtype=dtype)
738
+ w_o_diag_timef = vs.get_variable(
739
+ "W_O_diag_timef_%d" % block, shape=[self._num_units], dtype=dtype)
740
+ w_o_diag_timet = vs.get_variable(
741
+ "W_O_diag_timet_%d" % block, shape=[self._num_units], dtype=dtype)
742
+
743
+ # initialize the first freq state to be zero
744
+ m_prev_freq = array_ops.zeros([batch_size, self._num_units], dtype)
745
+ c_prev_freq = array_ops.zeros([batch_size, self._num_units], dtype)
746
+ for freq_index in range(len(freq_inputs)):
747
+ if state_is_tuple:
748
+ name_prefix = "%s_f%02d_b%02d" % (state_prefix, freq_index, block)
749
+ c_prev_time = getattr(state, name_prefix + "_c")
750
+ m_prev_time = getattr(state, name_prefix + "_m")
751
+ else:
752
+ c_prev_time = array_ops.slice(
753
+ state, [0, 2 * freq_index * self._num_units], [-1, self._num_units])
754
+ m_prev_time = array_ops.slice(
755
+ state, [0, (2 * freq_index + 1) * self._num_units],
756
+ [-1, self._num_units])
757
+
758
+ # i = input_gate, j = new_input, f = forget_gate, o = output_gate
759
+ cell_inputs = array_ops.concat(
760
+ [freq_inputs[freq_index], m_prev_time, m_prev_freq], 1)
761
+
762
+ # F-LSTM
763
+ lstm_matrix_freq = nn_ops.bias_add(
764
+ math_ops.matmul(cell_inputs, concat_w_f), b_f)
765
+ if self._couple_input_forget_gates:
766
+ i_freq, j_freq, o_freq = array_ops.split(
767
+ value=lstm_matrix_freq, num_or_size_splits=num_gates, axis=1)
768
+ f_freq = None
769
+ else:
770
+ i_freq, j_freq, f_freq, o_freq = array_ops.split(
771
+ value=lstm_matrix_freq, num_or_size_splits=num_gates, axis=1)
772
+ # T-LSTM
773
+ if self._share_time_frequency_weights:
774
+ i_time = i_freq
775
+ j_time = j_freq
776
+ f_time = f_freq
777
+ o_time = o_freq
778
+ else:
779
+ lstm_matrix_time = nn_ops.bias_add(
780
+ math_ops.matmul(cell_inputs, concat_w_t), b_t)
781
+ if self._couple_input_forget_gates:
782
+ i_time, j_time, o_time = array_ops.split(
783
+ value=lstm_matrix_time, num_or_size_splits=num_gates, axis=1)
784
+ f_time = None
785
+ else:
786
+ i_time, j_time, f_time, o_time = array_ops.split(
787
+ value=lstm_matrix_time, num_or_size_splits=num_gates, axis=1)
788
+
789
+ # F-LSTM c_freq
790
+ # input gate activations
791
+ if self._use_peepholes:
792
+ i_freq_g = sigmoid(i_freq + w_i_diag_freqf * c_prev_freq +
793
+ w_i_diag_freqt * c_prev_time)
794
+ else:
795
+ i_freq_g = sigmoid(i_freq)
796
+ # forget gate activations
797
+ if self._couple_input_forget_gates:
798
+ f_freq_g = 1.0 - i_freq_g
799
+ else:
800
+ if self._use_peepholes:
801
+ f_freq_g = sigmoid(f_freq + self._forget_bias + w_f_diag_freqf *
802
+ c_prev_freq + w_f_diag_freqt * c_prev_time)
803
+ else:
804
+ f_freq_g = sigmoid(f_freq + self._forget_bias)
805
+ # cell state
806
+ c_freq = f_freq_g * c_prev_freq + i_freq_g * tanh(j_freq)
807
+ if self._cell_clip is not None:
808
+ # pylint: disable=invalid-unary-operand-type
809
+ c_freq = clip_ops.clip_by_value(c_freq, -self._cell_clip,
810
+ self._cell_clip)
811
+ # pylint: enable=invalid-unary-operand-type
812
+
813
+ # T-LSTM c_freq
814
+ # input gate activations
815
+ if self._use_peepholes:
816
+ if self._share_time_frequency_weights:
817
+ i_time_g = sigmoid(i_time + w_i_diag_freqf * c_prev_freq +
818
+ w_i_diag_freqt * c_prev_time)
819
+ else:
820
+ i_time_g = sigmoid(i_time + w_i_diag_timef * c_prev_freq +
821
+ w_i_diag_timet * c_prev_time)
822
+ else:
823
+ i_time_g = sigmoid(i_time)
824
+ # forget gate activations
825
+ if self._couple_input_forget_gates:
826
+ f_time_g = 1.0 - i_time_g
827
+ else:
828
+ if self._use_peepholes:
829
+ if self._share_time_frequency_weights:
830
+ f_time_g = sigmoid(f_time + self._forget_bias + w_f_diag_freqf *
831
+ c_prev_freq + w_f_diag_freqt * c_prev_time)
832
+ else:
833
+ f_time_g = sigmoid(f_time + self._forget_bias + w_f_diag_timef *
834
+ c_prev_freq + w_f_diag_timet * c_prev_time)
835
+ else:
836
+ f_time_g = sigmoid(f_time + self._forget_bias)
837
+ # cell state
838
+ c_time = f_time_g * c_prev_time + i_time_g * tanh(j_time)
839
+ if self._cell_clip is not None:
840
+ # pylint: disable=invalid-unary-operand-type
841
+ c_time = clip_ops.clip_by_value(c_time, -self._cell_clip,
842
+ self._cell_clip)
843
+ # pylint: enable=invalid-unary-operand-type
844
+
845
+ # F-LSTM m_freq
846
+ if self._use_peepholes:
847
+ m_freq = sigmoid(o_freq + w_o_diag_freqf * c_freq +
848
+ w_o_diag_freqt * c_time) * tanh(c_freq)
849
+ else:
850
+ m_freq = sigmoid(o_freq) * tanh(c_freq)
851
+
852
+ # T-LSTM m_time
853
+ if self._use_peepholes:
854
+ if self._share_time_frequency_weights:
855
+ m_time = sigmoid(o_time + w_o_diag_freqf * c_freq +
856
+ w_o_diag_freqt * c_time) * tanh(c_time)
857
+ else:
858
+ m_time = sigmoid(o_time + w_o_diag_timef * c_freq +
859
+ w_o_diag_timet * c_time) * tanh(c_time)
860
+ else:
861
+ m_time = sigmoid(o_time) * tanh(c_time)
862
+
863
+ m_prev_freq = m_freq
864
+ c_prev_freq = c_freq
865
+ # Concatenate the outputs for T-LSTM and F-LSTM for each shift
866
+ if freq_index == 0:
867
+ state_out_lst = [c_time, m_time]
868
+ m_out_lst = [m_time, m_freq]
869
+ else:
870
+ state_out_lst.extend([c_time, m_time])
871
+ m_out_lst.extend([m_time, m_freq])
872
+
873
+ return m_out_lst, state_out_lst
874
+
875
+ def _make_tf_features(self, input_feat, slice_offset=0):
876
+ """Make the frequency features.
877
+
878
+ Args:
879
+ input_feat: input Tensor, 2D, [batch, num_units].
880
+ slice_offset: (optional) Python int, default 0, the slicing offset is only
881
+ used for the backward processing in the BidirectionalGridLSTMCell. It
882
+ specifies a different starting point instead of always 0 to enable the
883
+ forward and backward processing look at different frequency blocks.
884
+
885
+ Returns:
886
+ A list of frequency features, with each element containing:
887
+ - A 2D, [batch, output_dim], Tensor representing the time-frequency
888
+ feature for that frequency index. Here output_dim is feature_size.
889
+ Raises:
890
+ ValueError: if input_size cannot be inferred from static shape inference.
891
+ """
892
+ input_size = input_feat.get_shape().with_rank(2).dims[-1].value
893
+ if input_size is None:
894
+ raise ValueError("Cannot infer input_size from static shape inference.")
895
+ if slice_offset > 0:
896
+ # Padding to the end
897
+ inputs = array_ops.pad(input_feat,
898
+ array_ops.constant(
899
+ [0, 0, 0, slice_offset],
900
+ shape=[2, 2],
901
+ dtype=dtypes.int32), "CONSTANT")
902
+ elif slice_offset < 0:
903
+ # Padding to the front
904
+ inputs = array_ops.pad(input_feat,
905
+ array_ops.constant(
906
+ [0, 0, -slice_offset, 0],
907
+ shape=[2, 2],
908
+ dtype=dtypes.int32), "CONSTANT")
909
+ slice_offset = 0
910
+ else:
911
+ inputs = input_feat
912
+ freq_inputs = []
913
+ if not self._start_freqindex_list:
914
+ if len(self._num_frequency_blocks) != 1:
915
+ raise ValueError("Length of num_frequency_blocks"
916
+ " is not 1, but instead is %d" %
917
+ len(self._num_frequency_blocks))
918
+ num_feats = int(
919
+ (input_size - self._feature_size) / (self._frequency_skip)) + 1
920
+ if num_feats != self._num_frequency_blocks[0]:
921
+ raise ValueError(
922
+ "Invalid num_frequency_blocks, requires %d but gets %d, please"
923
+ " check the input size and filter config are correct." %
924
+ (self._num_frequency_blocks[0], num_feats))
925
+ block_inputs = []
926
+ for f in range(num_feats):
927
+ cur_input = array_ops.slice(
928
+ inputs, [0, slice_offset + f * self._frequency_skip],
929
+ [-1, self._feature_size])
930
+ block_inputs.append(cur_input)
931
+ freq_inputs.append(block_inputs)
932
+ else:
933
+ if len(self._start_freqindex_list) != len(self._end_freqindex_list):
934
+ raise ValueError("Length of start and end freqindex_list"
935
+ " does not match %d %d",
936
+ len(self._start_freqindex_list),
937
+ len(self._end_freqindex_list))
938
+ if len(self._num_frequency_blocks) != len(self._start_freqindex_list):
939
+ raise ValueError("Length of num_frequency_blocks"
940
+ " is not equal to start_freqindex_list %d %d",
941
+ len(self._num_frequency_blocks),
942
+ len(self._start_freqindex_list))
943
+ for b in range(len(self._start_freqindex_list)):
944
+ start_index = self._start_freqindex_list[b]
945
+ end_index = self._end_freqindex_list[b]
946
+ cur_size = end_index - start_index
947
+ block_feats = int(
948
+ (cur_size - self._feature_size) / (self._frequency_skip)) + 1
949
+ if block_feats != self._num_frequency_blocks[b]:
950
+ raise ValueError(
951
+ "Invalid num_frequency_blocks, requires %d but gets %d, please"
952
+ " check the input size and filter config are correct." %
953
+ (self._num_frequency_blocks[b], block_feats))
954
+ block_inputs = []
955
+ for f in range(block_feats):
956
+ cur_input = array_ops.slice(
957
+ inputs,
958
+ [0, start_index + slice_offset + f * self._frequency_skip],
959
+ [-1, self._feature_size])
960
+ block_inputs.append(cur_input)
961
+ freq_inputs.append(block_inputs)
962
+ return freq_inputs
963
+
964
+
965
+ class BidirectionalGridLSTMCell(GridLSTMCell):
966
+ """Bidirectional GridLstm cell.
967
+
968
+ The bidirection connection is only used in the frequency direction, which
969
+ hence doesn't affect the time direction's real-time processing that is
970
+ required for online recognition systems.
971
+ The current implementation uses different weights for the two directions.
972
+ """
973
+
974
+ def __init__(self,
975
+ num_units,
976
+ use_peepholes=False,
977
+ share_time_frequency_weights=False,
978
+ cell_clip=None,
979
+ initializer=None,
980
+ num_unit_shards=1,
981
+ forget_bias=1.0,
982
+ feature_size=None,
983
+ frequency_skip=None,
984
+ num_frequency_blocks=None,
985
+ start_freqindex_list=None,
986
+ end_freqindex_list=None,
987
+ couple_input_forget_gates=False,
988
+ backward_slice_offset=0,
989
+ reuse=None):
990
+ """Initialize the parameters for an LSTM cell.
991
+
992
+ Args:
993
+ num_units: int, The number of units in the LSTM cell
994
+ use_peepholes: (optional) bool, default False. Set True to enable
995
+ diagonal/peephole connections.
996
+ share_time_frequency_weights: (optional) bool, default False. Set True to
997
+ enable shared cell weights between time and frequency LSTMs.
998
+ cell_clip: (optional) A float value, default None, if provided the cell
999
+ state is clipped by this value prior to the cell output activation.
1000
+ initializer: (optional) The initializer to use for the weight and
1001
+ projection matrices, default None.
1002
+ num_unit_shards: (optional) int, default 1, How to split the weight
1003
+ matrix. If > 1, the weight matrix is stored across num_unit_shards.
1004
+ forget_bias: (optional) float, default 1.0, The initial bias of the
1005
+ forget gates, used to reduce the scale of forgetting at the beginning
1006
+ of the training.
1007
+ feature_size: (optional) int, default None, The size of the input feature
1008
+ the LSTM spans over.
1009
+ frequency_skip: (optional) int, default None, The amount the LSTM filter
1010
+ is shifted by in frequency.
1011
+ num_frequency_blocks: [required] A list of frequency blocks needed to
1012
+ cover the whole input feature splitting defined by start_freqindex_list
1013
+ and end_freqindex_list.
1014
+ start_freqindex_list: [optional], list of ints, default None, The
1015
+ starting frequency index for each frequency block.
1016
+ end_freqindex_list: [optional], list of ints, default None. The ending
1017
+ frequency index for each frequency block.
1018
+ couple_input_forget_gates: (optional) bool, default False, Whether to
1019
+ couple the input and forget gates, i.e. f_gate = 1.0 - i_gate, to reduce
1020
+ model parameters and computation cost.
1021
+ backward_slice_offset: (optional) int32, default 0, the starting offset to
1022
+ slice the feature for backward processing.
1023
+ reuse: (optional) Python boolean describing whether to reuse variables
1024
+ in an existing scope. If not `True`, and the existing scope already has
1025
+ the given variables, an error is raised.
1026
+ """
1027
+ super(BidirectionalGridLSTMCell, self).__init__(
1028
+ num_units, use_peepholes, share_time_frequency_weights, cell_clip,
1029
+ initializer, num_unit_shards, forget_bias, feature_size, frequency_skip,
1030
+ num_frequency_blocks, start_freqindex_list, end_freqindex_list,
1031
+ couple_input_forget_gates, True, reuse)
1032
+ self._backward_slice_offset = int(backward_slice_offset)
1033
+ state_names = ""
1034
+ for direction in ["fwd", "bwd"]:
1035
+ for block_index in range(len(self._num_frequency_blocks)):
1036
+ for freq_index in range(self._num_frequency_blocks[block_index]):
1037
+ name_prefix = "%s_state_f%02d_b%02d" % (direction, freq_index,
1038
+ block_index)
1039
+ state_names += ("%s_c, %s_m," % (name_prefix, name_prefix))
1040
+ self._state_tuple_type = collections.namedtuple(
1041
+ "BidirectionalGridLSTMStateTuple", state_names.strip(","))
1042
+ self._state_size = self._state_tuple_type(*(
1043
+ [num_units, num_units] * self._total_blocks * 2))
1044
+ self._output_size = 2 * num_units * self._total_blocks * 2
1045
+
1046
+ def call(self, inputs, state):
1047
+ """Run one step of LSTM.
1048
+
1049
+ Args:
1050
+ inputs: input Tensor, 2D, [batch, num_units].
1051
+ state: tuple of Tensors, 2D, [batch, state_size].
1052
+
1053
+ Returns:
1054
+ A tuple containing:
1055
+ - A 2D, [batch, output_dim], Tensor representing the output of the LSTM
1056
+ after reading "inputs" when previous state was "state".
1057
+ Here output_dim is num_units.
1058
+ - A 2D, [batch, state_size], Tensor representing the new state of LSTM
1059
+ after reading "inputs" when previous state was "state".
1060
+ Raises:
1061
+ ValueError: if an input_size was specified and the provided inputs have
1062
+ a different dimension.
1063
+ """
1064
+ batch_size = tensor_shape.dimension_value(
1065
+ inputs.shape[0]) or array_ops.shape(inputs)[0]
1066
+ fwd_inputs = self._make_tf_features(inputs)
1067
+ if self._backward_slice_offset:
1068
+ bwd_inputs = self._make_tf_features(inputs, self._backward_slice_offset)
1069
+ else:
1070
+ bwd_inputs = fwd_inputs
1071
+
1072
+ # Forward processing
1073
+ with vs.variable_scope("fwd"):
1074
+ fwd_m_out_lst = []
1075
+ fwd_state_out_lst = []
1076
+ for block in range(len(fwd_inputs)):
1077
+ fwd_m_out_lst_current, fwd_state_out_lst_current = self._compute(
1078
+ fwd_inputs[block],
1079
+ block,
1080
+ state,
1081
+ batch_size,
1082
+ state_prefix="fwd_state",
1083
+ state_is_tuple=True)
1084
+ fwd_m_out_lst.extend(fwd_m_out_lst_current)
1085
+ fwd_state_out_lst.extend(fwd_state_out_lst_current)
1086
+ # Backward processing
1087
+ bwd_m_out_lst = []
1088
+ bwd_state_out_lst = []
1089
+ with vs.variable_scope("bwd"):
1090
+ for block in range(len(bwd_inputs)):
1091
+ # Reverse the blocks
1092
+ bwd_inputs_reverse = bwd_inputs[block][::-1]
1093
+ bwd_m_out_lst_current, bwd_state_out_lst_current = self._compute(
1094
+ bwd_inputs_reverse,
1095
+ block,
1096
+ state,
1097
+ batch_size,
1098
+ state_prefix="bwd_state",
1099
+ state_is_tuple=True)
1100
+ bwd_m_out_lst.extend(bwd_m_out_lst_current)
1101
+ bwd_state_out_lst.extend(bwd_state_out_lst_current)
1102
+ state_out = self._state_tuple_type(*(fwd_state_out_lst + bwd_state_out_lst))
1103
+ # Outputs are always concated as it is never used separately.
1104
+ m_out = array_ops.concat(fwd_m_out_lst + bwd_m_out_lst, 1)
1105
+ return m_out, state_out
1106
+
1107
+
1108
+ # pylint: disable=protected-access
1109
+ _Linear = core_rnn_cell._Linear # pylint: disable=invalid-name
1110
+
1111
+
1112
+ # pylint: enable=protected-access
1113
+
1114
+
1115
+ class AttentionCellWrapper(rnn_cell_impl.RNNCell):
1116
+ """Basic attention cell wrapper.
1117
+
1118
+ Implementation based on https://arxiv.org/abs/1601.06733.
1119
+ """
1120
+
1121
+ def __init__(self,
1122
+ cell,
1123
+ attn_length,
1124
+ attn_size=None,
1125
+ attn_vec_size=None,
1126
+ input_size=None,
1127
+ state_is_tuple=True,
1128
+ reuse=None):
1129
+ """Create a cell with attention.
1130
+
1131
+ Args:
1132
+ cell: an RNNCell, an attention is added to it.
1133
+ attn_length: integer, the size of an attention window.
1134
+ attn_size: integer, the size of an attention vector. Equal to
1135
+ cell.output_size by default.
1136
+ attn_vec_size: integer, the number of convolutional features calculated
1137
+ on attention state and a size of the hidden layer built from
1138
+ base cell state. Equal attn_size to by default.
1139
+ input_size: integer, the size of a hidden linear layer,
1140
+ built from inputs and attention. Derived from the input tensor
1141
+ by default.
1142
+ state_is_tuple: If True, accepted and returned states are n-tuples, where
1143
+ `n = len(cells)`. By default (False), the states are all
1144
+ concatenated along the column axis.
1145
+ reuse: (optional) Python boolean describing whether to reuse variables
1146
+ in an existing scope. If not `True`, and the existing scope already has
1147
+ the given variables, an error is raised.
1148
+
1149
+ Raises:
1150
+ TypeError: if cell is not an RNNCell.
1151
+ ValueError: if cell returns a state tuple but the flag
1152
+ `state_is_tuple` is `False` or if attn_length is zero or less.
1153
+ """
1154
+ super(AttentionCellWrapper, self).__init__(_reuse=reuse)
1155
+ rnn_cell_impl.assert_like_rnncell("cell", cell)
1156
+ if nest.is_sequence(cell.state_size) and not state_is_tuple:
1157
+ raise ValueError(
1158
+ "Cell returns tuple of states, but the flag "
1159
+ "state_is_tuple is not set. State size is: %s" % str(cell.state_size))
1160
+ if attn_length <= 0:
1161
+ raise ValueError(
1162
+ "attn_length should be greater than zero, got %s" % str(attn_length))
1163
+ if not state_is_tuple:
1164
+ logging.warn("%s: Using a concatenated state is slower and will soon be "
1165
+ "deprecated. Use state_is_tuple=True.", self)
1166
+ if attn_size is None:
1167
+ attn_size = cell.output_size
1168
+ if attn_vec_size is None:
1169
+ attn_vec_size = attn_size
1170
+ self._state_is_tuple = state_is_tuple
1171
+ self._cell = cell
1172
+ self._attn_vec_size = attn_vec_size
1173
+ self._input_size = input_size
1174
+ self._attn_size = attn_size
1175
+ self._attn_length = attn_length
1176
+ self._reuse = reuse
1177
+ self._linear1 = None
1178
+ self._linear2 = None
1179
+ self._linear3 = None
1180
+
1181
+ @property
1182
+ def state_size(self):
1183
+ size = (self._cell.state_size, self._attn_size,
1184
+ self._attn_size * self._attn_length)
1185
+ if self._state_is_tuple:
1186
+ return size
1187
+ else:
1188
+ return sum(list(size))
1189
+
1190
+ @property
1191
+ def output_size(self):
1192
+ return self._attn_size
1193
+
1194
+ def call(self, inputs, state):
1195
+ """Long short-term memory cell with attention (LSTMA)."""
1196
+ if self._state_is_tuple:
1197
+ state, attns, attn_states = state
1198
+ else:
1199
+ states = state
1200
+ state = array_ops.slice(states, [0, 0], [-1, self._cell.state_size])
1201
+ attns = array_ops.slice(states, [0, self._cell.state_size],
1202
+ [-1, self._attn_size])
1203
+ attn_states = array_ops.slice(
1204
+ states, [0, self._cell.state_size + self._attn_size],
1205
+ [-1, self._attn_size * self._attn_length])
1206
+ attn_states = array_ops.reshape(attn_states,
1207
+ [-1, self._attn_length, self._attn_size])
1208
+ input_size = self._input_size
1209
+ if input_size is None:
1210
+ input_size = inputs.get_shape().as_list()[1]
1211
+ if self._linear1 is None:
1212
+ self._linear1 = _Linear([inputs, attns], input_size, True)
1213
+ inputs = self._linear1([inputs, attns])
1214
+ cell_output, new_state = self._cell(inputs, state)
1215
+ if self._state_is_tuple:
1216
+ new_state_cat = array_ops.concat(nest.flatten(new_state), 1)
1217
+ else:
1218
+ new_state_cat = new_state
1219
+ new_attns, new_attn_states = self._attention(new_state_cat, attn_states)
1220
+ with vs.variable_scope("attn_output_projection"):
1221
+ if self._linear2 is None:
1222
+ self._linear2 = _Linear([cell_output, new_attns], self._attn_size, True)
1223
+ output = self._linear2([cell_output, new_attns])
1224
+ new_attn_states = array_ops.concat(
1225
+ [new_attn_states, array_ops.expand_dims(output, 1)], 1)
1226
+ new_attn_states = array_ops.reshape(
1227
+ new_attn_states, [-1, self._attn_length * self._attn_size])
1228
+ new_state = (new_state, new_attns, new_attn_states)
1229
+ if not self._state_is_tuple:
1230
+ new_state = array_ops.concat(list(new_state), 1)
1231
+ return output, new_state
1232
+
1233
+ def _attention(self, query, attn_states):
1234
+ conv2d = nn_ops.conv2d
1235
+ reduce_sum = math_ops.reduce_sum
1236
+ softmax = nn_ops.softmax
1237
+ tanh = math_ops.tanh
1238
+
1239
+ with vs.variable_scope("attention"):
1240
+ k = vs.get_variable("attn_w",
1241
+ [1, 1, self._attn_size, self._attn_vec_size])
1242
+ v = vs.get_variable("attn_v", [self._attn_vec_size])
1243
+ hidden = array_ops.reshape(attn_states,
1244
+ [-1, self._attn_length, 1, self._attn_size])
1245
+ hidden_features = conv2d(hidden, k, [1, 1, 1, 1], "SAME")
1246
+ if self._linear3 is None:
1247
+ self._linear3 = _Linear(query, self._attn_vec_size, True)
1248
+ y = self._linear3(query)
1249
+ y = array_ops.reshape(y, [-1, 1, 1, self._attn_vec_size])
1250
+ s = reduce_sum(v * tanh(hidden_features + y), [2, 3])
1251
+ a = softmax(s)
1252
+ d = reduce_sum(
1253
+ array_ops.reshape(a, [-1, self._attn_length, 1, 1]) * hidden, [1, 2])
1254
+ new_attns = array_ops.reshape(d, [-1, self._attn_size])
1255
+ new_attn_states = array_ops.slice(attn_states, [0, 1, 0], [-1, -1, -1])
1256
+ return new_attns, new_attn_states
1257
+
1258
+
1259
+ class HighwayWrapper(rnn_cell_impl.RNNCell):
1260
+ """RNNCell wrapper that adds highway connection on cell input and output.
1261
+
1262
+ Based on:
1263
+ R. K. Srivastava, K. Greff, and J. Schmidhuber, "Highway networks",
1264
+ arXiv preprint arXiv:1505.00387, 2015.
1265
+ https://arxiv.org/abs/1505.00387
1266
+ """
1267
+
1268
+ def __init__(self,
1269
+ cell,
1270
+ couple_carry_transform_gates=True,
1271
+ carry_bias_init=1.0):
1272
+ """Constructs a `HighwayWrapper` for `cell`.
1273
+
1274
+ Args:
1275
+ cell: An instance of `RNNCell`.
1276
+ couple_carry_transform_gates: boolean, should the Carry and Transform gate
1277
+ be coupled.
1278
+ carry_bias_init: float, carry gates bias initialization.
1279
+ """
1280
+ self._cell = cell
1281
+ self._couple_carry_transform_gates = couple_carry_transform_gates
1282
+ self._carry_bias_init = carry_bias_init
1283
+
1284
+ @property
1285
+ def state_size(self):
1286
+ return self._cell.state_size
1287
+
1288
+ @property
1289
+ def output_size(self):
1290
+ return self._cell.output_size
1291
+
1292
+ def zero_state(self, batch_size, dtype):
1293
+ with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
1294
+ return self._cell.zero_state(batch_size, dtype)
1295
+
1296
+ def _highway(self, inp, out):
1297
+ input_size = inp.get_shape().with_rank(2).dims[1].value
1298
+ carry_weight = vs.get_variable("carry_w", [input_size, input_size])
1299
+ carry_bias = vs.get_variable(
1300
+ "carry_b", [input_size],
1301
+ initializer=init_ops.constant_initializer(self._carry_bias_init))
1302
+ carry = math_ops.sigmoid(nn_ops.xw_plus_b(inp, carry_weight, carry_bias))
1303
+ if self._couple_carry_transform_gates:
1304
+ transform = 1 - carry
1305
+ else:
1306
+ transform_weight = vs.get_variable("transform_w",
1307
+ [input_size, input_size])
1308
+ transform_bias = vs.get_variable(
1309
+ "transform_b", [input_size],
1310
+ initializer=init_ops.constant_initializer(-self._carry_bias_init))
1311
+ transform = math_ops.sigmoid(
1312
+ nn_ops.xw_plus_b(inp, transform_weight, transform_bias))
1313
+ return inp * carry + out * transform
1314
+
1315
+ def __call__(self, inputs, state, scope=None):
1316
+ """Run the cell and add its inputs to its outputs.
1317
+
1318
+ Args:
1319
+ inputs: cell inputs.
1320
+ state: cell state.
1321
+ scope: optional cell scope.
1322
+
1323
+ Returns:
1324
+ Tuple of cell outputs and new state.
1325
+
1326
+ Raises:
1327
+ TypeError: If cell inputs and outputs have different structure (type).
1328
+ ValueError: If cell inputs and outputs have different structure (value).
1329
+ """
1330
+ outputs, new_state = self._cell(inputs, state, scope=scope)
1331
+ nest.assert_same_structure(inputs, outputs)
1332
+
1333
+ # Ensure shapes match
1334
+ def assert_shape_match(inp, out):
1335
+ inp.get_shape().assert_is_compatible_with(out.get_shape())
1336
+
1337
+ nest.map_structure(assert_shape_match, inputs, outputs)
1338
+ res_outputs = nest.map_structure(self._highway, inputs, outputs)
1339
+ return (res_outputs, new_state)
1340
+
1341
+
1342
+ class LayerNormBasicLSTMCell(rnn_cell_impl.RNNCell):
1343
+ """LSTM unit with layer normalization and recurrent dropout.
1344
+
1345
+ This class adds layer normalization and recurrent dropout to a
1346
+ basic LSTM unit. Layer normalization implementation is based on:
1347
+
1348
+ https://arxiv.org/abs/1607.06450.
1349
+
1350
+ "Layer Normalization"
1351
+ Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton
1352
+
1353
+ and is applied before the internal nonlinearities.
1354
+ Recurrent dropout is base on:
1355
+
1356
+ https://arxiv.org/abs/1603.05118
1357
+
1358
+ "Recurrent Dropout without Memory Loss"
1359
+ Stanislau Semeniuta, Aliaksei Severyn, Erhardt Barth.
1360
+ """
1361
+
1362
+ def __init__(self,
1363
+ num_units,
1364
+ forget_bias=1.0,
1365
+ input_size=None,
1366
+ activation=math_ops.tanh,
1367
+ layer_norm=True,
1368
+ norm_gain=1.0,
1369
+ norm_shift=0.0,
1370
+ dropout_keep_prob=1.0,
1371
+ dropout_prob_seed=None,
1372
+ reuse=None):
1373
+ """Initializes the basic LSTM cell.
1374
+
1375
+ Args:
1376
+ num_units: int, The number of units in the LSTM cell.
1377
+ forget_bias: float, The bias added to forget gates (see above).
1378
+ input_size: Deprecated and unused.
1379
+ activation: Activation function of the inner states.
1380
+ layer_norm: If `True`, layer normalization will be applied.
1381
+ norm_gain: float, The layer normalization gain initial value. If
1382
+ `layer_norm` has been set to `False`, this argument will be ignored.
1383
+ norm_shift: float, The layer normalization shift initial value. If
1384
+ `layer_norm` has been set to `False`, this argument will be ignored.
1385
+ dropout_keep_prob: unit Tensor or float between 0 and 1 representing the
1386
+ recurrent dropout probability value. If float and 1.0, no dropout will
1387
+ be applied.
1388
+ dropout_prob_seed: (optional) integer, the randomness seed.
1389
+ reuse: (optional) Python boolean describing whether to reuse variables
1390
+ in an existing scope. If not `True`, and the existing scope already has
1391
+ the given variables, an error is raised.
1392
+ """
1393
+ super(LayerNormBasicLSTMCell, self).__init__(_reuse=reuse)
1394
+
1395
+ if input_size is not None:
1396
+ logging.warn("%s: The input_size parameter is deprecated.", self)
1397
+
1398
+ self._num_units = num_units
1399
+ self._activation = activation
1400
+ self._forget_bias = forget_bias
1401
+ self._keep_prob = dropout_keep_prob
1402
+ self._seed = dropout_prob_seed
1403
+ self._layer_norm = layer_norm
1404
+ self._norm_gain = norm_gain
1405
+ self._norm_shift = norm_shift
1406
+ self._reuse = reuse
1407
+
1408
+ @property
1409
+ def state_size(self):
1410
+ return rnn_cell_impl.LSTMStateTuple(self._num_units, self._num_units)
1411
+
1412
+ @property
1413
+ def output_size(self):
1414
+ return self._num_units
1415
+
1416
+ def _norm(self, inp, scope, dtype=dtypes.float32):
1417
+ shape = inp.get_shape()[-1:]
1418
+ gamma_init = init_ops.constant_initializer(self._norm_gain)
1419
+ beta_init = init_ops.constant_initializer(self._norm_shift)
1420
+ with vs.variable_scope(scope):
1421
+ # Initialize beta and gamma for use by layer_norm.
1422
+ vs.get_variable("gamma", shape=shape, initializer=gamma_init, dtype=dtype)
1423
+ vs.get_variable("beta", shape=shape, initializer=beta_init, dtype=dtype)
1424
+ normalized = layers.layer_norm(inp, reuse=True, scope=scope)
1425
+ return normalized
1426
+
1427
+ def _linear(self, args):
1428
+ out_size = 4 * self._num_units
1429
+ proj_size = args.get_shape()[-1]
1430
+ dtype = args.dtype
1431
+ weights = vs.get_variable("kernel", [proj_size, out_size], dtype=dtype)
1432
+ out = math_ops.matmul(args, weights)
1433
+ if not self._layer_norm:
1434
+ bias = vs.get_variable("bias", [out_size], dtype=dtype)
1435
+ out = nn_ops.bias_add(out, bias)
1436
+ return out
1437
+
1438
+ def call(self, inputs, state):
1439
+ """LSTM cell with layer normalization and recurrent dropout."""
1440
+ c, h = state
1441
+ args = array_ops.concat([inputs, h], 1)
1442
+ concat = self._linear(args)
1443
+ dtype = args.dtype
1444
+
1445
+ i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)
1446
+ if self._layer_norm:
1447
+ i = self._norm(i, "input", dtype=dtype)
1448
+ j = self._norm(j, "transform", dtype=dtype)
1449
+ f = self._norm(f, "forget", dtype=dtype)
1450
+ o = self._norm(o, "output", dtype=dtype)
1451
+
1452
+ g = self._activation(j)
1453
+ if (not isinstance(self._keep_prob, float)) or self._keep_prob < 1:
1454
+ g = nn_ops.dropout(g, self._keep_prob, seed=self._seed)
1455
+
1456
+ new_c = (
1457
+ c * math_ops.sigmoid(f + self._forget_bias) + math_ops.sigmoid(i) * g)
1458
+ if self._layer_norm:
1459
+ new_c = self._norm(new_c, "state", dtype=dtype)
1460
+ new_h = self._activation(new_c) * math_ops.sigmoid(o)
1461
+
1462
+ new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h)
1463
+ return new_h, new_state
1464
+
1465
+
1466
+ class NASCell(rnn_cell_impl.LayerRNNCell):
1467
+ """Neural Architecture Search (NAS) recurrent network cell.
1468
+
1469
+ This implements the recurrent cell from the paper:
1470
+
1471
+ https://arxiv.org/abs/1611.01578
1472
+
1473
+ Barret Zoph and Quoc V. Le.
1474
+ "Neural Architecture Search with Reinforcement Learning" Proc. ICLR 2017.
1475
+
1476
+ The class uses an optional projection layer.
1477
+ """
1478
+
1479
+ # NAS cell's architecture base.
1480
+ _NAS_BASE = 8
1481
+
1482
+ def __init__(self, num_units, num_proj=None, use_bias=False, reuse=None,
1483
+ **kwargs):
1484
+ """Initialize the parameters for a NAS cell.
1485
+
1486
+ Args:
1487
+ num_units: int, The number of units in the NAS cell.
1488
+ num_proj: (optional) int, The output dimensionality for the projection
1489
+ matrices. If None, no projection is performed.
1490
+ use_bias: (optional) bool, If True then use biases within the cell. This
1491
+ is False by default.
1492
+ reuse: (optional) Python boolean describing whether to reuse variables
1493
+ in an existing scope. If not `True`, and the existing scope already has
1494
+ the given variables, an error is raised.
1495
+ **kwargs: Additional keyword arguments.
1496
+ """
1497
+ super(NASCell, self).__init__(_reuse=reuse, **kwargs)
1498
+ self._num_units = num_units
1499
+ self._num_proj = num_proj
1500
+ self._use_bias = use_bias
1501
+ self._reuse = reuse
1502
+
1503
+ if num_proj is not None:
1504
+ self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_proj)
1505
+ self._output_size = num_proj
1506
+ else:
1507
+ self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_units)
1508
+ self._output_size = num_units
1509
+
1510
+ @property
1511
+ def state_size(self):
1512
+ return self._state_size
1513
+
1514
+ @property
1515
+ def output_size(self):
1516
+ return self._output_size
1517
+
1518
+ def build(self, inputs_shape):
1519
+ input_size = tensor_shape.dimension_value(
1520
+ tensor_shape.TensorShape(inputs_shape).with_rank(2)[1])
1521
+ if input_size is None:
1522
+ raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
1523
+
1524
+ num_proj = self._num_units if self._num_proj is None else self._num_proj
1525
+
1526
+ # Variables for the NAS cell. `recurrent_kernel` is all matrices multiplying
1527
+ # the hiddenstate and `kernel` is all matrices multiplying the inputs.
1528
+ self.recurrent_kernel = self.add_variable(
1529
+ "recurrent_kernel", [num_proj, self._NAS_BASE * self._num_units])
1530
+ self.kernel = self.add_variable(
1531
+ "kernel", [input_size, self._NAS_BASE * self._num_units])
1532
+
1533
+ if self._use_bias:
1534
+ self.bias = self.add_variable("bias",
1535
+ shape=[self._NAS_BASE * self._num_units],
1536
+ initializer=init_ops.zeros_initializer)
1537
+
1538
+ # Projection layer if specified
1539
+ if self._num_proj is not None:
1540
+ self.projection_weights = self.add_variable(
1541
+ "projection_weights", [self._num_units, self._num_proj])
1542
+
1543
+ self.built = True
1544
+
1545
+ def call(self, inputs, state):
1546
+ """Run one step of NAS Cell.
1547
+
1548
+ Args:
1549
+ inputs: input Tensor, 2D, batch x num_units.
1550
+ state: This must be a tuple of state Tensors, both `2-D`, with column
1551
+ sizes `c_state` and `m_state`.
1552
+
1553
+ Returns:
1554
+ A tuple containing:
1555
+ - A `2-D, [batch x output_dim]`, Tensor representing the output of the
1556
+ NAS Cell after reading `inputs` when previous state was `state`.
1557
+ Here output_dim is:
1558
+ num_proj if num_proj was set,
1559
+ num_units otherwise.
1560
+ - Tensor(s) representing the new state of NAS Cell after reading `inputs`
1561
+ when the previous state was `state`. Same type and shape(s) as `state`.
1562
+
1563
+ Raises:
1564
+ ValueError: If input size cannot be inferred from inputs via
1565
+ static shape inference.
1566
+ """
1567
+ sigmoid = math_ops.sigmoid
1568
+ tanh = math_ops.tanh
1569
+ relu = nn_ops.relu
1570
+
1571
+ (c_prev, m_prev) = state
1572
+
1573
+ m_matrix = math_ops.matmul(m_prev, self.recurrent_kernel)
1574
+ inputs_matrix = math_ops.matmul(inputs, self.kernel)
1575
+
1576
+ if self._use_bias:
1577
+ m_matrix = nn_ops.bias_add(m_matrix, self.bias)
1578
+
1579
+ # The NAS cell branches into 8 different splits for both the hiddenstate
1580
+ # and the input
1581
+ m_matrix_splits = array_ops.split(
1582
+ axis=1, num_or_size_splits=self._NAS_BASE, value=m_matrix)
1583
+ inputs_matrix_splits = array_ops.split(
1584
+ axis=1, num_or_size_splits=self._NAS_BASE, value=inputs_matrix)
1585
+
1586
+ # First layer
1587
+ layer1_0 = sigmoid(inputs_matrix_splits[0] + m_matrix_splits[0])
1588
+ layer1_1 = relu(inputs_matrix_splits[1] + m_matrix_splits[1])
1589
+ layer1_2 = sigmoid(inputs_matrix_splits[2] + m_matrix_splits[2])
1590
+ layer1_3 = relu(inputs_matrix_splits[3] * m_matrix_splits[3])
1591
+ layer1_4 = tanh(inputs_matrix_splits[4] + m_matrix_splits[4])
1592
+ layer1_5 = sigmoid(inputs_matrix_splits[5] + m_matrix_splits[5])
1593
+ layer1_6 = tanh(inputs_matrix_splits[6] + m_matrix_splits[6])
1594
+ layer1_7 = sigmoid(inputs_matrix_splits[7] + m_matrix_splits[7])
1595
+
1596
+ # Second layer
1597
+ l2_0 = tanh(layer1_0 * layer1_1)
1598
+ l2_1 = tanh(layer1_2 + layer1_3)
1599
+ l2_2 = tanh(layer1_4 * layer1_5)
1600
+ l2_3 = sigmoid(layer1_6 + layer1_7)
1601
+
1602
+ # Inject the cell
1603
+ l2_0 = tanh(l2_0 + c_prev)
1604
+
1605
+ # Third layer
1606
+ l3_0_pre = l2_0 * l2_1
1607
+ new_c = l3_0_pre # create new cell
1608
+ l3_0 = l3_0_pre
1609
+ l3_1 = tanh(l2_2 + l2_3)
1610
+
1611
+ # Final layer
1612
+ new_m = tanh(l3_0 * l3_1)
1613
+
1614
+ # Projection layer if specified
1615
+ if self._num_proj is not None:
1616
+ new_m = math_ops.matmul(new_m, self.projection_weights)
1617
+
1618
+ new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_m)
1619
+ return new_m, new_state
1620
+
1621
+
1622
+ class UGRNNCell(rnn_cell_impl.RNNCell):
1623
+ """Update Gate Recurrent Neural Network (UGRNN) cell.
1624
+
1625
+ Compromise between a LSTM/GRU and a vanilla RNN. There is only one
1626
+ gate, and that is to determine whether the unit should be
1627
+ integrating or computing instantaneously. This is the recurrent
1628
+ idea of the feedforward Highway Network.
1629
+
1630
+ This implements the recurrent cell from the paper:
1631
+
1632
+ https://arxiv.org/abs/1611.09913
1633
+
1634
+ Jasmine Collins, Jascha Sohl-Dickstein, and David Sussillo.
1635
+ "Capacity and Trainability in Recurrent Neural Networks" Proc. ICLR 2017.
1636
+ """
1637
+
1638
+ def __init__(self,
1639
+ num_units,
1640
+ initializer=None,
1641
+ forget_bias=1.0,
1642
+ activation=math_ops.tanh,
1643
+ reuse=None):
1644
+ """Initialize the parameters for an UGRNN cell.
1645
+
1646
+ Args:
1647
+ num_units: int, The number of units in the UGRNN cell
1648
+ initializer: (optional) The initializer to use for the weight matrices.
1649
+ forget_bias: (optional) float, default 1.0, The initial bias of the
1650
+ forget gate, used to reduce the scale of forgetting at the beginning
1651
+ of the training.
1652
+ activation: (optional) Activation function of the inner states.
1653
+ Default is `tf.tanh`.
1654
+ reuse: (optional) Python boolean describing whether to reuse variables
1655
+ in an existing scope. If not `True`, and the existing scope already has
1656
+ the given variables, an error is raised.
1657
+ """
1658
+ super(UGRNNCell, self).__init__(_reuse=reuse)
1659
+ self._num_units = num_units
1660
+ self._initializer = initializer
1661
+ self._forget_bias = forget_bias
1662
+ self._activation = activation
1663
+ self._reuse = reuse
1664
+ self._linear = None
1665
+
1666
+ @property
1667
+ def state_size(self):
1668
+ return self._num_units
1669
+
1670
+ @property
1671
+ def output_size(self):
1672
+ return self._num_units
1673
+
1674
+ def call(self, inputs, state):
1675
+ """Run one step of UGRNN.
1676
+
1677
+ Args:
1678
+ inputs: input Tensor, 2D, batch x input size.
1679
+ state: state Tensor, 2D, batch x num units.
1680
+
1681
+ Returns:
1682
+ new_output: batch x num units, Tensor representing the output of the UGRNN
1683
+ after reading `inputs` when previous state was `state`. Identical to
1684
+ `new_state`.
1685
+ new_state: batch x num units, Tensor representing the state of the UGRNN
1686
+ after reading `inputs` when previous state was `state`.
1687
+
1688
+ Raises:
1689
+ ValueError: If input size cannot be inferred from inputs via
1690
+ static shape inference.
1691
+ """
1692
+ sigmoid = math_ops.sigmoid
1693
+
1694
+ input_size = inputs.get_shape().with_rank(2).dims[1]
1695
+ if input_size.value is None:
1696
+ raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
1697
+
1698
+ with vs.variable_scope(
1699
+ vs.get_variable_scope(), initializer=self._initializer):
1700
+ cell_inputs = array_ops.concat([inputs, state], 1)
1701
+ if self._linear is None:
1702
+ self._linear = _Linear(cell_inputs, 2 * self._num_units, True)
1703
+ rnn_matrix = self._linear(cell_inputs)
1704
+
1705
+ [g_act, c_act] = array_ops.split(
1706
+ axis=1, num_or_size_splits=2, value=rnn_matrix)
1707
+
1708
+ c = self._activation(c_act)
1709
+ g = sigmoid(g_act + self._forget_bias)
1710
+ new_state = g * state + (1.0 - g) * c
1711
+ new_output = new_state
1712
+
1713
+ return new_output, new_state
1714
+
1715
+
1716
+ class IntersectionRNNCell(rnn_cell_impl.RNNCell):
1717
+ """Intersection Recurrent Neural Network (+RNN) cell.
1718
+
1719
+ Architecture with coupled recurrent gate as well as coupled depth
1720
+ gate, designed to improve information flow through stacked RNNs. As the
1721
+ architecture uses depth gating, the dimensionality of the depth
1722
+ output (y) also should not change through depth (input size == output size).
1723
+ To achieve this, the first layer of a stacked Intersection RNN projects
1724
+ the inputs to N (num units) dimensions. Therefore when initializing an
1725
+ IntersectionRNNCell, one should set `num_in_proj = N` for the first layer
1726
+ and use default settings for subsequent layers.
1727
+
1728
+ This implements the recurrent cell from the paper:
1729
+
1730
+ https://arxiv.org/abs/1611.09913
1731
+
1732
+ Jasmine Collins, Jascha Sohl-Dickstein, and David Sussillo.
1733
+ "Capacity and Trainability in Recurrent Neural Networks" Proc. ICLR 2017.
1734
+
1735
+ The Intersection RNN is built for use in deeply stacked
1736
+ RNNs so it may not achieve best performance with depth 1.
1737
+ """
1738
+
1739
+ def __init__(self,
1740
+ num_units,
1741
+ num_in_proj=None,
1742
+ initializer=None,
1743
+ forget_bias=1.0,
1744
+ y_activation=nn_ops.relu,
1745
+ reuse=None):
1746
+ """Initialize the parameters for an +RNN cell.
1747
+
1748
+ Args:
1749
+ num_units: int, The number of units in the +RNN cell
1750
+ num_in_proj: (optional) int, The input dimensionality for the RNN.
1751
+ If creating the first layer of an +RNN, this should be set to
1752
+ `num_units`. Otherwise, this should be set to `None` (default).
1753
+ If `None`, dimensionality of `inputs` should be equal to `num_units`,
1754
+ otherwise ValueError is thrown.
1755
+ initializer: (optional) The initializer to use for the weight matrices.
1756
+ forget_bias: (optional) float, default 1.0, The initial bias of the
1757
+ forget gates, used to reduce the scale of forgetting at the beginning
1758
+ of the training.
1759
+ y_activation: (optional) Activation function of the states passed
1760
+ through depth. Default is 'tf.nn.relu`.
1761
+ reuse: (optional) Python boolean describing whether to reuse variables
1762
+ in an existing scope. If not `True`, and the existing scope already has
1763
+ the given variables, an error is raised.
1764
+ """
1765
+ super(IntersectionRNNCell, self).__init__(_reuse=reuse)
1766
+ self._num_units = num_units
1767
+ self._initializer = initializer
1768
+ self._forget_bias = forget_bias
1769
+ self._num_input_proj = num_in_proj
1770
+ self._y_activation = y_activation
1771
+ self._reuse = reuse
1772
+ self._linear1 = None
1773
+ self._linear2 = None
1774
+
1775
+ @property
1776
+ def state_size(self):
1777
+ return self._num_units
1778
+
1779
+ @property
1780
+ def output_size(self):
1781
+ return self._num_units
1782
+
1783
+ def call(self, inputs, state):
1784
+ """Run one step of the Intersection RNN.
1785
+
1786
+ Args:
1787
+ inputs: input Tensor, 2D, batch x input size.
1788
+ state: state Tensor, 2D, batch x num units.
1789
+
1790
+ Returns:
1791
+ new_y: batch x num units, Tensor representing the output of the +RNN
1792
+ after reading `inputs` when previous state was `state`.
1793
+ new_state: batch x num units, Tensor representing the state of the +RNN
1794
+ after reading `inputs` when previous state was `state`.
1795
+
1796
+ Raises:
1797
+ ValueError: If input size cannot be inferred from `inputs` via
1798
+ static shape inference.
1799
+ ValueError: If input size != output size (these must be equal when
1800
+ using the Intersection RNN).
1801
+ """
1802
+ sigmoid = math_ops.sigmoid
1803
+ tanh = math_ops.tanh
1804
+
1805
+ input_size = inputs.get_shape().with_rank(2).dims[1]
1806
+ if input_size.value is None:
1807
+ raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
1808
+
1809
+ with vs.variable_scope(
1810
+ vs.get_variable_scope(), initializer=self._initializer):
1811
+ # read-in projections (should be used for first layer in deep +RNN
1812
+ # to transform size of inputs from I --> N)
1813
+ if input_size.value != self._num_units:
1814
+ if self._num_input_proj:
1815
+ with vs.variable_scope("in_projection"):
1816
+ if self._linear1 is None:
1817
+ self._linear1 = _Linear(inputs, self._num_units, True)
1818
+ inputs = self._linear1(inputs)
1819
+ else:
1820
+ raise ValueError("Must have input size == output size for "
1821
+ "Intersection RNN. To fix, num_in_proj should "
1822
+ "be set to num_units at cell init.")
1823
+
1824
+ n_dim = i_dim = self._num_units
1825
+ cell_inputs = array_ops.concat([inputs, state], 1)
1826
+ if self._linear2 is None:
1827
+ self._linear2 = _Linear(cell_inputs, 2 * n_dim + 2 * i_dim, True)
1828
+ rnn_matrix = self._linear2(cell_inputs)
1829
+
1830
+ gh_act = rnn_matrix[:, :n_dim] # b x n
1831
+ h_act = rnn_matrix[:, n_dim:2 * n_dim] # b x n
1832
+ gy_act = rnn_matrix[:, 2 * n_dim:2 * n_dim + i_dim] # b x i
1833
+ y_act = rnn_matrix[:, 2 * n_dim + i_dim:2 * n_dim + 2 * i_dim] # b x i
1834
+
1835
+ h = tanh(h_act)
1836
+ y = self._y_activation(y_act)
1837
+ gh = sigmoid(gh_act + self._forget_bias)
1838
+ gy = sigmoid(gy_act + self._forget_bias)
1839
+
1840
+ new_state = gh * state + (1.0 - gh) * h # passed thru time
1841
+ new_y = gy * inputs + (1.0 - gy) * y # passed thru depth
1842
+
1843
+ return new_y, new_state
1844
+
1845
+
1846
+ _REGISTERED_OPS = None
1847
+
1848
+
1849
+ class CompiledWrapper(rnn_cell_impl.RNNCell):
1850
+ """Wraps step execution in an XLA JIT scope."""
1851
+
1852
+ def __init__(self, cell, compile_stateful=False):
1853
+ """Create CompiledWrapper cell.
1854
+
1855
+ Args:
1856
+ cell: Instance of `RNNCell`.
1857
+ compile_stateful: Whether to compile stateful ops like initializers
1858
+ and random number generators (default: False).
1859
+ """
1860
+ self._cell = cell
1861
+ self._compile_stateful = compile_stateful
1862
+
1863
+ @property
1864
+ def state_size(self):
1865
+ return self._cell.state_size
1866
+
1867
+ @property
1868
+ def output_size(self):
1869
+ return self._cell.output_size
1870
+
1871
+ def zero_state(self, batch_size, dtype):
1872
+ with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
1873
+ return self._cell.zero_state(batch_size, dtype)
1874
+
1875
+ def __call__(self, inputs, state, scope=None):
1876
+ if self._compile_stateful:
1877
+ compile_ops = True
1878
+ else:
1879
+
1880
+ def compile_ops(node_def):
1881
+ global _REGISTERED_OPS
1882
+ if _REGISTERED_OPS is None:
1883
+ _REGISTERED_OPS = op_def_registry.get_registered_ops()
1884
+ return not _REGISTERED_OPS[node_def.op].is_stateful
1885
+
1886
+ with jit.experimental_jit_scope(compile_ops=compile_ops):
1887
+ return self._cell(inputs, state, scope=scope)
1888
+
1889
+
1890
+ def _random_exp_initializer(minval, maxval, seed=None, dtype=dtypes.float32):
1891
+ """Returns an exponential distribution initializer.
1892
+
1893
+ Args:
1894
+ minval: float or a scalar float Tensor. With value > 0. Lower bound of the
1895
+ range of random values to generate.
1896
+ maxval: float or a scalar float Tensor. With value > minval. Upper bound of
1897
+ the range of random values to generate.
1898
+ seed: An integer. Used to create random seeds.
1899
+ dtype: The data type.
1900
+
1901
+ Returns:
1902
+ An initializer that generates tensors with an exponential distribution.
1903
+ """
1904
+
1905
+ def _initializer(shape, dtype=dtype, partition_info=None):
1906
+ del partition_info # Unused.
1907
+ return math_ops.exp(
1908
+ random_ops.random_uniform(
1909
+ shape, math_ops.log(minval), math_ops.log(maxval), dtype,
1910
+ seed=seed))
1911
+
1912
+ return _initializer
1913
+
1914
+
1915
+ class PhasedLSTMCell(rnn_cell_impl.RNNCell):
1916
+ """Phased LSTM recurrent network cell.
1917
+
1918
+ https://arxiv.org/pdf/1610.09513v1.pdf
1919
+ """
1920
+
1921
+ def __init__(self,
1922
+ num_units,
1923
+ use_peepholes=False,
1924
+ leak=0.001,
1925
+ ratio_on=0.1,
1926
+ trainable_ratio_on=True,
1927
+ period_init_min=1.0,
1928
+ period_init_max=1000.0,
1929
+ reuse=None):
1930
+ """Initialize the Phased LSTM cell.
1931
+
1932
+ Args:
1933
+ num_units: int, The number of units in the Phased LSTM cell.
1934
+ use_peepholes: bool, set True to enable peephole connections.
1935
+ leak: float or scalar float Tensor with value in [0, 1]. Leak applied
1936
+ during training.
1937
+ ratio_on: float or scalar float Tensor with value in [0, 1]. Ratio of the
1938
+ period during which the gates are open.
1939
+ trainable_ratio_on: bool, weather ratio_on is trainable.
1940
+ period_init_min: float or scalar float Tensor. With value > 0.
1941
+ Minimum value of the initialized period.
1942
+ The period values are initialized by drawing from the distribution:
1943
+ e^U(log(period_init_min), log(period_init_max))
1944
+ Where U(.,.) is the uniform distribution.
1945
+ period_init_max: float or scalar float Tensor.
1946
+ With value > period_init_min. Maximum value of the initialized period.
1947
+ reuse: (optional) Python boolean describing whether to reuse variables
1948
+ in an existing scope. If not `True`, and the existing scope already has
1949
+ the given variables, an error is raised.
1950
+ """
1951
+ # We pass autocast=False because this layer can accept inputs of different
1952
+ # dtypes, so we do not want to automatically cast them to the same dtype.
1953
+ super(PhasedLSTMCell, self).__init__(_reuse=reuse, autocast=False)
1954
+ self._num_units = num_units
1955
+ self._use_peepholes = use_peepholes
1956
+ self._leak = leak
1957
+ self._ratio_on = ratio_on
1958
+ self._trainable_ratio_on = trainable_ratio_on
1959
+ self._period_init_min = period_init_min
1960
+ self._period_init_max = period_init_max
1961
+ self._reuse = reuse
1962
+ self._linear1 = None
1963
+ self._linear2 = None
1964
+ self._linear3 = None
1965
+
1966
+ @property
1967
+ def state_size(self):
1968
+ return rnn_cell_impl.LSTMStateTuple(self._num_units, self._num_units)
1969
+
1970
+ @property
1971
+ def output_size(self):
1972
+ return self._num_units
1973
+
1974
+ def _mod(self, x, y):
1975
+ """Modulo function that propagates x gradients."""
1976
+ return array_ops.stop_gradient(math_ops.mod(x, y) - x) + x
1977
+
1978
+ def _get_cycle_ratio(self, time, phase, period):
1979
+ """Compute the cycle ratio in the dtype of the time."""
1980
+ phase_casted = math_ops.cast(phase, dtype=time.dtype)
1981
+ period_casted = math_ops.cast(period, dtype=time.dtype)
1982
+ shifted_time = time - phase_casted
1983
+ cycle_ratio = self._mod(shifted_time, period_casted) / period_casted
1984
+ return math_ops.cast(cycle_ratio, dtype=dtypes.float32)
1985
+
1986
+ def call(self, inputs, state):
1987
+ """Phased LSTM Cell.
1988
+
1989
+ Args:
1990
+ inputs: A tuple of 2 Tensor.
1991
+ The first Tensor has shape [batch, 1], and type float32 or float64.
1992
+ It stores the time.
1993
+ The second Tensor has shape [batch, features_size], and type float32.
1994
+ It stores the features.
1995
+ state: rnn_cell_impl.LSTMStateTuple, state from previous timestep.
1996
+
1997
+ Returns:
1998
+ A tuple containing:
1999
+ - A Tensor of float32, and shape [batch_size, num_units], representing the
2000
+ output of the cell.
2001
+ - A rnn_cell_impl.LSTMStateTuple, containing 2 Tensors of float32, shape
2002
+ [batch_size, num_units], representing the new state and the output.
2003
+ """
2004
+ (c_prev, h_prev) = state
2005
+ (time, x) = inputs
2006
+
2007
+ in_mask_gates = [x, h_prev]
2008
+ if self._use_peepholes:
2009
+ in_mask_gates.append(c_prev)
2010
+
2011
+ with vs.variable_scope("mask_gates"):
2012
+ if self._linear1 is None:
2013
+ self._linear1 = _Linear(in_mask_gates, 2 * self._num_units, True)
2014
+
2015
+ mask_gates = math_ops.sigmoid(self._linear1(in_mask_gates))
2016
+ [input_gate, forget_gate] = array_ops.split(
2017
+ axis=1, num_or_size_splits=2, value=mask_gates)
2018
+
2019
+ with vs.variable_scope("new_input"):
2020
+ if self._linear2 is None:
2021
+ self._linear2 = _Linear([x, h_prev], self._num_units, True)
2022
+ new_input = math_ops.tanh(self._linear2([x, h_prev]))
2023
+
2024
+ new_c = (c_prev * forget_gate + input_gate * new_input)
2025
+
2026
+ in_out_gate = [x, h_prev]
2027
+ if self._use_peepholes:
2028
+ in_out_gate.append(new_c)
2029
+
2030
+ with vs.variable_scope("output_gate"):
2031
+ if self._linear3 is None:
2032
+ self._linear3 = _Linear(in_out_gate, self._num_units, True)
2033
+ output_gate = math_ops.sigmoid(self._linear3(in_out_gate))
2034
+
2035
+ new_h = math_ops.tanh(new_c) * output_gate
2036
+
2037
+ period = vs.get_variable(
2038
+ "period", [self._num_units],
2039
+ initializer=_random_exp_initializer(self._period_init_min,
2040
+ self._period_init_max))
2041
+ phase = vs.get_variable(
2042
+ "phase", [self._num_units],
2043
+ initializer=init_ops.random_uniform_initializer(0.,
2044
+ period.initial_value))
2045
+ ratio_on = vs.get_variable(
2046
+ "ratio_on", [self._num_units],
2047
+ initializer=init_ops.constant_initializer(self._ratio_on),
2048
+ trainable=self._trainable_ratio_on)
2049
+
2050
+ cycle_ratio = self._get_cycle_ratio(time, phase, period)
2051
+
2052
+ k_up = 2 * cycle_ratio / ratio_on
2053
+ k_down = 2 - k_up
2054
+ k_closed = self._leak * cycle_ratio
2055
+
2056
+ k = array_ops.where(cycle_ratio < ratio_on, k_down, k_closed)
2057
+ k = array_ops.where(cycle_ratio < 0.5 * ratio_on, k_up, k)
2058
+
2059
+ new_c = k * new_c + (1 - k) * c_prev
2060
+ new_h = k * new_h + (1 - k) * h_prev
2061
+
2062
+ new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h)
2063
+
2064
+ return new_h, new_state
2065
+
2066
+
2067
+ class ConvLSTMCell(rnn_cell_impl.RNNCell):
2068
+ """Convolutional LSTM recurrent network cell.
2069
+
2070
+ https://arxiv.org/pdf/1506.04214v1.pdf
2071
+ """
2072
+
2073
+ def __init__(self,
2074
+ conv_ndims,
2075
+ input_shape,
2076
+ output_channels,
2077
+ kernel_shape,
2078
+ use_bias=True,
2079
+ skip_connection=False,
2080
+ forget_bias=1.0,
2081
+ initializers=None,
2082
+ name="conv_lstm_cell"):
2083
+ """Construct ConvLSTMCell.
2084
+
2085
+ Args:
2086
+ conv_ndims: Convolution dimensionality (1, 2 or 3).
2087
+ input_shape: Shape of the input as int tuple, excluding the batch size.
2088
+ output_channels: int, number of output channels of the conv LSTM.
2089
+ kernel_shape: Shape of kernel as an int tuple (of size 1, 2 or 3).
2090
+ use_bias: (bool) Use bias in convolutions.
2091
+ skip_connection: If set to `True`, concatenate the input to the
2092
+ output of the conv LSTM. Default: `False`.
2093
+ forget_bias: Forget bias.
2094
+ initializers: Unused.
2095
+ name: Name of the module.
2096
+
2097
+ Raises:
2098
+ ValueError: If `skip_connection` is `True` and stride is different from 1
2099
+ or if `input_shape` is incompatible with `conv_ndims`.
2100
+ """
2101
+ super(ConvLSTMCell, self).__init__(name=name)
2102
+
2103
+ if conv_ndims != len(input_shape) - 1:
2104
+ raise ValueError("Invalid input_shape {} for conv_ndims={}.".format(
2105
+ input_shape, conv_ndims))
2106
+
2107
+ self._conv_ndims = conv_ndims
2108
+ self._input_shape = input_shape
2109
+ self._output_channels = output_channels
2110
+ self._kernel_shape = list(kernel_shape)
2111
+ self._use_bias = use_bias
2112
+ self._forget_bias = forget_bias
2113
+ self._skip_connection = skip_connection
2114
+
2115
+ self._total_output_channels = output_channels
2116
+ if self._skip_connection:
2117
+ self._total_output_channels += self._input_shape[-1]
2118
+
2119
+ state_size = tensor_shape.TensorShape(
2120
+ self._input_shape[:-1] + [self._output_channels])
2121
+ self._state_size = rnn_cell_impl.LSTMStateTuple(state_size, state_size)
2122
+ self._output_size = tensor_shape.TensorShape(
2123
+ self._input_shape[:-1] + [self._total_output_channels])
2124
+
2125
+ @property
2126
+ def output_size(self):
2127
+ return self._output_size
2128
+
2129
+ @property
2130
+ def state_size(self):
2131
+ return self._state_size
2132
+
2133
+ def call(self, inputs, state, scope=None):
2134
+ cell, hidden = state
2135
+ new_hidden = _conv([inputs, hidden], self._kernel_shape,
2136
+ 4 * self._output_channels, self._use_bias)
2137
+ gates = array_ops.split(
2138
+ value=new_hidden, num_or_size_splits=4, axis=self._conv_ndims + 1)
2139
+
2140
+ input_gate, new_input, forget_gate, output_gate = gates
2141
+ new_cell = math_ops.sigmoid(forget_gate + self._forget_bias) * cell
2142
+ new_cell += math_ops.sigmoid(input_gate) * math_ops.tanh(new_input)
2143
+ output = math_ops.tanh(new_cell) * math_ops.sigmoid(output_gate)
2144
+
2145
+ if self._skip_connection:
2146
+ output = array_ops.concat([output, inputs], axis=-1)
2147
+ new_state = rnn_cell_impl.LSTMStateTuple(new_cell, output)
2148
+ return output, new_state
2149
+
2150
+
2151
+ class Conv1DLSTMCell(ConvLSTMCell):
2152
+ """1D Convolutional LSTM recurrent network cell.
2153
+
2154
+ https://arxiv.org/pdf/1506.04214v1.pdf
2155
+ """
2156
+
2157
+ def __init__(self, name="conv_1d_lstm_cell", **kwargs):
2158
+ """Construct Conv1DLSTM. See `ConvLSTMCell` for more details."""
2159
+ super(Conv1DLSTMCell, self).__init__(conv_ndims=1, name=name, **kwargs)
2160
+
2161
+
2162
+ class Conv2DLSTMCell(ConvLSTMCell):
2163
+ """2D Convolutional LSTM recurrent network cell.
2164
+
2165
+ https://arxiv.org/pdf/1506.04214v1.pdf
2166
+ """
2167
+
2168
+ def __init__(self, name="conv_2d_lstm_cell", **kwargs):
2169
+ """Construct Conv2DLSTM. See `ConvLSTMCell` for more details."""
2170
+ super(Conv2DLSTMCell, self).__init__(conv_ndims=2, name=name, **kwargs)
2171
+
2172
+
2173
+ class Conv3DLSTMCell(ConvLSTMCell):
2174
+ """3D Convolutional LSTM recurrent network cell.
2175
+
2176
+ https://arxiv.org/pdf/1506.04214v1.pdf
2177
+ """
2178
+
2179
+ def __init__(self, name="conv_3d_lstm_cell", **kwargs):
2180
+ """Construct Conv3DLSTM. See `ConvLSTMCell` for more details."""
2181
+ super(Conv3DLSTMCell, self).__init__(conv_ndims=3, name=name, **kwargs)
2182
+
2183
+
2184
+ def _conv(args, filter_size, num_features, bias, bias_start=0.0):
2185
+ """Convolution.
2186
+
2187
+ Args:
2188
+ args: a Tensor or a list of Tensors of dimension 3D, 4D or 5D,
2189
+ batch x n, Tensors.
2190
+ filter_size: int tuple of filter shape (of size 1, 2 or 3).
2191
+ num_features: int, number of features.
2192
+ bias: Whether to use biases in the convolution layer.
2193
+ bias_start: starting value to initialize the bias; 0 by default.
2194
+
2195
+ Returns:
2196
+ A 3D, 4D, or 5D Tensor with shape [batch ... num_features]
2197
+
2198
+ Raises:
2199
+ ValueError: if some of the arguments has unspecified or wrong shape.
2200
+ """
2201
+
2202
+ # Calculate the total size of arguments on dimension 1.
2203
+ total_arg_size_depth = 0
2204
+ shapes = [a.get_shape().as_list() for a in args]
2205
+ shape_length = len(shapes[0])
2206
+ for shape in shapes:
2207
+ if len(shape) not in [3, 4, 5]:
2208
+ raise ValueError("Conv Linear expects 3D, 4D "
2209
+ "or 5D arguments: %s" % str(shapes))
2210
+ if len(shape) != len(shapes[0]):
2211
+ raise ValueError("Conv Linear expects all args "
2212
+ "to be of same Dimension: %s" % str(shapes))
2213
+ else:
2214
+ total_arg_size_depth += shape[-1]
2215
+ dtype = [a.dtype for a in args][0]
2216
+
2217
+ # determine correct conv operation
2218
+ if shape_length == 3:
2219
+ conv_op = nn_ops.conv1d
2220
+ strides = 1
2221
+ elif shape_length == 4:
2222
+ conv_op = nn_ops.conv2d
2223
+ strides = shape_length * [1]
2224
+ elif shape_length == 5:
2225
+ conv_op = nn_ops.conv3d
2226
+ strides = shape_length * [1]
2227
+
2228
+ # Now the computation.
2229
+ kernel = vs.get_variable(
2230
+ "kernel", filter_size + [total_arg_size_depth, num_features], dtype=dtype)
2231
+ if len(args) == 1:
2232
+ res = conv_op(args[0], kernel, strides, padding="SAME")
2233
+ else:
2234
+ res = conv_op(
2235
+ array_ops.concat(axis=shape_length - 1, values=args),
2236
+ kernel,
2237
+ strides,
2238
+ padding="SAME")
2239
+ if not bias:
2240
+ return res
2241
+ bias_term = vs.get_variable(
2242
+ "biases", [num_features],
2243
+ dtype=dtype,
2244
+ initializer=init_ops.constant_initializer(bias_start, dtype=dtype))
2245
+ return res + bias_term
2246
+
2247
+
2248
+ class GLSTMCell(rnn_cell_impl.RNNCell):
2249
+ """Group LSTM cell (G-LSTM).
2250
+
2251
+ The implementation is based on:
2252
+
2253
+ https://arxiv.org/abs/1703.10722
2254
+
2255
+ O. Kuchaiev and B. Ginsburg
2256
+ "Factorization Tricks for LSTM Networks", ICLR 2017 workshop.
2257
+
2258
+ In brief, a G-LSTM cell consists of one LSTM sub-cell per group, where each
2259
+ sub-cell operates on an evenly-sized sub-vector of the input and produces an
2260
+ evenly-sized sub-vector of the output. For example, a G-LSTM cell with 128
2261
+ units and 4 groups consists of 4 LSTMs sub-cells with 32 units each. If that
2262
+ G-LSTM cell is fed a 200-dim input, then each sub-cell receives a 50-dim part
2263
+ of the input and produces a 32-dim part of the output.
2264
+ """
2265
+
2266
+ def __init__(self,
2267
+ num_units,
2268
+ initializer=None,
2269
+ num_proj=None,
2270
+ number_of_groups=1,
2271
+ forget_bias=1.0,
2272
+ activation=math_ops.tanh,
2273
+ reuse=None):
2274
+ """Initialize the parameters of G-LSTM cell.
2275
+
2276
+ Args:
2277
+ num_units: int, The number of units in the G-LSTM cell
2278
+ initializer: (optional) The initializer to use for the weight and
2279
+ projection matrices.
2280
+ num_proj: (optional) int, The output dimensionality for the projection
2281
+ matrices. If None, no projection is performed.
2282
+ number_of_groups: (optional) int, number of groups to use.
2283
+ If `number_of_groups` is 1, then it should be equivalent to LSTM cell
2284
+ forget_bias: Biases of the forget gate are initialized by default to 1
2285
+ in order to reduce the scale of forgetting at the beginning of
2286
+ the training.
2287
+ activation: Activation function of the inner states.
2288
+ reuse: (optional) Python boolean describing whether to reuse variables
2289
+ in an existing scope. If not `True`, and the existing scope already
2290
+ has the given variables, an error is raised.
2291
+
2292
+ Raises:
2293
+ ValueError: If `num_units` or `num_proj` is not divisible by
2294
+ `number_of_groups`.
2295
+ """
2296
+ super(GLSTMCell, self).__init__(_reuse=reuse)
2297
+ self._num_units = num_units
2298
+ self._initializer = initializer
2299
+ self._num_proj = num_proj
2300
+ self._forget_bias = forget_bias
2301
+ self._activation = activation
2302
+ self._number_of_groups = number_of_groups
2303
+
2304
+ if self._num_units % self._number_of_groups != 0:
2305
+ raise ValueError("num_units must be divisible by number_of_groups")
2306
+ if self._num_proj:
2307
+ if self._num_proj % self._number_of_groups != 0:
2308
+ raise ValueError("num_proj must be divisible by number_of_groups")
2309
+ self._group_shape = [
2310
+ int(self._num_proj / self._number_of_groups),
2311
+ int(self._num_units / self._number_of_groups)
2312
+ ]
2313
+ else:
2314
+ self._group_shape = [
2315
+ int(self._num_units / self._number_of_groups),
2316
+ int(self._num_units / self._number_of_groups)
2317
+ ]
2318
+
2319
+ if num_proj:
2320
+ self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_proj)
2321
+ self._output_size = num_proj
2322
+ else:
2323
+ self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_units)
2324
+ self._output_size = num_units
2325
+ self._linear1 = [None] * number_of_groups
2326
+ self._linear2 = None
2327
+
2328
+ @property
2329
+ def state_size(self):
2330
+ return self._state_size
2331
+
2332
+ @property
2333
+ def output_size(self):
2334
+ return self._output_size
2335
+
2336
+ def _get_input_for_group(self, inputs, group_id, group_size):
2337
+ """Slices inputs into groups to prepare for processing by cell's groups.
2338
+
2339
+ Args:
2340
+ inputs: cell input or it's previous state,
2341
+ a Tensor, 2D, [batch x num_units]
2342
+ group_id: group id, a Scalar, for which to prepare input
2343
+ group_size: size of the group
2344
+
2345
+ Returns:
2346
+ subset of inputs corresponding to group "group_id",
2347
+ a Tensor, 2D, [batch x num_units/number_of_groups]
2348
+ """
2349
+ return array_ops.slice(
2350
+ input_=inputs,
2351
+ begin=[0, group_id * group_size],
2352
+ size=[self._batch_size, group_size],
2353
+ name=("GLSTM_group%d_input_generation" % group_id))
2354
+
2355
+ def call(self, inputs, state):
2356
+ """Run one step of G-LSTM.
2357
+
2358
+ Args:
2359
+ inputs: input Tensor, 2D, [batch x num_inputs]. num_inputs must be
2360
+ statically-known and evenly divisible into groups. The innermost
2361
+ vectors of the inputs are split into evenly-sized sub-vectors and fed
2362
+ into the per-group LSTM sub-cells.
2363
+ state: this must be a tuple of state Tensors, both `2-D`, with column
2364
+ sizes `c_state` and `m_state`.
2365
+
2366
+ Returns:
2367
+ A tuple containing:
2368
+
2369
+ - A `2-D, [batch x output_dim]`, Tensor representing the output of the
2370
+ G-LSTM after reading `inputs` when previous state was `state`.
2371
+ Here output_dim is:
2372
+ num_proj if num_proj was set,
2373
+ num_units otherwise.
2374
+ - LSTMStateTuple representing the new state of G-LSTM cell
2375
+ after reading `inputs` when the previous state was `state`.
2376
+
2377
+ Raises:
2378
+ ValueError: If input size cannot be inferred from inputs via
2379
+ static shape inference, or if the input shape is incompatible
2380
+ with the number of groups.
2381
+ """
2382
+ (c_prev, m_prev) = state
2383
+
2384
+ self._batch_size = tensor_shape.dimension_value(
2385
+ inputs.shape[0]) or array_ops.shape(inputs)[0]
2386
+
2387
+ # If the input size is statically-known, calculate and validate its group
2388
+ # size. Otherwise, use the output group size.
2389
+ input_size = tensor_shape.dimension_value(inputs.shape[1])
2390
+ if input_size is None:
2391
+ raise ValueError("input size must be statically known")
2392
+ if input_size % self._number_of_groups != 0:
2393
+ raise ValueError(
2394
+ "input size (%d) must be divisible by number_of_groups (%d)" %
2395
+ (input_size, self._number_of_groups))
2396
+ input_group_size = int(input_size / self._number_of_groups)
2397
+
2398
+ dtype = inputs.dtype
2399
+ scope = vs.get_variable_scope()
2400
+ with vs.variable_scope(scope, initializer=self._initializer):
2401
+ i_parts = []
2402
+ j_parts = []
2403
+ f_parts = []
2404
+ o_parts = []
2405
+
2406
+ for group_id in range(self._number_of_groups):
2407
+ with vs.variable_scope("group%d" % group_id):
2408
+ x_g_id = array_ops.concat(
2409
+ [
2410
+ self._get_input_for_group(inputs, group_id, input_group_size),
2411
+ self._get_input_for_group(m_prev, group_id,
2412
+ self._group_shape[0])
2413
+ ],
2414
+ axis=1)
2415
+ linear = self._linear1[group_id]
2416
+ if linear is None:
2417
+ linear = _Linear(x_g_id, 4 * self._group_shape[1], False)
2418
+ self._linear1[group_id] = linear
2419
+ R_k = linear(x_g_id) # pylint: disable=invalid-name
2420
+ i_k, j_k, f_k, o_k = array_ops.split(R_k, 4, 1)
2421
+
2422
+ i_parts.append(i_k)
2423
+ j_parts.append(j_k)
2424
+ f_parts.append(f_k)
2425
+ o_parts.append(o_k)
2426
+
2427
+ bi = vs.get_variable(
2428
+ name="bias_i",
2429
+ shape=[self._num_units],
2430
+ dtype=dtype,
2431
+ initializer=init_ops.constant_initializer(0.0, dtype=dtype))
2432
+ bj = vs.get_variable(
2433
+ name="bias_j",
2434
+ shape=[self._num_units],
2435
+ dtype=dtype,
2436
+ initializer=init_ops.constant_initializer(0.0, dtype=dtype))
2437
+ bf = vs.get_variable(
2438
+ name="bias_f",
2439
+ shape=[self._num_units],
2440
+ dtype=dtype,
2441
+ initializer=init_ops.constant_initializer(0.0, dtype=dtype))
2442
+ bo = vs.get_variable(
2443
+ name="bias_o",
2444
+ shape=[self._num_units],
2445
+ dtype=dtype,
2446
+ initializer=init_ops.constant_initializer(0.0, dtype=dtype))
2447
+
2448
+ i = nn_ops.bias_add(array_ops.concat(i_parts, axis=1), bi)
2449
+ j = nn_ops.bias_add(array_ops.concat(j_parts, axis=1), bj)
2450
+ f = nn_ops.bias_add(array_ops.concat(f_parts, axis=1), bf)
2451
+ o = nn_ops.bias_add(array_ops.concat(o_parts, axis=1), bo)
2452
+
2453
+ c = (
2454
+ math_ops.sigmoid(f + self._forget_bias) * c_prev +
2455
+ math_ops.sigmoid(i) * math_ops.tanh(j))
2456
+ m = math_ops.sigmoid(o) * self._activation(c)
2457
+
2458
+ if self._num_proj is not None:
2459
+ with vs.variable_scope("projection"):
2460
+ if self._linear2 is None:
2461
+ self._linear2 = _Linear(m, self._num_proj, False)
2462
+ m = self._linear2(m)
2463
+
2464
+ new_state = rnn_cell_impl.LSTMStateTuple(c, m)
2465
+ return m, new_state
2466
+
2467
+
2468
+ class LayerNormLSTMCell(rnn_cell_impl.RNNCell):
2469
+ """Long short-term memory unit (LSTM) recurrent network cell.
2470
+
2471
+ The default non-peephole implementation is based on:
2472
+
2473
+ https://pdfs.semanticscholar.org/1154/0131eae85b2e11d53df7f1360eeb6476e7f4.pdf
2474
+
2475
+ Felix Gers, Jurgen Schmidhuber, and Fred Cummins.
2476
+ "Learning to forget: Continual prediction with LSTM." IET, 850-855, 1999.
2477
+
2478
+ The peephole implementation is based on:
2479
+
2480
+ https://research.google.com/pubs/archive/43905.pdf
2481
+
2482
+ Hasim Sak, Andrew Senior, and Francoise Beaufays.
2483
+ "Long short-term memory recurrent neural network architectures for
2484
+ large scale acoustic modeling." INTERSPEECH, 2014.
2485
+
2486
+ The class uses optional peep-hole connections, optional cell clipping, and
2487
+ an optional projection layer.
2488
+
2489
+ Layer normalization implementation is based on:
2490
+
2491
+ https://arxiv.org/abs/1607.06450.
2492
+
2493
+ "Layer Normalization"
2494
+ Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton
2495
+
2496
+ and is applied before the internal nonlinearities.
2497
+
2498
+ """
2499
+
2500
+ def __init__(self,
2501
+ num_units,
2502
+ use_peepholes=False,
2503
+ cell_clip=None,
2504
+ initializer=None,
2505
+ num_proj=None,
2506
+ proj_clip=None,
2507
+ forget_bias=1.0,
2508
+ activation=None,
2509
+ layer_norm=False,
2510
+ norm_gain=1.0,
2511
+ norm_shift=0.0,
2512
+ reuse=None):
2513
+ """Initialize the parameters for an LSTM cell.
2514
+
2515
+ Args:
2516
+ num_units: int, The number of units in the LSTM cell
2517
+ use_peepholes: bool, set True to enable diagonal/peephole connections.
2518
+ cell_clip: (optional) A float value, if provided the cell state is clipped
2519
+ by this value prior to the cell output activation.
2520
+ initializer: (optional) The initializer to use for the weight and
2521
+ projection matrices.
2522
+ num_proj: (optional) int, The output dimensionality for the projection
2523
+ matrices. If None, no projection is performed.
2524
+ proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is
2525
+ provided, then the projected values are clipped elementwise to within
2526
+ `[-proj_clip, proj_clip]`.
2527
+ forget_bias: Biases of the forget gate are initialized by default to 1
2528
+ in order to reduce the scale of forgetting at the beginning of
2529
+ the training. Must set it manually to `0.0` when restoring from
2530
+ CudnnLSTM trained checkpoints.
2531
+ activation: Activation function of the inner states. Default: `tanh`.
2532
+ layer_norm: If `True`, layer normalization will be applied.
2533
+ norm_gain: float, The layer normalization gain initial value. If
2534
+ `layer_norm` has been set to `False`, this argument will be ignored.
2535
+ norm_shift: float, The layer normalization shift initial value. If
2536
+ `layer_norm` has been set to `False`, this argument will be ignored.
2537
+ reuse: (optional) Python boolean describing whether to reuse variables
2538
+ in an existing scope. If not `True`, and the existing scope already has
2539
+ the given variables, an error is raised.
2540
+
2541
+ When restoring from CudnnLSTM-trained checkpoints, must use
2542
+ CudnnCompatibleLSTMCell instead.
2543
+ """
2544
+ super(LayerNormLSTMCell, self).__init__(_reuse=reuse)
2545
+
2546
+ self._num_units = num_units
2547
+ self._use_peepholes = use_peepholes
2548
+ self._cell_clip = cell_clip
2549
+ self._initializer = initializer
2550
+ self._num_proj = num_proj
2551
+ self._proj_clip = proj_clip
2552
+ self._forget_bias = forget_bias
2553
+ self._activation = activation or math_ops.tanh
2554
+ self._layer_norm = layer_norm
2555
+ self._norm_gain = norm_gain
2556
+ self._norm_shift = norm_shift
2557
+
2558
+ if num_proj:
2559
+ self._state_size = (rnn_cell_impl.LSTMStateTuple(num_units, num_proj))
2560
+ self._output_size = num_proj
2561
+ else:
2562
+ self._state_size = (rnn_cell_impl.LSTMStateTuple(num_units, num_units))
2563
+ self._output_size = num_units
2564
+
2565
+ @property
2566
+ def state_size(self):
2567
+ return self._state_size
2568
+
2569
+ @property
2570
+ def output_size(self):
2571
+ return self._output_size
2572
+
2573
+ def _linear(self,
2574
+ args,
2575
+ output_size,
2576
+ bias,
2577
+ bias_initializer=None,
2578
+ kernel_initializer=None,
2579
+ layer_norm=False):
2580
+ """Linear map: sum_i(args[i] * W[i]), where W[i] is a Variable.
2581
+
2582
+ Args:
2583
+ args: a 2D Tensor or a list of 2D, batch x n, Tensors.
2584
+ output_size: int, second dimension of W[i].
2585
+ bias: boolean, whether to add a bias term or not.
2586
+ bias_initializer: starting value to initialize the bias
2587
+ (default is all zeros).
2588
+ kernel_initializer: starting value to initialize the weight.
2589
+ layer_norm: boolean, whether to apply layer normalization.
2590
+
2591
+
2592
+ Returns:
2593
+ A 2D Tensor with shape [batch x output_size] taking value
2594
+ sum_i(args[i] * W[i]), where each W[i] is a newly created Variable.
2595
+
2596
+ Raises:
2597
+ ValueError: if some of the arguments has unspecified or wrong shape.
2598
+ """
2599
+ if args is None or (nest.is_sequence(args) and not args):
2600
+ raise ValueError("`args` must be specified")
2601
+ if not nest.is_sequence(args):
2602
+ args = [args]
2603
+
2604
+ # Calculate the total size of arguments on dimension 1.
2605
+ total_arg_size = 0
2606
+ shapes = [a.get_shape() for a in args]
2607
+ for shape in shapes:
2608
+ if shape.ndims != 2:
2609
+ raise ValueError("linear is expecting 2D arguments: %s" % shapes)
2610
+ if tensor_shape.dimension_value(shape[1]) is None:
2611
+ raise ValueError("linear expects shape[1] to be provided for shape %s, "
2612
+ "but saw %s" % (shape, shape[1]))
2613
+ else:
2614
+ total_arg_size += tensor_shape.dimension_value(shape[1])
2615
+
2616
+ dtype = [a.dtype for a in args][0]
2617
+
2618
+ # Now the computation.
2619
+ scope = vs.get_variable_scope()
2620
+ with vs.variable_scope(scope) as outer_scope:
2621
+ weights = vs.get_variable(
2622
+ "kernel", [total_arg_size, output_size],
2623
+ dtype=dtype,
2624
+ initializer=kernel_initializer)
2625
+ if len(args) == 1:
2626
+ res = math_ops.matmul(args[0], weights)
2627
+ else:
2628
+ res = math_ops.matmul(array_ops.concat(args, 1), weights)
2629
+ if not bias:
2630
+ return res
2631
+ with vs.variable_scope(outer_scope) as inner_scope:
2632
+ inner_scope.set_partitioner(None)
2633
+ if bias_initializer is None:
2634
+ bias_initializer = init_ops.constant_initializer(0.0, dtype=dtype)
2635
+ biases = vs.get_variable(
2636
+ "bias", [output_size], dtype=dtype, initializer=bias_initializer)
2637
+
2638
+ if not layer_norm:
2639
+ res = nn_ops.bias_add(res, biases)
2640
+
2641
+ return res
2642
+
2643
+ def call(self, inputs, state):
2644
+ """Run one step of LSTM.
2645
+
2646
+ Args:
2647
+ inputs: input Tensor, 2D, batch x num_units.
2648
+ state: this must be a tuple of state Tensors,
2649
+ both `2-D`, with column sizes `c_state` and
2650
+ `m_state`.
2651
+
2652
+ Returns:
2653
+ A tuple containing:
2654
+
2655
+ - A `2-D, [batch x output_dim]`, Tensor representing the output of the
2656
+ LSTM after reading `inputs` when previous state was `state`.
2657
+ Here output_dim is:
2658
+ num_proj if num_proj was set,
2659
+ num_units otherwise.
2660
+ - Tensor(s) representing the new state of LSTM after reading `inputs` when
2661
+ the previous state was `state`. Same type and shape(s) as `state`.
2662
+
2663
+ Raises:
2664
+ ValueError: If input size cannot be inferred from inputs via
2665
+ static shape inference.
2666
+ """
2667
+ sigmoid = math_ops.sigmoid
2668
+
2669
+ (c_prev, m_prev) = state
2670
+
2671
+ dtype = inputs.dtype
2672
+ input_size = inputs.get_shape().with_rank(2).dims[1]
2673
+ if input_size.value is None:
2674
+ raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
2675
+ scope = vs.get_variable_scope()
2676
+ with vs.variable_scope(scope, initializer=self._initializer) as unit_scope:
2677
+
2678
+ # i = input_gate, j = new_input, f = forget_gate, o = output_gate
2679
+ lstm_matrix = self._linear(
2680
+ [inputs, m_prev],
2681
+ 4 * self._num_units,
2682
+ bias=True,
2683
+ bias_initializer=None,
2684
+ layer_norm=self._layer_norm)
2685
+ i, j, f, o = array_ops.split(
2686
+ value=lstm_matrix, num_or_size_splits=4, axis=1)
2687
+
2688
+ if self._layer_norm:
2689
+ i = _norm(self._norm_gain, self._norm_shift, i, "input")
2690
+ j = _norm(self._norm_gain, self._norm_shift, j, "transform")
2691
+ f = _norm(self._norm_gain, self._norm_shift, f, "forget")
2692
+ o = _norm(self._norm_gain, self._norm_shift, o, "output")
2693
+
2694
+ # Diagonal connections
2695
+ if self._use_peepholes:
2696
+ with vs.variable_scope(unit_scope):
2697
+ w_f_diag = vs.get_variable(
2698
+ "w_f_diag", shape=[self._num_units], dtype=dtype)
2699
+ w_i_diag = vs.get_variable(
2700
+ "w_i_diag", shape=[self._num_units], dtype=dtype)
2701
+ w_o_diag = vs.get_variable(
2702
+ "w_o_diag", shape=[self._num_units], dtype=dtype)
2703
+
2704
+ if self._use_peepholes:
2705
+ c = (
2706
+ sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev +
2707
+ sigmoid(i + w_i_diag * c_prev) * self._activation(j))
2708
+ else:
2709
+ c = (
2710
+ sigmoid(f + self._forget_bias) * c_prev +
2711
+ sigmoid(i) * self._activation(j))
2712
+
2713
+ if self._layer_norm:
2714
+ c = _norm(self._norm_gain, self._norm_shift, c, "state")
2715
+
2716
+ if self._cell_clip is not None:
2717
+ # pylint: disable=invalid-unary-operand-type
2718
+ c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip)
2719
+ # pylint: enable=invalid-unary-operand-type
2720
+ if self._use_peepholes:
2721
+ m = sigmoid(o + w_o_diag * c) * self._activation(c)
2722
+ else:
2723
+ m = sigmoid(o) * self._activation(c)
2724
+
2725
+ if self._num_proj is not None:
2726
+ with vs.variable_scope("projection"):
2727
+ m = self._linear(m, self._num_proj, bias=False)
2728
+
2729
+ if self._proj_clip is not None:
2730
+ # pylint: disable=invalid-unary-operand-type
2731
+ m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)
2732
+ # pylint: enable=invalid-unary-operand-type
2733
+
2734
+ new_state = (rnn_cell_impl.LSTMStateTuple(c, m))
2735
+ return m, new_state
2736
+
2737
+
2738
+ class SRUCell(rnn_cell_impl.LayerRNNCell):
2739
+ """SRU, Simple Recurrent Unit.
2740
+
2741
+ Implementation based on
2742
+ Training RNNs as Fast as CNNs (cf. https://arxiv.org/abs/1709.02755).
2743
+
2744
+ This variation of RNN cell is characterized by the simplified data
2745
+ dependence
2746
+ between hidden states of two consecutive time steps. Traditionally, hidden
2747
+ states from a cell at time step t-1 needs to be multiplied with a matrix
2748
+ W_hh before being fed into the ensuing cell at time step t.
2749
+ This flavor of RNN replaces the matrix multiplication between h_{t-1}
2750
+ and W_hh with a pointwise multiplication, resulting in performance
2751
+ gain.
2752
+
2753
+ Args:
2754
+ num_units: int, The number of units in the SRU cell.
2755
+ activation: Nonlinearity to use. Default: `tanh`.
2756
+ reuse: (optional) Python boolean describing whether to reuse variables
2757
+ in an existing scope. If not `True`, and the existing scope already has
2758
+ the given variables, an error is raised.
2759
+ name: (optional) String, the name of the layer. Layers with the same name
2760
+ will share weights, but to avoid mistakes we require reuse=True in such
2761
+ cases.
2762
+ **kwargs: Additional keyword arguments.
2763
+ """
2764
+
2765
+ def __init__(self, num_units, activation=None, reuse=None, name=None,
2766
+ **kwargs):
2767
+ super(SRUCell, self).__init__(_reuse=reuse, name=name, **kwargs)
2768
+ self._num_units = num_units
2769
+ self._activation = activation or math_ops.tanh
2770
+
2771
+ # Restrict inputs to be 2-dimensional matrices
2772
+ self.input_spec = input_spec.InputSpec(ndim=2)
2773
+
2774
+ @property
2775
+ def state_size(self):
2776
+ return self._num_units
2777
+
2778
+ @property
2779
+ def output_size(self):
2780
+ return self._num_units
2781
+
2782
+ def build(self, inputs_shape):
2783
+ if tensor_shape.dimension_value(inputs_shape[1]) is None:
2784
+ raise ValueError(
2785
+ "Expected inputs.shape[-1] to be known, saw shape: %s" % inputs_shape)
2786
+
2787
+ input_depth = tensor_shape.dimension_value(inputs_shape[1])
2788
+
2789
+ # pylint: disable=protected-access
2790
+ self._kernel = self.add_variable(
2791
+ rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
2792
+ shape=[input_depth, 4 * self._num_units])
2793
+ # pylint: enable=protected-access
2794
+ self._bias = self.add_variable(
2795
+ rnn_cell_impl._BIAS_VARIABLE_NAME, # pylint: disable=protected-access
2796
+ shape=[2 * self._num_units],
2797
+ initializer=init_ops.zeros_initializer)
2798
+
2799
+ self._built = True
2800
+
2801
+ def call(self, inputs, state):
2802
+ """Simple recurrent unit (SRU) with num_units cells."""
2803
+
2804
+ U = math_ops.matmul(inputs, self._kernel) # pylint: disable=invalid-name
2805
+ x_bar, f_intermediate, r_intermediate, x_tx = array_ops.split(
2806
+ value=U, num_or_size_splits=4, axis=1)
2807
+
2808
+ f_r = math_ops.sigmoid(
2809
+ nn_ops.bias_add(
2810
+ array_ops.concat([f_intermediate, r_intermediate], 1), self._bias))
2811
+ f, r = array_ops.split(value=f_r, num_or_size_splits=2, axis=1)
2812
+
2813
+ c = f * state + (1.0 - f) * x_bar
2814
+ h = r * self._activation(c) + (1.0 - r) * x_tx
2815
+
2816
+ return h, c
2817
+
2818
+
2819
+ class WeightNormLSTMCell(rnn_cell_impl.RNNCell):
2820
+ """Weight normalized LSTM Cell. Adapted from `rnn_cell_impl.LSTMCell`.
2821
+
2822
+ The weight-norm implementation is based on:
2823
+ https://arxiv.org/abs/1602.07868
2824
+ Tim Salimans, Diederik P. Kingma.
2825
+ Weight Normalization: A Simple Reparameterization to Accelerate
2826
+ Training of Deep Neural Networks
2827
+
2828
+ The default LSTM implementation based on:
2829
+
2830
+ https://pdfs.semanticscholar.org/1154/0131eae85b2e11d53df7f1360eeb6476e7f4.pdf
2831
+
2832
+ Felix Gers, Jurgen Schmidhuber, and Fred Cummins.
2833
+ "Learning to forget: Continual prediction with LSTM." IET, 850-855, 1999.
2834
+
2835
+ The class uses optional peephole connections, optional cell clipping
2836
+ and an optional projection layer.
2837
+
2838
+ The optional peephole implementation is based on:
2839
+ https://research.google.com/pubs/archive/43905.pdf
2840
+ Hasim Sak, Andrew Senior, and Francoise Beaufays.
2841
+ "Long short-term memory recurrent neural network architectures for
2842
+ large scale acoustic modeling." INTERSPEECH, 2014.
2843
+ """
2844
+
2845
+ def __init__(self,
2846
+ num_units,
2847
+ norm=True,
2848
+ use_peepholes=False,
2849
+ cell_clip=None,
2850
+ initializer=None,
2851
+ num_proj=None,
2852
+ proj_clip=None,
2853
+ forget_bias=1,
2854
+ activation=None,
2855
+ reuse=None):
2856
+ """Initialize the parameters of a weight-normalized LSTM cell.
2857
+
2858
+ Args:
2859
+ num_units: int, The number of units in the LSTM cell
2860
+ norm: If `True`, apply normalization to the weight matrices. If False,
2861
+ the result is identical to that obtained from `rnn_cell_impl.LSTMCell`
2862
+ use_peepholes: bool, set `True` to enable diagonal/peephole connections.
2863
+ cell_clip: (optional) A float value, if provided the cell state is clipped
2864
+ by this value prior to the cell output activation.
2865
+ initializer: (optional) The initializer to use for the weight matrices.
2866
+ num_proj: (optional) int, The output dimensionality for the projection
2867
+ matrices. If None, no projection is performed.
2868
+ proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is
2869
+ provided, then the projected values are clipped elementwise to within
2870
+ `[-proj_clip, proj_clip]`.
2871
+ forget_bias: Biases of the forget gate are initialized by default to 1
2872
+ in order to reduce the scale of forgetting at the beginning of
2873
+ the training.
2874
+ activation: Activation function of the inner states. Default: `tanh`.
2875
+ reuse: (optional) Python boolean describing whether to reuse variables
2876
+ in an existing scope. If not `True`, and the existing scope already has
2877
+ the given variables, an error is raised.
2878
+ """
2879
+ super(WeightNormLSTMCell, self).__init__(_reuse=reuse)
2880
+
2881
+ self._scope = "wn_lstm_cell"
2882
+ self._num_units = num_units
2883
+ self._norm = norm
2884
+ self._initializer = initializer
2885
+ self._use_peepholes = use_peepholes
2886
+ self._cell_clip = cell_clip
2887
+ self._num_proj = num_proj
2888
+ self._proj_clip = proj_clip
2889
+ self._activation = activation or math_ops.tanh
2890
+ self._forget_bias = forget_bias
2891
+
2892
+ self._weights_variable_name = "kernel"
2893
+ self._bias_variable_name = "bias"
2894
+
2895
+ if num_proj:
2896
+ self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_proj)
2897
+ self._output_size = num_proj
2898
+ else:
2899
+ self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_units)
2900
+ self._output_size = num_units
2901
+
2902
+ @property
2903
+ def state_size(self):
2904
+ return self._state_size
2905
+
2906
+ @property
2907
+ def output_size(self):
2908
+ return self._output_size
2909
+
2910
+ def _normalize(self, weight, name):
2911
+ """Apply weight normalization.
2912
+
2913
+ Args:
2914
+ weight: a 2D tensor with known number of columns.
2915
+ name: string, variable name for the normalizer.
2916
+ Returns:
2917
+ A tensor with the same shape as `weight`.
2918
+ """
2919
+
2920
+ output_size = weight.get_shape().as_list()[1]
2921
+ g = vs.get_variable(name, [output_size], dtype=weight.dtype)
2922
+ return nn_impl.l2_normalize(weight, axis=0) * g
2923
+
2924
+ def _linear(self,
2925
+ args,
2926
+ output_size,
2927
+ norm,
2928
+ bias,
2929
+ bias_initializer=None,
2930
+ kernel_initializer=None):
2931
+ """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable.
2932
+
2933
+ Args:
2934
+ args: a 2D Tensor or a list of 2D, batch x n, Tensors.
2935
+ output_size: int, second dimension of W[i].
2936
+ norm: bool, whether to normalize the weights.
2937
+ bias: boolean, whether to add a bias term or not.
2938
+ bias_initializer: starting value to initialize the bias
2939
+ (default is all zeros).
2940
+ kernel_initializer: starting value to initialize the weight.
2941
+
2942
+ Returns:
2943
+ A 2D Tensor with shape [batch x output_size] equal to
2944
+ sum_i(args[i] * W[i]), where W[i]s are newly created matrices.
2945
+
2946
+ Raises:
2947
+ ValueError: if some of the arguments has unspecified or wrong shape.
2948
+ """
2949
+ if args is None or (nest.is_sequence(args) and not args):
2950
+ raise ValueError("`args` must be specified")
2951
+ if not nest.is_sequence(args):
2952
+ args = [args]
2953
+
2954
+ # Calculate the total size of arguments on dimension 1.
2955
+ total_arg_size = 0
2956
+ shapes = [a.get_shape() for a in args]
2957
+ for shape in shapes:
2958
+ if shape.ndims != 2:
2959
+ raise ValueError("linear is expecting 2D arguments: %s" % shapes)
2960
+ if tensor_shape.dimension_value(shape[1]) is None:
2961
+ raise ValueError("linear expects shape[1] to be provided for shape %s, "
2962
+ "but saw %s" % (shape, shape[1]))
2963
+ else:
2964
+ total_arg_size += tensor_shape.dimension_value(shape[1])
2965
+
2966
+ dtype = [a.dtype for a in args][0]
2967
+
2968
+ # Now the computation.
2969
+ scope = vs.get_variable_scope()
2970
+ with vs.variable_scope(scope) as outer_scope:
2971
+ weights = vs.get_variable(
2972
+ self._weights_variable_name, [total_arg_size, output_size],
2973
+ dtype=dtype,
2974
+ initializer=kernel_initializer)
2975
+ if norm:
2976
+ wn = []
2977
+ st = 0
2978
+ with ops.control_dependencies(None):
2979
+ for i in range(len(args)):
2980
+ en = st + tensor_shape.dimension_value(shapes[i][1])
2981
+ wn.append(
2982
+ self._normalize(weights[st:en, :], name="norm_{}".format(i)))
2983
+ st = en
2984
+
2985
+ weights = array_ops.concat(wn, axis=0)
2986
+
2987
+ if len(args) == 1:
2988
+ res = math_ops.matmul(args[0], weights)
2989
+ else:
2990
+ res = math_ops.matmul(array_ops.concat(args, 1), weights)
2991
+ if not bias:
2992
+ return res
2993
+
2994
+ with vs.variable_scope(outer_scope) as inner_scope:
2995
+ inner_scope.set_partitioner(None)
2996
+ if bias_initializer is None:
2997
+ bias_initializer = init_ops.constant_initializer(0.0, dtype=dtype)
2998
+
2999
+ biases = vs.get_variable(
3000
+ self._bias_variable_name, [output_size],
3001
+ dtype=dtype,
3002
+ initializer=bias_initializer)
3003
+
3004
+ return nn_ops.bias_add(res, biases)
3005
+
3006
+ def call(self, inputs, state):
3007
+ """Run one step of LSTM.
3008
+
3009
+ Args:
3010
+ inputs: input Tensor, 2D, batch x num_units.
3011
+ state: A tuple of state Tensors, both `2-D`, with column sizes
3012
+ `c_state` and `m_state`.
3013
+
3014
+ Returns:
3015
+ A tuple containing:
3016
+
3017
+ - A `2-D, [batch x output_dim]`, Tensor representing the output of the
3018
+ LSTM after reading `inputs` when previous state was `state`.
3019
+ Here output_dim is:
3020
+ num_proj if num_proj was set,
3021
+ num_units otherwise.
3022
+ - Tensor(s) representing the new state of LSTM after reading `inputs` when
3023
+ the previous state was `state`. Same type and shape(s) as `state`.
3024
+
3025
+ Raises:
3026
+ ValueError: If input size cannot be inferred from inputs via
3027
+ static shape inference.
3028
+ """
3029
+ dtype = inputs.dtype
3030
+ num_units = self._num_units
3031
+ sigmoid = math_ops.sigmoid
3032
+ c, h = state
3033
+
3034
+ input_size = inputs.get_shape().with_rank(2).dims[1]
3035
+ if input_size.value is None:
3036
+ raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
3037
+
3038
+ with vs.variable_scope(self._scope, initializer=self._initializer):
3039
+
3040
+ concat = self._linear(
3041
+ [inputs, h], 4 * num_units, norm=self._norm, bias=True)
3042
+
3043
+ # i = input_gate, j = new_input, f = forget_gate, o = output_gate
3044
+ i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)
3045
+
3046
+ if self._use_peepholes:
3047
+ w_f_diag = vs.get_variable("w_f_diag", shape=[num_units], dtype=dtype)
3048
+ w_i_diag = vs.get_variable("w_i_diag", shape=[num_units], dtype=dtype)
3049
+ w_o_diag = vs.get_variable("w_o_diag", shape=[num_units], dtype=dtype)
3050
+
3051
+ new_c = (
3052
+ c * sigmoid(f + self._forget_bias + w_f_diag * c) +
3053
+ sigmoid(i + w_i_diag * c) * self._activation(j))
3054
+ else:
3055
+ new_c = (
3056
+ c * sigmoid(f + self._forget_bias) +
3057
+ sigmoid(i) * self._activation(j))
3058
+
3059
+ if self._cell_clip is not None:
3060
+ # pylint: disable=invalid-unary-operand-type
3061
+ new_c = clip_ops.clip_by_value(new_c, -self._cell_clip, self._cell_clip)
3062
+ # pylint: enable=invalid-unary-operand-type
3063
+ if self._use_peepholes:
3064
+ new_h = sigmoid(o + w_o_diag * new_c) * self._activation(new_c)
3065
+ else:
3066
+ new_h = sigmoid(o) * self._activation(new_c)
3067
+
3068
+ if self._num_proj is not None:
3069
+ with vs.variable_scope("projection"):
3070
+ new_h = self._linear(
3071
+ new_h, self._num_proj, norm=self._norm, bias=False)
3072
+
3073
+ if self._proj_clip is not None:
3074
+ # pylint: disable=invalid-unary-operand-type
3075
+ new_h = clip_ops.clip_by_value(new_h, -self._proj_clip,
3076
+ self._proj_clip)
3077
+ # pylint: enable=invalid-unary-operand-type
3078
+
3079
+ new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h)
3080
+ return new_h, new_state
3081
+
3082
+
3083
+ class IndRNNCell(rnn_cell_impl.LayerRNNCell):
3084
+ """Independently Recurrent Neural Network (IndRNN) cell
3085
+ (cf. https://arxiv.org/abs/1803.04831).
3086
+
3087
+ Args:
3088
+ num_units: int, The number of units in the RNN cell.
3089
+ activation: Nonlinearity to use. Default: `tanh`.
3090
+ reuse: (optional) Python boolean describing whether to reuse variables
3091
+ in an existing scope. If not `True`, and the existing scope already has
3092
+ the given variables, an error is raised.
3093
+ name: String, the name of the layer. Layers with the same name will
3094
+ share weights, but to avoid mistakes we require reuse=True in such
3095
+ cases.
3096
+ dtype: Default dtype of the layer (default of `None` means use the type
3097
+ of the first input). Required when `build` is called before `call`.
3098
+ """
3099
+
3100
+ def __init__(self,
3101
+ num_units,
3102
+ activation=None,
3103
+ reuse=None,
3104
+ name=None,
3105
+ dtype=None):
3106
+ super(IndRNNCell, self).__init__(_reuse=reuse, name=name, dtype=dtype)
3107
+
3108
+ # Inputs must be 2-dimensional.
3109
+ self.input_spec = input_spec.InputSpec(ndim=2)
3110
+
3111
+ self._num_units = num_units
3112
+ self._activation = activation or math_ops.tanh
3113
+
3114
+ @property
3115
+ def state_size(self):
3116
+ return self._num_units
3117
+
3118
+ @property
3119
+ def output_size(self):
3120
+ return self._num_units
3121
+
3122
+ def build(self, inputs_shape):
3123
+ if tensor_shape.dimension_value(inputs_shape[1]) is None:
3124
+ raise ValueError(
3125
+ "Expected inputs.shape[-1] to be known, saw shape: %s" % inputs_shape)
3126
+
3127
+ input_depth = tensor_shape.dimension_value(inputs_shape[1])
3128
+ # pylint: disable=protected-access
3129
+ self._kernel_w = self.add_variable(
3130
+ "%s_w" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
3131
+ shape=[input_depth, self._num_units])
3132
+ self._kernel_u = self.add_variable(
3133
+ "%s_u" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
3134
+ shape=[1, self._num_units],
3135
+ initializer=init_ops.random_uniform_initializer(
3136
+ minval=-1, maxval=1, dtype=self.dtype))
3137
+ self._bias = self.add_variable(
3138
+ rnn_cell_impl._BIAS_VARIABLE_NAME,
3139
+ shape=[self._num_units],
3140
+ initializer=init_ops.zeros_initializer(dtype=self.dtype))
3141
+ # pylint: enable=protected-access
3142
+
3143
+ self.built = True
3144
+
3145
+ def call(self, inputs, state):
3146
+ """IndRNN: output = new_state = act(W * input + u * state + B)."""
3147
+
3148
+ gate_inputs = math_ops.matmul(inputs, self._kernel_w) + (
3149
+ state * self._kernel_u)
3150
+ gate_inputs = nn_ops.bias_add(gate_inputs, self._bias)
3151
+ output = self._activation(gate_inputs)
3152
+ return output, output
3153
+
3154
+
3155
+ class IndyGRUCell(rnn_cell_impl.LayerRNNCell):
3156
+ r"""Independently Gated Recurrent Unit cell.
3157
+
3158
+ Based on IndRNNs (https://arxiv.org/abs/1803.04831) and similar to GRUCell,
3159
+ yet with the \\(U_r\\), \\(U_z\\), and \\(U\\) matrices in equations 5, 6, and
3160
+ 8 of http://arxiv.org/abs/1406.1078 respectively replaced by diagonal
3161
+ matrices, i.e. a Hadamard product with a single vector:
3162
+
3163
+ $$r_j = \sigma\left([\mathbf W_r\mathbf x]_j +
3164
+ [\mathbf u_r\circ \mathbf h_{(t-1)}]_j\right)$$
3165
+ $$z_j = \sigma\left([\mathbf W_z\mathbf x]_j +
3166
+ [\mathbf u_z\circ \mathbf h_{(t-1)}]_j\right)$$
3167
+ $$\tilde{h}^{(t)}_j = \phi\left([\mathbf W \mathbf x]_j +
3168
+ [\mathbf u \circ \mathbf r \circ \mathbf h_{(t-1)}]_j\right)$$
3169
+
3170
+ where \\(\circ\\) denotes the Hadamard operator. This means that each IndyGRU
3171
+ node sees only its own state, as opposed to seeing all states in the same
3172
+ layer.
3173
+
3174
+ Args:
3175
+ num_units: int, The number of units in the GRU cell.
3176
+ activation: Nonlinearity to use. Default: `tanh`.
3177
+ reuse: (optional) Python boolean describing whether to reuse variables
3178
+ in an existing scope. If not `True`, and the existing scope already has
3179
+ the given variables, an error is raised.
3180
+ kernel_initializer: (optional) The initializer to use for the weight
3181
+ matrices applied to the input.
3182
+ bias_initializer: (optional) The initializer to use for the bias.
3183
+ name: String, the name of the layer. Layers with the same name will
3184
+ share weights, but to avoid mistakes we require reuse=True in such
3185
+ cases.
3186
+ dtype: Default dtype of the layer (default of `None` means use the type
3187
+ of the first input). Required when `build` is called before `call`.
3188
+ """
3189
+
3190
+ def __init__(self,
3191
+ num_units,
3192
+ activation=None,
3193
+ reuse=None,
3194
+ kernel_initializer=None,
3195
+ bias_initializer=None,
3196
+ name=None,
3197
+ dtype=None):
3198
+ super(IndyGRUCell, self).__init__(_reuse=reuse, name=name, dtype=dtype)
3199
+
3200
+ # Inputs must be 2-dimensional.
3201
+ self.input_spec = input_spec.InputSpec(ndim=2)
3202
+
3203
+ self._num_units = num_units
3204
+ self._activation = activation or math_ops.tanh
3205
+ self._kernel_initializer = kernel_initializer
3206
+ self._bias_initializer = bias_initializer
3207
+
3208
+ @property
3209
+ def state_size(self):
3210
+ return self._num_units
3211
+
3212
+ @property
3213
+ def output_size(self):
3214
+ return self._num_units
3215
+
3216
+ def build(self, inputs_shape):
3217
+ if tensor_shape.dimension_value(inputs_shape[1]) is None:
3218
+ raise ValueError(
3219
+ "Expected inputs.shape[-1] to be known, saw shape: %s" % inputs_shape)
3220
+
3221
+ input_depth = tensor_shape.dimension_value(inputs_shape[1])
3222
+ # pylint: disable=protected-access
3223
+ self._gate_kernel_w = self.add_variable(
3224
+ "gates/%s_w" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
3225
+ shape=[input_depth, 2 * self._num_units],
3226
+ initializer=self._kernel_initializer)
3227
+ self._gate_kernel_u = self.add_variable(
3228
+ "gates/%s_u" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
3229
+ shape=[1, 2 * self._num_units],
3230
+ initializer=init_ops.random_uniform_initializer(
3231
+ minval=-1, maxval=1, dtype=self.dtype))
3232
+ self._gate_bias = self.add_variable(
3233
+ "gates/%s" % rnn_cell_impl._BIAS_VARIABLE_NAME,
3234
+ shape=[2 * self._num_units],
3235
+ initializer=(self._bias_initializer
3236
+ if self._bias_initializer is not None else
3237
+ init_ops.constant_initializer(1.0, dtype=self.dtype)))
3238
+ self._candidate_kernel_w = self.add_variable(
3239
+ "candidate/%s" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
3240
+ shape=[input_depth, self._num_units],
3241
+ initializer=self._kernel_initializer)
3242
+ self._candidate_kernel_u = self.add_variable(
3243
+ "candidate/%s_u" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
3244
+ shape=[1, self._num_units],
3245
+ initializer=init_ops.random_uniform_initializer(
3246
+ minval=-1, maxval=1, dtype=self.dtype))
3247
+ self._candidate_bias = self.add_variable(
3248
+ "candidate/%s" % rnn_cell_impl._BIAS_VARIABLE_NAME,
3249
+ shape=[self._num_units],
3250
+ initializer=(self._bias_initializer
3251
+ if self._bias_initializer is not None else
3252
+ init_ops.zeros_initializer(dtype=self.dtype)))
3253
+ # pylint: enable=protected-access
3254
+
3255
+ self.built = True
3256
+
3257
+ def call(self, inputs, state):
3258
+ """Recurrently independent Gated Recurrent Unit (GRU) with nunits cells."""
3259
+
3260
+ gate_inputs = math_ops.matmul(inputs, self._gate_kernel_w) + (
3261
+ gen_array_ops.tile(state, [1, 2]) * self._gate_kernel_u)
3262
+ gate_inputs = nn_ops.bias_add(gate_inputs, self._gate_bias)
3263
+
3264
+ value = math_ops.sigmoid(gate_inputs)
3265
+ r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1)
3266
+
3267
+ r_state = r * state
3268
+
3269
+ candidate = math_ops.matmul(inputs, self._candidate_kernel_w) + (
3270
+ r_state * self._candidate_kernel_u)
3271
+ candidate = nn_ops.bias_add(candidate, self._candidate_bias)
3272
+
3273
+ c = self._activation(candidate)
3274
+ new_h = u * state + (1 - u) * c
3275
+ return new_h, new_h
3276
+
3277
+
3278
+ class IndyLSTMCell(rnn_cell_impl.LayerRNNCell):
3279
+ r"""Basic IndyLSTM recurrent network cell.
3280
+
3281
+ Based on IndRNNs (https://arxiv.org/abs/1803.04831) and similar to
3282
+ BasicLSTMCell, yet with the \\(U_f\\), \\(U_i\\), \\(U_o\\) and \\(U_c\\)
3283
+ matrices in the regular LSTM equations replaced by diagonal matrices, i.e. a
3284
+ Hadamard product with a single vector:
3285
+
3286
+ $$f_t = \sigma_g\left(W_f x_t + u_f \circ h_{t-1} + b_f\right)$$
3287
+ $$i_t = \sigma_g\left(W_i x_t + u_i \circ h_{t-1} + b_i\right)$$
3288
+ $$o_t = \sigma_g\left(W_o x_t + u_o \circ h_{t-1} + b_o\right)$$
3289
+ $$c_t = f_t \circ c_{t-1} +
3290
+ i_t \circ \sigma_c\left(W_c x_t + u_c \circ h_{t-1} + b_c\right)$$
3291
+
3292
+ where \\(\circ\\) denotes the Hadamard operator. This means that each IndyLSTM
3293
+ node sees only its own state \\(h\\) and \\(c\\), as opposed to seeing all
3294
+ states in the same layer.
3295
+
3296
+ We add forget_bias (default: 1) to the biases of the forget gate in order to
3297
+ reduce the scale of forgetting in the beginning of the training.
3298
+
3299
+ It does not allow cell clipping, a projection layer, and does not
3300
+ use peep-hole connections: it is the basic baseline.
3301
+
3302
+ For a detailed analysis of IndyLSTMs, see https://arxiv.org/abs/1903.08023.
3303
+ """
3304
+
3305
+ def __init__(self,
3306
+ num_units,
3307
+ forget_bias=1.0,
3308
+ activation=None,
3309
+ reuse=None,
3310
+ kernel_initializer=None,
3311
+ bias_initializer=None,
3312
+ name=None,
3313
+ dtype=None):
3314
+ """Initialize the IndyLSTM cell.
3315
+
3316
+ Args:
3317
+ num_units: int, The number of units in the LSTM cell.
3318
+ forget_bias: float, The bias added to forget gates (see above).
3319
+ Must set to `0.0` manually when restoring from CudnnLSTM-trained
3320
+ checkpoints.
3321
+ activation: Activation function of the inner states. Default: `tanh`.
3322
+ reuse: (optional) Python boolean describing whether to reuse variables
3323
+ in an existing scope. If not `True`, and the existing scope already has
3324
+ the given variables, an error is raised.
3325
+ kernel_initializer: (optional) The initializer to use for the weight
3326
+ matrix applied to the inputs.
3327
+ bias_initializer: (optional) The initializer to use for the bias.
3328
+ name: String, the name of the layer. Layers with the same name will
3329
+ share weights, but to avoid mistakes we require reuse=True in such
3330
+ cases.
3331
+ dtype: Default dtype of the layer (default of `None` means use the type
3332
+ of the first input). Required when `build` is called before `call`.
3333
+ """
3334
+ super(IndyLSTMCell, self).__init__(_reuse=reuse, name=name, dtype=dtype)
3335
+
3336
+ # Inputs must be 2-dimensional.
3337
+ self.input_spec = input_spec.InputSpec(ndim=2)
3338
+
3339
+ self._num_units = num_units
3340
+ self._forget_bias = forget_bias
3341
+ self._activation = activation or math_ops.tanh
3342
+ self._kernel_initializer = kernel_initializer
3343
+ self._bias_initializer = bias_initializer
3344
+
3345
+ @property
3346
+ def state_size(self):
3347
+ return rnn_cell_impl.LSTMStateTuple(self._num_units, self._num_units)
3348
+
3349
+ @property
3350
+ def output_size(self):
3351
+ return self._num_units
3352
+
3353
+ def build(self, inputs_shape):
3354
+ if tensor_shape.dimension_value(inputs_shape[1]) is None:
3355
+ raise ValueError(
3356
+ "Expected inputs.shape[-1] to be known, saw shape: %s" % inputs_shape)
3357
+
3358
+ input_depth = tensor_shape.dimension_value(inputs_shape[1])
3359
+ # pylint: disable=protected-access
3360
+ self._kernel_w = self.add_variable(
3361
+ "%s_w" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
3362
+ shape=[input_depth, 4 * self._num_units],
3363
+ initializer=self._kernel_initializer)
3364
+ self._kernel_u = self.add_variable(
3365
+ "%s_u" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
3366
+ shape=[1, 4 * self._num_units],
3367
+ initializer=init_ops.random_uniform_initializer(
3368
+ minval=-1, maxval=1, dtype=self.dtype))
3369
+ self._bias = self.add_variable(
3370
+ rnn_cell_impl._BIAS_VARIABLE_NAME,
3371
+ shape=[4 * self._num_units],
3372
+ initializer=(self._bias_initializer
3373
+ if self._bias_initializer is not None else
3374
+ init_ops.zeros_initializer(dtype=self.dtype)))
3375
+ # pylint: enable=protected-access
3376
+
3377
+ self.built = True
3378
+
3379
+ def call(self, inputs, state):
3380
+ """Independent Long short-term memory cell (IndyLSTM).
3381
+
3382
+ Args:
3383
+ inputs: `2-D` tensor with shape `[batch_size, input_size]`.
3384
+ state: An `LSTMStateTuple` of state tensors, each shaped
3385
+ `[batch_size, num_units]`.
3386
+
3387
+ Returns:
3388
+ A pair containing the new hidden state, and the new state (a
3389
+ `LSTMStateTuple`).
3390
+ """
3391
+ sigmoid = math_ops.sigmoid
3392
+ one = constant_op.constant(1, dtype=dtypes.int32)
3393
+ c, h = state
3394
+
3395
+ gate_inputs = math_ops.matmul(inputs, self._kernel_w)
3396
+ gate_inputs += gen_array_ops.tile(h, [1, 4]) * self._kernel_u
3397
+ gate_inputs = nn_ops.bias_add(gate_inputs, self._bias)
3398
+
3399
+ # i = input_gate, j = new_input, f = forget_gate, o = output_gate
3400
+ i, j, f, o = array_ops.split(
3401
+ value=gate_inputs, num_or_size_splits=4, axis=one)
3402
+
3403
+ forget_bias_tensor = constant_op.constant(self._forget_bias, dtype=f.dtype)
3404
+ # Note that using `add` and `multiply` instead of `+` and `*` gives a
3405
+ # performance improvement. So using those at the cost of readability.
3406
+ add = math_ops.add
3407
+ multiply = math_ops.multiply
3408
+ new_c = add(
3409
+ multiply(c, sigmoid(add(f, forget_bias_tensor))),
3410
+ multiply(sigmoid(i), self._activation(j)))
3411
+ new_h = multiply(self._activation(new_c), sigmoid(o))
3412
+
3413
+ new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h)
3414
+ return new_h, new_state
3415
+
3416
+
3417
+ NTMControllerState = collections.namedtuple(
3418
+ "NTMControllerState",
3419
+ ("controller_state", "read_vector_list", "w_list", "M", "time"))
3420
+
3421
+
3422
+ class NTMCell(rnn_cell_impl.LayerRNNCell):
3423
+ """Neural Turing Machine Cell with RNN controller.
3424
+
3425
+ Implementation based on:
3426
+ https://arxiv.org/abs/1807.08518
3427
+ Mark Collier, Joeran Beel
3428
+
3429
+ which is in turn based on the source code of:
3430
+ https://github.com/snowkylin/ntm
3431
+
3432
+ and of course the original NTM paper:
3433
+ Neural Turing Machines
3434
+ https://arxiv.org/abs/1410.5401
3435
+ A Graves, G Wayne, I Danihelka
3436
+ """
3437
+
3438
+ def __init__(self,
3439
+ controller,
3440
+ memory_size,
3441
+ memory_vector_dim,
3442
+ read_head_num,
3443
+ write_head_num,
3444
+ shift_range=1,
3445
+ output_dim=None,
3446
+ clip_value=20,
3447
+ dtype=dtypes.float32,
3448
+ name=None):
3449
+ """Initialize the NTM Cell.
3450
+
3451
+ Args:
3452
+ controller: an RNNCell, the RNN controller.
3453
+ memory_size: int, The number of memory locations in the NTM memory
3454
+ matrix
3455
+ memory_vector_dim: int, The dimensionality of each location in the NTM
3456
+ memory matrix
3457
+ read_head_num: int, The number of read heads from the controller into
3458
+ memory
3459
+ write_head_num: int, The number of write heads from the controller into
3460
+ memory
3461
+ shift_range: int, The number of places to the left/right it is possible
3462
+ to iterate the previous address to in a single step
3463
+ output_dim: int, The number of dimensions to make a linear projection of
3464
+ the NTM controller outputs to. If None, no linear projection is
3465
+ applied
3466
+ clip_value: float, The maximum absolute value the controller parameters
3467
+ are clipped to
3468
+ dtype: Default dtype of the layer (default of `None` means use the type
3469
+ of the first input). Required when `build` is called before `call`.
3470
+ name: String, the name of the layer. Layers with the same name will
3471
+ share weights, but to avoid mistakes we require reuse=True in such
3472
+ cases.
3473
+ """
3474
+ super(NTMCell, self).__init__(dtype=dtype, name=name)
3475
+
3476
+ rnn_cell_impl.assert_like_rnncell("NTM RNN controller cell", controller)
3477
+
3478
+ self.controller = controller
3479
+ self.memory_size = memory_size
3480
+ self.memory_vector_dim = memory_vector_dim
3481
+ self.read_head_num = read_head_num
3482
+ self.write_head_num = write_head_num
3483
+ self.clip_value = clip_value
3484
+
3485
+ self.output_dim = output_dim
3486
+ self.shift_range = shift_range
3487
+
3488
+ self.num_parameters_per_head = (
3489
+ self.memory_vector_dim + 2 * self.shift_range + 4)
3490
+ self.num_heads = self.read_head_num + self.write_head_num
3491
+ self.total_parameter_num = (
3492
+ self.num_parameters_per_head * self.num_heads +
3493
+ self.memory_vector_dim * 2 * self.write_head_num)
3494
+
3495
+ @property
3496
+ def state_size(self):
3497
+ return NTMControllerState(
3498
+ controller_state=self.controller.state_size,
3499
+ read_vector_list=[
3500
+ self.memory_vector_dim for _ in range(self.read_head_num)
3501
+ ],
3502
+ w_list=[
3503
+ self.memory_size
3504
+ for _ in range(self.read_head_num + self.write_head_num)
3505
+ ],
3506
+ M=tensor_shape.TensorShape([self.memory_size * self.memory_vector_dim]),
3507
+ time=tensor_shape.TensorShape([]))
3508
+
3509
+ @property
3510
+ def output_size(self):
3511
+ return self.output_dim
3512
+
3513
+ def build(self, inputs_shape):
3514
+ if self.output_dim is None:
3515
+ if inputs_shape[1].value is None:
3516
+ raise ValueError(
3517
+ "Expected inputs.shape[-1] to be known, saw shape: %s" %
3518
+ inputs_shape)
3519
+ else:
3520
+ self.output_dim = inputs_shape[1].value
3521
+
3522
+ def _create_linear_initializer(input_size, dtype=dtypes.float32):
3523
+ stddev = 1.0 / math.sqrt(input_size)
3524
+ return init_ops.truncated_normal_initializer(stddev=stddev, dtype=dtype)
3525
+
3526
+ self._params_kernel = self.add_variable(
3527
+ "parameters_kernel",
3528
+ shape=[self.controller.output_size, self.total_parameter_num],
3529
+ initializer=_create_linear_initializer(self.controller.output_size))
3530
+
3531
+ self._params_bias = self.add_variable(
3532
+ "parameters_bias",
3533
+ shape=[self.total_parameter_num],
3534
+ initializer=init_ops.constant_initializer(0.0, dtype=self.dtype))
3535
+
3536
+ self._output_kernel = self.add_variable(
3537
+ "output_kernel",
3538
+ shape=[
3539
+ self.controller.output_size +
3540
+ self.memory_vector_dim * self.read_head_num, self.output_dim
3541
+ ],
3542
+ initializer=_create_linear_initializer(self.controller.output_size +
3543
+ self.memory_vector_dim *
3544
+ self.read_head_num))
3545
+
3546
+ self._output_bias = self.add_variable(
3547
+ "output_bias",
3548
+ shape=[self.output_dim],
3549
+ initializer=init_ops.constant_initializer(0.0, dtype=self.dtype))
3550
+
3551
+ self._init_read_vectors = [
3552
+ self.add_variable(
3553
+ "initial_read_vector_%d" % i,
3554
+ shape=[1, self.memory_vector_dim],
3555
+ initializer=initializers.glorot_uniform())
3556
+ for i in range(self.read_head_num)
3557
+ ]
3558
+
3559
+ self._init_address_weights = [
3560
+ self.add_variable(
3561
+ "initial_address_weights_%d" % i,
3562
+ shape=[1, self.memory_size],
3563
+ initializer=initializers.glorot_uniform())
3564
+ for i in range(self.read_head_num + self.write_head_num)
3565
+ ]
3566
+
3567
+ self._M = self.add_variable(
3568
+ "memory",
3569
+ shape=[self.memory_size, self.memory_vector_dim],
3570
+ initializer=init_ops.constant_initializer(1e-6, dtype=self.dtype))
3571
+
3572
+ self.built = True
3573
+
3574
+ def call(self, x, prev_state):
3575
+ # Addressing Mechanisms (Sec 3.3)
3576
+
3577
+ def _prev_read_vector_list_initial_value():
3578
+ return [
3579
+ self._expand(
3580
+ math_ops.tanh(
3581
+ array_ops.squeeze(
3582
+ math_ops.matmul(
3583
+ array_ops.ones([1, 1]), self._init_read_vectors[i]))),
3584
+ dim=0,
3585
+ N=x.shape[0].value or array_ops.shape(x)[0])
3586
+ for i in range(self.read_head_num)
3587
+ ]
3588
+
3589
+ prev_read_vector_list = control_flow_ops.cond(
3590
+ math_ops.equal(prev_state.time,
3591
+ 0), _prev_read_vector_list_initial_value, lambda:
3592
+ prev_state.read_vector_list)
3593
+ if self.read_head_num == 1:
3594
+ prev_read_vector_list = [prev_read_vector_list]
3595
+
3596
+ controller_input = array_ops.concat([x] + prev_read_vector_list, axis=1)
3597
+ controller_output, controller_state = self.controller(
3598
+ controller_input, prev_state.controller_state)
3599
+
3600
+ parameters = math_ops.matmul(controller_output, self._params_kernel)
3601
+ parameters = nn_ops.bias_add(parameters, self._params_bias)
3602
+ parameters = clip_ops.clip_by_value(parameters, -self.clip_value,
3603
+ self.clip_value)
3604
+ head_parameter_list = array_ops.split(
3605
+ parameters[:, :self.num_parameters_per_head * self.num_heads],
3606
+ self.num_heads,
3607
+ axis=1)
3608
+ erase_add_list = array_ops.split(
3609
+ parameters[:, self.num_parameters_per_head * self.num_heads:],
3610
+ 2 * self.write_head_num,
3611
+ axis=1)
3612
+
3613
+ def _prev_w_list_initial_value():
3614
+ return [
3615
+ self._expand(
3616
+ nn_ops.softmax(
3617
+ array_ops.squeeze(
3618
+ math_ops.matmul(
3619
+ array_ops.ones([1, 1]),
3620
+ self._init_address_weights[i]))),
3621
+ dim=0,
3622
+ N=x.shape[0].value or array_ops.shape(x)[0])
3623
+ for i in range(self.read_head_num + self.write_head_num)
3624
+ ]
3625
+
3626
+ prev_w_list = control_flow_ops.cond(
3627
+ math_ops.equal(prev_state.time, 0),
3628
+ _prev_w_list_initial_value, lambda: prev_state.w_list)
3629
+ if (self.read_head_num + self.write_head_num) == 1:
3630
+ prev_w_list = [prev_w_list]
3631
+
3632
+ prev_M = control_flow_ops.cond(
3633
+ math_ops.equal(prev_state.time, 0), lambda: self._expand(
3634
+ self._M, dim=0, N=x.shape[0].value or array_ops.shape(x)[0]),
3635
+ lambda: prev_state.M)
3636
+
3637
+ w_list = []
3638
+ for i, head_parameter in enumerate(head_parameter_list):
3639
+ k = math_ops.tanh(head_parameter[:, 0:self.memory_vector_dim])
3640
+ beta = nn_ops.softplus(head_parameter[:, self.memory_vector_dim])
3641
+ g = math_ops.sigmoid(head_parameter[:, self.memory_vector_dim + 1])
3642
+ s = nn_ops.softmax(head_parameter[:, self.memory_vector_dim +
3643
+ 2:(self.memory_vector_dim + 2 +
3644
+ (self.shift_range * 2 + 1))])
3645
+ gamma = nn_ops.softplus(head_parameter[:, -1]) + 1
3646
+ w = self._addressing(k, beta, g, s, gamma, prev_M, prev_w_list[i])
3647
+ w_list.append(w)
3648
+
3649
+ # Reading (Sec 3.1)
3650
+
3651
+ read_w_list = w_list[:self.read_head_num]
3652
+ read_vector_list = []
3653
+ for i in range(self.read_head_num):
3654
+ read_vector = math_ops.reduce_sum(
3655
+ array_ops.expand_dims(read_w_list[i], dim=2) * prev_M, axis=1)
3656
+ read_vector_list.append(read_vector)
3657
+
3658
+ # Writing (Sec 3.2)
3659
+
3660
+ write_w_list = w_list[self.read_head_num:]
3661
+ M = prev_M
3662
+ for i in range(self.write_head_num):
3663
+ w = array_ops.expand_dims(write_w_list[i], axis=2)
3664
+ erase_vector = array_ops.expand_dims(
3665
+ math_ops.sigmoid(erase_add_list[i * 2]), axis=1)
3666
+ add_vector = array_ops.expand_dims(
3667
+ math_ops.tanh(erase_add_list[i * 2 + 1]), axis=1)
3668
+ erase_M = array_ops.ones_like(M) - math_ops.matmul(w, erase_vector)
3669
+ M = M * erase_M + math_ops.matmul(w, add_vector)
3670
+
3671
+ output = math_ops.matmul(
3672
+ array_ops.concat([controller_output] + read_vector_list, axis=1),
3673
+ self._output_kernel)
3674
+ output = nn_ops.bias_add(output, self._output_bias)
3675
+ output = clip_ops.clip_by_value(output, -self.clip_value, self.clip_value)
3676
+
3677
+ return output, NTMControllerState(
3678
+ controller_state=controller_state,
3679
+ read_vector_list=read_vector_list,
3680
+ w_list=w_list,
3681
+ M=M,
3682
+ time=prev_state.time + 1)
3683
+
3684
+ def _expand(self, x, dim, N):
3685
+ return array_ops.concat([array_ops.expand_dims(x, dim) for _ in range(N)],
3686
+ axis=dim)
3687
+
3688
+ def _addressing(self, k, beta, g, s, gamma, prev_M, prev_w):
3689
+ # Sec 3.3.1 Focusing by Content
3690
+
3691
+ k = array_ops.expand_dims(k, axis=2)
3692
+ inner_product = math_ops.matmul(prev_M, k)
3693
+ k_norm = math_ops.sqrt(
3694
+ math_ops.reduce_sum(math_ops.square(k), axis=1, keepdims=True))
3695
+ M_norm = math_ops.sqrt(
3696
+ math_ops.reduce_sum(math_ops.square(prev_M), axis=2, keepdims=True))
3697
+ norm_product = M_norm * k_norm
3698
+
3699
+ # eq (6)
3700
+ K = array_ops.squeeze(inner_product / (norm_product + 1e-8))
3701
+
3702
+ K_amplified = math_ops.exp(array_ops.expand_dims(beta, axis=1) * K)
3703
+
3704
+ # eq (5)
3705
+ w_c = K_amplified / math_ops.reduce_sum(K_amplified, axis=1, keepdims=True)
3706
+
3707
+ # Sec 3.3.2 Focusing by Location
3708
+
3709
+ g = array_ops.expand_dims(g, axis=1)
3710
+
3711
+ # eq (7)
3712
+ w_g = g * w_c + (1 - g) * prev_w
3713
+
3714
+ s = array_ops.concat([
3715
+ s[:, :self.shift_range + 1],
3716
+ array_ops.zeros([
3717
+ s.shape[0].value or array_ops.shape(s)[0], self.memory_size -
3718
+ (self.shift_range * 2 + 1)
3719
+ ]), s[:, -self.shift_range:]
3720
+ ],
3721
+ axis=1)
3722
+ t = array_ops.concat(
3723
+ [array_ops.reverse(s, axis=[1]),
3724
+ array_ops.reverse(s, axis=[1])],
3725
+ axis=1)
3726
+ s_matrix = array_ops.stack([
3727
+ t[:, self.memory_size - i - 1:self.memory_size * 2 - i - 1]
3728
+ for i in range(self.memory_size)
3729
+ ],
3730
+ axis=1)
3731
+
3732
+ # eq (8)
3733
+ w_ = math_ops.reduce_sum(
3734
+ array_ops.expand_dims(w_g, axis=1) * s_matrix, axis=2)
3735
+ w_sharpen = math_ops.pow(w_, array_ops.expand_dims(gamma, axis=1))
3736
+
3737
+ # eq (9)
3738
+ w = w_sharpen / math_ops.reduce_sum(w_sharpen, axis=1, keepdims=True)
3739
+
3740
+ return w
3741
+
3742
+ def zero_state(self, batch_size, dtype):
3743
+ read_vector_list = [
3744
+ array_ops.zeros([batch_size, self.memory_vector_dim])
3745
+ for _ in range(self.read_head_num)
3746
+ ]
3747
+
3748
+ w_list = [
3749
+ array_ops.zeros([batch_size, self.memory_size])
3750
+ for _ in range(self.read_head_num + self.write_head_num)
3751
+ ]
3752
+
3753
+ controller_init_state = self.controller.zero_state(batch_size, dtype)
3754
+
3755
+ M = array_ops.zeros([batch_size, self.memory_size, self.memory_vector_dim])
3756
+
3757
+ return NTMControllerState(
3758
+ controller_state=controller_init_state,
3759
+ read_vector_list=read_vector_list,
3760
+ w_list=w_list,
3761
+ M=M,
3762
+ time=0)
3763
+
3764
+
3765
+ class MinimalRNNCell(rnn_cell_impl.LayerRNNCell):
3766
+ """MinimalRNN cell.
3767
+
3768
+ The implementation is based on:
3769
+
3770
+ https://arxiv.org/pdf/1806.05394v2.pdf
3771
+
3772
+ Minmin Chen, Jeffrey Pennington, Samuel S. Schoenholz.
3773
+ "Dynamical Isometry and a Mean Field Theory of RNNs: Gating Enables Signal
3774
+ Propagation in Recurrent Neural Networks." ICML, 2018.
3775
+
3776
+ A MinimalRNN cell first projects the input to the hidden space. The new
3777
+ hidden state is then calculated as a weighted sum of the projected input and
3778
+ the previous hidden state, using a single update gate.
3779
+ """
3780
+
3781
+ def __init__(self,
3782
+ units,
3783
+ activation="tanh",
3784
+ kernel_initializer="glorot_uniform",
3785
+ bias_initializer="ones",
3786
+ name=None,
3787
+ dtype=None,
3788
+ **kwargs):
3789
+ """Initialize the parameters for a MinimalRNN cell.
3790
+
3791
+ Args:
3792
+ units: int, The number of units in the MinimalRNN cell.
3793
+ activation: Nonlinearity to use in the feedforward network. Default:
3794
+ `tanh`.
3795
+ kernel_initializer: The initializer to use for the weight in the update
3796
+ gate and feedforward network. Default: `glorot_uniform`.
3797
+ bias_initializer: The initializer to use for the bias in the update
3798
+ gate. Default: `ones`.
3799
+ name: String, the name of the cell.
3800
+ dtype: Default dtype of the cell.
3801
+ **kwargs: Dict, keyword named properties for common cell attributes.
3802
+ """
3803
+ super(MinimalRNNCell, self).__init__(name=name, dtype=dtype, **kwargs)
3804
+
3805
+ # Inputs must be 2-dimensional.
3806
+ self.input_spec = input_spec.InputSpec(ndim=2)
3807
+
3808
+ self.units = units
3809
+ self.activation = activations.get(activation)
3810
+ self.kernel_initializer = initializers.get(kernel_initializer)
3811
+ self.bias_initializer = initializers.get(bias_initializer)
3812
+
3813
+ @property
3814
+ def state_size(self):
3815
+ return self.units
3816
+
3817
+ @property
3818
+ def output_size(self):
3819
+ return self.units
3820
+
3821
+ def build(self, inputs_shape):
3822
+ if inputs_shape[-1] is None:
3823
+ raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
3824
+ % str(inputs_shape))
3825
+
3826
+ input_size = inputs_shape[-1]
3827
+ # pylint: disable=protected-access
3828
+ # self._kernel contains W_x, W, V
3829
+ self.kernel = self.add_weight(
3830
+ name=rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
3831
+ shape=[input_size + 2 * self.units, self.units],
3832
+ initializer=self.kernel_initializer)
3833
+ self.bias = self.add_weight(
3834
+ name=rnn_cell_impl._BIAS_VARIABLE_NAME,
3835
+ shape=[self.units],
3836
+ initializer=self.bias_initializer)
3837
+ # pylint: enable=protected-access
3838
+
3839
+ self.built = True
3840
+
3841
+ def call(self, inputs, state):
3842
+ """Run one step of MinimalRNN.
3843
+
3844
+ Args:
3845
+ inputs: input Tensor, must be 2-D, `[batch, input_size]`.
3846
+ state: state Tensor, must be 2-D, `[batch, state_size]`.
3847
+
3848
+ Returns:
3849
+ A tuple containing:
3850
+
3851
+ - Output: A `2-D` tensor with shape `[batch_size, state_size]`.
3852
+ - New state: A `2-D` tensor with shape `[batch_size, state_size]`.
3853
+
3854
+ Raises:
3855
+ ValueError: If input size cannot be inferred from inputs via
3856
+ static shape inference.
3857
+ """
3858
+ input_size = inputs.get_shape()[1]
3859
+ if tensor_shape.dimension_value(input_size) is None:
3860
+ raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
3861
+
3862
+ feedforward_weight, gate_weight = array_ops.split(
3863
+ value=self.kernel,
3864
+ num_or_size_splits=[tensor_shape.dimension_value(input_size),
3865
+ 2 * self.units],
3866
+ axis=0)
3867
+
3868
+ feedforward = math_ops.matmul(inputs, feedforward_weight)
3869
+ feedforward = self.activation(feedforward)
3870
+
3871
+ gate_inputs = math_ops.matmul(
3872
+ array_ops.concat([feedforward, state], 1), gate_weight)
3873
+ gate_inputs = nn_ops.bias_add(gate_inputs, self.bias)
3874
+ u = math_ops.sigmoid(gate_inputs)
3875
+
3876
+ new_h = u * state + (1 - u) * feedforward
3877
+ return new_h, new_h
3878
+
3879
+
3880
+ class CFNCell(rnn_cell_impl.LayerRNNCell):
3881
+ """Chaos Free Network cell.
3882
+
3883
+ The implementation is based on:
3884
+
3885
+ https://openreview.net/pdf?id=S1dIzvclg
3886
+
3887
+ Thomas Laurent, James von Brecht.
3888
+ "A recurrent neural network without chaos." ICLR, 2017.
3889
+
3890
+ A CFN cell first projects the input to the hidden space. The hidden state
3891
+ goes through a contractive mapping. The new hidden state is then calculated
3892
+ as a linear combination of the projected input and the contracted previous
3893
+ hidden state, using decoupled input and forget gates.
3894
+ """
3895
+
3896
+ def __init__(self,
3897
+ units,
3898
+ activation="tanh",
3899
+ kernel_initializer="glorot_uniform",
3900
+ bias_initializer="ones",
3901
+ name=None,
3902
+ dtype=None,
3903
+ **kwargs):
3904
+ """Initialize the parameters for a CFN cell.
3905
+
3906
+ Args:
3907
+ units: int, The number of units in the CFN cell.
3908
+ activation: Nonlinearity to use. Default: `tanh`.
3909
+ kernel_initializer: Initializer for the `kernel` weights
3910
+ matrix. Default: `glorot_uniform`.
3911
+ bias_initializer: The initializer to use for the bias in the
3912
+ gates. Default: `ones`.
3913
+ name: String, the name of the cell.
3914
+ dtype: Default dtype of the cell.
3915
+ **kwargs: Dict, keyword named properties for common cell attributes.
3916
+ """
3917
+ super(CFNCell, self).__init__(name=name, dtype=dtype, **kwargs)
3918
+
3919
+ # Inputs must be 2-dimensional.
3920
+ self.input_spec = input_spec.InputSpec(ndim=2)
3921
+
3922
+ self.units = units
3923
+ self.activation = activations.get(activation)
3924
+ self.kernel_initializer = initializers.get(kernel_initializer)
3925
+ self.bias_initializer = initializers.get(bias_initializer)
3926
+
3927
+ @property
3928
+ def state_size(self):
3929
+ return self.units
3930
+
3931
+ @property
3932
+ def output_size(self):
3933
+ return self.units
3934
+
3935
+ def build(self, inputs_shape):
3936
+ if inputs_shape[-1] is None:
3937
+ raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
3938
+ % str(inputs_shape))
3939
+
3940
+ input_size = inputs_shape[-1]
3941
+ # pylint: disable=protected-access
3942
+ # `self.kernel` contains V_{\theta}, V_{\eta}, W.
3943
+ # `self.recurrent_kernel` contains U_{\theta}, U_{\eta}.
3944
+ # `self.bias` contains b_{\theta}, b_{\eta}.
3945
+ self.kernel = self.add_weight(
3946
+ shape=[input_size, 3 * self.units],
3947
+ name=rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
3948
+ initializer=self.kernel_initializer)
3949
+ self.recurrent_kernel = self.add_weight(
3950
+ shape=[self.units, 2 * self.units],
3951
+ name="recurrent_%s" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
3952
+ initializer=self.kernel_initializer)
3953
+ self.bias = self.add_weight(
3954
+ shape=[2 * self.units],
3955
+ name=rnn_cell_impl._BIAS_VARIABLE_NAME,
3956
+ initializer=self.bias_initializer)
3957
+ # pylint: enable=protected-access
3958
+
3959
+ self.built = True
3960
+
3961
+ def call(self, inputs, state):
3962
+ """Run one step of CFN.
3963
+
3964
+ Args:
3965
+ inputs: input Tensor, must be 2-D, `[batch, input_size]`.
3966
+ state: state Tensor, must be 2-D, `[batch, state_size]`.
3967
+
3968
+ Returns:
3969
+ A tuple containing:
3970
+
3971
+ - Output: A `2-D` tensor with shape `[batch_size, state_size]`.
3972
+ - New state: A `2-D` tensor with shape `[batch_size, state_size]`.
3973
+
3974
+ Raises:
3975
+ ValueError: If input size cannot be inferred from inputs via
3976
+ static shape inference.
3977
+ """
3978
+ input_size = inputs.get_shape()[-1]
3979
+ if tensor_shape.dimension_value(input_size) is None:
3980
+ raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
3981
+
3982
+ # The variable names u, v, w, b are consistent with the notations in the
3983
+ # original paper.
3984
+ v, w = array_ops.split(
3985
+ value=self.kernel,
3986
+ num_or_size_splits=[2 * self.units, self.units],
3987
+ axis=1)
3988
+ u = self.recurrent_kernel
3989
+ b = self.bias
3990
+
3991
+ gates = math_ops.matmul(state, u) + math_ops.matmul(inputs, v)
3992
+ gates = nn_ops.bias_add(gates, b)
3993
+ gates = math_ops.sigmoid(gates)
3994
+ theta, eta = array_ops.split(value=gates,
3995
+ num_or_size_splits=2,
3996
+ axis=1)
3997
+
3998
+ proj_input = math_ops.matmul(inputs, w)
3999
+
4000
+ # The input gate is (1 - eta), which is different from the original paper.
4001
+ # This is for the propose of initialization. With the default
4002
+ # bias_initializer `ones`, the input gate is initialized to a small number.
4003
+ new_h = theta * self.activation(state) + (1 - eta) * self.activation(
4004
+ proj_input)
4005
+
4006
+ return new_h, new_h