returnn 1.20241020.5643__tar.gz → 1.20241022.224754__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of returnn might be problematic. Click here for more details.

Files changed (468) hide show
  1. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/LICENSE +7 -0
  2. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/PKG-INFO +1 -1
  3. returnn-1.20241022.224754/_setup_info_generated.py +2 -0
  4. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/audio/mel.py +4 -1
  5. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/decoder/transformer.py +2 -6
  6. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/dtype.py +35 -1
  7. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/tensor/utils.py +3 -0
  8. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/torch/data/extern_data.py +8 -1
  9. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/torch/engine.py +54 -18
  10. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/torch/frontend/_backend.py +10 -1
  11. returnn-1.20241022.224754/returnn/torch/optim/README.md +5 -0
  12. returnn-1.20241022.224754/returnn/torch/optim/__init__.py +3 -0
  13. returnn-1.20241022.224754/returnn/torch/optim/lion.py +205 -0
  14. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/torch/updater.py +29 -17
  15. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/torch/util/exception_helper.py +4 -1
  16. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn.egg-info/PKG-INFO +1 -1
  17. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn.egg-info/SOURCES.txt +3 -0
  18. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/test_rf_base.py +28 -0
  19. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/test_torch_engine.py +49 -0
  20. returnn-1.20241020.5643/_setup_info_generated.py +0 -2
  21. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/.editorconfig +0 -0
  22. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/.gitignore +0 -0
  23. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/.gitmodules +0 -0
  24. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/.kateconfig +0 -0
  25. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/CHANGELOG.md +0 -0
  26. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/CODEOWNERS +0 -0
  27. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/CONTRIBUTING.md +0 -0
  28. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/MANIFEST.in +0 -0
  29. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/README.rst +0 -0
  30. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/__init__.py +0 -0
  31. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/12AX.cluster_map +0 -0
  32. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/_setup_returnn_env.py +0 -0
  33. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/demo-fwd.config +0 -0
  34. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/demo-horovod-mpi.py +0 -0
  35. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/demo-horovod-mpi.py.sh +0 -0
  36. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/demo-horovod-mpi.sh +0 -0
  37. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/demo-hyper-param-tuning.config +0 -0
  38. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/demo-iter-dataset.py +0 -0
  39. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/demo-list-devices.py +0 -0
  40. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/demo-lua-torch-layer.config +0 -0
  41. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/demo-pretrain.config +0 -0
  42. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/demo-record-and-push-to-webserver.py +0 -0
  43. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/demo-returnn-as-framework.py +0 -0
  44. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/demo-rf-pt-benchmark.py +0 -0
  45. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/demo-rf.config +0 -0
  46. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/demo-rhn-enwik8.config +0 -0
  47. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/demo-sprint-interface.py +0 -0
  48. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/demo-tf-att-copy.config +0 -0
  49. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/demo-tf-attention.config +0 -0
  50. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/demo-tf-chunking-blstm.12ax.config +0 -0
  51. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/demo-tf-contribrnn-lstm.12ax.config +0 -0
  52. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/demo-tf-enc-dec.config +0 -0
  53. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/demo-tf-hard-att-copy.config +0 -0
  54. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/demo-tf-lstm-benchmark.py +0 -0
  55. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/demo-tf-maxgradnorm-lstm.12ax.config +0 -0
  56. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/demo-tf-native-lstm-lowmem.12ax.config +0 -0
  57. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/demo-tf-native-lstm.12ax.config +0 -0
  58. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/demo-tf-native-lstm2.12ax.config +0 -0
  59. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/demo-tf-native-lstm2.12ax.tuned.config +0 -0
  60. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/demo-tf-neural-transducer.12ax.config +0 -0
  61. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/demo-tf-rec-explicit-lstm.config +0 -0
  62. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/demo-tf-rec-explicit-rnn.config +0 -0
  63. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/demo-tf-rec-self-att.config +0 -0
  64. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/demo-tf-search-compiled-graph.py +0 -0
  65. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/demo-tf-vanilla-lstm.12ax.config +0 -0
  66. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/demo-timit-lstm-ctc.config +0 -0
  67. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/demo-torch.config +0 -0
  68. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/demo-upd-mult-model.lstm.12ax.config +0 -0
  69. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/demo.sh +0 -0
  70. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/mdlstm/IAM/IAM_lines/a01-000u-00.png +0 -0
  71. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/mdlstm/IAM/IAM_lines/a01-007-04.png +0 -0
  72. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/mdlstm/IAM/IAM_lines/a01-007-06.png +0 -0
  73. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/mdlstm/IAM/README.txt +0 -0
  74. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/mdlstm/IAM/chars.txt +0 -0
  75. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/mdlstm/IAM/config_demo +0 -0
  76. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/mdlstm/IAM/config_fwd +0 -0
  77. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/mdlstm/IAM/config_real +0 -0
  78. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/mdlstm/IAM/create_IAM_dataset.py +0 -0
  79. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/mdlstm/IAM/decode.py +0 -0
  80. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/mdlstm/IAM/features/raw/demo.h5 +0 -0
  81. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/mdlstm/IAM/go.sh +0 -0
  82. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/mdlstm/IAM/lines.txt +0 -0
  83. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/mdlstm/IAM/split/eval.txt +0 -0
  84. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/mdlstm/IAM/split/train.txt +0 -0
  85. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/mdlstm/IAM/split/valid.txt +0 -0
  86. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/mdlstm/README.md +0 -0
  87. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/mdlstm/artificial/create_test_h5.py +0 -0
  88. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/mdlstm/artificial/forwardconfig +0 -0
  89. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/mdlstm/artificial/go.sh +0 -0
  90. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/mdlstm/artificial/trainconfig +0 -0
  91. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/mdlstm/artificial_rgb/create_test_h5.py +0 -0
  92. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/mdlstm/artificial_rgb/forwardconfig +0 -0
  93. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/mdlstm/artificial_rgb/go.sh +0 -0
  94. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/demos/mdlstm/artificial_rgb/trainconfig +0 -0
  95. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/pyproject.toml +0 -0
  96. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/requirements.txt +0 -0
  97. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/__init__.py +0 -0
  98. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/__main__.py +0 -0
  99. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/__old_mod_loader__.py +0 -0
  100. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/__setup__.py +0 -0
  101. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/config.py +0 -0
  102. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/datasets/__init__.py +0 -0
  103. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/datasets/audio.py +0 -0
  104. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/datasets/basic.py +0 -0
  105. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/datasets/bundle_file.py +0 -0
  106. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/datasets/cached.py +0 -0
  107. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/datasets/cached2.py +0 -0
  108. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/datasets/distrib_files.py +0 -0
  109. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/datasets/generating.py +0 -0
  110. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/datasets/hdf.py +0 -0
  111. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/datasets/lm.py +0 -0
  112. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/datasets/map.py +0 -0
  113. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/datasets/meta.py +0 -0
  114. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/datasets/multi_proc.py +0 -0
  115. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/datasets/normalization_data.py +0 -0
  116. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/datasets/numpy_dump.py +0 -0
  117. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/datasets/postprocessing.py +0 -0
  118. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/datasets/raw_wav.py +0 -0
  119. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/datasets/sprint.py +0 -0
  120. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/datasets/stereo.py +0 -0
  121. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/datasets/util/__init__.py +0 -0
  122. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/datasets/util/feature_extraction.py +0 -0
  123. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/datasets/util/strings.py +0 -0
  124. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/datasets/util/vocabulary.py +0 -0
  125. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/engine/__init__.py +0 -0
  126. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/engine/base.py +0 -0
  127. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/engine/batch.py +0 -0
  128. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/extern/WarpRna/__init__.py +0 -0
  129. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/extern/WarpRna/__main__.py +0 -0
  130. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/extern/WarpRna/warp-rna/.git +0 -0
  131. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/extern/WarpRna/warp-rna/.gitignore +0 -0
  132. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/extern/WarpRna/warp-rna/LICENSE +0 -0
  133. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/extern/WarpRna/warp-rna/README.md +0 -0
  134. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/extern/WarpRna/warp-rna/aligner.gif +0 -0
  135. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/extern/WarpRna/warp-rna/check.png +0 -0
  136. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/extern/WarpRna/warp-rna/core.cu +0 -0
  137. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/extern/WarpRna/warp-rna/core.h +0 -0
  138. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/extern/WarpRna/warp-rna/core_cpu.cpp +0 -0
  139. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/extern/WarpRna/warp-rna/pytorch_binding/LICENSE +0 -0
  140. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/extern/WarpRna/warp-rna/pytorch_binding/MANIFEST.in +0 -0
  141. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/extern/WarpRna/warp-rna/pytorch_binding/README.md +0 -0
  142. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/extern/WarpRna/warp-rna/pytorch_binding/binding.cpp +0 -0
  143. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/extern/WarpRna/warp-rna/pytorch_binding/core.cu +0 -0
  144. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/extern/WarpRna/warp-rna/pytorch_binding/core.h +0 -0
  145. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/extern/WarpRna/warp-rna/pytorch_binding/requirements.txt +0 -0
  146. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/extern/WarpRna/warp-rna/pytorch_binding/setup.py +0 -0
  147. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/extern/WarpRna/warp-rna/pytorch_binding/warp_rna/__init__.py +0 -0
  148. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/extern/WarpRna/warp-rna/pytorch_binding/warp_rna/test.py +0 -0
  149. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/extern/WarpRna/warp-rna/ref_rna.py +0 -0
  150. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/extern/WarpRna/warp-rna/tensorflow_binding/setup.py +0 -0
  151. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/extern/WarpRna/warp-rna/tensorflow_binding/src/warp_rna_op.cc +0 -0
  152. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/extern/WarpRna/warp-rna/tensorflow_binding/src/warp_rna_op_kernel_tmpl.h +0 -0
  153. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/extern/WarpRna/warp-rna/tensorflow_binding/warp_rna/__init__.py +0 -0
  154. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/extern/WarpRna/warp-rna/test.cpp +0 -0
  155. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/extern/__init__.py +0 -0
  156. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/extern/graph_editor/README.md +0 -0
  157. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/extern/graph_editor/__init__.py +0 -0
  158. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/extern/graph_editor/edit.py +0 -0
  159. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/extern/graph_editor/reroute.py +0 -0
  160. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/extern/graph_editor/select.py +0 -0
  161. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/extern/graph_editor/subgraph.py +0 -0
  162. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/extern/graph_editor/transform.py +0 -0
  163. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/extern/graph_editor/util.py +0 -0
  164. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/forward_iface.py +0 -0
  165. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/__init__.py +0 -0
  166. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/_backend.py +0 -0
  167. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/_native/__init__.py +0 -0
  168. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/_native/backend.cpp +0 -0
  169. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/_native/backend.hpp +0 -0
  170. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/_native/module.cpp +0 -0
  171. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/_native/module.hpp +0 -0
  172. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/_native/py_utils.hpp +0 -0
  173. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/_native/tensor_ops.cpp +0 -0
  174. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/_native/tensor_ops.hpp +0 -0
  175. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/_numpy_backend.py +0 -0
  176. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/_random_journal.py +0 -0
  177. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/_utils.py +0 -0
  178. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/array_.py +0 -0
  179. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/attention.py +0 -0
  180. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/audio/__init__.py +0 -0
  181. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/audio/specaugment.py +0 -0
  182. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/backend.py +0 -0
  183. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/build_from_dict.py +0 -0
  184. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/cond.py +0 -0
  185. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/const.py +0 -0
  186. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/container.py +0 -0
  187. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/control_flow_ctx.py +0 -0
  188. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/conv.py +0 -0
  189. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/conversions/__init__.py +0 -0
  190. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/conversions/espnet_e_branchformer.py +0 -0
  191. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/conversions/hf_llama.py +0 -0
  192. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/conversions/torch_nn.py +0 -0
  193. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/decoder/__init__.py +0 -0
  194. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/device.py +0 -0
  195. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/dims.py +0 -0
  196. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/dropout.py +0 -0
  197. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/encoder/__init__.py +0 -0
  198. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/encoder/base.py +0 -0
  199. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/encoder/conformer.py +0 -0
  200. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/encoder/e_branchformer.py +0 -0
  201. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/encoder/transformer.py +0 -0
  202. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/gradient.py +0 -0
  203. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/graph.py +0 -0
  204. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/hooks.py +0 -0
  205. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/init.py +0 -0
  206. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/label_smoothing.py +0 -0
  207. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/linear.py +0 -0
  208. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/loop.py +0 -0
  209. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/loss.py +0 -0
  210. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/math_.py +0 -0
  211. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/matmul.py +0 -0
  212. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/module.py +0 -0
  213. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/normalization.py +0 -0
  214. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/parameter.py +0 -0
  215. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/parametrizations.py +0 -0
  216. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/parametrize.py +0 -0
  217. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/piecewise_linear.py +0 -0
  218. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/rand.py +0 -0
  219. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/rec.py +0 -0
  220. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/reduce.py +0 -0
  221. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/run_ctx.py +0 -0
  222. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/signal.py +0 -0
  223. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/state.py +0 -0
  224. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/stepwise_scheduler.py +0 -0
  225. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/tensor_array.py +0 -0
  226. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/frontend/types.py +0 -0
  227. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/import_/__init__.py +0 -0
  228. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/import_/common.py +0 -0
  229. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/import_/git.py +0 -0
  230. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/import_/import_.py +0 -0
  231. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/learning_rate_control.py +0 -0
  232. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/log.py +0 -0
  233. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/native_op.cpp +0 -0
  234. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/native_op.py +0 -0
  235. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/pretrain.py +0 -0
  236. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/sprint/__init__.py +0 -0
  237. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/sprint/cache.py +0 -0
  238. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/sprint/control.py +0 -0
  239. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/sprint/error_signals.py +0 -0
  240. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/sprint/extern_interface.py +0 -0
  241. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/sprint/interface.py +0 -0
  242. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/tensor/README.md +0 -0
  243. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/tensor/__init__.py +0 -0
  244. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/tensor/_dim_extra.py +0 -0
  245. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/tensor/_tensor_extra.py +0 -0
  246. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/tensor/_tensor_mixin_base.py +0 -0
  247. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/tensor/_tensor_op_overloads.py +0 -0
  248. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/tensor/control_flow_ctx.py +0 -0
  249. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/tensor/dim.py +0 -0
  250. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/tensor/marked_dim.py +0 -0
  251. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/tensor/tensor.py +0 -0
  252. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/tensor/tensor_dict.py +0 -0
  253. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/tf/__init__.py +0 -0
  254. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/tf/compat.py +0 -0
  255. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/tf/data_pipeline.py +0 -0
  256. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/tf/distributed.py +0 -0
  257. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/tf/engine.py +0 -0
  258. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/tf/frontend_layers/README.md +0 -0
  259. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/tf/frontend_layers/__init__.py +0 -0
  260. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/tf/frontend_layers/_backend.py +0 -0
  261. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/tf/frontend_layers/_utils.py +0 -0
  262. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/tf/frontend_layers/cond.py +0 -0
  263. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/tf/frontend_layers/config_entry_points.py +0 -0
  264. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/tf/frontend_layers/debug_eager_mode.py +0 -0
  265. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/tf/frontend_layers/dims.py +0 -0
  266. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/tf/frontend_layers/layer.py +0 -0
  267. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/tf/frontend_layers/loop.py +0 -0
  268. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/tf/frontend_layers/make_layer.py +0 -0
  269. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/tf/frontend_layers/masked_computation.py +0 -0
  270. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/tf/frontend_layers/parameter_assign.py +0 -0
  271. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/tf/frontend_layers/prev_tensor_ref.py +0 -0
  272. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/tf/frontend_low_level/__init__.py +0 -0
  273. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/tf/frontend_low_level/_backend.py +0 -0
  274. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/tf/horovod.py +0 -0
  275. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/tf/hyper_param_tuning.py +0 -0
  276. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/tf/layers/__init__.py +0 -0
  277. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/tf/layers/base.py +0 -0
  278. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/tf/layers/basic.py +0 -0
  279. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/tf/layers/rec.py +0 -0
  280. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/tf/layers/segmental_model.py +0 -0
  281. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/tf/layers/signal_processing.py +0 -0
  282. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/tf/layers/variable.py +0 -0
  283. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/tf/native_op.py +0 -0
  284. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/tf/network.py +0 -0
  285. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/tf/sprint.py +0 -0
  286. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/tf/updater.py +0 -0
  287. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/tf/util/__init__.py +0 -0
  288. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/tf/util/basic.py +0 -0
  289. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/tf/util/data.py +0 -0
  290. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/tf/util/gradient_checkpoint.py +0 -0
  291. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/tf/util/ken_lm.py +0 -0
  292. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/tf/util/open_fst.py +0 -0
  293. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/torch/README.md +0 -0
  294. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/torch/__init__.py +0 -0
  295. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/torch/data/__init__.py +0 -0
  296. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/torch/data/pipeline.py +0 -0
  297. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/torch/data/queued_data_iter.py +0 -0
  298. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/torch/data/returnn_dataset_wrapper.py +0 -0
  299. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/torch/data/tensor_utils.py +0 -0
  300. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/torch/distributed.py +0 -0
  301. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/torch/frontend/__init__.py +0 -0
  302. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/torch/frontend/_rand.py +0 -0
  303. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/torch/frontend/bridge.py +0 -0
  304. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/torch/frontend/raw_ops.py +0 -0
  305. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/torch/util/README.md +0 -0
  306. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/torch/util/__init__.py +0 -0
  307. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/torch/util/array_.py +0 -0
  308. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/torch/util/diagnose_gpu.py +0 -0
  309. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/torch/util/gradient_checkpoint.py +0 -0
  310. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/torch/util/module.py +0 -0
  311. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/torch/util/scaled_gradient.py +0 -0
  312. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/util/__init__.py +0 -0
  313. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/util/basic.py +0 -0
  314. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/util/better_exchook.py +0 -0
  315. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/util/bpe.py +0 -0
  316. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/util/debug.py +0 -0
  317. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/util/debug_helpers.py +0 -0
  318. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/util/file_cache.py +0 -0
  319. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/util/fsa.py +0 -0
  320. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/util/literal_py_to_pickle.py +0 -0
  321. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/util/math.py +0 -0
  322. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/util/multi_proc_non_daemonic_spawn.py +0 -0
  323. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/util/native_code_compiler.py +0 -0
  324. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/util/pprint.py +0 -0
  325. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/util/py-to-pickle.cpp +0 -0
  326. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/util/py_compat.py +0 -0
  327. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/util/py_ext_mod_compiler.py +0 -0
  328. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/util/result_with_reason.py +0 -0
  329. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/util/sig_proc.py +0 -0
  330. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/util/task_system.py +0 -0
  331. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/util/train_proc_manager.py +0 -0
  332. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn/util/watch_memory.py +0 -0
  333. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn.egg-info/dependency_links.txt +0 -0
  334. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/returnn.egg-info/top_level.txt +0 -0
  335. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/rnn.py +0 -0
  336. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/setup.cfg +0 -0
  337. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/setup.py +0 -0
  338. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/DummySprintExec.py +0 -0
  339. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/PyCharm-inspection-profile.xml +0 -0
  340. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/PyCharm.idea/.gitignore +0 -0
  341. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/PyCharm.idea/.name +0 -0
  342. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/PyCharm.idea/codeStyleSettings.xml +0 -0
  343. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/PyCharm.idea/codeStyles/Project.xml +0 -0
  344. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/PyCharm.idea/codeStyles/codeStyleConfig.xml +0 -0
  345. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/PyCharm.idea/inspectionProfiles/Project_Default.xml +0 -0
  346. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/PyCharm.idea/inspectionProfiles/profiles_settings.xml +0 -0
  347. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/PyCharm.idea/misc.xml +0 -0
  348. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/PyCharm.idea/modules.xml +0 -0
  349. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/PyCharm.idea/returnn.iml +0 -0
  350. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/PyCharm.idea/scopes/scope_settings.xml +0 -0
  351. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/_set_num_threads1.py +0 -0
  352. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/_setup_returnn_env.py +0 -0
  353. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/_setup_test_env.py +0 -0
  354. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/bpe-unicode-demo.codes +0 -0
  355. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/bpe-unicode-demo.vocab +0 -0
  356. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/lexicon_opt.fst +0 -0
  357. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/lexicon_opt.isyms +0 -0
  358. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/lexicon_opt.jpg +0 -0
  359. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/lexicon_opt.osyms +0 -0
  360. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/lint_common.py +0 -0
  361. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/pycharm-inspect.py +0 -0
  362. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/pylint.py +0 -0
  363. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/returnn-as-framework.py +0 -0
  364. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/rf_utils.py +0 -0
  365. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/spelling.dic +0 -0
  366. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/test_Config.py +0 -0
  367. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/test_Dataset.py +0 -0
  368. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/test_Fsa.py +0 -0
  369. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/test_GeneratingDataset.py +0 -0
  370. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/test_HDFDataset.py +0 -0
  371. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/test_LearningRateControl.py +0 -0
  372. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/test_Log.py +0 -0
  373. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/test_MultiProcDataset.py +0 -0
  374. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/test_Pretrain.py +0 -0
  375. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/test_ResNet.py +0 -0
  376. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/test_SprintDataset.py +0 -0
  377. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/test_SprintInterface.py +0 -0
  378. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/test_TFEngine.py +0 -0
  379. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/test_TFNativeOp.py +0 -0
  380. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/test_TFNetworkLayer.py +0 -0
  381. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/test_TFNetworkRecLayer.py +0 -0
  382. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/test_TFNetworkSigProcLayer.py +0 -0
  383. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/test_TFUpdater.py +0 -0
  384. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/test_TFUtil.py +0 -0
  385. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/test_TF_determinism.py +0 -0
  386. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/test_TaskSystem.py +0 -0
  387. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/test_TaskSystem_SharedMem.py +0 -0
  388. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/test_TranslationDataset.py +0 -0
  389. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/test_Util.py +0 -0
  390. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/test_demos.py +0 -0
  391. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/test_fork_exec.py +0 -0
  392. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/test_hdf_dump.py +0 -0
  393. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/test_rf_array.py +0 -0
  394. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/test_rf_attention.py +0 -0
  395. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/test_rf_cond.py +0 -0
  396. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/test_rf_const.py +0 -0
  397. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/test_rf_container.py +0 -0
  398. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/test_rf_conv.py +0 -0
  399. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/test_rf_decoder_transformer.py +0 -0
  400. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/test_rf_encoder_conformer.py +0 -0
  401. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/test_rf_gradient.py +0 -0
  402. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/test_rf_label_smoothing.py +0 -0
  403. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/test_rf_loop.py +0 -0
  404. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/test_rf_math.py +0 -0
  405. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/test_rf_normalization.py +0 -0
  406. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/test_rf_piecewise_linear.py +0 -0
  407. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/test_rf_rec.py +0 -0
  408. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/test_rf_reduce.py +0 -0
  409. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/test_rf_signal.py +0 -0
  410. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/test_tensor.py +0 -0
  411. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/test_tools.py +0 -0
  412. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/test_torch_dataset.py +0 -0
  413. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/test_torch_frontend.py +0 -0
  414. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/test_torch_internal_frontend.py +0 -0
  415. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/test_torch_util.py +0 -0
  416. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tests/torch_utils.py +0 -0
  417. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tools/_setup_returnn_env.py +0 -0
  418. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tools/analyze-dataset-batches.py +0 -0
  419. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tools/bliss-collect-seq-lens.py +0 -0
  420. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tools/bliss-dump-text.py +0 -0
  421. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tools/bliss-get-segment-names.py +0 -0
  422. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tools/bliss-to-ogg-zip.py +0 -0
  423. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tools/bpe-create-lexicon.py +0 -0
  424. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tools/calculate-word-error-rate.py +0 -0
  425. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tools/cleanup-old-models.py +0 -0
  426. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tools/collect-orth-symbols.py +0 -0
  427. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tools/collect-words.py +0 -0
  428. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tools/compile_native_op.py +0 -0
  429. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tools/compile_tf_graph.py +0 -0
  430. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tools/debug-dump-search-scores.py +0 -0
  431. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tools/debug-plot-search-scores.py +0 -0
  432. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tools/dump-dataset-raw-strings.py +0 -0
  433. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tools/dump-dataset.py +0 -0
  434. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tools/dump-forward-stats.py +0 -0
  435. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tools/dump-forward.py +0 -0
  436. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tools/dump-network-json.py +0 -0
  437. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tools/dump-pickle.py +0 -0
  438. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tools/extract_state_tying_from_dataset.py +0 -0
  439. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tools/get-attention-weights.py +0 -0
  440. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tools/get-best-model-epoch.py +0 -0
  441. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tools/hdf_dump.py +0 -0
  442. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tools/hdf_dump_translation_dataset.py +0 -0
  443. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tools/import-blocks-mt-model.py +0 -0
  444. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tools/import-t2t-mt-model.py +0 -0
  445. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tools/lattice_rescorer/.gitignore +0 -0
  446. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tools/lattice_rescorer/Makefile +0 -0
  447. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tools/lattice_rescorer/README.md +0 -0
  448. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tools/lattice_rescorer/example/README.md +0 -0
  449. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tools/lattice_rescorer/example/libs_list +0 -0
  450. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tools/lattice_rescorer/example/network.040/i600_m600_m600.sgd_b16_lr0_cl2.newbobabs.config +0 -0
  451. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tools/lattice_rescorer/example/network.040/i600_m600_m600.sgd_b16_lr0_cl2.newbobabs.keep_over_epoch.lstm2.config +0 -0
  452. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tools/lattice_rescorer/example/rescore_lattice.sh +0 -0
  453. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tools/lattice_rescorer/example/state_vars_list +0 -0
  454. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tools/lattice_rescorer/example/tensor_names_list +0 -0
  455. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tools/lattice_rescorer/file.h +0 -0
  456. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tools/lattice_rescorer/htklatticerescorer.cc +0 -0
  457. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tools/lattice_rescorer/htklatticerescorer.h +0 -0
  458. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tools/lattice_rescorer/main.cc +0 -0
  459. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tools/lattice_rescorer/rescorer.h +0 -0
  460. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tools/lattice_rescorer/vocabulary.cc +0 -0
  461. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tools/lattice_rescorer/vocabulary.h +0 -0
  462. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tools/tf_avg_checkpoints.py +0 -0
  463. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tools/tf_inspect_checkpoint.py +0 -0
  464. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tools/tf_inspect_summary_log.py +0 -0
  465. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tools/torch_avg_checkpoints.py +0 -0
  466. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tools/torch_export_to_onnx.py +0 -0
  467. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tools/torch_inspect_checkpoint.py +0 -0
  468. {returnn-1.20241020.5643 → returnn-1.20241022.224754}/tools/torch_inspect_checkpoint_and_opt.py +0 -0
@@ -176,4 +176,11 @@ Contains code from PyTorch:
176
176
  - Copyright 2016..2023 various developers
177
177
  - Various code snippets
178
178
 
179
+ Contains code from lion-pytorch:
180
+ - https://github.com/google/automl/blob/master/lion/lion_pytorch.py
181
+ - https://github.com/lucidrains/lion-pytorch/
182
+ - MIT License / Apache License 2.0
183
+ - Copyright 2023 Google Research, Phil Wang
184
+ - torch/optim/lion.py
185
+
179
186
  Various code snippets from StackOverflow, which are under Creative Commons / Public Domain.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: returnn
3
- Version: 1.20241020.5643
3
+ Version: 1.20241022.224754
4
4
  Summary: The RWTH extensible training framework for universal recurrent neural networks
5
5
  Home-page: https://github.com/rwth-i6/returnn/
6
6
  Author: Albert Zeyer
@@ -0,0 +1,2 @@
1
+ version = '1.20241022.224754'
2
+ long_version = '1.20241022.224754+git.b5db365'
@@ -56,8 +56,8 @@ def mel_filterbank(
56
56
  filter_bank_matrix_np = _mel_filter_bank_matrix_np(
57
57
  f_min=f_min, f_max=f_max, sampling_rate=sampling_rate, fft_size=fft_length, nr_of_filters=out_dim.dimension
58
58
  )
59
- filter_bank_matrix_np = filter_bank_matrix_np.astype(x.dtype)
60
59
  filter_bank_matrix = rf.convert_to_tensor(filter_bank_matrix_np, dims=(in_dim, out_dim), _backend=backend)
60
+ filter_bank_matrix = rf.cast(filter_bank_matrix, dtype=x.dtype)
61
61
  filter_bank_matrix = rf.copy_to_device(filter_bank_matrix, x.device)
62
62
  if backend.executing_eagerly():
63
63
  if len(_mel_filter_bank_matrix_cache) > 100:
@@ -191,6 +191,9 @@ def log_mel_filterbank_from_raw(
191
191
  fft_length=n_fft,
192
192
  )
193
193
  power_spectrogram = rf.abs(spectrogram) ** 2.0
194
+ # stft might have upcasted this to float32 because some PyTorch versions don't support stft on bfloat16.
195
+ # https://github.com/pytorch/pytorch/issues/117844
196
+ power_spectrogram = rf.cast(power_spectrogram, dtype=raw_audio.dtype)
194
197
  mel_fbank = mel_filterbank(power_spectrogram, in_dim=in_dim_, out_dim=out_dim, sampling_rate=sampling_rate)
195
198
  log_mel_fbank = rf.safe_log(mel_fbank, eps=1e-10)
196
199
  if log_base != math.e:
@@ -101,15 +101,11 @@ class TransformerDecoder(rf.Module):
101
101
  if pos_enc is None:
102
102
  pass
103
103
  elif isinstance(pos_enc, dict):
104
- pos_enc = rf.build_from_dict(
105
- pos_enc, feat_dim=embed_dim or model_dim, dtype=self.input_embedding.weight.dtype
106
- )
104
+ pos_enc = rf.build_from_dict(pos_enc, feat_dim=embed_dim or model_dim)
107
105
  elif isinstance(pos_enc, rf.Module):
108
106
  pass
109
107
  elif isinstance(pos_enc, FunctionType):
110
- pos_enc = functools.partial(
111
- pos_enc, feat_dim=embed_dim or model_dim, dtype=self.input_embedding.weight.dtype
112
- )
108
+ pos_enc = functools.partial(pos_enc, feat_dim=embed_dim or model_dim)
113
109
  else:
114
110
  raise TypeError(f"unexpected pos_enc type {pos_enc!r}")
115
111
  self.pos_enc = pos_enc
@@ -3,9 +3,17 @@ DType helpers
3
3
  """
4
4
 
5
5
  from __future__ import annotations
6
+ from contextlib import contextmanager
6
7
 
7
8
 
8
- __all__ = ["get_default_float_dtype", "get_default_int_dtype", "get_default_array_index_dtype", "is_float_dtype"]
9
+ __all__ = [
10
+ "get_default_float_dtype",
11
+ "set_default_float_dtype",
12
+ "set_default_float_dtype_ctx",
13
+ "get_default_int_dtype",
14
+ "get_default_array_index_dtype",
15
+ "is_float_dtype",
16
+ ]
9
17
 
10
18
 
11
19
  _default_float_dtype: str = "float32"
@@ -21,6 +29,32 @@ def get_default_float_dtype() -> str:
21
29
  return _default_float_dtype
22
30
 
23
31
 
32
+ def set_default_float_dtype(dtype: str):
33
+ """
34
+ Set the default float dtype
35
+
36
+ :param dtype: the new default float dtype
37
+ """
38
+ global _default_float_dtype
39
+ assert isinstance(dtype, str)
40
+ _default_float_dtype = dtype
41
+
42
+
43
+ @contextmanager
44
+ def set_default_float_dtype_ctx(dtype: str):
45
+ """
46
+ :param dtype: see :func:`get_default_float_dtype`
47
+ """
48
+ global _default_float_dtype
49
+ assert isinstance(dtype, str)
50
+ old_default_float_dtype = _default_float_dtype
51
+ try:
52
+ _default_float_dtype = dtype
53
+ yield
54
+ finally:
55
+ _default_float_dtype = old_default_float_dtype
56
+
57
+
24
58
  def get_default_int_dtype() -> str:
25
59
  """
26
60
  https://data-apis.org/array-api/latest/API_specification/data_types.html#default-data-types
@@ -104,6 +104,9 @@ def tensor_fill_random_numpy_(
104
104
  x.raw_tensor = rnd.randint(0, 2, size=shape, dtype=x.dtype)
105
105
  elif x.dtype.startswith("float"):
106
106
  x.raw_tensor = rnd.normal(0.0, 1.0, size=shape).astype(x.dtype)
107
+ elif x.dtype == "bfloat16":
108
+ # Numpy does not support bfloat16, will later be casted to bfloat16
109
+ x.raw_tensor = rnd.normal(0.0, 1.0, size=shape).astype("float32")
107
110
  elif x.dtype.startswith("complex"):
108
111
  real = rnd.normal(0.0, 1.0, size=shape)
109
112
  imag = rnd.normal(0.0, 1.0, size=shape)
@@ -3,7 +3,7 @@ From raw dict to extern_data tensor dict.
3
3
  """
4
4
 
5
5
  from __future__ import annotations
6
- from typing import Any, Union, Dict, List, Sequence
6
+ from typing import Optional, Any, Union, Dict, List, Sequence
7
7
  import numpy
8
8
  import torch
9
9
  from returnn.tensor import Tensor, TensorDict, Dim
@@ -27,13 +27,18 @@ def raw_dict_to_extern_data(
27
27
  *,
28
28
  extern_data_template: TensorDict,
29
29
  device: Union[str, torch.device],
30
+ float_dtype: Optional[Union[str, torch.dtype]] = None,
30
31
  ) -> TensorDict:
31
32
  """
32
33
  :param extern_data_raw: This comes out of the DataLoader, via our collate_batch.
33
34
  :param extern_data_template: Specified via `extern_data` in the config.
34
35
  :param device: E.g. the GPU.
36
+ :param float_dtype:
35
37
  :return: tensor dict, like extern_data_template, but with raw tensors set to Torch tensors, on the right device.
36
38
  """
39
+ if isinstance(float_dtype, str):
40
+ float_dtype = getattr(torch, float_dtype)
41
+ assert isinstance(float_dtype, torch.dtype)
37
42
  assert isinstance(extern_data_raw, dict) and extern_data_raw
38
43
  batch_dim = get_batch_dim_from_extern_data(extern_data_template)
39
44
  for dim in _get_dyn_dims_from_extern_data(extern_data_template):
@@ -51,6 +56,8 @@ def raw_dict_to_extern_data(
51
56
  dim.dimension == raw_tensor.shape[i]
52
57
  ), f"shape mismatch for {k}: {raw_tensor.shape} vs {data.batch_shape}"
53
58
  if isinstance(raw_tensor, torch.Tensor):
59
+ if raw_tensor.dtype.is_floating_point and float_dtype:
60
+ raw_tensor = raw_tensor.to(dtype=float_dtype)
54
61
  data.dtype = str(raw_tensor.dtype).split(".")[-1] # just overwrite for now...
55
62
  data.raw_tensor = raw_tensor.to(device)
56
63
  elif isinstance(raw_tensor, numpy.ndarray):
@@ -4,7 +4,7 @@ Main engine for PyTorch
4
4
 
5
5
  from __future__ import annotations
6
6
  from typing import Optional, Any, Union, Callable, Dict, Set
7
- from contextlib import nullcontext
7
+ from contextlib import nullcontext, ExitStack, contextmanager
8
8
 
9
9
  import gc
10
10
  import os
@@ -129,6 +129,13 @@ class Engine(EngineBase):
129
129
  self._forward_auto_split_batch_on_oom = config.bool("forward_auto_split_batch_on_oom", False)
130
130
  self._stop_on_nonfinite_train_score = config.bool("stop_on_nonfinite_train_score", True)
131
131
 
132
+ default_float_dtype = config.value("default_float_dtype", None)
133
+ if default_float_dtype is not None:
134
+ assert isinstance(default_float_dtype, str)
135
+ default_float_dtype = getattr(torch, default_float_dtype)
136
+ assert isinstance(default_float_dtype, torch.dtype)
137
+ self._default_float_dtype: Optional[torch.dtype] = default_float_dtype
138
+
132
139
  amp_options = self.config.opt_typed_value("torch_amp")
133
140
  grad_scaler_opts = self.config.typed_value("grad_scaler", NotSpecified)
134
141
  if amp_options is not None:
@@ -380,7 +387,10 @@ class Engine(EngineBase):
380
387
  cur_count_grad_accum = 0
381
388
 
382
389
  extern_data = extern_data_util.raw_dict_to_extern_data(
383
- extern_data_raw, extern_data_template=self.extern_data, device=self._device
390
+ extern_data_raw,
391
+ extern_data_template=self.extern_data,
392
+ device=self._device,
393
+ float_dtype=self._default_float_dtype,
384
394
  )
385
395
  self._run_step(extern_data, train_flag=True, train_func=True)
386
396
 
@@ -389,7 +399,7 @@ class Engine(EngineBase):
389
399
  losses_dict = NumbersDict(
390
400
  {
391
401
  name: (
392
- float(loss.get_summed_loss().raw_tensor.detach().cpu().numpy())
402
+ float(loss.get_summed_loss().raw_tensor.detach().cpu().item())
393
403
  if self._device != "meta"
394
404
  else float("nan")
395
405
  )
@@ -553,7 +563,10 @@ class Engine(EngineBase):
553
563
  torch.distributed.broadcast(_has_data, src=0)
554
564
 
555
565
  extern_data = extern_data_util.raw_dict_to_extern_data(
556
- extern_data_raw, extern_data_template=self.extern_data, device=self._device
566
+ extern_data_raw,
567
+ extern_data_template=self.extern_data,
568
+ device=self._device,
569
+ float_dtype=self._default_float_dtype,
557
570
  )
558
571
 
559
572
  self._run_step(extern_data, train_func=True)
@@ -566,7 +579,7 @@ class Engine(EngineBase):
566
579
  losses_dict = NumbersDict(
567
580
  {
568
581
  name: (
569
- float(loss.get_summed_loss().raw_tensor.detach().cpu().numpy())
582
+ float(loss.get_summed_loss().raw_tensor.detach().cpu().item())
570
583
  if self._device != "meta"
571
584
  else float("nan")
572
585
  )
@@ -686,6 +699,26 @@ class Engine(EngineBase):
686
699
 
687
700
  return data_loader
688
701
 
702
+ @contextmanager
703
+ def _run_ctx_mgr(self):
704
+ with ExitStack() as stack:
705
+ if self._use_autocast:
706
+ stack.enter_context(autocast(device_type=self._device.split(":")[0], dtype=self._autocast_dtype))
707
+ stack.enter_context(rf.set_default_device_ctx(self._device))
708
+ if self._default_float_dtype:
709
+ stack.enter_context(rf.set_default_float_dtype_ctx(str(self._default_float_dtype).split(".")[-1]))
710
+ stack.enter_context(self._set_torch_default_dtype_ctx_mgr(self._default_float_dtype))
711
+ yield
712
+
713
+ @contextmanager
714
+ def _set_torch_default_dtype_ctx_mgr(self, dtype: torch.dtype):
715
+ old_dtype = torch.get_default_dtype()
716
+ try:
717
+ torch.set_default_dtype(dtype)
718
+ yield
719
+ finally:
720
+ torch.set_default_dtype(old_dtype)
721
+
689
722
  def _run_step(
690
723
  self, extern_data: TensorDict, *, train_flag: bool = False, train_func: bool, _inside_wrapped: bool = False
691
724
  ):
@@ -706,11 +739,7 @@ class Engine(EngineBase):
706
739
  expected_outputs=self._forward_step_expected_outputs, step=self.global_train_step, epoch=self.epoch
707
740
  )
708
741
 
709
- with (
710
- autocast(device_type=self._device.split(":")[0], dtype=self._autocast_dtype)
711
- if self._use_autocast
712
- else nullcontext()
713
- ), rf.set_default_device_ctx(self._device):
742
+ with self._run_ctx_mgr():
714
743
  sentinel_kw = util.get_fwd_compat_kwargs()
715
744
  if train_func:
716
745
  self._train_step_func(model=self._orig_model, extern_data=extern_data, **sentinel_kw)
@@ -893,6 +922,8 @@ class Engine(EngineBase):
893
922
  )
894
923
  )
895
924
 
925
+ if self._default_float_dtype:
926
+ self._pt_model.to(dtype=self._default_float_dtype)
896
927
  self._pt_model.to(self._device)
897
928
 
898
929
  if model_epoch_filename and is_training:
@@ -906,11 +937,7 @@ class Engine(EngineBase):
906
937
 
907
938
  load_model_post_hooks = self.config.typed_value("load_model_post_hooks")
908
939
  if load_model_post_hooks:
909
- with (
910
- autocast(device_type=self._device.split(":")[0], dtype=self._autocast_dtype)
911
- if self._use_autocast
912
- else nullcontext()
913
- ), rf.set_default_device_ctx(self._device):
940
+ with self._run_ctx_mgr():
914
941
  sentinel_kw = util.get_fwd_compat_kwargs()
915
942
  for hook in load_model_post_hooks:
916
943
  hook(model=self._orig_model, **sentinel_kw)
@@ -1090,7 +1117,10 @@ class Engine(EngineBase):
1090
1117
  # Currently, this callback interface is intended to also be used by other backends,
1091
1118
  # and then the user can always assume Numpy arrays.
1092
1119
  if isinstance(raw, torch.Tensor): # might already be numpy array
1093
- raw = raw.detach().cpu().numpy()
1120
+ raw = raw.detach().cpu()
1121
+ if raw.dtype == torch.bfloat16:
1122
+ raw = raw.float()
1123
+ raw = raw.numpy()
1094
1124
  y.raw_tensor = raw
1095
1125
  return y
1096
1126
 
@@ -1120,7 +1150,10 @@ class Engine(EngineBase):
1120
1150
  # Also resets any dyn dims, which might have been set in the prev step.
1121
1151
  self._forward_step_expected_outputs.reset_content()
1122
1152
  extern_data = extern_data_util.raw_dict_to_extern_data(
1123
- extern_data_raw, extern_data_template=self.extern_data, device=self._device
1153
+ extern_data_raw,
1154
+ extern_data_template=self.extern_data,
1155
+ device=self._device,
1156
+ float_dtype=self._default_float_dtype,
1124
1157
  )
1125
1158
  try:
1126
1159
  self._run_step(extern_data, train_func=False)
@@ -1224,7 +1257,10 @@ def _to_raw(n: Union[int, float, Tensor]):
1224
1257
  if isinstance(n, (int, float)):
1225
1258
  return n
1226
1259
  if isinstance(n, Tensor):
1227
- return n.raw_tensor.detach().cpu().numpy()
1260
+ x = n.raw_tensor.detach().cpu()
1261
+ if x.dtype == torch.bfloat16:
1262
+ x = x.float()
1263
+ return x.numpy()
1228
1264
  raise TypeError(f"Unexpected {n} of type {type(n)}")
1229
1265
 
1230
1266
 
@@ -676,6 +676,9 @@ class TorchBackend(Backend[torch.Tensor]):
676
676
  if len(batch_dims) != 1:
677
677
  targets_raw = torch.reshape(targets_raw, (batch_n_elems, targets_raw.shape[-1])) # [B', S]
678
678
  targets_lengths = torch.reshape(targets_lengths, (batch_n_elems,)) # [B']
679
+ if log_probs.dtype == torch.bfloat16:
680
+ # Currently (PyTorch 2.5), ctc_loss does not support bfloat16.
681
+ log_probs = log_probs.to(torch.float32)
679
682
  loss_raw = torch.nn.functional.ctc_loss(
680
683
  log_probs=log_probs,
681
684
  targets=targets_raw,
@@ -691,7 +694,7 @@ class TorchBackend(Backend[torch.Tensor]):
691
694
  name="ctc_loss",
692
695
  dims=batch_dims,
693
696
  raw_tensor=loss_raw,
694
- dtype=logits.dtype,
697
+ dtype=TorchBackend.get_dtype_name_raw(loss_raw),
695
698
  )
696
699
  return loss
697
700
 
@@ -2039,6 +2042,12 @@ class TorchBackend(Backend[torch.Tensor]):
2039
2042
  pad_right = fft_length - frame_length - pad_left
2040
2043
  window_pt = torch.nn.functional.pad(window_pt, (pad_left, pad_right))
2041
2044
 
2045
+ orig_dtype = x_raw.dtype
2046
+ if orig_dtype == torch.bfloat16:
2047
+ # PyTorch stft does not support bfloat16 currently (PyTorch 2.5):
2048
+ # https://github.com/pytorch/pytorch/issues/117844
2049
+ # (Check back later here whether that's still the case...)
2050
+ x_raw = x_raw.to(torch.float32)
2042
2051
  y_raw = torch.stft(
2043
2052
  x_raw,
2044
2053
  n_fft=fft_length,
@@ -0,0 +1,5 @@
1
+ Here we can put some arbitrary external optimizers.
2
+ It might be copied from some existing code, or our own implementation.
3
+ It might also happen that some of these will be added to later versions of PyTorch.
4
+ So, regarding the user config, the optimizers here should be differentiated
5
+ by having the full module name, e.g. like ``returnn.torch.optim.lion.Lion``.
@@ -0,0 +1,3 @@
1
+ """
2
+ Any custom optimizer
3
+ """
@@ -0,0 +1,205 @@
1
+ """
2
+ Lion optimizer <https://arxiv.org/abs/2302.06675>
3
+
4
+ Code adapted from https://github.com/lucidrains/lion-pytorch/,
5
+ which is adapted from https://github.com/google/automl/blob/master/lion/lion_pytorch.py.
6
+ """
7
+
8
+ from __future__ import annotations
9
+ from typing import Optional, Tuple, Callable
10
+ import inspect
11
+ import torch
12
+ from torch.optim.optimizer import Optimizer
13
+
14
+
15
+ class Lion(Optimizer):
16
+ """
17
+ Lion (Evolved Sign Momentum (Evo_l_ved S_i_gn M_o_me_n_tum)) optimizer <https://arxiv.org/abs/2302.06675>
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ params,
23
+ lr: float = 1e-4,
24
+ betas: Tuple[float, float] = (0.9, 0.99),
25
+ weight_decay: float = 0.0,
26
+ use_triton: Optional[bool] = None,
27
+ decoupled_weight_decay: bool = False,
28
+ ):
29
+ assert lr > 0.0
30
+ assert all([0.0 <= beta <= 1.0 for beta in betas])
31
+
32
+ self._init_lr = lr
33
+ self.decoupled_wd = decoupled_weight_decay
34
+
35
+ defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
36
+
37
+ super().__init__(params, defaults)
38
+
39
+ if use_triton is None:
40
+ use_triton = bool(triton_update_fn)
41
+ self.use_triton = use_triton
42
+
43
+ @torch.no_grad()
44
+ def step(self, closure: Optional[Callable] = None):
45
+ """update step"""
46
+ loss = None
47
+ if closure is not None:
48
+ with torch.enable_grad():
49
+ loss = closure()
50
+
51
+ for group in self.param_groups:
52
+ for p in group["params"]:
53
+ if p.grad is None:
54
+ continue
55
+
56
+ beta1, beta2 = group["betas"]
57
+ grad, lr, wd, state, decoupled_wd, init_lr = (
58
+ p.grad,
59
+ group["lr"],
60
+ group["weight_decay"],
61
+ self.state[p],
62
+ self.decoupled_wd,
63
+ self._init_lr,
64
+ )
65
+
66
+ # maybe decoupled weight decay
67
+
68
+ if decoupled_wd:
69
+ wd /= init_lr
70
+
71
+ # init state - exponential moving average of gradient values
72
+
73
+ if len(state) == 0:
74
+ state["exp_avg"] = torch.zeros_like(p)
75
+
76
+ exp_avg = state["exp_avg"]
77
+
78
+ if self.use_triton and p.is_cuda:
79
+ triton_update_fn(p, grad, exp_avg, lr, wd, beta1, beta2)
80
+ else:
81
+ update_fn(p, grad, exp_avg, lr, wd, beta1, beta2)
82
+
83
+ return loss
84
+
85
+
86
+ # update functions
87
+
88
+
89
+ def update_fn(p, grad, exp_avg, lr, wd, beta1, beta2):
90
+ """
91
+ Lion update function
92
+ """
93
+ # stepweight decay
94
+
95
+ p.data.mul_(1.0 - lr * wd)
96
+
97
+ # weight update
98
+
99
+ update = exp_avg.clone().mul_(beta1).add(grad, alpha=1.0 - beta1).sign_()
100
+ p.add_(update, alpha=-lr)
101
+
102
+ # decay the momentum running average coefficient
103
+
104
+ exp_avg.mul_(beta2).add_(grad, alpha=1.0 - beta2)
105
+
106
+
107
+ try:
108
+ # noinspection PyPackageRequirements
109
+ import triton
110
+
111
+ # noinspection PyPackageRequirements
112
+ import triton.language as tl
113
+ except ImportError as e:
114
+ triton = None
115
+ tl = None
116
+
117
+
118
+ # restore_value is not available in older versions of triton
119
+ if triton and "restore_value" in inspect.signature(triton.autotune).parameters:
120
+ # triton cuda kernel
121
+
122
+ # noinspection PyPep8Naming,PyArgumentList
123
+ @triton.autotune(
124
+ configs=[
125
+ triton.Config({"BLOCK_SIZE": 128}, num_warps=4),
126
+ triton.Config({"BLOCK_SIZE": 1024}, num_warps=8),
127
+ ],
128
+ key=["n_elements"],
129
+ restore_value=["p_ptr", "exp_avg_ptr"],
130
+ )
131
+ @triton.jit
132
+ def _triton_update_fn_kernel(
133
+ p_ptr,
134
+ grad_ptr,
135
+ exp_avg_ptr,
136
+ lr,
137
+ wd,
138
+ beta1,
139
+ beta2,
140
+ n_elements,
141
+ BLOCK_SIZE: tl.constexpr,
142
+ ):
143
+ pid = tl.program_id(axis=0)
144
+
145
+ block_start = pid * BLOCK_SIZE
146
+ offsets = block_start + tl.arange(0, BLOCK_SIZE)
147
+
148
+ mask = offsets < n_elements
149
+
150
+ # offsetted pointers
151
+
152
+ offset_p_ptr = p_ptr + offsets
153
+ offset_grad_ptr = grad_ptr + offsets
154
+ offset_exp_avg_ptr = exp_avg_ptr + offsets
155
+
156
+ # load
157
+
158
+ p = tl.load(offset_p_ptr, mask=mask)
159
+ grad = tl.load(offset_grad_ptr, mask=mask)
160
+ exp_avg = tl.load(offset_exp_avg_ptr, mask=mask)
161
+
162
+ # stepweight decay
163
+
164
+ p = p * (1 - lr * wd)
165
+
166
+ # diff between momentum running average and grad
167
+
168
+ diff = exp_avg - grad
169
+
170
+ # weight update
171
+
172
+ update = diff * beta1 + grad
173
+
174
+ # torch.sign
175
+
176
+ can_update = update != 0
177
+ update_sign = tl.where(update > 0, -lr, lr)
178
+
179
+ p = p + update_sign * can_update
180
+
181
+ # decay the momentum running average coefficient
182
+
183
+ exp_avg = diff * beta2 + grad
184
+
185
+ # store new params and momentum running average coefficient
186
+
187
+ tl.store(offset_p_ptr, p, mask=mask)
188
+ tl.store(offset_exp_avg_ptr, exp_avg, mask=mask)
189
+
190
+ def triton_update_fn(
191
+ p: torch.Tensor, grad: torch.Tensor, exp_avg: torch.Tensor, lr: float, wd: float, beta1: float, beta2: float
192
+ ):
193
+ """
194
+ Lion update function using triton kernel
195
+ """
196
+ assert all([t.is_cuda for t in (p, grad, exp_avg)])
197
+ n_elements = p.numel()
198
+
199
+ def _grid(meta):
200
+ return tuple((triton.cdiv(n_elements, meta["BLOCK_SIZE"]),))
201
+
202
+ _triton_update_fn_kernel[_grid](p, grad, exp_avg, lr, wd, beta1, beta2, n_elements)
203
+
204
+ else:
205
+ triton_update_fn = None
@@ -5,7 +5,7 @@ and model param update logic in general.
5
5
 
6
6
  from __future__ import annotations
7
7
 
8
- from typing import Optional, Union, Any, Type, Sequence, Iterable, Set, Dict, List, Tuple
8
+ from typing import Optional, Union, Any, Type, Callable, Sequence, Iterable, Set, Dict, List, Tuple
9
9
  import os
10
10
  import gc
11
11
  import torch
@@ -38,28 +38,40 @@ def _init_optimizer_classes_dict():
38
38
  _OptimizerClassesDict[name.lower()] = cls
39
39
 
40
40
 
41
- def get_optimizer_class(class_name) -> Type[torch.optim.Optimizer]:
41
+ def get_optimizer_class(
42
+ class_name: Union[str, Type[torch.optim.Optimizer], Callable[[], Type[torch.optim.Optimizer]]]
43
+ ) -> Type[torch.optim.Optimizer]:
42
44
  """
43
- :param str|()->torch.optim.Optimizer|type[torch.optim.Optimizer] class_name:
44
- Optimizer data, e.g. "adam", torch.optim.Adam...
45
- :return: Optimizer class
45
+ :param class_name: Optimizer class, either as str (e.g. "adam"), as type (torch.optim.Adam) or callable.
46
+ If str, we support all torch.optim optimizers (ignoring case) (e.g. "adam"),
47
+ or class names with full module path (e.g. "returnn.torch.optim.lion.Lion").
48
+ :return: Optimizer class, e.g. torch.optim.Adam
46
49
  """
47
50
  _init_optimizer_classes_dict()
48
51
  if isinstance(class_name, type):
49
52
  assert issubclass(class_name, torch.optim.Optimizer)
53
+ return class_name
50
54
  elif callable(class_name):
51
- class_name = class_name()
52
- else:
53
- assert isinstance(class_name, str)
54
- assert (
55
- class_name.lower() in _OptimizerClassesDict
56
- ), "%s not found in the available torch optimizers list: %s." % (
57
- class_name.lower(),
58
- ", ".join("'%s'" % key for key in _OptimizerClassesDict),
59
- )
60
- class_name = _OptimizerClassesDict[class_name.lower()]
55
+ return class_name()
56
+ elif isinstance(class_name, str):
57
+ if "." in class_name:
58
+ import importlib
61
59
 
62
- return class_name
60
+ mod_name, class_name_ = class_name.rsplit(".", 1)
61
+ mod = importlib.import_module(mod_name)
62
+ return getattr(mod, class_name_)
63
+
64
+ if class_name.lower() not in _OptimizerClassesDict:
65
+ raise ValueError(
66
+ "Optimizer %r not found in the available torch optimizers list: %s."
67
+ % (
68
+ class_name.lower(),
69
+ ", ".join("'%s'" % key for key in _OptimizerClassesDict),
70
+ )
71
+ )
72
+ return _OptimizerClassesDict[class_name.lower()]
73
+ else:
74
+ raise TypeError(f"Invalid optimizer class_name {class_name!r} type {type(class_name).__name__}")
63
75
 
64
76
 
65
77
  def _get_class_init_kwargs(optim_class):
@@ -411,7 +423,7 @@ class Updater:
411
423
  # If the user specified it as epsilon, parse it as eps for the optimizer
412
424
  if "eps" in optim_class_init_kwargs and "epsilon" in opt_kwargs:
413
425
  opt_kwargs["eps"] = opt_kwargs.pop("epsilon")
414
- if "learning_rate" in opt_kwargs:
426
+ if "learning_rate" in opt_kwargs or "lr" in opt_kwargs:
415
427
  raise ValueError("'learning_rate' should be set outside of the 'optimizer' dict.")
416
428
  lr = lr * opt_kwargs.pop("learning_rate_multiplier", 1.0)
417
429
  opt_kwargs["lr"] = lr
@@ -86,7 +86,10 @@ def _help_data_or_array(
86
86
  :return: (info,(min,max))
87
87
  """
88
88
  if isinstance(value, torch.Tensor):
89
- value = value.detach().cpu().numpy()
89
+ value = value.detach().cpu()
90
+ if value.dtype == torch.bfloat16:
91
+ value = value.float()
92
+ value = value.numpy()
90
93
  v_minmax = -1, -1
91
94
  if isinstance(value, np.ndarray):
92
95
  info = "shape %s, dtype %s" % (value.shape, value.dtype)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: returnn
3
- Version: 1.20241020.5643
3
+ Version: 1.20241022.224754
4
4
  Summary: The RWTH extensible training framework for universal recurrent neural networks
5
5
  Home-page: https://github.com/rwth-i6/returnn/
6
6
  Author: Albert Zeyer