returnn 1.20230403.124714__tar.gz → 1.20230403.211148__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of returnn might be problematic. Click here for more details.
- {returnn-1.20230403.124714/returnn.egg-info → returnn-1.20230403.211148}/PKG-INFO +1 -1
- returnn-1.20230403.211148/_setup_info_generated.py +2 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/frontend/run_ctx.py +32 -20
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tensor/_tensor_extra.py +8 -6
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/frontend_layers/config_entry_points.py +2 -2
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/torch/data/pipeline.py +8 -6
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/torch/engine.py +53 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/torch/frontend/_rand.py +2 -1
- {returnn-1.20230403.124714 → returnn-1.20230403.211148/returnn.egg-info}/PKG-INFO +1 -1
- returnn-1.20230403.124714/_setup_info_generated.py +0 -2
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/.editorconfig +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/.gitignore +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/.gitmodules +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/.kateconfig +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/CHANGELOG.md +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/CODEOWNERS +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/CONTRIBUTING.md +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/LICENSE +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/MANIFEST.in +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/README.rst +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/__init__.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/12AX.cluster_map +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/_setup_returnn_env.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-fwd.config +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-horovod-mpi.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-horovod-mpi.py.sh +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-horovod-mpi.sh +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-hyper-param-tuning.config +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-iter-dataset.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-list-devices.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-lua-torch-layer.config +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-pretrain.config +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-record-and-push-to-webserver.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-returnn-as-framework.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-rhn-enwik8.config +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-sprint-interface.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-tf-att-copy.config +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-tf-attention.config +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-tf-chunking-blstm.12ax.config +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-tf-contribrnn-lstm.12ax.config +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-tf-enc-dec.config +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-tf-hard-att-copy.config +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-tf-lstm-benchmark.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-tf-maxgradnorm-lstm.12ax.config +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-tf-native-lstm-lowmem.12ax.config +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-tf-native-lstm.12ax.config +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-tf-native-lstm2.12ax.config +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-tf-native-lstm2.12ax.tuned.config +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-tf-neural-transducer.12ax.config +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-tf-rec-explicit-lstm.config +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-tf-rec-explicit-rnn.config +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-tf-rec-self-att.config +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-tf-search-compiled-graph.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-tf-vanilla-lstm.12ax.config +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-timit-lstm-ctc.config +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-torch.config +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-upd-mult-model.lstm.12ax.config +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo.sh +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/mdlstm/IAM/IAM_lines/a01-000u-00.png +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/mdlstm/IAM/IAM_lines/a01-007-04.png +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/mdlstm/IAM/IAM_lines/a01-007-06.png +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/mdlstm/IAM/README.txt +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/mdlstm/IAM/chars.txt +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/mdlstm/IAM/config_demo +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/mdlstm/IAM/config_fwd +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/mdlstm/IAM/config_real +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/mdlstm/IAM/create_IAM_dataset.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/mdlstm/IAM/decode.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/mdlstm/IAM/features/raw/demo.h5 +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/mdlstm/IAM/go.sh +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/mdlstm/IAM/lines.txt +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/mdlstm/IAM/split/eval.txt +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/mdlstm/IAM/split/train.txt +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/mdlstm/IAM/split/valid.txt +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/mdlstm/README.md +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/mdlstm/artificial/create_test_h5.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/mdlstm/artificial/forwardconfig +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/mdlstm/artificial/go.sh +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/mdlstm/artificial/trainconfig +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/mdlstm/artificial_rgb/create_test_h5.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/mdlstm/artificial_rgb/forwardconfig +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/mdlstm/artificial_rgb/go.sh +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/mdlstm/artificial_rgb/trainconfig +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/pyproject.toml +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/requirements.txt +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/__init__.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/__main__.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/__old_mod_loader__.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/__setup__.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/config.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/datasets/__init__.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/datasets/audio.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/datasets/basic.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/datasets/bundle_file.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/datasets/cached.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/datasets/cached2.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/datasets/generating.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/datasets/hdf.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/datasets/lm.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/datasets/map.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/datasets/meta.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/datasets/multi_proc.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/datasets/normalization_data.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/datasets/numpy_dump.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/datasets/raw_wav.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/datasets/sprint.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/datasets/stereo.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/datasets/util/__init__.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/datasets/util/feature_extraction.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/datasets/util/vocabulary.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/engine/__init__.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/engine/base.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/engine/batch.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/__init__.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/__main__.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/warp-rna/.git +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/warp-rna/.gitignore +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/warp-rna/LICENSE +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/warp-rna/README.md +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/warp-rna/aligner.gif +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/warp-rna/check.png +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/warp-rna/core.cu +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/warp-rna/core.h +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/warp-rna/core_cpu.cpp +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/warp-rna/pytorch_binding/LICENSE +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/warp-rna/pytorch_binding/MANIFEST.in +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/warp-rna/pytorch_binding/README.md +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/warp-rna/pytorch_binding/binding.cpp +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/warp-rna/pytorch_binding/core.cu +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/warp-rna/pytorch_binding/core.h +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/warp-rna/pytorch_binding/requirements.txt +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/warp-rna/pytorch_binding/setup.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/warp-rna/pytorch_binding/warp_rna/__init__.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/warp-rna/pytorch_binding/warp_rna/test.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/warp-rna/ref_rna.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/warp-rna/tensorflow_binding/setup.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/warp-rna/tensorflow_binding/src/warp_rna_op.cc +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/warp-rna/tensorflow_binding/src/warp_rna_op_kernel_tmpl.h +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/warp-rna/tensorflow_binding/warp_rna/__init__.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/warp-rna/test.cpp +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/__init__.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/graph_editor/README.md +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/graph_editor/__init__.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/graph_editor/edit.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/graph_editor/reroute.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/graph_editor/select.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/graph_editor/subgraph.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/graph_editor/transform.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/graph_editor/util.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/frontend/__init__.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/frontend/_backend.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/frontend/_numpy_backend.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/frontend/_utils.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/frontend/array_.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/frontend/const.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/frontend/dims.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/frontend/dtype.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/frontend/init.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/frontend/linear.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/frontend/math_.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/frontend/matmul.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/frontend/module.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/frontend/parameter.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/frontend/rand.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/frontend/reduce.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/frontend/state.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/frontend/types.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/import_/__init__.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/import_/common.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/import_/git.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/import_/import_.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/learning_rate_control.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/log.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/native_op.cpp +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/native_op.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/pretrain.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/sprint/__init__.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/sprint/cache.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/sprint/control.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/sprint/error_signals.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/sprint/extern_interface.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/sprint/interface.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tensor/README.md +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tensor/__init__.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tensor/_dim_extra.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tensor/_tensor_mixin_base.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tensor/_tensor_op_overloads.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tensor/control_flow_ctx.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tensor/dim.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tensor/marked_dim.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tensor/tensor.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tensor/tensor_dict.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/__init__.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/compat.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/data_pipeline.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/distributed.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/engine.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/frontend_layers/__init__.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/frontend_layers/_backend.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/frontend_layers/_utils.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/frontend_layers/debug_eager_mode.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/frontend_layers/dims.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/frontend_layers/layer.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/frontend_layers/make_layer.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/frontend_layers/prev_tensor_ref.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/frontend_low_level/__init__.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/frontend_low_level/_backend.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/horovod.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/hyper_param_tuning.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/layers/__init__.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/layers/base.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/layers/basic.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/layers/rec.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/layers/segmental_model.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/layers/signal_processing.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/native_op.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/network.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/sprint.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/updater.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/util/__init__.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/util/basic.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/util/data.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/util/ken_lm.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/tf/util/open_fst.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/torch/README.md +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/torch/__init__.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/torch/data/__init__.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/torch/data/returnn_dataset_wrapper.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/torch/frontend/__init__.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/torch/frontend/_backend.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/torch/functional/README.md +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/torch/functional/__init__.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/torch/updater.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/util/__init__.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/util/basic.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/util/better_exchook.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/util/bpe.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/util/debug.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/util/debug_helpers.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/util/fsa.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/util/literal_py_to_pickle.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/util/pprint.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/util/py-to-pickle.cpp +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/util/sig_proc.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/util/task_system.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn.egg-info/SOURCES.txt +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn.egg-info/dependency_links.txt +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn.egg-info/top_level.txt +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/rnn.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/setup.cfg +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/setup.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/DummySprintExec.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/PyCharm-inspection-profile.xml +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/PyCharm.idea/.gitignore +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/PyCharm.idea/.name +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/PyCharm.idea/codeStyleSettings.xml +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/PyCharm.idea/codeStyles/Project.xml +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/PyCharm.idea/codeStyles/codeStyleConfig.xml +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/PyCharm.idea/inspectionProfiles/Project_Default.xml +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/PyCharm.idea/inspectionProfiles/profiles_settings.xml +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/PyCharm.idea/misc.xml +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/PyCharm.idea/modules.xml +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/PyCharm.idea/returnn.iml +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/PyCharm.idea/scopes/scope_settings.xml +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/_set_num_threads1.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/_setup_returnn_env.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/_setup_test_env.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/bpe-unicode-demo.codes +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/bpe-unicode-demo.vocab +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/lexicon_opt.fst +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/lexicon_opt.isyms +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/lexicon_opt.jpg +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/lexicon_opt.osyms +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/lint_common.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/pycharm-inspect.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/pylint.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/returnn-as-framework.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/rf_utils.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/spelling.dic +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_Config.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_Dataset.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_Fsa.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_GeneratingDataset.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_HDFDataset.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_LearningRateControl.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_Log.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_MultiProcDataset.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_PTDataset.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_Pretrain.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_ResNet.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_SprintDataset.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_SprintInterface.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_TFEngine.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_TFNativeOp.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_TFNetworkLayer.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_TFNetworkRecLayer.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_TFNetworkSigProcLayer.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_TFUpdater.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_TFUtil.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_TF_determinism.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_TaskSystem.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_TaskSystem_SharedMem.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_TranslationDataset.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_Util.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_demos.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_fork_exec.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_hdf_dump.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_rf_base.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_tensor.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_tools.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_torch_frontend.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tests/test_torch_internal_frontend.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/_setup_returnn_env.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/analyze-dataset-batches.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/bliss-collect-seq-lens.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/bliss-dump-text.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/bliss-get-segment-names.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/bliss-to-ogg-zip.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/bpe-create-lexicon.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/calculate-word-error-rate.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/cleanup-old-models.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/collect-orth-symbols.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/collect-words.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/compile_native_op.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/compile_tf_graph.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/debug-dump-search-scores.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/debug-plot-search-scores.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/dump-dataset-raw-strings.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/dump-dataset.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/dump-forward-stats.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/dump-forward.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/dump-network-json.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/dump-pickle.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/extract_state_tying_from_dataset.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/get-attention-weights.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/get-best-model-epoch.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/hdf_dump.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/hdf_dump_translation_dataset.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/import-blocks-mt-model.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/import-t2t-mt-model.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/lattice_rescorer/.gitignore +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/lattice_rescorer/Makefile +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/lattice_rescorer/README.md +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/lattice_rescorer/example/README.md +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/lattice_rescorer/example/libs_list +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/lattice_rescorer/example/network.040/i600_m600_m600.sgd_b16_lr0_cl2.newbobabs.config +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/lattice_rescorer/example/network.040/i600_m600_m600.sgd_b16_lr0_cl2.newbobabs.keep_over_epoch.lstm2.config +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/lattice_rescorer/example/rescore_lattice.sh +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/lattice_rescorer/example/state_vars_list +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/lattice_rescorer/example/tensor_names_list +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/lattice_rescorer/file.h +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/lattice_rescorer/htklatticerescorer.cc +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/lattice_rescorer/htklatticerescorer.h +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/lattice_rescorer/main.cc +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/lattice_rescorer/rescorer.h +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/lattice_rescorer/vocabulary.cc +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/lattice_rescorer/vocabulary.h +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/tf_avg_checkpoints.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/tf_inspect_checkpoint.py +0 -0
- {returnn-1.20230403.124714 → returnn-1.20230403.211148}/tools/tf_inspect_summary_log.py +0 -0
|
@@ -9,11 +9,11 @@ or forwarding loop.
|
|
|
9
9
|
from __future__ import annotations
|
|
10
10
|
from typing import Optional, Union, Any, Sequence, Dict
|
|
11
11
|
from dataclasses import dataclass
|
|
12
|
-
from returnn.tensor import Tensor, Dim
|
|
12
|
+
from returnn.tensor import Tensor, Dim, TensorDict
|
|
13
13
|
import returnn.frontend as rf
|
|
14
14
|
|
|
15
15
|
|
|
16
|
-
__all__ = ["RunCtx", "Loss", "
|
|
16
|
+
__all__ = ["RunCtx", "Loss", "get_run_ctx", "init_train_step_run_ctx", "init_forward_step_run_ctx"]
|
|
17
17
|
|
|
18
18
|
|
|
19
19
|
_run_ctx = None # type: Optional[RunCtx]
|
|
@@ -75,7 +75,7 @@ class RunCtx:
|
|
|
75
75
|
"""
|
|
76
76
|
self.stage = stage
|
|
77
77
|
self.losses = {} # type: Dict[str, Loss]
|
|
78
|
-
self.outputs =
|
|
78
|
+
self.outputs = TensorDict()
|
|
79
79
|
|
|
80
80
|
def mark_as_loss(
|
|
81
81
|
self,
|
|
@@ -141,7 +141,7 @@ class RunCtx:
|
|
|
141
141
|
custom_inv_norm_factor=custom_inv_norm_factor,
|
|
142
142
|
)
|
|
143
143
|
|
|
144
|
-
def mark_as_output(self, tensor: Union[Tensor, Any], name: str, *,
|
|
144
|
+
def mark_as_output(self, tensor: Union[Tensor, Any], name: str, *, dims: Optional[Sequence[int]] = None) -> None:
|
|
145
145
|
"""
|
|
146
146
|
Mark this as an output.
|
|
147
147
|
This has the effect that RETURNN will in any case construct the corresponding layer.
|
|
@@ -153,7 +153,7 @@ class RunCtx:
|
|
|
153
153
|
|
|
154
154
|
:param tensor:
|
|
155
155
|
:param name:
|
|
156
|
-
:param
|
|
156
|
+
:param dims: this specifies the order of the dims of the output, such that it is well-defined
|
|
157
157
|
for some external application.
|
|
158
158
|
If not specified, we try to infer BTF or BF as default, if that works, otherwise it will be an error.
|
|
159
159
|
"""
|
|
@@ -161,7 +161,32 @@ class RunCtx:
|
|
|
161
161
|
if not isinstance(tensor, Tensor):
|
|
162
162
|
tensor = rf.convert_to_tensor(tensor)
|
|
163
163
|
assert name not in self.outputs
|
|
164
|
-
|
|
164
|
+
if dims is None:
|
|
165
|
+
rem_dims = list(tensor.dims)
|
|
166
|
+
dims = []
|
|
167
|
+
if tensor.have_batch_axis():
|
|
168
|
+
rem_dims.remove(tensor.get_batch_dim_tag())
|
|
169
|
+
dims.append(tensor.get_batch_dim_tag())
|
|
170
|
+
if tensor.have_time_axis():
|
|
171
|
+
rem_dims.remove(tensor.get_time_dim_tag())
|
|
172
|
+
dims.append(tensor.get_time_dim_tag())
|
|
173
|
+
static_dims = [d for d in dims if d.is_static()]
|
|
174
|
+
if len(static_dims) > 1:
|
|
175
|
+
raise Exception(
|
|
176
|
+
f"Cannot infer order of dims automatically for output {name!r}. Please specify a shape explicitly."
|
|
177
|
+
)
|
|
178
|
+
elif len(static_dims) == 1:
|
|
179
|
+
rem_dims.remove(static_dims[0])
|
|
180
|
+
dims.insert(0, static_dims[0])
|
|
181
|
+
if len(rem_dims) > 1:
|
|
182
|
+
raise Exception(
|
|
183
|
+
f"Cannot infer order of dims automatically for output {name!r}. Please specify a shape explicitly."
|
|
184
|
+
)
|
|
185
|
+
elif len(rem_dims) == 1:
|
|
186
|
+
dims.append(rem_dims[0])
|
|
187
|
+
tensor = tensor.copy_transpose(dims, allow_int=False)
|
|
188
|
+
tensor = tensor.copy(name=name)
|
|
189
|
+
self.outputs.data[name] = tensor
|
|
165
190
|
|
|
166
191
|
def mark_as_default_output(self, tensor: Union[Tensor, Any], *, shape: Optional[Sequence[Dim]] = None) -> None:
|
|
167
192
|
"""
|
|
@@ -173,7 +198,7 @@ class RunCtx:
|
|
|
173
198
|
:param tensor:
|
|
174
199
|
:param shape:
|
|
175
200
|
"""
|
|
176
|
-
self.mark_as_output(tensor, "output",
|
|
201
|
+
self.mark_as_output(tensor, "output", dims=shape)
|
|
177
202
|
|
|
178
203
|
def total_loss(self) -> Union[Tensor, float]:
|
|
179
204
|
"""
|
|
@@ -233,16 +258,3 @@ class Loss:
|
|
|
233
258
|
else:
|
|
234
259
|
loss = self.get_summed_loss()
|
|
235
260
|
return loss * self.scale
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
@dataclass
|
|
239
|
-
class Output:
|
|
240
|
-
"""
|
|
241
|
-
Output via :func:`RunCtx.mark_as_output`.
|
|
242
|
-
|
|
243
|
-
We collect all relevant information here.
|
|
244
|
-
"""
|
|
245
|
-
|
|
246
|
-
tensor: Tensor
|
|
247
|
-
name: str
|
|
248
|
-
shape: Optional[Sequence[Dim]] = None
|
|
@@ -651,14 +651,16 @@ class _TensorMixin(_TensorMixinBase):
|
|
|
651
651
|
assert self.time_dim_axis is not None
|
|
652
652
|
return self.copy_move_axis(self.time_dim_axis, time_dim_axis)
|
|
653
653
|
|
|
654
|
-
def copy_transpose(self, perm) -> _t.Tensor:
|
|
654
|
+
def copy_transpose(self, perm: Sequence[Union[int, Dim]], *, allow_int: bool = True) -> _t.Tensor:
|
|
655
655
|
"""
|
|
656
|
-
:param
|
|
657
|
-
|
|
656
|
+
:param perm: permutation of the axes. Maps the new axes to the old axes
|
|
657
|
+
:param allow_int: allow int as axis, otherwise only :class:`Dim`
|
|
658
658
|
:return: copy of myself with permuted axes
|
|
659
659
|
"""
|
|
660
|
-
assert len(perm) == self.batch_ndim
|
|
661
|
-
|
|
660
|
+
assert len(perm) == self.batch_ndim, f"{self}: invalid perm {perm!r} length"
|
|
661
|
+
perm_ = perm
|
|
662
|
+
perm = [self.get_axis_from_description(a, allow_int=allow_int) for a in perm]
|
|
663
|
+
assert set(perm) == set(range(self.batch_ndim)), f"{self}: invalid perm {perm_!r} (axes: {perm!r})"
|
|
662
664
|
if all(perm[axis] == axis for axis in range(self.batch_ndim)):
|
|
663
665
|
return self.copy()
|
|
664
666
|
|
|
@@ -3156,7 +3158,7 @@ class _TensorMixin(_TensorMixinBase):
|
|
|
3156
3158
|
"""
|
|
3157
3159
|
import returnn.frontend as rf
|
|
3158
3160
|
|
|
3159
|
-
rf.get_run_ctx().mark_as_output(self, name=name,
|
|
3161
|
+
rf.get_run_ctx().mark_as_output(self, name=name, dims=shape)
|
|
3160
3162
|
|
|
3161
3163
|
def mark_as_default_output(self: Tensor, *, shape: Optional[Sequence[Dim]] = None) -> None:
|
|
3162
3164
|
"""
|
|
@@ -123,12 +123,12 @@ def get_net_dict(
|
|
|
123
123
|
# Note that this logic might change.
|
|
124
124
|
root_scope.marked_losses.append(loss_t)
|
|
125
125
|
|
|
126
|
-
for out in rf.get_run_ctx().outputs.values():
|
|
126
|
+
for out in rf.get_run_ctx().outputs.data.values():
|
|
127
127
|
if out.name == "output" and out.name not in root_scope.children:
|
|
128
128
|
layer = root_scope.get_child(out.name)
|
|
129
129
|
else:
|
|
130
130
|
layer = root_scope.get_new_child(suggested_name=out.name)
|
|
131
|
-
out_t = _utils.copy(out
|
|
131
|
+
out_t = _utils.copy(out, name=layer)
|
|
132
132
|
if layer.name != "output":
|
|
133
133
|
out_t.raw_tensor.layer_dict["is_output_layer"] = True
|
|
134
134
|
root_scope.marked_outputs.append(out_t)
|
|
@@ -18,17 +18,19 @@ However, having this separate pure PyTorch implementation is useful to allow to
|
|
|
18
18
|
other PyTorch datasets more directly, including also HuggingFace datasets.
|
|
19
19
|
"""
|
|
20
20
|
|
|
21
|
+
from __future__ import annotations
|
|
22
|
+
from typing import List, Dict
|
|
21
23
|
import sys
|
|
22
24
|
from copy import deepcopy
|
|
23
25
|
|
|
24
|
-
import numpy
|
|
26
|
+
import numpy
|
|
25
27
|
import torch
|
|
26
28
|
import torch.utils.data
|
|
27
29
|
|
|
28
30
|
from returnn.util.basic import NumbersDict
|
|
29
31
|
|
|
30
32
|
|
|
31
|
-
def create_tensor(array:
|
|
33
|
+
def create_tensor(array: numpy.ndarray) -> torch.Tensor:
|
|
32
34
|
"""
|
|
33
35
|
Adjust non-supported dtypes
|
|
34
36
|
|
|
@@ -36,14 +38,14 @@ def create_tensor(array: np.ndarray) -> torch.Tensor:
|
|
|
36
38
|
"""
|
|
37
39
|
# The only supported PyTorch dtypes are:
|
|
38
40
|
# float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool.
|
|
39
|
-
if array.dtype ==
|
|
40
|
-
array =
|
|
41
|
+
if array.dtype == numpy.uint32:
|
|
42
|
+
array = numpy.asarray(array, dtype=numpy.int64)
|
|
41
43
|
return torch.tensor(array)
|
|
42
44
|
|
|
43
45
|
|
|
44
|
-
def collate_batch(batch):
|
|
46
|
+
def collate_batch(batch: List[Dict[str, numpy.ndarray]]) -> Dict[str, torch.Tensor]:
|
|
45
47
|
"""
|
|
46
|
-
:param
|
|
48
|
+
:param batch:
|
|
47
49
|
"""
|
|
48
50
|
assert isinstance(batch, list)
|
|
49
51
|
assert batch, "batch is empty?"
|
|
@@ -294,6 +294,59 @@ class Engine(EngineBase):
|
|
|
294
294
|
print("Load model %s" % (filename,), file=log.v4)
|
|
295
295
|
model_state = torch.load(filename)
|
|
296
296
|
self._model.load_state_dict(model_state)
|
|
297
|
+
preload_from_files = self.config.typed_value("preload_from_files", {})
|
|
298
|
+
if preload_from_files:
|
|
299
|
+
# see `preload_from_files` in tf engine and `returnn.tf.network.CustomCheckpointLoader`
|
|
300
|
+
is_training = self.config.value("task", "train") == "train"
|
|
301
|
+
is_first_train_epoch = epoch == 1 and (
|
|
302
|
+
is_training or self.config.value("task", "train") == "initialize_model"
|
|
303
|
+
)
|
|
304
|
+
# We use the reversed sorted order here to achieve consistent behavior with the TF engine.
|
|
305
|
+
# There, the keys are used in sorted order but if a variable is loaded,
|
|
306
|
+
# it will not be considered anymore afterwards.
|
|
307
|
+
# So the first occurrence is used.
|
|
308
|
+
# Here, we overwrite variables even if they have been loaded before.
|
|
309
|
+
# In order to get consistent behavior, we use the reversed order.
|
|
310
|
+
for preload_key, opts in reversed(sorted(preload_from_files.items())):
|
|
311
|
+
assert isinstance(opts, dict) and "filename" in opts
|
|
312
|
+
if opts.get("init_for_train", False):
|
|
313
|
+
if not is_first_train_epoch:
|
|
314
|
+
continue
|
|
315
|
+
else: # default: init for recog
|
|
316
|
+
if is_training:
|
|
317
|
+
continue
|
|
318
|
+
print(f"Pre-load weights for key '{preload_key}' from {opts['filename']}", file=log.v3)
|
|
319
|
+
preload_model_state = torch.load(opts["filename"])
|
|
320
|
+
if opts.get("checkpoint_key", None) is not None:
|
|
321
|
+
# This can be used if an external checkpoint saves a checkpoint a different structure that just the
|
|
322
|
+
# model state dict. E.g., if a checkpoint is created using
|
|
323
|
+
# `torch.save({"model": model.state_dict(), "optimizer": optimizer.state)_dict(), ...})`
|
|
324
|
+
# we can set checkpoint_key = "model" to load the model.
|
|
325
|
+
# Currently, this only supports single level dicts, but it could be extended if needed.
|
|
326
|
+
preload_model_state = preload_model_state[opts["checkpoint_key"]]
|
|
327
|
+
if opts.get("prefix", ""):
|
|
328
|
+
# Only params with this prefix should be loaded.
|
|
329
|
+
# They are expected to be in the checkpoint without this prefix.
|
|
330
|
+
# By adding the prefix to all params,
|
|
331
|
+
# we make sure that only params matching this condition are loaded.
|
|
332
|
+
# This is in line with the behavior of the TF engine.
|
|
333
|
+
preload_model_state = {opts["prefix"] + key: value for key, value in preload_model_state.items()}
|
|
334
|
+
ignore_params = opts.get("ignore_params", [])
|
|
335
|
+
ignore_params_prefixes = opts.get("ignore_params_prefixes", [])
|
|
336
|
+
for key in list(preload_model_state.keys()):
|
|
337
|
+
if key in ignore_params or any(
|
|
338
|
+
[key.startswith(ignore_key) for ignore_key in ignore_params_prefixes]
|
|
339
|
+
):
|
|
340
|
+
print(f"Ignoring variable {key}", file=log.v3)
|
|
341
|
+
preload_model_state.pop(key)
|
|
342
|
+
for new_name, name_in_checkpoint in opts.get("var_name_mapping", {}).items():
|
|
343
|
+
preload_model_state[new_name] = preload_model_state.pop(name_in_checkpoint)
|
|
344
|
+
missing_keys, _ = self._model.load_state_dict(preload_model_state, strict=False)
|
|
345
|
+
if not opts.get("ignore_missing", False):
|
|
346
|
+
prefix_keys = [key for key in self._model.state_dict() if key.startswith(opts.get("prefix", ""))]
|
|
347
|
+
missing_prefix_keys = set(prefix_keys).intersection(set(missing_keys))
|
|
348
|
+
assert not missing_prefix_keys, f"Missing keys and ignore_missing=False: {missing_prefix_keys}"
|
|
349
|
+
print(f"Missing keys: {missing_keys}", file=log.v4)
|
|
297
350
|
|
|
298
351
|
self._model.to(self._device)
|
|
299
352
|
|
|
@@ -10,7 +10,8 @@ import warnings
|
|
|
10
10
|
|
|
11
11
|
def no_grad_trunc_normal_(tensor: torch.Tensor, mean, std, a, b, *, generator=None):
|
|
12
12
|
"""
|
|
13
|
-
Code copied and adopted from torch.nn.init._no_grad_trunc_normal_
|
|
13
|
+
Code copied and adopted from torch.nn.init._no_grad_trunc_normal_,
|
|
14
|
+
to support the extra `generator` argument (https://github.com/pytorch/pytorch/issues/98200).
|
|
14
15
|
|
|
15
16
|
Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
|
16
17
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-hyper-param-tuning.config
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-record-and-push-to-webserver.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-tf-chunking-blstm.12ax.config
RENAMED
|
File without changes
|
{returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-tf-contribrnn-lstm.12ax.config
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-tf-maxgradnorm-lstm.12ax.config
RENAMED
|
File without changes
|
{returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-tf-native-lstm-lowmem.12ax.config
RENAMED
|
File without changes
|
{returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-tf-native-lstm.12ax.config
RENAMED
|
File without changes
|
{returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-tf-native-lstm2.12ax.config
RENAMED
|
File without changes
|
{returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-tf-native-lstm2.12ax.tuned.config
RENAMED
|
File without changes
|
{returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-tf-neural-transducer.12ax.config
RENAMED
|
File without changes
|
{returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-tf-rec-explicit-lstm.config
RENAMED
|
File without changes
|
{returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-tf-rec-explicit-rnn.config
RENAMED
|
File without changes
|
|
File without changes
|
{returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-tf-search-compiled-graph.py
RENAMED
|
File without changes
|
{returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-tf-vanilla-lstm.12ax.config
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/demo-upd-mult-model.lstm.12ax.config
RENAMED
|
File without changes
|
|
File without changes
|
{returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/mdlstm/IAM/IAM_lines/a01-000u-00.png
RENAMED
|
File without changes
|
{returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/mdlstm/IAM/IAM_lines/a01-007-04.png
RENAMED
|
File without changes
|
{returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/mdlstm/IAM/IAM_lines/a01-007-06.png
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/mdlstm/IAM/create_IAM_dataset.py
RENAMED
|
File without changes
|
|
File without changes
|
{returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/mdlstm/IAM/features/raw/demo.h5
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/mdlstm/artificial/create_test_h5.py
RENAMED
|
File without changes
|
{returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/mdlstm/artificial/forwardconfig
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/mdlstm/artificial_rgb/forwardconfig
RENAMED
|
File without changes
|
|
File without changes
|
{returnn-1.20230403.124714 → returnn-1.20230403.211148}/demos/mdlstm/artificial_rgb/trainconfig
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/datasets/normalization_data.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/datasets/util/feature_extraction.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/warp-rna/.git
RENAMED
|
File without changes
|
{returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/warp-rna/.gitignore
RENAMED
|
File without changes
|
{returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/warp-rna/LICENSE
RENAMED
|
File without changes
|
{returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/warp-rna/README.md
RENAMED
|
File without changes
|
{returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/warp-rna/aligner.gif
RENAMED
|
File without changes
|
{returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/warp-rna/check.png
RENAMED
|
File without changes
|
{returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/warp-rna/core.cu
RENAMED
|
File without changes
|
{returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/warp-rna/core.h
RENAMED
|
File without changes
|
{returnn-1.20230403.124714 → returnn-1.20230403.211148}/returnn/extern/WarpRna/warp-rna/core_cpu.cpp
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|