returnn 1.20230413.135919__tar.gz → 1.20230413.141543__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {returnn-1.20230413.135919/returnn.egg-info → returnn-1.20230413.141543}/PKG-INFO +1 -1
- returnn-1.20230413.141543/_setup_info_generated.py +2 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/frontend/__init__.py +5 -1
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/frontend/_backend.py +2 -0
- returnn-1.20230413.141543/returnn/frontend/attention.py +211 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/tensor/_dim_extra.py +7 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/torch/frontend/_backend.py +0 -1
- {returnn-1.20230413.135919 → returnn-1.20230413.141543/returnn.egg-info}/PKG-INFO +1 -1
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn.egg-info/SOURCES.txt +1 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tests/test_rf_base.py +27 -0
- returnn-1.20230413.135919/_setup_info_generated.py +0 -2
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/.editorconfig +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/.gitignore +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/.gitmodules +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/.kateconfig +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/CHANGELOG.md +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/CODEOWNERS +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/CONTRIBUTING.md +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/LICENSE +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/MANIFEST.in +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/README.rst +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/__init__.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/12AX.cluster_map +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/_setup_returnn_env.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/demo-fwd.config +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/demo-horovod-mpi.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/demo-horovod-mpi.py.sh +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/demo-horovod-mpi.sh +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/demo-hyper-param-tuning.config +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/demo-iter-dataset.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/demo-list-devices.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/demo-lua-torch-layer.config +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/demo-pretrain.config +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/demo-record-and-push-to-webserver.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/demo-returnn-as-framework.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/demo-rf.config +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/demo-rhn-enwik8.config +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/demo-sprint-interface.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/demo-tf-att-copy.config +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/demo-tf-attention.config +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/demo-tf-chunking-blstm.12ax.config +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/demo-tf-contribrnn-lstm.12ax.config +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/demo-tf-enc-dec.config +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/demo-tf-hard-att-copy.config +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/demo-tf-lstm-benchmark.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/demo-tf-maxgradnorm-lstm.12ax.config +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/demo-tf-native-lstm-lowmem.12ax.config +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/demo-tf-native-lstm.12ax.config +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/demo-tf-native-lstm2.12ax.config +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/demo-tf-native-lstm2.12ax.tuned.config +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/demo-tf-neural-transducer.12ax.config +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/demo-tf-rec-explicit-lstm.config +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/demo-tf-rec-explicit-rnn.config +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/demo-tf-rec-self-att.config +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/demo-tf-search-compiled-graph.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/demo-tf-vanilla-lstm.12ax.config +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/demo-timit-lstm-ctc.config +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/demo-torch.config +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/demo-upd-mult-model.lstm.12ax.config +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/demo.sh +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/mdlstm/IAM/IAM_lines/a01-000u-00.png +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/mdlstm/IAM/IAM_lines/a01-007-04.png +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/mdlstm/IAM/IAM_lines/a01-007-06.png +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/mdlstm/IAM/README.txt +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/mdlstm/IAM/chars.txt +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/mdlstm/IAM/config_demo +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/mdlstm/IAM/config_fwd +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/mdlstm/IAM/config_real +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/mdlstm/IAM/create_IAM_dataset.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/mdlstm/IAM/decode.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/mdlstm/IAM/features/raw/demo.h5 +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/mdlstm/IAM/go.sh +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/mdlstm/IAM/lines.txt +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/mdlstm/IAM/split/eval.txt +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/mdlstm/IAM/split/train.txt +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/mdlstm/IAM/split/valid.txt +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/mdlstm/README.md +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/mdlstm/artificial/create_test_h5.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/mdlstm/artificial/forwardconfig +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/mdlstm/artificial/go.sh +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/mdlstm/artificial/trainconfig +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/mdlstm/artificial_rgb/create_test_h5.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/mdlstm/artificial_rgb/forwardconfig +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/mdlstm/artificial_rgb/go.sh +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/mdlstm/artificial_rgb/trainconfig +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/pyproject.toml +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/requirements.txt +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/__init__.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/__main__.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/__old_mod_loader__.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/__setup__.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/config.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/datasets/__init__.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/datasets/audio.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/datasets/basic.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/datasets/bundle_file.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/datasets/cached.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/datasets/cached2.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/datasets/generating.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/datasets/hdf.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/datasets/lm.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/datasets/map.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/datasets/meta.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/datasets/multi_proc.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/datasets/normalization_data.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/datasets/numpy_dump.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/datasets/raw_wav.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/datasets/sprint.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/datasets/stereo.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/datasets/util/__init__.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/datasets/util/feature_extraction.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/datasets/util/vocabulary.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/engine/__init__.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/engine/base.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/engine/batch.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/extern/WarpRna/__init__.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/extern/WarpRna/__main__.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/extern/WarpRna/warp-rna/.git +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/extern/WarpRna/warp-rna/.gitignore +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/extern/WarpRna/warp-rna/LICENSE +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/extern/WarpRna/warp-rna/README.md +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/extern/WarpRna/warp-rna/aligner.gif +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/extern/WarpRna/warp-rna/check.png +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/extern/WarpRna/warp-rna/core.cu +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/extern/WarpRna/warp-rna/core.h +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/extern/WarpRna/warp-rna/core_cpu.cpp +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/extern/WarpRna/warp-rna/pytorch_binding/LICENSE +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/extern/WarpRna/warp-rna/pytorch_binding/MANIFEST.in +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/extern/WarpRna/warp-rna/pytorch_binding/README.md +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/extern/WarpRna/warp-rna/pytorch_binding/binding.cpp +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/extern/WarpRna/warp-rna/pytorch_binding/core.cu +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/extern/WarpRna/warp-rna/pytorch_binding/core.h +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/extern/WarpRna/warp-rna/pytorch_binding/requirements.txt +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/extern/WarpRna/warp-rna/pytorch_binding/setup.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/extern/WarpRna/warp-rna/pytorch_binding/warp_rna/__init__.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/extern/WarpRna/warp-rna/pytorch_binding/warp_rna/test.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/extern/WarpRna/warp-rna/ref_rna.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/extern/WarpRna/warp-rna/tensorflow_binding/setup.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/extern/WarpRna/warp-rna/tensorflow_binding/src/warp_rna_op.cc +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/extern/WarpRna/warp-rna/tensorflow_binding/src/warp_rna_op_kernel_tmpl.h +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/extern/WarpRna/warp-rna/tensorflow_binding/warp_rna/__init__.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/extern/WarpRna/warp-rna/test.cpp +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/extern/__init__.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/extern/graph_editor/README.md +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/extern/graph_editor/__init__.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/extern/graph_editor/edit.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/extern/graph_editor/reroute.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/extern/graph_editor/select.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/extern/graph_editor/subgraph.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/extern/graph_editor/transform.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/extern/graph_editor/util.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/frontend/_numpy_backend.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/frontend/_utils.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/frontend/array_.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/frontend/cond.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/frontend/const.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/frontend/dims.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/frontend/dropout.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/frontend/dtype.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/frontend/init.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/frontend/linear.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/frontend/loss.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/frontend/math_.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/frontend/matmul.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/frontend/module.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/frontend/parameter.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/frontend/rand.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/frontend/reduce.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/frontend/run_ctx.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/frontend/state.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/frontend/types.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/import_/__init__.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/import_/common.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/import_/git.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/import_/import_.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/learning_rate_control.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/log.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/native_op.cpp +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/native_op.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/pretrain.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/sprint/__init__.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/sprint/cache.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/sprint/control.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/sprint/error_signals.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/sprint/extern_interface.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/sprint/interface.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/tensor/README.md +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/tensor/__init__.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/tensor/_tensor_extra.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/tensor/_tensor_mixin_base.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/tensor/_tensor_op_overloads.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/tensor/control_flow_ctx.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/tensor/dim.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/tensor/marked_dim.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/tensor/tensor.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/tensor/tensor_dict.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/tf/__init__.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/tf/compat.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/tf/data_pipeline.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/tf/distributed.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/tf/engine.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/tf/frontend_layers/__init__.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/tf/frontend_layers/_backend.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/tf/frontend_layers/_utils.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/tf/frontend_layers/config_entry_points.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/tf/frontend_layers/debug_eager_mode.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/tf/frontend_layers/dims.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/tf/frontend_layers/layer.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/tf/frontend_layers/make_layer.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/tf/frontend_layers/prev_tensor_ref.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/tf/frontend_low_level/__init__.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/tf/frontend_low_level/_backend.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/tf/horovod.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/tf/hyper_param_tuning.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/tf/layers/__init__.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/tf/layers/base.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/tf/layers/basic.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/tf/layers/rec.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/tf/layers/segmental_model.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/tf/layers/signal_processing.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/tf/native_op.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/tf/network.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/tf/sprint.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/tf/updater.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/tf/util/__init__.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/tf/util/basic.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/tf/util/data.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/tf/util/ken_lm.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/tf/util/open_fst.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/torch/README.md +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/torch/__init__.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/torch/data/__init__.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/torch/data/pipeline.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/torch/data/returnn_dataset_wrapper.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/torch/data/tensor_utils.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/torch/engine.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/torch/frontend/__init__.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/torch/frontend/_rand.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/torch/frontend/bridge.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/torch/functional/README.md +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/torch/functional/__init__.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/torch/updater.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/util/__init__.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/util/basic.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/util/better_exchook.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/util/bpe.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/util/debug.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/util/debug_helpers.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/util/fsa.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/util/literal_py_to_pickle.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/util/pprint.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/util/py-to-pickle.cpp +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/util/py_compat.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/util/sig_proc.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/util/task_system.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn.egg-info/dependency_links.txt +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn.egg-info/top_level.txt +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/rnn.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/setup.cfg +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/setup.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tests/DummySprintExec.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tests/PyCharm-inspection-profile.xml +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tests/PyCharm.idea/.gitignore +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tests/PyCharm.idea/.name +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tests/PyCharm.idea/codeStyleSettings.xml +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tests/PyCharm.idea/codeStyles/Project.xml +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tests/PyCharm.idea/codeStyles/codeStyleConfig.xml +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tests/PyCharm.idea/inspectionProfiles/Project_Default.xml +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tests/PyCharm.idea/inspectionProfiles/profiles_settings.xml +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tests/PyCharm.idea/misc.xml +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tests/PyCharm.idea/modules.xml +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tests/PyCharm.idea/returnn.iml +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tests/PyCharm.idea/scopes/scope_settings.xml +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tests/_set_num_threads1.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tests/_setup_returnn_env.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tests/_setup_test_env.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tests/bpe-unicode-demo.codes +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tests/bpe-unicode-demo.vocab +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tests/lexicon_opt.fst +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tests/lexicon_opt.isyms +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tests/lexicon_opt.jpg +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tests/lexicon_opt.osyms +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tests/lint_common.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tests/pycharm-inspect.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tests/pylint.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tests/returnn-as-framework.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tests/rf_utils.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tests/spelling.dic +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tests/test_Config.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tests/test_Dataset.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tests/test_Fsa.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tests/test_GeneratingDataset.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tests/test_HDFDataset.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tests/test_LearningRateControl.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tests/test_Log.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tests/test_MultiProcDataset.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tests/test_PTDataset.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tests/test_Pretrain.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tests/test_ResNet.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tests/test_SprintDataset.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tests/test_SprintInterface.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tests/test_TFEngine.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tests/test_TFNativeOp.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tests/test_TFNetworkLayer.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tests/test_TFNetworkRecLayer.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tests/test_TFNetworkSigProcLayer.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tests/test_TFUpdater.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tests/test_TFUtil.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tests/test_TF_determinism.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tests/test_TaskSystem.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tests/test_TaskSystem_SharedMem.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tests/test_TranslationDataset.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tests/test_Util.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tests/test_demos.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tests/test_fork_exec.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tests/test_hdf_dump.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tests/test_tensor.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tests/test_tools.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tests/test_torch_frontend.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tests/test_torch_internal_frontend.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tools/_setup_returnn_env.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tools/analyze-dataset-batches.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tools/bliss-collect-seq-lens.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tools/bliss-dump-text.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tools/bliss-get-segment-names.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tools/bliss-to-ogg-zip.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tools/bpe-create-lexicon.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tools/calculate-word-error-rate.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tools/cleanup-old-models.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tools/collect-orth-symbols.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tools/collect-words.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tools/compile_native_op.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tools/compile_tf_graph.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tools/debug-dump-search-scores.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tools/debug-plot-search-scores.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tools/dump-dataset-raw-strings.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tools/dump-dataset.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tools/dump-forward-stats.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tools/dump-forward.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tools/dump-network-json.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tools/dump-pickle.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tools/extract_state_tying_from_dataset.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tools/get-attention-weights.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tools/get-best-model-epoch.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tools/hdf_dump.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tools/hdf_dump_translation_dataset.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tools/import-blocks-mt-model.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tools/import-t2t-mt-model.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tools/lattice_rescorer/.gitignore +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tools/lattice_rescorer/Makefile +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tools/lattice_rescorer/README.md +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tools/lattice_rescorer/example/README.md +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tools/lattice_rescorer/example/libs_list +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tools/lattice_rescorer/example/network.040/i600_m600_m600.sgd_b16_lr0_cl2.newbobabs.config +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tools/lattice_rescorer/example/network.040/i600_m600_m600.sgd_b16_lr0_cl2.newbobabs.keep_over_epoch.lstm2.config +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tools/lattice_rescorer/example/rescore_lattice.sh +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tools/lattice_rescorer/example/state_vars_list +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tools/lattice_rescorer/example/tensor_names_list +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tools/lattice_rescorer/file.h +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tools/lattice_rescorer/htklatticerescorer.cc +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tools/lattice_rescorer/htklatticerescorer.h +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tools/lattice_rescorer/main.cc +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tools/lattice_rescorer/rescorer.h +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tools/lattice_rescorer/vocabulary.cc +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tools/lattice_rescorer/vocabulary.h +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tools/tf_avg_checkpoints.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tools/tf_inspect_checkpoint.py +0 -0
- {returnn-1.20230413.135919 → returnn-1.20230413.141543}/tools/tf_inspect_summary_log.py +0 -0
|
@@ -14,8 +14,11 @@ The convention for the user is to do::
|
|
|
14
14
|
# Some most come first here when others directly use it,
|
|
15
15
|
# e.g. `rf.Module` as a baseclass.
|
|
16
16
|
from .module import *
|
|
17
|
+
from .state import *
|
|
17
18
|
|
|
19
|
+
# Now the rest, in alphabetical order.
|
|
18
20
|
from .array_ import *
|
|
21
|
+
from .attention import *
|
|
19
22
|
from .cond import *
|
|
20
23
|
from .const import *
|
|
21
24
|
from .dims import *
|
|
@@ -29,9 +32,10 @@ from .parameter import *
|
|
|
29
32
|
from .rand import *
|
|
30
33
|
from .reduce import *
|
|
31
34
|
from .run_ctx import *
|
|
32
|
-
from .state import *
|
|
33
35
|
from .types import *
|
|
34
36
|
|
|
37
|
+
# Modules not in the main namespace but in sub namespaces.
|
|
35
38
|
from . import init
|
|
36
39
|
|
|
40
|
+
# And some functions from the internal backend API.
|
|
37
41
|
from ._backend import select_backend_torch, select_backend_returnn_layers_tf
|
|
@@ -660,6 +660,8 @@ class Backend(Generic[T]):
|
|
|
660
660
|
"""
|
|
661
661
|
# This default implementation works fine as long as the backend
|
|
662
662
|
# does not have special treatments of Tensor and dim tags itself (like TF net dict backend).
|
|
663
|
+
if not out_dim.is_dim_known():
|
|
664
|
+
out_dim.copy_from(in_dim)
|
|
663
665
|
out = source.copy_template_replace_dim_tag(axis=source.get_axis_from_description(in_dim), new_dim_tag=out_dim)
|
|
664
666
|
out.raw_tensor = source.raw_tensor
|
|
665
667
|
return out
|
|
@@ -0,0 +1,211 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Attention
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
from typing import Tuple, Union, Optional, Sequence
|
|
8
|
+
from returnn.util.py_compat import Protocol
|
|
9
|
+
from returnn.tensor import Tensor, Dim, single_step_dim
|
|
10
|
+
import returnn.frontend as rf
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"AttentionFunc",
|
|
15
|
+
"dot_attention",
|
|
16
|
+
"SelfAttentionBase",
|
|
17
|
+
"SelfAttention",
|
|
18
|
+
"CausalSelfAttention",
|
|
19
|
+
"CausalSelfAttentionState",
|
|
20
|
+
]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class AttentionFunc(Protocol):
|
|
24
|
+
"""Protocol defining a generic attention function"""
|
|
25
|
+
|
|
26
|
+
def __call__(
|
|
27
|
+
self,
|
|
28
|
+
query: Tensor,
|
|
29
|
+
keys: Tensor,
|
|
30
|
+
values: Tensor,
|
|
31
|
+
*,
|
|
32
|
+
key_dim: Dim,
|
|
33
|
+
axis: Dim,
|
|
34
|
+
att_dropout: float = 0.1,
|
|
35
|
+
):
|
|
36
|
+
...
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def dot_attention(
|
|
40
|
+
query: Tensor, keys: Tensor, values: Tensor, *, key_dim: Dim, axis: Dim, att_dropout: float = 0.0
|
|
41
|
+
) -> Tensor:
|
|
42
|
+
"""
|
|
43
|
+
Calculates attention over the given axis, for given key dim.
|
|
44
|
+
Any other unrelated axes do not matter here.
|
|
45
|
+
This can be used for multi-head or single head.
|
|
46
|
+
The query can have other dimensions or not.
|
|
47
|
+
|
|
48
|
+
:param query: {..., key_dim}. For self-attention, do not use the `axis` as in `keys` and `values`,
|
|
49
|
+
but rather replace it by another new dim via :func:`replace_dim`.
|
|
50
|
+
:param keys: {..., axis, key_dim}
|
|
51
|
+
:param values: {..., axis}
|
|
52
|
+
:param key_dim: dim in keys and query, to be reduced to calculate the attention energies.
|
|
53
|
+
:param axis: in keys and values, to apply attention on. softmax will be over this axis, and then it will be reduced
|
|
54
|
+
:param att_dropout: dropout for attention weights
|
|
55
|
+
:return: like values but with axis removed, and maybe any additional axes from query
|
|
56
|
+
"""
|
|
57
|
+
query *= key_dim.dimension**-0.5
|
|
58
|
+
energy = rf.matmul(query, keys, reduce=key_dim)
|
|
59
|
+
att_weights = rf.softmax(energy, axis=axis)
|
|
60
|
+
att_weights = rf.dropout(att_weights, att_dropout, axis=axis)
|
|
61
|
+
# Masking not needed because softmax should already have masked,
|
|
62
|
+
# so we have 0.0 att weights for padded frames.
|
|
63
|
+
att = rf.matmul(att_weights, values, reduce=axis, disable_masking=True)
|
|
64
|
+
return att
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
# noinspection PyAbstractClass
|
|
68
|
+
class SelfAttentionBase(rf.Module):
|
|
69
|
+
"""
|
|
70
|
+
Shared base class for (non-causal) self attention (:class:`SelfAttention`)
|
|
71
|
+
and causal self attention (:class:`CausalSelfAttention`).
|
|
72
|
+
|
|
73
|
+
It uses :func:`dot_attention` for multi-headed dot-attention.
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
def __init__(
|
|
77
|
+
self,
|
|
78
|
+
in_dim: Dim,
|
|
79
|
+
proj_dim: Optional[Dim],
|
|
80
|
+
*,
|
|
81
|
+
key_dim_total: Dim,
|
|
82
|
+
value_dim_total: Dim,
|
|
83
|
+
num_heads: Union[int, Dim],
|
|
84
|
+
with_bias: bool = True,
|
|
85
|
+
att_dropout: float = 0.1,
|
|
86
|
+
):
|
|
87
|
+
"""
|
|
88
|
+
:param in_dim: input dim
|
|
89
|
+
:param proj_dim: if given, will add a final linear projection to this dim.
|
|
90
|
+
otherwise no projection after the attention
|
|
91
|
+
:param key_dim_total: total key dim. should be a multiple of num_heads
|
|
92
|
+
:param value_dim_total: total value dim. should be a multiple of num_heads
|
|
93
|
+
:param num_heads: number of heads
|
|
94
|
+
:param with_bias: whether to add bias to qkv and proj linear projections.
|
|
95
|
+
Was False in original Transformer, but many recent implementations use True by default.
|
|
96
|
+
Also see: https://github.com/rwth-i6/returnn_common/issues/234.
|
|
97
|
+
:param att_dropout: dropout for attention weights
|
|
98
|
+
"""
|
|
99
|
+
super().__init__()
|
|
100
|
+
self.in_dim = in_dim
|
|
101
|
+
self.out_dim = proj_dim if proj_dim else value_dim_total
|
|
102
|
+
if isinstance(num_heads, int):
|
|
103
|
+
num_heads = Dim(num_heads, name="num_heads")
|
|
104
|
+
self.key_dim_total = key_dim_total
|
|
105
|
+
self.key_dim_per_head = key_dim_total.div_left(num_heads)
|
|
106
|
+
self.value_dim_total = value_dim_total
|
|
107
|
+
self.value_dim_per_head = value_dim_total.div_left(num_heads)
|
|
108
|
+
self.num_heads = num_heads
|
|
109
|
+
self.qkv_dim_total = 2 * key_dim_total + value_dim_total
|
|
110
|
+
self.qkv_dim_per_head = 2 * self.key_dim_per_head + self.value_dim_per_head
|
|
111
|
+
self.qkv = rf.Linear(in_dim, self.qkv_dim_total, with_bias=with_bias)
|
|
112
|
+
if proj_dim:
|
|
113
|
+
self.proj = rf.Linear(value_dim_total, proj_dim, with_bias=with_bias)
|
|
114
|
+
else:
|
|
115
|
+
self.proj = None
|
|
116
|
+
self.att_dropout = att_dropout
|
|
117
|
+
|
|
118
|
+
def forward_qkv(self, source: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
|
|
119
|
+
"""
|
|
120
|
+
:return: q,k,v
|
|
121
|
+
"""
|
|
122
|
+
qkv = self.qkv(source)
|
|
123
|
+
qkv = rf.split_dims(qkv, axis=self.qkv_dim_total, dims=(self.num_heads, self.qkv_dim_per_head))
|
|
124
|
+
q, k, v = rf.split(
|
|
125
|
+
qkv,
|
|
126
|
+
axis=self.qkv_dim_per_head,
|
|
127
|
+
out_dims=(self.key_dim_per_head, self.key_dim_per_head, self.value_dim_per_head),
|
|
128
|
+
)
|
|
129
|
+
return q, k, v
|
|
130
|
+
|
|
131
|
+
def attention(self, q: Tensor, k: Tensor, v: Tensor, *, kv_axis: Dim) -> Tensor:
|
|
132
|
+
"""apply attention"""
|
|
133
|
+
att = dot_attention(q, k, v, key_dim=self.key_dim_per_head, axis=kv_axis, att_dropout=self.att_dropout)
|
|
134
|
+
output, _ = rf.merge_dims(att, dims=(self.num_heads, self.value_dim_per_head), out_dim=self.value_dim_total)
|
|
135
|
+
if self.proj:
|
|
136
|
+
output = self.proj(output)
|
|
137
|
+
return output
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class SelfAttention(SelfAttentionBase):
|
|
141
|
+
"""
|
|
142
|
+
Classic self attention on sequence level
|
|
143
|
+
"""
|
|
144
|
+
|
|
145
|
+
def __call__(self, source: Tensor, *, axis: Dim) -> Tensor:
|
|
146
|
+
"""forward"""
|
|
147
|
+
q, k, v = self.forward_qkv(source)
|
|
148
|
+
kv_axis = Dim(None, name=f"{axis.name}-kv")
|
|
149
|
+
k, _ = rf.replace_dim(k, in_dim=axis, out_dim=kv_axis)
|
|
150
|
+
v, _ = rf.replace_dim(v, in_dim=axis, out_dim=kv_axis)
|
|
151
|
+
return self.attention(q, k, v, kv_axis=kv_axis)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
class CausalSelfAttention(SelfAttentionBase):
|
|
155
|
+
"""
|
|
156
|
+
Classic causal self attention
|
|
157
|
+
"""
|
|
158
|
+
|
|
159
|
+
def __call__(
|
|
160
|
+
self,
|
|
161
|
+
source: Tensor,
|
|
162
|
+
axis: Dim,
|
|
163
|
+
*,
|
|
164
|
+
state: CausalSelfAttentionState,
|
|
165
|
+
) -> Tuple[Tensor, CausalSelfAttentionState]:
|
|
166
|
+
"""forward"""
|
|
167
|
+
assert axis == single_step_dim # not implemented otherwise currently...
|
|
168
|
+
q, k, v = self.forward_qkv(source)
|
|
169
|
+
assert state
|
|
170
|
+
hist_dim = Dim(None, name="kv-history")
|
|
171
|
+
new_state = CausalSelfAttentionState()
|
|
172
|
+
k, _ = rf.cum_concat_step(k, prev_accum=state.k_accum, out_spatial_dim=hist_dim, axis=state.accum_axis)
|
|
173
|
+
v, _ = rf.cum_concat_step(v, prev_accum=state.v_accum, out_spatial_dim=hist_dim, axis=state.accum_axis)
|
|
174
|
+
new_state.k_accum = k
|
|
175
|
+
new_state.v_accum = v
|
|
176
|
+
new_state.accum_axis = hist_dim
|
|
177
|
+
output = self.attention(q, k, v, kv_axis=hist_dim)
|
|
178
|
+
return output, new_state
|
|
179
|
+
|
|
180
|
+
def default_initial_state(self, *, batch_dims: Sequence[Dim]) -> CausalSelfAttentionState:
|
|
181
|
+
"""
|
|
182
|
+
For causal attention.
|
|
183
|
+
"""
|
|
184
|
+
# Note: This dim tag is wrong. It should match to the expand_dim inside __call__.
|
|
185
|
+
# So the dim tag itself should be part of the layer state, and we need to define the initial value of it here.
|
|
186
|
+
# This is not really supported, in various ways, also including RETURNN.
|
|
187
|
+
# We just keep this code in place to be prepared for that.
|
|
188
|
+
# The reason it works right now is that we do an optimization where we replace zero init state by 0.
|
|
189
|
+
expand_dim = Dim(0, name="self_att_expand_dim_init")
|
|
190
|
+
return CausalSelfAttentionState(
|
|
191
|
+
k_accum=rf.zeros(list(batch_dims) + [expand_dim, self.num_heads, self.key_dim_per_head]),
|
|
192
|
+
v_accum=rf.zeros(list(batch_dims) + [expand_dim, self.num_heads, self.value_dim_per_head]),
|
|
193
|
+
accum_axis=expand_dim,
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
class CausalSelfAttentionState(rf.State):
|
|
198
|
+
"""
|
|
199
|
+
State for :class:`StepwiseCausalSelfAttention`.
|
|
200
|
+
"""
|
|
201
|
+
|
|
202
|
+
def __init__(self, *, k_accum: Tensor = None, v_accum: Tensor = None, accum_axis: Dim = None):
|
|
203
|
+
"""
|
|
204
|
+
:param k_accum: accumulated keys
|
|
205
|
+
:param v_accum: accumulated values
|
|
206
|
+
:param accum_axis:
|
|
207
|
+
"""
|
|
208
|
+
super().__init__()
|
|
209
|
+
self.k_accum = k_accum
|
|
210
|
+
self.v_accum = v_accum
|
|
211
|
+
self.accum_axis = accum_axis
|
|
@@ -1478,6 +1478,13 @@ class _DimMixin:
|
|
|
1478
1478
|
name="%s:batch" % self_base.description, shape=(), dtype="int32", batch_dim_axis=None
|
|
1479
1479
|
)
|
|
1480
1480
|
|
|
1481
|
+
def copy_from(self: Dim, other: Dim):
|
|
1482
|
+
"""define"""
|
|
1483
|
+
self.size = other.size
|
|
1484
|
+
self.capacity = other.capacity
|
|
1485
|
+
self.dyn_size_ext = other.dyn_size_ext
|
|
1486
|
+
self.derive_from(other)
|
|
1487
|
+
|
|
1481
1488
|
@classmethod
|
|
1482
1489
|
def get_existing_tag_from_collection(cls, other, tags, is_equal_opts=None):
|
|
1483
1490
|
"""
|
|
@@ -255,7 +255,6 @@ class TorchBackend(Backend[torch.Tensor]):
|
|
|
255
255
|
mask = tensor.get_sequence_mask_broadcast(axis=axis)
|
|
256
256
|
inf_value = get_global_inf_value()
|
|
257
257
|
tensor.raw_tensor = torch.where(mask, tensor.raw_tensor, -inf_value)
|
|
258
|
-
assert not axis.need_masking(), "not implemented"
|
|
259
258
|
out.raw_tensor = torch.softmax(tensor.raw_tensor, dim=tensor.dims.index(axis))
|
|
260
259
|
return out
|
|
261
260
|
|
|
@@ -184,3 +184,30 @@ def test_dropout():
|
|
|
184
184
|
out.mark_as_default_output(shape=(batch_dim, time_dim, in_dim))
|
|
185
185
|
|
|
186
186
|
run_model(extern_data, lambda *, epoch, step: _Net(), _forward_step)
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def test_dot_attention():
|
|
190
|
+
time_dim = Dim(Tensor("time", [batch_dim], dtype="int32"))
|
|
191
|
+
key_dim = Dim(7, name="key")
|
|
192
|
+
value_dim = Dim(13, name="value")
|
|
193
|
+
extern_data = TensorDict(
|
|
194
|
+
{
|
|
195
|
+
"q": Tensor("q", [batch_dim, time_dim, key_dim], dtype="float32"),
|
|
196
|
+
"k": Tensor("k", [batch_dim, time_dim, key_dim], dtype="float32"),
|
|
197
|
+
"v": Tensor("v", [batch_dim, time_dim, value_dim], dtype="float32"),
|
|
198
|
+
}
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
class _Net(rf.Module):
|
|
202
|
+
def __call__(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
|
|
203
|
+
kv_axis = Dim(None, name=f"kv-axis")
|
|
204
|
+
k, _ = rf.replace_dim(k, in_dim=time_dim, out_dim=kv_axis)
|
|
205
|
+
v, _ = rf.replace_dim(v, in_dim=time_dim, out_dim=kv_axis)
|
|
206
|
+
return rf.dot_attention(q, k, v, axis=kv_axis, key_dim=key_dim)
|
|
207
|
+
|
|
208
|
+
# noinspection PyShadowingNames
|
|
209
|
+
def _forward_step(*, model: _Net, extern_data: TensorDict):
|
|
210
|
+
out = model(q=extern_data["q"], k=extern_data["k"], v=extern_data["v"])
|
|
211
|
+
out.mark_as_default_output(shape=(batch_dim, time_dim, value_dim))
|
|
212
|
+
|
|
213
|
+
run_model(extern_data, lambda *, epoch, step: _Net(), _forward_step)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/demo-hyper-param-tuning.config
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/demo-record-and-push-to-webserver.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/demo-tf-chunking-blstm.12ax.config
RENAMED
|
File without changes
|
{returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/demo-tf-contribrnn-lstm.12ax.config
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/demo-tf-maxgradnorm-lstm.12ax.config
RENAMED
|
File without changes
|
{returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/demo-tf-native-lstm-lowmem.12ax.config
RENAMED
|
File without changes
|
{returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/demo-tf-native-lstm.12ax.config
RENAMED
|
File without changes
|
{returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/demo-tf-native-lstm2.12ax.config
RENAMED
|
File without changes
|
{returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/demo-tf-native-lstm2.12ax.tuned.config
RENAMED
|
File without changes
|
{returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/demo-tf-neural-transducer.12ax.config
RENAMED
|
File without changes
|
{returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/demo-tf-rec-explicit-lstm.config
RENAMED
|
File without changes
|
{returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/demo-tf-rec-explicit-rnn.config
RENAMED
|
File without changes
|
|
File without changes
|
{returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/demo-tf-search-compiled-graph.py
RENAMED
|
File without changes
|
{returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/demo-tf-vanilla-lstm.12ax.config
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/demo-upd-mult-model.lstm.12ax.config
RENAMED
|
File without changes
|
|
File without changes
|
{returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/mdlstm/IAM/IAM_lines/a01-000u-00.png
RENAMED
|
File without changes
|
{returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/mdlstm/IAM/IAM_lines/a01-007-04.png
RENAMED
|
File without changes
|
{returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/mdlstm/IAM/IAM_lines/a01-007-06.png
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/mdlstm/IAM/create_IAM_dataset.py
RENAMED
|
File without changes
|
|
File without changes
|
{returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/mdlstm/IAM/features/raw/demo.h5
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/mdlstm/artificial/create_test_h5.py
RENAMED
|
File without changes
|
{returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/mdlstm/artificial/forwardconfig
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/mdlstm/artificial_rgb/forwardconfig
RENAMED
|
File without changes
|
|
File without changes
|
{returnn-1.20230413.135919 → returnn-1.20230413.141543}/demos/mdlstm/artificial_rgb/trainconfig
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/datasets/normalization_data.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/datasets/util/feature_extraction.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/extern/WarpRna/warp-rna/.git
RENAMED
|
File without changes
|
{returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/extern/WarpRna/warp-rna/.gitignore
RENAMED
|
File without changes
|
{returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/extern/WarpRna/warp-rna/LICENSE
RENAMED
|
File without changes
|
{returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/extern/WarpRna/warp-rna/README.md
RENAMED
|
File without changes
|
{returnn-1.20230413.135919 → returnn-1.20230413.141543}/returnn/extern/WarpRna/warp-rna/aligner.gif
RENAMED
|
File without changes
|