returnn 1.20230403.124714__tar.gz → 1.20230403.211148__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of returnn might be problematic. Click here for more details.

Files changed (360) hide show
  1. {returnn-1.20230403.124714/returnn.egg-info → returnn-1.20230403.211148}/PKG-INFO +1 -1
  2. returnn-1.20230403.211148/_setup_info_generated.py +2 -0
  3. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/frontend/run_ctx.py +32 -20
  4. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tensor/_tensor_extra.py +8 -6
  5. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/frontend_layers/config_entry_points.py +2 -2
  6. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/torch/data/pipeline.py +8 -6
  7. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/torch/engine.py +53 -0
  8. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/torch/frontend/_rand.py +2 -1
  9. {returnn-1.20230403.124714 → returnn-1.20230403.211148/returnn.egg-info}/PKG-INFO +1 -1
  10. returnn-1.20230403.124714/_setup_info_generated.py +0 -2
  11. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/.editorconfig +0 -0
  12. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/.gitignore +0 -0
  13. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/.gitmodules +0 -0
  14. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/.kateconfig +0 -0
  15. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/CHANGELOG.md +0 -0
  16. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/CODEOWNERS +0 -0
  17. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/CONTRIBUTING.md +0 -0
  18. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/LICENSE +0 -0
  19. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/MANIFEST.in +0 -0
  20. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/README.rst +0 -0
  21. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/__init__.py +0 -0
  22. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/12AX.cluster_map +0 -0
  23. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/_setup_returnn_env.py +0 -0
  24. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-fwd.config +0 -0
  25. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-horovod-mpi.py +0 -0
  26. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-horovod-mpi.py.sh +0 -0
  27. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-horovod-mpi.sh +0 -0
  28. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-hyper-param-tuning.config +0 -0
  29. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-iter-dataset.py +0 -0
  30. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-list-devices.py +0 -0
  31. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-lua-torch-layer.config +0 -0
  32. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-pretrain.config +0 -0
  33. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-record-and-push-to-webserver.py +0 -0
  34. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-returnn-as-framework.py +0 -0
  35. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-rhn-enwik8.config +0 -0
  36. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-sprint-interface.py +0 -0
  37. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-tf-att-copy.config +0 -0
  38. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-tf-attention.config +0 -0
  39. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-tf-chunking-blstm.12ax.config +0 -0
  40. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-tf-contribrnn-lstm.12ax.config +0 -0
  41. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-tf-enc-dec.config +0 -0
  42. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-tf-hard-att-copy.config +0 -0
  43. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-tf-lstm-benchmark.py +0 -0
  44. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-tf-maxgradnorm-lstm.12ax.config +0 -0
  45. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-tf-native-lstm-lowmem.12ax.config +0 -0
  46. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-tf-native-lstm.12ax.config +0 -0
  47. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-tf-native-lstm2.12ax.config +0 -0
  48. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-tf-native-lstm2.12ax.tuned.config +0 -0
  49. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-tf-neural-transducer.12ax.config +0 -0
  50. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-tf-rec-explicit-lstm.config +0 -0
  51. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-tf-rec-explicit-rnn.config +0 -0
  52. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-tf-rec-self-att.config +0 -0
  53. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-tf-search-compiled-graph.py +0 -0
  54. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-tf-vanilla-lstm.12ax.config +0 -0
  55. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-timit-lstm-ctc.config +0 -0
  56. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-torch.config +0 -0
  57. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-upd-mult-model.lstm.12ax.config +0 -0
  58. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo.sh +0 -0
  59. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/mdlstm/IAM/IAM_lines/a01-000u-00.png +0 -0
  60. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/mdlstm/IAM/IAM_lines/a01-007-04.png +0 -0
  61. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/mdlstm/IAM/IAM_lines/a01-007-06.png +0 -0
  62. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/mdlstm/IAM/README.txt +0 -0
  63. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/mdlstm/IAM/chars.txt +0 -0
  64. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/mdlstm/IAM/config_demo +0 -0
  65. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/mdlstm/IAM/config_fwd +0 -0
  66. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/mdlstm/IAM/config_real +0 -0
  67. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/mdlstm/IAM/create_IAM_dataset.py +0 -0
  68. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/mdlstm/IAM/decode.py +0 -0
  69. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/mdlstm/IAM/features/raw/demo.h5 +0 -0
  70. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/mdlstm/IAM/go.sh +0 -0
  71. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/mdlstm/IAM/lines.txt +0 -0
  72. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/mdlstm/IAM/split/eval.txt +0 -0
  73. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/mdlstm/IAM/split/train.txt +0 -0
  74. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/mdlstm/IAM/split/valid.txt +0 -0
  75. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/mdlstm/README.md +0 -0
  76. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/mdlstm/artificial/create_test_h5.py +0 -0
  77. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/mdlstm/artificial/forwardconfig +0 -0
  78. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/mdlstm/artificial/go.sh +0 -0
  79. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/mdlstm/artificial/trainconfig +0 -0
  80. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/mdlstm/artificial_rgb/create_test_h5.py +0 -0
  81. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/mdlstm/artificial_rgb/forwardconfig +0 -0
  82. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/mdlstm/artificial_rgb/go.sh +0 -0
  83. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/mdlstm/artificial_rgb/trainconfig +0 -0
  84. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/pyproject.toml +0 -0
  85. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/requirements.txt +0 -0
  86. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/__init__.py +0 -0
  87. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/__main__.py +0 -0
  88. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/__old_mod_loader__.py +0 -0
  89. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/__setup__.py +0 -0
  90. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/config.py +0 -0
  91. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/datasets/__init__.py +0 -0
  92. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/datasets/audio.py +0 -0
  93. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/datasets/basic.py +0 -0
  94. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/datasets/bundle_file.py +0 -0
  95. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/datasets/cached.py +0 -0
  96. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/datasets/cached2.py +0 -0
  97. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/datasets/generating.py +0 -0
  98. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/datasets/hdf.py +0 -0
  99. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/datasets/lm.py +0 -0
  100. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/datasets/map.py +0 -0
  101. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/datasets/meta.py +0 -0
  102. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/datasets/multi_proc.py +0 -0
  103. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/datasets/normalization_data.py +0 -0
  104. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/datasets/numpy_dump.py +0 -0
  105. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/datasets/raw_wav.py +0 -0
  106. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/datasets/sprint.py +0 -0
  107. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/datasets/stereo.py +0 -0
  108. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/datasets/util/__init__.py +0 -0
  109. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/datasets/util/feature_extraction.py +0 -0
  110. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/datasets/util/vocabulary.py +0 -0
  111. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/engine/__init__.py +0 -0
  112. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/engine/base.py +0 -0
  113. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/engine/batch.py +0 -0
  114. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/__init__.py +0 -0
  115. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/__main__.py +0 -0
  116. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/warp-rna/.git +0 -0
  117. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/warp-rna/.gitignore +0 -0
  118. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/warp-rna/LICENSE +0 -0
  119. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/warp-rna/README.md +0 -0
  120. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/warp-rna/aligner.gif +0 -0
  121. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/warp-rna/check.png +0 -0
  122. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/warp-rna/core.cu +0 -0
  123. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/warp-rna/core.h +0 -0
  124. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/warp-rna/core_cpu.cpp +0 -0
  125. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/warp-rna/pytorch_binding/LICENSE +0 -0
  126. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/warp-rna/pytorch_binding/MANIFEST.in +0 -0
  127. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/warp-rna/pytorch_binding/README.md +0 -0
  128. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/warp-rna/pytorch_binding/binding.cpp +0 -0
  129. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/warp-rna/pytorch_binding/core.cu +0 -0
  130. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/warp-rna/pytorch_binding/core.h +0 -0
  131. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/warp-rna/pytorch_binding/requirements.txt +0 -0
  132. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/warp-rna/pytorch_binding/setup.py +0 -0
  133. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/warp-rna/pytorch_binding/warp_rna/__init__.py +0 -0
  134. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/warp-rna/pytorch_binding/warp_rna/test.py +0 -0
  135. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/warp-rna/ref_rna.py +0 -0
  136. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/warp-rna/tensorflow_binding/setup.py +0 -0
  137. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/warp-rna/tensorflow_binding/src/warp_rna_op.cc +0 -0
  138. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/warp-rna/tensorflow_binding/src/warp_rna_op_kernel_tmpl.h +0 -0
  139. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/warp-rna/tensorflow_binding/warp_rna/__init__.py +0 -0
  140. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/warp-rna/test.cpp +0 -0
  141. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/__init__.py +0 -0
  142. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/graph_editor/README.md +0 -0
  143. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/graph_editor/__init__.py +0 -0
  144. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/graph_editor/edit.py +0 -0
  145. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/graph_editor/reroute.py +0 -0
  146. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/graph_editor/select.py +0 -0
  147. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/graph_editor/subgraph.py +0 -0
  148. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/graph_editor/transform.py +0 -0
  149. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/graph_editor/util.py +0 -0
  150. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/frontend/__init__.py +0 -0
  151. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/frontend/_backend.py +0 -0
  152. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/frontend/_numpy_backend.py +0 -0
  153. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/frontend/_utils.py +0 -0
  154. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/frontend/array_.py +0 -0
  155. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/frontend/const.py +0 -0
  156. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/frontend/dims.py +0 -0
  157. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/frontend/dtype.py +0 -0
  158. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/frontend/init.py +0 -0
  159. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/frontend/linear.py +0 -0
  160. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/frontend/math_.py +0 -0
  161. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/frontend/matmul.py +0 -0
  162. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/frontend/module.py +0 -0
  163. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/frontend/parameter.py +0 -0
  164. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/frontend/rand.py +0 -0
  165. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/frontend/reduce.py +0 -0
  166. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/frontend/state.py +0 -0
  167. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/frontend/types.py +0 -0
  168. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/import_/__init__.py +0 -0
  169. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/import_/common.py +0 -0
  170. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/import_/git.py +0 -0
  171. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/import_/import_.py +0 -0
  172. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/learning_rate_control.py +0 -0
  173. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/log.py +0 -0
  174. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/native_op.cpp +0 -0
  175. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/native_op.py +0 -0
  176. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/pretrain.py +0 -0
  177. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/sprint/__init__.py +0 -0
  178. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/sprint/cache.py +0 -0
  179. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/sprint/control.py +0 -0
  180. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/sprint/error_signals.py +0 -0
  181. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/sprint/extern_interface.py +0 -0
  182. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/sprint/interface.py +0 -0
  183. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tensor/README.md +0 -0
  184. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tensor/__init__.py +0 -0
  185. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tensor/_dim_extra.py +0 -0
  186. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tensor/_tensor_mixin_base.py +0 -0
  187. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tensor/_tensor_op_overloads.py +0 -0
  188. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tensor/control_flow_ctx.py +0 -0
  189. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tensor/dim.py +0 -0
  190. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tensor/marked_dim.py +0 -0
  191. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tensor/tensor.py +0 -0
  192. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tensor/tensor_dict.py +0 -0
  193. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/__init__.py +0 -0
  194. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/compat.py +0 -0
  195. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/data_pipeline.py +0 -0
  196. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/distributed.py +0 -0
  197. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/engine.py +0 -0
  198. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/frontend_layers/__init__.py +0 -0
  199. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/frontend_layers/_backend.py +0 -0
  200. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/frontend_layers/_utils.py +0 -0
  201. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/frontend_layers/debug_eager_mode.py +0 -0
  202. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/frontend_layers/dims.py +0 -0
  203. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/frontend_layers/layer.py +0 -0
  204. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/frontend_layers/make_layer.py +0 -0
  205. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/frontend_layers/prev_tensor_ref.py +0 -0
  206. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/frontend_low_level/__init__.py +0 -0
  207. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/frontend_low_level/_backend.py +0 -0
  208. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/horovod.py +0 -0
  209. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/hyper_param_tuning.py +0 -0
  210. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/layers/__init__.py +0 -0
  211. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/layers/base.py +0 -0
  212. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/layers/basic.py +0 -0
  213. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/layers/rec.py +0 -0
  214. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/layers/segmental_model.py +0 -0
  215. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/layers/signal_processing.py +0 -0
  216. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/native_op.py +0 -0
  217. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/network.py +0 -0
  218. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/sprint.py +0 -0
  219. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/updater.py +0 -0
  220. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/util/__init__.py +0 -0
  221. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/util/basic.py +0 -0
  222. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/util/data.py +0 -0
  223. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/util/ken_lm.py +0 -0
  224. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/util/open_fst.py +0 -0
  225. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/torch/README.md +0 -0
  226. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/torch/__init__.py +0 -0
  227. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/torch/data/__init__.py +0 -0
  228. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/torch/data/returnn_dataset_wrapper.py +0 -0
  229. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/torch/frontend/__init__.py +0 -0
  230. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/torch/frontend/_backend.py +0 -0
  231. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/torch/functional/README.md +0 -0
  232. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/torch/functional/__init__.py +0 -0
  233. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/torch/updater.py +0 -0
  234. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/util/__init__.py +0 -0
  235. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/util/basic.py +0 -0
  236. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/util/better_exchook.py +0 -0
  237. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/util/bpe.py +0 -0
  238. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/util/debug.py +0 -0
  239. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/util/debug_helpers.py +0 -0
  240. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/util/fsa.py +0 -0
  241. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/util/literal_py_to_pickle.py +0 -0
  242. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/util/pprint.py +0 -0
  243. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/util/py-to-pickle.cpp +0 -0
  244. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/util/sig_proc.py +0 -0
  245. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/util/task_system.py +0 -0
  246. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn.egg-info/SOURCES.txt +0 -0
  247. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn.egg-info/dependency_links.txt +0 -0
  248. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn.egg-info/top_level.txt +0 -0
  249. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/rnn.py +0 -0
  250. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/setup.cfg +0 -0
  251. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/setup.py +0 -0
  252. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/DummySprintExec.py +0 -0
  253. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/PyCharm-inspection-profile.xml +0 -0
  254. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/PyCharm.idea/.gitignore +0 -0
  255. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/PyCharm.idea/.name +0 -0
  256. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/PyCharm.idea/codeStyleSettings.xml +0 -0
  257. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/PyCharm.idea/codeStyles/Project.xml +0 -0
  258. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/PyCharm.idea/codeStyles/codeStyleConfig.xml +0 -0
  259. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/PyCharm.idea/inspectionProfiles/Project_Default.xml +0 -0
  260. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/PyCharm.idea/inspectionProfiles/profiles_settings.xml +0 -0
  261. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/PyCharm.idea/misc.xml +0 -0
  262. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/PyCharm.idea/modules.xml +0 -0
  263. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/PyCharm.idea/returnn.iml +0 -0
  264. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/PyCharm.idea/scopes/scope_settings.xml +0 -0
  265. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/_set_num_threads1.py +0 -0
  266. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/_setup_returnn_env.py +0 -0
  267. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/_setup_test_env.py +0 -0
  268. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/bpe-unicode-demo.codes +0 -0
  269. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/bpe-unicode-demo.vocab +0 -0
  270. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/lexicon_opt.fst +0 -0
  271. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/lexicon_opt.isyms +0 -0
  272. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/lexicon_opt.jpg +0 -0
  273. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/lexicon_opt.osyms +0 -0
  274. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/lint_common.py +0 -0
  275. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/pycharm-inspect.py +0 -0
  276. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/pylint.py +0 -0
  277. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/returnn-as-framework.py +0 -0
  278. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/rf_utils.py +0 -0
  279. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/spelling.dic +0 -0
  280. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_Config.py +0 -0
  281. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_Dataset.py +0 -0
  282. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_Fsa.py +0 -0
  283. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_GeneratingDataset.py +0 -0
  284. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_HDFDataset.py +0 -0
  285. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_LearningRateControl.py +0 -0
  286. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_Log.py +0 -0
  287. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_MultiProcDataset.py +0 -0
  288. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_PTDataset.py +0 -0
  289. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_Pretrain.py +0 -0
  290. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_ResNet.py +0 -0
  291. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_SprintDataset.py +0 -0
  292. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_SprintInterface.py +0 -0
  293. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_TFEngine.py +0 -0
  294. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_TFNativeOp.py +0 -0
  295. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_TFNetworkLayer.py +0 -0
  296. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_TFNetworkRecLayer.py +0 -0
  297. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_TFNetworkSigProcLayer.py +0 -0
  298. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_TFUpdater.py +0 -0
  299. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_TFUtil.py +0 -0
  300. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_TF_determinism.py +0 -0
  301. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_TaskSystem.py +0 -0
  302. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_TaskSystem_SharedMem.py +0 -0
  303. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_TranslationDataset.py +0 -0
  304. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_Util.py +0 -0
  305. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_demos.py +0 -0
  306. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_fork_exec.py +0 -0
  307. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_hdf_dump.py +0 -0
  308. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_rf_base.py +0 -0
  309. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_tensor.py +0 -0
  310. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_tools.py +0 -0
  311. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_torch_frontend.py +0 -0
  312. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_torch_internal_frontend.py +0 -0
  313. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/_setup_returnn_env.py +0 -0
  314. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/analyze-dataset-batches.py +0 -0
  315. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/bliss-collect-seq-lens.py +0 -0
  316. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/bliss-dump-text.py +0 -0
  317. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/bliss-get-segment-names.py +0 -0
  318. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/bliss-to-ogg-zip.py +0 -0
  319. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/bpe-create-lexicon.py +0 -0
  320. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/calculate-word-error-rate.py +0 -0
  321. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/cleanup-old-models.py +0 -0
  322. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/collect-orth-symbols.py +0 -0
  323. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/collect-words.py +0 -0
  324. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/compile_native_op.py +0 -0
  325. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/compile_tf_graph.py +0 -0
  326. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/debug-dump-search-scores.py +0 -0
  327. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/debug-plot-search-scores.py +0 -0
  328. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/dump-dataset-raw-strings.py +0 -0
  329. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/dump-dataset.py +0 -0
  330. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/dump-forward-stats.py +0 -0
  331. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/dump-forward.py +0 -0
  332. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/dump-network-json.py +0 -0
  333. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/dump-pickle.py +0 -0
  334. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/extract_state_tying_from_dataset.py +0 -0
  335. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/get-attention-weights.py +0 -0
  336. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/get-best-model-epoch.py +0 -0
  337. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/hdf_dump.py +0 -0
  338. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/hdf_dump_translation_dataset.py +0 -0
  339. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/import-blocks-mt-model.py +0 -0
  340. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/import-t2t-mt-model.py +0 -0
  341. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/lattice_rescorer/.gitignore +0 -0
  342. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/lattice_rescorer/Makefile +0 -0
  343. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/lattice_rescorer/README.md +0 -0
  344. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/lattice_rescorer/example/README.md +0 -0
  345. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/lattice_rescorer/example/libs_list +0 -0
  346. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/lattice_rescorer/example/network.040/i600_m600_m600.sgd_b16_lr0_cl2.newbobabs.config +0 -0
  347. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/lattice_rescorer/example/network.040/i600_m600_m600.sgd_b16_lr0_cl2.newbobabs.keep_over_epoch.lstm2.config +0 -0
  348. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/lattice_rescorer/example/rescore_lattice.sh +0 -0
  349. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/lattice_rescorer/example/state_vars_list +0 -0
  350. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/lattice_rescorer/example/tensor_names_list +0 -0
  351. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/lattice_rescorer/file.h +0 -0
  352. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/lattice_rescorer/htklatticerescorer.cc +0 -0
  353. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/lattice_rescorer/htklatticerescorer.h +0 -0
  354. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/lattice_rescorer/main.cc +0 -0
  355. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/lattice_rescorer/rescorer.h +0 -0
  356. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/lattice_rescorer/vocabulary.cc +0 -0
  357. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/lattice_rescorer/vocabulary.h +0 -0
  358. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/tf_avg_checkpoints.py +0 -0
  359. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/tf_inspect_checkpoint.py +0 -0
  360. {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/tf_inspect_summary_log.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: returnn
3
- Version: 1.20230403.124714
3
+ Version: 1.20230403.211148
4
4
  Summary: The RWTH extensible training framework for universal recurrent neural networks
5
5
  Home-page: https://github.com/rwth-i6/returnn/
6
6
  Author: Albert Zeyer
@@ -0,0 +1,2 @@
1
+ version = '1.20230403.211148'
2
+ long_version = '1.20230403.211148+git.aa0221f'
@@ -9,11 +9,11 @@ or forwarding loop.
9
9
  from __future__ import annotations
10
10
  from typing import Optional, Union, Any, Sequence, Dict
11
11
  from dataclasses import dataclass
12
- from returnn.tensor import Tensor, Dim
12
+ from returnn.tensor import Tensor, Dim, TensorDict
13
13
  import returnn.frontend as rf
14
14
 
15
15
 
16
- __all__ = ["RunCtx", "Loss", "Output", "get_run_ctx", "init_train_step_run_ctx", "init_forward_step_run_ctx"]
16
+ __all__ = ["RunCtx", "Loss", "get_run_ctx", "init_train_step_run_ctx", "init_forward_step_run_ctx"]
17
17
 
18
18
 
19
19
  _run_ctx = None # type: Optional[RunCtx]
@@ -75,7 +75,7 @@ class RunCtx:
75
75
  """
76
76
  self.stage = stage
77
77
  self.losses = {} # type: Dict[str, Loss]
78
- self.outputs = {} # type: Dict[str, Output]
78
+ self.outputs = TensorDict()
79
79
 
80
80
  def mark_as_loss(
81
81
  self,
@@ -141,7 +141,7 @@ class RunCtx:
141
141
  custom_inv_norm_factor=custom_inv_norm_factor,
142
142
  )
143
143
 
144
- def mark_as_output(self, tensor: Union[Tensor, Any], name: str, *, shape: Optional[Sequence[int]] = None) -> None:
144
+ def mark_as_output(self, tensor: Union[Tensor, Any], name: str, *, dims: Optional[Sequence[int]] = None) -> None:
145
145
  """
146
146
  Mark this as an output.
147
147
  This has the effect that RETURNN will in any case construct the corresponding layer.
@@ -153,7 +153,7 @@ class RunCtx:
153
153
 
154
154
  :param tensor:
155
155
  :param name:
156
- :param shape: this specifies the order of the dims of the output, such that it is well-defined
156
+ :param dims: this specifies the order of the dims of the output, such that it is well-defined
157
157
  for some external application.
158
158
  If not specified, we try to infer BTF or BF as default, if that works, otherwise it will be an error.
159
159
  """
@@ -161,7 +161,32 @@ class RunCtx:
161
161
  if not isinstance(tensor, Tensor):
162
162
  tensor = rf.convert_to_tensor(tensor)
163
163
  assert name not in self.outputs
164
- self.outputs[name] = Output(tensor=tensor, name=name, shape=shape)
164
+ if dims is None:
165
+ rem_dims = list(tensor.dims)
166
+ dims = []
167
+ if tensor.have_batch_axis():
168
+ rem_dims.remove(tensor.get_batch_dim_tag())
169
+ dims.append(tensor.get_batch_dim_tag())
170
+ if tensor.have_time_axis():
171
+ rem_dims.remove(tensor.get_time_dim_tag())
172
+ dims.append(tensor.get_time_dim_tag())
173
+ static_dims = [d for d in dims if d.is_static()]
174
+ if len(static_dims) > 1:
175
+ raise Exception(
176
+ f"Cannot infer order of dims automatically for output {name!r}. Please specify a shape explicitly."
177
+ )
178
+ elif len(static_dims) == 1:
179
+ rem_dims.remove(static_dims[0])
180
+ dims.insert(0, static_dims[0])
181
+ if len(rem_dims) > 1:
182
+ raise Exception(
183
+ f"Cannot infer order of dims automatically for output {name!r}. Please specify a shape explicitly."
184
+ )
185
+ elif len(rem_dims) == 1:
186
+ dims.append(rem_dims[0])
187
+ tensor = tensor.copy_transpose(dims, allow_int=False)
188
+ tensor = tensor.copy(name=name)
189
+ self.outputs.data[name] = tensor
165
190
 
166
191
  def mark_as_default_output(self, tensor: Union[Tensor, Any], *, shape: Optional[Sequence[Dim]] = None) -> None:
167
192
  """
@@ -173,7 +198,7 @@ class RunCtx:
173
198
  :param tensor:
174
199
  :param shape:
175
200
  """
176
- self.mark_as_output(tensor, "output", shape=shape)
201
+ self.mark_as_output(tensor, "output", dims=shape)
177
202
 
178
203
  def total_loss(self) -> Union[Tensor, float]:
179
204
  """
@@ -233,16 +258,3 @@ class Loss:
233
258
  else:
234
259
  loss = self.get_summed_loss()
235
260
  return loss * self.scale
236
-
237
-
238
- @dataclass
239
- class Output:
240
- """
241
- Output via :func:`RunCtx.mark_as_output`.
242
-
243
- We collect all relevant information here.
244
- """
245
-
246
- tensor: Tensor
247
- name: str
248
- shape: Optional[Sequence[Dim]] = None
@@ -651,14 +651,16 @@ class _TensorMixin(_TensorMixinBase):
651
651
  assert self.time_dim_axis is not None
652
652
  return self.copy_move_axis(self.time_dim_axis, time_dim_axis)
653
653
 
654
- def copy_transpose(self, perm) -> _t.Tensor:
654
+ def copy_transpose(self, perm: Sequence[Union[int, Dim]], *, allow_int: bool = True) -> _t.Tensor:
655
655
  """
656
- :param list[int] perm: permutation of the axes, counted with batch-dim.
657
- Maps the new axes to the old axes
656
+ :param perm: permutation of the axes. Maps the new axes to the old axes
657
+ :param allow_int: allow int as axis, otherwise only :class:`Dim`
658
658
  :return: copy of myself with permuted axes
659
659
  """
660
- assert len(perm) == self.batch_ndim
661
- assert set(perm) == set(range(self.batch_ndim))
660
+ assert len(perm) == self.batch_ndim, f"{self}: invalid perm {perm!r} length"
661
+ perm_ = perm
662
+ perm = [self.get_axis_from_description(a, allow_int=allow_int) for a in perm]
663
+ assert set(perm) == set(range(self.batch_ndim)), f"{self}: invalid perm {perm_!r} (axes: {perm!r})"
662
664
  if all(perm[axis] == axis for axis in range(self.batch_ndim)):
663
665
  return self.copy()
664
666
 
@@ -3156,7 +3158,7 @@ class _TensorMixin(_TensorMixinBase):
3156
3158
  """
3157
3159
  import returnn.frontend as rf
3158
3160
 
3159
- rf.get_run_ctx().mark_as_output(self, name=name, shape=shape)
3161
+ rf.get_run_ctx().mark_as_output(self, name=name, dims=shape)
3160
3162
 
3161
3163
  def mark_as_default_output(self: Tensor, *, shape: Optional[Sequence[Dim]] = None) -> None:
3162
3164
  """
@@ -123,12 +123,12 @@ def get_net_dict(
123
123
  # Note that this logic might change.
124
124
  root_scope.marked_losses.append(loss_t)
125
125
 
126
- for out in rf.get_run_ctx().outputs.values():
126
+ for out in rf.get_run_ctx().outputs.data.values():
127
127
  if out.name == "output" and out.name not in root_scope.children:
128
128
  layer = root_scope.get_child(out.name)
129
129
  else:
130
130
  layer = root_scope.get_new_child(suggested_name=out.name)
131
- out_t = _utils.copy(out.tensor, name=layer)
131
+ out_t = _utils.copy(out, name=layer)
132
132
  if layer.name != "output":
133
133
  out_t.raw_tensor.layer_dict["is_output_layer"] = True
134
134
  root_scope.marked_outputs.append(out_t)
@@ -18,17 +18,19 @@ However, having this separate pure PyTorch implementation is useful to allow to
18
18
  other PyTorch datasets more directly, including also HuggingFace datasets.
19
19
  """
20
20
 
21
+ from __future__ import annotations
22
+ from typing import List, Dict
21
23
  import sys
22
24
  from copy import deepcopy
23
25
 
24
- import numpy as np
26
+ import numpy
25
27
  import torch
26
28
  import torch.utils.data
27
29
 
28
30
  from returnn.util.basic import NumbersDict
29
31
 
30
32
 
31
- def create_tensor(array: np.ndarray) -> torch.Tensor:
33
+ def create_tensor(array: numpy.ndarray) -> torch.Tensor:
32
34
  """
33
35
  Adjust non-supported dtypes
34
36
 
@@ -36,14 +38,14 @@ def create_tensor(array: np.ndarray) -> torch.Tensor:
36
38
  """
37
39
  # The only supported PyTorch dtypes are:
38
40
  # float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool.
39
- if array.dtype == np.uint32:
40
- array = np.asarray(array, dtype=np.int64)
41
+ if array.dtype == numpy.uint32:
42
+ array = numpy.asarray(array, dtype=numpy.int64)
41
43
  return torch.tensor(array)
42
44
 
43
45
 
44
- def collate_batch(batch):
46
+ def collate_batch(batch: List[Dict[str, numpy.ndarray]]) -> Dict[str, torch.Tensor]:
45
47
  """
46
- :param list[dict[str, numpy.ndarray]] batch:
48
+ :param batch:
47
49
  """
48
50
  assert isinstance(batch, list)
49
51
  assert batch, "batch is empty?"
@@ -294,6 +294,59 @@ class Engine(EngineBase):
294
294
  print("Load model %s" % (filename,), file=log.v4)
295
295
  model_state = torch.load(filename)
296
296
  self._model.load_state_dict(model_state)
297
+ preload_from_files = self.config.typed_value("preload_from_files", {})
298
+ if preload_from_files:
299
+ # see `preload_from_files` in tf engine and `returnn.tf.network.CustomCheckpointLoader`
300
+ is_training = self.config.value("task", "train") == "train"
301
+ is_first_train_epoch = epoch == 1 and (
302
+ is_training or self.config.value("task", "train") == "initialize_model"
303
+ )
304
+ # We use the reversed sorted order here to achieve consistent behavior with the TF engine.
305
+ # There, the keys are used in sorted order but if a variable is loaded,
306
+ # it will not be considered anymore afterwards.
307
+ # So the first occurrence is used.
308
+ # Here, we overwrite variables even if they have been loaded before.
309
+ # In order to get consistent behavior, we use the reversed order.
310
+ for preload_key, opts in reversed(sorted(preload_from_files.items())):
311
+ assert isinstance(opts, dict) and "filename" in opts
312
+ if opts.get("init_for_train", False):
313
+ if not is_first_train_epoch:
314
+ continue
315
+ else: # default: init for recog
316
+ if is_training:
317
+ continue
318
+ print(f"Pre-load weights for key '{preload_key}' from {opts['filename']}", file=log.v3)
319
+ preload_model_state = torch.load(opts["filename"])
320
+ if opts.get("checkpoint_key", None) is not None:
321
+ # This can be used if an external checkpoint saves a checkpoint a different structure that just the
322
+ # model state dict. E.g., if a checkpoint is created using
323
+ # `torch.save({"model": model.state_dict(), "optimizer": optimizer.state)_dict(), ...})`
324
+ # we can set checkpoint_key = "model" to load the model.
325
+ # Currently, this only supports single level dicts, but it could be extended if needed.
326
+ preload_model_state = preload_model_state[opts["checkpoint_key"]]
327
+ if opts.get("prefix", ""):
328
+ # Only params with this prefix should be loaded.
329
+ # They are expected to be in the checkpoint without this prefix.
330
+ # By adding the prefix to all params,
331
+ # we make sure that only params matching this condition are loaded.
332
+ # This is in line with the behavior of the TF engine.
333
+ preload_model_state = {opts["prefix"] + key: value for key, value in preload_model_state.items()}
334
+ ignore_params = opts.get("ignore_params", [])
335
+ ignore_params_prefixes = opts.get("ignore_params_prefixes", [])
336
+ for key in list(preload_model_state.keys()):
337
+ if key in ignore_params or any(
338
+ [key.startswith(ignore_key) for ignore_key in ignore_params_prefixes]
339
+ ):
340
+ print(f"Ignoring variable {key}", file=log.v3)
341
+ preload_model_state.pop(key)
342
+ for new_name, name_in_checkpoint in opts.get("var_name_mapping", {}).items():
343
+ preload_model_state[new_name] = preload_model_state.pop(name_in_checkpoint)
344
+ missing_keys, _ = self._model.load_state_dict(preload_model_state, strict=False)
345
+ if not opts.get("ignore_missing", False):
346
+ prefix_keys = [key for key in self._model.state_dict() if key.startswith(opts.get("prefix", ""))]
347
+ missing_prefix_keys = set(prefix_keys).intersection(set(missing_keys))
348
+ assert not missing_prefix_keys, f"Missing keys and ignore_missing=False: {missing_prefix_keys}"
349
+ print(f"Missing keys: {missing_keys}", file=log.v4)
297
350
 
298
351
  self._model.to(self._device)
299
352
 
@@ -10,7 +10,8 @@ import warnings
10
10
 
11
11
  def no_grad_trunc_normal_(tensor: torch.Tensor, mean, std, a, b, *, generator=None):
12
12
  """
13
- Code copied and adopted from torch.nn.init._no_grad_trunc_normal_.
13
+ Code copied and adopted from torch.nn.init._no_grad_trunc_normal_,
14
+ to support the extra `generator` argument (https://github.com/pytorch/pytorch/issues/98200).
14
15
 
15
16
  Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
16
17
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: returnn
3
- Version: 1.20230403.124714
3
+ Version: 1.20230403.211148
4
4
  Summary: The RWTH extensible training framework for universal recurrent neural networks
5
5
  Home-page: https://github.com/rwth-i6/returnn/
6
6
  Author: Albert Zeyer
@@ -1,2 +0,0 @@
1
- version = '1.20230403.124714'
2
- long_version = '1.20230403.124714+git.deb7a53'