returnn 1.20240829.92949__tar.gz → 1.20240829.174139__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of returnn might be problematic. Click here for more details.
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/PKG-INFO +1 -1
- returnn-1.20240829.174139/_setup_info_generated.py +2 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/datasets/postprocessing.py +14 -13
- returnn-1.20240829.174139/returnn/frontend/conversions/espnet_e_branchformer.py +206 -0
- returnn-1.20240829.174139/returnn/frontend/conversions/torch_nn.py +68 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/decoder/transformer.py +10 -5
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/encoder/conformer.py +59 -47
- returnn-1.20240829.174139/returnn/frontend/encoder/e_branchformer.py +275 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn.egg-info/PKG-INFO +1 -1
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn.egg-info/SOURCES.txt +3 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/PyCharm-inspection-profile.xml +2 -1
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/PyCharm.idea/inspectionProfiles/Project_Default.xml +2 -1
- returnn-1.20240829.174139/tests/test_rf_encoder_conformer.py +405 -0
- returnn-1.20240829.92949/_setup_info_generated.py +0 -2
- returnn-1.20240829.92949/tests/test_rf_encoder_conformer.py +0 -57
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/.editorconfig +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/.gitignore +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/.gitmodules +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/.kateconfig +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/CHANGELOG.md +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/CODEOWNERS +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/CONTRIBUTING.md +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/LICENSE +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/MANIFEST.in +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/README.rst +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/__init__.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/12AX.cluster_map +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/_setup_returnn_env.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/demo-fwd.config +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/demo-horovod-mpi.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/demo-horovod-mpi.py.sh +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/demo-horovod-mpi.sh +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/demo-hyper-param-tuning.config +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/demo-iter-dataset.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/demo-list-devices.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/demo-lua-torch-layer.config +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/demo-pretrain.config +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/demo-record-and-push-to-webserver.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/demo-returnn-as-framework.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/demo-rf-pt-benchmark.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/demo-rf.config +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/demo-rhn-enwik8.config +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/demo-sprint-interface.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/demo-tf-att-copy.config +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/demo-tf-attention.config +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/demo-tf-chunking-blstm.12ax.config +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/demo-tf-contribrnn-lstm.12ax.config +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/demo-tf-enc-dec.config +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/demo-tf-hard-att-copy.config +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/demo-tf-lstm-benchmark.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/demo-tf-maxgradnorm-lstm.12ax.config +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/demo-tf-native-lstm-lowmem.12ax.config +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/demo-tf-native-lstm.12ax.config +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/demo-tf-native-lstm2.12ax.config +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/demo-tf-native-lstm2.12ax.tuned.config +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/demo-tf-neural-transducer.12ax.config +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/demo-tf-rec-explicit-lstm.config +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/demo-tf-rec-explicit-rnn.config +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/demo-tf-rec-self-att.config +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/demo-tf-search-compiled-graph.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/demo-tf-vanilla-lstm.12ax.config +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/demo-timit-lstm-ctc.config +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/demo-torch.config +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/demo-upd-mult-model.lstm.12ax.config +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/demo.sh +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/mdlstm/IAM/IAM_lines/a01-000u-00.png +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/mdlstm/IAM/IAM_lines/a01-007-04.png +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/mdlstm/IAM/IAM_lines/a01-007-06.png +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/mdlstm/IAM/README.txt +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/mdlstm/IAM/chars.txt +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/mdlstm/IAM/config_demo +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/mdlstm/IAM/config_fwd +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/mdlstm/IAM/config_real +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/mdlstm/IAM/create_IAM_dataset.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/mdlstm/IAM/decode.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/mdlstm/IAM/features/raw/demo.h5 +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/mdlstm/IAM/go.sh +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/mdlstm/IAM/lines.txt +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/mdlstm/IAM/split/eval.txt +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/mdlstm/IAM/split/train.txt +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/mdlstm/IAM/split/valid.txt +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/mdlstm/README.md +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/mdlstm/artificial/create_test_h5.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/mdlstm/artificial/forwardconfig +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/mdlstm/artificial/go.sh +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/mdlstm/artificial/trainconfig +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/mdlstm/artificial_rgb/create_test_h5.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/mdlstm/artificial_rgb/forwardconfig +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/mdlstm/artificial_rgb/go.sh +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/demos/mdlstm/artificial_rgb/trainconfig +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/pyproject.toml +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/requirements.txt +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/__init__.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/__main__.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/__old_mod_loader__.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/__setup__.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/config.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/datasets/__init__.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/datasets/audio.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/datasets/basic.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/datasets/bundle_file.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/datasets/cached.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/datasets/cached2.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/datasets/distrib_files.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/datasets/generating.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/datasets/hdf.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/datasets/lm.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/datasets/map.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/datasets/meta.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/datasets/multi_proc.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/datasets/normalization_data.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/datasets/numpy_dump.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/datasets/raw_wav.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/datasets/sprint.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/datasets/stereo.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/datasets/util/__init__.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/datasets/util/feature_extraction.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/datasets/util/strings.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/datasets/util/vocabulary.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/engine/__init__.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/engine/base.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/engine/batch.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/extern/WarpRna/__init__.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/extern/WarpRna/__main__.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/extern/WarpRna/warp-rna/.git +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/extern/WarpRna/warp-rna/.gitignore +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/extern/WarpRna/warp-rna/LICENSE +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/extern/WarpRna/warp-rna/README.md +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/extern/WarpRna/warp-rna/aligner.gif +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/extern/WarpRna/warp-rna/check.png +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/extern/WarpRna/warp-rna/core.cu +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/extern/WarpRna/warp-rna/core.h +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/extern/WarpRna/warp-rna/core_cpu.cpp +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/extern/WarpRna/warp-rna/pytorch_binding/LICENSE +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/extern/WarpRna/warp-rna/pytorch_binding/MANIFEST.in +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/extern/WarpRna/warp-rna/pytorch_binding/README.md +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/extern/WarpRna/warp-rna/pytorch_binding/binding.cpp +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/extern/WarpRna/warp-rna/pytorch_binding/core.cu +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/extern/WarpRna/warp-rna/pytorch_binding/core.h +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/extern/WarpRna/warp-rna/pytorch_binding/requirements.txt +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/extern/WarpRna/warp-rna/pytorch_binding/setup.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/extern/WarpRna/warp-rna/pytorch_binding/warp_rna/__init__.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/extern/WarpRna/warp-rna/pytorch_binding/warp_rna/test.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/extern/WarpRna/warp-rna/ref_rna.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/extern/WarpRna/warp-rna/tensorflow_binding/setup.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/extern/WarpRna/warp-rna/tensorflow_binding/src/warp_rna_op.cc +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/extern/WarpRna/warp-rna/tensorflow_binding/src/warp_rna_op_kernel_tmpl.h +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/extern/WarpRna/warp-rna/tensorflow_binding/warp_rna/__init__.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/extern/WarpRna/warp-rna/test.cpp +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/extern/__init__.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/extern/graph_editor/README.md +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/extern/graph_editor/__init__.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/extern/graph_editor/edit.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/extern/graph_editor/reroute.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/extern/graph_editor/select.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/extern/graph_editor/subgraph.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/extern/graph_editor/transform.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/extern/graph_editor/util.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/forward_iface.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/__init__.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/_backend.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/_native/__init__.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/_native/backend.cpp +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/_native/backend.hpp +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/_native/module.cpp +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/_native/module.hpp +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/_native/py_utils.hpp +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/_native/tensor_ops.cpp +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/_native/tensor_ops.hpp +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/_numpy_backend.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/_random_journal.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/_utils.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/array_.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/attention.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/audio/__init__.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/audio/mel.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/audio/specaugment.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/backend.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/build_from_dict.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/cond.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/const.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/container.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/control_flow_ctx.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/conv.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/conversions/__init__.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/conversions/hf_llama.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/decoder/__init__.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/device.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/dims.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/dropout.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/dtype.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/encoder/__init__.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/encoder/base.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/gradient.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/graph.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/hooks.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/init.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/label_smoothing.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/linear.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/loop.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/loss.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/math_.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/matmul.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/module.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/normalization.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/parameter.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/parametrizations.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/parametrize.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/piecewise_linear.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/rand.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/rec.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/reduce.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/run_ctx.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/signal.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/state.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/stepwise_scheduler.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/tensor_array.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/types.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/import_/__init__.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/import_/common.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/import_/git.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/import_/import_.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/learning_rate_control.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/log.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/native_op.cpp +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/native_op.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/pretrain.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/sprint/__init__.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/sprint/cache.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/sprint/control.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/sprint/error_signals.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/sprint/extern_interface.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/sprint/interface.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/tensor/README.md +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/tensor/__init__.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/tensor/_dim_extra.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/tensor/_tensor_extra.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/tensor/_tensor_mixin_base.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/tensor/_tensor_op_overloads.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/tensor/control_flow_ctx.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/tensor/dim.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/tensor/marked_dim.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/tensor/tensor.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/tensor/tensor_dict.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/tensor/utils.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/tf/__init__.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/tf/compat.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/tf/data_pipeline.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/tf/distributed.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/tf/engine.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/tf/frontend_layers/README.md +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/tf/frontend_layers/__init__.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/tf/frontend_layers/_backend.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/tf/frontend_layers/_utils.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/tf/frontend_layers/cond.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/tf/frontend_layers/config_entry_points.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/tf/frontend_layers/debug_eager_mode.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/tf/frontend_layers/dims.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/tf/frontend_layers/layer.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/tf/frontend_layers/loop.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/tf/frontend_layers/make_layer.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/tf/frontend_layers/masked_computation.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/tf/frontend_layers/parameter_assign.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/tf/frontend_layers/prev_tensor_ref.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/tf/frontend_low_level/__init__.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/tf/frontend_low_level/_backend.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/tf/horovod.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/tf/hyper_param_tuning.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/tf/layers/__init__.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/tf/layers/base.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/tf/layers/basic.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/tf/layers/rec.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/tf/layers/segmental_model.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/tf/layers/signal_processing.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/tf/layers/variable.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/tf/native_op.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/tf/network.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/tf/sprint.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/tf/updater.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/tf/util/__init__.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/tf/util/basic.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/tf/util/data.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/tf/util/gradient_checkpoint.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/tf/util/ken_lm.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/tf/util/open_fst.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/torch/README.md +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/torch/__init__.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/torch/data/__init__.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/torch/data/extern_data.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/torch/data/pipeline.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/torch/data/queued_data_iter.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/torch/data/returnn_dataset_wrapper.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/torch/data/tensor_utils.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/torch/distributed.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/torch/engine.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/torch/frontend/__init__.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/torch/frontend/_backend.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/torch/frontend/_rand.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/torch/frontend/bridge.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/torch/frontend/raw_ops.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/torch/updater.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/torch/util/README.md +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/torch/util/__init__.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/torch/util/array_.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/torch/util/diagnose_gpu.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/torch/util/gradient_checkpoint.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/torch/util/scaled_gradient.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/util/__init__.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/util/basic.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/util/better_exchook.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/util/bpe.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/util/debug.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/util/debug_helpers.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/util/file_cache.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/util/fsa.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/util/literal_py_to_pickle.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/util/math.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/util/multi_proc_non_daemonic_spawn.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/util/native_code_compiler.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/util/pprint.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/util/py-to-pickle.cpp +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/util/py_compat.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/util/py_ext_mod_compiler.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/util/result_with_reason.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/util/sig_proc.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/util/task_system.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/util/train_proc_manager.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/util/watch_memory.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn.egg-info/dependency_links.txt +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn.egg-info/top_level.txt +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/rnn.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/setup.cfg +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/setup.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/DummySprintExec.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/PyCharm.idea/.gitignore +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/PyCharm.idea/.name +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/PyCharm.idea/codeStyleSettings.xml +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/PyCharm.idea/codeStyles/Project.xml +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/PyCharm.idea/codeStyles/codeStyleConfig.xml +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/PyCharm.idea/inspectionProfiles/profiles_settings.xml +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/PyCharm.idea/misc.xml +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/PyCharm.idea/modules.xml +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/PyCharm.idea/returnn.iml +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/PyCharm.idea/scopes/scope_settings.xml +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/_set_num_threads1.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/_setup_returnn_env.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/_setup_test_env.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/bpe-unicode-demo.codes +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/bpe-unicode-demo.vocab +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/lexicon_opt.fst +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/lexicon_opt.isyms +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/lexicon_opt.jpg +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/lexicon_opt.osyms +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/lint_common.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/pycharm-inspect.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/pylint.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/returnn-as-framework.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/rf_utils.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/spelling.dic +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/test_Config.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/test_Dataset.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/test_Fsa.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/test_GeneratingDataset.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/test_HDFDataset.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/test_LearningRateControl.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/test_Log.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/test_MultiProcDataset.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/test_Pretrain.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/test_ResNet.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/test_SprintDataset.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/test_SprintInterface.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/test_TFEngine.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/test_TFNativeOp.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/test_TFNetworkLayer.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/test_TFNetworkRecLayer.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/test_TFNetworkSigProcLayer.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/test_TFUpdater.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/test_TFUtil.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/test_TF_determinism.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/test_TaskSystem.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/test_TaskSystem_SharedMem.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/test_TranslationDataset.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/test_Util.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/test_demos.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/test_fork_exec.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/test_hdf_dump.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/test_rf_array.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/test_rf_attention.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/test_rf_base.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/test_rf_cond.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/test_rf_const.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/test_rf_container.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/test_rf_conv.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/test_rf_decoder_transformer.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/test_rf_gradient.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/test_rf_label_smoothing.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/test_rf_loop.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/test_rf_math.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/test_rf_normalization.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/test_rf_piecewise_linear.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/test_rf_rec.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/test_rf_reduce.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/test_rf_signal.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/test_tensor.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/test_tools.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/test_torch_dataset.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/test_torch_engine.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/test_torch_frontend.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/test_torch_internal_frontend.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/test_torch_util.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tests/torch_utils.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tools/_setup_returnn_env.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tools/analyze-dataset-batches.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tools/bliss-collect-seq-lens.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tools/bliss-dump-text.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tools/bliss-get-segment-names.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tools/bliss-to-ogg-zip.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tools/bpe-create-lexicon.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tools/calculate-word-error-rate.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tools/cleanup-old-models.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tools/collect-orth-symbols.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tools/collect-words.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tools/compile_native_op.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tools/compile_tf_graph.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tools/debug-dump-search-scores.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tools/debug-plot-search-scores.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tools/dump-dataset-raw-strings.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tools/dump-dataset.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tools/dump-forward-stats.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tools/dump-forward.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tools/dump-network-json.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tools/dump-pickle.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tools/extract_state_tying_from_dataset.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tools/get-attention-weights.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tools/get-best-model-epoch.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tools/hdf_dump.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tools/hdf_dump_translation_dataset.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tools/import-blocks-mt-model.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tools/import-t2t-mt-model.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tools/lattice_rescorer/.gitignore +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tools/lattice_rescorer/Makefile +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tools/lattice_rescorer/README.md +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tools/lattice_rescorer/example/README.md +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tools/lattice_rescorer/example/libs_list +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tools/lattice_rescorer/example/network.040/i600_m600_m600.sgd_b16_lr0_cl2.newbobabs.config +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tools/lattice_rescorer/example/network.040/i600_m600_m600.sgd_b16_lr0_cl2.newbobabs.keep_over_epoch.lstm2.config +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tools/lattice_rescorer/example/rescore_lattice.sh +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tools/lattice_rescorer/example/state_vars_list +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tools/lattice_rescorer/example/tensor_names_list +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tools/lattice_rescorer/file.h +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tools/lattice_rescorer/htklatticerescorer.cc +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tools/lattice_rescorer/htklatticerescorer.h +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tools/lattice_rescorer/main.cc +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tools/lattice_rescorer/rescorer.h +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tools/lattice_rescorer/vocabulary.cc +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tools/lattice_rescorer/vocabulary.h +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tools/tf_avg_checkpoints.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tools/tf_inspect_checkpoint.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tools/tf_inspect_summary_log.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tools/torch_avg_checkpoints.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tools/torch_export_to_onnx.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tools/torch_inspect_checkpoint.py +0 -0
- {returnn-1.20240829.92949 → returnn-1.20240829.174139}/tools/torch_inspect_checkpoint_and_opt.py +0 -0
|
@@ -5,7 +5,7 @@ Provides :class:`PostprocessingDataset`.
|
|
|
5
5
|
from __future__ import annotations
|
|
6
6
|
|
|
7
7
|
from itertools import islice
|
|
8
|
-
from numpy.random import
|
|
8
|
+
from numpy.random import RandomState
|
|
9
9
|
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple
|
|
10
10
|
|
|
11
11
|
from returnn.datasets.basic import DatasetSeq
|
|
@@ -45,9 +45,9 @@ class PostprocessingDataset(CachedDataset2):
|
|
|
45
45
|
"files": ["/path/to/data.hdf"],
|
|
46
46
|
},
|
|
47
47
|
# one of them, but not both:
|
|
48
|
-
# (data: TensorDict, *, rng: numpy.random.
|
|
48
|
+
# (data: TensorDict, *, rng: numpy.random.RandomState, **kwargs) -> TensorDict
|
|
49
49
|
"map_seq": map_seq,
|
|
50
|
-
# (iter: Iterator[TensorDict], *, rng: numpy.random.
|
|
50
|
+
# (iter: Iterator[TensorDict], *, rng: numpy.random.RandomState, **kwargs) -> Iterator[TensorDict]
|
|
51
51
|
"map_seq_stream": map_seqs,
|
|
52
52
|
# only required when data shapes change wrt. the wrapped dataset:
|
|
53
53
|
"map_outputs": {
|
|
@@ -67,17 +67,18 @@ class PostprocessingDataset(CachedDataset2):
|
|
|
67
67
|
"""
|
|
68
68
|
:param dataset: inner dataset to be post-processed
|
|
69
69
|
:param map_seq: post processor function operating on the single-segment level.
|
|
70
|
-
Signature: `(data: TensorDict, *, rng: numpy.random.
|
|
70
|
+
Signature: `(data: TensorDict, *, rng: numpy.random.RandomState, **kwargs) -> TensorDict`
|
|
71
71
|
To avoid confusion on the order of how the processing functions are applied to the data, only one of
|
|
72
|
-
|
|
73
|
-
To ensure forwards compatibility, the function must accept
|
|
72
|
+
``map_seq`` and ``map_seq_stream`` can be specified at a time.
|
|
73
|
+
To ensure forwards compatibility, the function must accept ``**kwargs`` as its last argument.
|
|
74
74
|
This is enforced by passing randomly named parameters at runtime.
|
|
75
75
|
:param map_seq_stream: post processor function operating on the multiple segment level via an iterator.
|
|
76
76
|
Allows merging multiple segments into one, or generating multiple output segments from one input segment.
|
|
77
|
-
Signature:
|
|
77
|
+
Signature:
|
|
78
|
+
``(iter: Iterator[TensorDict], *, rng: numpy.random.RandomState, **kwargs) -> Iterator[TensorDict]``
|
|
78
79
|
To avoid confusion on the order of how the processing functions are applied to the data, only one of
|
|
79
|
-
|
|
80
|
-
To ensure forwards compatibility, the function must accept
|
|
80
|
+
``map_seq`` and ``map_seq_stream`` can be specified at a time.
|
|
81
|
+
To ensure forwards compatibility, the function must accept ``**kwargs`` as its last argument.
|
|
81
82
|
This is enforced by passing randomly named parameters at runtime.
|
|
82
83
|
:param map_outputs: Type and axis specification of the outputs of the mapping functions,
|
|
83
84
|
like extern_data and model_outputs.
|
|
@@ -99,7 +100,7 @@ class PostprocessingDataset(CachedDataset2):
|
|
|
99
100
|
self._map_seq = map_seq
|
|
100
101
|
self._map_seq_stream = map_seq_stream
|
|
101
102
|
self._map_outputs = map_outputs
|
|
102
|
-
self._rng =
|
|
103
|
+
self._rng = RandomState(self._get_random_seed_for_epoch(0))
|
|
103
104
|
|
|
104
105
|
self._dataset = init_dataset(self._dataset_def, parent_dataset=self)
|
|
105
106
|
if self._map_seq_stream is None:
|
|
@@ -144,7 +145,7 @@ class PostprocessingDataset(CachedDataset2):
|
|
|
144
145
|
self._num_seqs = 0
|
|
145
146
|
return True
|
|
146
147
|
|
|
147
|
-
self._rng =
|
|
148
|
+
self._rng = RandomState(self._get_random_seed_for_epoch(epoch=epoch))
|
|
148
149
|
assert self._dataset is not None
|
|
149
150
|
self._dataset.init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order)
|
|
150
151
|
self._data_iter = enumerate(self._build_mapping_iter())
|
|
@@ -181,7 +182,7 @@ class PostprocessingDataset(CachedDataset2):
|
|
|
181
182
|
data_iter = self._iterate_dataset()
|
|
182
183
|
if self._map_seq_stream is not None:
|
|
183
184
|
data_iter = self._map_seq_stream(
|
|
184
|
-
data_iter, rng=self._rng, **{f"fwd_compatible_random_kwarg_{self._rng.
|
|
185
|
+
data_iter, rng=self._rng, **{f"fwd_compatible_random_kwarg_{self._rng.randint(0, 1000)}": None}
|
|
185
186
|
)
|
|
186
187
|
assert isinstance(
|
|
187
188
|
data_iter, Iterator
|
|
@@ -202,7 +203,7 @@ class PostprocessingDataset(CachedDataset2):
|
|
|
202
203
|
tensor_dict.data[data_key].raw_tensor = self._dataset.get_data(seq_index, data_key)
|
|
203
204
|
if self._map_seq is not None:
|
|
204
205
|
tensor_dict = self._map_seq(
|
|
205
|
-
tensor_dict, rng=self._rng, **{f"fwd_compatible_random_kwarg_{self._rng.
|
|
206
|
+
tensor_dict, rng=self._rng, **{f"fwd_compatible_random_kwarg_{self._rng.randint(0, 1000)}": None}
|
|
206
207
|
)
|
|
207
208
|
assert isinstance(
|
|
208
209
|
tensor_dict, TensorDict
|
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Import ESPnet E-Branchformer model parameters
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
from typing import TYPE_CHECKING, Union
|
|
7
|
+
import returnn.frontend as rf
|
|
8
|
+
from returnn.frontend.encoder.e_branchformer import EBranchformerLayer, FeedForwardConvGated
|
|
9
|
+
from returnn.frontend.decoder.transformer import FeedForward
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
import torch
|
|
13
|
+
from espnet2.asr.encoder.e_branchformer_encoder import EBranchformerEncoderLayer, ConvolutionalGatingMLP
|
|
14
|
+
from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import (
|
|
15
|
+
PositionwiseFeedForward,
|
|
16
|
+
)
|
|
17
|
+
from espnet.nets.pytorch_backend.transformer.attention import (
|
|
18
|
+
MultiHeadedAttention,
|
|
19
|
+
RelPositionMultiHeadedAttention,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def import_params_espnet_e_branchformer_layer_to_rf(
|
|
24
|
+
model_espnet: EBranchformerEncoderLayer, model_rf: EBranchformerLayer
|
|
25
|
+
):
|
|
26
|
+
"""
|
|
27
|
+
Import params from ESPnet E-Branchformer layer to
|
|
28
|
+
RF :class:`returnn.frontend.encoder.e_branchformer.EBranchformerLayer`.
|
|
29
|
+
"""
|
|
30
|
+
from .torch_nn import (
|
|
31
|
+
import_params_torch_conv1d_to_rf,
|
|
32
|
+
import_params_torch_layer_norm_to_rf,
|
|
33
|
+
import_params_torch_linear_to_rf,
|
|
34
|
+
)
|
|
35
|
+
from espnet2.asr.encoder.e_branchformer_encoder import EBranchformerEncoderLayer
|
|
36
|
+
from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import (
|
|
37
|
+
PositionwiseFeedForward,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
assert isinstance(model_espnet, EBranchformerEncoderLayer)
|
|
41
|
+
assert isinstance(model_rf, EBranchformerLayer)
|
|
42
|
+
|
|
43
|
+
assert isinstance(model_espnet.feed_forward, PositionwiseFeedForward)
|
|
44
|
+
assert isinstance(model_espnet.feed_forward_macaron, PositionwiseFeedForward)
|
|
45
|
+
|
|
46
|
+
import_params_espnet_positionwise_feed_forward_to_rf(model_espnet.feed_forward_macaron, model_rf.ffn1)
|
|
47
|
+
import_params_espnet_positionwise_feed_forward_to_rf(model_espnet.feed_forward, model_rf.ffn2)
|
|
48
|
+
|
|
49
|
+
import_params_torch_layer_norm_to_rf(model_espnet.norm_ff_macaron, model_rf.ffn1_layer_norm)
|
|
50
|
+
import_params_torch_layer_norm_to_rf(model_espnet.norm_ff, model_rf.ffn2_layer_norm)
|
|
51
|
+
import_params_torch_layer_norm_to_rf(model_espnet.norm_mha, model_rf.self_att_layer_norm)
|
|
52
|
+
import_params_torch_layer_norm_to_rf(model_espnet.norm_mlp, model_rf.cgmlp_layer_norm)
|
|
53
|
+
import_params_torch_layer_norm_to_rf(model_espnet.norm_final, model_rf.final_layer_norm)
|
|
54
|
+
|
|
55
|
+
# noinspection PyTypeChecker
|
|
56
|
+
import_params_espnet_multi_headed_attention_to_rf(model_espnet.attn, model_rf.self_att)
|
|
57
|
+
|
|
58
|
+
# noinspection PyTypeChecker
|
|
59
|
+
import_params_espnet_convolutional_gating_mlp_to_rf(model_espnet.cgmlp, model_rf.cgmlp)
|
|
60
|
+
|
|
61
|
+
import_params_torch_conv1d_to_rf(model_espnet.depthwise_conv_fusion, model_rf.merge.depthwise_conv_fusion)
|
|
62
|
+
import_params_torch_linear_to_rf(model_espnet.merge_proj, model_rf.merge.merge_proj)
|
|
63
|
+
|
|
64
|
+
num_params_espnet = 0
|
|
65
|
+
for k, v in model_espnet.named_parameters():
|
|
66
|
+
num_params_espnet += v.numel()
|
|
67
|
+
num_params_rf = 0
|
|
68
|
+
for k, v in model_rf.named_parameters():
|
|
69
|
+
num_params_rf += v.num_elements()
|
|
70
|
+
assert num_params_rf == num_params_espnet, f"num params RF {num_params_rf} != params ESPnet {num_params_espnet}"
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def import_params_espnet_positionwise_feed_forward_to_rf(model_espnet: PositionwiseFeedForward, model_rf: FeedForward):
|
|
74
|
+
"""import"""
|
|
75
|
+
from .torch_nn import import_params_torch_linear_to_rf
|
|
76
|
+
|
|
77
|
+
assert model_rf.linear_ff.with_bias and model_rf.linear_out.with_bias
|
|
78
|
+
import_params_torch_linear_to_rf(model_espnet.w_1, model_rf.linear_ff)
|
|
79
|
+
import_params_torch_linear_to_rf(model_espnet.w_2, model_rf.linear_out)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def import_params_espnet_multi_headed_attention_to_rf(
|
|
83
|
+
model_espnet: Union[MultiHeadedAttention, RelPositionMultiHeadedAttention],
|
|
84
|
+
model_rf: Union[rf.SelfAttention, rf.RelPosSelfAttention],
|
|
85
|
+
):
|
|
86
|
+
"""import"""
|
|
87
|
+
import torch
|
|
88
|
+
from .torch_nn import import_params_torch_linear_to_rf
|
|
89
|
+
from espnet.nets.pytorch_backend.transformer.attention import (
|
|
90
|
+
MultiHeadedAttention,
|
|
91
|
+
RelPositionMultiHeadedAttention,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
assert isinstance(model_espnet, (MultiHeadedAttention, RelPositionMultiHeadedAttention))
|
|
95
|
+
assert isinstance(model_rf, (rf.SelfAttention, rf.RelPosSelfAttention))
|
|
96
|
+
assert model_espnet.h == model_rf.num_heads.dimension
|
|
97
|
+
assert model_espnet.d_k == model_rf.key_dim_per_head.dimension
|
|
98
|
+
dim = model_espnet.d_k * model_espnet.h
|
|
99
|
+
nh = model_espnet.h
|
|
100
|
+
hdim = dim // nh
|
|
101
|
+
|
|
102
|
+
with torch.no_grad():
|
|
103
|
+
# Torch Linear: (out,in), but RF has (in,out).
|
|
104
|
+
q = model_espnet.linear_q.weight.T.reshape(dim, nh, hdim) # (in,h,out/h)
|
|
105
|
+
k = model_espnet.linear_k.weight.T.reshape(dim, nh, hdim) # (in,h,out/h)
|
|
106
|
+
v = model_espnet.linear_v.weight.T.reshape(dim, nh, hdim) # (in,h,out/h)
|
|
107
|
+
q_bias = model_espnet.linear_q.bias.reshape(nh, hdim) # (h,out/h)
|
|
108
|
+
k_bias = model_espnet.linear_k.bias.reshape(nh, hdim) # (h,out/h)
|
|
109
|
+
v_bias = model_espnet.linear_v.bias.reshape(nh, hdim) # (h,out/h)
|
|
110
|
+
qkv = torch.cat([q, k, v], dim=2) # (in,h,out/h*3)
|
|
111
|
+
qkv = qkv.reshape(dim, 3 * dim) # (in,out*3)
|
|
112
|
+
qkv_bias = torch.cat([q_bias, k_bias, v_bias], dim=1).reshape(dim * 3) # (out*3,)
|
|
113
|
+
model_rf.qkv.weight.raw_tensor.copy_(qkv)
|
|
114
|
+
model_rf.qkv.bias.raw_tensor.copy_(qkv_bias)
|
|
115
|
+
|
|
116
|
+
import_params_torch_linear_to_rf(model_espnet.linear_out, model_rf.proj)
|
|
117
|
+
|
|
118
|
+
if isinstance(model_espnet, RelPositionMultiHeadedAttention):
|
|
119
|
+
assert isinstance(model_rf, rf.RelPosSelfAttention)
|
|
120
|
+
assert model_rf.linear_pos is not None
|
|
121
|
+
assert model_rf.pos_bias_u is not None and model_rf.pos_bias_v is not None
|
|
122
|
+
|
|
123
|
+
import_params_torch_linear_to_rf(model_espnet.linear_pos, model_rf.linear_pos)
|
|
124
|
+
_reorder_rel_pos_emb_espnet_to_rf_(model_rf.linear_pos.weight.raw_tensor, dim=0)
|
|
125
|
+
model_rf.pos_bias_u.raw_tensor.copy_(model_espnet.pos_bias_u)
|
|
126
|
+
model_rf.pos_bias_v.raw_tensor.copy_(model_espnet.pos_bias_v)
|
|
127
|
+
else:
|
|
128
|
+
assert not isinstance(model_rf, rf.RelPosSelfAttention)
|
|
129
|
+
|
|
130
|
+
num_params_espnet = 0
|
|
131
|
+
for k, v in model_espnet.named_parameters():
|
|
132
|
+
num_params_espnet += v.numel()
|
|
133
|
+
num_params_rf = 0
|
|
134
|
+
for k, v in model_rf.named_parameters():
|
|
135
|
+
num_params_rf += v.num_elements()
|
|
136
|
+
assert num_params_rf == num_params_espnet, f"num params RF {num_params_rf} != params ESPnet {num_params_espnet}"
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def _reorder_rel_pos_emb_espnet_to_rf(x: torch.Tensor, *, dim=-1) -> torch.Tensor:
|
|
140
|
+
if dim < 0:
|
|
141
|
+
dim += x.ndim
|
|
142
|
+
assert 0 <= dim < x.ndim
|
|
143
|
+
if dim != x.ndim - 1:
|
|
144
|
+
x = x.transpose(dim, -1)
|
|
145
|
+
# x: [..., D]
|
|
146
|
+
# x feat dims is sin/cos repeated after each other
|
|
147
|
+
*o, d = x.shape
|
|
148
|
+
x = x.reshape(*o, d // 2, 2) # [..., D/2, 2]
|
|
149
|
+
# PT goes over indices T-1,T-2,...,0,1,2,...,T-1.
|
|
150
|
+
# RF goes the other way around.
|
|
151
|
+
# We don't flip here, to show that a linear transformation of the features is also fine.
|
|
152
|
+
# Flipping cos has no effect.
|
|
153
|
+
# Flipping sin would be equivalent to negating the positional encoding.
|
|
154
|
+
x[..., 0] = -x[..., 0]
|
|
155
|
+
# RF has first the sin, then the cos.
|
|
156
|
+
x = x.transpose(-1, -2).reshape(*o, d) # [..., D]
|
|
157
|
+
if dim != x.ndim - 1: # transpose back
|
|
158
|
+
x = x.transpose(dim, -1)
|
|
159
|
+
return x
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def _reorder_rel_pos_emb_espnet_to_rf_(x: torch.Tensor, *, dim=-1):
|
|
163
|
+
import torch
|
|
164
|
+
|
|
165
|
+
with torch.no_grad():
|
|
166
|
+
x.copy_(_reorder_rel_pos_emb_espnet_to_rf(x, dim=dim))
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def import_params_espnet_convolutional_gating_mlp_to_rf(
|
|
170
|
+
model_espnet: ConvolutionalGatingMLP, model_rf: FeedForwardConvGated
|
|
171
|
+
):
|
|
172
|
+
"""import"""
|
|
173
|
+
from .torch_nn import (
|
|
174
|
+
import_params_torch_linear_to_rf,
|
|
175
|
+
import_params_torch_layer_norm_to_rf,
|
|
176
|
+
import_params_torch_conv1d_to_rf,
|
|
177
|
+
)
|
|
178
|
+
from espnet2.asr.encoder.e_branchformer_encoder import ConvolutionalGatingMLP
|
|
179
|
+
|
|
180
|
+
assert isinstance(model_espnet, ConvolutionalGatingMLP)
|
|
181
|
+
assert isinstance(model_rf, FeedForwardConvGated)
|
|
182
|
+
|
|
183
|
+
import_params_torch_linear_to_rf(model_espnet.channel_proj1[0], model_rf.linear_ff)
|
|
184
|
+
_reorder_espnet_cgmlp_linear_ff_to_rf_(model_rf.linear_ff.weight.raw_tensor)
|
|
185
|
+
if model_rf.linear_ff.with_bias:
|
|
186
|
+
_reorder_espnet_cgmlp_linear_ff_to_rf_(model_rf.linear_ff.bias.raw_tensor)
|
|
187
|
+
import_params_torch_linear_to_rf(model_espnet.channel_proj2, model_rf.linear_out)
|
|
188
|
+
import_params_torch_layer_norm_to_rf(model_espnet.csgu.norm, model_rf.norm)
|
|
189
|
+
import_params_torch_conv1d_to_rf(model_espnet.csgu.conv, model_rf.conv)
|
|
190
|
+
assert model_espnet.csgu.linear is None
|
|
191
|
+
|
|
192
|
+
num_params_espnet = 0
|
|
193
|
+
for k, v in model_espnet.named_parameters():
|
|
194
|
+
num_params_espnet += v.numel()
|
|
195
|
+
num_params_rf = 0
|
|
196
|
+
for k, v in model_rf.named_parameters():
|
|
197
|
+
num_params_rf += v.num_elements()
|
|
198
|
+
assert num_params_rf == num_params_espnet, f"num params RF {num_params_rf} != params ESPnet {num_params_espnet}"
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def _reorder_espnet_cgmlp_linear_ff_to_rf_(w: torch.Tensor):
|
|
202
|
+
import torch
|
|
203
|
+
|
|
204
|
+
dims = list(w.shape)
|
|
205
|
+
with torch.no_grad():
|
|
206
|
+
w.copy_(w.reshape(*dims[:-1], 2, dims[-1] // 2).flip(-2).reshape(*dims))
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Import some of the torch.nn modules.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
7
|
+
import returnn.frontend as rf
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
import torch
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def import_params_torch_linear_to_rf(model_pt: torch.nn.Linear, model_rf: rf.Linear):
|
|
14
|
+
"""
|
|
15
|
+
import params from torch.nn.Linear to rf.Linear
|
|
16
|
+
"""
|
|
17
|
+
import torch
|
|
18
|
+
|
|
19
|
+
assert isinstance(model_pt, torch.nn.Linear)
|
|
20
|
+
assert isinstance(model_rf, rf.Linear)
|
|
21
|
+
assert model_rf.with_bias == (model_pt.bias is not None)
|
|
22
|
+
|
|
23
|
+
with torch.no_grad():
|
|
24
|
+
model_rf.weight.raw_tensor.copy_(model_pt.weight.T) # (in,out)
|
|
25
|
+
if model_rf.with_bias:
|
|
26
|
+
model_rf.bias.raw_tensor.copy_(model_pt.bias) # (out,)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def import_params_torch_conv1d_to_rf(model_pt: torch.nn.Conv1d, model_rf: rf.Conv1d):
|
|
30
|
+
"""
|
|
31
|
+
import params from torch.nn.Conv1d to rf.Conv1d
|
|
32
|
+
"""
|
|
33
|
+
import torch
|
|
34
|
+
|
|
35
|
+
assert isinstance(model_pt, torch.nn.Conv1d)
|
|
36
|
+
assert isinstance(model_rf, rf.Conv1d)
|
|
37
|
+
assert model_rf.with_bias == (model_pt.bias is not None)
|
|
38
|
+
|
|
39
|
+
with torch.no_grad():
|
|
40
|
+
# Torch shape: out_channels, in_channels // groups, *kernel_size
|
|
41
|
+
# RF shape: self.out_dim, self.filter_in_dim, *self.filter_size, i.e. should be same
|
|
42
|
+
model_rf.filter.raw_tensor.copy_(model_pt.weight)
|
|
43
|
+
if model_rf.with_bias:
|
|
44
|
+
model_rf.bias.raw_tensor.copy_(model_pt.bias)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def import_params_torch_layer_norm_to_rf(model_pt: torch.nn.LayerNorm, model_rf: rf.LayerNorm):
|
|
48
|
+
"""
|
|
49
|
+
Import the parameters from torch.nn.LayerNorm to rf.LayerNorm.
|
|
50
|
+
"""
|
|
51
|
+
import torch
|
|
52
|
+
|
|
53
|
+
assert isinstance(model_pt, torch.nn.LayerNorm)
|
|
54
|
+
assert isinstance(model_rf, rf.LayerNorm)
|
|
55
|
+
assert model_pt.weight.shape[0] == model_rf.in_dim.dimension
|
|
56
|
+
|
|
57
|
+
with torch.no_grad():
|
|
58
|
+
model_rf.scale.raw_tensor.copy_(model_pt.weight) # (in,)
|
|
59
|
+
model_rf.bias.raw_tensor.copy_(model_pt.bias) # (in,)
|
|
60
|
+
|
|
61
|
+
num_params_pt = 0
|
|
62
|
+
for k, v in model_pt.named_parameters():
|
|
63
|
+
num_params_pt += v.numel()
|
|
64
|
+
num_params_rf = 0
|
|
65
|
+
for k, v in model_rf.named_parameters():
|
|
66
|
+
assert isinstance(v.raw_tensor, torch.nn.Parameter)
|
|
67
|
+
num_params_rf += v.num_elements()
|
|
68
|
+
assert num_params_rf == num_params_pt
|
{returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/decoder/transformer.py
RENAMED
|
@@ -161,7 +161,7 @@ class TransformerDecoder(rf.Module):
|
|
|
161
161
|
|
|
162
162
|
self.layers = sequential(_copy.deepcopy(decoder_layer) for _ in range(num_layers))
|
|
163
163
|
|
|
164
|
-
self.final_layer_norm =
|
|
164
|
+
self.final_layer_norm = make_norm(norm, model_dim)
|
|
165
165
|
|
|
166
166
|
self.logits = rf.Linear(model_dim, vocab_dim, with_bias=logits_with_bias)
|
|
167
167
|
|
|
@@ -287,7 +287,7 @@ class TransformerDecoderLayer(rf.Module):
|
|
|
287
287
|
assert isinstance(ff, rf.Module)
|
|
288
288
|
|
|
289
289
|
self.ff = ff
|
|
290
|
-
self.ff_layer_norm =
|
|
290
|
+
self.ff_layer_norm = make_norm(norm, out_dim)
|
|
291
291
|
|
|
292
292
|
if self_att is None or isinstance(self_att, type) or isinstance(self_att, dict):
|
|
293
293
|
self_att_opts_ = dict(
|
|
@@ -312,7 +312,7 @@ class TransformerDecoderLayer(rf.Module):
|
|
|
312
312
|
self.self_att = _copy.deepcopy(self_att)
|
|
313
313
|
else:
|
|
314
314
|
raise TypeError(f"unexpected self_att type {self_att!r}")
|
|
315
|
-
self.self_att_layer_norm =
|
|
315
|
+
self.self_att_layer_norm = make_norm(norm, out_dim)
|
|
316
316
|
|
|
317
317
|
self.cross_att = None
|
|
318
318
|
self.cross_att_layer_norm = None
|
|
@@ -326,7 +326,7 @@ class TransformerDecoderLayer(rf.Module):
|
|
|
326
326
|
num_heads=num_heads,
|
|
327
327
|
att_dropout=att_dropout,
|
|
328
328
|
)
|
|
329
|
-
self.cross_att_layer_norm =
|
|
329
|
+
self.cross_att_layer_norm = make_norm(norm, out_dim)
|
|
330
330
|
|
|
331
331
|
def default_initial_state(self, *, batch_dims: Sequence[Dim]) -> rf.State:
|
|
332
332
|
"""default initial state"""
|
|
@@ -492,7 +492,12 @@ class FeedForwardGated(rf.Module):
|
|
|
492
492
|
return x_ff2
|
|
493
493
|
|
|
494
494
|
|
|
495
|
-
def
|
|
495
|
+
def make_norm(norm: Union[type, Dict[str, Any], rf.Module, Callable], out_dim: Dim) -> Union[rf.Module, Callable]:
|
|
496
|
+
"""
|
|
497
|
+
:param norm: norm type or dict or module or callable. e.g. ``rf.LayerNorm``
|
|
498
|
+
:param out_dim: model/out dim
|
|
499
|
+
:return: norm module or callable. e.g. ``rf.LayerNorm(out_dim)``
|
|
500
|
+
"""
|
|
496
501
|
if isinstance(norm, type):
|
|
497
502
|
norm = norm(out_dim)
|
|
498
503
|
elif isinstance(norm, dict):
|
{returnn-1.20240829.92949 → returnn-1.20240829.174139}/returnn/frontend/encoder/conformer.py
RENAMED
|
@@ -13,9 +13,10 @@ from returnn.tensor import Tensor, Dim
|
|
|
13
13
|
import returnn.frontend as rf
|
|
14
14
|
from returnn.util.basic import NotSpecified
|
|
15
15
|
from .base import ISeqDownsamplingEncoder
|
|
16
|
+
from ..decoder.transformer import FeedForward, make_norm
|
|
16
17
|
|
|
17
18
|
|
|
18
|
-
class ConformerPositionwiseFeedForward(
|
|
19
|
+
class ConformerPositionwiseFeedForward(FeedForward):
|
|
19
20
|
"""
|
|
20
21
|
Conformer position-wise feedforward neural network layer
|
|
21
22
|
FF -> Activation -> Dropout -> FF
|
|
@@ -25,37 +26,20 @@ class ConformerPositionwiseFeedForward(rf.Module):
|
|
|
25
26
|
self,
|
|
26
27
|
out_dim: Dim,
|
|
27
28
|
*,
|
|
28
|
-
ff_dim: Dim,
|
|
29
|
-
dropout: float,
|
|
30
|
-
activation: Union[Callable[[Tensor], Tensor], Dict[str, Any], rf.Module],
|
|
29
|
+
ff_dim: Union[Dim, int] = NotSpecified,
|
|
30
|
+
dropout: float = 0.1,
|
|
31
|
+
activation: Union[Callable[[Tensor], Tensor], Dict[str, Any], rf.Module] = rf.swish,
|
|
32
|
+
**kwargs,
|
|
31
33
|
):
|
|
32
34
|
"""
|
|
33
35
|
:param out_dim: output feature dimension
|
|
34
36
|
:param ff_dim: dimension of the feed-forward layers
|
|
35
37
|
:param dropout: dropout value
|
|
36
|
-
:param activation: activation function
|
|
38
|
+
:param activation: activation function. swish by default, unlike the base :class:`FeedForward`
|
|
37
39
|
"""
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
self.dropout = dropout
|
|
42
|
-
self.dropout_broadcast = rf.dropout_broadcast_default()
|
|
43
|
-
if isinstance(activation, dict):
|
|
44
|
-
activation = rf.build_from_dict(activation)
|
|
45
|
-
elif not callable(activation):
|
|
46
|
-
raise TypeError(f"{self}: unexpected activation type {activation!r}")
|
|
47
|
-
self.activation = activation
|
|
48
|
-
|
|
49
|
-
self.linear_ff = rf.Linear(out_dim, ff_dim)
|
|
50
|
-
self.linear_out = rf.Linear(ff_dim, out_dim)
|
|
51
|
-
|
|
52
|
-
def __call__(self, inp: Tensor) -> Tensor:
|
|
53
|
-
"""forward"""
|
|
54
|
-
x_ff1 = self.linear_ff(inp)
|
|
55
|
-
x_act = self.activation(x_ff1)
|
|
56
|
-
x_drop = rf.dropout(x_act, self.dropout, axis=self.dropout_broadcast and self.linear_ff.out_dim)
|
|
57
|
-
x_ff2 = self.linear_out(x_drop)
|
|
58
|
-
return x_ff2
|
|
40
|
+
if activation is NotSpecified:
|
|
41
|
+
activation = rf.swish
|
|
42
|
+
super().__init__(out_dim=out_dim, ff_dim=ff_dim, dropout=dropout, activation=activation, **kwargs)
|
|
59
43
|
|
|
60
44
|
|
|
61
45
|
class ConformerConvBlock(rf.Module):
|
|
@@ -188,8 +172,9 @@ class ConformerEncoderLayer(rf.Module):
|
|
|
188
172
|
self,
|
|
189
173
|
out_dim: Dim = Dim(512, name="conformer-enc-default-out-dim"),
|
|
190
174
|
*,
|
|
175
|
+
ff: Union[type, Dict[str, Any], rf.Module] = NotSpecified,
|
|
191
176
|
ff_dim: Dim = NotSpecified,
|
|
192
|
-
ff_activation: Union[Callable[[Tensor], Tensor], Dict[str, Any], rf.Module] =
|
|
177
|
+
ff_activation: Union[Callable[[Tensor], Tensor], Dict[str, Any], rf.Module] = NotSpecified,
|
|
193
178
|
dropout: float = 0.1,
|
|
194
179
|
conv_kernel_size: int = 32,
|
|
195
180
|
conv_norm: Union[rf.BatchNorm, type, Dict[str, Any], Any] = NotSpecified,
|
|
@@ -198,6 +183,7 @@ class ConformerEncoderLayer(rf.Module):
|
|
|
198
183
|
self_att: Optional[Union[rf.RelPosSelfAttention, rf.Module, type, Dict[str, Any], Any]] = None,
|
|
199
184
|
self_att_opts: Optional[Dict[str, Any]] = None,
|
|
200
185
|
att_dropout: float = 0.1,
|
|
186
|
+
norm: Union[type, Dict[str, Any], rf.Module, Callable] = rf.LayerNorm,
|
|
201
187
|
):
|
|
202
188
|
"""
|
|
203
189
|
:param out_dim: the output feature dimension
|
|
@@ -215,6 +201,7 @@ class ConformerEncoderLayer(rf.Module):
|
|
|
215
201
|
:param self_att: the self-attention layer. RelPosSelfAttention originally and default
|
|
216
202
|
:param self_att_opts: options for the self-attention layer, for :class:`nn.RelPosSelfAttention`
|
|
217
203
|
:param att_dropout: attention dropout value
|
|
204
|
+
:param norm: pre-normalization for FF, conv and attention blocks
|
|
218
205
|
"""
|
|
219
206
|
super().__init__()
|
|
220
207
|
|
|
@@ -222,17 +209,11 @@ class ConformerEncoderLayer(rf.Module):
|
|
|
222
209
|
self.dropout_broadcast = rf.dropout_broadcast_default()
|
|
223
210
|
self.out_dim = out_dim
|
|
224
211
|
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
self.ffn1 = ConformerPositionwiseFeedForward(
|
|
228
|
-
out_dim=out_dim, ff_dim=ff_dim, dropout=dropout, activation=ff_activation
|
|
229
|
-
)
|
|
230
|
-
self.ffn1_layer_norm = rf.LayerNorm(out_dim)
|
|
212
|
+
self.ffn1 = make_ff(ff=ff, out_dim=out_dim, ff_dim=ff_dim, dropout=dropout, ff_activation=ff_activation)
|
|
213
|
+
self.ffn1_layer_norm = make_norm(norm, out_dim)
|
|
231
214
|
|
|
232
|
-
self.ffn2 =
|
|
233
|
-
|
|
234
|
-
)
|
|
235
|
-
self.ffn2_layer_norm = rf.LayerNorm(out_dim)
|
|
215
|
+
self.ffn2 = make_ff(ff=ff, out_dim=out_dim, ff_dim=ff_dim, dropout=dropout, ff_activation=ff_activation)
|
|
216
|
+
self.ffn2_layer_norm = make_norm(norm, out_dim)
|
|
236
217
|
|
|
237
218
|
if conv_norm is NotSpecified or conv_norm is rf.BatchNorm:
|
|
238
219
|
conv_norm_opts = conv_norm_opts.copy() if conv_norm_opts else {}
|
|
@@ -245,7 +226,7 @@ class ConformerEncoderLayer(rf.Module):
|
|
|
245
226
|
if not callable(conv_norm):
|
|
246
227
|
raise TypeError(f"{self}: unexpected conv_norm type {conv_norm!r}")
|
|
247
228
|
self.conv_block = ConformerConvBlock(out_dim=out_dim, kernel_size=conv_kernel_size, norm=conv_norm)
|
|
248
|
-
self.conv_layer_norm =
|
|
229
|
+
self.conv_layer_norm = make_norm(norm, out_dim)
|
|
249
230
|
|
|
250
231
|
if self_att is None or isinstance(self_att, (dict, type)):
|
|
251
232
|
self_att_opts_ = dict(
|
|
@@ -271,9 +252,9 @@ class ConformerEncoderLayer(rf.Module):
|
|
|
271
252
|
if not callable(self_att):
|
|
272
253
|
raise TypeError(f"{self}: invalid non-callable: self_att {self_att!r}")
|
|
273
254
|
self.self_att = self_att
|
|
274
|
-
self.self_att_layer_norm =
|
|
255
|
+
self.self_att_layer_norm = make_norm(norm, out_dim)
|
|
275
256
|
|
|
276
|
-
self.final_layer_norm =
|
|
257
|
+
self.final_layer_norm = make_norm(norm, out_dim)
|
|
277
258
|
|
|
278
259
|
def __call__(self, inp: Tensor, *, spatial_dim: Dim) -> Tensor:
|
|
279
260
|
"""forward"""
|
|
@@ -313,12 +294,12 @@ class ConformerEncoder(ISeqDownsamplingEncoder):
|
|
|
313
294
|
out_dim: Dim = Dim(512, name="conformer-enc-default-out-dim"),
|
|
314
295
|
*,
|
|
315
296
|
num_layers: int,
|
|
316
|
-
input_layer: Union[ConformerConvSubsample, ISeqDownsamplingEncoder, rf.Module, Any],
|
|
297
|
+
input_layer: Optional[Union[ConformerConvSubsample, ISeqDownsamplingEncoder, rf.Module, Any]],
|
|
317
298
|
input_dropout: float = 0.1,
|
|
318
299
|
ff_dim: Dim = NotSpecified,
|
|
319
|
-
ff_activation: Union[Callable[[Tensor], Tensor], Dict[str, Any], rf.Module] =
|
|
300
|
+
ff_activation: Union[Callable[[Tensor], Tensor], Dict[str, Any], rf.Module] = NotSpecified,
|
|
320
301
|
dropout: float = 0.1,
|
|
321
|
-
conv_kernel_size: int =
|
|
302
|
+
conv_kernel_size: int = NotSpecified,
|
|
322
303
|
conv_norm: Union[rf.BatchNorm, type, Dict[str, Any], Any] = NotSpecified,
|
|
323
304
|
num_heads: int = 4,
|
|
324
305
|
att_dropout: float = 0.1,
|
|
@@ -352,8 +333,10 @@ class ConformerEncoder(ISeqDownsamplingEncoder):
|
|
|
352
333
|
|
|
353
334
|
# TODO once we figured out good defaults, we would create ConformerConvSubsample here when not given
|
|
354
335
|
self.input_layer = input_layer
|
|
355
|
-
self.input_projection =
|
|
356
|
-
self.input_layer.out_dim if self.input_layer else self.in_dim, self.out_dim, with_bias=False
|
|
336
|
+
self.input_projection = (
|
|
337
|
+
rf.Linear(self.input_layer.out_dim if self.input_layer else self.in_dim, self.out_dim, with_bias=False)
|
|
338
|
+
if input_layer
|
|
339
|
+
else None
|
|
357
340
|
)
|
|
358
341
|
self.input_dropout = input_dropout
|
|
359
342
|
|
|
@@ -368,6 +351,7 @@ class ConformerEncoder(ISeqDownsamplingEncoder):
|
|
|
368
351
|
num_heads=num_heads,
|
|
369
352
|
att_dropout=att_dropout,
|
|
370
353
|
)
|
|
354
|
+
encoder_layer_opts_ = {k: v for (k, v) in encoder_layer_opts_.items() if v is not NotSpecified}
|
|
371
355
|
if encoder_layer_opts:
|
|
372
356
|
encoder_layer_opts_.update(encoder_layer_opts)
|
|
373
357
|
if not encoder_layer:
|
|
@@ -404,7 +388,35 @@ class ConformerEncoder(ISeqDownsamplingEncoder):
|
|
|
404
388
|
x_subsample, out_spatial_dim = self.input_layer(source, in_spatial_dim=in_spatial_dim)
|
|
405
389
|
else:
|
|
406
390
|
x_subsample, out_spatial_dim = source, in_spatial_dim
|
|
407
|
-
|
|
408
|
-
x = rf.dropout(
|
|
391
|
+
x = self.input_projection(x_subsample) if self.input_projection else x_subsample
|
|
392
|
+
x = rf.dropout(x, self.input_dropout, axis=self.dropout_broadcast and self.out_dim)
|
|
409
393
|
x = self.layers(x, spatial_dim=out_spatial_dim, collected_outputs=collected_outputs)
|
|
410
394
|
return x, out_spatial_dim
|
|
395
|
+
|
|
396
|
+
|
|
397
|
+
def make_ff(
|
|
398
|
+
*,
|
|
399
|
+
out_dim: Dim,
|
|
400
|
+
ff: Union[type, Dict[str, Any], rf.Module],
|
|
401
|
+
ff_dim: Union[Dim, int],
|
|
402
|
+
ff_activation: Union[Callable[[Tensor], Tensor], Dict[str, Any], rf.Module],
|
|
403
|
+
dropout: float,
|
|
404
|
+
) -> Union[ConformerPositionwiseFeedForward, rf.Module]:
|
|
405
|
+
"""
|
|
406
|
+
make the feed-forward part of the Conformer layer
|
|
407
|
+
"""
|
|
408
|
+
if ff is NotSpecified:
|
|
409
|
+
ff = ConformerPositionwiseFeedForward
|
|
410
|
+
if isinstance(ff, rf.Module):
|
|
411
|
+
ff = _copy.deepcopy(ff)
|
|
412
|
+
else:
|
|
413
|
+
ff_kwargs = dict(out_dim=out_dim, ff_dim=ff_dim, dropout=dropout, activation=ff_activation)
|
|
414
|
+
ff_kwargs = {k: v for (k, v) in ff_kwargs.items() if v is not NotSpecified}
|
|
415
|
+
if isinstance(ff, type):
|
|
416
|
+
ff = ff(**ff_kwargs)
|
|
417
|
+
elif isinstance(ff, dict):
|
|
418
|
+
ff = rf.build_from_dict(ff, **ff_kwargs)
|
|
419
|
+
else:
|
|
420
|
+
raise TypeError(f"unexpected ff type {ff!r}")
|
|
421
|
+
assert isinstance(ff, rf.Module)
|
|
422
|
+
return ff
|