returnn 1.20230413.100132__tar.gz → 1.20230413.141543__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {returnn-1.20230413.100132/returnn.egg-info → returnn-1.20230413.141543}/PKG-INFO +1 -1
- returnn-1.20230413.141543/_setup_info_generated.py +2 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/frontend/__init__.py +5 -1
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/frontend/_backend.py +17 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/frontend/array_.py +35 -1
- returnn-1.20230413.141543/returnn/frontend/attention.py +211 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/tensor/_dim_extra.py +7 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/tensor/_tensor_extra.py +3 -1
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/tf/frontend_layers/_backend.py +14 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/torch/frontend/_backend.py +25 -3
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/util/basic.py +15 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543/returnn.egg-info}/PKG-INFO +1 -1
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn.egg-info/SOURCES.txt +1 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tests/test_rf_base.py +27 -0
- returnn-1.20230413.100132/_setup_info_generated.py +0 -2
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/.editorconfig +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/.gitignore +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/.gitmodules +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/.kateconfig +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/CHANGELOG.md +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/CODEOWNERS +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/CONTRIBUTING.md +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/LICENSE +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/MANIFEST.in +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/README.rst +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/__init__.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/12AX.cluster_map +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/_setup_returnn_env.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/demo-fwd.config +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/demo-horovod-mpi.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/demo-horovod-mpi.py.sh +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/demo-horovod-mpi.sh +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/demo-hyper-param-tuning.config +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/demo-iter-dataset.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/demo-list-devices.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/demo-lua-torch-layer.config +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/demo-pretrain.config +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/demo-record-and-push-to-webserver.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/demo-returnn-as-framework.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/demo-rf.config +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/demo-rhn-enwik8.config +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/demo-sprint-interface.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/demo-tf-att-copy.config +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/demo-tf-attention.config +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/demo-tf-chunking-blstm.12ax.config +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/demo-tf-contribrnn-lstm.12ax.config +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/demo-tf-enc-dec.config +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/demo-tf-hard-att-copy.config +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/demo-tf-lstm-benchmark.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/demo-tf-maxgradnorm-lstm.12ax.config +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/demo-tf-native-lstm-lowmem.12ax.config +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/demo-tf-native-lstm.12ax.config +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/demo-tf-native-lstm2.12ax.config +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/demo-tf-native-lstm2.12ax.tuned.config +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/demo-tf-neural-transducer.12ax.config +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/demo-tf-rec-explicit-lstm.config +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/demo-tf-rec-explicit-rnn.config +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/demo-tf-rec-self-att.config +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/demo-tf-search-compiled-graph.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/demo-tf-vanilla-lstm.12ax.config +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/demo-timit-lstm-ctc.config +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/demo-torch.config +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/demo-upd-mult-model.lstm.12ax.config +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/demo.sh +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/mdlstm/IAM/IAM_lines/a01-000u-00.png +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/mdlstm/IAM/IAM_lines/a01-007-04.png +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/mdlstm/IAM/IAM_lines/a01-007-06.png +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/mdlstm/IAM/README.txt +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/mdlstm/IAM/chars.txt +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/mdlstm/IAM/config_demo +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/mdlstm/IAM/config_fwd +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/mdlstm/IAM/config_real +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/mdlstm/IAM/create_IAM_dataset.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/mdlstm/IAM/decode.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/mdlstm/IAM/features/raw/demo.h5 +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/mdlstm/IAM/go.sh +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/mdlstm/IAM/lines.txt +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/mdlstm/IAM/split/eval.txt +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/mdlstm/IAM/split/train.txt +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/mdlstm/IAM/split/valid.txt +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/mdlstm/README.md +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/mdlstm/artificial/create_test_h5.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/mdlstm/artificial/forwardconfig +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/mdlstm/artificial/go.sh +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/mdlstm/artificial/trainconfig +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/mdlstm/artificial_rgb/create_test_h5.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/mdlstm/artificial_rgb/forwardconfig +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/mdlstm/artificial_rgb/go.sh +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/mdlstm/artificial_rgb/trainconfig +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/pyproject.toml +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/requirements.txt +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/__init__.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/__main__.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/__old_mod_loader__.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/__setup__.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/config.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/datasets/__init__.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/datasets/audio.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/datasets/basic.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/datasets/bundle_file.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/datasets/cached.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/datasets/cached2.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/datasets/generating.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/datasets/hdf.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/datasets/lm.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/datasets/map.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/datasets/meta.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/datasets/multi_proc.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/datasets/normalization_data.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/datasets/numpy_dump.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/datasets/raw_wav.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/datasets/sprint.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/datasets/stereo.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/datasets/util/__init__.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/datasets/util/feature_extraction.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/datasets/util/vocabulary.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/engine/__init__.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/engine/base.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/engine/batch.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/extern/WarpRna/__init__.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/extern/WarpRna/__main__.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/extern/WarpRna/warp-rna/.git +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/extern/WarpRna/warp-rna/.gitignore +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/extern/WarpRna/warp-rna/LICENSE +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/extern/WarpRna/warp-rna/README.md +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/extern/WarpRna/warp-rna/aligner.gif +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/extern/WarpRna/warp-rna/check.png +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/extern/WarpRna/warp-rna/core.cu +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/extern/WarpRna/warp-rna/core.h +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/extern/WarpRna/warp-rna/core_cpu.cpp +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/extern/WarpRna/warp-rna/pytorch_binding/LICENSE +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/extern/WarpRna/warp-rna/pytorch_binding/MANIFEST.in +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/extern/WarpRna/warp-rna/pytorch_binding/README.md +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/extern/WarpRna/warp-rna/pytorch_binding/binding.cpp +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/extern/WarpRna/warp-rna/pytorch_binding/core.cu +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/extern/WarpRna/warp-rna/pytorch_binding/core.h +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/extern/WarpRna/warp-rna/pytorch_binding/requirements.txt +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/extern/WarpRna/warp-rna/pytorch_binding/setup.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/extern/WarpRna/warp-rna/pytorch_binding/warp_rna/__init__.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/extern/WarpRna/warp-rna/pytorch_binding/warp_rna/test.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/extern/WarpRna/warp-rna/ref_rna.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/extern/WarpRna/warp-rna/tensorflow_binding/setup.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/extern/WarpRna/warp-rna/tensorflow_binding/src/warp_rna_op.cc +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/extern/WarpRna/warp-rna/tensorflow_binding/src/warp_rna_op_kernel_tmpl.h +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/extern/WarpRna/warp-rna/tensorflow_binding/warp_rna/__init__.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/extern/WarpRna/warp-rna/test.cpp +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/extern/__init__.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/extern/graph_editor/README.md +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/extern/graph_editor/__init__.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/extern/graph_editor/edit.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/extern/graph_editor/reroute.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/extern/graph_editor/select.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/extern/graph_editor/subgraph.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/extern/graph_editor/transform.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/extern/graph_editor/util.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/frontend/_numpy_backend.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/frontend/_utils.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/frontend/cond.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/frontend/const.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/frontend/dims.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/frontend/dropout.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/frontend/dtype.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/frontend/init.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/frontend/linear.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/frontend/loss.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/frontend/math_.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/frontend/matmul.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/frontend/module.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/frontend/parameter.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/frontend/rand.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/frontend/reduce.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/frontend/run_ctx.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/frontend/state.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/frontend/types.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/import_/__init__.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/import_/common.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/import_/git.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/import_/import_.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/learning_rate_control.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/log.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/native_op.cpp +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/native_op.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/pretrain.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/sprint/__init__.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/sprint/cache.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/sprint/control.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/sprint/error_signals.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/sprint/extern_interface.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/sprint/interface.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/tensor/README.md +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/tensor/__init__.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/tensor/_tensor_mixin_base.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/tensor/_tensor_op_overloads.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/tensor/control_flow_ctx.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/tensor/dim.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/tensor/marked_dim.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/tensor/tensor.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/tensor/tensor_dict.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/tf/__init__.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/tf/compat.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/tf/data_pipeline.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/tf/distributed.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/tf/engine.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/tf/frontend_layers/__init__.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/tf/frontend_layers/_utils.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/tf/frontend_layers/config_entry_points.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/tf/frontend_layers/debug_eager_mode.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/tf/frontend_layers/dims.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/tf/frontend_layers/layer.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/tf/frontend_layers/make_layer.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/tf/frontend_layers/prev_tensor_ref.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/tf/frontend_low_level/__init__.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/tf/frontend_low_level/_backend.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/tf/horovod.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/tf/hyper_param_tuning.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/tf/layers/__init__.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/tf/layers/base.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/tf/layers/basic.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/tf/layers/rec.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/tf/layers/segmental_model.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/tf/layers/signal_processing.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/tf/native_op.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/tf/network.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/tf/sprint.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/tf/updater.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/tf/util/__init__.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/tf/util/basic.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/tf/util/data.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/tf/util/ken_lm.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/tf/util/open_fst.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/torch/README.md +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/torch/__init__.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/torch/data/__init__.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/torch/data/pipeline.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/torch/data/returnn_dataset_wrapper.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/torch/data/tensor_utils.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/torch/engine.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/torch/frontend/__init__.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/torch/frontend/_rand.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/torch/frontend/bridge.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/torch/functional/README.md +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/torch/functional/__init__.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/torch/updater.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/util/__init__.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/util/better_exchook.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/util/bpe.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/util/debug.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/util/debug_helpers.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/util/fsa.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/util/literal_py_to_pickle.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/util/pprint.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/util/py-to-pickle.cpp +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/util/py_compat.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/util/sig_proc.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/util/task_system.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn.egg-info/dependency_links.txt +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn.egg-info/top_level.txt +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/rnn.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/setup.cfg +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/setup.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tests/DummySprintExec.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tests/PyCharm-inspection-profile.xml +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tests/PyCharm.idea/.gitignore +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tests/PyCharm.idea/.name +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tests/PyCharm.idea/codeStyleSettings.xml +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tests/PyCharm.idea/codeStyles/Project.xml +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tests/PyCharm.idea/codeStyles/codeStyleConfig.xml +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tests/PyCharm.idea/inspectionProfiles/Project_Default.xml +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tests/PyCharm.idea/inspectionProfiles/profiles_settings.xml +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tests/PyCharm.idea/misc.xml +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tests/PyCharm.idea/modules.xml +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tests/PyCharm.idea/returnn.iml +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tests/PyCharm.idea/scopes/scope_settings.xml +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tests/_set_num_threads1.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tests/_setup_returnn_env.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tests/_setup_test_env.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tests/bpe-unicode-demo.codes +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tests/bpe-unicode-demo.vocab +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tests/lexicon_opt.fst +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tests/lexicon_opt.isyms +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tests/lexicon_opt.jpg +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tests/lexicon_opt.osyms +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tests/lint_common.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tests/pycharm-inspect.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tests/pylint.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tests/returnn-as-framework.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tests/rf_utils.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tests/spelling.dic +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tests/test_Config.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tests/test_Dataset.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tests/test_Fsa.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tests/test_GeneratingDataset.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tests/test_HDFDataset.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tests/test_LearningRateControl.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tests/test_Log.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tests/test_MultiProcDataset.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tests/test_PTDataset.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tests/test_Pretrain.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tests/test_ResNet.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tests/test_SprintDataset.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tests/test_SprintInterface.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tests/test_TFEngine.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tests/test_TFNativeOp.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tests/test_TFNetworkLayer.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tests/test_TFNetworkRecLayer.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tests/test_TFNetworkSigProcLayer.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tests/test_TFUpdater.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tests/test_TFUtil.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tests/test_TF_determinism.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tests/test_TaskSystem.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tests/test_TaskSystem_SharedMem.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tests/test_TranslationDataset.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tests/test_Util.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tests/test_demos.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tests/test_fork_exec.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tests/test_hdf_dump.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tests/test_tensor.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tests/test_tools.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tests/test_torch_frontend.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tests/test_torch_internal_frontend.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tools/_setup_returnn_env.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tools/analyze-dataset-batches.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tools/bliss-collect-seq-lens.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tools/bliss-dump-text.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tools/bliss-get-segment-names.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tools/bliss-to-ogg-zip.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tools/bpe-create-lexicon.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tools/calculate-word-error-rate.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tools/cleanup-old-models.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tools/collect-orth-symbols.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tools/collect-words.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tools/compile_native_op.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tools/compile_tf_graph.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tools/debug-dump-search-scores.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tools/debug-plot-search-scores.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tools/dump-dataset-raw-strings.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tools/dump-dataset.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tools/dump-forward-stats.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tools/dump-forward.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tools/dump-network-json.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tools/dump-pickle.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tools/extract_state_tying_from_dataset.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tools/get-attention-weights.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tools/get-best-model-epoch.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tools/hdf_dump.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tools/hdf_dump_translation_dataset.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tools/import-blocks-mt-model.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tools/import-t2t-mt-model.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tools/lattice_rescorer/.gitignore +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tools/lattice_rescorer/Makefile +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tools/lattice_rescorer/README.md +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tools/lattice_rescorer/example/README.md +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tools/lattice_rescorer/example/libs_list +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tools/lattice_rescorer/example/network.040/i600_m600_m600.sgd_b16_lr0_cl2.newbobabs.config +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tools/lattice_rescorer/example/network.040/i600_m600_m600.sgd_b16_lr0_cl2.newbobabs.keep_over_epoch.lstm2.config +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tools/lattice_rescorer/example/rescore_lattice.sh +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tools/lattice_rescorer/example/state_vars_list +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tools/lattice_rescorer/example/tensor_names_list +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tools/lattice_rescorer/file.h +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tools/lattice_rescorer/htklatticerescorer.cc +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tools/lattice_rescorer/htklatticerescorer.h +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tools/lattice_rescorer/main.cc +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tools/lattice_rescorer/rescorer.h +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tools/lattice_rescorer/vocabulary.cc +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tools/lattice_rescorer/vocabulary.h +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tools/tf_avg_checkpoints.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tools/tf_inspect_checkpoint.py +0 -0
- {returnn-1.20230413.100132 → returnn-1.20230413.141543}/tools/tf_inspect_summary_log.py +0 -0
|
@@ -14,8 +14,11 @@ The convention for the user is to do::
|
|
|
14
14
|
# Some most come first here when others directly use it,
|
|
15
15
|
# e.g. `rf.Module` as a baseclass.
|
|
16
16
|
from .module import *
|
|
17
|
+
from .state import *
|
|
17
18
|
|
|
19
|
+
# Now the rest, in alphabetical order.
|
|
18
20
|
from .array_ import *
|
|
21
|
+
from .attention import *
|
|
19
22
|
from .cond import *
|
|
20
23
|
from .const import *
|
|
21
24
|
from .dims import *
|
|
@@ -29,9 +32,10 @@ from .parameter import *
|
|
|
29
32
|
from .rand import *
|
|
30
33
|
from .reduce import *
|
|
31
34
|
from .run_ctx import *
|
|
32
|
-
from .state import *
|
|
33
35
|
from .types import *
|
|
34
36
|
|
|
37
|
+
# Modules not in the main namespace but in sub namespaces.
|
|
35
38
|
from . import init
|
|
36
39
|
|
|
40
|
+
# And some functions from the internal backend API.
|
|
37
41
|
from ._backend import select_backend_torch, select_backend_returnn_layers_tf
|
|
@@ -293,6 +293,21 @@ class Backend(Generic[T]):
|
|
|
293
293
|
"""
|
|
294
294
|
raise NotImplementedError
|
|
295
295
|
|
|
296
|
+
@staticmethod
|
|
297
|
+
def cum_concat_step(source: Tensor, *, prev_accum: Tensor, axis: Dim, out_spatial_dim: Dim) -> Tensor:
|
|
298
|
+
"""
|
|
299
|
+
Concatenates all previous frames over a time-axis.
|
|
300
|
+
See RETURNN :class:`CumConcatLayer` for details.
|
|
301
|
+
|
|
302
|
+
:param source: same dims as prev_accum except for the accum axis
|
|
303
|
+
:param prev_accum: previous accumulated tensor, shape {..., axis}
|
|
304
|
+
:param axis: the axis to accumulate over
|
|
305
|
+
:param out_spatial_dim: the spatial dim of the output will be this dim. like axis+1.
|
|
306
|
+
:return: accumulated. accumulated shape {..., out_spatial_dim},
|
|
307
|
+
same shape as prev_accum with axis replaced by out_spatial_dim.
|
|
308
|
+
"""
|
|
309
|
+
raise NotImplementedError
|
|
310
|
+
|
|
296
311
|
# Restrict the possible activation function names,
|
|
297
312
|
# to not get unexpected behavior,
|
|
298
313
|
# or unwanted incompatibilities.
|
|
@@ -645,6 +660,8 @@ class Backend(Generic[T]):
|
|
|
645
660
|
"""
|
|
646
661
|
# This default implementation works fine as long as the backend
|
|
647
662
|
# does not have special treatments of Tensor and dim tags itself (like TF net dict backend).
|
|
663
|
+
if not out_dim.is_dim_known():
|
|
664
|
+
out_dim.copy_from(in_dim)
|
|
648
665
|
out = source.copy_template_replace_dim_tag(axis=source.get_axis_from_description(in_dim), new_dim_tag=out_dim)
|
|
649
666
|
out.raw_tensor = source.raw_tensor
|
|
650
667
|
return out
|
|
@@ -12,7 +12,18 @@ from .types import RawTensorTypes
|
|
|
12
12
|
|
|
13
13
|
T = TypeVar("T")
|
|
14
14
|
|
|
15
|
-
__all__ = [
|
|
15
|
+
__all__ = [
|
|
16
|
+
"convert_to_tensor",
|
|
17
|
+
"constant",
|
|
18
|
+
"cast",
|
|
19
|
+
"merge_dims",
|
|
20
|
+
"split_dims",
|
|
21
|
+
"split",
|
|
22
|
+
"cum_concat_step",
|
|
23
|
+
"masked_select",
|
|
24
|
+
"pack",
|
|
25
|
+
"gather",
|
|
26
|
+
]
|
|
16
27
|
|
|
17
28
|
|
|
18
29
|
def convert_to_tensor(
|
|
@@ -169,6 +180,29 @@ def split(source: Tensor, *, axis: Dim, out_dims: Sequence[Dim]) -> Tuple[Tensor
|
|
|
169
180
|
return source._raw_backend.split(source, axis=axis, out_dims=out_dims)
|
|
170
181
|
|
|
171
182
|
|
|
183
|
+
def cum_concat_step(
|
|
184
|
+
source: Tensor, *, prev_accum: Tensor, axis: Dim, out_spatial_dim: Optional[Dim] = None
|
|
185
|
+
) -> Tuple[Tensor, Dim]:
|
|
186
|
+
"""
|
|
187
|
+
Concatenates all previous frames over a time-axis.
|
|
188
|
+
See RETURNN :class:`CumConcatLayer` for details.
|
|
189
|
+
|
|
190
|
+
:param source: same dims as prev_accum except for the accum axis
|
|
191
|
+
:param prev_accum: previous accumulated tensor, shape {..., axis}
|
|
192
|
+
:param axis: the axis to accumulate over
|
|
193
|
+
:param out_spatial_dim: if given, the spatial dim of the output will be this dim. axis+1.
|
|
194
|
+
:return: (accumulated, out_spatial_dim). accumulated shape {..., out_spatial_dim},
|
|
195
|
+
same shape as prev_accum with axis replaced by out_spatial_dim.
|
|
196
|
+
"""
|
|
197
|
+
if not out_spatial_dim:
|
|
198
|
+
out_spatial_dim = axis + 1
|
|
199
|
+
# noinspection PyProtectedMember
|
|
200
|
+
return (
|
|
201
|
+
source._raw_backend.cum_concat_step(source, prev_accum=prev_accum, axis=axis, out_spatial_dim=out_spatial_dim),
|
|
202
|
+
out_spatial_dim,
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
|
|
172
206
|
def masked_select(
|
|
173
207
|
tensor: Tensor, *, mask: Tensor, dims: Sequence[Dim], out_dim: Optional[Dim] = None
|
|
174
208
|
) -> Tuple[Tensor, Dim]:
|
|
@@ -0,0 +1,211 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Attention
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
from typing import Tuple, Union, Optional, Sequence
|
|
8
|
+
from returnn.util.py_compat import Protocol
|
|
9
|
+
from returnn.tensor import Tensor, Dim, single_step_dim
|
|
10
|
+
import returnn.frontend as rf
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"AttentionFunc",
|
|
15
|
+
"dot_attention",
|
|
16
|
+
"SelfAttentionBase",
|
|
17
|
+
"SelfAttention",
|
|
18
|
+
"CausalSelfAttention",
|
|
19
|
+
"CausalSelfAttentionState",
|
|
20
|
+
]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class AttentionFunc(Protocol):
|
|
24
|
+
"""Protocol defining a generic attention function"""
|
|
25
|
+
|
|
26
|
+
def __call__(
|
|
27
|
+
self,
|
|
28
|
+
query: Tensor,
|
|
29
|
+
keys: Tensor,
|
|
30
|
+
values: Tensor,
|
|
31
|
+
*,
|
|
32
|
+
key_dim: Dim,
|
|
33
|
+
axis: Dim,
|
|
34
|
+
att_dropout: float = 0.1,
|
|
35
|
+
):
|
|
36
|
+
...
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def dot_attention(
|
|
40
|
+
query: Tensor, keys: Tensor, values: Tensor, *, key_dim: Dim, axis: Dim, att_dropout: float = 0.0
|
|
41
|
+
) -> Tensor:
|
|
42
|
+
"""
|
|
43
|
+
Calculates attention over the given axis, for given key dim.
|
|
44
|
+
Any other unrelated axes do not matter here.
|
|
45
|
+
This can be used for multi-head or single head.
|
|
46
|
+
The query can have other dimensions or not.
|
|
47
|
+
|
|
48
|
+
:param query: {..., key_dim}. For self-attention, do not use the `axis` as in `keys` and `values`,
|
|
49
|
+
but rather replace it by another new dim via :func:`replace_dim`.
|
|
50
|
+
:param keys: {..., axis, key_dim}
|
|
51
|
+
:param values: {..., axis}
|
|
52
|
+
:param key_dim: dim in keys and query, to be reduced to calculate the attention energies.
|
|
53
|
+
:param axis: in keys and values, to apply attention on. softmax will be over this axis, and then it will be reduced
|
|
54
|
+
:param att_dropout: dropout for attention weights
|
|
55
|
+
:return: like values but with axis removed, and maybe any additional axes from query
|
|
56
|
+
"""
|
|
57
|
+
query *= key_dim.dimension**-0.5
|
|
58
|
+
energy = rf.matmul(query, keys, reduce=key_dim)
|
|
59
|
+
att_weights = rf.softmax(energy, axis=axis)
|
|
60
|
+
att_weights = rf.dropout(att_weights, att_dropout, axis=axis)
|
|
61
|
+
# Masking not needed because softmax should already have masked,
|
|
62
|
+
# so we have 0.0 att weights for padded frames.
|
|
63
|
+
att = rf.matmul(att_weights, values, reduce=axis, disable_masking=True)
|
|
64
|
+
return att
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
# noinspection PyAbstractClass
|
|
68
|
+
class SelfAttentionBase(rf.Module):
|
|
69
|
+
"""
|
|
70
|
+
Shared base class for (non-causal) self attention (:class:`SelfAttention`)
|
|
71
|
+
and causal self attention (:class:`CausalSelfAttention`).
|
|
72
|
+
|
|
73
|
+
It uses :func:`dot_attention` for multi-headed dot-attention.
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
def __init__(
|
|
77
|
+
self,
|
|
78
|
+
in_dim: Dim,
|
|
79
|
+
proj_dim: Optional[Dim],
|
|
80
|
+
*,
|
|
81
|
+
key_dim_total: Dim,
|
|
82
|
+
value_dim_total: Dim,
|
|
83
|
+
num_heads: Union[int, Dim],
|
|
84
|
+
with_bias: bool = True,
|
|
85
|
+
att_dropout: float = 0.1,
|
|
86
|
+
):
|
|
87
|
+
"""
|
|
88
|
+
:param in_dim: input dim
|
|
89
|
+
:param proj_dim: if given, will add a final linear projection to this dim.
|
|
90
|
+
otherwise no projection after the attention
|
|
91
|
+
:param key_dim_total: total key dim. should be a multiple of num_heads
|
|
92
|
+
:param value_dim_total: total value dim. should be a multiple of num_heads
|
|
93
|
+
:param num_heads: number of heads
|
|
94
|
+
:param with_bias: whether to add bias to qkv and proj linear projections.
|
|
95
|
+
Was False in original Transformer, but many recent implementations use True by default.
|
|
96
|
+
Also see: https://github.com/rwth-i6/returnn_common/issues/234.
|
|
97
|
+
:param att_dropout: dropout for attention weights
|
|
98
|
+
"""
|
|
99
|
+
super().__init__()
|
|
100
|
+
self.in_dim = in_dim
|
|
101
|
+
self.out_dim = proj_dim if proj_dim else value_dim_total
|
|
102
|
+
if isinstance(num_heads, int):
|
|
103
|
+
num_heads = Dim(num_heads, name="num_heads")
|
|
104
|
+
self.key_dim_total = key_dim_total
|
|
105
|
+
self.key_dim_per_head = key_dim_total.div_left(num_heads)
|
|
106
|
+
self.value_dim_total = value_dim_total
|
|
107
|
+
self.value_dim_per_head = value_dim_total.div_left(num_heads)
|
|
108
|
+
self.num_heads = num_heads
|
|
109
|
+
self.qkv_dim_total = 2 * key_dim_total + value_dim_total
|
|
110
|
+
self.qkv_dim_per_head = 2 * self.key_dim_per_head + self.value_dim_per_head
|
|
111
|
+
self.qkv = rf.Linear(in_dim, self.qkv_dim_total, with_bias=with_bias)
|
|
112
|
+
if proj_dim:
|
|
113
|
+
self.proj = rf.Linear(value_dim_total, proj_dim, with_bias=with_bias)
|
|
114
|
+
else:
|
|
115
|
+
self.proj = None
|
|
116
|
+
self.att_dropout = att_dropout
|
|
117
|
+
|
|
118
|
+
def forward_qkv(self, source: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
|
|
119
|
+
"""
|
|
120
|
+
:return: q,k,v
|
|
121
|
+
"""
|
|
122
|
+
qkv = self.qkv(source)
|
|
123
|
+
qkv = rf.split_dims(qkv, axis=self.qkv_dim_total, dims=(self.num_heads, self.qkv_dim_per_head))
|
|
124
|
+
q, k, v = rf.split(
|
|
125
|
+
qkv,
|
|
126
|
+
axis=self.qkv_dim_per_head,
|
|
127
|
+
out_dims=(self.key_dim_per_head, self.key_dim_per_head, self.value_dim_per_head),
|
|
128
|
+
)
|
|
129
|
+
return q, k, v
|
|
130
|
+
|
|
131
|
+
def attention(self, q: Tensor, k: Tensor, v: Tensor, *, kv_axis: Dim) -> Tensor:
|
|
132
|
+
"""apply attention"""
|
|
133
|
+
att = dot_attention(q, k, v, key_dim=self.key_dim_per_head, axis=kv_axis, att_dropout=self.att_dropout)
|
|
134
|
+
output, _ = rf.merge_dims(att, dims=(self.num_heads, self.value_dim_per_head), out_dim=self.value_dim_total)
|
|
135
|
+
if self.proj:
|
|
136
|
+
output = self.proj(output)
|
|
137
|
+
return output
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class SelfAttention(SelfAttentionBase):
|
|
141
|
+
"""
|
|
142
|
+
Classic self attention on sequence level
|
|
143
|
+
"""
|
|
144
|
+
|
|
145
|
+
def __call__(self, source: Tensor, *, axis: Dim) -> Tensor:
|
|
146
|
+
"""forward"""
|
|
147
|
+
q, k, v = self.forward_qkv(source)
|
|
148
|
+
kv_axis = Dim(None, name=f"{axis.name}-kv")
|
|
149
|
+
k, _ = rf.replace_dim(k, in_dim=axis, out_dim=kv_axis)
|
|
150
|
+
v, _ = rf.replace_dim(v, in_dim=axis, out_dim=kv_axis)
|
|
151
|
+
return self.attention(q, k, v, kv_axis=kv_axis)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
class CausalSelfAttention(SelfAttentionBase):
|
|
155
|
+
"""
|
|
156
|
+
Classic causal self attention
|
|
157
|
+
"""
|
|
158
|
+
|
|
159
|
+
def __call__(
|
|
160
|
+
self,
|
|
161
|
+
source: Tensor,
|
|
162
|
+
axis: Dim,
|
|
163
|
+
*,
|
|
164
|
+
state: CausalSelfAttentionState,
|
|
165
|
+
) -> Tuple[Tensor, CausalSelfAttentionState]:
|
|
166
|
+
"""forward"""
|
|
167
|
+
assert axis == single_step_dim # not implemented otherwise currently...
|
|
168
|
+
q, k, v = self.forward_qkv(source)
|
|
169
|
+
assert state
|
|
170
|
+
hist_dim = Dim(None, name="kv-history")
|
|
171
|
+
new_state = CausalSelfAttentionState()
|
|
172
|
+
k, _ = rf.cum_concat_step(k, prev_accum=state.k_accum, out_spatial_dim=hist_dim, axis=state.accum_axis)
|
|
173
|
+
v, _ = rf.cum_concat_step(v, prev_accum=state.v_accum, out_spatial_dim=hist_dim, axis=state.accum_axis)
|
|
174
|
+
new_state.k_accum = k
|
|
175
|
+
new_state.v_accum = v
|
|
176
|
+
new_state.accum_axis = hist_dim
|
|
177
|
+
output = self.attention(q, k, v, kv_axis=hist_dim)
|
|
178
|
+
return output, new_state
|
|
179
|
+
|
|
180
|
+
def default_initial_state(self, *, batch_dims: Sequence[Dim]) -> CausalSelfAttentionState:
|
|
181
|
+
"""
|
|
182
|
+
For causal attention.
|
|
183
|
+
"""
|
|
184
|
+
# Note: This dim tag is wrong. It should match to the expand_dim inside __call__.
|
|
185
|
+
# So the dim tag itself should be part of the layer state, and we need to define the initial value of it here.
|
|
186
|
+
# This is not really supported, in various ways, also including RETURNN.
|
|
187
|
+
# We just keep this code in place to be prepared for that.
|
|
188
|
+
# The reason it works right now is that we do an optimization where we replace zero init state by 0.
|
|
189
|
+
expand_dim = Dim(0, name="self_att_expand_dim_init")
|
|
190
|
+
return CausalSelfAttentionState(
|
|
191
|
+
k_accum=rf.zeros(list(batch_dims) + [expand_dim, self.num_heads, self.key_dim_per_head]),
|
|
192
|
+
v_accum=rf.zeros(list(batch_dims) + [expand_dim, self.num_heads, self.value_dim_per_head]),
|
|
193
|
+
accum_axis=expand_dim,
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
class CausalSelfAttentionState(rf.State):
|
|
198
|
+
"""
|
|
199
|
+
State for :class:`StepwiseCausalSelfAttention`.
|
|
200
|
+
"""
|
|
201
|
+
|
|
202
|
+
def __init__(self, *, k_accum: Tensor = None, v_accum: Tensor = None, accum_axis: Dim = None):
|
|
203
|
+
"""
|
|
204
|
+
:param k_accum: accumulated keys
|
|
205
|
+
:param v_accum: accumulated values
|
|
206
|
+
:param accum_axis:
|
|
207
|
+
"""
|
|
208
|
+
super().__init__()
|
|
209
|
+
self.k_accum = k_accum
|
|
210
|
+
self.v_accum = v_accum
|
|
211
|
+
self.accum_axis = accum_axis
|
|
@@ -1478,6 +1478,13 @@ class _DimMixin:
|
|
|
1478
1478
|
name="%s:batch" % self_base.description, shape=(), dtype="int32", batch_dim_axis=None
|
|
1479
1479
|
)
|
|
1480
1480
|
|
|
1481
|
+
def copy_from(self: Dim, other: Dim):
|
|
1482
|
+
"""define"""
|
|
1483
|
+
self.size = other.size
|
|
1484
|
+
self.capacity = other.capacity
|
|
1485
|
+
self.dyn_size_ext = other.dyn_size_ext
|
|
1486
|
+
self.derive_from(other)
|
|
1487
|
+
|
|
1481
1488
|
@classmethod
|
|
1482
1489
|
def get_existing_tag_from_collection(cls, other, tags, is_equal_opts=None):
|
|
1483
1490
|
"""
|
|
@@ -2668,13 +2668,15 @@ class _TensorMixin(_TensorMixinBase):
|
|
|
2668
2668
|
|
|
2669
2669
|
def get_sequence_mask_broadcast(self: Tensor, axis=None) -> _t.RawTensorType:
|
|
2670
2670
|
"""
|
|
2671
|
-
:param int|None axis:
|
|
2671
|
+
:param Dim|int|None axis:
|
|
2672
2672
|
:return: seq mask of shape ((batch,time) or (time,batch)) + (1,)s for remaining dims
|
|
2673
2673
|
if BT or TB major, and axis is T or None.
|
|
2674
2674
|
In general compatible to placeholder, i.e. same ndim, with broadcast dims.
|
|
2675
2675
|
We assert here that the axis is dynamic (:func:`is_axis_dynamic`), i.e. we have the size.
|
|
2676
2676
|
:rtype: tf.Tensor
|
|
2677
2677
|
"""
|
|
2678
|
+
if isinstance(axis, Dim):
|
|
2679
|
+
axis = self.get_axis_from_description(axis)
|
|
2678
2680
|
if axis is None:
|
|
2679
2681
|
assert self.time_dim_axis is not None
|
|
2680
2682
|
axis = self.time_dim_axis
|
{returnn-1.20230413.100132 → returnn-1.20230413.141543}/returnn/tf/frontend_layers/_backend.py
RENAMED
|
@@ -192,6 +192,20 @@ class ReturnnLayersBackend(Backend[Layer]):
|
|
|
192
192
|
for i, dim in enumerate(out_dims)
|
|
193
193
|
)
|
|
194
194
|
|
|
195
|
+
@staticmethod
|
|
196
|
+
def cum_concat_step(source: Tensor, *, prev_accum: Tensor, axis: Dim, out_spatial_dim: Dim) -> Tensor:
|
|
197
|
+
"""cum_concat_step"""
|
|
198
|
+
return rfl.make_layer(
|
|
199
|
+
{
|
|
200
|
+
"class": "cum_concat",
|
|
201
|
+
"from": source,
|
|
202
|
+
"state": {"state": prev_accum},
|
|
203
|
+
"out_spatial_dim": out_spatial_dim,
|
|
204
|
+
"axis": axis,
|
|
205
|
+
},
|
|
206
|
+
name="cum_concat",
|
|
207
|
+
)
|
|
208
|
+
|
|
195
209
|
@staticmethod
|
|
196
210
|
def activation(tensor: Tensor, func: str) -> Tensor:
|
|
197
211
|
"""activation"""
|
|
@@ -9,7 +9,7 @@ import torch
|
|
|
9
9
|
import numpy
|
|
10
10
|
|
|
11
11
|
from returnn.tensor import Tensor, Dim
|
|
12
|
-
from returnn.util.basic import prod, NotSpecified
|
|
12
|
+
from returnn.util.basic import prod, NotSpecified, get_global_inf_value
|
|
13
13
|
|
|
14
14
|
# noinspection PyProtectedMember
|
|
15
15
|
from returnn.frontend._backend import Backend
|
|
@@ -212,6 +212,20 @@ class TorchBackend(Backend[torch.Tensor]):
|
|
|
212
212
|
out.raw_tensor = out_raw_list[i]
|
|
213
213
|
return out_tuple
|
|
214
214
|
|
|
215
|
+
@staticmethod
|
|
216
|
+
def cum_concat_step(source: Tensor, *, prev_accum: Tensor, axis: Dim, out_spatial_dim: Dim) -> Tensor:
|
|
217
|
+
"""cum concat step"""
|
|
218
|
+
out = prev_accum.copy_template_replace_dim_tag(
|
|
219
|
+
axis=prev_accum.get_axis_from_description(axis),
|
|
220
|
+
new_dim_tag=out_spatial_dim,
|
|
221
|
+
name=f"{source.name}/cum_concat_step",
|
|
222
|
+
)
|
|
223
|
+
source_ = source.copy_compatible_to(prev_accum)
|
|
224
|
+
out.raw_tensor = torch.cat(
|
|
225
|
+
(prev_accum.raw_tensor, source_.raw_tensor), dim=prev_accum.get_axis_from_description(axis)
|
|
226
|
+
)
|
|
227
|
+
return out
|
|
228
|
+
|
|
215
229
|
@staticmethod
|
|
216
230
|
def activation_raw(raw_tensor: torch.Tensor, func: str) -> torch.Tensor:
|
|
217
231
|
"""
|
|
@@ -236,7 +250,11 @@ class TorchBackend(Backend[torch.Tensor]):
|
|
|
236
250
|
:return: softmax over axis
|
|
237
251
|
"""
|
|
238
252
|
out = tensor.copy_template("softmax")
|
|
239
|
-
|
|
253
|
+
if axis.need_masking():
|
|
254
|
+
tensor = tensor.copy()
|
|
255
|
+
mask = tensor.get_sequence_mask_broadcast(axis=axis)
|
|
256
|
+
inf_value = get_global_inf_value()
|
|
257
|
+
tensor.raw_tensor = torch.where(mask, tensor.raw_tensor, -inf_value)
|
|
240
258
|
out.raw_tensor = torch.softmax(tensor.raw_tensor, dim=tensor.dims.index(axis))
|
|
241
259
|
return out
|
|
242
260
|
|
|
@@ -248,7 +266,11 @@ class TorchBackend(Backend[torch.Tensor]):
|
|
|
248
266
|
:return: log_softmax over axis
|
|
249
267
|
"""
|
|
250
268
|
out = tensor.copy_template("log_softmax")
|
|
251
|
-
|
|
269
|
+
if axis.need_masking():
|
|
270
|
+
tensor = tensor.copy()
|
|
271
|
+
mask = tensor.get_sequence_mask_broadcast(axis=axis)
|
|
272
|
+
inf_value = get_global_inf_value()
|
|
273
|
+
tensor.raw_tensor = torch.where(mask, tensor.raw_tensor, -inf_value)
|
|
252
274
|
out.raw_tensor = torch.log_softmax(tensor.raw_tensor, dim=tensor.dims.index(axis))
|
|
253
275
|
return out
|
|
254
276
|
|
|
@@ -3562,6 +3562,21 @@ def should_write_to_disk(config):
|
|
|
3562
3562
|
return True
|
|
3563
3563
|
|
|
3564
3564
|
|
|
3565
|
+
_default_global_inf_value = float("inf")
|
|
3566
|
+
|
|
3567
|
+
|
|
3568
|
+
def get_global_inf_value() -> float:
|
|
3569
|
+
"""
|
|
3570
|
+
:return: float("inf") by default, but tries to read `inf_value` from the global config
|
|
3571
|
+
"""
|
|
3572
|
+
from returnn.config import get_global_config
|
|
3573
|
+
|
|
3574
|
+
config = get_global_config(raise_exception=False)
|
|
3575
|
+
if not config:
|
|
3576
|
+
return _default_global_inf_value
|
|
3577
|
+
return config.float("inf_value", _default_global_inf_value)
|
|
3578
|
+
|
|
3579
|
+
|
|
3565
3580
|
class NativeCodeCompiler(object):
|
|
3566
3581
|
"""
|
|
3567
3582
|
Helper class to compile native C/C++ code on-the-fly.
|
|
@@ -184,3 +184,30 @@ def test_dropout():
|
|
|
184
184
|
out.mark_as_default_output(shape=(batch_dim, time_dim, in_dim))
|
|
185
185
|
|
|
186
186
|
run_model(extern_data, lambda *, epoch, step: _Net(), _forward_step)
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def test_dot_attention():
|
|
190
|
+
time_dim = Dim(Tensor("time", [batch_dim], dtype="int32"))
|
|
191
|
+
key_dim = Dim(7, name="key")
|
|
192
|
+
value_dim = Dim(13, name="value")
|
|
193
|
+
extern_data = TensorDict(
|
|
194
|
+
{
|
|
195
|
+
"q": Tensor("q", [batch_dim, time_dim, key_dim], dtype="float32"),
|
|
196
|
+
"k": Tensor("k", [batch_dim, time_dim, key_dim], dtype="float32"),
|
|
197
|
+
"v": Tensor("v", [batch_dim, time_dim, value_dim], dtype="float32"),
|
|
198
|
+
}
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
class _Net(rf.Module):
|
|
202
|
+
def __call__(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
|
|
203
|
+
kv_axis = Dim(None, name=f"kv-axis")
|
|
204
|
+
k, _ = rf.replace_dim(k, in_dim=time_dim, out_dim=kv_axis)
|
|
205
|
+
v, _ = rf.replace_dim(v, in_dim=time_dim, out_dim=kv_axis)
|
|
206
|
+
return rf.dot_attention(q, k, v, axis=kv_axis, key_dim=key_dim)
|
|
207
|
+
|
|
208
|
+
# noinspection PyShadowingNames
|
|
209
|
+
def _forward_step(*, model: _Net, extern_data: TensorDict):
|
|
210
|
+
out = model(q=extern_data["q"], k=extern_data["k"], v=extern_data["v"])
|
|
211
|
+
out.mark_as_default_output(shape=(batch_dim, time_dim, value_dim))
|
|
212
|
+
|
|
213
|
+
run_model(extern_data, lambda *, epoch, step: _Net(), _forward_step)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/demo-hyper-param-tuning.config
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/demo-record-and-push-to-webserver.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/demo-tf-chunking-blstm.12ax.config
RENAMED
|
File without changes
|
{returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/demo-tf-contribrnn-lstm.12ax.config
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/demo-tf-maxgradnorm-lstm.12ax.config
RENAMED
|
File without changes
|
{returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/demo-tf-native-lstm-lowmem.12ax.config
RENAMED
|
File without changes
|
{returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/demo-tf-native-lstm.12ax.config
RENAMED
|
File without changes
|
{returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/demo-tf-native-lstm2.12ax.config
RENAMED
|
File without changes
|
{returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/demo-tf-native-lstm2.12ax.tuned.config
RENAMED
|
File without changes
|
{returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/demo-tf-neural-transducer.12ax.config
RENAMED
|
File without changes
|
{returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/demo-tf-rec-explicit-lstm.config
RENAMED
|
File without changes
|
{returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/demo-tf-rec-explicit-rnn.config
RENAMED
|
File without changes
|
|
File without changes
|
{returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/demo-tf-search-compiled-graph.py
RENAMED
|
File without changes
|
{returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/demo-tf-vanilla-lstm.12ax.config
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/demo-upd-mult-model.lstm.12ax.config
RENAMED
|
File without changes
|
|
File without changes
|
{returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/mdlstm/IAM/IAM_lines/a01-000u-00.png
RENAMED
|
File without changes
|
{returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/mdlstm/IAM/IAM_lines/a01-007-04.png
RENAMED
|
File without changes
|
{returnn-1.20230413.100132 → returnn-1.20230413.141543}/demos/mdlstm/IAM/IAM_lines/a01-007-06.png
RENAMED
|
File without changes
|