returnn 1.20240705.144031__tar.gz → 1.20240709.122157__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of returnn might be problematic. Click here for more details.

Files changed (451) hide show
  1. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/PKG-INFO +1 -1
  2. returnn-1.20240709.122157/_setup_info_generated.py +2 -0
  3. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/datasets/meta.py +29 -1
  4. returnn-1.20240709.122157/returnn/torch/util/gradient_checkpoint.py +594 -0
  5. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn.egg-info/PKG-INFO +1 -1
  6. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn.egg-info/SOURCES.txt +2 -0
  7. returnn-1.20240709.122157/tests/test_torch_util.py +303 -0
  8. returnn-1.20240705.144031/_setup_info_generated.py +0 -2
  9. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/.editorconfig +0 -0
  10. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/.gitignore +0 -0
  11. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/.gitmodules +0 -0
  12. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/.kateconfig +0 -0
  13. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/CHANGELOG.md +0 -0
  14. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/CODEOWNERS +0 -0
  15. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/CONTRIBUTING.md +0 -0
  16. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/LICENSE +0 -0
  17. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/MANIFEST.in +0 -0
  18. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/README.rst +0 -0
  19. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/__init__.py +0 -0
  20. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/12AX.cluster_map +0 -0
  21. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/_setup_returnn_env.py +0 -0
  22. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-fwd.config +0 -0
  23. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-horovod-mpi.py +0 -0
  24. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-horovod-mpi.py.sh +0 -0
  25. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-horovod-mpi.sh +0 -0
  26. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-hyper-param-tuning.config +0 -0
  27. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-iter-dataset.py +0 -0
  28. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-list-devices.py +0 -0
  29. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-lua-torch-layer.config +0 -0
  30. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-pretrain.config +0 -0
  31. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-record-and-push-to-webserver.py +0 -0
  32. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-returnn-as-framework.py +0 -0
  33. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-rf-pt-benchmark.py +0 -0
  34. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-rf.config +0 -0
  35. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-rhn-enwik8.config +0 -0
  36. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-sprint-interface.py +0 -0
  37. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-tf-att-copy.config +0 -0
  38. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-tf-attention.config +0 -0
  39. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-tf-chunking-blstm.12ax.config +0 -0
  40. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-tf-contribrnn-lstm.12ax.config +0 -0
  41. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-tf-enc-dec.config +0 -0
  42. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-tf-hard-att-copy.config +0 -0
  43. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-tf-lstm-benchmark.py +0 -0
  44. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-tf-maxgradnorm-lstm.12ax.config +0 -0
  45. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-tf-native-lstm-lowmem.12ax.config +0 -0
  46. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-tf-native-lstm.12ax.config +0 -0
  47. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-tf-native-lstm2.12ax.config +0 -0
  48. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-tf-native-lstm2.12ax.tuned.config +0 -0
  49. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-tf-neural-transducer.12ax.config +0 -0
  50. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-tf-rec-explicit-lstm.config +0 -0
  51. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-tf-rec-explicit-rnn.config +0 -0
  52. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-tf-rec-self-att.config +0 -0
  53. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-tf-search-compiled-graph.py +0 -0
  54. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-tf-vanilla-lstm.12ax.config +0 -0
  55. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-timit-lstm-ctc.config +0 -0
  56. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-torch.config +0 -0
  57. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo-upd-mult-model.lstm.12ax.config +0 -0
  58. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/demo.sh +0 -0
  59. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/mdlstm/IAM/IAM_lines/a01-000u-00.png +0 -0
  60. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/mdlstm/IAM/IAM_lines/a01-007-04.png +0 -0
  61. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/mdlstm/IAM/IAM_lines/a01-007-06.png +0 -0
  62. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/mdlstm/IAM/README.txt +0 -0
  63. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/mdlstm/IAM/chars.txt +0 -0
  64. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/mdlstm/IAM/config_demo +0 -0
  65. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/mdlstm/IAM/config_fwd +0 -0
  66. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/mdlstm/IAM/config_real +0 -0
  67. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/mdlstm/IAM/create_IAM_dataset.py +0 -0
  68. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/mdlstm/IAM/decode.py +0 -0
  69. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/mdlstm/IAM/features/raw/demo.h5 +0 -0
  70. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/mdlstm/IAM/go.sh +0 -0
  71. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/mdlstm/IAM/lines.txt +0 -0
  72. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/mdlstm/IAM/split/eval.txt +0 -0
  73. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/mdlstm/IAM/split/train.txt +0 -0
  74. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/mdlstm/IAM/split/valid.txt +0 -0
  75. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/mdlstm/README.md +0 -0
  76. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/mdlstm/artificial/create_test_h5.py +0 -0
  77. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/mdlstm/artificial/forwardconfig +0 -0
  78. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/mdlstm/artificial/go.sh +0 -0
  79. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/mdlstm/artificial/trainconfig +0 -0
  80. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/mdlstm/artificial_rgb/create_test_h5.py +0 -0
  81. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/mdlstm/artificial_rgb/forwardconfig +0 -0
  82. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/mdlstm/artificial_rgb/go.sh +0 -0
  83. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/demos/mdlstm/artificial_rgb/trainconfig +0 -0
  84. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/pyproject.toml +0 -0
  85. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/requirements.txt +0 -0
  86. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/__init__.py +0 -0
  87. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/__main__.py +0 -0
  88. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/__old_mod_loader__.py +0 -0
  89. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/__setup__.py +0 -0
  90. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/config.py +0 -0
  91. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/datasets/__init__.py +0 -0
  92. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/datasets/audio.py +0 -0
  93. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/datasets/basic.py +0 -0
  94. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/datasets/bundle_file.py +0 -0
  95. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/datasets/cached.py +0 -0
  96. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/datasets/cached2.py +0 -0
  97. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/datasets/distrib_files.py +0 -0
  98. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/datasets/generating.py +0 -0
  99. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/datasets/hdf.py +0 -0
  100. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/datasets/lm.py +0 -0
  101. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/datasets/map.py +0 -0
  102. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/datasets/multi_proc.py +0 -0
  103. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/datasets/normalization_data.py +0 -0
  104. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/datasets/numpy_dump.py +0 -0
  105. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/datasets/raw_wav.py +0 -0
  106. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/datasets/sprint.py +0 -0
  107. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/datasets/stereo.py +0 -0
  108. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/datasets/util/__init__.py +0 -0
  109. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/datasets/util/feature_extraction.py +0 -0
  110. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/datasets/util/strings.py +0 -0
  111. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/datasets/util/vocabulary.py +0 -0
  112. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/engine/__init__.py +0 -0
  113. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/engine/base.py +0 -0
  114. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/engine/batch.py +0 -0
  115. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/WarpRna/__init__.py +0 -0
  116. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/WarpRna/__main__.py +0 -0
  117. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/WarpRna/warp-rna/.git +0 -0
  118. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/WarpRna/warp-rna/.gitignore +0 -0
  119. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/WarpRna/warp-rna/LICENSE +0 -0
  120. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/WarpRna/warp-rna/README.md +0 -0
  121. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/WarpRna/warp-rna/aligner.gif +0 -0
  122. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/WarpRna/warp-rna/check.png +0 -0
  123. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/WarpRna/warp-rna/core.cu +0 -0
  124. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/WarpRna/warp-rna/core.h +0 -0
  125. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/WarpRna/warp-rna/core_cpu.cpp +0 -0
  126. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/WarpRna/warp-rna/pytorch_binding/LICENSE +0 -0
  127. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/WarpRna/warp-rna/pytorch_binding/MANIFEST.in +0 -0
  128. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/WarpRna/warp-rna/pytorch_binding/README.md +0 -0
  129. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/WarpRna/warp-rna/pytorch_binding/binding.cpp +0 -0
  130. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/WarpRna/warp-rna/pytorch_binding/core.cu +0 -0
  131. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/WarpRna/warp-rna/pytorch_binding/core.h +0 -0
  132. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/WarpRna/warp-rna/pytorch_binding/requirements.txt +0 -0
  133. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/WarpRna/warp-rna/pytorch_binding/setup.py +0 -0
  134. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/WarpRna/warp-rna/pytorch_binding/warp_rna/__init__.py +0 -0
  135. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/WarpRna/warp-rna/pytorch_binding/warp_rna/test.py +0 -0
  136. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/WarpRna/warp-rna/ref_rna.py +0 -0
  137. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/WarpRna/warp-rna/tensorflow_binding/setup.py +0 -0
  138. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/WarpRna/warp-rna/tensorflow_binding/src/warp_rna_op.cc +0 -0
  139. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/WarpRna/warp-rna/tensorflow_binding/src/warp_rna_op_kernel_tmpl.h +0 -0
  140. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/WarpRna/warp-rna/tensorflow_binding/warp_rna/__init__.py +0 -0
  141. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/WarpRna/warp-rna/test.cpp +0 -0
  142. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/__init__.py +0 -0
  143. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/graph_editor/README.md +0 -0
  144. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/graph_editor/__init__.py +0 -0
  145. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/graph_editor/edit.py +0 -0
  146. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/graph_editor/reroute.py +0 -0
  147. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/graph_editor/select.py +0 -0
  148. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/graph_editor/subgraph.py +0 -0
  149. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/graph_editor/transform.py +0 -0
  150. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/extern/graph_editor/util.py +0 -0
  151. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/forward_iface.py +0 -0
  152. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/__init__.py +0 -0
  153. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/_backend.py +0 -0
  154. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/_native/__init__.py +0 -0
  155. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/_native/backend.cpp +0 -0
  156. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/_native/backend.hpp +0 -0
  157. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/_native/module.cpp +0 -0
  158. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/_native/module.hpp +0 -0
  159. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/_native/py_utils.hpp +0 -0
  160. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/_native/tensor_ops.cpp +0 -0
  161. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/_native/tensor_ops.hpp +0 -0
  162. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/_numpy_backend.py +0 -0
  163. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/_random_journal.py +0 -0
  164. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/_utils.py +0 -0
  165. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/array_.py +0 -0
  166. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/attention.py +0 -0
  167. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/audio/__init__.py +0 -0
  168. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/audio/mel.py +0 -0
  169. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/audio/specaugment.py +0 -0
  170. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/backend.py +0 -0
  171. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/build_from_dict.py +0 -0
  172. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/cond.py +0 -0
  173. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/const.py +0 -0
  174. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/container.py +0 -0
  175. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/control_flow_ctx.py +0 -0
  176. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/conv.py +0 -0
  177. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/decoder/__init__.py +0 -0
  178. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/decoder/transformer.py +0 -0
  179. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/device.py +0 -0
  180. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/dims.py +0 -0
  181. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/dropout.py +0 -0
  182. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/dtype.py +0 -0
  183. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/encoder/__init__.py +0 -0
  184. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/encoder/base.py +0 -0
  185. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/encoder/conformer.py +0 -0
  186. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/gradient.py +0 -0
  187. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/graph.py +0 -0
  188. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/hooks.py +0 -0
  189. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/init.py +0 -0
  190. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/label_smoothing.py +0 -0
  191. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/linear.py +0 -0
  192. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/loop.py +0 -0
  193. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/loss.py +0 -0
  194. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/math_.py +0 -0
  195. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/matmul.py +0 -0
  196. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/module.py +0 -0
  197. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/normalization.py +0 -0
  198. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/parameter.py +0 -0
  199. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/piecewise_linear.py +0 -0
  200. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/rand.py +0 -0
  201. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/rec.py +0 -0
  202. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/reduce.py +0 -0
  203. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/run_ctx.py +0 -0
  204. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/signal.py +0 -0
  205. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/state.py +0 -0
  206. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/stepwise_scheduler.py +0 -0
  207. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/tensor_array.py +0 -0
  208. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/frontend/types.py +0 -0
  209. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/import_/__init__.py +0 -0
  210. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/import_/common.py +0 -0
  211. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/import_/git.py +0 -0
  212. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/import_/import_.py +0 -0
  213. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/learning_rate_control.py +0 -0
  214. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/log.py +0 -0
  215. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/native_op.cpp +0 -0
  216. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/native_op.py +0 -0
  217. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/pretrain.py +0 -0
  218. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/sprint/__init__.py +0 -0
  219. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/sprint/cache.py +0 -0
  220. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/sprint/control.py +0 -0
  221. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/sprint/error_signals.py +0 -0
  222. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/sprint/extern_interface.py +0 -0
  223. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/sprint/interface.py +0 -0
  224. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tensor/README.md +0 -0
  225. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tensor/__init__.py +0 -0
  226. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tensor/_dim_extra.py +0 -0
  227. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tensor/_tensor_extra.py +0 -0
  228. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tensor/_tensor_mixin_base.py +0 -0
  229. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tensor/_tensor_op_overloads.py +0 -0
  230. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tensor/control_flow_ctx.py +0 -0
  231. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tensor/dim.py +0 -0
  232. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tensor/marked_dim.py +0 -0
  233. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tensor/tensor.py +0 -0
  234. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tensor/tensor_dict.py +0 -0
  235. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tensor/utils.py +0 -0
  236. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/__init__.py +0 -0
  237. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/compat.py +0 -0
  238. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/data_pipeline.py +0 -0
  239. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/distributed.py +0 -0
  240. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/engine.py +0 -0
  241. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/frontend_layers/README.md +0 -0
  242. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/frontend_layers/__init__.py +0 -0
  243. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/frontend_layers/_backend.py +0 -0
  244. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/frontend_layers/_utils.py +0 -0
  245. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/frontend_layers/cond.py +0 -0
  246. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/frontend_layers/config_entry_points.py +0 -0
  247. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/frontend_layers/debug_eager_mode.py +0 -0
  248. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/frontend_layers/dims.py +0 -0
  249. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/frontend_layers/layer.py +0 -0
  250. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/frontend_layers/loop.py +0 -0
  251. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/frontend_layers/make_layer.py +0 -0
  252. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/frontend_layers/masked_computation.py +0 -0
  253. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/frontend_layers/parameter_assign.py +0 -0
  254. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/frontend_layers/prev_tensor_ref.py +0 -0
  255. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/frontend_low_level/__init__.py +0 -0
  256. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/frontend_low_level/_backend.py +0 -0
  257. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/horovod.py +0 -0
  258. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/hyper_param_tuning.py +0 -0
  259. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/layers/__init__.py +0 -0
  260. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/layers/base.py +0 -0
  261. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/layers/basic.py +0 -0
  262. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/layers/rec.py +0 -0
  263. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/layers/segmental_model.py +0 -0
  264. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/layers/signal_processing.py +0 -0
  265. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/layers/variable.py +0 -0
  266. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/native_op.py +0 -0
  267. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/network.py +0 -0
  268. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/sprint.py +0 -0
  269. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/updater.py +0 -0
  270. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/util/__init__.py +0 -0
  271. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/util/basic.py +0 -0
  272. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/util/data.py +0 -0
  273. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/util/gradient_checkpoint.py +0 -0
  274. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/util/ken_lm.py +0 -0
  275. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/tf/util/open_fst.py +0 -0
  276. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/torch/README.md +0 -0
  277. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/torch/__init__.py +0 -0
  278. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/torch/data/__init__.py +0 -0
  279. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/torch/data/extern_data.py +0 -0
  280. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/torch/data/pipeline.py +0 -0
  281. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/torch/data/queued_data_iter.py +0 -0
  282. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/torch/data/returnn_dataset_wrapper.py +0 -0
  283. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/torch/data/tensor_utils.py +0 -0
  284. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/torch/distributed.py +0 -0
  285. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/torch/engine.py +0 -0
  286. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/torch/frontend/__init__.py +0 -0
  287. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/torch/frontend/_backend.py +0 -0
  288. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/torch/frontend/_rand.py +0 -0
  289. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/torch/frontend/bridge.py +0 -0
  290. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/torch/frontend/raw_ops.py +0 -0
  291. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/torch/updater.py +0 -0
  292. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/torch/util/README.md +0 -0
  293. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/torch/util/__init__.py +0 -0
  294. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/torch/util/diagnose_gpu.py +0 -0
  295. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/torch/util/scaled_gradient.py +0 -0
  296. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/util/__init__.py +0 -0
  297. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/util/basic.py +0 -0
  298. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/util/better_exchook.py +0 -0
  299. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/util/bpe.py +0 -0
  300. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/util/debug.py +0 -0
  301. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/util/debug_helpers.py +0 -0
  302. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/util/file_cache.py +0 -0
  303. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/util/fsa.py +0 -0
  304. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/util/literal_py_to_pickle.py +0 -0
  305. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/util/math.py +0 -0
  306. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/util/multi_proc_non_daemonic_spawn.py +0 -0
  307. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/util/native_code_compiler.py +0 -0
  308. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/util/pprint.py +0 -0
  309. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/util/py-to-pickle.cpp +0 -0
  310. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/util/py_compat.py +0 -0
  311. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/util/py_ext_mod_compiler.py +0 -0
  312. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/util/result_with_reason.py +0 -0
  313. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/util/sig_proc.py +0 -0
  314. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/util/task_system.py +0 -0
  315. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/util/train_proc_manager.py +0 -0
  316. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn/util/watch_memory.py +0 -0
  317. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn.egg-info/dependency_links.txt +0 -0
  318. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/returnn.egg-info/top_level.txt +0 -0
  319. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/rnn.py +0 -0
  320. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/setup.cfg +0 -0
  321. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/setup.py +0 -0
  322. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/DummySprintExec.py +0 -0
  323. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/PyCharm-inspection-profile.xml +0 -0
  324. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/PyCharm.idea/.gitignore +0 -0
  325. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/PyCharm.idea/.name +0 -0
  326. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/PyCharm.idea/codeStyleSettings.xml +0 -0
  327. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/PyCharm.idea/codeStyles/Project.xml +0 -0
  328. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/PyCharm.idea/codeStyles/codeStyleConfig.xml +0 -0
  329. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/PyCharm.idea/inspectionProfiles/Project_Default.xml +0 -0
  330. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/PyCharm.idea/inspectionProfiles/profiles_settings.xml +0 -0
  331. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/PyCharm.idea/misc.xml +0 -0
  332. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/PyCharm.idea/modules.xml +0 -0
  333. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/PyCharm.idea/returnn.iml +0 -0
  334. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/PyCharm.idea/scopes/scope_settings.xml +0 -0
  335. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/_set_num_threads1.py +0 -0
  336. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/_setup_returnn_env.py +0 -0
  337. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/_setup_test_env.py +0 -0
  338. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/bpe-unicode-demo.codes +0 -0
  339. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/bpe-unicode-demo.vocab +0 -0
  340. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/lexicon_opt.fst +0 -0
  341. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/lexicon_opt.isyms +0 -0
  342. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/lexicon_opt.jpg +0 -0
  343. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/lexicon_opt.osyms +0 -0
  344. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/lint_common.py +0 -0
  345. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/pycharm-inspect.py +0 -0
  346. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/pylint.py +0 -0
  347. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/returnn-as-framework.py +0 -0
  348. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/rf_utils.py +0 -0
  349. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/spelling.dic +0 -0
  350. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_Config.py +0 -0
  351. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_Dataset.py +0 -0
  352. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_Fsa.py +0 -0
  353. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_GeneratingDataset.py +0 -0
  354. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_HDFDataset.py +0 -0
  355. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_LearningRateControl.py +0 -0
  356. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_Log.py +0 -0
  357. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_MultiProcDataset.py +0 -0
  358. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_Pretrain.py +0 -0
  359. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_ResNet.py +0 -0
  360. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_SprintDataset.py +0 -0
  361. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_SprintInterface.py +0 -0
  362. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_TFEngine.py +0 -0
  363. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_TFNativeOp.py +0 -0
  364. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_TFNetworkLayer.py +0 -0
  365. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_TFNetworkRecLayer.py +0 -0
  366. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_TFNetworkSigProcLayer.py +0 -0
  367. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_TFUpdater.py +0 -0
  368. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_TFUtil.py +0 -0
  369. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_TF_determinism.py +0 -0
  370. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_TaskSystem.py +0 -0
  371. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_TaskSystem_SharedMem.py +0 -0
  372. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_TranslationDataset.py +0 -0
  373. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_Util.py +0 -0
  374. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_demos.py +0 -0
  375. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_fork_exec.py +0 -0
  376. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_hdf_dump.py +0 -0
  377. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_rf_array.py +0 -0
  378. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_rf_attention.py +0 -0
  379. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_rf_base.py +0 -0
  380. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_rf_cond.py +0 -0
  381. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_rf_const.py +0 -0
  382. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_rf_container.py +0 -0
  383. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_rf_conv.py +0 -0
  384. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_rf_encoder_conformer.py +0 -0
  385. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_rf_gradient.py +0 -0
  386. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_rf_label_smoothing.py +0 -0
  387. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_rf_loop.py +0 -0
  388. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_rf_math.py +0 -0
  389. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_rf_normalization.py +0 -0
  390. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_rf_piecewise_linear.py +0 -0
  391. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_rf_rec.py +0 -0
  392. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_rf_reduce.py +0 -0
  393. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_rf_signal.py +0 -0
  394. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_tensor.py +0 -0
  395. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_tools.py +0 -0
  396. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_torch_dataset.py +0 -0
  397. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_torch_engine.py +0 -0
  398. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_torch_frontend.py +0 -0
  399. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tests/test_torch_internal_frontend.py +0 -0
  400. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/_setup_returnn_env.py +0 -0
  401. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/analyze-dataset-batches.py +0 -0
  402. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/bliss-collect-seq-lens.py +0 -0
  403. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/bliss-dump-text.py +0 -0
  404. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/bliss-get-segment-names.py +0 -0
  405. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/bliss-to-ogg-zip.py +0 -0
  406. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/bpe-create-lexicon.py +0 -0
  407. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/calculate-word-error-rate.py +0 -0
  408. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/cleanup-old-models.py +0 -0
  409. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/collect-orth-symbols.py +0 -0
  410. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/collect-words.py +0 -0
  411. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/compile_native_op.py +0 -0
  412. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/compile_tf_graph.py +0 -0
  413. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/debug-dump-search-scores.py +0 -0
  414. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/debug-plot-search-scores.py +0 -0
  415. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/dump-dataset-raw-strings.py +0 -0
  416. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/dump-dataset.py +0 -0
  417. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/dump-forward-stats.py +0 -0
  418. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/dump-forward.py +0 -0
  419. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/dump-network-json.py +0 -0
  420. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/dump-pickle.py +0 -0
  421. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/extract_state_tying_from_dataset.py +0 -0
  422. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/get-attention-weights.py +0 -0
  423. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/get-best-model-epoch.py +0 -0
  424. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/hdf_dump.py +0 -0
  425. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/hdf_dump_translation_dataset.py +0 -0
  426. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/import-blocks-mt-model.py +0 -0
  427. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/import-t2t-mt-model.py +0 -0
  428. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/lattice_rescorer/.gitignore +0 -0
  429. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/lattice_rescorer/Makefile +0 -0
  430. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/lattice_rescorer/README.md +0 -0
  431. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/lattice_rescorer/example/README.md +0 -0
  432. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/lattice_rescorer/example/libs_list +0 -0
  433. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/lattice_rescorer/example/network.040/i600_m600_m600.sgd_b16_lr0_cl2.newbobabs.config +0 -0
  434. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/lattice_rescorer/example/network.040/i600_m600_m600.sgd_b16_lr0_cl2.newbobabs.keep_over_epoch.lstm2.config +0 -0
  435. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/lattice_rescorer/example/rescore_lattice.sh +0 -0
  436. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/lattice_rescorer/example/state_vars_list +0 -0
  437. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/lattice_rescorer/example/tensor_names_list +0 -0
  438. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/lattice_rescorer/file.h +0 -0
  439. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/lattice_rescorer/htklatticerescorer.cc +0 -0
  440. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/lattice_rescorer/htklatticerescorer.h +0 -0
  441. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/lattice_rescorer/main.cc +0 -0
  442. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/lattice_rescorer/rescorer.h +0 -0
  443. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/lattice_rescorer/vocabulary.cc +0 -0
  444. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/lattice_rescorer/vocabulary.h +0 -0
  445. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/tf_avg_checkpoints.py +0 -0
  446. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/tf_inspect_checkpoint.py +0 -0
  447. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/tf_inspect_summary_log.py +0 -0
  448. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/torch_avg_checkpoints.py +0 -0
  449. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/torch_export_to_onnx.py +0 -0
  450. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/torch_inspect_checkpoint.py +0 -0
  451. {returnn-1.20240705.144031 → returnn-1.20240709.122157}/tools/torch_inspect_checkpoint_and_opt.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: returnn
3
- Version: 1.20240705.144031
3
+ Version: 1.20240709.122157
4
4
  Summary: The RWTH extensible training framework for universal recurrent neural networks
5
5
  Home-page: https://github.com/rwth-i6/returnn/
6
6
  Author: Albert Zeyer
@@ -0,0 +1,2 @@
1
+ version = '1.20240709.122157'
2
+ long_version = '1.20240709.122157+git.a0f8be5'
@@ -1391,6 +1391,7 @@ class ConcatSeqsDataset(CachedDataset2):
1391
1391
  seq_tag_delim=";",
1392
1392
  remove_in_between_postfix=None,
1393
1393
  repeat_in_between_last_frame_up_to_multiple_of=None,
1394
+ pad_narrow_data_to_multiple_of_target_len=None,
1394
1395
  use_cache_manager=False,
1395
1396
  epoch_wise_filter=None,
1396
1397
  **kwargs,
@@ -1406,6 +1407,12 @@ class ConcatSeqsDataset(CachedDataset2):
1406
1407
  Now it could happen that ceildiv(data_len1 + data_len2, 6) < align_len1 + align_len2.
1407
1408
  This option would repeat intermediate ending frames such that data_len1 % 6 == 0,
1408
1409
  by setting it to {"data": 6}.
1410
+ :param dict[str,(str,int)]|None pad_narrow_data_to_multiple_of_target_len: data_key -> (target_key, multiple).
1411
+ Similar as repeat_in_between_last_frame_up_to_multiple_of, but works for more padding/alignment schemes.
1412
+ Example: align_len == ceildiv(data_len - P, F) for all your sub-sequences, where P is a custom number,
1413
+ repeat_in_between_last_frame_up_to_multiple_of would not work because align_len != ceildiv(data_len, F)
1414
+ This option would pad/narrow so that align_len * F == data_len for all but the last sub-sequences
1415
+ by setting it to {"data": ("classes", F)} to ensure concat_align_len == ceildiv(concat_data_len - P, F)
1409
1416
  :param bool use_cache_manager:
1410
1417
  :param dict[(int,int),dict] epoch_wise_filter: see :class:`EpochWiseFilter`
1411
1418
  """
@@ -1413,6 +1420,7 @@ class ConcatSeqsDataset(CachedDataset2):
1413
1420
  self.seq_tag_delim = seq_tag_delim
1414
1421
  self.remove_in_between_postfix = remove_in_between_postfix or {}
1415
1422
  self.repeat_in_between_last_frame_up_to_multiple_of = repeat_in_between_last_frame_up_to_multiple_of or {}
1423
+ self.pad_narrow_data_to_multiple_of_target_len = pad_narrow_data_to_multiple_of_target_len or {}
1416
1424
  self.epoch_wise_filter = EpochWiseFilter(epoch_wise_filter) if epoch_wise_filter else None
1417
1425
  if isinstance(dataset, dict):
1418
1426
  dataset = dataset.copy()
@@ -1486,7 +1494,7 @@ class ConcatSeqsDataset(CachedDataset2):
1486
1494
  sub_seq_list.extend(sub_seq_tags)
1487
1495
  assert sub_seq_idx == len(sub_seq_list) and len(seq_list) == len(sub_seq_idxs)
1488
1496
  self.cur_sub_seq_idxs = sub_seq_idxs
1489
- return self.sub_dataset.init_seq_order(seq_list=sub_seq_list)
1497
+ return self.sub_dataset.init_seq_order(epoch=epoch, seq_list=sub_seq_list)
1490
1498
 
1491
1499
  def supports_seq_order_sorting(self) -> bool:
1492
1500
  """supports sorting"""
@@ -1539,6 +1547,11 @@ class ConcatSeqsDataset(CachedDataset2):
1539
1547
  key,
1540
1548
  sub_dataset_keys,
1541
1549
  )
1550
+ for key in self.pad_narrow_data_to_multiple_of_target_len:
1551
+ assert key in sub_dataset_keys, (
1552
+ f"{self}: pad_narrow_data_to_multiple_of_target_len key {key}"
1553
+ f" not in sub dataset data-keys {sub_dataset_keys}"
1554
+ )
1542
1555
  for sub_seq_idx, sub_seq_tag in zip(sub_seq_idxs, sub_seq_tags):
1543
1556
  self.sub_dataset.load_seqs(sub_seq_idx, sub_seq_idx + 1)
1544
1557
  sub_dataset_tag = self.sub_dataset.get_tag(sub_seq_idx)
@@ -1562,6 +1575,17 @@ class ConcatSeqsDataset(CachedDataset2):
1562
1575
  if data.shape[0] % multiple != 0:
1563
1576
  data = numpy.concatenate([data] + [data[-1:]] * (multiple - data.shape[0] % multiple), axis=0)
1564
1577
  assert data.shape[0] % multiple == 0
1578
+ if key in self.pad_narrow_data_to_multiple_of_target_len and sub_seq_idx != sub_seq_idxs[-1]:
1579
+ target_key, multiple = self.pad_narrow_data_to_multiple_of_target_len[key]
1580
+ target_data = self.sub_dataset.get_data(sub_seq_idx, target_key)
1581
+ len_diff = data.shape[0] - target_data.shape[0] * multiple
1582
+ if len_diff > 0:
1583
+ # if data longer than ref_data * frame_rate, narrow the data
1584
+ data = data[:-len_diff]
1585
+ elif len_diff < 0:
1586
+ # if data shorter than ref_data * frame_rate, pad by repeating last frame
1587
+ data = numpy.concatenate([data] + [data[-1:]] * -len_diff, axis=0)
1588
+ assert data.shape[0] == target_data.shape[0] * multiple
1565
1589
  features[key].append(data)
1566
1590
  features = {key: numpy.concatenate(values, axis=0) for (key, values) in features.items()}
1567
1591
  return DatasetSeq(seq_idx=seq_idx, seq_tag=seq_tag, features=features)
@@ -1606,6 +1630,10 @@ class ConcatSeqsDataset(CachedDataset2):
1606
1630
  """
1607
1631
  return self.sub_dataset.get_data_shape(key)
1608
1632
 
1633
+ def get_total_num_seqs(self) -> int:
1634
+ """total num seqs"""
1635
+ return len(self.full_seq_list)
1636
+
1609
1637
 
1610
1638
  class ChunkShuffleDataset(CachedDataset2):
1611
1639
  """
@@ -0,0 +1,594 @@
1
+ """
2
+ Gradient checkpointing.
3
+
4
+ Following a lot of the code of the official
5
+ `torch.utils.checkpoint <https://pytorch.org/docs/stable/checkpoint.html>`__,
6
+ using ``torch.autograd.graph.saved_tensors_hooks``
7
+ and ``TorchDispatchMode``
8
+ but also handling the RNG fork and reset in a similar way.
9
+
10
+ See also :mod:`returnn.tf.util.gradient_checkpoint`:
11
+ same API and logic in TF, although it heavily makes use
12
+ of the TF computation graph, i.e. graph mode,
13
+ which makes this particular feature much easier to implement.
14
+
15
+ See also:
16
+ https://github.com/rwth-i6/returnn/issues/1552
17
+ https://discuss.pytorch.org/t/gradient-checkpointing/205416
18
+ https://gist.github.com/soulitzer/ec1049a947be046de7fbc2af61a4ee8c
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ from typing import Optional, Union, Any, Callable, Sequence, List, Dict
24
+ from types import MethodType
25
+ from dataclasses import dataclass, field
26
+ import contextlib
27
+ from weakref import ref, WeakSet
28
+ import threading
29
+
30
+ import torch
31
+ from torch.utils.weak import WeakTensorKeyDictionary # needs Torch >=2.0.0
32
+
33
+ # noinspection PyProtectedMember
34
+ from torch.utils._python_dispatch import TorchDispatchMode
35
+
36
+ # PyTree is very common and semi-standard for PyTorch, e.g. __torch_dispatch__.
37
+ # We might use dm-tree or so alternatively here, but PyTree should be fine.
38
+ # noinspection PyProtectedMember
39
+ import torch.utils._pytree as pytree
40
+
41
+
42
+ __all__ = ["gradient_checkpoint_scope"]
43
+
44
+
45
+ # gradient_checkpoint_scope is the public API to the user.
46
+ # gradient_checkpoint_scope.__enter__ will enter two other scopes:
47
+ #
48
+ # - record_graph_scope: _RecordGraph(TorchDispatchMode),
49
+ # to record the computation graph for all ops within the scope.
50
+ #
51
+ # - saved_tensors_hooks_scope: torch.autograd.graph.saved_tensors_hooks,
52
+ # to overwrite what we store for backpropagation, and how to recompute it.
53
+ # Specifically, for all tensors which were created within the gradient_checkpoint_scope,
54
+ # we will never store them in the pack_hook,
55
+ # and unpack_hook will trigger the recomputation of the computation graph.
56
+ #
57
+ # gradient_checkpoint_scope.__exit__ will exit the record_graph_scope,
58
+ # but the saved_tensors_hooks_scope will stay alive as long as needed,
59
+ # while any of the created tensors are still alive.
60
+ # We keep a weak tensor key dictionary to map from the created raw tensors
61
+ # to the point in the recorded computation graph (specifically _GraphTensor objects).
62
+ # We just check whether any of the weak tensor refs is still alive.
63
+ #
64
+ # To keep saved_tensors_hooks_scope alive and make sure
65
+ # that other calls to torch.autograd.graph.saved_tensors_hooks are correctly handled,
66
+ # specifically that the order of enter/exit is correct,
67
+ # we hook into torch.autograd.graph.saved_tensors_hooks.__enter__/__exit__ itself.
68
+ # See _register_custom_saved_tensors_hooks below.
69
+ # Further, torch.autograd.graph.saved_tensors_hooks is thread local,
70
+ # so we can do any such logic only within the same thread.
71
+ # We also hook into Tensor.__del__ and also handle gradient_checkpoint_scope.__del__,
72
+ # but as that might run in a different thread, we cannot always do the cleanup there.
73
+ # We always check for this.
74
+ # (Note that this is due to the API of torch.autograd.graph.saved_tensors_hooks.
75
+ # We actually would want to always use it for a set of specified tensors.
76
+ # We also discuss some potentially better PyTorch API to implement this in an easier way:
77
+ # https://github.com/pytorch/pytorch/issues/129867)
78
+ #
79
+ # For the recomputation, we make sure that we properly reset the RNG and AMP states,
80
+ # and that we perform the recomputation in the exact same order, such that RNG state is correct.
81
+ #
82
+ # Once some recomputed tensor was used and is not needed anymore, the GC should free it.
83
+ # We try to make sure that no unnecessary references are kept alive.
84
+ #
85
+ # Also see test_gradient_checkpoint_scope() which tests this.
86
+
87
+
88
+ class gradient_checkpoint_scope:
89
+ """
90
+ Create a gradient checkpoint scope.
91
+ All tensors created within this scope will not be stored for backpropagation,
92
+ but will be recomputed on the fly during backpropagation.
93
+
94
+ Example::
95
+
96
+ a = ...
97
+ b = ...
98
+ c = ...
99
+ with gradient_checkpoint_scope():
100
+ x = a + b
101
+ y = x * c
102
+
103
+ In this example, the tensor ``x`` will not be stored for backpropagation,
104
+ i.e. the computation ``x = a + b`` will be recomputed during backpropagation.
105
+
106
+ Internally, this uses the PyTorch ``torch.autograd.graph.saved_tensors_hooks`` mechanism
107
+ to override what we store for backpropagation, and how to recompute it.
108
+ And we use the PyTorch ``TorchDispatchMode`` to intercept all operations within the scope.
109
+ Note that the usage of ``torch.autograd.graph.saved_tensors_hooks`` is tricky here
110
+ as we need it beyond the scope of the ``gradient_checkpoint_scope``,
111
+ specifically for all future usages of the tensor ``x`` in the example.
112
+ See the code documentation for more details on this.
113
+
114
+ Note, PyTorch itself also provides a gradient checkpointing API,
115
+ namely `torch.utils.checkpoint <https://pytorch.org/docs/stable/checkpoint.html>`__.
116
+ This API is different: You cannot easily specify what not to store / what to recompute.
117
+ You rather specify a start/end point what to *store* for backpropagation,
118
+ and then PyTorch will recompute everything in between.
119
+ For the example above, you define that ``y`` is the end point and will be stored.
120
+ It looks like this::
121
+
122
+ a = ...
123
+ b = ...
124
+ c = ...
125
+ y = torch.utils.checkpoint.checkpoint(lambda: (a + b) * c)
126
+
127
+ PyTorch will not recompute ``... * c`` here,
128
+ but it will recompute ``a + b``.
129
+ We find this API more cumbersome to use and less flexible,
130
+ because in many case, you know what you want to recompute, i.e. what you don't want to store.
131
+ The PyTorch API is more about what you want to store, and then recompute everything else between.
132
+
133
+ See also:
134
+ https://github.com/rwth-i6/returnn/issues/1552
135
+ https://discuss.pytorch.org/t/gradient-checkpointing/205416
136
+ """
137
+
138
+ def __init__(self):
139
+ self.record_graph_scope = _RecordGraph()
140
+ self.record_graph_scope.graph.gradient_checkpoint_scope_backref = self
141
+ # Note: saved_tensors_hooks is thread local.
142
+ self.saved_tensors_hooks_scope = torch.autograd.graph.saved_tensors_hooks(self._pack_hook, self._unpack_hook)
143
+ self.entered = False
144
+ self.entered_thread_ref = None
145
+ self.exit_args: Optional[tuple] = None
146
+ self.exited_saved_tensors_hooks_scope = False
147
+
148
+ def __enter__(self):
149
+ self.record_graph_scope.__enter__()
150
+ self.saved_tensors_hooks_scope.__enter__()
151
+ self.entered = True
152
+ self.entered_thread_ref = ref(threading.current_thread())
153
+
154
+ def __exit__(self, exc_type, exc_val, exc_tb):
155
+ self.exit_args = (exc_type, exc_val, exc_tb)
156
+ self.record_graph_scope.__exit__(exc_type, exc_val, exc_tb)
157
+ if self.record_graph_scope.graph.is_any_recorded_tensor_alive():
158
+ # Do not exit saved_tensors_hooks_scope here
159
+ # because we still want to pack any tensors which were captured in our graph
160
+ # by giving it a ref to the graph tensor.
161
+ # However, we must track any further external calls to saved_tensors_hooks_scope,
162
+ # to be able to properly remove it from the stack at the right point.
163
+ _register_custom_saved_tensors_hooks(existing_scope=self.saved_tensors_hooks_scope)
164
+ _register_custom_saved_tensors_hooks_thread_local_callback(
165
+ _WeakMethod(self._custom_saved_tensors_hooks_callback, return_if_dead=False)
166
+ )
167
+ else: # no relevant tensors alive anymore
168
+ self.exit_saved_tensors_hooks_scope()
169
+
170
+ def _maybe_exit_saved_tensors_hooks_scope(self):
171
+ if self.exited_saved_tensors_hooks_scope:
172
+ return
173
+ if not self.exit_args:
174
+ return
175
+ # If we are in the right thread, maybe we can do the cleanup now.
176
+ if self.entered_thread_ref() is threading.current_thread():
177
+ if not self.record_graph_scope.graph.is_any_recorded_tensor_alive():
178
+ self.exit_saved_tensors_hooks_scope()
179
+
180
+ def __del__(self):
181
+ # Note, be very careful what we do in __del__ because it might be called in a different thread!
182
+ # Note that the __del__ will likely be called very late,
183
+ # as the reference to the _Graph is kept alive until we used it for backprop,
184
+ # as we keep this alive via _Graph.gradient_checkpoint_scope_backref
185
+ # as long as any _GraphTensor is alive due to backprop pack_hook.
186
+ self._maybe_exit_saved_tensors_hooks_scope()
187
+
188
+ def exit_saved_tensors_hooks_scope(self):
189
+ """
190
+ exit saved_tensors_hooks_scope if not yet done.
191
+ """
192
+ assert self.entered_thread_ref() is threading.current_thread()
193
+ if self.exit_args and not self.exited_saved_tensors_hooks_scope:
194
+ # Note that via _register_custom_saved_tensors_hooks,
195
+ # this saved_tensors_hooks_scope.__exit__ might get to our _custom_saved_tensors_hooks_exit below,
196
+ # which will make sure that the order of __exit__ is correct.
197
+ self.exited_saved_tensors_hooks_scope = True
198
+ self.saved_tensors_hooks_scope.__exit__(*self.exit_args)
199
+
200
+ def _pack_hook(self, x: torch.Tensor) -> Union[torch.Tensor, _GraphTensor]:
201
+ if self.exit_args and not self.record_graph_scope.graph.is_any_recorded_tensor_alive():
202
+ # No raw tensors alive anymore in graph_tensor_from_raw_tensor,
203
+ # so we can exit saved_tensors_hooks_scope now.
204
+ # (We might not always catch this properly in the Tensor _DelHook,
205
+ # e.g. when Tensor.__del__ runs in a different thread.)
206
+ self.exit_saved_tensors_hooks_scope()
207
+ return x
208
+ # _RecordGraph.__torch_dispatch__ should have recorded all newly created tensors.
209
+ x_ = self.record_graph_scope.graph.graph_tensor_from_weak_raw_tensor.get(x, x)
210
+ if isinstance(x_, _GraphTensor):
211
+ x._RETURNN_grad_ckpt_del_hook = _DelHook(_WeakMethod(self._tensor_del_hook))
212
+ return x_
213
+
214
+ @staticmethod
215
+ def _unpack_hook(x: Union[torch.Tensor, _GraphTensor]) -> torch.Tensor:
216
+ if isinstance(x, _GraphTensor):
217
+ x.op.graph.gradient_checkpoint_scope_backref._maybe_exit_saved_tensors_hooks_scope()
218
+ x.op.graph.maybe_recompute()
219
+ return x.get_recomputed()
220
+ return x
221
+
222
+ def _tensor_del_hook(self):
223
+ # Some of the relevant tensors got deleted.
224
+ # If we are in the right thread, maybe we can do the cleanup now.
225
+ self._maybe_exit_saved_tensors_hooks_scope()
226
+
227
+ def _custom_saved_tensors_hooks_callback(self) -> bool:
228
+ assert self.entered_thread_ref() is threading.current_thread()
229
+ assert self.exit_args
230
+ if self.record_graph_scope.graph.is_any_recorded_tensor_alive():
231
+ return True # keep callback alive
232
+ else:
233
+ self.exit_saved_tensors_hooks_scope()
234
+ return False # we are done, can delete callback
235
+
236
+
237
+ class _RecordGraph(TorchDispatchMode):
238
+ def __init__(self):
239
+ super().__init__()
240
+ self.graph = _Graph([])
241
+
242
+ def __torch_dispatch__(self, func, types, args=(), kwargs=None):
243
+ kwargs = {} if kwargs is None else kwargs
244
+ graph = self.graph
245
+ graph.maybe_store_rng_state(torch.device("cpu"))
246
+ graph.maybe_store_amp_state(torch.device("cpu"))
247
+ pytree.tree_map(graph.maybe_store_rng_state, args)
248
+ pytree.tree_map(graph.maybe_store_rng_state, kwargs)
249
+ out = func(*args, **kwargs)
250
+ graph.record_op(func, args, kwargs, out)
251
+ return out
252
+
253
+
254
+ @dataclass
255
+ class _Graph:
256
+ ops_to_be_recomputed: List[_GraphOp] = field(default_factory=list)
257
+ graph_tensor_from_weak_raw_tensor: WeakTensorKeyDictionary[torch.Tensor, _GraphTensor] = field(
258
+ default_factory=WeakTensorKeyDictionary
259
+ )
260
+ stored_device_rng_states: Dict[torch.device, Any] = field(default_factory=dict)
261
+ stored_device_amp_states: Dict[torch.device, Any] = field(default_factory=dict)
262
+ # Keep scope alive as long as any _GraphTensor is alive due to backprop pack_hook.
263
+ gradient_checkpoint_scope_backref: Optional[gradient_checkpoint_scope] = None
264
+
265
+ def is_any_recorded_tensor_alive(self) -> bool:
266
+ """
267
+ :return: any recorded tensor is still alive.
268
+ Recorded tensors are outputs from any ops which were recorded,
269
+ i.e. ops under the gradient_checkpoint_scope.
270
+ """
271
+ # graph_tensor_from_weak_raw_tensor is a WeakTensorKeyDictionary,
272
+ # i.e. once there is no other strong reference to some Tensor anymore,
273
+ # it would also be removed from graph_tensor_from_weak_raw_tensor.
274
+ return bool(self.graph_tensor_from_weak_raw_tensor)
275
+
276
+ def record_op(self, func: Any, args: Sequence[Any], kwargs: Dict[str, Any], out: Any):
277
+ """record op"""
278
+ out_flat, _ = pytree.tree_flatten(out)
279
+ wrapped_args = pytree.tree_map_only(torch.Tensor, self.maybe_map_raw_tensor_to_graph_tensor, args)
280
+ wrapped_kwargs = pytree.tree_map_only(torch.Tensor, self.maybe_map_raw_tensor_to_graph_tensor, kwargs)
281
+ op = _GraphOp(
282
+ graph=self,
283
+ func=func,
284
+ args=wrapped_args,
285
+ kwargs=wrapped_kwargs,
286
+ out_flat_num=len(out_flat),
287
+ )
288
+ self.ops_to_be_recomputed.append(op)
289
+ for i, out_flat_elem in enumerate(out_flat):
290
+ if isinstance(out_flat_elem, torch.Tensor):
291
+ if out_flat_elem in self.graph_tensor_from_weak_raw_tensor:
292
+ continue
293
+ tensor_ = _GraphTensor(op=op, out_flat_idx=i)
294
+ self.graph_tensor_from_weak_raw_tensor[out_flat_elem] = tensor_
295
+
296
+ def maybe_store_rng_state(self, arg: Any):
297
+ """
298
+ Store RNG state if not yet stored for this device.
299
+ We store it only once for the first usage,
300
+ as we only restore it once for the recomputation,
301
+ and then we rely on performing the recomputation in the correct order,
302
+ which should be deterministic and lead to the same RNG output.
303
+ """
304
+ if isinstance(arg, torch.Tensor):
305
+ device = arg.device
306
+ elif isinstance(arg, torch.device):
307
+ device = arg
308
+ else:
309
+ return
310
+ if device not in self.stored_device_rng_states:
311
+ self.stored_device_rng_states[device] = _get_dev_rng_state(device)
312
+
313
+ def maybe_store_amp_state(self, arg: Any):
314
+ """store AMP state if not yet stored for this device."""
315
+ if isinstance(arg, torch.Tensor):
316
+ device = arg.device
317
+ elif isinstance(arg, torch.device):
318
+ device = arg
319
+ else:
320
+ return
321
+ if device not in self.stored_device_amp_states:
322
+ self.stored_device_amp_states[device] = _get_dev_amp_state(device)
323
+
324
+ def maybe_map_raw_tensor_to_graph_tensor(self, tensor: torch.Tensor) -> Union[_GraphTensor, torch.Tensor]:
325
+ """raw tensor to graph tensor if available, otherwise return raw tensor."""
326
+ return self.graph_tensor_from_weak_raw_tensor.get(tensor, tensor)
327
+
328
+ def maybe_recompute(self):
329
+ """
330
+ Recompute.
331
+
332
+ Make sure that the recomputations happen in the correct order,
333
+ to get any random number generator state correct.
334
+
335
+ Note that we considered to have an API here which allowed to only recompute a subset of the ops.
336
+ It would still compute all from op idx 0 to some given op idx, but not the rest.
337
+ On subsequent calls, it would then continue from the last idx until again the requested op idx.
338
+ This works fine except of one important aspect: The RNG state.
339
+ If there are any other ops in between which use the RNG state, the RNG state would not be correct anymore.
340
+ To allow this, we then would need to get the RNG state again and reset it later again,
341
+ which would add some further overhead.
342
+ To keep things simple and to avoid this overhead, we recompute all ops together right now.
343
+
344
+ However, we can at least remove the op from the list once it is computed.
345
+ So once any referenced tensor is not needed anymore, it can be garbage collected.
346
+ """
347
+ if not self.ops_to_be_recomputed:
348
+ return
349
+ with _reset_rng_states_scope(self.stored_device_rng_states), _reset_amp_states_scope(
350
+ self.stored_device_amp_states
351
+ ):
352
+ ops_reversed_queue = list(self.ops_to_be_recomputed)
353
+ ops_reversed_queue.reverse()
354
+ self.ops_to_be_recomputed.clear()
355
+ while ops_reversed_queue:
356
+ op = ops_reversed_queue.pop(-1)
357
+ op.recompute()
358
+ self.stored_device_rng_states.clear()
359
+ self.stored_device_amp_states.clear()
360
+
361
+
362
+ @dataclass
363
+ class _GraphOp:
364
+ graph: _Graph
365
+ func: Any
366
+ args: Optional[Sequence[Union[_GraphTensor, Any]]]
367
+ kwargs: Optional[Dict[str, Union[_GraphTensor, Any]]]
368
+ out_flat_num: int
369
+ recomputed_out_flat: Optional[Sequence[torch.Tensor]] = None
370
+
371
+ def recompute(self):
372
+ """recompute, assuming all args are recomputed."""
373
+ args = pytree.tree_map_only(_GraphTensor, _GraphTensor.get_recomputed, self.args)
374
+ kwargs = pytree.tree_map_only(_GraphTensor, _GraphTensor.get_recomputed, self.kwargs)
375
+ out = self.func(*args, **kwargs)
376
+ out_flat, _ = pytree.tree_flatten(out)
377
+ assert len(out_flat) == self.out_flat_num
378
+ self.recomputed_out_flat = out_flat
379
+ # potentially free any referenced resources. we don't need them anymore.
380
+ self.args = None
381
+ self.kwargs = None
382
+ # self.func should be ok to keep, should ref some of the low-level aten functions
383
+
384
+
385
+ @dataclass
386
+ class _GraphTensor:
387
+ op: _GraphOp
388
+ out_flat_idx: int
389
+
390
+ def get_recomputed(self) -> torch.Tensor:
391
+ """assuming it was recomputed, return the raw tensor."""
392
+ assert self.op.recomputed_out_flat is not None
393
+ return self.op.recomputed_out_flat[self.out_flat_idx]
394
+
395
+
396
+ @contextlib.contextmanager
397
+ def _reset_rng_states_scope(states: Dict[torch.device, Any]):
398
+ """
399
+ Reset RNG states scope.
400
+ Like torch.random.fork_rng but simpler.
401
+ """
402
+ prev_states = {dev: _get_dev_rng_state(dev) for dev in states.keys()}
403
+ try:
404
+ for dev, state in states.items():
405
+ _set_dev_rng_state(dev, state)
406
+ yield
407
+ finally:
408
+ for dev, state in prev_states.items():
409
+ _set_dev_rng_state(dev, state)
410
+
411
+
412
+ def _get_dev_rng_state(dev: torch.device):
413
+ if dev.type == "cpu":
414
+ return torch.get_rng_state()
415
+ dev_mod = getattr(torch, dev.type)
416
+ return dev_mod.get_rng_state(dev)
417
+
418
+
419
+ def _set_dev_rng_state(dev: torch.device, state: Any):
420
+ if dev.type == "cpu":
421
+ torch.set_rng_state(state)
422
+ else:
423
+ dev_mod = getattr(torch, dev.type)
424
+ dev_mod.set_rng_state(state, dev)
425
+
426
+
427
+ @contextlib.contextmanager
428
+ def _reset_amp_states_scope(states: Dict[torch.device, Any]):
429
+ with contextlib.ExitStack() as stack:
430
+ for dev, state in states.items():
431
+ if not state:
432
+ continue
433
+ if dev.type == "cpu":
434
+ stack.enter_context(torch.cpu.amp.autocast(**state))
435
+ else:
436
+ device_module = getattr(torch, dev.type)
437
+ stack.enter_context(device_module.amp.autocast(**state))
438
+ yield
439
+
440
+
441
+ def _get_dev_amp_state(dev: torch.device):
442
+ if dev.type == "cpu":
443
+ if not torch.is_autocast_cpu_enabled():
444
+ return None
445
+ return {
446
+ "dtype": torch.get_autocast_cpu_dtype(),
447
+ "cache_enabled": torch.is_autocast_cache_enabled(),
448
+ }
449
+
450
+ if dev.type == "cuda":
451
+ if not torch.is_autocast_enabled():
452
+ return None
453
+ return {
454
+ "dtype": torch.get_autocast_gpu_dtype(),
455
+ "cache_enabled": torch.is_autocast_cache_enabled(),
456
+ }
457
+
458
+ device_module = getattr(torch, dev.type)
459
+ if hasattr(device_module, "is_autocast_enabled") and hasattr(device_module, "get_autocast_dtype"):
460
+ if not device_module.is_autocast_enabled():
461
+ return None
462
+ return {
463
+ "dtype": device_module.get_autocast_dtype(),
464
+ "cache_enabled": torch.is_autocast_cache_enabled(),
465
+ }
466
+
467
+ return None
468
+
469
+
470
+ class _DelHook:
471
+ def __init__(self, callback):
472
+ self.callback = callback
473
+
474
+ def __del__(self):
475
+ self.callback()
476
+
477
+
478
+ class _WeakMethod:
479
+ # wrong type hint because mypy/PyCharm don't handle MethodType well
480
+ def __init__(self, method: Union[MethodType, Callable], *, return_if_dead: Any = None):
481
+ assert isinstance(method, MethodType)
482
+ self.obj = ref(method.__self__)
483
+ self.func = method.__func__
484
+ self.return_if_dead = return_if_dead
485
+
486
+ def __call__(self, *args, **kwargs):
487
+ obj = self.obj()
488
+ if obj is None:
489
+ return self.return_if_dead
490
+ return self.func(obj, *args, **kwargs)
491
+
492
+
493
+ _orig_saved_tensors_hooks_enter = torch.autograd.graph.saved_tensors_hooks.__enter__
494
+ _orig_saved_tensors_hooks_exit = torch.autograd.graph.saved_tensors_hooks.__exit__
495
+ _custom_saved_tensors_hooks_tls_ctx = threading.local()
496
+ _custom_saved_tensors_hooks_lock = threading.Lock() # only needed for non thread-locals, i.e. threads, methods
497
+ _custom_saved_tensors_hooks_registered_threads = WeakSet()
498
+
499
+
500
+ def _register_custom_saved_tensors_hooks(*, existing_scope: torch.autograd.graph.saved_tensors_hooks):
501
+ """
502
+ The purpose of our custom saved_tensors_hooks __enter__/__exit__ is to make sure that
503
+ the order of __exit__ is correct, i.e. that we exit the scope in the correct order.
504
+
505
+ See :func:`_custom_saved_tensors_hooks_enter` and :func:`_custom_saved_tensors_hooks_exit`.
506
+
507
+ There is no need to call :func:`_unregister_custom_saved_tensors_hooks` later.
508
+ It will be called automatically when the last scope is exited.
509
+ """
510
+ thread = threading.current_thread()
511
+ with _custom_saved_tensors_hooks_lock:
512
+ if thread in _custom_saved_tensors_hooks_registered_threads:
513
+ return
514
+ if getattr(_custom_saved_tensors_hooks_tls_ctx, "stack", None) is None:
515
+ _custom_saved_tensors_hooks_tls_ctx.stack = []
516
+ _custom_saved_tensors_hooks_tls_ctx.in_callback = False
517
+ _custom_saved_tensors_hooks_tls_ctx.callbacks = []
518
+ _custom_saved_tensors_hooks_tls_ctx.queued_exits = []
519
+ _custom_saved_tensors_hooks_tls_ctx.active = True
520
+ _custom_saved_tensors_hooks_tls_ctx.stack.append(existing_scope)
521
+ _custom_saved_tensors_hooks_registered_threads.add(thread)
522
+ if len(_custom_saved_tensors_hooks_registered_threads) == 1:
523
+ torch.autograd.graph.saved_tensors_hooks.__enter__ = _custom_saved_tensors_hooks_enter
524
+ torch.autograd.graph.saved_tensors_hooks.__exit__ = _custom_saved_tensors_hooks_exit
525
+
526
+
527
+ def _unregister_custom_saved_tensors_hooks():
528
+ thread = threading.current_thread()
529
+ with _custom_saved_tensors_hooks_lock:
530
+ assert thread in _custom_saved_tensors_hooks_registered_threads
531
+ assert (
532
+ not _custom_saved_tensors_hooks_tls_ctx.stack
533
+ and not _custom_saved_tensors_hooks_tls_ctx.callbacks
534
+ and not _custom_saved_tensors_hooks_tls_ctx.queued_exits
535
+ )
536
+ _custom_saved_tensors_hooks_tls_ctx.active = False
537
+ _custom_saved_tensors_hooks_registered_threads.remove(thread)
538
+ if not _custom_saved_tensors_hooks_registered_threads:
539
+ torch.autograd.graph.saved_tensors_hooks.__enter__ = _orig_saved_tensors_hooks_enter
540
+ torch.autograd.graph.saved_tensors_hooks.__exit__ = _orig_saved_tensors_hooks_exit
541
+
542
+
543
+ def _custom_saved_tensors_hooks_enter(self: torch.autograd.graph.saved_tensors_hooks):
544
+ _custom_saved_tensors_hooks_call_callbacks()
545
+ # The callbacks might have unregistered us. Only add to the stack if we are still active.
546
+ if _custom_saved_tensors_hooks_tls_ctx.active:
547
+ _custom_saved_tensors_hooks_tls_ctx.stack.append(self)
548
+ return _orig_saved_tensors_hooks_enter(self)
549
+
550
+
551
+ def _custom_saved_tensors_hooks_exit(self: torch.autograd.graph.saved_tensors_hooks, exc_type, exc_val, exc_tb):
552
+ if self not in _custom_saved_tensors_hooks_tls_ctx.stack:
553
+ raise Exception(
554
+ f"saved_tensors_hooks __exit__ mismatch."
555
+ f" stack {_custom_saved_tensors_hooks_tls_ctx.stack},"
556
+ f" queued_exits {_custom_saved_tensors_hooks_tls_ctx.queued_exits},"
557
+ f" got self {self}"
558
+ )
559
+ _custom_saved_tensors_hooks_tls_ctx.queued_exits.append(self)
560
+ _custom_saved_tensors_hooks_call_callbacks()
561
+ while _custom_saved_tensors_hooks_tls_ctx.stack:
562
+ scope = _custom_saved_tensors_hooks_tls_ctx.stack[-1]
563
+ if scope not in _custom_saved_tensors_hooks_tls_ctx.queued_exits:
564
+ # Need to wait for this scope to exit first.
565
+ # Once we exit it, we would then exit also the others when they are on top.
566
+ break
567
+ _custom_saved_tensors_hooks_tls_ctx.stack.pop(-1)
568
+ _custom_saved_tensors_hooks_tls_ctx.queued_exits.remove(scope)
569
+ _orig_saved_tensors_hooks_exit(scope, exc_type, exc_val, exc_tb)
570
+ exc_type, exc_val, exc_tb = None, None, None # do not propagate this again (even though it's ignored anyway)
571
+ if not _custom_saved_tensors_hooks_tls_ctx.stack:
572
+ assert not _custom_saved_tensors_hooks_tls_ctx.queued_exits
573
+ if _custom_saved_tensors_hooks_tls_ctx.active: # might have been unregistered in the meantime by callbacks
574
+ _unregister_custom_saved_tensors_hooks()
575
+
576
+
577
+ def _register_custom_saved_tensors_hooks_thread_local_callback(cb: Callable[[], bool]):
578
+ """
579
+ Register some thread-local callback function which is called on saved_tensors_hooks __enter__ and __exit__.
580
+ If it returns True, it is kept alive, otherwise removed.
581
+ """
582
+ _custom_saved_tensors_hooks_tls_ctx.callbacks.append(cb)
583
+
584
+
585
+ def _custom_saved_tensors_hooks_call_callbacks():
586
+ if _custom_saved_tensors_hooks_tls_ctx.in_callback:
587
+ return # avoid recursive calls
588
+ try:
589
+ _custom_saved_tensors_hooks_tls_ctx.in_callback = True
590
+ _custom_saved_tensors_hooks_tls_ctx.callbacks = [
591
+ cb for cb in _custom_saved_tensors_hooks_tls_ctx.callbacks if cb()
592
+ ]
593
+ finally:
594
+ _custom_saved_tensors_hooks_tls_ctx.in_callback = False