returnn 1.20230408.155406__tar.gz → 1.20230409.122444__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (364) hide show
  1. {returnn-1.20230408.155406/returnn.egg-info → returnn-1.20230409.122444}/PKG-INFO +1 -1
  2. returnn-1.20230409.122444/_setup_info_generated.py +2 -0
  3. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-rf.config +1 -0
  4. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-torch.config +1 -1
  5. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/frontend/_backend.py +31 -1
  6. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/frontend/array_.py +11 -1
  7. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/frontend/const.py +3 -3
  8. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/frontend/run_ctx.py +29 -14
  9. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tensor/_dim_extra.py +35 -8
  10. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tensor/_tensor_extra.py +39 -1
  11. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/frontend_layers/_backend.py +5 -5
  12. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/frontend_low_level/_backend.py +12 -0
  13. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/torch/engine.py +50 -29
  14. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/torch/frontend/_backend.py +38 -8
  15. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/torch/frontend/bridge.py +3 -0
  16. {returnn-1.20230408.155406 → returnn-1.20230409.122444/returnn.egg-info}/PKG-INFO +1 -1
  17. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_demos.py +7 -0
  18. returnn-1.20230408.155406/_setup_info_generated.py +0 -2
  19. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/.editorconfig +0 -0
  20. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/.gitignore +0 -0
  21. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/.gitmodules +0 -0
  22. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/.kateconfig +0 -0
  23. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/CHANGELOG.md +0 -0
  24. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/CODEOWNERS +0 -0
  25. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/CONTRIBUTING.md +0 -0
  26. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/LICENSE +0 -0
  27. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/MANIFEST.in +0 -0
  28. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/README.rst +0 -0
  29. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/__init__.py +0 -0
  30. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/12AX.cluster_map +0 -0
  31. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/_setup_returnn_env.py +0 -0
  32. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-fwd.config +0 -0
  33. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-horovod-mpi.py +0 -0
  34. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-horovod-mpi.py.sh +0 -0
  35. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-horovod-mpi.sh +0 -0
  36. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-hyper-param-tuning.config +0 -0
  37. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-iter-dataset.py +0 -0
  38. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-list-devices.py +0 -0
  39. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-lua-torch-layer.config +0 -0
  40. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-pretrain.config +0 -0
  41. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-record-and-push-to-webserver.py +0 -0
  42. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-returnn-as-framework.py +0 -0
  43. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-rhn-enwik8.config +0 -0
  44. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-sprint-interface.py +0 -0
  45. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-tf-att-copy.config +0 -0
  46. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-tf-attention.config +0 -0
  47. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-tf-chunking-blstm.12ax.config +0 -0
  48. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-tf-contribrnn-lstm.12ax.config +0 -0
  49. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-tf-enc-dec.config +0 -0
  50. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-tf-hard-att-copy.config +0 -0
  51. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-tf-lstm-benchmark.py +0 -0
  52. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-tf-maxgradnorm-lstm.12ax.config +0 -0
  53. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-tf-native-lstm-lowmem.12ax.config +0 -0
  54. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-tf-native-lstm.12ax.config +0 -0
  55. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-tf-native-lstm2.12ax.config +0 -0
  56. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-tf-native-lstm2.12ax.tuned.config +0 -0
  57. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-tf-neural-transducer.12ax.config +0 -0
  58. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-tf-rec-explicit-lstm.config +0 -0
  59. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-tf-rec-explicit-rnn.config +0 -0
  60. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-tf-rec-self-att.config +0 -0
  61. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-tf-search-compiled-graph.py +0 -0
  62. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-tf-vanilla-lstm.12ax.config +0 -0
  63. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-timit-lstm-ctc.config +0 -0
  64. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo-upd-mult-model.lstm.12ax.config +0 -0
  65. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/demo.sh +0 -0
  66. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/mdlstm/IAM/IAM_lines/a01-000u-00.png +0 -0
  67. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/mdlstm/IAM/IAM_lines/a01-007-04.png +0 -0
  68. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/mdlstm/IAM/IAM_lines/a01-007-06.png +0 -0
  69. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/mdlstm/IAM/README.txt +0 -0
  70. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/mdlstm/IAM/chars.txt +0 -0
  71. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/mdlstm/IAM/config_demo +0 -0
  72. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/mdlstm/IAM/config_fwd +0 -0
  73. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/mdlstm/IAM/config_real +0 -0
  74. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/mdlstm/IAM/create_IAM_dataset.py +0 -0
  75. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/mdlstm/IAM/decode.py +0 -0
  76. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/mdlstm/IAM/features/raw/demo.h5 +0 -0
  77. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/mdlstm/IAM/go.sh +0 -0
  78. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/mdlstm/IAM/lines.txt +0 -0
  79. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/mdlstm/IAM/split/eval.txt +0 -0
  80. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/mdlstm/IAM/split/train.txt +0 -0
  81. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/mdlstm/IAM/split/valid.txt +0 -0
  82. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/mdlstm/README.md +0 -0
  83. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/mdlstm/artificial/create_test_h5.py +0 -0
  84. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/mdlstm/artificial/forwardconfig +0 -0
  85. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/mdlstm/artificial/go.sh +0 -0
  86. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/mdlstm/artificial/trainconfig +0 -0
  87. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/mdlstm/artificial_rgb/create_test_h5.py +0 -0
  88. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/mdlstm/artificial_rgb/forwardconfig +0 -0
  89. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/mdlstm/artificial_rgb/go.sh +0 -0
  90. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/demos/mdlstm/artificial_rgb/trainconfig +0 -0
  91. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/pyproject.toml +0 -0
  92. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/requirements.txt +0 -0
  93. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/__init__.py +0 -0
  94. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/__main__.py +0 -0
  95. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/__old_mod_loader__.py +0 -0
  96. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/__setup__.py +0 -0
  97. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/config.py +0 -0
  98. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/datasets/__init__.py +0 -0
  99. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/datasets/audio.py +0 -0
  100. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/datasets/basic.py +0 -0
  101. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/datasets/bundle_file.py +0 -0
  102. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/datasets/cached.py +0 -0
  103. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/datasets/cached2.py +0 -0
  104. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/datasets/generating.py +0 -0
  105. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/datasets/hdf.py +0 -0
  106. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/datasets/lm.py +0 -0
  107. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/datasets/map.py +0 -0
  108. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/datasets/meta.py +0 -0
  109. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/datasets/multi_proc.py +0 -0
  110. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/datasets/normalization_data.py +0 -0
  111. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/datasets/numpy_dump.py +0 -0
  112. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/datasets/raw_wav.py +0 -0
  113. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/datasets/sprint.py +0 -0
  114. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/datasets/stereo.py +0 -0
  115. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/datasets/util/__init__.py +0 -0
  116. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/datasets/util/feature_extraction.py +0 -0
  117. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/datasets/util/vocabulary.py +0 -0
  118. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/engine/__init__.py +0 -0
  119. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/engine/base.py +0 -0
  120. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/engine/batch.py +0 -0
  121. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/WarpRna/__init__.py +0 -0
  122. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/WarpRna/__main__.py +0 -0
  123. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/WarpRna/warp-rna/.git +0 -0
  124. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/WarpRna/warp-rna/.gitignore +0 -0
  125. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/WarpRna/warp-rna/LICENSE +0 -0
  126. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/WarpRna/warp-rna/README.md +0 -0
  127. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/WarpRna/warp-rna/aligner.gif +0 -0
  128. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/WarpRna/warp-rna/check.png +0 -0
  129. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/WarpRna/warp-rna/core.cu +0 -0
  130. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/WarpRna/warp-rna/core.h +0 -0
  131. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/WarpRna/warp-rna/core_cpu.cpp +0 -0
  132. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/WarpRna/warp-rna/pytorch_binding/LICENSE +0 -0
  133. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/WarpRna/warp-rna/pytorch_binding/MANIFEST.in +0 -0
  134. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/WarpRna/warp-rna/pytorch_binding/README.md +0 -0
  135. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/WarpRna/warp-rna/pytorch_binding/binding.cpp +0 -0
  136. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/WarpRna/warp-rna/pytorch_binding/core.cu +0 -0
  137. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/WarpRna/warp-rna/pytorch_binding/core.h +0 -0
  138. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/WarpRna/warp-rna/pytorch_binding/requirements.txt +0 -0
  139. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/WarpRna/warp-rna/pytorch_binding/setup.py +0 -0
  140. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/WarpRna/warp-rna/pytorch_binding/warp_rna/__init__.py +0 -0
  141. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/WarpRna/warp-rna/pytorch_binding/warp_rna/test.py +0 -0
  142. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/WarpRna/warp-rna/ref_rna.py +0 -0
  143. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/WarpRna/warp-rna/tensorflow_binding/setup.py +0 -0
  144. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/WarpRna/warp-rna/tensorflow_binding/src/warp_rna_op.cc +0 -0
  145. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/WarpRna/warp-rna/tensorflow_binding/src/warp_rna_op_kernel_tmpl.h +0 -0
  146. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/WarpRna/warp-rna/tensorflow_binding/warp_rna/__init__.py +0 -0
  147. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/WarpRna/warp-rna/test.cpp +0 -0
  148. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/__init__.py +0 -0
  149. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/graph_editor/README.md +0 -0
  150. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/graph_editor/__init__.py +0 -0
  151. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/graph_editor/edit.py +0 -0
  152. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/graph_editor/reroute.py +0 -0
  153. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/graph_editor/select.py +0 -0
  154. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/graph_editor/subgraph.py +0 -0
  155. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/graph_editor/transform.py +0 -0
  156. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/extern/graph_editor/util.py +0 -0
  157. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/frontend/__init__.py +0 -0
  158. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/frontend/_numpy_backend.py +0 -0
  159. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/frontend/_utils.py +0 -0
  160. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/frontend/dims.py +0 -0
  161. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/frontend/dtype.py +0 -0
  162. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/frontend/init.py +0 -0
  163. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/frontend/linear.py +0 -0
  164. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/frontend/loss.py +0 -0
  165. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/frontend/math_.py +0 -0
  166. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/frontend/matmul.py +0 -0
  167. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/frontend/module.py +0 -0
  168. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/frontend/parameter.py +0 -0
  169. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/frontend/rand.py +0 -0
  170. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/frontend/reduce.py +0 -0
  171. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/frontend/state.py +0 -0
  172. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/frontend/types.py +0 -0
  173. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/import_/__init__.py +0 -0
  174. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/import_/common.py +0 -0
  175. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/import_/git.py +0 -0
  176. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/import_/import_.py +0 -0
  177. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/learning_rate_control.py +0 -0
  178. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/log.py +0 -0
  179. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/native_op.cpp +0 -0
  180. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/native_op.py +0 -0
  181. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/pretrain.py +0 -0
  182. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/sprint/__init__.py +0 -0
  183. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/sprint/cache.py +0 -0
  184. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/sprint/control.py +0 -0
  185. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/sprint/error_signals.py +0 -0
  186. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/sprint/extern_interface.py +0 -0
  187. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/sprint/interface.py +0 -0
  188. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tensor/README.md +0 -0
  189. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tensor/__init__.py +0 -0
  190. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tensor/_tensor_mixin_base.py +0 -0
  191. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tensor/_tensor_op_overloads.py +0 -0
  192. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tensor/control_flow_ctx.py +0 -0
  193. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tensor/dim.py +0 -0
  194. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tensor/marked_dim.py +0 -0
  195. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tensor/tensor.py +0 -0
  196. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tensor/tensor_dict.py +0 -0
  197. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/__init__.py +0 -0
  198. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/compat.py +0 -0
  199. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/data_pipeline.py +0 -0
  200. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/distributed.py +0 -0
  201. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/engine.py +0 -0
  202. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/frontend_layers/__init__.py +0 -0
  203. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/frontend_layers/_utils.py +0 -0
  204. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/frontend_layers/config_entry_points.py +0 -0
  205. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/frontend_layers/debug_eager_mode.py +0 -0
  206. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/frontend_layers/dims.py +0 -0
  207. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/frontend_layers/layer.py +0 -0
  208. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/frontend_layers/make_layer.py +0 -0
  209. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/frontend_layers/prev_tensor_ref.py +0 -0
  210. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/frontend_low_level/__init__.py +0 -0
  211. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/horovod.py +0 -0
  212. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/hyper_param_tuning.py +0 -0
  213. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/layers/__init__.py +0 -0
  214. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/layers/base.py +0 -0
  215. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/layers/basic.py +0 -0
  216. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/layers/rec.py +0 -0
  217. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/layers/segmental_model.py +0 -0
  218. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/layers/signal_processing.py +0 -0
  219. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/native_op.py +0 -0
  220. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/network.py +0 -0
  221. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/sprint.py +0 -0
  222. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/updater.py +0 -0
  223. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/util/__init__.py +0 -0
  224. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/util/basic.py +0 -0
  225. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/util/data.py +0 -0
  226. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/util/ken_lm.py +0 -0
  227. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/tf/util/open_fst.py +0 -0
  228. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/torch/README.md +0 -0
  229. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/torch/__init__.py +0 -0
  230. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/torch/data/__init__.py +0 -0
  231. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/torch/data/pipeline.py +0 -0
  232. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/torch/data/returnn_dataset_wrapper.py +0 -0
  233. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/torch/data/tensor_utils.py +0 -0
  234. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/torch/frontend/__init__.py +0 -0
  235. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/torch/frontend/_rand.py +0 -0
  236. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/torch/functional/README.md +0 -0
  237. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/torch/functional/__init__.py +0 -0
  238. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/torch/updater.py +0 -0
  239. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/util/__init__.py +0 -0
  240. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/util/basic.py +0 -0
  241. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/util/better_exchook.py +0 -0
  242. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/util/bpe.py +0 -0
  243. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/util/debug.py +0 -0
  244. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/util/debug_helpers.py +0 -0
  245. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/util/fsa.py +0 -0
  246. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/util/literal_py_to_pickle.py +0 -0
  247. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/util/pprint.py +0 -0
  248. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/util/py-to-pickle.cpp +0 -0
  249. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/util/sig_proc.py +0 -0
  250. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn/util/task_system.py +0 -0
  251. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn.egg-info/SOURCES.txt +0 -0
  252. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn.egg-info/dependency_links.txt +0 -0
  253. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/returnn.egg-info/top_level.txt +0 -0
  254. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/rnn.py +0 -0
  255. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/setup.cfg +0 -0
  256. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/setup.py +0 -0
  257. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/DummySprintExec.py +0 -0
  258. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/PyCharm-inspection-profile.xml +0 -0
  259. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/PyCharm.idea/.gitignore +0 -0
  260. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/PyCharm.idea/.name +0 -0
  261. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/PyCharm.idea/codeStyleSettings.xml +0 -0
  262. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/PyCharm.idea/codeStyles/Project.xml +0 -0
  263. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/PyCharm.idea/codeStyles/codeStyleConfig.xml +0 -0
  264. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/PyCharm.idea/inspectionProfiles/Project_Default.xml +0 -0
  265. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/PyCharm.idea/inspectionProfiles/profiles_settings.xml +0 -0
  266. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/PyCharm.idea/misc.xml +0 -0
  267. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/PyCharm.idea/modules.xml +0 -0
  268. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/PyCharm.idea/returnn.iml +0 -0
  269. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/PyCharm.idea/scopes/scope_settings.xml +0 -0
  270. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/_set_num_threads1.py +0 -0
  271. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/_setup_returnn_env.py +0 -0
  272. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/_setup_test_env.py +0 -0
  273. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/bpe-unicode-demo.codes +0 -0
  274. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/bpe-unicode-demo.vocab +0 -0
  275. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/lexicon_opt.fst +0 -0
  276. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/lexicon_opt.isyms +0 -0
  277. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/lexicon_opt.jpg +0 -0
  278. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/lexicon_opt.osyms +0 -0
  279. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/lint_common.py +0 -0
  280. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/pycharm-inspect.py +0 -0
  281. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/pylint.py +0 -0
  282. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/returnn-as-framework.py +0 -0
  283. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/rf_utils.py +0 -0
  284. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/spelling.dic +0 -0
  285. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_Config.py +0 -0
  286. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_Dataset.py +0 -0
  287. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_Fsa.py +0 -0
  288. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_GeneratingDataset.py +0 -0
  289. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_HDFDataset.py +0 -0
  290. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_LearningRateControl.py +0 -0
  291. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_Log.py +0 -0
  292. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_MultiProcDataset.py +0 -0
  293. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_PTDataset.py +0 -0
  294. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_Pretrain.py +0 -0
  295. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_ResNet.py +0 -0
  296. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_SprintDataset.py +0 -0
  297. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_SprintInterface.py +0 -0
  298. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_TFEngine.py +0 -0
  299. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_TFNativeOp.py +0 -0
  300. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_TFNetworkLayer.py +0 -0
  301. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_TFNetworkRecLayer.py +0 -0
  302. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_TFNetworkSigProcLayer.py +0 -0
  303. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_TFUpdater.py +0 -0
  304. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_TFUtil.py +0 -0
  305. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_TF_determinism.py +0 -0
  306. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_TaskSystem.py +0 -0
  307. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_TaskSystem_SharedMem.py +0 -0
  308. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_TranslationDataset.py +0 -0
  309. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_Util.py +0 -0
  310. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_fork_exec.py +0 -0
  311. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_hdf_dump.py +0 -0
  312. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_rf_base.py +0 -0
  313. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_tensor.py +0 -0
  314. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_tools.py +0 -0
  315. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_torch_frontend.py +0 -0
  316. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tests/test_torch_internal_frontend.py +0 -0
  317. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/_setup_returnn_env.py +0 -0
  318. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/analyze-dataset-batches.py +0 -0
  319. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/bliss-collect-seq-lens.py +0 -0
  320. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/bliss-dump-text.py +0 -0
  321. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/bliss-get-segment-names.py +0 -0
  322. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/bliss-to-ogg-zip.py +0 -0
  323. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/bpe-create-lexicon.py +0 -0
  324. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/calculate-word-error-rate.py +0 -0
  325. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/cleanup-old-models.py +0 -0
  326. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/collect-orth-symbols.py +0 -0
  327. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/collect-words.py +0 -0
  328. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/compile_native_op.py +0 -0
  329. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/compile_tf_graph.py +0 -0
  330. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/debug-dump-search-scores.py +0 -0
  331. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/debug-plot-search-scores.py +0 -0
  332. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/dump-dataset-raw-strings.py +0 -0
  333. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/dump-dataset.py +0 -0
  334. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/dump-forward-stats.py +0 -0
  335. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/dump-forward.py +0 -0
  336. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/dump-network-json.py +0 -0
  337. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/dump-pickle.py +0 -0
  338. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/extract_state_tying_from_dataset.py +0 -0
  339. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/get-attention-weights.py +0 -0
  340. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/get-best-model-epoch.py +0 -0
  341. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/hdf_dump.py +0 -0
  342. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/hdf_dump_translation_dataset.py +0 -0
  343. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/import-blocks-mt-model.py +0 -0
  344. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/import-t2t-mt-model.py +0 -0
  345. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/lattice_rescorer/.gitignore +0 -0
  346. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/lattice_rescorer/Makefile +0 -0
  347. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/lattice_rescorer/README.md +0 -0
  348. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/lattice_rescorer/example/README.md +0 -0
  349. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/lattice_rescorer/example/libs_list +0 -0
  350. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/lattice_rescorer/example/network.040/i600_m600_m600.sgd_b16_lr0_cl2.newbobabs.config +0 -0
  351. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/lattice_rescorer/example/network.040/i600_m600_m600.sgd_b16_lr0_cl2.newbobabs.keep_over_epoch.lstm2.config +0 -0
  352. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/lattice_rescorer/example/rescore_lattice.sh +0 -0
  353. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/lattice_rescorer/example/state_vars_list +0 -0
  354. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/lattice_rescorer/example/tensor_names_list +0 -0
  355. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/lattice_rescorer/file.h +0 -0
  356. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/lattice_rescorer/htklatticerescorer.cc +0 -0
  357. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/lattice_rescorer/htklatticerescorer.h +0 -0
  358. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/lattice_rescorer/main.cc +0 -0
  359. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/lattice_rescorer/rescorer.h +0 -0
  360. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/lattice_rescorer/vocabulary.cc +0 -0
  361. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/lattice_rescorer/vocabulary.h +0 -0
  362. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/tf_avg_checkpoints.py +0 -0
  363. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/tf_inspect_checkpoint.py +0 -0
  364. {returnn-1.20230408.155406 → returnn-1.20230409.122444}/tools/tf_inspect_summary_log.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: returnn
3
- Version: 1.20230408.155406
3
+ Version: 1.20230409.122444
4
4
  Summary: The RWTH extensible training framework for universal recurrent neural networks
5
5
  Home-page: https://github.com/rwth-i6/returnn/
6
6
  Author: Albert Zeyer
@@ -0,0 +1,2 @@
1
+ version = '1.20230409.122444'
2
+ long_version = '1.20230409.122444+git.0fd75ff'
@@ -50,6 +50,7 @@ def train_step(*, model: Model, extern_data, **_kwargs):
50
50
  data = extern_data["data"]
51
51
  logits = model(data)
52
52
  targets = extern_data["classes"]
53
+ # TODO: use flattening on logits/targets
53
54
  loss = rf.cross_entropy(estimated=logits, estimated_type="logits", target=targets, axis=out_dim)
54
55
  loss.mark_as_loss(name="ce")
55
56
 
@@ -58,7 +58,7 @@ def train_step(*, model: Model, extern_data, **_kwargs):
58
58
  targets = extern_data["classes"]
59
59
  targets_packed = torch.nn.utils.rnn.pack_padded_sequence(
60
60
  targets.raw_tensor, data.dims[1].dyn_size_ext.raw_tensor, batch_first=True, enforce_sorted=False)
61
- loss = nn.CrossEntropyLoss()(logits_packed.data, targets_packed.data.long())
61
+ loss = nn.CrossEntropyLoss(reduction='none')(logits_packed.data, targets_packed.data.long())
62
62
  rf.get_run_ctx().mark_as_loss(name="cross_entropy", loss=loss)
63
63
 
64
64
 
@@ -208,6 +208,29 @@ class Backend(Generic[T]):
208
208
  """
209
209
  raise NotImplementedError
210
210
 
211
+ @staticmethod
212
+ def cast_raw(raw_tensor: T, dtype: str) -> T:
213
+ """
214
+ :param raw_tensor:
215
+ :param dtype: e.g. "float32"
216
+ :return: raw tensor with dtype casted
217
+ """
218
+ raise NotImplementedError
219
+
220
+ @staticmethod
221
+ def cast(tensor: Tensor, dtype: str) -> Tensor:
222
+ """
223
+ :param tensor:
224
+ :param dtype: e.g. "float32"
225
+ :return: tensor with dtype casted
226
+ """
227
+ # Default implementation using cast_raw.
228
+ res = tensor.copy_template()
229
+ res.dtype = dtype
230
+ # noinspection PyProtectedMember
231
+ res.raw_tensor = tensor._raw_backend.cast_raw(tensor.raw_tensor, dtype)
232
+ return res
233
+
211
234
  # Restrict the possible activation function names,
212
235
  # to not get unexpected behavior,
213
236
  # or unwanted incompatibilities.
@@ -287,6 +310,13 @@ class Backend(Generic[T]):
287
310
  """
288
311
  raise NotImplementedError
289
312
 
313
+ @staticmethod
314
+ def have_sequence_mask_raw() -> bool:
315
+ """
316
+ :return: whether we have a sequence_mask_raw implementation
317
+ """
318
+ return False
319
+
290
320
  @staticmethod
291
321
  def sequence_mask_raw(lengths: T, *, batch_major: bool = True) -> T:
292
322
  """
@@ -309,7 +339,7 @@ class Backend(Generic[T]):
309
339
  :return: context manager
310
340
  """
311
341
  # Default implementation for eager-based frameworks
312
- pass # nothing to do
342
+ yield # nothing to do
313
343
 
314
344
  @staticmethod
315
345
  @contextlib.contextmanager
@@ -12,7 +12,7 @@ from .types import RawTensorTypes
12
12
 
13
13
  T = TypeVar("T")
14
14
 
15
- __all__ = ["convert_to_tensor", "constant", "gather"]
15
+ __all__ = ["convert_to_tensor", "constant", "cast", "gather"]
16
16
 
17
17
 
18
18
  def convert_to_tensor(
@@ -77,6 +77,16 @@ def convert_to_tensor(
77
77
  constant = convert_to_tensor # alias for some older code
78
78
 
79
79
 
80
+ def cast(tensor: Tensor, dtype: str) -> Tensor:
81
+ """
82
+ :param tensor:
83
+ :param dtype:
84
+ :return: tensor with the same data, but with a different dtype
85
+ """
86
+ # noinspection PyProtectedMember
87
+ return tensor._raw_backend.cast(tensor, dtype=dtype)
88
+
89
+
80
90
  # noinspection PyUnusedLocal
81
91
  def gather(
82
92
  source: Tensor,
@@ -14,7 +14,7 @@ __all__ = ["full", "constant", "fill", "zeros", "ones"]
14
14
 
15
15
 
16
16
  def full(
17
- dims: Sequence[Dim], fill_value: RawTensorTypes, *, dtype: Optional[str] = None, sparse_dim: Optional[Dim] = None
17
+ *, dims: Sequence[Dim], fill_value: RawTensorTypes, dtype: Optional[str] = None, sparse_dim: Optional[Dim] = None
18
18
  ) -> Tensor:
19
19
  """
20
20
  full
@@ -46,11 +46,11 @@ def zeros(dims: Sequence[Dim], *, dtype: Optional[str] = None, sparse_dim: Optio
46
46
  """
47
47
  zeros. float by default.
48
48
  """
49
- return full(dims, 0, dtype=dtype or rf.get_default_float_dtype(), sparse_dim=sparse_dim)
49
+ return full(dims=dims, fill_value=0, dtype=dtype or rf.get_default_float_dtype(), sparse_dim=sparse_dim)
50
50
 
51
51
 
52
52
  def ones(dims: Sequence[Dim], *, dtype: Optional[str] = None, sparse_dim: Optional[Dim] = None) -> Tensor:
53
53
  """
54
54
  ones. float by default.
55
55
  """
56
- return full(dims, 1, dtype=dtype or rf.get_default_float_dtype(), sparse_dim=sparse_dim)
56
+ return full(dims=dims, fill_value=1, dtype=dtype or rf.get_default_float_dtype(), sparse_dim=sparse_dim)
@@ -119,24 +119,18 @@ class RunCtx:
119
119
  E.g. if the overall normalization is sum(loss)/sum(num_frames), this is also what the optimizer will use,
120
120
  otherwise the optimizer will just use sum(loss).
121
121
  :param custom_inv_norm_factor:
122
- The standard norm factor is 1/sum(target_seq_len) if the target has a time-axis,
123
- or 1/sum(output_seq_len) if there is no target and the output has a time-axis,
122
+ The standard inv norm factor is sum(target_seq_len) if the target has a time-axis,
123
+ or sum(output_seq_len) if there is no target and the output has a time-axis,
124
124
  or 1 otherwise. (See :func:`Loss.init` for details.)
125
125
  This is used for proper normalization of accumulated loss/error per epoch
126
126
  and also proper normalization per batch for reporting,
127
127
  no matter if use_normalized_loss is True or False.
128
128
  If you want to change this norm factor, you can set this.
129
- Basically, for all reporting, it uses sum(loss) * sum(custom_inv_norm_factor).
129
+ Basically, for all reporting, it uses sum(loss) / sum(custom_inv_norm_factor).
130
130
  """
131
131
  assert self.stage == "train_step"
132
132
  if not isinstance(loss, Tensor):
133
133
  assert isinstance(loss, _backend.global_backend.RawTensorType)
134
- assert _backend.global_backend.get_ndim_raw(loss) == 0, (
135
- f"mark_as_loss(<loss with shape {_backend.global_backend.get_known_shape_raw(loss)}>, {name!r}):"
136
- " Only scalar raw losses are supported,"
137
- " because we cannot know whether there are any dynamic dims which might require padding."
138
- " Explicitly convert to a Tensor first and specify dim tags."
139
- )
140
134
  loss = rf.convert_to_tensor(loss)
141
135
  assert name not in self.losses
142
136
  self.losses[name] = Loss(
@@ -220,31 +214,52 @@ class Loss:
220
214
 
221
215
  scale: float = 1.0
222
216
  as_error: bool = False
223
- use_normalized_loss: bool = False
217
+ use_normalized_loss: bool = False # for the gradient / total loss
224
218
  use_flatten_frames: bool = True
225
219
  custom_inv_norm_factor: Optional[Tensor] = None
226
220
 
221
+ _summed_loss_cached: Optional[Tensor] = None
222
+ _mean_loss_cached: Optional[Tensor] = None
223
+
227
224
  def get_summed_loss(self) -> Tensor:
228
225
  """
229
226
  :return: sum of loss (scalar)
230
227
  """
231
228
  if not self.loss.dims:
232
229
  return self.loss
233
- return rf.reduce_sum(self.loss, axis=self.loss.dims)
230
+ if self._summed_loss_cached is not None:
231
+ return self._summed_loss_cached
232
+ if self._mean_loss_cached is not None:
233
+ return self._mean_loss_cached / self.get_inv_norm_factor()
234
+ self._summed_loss_cached = rf.reduce_sum(self.loss, axis=self.loss.dims)
235
+ return self._summed_loss_cached
234
236
 
235
237
  def get_mean_loss(self) -> Tensor:
236
238
  """
237
239
  :return: sum of loss (scalar)
238
240
  """
241
+ if self._mean_loss_cached is not None:
242
+ return self._mean_loss_cached
239
243
  if self.custom_inv_norm_factor:
240
- return self.get_summed_loss() * self.custom_inv_norm_factor
244
+ loss = self.get_summed_loss()
245
+ loss /= rf.cast(self.custom_inv_norm_factor, dtype=loss.dtype)
246
+ return loss
241
247
  if not self.loss.dims:
242
248
  return self.loss
243
- return rf.reduce_mean(self.loss, axis=self.loss.dims)
249
+ self._mean_loss_cached = rf.reduce_mean(self.loss, axis=self.loss.dims)
250
+ return self._mean_loss_cached
251
+
252
+ def get_inv_norm_factor(self) -> Union[int, Tensor]:
253
+ """
254
+ :return: inverse norm factor (scalar)
255
+ """
256
+ if self.custom_inv_norm_factor:
257
+ return self.custom_inv_norm_factor
258
+ return self.loss.num_elements()
244
259
 
245
260
  def get_scaled_reduced_loss(self) -> Tensor:
246
261
  """
247
- :return: scaled reduced loss (scalar), as it is supposed to be used for calculating the
262
+ :return: scaled reduced loss (scalar), as it is supposed to be used for calculating the train gradient
248
263
  """
249
264
  if self.use_normalized_loss:
250
265
  loss = self.get_mean_loss()
@@ -747,7 +747,11 @@ class _DimMixin:
747
747
  :return: whether dim is static or dynamic but with scalar dyn_size_ext
748
748
  """
749
749
  if self.is_static():
750
+ if self.capacity is not None:
751
+ return self.size < self.capacity
750
752
  return False
753
+ if self.capacity is not None:
754
+ return True
751
755
  if not self.dyn_size_ext:
752
756
  return True # unknown
753
757
  return self.dyn_size_ext.batch_ndim > 0
@@ -1516,6 +1520,21 @@ class _DimMixin:
1516
1520
  If `self.src_data` has a placeholder, will use the shape from there.
1517
1521
  Otherwise, uses `self.dimension` (if static) or `self.dyn_size` (if dynamic).
1518
1522
 
1523
+ :return: max(size or dyn_size)
1524
+ """
1525
+ res = self.get_dim_value_tensor()
1526
+ if isinstance(res, _t.Tensor):
1527
+ assert res.dims == ()
1528
+ return res.raw_tensor
1529
+ assert isinstance(res, int)
1530
+ return res
1531
+
1532
+ def get_dim_value_tensor(self) -> Union[int, _t.Tensor]:
1533
+ """
1534
+ Infers the dim this axis should have if unbroadcasted.
1535
+ If `self.src_data` has a placeholder, will use the shape from there.
1536
+ Otherwise, uses `self.dimension` (if static) or `self.dyn_size` (if dynamic).
1537
+
1519
1538
  :return: max(size or dyn_size)
1520
1539
  """
1521
1540
  import returnn.frontend as rf
@@ -1530,25 +1549,33 @@ class _DimMixin:
1530
1549
  # Masking is not always possible here, e.g.
1531
1550
  # self = Dim{'self-att-keys'['time:var:extern_data:classes'[B]]}.
1532
1551
  use_time_mask=False,
1533
- ).raw_tensor
1534
- return self.dyn_size_ext.placeholder
1552
+ )
1553
+ return self.dyn_size_ext
1535
1554
  if self.is_batch_dim():
1555
+ res = None
1536
1556
  if self._extra and self._extra.src_data:
1537
- return self._extra.src_data.get_batch_dim()
1538
- if self.batch:
1539
- return self.batch.dim
1557
+ res = self._extra.src_data.get_batch_dim()
1558
+ elif self.batch:
1559
+ res = self.batch.dim
1560
+ if isinstance(res, int):
1561
+ return res
1562
+ if res is not None:
1563
+ return _t.Tensor("batch", dims=(), dtype=rf.get_default_array_index_dtype(), raw_tensor=res)
1540
1564
  if (
1541
1565
  self._extra
1542
1566
  and self._extra.src_data is not None
1543
1567
  and self._extra.src_axis is not None
1544
1568
  and self._extra.src_data.placeholder is not None
1545
1569
  ):
1546
- return self._extra.src_data.get_dim(self._extra.src_axis)
1570
+ res = self._extra.src_data.get_dim(self._extra.src_axis)
1571
+ if isinstance(res, int):
1572
+ return res
1573
+ return _t.Tensor("batch", dims=(), dtype=rf.get_default_array_index_dtype(), raw_tensor=res)
1547
1574
  self.complete_dyn_size()
1548
1575
  if self.dyn_size_ext and self.dyn_size_ext.placeholder is not None:
1549
1576
  if self.dyn_size_ext.batch_ndim > 0:
1550
- return rf.reduce_max(self.dyn_size_ext, axis=self.dyn_size_ext.dim_tags).raw_tensor
1551
- return self.dyn_size_ext.placeholder
1577
+ return rf.reduce_max(self.dyn_size_ext, axis=self.dyn_size_ext.dim_tags)
1578
+ return self.dyn_size_ext
1552
1579
  raise Exception("%s: need placeholder, self.dimension or self.dyn_size for dim value" % self)
1553
1580
 
1554
1581
  def axis_split_info(self):
@@ -2676,7 +2676,11 @@ class _TensorMixin(_TensorMixinBase):
2676
2676
  backend = tag.dyn_size_ext._raw_backend
2677
2677
  assert set(tag.dyn_size_ext.dim_tags).issubset(self.dim_tags) # https://github.com/rwth-i6/returnn/issues/721
2678
2678
  with backend.name_scope_raw("get_sequence_mask_broadcast"):
2679
- if tag.dyn_size_ext.have_batch_axis() and tag.dyn_size_ext.batch_ndim == 1: # just [B]
2679
+ if (
2680
+ backend.have_sequence_mask_raw()
2681
+ and tag.dyn_size_ext.have_batch_axis()
2682
+ and tag.dyn_size_ext.batch_ndim == 1
2683
+ ): # just [B]
2680
2684
  # This is the common case where the size is of shape [B].
2681
2685
  # We make use of sequence_mask or sequence_mask_time_major in that case,
2682
2686
  # which is optimized by caching.
@@ -2733,11 +2737,45 @@ class _TensorMixin(_TensorMixinBase):
2733
2737
  assert tag.dyn_size_ext
2734
2738
  return tag.dyn_size_ext.copy_compatible_to(self, check_dtype=False, check_sparse=False).placeholder
2735
2739
 
2740
+ def num_elements(self: Tensor) -> Union[int, Tensor]:
2741
+ """
2742
+ :return: number of elements in this tensor, i.e. prod(self.shape)
2743
+ :rtype: tf.Tensor
2744
+ """
2745
+ if all(dim.is_static() for dim in self.dims):
2746
+ n = 1
2747
+ for dim in self.dims:
2748
+ n *= dim.dimension
2749
+ return n
2750
+
2751
+ import returnn.frontend as rf
2752
+
2753
+ n = 1
2754
+ dims = list(self.dims)
2755
+ dims.sort(key=lambda dim: -dim.dyn_size_ext.batch_ndim if dim.dyn_size_ext else 0)
2756
+ while dims:
2757
+ dim = dims.pop(0)
2758
+ if dim.is_static():
2759
+ n *= dim.dimension
2760
+ continue
2761
+ # E.g. dyn_size_ext is shape [B], and self has shape [B,T].
2762
+ # Due to the sorting of dims above, dims will be [T,B], and we will first process T.
2763
+ # We want to sum over dyn_size_ext, but then we need to remove the other dims it covers.
2764
+ for dim_ in dim.dyn_size_ext.dims:
2765
+ assert dim_ in dims # num elements not really well-defined then
2766
+ assert not dim_.need_masking() # not implemented
2767
+ dims.remove(dim_)
2768
+ n_ = rf.reduce_sum(dim.dyn_size_ext, axis=dim.dyn_size_ext.dims)
2769
+ n *= n_
2770
+ return n
2771
+
2736
2772
  def copy_masked(self: Tensor, mask_value) -> Tensor:
2737
2773
  """
2738
2774
  :param float|int|tf.Tensor mask_value:
2739
2775
  """
2740
2776
  assert self.placeholder is not None
2777
+ if not any(dim.need_masking() for dim in self.dims):
2778
+ return self.copy()
2741
2779
  assert self._raw_backend.is_tensorflow # not implemented otherwise for now
2742
2780
  from returnn.tf.util.basic import mask_dyn_seq_len_nd
2743
2781
 
@@ -122,6 +122,11 @@ class ReturnnLayersBackend(Backend[Layer]):
122
122
  """transpose_raw is a no-op in this backend"""
123
123
  return raw_tensor
124
124
 
125
+ @staticmethod
126
+ def cast(tensor: Tensor, dtype: str) -> Tensor:
127
+ """cast"""
128
+ return rfl.make_layer({"class": "cast", "from": tensor, "dtype": dtype}, name="cast")
129
+
125
130
  @staticmethod
126
131
  def activation(tensor: Tensor, func: str) -> Tensor:
127
132
  """activation"""
@@ -172,11 +177,6 @@ class ReturnnLayersBackend(Backend[Layer]):
172
177
  log_probs = rf.log_softmax(logits, axis=axis)
173
178
  return -rf.matmul(targets, log_probs, reduce=axis)
174
179
 
175
- @staticmethod
176
- def sequence_mask_raw(lengths: Layer, *, batch_major: bool = True) -> Layer:
177
- """sequence mask"""
178
- raise NotImplementedError # TODO
179
-
180
180
  @staticmethod
181
181
  def create_parameter_raw(tensor: rf.Parameter) -> Layer:
182
182
  """create parameter"""
@@ -194,6 +194,11 @@ class TFBackend(Backend[tf.Tensor]):
194
194
  with tf_util.same_control_flow_ctx(raw_tensor):
195
195
  return tf.tile(raw_tensor, [1] * axis + [dim] + [1] * (raw_tensor.shape.ndims - axis - 1))
196
196
 
197
+ @staticmethod
198
+ def cast_raw(raw_tensor: tf.Tensor, dtype: str) -> tf.Tensor:
199
+ """cast"""
200
+ return tf.cast(raw_tensor, dtype)
201
+
197
202
  @staticmethod
198
203
  def activation_raw(raw_tensor: tf.Tensor, func: str) -> tf.Tensor:
199
204
  """
@@ -212,6 +217,13 @@ class TFBackend(Backend[tf.Tensor]):
212
217
  raise ValueError(f"unknown activation function {func!r}")
213
218
  return f(raw_tensor)
214
219
 
220
+ @staticmethod
221
+ def have_sequence_mask_raw() -> bool:
222
+ """
223
+ :return: whether we have sequence_mask
224
+ """
225
+ return True
226
+
215
227
  @staticmethod
216
228
  def sequence_mask_raw(lengths: tf.Tensor, *, batch_major: bool = True) -> tf.Tensor:
217
229
  """
@@ -141,34 +141,44 @@ class Engine(EngineBase):
141
141
  self._pt_model.train()
142
142
 
143
143
  accumulated_losses_dict = NumbersDict()
144
+ accumulated_inv_norm_factors_dict = NumbersDict()
144
145
  step_idx = 0
145
146
  for data in self._train_dataloader:
146
147
  self._run_step(data)
147
148
 
148
149
  train_ctx = rf.get_run_ctx()
149
- losses_dict = train_ctx.losses
150
150
  total_loss = train_ctx.total_loss()
151
+ losses_dict = NumbersDict(
152
+ {
153
+ name: float(loss.get_summed_loss().raw_tensor.detach().cpu().numpy())
154
+ for name, loss in train_ctx.losses.items()
155
+ }
156
+ )
157
+ inv_norm_factors_dict = NumbersDict(
158
+ {name: float(_to_raw(loss.get_inv_norm_factor())) for name, loss in train_ctx.losses.items()}
159
+ )
151
160
 
152
161
  self._updater.get_optimizer().zero_grad()
153
162
  total_loss.raw_tensor.backward()
154
163
  self._updater.get_optimizer().step()
155
164
 
156
- losses_dict = {
157
- "train_loss_" + name: float(loss.loss.raw_tensor.detach().cpu().numpy())
158
- for name, loss in losses_dict.items()
159
- }
160
- accumulated_losses_dict += NumbersDict(losses_dict)
161
- print("step %i, loss: %f" % (step_idx, total_loss.raw_tensor.detach().cpu().numpy()), file=log.v4)
165
+ accumulated_losses_dict += losses_dict
166
+ accumulated_inv_norm_factors_dict += inv_norm_factors_dict
167
+ print(f"step {step_idx}, loss: {dict(losses_dict / inv_norm_factors_dict)}", file=log.v4)
162
168
 
163
169
  step_idx += 1
164
170
  self._train_step += 1
165
171
 
166
172
  print("Trained %i steps" % step_idx)
167
173
 
168
- accumulated_losses_dict = accumulated_losses_dict / step_idx
169
- self.learning_rate_control.set_epoch_error(self.epoch, dict(accumulated_losses_dict))
174
+ accumulated_losses_dict = accumulated_losses_dict / accumulated_inv_norm_factors_dict
175
+ self.learning_rate_control.set_epoch_error(
176
+ self.epoch, {f"train_loss_{k}": v for k, v in accumulated_losses_dict.items()}
177
+ )
170
178
  self.learning_rate_control.save()
171
179
 
180
+ print(f"Total train loss: {dict(accumulated_losses_dict)}", file=log.v3)
181
+
172
182
  if self.epoch % self._save_model_epoch_interval == 0 or self.epoch == self._final_epoch:
173
183
  self._save_model()
174
184
  self._save_optimizer()
@@ -186,8 +196,8 @@ class Engine(EngineBase):
186
196
 
187
197
  data_loader = self._eval_dataloaders[dataset_name]
188
198
 
189
- accumulated_loss = 0.0
190
199
  accumulated_losses_dict = NumbersDict()
200
+ accumulated_inv_norm_factors_dict = NumbersDict()
191
201
  step_idx = 0
192
202
 
193
203
  with torch.no_grad():
@@ -195,29 +205,31 @@ class Engine(EngineBase):
195
205
 
196
206
  self._run_step(data)
197
207
  train_ctx = rf.get_run_ctx()
198
- losses_dict = train_ctx.losses
199
- total_loss = train_ctx.total_loss()
200
-
201
- total_loss = total_loss.raw_tensor.detach().cpu().numpy()
202
- losses_dict = {
203
- dataset_name + "_loss_" + name: float(loss.loss.raw_tensor.detach().cpu().numpy())
204
- for name, loss in losses_dict.items()
205
- }
206
- print("step %i, loss: %f" % (step_idx, total_loss), file=log.v4)
207
-
208
- accumulated_loss += total_loss
209
- accumulated_losses_dict += NumbersDict(losses_dict)
210
- step_idx += 1
211
208
 
212
- assert step_idx > 0, "No data in dataset '{}'.".format(dataset_name)
213
- accumulated_loss = accumulated_loss / step_idx
214
- accumulated_losses_dict = accumulated_losses_dict / step_idx
209
+ losses_dict = NumbersDict(
210
+ {
211
+ name: float(loss.get_summed_loss().raw_tensor.detach().cpu().numpy())
212
+ for name, loss in train_ctx.losses.items()
213
+ }
214
+ )
215
+ inv_norm_factors_dict = NumbersDict(
216
+ {name: float(_to_raw(loss.get_inv_norm_factor())) for name, loss in train_ctx.losses.items()}
217
+ )
218
+
219
+ accumulated_losses_dict += losses_dict
220
+ accumulated_inv_norm_factors_dict += inv_norm_factors_dict
221
+ print(f"step {step_idx}, loss: {dict(losses_dict / inv_norm_factors_dict)}", file=log.v4)
222
+ step_idx += 1
215
223
 
216
- self.learning_rate_control.set_epoch_error(self.epoch, dict(accumulated_losses_dict))
224
+ assert step_idx > 0, f"No data in dataset {dataset_name!r}."
225
+ accumulated_losses_dict = accumulated_losses_dict / accumulated_inv_norm_factors_dict
217
226
 
218
- print("Total loss for '{}': {:.6}".format(dataset_name, accumulated_loss), file=log.v3)
227
+ self.learning_rate_control.set_epoch_error(
228
+ self.epoch, {f"{dataset_name}_loss_{k}": v for k, v in accumulated_losses_dict.items()}
229
+ )
230
+ self.learning_rate_control.save()
219
231
 
220
- self.learning_rate_control.save()
232
+ print(f"Total loss for {dataset_name!r}: {dict(accumulated_losses_dict)}", file=log.v3)
221
233
 
222
234
  def _create_data_loader(self, dataset: Dataset) -> DataLoader2:
223
235
  """
@@ -312,6 +324,7 @@ class Engine(EngineBase):
312
324
  else:
313
325
  raise TypeError(f"get_model returned {model} of type {type(model)}, expected rf.Module or torch.nn.Module")
314
326
  assert isinstance(self._pt_model, torch.nn.Module)
327
+ print("Model:", self._pt_model, file=log.v4)
315
328
 
316
329
  if checkpoint_state is not None:
317
330
  self._pt_model.load_state_dict(checkpoint_state["model"])
@@ -404,3 +417,11 @@ class Engine(EngineBase):
404
417
  os.makedirs(directory, exist_ok=True)
405
418
 
406
419
  self._updater.save_optimizer(filename)
420
+
421
+
422
+ def _to_raw(n: Union[int, float, Tensor]):
423
+ if isinstance(n, (int, float)):
424
+ return n
425
+ if isinstance(n, Tensor):
426
+ return n.raw_tensor.detach().cpu().numpy()
427
+ raise TypeError(f"Unexpected {n} of type {type(n)}")
@@ -116,6 +116,11 @@ class TorchBackend(Backend[torch.Tensor]):
116
116
  """
117
117
  return raw_tensor.unsqueeze(axis)
118
118
 
119
+ @staticmethod
120
+ def cast_raw(raw_tensor: torch.Tensor, dtype: str) -> torch.Tensor:
121
+ """cast"""
122
+ return raw_tensor.to(dtype=TorchBackend.as_dtype_raw(dtype))
123
+
119
124
  @staticmethod
120
125
  def activation_raw(raw_tensor: torch.Tensor, func: str) -> torch.Tensor:
121
126
  """
@@ -411,6 +416,21 @@ class TorchBackend(Backend[torch.Tensor]):
411
416
 
412
417
  return result_tensor
413
418
 
419
+ @staticmethod
420
+ def range_over_dim(dim: Dim) -> Tensor[torch.Tensor]:
421
+ """
422
+ :param dim:
423
+ :return: tensor with shape [dim]
424
+ """
425
+ out = Tensor(
426
+ "range",
427
+ dims=[dim],
428
+ sparse_dim=dim,
429
+ dtype=dim.dyn_size_ext.dtype if dim.dyn_size_ext else rf.get_default_array_index_dtype(),
430
+ )
431
+ out.raw_tensor = torch.arange(dim.get_dim_value())
432
+ return out
433
+
414
434
  @staticmethod
415
435
  def reduce(
416
436
  source: Tensor[torch.Tensor],
@@ -422,15 +442,25 @@ class TorchBackend(Backend[torch.Tensor]):
422
442
  """reduce"""
423
443
  assert mode in Backend._AllowedReduceModes
424
444
  if isinstance(axis, Dim):
425
- assert not axis.need_masking() # not implemented
426
- else:
427
- assert all(not dim.need_masking() for dim in axis) # not implemented
445
+ axis = [axis]
446
+ assert all(isinstance(dim, Dim) for dim in axis)
447
+ if use_time_mask is not False and any(dim.need_masking() for dim in axis):
448
+ source = source.copy()
449
+ dtype = source.raw_tensor.dtype
450
+ if mode == "max":
451
+ mask_value = torch.finfo(dtype).min if dtype.is_floating_point else torch.iinfo(dtype).min
452
+ elif mode == "min":
453
+ mask_value = torch.finfo(dtype).max if dtype.is_floating_point else torch.iinfo(dtype).max
454
+ elif mode == "sum":
455
+ mask_value = 0
456
+ else:
457
+ raise NotImplementedError(f"reduce_{mode} not implemented with masking on tensor {source!r}.")
458
+ for i, dim in enumerate(axis):
459
+ if dim.need_masking():
460
+ mask = source.get_sequence_mask_broadcast(axis=i)
461
+ source.raw_tensor = torch.where(mask, source.raw_tensor, mask_value)
428
462
  func = getattr(torch, mode)
429
- raw_dims = (
430
- [source.get_axis_from_description(axis)]
431
- if isinstance(axis, Dim)
432
- else [source.get_axis_from_description(dim) for dim in axis]
433
- )
463
+ raw_dims = [source.get_axis_from_description(dim) for dim in axis]
434
464
  res_dims = [dim for i, dim in enumerate(source.dims) if i not in raw_dims]
435
465
  if not res_dims:
436
466
  raw_result = func(source.raw_tensor)
@@ -76,6 +76,9 @@ class _RFModuleAsPTModule(torch.nn.Module):
76
76
  pt_mod = rf_module_to_pt_module(rf_mod)
77
77
  self.add_module(name, pt_mod)
78
78
 
79
+ def _get_name(self):
80
+ return self._rf_module.__class__.__name__ + "[RF→PT]"
81
+
79
82
  @property
80
83
  def rf_module(self) -> rf.Module:
81
84
  """RF module"""
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: returnn
3
- Version: 1.20230408.155406
3
+ Version: 1.20230409.122444
4
4
  Summary: The RWTH extensible training framework for universal recurrent neural networks
5
5
  Home-page: https://github.com/rwth-i6/returnn/
6
6
  Author: Albert Zeyer
@@ -139,6 +139,13 @@ def test_demo_torch_task12ax():
139
139
  # TODO also check FER. So far this is not properly reported. https://github.com/rwth-i6/returnn/issues/1120
140
140
 
141
141
 
142
+ @unittest.skipIf(not torch, "no PyTorch")
143
+ def test_demo_rf_torch_task12ax():
144
+ cleanup_tmp_models("demos/demo-rf.config")
145
+ run(py, "rnn.py", "demos/demo-rf.config", print_stdout=True)
146
+ # TODO also check FER. So far this is not properly reported. https://github.com/rwth-i6/returnn/issues/1120
147
+
148
+
142
149
  def test_demo_iter_dataset_task12ax():
143
150
  # there should be no actual TF dependency, we just iterate the dataset
144
151
  cleanup_tmp_models("demos/demo-tf-vanilla-lstm.12ax.config")
@@ -1,2 +0,0 @@
1
- version = '1.20230408.155406'
2
- long_version = '1.20230408.155406+git.03aed81'